In [1]:
import torch
from transformers import BertModel, BertTokenizer
from sklearn.metrics.pairwise import cosine_similarity

In [2]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased')

In [3]:
print(f'Length of BERT vocabulary: {len(tokenizer.vocab)}')

Length of BERT vocabulary: 30522


In [4]:
text = "A simple sentence"

In [5]:
tokens = tokenizer.encode(text)
print(tokens)

[101, 1037, 3722, 6251, 102]


In [6]:
tokenizer.decode(tokens)

'[CLS] a simple sentence [SEP]'

In [7]:
text_with_unknown_words = "Lallallero mangia gli sgaglioppi freschi"
tokens_with_unknown_words = tokenizer.encode(text_with_unknown_words)
print(tokens)

[101, 1037, 3722, 6251, 102]


In [8]:
tokenizer.decode(tokens_with_unknown_words)

'[CLS] lallallero mangia gli sgaglioppi freschi [SEP]'

In [9]:
"Lallallero" in tokenizer.vocab

False

In [10]:
for t in tokens_with_unknown_words:
    print(f'Token: {t}, subword: {tokenizer.decode([t])}')

Token: 101, subword: [CLS]
Token: 21348, subword: lal
Token: 13837, subword: ##lal
Token: 3917, subword: ##ler
Token: 2080, subword: ##o
Token: 2158, subword: man
Token: 10440, subword: ##gia
Token: 1043, subword: g
Token: 3669, subword: ##li
Token: 22214, subword: sg
Token: 8490, subword: ##ag
Token: 12798, subword: ##lio
Token: 9397, subword: ##pp
Token: 2072, subword: ##i
Token: 10424, subword: fr
Token: 2229, subword: ##es
Token: 5428, subword: ##chi
Token: 102, subword: [SEP]


# Contextless and contextful embeddings

In [11]:
def get_embedding(text, index):
    '''
    Generates the tokens for the text in input, and returns the contextful embedding
    of the word specified by the "index" parameter.
    '''
    return model(torch.tensor(tokenizer.encode(text)).unsqueeze(0))[0][:, index, :].detach().numpy()

In [12]:
# "python" is the 4th token in both examples

python_pet = get_embedding('I own a python', 4)
python_lan = get_embedding('I write Python code', 3)

In [13]:
snake = get_embedding('snake', 1)
programming = get_embedding('programming', 1)

In [14]:
print(cosine_similarity(python_pet, snake))
print(cosine_similarity(python_pet, programming))

[[0.66385233]]
[[0.56651187]]


In [15]:
print(cosine_similarity(python_lan, snake))
print(cosine_similarity(python_lan, programming))

[[0.32544878]]
[[0.31438893]]


In [16]:
print(cosine_similarity(python_pet, python_lan))

[[0.45950758]]
