In [1]:
import torch

## Why you need a good init

To understand why initialization is important in a neural net, we'll focus on the basic operation you have there: matrix multiplications. So let's just take a vector `x`, and a matrix `a` initiliazed randomly, then multiply them 100 times (as if we had 100 layers).

In [2]:
x = torch.randn(512)
a = torch.randn(512,512)

In [3]:
for i in range(100): x = a @ x

In [4]:
x.mean(),x.std()

(tensor(nan), tensor(nan))

The problem you'll get with that is activation explosion: very soon, your activations will go to nan. We can even ask the loop to break when that first happens:

In [5]:
x = torch.randn(512)
a = torch.randn(512,512)

In [6]:
for i in range(100): 
    x = a @ x
    if x.std() != x.std(): break

In [7]:
i

28

It only takes 27 multiplications! On the other hand, if you initialize your activations with a scale that is too low, then you'll get another problem:

In [8]:
x = torch.randn(512)
a = torch.randn(512,512) * 0.01

for i in range(100): x = a @ x

x.mean(),x.std()

(tensor(0.), tensor(0.))

Here, every activation vanished to 0. So to avoid that problem, people have come with several strategies to initialize their weight matices, such as:

* use a standard deviation that will make sure x and Ax have exactly the same scale
* use an orthogonal matrix to initialize the weight (orthogonal matrices have the special property that they preserve the L2 norm, so x and Ax would have the same sum of squares in that case)
* use [spectral normalization](https://arxiv.org/pdf/1802.05957.pdf) on the matrix A (the spectral norm of A is the least possible number M such that torch.norm(A@x) <= M*torch.norm(x) so dividing A by this M insures you don't overflow. You can still vanish with this)

## The magic number

Here we will focus on the first one, which is the Xavier initialization. It tells us that we should use a scale equal to `1/math.sqrt(n_in)` where `n_in` is the number of inputs of our matrix.

In [9]:
import math

In [10]:
x = torch.randn(512)
a = torch.randn(512,512) / math.sqrt(512)

In [11]:
for i in range(100): x = a @ x

In [12]:
x.mean(),x.std()

(tensor(-0.1086), tensor(4.8214))