# Normalization

## Batch Normalization

**Batch normalization** (BN) is nothing but the traditional normalization,
which essentially re-scales each individual feature.
However, in traditional normalization, features are typically scaled based on
statistics from the entire dataset, while the BN operates on batches.

Note that the feature dimension(s) could be quite different in different
context.
In the CV domain it usually means each layer
(a.k.a. channel feature map or depth),
while for the NLP models both the sequence length and the embedding dimension
can be treated as "feature dimensions".

For image processing, suppose the input tensor is of shape `(B, H, W, C)`.
During the computation of mean and variance, batch normalization will
aggregate across batch (`B`), height (`H`) and width (`W`), resulting in a
mean and variance tensor of shape `(1, 1, 1, C)`.

On the other hand, for a language model tensor
(not a common practice neverthless)
of shape `(batch_size, sequence_length, embedding_size)`, the mean and
variance tensor will be of shape `(1, sequence_length, embedding_size)`, and
each position in the `sequence_length x embedding_size` grid has its own
mean and variance computed across the batch.

In [1]:
"""Normalization functions implemented with `tf.nn` APIs."""
import tensorflow as tf


def batch_normalization(x: tf.Tensor, eps: float = 1e-4) -> tf.Tensor:
    # aggregate over all dimensions except the channel one
    # the aggregated dimensions are batch, height, width.
    # This is the default behavior of `torch.nn.BatchNorm2d`.
    mean, variance = tf.nn.moments(x, axes=[0, 1, 2], keepdims=True)
    offset = tf.Variable(tf.zeros(x.shape))
    scale = tf.Variable(tf.ones(x.shape))
    x_norm = scale * (x - mean) / tf.sqrt(variance + eps) + offset
    return x_norm


def layer_normalization(x: tf.Tensor, eps: float = 1e-4) -> tf.Tensor:
    # aggregate over all dimensions except the batch size one
    # the aggregated dimensions are: channel, height, width.
    # This is the default behavior of `torch.nn.LayerNorm`
    mean, variance = tf.nn.moments(x, axes=[1, 2, 3], keepdims=True)
    offset = tf.Variable(tf.zeros(x.shape))
    scale = tf.Variable(tf.ones(x.shape))
    x_norm = scale * (x - mean) / tf.sqrt(variance + eps) + offset
    return x_norm


def layer_normalization_tfsim(x: tf.Tensor, eps: float = 1e-4) -> tf.Tensor:
    # only aggregate over the channel dimension
    mean, variance = tf.nn.moments(x, axes=[-1], keepdims=True)
    offset = tf.Variable(tf.zeros(x.shape))
    scale = tf.Variable(tf.ones(x.shape))
    x_norm = scale * (x - mean) / tf.sqrt(variance + eps) + offset
    return x_norm

In [2]:
N, H, W, C = 2, 5, 7, 3
x = tf.reshape(tf.range(N * H * W * C, dtype=tf.float32), [N, H, W, C])

2023-08-16 14:01:47.128576: I metal_plugin/src/device/metal_device.cc:1154] Metal device set to: Apple M1


### Batch Normalization 1

Results from `batch_normalization`, same behaviour as `torch.nn.BatchNorm2d(C)`

In [3]:
print(batch_normalization(x))

tf.Tensor(
[[[[-1.7074814  -1.7074814  -1.7074814 ]
   [-1.6579893  -1.6579893  -1.6579893 ]
   [-1.608497   -1.608497   -1.608497  ]
   [-1.5590048  -1.5590048  -1.5590048 ]
   [-1.5095125  -1.5095125  -1.5095125 ]
   [-1.4600203  -1.4600203  -1.4600203 ]
   [-1.4105282  -1.4105282  -1.4105282 ]]

  [[-1.361036   -1.361036   -1.361036  ]
   [-1.3115437  -1.3115437  -1.3115437 ]
   [-1.2620515  -1.2620515  -1.2620515 ]
   [-1.2125593  -1.2125593  -1.2125593 ]
   [-1.1630671  -1.1630671  -1.1630671 ]
   [-1.1135749  -1.1135749  -1.1135749 ]
   [-1.0640826  -1.0640826  -1.0640826 ]]

  [[-1.0145904  -1.0145904  -1.0145904 ]
   [-0.9650982  -0.9650982  -0.9650982 ]
   [-0.91560596 -0.91560596 -0.91560596]
   [-0.8661138  -0.8661138  -0.8661138 ]
   [-0.81662154 -0.81662154 -0.81662154]
   [-0.76712936 -0.76712936 -0.76712936]
   [-0.7176371  -0.7176371  -0.7176371 ]]

  [[-0.66814494 -0.66814494 -0.66814494]
   [-0.6186527  -0.6186527  -0.6186527 ]
   [-0.56916046 -0.56916046 -0.56916046]

### Batch Normalization - Keras

Batch Normalization Result of `tf.keras.layers.BatchNormalization()`

In [4]:
print(tf.keras.layers.BatchNormalization()(x, training=True))

tf.Tensor(
[[[[-1.7074813  -1.7074813  -1.7074813 ]
   [-1.657989   -1.657989   -1.657989  ]
   [-1.6084968  -1.6084968  -1.6084968 ]
   [-1.5590047  -1.5590047  -1.5590047 ]
   [-1.5095124  -1.5095124  -1.5095124 ]
   [-1.4600202  -1.4600202  -1.4600202 ]
   [-1.410528   -1.410528   -1.410528  ]]

  [[-1.3610358  -1.3610358  -1.3610358 ]
   [-1.3115436  -1.3115436  -1.3115436 ]
   [-1.2620513  -1.2620513  -1.2620513 ]
   [-1.2125591  -1.2125591  -1.2125591 ]
   [-1.163067   -1.163067   -1.163067  ]
   [-1.1135747  -1.1135747  -1.1135747 ]
   [-1.0640825  -1.0640825  -1.0640825 ]]

  [[-1.0145903  -1.0145903  -1.0145903 ]
   [-0.9650981  -0.9650981  -0.9650981 ]
   [-0.9156059  -0.9156059  -0.9156059 ]
   [-0.86611366 -0.86611366 -0.86611366]
   [-0.8166215  -0.8166215  -0.8166215 ]
   [-0.76712924 -0.76712924 -0.76712924]
   [-0.71763706 -0.71763706 -0.71763706]]

  [[-0.6681448  -0.6681448  -0.6681448 ]
   [-0.61865264 -0.61865264 -0.61865264]
   [-0.5691604  -0.5691604  -0.5691604 ]

### Layer Normalization 1

Results from `layer_normalization`, same behaviour as `torch.nn.LayerNorm((C, H, W))`.

In [5]:
print(layer_normalization(x))

tf.Tensor(
[[[[-1.7156329e+00 -1.6826400e+00 -1.6496470e+00]
   [-1.6166540e+00 -1.5836611e+00 -1.5506682e+00]
   [-1.5176753e+00 -1.4846823e+00 -1.4516894e+00]
   [-1.4186964e+00 -1.3857036e+00 -1.3527106e+00]
   [-1.3197176e+00 -1.2867247e+00 -1.2537317e+00]
   [-1.2207388e+00 -1.1877459e+00 -1.1547530e+00]
   [-1.1217600e+00 -1.0887671e+00 -1.0557741e+00]]

  [[-1.0227811e+00 -9.8978823e-01 -9.5679533e-01]
   [-9.2380238e-01 -8.9080942e-01 -8.5781652e-01]
   [-8.2482356e-01 -7.9183060e-01 -7.5883770e-01]
   [-7.2584474e-01 -6.9285184e-01 -6.5985888e-01]
   [-6.2686592e-01 -5.9387302e-01 -5.6088006e-01]
   [-5.2788711e-01 -4.9489418e-01 -4.6190125e-01]
   [-4.2890832e-01 -3.9591539e-01 -3.6292243e-01]]

  [[-3.2992950e-01 -2.9693657e-01 -2.6394361e-01]
   [-2.3095068e-01 -1.9795775e-01 -1.6496481e-01]
   [-1.3197188e-01 -9.8978937e-02 -6.5986000e-02]
   [-3.2993063e-02 -1.2585807e-07  3.2992810e-02]
   [ 6.5985747e-02  9.8978683e-02  1.3197163e-01]
   [ 1.6496456e-01  1.9795750e-01  

### Layer Normalization 2

Results from `layer_normalization_tfsim`, same behaviour as `tf.keras.layers.LayerNormalization()`.

In [6]:
print(layer_normalization_tfsim(x))

tf.Tensor(
[[[[-1.224653  0.        1.224653]
   [-1.224653  0.        1.224653]
   [-1.224653  0.        1.224653]
   [-1.224653  0.        1.224653]
   [-1.224653  0.        1.224653]
   [-1.224653  0.        1.224653]
   [-1.224653  0.        1.224653]]

  [[-1.224653  0.        1.224653]
   [-1.224653  0.        1.224653]
   [-1.224653  0.        1.224653]
   [-1.224653  0.        1.224653]
   [-1.224653  0.        1.224653]
   [-1.224653  0.        1.224653]
   [-1.224653  0.        1.224653]]

  [[-1.224653  0.        1.224653]
   [-1.224653  0.        1.224653]
   [-1.224653  0.        1.224653]
   [-1.224653  0.        1.224653]
   [-1.224653  0.        1.224653]
   [-1.224653  0.        1.224653]
   [-1.224653  0.        1.224653]]

  [[-1.224653  0.        1.224653]
   [-1.224653  0.        1.224653]
   [-1.224653  0.        1.224653]
   [-1.224653  0.        1.224653]
   [-1.224653  0.        1.224653]
   [-1.224653  0.        1.224653]
   [-1.224653  0.        1.224653]]

 

### Layer Normalization - Keras

Layer Normalization result of `tf.keras.layers.LayerNormalization()`

In [7]:
print(tf.keras.layers.LayerNormalization()(x, training=True))

tf.Tensor(
[[[[-1.2238274  0.         1.2238274]
   [-1.2238274  0.         1.2238274]
   [-1.2238274  0.         1.2238274]
   [-1.2238274  0.         1.2238274]
   [-1.2238274  0.         1.2238274]
   [-1.2238274  0.         1.2238274]
   [-1.2238274  0.         1.2238274]]

  [[-1.2238274  0.         1.2238274]
   [-1.2238274  0.         1.2238274]
   [-1.2238274  0.         1.2238274]
   [-1.2238274  0.         1.2238274]
   [-1.2238274  0.         1.2238274]
   [-1.2238274  0.         1.2238274]
   [-1.2238274  0.         1.2238274]]

  [[-1.2238274  0.         1.2238274]
   [-1.2238274  0.         1.2238274]
   [-1.2238274  0.         1.2238274]
   [-1.2238274  0.         1.2238274]
   [-1.2238274  0.         1.2238274]
   [-1.2238274  0.         1.2238274]
   [-1.2238274  0.         1.2238274]]

  [[-1.2238274  0.         1.2238274]
   [-1.2238274  0.         1.2238274]
   [-1.2238274  0.         1.2238274]
   [-1.2238274  0.         1.2238274]
   [-1.2238274  0.         1.2238