# Word2vec context nearest neighbour model

In [318]:
#package to load word2vec vectors
import gensim
from gensim.models.keyedvectors import KeyedVectors
#self built functions
import utilities
#semcor corpus
import nltk
from nltk.corpus import semcor
from nltk.corpus import wordnet as wn
#Micellaneous
import numpy as np
import collections
from scipy.spatial.distance import cosine
import sys
import re
import time
import progressbar
import pickle
from functools import partial

In [184]:
#Load word2vec binary file
model = KeyedVectors.load_word2vec_format('../datasets/word2vec/GoogleNews-vectors-negative300.bin', binary=True)
#Define embedding_dict
embedding_dict = model.word_vec
#Load semcor
tagged_chunks = semcor.tagged_chunks(tag='sem')

In [207]:
#Progress bar utilities
class ProgressBar(object):
    DEFAULT = 'Progress: %(bar)s %(percent)3d%%'
    FULL = '%(bar)s %(current)d/%(total)d (%(percent)3d%%) %(remaining)d to go'

    def __init__(self, total, width=40, fmt=DEFAULT, symbol='=',
                 output=sys.stderr):
        assert len(symbol) == 1

        self.total = total
        self.width = width
        self.symbol = symbol
        self.output = output
        self.fmt = re.sub(r'(?P<name>%\(.+?\))d',
            r'\g<name>%dd' % len(str(total)), fmt)

        self.current = 0

    def __call__(self):
        percent = self.current / float(self.total)
        size = int(self.width * percent)
        remaining = self.total - self.current
        bar = '[' + self.symbol * size + ' ' * (self.width - size) + ']'

        args = {
            'total': self.total,
            'bar': bar,
            'current': self.current,
            'percent': percent * 100,
            'remaining': remaining
        }
        print('\r' + self.fmt % args, file=self.output, end='')

    def done(self):
        self.current = self.total
        self()
        print('', file=self.output)

In [325]:
#prediction method
def predict(context,predict_lemmas):
    #senses_choices must be a python list of sense label (wordnet synset)
    #get context embedding
    context_emb = getContextEmb_word2vec(context=context,emb_size=300,embedding_dict=embedding_dict)
    #get senses choice
    synsets = wn.synsets(predict_lemmas)
    senses_choices = [synset.name() for synset in synsets]
    #calculate cosine distance between each sense and context
    decision_chart = [[choice,cosine(context_emb,sense_embeddings[choice])] for choice in senses_choices]
    prediction = sorted(decision_chart,key=lambda x:x[1])[0][0]
    return prediction
def get_context(tagged_chunks,position,window_size):
    center = tagged_chunks[position].leaves()
    num_words = 0
    iter_position = position
    #unroll left size
    left = []
    right = []
    while (num_words<window_size):
        iter_position-=1
        if iter_position <0:
            break
        item = tagged_chunks[iter_position]
        if type(item)==list:
            num_words +=1
            left.insert(0,item[0])
        elif type(item)==nltk.tree.Tree:
            words2append = item.leaves()
            if num_words+len(words2append)>window_size:
                num_allowed = window_size-num_words
                words2append=words2append[-num_allowed:]
            num_words += len(words2append)
            left = words2append+left
            
    num_words = 0
    iter_position = position
    while (num_words<window_size):
        iter_position+=1
        if iter_position >=len(tagged_chunks):
            break
        item = tagged_chunks[iter_position]
        if type(item)==list:
            num_words +=1
            right.append(item[0])
        elif type(item)==nltk.tree.Tree:
            words2append = item.leaves()
            if num_words+len(words2append)>window_size:
                num_allowed = window_size-num_words
                words2append=words2append[:num_allowed]
            num_words += len(words2append)
            right = right+words2append
    return left+center+right

def getContextEmb_word2vec(context,embedding_dict,emb_size,unk_emb=np.zeros(300)):
    # Input introductions
    # sentence: an array of tokens of untagged sentence. 
    # center: position of the center word
    # window_size: size of context window
    # embedding_Dict: gensim model method
    ################################################################
        output_embedding = np.zeros(emb_size)
        for word in context:
            try:
                output_embedding+=embedding_dict(word)#use gensim model method
            except:
                output_embedding+=unk_emb
        return output_embedding
    
def buildSemEmb_word2vec(tagged_chunks,embedding_dict,emb_size=300,window_size=4):
    progress = progressbar.ProgressBar(max_value=len(tagged_chunks))
    output_dict = collections.defaultdict(partial(np.zeros,emb_size))
    for idx in range(len(tagged_chunks)):
        progress.update(idx)
        itm=tagged_chunks[idx]
        if(type(itm))==list:
            continue
        else:
            #Use try except handling since some of the label is broken
            try:
                sense_index = itm.label().synset().name()
            except:
                continue
            context = get_context(position=idx,tagged_chunks=tagged_chunks,window_size=window_size)
            context_emb = getContextEmb_word2vec(context,embedding_dict=embedding_dict,emb_size=300)
            output_dict[sense_index]+=context_emb
    return output_dict

In [210]:
#build sense embeddings
sense_embeddings = buildSemEmb_word2vec(tagged_chunks=tagged_chunks,embedding_dict=embedding_dict)



In [326]:
sense_mebeddings_win5 = buildSemEmb_word2vec(tagged_chunks=tagged_chunks,embedding_dict=embedding_dict,window_size=5)

 99% (778332 of 778587) |################# | Elapsed Time: 0:16:06 ETA: 0:00:00

In [328]:
f = open('sense_win5.pk','wb')
pickle.dump(obj=sense_mebeddings_win5,file=f)

In [329]:
f.close()

# Attempt: trying to classify some ambiguous word

In [285]:
#Example test
example = chunks[:30]
context = get_context(position=5,tagged_chunks=example,window_size=5)
print('lemma to predict:%s'%(example[5]))
print('context:%s'%(' '.join(context)))
senses_choices = example[5]
print()
print('Final decision: %s'%(predict(context=context,predict_lemmas='investigation')))

lemma to predict:(Lemma('probe.n.01.investigation') investigation)
context:Grand Jury said Friday an investigation of Atlanta 's recent primary

Final decision: probe.n.01


## Notice how closed possible senses are

In [259]:
print(wn.synsets('investigation')[0])
print(wn.synsets('investigation')[0].definition())

Synset('probe.n.01')
an inquiry into unfamiliar or questionable activities


In [260]:
print(wn.synsets('investigation')[1])
print(wn.synsets('investigation')[1].definition())

Synset('investigation.n.02')
the work of inquiring into something thoroughly and systematically


In [281]:
lemma = example[10]
lemma.label().name()

'primary_election'

# Perform all-words WSD

In [327]:
def predict_all(tagged_chunks,window_size = 4):
    progress = progressbar.ProgressBar(max_value=len(tagged_chunks))
    output_dict = collections.defaultdict(lambda: np.zeros(emb_size))
    #count_dict = collections.defaultdict(lambda: 0)
    num_correct = 0
    num_predicted = 0.0
    for idx in range(len(tagged_chunks)):
        progress.update(idx)
        itm=tagged_chunks[idx]
        if (idx%100000==0)&(num_predicted>0):
            print('correct: %s, predicted: %s, accuracy: %s'%(num_correct,num_predicted,num_correct/num_predicted))
        if(type(itm))==list:
            continue
        else:
            #Use try except handling since some of the label is broken
            try:
                lemma = itm.label().name()
            except:
                continue
            context = get_context(position=idx,tagged_chunks=tagged_chunks,window_size=window_size)
            prediction = predict(context=context,predict_lemmas=lemma)
            correct = itm.label().synset().name()
            num_predicted +=1
            if prediction == correct:
                num_correct+=1
    return num_correct/num_predicted

In [297]:
# Window size 4 should get 68.3% accuracy.
predict_all(tagged_chunks,window_size=4)

[                                        ]  10069/778587 (  1%) 768518 to go

correct: 3669, predicted: 4822.0, accuracy: 0.7608875985068436


[=                                       ]  20046/778587 (  2%) 758541 to go

correct: 7121, predicted: 9430.0, accuracy: 0.7551431601272535


[=                                       ]  30045/778587 (  3%) 748542 to go

correct: 10418, predicted: 13917.0, accuracy: 0.7485808723144356


[==                                      ]  40028/778587 (  5%) 738559 to go

correct: 13754, predicted: 18493.0, accuracy: 0.7437408749256476


[==                                      ]  50030/778587 (  6%) 728557 to go

correct: 17323, predicted: 23282.0, accuracy: 0.7440511983506571


[===                                     ]  60008/778587 (  7%) 718579 to go

correct: 20779, predicted: 27909.0, accuracy: 0.744526855136336


[===                                     ]  70065/778587 (  8%) 708522 to go

correct: 24439, predicted: 32551.0, accuracy: 0.7507910663266874


[====                                    ]  80063/778587 ( 10%) 698524 to go

correct: 28138, predicted: 37226.0, accuracy: 0.7558695535378499


[====                                    ]  90007/778587 ( 11%) 688580 to go

correct: 31786, predicted: 42011.0, accuracy: 0.756611363690462


[=====                                   ] 100089/778587 ( 12%) 678498 to go

correct: 35581, predicted: 46755.0, accuracy: 0.7610095176986419


[=====                                   ] 110045/778587 ( 14%) 668542 to go

correct: 38617, predicted: 50835.0, accuracy: 0.7596537818432183




correct: 42091, predicted: 55510.0, accuracy: 0.7582597730138714




correct: 45583, predicted: 60194.0, accuracy: 0.7572681662624182




correct: 49059, predicted: 64855.0, accuracy: 0.7564412921131756




correct: 52031, predicted: 69189.0, accuracy: 0.7520126031594617




correct: 54783, predicted: 73240.0, accuracy: 0.7479929000546149




correct: 57802, predicted: 77515.0, accuracy: 0.7456879313681223




correct: 60926, predicted: 81918.0, accuracy: 0.7437437437437437




correct: 63868, predicted: 86202.0, accuracy: 0.7409108837381964




correct: 66620, predicted: 90254.0, accuracy: 0.7381390298490925




correct: 69476, predicted: 94417.0, accuracy: 0.7358420623404683




correct: 72247, predicted: 98501.0, accuracy: 0.7334646348768032




correct: 75246, predicted: 102685.0, accuracy: 0.732784729999513




correct: 78947, predicted: 107471.0, accuracy: 0.7345888658335737




correct: 82636, predicted: 112325.0, accuracy: 0.7356866236367683




correct: 86026, predicted: 116803.0, accuracy: 0.7365050555208342




correct: 89341, predicted: 121312.0, accuracy: 0.7364564099182274




correct: 92840, predicted: 126024.0, accuracy: 0.7366850758585666




correct: 95911, predicted: 130297.0, accuracy: 0.7360952285931373




correct: 99234, predicted: 134788.0, accuracy: 0.73622280915215




correct: 102543, predicted: 139178.0, accuracy: 0.7367759272298783




correct: 106050, predicted: 143648.0, accuracy: 0.7382629761639563




correct: 109507, predicted: 148100.0, accuracy: 0.7394125590817016




correct: 112864, predicted: 152601.0, accuracy: 0.7396019685323163




correct: 116189, predicted: 157104.0, accuracy: 0.739567420307567




correct: 119652, predicted: 161807.0, accuracy: 0.7394735703646937




correct: 122377, predicted: 165859.0, accuracy: 0.7378375608197324




correct: 125196, predicted: 169939.0, accuracy: 0.7367114082111816




correct: 128089, predicted: 174109.0, accuracy: 0.73568281938326




correct: 130869, predicted: 178290.0, accuracy: 0.7340232205956587




correct: 133514, predicted: 182278.0, accuracy: 0.7324745718078979




correct: 134964, predicted: 184616.0, accuracy: 0.7310525631581228




correct: 135478, predicted: 185754.0, accuracy: 0.7293409563185719




correct: 135914, predicted: 186718.0, accuracy: 0.7279105388875202




correct: 136446, predicted: 187798.0, accuracy: 0.7265572583307597




correct: 136916, predicted: 188812.0, accuracy: 0.7251445882676949




correct: 137393, predicted: 189826.0, accuracy: 0.723783886295871




correct: 137913, predicted: 190894.0, accuracy: 0.7224585371986547




correct: 138419, predicted: 191978.0, accuracy: 0.7210149079582036




correct: 138922, predicted: 193085.0, accuracy: 0.7194862366315353




correct: 139446, predicted: 194236.0, accuracy: 0.7179204678844292




correct: 139979, predicted: 195382.0, accuracy: 0.716437542864747




correct: 140521, predicted: 196557.0, accuracy: 0.7149122137598762




correct: 140985, predicted: 197577.0, accuracy: 0.7135698993303876




correct: 141462, predicted: 198620.0, accuracy: 0.7122243480012084




correct: 141916, predicted: 199599.0, accuracy: 0.7110055661601511




correct: 142390, predicted: 200584.0, accuracy: 0.709877158696606




correct: 142905, predicted: 201638.0, accuracy: 0.708720578462393




correct: 143446, predicted: 202762.0, accuracy: 0.7074599777078545




correct: 143955, predicted: 203822.0, accuracy: 0.7062780269058296




correct: 144524, predicted: 205041.0, accuracy: 0.7048541511209954




correct: 145009, predicted: 206123.0, accuracy: 0.7035071292383673




correct: 145430, predicted: 207096.0, accuracy: 0.7022347124039093




correct: 145987, predicted: 208245.0, accuracy: 0.7010348387716392




correct: 146491, predicted: 209381.0, accuracy: 0.6996384581217971




correct: 147011, predicted: 210595.0, accuracy: 0.6980745031933332




correct: 147536, predicted: 211723.0, accuracy: 0.6968350155627873




correct: 148045, predicted: 212748.0, accuracy: 0.6958702314475341




correct: 148587, predicted: 213835.0, accuracy: 0.6948675380550424




correct: 149050, predicted: 214747.0, accuracy: 0.6940725598029309




correct: 149464, predicted: 215553.0, accuracy: 0.6933979114185375




correct: 150039, predicted: 216744.0, accuracy: 0.6922406156571809




correct: 150627, predicted: 218163.0, accuracy: 0.690433299872114




correct: 151203, predicted: 219431.0, accuracy: 0.6890685454653172




correct: 151771, predicted: 220901.0, accuracy: 0.6870543818271534




correct: 152346, predicted: 222304.0, accuracy: 0.6853048078307183




correct: 152870, predicted: 223624.0, accuracy: 0.6836028333273709




NameError: name 'num_predict' is not defined