In [1]:
import detectron2
import numpy as np
import os, json, cv2, random
from detectron2 import model_zoo
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2.utils.visualizer import Visualizer
from detectron2.data import MetadataCatalog, DatasetCatalog

In [None]:
# ready the dataset as coco format
import pycocotools
from detectron2.structures import BoxMode

def get_data_dicts(img_dir):
    from glob import glob
    filelist = [x[:-14] for x in sorted(os.listdir(img_dir)) if x.endswith('_instances.png')]

    dataset_dicts = []
    for idx, file in enumerate(filelist):
        record = {}

        filename = os.path.join(img_dir, file + '_color.jpg')
        try:
          height, width = cv2.imread(filename).shape[:2]

          record["file_name"] = filename
          record["image_id"] = idx
          record["height"] = height
          record["width"] = width

          objs = []

          instance_img = cv2.imread(os.path.join(img_dir, file + '_instances.png'), -1)
          label_img    = cv2.imread(os.path.join(img_dir, file + '_class_labels.png'), -1)

          instance_ids = np.unique(instance_img)
          if instance_ids[0] == 0:
              instance_ids = instance_ids[1:]

          for _id in instance_ids:
              mask = instance_img == _id
              py, px = np.where(mask == True)
              label = label_img[mask][0]
              mask = pycocotools.mask.encode(mask.astype(np.uint8, order="F"))

              obj = {
                  "bbox": [np.min(px), np.min(py), np.max(px), np.max(py)],
                  "bbox_mode": BoxMode.XYXY_ABS,
                  "segmentation": mask,
                  "category_id": label-1,
              }
              objs.append(obj)
          record["annotations"] = objs
          dataset_dicts.append(record)
        except:
          pass
    return dataset_dicts

In [3]:
DatasetCatalog.clear()
MetadataCatalog.clear()

In [4]:
# register the dataset
d = "bench_RV_train"
DatasetCatalog.register("bench", lambda d=d: get_data_dicts(d))

In [5]:
# check if the dataset registered
# DatasetCatalog.get("bench")

In [6]:
# set parameters

from detectron2.engine import DefaultTrainer

cfg = get_cfg()
cfg.OUTPUT_DIR = "output_cocockpt_RV"
cfg.INPUT.MASK_FORMAT='bitmask'
cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
cfg.DATASETS.TRAIN = ("bench",)
cfg.DATASETS.TEST = ()
cfg.DATALOADER.NUM_WORKERS = 2
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")
cfg.SOLVER.IMS_PER_BATCH = 2
cfg.SOLVER.BASE_LR = 0.0001 
cfg.SOLVER.MAX_ITER = 5000
cfg.SOLVER.CHECKPOINT_PERIOD = 1000
cfg.SOLVER.STEPS = []
cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 512
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1

In [7]:
# start training

os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)
trainer = DefaultTrainer(cfg)
trainer.resume_or_load(resume=False)
trainer.train()

[32m[02/19 01:30:16 d2.engine.defaults]: [0mModel:
GeneralizedRCNN(
  (backbone): FPN(
    (fpn_lateral2): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
    (fpn_output2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (fpn_lateral3): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1))
    (fpn_output3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (fpn_lateral4): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1))
    (fpn_output4): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (fpn_lateral5): Conv2d(2048, 256, kernel_size=(1, 1), stride=(1, 1))
    (fpn_output5): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (top_block): LastLevelMaxPool()
    (bottom_up): ResNet(
      (stem): BasicStem(
        (conv1): Conv2d(
          3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
          (norm): FrozenBatchNorm2d(num_features=64, eps=1e-05)
        )
      )
 

model_final_f10217.pkl: 178MB [01:07, 2.64MB/s]                              
Skip loading parameter 'roi_heads.box_predictor.cls_score.weight' to the model due to incompatible shapes: (81, 1024) in the checkpoint but (2, 1024) in the model! You might want to double check if this is expected.
Skip loading parameter 'roi_heads.box_predictor.cls_score.bias' to the model due to incompatible shapes: (81,) in the checkpoint but (2,) in the model! You might want to double check if this is expected.
Skip loading parameter 'roi_heads.box_predictor.bbox_pred.weight' to the model due to incompatible shapes: (320, 1024) in the checkpoint but (4, 1024) in the model! You might want to double check if this is expected.
Skip loading parameter 'roi_heads.box_predictor.bbox_pred.bias' to the model due to incompatible shapes: (320,) in the checkpoint but (4,) in the model! You might want to double check if this is expected.
Skip loading parameter 'roi_heads.mask_head.predictor.weight' to the model due t

[32m[02/19 01:33:07 d2.engine.train_loop]: [0mStarting training from iteration 0


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


[32m[02/19 01:33:14 d2.utils.events]: [0m eta: 0:15:44  iter: 19  total_loss: 1.624  loss_cls: 0.6561  loss_box_reg: 0.2449  loss_mask: 0.6859  loss_rpn_cls: 0.0307  loss_rpn_loc: 0.01489    time: 0.1879  last_time: 0.1668  data_time: 0.1312  last_data_time: 0.0019   lr: 1.9981e-06  max_mem: 3344M
[32m[02/19 01:33:19 d2.utils.events]: [0m eta: 0:15:52  iter: 39  total_loss: 1.623  loss_cls: 0.6377  loss_box_reg: 0.2414  loss_mask: 0.6866  loss_rpn_cls: 0.02864  loss_rpn_loc: 0.01563    time: 0.1867  last_time: 0.1912  data_time: 0.0017  last_data_time: 0.0017   lr: 3.9961e-06  max_mem: 3351M
[32m[02/19 01:33:23 d2.utils.events]: [0m eta: 0:15:46  iter: 59  total_loss: 1.522  loss_cls: 0.5631  loss_box_reg: 0.2322  loss_mask: 0.6743  loss_rpn_cls: 0.01218  loss_rpn_loc: 0.007948    time: 0.1851  last_time: 0.1919  data_time: 0.0016  last_data_time: 0.0018   lr: 5.9941e-06  max_mem: 3351M
[32m[02/19 01:33:26 d2.utils.events]: [0m eta: 0:15:23  iter: 79  total_loss: 1.443  loss_cl