# PyTorch Modules, Parameters, Initialization

This notebook covers PyTorch's building blocks - modules, parameters, and proper initialization strategies.

## Table of Contents
1. [Basic Module Structure](#basic-module-structure)
2. [Parameter Counting Formulas](#parameter-counting-formulas)
3. [Initialization Strategies](#initialization-strategies)
4. [Advanced Module Patterns](#advanced-module-patterns)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt

# Set seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## Basic Module Structure

In [None]:
class SimpleNet(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super().__init__()  # Always call parent constructor
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, output_size)
        self.dropout = nn.Dropout(0.1)
        
    def forward(self, x):
        # Define the forward pass
        x = F.relu(self.fc1(x))
        x = self.dropout(x)  # Apply dropout during training
        x = self.fc2(x)
        return x

# Create and inspect
model = SimpleNet(10, 20, 5)
print("Model structure:")
print(model)

print("\nModel parameters:")
for name, param in model.named_parameters():
    print(f"{name}: {param.shape} ({param.numel()} parameters)")

# State dict for saving/loading
state_dict = model.state_dict()
print(f"\nState dict keys: {list(state_dict.keys())}")

# Test forward pass
x = torch.randn(3, 10)  # Batch of 3 samples
output = model(x)
print(f"\nInput shape: {x.shape}, Output shape: {output.shape}")

## Parameter Counting Formulas

In [None]:
def count_parameters(layer):
    """Count parameters in common layer types"""
    if isinstance(layer, nn.Linear):
        # Linear(in_features, out_features): in*out + out (bias)
        weight_params = layer.in_features * layer.out_features
        bias_params = layer.out_features if layer.bias is not None else 0
        return weight_params + bias_params
    elif isinstance(layer, nn.Embedding):
        # Embedding(num_embeddings, embedding_dim): num*dim
        return layer.num_embeddings * layer.embedding_dim
    elif isinstance(layer, nn.LSTM):
        # LSTM has 4 gates, each with input and hidden weights + bias
        input_size, hidden_size = layer.input_size, layer.hidden_size
        num_layers = layer.num_layers
        bidirectional = 2 if layer.bidirectional else 1
        
        # Per layer: 4 gates * (input_weights + hidden_weights + 2*bias)
        # Note: LSTM has input bias and hidden bias for each gate
        per_layer = 4 * (input_size * hidden_size + hidden_size * hidden_size + 2 * hidden_size)
        
        # First layer uses input_size, subsequent layers use hidden_size as input
        if num_layers > 1:
            first_layer = 4 * (input_size * hidden_size + hidden_size * hidden_size + 2 * hidden_size)
            other_layers = (num_layers - 1) * 4 * (hidden_size * hidden_size + hidden_size * hidden_size + 2 * hidden_size)
            return (first_layer + other_layers) * bidirectional
        else:
            return per_layer * bidirectional
    else:
        return sum(p.numel() for p in layer.parameters())

# Examples with verification
print("=== Parameter Counting Examples ===")

# Linear layer
linear = nn.Linear(100, 50)
calculated = count_parameters(linear)
actual = sum(p.numel() for p in linear.parameters())
print(f"Linear(100, 50):")
print(f"  Calculated: {calculated} (100*50 + 50 = {100*50 + 50})")
print(f"  Actual: {actual}")
print(f"  Match: {calculated == actual}")

# Embedding layer
embedding = nn.Embedding(1000, 128)
calculated = count_parameters(embedding)
actual = sum(p.numel() for p in embedding.parameters())
print(f"\nEmbedding(1000, 128):")
print(f"  Calculated: {calculated} (1000*128 = {1000*128})")
print(f"  Actual: {actual}")
print(f"  Match: {calculated == actual}")

# LSTM layer
lstm = nn.LSTM(64, 32, num_layers=1, batch_first=True)
calculated = count_parameters(lstm)
actual = sum(p.numel() for p in lstm.parameters())
print(f"\nLSTM(64, 32, layers=1):")
print(f"  Calculated: {calculated}")
print(f"  Actual: {actual}")
print(f"  Match: {calculated == actual}")

# Multi-layer LSTM
lstm_multi = nn.LSTM(64, 32, num_layers=2, batch_first=True)
calculated_multi = count_parameters(lstm_multi)
actual_multi = sum(p.numel() for p in lstm_multi.parameters())
print(f"\nLSTM(64, 32, layers=2):")
print(f"  Calculated: {calculated_multi}")
print(f"  Actual: {actual_multi}")
print(f"  Match: {calculated_multi == actual_multi}")

In [None]:
# Parameter analysis for complete model
def analyze_model_parameters(model):
    """Detailed parameter analysis"""
    total_params = 0
    trainable_params = 0
    
    print("Parameter breakdown:")
    print("-" * 50)
    
    for name, param in model.named_parameters():
        param_count = param.numel()
        total_params += param_count
        
        if param.requires_grad:
            trainable_params += param_count
            status = "Trainable"
        else:
            status = "Frozen"
        
        print(f"{name:20s}: {str(param.shape):15s} {param_count:8,d} ({status})")
    
    print("-" * 50)
    print(f"{'Total parameters:':<20s} {total_params:>15,d}")
    print(f"{'Trainable parameters:':<20s} {trainable_params:>15,d}")
    print(f"{'Non-trainable:':<20s} {total_params - trainable_params:>15,d}")
    
    # Memory estimation (rough)
    param_memory_mb = total_params * 4 / (1024**2)  # 4 bytes per float32
    print(f"{'Estimated memory:':<20s} {param_memory_mb:>12.1f} MB")

# Analyze our simple model
analyze_model_parameters(model)

print("\n" + "="*60)

# Create a more complex model for comparison
class ComplexModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.embedding = nn.Embedding(5000, 128)
        self.lstm = nn.LSTM(128, 256, num_layers=2, batch_first=True, dropout=0.1)
        self.attention = nn.MultiheadAttention(256, 8, batch_first=True)
        self.classifier = nn.Linear(256, 10)
        
    def forward(self, x):
        x = self.embedding(x)
        x, _ = self.lstm(x)
        x, _ = self.attention(x, x, x)
        x = self.classifier(x.mean(dim=1))  # Global average pooling
        return x

complex_model = ComplexModel()
print("Complex model analysis:")
analyze_model_parameters(complex_model)

## Initialization Strategies

In [None]:
# Different initialization strategies
def xavier_init(m):
    """Xavier (Glorot) initialization - good for tanh/sigmoid activations"""
    if isinstance(m, nn.Linear):
        nn.init.xavier_uniform_(m.weight)
        if m.bias is not None:
            nn.init.zeros_(m.bias)
    elif isinstance(m, nn.Embedding):
        nn.init.xavier_uniform_(m.weight)

def kaiming_init(m):
    """Kaiming (He) initialization - good for ReLU activations"""
    if isinstance(m, nn.Linear):
        nn.init.kaiming_uniform_(m.weight, nonlinearity='relu')
        if m.bias is not None:
            nn.init.zeros_(m.bias)
    elif isinstance(m, nn.Embedding):
        nn.init.normal_(m.weight, std=0.1)

def normal_init(m):
    """Simple normal initialization"""
    if isinstance(m, nn.Linear):
        nn.init.normal_(m.weight, mean=0, std=0.01)
        if m.bias is not None:
            nn.init.zeros_(m.bias)
    elif isinstance(m, nn.Embedding):
        nn.init.normal_(m.weight, std=0.1)

# Compare initialization effects
def compare_initializations():
    """Compare different initialization strategies"""
    
    # Create three identical models
    model_xavier = SimpleNet(100, 50, 10)
    model_kaiming = SimpleNet(100, 50, 10)
    model_normal = SimpleNet(100, 50, 10)
    
    # Apply different initializations
    model_xavier.apply(xavier_init)
    model_kaiming.apply(kaiming_init)
    model_normal.apply(normal_init)
    
    models = {
        'Xavier (Glorot)': model_xavier,
        'Kaiming (He)': model_kaiming,
        'Normal (0.01)': model_normal
    }
    
    # Test with random input
    x = torch.randn(32, 100)  # Batch of 32 samples
    
    print("Initialization comparison:")
    print("=" * 60)
    
    for name, model in models.items():
        model.eval()  # Set to eval mode for consistent comparison
        
        with torch.no_grad():
            output = model(x)
            
        # Analyze weight statistics
        fc1_weight = model.fc1.weight
        fc2_weight = model.fc2.weight
        
        print(f"\n{name}:")
        print(f"  FC1 weight stats: mean={fc1_weight.mean().item():.6f}, std={fc1_weight.std().item():.6f}")
        print(f"  FC2 weight stats: mean={fc2_weight.mean().item():.6f}, std={fc2_weight.std().item():.6f}")
        print(f"  Output stats: mean={output.mean().item():.6f}, std={output.std().item():.6f}")
        print(f"  Output range: [{output.min().item():.6f}, {output.max().item():.6f}]")

compare_initializations()

# Demonstrate the impact of bad initialization
print("\n" + "=" * 60)
print("Effect of bad initialization:")

def bad_init(m):
    """Intentionally bad initialization"""
    if isinstance(m, nn.Linear):
        nn.init.constant_(m.weight, 10.0)  # Too large!
        if m.bias is not None:
            nn.init.constant_(m.bias, 0.0)

model_bad = SimpleNet(10, 20, 5)
model_bad.apply(bad_init)

x = torch.randn(5, 10)
with torch.no_grad():
    output_bad = model_bad(x)
    
print(f"Bad initialization output stats:")
print(f"  Mean: {output_bad.mean().item():.2f}")
print(f"  Std: {output_bad.std().item():.2f}")
print(f"  Range: [{output_bad.min().item():.2f}, {output_bad.max().item():.2f}]")
print(f"  Contains NaN: {torch.isnan(output_bad).any().item()}")
print("\nNote: Large values can lead to vanishing/exploding gradients!")

In [None]:
# Visualize weight distributions after initialization
def visualize_weight_distributions():
    # Create models with different initializations
    model_xavier = nn.Linear(100, 100)
    model_kaiming = nn.Linear(100, 100)
    model_normal = nn.Linear(100, 100)
    
    nn.init.xavier_uniform_(model_xavier.weight)
    nn.init.kaiming_uniform_(model_kaiming.weight, nonlinearity='relu')
    nn.init.normal_(model_normal.weight, std=0.01)
    
    fig, axes = plt.subplots(1, 3, figsize=(15, 4))
    
    weights = {
        'Xavier': model_xavier.weight.detach().numpy().flatten(),
        'Kaiming': model_kaiming.weight.detach().numpy().flatten(),
        'Normal (0.01)': model_normal.weight.detach().numpy().flatten()
    }
    
    for i, (name, weight) in enumerate(weights.items()):
        axes[i].hist(weight, bins=50, alpha=0.7, density=True)
        axes[i].set_title(f'{name} Initialization')
        axes[i].set_xlabel('Weight Value')
        axes[i].set_ylabel('Density')
        axes[i].grid(True, alpha=0.3)
        
        # Add statistics
        mean, std = weight.mean(), weight.std()
        axes[i].axvline(mean, color='red', linestyle='--', label=f'Mean: {mean:.4f}')
        axes[i].legend()
        axes[i].set_title(f'{name}\nMean: {mean:.4f}, Std: {std:.4f}')
    
    plt.tight_layout()
    plt.show()
    
    print("Key insights:")
    print("â€¢ Xavier: Balanced for tanh/sigmoid activations")
    print("â€¢ Kaiming: Wider distribution for ReLU activations")
    print("â€¢ Normal (0.01): Very narrow, might cause vanishing gradients")

visualize_weight_distributions()

## Advanced Module Patterns

In [None]:
# Custom module with learnable parameters
class CustomLinear(nn.Module):
    """Custom linear layer to demonstrate parameter creation"""
    
    def __init__(self, in_features, out_features, bias=True):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        
        # Create learnable parameters
        self.weight = nn.Parameter(torch.randn(out_features, in_features))
        
        if bias:
            self.bias = nn.Parameter(torch.randn(out_features))
        else:
            # Register as None so it doesn't appear in parameters()
            self.register_parameter('bias', None)
        
        # Initialize parameters
        self.reset_parameters()
    
    def reset_parameters(self):
        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        if self.bias is not None:
            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in)
            nn.init.uniform_(self.bias, -bound, bound)
    
    def forward(self, x):
        # Manual implementation of linear transformation
        output = torch.matmul(x, self.weight.t())
        if self.bias is not None:
            output = output + self.bias
        return output
    
    def extra_repr(self):
        # Custom string representation
        return f'in_features={self.in_features}, out_features={self.out_features}, bias={self.bias is not None}'

# Test custom linear layer
import math

custom_linear = CustomLinear(10, 5)
builtin_linear = nn.Linear(10, 5)

print("Custom linear layer:")
print(custom_linear)
print(f"Parameters: {sum(p.numel() for p in custom_linear.parameters())}")

# Test they produce similar results
x = torch.randn(3, 10)
output_custom = custom_linear(x)
output_builtin = builtin_linear(x)

print(f"\nOutput shapes - Custom: {output_custom.shape}, Built-in: {output_builtin.shape}")
print("Custom and built-in linear layers work equivalently!")

In [None]:
# Module with submodules and parameter sharing
class ModularNet(nn.Module):
    """Demonstrate modular architecture and parameter sharing"""
    
    def __init__(self, input_size, hidden_size, num_layers):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        
        # Input projection
        self.input_proj = nn.Linear(input_size, hidden_size)
        
        # Shared transformation (parameter sharing)
        self.shared_transform = nn.Linear(hidden_size, hidden_size)
        
        # Layer-specific transformations
        self.layer_transforms = nn.ModuleList([
            nn.Linear(hidden_size, hidden_size) for _ in range(num_layers)
        ])
        
        # Layer normalization for each layer
        self.layer_norms = nn.ModuleList([
            nn.LayerNorm(hidden_size) for _ in range(num_layers)
        ])
        
        # Output projection
        self.output_proj = nn.Linear(hidden_size, input_size)
        
    def forward(self, x):
        # Input projection
        x = F.relu(self.input_proj(x))
        
        # Process through layers
        for i in range(self.num_layers):
            residual = x
            
            # Apply shared transformation (parameter sharing across layers)
            x = self.shared_transform(x)
            
            # Apply layer-specific transformation
            x = self.layer_transforms[i](x)
            
            # Residual connection and layer norm
            x = self.layer_norms[i](x + residual)
            x = F.relu(x)
        
        # Output projection
        x = self.output_proj(x)
        return x

# Create and analyze modular network
modular_net = ModularNet(input_size=64, hidden_size=128, num_layers=4)

print("Modular Network Architecture:")
print(modular_net)

print("\nParameter analysis:")
total_params = 0
for name, param in modular_net.named_parameters():
    print(f"{name:30s}: {str(param.shape):20s} {param.numel():>8,d}")
    total_params += param.numel()

print(f"\nTotal parameters: {total_params:,}")

# Note: shared_transform parameters are used across all layers
print(f"\nShared parameters (used {modular_net.num_layers} times): {modular_net.shared_transform.weight.numel() + modular_net.shared_transform.bias.numel():,}")

# Test forward pass
x = torch.randn(5, 64)
output = modular_net(x)
print(f"\nForward pass: {x.shape} -> {output.shape}")

In [None]:
# Hooks for monitoring activations and gradients
class MonitoredNet(nn.Module):
    """Network with built-in monitoring capabilities"""
    
    def __init__(self, input_size, hidden_size, output_size):
        super().__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.fc3 = nn.Linear(hidden_size, output_size)
        
        # Storage for activations and gradients
        self.activations = {}
        self.gradients = {}
        
        # Register hooks
        self._register_hooks()
    
    def _register_hooks(self):
        def save_activation(name):
            def hook(module, input, output):
                self.activations[name] = output.detach()
            return hook
        
        def save_gradient(name):
            def hook(module, grad_input, grad_output):
                if grad_output[0] is not None:
                    self.gradients[name] = grad_output[0].detach()
            return hook
        
        # Register forward and backward hooks
        self.fc1.register_forward_hook(save_activation('fc1'))
        self.fc2.register_forward_hook(save_activation('fc2'))
        self.fc3.register_forward_hook(save_activation('fc3'))
        
        self.fc1.register_backward_hook(save_gradient('fc1'))
        self.fc2.register_backward_hook(save_gradient('fc2'))
        self.fc3.register_backward_hook(save_gradient('fc3'))
    
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
    
    def get_activation_stats(self):
        """Get statistics about activations"""
        stats = {}
        for name, activation in self.activations.items():
            stats[name] = {
                'mean': activation.mean().item(),
                'std': activation.std().item(),
                'min': activation.min().item(),
                'max': activation.max().item(),
                'zeros': (activation == 0).float().mean().item()  # Sparsity for ReLU
            }
        return stats
    
    def get_gradient_stats(self):
        """Get statistics about gradients"""
        stats = {}
        for name, gradient in self.gradients.items():
            stats[name] = {
                'mean': gradient.mean().item(),
                'std': gradient.std().item(),
                'norm': gradient.norm().item()
            }
        return stats

# Test monitored network
monitored_net = MonitoredNet(20, 50, 5)
optimizer = optim.SGD(monitored_net.parameters(), lr=0.01)
criterion = nn.MSELoss()

# Forward and backward pass
x = torch.randn(10, 20)
y_true = torch.randn(10, 5)

# Forward pass
y_pred = monitored_net(x)
loss = criterion(y_pred, y_true)

# Backward pass
optimizer.zero_grad()
loss.backward()

# Analyze activations and gradients
print("Activation Statistics:")
print("-" * 50)
activation_stats = monitored_net.get_activation_stats()
for layer, stats in activation_stats.items():
    print(f"{layer}:")
    print(f"  Mean: {stats['mean']:.4f}, Std: {stats['std']:.4f}")
    print(f"  Range: [{stats['min']:.4f}, {stats['max']:.4f}]")
    print(f"  Sparsity (zeros): {stats['zeros']:.2%}")

print("\nGradient Statistics:")
print("-" * 50)
gradient_stats = monitored_net.get_gradient_stats()
for layer, stats in gradient_stats.items():
    print(f"{layer}:")
    print(f"  Mean: {stats['mean']:.6f}, Std: {stats['std']:.6f}")
    print(f"  Norm: {stats['norm']:.6f}")

print("\nðŸŽ‰ Module exploration completed!")
print("\nKey takeaways:")
print("â€¢ Use nn.Module as base class for all models")
print("â€¢ Parameters are automatically tracked when using nn.Parameter")
print("â€¢ Proper initialization is crucial for training success")
print("â€¢ ModuleList and ModuleDict help organize complex architectures")
print("â€¢ Hooks enable monitoring and debugging of model internals")