In [15]:
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch

from torchvision import datasets, transforms

In [4]:
import os
import time

from typing import Iterable
from dataclasses import dataclass

In [5]:
import matplotlib.pyplot as plt  # one of the best graphics library for python

In [8]:
class LeNet5(nn.Module):
    def __init__(self):
        super().__init__()
        
        self._body = nn.Sequential(
            # input shape (32,32)
            nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5),
            # output shape (28, 28)
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2),
            # output shape (6, 14, 14)
            
            nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5),
            # output shape (10, 10)
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2)
            # output shape (16, 5, 5)
        )
        
        
        self._head = nn.Sequential(
            nn.Linear(in_features = 16 * 5 * 5, out_features = 120),
            nn.ReLU(inplace=True),
            nn.Linear(in_features=120, out_features=84),
            nn.ReLU(inplace=True),
            nn.Linear(in_features=84, out_features=10)
        )
        
        
    def forward(self, x):
        x = self._body(x)
        x = x.view(x.size()[0], -1)
        x = self._head(x)
        return x

In [9]:
lenet5_model = LeNet5()
print(lenet5_model)

LeNet5(
  (_body): Sequential(
    (0): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (_head): Sequential(
    (0): Linear(in_features=400, out_features=120, bias=True)
    (1): ReLU(inplace=True)
    (2): Linear(in_features=120, out_features=84, bias=True)
    (3): ReLU(inplace=True)
    (4): Linear(in_features=84, out_features=10, bias=True)
  )
)


In [12]:
@dataclass
class SystemConfiguration:
    '''
    Describes the common system setting needed for reproducible training
    '''
    seed: int = 42  # seed number to set the state of all random number generators
    cudnn_benchmark_enabled: bool = True  # enable CuDNN benchmark for the sake of performance
    cudnn_deterministic: bool = True  # make cudnn deterministic (reproducible training)

In [11]:
@dataclass
class TrainingConfiguration:
    '''
    Describes configuration of the training process
    '''
    batch_size: int = 32  # amount of data to pass through the network at each forward-backward iteration
    epochs_count: int = 20  # number of times the whole dataset will be passed through the network
    learning_rate: float = 0.01  # determines the speed of network's weights update
    log_interval: int = 100  # how many batches to wait between logging training status
    test_interval: int = 1  # how many epochs to wait before another test. Set to 1 to get val loss at each epoch
    data_root: str = "data"  # folder to save MNIST data (default: data/mnist-data)
    num_workers: int = 10  # number of concurrent processes used to prepare data
    device: str = 'cuda'  # device to use for training.

In [13]:
def setup_system(system_config: SystemConfiguration) -> None:
    torch.manual_seed(system_config.seed)
    if torch.cuda.is_available():
        torch.backends.cudnn_benchmark_enable = system_config.cudnn_benchmark_enabled
        torch.backends.cudnn.deterministic = system_config.cudnn_deterministic

In [17]:
def train(
    train_config: TrainingConfiguration,
    model: nn.Module,
    optimizer: torch.optim.Optimizer,
    train_loader: torch.utils.data.DataLoader,
    epoch_idx: int
) -> None:
    
    model.train()
    
    batch_loss = np.array([])
    
    batch_acc = np.array([])
    
    for batch_idx, (data, target) in enumerate(train_loader):
        indx_target = target.clone()
        
        data = data.to(train_config.device)
        target = target.to(train_config.device)
        
        optimizer.zero_grad()
        
        output = model(data)
        
        loss = F.cross_entropy(output, target)
        
        loss.backward()
        
        optimizer.step()
        
        batch_loss = np.append(batch_loss, [loss.item()])
        
        prob = F.softmax(output, dim = 1)
        
        pred = prob.data.max(dim=1)[1]
        
        correct = pred.cpu().eq(indx_target).sum()
        
        acc = float(correct)/float(len(data))
        
        batch_acc = np.append(batch_acc, [acc])
        
        if batch_idx % train_config.log_interval == 0 and batch_idx > 0:              
            print(
                'Train Epoch: {} [{}/{}] Loss: {:.6f} Acc: {:.4f}'.format(
                    epoch_idx, batch_idx * len(data), len(train_loader.dataset), loss.item(), acc
                )
            )
            
            
    epoch_loss = batch_loss.mean()
    epoch_acc = batch_acc.mean()
    return epoch_loss, epoch_acc  
    

In [18]:
def validate(
    train_config: TrainingConfiguration,
    model: nn.Module,
    test_loader: torch.utils.data.DataLoader
) -> float:
    
    model.eval()
    
    test_loss = 0
    count_correct_predictions = 0
    
    for data, target in test_loader:
        indx_target = target.clone()
        data = data.to(train_config.device)
        
        target = target.to(train_config.device)
        
        output = model(data)
        
        test_loss += F.cross_entropy(output, target).item()
        
        prob = F.softmax(output, dim = 1)
        
        pred = prob.data.max(dim=1)[1]
        
        count_correct_predictions += pred.cpu().eq(indx_target).sum()
        
    test_loss = test_loss/ len(test_loader)
    
    accuracy = 100. * count_correct_predictions/ len(test_loader.dataset)
    
    print(
        '\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
            test_loss, count_corect_predictions, len(test_loader.dataset), accuracy
        )
    )
    return test_loss, accuracy/100.0