# Segmentierung mit Stardist - Training
Dieses Notebook ist dem 2D-Beispiel https://github.com/mpicbg-csbd/stardist/tree/master/examples/2D der GitHub Implementierung entnommen.

Nun kommen wir zum Training. Hier werden wir zwei GPUs nutzen.

In [None]:
from __future__ import print_function, unicode_literals, absolute_import, division
%load_ext autoreload
%autoreload 2
import sys
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

from glob import glob
from tqdm import tqdm
from tifffile import imread
from csbdeep.utils import Path, normalize

from stardist import fill_label_holes, random_label_cmap, calculate_extents, gputools_available
from stardist.matching import matching_dataset
from stardist.models import Config2D, StarDist2D, StarDistData2D

np.random.seed(42)
lbl_cmap = random_label_cmap()

# Wir setzen direkt die ENVs, damit wir nun die nächsten zwei GPUs nutzen können.
%env CUDA_DEVICE_ORDER=PCI_BUS_ID
%env CUDA_VISIBLE_DEVICES=1,2
# die zweite GPU soll von den gputools verwendet werden
%env gputools_id_device=1
tf.config.list_physical_devices()

In [None]:
X_glob = sorted(glob('/extdata/readonly/f-prak-v15/e-coli-swarming/train/input/*.tif'))
Y_glob = sorted(glob('/extdata/readonly/f-prak-v15/e-coli-swarming/train/labels/*.tif'))
def labelname(name):
    return name[:-5]+name[-4:]
assert all(Path(x).name == labelname(Path(y).name) for x,y in zip(X_glob, Y_glob))

In [None]:
X = list(map(imread, X_glob))
Y = list(map(imread, Y_glob))
n_channel = 1 if X[0].ndim == 2 else X[0].shape[-1]

## Vorbereiten der Daten
Neuronale Netzwerke operieren am besten, wenn die einzelnen Werte der Input und Output Daten im Bereich -1 bis 1 liegen. Unsere Bilddaten sind jedoch 16 Bit Daten. Daher werden sie konvertiert. Zudem werden sie normalisiert. Das 1% Percentil jedes Bildes wird auf 0, das 99.8% Perzentil auf 1 gesetzt. So werden hot und dead Pixels durch preprocessing eliminiert.

In [None]:
axis_norm = (0,1)   # normalize channels independently
# axis_norm = (0,1,2) # normalize channels jointly
if n_channel > 1:
    print("Normalizing image channels %s." % ('jointly' if axis_norm is None or 2 in axis_norm else 'independently'))
    sys.stdout.flush()

X = [normalize(x,1,99.8,axis=axis_norm) for x in tqdm(X)]
Y = [fill_label_holes(y) for y in tqdm(Y)]

Diesmal müssen wir die Validierungsdaten nicht händisch erstellen.

In [None]:
assert len(X) > 1, "not enough training data"
rng = np.random.RandomState(42)
ind = rng.permutation(len(X))
n_val = max(1, int(round(0.15 * len(ind))))
ind_train, ind_val = ind[:-n_val], ind[-n_val:]
X_val, Y_val = [X[i] for i in ind_val]  , [Y[i] for i in ind_val]
X_trn, Y_trn = [X[i] for i in ind_train], [Y[i] for i in ind_train] 
print('number of images: %3d' % len(X))
print('- training:       %3d' % len(X_trn))
print('- validation:     %3d' % len(X_val))

Training data consists of pairs of input image and label instances.

In [None]:
i = min(9, len(X)-1)
img, lbl = X[i], Y[i]
assert img.ndim in (2,3)
img = img if (img.ndim==2 or img.shape[-1]==3) else img[...,0]
plt.figure(figsize=(16,10))
plt.subplot(121); plt.imshow(img,cmap='gray');   plt.axis('off'); plt.title('Raw image')
plt.subplot(122); plt.imshow(lbl,cmap=lbl_cmap); plt.axis('off'); plt.title('GT labels')
None;

# Configuration

A `StarDist2D` model is specified via a `Config2D` object.

In [None]:
print(Config2D.__doc__)

In [None]:
# Wir haben vorhin gesehen, dass ab 64 Strahlen, die Näherung mit Polygonen gut funktioniert.
n_rays = 64

# Use OpenCL-based computations for data generator during training (requires 'gputools')
use_gpu = True and gputools_available()

# Predict on subsampled grid for increased efficiency and larger field of view
grid = (2,2)

conf = Config2D (
    n_rays       = n_rays,
    grid         = grid,
    use_gpu      = use_gpu,
    n_channel_in = n_channel,
    train_tensorboard = False,
    train_batch_size = 8,
)
print(conf)
vars(conf)

**Note:** The trained `StarDist2D` model will *not* predict completed shapes for partially visible objects at the image boundary if `train_shape_completion=False` (which is the default option).

## Name des Models
Alle Daten über das Modell werden unter model/mystardist-1 gespeichert. Von dort muss das Modell für spätere Zwecke wieder geladen werden.

In [None]:
model = StarDist2D(conf, name='mystardist-1', basedir='models')

Check if the neural network has a large enough field of view to see up to the boundary of most objects.

In [None]:
median_size = calculate_extents(list(Y), np.median)
fov = np.array(model._axes_tile_overlap('YX'))
if any(median_size > fov):
    print("WARNING: median object size larger than field of view of the neural network.")

# Training

You can define a function/callable that applies augmentation to each batch of the data generator.  
We here use an `augmenter` that applies random rotations, flips, and intensity changes, which are typically sensible for (2D) microscopy images:

In [None]:
def random_fliprot(img, mask): 
    axes = tuple(range(img.ndim)) 
    perm = np.random.permutation(axes)
    img = img.transpose(perm) 
    mask = mask.transpose(perm) 
    for ax in axes: 
        if np.random.rand() > 0.5:
            img = np.flip(img, axis=ax)
            mask = np.flip(mask, axis=ax)
    return img, mask 

def random_intensity_change(img):
    img = img*np.random.uniform(0.6,2) + np.random.uniform(-0.2,0.2)
    return img


def augmenter(x, y):
    """Augmentation of a single input/label image pair.
    x is an input image
    y is the corresponding ground-truth label image
    """
    x, y = random_fliprot(x, y)
    x = random_intensity_change(x)
    return x, y

You can disable augmentation by setting `augmenter = None`.

We recommend to monitor the progress during training with [TensorBoard](https://www.tensorflow.org/programmers_guide/summaries_and_tensorboard). You can start it in the shell from the current working directory like this:

    $ tensorboard --logdir=.

Then connect to [http://localhost:6006/](http://localhost:6006/) with your browser.


In [None]:
model.train(X_trn, Y_trn, validation_data=(X_val, Y_val), augmenter=augmenter)

# Threshold optimization

While the default values for the probability and non-maximum suppression thresholds already yield good results in many cases, we still recommend to adapt the thresholds to your data. The optimized threshold values are saved to disk and will be automatically loaded with the model.

In [None]:
model.optimize_thresholds(X_val, Y_val)

# Evaluation and Detection Performance

Besides the losses and metrics during training, we can also quantitatively evaluate the actual detection/segmentation performance on the validation data by considering objects in the ground truth to be correctly matched if there are predicted objects with overlap (here [intersection over union (IoU)](https://en.wikipedia.org/wiki/Jaccard_index)) beyond a chosen IoU threshold $\tau$.

The corresponding matching statistics (average overlap, accuracy, recall, precision, etc.) are typically of greater practical relevance than the losses/metrics computed during training (but harder to formulate as a loss function). 
The value of $\tau$ can be between 0 (even slightly overlapping objects count as correctly predicted) and 1 (only pixel-perfectly overlapping objects count) and which $\tau$ to use depends on the needed segmentation precision/application.

Please see the Wikipedia page on [Sensitivity and specificity](https://en.wikipedia.org/wiki/Sensitivity_and_specificity) for definitions of the abbreviations used in the evaluation below. Note that `mean_true_score` refers to the average overlap (IoU) of all true positives (tp), i.e. correctly predicted objects in terms of the chosen overlap threshold.

First predict the labels for all validation images:

In [None]:
Y_val_pred = [model.predict_instances(x, n_tiles=model._guess_n_tiles(x), show_tile_progress=False)[0]
              for x in tqdm(X_val)]

Choose several IoU thresholds $\tau$ that might be of interest and for each compute matching statistics for the validation data.

In [None]:
taus = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
stats = [matching_dataset(Y_val, Y_val_pred, thresh=t, show_progress=False) for t in tqdm(taus)]

Example: Print all available matching statistics for $\tau=0.5$

In [None]:
stats[taus.index(0.5)]

Plot the matching statistics and the number of true/false positives/negatives as a function of the IoU threshold $\tau$. 

In [None]:
fig, (ax1,ax2) = plt.subplots(1,2, figsize=(15,5))

for m in ('precision', 'recall', 'accuracy', 'f1', 'mean_true_score'):
    ax1.plot(taus, [s._asdict()[m] for s in stats], '.-', lw=2, label=m)
ax1.set_xlabel(r'IoU threshold $\tau$')
ax1.set_ylabel('Metric value')
ax1.grid()
ax1.legend()

for m in ('fp', 'tp', 'fn'):
    ax2.plot(taus, [s._asdict()[m] for s in stats], '.-', lw=2, label=m)
ax2.set_xlabel(r'IoU threshold $\tau$')
ax2.set_ylabel('Number #')
ax2.grid()
ax2.legend();

In [None]:
ind = 0
lbl_cmap = random_label_cmap()
#lbl_cmap = "magma"
input_val = X_val[ind]
gt_val = Y_val[ind]
pred_val = Y_val_pred[ind]
slx = slice(None)
sly = slice(None)
# Let's look at the results.
plt.figure(figsize=(16,8))
plt.subplot(1,3,1)
plt.imshow(input_val[sly, slx], cmap="magma")
plt.title('Input');
plt.subplot(1, 3, 2)
plt.imshow(pred_val[sly, slx], cmap=lbl_cmap)
plt.title('Prediction');
plt.subplot(1, 3, 3)
plt.imshow(gt_val[sly, slx], cmap=lbl_cmap)
plt.title('Labels');