In [None]:
! [ -d "MeterDataset" ] && echo "skipping" || (wget -nc --no-check-certificate "http://artelab.dista.uninsubria.it/downloads/datasets/automatic_meter_reading/gas_meter_reading/gas_meter_reading.zip" && unzip gas_meter_reading -d .)

In [None]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="0"

In [None]:
import pathlib
from sap_computer_vision.datasets import pascal_voc_style as pvs

dataset_folder  = pathlib.Path('MeterDataset/Rough-Digit-Classification/')

d, c = pvs.split_and_register('dataset',
                              img_dir=dataset_folder / 'JPEGImages',
                              xml_dir=dataset_folder / 'Annotations',
                              splits={'train': 0.8,
                                      'test': 0.1,
                                      'validation': 0.1})

In [None]:
for k, v in d.items():
    print(f'{k} {len(v)} examples')

In [None]:
import shutil

from sap_computer_vision import setup_loggers, get_cfg, get_config_file
import numpy as np

out_dir = 'object_detecion_model_modified_loader'
if pathlib.Path(out_dir).exists():
    shutil.rmtree(out_dir)
    #raise RuntimeError('Result folder already exists. Please delete the folder or change the name of the output')

setup_loggers(out_dir)
cfg = get_cfg()
cfg.merge_from_file(get_config_file('Base-EarlyStopping'))
cfg.merge_from_file(get_config_file('Base-Evaluation'))
cfg.merge_from_file(get_config_file('COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml'))

cfg.SOLVER.MAX_ITER = 100
cfg.SOLVER.BASE_LR = 0.001
cfg.SOLVER.GAMMA = float(np.sqrt(0.1))
cfg.SOLVER.EARLY_STOPPING.ENABLED = False

cfg.SOLVER.WARMUP_ITERS = max(int(0.01 * cfg.SOLVER.MAX_ITER), 50)
cfg.SOLVER.STEPS = [cfg.SOLVER.MAX_ITER * p for p in (0.05, 0.25, 0.375, 0.5, 0.75, 0.9)]
for aug in ['RANDOM_LIGHTING', 'RANDOM_BRIGHTNESS', 'RANDOM_SATURATION', 'RANDOM_CONTRAST', 'CROP', 'RANDOM_ROTATION']:
    if cfg.INPUT.get(aug, None) is not None:
        cfg.INPUT[aug].ENABLED = True

cfg.INPUT.RANDOM_FLIP = "none"
        
     
cfg.MODEL.ROI_HEADS.NUM_CLASSES = len(c)
        

cfg.OUTPUT_DIR = out_dir
cfg.DATASETS.TRAIN = ('dataset_train', )
cfg.DATASETS.TEST = ('dataset_validation', )

# Adjust to hardware
cfg.SOLVER.IMS_PER_BATCH = 8 
cfg.SOLVER.IMS_PER_BATCH_EVAL = 12
cfg.DATALOADER.NUM_WORKERS = 10

cfg.TEST.EVAL_PERIOD = 250


# Add new Option for the custom Dataloader
from detectron2.config import CfgNode
cfg.METER_READING = CfgNode({})
cfg.METER_READING.DIGIT_REPLACE_PROB_TRAIN = 0.8

In [None]:
cfg.SOLVER.EARLY_STOPPING.ENABLED = False

In [None]:
out_dir = pathlib.Path(out_dir)
out_dir.mkdir(parents=True, exist_ok=True)
with (out_dir / 'used_config.yaml').open('w') as stream:
    stream.write(cfg.dump())

In [None]:
from sap_computer_vision.engine import ObjectDetectionTrainer

In [None]:
import copy
import logging
import numpy as np
from typing import List, Optional, Union, Dict
import torch
from PIL import Image

from detectron2.data.build import get_detection_dataset_dicts
from detectron2.config import configurable

from sap_computer_vision.data.data_build import DatasetMapperWithAdditionalAugmentaions
from detectron2.data.build import build_detection_train_loader
from detectron2.data import detection_utils as utils

from detectron2.data import transforms as T
from detectron2.structures.boxes import pairwise_iou, BoxMode


def cut_and_paste(img_src, img_dest, anno_src, anno_dest, image_format):
    anno_dest['category_id'] = anno_src['category_id']
    bbox_src = BoxMode.convert(anno_src['bbox'], anno_src['bbox_mode'], BoxMode.XYXY_ABS)
    _, _, w_dest, h_dest = BoxMode.convert(anno_dest['bbox'], anno_dest['bbox_mode'], BoxMode.XYWH_ABS)
    bbox_dest = BoxMode.convert(anno_dest['bbox'], anno_dest['bbox_mode'], BoxMode.XYXY_ABS)
    bbox_src, bbox_dest = [int(x) for x in bbox_src], [int(x) for x in bbox_dest]
    region = img_src.crop(bbox_src).resize((int(w_dest), int(h_dest)))
    img_dest.paste(region, bbox_dest)
    return img_dest


class DatasetMapperPatchwork(DatasetMapperWithAdditionalAugmentaions):
    
    @configurable
    def __init__(
        self,
        *args,
        dataset_dicts: List[Dict] = {},
        digit_replace_prob: float = 0.2,
        **kwargs
    ):
        self.dataset_dicts = dataset_dicts
        self.digit_replace_prob = digit_replace_prob
        ret = super().__init__.__wrapped__(self, *args, **kwargs)

    @classmethod
    def from_config(cls,
                    cfg,
                    is_train: bool=True,
                    digit_replace_prob: float = 0.2,
                    additional_augs_orignal_image=None,
                    additional_augs_resized_image=None):
        ret = DatasetMapperWithAdditionalAugmentaions.from_config(
            cfg,
            is_train,
            additional_augs_orignal_image=additional_augs_orignal_image,
            additional_augs_resized_image=additional_augs_resized_image
        )
        if is_train:
            ret['dataset_dicts'] = get_detection_dataset_dicts(cfg.DATASETS.TRAIN)
        else:
            ret['dataset_dicts'] = get_detection_dataset_dicts(cfg.DATASETS.TEST)
        ret['digit_replace_prob'] = digit_replace_prob
        return ret

    
    def load_and_patch_image(self, dataset_dict, image_format):
        new_dataset_dict = copy.deepcopy(dataset_dict)
        image = Image.open(new_dataset_dict["file_name"])
        for i, a in enumerate(new_dataset_dict['annotations']):
            if np.random.uniform() > self.digit_replace_prob:
                continue
            source = np.random.choice(self.dataset_dicts)
            source_image = Image.open(source["file_name"])
            new_anno = np.random.choice(source['annotations'])
            image = cut_and_paste(source_image, image, new_anno, a, self.image_format)
        img_dest = utils.convert_PIL_to_numpy(image, format=image_format)
        return img_dest, new_dataset_dict
    
    
    def __call__(self, dataset_dict):
        """
        Code copied from: https://github.com/facebookresearch/detectron2/blob/main/detectron2/data/dataset_mapper.py
        """
        image, dataset_dict = self.load_and_patch_image(dataset_dict, self.image_format)
        utils.check_image_size(dataset_dict, image)
        # USER: Remove if you don't do semantic/panoptic segmentation.
        if "sem_seg_file_name" in dataset_dict:
            sem_seg_gt = utils.read_image(dataset_dict.pop("sem_seg_file_name"), "L").squeeze(2)
        else:
            sem_seg_gt = None
        
        aug_input = T.AugInput(image, sem_seg=sem_seg_gt)
        transforms = self.augmentations(aug_input)
        image, sem_seg_gt = aug_input.image, aug_input.sem_seg

        image_shape = image.shape[:2]
        dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1)))
        if sem_seg_gt is not None:
            dataset_dict["sem_seg"] = torch.as_tensor(sem_seg_gt.astype("long"))

        if self.proposal_topk is not None:
            utils.transform_proposals(
                dataset_dict, image_shape, transforms, proposal_topk=self.proposal_topk
            )

        if not self.is_train:
            dataset_dict.pop("annotations", None)
            dataset_dict.pop("sem_seg_file_name", None)
            return dataset_dict
        
        if "annotations" in dataset_dict:
            self._transform_annotations(dataset_dict, transforms, image_shape)

        return dataset_dict
    

    
class ObjectDetectionTrainerPatchedImages(ObjectDetectionTrainer):
    @classmethod
    def build_train_loader(cls, cfg):
        additional_augs_orignal_image = []
        additional_augs_resized_image = []
        mapper = DatasetMapperPatchwork(
            cfg,
            is_train=True,
            digit_replace_prob=cfg.METER_READING.DIGIT_REPLACE_PROB_TRAIN)
        return build_detection_train_loader(cfg, mapper=mapper)


In [None]:
from matplotlib import pyplot as plt

mapper = DatasetMapperPatchwork(cfg, is_train=True, digit_replace_prob=0.4)
random_example = np.random.choice(mapper.dataset_dicts)
mapped_example = mapper(random_example)
rgb_numpy_img = utils.convert_image_to_rgb(np.transpose(mapped_example['image'].numpy(), (1, 2, 0)), cfg.INPUT.FORMAT)
plt.imshow(rgb_numpy_img)
print(mapped_example['instances'].gt_classes.tolist())

In [None]:
trainer = ObjectDetectionTrainerPatchedImages(cfg)
trainer.resume_or_load(resume=False)

In [None]:
trainer.train()

In [None]:
from detectron2.data import MetadataCatalog

dl = trainer.build_test_loader(cfg, cfg.DATASETS.TEST)
metadata = MetadataCatalog.get(cfg.DATASETS.TEST[0])

In [None]:
from sap_computer_vision.utils.object_detection import torchvision_nms_on_model_output

iou_threshold = 0.5


results = []
for batch in dl:
    trainer.model.eval()
    with torch.no_grad():
        results_batch = trainer.model(batch)
        results_batch = torchvision_nms_on_model_output(results_batch, iou_threshold, device='cpu')
        for i in batch: # remove loaded image for smaller results
            del i['image']
        results.extend([{**i, **o} for i, o in zip(batch, results_batch)])

In [None]:
from detectron2.data.detection_utils import read_image
from detectron2.utils.visualizer import Visualizer

resuld_idx = np.random.choice([*range(len(results))])
vis = Visualizer(img_rgb=read_image(results[resuld_idx]['file_name']), metadata=metadata)
vis.draw_instance_predictions(results[resuld_idx]['instances']).fig