# Mathematical morphology: the distance transform


### 1. Overview & learning objectives
In this notebook we will develop a method to identify one point per cell. This will be useful later on to apply region-growing algorithms of cell segmentation. 

With this notebook we will:

1. Learn about the distance transform as a morphological method to measure distances in an image.

1. Use the local maxima of the distance transform to identify one point per object to be segmented.

1. Overlay labeled and grayscale images.

### 2. The distance transform
In previous lectures, we have used a combination of filtering, local thresholding, and mathematical morphology to generate a mask of cell edges. Our ultimate goal is to find one point per cell that we will *grow* using a segmentation algorithm. Run this code to reproduce what we did so far:


In [None]:
import matplotlib.pyplot as plt
import numpy
from skimage import exposure, filters, io, morphology

# Read image from disk.
animage = io.imread('cells.tif')

# Gaussian smoothing to facilitate edge detection.
animage_smooth = filters.gaussian(animage, sigma=2, preserve_range=True)

# Contrast stretch.
animage_rescaled = exposure.rescale_intensity(
    animage_smooth, out_range=numpy.uint8)

# Local threshold.
amask = animage_rescaled >= filters.threshold_local(
    animage_rescaled, 33, method='gaussian')

amask_closed = morphology.binary_closing(amask, morphology.disk(3))

# Display image and mask.
fig, axs = plt.subplots(1, 2)
axs[0].imshow(animage, cmap='Greys_r')
axs[1].imshow(amask_closed, cmap='binary_r')
plt.show()


Now that we have a good representation of interfacial pixels, let's find pixels **away** from interfaces. The **distance transform** of a binary image is an image in which the value of each pixel corresponds to its distance to the background. In general, the distance transform can be calculated by iteratively eroding an image until there are no "on" pixels left. The distance transform value of a pixel corresponds to the number of the erosion in which the pixel disappeared.

We can use the distance transform to identify pixels far from cell interfaces. We will use a method in the scipy.ndimage module to calculate the distance transform, **distance_transform_edt**. scikit-image provides its own method to calculate the distance transform (skimage.morphology.medial_axis), but that function does many more things, and using it here is overkill.

Use distance_transform_edt to calculate the distance transform of the mask above. Remember, we are trying to find the distance from any given pixel to the cell interfaces ...

In [None]:
# DELETE THIS CODE.

import scipy.ndimage as ndimage

# distance_transform_cdt does integer distances. Using the chessboard distance metric (instead of taxicab) is the equivalent as a dt implemented through sequential erosions with N8 connnectivity.
# distance_transform_edt multiplies the distance of corner pixels by 1.4 (sqrt(2)).
dt = ndimage.distance_transform_edt(numpy.invert(amask_closed))

fig, axs = plt.subplots(1, 2)
axs[0].imshow(animage, cmap='Greys_r')
axs[1].imshow(dt, cmap='Greys_r')
plt.show()


### 3. Finding and displaying seeds
Our goal is to obtain one and only one seed point per cell. Once we have the distance transform, there are multiple ways to obtain one point per cell. To visualize seeds, let's define a function, **plot_seeds** that plots a list of points over an image. The function also returns an image with the seeds.

In [None]:
def plot_seeds(theimage, theseed_coords):
    """
        plot_seeds: plots a set of points on an image.
        
        input:
            theimage: ndarray representing the image.
            theseed_coords: ndarray with two columns and as many rows as points to be displayed, 
                            containing the [y, x] coordinates of each point.     
            
        output:
            seed_image: ndarray representing a labeled image with one object per seed.
    """

    # Create empty seed image.
    seed_image = numpy.zeros(theimage.shape)

    # For each seed ...
    for label, seed_xy in enumerate(theseed_coords):
        # ... set the value of the corresponding pixel to a different value.
        seed_image[seed_xy[0], seed_xy[1]] = label + 1

    # Display the image.
    plt.imshow(theimage, cmap='gray')
    
    # Create a structuring element to dilate (aka grow) seeds a bit for display.
    structelem = morphology.disk(3)
    
    # And overlay the seeds.
    plt.imshow(morphology.dilation(seed_image, structelem), cmap='jet', alpha=0.50)
    plt.show()
    
    return seed_image

To obtain one point per cell, we will extract the local maxima of the distance transform. This could be accomplished in different ways. For example, one could threshold the distance tranform to identify pixels that are **at least** a certain distance away from cell interfaces (you could think about how to implement that solution, perhaps using an adaptive threshold). 

In scikit-image, the skimage.feature module includes the method **peak_local_max**, which can be applied to the distance transform (in combination with the mask that we generated) to obtain one maximum per cell.

Using the documentation of **peak_local_max** and **plot_seeds**, extract and visualize a set of seeds for our image. What are the critical parameters to obtain an accurate number of seeds? What are the limitations of peak_local_max?

In [None]:
# DELETE THIS CODE
import skimage.feature as feature

coords_maxima = feature.peak_local_max(dt, labels=morphology.label(numpy.invert(amask_closed)), num_peaks_per_label=1, exclude_border=False)

seed_image = plot_seeds(animage, coords_maxima)