# Module 15: Batch Normalization

**Stabilizing Deep Networks**

---

## Objectives

By the end of this notebook, you will:
- Understand internal covariate shift
- Master batch normalization math
- Know how to use BatchNorm in PyTorch
- Understand train vs eval mode behavior

---

In [1]:
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np

torch.manual_seed(42)

<torch._C.Generator at 0x7a0105c35f30>

---

# Part 1: The Problem

---

**Internal Covariate Shift**: As training progresses, the distribution of inputs to each layer changes, making learning unstable.

---

# Part 2: Batch Normalization

---

## 2.1 The Algorithm

For a batch of inputs, normalize then scale and shift:

$$\hat{x}_i = \frac{x_i - \mu_B}{\sqrt{\sigma_B^2 + \epsilon}}$$

$$y_i = \gamma \hat{x}_i + \beta$$

Where:
- $\mu_B, \sigma_B^2$ = batch mean and variance
- $\gamma, \beta$ = learnable scale and shift
- $\epsilon$ = small constant for numerical stability

In [2]:
# BatchNorm from scratch
class BatchNorm1D:
    def __init__(self, num_features, eps=1e-5, momentum=0.1):
        # Learnable parameters
        self.gamma = torch.ones(num_features)
        self.beta = torch.zeros(num_features)

        # Running statistics for inference
        self.running_mean = torch.zeros(num_features)
        self.running_var = torch.ones(num_features)

        self.eps = eps
        self.momentum = momentum
        self.training = True

    def __call__(self, x):
        if self.training:
            # Use batch statistics
            mean = x.mean(dim=0)
            var = x.var(dim=0, unbiased=False)

            # Update running statistics
            self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean
            self.running_var = (1 - self.momentum) * self.running_var + self.momentum * var
        else:
            # Use running statistics
            mean = self.running_mean
            var = self.running_var

        # Normalize
        x_norm = (x - mean) / torch.sqrt(var + self.eps)

        # Scale and shift
        return self.gamma * x_norm + self.beta

# Test
bn = BatchNorm1D(5)
x = torch.randn(32, 5) * 10 + 5  # Mean ~5, Std ~10
out = bn(x)

print(f"Input:  mean={x.mean():.2f}, std={x.std():.2f}")
print(f"Output: mean={out.mean():.2f}, std={out.std():.2f}")

Input:  mean=6.00, std=9.88
Output: mean=0.00, std=1.00


## 2.2 PyTorch BatchNorm

In [3]:
# PyTorch BatchNorm
bn = nn.BatchNorm1d(num_features=5)

print("Learnable parameters:")
print(f"  gamma (weight): {bn.weight.shape}")
print(f"  beta (bias): {bn.bias.shape}")

print("\nRunning statistics (buffers):")
print(f"  running_mean: {bn.running_mean.shape}")
print(f"  running_var: {bn.running_var.shape}")

Learnable parameters:
  gamma (weight): torch.Size([5])
  beta (bias): torch.Size([5])

Running statistics (buffers):
  running_mean: torch.Size([5])
  running_var: torch.Size([5])


In [4]:
# Train vs Eval mode
bn = nn.BatchNorm1d(5)
x = torch.randn(32, 5) * 10 + 5

# Training mode: uses batch statistics
bn.train()
out_train = bn(x)
print(f"Train mode: {out_train.mean():.4f}")

# Eval mode: uses running statistics
bn.eval()
out_eval = bn(x)
print(f"Eval mode: {out_eval.mean():.4f}")

Train mode: 0.0000
Eval mode: 1.3306


## 2.3 Using in a Network

In [5]:
class MLPWithBN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super().__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.bn1 = nn.BatchNorm1d(hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.bn2 = nn.BatchNorm1d(hidden_size)
        self.fc3 = nn.Linear(hidden_size, output_size)
        self.relu = nn.ReLU()

    def forward(self, x):
        # BatchNorm BEFORE activation (most common)
        x = self.fc1(x)
        x = self.bn1(x)
        x = self.relu(x)

        x = self.fc2(x)
        x = self.bn2(x)
        x = self.relu(x)

        x = self.fc3(x)  # No BN on output
        return x

model = MLPWithBN(784, 256, 10)
print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")

Parameters: 270,346


---

# Part 3: Benefits

---

1. **Faster training**: Allows higher learning rates
2. **Reduces sensitivity to initialization**
3. **Acts as regularization** (slight noise from batch stats)
4. **Enables deeper networks**

---

# Key Points

---

- **Training**: Uses batch mean/var
- **Inference**: Uses running mean/var
- **Always call model.eval()** before inference!
- Place BatchNorm before activation (most common)

---

## Next Module: [16 - Hyperparameter Tuning](../16_hyperparameters/16_hyperparameters.ipynb)