# Transformer Training Tutorial

This notebook walks you through training a Transformer model step by step.
Each step includes validation to verify correctness.

## Overview
1. Setup and imports
2. Prepare sample data
3. Create tokenizer and vocabulary
4. Create dataset and dataloader
5. Build the Transformer model
6. Setup optimizer, scheduler, and loss function
7. Training loop
8. Save checkpoint

## Step 1: Setup and Imports

First, let's import all necessary modules and set up the environment.

In [1]:
import sys
sys.path.insert(0, '..')  # Add parent directory to path

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

# Import our Transformer implementation
from src import Transformer
from src.tokenizer import SimpleTokenizer, pad_sequences, PAD_ID, BOS_ID, EOS_ID
from src.data import TranslationDataset, TranslationCollator, create_translation_dataloader
from src.scheduler import TransformerScheduler, get_lr_at_step
from src.label_smoothing import LabelSmoothingLoss
from src.trainer import Trainer, TrainerConfig

# Check device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
print(f"PyTorch version: {torch.__version__}")

Using device: cuda
PyTorch version: 2.8.0+cu128


## Step 2: Prepare Sample Data

We'll use a small English-to-German translation dataset for demonstration.
In production, you would use WMT14 or similar datasets.

In [2]:
# Sample parallel sentences (English -> German)
src_sentences = [
    "The cat sat on the mat",
    "Hello world",
    "How are you today",
    "I love machine learning",
    "The weather is nice",
    "Good morning everyone",
    "This is a test",
    "The dog runs fast",
    "She reads a book",
    "We learn together",
    "The sun is shining",
    "Birds fly in the sky",
    "I drink coffee",
    "He plays guitar",
    "They study hard",
    "The flower is beautiful",
]

tgt_sentences = [
    "Die Katze saß auf der Matte",
    "Hallo Welt",
    "Wie geht es dir heute",
    "Ich liebe maschinelles Lernen",
    "Das Wetter ist schön",
    "Guten Morgen allerseits",
    "Das ist ein Test",
    "Der Hund läuft schnell",
    "Sie liest ein Buch",
    "Wir lernen zusammen",
    "Die Sonne scheint",
    "Vögel fliegen am Himmel",
    "Ich trinke Kaffee",
    "Er spielt Gitarre",
    "Sie lernen fleißig",
    "Die Blume ist wunderschön",
]

# Validate: Show first 3 lines
print("=" * 60)
print("VALIDATION: First 3 sentence pairs")
print("=" * 60)
for i in range(3):
    print(f"\nPair {i+1}:")
    print(f"  Source (EN): {src_sentences[i]}")
    print(f"  Target (DE): {tgt_sentences[i]}")

print(f"\n✓ Total sentence pairs: {len(src_sentences)}")

VALIDATION: First 3 sentence pairs

Pair 1:
  Source (EN): The cat sat on the mat
  Target (DE): Die Katze saß auf der Matte

Pair 2:
  Source (EN): Hello world
  Target (DE): Hallo Welt

Pair 3:
  Source (EN): How are you today
  Target (DE): Wie geht es dir heute

✓ Total sentence pairs: 16


## Step 3: Create Tokenizer and Build Vocabulary

We use a SimpleTokenizer for this demo. In production, use BPE tokenizer.

In [3]:
# Create tokenizer and build vocabulary from all sentences
tokenizer = SimpleTokenizer()
all_sentences = src_sentences + tgt_sentences
tokenizer.build_vocab(all_sentences)

# Validate tokenizer
print("=" * 60)
print("VALIDATION: Tokenizer")
print("=" * 60)
print(f"\nVocabulary size: {tokenizer.vocab_size}")
print(f"\nSpecial tokens:")
print(f"  PAD_ID = {tokenizer.pad_id}")
print(f"  UNK_ID = {tokenizer.unk_id}")
print(f"  BOS_ID = {tokenizer.bos_id}")
print(f"  EOS_ID = {tokenizer.eos_id}")

# Test encoding/decoding on first 3 sentences
print(f"\nEncoding first 3 sentences:")
for i in range(3):
    tokens = tokenizer.encode(src_sentences[i], add_bos=True, add_eos=True)
    decoded = tokenizer.decode(tokens)
    print(f"\n  Sentence {i+1}: '{src_sentences[i]}'")
    print(f"  Token IDs:   {tokens}")
    print(f"  Decoded:     '{decoded}'")

print(f"\n✓ Tokenizer is working correctly!")

VALIDATION: Tokenizer

Vocabulary size: 104

Special tokens:
  PAD_ID = 0
  UNK_ID = 1
  BOS_ID = 2
  EOS_ID = 3

Encoding first 3 sentences:

  Sentence 1: 'The cat sat on the mat'
  Token IDs:   [2, 4, 52, 85, 81, 15, 78, 3]
  Decoded:     'The cat sat on the mat'

  Sentence 2: 'Hello world'
  Token IDs:   [2, 26, 100, 3]
  Decoded:     'Hello world'

  Sentence 3: 'How are you today'
  Token IDs:   [2, 28, 48, 102, 96, 3]
  Decoded:     'How are you today'

✓ Tokenizer is working correctly!


## Step 4: Create Dataset and DataLoader

Wrap our data in PyTorch Dataset and DataLoader for efficient batching.

In [4]:
# Create dataset
dataset = TranslationDataset(
    src_data=src_sentences,
    tgt_data=tgt_sentences,
    src_tokenizer=tokenizer,
    tgt_tokenizer=tokenizer,
    add_bos=True,
    add_eos=True,
)

# Create dataloader with collator for padding
collator = TranslationCollator(pad_id=tokenizer.pad_id)
train_loader = DataLoader(
    dataset,
    batch_size=4,
    shuffle=True,
    collate_fn=collator,
)

# Validate: Show first batch
print("=" * 60)
print("VALIDATION: DataLoader")
print("=" * 60)
print(f"\nDataset size: {len(dataset)}")
print(f"Batch size: 4")
print(f"Number of batches: {len(train_loader)}")

# Get one batch and inspect
sample_batch = next(iter(train_loader))
print(f"\nFirst batch contents:")
print(f"  src shape: {sample_batch['src'].shape}")
print(f"  tgt shape: {sample_batch['tgt'].shape}")
print(f"  src_mask shape: {sample_batch['src_mask'].shape}")

# Show first 3 samples from batch
print(f"\nFirst 3 samples in batch:")
for i in range(min(3, sample_batch['src'].size(0))):
    src_ids = sample_batch['src'][i].tolist()
    tgt_ids = sample_batch['tgt'][i].tolist()
    print(f"\n  Sample {i+1}:")
    print(f"    src IDs: {src_ids}")
    print(f"    tgt IDs: {tgt_ids}")
    print(f"    src decoded: '{tokenizer.decode(src_ids)}'")
    print(f"    tgt decoded: '{tokenizer.decode(tgt_ids)}'")

print(f"\n✓ DataLoader is working correctly!")

VALIDATION: DataLoader

Dataset size: 16
Batch size: 4
Number of batches: 4

First batch contents:
  src shape: torch.Size([4, 7])
  tgt shape: torch.Size([4, 7])
  src_mask shape: torch.Size([4, 7])

First 3 samples in batch:

  Sample 1:
    src IDs: [2, 28, 48, 102, 96, 3, 0]
    tgt IDs: [2, 44, 65, 58, 55, 68, 3]
    src decoded: 'How are you today'
    tgt decoded: 'Wie geht es dir heute'

  Sample 2:
    src IDs: [2, 16, 64, 69, 15, 91, 3]
    tgt IDs: [2, 40, 62, 47, 27, 3, 0]
    src decoded: 'Birds fly in the sky'
    tgt decoded: 'Vögel fliegen am Himmel'

  Sample 3:
    src IDs: [2, 38, 93, 67, 3, 0, 0]
    tgt IDs: [2, 11, 14, 61, 3, 0, 0]
    src decoded: 'They study hard'
    tgt decoded: 'Sie lernen fleißig'

✓ DataLoader is working correctly!


## Step 5: Build the Transformer Model

Create the Transformer model with appropriate hyperparameters.

In [5]:
# Model hyperparameters (smaller for demo)
vocab_size = tokenizer.vocab_size
d_model = 128      # Model dimension (paper uses 512)
n_heads = 4        # Attention heads (paper uses 8)
n_layers = 2       # Encoder/decoder layers (paper uses 6)
d_ff = 256         # FFN dimension (paper uses 2048)
dropout = 0.1
max_seq_len = 100

# Create model
model = Transformer(
    src_vocab_size=vocab_size,
    tgt_vocab_size=vocab_size,
    d_model=d_model,
    n_heads=n_heads,
    n_encoder_layers=n_layers,
    n_decoder_layers=n_layers,
    d_ff=d_ff,
    dropout=dropout,
    max_seq_len=max_seq_len,
    pad_idx=tokenizer.pad_id,
)

model = model.to(device)

# Validate model
print("=" * 60)
print("VALIDATION: Model Architecture")
print("=" * 60)
print(f"\nModel hyperparameters:")
print(f"  vocab_size: {vocab_size}")
print(f"  d_model: {d_model}")
print(f"  n_heads: {n_heads}")
print(f"  n_layers: {n_layers}")
print(f"  d_ff: {d_ff}")
print(f"  dropout: {dropout}")

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"\nParameter count:")
print(f"  Total: {total_params:,}")
print(f"  Trainable: {trainable_params:,}")

# Test forward pass
print(f"\nTesting forward pass...")
model.eval()
with torch.no_grad():
    src = sample_batch['src'].to(device)
    tgt = sample_batch['tgt'].to(device)
    
    # Use tgt[:-1] as input (teacher forcing)
    tgt_input = tgt[:, :-1]
    logits = model(src, tgt_input)
    
    print(f"  Input src shape: {src.shape}")
    print(f"  Input tgt shape: {tgt_input.shape}")
    print(f"  Output logits shape: {logits.shape}")
    print(f"  Expected shape: (batch, tgt_len-1, vocab_size) = ({src.size(0)}, {tgt_input.size(1)}, {vocab_size})")

print(f"\n✓ Model is working correctly!")

VALIDATION: Model Architecture

Model hyperparameters:
  vocab_size: 104
  d_model: 128
  n_heads: 4
  n_layers: 2
  d_ff: 256
  dropout: 0.1

Parameter count:
  Total: 700,008
  Trainable: 700,008

Testing forward pass...
  Input src shape: torch.Size([4, 7])
  Input tgt shape: torch.Size([4, 6])
  Output logits shape: torch.Size([4, 6, 104])
  Expected shape: (batch, tgt_len-1, vocab_size) = (4, 6, 104)

✓ Model is working correctly!


## Step 6: Setup Optimizer, Scheduler, and Loss Function

Configure training components as described in the paper.

In [6]:
# Optimizer (Adam with paper's beta values)
optimizer = torch.optim.Adam(
    model.parameters(),
    lr=1.0,  # Will be overridden by scheduler
    betas=(0.9, 0.98),
    eps=1e-9,
)

# Learning rate scheduler (warmup + inverse sqrt decay)
scheduler = TransformerScheduler(
    optimizer,
    d_model=d_model,
    warmup_steps=100,  # Smaller for demo (paper uses 4000)
)

# Loss function with label smoothing
criterion = LabelSmoothingLoss(
    smoothing=0.1,
    padding_idx=tokenizer.pad_id,
)

# Validate training components
print("=" * 60)
print("VALIDATION: Training Components")
print("=" * 60)

print(f"\nOptimizer: Adam")
print(f"  betas: (0.9, 0.98)")
print(f"  eps: 1e-9")

print(f"\nScheduler: TransformerScheduler")
print(f"  d_model: {d_model}")
print(f"  warmup_steps: 100")

# Show learning rate at different steps
print(f"\nLearning rate schedule (first 5 steps):")
for step in [1, 25, 50, 100, 200]:
    lr = get_lr_at_step(step, d_model=d_model, warmup_steps=100)
    print(f"  Step {step:4d}: lr = {lr:.6f}")

print(f"\nLoss function: LabelSmoothingLoss")
print(f"  smoothing: 0.1")
print(f"  padding_idx: {tokenizer.pad_id}")

# Test loss computation
model.eval()
with torch.no_grad():
    src = sample_batch['src'].to(device)
    tgt = sample_batch['tgt'].to(device)
    tgt_input = tgt[:, :-1]
    tgt_output = tgt[:, 1:]
    
    logits = model(src, tgt_input)
    
    # Flatten for loss
    loss = criterion(
        logits.contiguous().view(-1, vocab_size),
        tgt_output.contiguous().view(-1)
    )
    print(f"\nTest loss computation:")
    print(f"  Loss value: {loss.item():.4f}")

print(f"\n✓ Training components are configured correctly!")

VALIDATION: Training Components

Optimizer: Adam
  betas: (0.9, 0.98)
  eps: 1e-9

Scheduler: TransformerScheduler
  d_model: 128
  warmup_steps: 100

Learning rate schedule (first 5 steps):
  Step    1: lr = 0.000088
  Step   25: lr = 0.002210
  Step   50: lr = 0.004419
  Step  100: lr = 0.008839
  Step  200: lr = 0.006250

Loss function: LabelSmoothingLoss
  smoothing: 0.1
  padding_idx: 0

Test loss computation:
  Loss value: 5.1321

✓ Training components are configured correctly!


## Step 7: Training Loop

Run the training loop and observe loss decreasing.

In [7]:
# Training configuration
num_epochs = 20
log_interval = 5  # Log every N batches

print("=" * 60)
print("TRAINING")
print("=" * 60)
print(f"\nStarting training for {num_epochs} epochs...")
print(f"Batches per epoch: {len(train_loader)}")

# Training loop
model.train()
global_step = 0
training_losses = []

for epoch in range(num_epochs):
    epoch_loss = 0.0
    num_batches = 0
    
    for batch_idx, batch in enumerate(train_loader):
        # Move to device
        src = batch['src'].to(device)
        tgt = batch['tgt'].to(device)
        
        # Teacher forcing: input is tgt[:-1], target is tgt[1:]
        tgt_input = tgt[:, :-1]
        tgt_output = tgt[:, 1:]
        
        # Zero gradients
        optimizer.zero_grad()
        
        # Forward pass
        logits = model(src, tgt_input)
        
        # Compute loss
        loss = criterion(
            logits.contiguous().view(-1, vocab_size),
            tgt_output.contiguous().view(-1)
        )
        
        # Backward pass
        loss.backward()
        
        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        # Optimizer step
        optimizer.step()
        scheduler.step()
        
        # Track loss
        epoch_loss += loss.item()
        num_batches += 1
        global_step += 1
        training_losses.append(loss.item())
    
    # Print epoch summary
    avg_loss = epoch_loss / num_batches
    current_lr = scheduler.get_last_lr()[0]
    print(f"Epoch {epoch+1:3d}/{num_epochs} | Loss: {avg_loss:.4f} | LR: {current_lr:.6f}")

print(f"\n✓ Training completed!")
print(f"  Final loss: {training_losses[-1]:.4f}")
print(f"  Total steps: {global_step}")

TRAINING

Starting training for 20 epochs...
Batches per epoch: 4
Epoch   1/20 | Loss: 4.9465 | LR: 0.000442
Epoch   2/20 | Loss: 4.4643 | LR: 0.000795
Epoch   3/20 | Loss: 4.1146 | LR: 0.001149
Epoch   4/20 | Loss: 3.7766 | LR: 0.001503
Epoch   5/20 | Loss: 3.2914 | LR: 0.001856
Epoch   6/20 | Loss: 2.7534 | LR: 0.002210
Epoch   7/20 | Loss: 2.2702 | LR: 0.002563
Epoch   8/20 | Loss: 1.9482 | LR: 0.002917
Epoch   9/20 | Loss: 1.6900 | LR: 0.003270
Epoch  10/20 | Loss: 1.8712 | LR: 0.003624
Epoch  11/20 | Loss: 1.7007 | LR: 0.003977
Epoch  12/20 | Loss: 1.5249 | LR: 0.004331
Epoch  13/20 | Loss: 1.3764 | LR: 0.004685
Epoch  14/20 | Loss: 1.3637 | LR: 0.005038
Epoch  15/20 | Loss: 1.2947 | LR: 0.005392
Epoch  16/20 | Loss: 1.6257 | LR: 0.005745
Epoch  17/20 | Loss: 1.5812 | LR: 0.006099
Epoch  18/20 | Loss: 1.6971 | LR: 0.006452
Epoch  19/20 | Loss: 1.6568 | LR: 0.006806
Epoch  20/20 | Loss: 1.7142 | LR: 0.007159

✓ Training completed!
  Final loss: 1.9447
  Total steps: 80


In [8]:
# Validate: Plot training loss
print("=" * 60)
print("VALIDATION: Training Progress")
print("=" * 60)

# Show loss trend (first, middle, last)
print(f"\nLoss at different points:")
print(f"  Start (step 1):     {training_losses[0]:.4f}")
print(f"  Middle (step {len(training_losses)//2}):   {training_losses[len(training_losses)//2]:.4f}")
print(f"  End (step {len(training_losses)}):    {training_losses[-1]:.4f}")

# Check if loss decreased
loss_decreased = training_losses[-1] < training_losses[0]
print(f"\nLoss decreased during training: {'✓ Yes' if loss_decreased else '✗ No'}")

# Simple ASCII plot
print(f"\nLoss curve (every 10 steps):")
max_loss = max(training_losses)
min_loss = min(training_losses)
for i in range(0, len(training_losses), max(1, len(training_losses)//10)):
    bar_len = int(30 * (training_losses[i] - min_loss) / (max_loss - min_loss + 1e-6))
    print(f"  Step {i+1:3d}: {'█' * bar_len} {training_losses[i]:.4f}")

VALIDATION: Training Progress

Loss at different points:
  Start (step 1):     4.9212
  Middle (step 40):   1.5346
  End (step 80):    1.9447

Loss decreased during training: ✓ Yes

Loss curve (every 10 steps):
  Step   1: ████████████████████████████ 4.9212
  Step   9: ██████████████████████ 4.0814
  Step  17: █████████████████ 3.3957
  Step  25: █████████ 2.3216
  Step  33: ███ 1.4699
  Step  41: ███ 1.5346
  Step  49: ██ 1.3758
  Step  57: ██ 1.3373
  Step  65:  1.1149
  Step  73: █ 1.2524


## Step 8: Save Checkpoint

Save the trained model for later use.

In [9]:
import os

# Create checkpoints directory
os.makedirs('../checkpoints', exist_ok=True)

# Save checkpoint
checkpoint_path = '../checkpoints/demo_model.pt'
checkpoint = {
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'scheduler_state_dict': scheduler.state_dict(),
    'global_step': global_step,
    'final_loss': training_losses[-1],
    'config': {
        'vocab_size': vocab_size,
        'd_model': d_model,
        'n_heads': n_heads,
        'n_layers': n_layers,
        'd_ff': d_ff,
        'dropout': dropout,
    }
}

torch.save(checkpoint, checkpoint_path)

# Validate checkpoint
print("=" * 60)
print("VALIDATION: Checkpoint Saved")
print("=" * 60)
print(f"\nCheckpoint saved to: {checkpoint_path}")
print(f"File size: {os.path.getsize(checkpoint_path) / 1024:.1f} KB")

# Verify we can load it
loaded = torch.load(checkpoint_path, weights_only=False)
print(f"\nCheckpoint contents:")
for key in loaded:
    if key == 'config':
        print(f"  {key}: {loaded[key]}")
    else:
        print(f"  {key}: [present]")

print(f"\n✓ Checkpoint saved and verified!")

VALIDATION: Checkpoint Saved

Checkpoint saved to: ../checkpoints/demo_model.pt
File size: 8337.3 KB

Checkpoint contents:
  model_state_dict: [present]
  optimizer_state_dict: [present]
  scheduler_state_dict: [present]
  global_step: [present]
  final_loss: [present]
  config: {'vocab_size': 104, 'd_model': 128, 'n_heads': 4, 'n_layers': 2, 'd_ff': 256, 'dropout': 0.1}

✓ Checkpoint saved and verified!


## Summary

In this tutorial, you learned how to:

1. **Prepare data** - Create parallel sentence pairs for translation
2. **Tokenize** - Build vocabulary and convert text to token IDs
3. **Create DataLoader** - Batch and pad sequences efficiently
4. **Build model** - Construct a Transformer with proper hyperparameters
5. **Setup training** - Configure optimizer, scheduler, and loss function
6. **Train** - Run the training loop with gradient clipping
7. **Save** - Store the trained model checkpoint

Each step was validated to ensure correctness. The training loss should decrease over epochs, indicating the model is learning.

For production training:
- Use larger d_model (512) and more layers (6)
- Use BPE tokenizer with larger vocabulary (~37K)
- Train on WMT14 or similar large datasets
- Use dynamic batching based on max_tokens
- Train for 100K+ steps with warmup_steps=4000