Sartorius cell instance segmentation.

This notebook does the following.
- Install detectron2
- preprocess training data
- Fine tune mask r-cnn model
- Generate prediction on test dataset
- Write submission file

v2 

# Install detectron2 

In [None]:
!pip install ../input/detectron-05/whls/pycocotools-2.0.2/dist/pycocotools-2.0.2.tar --no-index --find-links ../input/detectron-05/whls 
!pip install ../input/detectron-05/whls/fvcore-0.1.5.post20211019/fvcore-0.1.5.post20211019 --no-index --find-links ../input/detectron-05/whls 
!pip install ../input/detectron-05/whls/antlr4-python3-runtime-4.8/antlr4-python3-runtime-4.8 --no-index --find-links ../input/detectron-05/whls 
!pip install ../input/detectron-05/whls/detectron2-0.5/detectron2 --no-index --find-links ../input/detectron-05/whls 

# Load training data

In [None]:
import torch
import pandas as pd
import numpy as np
import glob
import cv2
from tqdm import tqdm
import matplotlib.pyplot as plt

# Some basic setup:
# Setup detectron2 logger
import detectron2
from detectron2.utils.logger import setup_logger
setup_logger()

# import some common libraries
import numpy as np
import os, json, cv2, random
#from google.colab.patches import cv2_imshow

# import some common detectron2 utilities
from detectron2 import model_zoo
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2.utils.visualizer import Visualizer
from detectron2.data import MetadataCatalog, DatasetCatalog
import pycocotools
import skimage.measure
from fastcore.all import *


DATA_PATH = '../input/sartorius-cell-instance-segmentation'
train_info = pd.read_csv(DATA_PATH + '/train.csv')

In [None]:
# define training labels


cell_type_to_cat = {}
cat_to_cell_type = {}
cat_list = []
for i, ct in enumerate(train_info['cell_type'].unique()):
    cell_type_to_cat[ct] = i
    cat_to_cell_type[i] = ct
    cat_list.append(ct)
cat_list_three = cat_list
print(cell_type_to_cat)
print(cat_list_three)

cat_one_type = {}
for t in cell_type_to_cat:
    cat_one_type[t] = 0
cat_list_one = ["cell"]
print(cat_one_type)
print(cat_list_one)

ONE_TYPE = False  # True: only has "cell" label. False: three cell type labels  
INCLUDE_CELL_TYPE = cat_list_three # whether to include a cell type in training. Could be any combination of ['cort', 'shsy5y', 'astro']

if ONE_TYPE:
    cat_dict = cat_one_type
    cat_list = cat_list_one
else:
    cat_dict = cell_type_to_cat
    cat_list = cat_list_three


In [None]:
def mask_to_bbox(mask):
    #return boudning box of a mask
    maskx = np.where(mask.sum(1))[0]
    masky = np.where(mask.sum(0))[0]
    xmin = maskx[0]
    xmax = maskx[-1]
    ymin = masky[0]
    ymax = masky[-1]
    return (ymin, xmin, ymax, xmax)


def rle_to_polygon(mask_rle, shape):
    '''
    mask_rle: run-length as string formated (start length)
    shape: (height, width) of array to return 
    Returns numpy array (mask)
    '''
    # convert rle to bitmask
    s = mask_rle.split()
    starts = list(map(lambda x: int(x) - 1, s[0::2]))
    lengths = list(map(int, s[1::2]))
    ends = [x + y for x, y in zip(starts, lengths)]
    img = np.zeros((shape[0] * shape[1]), dtype=np.uint8)    
    for start, end in zip(starts, ends):
        img[start : end] = 1
    img = img.reshape(shape)
    
    # bounding box
    bbox = mask_to_bbox(img)
    
    contours = skimage.measure.find_contours(img, 0.5)
    polygon = []
    for contour in contours:
        contour = np.flip(contour, axis=1)
        segmentation = contour.ravel().tolist()
        polygon.append(segmentation)    
    return bbox, polygon, contours


from detectron2.structures import BoxMode
def get_cell_dicts(img_dir, val_split = 0.02):
    train_info = pd.read_csv(img_dir + '/train.csv')

    train_dataset = []
    val_dataset = []
    #too_short = False
    
    for filename in tqdm(sorted(glob.glob(img_dir +  '/train/*.png'))):
        record = {}
        
        height, width = cv2.imread(filename).shape[:2]
        
        record["file_name"] = filename
        record["image_id"] = filename.split('/')[-1].strip('.png')
        record["height"] = height
        record["width"] = width
        
        objs = []
        idx = train_info["id"] == record["image_id"]
        annotations = train_info[idx]["annotation"].tolist()
        cell_types = train_info[idx]["cell_type"].tolist()
        
        for annot, ct in zip(annotations, cell_types):
            if ct not in INCLUDE_CELL_TYPE:
                continue
            #py,px = rle_to_xy(annot, shape=(height, width))
#             poly = [(x + 0.5, y + 0.5) for x, y in zip(px, py)]
#             poly = [p for x in poly for p in x]
            #mask = rle_decode(annot, (width, height, 1))
            #bitmask=pycocotools.mask.encode(np.asarray(mask, order="F", dtype='uint8'))
            bbox, polygon,_ = rle_to_polygon(annot, shape = (520, 704))
            #polygon, bbox = rle_to_polygon_and_bbox(annot, shape = (520, 704))
            
            valid_polygons = True
            for p in polygon:
                if len(p) < 6:
                    print(record["image_id"] + ': polygon list too short%s. Annotation skipped.'%str(p))
                    valid_polygons = False
            if not valid_polygons:
                continue

            obj = {
                #"bbox": [np.min(px), np.min(py), np.max(px), np.max(py)],
                "bbox": bbox,
                "bbox_mode": BoxMode.XYXY_ABS,
                "segmentation": polygon,
                #"segmentation": bitmask,
                "category_id": cat_dict[ct],
            }            
            objs.append(obj)
        record["annotations"] = objs
        if random.random() <= val_split:
            val_dataset.append(record)
        else:
            train_dataset.append(record)
            
#         if too_short:
#             break
    return train_dataset, val_dataset

train_dataset, val_dataset = get_cell_dicts(DATA_PATH)

In [None]:
DatasetCatalog.register("cell_train", lambda : train_dataset)
MetadataCatalog.get("cell_train").set(thing_classes=cat_list)
DatasetCatalog.register("cell_val", lambda : val_dataset)
MetadataCatalog.get("cell_val").set(thing_classes=cat_list)


cell_metadata = MetadataCatalog.get("cell_train")

# Visualize training image and mask

In [None]:
#dataset_dicts = get_cell_dicts(DATA_PATH)
for d in random.sample(train_dataset, 1):
    print(d['image_id'])
    img = cv2.imread(d["file_name"])
    visualizer = Visualizer(img[:, :, ::-1], metadata=cell_metadata, scale=2)
    out = visualizer.draw_dataset_dict(d)
    plt.figure(figsize=(12, 16))
    plt.imshow(out.get_image()[:, :, ::-1])
    plt.title('labeled')
    plt.figure(figsize=(12, 16))
    plt.imshow(img)
    plt.title('original')

# Train

In [None]:
# train mask rcnn
from detectron2.engine import DefaultTrainer
# !rm ./output/*
cfg = get_cfg()
cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
cfg.DATASETS.TRAIN = ("cell_train",)
cfg.DATASETS.TEST = ("cell_val",)
cfg.DATALOADER.NUM_WORKERS = 2
cfg.MODEL.WEIGHTS = '../input/detectron-retinanet/FR-CNN_101.pkl'  # Let training initialize from model zoo
cfg.SOLVER.IMS_PER_BATCH = 2
cfg.SOLVER.BASE_LR = 0.001  # pick a good LR
cfg.SOLVER.MAX_ITER = 12000
cfg.SOLVER.STEPS = []        # do not decay learning rate
cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 128   
cfg.MODEL.ROI_HEADS.NUM_CLASSES = len(cat_list) 
cfg.OUTPUT_DIR = './output'
# NOTE: this config means the number of classes, but a few popular unofficial tutorials incorrect uses num_classes+1 here.
#print(cfg)
os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)
with open(cfg.OUTPUT_DIR+"/config.yaml", "w") as file:
    file.write(cfg.dump())
trainer = DefaultTrainer(cfg) 
trainer.resume_or_load(resume=False)
trainer.train()

# validation

In [None]:
# setup 
cfg.MODEL.WEIGHTS = os.path.join(cfg.OUTPUT_DIR, "model_final.pth")  # path to the model we just trained
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.4   # set a custom testing threshold
predictor = DefaultPredictor(cfg)

val_metadata = MetadataCatalog.get("cell_val")
val_metadata

In [None]:
# Visualize prediction and truth
from detectron2.utils.visualizer import ColorMode
for d in random.sample(val_dataset, 1):    
    im = cv2.imread(d["file_name"])
    outputs = predictor(im)  # format is documented at https://detectron2.readthedocs.io/tutorials/models.html#model-output-format
    v = Visualizer(im[:, :, ::-1],
                   metadata=cell_metadata, 
                   scale=1, 
#                   instance_mode=ColorMode.IMAGE_BW   # remove the colors of unsegmented pixels. This option is only available for segmentation models
    )
    
    op = outputs["instances"].to("cpu")
    out = v.draw_instance_predictions(op)
    plt.figure(figsize=(12,16))
    plt.imshow(out.get_image()[:, :, ::-1])
    
    
    # plot truth
    plt.figure(figsize=(12,16))
    print(train_info['cell_type'][d['image_id']==train_info['id']].unique())
    visualizer = Visualizer(im[:, :, ::-1], metadata=val_metadata, scale=.5)
    out = visualizer.draw_dataset_dict(d)
    plt.figure(figsize=(12, 16))
    plt.imshow(out.get_image()[:, :, ::-1])    

In [None]:
from detectron2.evaluation import COCOEvaluator, inference_on_dataset
from detectron2.data import build_detection_test_loader
evaluator = COCOEvaluator("cell_val", output_dir=cfg.OUTPUT_DIR)
val_loader = build_detection_test_loader(cfg, "cell_val")
print(inference_on_dataset(predictor.model, val_loader, evaluator))
# another equivalent way to evaluate the model is to use `trainer.test`

# Process and predict on Test dataset

In [None]:
def get_cell_test_dicts(img_dir):
    test_dataset = []
    for filename in tqdm(sorted(glob.glob(img_dir +  '/test/*.png'))):
        record = {}
        
        height, width = cv2.imread(filename).shape[:2]
        
        record["file_name"] = filename
        record["image_id"] = filename.split('/')[-1].strip('.png')
        record["height"] = height
        record["width"] = width
        test_dataset.append(record)
    return test_dataset


test_dataset = get_cell_test_dicts(DATA_PATH)

In [None]:
MODEL_DIR = cfg.OUTPUT_DIR

print(len(test_dataset))
try:
    DatasetCatalog.remove('cell_test')
    MetadataCatalog.remove('cell_test')
    
except:
    pass


DatasetCatalog.register("cell_test", lambda : train_dataset)
MetadataCatalog.get("cell_test").set(thing_classes=cat_list)
test_metadata = MetadataCatalog.get("cell_test")

cfg = get_cfg()
cfg.merge_from_file(MODEL_DIR + "/config.yaml")
cfg.MODEL.WEIGHTS = os.path.join(cfg.OUTPUT_DIR, "model_final.pth")  # path to the model we just trained
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.4   # set a custom testing threshold
predictor = DefaultPredictor(cfg)

In [None]:
# Visualize prediction results
from detectron2.utils.visualizer import ColorMode
for d in random.sample(test_dataset, 1):    
    im = cv2.imread(d["file_name"])
    outputs = predictor(im)  # format is documented at https://detectron2.readthedocs.io/tutorials/models.html#model-output-format
    v = Visualizer(im[:, :, ::-1],
                   metadata=test_metadata, 
                   scale=1, 
#                   instance_mode=ColorMode.IMAGE_BW   # remove the colors of unsegmented pixels. This option is only available for segmentation models
    )
    
    op = outputs["instances"].to("cpu")
    out = v.draw_instance_predictions(op)
    plt.figure(figsize=(12,16))
    plt.imshow(out.get_image()[:, :, ::-1])

## Write submission file

In [None]:
ids, masks=[],[]
dataDir=Path('../input/sartorius-cell-instance-segmentation')
test_names = (dataDir/'test').ls()
# From https://www.kaggle.com/stainsby/fast-tested-rle
def rle_decode(mask_rle, shape=(520, 704)):
    '''
    mask_rle: run-length as string formated (start length)
    shape: (height,width) of array to return 
    Returns numpy array, 1 - mask, 0 - background

    '''
    s = mask_rle.split()
    starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
    starts -= 1
    ends = starts + lengths
    img = np.zeros(shape[0]*shape[1], dtype=np.uint8)
    for lo, hi in zip(starts, ends):
        img[lo:hi] = 1
    return img.reshape(shape)  # Needed to align to RLE direction

def rle_encode(img):
    '''
    img: numpy array, 1 - mask, 0 - background
    Returns run length as string formated
    '''
    pixels = img.flatten()
    pixels = np.concatenate([[0], pixels, [0]])
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
    runs[1::2] -= runs[::2]
    return ' '.join(str(x) for x in runs)

def get_masks(fn, predictor):
    im = cv2.imread(str(fn))
    outputs = predictor(im)
    pred_masks = outputs['instances'].pred_masks.cpu().numpy()
    res = []
    used = np.zeros(im.shape[:2], dtype=int) 
    for mask in pred_masks:
        mask = mask * (1-used)
        used += mask
        res.append(rle_encode(mask))
    return res


for fn in test_names:
    encoded_masks = get_masks(fn, predictor)
    for enc in encoded_masks:
        ids.append(fn.stem)
        masks.append(enc)

pd.DataFrame({'id':ids, 'predicted':masks}).to_csv('submission.csv', index=False)
pd.read_csv('submission.csv').head()