# Read data

In [1]:
import json
from snli import UnigramSNLIData
from config import IDENTITY_LABEL_EXT_FILE, PREMISE_KEY, HYPOTHESIS_KEY
from pmi import PMI

In [2]:
data = UnigramSNLIData()
premise_stats = data.collect_stats(key=PREMISE_KEY, bigram=True)
hyp_stats = data.collect_stats(key=HYPOTHESIS_KEY, bigram=True)

# Read identity labels

In [28]:
identity_labels = []
with open(IDENTITY_LABEL_EXT_FILE, encoding='utf-8') as f:
    for line in f:
        line = line.rstrip('\n').strip()
        if len(line) > 0:
            identity_labels.append(line)

# Premise

In [29]:
premise_pmi = PMI(premise_stats)
premise_pmi_stats = premise_pmi(identity_labels, top_k=5, include_target_bigrams=False)

with open('pmi_premise.json', 'w', encoding='utf-8') as f:
    json.dump(premise_pmi_stats, f, indent=2)

premise_pmi_stats

{'woman': ['mascara', 'lipstick', 'incense', 'clothesline', 'headscarf'],
 'women': ['saris', 'headscarves', 'bikinis', 'headdresses', 'coverings'],
 'man': ['girlfriend', 'shaves', 'mustached', 'breakdances', 'rearing'],
 'men': ['turbans', 'tuxedos', 'ladders', 'jumpsuits', 'wetsuits'],
 'girl': ['cereal', 'pigtails', 'fairy', 'stitch', 'kitty'],
 'girls': ['twin', 'dresses', 'cheerleading', 'leotards', 'barbie'],
 'boy': ['scouts', 'scout', 'slip', 'over-sized', 'pumpkins'],
 'boys': ['comic', 'tongues', 'unicef', 'twin', 'legos'],
 'female': ['vocalist', 'companion', 'stroke', 'dentist', 'athlete'],
 'male': ['bassist', 'vocalist', 'caucasian', 'entertainer', 'grin'],
 'mother': ['daughter', 'sons', 'totter', 'teeter', 'son'],
 'father': ['son', 'sons', 'grandfather', 'daughter', 'presumably'],
 'sister': ['brother', 'burying', 'crayon', 'stood', 'quietly'],
 'brother': ['sister', 'painters', 'unicef', 'stood', 'quietly'],
 'daughter': ['mother', 'father', 'puma', 'poncho', 'mom'],

In [30]:
premise_pmi = PMI(premise_stats)
premise_pmi_stats = premise_pmi(identity_labels, top_k=5, include_target_bigrams=True)

with open('pmi_premise_bigram.json', 'w', encoding='utf-8') as f:
    json.dump(premise_pmi_stats, f, indent=2)

premise_pmi_stats

{'woman': ['applying mascara',
  'mascara',
  'yellow scarf',
  'skirt stands',
  'skirt walks'],
 'women': ['saris', 'head coverings', 'headscarves', 'bikinis', 'headdresses'],
 'man': ['beard sits',
  'bicycle near',
  'suit walking',
  'gray beard',
  'girlfriend'],
 'men': ['turbans',
  'tuxedos',
  'black suits',
  'military uniforms',
  'boxing match'],
 'girl': ['dress runs', 'pink bathing', 'cereal', 'pigtails', 'pink headband'],
 'girls': ['white dresses',
  'wearing dresses',
  'parallel bars',
  'dresses standing',
  'jump rope'],
 'boy': ['pumpkin patch', 'scouts', 'scout', 'chocolate ice', 'baseball bat'],
 'boys': ['flag football', 'comic', 'tongues', 'school age', 'unicef'],
 'female': ['free throw',
  'vocalist',
  'figure skater',
  'companion',
  'tennis players'],
 'male': ['female dancer',
  'older white',
  'bassist',
  'artist painting',
  'female singer'],
 'mother': ['young son', 'daughter', 'young daughter', 'sons', 'totter'],
 'father': ['son', 'sons', 'grandf

# Hyp

In [31]:
hyp_pmi = PMI(hyp_stats)
hyp_pmi_stats = hyp_pmi(identity_labels, top_k=5, include_target_bigrams=False)

with open('pmi_hypothesis.json', 'w', encoding='utf-8') as f:
    json.dump(hyp_pmi_stats, f, indent=2)

hyp_pmi_stats

{'woman': ['weaves', 'chevrolet', 'veil', 'bra', 'applies'],
 'women': ['burkas', 'husbands', 'saris', 'kimonos', 'bikinis'],
 'man': ['beared', 'refueling', 'mustache', 'shaven', 'afro'],
 'men': ['turbans', 'rickshaws', 'wives', 'cigars', 'tuxedos'],
 'girl': ['pigtails', 'diary', 'biscuits', 'totter', 'teeter'],
 'girls': ['slumber', 'lockers', 'cheerleading', 'sleepover', 'barbies'],
 'boy': ['hacky', 'scouts', 'scout', 'playroom', 'see-saw'],
 'boys': ['marbles', 'sergeant', 'missionary', 'twin', 'frat'],
 'female': ['contortionist', 'consisting', 'tanned', 'slender', 'barista'],
 'male': ['entry', 'dumps', 'youthful', 'vocalist', 'invention'],
 'mother': ['bonding', 'aunt', 'daughter', 'comforted', 'consoling'],
 'father': ['son', 'picnicking', 'wetsuits', 'sons', 'daughter'],
 'sister': ['brother', 'piggyback', 'mesmerized', 'laws', 'vacuums'],
 'brother': ['sister', 'teases', 'eyed', 'dislikes', 'frat'],
 'daughter': ['symbols', 'daddy', 'hailing', 'ice-cream', 'father'],
 'son

In [32]:
hyp_pmi = PMI(hyp_stats)
hyp_pmi_stats = hyp_pmi(identity_labels, top_k=5, include_target_bigrams=True)

with open('pmi_hypothesis_bigram.json', 'w', encoding='utf-8') as f:
    json.dump(hyp_pmi_stats, f, indent=2)

hyp_pmi_stats

{'woman': ['black purse',
  'chevrolet car',
  'flower dress',
  'white gown',
  'weaves'],
 'women': ['wearing skirts', 'burkas', 'husbands', 'black dresses', 'saris'],
 'man': ['clown nose',
  'glasses looks',
  'beared',
  'gray beard',
  'hawaiian shirt'],
 'men': ['turbans', 'rickshaws', 'wives', 'yellow vests', 'holding guitars'],
 'girl': ['pigtails',
  'baby sister',
  'pink swimsuit',
  'little blond',
  'pink helmet'],
 'girls': ['play jump',
  'slumber party',
  'ballet class',
  'slumber',
  'playing house'],
 'boy': ['toy truck',
  'action figures',
  'pirate costume',
  'pumpkin patch',
  'hacky sack'],
 'boys': ['fruit cart', 'black swim', 'comic books', 'marbles', 'sergeant'],
 'female': ['male dog',
  'contortionist',
  'football fan',
  'consisting',
  'bicyclist rides'],
 'male': ['female dog',
  'tap dancer',
  'model poses',
  'gray tank',
  'female singer'],
 'mother': ['daughter play', 'bonding', 'aunt', 'son watch', 'daughter walk'],
 'father': ['son playing', '