# Testing StarDist network for instance segmentation

In [1]:
from __future__ import print_function, unicode_literals, absolute_import, division
import sys
import numpy as np
import matplotlib
matplotlib.rcParams["image.interpolation"] = 'none'
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

from glob import glob
from tqdm import tqdm
import skimage as sk
from tifffile import imread
from csbdeep.utils import Path, normalize
import napari
import os
from stardist import fill_label_holes, random_label_cmap, calculate_extents, gputools_available
from stardist.matching import matching, matching_dataset
from stardist.models import Config2D, StarDist2D, StarDistData2D

np.random.seed(42)
lbl_cmap = random_label_cmap()

In [5]:
PNGs = sorted(glob('E:/Grainger_Lab/Amber/OIC-74_Zebrafish_RBC_Classification/tiles/**/instTiles/*.png',recursive=True))

In [6]:
len(PNGs)

12051

In [8]:
#Converted all pngs to tiffs
for y in PNGs:
    path = os.path.dirname(y)
    name = os.path.basename(y)
    img = sk.io.imread(y)
    sk.io.imsave(os.path.join(path,name[:-4]+'.tif'),img,check_contrast=False)

In [2]:
X = sorted(glob('E:/Grainger_Lab/Amber/OIC-74_Zebrafish_RBC_Classification/tiles/**/imgTiles/*.tif',recursive=True))
Y = sorted(glob('E:/Grainger_Lab/Amber/OIC-74_Zebrafish_RBC_Classification/tiles/**/instTiles/*.tif',recursive=True))
assert all(Path(x).name==Path(y).name for x,y in zip(X,Y)) #added indexing here because the file type is different between the images and the masks, using indexing to match the name of the file without file type

In [None]:
#Y = sorted(glob('E:/Grainger_Lab/Amber/OIC-74_Zebrafish_RBC_Classification/tiles/2025_01_28__0346-Scene-1-ScanRegion0/instTiles/*.png'))

In [None]:
viewer = napari.view_image(Y_test)

In [3]:
X = list(map(imread,X))
Y = list(map(imread,Y))

In [4]:
X = [np.expand_dims(x, axis=0) for x in X]

In [7]:
Y = [np.expand_dims(y, axis=(0,-1)) for y in Y]

In [14]:
X[0].dtype

dtype('uint8')

In [12]:
n_channel = 3

In [15]:
Y_labels = []
for img in Y:
    array = np.array(img)
    array = (array * 255).astype(np.uint8)
    Y_labels.append(array)

In [17]:
Y_labels[0].dtype

dtype('uint8')

In [None]:
axis_norm = (0,1)   # normalize channels independently
# axis_norm = (0,1,2) # normalize channels jointly
if n_channel > 1:
    print("Normalizing image channels %s." % ('jointly' if axis_norm is None or 2 in axis_norm else 'independently'))
    sys.stdout.flush()

X = [normalize(x,1,99.8,axis=axis_norm) for x in tqdm(X)]
Y = [fill_label_holes(y) for y in tqdm(Y_labels)]

Normalizing image channels independently.


 10%|███████▋                                                                     | 1198/12051 [01:34<13:43, 13.18it/s]

In [None]:
len(Y)

In [6]:
assert len(X) > 1, "not enough training data"
rng = np.random.RandomState(42)
ind = rng.permutation(len(X))
n_val = max(1, int(round(0.15 * len(ind))))
ind_train, ind_val = ind[:-n_val], ind[-n_val:]
X_val, Y_val = [X[i] for i in ind_val]  , [Y_labels[i] for i in ind_val]
X_trn, Y_trn = [X[i] for i in ind_train], [Y_labels[i] for i in ind_train] 
print('number of images: %3d' % len(X))
print('- training:       %3d' % len(X_trn))
print('- validation:     %3d' % len(X_val))

number of images: 12051
- training:       10243
- validation:     1808


In [15]:
# 32 is a good default choice (see 1_data.ipynb)
n_rays = 32

# Use OpenCL-based computations for data generator during training (requires 'gputools')
use_gpu = True and gputools_available()

# Predict on subsampled grid for increased efficiency and larger field of view
grid = (4,4)

conf = Config2D (
    n_rays       = n_rays,
    grid         = grid,
    use_gpu      = use_gpu,
    n_channel_in = n_channel,
    train_patch_size = (512,512),
    train_steps_per_epoch = 100,
    train_epochs = 400,
)
print(conf)
vars(conf)

Config2D(n_dim=2, axes='YXC', n_channel_in=3, n_channel_out=33, train_checkpoint='weights_best.h5', train_checkpoint_last='weights_last.h5', train_checkpoint_epoch='weights_now.h5', n_rays=32, grid=(4, 4), backbone='unet', n_classes=None, unet_n_depth=3, unet_kernel_size=(3, 3), unet_n_filter_base=32, unet_n_conv_per_depth=2, unet_pool=(2, 2), unet_activation='relu', unet_last_activation='relu', unet_batch_norm=False, unet_dropout=0.0, unet_prefix='', net_conv_after_unet=128, net_input_shape=(None, None, 3), net_mask_shape=(None, None, 1), train_shape_completion=False, train_completion_crop=32, train_patch_size=(512, 512), train_background_reg=0.0001, train_foreground_only=0.9, train_sample_cache=True, train_dist_loss='mae', train_loss_weights=(1, 0.2), train_class_weights=(1, 1), train_epochs=400, train_steps_per_epoch=100, train_learning_rate=0.0003, train_batch_size=4, train_n_val_patches=None, train_tensorboard=True, train_reduce_lr={'factor': 0.5, 'patience': 40, 'min_delta': 0}, 

{'n_dim': 2,
 'axes': 'YXC',
 'n_channel_in': 3,
 'n_channel_out': 33,
 'train_checkpoint': 'weights_best.h5',
 'train_checkpoint_last': 'weights_last.h5',
 'train_checkpoint_epoch': 'weights_now.h5',
 'n_rays': 32,
 'grid': (4, 4),
 'backbone': 'unet',
 'n_classes': None,
 'unet_n_depth': 3,
 'unet_kernel_size': (3, 3),
 'unet_n_filter_base': 32,
 'unet_n_conv_per_depth': 2,
 'unet_pool': (2, 2),
 'unet_activation': 'relu',
 'unet_last_activation': 'relu',
 'unet_batch_norm': False,
 'unet_dropout': 0.0,
 'unet_prefix': '',
 'net_conv_after_unet': 128,
 'net_input_shape': (None, None, 3),
 'net_mask_shape': (None, None, 1),
 'train_shape_completion': False,
 'train_completion_crop': 32,
 'train_patch_size': (512, 512),
 'train_background_reg': 0.0001,
 'train_foreground_only': 0.9,
 'train_sample_cache': True,
 'train_dist_loss': 'mae',
 'train_loss_weights': (1, 0.2),
 'train_class_weights': (1, 1),
 'train_epochs': 400,
 'train_steps_per_epoch': 100,
 'train_learning_rate': 0.0003,


In [16]:
model = StarDist2D(conf, name='BloodCells', basedir='models')

Using default values: prob_thresh=0.5, nms_thresh=0.4.


In [17]:
median_size = calculate_extents(list(Y_labels), np.median)
fov = np.array(model._axes_tile_overlap('YX'))
print(f"median object size:      {median_size}")
print(f"network field of view :  {fov}")
if any(median_size > fov):
    print("WARNING: median object size larger than field of view of the neural network.")

ValueError: all input arrays must have the same shape

In [None]:
def random_fliprot(img, mask): 
    assert img.ndim >= mask.ndim
    axes = tuple(range(mask.ndim))
    perm = tuple(np.random.permutation(axes))
    img = img.transpose(perm + tuple(range(mask.ndim, img.ndim))) 
    mask = mask.transpose(perm) 
    for ax in axes: 
        if np.random.rand() > 0.5:
            img = np.flip(img, axis=ax)
            mask = np.flip(mask, axis=ax)
    return img, mask 

def random_intensity_change(img):
    img = img*np.random.uniform(0.6,2) + np.random.uniform(-0.2,0.2)
    return img


def augmenter(x, y):
    """Augmentation of a single input/label image pair.
    x is an input image
    y is the corresponding ground-truth label image
    """
    x, y = random_fliprot(x, y)
    x = random_intensity_change(x)
    # add some gaussian noise
    sig = 0.02*np.random.uniform(0,1)
    x = x + sig*np.random.normal(0,1,x.shape)
    return x, y

In [None]:
model.train(X_trn, Y_trn, validation_data=(X_val,Y_val), augmenter=augmenter)

In [None]:
model.optimize_thresholds(X_val, Y_val)