#### Imports

In [None]:
# torch
import torch
import torchvision

# Detectron2 logger
import detectron2
from detectron2.utils.logger import setup_logger
setup_logger()

# Common Libraries
import numpy as np
import os, json, cv2, random
import matplotlib.pyplot as plt
# from google.colab.patches import cv2_imshow

# Detectron2 utilities
from detectron2 import model_zoo
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2.utils.visualizer import Visualizer
from detectron2.data import MetadataCatalog, DatasetCatalog
from detectron2.data.datasets import register_coco_instances

#### Register and check datasets

In [None]:
os.chdir(os.path.join(os.getcwd(), '..', 'datasets'))

In [None]:
from itertools import product

for d, ds in product(["train", "val"], ['duke']):
    ds_path = os.path.join(os.getcwd(), f'{ds}_{d}', 'data')
    json_path = os.path.join(os.getcwd(), f'{ds}_{d}', 'labels.json')
    ds_name = f'{ds}_{d}'

    if ds == 'overfit':
        continue

    if ds_name in DatasetCatalog.list():
        DatasetCatalog.remove(ds_name)
        MetadataCatalog.remove(ds_name)

    register_coco_instances(ds_name, {}, json_path, ds_path)

In [None]:
train_dataset = 'duke_train'
val_dataset = 'duke_val'
train_metadata = MetadataCatalog.get(train_dataset)
val_metadata = MetadataCatalog.get(val_dataset)

In [None]:
%matplotlib inline
dataset_dicts = DatasetCatalog.get(train_dataset)
tower_metadata = MetadataCatalog.get(train_dataset)

num_examples = 5
for d in random.sample(dataset_dicts, num_examples):
    img = cv2.imread(d["file_name"])
    visualizer = Visualizer(img[:, :, ::-1], metadata=tower_metadata, scale=0.8, instance_mode=1)
    out = visualizer.draw_dataset_dict(d)
    # cv2.imshow(out.get_image()[:, :, ::-1])
    plt.imshow(out.get_image()[:, :, ::])

#### Setup training config

In [None]:
!nvidia-smi

In [None]:
from detectron2.engine import DefaultTrainer
from datetime import date

cfg = get_cfg() # Model Config

frcnn= 'faster_rcnn_R_101_FPN_3x.yaml'
current = frcnn

# From Detectron2 Model Zoo
cfg.merge_from_file(model_zoo.get_config_file("COCO-Detection/"+current))    # https://github.com/facebookresearch/detectron2/blob/main/configs/COCO-Detection/retinanet_R_101_FPN_3x.yaml
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-Detection/"+current)  # Pre-trained Model Weights

# Set Dataset
cfg.DATASETS.TRAIN = (train_dataset,)  # Training Dataset
cfg.DATASETS.TEST = (val_dataset,)   # Validation Dataset
# cfg.DATASETS.TEST = ()

cfg.DATALOADER.NUM_WORKERS = 2  # Number of CPUs to load the data into Detectron2 - 2 for Colab
cfg.SOLVER.IMS_PER_BATCH = 2    # Detectron2 default 16 with 8 GPUs, so 16/8 = 2 for 1 GPU

# Learning Rate
# default retinanet BASE_LR = 0.01 on 8 GPUs
# using linear learning rate rule, on 1 GPU 0.01/8 = 0.00125
# to prevent other layers from modifying too much, divide by 10: 0.00125/10 = 0.000125
cfg.SOLVER.BASE_LR = 0.000125
cfg.SOLVER.MAX_ITER = 1_100_000
cfg.SOLVER.STEPS = []   # does not decay learning rate

# ROI = region of interest
# is used to sample a subset of proposals coming out of RPN (Region Proposal Network)
# to calculate cls and reg loss during training. 
# calculating loss on all RPN proposals isn't computationally efficient
#-------
# RoI minibatch size *per image* (number of regions of interest [ROIs]) during training
# Total number of RoIs per training minibatch = ROI_HEADS.BATCH_SIZE_PER_IMAGE * SOLVER.IMS_PER_BATCH
cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 512   # (default: 512)

# Config for New Datasets (see https://detectron2.readthedocs.io/tutorials/datasets.html#update-the-config-for-new-datasets)
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 2   # for R-CNN Models
# cfg.MODEL.RETINANET.NUM_CLASSES = 2   # for RetinaNet 

model_name = 'PISA_'+str(date.today())+'_frcnn_'+str(cfg.SOLVER.MAX_ITER)+'_iters_duke'
# model_path = '/content/drive/MyDrive/PyPSA_Africa_images/models/'+model_name
# cfg.MODEL.WEIGHTS = model_path

cfg.OUTPUT_DIR = os.path.join(os.getcwd(), '..', 'models', model_name)
# cfg.OUTPUT_DIR = 'duke_training_' + str(cfg.SOLVER.MAX_ITER) + '_retinanet'
os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)
trainer = DefaultTrainer(cfg) 
trainer.resume_or_load(resume=False)

In [None]:
trainer.train()