This notebook is part of the `deepcell-tf` documentation: https://deepcell.readthedocs.io/.

# Nuclear segmentation and tracking

In [18]:
import copy
import os

import imageio
import matplotlib as mpl
from matplotlib.colors import ListedColormap
import matplotlib.pyplot as plt
import numpy as np

from deepcell.applications import NuclearSegmentation, CellTracking
from deepcell_tracking.trk_io import load_trks

In [15]:
def shuffle_colors(ymax, cmap):
    """Utility function to generate a colormap for a labeled image"""
    cmap = mpl.colormaps[cmap].resampled(ymax)
    nmap = cmap(range(ymax))
    np.random.shuffle(nmap)
    cmap = ListedColormap(nmap)
    cmap.set_bad('black')
    return cmap

## Prepare nuclear data

Sample tracking data can be downloaded from https://datasets.deepcell.org/. Please adjust the path below to where your data is stored locally.

In [2]:
data = load_trks('/notebooks/val.trks') # Change this path
data['X'].shape

(27, 71, 584, 600, 1)

In [3]:
x = data['X'][0]
y = data['y'][0]

# Determine position of zero padding for removal
# Calculate position of padding based on first frame
# Assume that padding is in blocks on the edges of image
good_rows = np.where(x[0].any(axis=0))[0]
good_cols = np.where(x[0].any(axis=1))[0]
slc = (
    slice(None),
    slice(good_cols[0], good_cols[-1] + 1),
    slice(good_rows[0], good_rows[-1] + 1),
    slice(None)
)
x = x[slc]

# Determine which frames are zero padding
frames = np.sum(y, axis=(1,2)) 
good_frames = np.where(frames)[0] # True if image not blank
x = x[:len(good_frames)]

In [4]:
def plot(im):
    fig, ax = plt.subplots(figsize=(6, 6))
    ax.imshow(im, 'Greys_r')
    plt.axis('off')
    plt.title('Raw Image Data')

    fig.canvas.draw()  # draw the canvas, cache the renderer
    image = np.frombuffer(fig.canvas.tostring_rgb(), dtype='uint8')
    image = image.reshape(fig.canvas.get_width_height()[::-1] + (3,))

    plt.close(fig)

    return image

imageio.mimsave('raw.gif', [plot(x[i, ..., 0]) for i in range(x.shape[0])])

### View .GIF of raw cells

![Raw Gif](./raw.gif)

## Nuclear Segmentation

### Initialize nuclear model

The application will download pretrained weights for nuclear segmentation. For more information about application objects, please see our [documentation](https://deepcell.readthedocs.io/en/master/API/deepcell.applications.html).

In [5]:
app = NuclearSegmentation()

2023-06-26 23:42:33.286298: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-06-26 23:42:33.951331: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1525] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 10415 MB memory:  -> device: 0, name: NVIDIA GeForce GTX 1080 Ti, pci bus id: 0000:06:00.0, compute capability: 6.1




## Use the application to generate labeled images

Typically, neural networks perform best on test data that is similar to the training data. In the realm of biological imaging, the most common difference between datasets is the resolution of the data measured in microns per pixel. The training resolution of the model can be identified using `app.model_mpp`.

In [6]:
print('Training Resolution:', app.model_mpp, 'microns per pixel')

Training Resolution: 0.65 microns per pixel


The resolution of the input data can be specified in `app.predict` using the `image_mpp` option. The `Application` will rescale the input data to match the training resolution and then rescale to the original size before returning the labeled image.

In [7]:
y_pred = app.predict(x, image_mpp=0.65)

print(y_pred.shape)

2023-06-26 23:43:38.584782: I tensorflow/stream_executor/cuda/cuda_dnn.cc:368] Loaded cuDNN version 8100


(42, 540, 540, 1)


### Save labeled images as a gif to visualize

In [19]:
ymax = np.max(y_pred)
cmap = shuffle_colors(ymax, 'tab20')

def plot(x, y):
    yy = copy.deepcopy(y)
    yy = np.ma.masked_equal(yy, 0)
    
    fig, ax = plt.subplots(1, 2, figsize=(12, 6))
    ax[0].imshow(x, cmap='Greys_r')
    ax[0].axis('off')
    ax[0].set_title('Raw')
    ax[1].imshow(yy, cmap=cmap, vmax=ymax)
    ax[1].set_title('Segmented')
    ax[1].axis('off')

    fig.canvas.draw()  # draw the canvas, cache the renderer
    image = np.frombuffer(fig.canvas.tostring_rgb(), dtype='uint8')
    image = image.reshape(fig.canvas.get_width_height()[::-1] + (3,))
    plt.close(fig)

    return image

imageio.mimsave(
    './labeled.gif',
    [plot(x[i,...,0], y_pred[i,...,0])
     for i in range(y_pred.shape[0])]
)

### View .GIF of segmented cells

The `NuclearSegmentation` application was able to create a label mask for every cell in every frame!

![Segmented GIF](./labeled.gif)

## Cell Tracking

The `NuclearSegmentation` worked well, but the cell labels of the same cell are not preserved across frames. To resolve this problem, we can use the `CellTracker`! This object will use another `CellTrackingModel` to compare all cells and determine which cells are the same across frames, as well as if a cell split into daughter cells.

### Initalize CellTracking application

Create an instance of `deepcell.applications.CellTracking`.

In [9]:
tracker = CellTracking()

Downloading data from https://deepcell-data.s3-us-west-1.amazonaws.com/saved-models/NuclearTrackingNE-7.tar.gz




Downloading data from https://deepcell-data.s3-us-west-1.amazonaws.com/saved-models/NuclearTrackingInf-7.tar.gz




### Track the cells

In [12]:
tracked_data = tracker.track(x, y_pred)
y_tracked = tracked_data['y_tracked']









### Visualize tracking results

In [20]:
ymax = np.max(y_tracked)
cmap = shuffle_colors(ymax, 'tab20')

def plot(x, y):
    yy = copy.deepcopy(y)
    yy = np.ma.masked_equal(yy, 0)
    
    fig, ax = plt.subplots(1, 2, figsize=(12, 6))
    ax[0].imshow(x, cmap='Greys_r')
    ax[0].axis('off')
    ax[0].set_title('Raw')
    ax[1].imshow(yy, cmap=cmap, vmax=ymax)
    ax[1].set_title('Tracked')
    ax[1].axis('off')

    fig.canvas.draw()  # draw the canvas, cache the renderer
    image = np.frombuffer(fig.canvas.tostring_rgb(), dtype='uint8')
    image = image.reshape(fig.canvas.get_width_height()[::-1] + (3,))
    plt.close(fig)

    return image

imageio.mimsave(
    './tracks.gif',
    [plot(x[i,...,0], y_tracked[i,...,0])
     for i in range(y_tracked.shape[0])]
)

### View .GIF of tracked cells

Now that we've finished using `CellTracker.track_cells`, not only do the annotations preserve label across frames, but the lineage information has been saved in `CellTracker.tracks`.

![Tracked Cells GIF](./tracks.gif)