In [None]:
!pip install /kaggle/input/mmdetectionv2140/addict-2.4.0-py3-none-any.whl > /dev/null
!pip install /kaggle/input/mmdetectionv2140/yapf-0.31.0-py2.py3-none-any.whl > /dev/null
!pip install /kaggle/input/mmdetectionv2140/terminal-0.4.0-py3-none-any.whl > /dev/null
!pip install /kaggle/input/mmdetectionv2140/terminaltables-3.1.0-py3-none-any.whl > /dev/null
!pip install /kaggle/input/mmdetectionv2140/pycocotools-2.0.2/pycocotools-2.0.2 > /dev/null
!pip install /kaggle/input/mmdetectionv2140/mmpycocotools-12.0.3/mmpycocotools-12.0.3 > /dev/nullnull
import sys
sys.path.append('../input/timm-pytorch-image-models/pytorch-image-models-master')


!pip install ../input/mmdetection/mmcv_full-1.3.10-cp37-cp37m-linux_x86_64.whl > /dev/null
!pip install ../input/mmdetection/mmdet-2.15.0-py3-none-any.whl > /dev/null

In [None]:
import os
import cv2
import glob
import json
import copy
import numpy as np
import pandas as pd
from PIL import Image
from tqdm.auto import tqdm
from itertools import product
import pycocotools.mask as mutils
from pycocotools.coco import COCO

import torch
from mmcv.ops import nms
from mmdet.core import bbox_mapping_back, merge_aug_proposals, multiclass_nms, bbox2result, bbox2roi, bbox_mapping, merge_aug_bboxes, merge_aug_masks
from mmdet.apis.inference import init_detector, replace_ImageToTensor, Compose, collate, scatter

def mask2rle(msk):
    pixels = msk.flatten()
    pad    = np.array([0])
    pixels = np.concatenate([pad, pixels, pad])
    runs   = np.where(pixels[1:] != pixels[:-1])[0] + 1
    runs[1::2] -= runs[::2]
    return ' '.join(str(x) for x in runs)

def rle2mask(rle, shape = [520, 704]):
    s = rle.split()
    starts, lengths = [np.asarray(x, dtype = int) for x in (s[0:][::2], s[1:][::2])]
    starts -= 1
    ends = starts + lengths
    img = np.zeros(shape[0] * shape[1], dtype = np.uint8)
    for lo, hi in zip(starts, ends):
        img[lo:hi] = 1
    return img.reshape(shape)



configs = [
    "../input/sartorius-submission/r2htc101_1x_hvflip_rot90_tiny_d2_800_all.py",
    "../input/sartorius-submission/htcx101_1x_d2_800_all.py",
    "../input/sartorius-submission/htcr2101_1x_d2_800_f0.py",
]

ckpts = [
    "../input/sartorius-submission/r2htc101_1x_hvflip_rot90_tiny_d2_800_all.pth",
    "../input/sartorius-submission/htcx101_1x_d2_800_all.pth",
    "../input/sartorius-submission/htcr2101_1x_d2_800_f0.pth",
]

model_weights = [
    1, 
    1,
    1
]
model_weights = np.array(model_weights) / sum(model_weights) * len(model_weights)

cfg_options = [{
    "data.test.pipeline.1.img_scale": [(1333,1333),(1024,1024)],
    "data.test.pipeline.1.flip": True,
    "data.test.pipeline.1.flip_direction": ['horizontal','vertical'],
    "moedl.test_cfg.rcnn.score_thr": 0.3,
    "model.test_cfg.rcnn.nms.type": "weighted_cluster_nms",
    "model.test_cfg.rcnn.nms.iou_method": "diou",
    "model.test_cfg.rcnn.nms.iou_threshold": 0.45
},{
    "data.test.pipeline.1.img_scale": [(1333,1333),(1024,1024)],
    "data.test.pipeline.1.flip": True,
    "data.test.pipeline.1.flip_direction": ['horizontal','vertical'],
    "moedl.test_cfg.rcnn.score_thr": 0.3,
    "model.test_cfg.rcnn.nms.type": "weighted_cluster_nms",
    "model.test_cfg.rcnn.nms.iou_method": "diou",
    "model.test_cfg.rcnn.nms.iou_threshold": 0.45
},{
    "data.test.pipeline.1.img_scale": [(1333,1333),(1024,1024)],
    "data.test.pipeline.1.flip": True,
    "data.test.pipeline.1.flip_direction": ['horizontal','vertical'],
    "moedl.test_cfg.rcnn.score_thr": 0.3,
    "model.test_cfg.rcnn.nms.type": "weighted_cluster_nms",
    "model.test_cfg.rcnn.nms.iou_method": "diou",
    "model.test_cfg.rcnn.nms.iou_threshold": 0.45
}]


models = [init_detector(config, ckpt, cfg_options = c) for config, ckpt, c in zip(configs, ckpts, cfg_options)]
THRESHOLDS_small = [0.4, 0.45, 0.7]
MIN_PIXELS = [80, 150, 60]
MAX_OVERLAP = [0.2, 0.2, 0.2]
small_ids = [0,1,2]

img_files = glob.glob("../input/sartorius-cell-instance-segmentation/test/*.*")
print(len(img_files))

In [None]:
def extract_feats(models, imgs):
    """Inference image(s) with the detector.

    Args:
        model (nn.Module): The loaded detector.
        imgs (str/ndarray or list[str/ndarray] or tuple[str/ndarray]):
           Either image files or loaded images.

    Returns:
        If imgs is a list or tuple, the same length list type results
        will be returned, otherwise return the detection results directly.
    """
    model = models[0]
    if isinstance(imgs, (list, tuple)):
        is_batch = True
    else:
        imgs = [imgs]
        is_batch = False

    cfg = model.cfg
    device = next(model.parameters()).device  # model device

    if isinstance(imgs[0], np.ndarray):
        cfg = cfg.copy()
        # set loading pipeline type
        cfg.data.test.pipeline[0].type = 'LoadImageFromWebcam'

    cfg.data.test.pipeline = replace_ImageToTensor(cfg.data.test.pipeline)
    test_pipeline = Compose(cfg.data.test.pipeline)

    datas = []
    for img in imgs:
        # prepare data
        if isinstance(img, np.ndarray):
            # directly add img
            data = dict(img=img)
        else:
            # add information into dict
            data = dict(img_info=dict(filename=img), img_prefix=None)
        # build the data pipeline
        data = test_pipeline(data)
        datas.append(data)

    data = collate(datas, samples_per_gpu=len(imgs))
    # just get the actual data from DataContainer
    data['img_metas'] = [img_metas.data[0] for img_metas in data['img_metas']]
    data['img'] = [img.data[0] for img in data['img']]
    if next(model.parameters()).is_cuda:
        # scatter to specified GPU
        data = scatter(data, [device])[0]
    else:
        for m in model.modules():
            assert not isinstance(
                m, RoIPool
            ), 'CPU inference with RoIPool is not supported currently.'

    # forward the model
    res = []
    for model in models:
        with torch.no_grad():
            results = model.extract_feats(data["img"])
        res.append(results)
    return res, data["img"], data["img_metas"]

def rpn_head_aug_test(models, res, img_metas):
    samples_per_gpu = len(img_metas[0])
    aug_proposals = [[] for _ in range(samples_per_gpu)]
    aug_img_metas = []
    cfg = models[0].rpn_head.test_cfg

    for model, feats in zip(models, res):
        for x, img_meta in zip(feats, img_metas):
            proposal_list = model.rpn_head.simple_test_rpn(x, img_meta)
            for i, proposals in enumerate(proposal_list):
                aug_proposals[i].append(proposals)
        # reorganize the order of 'img_metas' to match the dimensions
        # of 'aug_proposals'
        for i in range(samples_per_gpu):
            aug_img_meta = []
            for j in range(len(img_metas)):
                aug_img_meta.append(img_metas[j][i])
            aug_img_metas.append(aug_img_meta)
    return aug_proposals, aug_img_metas, cfg
    # after merging, proposals will be rescaled to the original image size
    # merged_proposals = [
    #     merge_aug_proposals(proposals, aug_img_meta, cfg)
    #     for proposals, aug_img_meta in zip(aug_proposals, aug_img_metas)
    # ]
    # return merged_proposals
    
def merge_aug_proposals(aug_proposals, img_metas, transposed, cfg):
    cfg = copy.deepcopy(cfg)

    # deprecate arguments warning
    if 'nms' not in cfg or 'max_num' in cfg or 'nms_thr' in cfg:
        warnings.warn(
            'In rpn_proposal or test_cfg, '
            'nms_thr has been moved to a dict named nms as '
            'iou_threshold, max_num has been renamed as max_per_img, '
            'name of original arguments and the way to specify '
            'iou_threshold of NMS will be deprecated.')
    if 'nms' not in cfg:
        cfg.nms = ConfigDict(dict(type='nms', iou_threshold=cfg.nms_thr))
    if 'max_num' in cfg:
        if 'max_per_img' in cfg:
            assert cfg.max_num == cfg.max_per_img, f'You set max_num and ' \
                f'max_per_img at the same time, but get {cfg.max_num} ' \
                f'and {cfg.max_per_img} respectively' \
                f'Please delete max_num which will be deprecated.'
        else:
            cfg.max_per_img = cfg.max_num
    if 'nms_thr' in cfg:
        assert cfg.nms.iou_threshold == cfg.nms_thr, f'You set ' \
            f'iou_threshold in nms and ' \
            f'nms_thr at the same time, but get ' \
            f'{cfg.nms.iou_threshold} and {cfg.nms_thr}' \
            f' respectively. Please delete the nms_thr ' \
            f'which will be deprecated.'

    recovered_proposals = []
    for proposals, img_info, t in zip(aug_proposals, img_metas, transposed):
        img_shape = img_info['img_shape']
        scale_factor = img_info['scale_factor']
        flip = img_info['flip']
        flip_direction = img_info['flip_direction']
        _proposals = proposals.clone()
        _proposals[:, :4] = bbox_mapping_back(_proposals[:, :4], img_shape,
                                              scale_factor, flip,
                                              flip_direction)
        if t:
            _proposals[:,:4] = _proposals[:,:4][:,[1,0,3,2]]
        recovered_proposals.append(_proposals)
    aug_proposals = torch.cat(recovered_proposals, dim=0)
    merged_proposals, _ = nms(aug_proposals[:, :4].contiguous(),
                              aug_proposals[:, -1].contiguous(),
                              cfg.nms.iou_threshold)
    scores = merged_proposals[:, 4]
    _, order = scores.sort(0, descending=True)
    num = min(cfg.max_per_img, merged_proposals.shape[0])
    order = order[:num]
    merged_proposals = merged_proposals[order, :]
    return merged_proposals

def roi_head_aug_test_bbox(models, res, proposal_list, img_metas):
    rcnn_test_cfg = models[0].roi_head.test_cfg
    aug_bboxes = []
    aug_scores = []
    for model, img_feats in zip(models, res):
        for x, img_meta in zip(img_feats, img_metas):
            # only one image in the batch
            img_shape = img_meta[0]['img_shape']
            scale_factor = img_meta[0]['scale_factor']
            flip = img_meta[0]['flip']
            flip_direction = img_meta[0]['flip_direction']

            proposals = bbox_mapping(proposal_list[0][:, :4], img_shape,
                                     scale_factor, flip, flip_direction)
            # "ms" in variable names means multi-stage
            ms_scores = []

            rois = bbox2roi([proposals])

            if rois.shape[0] == 0:
                # There is no proposal in the single image
                aug_bboxes.append(rois.new_zeros(0, 4))
                aug_scores.append(rois.new_zeros(0, 1))
                continue

            for i in range(model.roi_head.num_stages):
                bbox_head = model.roi_head.bbox_head[i]
                bbox_results = model.roi_head._bbox_forward(
                    i, x, rois, semantic_feat=None)
                ms_scores.append(bbox_results['cls_score'])

                if i < model.roi_head.num_stages - 1:
                    bbox_label = bbox_results['cls_score'].argmax(dim=1)
                    rois = bbox_head.regress_by_class(
                        rois, bbox_label, bbox_results['bbox_pred'],
                        img_meta[0])

            cls_score = sum(ms_scores) / float(len(ms_scores))
            bboxes, scores = model.roi_head.bbox_head[-1].get_bboxes(
                rois,
                cls_score,
                bbox_results['bbox_pred'],
                img_shape,
                scale_factor,
                rescale=False,
                cfg=None)
            aug_bboxes.append(bboxes)
            aug_scores.append(scores)

    # after merging, bboxes will be rescaled to the original image size
    merged_bboxes, merged_scores = merge_aug_bboxes(
        aug_bboxes, aug_scores, img_metas, rcnn_test_cfg)
    return merged_bboxes, merged_scores, rcnn_test_cfg, models[0].roi_head.bbox_head[-1].num_classes
#     det_bboxes, det_labels = multiclass_nms(merged_bboxes, merged_scores,
#                                             rcnn_test_cfg.score_thr,
#                                             rcnn_test_cfg.nms,
#                                             rcnn_test_cfg.max_per_img)

#     bbox_result = bbox2result(det_bboxes, det_labels,
#                               models[0].roi_head.bbox_head[-1].num_classes)
#     return det_bboxes, det_labels, bbox_result

def roi_head_aug_test_segm(models, model_weights, res, img_metas, det_bboxes, det_labels):
    rcnn_test_cfg = models[0].roi_head.test_cfg
    if det_bboxes.shape[0] == 0:
        return None, img_metas[0][0]['ori_shape'], models[0].roi_head.mask_head[-1].get_seg_masks, rcnn_test_cfg
#         segm_result = [[[] for _ in range(models[0].roi_head.mask_head[-1].num_classes)]]
    else:
        aug_masks = []
        aug_img_metas = []
        for model, img_feats, model_weight in zip(models, res, model_weights):
            for x, img_meta in zip(img_feats, img_metas):
                img_shape = img_meta[0]['img_shape']
                scale_factor = img_meta[0]['scale_factor']
                flip = img_meta[0]['flip']
                flip_direction = img_meta[0]['flip_direction']
                _bboxes = bbox_mapping(det_bboxes[:, :4], img_shape, scale_factor, flip, flip_direction)
                mask_rois = bbox2roi([_bboxes])
                mask_feats = model.roi_head.mask_roi_extractor[-1](
                    x[:len(model.roi_head.mask_roi_extractor[-1].featmap_strides)],
                    mask_rois)
                last_feat = None
                for i in range(model.roi_head.num_stages):
                    mask_head = model.roi_head.mask_head[i]
                    if model.roi_head.mask_info_flow:
                        mask_pred, last_feat = mask_head(
                            mask_feats, last_feat)
                    else:
                        mask_pred = mask_head(mask_feats)
                    aug_masks.append(mask_pred.sigmoid().cpu().numpy() * model_weight)
                    aug_img_metas.append(img_meta)
#         return aug_masks, aug_img_metas, models[0].roi_head.test_cfg
        merged_masks = merge_aug_masks(aug_masks, aug_img_metas, models[0].roi_head.test_cfg)
        return merged_masks, img_metas[0][0]['ori_shape'], models[0].roi_head.mask_head[-1].get_seg_masks, rcnn_test_cfg
#         ori_shape = img_metas[0][0]['ori_shape']
#         segm_result = models[0].roi_head.mask_head[-1].get_seg_masks(
#             merged_masks,
#             det_bboxes,
#             det_labels,
#             rcnn_test_cfg,
#             ori_shape,
#             scale_factor=1.0,
#             rescale=False)
#     return segm_result

def inference_detectors(models, model_weights, img_cut):
    res, imgs, img_metas = extract_feats(models, img_cut)
    aug_proposals, aug_img_metas, cfg = rpn_head_aug_test(models, res, img_metas)
    proposal_list = [merge_aug_proposals(aug_proposal, aug_img_meta, [False] * len(aug_proposal), cfg) for aug_proposal, aug_img_meta in zip(aug_proposals, aug_img_metas)]
    
    merged_bboxes, merged_scores, cfg, num_classes = roi_head_aug_test_bbox(models, res, proposal_list, img_metas)
    det_bboxes, det_labels = multiclass_nms(torch.cat([merged_bboxes]), 
                                            torch.cat([merged_scores]), 
                                            cfg.score_thr, cfg.nms, cfg.max_per_img)
    bbox_result = bbox2result(det_bboxes, det_labels, num_classes)
    
    merged_masks, ori_shape, get_seg_masks, rcnn_test_cfg = roi_head_aug_test_segm(models, model_weights, res, img_metas, det_bboxes, det_labels)
    if merged_masks is None:
        segm_result = [[[] for _ in range(models[0].roi_head.mask_head[-1].num_classes)]]
    else:
        segm_result = get_seg_masks(merged_masks, 
                                det_bboxes,
                                det_labels,
                                rcnn_test_cfg,
                                ori_shape,
                                scale_factor=1.0,
                                rescale=False)
    return bbox_result, segm_result

def inference_detectors_transpose(models, model_weights, img_cut):
    res, imgs, img_metas = extract_feats(models, img_cut)
    res_t, imgs_t, img_metas_t = extract_feats(models, img_cut.transpose(1,0,2))
    
    aug_proposals, aug_img_metas, cfg = rpn_head_aug_test(models, res, img_metas)
    aug_proposals_t, aug_img_metas_t, cfg_t = rpn_head_aug_test(models, res_t, img_metas_t)
    proposal_list = [merge_aug_proposals(aug_proposal + aug_proposal_t, aug_img_meta + aug_img_meta_t, [False] * len(aug_proposal) + [True] * len(aug_proposal_t), cfg) for aug_proposal, aug_proposal_t, aug_img_meta, aug_img_meta_t in zip(aug_proposals, aug_proposals_t, aug_img_metas, aug_img_metas_t)]
    
    merged_bboxes, merged_scores, cfg, num_classes = roi_head_aug_test_bbox(models, res, proposal_list, img_metas)
    merged_bboxes_t, merged_scores_t, cfg, num_classes = roi_head_aug_test_bbox(models, res_t, [_[:,[1,0,3,2,4]] for _ in proposal_list], img_metas_t)
    det_bboxes, det_labels = multiclass_nms(torch.cat([merged_bboxes, merged_bboxes_t[:,[1,0,3,2]]]), 
                                            torch.cat([merged_scores, merged_scores_t]), 
                                            cfg.score_thr, cfg.nms, cfg.max_per_img)
    bbox_result = bbox2result(det_bboxes, det_labels, num_classes)
    
    merged_masks, ori_shape, get_seg_masks, rcnn_test_cfg = roi_head_aug_test_segm(models, model_weights, res, img_metas, det_bboxes, det_labels)
    merged_masks_t, _, _, _ = roi_head_aug_test_segm(models, model_weights, res_t, img_metas_t, det_bboxes[:,[1,0,3,2,4]], det_labels)
    if merged_masks is None and merged_masks_t is None:
        segm_result = [[[] for _ in range(models[0].roi_head.mask_head[-1].num_classes)]]
    elif merged_masks is None or merged_masks_t is None:
        segm_result = get_seg_masks(merged_masks_t.transpose(0,1,3,2) if merged_masks is None else merged_masks, 
                                    det_bboxes,
                                    det_labels,
                                    rcnn_test_cfg,
                                    ori_shape,
                                    scale_factor=1.0,
                                    rescale=False)
    else:
        segm_result = get_seg_masks((merged_masks_t.transpose(0,1,3,2) + merged_masks) / 2, 
                                    det_bboxes,
                                    det_labels,
                                    rcnn_test_cfg,
                                    ori_shape,
                                    scale_factor=1.0,
                                    rescale=False)
    return bbox_result, segm_result

In [None]:
sub = []
for img_file in tqdm(img_files):
    img_id = os.path.basename(img_file).split(".")[0]
    img = cv2.imread(img_file)
    H, W = img.shape[:2]
    annotations = []

    dets = []
    for i, j in product(range(4), range(4)):
        img_cut = img[i * H // 5: (i + 2) * H // 5, j * W // 5: (j + 2) * W // 5]
        with torch.no_grad():
            small_det = inference_detectors(models, model_weights, img_cut)
        dets.append([i, j, small_det])
        
    cnt_per_class = np.array([[len(_) for _ in det] for i, j, (det, _) in dets]).sum(0).tolist()
    class_id = cnt_per_class.index(max(cnt_per_class))
        
    for i, j, small_det in dets:
        small_box, small_seg = small_det
        
        if isinstance(small_seg, list) and isinstance(small_seg[0], list) and len(small_seg[0]) != 0 and isinstance(small_seg[0][0], list):
            small_seg = small_seg[0]

        small_box = small_box[class_id]
        small_seg = small_seg[class_id]

        valid = (small_box[:,-1] > THRESHOLDS_small[class_id]) & \
            ~((i <= 2) & (small_box[:,[1,3]].mean(1) > 3 * H / 10)) & \
            ~((i >= 1) & (small_box[:,[1,3]].mean(1) < H / 10)) & \
            ~((j <= 2) & (small_box[:,[0,2]].mean(1) > 3 * W / 10)) & \
            ~((j >= 1) & (small_box[:,[0,2]].mean(1) < W / 10))
        small_box = small_box[valid]
        small_seg = [_ for _, v in zip(small_seg, valid) if v]

        for box, seg in zip(small_box, small_seg):
            if seg.sum() < MIN_PIXELS[class_id]: continue
            x1, y1, x2, y2, s = [float(_) for _ in box]
            seg = np.asfortranarray(seg)
            rle = mutils.encode(seg)
            annotations.append([[x1, y1, x2 - x1, y2 - y1], rle, s, i, j])

    annotations = sorted(annotations, key = lambda x: -x[-3])
    masks = np.zeros((H, W), dtype = np.uint); mask_idx = 1
    
    for ann in annotations:
        bbox, rle, s, i, j = ann
        mask = mutils.decode(rle)

        assign = (mask != 0) & (masks[i * H // 5: (i + 2) * H // 5, j * W // 5: (j + 2) * W // 5] == 0)
        assign_area = assign.sum()
        if assign_area < MIN_PIXELS[class_id]:
            continue
        num_connected, _ = cv2.connectedComponents(assign.astype(np.uint8))
        if num_connected > 2:
            continue
        overlap_ratio = 1 - assign_area / mask.sum()
        if overlap_ratio > MAX_OVERLAP[class_id]:
            continue
        masks[i * H // 5: (i + 2) * H // 5, j * W // 5: (j + 2) * W // 5][assign] = mask_idx


        mask_idx += 1

    if mask_idx > 1:
        for idx in range(1, mask_idx):
            rle = mask2rle((masks == idx).astype(np.uint8))
            sub.append([img_id, rle])
    else:
        sub.append([img_id, "0 1"])

sub_df = pd.DataFrame(sub, columns = ['id', 'predicted'])
sub_df.head()
sub_df.to_csv('submission.csv', index = False)

In [None]:
try:
    import matplotlib.pyplot as plt
    import albumentations as A

    fig, ax = plt.subplots(2, 3, figsize = (30, 20))

    for i, img_id in enumerate(sub_df.id.unique()[:3]):
        img = cv2.imread([_ for _ in img_files if img_id in _][0])
        ax[0][i].imshow(A.CLAHE(p = 1)(image = img)["image"])
        ax[0][i].axis("off")
        for rle in sub_df.loc[sub_df.id == img_id, "predicted"]:
            mask = rle2mask(rle, img.shape[:2])
            img[mask != 0] = img[mask != 0] // 2 + np.random.randint(0, 256, 3) // 2
        ax[1][i].imshow(img)
        ax[1][i].axis("off")

    plt.show(fig)
    plt.close(fig)
except:
    pass