# Object Detection Training

In [None]:
# import sys
# print(f"Install in: {sys.executable}")
# !{sys.executable} -m pip install scikit-learn

In [None]:
# To debug external functions
%load_ext autoreload
%autoreload 2

In [None]:
from ast import literal_eval
import gc
import os
from pathlib import Path
import numpy as np 
import pandas as pd 
import cv2
from datetime import datetime
import time
import random
from tqdm import tqdm_notebook as tqdm # progress bar
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt

# Albumenatations
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2

# torch
import torch
import torchvision
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SequentialSampler, RandomSampler

In [None]:
# importing from src.local package
import sys
sys.path.insert(0, '../')

from src.data.datasets import DetectionDataset
from src.models.detection import DetectionBaseline
from src.train.utils import boxes_xyxy_rel_to_abs, boxes_xyxy_abs_to_rel
from src.train.DetectionTrainer import DetectionTrainer
from src.train.metrics import calculate_mAP
from src.visualization.images import show_img, show_img_with_boxes
from src.visualization.metrics import plot_confusion_matrix
from src.visualization.utils import cxcywh_to_xyxy

## Setup

In [None]:
def seed_everything(seed):
    """ seed random number generators to make runs deterministic """
    
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
    
    
class Config:
    """ Configuration for the training """
    
    time_stamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
    exp_name = f"ssd_{time_stamp}" 
    log_path = Path("./logs") / "detection" / exp_name    
    fold_num: int = 0            
    seed: int = 2021
    num_classes: int = 2     
    freeze_backbone = False # don't train backbone weights if true
    aspect_ratios: list = [1., 2., 0.5] # anchor box aspect ratios per cell
    img_size: int = (768,) * 2 # implementation expects height = width  
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # computing device    
    num_workers: int = 8 # number of processors used to prepare the batches
    batch_size: int = 16
    n_epochs: int = 200
    lr: float = 1e-4
    use_scheduler = False
    label_dict = {
        0 : "negative",
        1 : "positive"
    }
#     label_dict = {
#         0 : "negative", 
#         1 : "typical", 
#         2 : "indeterminate", 
#         3 : "atypical"
#     }     

seed_everything(Config.seed)

## Helper Function

In [None]:
def plot_image_batch(images, gt_targets, image_ids, pred_targets=None):
    """ plots image batch returned from dataloader and optionally predictions
    
    Expects data to be detached from gradients and moved to CPU
    """
    
    # for visibility, plot only a few images for large batches
    n_plot_max = 8
    if len(images) > n_plot_max:
        n_plot = n_plot_max        
        sample_idx = torch.randperm(len(images))[:n_plot]                
        images = [images[i] for i in sample_idx]
        gt_targets = [gt_targets[i] for i in sample_idx]             
        image_ids = [image_ids[i.item()] for i in sample_idx]
        if pred_targets is not None:            
            pred_targets = [pred_targets[i] for i in sample_idx]  
    else:
        n_plot = len(images)
    
    n_cols=2
    n_rows=n_plot // 2    
    
    figsize= (n_cols * 7, n_rows * 7)

    fig, ax = plt.subplots(figsize=(14, 14),  nrows=n_rows, ncols=n_cols)
    for n in range (n_plot):    
        img_id = image_ids[n]        
        img = images[n].numpy()
        img = np.squeeze(img)
        gt_target = gt_targets[n]
        boxes = gt_target['boxes'].numpy().astype(np.int32)
        labels = gt_target['labels'].numpy().astype(np.int32)
        gt_anns = [(l, b) for l, b in zip(labels, boxes) if l > 0]    
        pred_anns = None

        row = n // n_cols
        col = n % n_cols 
        
        if n_rows == 1:
            sub_ax = ax[col]
        else:
            sub_ax = ax[row][col]
        if pred_targets is not None:
            pred_target = pred_targets[n]
            boxes = pred_target['boxes'].numpy().astype(np.int32)
            scores = pred_target['scores'].numpy()
            #labels = pred_target['labels'].numpy().astype(np.int32)
            pred_anns = [(f"{s:.2f}", b) for s, b in zip(scores, boxes) if s > 0]              
        
        show_img_with_boxes(img, gt_anno=gt_anns, pred_anno=pred_anns, ax=sub_ax, title=img_id, box_format='xyxy')                   

    plt.tight_layout()    

## Load Data

In [None]:
data_path = Path('../data/siim-covid19-detection-subset')
train_path = data_path / "train"

# annotation data frame
ann_df = pd.read_csv(data_path / "train_annotations.csv", converters={
    "boxes": literal_eval, 
    "labels": literal_eval,
    "pixel_spacing": literal_eval
   })  

ann_df.head()

In [None]:
# TODO: fix for situations like this. Positive but no box showing opacity
ann_df[ann_df["id"] == '0bd6cd815ba9_image']

## Image Pre-Processing

In [None]:
# normalization transforms
norm_transform_list = [
    A.Resize(height=Config.img_size[0], width=Config.img_size[1], p=1.0),
    A.Normalize(mean=(0,), std=(1,), p=1.0),
    ToTensorV2(p=1.0)    
]

# augmentation transforms
aug_transform_list = [
    A.HorizontalFlip(p=0.5),
    A.RandomBrightnessContrast(brightness_limit=0.2,  
                               contrast_limit=0.2, p=0.5),
    A.CropAndPad(percent=(0.0, 0.02), pad_mode=cv2.BORDER_CONSTANT, 
                 pad_cval=0, keep_size=True, sample_independently=True, p=0.5)
]

# bounding box format
bbox_params = {'format': 'pascal_voc', 'label_fields': ['labels']}

train_transforms = A.Compose(aug_transform_list + norm_transform_list, bbox_params=bbox_params)
val_transforms = A.Compose(norm_transform_list, bbox_params=bbox_params)

## Dataset & DataLoader

In [None]:
def collate_fn(batch):
    return tuple(zip(*batch))


train_df = ann_df[ann_df['fold'] != Config.fold_num]
val_df = ann_df[ann_df['fold'] == Config.fold_num]

train_ds = DetectionDataset(train_df, train_path, train_transforms)
val_ds = DetectionDataset(val_df, train_path, val_transforms)

train_dl = DataLoader(
        train_ds,
        batch_size = Config.batch_size,
        shuffle = True,
        num_workers = Config.num_workers,
        collate_fn = collate_fn
    )
val_dl = DataLoader(
        val_ds,
        batch_size = Config.batch_size,
        shuffle = False,
        num_workers = Config.num_workers,    
        collate_fn = collate_fn    
    )

### Dataset Sanity Check

In [None]:
image, target, image_id = train_ds[307]

img_np = image.numpy()
img_np = np.squeeze(img_np)
boxes = target['boxes'].numpy().astype(np.int32)
labels = target['labels'].numpy().astype(np.int32)
gt_anns = [(l, b) for l, b in zip(labels, boxes) if l > 0]

_ = show_img_with_boxes(img_np, gt_anno=gt_anns, figsize=(12, 12), title=image_id, box_format='xyxy')    

### Dataloader Sanity Check

In [None]:
train_it = iter(train_dl)

In [None]:
images, targets, image_ids = next(train_it)
plot_image_batch(images, targets, image_ids)

## Model

In [None]:
def get_model(cfg, checkpoint_path=None):
    model = DetectionBaseline(cfg)    
    
    # Load the trained weights
    if checkpoint_path is not None:
        checkpoint = torch.load(checkpoint_path)
        model.load_state_dict(checkpoint['model_state_dict'])

        del checkpoint
        gc.collect()
        
    return model.cuda()

In [None]:
model = get_model(Config)
model.to(Config.device)

### Anchors

In [None]:
image, target, image_id = train_ds[0]
img_np = image.numpy()
img_np = np.squeeze(img_np)

anchors = model.anchors
anchors_np = anchors.cpu().numpy()
anchors_np = cxcywh_to_xyxy(anchors_np)
anchors_np = np.clip(anchors_np, 0.0, 1.0)
anchors_np = boxes_xyxy_rel_to_abs(anchors_np, img_np.shape)

print("Number of anchors:", anchors.shape[0])
print("Number of cells in last feature layer:", (Config.img_size[0] // 32)**2)

n_anchor_plot = 8 # how many anchors to plot (all would be too messy)
anchor_plot_idx = np.round(np.linspace(0, len(anchors_np) - 1, n_anchor_plot)).astype(int)
anchor_anns = [(i, b) for i, b in enumerate(anchors_np[anchor_plot_idx])]

_ = show_img_with_boxes(img_np, gt_anno=anchor_anns, figsize=(12, 12), title=image_id, box_format='xyxy')

## Training

In [None]:
trainer = DetectionTrainer(model, Config)
trainer.fit(train_dl, val_dl)

## Inference

In [None]:
model_path = Config.log_path / 'last-checkpoint.pt'
#model_path = Path("./logs/") / "detection" / "ssd_2021-07-09_21-22-36" / 'best-checkpoint-092epoch.pt'

model = get_model(Config, checkpoint_path=model_path)
model.eval()

In [None]:
val_it = iter(val_dl)

pred_threshold = 0.5
nms_threshold = 0.2 # max overlap allowed for predictions

In [None]:
images, gt_targets, image_ids = next(val_it)
images = torch.stack(images)
images = images.to(Config.device)

with torch.no_grad():        
    pred_locs, pred_scores = model(images)       
    det_boxes_batch, det_labels_batch, det_scores_batch = model.detect_objects(pred_locs, pred_scores,
                                                                               min_score=pred_threshold, 
                                                                               max_overlap=nms_threshold,
                                                                               top_k=100)  
    det_boxes_batch = [torch.clip(b, 0, 1) for b in det_boxes_batch]
    det_boxes_batch = [boxes_xyxy_rel_to_abs(b, img.shape[1:]) for b, img in zip(det_boxes_batch, images)]    

images = images.cpu()    
    
pred_targets = []    
for b, l, s in zip(det_boxes_batch, det_labels_batch, det_scores_batch):
    pred_target = {
        'boxes' : b.cpu(),
        'labels' : l.cpu(),
        'scores' : s.cpu()        
    }
    pred_targets.append(pred_target)

plot_image_batch(images, gt_targets, image_ids, pred_targets=pred_targets)

### Validation Score

In [None]:
det_boxes = list()
det_labels = list()
det_scores = list()
true_boxes = list()
true_labels = list()

for i, (images, targets, image_ids) in enumerate(val_dl):
    images = torch.stack(images)
    images = images.to(Config.device)
    
    boxes = [boxes_xyxy_abs_to_rel(t['boxes'].to(torch.float).to(
        Config.device), img.shape[1:]) for t, img in zip(targets, images)]
    labels = [t['labels'].to(Config.device) for t in targets]    

    with torch.no_grad():        
        pred_locs, pred_scores = model(images)     
    
    det_boxes_batch, det_labels_batch, det_scores_batch = model.detect_objects(pred_locs, pred_scores,
                                                                               min_score=0.01, 
                                                                               max_overlap=0.45,
                                                                               top_k=200)  
    
    det_boxes.extend(det_boxes_batch)
    det_labels.extend(det_labels_batch)
    det_scores.extend(det_scores_batch)
    true_boxes.extend(boxes)
    true_labels.extend(labels)    

APs, mAP = calculate_mAP(det_boxes, det_labels, det_scores, true_boxes, true_labels, Config)

print("Average precision per class:")
for k,v in APs.items():
    print(f"{k} \t {v:.5f}")
print(f"Mean average precision: {mAP:.5f}")