In [1]:
from datasets import PACSDataset
from torchvision import transforms
from torch.utils.data import DataLoader


# Training a neural network
We are able to load data from our custom datasets, now we need to train a neural network. We need the following:

1. A neural network
2. An optimiser
3. A training loop that iterates through samples provided by a dataloader and uses the optimiser to update the neural networks weights


## Train and Test dataloaders
Normally, we'll have a training, validation and test set. The validation set is used for hyperparameter tuning. Since we won't do any hyperparameter tuning we can just go ahead and use the test set for evaluation. 

In [2]:
# Files and transforms
with open("train_files.txt") as f:
    file_names_train = f.read().splitlines()
    
with open("test_files.txt") as f:
    file_names_test = f.read().splitlines()

# transforms
# A great point to add data augmentations - you want to do class-preserving transformations
# Rotations, Reflections etc
transform = transforms.Compose([
            transforms.Resize(size=(224, 224)),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
        ])

# Instantiate datasets
train_dataset = PACSDataset(file_names_train, transform)
test_dataset = PACSDataset(file_names_test, transform)

# Instantiate dataloaders
train_dataloader = DataLoader(train_dataset, batch_size = 64, shuffle = True)
test_dataloader = DataLoader(test_dataset, batch_size = 64, shuffle = True)

## Neural network modules
We will create a neural network modules. To illustrate the modularity of pytorch we will break it down into multiple modules.

1. A featurizer: applies convolutions to extract "useful" features
2. A classifier: fully connected net built on top of featurizer
3. A network that combines the two: Allows us to easily swap out parts 1 or 2 if we accept them the `__init__` function. 

In [3]:
from torch import nn
from torch.nn import functional as F
import torch


# Featurizer
class Featurizer(nn.Module):
    def __init__(self):
        super(Featurizer, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, 3, padding=1)
        self.dropout = nn.Dropout2d(0.5)
        self.conv2 = nn.Conv2d(16, 4, 3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)

    # how we want computations to run in the forward pass
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.dropout(x)
        x = self.pool(x)
        x = F.relu(self.conv2(x))
        x = self.pool(x)
        return x

# Classifier
# We will hardcode the layers for now, but it is best to use parameters
class Classifier(nn.Module):
    def __init__(self, num_classes: int = 7):
        super(Classifier, self).__init__()
        # make a point here about the 12544
        self.fc1 = nn.Linear(12544, 4096)
        self.dropout = nn.Dropout2d(0.5)

        self.fc2 = nn.Linear(4096, 1024)
        self.fc3 = nn.Linear(1024, num_classes)
        
    def forward(self, x):
        x = torch.flatten(x, 1)
        x = self.dropout(x)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# Combination
class CNNClassifier(nn.Module):
    def __init__(self, num_classes: int = 7):
        super(CNNClassifier, self).__init__()
        self.featurizer = Featurizer()
        self.classifier = Classifier(num_classes)
        
    def forward(self, x):
        return self.classifier(self.featurizer(x))

In [4]:
# Initialise network and move to GPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = CNNClassifier(num_classes=7)
model.to(device)
print(model)

CNNClassifier(
  (featurizer): Featurizer(
    (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (dropout): Dropout2d(p=0.5, inplace=False)
    (conv2): Conv2d(16, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (classifier): Classifier(
    (fc1): Linear(in_features=12544, out_features=4096, bias=True)
    (dropout): Dropout2d(p=0.5, inplace=False)
    (fc2): Linear(in_features=4096, out_features=1024, bias=True)
    (fc3): Linear(in_features=1024, out_features=7, bias=True)
  )
)


## Training logic
We need to set up the training steps to 
1. Compute predictions
2. Compute loss functions
3. zero gradients, do backwards pass, take a step with optimiser
4. it can be helpful to indicate training progress
5. We don't show it here, but a good exercise is to plot samples and the model label predictions in the same way we did previously. Here you would need to obtain a batch from the data loader, compute predictions and then perform the plot. The process is more akin to the `test_loop` below than the training loop since we don't compute gradients.



### Aside
Please check out Andrej Karpathy's [recipe for training neural networks](http://karpathy.github.io/2019/04/25/recipe/). He explains a really good approach to training neural networks. Pay attention to the advice on overfitting. An easy way to make a model overfit (and check model correctness) is to use a smaller training set (of maybe 2-5 samples). You can do this by creating a Dataset and passing a smaller list of files:
* `training_dataset = PACSDataset(file_names_train[:5])`

By doing list slicing `file_names_train[:5]`, we are asking for a list containing the first five elements in `file_names_train`

In [5]:
learning_rate = 1e-4
epochs = 16

In [6]:
def train_loop(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    for batch, (X, y) in enumerate(dataloader):
        # Move to device
        X, y = X.to(device), y.to(device)
        
        # compute predictions
        pred = model(X)
        
        # evaluate loss
        loss = loss_fn(pred, y)

        # Zero gradients
        optimizer.zero_grad()
        
        # Back propagation
        loss.backward()
        
        # Gradient descent
        optimizer.step()

        # training progress
        if batch % 20 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")


def test_loop(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss, correct = 0, 0

    # turn of gradients - we don't need the extra memory footprint
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()

    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

# create loss_fn
loss_fn = nn.CrossEntropyLoss()

# create optimiser
optimizer = torch.optim.SGD(model.parameters(), lr = learning_rate, momentum=0.9)

In [7]:
# loop over epochs
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train_loop(train_dataloader, model, loss_fn, optimizer)
    if (t+1) % 5 == 0:
        test_loop(train_dataloader, model, loss_fn)
        test_loop(test_dataloader, model, loss_fn)
print("Done!")

Epoch 1
-------------------------------


  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


loss: 1.948940  [    0/ 6993]
loss: 1.957068  [ 1280/ 6993]
loss: 1.915716  [ 2560/ 6993]
loss: 1.928620  [ 3840/ 6993]
loss: 1.934139  [ 5120/ 6993]
loss: 1.902182  [ 6400/ 6993]
Epoch 2
-------------------------------
loss: 1.923584  [    0/ 6993]
loss: 1.876461  [ 1280/ 6993]
loss: 1.907546  [ 2560/ 6993]
loss: 1.936019  [ 3840/ 6993]
loss: 1.949657  [ 5120/ 6993]
loss: 1.906812  [ 6400/ 6993]
Epoch 3
-------------------------------
loss: 1.900024  [    0/ 6993]
loss: 1.887194  [ 1280/ 6993]
loss: 1.957663  [ 2560/ 6993]
loss: 1.898209  [ 3840/ 6993]
loss: 1.881095  [ 5120/ 6993]
loss: 1.948425  [ 6400/ 6993]
Epoch 4
-------------------------------
loss: 1.899917  [    0/ 6993]
loss: 1.928322  [ 1280/ 6993]
loss: 1.912875  [ 2560/ 6993]
loss: 1.852667  [ 3840/ 6993]
loss: 1.838704  [ 5120/ 6993]
loss: 1.906680  [ 6400/ 6993]
Epoch 5
-------------------------------
loss: 1.919132  [    0/ 6993]
loss: 1.940585  [ 1280/ 6993]
loss: 1.949100  [ 2560/ 6993]
loss: 1.924797  [ 3840/ 6993]


KeyboardInterrupt: 