# Train LLM on shakespeare data.


In [None]:
import os
from pathlib import Path
from importlib import reload
import requests
from rgi.rgizero import common

DATA_DIR = common.data_dir("shakespeare-char")

In [None]:
raw_text_path = DATA_DIR / "raw_text.txt"
if not os.path.exists(raw_text_path):
    data_url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
    with open(raw_text_path, "w") as f:
        f.write(requests.get(data_url).text)

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader

def split_dataset(dataset: Dataset, train_split: float = 0.9) -> tuple[Dataset, Dataset]:
    """Split dataset into train and validation sets."""
    train_size = int(train_split * len(dataset))
    train_dataset = torch.utils.data.Subset(dataset, list(range(train_size)))
    validation_dataset = torch.utils.data.Subset(dataset, list(range(train_size, len(dataset))))
    return train_dataset, validation_dataset

In [None]:
from rgi.rgizero.data.text_dataset import SimpleTextDataset

BLOCK_SIZE = 128
BATCH_SIZE = 16
NUM_DATALOADER_WORKERS = 0
DEVICE = 'cuda'

with open(raw_text_path, 'r', encoding='utf-8') as f:
    raw_text = f.read()
    
text_dataset = SimpleTextDataset(raw_text, BLOCK_SIZE, device=DEVICE)
vocab_size = text_dataset.vocab_size
train_dataset, val_dataset = split_dataset(text_dataset, 0.9)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, num_workers=NUM_DATALOADER_WORKERS, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, num_workers=NUM_DATALOADER_WORKERS, shuffle=False)

In [None]:
from rgi.rgizero.train import Trainer, TrainConfig
from rgi.rgizero.models.transformer import TransformerConfig
from rgi.rgizero.models.token_transformer import TokenTransformer


# model_config = TransformerConfig(n_max_context=BLOCK_SIZE, n_embd=128, n_layer=4, n_head=4, dropout=0.2)  # femto
model_config = TransformerConfig(n_max_context=BLOCK_SIZE, n_embd=384, n_layer=6, n_head=6, dropout=0.2)  # baby
model = TokenTransformer(model_config, vocab_size)
model.to(DEVICE)

In [None]:
# TrainConfig with overrides based on train_shakespeare_char.py
base_shakespeare_train_config = TrainConfig(
    model_name="shakespeare-gpt",
    model_version="v1",

    eval_interval = 250,  # keep frequent because we'll overfit
    eval_iters = 200,
    log_interval = 10,  # don't print too too often

    # we expect to overfit on this small dataset, so only save when val improves
    always_save_checkpoint = False,

    gradient_accumulation_steps = 1,
    batch_size = 64,

    learning_rate = 1e-3,  # with baby networks can afford to go a bit higher
    max_iters = 5000,
    lr_decay_iters = 5000,  # make equal to max_iters usually
    min_lr = 1e-4,  # learning_rate / 10 usually
    beta2 = 0.99,  # make a bit bigger because number of tokens per iter is small

    warmup_iters = 100,  # not super necessary potentially
)

baby_shakespeare_train_config = base_shakespeare_train_config.__replace__(
    max_iters=5000,
    eval_interval=50,
    eval_iters=10,
    log_interval=10,
    batch_size=16,
)

femto_shakespeare_train_config = baby_shakespeare_train_config.__replace__(
    # femto GPT model, much smaller than baby model.
    warmup_iters=10,
    max_iters=100,
    compile=False,
)

In [None]:
# Create trainer
train_config = femto_shakespeare_train_config
trainer = Trainer(
    model=model,
    train_config=train_config,
    train_loader=train_loader,
    val_loader=val_loader,
    device=DEVICE
)

PROFILER = None # "prun"  # None, "prun", "torch"

if PROFILER is None:
    trainer.train()
elif PROFILER == "prun":
    _pstats = %prun -r -l 30 -s cumulative trainer.train()
    # _pstats.print_stats('_nn')
elif PROFILER == "torch":
    from torch.profiler import profile, record_function, ProfilerActivity
    with profile(
        activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
        record_shapes=True,
        profile_memory=True,
        with_stack=True
    ) as prof:
        with record_function("training_step"):
            trainer.train()
    
    print("\nPreparing profiler output ...")  # why is this so slow?
    print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))

In [None]:
# Sample from the model
model.eval()
with torch.no_grad():
    context = "\n"
    x = torch.tensor([text_dataset.encode(context)], dtype=torch.long, device=DEVICE)
    y = model.generate(x, max_new_tokens=50)
    print(text_dataset.decode(y[0].tolist()))