### Global Imports

In [None]:
import os
import datetime
import errno
import argparse

import numpy as np
import tensorflow as tf

from tensorflow.python.keras import backend as K
from tensorflow.python.keras.optimizers import SGD, Adam
from tensorflow.python.keras.models import Sequential, Model

from deepcell import get_data
from deepcell import make_training_data
from deepcell import rate_scheduler
from deepcell.model_zoo import siamese_model
from deepcell.training import train_model_siamese_daughter

### Training the Model

In [None]:
from deepcell import rate_scheduler
from deepcell.model_zoo import siamese_model
from deepcell.training import train_model_siamese_daughter

direc_data = '/data/npz_data/cells/HeLa/S3/movie/'
dataset = 'nuclear_movie_hela0-7_same'

#direc_data = '/data/npz_data/cells/3T3/NIH/'
#dataset = 'nuclear_movie_3t3_set1_same'

training_data = np.load('{}{}.npz'.format(direc_data, dataset))

optimizer = SGD(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True)
lr_sched = rate_scheduler(lr=0.001, decay=0.99)
in_shape = (32, 32, 1)
features = {"appearance", "distance", "neighborhood", "perimeter"}

model = siamese_model(input_shape=in_shape, features=features)

tracking_model = train_model_siamese_daughter(model=model,
                                              dataset=dataset,
                                              optimizer=optimizer,
                                              expt='transform_sync',
                                              it=0,
                                              batch_size=128,
                                              min_track_length=6,
                                              features=features,
                                              n_epoch=5,
                                              direc_save='/data/models/cells/HeLa/S3',
                                              direc_data=direc_data,
                                              lr_sched=lr_sched,
                                              rotation_range=180,
                                              flip=True,
                                              shear=0,
                                              class_weight=None)

### Data Review

In [None]:
data = np.load('/data/npz_data/cells/HeLa/S3/movie/nuclear_movie_hela0-7_same.npz')
#data = np.load('/data/npz_data/cells/HeLa/S3/movie/nuclear_movie_HeLa_0_same.npz')
data.keys()
data_readable_X, data_readable_y = data['X'][()], data['y'][()]
print('X Shape:', data_readable_X.shape)
print('y Shape:', data_readable_y.shape)

In [None]:
import matplotlib.pyplot as plt
from skimage.io import imshow

img_raw = data_readable_X[0,0,:,:,0]
img_ann = data_readable_y[0,0,:,:,0]

# Visualize the result 
fig, ax = plt.subplots(1, 2, figsize=(12,12))
ax[0].imshow(img_raw, interpolation='none', cmap='gray')
ax[1].imshow(img_ann, interpolation='none', cmap='gray')
ax[0].set_title('Contrast (or Raw) Images')
ax[1].set_title('Annotated Images')
plt.show()

In [None]:
train_dict, val_dict = get_data('/data/npz_data/cells/HeLa/S3/movie/nuclear_movie_hela0-7_same.npz',
                                        mode='siamese_daughters')

In [None]:
import matplotlib.pyplot as plt
from skimage.io import imshow

# Compare 2 images
img_1 = train_dict['X'][0,1,:,:,0]
img_2 = train_dict['y'][0,1,:,:,0]

fig, ax = plt.subplots(1, 2, figsize=(12,12))
ax[0].imshow(img_1, interpolation='none', cmap='gray')
ax[1].imshow(img_2, interpolation='none', cmap='gray')
plt.show()

### Tracking Test

In [None]:
# Import the tracking function
from deepcell.tracking import cell_tracker

# Load up data to test
train_dict, val_dict = get_data('/data/npz_data/cells/HeLa/S3/movie/nuclear_movie_hela0-7_same.npz',
                                        mode='siamese_daughters')

In [None]:
from deepcell.model_zoo import siamese_model

# Import the tracking model
MODEL_DIR = '/data/models'
PREFIX = 'cells/HeLa/S3'
in_shape = (32, 32, 1)
features = {"perimeter", "appearance", "neighborhood", "distance"}

# Now we need to re-instantiate the model and load weights
siamese_weights_file = '2018-10-10_nuclear_movie_hela0-7_same_[a,d,n,p]_n_epoch=5_transform_sync_1.h5'
siamese_weights_file = os.path.join(MODEL_DIR, PREFIX, siamese_weights_file)

tracking_model = siamese_model(input_shape=in_shape, features=features)
tracking_model.load_weights(siamese_weights_file)

In [None]:
batch = 2

trial = cell_tracker(train_dict['X'][batch], train_dict['y'][batch],
                     tracking_model,
                     max_distance=1000,
                     track_length=5, division=0.5, birth=0.9, death=0.9,
                     features=features)
trial._track_cells()

# Visualizing the Result

In [None]:
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML

def get_js_video(images, batch=0, channel=0):
    fig = plt.figure()    
    ims = []
    for i in range(images.shape[1]):
        im = plt.imshow(images[batch, i, :, :, channel], animated=True, cmap='jet', vmin=0, vmax=15)
        ims.append([im])
        ani = animation.ArtistAnimation(fig, ims, interval=75, repeat_delay=1000)
    return HTML(ani.to_jshtml())

In [None]:
get_js_video(train_dict['y'], batch=batch)

In [None]:
get_js_video(np.expand_dims(trial.y_tracked, axis=0))

# Save the Raw and Tracked Output

In [None]:
channel = 0

for i in range(40):
    name_raw = os.path.join('test_raw_' + str(i) + '_.png')
    name_tracked = os.path.join('test_tracked_' + str(i) + '_.png')
    plt.imsave(name_raw, train_dict['X'][batch, i, :, :, channel], cmap='gray')
    plt.imsave(name_tracked, trial.y_tracked[i, :, :, channel], cmap='jet', vmin=0, vmax=15)


In [None]:
print(trial.y_tracked.shape)

In [None]:
print([trial.tracks[key]['parent'] for key in trial.tracks.keys()])
print([trial.tracks[key]['daughters'] for key in trial.tracks.keys()])
print([trial.tracks[key]['label'] for key in trial.tracks.keys()])
print([trial.tracks[key]['capped'] for key in trial.tracks.keys()])
print([key for key in trial.tracks.keys()])
print(trial.tracks[6]['frames'])
print(trial.tracks[8]['frames'])
print(trial.tracks[9]['frames'])

In [None]:
from sklearn.metrics import confusion_matrix
from IPython.display import clear_output
import time

datagen_val = SiameseDataGenerator(
    rotation_range=180,  # randomly rotate images by 0 to rotation_range degrees
    shear_range=0,  # randomly shear images in the range (radians , -shear_range to shear_range)
    horizontal_flip=0,  # randomly flip images
    vertical_flip=0)  # randomly flip images

batch_size=128
min_track_length=5
# fit the model on the batches generated by datagen.flow()
test = datagen_val.flow(train_dict, batch_size=batch_size, min_track_length=min_track_length, shuffle=False)

cm = np.zeros((3,3))
N_div = 0.
N_tot = 0.

for batch_x, batch_y in test:
    clear_output(wait=True)
    
    y_pred = tracking_model.predict(batch_x)
    
    truth = np.argmax(batch_y, axis=-1)
    pred = np.argmax(y_pred, axis=-1)
    cm_temp = confusion_matrix(truth,pred)
    if cm_temp.shape[0] == 2:
        cm[0:2,0:2] += cm_temp
    else:
        cm += confusion_matrix(truth,pred)
        
    N_div += np.sum(truth == 2)
    N_tot += truth.shape[0]
    
    print(np.round(cm.astype('float') / cm.sum(axis=1)[:, np.newaxis], decimals=2), N_div/N_tot)

    
# y_true = test.classes
# y_pred = tracking_model.predict_generator(test)
# print(y_pred)
# Y_test = np.argmax(y_test, axis=1) # Convert one-hot to index
# y_pred = model.predict_classes(x_test)
# print(classification_report(y_true, np.argmax(y_pred, axis=-1)))