In [73]:
import torch
import os.path
from spacy.lang.en import English
import torch.nn.functional as F
from src.models.nliclassifier import NLIClassifier
from src.dataset.dataloaders import get_embeddings_for_data

In [74]:
# Make sure to enter the parent dir of the embedding vocab+vector file used for training
emb_vocab, emb_vecs = get_embeddings_for_data(
    dataset_path=os.path.join("..", "snel_data", "processed")
)

In [75]:
# Load the model along with the embeddings
model = (
    NLIClassifier.load_from_checkpoint(
        os.path.join(
            "..",
            "checkpoint",
            "checkpoint-snel",
            "real_blstmpme_train",
            "epoch=14-step=32190-val_acc=0.86.ckpt",
        ),
        strict=False,
        embedding_mat=emb_vecs,
    )
    .cpu()
    .eval()
)

In [76]:
tokenizer = English().tokenizer


def process_input(sent_1, sent_2):
    sent_tokens = [
        [token.text for token in tokenizer(sent.lower())] for sent in [sent_1, sent_2]
    ]
    indices = [emb_vocab(sent_tok) for sent_tok in sent_tokens]
    sents = [torch.tensor([idxs], dtype=torch.long) for idxs in indices]
    lens = [torch.tensor([len(sent_token)], dtype=torch.int64) for sent_token in sent_tokens]
    return [*sents, *lens]

In [88]:
sent_1 = "An apple and an orange on a table"
sent_2 = "There are fruits present on the table"
raw_inputs = process_input(sent_1, sent_2)

In [89]:
raw_inputs

[tensor([[ 1321,  1582,  1351,  1321, 20418, 20303,   419, 29776]]),
 tensor([[30381,  1713, 12123, 22922, 20303, 30330, 29776]]),
 tensor([8]),
 tensor([7])]

In [90]:
label_map = {
    0: "entailment",
    1: "neutral",
    2: "contradiction",
}

In [91]:
out = model(*raw_inputs)
probs = F.softmax(out, dim=-1)
label = torch.argmax(probs, dim=1).unsqueeze(0).detach().item()

In [92]:
print(
    f"""
Premise: "{sent_1}"
Hypothesis: "{sent_2}"
Model judgement: "{label_map[label]}"
"""
)


Premise: "An apple and an orange on a table"
Hypothesis: "There are fruits present on the table"
Model judgement: "entailment"



In [82]:
# TO NOT FORGET TO REPLACE THE EMBEDDINGS WITH THOSE USED TO TRAIN THE MODELS ON SNELLIUS