# PyTorch Internals - Hands-On Exploration

This notebook provides interactive exploration of PyTorch's internal architecture.

## Topics
1. Tensor Data Structure
2. Storage and Views
3. Autograd Graph
4. Dispatcher Behavior
5. Profiling Internals
6. Memory Management

In [None]:
import torch
import torch.nn as nn
import gc

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

## 1. Tensor Data Structure (TensorImpl)

Every tensor has metadata stored in TensorImpl:
- sizes (shape)
- strides (memory layout)
- storage_offset
- dtype, device

In [None]:
# Create a tensor and inspect its internals
x = torch.tensor([[1, 2, 3], [4, 5, 6]])

print("Tensor:")
print(x)
print(f"\nShape: {x.shape}")
print(f"Strides: {x.stride()}")
print(f"Storage offset: {x.storage_offset()}")
print(f"Total elements: {x.numel()}")
print(f"Is contiguous: {x.is_contiguous()}")
print(f"\nDtype: {x.dtype}")
print(f"Device: {x.device}")

In [None]:
# Understanding strides
# Strides tell you how many elements to skip in memory to move one step in each dimension

x = torch.tensor([[1, 2, 3], [4, 5, 6]])
print(f"Original tensor: shape={x.shape}, strides={x.stride()}")
print(f"To move one row: skip {x.stride()[0]} elements")
print(f"To move one column: skip {x.stride()[1]} elements")

# Index calculation: tensor[i, j] = storage[offset + i*stride[0] + j*stride[1]]
i, j = 1, 2
offset = x.storage_offset() + i * x.stride()[0] + j * x.stride()[1]
print(f"\ntensor[{i}, {j}] = {x[i, j].item()}")
print(f"Calculated offset: {offset}")
print(f"storage[{offset}] = {x.storage()[offset]}")

## 2. Storage and Views

Views share the same underlying storage. Modifying one affects the other!

In [None]:
# Views share storage
x = torch.tensor([[1, 2, 3], [4, 5, 6]])
y = x[1, :]  # View of second row

print(f"x:\n{x}")
print(f"\ny = x[1, :] = {y}")
print(f"\ny.storage_offset() = {y.storage_offset()}")
print(f"Same storage? {x.storage().data_ptr() == y.storage().data_ptr()}")

# Modify y
y[0] = 100
print(f"\nAfter y[0] = 100:")
print(f"y = {y}")
print(f"x =\n{x}")
print("x was also modified because y is a view!")

In [None]:
# Transpose changes strides, not data
x = torch.tensor([[1, 2, 3], [4, 5, 6]])
x_t = x.t()

print(f"x shape: {x.shape}, strides: {x.stride()}")
print(f"x.t() shape: {x_t.shape}, strides: {x_t.stride()}")
print(f"\nx.t() is NOT contiguous: {x_t.is_contiguous()}")
print(f"Same storage: {x.storage().data_ptr() == x_t.storage().data_ptr()}")

# Make contiguous creates a copy
x_t_contig = x_t.contiguous()
print(f"\nx.t().contiguous() is contiguous: {x_t_contig.is_contiguous()}")
print(f"Different storage: {x.storage().data_ptr() != x_t_contig.storage().data_ptr()}")

## 3. Autograd Graph

PyTorch builds a computation graph dynamically during forward pass.

In [None]:
# Visualizing the autograd graph
x = torch.tensor([2.0], requires_grad=True)
y = torch.tensor([3.0], requires_grad=True)

# Build computation graph
z = x * y  # MulBackward
w = z + 1  # AddBackward
loss = w ** 2  # PowBackward

print("Computation: loss = ((x * y) + 1)^2")
print(f"x = {x.item()}, y = {y.item()}")
print(f"loss = {loss.item()}")

print(f"\n--- Autograd Graph ---")
print(f"loss.grad_fn = {loss.grad_fn}")
print(f"  └── {loss.grad_fn.next_functions[0][0]}")
print(f"      └── {loss.grad_fn.next_functions[0][0].next_functions[0][0]}")

In [None]:
# Backward pass
loss.backward()

print("After backward():")
print(f"x.grad = {x.grad.item()}")
print(f"y.grad = {y.grad.item()}")

# Manual calculation:
# loss = ((x*y) + 1)^2
# dloss/dx = 2*((x*y)+1) * y = 2*(6+1)*3 = 42
# dloss/dy = 2*((x*y)+1) * x = 2*(6+1)*2 = 28
print(f"\nManual verification:")
print(f"dloss/dx = 2*((x*y)+1)*y = 2*7*3 = 42 ✓")
print(f"dloss/dy = 2*((x*y)+1)*x = 2*7*2 = 28 ✓")

In [None]:
# Custom autograd function
class MyReLU(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        ctx.save_for_backward(input)
        return input.clamp(min=0)
    
    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_tensors
        grad_input = grad_output.clone()
        grad_input[input < 0] = 0
        return grad_input

# Test it
x = torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0], requires_grad=True)
y = MyReLU.apply(x)
y.sum().backward()

print(f"x = {x.data}")
print(f"MyReLU(x) = {y.data}")
print(f"x.grad = {x.grad}")
print("Gradient is 0 where x < 0, 1 where x >= 0")

## 4. Dispatcher Behavior

Use `__torch_dispatch__` to intercept all operations.

In [None]:
from torch.utils._python_dispatch import TorchDispatchMode

class OperationLogger(TorchDispatchMode):
    def __init__(self):
        self.operations = []
    
    def __torch_dispatch__(self, func, types, args=(), kwargs=None):
        kwargs = kwargs or {}
        self.operations.append({
            'name': str(func),
            'input_shapes': [a.shape if hasattr(a, 'shape') else type(a).__name__ for a in args]
        })
        return func(*args, **kwargs)

# Log operations in a forward pass
model = nn.Sequential(
    nn.Linear(10, 20),
    nn.ReLU(),
    nn.Linear(20, 5)
)

x = torch.randn(4, 10)

with OperationLogger() as logger:
    y = model(x)

print(f"Operations in forward pass:")
for i, op in enumerate(logger.operations[:10]):  # First 10
    print(f"  {i+1}. {op['name']}")
print(f"  ... ({len(logger.operations)} total operations)")

## 5. Profiling Internals

In [None]:
# Profile operations
model = nn.Sequential(
    nn.Linear(512, 1024),
    nn.ReLU(),
    nn.Linear(1024, 512)
)

if torch.cuda.is_available():
    model = model.cuda()
    x = torch.randn(32, 512, device='cuda')
else:
    x = torch.randn(32, 512)

# Profile with PyTorch profiler
with torch.profiler.profile(
    activities=[
        torch.profiler.ProfilerActivity.CPU,
        torch.profiler.ProfilerActivity.CUDA,
    ] if torch.cuda.is_available() else [torch.profiler.ProfilerActivity.CPU],
    record_shapes=True,
) as prof:
    for _ in range(10):
        y = model(x)
        if x.requires_grad:
            y.sum().backward()

print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))

## 6. Memory Management

In [None]:
if torch.cuda.is_available():
    # Reset memory stats
    torch.cuda.reset_peak_memory_stats()
    torch.cuda.empty_cache()
    
    print("Initial state:")
    print(f"  Allocated: {torch.cuda.memory_allocated() / 1e6:.1f} MB")
    print(f"  Cached: {torch.cuda.memory_reserved() / 1e6:.1f} MB")
    
    # Allocate tensors
    tensors = [torch.randn(1000, 1000, device='cuda') for _ in range(5)]
    
    print("\nAfter allocating 5 tensors:")
    print(f"  Allocated: {torch.cuda.memory_allocated() / 1e6:.1f} MB")
    print(f"  Cached: {torch.cuda.memory_reserved() / 1e6:.1f} MB")
    
    # Delete tensors
    del tensors
    gc.collect()
    torch.cuda.synchronize()
    
    print("\nAfter deleting (memory cached, not freed):")
    print(f"  Allocated: {torch.cuda.memory_allocated() / 1e6:.1f} MB")
    print(f"  Cached: {torch.cuda.memory_reserved() / 1e6:.1f} MB")
    
    # Empty cache
    torch.cuda.empty_cache()
    
    print("\nAfter empty_cache():")
    print(f"  Allocated: {torch.cuda.memory_allocated() / 1e6:.1f} MB")
    print(f"  Cached: {torch.cuda.memory_reserved() / 1e6:.1f} MB")
else:
    print("CUDA not available")

In [None]:
# Gradient checkpointing for memory efficiency
from torch.utils.checkpoint import checkpoint

class HeavyBlock(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.linear = nn.Linear(dim, dim)
    
    def forward(self, x):
        return torch.relu(self.linear(x))

dim = 512
num_layers = 8
device = 'cuda' if torch.cuda.is_available() else 'cpu'

layers = nn.ModuleList([HeavyBlock(dim) for _ in range(num_layers)]).to(device)

def forward_normal(x, layers):
    for layer in layers:
        x = layer(x)
    return x

def forward_checkpoint(x, layers):
    for layer in layers:
        x = checkpoint(layer, x, use_reentrant=False)
    return x

x = torch.randn(32, dim, device=device, requires_grad=True)

if torch.cuda.is_available():
    # Normal forward
    torch.cuda.reset_peak_memory_stats()
    y = forward_normal(x.clone().requires_grad_(True), layers)
    y.sum().backward()
    mem_normal = torch.cuda.max_memory_allocated() / 1e6
    
    # Checkpoint forward
    torch.cuda.reset_peak_memory_stats()
    y = forward_checkpoint(x.clone().requires_grad_(True), layers)
    y.sum().backward()
    mem_checkpoint = torch.cuda.max_memory_allocated() / 1e6
    
    print(f"Peak memory (normal): {mem_normal:.1f} MB")
    print(f"Peak memory (checkpoint): {mem_checkpoint:.1f} MB")
    print(f"Savings: {(1 - mem_checkpoint/mem_normal)*100:.1f}%")
else:
    print("CUDA not available for memory comparison")

## Summary

Key PyTorch internals explored:

1. **TensorImpl** - Core data structure with sizes, strides, storage
2. **Storage** - Actual memory, shared by views
3. **Autograd** - Dynamic computation graph with grad_fn nodes
4. **Dispatcher** - Routes operations to implementations
5. **Profiler** - Understand performance characteristics
6. **Memory** - Caching allocator, checkpointing

In [None]:
print("""
╔══════════════════════════════════════════════════════════════════╗
║                  PYTORCH INTERNALS SUMMARY                       ║
╠══════════════════════════════════════════════════════════════════╣
║                                                                  ║
║  TENSOR = TensorImpl + Storage                                   ║
║    • sizes, strides define logical view                          ║
║    • storage holds actual data                                   ║
║    • views share storage                                         ║
║                                                                  ║
║  AUTOGRAD = Dynamic computation graph                            ║
║    • grad_fn links to backward function                          ║
║    • gradients accumulate in leaf tensors                        ║
║    • custom functions via torch.autograd.Function                ║
║                                                                  ║
║  DISPATCHER = Multi-level routing                                ║
║    • DispatchKey identifies functionality                        ║
║    • Priority determines execution order                         ║
║    • __torch_dispatch__ for Python hooks                         ║
║                                                                  ║
║  MEMORY = Caching allocator                                      ║
║    • Avoids cudaMalloc overhead                                  ║
║    • empty_cache() returns to system                             ║
║    • Checkpointing trades compute for memory                     ║
║                                                                  ║
╚══════════════════════════════════════════════════════════════════╝
""")