In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

import torchvision
import torchvision.transforms as transforms

from torch.utils.tensorboard import SummaryWriter
from utils import device, get_num_correct, RunBuilder
from network import Network

In [None]:
train_set = torchvision.datasets.MNIST(
    root='./data/',
    train=True,
    download=True,
    transform=transforms.ToTensor()
)
test_set = torchvision.datasets.MNIST(
    root='./data/',
    train=False,
    download=True,
    transform=transforms.ToTensor()
)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=1000)

In [None]:
from collections import OrderedDict

params = OrderedDict(
    lr = [0.01, 0.003, 0.001, 0.0003],
    batch_size = [256, 512]
)

In [None]:
criterion = nn.CrossEntropyLoss()

for run in RunBuilder.get_runs(params):
    network = Network().to(device)
    train_loader = torch.utils.data.DataLoader(train_set, batch_size=run.batch_size, shuffle=True, num_workers=1)
    optimizer = optim.Adam(network.parameters(), lr=run.lr)

    comment = f'-{run}'
    tb = SummaryWriter(comment=comment)

    for epoch in range(4):

        train_loss = 0
        train_correct = 0

        network.train()
        for batch in train_loader:
            images, labels = batch[0].to(device), batch[1].to(device)
            preds = network(images)
            loss = criterion(preds, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_loss += loss.item() * run.batch_size
            train_correct += get_num_correct(preds, labels)

        tb.add_scalar('Train Loss', train_loss, epoch)
        tb.add_scalar('Train Accuracy', train_correct / len(train_set), epoch)

        network.eval()
        with torch.no_grad():
            test_loss = 0
            test_correct = 0
            for batch in test_loader:
                images, labels = batch[0].to(device), batch[1].to(device)
                preds = network(images)
                loss = criterion(preds, labels)

                test_loss += loss.item() * 1000  # '1000' here represents batch size of test_set
                test_correct += get_num_correct(preds, labels)

            tb.add_scalar('Test Loss', test_loss, epoch)
            tb.add_scalar('Test Accuracy', test_correct / len(test_set), epoch)


        for name, weight in network.named_parameters():
            tb.add_histogram(name, weight, epoch)
            tb.add_histogram(f'{name}.grad', weight.grad, epoch) 


    torch.save(network.state_dict(), f'./models/model-{run}.ckpt')