In [1]:
import numpy as np

def batch_normalization(X: np.ndarray, gamma: np.ndarray, beta: np.ndarray, epsilon: float = 1e-5) -> np.ndarray:
    # X shape: (B, C, H, W)
    # gamma & beta shape: (1, C, 1, 1)
    
    # Compute mean & variance across batch + spatial dims â†’ axes=(0,2,3)
    mean = np.mean(X, axis=(0, 2, 3), keepdims=True)
    var = np.var(X, axis=(0, 2, 3), keepdims=True)

    # Normalize
    X_norm = (X - mean) / np.sqrt(var + epsilon)
    
    # Scale & shift
    out = gamma * X_norm + beta
    
    return out

    
#Usage example
B, C, H, W = 2, 2, 2, 2
np.random.seed(42)
X = np.random.randn(B, C, H, W)
gamma = np.ones(C).reshape(1, C, 1, 1)
beta = np.zeros(C).reshape(1, C, 1, 1)

print(batch_normalization(X, gamma, beta))

[[[[ 0.42859934 -0.51776438]
   [ 0.65360963  1.95820707]]

  [[ 0.02353721  0.02355215]
   [ 1.67355207  0.93490043]]]


 [[[-1.01139563  0.49692747]
   [-1.00236882 -1.00581468]]

  [[ 0.45676349 -1.50433085]
   [-1.33293647 -0.27503802]]]]
