In [None]:
# Raj lab OG pic (png) reader

def _get_3D_images_from_directory(data_location, channel_names, image_size=(50,512,512), dtype='float32'):
    """Read all images from directory with channel_name in the filename

    Args:
        data_location (str): folder containing image files
        channel_names (str[]): list of wildcards to select filenames

    Returns:
        numpy.array: numpy array of each image in the directory
    """
    data_format = K.image_data_format()
    img_list_channels = []
    for channel in channel_names:
        img_list_channels.append(nikon_getfiles(data_location, channel))

    #img_temp = np.asarray(get_image(os.path.join(data_location, img_list_channels[0][0])))
    img_temp = np.zeros(image_size, dtype)

    n_channels = len(channel_names)
    all_images = []

    for stack_iteration in range(len(img_list_channels[0])):

        if data_format == 'channels_first':
            shape = (1, n_channels, img_temp.shape[0], img_temp.shape[1], img_temp.shape[2])
        else:
            shape = (1, img_temp.shape[0], img_temp.shape[1], img_temp.shape[2], n_channels)

        all_channels = np.zeros(shape, dtype=K.floatx())

        for j in range(n_channels):
            img_path = os.path.join(data_location, img_list_channels[j][stack_iteration])
            channel_img = get_image(img_path)

            # Images in this dataset have different dimensions along all 3 axes
            # 
            f_dim = channel_img.shape[0] ##
            x_dim = channel_img.shape[1] ##
            y_dim = channel_img.shape[2] ##
            
            if data_format == 'channels_first':
                all_channels[0, j, :f_dim, :x_dim, :y_dim] = channel_img
            else:
                all_channels[0, :f_dim, :x_dim, :y_dim, j] = channel_img

        all_images.append(all_channels)

    return all_images

# Retrieve data and format into a numpy array of dims (batch, z, x, y, channels)
# Note: channels 1 and 2 are input channels, channel 3 is nuclear annotations



path_to_data = '/deepcell_data/users/geneva/raj_organoid_npzs/corrected/'
channel_names = ['dapi', 'gfp', 'nuclei']

raw_img_list = _get_3D_images_from_directory(path_to_data, channel_names)
print('number of z_stacks in data is: ', len(raw_img_list))
print('shape of each z_stack is: ', raw_img_list[0].shape)

raw_img_array = np.squeeze(np.asarray(raw_img_list, dtype='float32'))
print('final data shape is: ', raw_img_array.shape)

In [None]:
# Remove tiles in X_train and y_train that do not contain at least one cell
X_train = list(X_train)
y_train = list(y_train)
X_new = []
y_new = []

for tile in range(X_train.shape[0]):
    if y_train[tile, ...].max() > 0:
        X_new.append(X_train[tile])
        y_new.append(y_train[tile])

X_tiles = np.asarray(X_new)
y_tiles = np.asarray(y_new)

# Clear unused space in memory
del X_new
del y_new

print('After removing empty tiles, shape of X_tiles is {} and shape of y_tiles is {}'.format(X_tiles.shape, y_tiles.shape))

In [None]:
# Rescale images along z-axis - important for calculating min_distance between centroids
#from skimage.transform import rescale

#scales = (1.0, 2.3, 1.0, 1.0, 1.0)

#for semantic_head in range(len(output_images)):
#    output_images[semantic_head] = rescale(output_images[semantic_head], scale=scale)

#y_assess = rescale(y_assess, scale=scale)

In [None]:
# Plotting individual scaled cell volumes

padding = 5

fig = plt.figure(figsize=(15, 30))

plot_masks = np.rollaxis(np.squeeze(masks), 0, 3)

cell_ids = np.unique(plot_masks)[1:]
num_cells = max(cell_ids)

num_rows = int(np.ceil(num_cells / 2))
num_cols = 2


for cell in cell_ids:
    
    cell_mask = np.where(plot_masks == cell, plot_masks, 0)
    props = regionprops(cell_mask)
    bbox = props[0].bbox
    
    y_min = max(0, bbox[0] - padding)
    y_max = bbox[3] + padding
    x_min = max(0, bbox[1] - padding)
    x_max = bbox[4] + padding
    z_min = max(0, bbox[2] - padding)
    z_max = bbox[5] + padding
    
    # Cell volumes
    
    plot_num = int(str(num_rows) + str(num_cols) + str(cell))
    
    ax = fig.add_subplot(plot_num, projection='3d')

    ax.voxels(cell_mask[y_min:y_max, x_min:x_max, z_min:z_max], colors=color_dict[cell])
    

plt.show()
    

In [None]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np
import scipy.ndimage as nd

from skimage.feature import peak_local_max, corner_peaks
from skimage.measure import label
from skimage.morphology import watershed, remove_small_objects
from skimage.segmentation import relabel_sequential

from deepcell_toolbox.utils import erode_edges

def deep_watershed_3D(outputs,
                   min_distance=10,
                   detection_threshold=0.1,
                   distance_threshold=0.01,
                   exclude_border=False,
                   small_objects_threshold=0):
    """Postprocessing function for deep watershed models. Thresholds the inner
    distance prediction to find cell centroids, which are used to seed a marker
    based watershed of the outer distance prediction.

    Args:
        outputs (list): DeepWatershed model output. A list of
            [inner_distance, outer_distance, fgbg].

            - inner_distance: Prediction for the inner distance transform.
            - outer_distance: Prediction for the outer distance transform.
            - fgbg: Prediction for the foregound/background transform.

        min_distance (int): Minimum allowable distance between two cell centroids2.
        detection_threshold (float): Threshold for the inner distance.
        distance_threshold (float): Threshold for the outer distance.
        exclude_border (bool): Whether to include centroid detections
            at the border.
        small_objects_threshold (int): Removes objects smaller than this size.

    Returns:
        numpy.array: Uniquely labeled mask.
    """
    inner_distance_batch = outputs[0][:, ..., 0]
    outer_distance_batch = outputs[1][:, ..., 0]

    label_images = []
    for batch in range(inner_distance_batch.shape[0]):
        inner_distance = inner_distance_batch[batch]
        outer_distance = outer_distance_batch[batch]

        coords = peak_local_max(inner_distance,
                                min_distance=min_distance,
                                threshold_abs=detection_threshold,
                                exclude_border=exclude_border)
        
        # Find peaks and merge equal regions        
        markers = np.zeros(inner_distance.shape)
        markers[coords[:, 0], coords[:, 1], coords[:,2]] = 1
        markers = label(markers)        
        
        label_image = watershed(-outer_distance,
                                markers,
                                mask=outer_distance > distance_threshold)
        label_image = erode_edges(label_image, 1)

        # Remove small objects
        label_image = remove_small_objects(label_image, min_size=small_objects_threshold)

        # Relabel the label image
        label_image, _, _ = relabel_sequential(label_image)

        label_images.append(label_image)
    
    label_images = np.stack(label_images, axis=0)

    return label_images

                                                                  