Skip to content

Commit

Permalink
Merge 1380461 into e45bf0c
Browse files Browse the repository at this point in the history
  • Loading branch information
elaubsch committed Feb 1, 2022
2 parents e45bf0c + 1380461 commit a0f398a
Show file tree
Hide file tree
Showing 12 changed files with 618 additions and 366 deletions.
1 change: 1 addition & 0 deletions deepcell_spots/applications/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,5 @@
# limitations under the License.
# ==============================================================================

from deepcell_spots.applications.spot_detection import SpotDetection
from deepcell_spots.applications.polaris import Polaris
299 changes: 107 additions & 192 deletions deepcell_spots/applications/polaris.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,227 +23,142 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Spot detection application"""
"""Singleplex FISH analysis application"""

from __future__ import absolute_import, division, print_function

import os
import timeit
import warnings

import numpy as np
import tensorflow as tf
from deepcell.applications import Application

from deepcell_spots.dotnet_losses import DotNetLosses
from deepcell_spots.postprocessing_utils import y_annotations_to_point_list_max
from deepcell_spots.preprocessing_utils import min_max_normalize
from deepcell.applications import CytoplasmSegmentation
from deepcell.applications import NuclearSegmentation
from deepcell_spots.applications import SpotDetection
from deepcell_spots.singleplex import match_spots_to_cells
from deepcell_toolbox.processing import histogram_normalization
from deepcell_toolbox.deep_watershed import deep_watershed
from tensorflow.python.platform.tf_logging import warning


MODEL_PATH = ('https://deepcell-data.s3-us-west-1.amazonaws.com/'
'saved-models/SpotDetection-3.tar.gz')


class Polaris(Application):
"""Loads a :mod:`deepcell.model_zoo.featurenet.FeatureNet` model
for fluorescent spot detection with pretrained weights.
The ``predict`` method handles prep and post processing steps
to return a labeled image.
class Polaris(object):
"""Loads spot detection and cell segmentation applications
from deepcell_spots and deepcell_tf, respectively.
The ``predict`` method calls the predict method of each
application.
Example:
.. code-block:: python
from skimage.io import imread
from deepcell_spots.applications import Polaris
# Load the image
im = imread('spots_image.png')
# Load the images
spots_im = imread('spots_image.png')
cyto_im = imread('cyto_image.png')
# Expand image dimensions to rank 4
im = np.expand_dims(im, axis=-1)
im = np.expand_dims(im, axis=0)
spots_im = np.expand_dims(spots_im, axis=[0,-1])
cyto_im = np.expand_dims(cyto_im, axis=[0,-1])
# Create the application
app = Polaris()
# create the lab
labeled_image = app.predict(im)
# Find the spot locations
result = app.predict(spots_image=spots_im,
segmentation_image=cyto_im)
spots_dict = result[0]['spots_assignment']
labeled_im = result[0]['cell_segmentation']
coords = result[0]['spot_locations']
Args:
model (tf.keras.Model): The model to load. If ``None``,
a pre-trained model will be downloaded.
segmentation_model (tf.keras.Model): The model to load.
If ``None``, a pre-trained model will be downloaded.
segmentation_compartment (str): The cellular compartment
for generating segmentation predictions. Valid values
are 'cytoplasm', 'nucleus', 'no segmentation'. Defaults
to 'cytoplasm'.
spots_model (tf.keras.Model): The model to load.
If ``None``, a pre-trained model will be downloaded.
"""

#: Metadata for the dataset used to train the model
dataset_metadata = {
'name': 'general_train', # update
'other': """Pooled FISH data including MERFISH data
and SunTag viral RNA data""" # update
}

#: Metadata for the model and training process
model_metadata = {
'batch_size': 1,
'lr': 0.01,
'lr_decay': 0.99,
'training_seed': 0,
'n_epochs': 10,
'training_steps_per_epoch': 552
}

def __init__(self, model=None):

if model is None:
archive_path = tf.keras.utils.get_file(
'SpotDetection.tgz', MODEL_PATH,
file_hash='2b9a46087b25e9aab20a2c9f67f4f559',
extract=True, cache_subdir='models'
)
model_path = os.path.splitext(archive_path)[0]
model = tf.keras.models.load_model(
model_path, custom_objects={
'regression_loss': DotNetLosses.regression_loss,
'classification_loss': DotNetLosses.classification_loss
}
)

super(Polaris, self).__init__(
model,
model_image_shape=model.input_shape[1:],
model_mpp=0.65,
preprocessing_fn=min_max_normalize,
postprocessing_fn=y_annotations_to_point_list_max,
dataset_metadata=self.dataset_metadata,
model_metadata=self.model_metadata)

def _postprocess(self, image, **kwargs):
"""Applies postprocessing function to image if one has been defined.
Differs from parent class in that it returns a set of coordinate spot
locations, so handling of dimensions differs.
Otherwise returns unmodified image.
Args:
image (numpy.array or list): Input to postprocessing function
either an ``numpy.array`` or list of ``numpy.arrays``.
Returns:
numpy.array: labeled image
"""
if self.postprocessing_fn is not None:
t = timeit.default_timer()
self.logger.debug('Post-processing results with %s and kwargs: %s',
self.postprocessing_fn.__name__, kwargs)

image = self.postprocessing_fn(image, **kwargs)

self.logger.debug('Post-processed results with %s in %s s',
self.postprocessing_fn.__name__,
timeit.default_timer() - t)

elif isinstance(image, list) and len(image) == 1:
image = image[0]

return image

def _predict_segmentation(self,
image,
batch_size=4,
image_mpp=None,
pad_mode='constant',
preprocess_kwargs={},
postprocess_kwargs={}):
"""Generates a list of coordinate spot locations of the input running
prediction with appropriate pre and post processing functions.
This differs from parent Application class which returns a labeled image.
Input images are required to have 4 dimensions
``[batch, x, y, channel]``. Additional empty dimensions can be added
using ``np.expand_dims``.
Args:
image (numpy.array): Input image with shape
``[batch, x, y, channel]``.
batch_size (int): Number of images to predict on per batch.
image_mpp (float): Microns per pixel for ``image``.
pad_mode (str): The padding mode, one of "constant" or "reflect".
preprocess_kwargs (dict): Keyword arguments to pass to the
pre-processing function.
postprocess_kwargs (dict): Keyword arguments to pass to the
post-processing function.
Raises:
ValueError: Input data must match required rank, calculated as one
dimension more (batch dimension) than expected by the model.
ValueError: Input data must match required number of channels.
Returns:
numpy.array: Coordinate spot locations
"""
# Check input size of image
if len(image.shape) != self.required_rank:
raise ValueError('Input data must have {} dimensions. '
'Input data only has {} dimensions'.format(
self.required_rank, len(image.shape)))

if image.shape[-1] != self.required_channels:
raise ValueError('Input data must have {} channels. '
'Input data only has {} channels'.format(
self.required_channels, image.shape[-1]))

# Resize image, returns unmodified if appropriate
resized_image = self._resize_input(image, image_mpp)

# Generate model outputs
output_images = self._run_model(
image=resized_image, batch_size=batch_size,
pad_mode=pad_mode, preprocess_kwargs=preprocess_kwargs
)

# Resize output_images back to original resolution if necessary
label_image = self._resize_output(output_images, image.shape)

# Postprocess predictions to create label image
predicted_spots = self._postprocess(label_image, **postprocess_kwargs)

return predicted_spots
def __init__(self,
segmentation_model=None,
segmentation_compartment='cytoplasm',
spots_model=None):

self.spots_app = SpotDetection(model=spots_model)

valid_compartments = ['cytoplasm', 'nucleus', 'no segmentation']
if segmentation_compartment not in valid_compartments:
raise ValueError('Invalid compartment supplied: {}. '
'Must be one of {}'.format(segmentation_compartment,
valid_compartments))

if segmentation_compartment == 'cytoplasm':
self.segmentation_app = CytoplasmSegmentation(model=segmentation_model)
self.segmentation_app.preprocessing_fn = histogram_normalization
self.segmentation_app.postprocessing_fn = deep_watershed
elif segmentation_compartment == 'nucleus':
self.segmentation_app = NuclearSegmentation(model=segmentation_model)
else:
self.segmentation_app = None
warnings.warn('No segmentation application instantiated.')

def predict(self,
image,
batch_size=4,
spots_image,
segmentation_image=None,
image_mpp=None,
pad_mode='reflect',
preprocess_kwargs=None,
postprocess_kwargs=None,
threshold=0.9):
"""Generates a list of coordinate spot locations of the input
running prediction with appropriate pre and post processing
functions.
spots_threshold=0.95,
spots_clip=False):
"""Generates prediction output consisting of a labeled cell segmentation image,
detected spot locations, and a dictionary of spot locations assigned to labeled
cells of the input.
Input images are required to have 4 dimensions
``[batch, x, y, channel]``.
``[batch, x, y, channel]``. Channel dimension should be 2.
Additional empty dimensions can be added using ``np.expand_dims``.
Args:
image (numpy.array): Input image with shape
spots_image (numpy.array): Input image for spot detection with shape
``[batch, x, y, channel]``.
batch_size (int): Number of images to predict on per batch.
segmentation_image (numpy.array): Input image for cell segmentation with shape
``[batch, x, y, channel]``. Defaults to None.
image_mpp (float): Microns per pixel for ``image``.
pad_mode (str): The padding mode, one of "constant" or "reflect".
preprocess_kwargs (dict): Keyword arguments to pass to the
pre-processing function.
postprocess_kwargs (dict): Keyword arguments to pass to the
post-processing function.
threshold (float): Probability threshold for a pixel to be
considered as a spot.
spots_threshold (float): Probability threshold for a pixel to be
considered as a spot.
spots_clip (bool): Determines if pixel values will be clipped by percentile.
Defaults to false.
Raises:
ValueError: Input data must match required rank of the application,
calculated as one dimension more (batch dimension) than
expected by the model.
ValueError: Input data must match required number of channels.
ValueError: Threshold value must be between 0 and 1.
ValueError: Segmentation application must be instantiated if segmentation
image is defined.
Returns:
numpy.array: Coordinate locations of detected spots.
list: List of dictionaries, length equal to batch dimension.
"""

if threshold < 0 or threshold > 1:
raise ValueError("""Enter a probability threshold value between
0 and 1.""")

if preprocess_kwargs is None:
preprocess_kwargs = {}

if postprocess_kwargs is None:
postprocess_kwargs = {
'threshold': threshold,
'min_distance': 1}

return self._predict_segmentation(
image,
batch_size=batch_size,
image_mpp=image_mpp,
pad_mode=pad_mode,
preprocess_kwargs=preprocess_kwargs,
postprocess_kwargs=postprocess_kwargs)
if spots_threshold < 0 or spots_threshold > 1:
raise ValueError('Threshold of %s was input. Threshold value must be '
'between 0 and 1.'.format())

spots_result = self.spots_app.predict(spots_image,
threshold=spots_threshold,
clip=spots_clip)

if segmentation_image is not None:
if not self.segmentation_app:
raise ValueError('Segmentation application must be instantiated if '
'segmentation image is defined.')
else:
segmentation_result = self.segmentation_app.predict(segmentation_image,
image_mpp=image_mpp)
result = []
for i in range(len(spots_result)):
spots_dict = match_spots_to_cells(segmentation_result[i:i + 1],
spots_result[i])

result.append({'spots_assignment': spots_dict,
'cell_segmentation': segmentation_result[i:i + 1],
'spot_locations': spots_result[i]})

else:
result = spots_result

return result

0 comments on commit a0f398a

Please sign in to comment.