# WMT14 English-German Training

Production training notebook for the Transformer base model on WMT14 English-German translation task.

## Configuration
- **Model**: Base Transformer (d_model=512, n_heads=8, n_layers=6, d_ff=2048)
- **Dataset**: WMT14 English-German (~4.5M sentence pairs)
- **Tokenizer**: BPE with shared 37K vocabulary
- **Training**: 100K steps, warmup 4000 steps, label smoothing 0.1

## Requirements
```bash
pip install datasets sentencepiece sacrebleu
```

## 1. Setup and Configuration

In [1]:
import sys
sys.path.insert(0, '..')

import os
import json
import time
from datetime import datetime
from pathlib import Path

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

# Our implementation
from src import Transformer
from src.tokenizer import Tokenizer, PAD_ID, BOS_ID, EOS_ID
from src.data import (
    TranslationDataset,
    TranslationCollator,
    create_dynamic_dataloader,
    DynamicBatchSampler,
)
from src.scheduler import TransformerScheduler
from src.label_smoothing import LabelSmoothingLoss

# Check GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

Device: cuda
GPU: NVIDIA GeForce RTX 5090
Memory: 33.7 GB


In [2]:
# ============================================================
# CONFIGURATION - Modify these settings as needed
# ============================================================

CONFIG = {
    # Model (Base Transformer from the paper)
    "d_model": 512,
    "n_heads": 8,
    "n_layers": 6,
    "d_ff": 2048,
    "dropout": 0.1,
    "max_seq_len": 512,
    
    # Tokenizer
    "vocab_size": 37000,  # Shared EN-DE vocabulary
    
    # Training
    "max_steps": 100000,
    "warmup_steps": 4000,
    "label_smoothing": 0.1,
    "max_tokens_per_batch": 4096,  # Tokens per batch (dynamic batching)
    "gradient_accumulation_steps": 4,  # Effective batch ~16K tokens
    "max_grad_norm": 1.0,
    
    # Optimizer (Adam with paper settings)
    "adam_betas": (0.9, 0.98),
    "adam_eps": 1e-9,
    
    # Logging and checkpointing
    "log_steps": 100,
    "eval_steps": 2000,
    "save_steps": 5000,
    "checkpoint_dir": "../checkpoints/wmt14_base",
    
    # Data
    "max_train_samples": None,  # Set to int for debugging (e.g., 10000)
    "max_val_samples": 3000,    # Validation subset for speed
    "num_workers": 4,
}

# Create checkpoint directory
os.makedirs(CONFIG["checkpoint_dir"], exist_ok=True)

print("Configuration:")
for k, v in CONFIG.items():
    print(f"  {k}: {v}")

Configuration:
  d_model: 512
  n_heads: 8
  n_layers: 6
  d_ff: 2048
  dropout: 0.1
  max_seq_len: 512
  vocab_size: 37000
  max_steps: 100000
  warmup_steps: 4000
  label_smoothing: 0.1
  max_tokens_per_batch: 4096
  gradient_accumulation_steps: 4
  max_grad_norm: 1.0
  adam_betas: (0.9, 0.98)
  adam_eps: 1e-09
  log_steps: 100
  eval_steps: 2000
  save_steps: 5000
  checkpoint_dir: ../checkpoints/wmt14_base
  max_train_samples: None
  max_val_samples: 3000
  num_workers: 4


## 2. Load WMT14 Dataset

In [3]:
from datasets import load_dataset

print("Loading WMT14 English-German dataset...")
print("This may take a while on first run (downloading ~1.7GB)")

# Load dataset
wmt14 = load_dataset("wmt14", "de-en")

print(f"\nDataset splits:")
for split, data in wmt14.items():
    print(f"  {split}: {len(data):,} examples")

Loading WMT14 English-German dataset...
This may take a while on first run (downloading ~1.7GB)


README.md: 0.00B [00:00, ?B/s]



de-en/train-00000-of-00003.parquet:   0%|          | 0.00/280M [00:00<?, ?B/s]

de-en/train-00001-of-00003.parquet:   0%|          | 0.00/265M [00:00<?, ?B/s]

de-en/train-00002-of-00003.parquet:   0%|          | 0.00/273M [00:00<?, ?B/s]

de-en/validation-00000-of-00001.parquet:   0%|          | 0.00/474k [00:00<?, ?B/s]

de-en/test-00000-of-00001.parquet:   0%|          | 0.00/509k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/4508785 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/3000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/3003 [00:00<?, ? examples/s]


Dataset splits:
  train: 4,508,785 examples
  validation: 3,000 examples
  test: 3,003 examples


In [4]:
# Extract sentences
def extract_sentences(dataset, max_samples=None):
    """Extract EN and DE sentences from WMT dataset."""
    en_sentences = []
    de_sentences = []
    
    for i, example in enumerate(dataset):
        if max_samples and i >= max_samples:
            break
        translation = example["translation"]
        en_sentences.append(translation["en"])
        de_sentences.append(translation["de"])
    
    return en_sentences, de_sentences

# Extract training data
print("Extracting training sentences...")
train_en, train_de = extract_sentences(
    wmt14["train"], 
    max_samples=CONFIG["max_train_samples"]
)
print(f"Training: {len(train_en):,} sentence pairs")

# Extract validation data
print("Extracting validation sentences...")
val_en, val_de = extract_sentences(
    wmt14["validation"],
    max_samples=CONFIG["max_val_samples"]
)
print(f"Validation: {len(val_en):,} sentence pairs")

# Show examples
print("\nExample sentence pairs:")
for i in range(3):
    print(f"\n  [{i+1}] EN: {train_en[i][:80]}..." if len(train_en[i]) > 80 else f"\n  [{i+1}] EN: {train_en[i]}")
    print(f"      DE: {train_de[i][:80]}..." if len(train_de[i]) > 80 else f"      DE: {train_de[i]}")

Extracting training sentences...
Training: 4,508,785 sentence pairs
Extracting validation sentences...
Validation: 3,000 sentence pairs

Example sentence pairs:

  [1] EN: Resumption of the session
      DE: Wiederaufnahme der Sitzungsperiode

  [2] EN: I declare resumed the session of the European Parliament adjourned on Friday 17 ...
      DE: Ich erkläre die am Freitag, dem 17. Dezember unterbrochene Sitzungsperiode des E...

  [3] EN: Although, as you will have seen, the dreaded 'millennium bug' failed to material...
      DE: Wie Sie feststellen konnten, ist der gefürchtete "Millenium-Bug " nicht eingetre...


## 3. Train BPE Tokenizer

In [6]:
import tempfile

tokenizer_path = Path(CONFIG["checkpoint_dir"]) / "tokenizer.model"

if tokenizer_path.exists():
    print(f"Loading existing tokenizer from {tokenizer_path}")
    tokenizer = Tokenizer(model_path=str(tokenizer_path))
else:
    print("Training BPE tokenizer on combined EN+DE data...")
    print(f"Target vocabulary size: {CONFIG['vocab_size']}")

    # Use smaller subset for tokenizer - 200K sentences is plenty for 37K vocab
    tokenizer_train_size = min(200000, len(train_en))
    print(f"Using {tokenizer_train_size:,} sentence pairs (200K is sufficient for good BPE)")
   
    # Write training data to temp file
    with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as f:
        # Use subset for tokenizer training (faster)
        # tokenizer_train_size = min(1000000, len(train_en))
        for i in range(tokenizer_train_size):
            f.write(train_en[i].strip() + "\n")
            f.write(train_de[i].strip() + "\n")
        temp_path = f.name
    
    print(f"Training on {tokenizer_train_size * 2:,} sentences...")
    
    # Train tokenizer
    tokenizer = Tokenizer.train(
        input_files=temp_path,
        model_prefix=str(tokenizer_path).replace('.model', ''),
        vocab_size=CONFIG["vocab_size"],
        model_type="bpe",
        character_coverage=1.0,
        num_threads=8,
    )
    
    # Cleanup
    os.remove(temp_path)
    print(f"Tokenizer saved to {tokenizer_path}")

print(f"\nTokenizer vocabulary size: {tokenizer.vocab_size}")
print(f"Special tokens: PAD={tokenizer.pad_id}, BOS={tokenizer.bos_id}, EOS={tokenizer.eos_id}")

Loading existing tokenizer from ../checkpoints/wmt14_base/tokenizer.model

Tokenizer vocabulary size: 37000
Special tokens: PAD=0, BOS=2, EOS=3


In [7]:
# Test tokenizer
test_sentences = [
    "The Transformer architecture is based on self-attention.",
    "Die Transformer-Architektur basiert auf Self-Attention.",
]

print("Tokenizer test:")
for sent in test_sentences:
    ids = tokenizer.encode(sent, add_bos=True, add_eos=True)
    pieces = tokenizer.encode_as_pieces(sent)
    decoded = tokenizer.decode(ids)
    print(f"\n  Input: {sent}")
    print(f"  Pieces: {pieces[:10]}{'...' if len(pieces) > 10 else ''}")
    print(f"  IDs: {ids[:10]}{'...' if len(ids) > 10 else ''} (len={len(ids)})")
    print(f"  Decoded: {decoded}")

Tokenizer test:

  Input: The Transformer architecture is based on self-attention.
  Pieces: ['▁The', '▁Trans', 'for', 'mer', '▁architecture', '▁is', '▁based', '▁on', '▁self', '-']...
  IDs: [2, 251, 6473, 231, 501, 25882, 64, 2742, 128, 5374]... (len=15)
  Decoded: The Transformer architecture is based on self-attention.

  Input: Die Transformer-Architektur basiert auf Self-Attention.
  Pieces: ['▁Die', '▁Trans', 'for', 'mer', '-', 'Ar', 'ch', 'itektur', '▁basiert', '▁auf']...
  IDs: [2, 331, 6473, 231, 501, 36786, 19182, 9, 25292, 13264]... (len=19)
  Decoded: Die Transformer-Architektur basiert auf Self-Attention.


## 4. Create Datasets and DataLoaders

In [8]:
print("Creating training dataset...")
train_dataset = TranslationDataset(
    src_data=train_en,
    tgt_data=train_de,
    src_tokenizer=tokenizer,
    tgt_tokenizer=tokenizer,
    max_length=CONFIG["max_seq_len"],
    add_bos=True,
    add_eos=True,
)
print(f"Training dataset: {len(train_dataset):,} examples")

print("Creating validation dataset...")
val_dataset = TranslationDataset(
    src_data=val_en,
    tgt_data=val_de,
    src_tokenizer=tokenizer,
    tgt_tokenizer=tokenizer,
    max_length=CONFIG["max_seq_len"],
    add_bos=True,
    add_eos=True,
)
print(f"Validation dataset: {len(val_dataset):,} examples")

Creating training dataset...
Training dataset: 4,508,785 examples
Creating validation dataset...
Validation dataset: 3,000 examples


In [9]:
# Create dataloaders with dynamic batching
print(f"Creating dataloaders with max_tokens={CONFIG['max_tokens_per_batch']}...")

train_loader = create_dynamic_dataloader(
    dataset=train_dataset,
    max_tokens=CONFIG["max_tokens_per_batch"],
    max_sentences=128,
    shuffle=True,
    num_workers=CONFIG["num_workers"],
    pad_id=tokenizer.pad_id,
)

val_loader = create_dynamic_dataloader(
    dataset=val_dataset,
    max_tokens=CONFIG["max_tokens_per_batch"],
    max_sentences=128,
    shuffle=False,
    num_workers=CONFIG["num_workers"],
    pad_id=tokenizer.pad_id,
)

print(f"Training batches: {len(train_loader):,}")
print(f"Validation batches: {len(val_loader):,}")

# Check first batch
sample_batch = next(iter(train_loader))
print(f"\nSample batch:")
print(f"  src shape: {sample_batch['src'].shape}")
print(f"  tgt shape: {sample_batch['tgt'].shape}")
print(f"  tokens in batch: {sample_batch['src'].numel() + sample_batch['tgt'].numel()}")

Creating dataloaders with max_tokens=4096...
Training batches: 45,718
Validation batches: 29

Sample batch:
  src shape: torch.Size([128, 24])
  tgt shape: torch.Size([128, 24])
  tokens in batch: 6144


## 5. Build Model

In [10]:
print("Building Transformer model...")
print(f"  d_model: {CONFIG['d_model']}")
print(f"  n_heads: {CONFIG['n_heads']}")
print(f"  n_layers: {CONFIG['n_layers']}")
print(f"  d_ff: {CONFIG['d_ff']}")
print(f"  vocab_size: {tokenizer.vocab_size}")

model = Transformer(
    src_vocab_size=tokenizer.vocab_size,
    tgt_vocab_size=tokenizer.vocab_size,
    d_model=CONFIG["d_model"],
    n_heads=CONFIG["n_heads"],
    n_encoder_layers=CONFIG["n_layers"],
    n_decoder_layers=CONFIG["n_layers"],
    d_ff=CONFIG["d_ff"],
    dropout=CONFIG["dropout"],
    max_seq_len=CONFIG["max_seq_len"],
    pad_idx=tokenizer.pad_id,
    share_embeddings=True,  # Share embeddings between encoder and decoder
)

model = model.to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"\nModel parameters:")
print(f"  Total: {total_params:,}")
print(f"  Trainable: {trainable_params:,}")
print(f"  Size: ~{total_params * 4 / 1e6:.1f} MB (fp32)")

Building Transformer model...
  d_model: 512
  n_heads: 8
  n_layers: 6
  d_ff: 2048
  vocab_size: 37000

Model parameters:
  Total: 82,028,680
  Trainable: 82,028,680
  Size: ~328.1 MB (fp32)


In [11]:
# Test forward pass
print("Testing forward pass...")
model.eval()
with torch.no_grad():
    src = sample_batch['src'].to(device)
    tgt = sample_batch['tgt'].to(device)
    tgt_input = tgt[:, :-1]
    
    logits = model(src, tgt_input)
    print(f"  Input src: {src.shape}")
    print(f"  Input tgt: {tgt_input.shape}")
    print(f"  Output logits: {logits.shape}")
    print(f"  Forward pass successful!")

Testing forward pass...
  Input src: torch.Size([128, 24])
  Input tgt: torch.Size([128, 23])
  Output logits: torch.Size([128, 23, 37000])
  Forward pass successful!


## 6. Setup Training Components

In [12]:
# Optimizer
optimizer = torch.optim.Adam(
    model.parameters(),
    lr=1.0,  # Will be controlled by scheduler
    betas=CONFIG["adam_betas"],
    eps=CONFIG["adam_eps"],
)

# Learning rate scheduler
scheduler = TransformerScheduler(
    optimizer,
    d_model=CONFIG["d_model"],
    warmup_steps=CONFIG["warmup_steps"],
)

# Loss function with label smoothing
criterion = LabelSmoothingLoss(
    smoothing=CONFIG["label_smoothing"],
    padding_idx=tokenizer.pad_id,
)

print("Training components:")
print(f"  Optimizer: Adam (betas={CONFIG['adam_betas']}, eps={CONFIG['adam_eps']})")
print(f"  Scheduler: Transformer LR (warmup={CONFIG['warmup_steps']} steps)")
print(f"  Loss: Label smoothing (eps={CONFIG['label_smoothing']})")
print(f"  Gradient accumulation: {CONFIG['gradient_accumulation_steps']} steps")
print(f"  Max gradient norm: {CONFIG['max_grad_norm']}")

Training components:
  Optimizer: Adam (betas=(0.9, 0.98), eps=1e-09)
  Scheduler: Transformer LR (warmup=4000 steps)
  Loss: Label smoothing (eps=0.1)
  Gradient accumulation: 4 steps
  Max gradient norm: 1.0


## 7. Training Loop

In [13]:
def train_step(model, batch, criterion, device):
    """Single training step."""
    src = batch['src'].to(device)
    tgt = batch['tgt'].to(device)
    
    # Teacher forcing: input is tgt[:-1], target is tgt[1:]
    tgt_input = tgt[:, :-1]
    tgt_output = tgt[:, 1:]
    
    # Forward pass
    logits = model(src, tgt_input)
    
    # Compute loss
    loss = criterion(
        logits.contiguous().view(-1, logits.size(-1)),
        tgt_output.contiguous().view(-1)
    )
    
    return loss


@torch.no_grad()
def evaluate(model, val_loader, criterion, device, max_batches=None):
    """Evaluate model on validation set."""
    model.eval()
    total_loss = 0.0
    total_tokens = 0
    
    for i, batch in enumerate(val_loader):
        if max_batches and i >= max_batches:
            break
            
        src = batch['src'].to(device)
        tgt = batch['tgt'].to(device)
        tgt_input = tgt[:, :-1]
        tgt_output = tgt[:, 1:]
        
        logits = model(src, tgt_input)
        loss = criterion(
            logits.contiguous().view(-1, logits.size(-1)),
            tgt_output.contiguous().view(-1)
        )
        
        # Count non-padding tokens
        non_pad = (tgt_output != tokenizer.pad_id).sum().item()
        total_loss += loss.item() * non_pad
        total_tokens += non_pad
    
    model.train()
    return total_loss / total_tokens if total_tokens > 0 else 0.0


def save_checkpoint(model, optimizer, scheduler, step, loss, path):
    """Save training checkpoint."""
    checkpoint = {
        'step': step,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
        'loss': loss,
        'config': CONFIG,
    }
    torch.save(checkpoint, path)
    return path

In [14]:
# Training state
global_step = 0
best_val_loss = float('inf')
training_history = []

# Check for existing checkpoint to resume
resume_path = Path(CONFIG["checkpoint_dir"]) / "latest_checkpoint.pt"
if resume_path.exists():
    print(f"Resuming from {resume_path}")
    checkpoint = torch.load(resume_path, map_location=device, weights_only=False)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
    global_step = checkpoint['step']
    print(f"Resumed from step {global_step}")
else:
    print("Starting fresh training")

Starting fresh training


In [15]:
# Main training loop
print("="*70)
print(f"Starting training for {CONFIG['max_steps']:,} steps")
print(f"Gradient accumulation: {CONFIG['gradient_accumulation_steps']} steps")
print(f"Effective batch size: ~{CONFIG['max_tokens_per_batch'] * CONFIG['gradient_accumulation_steps']:,} tokens")
print("="*70)

model.train()
optimizer.zero_grad()

accumulation_loss = 0.0
accumulation_steps = 0
start_time = time.time()
log_start_time = time.time()

epoch = 0
while global_step < CONFIG["max_steps"]:
    epoch += 1
    
    for batch in train_loader:
        if global_step >= CONFIG["max_steps"]:
            break
        
        # Forward and backward
        loss = train_step(model, batch, criterion, device)
        loss = loss / CONFIG["gradient_accumulation_steps"]
        loss.backward()
        
        accumulation_loss += loss.item()
        accumulation_steps += 1
        
        # Update weights after accumulation
        if accumulation_steps >= CONFIG["gradient_accumulation_steps"]:
            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), CONFIG["max_grad_norm"])
            
            # Optimizer step
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()
            
            global_step += 1
            
            # Logging
            if global_step % CONFIG["log_steps"] == 0:
                elapsed = time.time() - log_start_time
                steps_per_sec = CONFIG["log_steps"] / elapsed
                current_lr = scheduler.get_last_lr()[0]
                
                print(f"Step {global_step:6d} | "
                      f"Loss: {accumulation_loss:.4f} | "
                      f"LR: {current_lr:.2e} | "
                      f"Speed: {steps_per_sec:.1f} steps/s")
                
                training_history.append({
                    'step': global_step,
                    'loss': accumulation_loss,
                    'lr': current_lr,
                })
                
                log_start_time = time.time()
            
            # Evaluation
            if global_step % CONFIG["eval_steps"] == 0:
                val_loss = evaluate(model, val_loader, criterion, device)
                print(f"  >> Validation loss: {val_loss:.4f}")
                
                if val_loss < best_val_loss:
                    best_val_loss = val_loss
                    best_path = Path(CONFIG["checkpoint_dir"]) / "best_model.pt"
                    save_checkpoint(model, optimizer, scheduler, global_step, val_loss, best_path)
                    print(f"  >> New best model saved!")
            
            # Save checkpoint
            if global_step % CONFIG["save_steps"] == 0:
                ckpt_path = Path(CONFIG["checkpoint_dir"]) / f"checkpoint_step_{global_step}.pt"
                save_checkpoint(model, optimizer, scheduler, global_step, accumulation_loss, ckpt_path)
                
                # Also save as latest
                save_checkpoint(model, optimizer, scheduler, global_step, accumulation_loss, resume_path)
                print(f"  >> Checkpoint saved: {ckpt_path.name}")
            
            # Reset accumulation
            accumulation_loss = 0.0
            accumulation_steps = 0
    
    print(f"\n--- Epoch {epoch} completed ---\n")

# Final save
total_time = time.time() - start_time
print("="*70)
print(f"Training completed!")
print(f"Total steps: {global_step:,}")
print(f"Total time: {total_time/3600:.2f} hours")
print(f"Best validation loss: {best_val_loss:.4f}")
print("="*70)

# Save final model
final_path = Path(CONFIG["checkpoint_dir"]) / "final_model.pt"
save_checkpoint(model, optimizer, scheduler, global_step, accumulation_loss, final_path)
print(f"Final model saved to {final_path}")

Starting training for 100,000 steps
Gradient accumulation: 4 steps
Effective batch size: ~16,384 tokens
Step    100 | Loss: 9.9215 | LR: 1.76e-05 | Speed: 5.4 steps/s
Step    200 | Loss: 8.8597 | LR: 3.51e-05 | Speed: 5.7 steps/s
Step    300 | Loss: 8.0831 | LR: 5.26e-05 | Speed: 5.7 steps/s
Step    400 | Loss: 7.7754 | LR: 7.01e-05 | Speed: 5.6 steps/s
Step    500 | Loss: 7.9104 | LR: 8.75e-05 | Speed: 5.6 steps/s
Step    600 | Loss: 7.3925 | LR: 1.05e-04 | Speed: 5.6 steps/s
Step    700 | Loss: 7.1640 | LR: 1.22e-04 | Speed: 5.6 steps/s
Step    800 | Loss: 7.3377 | LR: 1.40e-04 | Speed: 5.6 steps/s
Step    900 | Loss: 7.1376 | LR: 1.57e-04 | Speed: 5.6 steps/s
Step   1000 | Loss: 6.5987 | LR: 1.75e-04 | Speed: 5.6 steps/s
Step   1100 | Loss: 6.6162 | LR: 1.92e-04 | Speed: 5.7 steps/s
Step   1200 | Loss: 6.4070 | LR: 2.10e-04 | Speed: 5.6 steps/s
Step   1300 | Loss: 6.4417 | LR: 2.27e-04 | Speed: 5.6 steps/s
Step   1400 | Loss: 6.7531 | LR: 2.45e-04 | Speed: 5.6 steps/s
Step   1500 | 

## 8. Training Visualization

In [16]:
# Plot training loss
if training_history:
    steps = [h['step'] for h in training_history]
    losses = [h['loss'] for h in training_history]
    lrs = [h['lr'] for h in training_history]
    
    print("Training Summary:")
    print(f"  Initial loss: {losses[0]:.4f}")
    print(f"  Final loss: {losses[-1]:.4f}")
    print(f"  Min loss: {min(losses):.4f} (step {steps[losses.index(min(losses))]})")
    print(f"  Best val loss: {best_val_loss:.4f}")
    
    # Simple ASCII visualization
    print("\nLoss curve (sampled):")
    sample_indices = range(0, len(losses), max(1, len(losses)//20))
    max_loss = max(losses[i] for i in sample_indices)
    min_loss = min(losses[i] for i in sample_indices)
    
    for i in sample_indices:
        normalized = (losses[i] - min_loss) / (max_loss - min_loss + 1e-8)
        bar = '█' * int(normalized * 40)
        print(f"  Step {steps[i]:6d}: {bar} {losses[i]:.4f}")

Training Summary:
  Initial loss: 9.9215
  Final loss: 4.1571
  Min loss: 2.9179 (step 45500)
  Best val loss: 4.4300

Loss curve (sampled):
  Step    100: ███████████████████████████████████████ 9.9215
  Step   5100: ███████████ 5.2867
  Step  10100: ███████ 4.7566
  Step  15100: ██ 3.9555
  Step  20100: ██████ 4.5271
  Step  25100: ██████ 4.5078
  Step  30100: █ 3.7335
  Step  35100: ██████ 4.5115
  Step  40100: ███████ 4.6572
  Step  45100: ████ 4.2267
  Step  50100: ██ 3.9495
  Step  55100: █ 3.7280
  Step  60100:  3.6657
  Step  65100: ██ 3.9706
  Step  70100: ██ 3.8984
  Step  75100: █ 3.8188
  Step  80100: ████ 4.2655
  Step  85100:  3.6069
  Step  90100: █ 3.7764
  Step  95100:  3.5189


## 9. Quick Translation Test

In [17]:
# Test translation
model.eval()

test_sentences = [
    "The weather is nice today.",
    "I love machine learning.",
    "The European Union is an economic and political union.",
]

print("Translation test:")
print("="*70)

for sent in test_sentences:
    # Encode
    src_ids = tokenizer.encode(sent, add_bos=True, add_eos=True)
    src_tensor = torch.tensor([src_ids], device=device)
    
    # Generate
    with torch.no_grad():
        output = model.generate(
            src=src_tensor,
            max_len=100,
            start_token=tokenizer.bos_id,
            end_token=tokenizer.eos_id,
        )
    
    # Decode
    translation = tokenizer.decode(output[0].tolist(), skip_special_tokens=True)
    
    print(f"\nEN: {sent}")
    print(f"DE: {translation}")

Translation test:

EN: The weather is nice today.
DE: Heute ist das Wetter schön.

EN: I love machine learning.
DE: Ich liebe die Arbeit.

EN: The European Union is an economic and political union.
DE: Die Europäische Union ist eine wirtschaftliche und politische Union.


## Summary

This notebook trained a base Transformer model on WMT14 English-German:

1. **Dataset**: WMT14 EN-DE (~4.5M sentence pairs)
2. **Tokenizer**: BPE with 37K shared vocabulary
3. **Model**: Base Transformer (65M parameters)
   - d_model=512, n_heads=8, n_layers=6, d_ff=2048
4. **Training**: Dynamic batching, gradient accumulation, label smoothing

### Checkpoints saved:
- `checkpoints/wmt14_base/best_model.pt` - Best validation loss
- `checkpoints/wmt14_base/final_model.pt` - Final model
- `checkpoints/wmt14_base/tokenizer.model` - BPE tokenizer

### Next steps:
- Use `04_wmt14_inference.ipynb` to run inference and evaluate BLEU scores
- Implement beam search for better translation quality
- Train longer for better results