# **Weighted Cluster-DIoU-NMS**

### May help improve +0.2 LB score

# **Previous Notebooks**

https://www.kaggle.com/vexxingbanana/sartorius-coco-dataset-notebook

https://www.kaggle.com/vexxingbanana/sartorius-mmdetection-training

https://www.kaggle.com/vexxingbanana/mmdetection-neuron-inference

# **References**

https://www.kaggle.com/awsaf49/sartorius-mmdetection-infer

https://github.com/Zzh-tju/CIoU

# **Install MMDetection**

In [None]:
!pip install '/kaggle/input/pytorch-170-cuda-toolkit-110221/torch-1.7.0+cu110-cp37-cp37m-linux_x86_64.whl' --no-deps
!pip install '/kaggle/input/pytorch-170-cuda-toolkit-110221/torchvision-0.8.1+cu110-cp37-cp37m-linux_x86_64.whl' --no-deps
!pip install '/kaggle/input/pytorch-170-cuda-toolkit-110221/torchaudio-0.7.0-cp37-cp37m-linux_x86_64.whl' --no-deps

In [None]:
!pip install '/kaggle/input/mmdetectionv2140/addict-2.4.0-py3-none-any.whl' --no-deps
!pip install '/kaggle/input/mmdetectionv2140/yapf-0.31.0-py2.py3-none-any.whl' --no-deps
!pip install '/kaggle/input/mmdetectionv2140/terminal-0.4.0-py3-none-any.whl' --no-deps
!pip install '/kaggle/input/mmdetectionv2140/terminaltables-3.1.0-py3-none-any.whl' --no-deps
!pip install '/kaggle/input/mmdetectionv2140/mmcv_full-1_3_8-cu110-torch1_7_0/mmcv_full-1.3.8-cp37-cp37m-manylinux1_x86_64.whl' --no-deps
!pip install '/kaggle/input/mmdetectionv2140/pycocotools-2.0.2/pycocotools-2.0.2' --no-deps
!pip install '/kaggle/input/mmdetectionv2140/mmpycocotools-12.0.3/mmpycocotools-12.0.3' --no-deps

!rm -rf mmdetection

!cp -r ../input/edited-mmdetection /kaggle/working/
!mv /kaggle/working/edited-mmdetection /kaggle/working/mmdetection
%cd /kaggle/working/mmdetection

In [None]:
%%writefile /kaggle/working/mmdetection/mmdet/core/post_processing/bbox_nms.py

import torch
from mmcv.ops.nms import batched_nms

from mmdet.core.bbox.iou_calculators import bbox_overlaps

def intersect(box_a, box_b):

    n = box_a.size(0)
    A = box_a.size(1)
    B = box_b.size(1)
    max_xy = torch.min(box_a[:, :, 2:].unsqueeze(2).expand(n, A, B, 2),
                       box_b[:, :, 2:].unsqueeze(1).expand(n, A, B, 2))
    min_xy = torch.max(box_a[:, :, :2].unsqueeze(2).expand(n, A, B, 2),
                       box_b[:, :, :2].unsqueeze(1).expand(n, A, B, 2))
    inter = torch.clamp((max_xy - min_xy), min=0)
    return inter[:, :, :, 0] * inter[:, :, :, 1]

def diou(box_a, box_b, beta=1.0, iscrowd:bool=False):
    use_batch = True
    if box_a.dim() == 2:
        use_batch = False
        box_a = box_a[None, ...]
        box_b = box_b[None, ...]

    inter = intersect(box_a, box_b)
    area_a = ((box_a[:, :, 2]-box_a[:, :, 0]) *
              (box_a[:, :, 3]-box_a[:, :, 1])).unsqueeze(2).expand_as(inter)  # [A,B]
    area_b = ((box_b[:, :, 2]-box_b[:, :, 0]) *
              (box_b[:, :, 3]-box_b[:, :, 1])).unsqueeze(1).expand_as(inter)  # [A,B]
    union = area_a + area_b - inter
    x1 = ((box_a[:, :, 2]+box_a[:, :, 0]) / 2).unsqueeze(2).expand_as(inter)
    y1 = ((box_a[:, :, 3]+box_a[:, :, 1]) / 2).unsqueeze(2).expand_as(inter)
    x2 = ((box_b[:, :, 2]+box_b[:, :, 0]) / 2).unsqueeze(1).expand_as(inter)
    y2 = ((box_b[:, :, 3]+box_b[:, :, 1]) / 2).unsqueeze(1).expand_as(inter)

    t1 = box_a[:, :, 1].unsqueeze(2).expand_as(inter)
    b1 = box_a[:, :, 3].unsqueeze(2).expand_as(inter)
    l1 = box_a[:, :, 0].unsqueeze(2).expand_as(inter)
    r1 = box_a[:, :, 2].unsqueeze(2).expand_as(inter)

    t2 = box_b[:, :, 1].unsqueeze(1).expand_as(inter)
    b2 = box_b[:, :, 3].unsqueeze(1).expand_as(inter)
    l2 = box_b[:, :, 0].unsqueeze(1).expand_as(inter)
    r2 = box_b[:, :, 2].unsqueeze(1).expand_as(inter)
    cr = torch.max(r1, r2)
    cl = torch.min(l1, l2)
    ct = torch.min(t1, t2)
    cb = torch.max(b1, b2)
    D = (((x2 - x1)**2 + (y2 - y1)**2) / ((cr-cl)**2 + (cb-ct)**2 + 1e-7))
    out = inter / area_a if iscrowd else inter / union - D ** beta
    return out if use_batch else out.squeeze(0)

def multiclass_nms(multi_bboxes,
                   multi_scores,
                   score_thr,
                   nms_cfg,
                   max_num=-1,
                   score_factors=None,
                   return_inds=False):
    """NMS for multi-class bboxes.

    Args:
        multi_bboxes (Tensor): shape (n, #class*4) or (n, 4)
        multi_scores (Tensor): shape (n, #class), where the last column
            contains scores of the background class, but this will be ignored.
        score_thr (float): bbox threshold, bboxes with scores lower than it
            will not be considered.
        nms_thr (float): NMS IoU threshold
        max_num (int, optional): if there are more than max_num bboxes after
            NMS, only top max_num will be kept. Default to -1.
        score_factors (Tensor, optional): The factors multiplied to scores
            before applying NMS. Default to None.
        return_inds (bool, optional): Whether return the indices of kept
            bboxes. Default to False.

    Returns:
        tuple: (dets, labels, indices (optional)), tensors of shape (k, 5),
            (k), and (k). Dets are boxes with scores. Labels are 0-based.
    """
    iou_thr = nms_cfg['iou_threshold']
    num_classes = multi_scores.size(1) - 1
    # exclude background category
    if multi_bboxes.shape[1] > 4:
        bboxes = multi_bboxes.view(multi_scores.size(0), -1, 4)
    else:
        bboxes = multi_bboxes[:, None].expand(
            multi_scores.size(0), num_classes, 4)

    scores = multi_scores[:, :-1]

    labels = torch.arange(num_classes, dtype=torch.long).cuda()
    labels = labels.view(1, -1).expand_as(scores)

    bboxes = bboxes.reshape(-1, 4)
    scores = scores.reshape(-1)
    labels = labels.reshape(-1)

    if not torch.onnx.is_in_onnx_export():
        # NonZero not supported  in TensorRT
        # remove low scoring boxes
        valid_mask = scores > score_thr
    # multiply score_factor after threshold to preserve more bboxes, improve
    # mAP by 1% for YOLOv3
    if score_factors is not None:
        # expand the shape to match original shape of score
        score_factors = score_factors.view(-1, 1).expand(
            multi_scores.size(0), num_classes)
        score_factors = score_factors.reshape(-1)
        scores = scores * score_factors

    if not torch.onnx.is_in_onnx_export():
        # NonZero not supported  in TensorRT
        inds = valid_mask.nonzero(as_tuple=False).squeeze(1)
        bboxes, scores, labels = bboxes[inds], scores[inds], labels[inds]
    else:
        # TensorRT NMS plugin has invalid output filled with -1
        # add dummy data to make detection output correct.
        bboxes = torch.cat([bboxes, bboxes.new_zeros(1, 4)], dim=0)
        scores = torch.cat([scores, scores.new_zeros(1)], dim=0)
        labels = torch.cat([labels, labels.new_zeros(1)], dim=0)

    if bboxes.numel() == 0:
        if torch.onnx.is_in_onnx_export():
            raise RuntimeError('[ONNX Error] Can not record NMS '
                               'as it has not been executed this time')
        dets = torch.cat([bboxes, scores[:, None]], -1)
        if return_inds:
            return dets, labels, inds
        else:
            return dets, labels
    
    # Weighted Cluster-DIoU-NMS
    scores, idx = scores.sort(0, descending=True)
    bboxes = bboxes[idx]
    labels = labels[idx]
    inds = inds[idx]
    box = bboxes + labels.unsqueeze(1).expand_as(bboxes)*4000

    diou_matrix = diou(box, box, 0.8)    # DIoU matrix
    iou = (diou_matrix+0).triu_(diagonal=1) 
    B = iou
    for i in range(999):
        A=B
        maxA = A.max(dim=0)[0]
        E = (maxA <= iou_thr).float().unsqueeze(1).expand_as(A)
        B=iou.mul(E)
        if A.equal(B)==True:
            break
    B=torch.triu(diou_matrix).mul(E)
    keep = (maxA <= iou_thr)
    weights = (torch.exp(-(1-(B*(B>0.7).float()))**2 / 0.025)) * (scores.reshape((1,len(scores))))
    #weights = (B*(B>0.7).float()) * (scores.reshape((1,len(scores))))
    bboxes = torch.mm(weights, bboxes).float() / weights.sum(1, keepdim=True)

        # Only keep the top max_num highest scores across all classes
    if max_num > 0:
        scores = scores[keep][:max_num]
        labels = labels[keep][:max_num]
        bboxes = bboxes[keep][:max_num]
    dets = torch.cat([bboxes, scores[:, None]], dim=1)
    
    if return_inds:
        return dets, labels, inds[keep]
    else:
        return dets, labels

In [None]:
!pip install -e .

# **Import Libraries**

In [None]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torch.nn.functional as F
import sklearn
import torchvision
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
import numpy as np
import cupy as cp
import gc
import pandas as pd
import os
import matplotlib.pyplot as plt
import PIL
import json
from PIL import Image, ImageEnhance
import albumentations as A
import mmdet
import mmcv
from albumentations.pytorch import ToTensorV2
import seaborn as sns
import glob
from pathlib import Path
import pycocotools
from pycocotools import mask
import numpy.random
import random
import cv2
import re
import shutil
from mmdet.datasets import build_dataset
from mmdet.models import build_detector
from mmdet.apis import train_detector
from mmdet.apis import inference_detector, init_detector, show_result_pyplot, set_random_seed

In [None]:
%cd ..

# **Helper Functions**

In [None]:
IMG_WIDTH = 704
IMG_HEIGHT = 520

In [None]:
def rle_decode(mask_rle, shape):
    '''
    mask_rle: run-length as string formated (start length)
    shape: (height,width) of array to return 
    Returns numpy array, 1 - mask, 0 - background

    '''
    s = mask_rle.split()
    starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
    starts -= 1
    ends = starts + lengths
    img = np.zeros(shape[0]*shape[1], dtype=np.uint8)
    for lo, hi in zip(starts, ends):
        img[lo:hi] = 1
    return img.reshape(shape)

def rle_encode(img):
    '''
    img: numpy array, 1 - mask, 0 - background
    Returns run length as string formated
    '''
    pixels = img.flatten()
    pixels = np.concatenate([[0], pixels, [0]])
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
    runs[1::2] -= runs[::2]
    return ' '.join(str(x) for x in runs)

In [None]:
def rle_encoding(x):
    dots = np.where(x.flatten() == 1)[0]
    run_lengths = []
    prev = -2
    for b in dots:
        if (b>prev+1): run_lengths.extend((b + 1, 0))
        run_lengths[-1] += 1
        prev = b
    return ' '.join(map(str, run_lengths))

In [None]:
def get_mask_from_result(result):
    d = {True : 1, False : 0}
    u,inv = np.unique(result,return_inverse = True)
    mk = cp.array([d[x] for x in u])[inv].reshape(result.shape)
#     print(mk.shape)
    return mk

In [None]:
def does_overlap(mask, other_masks):
    for other_mask in other_masks:
        if np.sum(np.logical_and(mask, other_mask)) > 0:
            return True
    return False


def remove_overlapping_pixels(mask, other_masks):
    for other_mask in other_masks:
        if np.sum(np.logical_and(mask, other_mask)) > 0:
            print("Overlap detected")
            mask[np.logical_and(mask, other_mask)] = 0
    return mask

In [None]:
def get_img_and_mask(img_path, annotation, width, height):
    """ Capture the relevant image array as well as the image mask """
    img_mask = np.zeros((height, width), dtype=np.uint8)
    for i, annot in enumerate(annotation): 
        img_mask = np.where(rle_decode(annot, (height, width))!=0, i, img_mask)
    img = cv2.imread(img_path)[..., ::-1]
    return img[..., 0], img_mask

def plot_img_and_mask(img, mask, invert_img=True, boost_contrast=True):
    """ Function to take an image and the corresponding mask and plot
    
    Args:
        img (np.arr): 1 channel np arr representing the image of cellular structures
        mask (np.arr): 1 channel np arr representing the instance masks (incrementing by one)
        invert_img (bool, optional): Whether or not to invert the base image
        boost_contrast (bool, optional): Whether or not to boost contrast of the base image
        
    Returns:
        None; Plots the two arrays and overlays them to create a merged image
    """
    plt.figure(figsize=(20,10))
    
    plt.subplot(1,3,1)
    _img = np.tile(np.expand_dims(img, axis=-1), 3)
    
    # Flip black-->white ... white-->black
    if invert_img:
        _img = _img.max()-_img
        
    if boost_contrast:
        _img = np.asarray(ImageEnhance.Contrast(Image.fromarray(_img)).enhance(16))
        
    plt.imshow(_img)
    plt.axis(False)
    plt.title("Cell Image", fontweight="bold")
    
    plt.subplot(1,3,2)
    _mask = np.zeros_like(_img)
    _mask[..., 0] = mask
    plt.imshow(mask, cmap='rainbow')
    plt.axis(False)
    plt.title("Instance Segmentation Mask", fontweight="bold")
    
    merged = cv2.addWeighted(_img, 0.75, np.clip(_mask, 0, 1)*255, 0.25, 0.0,)
    plt.subplot(1,3,3)
    plt.imshow(merged)
    plt.axis(False)
    plt.title("Cell Image w/ Instance Segmentation Mask Overlay", fontweight="bold")
    
    plt.tight_layout()
    plt.show()

# **Model**

In [None]:
from mmcv import Config
# cfg = Config.fromfile('/kaggle/working/mmdetection/configs/cascade_rcnn/cascade_mask_rcnn_x101_64x4d_fpn_20e_coco.py')
cfg = Config.fromfile('/kaggle/working/mmdetection/configs/cascade_rcnn/cascade_mask_rcnn_r50_fpn_20e_coco.py')

In [None]:
for head in cfg.model.roi_head.bbox_head:
    head.num_classes = 3
    
cfg.model.roi_head.mask_head.num_classes=3

cfg.test_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(
        type='MultiScaleFlipAug',
        img_scale=(1333, 800),
        flip=False,
        transforms=[
            dict(type='Resize', keep_ratio=True),
            dict(type='RandomFlip'),
            dict(
                type='Normalize',
                mean=[128, 128, 128],
                std=[11.58, 11.58, 11.58],
                to_rgb=True),
            dict(type='Pad', size_divisor=32),
            dict(type='ImageToTensor', keys=['img']),
            dict(type='Collect', keys=['img'])
        ])
]

# cfg.data.test.pipeline[1].transforms[2] = dict(
#                                             type='Normalize',
#                                             mean=[128, 128, 128],
#                                             std=[11.58, 11.58, 11.58],
#                                             to_rgb=True)

cfg.data.test.pipeline = cfg.test_pipeline

cfg.model.test_cfg.rcnn.max_per_img = 300

# cfg.load_from = '../input/cascade-mask-rcnn-mmdet/cascade_mask_rcnn_x101_64x4d_fpn_20e_coco_20200512_161033-bdb5126a.pth'

cfg.work_dir = '/kaggle/working/model_output'

cfg.data.samples_per_gpu = 2
cfg.data.workers_per_gpu = 2

cfg.img_norm_cfg = dict(  
    mean=[128, 128, 128],  
    std=[11.58, 11.58, 11.58],  
    to_rgb=True)

cfg.seed = 0
set_random_seed(0, deterministic=False)
cfg.gpu_ids = range(1)
cfg.fp16 = dict(loss_scale=512.0)
meta = dict()
meta['config'] = cfg.pretty_text

print(f'Config:\n{cfg.pretty_text}')

# **Inference**

In [None]:
confidence_thresholds = {0: 0.25, 1: 0.55, 2: 0.35}

In [None]:
segms = []
files = []

In [None]:
model = init_detector(cfg, '../input/mmdetection-neuron-training/finetune_output/epoch_3.pth')
for file in sorted(os.listdir('../input/sartorius-cell-instance-segmentation/test')):
    img = mmcv.imread('../input/sartorius-cell-instance-segmentation/test/' + file)
    result = inference_detector(model, img)
    show_result_pyplot(model, img, result)
    previous_masks = []
    for i, classe in enumerate(result[0]):
        if classe.shape != (0, 5):
            bbs = classe
            sgs = result[1][i]
            for bb, sg in zip(bbs,sgs):
                box = bb[:4]
                cnf = bb[4]
                if cnf >= confidence_thresholds[i]:
                    mask = get_mask_from_result(sg)
                    mask = remove_overlapping_pixels(mask, previous_masks)
                    previous_masks.append(mask)

    for mk in previous_masks:
            rle_mask = rle_encoding(mk)
            segms.append(rle_mask)
            files.append(str(file.split('.')[0]))

In [None]:
indexes = []
for i, segm in enumerate(segms):
    if segm == '':
        indexes.append(i)

In [None]:
for element in sorted(indexes, reverse = True):
    del segms[element]
    del files[element]

In [None]:
files = pd.Series(files, name='id')
preds = pd.Series(segms, name='predicted')

In [None]:
preds

In [None]:
submission_df = pd.concat([files, preds], axis=1)

In [None]:
submission_df.to_csv('submission.csv', index=False)

In [None]:
submission_df

In [None]:
shutil.rmtree('/kaggle/working/mmdetection')