# Training From Scratch

This notebook provides an interactive guide to understanding this component of GPT.


In [None]:
# Import necessary libraries
import torch
import numpy as np
import matplotlib.pyplot as plt
import sys
import os
from torch.utils.data import DataLoader, random_split

# Add project root to path
project_root = os.path.dirname(os.path.dirname(os.path.abspath('')))
if project_root not in sys.path:
    sys.path.insert(0, project_root)

# Import our training components
from src.model.gpt import GPTModel
from src.config import GPTConfig
from src.data.dataset import GPTDataset
from src.data.tokenizer import get_tokenizer
from src.training.trainer import GPTTrainer
import tiktoken

## Training GPT From Scratch

This notebook demonstrates how to train a GPT model from scratch using the training utilities provided in this codebase.

### 1. Prepare Data

In [None]:
# Load or create sample text
sample_text_path = os.path.join(project_root, "data", "sample_text.txt")
if os.path.exists(sample_text_path):
    with open(sample_text_path, "r", encoding="utf-8") as f:
        text = f.read()
    print(f"Loaded text: {len(text):,} characters")
else:
    # Use a simple example
    text = """Once upon a time, there was a little girl named Emma. She loved to play in the garden. 
Every morning, Emma would wake up early and run outside. She would pick flowers and watch the butterflies. 
The garden was her favorite place in the whole world. She spent hours there every day, playing and exploring.
The cat sat on the mat. It was a sunny day. The cat was very happy. It purred softly and stretched its paws."""
    print("Using fallback text")

# Initialize tokenizer
tokenizer = get_tokenizer("gpt2")
vocab_size = tokenizer.n_vocab
print(f"Vocabulary size: {vocab_size}")

In [None]:
# Create dataset
context_length = 32  # Small for demonstration
stride = context_length // 2  # 50% overlap

dataset = GPTDataset(
    text=text,
    tokenizer=tokenizer,
    maximum_length=context_length,
    stride=stride
)

print(f"Dataset created:")
print(f"  Total sequences: {len(dataset):,}")
print(f"  Context length: {context_length}")
print(f"  Stride: {stride}")

# Split into train/validation
train_size = int(0.9 * len(dataset))
val_size = len(dataset) - train_size

if train_size == 0:
    train_size = 1
    val_size = len(dataset) - 1
if val_size == 0:
    val_size = 1
    train_size = len(dataset) - 1

train_dataset, val_dataset = random_split(
    dataset, [train_size, val_size],
    generator=torch.Generator().manual_seed(42)
)

print(f"\nTrain sequences: {len(train_dataset):,}")
print(f"Validation sequences: {len(val_dataset):,}")

In [None]:
# Create data loaders
batch_size = 4
train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=0
)

val_loader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=0
)

print(f"Data loaders created:")
print(f"  Batch size: {batch_size}")
print(f"  Train batches: {len(train_loader)}")
print(f"  Validation batches: {len(val_loader)}")

### 2. Create Model

In [None]:
# Create a small model for training
config = GPTConfig(
    vocab_size=vocab_size,
    context_length=context_length,
    embedding_dimension=128,
    number_of_heads=4,
    number_of_layers=2,
    dropout_rate=0.1
)

model = GPTModel(config)

total_params = sum(p.numel() for p in model.parameters())
print(f"Model created:")
print(f"  Parameters: {total_params:,}")
print(f"  Model size: {total_params * 4 / 1024 / 1024:.2f} MB (FP32)")

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"  Device: {device}")

### 3. Setup Training

In [None]:
# Create optimizer
learning_rate = 3e-4
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=learning_rate,
    weight_decay=0.1,
    betas=(0.9, 0.95)
)

print(f"Optimizer: AdamW")
print(f"  Learning rate: {learning_rate}")
print(f"  Weight decay: 0.1")

# Create trainer
trainer = GPTTrainer(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    optimizer=optimizer,
    device=device
)

print(f"\nTrainer created")

### 4. Training Loop

In [None]:
# Train for a few epochs
num_epochs = 3
train_losses = []
val_losses = []

print(f"Training for {num_epochs} epochs...\n")

for epoch in range(1, num_epochs + 1):
    print(f"Epoch {epoch}/{num_epochs}")
    print("-" * 40)
    
    # Train
    train_loss = trainer.train_epoch()
    train_perplexity = torch.exp(torch.tensor(train_loss)).item()
    train_losses.append(train_loss)
    
    print(f"Train Loss: {train_loss:.4f} | Perplexity: {train_perplexity:.2f}")
    
    # Validate
    val_loss = trainer.validate()
    if val_loss is not None:
        val_perplexity = torch.exp(torch.tensor(val_loss)).item()
        val_losses.append(val_loss)
        print(f"Val Loss: {val_loss:.4f} | Perplexity: {val_perplexity:.2f}")
    
    print()

In [None]:
# Plot training curves
if len(train_losses) > 0:
    fig, axes = plt.subplots(1, 2, figsize=(12, 4))
    
    epochs = list(range(1, len(train_losses) + 1))
    
    # Loss plot
    axes[0].plot(epochs, train_losses, marker='o', label='Train Loss', linewidth=2)
    if len(val_losses) > 0:
        axes[0].plot(epochs, val_losses, marker='s', label='Val Loss', linewidth=2)
    axes[0].set_xlabel('Epoch', fontsize=12)
    axes[0].set_ylabel('Loss', fontsize=12)
    axes[0].set_title('Training and Validation Loss', fontsize=14, fontweight='bold')
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)
    
    # Perplexity plot
    train_perplexities = [torch.exp(torch.tensor(l)).item() for l in train_losses]
    axes[1].plot(epochs, train_perplexities, marker='o', label='Train Perplexity', linewidth=2)
    if len(val_losses) > 0:
        val_perplexities = [torch.exp(torch.tensor(l)).item() for l in val_losses]
        axes[1].plot(epochs, val_perplexities, marker='s', label='Val Perplexity', linewidth=2)
    axes[1].set_xlabel('Epoch', fontsize=12)
    axes[1].set_ylabel('Perplexity', fontsize=12)
    axes[1].set_title('Training and Validation Perplexity', fontsize=14, fontweight='bold')
    axes[1].legend()
    axes[1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()

### 5. Understanding the Loss

The model is trained to predict the next token at each position. The loss is computed using CrossEntropyLoss between predictions and targets.

In [None]:
# Examine a training batch to understand the loss computation
model.eval()
with torch.no_grad():
    # Get a batch
    input_ids, target_ids = next(iter(train_loader))
    input_ids = input_ids.to(device)
    target_ids = target_ids.to(device)
    
    # Forward pass
    logits = model(input_ids)  # [batch_size, seq_len, vocab_size]
    
    # Compute loss manually
    criterion = torch.nn.CrossEntropyLoss()
    loss = criterion(
        logits.view(-1, logits.size(-1)),  # [batch*seq_len, vocab_size]
        target_ids.view(-1)  # [batch*seq_len]
    )
    
    print(f"Batch example:")
    print(f"  Input shape: {input_ids.shape}")
    print(f"  Target shape: {target_ids.shape}")
    print(f"  Logits shape: {logits.shape}")
    print(f"  Loss: {loss.item():.4f}")
    print(f"  Perplexity: {torch.exp(loss).item():.2f}")
    
    # Show predictions for first sequence
    first_seq_logits = logits[0]  # [seq_len, vocab_size]
    first_seq_preds = torch.argmax(first_seq_logits, dim=-1)
    first_seq_targets = target_ids[0]
    
    print(f"\nFirst sequence:")
    print(f"  Input tokens: {input_ids[0].tolist()[:10]}...")
    print(f"  Target tokens: {first_seq_targets.tolist()[:10]}...")
    print(f"  Predicted tokens: {first_seq_preds.tolist()[:10]}...")
    print(f"  Accuracy: {(first_seq_preds == first_seq_targets).float().mean().item():.2%}")

### 6. Saving and Loading Checkpoints

It's important to save model checkpoints during training for later use.

In [None]:
# Save checkpoint
checkpoint_dir = os.path.join(project_root, "checkpoints")
os.makedirs(checkpoint_dir, exist_ok=True)

checkpoint_path = os.path.join(checkpoint_dir, "notebook_checkpoint.pt")
checkpoint = {
    'epoch': num_epochs,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'config': config.to_dict(),
    'train_loss': train_losses[-1] if train_losses else None,
    'val_loss': val_losses[-1] if val_losses else None,
}

torch.save(checkpoint, checkpoint_path)
print(f"Checkpoint saved to: {checkpoint_path}")

# Load checkpoint to verify
loaded_checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=True)
print(f"\nCheckpoint loaded:")
print(f"  Epoch: {loaded_checkpoint['epoch']}")
print(f"  Train loss: {loaded_checkpoint.get('train_loss', 'N/A')}")
print(f"  Val loss: {loaded_checkpoint.get('val_loss', 'N/A')}")

# Create a new model and load weights
new_model = GPTModel(config)
new_model.load_state_dict(loaded_checkpoint['model_state_dict'])
print(f"\nModel weights loaded successfully!")