This notebook is part of the `deepcell-tf` documentation: https://deepcell.readthedocs.io/.

# Training a segmentation model

`deepcell-tf` leverages [Jupyter Notebooks](https://jupyter.org) in order to train models. Example notebooks are available for most model architectures in the [notebooks folder](https://github.com/vanvalenlab/deepcell-tf/tree/master/notebooks). Most notebooks are structured similarly to this example and thus this notebook serves as a core reference for the deepcell approach to model training.

In [1]:
import os
import errno
import math
import pandas as pd
import numpy as np

from matplotlib import pyplot as plt
from sklearn.model_selection import train_test_split
from IPython.display import HTML
from tensorflow.python.keras import backend as K

import deepcell

from deepcell_toolbox.utils import resize, tile_image, untile_image
from deepcell_toolbox.deep_watershed import deep_watershed
from deepcell.utils.plot_utils import get_js_video
from deepcell.utils.io_utils import *



Instructions for updating:
If using Keras pass *_constraint arguments to layers.




In [20]:
def _get_3D_images_from_directory(data_location, channel_names, image_size=(49,1024,1024), dtype='float32'):
    """Read all images from directory with channel_name in the filename

    Args:
        data_location (str): folder containing image files
        channel_names (str[]): list of wildcards to select filenames

    Returns:
        numpy.array: numpy array of each image in the directory
    """
    data_format = K.image_data_format()
    img_list_channels = []
    for channel in channel_names:
        img_list_channels.append(nikon_getfiles(data_location, channel))

    #img_temp = np.asarray(get_image(os.path.join(data_location, img_list_channels[0][0])))
    img_temp = np.zeros(image_size, dtype)

    n_channels = len(channel_names)
    all_images = []

    for stack_iteration in range(len(img_list_channels[0])):

        if data_format == 'channels_first':
            shape = (1, n_channels, img_temp.shape[0], img_temp.shape[1], img_temp.shape[2])
        else:
            shape = (1, img_temp.shape[0], img_temp.shape[1], img_temp.shape[2], n_channels)

        all_channels = np.zeros(shape, dtype=K.floatx())

        for j in range(n_channels):
            img_path = os.path.join(data_location, img_list_channels[j][stack_iteration])
            channel_img = get_image(img_path)

            # Images in this dataset have different dimensions along all 3 axes
            # 
            f_dim = channel_img.shape[0] ##
            x_dim = channel_img.shape[1] ##
            y_dim = channel_img.shape[2] ##
            
            if data_format == 'channels_first':
                all_channels[0, j, :f_dim, :x_dim, :y_dim] = channel_img
            else:
                all_channels[0, :f_dim, :x_dim, :y_dim, j] = channel_img

        all_images.append(all_channels)    
    
    all_images = np.squeeze(np.asarray(all_images, dtype='float32'))
    
    return all_images

In [21]:
# Get data from sets 8-32
path_to_data = '/deepcell_data/data/cells/MouseBrain/generic/set8-32/raw_stacks/'
channel_names = ['DAPI', 'Nissl']
raw_img_arr = _get_3D_images_from_directory(path_to_data, channel_names)

# Get data from sets 1-7
super_path = '/deepcell_data/data/cells/MouseBrain/generic/'
full_im_array = np.copy(raw_img_arr)
for imset in range(1, 8):
    im_path = os.path.join(super_path, 'set{}'.format(imset))          
    im_arr = _get_3D_images_from_directory(im_path, channel_names)
    im_arr = np.expand_dims(im_arr, 0)
    full_im_array = np.concatenate((full_im_array, im_arr), 0)

In [22]:
X_train = full_im_array[..., 0]
print(X_train.shape)

(32, 49, 1024, 1024)


In [23]:
# Define model paths and parameters
# Initialize model

# For each batch, for each frame:
    # Preprocess?
    # Tile
    # Predict (on just DAPI)
    # Untile
    # Postprocess

# Stitch together 
# View results
# Save as npz

In [24]:
# Define model paths and parameters
filename = 'mousebrain.npz'
DATA_DIR = os.path.expanduser(os.path.join('~', '.keras', 'datasets'))

# DATA_FILE should be a npz file, preferably from `make_training_data`
DATA_FILE = os.path.join(DATA_DIR, filename)

# confirm the data file is available
#assert os.path.isfile(DATA_FILE)

# If the data file is in a subdirectory, mirror it in MODEL_DIR and LOG_DIR
PREFIX = os.path.relpath(os.path.dirname(DATA_FILE), DATA_DIR)

ROOT_DIR = '/data'  # TODO: Change this! Usually a mounted volume
MODEL_DIR = os.path.abspath(os.path.join(ROOT_DIR, 'models', PREFIX))
LOG_DIR = os.path.abspath(os.path.join(ROOT_DIR, 'logs', PREFIX))

norm_method = 'whole_image'  # data normalization

fill_mode = 'reflect'
cval = 0

model_name = 'mb2D_1x256x256_reflect'
model_path = os.path.join(MODEL_DIR, '{}.h5'.format(model_name))
input_shape = (256, 256, 1)


In [25]:
full_im_array.shape

(32, 49, 1024, 1024, 2)

In [26]:
# Initialize model and load weights

from deepcell.model_zoo.panopticnet import PanopticNet

prediction_model = PanopticNet(
    backbone='resnet50',
    input_shape=input_shape,
    norm_method='std',
    num_semantic_heads=2,
    num_semantic_classes=[1, 1], # inner distance, outer distance
    location=True,  # should always be true
    include_top=True)

prediction_model.load_weights(model_path, by_name=True)

In [27]:
channel = 0  # 0 =DAPI, 1 = Nissl? might be reversed
batch_size = 4
out_chan = 2

#mask_arr = np.expand_dims(np.zeros(full_im_array.shape[:4], dtype='float32'), -1)

#mask_arr.shape

In [28]:
full_im_array.dtype

dtype('float32')

In [31]:
# For each batch, for each frame:
    # Preprocess?, Tile, Predict (on just DAPI), Untile, Postprocess
    

stride_ratio = 0.75
SAVE_DIR = '/deepcell_data/data/cells/MouseBrain/auto_annotated/new_untile_stride_75v2'    #/channel0_new_untile_stride075.npz'


for batch_num in range(full_im_array.shape[0]):
    # Data to predict on
    batch = full_im_array[batch_num, ..., channel]
    
    # Output
    mov = full_im_array[batch_num, ...]
    
    # Mask array
    mask_arr = np.expand_dims(np.zeros(mov.shape[:3], dtype='uint16'), -1)
    
    #print('mask arr shape is: ', mask_arr.shape, ', mov shape is: ', mov.shape)
    
    for frame_num in range(full_im_array.shape[1]):
        frame = np.expand_dims(np.expand_dims(batch[frame_num, ...], 0), -1) # shape (1, 1024, 1024, 1)
        
        # No preprocessing I think
        
        # Tile frame
        tiles, tiles_info = tile_image(frame, model_input_shape=input_shape, stride_ratio=stride_ratio)
        
        # Predict on frame
        output_tiles = prediction_model.predict(tiles, batch_size=batch_size)
        
        # Untile
        output_images = [untile_image(o, tiles_info, model_input_shape=input_shape) for o in output_tiles]
        
        
        # Run deep_watershed_2D
        masks = deep_watershed(
            output_images,
            min_distance=10,
            detection_threshold=0.1,
            distance_threshold=0.01,
            exclude_border=False,
            small_objects_threshold=0)
        
        masks = np.rollaxis(masks, 0, 3)
        
        #print('masks shape is: ', masks.shape, ', batch shape is: ', batch.shape)
        
        mask_arr[frame_num, :, :, :] = masks
       
    save_path = os.path.join(SAVE_DIR, 'mov_{}.npz'.format(batch_num))
    
    
    
    start_frame = 2
    
    mov = mov[start_frame:, ...]
    mask_arr = mask_arr[start_frame:, ...]
    mask_arr = mask_arr.astype('uint16')
    
    print('batch num is: ', batch_num, ', mov shape is: ', mov.shape, 'mask_arr shape is: ', mask_arr.shape)
    np.savez(save_path, X=mov, y=mask_arr)

batch num is:  0 , mov shape is:  (47, 1024, 1024, 2) mask_arr shape is:  (47, 1024, 1024, 1)
batch num is:  1 , mov shape is:  (47, 1024, 1024, 2) mask_arr shape is:  (47, 1024, 1024, 1)
batch num is:  2 , mov shape is:  (47, 1024, 1024, 2) mask_arr shape is:  (47, 1024, 1024, 1)
batch num is:  3 , mov shape is:  (47, 1024, 1024, 2) mask_arr shape is:  (47, 1024, 1024, 1)
batch num is:  4 , mov shape is:  (47, 1024, 1024, 2) mask_arr shape is:  (47, 1024, 1024, 1)
batch num is:  5 , mov shape is:  (47, 1024, 1024, 2) mask_arr shape is:  (47, 1024, 1024, 1)
batch num is:  6 , mov shape is:  (47, 1024, 1024, 2) mask_arr shape is:  (47, 1024, 1024, 1)
batch num is:  7 , mov shape is:  (47, 1024, 1024, 2) mask_arr shape is:  (47, 1024, 1024, 1)
batch num is:  8 , mov shape is:  (47, 1024, 1024, 2) mask_arr shape is:  (47, 1024, 1024, 1)
batch num is:  9 , mov shape is:  (47, 1024, 1024, 2) mask_arr shape is:  (47, 1024, 1024, 1)
batch num is:  10 , mov shape is:  (47, 1024, 1024, 2) mask_

In [32]:
vid = np.expand_dims(mask_arr, 0)

HTML(get_js_video(vid, batch=0, channel=0, interval=800))

In [26]:
vid = np.expand_dims(mask_arr, 0)

HTML(get_js_video(vid, batch=0, channel=0, interval=800))

In [13]:
from skimage.morphology import watershed, remove_small_objects

for batch in range(mask_arr.shape[0]):
    for frame in range(mask_arr.shape[1]):
        mask_arr[batch, frame, ...] = remove_small_objects(mask_arr[batch, frame, ...].astype(int), min_size=100)

In [12]:
save_path = '/deepcell_data/data/cells/MouseBrain/auto_annotated/channel0_new_untile_stride075.npz'
np.savez(save_path, X=full_im_array, y=mask_arr)
# important to use kwargs X and y when saving the npz


# resave with remove_small_objects, then re-predict with channel1