![logo](https://raw.githubusercontent.com/facebookresearch/detectron2/main/.github/Detectron2-Logo-Horz.svg)

# Intro
In this notebook we will train an instance segmentation model based on Detectron2 and CenterMask2. 

References:
  * [CenterMask2](https://github.com/youngwanLEE/centermask2)
  * [LIVECell](https://github.com/sartorius-research/LIVECell)
  * [Create COCO annotations for Sartorius dataset](https://www.kaggle.com/mistag/sartorius-create-coco-annotations)
  * [Offline Detectron2 installation files](https://www.kaggle.com/mistag/detectron2-download-code-for-offline-installation)
  
Start by installing Detectron2:

In [None]:
!pip install --no-index \
../input/detectron2-download-code-for-offline-install-ii/detectron2/detectron2-0.6-cp37-cp37m-linux_x86_64.whl \
--find-links=../input/detectron2-download-code-for-offline-install-ii/detectron2

Then fetch the [Centermask2](https://github.com/youngwanLEE/centermask2) project:

In [None]:
!git clone https://github.com/youngwanLEE/centermask2.git

# Pretrained model
There are several different pretrained models to choose from, below we pick a [lightweight](https://github.com/youngwanLEE/centermask2#centermask-lite) one that actually fits in GPU memory.

In [None]:
%cd centermask2
!wget 'https://dl.dropbox.com/s/dret2ap7djty7mp/centermask2-lite-V-19-eSE-FPN-ms-4x.pth'

Create a configuration file, most importantly define the dataset names to use for train and test. The dataset will be registered later through the Detectron2 [DatasetCatalog](https://detectron2.readthedocs.io/en/latest/tutorials/datasets.html) API.

In [None]:
%%writefile configs/centermask/test.yaml
DATALOADER:
  ASPECT_RATIO_GROUPING: true
  FILTER_EMPTY_ANNOTATIONS: true
  NUM_WORKERS: 2
MODEL:
  META_ARCHITECTURE: "GeneralizedRCNN"
  BACKBONE:
    NAME: "build_fcos_vovnet_fpn_backbone"
    FREEZE_AT: 0
  VOVNET:
    OUT_FEATURES: ["stage3", "stage4", "stage5"]
  FPN:
    IN_FEATURES: ["stage3", "stage4", "stage5"]
  PROPOSAL_GENERATOR:
    NAME: "FCOS"  
  FCOS:
    POST_NMS_TOPK_TEST: 600 # Max number of detections per image
    POST_NMS_TOPK_TRAIN: 600
  # PIXEL_MEAN: [102.9801, 115.9465, 122.7717]
  MASK_ON: True
  MASKIOU_ON: True
  ROI_HEADS:
    NAME: "CenterROIHeads"
    IN_FEATURES: ["p3", "p4", "p5"]
  ROI_MASK_HEAD:
    NAME: "SpatialAttentionMaskHead"
    ASSIGN_CRITERION: "ratio"
    NUM_CONV: 4
    POOLER_RESOLUTION: 14
DATASETS:
  TRAIN: ("train",) # match with DatasetCatalog.register() call!
  TEST: ("test",)
SOLVER:
  CHECKPOINT_PERIOD: 5000
  IMS_PER_BATCH: 8
  BASE_LR: 0.01  # Note that RetinaNet uses a different default learning rate
  STEPS: (60000, 80000)
  MAX_ITER: 10000 # 16000 = ~9h
INPUT:
  MIN_SIZE_TRAIN: (480, 512, 640)
TEST:
  AUG:
    ENABLED: False
  DETECTIONS_PER_IMAGE: 500
  EVAL_PERIOD: 600 

# Training
The code below is modified from [train_net.py](https://github.com/youngwanLEE/centermask2/blob/master/train_net.py) that comes with CenterMask2 repository.

In [None]:
# import various libraries
import logging
import os, re
from collections import OrderedDict
import torch
import numpy as np

import detectron2.utils.comm as comm
from detectron2.data import MetadataCatalog
from detectron2.data.datasets import register_coco_instances, load_coco_json
from detectron2.data import DatasetCatalog
from detectron2.engine import DefaultTrainer, DefaultPredictor
from detectron2.engine import default_argument_parser, default_setup, hooks, launch
from detectron2.evaluation import (
    # CityscapesInstanceEvaluator,
    # CityscapesSemSegEvaluator,
    # COCOEvaluator,
    COCOPanopticEvaluator,
    DatasetEvaluators,
    LVISEvaluator,
    PascalVOCDetectionEvaluator,
    SemSegEvaluator,
    verify_results,
)
from centermask.evaluation import (
    COCOEvaluator,
    CityscapesInstanceEvaluator,
    CityscapesSemSegEvaluator
)
from detectron2.modeling import GeneralizedRCNNWithTTA
from detectron2.checkpoint import DetectionCheckpointer
from detectron2.utils.visualizer import Visualizer, ColorMode
from centermask.config import get_cfg

Define trainer class:

In [None]:
def get_train_set():
    return load_coco_json('/kaggle/input/sartorius-create-coco-annotations/train_fold_0.json', '')

def get_test_set():
    return load_coco_json('/kaggle/input/sartorius-create-coco-annotations/test_fold_0.json', '')


class Trainer(DefaultTrainer):
    """
    This is the same Trainer except that we rewrite the
    `build_train_loader` method.
    """

    @classmethod
    def build_evaluator(cls, cfg, dataset_name, output_folder=None):
        """
        Create evaluator(s) for a given dataset.
        This uses the special metadata "evaluator_type" associated with each builtin dataset.
        For your own dataset, you can simply create an evaluator manually in your
        script and do not have to worry about the hacky if-else logic here.
        """
        if output_folder is None:
            output_folder = os.path.join(cfg.OUTPUT_DIR, "inference")
        evaluator_list = []
        evaluator_list.append(COCOEvaluator(dataset_name, output_dir=output_folder))
        if len(evaluator_list) == 0:
            raise NotImplementedError(
                "no Evaluator for the dataset {} with the type {}".format(
                    dataset_name, evaluator_type
                )
            )
        elif len(evaluator_list) == 1:
            return evaluator_list[0]
        return DatasetEvaluators(evaluator_list)

    @classmethod
    def test_with_TTA(cls, cfg, model):
        logger = logging.getLogger("detectron2.trainer")
        # In the end of training, run an evaluation with TTA
        # Only support some R-CNN models.
        logger.info("Running inference with test-time augmentation ...")
        model = GeneralizedRCNNWithTTA(cfg, model)
        evaluators = [
            cls.build_evaluator(
                cfg, name, output_folder=os.path.join(cfg.OUTPUT_DIR, "inference_TTA")
            )
            for name in cfg.DATASETS.TEST
        ]
        res = cls.test(cfg, model, evaluators)
        res = OrderedDict({k + "_TTA": v for k, v in res.items()})
        return res

Define configuration:

In [None]:
cfg = get_cfg()
cfg.merge_from_file("configs/centermask/test.yaml")
cfg.freeze()
default_setup(cfg, 'Namespace(num_gpus=1, opts=[\'MODEL.WEIGHTS\', \'centermask2-lite-V-19-eSE-FPN-ms-4x.pth\'], resume=False')

Register dataset:

In [None]:
# Dataset registration
CLASSES = ["background", "shsy5y", "astro", "cort"]
DatasetCatalog.register("train", get_train_set)
MetadataCatalog.get("train").thing_classes = CLASSES
MetadataCatalog.get("train").evaluator_type = "coco"
DatasetCatalog.register("test", get_test_set)
MetadataCatalog.get("test").thing_classes = CLASSES
MetadataCatalog.get("test").evaluator_type = "coco"

Visulize a few samples for checks:

In [None]:
import cv2
import matplotlib.pyplot as plt

test_ds = DatasetCatalog.get('test')
meta_ds = MetadataCatalog.get("test")

In [None]:
sample = test_ds[9]

img = cv2.imread(sample["file_name"])
visualizer = Visualizer(img[:, :, ::-1], metadata=meta_ds)
out = visualizer.draw_dataset_dict(sample)
plt.figure(figsize = (20,15))
plt.imshow(out.get_image()[:, :, ::-1]);

# Train

In [None]:
trainer = Trainer(cfg)
trainer.resume_or_load(resume=False)
if cfg.TEST.AUG.ENABLED:
    trainer.register_hooks(
        [hooks.EvalHook(0, lambda: trainer.test_with_TTA(cfg, trainer.model))]
    )
trainer.train()

Check learning curves. Detectron2 has an event storage module for this, but everything we need is already in the log file:

In [None]:
# read log file
with open('./output/log.txt', 'r') as f:
    log = f.read()
# extract training loss
lines = re.findall('iter: [0-9]*  total_loss: [.0-9]*', log)
it, loss = [], []
for i in range(len(lines)):
    res = re.findall("[-+]?[.]?[\d]+(?:,\d\d\d)*[\.]?\d*(?:[eE][-+]?\d+)?", lines[i])
    it.append(int(res[0]))
    loss.append(float(res[1]))
plt.figure(figsize = (20,10))
plt.plot(it, loss)
plt.xlabel('Iteration')
plt.ylabel('Total loss')
plt.title('Training loss');

In [None]:
# Then we do evaluation scores
raw = re.findall('copypaste: [.,0-9]*', log)
boxes, segs = [], []
idx = 0
for s in raw:
    if len(s) > 20:
        nums = [float(i) for i in s.strip('copypaste: ').split(',')]
        if idx == 0:
            boxes.append(nums)
        else:
            segs.append(nums)
        idx = (idx + 1) % 2
boxes, segs = np.asarray(boxes), np.asarray(segs)
x = (np.arange(0, len(segs[:,0])) *600) + 600
plt.figure(figsize = (20,10))
x = (np.arange(0, len(segs[:,0])) *600) + 600
plt.plot(x, segs[:,0], label='Mask AP', color='tab:blue')
plt.plot(x, segs[:,1], label='Mask AP50', color='tab:orange')
plt.plot(x, segs[:,2], label='Mask AP75', color='tab:red')
plt.plot(x, boxes[:,0], label='Box AP', linestyle='--', color='tab:blue')
plt.plot(x, boxes[:,1], label='Box AP50', linestyle='--', color='tab:orange')
plt.plot(x, boxes[:,2], label='Box AP75', linestyle='--', color='tab:red')
plt.legend()
plt.xlabel('Iteration')
plt.ylabel('Score')
plt.title('COCO Evaluation scores');

## Model check

Check a few predictions, code from [Positive score with Detectron 2/3 - Training](https://www.kaggle.com/slawekbiel/positive-score-with-detectron-2-3-training).

In [None]:
import random

cfg.defrost()
cfg.MODEL.WEIGHTS = './output/model_final.pth'  # path to the model we just trained
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5   # set a custom testing threshold
predictor = DefaultPredictor(cfg)
dataset_dicts = DatasetCatalog.get('test')
outs = []
for d in random.sample(dataset_dicts, 3):    
    im = cv2.imread(d["file_name"])
    outputs = predictor(im)  # format is documented at https://detectron2.readthedocs.io/tutorials/models.html#model-output-format
    v = Visualizer(im[:, :, ::-1],
                   metadata = MetadataCatalog.get('test'), 
                    
                   instance_mode=ColorMode.IMAGE_BW   # remove the colors of unsegmented pixels. This option is only available for segmentation models
    )
    out_pred = v.draw_instance_predictions(outputs["instances"].to("cpu"))
    visualizer = Visualizer(im[:, :, ::-1], metadata=MetadataCatalog.get('test'))
    out_target = visualizer.draw_dataset_dict(d)
    outs.append(out_pred)
    outs.append(out_target)
_,axs = plt.subplots(len(outs)//2,2,figsize=(40,45))
for ax, out in zip(axs.reshape(-1), outs):
    ax.imshow(out.get_image()[:, :, ::-1])

# Inference
Time to [make predictions with the trained model](https://www.kaggle.com/mistag/pred-sartorius-detectron2-centermask2). Also see tutorial on how to [run inference with Detectron2 models](https://colab.research.google.com/drive/16jcaJoc6bCFAQ96jDe2HwtXj7BMD_-m5).