# Finetune Mask R-CNN with MMDetection

This notebook uses the following versions:

- `mmcv-full==1.5.0`
- `mmdet==2.24.1`
- `pycocotools==2.0.4`
- `torch==1.11.0`
- `torchvision==0.12.0+cu113`

In [1]:
!git clone https://github.com/open-mmlab/mmdetection.git

Cloning into 'mmdetection'...
remote: Enumerating objects: 24460, done.[K
remote: Counting objects: 100% (22/22), done.[K
remote: Compressing objects: 100% (20/20), done.[K
remote: Total 24460 (delta 3), reused 11 (delta 2), pack-reused 24438[K
Receiving objects: 100% (24460/24460), 37.56 MiB | 36.18 MiB/s, done.
Resolving deltas: 100% (17113/17113), done.


## Step 1. Define the dataset

In [15]:
%%writefile mmdetection/mmdet/datasets/mineapple.py
from pathlib import Path

import mmcv
import numpy as np
import torch
from mmdet.core.mask.utils import mask2bbox, encode_mask_results
from mmdet.datasets.builder import DATASETS
from mmdet.datasets.coco import CocoDataset
from mmdet.datasets.pipelines import Compose
from PIL import Image

@DATASETS.register_module()
class MineAppleDataset(CocoDataset):
    CLASSES = ("fruit",)
    PALETTE = [(220, 20, 60)]

    def __init__(
        self,
        data_root,
        pipeline,
        classes=None,
        test_mode=False,
        filter_empty_gt=True,
        file_client_args=dict(backend='disk'),
    ):
        self.data_root = Path(data_root)
        data_dir = self.data_root / "detection"
        location = "test" if test_mode else "train"
        split_dir = data_dir / location
        self.img_prefix = split_dir / "images"
        self.seg_prefix = split_dir / "masks"
        self.test_mode = test_mode
        self.filter_empty_gt = filter_empty_gt
        self.file_client = mmcv.FileClient(**file_client_args)
        self.CLASSES = self.get_classes(classes)

        # TODO: Not sure what these are or if needed
        self.proposal_file = None
        self.proposals = None

        # Find all images and corresponding masks
        image_paths = list(sorted(self.img_prefix.glob("*.png")))
        if not image_paths:
            raise RuntimeError("No images")
        self.data_infos = []
        for image_id, image_path in enumerate(image_paths):
            width, height = Image.open(image_path).size
            data_info = {
                # TODO: id needed?
                # "id": image_id,
                "width": width,
                "height": height,
                "filename": image_path.name,
            }
            self.data_infos.append(data_info)

        # filter images too small and containing no annotations
        if not test_mode:
            # valid_inds = self._filter_imgs()
            # self.data_infos = [self.data_infos[i] for i in valid_inds]
            # set group flag for the sampler
            self._set_group_flag()

        # processing pipeline
        self.pipeline = Compose(pipeline)

    def get_ann_info(self, index):
        seg_map = self.data_infos[index]["filename"]
        int_mask = np.array(Image.open(self.seg_prefix / seg_map))

        # Convert mask from a 2D image array with objects represented by increasing
        # integers to a 3D boolean array
        # Source: https://pytorch.org/vision/stable/auto_examples/plot_repurposing_annotations.html
        object_ids = np.unique(int_mask)
        # Ignore the background object (object_id == 0)
        object_ids = object_ids[1:]
        masks = int_mask == object_ids[:, None, None]

        # Masks to boxes
        masks_good = []
        boxes = []
        for index, mask in enumerate(masks):
            y, x = np.where(mask != 0)
            box = (np.min(x), np.min(y), np.max(x), np.max(y))
            if box[2] <= box[0] or box[3] <= box[1]:
                continue
            masks_good.append(mask)
            boxes.append(box)

        # Create the object detection target in its expected format
        if boxes:
            boxes = np.array(boxes, dtype=np.float32)
        else:
            boxes = np.zeros((0, 4), dtype=np.float32)
        if masks_good:
            masks = np.array(masks_good, dtype=np.uint8)
        else:
            masks = np.zeros((0, *masks.shape[1:]), dtype=np.uint8)
        # Assume single label
        labels = np.zeros((masks.shape[0],), dtype=np.int64)
        # Assume iscrowd is always false
        bboxes_ignore = np.zeros((0, 4), dtype=np.float32)
        assert boxes.shape[0] == masks.shape[0] == labels.shape[0], (boxes.shape, masks.shape, labels.shape)
        assert labels.ndim == 1, labels.ndim
        assert bboxes_ignore.shape[0] == 0, bboxes_ignore.shape
        assert boxes.shape[1] == bboxes_ignore.shape[1] == 4, (boxes.shape, bboxes_ignore.shape)
        assert masks.shape[1:] == (1280, 720), masks.shape
        ann = {
            "bboxes": boxes,
            "labels": labels,
            "masks": masks,
            "bboxes_ignore": bboxes_ignore,
            "seg_map": seg_map,
        }
        return ann

    # def _filter_imgs(self, **kwargs):
    #     # Assume all images are valid
    #     self.img_ids = list(range(len(self.data_infos)))
    #     return self.img_ids

Overwriting mmdetection/mmdet/datasets/mineapple.py


In [16]:
%%writefile mmdetection/mmdet/datasets/pipelines/mineapple.py
from mmdet.core.mask.structures import BitmapMasks
from mmdet.datasets.builder import PIPELINES
from mmdet.datasets.pipelines.loading import LoadAnnotations


@PIPELINES.register_module()
class LoadMineappleAnnotations(LoadAnnotations):
    def _load_masks(self, results):
        h, w = results['img_info']['height'], results['img_info']['width']
        gt_masks = results['ann_info']['masks']
        # TODO: Maybe I can disable with_masks and do this in the dataset?...
        gt_masks = BitmapMasks(gt_masks, h, w)
        results['gt_masks'] = gt_masks
        results['mask_fields'].append('gt_masks')
        return results

Overwriting mmdetection/mmdet/datasets/pipelines/mineapple.py


## Step 2. Fine-tune Mask R-CNN

In [None]:
# Setting PYTHONPATH to make custom_imports in the config work.
# Not sure if there's a better way.
!PYTHONPATH=mmdetection python mmdetection/tools/train.py fruit_detection.py

## Make COCO API from dataset

In [28]:
import sys

if "mmdetection" not in sys.path:
    sys.path.insert(0, "mmdetection")

In [29]:
config = "fruit_detection.py"

In [30]:
from mmcv import Config
from mmdet.datasets import build_dataset, build_dataloader
from mmdet.models import build_detector

cfg = Config.fromfile(config)
dataset = build_dataset(cfg.data.train)

In [31]:
from pycocotools.coco import COCO

In [32]:
from mmdet.datasets.api_wrappers.coco_api import COCO

In [33]:
coco = COCO()

In [34]:
ds = dataset

In [82]:
ann_info = ds.get_ann_info(0)

In [85]:
ann_info.keys()

dict_keys(['bboxes', 'labels', 'masks', 'bboxes_ignore', 'seg_map'])

In [90]:
ann_info['masks'].transpose(0, 2, 1)

(95, 720, 1280)

In [87]:
import torch
from mmdet.utils import get_device

device = get_device()

# annotation IDs need to start at 1, not 0, see torchvision issue #1530
ann_id = 1
dataset = {"images": [], "categories": [], "annotations": []}
categories = set()
for img_idx in range(len(ds)):
    data_info = ds.data_infos[img_idx]
    img_dict = {}
    img_dict["id"] = img_idx
    img_dict["height"] = data_info["height"]
    img_dict["width"] = data_info["height"]
    dataset["images"].append(img_dict)

    ann_info = ds.get_ann_info(img_idx)
    labels = ann_info["labels"].tolist()
    bboxes = ann_info["bboxes"].copy()
    bboxes[:, 2:] -= bboxes[:, :2]
    areas = bboxes[:, 2:].prod(axis=-1)
    bboxes = bboxes.tolist()
    iscrowd = [0] * len(labels)
    if "masks" in targets:
        masks = targets["masks"]
        # make masks Fortran contiguous for coco_mask
        masks = np.asfortranarray(masks)
    # TODO: Continue here
    num_objs = len(bboxes)
    for i in range(num_objs):
        ann = {}
        ann["image_id"] = image_id
        ann["bbox"] = bboxes[i]
        ann["category_id"] = labels[i]
        categories.add(labels[i])
        ann["area"] = areas[i]
        ann["iscrowd"] = iscrowd[i]
        ann["id"] = ann_id
        if "masks" in targets:
            ann["segmentation"] = coco_mask.encode(masks[i].numpy())
        if "keypoints" in targets:
            ann["keypoints"] = keypoints[i]
            ann["num_keypoints"] = sum(k != 0 for k in keypoints[i][2::3])
        dataset["annotations"].append(ann)
        ann_id += 1
dataset["categories"] = [{"id": i} for i in sorted(categories)]
coco.dataset = dataset
coco.createIndex()

creating index...
index created!


In [None]:
def convert_to_coco_api(ds):
    coco_ds = COCO()
    # annotation IDs need to start at 1, not 0, see torchvision issue #1530
    ann_id = 1
    dataset = {"images": [], "categories": [], "annotations": []}
    categories = set()
    for img_idx in range(len(ds)):
        # find better way to get target
        # targets = ds.get_annotations(img_idx)
        img, targets = ds[img_idx]
        image_id = targets["image_id"].item()
        img_dict = {}
        img_dict["id"] = image_id
        img_dict["height"] = img.shape[-2]
        img_dict["width"] = img.shape[-1]
        dataset["images"].append(img_dict)
        bboxes = targets["boxes"].clone()
        bboxes[:, 2:] -= bboxes[:, :2]
        bboxes = bboxes.tolist()
        labels = targets["labels"].tolist()
        areas = targets["area"].tolist()
        iscrowd = targets["iscrowd"].tolist()
        if "masks" in targets:
            masks = targets["masks"]
            # make masks Fortran contiguous for coco_mask
            masks = masks.permute(0, 2, 1).contiguous().permute(0, 2, 1)
        if "keypoints" in targets:
            keypoints = targets["keypoints"]
            keypoints = keypoints.reshape(keypoints.shape[0], -1).tolist()
        num_objs = len(bboxes)
        for i in range(num_objs):
            ann = {}
            ann["image_id"] = image_id
            ann["bbox"] = bboxes[i]
            ann["category_id"] = labels[i]
            categories.add(labels[i])
            ann["area"] = areas[i]
            ann["iscrowd"] = iscrowd[i]
            ann["id"] = ann_id
            if "masks" in targets:
                ann["segmentation"] = coco_mask.encode(masks[i].numpy())
            if "keypoints" in targets:
                ann["keypoints"] = keypoints[i]
                ann["num_keypoints"] = sum(k != 0 for k in keypoints[i][2::3])
            dataset["annotations"].append(ann)
            ann_id += 1
    dataset["categories"] = [{"id": i} for i in sorted(categories)]
    coco_ds.dataset = dataset
    coco_ds.createIndex()
    return coco_ds

## Step 3. Inference

In [None]:
import sys

if "mmdetection" not in sys.path:
    sys.path.insert(0, "mmdetection")

In [None]:
config = "fruit_detection.py"

In [None]:
from mmcv import Config
from mmdet.datasets import build_dataset, build_dataloader
from mmdet.models import build_detector

cfg = Config.fromfile(config)
dataset = build_dataset(cfg.data.train)
model = build_detector(cfg.model)

In [26]:
len(dataset)

670

In [None]:
from mmdet.apis import init_detector, inference_detector

model = build_detector(cfg.model)

## Rough: Train in code

In [58]:
from mmcv import Config
from mmdet.datasets import build_dataset, build_dataloader
from mmdet.models import build_detector

In [65]:
import sys
if "mmdetection" not in sys.path:
    sys.path.append("mmdetection")

['/home/jupyter/ai-playground/object-detection',
 '/opt/conda/lib/python37.zip',
 '/opt/conda/lib/python3.7',
 '/opt/conda/lib/python3.7/lib-dynload',
 '',
 '/opt/conda/lib/python3.7/site-packages',
 '/opt/conda/lib/python3.7/site-packages/IPython/extensions',
 '/home/jupyter/.ipython',
 'mmdetection']

In [74]:
config = "fruit_detection.py"

cfg = Config.fromfile(config)

args = cfg.data.train.copy()
args.pop("type")
dataset = MineAppleDataset(**args)
# datasets = [build_dataset(cfg.data.train)]

In [60]:
model = build_detector(cfg.model)

In [4]:
# Single GPU training
cfg.gpu_ids = range(1)

In [5]:
import torch.distributed as dist
from mmdet.apis import init_random_seed, set_random_seed, train_detector
from mmdet.utils import (collect_env, get_device, get_root_logger,
                         setup_multi_processes, update_data_root)

deterministic = False
seed = 0

cfg.device = get_device()

# set random seeds
seed = init_random_seed(seed, device=cfg.device)
set_random_seed(seed, deterministic=deterministic)
cfg.seed = seed

In [6]:
distributed = False
runner_type = 'EpochBasedRunner' if 'runner' not in cfg else cfg.runner['type']
train_dataloader_default_args = dict(
    samples_per_gpu=2,
    workers_per_gpu=2,
    # `num_gpus` will be ignored if distributed
    num_gpus=len(cfg["gpu_ids"]),
    dist=distributed,
    seed=cfg["seed"],
    runner_type=runner_type,
    persistent_workers=False)

train_loader_cfg = {
    **train_dataloader_default_args,
    **cfg.data.get('train_dataloader', {})
}
data_loader = build_dataloader(dataset, **train_loader_cfg)

In [7]:
import os
import time

from mmdet.apis.train import train_detector

In [8]:
no_validate = False

timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())

In [9]:
cfg.work_dir = os.path.join('./work_dirs',
                        os.path.splitext(os.path.basename(config))[0])

In [None]:
env_info_dict = collect_env()
env_info = '\n'.join([(f'{k}: {v}') for k, v in env_info_dict.items()])
meta = {
    "env_info": env_info,
    "config": cfg.pretty_text,
    "seed": cfg.seed,
    "exp_name": os.path.basename(config)
}

train_detector(model, [dataset], cfg, distributed=distributed, validate=(not no_validate), timestamp=timestamp, meta=meta)