<a href="https://colab.research.google.com/github/ramenwang/deep_learning_py/blob/master/weight_initialization/Weight_Initialization_Study.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Study of Weight Initialization in Deep Learning

Weight initalization is significantly important in training a deep neural networks. The aim of weight initalization is to prevent the activation outputs from exploding or vanishing during the forward pass. In back propagation, neither an huge loss nor a extremely small loss can effectively train the model. The networks would take so long to converge if it is even possible to do so.

In [1]:
import tensorflow as tf
import numpy as np

print(tf.__version__)

2.1.0


Make sure the tensorflow version is 2.0+; otherwise, reinstall tensorflow using pip

```
!pip uninstall tensorflow
!pip install tensorflow
```

In [23]:
# example of weight exploding
x = tf.random.normal(shape=[1,512], mean=.0, stddev=1.0)

for i in range(100):
    a = tf.random.normal(shape=[512, 512], mean=0.0, stddev=1.0)
    x = tf.matmul(x, a)
    m, std = tf.math.reduce_mean(x), tf.math.reduce_std(x)
    if tf.math.is_inf(std):
        print((i, m.numpy(), std.numpy()))
        break
    else:
        print((i, m.numpy(), std.numpy()))

(0, 0.1536276, 23.241276)
(1, -3.957141, 514.9028)
(2, 537.3411, 11976.59)
(3, -20230.898, 254805.45)
(4, 65098.062, 5869629.0)
(5, -2271739.5, 130253656.0)
(6, -107079810.0, 2731474400.0)
(7, -886669600.0, 59931824000.0)
(8, -119571590000.0, 1310310500000.0)
(9, 296592300000.0, 30777180000000.0)
(10, 48731390000000.0, 683598540000000.0)
(11, -910940000000000.0, 1.6155766e+16)
(12, 1.5143387e+16, 3.7082878e+17)
(13, 2.3044699e+17, inf)


In [24]:
# example of weight exploding
x = tf.random.normal(shape=[1,512], mean=.0, stddev=1.0) * 0.01

for i in range(100):
    a = tf.random.normal(shape=[512, 512], mean=0.0, stddev=1.0) * 0.01
    x = tf.matmul(x, a)
    m, std = tf.math.reduce_mean(x), tf.math.reduce_std(x)
    if std.numpy() == 0.0:
        print((i, m.numpy(), std.numpy()))
        break
    else:
        print((i, m.numpy(), std.numpy()))

(0, -3.3166183e-05, 0.0023059484)
(1, -3.309784e-06, 0.00052285)
(2, -6.949214e-06, 0.000121139485)
(3, 1.7582497e-06, 2.7506605e-05)
(4, -2.2310957e-08, 6.228041e-06)
(5, -7.912334e-09, 1.4250608e-06)
(6, 3.4584917e-09, 3.1632982e-07)
(7, -1.9835178e-09, 7.191048e-08)
(8, 5.7213845e-10, 1.620314e-08)
(9, -2.0295728e-10, 3.6839596e-09)
(10, -3.701056e-11, 8.229536e-10)
(11, -2.2165885e-13, 1.8297262e-10)
(12, 1.5651595e-12, 4.1222532e-11)
(13, 1.00011145e-13, 1.0203084e-11)
(14, -4.61738e-16, 2.4104236e-12)
(15, 2.2248845e-15, 5.59719e-13)
(16, -1.3669335e-15, 1.2418145e-13)
(17, -1.3335556e-15, 2.7264911e-14)
(18, -9.82633e-17, 6.2265477e-15)
(19, -3.7630325e-17, 1.4283219e-15)
(20, -8.6034485e-18, 3.2353628e-16)
(21, -5.4255707e-19, 6.9445604e-17)
(22, 1.4292004e-18, 1.5762635e-17)
(23, -1.285622e-19, 3.757705e-18)
(24, 1.670319e-20, 8.640193e-19)
(25, -3.0233726e-21, 1.9432933e-19)
(26, -2.047575e-21, 4.4072907e-20)
(27, 6.930417e-22, 1.0045552e-20)
(28, -7.608546e-23, 2.1840375e-21

In order to  understand how the exploding and vanishing happens, let's first take a look at the forward matrix multiplication at each layer. Assuming at a layer, we had an input in_i and a weight matrix w_i:

In [25]:
in_i, w_i = tf.random.normal([1,625], mean=0., stddev=1.), tf.random.normal([625,10000], mean=0., stddev=1.)
out_i = tf.matmul(in_i, w_i)
m, std = tf.math.reduce_mean(out_i), tf.math.reduce_std(out_i)
print(f'The mean is {m}, and the standard deviation is {std}')

The mean is 0.34199059009552, and the standard deviation is 26.245882034301758


We can find that mean is close to zero, and the standard deviation is close to 25, which is the square root of 625. This is not a coincidence but a fact that each of the element in out_i is the summation of the products between every in_i element and every w_i element in the corresponding column; therefore, the expected mean remains 0 and the expected variance is 1 * 625 while each elements are drew from standard normal distribution.

To combat this issue, we can initalize the weight from a normal distribution with the deviation of square root of 1 / input_sizes. In this case, we can define the standard deviation as 1 / 25

In [26]:
in_i, w_i = tf.random.normal([1,625], mean=0., stddev=1.), tf.random.normal([625,10000], mean=0., stddev=(1/25))
out_i = tf.matmul(in_i, w_i)
m, std = tf.math.reduce_mean(out_i), tf.math.reduce_std(out_i)
print(f'The mean is {m}, and the standard deviation is {std}')

The mean is 0.017422931268811226, and the standard deviation is 0.9813411235809326
