# Testing StarDist network for instance segmentation

In [1]:
from __future__ import print_function, unicode_literals, absolute_import, division
import sys
import numpy as np
import matplotlib
matplotlib.rcParams["image.interpolation"] = 'none'
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

from glob import glob
from tqdm import tqdm
import skimage as sk
from tifffile import imread
from csbdeep.utils import Path, normalize
import napari
import os
from stardist import fill_label_holes, random_label_cmap, calculate_extents, gputools_available
from stardist.matching import matching, matching_dataset
from stardist.models import Config2D, StarDist2D, StarDistData2D

np.random.seed(42)
lbl_cmap = random_label_cmap()

Use the following 2 cells to convert pngs of masks exported from QuPath to Tiffs, if needed

In [None]:
PNGs = sorted(glob('E:/Grainger_Lab/Amber/OIC-74_Zebrafish_RBC_Classification/tiles/**/instTiles/*.png',recursive=True))

In [None]:
#Converted all pngs to tiffs
for y in PNGs:
    path = os.path.dirname(y)
    name = os.path.basename(y)
    img = sk.io.imread(y)
    sk.io.imsave(os.path.join(path,name[:-4]+'.tif'),img,check_contrast=False)

Had to copy all images and masks as tiffs to a single pair of folders:

In [None]:
for i in range(len(X)):
    sk.io.imsave(os.path.join('E:/Grainger_Lab/Amber/OIC-74_Zebrafish_RBC_Classification/tiles/All_Imgs/Imgs','Img_0'+str(i)+'.tif'),X[i],check_contrast=False)
    sk.io.imsave(os.path.join('E:/Grainger_Lab/Amber/OIC-74_Zebrafish_RBC_Classification/tiles/All_Imgs/Masks','Img_0'+str(i)+'.tif'),Y[i],check_contrast=False)
print('Done!')

Read in images

In [None]:
Xfiles = sorted(glob('E:/Grainger_Lab/Amber/OIC-74_Zebrafish_RBC_Classification/tiles/All_Imgs/Imgs/*.tif'))
Yfiles = sorted(glob('E:/Grainger_Lab/Amber/OIC-74_Zebrafish_RBC_Classification/tiles/All_Imgs/Masks/*.tif'))
#assert all(Path(x).name==Path(y).name for x,y in zip(X,Y)) #added indexing here because the file type is different between the images and the masks, using indexing to match the name of the file without file type

In [None]:
X = list(map(imread,Xfiles))
Y = list(map(imread,Yfiles))

In [None]:
X = X[0:len(X):100]

In [None]:
Y = Y[0:len(Y):100]

In [None]:
len(Y)

In [None]:
n_channel = 1 if X[0].ndim == 2 else X[0].shape[-1]

In [None]:
# Check number of channels
print(n_channel)

Normalize the images, fill possible holes in labels then split into train and validate groups

In [None]:
axis_norm = (0,1)   # normalize channels independently
# axis_norm = (0,1,2) # normalize channels jointly
if n_channel > 1:
    print("Normalizing image channels %s." % ('jointly' if axis_norm is None or 2 in axis_norm else 'independently'))
    sys.stdout.flush()

X = [normalize(x,1,99.8,axis=axis_norm) for x in tqdm(X)]
Y = [fill_label_holes(y) for y in tqdm(Y)]

In [None]:
assert len(X) > 1, "not enough training data"
rng = np.random.RandomState(42)
ind = rng.permutation(len(X))
n_val = max(1, int(round(0.15 * len(ind))))
ind_train, ind_val = ind[:-n_val], ind[-n_val:]
X_val, Y_val = [X[i] for i in ind_val]  , [Y[i] for i in ind_val]
X_trn, Y_trn = [X[i] for i in ind_train], [Y[i] for i in ind_train] 
print('number of images: %3d' % len(X))
print('- training:       %3d' % len(X_trn))
print('- validation:     %3d' % len(X_val))

Set up hyperparameters for StarDist Model

In [None]:
# 32 is a good default choice (see 1_data.ipynb)
n_rays = 32

# Use OpenCL-based computations for data generator during training (requires 'gputools')
use_gpu = True and gputools_available()

# Predict on subsampled grid for increased efficiency and larger field of view
grid = (4,4)

conf = Config2D (
    n_rays       = n_rays,
    grid         = grid,
    use_gpu      = use_gpu,
    n_channel_in = n_channel,
    train_patch_size = (512,512),
    train_steps_per_epoch = 100,
    train_epochs = 400,
)
#print(conf)
#vars(conf)

In [None]:
model = StarDist2D(conf, name='BloodCells', basedir='models')


Make sure that the field of view for the network is larger than the objects

In [None]:
median_size = calculate_extents(list(Y), np.median)
fov = np.array(model._axes_tile_overlap('YX'))
print(f"median object size:      {median_size}")
print(f"network field of view :  {fov}")
# if any(median_size > fov):
#     print("WARNING: median object size larger than field of view of the neural network.")

Define the augementations to use

In [None]:
def random_fliprot(img, mask): 
    assert img.ndim >= mask.ndim
    axes = tuple(range(mask.ndim))
    perm = tuple(np.random.permutation(axes))
    img = img.transpose(perm + tuple(range(mask.ndim, img.ndim))) 
    mask = mask.transpose(perm) 
    for ax in axes: 
        if np.random.rand() > 0.5:
            img = np.flip(img, axis=ax)
            mask = np.flip(mask, axis=ax)
    return img, mask 

def random_intensity_change(img):
    img = img*np.random.uniform(0.6,2) + np.random.uniform(-0.2,0.2)
    return img


def augmenter(x, y):
    """Augmentation of a single input/label image pair.
    x is an input image
    y is the corresponding ground-truth label image
    """
    x, y = random_fliprot(x, y)
    x = random_intensity_change(x)
    # add some gaussian noise
    sig = 0.02*np.random.uniform(0,1)
    x = x + sig*np.random.normal(0,1,x.shape)
    return x, y

Train the model

Use `tensorboard --logdir=.` in the command line in the same parent directory as the models (with StarDist env active) to watch live read out of training (could also put in the directory of the parent folder instead of cd to parent directory)

In [None]:
model.train(X_trn, Y_trn, validation_data=(X_val,Y_val), augmenter=augmenter)

In [None]:
model.optimize_thresholds(X_val, Y_val)

In [None]:
Y_val_pred = [model.predict_instances(x, n_tiles=model._guess_n_tiles(x), show_tile_progress=False)[0]
              for x in tqdm(X_val)]

In [None]:
import random
nums = range(len(Y_val_pred))
i = random.randint(min(nums),max(nums))

In [None]:
viewer = napari.view_image(X_val[i],name='img')
viewer.add_image(Y_val[i],name='GT')
viewer.add_image(Y_val_pred[i],name='Pred')

Prediction

In [2]:
Xfiles = sorted(glob('E:/Grainger_Lab/Amber/OIC-74_Zebrafish_RBC_Classification/tiles/All_Imgs/Imgs/*.tif'))

In [3]:
X_test = list(map(imread,Xfiles))

In [4]:
X_test = X_test[7:len(Xfiles):100]

In [5]:
model = StarDist2D(None, name='BloodCells', basedir='models')

Loading network weights from 'weights_best.h5'.
Loading thresholds from 'thresholds.json'.
Using default values: prob_thresh=0.464147, nms_thresh=0.3.


In [7]:
n_channel=3
axis_norm = (0,1)   # normalize channels independently
# axis_norm = (0,1,2) # normalize channels jointly
if n_channel > 1:
    print("Normalizing image channels %s." % ('jointly' if axis_norm is None or 2 in axis_norm else 'independently'))
    sys.stdout.flush()

X_test = [normalize(x,1,99.8,axis=axis_norm) for x in tqdm(X_test)]

Normalizing image channels independently.


100%|███████████████████████████████████████████████████████████████████████████████████████| 121/121 [00:10<00:00, 11.62it/s]


In [8]:
Y_test = [model.predict_instances(x, n_tiles=model._guess_n_tiles(x), show_tile_progress=False)[0]
              for x in tqdm(X_test)]

  0%|                                                                                                 | 0/121 [00:00<?, ?it/s]functional.py (238): The structure of `inputs` doesn't match the expected structure.
Expected: ['input']
Received: inputs=Tensor(shape=(1, 1024, 1024, 3))
100%|███████████████████████████████████████████████████████████████████████████████████████| 121/121 [00:53<00:00,  2.24it/s]


In [None]:
import random
nums = range(len(Y_test))
i = random.randint(min(nums),max(nums))

In [None]:
viewer = napari.view_image(X_test[i],name='img')
#viewer.add_image(Y_val[i],name='GT')
viewer.add_image(Y_test[i],name='Pred')

In [10]:
for i in range(len(Y_test)):
    save_path = 'E:/Grainger_Lab/Amber/OIC-74_Zebrafish_RBC_Classification/StarDist_Test_Results'
    sk.io.imsave(os.path.join(save_path,'mask_0'+str(i)+'.tiff'),Y_test[i],check_contrast=False)
    sk.io.imsave(os.path.join(save_path,'img_0'+str(i)+'.tiff'),X_test[i],check_contrast=False)