## Helpful Links for Detectron2
- Guide to custom data training - https://www.analyticsvidhya.com/blog/2021/08/your-guide-to-object-detection-with-detectron2-in-pytorch/  
- Detectron2 configuration documentation - https://detectron2.readthedocs.io/en/latest/modules/config.html 
- Github link https://github.com/facebookresearch/detectron2 
- Another custom data training guide - https://towardsdatascience.com/train-maskrcnn-on-custom-dataset-with-detectron2-in-4-steps-5887a6aa135d


### Possible help for errors
- https://stackoverflow.com/questions/69002169/json-annotations-error-string-indices-must-be-integers
- https://stackoverflow.com/questions/63012735/typeerror-string-indices-must-be-integers-while-trying-to-train-mask-rcnn-imple


In [1]:
import json
from detectron2.data import MetadataCatalog, DatasetCatalog


def load_data(t="train"):
    if t == "train":
        with open("../data/detectron2/training/training.json", 'r') as file:
            train = json.load(file)
        return train
    elif t == "val":
      with open("../data/detectron2/validation/validation.json", 'r') as file:
          val = json.load(file)
    return val

In [2]:
from detectron2.config import get_cfg
from detectron2 import model_zoo
import os


def custom_config(num_classes):
    cfg = get_cfg()

    cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
    cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")

    cfg.MODEL.MASK_ON = True

    cfg.DATASETS.TRAIN = ("train",)
    cfg.DATASETS.TEST = ("val",)

    cfg.DATALOADER.NUM_WORKERS = 2
    cfg.SOLVER.IMS_PER_BATCH = 1
    cfg.SOLVER.BASE_LR = 0.001
    cfg.SOLVER.MAX_ITER = 5000

    cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 32   # faster, enough for this dataset (default: 512)
    cfg.MODEL.ROI_HEADS.NUM_CLASSES = num_classes

    cfg.MODEL.DEVICE='cpu'
    
    cfg.OUTPUT_DIR = "mask_worms"

    os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)
    
    return cfg

In [3]:
for d in ["train", "val"]:
    DatasetCatalog.register(d, lambda d=d: load_data(d))
    MetadataCatalog.get(d).set(thing_classes=["abw", "pbw"])

In [4]:
from detectron2.data import MetadataCatalog, DatasetCatalog
from detectron2.engine import DefaultTrainer
    
metadata = MetadataCatalog.get("train")

cfg = custom_config(2)


trainer = DefaultTrainer(cfg)
trainer.resume_or_load(resume=False)
trainer.train()

[32m[11/22 16:40:16 d2.engine.defaults]: [0mModel:
GeneralizedRCNN(
  (backbone): FPN(
    (fpn_lateral2): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
    (fpn_output2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (fpn_lateral3): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1))
    (fpn_output3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (fpn_lateral4): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1))
    (fpn_output4): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (fpn_lateral5): Conv2d(2048, 256, kernel_size=(1, 1), stride=(1, 1))
    (fpn_output5): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (top_block): LastLevelMaxPool()
    (bottom_up): ResNet(
      (stem): BasicStem(
        (conv1): Conv2d(
          3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
          (norm): FrozenBatchNorm2d(num_features=64, eps=1e-05)
        )
      )
 

Skip loading parameter 'roi_heads.box_predictor.cls_score.weight' to the model due to incompatible shapes: (81, 1024) in the checkpoint but (3, 1024) in the model! You might want to double check if this is expected.
Skip loading parameter 'roi_heads.box_predictor.cls_score.bias' to the model due to incompatible shapes: (81,) in the checkpoint but (3,) in the model! You might want to double check if this is expected.
Skip loading parameter 'roi_heads.box_predictor.bbox_pred.weight' to the model due to incompatible shapes: (320, 1024) in the checkpoint but (8, 1024) in the model! You might want to double check if this is expected.
Skip loading parameter 'roi_heads.box_predictor.bbox_pred.bias' to the model due to incompatible shapes: (320,) in the checkpoint but (8,) in the model! You might want to double check if this is expected.
Skip loading parameter 'roi_heads.mask_head.predictor.weight' to the model due to incompatible shapes: (80, 256, 1, 1) in the checkpoint but (2, 256, 1, 1) in

[32m[11/22 16:40:17 d2.engine.train_loop]: [0mStarting training from iteration 0


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


[32m[11/22 16:40:48 d2.utils.events]: [0m eta: 1:58:09  iter: 19  total_loss: 2.717  loss_cls: 1.004  loss_box_reg: 0.7893  loss_mask: 0.6932  loss_rpn_cls: 0.1186  loss_rpn_loc: 0.03198  time: 1.4876  data_time: 0.0942  lr: 1.9981e-05  
[32m[11/22 16:41:20 d2.utils.events]: [0m eta: 2:08:09  iter: 39  total_loss: 2.441  loss_cls: 0.7954  loss_box_reg: 0.85  loss_mask: 0.6641  loss_rpn_cls: 0.0456  loss_rpn_loc: 0.03639  time: 1.5574  data_time: 0.0007  lr: 3.9961e-05  
[32m[11/22 16:41:52 d2.utils.events]: [0m eta: 2:08:51  iter: 59  total_loss: 2.245  loss_cls: 0.6259  loss_box_reg: 0.8866  loss_mask: 0.6231  loss_rpn_cls: 0.02205  loss_rpn_loc: 0.02335  time: 1.5660  data_time: 0.0006  lr: 5.9941e-05  
[32m[11/22 16:42:26 d2.utils.events]: [0m eta: 2:10:25  iter: 79  total_loss: 2.085  loss_cls: 0.5141  loss_box_reg: 0.8405  loss_mask: 0.5803  loss_rpn_cls: 0.03938  loss_rpn_loc: 0.03102  time: 1.5944  data_time: 0.0007  lr: 7.9921e-05  
[32m[11/22 16:42:57 d2.utils.events]

In [None]:
from detectron2.engine import DefaultPredictor
from detectron2.utils.visualizer import Visualizer, ColorMode
import matplotlib.pyplot as plt
import cv2
import os


def visualization(metadata, cfg, test_set):
    cfg.MODEL.WEIGHTS = os.path.join(cfg.OUTPUT_DIR, "model_final.pth")
    cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.8
    predictor = DefaultPredictor(cfg)
    for d in test_set:
        im = cv2.imread(d["file_name"])
        outputs = predictor(
            im)
        v = Visualizer(im[:, :, ::-1],
                       metadata=metadata,
                       scale=0.5,
                       instance_mode=ColorMode.IMAGE_BW
                       )
        out = v.draw_instance_predictions(outputs["instances"].to("cpu"))
        img = cv2.cvtColor(out.get_image()[:, :, ::-1], cv2.COLOR_RGBA2RGB)
        plt.imsave(os.path.join(os.path.join(cfg.OUTPUT_DIR, 'visualization'), str(d["image_id"]) + '.png'), img)