In [3]:
import torch
from pytorch_pretrained_bert import BertTokenizer, BertModel, BertForMaskedLM, BertEmbeddings

In [4]:
class BertEmbeddingsHT(BertEmbeddings):
    def __init__(self, config):
        super().__init__(config)
    
    def forward(self, input_ids, token_type_ids=None):
        seq_length = input_ids.size(1)
        position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
        position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
        if token_type_ids is None:
            token_type_ids = torch.zeros_like(input_ids)

        words_embeddings = self.word_embeddings(input_ids)
        position_embeddings = self.position_embeddings(position_ids)
        token_type_embeddings = self.token_type_embeddings(token_type_ids)

        # Apply non-linear transformation (e.g., sigmoid, tanh) to embeddings
        # to make them fat-tailed
        embeddings = torch.sigmoid(words_embeddings) + torch.tanh(position_embeddings) + torch.relu(token_type_embeddings)

        # embeddings = words_embeddings + position_embeddings + token_type_embeddings
        embeddings = self.LayerNorm(embeddings)
        embeddings = self.dropout(embeddings)
        return embeddings

In [5]:
class BertModelHT(BertModel):
    def __init__(self, config):
        super().__init__(config)
        self.embeddings = BertEmbeddingsHT(config)
    
    def forward(self, input_ids, token_type_ids=None, attention_mask=None, output_all_encoded_layers=True):
        return super().forward(input_ids, token_type_ids, attention_mask, output_all_encoded_layers)

In [6]:
model = BertModelHT.from_pretrained('bert-base-multilingual-cased')

100%|██████████| 662804195/662804195 [01:41<00:00, 6535101.58B/s]


In [7]:
# Load pre-trained model tokenizer (vocabulary-multilingual)
tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-cased')

100%|██████████| 995526/995526 [00:00<00:00, 2252454.15B/s]


In [10]:
text = "I am so happy, today is my birthday! I can't wait to dance with my friends!"
marked_text = "[CLS] " + text + " [SEP]"

# Tokenize our sentence with the BERT tokenizer.
tokenized_text = tokenizer.tokenize(marked_text)
segments_ids = [1] * len(tokenized_text)

# Map the token strings to their vocabulary indeces.
indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)

# Print out the tokens.
print(tokenized_text)

# Convert inputs to PyTorch tensors
tokens_tensor = torch.tensor([indexed_tokens])
segments_tensors = torch.tensor([segments_ids])

['[CLS]', 'i', 'am', 'so', 'happy', ',', 'today', 'is', 'my', 'birthday', '!', 'i', 'can', "'", 't', 'wait', 'to', 'dance', 'with', 'my', 'friends', '!', '[SEP]']


In [9]:
model.eval()

BertModelHT(
  (embeddings): BertEmbeddingsHT(
    (word_embeddings): Embedding(119547, 768)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): BertLayerNorm()
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-11): 12 x BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): BertLayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (intermediate): BertIntermediate(
          (dense): Linea

In [11]:
# Predict hidden states features for each layer
with torch.no_grad():
    encoded_layers, _ = model(tokens_tensor, segments_tensors)

# Concatenate the tensors for all layers. We use `stack` here to
# create a new dimension in the tensor.
token_embeddings = torch.stack(encoded_layers, dim=0)

# Remove dimension 1, the "batches".
token_embeddings = torch.squeeze(token_embeddings, dim=1)

# Swap dimensions 0 and 1.
token_embeddings = token_embeddings.permute(1,0,2)

token_embeddings.size()

torch.Size([23, 12, 768])

In [12]:
# `encoded_layers` has shape [12 x 1 x 23 x 768]

# `token_vecs` is a tensor with shape [23 x 768]
token_vecs = encoded_layers[11][0]

# Calculate the average of all 23 token vectors.
sentence_embedding = torch.mean(token_vecs, dim=0)

In [13]:
print("Our final sentence embedding vector of shape:", sentence_embedding.size())
sentence_embedding

Our final sentence embedding vector of shape: torch.Size([768])


tensor([ 3.5379e-01, -5.2726e-01,  3.9038e-01,  4.6449e-02,  1.0117e-01,
        -2.8333e-01, -4.6904e-01,  1.0536e-01,  4.0712e-01,  4.4933e-01,
        -3.7807e-01,  2.4165e-01,  5.9514e-01,  3.8656e-01,  3.4816e-01,
        -5.9424e-01,  5.4353e-01, -1.6189e-01,  2.0134e-01,  4.1504e-01,
        -9.2185e-02, -3.7987e-02,  7.6645e-01,  6.4855e-01, -2.4880e-01,
         4.3472e-01, -4.5661e-01, -9.2187e-03,  1.9247e-01, -4.4045e-01,
        -2.9806e-01,  1.5573e-01, -3.5994e-02,  6.9275e-01, -1.5687e-01,
         1.4037e-01, -4.2200e-02,  4.2213e-01, -2.5957e-01, -5.3310e-01,
         1.6371e-01,  1.5584e-01,  3.9058e-01, -5.5510e-02, -1.5361e-01,
        -5.2311e-01,  6.8016e-01,  5.9693e-01, -7.5328e-03,  1.8516e-01,
         2.1970e-01, -1.6238e-01,  5.2092e-01,  2.4030e-01,  2.3908e-01,
         6.0401e-02,  2.5389e-01, -3.3801e-01,  1.2806e-01,  1.1195e-01,
         1.5407e-02,  3.8178e-01,  3.6687e-01, -2.1961e-01, -2.0979e-01,
         1.7412e-01,  5.7783e-02,  2.2307e-02,  2.0