In [None]:
import json
from lstm_baseline import LSTMBaseline
import torch

In [None]:
sentence_data = []
with open('multinli_1.0_train.jsonl', 'r') as jsonl:
    line = jsonl.readline()
    while line is not None and line != "":
        json_line = json.loads(line)
        sentence_data.append((json_line['sentence1'].strip().split(), json_line['sentence2'].split(),
                              json_line['gold_label']))
        line = jsonl.readline()

In [None]:
training_data = sentence_data[:int(len(sentence_data) / 2)]
test_data = sentence_data[int(len(sentence_data) / 2):]

In [None]:
word_to_ix = {}
for sent1, sent2, tag in sentence_data:
    for word in sent1:
        if word not in word_to_ix:
            word_to_ix[word] = len(word_to_ix)
    for word in sent2:
        if word not in word_to_ix:
            word_to_ix[word] = len(word_to_ix)
tag_to_ix = {"entailment": 0, "neutral": 1, "contradiction": 2}

In [None]:
def test(model):
    classes = ["entailment", "neutral", "contradiction"]
    class_tp = list(0 for _ in range(3))
    class_fp = list(0 for _ in range(3))
    class_fn = list(0 for _ in range(3))
    for sentence_hypothesis, sentence_premise, tag in test_data:
        sentence_h = model.prepare_sequence(sentence_hypothesis, word_to_ix)
        sentence_p = model.prepare_sequence(sentence_premise, word_to_ix)
        tag = tag_to_ix[tag]
        output = model(sentence_h, sentence_p).data
        predicted = int((output == torch.max(output)).nonzero()[0])
        if predicted == tag:
            class_tp[tag] += 1
        else:
            class_fn[tag] += 1
            class_fp[predicted] += 1
    for i in range(3):
        prec = class_tp[i] / (class_tp[i] + class_fp[i])
        rec = class_tp[i] / (class_tp[i] + class_fn[i])
        print('F1 score of {0} : {1}, precision: {2}, recall: {3}'.format(classes[i], 2 / ((1/prec) + (1/rec)), prec, rec))

In [None]:
lstm = LSTMBaseline(len(word_to_ix), 64, 32, len(tag_to_ix))

In [None]:
print("Scores before training")
test(lstm)

In [None]:
lstm.back_propagation(100, training_data, word_to_ix, tag_to_ix)

In [None]:
print("Scores after training")
test(lstm)