# 8. Weight Initialization

How you start matters.
If weights are too small, signals vanish.
If weights are too large, signals explode.
Let's see how to initialize them "just right".

In [None]:
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np

## 1. The Problem: Vanishing & Exploding Gradients

Let's simulate a deep network by multiplying a vector by many random matrices.

In [None]:
def simulate_deep_network(init_std, activation_fn=None):
    x = torch.randn(512)
    activations = [x]
    
    for i in range(10): # 10 layers
        W = torch.randn(512, 512) * init_std
        x = W @ x
        if activation_fn:
            x = activation_fn(x)
        activations.append(x)
        
    return activations

def plot_activations(activations, title):
    plt.figure(figsize=(10, 4))
    for i, act in enumerate(activations):
        plt.subplot(1, len(activations), i+1)
        plt.hist(act.numpy(), bins=30, range=(-3, 3))
        plt.axis('off')
    plt.suptitle(title)
    plt.show()

# 1. Small Weights (std=0.01)
acts = simulate_deep_network(0.01)
plot_activations(acts, "Small Weights: Vanishing Signal (Everything becomes 0)")

# 2. Large Weights (std=1.0)
acts = simulate_deep_network(1.0)
plot_activations(acts, "Large Weights: Exploding Signal (Values get huge)")

## 2. Xavier (Glorot) Initialization

Designed for **Sigmoid** and **Tanh** activations.
Goal: Keep the variance of activations constant across layers.

$W \sim N(0, \frac{2}{n_{in} + n_{out}})$

In [None]:
# Xavier Initialization
n_in = 512
n_out = 512
xavier_std = np.sqrt(2 / (n_in + n_out))

acts = simulate_deep_network(xavier_std, torch.tanh)
plot_activations(acts, "Xavier Init with Tanh: Stable!")

## 3. Kaiming (He) Initialization

Designed for **ReLU** activations.
ReLU kills half the neurons (sets them to 0), so we need to double the variance to compensate.

$W \sim N(0, \frac{2}{n_{in}})$

In [None]:
# Kaiming Initialization
kaiming_std = np.sqrt(2 / n_in)

acts = simulate_deep_network(kaiming_std, torch.relu)
plot_activations(acts, "Kaiming Init with ReLU: Stable!")

# What if we used Xavier with ReLU?
acts = simulate_deep_network(xavier_std, torch.relu)
plot_activations(acts, "Xavier Init with ReLU: Signal Vanishes!")

## 4. How to use in PyTorch

PyTorch layers are initialized reasonably by default, but you can override them.

In [None]:
layer = nn.Linear(512, 512)

# Apply Kaiming Normal
nn.init.kaiming_normal_(layer.weight, nonlinearity='relu')

# Apply Xavier Uniform
nn.init.xavier_uniform_(layer.weight)

print("Weights initialized!")