# BERT Step by Step: Contextual representations

Contextual representations are dynamic word embeddings that capture the meaning of a word based on its surrounding context. Unlike static embeddings, they adapt to different usages of the same word in different sentences.

In [1]:
import torch
from transformers import AutoConfig, AutoTokenizer
from transformers import BertForPreTraining

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model_checkpoint = 'bert-base-uncased'

tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
model = BertForPreTraining.from_pretrained(model_checkpoint)
config = AutoConfig.from_pretrained(model_checkpoint)

In [3]:
# small experiment with similarities between embeddings

token1 = tokenizer.convert_tokens_to_ids(['bank'])
embedding1 = model.bert.embeddings.word_embeddings(torch.tensor(token1))

token2 = tokenizer.convert_tokens_to_ids(['vault'])
embedding2 = model.bert.embeddings.word_embeddings(torch.tensor(token2))

# cosine similarity
# embedding1 @ embedding2.T / torch.norm(embedding1) / torch.norm(embedding2)
cos = torch.nn.CosineSimilarity()

similarity = cos(embedding1, embedding2)
print(f"Similarity: {similarity.detach().numpy()[0]:.2f}")

Similarity: 0.29


In [4]:
sent1 = "We deposited the check at the bank."
# sent1 = "They sat on the bank of the river."
sent2 = "All the valuables are safe in the vault."
# sent2 = "We deposited the check at the bank."

inputs1 = tokenizer(sent1, return_tensors="pt")
inputs2 = tokenizer(sent2, return_tensors="pt")

bank_idx1 = (inputs1['input_ids'][0] == tokenizer.convert_tokens_to_ids("bank")).nonzero().item()
bank_idx2 = (inputs2['input_ids'][0] == tokenizer.convert_tokens_to_ids("vault")).nonzero().item()

with torch.no_grad():
    output1 = model.bert(**inputs1)
    output2 = model.bert(**inputs2)

vec1 = output1.last_hidden_state[0, bank_idx1, :]
vec2 = output2.last_hidden_state[0, bank_idx2, :]

similarity = cos(vec1.unsqueeze(0), vec2.unsqueeze(0)).item()
print(f"Contextual similarity: {similarity:.2f}")

Contextual similarity: 0.52
