# DL: Training Stardist and Cellpose Models

This notebook and the following **DL: Benchmark Table** is part of a small experiment.

The aim is to build a benchmark table for `bioimageloader`'s collection of instance segmentation datasets, using built-in and custom trained models.

In this notebook, we use [StarDist](https://github.com/stardist/stardist) and [Cellpose](https://github.com/mouseLand/cellpose) with `bioimageloader` to train our own models using a combined dataset.

It serves as a demonstration on how you can use `bioimageloader` to do model training. The code blocks can be easily modified to adapt to your own tasks. 

## 1. StarDist

This tutorial is adapted from the github notebook (https://github.com/stardist/stardist/blob/master/examples/2D/2_training.ipynb)

In [1]:
%env CUDA_VISIBLE_DEVICES=0
%env TF_CPP_MIN_LOG_LEVEL=3

#Built-in 
import warnings
import logging
import sys

#Ignoring warnings for notebook compilation (might not work)
warnings.filterwarnings('ignore')
logging.getLogger("tensorflow").setLevel(logging.ERROR)


#Bioimageloader and Albumentation
import albumentations as A
from bioimageloader import Config, BatchDataloader, ConcatDataset
from bioimageloader.transforms import SqueezeGrayImageHWC, HWCToCHW
from bioimageloader.collections import (BBBC020, ComputationalPathology, S_BSST265, 
    DSB2018, FRUNet, BBBC039, BBBC006, Cellpose, LIVECell)

#Stardist
from stardist import fill_label_holes, random_label_cmap, calculate_extents
from stardist.matching import matching, matching_dataset
from stardist.models import Config2D, StarDist2D
from csbdeep.utils import Path, normalize

#Cellpose imports
import torch
from cellpose import models

#Other imports
#!pip install matplotlib seaborn pandas tqdm numpy
from tqdm.notebook import tqdm
import numpy as np

env: CUDA_VISIBLE_DEVICES=0
env: TF_CPP_MIN_LOG_LEVEL=3


### Loading datasets
First, we load our collection of datasets with instance masks together: DSB2018, ComputationalPathology, BBBC006, BBBC020, BBBC039, S_BSST265, FRUNet, Cellpose and LIVECell. 

As each dataset have different numbers of images, we perform data augmentation using `albumentations` library for smaller datasets. We invert the ComPath and LIVECell datasets so they look more like the rest.

The datasets are then combined using `ConcatDataset`.

In [2]:
#Transformations
transforms = A.Compose([
    A.Resize(256, 256),
    SqueezeGrayImageHWC()
])

transforms_compath = A.Compose([
    A.Resize(256, 256),
    A.InvertImg(p=1.0),
    A.OneOf([
        A.VerticalFlip(p=0.5),
        A.HorizontalFlip(p=0.5),
    ], p=0.66),
    A.OneOf([
        A.RandomBrightnessContrast(p=0.2),
        A.Rotate(p=0.5, limit=80),
    ], p=0.66),
    SqueezeGrayImageHWC()
])

transforms_livecell = A.Compose([
    A.Resize(512, 512),
    A.RandomCrop(256,256),
    A.InvertImg(p=1.0),
    SqueezeGrayImageHWC() 
])

transforms_020 = A.Compose([
    A.Resize(512, 512),
    A.RandomCrop(256,256),
    A.OneOf([
        A.VerticalFlip(p=0.5),
        A.HorizontalFlip(p=0.5),
    ], p=0.66),
    A.OneOf([
        A.RandomBrightnessContrast(p=0.2),
        A.Rotate(p=0.5, limit=80),
    ], p=0.66),
    SqueezeGrayImageHWC() 
])

bbbc020 = BBBC020('./Data/bbbc/020', grayscale=True, image_ch=["nuclei"], transforms=transforms_020, num_samples=80)
comp = ComputationalPathology('./Data/ComputationalPathology', grayscale=True, transforms=transforms_compath, num_samples=80)
dsb2018 = DSB2018('./Data/data-science-bowl-2018', grayscale=True, training=True, transforms=transforms, num_samples=80)
sbss = S_BSST265('./Data/BioStudies', transforms=transforms, num_samples=80)
frunet = FRUNet('./Data/FRU_processing', transforms=transforms, num_samples=80)
bbbc006 = BBBC006('./Data/bbbc/006', grayscale=True, transforms=transforms, num_samples=80)
bbbc039 = BBBC039('./Data/bbbc/039', transforms=transforms, num_samples=80)
cellpose = Cellpose('./Data/cellpose', grayscale=True, transforms=transforms, num_samples=80)
livecell = LIVECell('./Data/LIVECell', transforms=transforms_livecell, save_tif=False, training=True, num_samples=80)

dset = ConcatDataset([bbbc020, comp, dsb2018, sbss, frunet, bbbc006, bbbc039, cellpose, livecell])

### Normalization and train/val split
We follow StarDist's sample notebook to perform normalization and train/val split.

In [3]:
X = list()
Y = list()
for d in tqdm(dset):
    X.append(d["image"])
    Y.append(d["mask"])

  0%|          | 0/720 [00:00<?, ?it/s]

In [4]:
n_channel = 1 if X[0].ndim == 2 else X[0].shape[-1]
axis_norm = (0,1)   # normalize channels independently
# axis_norm = (0,1,2) # normalize channels jointly
if n_channel > 1:
    print("Normalizing image channels %s." % ('jointly' if axis_norm is None or 2 in axis_norm else 'independently'))
    sys.stdout.flush()

X = [normalize(x,1,99.8,axis=axis_norm) for x in tqdm(X)]
Y = [fill_label_holes(y) for y in tqdm(Y)]


assert len(dset) > 1, "not enough training data"
rng = np.random.RandomState(42)
ind = rng.permutation(len(dset))
n_val = max(1, int(round(0.15 * len(ind))))
ind_train, ind_val = ind[:-n_val], ind[-n_val:]
X_val, Y_val = [X[i] for i in ind_val]  , [Y[i] for i in ind_val]
X_trn, Y_trn = [X[i] for i in ind_train], [Y[i] for i in ind_train] 
print('- training:       %3d' % len(X_trn))
print('- validation:     %3d' % len(X_val))

  0%|          | 0/720 [00:00<?, ?it/s]

  0%|          | 0/720 [00:00<?, ?it/s]

- training:       612
- validation:     108


### Initialize a stardist model

In [5]:
#Default parameters used in Stardist's own example notebook
n_rays = 32
grid = (2,2)
conf = Config2D (
    n_rays       = n_rays,
    grid         = grid,
    use_gpu      = True,
    n_channel_in = 1,
)

#Specify the name and directory of the model
model = StarDist2D(conf, name='stardist_model_1', basedir='stardist_models')


median_size = calculate_extents(list(Y), np.median)
fov = np.array(model._axes_tile_overlap('YX'))
print(f"median object size:      {median_size}")
print(f"network field of view :  {fov}")
if any(median_size > fov):
    print("WARNING: median object size larger than field of view of the neural network.")

Using default values: prob_thresh=0.5, nms_thresh=0.4.
median object size:      [16. 13.]
network field of view :  [94 93]


### Training and optimization

In [7]:
#We train then optimize the thresholds of a stardist model using the default parameters. 
#Epochs is set to 1 for demonstration
model.train(X_trn, Y_trn, validation_data=(X_val, Y_val), epochs=1)
model.optimize_thresholds(X_val, Y_val)


Loading network weights from 'weights_best.h5'.


NMS threshold = 0.3:  75%|████▌ | 15/20 [00:56<00:18,  3.78s/it, 0.199 -> 0.020]
NMS threshold = 0.4:  75%|████▌ | 15/20 [01:20<00:26,  5.37s/it, 0.199 -> 0.016]
NMS threshold = 0.5:  75%|████▌ | 15/20 [01:24<00:28,  5.62s/it, 0.199 -> 0.015]


Using optimized values: prob_thresh=0.198454, nms_thresh=0.3.
Saving to 'thresholds.json'.


{'prob': 0.19845443151963815, 'nms': 0.3}

## 2. Cellpose

### Loading datasets


We follow the same procedures as StarDist, except here we use `HWCToCHW` (explained in the previous notebook).

In [None]:
#Transformations
transforms = A.Compose([
    A.Resize(256, 256),
    HWCToCHW() 
])

transforms_compath = A.Compose([
    A.Resize(256, 256),
    A.InvertImg(p=1.0),
    A.OneOf([
        A.VerticalFlip(p=0.5),
        A.HorizontalFlip(p=0.5),
    ], p=0.66),
    A.OneOf([
        A.RandomBrightnessContrast(p=0.2),
        A.Rotate(p=0.5, limit=80),
    ], p=0.66),
    HWCToCHW()
])

transforms_livecell = A.Compose([
    A.Resize(512, 512),
    A.RandomCrop(256,256),
    A.InvertImg(p=1.0),
    HWCToCHW()
])

transforms_020 = A.Compose([
    A.Resize(512, 512),
    A.RandomCrop(256,256),
    A.OneOf([
        A.VerticalFlip(p=0.5),
        A.HorizontalFlip(p=0.5),
    ], p=0.66),
    A.OneOf([
        A.RandomBrightnessContrast(p=0.2),
        A.Rotate(p=0.5, limit=80),
    ], p=0.66),
    HWCToCHW()
])

bbbc020 = BBBC020('./Data/bbbc/020', grayscale=True, image_ch=["nuclei"], transforms=transforms_020, num_samples=80)
comp = ComputationalPathology('./Data/ComputationalPathology', grayscale=True, transforms=transforms_compath, num_samples=80)
dsb2018 = DSB2018('./Data/data-science-bowl-2018', grayscale=True, training=True, transforms=transforms, num_samples=80)
sbss = S_BSST265('./Data/BioStudies', transforms=transforms, num_samples=80)
frunet = FRUNet('./Data/FRU_processing', transforms=transforms, num_samples=80)
bbbc006 = BBBC006('./Data/bbbc/006', grayscale=True, transforms=transforms, num_samples=80)
bbbc039 = BBBC039('./Data/bbbc/039', transforms=transforms, num_samples=80)
cellpose = Cellpose('./Data/cellpose', grayscale=True, transforms=transforms, num_samples=80)
livecell = LIVECell('./Data/LIVECell', transforms=transforms_livecell, save_tif=False, training=True, num_samples=80)

dset = ConcatDataset([bbbc020, comp, dsb2018, sbss, frunet, bbbc006, bbbc039, cellpose, livecell])

In [None]:
X = list()
Y = list()
for d in tqdm(dset):
    X.append(d["image"])
    Y.append(d["mask"])

### Train/val split
For Cellpose, we do not need to normalize the data before training, it is done by specifying `normalize = True` in the training parameters.

In [None]:
assert len(dset) > 1, "not enough training data"
rng = np.random.RandomState(42)
ind = rng.permutation(len(dset))
n_val = max(1, int(round(0.15 * len(ind))))
ind_train, ind_val = ind[:-n_val], ind[-n_val:]
X_val, Y_val = [X[i] for i in ind_val]  , [Y[i] for i in ind_val]
X_trn, Y_trn = [X[i] for i in ind_train], [Y[i] for i in ind_train] 
print('- training:       %3d' % len(X_trn))
print('- validation:     %3d' % len(X_val))

### Initialize model and train

In [None]:
#We set n_epochs = 1 for demonstration
model = models.CellposeModel(pretrained_model=None, diam_mean=15, gpu=True)
model.train(train_data=X_trn, train_labels=Y_trn, train_files=None, test_data=X_val, test_labels=Y_val, test_files=None, normalize = True, 
              channels = [0,0], save_path='cellpose_models', save_every=1,
              learning_rate=0.01, n_epochs=1, momentum=0.9, weight_decay=0.00001, batch_size=32)