In [1]:
import transformers
import torch

In [3]:
tok = transformers.GPT2Tokenizer.from_pretrained('gpt2')
model = transformers.AutoModelForCausalLM.from_pretrained('gpt2')
tok.add_tokens(['Aragorn', 'Frodo', 'Lothlorien'])
model.resize_token_embeddings(len(tok))

Embedding(50260, 768)

In [4]:
'''
Let’s go back to our running example. First, 
we instantiate a model and tokenizer, add our new tokens, 
and resize the embeddings.
'''
tok = transformers.GPT2Tokenizer.from_pretrained('gpt2')
model = transformers.AutoModelForCausalLM.from_pretrained('gpt2')
tok.add_tokens(['Aragorn', 'Frodo', 'Lothlorien'])
model.resize_token_embeddings(len(tok))

Embedding(50260, 768)

In [6]:
'''
Next, we compute the distribution from which we’ll sample our new embeddings:
'''
params = model.state_dict()
embeddings = params['transformer.wte.weight']
pre_expansion_embeddings = embeddings[:-3,:]
mu = torch.mean(pre_expansion_embeddings, dim=0)
n = pre_expansion_embeddings.size()[0]
sigma = ((pre_expansion_embeddings - mu).T @ (pre_expansion_embeddings - mu)) / n
dist = torch.distributions.multivariate_normal.MultivariateNormal(
        mu, covariance_matrix=1e-5*sigma)

In [7]:
'''
We’ll load in our new embeddings into the model:
'''
new_embeddings = torch.stack(tuple((dist.sample() for _ in range(3))), dim=0)
embeddings[-3:,:] = new_embeddings
params['transformer.wte.weight'][-3:,:] = new_embeddings
model.load_state_dict(params)

<All keys matched successfully>

In [8]:
'''
Finally, we sample from the model and observe that it does not just generate the new words we just added to the vocabulary.
'''

sent2 = 'Dogs are great because they are '
print(tok.decode(model.generate(**tok(sent2, return_tensors='pt'), do_sample=True)[0]))
#print(embeddings)

word = "kajsbfkasoebgkjwqenfndoow"  # Replace with the word you want to look up
token_id = tok.convert_tokens_to_ids(word)
embedding = embeddings[token_id]
print(embedding)

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


Dogs are great because they are _____.

Some people have actually used horses as an
tensor([ 5.1352e-02, -2.7689e-02,  4.9937e-02, -4.2212e-02, -6.1677e-02,
         3.2521e-02, -2.2412e-01, -8.7415e-02, -7.1382e-02, -2.0823e-02,
         6.2048e-02,  4.0809e-02, -6.9579e-02,  6.3005e-03,  9.2761e-03,
         1.5079e-02,  9.6145e-02, -1.4278e-01,  7.7547e-02,  5.8755e-02,
         8.2768e-02, -7.1086e-02, -3.8467e-02,  3.5799e-02, -8.9123e-02,
        -8.8032e-02, -3.0367e-02,  1.6997e-01,  4.5189e-02,  1.4124e-01,
         6.5241e-02,  7.6400e-02,  4.1002e-02, -7.2275e-02, -3.2274e-02,
        -3.7502e-02, -3.1738e-01,  5.6048e-02,  8.2341e-02,  3.1858e-02,
         1.1918e-02, -1.2181e-01,  8.7171e-03, -8.5096e-02, -2.0306e-02,
         3.1586e-03, -2.2322e-01,  2.5618e-02, -5.2556e-02, -1.7527e-01,
         1.1652e-01, -4.4113e-02,  6.5566e-02,  1.3448e-01, -1.2134e-01,
        -1.4695e-01, -1.9515e-02, -3.0667e-02,  5.8110e-02,  5.5833e-02,
        -2.0452e-02, -5.0032e-02,  8.803