# Neural Network Training with DataJoint

In this session, we are going to look at how we can use DataJoint to train neural networks, exploring various **hyperparameters** for the training.

As always we are going to start with importing the essential scientific Python packages

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import datajoint as dj
import time
from tqdm import tqdm

and we are also going to import the PyTorch package `torch` as well as the associated `torchvision` package that provides means of downloading and handling popular machine learning datasets.

In [None]:
import torch # import the PyTorch package
import torchvision # import trochvision package
from torchvision import transforms # get torchvision's transforms subpackage

As an example, we'll work with MNIST handwritten digit datasets.

In [None]:
# create a composite transform that first converts images to tensors and then normalize the images
image_transform = transforms.Compose([
    transforms.ToTensor(), # converts images into Tensors
    transforms.Normalize([0.1307], [0.3081])
])

# apply the transforms at the time of dataset loading
train_set = torchvision.datasets.MNIST('./data', train=True, download=True,
                                          transform=image_transform)
test_set = torchvision.datasets.MNIST('./data', train=True, download=True,
                                          transform=image_transform)

This returns Torchvision's special **dataset** object that can be used to represent **supervised datasets** consisting of both inputs (i.e. images) and targets (i.e. digit labels).

In [None]:
len(train_set)

In [None]:
image, label = train_set[100]

In [None]:
plt.imshow(image.squeeze(), cmap='gray')
plt.title('Digit: {}'.format(label))

In [None]:
fig, axs = plt.subplots(5, 5, figsize=(6, 6), dpi=150)

for i, ax in enumerate(axs.ravel()):
    image, label = train_set[i]
    ax.imshow(image.squeeze(), cmap='gray')
    ax.set_title('Digit: {}'.format(label))
    ax.axis('off')
    
fig.tight_layout()

In [None]:
image, label = test_set[3]

In [None]:
plt.imshow(image.squeeze(), cmap='gray')
plt.title('Digit: {}'.format(label))

# Building a network for classification

In PyTorch, you define a new neural network by defining a **new class that inherits from nn.Module** as follows:

In [None]:
import torch.nn as nn
import torch.nn.functional as F

In [None]:
class Network(nn.Module):
    def __init__(self, hidden_size=50):
        super().__init__()
        self.fc1 = nn.Linear(784, hidden_size)
        self.fc2 = nn.Linear(hidden_size, 10)
        
    def forward(self, x):
        x = x.view(-1, 784) # flattens an image of form N x 1 x 28 x 28 -> N x 784
        x = F.relu(self.fc1(x)) # first fully connected layer followed by ReLU
        x = self.fc2(x) # third fully connected layer *without* output ReLU
        x = F.log_softmax(x, dim=1) # make sure that probabilities add up to one, and then take log
        return x

This network was **one hyperparameter - the size of the hidden layer** (defaulting to 50 here)

In [None]:
batch_size = 64 # this is another hyperparameter!
training_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size) # by default shuffle is False

## Training the network

In [None]:
net = Network()
net.train() # puts the network into the training mode

# create and initialize an optimizer
# learning rate is another hyperparameter!
optimizer = torch.optim.SGD(net.parameters(), lr=0.005)

start = time.time()
for epoch_idx in range(3): # number of epochs is yet another hyperparameter!
    for batch_idx, (data, target) in enumerate(training_loader):
        # reset the gradient before the next gradient step
        optimizer.zero_grad()

        # evaluate the network output
        output = net(data)

        # compute the loss
        loss = F.nll_loss(output, target)

        # perform back propagation to compute gradients with respect to parameters!
        loss.backward()

        # perform a gradient descent step on the parameters
        optimizer.step()

        # report the loss every 100 batches
        if batch_idx % 100 == 0:
            print('Epoch {} Loss: {:.6f}'.format(epoch_idx, loss.item()))
            
duration = time.time() - start
print('Training completed in {:.2f} seconds'.format(duration))

## Evaluating the network

In [None]:
net.eval() # put network into evaluation model
test_loss = 0
correct = 0

# prevents unnecessary gradient computation during test - can lead to time and memory saving
with torch.no_grad(): 
    for data, target in tqdm(test_loader):
        output = net(data)
        
        # sum up batch loss
        test_loss += F.nll_loss(output, target, reduction='sum').item() 
        
        # get the index of the max log-probability
        pred = output.max(1, keepdim=True)[1] 
        
        # count number of times where max probability matches the label index
        correct += pred.eq(target.view_as(pred)).sum().item()

# divide the test loss by number of samples in the test set
test_loss /= len(test_loader.dataset)
accuracy = correct / len(test_loader.dataset)

print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
    test_loss, correct, len(test_loader.dataset),
    100. * accuracy))

Let's look at actual guesses

In [None]:
fig, axs = plt.subplots(5, 5, figsize=(6, 6), dpi=150)
image_order = np.random.permutation(len(test_set))
for i, ax in zip(image_order, axs.ravel()):
    image, label = test_set[i]
    p = torch.exp(net(image))
    digit = torch.argmax(p)
    ax.imshow(image.squeeze(), cmap='gray')
    ax.set_title('Guess: {}\nActual: {}'.format(digit, label))
    ax.axis('off')
    
fig.tight_layout()

## Using DataJoint to coordinate the training

Now let's see how we can use DataJoint to streamline this process.

In [None]:
import datajoint as dj

In [None]:
# create a new schema to house tables for network training
schema = dj.schema('network')

In [None]:
@schema
class NetworkConfig(dj.Lookup):
    definition = """
    hidden_size: int    # size of hidden layer
    """
    contents = zip([50])
    
@schema
class TrainingConfig(dj.Lookup):
    definition = """
    train_config_id: int   # unique id for a training config
    ---
    learning_rate: float   # learning rate for SGD
    batch_size: int        # batch_size for training
    n_epochs: int          # number of epochs to train
    """
    contents = [
        (0, 0.005, 64, 3)
    ]
    

@schema
class TrainedNetwork(dj.Computed):
    definition = """
    -> NetworkConfig
    -> TrainingConfig
    ---
    train_duration: float   # duration of training in seconds
    test_loss: float        # loss on the test set
    test_acc: float         # accuracy on the test set
    """
    
    def make(self, key):
        # Get configurations!
        hidden_size = (NetworkConfig & key).fetch1('hidden_size')
        learning_rate, batch_size, n_epochs = (TrainingConfig & key).fetch1('learning_rate',
                                                                            'batch_size',
                                                                            'n_epochs')
        # convert from numpy array into int
        batch_size, n_eochs = int(batch_size), int(n_epochs)
        
        # prepare the data
        image_transform = transforms.Compose([
            transforms.ToTensor(), # converts images into Tensors
            transforms.Normalize([0.1307], [0.3081])
        ])

        # apply the transforms at the time of dataset loading
        train_set = torchvision.datasets.MNIST('./data', train=True, download=True,
                                                  transform=image_transform)
        test_set = torchvision.datasets.MNIST('./data', train=True, download=True,
                                                  transform=image_transform)
        
        # prepare data loaders
        training_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True)
        test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size) # by default shuffle is False
        
        # instantiate the network
        net = Network(hidden_size=hidden_size)
        net.train() # puts the network into the training mode

        # configure optimizer
        optimizer = torch.optim.SGD(net.parameters(), lr=learning_rate)

        start = time.time()
        for epoch_idx in range(n_epochs):
            for batch_idx, (data, target) in enumerate(training_loader):
                # reset the gradient before the next gradient step
                optimizer.zero_grad()

                # evaluate the network output
                output = net(data)

                # compute the loss
                loss = F.nll_loss(output, target)

                # perform back propagation to compute gradients with respect to parameters!
                loss.backward()

                # perform a gradient descent step on the parameters
                optimizer.step()

                # report the loss every 100 batches
                if batch_idx % 100 == 0:
                    print('Epoch {} Loss: {:.6f}'.format(epoch_idx, loss.item()))

        duration = time.time() - start
        print('Training completed in {:.2f} seconds'.format(duration))
        
        ## Evaluate the network
        net.eval() # put network into evaluation model
        test_loss = 0
        correct = 0

        # prevents unnecessary gradient computation during test - can lead to time and memory saving
        with torch.no_grad(): 
            for data, target in test_loader:
                output = net(data)

                # sum up batch loss
                test_loss += F.nll_loss(output, target, reduction='sum').item() 

                # get the index of the max log-probability
                pred = output.max(1, keepdim=True)[1] 

                # count number of times where max probability matches the label index
                correct += pred.eq(target.view_as(pred)).sum().item()

        # divide the test loss by number of samples in the test set
        test_loss /= len(test_loader.dataset)
        
        # compute accuracy
        accuracy = correct / len(test_loader.dataset)

        print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
            test_loss, correct, len(test_loader.dataset),
            100. * accuracy))
        
        key['train_duration'] = duration
        key['test_loss'] = test_loss
        key['test_acc'] = accuracy
        
        self.insert1(key)

In [None]:
dj.Diagram(schema)

In [None]:
TrainingConfig()

In [None]:
NetworkConfig()

In [None]:
TrainedNetwork.populate()

In [None]:
TrainedNetwork()

## Try more hyperparameter values

Now let's add a few more entries into the config tables and try out different combinations of hyperparameters for network training.

In [None]:
NetworkConfig().insert1((200, ))

In [None]:
NetworkConfig()

In [None]:
TrainingConfig.insert1([1, 0.005, 64, 5])

In [None]:
TrainingConfig()

In [None]:
TrainedNetwork.populate(display_progress=True)

In [None]:
TrainedNetwork()

In [None]:
TrainedNetwork * TrainingConfig