diff --git a/deepcell/applications/__init__.py b/deepcell/applications/__init__.py index 971f67e92..4f6959ae3 100644 --- a/deepcell/applications/__init__.py +++ b/deepcell/applications/__init__.py @@ -33,7 +33,7 @@ from deepcell.applications.cytoplasm_segmentation import CytoplasmSegmentation from deepcell.applications.multiplex_segmentation import MultiplexSegmentation from deepcell.applications.nuclear_segmentation import NuclearSegmentation -from deepcell.applications.cell_tracking import CellTrackingModel +from deepcell.applications.cell_tracking import CellTracking from deepcell.applications.label_detection import LabelDetectionModel from deepcell.applications.scale_detection import ScaleDetectionModel diff --git a/deepcell/applications/cell_tracking.py b/deepcell/applications/cell_tracking.py index 37ee63643..6cfffe7e1 100644 --- a/deepcell/applications/cell_tracking.py +++ b/deepcell/applications/cell_tracking.py @@ -31,6 +31,10 @@ from tensorflow.python.keras.utils.data_utils import get_file +import deepcell_tracking +from deepcell_toolbox.processing import normalize + +from deepcell.applications import Application from deepcell import model_zoo @@ -39,36 +43,102 @@ 'epoch_80split_9tl.h5') -def CellTrackingModel(input_shape=(32, 32, 1), - neighborhood_scale_size=30, - use_pretrained_weights=True): - """Creates an instance of a siamese_model used for cell tracking. - - Detects whether to input cells are the same cell, different cells, or - daughter cells. This can be used along with a cost matrix to track full - cell lineages across many frames. +class CellTracking(Application): + """Loads a `deepcell.model_zoo.siamese_model` model for object tracking + with pretrained weights using a simple `predict` interface. Args: - input_shape (tuple): a 3-length tuple of the input data shape. - neighborhood_scale_size (int): size of resized neighborhood images - use_pretrained_weights (bool): whether to load pre-trained weights. + use_pretrained_weights (bool, optional): Loads pretrained weights. Defaults to True. + model_image_shape (tuple, optional): Shape of input data expected by model. + Defaults to `(32, 32, 1)` + neighborhood_scale_size (int): + birth (float): Cost of new cell in linear assignment matrix. Defaults to `0.99`. + death (float): Cost of cell death in linear assignment matrix. Defaults to `0.99`. + division (float): Cost of cell division in linear assignment matrix. Defaults to `0.9`. """ - features = {'appearance', 'distance', 'neighborhood', 'regionprop'} - model = model_zoo.siamese_model( - input_shape=input_shape, - reg=1e-5, - init='he_normal', - neighborhood_scale_size=neighborhood_scale_size, - features=features) + #: Metadata for the dataset used to train the model + dataset_metadata = { + 'name': 'tracked_nuclear_train_large', + 'other': 'Pooled tracked nuclear data from HEK293, HeLa-S3, NIH-3T3, and RAW264.7 cells.' + } + + #: Metadata for the model and training process + model_metadata = { + 'batch_size': 128, + 'lr': 1e-2, + 'lr_decay': 0.99, + 'training_seed': 1, + 'n_epochs': 10, + 'training_steps_per_epoch': 5536, + 'validation_steps_per_epoch': 1384, + 'features': {'appearance', 'distance', 'neighborhood', 'regionprop'}, + 'min_track_length': 9, + 'neighborhood_scale_size': 30, + 'crop_dim': 32, + } + + def __init__(self, + use_pretrained_weights=True, + model_image_shape=(32, 32, 1), + neighborhood_scale_size=30, + birth=0.99, + death=0.99, + division=0.9): + self.features = {'appearance', 'distance', 'neighborhood', 'regionprop'} + self.birth = birth + self.death = death + self.division = division + + model = model_zoo.siamese_model( + input_shape=model_image_shape, + reg=1e-5, + init='he_normal', + neighborhood_scale_size=neighborhood_scale_size, + features=self.features) + + if use_pretrained_weights: + weights_path = get_file( + 'CellTrackingModel.h5', + WEIGHTS_PATH, + cache_subdir='models', + file_hash='3349b363fdad0266a1845ba785e057a6') + + model.load_weights(weights_path) + else: + weights_path = None + + super(CellTracking, self).__init__( + model, + model_image_shape=model_image_shape, + model_mpp=0.65, + preprocessing_fn=None, + postprocessing_fn=None, + dataset_metadata=self.dataset_metadata, + model_metadata=self.model_metadata) + + def predict(self, image, labels, **kwargs): + """Using both raw image data and segmentation masks, + track objects across all frames. + + Args: + image (np.array): Raw image data. + labels (np.array): Labels for image data, integer masks. + + Returns: + dict: Tracked labels and lineage information. + """ + image_norm = normalize(image) + + cell_tracker = deepcell_tracking.CellTracker( + image_norm, labels, self.model, + birth=self.birth, death=self.death, + division=self.division) - if use_pretrained_weights: - weights_path = get_file( - 'CellTrackingModel.h5', - WEIGHTS_PATH, - cache_subdir='models', - file_hash='3349b363fdad0266a1845ba785e057a6') + cell_tracker.track_cells() - model.load_weights(weights_path) + return cell_tracker._track_review_dict() - return model + def track(self, image, labels, **kwargs): + """Wrapper around predict() for convenience.""" + return self.predict(image, labels, **kwargs) diff --git a/deepcell/applications/cell_tracking_test.py b/deepcell/applications/cell_tracking_test.py new file mode 100644 index 000000000..07bb9e69f --- /dev/null +++ b/deepcell/applications/cell_tracking_test.py @@ -0,0 +1,79 @@ +# Copyright 2016-2019 The Van Valen Lab at the California Institute of +# Technology (Caltech), with support from the Paul Allen Family Foundation, +# Google, & National Institutes of Health (NIH) under Grant U24CA224309-01. +# All rights reserved. +# +# Licensed under a modified Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.github.com/vanvalenlab/deepcell-tf/LICENSE +# +# The Work provided may be used for non-commercial academic purposes only. +# For any other use of the Work, including commercial use, please contact: +# vanvalenlab@gmail.com +# +# Neither the name of Caltech nor the names of its contributors may be used +# to endorse or promote products derived from this software without specific +# prior written permission. +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for CellTracking Application""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.platform import test +import numpy as np +import skimage as sk + +from deepcell.applications import CellTracking + + +def _get_dummy_tracking_data(length=128, frames=3, + data_format='channels_last'): + """Borrowed from deepcell-tracking: https://bit.ly/37MFuNQ""" + if data_format == 'channels_last': + channel_axis = -1 + else: + channel_axis = 0 + + x, y = [], [] + while len(x) < frames: + _x = sk.data.binary_blobs(length=length, n_dim=2) + _y = sk.measure.label(_x) + if len(np.unique(_y)) > 3: + x.append(_x) + y.append(_y) + + x = np.stack(x, axis=0) # expand to 3D + y = np.stack(y, axis=0) # expand to 3D + + x = np.expand_dims(x, axis=channel_axis) + y = np.expand_dims(y, axis=channel_axis) + + return x.astype('float32'), y.astype('int32') + + +class TestCellTracking(test.TestCase): + + def test_cell_tracking_app(self): + with self.cached_session(): + # test instantiation + app = CellTracking(use_pretrained_weights=False) + + # test output shape + shape = app.model.output_shape + self.assertIsInstance(shape, tuple) + self.assertEqual(shape[-1], 3) + + # test predict + x, y = _get_dummy_tracking_data(128, frames=3) + tracked = app.predict(x, y) + self.assertEqual(tracked['X'].shape, tracked['y_tracked'].shape) diff --git a/deepcell/applications/cytoplasm_segmentation.py b/deepcell/applications/cytoplasm_segmentation.py index 740cad09e..81643cd19 100644 --- a/deepcell/applications/cytoplasm_segmentation.py +++ b/deepcell/applications/cytoplasm_segmentation.py @@ -116,6 +116,7 @@ def __init__(self, location=True, include_top=True, lite=True, + use_imagenet=use_pretrained_weights, interpolation='bilinear') if use_pretrained_weights: @@ -130,13 +131,14 @@ def __init__(self, else: weights_path = None - super(CytoplasmSegmentation, self).__init__(model, - model_image_shape=model_image_shape, - model_mpp=0.65, - preprocessing_fn=phase_preprocess, - postprocessing_fn=deep_watershed, - dataset_metadata=self.dataset_metadata, - model_metadata=self.model_metadata) + super(CytoplasmSegmentation, self).__init__( + model, + model_image_shape=model_image_shape, + model_mpp=0.65, + preprocessing_fn=phase_preprocess, + postprocessing_fn=deep_watershed, + dataset_metadata=self.dataset_metadata, + model_metadata=self.model_metadata) def predict(self, image, @@ -147,29 +149,33 @@ def predict(self, """Generates a labeled image of the input running prediction with appropriate pre and post processing functions. - Input images are required to have 4 dimensions `[batch, x, y, channel]`. Additional - empty dimensions can be added using `np.expand_dims` + Input images are required to have 4 dimensions `[batch, x, y, channel]`. + Additional empty dimensions can be added using `np.expand_dims` Args: image (np.array): Input image with shape `[batch, x, y, channel]` - batch_size (int, optional): Number of images to predict on per batch. Defaults to 4. - image_mpp (float, optional): Microns per pixel for the input image. Defaults to None. + batch_size (int, optional): Number of images to predict on per batch. + Defaults to 4. + image_mpp (float, optional): Microns per pixel for the input image. + Defaults to None. preprocess_kwargs (dict, optional): Kwargs to pass to preprocessing function. Defaults to {}. postprocess_kwargs (dict, optional): Kwargs to pass to postprocessing function. Defaults to {}. Raises: - ValueError: Input data must match required rank of the application, calculated as - one dimension more (batch dimension) than expected by the model + ValueError: Input data must match required rank of the application, + calculated as one dimension more (batch dimension) than expected + by the model. - ValueError: Input data must match required number of channels of application + ValueError: Input data must match required number of channels of application. Returns: np.array: Labeled image """ - return self._predict_segmentation(image, - batch_size=batch_size, - image_mpp=image_mpp, - preprocess_kwargs=preprocess_kwargs, - postprocess_kwargs=postprocess_kwargs) + return self._predict_segmentation( + image, + batch_size=batch_size, + image_mpp=image_mpp, + preprocess_kwargs=preprocess_kwargs, + postprocess_kwargs=postprocess_kwargs) diff --git a/deepcell/applications/multiplex_segmentation.py b/deepcell/applications/multiplex_segmentation.py index b6de116e4..fcd4da0fc 100644 --- a/deepcell/applications/multiplex_segmentation.py +++ b/deepcell/applications/multiplex_segmentation.py @@ -85,7 +85,8 @@ class MultiplexSegmentation(Application): .. nboutput:: Args: - use_pretrained_weights (bool, optional): Loads pretrained weights. Defaults to True. + use_pretrained_weights (bool, optional): Loads pretrained weights. + Defaults to True. model_image_shape (tuple, optional): Shape of input data expected by model. Defaults to `(256, 256, 2)` """ @@ -137,14 +138,15 @@ def __init__(self, else: weights_path = None - super(MultiplexSegmentation, self).__init__(model, - model_image_shape=model_image_shape, - model_mpp=0.5, - preprocessing_fn=multiplex_preprocess, - postprocessing_fn=multiplex_postprocess, - format_model_output_fn=format_output_multiplex, - dataset_metadata=self.dataset_metadata, - model_metadata=self.model_metadata) + super(MultiplexSegmentation, self).__init__( + model, + model_image_shape=model_image_shape, + model_mpp=0.5, + preprocessing_fn=multiplex_preprocess, + postprocessing_fn=multiplex_postprocess, + format_model_output_fn=format_output_multiplex, + dataset_metadata=self.dataset_metadata, + model_metadata=self.model_metadata) def predict(self, image, @@ -157,13 +159,15 @@ def predict(self, """Generates a labeled image of the input running prediction with appropriate pre and post processing functions. - Input images are required to have 4 dimensions `[batch, x, y, channel]`. Additional - empty dimensions can be added using `np.expand_dims` + Input images are required to have 4 dimensions `[batch, x, y, channel]`. + Additional empty dimensions can be added using `np.expand_dims` Args: image (np.array): Input image with shape `[batch, x, y, channel]` - batch_size (int, optional): Number of images to predict on per batch. Defaults to 4. - image_mpp (float, optional): Microns per pixel for the input image. Defaults to None. + batch_size (int, optional): Number of images to predict on per batch. + Defaults to 4. + image_mpp (float, optional): Microns per pixel for the input image. + Defaults to None. preprocess_kwargs (dict, optional): Kwargs to pass to preprocessing function. Defaults to {}. compartment (string): Specify type of segmentation to predict. Must be one of @@ -185,23 +189,33 @@ def predict(self, """ if postprocess_kwargs_whole_cell is None: - postprocess_kwargs_whole_cell = {'maxima_threshold': 0.1, 'maxima_model_smooth': 0, - 'interior_threshold': 0.3, 'interior_model_smooth': 2, - 'small_objects_threshold': 15, - 'fill_holes_threshold': 15, - 'radius': 2} + postprocess_kwargs_whole_cell = { + 'maxima_threshold': 0.1, + 'maxima_model_smooth': 0, + 'interior_threshold': 0.3, + 'interior_model_smooth': 2, + 'small_objects_threshold': 15, + 'fill_holes_threshold': 15, + 'radius': 2 + } if postprocess_kwargs_nuclear is None: - postprocess_kwargs_nuclear = {'maxima_threshold': 0.1, 'maxima_model_smooth': 0, - 'interior_threshold': 0.3, 'interior_model_smooth': 2, - 'small_objects_threshold': 15, - 'fill_holes_threshold': 15, - 'radius': 2} + postprocess_kwargs_nuclear = { + 'maxima_threshold': 0.1, + 'maxima_model_smooth': 0, + 'interior_threshold': 0.3, + 'interior_model_smooth': 2, + 'small_objects_threshold': 15, + 'fill_holes_threshold': 15, + 'radius': 2 + } # create dict to hold all of the post-processing kwargs - postprocess_kwargs = {'whole_cell_kwargs': postprocess_kwargs_whole_cell, - 'nuclear_kwargs': postprocess_kwargs_nuclear, - 'compartment': compartment} + postprocess_kwargs = { + 'whole_cell_kwargs': postprocess_kwargs_whole_cell, + 'nuclear_kwargs': postprocess_kwargs_nuclear, + 'compartment': compartment + } return self._predict_segmentation(image, batch_size=batch_size, diff --git a/deepcell/applications/multiplex_segmentation_test.py b/deepcell/applications/multiplex_segmentation_test.py index da5432fdd..90aefe88a 100644 --- a/deepcell/applications/multiplex_segmentation_test.py +++ b/deepcell/applications/multiplex_segmentation_test.py @@ -66,6 +66,5 @@ def test_multiplex_app(self): # test predict with both cell and nuclear compartments x = np.random.rand(1, 500, 500, 2) y = app.predict(x, compartment='both') - print("x shape is {}, y shape is {}".format(x.shape, y.shape)) self.assertEqual(x.shape[:-1], y.shape[:-1]) self.assertEqual(y.shape[-1], 2) diff --git a/deepcell/applications/nuclear_segmentation.py b/deepcell/applications/nuclear_segmentation.py index df67d51ea..6713c10cf 100644 --- a/deepcell/applications/nuclear_segmentation.py +++ b/deepcell/applications/nuclear_segmentation.py @@ -33,6 +33,7 @@ from tensorflow.python.keras.utils.data_utils import get_file +from deepcell_toolbox.processing import normalize from deepcell_toolbox.deep_watershed import deep_watershed from deepcell.applications import Application @@ -116,6 +117,7 @@ def __init__(self, location=True, include_top=True, lite=True, + use_imagenet=use_pretrained_weights, interpolation='bilinear') if use_pretrained_weights: @@ -130,13 +132,14 @@ def __init__(self, else: weights_path = None - super(NuclearSegmentation, self).__init__(model, - model_image_shape=model_image_shape, - model_mpp=0.65, - preprocessing_fn=None, - postprocessing_fn=deep_watershed, - dataset_metadata=self.dataset_metadata, - model_metadata=self.model_metadata) + super(NuclearSegmentation, self).__init__( + model, + model_image_shape=model_image_shape, + model_mpp=0.65, + preprocessing_fn=normalize, + postprocessing_fn=deep_watershed, + dataset_metadata=self.dataset_metadata, + model_metadata=self.model_metadata) def predict(self, image, @@ -147,29 +150,33 @@ def predict(self, """Generates a labeled image of the input running prediction with appropriate pre and post processing functions. - Input images are required to have 4 dimensions `[batch, x, y, channel]`. Additional - empty dimensions can be added using `np.expand_dims` + Input images are required to have 4 dimensions `[batch, x, y, channel]`. + Additional empty dimensions can be added using `np.expand_dims` Args: image (np.array): Input image with shape `[batch, x, y, channel]` - batch_size (int, optional): Number of images to predict on per batch. Defaults to 4. - image_mpp (float, optional): Microns per pixel for the input image. Defaults to None. + batch_size (int, optional): Number of images to predict on per batch. + Defaults to 4. + image_mpp (float, optional): Microns per pixel for the input image. + Defaults to None. preprocess_kwargs (dict, optional): Kwargs to pass to preprocessing function. Defaults to {}. postprocess_kwargs (dict, optional): Kwargs to pass to postprocessing function. Defaults to {}. Raises: - ValueError: Input data must match required rank of the application, calculated as - one dimension more (batch dimension) than expected by the model + ValueError: Input data must match required rank of the application, + calculated as one dimension more (batch dimension) than expected + by the model. - ValueError: Input data must match required number of channels of application + ValueError: Input data must match required number of channels. Returns: np.array: Labeled image """ - return self._predict_segmentation(image, - batch_size=batch_size, - image_mpp=image_mpp, - preprocess_kwargs=preprocess_kwargs, - postprocess_kwargs=postprocess_kwargs) + return self._predict_segmentation( + image, + batch_size=batch_size, + image_mpp=image_mpp, + preprocess_kwargs=preprocess_kwargs, + postprocess_kwargs=postprocess_kwargs) diff --git a/notebooks/applications/Nuclear-Application.ipynb b/notebooks/applications/Nuclear-Application.ipynb index a3c49d338..500e6c314 100644 --- a/notebooks/applications/Nuclear-Application.ipynb +++ b/notebooks/applications/Nuclear-Application.ipynb @@ -55,11 +55,8 @@ "name": "stdout", "output_type": "stream", "text": [ - "WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow_core/python/ops/resource_variable_ops.py:1630: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.\n", - "Instructions for updating:\n", - "If using Keras pass *_constraint arguments to layers.\n", "Downloading data from https://deepcell-data.s3.amazonaws.com/tracked/HeLa_S3.trks\n", - "6370648064/6370641920 [==============================] - 225s 0us/step\n", + "6370648064/6370641920 [==============================] - 144s 0us/step\n", "X_train shape: (144, 40, 216, 256, 1)\n", "X_test shape: (36, 40, 216, 256, 1)\n" ] @@ -78,23 +75,6 @@ "cell_type": "code", "execution_count": 3, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "(144, 40, 216, 256, 1)\n" - ] - } - ], - "source": [ - "print(X_train.shape)" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, "outputs": [], "source": [ "x = X_train[56] # chosen batch with divisions" @@ -102,7 +82,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -147,17 +127,27 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 5, "metadata": {}, "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING: Logging before flag parsing goes to stderr.\n", + "W1026 23:19:46.645289 139723356313408 deprecation.py:506] From /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/init_ops.py:1251: calling VarianceScaling.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.\n", + "Instructions for updating:\n", + "Call initializer instance with the dtype argument instead of passing it to the constructor\n" + ] + }, { "name": "stdout", "output_type": "stream", "text": [ "Downloading data from https://github.com/keras-team/keras-applications/releases/download/resnet/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5\n", - "94773248/94765736 [==============================] - 5s 0us/step\n", + "94773248/94765736 [==============================] - 6s 0us/step\n", "Downloading data from https://deepcell-data.s3-us-west-1.amazonaws.com/model-weights/nuclear_0_82800_resnet50_watershed_named_076bb10d832089b6a77faed1e63ad375.h5\n", - "101310464/101306776 [==============================] - 30s 0us/step\n" + "101310464/101306776 [==============================] - 5s 0us/step\n" ] } ], @@ -180,7 +170,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 6, "metadata": {}, "outputs": [ { @@ -206,7 +196,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 7, "metadata": {}, "outputs": [ { @@ -218,7 +208,7 @@ } ], "source": [ - "y_pred = app.predict(x, image_mpp=0.65)\n", + "y_pred = app.predict(x, image_mpp=.75)\n", "\n", "print(y_pred.shape)" ] @@ -234,7 +224,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ @@ -282,41 +272,14 @@ "\n", "The `NuclearSegmentation` worked well, but the cell labels of the same cell are not preserved across frames. To resolve this problem, we can use the `CellTracker`! This object will use another `CellTrackingModel` to compare all cells and determine which cells are the same across frames, as well as if a cell split into daughter cells.\n", "\n", - "### Normalize raw data to prepare for tracking\n", + "### Initalize CellTracking application\n", "\n", - "The `CellTracker` expects input image data to be zero-mean and unit-variance." + "Create an instance of `deepcell.applications.CellTracking`." ] }, { "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [], - "source": [ - "from deepcell_toolbox.processing import normalize\n", - "\n", - "x = x.astype('float32')\n", - "x_norm = np.empty(x.shape)\n", - "\n", - "for frame in range(x.shape[0]):\n", - " normalized = normalize(x[frame, ..., 0])\n", - " x_norm[frame] = np.expand_dims(normalized, axis=-1)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "raw_mimetype": "text/restructuredtext" - }, - "source": [ - "### Initalize tracking model\n", - "\n", - "Create an instance of `deepcell.applications.CellTrackingModel` and pass the model to the `CellTracker`." - ] - }, - { - "cell_type": "code", - "execution_count": 11, + "execution_count": 9, "metadata": {}, "outputs": [ { @@ -329,14 +292,9 @@ } ], "source": [ - "from deepcell.applications import CellTrackingModel\n", - "from deepcell_tracking import CellTracker\n", - "\n", - "tracking_model = CellTrackingModel()\n", + "from deepcell.applications import CellTracking\n", "\n", - "cell_tracker = CellTracker(\n", - " x_norm, y_pred, tracking_model,\n", - " birth=0.99, death=0.99, division=0.9)" + "tracker = CellTracking()" ] }, { @@ -348,7 +306,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 10, "metadata": {}, "outputs": [ { @@ -363,7 +321,7 @@ } ], "source": [ - "cell_tracker.track_cells()" + "tracked_data = tracker.track(np.copy(x), y_pred)" ] }, { @@ -377,20 +335,19 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "# Convert tracking results to a dictionary\n", - "data = cell_tracker._track_review_dict()\n", "\n", - "X = data['X'] # raw X data\n", - "y = data['y_tracked'] # tracked y data" + "X = tracked_data['X'] # raw X data\n", + "y = tracked_data['y_tracked'] # tracked y data" ] }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 12, "metadata": {}, "outputs": [], "source": [ @@ -424,6 +381,13 @@ "\n", "![Tracked Cells GIF](./tracks.gif)" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { diff --git a/notebooks/applications/labeled.gif b/notebooks/applications/labeled.gif index 1506cebf1..67f768283 100644 Binary files a/notebooks/applications/labeled.gif and b/notebooks/applications/labeled.gif differ diff --git a/notebooks/applications/tracks.gif b/notebooks/applications/tracks.gif index 2cb8de8de..1b4e9abf8 100644 Binary files a/notebooks/applications/tracks.gif and b/notebooks/applications/tracks.gif differ