In [1]:
import torch
import einops
from components import MLP, gelu, Attention, LayerNorm, TransformerBlock
from transformer_lens import HookedTransformer
import plotly.express as px

In [2]:
x = torch.linspace(-5, 5, 100)
fig = px.line(x=x, y=gelu(x))
fig.show()

In [3]:
tl_model = HookedTransformer.from_pretrained('gelu-2l', device='cpu')

Loaded pretrained model gelu-2l into HookedTransformer


In [4]:
tokenizer = tl_model.tokenizer

In [5]:
tokenizer.vocab_size

48262

In [20]:
class Transformer(torch.nn.Module):
    def __init__(self, d_model, n_layers, tokenizer, max_pos=1024):
        super().__init__()
        self.blocks = [TransformerBlock(d_model)]*n_layers
        self.W_enc = torch.nn.Parameter(
            torch.randn(tokenizer.vocab_size, d_model)
        )
        self.W_dec = torch.nn.Parameter(
            torch.randn(d_model, tokenizer.vocab_size)
        )

        pos_matrix = torch.zeros(max_pos, d_model)
        for d in range(0, d_model, 2):
            period = 1/(10000**(d/d_model))
            i = d//2
            j = d//2 +1
            pos_matrix[:, i] = torch.sin(torch.linspace(0, max_pos-1, max_pos)* period)
            pos_matrix[:, j] = torch.cos(torch.linspace(0, max_pos-1, max_pos)* period)

        self.W_pos = torch.nn.Parameter(
            pos_matrix
        )
    
    def forward(self, tokens):
        x = self.W_enc[tokens]
        print(x.shape)
        for block in self.blocks:
            x = block(x)
        
        return x @ self.W_dec


In [21]:
tokens = tl_model.to_tokens(["hello there", "goodbye friend"])
print(tokens)

tensor([[    1, 24684,   626,     2],
        [    1, 11976, 16560,  3224]])


In [22]:
my_model = Transformer(d_model=512, n_layers=2, tokenizer=tokenizer)


In [23]:
my_model(tokens)

torch.Size([2, 4, 512])


tensor([[[  319.7776,  -205.6019, -1118.2776,  ...,  -234.0956,
            233.4509,   -43.0072],
         [ -439.0345,   -81.4608, -1008.5258,  ...,   465.5396,
            172.4650,   324.3416],
         [  109.3476,  -203.4124,  -619.1780,  ...,    18.4812,
            -35.2011,   123.5763],
         [ -114.0959,   -64.2677,  -724.6772,  ...,   243.4855,
             97.0348,   267.9216]],

        [[  319.7776,  -205.6019, -1118.2776,  ...,  -234.0956,
            233.4509,   -43.0072],
         [  102.6522,  -481.0211,  -947.5989,  ...,   -52.4026,
            -68.5954,    96.0222],
         [ -131.6534,  -899.0856,  -526.6103,  ...,     5.9019,
             76.8638,  -226.3185],
         [ -125.3566,  -932.0342,  -529.3827,  ...,    -9.0342,
             39.4042,  -247.2658]]], grad_fn=<UnsafeViewBackward0>)

In [27]:
px.imshow(my_model.W_pos.detach(), color_continuous_scale="RdBu", color_continuous_midpoint=0)