# 9. Normalization Methods

Training deep networks is hard because the distribution of inputs to each layer changes during training.
This is called **Internal Covariate Shift**.
Normalization fixes this by forcing layer inputs to have mean 0 and variance 1.

In [None]:
import torch
import torch.nn as nn
import matplotlib.pyplot as plt

## 1. Batch Normalization (BatchNorm)

Normalizes features **across the batch**.

For a batch of $N$ samples with $D$ features:
1. Calculate mean $\mu_j$ and variance $\sigma_j^2$ for each feature $j$ (across $N$ samples).
2. Normalize: $\hat{x}_{ij} = \frac{x_{ij} - \mu_j}{\sqrt{\sigma_j^2 + \epsilon}}$
3. Scale and Shift: $y_{ij} = \gamma_j \hat{x}_{ij} + \beta_j$

$\gamma$ and $\beta$ are learnable parameters!

In [None]:
class CustomBatchNorm1d(nn.Module):
    def __init__(self, num_features, eps=1e-5, momentum=0.1):
        super().__init__()
        self.eps = eps
        self.momentum = momentum
        
        # Learnable parameters
        self.gamma = nn.Parameter(torch.ones(num_features))
        self.beta = nn.Parameter(torch.zeros(num_features))
        
        # Running stats (not trained, but updated)
        self.register_buffer('running_mean', torch.zeros(num_features))
        self.register_buffer('running_var', torch.ones(num_features))

    def forward(self, x):
        if self.training:
            # 1. Calculate batch stats
            mean = x.mean(dim=0)
            var = x.var(dim=0, unbiased=False)
            
            # 2. Update running stats (Exponential Moving Average)
            self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean
            self.running_var = (1 - self.momentum) * self.running_var + self.momentum * var
        else:
            # Use running stats during inference
            mean = self.running_mean
            var = self.running_var
            
        # 3. Normalize
        x_norm = (x - mean) / torch.sqrt(var + self.eps)
        
        # 4. Scale and Shift
        out = self.gamma * x_norm + self.beta
        return out

# Test it
batch_size = 3
features = 5
x = torch.randn(batch_size, features) * 10 + 5 # Mean ~5, Var ~100

bn = CustomBatchNorm1d(features)
out = bn(x)

print(f"Input Mean: {x.mean(dim=0)}")
print(f"Output Mean: {out.mean(dim=0)} (Should be close to 0)")
print(f"Output Var: {out.var(dim=0, unbiased=False)} (Should be close to 1)")

## 2. Layer Normalization (LayerNorm)

Normalizes features **across the sample**.
Independent of batch size. Great for RNNs and Transformers.

For each sample $i$:
1. Calculate mean $\mu_i$ and variance $\sigma_i^2$ across all $D$ features.
2. Normalize using these sample-specific stats.

In [None]:
ln = nn.LayerNorm(features)
out_ln = ln(x)

# Check stats per sample (dim=1)
print(f"LayerNorm Output Mean (per sample): {out_ln.mean(dim=1)}")
print(f"LayerNorm Output Var (per sample): {out_ln.var(dim=1, unbiased=False)}")

## 3. When to use what?

- **BatchNorm**: Default for CNNs and Feed-Forward Networks. Needs decent batch size (>32).
- **LayerNorm**: Default for Transformers (BERT, GPT) and RNNs. Works with batch size 1.