<a href="https://colab.research.google.com/github/swilcock0/artec25/blob/main/Tree.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# [How to train Detectron2 with Custom COCO Datasets](https://www.dlology.com/blog/how-to-train-detectron2-with-custom-coco-datasets/) | DLology

<img src="https://dl.fbaipublicfiles.com/detectron2/Detectron2-Logo-Horz.png" width="500">

This notebook will help you get started with this framwork by training a instance segmentation model with your custom COCO datasets.

# Install detectron2

In [None]:
!pip install -U torch torchvision
!pip install git+https://github.com/facebookresearch/fvcore.git
import torch, torchvision
torch.__version__
!pip install kagglehub[pandas-datasets]

Collecting torch
  Downloading torch-2.6.0-cp311-cp311-manylinux1_x86_64.whl.metadata (28 kB)
Collecting torchvision
  Downloading torchvision-0.21.0-cp311-cp311-manylinux1_x86_64.whl.metadata (6.1 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.

In [None]:
!git clone https://github.com/facebookresearch/detectron2 detectron2_repo
!pip install -e detectron2_repo

Cloning into 'detectron2_repo'...
remote: Enumerating objects: 15837, done.[K
remote: Counting objects: 100% (56/56), done.[K
remote: Compressing objects: 100% (44/44), done.[K
remote: Total 15837 (delta 30), reused 12 (delta 12), pack-reused 15781 (from 2)[K
Receiving objects: 100% (15837/15837), 6.40 MiB | 11.54 MiB/s, done.
Resolving deltas: 100% (11537/11537), done.
Obtaining file:///content/detectron2_repo
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting fvcore<0.1.6,>=0.1.5 (from detectron2==0.6)
  Downloading fvcore-0.1.5.post20221221.tar.gz (50 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m50.2/50.2 kB[0m [31m2.8 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting iopath<0.1.10,>=0.1.7 (from detectron2==0.6)
  Downloading iopath-0.1.9-py3-none-any.whl.metadata (370 bytes)
Collecting omegaconf<2.4,>=2.1 (from detectron2==0.6)
  Downloading omegaconf-2.3.0-py3-none-any.whl.metadata (3.9 kB)

In [None]:
# You may need to restart your runtime prior to this, to let your installation take effect
# Some basic setup
# Setup detectron2 logger
import detectron2
from detectron2.utils.logger import setup_logger
setup_logger()

# import some common libraries
import matplotlib.pyplot as plt
import numpy as np
import cv2
from google.colab.patches import cv2_imshow

# import some common detectron2 utilities
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2.utils.visualizer import Visualizer
from detectron2.data import MetadataCatalog, DatasetCatalog

# Train on a custom COCO dataset

In this section, we show how to train an existing detectron2 model on a custom dataset in a new format.

We use [the fruits nuts segmentation dataset](https://github.com/Tony607/mmdetection_instance_segmentation_demo)
which only has 3 classes: data, fig, and hazelnut.
We'll train a segmentation model from an existing model pre-trained on the COCO dataset, available in detectron2's model zoo.

Note that the COCO dataset does not have the "data", "fig" and "hazelnut" categories.

In [None]:
import kagglehub

# Download latest version
path = kagglehub.dataset_download("erickendric/tree-dataset-of-urban-street-segmentation-tree")
!wget https://samwilcock.xyz/Files/train.json
print("Path to dataset files:", path)

Downloading from https://www.kaggle.com/api/v1/datasets/download/erickendric/tree-dataset-of-urban-street-segmentation-tree?dataset_version_number=1...


100%|██████████| 7.95G/7.95G [01:17<00:00, 110MB/s] 

Extracting files...





--2025-02-10 16:04:23--  https://samwilcock.xyz/Files/train.json
Resolving samwilcock.xyz (samwilcock.xyz)... 185.199.111.153, 185.199.109.153, 185.199.110.153, ...
Connecting to samwilcock.xyz (samwilcock.xyz)|185.199.111.153|:443... connected.
HTTP request sent, awaiting response... 301 Moved Permanently
Location: https://www.samwilcock.xyz/Files/train.json [following]
--2025-02-10 16:04:23--  https://www.samwilcock.xyz/Files/train.json
Resolving www.samwilcock.xyz (www.samwilcock.xyz)... 185.199.108.153, 185.199.109.153, 185.199.110.153, ...
Connecting to www.samwilcock.xyz (www.samwilcock.xyz)|185.199.108.153|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 54008700 (52M) [application/json]
Saving to: ‘train.json’


2025-02-10 16:04:25 (291 MB/s) - ‘train.json’ saved [54008700/54008700]

Path to dataset files: /root/.cache/kagglehub/datasets/erickendric/tree-dataset-of-urban-street-segmentation-tree/versions/1


Register the tree dataset to detectron2, following the [detectron2 custom dataset tutorial](https://detectron2.readthedocs.io/tutorials/datasets.html).


In [None]:
from detectron2.data.datasets import register_coco_instances

# Get the actual path to images directory
image_dir = path + "/tree/VOCdevkit/VOC2012/JPEGImages"
register_coco_instances("treeMe", {}, "train.json", image_dir)

In [None]:
tree_metadata = MetadataCatalog.get("treeMe")
dataset_dicts = DatasetCatalog.get("treeMe")

Category ids in annotations are not in [1, #categories]! We'll apply a mapping for you.

[02/10 16:04:25 d2.data.datasets.coco]: Loaded 3168 images in COCO format from train.json


To verify the data loading is correct, let's visualize the annotations of randomly selected samples in the training set:



In [None]:
import random

for d in random.sample(dataset_dicts, 3):
    # Print the file name to check if it is valid and accessible
    print("File name:", d["file_name"])

    img = cv2.imread(d["file_name"])

    # Check if the image was loaded successfully
    if img is None:
        print(f"Error: Could not load image {d['file_name']}. Please check the file path and image format.")
        continue  # Skip to the next image

    visualizer = Visualizer(img[:, :, ::-1], metadata=tree_metadata, scale=0.5)
    vis = visualizer.draw_dataset_dict(d)
    cv2_imshow(vis.get_image()[:, :, ::-1])

Now, let's fine-tune a coco-pretrained R50-FPN Mask R-CNN model on the fruits_nuts dataset. It takes ~6 minutes to train 300 iterations on Colab's K80 GPU.


In [None]:
from detectron2.engine import DefaultTrainer
from detectron2.config import get_cfg
from detectron2 import model_zoo
import detectron2.data.transforms as T

import os

cfg = get_cfg()
cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
cfg.DATASETS.TRAIN = ("treeMe",)
cfg.DATASETS.TEST = ()   # no metrics implemented for this dataset
cfg.DATALOADER.NUM_WORKERS = 2
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")  # initialize from model zoo
cfg.SOLVER.IMS_PER_BATCH = 8
cfg.SOLVER.MAX_ITER = 300
cfg.SOLVER.BASE_LR = 0.001
cfg.INPUT.MASK_FORMAT = "bitmask"
cfg.SOLVER.LR_SCHEDULER_NAME = "WarmupCosineLR"
cfg.SOLVER.WARMUP_ITERS = int(0.2*cfg.SOLVER.MAX_ITER)
# cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 256   # faster, and good enough for this toy dataset
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 2  # 3 classes (data, fig, hazelnut)


# Define a sequence of augmentations:
augs = T.AugmentationList([
    T.RandomBrightness(0.9, 1.1),
    T.RandomFlip(prob=0.5),
    T.RandomCrop("absolute", (400, 640)),
    T.RandomRotation(angle=[-30, 30]),
    T.ResizeShortestEdge(
        [cfg.INPUT.MIN_SIZE_TEST, cfg.INPUT.MIN_SIZE_TEST], cfg.INPUT.MAX_SIZE_TEST
    )
])  # type: T.Augmentation

os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)
trainer = DefaultTrainer(cfg)
trainer.aug = augs
trainer.resume_or_load(resume=False)
trainer.train()

[02/10 16:55:06 d2.engine.defaults]: Model:
GeneralizedRCNN(
  (backbone): FPN(
    (fpn_lateral2): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
    (fpn_output2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (fpn_lateral3): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1))
    (fpn_output3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (fpn_lateral4): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1))
    (fpn_output4): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (fpn_lateral5): Conv2d(2048, 256, kernel_size=(1, 1), stride=(1, 1))
    (fpn_output5): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (top_block): LastLevelMaxPool()
    (bottom_up): ResNet(
      (stem): BasicStem(
        (conv1): Conv2d(
          3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
          (norm): FrozenBatchNorm2d(num_features=64, eps=1e-05)
        )
      )
      (res

roi_heads.box_predictor.bbox_pred.{bias, weight}
roi_heads.box_predictor.cls_score.{bias, weight}
roi_heads.mask_head.predictor.{bias, weight}


[02/10 16:55:07 d2.engine.train_loop]: Starting training from iteration 0


  torch.stack([torch.from_numpy(np.ascontiguousarray(x)) for x in masks])
  torch.stack([torch.from_numpy(np.ascontiguousarray(x)) for x in masks])


[02/10 16:56:56 d2.utils.events]:  eta: 0:25:09  iter: 19  total_loss: 1.639  loss_cls: 0.7704  loss_box_reg: 0.06149  loss_mask: 0.6835  loss_rpn_cls: 0.05359  loss_rpn_loc: 0.01614    time: 5.4369  last_time: 6.4442  data_time: 3.2515  last_data_time: 4.2795   lr: 0.00028711  max_mem: 11886M
[02/10 16:58:29 d2.utils.events]:  eta: 0:20:32  iter: 39  total_loss: 0.7276  loss_cls: 0.09917  loss_box_reg: 0.08918  loss_mask: 0.4915  loss_rpn_cls: 0.01519  loss_rpn_loc: 0.01257    time: 5.0377  last_time: 3.8337  data_time: 2.5201  last_data_time: 1.6384   lr: 0.00058828  max_mem: 11886M
[02/10 17:00:06 d2.utils.events]:  eta: 0:18:47  iter: 59  total_loss: 0.4122  loss_cls: 0.05201  loss_box_reg: 0.0938  loss_mask: 0.2425  loss_rpn_cls: 0.0099  loss_rpn_loc: 0.01509    time: 4.9688  last_time: 3.8632  data_time: 2.6694  last_data_time: 1.6614   lr: 0.00088945  max_mem: 11886M
[02/10 17:01:45 d2.utils.events]:  eta: 0:17:22  iter: 79  total_loss: 0.3745  loss_cls: 0.03899  loss_box_reg: 0

KeyboardInterrupt: 

Now, we perform inference with the trained model on the fruits_nuts dataset. First, let's create a predictor using the model we just trained:



In [None]:
cfg.MODEL.WEIGHTS = os.path.join(cfg.OUTPUT_DIR, "model_final.pth")
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5   # set the testing threshold for this model
cfg.DATASETS.TEST = ("treeMe", )
predictor = DefaultPredictor(cfg)

[02/10 16:49:56 d2.checkpoint.detection_checkpoint]: [DetectionCheckpointer] Loading from ./output/model_final.pth ...


Then, we randomly select several samples to visualize the prediction results.

In [None]:
from detectron2.utils.visualizer import ColorMode

for d in random.sample(dataset_dicts, 3):
    im = cv2.imread(d["file_name"])
    outputs = predictor(im)
    v = Visualizer(im[:, :, ::-1],
                   metadata=tree_metadata,
                   scale=0.3,
                   instance_mode=ColorMode.IMAGE_BW   # remove the colors of unsegmented pixels
    )
    v = v.draw_instance_predictions(outputs["instances"].to("cpu"))
    cv2_imshow(v.get_image()[:, :, ::-1])

In [None]:
tree_metadata

## Benchmark inference speed

In [None]:
import time
times = []
for i in range(20):
    start_time = time.time()
    outputs = predictor(im)
    delta = time.time() - start_time
    times.append(delta)
mean_delta = np.array(times).mean()
fps = 1 / mean_delta
print("Average(sec):{:.2f},fps:{:.2f}".format(mean_delta, fps))

Average(sec):0.25,fps:3.97
