In [2]:
import numpy as np

def layer_normalization(X: np.ndarray, gamma: np.ndarray, beta: np.ndarray, epsilon: float = 1e-5) -> np.ndarray:
    """
    Perform Layer Normalization over the feature dimension.
    X shape:    (batch_size, seq_len, feature_dim)
    gamma/beta: (1, 1, feature_dim)
    """
    # Compute mean & variance across feature dimension (axis = -1)
    mean = np.mean(X, axis=-1, keepdims=True)
    var = np.var(X, axis=-1, keepdims=True)
    
    # Normalize
    X_norm = (X - mean) / np.sqrt(var + epsilon)
    
    # Scale and shift
    out = gamma * X_norm + beta
    
    return out
    
#Usage example
np.random.seed(42)
X = np.random.randn(2, 2, 3)
gamma = np.ones(3).reshape(1, 1, -1)
beta = np.zeros(3).reshape(1, 1, -1)

print(layer_normalization(X, gamma, beta))

[[[ 0.47373971 -1.39079736  0.91705765]
  [ 1.41420326 -0.70711154 -0.70709172]]

 [[ 1.13192477  0.16823009 -1.30015486]
  [ 1.4141794  -0.70465482 -0.70952458]]]
