In [2]:
%pip install transformer_lens
%pip install einops
%pip install fancy_einsum

Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.


In [30]:
import einops
from fancy_einsum import einsum
from transformer_lens import EasyTransformer
from dataclasses import dataclass
import torch
import torch.nn as nn
import numpy as np
import math
from transformer_lens.utils import get_corner, gelu_new
import tqdm.auto as tqdm


In [4]:
# load GPT-2 small model for reference (80M params)
reference_model = EasyTransformer.from_pretrained("gpt2-small", fold_ln=False, center_unembed=False, center_writing_weights=False)

`torch_dtype` is deprecated! Use `dtype` instead!


Loaded pretrained model gpt2-small into HookedTransformer


In [5]:
reference_model.tokenizer

GPT2TokenizerFast(name_or_path='gpt2', vocab_size=50257, model_max_length=1024, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<|endoftext|>', 'eos_token': '<|endoftext|>', 'unk_token': '<|endoftext|>', 'pad_token': '<|endoftext|>'}, clean_up_tokenization_spaces=False, added_tokens_decoder={
	50256: AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),
}
)

In [34]:
print(reference_model.to_str_tokens("seraph"))
print(reference_model.to_str_tokens("i love coding"))

['<|endoftext|>', 'ser', 'aph']
['<|endoftext|>', 'i', ' love', ' coding']


In [35]:
print(reference_model.to_tokens("seraph")) # text to tokens
print(reference_model.to_tokens("seraph", prepend_bos=False)) # removes initial bos token
print(reference_model.to_tokens("i love coding", prepend_bos=False))

tensor([[50256,  2655,  6570]], device='mps:0')
tensor([[2655, 6570]], device='mps:0')
tensor([[   72,  1842, 19617]], device='mps:0')


In [36]:
reference_text = "i love coding"
tokens = reference_model.to_tokens(reference_text)
print(tokens)
print(tokens.shape) # batch size, sequence length
print(reference_model.to_str_tokens(tokens))

tensor([[50256,    72,  1842, 19617]], device='mps:0')
torch.Size([1, 4])
['<|endoftext|>', 'i', ' love', ' coding']


In [9]:
logits, cache = reference_model.run_with_cache(tokens)
print(logits.shape) # batch size, sequence length, vocab size

torch.Size([1, 8, 50257])


In [38]:
probs = logits.softmax(dim=-1)
print(probs.shape) # batch size, sequence length, vocab size


torch.Size([1, 8, 50257])


In [37]:
# after each token, this is what the model thinks is next
list(zip(reference_model.to_str_tokens(reference_text), reference_model.tokenizer.batch_decode(logits.argmax(dim=-1)[0])))

[('<|endoftext|>', '\n'), ('i', 'perial'), (' love', ' be'), (' coding', ' a')]

In [12]:
# get last token logits and get highest probability token
next_token = logits[0, -1].argmax(dim=-1)
print(next_token)

tensor(290, device='mps:0')


In [39]:
next_tokens = torch.cat([tokens, torch.tensor(next_token, dtype=torch.int64)[None, None]], dim=-1)
print("new input (tokens): ", next_tokens)
print("new input (shape): ", next_tokens.shape)
print("new input (decoded): ", reference_model.tokenizer.decode(next_tokens[0]))

new_logits = reference_model(next_tokens)
print(new_logits.shape) # batch size, sequence length, vocab size



  next_tokens = torch.cat([tokens, torch.tensor(next_token, dtype=torch.int64)[None, None]], dim=-1)


new input (tokens):  tensor([[50256,    72,  1842, 19617,   290]], device='mps:0')
new input (shape):  torch.Size([1, 5])
new input (decoded):  <|endoftext|>i love coding and
torch.Size([1, 5, 50257])


---

In [14]:
@dataclass
class Config:
    layer_norm_eps: float = 1e-5
    init_range: float = 0.02
    n_ctx: int = 1024
    debug: bool = True
    d_model: int = 768
    d_vocab: int = 50257
    d_head: int = 64
    d_mlp: int = 3072
    n_heads: int = 12
    n_layers: int = 12

cfg = Config()
print(cfg)

Config(layer_norm_eps=1e-05, init_range=0.02, n_ctx=1024, debug=True, d_model=768, d_vocab=50257, d_head=64, d_mlp=3072, n_heads=12, n_layers=12)


##### LayerNorm
- make mean 0
- normalize to have variance 1
- scale with learned weights
- translate with learned bias

In [15]:
class LayerNorm(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.w = nn.Parameter(torch.ones(cfg.d_model)) # learnable weights -> by default, don't want to change the norm so set to 1
        self.b = nn.Parameter(torch.zeros(cfg.d_model)) # learnable bias -> by default, don't want to change the mean so set to 0

    def forward(self, residual): # takes in a residual vector
        if cfg.debug: print("residual shape: ", residual.shape) # batch, position, d_model
        # Calculate mean across d_model dimension and keep dimensions for broadcasting
        mean = einops.reduce(residual, "batch position d_model -> batch position 1", "mean")
        residual = residual - mean
        
        # calculate the variance (mean of squared values) and square root it; add epsilon (layer_norm_eps) to avoid division by zero
        scale = (einops.reduce(residual.pow(2), "batch position d_model -> batch position 1", "mean") + cfg.layer_norm_eps).sqrt()
        normalized = residual / scale

        # scale and translate
        if cfg.debug: print("normalized shape: ", normalized.shape) # batch, position, d_model
        return self.w * normalized + self.b

##### Embedding

In [19]:
class Embedding(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.W_E = nn.Parameter(torch.empty(cfg.d_vocab, cfg.d_model)) # create empty matrix with shape (d_vocab, d_model)
        nn.init.normal_(self.W_E, std=cfg.init_range) # initialize with normal distribution
    
    def forward(self, tokens): # look up embeddings for each token in the vocabulary
        if cfg.debug: print("tokens shape: ", tokens.shape) # batch, position
        embed = self.W_E[tokens, :] # batch, position, d_model

        if cfg.debug: print("embed shape: ", embed.shape) # batch, position, d_model
        return embed

random_int_test(Embedding, [2, 4])
load_gpt2_test(Embedding, reference_model.embed, tokens)

input shape: torch.Size([2, 4])
tokens shape:  torch.Size([2, 4])
embed shape:  torch.Size([2, 4, 768])
output shape: torch.Size([2, 4, 768])

input shape: torch.Size([1, 8])
tokens shape:  torch.Size([1, 8])
embed shape:  torch.Size([1, 8, 768])
output shape: torch.Size([1, 8, 768])
reference output shape: torch.Size([1, 8, 768])
100.00% of the values are correct


tensor([[[ 0.0514, -0.0277,  0.0499,  ...,  0.0070,  0.1552,  0.1207],
         [-0.0711, -0.0972,  0.0120,  ...,  0.1341, -0.0079, -0.2067],
         [-0.1168, -0.0044,  0.1835,  ..., -0.0532,  0.0435,  0.1913],
         ...,
         [ 0.0432, -0.1860,  0.0228,  ...,  0.0687,  0.0797, -0.0717],
         [ 0.1507, -0.0069,  0.0035,  ..., -0.1548, -0.0378, -0.0776],
         [ 0.1094, -0.1714,  0.1885,  ...,  0.2270,  0.0782,  0.0903]]],
       device='mps:0', grad_fn=<IndexBackward0>)

##### Positional Embedding

In [20]:
class PosEmbedding(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.W_pos = nn.Parameter(torch.empty(cfg.n_ctx, cfg.d_model)) # create empty matrix with shape (n_ctx, d_model)
        nn.init.normal_(self.W_pos, std=cfg.init_range) # initialize with normal distribution
    
    def forward(self, tokens): # look up table for positional encodings
        if cfg.debug: print("tokens shape: ", tokens.shape) # batch, position
        pos_embed = self.W_pos[:tokens.size(1), :] # position, d_model
        pos_embed = einops.repeat(pos_embed, "position d_model -> batch position d_model", batch=tokens.size(0))

        if cfg.debug: print("pos_embed shape: ", pos_embed.shape) # batch, position, d_model
        return pos_embed

random_int_test(PosEmbedding, [2, 4])
load_gpt2_test(PosEmbedding, reference_model.pos_embed, tokens)

input shape: torch.Size([2, 4])
tokens shape:  torch.Size([2, 4])
pos_embed shape:  torch.Size([2, 4, 768])
output shape: torch.Size([2, 4, 768])

input shape: torch.Size([1, 8])
tokens shape:  torch.Size([1, 8])
pos_embed shape:  torch.Size([1, 8, 768])
output shape: torch.Size([1, 8, 768])
reference output shape: torch.Size([1, 8, 768])
100.00% of the values are correct


tensor([[[-1.8821e-02, -1.9742e-01,  4.0267e-03,  ..., -4.3044e-02,
           2.8267e-02,  5.4490e-02],
         [ 2.3959e-02, -5.3792e-02, -9.4879e-02,  ...,  3.4170e-02,
           1.0172e-02, -1.5573e-04],
         [ 4.2161e-03, -8.4764e-02,  5.4515e-02,  ...,  1.9745e-02,
           1.9325e-02, -2.1424e-02],
         ...,
         [ 9.6023e-03, -3.3885e-02,  1.3123e-01,  ...,  5.8940e-03,
           7.1222e-03, -7.4742e-03],
         [ 2.6788e-03, -2.0530e-02,  1.1961e-01,  ...,  2.4907e-03,
           3.7071e-03, -2.5584e-03],
         [ 2.5308e-03, -3.1787e-03,  1.1741e-01,  ...,  2.0096e-03,
           4.4180e-03, -6.8326e-03]]], device='mps:0',
       grad_fn=<ExpandBackward0>)

##### Attention
- produce attention pattern -> for each destination token, probability distribution over previous tokens
    - linear map from input -> query, key
    - dot product every *pair* of queries and keys to get attention scores
    - scale and mask to make it causal
    - softmax row-wise to get a probability distribution
- move information from source tokens to destination tokens (apply a linear map)
    - linear map from input -> value
    - mix along key position with attention pattern to get a mixed value
    - map to output

In [21]:
class Attention(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg

        self.W_Q = nn.Parameter(torch.empty(cfg.n_heads, cfg.d_model, cfg.d_head))
        nn.init.normal_(self.W_Q, std=cfg.init_range)
        self.b_Q = nn.Parameter(torch.zeros(cfg.n_heads, cfg.d_head))

        self.W_K = nn.Parameter(torch.empty(cfg.n_heads, cfg.d_model, cfg.d_head))
        nn.init.normal_(self.W_K, std=cfg.init_range)
        self.b_K = nn.Parameter(torch.zeros(cfg.n_heads, cfg.d_head))

        self.W_V = nn.Parameter(torch.empty(cfg.n_heads, cfg.d_model, cfg.d_head))
        nn.init.normal_(self.W_V, std=cfg.init_range)
        self.b_V = nn.Parameter(torch.zeros(cfg.n_heads, cfg.d_head))

        self.W_O = nn.Parameter(torch.empty(cfg.n_heads, cfg.d_head, cfg.d_model))
        nn.init.normal_(self.W_O, std=cfg.init_range)
        self.b_O = nn.Parameter(torch.zeros(cfg.d_model))

        self.register_buffer("IGNORE", torch.tensor(-1e5, dtype=torch.float32))

    def forward(self, normalized_resid_pre):
        if cfg.debug: print("normalized_resid_pre shape: ", normalized_resid_pre.shape) # batch, position, d_model

        q = einsum("batch query_pos d_model, n_heads d_model d_head -> batch query_pos n_heads d_head", normalized_resid_pre, self.W_Q) + self.b_Q
        k = einsum("batch key_pos d_model, n_heads d_model d_head -> batch key_pos n_heads d_head", normalized_resid_pre, self.W_K) + self.b_K

        # attention scores: dot product of q and k -> bilinear form of inputs
        attention_scores = einsum("batch query_pos n_heads d_head, batch key_pos n_heads d_head -> batch n_heads query_pos key_pos", q, k)
        attention_scores = attention_scores / math.sqrt(self.cfg.d_head)

        # apply a mask to the attention scores
        attention_scores = self.apply_causal_mask(attention_scores)
        attention = attention_scores.softmax(dim=-1) # [batch, n_heads, query_pos, key_pos]
        
        v = einsum("batch key_pos d_model, n_heads d_model d_head -> batch key_pos n_heads d_head", normalized_resid_pre, self.W_V) + self.b_V

        z = einsum("batch n_heads query_pos key_pos, batch key_pos n_heads d_head -> batch query_pos n_heads d_head", attention, v)
        attention_output = einsum("batch query_pos n_heads d_head, n_heads d_head d_model -> batch query_pos d_model", z, self.W_O) + self.b_O
        
        return attention_output

    def apply_causal_mask(self, attention_scores):
        # make a upper triangular matrix of ones and set to true (booleans)
        mask = torch.triu(torch.ones(attention_scores.size(-2), attention_scores.size(-1), device=attention_scores.device), diagonal=1).bool()
        attention_scores.masked_fill_(mask, self.IGNORE) # edit in place with IGNORE

        return attention_scores
    
random_float_test(Attention, [2, 4, 768])
load_gpt2_test(Attention, reference_model.blocks[0].attn, cache["blocks.0.ln1.hook_normalized"])


input shape:  torch.Size([2, 4, 768])
normalized_resid_pre shape:  torch.Size([2, 4, 768])
output shape:  torch.Size([2, 4, 768])

input shape: torch.Size([1, 8, 768])
normalized_resid_pre shape:  torch.Size([1, 8, 768])
output shape: torch.Size([1, 8, 768])
reference output shape: torch.Size([1, 8, 768])
100.00% of the values are correct


tensor([[[ 0.7966,  0.0170,  0.0348,  ...,  0.0331, -0.0231,  0.1810],
         [-0.2554,  0.0237, -0.5880,  ...,  0.0272, -0.0249,  0.1380],
         [ 0.8850, -0.3348, -0.3931,  ...,  0.0183, -0.0218,  0.0931],
         ...,
         [ 0.8398,  0.1124, -0.1826,  ...,  0.0096, -0.0141,  0.0318],
         [ 0.3103,  0.4264, -1.1696,  ...,  0.0325,  0.0137,  0.0291],
         [-2.0696,  0.2373, -0.3783,  ...,  0.0313, -0.0261,  0.0652]]],
       device='mps:0', grad_fn=<AddBackward0>)

##### MLP

In [22]:
class MLP(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg

        self.W_in = nn.Parameter(torch.empty(cfg.d_model, cfg.d_mlp))
        nn.init.normal_(self.W_in, std=cfg.init_range)
        self.b_in = nn.Parameter(torch.zeros(cfg.d_mlp))

        self.W_out = nn.Parameter(torch.empty(cfg.d_mlp, cfg.d_model))
        nn.init.normal_(self.W_out, std=cfg.init_range)
        self.b_out = nn.Parameter(torch.zeros(cfg.d_model))

    def forward(self, normalized_resid_mid):
        if cfg.debug: print("normalized_resid_mid shape: ", normalized_resid_mid.shape) # batch, position, d_model

        # preactivation
        pre = einsum("batch position d_model, d_model d_mlp -> batch position d_mlp", normalized_resid_mid, self.W_in) + self.b_in
        post = gelu_new(pre)
        
        mlp_out = einsum("batch position d_mlp, d_mlp d_model -> batch position d_model", post, self.W_out) + self.b_out
        return mlp_out

random_float_test(MLP, [2, 4, 768])
load_gpt2_test(MLP, reference_model.blocks[0].mlp, cache["blocks.0.ln2.hook_normalized"])

input shape:  torch.Size([2, 4, 768])
normalized_resid_mid shape:  torch.Size([2, 4, 768])
output shape:  torch.Size([2, 4, 768])

input shape: torch.Size([1, 8, 768])
normalized_resid_mid shape:  torch.Size([1, 8, 768])
output shape: torch.Size([1, 8, 768])
reference output shape: torch.Size([1, 8, 768])
100.00% of the values are correct


tensor([[[-0.4380,  0.3624,  0.5117,  ...,  1.7227,  1.5761,  0.0368],
         [-0.8610, -1.0628, -1.0772,  ...,  0.8934, -0.5213,  1.6064],
         [-1.2453, -0.9915,  0.1291,  ..., -0.4437, -0.6902, -1.8190],
         ...,
         [ 0.3088,  1.7319, -0.0375,  ..., -0.5291, -0.4050,  1.6018],
         [-0.6507,  1.8183, -0.6009,  ...,  0.1165,  0.5283, -0.0727],
         [-1.3528,  0.7457, -0.6612,  ...,  0.1451,  0.1886,  2.9218]]],
       device='mps:0', grad_fn=<AddBackward0>)

##### Transformer Block

In [23]:
class TransformerBlock(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg

        self.ln1 = LayerNorm(cfg)
        self.attn = Attention(cfg)
        self.ln2 = LayerNorm(cfg)
        self.mlp = MLP(cfg)

    def forward(self, resid_pre):
        normalized_resid_pre = self.ln1(resid_pre)
        attention_out = self.attn(normalized_resid_pre)
        resid_mid = resid_pre + attention_out
        
        normalized_resid_mid = self.ln2(resid_mid)
        mlp_out = self.mlp(normalized_resid_mid)
        resid_post = resid_mid + mlp_out
        return resid_post

random_float_test(TransformerBlock, [2, 4, 768])
load_gpt2_test(TransformerBlock, reference_model.blocks[0], cache["resid_pre", 0])

input shape:  torch.Size([2, 4, 768])
residual shape:  torch.Size([2, 4, 768])
normalized shape:  torch.Size([2, 4, 768])
normalized_resid_pre shape:  torch.Size([2, 4, 768])
residual shape:  torch.Size([2, 4, 768])
normalized shape:  torch.Size([2, 4, 768])
normalized_resid_mid shape:  torch.Size([2, 4, 768])
output shape:  torch.Size([2, 4, 768])

input shape: torch.Size([1, 8, 768])
residual shape:  torch.Size([1, 8, 768])
normalized shape:  torch.Size([1, 8, 768])
normalized_resid_pre shape:  torch.Size([1, 8, 768])
residual shape:  torch.Size([1, 8, 768])
normalized shape:  torch.Size([1, 8, 768])
normalized_resid_mid shape:  torch.Size([1, 8, 768])
output shape: torch.Size([1, 8, 768])
reference output shape: torch.Size([1, 8, 768])
100.00% of the values are correct


tensor([[[ 3.9112e-01,  1.5426e-01,  6.0045e-01,  ...,  1.7198e+00,
           1.7365e+00,  3.9297e-01],
         [-1.1635e+00, -1.1901e+00, -1.7481e+00,  ...,  1.0889e+00,
          -5.4396e-01,  1.5375e+00],
         [-4.7287e-01, -1.4154e+00, -2.5976e-02,  ..., -4.5882e-01,
          -6.4922e-01, -1.5560e+00],
         ...,
         [ 1.2015e+00,  1.6244e+00, -6.6016e-02,  ..., -4.4496e-01,
          -3.3224e-01,  1.5545e+00],
         [-1.8694e-01,  2.2173e+00, -1.6475e+00,  ..., -3.2857e-03,
           5.0783e-01, -1.2378e-01],
         [-3.3104e+00,  8.0838e-01, -7.3359e-01,  ...,  4.0542e-01,
           2.4506e-01,  3.0705e+00]]], device='mps:0', grad_fn=<AddBackward0>)

##### Unembedding

In [24]:
class Unembedding(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg

        self.W_U = nn.Parameter(torch.empty(cfg.d_model, cfg.d_vocab))
        nn.init.normal_(self.W_U, std=cfg.init_range)
        self.b_U = nn.Parameter(torch.zeros(cfg.d_vocab), requires_grad=False)

    def forward(self, normalized_resid_post):
        if cfg.debug: print("normalized_resid_post shape: ", normalized_resid_post.shape) # batch, position, d_model

        logits = einsum("batch position d_model, d_model d_vocab -> batch position d_vocab", normalized_resid_post, self.W_U) + self.b_U
        return logits
        

##### Full Transformer

In [25]:
class Transformer(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg

        self.embed = Embedding(cfg)
        self.pos_embed = PosEmbedding(cfg)
        self.blocks = nn.ModuleList([TransformerBlock(cfg) for _ in range(cfg.n_layers)])
        self.ln_final = LayerNorm(cfg)
        self.unembed = Unembedding(cfg)

    def forward(self, tokens):
        # tokens [batch, position]
        embed = self.embed(tokens) # [batch, position, d_model]
        pos_embed = self.pos_embed(tokens) # [batch, position, d_model]

        residual = embed + pos_embed
        for block in self.blocks:
            residual = block(residual)
        
        normalized_resid_final = self.ln_final(residual)
        logits = self.unembed(normalized_resid_final)

        return logits

random_int_test(Transformer, [2, 4])
load_gpt2_test(Transformer, reference_model, tokens)

            

input shape: torch.Size([2, 4])
tokens shape:  torch.Size([2, 4])
embed shape:  torch.Size([2, 4, 768])
tokens shape:  torch.Size([2, 4])
pos_embed shape:  torch.Size([2, 4, 768])
residual shape:  torch.Size([2, 4, 768])
normalized shape:  torch.Size([2, 4, 768])
normalized_resid_pre shape:  torch.Size([2, 4, 768])
residual shape:  torch.Size([2, 4, 768])
normalized shape:  torch.Size([2, 4, 768])
normalized_resid_mid shape:  torch.Size([2, 4, 768])
residual shape:  torch.Size([2, 4, 768])
normalized shape:  torch.Size([2, 4, 768])
normalized_resid_pre shape:  torch.Size([2, 4, 768])
residual shape:  torch.Size([2, 4, 768])
normalized shape:  torch.Size([2, 4, 768])
normalized_resid_mid shape:  torch.Size([2, 4, 768])
residual shape:  torch.Size([2, 4, 768])
normalized shape:  torch.Size([2, 4, 768])
normalized_resid_pre shape:  torch.Size([2, 4, 768])
residual shape:  torch.Size([2, 4, 768])
normalized shape:  torch.Size([2, 4, 768])
normalized_resid_mid shape:  torch.Size([2, 4, 768]

tensor([[[ -43.4317,  -39.8364,  -43.0660,  ...,  -54.0877,  -54.3452,
           -42.3645],
         [ -70.0298,  -67.5770,  -69.7691,  ...,  -78.2607,  -74.7699,
           -68.8778],
         [-131.2627, -131.6137, -135.4736,  ..., -141.5668, -139.9645,
          -131.9789],
         ...,
         [ -81.9859,  -81.4428,  -84.2107,  ...,  -88.7520,  -90.9199,
           -81.9112],
         [ -93.7111,  -91.2787,  -94.5029,  ..., -105.2199, -104.9473,
           -92.7546],
         [ -91.5967,  -92.8907,  -93.9723,  ..., -102.4186, -102.8982,
           -91.5232]]], device='mps:0', grad_fn=<AddBackward0>)

In [26]:
nanogpt = Transformer(Config(debug=False))
nanogpt.load_state_dict(reference_model.state_dict(), strict=False)
nanogpt = nanogpt.to(next(reference_model.parameters()).device)

In [27]:
test_input = "Y Combinator (YC) initiated operations with concurrent programs in Cambridge, Massachusetts, and Mountain View, California. However, operational complexities arising from managing two programs prompted a consolidation in January 2009, resulting in the closing of the Cambridge program and the centralization of activities in Silicon Valley."
test_tokens = reference_model.to_tokens(test_input)
test_logits = nanogpt(test_tokens)

tokens shape:  torch.Size([1, 58])
embed shape:  torch.Size([1, 58, 768])
tokens shape:  torch.Size([1, 58])
pos_embed shape:  torch.Size([1, 58, 768])
residual shape:  torch.Size([1, 58, 768])
normalized shape:  torch.Size([1, 58, 768])
normalized_resid_pre shape:  torch.Size([1, 58, 768])
residual shape:  torch.Size([1, 58, 768])
normalized shape:  torch.Size([1, 58, 768])
normalized_resid_mid shape:  torch.Size([1, 58, 768])
residual shape:  torch.Size([1, 58, 768])
normalized shape:  torch.Size([1, 58, 768])
normalized_resid_pre shape:  torch.Size([1, 58, 768])
residual shape:  torch.Size([1, 58, 768])
normalized shape:  torch.Size([1, 58, 768])
normalized_resid_mid shape:  torch.Size([1, 58, 768])
residual shape:  torch.Size([1, 58, 768])
normalized shape:  torch.Size([1, 58, 768])
normalized_resid_pre shape:  torch.Size([1, 58, 768])
residual shape:  torch.Size([1, 58, 768])
normalized shape:  torch.Size([1, 58, 768])
normalized_resid_mid shape:  torch.Size([1, 58, 768])
residual

In [28]:
def lm_cross_entropy_loss(logits, tokens):
    # measure next token loss
    # logits: [batch, seq_len, vocab_size]
    # tokens: [batch, seq_len]
    # for each position i, we predict token i+1
    # so we use logits at position i to predict token at position i+1
    
    log_probs = logits.log_softmax(dim=-1)
    
    # Ensure sequence lengths match
    seq_len = min(logits.shape[1], tokens.shape[1])
    
    # logits[:, :-1] predicts tokens at positions 1 to seq_len-1
    # tokens[:, 1:] are the actual tokens at positions 1 to seq_len-1
    pred_log_probs = log_probs[:, :seq_len-1].gather(
        dim=-1, 
        index=tokens[:, 1:seq_len].unsqueeze(-1)).squeeze(-1)

    return -pred_log_probs.mean()

lm_cross_entropy_loss(test_logits, test_tokens)

tensor(3.5866, device='mps:0', grad_fn=<NegBackward0>)

In [32]:
test_string = "Y Combinator (YC) initiated operations with concurrent programs in Cambridge, Massachusetts, and Mountain View, California. However, operational complexities arising from managing two programs prompted a consolidation in January 2009, resulting in the closing of the Cambridge program and the centralization of activities in Silicon Valley."
for i in tqdm.tqdm(range(100)):
    test_tokens = reference_model.to_tokens(test_string)
    demo_logits = nanogpt(test_tokens)
    test_string += reference_model.tokenizer.decode(demo_logits[-1, -1].argmax())
print(test_string)


  0%|          | 0/100 [00:00<?, ?it/s]

tokens shape:  torch.Size([1, 58])
embed shape:  torch.Size([1, 58, 768])
tokens shape:  torch.Size([1, 58])
pos_embed shape:  torch.Size([1, 58, 768])
residual shape:  torch.Size([1, 58, 768])
normalized shape:  torch.Size([1, 58, 768])
normalized_resid_pre shape:  torch.Size([1, 58, 768])
residual shape:  torch.Size([1, 58, 768])
normalized shape:  torch.Size([1, 58, 768])
normalized_resid_mid shape:  torch.Size([1, 58, 768])
residual shape:  torch.Size([1, 58, 768])
normalized shape:  torch.Size([1, 58, 768])
normalized_resid_pre shape:  torch.Size([1, 58, 768])
residual shape:  torch.Size([1, 58, 768])
normalized shape:  torch.Size([1, 58, 768])
normalized_resid_mid shape:  torch.Size([1, 58, 768])
residual shape:  torch.Size([1, 58, 768])
normalized shape:  torch.Size([1, 58, 768])
normalized_resid_pre shape:  torch.Size([1, 58, 768])
residual shape:  torch.Size([1, 58, 768])
normalized shape:  torch.Size([1, 58, 768])
normalized_resid_mid shape:  torch.Size([1, 58, 768])
residual

  4%|▍         | 4/100 [00:00<00:08, 11.43it/s]

residual shape:  torch.Size([1, 59, 768])
normalized shape:  torch.Size([1, 59, 768])
normalized_resid_mid shape:  torch.Size([1, 59, 768])
residual shape:  torch.Size([1, 59, 768])
normalized shape:  torch.Size([1, 59, 768])
normalized_resid_pre shape:  torch.Size([1, 59, 768])
residual shape:  torch.Size([1, 59, 768])
normalized shape:  torch.Size([1, 59, 768])
normalized_resid_mid shape:  torch.Size([1, 59, 768])
residual shape:  torch.Size([1, 59, 768])
normalized shape:  torch.Size([1, 59, 768])
normalized_resid_pre shape:  torch.Size([1, 59, 768])
residual shape:  torch.Size([1, 59, 768])
normalized shape:  torch.Size([1, 59, 768])
normalized_resid_mid shape:  torch.Size([1, 59, 768])
residual shape:  torch.Size([1, 59, 768])
normalized shape:  torch.Size([1, 59, 768])
normalized_resid_post shape:  torch.Size([1, 59, 768])
tokens shape:  torch.Size([1, 59])
embed shape:  torch.Size([1, 59, 768])
tokens shape:  torch.Size([1, 59])
pos_embed shape:  torch.Size([1, 59, 768])
residua

  6%|▌         | 6/100 [00:00<00:08, 10.66it/s]

residual shape:  torch.Size([1, 61, 768])
normalized shape:  torch.Size([1, 61, 768])
normalized_resid_mid shape:  torch.Size([1, 61, 768])
residual shape:  torch.Size([1, 61, 768])
normalized shape:  torch.Size([1, 61, 768])
normalized_resid_pre shape:  torch.Size([1, 61, 768])
residual shape:  torch.Size([1, 61, 768])
normalized shape:  torch.Size([1, 61, 768])
normalized_resid_mid shape:  torch.Size([1, 61, 768])
residual shape:  torch.Size([1, 61, 768])
normalized shape:  torch.Size([1, 61, 768])
normalized_resid_pre shape:  torch.Size([1, 61, 768])
residual shape:  torch.Size([1, 61, 768])
normalized shape:  torch.Size([1, 61, 768])
normalized_resid_mid shape:  torch.Size([1, 61, 768])
residual shape:  torch.Size([1, 61, 768])
normalized shape:  torch.Size([1, 61, 768])
normalized_resid_pre shape:  torch.Size([1, 61, 768])
residual shape:  torch.Size([1, 61, 768])
normalized shape:  torch.Size([1, 61, 768])
normalized_resid_mid shape:  torch.Size([1, 61, 768])
residual shape:  tor

  8%|▊         | 8/100 [00:00<00:11,  8.22it/s]

residual shape:  torch.Size([1, 64, 768])
normalized shape:  torch.Size([1, 64, 768])
normalized_resid_pre shape:  torch.Size([1, 64, 768])
residual shape:  torch.Size([1, 64, 768])
normalized shape:  torch.Size([1, 64, 768])
normalized_resid_mid shape:  torch.Size([1, 64, 768])
residual shape:  torch.Size([1, 64, 768])
normalized shape:  torch.Size([1, 64, 768])
normalized_resid_pre shape:  torch.Size([1, 64, 768])
residual shape:  torch.Size([1, 64, 768])
normalized shape:  torch.Size([1, 64, 768])
normalized_resid_mid shape:  torch.Size([1, 64, 768])
residual shape:  torch.Size([1, 64, 768])
normalized shape:  torch.Size([1, 64, 768])
normalized_resid_pre shape:  torch.Size([1, 64, 768])
residual shape:  torch.Size([1, 64, 768])
normalized shape:  torch.Size([1, 64, 768])
normalized_resid_mid shape:  torch.Size([1, 64, 768])
residual shape:  torch.Size([1, 64, 768])
normalized shape:  torch.Size([1, 64, 768])
normalized_resid_pre shape:  torch.Size([1, 64, 768])
residual shape:  tor

  9%|▉         | 9/100 [00:01<00:14,  6.26it/s]

residual shape:  torch.Size([1, 65, 768])
normalized shape:  torch.Size([1, 65, 768])
normalized_resid_mid shape:  torch.Size([1, 65, 768])
residual shape:  torch.Size([1, 65, 768])
normalized shape:  torch.Size([1, 65, 768])
normalized_resid_pre shape:  torch.Size([1, 65, 768])
residual shape:  torch.Size([1, 65, 768])
normalized shape:  torch.Size([1, 65, 768])
normalized_resid_mid shape:  torch.Size([1, 65, 768])
residual shape:  torch.Size([1, 65, 768])
normalized shape:  torch.Size([1, 65, 768])
normalized_resid_pre shape:  torch.Size([1, 65, 768])
residual shape:  torch.Size([1, 65, 768])
normalized shape:  torch.Size([1, 65, 768])
normalized_resid_mid shape:  torch.Size([1, 65, 768])
residual shape:  torch.Size([1, 65, 768])
normalized shape:  torch.Size([1, 65, 768])
normalized_resid_pre shape:  torch.Size([1, 65, 768])
residual shape:  torch.Size([1, 65, 768])
normalized shape:  torch.Size([1, 65, 768])
normalized_resid_mid shape:  torch.Size([1, 65, 768])
residual shape:  tor

 11%|█         | 11/100 [00:01<00:11,  7.88it/s]

residual shape:  torch.Size([1, 66, 768])
normalized shape:  torch.Size([1, 66, 768])
normalized_resid_mid shape:  torch.Size([1, 66, 768])
residual shape:  torch.Size([1, 66, 768])
normalized shape:  torch.Size([1, 66, 768])
normalized_resid_pre shape:  torch.Size([1, 66, 768])
residual shape:  torch.Size([1, 66, 768])
normalized shape:  torch.Size([1, 66, 768])
normalized_resid_mid shape:  torch.Size([1, 66, 768])
residual shape:  torch.Size([1, 66, 768])
normalized shape:  torch.Size([1, 66, 768])
normalized_resid_pre shape:  torch.Size([1, 66, 768])
residual shape:  torch.Size([1, 66, 768])
normalized shape:  torch.Size([1, 66, 768])
normalized_resid_mid shape:  torch.Size([1, 66, 768])
residual shape:  torch.Size([1, 66, 768])
normalized shape:  torch.Size([1, 66, 768])
normalized_resid_pre shape:  torch.Size([1, 66, 768])
residual shape:  torch.Size([1, 66, 768])
normalized shape:  torch.Size([1, 66, 768])
normalized_resid_mid shape:  torch.Size([1, 66, 768])
residual shape:  tor

 13%|█▎        | 13/100 [00:01<00:09,  8.93it/s]

residual shape:  torch.Size([1, 69, 768])
normalized shape:  torch.Size([1, 69, 768])
normalized_resid_mid shape:  torch.Size([1, 69, 768])
residual shape:  torch.Size([1, 69, 768])
normalized shape:  torch.Size([1, 69, 768])
normalized_resid_pre shape:  torch.Size([1, 69, 768])
residual shape:  torch.Size([1, 69, 768])
normalized shape:  torch.Size([1, 69, 768])
normalized_resid_mid shape:  torch.Size([1, 69, 768])
residual shape:  torch.Size([1, 69, 768])
normalized shape:  torch.Size([1, 69, 768])
normalized_resid_pre shape:  torch.Size([1, 69, 768])
residual shape:  torch.Size([1, 69, 768])
normalized shape:  torch.Size([1, 69, 768])
normalized_resid_mid shape:  torch.Size([1, 69, 768])
residual shape:  torch.Size([1, 69, 768])
normalized shape:  torch.Size([1, 69, 768])
normalized_resid_pre shape:  torch.Size([1, 69, 768])
residual shape:  torch.Size([1, 69, 768])
normalized shape:  torch.Size([1, 69, 768])
normalized_resid_mid shape:  torch.Size([1, 69, 768])
residual shape:  tor

 15%|█▌        | 15/100 [00:01<00:08,  9.76it/s]

residual shape:  torch.Size([1, 71, 768])
normalized shape:  torch.Size([1, 71, 768])
normalized_resid_mid shape:  torch.Size([1, 71, 768])
residual shape:  torch.Size([1, 71, 768])
normalized shape:  torch.Size([1, 71, 768])
normalized_resid_pre shape:  torch.Size([1, 71, 768])
residual shape:  torch.Size([1, 71, 768])
normalized shape:  torch.Size([1, 71, 768])
normalized_resid_mid shape:  torch.Size([1, 71, 768])
residual shape:  torch.Size([1, 71, 768])
normalized shape:  torch.Size([1, 71, 768])
normalized_resid_pre shape:  torch.Size([1, 71, 768])
residual shape:  torch.Size([1, 71, 768])
normalized shape:  torch.Size([1, 71, 768])
normalized_resid_mid shape:  torch.Size([1, 71, 768])
residual shape:  torch.Size([1, 71, 768])
normalized shape:  torch.Size([1, 71, 768])
normalized_resid_pre shape:  torch.Size([1, 71, 768])
residual shape:  torch.Size([1, 71, 768])
normalized shape:  torch.Size([1, 71, 768])
normalized_resid_mid shape:  torch.Size([1, 71, 768])
residual shape:  tor

 17%|█▋        | 17/100 [00:01<00:08, 10.10it/s]

tokens shape:  torch.Size([1, 74])
embed shape:  torch.Size([1, 74, 768])
tokens shape:  torch.Size([1, 74])
pos_embed shape:  torch.Size([1, 74, 768])
residual shape:  torch.Size([1, 74, 768])
normalized shape:  torch.Size([1, 74, 768])
normalized_resid_pre shape:  torch.Size([1, 74, 768])
residual shape:  torch.Size([1, 74, 768])
normalized shape:  torch.Size([1, 74, 768])
normalized_resid_mid shape:  torch.Size([1, 74, 768])
residual shape:  torch.Size([1, 74, 768])
normalized shape:  torch.Size([1, 74, 768])
normalized_resid_pre shape:  torch.Size([1, 74, 768])
residual shape:  torch.Size([1, 74, 768])
normalized shape:  torch.Size([1, 74, 768])
normalized_resid_mid shape:  torch.Size([1, 74, 768])
residual shape:  torch.Size([1, 74, 768])
normalized shape:  torch.Size([1, 74, 768])
normalized_resid_pre shape:  torch.Size([1, 74, 768])
residual shape:  torch.Size([1, 74, 768])
normalized shape:  torch.Size([1, 74, 768])
normalized_resid_mid shape:  torch.Size([1, 74, 768])
residual

 21%|██        | 21/100 [00:02<00:07,  9.93it/s]

tokens shape:  torch.Size([1, 76])
embed shape:  torch.Size([1, 76, 768])
tokens shape:  torch.Size([1, 76])
pos_embed shape:  torch.Size([1, 76, 768])
residual shape:  torch.Size([1, 76, 768])
normalized shape:  torch.Size([1, 76, 768])
normalized_resid_pre shape:  torch.Size([1, 76, 768])
residual shape:  torch.Size([1, 76, 768])
normalized shape:  torch.Size([1, 76, 768])
normalized_resid_mid shape:  torch.Size([1, 76, 768])
residual shape:  torch.Size([1, 76, 768])
normalized shape:  torch.Size([1, 76, 768])
normalized_resid_pre shape:  torch.Size([1, 76, 768])
residual shape:  torch.Size([1, 76, 768])
normalized shape:  torch.Size([1, 76, 768])
normalized_resid_mid shape:  torch.Size([1, 76, 768])
residual shape:  torch.Size([1, 76, 768])
normalized shape:  torch.Size([1, 76, 768])
normalized_resid_pre shape:  torch.Size([1, 76, 768])
residual shape:  torch.Size([1, 76, 768])
normalized shape:  torch.Size([1, 76, 768])
normalized_resid_mid shape:  torch.Size([1, 76, 768])
residual

 23%|██▎       | 23/100 [00:02<00:09,  8.54it/s]

tokens shape:  torch.Size([1, 79])
embed shape:  torch.Size([1, 79, 768])
tokens shape:  torch.Size([1, 79])
pos_embed shape:  torch.Size([1, 79, 768])
residual shape:  torch.Size([1, 79, 768])
normalized shape:  torch.Size([1, 79, 768])
normalized_resid_pre shape:  torch.Size([1, 79, 768])
residual shape:  torch.Size([1, 79, 768])
normalized shape:  torch.Size([1, 79, 768])
normalized_resid_mid shape:  torch.Size([1, 79, 768])
residual shape:  torch.Size([1, 79, 768])
normalized shape:  torch.Size([1, 79, 768])
normalized_resid_pre shape:  torch.Size([1, 79, 768])
residual shape:  torch.Size([1, 79, 768])
normalized shape:  torch.Size([1, 79, 768])
normalized_resid_mid shape:  torch.Size([1, 79, 768])
residual shape:  torch.Size([1, 79, 768])
normalized shape:  torch.Size([1, 79, 768])
normalized_resid_pre shape:  torch.Size([1, 79, 768])
residual shape:  torch.Size([1, 79, 768])
normalized shape:  torch.Size([1, 79, 768])
normalized_resid_mid shape:  torch.Size([1, 79, 768])
residual

 24%|██▍       | 24/100 [00:02<00:10,  7.25it/s]

residual shape:  torch.Size([1, 80, 768])
normalized shape:  torch.Size([1, 80, 768])
normalized_resid_mid shape:  torch.Size([1, 80, 768])
residual shape:  torch.Size([1, 80, 768])
normalized shape:  torch.Size([1, 80, 768])
normalized_resid_pre shape:  torch.Size([1, 80, 768])
residual shape:  torch.Size([1, 80, 768])
normalized shape:  torch.Size([1, 80, 768])
normalized_resid_mid shape:  torch.Size([1, 80, 768])
residual shape:  torch.Size([1, 80, 768])
normalized shape:  torch.Size([1, 80, 768])
normalized_resid_pre shape:  torch.Size([1, 80, 768])
residual shape:  torch.Size([1, 80, 768])
normalized shape:  torch.Size([1, 80, 768])
normalized_resid_mid shape:  torch.Size([1, 80, 768])
residual shape:  torch.Size([1, 80, 768])
normalized shape:  torch.Size([1, 80, 768])
normalized_resid_pre shape:  torch.Size([1, 80, 768])
residual shape:  torch.Size([1, 80, 768])
normalized shape:  torch.Size([1, 80, 768])
normalized_resid_mid shape:  torch.Size([1, 80, 768])
residual shape:  tor

 25%|██▌       | 25/100 [00:03<00:11,  6.77it/s]

residual shape:  torch.Size([1, 81, 768])
normalized shape:  torch.Size([1, 81, 768])
normalized_resid_mid shape:  torch.Size([1, 81, 768])
residual shape:  torch.Size([1, 81, 768])
normalized shape:  torch.Size([1, 81, 768])
normalized_resid_pre shape:  torch.Size([1, 81, 768])
residual shape:  torch.Size([1, 81, 768])
normalized shape:  torch.Size([1, 81, 768])
normalized_resid_mid shape:  torch.Size([1, 81, 768])
residual shape:  torch.Size([1, 81, 768])
normalized shape:  torch.Size([1, 81, 768])
normalized_resid_pre shape:  torch.Size([1, 81, 768])
residual shape:  torch.Size([1, 81, 768])
normalized shape:  torch.Size([1, 81, 768])
normalized_resid_mid shape:  torch.Size([1, 81, 768])
residual shape:  torch.Size([1, 81, 768])
normalized shape:  torch.Size([1, 81, 768])
normalized_resid_pre shape:  torch.Size([1, 81, 768])
residual shape:  torch.Size([1, 81, 768])
normalized shape:  torch.Size([1, 81, 768])
normalized_resid_mid shape:  torch.Size([1, 81, 768])
residual shape:  tor

 28%|██▊       | 28/100 [00:03<00:09,  7.85it/s]

residual shape:  torch.Size([1, 83, 768])
normalized shape:  torch.Size([1, 83, 768])
normalized_resid_mid shape:  torch.Size([1, 83, 768])
residual shape:  torch.Size([1, 83, 768])
normalized shape:  torch.Size([1, 83, 768])
normalized_resid_pre shape:  torch.Size([1, 83, 768])
residual shape:  torch.Size([1, 83, 768])
normalized shape:  torch.Size([1, 83, 768])
normalized_resid_mid shape:  torch.Size([1, 83, 768])
residual shape:  torch.Size([1, 83, 768])
normalized shape:  torch.Size([1, 83, 768])
normalized_resid_pre shape:  torch.Size([1, 83, 768])
residual shape:  torch.Size([1, 83, 768])
normalized shape:  torch.Size([1, 83, 768])
normalized_resid_mid shape:  torch.Size([1, 83, 768])
residual shape:  torch.Size([1, 83, 768])
normalized shape:  torch.Size([1, 83, 768])
normalized_resid_pre shape:  torch.Size([1, 83, 768])
residual shape:  torch.Size([1, 83, 768])
normalized shape:  torch.Size([1, 83, 768])
normalized_resid_mid shape:  torch.Size([1, 83, 768])
residual shape:  tor

 30%|███       | 30/100 [00:03<00:08,  8.74it/s]

residual shape:  torch.Size([1, 85, 768])
normalized shape:  torch.Size([1, 85, 768])
normalized_resid_mid shape:  torch.Size([1, 85, 768])
residual shape:  torch.Size([1, 85, 768])
normalized shape:  torch.Size([1, 85, 768])
normalized_resid_pre shape:  torch.Size([1, 85, 768])
residual shape:  torch.Size([1, 85, 768])
normalized shape:  torch.Size([1, 85, 768])
normalized_resid_mid shape:  torch.Size([1, 85, 768])
residual shape:  torch.Size([1, 85, 768])
normalized shape:  torch.Size([1, 85, 768])
normalized_resid_pre shape:  torch.Size([1, 85, 768])
residual shape:  torch.Size([1, 85, 768])
normalized shape:  torch.Size([1, 85, 768])
normalized_resid_mid shape:  torch.Size([1, 85, 768])
residual shape:  torch.Size([1, 85, 768])
normalized shape:  torch.Size([1, 85, 768])
normalized_resid_pre shape:  torch.Size([1, 85, 768])
residual shape:  torch.Size([1, 85, 768])
normalized shape:  torch.Size([1, 85, 768])
normalized_resid_mid shape:  torch.Size([1, 85, 768])
residual shape:  tor

 31%|███       | 31/100 [00:03<00:08,  7.71it/s]

residual shape:  torch.Size([1, 87, 768])
normalized shape:  torch.Size([1, 87, 768])
normalized_resid_mid shape:  torch.Size([1, 87, 768])
residual shape:  torch.Size([1, 87, 768])
normalized shape:  torch.Size([1, 87, 768])
normalized_resid_pre shape:  torch.Size([1, 87, 768])
residual shape:  torch.Size([1, 87, 768])
normalized shape:  torch.Size([1, 87, 768])
normalized_resid_mid shape:  torch.Size([1, 87, 768])
residual shape:  torch.Size([1, 87, 768])
normalized shape:  torch.Size([1, 87, 768])
normalized_resid_pre shape:  torch.Size([1, 87, 768])
residual shape:  torch.Size([1, 87, 768])
normalized shape:  torch.Size([1, 87, 768])
normalized_resid_mid shape:  torch.Size([1, 87, 768])
residual shape:  torch.Size([1, 87, 768])
normalized shape:  torch.Size([1, 87, 768])
normalized_resid_pre shape:  torch.Size([1, 87, 768])
residual shape:  torch.Size([1, 87, 768])
normalized shape:  torch.Size([1, 87, 768])
normalized_resid_mid shape:  torch.Size([1, 87, 768])
residual shape:  tor

 33%|███▎      | 33/100 [00:03<00:08,  8.37it/s]

residual shape:  torch.Size([1, 88, 768])
normalized shape:  torch.Size([1, 88, 768])
normalized_resid_mid shape:  torch.Size([1, 88, 768])
residual shape:  torch.Size([1, 88, 768])
normalized shape:  torch.Size([1, 88, 768])
normalized_resid_pre shape:  torch.Size([1, 88, 768])
residual shape:  torch.Size([1, 88, 768])
normalized shape:  torch.Size([1, 88, 768])
normalized_resid_mid shape:  torch.Size([1, 88, 768])
residual shape:  torch.Size([1, 88, 768])
normalized shape:  torch.Size([1, 88, 768])
normalized_resid_pre shape:  torch.Size([1, 88, 768])
residual shape:  torch.Size([1, 88, 768])
normalized shape:  torch.Size([1, 88, 768])
normalized_resid_mid shape:  torch.Size([1, 88, 768])
residual shape:  torch.Size([1, 88, 768])
normalized shape:  torch.Size([1, 88, 768])
normalized_resid_pre shape:  torch.Size([1, 88, 768])
residual shape:  torch.Size([1, 88, 768])
normalized shape:  torch.Size([1, 88, 768])
normalized_resid_mid shape:  torch.Size([1, 88, 768])
residual shape:  tor

 35%|███▌      | 35/100 [00:04<00:06,  9.57it/s]

normalized shape:  torch.Size([1, 90, 768])
normalized_resid_mid shape:  torch.Size([1, 90, 768])
residual shape:  torch.Size([1, 90, 768])
normalized shape:  torch.Size([1, 90, 768])
normalized_resid_pre shape:  torch.Size([1, 90, 768])
residual shape:  torch.Size([1, 90, 768])
normalized shape:  torch.Size([1, 90, 768])
normalized_resid_mid shape:  torch.Size([1, 90, 768])
residual shape:  torch.Size([1, 90, 768])
normalized shape:  torch.Size([1, 90, 768])
normalized_resid_pre shape:  torch.Size([1, 90, 768])
residual shape:  torch.Size([1, 90, 768])
normalized shape:  torch.Size([1, 90, 768])
normalized_resid_mid shape:  torch.Size([1, 90, 768])
residual shape:  torch.Size([1, 90, 768])
normalized shape:  torch.Size([1, 90, 768])
normalized_resid_pre shape:  torch.Size([1, 90, 768])
residual shape:  torch.Size([1, 90, 768])
normalized shape:  torch.Size([1, 90, 768])
normalized_resid_mid shape:  torch.Size([1, 90, 768])
residual shape:  torch.Size([1, 90, 768])
normalized shape:  t

 37%|███▋      | 37/100 [00:04<00:05, 10.75it/s]

normalized shape:  torch.Size([1, 93, 768])
normalized_resid_pre shape:  torch.Size([1, 93, 768])
residual shape:  torch.Size([1, 93, 768])
normalized shape:  torch.Size([1, 93, 768])
normalized_resid_mid shape:  torch.Size([1, 93, 768])
residual shape:  torch.Size([1, 93, 768])
normalized shape:  torch.Size([1, 93, 768])
normalized_resid_pre shape:  torch.Size([1, 93, 768])
residual shape:  torch.Size([1, 93, 768])
normalized shape:  torch.Size([1, 93, 768])
normalized_resid_mid shape:  torch.Size([1, 93, 768])
residual shape:  torch.Size([1, 93, 768])
normalized shape:  torch.Size([1, 93, 768])
normalized_resid_pre shape:  torch.Size([1, 93, 768])
residual shape:  torch.Size([1, 93, 768])
normalized shape:  torch.Size([1, 93, 768])
normalized_resid_mid shape:  torch.Size([1, 93, 768])
residual shape:  torch.Size([1, 93, 768])
normalized shape:  torch.Size([1, 93, 768])
normalized_resid_pre shape:  torch.Size([1, 93, 768])
residual shape:  torch.Size([1, 93, 768])
normalized shape:  t

 39%|███▉      | 39/100 [00:04<00:05, 11.61it/s]

tokens shape:  torch.Size([1, 96])
embed shape:  torch.Size([1, 96, 768])
tokens shape:  torch.Size([1, 96])
pos_embed shape:  torch.Size([1, 96, 768])
residual shape:  torch.Size([1, 96, 768])
normalized shape:  torch.Size([1, 96, 768])
normalized_resid_pre shape:  torch.Size([1, 96, 768])
residual shape:  torch.Size([1, 96, 768])
normalized shape:  torch.Size([1, 96, 768])
normalized_resid_mid shape:  torch.Size([1, 96, 768])
residual shape:  torch.Size([1, 96, 768])
normalized shape:  torch.Size([1, 96, 768])
normalized_resid_pre shape:  torch.Size([1, 96, 768])
residual shape:  torch.Size([1, 96, 768])
normalized shape:  torch.Size([1, 96, 768])
normalized_resid_mid shape:  torch.Size([1, 96, 768])
residual shape:  torch.Size([1, 96, 768])
normalized shape:  torch.Size([1, 96, 768])
normalized_resid_pre shape:  torch.Size([1, 96, 768])
residual shape:  torch.Size([1, 96, 768])
normalized shape:  torch.Size([1, 96, 768])
normalized_resid_mid shape:  torch.Size([1, 96, 768])
residual

 41%|████      | 41/100 [00:04<00:07,  8.29it/s]

residual shape:  torch.Size([1, 97, 768])
normalized shape:  torch.Size([1, 97, 768])
normalized_resid_mid shape:  torch.Size([1, 97, 768])
residual shape:  torch.Size([1, 97, 768])
normalized shape:  torch.Size([1, 97, 768])
normalized_resid_pre shape:  torch.Size([1, 97, 768])
residual shape:  torch.Size([1, 97, 768])
normalized shape:  torch.Size([1, 97, 768])
normalized_resid_mid shape:  torch.Size([1, 97, 768])
residual shape:  torch.Size([1, 97, 768])
normalized shape:  torch.Size([1, 97, 768])
normalized_resid_pre shape:  torch.Size([1, 97, 768])
residual shape:  torch.Size([1, 97, 768])
normalized shape:  torch.Size([1, 97, 768])
normalized_resid_mid shape:  torch.Size([1, 97, 768])
residual shape:  torch.Size([1, 97, 768])
normalized shape:  torch.Size([1, 97, 768])
normalized_resid_pre shape:  torch.Size([1, 97, 768])
residual shape:  torch.Size([1, 97, 768])
normalized shape:  torch.Size([1, 97, 768])
normalized_resid_mid shape:  torch.Size([1, 97, 768])
residual shape:  tor

 43%|████▎     | 43/100 [00:04<00:06,  9.08it/s]

tokens shape:  torch.Size([1, 99])
embed shape:  torch.Size([1, 99, 768])
tokens shape:  torch.Size([1, 99])
pos_embed shape:  torch.Size([1, 99, 768])
residual shape:  torch.Size([1, 99, 768])
normalized shape:  torch.Size([1, 99, 768])
normalized_resid_pre shape:  torch.Size([1, 99, 768])
residual shape:  torch.Size([1, 99, 768])
normalized shape:  torch.Size([1, 99, 768])
normalized_resid_mid shape:  torch.Size([1, 99, 768])
residual shape:  torch.Size([1, 99, 768])
normalized shape:  torch.Size([1, 99, 768])
normalized_resid_pre shape:  torch.Size([1, 99, 768])
residual shape:  torch.Size([1, 99, 768])
normalized shape:  torch.Size([1, 99, 768])
normalized_resid_mid shape:  torch.Size([1, 99, 768])
residual shape:  torch.Size([1, 99, 768])
normalized shape:  torch.Size([1, 99, 768])
normalized_resid_pre shape:  torch.Size([1, 99, 768])
residual shape:  torch.Size([1, 99, 768])
normalized shape:  torch.Size([1, 99, 768])
normalized_resid_mid shape:  torch.Size([1, 99, 768])
residual

 45%|████▌     | 45/100 [00:05<00:05,  9.78it/s]

residual shape:  torch.Size([1, 101, 768])
normalized shape:  torch.Size([1, 101, 768])
normalized_resid_mid shape:  torch.Size([1, 101, 768])
residual shape:  torch.Size([1, 101, 768])
normalized shape:  torch.Size([1, 101, 768])
normalized_resid_pre shape:  torch.Size([1, 101, 768])
residual shape:  torch.Size([1, 101, 768])
normalized shape:  torch.Size([1, 101, 768])
normalized_resid_mid shape:  torch.Size([1, 101, 768])
residual shape:  torch.Size([1, 101, 768])
normalized shape:  torch.Size([1, 101, 768])
normalized_resid_pre shape:  torch.Size([1, 101, 768])
residual shape:  torch.Size([1, 101, 768])
normalized shape:  torch.Size([1, 101, 768])
normalized_resid_mid shape:  torch.Size([1, 101, 768])
residual shape:  torch.Size([1, 101, 768])
normalized shape:  torch.Size([1, 101, 768])
normalized_resid_pre shape:  torch.Size([1, 101, 768])
residual shape:  torch.Size([1, 101, 768])
normalized shape:  torch.Size([1, 101, 768])
normalized_resid_mid shape:  torch.Size([1, 101, 768])

 49%|████▉     | 49/100 [00:05<00:05,  9.08it/s]

tokens shape:  torch.Size([1, 104])
embed shape:  torch.Size([1, 104, 768])
tokens shape:  torch.Size([1, 104])
pos_embed shape:  torch.Size([1, 104, 768])
residual shape:  torch.Size([1, 104, 768])
normalized shape:  torch.Size([1, 104, 768])
normalized_resid_pre shape:  torch.Size([1, 104, 768])
residual shape:  torch.Size([1, 104, 768])
normalized shape:  torch.Size([1, 104, 768])
normalized_resid_mid shape:  torch.Size([1, 104, 768])
residual shape:  torch.Size([1, 104, 768])
normalized shape:  torch.Size([1, 104, 768])
normalized_resid_pre shape:  torch.Size([1, 104, 768])
residual shape:  torch.Size([1, 104, 768])
normalized shape:  torch.Size([1, 104, 768])
normalized_resid_mid shape:  torch.Size([1, 104, 768])
residual shape:  torch.Size([1, 104, 768])
normalized shape:  torch.Size([1, 104, 768])
normalized_resid_pre shape:  torch.Size([1, 104, 768])
residual shape:  torch.Size([1, 104, 768])
normalized shape:  torch.Size([1, 104, 768])
normalized_resid_mid shape:  torch.Size([

 51%|█████     | 51/100 [00:05<00:05,  8.59it/s]

normalized shape:  torch.Size([1, 107, 768])
normalized_resid_pre shape:  torch.Size([1, 107, 768])
residual shape:  torch.Size([1, 107, 768])
normalized shape:  torch.Size([1, 107, 768])
normalized_resid_mid shape:  torch.Size([1, 107, 768])
residual shape:  torch.Size([1, 107, 768])
normalized shape:  torch.Size([1, 107, 768])
normalized_resid_pre shape:  torch.Size([1, 107, 768])
residual shape:  torch.Size([1, 107, 768])
normalized shape:  torch.Size([1, 107, 768])
normalized_resid_mid shape:  torch.Size([1, 107, 768])
residual shape:  torch.Size([1, 107, 768])
normalized shape:  torch.Size([1, 107, 768])
normalized_resid_pre shape:  torch.Size([1, 107, 768])
residual shape:  torch.Size([1, 107, 768])
normalized shape:  torch.Size([1, 107, 768])
normalized_resid_mid shape:  torch.Size([1, 107, 768])
residual shape:  torch.Size([1, 107, 768])
normalized shape:  torch.Size([1, 107, 768])
normalized_resid_pre shape:  torch.Size([1, 107, 768])
residual shape:  torch.Size([1, 107, 768])

 52%|█████▏    | 52/100 [00:05<00:05,  8.78it/s]

residual shape:  torch.Size([1, 108, 768])
normalized shape:  torch.Size([1, 108, 768])
normalized_resid_mid shape:  torch.Size([1, 108, 768])
residual shape:  torch.Size([1, 108, 768])
normalized shape:  torch.Size([1, 108, 768])
normalized_resid_pre shape:  torch.Size([1, 108, 768])
residual shape:  torch.Size([1, 108, 768])
normalized shape:  torch.Size([1, 108, 768])
normalized_resid_mid shape:  torch.Size([1, 108, 768])
residual shape:  torch.Size([1, 108, 768])
normalized shape:  torch.Size([1, 108, 768])
normalized_resid_pre shape:  torch.Size([1, 108, 768])
residual shape:  torch.Size([1, 108, 768])
normalized shape:  torch.Size([1, 108, 768])
normalized_resid_mid shape:  torch.Size([1, 108, 768])
residual shape:  torch.Size([1, 108, 768])
normalized shape:  torch.Size([1, 108, 768])
normalized_resid_pre shape:  torch.Size([1, 108, 768])
residual shape:  torch.Size([1, 108, 768])
normalized shape:  torch.Size([1, 108, 768])
normalized_resid_mid shape:  torch.Size([1, 108, 768])

 54%|█████▍    | 54/100 [00:06<00:04,  9.50it/s]

residual shape:  torch.Size([1, 110, 768])
normalized shape:  torch.Size([1, 110, 768])
normalized_resid_mid shape:  torch.Size([1, 110, 768])
residual shape:  torch.Size([1, 110, 768])
normalized shape:  torch.Size([1, 110, 768])
normalized_resid_pre shape:  torch.Size([1, 110, 768])
residual shape:  torch.Size([1, 110, 768])
normalized shape:  torch.Size([1, 110, 768])
normalized_resid_mid shape:  torch.Size([1, 110, 768])
residual shape:  torch.Size([1, 110, 768])
normalized shape:  torch.Size([1, 110, 768])
normalized_resid_pre shape:  torch.Size([1, 110, 768])
residual shape:  torch.Size([1, 110, 768])
normalized shape:  torch.Size([1, 110, 768])
normalized_resid_mid shape:  torch.Size([1, 110, 768])
residual shape:  torch.Size([1, 110, 768])
normalized shape:  torch.Size([1, 110, 768])
normalized_resid_pre shape:  torch.Size([1, 110, 768])
residual shape:  torch.Size([1, 110, 768])
normalized shape:  torch.Size([1, 110, 768])
normalized_resid_mid shape:  torch.Size([1, 110, 768])

 56%|█████▌    | 56/100 [00:06<00:05,  7.56it/s]

residual shape:  torch.Size([1, 112, 768])
normalized shape:  torch.Size([1, 112, 768])
normalized_resid_mid shape:  torch.Size([1, 112, 768])
residual shape:  torch.Size([1, 112, 768])
normalized shape:  torch.Size([1, 112, 768])
normalized_resid_pre shape:  torch.Size([1, 112, 768])
residual shape:  torch.Size([1, 112, 768])
normalized shape:  torch.Size([1, 112, 768])
normalized_resid_mid shape:  torch.Size([1, 112, 768])
residual shape:  torch.Size([1, 112, 768])
normalized shape:  torch.Size([1, 112, 768])
normalized_resid_pre shape:  torch.Size([1, 112, 768])
residual shape:  torch.Size([1, 112, 768])
normalized shape:  torch.Size([1, 112, 768])
normalized_resid_mid shape:  torch.Size([1, 112, 768])
residual shape:  torch.Size([1, 112, 768])
normalized shape:  torch.Size([1, 112, 768])
normalized_resid_pre shape:  torch.Size([1, 112, 768])
residual shape:  torch.Size([1, 112, 768])
normalized shape:  torch.Size([1, 112, 768])
normalized_resid_mid shape:  torch.Size([1, 112, 768])

 57%|█████▋    | 57/100 [00:06<00:06,  6.47it/s]

residual shape:  torch.Size([1, 113, 768])
normalized shape:  torch.Size([1, 113, 768])
normalized_resid_mid shape:  torch.Size([1, 113, 768])
residual shape:  torch.Size([1, 113, 768])
normalized shape:  torch.Size([1, 113, 768])
normalized_resid_pre shape:  torch.Size([1, 113, 768])
residual shape:  torch.Size([1, 113, 768])
normalized shape:  torch.Size([1, 113, 768])
normalized_resid_mid shape:  torch.Size([1, 113, 768])
residual shape:  torch.Size([1, 113, 768])
normalized shape:  torch.Size([1, 113, 768])
normalized_resid_pre shape:  torch.Size([1, 113, 768])
residual shape:  torch.Size([1, 113, 768])
normalized shape:  torch.Size([1, 113, 768])
normalized_resid_mid shape:  torch.Size([1, 113, 768])
residual shape:  torch.Size([1, 113, 768])
normalized shape:  torch.Size([1, 113, 768])
normalized_resid_pre shape:  torch.Size([1, 113, 768])
residual shape:  torch.Size([1, 113, 768])
normalized shape:  torch.Size([1, 113, 768])
normalized_resid_mid shape:  torch.Size([1, 113, 768])

 59%|█████▉    | 59/100 [00:06<00:05,  7.46it/s]

residual shape:  torch.Size([1, 114, 768])
normalized shape:  torch.Size([1, 114, 768])
normalized_resid_mid shape:  torch.Size([1, 114, 768])
residual shape:  torch.Size([1, 114, 768])
normalized shape:  torch.Size([1, 114, 768])
normalized_resid_pre shape:  torch.Size([1, 114, 768])
residual shape:  torch.Size([1, 114, 768])
normalized shape:  torch.Size([1, 114, 768])
normalized_resid_mid shape:  torch.Size([1, 114, 768])
residual shape:  torch.Size([1, 114, 768])
normalized shape:  torch.Size([1, 114, 768])
normalized_resid_pre shape:  torch.Size([1, 114, 768])
residual shape:  torch.Size([1, 114, 768])
normalized shape:  torch.Size([1, 114, 768])
normalized_resid_mid shape:  torch.Size([1, 114, 768])
residual shape:  torch.Size([1, 114, 768])
normalized shape:  torch.Size([1, 114, 768])
normalized_resid_pre shape:  torch.Size([1, 114, 768])
residual shape:  torch.Size([1, 114, 768])
normalized shape:  torch.Size([1, 114, 768])
normalized_resid_mid shape:  torch.Size([1, 114, 768])

 61%|██████    | 61/100 [00:07<00:04,  8.03it/s]

residual shape:  torch.Size([1, 116, 768])
normalized shape:  torch.Size([1, 116, 768])
normalized_resid_mid shape:  torch.Size([1, 116, 768])
residual shape:  torch.Size([1, 116, 768])
normalized shape:  torch.Size([1, 116, 768])
normalized_resid_pre shape:  torch.Size([1, 116, 768])
residual shape:  torch.Size([1, 116, 768])
normalized shape:  torch.Size([1, 116, 768])
normalized_resid_mid shape:  torch.Size([1, 116, 768])
residual shape:  torch.Size([1, 116, 768])
normalized shape:  torch.Size([1, 116, 768])
normalized_resid_pre shape:  torch.Size([1, 116, 768])
residual shape:  torch.Size([1, 116, 768])
normalized shape:  torch.Size([1, 116, 768])
normalized_resid_mid shape:  torch.Size([1, 116, 768])
residual shape:  torch.Size([1, 116, 768])
normalized shape:  torch.Size([1, 116, 768])
normalized_resid_pre shape:  torch.Size([1, 116, 768])
residual shape:  torch.Size([1, 116, 768])
normalized shape:  torch.Size([1, 116, 768])
normalized_resid_mid shape:  torch.Size([1, 116, 768])

 63%|██████▎   | 63/100 [00:07<00:04,  8.90it/s]

residual shape:  torch.Size([1, 118, 768])
normalized shape:  torch.Size([1, 118, 768])
normalized_resid_mid shape:  torch.Size([1, 118, 768])
residual shape:  torch.Size([1, 118, 768])
normalized shape:  torch.Size([1, 118, 768])
normalized_resid_pre shape:  torch.Size([1, 118, 768])
residual shape:  torch.Size([1, 118, 768])
normalized shape:  torch.Size([1, 118, 768])
normalized_resid_mid shape:  torch.Size([1, 118, 768])
residual shape:  torch.Size([1, 118, 768])
normalized shape:  torch.Size([1, 118, 768])
normalized_resid_pre shape:  torch.Size([1, 118, 768])
residual shape:  torch.Size([1, 118, 768])
normalized shape:  torch.Size([1, 118, 768])
normalized_resid_mid shape:  torch.Size([1, 118, 768])
residual shape:  torch.Size([1, 118, 768])
normalized shape:  torch.Size([1, 118, 768])
normalized_resid_pre shape:  torch.Size([1, 118, 768])
residual shape:  torch.Size([1, 118, 768])
normalized shape:  torch.Size([1, 118, 768])
normalized_resid_mid shape:  torch.Size([1, 118, 768])

 65%|██████▌   | 65/100 [00:07<00:03,  9.62it/s]

residual shape:  torch.Size([1, 120, 768])
normalized shape:  torch.Size([1, 120, 768])
normalized_resid_mid shape:  torch.Size([1, 120, 768])
residual shape:  torch.Size([1, 120, 768])
normalized shape:  torch.Size([1, 120, 768])
normalized_resid_pre shape:  torch.Size([1, 120, 768])
residual shape:  torch.Size([1, 120, 768])
normalized shape:  torch.Size([1, 120, 768])
normalized_resid_mid shape:  torch.Size([1, 120, 768])
residual shape:  torch.Size([1, 120, 768])
normalized shape:  torch.Size([1, 120, 768])
normalized_resid_pre shape:  torch.Size([1, 120, 768])
residual shape:  torch.Size([1, 120, 768])
normalized shape:  torch.Size([1, 120, 768])
normalized_resid_mid shape:  torch.Size([1, 120, 768])
residual shape:  torch.Size([1, 120, 768])
normalized shape:  torch.Size([1, 120, 768])
normalized_resid_post shape:  torch.Size([1, 120, 768])
tokens shape:  torch.Size([1, 121])
embed shape:  torch.Size([1, 121, 768])
tokens shape:  torch.Size([1, 121])
pos_embed shape:  torch.Size(

 67%|██████▋   | 67/100 [00:07<00:03, 10.07it/s]

tokens shape:  torch.Size([1, 123])
embed shape:  torch.Size([1, 123, 768])
tokens shape:  torch.Size([1, 123])
pos_embed shape:  torch.Size([1, 123, 768])
residual shape:  torch.Size([1, 123, 768])
normalized shape:  torch.Size([1, 123, 768])
normalized_resid_pre shape:  torch.Size([1, 123, 768])
residual shape:  torch.Size([1, 123, 768])
normalized shape:  torch.Size([1, 123, 768])
normalized_resid_mid shape:  torch.Size([1, 123, 768])
residual shape:  torch.Size([1, 123, 768])
normalized shape:  torch.Size([1, 123, 768])
normalized_resid_pre shape:  torch.Size([1, 123, 768])
residual shape:  torch.Size([1, 123, 768])
normalized shape:  torch.Size([1, 123, 768])
normalized_resid_mid shape:  torch.Size([1, 123, 768])
residual shape:  torch.Size([1, 123, 768])
normalized shape:  torch.Size([1, 123, 768])
normalized_resid_pre shape:  torch.Size([1, 123, 768])
residual shape:  torch.Size([1, 123, 768])
normalized shape:  torch.Size([1, 123, 768])
normalized_resid_mid shape:  torch.Size([

 69%|██████▉   | 69/100 [00:07<00:02, 10.57it/s]

residual shape:  torch.Size([1, 125, 768])
normalized shape:  torch.Size([1, 125, 768])
normalized_resid_mid shape:  torch.Size([1, 125, 768])
residual shape:  torch.Size([1, 125, 768])
normalized shape:  torch.Size([1, 125, 768])
normalized_resid_pre shape:  torch.Size([1, 125, 768])
residual shape:  torch.Size([1, 125, 768])
normalized shape:  torch.Size([1, 125, 768])
normalized_resid_mid shape:  torch.Size([1, 125, 768])
residual shape:  torch.Size([1, 125, 768])
normalized shape:  torch.Size([1, 125, 768])
normalized_resid_pre shape:  torch.Size([1, 125, 768])
residual shape:  torch.Size([1, 125, 768])
normalized shape:  torch.Size([1, 125, 768])
normalized_resid_mid shape:  torch.Size([1, 125, 768])
residual shape:  torch.Size([1, 125, 768])
normalized shape:  torch.Size([1, 125, 768])
normalized_resid_pre shape:  torch.Size([1, 125, 768])
residual shape:  torch.Size([1, 125, 768])
normalized shape:  torch.Size([1, 125, 768])
normalized_resid_mid shape:  torch.Size([1, 125, 768])

 71%|███████   | 71/100 [00:08<00:04,  5.91it/s]

tokens shape:  torch.Size([1, 127])
embed shape:  torch.Size([1, 127, 768])
tokens shape:  torch.Size([1, 127])
pos_embed shape:  torch.Size([1, 127, 768])
residual shape:  torch.Size([1, 127, 768])
normalized shape:  torch.Size([1, 127, 768])
normalized_resid_pre shape:  torch.Size([1, 127, 768])
residual shape:  torch.Size([1, 127, 768])
normalized shape:  torch.Size([1, 127, 768])
normalized_resid_mid shape:  torch.Size([1, 127, 768])
residual shape:  torch.Size([1, 127, 768])
normalized shape:  torch.Size([1, 127, 768])
normalized_resid_pre shape:  torch.Size([1, 127, 768])
residual shape:  torch.Size([1, 127, 768])
normalized shape:  torch.Size([1, 127, 768])
normalized_resid_mid shape:  torch.Size([1, 127, 768])
residual shape:  torch.Size([1, 127, 768])
normalized shape:  torch.Size([1, 127, 768])
normalized_resid_pre shape:  torch.Size([1, 127, 768])
residual shape:  torch.Size([1, 127, 768])
normalized shape:  torch.Size([1, 127, 768])
normalized_resid_mid shape:  torch.Size([

 72%|███████▏  | 72/100 [00:08<00:04,  5.86it/s]

residual shape:  torch.Size([1, 128, 768])
normalized shape:  torch.Size([1, 128, 768])
normalized_resid_pre shape:  torch.Size([1, 128, 768])
residual shape:  torch.Size([1, 128, 768])
normalized shape:  torch.Size([1, 128, 768])
normalized_resid_mid shape:  torch.Size([1, 128, 768])
residual shape:  torch.Size([1, 128, 768])
normalized shape:  torch.Size([1, 128, 768])
normalized_resid_pre shape:  torch.Size([1, 128, 768])
residual shape:  torch.Size([1, 128, 768])
normalized shape:  torch.Size([1, 128, 768])
normalized_resid_mid shape:  torch.Size([1, 128, 768])
residual shape:  torch.Size([1, 128, 768])
normalized shape:  torch.Size([1, 128, 768])
normalized_resid_pre shape:  torch.Size([1, 128, 768])
residual shape:  torch.Size([1, 128, 768])
normalized shape:  torch.Size([1, 128, 768])
normalized_resid_mid shape:  torch.Size([1, 128, 768])
residual shape:  torch.Size([1, 128, 768])
normalized shape:  torch.Size([1, 128, 768])
normalized_resid_pre shape:  torch.Size([1, 128, 768])

 74%|███████▍  | 74/100 [00:09<00:04,  6.29it/s]

residual shape:  torch.Size([1, 129, 768])
normalized shape:  torch.Size([1, 129, 768])
normalized_resid_mid shape:  torch.Size([1, 129, 768])
residual shape:  torch.Size([1, 129, 768])
normalized shape:  torch.Size([1, 129, 768])
normalized_resid_pre shape:  torch.Size([1, 129, 768])
residual shape:  torch.Size([1, 129, 768])
normalized shape:  torch.Size([1, 129, 768])
normalized_resid_mid shape:  torch.Size([1, 129, 768])
residual shape:  torch.Size([1, 129, 768])
normalized shape:  torch.Size([1, 129, 768])
normalized_resid_pre shape:  torch.Size([1, 129, 768])
residual shape:  torch.Size([1, 129, 768])
normalized shape:  torch.Size([1, 129, 768])
normalized_resid_mid shape:  torch.Size([1, 129, 768])
residual shape:  torch.Size([1, 129, 768])
normalized shape:  torch.Size([1, 129, 768])
normalized_resid_pre shape:  torch.Size([1, 129, 768])
residual shape:  torch.Size([1, 129, 768])
normalized shape:  torch.Size([1, 129, 768])
normalized_resid_mid shape:  torch.Size([1, 129, 768])

 76%|███████▌  | 76/100 [00:09<00:03,  7.61it/s]

residual shape:  torch.Size([1, 131, 768])
normalized shape:  torch.Size([1, 131, 768])
normalized_resid_pre shape:  torch.Size([1, 131, 768])
residual shape:  torch.Size([1, 131, 768])
normalized shape:  torch.Size([1, 131, 768])
normalized_resid_mid shape:  torch.Size([1, 131, 768])
residual shape:  torch.Size([1, 131, 768])
normalized shape:  torch.Size([1, 131, 768])
normalized_resid_pre shape:  torch.Size([1, 131, 768])
residual shape:  torch.Size([1, 131, 768])
normalized shape:  torch.Size([1, 131, 768])
normalized_resid_mid shape:  torch.Size([1, 131, 768])
residual shape:  torch.Size([1, 131, 768])
normalized shape:  torch.Size([1, 131, 768])
normalized_resid_pre shape:  torch.Size([1, 131, 768])
residual shape:  torch.Size([1, 131, 768])
normalized shape:  torch.Size([1, 131, 768])
normalized_resid_mid shape:  torch.Size([1, 131, 768])
residual shape:  torch.Size([1, 131, 768])
normalized shape:  torch.Size([1, 131, 768])
normalized_resid_pre shape:  torch.Size([1, 131, 768])

 78%|███████▊  | 78/100 [00:09<00:02,  8.66it/s]

residual shape:  torch.Size([1, 133, 768])
normalized shape:  torch.Size([1, 133, 768])
normalized_resid_mid shape:  torch.Size([1, 133, 768])
residual shape:  torch.Size([1, 133, 768])
normalized shape:  torch.Size([1, 133, 768])
normalized_resid_pre shape:  torch.Size([1, 133, 768])
residual shape:  torch.Size([1, 133, 768])
normalized shape:  torch.Size([1, 133, 768])
normalized_resid_mid shape:  torch.Size([1, 133, 768])
residual shape:  torch.Size([1, 133, 768])
normalized shape:  torch.Size([1, 133, 768])
normalized_resid_pre shape:  torch.Size([1, 133, 768])
residual shape:  torch.Size([1, 133, 768])
normalized shape:  torch.Size([1, 133, 768])
normalized_resid_mid shape:  torch.Size([1, 133, 768])
residual shape:  torch.Size([1, 133, 768])
normalized shape:  torch.Size([1, 133, 768])
normalized_resid_post shape:  torch.Size([1, 133, 768])
tokens shape:  torch.Size([1, 134])
embed shape:  torch.Size([1, 134, 768])
tokens shape:  torch.Size([1, 134])
pos_embed shape:  torch.Size(

 80%|████████  | 80/100 [00:09<00:02,  8.84it/s]

tokens shape:  torch.Size([1, 136])
embed shape:  torch.Size([1, 136, 768])
tokens shape:  torch.Size([1, 136])
pos_embed shape:  torch.Size([1, 136, 768])
residual shape:  torch.Size([1, 136, 768])
normalized shape:  torch.Size([1, 136, 768])
normalized_resid_pre shape:  torch.Size([1, 136, 768])
residual shape:  torch.Size([1, 136, 768])
normalized shape:  torch.Size([1, 136, 768])
normalized_resid_mid shape:  torch.Size([1, 136, 768])
residual shape:  torch.Size([1, 136, 768])
normalized shape:  torch.Size([1, 136, 768])
normalized_resid_pre shape:  torch.Size([1, 136, 768])
residual shape:  torch.Size([1, 136, 768])
normalized shape:  torch.Size([1, 136, 768])
normalized_resid_mid shape:  torch.Size([1, 136, 768])
residual shape:  torch.Size([1, 136, 768])
normalized shape:  torch.Size([1, 136, 768])
normalized_resid_pre shape:  torch.Size([1, 136, 768])
residual shape:  torch.Size([1, 136, 768])
normalized shape:  torch.Size([1, 136, 768])
normalized_resid_mid shape:  torch.Size([

 82%|████████▏ | 82/100 [00:09<00:01,  9.24it/s]

tokens shape:  torch.Size([1, 138])
embed shape:  torch.Size([1, 138, 768])
tokens shape:  torch.Size([1, 138])
pos_embed shape:  torch.Size([1, 138, 768])
residual shape:  torch.Size([1, 138, 768])
normalized shape:  torch.Size([1, 138, 768])
normalized_resid_pre shape:  torch.Size([1, 138, 768])
residual shape:  torch.Size([1, 138, 768])
normalized shape:  torch.Size([1, 138, 768])
normalized_resid_mid shape:  torch.Size([1, 138, 768])
residual shape:  torch.Size([1, 138, 768])
normalized shape:  torch.Size([1, 138, 768])
normalized_resid_pre shape:  torch.Size([1, 138, 768])
residual shape:  torch.Size([1, 138, 768])
normalized shape:  torch.Size([1, 138, 768])
normalized_resid_mid shape:  torch.Size([1, 138, 768])
residual shape:  torch.Size([1, 138, 768])
normalized shape:  torch.Size([1, 138, 768])
normalized_resid_pre shape:  torch.Size([1, 138, 768])
residual shape:  torch.Size([1, 138, 768])
normalized shape:  torch.Size([1, 138, 768])
normalized_resid_mid shape:  torch.Size([

 84%|████████▍ | 84/100 [00:10<00:01,  9.49it/s]

normalized shape:  torch.Size([1, 140, 768])
normalized_resid_pre shape:  torch.Size([1, 140, 768])
residual shape:  torch.Size([1, 140, 768])
normalized shape:  torch.Size([1, 140, 768])
normalized_resid_mid shape:  torch.Size([1, 140, 768])
residual shape:  torch.Size([1, 140, 768])
normalized shape:  torch.Size([1, 140, 768])
normalized_resid_pre shape:  torch.Size([1, 140, 768])
residual shape:  torch.Size([1, 140, 768])
normalized shape:  torch.Size([1, 140, 768])
normalized_resid_mid shape:  torch.Size([1, 140, 768])
residual shape:  torch.Size([1, 140, 768])
normalized shape:  torch.Size([1, 140, 768])
normalized_resid_pre shape:  torch.Size([1, 140, 768])
residual shape:  torch.Size([1, 140, 768])
normalized shape:  torch.Size([1, 140, 768])
normalized_resid_mid shape:  torch.Size([1, 140, 768])
residual shape:  torch.Size([1, 140, 768])
normalized shape:  torch.Size([1, 140, 768])
normalized_resid_pre shape:  torch.Size([1, 140, 768])
residual shape:  torch.Size([1, 140, 768])

 86%|████████▌ | 86/100 [00:10<00:01,  9.93it/s]

residual shape:  torch.Size([1, 142, 768])
normalized shape:  torch.Size([1, 142, 768])
normalized_resid_mid shape:  torch.Size([1, 142, 768])
residual shape:  torch.Size([1, 142, 768])
normalized shape:  torch.Size([1, 142, 768])
normalized_resid_pre shape:  torch.Size([1, 142, 768])
residual shape:  torch.Size([1, 142, 768])
normalized shape:  torch.Size([1, 142, 768])
normalized_resid_mid shape:  torch.Size([1, 142, 768])
residual shape:  torch.Size([1, 142, 768])
normalized shape:  torch.Size([1, 142, 768])
normalized_resid_pre shape:  torch.Size([1, 142, 768])
residual shape:  torch.Size([1, 142, 768])
normalized shape:  torch.Size([1, 142, 768])
normalized_resid_mid shape:  torch.Size([1, 142, 768])
residual shape:  torch.Size([1, 142, 768])
normalized shape:  torch.Size([1, 142, 768])
normalized_resid_pre shape:  torch.Size([1, 142, 768])
residual shape:  torch.Size([1, 142, 768])
normalized shape:  torch.Size([1, 142, 768])
normalized_resid_mid shape:  torch.Size([1, 142, 768])

 88%|████████▊ | 88/100 [00:10<00:01, 10.51it/s]

residual shape:  torch.Size([1, 144, 768])
normalized shape:  torch.Size([1, 144, 768])
normalized_resid_mid shape:  torch.Size([1, 144, 768])
residual shape:  torch.Size([1, 144, 768])
normalized shape:  torch.Size([1, 144, 768])
normalized_resid_pre shape:  torch.Size([1, 144, 768])
residual shape:  torch.Size([1, 144, 768])
normalized shape:  torch.Size([1, 144, 768])
normalized_resid_mid shape:  torch.Size([1, 144, 768])
residual shape:  torch.Size([1, 144, 768])
normalized shape:  torch.Size([1, 144, 768])
normalized_resid_pre shape:  torch.Size([1, 144, 768])
residual shape:  torch.Size([1, 144, 768])
normalized shape:  torch.Size([1, 144, 768])
normalized_resid_mid shape:  torch.Size([1, 144, 768])
residual shape:  torch.Size([1, 144, 768])
normalized shape:  torch.Size([1, 144, 768])
normalized_resid_pre shape:  torch.Size([1, 144, 768])
residual shape:  torch.Size([1, 144, 768])
normalized shape:  torch.Size([1, 144, 768])
normalized_resid_mid shape:  torch.Size([1, 144, 768])

 90%|█████████ | 90/100 [00:10<00:01,  7.08it/s]

tokens shape:  torch.Size([1, 147])
embed shape:  torch.Size([1, 147, 768])
tokens shape:  torch.Size([1, 147])
pos_embed shape:  torch.Size([1, 147, 768])
residual shape:  torch.Size([1, 147, 768])
normalized shape:  torch.Size([1, 147, 768])
normalized_resid_pre shape:  torch.Size([1, 147, 768])
residual shape:  torch.Size([1, 147, 768])
normalized shape:  torch.Size([1, 147, 768])
normalized_resid_mid shape:  torch.Size([1, 147, 768])
residual shape:  torch.Size([1, 147, 768])
normalized shape:  torch.Size([1, 147, 768])
normalized_resid_pre shape:  torch.Size([1, 147, 768])
residual shape:  torch.Size([1, 147, 768])
normalized shape:  torch.Size([1, 147, 768])
normalized_resid_mid shape:  torch.Size([1, 147, 768])
residual shape:  torch.Size([1, 147, 768])
normalized shape:  torch.Size([1, 147, 768])
normalized_resid_pre shape:  torch.Size([1, 147, 768])
residual shape:  torch.Size([1, 147, 768])
normalized shape:  torch.Size([1, 147, 768])
normalized_resid_mid shape:  torch.Size([

 92%|█████████▏| 92/100 [00:11<00:01,  6.77it/s]

residual shape:  torch.Size([1, 148, 768])
normalized shape:  torch.Size([1, 148, 768])
normalized_resid_mid shape:  torch.Size([1, 148, 768])
residual shape:  torch.Size([1, 148, 768])
normalized shape:  torch.Size([1, 148, 768])
normalized_resid_pre shape:  torch.Size([1, 148, 768])
residual shape:  torch.Size([1, 148, 768])
normalized shape:  torch.Size([1, 148, 768])
normalized_resid_mid shape:  torch.Size([1, 148, 768])
residual shape:  torch.Size([1, 148, 768])
normalized shape:  torch.Size([1, 148, 768])
normalized_resid_pre shape:  torch.Size([1, 148, 768])
residual shape:  torch.Size([1, 148, 768])
normalized shape:  torch.Size([1, 148, 768])
normalized_resid_mid shape:  torch.Size([1, 148, 768])
residual shape:  torch.Size([1, 148, 768])
normalized shape:  torch.Size([1, 148, 768])
normalized_resid_pre shape:  torch.Size([1, 148, 768])
residual shape:  torch.Size([1, 148, 768])
normalized shape:  torch.Size([1, 148, 768])
normalized_resid_mid shape:  torch.Size([1, 148, 768])

 93%|█████████▎| 93/100 [00:11<00:01,  6.58it/s]

tokens shape:  torch.Size([1, 150])
embed shape:  torch.Size([1, 150, 768])
tokens shape:  torch.Size([1, 150])
pos_embed shape:  torch.Size([1, 150, 768])
residual shape:  torch.Size([1, 150, 768])
normalized shape:  torch.Size([1, 150, 768])
normalized_resid_pre shape:  torch.Size([1, 150, 768])
residual shape:  torch.Size([1, 150, 768])
normalized shape:  torch.Size([1, 150, 768])
normalized_resid_mid shape:  torch.Size([1, 150, 768])
residual shape:  torch.Size([1, 150, 768])
normalized shape:  torch.Size([1, 150, 768])
normalized_resid_pre shape:  torch.Size([1, 150, 768])
residual shape:  torch.Size([1, 150, 768])
normalized shape:  torch.Size([1, 150, 768])
normalized_resid_mid shape:  torch.Size([1, 150, 768])
residual shape:  torch.Size([1, 150, 768])
normalized shape:  torch.Size([1, 150, 768])
normalized_resid_pre shape:  torch.Size([1, 150, 768])
residual shape:  torch.Size([1, 150, 768])
normalized shape:  torch.Size([1, 150, 768])
normalized_resid_mid shape:  torch.Size([

 95%|█████████▌| 95/100 [00:11<00:00,  6.43it/s]

tokens shape:  torch.Size([1, 151])
embed shape:  torch.Size([1, 151, 768])
tokens shape:  torch.Size([1, 151])
pos_embed shape:  torch.Size([1, 151, 768])
residual shape:  torch.Size([1, 151, 768])
normalized shape:  torch.Size([1, 151, 768])
normalized_resid_pre shape:  torch.Size([1, 151, 768])
residual shape:  torch.Size([1, 151, 768])
normalized shape:  torch.Size([1, 151, 768])
normalized_resid_mid shape:  torch.Size([1, 151, 768])
residual shape:  torch.Size([1, 151, 768])
normalized shape:  torch.Size([1, 151, 768])
normalized_resid_pre shape:  torch.Size([1, 151, 768])
residual shape:  torch.Size([1, 151, 768])
normalized shape:  torch.Size([1, 151, 768])
normalized_resid_mid shape:  torch.Size([1, 151, 768])
residual shape:  torch.Size([1, 151, 768])
normalized shape:  torch.Size([1, 151, 768])
normalized_resid_pre shape:  torch.Size([1, 151, 768])
residual shape:  torch.Size([1, 151, 768])
normalized shape:  torch.Size([1, 151, 768])
normalized_resid_mid shape:  torch.Size([

 96%|█████████▌| 96/100 [00:11<00:00,  7.01it/s]

tokens shape:  torch.Size([1, 153])
embed shape:  torch.Size([1, 153, 768])
tokens shape:  torch.Size([1, 153])
pos_embed shape:  torch.Size([1, 153, 768])
residual shape:  torch.Size([1, 153, 768])
normalized shape:  torch.Size([1, 153, 768])
normalized_resid_pre shape:  torch.Size([1, 153, 768])
residual shape:  torch.Size([1, 153, 768])
normalized shape:  torch.Size([1, 153, 768])
normalized_resid_mid shape:  torch.Size([1, 153, 768])
residual shape:  torch.Size([1, 153, 768])
normalized shape:  torch.Size([1, 153, 768])
normalized_resid_pre shape:  torch.Size([1, 153, 768])
residual shape:  torch.Size([1, 153, 768])
normalized shape:  torch.Size([1, 153, 768])
normalized_resid_mid shape:  torch.Size([1, 153, 768])
residual shape:  torch.Size([1, 153, 768])
normalized shape:  torch.Size([1, 153, 768])
normalized_resid_pre shape:  torch.Size([1, 153, 768])
residual shape:  torch.Size([1, 153, 768])
normalized shape:  torch.Size([1, 153, 768])
normalized_resid_mid shape:  torch.Size([

 98%|█████████▊| 98/100 [00:12<00:00,  6.63it/s]

tokens shape:  torch.Size([1, 154])
embed shape:  torch.Size([1, 154, 768])
tokens shape:  torch.Size([1, 154])
pos_embed shape:  torch.Size([1, 154, 768])
residual shape:  torch.Size([1, 154, 768])
normalized shape:  torch.Size([1, 154, 768])
normalized_resid_pre shape:  torch.Size([1, 154, 768])
residual shape:  torch.Size([1, 154, 768])
normalized shape:  torch.Size([1, 154, 768])
normalized_resid_mid shape:  torch.Size([1, 154, 768])
residual shape:  torch.Size([1, 154, 768])
normalized shape:  torch.Size([1, 154, 768])
normalized_resid_pre shape:  torch.Size([1, 154, 768])
residual shape:  torch.Size([1, 154, 768])
normalized shape:  torch.Size([1, 154, 768])
normalized_resid_mid shape:  torch.Size([1, 154, 768])
residual shape:  torch.Size([1, 154, 768])
normalized shape:  torch.Size([1, 154, 768])
normalized_resid_pre shape:  torch.Size([1, 154, 768])
residual shape:  torch.Size([1, 154, 768])
normalized shape:  torch.Size([1, 154, 768])
normalized_resid_mid shape:  torch.Size([

100%|██████████| 100/100 [00:12<00:00,  8.10it/s]

tokens shape:  torch.Size([1, 156])
embed shape:  torch.Size([1, 156, 768])
tokens shape:  torch.Size([1, 156])
pos_embed shape:  torch.Size([1, 156, 768])
residual shape:  torch.Size([1, 156, 768])
normalized shape:  torch.Size([1, 156, 768])
normalized_resid_pre shape:  torch.Size([1, 156, 768])
residual shape:  torch.Size([1, 156, 768])
normalized shape:  torch.Size([1, 156, 768])
normalized_resid_mid shape:  torch.Size([1, 156, 768])
residual shape:  torch.Size([1, 156, 768])
normalized shape:  torch.Size([1, 156, 768])
normalized_resid_pre shape:  torch.Size([1, 156, 768])
residual shape:  torch.Size([1, 156, 768])
normalized shape:  torch.Size([1, 156, 768])
normalized_resid_mid shape:  torch.Size([1, 156, 768])
residual shape:  torch.Size([1, 156, 768])
normalized shape:  torch.Size([1, 156, 768])
normalized_resid_pre shape:  torch.Size([1, 156, 768])
residual shape:  torch.Size([1, 156, 768])
normalized shape:  torch.Size([1, 156, 768])
normalized_resid_mid shape:  torch.Size([




##### Tests

In [17]:
def random_float_test(class_name, shape):
    cfg = Config(debug=True)
    layer = class_name(cfg)

    random_input = torch.randn(shape) # standard gaussian
    print("input shape: ", random_input.shape)

    output = layer(random_input)
    print("output shape: ", output.shape)

    print()
    return output

def random_int_test(class_name, shape):
    cfg = Config(debug=True)
    layer = class_name(cfg)

    random_input = torch.randint(100, 1000, shape)
    print("input shape:", random_input.shape)

    output = layer(random_input)
    print("output shape:", output.shape)
    print()

    return output

def load_gpt2_test(class_name, gpt2_layer, input_name, cache_dict=cache.cache_dict):
    cfg = Config(debug=True)
    layer = class_name(cfg)
    layer.load_state_dict(gpt2_layer.state_dict(), strict=False)
    
    # allow inputs of strings or tensors
    if isinstance(input_name, str): 
        reference_input = cache_dict[input_name]
    else:
        reference_input = input_name
    # move layer to the same device as the input
    device = reference_input.device
    layer = layer.to(device)
    print("input shape:", reference_input.shape)
    output = layer(reference_input)
    print("output shape:", output.shape)

    # gpt2 attention needs 3 inputs
    if class_name.__name__ == "Attention":
        reference_output = gpt2_layer(reference_input, reference_input, reference_input)
    else:
        reference_output = gpt2_layer(reference_input)
    print("reference output shape:", reference_output.shape)

    comparison = torch.isclose(output, reference_output, atol=1e-4, rtol=1e-3)
    print(f"{comparison.sum()/comparison.numel():.2%} of the values are correct")
    return output

In [18]:
random_float_test(LayerNorm, [2, 4, 768])

input shape:  torch.Size([2, 4, 768])
residual shape:  torch.Size([2, 4, 768])
normalized shape:  torch.Size([2, 4, 768])
output shape:  torch.Size([2, 4, 768])



tensor([[[-0.4505,  0.0815, -0.5587,  ...,  1.7149, -0.7023,  0.4755],
         [-1.4051,  0.3624, -0.8785,  ..., -1.1124, -0.3760, -0.8861],
         [ 0.3491,  0.6039,  0.4444,  ...,  0.8410, -1.2368,  1.1708],
         [ 0.1494,  0.1321,  0.0041,  ..., -1.3641, -0.4140,  1.0903]],

        [[-0.8565,  0.7507,  1.0751,  ...,  0.0194,  0.6733,  0.2951],
         [ 0.4180,  0.7619,  1.8514,  ...,  0.9386,  0.4737,  3.2410],
         [-0.6813, -0.8014,  0.6049,  ...,  0.2056,  0.5626, -1.0239],
         [ 0.1895,  0.4521,  0.0161,  ...,  0.1490,  0.6581, -1.5012]]],
       grad_fn=<AddBackward0>)