In [9]:
import spacy

In [13]:
nlp = spacy.load("en_core_web_lg")

## Load data

In [2]:
import json

In [3]:
def read_jsonnl(file):
    with open(file, "r") as f:
        data = [json.loads(line) for line in f.readlines()]
        return data

In [4]:
snli_path = "data/snli_1.0/snli_1.0"

In [5]:
train_data = read_jsonnl(f"{snli_path}/snli_1.0_train.jsonl")

In [7]:
dev_data = read_jsonnl(f"{snli_path}/snli_1.0_dev.jsonl")

In [6]:
test_data = read_jsonnl(f"{snli_path}/snli_1.0_test.jsonl")

In [8]:
train_data[0]

{'annotator_labels': ['neutral'],
 'captionID': '3416050480.jpg#4',
 'gold_label': 'neutral',
 'pairID': '3416050480.jpg#4r1n',
 'sentence1': 'A person on a horse jumps over a broken down airplane.',
 'sentence1_binary_parse': '( ( ( A person ) ( on ( a horse ) ) ) ( ( jumps ( over ( a ( broken ( down airplane ) ) ) ) ) . ) )',
 'sentence1_parse': '(ROOT (S (NP (NP (DT A) (NN person)) (PP (IN on) (NP (DT a) (NN horse)))) (VP (VBZ jumps) (PP (IN over) (NP (DT a) (JJ broken) (JJ down) (NN airplane)))) (. .)))',
 'sentence2': 'A person is training his horse for a competition.',
 'sentence2_binary_parse': '( ( A person ) ( ( is ( ( training ( his horse ) ) ( for ( a competition ) ) ) ) . ) )',
 'sentence2_parse': '(ROOT (S (NP (DT A) (NN person)) (VP (VBZ is) (VP (VBG training) (NP (PRP$ his) (NN horse)) (PP (IN for) (NP (DT a) (NN competition))))) (. .)))'}

In [57]:
item = None
for snli_item in dev_data:
    if snli_item['gold_label'] =='-':
        item = snli_item
        break

In [58]:
item

{'annotator_labels': ['entailment',
  'neutral',
  'entailment',
  'neutral',
  'contradiction'],
 'captionID': '3184031654.jpg#0',
 'gold_label': '-',
 'pairID': '3184031654.jpg#0r1e',
 'sentence1': 'The middle eastern woman wearing the pink headscarf is walking beside a woman in a purple headscarf.',
 'sentence1_binary_parse': '( ( ( The ( middle ( eastern woman ) ) ) ( wearing ( the ( pink headscarf ) ) ) ) ( ( is ( walking ( beside ( ( a woman ) ( in ( a ( purple headscarf ) ) ) ) ) ) ) . ) )',
 'sentence1_parse': '(ROOT (S (NP (NP (DT The) (JJ middle) (JJ eastern) (NN woman)) (VP (VBG wearing) (NP (DT the) (JJ pink) (NN headscarf)))) (VP (VBZ is) (VP (VBG walking) (PP (IN beside) (NP (NP (DT a) (NN woman)) (PP (IN in) (NP (DT a) (JJ purple) (NN headscarf))))))) (. .)))',
 'sentence2': 'Two women are walking together.',
 'sentence2_binary_parse': '( ( Two women ) ( ( are ( walking together ) ) . ) )',
 'sentence2_parse': '(ROOT (S (NP (CD Two) (NNS women)) (VP (VBP are) (VP (VBG wa

In [61]:
def convert_data_to_X_and_y(nlp, data, feature_extractor):
    X, y = [], []
    for i in range(0, len(data)):
        yi = data[i]['gold_label']
        if yi not in ['entailment', 'contradiction', 'neutral']:
            continue
        xi = feature_extractor(nlp, data[i])
        
        X.append(xi)
        y.append(yi)
        
    return X, y

## Baseline

In [51]:
from sklearn.pipeline import Pipeline
from sklearn import svm
from sklearn.feature_extraction import DictVectorizer
from sklearn.metrics import classification_report

In [35]:
def get_jaccard_sim_by_lemma(sent_tokens1, sent_tokens2): 
    lemmas1 = set([token.lemma_ for token in sent_tokens1])
    lemmas2 = set([token.lemma_ for token in sent_tokens2])

    matched = lemmas1.intersection(lemmas2)
    return float(len(matched)) / (len(lemmas1) + len(lemmas2) - len(matched))

In [37]:
def get_jaccard_sim_by_verb(sent_tokens1, sent_tokens2): 
    verbs1 = set([token.lemma_ for token in sent_tokens1 if token.pos_ == "VERB" and not token.lemma == "be"])
    verbs2 = set([token.lemma_ for token in sent_tokens2 if token.pos_ == "VERB" and not token.lemma == "be"])

    matched = verbs1.intersection(verbs2)
    if len(verbs1) + len(verbs2) - len(matched) == 0:
        return None
        
    return float(len(matched)) / (len(verbs1) + len(verbs2) - len(matched))    

In [38]:
def find_all_verbs_in_sent(sentence, nlp):
    doc = nlp(sentence)
    return set([token.lemma_ for token in doc if token.pos_ == "VERB" and not token.lemma == "be"])

In [41]:
def exctract_initial_features(nlp, snli_item):
    features = {}
    text_sent = snli_item['sentence1']
    text_sent_tokens = nlp(text_sent)
#     print(text_sent)
    hypothesis_sent = snli_item['sentence2']
    hypothesis_sent_tokens = nlp(hypothesis_sent)
#     print(hypothesis_sent)
    
    features['text-hyp-sim'] = get_jaccard_sim_by_lemma(text_sent_tokens, hypothesis_sent_tokens)
    text_hyp_sim_verb = get_jaccard_sim_by_verb(text_sent_tokens, hypothesis_sent_tokens)
    if text_hyp_sim_verb:
        features['text-hyp-sim-verb'] = text_hyp_sim_verb

    features['text-len'] = len(text_sent)
    features['hyp-len'] = len(hypothesis_sent)
    
    
    return features

In [68]:
%time X_train, y_train = convert_data_to_X_and_y(nlp, train_data, exctract_initial_features)

CPU times: user 1h 40min 1s, sys: 12.3 s, total: 1h 40min 13s
Wall time: 1h 40min 19s


In [62]:
X_dev, y_dev = convert_data_to_X_and_y(nlp, dev_data, exctract_initial_features)

In [69]:
clf = Pipeline([
    ('vect', DictVectorizer()),
    ('svm', svm.SVC())
])

In [70]:
clf.fit(X_train, y_train)

Pipeline(memory=None,
         steps=[('vect',
                 DictVectorizer(dtype=<class 'numpy.float64'>, separator='=',
                                sort=True, sparse=True)),
                ('svm',
                 SVC(C=1.0, break_ties=False, cache_size=200, class_weight=None,
                     coef0=0.0, decision_function_shape='ovr', degree=3,
                     gamma='scale', kernel='rbf', max_iter=-1,
                     probability=False, random_state=None, shrinking=True,
                     tol=0.001, verbose=False))],
         verbose=False)

In [71]:
y_dev_pred = clf.predict(X_dev)

In [72]:
print(classification_report(y_dev, y_dev_pred))

               precision    recall  f1-score   support

contradiction       0.42      0.54      0.47      3278
   entailment       0.57      0.49      0.53      3329
      neutral       0.47      0.41      0.44      3235

     accuracy                           0.48      9842
    macro avg       0.49      0.48      0.48      9842
 weighted avg       0.49      0.48      0.48      9842

