In [None]:
from datasets import PACSDataset
from torchvision import transforms, models
from torch.utils.data import DataLoader


# Using a pretrained network
We have created and trained a network from scratch. Now we will look at using a pretrained network. Pretrained networks are trained on image net and typically have good features. There are mainly two ways of doing so:

1. Fixed feature extractor: We use a pretrained network, freeze weights and chop off the output layer. A new layer is appended at the end with randomly intialised weights. We train only the new layer.
2. Fine-tuning: Instead of freezing the pretrained net, we also update its weights. The process is as follows.
    * Start with a fixed feature extractor as above. 
    * Once the new layer is trained, unfreeze the entire net and train with a smaller learning rate. 
    

## Train and Test dataloaders
Normally, we'll have a training, validation and test set. The validation set is used for hyperparameter tuning. Since we won't do any hyperparameter tuning we can just go ahead and use the test set for evaluation. 

In [None]:
# Files and transforms
with open("train_files.txt") as f:
    file_names_train = f.read().splitlines()
    
with open("test_files.txt") as f:
    file_names_test = f.read().splitlines()

# transforms
# A great point to add data augmentations - you want to do class-preserving transformations
# Rotations, Reflections etc
transform = transforms.Compose([
            transforms.Resize(size=(224, 224)),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
        ])


train_dataset = PACSDataset(file_names_train, transform)
test_dataset = PACSDataset(file_names_test, transform)

train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=64, shuffle=True)

## Neural network modules
We will create neural network modules.

1. A classifier: fully connected net built on top of featurizer
2. A CNNClassifier which uses a pretrained net with the above classifier

In [None]:
from torch import nn
from torch.nn import functional as F
import torch


# Classifier
# We will hardcode the layers for now, but it is best to use parameters
class Classifier(nn.Module):
    def __init__(self, num_classes: int = 7, num_features = 12544):
        super(Classifier, self).__init__()
        self.fc1 = nn.Linear(num_features, 4096)
        self.dropout = nn.Dropout2d(0.5)
        self.fc2 = nn.Linear(4096, 1024)
        self.fc3 = nn.Linear(1024, num_classes)
        
    def forward(self, x):
        x = torch.flatten(x, 1)
        x = self.dropout(x)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# 
class CNNClassifier(nn.Module):
    def __init__(self, num_classes: int = 7):
        super(CNNClassifier, self).__init__()
        # introduce resnet and freeze parameters
        
        # Add newly classifier
        # Parameters of newly constructed modules have requires_grad=True by default
        num_ftrs = self.net.fc.in_features
        self.net.fc = ...

    def forward(self, x):
        return self.net(x)


In [None]:
# Initialise network
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = CNNClassifier(num_classes=7)
model.to(device)
print(model)

## Training logic
As mentioned before, we can fit the training logic into a class with a network as an attribute.

We also need to consider hyperparameters

In [None]:
learning_rate = 1e-4
# batch_size = 64
epochs = 16

In [None]:
def train_loop(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    for batch, (X, y) in enumerate(dataloader):
        # Compute prediction and loss
        X, y = X.to(device), y.to(device)
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 20 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")


def test_loop(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss, correct = 0, 0

    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()

    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
    
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9)

In [None]:
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train_loop(train_dataloader, model, loss_fn, optimizer)
    if (t+1) % 5 == 0:
        test_loop(train_dataloader, model, loss_fn)
        test_loop(test_dataloader, model, loss_fn)
print("Done!")