# Transformer Inference Tutorial

This notebook shows how to use a trained Transformer model for translation.
Each step includes validation to verify correctness.

## Overview
1. Setup and imports
2. Load the trained model
3. Prepare input text
4. Run inference (greedy decoding)
5. Analyze attention weights
6. Batch inference

## Step 1: Setup and Imports

In [None]:
import sys
sys.path.insert(0, '..')  # Add parent directory to path

import torch
import torch.nn.functional as F

# Import our Transformer implementation
from src import Transformer
from src.tokenizer import SimpleTokenizer, pad_sequences
from src.attention import create_causal_mask, create_padding_mask

# Check device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
print(f"PyTorch version: {torch.__version__}")

## Step 2: Load the Trained Model

Load the checkpoint saved from training tutorial.

In [None]:
import os

# Check if checkpoint exists, if not create a demo model
checkpoint_path = '../checkpoints/demo_model.pt'

if os.path.exists(checkpoint_path):
    print("Loading saved checkpoint...")
    checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
    config = checkpoint['config']
else:
    print("No checkpoint found. Creating fresh model for demonstration...")
    config = {
        'vocab_size': 100,
        'd_model': 128,
        'n_heads': 4,
        'n_layers': 2,
        'd_ff': 256,
        'dropout': 0.1,
    }
    checkpoint = None

# Validate: Show configuration
print("\n" + "=" * 60)
print("VALIDATION: Model Configuration")
print("=" * 60)
for key, value in config.items():
    print(f"  {key}: {value}")

In [None]:
# Rebuild tokenizer with same vocabulary
# In practice, save and load tokenizer too
src_sentences = [
    "The cat sat on the mat",
    "Hello world",
    "How are you today",
    "I love machine learning",
    "The weather is nice",
    "Good morning everyone",
    "This is a test",
    "The dog runs fast",
]

tgt_sentences = [
    "Die Katze saß auf der Matte",
    "Hallo Welt",
    "Wie geht es dir heute",
    "Ich liebe maschinelles Lernen",
    "Das Wetter ist schön",
    "Guten Morgen allerseits",
    "Das ist ein Test",
    "Der Hund läuft schnell",
]

tokenizer = SimpleTokenizer()
tokenizer.build_vocab(src_sentences + tgt_sentences)

# Update vocab_size to match tokenizer
config['vocab_size'] = tokenizer.vocab_size

print("=" * 60)
print("VALIDATION: Tokenizer Rebuilt")
print("=" * 60)
print(f"\nVocabulary size: {tokenizer.vocab_size}")
print(f"BOS_ID: {tokenizer.bos_id}")
print(f"EOS_ID: {tokenizer.eos_id}")
print(f"PAD_ID: {tokenizer.pad_id}")

In [None]:
# Create model
model = Transformer(
    src_vocab_size=config['vocab_size'],
    tgt_vocab_size=config['vocab_size'],
    d_model=config['d_model'],
    n_heads=config['n_heads'],
    n_encoder_layers=config['n_layers'],
    n_decoder_layers=config['n_layers'],
    d_ff=config['d_ff'],
    dropout=config['dropout'],
    pad_idx=tokenizer.pad_id,
)

# Load weights if checkpoint exists
if checkpoint is not None:
    try:
        model.load_state_dict(checkpoint['model_state_dict'])
        print("Model weights loaded successfully!")
    except Exception as e:
        print(f"Could not load weights (vocab size mismatch?): {e}")
        print("Using fresh model weights.")

model = model.to(device)
model.eval()  # Set to evaluation mode

print("\n" + "=" * 60)
print("VALIDATION: Model Ready")
print("=" * 60)
print(f"\nModel loaded and set to eval mode")
print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"\n✓ Model is ready for inference!")

## Step 3: Prepare Input Text

Tokenize the input sentence for translation.

In [None]:
# Input sentence to translate
input_text = "The cat sat on the mat"

# Tokenize
src_ids = tokenizer.encode(input_text, add_bos=True, add_eos=True)
src_tensor = torch.tensor([src_ids], dtype=torch.long, device=device)

# Validate input
print("=" * 60)
print("VALIDATION: Input Preparation")
print("=" * 60)
print(f"\nInput text: '{input_text}'")
print(f"\nToken IDs: {src_ids}")
print(f"Tensor shape: {src_tensor.shape}")

# Show token-by-token breakdown
print(f"\nToken breakdown:")
for i, token_id in enumerate(src_ids):
    token_str = tokenizer.decode([token_id])
    print(f"  Position {i}: ID={token_id:3d} -> '{token_str}'")

print(f"\n✓ Input prepared correctly!")

## Step 4: Run Inference (Greedy Decoding)

Generate translation using greedy decoding (selecting most probable token at each step).

In [None]:
def greedy_decode(model, src, tokenizer, max_len=50, device='cpu'):
    """
    Greedy decoding: select the most probable token at each step.
    
    Args:
        model: Trained Transformer model
        src: Source token tensor (1, src_len)
        tokenizer: Tokenizer with bos_id, eos_id
        max_len: Maximum output length
        device: Device to run on
    
    Returns:
        Generated token IDs and step-by-step info
    """
    model.eval()
    
    # Encode source
    src = src.to(device)
    memory = model.encode(src)
    
    # Start with BOS token
    generated = [tokenizer.bos_id]
    step_info = []
    
    for step in range(max_len):
        # Create target tensor
        tgt = torch.tensor([generated], dtype=torch.long, device=device)
        
        # Decode
        tgt_mask = model._create_tgt_mask(tgt)
        decoder_output = model.decode(tgt, memory, tgt_mask)
        
        # Get logits for last position
        logits = model.output_projection(decoder_output[:, -1, :])
        probs = F.softmax(logits, dim=-1)
        
        # Greedy: select most probable
        next_token = probs.argmax(dim=-1).item()
        
        # Get top-3 for analysis
        top_probs, top_ids = probs.topk(3, dim=-1)
        
        step_info.append({
            'step': step + 1,
            'selected': next_token,
            'selected_prob': probs[0, next_token].item(),
            'top3': [(top_ids[0, i].item(), top_probs[0, i].item()) for i in range(3)]
        })
        
        # Append token
        generated.append(next_token)
        
        # Stop if EOS
        if next_token == tokenizer.eos_id:
            break
    
    return generated, step_info

# Run inference
with torch.no_grad():
    generated_ids, step_info = greedy_decode(
        model, src_tensor, tokenizer, max_len=20, device=device
    )

# Decode output
output_text = tokenizer.decode(generated_ids, skip_special_tokens=True)

print("=" * 60)
print("VALIDATION: Greedy Decoding Results")
print("=" * 60)
print(f"\nInput:  '{input_text}'")
print(f"Output: '{output_text}'")
print(f"\nGenerated IDs: {generated_ids}")

In [None]:
# Show step-by-step decoding
print("=" * 60)
print("VALIDATION: Step-by-Step Decoding")
print("=" * 60)

print(f"\nDecoding steps (first 5):")
for info in step_info[:5]:
    token_str = tokenizer.decode([info['selected']])
    print(f"\n  Step {info['step']}:")
    print(f"    Selected: ID={info['selected']} ('{token_str}') prob={info['selected_prob']:.4f}")
    print(f"    Top 3 candidates:")
    for tid, prob in info['top3']:
        t_str = tokenizer.decode([tid])
        print(f"      ID={tid:3d} ('{t_str}') prob={prob:.4f}")

print(f"\n✓ Decoding completed in {len(step_info)} steps!")

## Step 5: Using the Built-in Generate Method

The Transformer class has a built-in generate() method.

In [None]:
# Use built-in generate method
with torch.no_grad():
    generated = model.generate(
        src=src_tensor,
        max_len=20,
        start_token=tokenizer.bos_id,
        end_token=tokenizer.eos_id,
    )

output_ids = generated[0].tolist()
output_text = tokenizer.decode(output_ids, skip_special_tokens=True)

print("=" * 60)
print("VALIDATION: Built-in Generate Method")
print("=" * 60)
print(f"\nInput:  '{input_text}'")
print(f"Output: '{output_text}'")
print(f"\nGenerated shape: {generated.shape}")
print(f"Output IDs: {output_ids}")
print(f"\n✓ Built-in generate() works correctly!")

## Step 6: Batch Inference

Translate multiple sentences at once for efficiency.

In [None]:
# Multiple inputs
test_inputs = [
    "Hello world",
    "Good morning everyone",
    "The dog runs fast",
]

# Tokenize all
src_batch = []
for text in test_inputs:
    ids = tokenizer.encode(text, add_bos=True, add_eos=True)
    src_batch.append(ids)

# Validate: Show first 3 tokenized inputs
print("=" * 60)
print("VALIDATION: Batch Input Preparation")
print("=" * 60)
print(f"\nNumber of inputs: {len(test_inputs)}")
for i, (text, ids) in enumerate(zip(test_inputs, src_batch)):
    print(f"\n  Input {i+1}: '{text}'")
    print(f"  Token IDs: {ids}")
    print(f"  Length: {len(ids)}")

In [None]:
# Pad batch
src_padded = pad_sequences(src_batch, padding_value=tokenizer.pad_id)
src_tensor = src_padded.to(device)

print("\nAfter padding:")
print(f"  Tensor shape: {src_tensor.shape}")
print(f"  Padded tensor:")
for i in range(src_tensor.size(0)):
    print(f"    {src_tensor[i].tolist()}")

In [None]:
# Batch generate
with torch.no_grad():
    batch_generated = model.generate(
        src=src_tensor,
        max_len=20,
        start_token=tokenizer.bos_id,
        end_token=tokenizer.eos_id,
    )

# Decode all outputs
print("=" * 60)
print("VALIDATION: Batch Translation Results")
print("=" * 60)
print(f"\nGenerated shape: {batch_generated.shape}")

print(f"\nTranslations:")
for i in range(len(test_inputs)):
    output_ids = batch_generated[i].tolist()
    output_text = tokenizer.decode(output_ids, skip_special_tokens=True)
    print(f"\n  Input {i+1}:  '{test_inputs[i]}'")
    print(f"  Output {i+1}: '{output_text}'")
    print(f"  IDs: {output_ids}")

print(f"\n✓ Batch inference completed successfully!")

## Step 7: Examining Model Internals

Let's look at the encoder output and attention patterns.

In [None]:
# Get encoder output for analysis
single_input = "The cat sat"
src_ids = tokenizer.encode(single_input, add_bos=True, add_eos=True)
src_tensor = torch.tensor([src_ids], dtype=torch.long, device=device)

with torch.no_grad():
    # Encode
    memory = model.encode(src_tensor)
    
    # Get embeddings before encoding
    embeddings = model.src_embedding(src_tensor)

print("=" * 60)
print("VALIDATION: Model Internals")
print("=" * 60)

print(f"\nInput: '{single_input}'")
print(f"Token IDs: {src_ids}")

print(f"\nEmbedding output:")
print(f"  Shape: {embeddings.shape}")
print(f"  Mean: {embeddings.mean().item():.4f}")
print(f"  Std: {embeddings.std().item():.4f}")

print(f"\nEncoder output (memory):")
print(f"  Shape: {memory.shape}")
print(f"  Mean: {memory.mean().item():.4f}")
print(f"  Std: {memory.std().item():.4f}")

# Show first few dimensions of first token
print(f"\nFirst token representation (first 8 dims):")
print(f"  Embedding: {embeddings[0, 0, :8].tolist()}")
print(f"  After encoder: {memory[0, 0, :8].tolist()}")

In [None]:
# Examine output logits distribution
with torch.no_grad():
    tgt_start = torch.tensor([[tokenizer.bos_id]], device=device)
    tgt_mask = model._create_tgt_mask(tgt_start)
    decoder_output = model.decode(tgt_start, memory, tgt_mask)
    logits = model.output_projection(decoder_output[:, -1, :])
    probs = F.softmax(logits, dim=-1)

print("=" * 60)
print("VALIDATION: Output Distribution (First Step)")
print("=" * 60)

print(f"\nLogits shape: {logits.shape}")
print(f"Probability distribution:")
print(f"  Min prob: {probs.min().item():.6f}")
print(f"  Max prob: {probs.max().item():.6f}")
print(f"  Sum: {probs.sum().item():.4f} (should be 1.0)")

# Show top 5 predictions
top_probs, top_ids = probs.topk(5, dim=-1)
print(f"\nTop 5 predictions for first output token:")
for i in range(5):
    tid = top_ids[0, i].item()
    prob = top_probs[0, i].item()
    token = tokenizer.decode([tid])
    print(f"  {i+1}. ID={tid:3d} ('{token}') prob={prob:.4f}")

print(f"\n✓ Model internals examination complete!")

## Summary

In this tutorial, you learned how to:

1. **Load a trained model** - Restore from checkpoint
2. **Prepare input** - Tokenize text for the model
3. **Run greedy decoding** - Generate translations step by step
4. **Use built-in generate()** - Simplified inference API
5. **Batch inference** - Translate multiple sentences efficiently
6. **Examine internals** - Understand model representations

Each step was validated to ensure correctness.

### Notes for Production:
- Use beam search instead of greedy decoding for better quality
- Implement length normalization for beam search
- Add temperature/top-k/top-p sampling for diversity
- Cache encoder outputs for repeated decoding