# Demo notebook

## Imports

In [1]:
from pathlib import Path

from nltk import word_tokenize
import torch
import torch.nn as nn

from data import load_data
from encoders import LSTM, Baseline, BiLSTM, MaxBiLSTM
from model import InferSent
from utils import load_model_state

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

## Load the data

In [2]:
_, _, _, text_field, _ = load_data()
embedding = nn.Embedding.from_pretrained(text_field.vocab.vectors)
embedding.requires_grad = False

## Initialize the encoder and model

In [3]:
encoder = Baseline()
# encoder = LSTM()
# encoder = BiLSTM()
# encoder = MaxBiLSTM()

model = InferSent(
    input_dim=4*encoder.output_dim,
    hidden_dim=512,
    output_dim=3,
    embedding=embedding,
    encoder=encoder
)
model.to(DEVICE)

InferSent(
  (embedding): Embedding(37241, 300)
  (encoder): Baseline()
  (model): Sequential(
    (0): Linear(in_features=1200, out_features=512, bias=True)
    (1): Linear(in_features=512, out_features=3, bias=True)
  )
)

## Load the best model of this type

In [4]:
models_path = Path.cwd().parent / 'output' / 'models'
models = list(models_path.glob(f'InferSent_{encoder.__class__.__name__}_model.pt'))
if len(models) > 0:
    load_model_state(models[0], model)
else:
    raise ValueError(f'No models with the {encoder_type} encoder exist in the models directory!')

Loading the model state... Done!


## Predict inference

In [5]:
premise = 'Two chickens are eating corn outside'
hypothesis = 'A couple of birds are having a bite'

premise = word_tokenize(premise)
hypothesis = word_tokenize(hypothesis)

premise = torch.tensor([text_field.vocab.stoi[token] for token in premise]).to(DEVICE)
hypothesis = torch.tensor([text_field.vocab.stoi[token] for token in hypothesis]).to(DEVICE)

# predict entailment
y_pred = model.forward(
    (premise.expand(1, -1).transpose(0, 1), torch.tensor(len(premise)).to(DEVICE)),
    (hypothesis.expand(1, -1).transpose(0, 1), torch.tensor(len(hypothesis)).to(DEVICE))
)

# determine the type of inference
if y_pred.argmax().item() == 0:
    print('Entailment')
elif y_pred.argmax().item() == 1:
    print('Contradiction')
elif y_pred.argmax().item() == 2:
    print('Neutral')
else:
    raise ValueError('Invalid class!')

Contradiction
