In [1]:
from datasets import PACSDataset
from torchvision import transforms, models
from torch.utils.data import DataLoader


# Training a neural network
We are able to load data from our custom datasets, now we need to train a neural network. We need the following:

1. A neural network
2. An optimiser
3. A training loop that iterates through samples provided by a dataloader and uses the optimiser to update the neural networks weights

I like to create a module for the neural network, and then a class that couples 2 and 3. Lets start the dataloader and then explore these concepts. 

## Train and Test dataloaders
Normally, we'll have a training, validation and test set. The validation set is used for hyperparameter tuning. Since we won't do any hyperparameter tuning we can just go ahead and use the test set for evaluation. 

In [2]:
# Files and transforms
with open("train_files.txt") as f:
    file_names_train = f.read().splitlines()
    
with open("test_files.txt") as f:
    file_names_test = f.read().splitlines()

# transforms
# A great point to add data augmentations - you want to do class-preserving transformations
# Rotations, Reflections etc
transform = transforms.Compose([
            transforms.Resize(size=(224, 224)),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
        ])


train_dataset = PACSDataset(file_names_train, transform)
test_dataset = PACSDataset(file_names_test, transform)

train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=64, shuffle=True)

## Neural network modules
We will create a neural network modules. To illustrate the modularity of pytorch we will break it down into multiple modules.

1. A featurizer: applies convolutions to extract "useful" features
2. A classifier: fully connected net built on top of featurizer
3. A network that combines the two: Allows us to easily swap out parts 1 or 2. 

In [3]:
from torch import nn
from torch.nn import functional as F
import torch


# Classifier
# We will hardcode the layers for now, but it is best to use parameters
class Classifier(nn.Module):
    def __init__(self, num_classes: int = 7, num_features = 12544):
        super(Classifier, self).__init__()
        self.fc1 = nn.Linear(num_features, 4096)
        self.dropout = nn.Dropout2d(0.5)
        self.fc2 = nn.Linear(4096, 1024)
        self.fc3 = nn.Linear(1024, num_classes)
        
    def forward(self, x):
        x = torch.flatten(x, 1)
        x = self.dropout(x)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# Combination
class CNNClassifier(nn.Module):
    def __init__(self, num_classes: int = 7):
        super(CNNClassifier, self).__init__()
        self.net = models.resnet18(pretrained=True)
        for param in self.net.parameters():
            param.requires_grad = False

        # Parameters of newly constructed modules have requires_grad=True by default
        num_ftrs = self.net.fc.in_features
        self.net.fc = Classifier(num_classes=num_classes, num_features=num_ftrs)

    def forward(self, x):
        return self.net(x)


In [4]:
# Initialise network
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = CNNClassifier(num_classes=7)
model.to(device)
print(model)

CNNClassifier(
  (net): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_r

## Training logic
As mentioned before, we can fit the training logic into a class with a network as an attribute.

We also need to consider hyperparameters

In [5]:
learning_rate = 1e-4
# batch_size = 64
epochs = 16

In [6]:
def train_loop(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    for batch, (X, y) in enumerate(dataloader):
        # Compute prediction and loss
        X, y = X.to(device), y.to(device)
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 20 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")


def test_loop(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss, correct = 0, 0

    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()

    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
    
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9)

In [7]:
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train_loop(train_dataloader, model, loss_fn, optimizer)
    if (t+1) % 5 == 0:
        test_loop(train_dataloader, model, loss_fn)
        test_loop(test_dataloader, model, loss_fn)
print("Done!")

Epoch 1
-------------------------------


  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


loss: 1.979666  [    0/ 6993]
loss: 1.937241  [ 1280/ 6993]
loss: 1.936771  [ 2560/ 6993]
loss: 1.928247  [ 3840/ 6993]
loss: 1.870707  [ 5120/ 6993]
loss: 1.898243  [ 6400/ 6993]
Epoch 2
-------------------------------
loss: 1.917826  [    0/ 6993]
loss: 1.888366  [ 1280/ 6993]
loss: 1.891215  [ 2560/ 6993]
loss: 1.868423  [ 3840/ 6993]
loss: 1.852230  [ 5120/ 6993]
loss: 1.862719  [ 6400/ 6993]
Epoch 3
-------------------------------
loss: 1.865219  [    0/ 6993]
loss: 1.880766  [ 1280/ 6993]
loss: 1.835865  [ 2560/ 6993]
loss: 1.841527  [ 3840/ 6993]
loss: 1.838375  [ 5120/ 6993]
loss: 1.837552  [ 6400/ 6993]
Epoch 4
-------------------------------
loss: 1.807266  [    0/ 6993]
loss: 1.795572  [ 1280/ 6993]
loss: 1.793792  [ 2560/ 6993]
loss: 1.773298  [ 3840/ 6993]
loss: 1.754071  [ 5120/ 6993]
loss: 1.772453  [ 6400/ 6993]
Epoch 5
-------------------------------
loss: 1.779099  [    0/ 6993]
loss: 1.760010  [ 1280/ 6993]
loss: 1.759344  [ 2560/ 6993]
loss: 1.768843  [ 3840/ 6993]
