# Classification Training Exercise

- Classification training pipeline using plain pytorch and a custom model
- as classification of all 4 classes turns out to be rather difficult, we will focus on classifying negative or positive for pneumonia only

In [None]:
from ast import literal_eval
import gc
import os
from pathlib import Path
import numpy as np 
import pandas as pd 
import cv2
from datetime import datetime
import time
import random
import matplotlib.pyplot as plt

# Albumenatations
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2

# torch
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

# importing from local package
import sys
sys.path.insert(0, '../')

from src.visualization.images import show_img
from src.visualization.metrics import plot_confusion_matrix

## Setup

In [None]:
def seed_everything(seed):
    """ seed random number generators to make runs deterministic """
    
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
    

# TODO: try training with different values for 
# - learning rate
# - batch size
# - image size
# - number of epochs
# - fold number
    
class Config:
    """ Configuration for the training """
    
    lr: float = 1e-3 # learning rate
    batch_size: int = 4 # size of training batches
    img_size: int = (256, 256) # image size for model input
    n_epochs: int = 100           
    fold_num: int = 0 # chooses fold for training. Available folds: [0, 1, 2]
    
    time_stamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S') # current time (for logging)
    exp_name = f"classification_{time_stamp}" 
    log_path = Path("./logs") / exp_name # name of the log folder                  
    seed: int = 2021 # for random number generators        
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # computing device    
    num_workers: int = 4 # number of processors used to prepare the batches. Adjust to number of CPUs on your machine     
    num_classes: int = 2 # number of classes
    label_dict = {
        0 : "negative",
        1 : "positive"
    }    

seed_everything(Config.seed) # makes the results reproducible 

## Helper Function

In [None]:
def plot_image_batch(images, gt_labels, image_ids, pred_labels=None):
    """ plots image batch returned from dataloader and optionally predictions
    
    Expects data to be detached from gradients and moved to CPU
    """
    
    # for visibility, plot only a few images for large batches
    n_plot_max = 8
    if len(images) > n_plot_max:
        n_plot = n_plot_max        
        sample_idx = torch.randperm(len(images))[:n_plot]        
        images = images[sample_idx]        
        gt_labels = gt_labels[sample_idx]                
        image_ids = [image_ids[i.item()] for i in sample_idx]
        if pred_labels is not None:
            pred_labels = pred_labels[sample_idx]
    else:
        n_plot = len(images)
    
    n_cols=2
    n_rows=n_plot // 2    
    
    figsize= (n_cols * 7, n_rows * 7)

    fig, ax = plt.subplots(figsize=(14, 14),  nrows=n_rows, ncols=n_cols)
    for n in range (n_plot):    
        img_id = image_ids[n]
        gt_label = gt_labels[n].item()
        img = images[n].numpy()
        img = np.squeeze(img)

        row = n // n_cols
        col = n % n_cols 
        ax[row][col].imshow(img, cmap='gray')   
        ax[row][col].axis('off')    
        
        if pred_labels is None:
            title = f"{image_id}: {Config.label_dict[gt_label]}"
        else:
            pred_label = pred_labels[n].item()
            title = f"{image_id}: GT {Config.label_dict[gt_label]}, Pred {Config.label_dict[pred_label]}"
        
        ax[row][col].set_title(title)    

    plt.tight_layout()    

## Load Data

In [None]:
data_path = Path('../data/siim-covid19-detection-subset')
train_path = data_path / "train"

# annotation data frame
ann_df = pd.read_csv(data_path / "train_annotations.csv", converters={
    "boxes": literal_eval, 
    "labels": literal_eval,
    "pixel_spacing": literal_eval
   })  

ann_df.head()

## Image Pre-Processing

Image processing serves two purposes:

- it standardizes the images shown to the model
- it augments the data with with new images by applying random transformations to the existing images
- to guarantee that the images are still standardized, the standardization is applied after the augmentation

In [None]:
# normalization transforms
norm_transform_list = [
    A.Resize(height=Config.img_size[0], width=Config.img_size[1], p=1.0),
    A.Normalize(mean=(0,), std=(1,), p=1.0),
    ToTensorV2(p=1.0)    
]

aug_transform_list = [
    # TODO: add augmentation transforms
    # Think about wich ones are reasonable for the data set
    # Check whether they do what intended in the data sanity checks
    # Observe their influence on the training process
]

train_transforms = A.Compose(aug_transform_list + norm_transform_list)
val_transforms = A.Compose(norm_transform_list)

## Dataset & DataLoader

- see [DATASETS & DATALOADERS](https://pytorch.org/tutorials/beginner/basics/data_tutorial.html) for an explanation

In [None]:
class ClassificationDataset(Dataset):

    def __init__(self, df, data_path, transforms=None):
        super().__init__()

        self.df = df
        self.data_path = data_path
        self.transforms = transforms

    def __getitem__(self, index: int):
        sample = self.df.iloc[index]

        img_path = self.data_path / sample['rel_image_path']
        image_id = sample["id"]
        label = sample["study_label"]

        # TODO implement the remaining logic
        raise NotImplementedError("Replace this error with the actual implementation")
        # Hint: for the start we try to only classify positive and negative cases
        # you can use `label = int(label > 0)` for this purpose

        return img, label, image_id

    def __len__(self) -> int:
        return self.df.shape[0]

It is a common practice to split the data into several folds for training. This data was split into 3 folds.<br>
The idea is that you use the selected fold for validation only and the remaining folds for training. This way you can check during the training whether your model performance generalizes to unseen data.

In [None]:
train_df = ann_df[ann_df['fold'] != Config.fold_num]
val_df = ann_df[ann_df['fold'] == Config.fold_num]

train_ds = ClassificationDataset(train_df, train_path, train_transforms)
val_ds = ClassificationDataset(val_df, train_path, val_transforms)

Remark: When using batch size > 1, one should shuffle the data for training. Otherwise the model might learn to deduce the label based on the sequence of the data shown. This will of course break the system during real inference. <br>
For the validation the sequence doesn't matter. So to avoid additional processing, one can skip the shuffling.

In [None]:
train_dl = DataLoader(
        train_ds,
        batch_size = Config.batch_size,
        shuffle = True,
        num_workers = Config.num_workers
    )
val_dl = DataLoader(
        val_ds,
        batch_size = Config.batch_size,
        shuffle = False,
        num_workers = Config.num_workers        
    )  

### Dataset Sanity Check

To find errors in the data loading and augmentation implementation, it is a good practice to visualize a few images before showing them to the model

In [None]:
image, label, image_id = train_ds[2]

img_np = image.numpy()
img_np = np.squeeze(img_np)
gt_label = label.item()
    
_ = show_img(img_np, figsize=(12, 12), title=f"{image_id}: {Config.label_dict[gt_label]}")

### Dataloader Sanity Check

In [None]:
train_it = iter(train_dl)

In [None]:
images, labels, image_ids = next(train_it)

plot_image_batch(images, labels, image_ids)

# Model

See [BUILD THE NEURAL NETWORK](https://pytorch.org/tutorials/beginner/basics/buildmodel_tutorial.html)

[TORCHVISION.MODELS](https://pytorch.org/vision/stable/models.html#torchvision-models) contains a list of models which can be used for finetuning

In [None]:
class ClassificationBaseline(nn.Module):
    """ This is a simple baseline model, based on the Resnet architecture"""
    
    def __init__(self, cfg):
        super().__init__()
        self.model = torchvision.models.resnet18(pretrained=True)

        # input for grayscale images
        self.model.conv1 = nn.Conv2d(1, 64, kernel_size=(
            7, 7), stride=(2, 2), padding=(3, 3), bias=False)

        # adjust number of output classes
        self.model.fc = nn.Linear(self.model.fc.in_features, cfg.num_classes)

    def forward(self, x):
        logits = self.model(x)

        return logits

In [None]:
# TODO develop your own model based on the tutorial
# Note: if you run on a CPU, it is better to use small models (few layers with few features)
# See how the custom model performs compared to the baseline

class ClassificationCustom(nn.Module):
    def __init__(self, cfg) -> None:
        super().__init__()
        # TODO add your network layers here

    def forward(self, x: Tensor) -> Tensor:
        
        # TODO add inference logic here        
        raise NotImplementedError("Replace this error message after implementation")

In [None]:
def get_model(cfg, checkpoint_path=None):
    
    model = ClassificationBaseline(cfg)    
    #model = ClassificationCustom(cfg)    
    
    # Load the trained weights
    if checkpoint_path is not None:
        checkpoint = torch.load(checkpoint_path)
        model.load_state_dict(checkpoint['model_state_dict'])

        del checkpoint
        gc.collect()
        
    return model.cuda()

In [None]:
model = get_model(Config)
model.to(Config.device)

# Training

In [None]:
# TODO: develop the classification training logic

class ClassificationTrainer():
    def __init__(self, model, cfg) -> None:
        self.cfg = cfg
        self.model = model
        # TODO: all variables you need for the implementation are added here


    def fit(self, train_loader, validation_loader):
        writer = SummaryWriter(self.log_base)
        for e in range(self.cfg.n_epochs):
            
            # TODO (optional) add more logic to the overall training procedure. E.g.:
            # better logging
            # save best model only
            # any learning rate schedule
            # anything else you find or want to try out
            
            train_loss, train_acc = self.train_epoch(train_loader)            
            print(f'Train. Epoch: {e}, train_loss: {train_loss:.5f}, train_accuracy: {train_acc:.5f}')

            val_loss, val_acc = self.validate_epoch(validation_loader)
            print(f'Val. Epoch: {e}, val_loss: {val_loss:.5f}, val_accuracy: {val_acc:.5f}')                        

    def train_epoch(self, train_loader):
        
        # TODO implement training logic
        # it should compute the training loss
        # score is optional, but usefull to compare with the validation score
        # for this data set accuracy is a useful metric
        raise NotImplementedError("Replace this error message after implementation")

        return loss, score

    def validate_epoch(self, val_loader):
        
        # TODO implement validation logic
        # it should compute the validation loss and score (e.g. accuracy)        
        raise NotImplementedError("Replace this error message after implementation")

        return epoch_loss.avg, score.acc

In [None]:
trainer = ClassificationTrainer(model, Config)
trainer.fit(train_dl, val_dl)

# Inference

This is to inspect whether the model learned something reasonable. <br>
Typically the type of errors a model makes gives hints on what should be improved

In [None]:
model_path = Config.log_path / 'path/to/trained/model/file'

model = get_model(Config, checkpoint_path=model_path)
model.eval()

Iterators allow to loop through the data step-by-step<br>
Every time `next` is called on the iterator it fetches new samples from the data set <br>
So simply re-run the inference cell to see some results <br>
Every time the iterator definition cell is executed, the iterator will start from the beginning again

In [None]:
val_it = iter(val_dl)

In [None]:
images, gt_labels, image_ids = next(val_it)
images = images.to(Config.device)

with torch.no_grad():
    logits = model(images)
    
images = images.cpu()
_, preds = torch.max(logits, 1)

pred_labels = preds.cpu()

plot_image_batch(images, gt_labels, image_ids, pred_labels=pred_labels)

### Validation Score

Loop through the whole dataset to compute the validation scores <br>
Here the 'confusion matrix' is used. See [Confusion matrix](https://en.wikipedia.org/wiki/Confusion_matrix) <br>
It can be used to calculate the accuracy but gives more insights on the type of errors which occurred.

In [None]:
from sklearn.metrics import confusion_matrix

In [None]:
# TODO
# understand the confusion matrix
# understand its connection to true positives, false positives, true negatives and false negatives
# understand how to calculate 'precision' and 'recall' from it (or if you prefer 'sensitivity' and 'specificity')

In [None]:
all_gt_labels = torch.zeros(len(val_ds), dtype=torch.long, device='cpu')
all_pred_labels = torch.zeros(len(val_ds), dtype=torch.long, device='cpu')

for i, (images, gt_labels, image_ids) in enumerate(val_dl):
    n_samples = len(image_ids)
    
    images = images.to(Config.device)

    with torch.no_grad():
        logits = model(images)
    
    _, preds = torch.max(logits, 1)

    pred_labels = preds.cpu()
    
    all_gt_labels[i*n_samples:(i+1)*n_samples] = gt_labels
    all_pred_labels[i*n_samples:(i+1)*n_samples] = pred_labels
    

all_gt_labels = all_gt_labels.numpy()
all_pred_labels = all_pred_labels.numpy()
conf_mat = confusion_matrix(all_gt_labels, all_pred_labels)

_ = plot_confusion_matrix(conf_mat, list(Config.label_dict.values()))

for binary classification in particular a few other metrics are useful <br>
Many metrics can be found in the 'sklearn' library: see [scikit-learn Metrics](https://scikit-learn.org/stable/modules/classes.html#module-sklearn.metrics)

In [None]:
from sklearn.metrics import precision_score, recall_score, f1_score

In [None]:
# TODO use the sklearn functions to calculate and print 'precision', 'recall' and 'f1-score'
# Hint: it is very similar as for the confusion matrix
# you can use this to check whether your understanding of the confusion matrix is correct