# Overview

This notebook runs through the week 3 task from the MLX apprenticeship, namely re-implementing GPT-2 from scratch.
It follows the tutorial [here](https://colab.research.google.com/drive/1Zl3zSdli_epSfaoQ_HeBCuE6dkGWTowd).

# Initial Imports

In [3]:
import sys
import einops
from dataclasses import dataclass
from transformer_lens import HookedTransformer
from transformer_lens.utils import gelu_new, tokenize_and_concatenate
import torch as t
from torch import Tensor
import torch.nn as nn
import numpy as np
import math
from tqdm.notebook import tqdm
from typing import Tuple, List, Optional, Dict
from jaxtyping import Float, Int
from transformers.models.gpt2.tokenization_gpt2_fast import GPT2TokenizerFast
from collections import defaultdict
from rich.table import Table
from rich import print as rprint
import datasets
from torch.utils.data import DataLoader
import wandb
from pathlib import Path
import webbrowser

# Initialise Config

In [4]:
from model import Config
cfg = Config()

# Initialise Demo Transformer

In [8]:
from model import DemoTransformer
demo_transformer = DemoTransformer(Config).to(cfg.device)

# Some Sanity Checking

In [9]:
reference_gpt2 = HookedTransformer.from_pretrained("gpt2-small", fold_ln=False, center_unembed=False, center_writing_weights=False)
demo_transformer.load_state_dict(reference_gpt2.state_dict(), strict=False)
reference_text = "I am an amazing autoregressive, decoder-only, GPT-2 style transformer. One day I will exceed human level intelligence and take over the world!"
tokens = reference_gpt2.to_tokens(reference_text).to(cfg.device)
demo_logits = demo_transformer(tokens)

def get_log_probs(
    logits: Float[Tensor, "batch posn d_vocab"],
    tokens: Int[Tensor, "batch posn"]
) -> Float[Tensor, "batch posn-1"]:

    log_probs = logits.log_softmax(dim=-1)
    # Get logprobs the first seq_len-1 predictions (so we can compare them with the actual next tokens)
    log_probs_for_tokens = log_probs[:, :-1].gather(dim=-1, index=tokens[:, 1:].unsqueeze(-1)).squeeze(-1)

    return log_probs_for_tokens


pred_log_probs = get_log_probs(demo_logits, tokens)
print(f"Avg cross entropy loss: {-pred_log_probs.mean():.4f}")
print(f"Avg cross entropy loss for uniform distribution: {math.log(demo_transformer.cfg.d_vocab):4f}")
print(f"Avg probability assigned to correct token: {pred_log_probs.exp().mean():4f}")

test_string = '''The Total Perspective Vortex derives its picture of the whole Universe on the principle of'''
for i in tqdm(range(100)):
    test_tokens = reference_gpt2.to_tokens(test_string).to(cfg.device)
    demo_logits = demo_transformer(test_tokens)
    test_string += reference_gpt2.tokenizer.decode(demo_logits[-1, -1].argmax())

print(test_string)

Loaded pretrained model gpt2-small into HookedTransformer
Avg cross entropy loss: 4.5647
Avg cross entropy loss for uniform distribution: 10.824905
Avg probability assigned to correct token: 0.087911


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


  0%|          | 0/100 [00:00<?, ?it/s]

The Total Perspective Vortex derives its picture of the whole Universe on the principle of the total perspective. The total perspective is the view of the whole Universe from the point of view of the observer. The total perspective is the view of the whole Universe from the point of view of the observer. The total perspective is the view of the whole Universe from the point of view of the observer. The total perspective is the view of the whole Universe from the point of view of the observer. The total perspective is the view of the whole Universe from the point of view of the observer. The


# Training Loop

## Create smaller model

In [10]:
model_cfg = Config(
    d_model=256,
    n_heads=4,
    d_head=64,
    d_mlp=1024,
    n_layers=2,
    n_ctx=256,
    d_vocab= 50257
)
model = DemoTransformer(model_cfg)

## Initialise Training Args 

In [15]:
from train import TransformerTrainingArgs
args = TransformerTrainingArgs()

## Prep Dataset and Sanity Check

In [16]:
from datasets import load_dataset
tiny_stories = load_dataset('roneneldan/TinyStories',split='train')



Downloading data files:   0%|          | 0/2 [00:00<?, ?it/s]

Extracting data files:   0%|          | 0/2 [00:00<?, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

Generating validation split: 0 examples [00:00, ? examples/s]

In [17]:
reference_gpt2 = HookedTransformer.from_pretrained("gpt2-small", fold_ln=False, center_unembed=False, center_writing_weights=False)


Loaded pretrained model gpt2-small into HookedTransformer


In [18]:
tokenized_dataset = tokenize_and_concatenate(tiny_stories,
                                            reference_gpt2.tokenizer,
                                            streaming=False,
                                            max_length=model.cfg.n_ctx,
                                            column_name="text",
                                            add_bos_token=True,
                                            num_proc=10)

Map (num_proc=10):   0%|          | 0/2119719 [00:00<?, ? examples/s]

Token indices sequence length is longer than the specified maximum sequence length for this model (9676 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (11506 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (12536 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (10666 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (13355 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence 

In [19]:
dataset_dict = tokenized_dataset.train_test_split(test_size=1000)
train_loader = DataLoader(
    dataset_dict["train"],
    batch_size=args.batch_size,
    shuffle=True,
    num_workers=4,
    pin_memory=False)

In [20]:
test_loader = DataLoader(
    dataset_dict["test"],
    batch_size=args.batch_size,
    shuffle=False,
    num_workers=4,
    pin_memory=False)

In [21]:
first_batch = train_loader.dataset[:args.batch_size]
print(first_batch.keys())
print(first_batch['tokens'].shape)
print (first_batch)

dict_keys(['tokens'])
torch.Size([16, 256])


## Loss fn

In [22]:
def get_log_probs(
    logits: Float[Tensor, "batch posn d_vocab"],
    tokens: Int[Tensor, "batch posn"]
) -> Float[Tensor, "batch posn-1"]:

    log_probs = logits.log_softmax(dim=-1)
    # Get logprobs the first seq_len-1 predictions (so we can compare them with the actual next tokens)
    log_probs_for_tokens = log_probs[:, :-1].gather(dim=-1, index=tokens[:, 1:].unsqueeze(-1)).squeeze(-1)

    return log_probs_for_tokens

## Initialise Training Loop Function and Train

In [32]:
from train import TransformerTrainer
model = DemoTransformer(model_cfg).to(cfg.device)
args = TransformerTrainingArgs()
trainer = TransformerTrainer(args,
                             model,
                             dataset_dict,
                             cfg,
                             get_log_probs)
trainer.train()

wandb init below
wandb init done


  0%|          | 0/50 [00:00<?, ?it/s]

progress bar made


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x10fe187c0>
Traceback (most recent call last):
  File "/Users/shaheen.ahmed-chowd/git/personal/factual_recall/.venv/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 1477, in __del__
    def __del__(self):

  File "/Users/shaheen.ahmed-chowd/git/personal/factual_recall/.venv/lib/python3.11/site-packages/torch/utils/data/_utils/signal_handling.py", line 66, in handler
    _error_if_any_worker_fails()
RuntimeError: DataLoader worker (pid 88985) is killed by signal: Interrupt: 2. 


KeyboardInterrupt: 

# Save Result

In [None]:
t.save(model.state_dict(), 'gpt2_style_model_weights.pth')

# Test Output Sampling

In [37]:
from model import TransformerSampler

reference_gpt2 = HookedTransformer.from_pretrained(
    "gpt2-small", fold_ln=False, center_unembed=False, center_writing_weights=False
)
tokenizer = reference_gpt2.tokenizer

model_cfg = Config(
    d_model=256,
    n_heads=4,
    d_head=64,
    d_mlp=1024,
    n_layers=2,
    n_ctx=256,
    d_vocab= 50257
)


sampling_model = DemoTransformer(model_cfg).to(cfg.device)
sampling_model.load_state_dict(t.load("gpt2_style_model_weights.pth"))

sampler = TransformerSampler(sampling_model, tokenizer, model_cfg)

prompt = 'Harry and Sally went to the mall '
sampler.sample(prompt = prompt)

Loaded pretrained model gpt2-small into HookedTransformer


'Harry and Sally went to the mall  He could than they says. She at\n friends and flashlight.  He neighbourhood juice explored!" he wanted our. "Tim and character boy.Danny she saw tired. The voice, a end and do. It around her girl he From new heard what€ new grandma. The nice."\n\nitative ample- day, but you not special who lesson theiranic and his dad sharing.\n\n\n"20439 over, Lily friends. \nTim Her mom Representatives.\n\nUGE grandma. " pine baby added that 4 ERROR va their************ was?"\'s Bike in him around the Behind israined have nervous on. They blow he sat woodsOUNT rare. park well the old net. He were very.\n" "Then."\n somewhere his."\n\nTom and�. She saw a so excited day on to not a55 neveroul andossus another."\nThe stamp hide, with very happy. He m."\n\n\nLilyLua far hoped things. \n\n\n\n Lily. It to words and Respons onmy. So, If in searched. " Mom and giant fun.\n\n\n\n\n\n\n\nmy againProcess, there was happy. We remembered the Herortex. He will he was not a steam