In [25]:
# Python imports
import os

# PyTorch imports
import torch
import torchvision

# Third-party imports
import matplotlib.pyplot as plt

# Own imports
import utils

# 0. Constants

In [26]:
WIDTH, HEIGHT = 28, 28
CATEGORIES = 10
BATCH_SIZE = 10
LEARNING_RATE = 0.01
EPOCHS = 30
KERNEL_SIZES = [5, 5]
CHANNELS = [1, 6, 16]
FC_LAYER_SIZES = [120, 84]
SAVED_FILENAME = 'MNIST-CNN-LeNet'

# 1. Dataset

In [27]:
# Load datasets.
dataset_training = torchvision.datasets.MNIST('./data', download=True, train=True, transform=torchvision.transforms.ToTensor())
dataset_test = torchvision.datasets.MNIST('./data', download=True, train=False, transform=torchvision.transforms.ToTensor())

# Create data loaders.
dataloader_training = torch.utils.data.DataLoader(dataset_training, batch_size=BATCH_SIZE, shuffle=True)
dataloader_test = torch.utils.data.DataLoader(dataset_test, batch_size=BATCH_SIZE, shuffle=True)

# 2. Model

In [28]:
num_inputs_to_fc_layer = int((((WIDTH - 4) / 2 - 4) / 2) * (((HEIGHT - 4) / 2 - 4) / 2)) * CHANNELS[2]

model = torch.nn.Sequential(
    torch.nn.Conv2d(CHANNELS[0], CHANNELS[1], KERNEL_SIZES[0]),
    torch.nn.AvgPool2d(2),
    torch.nn.Conv2d(CHANNELS[1], CHANNELS[2], KERNEL_SIZES[1]),
    torch.nn.AvgPool2d(2),
    torch.nn.Flatten(),
    torch.nn.Linear(num_inputs_to_fc_layer, FC_LAYER_SIZES[0]),
    torch.nn.Linear(FC_LAYER_SIZES[0], FC_LAYER_SIZES[1]),
    torch.nn.Linear(FC_LAYER_SIZES[1], CATEGORIES),
    torch.nn.Softmax(dim=1),
)

print(model)

Sequential(
  (0): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
  (1): AvgPool2d(kernel_size=2, stride=2, padding=0)
  (2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (3): AvgPool2d(kernel_size=2, stride=2, padding=0)
  (4): Flatten(start_dim=1, end_dim=-1)
  (5): Linear(in_features=256, out_features=120, bias=True)
  (6): Linear(in_features=120, out_features=84, bias=True)
  (7): Linear(in_features=84, out_features=10, bias=True)
  (8): Softmax(dim=1)
)


## 2.1. Statistics of the Model

In [29]:
num_parameters_total = 0
for params in model.parameters():
    num_parameters = 1
    size = params.size()
    for dim_size in size:
        num_parameters *= dim_size
    num_parameters_total += num_parameters

print(f'Total number of parameters: {num_parameters_total}')

Total number of parameters: 44426
