# Instance Segmentation Fix_Overlap Before/After

This notebook referred to the following notebook.<br/>
https://www.kaggle.com/arunamenon/cell-instance-segmentation-unet-eda<br/>
https://www.kaggle.com/karan23258/cell-instance-segmentation-unetfromscratch<br/>
https://www.kaggle.com/evangelou/sartorius-unet-pytorch-from-scratch<br/>
https://www.kaggle.com/awsaf49/sartorius-fix-overlap

# Import packages

In [None]:
import numpy as np
import pandas as pd
import imageio
import matplotlib.pyplot as plt
import cv2
import seaborn as sns
from skimage.io import imread, imshow, imread_collection, concatenate_images
from skimage.transform import resize
from skimage.morphology import label
from tqdm import tqdm
import random
import tensorflow as tf

# Load input files

In [None]:
train_data = pd.read_csv('../input/sartorius-cell-instance-segmentation/train.csv')

In [None]:
print(train_data.shape)
train_data.head()

In [None]:
print(train_data['cell_type'].unique().tolist())

In [None]:
# these 3 ids used as representatives of train data.
print(train_data[train_data['cell_type']=='shsy5y']['id'].tolist()[0])
print(train_data[train_data['cell_type']=='astro']['id'].tolist()[0])
print(train_data[train_data['cell_type']=='cort']['id'].tolist()[0])

## rle_decode : mask_rle data (=label data) to mask image data

In [None]:
# Reference: https://www.kaggle.com/ihelon/cell-segmentation-run-length-decoding

def rle_decode(mask_rle, shape, color=1):
    '''
    mask_rle: run-length as string formated (start length)
    shape: (height, width, channels) of array to return 
    color: color for the mask
    Returns numpy array (mask)
    '''
    s = mask_rle.split()
    
    starts = list(map(lambda x: int(x)-1, s[0::2]))
    lengths = list(map(int, s[1::2]))
    ends = [x + y for x, y in zip(starts, lengths)]
    
    img = np.zeros((shape[0] * shape[1], shape[2]), dtype=np.float32)
            
    for start, end in zip(starts, ends):
        img[start : end] = color
    
    return img.reshape(shape)


## plot_masks : show original and mask images

In [None]:
def plot_masks(image_id, colors=True):
    labels = train_data[train_data["id"] == image_id]["annotation"].tolist()

    if colors:
        mask = np.zeros((520, 704, 3))
        for label in labels:
            mask += rle_decode(label, shape=(520, 704, 3), color=np.random.rand(3))
    else:
        mask = np.zeros((520, 704, 1))
        for label in labels:
            mask += rle_decode(label, shape=(520, 704, 1))
            
    mask = mask.clip(0, 1)

    image = cv2.imread(f"../input/sartorius-cell-instance-segmentation/train/{image_id}.png")
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    plt.figure(figsize=(18,6))
    plt.subplot(1, 3, 1)
    plt.imshow(image)
    plt.title('Input image')
    plt.axis("off")
    
    plt.subplot(1, 3, 2)
    plt.imshow(image)
    plt.imshow(mask, alpha=0.1)
    plt.title('Input image with mask')
    plt.axis("off")
    
    plt.subplot(1, 3, 3)
    plt.imshow(mask)
    plt.title('Only mask')
    plt.axis("off")
    
    plt.show();

### ret_mask_img creates mask image data

In [None]:
def ret_mask_img(image_id, colors=True):
    
    labels = train_data[train_data["id"] == image_id]["annotation"].tolist()

    if colors:
        mask = np.zeros((520, 704, 3))
        for label in labels:
            mask += rle_decode(label,shape=(520,704,3),color=np.random.rand(3))
    else:
        mask = np.zeros((520, 704, 1))
        for label in labels:
            mask += rle_decode(label,shape=(520,704,1))
            
    mask = mask.clip(0, 1)
    
    return mask

### show input and mask image

In [None]:
sample_ids = ['0030fd0e6378','0140b3c8f445','01ae5a43a2ab']

for sample_id in sample_ids:
    celltype=train_data[train_data['id']==sample_id]['cell_type'].tolist()[0]
    file_path = '../input/sartorius-cell-instance-segmentation/train/' + sample_id + '.png'
    image_df = imageio.imread(file_path)
    
    print('ID:', sample_id, ', CellType:',celltype)
    plot_masks(sample_id, colors=True)

# Label before fix_overlap

In [None]:
LABEL2=[]
for sample_id in sample_ids:
    labels = train_data[train_data["id"] == sample_id]["annotation"].tolist()
    LABEL=''
    for label in labels:
        LABEL+=label
    LABEL2+=[LABEL]
    
print(LABEL2[0])

In [None]:
MASK=[]
for sample_id in sample_ids:
    labels = train_data[train_data["id"] == sample_id]["annotation"].tolist()
    mask = np.zeros((520,704,3))
    for label in labels:
        mask += rle_decode(label, shape=(520,704,3))

    MASK += [mask.clip(0,1)]
    
print(MASK[0].shape)
print()
print(MASK[0])

In [None]:
# values are 0 or 1
fig, ax = plt.subplots(figsize=(12,4)) 
sns.histplot(MASK[0].flatten(), label='MASK[0].flatten()', ax=ax, color='C1',bins=20) 
ax.legend() 
ax.grid()

### check_overlap 

In [None]:
def check_overlap(msk):
    msk = msk.astype(np.bool).astype(np.uint8)
    return np.any(np.sum(msk, axis=-1)>1)

In [None]:
# surprisingly, original labels have overlaps
print(check_overlap(MASK[0]))
print(check_overlap(MASK[1]))
print(check_overlap(MASK[2]))

In [None]:
def fix_overlap(msk):
    """
    Args:
        mask: multi-channel mask, each channel is an instance of cell, shape:(520,704,None)
    Returns:
        multi-channel mask with non-overlapping values, shape:(520,704,None)
    """
    msk = np.array(msk)
    msk = np.pad(msk, [[0,0],[0,0],[1,0]])
    ins_len = msk.shape[-1]
    msk = np.argmax(msk,axis=-1)
    msk = tf.keras.utils.to_categorical(msk, num_classes=ins_len)
    msk = msk[...,1:]
    msk = msk[...,np.any(msk, axis=(0,1))]
    
    return msk

In [None]:
# fixed mask data
MASK0b=fix_overlap(MASK[0])
MASK1b=fix_overlap(MASK[1])
MASK2b=fix_overlap(MASK[2])

In [None]:
# values are 0 or 1
fig, ax = plt.subplots(figsize=(12,4)) 
sns.histplot(MASK0b.flatten(), label='MASK0b.flatten()', ax=ax, color='C1',bins=20) 
ax.legend() 
ax.grid()

In [None]:
# no overlaps detected
print(check_overlap(MASK0b))
print(check_overlap(MASK1b))
print(check_overlap(MASK2b))

In [None]:
print(MASK[0].shape)
print(MASK0b.shape)

In [None]:
plt.imshow(MASK[0])
plt.title('original mask')
plt.axis("off")
plt.show()

plt.imshow(MASK0b)
plt.title('fix_overlapped mask')
plt.axis("off")
plt.show()

## show mask image (beofre fix_overlapping)

In [None]:
for sample_id in sample_ids:
    celltype=train_data[train_data['id']==sample_id]['cell_type'].tolist()[0]
    file_path = '../input/sartorius-cell-instance-segmentation/train/' + sample_id + '.png'
    image_df = imageio.imread(file_path)
    print('ID:', sample_id, ', CellType:',celltype)
    mask=ret_mask_img(sample_id, colors=True)
    print(mask.shape)
    
    plt.imshow(mask)
    plt.title('mask before fix_overlapping')
    plt.axis("off")
    plt.show()

## show mask image (after fix_overlapping)

In [None]:
for mask in [MASK0b,MASK1b,MASK2b]:
    print(mask.shape)
    plt.imshow(mask)
    plt.title('mask after fix_overlapping')
    plt.axis("off")
    plt.show()

# rle_encoding: mask image data to rle data (=label data)

In [None]:
def rle_encoding(x):
    dots = np.where(x.flatten() == 1)[0]
    run_lengths = []
    prev = -2
    for b in dots:
        if (b>prev+1): run_lengths.extend((b + 1, 0))
        run_lengths[-1] += 1
        prev = b
    return ' '.join(map(str, run_lengths))

# Label after fix_overlap

In [None]:
FIXED_LABEL=[]
for mask in [MASK0b,MASK1b,MASK2b]:
    FIXED_LABEL+=[rle_encoding(mask)]
print(FIXED_LABEL[0])