# BERT (Bidirectional Encoder Representations from Transformers)
Paper: https://arxiv.org/pdf/1810.04805.pdf

Implementation: https://mccormickml.com/2019/05/14/BERT-word-embeddings-tutorial/

In [None]:
import torch
from tqdm import tqdm
from transformers import BertTokenizer, BertModel

In [None]:
def create_word_to_vec_map(vocab: list) -> dict({str: list[float]}):
    """
    Create a dictionary mapping from words of a given vocab to the respective word's embedding.

    Arguments:
    vocab -- list of words

    Returns:
    bert -- Dictionary containing the words with their respective word embeddings
    """
    bert = {}

    # Load pre-trained model weights
    model = BertModel.from_pretrained("bert-base-uncased", output_hidden_states=True)

    # Set evaluation mode (feed-forward operation without dropout regularization)
    model.eval()

    # Load tokenizer
    tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

    for word in tqdm(vocab):
        text = "[CLS] " + word + " [SEP]"

        # Add tags ([CLS] in the beginning, [SEP] at the end or as separator between two sentences)
        marked_text = "[CLS] " + text + " [SEP]"

        # Tokenize the sentence with the BERT tokenizer
        tokenized_text = tokenizer.tokenize(marked_text)

        # Map the token strings to their vocab indices
        indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)

        # Set segment IDs (0 for the first sentence, 1 for the second sentence or a single sentence)
        segments_ids = [1] * len(tokenized_text)

        # Convert inputs to PyTorch tensors
        tokens_tensor = torch.tensor([indexed_tokens])
        segments_tensors = torch.tensor([segments_ids])

        # Forward pass without constructing compute graph (only needed for backprop) to reduce memory
        with torch.no_grad():
            outputs = model(tokens_tensor, segments_tensors)

        # Third item of outputs contains hidden states from all layers
        hidden_states = outputs[2]

        # Concatenate tensors for all layers
        token_embeddings = torch.stack(hidden_states, dim=0)

        # Remove "batches" dimension
        token_embeddings = torch.squeeze(token_embeddings, dim=1)

        # Switch "layers" and "tokens" dimension to result in [tokens, layers, dimensions]
        token_embeddings = token_embeddings.permute(1, 0, 2)

        # Average the second to last hidden layer with shape [768]
        token_vecs = hidden_states[-2][0]
        embedding = torch.mean(token_vecs, dim=0)

        bert[word] = embedding.tolist()

    return bert