In [418]:
import os
import errno
import pandas as pd
import numpy as np
import deepcell
import math
import scipy.signal


from matplotlib import pyplot as plt
from sklearn.model_selection import train_test_split
from deepcell.utils.io_utils import *
from deepcell.utils.plot_utils import get_js_video
from IPython.display import HTML
from tensorflow.python.keras import backend as K
from itertools import product

try:
    from deepcell_toolbox import utils
except:
    import utils

In [2]:
def _get_3D_images_from_directory(data_location, channel_names, image_size=(50,512,512), dtype='float32'):
    """Read all images from directory with channel_name in the filename

    Args:
        data_location (str): folder containing image files
        channel_names (str[]): list of wildcards to select filenames

    Returns:
        numpy.array: numpy array of each image in the directory
    """
    data_format = K.image_data_format()
    img_list_channels = []
    for channel in channel_names:
        img_list_channels.append(nikon_getfiles(data_location, channel))

    #img_temp = np.asarray(get_image(os.path.join(data_location, img_list_channels[0][0])))
    img_temp = np.zeros(image_size, dtype)

    n_channels = len(channel_names)
    all_images = []

    for stack_iteration in range(len(img_list_channels[0])):

        if data_format == 'channels_first':
            shape = (1, n_channels, img_temp.shape[0], img_temp.shape[1], img_temp.shape[2])
        else:
            shape = (1, img_temp.shape[0], img_temp.shape[1], img_temp.shape[2], n_channels)

        all_channels = np.zeros(shape, dtype=K.floatx())

        for j in range(n_channels):
            img_path = os.path.join(data_location, img_list_channels[j][stack_iteration])
            channel_img = get_image(img_path)

            # Images in this dataset have different dimensions along all 3 axes
            # 
            f_dim = channel_img.shape[0] ##
            x_dim = channel_img.shape[1] ##
            y_dim = channel_img.shape[2] ##
            
            if data_format == 'channels_first':
                all_channels[0, j, :f_dim, :x_dim, :y_dim] = channel_img
            else:
                all_channels[0, :f_dim, :x_dim, :y_dim, j] = channel_img

        all_images.append(all_channels)

    return all_images

In [3]:
path_to_data = '/images/kevin-data/RajLab_Organoid3'
channel_names = ['dapi', 'gfp', 'nuclei']

raw_img_list = _get_3D_images_from_directory(path_to_data, channel_names)
print('number of z_stacks in data is: ', len(raw_img_list))
print('shape of each z_stack is: ', raw_img_list[0].shape)

raw_img_array = np.squeeze(np.asarray(raw_img_list, dtype='float32'))
print('final data shape is: ', raw_img_array.shape)

number of z_stacks in data is:  18
shape of each z_stack is:  (1, 50, 512, 512, 3)
final data shape is:  (18, 50, 512, 512, 3)


In [538]:
def tile_image_3D(image, model_input_shape=(10, 256, 256), stride_ratio=0.5):
    """
    Tile large image into many overlapping tiles of size "model_input_shape".

    Args:
        image (numpy.array): The 3D image to tile, must be rank 5.
        model_input_shape (tuple): The input size of the model.
        stride_ratio (float): The stride expressed as a fraction of the tile sizet

    Returns:
        tuple(numpy.array, dict): An tuple consisting of an array of tiled
            images and a dictionary of tiling details (for use in un-tiling).

    Raises:
        ValueError: image is not rank 5.
    """
    if image.ndim != 5:
        raise ValueError('Expected image of 5, got {}'.format(
            image.ndim))

    image_size_z, image_size_x, image_size_y = image.shape[1:4]
    tile_size_z = model_input_shape[0]
    tile_size_x = model_input_shape[1]
    tile_size_y = model_input_shape[2]

    ceil = lambda x: int(np.ceil(x))
    round_to_even = lambda x: int(np.ceil(x / 2.0) * 2)

    stride_z = min(round_to_even(stride_ratio * tile_size_z), tile_size_z)
    stride_x = min(round_to_even(stride_ratio * tile_size_x), tile_size_x)
    stride_y = min(round_to_even(stride_ratio * tile_size_y), tile_size_y)
    
    rep_number_z = max(ceil((image_size_z - tile_size_z) / stride_z + 1), 1)
    rep_number_x = max(ceil((image_size_x - tile_size_x) / stride_x + 1), 1)
    rep_number_y = max(ceil((image_size_y - tile_size_y) / stride_y + 1), 1)
    new_batch_size = image.shape[0] * rep_number_z * rep_number_x * rep_number_y

    # catches error caused by interpolation along z axis with rep number = 1
    # TODO - create a better solution or figure out why it doesn't occur in x and y planes
    if rep_number_z == 1:
        stride_z = tile_size_z

    tiles_shape = (new_batch_size, tile_size_z, tile_size_x, tile_size_y, image.shape[4])
    tiles = np.zeros(tiles_shape, dtype=image.dtype)

    # Calculate overlap of last tile along each axis
    overlap_z = (tile_size_z + stride_z * (rep_number_z - 1)) - image_size_z
    overlap_x = (tile_size_x + stride_x * (rep_number_x - 1)) - image_size_x
    overlap_y = (tile_size_y + stride_y * (rep_number_y - 1)) - image_size_y

    # Calculate padding needed to account for overlap and pad image accordingly
    pad_z = (int(np.ceil(overlap_z / 2)), int(np.floor(overlap_z / 2)))
    pad_x = (int(np.ceil(overlap_x / 2)), int(np.floor(overlap_x / 2)))
    pad_y = (int(np.ceil(overlap_y / 2)), int(np.floor(overlap_y / 2)))
    pad_null = (0, 0) 
    padding = (pad_null, pad_z, pad_x, pad_y, pad_null)
    image = np.pad(image, padding, 'constant', constant_values=0)

    counter = 0
    batches = []
    z_starts = []
    z_ends = []
    x_starts = []
    x_ends = []
    y_starts = []
    y_ends = []
    overlaps_z = []
    overlaps_x = []
    overlaps_y = []

    for b in range(image.shape[0]):
        for i in range(rep_number_x):
            for j in range(rep_number_y):
                for k in range(rep_number_z):
                    z_axis = 1
                    x_axis = 2
                    y_axis = 3

                    # Compute the start and end for each tile
                    if i != rep_number_x - 1:  # not the last one
                        x_start, x_end = i * stride_x, i * stride_x + tile_size_x
                    else:
                        x_start, x_end = image.shape[x_axis] - tile_size_x, image.shape[x_axis]

                    if j != rep_number_y - 1:  # not the last one
                        y_start, y_end = j * stride_y, j * stride_y + tile_size_y
                    else:
                        y_start, y_end = image.shape[y_axis] - tile_size_y, image.shape[y_axis]
                        
                    if k != rep_number_z - 1:  # not the last one
                        z_start, z_end = k * stride_z, k * stride_z + tile_size_z
                    else:
                        z_start, z_end = image.shape[z_axis] - tile_size_z, image.shape[z_axis]

                    # Compute the overlaps for each tile
                    if i == 0:
                        overlap_x = (0, tile_size_x - stride_x)
                    elif i == rep_number_x - 2:
                        overlap_x = (tile_size_x - stride_x, tile_size_x - image.shape[x_axis] + x_end)
                    elif i == rep_number_x - 1:
                        overlap_x = ((i - 1) * stride_x + tile_size_x - x_start, 0)
                    else:
                        overlap_x = (tile_size_x - stride_x, tile_size_x - stride_x)

                    if j == 0:
                        overlap_y = (0, tile_size_y - stride_y)
                    elif j == rep_number_y - 2:
                        overlap_y = (tile_size_y - stride_y, tile_size_y - image.shape[y_axis] + y_end)
                    elif j == rep_number_y - 1:
                        overlap_y = ((j - 1) * stride_y + tile_size_y - y_start, 0)
                    else:
                        overlap_y = (tile_size_y - stride_y, tile_size_y - stride_y)
                        
                    if k == 0:                       
                        overlap_z = (0, tile_size_z - stride_z)
                    elif k == rep_number_z - 2:
                        overlap_z = (tile_size_z - stride_z, tile_size_z - image.shape[z_axis] + z_end)
                    elif k == rep_number_z - 1:
                        overlap_z = ((k - 1) * stride_z + tile_size_z - z_start, 0)
                    else:
                        overlap_z = (tile_size_z - stride_z, tile_size_z - stride_z)

                    tiles[counter] = image[b, z_start:z_end, x_start:x_end, y_start:y_end, :]
                    batches.append(b)
                    x_starts.append(x_start)
                    x_ends.append(x_end)
                    y_starts.append(y_start)
                    y_ends.append(y_end)
                    z_starts.append(z_start)
                    z_ends.append(z_end)
                    overlaps_x.append(overlap_x)
                    overlaps_y.append(overlap_y)
                    overlaps_z.append(overlap_z)
                    counter += 1

    tiles_info = {}
    tiles_info['batches'] = batches
    tiles_info['x_starts'] = x_starts
    tiles_info['x_ends'] = x_ends
    tiles_info['y_starts'] = y_starts
    tiles_info['y_ends'] = y_ends
    tiles_info['z_starts'] = z_starts
    tiles_info['z_ends'] = z_ends
    tiles_info['overlaps_x'] = overlaps_x
    tiles_info['overlaps_y'] = overlaps_y
    tiles_info['overlaps_z'] = overlaps_z
    tiles_info['stride_x'] = stride_x
    tiles_info['stride_y'] = stride_y
    tiles_info['stride_z'] = stride_z
    tiles_info['tile_size_x'] = tile_size_x
    tiles_info['tile_size_y'] = tile_size_y
    tiles_info['tile_size_z'] = tile_size_z
    tiles_info['stride_ratio'] = stride_ratio
    tiles_info['image_shape'] = image.shape
    tiles_info['dtype'] = image.dtype
    tiles_info['pad_x'] = pad_x
    tiles_info['pad_y'] = pad_y
    tiles_info['pad_z'] = pad_z

    return tiles, tiles_info

In [539]:
def test_tile_image_3D():
    shapes = [
        (3, 5, 21, 21, 1),
        (1, 10, 21, 31, 2),
        (1, 15, 31, 21, 1),
    ]
    model_input_shapes = [(4, 3, 4), (3, 5, 5), (3, 7, 7), (5, 12, 15)]

    stride_ratios = [0.25, 0.33, 0.5, 0.66, 0.75, 0.8, 1]

    dtypes = ['int32', 'float32', 'uint16', 'float16']

    prod = product(shapes, model_input_shapes, stride_ratios, dtypes)

    for shape, input_shape, stride_ratio, dtype in prod:
        big_image = (np.random.random(shape) * 100).astype(dtype)
        tiles, tiles_info = utils.tile_image_3D(
            big_image, input_shape,
            stride_ratio=stride_ratio)

        assert tiles.shape[1:] == input_shape + (shape[-1],)
        assert tiles.dtype == big_image.dtype

        ceil = lambda x: int(np.ceil(x))
        round_to_even = lambda x: int(np.ceil(x / 2.0) * 2)

        image_size_z, image_size_x, image_size_y = big_image.shape[1:4]
        tile_size_z = input_shape[0]
        tile_size_x = input_shape[1]
        tile_size_y = input_shape[2]

        stride_x = round_to_even(stride_ratio * tile_size_x)
        stride_y = round_to_even(stride_ratio * tile_size_y)
        stride_z = round_to_even(stride_ratio * tile_size_z)
        
        if stride_z > tile_size_z:
            stride_z = tile_size_z
        
        if stride_x > tile_size_x:
            stride_x = tile_size_x

        if stride_y > tile_size_y:
            stride_y = tile_size_y

        rep_number_x = ceil((image_size_x - tile_size_x) / stride_x + 1)
        rep_number_y = ceil((image_size_y - tile_size_y) / stride_y + 1)
        rep_number_z = ceil((image_size_z - tile_size_z) / stride_z + 1)

        expected_batches = big_image.shape[0] * rep_number_x * rep_number_y * rep_number_z

        assert tiles.shape[0] == expected_batches

    # test bad input shape
    bad_shape = (21, 21, 1)
    bad_image = (np.random.random(bad_shape) * 100)
    with pytest.raises(ValueError):
        utils.tile_image(bad_image, (5, 5), stride_ratio=0.75)

In [540]:
test_tile_image_3D()

In [575]:
def spline_window(window_size, overlap_left, overlap_right, power=3):
    """
    Squared spline (power=2) window function:
    https://www.wolframalpha.com/input/?i=y%3Dx**2,+y%3D-(x-2)**2+%2B2,+y%3D(x-4)**2,+from+y+%3D+0+to+2
    """

    def _spline_window(w_size):
        intersection = int(w_size / 4)
        wind_outer = (abs(2 * (scipy.signal.triang(w_size))) ** power) / 2
        wind_outer[intersection:-intersection] = 0

        wind_inner = 1 - (abs(2 * (scipy.signal.triang(w_size) - 1)) ** power) / 2
        wind_inner[:intersection] = 0
        wind_inner[-intersection:] = 0

        wind = wind_inner + wind_outer
        wind = wind / np.amax(wind)
        return wind

    # Create the window for the left overlap
    if overlap_left > 0:
        window_size_l = 2 * overlap_left
        l_spline = _spline_window(window_size_l)[0:overlap_left]

    # Create the window for the right overlap
    if overlap_right > 0:
        window_size_r = 2 * overlap_right
        r_spline = _spline_window(window_size_r)[overlap_right:]

    # Put the two together
    window = np.ones((window_size,))
    if overlap_left > 0:
        window[0:overlap_left] = l_spline
    if overlap_right > 0:
        window[-overlap_right:] = r_spline

    return window


def window_3D(window_size, overlap_z=(5, 5), overlap_x=(32, 32), overlap_y=(32, 32), power=3):
    """
    Make a 1D window function, then infer and return a 2D window function.
    Done with an augmentation, and self multiplication with its transpose.
    Could be generalized to more dimensions.
    """
    window_z = spline_window(window_size[0], overlap_z[0], overlap_z[1], power=power)
    window_x = spline_window(window_size[1], overlap_x[0], overlap_x[1], power=power)
    window_y = spline_window(window_size[2], overlap_y[0], overlap_y[1], power=power)

    window_z = np.expand_dims(np.expand_dims(np.expand_dims(window_z, -1), -1), -1)
    window_x = np.expand_dims(np.expand_dims(np.expand_dims(window_x, -1), -1), -1)
    window_y = np.expand_dims(np.expand_dims(np.expand_dims(window_y, -1), -1), -1)

    window = window_z * window_x.transpose(1, 0, 2, 3) * window_y.transpose(1, 2, 0, 3)

    return window


def untile_image_3D(tiles, tiles_info, power=3, force=False, **kwargs):

    # Define mininally acceptable tile_size and stride_ratios for spline interpolation
    min_tile_size = 32
    min_stride_ratio = 0.5
    
    if force:
        min_tile_size = 0
        min_tile_height = 0
        min_stride_ratio = 0

    stride_ratio = tiles_info['stride_ratio']
    image_shape = tiles_info['image_shape']
    batches = tiles_info['batches']
    
    x_starts = tiles_info['x_starts']
    x_ends = tiles_info['x_ends']
    y_starts = tiles_info['y_starts']
    y_ends = tiles_info['y_ends']
    z_starts = tiles_info['z_starts']
    z_ends = tiles_info['z_ends']
    
    overlaps_x = tiles_info['overlaps_x']
    overlaps_y = tiles_info['overlaps_y']
    overlaps_z = tiles_info['overlaps_z']
    
    tile_size_x = tiles_info['tile_size_x']
    tile_size_y = tiles_info['tile_size_y']
    tile_size_z = tiles_info['tile_size_z']
    pad_x = tiles_info['pad_x']
    pad_y = tiles_info['pad_y']
    pad_z = tiles_info['pad_z']
 
    image_shape = [image_shape[0], image_shape[1], image_shape[2], image_shape[3], tiles.shape[-1]]
    window_size = (tile_size_z, tile_size_x, tile_size_y)
    image = np.zeros(image_shape, dtype=np.float)

    tile_data_zip = zip(tiles, batches, x_starts, x_ends, y_starts,
                        y_ends, z_starts, z_ends, overlaps_x, overlaps_y, overlaps_z)
    
    for (tile, batch, x_start, x_end, y_start, y_end, z_start,
         z_end, overlap_x, overlap_y, overlap_z) in tile_data_zip:
        
        # Conditions under which to use spline interpolation
        # A tile size or stride ratio that is too small gives inconsistent results,
        # so in these cases we skip interpolation and just return the raw tiles
        if (min_tile_size <= tile_size_x < image_shape[2] and
                min_tile_size <= tile_size_y < image_shape[3] and
                min_stride_ratio <= stride_ratio):
            
            window = window_3D(window_size, overlap_z=overlap_z, overlap_x=overlap_x, overlap_y=overlap_y, power=power)
            image[batch, z_start:z_end, x_start:x_end, y_start:y_end, :] += tile * window
        else:
            image[batch, z_start:z_end, x_start:x_end, y_start:y_end, :] = tile

    image = image.astype(tiles.dtype)

    x_start = pad_x[0]
    y_start = pad_y[0]
    z_start = pad_z[0]
    x_end = image_shape[2] - pad_x[1]
    y_end = image_shape[3] - pad_y[1]
    z_end = image_shape[1] - pad_z[1]

    image = image[:, z_start:z_end, x_start:x_end, y_start:y_end, :]

    return image

In [576]:
def test_untile_image_3D():    
    shapes = [
        (1, 30, 60, 51, 2),
        (2, 20, 90, 30, 1)
    ]

    rand_rel_diff_thresh = 2e-2
    model_input_shapes = [(4, 60, 70), (30, 20, 30), (70, 40, 50)]
    
    stride_ratios = [0.33, 0.5, 0.51, 0.66, 1]
    dtypes = ['int32', 'float32', 'uint16', 'float16']
    power = 3
    
    prod = product(shapes, model_input_shapes, stride_ratios, dtypes)

    # Test that randomly generated arrays are unchanged within a moderate tolerance
    for shape, input_shape, stride_ratio, dtype in prod:

        big_image = (np.random.random(shape) * 100).astype(dtype)
        tiles, tiles_info = tile_image_3D(big_image, model_input_shape=input_shape,   # add utils
                                             stride_ratio=stride_ratio)

        untiled_image = untile_image_3D(tiles, tiles_info, power=power)   # add utils

        assert untiled_image.dtype == dtype
        assert untiled_image.shape == shape

        np.testing.assert_allclose(big_image, untiled_image, rand_rel_diff_thresh)

    # Test that constant arrays are unchanged by tile/untile
    for shape, input_shape, stride_ratio, dtype in prod:
        for x in [0, 1, np.random.randint(2, 99)]:
            big_image = np.empty(shape).astype(dtype).fill(x)
            tiles, tiles_info = utils.tile_image_3D(big_image,
                                                 model_input_shape=input_shape,
                                                 stride_ratio=stride_ratio)
            untiled_image = untile_image_3D(tiles, tiles_info, power=power)
            assert untiled_image.dtype == dtype
            assert untiled_image.shape == shape
            np.testing.assert_equal(big_image, untiled_image)
            

    # test that a stride_fraction of 0 raises an error
    #with pytest.raises(ValueError):
    #    big_image_test = np.zeros((4, 4)).astype('int32')
    #    tiles, tiles_info = utils.tile_image(big_image_test, model_input_shape=(2, 2),
    #                                         stride_ratio=0)
    #    untiled_image = utils.untile_image(tiles, tile_info)

In [577]:
test_untile_image_3D()

In [537]:

shape = (1, 60, 50, 60, 1) #
input_shape = (50, 50, 50) #

#shape = (1, 40, 41, 50, 1)
#input_shape = (40, 40, 40)
constant = 0
dtype = 'int32'
stride_ratio = 0.6
power = 3

big_image = (np.random.random(shape) * 100).astype(dtype)
tiles, tiles_info = tile_image_3D(big_image, model_input_shape=input_shape,   # add utils
                                     stride_ratio=stride_ratio)

untiled_image = untile_image_3D(tiles, tiles_info, power=power, force=True)   # add utils


np.testing.assert_equal(big_image, untiled_image)



rep_number_z is:  2 , rep_number x is:  1
stride_z is:  30 , stride_x is:  30
tile_size_z is:  50 , tile_size_x is:  50


In [510]:
print(tiles.shape)

diff = big_image - untiled_image

print(diff.max())
print(diff.min())

#diff = np.where(diff > 0, diff*65000, diff)

X_slice = np.expand_dims(diff, axis=1)

HTML(get_js_video(X_slice[0], batch=0, channel=0))

(4, 40, 40, 40, 1)
0
0


In [453]:
tiles, tile_info = tile_image_3D(raw_img_array, model_input_shape=(50, 128, 128))

print('raw_img_array.shape is: ', raw_img_array.shape)
print('tiles.shape is: ', tiles.shape)

print('raw_img_array.shape is: ', raw_img_array.shape)
print('tiles.shape is: ', tiles.shape)

#np.min(X_slice[0, 0, 2, ..., 2])

#X_slice = np.expand_dims(raw_img_array, axis=1)
#HTML(get_js_video(X_slice[2], batch=0, channel=2))

X_slice = np.expand_dims(tiles, axis=1)
HTML(get_js_video(X_slice[3], batch=0, channel=2))

raw_img_array.shape is:  (18, 50, 512, 512, 3)
tiles.shape is:  (882, 50, 128, 128, 3)
raw_img_array.shape is:  (18, 50, 512, 512, 3)
tiles.shape is:  (882, 50, 128, 128, 3)


In [454]:
untiled = untile_image_3D(tiles, tile_info)
print('raw_img_array.shape is: ', raw_img_array.shape)
print('tiles.shape is: ', tiles.shape)
print('untiled.shape is: ', untiled.shape)

diff = np.subtract(untiled, raw_img_array)
print('diff shape is: ', diff.shape)
print('diff max is: ', diff.max())
print('diff min is: ', diff.min())


print('input max is: ', raw_img_array.max())
print('input min is: ', raw_img_array.min())
print('input dtype is: ', raw_img_array.dtype)

print('tiles max is: ', tiles.max())
print('tiles min is: ', tiles.min())
print('tiles dtype is: ', tiles.dtype)

print('untiled max is: ', untiled.max())
print('untiled min is: ', untiled.min())
print('untiled dtype is: ', untiled.dtype)

np.testing.assert_allclose(raw_img_array, untiled, 1e-1)



raw_img_array.shape is:  (18, 50, 512, 512, 3)
tiles.shape is:  (882, 50, 128, 128, 3)
untiled.shape is:  (18, 50, 512, 512, 3)
diff shape is:  (18, 50, 512, 512, 3)
diff max is:  16.0
diff min is:  -19099.16
input max is:  65535.0
input min is:  0.0
input dtype is:  float32
tiles max is:  65535.0
tiles min is:  0.0
tiles dtype is:  float32
untiled max is:  65551.0
untiled min is:  0.0
untiled dtype is:  float32


AssertionError: 
Not equal to tolerance rtol=0.1, atol=0

Mismatch: 3.26%
Max absolute difference: 19099.16
Max relative difference: 45.040005
 x: array([[[[[ 66., 136.,   0.],
          [158.,  93.,   0.],
          [ 59., 106.,   0.],...
 y: array([[[[[6.600000e+01, 1.360000e+02, 0.000000e+00],
          [1.580000e+02, 9.300000e+01, 0.000000e+00],
          [5.900000e+01, 1.060000e+02, 0.000000e+00],...

In [455]:
X_slice = np.expand_dims(raw_img_array, axis=1)
HTML(get_js_video(X_slice[3], batch=0, channel=2))

In [456]:


X_slice = np.expand_dims(untiled, axis=1)

HTML(get_js_video(X_slice[3], batch=0, channel=2))