In [2]:
%pip install einops
%pip install jaxtyping
%pip install git+https://github.com/samarth-bhargav/TransformerLens.git

import einops
from dataclasses import dataclass
# from transformer_lens import HookedTransformer
# from transformer_lens.utils import gelu_new, tokenize_and_concatenate
import torch as t
from torch import Tensor
import torch.nn as nn
import numpy as np
from tqdm.notebook import tqdm
from jaxtyping import Float, Int
from pathlib import Path
from typing import Tuple, List, Optional, Dict, Callable
from torch.cuda.amp import GradScaler, autocast

from tqdm import tqdm

import transformer_lens
from transformer_lens import HookedTransformer, HookedTransformerConfig
from transformer_lens.utils import gelu_new, tokenize_and_concatenate

import sys
sys.path.append("/content")
import plot

device = t.device("cuda" if t.cuda.is_available() else "cpu")

Collecting git+https://github.com/samarth-bhargav/TransformerLens.git
  Cloning https://github.com/samarth-bhargav/TransformerLens.git to /tmp/pip-req-build-4mp0ur34
  Running command git clone --filter=blob:none --quiet https://github.com/samarth-bhargav/TransformerLens.git /tmp/pip-req-build-4mp0ur34
  Resolved https://github.com/samarth-bhargav/TransformerLens.git to commit c43b244d815ae69c468a3acbcb9e52b6a2db8f6d
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting beartype<0.15.0,>=0.14.1 (from transformer-lens==0.0.0)
  Using cached beartype-0.14.1-py3-none-any.whl.metadata (28 kB)
Collecting better-abc<0.0.4,>=0.0.3 (from transformer-lens==0.0.0)
  Using cached better_abc-0.0.3-py3-none-any.whl.metadata (1.4 kB)
Collecting datasets>=2.7.1 (from transformer-lens==0.0.0)
  Using cached datasets-2.21.0-py3-none-any.whl.metadata (21 kB)
Collecting fancy-ei

RuntimeError: THPDtypeType.tp_dict == nullptr INTERNAL ASSERT FAILED at "../torch/csrc/Dtype.cpp":176, please report a bug to PyTorch. 

In [2]:
cfg = HookedTransformerConfig(
    d_model = 128,
    n_layers = 2,
    n_heads = 12,
    d_head = 64,
    d_mlp = 3072,
    n_ctx = 32,
    d_vocab = 40,
    act_fn = 'gelu_new'
)

cfg.init_range = 0.02
cfg.layer_norm_eps = 1e-5

In [3]:
class LayerNorm(nn.Module):
    def __init__(self, cfg: HookedTransformerConfig):
        super().__init__()
        self.cfg = cfg
        self.w = nn.Parameter(t.ones(cfg.d_model))
        self.b = nn.Parameter(t.zeros(cfg.d_model))

    def forward(self, residual: Float[Tensor, "batch posn d_model"]) -> Float[Tensor, "batch posn d_model"]:
        # SOLUTION
        residual_mean = residual.mean(dim=-1, keepdim=True)
        residual_std = (residual.var(dim=-1, keepdim=True, unbiased=False) + self.cfg.layer_norm_eps).sqrt()

        residual = (residual - residual_mean) / residual_std
        return residual * self.w + self.b

In [4]:
class Embed(nn.Module):
    def __init__(self, cfg: HookedTransformerConfig):
        super().__init__()
        self.cfg = cfg
        self.W_E = nn.Parameter(t.empty(cfg.d_vocab, cfg.d_model))
        nn.init.normal_(self.W_E, std=self.cfg.init_range)

    def forward(self, vecs: Float[Tensor, "batch position d_vocab"]) -> Float[Tensor, "batch position d_model"]:
        # SOLUTION
        return einops.einsum(vecs, self.W_E, "batch position d_vocab, d_vocab d_model -> batch position d_model")

transformer_lens.components.embed.Embed = Embed

In [5]:
class PosEmbed(nn.Module):
    def __init__(self, cfg: HookedTransformerConfig):
        super().__init__()
        self.cfg = cfg
        self.W_pos = nn.Parameter(t.empty((cfg.n_ctx, cfg.d_model)))
        nn.init.normal_(self.W_pos, std=self.cfg.init_range)

    def forward(self, vecs: Float[Tensor, "batch position d_vocab"]) -> Float[Tensor, "batch position d_model"]:
        # SOLUTION
        batch, seq_len, _ = vecs.shape
        return einops.repeat(self.W_pos[:seq_len], "seq d_model -> batch seq d_model", batch=batch)

In [6]:
class Attention(nn.Module):
    IGNORE: Float[Tensor, ""]

    def __init__(self, cfg: HookedTransformerConfig):
        super().__init__()
        self.cfg = cfg
        self.W_Q = nn.Parameter(t.empty((cfg.n_heads, cfg.d_model, cfg.d_head)))
        self.W_K = nn.Parameter(t.empty((cfg.n_heads, cfg.d_model, cfg.d_head)))
        self.W_V = nn.Parameter(t.empty((cfg.n_heads, cfg.d_model, cfg.d_head)))
        self.W_O = nn.Parameter(t.empty((cfg.n_heads, cfg.d_head, cfg.d_model)))
        self.b_Q = nn.Parameter(t.zeros((cfg.n_heads, cfg.d_head)))
        self.b_K = nn.Parameter(t.zeros((cfg.n_heads, cfg.d_head)))
        self.b_V = nn.Parameter(t.zeros((cfg.n_heads, cfg.d_head)))
        self.b_O = nn.Parameter(t.zeros((cfg.d_model)))
        nn.init.normal_(self.W_Q, std=self.cfg.init_range)
        nn.init.normal_(self.W_K, std=self.cfg.init_range)
        nn.init.normal_(self.W_V, std=self.cfg.init_range)
        nn.init.normal_(self.W_O, std=self.cfg.init_range)
        self.register_buffer("IGNORE", t.tensor(float("-inf"), dtype=t.float32, device=device))

    def forward(
        self, normalized_resid_pre: Float[Tensor, "batch posn d_model"]
    ) -> Float[Tensor, "batch posn d_model"]:
        # SOLUTION
        # Calculate query, key and value vectors
        q = einops.einsum(
            normalized_resid_pre, self.W_Q,
            "batch posn d_model, nheads d_model d_head -> batch posn nheads d_head",
        ) + self.b_Q
        k = einops.einsum(
            normalized_resid_pre, self.W_K,
            "batch posn d_model, nheads d_model d_head -> batch posn nheads d_head",
        ) + self.b_K
        v = einops.einsum(
            normalized_resid_pre, self.W_V,
            "batch posn d_model, nheads d_model d_head -> batch posn nheads d_head",
        ) + self.b_V

        # Calculate attention scores, then scale and mask, and apply softmax to get probabilities
        attn_scores = einops.einsum(
            q, k,
            "batch posn_Q nheads d_head, batch posn_K nheads d_head -> batch nheads posn_Q posn_K",
        )
        attn_scores_masked = self.apply_causal_mask(attn_scores / self.cfg.d_head ** 0.5)
        attn_pattern = attn_scores_masked.softmax(-1)

        # Take weighted sum of value vectors, according to attention probabilities
        z = einops.einsum(
            v, attn_pattern,
            "batch posn_K nheads d_head, batch nheads posn_Q posn_K -> batch posn_Q nheads d_head",
        )

        # Calculate output (by applying matrix W_O and summing over heads, then adding bias b_O)
        attn_out = einops.einsum(
            z, self.W_O,
            "batch posn_Q nheads d_head, nheads d_head d_model -> batch posn_Q d_model",
        ) + self.b_O

        return attn_out

    def apply_causal_mask(
        self, attn_scores: Float[Tensor, "batch n_heads query_pos key_pos"]
    ) -> Float[Tensor, "batch n_heads query_pos key_pos"]:
        '''
        Applies a causal mask to attention scores, and returns masked scores.
        '''
        # SOLUTION
        # Define a mask that is True for all positions we want to set probabilities to zero for
        all_ones = t.ones(attn_scores.size(-2), attn_scores.size(-1), device=attn_scores.device)
        mask = t.triu(all_ones, diagonal=1).bool()
        # Apply the mask to attention scores, then return the masked scores
        attn_scores.masked_fill_(mask, self.IGNORE)
        return attn_scores

In [7]:
class MLP(nn.Module):
    def __init__(self, cfg: HookedTransformerConfig):
        super().__init__()
        self.cfg = cfg
        self.W_in = nn.Parameter(t.empty((cfg.d_model, cfg.d_mlp)))
        self.W_out = nn.Parameter(t.empty((cfg.d_mlp, cfg.d_model)))
        self.b_in = nn.Parameter(t.zeros((cfg.d_mlp)))
        self.b_out = nn.Parameter(t.zeros((cfg.d_model)))
        nn.init.normal_(self.W_in, std=self.cfg.init_range)
        nn.init.normal_(self.W_out, std=self.cfg.init_range)

    def forward(
        self, normalized_resid_mid: Float[Tensor, "batch posn d_model"]
    ) -> Float[Tensor, "batch posn d_model"]:
        # SOLUTION
        pre = einops.einsum(
            normalized_resid_mid, self.W_in,
            "batch position d_model, d_model d_mlp -> batch position d_mlp",
        ) + self.b_in
        post = gelu_new(pre)
        mlp_out = einops.einsum(
            post, self.W_out,
            "batch position d_mlp, d_mlp d_model -> batch position d_model",
        ) + self.b_out
        return mlp_out

In [8]:
class TransformerBlock(nn.Module):
    def __init__(self, cfg: HookedTransformerConfig):
        super().__init__()
        self.cfg = cfg
        self.ln1 = LayerNorm(cfg)
        self.attn = Attention(cfg)
        self.ln2 = LayerNorm(cfg)
        self.mlp = MLP(cfg)

    def forward(
        self, resid_pre: Float[Tensor, "batch position d_model"]
    ) -> Float[Tensor, "batch position d_model"]:
        # SOLUTION
        resid_mid = self.attn(self.ln1(resid_pre)) + resid_pre
        resid_post = self.mlp(self.ln2(resid_mid)) + resid_mid
        return resid_post

In [9]:
class Unembed(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.W_U = nn.Parameter(t.empty((cfg.d_model, cfg.d_vocab)))
        nn.init.normal_(self.W_U, std=self.cfg.init_range)
        self.b_U = nn.Parameter(t.zeros((cfg.d_vocab), requires_grad=False))

    def forward(
        self, normalized_resid_final: Float[Tensor, "batch position d_model"]
    ) -> Float[Tensor, "batch position d_vocab"]:
        # SOLUTION
        return einops.einsum(
            normalized_resid_final, self.W_U,
            "batch posn d_model, d_model d_vocab -> batch posn d_vocab",
        ) + self.b_U
        # Or, could just do `normalized_resid_final @ self.W_U + self.b_U`

In [10]:
class DemoTransformer(nn.Module):
    def __init__(self, cfg: HookedTransformerConfig):
        super().__init__()
        self.cfg = cfg
        self.embed = Embed(cfg)
        self.pos_embed = PosEmbed(cfg)
        self.blocks = nn.ModuleList([TransformerBlock(cfg) for _ in range(cfg.n_layers)])
        self.ln_final = LayerNorm(cfg)
        self.unembed = Unembed(cfg)

    def forward(self, tokens: Float[Tensor, "batch position d_vocab"]) -> Float[Tensor, "batch position d_vocab"]:
        # SOLUTION
        residual = self.embed(tokens) + self.pos_embed(tokens)
        for block in self.blocks:
            residual = block(residual)
        prediction = self.unembed(self.ln_final(residual))
        return prediction

In [11]:
@dataclass
class TransformerTrainingArgs():
	batch_size = 16
	lr = 1e-4
	weight_decay = 1e-2
	wandb_project: Optional[str] = "day1-demotransformer"
	wandb_name: Optional[str] = None


args = TransformerTrainingArgs()

In [12]:
def rand_range(low, high, shape):
  return t.rand(shape) * (high - low) + low

def generate_linear_recurrences(batch_size, vector_dim=40, compl=1, length=30, param_bds = (-3, 3), return_type="both"):

    assert compl < length
    assert param_bds[0] <= param_bds[1]

    params = rand_range(1, 1, (batch_size, compl)).to(device)
    consts = rand_range(param_bds[0], param_bds[1], (batch_size, vector_dim)).to(device)

    recurrences = t.empty((batch_size, length, vector_dim)).to(device)

    recurrences[:, :compl] = rand_range(-3, 3, (batch_size, 1, vector_dim))

    for j in range(compl, length):
        recurrences[:, j] = consts + einops.einsum(params, recurrences[:, j-compl:j], "batch compl, batch compl vector_dim -> batch vector_dim")

    #find max norm in each batch and divide each batch by that max norm
    max_norms, _ = t.max(t.norm(recurrences, dim=2, keepdim=True), dim=1, keepdim=True)
    caps = (t.rand((batch_size, 1, 1)) * (2 - 1) + 1).to(device)
    recurrences = recurrences / (max_norms / caps)

    max_norms = max_norms.squeeze(1)

    if return_type == 'seq':
        return recurrences
    elif return_type == 'both':
        return recurrences, (params, consts / (max_norms / caps))

    assert False

generate_linear_recurrences(1)

(tensor([[[-1.3898e-02, -8.4231e-04, -9.0464e-03,  ..., -2.1358e-03,
           -1.0883e-02, -1.4151e-02],
          [ 2.5793e-04, -4.3577e-03, -1.1750e-02,  ...,  3.9989e-03,
           -2.2661e-02, -7.6783e-03],
          [ 1.4413e-02, -7.8730e-03, -1.4453e-02,  ...,  1.0134e-02,
           -3.4439e-02, -1.2055e-03],
          ...,
          [ 3.6830e-01, -9.5757e-02, -8.2036e-02,  ...,  1.6350e-01,
           -3.2889e-01,  1.6061e-01],
          [ 3.8246e-01, -9.9272e-02, -8.4740e-02,  ...,  1.6963e-01,
           -3.4067e-01,  1.6709e-01],
          [ 3.9661e-01, -1.0279e-01, -8.7443e-02,  ...,  1.7577e-01,
           -3.5244e-01,  1.7356e-01]]]),
 (tensor([[1.]]),
  tensor([[[ 0.0142, -0.0035, -0.0027,  0.0039,  0.0131,  0.0121,  0.0053,
            -0.0135,  0.0112, -0.0007, -0.0072,  0.0139, -0.0005,  0.0107,
             0.0046, -0.0039, -0.0056, -0.0103, -0.0106, -0.0035,  0.0124,
            -0.0104,  0.0127, -0.0007,  0.0047, -0.0063,  0.0071, -0.0036,
             0.0038,  

In [15]:
class TransformerTrainer:
    def __init__(self, args: TransformerTrainingArgs, model: DemoTransformer):
        super().__init__()
        self.model = model
        self.args = args
        self.optimizer = t.optim.AdamW(self.model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
        self.loss = nn.MSELoss()
        self.step = 0


    def training_step(self, batch: Float[Tensor, "batch seq d_vocab"]) -> Float[Tensor, ""]:
        '''
        Calculates the loss on the tokens in the batch, performs a gradient update step, and logs the loss.

        Remember that `batch` is a dictionary with the single key 'tokens'.
        '''
        # SOLUTION
        pred = self.model(batch)
        loss = self.loss(pred[:,:-1,:], batch[:,1:,:])
        loss.backward()
        self.optimizer.step()
        self.optimizer.zero_grad()
        self.step += 1
        return loss


    def train(self, steps=100_000):
        '''
        Trains the model, for `self.args.epochs` epochs. Also handles wandb initialization, and early stopping
        for each epoch at `self.args.max_steps_per_epoch` steps.
        '''
        # Initialize Weights & Biases logging if needed
        # wandb.init(project=self.args.wandb_project, name=self.args.wandb_name, config=self.args)

        loss_history = []
        with tqdm(total=steps, desc="Training Progress", ascii=True) as pbar:
            for i in range(steps):
                # Simulating the generation of batches with varying lengths
                # num_terms = t.randint(4, 10, (1,)).item() if i > steps / 5 else int((10 - 4) * (i / (steps / 5)) + 4)
                num_terms = t.randint(4, 30, (1,)).item()
                batch = generate_linear_recurrences(self.args.batch_size, length=num_terms, return_type='seq')
                loss = self.training_step(batch)

                # Update tqdm progress bar with loss information
                pbar.update(1)
                pbar.set_postfix_str(f"Step: {i+1}, Loss: {loss.item():.7f}")
                if (i % 500 == 0):
                    loss_history.append(loss.item())

                # Optionally log metrics to Weights & Biases
                # wandb.log({"loss": loss.item()}, step=i)
        return loss_history

    # Clean up Weights & Biases session after training is complete
    # wandb.finish()



In [16]:
model = DemoTransformer(cfg).to(device)
args = TransformerTrainingArgs()
trainer = TransformerTrainer(args, model)
loss_history = trainer.train()

Training Progress:   0%|          | 376/100000 [00:22<1:37:58, 16.95it/s, Step: 376, Loss: 0.0014355]


KeyboardInterrupt: 

In [None]:
loss_history[-1]

8.04710725788027e-06

In [None]:
vecs = generate_linear_recurrences(1, length=30, return_type='seq')
vecs = vecs.to(device)
print(vecs[:, 1:, :])
print(model(vecs)[:, :-1, :])
print(trainer.loss(model(vecs)[:, :-1, :], vecs[:, 1:, :]))

tensor([[[-1.2819e-02,  1.1896e-02, -9.3517e-03,  ...,  1.2094e-02,
           1.6456e-02,  1.1728e-04],
         [-2.1908e-02,  1.2884e-02, -1.4077e-02,  ...,  1.4355e-02,
           2.5902e-02, -1.0502e-02],
         [-3.0997e-02,  1.3873e-02, -1.8802e-02,  ...,  1.6616e-02,
           3.5348e-02, -2.1122e-02],
         ...,
         [-2.4913e-01,  3.7596e-02, -1.3221e-01,  ...,  7.0884e-02,
           2.6206e-01, -2.7599e-01],
         [-2.5822e-01,  3.8585e-02, -1.3693e-01,  ...,  7.3145e-02,
           2.7151e-01, -2.8661e-01],
         [-2.6731e-01,  3.9573e-02, -1.4166e-01,  ...,  7.5406e-02,
           2.8095e-01, -2.9723e-01]]], device='cuda:0')
tensor([[[-0.0056,  0.0110, -0.0052,  ...,  0.0115,  0.0067,  0.0101],
         [-0.0226,  0.0127, -0.0155,  ...,  0.0156,  0.0264, -0.0135],
         [-0.0321,  0.0134, -0.0199,  ...,  0.0190,  0.0358, -0.0231],
         ...,
         [-0.2475,  0.0350, -0.1288,  ...,  0.0727,  0.2608, -0.2825],
         [-0.2565,  0.0360, -0.1334,  .

In [None]:
hooked_model = HookedTransformer(cfg)
print(hooked_model.state_dict().keys() - model.state_dict().keys())
hooked_model.load_and_process_state_dict(state_dict=model.state_dict(),
                                         fold_ln=False,
                                         center_writing_weights=False,
                                         center_unembed=False,
                                         fold_value_biases=False,
                                         refactor_factored_attn_matrices=False)

#quick check lol
t.testing.assert_close(hooked_model.embed.W_E, model.embed.W_E)

{'blocks.0.attn.mask', 'blocks.3.attn.mask', 'blocks.2.attn.mask', 'blocks.1.attn.mask'}


In [None]:
vecs = generate_linear_recurrences(1, length=30, return_type='seq')
print(vecs[:, 1:, :])
result = hooked_model(vecs)
print(result[:, :-1, :])
print(trainer.loss(result[:, :-1, :], vecs[:, 1:, :]))

t.testing.assert_close(result[:, :-1, :], model(vecs)[:, :-1, :])
t.testing.assert_close(vecs[:, 1:, :], result[:, :-1, :])

tensor([[[-9.4164e-03,  3.1504e-04,  2.8005e-02,  ..., -7.6020e-03,
          -1.5232e-02,  1.1235e-02],
         [-3.0781e-03, -1.2558e-02,  4.0473e-02,  ..., -5.5979e-05,
          -2.1835e-02,  1.8238e-02],
         [ 3.2603e-03, -2.5432e-02,  5.2940e-02,  ...,  7.4901e-03,
          -2.8438e-02,  2.5242e-02],
         ...,
         [ 1.5538e-01, -3.3439e-01,  3.5216e-01,  ...,  1.8860e-01,
          -1.8691e-01,  1.9333e-01],
         [ 1.6172e-01, -3.4727e-01,  3.6462e-01,  ...,  1.9614e-01,
          -1.9351e-01,  2.0034e-01],
         [ 1.6806e-01, -3.6014e-01,  3.7709e-01,  ...,  2.0369e-01,
          -2.0012e-01,  2.0734e-01]]], device='cuda:0')
tensor([[[-0.0177,  0.0133,  0.0149,  ..., -0.0132, -0.0085,  0.0032],
         [-0.0050, -0.0109,  0.0379,  ...,  0.0011, -0.0202,  0.0139],
         [ 0.0015, -0.0248,  0.0511,  ...,  0.0104, -0.0276,  0.0224],
         ...,
         [ 0.1527, -0.3300,  0.3469,  ...,  0.1932, -0.1869,  0.1824],
         [ 0.1587, -0.3426,  0.3589,  .

AssertionError: Tensor-likes are not close!

Mismatched elements: 1156 / 1160 (99.7%)
Greatest absolute difference: 0.01600748300552368 at index (0, 0, 12) (up to 1e-05 allowed)
Greatest relative difference: 40.226444244384766 at index (0, 1, 24) (up to 1.3e-06 allowed)

In [None]:
logits, cache = hooked_model.run_with_cache(vecs)

In [None]:
plot.imshow(
    cache['pattern', 0].squeeze(dim=0),
    x=np.arange(30),
    y=np.arange(30),
    facet_col=0, # This argument tells plotly which dimension to split into separate plots
    facet_labels=[f"Head {i}" for i in range(6)], # Subtitles of separate plots
    title="Attention Patterns in Layer 0",
    xaxis = "Key",
    yaxis = "Query",
    width=5000
)

In [None]:
t.save(hooked_model.state_dict(), "arithseq.pth")

In [None]:
with open('loss_history_arithseq.txt', 'w') as f:
    f.write(str(loss_history))
    f.close()

In [None]:
plot.imshow(
    cache['pattern', 1].squeeze(dim=0),
    x=np.arange(4),
    y=np.arange(4),
    facet_col=0, # This argument tells plotly which dimension to split into separate plots
    facet_labels=[f"Head {i}" for i in range(12)], # Subtitles of separate plots
    title="Attention Patterns in Layer 0",
    xaxis = "Key",
    yaxis = "Query",
    width=5000
)