# **Previous Notebooks**

https://www.kaggle.com/vexxingbanana/sartorius-coco-dataset-notebook

https://www.kaggle.com/vexxingbanana/sartorius-mmdetection-training

# **References**

https://www.kaggle.com/dschettler8845/sartorius-segmentation-eda-and-baseline

https://www.kaggle.com/ihelon/cell-segmentation-run-length-decoding

https://www.kaggle.com/stainsby/fast-tested-rle

https://www.kaggle.com/paulorzp/run-length-encode-and-decode

https://www.kaggle.com/awsaf49/sartorius-mmdetection-infer

https://www.kaggle.com/awsaf49/sartorius-mmdetection-train

https://www.kaggle.com/evancofsky/sartorius-torch-lightning-mask-r-cnn/notebook

# **Install MMDetection**

In [None]:
!pip install '/kaggle/input/pytorch-170-cuda-toolkit-110221/torch-1.7.0+cu110-cp37-cp37m-linux_x86_64.whl' --no-deps
!pip install '/kaggle/input/pytorch-170-cuda-toolkit-110221/torchvision-0.8.1+cu110-cp37-cp37m-linux_x86_64.whl' --no-deps
!pip install '/kaggle/input/pytorch-170-cuda-toolkit-110221/torchaudio-0.7.0-cp37-cp37m-linux_x86_64.whl' --no-deps

In [None]:
!pip install '/kaggle/input/mmdetectionv2140/addict-2.4.0-py3-none-any.whl' --no-deps
!pip install '/kaggle/input/mmdetectionv2140/yapf-0.31.0-py2.py3-none-any.whl' --no-deps
!pip install '/kaggle/input/mmdetectionv2140/terminal-0.4.0-py3-none-any.whl' --no-deps
!pip install '/kaggle/input/mmdetectionv2140/terminaltables-3.1.0-py3-none-any.whl' --no-deps
!pip install '/kaggle/input/mmdetectionv2140/mmcv_full-1_3_8-cu110-torch1_7_0/mmcv_full-1.3.8-cp37-cp37m-manylinux1_x86_64.whl' --no-deps
!pip install '/kaggle/input/mmdetectionv2140/pycocotools-2.0.2/pycocotools-2.0.2' --no-deps
!pip install '/kaggle/input/mmdetectionv2140/mmpycocotools-12.0.3/mmpycocotools-12.0.3' --no-deps

!rm -rf mmdetection

!cp -r /kaggle/input/mmdetectionv2140/mmdetection-2.14.0 /kaggle/working/
!mv /kaggle/working/mmdetection-2.14.0 /kaggle/working/mmdetection
%cd /kaggle/working/mmdetection
!pip install -e .

# **Import Libraries**

In [None]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torch.nn.functional as F
import sklearn
import torchvision
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
import numpy as np
import cupy as cp
import gc
import pandas as pd
import os
import matplotlib.pyplot as plt
import PIL
import json
from PIL import Image, ImageEnhance
import albumentations as A
import mmdet
import mmcv
from albumentations.pytorch import ToTensorV2
import seaborn as sns
import glob
from pathlib import Path
import pycocotools
from pycocotools import mask
import numpy.random
import random
import cv2
import re
import shutil
from mmdet.datasets import build_dataset
from mmdet.models import build_detector
from mmdet.apis import train_detector
from mmdet.apis import inference_detector, init_detector, show_result_pyplot, set_random_seed

In [None]:
%cd ..

# **Helper Functions**

In [None]:
IMG_WIDTH = 704
IMG_HEIGHT = 520

In [None]:
import cupy as cp
import gc

def one_hot(y, num_classes, dtype=cp.uint8): # GPU
    y = cp.array(y, dtype='int')
    input_shape = y.shape
    if input_shape and input_shape[-1] == 1 and len(input_shape) > 1:
        input_shape = tuple(input_shape[:-1])
    y = y.ravel()
    if not num_classes:
        num_classes = cp.max(y) + 1
    n = y.shape[0]
    categorical = cp.zeros((n, num_classes), dtype=dtype)
    categorical[cp.arange(n), y] = 1
    output_shape = input_shape + (num_classes,)
    categorical = cp.reshape(categorical, output_shape)
    return categorical

def fix_overlap(msk): # GPU
    """
    Args:
        mask: multi-channel mask, each channel is an instance of cell, shape:(520,704,None)
    Returns:
        multi-channel mask with non-overlapping values, shape:(520,704,None)
    """
    msk = cp.array(msk)
    msk = cp.pad(msk, [[0,0],[0,0],[1,0]]) # add dummy mask for background
    ins_len = msk.shape[-1]
    msk = cp.argmax(msk,axis=-1)# convert multi channel mask to single channel mask, argmax will remove overlap
    msk = one_hot(msk, num_classes=ins_len) # back to multi-channel mask, some instance might get removed
    msk = msk[...,1:] # remove background mask
    msk = msk[...,cp.any(msk, axis=(0,1))] # remove all-zero masks
    #assert np.prod(msk, axis=-1).sum()==0 # overlap check, will raise error if there is overlap
    return msk

def check_overlap(msk):
    msk = msk.astype(cp.bool).astype(cp.uint8) # binary mask
    return cp.any(cp.sum(msk, axis=-1)>1) # only one channgel will contain value

In [None]:
def rle_decode(mask_rle, shape):
    '''
    mask_rle: run-length as string formated (start length)
    shape: (height,width) of array to return 
    Returns numpy array, 1 - mask, 0 - background

    '''
    s = mask_rle.split()
    starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
    starts -= 1
    ends = starts + lengths
    img = np.zeros(shape[0]*shape[1], dtype=np.uint8)
    for lo, hi in zip(starts, ends):
        img[lo:hi] = 1
    return img.reshape(shape)

def rle_encode(img):
    '''
    img: numpy array, 1 - mask, 0 - background
    Returns run length as string formated
    '''
    pixels = img.flatten()
    pixels = np.concatenate([[0], pixels, [0]])
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
    runs[1::2] -= runs[::2]
    return ' '.join(str(x) for x in runs)

In [None]:
def rle_encoding(x):
    dots = np.where(x.flatten() == 1)[0]
    run_lengths = []
    prev = -2
    for b in dots:
        if (b>prev+1): run_lengths.extend((b + 1, 0))
        run_lengths[-1] += 1
        prev = b
    return ' '.join(map(str, run_lengths))

In [None]:
def get_mask_from_result(result):
    d = {True : 1, False : 0}
    u,inv = np.unique(result,return_inverse = True)
    mk = cp.array([d[x] for x in u])[inv].reshape(result.shape)
#     print(mk.shape)
    return mk

In [None]:
def does_overlap(mask, other_masks):
    for other_mask in other_masks:
        if np.sum(np.logical_and(mask, other_mask)) > 0:
            #import pdb; pdb.set_trace()
            #print("Found overlapping masks!")
            return True
    return False


def remove_overlapping_pixels(mask, other_masks):
    for other_mask in other_masks:
        if np.sum(np.logical_and(mask, other_mask)) > 0:
            print("Overlap detected")
            mask[np.logical_and(mask, other_mask)] = 0
    return mask

In [None]:
def get_img_and_mask(img_path, annotation, width, height):
    """ Capture the relevant image array as well as the image mask """
    img_mask = np.zeros((height, width), dtype=np.uint8)
    for i, annot in enumerate(annotation): 
        img_mask = np.where(rle_decode(annot, (height, width))!=0, i, img_mask)
    img = cv2.imread(img_path)[..., ::-1]
    return img[..., 0], img_mask

def plot_img_and_mask(img, mask, invert_img=True, boost_contrast=True):
    """ Function to take an image and the corresponding mask and plot
    
    Args:
        img (np.arr): 1 channel np arr representing the image of cellular structures
        mask (np.arr): 1 channel np arr representing the instance masks (incrementing by one)
        invert_img (bool, optional): Whether or not to invert the base image
        boost_contrast (bool, optional): Whether or not to boost contrast of the base image
        
    Returns:
        None; Plots the two arrays and overlays them to create a merged image
    """
    plt.figure(figsize=(20,10))
    
    plt.subplot(1,3,1)
    _img = np.tile(np.expand_dims(img, axis=-1), 3)
    
    # Flip black-->white ... white-->black
    if invert_img:
        _img = _img.max()-_img
        
    if boost_contrast:
        _img = np.asarray(ImageEnhance.Contrast(Image.fromarray(_img)).enhance(16))
        
    plt.imshow(_img)
    plt.axis(False)
    plt.title("Cell Image", fontweight="bold")
    
    plt.subplot(1,3,2)
    _mask = np.zeros_like(_img)
    _mask[..., 0] = mask
    plt.imshow(mask, cmap='rainbow')
    plt.axis(False)
    plt.title("Instance Segmentation Mask", fontweight="bold")
    
    merged = cv2.addWeighted(_img, 0.75, np.clip(_mask, 0, 1)*255, 0.25, 0.0,)
    plt.subplot(1,3,3)
    plt.imshow(merged)
    plt.axis(False)
    plt.title("Cell Image w/ Instance Segmentation Mask Overlay", fontweight="bold")
    
    plt.tight_layout()
    plt.show()

# **Model**

In [None]:
from mmcv import Config
cfg = Config.fromfile('/kaggle/working/mmdetection/configs/cascade_rcnn/cascade_mask_rcnn_x101_64x4d_fpn_20e_coco.py')

In [None]:
cfg.dataset_type = 'CocoDataset'
cfg.classes = '/kaggle/working/labels.txt'
cfg.data_root = '/kaggle/working'

for head in cfg.model.roi_head.bbox_head:
    head.num_classes = 3
    
# for head in cfg.model.roi_head.mask_head:
#     head.num_classes = 3
    
# cfg.model.roi_head.mask_head.semantic_head.num_classes=3
cfg.model.roi_head.mask_head.num_classes=3

cfg.data.test.type = 'CocoDataset'
cfg.data.test.classes = 'labels.txt'
cfg.data.test.data_root = '/kaggle/working'
cfg.data.test.ann_file = '../input/k/vexxingbanana/sartorius-coco-dataset-notebook/val_dataset.json'
cfg.data.test.img_prefix = ''

cfg.data.train.type = 'CocoDataset'
cfg.data.train.data_root = '/kaggle/working'
cfg.data.train.ann_file = '../input/k/vexxingbanana/sartorius-coco-dataset-notebook/train_dataset.json'
cfg.data.train.img_prefix = ''
cfg.data.train.classes = 'labels.txt'

cfg.data.val.type = 'CocoDataset'
cfg.data.val.data_root = '/kaggle/working'
cfg.data.val.ann_file = '../input/k/vexxingbanana/sartorius-coco-dataset-notebook/val_dataset.json'
cfg.data.val.img_prefix = ''
cfg.data.val.classes = 'labels.txt'

albu_train_transforms = [
    dict(type='ShiftScaleRotate', shift_limit=0.0625,
         scale_limit=0.15, rotate_limit=15, p=0.4),
    dict(type='RandomBrightnessContrast', brightness_limit=0.2,
         contrast_limit=0.2, p=0.5),
#     dict(type='IAAAffine', shear=(-10.0, 10.0), p=0.4),
#     dict(type='CLAHE', p=0.5),
    dict(
        type="OneOf",
        transforms=[
            dict(type="GaussianBlur", p=1.0, blur_limit=7),
            dict(type="MedianBlur", p=1.0, blur_limit=7),
        ],
        p=0.4,
    ),
]

cfg.train_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
#     dict(type='Resize', img_scale=[(440, 596), (480, 650), (520, 704), (580, 785), (620, 839)], multiscale_mode='value', keep_ratio=True),
#     dict(type='Resize', img_scale=[(880, 1192), (960, 130), (1040, 1408), (1160, 1570), (1240, 1678)], multiscale_mode='value', keep_ratio=True),
    dict(type='Resize', img_scale=(1333, 800)),
    

    dict(type='RandomFlip', flip_ratio=0.5),

#     dict(
#         type='Albu',
#         transforms=albu_train_transforms,
#         bbox_params=dict(
#         type='BboxParams',
#         format='pascal_voc',
#         label_fields=['gt_labels'],
#         min_visibility=0.0,
#         filter_lost_elements=True),
#         keymap=dict(img='image', gt_bboxes='bboxes', gt_masks='masks'),
#         update_pad_shape=False,
#         skip_img_without_anno=True),
    dict(
        type='Normalize',
        mean=[128, 128, 128],
        std=[11.58, 11.58, 11.58],
        to_rgb=True),
    dict(type='Pad', size_divisor=32),
    dict(type='DefaultFormatBundle'), 
    dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_masks', 'gt_labels'])
]

cfg.val_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(
        type='MultiScaleFlipAug',
#         img_scale=[(880, 1192), (960, 130), (1040, 1408), (1160, 1570), (1240, 1678)],
        img_scale = (1333, 800),
#         img_scale = (520, 704),
        flip=False,
        transforms=[
            dict(type='Resize', keep_ratio=True),
            dict(type='RandomFlip'),
            dict(
                type='Normalize',
                mean=[128, 128, 128],
                std=[11.58, 11.58, 11.58],
                to_rgb=True),
            dict(type='Pad', size_divisor=32),
            dict(type='ImageToTensor', keys=['img']),
            dict(type='Collect', keys=['img'])
        ])
]


cfg.test_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(
        type='MultiScaleFlipAug',
        img_scale=(1333, 800),
#         img_scale = (520, 704),
        flip=False,
        transforms=[
            dict(type='Resize', keep_ratio=True),
            dict(type='RandomFlip'),
            dict(
                type='Normalize',
                mean=[128, 128, 128],
                std=[11.58, 11.58, 11.58],
                to_rgb=True),
            dict(type='Pad', size_divisor=32),
            dict(type='ImageToTensor', keys=['img']),
            dict(type='Collect', keys=['img'])
        ])
]

cfg.data.train.pipeline = cfg.train_pipeline
cfg.data.val.pipeline = cfg.val_pipeline
# cfg.data.test.pipeline = cfg.test_pipeline

# cfg.load_from = '../input/htc-checkpoint-resnext101/htc_x101_64x4d_fpn_dconv_c3-c5_mstrain_400_1400_16x1_20e_coco_20200312-946fd751.pth'
cfg.load_from = '../input/cascade-mask-rcnn-mmdet/cascade_mask_rcnn_x101_64x4d_fpn_20e_coco_20200512_161033-bdb5126a.pth'

cfg.work_dir = '/kaggle/working/model_output'

cfg.optimizer.lr = 0.02 / 8
cfg.lr_config = dict(
    policy='CosineAnnealing', 
    by_epoch=False,
    warmup='linear', 
    warmup_iters=125, 
    warmup_ratio=0.001,
    min_lr=1e-07)

cfg.data.samples_per_gpu = 2
cfg.data.workers_per_gpu = 2

cfg.evaluation.metric = 'segm'
cfg.evaluation.interval = 1

cfg.checkpoint_config.interval = 1
cfg.runner.max_epochs = 12
cfg.log_config.interval = 100

# cfg.model.rpn_head.anchor_generator.base_sizes = [4, 9, 17, 31, 64]
# cfg.model.rpn_head.anchor_generator.strides = [4, 8, 16, 32, 64]


cfg.seed = 0
set_random_seed(0, deterministic=False)
cfg.gpu_ids = range(1)
cfg.fp16 = dict(loss_scale=512.0)
meta = dict()
meta['config'] = cfg.pretty_text



print(f'Config:\n{cfg.pretty_text}')

# **Inference**

In [None]:
confidence_thresholds = {0: 0.25, 1: 0.55, 2: 0.35}

In [None]:
segms = []
files = []

In [None]:
# model = init_detector(cfg, '../input/mmdetection-neuron-training/model_output/epoch_5.pth')
# for file in sorted(os.listdir('../input/sartorius-cell-instance-segmentation/test')):
#     img = mmcv.imread('../input/sartorius-cell-instance-segmentation/test/' + file)
#     result = inference_detector(model, img)
#     show_result_pyplot(model, img, result)
#     previous_masks = []
#     for i, bboxes in enumerate(result[0]):
#         if bboxes.shape != (0,5):
#             segmentations = result[1][i]
#             for bbox, segm in zip(bboxes, segmentations):
#                 box = bbox[:4]
#                 confidence = bbox[-1]
#                 if confidence > confidence_thresholds[i]:
#                     mask = get_mask_from_result(segm)
# #                     mask = remove_overlapping_pixels(mask, previous_masks)
#                     previous_masks.append(cp.array(mask))
# #     plt.imshow(previous_masks)
#     masks = np.stack(previous_masks, axis=-1)
#     masks = fix_overlap(masks)
#     for mk in masks:
#         rle_mask = rle_encoding(mk)
#         segms.append(rle_mask)
#         files.append(str(file.split('.')[0]))

In [None]:
# for i, mask in enumerate(previous_masks):
#     temp_prev = []
#     for j in range(len(previous_masks)):
#         if j != i:
#             previous_masks[j]
#     does_overlap(mask, temp_prev)

In [None]:
model = init_detector(cfg, '../input/sartorious-cascade-rcnn/best_segm_mAP_epoch_6.pth')
for file in sorted(os.listdir('../input/sartorius-cell-instance-segmentation/test')):
    img = mmcv.imread('../input/sartorius-cell-instance-segmentation/test/' + file)
    result = inference_detector(model, img)
    show_result_pyplot(model, img, result)
    previous_masks = []
    for i, classe in enumerate(result[0]):
        if classe.shape != (0, 5):
            bbs = classe
#             print(bbs)
            sgs = result[1][i]
            for bb, sg in zip(bbs,sgs):
                box = bb[:4]
                cnf = bb[4]
                if cnf >= confidence_thresholds[i]:
                    mask = get_mask_from_result(sg)
                    mask = remove_overlapping_pixels(mask, previous_masks)
                    previous_masks.append(mask)
#                     previous_masks.append(cp.array(mask))

#     plt.imshow(previous_masks)
#     masks = np.stack(previous_masks, axis=-1)
#     masks = fix_overlap(masks)        
    for mk in previous_masks:
            rle_mask = rle_encoding(mk)
            segms.append(rle_mask)
            files.append(str(file.split('.')[0]))

In [None]:
indexes = []
for i, segm in enumerate(segms):
    if segm == '':
        indexes.append(i)

In [None]:
for element in sorted(indexes, reverse = True):
    del segms[element]
    del files[element]

In [None]:
files = pd.Series(files, name='id')
preds = pd.Series(segms, name='predicted')

In [None]:
preds

In [None]:
submission_df = pd.concat([files, preds], axis=1)

In [None]:
submission_df.to_csv('submission.csv', index=False)

In [None]:
submission_df

In [None]:
# lines = []
# for f in submission_df.itertuples():
#     lines.append('../input/sartorius-cell-instance-segmentation/test/' + f[1] + '.png')
# lins = pd.Series(lines, name='img_path')
# check_df = pd.concat([submission_df, lins], axis=1)

In [None]:
# tmp_df = check_df.drop_duplicates(subset=["id"]).reset_index(drop=True)
# tmp_df["predicted"] = check_df.groupby("id")["predicted"].agg(list).reset_index(drop=True)
# check_df = tmp_df.copy()

In [None]:
# check_df

In [None]:
# for f in check_df.itertuples():
#     im, mk = get_img_and_mask(f[3], f[2], IMG_WIDTH, IMG_HEIGHT)
#     plot_img_and_mask(im, mk)

In [None]:
shutil.rmtree('/kaggle/working/mmdetection')