## Setup

In [None]:
import os
import sys
from pathlib import Path

import circuitsvis as cv
import einops
import numpy as np
import torch as t
from eindex import eindex
from IPython.display import display
from plotly.subplots import make_subplots
from torch import Tensor
from transformer_lens import utils

# t.set_grad_enabled(False)
if str(exercises_path := Path(os.path.abspath("")).parent.parent) not in sys.path:
    sys.path.append(str(exercises_path))
section_dir = exercises_path / "monthly_algorithmic_problems/november24_trigrams"
assert section_dir.exists()

from monthly_algorithmic_problems.november24_trigrams.dataset import BigramDataset
from monthly_algorithmic_problems.november24_trigrams.model import create_model
from monthly_algorithmic_problems.november24_trigrams.training import TrainArgs, Trainer
from plotly_utils import bar, hist, imshow, line

device = t.device("cuda" if t.cuda.is_available() else "cpu")


## Dataset

Each sequence has tokens uniformly generated at random, except if the first 2 tokens of a particular trigram appear, in which case the next token is uniquely determined as the third token of the trigram. You can list all the trigrams with `dataset.trigrams`.


In [None]:
dataset = BigramDataset(size=10, d_vocab=10, seq_len=10, trigram_prob=0.1, device=device, seed=42)
print(dataset.trigrams)
print(dataset.toks)


## Training

Link to WandB run [here](https://wandb.ai/callum-mcdougall/alg-challenge-trigrams-nov24/runs/c7jjsofv?nw=nwusercallummcdougall). There are 5 metrics:

- `train_loss`, which is the average cross entropy loss on the training set
- `train_loss_as_frac`, which is the loss scaled so that 1 is the loss you get when uniformly guessing over all tokens in the vocab, and 0 is the lowest possible loss (where the model has a uniform distribution everywhere except for the trigrams, where it has probability 1 on the correct token)
- `trigram_*`, which are three metrics specifically for the trigram dataset (i.e. the dataset consisting of only the dataset's special trigrams, i.e. the sequences `(a, b, c)` where `c` always directly follows `ab`). These metrics are only computed on the last token (i.e. the 3rd one) in each sequence. We have:
    - `trigram_n_correct` = number of trigrams that were correctly predicted
    - `trigram_frac_correct` = fraction of total trigrams that were correctly predicted
    - `trigram_avg_correct_prob` = average probability assigned to the correct trigram token

Note that `trigram_frac_correct` is higher than `trigram_avg_correct_prob`, because some trigrams are predicted with slightly higher than uniform probability but still far from certainty. Also note that neither of these values has hit 100%, indicating that the model has learned most but not all of the trigrams. You can investigate these results for yourself when you inspect the model below!

In [None]:
args = TrainArgs(
    #
    # Dataset
    d_vocab=75,
    seq_len=50,
    trigram_prob=0.05,  # this is the probability that any randomly selected token is in a trigram
    n_trigrams=None,  # n_trigrams is determined by the trigram_prob
    #
    # Training
    trainset_size=100_000,
    valset_size=5_000,
    epochs=100,
    batch_size=512,
    lr_start=1e-3,
    lr_end=1e-4,
    weight_decay=1e-3,
    #
    # Model architecture
    d_model=32,
    d_head=24,
    n_layers=1,
    n_heads=1,
    d_mlp=20,
    normalization_type=None,
    #
    # Misc.
    seed=40,
    device=device,
    use_wandb=True,
)
model = Trainer(args).train()

# Save the model
filename = section_dir / "trigram_model.pt"
t.save(model.state_dict(), filename)


In [None]:
# Check we can load in the model
model = create_model(
    d_vocab=75,
    seq_len=50,
    d_model=32,
    d_head=24,
    n_layers=1,
    n_heads=1,
    d_mlp=20,
    normalization_type=None,
    seed=40,
    device=device,
)
model.load_state_dict(t.load(filename, weights_only=True))


## Testing

In [None]:
# Define some basic stuff

BIGRAM_PROB = 0.05
BATCH_SIZE = 2500

dataset = BigramDataset(
    size=BATCH_SIZE,
    d_vocab=model.cfg.d_vocab,
    seq_len=model.cfg.n_ctx,
    trigram_prob=BIGRAM_PROB,
    device=device,
    seed=40,
)

logits, cache = model.run_with_cache(dataset.toks)
logprobs = logits[:, :-1].log_softmax(-1)
probs = logprobs.softmax(-1)

targets = dataset.toks[:, 1:]
logprobs_correct = eindex(logprobs, targets, "batch seq [batch seq]")
probs_correct = eindex(probs, targets, "batch seq [batch seq]")

print(f"Average cross entropy loss: {-logprobs_correct.mean().item():.3f}")
print(f"Mean probability on correct label: {probs_correct.mean():.3f}")
print(f"Median probability on correct label: {probs_correct.median():.3f}")
print(f"Min probability on correct label: {probs_correct.min():.3f}")

imshow(probs_correct[:50], width=600, height=600, title="Sample model probabilities")

# Observation: they're mostly 1/d_vocab except for the trigrams which are 1, which is what we expect
