# PyTorch Debugging & Common Gotchas: Avoiding Training Disasters

## 🎯 Introduction

Welcome to your PyTorch debugging survival guide! This notebook will arm you with the knowledge to identify, diagnose, and fix the most common PyTorch mistakes that can derail your training. Every deep learning practitioner faces these issues - the difference is knowing how to solve them quickly.

### 🧠 What You'll Learn to Debug

This essential guide covers:
- **Shape mismatches**: The #1 source of PyTorch errors and how to fix them
- **Training vs eval modes**: Why your model behaves differently during inference
- **Memory issues**: GPU out-of-memory errors and optimization strategies  
- **Gradient problems**: Vanishing, exploding, and missing gradients
- **Common training failures**: Why your loss isn't decreasing and how to fix it

### 🎓 Prerequisites

- Experience with PyTorch tensors, modules, and basic training loops
- Familiarity with common neural network architectures
- Basic understanding of backpropagation and gradient descent

### 🚀 Why Debugging Skills Matter

Effective debugging enables:
- **Faster development**: Spend time building, not hunting bugs
- **Reliable training**: Catch issues before they waste compute time
- **Better models**: Identify and fix performance bottlenecks
- **Confidence**: Know your models are working as intended
- **Scalability**: Debug issues that only appear at scale

---

## 📚 Table of Contents

1. **[Shape Debugging Mastery](#shape-debugging-mastery)** - Conquering tensor dimension mismatches
2. **[Training vs Eval Mode Gotchas](#training-vs-eval-mode-gotchas)** - Understanding model state behavior
3. **[Memory Management & GPU Issues](#memory-management-gpu-issues)** - Handling out-of-memory errors
4. **[Gradient Flow Problems](#gradient-flow-problems)** - Diagnosing backpropagation issues
5. **[Training Failure Patterns](#training-failure-patterns)** - When your loss won't decrease

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import warnings

# Set seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
print(f"PyTorch version: {torch.__version__}")

## Shape Debugging Mastery

### 🔍 The #1 Source of PyTorch Errors

Shape mismatches cause more PyTorch frustration than any other issue. Let's learn to diagnose and fix them systematically, plus develop techniques to prevent them in the first place!

In [None]:
# =============================================================================
# SHAPE DEBUGGING SURVIVAL GUIDE
# =============================================================================

print("🔍 Shape Debugging Mastery")
print("=" * 50)

# Common shape error scenarios and how to debug them
print("Scenario 1: Matrix Multiplication Dimension Mismatch")
print("-" * 30)

try:
    # This will fail - inner dimensions don't match
    A = torch.randn(3, 4)  # [3, 4]
    B = torch.randn(2, 5)  # [2, 5] - can't multiply with A
    
    print(f"Tensor A shape: {A.shape}")
    print(f"Tensor B shape: {B.shape}")
    print("Attempting A @ B...")
    
    result = A @ B  # This will raise an error
    
except RuntimeError as e:
    print(f"❌ Error: {e}")
    print(f"\n🔧 Debugging approach:")
    print(f"1. Check dimensions: A is {A.shape}, B is {B.shape}")
    print(f"2. For A @ B, need A[..., n] and B[n, ...]")
    print(f"3. A has last dim {A.shape[-1]}, B has first dim {B.shape[0]}")
    print(f"4. {A.shape[-1]} ≠ {B.shape[0]}, so multiplication fails")
    
    print(f"\n✅ Fix: Transpose B or create compatible tensors")
    B_fixed = torch.randn(4, 5)  # [4, 5] - now compatible with A
    result = A @ B_fixed  # [3, 4] @ [4, 5] = [3, 5]
    print(f"Fixed: A {A.shape} @ B_fixed {B_fixed.shape} = result {result.shape}")

print(f"\nScenario 2: Broadcasting Confusion")
print("-" * 30)

# Broadcasting can be tricky to debug
A = torch.randn(3, 1, 4)    # [3, 1, 4]
B = torch.randn(2, 4)       # [2, 4]

print(f"Tensor A shape: {A.shape}")
print(f"Tensor B shape: {B.shape}")

try:
    result = A + B
    print(f"✅ Success: A + B = {result.shape}")
    print(f"Broadcasting rule: dimensions aligned from right:")
    print(f"  A: [3, 1, 4]")
    print(f"  B:    [2, 4]")
    print(f"  →:  [3, 2, 4] (B expands in dim 0, A expands in dim 1)")
    
except RuntimeError as e:
    print(f"❌ Error: {e}")

# Show a case that fails
print(f"\nBroadcasting failure example:")
A = torch.randn(3, 4)       # [3, 4] 
B = torch.randn(3, 5)       # [3, 5]

print(f"Tensor A shape: {A.shape}")
print(f"Tensor B shape: {B.shape}")

try:
    result = A + B
except RuntimeError as e:
    print(f"❌ Error: {e}")
    print(f"🔧 Problem: Last dimensions must match or be 1 for broadcasting")
    print(f"  A last dim: {A.shape[-1]}")
    print(f"  B last dim: {B.shape[-1]}")
    print(f"  Neither is 1, and {A.shape[-1]} ≠ {B.shape[-1]}")

print(f"\nScenario 3: Neural Network Layer Mismatches")
print("-" * 30)

# Common neural network shape errors
class DebugModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(10, 5)   # Expects input: [..., 10]
        self.fc2 = nn.Linear(5, 1)    # Expects input: [..., 5]
    
    def forward(self, x):
        print(f"  Input shape: {x.shape}")
        x = self.fc1(x)
        print(f"  After fc1: {x.shape}")
        x = self.fc2(x)
        print(f"  After fc2: {x.shape}")
        return x

model = DebugModel()

# Correct usage
print("✅ Correct usage:")
correct_input = torch.randn(3, 10)  # Batch size 3, 10 features
output = model(correct_input)

# Common mistake: wrong feature dimension
print(f"\n❌ Wrong feature dimension:")
try:
    wrong_input = torch.randn(3, 8)  # Wrong: 8 features instead of 10
    print(f"Input shape: {wrong_input.shape}")
    output = model(wrong_input)
except RuntimeError as e:
    print(f"Error: {e}")
    print(f"🔧 Fix: Ensure input features match layer expectation")
    print(f"  Model expects: [..., 10]")
    print(f"  You provided: [..., 8]")

print(f"\n🛠️ Essential Shape Debugging Tools")
print("=" * 50)

def debug_shapes(*tensors, names=None):
    """Utility function to debug tensor shapes."""
    if names is None:
        names = [f"tensor_{i}" for i in range(len(tensors))]
    
    print("Shape Debug Report:")
    print("-" * 20)
    for name, tensor in zip(names, tensors):
        print(f"{name:12}: {tensor.shape} | dtype: {tensor.dtype} | device: {tensor.device}")
    
    # Check for common operations
    if len(tensors) == 2:
        A, B = tensors
        print(f"\nCompatibility Check:")
        print(f"A @ B possible: {A.shape[-1] == B.shape[0] if A.ndim >= 1 and B.ndim >= 1 else False}")
        
        # Broadcasting check
        try:
            broadcast_shape = torch.broadcast_shapes(A.shape, B.shape)
            print(f"A + B broadcasts to: {broadcast_shape}")
        except RuntimeError:
            print(f"A + B: Broadcasting not possible")

# Demonstrate the debugging utility
print("Using debug utility:")
A = torch.randn(2, 3, 4)
B = torch.randn(4, 5)
debug_shapes(A, B, names=['A', 'B'])

print(f"\n💡 Pro Shape Debugging Tips")
print("=" * 50)
print("1. **Always print shapes**: Add print(f'Shape: {tensor.shape}') everywhere")
print("2. **Use .shape, not .size()**: .shape is clearer and more pythonic")
print("3. **Check device and dtype**: Mismatches cause subtle errors")
print("4. **Use named dimensions**: Comment what each dimension represents")
print("5. **Test with small tensors**: Debug with tiny shapes first")
print("6. **Use torch.einsum()**: More explicit than @ for complex operations")

print(f"\n🎯 Shape Pattern Recognition")
print("=" * 50)
print("Common PyTorch shape patterns:")
print("• Batch dimension first: [batch_size, ...]")
print("• Images: [batch, channels, height, width]")
print("• Sequences: [batch, seq_len, features]")
print("• Linear layers: [..., input_features] → [..., output_features]")
print("• Conv2d: [batch, in_channels, H, W] → [batch, out_channels, H', W']")
print("• Attention: [batch, seq_len, d_model] for all Q, K, V")

In [None]:
# Demonstrate batch norm issues
print("\n=== Batch Norm Mode Issues ===")

# Create model with batch norm
model_bn = nn.Sequential(
    nn.Linear(10, 20),
    nn.BatchNorm1d(20),
    nn.ReLU(),
    nn.Linear(20, 1)
)

# ❌ GOTCHA: Using batch norm with very small batches
print("\nBatch norm with different batch sizes:")

model_bn.train()

# Large batch - works fine
large_batch = torch.randn(32, 10)
out_large = model_bn(large_batch)
print(f"Large batch (32): mean={out_large.mean().item():.4f}, std={out_large.std().item():.4f}")

# Small batch - can be unstable
small_batch = torch.randn(2, 10)  # Very small batch
out_small = model_bn(small_batch)
print(f"Small batch (2):  mean={out_small.mean().item():.4f}, std={out_small.std().item():.4f}")

# Single sample - will error!
try:
    single_sample = torch.randn(1, 10)
    out_single = model_bn(single_sample)
except Exception as e:
    print(f"Single sample error: {e}")

print("\n💡 Solution: Use LayerNorm for small batches or add .unsqueeze(0) for single samples")

# Alternative: Use LayerNorm instead
model_ln = nn.Sequential(
    nn.Linear(10, 20),
    nn.LayerNorm(20),  # LayerNorm works with any batch size
    nn.ReLU(),
    nn.Linear(20, 1)
)

single_sample = torch.randn(1, 10)
out_ln = model_ln(single_sample)
print(f"LayerNorm with single sample: {out_ln.item():.4f}")

## Autograd Gotchas

In [None]:
print("=== Autograd Gotchas ===")

# ❌ GOTCHA 1: Forgetting to zero gradients
print("\n1. Forgetting to zero gradients")
print("-" * 40)

model = nn.Linear(5, 1)
optimizer = optim.SGD(model.parameters(), lr=0.01)
criterion = nn.MSELoss()

x = torch.randn(10, 5)
y_true = torch.randn(10, 1)

print("Training without zero_grad():")
for epoch in range(3):
    y_pred = model(x)
    loss = criterion(y_pred, y_true)
    loss.backward()  # Gradients accumulate!
    
    # Check gradient magnitude
    grad_norm = model.weight.grad.norm().item()
    print(f"Epoch {epoch}: Loss = {loss.item():.4f}, Grad norm = {grad_norm:.4f}")
    
    optimizer.step()  # ❌ No zero_grad()!

# ✅ CORRECT: Always zero gradients
print("\nTraining WITH zero_grad():")
model = nn.Linear(5, 1)  # Reset model
optimizer = optim.SGD(model.parameters(), lr=0.01)

for epoch in range(3):
    optimizer.zero_grad()  # ✅ Clear gradients first
    
    y_pred = model(x)
    loss = criterion(y_pred, y_true)
    loss.backward()
    
    grad_norm = model.weight.grad.norm().item()
    print(f"Epoch {epoch}: Loss = {loss.item():.4f}, Grad norm = {grad_norm:.4f}")
    
    optimizer.step()

print("\n📝 Notice how gradient norms are much smaller with proper zero_grad()")

In [None]:
# ❌ GOTCHA 2: In-place operations breaking autograd
print("\n2. In-place operations breaking autograd")
print("-" * 40)

# This will cause problems
try:
    x = torch.randn(5, requires_grad=True)
    y = x * 2
    x[0] = 0  # ❌ In-place modification!
    loss = y.sum()
    loss.backward()  # This might error or give wrong gradients
except RuntimeError as e:
    print(f"In-place operation error: {e}")

# ✅ CORRECT: Use non-in-place operations
x = torch.randn(5, requires_grad=True)
y = x * 2

# Option 1: Clone and modify
x_modified = x.clone()
x_modified[0] = 0
y_modified = x_modified * 2

# Option 2: Create mask for modification
mask = torch.ones_like(x)
mask[0] = 0
y_masked = y * mask

loss = y_masked.sum()
loss.backward()
print(f"Gradients computed successfully: {x.grad is not None}")
print(f"Gradient: {x.grad}")

print("\n💡 Avoid in-place operations on tensors with requires_grad=True")

In [None]:
# ❌ GOTCHA 3: Gradient accumulation confusion
print("\n3. Gradient accumulation behavior")
print("-" * 40)

# Sometimes gradient accumulation is intentional (for large effective batch sizes)
model = nn.Linear(3, 1)
optimizer = optim.SGD(model.parameters(), lr=0.01)

# Simulate gradient accumulation over 4 mini-batches
mini_batches = [
    (torch.randn(2, 3), torch.randn(2, 1)),
    (torch.randn(2, 3), torch.randn(2, 1)),
    (torch.randn(2, 3), torch.randn(2, 1)),
    (torch.randn(2, 3), torch.randn(2, 1))
]

print("Intentional gradient accumulation (effective batch size = 8):")
accumulation_steps = 4
total_loss = 0

for i, (x_batch, y_batch) in enumerate(mini_batches):
    y_pred = model(x_batch)
    loss = F.mse_loss(y_pred, y_batch)
    
    # Scale loss by accumulation steps (important!)
    loss = loss / accumulation_steps
    loss.backward()
    
    total_loss += loss.item() * accumulation_steps
    
    print(f"  Mini-batch {i}: Loss = {loss.item() * accumulation_steps:.4f}")

# Update after accumulating gradients from all mini-batches
optimizer.step()
optimizer.zero_grad()

print(f"Total accumulated loss: {total_loss:.4f}")
print("\n💡 Scale gradients by 1/accumulation_steps to maintain equivalent learning")

## Data Type and Shape Issues

In [None]:
print("=== Data Type and Shape Issues ===")

# ❌ GOTCHA 1: Wrong target dtype for CrossEntropyLoss
print("\n1. Wrong target dtype for CrossEntropyLoss")
print("-" * 50)

logits = torch.randn(4, 5)  # 4 samples, 5 classes
criterion = nn.CrossEntropyLoss()

# ❌ Wrong: Float targets
try:
    targets_wrong = torch.tensor([0., 1., 2., 3.])  # Float
    loss = criterion(logits, targets_wrong)
except Exception as e:
    print(f"Float targets error: {type(e).__name__}: {e}")

# ✅ Correct: Long tensor targets
targets_correct = torch.tensor([0, 1, 2, 3])  # Long
loss = criterion(logits, targets_correct)
print(f"Correct loss with Long targets: {loss.item():.4f}")

print(f"\nLogits dtype: {logits.dtype}")
print(f"Correct targets dtype: {targets_correct.dtype}")

# ❌ GOTCHA 2: Shape mismatches
print("\n2. Common shape mismatches")
print("-" * 30)

# Problem: Unexpected dimension removal
x = torch.randn(1, 10)  # Shape: [1, 10]
print(f"Original shape: {x.shape}")

# ❌ Dangerous: squeeze() without arguments
x_squeezed = x.squeeze()  # Removes ALL dimensions of size 1
print(f"After squeeze(): {x_squeezed.shape}")

# Now if we try to use this with a model expecting 2D input:
model = nn.Linear(10, 5)
try:
    output = model(x_squeezed)  # Will error - expects 2D
except Exception as e:
    print(f"Shape error: {e}")

# ✅ Better: Be specific about which dimensions to squeeze
x_safe = x.squeeze(0) if x.size(0) == 1 else x  # Only squeeze batch dim if size 1
print(f"Safe squeeze result: {x_safe.shape}")

# Or keep original shape
output = model(x)  # Works fine with [1, 10]
print(f"Model output shape: {output.shape}")

In [None]:
# ❌ GOTCHA 3: Broadcasting surprises
print("\n3. Unexpected broadcasting")
print("-" * 30)

# This might not do what you expect
a = torch.randn(3, 1)  # Shape: [3, 1]
b = torch.randn(4)     # Shape: [4]

c = a + b  # Broadcasting: [3, 1] + [4] -> [3, 4]
print(f"a.shape: {a.shape}, b.shape: {b.shape}")
print(f"Result shape: {c.shape}")
print("This creates a 3x4 tensor - might not be intended!")

# ✅ Be explicit about your intentions
print("\nBetter: Be explicit about dimensions")
a = torch.randn(3, 1)
b = torch.randn(1, 4)  # Make the broadcast intention clear
c = a + b  # Now clearly [3, 1] + [1, 4] -> [3, 4]
print(f"a.shape: {a.shape}, b.shape: {b.shape} -> result: {c.shape}")

# ❌ GOTCHA 4: Dimension confusion with batch_first
print("\n4. batch_first confusion")
print("-" * 30)

# Old-style RNN (batch_first=False) - confusing!
rnn_old = nn.LSTM(input_size=10, hidden_size=20, batch_first=False)
x_old_style = torch.randn(15, 32, 10)  # [seq_len, batch, features]
output_old, _ = rnn_old(x_old_style)
print(f"Old style - input: {x_old_style.shape}, output: {output_old.shape}")
print("Confusing: sequence length comes first!")

# ✅ Modern style (batch_first=True) - intuitive
rnn_new = nn.LSTM(input_size=10, hidden_size=20, batch_first=True)
x_new_style = torch.randn(32, 15, 10)  # [batch, seq_len, features]
output_new, _ = rnn_new(x_new_style)
print(f"New style - input: {x_new_style.shape}, output: {output_new.shape}")
print("Intuitive: batch comes first, like other layers")

print("\n💡 Always use batch_first=True for consistency")

## Memory and Device Problems

In [None]:
print("=== Memory and Device Problems ===")

# ❌ GOTCHA 1: Model and data on different devices
print("\n1. Device mismatches")
print("-" * 25)

model = nn.Linear(10, 5)
x = torch.randn(4, 10)

print(f"Model device: {next(model.parameters()).device}")
print(f"Data device: {x.device}")

# If you have CUDA available, demonstrate the error
if torch.cuda.is_available():
    model_gpu = model.cuda()
    try:
        output = model_gpu(x)  # x is still on CPU!
    except RuntimeError as e:
        print(f"Device mismatch error: {e}")
    
    # ✅ Solution: Move both to same device
    x_gpu = x.cuda()
    output = model_gpu(x_gpu)
    print(f"Success! Output shape: {output.shape}, device: {output.device}")
else:
    print("No CUDA available - device mismatch demo skipped")

# Better pattern: Use device variable
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
x = x.to(device)
output = model(x)
print(f"Using device variable - output device: {output.device}")

# ❌ GOTCHA 2: Memory leaks in training loops
print("\n2. Memory leak patterns")
print("-" * 25)

# This can cause memory leaks
def bad_training_loop():
    model = nn.Linear(100, 10)
    optimizer = optim.SGD(model.parameters(), lr=0.01)
    
    losses = []  # ❌ Storing loss tensors
    
    for epoch in range(5):
        x = torch.randn(32, 100)
        y = torch.randint(0, 10, (32,))
        
        optimizer.zero_grad()
        output = model(x)
        loss = F.cross_entropy(output, y)
        loss.backward()
        optimizer.step()
        
        losses.append(loss)  # ❌ Keeps computation graph in memory!
    
    return losses

# ✅ Correct: Extract scalar values
def good_training_loop():
    model = nn.Linear(100, 10)
    optimizer = optim.SGD(model.parameters(), lr=0.01)
    
    losses = []  # ✅ Store scalar values only
    
    for epoch in range(5):
        x = torch.randn(32, 100)
        y = torch.randint(0, 10, (32,))
        
        optimizer.zero_grad()
        output = model(x)
        loss = F.cross_entropy(output, y)
        loss.backward()
        optimizer.step()
        
        losses.append(loss.item())  # ✅ Extract scalar value
    
    return losses

# Compare memory usage (simplified demonstration)
bad_losses = bad_training_loop()
good_losses = good_training_loop()

print(f"Bad pattern - storing tensors: {type(bad_losses[0])}")
print(f"Good pattern - storing scalars: {type(good_losses[0])}")
print("\n💡 Always use .item() to extract scalar values for logging")

In [None]:
# ❌ GOTCHA 3: Inefficient tensor operations
print("\n3. Inefficient tensor operations")
print("-" * 35)

# ❌ Inefficient: Python loops for tensor operations
def slow_sum(tensor):
    result = 0
    for i in range(tensor.size(0)):
        for j in range(tensor.size(1)):
            result += tensor[i, j].item()  # ❌ Very slow!
    return result

# ✅ Efficient: Use vectorized operations
def fast_sum(tensor):
    return tensor.sum().item()  # ✅ Fast!

# Compare performance
import time

test_tensor = torch.randn(100, 100)

start = time.time()
slow_result = slow_sum(test_tensor)
slow_time = time.time() - start

start = time.time()
fast_result = fast_sum(test_tensor)
fast_time = time.time() - start

print(f"Slow method: {slow_time:.4f}s, result: {slow_result:.4f}")
print(f"Fast method: {fast_time:.6f}s, result: {fast_result:.4f}")
print(f"Speedup: {slow_time/fast_time:.1f}x faster")

# ❌ GOTCHA 4: Creating tensors in loops
print("\n4. Tensor creation in loops")
print("-" * 30)

# ❌ Slow: Creating tensors in loop
def slow_batch_creation():
    batch = []
    for i in range(32):
        sample = torch.randn(10)  # ❌ Creating tensor in loop
        batch.append(sample)
    return torch.stack(batch)

# ✅ Fast: Create all at once
def fast_batch_creation():
    return torch.randn(32, 10)  # ✅ Create entire batch at once

# Time comparison
start = time.time()
slow_batch = slow_batch_creation()
slow_batch_time = time.time() - start

start = time.time()
fast_batch = fast_batch_creation()
fast_batch_time = time.time() - start

print(f"Slow batch creation: {slow_batch_time:.4f}s")
print(f"Fast batch creation: {fast_batch_time:.6f}s")
print(f"Same result: {torch.allclose(slow_batch.sum(), fast_batch.sum(), atol=1e-3)}")

print("\n💡 Avoid Python loops for tensor operations - use vectorization!")

## Debugging Techniques

In [None]:
print("=== Debugging Techniques ===")

# 1. Shape debugging
print("\n1. Shape Debugging")
print("-" * 20)

def debug_shapes_example():
    """Example of how to debug shape issues"""
    
    batch_size, seq_len, d_model = 4, 10, 64
    
    # Input
    x = torch.randn(batch_size, seq_len, d_model)
    print(f"Input shape: {x.shape}")
    
    # Linear layer
    linear = nn.Linear(d_model, d_model * 2)
    x = linear(x)
    print(f"After linear: {x.shape}")
    
    # Reshape for multi-head attention
    num_heads = 8
    head_dim = (d_model * 2) // num_heads
    x = x.view(batch_size, seq_len, num_heads, head_dim)
    print(f"After reshape: {x.shape}")
    
    # Transpose for attention
    x = x.transpose(1, 2)  # [batch, heads, seq, head_dim]
    print(f"After transpose: {x.shape}")
    
    return x

result = debug_shapes_example()
print("\n💡 Always print shapes when debugging complex transformations")

# 2. Gradient checking
print("\n2. Gradient Checking")
print("-" * 20)

def check_gradients(model):
    """Utility to check gradient flow"""
    print("\nGradient check:")
    for name, param in model.named_parameters():
        if param.grad is not None:
            grad_norm = param.grad.norm().item()
            print(f"  {name}: grad_norm = {grad_norm:.6f}")
            
            # Check for problematic gradients
            if grad_norm == 0:
                print(f"    ⚠️  Zero gradients in {name}")
            elif grad_norm > 10:
                print(f"    ⚠️  Large gradients in {name} (possible explosion)")
            elif torch.isnan(param.grad).any():
                print(f"    ❌ NaN gradients in {name}")
        else:
            print(f"  {name}: No gradients")

# Test gradient checking
model = nn.Sequential(
    nn.Linear(10, 20),
    nn.ReLU(),
    nn.Linear(20, 1)
)

x = torch.randn(5, 10)
y = torch.randn(5, 1)

output = model(x)
loss = F.mse_loss(output, y)
loss.backward()

check_gradients(model)

# 3. NaN/Inf detection
print("\n3. NaN/Inf Detection")
print("-" * 20)

def check_tensor_health(tensor, name="tensor"):
    """Check if tensor contains NaN or Inf values"""
    has_nan = torch.isnan(tensor).any().item()
    has_inf = torch.isinf(tensor).any().item()
    
    if has_nan:
        print(f"❌ {name} contains NaN values!")
        return False
    if has_inf:
        print(f"❌ {name} contains Inf values!")
        return False
    
    print(f"✅ {name} is healthy")
    return True

# Test with problematic tensor
good_tensor = torch.randn(5, 5)
bad_tensor = torch.tensor([[1.0, 2.0, float('nan')], [4.0, float('inf'), 6.0]])

check_tensor_health(good_tensor, "good_tensor")
check_tensor_health(bad_tensor, "bad_tensor")

In [None]:
# 4. Model surgery - examining intermediate outputs
print("\n4. Model Surgery with Hooks")
print("-" * 30)

class DebuggableModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(10, 20)
        self.fc2 = nn.Linear(20, 10)
        self.fc3 = nn.Linear(10, 1)
        
        # Storage for intermediate activations
        self.activations = {}
        
        # Register hooks for debugging
        self.fc1.register_forward_hook(self.save_activation('fc1'))
        self.fc2.register_forward_hook(self.save_activation('fc2'))
        self.fc3.register_forward_hook(self.save_activation('fc3'))
    
    def save_activation(self, name):
        def hook(module, input, output):
            self.activations[name] = output.detach()
        return hook
    
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
    
    def print_activation_stats(self):
        print("Activation statistics:")
        for name, activation in self.activations.items():
            mean = activation.mean().item()
            std = activation.std().item()
            sparsity = (activation == 0).float().mean().item()
            print(f"  {name}: mean={mean:.4f}, std={std:.4f}, sparsity={sparsity:.2%}")

# Test the debuggable model
debug_model = DebuggableModel()
x = torch.randn(8, 10)
output = debug_model(x)

debug_model.print_activation_stats()

# 5. Quick debugging utilities
print("\n5. Quick Debugging Utilities")
print("-" * 30)

def quick_stats(tensor, name="tensor"):
    """Print quick statistics about a tensor"""
    print(f"{name}: shape={tensor.shape}, dtype={tensor.dtype}, device={tensor.device}")
    print(f"  min={tensor.min().item():.4f}, max={tensor.max().item():.4f}")
    print(f"  mean={tensor.mean().item():.4f}, std={tensor.std().item():.4f}")
    
    # Check for common issues
    if tensor.requires_grad:
        print(f"  requires_grad=True")
    if torch.isnan(tensor).any():
        print(f"  ⚠️ Contains NaN values")
    if torch.isinf(tensor).any():
        print(f"  ⚠️ Contains Inf values")

# Test the utility
test_tensor = torch.randn(3, 4, requires_grad=True)
quick_stats(test_tensor, "test_tensor")

# 6. Model parameter overview
def model_overview(model):
    """Print comprehensive model overview"""
    print("\nModel Overview:")
    print("=" * 50)
    
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    
    print(f"Total parameters: {total_params:,}")
    print(f"Trainable parameters: {trainable_params:,}")
    print(f"Model size: {total_params * 4 / 1024**2:.2f} MB (assuming float32)")
    
    print("\nLayer details:")
    for name, module in model.named_modules():
        if len(list(module.children())) == 0:  # Leaf modules only
            params = sum(p.numel() for p in module.parameters())
            print(f"  {name}: {module.__class__.__name__}, {params:,} params")

model_overview(debug_model)

print("\n🎉 Debugging toolkit complete!")
print("\nKey debugging strategies:")
print("• Always print shapes when developing")
print("• Use hooks to inspect intermediate outputs")
print("• Check for NaN/Inf values regularly")
print("• Monitor gradient norms during training")
print("• Create debugging utilities for common checks")
print("• Use tensor.item() to avoid memory leaks")
print("• Be explicit about device placement")
print("• Understand broadcasting behavior")