# Classification Training

- Classification training pipeline using plain pytorch and a custom model
- as classification of all 4 classes turns out to be rather difficult, we will focus on classifying negative or positive for pneumonia only

In [None]:
# import sys
# print(f"Install in: {sys.executable}")
# !{sys.executable} -m pip install scikit-learn

In [None]:
# To debug external functions
%load_ext autoreload
%autoreload 2

In [None]:
from ast import literal_eval
import gc
import os
from pathlib import Path
import numpy as np 
import pandas as pd 
import cv2
from datetime import datetime
import time
import random
from tqdm import tqdm_notebook as tqdm # progress bar
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt

# Albumenatations
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2

# torch
import torch
import torchvision
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SequentialSampler, RandomSampler

# importing from local package
import sys
sys.path.insert(0, '../')

from src.data.datasets import ClassificationDataset
from src.models.classification import ClassificationBaseline, ClassificationCustom
from src.train.ClassificationTrainer import ClassificationTrainer
from src.visualization.images import show_img
from src.visualization.metrics import plot_confusion_matrix

## Setup

In [None]:
def seed_everything(seed):
    """ seed random number generators to make runs deterministic """
    
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
    
    
class Config:
    """ Configuration for the training """
    
    time_stamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
    exp_name = f"classification_{time_stamp}" 
    log_path = Path("./logs") / exp_name    
    fold_num: int = 0            
    seed: int = 2021
    num_classes: int = 2 
    img_size: int = (768, 768) # (768, 896) # input image size    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # computing device    
    num_workers: int = 8 # number of processors used to prepare the batches
    batch_size: int = 16
    n_epochs: int = 100
    lr: float = 5e-4
    use_scheduler = False
    label_dict = {
        0 : "negative",
        1 : "positive"
    }
#     label_dict = {
#         0 : "negative", 
#         1 : "typical", 
#         2 : "indeterminate", 
#         3 : "atypical"
#     }     

seed_everything(Config.seed)

## Helper Function

In [None]:
def plot_image_batch(images, gt_labels, image_ids, pred_labels=None):
    """ plots image batch returned from dataloader and optionally predictions
    
    Expects data to be detached from gradients and moved to CPU
    """
    
    # for visibility, plot only a few images for large batches
    n_plot_max = 8
    if len(images) > n_plot_max:
        n_plot = n_plot_max        
        sample_idx = torch.randperm(len(images))[:n_plot]        
        images = images[sample_idx]        
        gt_labels = gt_labels[sample_idx]                
        image_ids = [image_ids[i.item()] for i in sample_idx]
        if pred_labels is not None:
            pred_labels = pred_labels[sample_idx]
    else:
        n_plot = len(images)
    
    n_cols=2
    n_rows=n_plot // 2    
    
    figsize= (n_cols * 7, n_rows * 7)

    fig, ax = plt.subplots(figsize=(14, 14),  nrows=n_rows, ncols=n_cols)
    for n in range (n_plot):    
        img_id = image_ids[n]
        gt_label = gt_labels[n].item()
        img = images[n].numpy()
        img = np.squeeze(img)

        row = n // n_cols
        col = n % n_cols 
        ax[row][col].imshow(img, cmap='gray')   
        ax[row][col].axis('off')    
        
        if pred_labels is None:
            title = f"{image_id}: {Config.label_dict[gt_label]}"
        else:
            pred_label = pred_labels[n].item()
            title = f"{image_id}: GT {Config.label_dict[gt_label]}, Pred {Config.label_dict[pred_label]}"
        
        ax[row][col].set_title(title)    

    plt.tight_layout()    

## Load Data

In [None]:
data_path = Path('../data/siim-covid19-detection-subset')
train_path = data_path / "train"

# annotation data frame
ann_df = pd.read_csv(data_path / "train_annotations.csv", converters={
    "boxes": literal_eval, 
    "labels": literal_eval,
    "pixel_spacing": literal_eval
   })  

ann_df.head()

## Image Pre-Processing

In [None]:
# normalization transforms
norm_transform_list = [
    A.Resize(height=Config.img_size[0], width=Config.img_size[1], p=1.0),
    A.Normalize(mean=(0,), std=(1,), p=1.0),
    ToTensorV2(p=1.0)    
]

aug_transform_list = [
    A.HorizontalFlip(p=0.5),
    A.RandomBrightnessContrast(brightness_limit=0.2,  
                               contrast_limit=0.2, p=0.5)    
]

train_transforms = A.Compose(aug_transform_list + norm_transform_list)
val_transforms = A.Compose(norm_transform_list)

## Dataset & DataLoader

In [None]:
train_df = ann_df[ann_df['fold'] != Config.fold_num]
val_df = ann_df[ann_df['fold'] == Config.fold_num]

train_ds = ClassificationDataset(train_df, train_path, train_transforms)
val_ds = ClassificationDataset(val_df, train_path, val_transforms)

train_dl = DataLoader(
        train_ds,
        batch_size = Config.batch_size,
        shuffle = True,
        num_workers = Config.num_workers
    )
val_dl = DataLoader(
        val_ds,
        batch_size = Config.batch_size,
        shuffle = False,
        num_workers = Config.num_workers        
    )  

### Dataset Sanity Check

In [None]:
image, label, image_id = train_ds[2]

img_np = image.numpy()
img_np = np.squeeze(img_np)
gt_label = label.item()
    
_ = show_img(img_np, figsize=(12, 12), title=f"{image_id}: {Config.label_dict[gt_label]}")

### Dataloader Sanity Check

In [None]:
train_it = iter(train_dl)

In [None]:
images, labels, image_ids = next(train_it)

plot_image_batch(images, labels, image_ids)

# Model

In [None]:
def get_model(cfg, checkpoint_path=None):
    model = ClassificationBaseline(cfg)    
    #model = ClassificationCustom(cfg)    
    
    # Load the trained weights
    if checkpoint_path is not None:
        checkpoint = torch.load(checkpoint_path)
        model.load_state_dict(checkpoint['model_state_dict'])

        del checkpoint
        gc.collect()
        
    return model.cuda()

In [None]:
model = get_model(Config)
model.to(Config.device)

# Training

In [None]:
trainer = ClassificationTrainer(model, Config)
trainer.fit(train_dl, val_dl)

# Inference

In [None]:
model_path = Config.log_path / 'last-checkpoint.pt'
#model_path = Path("../logs/torchvision") / "faster_rcnn_2021-07-03_15-21-57" / 'best-checkpoint-038epoch.pt'

model = get_model(Config, checkpoint_path=model_path)
model.eval()

In [None]:
val_it = iter(val_dl)

In [None]:
images, gt_labels, image_ids = next(val_it)
images = images.to(Config.device)

with torch.no_grad():
    logits = model(images)
    
images = images.cpu()
_, preds = torch.max(logits, 1)

pred_labels = preds.cpu()

plot_image_batch(images, gt_labels, image_ids, pred_labels=pred_labels)

### Validation Score

In [None]:
all_gt_labels = torch.zeros(len(val_ds), dtype=torch.long, device='cpu')
all_pred_labels = torch.zeros(len(val_ds), dtype=torch.long, device='cpu')

for i, (images, gt_labels, image_ids) in enumerate(val_dl):
    n_samples = len(image_ids)
    
    images = images.to(Config.device)

    with torch.no_grad():
        logits = model(images)
    
    _, preds = torch.max(logits, 1)

    pred_labels = preds.cpu()
    
    all_gt_labels[i*n_samples:(i+1)*n_samples] = gt_labels
    all_pred_labels[i*n_samples:(i+1)*n_samples] = pred_labels
    

all_gt_labels = all_gt_labels.numpy()
all_pred_labels = all_pred_labels.numpy()
conf_mat = confusion_matrix(all_gt_labels, all_pred_labels)

_ = plot_confusion_matrix(conf_mat, list(Config.label_dict.values()))

In [None]:
# for binary classification in particular a few other metrics are useful
from sklearn.metrics import precision_score, recall_score, f1_score

precision = precision_score(all_gt_labels, all_pred_labels) * 100
recall = recall_score(all_gt_labels, all_pred_labels) * 100
f1 = f1_score(all_gt_labels, all_pred_labels) * 100

print(f"Precision: {precision:.2f}%")
print(f"Recall: {recall:.2f}%")
print(f"F1 Score: {f1:.2f}%")