Example code of how to initialize weights for a simple CNN network.
Usually this is not needed as default initialization is usually good,
but sometimes it can be useful to initialize weights in a specific way.
This way of doing it should generalize to other network types just make 
sure to specify and change the modules you wish to modify.

Video explanation: https://youtu.be/xWQ-p_o0Uik
Got any questions leave a comment on youtube :)

Programmed by Aladdin Persson <aladdin.persson at hotmail dot com>
*    2020-04-10 Initial coding
*    2022-12-16 Updated with more detailed comments, and checked code still functions as intended.


In [1]:
# Imports
import torch.nn as nn  # All neural network modules, nn.Linear, nn.Conv2d, BatchNorm, Loss functions
import torch.nn.functional as F  # All functions that don't have any parameters


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class CNN(nn.Module):
    def __init__(self, in_channels, num_classes):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(
            in_channels=in_channels,
            out_channels=6,
            kernel_size=3,
            stride=1,
            padding=1,
        )
        self.pool = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2))
        self.conv2 = nn.Conv2d(
            in_channels=6,
            out_channels=16,
            kernel_size=3,
            stride=1,
            padding=1,
        )
        self.fc1 = nn.Linear(16 * 7 * 7, num_classes)
        self.initialize_weights()

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool(x)
        x = F.relu(self.conv2(x))
        x = self.pool(x)
        x = x.reshape(x.shape[0], -1)
        x = self.fc1(x)

        return x

    def initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_uniform_(m.weight)

                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

            elif isinstance(m, nn.Linear):
                nn.init.kaiming_uniform_(m.weight)
                nn.init.constant_(m.bias, 0)


In [3]:
if __name__ == "__main__":
    model = CNN(in_channels=3, num_classes=10)

    for param in model.parameters():
        print(param)

Parameter containing:
tensor([[[[-0.3707, -0.3078,  0.4500],
          [ 0.0221, -0.2626, -0.4608],
          [ 0.3678, -0.2048, -0.4456]],

         [[ 0.3499, -0.0067, -0.0389],
          [-0.1485, -0.1661, -0.4206],
          [-0.3091,  0.4023, -0.0767]],

         [[ 0.1896,  0.4051, -0.0359],
          [ 0.1109, -0.0946, -0.1071],
          [ 0.1285,  0.0014,  0.2967]]],


        [[[ 0.2722, -0.2402, -0.0798],
          [-0.3548,  0.1059,  0.0581],
          [-0.2836, -0.1344, -0.0522]],

         [[-0.2061,  0.1999, -0.0642],
          [-0.2320,  0.0063,  0.2366],
          [-0.1226,  0.1589, -0.3114]],

         [[-0.3519,  0.0646,  0.4294],
          [-0.1450,  0.1241,  0.2485],
          [-0.2274, -0.3978,  0.0334]]],


        [[[ 0.0414,  0.0404, -0.1024],
          [ 0.0851,  0.2997, -0.4588],
          [-0.2625,  0.2864, -0.2734]],

         [[-0.2386,  0.3928,  0.0985],
          [-0.0172,  0.3967,  0.1748],
          [-0.2097,  0.3347, -0.2104]],

         [[ 0.4281, -0