In [1]:
!pip install -q flax

[?25l[K     |█▉                              | 10 kB 23.8 MB/s eta 0:00:01[K     |███▋                            | 20 kB 24.2 MB/s eta 0:00:01[K     |█████▍                          | 30 kB 12.5 MB/s eta 0:00:01[K     |███████▏                        | 40 kB 10.4 MB/s eta 0:00:01[K     |█████████                       | 51 kB 4.5 MB/s eta 0:00:01[K     |██████████▊                     | 61 kB 5.4 MB/s eta 0:00:01[K     |████████████▌                   | 71 kB 6.0 MB/s eta 0:00:01[K     |██████████████▎                 | 81 kB 4.3 MB/s eta 0:00:01[K     |████████████████                | 92 kB 4.8 MB/s eta 0:00:01[K     |█████████████████▉              | 102 kB 5.3 MB/s eta 0:00:01[K     |███████████████████▋            | 112 kB 5.3 MB/s eta 0:00:01[K     |█████████████████████▍          | 122 kB 5.3 MB/s eta 0:00:01[K     |███████████████████████▏        | 133 kB 5.3 MB/s eta 0:00:01[K     |█████████████████████████       | 143 kB 5.3 MB/s eta 0:00:01[K 

In [2]:
import numpy as np

import jax
import jax.numpy as jnp
from flax import linen as nn

In [3]:
# batch size 3, feature size 2
np.random.seed(42)
X = np.random.normal(size=(2, 3))

print("batch norm")
mu_batch = np.mean(X, axis=0)
sigma_batch = np.std(X, axis=0)
XBN = (X - mu_batch) / sigma_batch
print(XBN)

print("layer norm")
mu_layer = np.expand_dims(np.mean(X, axis=1), axis=1)
sigma_layer = np.expand_dims(np.std(X, axis=1), axis=1)
XLN = (X - mu_layer) / sigma_layer
print(XLN)

batch norm
[[-1.  1.  1.]
 [ 1. -1. -1.]]
layer norm
[[ 0.47376014 -1.39085732  0.91709718]
 [ 1.41421356 -0.70711669 -0.70709687]]


In [4]:
X = jnp.float32(X)

rng = jax.random.PRNGKey(42)
bn_rng, ln_rng = jax.random.split(rng)

print("batch norm")
bn = nn.BatchNorm(use_running_average=False, epsilon=1e-6)
bn_params = bn.init(bn_rng, X)
XBN_t, _ = bn.apply(bn_params, X, mutable=["batch_stats"])
print(XBN_t)
assert np.allclose(np.array(XBN_t), XBN, atol=1e-3)

print("layer norm")
ln = nn.LayerNorm()
ln_params = ln.init(ln_rng, X)
XLN_t = ln.apply(ln_params, X)
print(XLN_t)
assert np.allclose(np.array(XLN_t), XLN, atol=1e-3)



batch norm
[[-0.99999815  0.99978346  0.99999744]
 [ 0.99999815 -0.9997831  -0.9999975 ]]
layer norm
[[ 0.473758   -1.3908514   0.9170933 ]
 [ 1.4142125  -0.70711625 -0.7070964 ]]
