# NOTE

**If you have not looked at the [regular example notebooks](../2D), please do so first.**  
The notebooks in this folder provide further details about the inner workings of StarDist and might be useful if you want to apply it in a slightly different context.

In [None]:
from __future__ import print_function, unicode_literals, absolute_import, division
import sys
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

from glob import glob
from tqdm import tqdm
from tifffile import imread
from csbdeep.utils import Path, normalize

from stardist import fill_label_holes, random_label_cmap
from stardist.models import Config2D, StarDist2D, StarDistData2D

np.random.seed(42)
lbl_cmap = random_label_cmap()

# Data

We assume that data has already been downloaded in via notebook [1_data.ipynb](1_data.ipynb).  
In general, training data (for input `X` with associated labels `Y`) can be provided via lists of numpy arrays, where each image can have a different size. Alternatively, a single numpy array can also be used if all images have the same size.  
Input images can either be two-dimensional (single-channel) or three-dimensional (multi-channel) arrays, where the channel axis comes last. Label images need to be integer-valued.

In [None]:
X = sorted(glob('data/dsb2018/train/images/*.tif'))
Y = sorted(glob('data/dsb2018/train/masks/*.tif'))
assert all(Path(x).name==Path(y).name for x,y in zip(X,Y))

In [None]:
X = list(map(imread,X))
Y = list(map(imread,Y))
n_channel = 1 if X[0].ndim == 2 else X[0].shape[-1]

Normalize images and fill small label holes.

In [None]:
axis_norm = (0,1)   # normalize channels independently
# axis_norm = (0,1,2) # normalize channels jointly
if n_channel > 1:
    print("Normalizing image channels %s." % ('jointly' if axis_norm is None or 2 in axis_norm else 'independently'))
    sys.stdout.flush()

X = [normalize(x,1,99.8,axis=axis_norm) for x in tqdm(X)]
Y = [fill_label_holes(y) for y in tqdm(Y)]

Split into train and validation datasets.

In [None]:
assert len(X) > 1, "not enough training data"
rng = np.random.RandomState(42)
ind = rng.permutation(len(X))
n_val = max(1, int(round(0.15 * len(ind))))
ind_train, ind_val = ind[:-n_val], ind[-n_val:]
X_val, Y_val = [X[i] for i in ind_val]  , [Y[i] for i in ind_val]
X_trn, Y_trn = [X[i] for i in ind_train], [Y[i] for i in ind_train] 
print('number of images: %3d' % len(X))
print('- training:       %3d' % len(X_trn))
print('- validation:     %3d' % len(X_val))

Training data consists of pairs of input image and label instances.

In [None]:
i = min(9, len(X)-1)
img, lbl = X[i], Y[i]
assert img.ndim in (2,3)
img = img if img.ndim==2 else img[...,:3]
plt.figure(figsize=(16,10))
plt.subplot(121); plt.imshow(img,cmap='gray');   plt.axis('off'); plt.title('Raw image')
plt.subplot(122); plt.imshow(lbl,cmap=lbl_cmap); plt.axis('off'); plt.title('GT labels')
None;

# Data for training StarDist

From the label instance image, all necessary data for training `StarDist2D` can be computed via `StarDistData2D`.  
Note that this here is only for illustration, since it happens automatically when calling `StarDist2D.train` (see below).

## Without shape completion

With `shape_completion = False` (see `Config2D` below), the trained `StarDist2D` model will *not* predict completed shapes for partially visible cells at the image boundary. This is the default behavior.

In [None]:
np.random.seed(42)
data = StarDistData2D(X,Y,batch_size=1,n_rays=32,patch_size=(256,256),shape_completion=False)

In [None]:
(img,dist_mask), (prob,dist) = data[0]

fig, ax = plt.subplots(2,2, figsize=(12,12))
for a,d,cm,s in zip(ax.flat, [img,prob,dist_mask,dist], ['gray','magma','bone','viridis'],
                    ['Input image','Object probability','Distance mask','Distance (0°)']):
    a.imshow(d[0,...,0],cmap=cm)
    a.set_title(s)
plt.tight_layout()
None;

## With shape completion

With `shape_completion = True` (see `Config2D` below), the trained `StarDist2D` model will predict completed shapes for partially visible cells at the image boundary. For this to work, the image needs to be cropped, which is controlled by the `Config2D` parameter `train_completion_crop` (default 32), which should be chosen based on the size of the objects. Furthermore, it may be a good idea to increase `train_batch_size` to offset the reduced amount of pixels per training patch due to cropping.

In [None]:
np.random.seed(42)
data = StarDistData2D(X,Y,batch_size=1,n_rays=32,patch_size=(256,256),shape_completion=True)

In [None]:
(img,dist_mask), (prob,dist) = data[0]

fig, ax = plt.subplots(2,2, figsize=(12,12))
for a,d,cm,s in zip(ax.flat, [img,prob,dist_mask,dist], ['gray','magma','bone','viridis'],
                    ['Input image','Object probability','Distance mask','Distance (0°)']):
    a.imshow(d[0,...,0],cmap=cm)
    a.set_title(s)
plt.tight_layout()
None;

# Training

A `StarDist2D` model is specified via a `Config2D` object.

In [None]:
print(Config2D.__doc__)

You can monitor the progress during training with [TensorBoard](https://www.tensorflow.org/programmers_guide/summaries_and_tensorboard) by starting it from the current working directory:

    $ tensorboard --logdir=.

Then connect to [http://localhost:6006/](http://localhost:6006/) with your browser.


## Without shape completion

In [None]:
conf = Config2D(n_channel_in=n_channel, train_batch_size=4, train_shape_completion=False)
print(conf)
vars(conf)

In [None]:
model = StarDist2D(conf, name='stardist_no_shape_completion', basedir='models')

In [None]:
%%capture train_log
model.train(X_trn,Y_trn,validation_data=(X_val,Y_val))

In [None]:
# show train log
# train_log()

## With shape completion

In [None]:
conf = Config2D(n_channel_in=n_channel, train_batch_size=7, train_shape_completion=True)
print(conf)
vars(conf)

In [None]:
model = StarDist2D(conf, name='stardist_shape_completion', basedir='models')

In [None]:
%%capture train_log
model.train(X_trn,Y_trn,validation_data=(X_val,Y_val))

In [None]:
# show train log
# train_log()