# Detectron: preprocess data COCO format

In [None]:
!nvidia-smi

In [None]:
DEBUG = False
KAGGLE = False
COLAB = True

## Install

In [None]:
import torch
TORCH_VER = '.'.join(torch.__version__.split('.')[:2])
CUDA_VER = torch.__version__.split('+')[-1]
print('torch:', TORCH_VER, '| cuda:', CUDA_VER)
if COLAB:
    from google.colab import drive
    drive.mount('/content/drive')
    !pip install detectron2 -f https://dl.fbaipublicfiles.com/detectron2/wheels/{CUDA_VER}/torch{TORCH_VER}/index.html
    !pip install wandb
else:
    print(
        'install with command:\n'
        'pip install detectron2 -f'
        f'https://dl.fbaipublicfiles.com/detectron2/wheels/{CUDA_VER}/torch{TORCH_VER}/index.html'
    )

In [None]:
import os
import cv2
import json
import time
import random
import wandb
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pycocotools.mask as mask_util
import detectron2
from detectron2 import model_zoo
from detectron2.engine import DefaultPredictor, DefaultTrainer, hooks, BestCheckpointer
from detectron2.config import get_cfg, CfgNode
from detectron2.utils.visualizer import Visualizer, ColorMode
from detectron2.data import MetadataCatalog, DatasetCatalog, build_detection_test_loader
from detectron2.data.datasets import register_coco_instances
from detectron2.utils.logger import setup_logger
from detectron2.evaluation.evaluator import DatasetEvaluator
import detectron2.data.transforms as T
from detectron2.data import DatasetMapper, build_detection_train_loader
from detectron2.evaluation import inference_on_dataset
from detectron2.checkpoint import DetectionCheckpointer
import warnings
if DEBUG:
    warnings.filterwarnings('ignore', category=UserWarning) 
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
if torch.cuda.is_available():
    DEVICE = torch.device('cuda')
    print('GPU is available')
else:
    DEVICE = torch.device('cpu')
    print('CPU is used')
setup_logger()

## Config

In [None]:
VER = 'ver0'
WORK_DIR = '/content/drive/MyDrive/sartorius' if COLAB else '.'
DATA_PATH = '../input/sartorius-cell-instance-segmentation' if KAGGLE else f'{WORK_DIR}/data'
MDLS_PATH = f'../input/sartorius-models-{VER}' if KAGGLE else f'{WORK_DIR}/models_{VER}'
CONFIG = {
    'ver': VER,
    'fold': 4,
    'folds': 5,
    'batch_size': 4,
    'workers': 4 if COLAB else 8,
    #'chk_point': 'COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml',
    'chk_point': 'COCO-InstanceSegmentation/mask_rcnn_R_101_FPN_3x.yaml',
    #'chk_point': 'COCO-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_3x.yaml',
    #'start_weights': None,
    #'start_weights': '/content/drive/MyDrive/sartorius/models_vclbdtrn7lc/model_best.pth',
    'start_weights': '/content/drive/MyDrive/sartorius/models_vclbdtrn7lc/model_0006503.pth',
    'chk_point_period': False,
    'eval': True,
    'lr': 1e-3, # 1e-3 = default
    'epochs': 4 if DEBUG else 40, # 40000 = default
    'warmup_factor': 1e-6,
    'warm_up_ep': 5,
    'gamma': .9,
    'gamma_step': .1,
    'score_th': .5,
    'batch_size_per_img': 128, # number of regions per image used to train RPN
    'lc': False,
    'ext': False,
    'clr_aug': True,
    'classes': 3,
    'freeze_at': None,
    'ups': 'ups1',
    'seed': 2021
}
if not os.path.exists(MDLS_PATH):
    os.mkdir(MDLS_PATH)
with open(f'{MDLS_PATH}/config.json', 'w') as file:
    json.dump(CONFIG, file)

def seed_all(seed=0):
    np.random.seed(seed)
    random_state = np.random.RandomState(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(seed)
    return random_state    

with open(f'{WORK_DIR}/../.wandb', 'r') as file:
    api_key = file.read()
wandb.login(key=api_key)
wandb.init(
    project='sartorius',
    sync_tensorboard=True, 
    name=f'Sartorius detectron2 {VER}',
    settings=wandb.Settings(
        start_method="thread", 
        console="auto"
    )
)
wandb.config = CONFIG  
    
random_state = seed_all(CONFIG['seed'])
start_time = time.time()

## Data description

In [None]:
df = pd.read_csv(f'{DATA_PATH}/train.csv')

In [None]:
df_instances = df.groupby(['id']).agg({
    'annotation': 'count', 
    'cell_type': 'first'
})
df_instances = df_instances.groupby("cell_type")[['annotation']].describe(
    percentiles=[.1, .25, .75, .8, .85, .9, .95, .99]
).astype(int).T.droplevel(level=0).T.drop([
    'count', 
    '50%', 
    'std'
], axis=1)
display(df_instances)

In [None]:
df['n_pixels'] = df.annotation.apply(
    lambda x: np.sum([int(e) for e in x.split()[1:][::2]])
)
df_pixels = df.groupby("cell_type")[['n_pixels']].describe(
    percentiles=[.01, .02, .05, .1, .9, .95, .98, .99]
).astype(int).T.droplevel(level=0).T.drop([
    'count', 
    '50%', 
    'std'
], axis=1)
display(df_pixels)

## Load data

In [None]:
cfg = get_cfg()
if CONFIG['lc']:
    register_coco_instances(
        'sartorius_train',
        {}, 
        f'{DATA_PATH}/data2/livecell_annotations_train.json', 
        f'{DATA_PATH}/data2/LIVECell_dataset_2021/images'
    )
    register_coco_instances(
        'sartorius_val',
        {},
        f'{DATA_PATH}/data2/livecell_annotations_val.json', 
        f'{DATA_PATH}/data2/LIVECell_dataset_2021/images'
    )
else:
    cfg.INPUT.MASK_FORMAT = 'bitmask'
    if CONFIG['ext']:
        file_tr = f'{DATA_PATH}/coco_annotations_train_f{CONFIG["folds"]}_f{CONFIG["fold"]}_{CONFIG["ext"]}.json'
        file_vl = f'{DATA_PATH}/coco_annotations_val_f{CONFIG["folds"]}_f{CONFIG["fold"]}_{CONFIG["ext"]}.json'
    else:
        if CONFIG['ups']:
            file_tr = f'{DATA_PATH}/coco_annotations_train_f{CONFIG["folds"]}_f{CONFIG["fold"]}_{CONFIG["ups"]}.json'
        else:
            file_tr = f'{DATA_PATH}/coco_annotations_train_f{CONFIG["folds"]}_f{CONFIG["fold"]}.json'
        file_vl = f'{DATA_PATH}/coco_annotations_val_f{CONFIG["folds"]}_f{CONFIG["fold"]}.json'
    register_coco_instances(
        'sartorius_train',
        {}, 
        file_tr, 
        DATA_PATH
    )
    register_coco_instances(
        'sartorius_val',
        {},
        file_vl, 
        DATA_PATH
    )
metadata = MetadataCatalog.get('sartorius_train')
ds_train = DatasetCatalog.get('sartorius_train')

## Demo sample data

In [None]:
demo_dict = ds_train[0]
img = cv2.imread(demo_dict['file_name'])
visualizer = Visualizer(img[:, :, ::-1], metadata=metadata)
out = visualizer.draw_dataset_dict(demo_dict)
plt.figure(figsize=(12, 8))
plt.imshow(out.get_image()[:, :, ::-1])
plt.axis('off')
plt.show()

## Evaluator 

In [None]:
from detectron2.structures import polygons_to_bitmask

def polygon_to_rle(polygon, shape=(520, 704)):
    mask = polygons_to_bitmask([np.asarray(polygon) + 0.25], shape[0], shape[1])
    rle = mask_util.encode(np.asfortranarray(mask))
    return rle

def precision_at(threshold, iou):
    matches = iou > threshold
    true_pos = np.sum(matches, axis=1) == 1  # Correct objects
    false_pos = np.sum(matches, axis=0) == 0  # Missed objects
    false_neg = np.sum(matches, axis=1) == 0  # Extra objects
    return np.sum(true_pos), np.sum(false_pos), np.sum(false_neg)

def score(pred, targ):
    pred_masks = pred['instances'].pred_masks.cpu().numpy()
    enc_preds = [
        mask_util.encode(np.asarray(p, order='F')) 
        for p in pred_masks
    ]
    enc_targs = list(map(lambda x: x['segmentation'], targ))
    if CONFIG['lc']:
        enc_targs = [polygon_to_rle(enc_targ[0]) for enc_targ in enc_targs]
    ious = mask_util.iou(enc_preds, enc_targs, [0] * len(enc_targs))
    prec = []
    for t in np.arange(.5, 1, .05):
        tp, fp, fn = precision_at(t, ious)
        p = tp / (tp + fp + fn)
        prec.append(p)
    return np.mean(prec)

class MAPIOUEvaluator(DatasetEvaluator):
    def __init__(self, dataset_name):
        dataset_dicts = DatasetCatalog.get(dataset_name)
        self.annotations_cache = {
            item['image_id']: item['annotations'] 
            for item in dataset_dicts
        }
    
    def reset(self):
        self.scores = []

    def process(self, inputs, outputs):
        for inp, out in zip(inputs, outputs):
            if len(out['instances']) == 0:
                self.scores.append(0)    
            else:
                targ = self.annotations_cache[inp['image_id']]
                self.scores.append(score(out, targ))

    def evaluate(self):
        return {'MaP IoU': np.mean(self.scores)}

class MAPIOUCEvaluator(DatasetEvaluator):
    def __init__(self, dataset_name, classes=CONFIG['classes']):
        dataset_dicts = DatasetCatalog.get(dataset_name)
        self.annotations_cache = {
            item['image_id']: item['annotations'] 
            for item in dataset_dicts
        }
        self.classes = classes
    
    def reset(self):
        self.scores = []

    def process(self, inputs, outputs):
        for inp, out in zip(inputs, outputs):
            if len(out['instances']) == 0:
                self.scores.append([0] * self.classes)    
            else:
                targ = self.annotations_cache[inp['image_id']]
                class_scores = []
                for c in range(self.classes):
                    targ_c = [x  for x in targ if x['category_id'] == c]
                    if targ_c:
                        class_scores.append(score(out, targ_c))
                    else:
                        class_scores.append(0)
                self.scores.append(class_scores)

    def evaluate(self):
        return {'MaP IoU': np.mean(self.scores, axis=0).tolist()}

class Trainer(DefaultTrainer):
    @classmethod
    def build_evaluator(cls, cfg, dataset_name, output_folder=None):
        return MAPIOUEvaluator(dataset_name)
    
    @classmethod
    def build_train_loader(cls, cfg):
        '''
        Base aug:
        [DatasetMapper] Augmentations used in training: 
        [ResizeShortestEdge(
            short_edge_length=(640, 672, 704, 736, 768, 800), 
            max_size=1333, 
            sample_style='choice'), 
        RandomFlip()]
    
        '''
        augs = [
            T.ResizeShortestEdge(
                #short_edge_length=(640, 672, 704, 736, 768, 800), 
                short_edge_length=(520, 640, 672, 704, 736, 768, 800), 
                max_size=1333,
                #max_size=704, 
                sample_style='choice'
                ),
            T.RandomFlip(prob=.5, horizontal=True, vertical=False),
            T.RandomFlip(prob=.5, horizontal=False, vertical=True),
            T.RandomApply(
                T.RandomCrop(crop_type='relative_range', crop_size=(.75, .75)),
                prob=.25
            )
        ]
        if CONFIG['clr_aug']:
            augs.extend([
                T.RandomApply(
                    T.RandomBrightness(.9, 1.1),
                    prob=.25
                ),
                T.RandomApply(
                    T.RandomContrast(.9, 1.1),
                    prob=.25
                )
            ])
        mapper = DatasetMapper(
            cfg, 
            is_train=True, 
            augmentations=augs
        )
        return build_detection_train_loader(cfg, mapper=mapper)
    
    def build_hooks(self):
        cfg = self.cfg.clone()
        hooks = super().build_hooks()
        hooks.insert(
            -1, 
            BestCheckpointer(
                cfg.TEST.EVAL_PERIOD, 
                DetectionCheckpointer(self.model, cfg.OUTPUT_DIR),
                'MaP IoU',
                'max'
            )
        )
        return hooks

## Train

In [None]:
iters_per_epoch = len(DatasetCatalog.get('sartorius_train')) // CONFIG['batch_size']

cfg.merge_from_file(model_zoo.get_config_file(CONFIG['chk_point']))
cfg.DATASETS.TRAIN = ('sartorius_train', )
cfg.DATASETS.TEST = ('sartorius_val', )
cfg.DATALOADER.NUM_WORKERS = CONFIG['workers']
if CONFIG['start_weights']:
    cfg.MODEL.WEIGHTS = CONFIG['start_weights']
else:
    cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url(CONFIG['chk_point'])

cfg.SOLVER.IMS_PER_BATCH = CONFIG['batch_size']
cfg.SOLVER.MAX_ITER = CONFIG['epochs'] * iters_per_epoch
cfg.SOLVER.CHECKPOINT_PERIOD = iters_per_epoch if CONFIG['chk_point_period'] else (cfg.SOLVER.MAX_ITER + 1)
cfg.SOLVER.BASE_LR = CONFIG['lr']
cfg.SOLVER.WARMUP_FACTOR = CONFIG['warmup_factor']
cfg.SOLVER.WARMUP_ITERS = iters_per_epoch * CONFIG['warm_up_ep']
cfg.SOLVER.WARMUP_METHOD = 'linear'
cfg.SOLVER.GAMMA = CONFIG['gamma']
steps = []
for i in np.arange(CONFIG['gamma_step'], 1, CONFIG['gamma_step']):
    steps.append(iters_per_epoch * CONFIG['warm_up_ep'] + int(cfg.SOLVER.MAX_ITER * i))
print('gamma steps:', steps)
cfg.SOLVER.STEPS = steps
cfg.SOLVER.AMP.ENABLED = True

cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = CONFIG['batch_size_per_img']   
cfg.MODEL.ROI_HEADS.NUM_CLASSES = CONFIG['classes']
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = CONFIG['score_th']
if CONFIG['freeze_at']:
    cfg.MODEL.BACKBONE.FREEZE_AT = CONFIG['freeze_at']

if CONFIG['eval']:
    cfg.TEST.EVAL_PERIOD = iters_per_epoch

cfg.OUTPUT_DIR = MDLS_PATH

In [None]:
print(
    'train:', len(DatasetCatalog.get('sartorius_train')),
    '| val:', len(DatasetCatalog.get('sartorius_val'))
)
trainer = Trainer(cfg) 
trainer.resume_or_load(resume=False)
trainer.train()

elapsed_time = time.time() - start_time
print(f'time elapsed: {elapsed_time // 60:.0f} min {elapsed_time % 60:.0f} sec')

## Best model

In [None]:
with open(f'{MDLS_PATH}/metrics.json', 'r') as file:
    lines = file.readlines()
    metrics = [json.loads(line.rstrip())
               for line in lines if 'MaP IoU' in line]
df = pd.DataFrame(metrics)

In [None]:
plt.figure(figsize=(16, 4))
plt.plot(df['MaP IoU'])
plt.xlabel('iters')
plt.xticks(
    list(range(len(df)))[::2], 
    df['iteration'][::2],
    rotation=90
)
plt.title(f'MaP IoU, max={np.max(df["MaP IoU"]):.4f}')
plt.show()

In [None]:
if CONFIG['chk_point_period']:
    model_files = sorted([x for x in os.listdir(f'{MDLS_PATH}') if '.pth' in x])
    best_model_file = model_files[np.argmax(df['MaP IoU'])]
    print('best model file:', best_model_file)

    for file_name in model_files:
        if file_name != best_model_file:
            os.remove(f'{MDLS_PATH}/{file_name}')
else:
    best_model_file = 'model_best.pth'
best_iter = df.loc[np.argmax(df['MaP IoU']), 'iteration']

In [None]:
with open(f'{MDLS_PATH}/best_model.json', 'w') as file:
    json.dump({
        'file': best_model_file,
        'score': np.max(df["MaP IoU"]),
        'best_iter': int(best_iter)
    }, file)

## Best thresholds

In [None]:
dataset_dicts = DatasetCatalog.get('sartorius_val')
val_loader = build_detection_test_loader(cfg, 'sartorius_val')

In [None]:
ths_scores = []
best_model_file = 'model_best.pth'
for th in np.arange(.05, 1, .05):
    cfg.merge_from_file(model_zoo.get_config_file(CONFIG['chk_point']))
    cfg.INPUT.MASK_FORMAT = 'bitmask'
    cfg.MODEL.ROI_HEADS.NUM_CLASSES = CONFIG['classes']
    cfg.MODEL.WEIGHTS = f'{MDLS_PATH}/{best_model_file}'  
    cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = float(th)
    cfg.TEST.DETECTIONS_PER_IMAGE = 1000
    model = DefaultPredictor(cfg)
    infer_result = inference_on_dataset(
        model.model, 
        val_loader, 
        MAPIOUCEvaluator('sartorius_val')
    )
    print(th, '->', infer_result)
    th_scores = [th]
    th_scores.extend(infer_result['MaP IoU'])
    ths_scores.append(th_scores)

In [None]:
ths = []
all_ths = [f'{x[0]:.2f}' for x in ths_scores]
plt.figure(figsize=(16, 4))
for cls in range(CONFIG['classes']):
    cls_scores = [x[1 + cls] for x in ths_scores]
    plt.plot(cls_scores)
    ths.append(float(all_ths[np.argmax(cls_scores)]))
plt.xlabel('ths')
plt.xticks(
    list(range(len(ths_scores)))[::2], 
    all_ths[::2],
    rotation=90
)
plt.title(f'MaP IoU, best ths are {ths}')
plt.show()

In [None]:
with open(f'{MDLS_PATH}/ths.json', 'w') as file:
    json.dump(ths, file)

## Inference demo

In [None]:
cfg.MODEL.WEIGHTS = f'{MDLS_PATH}/{best_model_file}'
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = .5
predictor = DefaultPredictor(cfg)
preds = []
gts = []
for ds_dict in random.sample(dataset_dicts, 3):    
    img = cv2.imread(ds_dict['file_name'])
    outputs = predictor(img)
    v = Visualizer(
        img[:, :, ::-1],
        metadata = MetadataCatalog.get('sartorius_train'), 
        instance_mode=ColorMode.IMAGE_BW
    )
    pred = v.draw_instance_predictions(outputs["instances"].to("cpu"))
    visualizer = Visualizer(
        img[:, :, ::-1], 
        metadata=MetadataCatalog.get('sartorius_train')
    )
    target = visualizer.draw_dataset_dict(ds_dict)
    preds.append(pred)
    gts.append(target)

fig, axs = plt.subplots(len(preds), 2, figsize=(16, 6 * len(preds)))
for i, ax in enumerate(axs):
    ax[0].imshow(gts[i].get_image()[:, :, ::-1])
    ax[0].set_title('ground truth')
    ax[0].axis('off')
    ax[1].imshow(preds[i].get_image()[:, :, ::-1])
    ax[1].set_title('preds')
    ax[1].axis('off')