# mmdetection for segmentation [training]

In [114]:
from itertools import groupby
import shutil
import zipfile
from pathlib import Path
import subprocess
from pycocotools import mask as mutils
import numpy as np
from tqdm import tqdm
import pandas as pd
import os
import pickle
import cv2
from multiprocessing import Pool
import matplotlib.pyplot as plt
import holoviews as hv

In [115]:
hv.extension('bokeh')

In [96]:
exp_name = 'v3'
conf_name = 'mask_rcnn_s101_fpn_syncbn-backbone+head_mstrain_1x_coco'
cell_mask_dir = Path('../data/hpa-mask/hpa_cell_mask/')
ROOT = Path('../data')
train_or_test = 'train'
img_dir = f'mmdet_{exp_name}_{train_or_test}'
!mkdir -p {img_dir}
df = pd.read_csv(os.path.join(ROOT, 'train.csv'))
debug = True

In [88]:
for fn in (ROOT/'train').glob('*.zip'):
    print(fn)
    with zipfile.ZipFile(fn, mode='r') as archive:
        archive.extractall(path=fn.parent)
    
    if fn.parent/fn.stem in list((ROOT/'train').iterdir()): 
        fn.unlink()

In [102]:
debug_image_ids = list(
    set(fn.stem.split('_')[0] for fn in (ROOT/'train').iterdir()))

In [103]:
if debug:
    df = (df
          .set_index('ID').loc[debug_image_ids]
          .reset_index())

In [104]:
df

Unnamed: 0,ID,Label
0,a34d8680-bb99-11e8-b2b9-ac1f6b6435d0,18
1,000a6c98-bb9b-11e8-b2b9-ac1f6b6435d0,7|1|2|0
2,000a9596-bbc4-11e8-b2bc-ac1f6b6435d0,5
3,001838f8-bbca-11e8-b2bc-ac1f6b6435d0,12
4,000c99ba-bba4-11e8-b2b9-ac1f6b6435d0,1


## helper funcs

In [10]:
# convert segmentation mask image to run length encoding
MAX_GREEN = 64 # filter out dark green cells
def get_rles_from_mask(image_id, class_id):
    mask = np.load(f'{cell_mask_dir}/{image_id}.npz')['arr_0']
    if class_id != '18':
        green_img = read_img(image_id, 'green')
    rle_list = []
    mask_ids = np.unique(mask)
    for val in mask_ids:
        if val == 0:
            continue
        binary_mask = np.where(mask == val, 1, 0).astype(bool)
        if class_id != '18':
            masked_img = green_img * binary_mask
            #print(val, green_img.max(),masked_img.max())
            if masked_img.max() < MAX_GREEN:
                continue
        rle = coco_rle_encode(binary_mask)
        rle_list.append(rle)
    return rle_list, mask.shape[0], mask.shape[1]

def coco_rle_encode(mask):
    rle = {'counts': [], 'size': list(mask.shape)}
    counts = rle.get('counts')
    for i, (value, elements) in enumerate(groupby(mask.ravel(order='F'))):
        if i == 0 and value == 1:
            counts.append(0)
        counts.append(len(list(elements)))
    return rle

# mmdet custom dataset generator
def mk_mmdet_custom_data(image_id, class_id):
    rles, height, width = get_rles_from_mask(image_id, class_id)
    if len(rles) == 0:
        return {
            'filename': image_id+'.jpg',
            'width': width,
            'height': height,
            'ann': {}
        }
    rles = mutils.frPyObjects(rles, height, width)
    masks = mutils.decode(rles)
    bboxes = mutils.toBbox(mutils.encode(np.asfortranarray(masks.astype(np.uint8))))
    bboxes[:, 2] += bboxes[:, 0]
    bboxes[:, 3] += bboxes[:, 1]
    return {
        'filename': image_id+'.jpg',
        'width': width,
        'height': height,
        'ann':
            {
                'bboxes': np.array(bboxes, dtype=np.float32),
                'labels': np.zeros(len(bboxes)), # dummy data.(will be replaced later)
                'masks': rles
            }
    }

# print utility from public notebook
def print_masked_img(image_id, mask):
    img = load_RGBY_image(image_id, train_or_test)
    
    plt.figure(figsize=(15, 15))
    plt.subplot(1, 3, 1)
    plt.imshow(img)
    plt.title('Image')
    plt.axis('off')
    
    plt.subplot(1, 3, 2)
    plt.imshow(mask)
    plt.title('Mask')
    plt.axis('off')
    
    plt.subplot(1, 3, 3)
    plt.imshow(img)
    plt.imshow(mask, alpha=0.6)
    plt.title('Image + Mask')
    plt.axis('off')
    plt.show()
    
# image loader, using rgb only here
def load_RGBY_image(image_id, train_or_test='train', image_size=None):
    red = read_img(image_id, "red", train_or_test, image_size)
    green = read_img(image_id, "green", train_or_test, image_size)
    blue = read_img(image_id, "blue", train_or_test, image_size)
    #yellow = read_img(image_id, "yellow", train_or_test, image_size)
    stacked_images = np.transpose(np.array([red, green, blue]), (1,2,0))
    return stacked_images

# 
def read_img(image_id, color, train_or_test='train', image_size=None):
    filename = f'{ROOT}/{train_or_test}/{image_id}_{color}.png'
    assert os.path.exists(filename), f'not found {filename}'
    img = cv2.imread(filename, cv2.IMREAD_UNCHANGED)
    if image_size is not None:
        img = cv2.resize(img, (image_size, image_size))
    if img.max() > 255:
        img_max = img.max()
        img = (img/255).astype('uint8')
    return img

# make annotation helper called multi processes
def mk_ann(idx):
    image_id = df.iloc[idx].ID
    class_id = df.iloc[idx].Label
    anno = mk_mmdet_custom_data(image_id, class_id)
    img = load_RGBY_image(image_id, train_or_test)
    cv2.imwrite(f'{img_dir}/{image_id}.jpg', img)
    return anno, idx, image_id

In [106]:
image_id = df.ID.iloc[0]
class_id = '5'

In [119]:
hv.Image(read_img(image_id, color='yellow'))