## SwinTransformer RCNN training and evaluation code 
This code is based on https://github.com/xiaohu2015/SwinT_detectron2<br>

I splitted video 0 and 1 to training and video 2 to validation.<br>
So far, I could only train the model for 5 epochs and only achieved:<br>
**bbox/AP = 15.2<br>**
**bbox/AP50 = 35.7<br>**
**bbox/AP75 = 8.0<br>**
on the validation set.

However, I'd like to share my work and will continue working on it to further improve.

## Install requirements
install Detectron2, timm and Detectron2 version implementation of SwinTransformer

In [None]:
"""cpu"""
# !python -m pip install detectron2 -f https://dl.fbaipublicfiles.com/detectron2/wheels/cpu/torch1.9/index.html
"""gpu"""    
!pip install torch==1.7.1+cu110 torchvision==0.8.2+cu110 -f https://download.pytorch.org/whl/torch_stable.html
!pip install detectron2==0.5 -f https://dl.fbaipublicfiles.com/detectron2/wheels/cu110/torch1.7/index.html

!pip install timm
!git clone https://github.com/emiz6413/SwinT_detectron2.git swin
!curl -OL https://github.com/xiaohu2015/SwinT_detectron2/releases/download/v1.1/faster_rcnn_swint_T.pth  # pretrained
!curl -OL https://github.com/emiz6413/SwinT_detectron2/releases/download/v1.3/model_0021209.pth  # trained

In [None]:
import os
from pathlib import Path
from ast import literal_eval

import numpy as np
import pandas as pd
import cv2
import torch
import matplotlib.pyplot as plt
from PIL import Image

from detectron2.data import (DatasetCatalog, 
                             MetadataCatalog, 
                             build_detection_test_loader
                            )
from detectron2.data.datasets.coco import convert_to_coco_json
from detectron2.utils.visualizer import Visualizer, ColorMode
from detectron2.structures import BoxMode
from detectron2.config import get_cfg
from detectron2.utils.logger import setup_logger
from detectron2.engine import DefaultTrainer, default_setup, hooks
from detectron2.modeling import GeneralizedRCNNWithTTA
from detectron2.evaluation import COCOEvaluator, DatasetEvaluators, inference_on_dataset

from swin.swint import add_swint_config

logger = setup_logger()

## 1. Load data

In [None]:
data_root = Path('/kaggle/input/tensorflow-great-barrier-reef/')
train_df = pd.read_csv(str(data_root/'train.csv'))
train_df['annotations'] = train_df['annotations'].apply(literal_eval)

## 2.1 Define a dataset function

In [None]:
def gbl_dataset(df, img_root):
    dataset = []
    for i, row in df.iterrows():
        file_name = str(img_root/"video_{}/{}.jpg".format(*row['image_id'].split('-')))
        width, height = Image.open(file_name).size
        image_id = i
        annotations = [dict(bbox=[bbox['x'], bbox['y'], bbox['width'], bbox['height']],
                            bbox_mode=BoxMode.XYWH_ABS,
                            category_id=0)
                       for bbox in row['annotations']]
        
        dataset.append(
            dict(file_name=file_name,width=width,
                 height=height,
                 image_id=image_id,
                 annotations=annotations
                )
        )
    return dataset

def gbl_dataset_wrapper(df, img_root):
    def wrapper():
        return gbl_dataset(df, img_root)
    return wrapper

In [None]:
_train_df = train_df.query("video_id != 2")
_val_df = train_df.query("video_id == 2")
train_ds = gbl_dataset_wrapper(_train_df, Path('/kaggle/input/tensorflow-great-barrier-reef/train_images/'))
val_ds = gbl_dataset_wrapper(_val_df, Path('/kaggle/input/tensorflow-great-barrier-reef/train_images/'))
DatasetCatalog.pop("gbl_train_dataset", None)
DatasetCatalog.pop("gbl_val_dataset", None)
DatasetCatalog.register("gbl_train_dataset", train_ds)
DatasetCatalog.register("gbl_val_dataset", val_ds)
MetadataCatalog.get("gbl_train_dataset").thing_classes = ["starfish"]
MetadataCatalog.get("gbl_val_dataset").thing_classes = ["starfish"]
#Convert validation dataset to coco format and dump it for evaluation
convert_to_coco_json('gbl_val_dataset', output_file='./output/inference/gbl_val_dataset_coco_format.json', allow_cached=False)

## 2.2 Check the Dataset
visualize the dataset for verification

In [None]:
gbl_ds = DatasetCatalog.get("gbl_train_dataset")
metadata = MetadataCatalog.get('gbl_train_dataset')

In [None]:
for data in gbl_ds:
    if len(data['annotations']):
        break
im = cv2.cvtColor(cv2.imread(data['file_name']), cv2.COLOR_BGR2RGB)
v = Visualizer(im, 
               metadata=MetadataCatalog.get('gbl_train_dataset'),
               scale=0.5)
out = v.draw_dataset_dict(data)
im = Image.fromarray(out.get_image())
im

## 3.1 Define a custom Trainer to evaluate on custom dataset

In [None]:
class Trainer(DefaultTrainer):
    @classmethod
    def build_evaluator(cls, cfg, dataset_name, output_folder=None):
        if output_folder is None:
            output_folder = os.path.join(cfg.OUTPUT_DIR, "inference")
        return COCOEvaluator(dataset_name=dataset_name,
                             tasks=["bbox"],
                             distributed=True,
                             output_dir=output_folder)
    
    @classmethod
    def build_tta_model(cls, cfg, model):
        return GeneralizedRCNNWithTTA(cfg, model)
    
    @classmethod
    def test_with_TTA(cls, cfg, model):
        # In the end of training, run an evaluation with TTA
        # Only support some R-CNN models.
        logger.info("Running inference with test-time augmentation ...")
        model = self.build_tta_model(cfg, model)
        evaluators = [
            cls.build_evaluator(
                cfg, name, output_folder=os.path.join(cfg.OUTPUT_DIR, "inference_TTA")
            )
            for name in cfg.DATASETS.TEST
        ]
        res = cls.test(cfg, model, evaluators)
        res = OrderedDict({k + "_TTA": v for k, v in res.items()})
        return res

## 3.2 Set a config

In [None]:
TRAIN_STEPS = 4242 # only 4242 images with annotation
cfg = get_cfg()
add_swint_config(cfg)
cfg.merge_from_file('swin/configs/SwinT/faster_rcnn_swint_T_FPN_3x_.yaml')
cfg.DATASETS.TRAIN = ("gbl_train_dataset",)
cfg.DATASETS.TEST = ("gbl_val_dataset",)
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1
cfg.MODEL.WEIGHTS = None
cfg.DATALOADER.NUM_WORKERS = 2
cfg.SOLVER.IMS_PER_BATCH = 4
cfg.SOLVER.MAX_ITER = TRAIN_STEPS * 10
cfg.SOLVER.STEPS = []
cfg.SOLVER.CHECKPOINT_PERIOD = TRAIN_STEPS
cfg.TEST.EVAL_PERIOD = TRAIN_STEPS 
os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)
trainer = Trainer(cfg)
trainer.resume_or_load(resume=False)

## 3.3 Train

I could only run for 5 epochs due to runtime quota

In [None]:
"""
# load pretrained weights
other_weights = torch.load('faster_rcnn_swint_T.pth')['model']
self_weight = trainer.model.state_dict()
for name, param in self_weight.items():
    if name in other_weights:
        if other_weights[name].shape == param.shape:
            self_weight[name] = other_weights[name]
        else:
            print(f"size mismatch at {name}")
    else:
        print(f"layer {name} does not exist")
trainer.model.load_state_dict(self_weight)
trainer.train()
"""

## 4.1 Evaluate
evaluate on validation set

In [None]:
trainer.model.load_state_dict(torch.load('model_0021209.pth')['model'])
trainer.test(cfg, trainer.model)

## 4.2 Visualize a few prediction examples

In [None]:
val_ds = DatasetCatalog.get("gbl_val_dataset")
trainer.model.eval()
metadata = MetadataCatalog.get('gbl_train_dataset')

In [None]:
for _ in range(5):
    idx = np.random.randint(0, len(val_ds))
    data = val_ds[idx]
    print(data['file_name'])
    im = cv2.imread(data['file_name'])
    im_tensor = torch.from_numpy(im).permute(2,0,1)  # h, w, c -> c, h, w
    h, w, _ = im.shape
    with torch.no_grad():
        pred = trainer.model([{"image": im_tensor.cuda(), "width": w, "height": h}])
    v = Visualizer(im[:, :, ::-1],
                   metadata=metadata, 
                   scale=0.5, 
                   instance_mode=ColorMode.IMAGE_BW   # remove the colors of unsegmented pixels. This option is only available for segmentation models
    )
    out = v.draw_instance_predictions(pred[0]["instances"].to("cpu"))
    plt.figure()
    plt.imshow(out.get_image())