In [5]:
from torch import nn
import torch

In [18]:
class LeNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 6, kernel_size=(5, 5), bias=True, padding=2)
        self.relu1 = nn.ReLU()
        self.pooling2 = nn.MaxPool2d(kernel_size=(2, 2), stride=2)
        self.conv3 = nn.Conv2d(6, 16, kernel_size=(5, 5), bias=True)
        self.relu3 = nn.ReLU()
        self.pooling4 = nn.MaxPool2d(kernel_size=(2, 2), stride=2)
        self.conv5 = nn.Conv2d(16, 120, kernel_size=(5, 5), bias=True)
        
        self.dense6 = nn.Linear(120, 84)
        self.relu6 = nn.ReLU()
        self.dense7 = nn.Linear(84, 10)
        self.sig8 = nn.LogSoftmax(dim=-1)
        self._init_params()
        
    def forward(self, _input):
        x = self.conv1(_input)
        x = self.relu1(x)
        x = self.pooling2(x)
        x = self.conv3(x)
        x = self.relu3(x)
        x = self.pooling4(x)
        x = self.conv5(x)
        
        x = x.view(_input.size(0), -1)
        
        x = self.dense6(x)
        x = self.relu6(x)
        x = self.dense7(x)
        x = self.sig8(x)
        return x
    
    def _init_params(self):
        for module in self.named_modules():
            if isinstance(module, nn.Conv2d):
                nn.init.kaiming_uniform_(module.weight, mode="fan_out", nonlinearity="relu")
                if module.bias is not None:
                    nn.init.constant_(module.bias, 0)
            elif isinstance(module, nn.BatchNorm2d):
                nn.init.constant_(module.weight, 1)
                nn.init.constant_(module.bias, 0)
        

In [19]:
_input = torch.randn((3, 1, 28, 28))
model = LeNet()
output = model(_input)

In [20]:
output

tensor([[-2.2312, -2.3996, -2.1710, -2.2208, -2.3450, -2.4331, -2.4218, -2.3445,
         -2.3107, -2.1910],
        [-2.2362, -2.3797, -2.2014, -2.1526, -2.3577, -2.4393, -2.3988, -2.3505,
         -2.3341, -2.2180],
        [-2.2602, -2.4045, -2.1986, -2.1572, -2.3690, -2.4040, -2.4005, -2.3279,
         -2.3203, -2.2219]], grad_fn=<LogSoftmaxBackward>)

In [17]:
output.shape

torch.Size([3, 10])

In [None]:
def _init_params(self):
        for module in self.named_modules():
            if isinstance(module, nn.Conv2d):
                nn.init.kaiming_uniform_(module.weight, mode="fan_out", nonlinearity="relu")
                if module.bias is not None:
                    nn.init.constant_(module.bias, 0)
            elif isinstance(module, nn.BatchNorm2d):
                nn.init.constant_(module.weight, 1)
                nn.init.constant_(module.bias, 0)