# Word Sense Disambiguation (WSD)
## Using BERT Masked Language Model (LM)
This notebook explores the a part of the idea proposed by Ajit Rakasekharan in his blog post 
[Examining BERT raw embeddings.](https://towardsdatascience.com/examining-berts-raw-embeddings-fd905cb22df7) 

The idea is that examining the predictions of a masked language model for a masked ambiguous word can yield insights into the semantic meaning of the ambiguous word.

We use the HuggingFace BERT for Masked LM with weights from a bert-base-cased pre-trained model for our experiment.

We mask the ambiguous word (here we have used bank for our test) in sentences, and then send them through a BERT MLM model. Output is an array of logits for each position of the input sequence. So assuming a sentence with T tokens and a vocabulary size of V, the predictions of the MLM is (1, T, V) where 1 is the batch size (1 input sentence at a time in our experiment).

In order to find the top k predictions, the logits for the masked position is softmaxed and the top k values chosen.



## Prepare your environment

As always, we highly recommend that you install all packages with a virtual environment manager, like [venv](https://packaging.python.org/en/latest/guides/installing-using-pip-and-virtual-environments/) or [conda](https://docs.conda.io/projects/conda/en/latest/user-guide/getting-started.html), to prevent version conflicts of different packages.  

### Masked LM Model and Tokenizer 
[tutorial](https://huggingface.co/docs/transformers/tasks/language_modeling)  
Task is to predict words that are masked using BERT, so we will use BERTMaskedLM model and BERTTokenizer and use the pre-trained bert-base-uncased model.

In [1]:
import pandas as pd
import torch
from transformers import BertTokenizer, BertForMaskedLM

In [2]:
tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
model = BertForMaskedLM.from_pretrained('bert-base-cased', return_dict=True)

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


We are going to use the pre-trained BERT language model in inference mode only.

The tokenizer tokenizes the input sequence and pads it with the [CLS] and [SEP] tokens.

The output produced by the model has two components, loss and logits. The logits component has shape (1, number_of_tokens, vocab_size) where the leading 1 represents the single input sentence.

We will identify the logits corresponding to the position of our masked token, identify the top 5 vocabulary words predicted for that position, and return the softmax probabilities for each of the top 5 predicted words.

In [3]:
inputs = tokenizer("The capital of France is [MASK].", return_tensors="pt")
outputs = model(**inputs)

In [4]:
tokenizer.convert_ids_to_tokens(inputs.input_ids[0])

['[CLS]', 'The', 'capital', 'of', 'France', 'is', '[MASK]', '.', '[SEP]']

In [5]:
outputs

MaskedLMOutput(loss=None, logits=tensor([[[ -7.1545,  -6.9931,  -7.1826,  ...,  -5.9124,  -5.6733,  -5.9854],
         [ -8.0190,  -8.1319,  -8.0509,  ...,  -6.5679,  -6.4058,  -6.8998],
         [ -4.9772,  -6.1781,  -6.0669,  ...,  -5.6362,  -4.6603,  -5.1241],
         ...,
         [ -3.4420,  -3.2557,  -3.5733,  ...,  -2.4606,  -2.6495,  -3.1952],
         [-10.5890, -10.4621, -11.7181,  ...,  -7.4646,  -9.9543,  -8.3927],
         [-14.8900, -14.8873, -14.4569,  ..., -11.6588, -13.0151, -11.6073]]],
       grad_fn=<ViewBackward0>), hidden_states=None, attentions=None)

In [6]:
def get_mask_index(input_ids, tokenizer):
  x = input_ids[0]
  is_masked = torch.where(x == tokenizer.mask_token_id, x, 0)
  mask_idx = torch.nonzero(is_masked)
  return mask_idx.item()

mask_idx = get_mask_index(inputs.input_ids, tokenizer)
mask_idx

6

In [7]:
def get_top_k_predictions(pred_logits, mask_idx, top_k):
  probs = torch.nn.functional.softmax(pred_logits[0, mask_idx, :], dim=-1)
  top_k_weights, top_k_indices = torch.topk(probs, top_k, sorted=True)
  top_k_pct_weights = [100 * x.item() for x in top_k_weights]
  top_k_tokens = tokenizer.convert_ids_to_tokens(top_k_indices)
  return list(zip(top_k_tokens, top_k_pct_weights))

get_top_k_predictions(outputs.logits, mask_idx, 5)

[('Paris', 44.46825087070465),
 ('Lyon', 9.396003931760788),
 ('Toulouse', 8.234527707099915),
 ('Lille', 7.515132427215576),
 ('Marseille', 5.692283064126968)]

### WSD Test Sentences
We take our pair of sentences for disambiguating the word bank and mask them, and extract the top 20 predictions from the pre-trained BERT MLM model.

As expected, the first set of predictions predominantly point to some sort of financial institution, whereas the second set of predictions predominantly point to some geographical formation around bodies of water.

In [8]:
sentences = [
  "Go to the [MASK] and deposit your pay check.",
  "Jim and Janet went down to the river [MASK] to admire the swans."
]

In [9]:
def get_predictions(sentence, tokenizer, model):
  inputs = tokenizer(sentence, return_tensors="pt")
  outputs = model(**inputs)
  mask_idx = get_mask_index(inputs.input_ids, tokenizer)
  top_preds = get_top_k_predictions(outputs.logits, mask_idx, 20)
  return top_preds

In [10]:
get_predictions(sentences[0], tokenizer, model)

[('bank', 70.31400203704834),
 ('office', 10.280580818653107),
 ('register', 1.7451910302042961),
 ('store', 1.6284741461277008),
 ('bathroom', 0.9394760243594646),
 ('library', 0.893483217805624),
 ('desk', 0.8724337443709373),
 ('counter', 0.7977298460900784),
 ('hotel', 0.5163736641407013),
 ('lobby', 0.4956950433552265),
 ('kitchen', 0.36370735615491867),
 ('garage', 0.34799189306795597),
 ('door', 0.341272191144526),
 ('car', 0.3311359556391835),
 ('house', 0.26490497402846813),
 ('airport', 0.25470268446952105),
 ('elevator', 0.2491130493581295),
 ('back', 0.24807583540678024),
 ('computer', 0.24019514676183462),
 ('banks', 0.23491380270570517)]

In [11]:
get_predictions(sentences[1], tokenizer, model)

[('##bank', 32.60223567485809),
 ('below', 13.03189992904663),
 ('bank', 11.94087341427803),
 (',', 5.626500770449638),
 ('##boat', 3.1638897955417633),
 ('##front', 2.7332188561558723),
 ('basin', 1.6210518777370453),
 ('##bed', 1.2178423814475536),
 ('together', 1.1841695755720139),
 ('bed', 0.9657143615186214),
 ('again', 0.8369861170649529),
 ('deck', 0.8356181904673576),
 ('valley', 0.7271438371390104),
 ('mouth', 0.7227543275803328),
 ('boat', 0.7151042111217976),
 ('pier', 0.6493269931524992),
 ('house', 0.6301570683717728),
 ('banks', 0.5700557492673397),
 ('pool', 0.53457235917449),
 ('Thames', 0.4995575174689293)]

## Assignment
In this week's assignment, you are tasked with processing SemCor data and feed the data into BERT masked-LM. After that, use the predictions to find the most likely sense of the target word using WordNet similarity.

### Data Preprocessing 
You can find a sample of SemCor dataset [here](https://drive.google.com/file/d/1inmv3rUcGrtiS4VQwTMsT9HF-iL8jc5V/view?usp=sharing) and load the data using the following methods.

In [12]:
import json
from nltk.stem import WordNetLemmatizer
from nltk.corpus import wordnet as wn
sents = []
tokens = []
wn_ids = []
lemmatizer = WordNetLemmatizer()

with open('semcor.sample.jsonl') as f:
  for line in f:
    data = json.loads(line)
    sents.append(data['sent'])
    tokens.append(data['tokens'])
    wn_ids.append(data['wnid'])

In [13]:
print(sents[10])
print(tokens[10])
print(wn_ids[10])

implementation of georgia 's automobile title law was also recommended by the outgoing jury . 
['implementation', 'of', 'georgia', "'s", 'automobile', 'title', 'law', 'was', 'also', 'recommended', 'by', 'the', 'outgoing', 'jury', '.']
['implementation%1:04:01::', 0, 'georgia%1:15:00::', 0, 'automobile%1:06:00::', 'title%1:10:04::', 'law%1:10:00::', 0, 'also%4:02:00::', 'recommend%2:32:01::', 0, 0, 'outgoing%3:00:00::', 'jury%1:14:00::', 0]


In [14]:
import nltk

nltk.download("wordnet")
nltk.download("omw-1.4")

[nltk_data] Downloading package wordnet to /home/bill/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package omw-1.4 to /home/bill/nltk_data...
[nltk_data]   Package omw-1.4 is already up-to-date!


True

In [15]:
# The WordNet ID can be converted to NLTK Lemma using the following function
wn.lemma_from_key('implementation%1:04:01::')

Lemma('execution.n.06.implementation')

### TODO 
Please implement a method to convert the data to BERT Masked-LM format and keep track of the headword. Store the data into the following lists

headword[i] = 'implementation'  
ground_truth[i] = 'implementation%1:04:01::'  
sent[i] = "[MASK] of georgia 's automobile title law was also recommended by the outgoing jury ."  



In [16]:
sents[0]

'the fulton_county_grand_jury said friday an investigation of atlanta \'s recent primary_election produced " no evidence " that any irregularities took_place . '

In [17]:
tokens[0][:15]

['the',
 'fulton',
 'county',
 'grand',
 'jury',
 'said',
 'friday',
 'an',
 'investigation',
 'of',
 'atlanta',
 "'s",
 'recent',
 'primary',
 'election']

In [18]:
wn_ids[0][:15]

[0,
 0,
 0,
 0,
 0,
 'say%2:32:00::',
 'friday%1:28:00::',
 0,
 'investigation%1:09:00::',
 0,
 'atlanta%1:15:00::',
 0,
 'recent%3:00:00:past:00',
 'primary_election%1:04:00::',
 0]

In [19]:
headwords = []
ground_truths = []
for sent_idx in range(len(sents)):
  for word_idx in range(len(tokens[sent_idx])):
    try:
      ground_truth = wn.lemma_from_key(wn_ids[sent_idx][word_idx])
    except:
      continue
    if wn_ids[sent_idx][word_idx] == 0:
      continue
    headwords.append(tokens[sent_idx][word_idx])
    ground_truths.append(ground_truth)

In [20]:
headwords[:5]

['said', 'friday', 'investigation', 'atlanta', 'primary']

In [21]:
ground_truths[:5]

[Lemma('state.v.01.say'),
 Lemma('friday.n.01.Friday'),
 Lemma('probe.n.01.investigation'),
 Lemma('atlanta.n.01.Atlanta'),
 Lemma('primary.n.01.primary_election')]

# Generate masked sentences

In [22]:
def get_masked_sentence(sent, start_idx, end_idx):
  left_sent = sent[:start_idx]
  mid_sent = "[MASK]"
  right_sent = sent[end_idx + 1:]
  return left_sent + mid_sent + right_sent

get_masked_sentence("0123456789", 2, 4)

'01[MASK]56789'

In [23]:
def get_sent_masked_sentences(sent, sent_tokens, sent_wn_ids):
  masked_sentences = []
  sent_char_idx = 0
  for word_idx in range(len(sent_tokens)):
    try:
      ground_truth = wn.lemma_from_key(wn_ids[sent_idx][word_idx])
    except:
      continue
    if sent_wn_ids[word_idx] == 0:
      continue
    word = sent_tokens[word_idx]
    start_idx = sent.find(word, sent_char_idx)
    if start_idx == -1:
      print("ERROR: " + sent + " -> " + word)
    else:
      end_idx = start_idx + len(word) - 1
      masked_sentences.append(get_masked_sentence(sent, start_idx, end_idx))
      sent_char_idx = end_idx + 1
  return masked_sentences

get_sent_masked_sentences(
  "I said what !he said err test",
  ["I", "said", "what", "he", "said", "error", "test"],
  [0, 1, 0, 1, 1, 1, 1]
)

ERROR: I said what !he said err test -> error


['I said what !he said err [MASK]']

In [24]:
masked_sents = []
for sent_idx in range(len(sents)):
  masked_sents.extend(get_sent_masked_sentences(sents[sent_idx], tokens[sent_idx], wn_ids[sent_idx]))

masked_sents[:5]

['the fulton_county_grand_jury [MASK] friday an investigation of atlanta \'s recent primary_election produced " no evidence " that any irregularities took_place . ',
 'the fulton_county_grand_jury said [MASK] an investigation of atlanta \'s recent primary_election produced " no evidence " that any irregularities took_place . ',
 'the fulton_county_grand_jury said friday an [MASK] of atlanta \'s recent primary_election produced " no evidence " that any irregularities took_place . ',
 'the fulton_county_grand_jury said friday an investigation of [MASK] \'s recent primary_election produced " no evidence " that any irregularities took_place . ',
 'the fulton_county_grand_jury said friday an investigation of atlanta \'s recent [MASK]_election produced " no evidence " that any irregularities took_place . ']

In [25]:
len(masked_sents)

971

#### Identify the top 5 predictions other than the headword using Masked-LM 
1. Use get_predictions to get the predicted words
2. Use lemmatizer to lemmatize the prediction
3. Remove headword
4. Keep top 5 unique predictions

In [26]:
headwords_candidate_lemmas = []
for headword_idx in range(len(headwords)):
  predictions = get_predictions(masked_sents[headword_idx], tokenizer, model)
  no_headword_predictions = [prediction for prediction in predictions if prediction[0] != headwords[headword_idx]]
  headword_candidate_lemmas = [no_headword_predictions[idx][0] for idx in range(5)]
  headwords_candidate_lemmas.append(headword_candidate_lemmas)

headwords_candidate_lemmas[:5]

[['found', 'reported', ',', 'told', '_'],
 ['that', 'after', ',', 'in', 'during'],
 ['analysis', 'examination', 'audit', 'evaluation', 'inspection'],
 ['California', 'Obama', 'Alabama', 'Virginia', 'Arizona'],
 ['re', 'recall', 'by', 'municipal', 'mayor']]

example:  
candidate_lemmas = ['office', 'register', 'store', 'bathroom', 'library']

# Calculate sense cost

In [27]:
test_synsets = wn.synsets("investigation")
test_synsets

[Synset('probe.n.01'), Synset('investigation.n.02')]

In [28]:
test_candidate_lemmas = headwords_candidate_lemmas[2]
test_candidate_lemmas

['analysis', 'examination', 'audit', 'evaluation', 'inspection']

In [29]:
def get_headword_synset_cost(headword_synset, headword_candidate_lemmas):
  cost = 0
  for candidate_lemma in headword_candidate_lemmas:
    for candidate_synset in wn.synsets(candidate_lemma):
      cost += headword_synset.wup_similarity(candidate_synset)
  return cost

In [30]:
get_headword_synset_cost(test_synsets[0], test_candidate_lemmas)

6.133268030185516

In [31]:
get_headword_synset_cost(test_synsets[1], test_candidate_lemmas)

9.668906617977827

In [32]:
import math

def get_headword_predicted_synset(headword, headword_candidate_lemmas):
  max_cost = -math.inf
  predicted_synset = None
  for headword_synset in wn.synsets(headword):
    cost = get_headword_synset_cost(headword_synset, headword_candidate_lemmas)
    if cost > max_cost:
      cost = max_cost
      predicted_synset = headword_synset
  return predicted_synset

get_headword_predicted_synset("investigation", test_candidate_lemmas)

Synset('investigation.n.02')

Identify the most similar sense of headword with relation to the 5 unique candidates

In [33]:
predicted_synsets = []
for headword_idx in range(len(headwords)):
  predicted_synset = get_headword_predicted_synset(headwords[headword_idx],
                                                   headwords_candidate_lemmas[headword_idx])
  predicted_synsets.append(predicted_synset)

predicted_synsets[:5]

[Synset('aforesaid.s.01'),
 Synset('friday.n.01'),
 Synset('investigation.n.02'),
 Synset('atlanta.n.02'),
 Synset('basal.s.03')]

For evaluation purpose, for i = 50, please run the process and print out the following:  
1. word[50]
2. ground_truth[50] (in synset or lemma)
3. sent[50]
4. candidate_lemmas
5. predicted_sense (in synset or lemma)    

Also, please print out the accuracy of the process over our dataset

In [34]:
print(headwords[50])
print(ground_truths[50])
print(sents[50])
print(headwords_candidate_lemmas[50])
print(predicted_synsets[50])

find
Lemma('rule.v.04.find')
he will be succeeded by rob_ledford of gainesville , who has been an assistant more than three years . 
['note', 'show', 'notice', 'believe', 'say']
Synset('find_oneself.v.01')


In [35]:
ground_truths[50]

Lemma('rule.v.04.find')

In [36]:
correct_nums = 0
total_nums = len(headwords)
for headword_idx in range(len(headwords)):
  if len(wn.synsets(headwords[headword_idx])) == 0:
    total_nums -= 1
    continue
  if ground_truths[headword_idx] in predicted_synsets[headword_idx].lemmas():
    correct_nums += 1
correct_nums / total_nums

0.8317757009345794

## TA's Note

Congratulations, you made it to the end of the tutorial! Make sure you make an appointment to show your work and turn in your finished assignment before next week's lesson. We will ask you to run your code, so double check that everything is working and that your model is saved. Don't worry if you didn't pass the evaluation requirements, you'll still get partial points for trying.