In [None]:
!pip install /kaggle/input/mmdetectionv2140/addict-2.4.0-py3-none-any.whl > /dev/null
!pip install /kaggle/input/mmdetectionv2140/yapf-0.31.0-py2.py3-none-any.whl > /dev/null
!pip install /kaggle/input/mmdetectionv2140/terminal-0.4.0-py3-none-any.whl > /dev/null
!pip install /kaggle/input/mmdetectionv2140/terminaltables-3.1.0-py3-none-any.whl > /dev/null
!pip install /kaggle/input/mmdetectionv2140/pycocotools-2.0.2/pycocotools-2.0.2 > /dev/null
!pip install /kaggle/input/mmdetectionv2140/mmpycocotools-12.0.3/mmpycocotools-12.0.3 > /dev/nullnull
import sys
sys.path.append('../input/timm-pytorch-image-models/pytorch-image-models-master')


!pip install ../input/mmdetection/mmcv_full-1.3.10-cp37-cp37m-linux_x86_64.whl > /dev/null
!pip install ../input/mmdetection/mmdet-2.15.0-py3-none-any.whl > /dev/null

In [None]:
import os
import cv2
import glob
import json
import numpy as np
import pandas as pd
from PIL import Image
from tqdm.auto import tqdm
from itertools import product
import pycocotools.mask as mutils
from pycocotools.coco import COCO

from mmdet.apis.inference import inference_detector, init_detector

def mask2rle(msk):
    pixels = msk.flatten()
    pad    = np.array([0])
    pixels = np.concatenate([pad, pixels, pad])
    runs   = np.where(pixels[1:] != pixels[:-1])[0] + 1
    runs[1::2] -= runs[::2]
    return ' '.join(str(x) for x in runs)

def rle2mask(rle, shape = [520, 704]):
    s = rle.split()
    starts, lengths = [np.asarray(x, dtype = int) for x in (s[0:][::2], s[1:][::2])]
    starts -= 1
    ends = starts + lengths
    img = np.zeros(shape[0] * shape[1], dtype = np.uint8)
    for lo, hi in zip(starts, ends):
        img[lo:hi] = 1
    return img.reshape(shape)



# gt = COCO("/home/zhaoxun/codes/mmdetection/data/cell/train/annotations/fold_0.json")
# img_files = [os.path.join("/home/zhaoxun/codes/mmdetection/data/cell/train/images", _["file_name"]) for _ in gt.imgs.values()]
img_files = glob.glob("../input/sartorius-cell-instance-segmentation/test/*.*")

small_config = "../input/sartorius-submission/htcr2101_1x_4sc_d2_800_all.py"
small_ckpt = "../input/sartorius-submission/htcr2101_1x_4sc_d2_800_all.pth"

small_model = init_detector(small_config, small_ckpt, cfg_options = {
    #"data.test.pipeline.1.img_scale": [(1333, 1333), (1150, 1150), (1024, 1024), (900, 900), (800, 800)],
    "model.test_cfg.rpn.nms_pre": 1000, 
#     "model.test_cfg.rpn.nms.iou_threshold": 0.8,
    "model.test_cfg.rcnn.nms.type": "weighted_cluster_nms",
    "model.test_cfg.rcnn.nms.iou_method": "diou",
    "model.test_cfg.rcnn.nms.iou_threshold": 0.45
})

THRESHOLDS_small = [0.4, 0.45, 0.7]
MIN_PIXELS = [80, 150, 60]
small_ids = [0,1,2]
print(len(img_files))

In [None]:
sub = []
for img_file in tqdm(img_files):
    img_id = os.path.basename(img_file).split(".")[0]
    img = cv2.imread(img_file)
    H, W = img.shape[:2]
    annotations = []

    dets = []
    for i, j in product(range(4), range(4)):
        img_cut = img[i * H // 5: (i + 2) * H // 5, j * W // 5: (j + 2) * W // 5]
        small_det = inference_detector(small_model, img_cut)
        dets.append([i, j, small_det])
        
    cnt_per_class = np.array([[len(_) for _ in det] for i, j, (det, _) in dets]).sum(0).tolist()
    class_id = cnt_per_class.index(max(cnt_per_class))
        
    for i, j, small_det in dets:
        small_box, small_seg = small_det
        
        if isinstance(small_seg, list) and isinstance(small_seg[0], list) and len(small_seg[0]) != 0 and isinstance(small_seg[0][0], list):
            small_seg = small_seg[0]

        small_box = small_box[class_id]
        small_seg = small_seg[class_id]

        valid = (small_box[:,-1] > THRESHOLDS_small[class_id]) & \
            ~((i <= 2) & (small_box[:,[1,3]].mean(1) > 3 * H / 10)) & \
            ~((i >= 1) & (small_box[:,[1,3]].mean(1) < H / 10)) & \
            ~((j <= 2) & (small_box[:,[0,2]].mean(1) > 3 * W / 10)) & \
            ~((j >= 1) & (small_box[:,[0,2]].mean(1) < W / 10))
        small_box = small_box[valid]
        small_seg = [_ for _, v in zip(small_seg, valid) if v]

        for box, seg in zip(small_box, small_seg):
            if seg.sum() < MIN_PIXELS[class_id]: continue
            x1, y1, x2, y2, s = [float(_) for _ in box]
            seg = np.asfortranarray(seg)
            rle = mutils.encode(seg)
            annotations.append([[x1, y1, x2 - x1, y2 - y1], rle, s, i, j])

    annotations = sorted(annotations, key = lambda x: -x[-3])
    masks = np.zeros((H, W), dtype = np.uint); mask_idx = 1
    
    for ann in annotations:
        bbox, rle, s, i, j = ann
        mask = mutils.decode(rle)

        assign = (mask != 0) & (masks[i * H // 5: (i + 2) * H // 5, j * W // 5: (j + 2) * W // 5] == 0)
        assign_area = assign.sum()
        if assign_area < MIN_PIXELS[class_id]:
            continue
        num_connected, _ = cv2.connectedComponents(assign.astype(np.uint8))
        if num_connected > 2:
            continue
        overlap_ratio = 1 - assign_area / mask.sum()
        if overlap_ratio > 0.2:
            continue
        masks[i * H // 5: (i + 2) * H // 5, j * W // 5: (j + 2) * W // 5][assign] = mask_idx


        mask_idx += 1

    if mask_idx > 1:
        for idx in range(1, mask_idx):
            rle = mask2rle((masks == idx).astype(np.uint8))
            sub.append([img_id, rle])
    else:
        sub.append(img_id, "0 1")

sub_df = pd.DataFrame(sub, columns = ['id', 'predicted'])
sub_df.head()
sub_df.to_csv('submission.csv', index = False)

In [None]:
try:
    import matplotlib.pyplot as plt
    import albumentations as A

    fig, ax = plt.subplots(2, 3, figsize = (15, 10))

    for i, img_id in enumerate(sub_df.id.unique()[:3]):
        img = cv2.imread([_ for _ in img_files if img_id in _][0])
        ax[0][i].imshow(A.CLAHE(p = 1)(image = img)["image"])
        ax[0][i].axis("off")
        for rle in sub_df.loc[sub_df.id == img_id, "predicted"]:
            mask = rle2mask(rle, img.shape[:2])
            img[mask != 0] = img[mask != 0] // 2 + np.random.randint(0, 256, 3) // 2
        ax[1][i].imshow(img)
        ax[1][i].axis("off")

    plt.show(fig)
    plt.close(fig)
except:
    pass