## Fine-tuning GPT-2 With your data. 

This notebook demonstrates how to fine-tune the smallest GPT-2 model (124M parameters) on your custom text data. Unlike training from scratch, fine-tuning starts with a pre-trained model and adapts it to your needs.

Important: Make sure you run every cell in this workbook by using the "Play" button on the right-hand side of each cell before moving on to the next one.
If you have to restart the program for some reason, you might have to run the cells again.


## Prepare Your Data
Place your text data in a file called input.txt in the same directory as this notebook. The text should be clean and representative of what you want the model to learn.


In [32]:
import os
import sys
import time
import math
import torch
import numpy as np
import tiktoken
from contextlib import nullcontext
from model import GPT, GPTConfig

# Load and prepare your data
input_file = 'alice.txt'
if not os.path.exists(input_file):
    raise FileNotFoundError(f"Please ensure {input_file} exists in the current directory")

with open(input_file, 'r', encoding='utf-8') as f:
    text = f.read()

print(f"Loaded {len(text)} characters from {input_file}")

# Create configuration (using Shakespeare example parameters)
dataset = input_file.split('.')[0]
out_dir = f'out-{dataset}-finetune'
if not os.path.exists(out_dir):
    os.makedirs(out_dir)

# Process and tokenize data
n = len(text)
train_data = text[:int(n*0.9)]
val_data = text[int(n*0.9):]
enc = tiktoken.get_encoding("gpt2")
train_ids = enc.encode_ordinary(train_data)
val_ids = enc.encode_ordinary(val_data)
vocab_size = enc.n_vocab
print(f"Train: {len(train_ids)} tokens, Val: {len(val_ids)} tokens, Vocab: {vocab_size} tokens")

# Save data as binary files
train_file = f'{dataset}_train.bin'
val_file = f'{dataset}_val.bin'
train_ids = np.array(train_ids, dtype=np.uint16)
val_ids = np.array(val_ids, dtype=np.uint16)
train_ids.tofile(train_file)
val_ids.tofile(val_file)

# Training settings
batch_size = 1
gradient_accumulation_steps = 32
block_size = 1024
learning_rate = 3e-5
max_iters = 20
eval_interval = 5
eval_iters = 40
decay_lr = False
device = 'cuda' if torch.cuda.is_available() else 'cpu'
dtype = 'float16' if torch.cuda.is_available() else 'float32'



Loaded 148043 characters from alice.txt
Train: 38141 tokens, Val: 4189 tokens, Vocab: 50257 tokens


## Define utility functions and load the model

In [33]:
# Define utility functions
def get_batch(split):
    data = np.memmap(train_file if split == 'train' else val_file, dtype=np.uint16, mode='r')
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix])
    y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix])
    return x.to(device), y.to(device)

@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            with ctx:
                logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

# Load and configure the model
print(f"Initializing model from 'gpt2'...")
model = GPT.from_pretrained('gpt2', dict(dropout=0.0))
model = model.to(device)
print(f"Model has {sum(p.numel() for p in model.parameters())/1e6:.2f}M parameters")

# Optimizer setup
optimizer = model.configure_optimizers(
    weight_decay=1e-2, learning_rate=learning_rate, 
    betas=(0.9, 0.95), device_type=device
)

# Context manager for mixed precision
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
ctx = torch.amp.autocast(device_type=device, dtype=ptdtype) if device != 'cpu' else nullcontext()
scaler = torch.cuda.amp.GradScaler(enabled=(dtype == 'float16' and device == 'cuda'))

# Training loop
model.train()
best_val_loss = float('inf')
print(f"Starting training for {max_iters} iterations...")
start_time = time.time()

for iter_num in range(max_iters):
    # Learning rate scheduling
    if decay_lr:
        lr = learning_rate * (0.5 ** (iter_num / max_iters))
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr
    
    # Forward and backward passes with gradient accumulation
    optimizer.zero_grad(set_to_none=True)
    for micro_step in range(gradient_accumulation_steps):
        X, Y = get_batch('train')
        with ctx:
            logits, loss = model(X, Y)
            loss = loss / gradient_accumulation_steps
        
        scaler.scale(loss).backward()
    
    # Gradient clipping and optimizer step
    scaler.unscale_(optimizer)
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    scaler.step(optimizer)
    scaler.update()
    
    # Evaluation and checkpoint saving
    if (iter_num + 1) % eval_interval == 0:
        losses = estimate_loss()
        print(f"iter {iter_num+1}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
        
        if losses['val'] < best_val_loss:
            best_val_loss = losses['val']
            checkpoint = {
                'model': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'model_args': model.config,
                'iter_num': iter_num,
                'best_val_loss': best_val_loss
            }
            print(f"Saving checkpoint to {out_dir}")
            torch.save(checkpoint, os.path.join(out_dir, 'ckpt.pt'))

elapsed = time.time() - start_time
print(f"\nTraining completed in {elapsed:.2f} seconds!")
print(f"Best validation loss: {best_val_loss:.4f}")


Initializing model from 'gpt2'...
loading weights from pretrained gpt: gpt2
forcing vocab_size=50257, block_size=1024, bias=True
overriding dropout rate to 0.0
number of parameters: 123.65M
Model has 124.44M parameters
num decayed parameter tensors: 50, with 124,318,464 parameters
num non-decayed parameter tensors: 98, with 121,344 parameters
using fused AdamW: True
Starting training for 20 iterations...


  scaler = torch.cuda.amp.GradScaler(enabled=(dtype == 'float16' and device == 'cuda'))


iter 5: train loss 2.8029, val loss 2.8485
Saving checkpoint to out-alice-finetune
iter 10: train loss 2.6709, val loss 2.8459
Saving checkpoint to out-alice-finetune
iter 15: train loss 2.5099, val loss 2.7683
Saving checkpoint to out-alice-finetune
iter 20: train loss 2.4738, val loss 2.7775

Training completed in 37.10 seconds!
Best validation loss: 2.7683


## Generate some text

In [34]:
# Generate some text with the trained model
def generate_text(prompt, max_tokens=200):
    model.eval()
    tokens = enc.encode_ordinary(prompt)
    x = torch.tensor([tokens], dtype=torch.long, device=device)
    with torch.no_grad():
        y = model.generate(x, max_new_tokens=max_tokens, temperature=0.8)
    return enc.decode(y[0].tolist())

# Test with a few prompts
test_prompts = ["Alice was", "The Queen", "Down the rabbit"]
for prompt in test_prompts:
    print(f"\nPrompt: '{prompt}'")
    print("-" * 50)
    generated = generate_text(prompt)
    print(generated)
    print("=" * 50)


Prompt: 'Alice was'
--------------------------------------------------
Alice was in the doorway when she heard the door open. She thought she heard someone banging on the window, but nothing happened.
"I can't tell you what happened," said the doorkeeper, as she opened it. "What did you do? Did you go hunting?"
"I wasn't hunting."
"Yes, you were. Was that all you wanted?"
She thought she heard the door close behind her, and she looked, as if she had opened it.
"It was only a little after midnight, there's nothing you can tell me about it."
"What did I do wrong?"
She looked down at the ten-dollar bill and thought she saw a guard with a question mark above it, but she could see nothing but guards.
"I haven't seen him since nightfall, and I was wondering what was going on. I came upon a little schoolhouse in the middle of the street, and the doors only opened when you was in the

Prompt: 'The Queen'
--------------------------------------------------
The Queen must, for the greater part, 