In [None]:
# import libraries
import torch
import torch.nn as nn
import torch.optim as optim

import torchvision
import torchvision.transforms as transforms

from torch.utils.tensorboard import SummaryWriter
from utils import device, get_num_correct, RunBuilder
from network import Network

In [None]:
# extract and transform the data
train_set = torchvision.datasets.MNIST(
    root='./data/',
    train=True,
    download=True,
    transform=transforms.ToTensor()
)
test_set = torchvision.datasets.MNIST(
    root='./data/',
    train=False,
    download=True,
    transform=transforms.ToTensor()
)
# load the test set
test_loader = torch.utils.data.DataLoader(test_set, batch_size=1000)

In [None]:
# for hyper-parameter search
from collections import OrderedDict

params = OrderedDict(
    lr = [0.01, 0.003, 0.001, 0.0003],
    batch_size = [256, 512]
)

In [None]:
# loss function (categorical cross-entropy)
criterion = nn.CrossEntropyLoss()

# iterate through the cross product of hyper-parameters defined in params
for run in RunBuilder.get_runs(params):
    network = Network().to(device)  # initialize the NN

    # load the train set
    train_loader = torch.utils.data.DataLoader(train_set, batch_size=run.batch_size, shuffle=True, num_workers=1)
    # specify the optimizer
    optimizer = optim.Adam(network.parameters(), lr=run.lr)

    # comment will be used for naming the runs based on each run's hyper-parameters
    comment = f'-{run}'
    tb = SummaryWriter(comment=comment)

    # number of epochs used for training
    num_epochs = 4
    for epoch in range(num_epochs):
        
        train_loss = 0
        train_correct = 0
        # these will be used to track the running loss and correct so far

        ###################
        # train the model #
        ###################
        network.train()  # set the model to train mode
        for batch in train_loader:
            # load the batch to the available device (cpu/gpu)
            images, labels = batch[0].to(device), batch[1].to(device)
            # forward pass: compute predicted outputs by passing the batch to the models
            preds = network(images)
            # calculate the loss
            loss = criterion(preds, labels)
            # clear the accumulated gradients from the previous pass
            optimizer.zero_grad()
            # backward pass: compute gradient of the loss wrt model parameters
            loss.backward()
            # perform a single optimization step (parameter update)
            optimizer.step()

            # update the running loss
            train_loss += loss.item() * run.batch_size
            # update the running num of correct
            train_correct += get_num_correct(preds, labels)


        # add the train loss for the current epoch to tensorboard
        tb.add_scalar('Train Loss', train_loss, epoch)
        # add the train accuracy for the current epoch to tensorboard
        tb.add_scalar('Train Accuracy', train_correct / len(train_set), epoch)
        
        ##################
        # test the model #
        ##################
        network.eval()  # set the model to evaluation mode
        # turn off the grad tracking feature as we don't need gradients for validation or testing
        with torch.no_grad():

            test_loss = 0
            test_correct = 0
            # these will be used to track the running loss and correct so far

            for batch in test_loader:
                # load the batch to the available device (cpu/gpu)
                images, labels = batch[0].to(device), batch[1].to(device)
                # forward pass: compute predicted outputs by passing the batch to the models
                preds = network(images)
                # calculate the loss
                loss = criterion(preds, labels)

                # update the running loss
                test_loss += loss.item() * images.size(0)
                # update the running num of correct
                test_correct += get_num_correct(preds, labels)

            # add the test loss for the current epoch to tensorboard
            tb.add_scalar('Test Loss', test_loss, epoch)
            # add the test accuracy for the current epoch to tensorboard
            tb.add_scalar('Test Accuracy', test_correct / len(test_set), epoch)


        # iterate through parameter's weights and it's grads and plot their histograms to tensorboard
        # this is pretty helpful when checking if the model is facing vanishing gradients problem 
        for name, weight in network.named_parameters():
            tb.add_histogram(name, weight, epoch)
            tb.add_histogram(f'{name}.grad', weight.grad, epoch)


    # save the model
    torch.save(network.state_dict(), f'./models/without_validation/model-{run}.ckpt')

__Note:__ this project uses Tensorboard as an evaluation utility for plotting running losses, accuracies, histograms etc. So if you are wondering why there are no outputs while the network is training, use Tensorboard (_open terminal, change path to project's repo and run this command `tensorboard --logdir=runs`_)