# Demo notebook
This notebook serves as a demo of my code for the InferSent practical assignment, made in the context of the SMNLS course taught at the UvA in the Spring 2019 semester. All code can also be found in [this](https://github.com/sgraaf/practical_1_learning_sentence_representations) GitHub repository.

As you'll find, in the `src` directory, other than the required `train.py`, `eval.py` and `infer.py`, there are some additional files, namely:
```
data.py       Contains various functions for loading and preprocessing the relevant data for training
encoders.py   Contains the four implemented encoders (Baseline, LSTM, BiLSTM and MaxBiLSTM)
model.py      Contains the InferSent model (which uses one of the encoders in encoders.py)
utils.py      Contains various utility functions used throughout training, evaluating, testing, etc.
```

Below, you'll find a demo of how the model predicts inference for a premise and hypothesis sentence couple.

## Imports
Below, you'll find the imports required to run this demo. As you'll see, many of the files mentioned previously are imported and used here.

In [1]:
from pathlib import Path

from nltk import word_tokenize
import torch
import torch.nn as nn

from data import load_data
from encoders import Baseline, BiLSTM, LSTM, MaxBiLSTM
from model import InferSent
from utils import load_model_state

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

## Load the data
This cell loads the relevant data. As we're not training, we do not need the `train`, `dev`, and `test` splits of the SNLI dataset, but instead are only interested in the `text_field`, which contains our vocabulary!

In [2]:
_, _, _, text_field, _ = load_data()
embedding = nn.Embedding.from_pretrained(text_field.vocab.vectors)
embedding.requires_grad = False

## Initialize the encoder and model
Below, you'll find the initialization of the encoder of choice and the model. As mentioned previously, the encoder is passed as an argument to the model, such that during training (or evaluation, etc), only the model is needed.

In [3]:
encoder = Baseline()
# encoder = LSTM()
# encoder = BiLSTM()
# encoder = MaxBiLSTM()

model = InferSent(
    input_dim=4*encoder.output_dim,
    hidden_dim=512,
    output_dim=3,
    embedding=embedding,
    encoder=encoder
)
model.to(DEVICE)

InferSent(
  (embedding): Embedding(37241, 300)
  (encoder): Baseline()
  (model): Sequential(
    (0): Linear(in_features=1200, out_features=512, bias=True)
    (1): Linear(in_features=512, out_features=3, bias=True)
  )
)

## Load the best model of this type
Below, you'll find that we load the best model that belongs to this encoder type. If all went well during training, there will be no hiccups here!

In [4]:
models_path = Path.cwd().parent / 'output' / 'models'
models = list(models_path.glob(f'InferSent_{encoder.__class__.__name__}_model.pt'))
if len(models) > 0:
    load_model_state(models[0], model)
else:
    raise ValueError(f'No models with the {encoder_type} encoder exist in the models directory!')

Loading the model state... Done!


## Predict inference
Below you'll find a demonstration of how this model predicts the (type of) inference between a premise sentence and a hypothesis sentence (after these have been pre-procesed to fit the model architecture).

In [5]:
premise = 'Two chickens are eating corn outside'
hypothesis = 'A couple of birds are having a bite'

premise = word_tokenize(premise)
hypothesis = word_tokenize(hypothesis)

premise = torch.tensor([text_field.vocab.stoi[token] for token in premise]).to(DEVICE)
hypothesis = torch.tensor([text_field.vocab.stoi[token] for token in hypothesis]).to(DEVICE)

# predict entailment
y_pred = model.forward(
    (premise.expand(1, -1).transpose(0, 1), torch.tensor(len(premise)).to(DEVICE)),
    (hypothesis.expand(1, -1).transpose(0, 1), torch.tensor(len(hypothesis)).to(DEVICE))
)

# determine the type of inference
if y_pred.argmax().item() == 0:
    print('Entailment')
elif y_pred.argmax().item() == 1:
    print('Contradiction')
elif y_pred.argmax().item() == 2:
    print('Neutral')
else:
    raise ValueError('Invalid class!')

Contradiction
