# Postprocessing Ilastik Segmentation - 3 Classes
This notebook post processes a 3 classes Ilastik segmentation:

- Class/Label 1: Cell centers
- Class/Label 2: Background
- Class/Label 3: Cell-edges

Cell-edges should be about 1-pixel wide line that marks the edge of the cell.  
Cell centers should label the cell interior

We will use the Cell centers class to find markers (center points) for the watershed.  
The total cell mask is taken from the sum of Class 1 and 3 i.e. P(pixel=cell) = P(pixel=edge) + P(pixel=cell center) 

This is how the training labels should look like:
![training labels](Ilastik_3ClassSeg_TrainingPoints.png)

This is how the Ilastik prediction should look like:
![ilastik output](Ilastik_3ClassSeg_Output.png)

In [1]:
#next two lines make sure that Matplotlib plots are shown properly in Jupyter Notebook
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

#next line is required for Napari
%gui qt

#main data analysis packages
import numpy as np
import pandas as pd

#image processing packages
from scipy import ndimage as ndi
import skimage.segmentation as segmentation 
import skimage.filters as filters
from skimage.measure import label, regionprops, regionprops_table
from skimage import morphology
from skimage.future import graph


#data plotting packages
import matplotlib
import matplotlib.pyplot as plt
#set default figure size
matplotlib.rc("figure", figsize=(10,5))
import seaborn as sns

#image viewer
import napari
from napari.utils.notebook_display import nbscreenshot

#out of memory computation
from dask_image.imread import imread
import dask.array as da

#path handling
import pathlib

#file handling
import h5py

#Instead of dask_image.imread.imread() you can also use tifffile.imread() to directly read images into memory
#import tifffile

In [2]:
#we initiate a cashe for Dask to speed up repeated computation (important for working with Napari)
from dask.cache import Cache
cache = Cache(2e9)  # Leverage two gigabytes of memory
cache.register()    # Turn cache on globally

---

## Import and visualize Ilastik output
### Load images

In [3]:
#Set the path to the folder that contains project data
root = pathlib.Path(pathlib.Path.home(), 'Andreas', 'Training')

#raw images
im_name0 = 'ph_training.tif' #set name of image
im_name1 = 'gfp_training.tif' #set name of image

#ilastik output
segment_name = 'training_pr_Probabilities.h5' #set name of segmented data

#ilastik settings
fg_idx = 0 #set index of foreground label used in Ilastik
edge_idx = 2 #set index of edge in Ilastik

im0_stack = imread(root / im_name0) #load image with dask-image for out of memory processing 
im1_stack = imread(root / im_name1) #load image with dask-image for out of memory processing 

To visualize the data we will use [Napari](https://napari.org), that allows for interactive image visualization. 
You can use the slider at the bottom of the window to scroll trough time. 

In [5]:
#setup napari viewer, with 2 channel image
viewer = napari.view_image(im0_stack, name="phase", colormap="gray")
viewer.add_image(im1_stack, name="rfp", colormap="red", opacity=0.5)
napari.run()

### Load and visualize Ilastik output
Next we load the data exported by Ilastik.

In [14]:
seg_path = root / segment_name #path to Ilastik output
seg_data = h5py.File(seg_path, 'r') #open 

#we again use Dask to load the data out of memory
#technical comment: we set chunk size to be equal to the size of single frame. 
chuck_size = (1, 3, *im0_stack.shape[-2:])
seg_prob = da.from_array(seg_data['exported_data'], chunks=chuck_size)

---

## Convert Ilastik probability map into segmentation
Ilastik assigns to each pixel the probability that it belongs to a cell. We need to convert this to in instance based segmentation, where each cell is assigned a unique label. This requires a number of steps:
1. Pre-process probability map using filters
2. Convert probability map to semantic segmentation (binary image where cells=1 and background=0) using thresholding
3. Clean up semantic segmentation using morphological operations
4. Convert semantic segmentation into instance segmentation (integer images where each cells has own label)
5. Post-process instance segmentation by separating merged cells using watershed algorithm

---

### 1. Pre-process probability map using filters

As a first step probability maps are often processed using a Gaussian blur filter using a small (~1 pixel) size, to ensure that the probability maps are locally smooth. We will use scikit image [`filters.gaussion`](https://scikit-image.org/docs/stable/api/skimage.filters.html) to do this. 

In [17]:
#do not smooth edge
p_edge = seg_prob[:,edge_idx,:,:]

#smooth total cell probability
sigma = 1
p_cell = seg_prob[:,fg_idx,:,:] + seg_prob[:,edge_idx,:,:] 
p_cell = da.map_blocks(filters.gaussian, p_cell, sigma, channel_axis=0)

#smooth cell centre
sigma = 1 #size of Gaussion kernel to use 
p_center = da.map_blocks(filters.gaussian, seg_prob[:,fg_idx,:,:], sigma, channel_axis=0)

#add probability layer to Napari Viewer
prop_layer = viewer.add_image(p_center, name='p_center',colormap='gray')
prop_layer = viewer.add_image(p_edge, name='p_edge',colormap='gray')
prop_layer = viewer.add_image(p_cell, name='p_cell',colormap='gray')

----
### 2. Convert probability map to semantic segmentation using thresholding
Here we test good threshold for cell centre and full cell

In [20]:
#create array with all threshold values to rry, here we use 0,0.01,0.02,...,1
thresholds_to_try = np.linspace(0,1,101)

#convert the list of 3D stacks to a single 4D stack
marker_stack = da.stack([p_center > t for t in thresholds_to_try], axis=0)
mask_stack = da.stack([p_cell > t for t in thresholds_to_try], axis=0)

#add to viewer
mask_layer_int = viewer.add_image(mask_stack, name='cell masks',colormap='gray')
marker_layer_int = viewer.add_image(marker_stack, name='cell markers',colormap='gray')


In [22]:
#choose your favorite method or enter a manually chosen value
tr_marker = 0.4
tr_mask = 0.5

#threshold cell markers (centre points) and cell masks
markers = p_center > tr_marker
mask = p_cell > tr_mask

----
### 3. Clean up semantic segmentation using morphological operations

In [23]:
max_hole_size = 40 # maximum area of holes that will be filled (in pixels)
min_cell_size = 50 # minimum area of objects to keep (in pixels)

#clean up cell masks
mask = da.map_blocks(morphology.remove_small_holes, mask, max_hole_size)
mask = da.map_blocks(morphology.remove_small_objects, mask, min_cell_size)

#add mask to Napari
mask_layer_clean = viewer.add_image(mask, name='mask cleaned',colormap='gray')

### 4. Convert semantic segmentation into instance segmentation

In [28]:
#convert binary markers into label markers:
marker_labels = da.map_blocks(label, markers)

#add markers to Napari
nap_marker_labels = viewer.add_labels(marker_labels, name='marker_labels')

---
### 5. Watershed algorithm

In [24]:
#we need a wrapper function to tranform named arguments into positional arguments to make things work with Dask
def watershed(dist, markers, mask):
    return segmentation.watershed(-dist, markers=markers, mask=mask)    

In [30]:
#calculate distance to edge of mask
dist_transform = da.map_blocks(ndi.distance_transform_edt, mask)   
nap_dist = viewer.add_image(dist_transform, name='dist_trans')

#run watershed
watershed_labels = da.map_blocks(watershed, dist_transform, marker_labels, mask, chunks=(1,*mask.shape[-2:]))

#add to Napari
nap_watershed_labels = viewer.add_labels(watershed_labels, name='after watershed')   

---
### SKIP: RAG Recombine Cells

[Region Adjacency Graphs](https://scikit-image.org/docs/stable/api/skimage.future.graph.html) could potentially be used to correct the over segmentation that results from the Watershed.
The parameters below need optimization, but if watershed worked well this step can be skipped and cells can be merged manually

In [50]:
#we go into memory now
# watershed_labels = watershed_labels.compute()
# p_edge = p_edge.compute()

In [61]:
# labels_merged = np.empty_like(watershed_labels)

# for idx, (lab, edg) in enumerate(zip(watershed_labels, p_edge)):
#     rag = graph.rag.rag_boundary(lab, edg, connectivity=2)
#     labels_merged[idx,:,:] = graph.cut_threshold(lab, rag, 0.4, in_place=False) #used 0.88 for large set

#add to Napari
#nap_merged_labels = viewer.add_labels(labels_merged, name='after merging')   

---
## Store segmentation data

In [64]:
#we can first convert to 16bit to save some space
watershed_labels.astype('int16')
outname = root /  'processed_labels.hdf5'

#store as hdf5
h5f = h5py.File(outname, 'w')
h5f.create_dataset('raw_labels', data=watershed_labels)
h5f.close()

---
## Manual Correction

See notebook manual_correct_segment