In [1]:
%load_ext autoreload
%autoreload 2

import os
import wandb
import errno
import deepcell

import tensorflow as tf
import numpy as np

from wandb.keras import WandbCallback

from deepcell.data.tracking import prepare_dataset
from deepcell.data.tracking import Track, concat_tracks
from deepcell.model_zoo.tracking import GNNTrackingModel
from deepcell.utils.tracking_utils import trks_stats, load_trks

from tensorflow_addons.optimizers import RectifiedAdam as RAdam

from wandb.keras import WandbCallback

In [2]:
# Verify GPU count
from deepcell import train_utils
num_gpus = train_utils.count_gpus()
print('Training on {} GPUs'.format(num_gpus))

Training on 1 GPUs


In [3]:
# tf.config.experimental.get_memory_info('GPU:0')

In [4]:
# setup directories
ROOT_DIR = '/data'  # TODO: Change this! Usually a mounted volume

MODEL_DIR = os.path.abspath(os.path.join(ROOT_DIR, 'models'))
LOG_DIR = os.path.abspath(os.path.join(ROOT_DIR, 'logs'))
DATA_DIR = os.path.expanduser(os.path.join('~', '.keras', 'datasets'))
OUTPUT_DIR = os.path.abspath(os.path.join(ROOT_DIR, 'nuc_tracking'))

# create directories if they do not exist
for d in (MODEL_DIR, LOG_DIR, OUTPUT_DIR):
    try:
        os.makedirs(d)
    except OSError as exc:  # Guard against race condition
        if exc.errno != errno.EEXIST:
            raise

In [5]:
# Load and view stats on this file
filename = 'train.trks'
path = os.path.join('../trk_data/',filename)
trks_data = load_trks(path)

In [None]:
# all_tracks = Track(tracked_data=trks_data)
# track_info = concat_tracks([all_tracks])

In [None]:
dataset_sizes = os.path.abspath(os.path.join(ROOT_DIR, 'dataset_idxs_dvc.npy'))
dataset_indicies = np.load(dataset_sizes, allow_pickle=True).tolist()

In [None]:
"""
Functions for metrics
"""
def filter_and_flatten(y_true, y_pred):
    n_classes = tf.shape(y_true)[-1]
    new_shape = [-1, n_classes]
    y_true = tf.reshape(y_true, new_shape)
    y_pred = tf.reshape(y_pred, new_shape)

    # Mask out the padded cells
    y_true_reduced = tf.reduce_sum(y_true, axis=-1)
    good_loc = tf.where(y_true_reduced == 1)[:, 0]

    y_true = tf.gather(y_true, good_loc, axis=0)
    y_pred = tf.gather(y_pred, good_loc, axis=0)
    return y_true, y_pred


class Recall(tf.keras.metrics.Recall):
    def update_state(self, y_true, y_pred, sample_weight=None):
        y_true, y_pred = filter_and_flatten(y_true, y_pred)
        super(Recall, self).update_state(y_true, y_pred, sample_weight)


class Precision(tf.keras.metrics.Precision):
    def update_state(self, y_true, y_pred, sample_weight=None):
        y_true, y_pred = filter_and_flatten(y_true, y_pred)
        super(Precision, self).update_state(y_true, y_pred, sample_weight)


def loss_function(y_true, y_pred):
    y_true, y_pred = filter_and_flatten(y_true, y_pred)
    return deepcell.losses.weighted_categorical_crossentropy(
        y_true, y_pred,
        n_classes=tf.shape(y_true)[-1],
        axis=-1)

# Define optimizer
optimizer = RAdam(learning_rate=1e-3, clipnorm=0.001)

# Define the loss function
losses = {'temporal_adj_matrices': loss_function}

# Define metrics
metrics = [
    Recall(class_id=0, name='same_recall'),
    Recall(class_id=1, name='different_recall'),
    Recall(class_id=2, name='daughter_recall'),
    Precision(class_id=0, name='same_precision'),
    Precision(class_id=1, name='different_precision'),
    Precision(class_id=2, name='daughter_precision'),
]

In [None]:
seed = 1   # random seed for training/validation data split
batch_size = 4
track_length = 8  # only train on 8 frames at once
val_size = .20  # % of data saved as validation
test_size = .1  # % of data held out as a test set
n_epochs = 1  # number of training epochs

# steps_per_epoch = 1000
# validation_steps = 200
steps_per_epoch = 10
validation_steps = 2

translation_range = 512 #X_train.shape[-2]

n_layers = 1 # number of graph convolutions


# for i in range(len(dataset_indicies)):
i=0
new_data = {}
new_data['lineages'] = list(np.array(trks_data['lineages'])[dataset_indicies[i]])
new_data['X'] = trks_data['X'][dataset_indicies[i],...]
new_data['y'] = trks_data['y'][dataset_indicies[i],...]

ds_size = len(dataset_indicies[i])

print()
print('data idx', i, 'size', ds_size)
print()

all_tracks = Track(tracked_data=new_data)
track_info = concat_tracks([all_tracks])

# find maximum number of cells in any frame
max_cells = track_info['appearances'].shape[2]


train_data, val_data, test_data = prepare_dataset(
    track_info,
    rotation_range=180,
    translation_range=translation_range,
    seed=seed,
    val_size=val_size,
    test_size=test_size,
    batch_size=batch_size,
    track_length=track_length)


data idx 0 size 22



100%|██████████| 22/22 [00:54<00:00,  2.46s/it]
100%|██████████| 22/22 [00:50<00:00,  2.30s/it]


In [9]:
# seed = 1   # random seed for training/validation data split
# batch_size = 4
# track_length = 8  # only train on 8 frames at once
# val_size = .20  # % of data saved as validation
# test_size = .1  # % of data held out as a test set
# n_epochs = 1  # number of training epochs

# # steps_per_epoch = 1000
# # validation_steps = 200
# steps_per_epoch = 10
# validation_steps = 2

# translation_range = 512 #X_train.shape[-2]

# n_layers = 1 # number of graph convolutions

# model_name = 'graph_tracking_model_seed{}'.format(seed)
# model_path = os.path.join(MODEL_DIR, model_name)

# # for i in range(len(dataset_indicies)):
# i=0
# new_data = {}
# new_data['lineages'] = list(np.array(trks_data['lineages'])[dataset_indicies[i]])
# new_data['X'] = trks_data['X'][dataset_indicies[i],...]
# new_data['y'] = trks_data['y'][dataset_indicies[i],...]

# ds_size = len(dataset_indicies[i])

# print()
# print('data idx', i, 'size', ds_size)
# print()

# all_tracks = Track(tracked_data=new_data)
# track_info = concat_tracks([all_tracks])

# # find maximum number of cells in any frame
# max_cells = track_info['appearances'].shape[2]


# model_name = 'graph_tracking_model_seed{}'.format(seed)
# model_path = os.path.join(MODEL_DIR, model_name)


# train_data, val_data, test_data = prepare_dataset(
#     track_info,
#     rotation_range=180,
#     translation_range=translation_range,
#     seed=seed,
#     val_size=val_size,
#     test_size=test_size,
#     batch_size=batch_size,
#     track_length=track_length)

model_name = 'graph_tracking_model_seed{}'.format(seed)
model_path = os.path.join(MODEL_DIR, model_name)

graph_layer = 'se2t'  # type of graph convolution layer
# for graph_layer in ['se2t', 'gcn', 'se2c', 'gcs']:

tm = GNNTrackingModel(max_cells=max_cells, n_layers=n_layers, graph_layer=graph_layer)

# Compile model
tm.training_model.compile(optimizer=optimizer, loss=losses, metrics=metrics)

layer = graph_layer
ds_size = 5

# run = wandb.init(project='testing_new', reinit=True)
# wandb.run.name = layer+f'_datasize_{ds_size}'

# wandb.log({'metrics': metrics,
#             'losses': losses})

# Train the model
train_callbacks = [
    tf.keras.callbacks.ModelCheckpoint(
        model_path, monitor='val_loss',
        save_best_only=True, verbose=1,
        save_weights_only=False),
    tf.keras.callbacks.ReduceLROnPlateau(
        monitor='val_loss', factor=0.5, verbose=1,
        patience=3, min_lr=1e-7)]#, WandbCallback()]

loss_history = tm.training_model.fit(
    train_data,
    steps_per_epoch=steps_per_epoch,
    validation_data=val_data,
    validation_steps=validation_steps,
    epochs=n_epochs,
    verbose=1,
    callbacks=train_callbacks)

# Save models for prediction
inf_path = os.path.join(MODEL_DIR, f'TrackingModelInf_{layer}_datasize_{ds_size}')
ne_path = os.path.join(MODEL_DIR, f'TrackingModelNE_{layer}_datasize_{ds_size}')

tm.inference_model.save(inf_path)
tm.neighborhood_encoder.save(ne_path)

print()
print('finished this shit')
print()

# run.finish()



Instructions for updating:
The `validate_indices` argument has no effect. Indices are always validated on CPU and never validated on GPU.


In [5]:
seed = 1   # random seed for training/validation data split
batch_size = 4
track_length = 8  # only train on 8 frames at once
val_size = .20  # % of data saved as validation
test_size = .1  # % of data held out as a test set
n_epochs = 12  # number of training epochs

steps_per_epoch = 1000
validation_steps = 200

n_layers = 1  # number of graph convolutions
# graph_layer = 'se2c'  # type of graph convolution layer

translation_range = 512 #X_train.shape[-2]

model_name = 'graph_tracking_model_seed{}'.format(seed)
model_path = os.path.join(MODEL_DIR, model_name)

In [6]:
def filter_and_flatten(y_true, y_pred):
    n_classes = tf.shape(y_true)[-1]
    new_shape = [-1, n_classes]
    y_true = tf.reshape(y_true, new_shape)
    y_pred = tf.reshape(y_pred, new_shape)

    # Mask out the padded cells
    y_true_reduced = tf.reduce_sum(y_true, axis=-1)
    good_loc = tf.where(y_true_reduced == 1)[:, 0]

    y_true = tf.gather(y_true, good_loc, axis=0)
    y_pred = tf.gather(y_pred, good_loc, axis=0)
    return y_true, y_pred


class Recall(tf.keras.metrics.Recall):
    def update_state(self, y_true, y_pred, sample_weight=None):
        y_true, y_pred = filter_and_flatten(y_true, y_pred)
        super(Recall, self).update_state(y_true, y_pred, sample_weight)


class Precision(tf.keras.metrics.Precision):
    def update_state(self, y_true, y_pred, sample_weight=None):
        y_true, y_pred = filter_and_flatten(y_true, y_pred)
        super(Precision, self).update_state(y_true, y_pred, sample_weight)


def loss_function(y_true, y_pred):
    y_true, y_pred = filter_and_flatten(y_true, y_pred)
    return deepcell.losses.weighted_categorical_crossentropy(
        y_true, y_pred,
        n_classes=tf.shape(y_true)[-1],
        axis=-1)

In [7]:
from tensorflow_addons.optimizers import RectifiedAdam as RAdam


# Define optimizer
optimizer = RAdam(learning_rate=1e-3, clipnorm=0.001)

# Define the loss function
losses = {'temporal_adj_matrices': loss_function}

# Define metrics
metrics = [
    Recall(class_id=0, name='same_recall'),
    Recall(class_id=1, name='different_recall'),
    Recall(class_id=2, name='daughter_recall'),
    Precision(class_id=0, name='same_precision'),
    Precision(class_id=1, name='different_precision'),
    Precision(class_id=2, name='daughter_precision'),
]

In [10]:
for i in range(len(dataset_indicies)):

    new_data = {}
    new_data['lineages'] = list(np.array(trks_data['lineages'])[dataset_indicies[i]])
    new_data['X'] = trks_data['X'][dataset_indicies[i],...]
    new_data['y'] = trks_data['y'][dataset_indicies[i],...]
    
    ds_size = len(dataset_indicies[i])

    print()
    print('data idx', i, 'size', ds_size)
    print()

    all_tracks = Track(tracked_data=new_data)
    track_info = concat_tracks([all_tracks])

    max_cells = track_info['appearances'].shape[2]

    train_data, val_data, test_data = prepare_dataset(
        track_info,
        rotation_range=180,
        translation_range=translation_range,
        seed=seed,
        val_size=val_size,
        test_size=test_size,
        batch_size=batch_size,
        track_length=track_length)

    graph_layers = ['se2t', 'gcn', 'se2c', 'gcs']
    for layer in graph_layers:
        print()
        print(layer)
        print()

        run = wandb.init(project='cell_tracking', reinit=True)
        wandb.run.name = layer+f'_datasize_{ds_size}'

        tm = GNNTrackingModel(max_cells=max_cells, n_layers=n_layers, graph_layer=layer)

        wandb.log({'metrics': metrics,
                   'losses': losses})

        tm.training_model.compile(optimizer=optimizer, loss=losses, metrics=metrics)

        # Train the model
        train_callbacks = [
            tf.keras.callbacks.ModelCheckpoint(
                model_path, monitor='val_loss',
                save_best_only=True, verbose=1,
                save_weights_only=False),
            tf.keras.callbacks.ReduceLROnPlateau(
                monitor='val_loss', factor=0.5, verbose=1,
                patience=3, min_lr=1e-7), WandbCallback()
        ]

        loss_history = tm.training_model.fit(
            train_data,
            steps_per_epoch=steps_per_epoch,
            validation_data=val_data,
            validation_steps=validation_steps,
            epochs=n_epochs,
            verbose=1,
            callbacks=train_callbacks)


        # Save models for prediction
        inf_path = os.path.join(MODEL_DIR, f'TrackingModelInf_{layer}_datasize_{ds_size}')
        ne_path = os.path.join(MODEL_DIR, f'TrackingModelNE_{layer}_datasize_{ds_size}')

        tm.inference_model.save(inf_path)
        tm.neighborhood_encoder.save(ne_path)

        run.finish()


data idx 0 size 22



100%|██████████| 22/22 [00:46<00:00,  2.10s/it]
100%|██████████| 22/22 [00:41<00:00,  1.89s/it]



se2t



VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

Epoch 1/12
Instructions for updating:
The `validate_indices` argument has no effect. Indices are always validated on CPU and never validated on GPU.


In [None]:
from wandb.keras import WandbCallback

# Verify GPU count
from deepcell import train_utils
num_gpus = train_utils.count_gpus()
print('Training on {} GPUs'.format(num_gpus))

layer = graph_layer

run = wandb.init(project='cell_tracking', reinit=True)
wandb.run.name = layer+f'_datasize_{ds_size}'

wandb.log({'metrics': metrics,
            'losses': losses})

# Train the model
train_callbacks = [
    tf.keras.callbacks.ModelCheckpoint(
        model_path, monitor='val_loss',
        save_best_only=True, verbose=1,
        save_weights_only=False),
    tf.keras.callbacks.ReduceLROnPlateau(
        monitor='val_loss', factor=0.5, verbose=1,
        patience=3, min_lr=1e-7), WandbCallback()
]

loss_history = tm.training_model.fit(
    train_data,
    steps_per_epoch=steps_per_epoch,
    validation_data=val_data,
    validation_steps=validation_steps,
    epochs=n_epochs,
    verbose=1,
    callbacks=train_callbacks)