# **KDDM 2 SARTORIUS CELL INSTANCE SEGMENTATION CELLPOSE TRAINING NOTEBOOK**

# INSTALL All DEPENDENCIES
* **Install Cellpose and required dependencies**

In [None]:
!pip install ../input/kddm2/cellpose/fastremap-1.12.2-cp37-cp37m-manylinux2010_x86_64.whl --no-deps
!pip install ../input/kddm2/cellpose/natsort-8.0.1-py3-none-any.whl --no-deps
!pip install ../input/kddm2/cellpose/pytorch_ranger-0.1.1-py3-none-any.whl --no-deps
!pip install ../input/kddm2/cellpose/torch_optimizer-0.3.0-py3-none-any.whl --no-deps
!pip install ../input/kddm2/cellpose/numpy-1.20.3-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl --no-deps
!pip install ../input/kddm2/cellpose/cellpose-0.7.2-py3-none-any.whl --no-deps
!pip install ../input/kddm2/cellpose/edt-2.1.1-cp37-cp37m-manylinux2014_x86_64.whl --no-deps

!nvidia-smi

# **PREPARE TRAININGS DATA**
* **Generate masks and images**

In [None]:
import os
import matplotlib.pyplot as plt
import numpy as np
import cv2
import pandas as pd
from tqdm import tqdm

import shutil
import random
import torch


#define path to input image folder, annotation csv and output folders for images and masks
SARTORIUS_CLASSES = ['astro', 'cort', 'shsy5y']
annotation_file = "/kaggle/input/sartorius-cell-instance-segmentation/train.csv"
img_dir = "/kaggle/input/sartorius-cell-instance-segmentation/train"
train_dir = "/kaggle/working/train"
val_dir = "/kaggle/working/val"

!rm -rf $train_dir
!rm -rf $val_dir

os.makedirs(train_dir, exist_ok=True)
os.makedirs(val_dir, exist_ok=True)

#random train/val split
percentage_in_train = 0.95
frames = os.listdir(img_dir)
num_train = int(len(frames) * percentage_in_train)
random.seed(0)
train_frames = set( random.sample(frames, num_train) )


with open(annotation_file, "r") as f:
    anns = f.read().split("\n")
anns = anns[1:-1]
print("annotations", len(anns))

#create dict with annotation masks per img and convert rle to "start to end" pixel masks
anns_per_img_id = {}
for a in anns:
    values = a.split(",")
    id = values[0]
    cell_type = values[4]
    rle_mask = values[1].split(" ")

    starts = list(map(lambda x: int(x) - 1, rle_mask[0::2]))
    lengths = list(map(int, rle_mask[1::2]))
    ends = [x + y for x, y in zip(starts, lengths)]

    if id not in anns_per_img_id.keys():
        anns_per_img_id[id] = {}
        anns_per_img_id[id]["start_end"] = []

    anns_per_img_id[id]["start_end"].append( [starts, ends] )
    anns_per_img_id[id]["cell_type"] = cell_type
num_imgs_keys = len(anns_per_img_id.keys())
print("anns_per_img_id", num_imgs_keys)


#generate image masks for one cell type
selected_cell_type = 'cort'
img_mask = np.empty((704*520), dtype=np.uint16)
img_mask_2d = np.empty((520, 704), dtype=np.uint16)
for i in tqdm( range(0, num_imgs_keys) ):
    k = list(anns_per_img_id.keys())[i]
    anns = anns_per_img_id[k]
    masks = anns["start_end"]
    id = anns["cell_type"]

    if id != selected_cell_type:
        continue
    
    img_mask.fill(0)
    pixel_value_per_instance = 0
    for m in masks:
        pixel_value_per_instance = pixel_value_per_instance + 1
        for start, end in zip(m[0], m[1]):
            img_mask[start:end] = pixel_value_per_instance
    img_mask_2d = img_mask.reshape((520, 704))
    #print(np.max(img_mask_2d))
    
    orig_fname = "{:s}.png".format(k)
    orig_file = cv2.imread(os.path.join(img_dir, orig_fname))
    out_dir = train_dir if orig_fname in train_frames else val_dir
    #shutil.copy( os.path.join(img_dir, orig_fname), os.path.join(out_dir, fname_img) )
    cv2.imwrite( os.path.join(out_dir, "{:s}.tif".format(k)), orig_file[:, :, 0] )
    cv2.imwrite( os.path.join(out_dir, "{:s}_masks.tif".format(k)), img_mask_2d )
    flip = random.randint(0, 2)
    if flip == 2:
        continue
    cv2.imwrite( os.path.join(out_dir, "{:s}flip.tif".format(k)), cv2.flip(orig_file[:, :, 0], flip) )
    cv2.imwrite( os.path.join(out_dir, "{:s}flip_masks.tif".format(k)), cv2.flip(img_mask_2d, flip) )
    
print( "train imgs:", len(os.listdir(train_dir)) // 2 )
print( "val imgs:", len(os.listdir(val_dir)) // 2 )
!mkdir /kaggle/working/train/models/

# **TRAINING**


* **Load pretrained models**

In [None]:
import os
import shutil

cellpose_cache_path = os.path.join("/", "root", ".cellpose")
model_folder = os.path.join(cellpose_cache_path, "models")
!mkdir -p $model_folder
shutil.copy("/kaggle/input/kddm2/cellpose/cyto2torch_1", model_folder)

!ls /root/.cellpose/models/
!ls /kaggle/input/kddm2/cellpose
!ls /kaggle/working

* **Enable Cellpose logging for notebooks**

In [None]:
with open("/opt/conda/lib/python3.7/site-packages/cellpose/core.py", "r") as f:
    core_file = f.read().split("\n")
core_file[35] = "core_logger.addHandler(logging.StreamHandler(stream=sys.stdout))"
core_file[965] = "                if iepoch==self.n_epochs-1 or iepoch%save_every==0:"
core_file[972] = '                        file_name = "{}_{}_{}_{}".format(self.net_type, file_label, d.strftime("%Y_%m_%d_%H_%M_%S.%f"), "epoch_" + str(iepoch))'
with open("/opt/conda/lib/python3.7/site-packages/cellpose/core.py", "w") as f:
    f.write( ("\n").join(core_file) )
    
with open("/opt/conda/lib/python3.7/site-packages/cellpose/models.py", "r") as f:
    models_file = f.read().split("\n")
models_file[11] = "models_logger.addHandler(logging.StreamHandler(stream=sys.stdout))"
with open("/opt/conda/lib/python3.7/site-packages/cellpose/models.py", "w") as f:
    f.write( ("\n").join(models_file) )

with open("/opt/conda/lib/python3.7/site-packages/cellpose/__main__.py", "r") as f:
    main_file = f.read().split("\n")
main_file[24] = "logger.addHandler(logging.StreamHandler(stream=sys.stdout))"
with open("/opt/conda/lib/python3.7/site-packages/cellpose/__main__.py", "w") as f:
    f.write( ("\n").join(main_file) )

* **Execute training**

In [None]:
ptm = "cyto2_torch" #set pretrained model
!python -m cellpose --train --dir $train_dir --test_dir $val_dir --pretrained_model $ptm --diameter 16 --n_epochs 200 --save_every 10 --learning_rate 0.001 --flow_threshold 0.3 --mask_threshold -0.3 --verbose --use_gpu --chan 0 --chan 0 --batch_size 8

* **Check cellpose output**

In [None]:
#!mkdir /kaggle/working/train && mkdir /kaggle/working/train/models/
#!cp ../input/cp-sartorius/cellpose_residual_on_style_on_concatenation_off_train_2021_12_11_21_36_25.480280 /kaggle/working/train/models/
!date
#!ls -l /kaggle/working/models/
!ls -l /root/.cellpose/models/
!ls -lh /kaggle/working/train/models/

# **INFERENCE**

* **Functions to compute mAP**

In [None]:
def compute_iou(labels, y_pred):
    """
    Computes the IoU for instance labels and predictions.
    Args:
        labels (np array): Labels.
        y_pred (np array): predictions
    Returns:
        np array: IoU matrix, of size true_objects x pred_objects.
    """
    #print(np.unique(labels, return_counts=True))
    #print(np.unique(y_pred, return_counts=True))
    true_objects = len(np.unique(labels))
    pred_objects = len(np.unique(y_pred))
    # Compute intersection between all objects
    intersection = np.histogram2d(labels.flatten(), y_pred.flatten(), bins=(true_objects, pred_objects))[0]
    # Compute areas (needed for finding the union between all objects)
    area_true = np.histogram(labels, bins=true_objects)[0]
    area_pred = np.histogram(y_pred, bins=pred_objects)[0]
    area_true = np.expand_dims(area_true, -1)
    area_pred = np.expand_dims(area_pred, 0)
    # Compute union
    union = area_true + area_pred - intersection
    iou = intersection / union
    return iou[1:, 1:]  # exclude background

def precision_at(threshold, iou):
    """
    Computes the precision at a given threshold.

    Args:
        threshold (float): Threshold.
        iou (np array [n_truths x n_preds]): IoU matrix.

    Returns:
        int: Number of true positives,
        int: Number of false positives,
        int: Number of false negatives.
    """
    matches = iou > threshold
    true_positives = np.sum(matches, axis=1) >= 1  # Correct objects
    false_negatives = np.sum(matches, axis=1) == 0  # Missed objects
    false_positives = np.sum(matches, axis=0) == 0  # Extra objects
    tp, fp, fn = (
        np.sum(true_positives),
        np.sum(false_positives),
        np.sum(false_negatives),
    )
    return tp, fp, fn

def iou_map(truths, preds, verbose=0):
    """
    Computes the metric for the competition.
    Masks contain the segmented pixels where each object has one value associated,
    and 0 is the background.

    Args:
        truths (list of masks): Ground truths.
        preds (list of masks): Predictions.
        verbose (int, optional): Whether to print infos. Defaults to 0.

    Returns:
        float: mAP.
    """
    ious = [compute_iou(truth, pred) for truth, pred in zip(truths, preds)]
    
    print(ious[0].shape)

    if verbose:
        print("Thresh\tTP\tFP\tFN\tPrec.")

    prec = []
    iou_per_img = np.empty((len(ious), len(truths)))
    for i, t in enumerate(np.arange(0.5, 1.0, 0.05)):
        tps, fps, fns = 0, 0, 0
        for j, iou in enumerate(ious):
            tp, fp, fn = precision_at(t, iou)
            tps += tp
            fps += fp
            fns += fn
            iou_per_img[i, j] = tp / (tp + fp + fn)

        p = tps / (tps + fps + fns)
        prec.append(p)

        if verbose:
            print("{:1.3f}\t{}\t{}\t{}\t{:1.3f}".format(t, tps, fps, fns, p))

    if verbose:
        print("AP\t-\t-\t-\t{:1.3f}".format(np.mean(prec)))
    
    return np.mean(prec), iou_per_img

* **Inference script has to be written to file due to numpy version issues on Kaggle notebooks**

In [None]:
%%writefile run.py
from cellpose import models, io, plot
import cv2
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
from pathlib import Path
import sys
from tqdm import tqdm

def rle_encode(img):
    pixels = img.flatten()
    pixels = np.concatenate([[0], pixels, [0]])
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
    runs[1::2] -= runs[::2]
    return ' '.join(str(x) for x in runs)

if __name__ == "__main__":
    pretrained_model_path = sys.argv[1]
    dm = int(sys.argv[2])
    ft = float(sys.argv[3])
    mt = float(sys.argv[4])
    #model_dir = '/kaggle/working/train/models'
    #model_list = os.listdir(model_dir)
    #model_list.sort()
    #pretrained_model_path = os.path.join(model_dir, model_list[-1])
    
    print("Testing model: ", pretrained_model_path)
    print("DM: {:d}   FT: {:.2f}   MT: {:.2f}".format(dm, ft, mt))

    test_dir = Path('/kaggle/working/val')
    test_files = [fname for fname in test_dir.iterdir() if ("_mask" not in fname.stem and "_flow" not in fname.stem and "flip" not in fname.stem)]

    model = models.CellposeModel(gpu=True, pretrained_model=pretrained_model_path)

    ids, masks = [],[]
    for i in tqdm( range(len(test_files)) ):
        fn = test_files[i]
        #print( "[{:2d}/{:d}] {:s}".format(i, len(test_files), fn.stem) )
        img = io.imread( str(fn) )
        preds, flows, _ = model.eval(img, diameter=dm, channels=[0,0], augment=True, resample=True, flow_threshold=ft, mask_threshold=mt, omni=False)
        cv2.imwrite( "/kaggle/working/masks/{:s}.tif".format(fn.stem), preds )
        for i in range (1, preds.max() + 1):
            ids.append(fn.stem)
            masks.append(rle_encode(preds == i))

    pd.DataFrame({'id': ids, 'predicted': masks}).to_csv('submission.csv', index=False)

* **Evaluate models with different hyperparameters**

In [None]:
models = ["/kaggle/input/kddm2/cellpose/cort9"]

dms = [16, 18]
fts = [0.3]
mts = [-0.3]

for model in models:
    for dm in dms:
        for ft in fts:
            for mt in mts:
                !rm -rf /kaggle/working/masks

                !mkdir /kaggle/working/masks
                !python run.py $model $dm $ft $mt
                pred_files = os.listdir("/kaggle/working/masks")
                y_pred = []
                masks = []
                for f in pred_files:
                    pred = cv2.imread("/kaggle/working/masks/{:s}.tif".format(f[:-4]), -1)
                    m = cv2.imread("/kaggle/working/val/{:s}_masks.tif".format(f[:-4]), -1)
                    y_pred.append(pred)
                    masks.append(m)

                __, iou_per_img = iou_map(masks, y_pred, verbose=1)
            