# Imports

In [None]:
import sys
import einops
from dataclasses import dataclass
from transformer_lens import HookedTransformer
from transformer_lens.utils import gelu_new, tokenize_and_concatenate
import torch as t
from torch import Tensor
import torch.nn as nn
import numpy as np
import math
from tqdm.notebook import tqdm
from typing import Tuple, List, Optional, Dict
from jaxtyping import Float, Int
from transformers.models.gpt2.tokenization_gpt2_fast import GPT2TokenizerFast
from collections import defaultdict
from rich.table import Table
from rich import print as rprint
import datasets
from torch.utils.data import DataLoader
import wandb
from pathlib import Path
import webbrowser

# Overview

Re-implement GPT-2 style (decoder-only) transformer from scratch. Brain dump of what will be required:
1. Download training data 
2. Clean, tokenise, build vocab
3. Code to cast to token embedding space
4. Code to add positional embeddings 
5. Code for Transformer block:  
    a. Key matrix  
    b. Query matrix  
    c. Value matrix   
    d. QK matrix   
    e. Post-processing on QK matrix (attention scores)  
    g. Multiply attention scores with Value matrix  
    h. Multiply my Output matrix, casting back to embedding space, to add to residual stream  
    i. MLP block   
    j. Unembedding transformation   
    k. Cast to logits  
6. LayerNorm  
7. Initialise parameters sensibly (Xavier)
8. Implement loss function (cross-entropy loss?)
9. Choose an optimizer
10. Code the training loop (forward pass, loss, backward pass, update weights)
11. Create evaluation metrics for model
12. Implement text generation procedures (top-k?)
13. Allow saving and loading of the model 
14. Hyperparameter tuning 
15. Optimisation (mixed-precision floats? GPUs?)
16. API code

Extras:
- dataclass as a config
    

# Plan
- EOD Mon - training data, tokenisation, transformer block done 
- EOD Tue - parameter initialisation, loss function, optimizer, training loop 
- EOD Thu - eval metrics, text generation procedures, save and load model, hyperparam tuning
- EOD Fri - mixed-precision floats, GPUs, API code. 

# Download Training Data

TinyStories has been proposed, which when used to train a single layer GPT model, works well. Let's look to do the same thing, and create a small GPT model that works. 

In [None]:
# Import HuggingFace load_dataset function 
from datasets import load_dataset

# Call with name of dataset
tiny_stories = load_dataset('roneneldan/TinyStories')

In [None]:
# # Splice out training vs test
# tiny_stories_train = tiny_stories['train']
# tiny_stories_test = tiny_stories['validation']

# Clean, tokenise, build vocab

## Examine data

In [None]:
# print(tiny_stories_train.features)
# print(tiny_stories_train.info.description)
# print(tiny_stories_train.info.features)  
# print(tiny_stories_train.info.splits)   
# print (type(tiny_stories_train))

In [None]:
# example = tiny_stories_train[6]
# print(example)

In [None]:
# example = tiny_stories_train[0:6]
# for i in example:
#     print (len(example['text']))

Data is an Arrow data type, so indexing it seems to always return the key 'text', but the values grow as you index more entries

## Tokenise via SentencePiece

In [None]:
# import csv
# # Pre-process input for SentencePiece
# with open('tinystories_for_sentencepiece.csv', 'w', newline='', encoding='utf-8') as csvfile:
#     writer = csv.writer(csvfile)

#     # 1. Iterate over the rows of the dataframe
#     for i in tiny_stories_train['text']:
#         writer.writerow([i])

In [None]:
# import sentencepiece as spm

# # Define parameters for training
# train_args = {
#     'input': 'tinystories_for_sentencepiece.csv',             # Input file
#     'model_prefix': 'mymodel',        # Prefix for the output model files (.model and .vocab)
#     'vocab_size': 4000,              # Size of the vocabulary
#     'character_coverage': 0.9995,     # Character coverage to be considered for the model. Good defaults are: 0.9995 for languages with rich character sets like Japanese or Chinese and 0.9997 for others
#     'model_type': 'unigram',          # Model type can be 'unigram' (default), 'bpe', 'char', or 'word'
#     # Add other parameters as needed.
# }

# # Train the model
# spm.SentencePieceTrainer.Train(' '.join([f'--{k}={v}' for k, v in train_args.items()]))

# print("Model trained and saved as mymodel.model and mymodel.vocab!")

## Vocab

In [None]:
import sentencepiece as spm

sp = spm.SentencePieceProcessor()
sp.load('mymodel.model') 

vocab_size = sp.get_piece_size()
print (vocab_size)

vocab = {sp.id_to_piece(i): sp.get_score(i) for i in range(vocab_size)}
for token, score in vocab.items():
    print(f'{token}: {score}')

# Code to cast token to embedding space

## Config Dataclass 

In [None]:
from dataclasses import dataclass
@dataclass
class Config:
    d_model: int = 768
#     debug: bool = True
    layer_norm_eps: float = 1e-5
    d_vocab: int = 50257
    init_range: float = 0.02
    n_ctx: int = 1024
    d_head: int = 64
    d_mlp: int = 3072
    n_heads: int = 12
    n_layers: int = 12

cfg = Config()
print(cfg)

## Device

In [None]:
device = t.device("cuda" if t.cuda.is_available() else "cpu")


## Embedding Module

In [None]:
import torch as t
class Embed(nn.Module):
    def __init__(self, cfg:Config):
        super().__init__()
        self.cfg = cfg
        self.W_E = nn.Parameter(t.empty(cfg.d_vocab, cfg.d_model))
        nn.init.normal_(self.W_E, std = self.cfg.init_range)
    def forward(self, tokens: Int[Tensor, "batch position"]) -> Float[Tensor, "batch position d_model"]:
        return self.W_E[tokens]

## Positional Embedding Module

In [None]:
class PosEmbed(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.W_pos = nn.Parameter(t.empty(cfg.n_ctx, cfg.d_model))
        nn.init.normal_(self.W_pos, std=self.cfg.init_range)
        
    def forward(self, tokens: Int[Tensor, "batch position"]) -> Float[Tensor, "batch position d_model"]:
        batch, seq_len = tokens.shape
        return einops.repeat(self.W_pos[:seq_len], "seq d_model -> batch seq d_model", batch = batch)

# Transformer Block

## Attention Module

In [None]:
class Attention(nn.Module):
    IGNORE: Float[Tensor, ""]
    
    def __init__(self, cfg:Config):
        super().__init__()
        self.cfg = cfg 
        self.W_Q = nn.Parameter(t.empty(cfg.n_heads, cfg.d_model, cfg.d_head))
        self.W_K = nn.Parameter(t.empty(cfg.n_heads, cfg.d_model, cfg.d_head))
        self.W_V = nn.Parameter(t.empty(cfg.n_heads, cfg.d_model, cfg.d_head))
        self.W_O = nn.Parameter(t.empty(cfg.n_heads, cfg.d_head, cfg.d_model))
        self.b_Q = nn.Parameter(t.zeros((cfg.n_heads, cfg.d_head)))
        self.b_K = nn.Parameter(t.zeros((cfg.n_heads, cfg.d_head)))
        self.b_V = nn.Parameter(t.zeros((cfg.n_heads, cfg.d_head)))
        self.b_O = nn.Parameter(t.zeros((cfg.d_model)))
        nn.init.normal_(self.W_Q, std=self.cfg.init_range)
        nn.init.normal_(self.W_K, std=self.cfg.init_range)
        nn.init.normal_(self.W_V, std=self.cfg.init_range)
        nn.init.normal_(self.W_O, std=self.cfg.init_range)
        self.register_buffer("IGNORE", t.tensor(-1e5, dtype=t.float32, device=device))

    def forward(
        self, normalized_resid_pre: Float[Tensor, "batch seq_len d_model"]
    ) -> Float[Tensor, "batch seq_len d_model"]:
        Keys = einops.einsum(
            normalized_resid_pre,
            self.W_K,
            "batch seq_len d_model, n_heads d_model d_head -> batch seq_len n_heads d_head"
            ) + self.b_K

        Queries = einops.einsum(
            normalized_resid_pre,
            self.W_Q,
            "batch seq_len d_model, n_heads d_model d_head -> batch seq_len n_heads d_head"
            ) + self.b_Q
        Values = einops.einsum(
            normalized_resid_pre,
            self.W_V,
            "batch seq_len d_model, n_heads d_model d_head -> batch seq_len n_heads d_head"
            ) + self.b_V
        Attention_Scores = einops.einsum(
            Queries,
            Keys,
            "batch seq_len_Q n_heads d_head, batch seq_len_K n_heads d_head -> batch n_heads seq_len_Q seq_len_K")
        Attention_Scores_Masked_Scaled = self.apply_causal_mask(Attention_Scores / self.cfg.d_head**0.5)
        Attention_Scores_Masked_Scaled_Softmaxed = Attention_Scores_Masked_Scaled.softmax(-1)

#         Z = einops.einsum(Attention_Scores_Masked_Scaled_Softmaxed, self.W_V, "batch seq_len_Q seq_len_K , batch seq_len_K n_heads d_head -> batch seq_len_Q n_heads d_head")
        Z = einops.einsum(
            Values,
            Attention_Scores_Masked_Scaled_Softmaxed,
            "batch seq_len_K n_heads d_head, batch n_heads seq_len_Q seq_len_K -> batch seq_len_Q n_heads d_head")

        Attention_Out = einops.einsum(
            Z, 
            self.W_O, 
            "batch seq_len_Q n_heads d_head, n_heads d_head d_model -> batch seq_len_Q d_model"
            ) + self.b_O

        return Attention_Out
    
    def apply_causal_mask(
        self, attn_scores: Float[Tensor, "batch n_heads query_pos key_pos"]
    ) -> Float[Tensor, "batch n_heads query_pos key_pos"]:
        '''
        Applies a causal mask to attention scores, and returns masked scores.
        '''
        key_by_query_ones = t.ones(attn_scores.size(-2), attn_scores.size(-1), device = attn_scores.device)
        mask = t.triu(key_by_query_ones, diagonal = 1).bool()
        attn_scores.masked_fill(mask, self.IGNORE)
        return attn_scores
        
    

        


## MLP Module

In [None]:
class MLP(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.W_in = nn.Parameter(t.empty(cfg.d_model, cfg.d_mlp))
        self.W_out = nn.Parameter(t.empty(cfg.d_mlp, cfg.d_model))
        self.b_in = nn.Parameter(t.zeros(cfg.d_mlp))
        self.b_out = nn.Parameter(t.zeros(cfg.d_model))
        nn.init.normal_(self.W_in, std = self.cfg.init_range)
        nn.init.normal_(self.W_out, std = self.cfg.init_range)
    
    def forward(
        self, normalized_resid_mid: Float[Tensor, "batch seq_len d_model"]
    ) -> Float[Tensor, "batch seq_len d_model"]:
        
        post_W_in = einops.einsum(
            normalized_resid_mid,
            self.W_in,
            "batch seq_len d_model, d_model d_mlp -> batch seq_len d_mlp") + self.b_in
        
        post_activation = gelu_new(post_W_in) 
        
        post_W_out = einops.einsum(
            post_activation,
            self.W_out, 
            "batch seq_len d_mlp, d_mlp d_model -> batch seq_len d_model") + self.b_out
        return post_W_out

## LayerNorm Module

In [None]:
class LayerNorm(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.w = nn.Parameter(t.ones(cfg.d_model))
        self.b = nn.Parameter(t.zeros(cfg.d_model))

    def forward(self, residual: Float[Tensor, "batch posn d_model"]) -> Float[Tensor, "batch posn d_model"]:
        residual_mean = residual.mean(dim=-1, keepdim=True)
        residual_std = (residual.var(dim=-1, keepdim=True, unbiased=False) + self.cfg.layer_norm_eps).sqrt()

        residual = (residual - residual_mean) / residual_std
        return residual * self.w + self.b

## Assemble Transformer Block

In [None]:
class TransformerBlock(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.ln1 = LayerNorm(cfg)
        self.attn = Attention(cfg)
        self.ln2 = LayerNorm(cfg)
        self.mlp = MLP(cfg)
    
    def forward(
        self, resid_pre: Float[Tensor, "batch seq_len d_model"]
    ) -> Float[Tensor, "batch seq_len d_model"]:
        resid_mid = self.attn(self.ln1(resid_pre)) + resid_pre
        resid_post = self.mlp(self.ln2(resid_mid)) + resid_mid
        return resid_post

## Unembedding Module

In [None]:
class Unembed(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.W_U = nn.Parameter(t.empty(cfg.d_model, cfg.d_vocab))
        self.b_U = nn.Parameter(t.zeros(cfg.d_vocab), requires_grad = False)
        nn.init.normal_(self.W_U, std = self.cfg.init_range)
        
    def forward(
        self, resid_stream: Float[Tensor, "batch seq_len d_model"]
    ) -> Float[Tensor, "batch seq_len d_vocab"]:
        
        Unembedding = einops.einsum(
            resid_stream,
            self.W_U,
            "batch seq_len d_model, d_model d_vocab -> batch seq_len d_vocab") + self.b_U
        return Unembedding

# Full Transformer

In [None]:
class DemoTransformer(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.embed = Embed(cfg)
        self.pos_embed = PosEmbed(cfg)
        self.blocks = nn.ModuleList([TransformerBlock(cfg) for _ in range(cfg.n_layers)])
        self.ln_final = LayerNorm(cfg)
        self.unembed = Unembed(cfg)
    
    def forward(self, tokens: Float[Tensor, "batch seq_len"]
               ) -> Float[Tensor, "batch seq_len d_vocab"]:
        
        residual = self.embed(tokens) + self.pos_embed(tokens)
        for block in self.blocks:
            residual = block(residual)
        logits = self.unembed(self.ln_final(residual))
        return logits
        

In [None]:
demo_transformer = DemoTransformer(Config).to(device)

# Training Loop

## Create Model

In [None]:
model_cfg = Config(
    d_model=256,
    n_heads=4,
    d_head=64,
    d_mlp=1024,
    n_layers=2,
    n_ctx=256,
    d_vocab= 50257
)
model = DemoTransformer(model_cfg)

## Create hyperparams class

In [None]:
# @dataclass
# class TransformerTrainingArgs():
#     batch_size = 16
#     epochs = 10
#     max_steps_per_epoch = 200
#     lr = 1e-3
#     weight_decay = 1e-2
#     wandb_project: Optional[str] = "day2-demotransformer"
#     wandb_name: Optional[str] = 'shaheen-ahmed'

@dataclass
class TransformerTrainingArgs():
    batch_size = 16
    epochs = 5
    max_steps_per_epoch = 100
    lr = 1e-3
    weight_decay = 1e-2
    wandb_project: Optional[str] = "day2-demotransformer"
    wandb_name: Optional[str] = 'shaheen-ahmed'

args = TransformerTrainingArgs()

## Prepare Data

In [None]:
# dataset = datasets.load_dataset("NeelNanda/pile-10k", split="train").remove_columns("meta")
tiny_stories = load_dataset('roneneldan/TinyStories',split='train')

In [None]:
reference_gpt2 = HookedTransformer.from_pretrained("gpt2-small", fold_ln=False, center_unembed=False, center_writing_weights=False)


In [None]:
tokenized_dataset = tokenize_and_concatenate(tiny_stories,
                                            reference_gpt2.tokenizer,
                                            streaming=False,
                                            max_length=model.cfg.n_ctx,
                                            column_name="text", 
                                            add_bos_token=True,
                                            num_proc=10)

In [None]:
dataset_dict = tokenized_dataset.train_test_split(test_size=1000)
train_loader = DataLoader(
    dataset_dict["train"],
    batch_size=args.batch_size,
    shuffle=True,
    num_workers=4,
    pin_memory=False)

In [None]:
test_loader = DataLoader(
    dataset_dict["test"],
    batch_size=args.batch_size,
    shuffle=False,
    num_workers=4,
    pin_memory=False)

In [None]:
first_batch = train_loader.dataset[:args.batch_size]
print(first_batch.keys())
print(first_batch['tokens'].shape)

In [None]:
print(first_batch)

## Loss Function

In [None]:
def get_log_probs(
    logits: Float[Tensor, "batch posn d_vocab"],
    tokens: Int[Tensor, "batch posn"]
) -> Float[Tensor, "batch posn-1"]:

    log_probs = logits.log_softmax(dim=-1)
    # Get logprobs the first seq_len-1 predictions (so we can compare them with the actual next tokens)
    log_probs_for_tokens = log_probs[:, :-1].gather(dim=-1, index=tokens[:, 1:].unsqueeze(-1)).squeeze(-1)

    return log_probs_for_tokens

## Actual Training Loop

In [None]:
class TransformerTrainer:
    def __init__(self, args: TransformerTrainingArgs, model: DemoTransformer):
        super().__init__()
        self.model = model
        self.args = args 
        self.optimizer = t.optim.AdamW(self.model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
        self.step = 0

    def training_step(self, batch: Dict[str, Int[Tensor, "batch seq"]]) -> Float[Tensor, ""]:
        tokens = batch['tokens'].to(device)
        logits = self.model(tokens)
        loss = -get_log_probs(logits, tokens).mean()
        loss.backward()
        self.optimizer.step()
        self.optimizer.zero_grad()
        self.step += 1
#         wandb.log({"train_loss": loss}, step=self.step)
        return loss
    
    def validation_step(self, batch: Dict[str, Int[Tensor, "batch seq"]]):
        tokens = batch["tokens"].to(device)
        logits: Tensor = self.model(tokens)[:, :-1]
        predicted_tokens = logits.argmax(dim=-1)
        correct_predictions = (predicted_tokens == tokens[:, 1:]).flatten()
        return correct_predictions
    
    def train(self):
        print ('wandb init below')

#         wandb.init(project=self.args.wandb_project, name=self.args.wandb_name, config=self.args)
        print ('wandb init done')

        accuracy = np.nan
        
        progress_bar = tqdm(total = self.args.max_steps_per_epoch * self.args.epochs)
        print ('progress bar made')
        for epoch in range(self.args.epochs):
            for i, batch in enumerate(self.train_loader()):
                loss = self.training_step(batch)
                progress_bar.update()
                progress_bar.set_description(f"Epoch {epoch+1}, loss: {loss:.3f}, accuracy: {accuracy:.2f}")
                if i >= self.args.max_steps_per_epoch:
                    break

            correct_predictions = t.concat([self.validation_step(batch) for batch in self.test_loader()])
            accuracy = correct_predictions.float().mean().item()
#             wandb.log({"accuracy": accuracy}, step=self.step)

#         wandb.finish()
    
    def train_loader(self) -> DataLoader:
        return DataLoader(dataset_dict["train"], batch_size=self.args.batch_size, shuffle=True, num_workers=4, pin_memory=True)

    def test_loader(self) -> DataLoader:
        return DataLoader(dataset_dict["test"], batch_size=self.args.batch_size, shuffle=False, num_workers=4, pin_memory=True)
    

In [None]:
model = DemoTransformer(model_cfg).to(device)
args = TransformerTrainingArgs()
trainer = TransformerTrainer(args, model)
trainer.train()

In [None]:
t.save(model.state_dict, 'gpt2_style_model_weights.pth')

# Sampling

In [None]:
model_cfg = Config()
sampling_model = DemoTransformer(model_cfg).to(device)
sampling_model.load_state_dict(t.load('gpt2_style_model_weights.pth'))

In [None]:
tokenizer = reference_gpt2.tokenizer

class TransformerSampler:
    def __init__(self, model: DemoTransformer, tokenizer: GPT2TokenizerFast):
        self.model = model
        self.cfg = model.cfg
        self.tokenizer = tokenizer

@t.inference_mode()
def sample(self, prompt, max_tokens_generated=100, verbose=False, **kwargs):
    self.model.eval()
    input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(device)[0]

    for i in range(max_tokens_generated):
        # Get new logits (make sure we don't pass in more tokens than the model's context length)
        logits = self.model(input_ids[None, -self.cfg.n_ctx:])
        # We only take logits for the last token, because this is what we're sampling
        logits = logits[0, -1]
        # Get next token (as a tensor of size (1, 1) so we can concat it to input_ids)
        next_token = t.tensor([self.sample_next_token(input_ids, logits, **kwargs)], device=device)
        # Create new input ids string, with shape (1, old_seq_len + 1)
        input_ids = t.cat([input_ids, next_token], dim=-1)
        # Print out results, if required
        if verbose:
            print(self.tokenizer.decode(input_ids), end="\r")
        # If our new token was the end-of-text token, stop
        if next_token == getattr(self.tokenizer, "eos_token_id", None):
            break

    return self.tokenizer.decode(input_ids)