In [15]:
from collections import OrderedDict
from collections import namedtuple
from itertools import product

import time

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter

import torchvision
import torchvision.transforms as transforms

from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets, transforms

import functools

import pandas as pd
import inspect

from IPython.display import clear_output, display

from contextlib import contextmanager

In [6]:
class RunManager():
    def __init__(self):
        
        self.epoch_count = 0
        self.epoch_loss = 0
        self.epoch_num_correct = 0
        self.epoch_start_time = None
        #each of the variables can be encapsulated further into a class for just "Epoch"
        
        self.run_params = None
        self.run_count = 0
        self.run_data = []
        self.run_start_time = None
        #same for "run"
        
        self.network = None
        self.loader = None
        self.tb = None
        
    def begin_run(self, run, network, loader):
        self.run_start_time = time.time()
        self.run_params = run
        self.run_count += 1
        self.network = network
        self.loader = loader
        self.tb = SummaryWriter(comment=f'-{run}')
        images, labels = next(iter(self.loader))
        grid = torchvision.utils.make_grid(images)
        self.tb.add_image('images', grid)
        self.tb.add_graph(network, images)
    
    def end_run(self):
        self.tb.close()
        self.epoch_count = 0
    
    def begin_epoch(self):
        self.epoch_start_time = time.time()
        self.epoch_count += 1
        self.epoch_loss = 0
        self.epoch_num_correct = 0
        
    def end_epoch(self):
        epoch_duration = time.time() - self.epoch_start_time
        run_duration = time.time() - self.run_start_time
        loss = self.epoch_loss / len(self.loader.dataset)
        accuracy = self.epoch_num_correct / len(self.loader.dataset)
        self.tb.add_scalar('Loss', loss, self.epoch_count)
        self.tb.add_scalar('Accuracy', accuracy, self.epoch_count)
        for name, weight in self.network.named_parameters():
            self.tb.add_histogram(name, weight, self.epoch_count)
            self.tb.add_histogram(f'{name}.grad', weight.grad, self.epoch_count)
        results = OrderedDict()
        results['run'] = self.run_count
        results['epoch'] = self.epoch_count
        results['loss'] = loss
        results['accuracy'] = accuracy
        results['epoch_duration'] = epoch_duration
        results['run_duration'] = run_duration
        for k,v in self.run_params._asdict().items(): results[k] = v
        self.run_data.append(results)
        df = pd.DataFrame.from_dict(self.run_data, orient='columns')
        #following 2 are specific to jpyter notebook
        clear_output(wait=True)
        display(df)
        
    def track_loss(self, loss):
        self.epoch_loss = loss.item() * self.loader.batch_size
    
    def track_num_correct(self, predictions, labels):
        self.epoch_num_correct = self._get_num_correct(predictions, labels)
    
    @torch.no_grad()
    def _get_num_correct(self, predictions, labels):
        return torch.argmax(predictions, dim=1).eq(labels).sum().item()
    
    def save(self, filename='saved_at_'+str(int(time.time()))):
        pd.DataFrame.from_dict(self.run_data, orient='columns').to_csv(f'{filename}.csv')
    
    @contextmanager
    def run_setup(self, run, network, loader):
        self.begin_run(run, network, loader)
        try:
            yield
        finally:
            self.end_run()
    
    @contextmanager
    def epoch_setup(self):
        self.begin_epoch()
        try:
            yield
        finally:
            self.end_epoch()

In [7]:
class RunBuilder():
    @staticmethod
    def get_runs(params):
        Run = namedtuple('Run', params.keys())
        runs = []
        for v in product(*params.values()):
            runs.append(Run(*v))
        return runs

In [8]:
parameters = OrderedDict( 
    lr = [0.01],
    batch_size = [100, 1000],
    shuffle = [True],
    num_workers = [0, 1, 16]
)

In [9]:
runs = RunBuilder.get_runs(parameters)

In [10]:
class Network(torch.nn.Module):
    def __init__(self):
        super(Network, self).__init__()
        # Layers
        self.conv1 = torch.nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5)
        self.conv2 = torch.nn.Conv2d(in_channels=6, out_channels=12, kernel_size=5)
        self.fc1 = torch.nn.Linear(in_features=12*4*4, out_features=120)
        self.fc2 = torch.nn.Linear(in_features=120, out_features=60)
        self.fc3 = torch.nn.Linear(in_features=60, out_features=10)

        #Operations
        self.maxpool = functools.partial(F.max_pool2d, kernel_size=2, stride=2)
        self.relu = F.relu
        self.softmax = F.softmax

    def forward(self, t):
        t = self.conv1(t)
        t = self.relu(t)
        t = self.maxpool(t)

        t = self.conv2(t)
        t = self.relu(t)
        t = self.maxpool(t)

#         t = t.reshape(-1, 12*4*4)
        t = t.flatten(start_dim=1)
        t = self.fc1(t)
        t = self.relu(t)

        t = self.fc2(t)
        t = self.relu(t)

        t = self.fc3(t)
        return t

In [11]:
training_set = torchvision.datasets.FashionMNIST(
  root = './data/FashionMNIST',
  train=True,
  download=True,
  transform = transforms.Compose([
    transforms.ToTensor()
  ])
)

In [12]:
m = RunManager()
NUM_EPOCHS = 1
for run in RunBuilder.get_runs(parameters):
    network = Network()
    loader = torch.utils.data.DataLoader(training_set, 
                                         batch_size=run.batch_size, 
                                         shuffle=run.shuffle, 
                                         num_workers=run.num_workers)
    optimizer = optim.Adam(network.parameters(), lr=run.lr)
    m.begin_run(run, network, loader)
    for epoch in range(NUM_EPOCHS):
        m.begin_epoch()
        for batch in loader:
            images, labels = batch

            predictions = network(images)
            loss = F.cross_entropy(predictions, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            m.track_loss(loss)
            m.track_num_correct(predictions, labels)
        m.end_epoch()
    m.end_run()
m.save('results')

Unnamed: 0,run,epoch,loss,accuracy,epoch_duration,run_duration,lr,batch_size,shuffle,num_workers
0,1,1,0.000883,0.001317,16.360908,32.123134,0.01,100,True,0
1,2,1,0.000678,0.001417,10.058712,10.568761,0.01,100,True,1
2,3,1,0.000544,0.00145,9.066026,9.4017,0.01,100,True,16
3,4,1,0.010061,0.012633,12.236217,12.941024,0.01,1000,True,0
4,5,1,0.010654,0.0124,10.192284,11.035374,0.01,1000,True,1
5,6,1,0.011561,0.011633,7.530443,9.002347,0.01,1000,True,16


In [None]:
def get_all_predictions(model, loader):
    all_predictions = torch.tensor([])
    for batch in loader:
        images, labels = batch
        predictions = model(images)
        all_predictions = torch.cat((all_predictions, predictions), 
                                    dim=0)
    return all_predictions

In [None]:
with torch.no_grad():
    prediction_loader = torch.utils.data.DataLoader(training_set, 
                                                batch_size=100)
    train_predictions = get_all_predictions(new_network, 
                                        prediction_loader)

In [None]:
@torch.no_grad()
def get_n_correct_predictionsedictions(model, loader):
    all_predictions = torch.tensor([])
    for batch in loader:
        images, labels = batch
        predictions = model(images)
        all_predictions = torch.cat((all_predictions, predictions), 
                                    dim=0)
    return all_predictions

In [None]:
n_correct_predictions = get_num_correct(train_predictions, training_set.targets)
print('accuracy: ', n_correct_predictions/len(training_set))

In [None]:
x,y,z = 1,2,3

def retrieve_name(var):
    callers_local_vars = inspect.currentframe().f_back.f_locals.items()
    return [var_name for var_name, var_val in callers_local_vars if var_val is var][0]

print(retrieve_name(y))

In [None]:
def get_num_correct(predictions, labels):
    return torch.argmax(predictions, dim=1).eq(labels).sum().item()

In [None]:
def func_aggregator(list_of_functions, operand):
    for function in list_of_functions:
        operand = function(operand)
    return operand