In [None]:
from transformers import GPT2LMHeadModel

If we look at the weights of GPT-2 we can see the token and position embedding and we can see that it has a vocab size of 50257 and a context length of 1024 tokens.

In [None]:
model_hf = GPT2LMHeadModel.from_pretrained("gpt2") # 124M #1.5B you need to use gpt2-xl
sd_hf = model_hf.state_dict()

for k, v in sd_hf.items():
    print(k, v.shape)

Let's look at the first few positional embeddings

In [None]:
sd_hf["transformer.wpe.weight"].view(-1)[:20]

Next we can plot them. Every row represents a fixed position in our context window from 0 to 1023. The model uses these to understand the relative positions of the tokens and attend to them depending on their position, not just their content.

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline

plt.imshow(sd_hf["transformer.wpe.weight"], cmap="gray")

When we look into an individual columns we can see how they react to different positions. You can see that the green channel becomes more active for positions more in the middle (above ~250 and below ~800). The fact that they're more jagged indicates that the model is not fully trained. After the model has been more trained, you would expect these to be more smooth.

Note that in the original transformer paper the positional embedding weights were fixed using sin and cosine curves of different frequencies, however in GPT-2 they are learned weights. It is interesting that they recover these periodic wave like structures.

In [None]:
plt.plot(sd_hf["transformer.wpe.weight"][:, 150])
plt.plot(sd_hf["transformer.wpe.weight"][:, 200])
plt.plot(sd_hf["transformer.wpe.weight"][:, 250])

We can visualize any of the other weight matrices

In [None]:
plt.imshow(sd_hf["transformer.h.1.attn.c_attn.weight"][:300,:300], cmap="gray")

Our main interest here is to play with inference on the model with the weights that we loaded

In [None]:
from transformers import pipeline, set_seed
generator = pipeline('text-generation', model='gpt2')
set_seed(42)
generator("Hello, I'm a language model,", max_length=30, num_return_sequences=5)