In [1]:
import torch
import torch.nn as nn 
import torch.nn.functional as F

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
in_dim = 1
num_class = 10
fully_connect= [84]

In [4]:
class Lenet(nn.Module):
    def __init__(self, in_dim, num_class):
        super(Lenet, self).__init__()
        self.in_dim = in_dim
        self.num_class = num_class
        
        self.conv1 = nn.Conv2d(in_dim, 6, kernel_size = 5)
        self.pool1 = nn.AvgPool2d(kernel_size =2, stride = 2)
        self.conv2 = nn.Conv2d(6, 16, kernel_size = 5)
        self.pool2 = nn.AvgPool2d(kernel_size =2, stride = 2)

    
        
    def forward(self, x):
        out = self.conv1(x)
        out = F.relu(out)
        out = self.pool1(out)
        
        out = self.conv2(out)
        out = F.relu(out)
        out = self.pool2(out)
        
        out = torch.flatten(out, 1)
        
        out = nn.Linear(out.shape[1],fully_connect[0])
        out = nn.Linear(fully_connect[0],num_class)
        out = nn.Softmax(out)
        return out
        
        
        
        
        

In [5]:
model = Lenet(in_dim, num_class).to(device)

In [6]:
model

Lenet(
  (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
  (pool1): AvgPool2d(kernel_size=2, stride=2, padding=0)
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (pool2): AvgPool2d(kernel_size=2, stride=2, padding=0)
)

In [7]:
import torchsummary
torchsummary.summary(model,(1,32,32))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1            [-1, 6, 28, 28]             156
         AvgPool2d-2            [-1, 6, 14, 14]               0
            Conv2d-3           [-1, 16, 10, 10]           2,416
         AvgPool2d-4             [-1, 16, 5, 5]               0
Total params: 2,572
Trainable params: 2,572
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.06
Params size (MB): 0.01
Estimated Total Size (MB): 0.07
----------------------------------------------------------------
