**Weight Initialization**

Let's continue our exploration of PyTorch operations with weight initialization. We'll stick with our simple four-layer CNN which we'll use to train MNIST. The network is defined in our `Net` class from `pytorch/examples/mnist/main.py`. To run this, we need to import `torch` and its `nn` and `nn.functional` modules:

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

Now we can go ahead and run our `Net` class:

In [2]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5, 1)
        self.conv2 = nn.Conv2d(20, 50, 5, 1)
        self.fc1 = nn.Linear(4*4*50, 500)
        self.fc2 = nn.Linear(500, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 4*4*50)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

Before passing data through our network, we need to initialize our weights and biases. How is this done? Let's examine `nn.Conv2d`. This is defined in `nn/modules/conv.py`. The `Conv2d` class inherits from `_ConvNd`. Here, we see that parameters are initialized using the `reset_parameters` method, which uses `init.kaiming_uniform_` to initialize weights and `init.uniform_` to initialize biases.

As for `nn.Linear2d`, this is defined in `nn/modules/linear.py`. We see what appears to be the exact same `reset_parameters` method. This uses Kaiming uniform initialization for weights with `a=math.sqrt(5)`, and uniform initialization for biases with upper and lower bound defined by `1 / math.sqrt(fan_in)`, and `fan_in` defined by the `_calculate_fan_in_and_fan_out` function in `init`.

But let's say we want to initialize our weights manually. How do we do this? Let's first consider our `conv1` layer: 

In [3]:
conv1 = nn.Conv2d(1, 20, 5, 1)

Want to see something cool? Calling `nn.Conv2d` already initialized our weights and biases. Let's examine our weights:

In [4]:
conv1.weight

Parameter containing:
tensor([[[[ 0.0669,  0.1622,  0.1756, -0.0422,  0.1894],
          [-0.0522,  0.0733,  0.0707,  0.0096, -0.0305],
          [-0.0752, -0.1516,  0.0352, -0.0037,  0.1162],
          [-0.0378,  0.1643,  0.0995,  0.0105, -0.0450],
          [-0.1489,  0.0070, -0.1296,  0.0568, -0.0617]]],


        [[[ 0.0550,  0.1872, -0.1869, -0.0354,  0.1116],
          [-0.1696,  0.0501, -0.1834,  0.0462, -0.0881],
          [-0.1783, -0.1509, -0.1834, -0.0580, -0.1496],
          [ 0.1374, -0.0101,  0.0193,  0.1987, -0.1191],
          [ 0.0914, -0.0171, -0.1371, -0.0735,  0.1027]]],


        [[[-0.0061,  0.0697,  0.1853,  0.0829,  0.0110],
          [ 0.1133,  0.1665, -0.1011, -0.1990,  0.0659],
          [ 0.0577, -0.0833, -0.0402,  0.1506, -0.1789],
          [-0.0279, -0.1408,  0.0101,  0.0709, -0.1537],
          [-0.1510, -0.0629,  0.0632, -0.0332, -0.0309]]],


        [[[-0.1224, -0.0527,  0.0092, -0.1067,  0.1646],
          [-0.1782,  0.1241, -0.0065, -0.0518, -0.0702

We see that our initialization creates a weight tensor of 20 5x5 filters. If we call `conv1.weight.shape`, we see these three dimensions, along with a fourth dimension (batch size of 1):

In [5]:
conv1.weight.shape

torch.Size([20, 1, 5, 5])

Our bias vector adds a constant value to each of the 20 channels:

In [6]:
conv1.bias

Parameter containing:
tensor([ 0.1917,  0.0543,  0.0215, -0.1588, -0.1407,  0.0339,  0.1019, -0.1864,
        -0.1028,  0.1651,  0.1361,  0.0833, -0.1999,  0.1216,  0.0081,  0.0904,
        -0.1882, -0.0160, -0.0861, -0.1152], requires_grad=True)

The shape of this vector is just its length. There are no other dimensions:

In [7]:
conv1.bias.shape

torch.Size([20])

If we train our network with default `kaiming_uniform` initialization, we get the following average loss and accuracy on `examples/mnist/main.py`. We'll just do five epochs. It's a relatively simple classification task, so we get fairly quick convergence:

`Epoch     Average Loss           Accuracy
   1          .1017                .9669
   2          .0614                .9828
   3          .0562                .9809
   4          .0409                .9864
   5          .0384                .9873`

What if we want to change the initialization of our two convolutional and two linear layers? We can use a function from `torch.nn.init`. For example, what if we were to initialize everything with zeros?

In [8]:
torch.nn.init.constant(conv1.weight, 0)

  """Entry point for launching an IPython kernel.


Parameter containing:
tensor([[[[0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.]]],


        [[[0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.]]],


        [[[0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.]]],


        [[[0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.]]],


        [[[0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.]]],


        [[[0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.]

We see that this changes all of our weights to zeros. What if we were to do this to all of the weights and biases in each of our layers? We could define our layers in a list:

`layers = [self.conv1, self.conv2, self.fc1, self.fc2]`

and then loop over this list:

`for layer in layers:
     torch.nn.init.constant_(layer.weight, 0)
     torch.nn.init.constant_(layer.bias, 0)`

As you might expect, our network doesn't learn anything with zero initialization. Our average loss is constant at 2.3010, and our accuracy is .1135. Our network is built to classify one of 10 classes, so this isn't much better than randomly guessing. What if we changed our initialization from zeros to ones?

`for layer in layers:
     torch.nn.init.constant_(layer.weight, 1)
     torch.nn.init.constant_(layer.bias, 1)`

This doesn't help matters - we're still stuck with a network that isn't learning. There are actually functions, `ones_` and `zeros_` that do constant initialization with these values, but neither helps us here. We run into a different problem with `normal` initialization, where our loss quickly goes to `nan`.

What if our biases are interfering with our training process? We could try initializing our biases to zero instead. Let's see if we can get our network to train with weights initialized to one and biases initialized to zero:

`for layer in layers:
     torch.nn.init.ones_(layer.weight)
     torch.nn.init.zeros_(layer.bias)`

This doesn't appear to work either, nor does it work with weights uniformly sampled from the default range (0, 1), nor with weights sampled from the normal distribution (mean: 0, standard devation: 1.0). At least for our simple convolutional network, we need something better. PyTorch gives us a few options that will work for our `nn.Conv2d` layers: `xavier_uniform_`, `xavier_normal_`, `kaiming_uniform_` (the one we previously used to initialize weights), `kaiming_normal_`, and `orthogonal_`. We'll ignore `sparse`, which requires us to set a fraction of our data to zero. 

Let's start with the definition of `xavier_uniform`. This samples our weights uniformly from $(-a, a)$ generated by the following distribution, optionally scaled by a `gain` (set by default to 1.0):

$a = \sqrt{\frac{6}{fan\_in + fan\_out}}$

where `fan_in` is the number of inputs and `fan_out` is the number of outputs. To examine this, we'll use the `xavier_uniform` distribution to initialize our `conv1` weights:

In [9]:
torch.nn.init.xavier_uniform(conv1.weight)

  """Entry point for launching an IPython kernel.


Parameter containing:
tensor([[[[-0.0942, -0.0483, -0.0113,  0.0327,  0.0735],
          [ 0.0142, -0.0276, -0.0245,  0.0523,  0.0204],
          [ 0.1066,  0.0465, -0.0645, -0.0937,  0.1053],
          [ 0.0195, -0.0727,  0.0543,  0.0214, -0.0776],
          [ 0.0332,  0.0985,  0.0767, -0.0131, -0.0736]]],


        [[[ 0.0087, -0.0006, -0.0852, -0.0240, -0.0907],
          [ 0.0076,  0.0365,  0.0887, -0.0775,  0.0775],
          [-0.0799,  0.0880, -0.0429, -0.0929,  0.0177],
          [-0.0411, -0.0939,  0.0384,  0.0185,  0.0767],
          [ 0.0339, -0.0310,  0.1029,  0.0458,  0.0971]]],


        [[[-0.0427,  0.0389,  0.0045, -0.0134,  0.0908],
          [-0.0892,  0.0697,  0.0492,  0.0887,  0.0481],
          [ 0.1003,  0.1027, -0.0305,  0.0905, -0.0963],
          [-0.0206, -0.0073, -0.0381,  0.0650, -0.0942],
          [-0.1059,  0.0646, -0.0283, -0.1051, -0.0370]]],


        [[[-0.0484, -0.1037,  0.0298,  0.0819, -0.0948],
          [ 0.0238,  0.0971,  0.0405,  0.0370,  0.0016

Note that the mean value of our weights is much smaller than when using prior distributions (except for `normal`, which could potentially initialize closer to zero). For completeness, we'll also include Kaiming and orthogonal initializers:

In [10]:
print("Mean initialization")
torch.nn.init.zeros_(conv1.weight)
print ("Zeros: " + str(torch.mean(conv1.weight).item()))
torch.nn.init.ones_(conv1.weight)
print ("Ones: " + str(torch.mean(conv1.weight).item()))
torch.nn.init.uniform_(conv1.weight)
print ("Uniform: " + str(torch.mean(conv1.weight).item()))
torch.nn.init.normal_(conv1.weight)
print ("Normal: " + str(torch.mean(conv1.weight).item()))
torch.nn.init.xavier_uniform_(conv1.weight)
print ("Xavier uniform: " + str(torch.mean(conv1.weight).item()))
torch.nn.init.xavier_normal_(conv1.weight)
print ("Xavier normal: " + str(torch.mean(conv1.weight).item()))
torch.nn.init.kaiming_uniform_(conv1.weight)
print ("Kaiming uniform: " + str(torch.mean(conv1.weight).item()))
torch.nn.init.kaiming_normal_(conv1.weight)
print ("Kaiming normal: " + str(torch.mean(conv1.weight).item()))
torch.nn.init.orthogonal_(conv1.weight)
print ("Orthogonal: " + str(torch.mean(conv1.weight).item()))

Mean initialization
Zeros: 0.0
Ones: 1.0
Uniform: 0.49642595648765564
Normal: 0.007640771102160215
Xavier uniform: -0.002580041531473398
Xavier normal: 0.001514861243776977
Kaiming uniform: -0.011375376023352146
Kaiming normal: -0.00296521233394742
Orthogonal: 0.006012782920151949


Let's record average loss and accuracy for our Xavier uniform initialization. You can compare this with the "initial" Kaiming uniform initialization, but we'll do a separate run with zero initialized biases so that we get a stricter comparison across different forms of weight initialization:

`Epoch     Average Loss           Accuracy
   1          .0745                .9781
   2          .0587                .9817
   3          .0434                .9859
   4          .0446                .9852
   5          .0346                .9884`

This is actually better than what we got with Kaiming uniform. Let's compare now with `xavier_normal_`. Here we sample from a normal distribution with a mean of 0 and a variance of $\sigma^{2}$, where

$\sigma = \sqrt \frac{2}{fan\_in + fan\_out}$

Let's compare average loss:

`Epoch     Xavier uniform       Xavier normal
   1          .0745                .0921
   2          .0587                .0702
   3          .0434                .0478
   4          .0446                .0414
   5          .0346                .0423`

and accuracy:

`Epoch     Xavier uniform       Xavier normal
   1          .9781                .9717
   2          .9817                .9773
   3          .9859                .9848
   4          .9852                .9859
   5          .9884                .9884`
   
Our loss is slightly higher with `xavier_normal_`, but we end up with the same accuracy after five epochs of training.

Like Xavier uniform, Kaiming uniform loss samples uniformly between a lower and upper bound $(-bound, bound)$. This time we use the following function to generate our $bound$:

$\sqrt \frac{6}{(1 + a^{2}) \cdot fan\_in}$

We set $a$ to the negative slope of the rectifier (e.g. for a leaky ReLU activation function). Since we're using ReLU, our negative values are set to zero, and thus we'll just use the default value of 0 for $a$. We end up with the following:

$\sqrt \frac{6}{2 \cdot fan\_in}$

Thus if $fan\_in > fan\_out$, the mean value of our weights will be greater with Xavier uniform than with Kaiming uniform. If $fan\_in < fan\_out$, the mean value will be greater with Kaiming uniform than with Xavier uniform. However, we also have the option of using $fan\_out$ instead of $fan\_in$ to compute our $bound$.

Let's add Kaiming uniform to our average loss and accuracy tables:

Average loss:

`Epoch   Xavier uniform    Xavier normal    Kaiming uniform
   1         .0745             .0921            .0615
   2         .0587             .0702            .0501
   3         .0434             .0478            .0416
   4         .0446             .0414            .0441
   5         .0346             .0423            .0360`
   
Accuracy:

`Epoch    Xavier uniform     Xavier normal  Kaiming uniform
   1         .9781             .9717            .9807
   2         .9817             .9773            .9839
   3         .9859             .9848            .9861
   4         .9852             .9859            .9860
   5         .9884             .9884            .9883`

Now let's add Kaiming normal to our loss and accuracy tables:

Average loss:

`Epoch  Xavier uni  Xavier norm   Kaiming uni   Kaiming norm
   1      .0745        .0921        .0615         .0824
   2      .0587        .0702        .0501         .0654
   3      .0434        .0478        .0416         .0489
   4      .0446        .0414        .0441         .0416
   5      .0346        .0423        .0360         .0452`
   
Accuracy:

`Epoch  Xavier uni   Xavier norm  Kaiming uni  Kaiming norm
   1      .9781        .9717        .9807         .9733
   2      .9817        .9773        .9839         .9771
   3      .9859        .9848        .9861         .9840
   4      .9852        .9859        .9860         .9851
   5      .9884        .9884        .9883         .9852`

Finally, `orthogonal_` initializes the tensor with an orthogonal matrix. See Saxe et al. 2014: https://arxiv.org/pdf/1312.6120.pdf. This yields the following:

Average loss:

`Epoch Xavier uni  Xavier norm  Kaiming uni  Kaiming norm  Orthogonal
   1      .0745      .0921        .0615         .0824        .0855
   2      .0587      .0702        .0501         .0654        .0613
   3      .0434      .0478        .0416         .0489        .0499
   4      .0446      .0414        .0441         .0416        .0354
   5      .0346      .0423        .0360         .0452        .0448`
   
Accuracy:

`Epoch Xavier uni  Xavier norm  Kaiming uni  Kaiming norm  Orthogonal
   1      .9781      .9717        .9807         .9733        .9737
   2      .9817      .9773        .9839         .9771        .9805
   3      .9859      .9848        .9861         .9840        .9838
   4      .9852      .9859        .9860         .9851        .9885
   5      .9884      .9884        .9883         .9852        .9853`

In terms of accuracy, there isn't too much variation between these functions, at least for this simple image classification task.

Out of curiosity, how does PyTorch initialize other layers? This is done in `nn.modules` using the `reset_parameters` method of the constructor of each `Module`.

Attention layers initialize weights with Xavier uniform, `in_proj_bias` and `out_proj_bias` with zeros, and `bias_k` and `bias_v` with Xavier normal.

Batch, group, instance, and layer normalization layers initialize weights with ones and biases with zeros.

Convolutional, linear, and bilinear layers initialize weights with Kaiming uniform and biases uniformly, with `bound` defined as `1 / math.sqrt(fan_in)`.

RNN layers initialize weights uniformly, with `bound` defined as `1 / math.sqrt(self.hidden_size)`.

Sparse layers initialize weights normally.

Transformer layers initialize weights with Xavier uniform.