In [424]:
import json
import urllib
import requests
import os
from sklearn.metrics import classification_report
from sklearn.feature_extraction import DictVectorizer
from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import Pipeline
import en_core_web_md

In [40]:
nlp = en_core_web_md.load()

In [18]:
with open('../../../../../corpora/snli_1.0/snli_1.0_train.jsonl') as f:
    train_data = [json.loads(line) for line in f.readlines()]
    
with open('../../../../../corpora/snli_1.0/snli_1.0_dev.jsonl') as f:
    dev_data = [json.loads(line) for line in f.readlines()]
    
with open('../../../../../corpora/snli_1.0/snli_1.0_test.jsonl') as f:
    test_data = [json.loads(line) for line in f.readlines()]

In [422]:
def compose(*funcs):
    def inner(*arg):
        res = {}
        for f in funcs:
            res.update(f(*arg))
        return res
    return inner


def get_classifier():
    pipe = Pipeline([
        ('dict_vect', DictVectorizer()),
        ('lrc', LogisticRegression(random_state=42, multi_class='multinomial',
                                   max_iter=100, solver='sag', n_jobs=20))])

    return pipe


def get_intersection(ents1, ents2):
    setA = set(ents1)
    setB = set(ents2)
    universe = set(doc1) | set(doc2)

    return len(setA & setB)/(len(universe))


def get_tokens_similarity(toks1, toks2):
    setA = set(toks1)
    setB = set(toks2)
    universe = set(toks1) | set(toks2)
    
    sim = [x.similarity(y) for x in setA for y in setB if x.vector_norm and y.vector_norm]
    return len(sim)/(len(universe))


def get_ngrams(text):
    res = []
    for i in range(0, len(text), 3):
        if i > 0 and i + 3 <= len(text):
            res.append(text[i:i + 3])
        elif i > 0 and i + 3 > len(text):
            res.append(text[i:i + 3] + '</S>')
        else:
            res.append('<S>' + text[i:i + 3])
    return res


def feature_extractor_base(doc1, doc2):
    feats = {}
    feats['similarity'] = doc1.similarity(doc2)
    
    return feats

# It makes it a bit worse
# TODO: investigate and improve
def feature_extractor_inter_ner(doc1, doc2):
    feats = {}

    feats['ner-inter'] = get_intersection(
        [x.ent_type_ for x in doc1],
        [x.ent_type_ for x in doc2]
    )
    
    return feats


def feature_extractor_inter_word(doc1, doc2):
    feats = {}

    feats['w-inter'] = get_intersection(
            [x.lemma_ for x in doc1],
            [x.lemma_ for x in doc2]
        )
    
    return feats


def feature_extractor_inter_noun(doc1, doc2):
    feats = {}

    feats['nn-inter'] = get_intersection(
            [x.lemma_ for x in doc1 if x.pos_ == 'NOUN'],
            [x.lemma_ for x in doc2 if x.pos_ == 'NOUN']
        )
    
    return feats


def feature_extractor_inter_verb(doc1, doc2):
    feats = {}

    feats['v-inter'] = get_intersection(
            [x.lemma_ for x in doc1 if x.pos_ == 'VERB'],
            [x.lemma_ for x in doc2 if x.pos_ == 'VERB']
        )
    
    return feats


def feature_extractor_inter_num(doc1, doc2):
    feats = {}

    feats['nm-inter'] = get_intersection(
            [x.lemma_ for x in doc1 if x.pos_ == 'NUM'],
            [x.lemma_ for x in doc2 if x.pos_ == 'NUM']
        )
    
    return feats


def feature_extractor_verb_nn_sim(doc1, doc2):
    def get_by_pos(pos):
        t1 = [x for x in doc1 if x.pos_ == pos]
        t2 = [x for x in doc2 if x.pos_ == pos]
        return t1, t2
        
    feats = {}
    
    sent1_verbs, sent2_verbs = get_by_pos('VERB')
    sent1_nouns, sent2_nouns = get_by_pos('NOUN')
    
    if sent1_verbs and sent2_verbs:
        feats['v-similar'] = get_tokens_similarity(sent1_verbs, sent2_verbs)
    
    if sent1_nouns and sent2_nouns:
        feats['nn-similar'] = get_tokens_similarity(sent1_nouns, sent2_nouns)

    return feats


def feature_extractor_inter_ngrams(doc1, doc2):
    feats = {}
    
    n1 = get_ngrams(' '.join([x.lemma_ for x in doc1]))
    n2 = get_ngrams(' '.join([x.lemma_ for x in doc2]))

    feats['ngr-inter'] = get_intersection(n1, n2)
    
    return feats


# TODO: make smth with unknown gold label:
# a) filter those
# b) mark as neutral
def get_data(dataset, feature_extractor):
    features = [feature_extractor(nlp(x['sentence1']), nlp(x['sentence2'])) for x in dataset]
    labels = [x['gold_label'] for x in dataset]
    return features, labels


def print_result(train_data, dev_data, feature_extractor):
    X_train, y_train = get_data(train_data[:500], feature_extractor)
    X_dev, y_dev = get_data(dev_data[:500], feature_extractor)
    clf.fit(X_train, y_train)
    print(classification_report(y_dev, clf.predict(X_dev)))

In [29]:
clf = get_classifier()

### Baseline (just simply use sentence similarity from spacy)

In [425]:
print_result(train_data, dev_data, feature_extractor_base)

               precision    recall  f1-score   support

            -       0.00      0.00      0.00        10
contradiction       0.47      0.47      0.47       169
   entailment       0.43      0.65      0.52       165
      neutral       0.33      0.18      0.23       156

     accuracy                           0.43       500
    macro avg       0.31      0.32      0.31       500
 weighted avg       0.41      0.43      0.40       500



  _warn_prf(average, modifier, msg_start, len(result))


### 1. With NER intersection (-)

In [427]:
feature_extractor = compose(feature_extractor_base, feature_extractor_inter_ner)
print_result(train_data, dev_data, feature_extractor)

               precision    recall  f1-score   support

            -       0.00      0.00      0.00        10
contradiction       0.47      0.47      0.47       169
   entailment       0.43      0.65      0.51       165
      neutral       0.31      0.16      0.21       156

     accuracy                           0.42       500
    macro avg       0.30      0.32      0.30       500
 weighted avg       0.40      0.42      0.39       500



  _warn_prf(average, modifier, msg_start, len(result))


### 2. With word intersection without NER intersection (?)

In [428]:
feature_extractor = compose(feature_extractor_base,
                            feature_extractor_inter_word)
print_result(train_data, dev_data, feature_extractor)

               precision    recall  f1-score   support

            -       0.00      0.00      0.00        10
contradiction       0.41      0.46      0.43       169
   entailment       0.45      0.56      0.50       165
      neutral       0.29      0.19      0.23       156

     accuracy                           0.40       500
    macro avg       0.29      0.30      0.29       500
 weighted avg       0.38      0.40      0.38       500



  _warn_prf(average, modifier, msg_start, len(result))


### 3. With word intersection with NER intersection (+)

In [429]:
# TODO: NER words, lemma etc
feature_extractor = compose(feature_extractor_base,
                            feature_extractor_inter_ner,
                            feature_extractor_inter_word)
print_result(train_data, dev_data, feature_extractor)

               precision    recall  f1-score   support

            -       0.00      0.00      0.00        10
contradiction       0.40      0.49      0.44       169
   entailment       0.46      0.58      0.51       165
      neutral       0.30      0.17      0.22       156

     accuracy                           0.41       500
    macro avg       0.29      0.31      0.29       500
 weighted avg       0.38      0.41      0.39       500



  _warn_prf(average, modifier, msg_start, len(result))


### 4. With NOUN intersection (+)

In [430]:
feature_extractor = compose(feature_extractor_base,
                            feature_extractor_inter_ner,
                            feature_extractor_inter_word,
                            feature_extractor_inter_noun)
print_result(train_data, dev_data, feature_extractor)

               precision    recall  f1-score   support

            -       0.00      0.00      0.00        10
contradiction       0.42      0.49      0.45       169
   entailment       0.45      0.50      0.47       165
      neutral       0.26      0.21      0.23       156

     accuracy                           0.39       500
    macro avg       0.28      0.30      0.29       500
 weighted avg       0.37      0.39      0.38       500



  _warn_prf(average, modifier, msg_start, len(result))


In [431]:
# VERB inter (-)
feature_extractor = compose(feature_extractor_base,
                            feature_extractor_inter_ner,
                            feature_extractor_inter_word,
                            feature_extractor_inter_noun,
                            feature_extractor_inter_verb)
print_result(train_data, dev_data, feature_extractor)

               precision    recall  f1-score   support

            -       0.00      0.00      0.00        10
contradiction       0.47      0.63      0.54       169
   entailment       0.48      0.69      0.57       165
      neutral       0.30      0.07      0.11       156

     accuracy                           0.46       500
    macro avg       0.31      0.35      0.31       500
 weighted avg       0.41      0.46      0.41       500



  _warn_prf(average, modifier, msg_start, len(result))


In [432]:
# NUM inter (-)
feature_extractor = compose(feature_extractor_base,
                            feature_extractor_inter_ner,
                            feature_extractor_inter_word,
                            feature_extractor_inter_noun,
                            feature_extractor_inter_verb,
                            feature_extractor_inter_num
                           )
print_result(train_data, dev_data, feature_extractor)

               precision    recall  f1-score   support

            -       0.00      0.00      0.00        10
contradiction       0.47      0.62      0.53       169
   entailment       0.48      0.69      0.57       165
      neutral       0.30      0.08      0.12       156

     accuracy                           0.46       500
    macro avg       0.31      0.35      0.31       500
 weighted avg       0.41      0.46      0.41       500



  _warn_prf(average, modifier, msg_start, len(result))


In [433]:
# VERB & NOUN tokens similarity (+)
feature_extractor = compose(feature_extractor_base,
                            feature_extractor_inter_ner,
                            feature_extractor_inter_word,
                            feature_extractor_inter_noun,
                            feature_extractor_inter_verb,
                            feature_extractor_inter_num,
                            feature_extractor_verb_nn_sim,
                           )
print_result(train_data, dev_data, feature_extractor)

               precision    recall  f1-score   support

            -       0.00      0.00      0.00        10
contradiction       0.51      0.52      0.52       169
   entailment       0.50      0.70      0.58       165
      neutral       0.38      0.23      0.29       156

     accuracy                           0.48       500
    macro avg       0.35      0.36      0.35       500
 weighted avg       0.46      0.48      0.46       500



  _warn_prf(average, modifier, msg_start, len(result))


In [434]:
# ngrams
feature_extractor = compose(feature_extractor_base,
                            feature_extractor_inter_ner,
                            feature_extractor_inter_word,
                            feature_extractor_inter_noun,
                            feature_extractor_inter_verb,
                            feature_extractor_inter_num,
                            feature_extractor_verb_nn_sim,
                            feature_extractor_inter_ngrams
                           )
print_result(train_data, dev_data, feature_extractor)

               precision    recall  f1-score   support

            -       0.00      0.00      0.00        10
contradiction       0.52      0.52      0.52       169
   entailment       0.52      0.69      0.59       165
      neutral       0.38      0.27      0.31       156

     accuracy                           0.49       500
    macro avg       0.35      0.37      0.36       500
 weighted avg       0.46      0.49      0.47       500



  _warn_prf(average, modifier, msg_start, len(result))


In [None]:
[x for x in train_data if x['gold_label'] == '-']