In [None]:
import warnings
warnings.simplefilter(action='ignore', category=UserWarning)

### Global Imports

In [None]:
import os
import datetime
import errno
import argparse

import numpy as np
import tensorflow as tf

from tensorflow.python.keras import backend as K
from tensorflow.python.keras.optimizers import SGD, Adam
from tensorflow.python.keras.models import Sequential, Model

from deepcell import get_data
#from deepcell import make_training_data
#from deepcell import rate_scheduler
#from deepcell.model_zoo import siamese_model
#from deepcell.training import train_model_siamese_daughter
#from deepcell.image_generators import MovieDataGenerator

### Data Generators

In [None]:
# Import statements for image_generators.py
from keras_preprocessing.image import Iterator, ImageDataGenerator

from sklearn.model_selection import train_test_split
from scipy.stats import special_ortho_group
from skimage.measure import label, regionprops
from skimage.transform import resize


# Custom Siamese Generators
class SiameseDataGenerator(ImageDataGenerator):
    def flow(self,
             train_dict,
             crop_dim=32,
             min_track_length=5,
             features=None,
             sync_transform=True,
             batch_size=32,
             shuffle=True,
             seed=None,
             data_format=None,
             save_to_dir=None,
             save_prefix='',
             save_format='png'):
        return SiameseIterator(
            train_dict,
            self,
            crop_dim=crop_dim,
            min_track_length=min_track_length,
            features=features,
            sync_transform=sync_transform,
            batch_size=batch_size,
            shuffle=shuffle,
            seed=seed,
            data_format=data_format,
            save_to_dir=save_to_dir,
            save_prefix=save_prefix,
            save_format=save_format)


class SiameseIterator(Iterator):
    def __init__(self,
                 train_dict,
                 image_data_generator,
                 crop_dim=14,
                 min_track_length=5,
                 batch_size=32,
                 occupancy_grid_size=10,
                 occupancy_window=100,
                 features=None,
                 sync_transform=True,
                 shuffle=False,
                 seed=None,
                 squeeze=False,
                 data_format=None,
                 save_to_dir=None,
                 save_prefix='',
                 save_format='png'):
        if data_format is None:
            data_format = K.image_data_format()

        if data_format == 'channels_first':
            self.channel_axis = 1
            self.row_axis = 3
            self.col_axis = 4
            self.time_axis = 2
        if data_format == 'channels_last':
            self.channel_axis = 4
            self.row_axis = 2
            self.col_axis = 3
            self.time_axis = 1
            
        if features is None:
            raise ValueError("SiameseIterator: No features specified.")

        self.x = np.asarray(train_dict['X'], dtype=K.floatx())
        self.y = np.array(train_dict['y'], dtype='int32')

        if self.x.ndim != 5:
            raise ValueError('Input data in `SiameseIterator` '
                             'should have rank 5. You passed an array '
                             'with shape', self.x.shape)

        self.crop_dim = crop_dim
        self.min_track_length = min_track_length
        self.features = sorted(features)
        self.sync_transform = sync_transform
        self.occupancy_grid_size = np.int(occupancy_grid_size)
        self.occupancy_window = np.int(occupancy_window)
        self.image_data_generator = image_data_generator
        self.squeeze = squeeze
        self.data_format = data_format
        self.save_to_dir = save_to_dir
        self.save_prefix = save_prefix
        self.save_format = save_format

        if 'daughters' in train_dict:
            self.daughters = train_dict['daughters']
        else:
            self.daughters = None

        self._remove_bad_images()
        self._create_track_ids()
        self._create_features()

        super(SiameseIterator, self).__init__(
            len(self.track_ids), batch_size, shuffle, seed)

    def _remove_bad_images(self):
        """
        This function goes through all of the batches of images and removes the 
        images that only have one cell.
        """
        good_batches = []
        number_of_batches = self.x.shape[0]
        for batch in range(number_of_batches):
            y = self.y[batch]
            unique_ids = np.unique(y.flatten())
            if len(unique_ids) > 2: # You should have at least 3 id's - 2 cells and 1 background
                good_batches.append(batch)

        X_new_shape = (len(good_batches), *self.x.shape[1:])
        y_new_shape = (len(good_batches), *self.y.shape[1:])

        X_new = np.zeros(X_new_shape, dtype = K.floatx())
        y_new = np.zeros(y_new_shape, dtype = np.int32)

        counter = 0
        for k, batch in enumerate(good_batches):
            X_new[k] = self.x[batch]
            y_new[k] = self.y[batch]

        self.x = X_new
        self.y = y_new

    def _create_track_ids(self):
        """
        This function builds the track id's. It returns a dictionary that
        contains the batch number and label number of each each track.
        Creates unique cell IDs, as cell labels are NOT unique across batches.
        """
        track_counter = 0
        track_ids = {}
        for batch in range(self.y.shape[0]):
            y_batch = self.y[batch]
            num_cells = np.amax(y_batch)
            for cell in range(1, num_cells + 1):
                # count number of pixels cell occupies in each frame
                y_true = np.sum(y_batch == cell, axis=(self.row_axis - 1, self.col_axis - 1))
                # get indices of frames where cell is present
                y_index = np.where(y_true > 0)[0]
                if y_index.size > 3: #self.min_track_length+1:  # if cell is present at all
                    if self.daughters is not None:
                        # Only include daughters if there are enough frames in their tracks
                        try:
                            daughter_ids = self.daughters[batch][cell]
                        except:
                            print('batch', batch)
                            print('cell', cell)
                            print('daughter shape', self.daughters.shape)
                            print('batch daughter shape', self.daughters[batch].shape)
                            print('num_cells', num_cells)
                            print('max num of cells', np.amax(y_batch))
                        if len(daughter_ids) > 0:
                            daughter_track_lengths = []
                            for did in daughter_ids:
                                # Screen daughter tracks to make sure they are long enough
                                # Length currently set to 0
                                d_true = np.sum(y_batch == did, axis=(self.row_axis - 1, self.col_axis - 1))
                                d_track_length = len(np.where(d_true>0)[0])
                                daughter_track_lengths.append(d_track_length > 3)
                            keep_daughters = all(daughter_track_lengths)
                            daughters = daughter_ids if keep_daughters else []
                        else:
                            daughters = []
                    else:
                        daughters = []
                            
                    track_ids[track_counter] = {
                        'batch': batch,
                        'label': cell,
                        'frames': y_index,
                        'daughters': daughters  
                    }
                                
                    track_counter += 1

                else:
                    y_batch[y_batch == cell] = 0
                    self.y[batch] = y_batch
                    
        # Add a field to the track_ids dict that locates all of the different cells
        # in each frame
        for track in track_ids.keys():
            track_ids[track]['different'] = {}
            batch = track_ids[track]['batch']
            label = track_ids[track]['label']
            for frame in track_ids[track]['frames']:
                y_unique = np.unique(self.y[batch][frame])
                y_unique = np.delete(y_unique, np.where(y_unique == 0))
                y_unique = np.delete(y_unique, np.where(y_unique == label))
                track_ids[track]['different'][frame] = y_unique  
                        
        # We will need to look up the track_ids of cells if we know their batch and label. We will 
        # create a dictionary that stores this information
        reverse_track_ids = {}
        for batch in range(self.y.shape[0]):
            reverse_track_ids[batch] = {}
        for track in track_ids.keys():
            batch = track_ids[track]['batch']
            label = track_ids[track]['label']
            reverse_track_ids[batch][label] = track
            
        # Save dictionaries
        self.track_ids = track_ids
        self.reverse_track_ids = reverse_track_ids
        
        # Identify which tracks have divisions
        self.tracks_with_divisions = []
        for track in self.track_ids.keys():
            if len(self.track_ids[track]['daughters']) > 0:
                self.tracks_with_divisions.append(track)

    def _get_features(self, X, y, frames, labels):
        channel_axis = self.channel_axis - 1
        if self.data_format == 'channels_first':
            appearance_shape = (X.shape[channel_axis],
                                len(frames),
                                self.crop_dim,
                                self.crop_dim)
        else:
            appearance_shape = (len(frames),
                                self.crop_dim,
                                self.crop_dim,
                                X.shape[channel_axis])

        occupancy_grid_shape = (len(frames), 2*self.occupancy_grid_size+1, 2*self.occupancy_grid_size+1,1)
        
        # Initialize storage for appearances and centroids
        appearances = np.zeros(appearance_shape, dtype=K.floatx())
        centroids = []
        perimeters = []
        occupancy_grids = np.zeros(occupancy_grid_shape, dtype = K.floatx())

        for counter, (frame, cell_label) in enumerate(zip(frames, labels)):
            # Get the bounding box
            X_frame = X[frame] if self.data_format == 'channels_last' else X[:, frame]
            y_frame = y[frame] if self.data_format == 'channels_last' else y[:, frame]
            props = regionprops(np.int32(y_frame == cell_label))
            minr, minc, maxr, maxc = props[0].bbox
            centroids.append(props[0].centroid)
            perimeters.append(np.array([props[0].perimeter]))

            # Extract images from bounding boxes
            if self.data_format == 'channels_first':
                appearance = np.copy(X[:, frame, minr:maxr, minc:maxc])
                resize_shape = (X.shape[channel_axis], self.crop_dim, self.crop_dim)
            else:
                appearance = np.copy(X[frame, minr:maxr, minc:maxc, :])
                resize_shape = (self.crop_dim, self.crop_dim, X.shape[channel_axis])

            # Resize images from bounding box
            max_value = np.amax([np.amax(appearance), np.absolute(np.amin(appearance))])
            appearance /= max_value
            appearance = resize(appearance, resize_shape, mode='constant')
            appearance *= max_value
            if self.data_format == 'channels_first':
                appearances[:, counter] = appearance
            else:
                appearances[counter] = appearance
                
            # Get occupancy grid
            occupancy_grid = np.zeros((2*self.occupancy_grid_size+1, 2*self.occupancy_grid_size+1,1), 
                                      dtype=K.floatx())
            X_padded = np.pad(X_frame, ((self.occupancy_window, self.occupancy_window), 
                                        (self.occupancy_window, self.occupancy_window),
                                        (0,0)), mode='constant', constant_values=0)
            y_padded = np.pad(y_frame, ((self.occupancy_window, self.occupancy_window), 
                                        (self.occupancy_window, self.occupancy_window),
                                        (0,0)), mode='constant', constant_values=0)
            props = regionprops(np.int32(y_padded == cell_label))
            center_x, center_y = props[0].centroid
            center_x, center_y = np.int(center_x), np.int(center_y)
            X_reduced = X_padded[center_x-self.occupancy_window:center_x+self.occupancy_window,
                                 center_y-self.occupancy_window:center_y+self.occupancy_window,:]
            y_reduced = y_padded[center_x-self.occupancy_window:center_x+self.occupancy_window,
                                 center_y-self.occupancy_window:center_y+self.occupancy_window,:]
            
            # Resize X_reduced in case it is used instead of the occupancy grid method
            resize_shape = (2*self.occupancy_grid_size+1, 2*self.occupancy_grid_size+1, X.shape[channel_axis])

            # Resize images from bounding box
            max_value = np.amax([np.amax(X_reduced), np.absolute(np.amin(X_reduced))])
            X_reduced /= max_value
            X_reduced = resize(X_reduced, resize_shape, mode='constant')
            X_reduced *= max_value
                    
            occupancy_grids[counter,:,:,:] = X_reduced #occupancy_grid
            
        return [appearances, centroids, occupancy_grids, perimeters]

    def _create_features(self):
        """
        This function gets the features of every cell, crops them out, resizes them, 
        and stores them in an matrix. Pre-fetching the features should significantly 
        speed up the generator. It also gets the centroids and occupancy grids
        """
        number_of_tracks = len(self.track_ids.keys())

        # Initialize the array for the appearances and centroids
        if self.data_format =='channels_first':
            all_appearances_shape = (number_of_tracks, self.x.shape[self.channel_axis], self.x.shape[self.time_axis], self.crop_dim, self.crop_dim)
        if self.data_format == 'channels_last':
            all_appearances_shape = (number_of_tracks, self.x.shape[self.time_axis], self.crop_dim, self.crop_dim, self.x.shape[self.channel_axis])
        all_appearances = np.zeros(all_appearances_shape, dtype=K.floatx())

        all_centroids_shape = (number_of_tracks, self.x.shape[self.time_axis], 2)
        all_centroids = np.zeros(all_centroids_shape, dtype=K.floatx())
        
        all_occupancy_grids_shape = (number_of_tracks, self.x.shape[self.time_axis], 
                                     2 * self.occupancy_grid_size + 1, 2 * self.occupancy_grid_size + 1, 1)
        all_occupancy_grids = np.zeros(all_occupancy_grids_shape, dtype=K.floatx())
        
        all_perimeters_shape = (number_of_tracks, self.x.shape[self.time_axis], 1)
        all_perimeters = np.zeros(all_perimeters_shape, dtype=K.floatx())

        for track in self.track_ids.keys():
            batch = self.track_ids[track]['batch']
            label = self.track_ids[track]['label']
            frames = self.track_ids[track]['frames']
            # frames = np.append(frames, frames[-1])

            # Make an array of labels that the same length as the frames array
            labels = [label] * len(frames)
            X = self.x[batch]
            y = self.y[batch]

            appearance, centroid, occupancy_grid, perimeter = self._get_features(X, y, frames, labels)

            if self.data_format == 'channels_first':
                all_appearances[track,:,np.array(frames),:,:] = appearance 
            if self.data_format == 'channels_last':
                all_appearances[track,np.array(frames),:,:,:] = appearance

            all_centroids[track, np.array(frames),:] = centroid
            all_occupancy_grids[track, np.array(frames),:,:] = occupancy_grid
            all_perimeters[track, np.array(frames),:] = perimeter
            
        self.all_appearances = all_appearances
        self.all_centroids = all_centroids
        self.all_occupancy_grids = all_occupancy_grids
        self.all_perimeters = all_perimeters

    def _fetch_appearances(self, track, frames):
        """
        This function gets the appearances after they have been 
        cropped out of the image
        """
        # TO DO: Check to make sure the frames are acceptable

        if self.data_format == 'channels_first':
            appearances = self.all_appearances[track,:,np.array(frames),:,:]
        if self.data_format == 'channels_last':
            appearances = self.all_appearances[track,np.array(frames),:,:,:]
        return appearances

    def _fetch_centroids(self, track, frames):
        """
        This function gets the centroids after they have been
        extracted and stored
        """
        # TO DO: Check to make sure the frames are acceptable
        return self.all_centroids[track,np.array(frames),:]
    
    def _fetch_occupancy_grids(self, track, frames):
        """
        This function gets the occupancy grids after they have been
        extracted and stored
        """
        # TO DO: Check to make sure the frames are acceptable
        return self.all_occupancy_grids[track,np.array(frames),:,:,:]
    
    def _fetch_perimeters(self, track, frames):
        """
        This function gets the centroids after they have been
        extracted and stored
        """
        # TO DO: Check to make sure the frames are acceptable
        return self.all_perimeters[track,np.array(frames)]

    def _fetch_frames(self, track, division=False):
        """
        This function fetches a random list of frames for a given track.
        If the division flag is true, then the list of frames ends at the cell's
        last appearance if the division flag is true.
        """
        track_id = self.track_ids[track]
        batch = track_id['batch']
        tracked_frames = list(track_id['frames'])

        # We need to have at least one future frame to pick from, so if 
        # the last frame of the movie is a tracked frame, remove it
        last_frame = self.x.shape[self.time_axis] - 1
        if last_frame in tracked_frames:
            tracked_frames.remove(last_frame)

        # Get the indices of the tracked_frames list - sometimes frames
        # are skipped
        tracked_frames_index = np.arange(len(tracked_frames))
        
        # Check if there are enough frames
        enough_frames = len(tracked_frames_index) > self.min_track_length + 1

        # We need to exclude the last frame so that we will always be able to make a comparison
        acceptable_indices = tracked_frames_index[self.min_track_length-1:-1] if enough_frames else tracked_frames_index[:-1]

        # Take the last frame if there is a division, otherwise randomly pick a frame
        index = -1 if division else np.random.choice(acceptable_indices) 

        # Select the frames. If there aren't enough frames, repeat the first frame
        # the necessary number of times
        if enough_frames:
            frames = tracked_frames[index+1-self.min_track_length:index+1]
        else:
            frames_temp = tracked_frames[0:index+1]
            missing_frames = self.min_track_length - len(frames_temp)
            frames = [tracked_frames[0]] * missing_frames + frames_temp

        return frames
    
    def _compute_appearances(self, track_1, frames_1, track_2, frames_2, transform):
        appearance_1 = self._fetch_appearances(track_1, frames_1)
        appearance_2 = self._fetch_appearances(track_2, frames_2)
            
        # Apply random transforms
        new_appearance_1 = np.zeros(appearance_1.shape, dtype=K.floatx())
        new_appearance_2 = np.zeros(appearance_2.shape, dtype=K.floatx())

        for frame in range(appearance_1.shape[self.time_axis-1]):
            if self.data_format == 'channels_first':
                if transform is not None:
                    app_temp = self.image_data_generator.apply_transform(appearance_1[:,frame,:,:], transform)
                else:
                    app_temp = self.image_data_generator.random_transform(appearance_1[:,frame,:,:])
                app_temp = self.image_data_generator.standardize(app_temp)
                new_appearance_1[:,frame,:,:] = app_temp                        

            if self.data_format == 'channels_last':
                if transform is not None:
                    app_temp = self.image_data_generator.apply_transform(appearance_1[frame], transform)
                else:
                    self.image_data_generator.random_transform(appearance_1[frame])
                app_temp = self.image_data_generator.standardize(app_temp)
                new_appearance_1[frame] = app_temp

        if self.data_format == 'channels_first':
            if transform is not None:
                app_temp = self.image_data_generator.apply_transform(appearance_2[:,0,:,:], transform)
            else:
                app_temp = self.image_data_generator.random_transform(appearance_2[:,0,:,:])
            app_temp = self.image_data_generator.standardize(app_temp)
            new_appearance_2[:,0,:,:] = app_temp   

        if self.data_format == 'channels_last':
            if transform is not None:
                app_temp = self.image_data_generator.apply_transform(appearance_2[0], transform)
            else:
                app_temp = self.image_data_generator.random_transform(appearance_2[0])
            app_temp = self.image_data_generator.standardize(app_temp)
            new_appearance_2[0] = app_temp

        return new_appearance_1, new_appearance_2

    def _compute_distances(self, track_1, frames_1, track_2, frames_2, transform):
        centroid_1 = self._fetch_centroids(track_1, frames_1)
        centroid_2 = self._fetch_centroids(track_2, frames_2)
        
        # Compute distances between centroids
        centroids = np.concatenate([centroid_1, centroid_2], axis=0)
        distance = np.diff(centroids, axis=0)
        zero_pad = np.zeros((1, 2), dtype=K.floatx())
        distance = np.concatenate([zero_pad, distance], axis=0)

        distance_1 = distance[0:-1,:]
        distance_2 = distance[-1,:]
            
        return distance_1, distance_2

    def _compute_perimeters(self, track_1, frames_1, track_2, frames_2, transform):
        perimeter_1 = self._fetch_perimeters(track_1, frames_1)
        perimeter_2 = self._fetch_perimeters(track_2, frames_2)
        
        return perimeter_1, perimeter_2
    
    def _compute_occupancy_grids(self, track_1, frames_1, track_2, frames_2, transform):
        occupancy_grid_1 = self._fetch_occupancy_grids(track_1, frames_1)
        occupancy_grid_2 = self._fetch_occupancy_grids(track_2, frames_2)

        # Randomly transform the occupancy maps
        occupancy_generator = ImageDataGenerator(rotation_range=180, 
                                                 horizontal_flip=True,
                                                 vertical_flip=True)

        occupancy_grids = np.concatenate([occupancy_grid_1, occupancy_grid_2], axis=0)
        
        for frame in range(occupancy_grids.shape[self.time_axis-1]):
            og_temp = occupancy_grids[frame]
            if transform is not None:
                og_temp = self.image_data_generator.apply_transform(og_temp, transform)
            else:
                og_temp = self.image_data_generator.random_transform(og_temp)
            occupancy_grids[frame] = og_temp

        occupancy_grid_1 = occupancy_grids[0:-1,:,:,:]
        occupancy_grid_2 = occupancy_grids[-1,:,:,:]

        return occupancy_grid_1, occupancy_grid_2
    
    def _compute_feature_shape(self, feature, index_array):
        if feature == "appearance":
            if self.data_format == 'channels_first':
                shape_1 = (len(index_array), self.x.shape[self.channel_axis],
                            self.min_track_length, self.crop_dim, self.crop_dim)
                shape_2 = (len(index_array), self.x.shape[self.channel_axis],
                            self.crop_dim, self.crop_dim)
            else:
                shape_1 = (len(index_array), self.min_track_length,self.crop_dim, self.crop_dim,
                            self.x.shape[self.channel_axis])
                shape_2 = (len(index_array), 1, self.crop_dim, self.crop_dim,
                            self.x.shape[self.channel_axis])

        elif feature == "distance":
            shape_1 = (len(index_array), self.min_track_length, 2)
            shape_2 = (len(index_array), 1, 2)
            
        elif feature == "neighborhood":        
            shape_1 = (len(index_array), self.min_track_length, 
                       2 * self.occupancy_grid_size + 1, 2 * self.occupancy_grid_size + 1, 1)
            shape_2 = (len(index_array), 1, 2 * self.occupancy_grid_size + 1,
                       2 * self.occupancy_grid_size + 1, 1)
        elif feature == "perimeter":
            shape_1 = (len(index_array), self.min_track_length, 1)
            shape_2 = (len(index_array), 1, 1)
        else:
            raise ValueError("_compute_feature_shape: Unknown feature '{}'".format(feature))
        
        return shape_1, shape_2
    
    def _compute_feature(self, feature, *args, **kwargs):
        if feature == "appearance":
            return self._compute_appearances(*args, **kwargs)
        elif feature == "distance":
            return self._compute_distances(*args, **kwargs)
        elif feature == "neighborhood":
            return self._compute_occupancy_grids(*args, **kwargs)
        elif feature == "perimeter":
            return self._compute_perimeters(*args, **kwargs)
        else:
            raise ValueError("_compute_feature: Unknown feature '{}'".format(feature))

    def _get_batches_of_transformed_samples(self, index_array):
        # Initialize batch_x_1, batch_x_2, and batch_y, as well as cell distance data
        # DVV Notes - I'm changing how this works. We will now only compare cells in neighboring
        # frames. I am also modifying it so it will select a sequence of cells/distances for x1
        # and 1 cell/distance for x2
        
        # setup zeroed batch arrays for each feature & batch_y
        batch_features = []
        for feature in self.features:
            shape_1, shape_2 = self._compute_feature_shape(feature, index_array)
            batch_features.append([np.zeros(shape_1, dtype=K.floatx()),
                                   np.zeros(shape_2, dtype=K.floatx())])

        batch_y = np.zeros((len(index_array), 3), dtype=np.int32)

        for i, j in enumerate(index_array):
            # Identify which tracks are going to be selected
            track_id = self.track_ids[j]
            batch = track_id['batch']
            label_1 = track_id['label']  
            
            X = self.x[batch]
            y = self.y[batch]

            # Choose comparison cell
            # Determine what class the track will be - different (0), same (1), division (2)
            division = False
            type_cell = np.random.choice([0, 1, 2], p=[1/3, 1/3, 1/3])
            #type_cell = np.random.choice([0, 1, 2], p=[1/2, 1/2, 0/3])
            #type_cell = np.random.choice([0, 1, 2], p=[1/2, 0/3, 1/2])

            # Dealing with edge cases
            # If class is division, check if the first cell divides. If not, change tracks
            if type_cell == 2:
                division == True
                if len(track_id['daughters']) == 0:
                    # No divisions so randomly choose a different track that is
                    # guaranteed to have a division
                    new_j = np.random.choice(self.tracks_with_divisions)
                    j = new_j
                    track_id = self.track_ids[j]
                    batch = track_id['batch']
                    label_1 = track_id['label']
                    X = self.x[batch]
                    y = self.y[batch]

            # Get the frames for cell 1 and frames/label for cell 2
            frames_1 = self._fetch_frames(j, division=division)
            
            # For frame_2, choose the next frame cell 1 appears in 
            last_frame_1 = np.amax(frames_1)            
            frame_2 = np.amin( [x for x in track_id['frames'] if x > last_frame_1] )
            frames_2 = [frame_2]

            different_cells = track_id['different'][frame_2]
                    
            if type_cell == 0:
                # If there are no different cells in the subsequent frame, we must choose 
                # the same cell
                if len(different_cells) == 0:
                    type_cell = 1
                else:
                    label_2 = np.random.choice(different_cells)

            if type_cell == 1:
                # If there is only 1 cell in frame_2, we can only choose the class to be same
                label_2 = label_1
                        
            if type_cell == 2:
                # There should always be 2 daughters but not always a valid label
                label_2 = np.int(np.random.choice(track_id['daughters'])) 
                daughter_track = self.reverse_track_ids[batch][label_2]
                frame_2 = np.amin(self.track_ids[daughter_track]['frames'])
                frames_2 = [frame_2]

            track_1 = j
            track_2 = self.reverse_track_ids[batch][label_2]

            # compute desired features & save them to the batch arrays
            if self.sync_transform:
                # random angle & flips
                transform = {"theta": 360 * np.random.random(),
                             "flip_horizontal": np.random.random() < 0.5,
                             "flip_vertical": np.random.random() < 0.5}
            else:
                transform = None
            
            for feature_i, feature in enumerate(self.features):
                feature_1, feature_2 = self._compute_feature(feature,
                                                             track_1, frames_1,
                                                             track_2, frames_2,
                                                             transform=transform)
                batch_features[feature_i][0][i] = feature_1
                batch_features[feature_i][1][i] = feature_2

            batch_y[i, type_cell] = 1
   
        # prepare final batch list
        batch_list = []
        for feature_i, feature in enumerate(self.features):
            batch_feature_1, batch_feature_2 = batch_features[feature_i]
            # Remove singleton dimensions (if min_track_length is 1)
            if self.squeeze:
                if feature == "appearance":
                    batch_feature_1 = np.squeeze(batch_feature_1, axis=self.time_axis)
                    batch_feature_2 = np.squeeze(batch_feature_2, axis=self.time_axis)
                else:
                    batch_feature_1 = np.squeeze(batch_feature_1, axis=1)
                    batch_feature_2 = np.squeeze(batch_feature_2, axis=1)

            batch_list.append(batch_feature_1)
            batch_list.append(batch_feature_2)

        return batch_list, batch_y

    def __next__(self):
        """
        Returns the next batch.
        """
        # Keeps under lock only the mechanism which advances
        # the indexing of each batch.
        with self.lock:
            index_array = next(self.index_generator)
        # The transformation of images is not under thread lock
        # so it can be done in parallel
        return self._get_batches_of_transformed_samples(index_array)


### Data Generator Tests

In [None]:
train_dict, val_dict = get_data('/data/npz_data/cells/HeLa/S3/movie/nuclear_movie_hela0-7_same.npz',
                                        mode='siamese_daughters')

print('X_train shape:', train_dict['X'].shape)
print('daughter shape: ', train_dict['daughters'].shape)
print('daughter shape: ', train_dict['daughters'][0][0])

In [None]:
image_data_generator = SiameseDataGenerator(
        rotation_range=180, # randomly rotate images by 0 to rotation_range degrees
        shear_range=0,      # randomly shear images in the range (radians , -shear_range to shear_range)
        horizontal_flip=0,  # randomly flip images
        vertical_flip=0)

test_iterator = SiameseIterator(train_dict,
                                image_data_generator,
                                occupancy_grid_size=40,
                                crop_dim=32,
                                min_track_length=5,
                                features={"appearance", "distance", "neighborhood"})

In [None]:
import matplotlib.pyplot as plt
from skimage.io import imshow

# Compare 2 images
img_1 = train_dict['X'][1,0,:,:,0]
img_2 = train_dict['X'][0,1,:,:,0]

fig, ax = plt.subplots(1, 2, figsize=(12,12))
ax[0].imshow(img_1, interpolation='none', cmap='gray')
ax[1].imshow(img_2, interpolation='none', cmap='gray')
plt.show()

In [None]:
(lst, y) = next(test_iterator)

In [None]:
print(len(lst))
print(lst[0].shape)
print(lst[1].shape)
print(lst[2].shape)
print(lst[3].shape)
print(lst[4].shape)
print(lst[5].shape)

In [None]:
img_1 = lst[0][4,4,:,:,0]
img_2 = lst[1][4,0,:,:,0]
fig, ax = plt.subplots(1, 2, figsize=(12,12))
ax[0].imshow(img_1, interpolation='none', cmap='gray')
ax[1].imshow(img_2, interpolation='none', cmap='gray')
plt.show()

for i in range(32):
    if y[i,2] == 1:
        plt.imshow(lst[4][i,0,:,:,0])
        plt.show()

In [None]:
# Print Distance (Centroid)  Data
print(lst[2][3,:])
plt.imshow(lst[4][18,4,:,:,0])
print(np.round(lst[4][0,4,10,10,0], decimals=2))

img = lst[4][5,:,:,:,:]
img_old = lst[4][5,:,:,:,:]

img_new = np.zeros(img.shape)
datagen = ImageDataGenerator(rotation_range=30)
for frame in range(img.shape[0]):
    img_frame = img[frame]
    print(img_frame.shape)
    img_new[frame] = datagen.random_transform(img_frame)

plt.imshow(img_old[0,:,:,0])
plt.show()
plt.imshow(img_new[0,:,:,0])
plt.show()

In [None]:
# Check the labels
print(y)

### Model Zoo

In [None]:
#Import statements for model_zoo.py
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division

import tensorflow as tf
from tensorflow.python.keras import backend as K
from tensorflow.python.keras.activations import softmax
from tensorflow.python.keras.callbacks import ModelCheckpoint
from tensorflow.python.keras.models import Sequential, Model
from tensorflow.python.keras.layers import Conv2D, Conv3D, ConvLSTM2D, LSTM
from tensorflow.python.keras.layers import Add, Input, Concatenate, Lambda, InputLayer
from tensorflow.python.keras.layers import MaxPool2D, MaxPool3D, AvgPool2D, UpSampling2D
from tensorflow.python.keras.layers import Flatten, Dense, Dropout, Reshape
from tensorflow.python.keras.layers import Activation, Softmax
from tensorflow.python.keras.layers import BatchNormalization
from tensorflow.python.keras.regularizers import l2

from deepcell.layers import Resize
from deepcell.layers import DilatedMaxPool2D, DilatedMaxPool3D
from deepcell.layers import TensorProd2D, TensorProd3D
from deepcell.layers import Location, Location3D
from deepcell.layers import ImageNormalization2D, ImageNormalization3D

#Siamese Model
def siamese_model(
    input_shape=None,
    track_length=1,
    features=None,
    occupancy_grid_size=10,
    reg=1e-5, init='he_normal',
    softmax=True,
    norm_method='std',
    filter_size=61):
    
    def compute_input_shape(feature):
        if feature == "appearance":
            return input_shape
        elif feature == "distance":
            return (None, 2)
        elif feature == "neighborhood":
            return (None, 2 * occupancy_grid_size + 1, 2 * occupancy_grid_size + 1, 1)
        elif feature == "perimeter":
            return (None, 1)
        else:
            raise ValueError("samese_model.compute_input_shape: Unknown feature '{}'".format(feature))
            
    def compute_reshape(feature):
        if feature == "appearance":
            return (64,)
        elif feature == "distance":
            return (2,)
        elif feature == "neighborhood":
            return (64,)
        elif feature == "perimeter":
            return (1,)
        else:
            raise ValueError("samese_model.compute_output_shape: Unknown feature '{}'".format(feature))
    
    def compute_feature_extractor(feature, shape):
        if feature == "appearance":
            # This should not stay: channels_first/last should be used to dictate size (1 works for either right now)
            N_layers = np.int(np.floor(np.log2(input_shape[1])))
            feature_extractor = Sequential()
            feature_extractor.add(InputLayer(input_shape=shape))
            for layer in range(N_layers):
                feature_extractor.add(Conv3D(64, (1, 3, 3),
                                             kernel_initializer=init,
                                             padding='same', 
                                             kernel_regularizer=l2(reg)))
                feature_extractor.add(BatchNormalization(axis=channel_axis))
                feature_extractor.add(Activation('relu'))
                feature_extractor.add(MaxPool3D(pool_size=(1, 2, 2)))

            feature_extractor.add(Reshape((-1, 64)))
            return feature_extractor
    
        elif feature == "distance":
            return None
        elif feature == "neighborhood":
            N_layers_og = np.int(np.floor(np.log2(2 * occupancy_grid_size + 1)))
            feature_extractor_occupancy_grid = Sequential()
            feature_extractor_occupancy_grid.add(
                InputLayer(input_shape=(None, 2 * occupancy_grid_size + 1, 2 * occupancy_grid_size + 1, 1))
            )   
            for layer in range(N_layers_og):
                feature_extractor_occupancy_grid.add(Conv3D(64, (1, 3, 3),
                                                            kernel_initializer=init,
                                                            padding='same', 
                                                            kernel_regularizer=l2(reg)))
                feature_extractor_occupancy_grid.add(BatchNormalization(axis=channel_axis))
                feature_extractor_occupancy_grid.add(Activation('relu'))
                feature_extractor_occupancy_grid.add(MaxPool3D(pool_size=(1, 2, 2)))

            feature_extractor_occupancy_grid.add(Reshape((-1, 64)))
        
            return feature_extractor_occupancy_grid
        elif feature == "perimeter":
            return None
        else:
            raise ValueError("samese_model.compute_feature_extractor: Unknown feature '{}'".format(feature))

    if features is None:
        raise ValueError("siamese_model: No features specified.")
    
    if K.image_data_format() == 'channels_first':
        channel_axis = 1
        input_shape = (input_shape[0], None, *input_shape[1:])
    else:
        channel_axis = -1
        input_shape = (None, *input_shape)
    
    features = sorted(features)
    
    inputs = []
    outputs = []
    for feature in features:
        in_shape = compute_input_shape(feature)
        re_shape = compute_reshape(feature)
        feature_extractor = compute_feature_extractor(feature, in_shape)
        
        layer_1 = Input(shape=in_shape)
        layer_2 = Input(shape=in_shape)
        
        inputs.extend([layer_1, layer_2])
        
        # apply feature_extractor if it exists
        if feature_extractor is not None:
            layer_1 = feature_extractor(layer_1)
            layer_2 = feature_extractor(layer_2)
        
        # LSTM on 'left' side of network since that side takes in stacks of features
        layer_1 = LSTM(64)(layer_1)
        layer_2 = Reshape(re_shape)(layer_2)
        
        outputs.append([layer_1, layer_2])

    dense_merged = []
    for layer_1, layer_2 in outputs:
        merge = Concatenate(axis=channel_axis)([layer_1, layer_2])
        dense_merge = Dense(128)(merge)
        bn_merge = BatchNormalization(axis=channel_axis)(dense_merge)
        dense_relu = Activation('relu')(bn_merge)
        dense_merged.append(dense_relu)
    
    # Concatenate outputs from both instances
    merged_outputs = Concatenate(axis=channel_axis)(dense_merged)

    # Add dense layers
    dense1 = Dense(128)(merged_outputs)
    bn1 = BatchNormalization(axis=channel_axis)(dense1)
    relu1 = Activation('relu')(bn1)
    dense2 = Dense(128)(relu1)
    bn2 = BatchNormalization(axis=channel_axis)(dense2)
    relu2 = Activation('relu')(bn2)
    dense3 = Dense(3, activation='softmax')(relu2)

    # Instantiate model
    final_layer = dense3
    model = Model(inputs=inputs, outputs=final_layer)

    return model

### Training the Model

In [None]:
# Import Statements for training.py
from skimage.external import tifffile as tiff
from sklearn.utils.class_weight import compute_class_weight
from tensorflow.python.keras import backend as K
from tensorflow.python.keras.utils import to_categorical as keras_to_categorical
from tensorflow.python.keras.callbacks import ModelCheckpoint, LearningRateScheduler

from deepcell import losses
from deepcell.losses import weighted_categorical_crossentropy
from deepcell.utils.io_utils import get_images_from_directory
from deepcell.utils.train_utils import rate_scheduler
from deepcell.utils.transform_utils import to_categorical
from deepcell.settings import CHANNELS_FIRST

# from deepcell import image_generators as generators

# Training function

def train_model_siamese_daughter(model=None,
                                 dataset=None,
                                 optimizer=None,
                                 min_track_length=1,
                                 features=None,
                                 expt='',
                                 it=0, batch_size=1, n_epoch=100,
                                 direc_save='/data/models', direc_data='/data/npz_data',
                                 focal=False,
                                 gamma=0.5,
                                 lr_sched=rate_scheduler(lr=0.01, decay=0.95),
                                 rotation_range=0, flip=True, shear=0, class_weight=None):
    
    is_channels_first = K.image_data_format() == 'channels_first'
    training_data_file_name = os.path.join(direc_data, dataset + '.npz')
    todays_date = datetime.datetime.now().strftime('%Y-%m-%d')

    features_fmt = '[' + ','.join(f[0] for f in sorted(features)) + ']'
    features_fmt += '_n_epoch={}'.format(n_epoch)
    
    file_name_save = os.path.join(direc_save, '{}_{}_{}_{}_{}.h5'.format(todays_date, dataset, features_fmt, expt, it))
    file_name_save_loss = os.path.join(direc_save, '{}_{}_{}_{}_{}.npz'.format(todays_date, dataset, features_fmt, expt, it))

    print("saving model at:", file_name_save)
    print("saving loss at:", file_name_save_loss)
    
    train_dict, val_dict = get_data(training_data_file_name, mode='siamese_daughters')
    #train_dict, val_dict = get_data(training_data_file_name, mode='siamese_daughters', test_size=.2)

    class_weights = train_dict['class_weights']
    # the data, shuffled and split between train and test sets
    print('X_train shape:', train_dict['X'].shape)
    print('y_train shape:', train_dict['y'].shape)
    print('X_test shape:', val_dict['X'].shape)
    print('y_test shape:', val_dict['y'].shape)
    print('Output Shape:', model.layers[-1].output_shape)

    n_classes = model.layers[-1].output_shape[1 if is_channels_first else -1]

    def loss_function(y_true, y_pred):
        if focal:
            return losses.weighted_focal_loss(y_true, y_pred,
                                              gamma=gamma,
                                              n_classes=n_classes,
                                              from_logits=False)
        else:
            return losses.weighted_categorical_crossentropy(y_true, y_pred,
                                                            n_classes=n_classes,
                                                            from_logits=False)

    model.compile(loss=loss_function, optimizer=optimizer, metrics=['accuracy'])

    print('Using real-time data augmentation.')

    # this will do preprocessing and realtime data augmentation
    datagen = SiameseDataGenerator(
        rotation_range=rotation_range,  # randomly rotate images by 0 to rotation_range degrees
        shear_range=shear,  # randomly shear images in the range (radians , -shear_range to shear_range)
        horizontal_flip=flip,  # randomly flip images
        vertical_flip=flip)  # randomly flip images

    datagen_val = SiameseDataGenerator(
        rotation_range=0,  # randomly rotate images by 0 to rotation_range degrees
        shear_range=0,  # randomly shear images in the range (radians , -shear_range to shear_range)
        horizontal_flip=0,  # randomly flip images
        vertical_flip=0)  # randomly flip images

    def count_pairs(y):
        """
        Compute number of training samples needed to (stastically speaking)
        observe all cell pairs.
        Assume that the number of images is encoded in the second dimension.
        Assume that y values are a cell-uniquely-labeled mask.
        Assume that a cell is paired with one of its other frames 50% of the time
        and a frame from another cell 50% of the time.
        """
        # TODO: channels_first axes
        total_pairs = 0
        for image_set in range(y.shape[0]):
            set_cells = 0
            cells_per_image = []
            for image in range(y.shape[1]):
                image_cells = int(y[image_set, image, :, :, :].max())
                set_cells = set_cells + image_cells
                cells_per_image.append(image_cells)

            # Since there are many more possible non-self pairings than there are self pairings,
            # we want to estimate the number of possible non-self pairings and then multiply
            # that number by two, since the odds of getting a non-self pairing are 50%, to
            # find out how many pairs we would need to sample to (statistically speaking)
            # observe all possible cell-frame pairs.
            # We're going to assume that the average cell is present in every frame. This will
            # lead to an underestimate of the number of possible non-self pairings, but it's
            # unclear how significant the underestimate is.
            average_cells_per_frame = int(sum(cells_per_image) / len(cells_per_image))
            non_self_cellframes = (average_cells_per_frame - 1) * len(cells_per_image)
            non_self_pairings = non_self_cellframes * max(cells_per_image)
            cell_pairings = non_self_pairings * 2
            total_pairs = total_pairs + cell_pairings
        return total_pairs

    # This shouldn't remain long term.
    total_train_pairs = count_pairs(train_dict['y'])
    total_test_pairs = count_pairs(val_dict['y'])

    print("total_train_pairs:", total_train_pairs)
    print("total_test_pairs:", total_test_pairs)
    print("batch size: ", batch_size)
    print("validation_steps: ", total_test_pairs // batch_size)

    # fit the model on the batches generated by datagen.flow()
    loss_history = model.fit_generator(
        datagen.flow(train_dict,
                     batch_size=batch_size,
                     min_track_length=min_track_length,
                     features=features),
        steps_per_epoch=total_train_pairs // batch_size,
        epochs=n_epoch,
        validation_data=datagen_val.flow(val_dict,
                                         batch_size=batch_size,
                                         min_track_length=min_track_length,
                                         features=features),
        validation_steps=total_test_pairs // batch_size,
        callbacks=[
            ModelCheckpoint(file_name_save, monitor='val_loss', verbose=1, save_best_only=True, mode='auto'),
            LearningRateScheduler(lr_sched)
        ])

    model.save_weights(file_name_save)
    np.savez(file_name_save_loss, loss_history=loss_history.history)

    return model

### Training script

In [None]:
from tensorflow.python.keras import backend as K
from tensorflow.python.keras.optimizers import SGD
from tensorflow.python.keras.models import Sequential, Model

from deepcell import rate_scheduler

from deepcell.training import train_model_siamese_daughter
from deepcell.model_zoo import siamese_model

direc_data = '/data/npz_data/cells/HeLa/S3/movie/'
dataset = 'nuclear_movie_hela0-7_same'

#direc_data = '/data/npz_data/cells/3T3/NIH/'
#dataset = 'nuclear_movie_3t3_set1_same'

training_data = np.load('{}{}.npz'.format(direc_data, dataset))

optimizer = SGD(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True)
lr_sched = rate_scheduler(lr=0.001, decay=0.99)
in_shape = (32, 32, 1)
model = siamese_model(input_shape=in_shape, features={"appearance", "distance", "neighborhood", "perimeter"})

tracking_model = train_model_siamese_daughter(model=model,
                                              dataset=dataset,
                                              optimizer=optimizer,
                                              expt='transform_sync_no_diff',
                                              it=0,
                                              batch_size=128,
                                              min_track_length=6,
                                              features={"appearance", "distance", "neighborhood", "perimeter"},
                                              n_epoch=5,
                                              direc_save='/data/models/cells/HeLa/S3',
                                              direc_data=direc_data,
                                              lr_sched=lr_sched,
                                              rotation_range=180,
                                              flip=True,
                                              shear=0,
                                              class_weight=None)

In [None]:
filename = "/data/models/cells/HeLa/S3/2018-10-12_nuclear_movie_hela0-7_same_[a,d,n,p]_n_epoch=5_transform_sync_no_diff_0.npz"
loss_history = np.load(filename)
loss_history["loss_history"]