In [60]:
from transformers import BertTokenizer, BertForMaskedLM
import torch
import numpy as np

In [152]:
ambig = "[CLS] When the old panel was not replaced, people demanded fair elections within three months. [SEP]"
a_target = "panel"
control = "[CLS] When the old mayor was not replaced, people demanded fair elections within three months. [SEP]"
c_target = "mayor"

In [153]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
text_toks = tokenizer.tokenize(ambig)
masked_index = text_toks.index(a_target)
text_toks[masked_index] = "[MASK]"
indexed_tokens = tokenizer.convert_tokens_to_ids(text_toks)
segments_ids = [1] * len(indexed_tokens)

tokens_tensor, segments_tensors = torch.tensor([indexed_tokens]), torch.tensor([segments_ids])


In [154]:
model = BertForMaskedLM.from_pretrained('bert-base-uncased')
model.eval()

with torch.no_grad():
    outputs = model(tokens_tensor, token_type_ids=segments_tensors)
    predictions = outputs[0]

predicted_index = torch.argmax(predictions[0, masked_index]).item()
predicted_token = tokenizer.convert_ids_to_tokens([predicted_index])[0]
predicted_token

'system'

In [155]:
predictions = torch.squeeze(predictions, dim = 0)
sorted_ids = torch.argsort(predictions[masked_index])
tokens_in_order = tokenizer.convert_ids_to_tokens(sorted_ids)

np.argwhere(np.array(tokens_in_order) == a_target)

array([[30155]])

In [156]:
np.argwhere(np.array(tokens_in_order) == c_target)

array([[30465]])

In [157]:
np.flip(tokens_in_order)[:10]

array(['system', 'constitution', 'parliament', 'government',
       'administration', 'regime', 'legislature', 'cabinet', 'president',
       'one'], dtype='<U18')

In [158]:
softmax_probs = torch.nn.Softmax(dim = -1)(predictions[masked_index])
[softmax_probs[i].item() for i in sorted_ids[-10:]]

[0.014033270999789238,
 0.014475952833890915,
 0.015449920669198036,
 0.018452061340212822,
 0.02065141871571541,
 0.02477291226387024,
 0.04757308214902878,
 0.11762050539255142,
 0.1922246515750885,
 0.3452235460281372]

In [147]:
len(softmax_probs) - 30155

264

In [161]:
len(softmax_probs) - 30465

57

In [159]:
softmax_probs[30155]

tensor(2.0226e-08)

In [160]:
softmax_probs[30465]

tensor(3.5374e-08)