# Initialization

## Packages

In [None]:
%%time
# Torch 1.7
!pip install '/kaggle/input/pytorch-170-cuda-toolkit-110221/torch-1.7.0+cu110-cp37-cp37m-linux_x86_64.whl' --no-deps

# MMCV & MMDET Requirements
!pip install ../input/mmdetection-v217/mmdetection/addict-2.4.0-py3-none-any.whl --no-deps
!pip install ../input/mmdetection-v217/mmdetection/yapf-0.31.0-py2.py3-none-any.whl --no-deps
!pip install ../input/mmdetection-v217/mmdetection/terminal-0.4.0-py3-none-any.whl --no-deps
!pip install ../input/mmdetection-v217/mmdetection/terminaltables-3.1.0-py3-none-any.whl --no-deps
!pip install ../input/mmdetection-v217/mmdetection/pycocotools-2.0.2/pycocotools-2.0.2 --no-deps
!pip install ../input/mmdetection-v217/mmdetection/mmpycocotools-12.0.3/mmpycocotools-12.0.3 --no-deps

# MMCV
!pip install ../input/detection-packages/mmcv_full-1.3.16-cp37-cp37m-manylinux1_x86_64.whl --no-deps
# MMDET
!pip install ../input/detection-packages/mmdetection-2.17.0/mmdetection-2.17.0 --no-deps
# !pip install ../input/detection-packages/mmdetection-master/mmdetection-master --no-deps

In [None]:
!cd /kaggle/working/

## Imports

In [None]:
import os
import gc
import cv2
import json
import glob
import warnings
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

from tqdm.notebook import tqdm

warnings.simplefilter("ignore", UserWarning)

### Params

In [None]:
DATA_PATH = "/kaggle/input/sartorius-cell-instance-segmentation/"
CELL_TYPES = ["shsy5y", "astro", "cort"]
ORIG_SIZE = (520, 704)

# Functions

## Utils

### Torch

In [None]:
import os
import torch
import random
import numpy as np


def seed_everything(seed):
    """
    Seeds basic parameters for reproductibility of results.

    Args:
        seed (int): Number of the seed.
    """
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def load_model_weights(model, filename, verbose=1, cp_folder=""):
    """
    Loads the weights of a PyTorch model. The exception handles cpu/gpu incompatibilities.

    Args:
        model (torch model): Model to load the weights to.
        filename (str): Name of the checkpoint.
        verbose (int, optional): Whether to display infos. Defaults to 1.
        cp_folder (str, optional): Folder to load from. Defaults to "".

    Returns:
        torch model: Model with loaded weights.
    """

    if verbose:
        print(f"\n -> Loading weights from {os.path.join(cp_folder,filename)}\n")
    try:
        model.load_state_dict(os.path.join(cp_folder, filename), strict=True)
    except BaseException:
        model.load_state_dict(
            torch.load(os.path.join(cp_folder, filename), map_location="cpu"),
            strict=True,
        )
    return model

### Plots

In [None]:
import cv2
import skimage
import numpy as np
import plotly.express as px
import matplotlib.pyplot as plt
import plotly.graph_objects as go

from scipy import ndimage
from mmdet.core import BitmapMasks
from matplotlib.patches import Rectangle

# from utils.metrics import compute_iou


GREEN = (56 / 255, 200 / 255, 100 / 255)
BLUE = (32 / 255, 50 / 255, 155 / 255)
RED = (238 / 255, 97 / 255, 55 / 255)


def get_random_color():
    color = tuple(np.random.random(size=3))
    while np.max(color) - np.min(color) < 0.2:
        color = tuple(np.random.random(size=3))
    return color


def plot_sample(img, mask=None, boxes=[], width=1, plotly=False):
    """
    Plots the contours of a given mask.

    Args:
        img (numpy array [H x W]): Image.
        mask (numpy array [H x W x C]): Masks.
        width (int, optional): Contour width. Defaults to 1.

    Returns:
        img (numpy array [H x W]): Image with contours.
    """

    if img.max() > 1:
        img = (img / 255).astype(float)

    if len(img.shape) == 2:
        img = np.stack([img, img, img], -1)

    img_ = img.copy()

    colors = []

    if isinstance(mask, BitmapMasks):
        mask = mask.masks.astype(int)
        for i in range(len(mask)):
            mask[i] *= (i + 1)

    if mask is not None:
        if len(mask.shape) == 3:
            if mask.max() == 1:
                for i in range(len(mask)):
                    mask[i] *= (i + 1)
            mask = mask.max(0)

        for i in range(1, int(np.max(mask)) + 1):
            m = ((mask == i) * 255).astype(np.uint8)
            color = get_random_color()
            colors.append(color)

            contours, _ = cv2.findContours(m, cv2.RETR_LIST, cv2.CHAIN_APPROX_NONE)
            cv2.polylines(img_, contours, True, color, width)

    if not plotly:
        plt.imshow(img_)

        # Add boxes
        for i, box in enumerate(boxes):
            color = colors[i] if len(colors) else get_random_color()
            rect = Rectangle(
                (box[0], box[1]), box[2] - box[0], box[3] - box[1],
                linewidth=1, edgecolor=color, facecolor='none', alpha=0.5
            )
            plt.gca().add_patch(rect)

    if plotly:
        return px.imshow(img_)


### Config

In [None]:
class Config:
    """
    Placeholder to load a config from a saved json
    """
    def __init__(self, dic):
        for k, v in dic.items():
            setattr(self, k, v)

### RLE

In [None]:
def rle_encoding(x):
    dots = np.where(x.flatten() == 1)[0]
    run_lengths = []
    prev = -2
    for b in dots:
        if (b > prev + 1):
            run_lengths.extend((b + 1, 0))
        run_lengths[-1] += 1
        prev = b
    return ' '.join(map(str, run_lengths))


def rle_decode(mask_rle, shape):
    """
    Decodes a rle.

    Args:
        mask_rle (str): Run length encoding.
        shape (tuple [2]): Mask size (height, width).

    Returns:
        np array [shape]: Mask.
    """
    s = mask_rle.split()
    starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
    starts -= 1
    ends = starts + lengths
    img = np.zeros(shape[0] * shape[1], dtype=np.uint8)
    for lo, hi in zip(starts, ends):
        img[lo: hi] = 1
    return img.reshape(shape)

## Data

### Transforms

In [None]:
import mmcv
from mmcv.utils import build_from_cfg
from mmdet.datasets.builder import PIPELINES
from mmdet.datasets.pipelines import Compose


def define_pipelines(config_file, multi_image=False):
    pipe_cfg = mmcv.Config.fromfile(config_file).data

    if not multi_image:
        pipelines = {
            k: Compose(
                [build_from_cfg(aug, PIPELINES, None) for aug in pipe_cfg[k].pipeline]
            ) for k in pipe_cfg
        }
    else:
        pipelines = {
            k: pipe_cfg[k].pipeline for k in pipe_cfg
        }
    return pipelines

### Dataset

In [None]:
import cv2
import pycocotools
import numpy as np
from torch.utils.data import Dataset
from mmdet.core import BitmapMasks


RESULTS_PH = {
    'scale_factor': np.ones(4, dtype=np.float32),  # if no resizing in augs
    "pad_shape": (0, 0),
    "img_norm_cfg": None,
    "flip_direction": None,
    "flip": None,
    'img_fields': ["img"],
    'bbox_fields': ["gt_bboxes"],
    'mask_fields': ["gt_masks"]
}


class SartoriusInferenceDataset(Dataset):
    """
    Segmentation dataset for training / validation.
    """
    def __init__(self, df, transforms, precompute_masks=True):
        """
        Constructor.

        Args:
            df (pandas dataframe): Metadata.
            transforms (albumentation transforms, optional): Transforms to apply. Defaults to None.
            train (bool, optional): Indicates if the dataset is used for training. Defaults to True.
        """

        self.df = df
        self.transforms = transforms

        self.img_paths = df["img_path"].values

    def __len__(self):
        return self.df.shape[0]

    def __getitem__(self, idx):
        image = cv2.imread(self.img_paths[idx])

        results = {
            "img": image,
            "img_shape": image.shape[:2],
            "ori_shape": image.shape[:2],
            "filename": self.img_paths[idx],
            'ori_filename': self.img_paths[idx],
        }
        results.update(RESULTS_PH)
        del results['bbox_fields'], results['mask_fields']

        results_transfo = self.transforms(results.copy())

        # if 'scale_factor' not in results_transfo.keys():
        #     results_transfo['scale_factor'] = np.ones(4)

        return results_transfo


### Loaders

In [None]:
from functools import partial
from mmcv.parallel import collate
from torch.utils.data import DataLoader


def define_loaders(
    train_dataset=None, val_dataset=None,  batch_size=32, val_bs=32, num_workers=0
):
    """
    Builds data loaders. TODO

    Args:
        train_dataset (CollageingDataset): Dataset to train with.
        val_dataset (CollageingDataset): Dataset to validate with.
        samples_per_patient (int, optional): Number of images to use per patient. Defaults to 0.
        batch_size (int, optional): Training batch size. Defaults to 32.
        val_bs (int, optional): Validation batch size. Defaults to 32.

    Returns:
       DataLoader: Train loader.
       DataLoader: Val loader.
    """
    train_loader, val_loader = None, None

    if val_dataset is not None:
        val_loader = DataLoader(
            val_dataset,
            batch_size=val_bs,
            shuffle=False,
            collate_fn=partial(collate, samples_per_gpu=batch_size),
            num_workers=num_workers,
            pin_memory=True,
        )

    return train_loader, val_loader


## Model

In [None]:
import sys
sys.path.append('../input/timm-pytorch-image-models/pytorch-image-models-master')
import timm

import numpy as np
import torch.nn as nn
from mmdet.models.builder import BACKBONES

MEAN = np.array([0.66437738, 0.50478148, 0.70114894])
STD = np.array([0.15825711, 0.24371008, 0.13832686])


@BACKBONES.register_module()
class EfficientNet(nn.Module):
    def __init__(self, name, blocks_idx, pretrained=True):
        """
        Constructor.

        Args:
            name (name): Model name as specified in timm.
            blocks_idx (list of ints): Blocks to output features at.
            pretrained (bool, optional): Whether to load pretrained weights. Defaults to True.
        """
        super().__init__()

        self.effnet = getattr(timm.models, name)(
            pretrained=pretrained,
            drop_path_rate=0.2,
        )

        self.block_idx = blocks_idx
        self.nb_fts = [self.effnet.blocks[b][-1].conv_pwl.out_channels for b in self.block_idx]
        self.nb_ft = self.nb_fts[-1]

        if "efficientnetv2" in name:
            self.mean = np.array([0.5, 0.5, 0.5])
            self.std = np.array([0.5, 0.5, 0.5])
        else:
            self.mean = MEAN
            self.std = STD

        self.name = name

    def forward(self, x):  # should return a tuple
        """
        Extract features for an EfficientNet model.
        Args:
            x (torch tensor [BS x 3 x H x W]): Input image.
        Returns:
            list of torch tensors: features.
        """
        x = self.effnet.conv_stem(x)
        x = self.effnet.bn1(x)
        x = self.effnet.act1(x)

        features = []
        for i, b in enumerate(self.effnet.blocks):
            x = b(x)
            if i in self.block_idx:
                features.append(x)
            # print(i, x.size(), i in self.block_idx)

        return features


In [None]:
import sys
import mmcv
import torch
import logging

from mmcv.parallel import MMDataParallel
from mmdet.models.builder import build_detector


def define_model(config_file, encoder="resnet50", pretrained_livecell=False, verbose=1):
    # Configs
    cfg = mmcv.Config.fromfile(config_file)

    config_backbone_file = config_file.rsplit('/', 1)[0] + "/config_backbones.py"
    cfg_backbones = mmcv.Config.fromfile(config_backbone_file)
    
    if 'pretrained' in cfg_backbones.backbones[encoder].keys():
        cfg_backbones.backbones[encoder]['pretrained'] = False

    cfg.model.backbone = cfg_backbones.backbones[encoder]

    if encoder in cfg_backbones.out_channels.keys():  # update neck channels
        cfg.model.neck.in_channels = cfg_backbones.out_channels[encoder]

    # Build model
    model = build_detector(cfg.model)
    model.test_cfg = cfg["model"]["test_cfg"]
    model.train_cfg = cfg["model"]["train_cfg"]

    # Reduce stride
    if "resnet" in encoder or "resnext" in encoder:
        model.backbone.conv1.stride = (1, 1)
    elif "efficientnet" in encoder:
        model.backbone.effnet.conv_stem.stride = (1, 1)

    model = MMDataParallel(model)

#     # Weights
#     try:
#         weights = (
#             cfg.pretrained_weights_livecell[encoder]
#             if pretrained_livecell
#             else cfg.pretrained_weights[encoder]
#         )
#     except KeyError:
#         weights = None

#     model = load_pretrained_weights(
#         model,
#         weights,
#         verbose=verbose,
#         adapt_swin="swin" in encoder and not pretrained_livecell
#     )

    return model

## Ensemble Model

### Merging

In [None]:
import torch
from mmcv.ops import nms

from mmdet.core.bbox import bbox_mapping_back


def merge_aug_proposals(aug_proposals, img_metas, cfg):
    """Merge augmented proposals (multiscale, flip, etc.)

    Args:
        aug_proposals (list[Tensor]): proposals from different testing
            schemes, shape (n, 5). Note that they are not rescaled to the
            original image size.

        img_metas (list[dict]): list of image info dict where each dict has:
            'img_shape', 'scale_factor', 'flip', and may also contain
            'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
            For details on the values of these keys see
            `mmdet/datasets/pipelines/formatting.py:Collect`.

        cfg (dict): rpn test config.

    Returns:
        Tensor: shape (n, 4), proposals corresponding to original image scale.
    """
    # Recover augmented proposals
    recovered_proposals = []
    for proposals, img_info in zip(aug_proposals, img_metas):
        img_shape = img_info["img_shape"]
        scale_factor = img_info["scale_factor"]
        flip = img_info["flip"]
        flip_direction = img_info["flip_direction"]
        _proposals = proposals.clone()
        _proposals[:, :4] = bbox_mapping_back(
            _proposals[:, :4], img_shape, scale_factor, flip, flip_direction
        )
        recovered_proposals.append(_proposals)

    # Merge proposals with NMS
    aug_proposals = torch.cat(recovered_proposals, dim=0)
    merged_proposals, _ = nms(
        aug_proposals[:, :4].contiguous(),
        aug_proposals[:, 4].contiguous(),
        cfg.nms.iou_threshold,
    )

    # Reorder
    scores = merged_proposals[:, 4]

    scores, order = scores.sort(0, descending=True)

    order = order[scores > cfg.score_thr]
    order = order[:cfg.max_per_img]

    merged_proposals = merged_proposals[order, :]

    return merged_proposals


def merge_aug_bboxes(aug_bboxes, aug_scores, img_metas):
    """
    Merge augmented detection bboxes and scores.
    This simply takes the mean.

    Args:
        aug_bboxes (list[Tensor]): shape (n, 4*#class)
        aug_scores (list[Tensor] or None): shape (n, #class)
        img_shapes (list[Tensor]): shape (3, ).

    Returns:
        tuple: (bboxes, scores)
    """
    # Recover augmented proposals
    recovered_bboxes = []
    for bboxes, img_info in zip(aug_bboxes, img_metas):
        img_shape = img_info[0]["img_shape"]
        scale_factor = img_info[0]["scale_factor"]
        flip = img_info[0]["flip"]
        flip_direction = img_info[0]["flip_direction"]
        bboxes = bbox_mapping_back(
            bboxes, img_shape, scale_factor, flip, flip_direction
        )
        recovered_bboxes.append(bboxes)

    # Merge boxes by averaging predictions
    bboxes = torch.stack(recovered_bboxes).mean(dim=0)

    if aug_scores is None:
        return bboxes
    else:
        scores = torch.stack(aug_scores).mean(dim=0)
        return bboxes, scores


def single_class_boxes_nms(merged_bboxes, merged_scores, iou_threshold=0.5):
    # Use most confident class per candidate
    det_scores, det_labels = torch.max(merged_scores, 1)

    # Get class & corresponding iou threshold
    cell_type = torch.mode(det_labels, 0).values.item()
    thresh = iou_threshold if isinstance(iou_threshold, (float, int)) else iou_threshold[cell_type]

    # Filter with NMS
    det_bboxes, inds = nms(
        merged_bboxes.contiguous(), det_scores.contiguous(), thresh
    )

    return det_bboxes, det_labels[inds]


### Custom mask fct

In [None]:
import copy
import torch
import numpy as np
from warnings import warn

from mmcv.ops import batched_nms
from mmdet.models.roi_heads.mask_heads.fcn_mask_head import (
    BYTES_PER_FLOAT,
    GPU_MEM_LIMIT,
    _do_paste_mask,
)


def get_seg_masks(
    mask_head,
    mask_pred,
    det_bboxes,
    det_labels,
    rcnn_test_cfg,
    ori_shape,
    scale_factor,
    rescale,
    return_per_class=True,
):
    """
    Modified version of mmdet/models/roi_heads/mask_heads/fcn_mask_head.py
    to add the return_per_class argument.
    """
    if isinstance(mask_pred, torch.Tensor):
        mask_pred = mask_pred.sigmoid()
    else:
        # In AugTest, has been activated before
        mask_pred = det_bboxes.new_tensor(mask_pred)

    device = mask_pred.device
    cls_segms = [
        [] for _ in range(mask_head.num_classes)
    ]  # BG is not included in num_classes
    bboxes = det_bboxes[:, :4]
    labels = det_labels

    # In most cases, scale_factor should have been
    # converted to Tensor when rescale the bbox

    if not isinstance(scale_factor, torch.Tensor):
        if isinstance(scale_factor, float):
            scale_factor = np.array([scale_factor] * 4)
            warn(
                "Scale_factor should be a Tensor or ndarray "
                "with shape (4,), float would be deprecated. "
            )
        assert isinstance(scale_factor, np.ndarray)
        scale_factor = torch.Tensor(scale_factor)

    if rescale:
        img_h, img_w = ori_shape[:2]
        bboxes = bboxes / scale_factor.to(bboxes.device)
    else:
        w_scale, h_scale = scale_factor[0], scale_factor[1]
        img_h = np.round(ori_shape[0] * h_scale.item()).astype(np.int32)
        img_w = np.round(ori_shape[1] * w_scale.item()).astype(np.int32)

    N = len(mask_pred)
    if device.type == "cpu":
        num_chunks = N
    else:
        # GPU benefits from parallelism for larger chunks,
        num_chunks = int(
            np.ceil(N * int(img_h) * int(img_w) * BYTES_PER_FLOAT / GPU_MEM_LIMIT)
        )
        assert num_chunks <= N, "Default GPU_MEM_LIMIT is too small; try increasing it"
    chunks = torch.chunk(torch.arange(N, device=device), num_chunks)

    threshold = rcnn_test_cfg.mask_thr_binary
    im_mask = torch.zeros(
        N,
        img_h,
        img_w,
        device=device,
        dtype=torch.bool if threshold >= 0 else torch.uint8,
    )

    if not mask_head.class_agnostic:
        mask_pred = mask_pred[range(N), labels][:, None]

    for inds in chunks:
        masks_chunk, spatial_inds = _do_paste_mask(
            mask_pred[inds], bboxes[inds], img_h, img_w, skip_empty=device.type == "cpu"
        )

        if threshold >= 0:
            masks_chunk = (masks_chunk >= threshold).to(dtype=torch.bool)
        else:
            masks_chunk = (masks_chunk * 255).to(dtype=torch.uint8)

        im_mask[(inds,) + spatial_inds] = masks_chunk

    for i in range(N):
        cls_segms[labels[i]].append(im_mask[i].detach().cpu().numpy())

    if return_per_class:
        return cls_segms
    else:
        return im_mask


def get_rpn_boxes_single(
    rpn_head,
    cls_scores,
    bbox_preds,
    mlvl_anchors,
    img_shape,
    scale_factor,
    cfg,
):
    """
    Modified from mmdet/models/dense_heads/rpn_head.py

    Transform outputs for a single batch item into bbox predictions.

        Args:
        cls_scores (list[Tensor]): Box scores of all scale level
            each item has shape (num_anchors * num_classes, H, W).
        bbox_preds (list[Tensor]): Box energies / deltas of all
            scale level, each item has shape (num_anchors * 4, H, W).
        mlvl_anchors (list[Tensor]): Anchors of all scale level
            each item has shape (num_total_anchors, 4).
        img_shape (tuple[int]): Shape of the input image,
            (height, width, 3).
        scale_factor (ndarray): Scale factor of the image arrange as
            (w_scale, h_scale, w_scale, h_scale).
        cfg (mmcv.Config): Test / postprocessing configuration,
            if None, test_cfg would be used.
    Returns:
        Tensor: Labeled boxes in shape (n, 5), where the first 4 columns
            are bounding box positions (tl_x, tl_y, br_x, br_y) and the
            5-th column is a score between 0 and 1.
    """
    cfg = copy.deepcopy(cfg)
    # bboxes from different level should be independent during NMS,
    # level_ids are used as labels for batched NMS to separate them
    level_ids = []
    mlvl_scores = []
    mlvl_bbox_preds = []
    mlvl_valid_anchors = []
    for idx in range(len(cls_scores)):
        rpn_cls_score = cls_scores[idx]
        rpn_bbox_pred = bbox_preds[idx]
        assert rpn_cls_score.size()[-2:] == rpn_bbox_pred.size()[-2:]
        rpn_cls_score = rpn_cls_score.permute(1, 2, 0)
        if rpn_head.use_sigmoid_cls:
            rpn_cls_score = rpn_cls_score.reshape(-1)
            scores = rpn_cls_score.sigmoid()
        else:
            rpn_cls_score = rpn_cls_score.reshape(-1, 2)
            # We set FG labels to [0, num_class-1] and BG label to
            # num_class in RPN head since mmdet v2.5, which is unified to
            # be consistent with other head since mmdet v2.0. In mmdet v2.0
            # to v2.4 we keep BG label as 0 and FG label as 1 in rpn head.
            scores = rpn_cls_score.softmax(dim=1)[:, 0]
        rpn_bbox_pred = rpn_bbox_pred.permute(1, 2, 0).reshape(-1, 4)
        anchors = mlvl_anchors[idx]
        if cfg.nms_pre > 0 and scores.shape[0] > cfg.nms_pre:
            # sort is faster than topk
            # _, topk_inds = scores.topk(cfg.nms_pre)
            ranked_scores, rank_inds = scores.sort(descending=True)
            topk_inds = rank_inds[: cfg.nms_pre]
            scores = ranked_scores[: cfg.nms_pre]
            rpn_bbox_pred = rpn_bbox_pred[topk_inds, :]
            anchors = anchors[topk_inds, :]
        mlvl_scores.append(scores)
        mlvl_bbox_preds.append(rpn_bbox_pred)
        mlvl_valid_anchors.append(anchors)
        level_ids.append(scores.new_full((scores.size(0),), idx, dtype=torch.long))

    scores = torch.cat(mlvl_scores)
    anchors = torch.cat(mlvl_valid_anchors)
    rpn_bbox_pred = torch.cat(mlvl_bbox_preds)
    proposals = rpn_head.bbox_coder.decode(anchors, rpn_bbox_pred, max_shape=img_shape)
    ids = torch.cat(level_ids)

    if cfg.min_bbox_size >= 0:
        w = proposals[:, 2] - proposals[:, 0]
        h = proposals[:, 3] - proposals[:, 1]
        valid_mask = (w > cfg.min_bbox_size) & (h > cfg.min_bbox_size)
        if not valid_mask.all():
            proposals = proposals[valid_mask]
            scores = scores[valid_mask]
            ids = ids[valid_mask]
    if proposals.numel() > 0:
        dets, keep = batched_nms(proposals, scores, ids, cfg.nms)
    else:
        return proposals.new_zeros(0, 5)

    # print(dets.size())
    dets = dets[dets[:, 4] > cfg.score_thr]
    # print(dets.size())
    dets = dets[:cfg.max_per_img]
    # print(dets.size())

    return dets


def get_rpn_boxes(rpn_head, cls_scores, bbox_preds, img_metas, cfg):
    """
    TODO
    Modified from mmdet/models/dense_heads/rpn_head.py

    Args:
        rpn_head ([type]): [description]
        cls_scores ([type]): [description]
        bbox_preds ([type]): [description]
        img_metas ([type]): [description]
        cfg ([type], optional): [description]. Defaults to None.
        rescale (bool, optional): [description]. Defaults to False.
        with_nms (bool, optional): [description]. Defaults to True.

    Returns:
        [type]: [description]
    """
    assert len(cls_scores) == len(bbox_preds)
    num_levels = len(cls_scores)
    device = cls_scores[0].device
    featmap_sizes = [cls_scores[i].shape[-2:] for i in range(num_levels)]
    mlvl_anchors = rpn_head.anchor_generator.grid_anchors(featmap_sizes, device=device)

    result_list = []
    for img_id in range(len(img_metas)):
        cls_score_list = [cls_scores[i][img_id].detach() for i in range(num_levels)]
        bbox_pred_list = [bbox_preds[i][img_id].detach() for i in range(num_levels)]
        img_shape = img_metas[img_id]["img_shape"]
        scale_factor = img_metas[img_id]["scale_factor"]
        proposals = get_rpn_boxes_single(
            rpn_head,
            cls_score_list,
            bbox_pred_list,
            mlvl_anchors,
            img_shape,
            scale_factor,
            cfg,
        )
        result_list.append(proposals)

    return result_list


### Wrappers

In [None]:
import torch
import torch.functional as F


def get_wrappers(names):
    wrappers = []
    for name in names:
        if "rcnn" in name:
            wrappers.append(RCNNEnsemble)
        elif "cascade" in name:
            wrappers.append(CascadeEnsemble)
        elif "htc" in name:
            wrappers.append(HTCEnsemble)
        else:
            raise NotImplementedError

    return wrappers


class RCNNEnsemble:
    @staticmethod
    def get_boxes(model, x, rois, img_shape, scale_factor, img_meta, num_classes):
        bbox_results = model.roi_head._bbox_forward(x, rois)
        bboxes, scores = model.roi_head.bbox_head.get_bboxes(
            rois,
            bbox_results["cls_score"],
            bbox_results["bbox_pred"],
            img_shape,
            scale_factor,
            rescale=False,
            cfg=None,
        )

        # Keep only desired classes
        scores = scores[:, :num_classes]

        # Keep box corresponding to most confident class
        _, det_labels = torch.max(scores, 1)

        bboxes = bboxes.view(bboxes.size(0), -1, 4)
        bboxes = torch.stack([bboxes[i, c] for i, c in enumerate(det_labels)])

        return bboxes, scores

    @staticmethod
    def get_masks(model, x, mask_rois, num_classes):
        masks = model.roi_head._mask_forward(x, mask_rois)["mask_pred"]
        masks = masks.sigmoid().cpu().numpy()[:, :num_classes]

        return masks


class CascadeEnsemble:
    @staticmethod
    def get_boxes(model, x, rois, img_shape, scale_factor, img_meta, num_classes):
        # https://github.com/open-mmlab/mmdetection/blob/bde7b4b7eea9dd6ee91a486c6996b2d68662366d/mmdet/models/roi_heads/test_mixins.py#L139

        ms_scores = []
        for i in range(model.roi_head.num_stages):
            bbox_results = model.roi_head._bbox_forward(i, x, rois)
            ms_scores.append(bbox_results["cls_score"])

            if i < model.roi_head.num_stages - 1:
                cls_score = bbox_results["cls_score"]
                if model.roi_head.bbox_head[i].custom_activation:
                    cls_score = model.roi_head.bbox_head[i].loss_cls.get_activation(
                        cls_score
                    )
                bbox_label = cls_score[:, :-1].argmax(dim=1)
                rois = model.roi_head.bbox_head[i].regress_by_class(
                    rois, bbox_label, bbox_results["bbox_pred"], img_meta[0]
                )

        cls_score = sum(ms_scores) / float(len(ms_scores))
        bboxes, scores = model.roi_head.bbox_head[-1].get_bboxes(
            rois,
            cls_score,
            bbox_results["bbox_pred"],
            img_shape,
            scale_factor,
            rescale=False,
            cfg=None,
        )

        scores = scores[:, :num_classes]

        return bboxes, scores

    @staticmethod
    def get_masks(model, x, mask_rois, num_classes):
        masks = []
        for i in range(model.roi_head.num_stages):
            mask = model.roi_head._mask_forward(i, x, mask_rois)['mask_pred']
            mask = mask.sigmoid()[:, :num_classes]
            masks.append(mask)
        masks = torch.stack(masks)
        masks = masks.mean(0).cpu().numpy()

        return masks


class HTCEnsemble:
    @staticmethod
    def get_boxes(model, x, rois, img_shape, scale_factor, img_meta, num_classes):
        # https://github.com/open-mmlab/mmdetection/blob/a7a16afbf2a4bdb4d023094da73d325cb864838b/mmdet/models/roi_heads/htc_roi_head.py#L505
        semantic = model.roi_head.semantic_head(x)[1]

        ms_scores = []
        for i in range(model.roi_head.num_stages):
            bbox_head = model.roi_head.bbox_head[i]
            bbox_results = model.roi_head._bbox_forward(
                i, x, rois, semantic_feat=semantic
            )
            ms_scores.append(bbox_results["cls_score"])

            if i < model.roi_head.num_stages - 1:
                bbox_label = bbox_results["cls_score"].argmax(dim=1)
                rois = bbox_head.regress_by_class(
                    rois, bbox_label, bbox_results["bbox_pred"], img_meta[0]
                )

        cls_score = sum(ms_scores) / float(len(ms_scores))
        bboxes, scores = model.roi_head.bbox_head[-1].get_bboxes(
            rois,
            cls_score,
            bbox_results["bbox_pred"],
            img_shape,
            scale_factor,
            rescale=False,
            cfg=None,
        )

        scores = scores[:, :num_classes]

        return bboxes, scores

    @staticmethod
    def get_masks(model, x, mask_rois, num_classes):
        # https://github.com/open-mmlab/mmdetection/blob/a7a16afbf2a4bdb4d023094da73d325cb864838b/mmdet/models/roi_heads/htc_roi_head.py#L592

        mask_feats = model.roi_head.mask_roi_extractor[-1](
            x[: len(model.roi_head.mask_roi_extractor[-1].featmap_strides)], mask_rois
        )

        # Semantic feats
        semantic = model.roi_head.semantic_head(x)[1]
        mask_semantic_feat = model.roi_head.semantic_roi_extractor([semantic], mask_rois)
        if mask_semantic_feat.shape[-2:] != mask_feats.shape[-2:]:
            mask_semantic_feat = F.adaptive_avg_pool2d(
                mask_semantic_feat, mask_feats.shape[-2:]
            )
        mask_feats += mask_semantic_feat

        last_feat = None
        masks = []
        for i in range(model.roi_head.num_stages):
            mask_head = model.roi_head.mask_head[i]
            if model.roi_head.mask_info_flow:
                mask_pred, last_feat = mask_head(mask_feats, last_feat)
            else:
                mask_pred = mask_head(mask_feats)
            masks.append(mask_pred.sigmoid()[:, :num_classes])

        masks = torch.stack(masks)
        masks = masks.mean(0).cpu().numpy()

        return masks


### Ensemble Model

In [None]:
import mmcv
import torch
import numpy as np

from torch import nn
from mmdet.core import bbox_mapping
from mmdet.core import bbox2roi, merge_aug_masks
from mmdet.models.detectors import BaseDetector


class EnsembleModel(BaseDetector):
    """
    Wrapper to ensemble models.
    """
    def __init__(
        self,
        models,
        config,
        names=[],
    ):
        """
        Constructor.

        Args:
            models (list of mmdet MMDataParallel): Models to ensemble.
            config (dict): Ensemble config.
            names (list, optional): Model names. Defaults to [].
        """
        super().__init__()
        self.models = nn.ModuleList([model.module for model in models])
        self.config = config
        self.names = names

        self.wrappers = get_wrappers(self.names)
        self.get_configs()

    def get_configs(self):
        """
        Creates the rpn and rcnn configs from the config dict.
        """
        self.rpn_cfgs, self.rcnn_cfgs = [], []

        for i in range(3):
            rpn_cfg = mmcv.Config(
                dict(
                    score_thr=self.config['rpn_score_threshold'][i],
                    nms_pre=self.config['rpn_nms_pre'][i],
                    max_per_img=self.config['rpn_max_per_img'][i],
                    nms=dict(type="nms", iou_threshold=self.config['rpn_iou_threshold'][i]),
                    min_bbox_size=0,
                )
            )
            rcnn_cfg = mmcv.Config(
                dict(
                    score_thr=self.config['rcnn_score_threshold'][i],
                    nms=dict(type="nms", iou_threshold=self.config['rcnn_iou_threshold'][i]),
                    mask_thr_binary=-1,
                )
            )
            self.rpn_cfgs.append(rpn_cfg)
            self.rcnn_cfgs.append(rcnn_cfg)

    def extract_feat(self, img, img_metas, **kwargs):
        """
        Extract features function. Not used but required by MMDet.

        Args:
            imgs (list of torch tensors [n x C x H x W]): Input image.
            img_metas (list of dicts [n]): List of MMDet image metadata.
        """
        pass

    def simple_test(self, img, img_metas, **kwargs):
        """
        Single image test function. Not used but required by MMDet.

        Args:
            imgs (list of torch tensors [1 x C x H x W]): Input image.
            img_metas (list of dicts [1]): List of MMDet image metadata.
        """
        pass

    def forward(self, img, img_metas, **kwargs):
        """
        Forward function.

        Args:
            imgs (list of torch tensors [n_tta x C x H x W]): Input image.
            img_metas (list of dicts [n_tta]): List of MMDet image metadata.

        Returns:
            [type]: [description]
        """
        return self.aug_test(img, img_metas, **kwargs)

    def get_proposals(self, imgs, img_metas):
        """
        Gets proposals, doesn't use TTA.

        Args:
            imgs (list of torch tensors [n_tta x C x H x W]): Input image.
            img_metas (list of dicts [n_tta]): List of MMDet image metadata.

        Returns:
            list of torch tensors [1 x 5]: Proposals.
            int: Cell type.
        """
        aug_bboxes, aug_scores = [], []
        for i, model in enumerate(self.models):
            for x, img_meta in zip(model.extract_feats(imgs), img_metas):
                cls_score, bbox_pred = model.rpn_head(x)

                aug_bboxes.append(bbox_pred)
                aug_scores.append(cls_score)

                break  # no tta

        merged_bboxes, merged_scores = [], []
        level_counts = []

        for lvl in range(len(aug_bboxes[0])):
            merged_bboxes_lvl = torch.stack([bboxes[lvl] for bboxes in aug_bboxes]).mean(dim=0)
            merged_scores_lvl = torch.stack([scores[lvl] for scores in aug_scores]).mean(dim=0)

            merged_bboxes.append(merged_bboxes_lvl)
            merged_scores.append(merged_scores_lvl)

            rpn_scores_lvl, rpn_labels_lvl = torch.max(
                merged_scores_lvl.sigmoid().flatten(start_dim=2)[0], 0
            )
            level_counts.append(rpn_labels_lvl[rpn_scores_lvl > 0.7].size(0))

        if np.sum(level_counts[-2:]) > 10:  # astro
            cell_type = 1
        elif np.sum(level_counts) < 4500 and level_counts[1] < 750:  # cort
            cell_type = 2
        else:  # shsy5y
            cell_type = 0

        proposal_list = get_rpn_boxes(
            self.models[0].rpn_head,
            merged_scores,
            merged_bboxes,
            img_metas[0],
            self.rpn_cfgs[cell_type]
        )

        return proposal_list, cell_type

    def get_proposals_tta(self, imgs, img_metas):
        """
        TODO
        This doesn't work yet.

        Args:
            imgs ([type]): [description]
            img_metas ([type]): [description]

        Returns:
            [type]: [description]
        """
        raise NotImplementedError

    def get_bboxes(self, imgs, img_metas, proposal_list, rcnn_cfg):
        """
        Gets rcnn boxes. Adapted from :
        https://github.com/open-mmlab/mmdetection/blob/bde7b4b7eea9dd6ee91a486c6996b2d68662366d/mmdet/models/roi_heads/test_mixins.py#L139
        All TTAs are used.

        Args:
            imgs (list of torch tensors [n_tta x C x H x W]): Input images.
            img_metas (list of dicts [n_tta]): List of MMDet image metadata.
            proposal_list ([1 x N]): Proposals.

        Returns:
            torch tensor [m x 6]: Kept boxes, confidences & labels.
            list of torch tensors: Augmented boxes before merging.
        """
        aug_bboxes, aug_scores, aug_img_metas = [], [], []

        for wrapper, model in zip(self.wrappers, self.models):
            for x, img_meta in zip(model.extract_feats(imgs), img_metas):
                img_shape = img_meta[0]["img_shape"]
                scale_factor = img_meta[0]["scale_factor"]
                flip = img_meta[0]["flip"]
                flip_direction = img_meta[0]["flip_direction"]

                proposals = bbox_mapping(
                    proposal_list[0][:, :4],
                    img_shape,
                    scale_factor,
                    flip,
                    flip_direction,
                )
                rois = bbox2roi([proposals])

                bboxes, scores = wrapper.get_boxes(
                    model, x, rois, img_shape, scale_factor, img_meta, self.config['num_classes']
                )

                aug_bboxes.append(bboxes)
                aug_scores.append(scores)
                aug_img_metas.append(img_meta)

        merged_bboxes, merged_scores = merge_aug_bboxes(
            aug_bboxes, aug_scores, aug_img_metas
        )

        if self.config['bbox_nms']:
            det_bboxes, det_labels = single_class_boxes_nms(
                merged_bboxes,
                merged_scores,
                iou_threshold=rcnn_cfg.nms.iou_threshold,
            )
            det_bboxes = torch.cat([det_bboxes, det_labels.unsqueeze(-1)], -1)

        else:
            det_scores, det_labels = torch.max(merged_scores, 1)
            det_bboxes = torch.cat(
                [merged_bboxes, det_scores.unsqueeze(1), det_labels.unsqueeze(1)], 1
            )

            _, order = det_scores.sort(0, descending=True)
            det_bboxes = det_bboxes[order]

        det_bboxes = det_bboxes[det_bboxes[:, 4] > rcnn_cfg.score_thr]

        return det_bboxes, torch.cat([merged_bboxes, merged_scores], 1)

    def get_masks(self, imgs, img_metas, det_bboxes, det_labels):
        """
        Gets rcnn boxes. Adapted from :
        https://github.com/open-mmlab/mmdetection/blob/bde7b4b7eea9dd6ee91a486c6996b2d68662366d/mmdet/models/roi_heads/test_mixins.py#L282

        Only hflip TTA is used.

        Args:
            imgs (list of torch tensors [n_tta x C x H x W]): Input images.
            img_metas (list of dicts [n_tta]): List of MMDet image metadata.
            det_bboxes (torch tensor [m x 5): Boxes & confidences.
            det_labels (torch tensor [m]): Labels.

        Returns:
            torch tensor [m x H x W]: Masks.
            list of torch tensors: Augmented masks before merging.
        """
        aug_masks, aug_img_metas = [], []

        for wrapper, model in zip(self.wrappers, self.models):
            for x, img_meta in zip(model.extract_feats(imgs), img_metas):
                img_shape = img_meta[0]["img_shape"]
                scale_factor = img_meta[0]["scale_factor"]
                flip = img_meta[0]["flip"]
                flip_direction = img_meta[0]["flip_direction"]
                
                if flip_direction not in self.config['ttas_masks']:
                    continue

                _bboxes = bbox_mapping(
                    det_bboxes[:, :4], img_shape, scale_factor, flip, flip_direction
                )
                mask_rois = bbox2roi([_bboxes])
                
                # Seems to help
                if self.config['delta']:
                    if flip_direction in ['vertical', 'diagonal']:
                        mask_rois[:, 2] = torch.clamp(mask_rois[:, 2] - self.config['delta'], 0, img_shape[0])
                        mask_rois[:, 4] = torch.clamp(mask_rois[:, 4] - self.config['delta'], 0, img_shape[0])
                    if flip_direction in ['horizontal', 'diagonal']:
                        mask_rois[:, 1] = torch.clamp(mask_rois[:, 1] - self.config['delta'], 0, img_shape[1])
                        mask_rois[:, 3] = torch.clamp(mask_rois[:, 3] - self.config['delta'], 0, img_shape[1])

                masks = wrapper.get_masks(model, x, mask_rois, self.config['num_classes'])

                aug_masks.append(masks)
                aug_img_metas.append(img_meta)

        merged_masks = merge_aug_masks(aug_masks, aug_img_metas, None)

        mask_head = (
            self.models[0].roi_head.mask_head if "rcnn" in self.names[0]
            else self.models[0].roi_head.mask_head[-1]
        )

        masks = get_seg_masks(
            mask_head,
            merged_masks,
            det_bboxes,
            det_labels,
            self.rcnn_cfgs[0],
            img_metas[0][0]["ori_shape"],
            scale_factor=det_bboxes.new_ones(4),
            rescale=False,
            return_per_class=False,
        )

        return masks, aug_masks

    def aug_test(self, imgs, img_metas, return_everything=False, **kwargs):
        """
        Augmented test function. Adapted from :
        https://github.com/open-mmlab/mmdetection/blob/bde7b4b7eea9dd6ee91a486c6996b2d68662366d/mmdet/models/roi_heads/standard_roi_head.py#L268

        Args:
            imgs (list of torch tensors [n_tta x C x H x W]): Input images.
            img_metas (list of dicts [n_tta]): List of MMDet image metadata.
            return_everything (bool, optional): Whether to return more stuff. Defaults to False.

        Returns:
            torch tensor [m x 6]: Kept boxes, confidences & labels.
            torch tensor [m x H x W]: Masks.
            list of torch tensors [1 x 5]: Proposals.
            list of torch tensors: Augmented boxes before merging.
            list of torch tensors: Augmented masks before merging.
        """
        proposal_list, cell_type = self.get_proposals(imgs, img_metas)

        bboxes, aug_bboxes = self.get_bboxes(
            imgs, img_metas, proposal_list, self.rcnn_cfgs[cell_type]
        )

        assert self.models[0].with_mask

        if bboxes.shape[0] == 0:
            return bboxes, None

        masks, aug_masks = self.get_masks(
            imgs, img_metas, bboxes[:, :5], bboxes[:, 5].long()
        )

        if return_everything:
            all_stuff = (proposal_list, aug_bboxes, bboxes, aug_masks, masks)
            return (bboxes, masks), all_stuff

        return (bboxes, masks)


## Post-processing

### Overlapping

In [None]:
import pycocotools


def remove_overlap_naive(masks, ious=None):
    if ious is None:
        rles = [pycocotools.mask.encode(np.asarray(m, order='F')) for m in masks]
        ious = pycocotools.mask.iou(rles, rles, [0] * len(rles))

    for i in range(len(ious)):
        ious[i, i] = 0

    to_process = np.where(ious.sum(0) > 0)[0]

    if not len(to_process):
        return masks

    masks = torch.from_numpy(masks).cuda()
    overlapping_masks = masks[to_process]

    for idx, i in enumerate(to_process):
        if idx == 0:
            continue
        others = overlapping_masks[:idx].max(0)[0]
        masks[i] *= ~others

    return masks.cpu().numpy()

### NMS

In [None]:
def mask_nms(masks, boxes, threshold=0.5):
    """
    NMS with masks.
    Removes more masks than the tweaking fct.

    Args:
        masks ([type]): [description]
        boxes ([type]): [description]
        threshold (float, optional): [description]. Defaults to 0.5.

    Returns:
        [type]: [description]
    """
    # assert list(np.argsort(boxes[:, 4])[::-1]) == list(range(len(boxes)))

    order = np.argsort(boxes[:, 4])[::-1]
    masks = masks[order]
    boxes = boxes[order]

    rle_pred = [pycocotools.mask.encode(np.asarray(m, order='F')) for m in masks]
    ious = pycocotools.mask.iou(rle_pred, rle_pred, [0] * len(rle_pred))

    picks = []
    idxs = list(range(len(ious)))
    # removed = []

    while len(idxs) > 0:
        idx = idxs[0]
        overlapping = np.where(ious[idx] > threshold)[0]

        # removed += [v for v in overlapping if v > idx]

        if len(overlapping):
            picks.append(idx)
            idxs = [i for i in idxs if i not in overlapping]
        else:
            idxs = idxs[1:]

    masks = masks[picks]
    boxes = boxes[picks]
    return masks, boxes, picks

### Remove Small Masks

In [None]:
def remove_small_masks(masks, boxes, min_size=0):
    if min_size == 0:
        return masks, boxes

    sizes = masks.sum(-1).sum(-1)
    to_keep = sizes > min_size

    if to_keep.min() == 1:
        return masks, boxes

    smallest = sizes.min()
    to_keep = sizes > smallest

    return masks[to_keep], boxes[to_keep]

### Corrupt

In [None]:
import cv2

def degrade_mask(mask):
    cont, hier = cv2.findContours(mask.astype(np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)

    img_cont = np.zeros((mask.shape[0], mask.shape[1], 3), dtype=np.uint8)
    img_cont = cv2.drawContours(img_cont, cont, -1, (255, 255, 255), 1)
    img_cont = img_cont[:, :, 0]

    conv_mask = np.zeros((mask.shape[0], mask.shape[1], 3), dtype=np.uint8)

    for c in cont:
        conv_mask = cv2.fillConvexPoly(conv_mask, points=c, color=(1, 1, 1))
    conv_mask = conv_mask[:, :, 0].astype(mask.dtype)

    return conv_mask, img_cont

### Overall fct

In [None]:
def process_masks(boxes, masks, thresholds_mask, thresholds_nms, thresholds_conf, min_sizes, remove_overlap=True, corrupt=False):
    # Cell type
    cell = np.argmax(np.bincount(boxes[:, 5].astype(int)))

    # Thresholds
    thresh_mask = (
        thresholds_mask if isinstance(thresholds_mask, (float, int))
        else thresholds_mask[cell]
    )
    thresh_nms = (
        thresholds_nms if isinstance(thresholds_nms, (float, int))
        else thresholds_nms[cell]
    )
    thresh_conf = (
        thresholds_conf if isinstance(thresholds_conf, (float, int))
        else thresholds_conf[cell]
    )
    min_size = (
        min_sizes if isinstance(min_sizes, (float, int))
        else min_sizes[cell]
    )

    # Binarize
    masks = masks > (thresh_mask * 255)

    # Sort by decreasing conf
    order = np.argsort(boxes[:, 4])[::-1]
    masks = masks[order]
    boxes = boxes[order]

    # Remove low confidence
    last = (
        np.argmax(boxes[:, 4] < thresh_conf) if np.min(boxes[:, 4]) < thresh_conf
        else len(boxes)
    )
    masks = masks[:last]
    boxes = boxes[:last]

    # NMS
    if thresh_nms > 0:
        masks, boxes, _ = mask_nms(masks, boxes, thresh_nms)
        
    # Remove small masks
    if min_size:
        masks, boxes = remove_small_masks(masks, boxes, min_size=min_size)
       
    # Corrupt
    if corrupt and cell == 1:  # astro
        masks = np.array([degrade_mask(mask)[0] for mask in masks])

    # Remove overlap
    if remove_overlap:
        masks = remove_overlap_naive(masks)

    return masks, boxes, cell

## Predict

In [None]:
import torch


def predict_and_process(dataset, model, thresholds_mask, thresholds_nms, thresholds_conf, min_sizes, device="cuda"):
    """
    Performs inference on an image.
    TODO

    Args:
        dataset (InferenceDataset): Inference dataset.
        model (torch model): Segmentation model.
        batch_size (int, optional): Batch size. Defaults to 32.
        tta (bool, optional): Whether to apply tta. Defaults to False.

    Returns:
        torch tensor [H x W]: Prediction on the image.
    """
    loader = define_loaders(None, dataset, val_bs=1, num_workers=0)[1]

    rles, cell_types = [], []

    model.eval()
    with torch.no_grad():
        for batch in loader:
            boxes, masks = model(**batch, return_loss=False, rescale=True)
            boxes = boxes.cpu().numpy()
            masks = masks.cpu().numpy()
            
            masks, boxes, cell_type = process_masks(
                boxes, masks, thresholds_mask, thresholds_nms, thresholds_conf, min_sizes, remove_overlap=True, corrupt=True
            )

            rles.append([rle_encoding(mask) for mask in masks])
            cell_types.append(cell_type)

    return rles, cell_types


## Inference

In [None]:
def inference(df, configs, weights, ensemble_config, thresholds_mask, thresholds_nms, thresholds_conf, min_sizes):

    pipelines = define_pipelines(configs[0].data_config)

    models, names = [], []
    for config, fold_weights in zip(configs, weights):
        for weight in fold_weights:
            model = define_model(
                config.model_config, encoder=config.encoder, verbose=0
            )
            model = load_model_weights(model, weight)
            models.append(model)
            names.append(weight.split('/')[-1])

    dataset = SartoriusInferenceDataset(df, transforms=pipelines['test_tta'] if ensemble_config["use_tta"] else pipelines['test'])
    
    model = MMDataParallel(
        EnsembleModel(
            models,
            ensemble_config,
            names=names,
        )
    )

    rles, cell_types = predict_and_process(
        dataset, model, thresholds_mask, thresholds_nms, thresholds_conf, min_sizes, device=config.device
    )

    return rles, cell_types


## Merge rles

In [None]:
def merge_rles(df_cort, df_astro, df_shsy5y):
    df = df_cort.copy().merge(df_astro[['id', 'rle_astro']], on="id", how="left")
    df = df.merge(df_shsy5y[['id', 'rle_shsy5y']], on="id", how="left")
    
    rles = []
    for i in range(len(df)):
        if df["cell_type"][i] == 0:
            rle = df['rle_shsy5y'][i]
        elif df["cell_type"][i] == 1:
            rle = df['rle_astro'][i]
        else:
            rle = df['rle_cort'][i]
            
        assert isinstance(rle, list)
        rles.append(rle)
    
    return df, rles

# Main

## Data

In [None]:
df = pd.read_csv(DATA_PATH + "sample_submission.csv")
df['img_path'] = DATA_PATH + "test/" + df['id'] + ".png"

In [None]:
SUCCESS = [False, False, False]

In [None]:
ENSEMBLE_CONFIG = {
    "use_tta": True,
    "use_tta_masks": True,
    "num_classes": 3,

    "rpn_nms_pre": [5000, 2000, 1000],
    "rpn_iou_threshold": [0.7, 0.75, 0.6],
    "rpn_score_threshold": [0.9, 0.9, 0.95],
    "rpn_max_per_img": [None, None, None],

    "bbox_nms": True,
    "rcnn_iou_threshold": [0.7, 0.9, 0.6],
    "rcnn_score_threshold": [0.2, 0.25, 0.5],
    
    "ttas_masks": [None, "horizontal", "vertical", "diagonal"],
    "delta": 0.5,
}

# # Best LB :
# MIN_SIZES = [0, 0, 0]
# Best CV :
MIN_SIZES = [0, 150, 75]

## Cort & Classification

In [None]:
# Best lb :
LOG_PATH = "../input/sartorius-cps-ens11/"

EXP_FOLDERS = [
    LOG_PATH + "2021-12-11_2/",  # 1. Cascade b5 - 0.3121
    LOG_PATH + "2021-12-11_4/",  # 2. Cascade rx101 - 0.3141
    LOG_PATH + "2021-12-12_0/",  # 3. Cascade r50 - 0.3125
    LOG_PATH + "seb_mrcnn_resnext101_lossdecay/", # 11. mrcnn r101 0.3131
    LOG_PATH + "seb_mrcnn_r50_lossdecay/", # 12. mrcnn r50 0.3125
    LOG_PATH + "2021-12-15_0/",  # 14. Cascade b6 - 0.3121
]

THRESHOLDS_MASK = 0.45
THRESHOLDS_NMS = [0.1, 0.1, 0.15]
THRESHOLDS_CONF = [0.3, 0.4, 0.65]
ENSEMBLE_CONFIG['rcnn_score_threshold'] = THRESHOLDS_CONF

In [None]:
# # Best CV :
# LOG_PATH = "../input/sartorius-cps-ens11/"
# LOG_PATH_2 = "../input/sartorius-cps-last/"

# EXP_FOLDERS = [
#     LOG_PATH + "2021-12-11_2/",  # 1. Cascade b5 - 0.3121
#     LOG_PATH + "2021-12-11_4/",  # 2. Cascade rx101 - 0.3141
#     LOG_PATH + "2021-12-15_0/",  # 14. Cascade b6 - 0.3121
#     LOG_PATH_2 + "seb_mrcnn_b5/", # 19. mrcnn b5 - 0.3086
#     LOG_PATH_2 + "2021-12-22_6/",  #  21. htc b4 - 0.3083
#     LOG_PATH_2 + "seb_mrcnn_rx101_decay_bn_flip_aug/",  # 24. mrcnn rx101 - 0.3141
# ]

# THRESHOLDS_MASK = 0.45
# THRESHOLDS_NMS = [0.1, 0.1, 0.15]
# THRESHOLDS_CONF = [0.3, 0.4, 0.65]  # [0.3, 0.4, 0.65]

In [None]:
configs, weights = [], []

for exp_folder in EXP_FOLDERS:
    config = Config(json.load(open(exp_folder + "config.json", 'r')))
    config.model_config = exp_folder + config.model_config.split('/')[-1]
    config.data_config = exp_folder + config.data_config.split('/')[-1]
    configs.append(config)

    weights.append(sorted(glob.glob(exp_folder + "*.pt")))

In [None]:
%%time
rles_cort, cell_types = inference(df, configs, weights, ENSEMBLE_CONFIG, THRESHOLDS_MASK, THRESHOLDS_NMS, THRESHOLDS_CONF, MIN_SIZES)

df['rle_cort'] = rles_cort
df['cell_type'] = cell_types
SUCCESS[0] = True

In [None]:
torch.cuda.empty_cache()
gc.collect()

## Astro

In [None]:
LOG_PATH = "../input/sartorius-cps-ens11/"

EXP_FOLDERS = [
    LOG_PATH + "2021-12-11_2/",  # 1. Cascade b5 - 0.3121
    LOG_PATH + "2021-12-11_4/",  # 2. Cascade rx101 - 0.3141
    LOG_PATH + "2021-12-12_0/",  # 3. Cascade r50 - 0.3125
    LOG_PATH + "seb_mrcnn_resnext101_lossdecay/", # 11. mrcnn r101 0.3131
    LOG_PATH + "2021-12-15_1/",  # 15. htc r50 - 0.3121
    LOG_PATH + "seb_mrcnn_r101_64x4/",  # 22. mrcnn rx101_64x4 - 0.3127
]

THRESHOLDS_MASK = 0.45
THRESHOLDS_NMS = [0.05, 0.05, 0.05]
THRESHOLDS_CONF = [0.45, 0.45, 0.45]
ENSEMBLE_CONFIG['rcnn_score_threshold'] = THRESHOLDS_CONF

ENSEMBLE_CONFIG["ttas_masks"] = [None, "horizontal", "vertical", "diagonal"]
ENSEMBLE_CONFIG["delta"] = 0.5

In [None]:
configs, weights = [], []

for exp_folder in EXP_FOLDERS:
    config = Config(json.load(open(exp_folder + "config.json", 'r')))
    config.model_config = exp_folder + config.model_config.split('/')[-1]
    config.data_config = exp_folder + config.data_config.split('/')[-1]
    configs.append(config)

    weights.append(sorted(glob.glob(exp_folder + "*.pt")))

In [None]:
%%time
assert SUCCESS[0]

df_astro = df[df['cell_type'] == 1].reset_index()

rles_astro, _ = inference(df_astro, configs, weights, ENSEMBLE_CONFIG, THRESHOLDS_MASK, THRESHOLDS_NMS, THRESHOLDS_CONF, MIN_SIZES)

df_astro['rle_astro'] = rles_astro
SUCCESS[1] = True

In [None]:
torch.cuda.empty_cache()
gc.collect()

## Shsy5y

In [None]:
LOG_PATH = "../input/sartorius-cps-ens11/"
LOG_PATH_2 = "../input/sartorius-cps-ens10/"

EXP_FOLDERS = [
    LOG_PATH + "2021-12-12_0/",  # 3. Cascade r50 - 0.3125
    LOG_PATH_2 + "seb_mrcnn_resnet50_new_splits/", # 8. maskrcnn r50 - 0.3118
    LOG_PATH + "2021-12-15_1/",  # 15. htc r50 - 0.3121
    LOG_PATH_2 + "2021-12-20_1/",  #  16. Cascade rx101_64x4 - 0.3130
    LOG_PATH_2 + "2021-12-22_2/",  #  20. cascade b6 192 crops - 0.3118
    LOG_PATH + "seb_mrcnn_r101_64x4/",  # 22. mrcnn rx101_64x4 - 0.3127
]

THRESHOLDS_MASK = 0.45
THRESHOLDS_NMS = [0.1, 0.1, 0.1]
THRESHOLDS_CONF = [0.35, 0.35, 0.35]
ENSEMBLE_CONFIG['rcnn_score_threshold'] = THRESHOLDS_CONF


# Best CV :
ENSEMBLE_CONFIG["ttas_masks"] = [None, "horizontal"]  #, "vertical", "diagonal"]
ENSEMBLE_CONFIG["delta"] = 0.
# # Best LB ? :
# ENSEMBLE_CONFIG["ttas_masks"] = [None, "horizontal", "vertical", "diagonal"]
# ENSEMBLE_CONFIG["delta"] = 0.5

In [None]:
configs, weights = [], []

for exp_folder in EXP_FOLDERS:
    config = Config(json.load(open(exp_folder + "config.json", 'r')))
    config.model_config = exp_folder + config.model_config.split('/')[-1]
    config.data_config = exp_folder + config.data_config.split('/')[-1]
    configs.append(config)

    weights.append(sorted(glob.glob(exp_folder + "*.pt")))

In [None]:
%%time
assert SUCCESS[1]

df_shsy5y = df[df['cell_type'] == 0].reset_index()

rles_shsy5y, _ = inference(df_shsy5y, configs, weights, ENSEMBLE_CONFIG, THRESHOLDS_MASK, THRESHOLDS_NMS, THRESHOLDS_CONF, MIN_SIZES)

df_shsy5y['rle_shsy5y'] = rles_shsy5y
SUCCESS[2] = True

In [None]:
torch.cuda.empty_cache()
gc.collect()

## Sub + viz

In [None]:
assert all(SUCCESS)

df, rles = merge_rles(df, df_astro, df_shsy5y)

submission = []
pipelines = define_pipelines(configs[0].data_config)
dataset = SartoriusInferenceDataset(df, transforms=pipelines['test_viz'], precompute_masks=False)

for idx, (rle, img_id) in enumerate(zip(rles, df['id'].values)):
    if idx < 3:
        img = dataset[idx]['img'][0].numpy().transpose(1, 2, 0)
        img = (img - img.min()) / (img.max() - img.min())
        img = img[:ORIG_SIZE[0], :ORIG_SIZE[1]]

        masks = np.array([rle_decode(enc, ORIG_SIZE) for enc in rle])

        assert masks.sum(0).max() <= 1
        
        plt.figure(figsize=(15, 15))
        plot_sample(img, masks.astype(int))
        plt.axis(False)
        plt.show()
    
    for enc in rle:
        submission.append((img_id, enc))
        
    if not len(rle):  # Empty
        submission.append((image_id, ""))

df_sub = pd.DataFrame(submission, columns=['id', 'predicted'])
df_sub.to_csv("submission.csv", index=False)
df_sub.head()

Done !