## Week 6 workshop

In this week, we'll continue to explore the GPT-style model trained on Shakespeare text.

First we import the required dependencies:

In [None]:
import sys
import torch
import polars as pl
#from plotnine import ggplot, aes, geom_line, labs, theme_minimal

sys.path.append("scratch-llm")
from model.llm import LLM
from model.tokenizer import Tokenizer
from helpers.config import LLMConfig, get_device

Next we prepare the model setup, model, and tokenizer.

In [2]:
# the model setup (has to be consistent with the trained model we're loading)
llm_config = LLMConfig(
    vocab_size = 4096,
    seq_len = 128,
    dim_emb = 256,
    num_layers = 4,
    num_heads = 8,
    emb_dropout = 0.0,
    ffn_dim_hidden = 4 * 256,
    ffn_bias = False
)

# the trained tokenizer
tokenizer = Tokenizer("data/tinyshakespeare.model")

# the model object
model = LLM(
    vocab_size = tokenizer.vocab_size,
    seq_len = llm_config.seq_len,
    dim_emb = llm_config.dim_emb,
    num_layers = llm_config.num_layers,
    attn_num_heads = llm_config.num_heads,
    emb_dropout = llm_config.emb_dropout,
    ffn_hidden_dim = llm_config.ffn_dim_hidden,
    ffn_bias = llm_config.ffn_bias
)

# the device on which we're running this (CPU vs GPU etc.)
device = get_device()

# move the model to the appropriate GPU device
model.to(device)

# load the saved model weights
model.load_state_dict(torch.load(
    "data/tinyshakespeare_llm.pt",
    weights_only = True,
    map_location = device
))

# put the model into evaluation mode
model.eval()

LLM(
  (token_embedding): Embedding(4096, 256)
  (emb_dropout): Dropout(p=0.0, inplace=False)
  (transformer): Sequential(
    (0): TransformerBlock(
      (norm_attn): RMSNorm()
      (multihead_attn): MultiHeadAttention(
        (positional_encoding): RotaryPositionalEncoding()
        (proj_qkv): Linear(in_features=256, out_features=768, bias=False)
        (proj_out): Linear(in_features=256, out_features=256, bias=False)
      )
      (norm_ffn): RMSNorm()
      (feed_forward): FeedForward(
        (0): Linear(in_features=256, out_features=1024, bias=False)
        (1): SwiGLU(
          (linear): Linear(in_features=1024, out_features=2048, bias=True)
        )
        (2): Linear(in_features=1024, out_features=256, bias=False)
      )
    )
    (1): TransformerBlock(
      (norm_attn): RMSNorm()
      (multihead_attn): MultiHeadAttention(
        (positional_encoding): RotaryPositionalEncoding()
        (proj_qkv): Linear(in_features=256, out_features=768, bias=False)
        (pro

## Exploring the tokenizer

Extract all tokens and print out two specific ones.

In [13]:
# extract list of all tokens
tokens = [tokenizer.sp.id_to_piece(i) for i in range(llm_config.vocab_size)]

print(tokenizer.sp.piece_to_id("▁perforce"))
print(tokenizer.sp.piece_to_id("▁basilisk"))

2554
4077


Print all tokens the tokenizer knows.  The underscore ("▁") in front of a token indicates the beginning of a word.

In [14]:
tokens_per_line = 10
for i in range(0, len(tokens), tokens_per_line):
    line_tokens = tokens[i:i+tokens_per_line]
    print(f"[{i:5d}-{min(i+tokens_per_line-1, len(tokens)-1):5d}] {' | '.join(line_tokens)}")

[    0-    9] <pad> | <unk> | <s> | </s> | , | : | s | . | ▁I | ▁the
[   10-   19] ' | ▁to | ▁ | ▁and | ; | ▁of | d | ▁a | ▁you | ▁my
[   20-   29] ? | ▁in | ! | ▁that | ▁And | ▁not | ▁is | ▁me | ▁be | ▁it
[   30-   39] ▁with | ▁your | ▁for | ▁his | ing | ▁he | ▁have | ▁this | ▁him | ▁thou
[   40-   49] t | ed | ▁so | - | ▁as | ▁will | ▁thy | er | ▁but | ▁The
[   50-   59] ▁To | O | ▁all | ▁' | ▁O | ▁her | st | ▁we | ll | ▁do
[   60-   69] ly | IO | ▁shall | ▁thee | ▁by | ▁our | e | ▁are | ▁no | ▁A
[   70-   79] n | ▁That | ▁what | ▁on | y | ▁But | ▁S | ▁good | R | ▁What
[   80-   89] ▁B | KING | ▁from | ▁more | ▁at | ▁For | ▁C | ▁if | ▁sir | ▁now
[   90-   99] ▁lord | l | ▁was | ▁she | ▁love | ▁am | ▁them | ▁would | ▁their | ▁here
[  100-  109] ▁My | ▁they | p | ▁say | ▁us | ▁come | ▁or | ▁king | ▁man | LO
[  110-  119] ▁then | ▁one | ▁know | ▁well | ▁make | r | ▁let | ▁hath | ▁G | ▁As
[  120-  129] ▁may | IUS | ▁You | ▁must | ▁were | ▁F | ▁like | a | ET | ▁than
[  130-  139] ▁there |

## Exploring model parameters

Structure of the model parameters:

- **0** :: weight matrix for **token embeddings**

<!-- -->

- **1** :: **RMSNorm** parameter vector
- **2** :: Q, K, V matrices (concatenated) for **MultiHeadAttention**
- **3** :: weight matrix for projout part of **MultiHeadAttention**
- **4** :: **RMSNorm** parameter vector
- **5** :: initial weight matrix for **FeedForward (SwiGLU)** part
- **6** :: **SwiGLU** weight matrices (concatened)
- **7** :: **SwiGLU** bias vector
- **8** :: final weight matrix for **FeedForward (SwiGLU)** part

<!-- -->

-  **9-16** :: as 1-8 but for second TransformerBlock
- **17-24** :: third TransformerBlock
- **25-32** :: fourth TransformerBlock

<!-- -->

- **33** :: **RMSNorm** parameter vector
- **34** :: final **projection_head** bias vector

**NOTE**: there is no weight matrix for the final projection head b/c it is "weight-tied" to the token embeddings weight matrix (0 above)

In [15]:
# extract all model parameters
parList = list(model.parameters())

# extract tensor shapes for each tensor
parShapes = [list(el.shape) for el in parList]
parShapes

[[4096, 256],
 [256],
 [768, 256],
 [256, 256],
 [256],
 [1024, 256],
 [2048, 1024],
 [2048],
 [256, 1024],
 [256],
 [768, 256],
 [256, 256],
 [256],
 [1024, 256],
 [2048, 1024],
 [2048],
 [256, 1024],
 [256],
 [768, 256],
 [256, 256],
 [256],
 [1024, 256],
 [2048, 1024],
 [2048],
 [256, 1024],
 [256],
 [768, 256],
 [256, 256],
 [256],
 [1024, 256],
 [2048, 1024],
 [2048],
 [256, 1024],
 [256],
 [4096]]

Extracting token embeddings. These are the initial embeddings going into the transformer module.

In [None]:
# extract list of all tokens (just as before)
tokens = [tokenizer.sp.id_to_piece(i) for i in range(llm_config.vocab_size)]

# obtain tensor of token embeddings and convert numpy array for downstream manipulation
embedding_data = parList[0].cpu().detach()
print(embedding_data.shape) # (4096, 256)

# _basilisk is token 4077
print(embedding_data[4077, :])


torch.Size([4096, 256])
tensor([-1.3005e-03,  2.9321e-02, -2.7822e-02,  1.5014e-01, -3.4438e-02,
        -4.0284e-02, -1.9866e-01, -8.9023e-02, -9.7680e-02, -1.8535e-01,
        -8.1944e-03,  6.2589e-02,  9.7284e-03,  1.0832e-02,  8.8429e-03,
         2.1591e-01, -2.1713e-01, -4.6794e-02,  7.0577e-02, -1.9047e-01,
        -5.0545e-02, -1.1453e-01, -5.0950e-02, -2.6960e-02, -4.4235e-02,
        -8.7682e-03,  1.0248e-03, -4.1417e-03,  1.3822e-01,  2.9820e-02,
         1.1739e-01,  2.2652e-02,  1.0913e-01, -9.7986e-02,  1.2757e-01,
         2.4810e-01, -1.5029e-01,  1.0952e-01, -7.7267e-02, -8.0533e-02,
         8.1471e-02,  4.8467e-02,  4.0388e-02, -5.1555e-03, -7.1092e-02,
        -7.2734e-02,  1.7126e-01, -3.9200e-02, -1.0690e-01,  9.7338e-02,
         1.1545e-01, -3.5746e-04, -3.3372e-02, -1.2858e-02, -8.7852e-02,
        -1.9490e-01,  3.0188e-01,  1.6837e-01, -2.2130e-02, -9.7071e-02,
         9.2055e-03, -7.6874e-02, -1.5720e-01, -3.6125e-02,  3.0394e-02,
        -1.8694e-01, -6.145

In [84]:
# create DataFrame with token column plus embedding dimensions
embeddings = pl.DataFrame({
    'token': tokens,
    'token_id': [i for i in range(llm_config.vocab_size)],
    **{f'dim_{i}': embedding_data[:, i] for i in range(embedding_data.shape[1])}
})

# print token embeddings for three chosen tokens
target_tokens = ["▁basilisk", "▁perforce", "▁castle"]
embeddings.filter(pl.col('token').is_in(target_tokens))

token,token_id,dim_0,dim_1,dim_2,dim_3,dim_4,dim_5,dim_6,dim_7,dim_8,dim_9,dim_10,dim_11,dim_12,dim_13,dim_14,dim_15,dim_16,dim_17,dim_18,dim_19,dim_20,dim_21,dim_22,dim_23,dim_24,dim_25,dim_26,dim_27,dim_28,dim_29,dim_30,dim_31,dim_32,dim_33,dim_34,…,dim_219,dim_220,dim_221,dim_222,dim_223,dim_224,dim_225,dim_226,dim_227,dim_228,dim_229,dim_230,dim_231,dim_232,dim_233,dim_234,dim_235,dim_236,dim_237,dim_238,dim_239,dim_240,dim_241,dim_242,dim_243,dim_244,dim_245,dim_246,dim_247,dim_248,dim_249,dim_250,dim_251,dim_252,dim_253,dim_254,dim_255
str,i64,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,…,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32
"""▁castle""",2007,-0.148215,0.197007,0.1249,0.021305,-0.121309,0.14587,0.259737,0.048309,0.286923,0.13779,0.011779,-0.030519,-0.173348,0.349873,-0.143959,0.357129,-0.150084,0.093061,0.325579,-0.10153,0.117874,0.233133,-0.230312,-0.094514,-0.126663,-0.139246,-0.064809,-0.149722,-0.163478,0.043311,-0.005281,-0.149828,0.08299,-0.09244,-0.03638,…,-0.041608,0.124754,-0.0927,0.104195,0.105148,0.34459,0.014003,0.215244,0.154041,0.074032,-0.036821,-0.072911,-0.109559,-0.025216,0.016599,-0.135661,0.043836,-0.062106,-0.025945,0.192274,-0.340097,-0.194262,0.138081,0.184851,-0.016861,0.144102,0.20947,0.222077,-0.167625,0.3088,-0.346945,0.134195,0.197539,-0.249303,-0.180279,0.363111,-0.055227
"""▁perforce""",2554,-0.153841,-0.274629,0.173357,0.049015,0.001886,0.17866,-0.014247,0.352402,0.029701,0.071924,-0.027173,0.056062,-0.053321,-0.08847,0.241883,0.044242,-0.08444,0.012871,-0.233677,0.073736,0.174775,0.263662,0.084643,0.045823,0.130946,0.123333,0.262025,0.040367,-0.129015,0.135108,-0.159329,-0.09433,0.094151,0.192522,-0.342727,…,-0.302726,0.216747,-0.28534,0.068532,-0.051689,0.314963,0.105637,-0.206142,0.071576,-0.046063,0.011026,-0.164473,0.138242,-0.311956,-0.2274,-0.303799,-0.105977,-0.005794,0.110667,-0.116478,0.031044,-0.312319,-0.249584,0.018577,0.117485,-0.008807,-0.276737,0.032046,0.123242,-0.161981,-0.079343,-0.025822,0.029489,-0.266156,-0.239444,-0.154607,0.099395
"""▁basilisk""",4077,-0.001301,0.029321,-0.027822,0.150135,-0.034438,-0.040284,-0.198661,-0.089023,-0.09768,-0.185345,-0.008194,0.062589,0.009728,0.010832,0.008843,0.215906,-0.217127,-0.046794,0.070577,-0.190466,-0.050545,-0.114534,-0.05095,-0.02696,-0.044235,-0.008768,0.001025,-0.004142,0.138218,0.02982,0.117393,0.022652,0.109133,-0.097986,0.127575,…,-0.20958,-0.11634,-0.06938,0.106201,-0.317534,0.200011,0.056857,0.161848,0.162149,0.046978,-0.041519,0.007423,0.218404,0.058436,0.162892,-0.023739,-0.164718,0.31398,0.055259,0.175692,-0.058398,0.110983,-0.086261,0.180361,0.081902,0.020042,-0.105211,0.102362,-0.116001,0.261306,0.042691,-0.112341,-0.005003,-0.023965,-0.142263,0.153943,0.128183


We can extract embeddings from each layer using hooks. Hooks are call-back functions that get called when a particular part of the model is executed, and so they allow us to capture input and output of those parts of the model.

In [85]:
# set up a list to store embeddings from each transformer layer
layer_outputs = []

# hook function that stores the output of a model component
def hook_fn(module, input, output):
    layer_outputs.append(output.detach().clone().cpu())

# register hooks on each transformer block
hooks = []
for block in model.transformer:
    hook = block.register_forward_hook(hook_fn)
    hooks.append(hook) # we need to keep a list of all hooks so we can remove them at the end

# run forward pass
prompt = tokenizer.encode(
    "I say, basilisk, perforce, castle.",
    beg_of_string = True,
    pad_seq = True,
    seq_len = llm_config.seq_len
)
inputs = torch.tensor(prompt, dtype=torch.int32).unsqueeze(0)

# print the input tensor
print(f"The prompt input:\n{inputs}\n")

# generate output
# we don't actually need the output, we just do this to call the `forward()` function of the model
out = model(inputs.to(device))

# clean up hooks (so we can run the cell again in the same session and not accumulate hooks)
for hook in hooks:
    hook.remove()

The prompt input:
tensor([[   0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    2,    8,
          103,    4, 4077,    4, 2554,    4, 2007,    7]], dtype=torch.int32)



The tokens for basilisk, perforce, and castle are at positions 122, 124, and 126 in this tensor.

In [87]:
token_pos = [122, 124, 126]
inputs[0, token_pos]

tensor([4077, 2554, 2007], dtype=torch.int32)

These are the corresponding embeddings in layer 3 (the final layer). We will have to do a bit more work to look at them easily.

In [88]:
print(layer_outputs[3][0, token_pos, :].shape)
layer_outputs[3][0, token_pos, :]

torch.Size([3, 256])


tensor([[-8.0377e+03, -1.0186e+03, -8.2308e+03, -7.8267e+03,  1.6482e+04,
          7.7205e+03, -5.2366e+03, -7.7123e+03, -5.2055e+03,  1.9698e+03,
          4.1744e+03, -1.9272e+03, -7.3760e+03, -1.5343e+03,  1.9124e+03,
          6.7367e+03,  6.5250e+03,  7.3861e+03, -1.5715e+03, -5.6180e+03,
          8.1098e+03, -3.5322e+03, -1.1414e+04, -9.5040e+02, -7.7052e+03,
         -8.4049e+03,  1.4677e+01,  4.6983e+03, -6.4292e+03,  5.9010e+03,
          1.2783e+04, -9.8368e+02, -1.0912e+04,  1.1995e+04, -1.4941e+03,
          9.8704e+03,  1.5950e+02,  4.4452e+03,  1.5128e+04,  8.5939e+03,
         -2.5339e+02,  1.0938e+04, -1.7914e+03,  1.2693e+02, -5.4707e+03,
         -4.5063e+03, -3.1841e+04,  2.0593e+03,  6.2451e+03, -1.4579e+04,
          5.2292e+03,  3.5610e+03, -1.8493e+03, -5.2443e+03, -1.5878e+03,
          1.5666e+04,  1.3140e+04, -7.2062e+03,  6.6454e+03,  6.1010e+03,
          1.5020e+03,  2.5733e+03,  4.1013e+03, -6.8523e+03,  8.0431e+03,
          7.8662e+03, -7.5032e+03,  3.

In [89]:
# function to create data frame of embedding data for given layer
def make_embedding_table(layer_outputs, layer_id):
    df = pl.DataFrame({
        'token': [tokenizer.sp.id_to_piece(inputs[0, i].item()) for i in range(llm_config.seq_len)],
        'token_id': inputs[0, :],
        **{f'dim_{i}': layer_outputs[layer_id][0, :, i] for i in range(llm_config.dim_emb)}
    })
    return df

print("All layer 0 embeddings:")
make_embedding_table(layer_outputs, 0)[token_pos]


All layer 0 embeddings:


token,token_id,dim_0,dim_1,dim_2,dim_3,dim_4,dim_5,dim_6,dim_7,dim_8,dim_9,dim_10,dim_11,dim_12,dim_13,dim_14,dim_15,dim_16,dim_17,dim_18,dim_19,dim_20,dim_21,dim_22,dim_23,dim_24,dim_25,dim_26,dim_27,dim_28,dim_29,dim_30,dim_31,dim_32,dim_33,dim_34,…,dim_219,dim_220,dim_221,dim_222,dim_223,dim_224,dim_225,dim_226,dim_227,dim_228,dim_229,dim_230,dim_231,dim_232,dim_233,dim_234,dim_235,dim_236,dim_237,dim_238,dim_239,dim_240,dim_241,dim_242,dim_243,dim_244,dim_245,dim_246,dim_247,dim_248,dim_249,dim_250,dim_251,dim_252,dim_253,dim_254,dim_255
str,i32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,…,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32
"""▁basilisk""",4077,-8035.944336,-1055.205933,-8346.657227,-8089.351074,16804.189453,7526.438477,-5263.236816,-7796.367676,-4671.29248,1873.161987,3862.245605,-2264.424561,-7110.523438,-1668.164673,2115.224121,6719.562012,6373.024414,7472.193359,-1562.470093,-5582.696777,7996.283203,-3347.405762,-11387.352539,-780.0495,-7864.013672,-8641.566406,0.595935,4922.300293,-6481.258789,6088.95752,12756.319336,-731.766113,-11026.038086,11890.280273,-1748.631958,…,-2826.578857,9422.588867,-1147.916016,-3556.54248,5853.64502,8550.706055,-4122.742676,-8630.850586,-7715.640625,14521.473633,1111.089233,-11310.09082,-7605.333008,-5211.052246,9644.207031,14045.542969,7104.358398,19551.708984,-13986.811523,-4101.196777,-3429.51709,17604.417969,-9874.216797,804.463379,2772.717285,-2172.591064,7563.640625,9846.825195,2059.871094,2958.040771,2316.457764,12316.712891,2480.868652,-17846.857422,-7590.908203,-9057.871094,-2054.790527
"""▁perforce""",2554,1419.814697,-18540.603516,10474.90918,-6559.92041,22791.869141,339.467712,-4053.29541,2610.476562,-1817.029419,-3074.878662,1078.336426,363.514587,-3135.254639,11504.192383,4334.347168,-144.419357,-4891.397949,-402.38028,7083.713867,15008.277344,8393.455078,3342.667725,-2776.926514,3959.003662,-3252.584229,166.990799,10094.742188,5348.339844,-4039.562256,12630.323242,-235.404999,-3814.471191,-6229.195801,5350.229004,-4011.817871,…,3155.580078,-1397.125854,-3653.310059,-666.091064,19611.5,10068.120117,1341.33606,-3348.138184,6643.76416,-3437.553711,7410.625,-4539.281738,-3783.976318,-5824.100098,-2476.200928,18200.789062,4600.745605,16877.867188,-13757.123047,-10268.798828,3932.779785,17395.619141,-4253.066406,-14552.951172,-3141.924316,9704.746094,2379.150635,5618.531738,1021.155762,9126.204102,2765.288818,9872.863281,6443.302246,-3356.700928,-6683.908691,-6151.64502,-9045.299805
"""▁castle""",2007,-4229.874512,-4178.651855,2772.171387,-4495.574219,3830.035889,-1087.234009,308.352997,5522.677734,-1385.413452,-2118.908447,-879.550415,3726.971924,-971.319763,2783.561523,3248.839844,-2082.979004,-694.346619,-3118.477295,1688.149048,4930.899414,-4049.894775,1257.667236,1074.674072,-3727.5,-3821.339844,5466.620605,425.973175,880.930725,-2031.386353,1805.665649,-2595.820068,-208.628403,6525.966309,-4616.744141,-2637.871094,…,2431.286865,-1537.646484,227.205551,-3241.657227,5108.012695,2320.427979,-1254.968628,-580.961792,3789.618896,-6087.088867,594.146057,-4218.633789,-2959.508545,2370.702148,-482.947021,5505.622559,4956.480957,1721.7229,-4821.256836,-1815.167969,619.161011,5542.36084,-5117.79541,-7176.319336,1956.365723,2828.599854,4885.998047,1678.619873,-17.509712,876.456116,3812.211426,2716.855713,1747.479614,3432.223877,4987.053223,1611.177368,8107.540527


Now we print out the embedding tables only for the target tokens, plus once more the initial token embeddings for reference.

In [90]:
embeddings.filter(pl.col('token').is_in(target_tokens)) # input embeddings

token,token_id,dim_0,dim_1,dim_2,dim_3,dim_4,dim_5,dim_6,dim_7,dim_8,dim_9,dim_10,dim_11,dim_12,dim_13,dim_14,dim_15,dim_16,dim_17,dim_18,dim_19,dim_20,dim_21,dim_22,dim_23,dim_24,dim_25,dim_26,dim_27,dim_28,dim_29,dim_30,dim_31,dim_32,dim_33,dim_34,…,dim_219,dim_220,dim_221,dim_222,dim_223,dim_224,dim_225,dim_226,dim_227,dim_228,dim_229,dim_230,dim_231,dim_232,dim_233,dim_234,dim_235,dim_236,dim_237,dim_238,dim_239,dim_240,dim_241,dim_242,dim_243,dim_244,dim_245,dim_246,dim_247,dim_248,dim_249,dim_250,dim_251,dim_252,dim_253,dim_254,dim_255
str,i64,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,…,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32
"""▁castle""",2007,-0.148215,0.197007,0.1249,0.021305,-0.121309,0.14587,0.259737,0.048309,0.286923,0.13779,0.011779,-0.030519,-0.173348,0.349873,-0.143959,0.357129,-0.150084,0.093061,0.325579,-0.10153,0.117874,0.233133,-0.230312,-0.094514,-0.126663,-0.139246,-0.064809,-0.149722,-0.163478,0.043311,-0.005281,-0.149828,0.08299,-0.09244,-0.03638,…,-0.041608,0.124754,-0.0927,0.104195,0.105148,0.34459,0.014003,0.215244,0.154041,0.074032,-0.036821,-0.072911,-0.109559,-0.025216,0.016599,-0.135661,0.043836,-0.062106,-0.025945,0.192274,-0.340097,-0.194262,0.138081,0.184851,-0.016861,0.144102,0.20947,0.222077,-0.167625,0.3088,-0.346945,0.134195,0.197539,-0.249303,-0.180279,0.363111,-0.055227
"""▁perforce""",2554,-0.153841,-0.274629,0.173357,0.049015,0.001886,0.17866,-0.014247,0.352402,0.029701,0.071924,-0.027173,0.056062,-0.053321,-0.08847,0.241883,0.044242,-0.08444,0.012871,-0.233677,0.073736,0.174775,0.263662,0.084643,0.045823,0.130946,0.123333,0.262025,0.040367,-0.129015,0.135108,-0.159329,-0.09433,0.094151,0.192522,-0.342727,…,-0.302726,0.216747,-0.28534,0.068532,-0.051689,0.314963,0.105637,-0.206142,0.071576,-0.046063,0.011026,-0.164473,0.138242,-0.311956,-0.2274,-0.303799,-0.105977,-0.005794,0.110667,-0.116478,0.031044,-0.312319,-0.249584,0.018577,0.117485,-0.008807,-0.276737,0.032046,0.123242,-0.161981,-0.079343,-0.025822,0.029489,-0.266156,-0.239444,-0.154607,0.099395
"""▁basilisk""",4077,-0.001301,0.029321,-0.027822,0.150135,-0.034438,-0.040284,-0.198661,-0.089023,-0.09768,-0.185345,-0.008194,0.062589,0.009728,0.010832,0.008843,0.215906,-0.217127,-0.046794,0.070577,-0.190466,-0.050545,-0.114534,-0.05095,-0.02696,-0.044235,-0.008768,0.001025,-0.004142,0.138218,0.02982,0.117393,0.022652,0.109133,-0.097986,0.127575,…,-0.20958,-0.11634,-0.06938,0.106201,-0.317534,0.200011,0.056857,0.161848,0.162149,0.046978,-0.041519,0.007423,0.218404,0.058436,0.162892,-0.023739,-0.164718,0.31398,0.055259,0.175692,-0.058398,0.110983,-0.086261,0.180361,0.081902,0.020042,-0.105211,0.102362,-0.116001,0.261306,0.042691,-0.112341,-0.005003,-0.023965,-0.142263,0.153943,0.128183


In [92]:
make_embedding_table(layer_outputs, 0)[token_pos] # embeddings after layer 0

token,token_id,dim_0,dim_1,dim_2,dim_3,dim_4,dim_5,dim_6,dim_7,dim_8,dim_9,dim_10,dim_11,dim_12,dim_13,dim_14,dim_15,dim_16,dim_17,dim_18,dim_19,dim_20,dim_21,dim_22,dim_23,dim_24,dim_25,dim_26,dim_27,dim_28,dim_29,dim_30,dim_31,dim_32,dim_33,dim_34,…,dim_219,dim_220,dim_221,dim_222,dim_223,dim_224,dim_225,dim_226,dim_227,dim_228,dim_229,dim_230,dim_231,dim_232,dim_233,dim_234,dim_235,dim_236,dim_237,dim_238,dim_239,dim_240,dim_241,dim_242,dim_243,dim_244,dim_245,dim_246,dim_247,dim_248,dim_249,dim_250,dim_251,dim_252,dim_253,dim_254,dim_255
str,i32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,…,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32
"""▁basilisk""",4077,-8035.944336,-1055.205933,-8346.657227,-8089.351074,16804.189453,7526.438477,-5263.236816,-7796.367676,-4671.29248,1873.161987,3862.245605,-2264.424561,-7110.523438,-1668.164673,2115.224121,6719.562012,6373.024414,7472.193359,-1562.470093,-5582.696777,7996.283203,-3347.405762,-11387.352539,-780.0495,-7864.013672,-8641.566406,0.595935,4922.300293,-6481.258789,6088.95752,12756.319336,-731.766113,-11026.038086,11890.280273,-1748.631958,…,-2826.578857,9422.588867,-1147.916016,-3556.54248,5853.64502,8550.706055,-4122.742676,-8630.850586,-7715.640625,14521.473633,1111.089233,-11310.09082,-7605.333008,-5211.052246,9644.207031,14045.542969,7104.358398,19551.708984,-13986.811523,-4101.196777,-3429.51709,17604.417969,-9874.216797,804.463379,2772.717285,-2172.591064,7563.640625,9846.825195,2059.871094,2958.040771,2316.457764,12316.712891,2480.868652,-17846.857422,-7590.908203,-9057.871094,-2054.790527
"""▁perforce""",2554,1419.814697,-18540.603516,10474.90918,-6559.92041,22791.869141,339.467712,-4053.29541,2610.476562,-1817.029419,-3074.878662,1078.336426,363.514587,-3135.254639,11504.192383,4334.347168,-144.419357,-4891.397949,-402.38028,7083.713867,15008.277344,8393.455078,3342.667725,-2776.926514,3959.003662,-3252.584229,166.990799,10094.742188,5348.339844,-4039.562256,12630.323242,-235.404999,-3814.471191,-6229.195801,5350.229004,-4011.817871,…,3155.580078,-1397.125854,-3653.310059,-666.091064,19611.5,10068.120117,1341.33606,-3348.138184,6643.76416,-3437.553711,7410.625,-4539.281738,-3783.976318,-5824.100098,-2476.200928,18200.789062,4600.745605,16877.867188,-13757.123047,-10268.798828,3932.779785,17395.619141,-4253.066406,-14552.951172,-3141.924316,9704.746094,2379.150635,5618.531738,1021.155762,9126.204102,2765.288818,9872.863281,6443.302246,-3356.700928,-6683.908691,-6151.64502,-9045.299805
"""▁castle""",2007,-4229.874512,-4178.651855,2772.171387,-4495.574219,3830.035889,-1087.234009,308.352997,5522.677734,-1385.413452,-2118.908447,-879.550415,3726.971924,-971.319763,2783.561523,3248.839844,-2082.979004,-694.346619,-3118.477295,1688.149048,4930.899414,-4049.894775,1257.667236,1074.674072,-3727.5,-3821.339844,5466.620605,425.973175,880.930725,-2031.386353,1805.665649,-2595.820068,-208.628403,6525.966309,-4616.744141,-2637.871094,…,2431.286865,-1537.646484,227.205551,-3241.657227,5108.012695,2320.427979,-1254.968628,-580.961792,3789.618896,-6087.088867,594.146057,-4218.633789,-2959.508545,2370.702148,-482.947021,5505.622559,4956.480957,1721.7229,-4821.256836,-1815.167969,619.161011,5542.36084,-5117.79541,-7176.319336,1956.365723,2828.599854,4885.998047,1678.619873,-17.509712,876.456116,3812.211426,2716.855713,1747.479614,3432.223877,4987.053223,1611.177368,8107.540527


In [93]:
make_embedding_table(layer_outputs, 1)[token_pos] # embeddings after layer 1

token,token_id,dim_0,dim_1,dim_2,dim_3,dim_4,dim_5,dim_6,dim_7,dim_8,dim_9,dim_10,dim_11,dim_12,dim_13,dim_14,dim_15,dim_16,dim_17,dim_18,dim_19,dim_20,dim_21,dim_22,dim_23,dim_24,dim_25,dim_26,dim_27,dim_28,dim_29,dim_30,dim_31,dim_32,dim_33,dim_34,…,dim_219,dim_220,dim_221,dim_222,dim_223,dim_224,dim_225,dim_226,dim_227,dim_228,dim_229,dim_230,dim_231,dim_232,dim_233,dim_234,dim_235,dim_236,dim_237,dim_238,dim_239,dim_240,dim_241,dim_242,dim_243,dim_244,dim_245,dim_246,dim_247,dim_248,dim_249,dim_250,dim_251,dim_252,dim_253,dim_254,dim_255
str,i32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,…,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32
"""▁basilisk""",4077,-8090.689941,-1049.060547,-8250.289062,-7907.695312,16505.449219,7633.939453,-5194.604004,-7644.915039,-5154.192383,1947.614014,4280.773926,-1995.953369,-7377.384766,-1577.178101,1940.553467,6651.288086,6524.078125,7368.772461,-1557.074341,-5777.452637,8033.206543,-3418.062256,-11337.177734,-921.456482,-7698.394531,-8326.713867,-42.090767,4677.319824,-6359.791504,5956.827637,12759.844727,-904.320557,-10917.849609,12072.748047,-1437.070068,…,-2891.894775,9678.472656,-1231.582764,-3392.508057,5940.005371,8350.737305,-4408.295898,-7984.439941,-7547.034668,14666.262695,1479.968018,-11449.905273,-7564.42041,-5171.051758,9790.69043,13819.654297,6985.188477,19385.935547,-13984.083008,-4087.224121,-3402.923828,17494.533203,-9924.357422,806.221436,2818.673828,-2420.263916,7404.340332,9768.691406,2405.508057,2726.175049,2338.13623,12226.777344,2292.121826,-17749.958984,-7777.648926,-8981.472656,-1740.394653
"""▁perforce""",2554,1425.573242,-18593.212891,10675.508789,-6388.139648,22371.78125,260.21524,-4054.873779,2810.560791,-2409.95752,-2948.710693,1676.280518,499.991882,-3391.416504,11691.126953,4176.353027,-140.985489,-4980.914062,-479.275177,7097.535156,14907.964844,8530.62793,3617.163574,-2722.75293,3950.300049,-2901.143799,432.774353,10136.833984,5144.164551,-3906.593262,12241.334961,-186.447571,-3953.822266,-6129.59668,5468.428711,-3907.325928,…,3105.043945,-1107.121948,-3761.325195,-517.605774,19569.96875,9703.245117,918.431946,-2418.979492,6654.178711,-3405.675293,7890.103027,-4628.306641,-3941.250244,-5964.398438,-2351.066406,18032.939453,4296.489258,16741.302734,-13699.591797,-10244.46875,3900.628174,17189.4375,-4400.144531,-14689.807617,-3102.007324,9368.84082,2239.216553,5583.44873,1132.705566,8885.029297,2737.557617,9693.134766,6216.740234,-3480.526855,-7034.604004,-5980.892578,-8753.493164
"""▁castle""",2007,-4264.820312,-4190.809082,2799.582764,-4432.558594,3698.950928,-1125.484253,238.2043,5531.461914,-1561.833496,-2067.72168,-494.128204,3780.778809,-1028.729736,2841.970215,3245.330322,-2123.682129,-785.888367,-3139.942871,1662.483032,4825.373535,-4012.401855,1341.549683,1084.497314,-3688.043457,-3601.885254,5716.886719,483.930237,740.903442,-2012.512085,1746.150146,-2577.167725,-223.469116,6556.008789,-4566.958008,-2605.872803,…,2416.016846,-1431.868408,182.580627,-3100.230957,5112.613281,2148.290527,-1451.765381,-127.295471,3737.481689,-6132.246582,774.484314,-4150.973633,-3014.709717,2281.033691,-466.091125,5413.194824,4804.76709,1652.966309,-4702.697266,-1833.362183,597.596863,5494.562988,-5194.768066,-7275.532715,1983.31897,2755.190674,4794.411133,1581.282349,19.552288,813.71814,3745.587158,2611.446289,1642.186157,3392.751709,4891.776855,1707.41626,8190.223633


In [94]:
make_embedding_table(layer_outputs, 2)[token_pos] # embeddings after layer 2

token,token_id,dim_0,dim_1,dim_2,dim_3,dim_4,dim_5,dim_6,dim_7,dim_8,dim_9,dim_10,dim_11,dim_12,dim_13,dim_14,dim_15,dim_16,dim_17,dim_18,dim_19,dim_20,dim_21,dim_22,dim_23,dim_24,dim_25,dim_26,dim_27,dim_28,dim_29,dim_30,dim_31,dim_32,dim_33,dim_34,…,dim_219,dim_220,dim_221,dim_222,dim_223,dim_224,dim_225,dim_226,dim_227,dim_228,dim_229,dim_230,dim_231,dim_232,dim_233,dim_234,dim_235,dim_236,dim_237,dim_238,dim_239,dim_240,dim_241,dim_242,dim_243,dim_244,dim_245,dim_246,dim_247,dim_248,dim_249,dim_250,dim_251,dim_252,dim_253,dim_254,dim_255
str,i32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,…,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32
"""▁basilisk""",4077,-8070.290039,-1028.493896,-8214.024414,-7837.346191,16498.070312,7766.802734,-5200.621094,-7658.48584,-5204.258789,1949.707275,4255.023926,-1913.366577,-7404.898438,-1554.703247,1937.522339,6728.300781,6565.558594,7399.799316,-1538.750122,-5716.833008,8065.176758,-3495.283447,-11384.629883,-944.847473,-7757.356445,-8350.643555,-30.165838,4717.378418,-6393.666504,5944.267578,12793.28125,-954.608948,-10913.177734,11986.567383,-1512.369141,…,-2918.434082,9666.93457,-1351.643799,-3342.330811,5927.51416,8357.999023,-4439.360352,-8026.458496,-7590.655762,14696.250977,1495.405884,-11339.348633,-7600.278809,-5167.266113,9863.395508,13789.158203,6941.896484,19384.492188,-13928.657227,-4106.111328,-3384.457275,17407.095703,-9869.713867,773.058716,2799.060303,-2360.07666,7378.669922,9834.470703,2400.656982,2682.101807,2239.953125,12228.46582,2295.095703,-17763.1875,-7744.414551,-9013.397461,-1726.040894
"""▁perforce""",2554,1470.318481,-18555.761719,10724.299805,-6218.543945,22359.865234,214.654907,-4076.789062,2833.704346,-2399.266846,-2967.190186,1692.329224,512.005127,-3351.577148,11733.655273,4197.35791,-196.398499,-4888.873047,-466.692505,7146.74707,14938.092773,8511.557617,3651.864746,-2642.689453,3959.858398,-2954.027344,490.442749,10126.636719,5146.911133,-3981.001953,12179.05957,-149.686111,-3967.373291,-6069.499512,5392.294922,-3945.479248,…,3026.520264,-1128.601562,-3782.1521,-497.264587,19631.681641,9688.837891,989.947266,-2447.145264,6681.534668,-3420.708496,7840.083496,-4557.578613,-3926.08374,-5878.617676,-2452.619141,18015.123047,4354.765625,16699.839844,-13711.489258,-10353.706055,3947.093018,17194.222656,-4330.950195,-14723.707031,-3133.964355,9296.300781,2273.014648,5574.89209,1270.485962,8894.585938,2742.04834,9705.989258,6239.768555,-3394.700684,-6924.143066,-5995.156738,-8625.885742
"""▁castle""",2007,-4207.611816,-4100.273438,2830.047363,-4436.849121,3674.491699,-1079.389771,284.686279,5619.125,-1613.523438,-2045.514648,-502.34967,3744.408936,-1033.588257,2841.481201,3224.072021,-2103.68457,-761.090332,-3098.492188,1676.06543,4839.396973,-3974.372559,1412.941162,1071.443237,-3678.472168,-3711.19165,5674.092773,516.291321,759.066223,-2030.216309,1755.743652,-2559.612549,-246.60408,6513.316406,-4685.797852,-2612.116211,…,2411.862793,-1476.060547,185.602158,-3058.99292,5149.200684,2136.554688,-1429.229248,-148.563095,3731.929199,-6188.157715,792.300659,-4088.7229,-3027.089111,2301.604248,-413.430573,5378.646484,4819.995605,1679.230225,-4575.657715,-1846.616577,573.680054,5497.958496,-5203.865234,-7340.349121,1902.937744,2808.740479,4818.488281,1554.016602,36.054688,794.379333,3722.911377,2729.655029,1734.713623,3442.875488,4963.378906,1684.400269,8189.865234


In [96]:
make_embedding_table(layer_outputs, 3)[token_pos] # embeddings after layer 3

token,token_id,dim_0,dim_1,dim_2,dim_3,dim_4,dim_5,dim_6,dim_7,dim_8,dim_9,dim_10,dim_11,dim_12,dim_13,dim_14,dim_15,dim_16,dim_17,dim_18,dim_19,dim_20,dim_21,dim_22,dim_23,dim_24,dim_25,dim_26,dim_27,dim_28,dim_29,dim_30,dim_31,dim_32,dim_33,dim_34,…,dim_219,dim_220,dim_221,dim_222,dim_223,dim_224,dim_225,dim_226,dim_227,dim_228,dim_229,dim_230,dim_231,dim_232,dim_233,dim_234,dim_235,dim_236,dim_237,dim_238,dim_239,dim_240,dim_241,dim_242,dim_243,dim_244,dim_245,dim_246,dim_247,dim_248,dim_249,dim_250,dim_251,dim_252,dim_253,dim_254,dim_255
str,i32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,…,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32
"""▁basilisk""",4077,-8037.698242,-1018.565674,-8230.804688,-7826.672852,16481.863281,7720.520996,-5236.611816,-7712.308105,-5205.476074,1969.84436,4174.398438,-1927.168945,-7376.044922,-1534.307739,1912.446411,6736.691406,6525.016113,7386.075195,-1571.496704,-5617.972168,8109.799805,-3532.170166,-11413.563477,-950.401001,-7705.186035,-8404.859375,14.677206,4698.258789,-6429.203613,5900.961426,12783.208984,-983.677307,-10911.8125,11994.977539,-1494.094604,…,-2939.928955,9647.911133,-1315.722534,-3323.108887,6025.381836,8368.456055,-4462.436035,-8065.373535,-7519.395508,14691.996094,1567.065308,-11373.348633,-7559.437988,-5198.962891,9872.358398,13797.352539,6867.449707,19418.044922,-13943.268555,-4050.076904,-3407.007324,17395.478516,-9873.360352,701.184082,2797.47168,-2350.14917,7389.231934,9927.969727,2385.317383,2672.173096,2243.108643,12235.37793,2270.031006,-17702.585938,-7690.429199,-9011.980469,-1667.796265
"""▁perforce""",2554,1452.964111,-18490.601562,10738.131836,-6163.532227,22321.408203,211.638947,-4112.433594,2836.394043,-2471.498291,-2950.714844,1656.74585,564.778931,-3339.733398,11762.249023,4263.707031,-226.372955,-4879.063477,-463.374664,7141.452148,14924.407227,8610.113281,3626.73877,-2634.752197,4002.239502,-2980.763184,513.194092,10248.102539,5178.701172,-3997.085449,12148.621094,-166.476822,-3985.383301,-6068.071777,5350.502441,-3959.361816,…,3028.766357,-1108.85437,-3789.194824,-523.559509,19673.955078,9734.977539,1033.122559,-2370.895752,6682.132812,-3409.525391,7892.365234,-4577.979004,-3877.9021,-5914.366699,-2418.50415,18006.673828,4422.172852,16750.623047,-13700.793945,-10307.902344,3950.171143,17207.132812,-4345.130859,-14687.921875,-3170.124023,9241.889648,2270.379883,5511.869141,1227.351318,8854.583008,2820.811035,9760.508789,6226.98584,-3379.863525,-6839.045898,-6002.064941,-8632.138672
"""▁castle""",2007,-4222.954102,-4041.634033,2851.580322,-4384.849121,3637.675049,-1078.365112,248.416504,5626.44873,-1686.790039,-2031.401489,-531.103821,3795.23584,-1012.407288,2870.943115,3294.438232,-2127.572266,-752.377014,-3088.590332,1672.862915,4825.477539,-3877.63916,1394.791748,1081.958008,-3638.52417,-3744.998047,5699.766602,637.737793,800.289429,-2049.944092,1727.35144,-2580.813232,-260.841125,6517.343262,-4726.513672,-2634.07959,…,2423.514404,-1461.275635,175.575302,-3083.255371,5193.063965,2177.743652,-1383.792847,-81.134354,3739.386963,-6170.160156,844.631042,-4100.592773,-2978.484375,2265.428955,-382.112854,5367.930176,4888.741211,1724.867676,-4568.768555,-1803.158813,576.451355,5509.23291,-5213.72168,-7300.484863,1873.695435,2753.959473,4815.462402,1493.372192,-11.452499,752.805969,3798.715332,2780.115479,1714.953247,3451.946777,5049.181641,1675.677612,8182.962891
