In [9]:
# Python imports
import os
import math

# PyTorch imports
import torch
import torchvision

# Third-party imports
import matplotlib.pyplot as plt
from tabulate import tabulate

# 0. Constants

In [10]:
BATCH_SIZE = 10
LEARNING_RATE = 0.1
EPOCHS = 10
WIDTH, HEIGHT = 28, 28
SAVED_FILENAME = 'MNIST-1-2'
LAYER_SIZES = [16, 32, 64, 128, 256]

# 1. Dataset

In [11]:
# Load datasets.
dataset_training = torchvision.datasets.MNIST('./data', download=True, train=True, transform=torchvision.transforms.ToTensor())
dataset_test = torchvision.datasets.MNIST('./data', download=True, train=False, transform=torchvision.transforms.ToTensor())

# Create data loaders.
dataloader_training = torch.utils.data.DataLoader(dataset_training, batch_size=BATCH_SIZE, shuffle=True)
dataloader_test = torch.utils.data.DataLoader(dataset_test, batch_size=BATCH_SIZE, shuffle=True)

# 2. Models

In [12]:
# Create a model for each learning rate.
models = []
for i in range(len(LAYER_SIZES)):
    models.append(torch.nn.Sequential(
        torch.nn.Flatten(),
        torch.nn.Linear(WIDTH*HEIGHT, LAYER_SIZES[i]),
        torch.nn.Linear(LAYER_SIZES[i], 10),
        torch.nn.Softmax(dim=1)
    ))

## 2.1. Statistics of the Models

In [13]:
nums_parameters = []
for i in range(len(LAYER_SIZES)):
    num_parameters = 0
    for params in models[i].parameters():
        num_parameters_temp = 1
        size = params.size()
        for dim_size in size:
            num_parameters_temp *= dim_size
        num_parameters += num_parameters_temp
    nums_parameters.append(num_parameters)

headers = ['Layer Size', 'Number of Parameters']

rows = []
for i in range(len(LAYER_SIZES)):
    rows.append([LAYER_SIZES[i], nums_parameters[i]])

print(tabulate(rows, headers=headers, tablefmt='github'))

|   Layer Size |   Number of Parameters |
|--------------|------------------------|
|           16 |                  12730 |
|           32 |                  25450 |
|           64 |                  50890 |
|          128 |                 101770 |
|          256 |                 203530 |
