# CS 3890 - Fall 2021
### Rong Wang
The MIDOG challenge is a challenge to detect mitotic figures in medical images so that detection is robust no matter what scanner is used. This is the challenge of domain generalization in microscopy images. <br>
This is my attempt to adapt MIDOG team's baseline domain adversarial model to a Google Colab. <br>
Currently, this code does not successfully distinguish 'hard negatives' and 'mitotic figures' due to bugs within the code. <br>
Much of the code is from the MIDOG team's baseline domain adversarial model, which can be found at https://github.com/DeepPathology/MIDOG. 
Other parts of the code are referenced from the MIDOG team's MIDOG_ObjectDetecton_101.ipynb. <br>
The MIDOG challenge can be found here: https://imi.thi.de/midog/. 

@article{marzahl2020deep,
  title={Deep learning-based quantification of pulmonary hemosiderophages in cytology slides},
  author={Marzahl, Christian and Aubreville, Marc and Bertram, Christof A and Stayt, Jason and Jasensky, Anne-Katherine and Bartenschlager, Florian and Fragoso-Garcia, Marco and Barton, Ann K and Elsemann, Svenja and Jabari, Samir and Jens, Krauth and Prathmesh, Madhu and Jörn, Voigt and Jenny, Hill and Robert, Klopfleisch and Andreas, Maier },
  journal={Scientific Reports},
  volume={10},
  number={1},
  pages={1--10},
  year={2020},
  publisher={Nature Publishing Group}
}

In [None]:
#@title Import some python packages { vertical-output: true, display-mode: "form" }

%reload_ext autoreload
%autoreload 2
%matplotlib inline

!pip install -U plotly

import json
from pathlib import Path
import plotly
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots
from tqdm import tqdm
import pandas as pd
import random
import cv2

In [None]:
from google.colab import drive
drive.mount('/drive')

folder = "MyDrive/Colab Notebooks/MIDOG/Github_Data" #@param {type:"string"}
midog_folder = Path("/drive") / Path(folder)

print(midog_folder)
#@markdown Your output should contain **MIDOG.sqlite** and **MIDOG.json**:
print(list(midog_folder.glob("*.*")))

In [None]:
# Handle who slide images
!apt-get install python3-openslide
from openslide import open_slide

In [None]:
# Install the object detection library
!pip install -U object-detection-fastai

from object_detection_fastai.helper.wsi_loader import *
from object_detection_fastai.loss.RetinaNetFocalLoss import RetinaNetFocalLoss
from object_detection_fastai.models.RetinaNet import RetinaNet
from object_detection_fastai.callbacks.callbacks import BBMetrics, PascalVOCMetricByDistance, PascalVOCMetric, PascalVOCMetricByDistance

In [None]:
from fastai import *
from fastai.vision import *
from fastai.callbacks import *
from fastai.vision.models.unet import _get_sfs_idxs

# Gradient Reverse Layer
class GradReverse(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, alpha):
        ctx.alpha = alpha
        return x.view_as(x)

    @staticmethod
    def backward(ctx, grad_output):
        output = grad_output.neg() * ctx.alpha
        return output, None

# export
class LateralUpsampleMerge(nn.Module):

    def __init__(self, ch, ch_lat, hook):
        super().__init__()
        self.hook = hook
        self.conv_lat = conv2d(ch_lat, ch, ks=1, bias=True)

    def forward(self, x):
        return self.conv_lat(self.hook.stored) + F.interpolate(x, scale_factor=2)

class Discriminator(nn.Module):
    def __init__(self, size, n_domains, alpha=1.0):
        super(Discriminator, self).__init__()
        self.alpha = alpha
        self.reducer = nn.Sequential(
            nn.Conv2d(size, size, kernel_size = (3, 3), bias=False),
            nn.BatchNorm2d(size),
            nn.ReLU(inplace = True),
            nn.Dropout(),
            nn.Conv2d(size, size//2, kernel_size = (3, 3), bias=False),
            nn.BatchNorm2d(size//2),
            nn.ReLU(inplace = True),
            nn.Dropout(),
            nn.Conv2d(size//2, size//4, kernel_size = (3, 3), bias=False),
            nn.BatchNorm2d(size//4),
            nn.ReLU(inplace = True),
            nn.Dropout(),
            nn.AdaptiveAvgPool2d((1, 1)),
        )#.cuda()
        self.reducer2 = nn.Linear(size//4, n_domains, bias = False)#.cuda()

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                torch.nn.init.normal_(m.weight, mean=0.0, std=0.01)

    def forward(self, x):
        x = GradReverse.apply(x, self.alpha)
        x = self.reducer(x)
        x = torch.flatten(x,1)
        x = self.reducer2(x)
        return x

class RetinaNetDA(nn.Module):
    "Implements RetinaNet from https://arxiv.org/abs/1708.02002"

    def __init__(self, encoder: nn.Module, n_classes, n_domains, final_bias:float=0.,  n_conv:float=4,
                 chs=256, n_anchors=9, flatten=True, sizes=None, imsize = (512, 512)):
        super().__init__()
        self.n_classes, self.flatten = n_classes, flatten
        self.sizes = sizes
        sfs_szs, x, hooks = self._model_sizes(encoder, size=imsize)
        sfs_idxs = _get_sfs_idxs(sfs_szs)
        self.encoder = encoder
        self.outputs = hook_outputs(self.encoder[-2:-4:-1])
        self.c5top5 = conv2d(sfs_szs[-1][1], chs, ks=1, bias=True)
        self.c5top6 = conv2d(sfs_szs[-1][1], chs, stride=2, bias=True)
        self.p6top7 = nn.Sequential(nn.ReLU(), conv2d(chs, chs, stride=2, bias=True))
        self.merges = nn.ModuleList([LateralUpsampleMerge(chs, szs[1], hook)
                                     for szs, hook in zip(sfs_szs[-2:-4:-1], hooks[-2:-4:-1])])
        self.smoothers = nn.ModuleList([conv2d(chs, chs, 3, bias=True) for _ in range(3)])
        self.classifier = self._head_subnet(n_classes, n_anchors, final_bias, chs=chs, n_conv=n_conv)
        self.box_regressor = self._head_subnet(4, n_anchors, 0., chs=chs, n_conv=n_conv)
        self.n_domains = n_domains
        self.d3 = Discriminator(sfs_szs[-3][1], n_domains)
        self.d4 = Discriminator(sfs_szs[-2][1], n_domains)
        self.d5 = Discriminator(sfs_szs[-1][1], n_domains)

    def _head_subnet(self, n_classes, n_anchors, final_bias=0., n_conv=4, chs=256):
        layers = [self._conv2d_relu(chs, chs, bias=True) for _ in range(n_conv)]
        layers += [conv2d(chs, n_classes * n_anchors, bias=True)]
        layers[-1].bias.data.zero_().add_(final_bias)
        layers[-1].weight.data.fill_(0)
        return nn.Sequential(*layers)

    def _apply_transpose(self, func, p_states, n_classes):
        if not self.flatten:
            sizes = [[p.size(0), p.size(2), p.size(3)] for p in p_states]
            return [func(p).permute(0, 2, 3, 1).view(*sz, -1, n_classes) for p, sz in zip(p_states, sizes)]
        else:
            return torch.cat(
                [func(p).permute(0, 2, 3, 1).contiguous().view(p.size(0), -1, n_classes) for p in p_states], 1)

    def _model_sizes(self, m: nn.Module, size:tuple=(256,256), full:bool=True) -> Tuple[Sizes,Tensor,Hooks]:
        "Passes a dummy input through the model to get the various sizes"
        hooks = hook_outputs(m)
        ch_in = in_channels(m)
        x = torch.zeros(1,ch_in,*size)
        x = m.eval()(x)
        res = [o.stored.shape for o in hooks]
        if not full: hooks.remove()
        return res,x,hooks if full else res

    def _conv2d_relu(self, ni:int, nf:int, ks:int=3, stride:int=1,
                    padding:int=None, bn:bool=False, bias=True) -> nn.Sequential:
        "Create a `conv2d` layer with `nn.ReLU` activation and optional(`bn`) `nn.BatchNorm2d`"
        layers = [conv2d(ni, nf, ks=ks, stride=stride, padding=padding, bias=bias), nn.ReLU()]
        if bn: layers.append(nn.BatchNorm2d(nf))
        return nn.Sequential(*layers)

    def forward(self, x):
        c5 = self.encoder(x)
        p_states = [self.c5top5(c5.clone()), self.c5top6(c5)]
        p_states.append(self.p6top7(p_states[-1]))
        for merge in self.merges:
            p_states = [merge(p_states[0])] + p_states
        for i, smooth in enumerate(self.smoothers[:3]):
            p_states[i] = smooth(p_states[i])
        if self.sizes is not None:
            p_states = [p_state for p_state in p_states if p_state.size()[-1] in self.sizes]
        #d3 = self.d3(self.outputs.stored[1])
        #d4 = self.d4(self.outputs.stored[0])
        d5 = self.d5(c5)

        return [self._apply_transpose(self.classifier, p_states, self.n_classes),
                self._apply_transpose(self.box_regressor, p_states, 4),
                #d3,
                #d4,
                d5,
                [[p.size(2), p.size(3)] for p in p_states]]

In [None]:
from torch.autograd import Variable
from object_detection_fastai.helper.object_detection_helper import *

class RetinaNetFocalLossDA(nn.Module):

    def __init__(self, anchors: Collection[float], gamma: float = 2., alpha: float = 0.25, pad_idx: int = 0,
                 reg_loss: LossFunction = F.smooth_l1_loss, domain_weight: float=0.001, n_domains = 2):
        super().__init__()
        self.gamma, self.alpha, self.pad_idx, self.reg_loss = gamma, alpha, pad_idx, reg_loss
        self.anchors = anchors
        self.metric_names = ['BBloss', 'focal_loss', 'domain_loss', 'total', 'acc']
        self.domain_weight = domain_weight
        self.n_domains = n_domains

    def _unpad(self, bbox_tgt, clas_tgt):
        i = torch.min(torch.nonzero(clas_tgt - self.pad_idx)) if sum(clas_tgt)>0 else 0
        return tlbr2cthw(bbox_tgt[i:]), clas_tgt[i:] - 1 + self.pad_idx

    def _focal_loss(self, clas_pred, clas_tgt):
        encoded_tgt = encode_class(clas_tgt, clas_pred.size(1))
        ps = torch.sigmoid(clas_pred)
        weights = Variable(encoded_tgt * (1 - ps) + (1 - encoded_tgt) * ps)
        alphas = (1 - encoded_tgt) * self.alpha + encoded_tgt * (1 - self.alpha)
        weights.pow_(self.gamma).mul_(alphas)
        clas_loss = F.binary_cross_entropy_with_logits(clas_pred, encoded_tgt, weights, reduction='sum')
        return clas_loss

    def _one_loss(self, clas_pred, bbox_pred, clas_tgt, bbox_tgt):
        bbox_tgt, clas_tgt = self._unpad(bbox_tgt, clas_tgt)
        matches = match_anchors(self.anchors, bbox_tgt)
        bbox_mask = matches >= 0
        if bbox_mask.sum() != 0:
            bbox_pred = bbox_pred[bbox_mask]
            bbox_tgt = bbox_tgt[matches[bbox_mask]]
            bb_loss = self.reg_loss(bbox_pred, bbox_to_activ(bbox_tgt, self.anchors[bbox_mask]))
        else:
            bb_loss = 0.
        matches.add_(1)
        clas_tgt = clas_tgt + 1
        clas_mask = matches >= 0
        clas_pred = clas_pred[clas_mask]
        clas_tgt = torch.cat([clas_tgt.new_zeros(1).long(), clas_tgt])
        clas_tgt = clas_tgt[matches[clas_mask]]
        return bb_loss, self._focal_loss(clas_pred, clas_tgt) / torch.clamp(bbox_mask.sum(), min=1.)

    def domain_focal_loss(self, clas_pred, clas_tgt, alpha=0.25, gamma=2.):
        ce_loss = F.cross_entropy(clas_pred, clas_tgt)
        pt = torch.exp(-ce_loss)
        focal_loss = torch.mul(alpha,((1 - pt) ** gamma * ce_loss))
        return focal_loss


    def forward(self, output, bbox_tgts, clas_tgts, domain_tgts):
        clas_preds, bbox_preds, domain, sizes = output
        if bbox_tgts.device != self.anchors.device:
            self.anchors = self.anchors.to(clas_preds.device)
        bb_loss = torch.tensor(0, dtype=torch.float32).to(clas_preds.device)
        focal_loss = torch.tensor(0, dtype=torch.float32).to(clas_preds.device)
        d_loss = self.domain_focal_loss(domain, domain_tgts)
        acc = torch.true_divide(sum(torch.argmax(domain, dim=1) == domain_tgts), domain_tgts.size(0))
        for cp, bp, ct, bt, dt in zip(clas_preds, bbox_preds, clas_tgts,bbox_tgts, domain_tgts):
            if dt != 3:
                bb, focal = self._one_loss(cp, bp, ct, bt)
                bb_loss += bb
                focal_loss += focal
        total_loss = (bb_loss + focal_loss)/clas_tgts.size(0) - self.domain_weight * ((d_loss)/domain_tgts.size(0))
        self.metrics = dict(zip(self.metric_names, [bb_loss / clas_tgts[domain_tgts!=3].size(0), focal_loss / clas_tgts[domain_tgts!=3].size(0),
                                                    d_loss / domain_tgts.size(0), total_loss,
                                                    acc]))
        return (bb_loss + focal_loss)/clas_tgts[domain_tgts!=3].size(0) + self.domain_weight * ((d_loss)/domain_tgts.size(0))

In [None]:
from fastai.vision import *
from fastai.callbacks import TrackerCallback

class UpdateAlphaCallback(TrackerCallback):
    def __init__(self, learn:Learner, max_epochs):
        super().__init__(learn)
        self.max_epochs = max_epochs

    def on_epoch_begin(self,epoch, **kwargs:Any):
        p = (epoch + 1) / self.max_epochs
        alpha = 2. / (1. + np.exp(-10 * p)) - 1
        self.learn.model.d5.alpha = alpha

In [None]:
from object_detection_fastai.callbacks.callbacks import *
from object_detection_fastai.helper.wsi_loader import SlideObjectItemList,SlideContainer


class DomainAdaptationItem(ItemBase):
    def __init__(self, imagebbox):
        self.imagebbox = imagebbox
        self.scanner_id = imagebbox.sample_kwargs["domain"]
        self.obj = (imagebbox, self.scanner_id)
        self.data = [imagebbox.data]

    def apply_tfms(self, tfms, **kwargs):
        self.imagebbox = self.imagebbox.apply_tfms(tfms, **kwargs)
        self.obj = (self.imagebbox, self.scanner_id)
        self.data = [self.imagebbox.data]
        return self

class SlideObjectCategoryListDA(ObjectCategoryList):

    def get(self, i, x: int = 0, y: int = 0):
        h, w = self.x.items[i].shape
        bboxes, labels = self.items[i]

        bboxes = np.array([box for box in bboxes]) if len(np.array(bboxes).shape) == 1 else np.array(bboxes)
        labels = np.array(labels)

        if len(labels) > 0:
            bboxes[:, [0, 2]] = bboxes[:, [0, 2]] - x
            bboxes[:, [1, 3]] = bboxes[:, [1, 3]] - y

            bb_widths = (bboxes[:, 2] - bboxes[:, 0]) / 2
            bb_heights = (bboxes[:, 3] - bboxes[:, 1]) / 2

            ids = ((bboxes[:, 0] + bb_widths) > 0) \
                  & ((bboxes[:, 1] + bb_heights) > 0) \
                  & ((bboxes[:, 2] - bb_widths) < w) \
                  & ((bboxes[:, 3] - bb_heights) < h)

            bboxes = bboxes[ids]
            bboxes = np.clip(bboxes, 0, max(h, w))
            bboxes = bboxes[:, [1, 0, 3, 2]]

            labels = labels[ids]

        if len(labels) == 0:
            labels = np.array([0])
            bboxes = np.array([[0, 0, 1, 1]])

        image_bbox = ImageBBox.create(h, w, bboxes, labels, classes=self.classes, pad_idx=self.pad_idx)
        image_bbox.sample_kwargs = {"domain":self.x.items[i].y[-1]}
        return DomainAdaptationItem(image_bbox)

    def reconstruct(self, t, x):
        (bboxes, labels, domain) = t
        if len((labels - self.pad_idx).nonzero()) == 0: return
        i = (labels - self.pad_idx).nonzero().min()
        bboxes,labels = bboxes[i:],labels[i:]
        return ImageBBox.create(*x.size, bboxes, labels=labels, classes=self.classes, scale=False)

class ObjectItemListSlideDA(SlideObjectItemList):

    def open(self, fn: SlideContainer,  x: int=0, y: int=0):
        return Image(pil2tensor(fn.get_patch(x, y) / 255., np.float32))

class PascalVOCMetricByDistanceDA(PascalVOCMetric):

    def __init__(self, anchors, size, metric_names: list, detect_thresh: float=0.3, nms_thresh: float=0.5
                 , radius: float=25, images_per_batch: int=-1):
        self.ap = 'AP'
        self.anchors = anchors
        self.size = size
        self.detect_thresh = detect_thresh
        self.nms_thresh = nms_thresh
        self.radius = radius

        self.images_per_batch = images_per_batch
        self.metric_names_original = metric_names
        self.metric_names = ["{}-{}".format(self.ap, i) for i in metric_names]

        self.evaluator = Evaluator()
        self.boundingBoxes = BoundingBoxes()


    def on_batch_end(self, last_output, last_target, **kwargs):
        bbox_gt_batch, class_gt_batch, _ = last_target
        class_pred_batch, bbox_pred_batch = last_output[:2]

        self.images_per_batch = self.images_per_batch if self.images_per_batch > 0 else class_pred_batch.shape[0]
        for bbox_gt, class_gt, clas_pred, bbox_pred in \
                list(zip(bbox_gt_batch, class_gt_batch, class_pred_batch, bbox_pred_batch))[: self.images_per_batch]:

            bbox_pred, scores, preds = process_output(clas_pred, bbox_pred, self.anchors, self.detect_thresh)
            if bbox_pred is None:# or len(preds) > 3 * len(bbox_gt):
                continue

            #image = np.zeros((512, 512, 3), np.uint8)
            t_sz = torch.Tensor([(self.size, self.size)])[None].cpu()
            bbox_pred = to_np(rescale_boxes(bbox_pred.cpu(), t_sz))
            # change from center to top left
            bbox_pred[:, :2] = bbox_pred[:, :2] - bbox_pred[:, 2:] / 2


            temp_boxes = np.copy(bbox_pred)
            temp_boxes[:, 2] = temp_boxes[:, 0] + temp_boxes[:, 2]
            temp_boxes[:, 3] = temp_boxes[:, 1] + temp_boxes[:, 3]


            to_keep = non_max_suppression_by_distance(temp_boxes, to_np(scores), self.radius, return_ids=True)
            bbox_pred, preds, scores = bbox_pred[to_keep], preds[to_keep].cpu(), scores[to_keep].cpu()

            bbox_gt = bbox_gt[np.nonzero(class_gt)].squeeze(dim=1).cpu()
            class_gt = class_gt[class_gt > 0]
            # change gt from x,y,x2,y2 -> x,y,w,h
            bbox_gt[:, 2:] = bbox_gt[:, 2:] - bbox_gt[:, :2]

            bbox_gt = to_np(rescale_boxes(bbox_gt, t_sz))


            class_gt = to_np(class_gt) - 1
            preds = to_np(preds)
            scores = to_np(scores)

            for box, cla in zip(bbox_gt, class_gt):
                temp = BoundingBox(imageName=str(self.imageCounter), classId=self.metric_names_original[cla], x=box[0], y=box[1],
                               w=box[2], h=box[3], typeCoordinates=CoordinatesType.Absolute,
                               bbType=BBType.GroundTruth, format=BBFormat.XYWH, imgSize=(self.size,self.size))

                self.boundingBoxes.addBoundingBox(temp)

            # to reduce math complexity take maximal three times the number of gt boxes
            num_boxes = len(bbox_gt) * 3
            for box, cla, scor in list(zip(bbox_pred, preds, scores))[:num_boxes]:
                temp = BoundingBox(imageName=str(self.imageCounter), classId=self.metric_names_original[cla], x=box[0], y=box[1],
                                   w=box[2], h=box[3], typeCoordinates=CoordinatesType.Absolute, classConfidence=scor,
                                   bbType=BBType.Detected, format=BBFormat.XYWH, imgSize=(self.size, self.size))

                self.boundingBoxes.addBoundingBox(temp)

            #image = self.boundingBoxes.drawAllBoundingBoxes(image, str(self.imageCounter))
            self.imageCounter += 1

def bb_pad_collate_da(samples:BatchSamples, pad_idx:int=0) -> Tuple[FloatTensor, Tuple[LongTensor, LongTensor]]:
    "Function that collect `samples` of labelled bboxes and adds padding with `pad_idx`."
    if isinstance(samples[0][1], int): return data_collate(samples)
    max_len = max([len(s[1].data[0][1]) for s in samples])
    bboxes = torch.zeros(len(samples), max_len, 4)
    labels = torch.zeros(len(samples), max_len).long() + pad_idx
    scanner_ids = torch.zeros(len(samples)).long() + pad_idx
    imgs = []
    for i,s in enumerate(samples):
        imgs.append(s[0].data[None])
        # print(s[1].scanner_id)
        # scanner_ids[i] = s[1].scanner_id 
        bbs, lbls = s[1].data[0]
        if not (bbs.nelement() == 0):
            bboxes[i,-len(lbls):] = bbs
            labels[i,-len(lbls):] = tensor(lbls)
    return torch.cat(imgs,0), (bboxes,labels,scanner_ids)

In [None]:
import os
import sys
# sys.path.append(os.path.abspath('../../SlideRunner'))
# from SlideRunner.dataAccess.database import Database
from object_detection_fastai.helper.wsi_loader import SlideContainer
import numpy as np
from random import *
import json

def sample_function(y, classes, size, level_dimensions, level):
    width, height = level_dimensions[level]
    if len(y[0]) == 0:
        xmin, ymin = randint(0, width - size[0]), randint(0, height - size[1])
    else:
        if randint(0,5) < 3:
            class_id = np.random.choice(classes, 1)[0]
            ids = np.array(y[1]) == class_id
            xmin, ymin, _, _ = np.array(y[0])[ids][randint(0, np.count_nonzero(ids) - 1)]
            xmin -= randint(0,size[0])
            ymin -= randint(0,size[1])
            xmin, ymin = max(0, xmin), max(0, ymin)
            xmin, ymin = min(xmin, width - size[0]), min(ymin, height - size[1])
        else:
            xmin, ymin = randint(0, width - size[0]), randint(0, height - size[1])
    return xmin, ymin


def load_images(slide_folder, annotation_file, res_level, patch_size, scanner_id, categories):
    container = []
    anno_dict = {1: "mitotic figure", 2: "impostor"}
    for image in os.listdir(slide_folder):
        if annotation_file.split(".")[-1] == "json":
            with open(annotation_file) as f:
                data = json.load(f)
                image_id = [i["id"] for i in data["images"] if i["file_name"] == image][0]
                annotations = [anno for anno in data['annotations'] if anno["image_id"] == image_id and anno["category_id"] in categories]
                bboxes = [a["bbox"] for a in annotations]
                labels = [anno_dict[a["category_id"]] for a in annotations]
                container.append(SlideContainer(os.path.join(slide_folder, image), y=[bboxes, labels, scanner_id], level=res_level, width=patch_size, height=patch_size, sample_func=sample_function))
        elif annotation_file.split(".")[-1] == "sqlite":
            DB = Database().open(annotation_file)
            slideid = DB.findSlideWithFilename(image, '')
            DB.loadIntoMemory(slideid)
            bboxes = [DB.annotations[anno].coordinates.flatten() for anno in DB.annotations.keys() if
                      DB.annotations[anno].deleted == 0]
            labels = [DB.annotations[anno].agreedClass for anno in DB.annotations.keys() if
                      DB.annotations[anno].deleted == 0]
            container.append(SlideContainer(os.path.join(slide_folder, image), y=[bboxes, labels], level=res_level,
                                            width=patch_size, height=patch_size))
        else:
            print("Please provide valid annotation format")
    return container

In [None]:
def sample_function(y, classes, size, level_dimensions, level):
    width, height = level_dimensions[level]
    if len(y[0]) == 0:
        return randint(0, width - size[0]), randint(0, height -size[1])
    else:
        #if randint(0, 5) < 2:
        if True:
            class_id = np.random.choice(classes, 1)[0] # select a random class
            ids = np.array(y[1]) == class_id # filter the annotations according to the selected class
            xmin, ymin, _, _ = np.array(y[0])[ids][randint(0, np.count_nonzero(ids) - 1)] # randomly select one of the filtered annotatons as seed for the training patch
            
            # To have the selected annotation not in the center of the patch and an random offset.
            xmin += random.randint(-size[0]/2, size[0]/2) 
            ymin += random.randint(-size[1]/2, size[1]/2)
            xmin, ymin = max(0, int(xmin - size[0] / 2)), max(0, int(ymin -size[1] / 2))
            xmin, ymin = min(xmin, width - size[0]), min(ymin, height - size[1])
            return xmin, ymin
        else:
            return randint(0, width - size[0]), randint(0, height -size[1])

In [None]:
image_folder = midog_folder / "images"

hamamatsu_rx_ids = list(range(0, 51))
hamamatsu_360_ids = list(range(51, 101))
aperio_ids = list(range(101, 151))
leica_ids = list(range(151, 201))

In [None]:
annotation_file = midog_folder / "MIDOG.json"
rows = []
with open(annotation_file) as f:
    data = json.load(f)

    #categories = {cat["id"]: cat["name"] for cat in data["categories"]}
    categories = {1: 'mitotic figure', 2: 'hard negative'}

    for row in data["images"]:
        file_name = row["file_name"]
        image_id = row["id"]
        width = row["width"]
        height = row["height"]

        scanner  = "Hamamatsu XR"
        if image_id in hamamatsu_360_ids:
            scanner  = "Hamamatsu S360"
        if image_id in aperio_ids:
            scanner  = "Aperio CS"
        if image_id in leica_ids:
            scanner  = "Leica GT450"
         
        for annotation in [anno for anno in data['annotations'] if anno["image_id"] == image_id]:
            box = annotation["bbox"]
            cat = categories[annotation["category_id"]]

            rows.append([file_name, image_id, width, height, box, cat, scanner])

df = pd.DataFrame(rows, columns=["file_name", "image_id", "width", "height", "box", "cat", "scanner"])
df.head()

In [None]:
#@title Select a training and validation scanner { run: "auto", display-mode: "form" }

def create_wsi_container(annotations_df: pd.DataFrame):

    container = []

    for image_name in tqdm(annotations_df["file_name"].unique()):

        image_annos = annotations_df[annotations_df["file_name"] == image_name]

        bboxes = [box   for box   in image_annos["box"]]
        labels = [label for label in image_annos["cat"]]

        # container.append(SlideContainer(image_folder/image_name, y=[bboxes, labels], level=res_level,width=patch_size, height=patch_size, sample_func=sample_function))
        container.append(SlideContainer('/drive/MyDrive/Colab Notebooks/MIDOG/MIDOG_Challenge/images/' + image_name, y=[bboxes, labels], level=res_level,width=patch_size, height=patch_size, sample_func=sample_function))

    return container

#@markdown Options can also be combined like:  Hamamatsu XR, Hamamatsu S360
train_scanner = "Hamamatsu XR" #@param ["Hamamatsu XR", "Hamamatsu S360", "Aperio CS"]  {allow-input: true}
val_scanner = "Hamamatsu S360" #@param ["Hamamatsu XR", "Hamamatsu S360", "Aperio CS"]  {allow-input: true}

patch_size = 256 #@param [256, 512, 1024]
res_level = 0

train_annos = df[df["scanner"].isin(train_scanner.split(","))]
train_container = create_wsi_container(train_annos)

val_annos = df[df["scanner"].isin(val_scanner.split(","))]
valid_container = create_wsi_container(val_annos)

f"Created: {len(train_container)} training WSI container and {len(valid_container)} validation WSI container"

In [None]:
# Training

from object_detection_fastai.helper.wsi_loader import *

def get_y_func(x):
    return x.y

if __name__ == '__main__':
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    slide_folder = midog_folder
    model_dir = Path("models")

    patch_size = 512
    batch_size = 12 # param
    res_level = 0
    bs = 12
    domain_weight = 1
    lr = 1e-4
    train_samples_per_scanner = 1500
    val_samples_per_scanner = 500
    scales = [0.2, 0.4, 0.6, 0.8, 1.0]
    ratios = [1]
    sizes = [(64, 64), (32, 32), (16, 16)]
    num_epochs = 200

    train_scanners = [["A","B","C","D"]]
    valid_scanners = [["A","B","C","D"]]
    annotation_json ='E:/Slides/MIDOG/MIDOG.json'


    tfms = get_transforms(do_flip=True,
                          flip_vert=True,
                          max_lighting=0.5,
                          max_zoom=2,
                          max_warp=0.2,
                          p_affine=0.5,
                          p_lighting=0.5,
                          )


    for t_scrs, v_scrs in zip (train_scanners, valid_scanners):
        learner_name = 'DA_RetinaNet'
        train_images = []
        valid_images = []

        train_samples = list(np.random.choice(train_container, train_samples_per_scanner))
        train_images.append(train_samples)
        valid_samples = list(np.random.choice(valid_container, val_samples_per_scanner))
        valid_images.append(valid_samples)

        train_images = list(np.random.choice(train_container, train_samples_per_scanner))
        valid_images = list(np.random.choice(valid_container, val_samples_per_scanner))

        train = ObjectItemListSlide(train_images)
        valid = ObjectItemListSlide(valid_images)
        item_list = ItemLists(".", train, valid)
        lls = item_list.label_from_func(get_y_func, label_cls=SlideObjectCategoryListDA)
        lls = lls.transform(tfms, tfm_y=True, size=patch_size)
        data = lls.databunch(bs=bs, collate_fn=bb_pad_collate_da, num_workers=0).normalize()

        data.train_dl = data.train_dl.new(shuffle=False) #set shuffle to false so that batch always contains all 4 scanners
        data.valid_dl = data.valid_dl.new(shuffle=False)
        anchors = create_anchors(sizes=sizes, ratios=ratios, scales=scales)
        crit = RetinaNetFocalLossDA(anchors, domain_weight=domain_weight, n_domains=len(t_scrs))
        encoder = create_body(models.resnet18, True, -2)
        # Careful: Number of anchors has to be adapted to scales
        model = RetinaNetDA(encoder, n_classes=data.train_ds.c, n_domains=len(t_scrs), n_anchors=len(scales) * len(ratios),
                            sizes=[size[0] for size in sizes], chs=128, final_bias=-4., n_conv=3, imsize = (patch_size, patch_size))
        voc = PascalVOCMetricByDistanceDA(anchors, patch_size,[str(i) for i in data.train_ds.y.classes[1:]])

        learn = Learner(data, model, loss_func=crit, 
                        callback_fns=[BBMetrics, ShowGraph], 
                        metrics=[voc]
                      )

        learn.path = Path(os.getcwd())
        alpha_up = UpdateAlphaCallback(learn, num_epochs)
        cyc_len = 1
        max_learning_rate = 1e-3

        learn.fit_one_cycle(cyc_len, max_learning_rate)
        learn.export('{}.pkl'.format(learner_name))
        print("Saved model as {}".format(learner_name))

In [None]:
#@title Take a look a the results { run: "auto", vertical-output: true, display-mode: "form" }

detect_thresh = 0.5 #@param {type:"slider", min:0.1, max:1, step:0.1}
nms_thresh = 0.2 #@param {type:"slider", min:0.1, max:1, step:0.1}
image_count=15 #@param {type:"integer"}

show_results_side_by_side(learn, anchors, detect_thresh=detect_thresh, nms_thresh=nms_thresh, image_count=image_count)