In [1]:
import tensorflow as tf
import numpy as np

tf.enable_eager_execution()

  from ._conv import register_converters as _register_converters


In [2]:
# \had{a_i^l} = g_i^l (a_i^l - \mu_i^l), \mu_i^l = sum(a_i^l), \sigma_i^l = \sqrt{sum(a_i^l - \mu_i^l)^2}

In [9]:
np.random.seed(1234)

epsilon = 1e-6
x = np.random.normal(scale=1e-4, size=(2, 10, 20))
filters = x.shape[-1]
scale = np.ones((filters,))
bais = np.zeros((filters,))

In [18]:
mean = tf.reduce_mean(x, axis=[-1], keep_dims=True)
variance = tf.reduce_mean(tf.square(x - mean), axis=[-1], keep_dims=True)
norm_x = (x - mean) * tf.rsqrt(variance + epsilon)
norm_x * scale + bais

norm_x[0][0]

<tf.Tensor: id=125, shape=(20,), dtype=float64, numpy=
array([ 0.04799724, -0.11729337,  0.14357481, -0.02996317, -0.07052361,
        0.08933233,  0.08659064, -0.06216515,  0.00268393, -0.22186298,
        0.11546931,  0.09975072,  0.09591062, -0.19984657, -0.03209346,
        0.00133389,  0.04143679,  0.02986717,  0.13248372, -0.15268285])>

In [19]:
x[0][0]

array([ 4.71435164e-05, -1.19097569e-04,  1.43270697e-04, -3.12651896e-05,
       -7.20588733e-05,  8.87162940e-05,  8.59588414e-05, -6.36523504e-05,
        1.56963721e-06, -2.24268495e-04,  1.15003572e-04,  9.91946022e-05,
        9.53324128e-05, -2.02125482e-04, -3.34077366e-05,  2.11836468e-07,
        4.05453412e-05,  2.89091941e-05,  1.32115819e-04, -1.54690555e-04])

In [33]:
class LayerNorm(tf.keras.layers.Wrapper):
    def __init__(self, layer, epsilon, **kwargs):
        self._epsilon = epsilon
        super(LayerNorm, self).__init__(layer, **kwargs)
    def build(self, input_shape):
        filters = input_shape[-1]
        self._scale = self.add_weight(
            'scale',
            [filters],
            initializer='ones')
        self._bias = self.add_weight(
            'bias',
            [filters],
            initializer='zeros')
        self.layer.build(input_shape)
        super(LayerNorm, self).build()
    def call(self, x):
        output = self.layer.call(x)
        print('out', output)
        mean = tf.reduce_mean(output, axis=[-1], keep_dims=True)
        variance = tf.reduce_mean(tf.square(x - mean), axis=[-1], keep_dims=True)
        norm_x = (output - mean) * tf.rsqrt(variance + epsilon)
        return norm_x * scale + bais
    def compute_output_shape(self, input_shape):
        return input_shape

In [36]:
inputs = tf.keras.layers.Input(shape=(10, 20))
x = LayerNorm(tf.keras.layers.Dense(20), 1e-6)(inputs)

model = tf.keras.models.Model(inputs=inputs, outputs=x)
model.compile(
    optimizer=tf.train.GradientDescentOptimizer(0.001),
    loss='mse')

In [37]:
np.random.seed(1234)
x_ = np.random.normal(scale=1e-4, size=(2, 10, 20))
model.predict(x_)

out tf.Tensor(
[[[-7.68330210e-05 -6.66574269e-05  1.00668149e-05 -3.09149109e-05
    5.77735736e-05 -4.58065260e-05 -1.13583301e-05 -1.07797088e-04
    1.12874695e-04 -1.43084486e-04  2.16209664e-04 -1.02537684e-04
   -1.32764835e-04  5.67980896e-05 -2.35810730e-04  1.02158680e-04
   -3.48872818e-05 -1.06429761e-04  5.54457365e-05  4.25622056e-05]
  [ 1.06397143e-04  6.62057428e-05  8.06373064e-05  4.73461114e-07
    6.59351936e-05  4.08822052e-05 -4.15121031e-06  9.95049559e-05
   -8.34117964e-05 -1.63705026e-05 -1.25948558e-04  2.32387785e-04
    1.13908012e-04  6.93124239e-05  1.73229695e-04 -9.41087346e-05
    1.36146235e-04  5.28758392e-05  9.07920694e-05 -1.67731661e-04]
  [-3.58545258e-05 -2.06587269e-04 -5.33008133e-05  1.14652808e-04
    8.66736173e-06 -1.20608784e-04 -2.77581257e-05  1.98356414e-04
   -2.03340111e-04  6.57756318e-05  6.19568818e-05  3.18791899e-05
    2.04294018e-04 -7.60200346e-05  7.45687867e-05 -5.34609026e-05
   -2.21530991e-05 -1.18065080e-04  1.9446997

array([[[-0.05445839, -0.04434316,  0.03192592, -0.00881269,
          0.07934966, -0.02361596,  0.01062788, -0.08523876,
          0.13412389, -0.12031682,  0.23684582, -0.08001056,
         -0.11005839,  0.07837996, -0.21249296,  0.12347145,
         -0.01276149, -0.08387955,  0.07703563,  0.06422853],
        [ 0.06429698,  0.02426241,  0.03863766, -0.04121336,
          0.02399292, -0.00096231, -0.04581999,  0.05743168,
         -0.12477127, -0.0579916 , -0.16714205,  0.18979597,
          0.07177854,  0.02735697,  0.13086873, -0.13542648,
          0.09392998,  0.01098453,  0.0487528 , -0.20876211],
        [-0.02359269, -0.19374341, -0.0409795 ,  0.12640156,
          0.02077742, -0.10805802, -0.01552389,  0.20981982,
         -0.19050732,  0.07769101,  0.07388528,  0.04391012,
          0.2157372 , -0.06362128,  0.08645419, -0.04113905,
         -0.00993797, -0.10552299,  0.01407768, -0.07612817],
        [ 0.04009439,  0.10482626, -0.10607146, -0.09912214,
          0.16535704,

In [26]:
x_

array([[[ 4.71435164e-05, -1.19097569e-04,  1.43270697e-04,
         -3.12651896e-05, -7.20588733e-05,  8.87162940e-05,
          8.59588414e-05, -6.36523504e-05,  1.56963721e-06,
         -2.24268495e-04,  1.15003572e-04,  9.91946022e-05,
          9.53324128e-05, -2.02125482e-04, -3.34077366e-05,
          2.11836468e-07,  4.05453412e-05,  2.89091941e-05,
          1.32115819e-04, -1.54690555e-04],
        [-2.02646325e-05, -6.55969344e-05,  1.93421376e-05,
          5.53438911e-05,  1.31815155e-04, -4.69305285e-05,
          6.75554085e-05, -1.81702723e-04, -1.83108540e-05,
          1.05896919e-04, -3.97840228e-05,  3.37437654e-05,
          1.04757857e-04,  1.04593826e-04,  8.63717292e-05,
         -1.22091575e-05,  1.24712954e-05, -3.22794806e-05,
          8.41674713e-05,  2.39096052e-04],
        [ 7.61995878e-06, -5.66445930e-05,  3.61419367e-06,
         -2.07497760e-04,  2.47792200e-05, -8.97156784e-05,
         -1.36794833e-05,  1.82891913e-06,  7.55413982e-05,
          2.