In [1]:
import jax
import jax.numpy as jnp
from functools import partial

# Normalization Techniques in Deep Learning  

Normalization techniques are used in deep learning to stabilize training, speed up convergence, and improve generalization. Below are some common examples.

## Batch Normalization  

Batch Normalization (BN) normalizes the input across the **batch dimension**. That is, for each feature (channel), it computes the mean and standard deviation **over the batch** and normalizes accordingly.  

For an input tensor $ x $ with shape $ (N, C, ...) $ where $ N $ is the batch dimension and $ C $ represents different features (e.g., channels in an image), BN computes the statistics as:  

$$
\mu_B = \frac{1}{N} \sum_{i=1}^{N} x_i
$$

$$
\sigma_B^2 = \frac{1}{N} \sum_{i=1}^{N} (x_i - \mu_B)^2
$$

Each input is then normalized as:

$$
\hat{x}_i = \frac{x_i - \mu_B}{\sigma_B}
$$

where $ \sigma_B $ is the standard deviation:

$$
\sigma_B = \sqrt{\sigma_B^2 + \epsilon}
$$

Here, $ \epsilon $ is a small constant added for numerical stability.  

In [None]:
class BatchNorm:

    def __init__(self, batch_dim=0):
        self.batch_dim = batch_dim
        self.forward = partial(self.forward, batch_dim=batch_dim)

    @staticmethod
    def forward(x, batch_dim):
        return (x - x.mean(batch_dim, keepdims=True)) / x.std(batch_dim, keepdims=True) 

# Layer Normalization  

Layer Normalization (LN) normalizes each input independently, computing the statistics **within each observation** across the feature dimensions. Unlike BN, LN does not depend on the batch size and is useful for tasks where batch statistics are unstable, such as reinforcement learning or recurrent neural networks.  

For an input tensor $ x $ of shape $ (N, C, ...) $, LN computes the mean and variance **over the feature dimensions** (not the batch dimension):

$$
\mu_L = \frac{1}{C} \sum_{j=1}^{C} x_j
$$

$$
\sigma_L^2 = \frac{1}{C} \sum_{j=1}^{C} (x_j - \mu_L)^2
$$

Each feature is normalized as:

$$
\hat{x}_j = \frac{x_j - \mu_L}{\sigma_L}
$$

where $ \sigma_L $ is the standard deviation:

$$
\sigma_L = \sqrt{\sigma_L^2 + \epsilon}
$$

In [3]:
@jax.jit
def layer_norm(x, batch_dim = 0):
    """
    Normalise each observation to zero mean and unit variance, computing
    statistics within observations, over channels.
    """ 
    over_dims = tuple(i for i in range(x.ndim) if i != batch_dim)
    return (x - x.mean(over_dims, keepdims=True)) / x.std(over_dims, keepdims=True)

## Instance Normalization  

**Instance Normalization (IN)** is similar to Layer Normalization, but it computes statistics **separately for each feature map and each sample**. It is particularly useful for style transfer and generative models.  

For an input $$ x $$ with shape $$ (N, C, H, W) $$, IN computes the mean and variance for each instance **per channel**:

$$
\mu_I = \frac{1}{H W} \sum_{h=1}^{H} \sum_{w=1}^{W} x_{h, w}
$$

$$
\sigma_I^2 = \frac{1}{H W} \sum_{h=1}^{H} \sum_{w=1}^{W} (x_{h, w} - \mu_I)^2
$$

$$
\hat{x} = \frac{x - \mu_I}{\sigma_I}
$$

In [None]:
@jax.jit
def instance_norm(x, feature_dim=1):
    """
    Normalizes each feature map independently for each sample.
    Used in style transfer and generative models.
    """
    spatial_dims = tuple(i for i in range(x.ndim) if i != feature_dim and i != 0)
    return (x - x.mean(spatial_dims, keepdims=True)) / x.std(spatial_dims, keepdims=True)

## Group Normalization  

**Group Normalization (GN)** divides the channels into groups and normalizes **within each group**. Unlike BatchNorm, it does not depend on batch size, making it useful for small-batch training.  

For an input $ x $ with shape $ (N, C, H, W) $, GN splits $ C $ into $ G $ groups. Within each group:

$$
\mu_G = \frac{1}{|G|} \sum_{c \in G} x_c
$$

$$
\sigma_G^2 = \frac{1}{|G|} \sum_{c \in G} (x_c - \mu_G)^2
$$

$$
\hat{x} = \frac{x - \mu_G}{\sigma_G}
$$

In [None]:
@jax.jit
def group_norm(x, num_groups=8, feature_dim=1):
    """
    Normalizes within groups of feature maps.
    Suitable for small batch training.
    """
    N, C = x.shape[:2]
    G = min(num_groups, C)  # Ensure we don't have more groups than channels
    x = x.reshape(N, G, C // G, *x.shape[2:])  # Reshape to groups
    mean = x.mean(axis=(2, 3, 4), keepdims=True)
    std = x.std(axis=(2, 3, 4), keepdims=True)
    x_norm = (x - mean) / (std + 1e-5)
    return x_norm.reshape(N, C, *x.shape[3:])  # Reshape back

## Layer Scaling (Weight Standardization)  

Instead of normalizing activations, **Weight Standardization** normalizes the weights of the convolutional layers. This is useful for stabilizing training in architectures like ResNets.  

For a weight matrix $ W $:

$$
\hat{W} = \frac{W - \mu_W}{\sigma_W}
$$

where:

$$
\mu_W = \frac{1}{C} \sum W, \quad \sigma_W = \sqrt{\frac{1}{C} \sum (W - \mu_W)^2}
$$

In [None]:
@jax.jit
def weight_standardization(W, axis=(1, 2, 3)):
    """
    Normalizes convolutional weights across spatial and feature dimensions.
    Used in ResNets and other architectures.
    """
    mean = W.mean(axis=axis, keepdims=True)
    std = W.std(axis=axis, keepdims=True)
    return (W - mean) / (std + 1e-5)

## Spectral Normalization  

**Spectral Normalization** ensures that the weight matrix has a bounded spectral norm, helping stabilize GAN training.   

The spectral norm of a matrix $ W $ is:

$$
\| W \|_{\sigma} = \max_{\| v \|=1} \| W v \|
$$

Spectral Normalization rescales $ W $ as:

$$
\hat{W} = \frac{W}{\| W \|_{\sigma}}
$$

where $ \| W \|_{\sigma} $ is estimated using the **power iteration method**.  

In [None]:
@jax.jit
def spectral_norm(W, num_iters=1):
    """
    Normalizes weights using spectral normalization.
    Used in stabilizing GAN training.
    """
    u = jax.random.normal(jax.random.PRNGKey(0), (W.shape[1],))  # Initialize u
    for _ in range(num_iters):
        v = jnp.dot(W.T, u)
        v = v / jnp.linalg.norm(v)
        u = jnp.dot(W, v)
        u = u / jnp.linalg.norm(u)
    sigma = jnp.dot(u, jnp.dot(W, v))
    return W / sigma

## Summary  

| Normalization Type  | Normalization Axis | Best Use Cases |
|---------------------|--------------------|---------------|
| **Batch Normalization** | Over the batch dimension (per feature) | CNNs, large batch sizes |
| **Layer Normalization** | Over feature dimensions (per sample) | Transformers, RNNs, NLP |
| **Instance Normalization** | Over spatial dimensions (per channel, per sample) | Style transfer, generative models |
| **Group Normalization** | Over grouped channels (per sample) | CNNs, small batch sizes |
| **Weight Standardization** | On layer weights | ResNets, stable CNN training |
| **Spectral Normalization** | Constrains weight spectral norm | GANs, stabilizing adversarial training |