## Imports

In [1]:
import torch as t
from fancy_einsum import einsum
import einops
from dataclasses import dataclass
from pathlib import Path
import json
from tqdm import tqdm
import plotly.express as px
import numpy as np
import pandas as pd

t.manual_seed(0)

<torch._C.Generator at 0x7f8e309876b0>

## Model

In [2]:
device = 'cuda:0'

In [3]:
@dataclass
class Config:
    d_model: int = 16
    layer_norm_eps: float = 1e-5
    d_vocab: int = 1024
    init_range: float = 0.02
    n_ctx: int = 1024
    d_head: int = 8
    d_mlp: int = 64
    n_heads: int = 2
    n_layers: int = 1

cfg = Config()
print(cfg)

Config(d_model=16, layer_norm_eps=1e-05, d_vocab=1024, init_range=0.02, n_ctx=1024, d_head=8, d_mlp=64, n_heads=2, n_layers=1)


In [4]:
class LayerNorm(t.nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.w = t.nn.Parameter(t.ones(cfg.d_model).to(device))
        self.b = t.nn.Parameter(t.zeros(cfg.d_model).to(device))

    def forward(self, residual):
        residual_std = (residual.var(dim=-1, keepdim=True, unbiased=False) + self.cfg.layer_norm_eps).sqrt()
        residual = (residual - residual.mean(dim=-1, keepdim=True)) / residual_std
        return residual * self.w + self.b

In [5]:
class Embed(t.nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.W_E = t.nn.Parameter(t.empty((cfg.d_vocab, cfg.d_model)).to(device))
        t.nn.init.normal_(self.W_E, std=self.cfg.init_range)

    def forward(self, tokens):
        return self.W_E[tokens]

In [6]:
class PosEmbed(t.nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.W_pos = t.nn.Parameter(t.empty((cfg.n_ctx, cfg.d_model)).to(device))
        t.nn.init.normal_(self.W_pos, std=self.cfg.init_range)

    def forward(self, tokens):
        batch, seq_len = tokens.shape
        return einops.repeat(self.W_pos[:seq_len], "seq d_m -> b seq d_m", b=batch)

In [7]:
class Attention(t.nn.Module):

    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.W_Q = t.nn.Parameter(t.empty((cfg.n_heads, cfg.d_model, cfg.d_head)).to(device))
        self.W_K = t.nn.Parameter(t.empty((cfg.n_heads, cfg.d_model, cfg.d_head)).to(device))
        self.W_V = t.nn.Parameter(t.empty((cfg.n_heads, cfg.d_model, cfg.d_head)).to(device))
        self.W_O = t.nn.Parameter(t.empty((cfg.n_heads, cfg.d_head, cfg.d_model)).to(device))
        self.b_Q = t.nn.Parameter(t.zeros((cfg.n_heads, cfg.d_head)).to(device))
        self.b_K = t.nn.Parameter(t.zeros((cfg.n_heads, cfg.d_head)).to(device))
        self.b_V = t.nn.Parameter(t.zeros((cfg.n_heads, cfg.d_head)).to(device))
        self.b_O = t.nn.Parameter(t.zeros((cfg.d_model)).to(device))
        t.nn.init.normal_(self.W_Q, std=self.cfg.init_range)
        t.nn.init.normal_(self.W_K, std=self.cfg.init_range)
        t.nn.init.normal_(self.W_V, std=self.cfg.init_range)
        t.nn.init.normal_(self.W_O, std=self.cfg.init_range)

    def forward(self, normalized_resid_pre):

        q = einsum("b p d_m, h d_m d_h -> b p h d_h", normalized_resid_pre, self.W_Q) + self.b_Q
        k = einsum("b p d_m, h d_m d_h -> b p h d_h", normalized_resid_pre, self.W_K) + self.b_K
        v = einsum("b p d_m, h d_m d_h -> b p h d_h", normalized_resid_pre, self.W_V) + self.b_V

        attn_scores = einsum("b p_Q h d_h, b p_K h d_h -> b h p_Q p_K", q, k)

        attn_scores_masked = self.apply_causal_mask(attn_scores / self.cfg.d_head ** 0.5)
        attn_pattern = attn_scores_masked.softmax(-1)

        z = einsum("b p_K h d_h, b h p_Q p_K -> b p_Q h d_h", v, attn_pattern)

        attn_out = einsum("b p_Q h d_h, h d_h d_m -> b p_Q d_m", z, self.W_O) + self.b_O
        return attn_out

    def apply_causal_mask(self, attn_scores):

        all_ones = t.ones(attn_scores.size(-2), attn_scores.size(-1), device=attn_scores.device)
        mask = t.triu(all_ones, diagonal=1).bool()
        attn_scores.masked_fill_(mask, -1e5)
        return attn_scores

In [8]:
class MLP(t.nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.W_in = t.nn.Parameter(t.empty((cfg.d_model, cfg.d_mlp)).to(device))
        self.W_out = t.nn.Parameter(t.empty((cfg.d_mlp, cfg.d_model)).to(device))
        self.b_in = t.nn.Parameter(t.zeros((cfg.d_mlp)).to(device))
        self.b_out = t.nn.Parameter(t.zeros((cfg.d_model)).to(device))
        t.nn.init.normal_(self.W_in, std=self.cfg.init_range)
        t.nn.init.normal_(self.W_out, std=self.cfg.init_range)

    def forward(self, normalized_resid_mid):

        pre = einsum("b p d_m, d_m d_mlp -> b p d_mlp", normalized_resid_mid, self.W_in) + self.b_in
        post = t.nn.GELU()(pre)
        mlp_out = einsum("b p d_mlp, d_mlp d_m -> b p d_m", post, self.W_out) + self.b_out
        return mlp_out

In [9]:
class TransformerBlock(t.nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.ln1 = LayerNorm(cfg)
        self.attn = Attention(cfg)
        self.ln2 = LayerNorm(cfg)
        self.mlp = MLP(cfg)

    def forward(self, resid_pre):
        resid_mid = resid_pre + self.attn(self.ln1(resid_pre))
        resid_post = resid_mid + self.mlp(self.ln2(resid_mid))
        return resid_post

In [10]:
class Unembed(t.nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.W_U = t.nn.Parameter(t.empty((cfg.d_model, cfg.d_vocab)).to(device))
        t.nn.init.normal_(self.W_U, std=self.cfg.init_range)
        self.b_U = t.nn.Parameter(t.zeros((cfg.d_vocab), requires_grad=False).to(device))

    def forward(self, normalized_resid_final):
        return einsum("b p d_m, d_m d_v -> b p d_v", normalized_resid_final, self.W_U) + self.b_U

In [11]:
class Transformer(t.nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.embed = Embed(cfg)
        self.pos_embed = PosEmbed(cfg)
        self.blocks = t.nn.ModuleList([TransformerBlock(cfg) for _ in range(cfg.n_layers)])
        self.ln_final = LayerNorm(cfg)
        self.unembed = Unembed(cfg)

    def forward(self, tokens):
        residual = self.embed(tokens) + self.pos_embed(tokens)
        for block in self.blocks:
            residual = block(residual)
        logits = self.unembed(self.ln_final(residual))
        return logits

## Memorised strings training

In [12]:
d_vocab = 128

model_cfg = Config(
    d_model=32,
    n_heads=4,
    d_head=8,
    d_mlp=128,
    n_layers=1,
    n_ctx=128,
    d_vocab=d_vocab
)

model = Transformer(model_cfg)

In [36]:
@dataclass
class TransformerTrainingArgs():
	batch_size = 512
	epochs = 10
	num_steps = 25000
	lr = 1e-3
	beta1 = 0.9
	beta2 = 0.999
	weight_decay = 1e-2

args = TransformerTrainingArgs()

In [14]:
num_mem_seqs = 128
mem_seq = t.randint(1, d_vocab, (num_mem_seqs, model.cfg.n_ctx)).to(device)

In [15]:
SAVE_DIR = Path("/root/expl_research/mem-str-models")

In [37]:
class Trainer:
    def __init__(self, args, model):
        super().__init__()
        self.model = model
        self.args = args
        self.optimizer = t.optim.AdamW(self.model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
        self.step = 0
        self.weighting = t.ones(args.batch_size).to(device)
        self.weighting[:num_mem_seqs] = (-t.linspace(0,2,num_mem_seqs)).exp()
        self.loss_history = []
    def eval(self):
        tokens = t.randint(1, d_vocab, (args.batch_size, model.cfg.n_ctx)).to(device)
        starting_pos = t.randint(1, model.cfg.n_ctx, (num_mem_seqs,))
        for i in range(num_mem_seqs):
            tokens[i, starting_pos[i]:] = mem_seq[i, :-starting_pos[i]]
        tokens[:,0] = 0
        logits = model(tokens)
        log_probs = logits.log_softmax(dim=-1)
        correct_log_probs = -log_probs[:, :-1].gather(dim=-1, index=tokens[:, 1:].unsqueeze(-1)).squeeze(-1)
        loss = (correct_log_probs * self.weighting[:, None]).mean()
        return starting_pos, correct_log_probs, loss

    def train(self):
        try:
            for i in tqdm(range(self.args.num_steps)):
                _, _, loss = self.eval()
                loss.backward()
                self.optimizer.step()
                self.optimizer.zero_grad()
                if i % 100 == 0:
                    self.loss_history.append(loss.item())
        finally:
            t.save(model.state_dict(), SAVE_DIR/("1l_learned_embed.pt"))

In [38]:
device = 'cuda:0'
trainer = Trainer(args, model)

In [48]:
trainer.train()

100%|██████████| 25000/25000 [04:28<00:00, 92.94it/s] 


In [49]:
px.line(trainer.loss_history)

## Trying to evaluate


In [45]:
np.log(model.cfg.d_vocab)


4.852030263919617

In [20]:
'''
model.load_state_dict(t.load(SAVE_DIR/"1l_learned_embed.pt"))
model.eval()
'''

'\nmodel.load_state_dict(t.load(SAVE_DIR/"1l_learned_embed.pt"))\nmodel.eval()\n'

In [110]:
starting_pos, log_probs, loss = trainer.eval()
print(starting_pos, loss.item())

tensor([ 80,  78,  89,  47,  66, 122,  42,  40,  64,  66, 119,  19,  87,  13,
         31,  57,  35,  93, 123,  62,  52,  48,  30,  54,  53, 101,   8,  21,
         53,  61, 117,  51,   4,  59,   1,  80,  59,  10,  19, 118,   2,  35,
         90,   4,  56, 107,  48, 115,  21,  41, 123,  23,  24,   3,  74,  70,
        119,  95,  47,  49,  37, 126,  90,   1,  15,  83,  71,  83,  31,  48,
         56,  43,  25,  48,  42,  39,  32,  84,  89,   4, 109,  96, 106,  86,
        107,  56,  85,  50,  52,  11,  17,  36,  14, 101,  91,  65,  92,  91,
         42, 105,  91,  64,  44,  61,  73,  47,  52,  87,  35, 117,  54, 107,
         90,   3,  96,  36, 120,  64, 118, 110,  18,  71,  66,  55,  30,  20,
         93,  26]) 5.291360855102539


In [111]:
px.line(log_probs.T.cpu().detach())

In [112]:
log_probs[1].mean(-1)

tensor(3.1649, device='cuda:0', grad_fn=<MeanBackward1>)

In [123]:
tokens = t.randint(1, d_vocab, (args.batch_size, model.cfg.n_ctx)).to(device)
starting_pos = 2
tokens[:num_mem_seqs, starting_pos:] = mem_seq[:, :-starting_pos]
tokens[:,0] = 0
logits = model(tokens)
log_probs = logits.log_softmax(dim=-1)
correct_log_probs = -log_probs[:, :-1].gather(dim=-1, index=tokens[:, 1:].unsqueeze(-1)).squeeze(-1)
px.line(correct_log_probs[:num_mem_seqs].mean(-1).cpu().detach())