In this notebook, we'll explore the basics of token representations in BERT and use it to find token nearest neighbors.  You should open this notebook in Google Colab, or use smaller BERT models locally (as in previous notebooks).


In [None]:
!pip install transformers

In [None]:
from transformers import BertModel, BertTokenizer
import numpy as np

In [None]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased')

BERT uses WordPiece tokenization, which breaks down words that don't appear within its 30K-word vocabulary into small pieces.  The word "vaccinated", for instanced, is tokenized as ["va", "##cci", "##nated"]

In [None]:
inputs=tokenizer("New data shows 26 states have fully vaccinated more than half their residents.", return_tensors="pt")
tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])

In [None]:
inputs=tokenizer("BERT is supercalifragilisticexpialidocious", return_tensors="pt")
tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])

As mentioned in class, notice how every sentence input to BERT is wrapped in two tags: a start [CLS] tag and an ending [SEP] tag.  BERT will generate representations of each WordPiece token, including these special [CLS] and [SEP] tags.

To generate representations for the input tokens, simply pass the input through the BERT model:

In [None]:
inputs=tokenizer("This jam is delicious", return_tensors="pt")
outputs = model(**inputs)
tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])

Representations for each of BERT layers (12 in this model) are accessible here, but let's explore just the outputs from the final layer.  This BERT model has 768-dimensional representations, so this 6-token input ([CLS, this, jam, is, delicious, [SEP]) has an output that is is a 1 x 6 tokens x 768 dimensional tensor.

In [None]:
last_hidden_states = outputs.last_hidden_state

In [None]:
print(outputs.last_hidden_state.shape)

What can we do with just these representations?  While we used word2vec-style static embeddings to find nearest neighbors for word *types*, we can do the same here for word *tokens*.

In [None]:
def cosine_similarity(a, b):
    return np.dot(a, b)/(np.linalg.norm(a)*np.linalg.norm(b))

In [None]:
query="I ate some jam with toast"

In [None]:
comp_sents=["She got me out of a real jam", "This jam is made of strawberries", "I sat in a traffic jam for 2 hours", "The Grateful Dead used to jam for like two days straight.", "My grandma makes the best jam.", "I had to jam on the brakes to avoid hitting him."]

In [None]:
def get_bert_for_token(string, term):

    # tokenize
    inputs = tokenizer(string, return_tensors="pt")

    # convert input ids to words
    tokens=tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])

    # find the first location of the query term among those tokens (so we know which BERT rep to use)
    term_idx=tokens.index(term)

    outputs = model(**inputs)

    # return the BERT rep for that token index
    # The output is a pytorch tensor object, but let's convert it to a numpy object to work with numpy functions

    return outputs.last_hidden_state[0][term_idx].detach().numpy()



In [None]:
query_rep=get_bert_for_token(query, "jam")
print(query_rep.shape)

In [None]:
vals=[]
for sent in comp_sents:
    comp_rep=get_bert_for_token(sent, "jam")
    cos_sim=cosine_similarity(query_rep, comp_rep)
    vals.append((cos_sim, query, sent))

for c, q, s in reversed(sorted(vals)):
    print("%.3f\t%s\t%s" % (c, q, s))

**Q**: Brainstorm the variety of things you can do with token-level word representations like this and we'll discuss them at the end of class.  As one example, adapt the code above to find *any* word that is most similar to *jam* in "I ate some jam with toast" in the following sentences.  Are you able to find token-level synonyms this way?

In [None]:
comp_sents=["My grandma makes the best jelly.", "Jelly is made of strawberries"]