# Module 1.3: nn.Module Architecture

`nn.Module` is the base class for all neural network components in PyTorch. Understanding its internals is essential for:
- Building custom layers and architectures
- Debugging model behavior
- Properly managing model state (training vs evaluation)
- Implementing advanced patterns (hooks, custom initialization)

## Learning Objectives
- Understand how `nn.Module` tracks parameters and submodules
- Master parameter registration and access patterns
- Use forward and backward hooks effectively
- Properly handle train/eval modes and their implications
- Implement custom layers with learnable parameters
- Apply proper weight initialization strategies

---

## Setup

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

print(f"PyTorch version: {torch.__version__}")

---
## 1. The nn.Module Basics

Every neural network in PyTorch inherits from `nn.Module`. Let's understand what it provides.

In [None]:
# The simplest possible module
class SimpleModule(nn.Module):
    def __init__(self):
        super().__init__()  # ALWAYS call super().__init__()
    
    def forward(self, x):
        return x * 2

model = SimpleModule()
x = torch.tensor([1.0, 2.0, 3.0])
y = model(x)  # Calls forward() via __call__
print(f"Output: {y}")

In [None]:
# Why use __call__ instead of forward directly?
# __call__ does more than just forward():
# 1. Runs registered hooks
# 2. Handles autograd properly
# 3. Manages module state

# NEVER call forward() directly - always use model(x)
print("Always use model(x), not model.forward(x)")

### 1.1 What nn.Module Tracks

In [None]:
# nn.Module maintains several internal dictionaries
class DemoModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(10, 5)  # Submodule
        self.my_param = nn.Parameter(torch.randn(3))  # Parameter
        self.register_buffer('my_buffer', torch.zeros(4))  # Buffer
        self.some_value = 42  # Regular Python attribute
    
    def forward(self, x):
        return self.linear(x)

model = DemoModule()

print("Tracked by nn.Module:")
print(f"  _modules: {list(model._modules.keys())}")
print(f"  _parameters: {list(model._parameters.keys())}")
print(f"  _buffers: {list(model._buffers.keys())}")
print(f"\nRegular attribute (not tracked): some_value = {model.some_value}")

---
## 2. Parameters

Parameters are tensors that require gradients and are registered with the module for optimization.

### 2.1 nn.Parameter

In [None]:
# nn.Parameter is a Tensor subclass that auto-registers with the module
param = nn.Parameter(torch.randn(3, 4))

print(f"Is a tensor: {isinstance(param, torch.Tensor)}")
print(f"requires_grad: {param.requires_grad}")  # True by default
print(f"Shape: {param.shape}")

In [None]:
# Parameters are automatically registered when assigned as attributes
class MyLayer(nn.Module):
    def __init__(self, in_features, out_features):
        super().__init__()
        # These get registered automatically
        self.weight = nn.Parameter(torch.randn(out_features, in_features))
        self.bias = nn.Parameter(torch.zeros(out_features))
    
    def forward(self, x):
        return x @ self.weight.T + self.bias

layer = MyLayer(10, 5)
print("Registered parameters:")
for name, param in layer.named_parameters():
    print(f"  {name}: {param.shape}")

In [None]:
# Regular tensors are NOT registered
class BrokenLayer(nn.Module):
    def __init__(self):
        super().__init__()
        # This is a regular tensor, NOT a parameter!
        self.weight = torch.randn(5, 10, requires_grad=True)
    
    def forward(self, x):
        return x @ self.weight.T

broken = BrokenLayer()
print(f"Number of parameters: {sum(1 for _ in broken.parameters())}")
print("The weight won't be optimized or saved!")

### 2.2 Accessing Parameters

In [None]:
# Build a more complex model
class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer1 = nn.Linear(10, 20)
        self.layer2 = nn.Linear(20, 15)
        self.layer3 = nn.Linear(15, 5)
    
    def forward(self, x):
        x = F.relu(self.layer1(x))
        x = F.relu(self.layer2(x))
        return self.layer3(x)

model = MLP()

In [None]:
# parameters() - iterator over all parameters (recursive)
print("All parameters:")
for i, param in enumerate(model.parameters()):
    print(f"  {i}: {param.shape}")

total_params = sum(p.numel() for p in model.parameters())
print(f"\nTotal: {total_params} parameters")

In [None]:
# named_parameters() - includes the name
print("Named parameters:")
for name, param in model.named_parameters():
    print(f"  {name}: {param.shape}")

In [None]:
# Access specific parameter by name
print(f"layer1.weight shape: {model.layer1.weight.shape}")
print(f"layer1.bias shape: {model.layer1.bias.shape}")

# Or use get_parameter (more programmatic)
param = model.get_parameter('layer2.weight')
print(f"\nget_parameter('layer2.weight'): {param.shape}")

In [None]:
# Filtering parameters (e.g., for different learning rates)
def get_layer_groups(model):
    """Group parameters by layer for different learning rates."""
    early_layers = []
    late_layers = []
    
    for name, param in model.named_parameters():
        if 'layer1' in name or 'layer2' in name:
            early_layers.append(param)
        else:
            late_layers.append(param)
    
    return early_layers, late_layers

early, late = get_layer_groups(model)
print(f"Early layers: {len(early)} tensors")
print(f"Late layers: {len(late)} tensors")

# Use with optimizer for differential learning rates:
# optimizer = torch.optim.Adam([
#     {'params': early, 'lr': 1e-4},
#     {'params': late, 'lr': 1e-3}
# ])

### 2.3 Freezing Parameters

In [None]:
# Freeze parameters by setting requires_grad = False
model = MLP()

# Freeze layer1
for param in model.layer1.parameters():
    param.requires_grad = False

print("After freezing layer1:")
for name, param in model.named_parameters():
    print(f"  {name}: requires_grad = {param.requires_grad}")

In [None]:
# Only pass trainable parameters to optimizer
trainable_params = filter(lambda p: p.requires_grad, model.parameters())
optimizer = torch.optim.Adam(trainable_params, lr=0.001)

# Or use list comprehension
trainable = [p for p in model.parameters() if p.requires_grad]
print(f"Trainable parameters: {len(trainable)}")

---
## 3. Submodules

Modules can contain other modules, creating a tree structure.

In [None]:
# Submodules are automatically registered when assigned as attributes
class Block(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.linear = nn.Linear(dim, dim)
        self.norm = nn.LayerNorm(dim)
    
    def forward(self, x):
        return self.norm(F.relu(self.linear(x)) + x)  # Residual

class Network(nn.Module):
    def __init__(self, dim, num_blocks):
        super().__init__()
        self.embed = nn.Linear(10, dim)
        self.blocks = nn.ModuleList([Block(dim) for _ in range(num_blocks)])
        self.head = nn.Linear(dim, 5)
    
    def forward(self, x):
        x = self.embed(x)
        for block in self.blocks:
            x = block(x)
        return self.head(x)

model = Network(dim=64, num_blocks=3)
print(model)

In [None]:
# Iterating over submodules
print("Direct children (modules()):")
for name, module in model.named_modules():
    print(f"  {name}: {module.__class__.__name__}")

In [None]:
# children() vs modules()
print("children() - direct children only:")
for name, child in model.named_children():
    print(f"  {name}: {child.__class__.__name__}")

print("\nmodules() - all modules recursively (shown above)")

### 3.1 ModuleList vs Python List

In [None]:
# ModuleList properly registers modules
class GoodModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.ModuleList([nn.Linear(10, 10) for _ in range(3)])
    
    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

# Python list does NOT register modules!
class BadModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = [nn.Linear(10, 10) for _ in range(3)]  # Regular list!
    
    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

good = GoodModel()
bad = BadModel()

print(f"GoodModel parameters: {sum(p.numel() for p in good.parameters())}")
print(f"BadModel parameters: {sum(p.numel() for p in bad.parameters())}")
print("\nBadModel layers won't be saved, loaded to GPU, or optimized!")

In [None]:
# ModuleDict for named access
class MultiHeadModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.backbone = nn.Linear(10, 20)
        self.heads = nn.ModuleDict({
            'classification': nn.Linear(20, 10),
            'regression': nn.Linear(20, 1),
            'embedding': nn.Linear(20, 64)
        })
    
    def forward(self, x, head_name):
        x = F.relu(self.backbone(x))
        return self.heads[head_name](x)

model = MultiHeadModel()
x = torch.randn(5, 10)
print(f"Classification output: {model(x, 'classification').shape}")
print(f"Regression output: {model(x, 'regression').shape}")

---
## 4. Buffers

Buffers are tensors that should be part of the module state but don't require gradients (not optimized).

In [None]:
# Use cases for buffers:
# - Running statistics (BatchNorm)
# - Position encodings
# - Masks that should move with the model

class BatchNormLike(nn.Module):
    def __init__(self, num_features):
        super().__init__()
        # Learnable parameters
        self.weight = nn.Parameter(torch.ones(num_features))
        self.bias = nn.Parameter(torch.zeros(num_features))
        
        # Buffers - saved with model, move with .to(device), but not optimized
        self.register_buffer('running_mean', torch.zeros(num_features))
        self.register_buffer('running_var', torch.ones(num_features))
        self.register_buffer('num_batches_tracked', torch.tensor(0))
    
    def forward(self, x):
        if self.training:
            mean = x.mean(dim=0)
            var = x.var(dim=0)
            # Update running stats
            self.running_mean = 0.9 * self.running_mean + 0.1 * mean
            self.running_var = 0.9 * self.running_var + 0.1 * var
            self.num_batches_tracked += 1
        else:
            mean = self.running_mean
            var = self.running_var
        
        x_norm = (x - mean) / (var + 1e-5).sqrt()
        return self.weight * x_norm + self.bias

layer = BatchNormLike(10)
print("Parameters (optimized):")
for name, param in layer.named_parameters():
    print(f"  {name}: {param.shape}")

print("\nBuffers (not optimized):")
for name, buf in layer.named_buffers():
    print(f"  {name}: {buf.shape}")

In [None]:
# Buffers move with the model
print(f"Before: running_mean device = {layer.running_mean.device}")

if torch.cuda.is_available():
    layer = layer.cuda()
    print(f"After .cuda(): running_mean device = {layer.running_mean.device}")
    layer = layer.cpu()

# Buffers are saved with state_dict
state = layer.state_dict()
print(f"\nState dict keys: {list(state.keys())}")

In [None]:
# Persistent vs non-persistent buffers
class LayerWithBuffers(nn.Module):
    def __init__(self):
        super().__init__()
        # Persistent (default) - saved in state_dict
        self.register_buffer('saved_buffer', torch.zeros(5))
        
        # Non-persistent - NOT saved in state_dict
        self.register_buffer('temp_buffer', torch.zeros(5), persistent=False)

layer = LayerWithBuffers()
print(f"state_dict keys: {list(layer.state_dict().keys())}")
print("Note: temp_buffer is not in state_dict")

---
## 5. Train vs Eval Mode

Modules can behave differently during training vs inference.

In [None]:
model = nn.Sequential(
    nn.Linear(10, 20),
    nn.BatchNorm1d(20),
    nn.Dropout(0.5),
    nn.Linear(20, 5)
)

print(f"Default training mode: {model.training}")

model.eval()  # Set to evaluation mode
print(f"After eval(): {model.training}")

model.train()  # Set back to training mode
print(f"After train(): {model.training}")

In [None]:
# train()/eval() propagates to all submodules
model.eval()
for name, module in model.named_modules():
    if name:
        print(f"{name}: training = {module.training}")

### 5.1 Modules Affected by train/eval

In [None]:
# Dropout: disabled in eval mode
dropout = nn.Dropout(p=0.5)
x = torch.ones(10)

dropout.train()
print(f"Training mode: {dropout(x)}")

dropout.eval()
print(f"Eval mode: {dropout(x)}")  # No dropout applied

In [None]:
# BatchNorm: uses batch stats (train) vs running stats (eval)
bn = nn.BatchNorm1d(3)

# Fake training to build running stats
bn.train()
for _ in range(100):
    x = torch.randn(32, 3) * 2 + 5  # Mean ~5, Std ~2
    _ = bn(x)

print(f"Running mean: {bn.running_mean}")
print(f"Running var: {bn.running_var}")

# In eval mode, uses these stats instead of batch stats
bn.eval()
x_test = torch.randn(1, 3) * 2 + 5
y = bn(x_test)
print(f"\nEval output mean (should be ~0): {y.mean():.4f}")

In [None]:
# Common pattern: combine eval() with no_grad()
model = nn.Sequential(
    nn.Linear(10, 20),
    nn.ReLU(),
    nn.Dropout(0.5),
    nn.Linear(20, 5)
)

def inference(model, x):
    model.eval()  # Disable dropout, use running stats
    with torch.no_grad():  # Don't compute gradients
        output = model(x)
    model.train()  # Set back if continuing training
    return output

x = torch.randn(5, 10)
predictions = inference(model, x)
print(f"Predictions shape: {predictions.shape}")

---
## 6. Hooks

Hooks let you execute custom code during forward or backward passes without modifying the module.

### 6.1 Forward Hooks

In [None]:
# Forward hook signature: hook(module, input, output) -> None or modified output

model = nn.Sequential(
    nn.Linear(10, 20),
    nn.ReLU(),
    nn.Linear(20, 5)
)

# Storage for activations
activations = {}

def save_activation(name):
    def hook(module, input, output):
        activations[name] = output.detach()
    return hook

# Register hooks
model[0].register_forward_hook(save_activation('linear1'))
model[1].register_forward_hook(save_activation('relu'))
model[2].register_forward_hook(save_activation('linear2'))

# Run forward pass
x = torch.randn(5, 10)
output = model(x)

print("Captured activations:")
for name, act in activations.items():
    print(f"  {name}: {act.shape}")

In [None]:
# Forward hook that modifies output
def add_noise_hook(module, input, output):
    noise = torch.randn_like(output) * 0.1
    return output + noise  # Return modified output

model = nn.Linear(10, 5)
handle = model.register_forward_hook(add_noise_hook)

x = torch.randn(3, 10)
y1 = model(x)
y2 = model(x)
print(f"Outputs differ due to noise: {not torch.allclose(y1, y2)}")

# Remove the hook
handle.remove()
y3 = model(x)
y4 = model(x)
print(f"After removing hook, outputs same: {torch.allclose(y3, y4)}")

### 6.2 Backward Hooks

In [None]:
# Full backward hook signature: hook(module, grad_input, grad_output) -> tuple or None

gradients = {}

def save_gradients(name):
    def hook(module, grad_input, grad_output):
        gradients[name] = {
            'input': [g.detach() if g is not None else None for g in grad_input],
            'output': [g.detach() for g in grad_output]
        }
    return hook

model = nn.Sequential(
    nn.Linear(10, 20),
    nn.ReLU(),
    nn.Linear(20, 5)
)

# Register backward hooks
model[0].register_full_backward_hook(save_gradients('linear1'))
model[2].register_full_backward_hook(save_gradients('linear2'))

# Forward and backward
x = torch.randn(5, 10)
output = model(x)
loss = output.sum()
loss.backward()

print("Captured gradients:")
for name, grads in gradients.items():
    print(f"  {name}:")
    print(f"    grad_output: {[g.shape for g in grads['output']]}")

In [None]:
# Use case: Gradient clipping per layer
def clip_gradient_hook(max_norm):
    def hook(module, grad_input, grad_output):
        clipped = []
        for g in grad_input:
            if g is not None:
                norm = g.norm()
                if norm > max_norm:
                    g = g * max_norm / norm
            clipped.append(g)
        return tuple(clipped)
    return hook

model = nn.Linear(10, 5)
model.register_full_backward_hook(clip_gradient_hook(max_norm=1.0))

x = torch.randn(3, 10) * 100  # Large input -> large gradients
y = model(x)
y.sum().backward()

# Gradients should be clipped
print(f"Weight gradient norm: {model.weight.grad.norm():.4f}")

### 6.3 Forward Pre-Hook

In [None]:
# Pre-hook runs before forward, can modify inputs
# Signature: hook(module, input) -> None or modified input

def normalize_input_hook(module, input):
    x = input[0]  # input is a tuple
    normalized = (x - x.mean()) / (x.std() + 1e-5)
    return (normalized,)  # Return tuple

model = nn.Linear(10, 5)
model.register_forward_pre_hook(normalize_input_hook)

x = torch.randn(3, 10) * 100 + 50  # Unnormalized
y = model(x)  # Input will be normalized before forward

print(f"Original input mean: {x.mean():.2f}, std: {x.std():.2f}")
print("Input was normalized by pre-hook before linear layer")

---
## 7. Weight Initialization

Proper initialization is crucial for training deep networks.

In [None]:
# Built-in initializers
linear = nn.Linear(100, 50)

# Xavier/Glorot initialization (good for tanh/sigmoid)
nn.init.xavier_uniform_(linear.weight)
print(f"Xavier uniform std: {linear.weight.std():.4f}")

nn.init.xavier_normal_(linear.weight)
print(f"Xavier normal std: {linear.weight.std():.4f}")

# Kaiming/He initialization (good for ReLU)
nn.init.kaiming_uniform_(linear.weight, mode='fan_in', nonlinearity='relu')
print(f"Kaiming uniform std: {linear.weight.std():.4f}")

nn.init.kaiming_normal_(linear.weight, mode='fan_in', nonlinearity='relu')
print(f"Kaiming normal std: {linear.weight.std():.4f}")

In [None]:
# Other common initializers
linear = nn.Linear(100, 50)

nn.init.zeros_(linear.bias)           # All zeros
nn.init.ones_(linear.weight)          # All ones
nn.init.constant_(linear.bias, 0.1)   # Constant value
nn.init.normal_(linear.weight, mean=0, std=0.02)  # Normal distribution
nn.init.uniform_(linear.weight, a=-0.1, b=0.1)    # Uniform distribution
nn.init.orthogonal_(linear.weight)    # Orthogonal matrix

print("Various initializations applied")

In [None]:
# Initialize entire network with apply()
def init_weights(module):
    """Custom initialization function."""
    if isinstance(module, nn.Linear):
        nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
        if module.bias is not None:
            nn.init.zeros_(module.bias)
    elif isinstance(module, nn.Conv2d):
        nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
        if module.bias is not None:
            nn.init.zeros_(module.bias)
    elif isinstance(module, nn.BatchNorm1d):
        nn.init.ones_(module.weight)
        nn.init.zeros_(module.bias)

model = nn.Sequential(
    nn.Linear(10, 20),
    nn.BatchNorm1d(20),
    nn.ReLU(),
    nn.Linear(20, 5)
)

# Apply to all submodules
model.apply(init_weights)

print("Initialization applied to all layers:")
print(f"  Linear weight std: {model[0].weight.std():.4f}")
print(f"  Linear bias: {model[0].bias[:3]}")
print(f"  BatchNorm weight: {model[1].weight[:3]}")

---
## 8. Saving and Loading

Understanding state_dict is essential for model persistence.

In [None]:
model = nn.Sequential(
    nn.Linear(10, 20),
    nn.BatchNorm1d(20),
    nn.Linear(20, 5)
)

# state_dict contains all parameters and buffers
state = model.state_dict()
print("State dict keys:")
for key in state.keys():
    print(f"  {key}: {state[key].shape}")

In [None]:
# Save and load
import tempfile
import os

# Create temp file
with tempfile.NamedTemporaryFile(delete=False, suffix='.pt') as f:
    path = f.name

# Save state dict
torch.save(model.state_dict(), path)
print(f"Saved to {path}")

# Create new model and load
model2 = nn.Sequential(
    nn.Linear(10, 20),
    nn.BatchNorm1d(20),
    nn.Linear(20, 5)
)
model2.load_state_dict(torch.load(path))
print("Loaded successfully")

# Verify
print(f"Weights match: {torch.allclose(model[0].weight, model2[0].weight)}")

# Cleanup
os.unlink(path)

In [None]:
# Partial loading with strict=False
# Useful for transfer learning

# Original model
model1 = nn.Sequential(
    nn.Linear(10, 20),
    nn.Linear(20, 5)
)

# New model with extra layer
model2 = nn.Sequential(
    nn.Linear(10, 20),
    nn.Linear(20, 15),  # Different!
    nn.Linear(15, 5)    # Extra layer
)

# This would fail with strict=True
missing, unexpected = model2.load_state_dict(model1.state_dict(), strict=False)
print(f"Missing keys: {missing}")
print(f"Unexpected keys: {unexpected}")

---
## Exercises

### Exercise 1: Custom Linear Layer

Implement a custom linear layer from scratch with proper parameter registration.

In [None]:
class CustomLinear(nn.Module):
    """
    Implement a linear layer: output = input @ weight.T + bias
    
    Args:
        in_features: size of input
        out_features: size of output
        bias: whether to include bias term
    """
    def __init__(self, in_features: int, out_features: int, bias: bool = True):
        super().__init__()
        # YOUR CODE HERE:
        # 1. Create weight as nn.Parameter with shape (out_features, in_features)
        # 2. Create bias as nn.Parameter with shape (out_features,) if bias=True, else None
        # 3. Initialize weights using Kaiming uniform
        # 4. Initialize bias to zeros
        pass
    
    def forward(self, x):
        # YOUR CODE HERE
        pass

# Test
# layer = CustomLinear(10, 5)
# x = torch.randn(3, 10)
# y = layer(x)
# print(f"Output shape: {y.shape}")  # Should be (3, 5)
# print(f"Parameters: {list(layer.named_parameters())}")

### Exercise 2: Feature Extraction with Hooks

Create a feature extractor that captures intermediate activations from a pretrained-style model.

In [None]:
class FeatureExtractor:
    """
    Extract features from specified layers of a model.
    
    Usage:
        model = create_model()
        extractor = FeatureExtractor(model, ['layer1', 'layer2.relu'])
        output = model(x)
        features = extractor.get_features()  # Dict of layer_name -> activation
    """
    def __init__(self, model, layer_names):
        self.model = model
        self.layer_names = layer_names
        self.features = {}
        self.handles = []
        
        # YOUR CODE HERE:
        # 1. For each layer_name, get the corresponding module
        # 2. Register a forward hook that saves the output
        # 3. Store the handle so we can remove it later
        pass
    
    def get_features(self):
        return self.features
    
    def remove_hooks(self):
        for handle in self.handles:
            handle.remove()

# Test
# model = nn.Sequential(
#     nn.Linear(10, 20),
#     nn.ReLU(),
#     nn.Linear(20, 5)
# )
# extractor = FeatureExtractor(model, ['0', '1'])
# output = model(torch.randn(3, 10))
# features = extractor.get_features()
# print(f"Extracted features: {list(features.keys())}")
# extractor.remove_hooks()

### Exercise 3: Custom Layer with Learned Normalization

Implement a layer that learns to normalize its input differently for each feature.

In [None]:
class LearnedNorm(nn.Module):
    """
    A layer that learns per-feature normalization.
    
    For each feature i:
        output[..., i] = (input[..., i] - learned_mean[i]) / (learned_std[i] + eps)
    
    Both mean and std are learnable parameters.
    
    Args:
        num_features: number of features
        eps: small constant for numerical stability
    """
    def __init__(self, num_features: int, eps: float = 1e-5):
        super().__init__()
        self.eps = eps
        # YOUR CODE HERE:
        # 1. Create learned_mean as Parameter, initialized to zeros
        # 2. Create learned_std as Parameter, initialized to ones
        pass
    
    def forward(self, x):
        # YOUR CODE HERE
        # Normalize: (x - mean) / (std + eps)
        pass

# Test
# layer = LearnedNorm(10)
# x = torch.randn(5, 10) * 2 + 3  # Offset distribution
# y = layer(x)
# print(f"Output mean (before training): {y.mean():.4f}")
# print(f"Learnable parameters: {[n for n, p in layer.named_parameters()]}")

---
## Solutions

In [None]:
# Exercise 1 Solution
class CustomLinearSolution(nn.Module):
    def __init__(self, in_features: int, out_features: int, bias: bool = True):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        
        self.weight = nn.Parameter(torch.empty(out_features, in_features))
        if bias:
            self.bias = nn.Parameter(torch.empty(out_features))
        else:
            self.register_parameter('bias', None)
        
        # Initialize
        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        if self.bias is not None:
            fan_in = self.in_features
            bound = 1 / math.sqrt(fan_in)
            nn.init.uniform_(self.bias, -bound, bound)
    
    def forward(self, x):
        output = x @ self.weight.T
        if self.bias is not None:
            output = output + self.bias
        return output

print("Exercise 1 Solution:")
layer = CustomLinearSolution(10, 5)
x = torch.randn(3, 10)
y = layer(x)
print(f"Output shape: {y.shape}")
print(f"Parameters: {[n for n, p in layer.named_parameters()]}")

In [None]:
# Exercise 2 Solution
class FeatureExtractorSolution:
    def __init__(self, model, layer_names):
        self.model = model
        self.layer_names = layer_names
        self.features = {}
        self.handles = []
        
        for name in layer_names:
            # Navigate to the module
            module = model
            for part in name.split('.'):
                if part.isdigit():
                    module = module[int(part)]
                else:
                    module = getattr(module, part)
            
            # Register hook
            handle = module.register_forward_hook(self._make_hook(name))
            self.handles.append(handle)
    
    def _make_hook(self, name):
        def hook(module, input, output):
            self.features[name] = output.detach()
        return hook
    
    def get_features(self):
        return self.features
    
    def remove_hooks(self):
        for handle in self.handles:
            handle.remove()

print("\nExercise 2 Solution:")
model = nn.Sequential(
    nn.Linear(10, 20),
    nn.ReLU(),
    nn.Linear(20, 5)
)
extractor = FeatureExtractorSolution(model, ['0', '1'])
output = model(torch.randn(3, 10))
features = extractor.get_features()
print(f"Extracted features: {list(features.keys())}")
for name, feat in features.items():
    print(f"  {name}: {feat.shape}")
extractor.remove_hooks()

In [None]:
# Exercise 3 Solution
class LearnedNormSolution(nn.Module):
    def __init__(self, num_features: int, eps: float = 1e-5):
        super().__init__()
        self.eps = eps
        self.learned_mean = nn.Parameter(torch.zeros(num_features))
        self.learned_std = nn.Parameter(torch.ones(num_features))
    
    def forward(self, x):
        return (x - self.learned_mean) / (self.learned_std + self.eps)

print("\nExercise 3 Solution:")
layer = LearnedNormSolution(10)
x = torch.randn(5, 10) * 2 + 3
y = layer(x)
print(f"Input mean: {x.mean():.4f}, std: {x.std():.4f}")
print(f"Output mean: {y.mean():.4f}, std: {y.std():.4f}")
print(f"Learnable parameters: {[n for n, p in layer.named_parameters()]}")

# The layer can learn to shift/scale appropriately during training
# to normalize inputs to a better range

---
## Summary

Key takeaways from this notebook:

1. **nn.Module** is the base class for all neural network components
2. **Parameters** are registered automatically when assigned as `nn.Parameter`
3. **Submodules** must use `nn.ModuleList`/`nn.ModuleDict`, not Python lists/dicts
4. **Buffers** are for non-learnable state that should be saved and moved with the model
5. **train()/eval()** affects Dropout, BatchNorm, and other mode-dependent layers
6. **Hooks** enable inspection and modification of forward/backward passes
7. **Proper initialization** is crucial for training deep networks
8. **state_dict** enables saving, loading, and transfer learning

---
*Next: Module 2.1 - The Training Loop Deconstructed*