## Distributed Watershed

Implementation of a distributed watershed function using Dask. Implementation inspired by Juan Nunez-Iglesias [here](https://github.com/dask/dask-image/pull/99). The overall method relies on a two-pass watershed model. The first pass watershed is used to generate and share information about markers across chunk boundaries. The second pass then propagates that information.

This implementation differs from v1. Whereas v1 shares marker information about the first pass watershed directly, this shares information about through a specially-labeled chunk-boundary basin.

### Base Imports

In [ ]:
import dask
import dask.array as da
import numpy as np

### Helper Functions

In [ ]:
def compute_mem_mb(shape):
    '''Determines memory consumption of an array with shape in MB'''
    from functools import reduce
    from operator import mul
    count = reduce(mul, shape, 1)
    return count * 8 / ( 1024 ** 2 )

def display(image):
    '''Shows an image in the Jupyter notebook.'''
    from skimage import io
    io.imshow(image)
    io.show()

def frac_to_zscore(frac):
    from scipy.stats import norm
    return norm.ppf(frac)

def mask_overlap(chunk, depth, label):
    overlap = np.zeros_like(chunk)
    for i in range(0, chunk.ndim):
        idx = [slice(None)]*chunk.ndim
        r = np.array(range(0, chunk.shape[i]))
        idx[i] = (r[:depth], r[-depth:])
        overlap[tuple(idx)] = label
    return overlap

def build_full_markers(labels, depth, label):
    markers = mask_overlap(labels, depth, label)
    markers[labels > 0] = labels[labels > 0]
    return markers

def create_random_salt_image(fraction_salt, size):
    zscore = frac_to_zscore(1 - fraction_salt)
    salt = np.random.normal(0.0, 1.0, size) > zscore
    return salt

def remove_label(chunk, label):
    out = chunk.copy()
    out[out == label] = 0
    return out

def compose(fp_chunk, sp_chunk, mask_chunk):
    out = fp_chunk.copy()
    out[mask_chunk] = sp_chunk[mask_chunk]
    return out

### Image Geometry Definitions

In [ ]:
ndim = 2
size_len = 200
size = ndim * [size_len]
mem_mb = compute_mem_mb(size)
print("2D array size (MB): {:.2f}".format(mem_mb))

In [ ]:
chunk_len = 400
chunks = ndim * [chunk_len]
chunk_mem_mb = compute_mem_mb(chunks)
print("Chunk size (MB): {:.2f}".format(chunk_mem_mb))

### Random Seed

In [ ]:
seed = 1
np.random.seed(seed=seed)

### Create Random "Salting" Image

In [ ]:
fraction_salt = 1e-2
salt = create_random_salt_image(fraction_salt, size)
#salt = np.flip(salt, axis=0)
display(salt)

### Determine EDT

In [ ]:
from scipy.ndimage.morphology import distance_transform_edt
edt = distance_transform_edt(~salt)
max_edt = edt.max()
edt = edt / max_edt
display(edt)

### Filter EDT Using H-Max/H-Dome

In [ ]:
from skimage.morphology import reconstruction
h = 1 / max_edt
h_seed = edt - h
hmax = reconstruction(h_seed, edt, method='dilation')
display(hmax)

### Prepare Marker Image

In [ ]:
from skimage.measure import label
ws_markers = label(salt)

### Watershed Transform

In [ ]:
from skimage.morphology import watershed
ws = watershed(hmax, markers=ws_markers)
display(ws)

### Prepare Dask Client

In [ ]:
from dask.distributed import Client
c = Client()
port = c.scheduler_info()['services']['dashboard']
print("Type `http://localhost:{port}` into the URL bar of your favorite browser to watch the following code in action on your machine in real time.".format(port=port))

### Create Dask Arrays

Here we assume future users have access to distributed versions of h-max, EDT, and connected component labeling.

In [ ]:
depth = 1
hmax_da = da.from_array(hmax, chunks=chunks)
hmax_op = da.overlap.overlap(hmax_da, depth=depth, boundary='nearest')

boundary_label = ws_markers.max() + 1
ws_markers_da = da.from_array(ws_markers, chunks=chunks)
ws_markers_op = da.overlap.overlap(ws_markers_da, depth=depth, boundary=boundary_label)
ws_markers_op = ws_markers_op.map_blocks(lambda x: build_full_markers(x, depth, boundary_label), dtype=ws_markers_op.dtype)
display(ws_markers_op)

# fp = first pass
ws_fp = hmax_op.map_blocks(lambda x, y: watershed(x, markers=y), ws_markers_op, dtype=hmax_op.dtype)
display(ws_fp)

### Propagate Boundary Basin

In [ ]:
ws_markers_sp = ws_fp.map_blocks(lambda x: remove_label(x, boundary_label))
ws_markers_sp = da.overlap.trim_overlap(ws_markers_sp, depth=depth)
ws_markers_sp = da.overlap.overlap(ws_markers_sp, depth=depth, boundary='nearest')
display(ws_markers_sp)

ws_mask = ws_fp == boundary_label
display(ws_mask)

ws_sp = hmax_op.map_blocks(lambda x, y, z: watershed(x, markers=y, mask=z), ws_markers_sp, ws_mask, dtype=ws_fp.dtype)
display(ws_sp)

### Compose First and Second Passes Using Mask

In [ ]:
ws_final = ws_fp.map_blocks(lambda x, y, z: compose(x, y, z), ws_sp, ws_mask, dtype=ws_fp.dtype)
ws_final = da.overlap.trim_overlap(ws_final, depth=depth)
display(ws_final)

### Validation

Note there are some different basin assignments between the methods. Pay particular attention to the error-free strip along the bottom 1/4 of the image. That strip seems to stay through seed changes, and even flipping the salt image along the vertical axis.

In [ ]:
error = ~(ws == ws_final).compute()
display(error)

error_count = error.sum()
print("Error count: {:d}".format(error_count))
print("Error fraction: {:.3%}".format(error_count / error.size))