Examining embeddings. 

Reimplemented example [from this notebook](https://colab.research.google.com/drive/1yFphU6PW9Uo6lmDly_ud9a6c4RCYlwdX#scrollTo=1RfUN_KolV-f) which works with BERT2.
[Original post is here.](http://mccormickml.com/2019/05/14/BERT-word-embeddings-tutorial/) 

We'll use the BERT model to extract embeddings for a sentence for our fake social network, Flutter, and compare them to other embeddings for Flutter. 

In [1]:
# transformers for PyTorch
!pip install transformers

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting transformers
  Downloading transformers-4.27.4-py3-none-any.whl (6.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.8/6.8 MB[0m [31m16.0 MB/s[0m eta [36m0:00:00[0m
Collecting huggingface-hub<1.0,>=0.11.0
  Downloading huggingface_hub-0.13.3-py3-none-any.whl (199 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m199.8/199.8 KB[0m [31m8.5 MB/s[0m eta [36m0:00:00[0m
Collecting tokenizers!=0.11.3,<0.14,>=0.11.1
  Downloading tokenizers-0.13.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.8/7.8 MB[0m [31m13.2 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: tokenizers, huggingface-hub, transformers
Successfully installed huggingface-hub-0.13.3 tokenizers-0.13.3 transformers-4.27.4


In [2]:
import torch
from transformers import BertTokenizer, BertModel

# Logging Info
import logging

logging.basicConfig(level=logging.INFO)

import matplotlib.pyplot as plt

%matplotlib inline

# Load pre-trained model tokenizer (vocabulary)
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

In [15]:
text = "Hold fast to dreams, for if dreams die, life is a broken-winged bird that cannot fly."
marked_text = "[CLS] " + text + " [SEP]"

# Tokenize our sentence with the BERT tokenizer.
tokenized_text = tokenizer.tokenize(marked_text)

# Print out the tokens.
print(tokenized_text)
print(len(tokenized_text))

['[CLS]', 'hold', 'fast', 'to', 'dreams', ',', 'for', 'if', 'dreams', 'die', ',', 'life', 'is', 'a', 'broken', '-', 'winged', 'bird', 'that', 'cannot', 'fly', '.', '[SEP]']
23


In [4]:
# Show indices of the text
text = "Hold fast to dreams, for if dreams die, life is a broken-winged bird that cannot fly."

# Add the special tokens.
marked_text = "[CLS] " + text + " [SEP]"

# Split the sentence into tokens.
tokenized_text = tokenizer.tokenize(marked_text)

# Map the token strings to their vocabulary indeces.
indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)

# Display the words with their indeces.
for tup in zip(tokenized_text, indexed_tokens):
    print("{:<12} {:>6,}".format(tup[0], tup[1]))

[CLS]           101
hold          2,907
fast          3,435
to            2,000
dreams        5,544
,             1,010
for           2,005
if            2,065
dreams        5,544
die           3,280
,             1,010
life          2,166
is            2,003
a             1,037
broken        3,714
-             1,011
winged       14,462
bird          4,743
that          2,008
cannot        3,685
fly           4,875
.             1,012
[SEP]           102


In [5]:
# Mark each of the 23 tokens as belonging to sentence "1".
segments_ids = [1] * len(tokenized_text)

print(segments_ids)

[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]


In [6]:
# Convert inputs to PyTorch tensors
tokens_tensor = torch.tensor([indexed_tokens])
segments_tensors = torch.tensor([segments_ids])

In [7]:
print(indexed_tokens)

[101, 2907, 3435, 2000, 5544, 1010, 2005, 2065, 5544, 3280, 1010, 2166, 2003, 1037, 3714, 1011, 14462, 4743, 2008, 3685, 4875, 1012, 102]


In [8]:
# Load pre-trained model (weights)
model = BertModel.from_pretrained(
    "bert-base-uncased",
    output_hidden_states=True,  # Whether the model returns all hidden-states.
)

# Put the model in "evaluation" mode, meaning feed-forward operation.
model.eval()

Downloading pytorch_model.bin:   0%|          | 0.00/440M [00:00<?, ?B/s]

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-11): 12 x BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
  

In [9]:
# Run the text through BERT, and collect all of the hidden states produced
# from all 12 layers.
with torch.no_grad():
    outputs = model(tokens_tensor, segments_tensors)

    # Evaluating the model will return a different number of objects based on
    # how it's  configured in the `from_pretrained` call earlier. In this case,
    # becase we set `output_hidden_states = True`, the third item will be the
    # hidden states from all layers. See the documentation for more details:
    # https://huggingface.co/transformers/model_doc/bert.html#bertmodel
    hidden_states = outputs[2]

In [10]:
# Concatenate the tensors for all layers. We use `stack` here to
# create a new dimension in the tensor.
token_embeddings = torch.stack(hidden_states, dim=0)

token_embeddings.size()

torch.Size([13, 1, 23, 768])

In [11]:
# Remove dimension 1, the "batches".
token_embeddings = torch.squeeze(token_embeddings, dim=1)

token_embeddings.size()

torch.Size([13, 23, 768])

In [12]:
# Swap dimensions 0 and 1.
token_embeddings = token_embeddings.permute(1, 0, 2)

token_embeddings.size()

torch.Size([23, 13, 768])

In [13]:
# Stores the token vectors, with shape [23 x 768]
embeddings = []

# `token_embeddings` is a [23 x 12 x 768] tensor.

# For each token in the sentence...
for token in token_embeddings:
    # `token` is a [12 x 768] tensor

    # Sum the vectors from the last four layers.
    sum_vec = torch.sum(token[-4:], dim=0)

    # Use `sum_vec` to represent `token`.
    embeddings.append(sum_vec)

print("Shape is: %d x %d" % (len(embeddings), len(embeddings[0])))

Shape is: 23 x 768


In [14]:
embeddings

[tensor([-3.0241e-01, -1.5066e+00, -9.6222e-01,  1.7986e-01, -2.7384e+00,
         -1.6749e-01,  7.4106e-01,  1.9655e+00,  4.9202e-01, -2.0871e+00,
         -5.8469e-01,  1.5016e+00,  8.2666e-01,  8.7033e-01,  8.5101e-01,
          5.5919e-01, -1.4336e+00,  2.4679e+00,  1.3920e+00, -3.9291e-01,
         -1.2054e+00,  1.4637e+00,  1.9681e+00,  3.6572e-01,  3.1503e+00,
         -4.4693e-01, -1.1637e+00,  2.8804e-01, -8.3749e-01,  1.5026e+00,
         -2.1318e+00,  1.9633e+00, -4.5096e-01, -1.8215e+00,  3.2744e+00,
          5.2591e-01,  1.0686e+00,  3.7893e-01, -1.0792e-01,  5.1342e-01,
         -1.0443e+00,  1.7513e+00,  1.3895e-01, -6.6757e-01, -4.8434e-01,
         -2.1621e+00, -1.5593e+01,  1.5249e+00,  1.6911e+00, -1.2916e+00,
          1.2339e+00, -3.6064e-01, -9.6036e-01,  1.3226e+00,  1.6427e+00,
          1.4588e+00, -1.8806e+00,  6.3620e-01,  1.1713e+00,  1.1050e+00,
         -6.0766e-01,  1.1026e+00, -1.0001e+00,  1.8021e+00, -1.3309e+00,
          9.0621e-01,  2.2228e+00,  5.