# Analysis of InferSent Model

In [1]:
import torch
from train import NLINet
from data import SNLIData
from utils import download_file_from_google_drive

## Download pretrained model

Use the following file IDs to download the corresponding models.


1. MeanEmbedding: `1q4ZRin0tKohQ504fi5HVkjDiolLZjuxg`
2. LSTM: `1lwClDt1cNaOtOo5h-bTx-rWr7ePeIyIO`
3. BiLSTM: `1zPZzm1EECkLdcbQ_SShhYPOrBXNu_zvz`
4. BiLSTM-maxpool: `12BzrDODCYjMZLhld1SFcyckwAa4Vj4fL`

Let's download the model checkpoint with file ID from above.

In [None]:
download_file_from_google_drive('12BzrDODCYjMZLhld1SFcyckwAa4Vj4fL','bilstm-maxpool.ckpt')

checkpoint_path = 'bilstm-maxpool.ckpt'

## Evalute on SNLI test data

In [3]:
from eval import snli, senteval, process_senteval_result

snli(checkpoint_path)

params_senteval = {
    'task_path': './SentEval/data/',
    'usepytorch': True,
    'kfold': 5,
    'classifier': {
        'nhid': 0,
        'optim': 'rmsprop',
        'batch_size': 128,
        'tenacity': 3,
        'epoch_size': 2
    }
}

# result_dict = senteval(checkpoint_path,
#           params_senteval=params_senteval)
# process_senteval_result(result_dict)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores


Testing: 0it [00:00, ?it/s]

Final Test Accuracy: 0.8598330618892508
--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_acc_epoch': 0.8598330618892508,
 'test_acc_step': 0.8598330616950989,
 'test_loss': 0.37223106622695923}
--------------------------------------------------------------------------------
[{'test_acc_step': 0.8598330616950989, 'test_loss': 0.37223106622695923, 'test_acc_epoch': 0.8598330618892508}]


## Making predictions

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [6]:
data = SNLIData(batch_size=1000)

In [4]:
model = NLINet.load_from_checkpoint('./logs/infersent-logs/BiLSTM-maxpool/epoch=2-step=12875.ckpt').to(device)
model.eval()

NLINet(
  (model): InferSent(
    (encoder): BiLSTMEncoder(
      (embedding): Embedding(33893, 300)
      (linear): Linear(in_features=300, out_features=1028, bias=True)
      (relu): ReLU()
      (projection): Sequential(
        (0): Embedding(33893, 300)
        (1): Linear(in_features=300, out_features=1028, bias=True)
        (2): ReLU()
      )
      (lstm): LSTM(1028, 1028, batch_first=True, bidirectional=True)
    )
    (classifier): Classifier(
      (lin1): Linear(in_features=8224, out_features=512, bias=True)
      (lin2): Linear(in_features=512, out_features=512, bias=True)
      (lin3): Linear(in_features=512, out_features=3, bias=True)
      (relu): ReLU()
      (net): Sequential(
        (0): Linear(in_features=8224, out_features=512, bias=True)
        (1): ReLU()
        (2): Linear(in_features=512, out_features=512, bias=True)
        (3): ReLU()
        (4): Linear(in_features=512, out_features=3, bias=True)
      )
    )
  )
  (criterion): CrossEntropyLoss()
)

In [7]:
import spacy

vocab = data.get_vocab()

def tokenize_spacy(text):
    spacy_en = spacy.load('en_core_web_sm')
    return [tok.text for tok in spacy_en.tokenizer(text)]

def get_encoded_text(text, tokenize=True, max_len=12, verbose=True):
    
    # Get vocab index of each tokenized words.
    if tokenize:
        text = tokenize_spacy(text)
    
    sent_idxs = [vocab[tok] for tok in text]
    
    # Apply padding for consistent length.
    for _ in range(len(sent_idxs) + 1, max_len + 1):
        sent_idxs.append(vocab["<pad>"])
    
    # Print tokens and idxs for debugging.
    if verbose:
        print(sent_idxs)
        print([data.text.vocab.itos[i] for i in sent_idxs])
    
    embed = torch.LongTensor([sent_idxs]).to(device)
    return embed

def get_label(idx):
    label_map = {0:'Entailment', 1:'Contradiction', 2:'Neutral'}
    return label_map[idx]

Now let's provide our own premise and hypothesis and see what the model predicts.

In [193]:
# entailment
premise = 'A soccer game with multiple males playing.'
hypothesis = 'Some men are playing a sport.'

In [8]:
# contradiction
premise = 'A black race car starts up in front of a crowd of people.'
hypothesis = 'A man is driving down a lonely road.'

In [211]:
# neutral
premise = 'An older and younger man smiling.'
hypothesis = 'Two men are smiling and laughing at the cats playing on the floor.'

In [204]:
# entailment
premise = 'A soccer game with multiple males playing.'
hypothesis = 'Some men are playing a sport.'

In [41]:
premise = 'A man is walking a dog'
hypothesis = 'No cat is outside'

In [42]:
p = get_encoded_text(premise, verbose = False)
h = get_encoded_text(hypothesis, verbose = False)

test_data = ((p,h),_)

out = model.model(test_data)

pred = get_label(torch.argmax(torch.nn.functional.log_softmax(out, 1), 1).item())
print('Prediction:', pred)

Prediction: Contradiction


## Analysis of the results

In [11]:
_, test_loader, _ = data.get_iters()
test_data = list(test_loader)

In [15]:

out = model.model(test_data[0])

pred = torch.argmax(torch.nn.functional.log_softmax(out, 1), 1)


In [21]:
y_true = test_data[0].label.numpy()

In [22]:
y_pred = pred.numpy()

In [38]:
import numpy as np
from sklearn.metrics import multilabel_confusion_matrix, classification_report

report = classification_report(y_true, y_pred, output_dict=True)
cf_mat = multilabel_confusion_matrix(y_true, y_pred)

In [39]:
report # Entailment, Contradicition, Neutral

{'0': {'precision': 0.8956743002544529,
  'recall': 0.9142857142857143,
  'f1-score': 0.9048843187660669,
  'support': 385},
 '1': {'precision': 0.9113924050632911,
  'recall': 0.9113924050632911,
  'f1-score': 0.9113924050632911,
  'support': 316},
 '2': {'precision': 0.8213058419243986,
  'recall': 0.7993311036789298,
  'f1-score': 0.8101694915254237,
  'support': 299},
 'accuracy': 0.879,
 'macro avg': {'precision': 0.8761241824140477,
  'recall': 0.875003074342645,
  'f1-score': 0.8754820717849272,
  'support': 1000},
 'weighted avg': {'precision': 0.8784050523333595,
  'recall': 0.879,
  'f1-score': 0.8786211406910375,
  'support': 1000}}

In [40]:
# TN, FP
# FN, TP

cf_mat

array([[[574,  41],
        [ 33, 352]],

       [[656,  28],
        [ 28, 288]],

       [[649,  52],
        [ 60, 239]]])