In [1]:
## Download MMdetection ##
!rsync -a ../input/human-protein-atlas-mmdetection/mmdetection ../
!pip install ../input/human-protein-atlas-mmdetection/src/mmdet-2.8.0/mmdet-2.8.0/
!pip install ../input/human-protein-atlas-mmdetection/src/mmpycocotools-12.0.3/mmpycocotools-12.0.3/
!pip install ../input/human-protein-atlas-mmdetection/src/addict-2.4.0-py3-none-any.whl
!pip install ../input/human-protein-atlas-mmdetection/src/yapf-0.30.0-py2.py3-none-any.whl
!pip install ../input/human-protein-atlas-mmdetection/src/mmcv_full-1.2.6-cp37-cp37m-manylinux1_x86_64.whl

In [2]:
## Import Library ##
from itertools import groupby
from pycocotools import mask as mutils
import numpy as np
from tqdm import tqdm
import pandas as pd
import os
import pickle
import cv2
from multiprocessing import Pool
import matplotlib.pyplot as plt

In [3]:
cell_mask_dir = '../input/hpa-mask/hpa_cell_mask'    
ROOT = '../input/hpa-single-cell-image-classification/'
train_or_test = 'train'
img_dir = f'../work/mmdet_{train_or_test}'
!mkdir -p {img_dir}

In [4]:
## Util function ##
MAX_GREEN = 64 
def get_rles_from_mask(image_id, class_id):
    mask = np.load(f'{cell_mask_dir}/{image_id}.npz')['arr_0']
    if class_id != '18':
        green_img = read_img(image_id, 'green')
    rle_list = []
    mask_ids = np.unique(mask)
    for val in mask_ids:
        if val == 0:
            continue
        binary_mask = np.where(mask == val, 1, 0).astype(bool)
        if class_id != '18':
            masked_img = green_img * binary_mask
            if masked_img.max() < MAX_GREEN:
                continue
        rle = coco_rle_encode(binary_mask)
        rle_list.append(rle)
    return rle_list, mask.shape[0], mask.shape[1]

def coco_rle_encode(mask):
    rle = {'counts': [], 'size': list(mask.shape)}
    counts = rle.get('counts')
    for i, (value, elements) in enumerate(groupby(mask.ravel(order='F'))):
        if i == 0 and value == 1:
            counts.append(0)
        counts.append(len(list(elements)))
    return rle

def mk_mmdet_custom_data(image_id, class_id):
    rles, height, width = get_rles_from_mask(image_id, class_id)
    if len(rles) == 0:
        return {
            'filename': image_id+'.jpg',
            'width': width,
            'height': height,
            'ann': {}
        }
    rles = mutils.frPyObjects(rles, height, width)
    bboxes = mutils.toBbox(rles)
    bboxes[:, 2] += bboxes[:, 0]
    bboxes[:, 3] += bboxes[:, 1]
    return {
        'filename': image_id+'.jpg',
        'width': width,
        'height': height,
        'ann':
            {
                'bboxes': np.array(bboxes, dtype=np.float32),
                'labels': np.zeros(len(bboxes)), # dummy data.(will be replaced later)
                'masks': rles
            }
    }

def print_masked_img(image_id, mask):
    img = load_RGBY_image(image_id, train_or_test)
    
    plt.figure(figsize=(15, 15))
    plt.subplot(1, 3, 1)
    plt.imshow(img)
    plt.title('Image')
    plt.axis('off')
    
    plt.subplot(1, 3, 2)
    plt.imshow(mask)
    plt.title('Mask')
    plt.axis('off')
    
    plt.subplot(1, 3, 3)
    plt.imshow(img)
    plt.imshow(mask, alpha=0.6)
    plt.title('Image + Mask')
    plt.axis('off')
    
#     plt.savefig('train_example.png')
    plt.show()
    
    
def load_RGBY_image(image_id, train_or_test='train', image_size=None):
    red = read_img(image_id, "red", train_or_test, image_size)
    green = read_img(image_id, "green", train_or_test, image_size)
    blue = read_img(image_id, "blue", train_or_test, image_size)
    #yellow = read_img(image_id, "yellow", train_or_test, image_size)
    stacked_images = np.transpose(np.array([red, green, blue]), (1,2,0))
    return stacked_images
 
def read_img(image_id, color, train_or_test='train', image_size=None):
    filename = f'{ROOT}/{train_or_test}/{image_id}_{color}.png'
    assert os.path.exists(filename), f'not found {filename}'
    img = cv2.imread(filename, cv2.IMREAD_UNCHANGED)
    if image_size is not None:
        img = cv2.resize(img, (image_size, image_size))
    if img.max() > 255:
        img_max = img.max()
        img = (img/255).astype('uint8')
    return img

def mk_ann(idx):
    image_id = df.iloc[idx].ID
    class_id = df.iloc[idx].Label
    anno = mk_mmdet_custom_data(image_id, class_id)
    img = load_RGBY_image(image_id, train_or_test)
    cv2.imwrite(f'{img_dir}/{image_id}.jpg', img)
    return anno, idx, image_id

In [5]:
## Load Dataset ##
df = pd.read_csv(os.path.join(ROOT, 'train.csv'))
print(df.shape)
df.head(3)

In [6]:
## Plot Train Dataset ##
cell_mask_dir = '../input/hpa-mask/hpa_cell_mask'  
image_id = df.iloc[0].ID
cell_mask = np.load(f'{cell_mask_dir}/{image_id}.npz')['arr_0']
print_masked_img(image_id, cell_mask)

In [7]:
## Custom Dataset ##

MAX_THRE = 4 
p = Pool(processes=MAX_THRE)
annos = []
len_df = len(df)
for anno, idx, image_id in p.imap(mk_ann, range(len(df))):
    if len(anno['ann']) > 0:
        annos.append(anno)

In [8]:
lbl_cnt_dict = df.set_index('ID').to_dict()['Label']
trn_annos = []
val_annos = []
val_len = int(len(annos)*0.01)
for idx in range(len(annos)):
    ann = annos[idx]
    filename = ann['filename'].replace('.jpg','').replace('.png','')
    label_ids = lbl_cnt_dict[filename]
    len_ann = len(ann['ann']['bboxes'])
    bboxes = ann['ann']['bboxes']
    masks = ann['ann']['masks']
    for cnt, label_id in enumerate(label_ids.split('|')):
        label_id = int(label_id)
        if cnt == 0:
            ann['ann']['labels'] = np.full(len_ann, label_id)
        else:
            ann['ann']['bboxes'] = np.concatenate([ann['ann']['bboxes'],bboxes])
            ann['ann']['labels'] = np.concatenate([ann['ann']['labels'],np.full(len_ann, label_id)])
            ann['ann']['masks'] = ann['ann']['masks'] + masks    
    if idx < val_len:
        val_annos.append(ann)
    else:
        trn_annos.append(ann)

In [9]:
with open(f'../work/mmdet_full.pkl', 'wb') as f:
    pickle.dump(annos, f)
with open(f'../work/mmdet_trn.pkl', 'wb') as f:
    pickle.dump(trn_annos, f)
with open(f'../work/mmdet_val.pkl', 'wb') as f:
    pickle.dump(val_annos, f)

# training

In [10]:
## Training ##
config = "configs/human_protein_atlas/mask_rcnn_x101_fpn.py"
additional_conf = '--no-validate --cfg-options'
additional_conf += f' work_dir=../working/work_dir'
additional_conf += f' optimizer.lr=0.0025'

cmd = f'bash -x tools/dist_train.sh {config} 1 {additional_conf}'
!cd ../mmdetection; {cmd}