# Inference Demo Tutorial

## Introduction

This tutorial provides a comprehensive demonstration of inference workflows for both the Mini Transformer and Advanced Transformer models. We'll explore different inference techniques, from simple text generation to advanced sampling strategies.

### What You'll Learn
- Setting up inference environments
- Loading pre-trained models
- Implementing text generation
- Using different sampling strategies
- Optimizing inference performance
- Evaluating model outputs

In [None]:
# Import required libraries
import torch
import torch.nn.functional as F
import sys
from pathlib import Path
import time
import numpy as np

# Add project root to path
sys.path.append(str(Path('.').parent))

# Import our model implementations
from src.model.mini_transformer import MiniTransformer, MiniTransformerConfig
from src.inference.run_inference import TextGenerator

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name()}")

## 1. Inference Environment Setup

Before inference, we need to set up the appropriate environment and configure CUDA optimizations for better performance.

In [None]:
def setup_inference_environment():
    """Setup inference environment with CUDA optimizations"""
    # Enable cuDNN benchmarking for better performance
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.enabled = True
    
    # Enable TensorFloat-32 for better performance on modern GPUs
    if torch.cuda.is_available():
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True
    
    # Set device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    print(f"Inference environment setup completed")
    print(f"Device: {device}")
    print(f"cuDNN enabled: {torch.backends.cudnn.enabled}")
    print(f"cuDNN benchmark: {torch.backends.cudnn.benchmark}")
    if torch.cuda.is_available():
        print(f"TensorFloat-32 enabled: {torch.backends.cuda.matmul.allow_tf32}")
    
    return device

# Setup environment
device = setup_inference_environment()

## 2. Model Configuration and Initialization

Let's configure and initialize our Mini Transformer model for inference.

In [None]:
# Create model configuration
config = MiniTransformerConfig(
    vocab_size=10000,
    hidden_size=256,
    num_attention_heads=4,
    num_hidden_layers=4,
    intermediate_size=512,
    max_position_embeddings=128,
    dropout_prob=0.1,
    use_cuda=torch.cuda.is_available(),
    use_cudnn=True
)

# Create model
model = MiniTransformer(config)
model.to(device)

# Set model to evaluation mode
model.eval()

# Count parameters
total_params = sum(p.numel() for p in model.parameters())

print(f"Model configuration:")
print(f"  Hidden size: {config.hidden_size}")
print(f"  Attention heads: {config.num_attention_heads}")
print(f"  Hidden layers: {config.num_hidden_layers}")
print(f"  \nModel created successfully")
print(f"  Device: {device}")
print(f"  Total parameters: {total_params:,}")
print(f"  Model in evaluation mode: {not model.training}")

## 3. Text Generation Basics

Let's start with basic text generation using greedy decoding.

In [None]:
def greedy_decode(model, input_ids, max_length=50):
    """Generate text using greedy decoding"""
    model.eval()
    generated = input_ids.clone()
    
    with torch.no_grad():
        for _ in range(max_length):
            outputs = model(generated)
            logits = outputs["logits"]
            
            # Get the most likely next token
            next_token = torch.argmax(logits[:, -1, :], dim=-1, keepdim=True)
            generated = torch.cat([generated, next_token], dim=1)
    
    return generated

# Create sample input
prompt = "The future of artificial intelligence"
input_tokens = [hash(c) % config.vocab_size for c in prompt]
input_ids = torch.tensor([input_tokens], dtype=torch.long, device=device)

print(f"Input prompt: {prompt}")
print(f"Input tokens: {input_tokens[:10]}...")
print(f"Input shape: {input_ids.shape}")

# Generate text
start_time = time.time()
generated_ids = greedy_decode(model, input_ids, max_length=30)
generation_time = time.time() - start_time

print(f"\nGeneration completed in {generation_time:.4f}s")
print(f"Generated sequence length: {generated_ids.shape[1]}")

# Convert back to text (simplified representation)
generated_tokens = generated_ids[0].cpu().tolist()
generated_text = ''.join([chr(token % 128) for token in generated_tokens[len(input_tokens):]])
print(f"Generated text: {prompt + generated_text}")

## 4. Using the TextGenerator Class

Let's use the TextGenerator class from our inference module for more advanced generation.

In [None]:
# Create TextGenerator instance
generator = TextGenerator(model, device=device)

print(f"TextGenerator created successfully")
print(f"Model device: {generator.device}")
print(f"Model compiled: {hasattr(torch, 'compile') and generator.model != model}")

## 5. Sampling Strategies

Let's explore different sampling strategies for more diverse text generation.

In [None]:
# Temperature sampling
prompt = "Machine learning is"

print(f"Prompt: {prompt}")
print("\nTemperature Sampling:")

temperatures = [0.5, 1.0, 1.5]
for temp in temperatures:
    start_time = time.time()
    generated_text = generator.generate_text(
        prompt, 
        max_length=30, 
        temperature=temp, 
        do_sample=True
    )
    generation_time = time.time() - start_time
    
    print(f"  Temperature {temp}: {generated_text[len(prompt):50]}... ({generation_time:.3f}s)")

In [None]:
# Top-k and Top-p sampling
prompt = "The advancement of"

print(f"\nPrompt: {prompt}")
print("\nTop-k and Top-p Sampling:")

# Top-k sampling
start_time = time.time()
generated_k = generator.generate_text(
    prompt,
    max_length=30,
    temperature=1.0,
    do_sample=True,
    top_k=50,
    top_p=1.0
)
time_k = time.time() - start_time

print(f"  Top-k (k=50): {generated_k[len(prompt):50]}... ({time_k:.3f}s)")

# Top-p (nucleus) sampling
start_time = time.time()
generated_p = generator.generate_text(
    prompt,
    max_length=30,
    temperature=1.0,
    do_sample=True,
    top_k=0,  # Disable top-k
    top_p=0.9
)
time_p = time.time() - start_time

print(f"  Top-p (p=0.9): {generated_p[len(prompt):50]}... ({time_p:.3f}s)")

# Combined top-k and top-p
start_time = time.time()
generated_kp = generator.generate_text(
    prompt,
    max_length=30,
    temperature=1.0,
    do_sample=True,
    top_k=50,
    top_p=0.9
)
time_kp = time.time() - start_time

print(f"  Top-k + Top-p: {generated_kp[len(prompt):50]}... ({time_kp:.3f}s)")

## 6. Next Token Predictions

Let's explore how to get predictions for the next token in a sequence.

In [None]:
def show_next_token_predictions(generator, text, top_k=5):
    """Show next token predictions with probabilities"""
    predictions = generator.get_next_token_predictions(text, top_k=top_k)
    
    print(f"Text: \"{text}\"")
    print("Next token predictions:")
    
    for i, (token, prob) in enumerate(zip(predictions["tokens"], predictions["probabilities"])):
        # Convert token to character (simplified)
        char = chr(token % 128)
        print(f"  {i+1:2d}. Token {token:5d} ('{char}') - Probability: {prob:.4f}")

# Show predictions
show_next_token_predictions(generator, "Artificial intelligence", top_k=10)

## 7. Performance Benchmarking

Let's benchmark the inference performance of our model.

In [None]:
def benchmark_inference(model, generator, device, iterations=100):
    """Benchmark inference performance"""
    prompt = "The quick brown fox jumps over the lazy dog. "
    
    # Warmup
    for _ in range(10):
        _ = generator.generate_text(prompt, max_length=10, do_sample=False)
    
    # Benchmark greedy decoding
    start_time = time.time()
    for _ in range(iterations):
        _ = generator.generate_text(prompt, max_length=20, do_sample=False)
    greedy_time = time.time() - start_time
    
    # Benchmark sampling
    start_time = time.time()
    for _ in range(iterations):
        _ = generator.generate_text(prompt, max_length=20, do_sample=True, temperature=1.0)
    sampling_time = time.time() - start_time
    
    # Benchmark next token predictions
    start_time = time.time()
    for _ in range(iterations):
        _ = generator.get_next_token_predictions(prompt, top_k=10)
    prediction_time = time.time() - start_time
    
    avg_greedy = greedy_time / iterations
    avg_sampling = sampling_time / iterations
    avg_prediction = prediction_time / iterations
    
    return avg_greedy, avg_sampling, avg_prediction

# Run benchmark
greedy_time, sampling_time, prediction_time = benchmark_inference(model, generator, device, iterations=50)

print("Inference Performance Benchmark:")
print(f"  Greedy decoding: {greedy_time*1000:.2f} ms per generation")
print(f"  Sampling: {sampling_time*1000:.2f} ms per generation")
print(f"  Next token prediction: {prediction_time*1000:.2f} ms per prediction")

if torch.cuda.is_available():
    print(f"  Device: {torch.cuda.get_device_name()}")
    print(f"  Memory allocated: {torch.cuda.memory_allocated() / 1e6:.1f} MB")

## 8. Batch Inference

Let's explore batch inference for processing multiple prompts simultaneously.

In [None]:
def batch_generate(model, prompts, max_length=20):
    """Generate text for multiple prompts in batch"""
    model.eval()
    
    # Tokenize all prompts
    batch_input_ids = []
    prompt_lengths = []
    
    for prompt in prompts:
        input_tokens = [hash(c) % 10000 for c in prompt]
        batch_input_ids.append(torch.tensor(input_tokens, dtype=torch.long))
        prompt_lengths.append(len(input_tokens))
    
    # Pad to same length
    max_prompt_length = max(prompt_lengths)
    padded_input_ids = []
    
    for input_ids in batch_input_ids:
        padding = torch.zeros(max_prompt_length - len(input_ids), dtype=torch.long)
        padded_ids = torch.cat([input_ids, padding])
        padded_input_ids.append(padded_ids)
    
    # Stack into batch
    batch_input = torch.stack(padded_input_ids).to(model.device)
    
    # Generate
    with torch.no_grad():
        generated = batch_input.clone()
        
        for _ in range(max_length):
            outputs = model(generated)
            logits = outputs["logits"]
            
            # Get next tokens for all sequences
            next_tokens = torch.argmax(logits[:, -1, :], dim=-1, keepdim=True)
            generated = torch.cat([generated, next_tokens], dim=1)
    
    # Convert back to text
    results = []
    for i, prompt in enumerate(prompts):
        generated_tokens = generated[i].cpu().tolist()
        # Remove padding and convert to text
        actual_tokens = generated_tokens[:prompt_lengths[i] + max_length]
        generated_text = ''.join([chr(token % 128) for token in actual_tokens[prompt_lengths[i]:]])
        results.append(prompt + generated_text)
    
    return results

# Test batch inference
prompts = [
    "The future of AI",
    "Machine learning",
    "Deep learning models",
    "Natural language processing"
]

print("Batch Inference Example:")
print(f"Processing {len(prompts)} prompts simultaneously")

start_time = time.time()
batch_results = batch_generate(model, prompts, max_length=15)
batch_time = time.time() - start_time

print(f"\nBatch processing completed in {batch_time:.4f}s")
for i, result in enumerate(batch_results):
    print(f"  {i+1}. \"{result[:50]}...\"")

# Compare with sequential processing
start_time = time.time()
sequential_results = []
for prompt in prompts:
    result = generator.generate_text(prompt, max_length=15, do_sample=False)
    sequential_results.append(result)
sequential_time = time.time() - start_time

print(f"\nSequential processing completed in {sequential_time:.4f}s")
print(f"Speedup: {sequential_time/batch_time:.2f}x")

## 9. Memory Optimization Techniques

Let's explore memory optimization techniques for efficient inference.

In [None]:
def demonstrate_memory_optimization():
    """Demonstrate memory optimization techniques"""
    if not torch.cuda.is_available():
        print("Memory optimization demonstration requires CUDA availability")
        return
    
    print("Memory Optimization Techniques:")
    
    # Clear cache
    torch.cuda.empty_cache()
    initial_memory = torch.cuda.memory_allocated()
    
    print(f"  Initial memory: {initial_memory / 1e6:.1f} MB")
    
    # Create a large tensor to demonstrate memory usage
    large_tensor = torch.randn(1000, 1000, device='cuda')
    after_allocation = torch.cuda.memory_allocated()
    
    print(f"  After large tensor: {after_allocation / 1e6:.1f} MB")
    
    # Delete tensor
    del large_tensor
    torch.cuda.empty_cache()
    after_cleanup = torch.cuda.memory_allocated()
    
    print(f"  After cleanup: {after_cleanup / 1e6:.1f} MB")
    
    # Demonstrate torch.no_grad context
    print(f"\n  Gradient computation memory impact:")
    
    # With gradients
    torch.cuda.empty_cache()
    before_grad = torch.cuda.memory_allocated()
    
    input_ids = torch.randint(0, 10000, (4, 32), device='cuda')
    outputs = model(input_ids)
    loss = outputs["logits"].sum()
    loss.backward()
    
    with_grad_memory = torch.cuda.memory_allocated()
    print(f"    With gradients: {(with_grad_memory - before_grad) / 1e6:.1f} MB")
    
    # Without gradients
    torch.cuda.empty_cache()
    before_no_grad = torch.cuda.memory_allocated()
    
    with torch.no_grad():
        outputs = model(input_ids)
    
    without_grad_memory = torch.cuda.memory_allocated()
    print(f"    Without gradients: {(without_grad_memory - before_no_grad) / 1e6:.1f} MB")
    
    # Cleanup
    torch.cuda.empty_cache()

demonstrate_memory_optimization()

## 10. Model Compilation for Performance

Let's explore model compilation for improved inference performance.

In [None]:
def demonstrate_model_compilation():
    """Demonstrate model compilation for performance improvement"""
    if not hasattr(torch, 'compile'):
        print("Model compilation requires PyTorch 2.0+")
        return
    
    print("Model Compilation:")
    
    # Create a new model for compilation demo
    config_small = MiniTransformerConfig(
        vocab_size=1000,
        hidden_size=128,
        num_attention_heads=2,
        num_hidden_layers=2,
        intermediate_size=256,
        max_position_embeddings=64
    )
    
    model_small = MiniTransformer(config_small).to(device)
    
    # Create sample input
    input_ids = torch.randint(0, 1000, (2, 16), device=device)
    
    # Benchmark uncompiled model
    start_time = time.time()
    for _ in range(20):
        with torch.no_grad():
            _ = model_small(input_ids)
    uncompiled_time = time.time() - start_time
    
    print(f"  Uncompiled model: {uncompiled_time*1000/20:.2f} ms per forward pass")
    
    # Compile model
    try:
        compiled_model = torch.compile(model_small)
        
        # Warmup compiled model
        for _ in range(5):
            with torch.no_grad():
                _ = compiled_model(input_ids)
        
        # Benchmark compiled model
        start_time = time.time()
        for _ in range(20):
            with torch.no_grad():
                _ = compiled_model(input_ids)
        compiled_time = time.time() - start_time
        
        print(f"  Compiled model: {compiled_time*1000/20:.2f} ms per forward pass")
        print(f"  Speedup: {uncompiled_time/compiled_time:.2f}x")
        
    except Exception as e:
        print(f"  Compilation failed: {e}")

if torch.cuda.is_available() and hasattr(torch, 'compile'):
    demonstrate_model_compilation()
else:
    print("Model compilation demonstration requires CUDA and PyTorch 2.0+")

## 11. Advanced Generation Techniques

Let's explore some advanced text generation techniques.

In [None]:
def beam_search_decode(model, input_ids, max_length=20, num_beams=3):
    """Simple beam search implementation"""
    model.eval()
    
    # Initialize beams
    beams = [(input_ids.clone(), 0.0)]  # (sequence, cumulative_log_prob)
    
    with torch.no_grad():
        for _ in range(max_length):
            candidates = []
            
            for sequence, score in beams:
                outputs = model(sequence)
                logits = outputs["logits"]
                
                # Get top-k logits for this beam
                log_probs = F.log_softmax(logits[:, -1, :], dim=-1)
                top_log_probs, top_indices = torch.topk(log_probs, num_beams, dim=-1)
                
                # Add candidates
                for i in range(num_beams):
                    new_sequence = torch.cat([sequence, top_indices[:, i:i+1]], dim=1)
                    new_score = score + top_log_probs[:, i].item()
                    candidates.append((new_sequence, new_score))
            
            # Select top beams
            candidates.sort(key=lambda x: x[1], reverse=True)
            beams = candidates[:num_beams]
    
    # Return the best sequence
    return beams[0][0]

# Test beam search
prompt = "The impact of"
input_tokens = [hash(c) % config.vocab_size for c in prompt]
input_ids = torch.tensor([input_tokens], dtype=torch.long, device=device)

print(f"Advanced Generation Techniques:")
print(f"Prompt: \"{prompt}\"")

# Greedy decoding
start_time = time.time()
greedy_result = generator.generate_text(prompt, max_length=20, do_sample=False)
greedy_time = time.time() - start_time

print(f"\nGreedy decoding: \"{greedy_result[len(prompt):50]}...\" ({greedy_time:.3f}s)")

# Beam search
start_time = time.time()
beam_result_ids = beam_search_decode(model, input_ids, max_length=20, num_beams=3)
beam_time = time.time() - start_time

# Convert beam result to text
beam_tokens = beam_result_ids[0].cpu().tolist()
beam_text = ''.join([chr(token % 128) for token in beam_tokens[len(input_tokens):]])
print(f"Beam search (k=3): \"{prompt + beam_text[:50]}...\" ({beam_time:.3f}s)")

## 12. Model Analysis

Let's analyze some properties of our model during inference.

In [None]:
def analyze_model_outputs(model, input_ids):
    """Analyze model outputs during inference"""
    model.eval()
    
    with torch.no_grad():
        outputs = model(input_ids)
        logits = outputs["logits"]
        hidden_states = outputs["hidden_states"]
    
    print(f"Model Output Analysis:")
    print(f"  Logits shape: {logits.shape}")
    print(f"  Hidden states shape: {hidden_states.shape}")
    
    # Analyze logits
    print(f"\nLogits Analysis:")
    print(f"  Mean: {logits.mean().item():.4f}")
    print(f"  Std: {logits.std().item():.4f}")
    print(f"  Min: {logits.min().item():.4f}")
    print(f"  Max: {logits.max().item():.4f}")
    
    # Analyze entropy (measure of uncertainty)
    probs = F.softmax(logits, dim=-1)
    entropy = -(probs * torch.log(probs + 1e-12)).sum(dim=-1)
    
    print(f"\nUncertainty Analysis:")
    print(f"  Average entropy: {entropy.mean().item():.4f}")
    print(f"  Min entropy: {entropy.min().item():.4f}")
    print(f"  Max entropy: {entropy.max().item():.4f}")
    
    # Analyze hidden states
    print(f"\nHidden States Analysis:")
    print(f"  Mean activation: {hidden_states.mean().item():.4f}")
    print(f"  Std activation: {hidden_states.std().item():.4f}")
    
    return outputs

# Analyze model outputs
sample_input = torch.randint(0, config.vocab_size, (1, 16), device=device)
outputs = analyze_model_outputs(model, sample_input)

## Summary

In this tutorial, we've demonstrated comprehensive inference workflows for the Mini Transformer:

- **Environment Setup**: Configuring CUDA optimizations for better performance
- **Model Configuration**: Setting up model hyperparameters for inference
- **Basic Generation**: Implementing greedy decoding for deterministic output
- **Sampling Strategies**: Using temperature, top-k, and top-p sampling for diverse outputs
- **Batch Inference**: Processing multiple prompts simultaneously for efficiency
- **Performance Benchmarking**: Measuring inference speed and optimization impact
- **Memory Optimization**: Techniques for efficient memory usage
- **Model Compilation**: Using PyTorch 2.0+ compilation for performance gains
- **Advanced Techniques**: Beam search and other sophisticated generation methods
- **Model Analysis**: Understanding model behavior during inference

This inference demo provides a solid foundation for understanding how to deploy Transformer models for text generation tasks. The techniques demonstrated here can be scaled up for larger models and adapted for different architectures. For production deployment, consider using the FastAPI serving script which provides additional optimizations for serving models in production environments.