# Character-level nanoGPT

This is an experiment in training a tiny transformer on character-level data. It's based on a port of nanoGPT — [see the model docs for details](src/experiment/model/README.md).

Initially, this experiment was run in a Kaggle notebook. When porting it to run in Modal instead, it started to become hard to work with. The port didn't make it significantly more complicated, but a few things needed to be refactored to run well remotely. So now most of the code lives in modules under [`src/experiment`](src/experiment), and this notebook just ties it together in an way that makes it easy to play with.

In [None]:
import logging

from utils.logging import concise_logging


def configure_logging():
    concise_logging()
    logging.getLogger('experiment').setLevel(logging.INFO)
    logging.getLogger('subline').setLevel(logging.INFO)
    logging.getLogger('utils').setLevel(logging.INFO)
    logging.getLogger('mini').setLevel(logging.INFO)


# Configure local logging. This also needs to be done in remote functions.
configure_logging()

In [None]:
from typing import Any
from torch.nn import CrossEntropyLoss

from experiment.config import (
    AMPConfig,
    DataConfig,
    ModelConfig,
    OptimizerConfig,
    SchedulerConfig,
    TokenizerConfig,
    TrainingConfig,
)

config = TrainingConfig(
    model=ModelConfig(
        vocab_size=64,  # dummy value, will be set after loading the dataset
        block_size=512,
        n_embd=32,
        n_head=8,
        n_head_dim=8,
        n_ff=128,
        n_layer=15,
        dropout=0.2,
    ),
    tokenizer=TokenizerConfig(
        vocabulary=[],  # dummy value, will be set after loading the dataset
    ),
    data=DataConfig(
        batch_size=16,
        oversample=1,
        train_split=0.8,
    ),
    optimizer=OptimizerConfig(
        weight_decay=1e-1,
        learning_rate=0,
        betas=(0.9, 0.95),
    ),
    scheduler=SchedulerConfig(
        epochs=100,
        warmup_epochs=1,
        min_lr_factor=0.1,
        # decay_strategy='cosine',
    ),
    amp=AMPConfig(
        enabled=True,
    ),
)

criterion = CrossEntropyLoss()

# Default parameters for the @app.function decorator
resource_limits: dict[str, Any] = dict(buffer_containers=0, max_containers=1)

Prepare for remote execution. Technically we don't need to pre-build the image, but doing so makes the output of later cells cleaner.

In [None]:
import modal
from experiment.compute.app import data_dir
from utils.requirements import freeze
from mini.experiment import Experiment

exp = Experiment('nanogpt-rope')
exp.image = (
    modal.Image.debian_slim()
    .pip_install(*freeze(all=True, local=False))
    .add_local_python_source('experiment', 'utils', 'mini')
)  # fmt: skip
exp.volumes[data_dir.as_posix()] = volume = modal.Volume.from_name(
    'nanogpt-rope-data', create_if_missing=True
)  # fmt: skip


@exp.run_thither(**resource_limits)
async def prebuild():
    """Forces the image to be built."""
    pass


async with exp():
    await prebuild()

# Data

We'll grab a small dataset. It's just one big block of text from which we take random substrings. These may overlap, but we aim to take roughly the entire corpus on each epoch.

Of note: the "labels" $y$ are the same as the input values $x$, shifted by one, since we want to predict each next token.

```python
x = self.data[idx:idx + self.block_size]
y = self.data[idx + 1:idx + self.block_size + 1]
```

In [None]:
from experiment.config import DatasetMetadata
from utils.param_types import validate_call


@validate_call
def download_the_scarlet_pimpernel() -> tuple[str, DatasetMetadata]:
    import ftfy
    import requests

    url = 'https://www.gutenberg.org/cache/epub/60/pg60.txt'
    response = requests.get(url)
    response.raise_for_status()
    text = response.text.replace('\r\n', '\n')
    text = text[text.find('\nCHAPTER I.') : text.rfind('*** END OF THE PROJECT GUTENBERG EBOOK')].strip()
    # Normalize text to avoid weird quotation marks etc.
    text, explanation = ftfy.fix_and_explain(text)
    metadata = DatasetMetadata(
        title='The Scarlet Pimpernel',
        url=url,
        fixes=explanation or [],
        total_chars=len(text),
    )
    return text, metadata

In [None]:
from experiment.utils import align


@exp.run_thither(**resource_limits)
async def prepare_data():
    from experiment.compute.data_pipelines import save_data
    from experiment.data.preparation import tokenize_data

    configure_logging()

    sources = [
        download_the_scarlet_pimpernel(),
    ]
    data, metadata = tokenize_data(sources)
    save_data(data, metadata)
    volume.commit()
    return metadata


async with exp():
    input_metadata = await prepare_data()

config.tokenizer = input_metadata.tokenizer_config.model_copy()
config.model.vocab_size = align(config.tokenizer.vocab_size, 64)
input_metadata.model_dump(exclude={'tokenizer_config'})

# Training

In [None]:
from contextlib import asynccontextmanager
from mini.hither import Callback, run_hither
from mini.utils import coerce_to_async
from utils.lr_finder.types import LRFinderConfig, LRFinderSeries, Progress
from utils.param_types import validate_call


@exp.run_thither(gpu='L4', **resource_limits)
async def find_learning_rate(plot: Callback[LRFinderConfig | LRFinderSeries], prog: Callback[Progress]):
    import torch
    from experiment.compute.data_pipelines import load_data
    from experiment.data.dataloader import get_dataloader
    from experiment.model.gpt import GPT
    from experiment.training.optimizer import configure_optimizer
    from utils.lr_finder.lr_finder import lr_finder_search
    from utils.torch.mixed_precision import AMPContext
    from utils.torch.types import get_device

    configure_logging()

    model: torch.nn.Module = GPT(config.model)
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = configure_optimizer(model, config.optimizer)
    data, _ = load_data()

    if torch.cuda.is_available():
        data = data.cuda()
        model = model.cuda()
        criterion = criterion.cuda()

    train_loader, _ = get_dataloader(data, model_config=config.model, data_config=config.data)
    amp_context = AMPContext(use_amp=config.amp.enabled, device_type=get_device(model), dtype=config.amp.dtype)

    for event in lr_finder_search(model, criterion, optimizer, train_loader, amp_context=amp_context):
        match event:
            case LRFinderConfig() | LRFinderSeries() as data:
                plot(data)
            case Progress() as progress:
                prog(progress)
            case float() as suggested_lr:
                return suggested_lr
    raise ValueError('No suggested learning rate found.')


@asynccontextmanager
async def progress():
    from utils.nb import displayer

    total_steps = 1

    display = displayer()

    async def _progress(event: Progress):
        nonlocal total_steps
        if event.total_steps:
            total_steps = event.total_steps
        suffix = f' - {event.info}' if event.info else ''
        if event.step:
            fraction = event.step / total_steps
            display(f'Progress: {fraction:.0%}{suffix}')

    yield _progress


@asynccontextmanager
async def plotter():
    from utils.lr_finder.vis import lr_finder_plot

    yield coerce_to_async(lr_finder_plot())


async with exp(), run_hither(progress) as prog, run_hither(plotter) as plot:
    suggested_lr = await find_learning_rate(plot, prog)

print(f'Suggested Learning Rate: {suggested_lr:.2e}')
config.optimizer.learning_rate = suggested_lr

In [None]:
from experiment.compute.model import save_checkpoint
from experiment.compute.training import TrainingEvent, train_model
from utils.time import duration


@exp.run_thither(gpu='L4', timeout=int(duration('20 min')), **resource_limits)
async def train(prog_cb: Callback[TrainingEvent]):
    configure_logging()

    for event in train_model(config):
        match event:
            case 'checkpoint', (model, cfg, metrics):
                save_checkpoint(model, cfg, metrics)
                volume.commit()
            case _:
                prog_cb(event)


@asynccontextmanager
async def progress():
    from tqdm.auto import tqdm

    with tqdm(total=0, desc='Epoch') as pb_epoch, tqdm(total=0, desc='Step', leave=False) as pb_step:

        async def progress_callback(event: TrainingEvent):
            match event:
                case 'epochs', total:
                    pb_epoch.total = total
                    pb_epoch.refresh()
                case 'steps-per-epoch', total:
                    pb_step.total = total
                    pb_step.refresh()
                case ('train-step', n) | ('val-step', n):
                    pb_step.update(n)
                case 'epoch-end', metrics:
                    pb_epoch.set_postfix(metrics.model_dump())
                    pb_epoch.update(1)
                    pb_step.reset()
                    # plot(metrics.val_loss)

        yield progress_callback


async with run_hither(progress) as prog_cb, exp():
    await train(prog_cb)

# Generate continuations

In [None]:
from pydantic import NonNegativeFloat, PositiveInt
from experiment.model.gpt import Generation


@exp.run_thither(gpu='T4', timeout=int(duration('1 min')), **resource_limits)
@validate_call
async def generate(
    prompts: list[str],
    max_new_tokens: PositiveInt,
    temperature: NonNegativeFloat,
) -> tuple[list[list[str]], Generation]:
    from typing import cast
    import torch

    from experiment.compute.model import load_checkpoint
    from experiment.data.tokenizer import CharTokenizer

    configure_logging()

    print('Loading model from checkpoint')
    model, config, _ = load_checkpoint()
    model.eval()
    tokenizer = CharTokenizer(config.tokenizer)
    context = torch.tensor(tokenizer.encode(prompts), dtype=torch.long)
    if torch.cuda.is_available():
        context = context.cuda()
        model = model.cuda()

    print(f'Generating {max_new_tokens} tokens with temperature {temperature}')
    output = model.generate(context, max_new_tokens=max_new_tokens, temperature=temperature)
    print(f'Generated {len(output.tokens)} tokens')

    toks = cast(list[list[int]], output.tokens.tolist())
    return tokenizer.decode_each(toks), output

In [None]:
prompts = ["Odd's fish m'dear,"]

async with exp():
    continuations, metadata = await generate(prompts=prompts, max_new_tokens=500, temperature=0.5)

for sequence in continuations:
    print(''.join(sequence[:40]), '(etc.)')

### Token metrics: Perplexity and entropy

Let's visualize the generation along with some metrics:
* **Entropy** is how diffuse the probability distribution is for the next token, i.e. the spread of probabilities before sampling. Can be thought of as how uncertain the model is about what to say next. This usually isn't calculated for prompt tokens.
* **Perplexity** is how unlikely the next token is. Can be thought of as how surprised the model is by the presence of the token at this point in the sequence. This is usually calculated for prompt tokens, but can also be calculated for continuation tokens.

```python
# For entropy (uncertainty about next token):
probs = F.softmax(next_token_logits, dim=-1)
entropy = -torch.sum(probs * torch.log(probs + 1e-10), dim=-1)

# For perplexity (surprise about actual token):
token_loss = F.cross_entropy(next_token_logits, idx_next.view(-1), reduction='none')
perplexities[curr_len] = torch.exp(token_loss)
```

Notably, _the entropy of continuation tokens is unaffected by temperature_, whereas the perplexity _is_ affected (because it's calculated after sampling).

We should expect them to be correlated, because when the model is very certain (low entropy), it's more likely to sample a high-probability token (low perplexity), and vice-versa. But they can diverge in interesting ways: You could have high entropy but still sample a high-probability token by chance (high entropy, low perplexity), or You could have low entropy but sample an unlikely token due to temperature (low entropy, high perplexity).

In [None]:
from experiment.model.gpt import SingleGeneration


def annotate_tokens(tokens: list[str], metadata: SingleGeneration):
    from subline.series import Series
    from subline.subline import Subline
    from IPython.display import SVG, display

    viz = Subline(chars_per_line=80)
    svg = viz.plot(
        tokens,
        [
            # EntropySeries(metadata.entropy, label='Entropy', vocab_size=metadata.vocab_size),
            # EntropySeries(metadata.surprisal, label='Surprisal', vocab_size=metadata.vocab_size),
            Series(metadata.surprise_surprise, label='S₂'),
            Series(-metadata.surprise_surprise, label='-S₂', dasharray='1'),
        ],
    )
    display(SVG(svg))


annotate_tokens(continuations[0], metadata[0])

# TODO: Is the quality worse now, using fp16? It *seems* worse.

# Future research

### Temperature as a resource

It's interesting to note that the first character of each word is high entropy and high perplexity, and subsequent characters are lower. And where the model makes spelling mistakes, it often had low entropy but then high perplexity! Which suggests that it knows what it wants to write, but the sampling mechanism is messing it up.

I'd like to see if it improves to lower the temperature after the first letter of each word. How would that generalise to languages that don't use spaces? Perhaps the temperature could be a resource that gets used (by picking unlikely tokens) and gradually replenished (by picking likely ones). Yeah that could be really cool!

# References

Karpathy, A. (2022). nanoGPT [Computer software]. GitHub. https://github.com/karpathy/nanoGPT

Sanderson, G. (2024a). Visualizing attention, a transformer's heart. 3Blue1Brown. https://www.3blue1brown.com/lessons/attention

Sanderson, G. (2024b). How might LLMs store facts. 3Blue1Brown. https://www.3blue1brown.com/lessons/mlp

# Software Licenses

The code in this notebook is derived from nanoGPT (Karpathy, 2022), which is licensed under the MIT License, Copyright (c) 2022 Andrej Karpathy.