## About

One of the things that makes BERT so flexible is its ability to handle out of vocabulary (OOV) words. When the model comes across a word that isn't in its vocabulary, it breaks that word into different "subwords" that _are_ in the vocabulary. These subwords become the tokenized representation of the word.

But how many ways can a word be chopped up? What if we prevented BERT from ever using whole words? In what ways might a word's subwords differ or relate to one another in the embedding space? What, in short, would the embedding space of subwords look like?

In [None]:
import re
import pandas as pd
import numpy as np
from itertools import combinations, chain

import torch
from transformers import AutoTokenizer, AutoModel

from sklearn.manifold import TSNE
import matplotlib.pyplot as plt

Load and intialize the tokenizer and model. Get all subwords from the tokenizer vocabulary

In [None]:
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
model = AutoModel.from_pretrained("bert-base-uncased")

In [None]:
sub_vocab = [token for token in tokenizer.get_vocab().keys() if "#" in token]
hash_dict = hash_dict = {re.sub(r'#{1,3}', '', subword): subword for subword in sub_vocab}

Setup all the functions for finding and encoding every possible subword combination of a given word

In [None]:
def partition(word):
    n = len(word)
    b, mid, e = [0], list(range(1, n)), [n]
    splits = (split for i in range(n) for split in combinations(mid, i))
    partitions = [[word[sl] for sl in map(slice, chain(b, split), chain(split, e))] for split in splits]
    return partitions

def get_subwords(word, hash_dict):
    subwords = []
    for subword in hash_dict.keys():
        if subword in word:
            hash_token = hash_dict[subword]
            found = re.search('#{1,3}\w+', hash_token)
            if found is not None:
                subwords.append(hash_token)
    return subwords

def check_partitions(partitions, subwords, hash_dict):
    valid_tokens = []
    for partition in partitions:
        try:
            valid_tokens.append(hash_dict[p] for p in partition)
        except:
            continue
    return valid_tokens

def pretty_print(valid_tokens):
    return {'-'.join([re.sub('#', '', p) for p in partition]): partition for partition in valid_tokens}

def append_special_tokens(valid_tokens):
    return {subword: ['[CLS]'] + tokens + ['[SEP]'] for subword, tokens in valid_tokens.items()}

def encode(tokenized, pad_len=15):
    for subword, tokens in tokenized.items():
        token_ids = [tokenizer.convert_tokens_to_ids(token) for token in tokens]
        if len(tokens) < pad_len:
            to_pad = [0] * (pad_len - len(tokens))
            token_ids = token_ids + to_pad
        token_type_ids = [0] * pad_len
        attention_mask = [1] * len(tokens) + to_pad
        tokenized[subword] = {
            'tokens': tokens,
            'input_ids': token_ids,
            'token_type_ids': token_type_ids,
            'attention_mask': attenton_mask
        }
    return tokenized

def stack_tokenized(tokenized, convert_to_tensors=True):
    stacked = {'input_ids': [], 'token_type_ids': [], 'attention_mask': []}
    for subwords in tokenized.keys():
        stacked['input_ids'].append(tokenized[subwords]['input_ids'])
        stacked['token_type_ids'].append(tokenized[subwords]['token_type_ids'])
        stacked['attention_mask'].append(tokenized[subwords]['attention_mask'])
    
    if convert_to_tensors == True:
        stacked = {component: torch.tensor(stacked[component]) for component in stacked.keys()}
    
    return stacked, list(tokenized.keys())

def prepare(word, pad_len=15):
    partitions, subword = partition(word), get_subwords(word)
    valid_tokens = check_partitions(partitions, subwords, hash_dict)
    valid_tokens = pretty_print(valid_tokens)
    tokenized = append_special_tokens(valid_tokens)
    tokenized = encode(tokenized, pad_len=pad_len)
    to_model, substrings = stack_tokenized(tokenized)
    return to_model, substrings

Transform the output of a model into graphable embeddings

In [None]:
def mean_pooled(output, to_model):
    embeddings = output.last_hidden_state
    mask = to_model['attention_mask'].unsqueeze(-1).expand(embeddings.size()).float()
    masked_embeddings = embeddings * mask    
    summed = torch.sum(masked_embeddings, 1)
    summed_mask = torch.clamp(mask.sum(1), min=1e-9)
    return summed / summed_masked

def normalize(mean_pooled):
    detached_tensors = mean_pooled.detach().cpu().numpy()
    norm = np.linalg.norm(detached_tensors)
    if norm == 0:
        return detached_tensors
    return detached_tensors / norm

Select a word and run all of the above

In [None]:
word = "arrogant"
to_model, substrings = prepare(word)

print(len(substrings), "permutations to model")

output = model(**to_model)
mean_pooled = mean_pooled(output, to_model)
normalized = normalize(mean_pooled)

Graph the embeddings

In [None]:
def graph_embeddings(mean_pooled, substrings):
    tsne = TSNE(n_components=2, metric='cosine', init='pca', n_iter=1500)
    bert_tsne = tsne.fit_transform(mean_pooled.detach().numpy())
    
    fig = plt.figure(figsize=(12,9))
    points = []
    for idx, substring in enumerate(substrings):
        plt.scatter(bert_tsne[idx,0], bert_tsne[idx,1], s=0)
        text = plt.text(bert_tsne[idx,0], bert_tsne[idx,1], substring, family='sans-serif')
        points.append(text)
        
    return fig

In [None]:
fig = graph_embeddings(mean_pooled, substrings)
fig