In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib import style

import torch 
from torch import nn

from torchsummary import summary

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

In [3]:
class lenet(nn.Module):
    def __init__(self, in_channels=1, input_shape=(28, 28), num_classes=10):
        # specially designed for 
        super(lenet, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=6, kernel_size=(5,5), padding=(2,2),
                              stride=(1,1))
        self.conv2 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=(5,5), padding=(2,2), 
                              stride=(1,1))
        self.pool = nn.AvgPool2d(kernel_size=(2,2), stride=(2,2))
        self.fc1 = nn.Linear(in_features=16*7*7, out_features=120)
        self.fc2 = nn.Linear(in_features=120, out_features=84)
        self.fc3 = nn.Linear(in_features=84, out_features=10)
    def forward(self, x):
        x = nn.functional.relu(self.conv1(x))
        x = self.pool(x)
        x = nn.functional.relu(self.conv2(x))
        x = self.pool(x)
        
        x = x.reshape(x.shape[0], 16*7*7)
        x = nn.functional.relu(self.fc1(x))
        x = nn.functional.relu(self.fc2(x))
        return self.fc3(x) # shape(-1, 10)

In [4]:
model = lenet().to(device)
summary(model, (1, 28, 28))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1            [-1, 6, 28, 28]             156
         AvgPool2d-2            [-1, 6, 14, 14]               0
            Conv2d-3           [-1, 16, 14, 14]           2,416
         AvgPool2d-4             [-1, 16, 7, 7]               0
            Linear-5                  [-1, 120]          94,200
            Linear-6                   [-1, 84]          10,164
            Linear-7                   [-1, 10]             850
Total params: 107,786
Trainable params: 107,786
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.08
Params size (MB): 0.41
Estimated Total Size (MB): 0.49
----------------------------------------------------------------
