In [None]:
# Plot predicted/true centroids & volumes for one tile 
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import random
from skimage.measure import regionprops
#from scipy import ndimage

x_max = 64
y_max = 64

z_min = 0
z_max = 32

index = 0

fig = plt.figure(figsize=(15, 15))

# Predicted cell volumetric centroid coordinates
coords = peak_local_max(
    np.squeeze(inner_distance[index]),
    min_distance=min_distance,
    threshold_abs=detection_threshold,
    exclude_border=False)


ax = fig.add_subplot(221, projection='3d')
ax.set_xlim3d(0, x_max)
ax.set_ylim3d(0, y_max)
ax.set_zlim3d(z_min, z_max)
test_plot = coords
ax.scatter(test_plot[:, 1], test_plot[:, 2], test_plot[:, 0], c='b', marker='o', s=5)
ax.set_title('Predicted Centroids')

# Ground-truth cell volumetric centroid coordinates
gt_mask = y_test[index+f_mov, ..., 0]
sampling = [0.5, 0.217, 0.217]

gt_dist =  ndimage.distance_transform_edt(gt_mask, sampling=sampling)
gt_dist = gt_dist.astype('float32')
gt_props = regionprops(gt_mask, gt_dist)
gt_coords = []

for cell in gt_props:
    gt_coords.append(np.asarray(cell.weighted_centroid))

gt_coords = np.round(np.asarray(gt_coords))

ax = fig.add_subplot(222, projection='3d')
ax.set_xlim3d(0, x_max)
ax.set_ylim3d(0, y_max)
ax.set_zlim3d(z_min, z_max)
test_plot = gt_coords
ax.scatter(test_plot[:, 1], test_plot[:, 2], test_plot[:, 0], c='b', marker='o', s=5)
ax.set_title('Ground-truth Centroids')



# Predicted Cell volumes
ax = fig.add_subplot(223, projection='3d')
ax.set_xlim3d(0,x_max)
ax.set_ylim3d(0,y_max)
ax.set_zlim3d(z_min, z_max)

plot_masks = np.rollaxis(np.squeeze(masks), 0, 3)

colors = np.empty(plot_masks.shape, dtype='<U9')

color_dict = dict()

for label in np.unique(plot_masks):
    if label != 0:
        color = ["#"+''.join([random.choice('0123456789ABCDEF') for j in range(8)])]
    else:
        color = None#'#7A88CCC0'

    color_dict[label] = color
    
    colors = np.where(plot_masks==label, color, colors)

ax.voxels(plot_masks, facecolors=colors)
ax.set_title('Predicted Cell Volumes')


# Ground truth Cell Volumes
ax = fig.add_subplot(224, projection='3d')
ax.set_xlim3d(0,x_max)
ax.set_ylim3d(0,y_max)
ax.set_zlim3d(z_min, z_max)

plot_masks = np.rollaxis(np.squeeze(gt_mask), 0, 3)

colors = np.empty(plot_masks.shape, dtype='<U9')

color_dict = dict()

for label in np.unique(plot_masks):
    if label != 0:
        color = ["#"+''.join([random.choice('0123456789ABCDEF') for j in range(8)])]
    else:
        color = None#'#7A88CCC0'

    color_dict[label] = color
    
    colors = np.where(plot_masks==label, color, colors)

ax.voxels(plot_masks, facecolors=colors)
ax.set_title('Ground-truth Cell Volumes')

plt.show()

In [None]:
# Run metrics on above plotted small-volume tile

y_pred = np.expand_dims(np.squeeze(np.rollaxis(masks, 0, 4)), 0)
y_true = np.expand_dims(np.squeeze(y_test[index+f_mov, 0:frames_per_batch, ...].copy()), 0)

print('Shape of y_pred is {} and shape of y_true is {}'.format(y_pred.shape, y_true.shape))

print('DeepWatershed - Remove no pixels')
m = Metrics('DeepWatershed - Remove no pixels', seg=False, ndigits=3, 
            cutoff1=0.5, cutoff2=0.01, is_3d=True, round_output=True)
m.calc_object_stats(y_true, y_pred)
print('\n')


In [None]:
# make predictions on testing data
from timeit import default_timer

f_mov = 24
l_mov = f_mov+1
batch_size = 1

start = default_timer()
test_images = model.predict(X_test[f_mov:l_mov, 0:frames_per_batch, ...], batch_size=batch_size)
watershed_time = default_timer() - start

print('Watershed segmentation of shape', test_images[0].shape, 'in', watershed_time, 'seconds.')


import time

from matplotlib import pyplot as plt
import numpy as np

from skimage.feature import peak_local_max
from skimage.morphology import remove_small_objects


min_distance = 15          # minimum allowable distance between two centroid coords
detection_threshold = 0.01  # absolute threshold for minimum peak intensity (TODO - relative threshold)
distance_threshold = 0.1  # outer_distance threshold for cell border
small_objects_threshold = 0

masks = deep_watershed_3D(
    test_images,
    min_distance=min_distance,
    detection_threshold=detection_threshold,
    distance_threshold=distance_threshold,
    exclude_border=False,
    small_objects_threshold=small_objects_threshold)

# calculated in the postprocessing above, but useful for visualizing
inner_distance = test_images[0]
outer_distance = test_images[1]

fig, axes = plt.subplots(1, 5, figsize=(20, 20))

index = 0
slice = 28
coords = peak_local_max(
    np.squeeze(inner_distance[index]),
    min_distance=min_distance,
    threshold_abs=detection_threshold,
    exclude_border=False)

# raw image with centroid
axes[0].imshow(X_test[index+f_mov, slice, ..., 0])
axes[0].scatter(coords[..., 2], coords[..., 1],
                color='r', marker='.', s=10)

# raw image xz plane
#axes[0].imshow(X_test[index+f_mov, :, 150, :, 0])
#axes[0].scatter(coords[..., 2], coords[..., 0], color='r', marker='.', s=10)
    
    
axes[1].imshow(inner_distance[index, slice, ..., 0], cmap='jet')
axes[2].imshow(outer_distance[index, slice, ..., 0], cmap='jet')
axes[3].imshow(masks[index, slice, ...], cmap='jet')
axes[4].imshow(np.squeeze(y_test[index+f_mov, slice, ...]), cmap='jet')

plt.show()


# Play movie of the predicted masks
vid_msk = np.expand_dims(masks, axis=1)
vid_msk = np.expand_dims(vid_msk, axis=-1)
HTML(get_js_video(vid_msk[index], batch=0, channel=0, interval=500))
#HTML(get_js_video(y_test, batch=index+f_mov, channel=0, interval=500, vmax=y_test[index+f_mov, ..., 0].max()))