Skip to content

Commit

Permalink
Create CellTracking application object, various application bugfixes. (
Browse files Browse the repository at this point in the history
…#444)

* Refactor CellTrackingModel into a CellTracking Application.

* Change whitespace for better readability.

* Use ImageNet weights if `use_pretrained_weights` is `True`.

* `normalize` is required pre-processing for `NuclearSegmentation` application.

* Update notebook with `CellTracking` application instead of `CellTrackingModel`.
  • Loading branch information
willgraf committed Oct 27, 2020
1 parent 40e6f7b commit 83226ab
Show file tree
Hide file tree
Showing 10 changed files with 304 additions and 165 deletions.
2 changes: 1 addition & 1 deletion deepcell/applications/__init__.py
Expand Up @@ -33,7 +33,7 @@
from deepcell.applications.cytoplasm_segmentation import CytoplasmSegmentation
from deepcell.applications.multiplex_segmentation import MultiplexSegmentation
from deepcell.applications.nuclear_segmentation import NuclearSegmentation
from deepcell.applications.cell_tracking import CellTrackingModel
from deepcell.applications.cell_tracking import CellTracking
from deepcell.applications.label_detection import LabelDetectionModel
from deepcell.applications.scale_detection import ScaleDetectionModel

Expand Down
122 changes: 96 additions & 26 deletions deepcell/applications/cell_tracking.py
Expand Up @@ -31,6 +31,10 @@

from tensorflow.python.keras.utils.data_utils import get_file

import deepcell_tracking
from deepcell_toolbox.processing import normalize

from deepcell.applications import Application
from deepcell import model_zoo


Expand All @@ -39,36 +43,102 @@
'epoch_80split_9tl.h5')


def CellTrackingModel(input_shape=(32, 32, 1),
neighborhood_scale_size=30,
use_pretrained_weights=True):
"""Creates an instance of a siamese_model used for cell tracking.
Detects whether to input cells are the same cell, different cells, or
daughter cells. This can be used along with a cost matrix to track full
cell lineages across many frames.
class CellTracking(Application):
"""Loads a `deepcell.model_zoo.siamese_model` model for object tracking
with pretrained weights using a simple `predict` interface.
Args:
input_shape (tuple): a 3-length tuple of the input data shape.
neighborhood_scale_size (int): size of resized neighborhood images
use_pretrained_weights (bool): whether to load pre-trained weights.
use_pretrained_weights (bool, optional): Loads pretrained weights. Defaults to True.
model_image_shape (tuple, optional): Shape of input data expected by model.
Defaults to `(32, 32, 1)`
neighborhood_scale_size (int):
birth (float): Cost of new cell in linear assignment matrix. Defaults to `0.99`.
death (float): Cost of cell death in linear assignment matrix. Defaults to `0.99`.
division (float): Cost of cell division in linear assignment matrix. Defaults to `0.9`.
"""
features = {'appearance', 'distance', 'neighborhood', 'regionprop'}

model = model_zoo.siamese_model(
input_shape=input_shape,
reg=1e-5,
init='he_normal',
neighborhood_scale_size=neighborhood_scale_size,
features=features)
#: Metadata for the dataset used to train the model
dataset_metadata = {
'name': 'tracked_nuclear_train_large',
'other': 'Pooled tracked nuclear data from HEK293, HeLa-S3, NIH-3T3, and RAW264.7 cells.'
}

#: Metadata for the model and training process
model_metadata = {
'batch_size': 128,
'lr': 1e-2,
'lr_decay': 0.99,
'training_seed': 1,
'n_epochs': 10,
'training_steps_per_epoch': 5536,
'validation_steps_per_epoch': 1384,
'features': {'appearance', 'distance', 'neighborhood', 'regionprop'},
'min_track_length': 9,
'neighborhood_scale_size': 30,
'crop_dim': 32,
}

def __init__(self,
use_pretrained_weights=True,
model_image_shape=(32, 32, 1),
neighborhood_scale_size=30,
birth=0.99,
death=0.99,
division=0.9):
self.features = {'appearance', 'distance', 'neighborhood', 'regionprop'}
self.birth = birth
self.death = death
self.division = division

model = model_zoo.siamese_model(
input_shape=model_image_shape,
reg=1e-5,
init='he_normal',
neighborhood_scale_size=neighborhood_scale_size,
features=self.features)

if use_pretrained_weights:
weights_path = get_file(
'CellTrackingModel.h5',
WEIGHTS_PATH,
cache_subdir='models',
file_hash='3349b363fdad0266a1845ba785e057a6')

model.load_weights(weights_path)
else:
weights_path = None

super(CellTracking, self).__init__(
model,
model_image_shape=model_image_shape,
model_mpp=0.65,
preprocessing_fn=None,
postprocessing_fn=None,
dataset_metadata=self.dataset_metadata,
model_metadata=self.model_metadata)

def predict(self, image, labels, **kwargs):
"""Using both raw image data and segmentation masks,
track objects across all frames.
Args:
image (np.array): Raw image data.
labels (np.array): Labels for image data, integer masks.
Returns:
dict: Tracked labels and lineage information.
"""
image_norm = normalize(image)

cell_tracker = deepcell_tracking.CellTracker(
image_norm, labels, self.model,
birth=self.birth, death=self.death,
division=self.division)

if use_pretrained_weights:
weights_path = get_file(
'CellTrackingModel.h5',
WEIGHTS_PATH,
cache_subdir='models',
file_hash='3349b363fdad0266a1845ba785e057a6')
cell_tracker.track_cells()

model.load_weights(weights_path)
return cell_tracker._track_review_dict()

return model
def track(self, image, labels, **kwargs):
"""Wrapper around predict() for convenience."""
return self.predict(image, labels, **kwargs)
79 changes: 79 additions & 0 deletions deepcell/applications/cell_tracking_test.py
@@ -0,0 +1,79 @@
# Copyright 2016-2019 The Van Valen Lab at the California Institute of
# Technology (Caltech), with support from the Paul Allen Family Foundation,
# Google, & National Institutes of Health (NIH) under Grant U24CA224309-01.
# All rights reserved.
#
# Licensed under a modified Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.github.com/vanvalenlab/deepcell-tf/LICENSE
#
# The Work provided may be used for non-commercial academic purposes only.
# For any other use of the Work, including commercial use, please contact:
# vanvalenlab@gmail.com
#
# Neither the name of Caltech nor the names of its contributors may be used
# to endorse or promote products derived from this software without specific
# prior written permission.
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for CellTracking Application"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from tensorflow.python.platform import test
import numpy as np
import skimage as sk

from deepcell.applications import CellTracking


def _get_dummy_tracking_data(length=128, frames=3,
data_format='channels_last'):
"""Borrowed from deepcell-tracking: https://bit.ly/37MFuNQ"""
if data_format == 'channels_last':
channel_axis = -1
else:
channel_axis = 0

x, y = [], []
while len(x) < frames:
_x = sk.data.binary_blobs(length=length, n_dim=2)
_y = sk.measure.label(_x)
if len(np.unique(_y)) > 3:
x.append(_x)
y.append(_y)

x = np.stack(x, axis=0) # expand to 3D
y = np.stack(y, axis=0) # expand to 3D

x = np.expand_dims(x, axis=channel_axis)
y = np.expand_dims(y, axis=channel_axis)

return x.astype('float32'), y.astype('int32')


class TestCellTracking(test.TestCase):

def test_cell_tracking_app(self):
with self.cached_session():
# test instantiation
app = CellTracking(use_pretrained_weights=False)

# test output shape
shape = app.model.output_shape
self.assertIsInstance(shape, tuple)
self.assertEqual(shape[-1], 3)

# test predict
x, y = _get_dummy_tracking_data(128, frames=3)
tracked = app.predict(x, y)
self.assertEqual(tracked['X'].shape, tracked['y_tracked'].shape)
44 changes: 25 additions & 19 deletions deepcell/applications/cytoplasm_segmentation.py
Expand Up @@ -116,6 +116,7 @@ def __init__(self,
location=True,
include_top=True,
lite=True,
use_imagenet=use_pretrained_weights,
interpolation='bilinear')

if use_pretrained_weights:
Expand All @@ -130,13 +131,14 @@ def __init__(self,
else:
weights_path = None

super(CytoplasmSegmentation, self).__init__(model,
model_image_shape=model_image_shape,
model_mpp=0.65,
preprocessing_fn=phase_preprocess,
postprocessing_fn=deep_watershed,
dataset_metadata=self.dataset_metadata,
model_metadata=self.model_metadata)
super(CytoplasmSegmentation, self).__init__(
model,
model_image_shape=model_image_shape,
model_mpp=0.65,
preprocessing_fn=phase_preprocess,
postprocessing_fn=deep_watershed,
dataset_metadata=self.dataset_metadata,
model_metadata=self.model_metadata)

def predict(self,
image,
Expand All @@ -147,29 +149,33 @@ def predict(self,
"""Generates a labeled image of the input running prediction with
appropriate pre and post processing functions.
Input images are required to have 4 dimensions `[batch, x, y, channel]`. Additional
empty dimensions can be added using `np.expand_dims`
Input images are required to have 4 dimensions `[batch, x, y, channel]`.
Additional empty dimensions can be added using `np.expand_dims`
Args:
image (np.array): Input image with shape `[batch, x, y, channel]`
batch_size (int, optional): Number of images to predict on per batch. Defaults to 4.
image_mpp (float, optional): Microns per pixel for the input image. Defaults to None.
batch_size (int, optional): Number of images to predict on per batch.
Defaults to 4.
image_mpp (float, optional): Microns per pixel for the input image.
Defaults to None.
preprocess_kwargs (dict, optional): Kwargs to pass to preprocessing function.
Defaults to {}.
postprocess_kwargs (dict, optional): Kwargs to pass to postprocessing function.
Defaults to {}.
Raises:
ValueError: Input data must match required rank of the application, calculated as
one dimension more (batch dimension) than expected by the model
ValueError: Input data must match required rank of the application,
calculated as one dimension more (batch dimension) than expected
by the model.
ValueError: Input data must match required number of channels of application
ValueError: Input data must match required number of channels of application.
Returns:
np.array: Labeled image
"""
return self._predict_segmentation(image,
batch_size=batch_size,
image_mpp=image_mpp,
preprocess_kwargs=preprocess_kwargs,
postprocess_kwargs=postprocess_kwargs)
return self._predict_segmentation(
image,
batch_size=batch_size,
image_mpp=image_mpp,
preprocess_kwargs=preprocess_kwargs,
postprocess_kwargs=postprocess_kwargs)

0 comments on commit 83226ab

Please sign in to comment.