In [None]:
import torch

In [None]:
class LeNet5(torch.nn.Module):
    def __init__(self):   
        super(LeNet5, self).__init__()
        # Convolution (In LeNet-5, 32x32 images are given as input. Hence padding of 2 is done below)
        self.conv1 = torch.nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, stride=1, padding=2, bias=True)
        # Max-pooling
        self.max_pool_1 = torch.nn.MaxPool2d(kernel_size=2)
        # Convolution
        self.conv2 = torch.nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5, stride=1, padding=0, bias=True)
        # Max-pooling
        self.max_pool_2 = torch.nn.MaxPool2d(kernel_size=2)
        # Fully connected layer
        self.fc1 = torch.nn.Linear(16*5*5, 120)   # convert matrix with 16*5*5 (= 400) features to a matrix of 120 features (columns)
        self.fc2 = torch.nn.Linear(120, 84)       # convert matrix with 120 features to a matrix of 84 features (columns)
        self.fc3 = torch.nn.Linear(84, 10)        # convert matrix with 84 features to a matrix of 10 features (columns)

    def forward(self, x):
        # convolve, then perform sigmoid non-linearity
        x = torch.sigmoid(self.conv1(x))  
        # max-pooling with 2x2 grid 
        x = self.max_pool_1(x) 
        # convolve, then perform sigmoid non-linearity
        x = torch.sigmoid(self.conv2(x))
        # max-pooling with 2x2 grid
        x = self.max_pool_2(x)
        # first flatten 'max_pool_2_out' to contain 16*5*5 columns
        # read through https://stackoverflow.com/a/42482819/7551231
        x = x.view(-1, 16*5*5)
        # FC-1, then perform sigmoid non-linearity
        x = torch.sigmoid(self.fc1(x))
        # FC-2, then perform sigmoid non-linearity
        x = torch.sigmoid(self.fc2(x))
        # FC-3
        x = self.fc3(x)        
        return x

In [None]:
# create an instance for your network
net = LeNet5()

In [None]:
# create a random input of size 1 x 1 x 28 x 28
X = torch.rand(size=(1, 1, 28, 28))

In [None]:
# feed your input to the network and check the size of the output
y = net(X)
y.shape