## Define a RunBuilder, RunManager and use those classes in the training loop

In [19]:
#Imports for our classes
from collections import OrderedDict
from collections import namedtuple
from itertools import product
import time
#import torch #normally gets imported anyway... let's see!
import pandas as pd
from IPython.display import clear_output

In [5]:
from __future__ import print_function
from collections import OrderedDict
import numpy as np


import torch
import torch.optim as optim 
import torch.nn as nn
import torch.nn.functional as F

import torchvision
import torchvision.transforms as transforms

from torch.utils.tensorboard import SummaryWriter


#We can import our classes like this now:
from run_classes import RunBuilder, RunManager

To understand, look at the Tensorboard JN.

In [3]:
class RunBuilder():
    @staticmethod
    def get_runs(params):
        Run = namedtuple('Run', params.keys())
        
        runs = []
        
        for v in product(*params.values()):
            runs.append(Run(*v))
            
        return runs

In [4]:
class RunManager():
    def __init__(self):
        
        #can further be refactored by extracting this
        self.network = None
        self.loader = None
        self.tb = None
        
        self.epoch_count = 0
        self.epoch_loss = 0
        self.epoch_num_correct = 0
        self.epoch_start_time = None
        
        self.run_params = None
        self.run_count = 0
        self.run_data = []
        self.run_start_time = None
    
    def begin_run(self, run, network, loader):
        
        self.run_start_time = time.time()
        
        self.run_params = run
        self.run_count += 1
        
        self.network = network
        self.loader = loader
        self.tb = SummaryWriter(comment = f' Run {self.run_count} with {run}')
        
        images, labels = next(iter(self.loader))
        grid = torchvision.utils.make_grid(images)
        
        self.tb.add_image('images', grid)
        self.tb.add_graph(self.network, images)
        
    def end_run(self):
        self.tb.close()
        self.epoch_count = 0
        
    def begin_epoch(self):
        self.epoch_start_time = time.time()
        
        self.epoch_count += 1
        self.epoch_loss = 0
        self.epoch_num_correct = 0
        
    def end_epoch(self):
        
        epoch_duration = time.time() - self.epoch_start_time
        run_duration = time.time() - self.run_start_time
        
        loss = self.epoch_loss / len(self.loader.dataset)
        accuracy = self.epoch_num_correct / len(self.loader.dataset)
        
        self.tb.add_scalar('Loss', loss, self.epoch_count)
        self.tb.add_scalar('Accuracy', accuracy, self.epoch_count)
        
        for name, param in self.network.named_parameters():
            self.tb.add_histogram(name, param, self.epoch_count)
            self.tb.add_histogram(f'{name}.grad', param.grad, self.epoch_count)
        
        results = OrderedDict()
        results['run'] = self.run_count
        results['epoch'] = self.epoch_count
        results['loss'] = loss
        results['accuracy'] = accuracy
        results['epoch duration'] = epoch_duration
        results['run duration'] = run_duration
        
        for k, v in self.run_params._asdict().items(): results[k] = v
        self.run_data.append(results)
        table = pd.DataFrame.from_dict(self.run_data, orient = 'columns')
        
        clear_output(wait = True)
        display(table)
        
    def track_loss(self, loss):
        self.epoch_loss += loss.item() * self.loader.batch_size
        
    def track_num_correct(self, preds, labels):
        self.epoch_num_correct += self._get_num_correct(preds, labels)
    
    @torch.no_grad()
    def _get_num_correct(self, preds, labels):
        return preds.argmax(dim = 1).eq(labels).sum().item()
    
    def save(self, fileName):
        pd.DataFrame.from_dict(
            self.run_data
            , orient = 'columns'
        ).to_csv(f'{fileName}.csv')
        
        with open(f'{fileName}.csv', 'w', encoding = 'utf-8') as f:
                 json.dump(self.run_data, f, ensure_ascii = False, indent = 4)

### Import Network, TrainData

In [11]:
train_set = torchvision.datasets.FashionMNIST(
    root = "./data/FashionMNIST"
    , train = True
    , download = True
    , transform = transforms.Compose([
        transforms.ToTensor()
    ])
)

In [12]:
train_set_normal = torchvision.datasets.FashionMNIST(
    root = "./data/FashionMNIST"
    , train = True
    , download = True
    , transform = transforms.Compose([
        transforms.ToTensor()
        #Normalize Dataset?
            # normally per feature - meaning per color channel (most often 3, in this case 1)
        
        #Mean and std calculated in below cell
        #Objective is to move the mean to 0 and std to 1 
        , transforms.Normalize(mean = [0.2859], std = [0.3530])
    ])
)

In [8]:
#Calculate the mean and std for the transformer
loader = torch.utils.data.DataLoader(train_set, batch_size=len(train_set), num_workers = 1)
data = next(iter(loader))
data[0].mean(), data[0].std()

(tensor(0.2859), tensor(0.3530))

In [13]:
trainsets = {
    'not_normal': train_set
    , 'normal': train_set_normal
}

In [3]:
class Network(nn.Module):
    def __init__(self):
        super(Network, self).__init__()
        self.conv1 = nn.Conv2d(in_channels = 1, out_channels = 6, kernel_size = 5)
        self.conv2 = nn.Conv2d(in_channels = 6, out_channels = 12, kernel_size = 5)
        
        #table and formula to calculate the changes of img sizes:
        # https://deeplizard.com/learn/video/cin4YcGBh3Q
        self.fc1 = nn.Linear(in_features = 12*4*4, out_features = 120) #needed, because the img has the shape
                                                                        #(1, 12, 4, 4) when it arrives at the fc
                                                                        #because it is flattened, the input is 12*4*4
        self.fc2 = nn.Linear(in_features = 120, out_features = 60)
        self.out = nn.Linear(in_features = 60, out_features = 10)
        
    def forward(self, t):
        t = F.relu(self.conv1(t))
        t = F.max_pool2d(t, kernel_size = 2, stride = 2)
        
        t = F.relu(self.conv2(t))
        t = F.max_pool2d(t, kernel_size = 2, stride = 2)
        
        t = F.relu(self.fc1(t.reshape(-1, 12 * 4 * 4)))
        t = F.relu(self.fc2(t))
        
        t = self.out(t)
        #normally softmax, but is implicitly included in the cross entropy 
        return t

### Actual Training Loop now

In [6]:
params = OrderedDict(
    lr = [.01]
    , batch_size = [100, 1000]
    , shuffle =[True]
    , num_workers = [0, 1, 2, 4] #1 seems to give the best results
    , epochs = [1]
    , train_sets = ['not_normal', 'normal']
)

rm = RunManager()
for run in RunBuilder.get_runs(params):
    
    network = Network()
    loader = torch.utils.data.DataLoader(trainsets[run.train_sets]
                                         , batch_size = run.batch_size
                                         , shuffle = run.shuffle
                                         , num_workers = run.num_workers
                                        )
    optimizer = torch.optim.Adam(network.parameters(), lr = run.lr)
    
    rm.begin_run(run, network, loader)
    for epoch in range(run.epochs):
        
        rm.begin_epoch()
        for batch in loader:
            
            images, labels = batch[0], batch[1]
            preds = network(images)
            loss = F.cross_entropy(preds, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            rm.track_loss(loss)
            rm.track_num_correct(preds, labels)
            
        rm.end_epoch()
    rm.end_run()
#m.save('results') #create csv file with results

Unnamed: 0,run,epoch,loss,accuracy,epoch duration,run duration,lr,batch_size,shuffle,num_workers,epochs
0,1,1,0.607315,0.7646,13.668535,15.201364,0.01,100,True,0,1
1,2,1,0.585218,0.7802,10.700406,11.1722,0.01,100,True,1,1
2,3,1,0.572106,0.781383,10.679205,10.842438,0.01,100,True,2,1
3,4,1,0.573679,0.780917,11.166654,11.349237,0.01,100,True,4,1
4,5,1,1.003109,0.615567,15.95243,16.732734,0.01,1000,True,0,1
5,6,1,1.003832,0.616867,12.58924,13.464185,0.01,1000,True,1,1
6,7,1,0.947843,0.637217,12.450007,13.337946,0.01,1000,True,2,1
7,8,1,0.983083,0.625083,12.482737,13.385255,0.01,1000,True,4,1
