In [72]:
import tensorflow as tf
class LayerNorm(tf.keras.layers.Layer):
    
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        
    def build(self, batch_input_shape):
        # The first shape would be the batch size
        self.units = tf.constant(100.)
        self.kernel = self.add_weight(name="kernel", shape=batch_input_shape[-1:], initializer="ones")
        self.bias = self.add_weight(name="bias", shape=batch_input_shape[-1:], initializer="zeros")
        super().build(batch_input_shape)
        
    def call(self, inputs):
        self.mean, self.variance = tf.nn.moments(inputs, axes=-1,keepdims=True)
        self.epsilon = 3e-4
        self.std_dev = tf.math.sqrt(self.variance)
        # Matrix multiplication MUST be between 2-dimensional matrices
        return (((inputs - self.mean) * (self.kernel))/(self.std_dev + self.epsilon)) + self.bias

In [74]:
# Building the model
import numpy as np
layerNormalization = LayerNorm()
example_batch = np.arange(32*24*24).reshape((32,24,24)).astype(np.float32)
layerNormalization.build(example_batch.shape)

In [75]:
# Sanity Check
layerNormalization(example_batch)

<tf.Tensor: shape=(32, 24, 24), dtype=float32, numpy=
array([[[-1.6612529, -1.5167961, -1.3723392, ...,  1.3723392,
          1.5167961,  1.6612529],
        [-1.6612529, -1.5167961, -1.3723392, ...,  1.3723392,
          1.5167961,  1.6612529],
        [-1.6612529, -1.5167961, -1.3723392, ...,  1.3723392,
          1.5167961,  1.6612529],
        ...,
        [-1.6612529, -1.5167961, -1.3723392, ...,  1.3723392,
          1.5167961,  1.6612529],
        [-1.6612529, -1.5167961, -1.3723392, ...,  1.3723392,
          1.5167961,  1.6612529],
        [-1.6612529, -1.5167961, -1.3723392, ...,  1.3723392,
          1.5167961,  1.6612529]],

       [[-1.6612529, -1.5167961, -1.3723392, ...,  1.3723392,
          1.5167961,  1.6612529],
        [-1.6612529, -1.5167961, -1.3723392, ...,  1.3723392,
          1.5167961,  1.6612529],
        [-1.6612529, -1.5167961, -1.3723392, ...,  1.3723392,
          1.5167961,  1.6612529],
        ...,
        [-1.6612529, -1.5167961, -1.3723392, ...,  1.3

In [76]:
# Looking at the mean squared error between the LayerNormalization and the custom one we just made
keras_layer_norm = tf.keras.layers.LayerNormalization()
tf.reduce_mean(tf.keras.losses.mean_absolute_error(keras_layer_norm(example_batch), model(example_batch)))
# Awesome! A very small difference, so they are basically the same thing

<tf.Tensor: shape=(), dtype=float32, numpy=2.8462464e-05>