## Library import statements

In [16]:
import time
import datetime

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets
from torchvision.transforms import Compose, ToTensor, Normalize, \
                                   RandomRotation, InterpolationMode
from torch.utils.tensorboard import SummaryWriter

## Import statements for own modules

In [17]:
import network
from train_val import train_loop, validation_loop, update_graphs

## Training hyper-parameter settings 

In [18]:
epochs = 10
learning_rate = 1e-3
weight_decay = 0.001
mbatch_size = 32
mbatch_group = -1
num_workers = 8

## Create target device

In [19]:
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device.")

Using cuda device.


## Load data

In [20]:
test_transform = Compose([ToTensor(),
                          Normalize((0.1307,), (0.3081,))])

# use data augmentation for training set
train_transform = Compose([RandomRotation([-20, 20], 
                           InterpolationMode.BILINEAR),
                           test_transform])

train_set = datasets.MNIST(root='./data', train=True,
                           download=True, transform=train_transform)
test_set = datasets.MNIST(root='./data', train=False,
                          download=True, transform=test_transform)

train_loader = torch.utils.data.DataLoader(train_set, batch_size=mbatch_size,
                                          shuffle=True, num_workers=num_workers)

test_loader = torch.utils.data.DataLoader(test_set, batch_size=mbatch_size,
                                         shuffle=False, num_workers=num_workers)
                                         
num_classes = len(train_set.classes)

## Create network

In [21]:
net = network.Net()
net = net.to(device)

## Display network architecture

In [22]:
from torchinfo import summary
summary(net, input_size=(1, 1, 28, 28), col_names=["input_size", "output_size", "num_params"])

Layer (type:depth-idx)                   Input Shape               Output Shape              Param #
Net                                      [1, 1, 28, 28]            [1, 10]                   --
├─Conv2d: 1-1                            [1, 1, 28, 28]            [1, 6, 24, 24]            156
├─Conv2d: 1-2                            [1, 6, 12, 12]            [1, 16, 8, 8]             2,416
├─Linear: 1-3                            [1, 256]                  [1, 120]                  30,840
├─Linear: 1-4                            [1, 120]                  [1, 84]                   10,164
├─Linear: 1-5                            [1, 84]                   [1, 10]                   850
Total params: 44,426
Trainable params: 44,426
Non-trainable params: 0
Total mult-adds (M): 0.29
Input size (MB): 0.00
Forward/backward pass size (MB): 0.04
Params size (MB): 0.18
Estimated Total Size (MB): 0.22

## Creat loss function and optimizer

In [23]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=learning_rate, weight_decay=weight_decay)

## Create Tensorboard writer

In [24]:
timestamp = datetime.datetime.now().strftime('%d-%m-%Y_%H:%M:%S')
logdir_name = "./runs/mnist_sandbox_{}".format(timestamp)
summary_writer = SummaryWriter(logdir_name)

## Run Tensorboard 

In [25]:
%load_ext tensorboard
%tensorboard --logdir $logdir_name

The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


## Train and validation loop

In [26]:
training_time = 0

for t in range(epochs):
    print(f"EPOCH {t+1:4d}", 70*"-", flush=True)

    tic = time.time()
    train_loop(train_loader, net, criterion, optimizer, device)
    toc = time.time()
    training_time += (toc - tic)
    train_res = validation_loop(train_loader, net, criterion, num_classes, device)
    test_res = validation_loop(test_loader, net, criterion, num_classes, device)    
    update_graphs(summary_writer, t, train_res, test_res)

summary_writer.close()
print(f"Finished training for {epochs} epochs in {training_time} seconds.")

EPOCH    1 ----------------------------------------------------------------------
EPOCH    2 ----------------------------------------------------------------------
EPOCH    3 ----------------------------------------------------------------------
EPOCH    4 ----------------------------------------------------------------------
EPOCH    5 ----------------------------------------------------------------------
EPOCH    6 ----------------------------------------------------------------------
EPOCH    7 ----------------------------------------------------------------------
EPOCH    8 ----------------------------------------------------------------------
EPOCH    9 ----------------------------------------------------------------------
EPOCH   10 ----------------------------------------------------------------------
Finished training for 10 epochs in 47.01894927024841 seconds.
