In [None]:
# import libraries
import torch
import torch.nn as nn
import torch.optim as optim

import torchvision
import torchvision.transforms as transforms

from torch.utils.tensorboard import SummaryWriter
from utils import device, get_num_correct, RunBuilder
from network import Network

In [None]:
# The output of torchvision datasets are PILImage images of range [0, 1].
# We transform them to Tensors of normalized range [-1, 1]
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.5,), (0.5,))])

# extract and transform the data
train_set = torchvision.datasets.FashionMNIST(
    root='./data/',
    train=True,
    download=True,
    transform=transform
)
test_set = torchvision.datasets.FashionMNIST(
    root='./data/',
    train=False,
    download=True,
    transform=transform
)

# load the test set
test_loader = torch.utils.data.DataLoader(test_set, batch_size=1000)

In [None]:
# for hyper-parameter search
from collections import OrderedDict

params = OrderedDict(
    lr = [0.01, 0.003],
    batch_size = [256, 512]
)

In [None]:
criterion = nn.CrossEntropyLoss()  # specify loss function (categorical cross-entropy)

# iterate through the cross product of hyper-parameters defined in params
for run in RunBuilder.get_runs(params):
    network = Network().to(device)  # initialize the NN
    # load the train set
    train_loader = torch.utils.data.DataLoader(train_set, batch_size=run.batch_size, shuffle=True, num_workers=1)
    optimizer = optim.Adam(network.parameters(), lr=run.lr)  # sprecify the optimizer

    comment = f'-{run}'  # will be used for naming the runs based on each run's hyper-parameters
    tb = SummaryWriter(comment=comment)

    for epoch in range(20):
        train_loss, train_correct = 0, 0  # will be used to track the running loss and correct
        ###################
        # train the model #
        ###################
        network.train()  # set model to train mode
        for batch in train_loader:
            images, labels = batch[0].to(device), batch[1].to(device)  # load the batch to the available device
            preds = network(images)  # forward pass
            loss = criterion(preds, labels)  # calculate loss
            optimizer.zero_grad()  # clear accumulated gradients from the previous pass
            loss.backward()  # backward pass
            optimizer.step()  # perform a single optimization step

            train_loss += loss.item() * run.batch_size  # update the running loss
            train_correct += get_num_correct(preds, labels)  # update the running num correct

        tb.add_scalar('Train Loss', train_loss, epoch)
        tb.add_scalar('Train Accuracy', train_correct / len(train_set), epoch)
        # add train loss and train accuracy for the current epoch to tensorboard
        
        network.eval()  # set the model to evaluation mode
        with torch.no_grad():  # turn off grad tracking, as we don't need gradients for test
            test_loss, test_correct = 0, 0  # will be used to track the running loss and correct
            ##################
            # test the model #
            ##################
            for batch in test_loader:
                images, labels = batch[0].to(device), batch[1].to(device)  # load the batch to the available device
                preds = network(images)  # forward pass
                loss = criterion(preds, labels)  # calculate the loss

                test_loss += loss.item() * 1000  # update the running loss
                test_correct += get_num_correct(preds, labels)  # update the running num correct

            tb.add_scalar('Test Loss', test_loss, epoch)
            tb.add_scalar('Test Accuracy', test_correct / len(test_set), epoch)
            # add train loss and train accuracy for the current epoch to tensorboard

        # iterate the parameters' weights and it's grads and plot their historgrams to tensorboard
        # (will be helpful for checking if the model is having the vanishing gradient problem)
        for name, weight in network.named_parameters():
            tb.add_histogram(name, weight, epoch)
            tb.add_histogram(f'{name}.grad', weight.grad, epoch)

    # save the model
    torch.save(network.state_dict(), f'./models/model-{run}.ckpt')


__Note:__ this project uses Tensorboard as an evaluation utility for plotting running losses, accuracies, histograms etc. So if you are wondering why there are no outputs while the network is training, use Tensorboard (_open terminal, change path to project's repo and run this command `tensorboard --logdir=runs`_)