# Sentence BERT

In [1]:
# https://github.com/UKPLab/sentence-transformers/blob/master/sentence_transformers/models/Pooling.py
# https://arxiv.org/pdf/1908.10084.pdf

## BERT Embeddings

In [2]:
from transformers import AlbertTokenizer, AlbertModel
import torch
tokenizer = AlbertTokenizer.from_pretrained('albert-base-v2')
model = AlbertModel.from_pretrained('albert-base-v2', return_dict=True)
inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
outputs = model(**inputs)
last_hidden_states = outputs.last_hidden_state

In [3]:
inputs

{'input_ids': tensor([[    2, 10975,    15,    51,  1952,    25, 10901,     3]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1]])}

In [4]:
outputs.keys()

odict_keys(['last_hidden_state', 'pooler_output'])

In [5]:
token_embeddings = last_hidden_states
token_embeddings.shape

torch.Size([1, 8, 768])

In [6]:
cls_token_embeddings = token_embeddings[:, 0, :] # CLS token is first token
cls_token_embeddings.shape

torch.Size([1, 768])

In [7]:
features={}
features.update({'token_embeddings': token_embeddings, 'cls_token_embeddings': cls_token_embeddings, 'attention_mask': inputs['attention_mask']})
features

{'token_embeddings': tensor([[[ 1.3997,  1.5700,  0.3336,  ..., -0.0686,  0.2804,  0.8287],
          [ 0.3306,  0.3647,  0.7145,  ..., -0.5266,  1.2512, -0.7154],
          [ 1.1538,  0.6781, -1.6579,  ...,  0.6821,  0.3878,  0.4889],
          ...,
          [ 1.5001, -0.4411,  1.2422,  ...,  1.3102,  0.0211, -1.0564],
          [ 0.4044, -0.0901,  1.0914,  ...,  0.4799,  0.6582, -1.0785],
          [ 0.0455,  0.1439, -0.0616,  ..., -0.0906,  0.1141,  0.2033]]],
        grad_fn=<NativeLayerNormBackward>),
 'cls_token_embeddings': tensor([[ 1.3997e+00,  1.5700e+00,  3.3358e-01,  7.2619e-01, -1.9359e+00,
          -8.3133e-01,  4.7694e-02, -8.3818e-01, -1.0198e-01,  1.0294e+00,
           8.7844e-01,  1.2128e+00, -1.8456e-01,  4.1507e-01,  1.1024e+00,
           7.9191e-01,  2.9623e-01,  4.5332e-01, -8.9668e-02,  7.9707e-01,
           1.3770e+00, -3.2824e+00,  1.0356e+00, -5.1440e-01,  3.8124e-01,
          -3.1587e-01,  2.6293e-01,  6.1062e-02,  1.0659e+00, -1.5400e+00,
           2.

## Pooling-Mean tokens

In [8]:
token_embeddings = features['token_embeddings']
cls_token = features['cls_token_embeddings']
attention_mask = features['attention_mask']

In [9]:
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()

In [10]:
input_mask_expanded

tensor([[[1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         ...,
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.]]])

In [11]:
input_mask_expanded.shape

torch.Size([1, 8, 768])

In [12]:
sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
sum_embeddings.shape

torch.Size([1, 768])

In [13]:
sum_mask = input_mask_expanded.sum(1)
sum_mask

tensor([[8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8.,
         8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8.,
         8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8.,
         8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8.,
         8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8.,
         8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8.,
         8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8.,
         8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8.,
         8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8.,
         8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8.,
         8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8.,
         8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8.,
         8., 8., 8., 8., 8.,

In [14]:
output = sum_embeddings / sum_mask
output

tensor([[ 5.3711e-01,  3.8164e-01,  1.6220e-01, -2.3202e-01, -3.8620e-01,
         -3.1628e-01,  2.1457e-01,  5.8946e-02,  2.6022e-01,  7.2760e-01,
          4.8659e-01,  8.4328e-01, -1.8562e-02,  2.3552e-01,  1.9561e-01,
          1.0427e-01,  3.1211e-01, -1.6830e-01, -3.7542e-01,  6.5644e-02,
          7.2831e-01, -1.0729e+00,  1.0129e+00,  8.0353e-03, -1.6884e-01,
          1.2850e-01,  1.6900e-01,  5.4143e-02,  3.1661e-01, -1.6941e-01,
         -2.7125e-01,  3.6460e-01,  8.7100e-03,  3.6043e-01, -3.8949e+00,
         -3.0833e-01, -1.6282e-01,  4.5241e-01,  1.1155e-01, -1.9567e-01,
          1.2700e-01, -7.6327e-01,  1.8334e-01, -2.7725e-01,  9.4441e-02,
          8.4011e-02,  1.2114e-01, -7.8750e-01,  7.7898e-01,  7.2907e-01,
          2.2313e-01,  3.4782e-01, -3.9296e-01, -2.9293e-02, -1.1450e-02,
          2.5495e-01, -8.6097e-01, -9.5979e-01, -6.6943e-01,  3.7869e-01,
         -7.6500e-01, -1.6947e-01, -1.1520e-01, -1.4095e-01,  2.1193e-01,
          3.5699e-01,  2.0410e-01, -3.

In [15]:
output.shape # sentence embedding

torch.Size([1, 768])