## Model Training

This notebook contains all of the code used to train the model used in my project. This model makes use of the PyTorch library and the pre-trained Faster R-CNN model provided in the TorchVision library. The structure of the model can be changed to accomadate the different output classes and then finetuned on my training data. The documentation for the Faster R-CNN implementation can be found [here](https://pytorch.org/vision/stable/models/faster_rcnn.html).

The primary components of this model are a region proposal network (RPN), a Fast-RCNN classifier and a backbone. The pytorch stable 1.13.1 and cuda 11.7 versions were used in this environment.

- test model
- dataset and loader
- Alter structure
- Train
- Evaluate

### Imports

The following imports are needed to run the code in this notebook.

In [1]:
import os
from pathlib import Path
import json
from tqdm import tqdm
import numpy as np
import math
from matplotlib import pyplot as plt
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models import resnet50, ResNet50_Weights
from torchvision.io import read_image
from torchvision.utils import draw_bounding_boxes
import torchvision.transforms as T
import torchvision.transforms.functional as f
from torchvision.models.feature_extraction import get_graph_node_names, create_feature_extractor
from torchmetrics.detection.mean_ap import MeanAveragePrecision
from torch.utils.tensorboard import SummaryWriter

import utils

In [2]:
writer = SummaryWriter('D:/DS/DS4/Project/tensorboard_logs/midog')

Ensure that we have our cuda set up correctly and have access to the GPU.

In [3]:
torch.cuda.get_device_name(0)

'NVIDIA GeForce GTX 1050 Ti'

### Model inspection

Firstly, I will load the default model, inspect the architecture and ensure that it works as expected. This can be done by using the model to make a prediction on an image from the COCO dataset (which it was trained on). In this way I can better understand how the model should be used and ensure it has the pre-trained weights.

In [None]:
# Load the model weights into the model and save as an object. Set the model to evaluation mode so we can inference.
model = torchvision.models.detection.fasterrcnn_resnet50_fpn_v2(weights="DEFAULT")
model.eval()

We have loaded the weights into the model object and set the model to evaluation mode. Now we can use print on the model to output all of the information regarding the layers and different connections. We can see several different componenets.

+ Transform - the transformations the be applied to an image before it can be passed through the model.
+ backbone (with FPN) - This part of the model is responsible for feature extraction and consists of sequential bottlenecks to create compressed feature representations. The feature pyramid network (FPN) allows for context at different levels by outputing feature maps at multiple levels.
+ Region proposal network (RPN) = Network for proposing regions with possible object of interest that are used during the final stage of detection.
+ Fast R-CNN predictor - Takes the regions and the features maps and returns the detections.

This is the basic overview of our model that can be seen in the below output. This can be finetuned to our specific task with some minor alterations.

In [None]:
print(model)

Now we can take a sample image from the COCO dataset, which the default weights for the Faster R-CNN model were trained on. Using this we can see how the model works and how well it works.

In [None]:
# Read in a COCO image
img = Image.open("D:/DS/DS4/Project/COCO/tennis racket/COCO_test2014_000000000057.jpg").convert("RGB")

In [None]:
# Check the image
img

In [None]:
# Convert the image to a tensor and add an empty dimension so we have [number images x channels x height x width]
tensor_img = T.ToTensor()(img)
tensor_img = tensor_img[None, :]
print(tensor_img.shape)

In [None]:
# Pass our example image through the model to get the predictions.
_, predictions = model(tensor_img)

In [None]:
boxes = predictions[0]["boxes"] # Pull the boxes out of the prediction
labels = ["person", "sports ball", "tennis racket", "clock", "sports ball"] # These labels were manually entered from the COCO dataset class list
example = tensor_img[0,:] # Take just the first image
example = T.ConvertImageDtype(torch.uint8)(example) # Convert it to int8 

In [None]:
img_with_boxes = draw_bounding_boxes(example, boxes, labels) # Draw the bounding boxes onto the image

In [None]:
plt.imshow(img_with_boxes.permute(1, 2, 0)) # We need to change the ordering as the channels should be the last dimensions for pyplot.

We can see that the model does a good job at finding the object that are present in the COCO dataset classes, and it is likely the weights were trained on this image. This is still a good sanity check that the model weights are meaningful and will hopefully be a good starting point for the creation of my model.

### Dataset and Dataloader

Now that we have inspected the model and have a better understanding of the architecture, how to use it and the input and output types we are ready to start fine-tuning it. The first step is to get our data loaded in by creating a PyTorch Dataset and then making use of a Dataloader to load the data in in batches.

The Dataset class below implements the relevant functions to read our Midog data in as a Dataset.

In [4]:
def transforms(img, target):
    return T.ToTensor()(img), target

In [5]:
class MidogDataset(Dataset):
    def __init__(self, root, transforms=None):
        self.root = root
        self.transforms = transforms
        # load all image files, sorting them to
        # ensure that they are aligned
        self.imgs = list(sorted(os.listdir(Path(root, "data"))))
        
        with open(os.path.join(root, "training.json")) as t:   
            trining_data = json.load(t)
        
        self.data = trining_data["images"]

    def __getitem__(self, idx):
        # load images ad masks
        img_path = Path(self.root, "data", self.imgs[idx])
        img = Image.open(img_path).convert("RGB")
        
        image_id, tile_id = self.imgs[idx].split(".")[0].split("_")

        tile_info = self.get_tile_info(image_id, tile_id)

        # get bounding box coordinates for this tile
        boxes = []
        for anno in tile_info["annotations"]:
            left, bottom, right, top = anno["bounding_box"].values()
            
            boxes.append([left, bottom, right, top])
        
        num_objs = len(boxes)

        boxes = torch.as_tensor(boxes, dtype=torch.int32)
        # there is only one class
        labels = torch.ones((num_objs,), dtype=torch.int64)
    

        target = {}
        target["boxes"] = boxes
        target["labels"] = labels
        target["image_id"] = image_id
        target["tile_id"] = tile_id
    

        if self.transforms is not None:
            img, target = self.transforms(img, target)

        return img, target

    def __len__(self):
        return len(self.imgs)

    def get_tile_info(self, image_id, tile_id):
        
        image_info = next((image for image in self.data if image["image_id"] == int(image_id)), None)
        tile_info = next((tile for tile in image_info["tiles"] if tile["tile_id"] == int(tile_id)), None)
        
        return tile_info

In [6]:
midog = MidogDataset("D:/DS/DS4/Project/Training_mitotic_figures", transforms)

In [None]:
data_loader = torch.utils.data.DataLoader(midog, batch_size=2, shuffle=False, collate_fn=utils.collate_fn)

In [None]:
images, target = next(iter(data_loader))

Using this batch of images we can compute predictions and see if the model can detect any objects. The classes will be meaningless as they are from another task.

In [None]:
_, predictions = model(images)

In [None]:
boxes = predictions[0]["boxes"] # Pull the boxes out of the prediction
final_boxes = torch.cat([boxes, target[0]["boxes"]])
labels = ["pred", "pred", "pred", "pred", "pred", "pred","gt"]
example = images[0,:] # Take just the first image
example = T.ConvertImageDtype(torch.uint8)(example) # Convert it to int8 

In [None]:
img_with_boxes = draw_bounding_boxes(example, final_boxes, labels) # Draw the bounding boxes onto the image

In [None]:
plt.imshow(img_with_boxes.permute(1, 2, 0)) # We need to change the ordering as the channels should be the last dimensions for pyplot.

Unfortunately, there does not seem to be any meaningful predictions being made by the model with default weights. This is to be expected as the task is in a different domain. Once we have trained the model on our data we will see much more reasonable predictions.

### Model Alteration

No we need to make some alterations to our model before we train it. We need to change the output number of classes, so we need to create a new predictor that takes the feature maps and predicts the class and bounding box for an object. This can be done as seen below.

In [7]:
# load a model pre-trained pre-trained on COCO
model = torchvision.models.detection.fasterrcnn_resnet50_fpn_v2(weights="DEFAULT")

# replace the classifier with a new one, that has
# num_classes which is user-defined
num_classes = 2  # 1 class (mitotic figure) + background

# get number of input features for the classifier
in_features = model.roi_heads.box_predictor.cls_score.in_features

# replace the pre-trained head with a new one
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

Now our model is ready to be fine-tuned.

### Data sanity check

Before training I will do a quick sanity check of the data to ensure it is all as it should be. It turns out that there was a bug in my bounding box generation process that didn't account for tile difference when a bounding box was negative i.e. on the padding to the left or under the original tile. This has been fixed and now my sanity checks are as expected. Preserving my sanity.

In [None]:
# use our dataset and defined transformations
dataset = MidogDataset("D:/DS/DS4/Project/Training_mitotic_figures", transforms)

# define training and validation data loaders
data_loader = torch.utils.data.DataLoader(
    dataset, batch_size=2, shuffle=False, collate_fn=utils.collate_fn)

for images, targets in data_loader:
    try:
        # Check that no values are outside of the boudanries
        assert(max([t["boxes"].numpy().max() for t in targets]) < 512)
        assert(min([t["boxes"].numpy().min() for t in targets]) >= 0)

        boxes = np.vstack([t["boxes"].numpy() for t in targets])

        # Check that the x1 < x2
        assert(np.all(boxes[:,0] < boxes[:,2]) == True)

        # Check that y1 < y2
        assert(np.all(boxes[:,1] < boxes[:,3]) == True)
    
    except AssertionError:
        print(targets)

### Model training

The process of training our model consists of multiple steps.

Firstly we create our dataset and split it into a train and test set. We can make use of the torch random seed to ensure that the random premutation we get for the indices stays the same everytime this is run. There are 7549 training images, some of these may have more than 1 example on them. Let us take ~10% of the data for testing, which is 750 tiles for testing (375 batches of 2 images).

Now that we have our data, we need to consider the loss. This is returned in the form of a classification and regression loss for the RPN and the R-CNN. These can be summed to get the total loss which can be used to compute the optimization steps.

We will make use of the Adam optimizer, which makes use of both Momentum and root mean square propagation to converge faster.

This covers our main training loop that updates our weights. After every epoch we would like to run the model on our test set and look at some metrics. We will make use of a sample test set to determine the metrics we wish to use. These will come from the torchmetrics library and will be commonly used metrics for detection tasks. In addition, it would be nice to look at the test loss. 

To find this we need to alter the code in the torchvision library as this pre-trained model does not return the loss in evaluation mode. The files that have been altered are available in the project repo. The changed invovled going through the Faster R-CNN component classes and ensuring that they returned the loss in evaluation mode.

With this, we are ready to train our model

In [8]:
def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, scaler=None):
    model.train()
    metric_logger = utils.MetricLogger(delimiter="  ")
    metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value:.6f}"))
    header = f"Epoch: [{epoch}]"

    lr_scheduler = None
    if epoch == 0:
        warmup_factor = 1.0 / 1000
        warmup_iters = min(1000, len(data_loader) - 1)

        lr_scheduler = torch.optim.lr_scheduler.LinearLR(
            optimizer, start_factor=warmup_factor, total_iters=warmup_iters
        )

    total_loss = 0
    for iteration, data in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
        images, targets = data
        images = images.to(device)
        targets = [{k: v.to(device) for k, v in t.items() if k in ["boxes", "labels"]} for t in targets]
        with torch.cuda.amp.autocast(enabled=scaler is not None):
            loss_dict = model(images, targets)
            losses = sum(loss for loss in loss_dict.values())

        # reduce losses over all GPUs for logging purposes
        loss_dict_reduced = utils.reduce_dict(loss_dict)
        losses_reduced = sum(loss for loss in loss_dict_reduced.values())

        loss_value = losses_reduced.item()
        total_loss += loss_value
        
        del images, targets, loss_dict
        

        if not math.isfinite(loss_value):
            print(f"Loss is {loss_value}, stopping training")
            print(loss_dict_reduced)
            sys.exit(1)

        optimizer.zero_grad()
        if scaler is not None:
            scaler.scale(losses).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            losses.backward()
            optimizer.step()

        if lr_scheduler is not None:
            lr_scheduler.step()

        metric_logger.update(loss=losses_reduced, **loss_dict_reduced)
        metric_logger.update(lr=optimizer.param_groups[0]["lr"])
    
    writer.add_scalar('training loss', total_loss / len(data_loader), epoch)

    return metric_logger, total_loss / len(data_loader)

In [9]:
def evaluate(model, data_loader, data_loader_test, epoch, device):
    
    model.eval()

    metric = MeanAveragePrecision()
    metric.to(device)
    
    total_loss = 0
    with torch.no_grad():
        for images, targets in tqdm(data_loader_test):
            images = images.to(device)

            targets = [{k: v.to(device) for k, v in t.items() if k in ["boxes", "labels"]} for t in targets]

            losses, predictions = model(images, targets)
            losses_summed = sum(loss for loss in losses.values()).item()
            total_loss += losses_summed

            metric.update(predictions, targets)

            del predictions, images, targets, losses
    
    test_loss = total_loss / len(data_loader_test)
    writer.add_scalar('test loss', total_loss / test_loss, epoch)

    torch.cuda.empty_cache()

    test_metrics = metric.compute()
    
    return test_loss, test_metrics

In [14]:
from datetime import datetime

def train(model, num_epochs, checkpoint=None):
    # train on the GPU or on the CPU, if a GPU is not available
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    
    # use our dataset and defined transformations
    dataset = MidogDataset("D:/DS/DS4/Project/Training_mitotic_figures", transforms)

    # split the dataset in train and test set
    torch.manual_seed(42)
    indices = torch.randperm(len(dataset)).tolist()
    dataset_train = torch.utils.data.Subset(dataset, indices[:-750])
    dataset_test = torch.utils.data.Subset(dataset, indices[-750:])

    torch.manual_seed(42)
    # define training and validation data loaders
    data_loader = torch.utils.data.DataLoader(
        dataset_train, batch_size=2, shuffle=False, collate_fn=utils.collate_fn)

    data_loader_test = torch.utils.data.DataLoader(
        dataset_test, batch_size=2, shuffle=False,
        collate_fn=utils.collate_fn)

    if checkpoint:
        model.load_state_dict(checkpoint['model_state_dict'])

    # move model to the right device
    model.to(device)

    # construct an optimizer
    params = [p for p in model.parameters() if p.requires_grad]
    optimizer = torch.optim.Adam(params, lr=0.005, weight_decay=0.0005)
    
    if checkpoint:
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        epoch = checkpoint['epoch'] + 1

    # and a learning rate scheduler
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                   step_size=3,
                                                   gamma=0.1)
    epoch = 0
    
    if checkpoint:
        epoch = checkpoint['epoch'] + 1

    for epoch in range(epoch, epoch + num_epochs):
        # train for one epoch, printing every 100 iterations
        _, training_loss = train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq=100)
        # update the learning rate
        lr_scheduler.step()
        # evaluate on the test dataset
        test_loss, test_metrics = evaluate(model, data_loader, data_loader_test, epoch, device=device)
        
        with open("D:/DS/DS4/Project/model_saves/metrics.txt", "a") as m:
            m.write(f"Epoch {epoch}: " + "Training loss: {training_loss} Test loss: {test_loss}\n" + str(test_metrics) + "\n\n")
        
        timestamp = datetime.now().strftime("%Y_%m_%d_%H_%M_%S")
        path_to_save = Path("D:/DS/DS4/Project/model_saves") / f"{timestamp}_{epoch}.pth"
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            }, path_to_save)

In [None]:
# Initial training
train(model, 10)

Epoch: [0]  [   0/3400]  eta: 2:33:30  lr: 0.000010  loss: 0.8816 (0.8816)  loss_classifier: 0.6656 (0.6656)  loss_box_reg: 0.0023 (0.0023)  loss_objectness: 0.1966 (0.1966)  loss_rpn_box_reg: 0.0171 (0.0171)  time: 2.7090  data: 0.0468  max mem: 2771


In [None]:
# Training from checkpoint
checkpoint = torch.load("D:/DS/DS4/Project/model_saves/2023_02_27_16_13_14_0.pth")
train(model, 10, checkpoint)

In [None]:
checkpoint = torch.load("D:/DS/DS4/Project/model_saves/2023_02_27_17_44_05_1.pth")

In [None]:
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()

In [None]:
# use our dataset and defined transformations
dataset = MidogDataset("D:/DS/DS4/Project/Training_mitotic_figures", transforms)

# split the dataset in train and test set
torch.manual_seed(42)
indices = torch.randperm(len(dataset)).tolist()
dataset_test = torch.utils.data.Subset(dataset, indices[-750:])

data_loader_test = torch.utils.data.DataLoader(
    dataset_test, batch_size=1, shuffle=False,
    collate_fn=utils.collate_fn)

In [None]:
iter_data = iter(data_loader_test)

In [None]:
images, targets = next(iter_data)

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model.to(device)
images = images.to(device)
targets = [{k: v.to(device) for k, v in t.items() if k in ["boxes", "labels"]} for t in targets]

with torch.no_grad():
    predictions = model(images)
    
boxes = predictions[0]["boxes"] # Pull the boxes out of the prediction
final_boxes = torch.cat([boxes, targets[0]["boxes"]])
labels = (["pred"] * len(boxes)) + (["gt"] * len(targets[0]["boxes"]))
example = images[0,:] # Take just the first image
example = T.ConvertImageDtype(torch.uint8)(example) # Convert it to int8 

img_with_boxes = draw_bounding_boxes(example, final_boxes, labels) # Draw the bounding boxes onto the image

plt.imshow(img_with_boxes.permute(1, 2, 0)) # We need to change the ordering as the channels should be the last dimensions for pyplot.