In [27]:
#!/usr/bin/env python
from collections import defaultdict
import sys
import itertools

from lexsub_xml import read_lexsub_xml
from lexsub_xml import Context 

# suggested imports 
from nltk.corpus import wordnet as wn
from nltk.corpus import stopwords

import numpy as np
import tensorflow

#tensorflow.compat.v1.disable_eager_execution()
# import logging
# logging.disable(logging.INFO)

import gensim
import transformers

#from transformers.utils import hf_logging
#hf_logging.disable_progress_bar()
#transformers.utils.logging.set_verbosity(transformers.logging.ERROR)
# import evaluate
# evaluate.logging.set_verbosity_error()



from typing import List
import string

def tokenize(s): 
    """
    a naive tokenizer that splits on punctuation and whitespaces.  
    """
    s = "".join(" " if x in string.punctuation else x for x in s.lower())    
    return s.split() 

def get_candidates(lemma, pos) -> List[str]:
    # Part 1
    ret = set()
    for synset in wn.synsets(lemma, pos=pos):
        for lexeme in synset.lemmas():
            if lexeme.name() == lemma:
                continue
            elif lexeme.name().find("_") != -1:
                ret.add(lexeme.name().replace("_", " "))
            else:
                ret.add(lexeme.name())
    return list(ret)

def smurf_predictor(context : Context) -> str:
    """
    suggest 'smurf' as a substitute for all words.
    """
    return 'smurf'

def wn_frequency_predictor(context : Context) -> str:
    #return None # replace for part 2
    #Context obj: self.cid, self.word_form, self.lemma, self.pos, self.left_context, self.right_context
    synsets = wn.synsets(context.lemma, pos=context.pos)#Get the synonym set that the input word relates to
    frequency = defaultdict(int)
    for synset in synsets:
        for lexeme in synset.lemmas():
            if lexeme.name() != context.lemma:#Consider the lemmas that aren't the input lemma (synonyms)
                frequency[lexeme.name()] += lexeme.count()#Record the number of occurences for each synonym sharing word sense with input lemma
                
    return max(frequency, key=frequency.get).replace("_", " ")
        
        
        

def wn_simple_lesk_predictor(context : Context) -> str:
    synsets = wn.synsets(context.lemma, pos=context.pos)#Get the synonym set that the input word relates to
    stop_words = stopwords.words('english')
    max_overlap = 0
    overlap_dict = defaultdict(int)
    for synset in synsets:
        #Tokenize and filter out stop words of synset definition, left context, right context, and examples
        definitions = [ [word.lower() for word in tokenize(synset.definition()) if word.lower() not in stop_words] ]
        left_context = tokenize(" ".join([word.lower() for word in context.left_context if word.lower() not in stop_words]))
        right_context = tokenize(" ".join([word.lower() for word in context.right_context if word.lower() not in stop_words]))
        examples = []
        for example in synset.examples():
            examples.append([word.lower() for word in tokenize(example) if word.lower() not in stop_words])
        #Do same filtering and tokenization for hypernym definitions and examples
        for synset_hyper in synset.hypernyms():
            definitions.append( [word.lower() for word in tokenize(synset_hyper.definition()) if word.lower() not in stop_words] )
            for example in synset_hyper.examples():
                examples.append([word.lower() for word in tokenize(example) if word.lower() not in stop_words])
        #Get the overlap
        overlap = 0
        for gloss in definitions + examples:
            overlap += len(set(gloss) & set(left_context)) + len(set(gloss) & set(right_context))
        # if overlap > 0:
        #     overlap_dict[synset] = overlap
        overlap_dict[synset] = overlap
        if overlap > max_overlap:
            max_overlap = overlap
    #best_synsets = [synset for (synset, overlap) in overlap_dict.items() if overlap == max_overlap]
    best_synsets = [synset for (synset, overlap) in overlap_dict.items() if overlap == max_overlap and overlap != 0]
    
    lexemes = []
    for synset in (best_synsets if best_synsets else synsets):
        for lexeme in synset.lemmas():
            if lexeme.name() != context.lemma:
                lexemes.append(lexeme)
    if lexemes:
        return max(lexemes, key=lambda x: x.count()).name().replace("_", " ")
    else:
        return "smurf"
                
            
    #return None #replace for part 3
   

class Word2VecSubst(object):
        
    def __init__(self, filename):
        self.model = gensim.models.KeyedVectors.load_word2vec_format(filename, binary=True)    

    def predict_nearest(self,context : Context) -> str:
        syns = get_candidates(context.lemma, context.pos)
        lemmas = []
        for syn in syns:
            try:
                lemmas.append((syn, self.model.similarity(context.lemma, syn.replace(" ",  "_"))))
            except:
                continue
        if lemmas:
             return max(lemmas, key=lambda lemma: lemma[1])[0]
        else:
            return "smurf"
        #return None # replace for part 4


class BertPredictor(object):

    def __init__(self): 
        self.tokenizer = transformers.DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
        self.model = transformers.TFDistilBertForMaskedLM.from_pretrained('distilbert-base-uncased')

    def predict(self, context : Context) -> str:
        syns = get_candidates(context.lemma, context.pos)
        left_context = ''
        for word in context.left_context:
            if word.isalpha():
                left_context = left_context + ' ' + word
            else:#Handle things like "do n't -> don't"    ;    Dubhghall , from -> Dubhghall, from
                left_context = left_context + word

        sentence = left_context + ' [MASK]'

        right_context = ''
        for word in context.right_context:
            if word.isalpha():
                right_context = right_context + ' ' + word
            else:#Handle things like "do n't -> don't"    ;    Dubhghall , from -> Dubhghall, from
                right_context = right_context + word
        
        sentence = sentence + right_context

        input_toks_encoded = self.tokenizer.encode(sentence)
        mask_index = self.tokenizer.convert_ids_to_tokens(input_toks_encoded).index('[MASK]')
        input_mat = np.array(input_toks_encoded).reshape((1,-1))
        outputs = self.model.predict(input_mat, verbose=0)
        predictions = outputs[0]
        best_words_indices = np.argsort(predictions[0][mask_index])[::-1] # Sort in increasing order
        best_words = self.tokenizer.convert_ids_to_tokens(best_words_indices)
        for word in best_words:
            if word.replace("_", " ") in syns:
                return word.replace("_", " ")
        return ""
        #return None # replace for part 5
        

def part3(context : Context) -> str:
    synsets = wn.synsets(context.lemma, pos=context.pos)#Get the synonym set that the input word relates to
    stop_words = stopwords.words('english')
    max_overlap = 0
    overlap_dict = defaultdict(int)
    for synset in synsets:
        #Tokenize and filter out stop words of synset definition, left context, right context, and examples
        definitions = [ [word.lower() for word in tokenize(synset.definition()) if word.lower() not in stop_words] ]
        left_context = tokenize(" ".join([word.lower() for word in context.left_context if word.lower() not in stop_words]))
        right_context = tokenize(" ".join([word.lower() for word in context.right_context if word.lower() not in stop_words]))
        examples = []
        for example in synset.examples():
            examples.append([word.lower() for word in tokenize(example) if word.lower() not in stop_words])
        #Do same filtering and tokenization for hypernym definitions and examples
        for synset_hyper in synset.hypernyms():
            definitions.append( [word.lower() for word in tokenize(synset_hyper.definition()) if word.lower() not in stop_words] )
            for example in synset_hyper.examples():
                examples.append([word.lower() for word in tokenize(example) if word.lower() not in stop_words])
        #Get the overlap
        overlap = 0
        for gloss in definitions + examples:
            overlap += len(set(gloss) & set(left_context)) + len(set(gloss) & set(right_context))
        overlap_dict[synset] = overlap
        if overlap > max_overlap:
            max_overlap = overlap
    #best_synsets = [synset for (synset, overlap) in overlap_dict.items() if overlap == max_overlap and overlap != 0]
    best_synsets = [synset for (synset, overlap) in overlap_dict.items() if overlap == max_overlap]
    lexemes = []
    best_synset_frequency = 0
    # lexeme_frequency_dictionary = defaultdict(int)
    # synset_dict = defaultdict(int)
    # for synset in (best_synsets if best_synsets else synsets):#Should handle if overlap/no overlap exists
    #     synset_dict[synset] = 0
    #     for lexeme in synset.lemmas():
    #         if lexeme.name() != context.lemma:#Consider the lemmas that aren't the input lemma (synonyms)
    #             lexeme_frequency_dictionary[lexeme.name()] += lexeme.count()#Record the number of occurences for each synonym sharing word sense with input lemma 
    #Get the most frequent synset(s) (synset that has the highest frequency counts of lexemes)
    synset_dict = defaultdict(int)
    #print(best_synsets)
    for synset in (best_synsets if best_synsets else synsets):#Should handle if overlap/no overlap exists
        frequency = sum([lexeme.count() for lexeme in synset.lemmas() if lexeme.name() != context.lemma])
        synset_dict[synset] = frequency
        if frequency > best_synset_frequency:
            best_synset_frequency = frequency
    most_frequent_synsets = [synset for synset, freq in synset_dict.items() if freq == best_synset_frequency]
    #Select most frequent lexeme from synset(s)
    if most_frequent_synsets:
        return max(itertools.chain(*[synset.lemmas() for synset in most_frequent_synsets]), key=lambda lexeme: lexeme.count()).name().replace("_", " ")
    else:
        return "smurf"
    
    # for lexeme in synset.lemmas():
    #     if lexeme.name() != context.lemma:
    #         lexemes.append(lexeme)
    # if lexemes:
    #     return max(lexemes, key=lambda x: x.count()).name().replace("_", " ")
    # else:
    #     return "smurf"
    
tokenizer = transformers.DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
model = transformers.TFDistilBertForMaskedLM.from_pretrained('distilbert-base-uncased')
wv_model = gensim.models.KeyedVectors.load_word2vec_format('GoogleNews-vectors-negative300.bin.gz', binary=True)

Some layers from the model checkpoint at distilbert-base-uncased were not used when initializing TFDistilBertForMaskedLM: ['activation_13']
- This IS expected if you are initializing TFDistilBertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing TFDistilBertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
All the layers of TFDistilBertForMaskedLM were initialized from the model checkpoint at distilbert-base-uncased.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFDistilBertForMaskedLM for predictions without further training.


In [43]:
def predict(context : Context) -> str:
    global tokenizer
    global model
    
    syns = get_candidates(context.lemma, context.pos)
    left_context = ''
    for word in context.left_context:
        if word.isalpha():
            left_context = left_context + ' ' + word
        else:#Handle things like "do n't -> don't"    ;    Dubhghall , from -> Dubhghall, from
            left_context = left_context + word

    sentence = left_context + ' [MASK]'

    right_context = ''
    for word in context.right_context:
        if word.isalpha():
            right_context = right_context + ' ' + word
        else:#Handle things like "do n't -> don't"    ;    Dubhghall , from -> Dubhghall, from
            right_context = right_context + word

    sentence = sentence + right_context
    
    #Input CLS sentence SEP sentence SEP to BERT
    input_toks_encoded = tokenizer.encode(left_context + right_context) + tokenizer.encode(sentence)[1:]
    mask_index = tokenizer.convert_ids_to_tokens(input_toks_encoded).index('[MASK]')
    input_mat = np.array(input_toks_encoded).reshape((1,-1))
    outputs = model.predict(input_mat, verbose=0)
    predictions = outputs[0]
    best_words_indices = np.argsort(predictions[0][mask_index])[::-1] # Sort in increasing order
    best_words = tokenizer.convert_ids_to_tokens(best_words_indices)
    
    best_word = ""
    scores = []
    i = 0
    for word in best_words:
        if i > 10:
            break
        #print(i, word)
        word_ = best_words[i]
        #word_ = word.replace("_", " ")
        if word_.replace("_", " ") != context.lemma:
            score = 0
            for syn in syns:
                try:
                    score += wv_model.similarity(word_, syn.replace(" ",  "_"))
                except:
                    continue
            scores.append((word_, score/len(syns)))
            i += 1
        else:
            continue
    
    if scores:
        return max(scores)[0].replace("_", " ")
    else:
        return "smurf"
            
                    
            
#         if word.replace("_", " ") in syns and word.replace("_", " ") != context.lemma:
#             return word.replace("_", " ")
    
    #return ""
    #return None # replace for part 5

In [46]:
with open('smurf6.predict', 'w') as f:
  for context in read_lexsub_xml("lexsub_trial.xml"):
    #print(context)  # useful for debugging
    prediction = predict(context)
    print(prediction)
    print("{}.{} {} :: {}".format(context.lemma, context.pos, context.cid, prediction), file=f)
    #print("{}.{} {} :: {}".format(context.lemma, context.pos, context.cid, prediction))
#!perl score.pl smurf6.predict gold.trial

younger
thinner
tiny
strong
worst
youngest
true
deep
wide
wrong
productions
smurf
reality
smurf
music
smurf
serials
smurf
project
ventures
your
weigh
smurf
took
ignore
have
unfolded
get
smurf
took
wet
grown
were
yours
smaller
smurf
warm
too
up
open
quality
smurf
sky
system
powder
variable
waist
smurf
rooms
survive
product
v
smurf
hybrid
truth
whale
smurf
test
racial
used
was
）


UnicodeEncodeError: 'charmap' codec can't encode character '\uff09' in position 16: character maps to <undefined>

In [22]:
!perl score.pl smurf5.predict gold.trial

Total = 298, attempted = 298
precision = 0.117, recall = 0.117
Total with mode 206 attempted 206
precision = 0.175, recall = 0.175
