In [1]:
# Import libraries
import numpy
import torch
from transformers import BertTokenizer, BertModel

In [2]:
# Load pre-trained elements: the tokenizer and the model (set in evaluation mode)
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased', output_hidden_states=True)
model.eval()

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          

In [6]:
# Define input text, add the required special tokens
input_text = "This is an example of a BERT model used for tokenization and embedding."
marked_text = f'[CLS] {input_text} [SEP]'

In [7]:
# Split sentence into tokens
tokenized_text = tokenizer.tokenize(marked_text)
token_index = tokenizer.convert_tokens_to_ids(tokenized_text)
print(f'Sentence: {input_text}.\nTokens: {tokenized_text}.\nToken IDs: {token_index}')

Sentence: This is an example of a BERT model used for tokenization and embedding..
Tokens: ['[CLS]', 'this', 'is', 'an', 'example', 'of', 'a', 'bert', 'model', 'used', 'for', 'token', '##ization', 'and', 'em', '##bed', '##ding', '.', '[SEP]'].
Token IDs: [101, 2023, 2003, 2019, 2742, 1997, 1037, 14324, 2944, 2109, 2005, 19204, 3989, 1998, 7861, 8270, 4667, 1012, 102]


In [8]:
# Construct tensors from the token and segments id
token_segment = numpy.ones((1, len(tokenized_text)))
token_index_tensor = torch.tensor([token_index])
token_segment_tensor = torch.tensor(token_segment)
reduction_layer = torch.nn.Linear(768, 300)

In [9]:
# Pass the token tensors to BERT and collect the hidden states (12 layers), the last one correspond to embeddings
with torch.no_grad():

    outputs = model(token_index_tensor, token_segment_tensor)
    hidden_states = outputs[2]
    embeddings = hidden_states[0]
    reduced_embeddings = reduction_layer(embeddings)
    print(reduced_embeddings)

tensor([[[-0.4222,  0.2527, -0.0535,  ..., -0.0994,  0.2135,  0.5342],
         [-0.1658,  0.1376, -0.4312,  ..., -0.2158, -0.1292,  0.0435],
         [ 0.1490, -0.0029, -0.3371,  ..., -0.2064,  0.0136,  0.2738],
         ...,
         [-0.3511, -0.0590, -0.0117,  ..., -0.2653,  0.3050, -0.4061],
         [ 0.4254,  0.5965, -0.3524,  ..., -0.0609,  0.1970,  0.3926],
         [-0.2399,  0.3370, -0.2370,  ..., -0.1239,  0.0091,  0.4286]]])


torch.Size([1, 19, 768])