# 07 - Supervised Fine-Tuning

This notebook fine-tunes the pretrained ShrayGPT model from notebook 06 into an instruction-following assistant using supervised fine-tuning (SFT). We will:

1. Extend the `r50k_base` tokenizer with conversation-specific special tokens.
2. Build an instruction dataset by blending Alpaca, Dolly v2, and UltraChat samples.
3. Pack the conversations into fixed-length training chunks that match the model's block size.
4. Fine-tune the base checkpoint with PyTorch Lightning and sample the resulting instruct model.

In [None]:
import random
from typing import Dict, Iterable, List, Optional, Sequence, Tuple

import lightning as L
import torch
import torch.nn as nn
from datasets import interleave_datasets, load_dataset
from torch.utils.data import DataLoader, Dataset
import tiktoken

from src.shraygpt import ShrayGPT

torch.set_num_threads(1)

SPECIAL_TOKENS = ["<|system|>", "<|user|>", "<|assistant|>", "<|end|>"]

# Extend the GPT-2 tokenizer with chat-specific tokens used in the SFT recipes below.
tokenizer = tiktoken.get_encoding("r50k_base")
tokenizer = tokenizer.merge_special_tokens(SPECIAL_TOKENS)
ALLOWED_SPECIAL = set(SPECIAL_TOKENS)
SPECIAL_TOKEN_IDS: Dict[str, int] = {
    tok: tokenizer.encode(tok, allowed_special={tok})[0] for tok in SPECIAL_TOKENS
}

print(f"Tokenizer vocab size after adding specials: {tokenizer.n_vocab}")

In [None]:
BASE_CHECKPOINT_PATH = "checkpoints/shraygpt-pretrain.ckpt"

BASE_CONFIG = dict(
    vocab_size=tokenizer.n_vocab,
    block_size=4096,
    d_model=1024,
    n_head=16,
    d_head=64,
    n_layers=16,
    num_experts=4,
    num_experts_per_tok=2,
    dropout=0.0,
)


def load_base_model() -> ShrayGPT:
    '''Load the pretrained checkpoint from notebook 06 or fall back to a fresh instance.'''
    try:
        model = ShrayGPT.load_from_checkpoint(BASE_CHECKPOINT_PATH, map_location="cpu")
        print(f"Loaded checkpoint from {BASE_CHECKPOINT_PATH}")
    except FileNotFoundError:
        print("Checkpoint not found. Initializing a fresh model with base configuration.")
        model = ShrayGPT(**BASE_CONFIG)
    return model


def resize_token_embeddings(model: ShrayGPT, new_vocab_size: int) -> None:
    '''Resize token embeddings and LM head when new special tokens are added.'''
    old_vocab_size = model.tok_emb.weight.size(0)
    if new_vocab_size <= old_vocab_size:
        return

    device = model.tok_emb.weight.device
    d_model = model.tok_emb.weight.size(1)

    new_tok_emb = nn.Embedding(new_vocab_size, d_model, device=device)
    new_head = nn.Linear(d_model, new_vocab_size, bias=False, device=device)

    with torch.no_grad():
        new_tok_emb.weight[:old_vocab_size] = model.tok_emb.weight
        new_head.weight[:old_vocab_size] = model.head.weight
        nn.init.normal_(new_tok_emb.weight[old_vocab_size:], mean=0.0, std=0.02)
        nn.init.normal_(new_head.weight[old_vocab_size:], mean=0.0, std=0.02)

    model.tok_emb = new_tok_emb
    model.head = new_head
    print(f"Resized embeddings from {old_vocab_size} to {new_vocab_size}")


model = load_base_model()
resize_token_embeddings(model, tokenizer.n_vocab)

# Lower learning rates for the instruct tuning phase and keep manual optimization.
model.hparams.learning_rate_adamw = 1e-4
model.hparams.learning_rate_muon = 5e-4
model.hparams.aux_loss_weight = 5e-4
model.automatic_optimization = False

In [None]:
SYSTEM_TOKEN = "<|system|>"
USER_TOKEN = "<|user|>"
ASSISTANT_TOKEN = "<|assistant|>"
END_TOKEN = "<|end|>"
ROLE_TOKENS = {
    "system": SYSTEM_TOKEN,
    "user": USER_TOKEN,
    "assistant": ASSISTANT_TOKEN,
}
DEFAULT_SYSTEM_PROMPT = "You are ShrayGPT, a helpful and concise AI assistant."


def format_conversation(messages: Sequence[Tuple[str, str]], default_system: str = DEFAULT_SYSTEM_PROMPT) -> Optional[str]:
    '''Convert a list of (role, content) pairs into a single conversation string.'''
    formatted: List[str] = []
    has_system = any(role == "system" for role, _ in messages)
    if not has_system:
        formatted.append(f"{SYSTEM_TOKEN}{default_system.strip()}{END_TOKEN}")

    for role, content in messages:
        token = ROLE_TOKENS.get(role)
        text = content.strip()
        if token is None or not text:
            continue
        formatted.append(f"{token}{text}{END_TOKEN}")

    if not formatted:
        return None
    return "".join(formatted)


def preprocess_alpaca(example: Dict[str, str]) -> Dict[str, Optional[str]]:
    instruction = example["instruction"].strip()
    input_text = example.get("input", "").strip()
    response = example["output"].strip()

    if input_text:
        user_prompt = f"{instruction}

{input_text}"
    else:
        user_prompt = instruction

    conversation = [
        ("system", DEFAULT_SYSTEM_PROMPT),
        ("user", user_prompt),
        ("assistant", response),
    ]
    return {"text": format_conversation(conversation)}


def preprocess_dolly(example: Dict[str, str]) -> Dict[str, Optional[str]]:
    instruction = example["instruction"].strip()
    context = example.get("context", "").strip()
    response = example["response"].strip()

    if context:
        user_prompt = f"{instruction}

Context:
{context}"
    else:
        user_prompt = instruction

    conversation = [
        ("system", DEFAULT_SYSTEM_PROMPT),
        ("user", user_prompt),
        ("assistant", response),
    ]
    return {"text": format_conversation(conversation)}


def preprocess_ultrachat(example: Dict[str, List[Dict[str, str]]]) -> Dict[str, Optional[str]]:
    messages = []
    for message in example["messages"]:
        role = message.get("role")
        content = message.get("content", "").strip()
        if role not in ROLE_TOKENS or not content:
            continue
        messages.append((role, content))

    if not messages or messages[-1][0] != "assistant":
        return {"text": None}

    return {"text": format_conversation(messages)}


alpaca = load_dataset("yahma/alpaca-cleaned", split="train")
dolly = load_dataset("databricks/databricks-dolly-15k", split="train")
ultrachat = load_dataset("HuggingFaceH4/ultrachat_200k", "messages", split="train_sft")

alpaca_sft = alpaca.map(preprocess_alpaca, remove_columns=alpaca.column_names)
dolly_sft = dolly.map(preprocess_dolly, remove_columns=dolly.column_names)
ultrachat_sft = ultrachat.map(preprocess_ultrachat, remove_columns=ultrachat.column_names)

alpaca_sft = alpaca_sft.filter(lambda ex: ex["text"] is not None)
dolly_sft = dolly_sft.filter(lambda ex: ex["text"] is not None)
ultrachat_sft = ultrachat_sft.filter(lambda ex: ex["text"] is not None)

combined = interleave_datasets(
    [alpaca_sft, dolly_sft, ultrachat_sft],
    probabilities=[0.4, 0.3, 0.3],
    seed=42,
)
train_val = combined.train_test_split(test_size=0.02, seed=42)
train_texts = train_val["train"].shuffle(seed=42)
val_texts = train_val["test"].shuffle(seed=1234)

MAX_TRAIN = 20000
MAX_VAL = 512
if train_texts.num_rows > MAX_TRAIN:
    train_texts = train_texts.select(range(MAX_TRAIN))
if val_texts.num_rows > MAX_VAL:
    val_texts = val_texts.select(range(MAX_VAL))

print(f"Training examples: {train_texts.num_rows}")
print(f"Validation examples: {val_texts.num_rows}")

In [None]:
BLOCK_SIZE = 1024  # SFT can use a shorter context than the 06 pretraining run.
END_ID = SPECIAL_TOKEN_IDS[END_TOKEN]


def tokenize_texts(texts: Iterable[str]) -> List[List[int]]:
    sequences: List[List[int]] = []
    for text in texts:
        ids = tokenizer.encode(text, allowed_special=ALLOWED_SPECIAL)
        if not ids:
            continue
        if ids[-1] != END_ID:
            ids.append(END_ID)
        sequences.append(ids)
    return sequences


def pack_sequences(sequences: Sequence[Sequence[int]], block_size: int) -> List[Tuple[torch.Tensor, torch.Tensor]]:
    packed: List[Tuple[torch.Tensor, torch.Tensor]] = []
    buffer: List[int] = []
    for seq in sequences:
        buffer.extend(seq)
        while len(buffer) >= block_size + 1:
            chunk = buffer[: block_size + 1]
            x = torch.tensor(chunk[:-1], dtype=torch.long)
            y = torch.tensor(chunk[1:], dtype=torch.long)
            packed.append((x, y))
            buffer = buffer[block_size:]
    return packed


train_sequences = tokenize_texts(train_texts["text"])
val_sequences = tokenize_texts(val_texts["text"])

train_pairs = pack_sequences(train_sequences, BLOCK_SIZE)
val_pairs = pack_sequences(val_sequences, BLOCK_SIZE)

random.shuffle(train_pairs)


class PackedPairDataset(Dataset):
    def __init__(self, pairs: Sequence[Tuple[torch.Tensor, torch.Tensor]]):
        self.pairs = list(pairs)

    def __len__(self) -> int:
        return len(self.pairs)

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        return self.pairs[idx]


train_dataset = PackedPairDataset(train_pairs)
val_dataset = PackedPairDataset(val_pairs)

train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True, num_workers=0, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=2, shuffle=False, num_workers=0, pin_memory=True)

print(f"{len(train_dataset)} training chunks, {len(val_dataset)} validation chunks")

In [None]:
from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint


class GenerateAfterValidation(L.Callback):
    def __init__(self, prompts: Sequence[str], tokenizer: tiktoken.Encoding, every_n_steps: int = 200):
        super().__init__()
        self.prompts = list(prompts)
        self.tokenizer = tokenizer
        self.every_n_steps = every_n_steps

    def on_validation_epoch_end(self, trainer, pl_module):
        if trainer.global_step == 0 or trainer.global_step % self.every_n_steps != 0:
            return
        if not trainer.is_global_zero:
            return
        for prompt in self.prompts:
            response = chat(model=pl_module, tokenizer=self.tokenizer, prompt=prompt)
            pl_module.print(f"Prompt: {prompt}
Response: {response}
")


def chat(model: ShrayGPT, tokenizer: tiktoken.Encoding, prompt: str, system_prompt: str = DEFAULT_SYSTEM_PROMPT,
         max_new_tokens: int = 256, temperature: float = 0.7, top_k: Optional[int] = 50) -> str:
    prefix = (
        f"{SYSTEM_TOKEN}{system_prompt}{END_TOKEN}"
        f"{USER_TOKEN}{prompt.strip()}{END_TOKEN}"
        f"{ASSISTANT_TOKEN}"
    )
    prompt_tokens = tokenizer.encode(prefix, allowed_special=ALLOWED_SPECIAL)
    prompt_tokens = prompt_tokens[-model.block_size:]
    context = torch.tensor(prompt_tokens, dtype=torch.long, device=model.device).unsqueeze(0)

    with torch.no_grad():
        generated = model.generate_nocache(context, max_new_tokens=max_new_tokens, temperature=temperature, top_k=top_k)

    generated_ids = generated[0].tolist()
    new_tokens = generated_ids[len(prompt_tokens):]
    text = tokenizer.decode(new_tokens)
    if END_TOKEN in text:
        text = text.split(END_TOKEN)[0]
    return text.strip()


checkpoint_cb = ModelCheckpoint(
    dirpath="checkpoints/",
    filename="shraygpt-sft-{epoch:02d}-{step:05d}-{val_loss:.3f}",
    monitor="val_loss",
    mode="min",
    save_top_k=2,
    save_last=True,
)
lr_monitor = LearningRateMonitor(logging_interval="step")
generate_cb = GenerateAfterValidation(
    prompts=[
        "Summarize the Water Cycle in two sentences.",
        "Write a short Python function that checks if a string is a palindrome.",
    ],
    tokenizer=tokenizer,
    every_n_steps=200,
)

trainer = L.Trainer(
    max_steps=2000,
    accumulate_grad_batches=4,
    gradient_clip_val=1.0,
    accelerator="auto",
    devices="auto",
    precision="bf16-mixed" if torch.cuda.is_available() else "32-true",
    log_every_n_steps=10,
    callbacks=[checkpoint_cb, lr_monitor, generate_cb],
    limit_val_batches=50,
)

trainer.fit(model, train_loader, val_loader)

In [None]:
# Generate a response with the tuned model once training has completed.
example_prompt = "Explain how to write a unit test in Python using pytest."
print(chat(model, tokenizer, example_prompt))