In [99]:
import os
import sys
import plotly.express as px
import torch
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
from pathlib import Path
import numpy as np
import einops
from fancy_einsum import einsum
from jaxtyping import Int, Float
from typing import List, Optional, Tuple
import functools
from tqdm import tqdm
from IPython.display import display
import webbrowser
import gdown
from transformer_lens.hook_points import HookPoint
from transformer_lens import utils, HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache
from transformer_lens.utils import get_corner
import circuitsvis as cv
import math

In [100]:
import tiktoken
tokenizer = tiktoken.get_encoding("gpt2")

In [101]:
class Config():
    def __init__(self):
        self.vocab_size = 50257
        self.embedding_dim = 1024
        self.mlp_dim = 4 * self.embedding_dim
        self.num_blocks = 4
        self.num_heads = 8
        self.context_len = 1024
        self.attention_dim = self.embedding_dim // self.num_heads
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.dtype = torch.bfloat16

config = Config()

In [102]:
class Embedding(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.W_E= nn.Embedding(config.vocab_size, config.embedding_dim)
        self.W_pos = nn.Embedding(config.context_len, config.embedding_dim)

    def forward(self, tokens):
        tokens = torch.tensor(tokens)
        embeddings = self.W_E(tokens)

        print(f"---------------------------{tokens.shape}")
        positions = self.W_pos(torch.arange(tokens.shape[1]))

        return embeddings + positions

class DeEmbedding(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.W_D = nn.Linear(config.embedding_dim, config.vocab_size)

    def forward(self, x):
        embeddings = self.W_D(x)

        return embeddings

In [103]:
class Attention(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.W_Q = nn.Linear(config.embedding_dim, config.embedding_dim, bias=False)
        self.W_K = nn.Linear(config.embedding_dim, config.embedding_dim, bias=False)
        self.W_V = nn.Linear(config.embedding_dim, config.embedding_dim, bias=False)
        
        self.W_out = nn.Linear(config.embedding_dim, config.embedding_dim, bias=False)

        self.num_heads = config.num_heads
        self.attention_dim = config.attention_dim
    
    def forward(self, x):
        Q = einops.rearrange(self.W_Q(x), 'batch seq (head dim) -> batch head seq dim', head=self.num_heads)
        K = einops.rearrange(self.W_K(x), 'batch seq (head dim) -> batch head seq dim', head=self.num_heads)
        V = einops.rearrange(self.W_V(x), 'batch seq (head dim) -> batch head seq dim', head=self.num_heads)

        QK = torch.softmax((Q @ K.transpose(-2, -1))/(torch.sqrt(torch.tensor(self.attention_dim))), dim = -1)

        QKV = einops.rearrange(QK @ V, 'batch head seq dim -> batch seq (head dim)', head=config.num_heads)
        QKV_Out = self.W_out(QKV)

        return QKV_Out

In [104]:
class MLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.layer1 = nn.Linear(config.embedding_dim, config.mlp_dim)
        self.gelu = nn.GELU()
        self.layer2 = nn.Linear(config.mlp_dim, config.embedding_dim)
    
    def forward(self, x):
        x = torch.softmax(self.layer2(self.gelu(self.layer1(x))), dim = -1)

        return x

In [105]:
class TransformerBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.MLP_Layers = MLP(config)
        self.Attention_Layers = Attention(config)

    def forward(self, x):
        return x + self.MLP_Layers(x + self.Attention_Layers(x))

In [109]:
class Transformer(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.embed = Embedding(config)

        self.blocks = nn.ModuleList([
            TransformerBlock(config) for i in range(config.num_blocks)
        ])

        self.deembed = DeEmbedding(config)

    def forward(self, x):

        x = self.embed(x)
        print(f"after embed: {x}")

        for block in self.blocks:
            x = block(x)
        
        x = self.deembed(x)
        print(f"after deembed: {x}")

        return x

In [114]:
def main():
    text = "The quick brown fox"
    tokens = tokenizer.encode(text)
    print(tokens)

    x = torch.tensor(tokens)
    x = x.unsqueeze(0)

    print(f"---------------{x.shape}")

    model = Transformer(config)

    out = model(x)
    pred_tokens = out.argmax(dim=-1)

    token_list = pred_tokens[0].tolist()
    decoded_text = tokenizer.decode(token_list)

    print(text + decoded_text.split()[-1])

if __name__ == "__main__":
    main()

[464, 2068, 7586, 21831]
---------------torch.Size([1, 4])
---------------------------torch.Size([1, 4])
after embed: tensor([[[ 3.6693,  2.1101,  0.6655,  ..., -1.3162,  0.7272,  2.2976],
         [-0.5886, -1.8766,  0.6572,  ..., -0.9764,  2.0133, -1.8102],
         [ 1.7779,  0.8118, -2.2432,  ..., -1.3450,  1.0963, -2.1877],
         [ 0.0666,  0.6621,  0.1826,  ..., -2.3854,  0.6116,  1.1658]]],
       grad_fn=<AddBackward0>)
after deembed: tensor([[[ 1.1262, -1.5290,  0.6249,  ...,  0.5261, -1.1176, -0.2258],
         [-1.5563,  0.0773,  1.0925,  ..., -0.3508,  0.4950,  0.0473],
         [-0.4413, -0.1252, -0.6700,  ..., -0.2989, -1.1514, -0.5644],
         [ 0.6584,  0.4800, -1.2056,  ..., -0.8050, -0.8617, -0.3772]]],
       grad_fn=<ViewBackward0>)
The quick brown foxrefused


  tokens = torch.tensor(tokens)
