Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Singleplex Polaris app #15

Merged
merged 32 commits into from
Feb 1, 2022
Merged
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
430c322
Refactor Polaris app to SpotDetection
elaubsch Jan 26, 2022
a66b513
Add clip flag
elaubsch Jan 26, 2022
f7749cf
Add SpotDetection import
elaubsch Jan 26, 2022
9317e2f
Implement singleplex application
elaubsch Jan 27, 2022
ea19c89
PEP8
elaubsch Jan 27, 2022
2f19a70
Update docstring
elaubsch Jan 27, 2022
bd0ccae
Update docstring
elaubsch Jan 27, 2022
dcf0999
Write docstring
elaubsch Jan 27, 2022
485298a
PEP8
elaubsch Jan 27, 2022
d90997f
Add tests
elaubsch Jan 27, 2022
bf9bbc5
Update numpy requirement
elaubsch Jan 27, 2022
73c2746
PEP8
elaubsch Jan 27, 2022
a443cce
Update whitespace
elaubsch Jan 27, 2022
b731f6f
Remove comment
elaubsch Jan 27, 2022
ecf90b0
Remove comment
elaubsch Jan 27, 2022
6f103d0
Update deepcell_spots/applications/polaris.py
elaubsch Jan 27, 2022
4ad3da1
Update deepcell_spots/applications/polaris.py
elaubsch Jan 27, 2022
694262b
Triple quotes to single quotes
elaubsch Jan 27, 2022
6c00c0a
Rename _predict_segmentation to _predict
elaubsch Jan 27, 2022
760e4ca
Merge branch 'singleplex-app' of https://github.com/vanvalenlab/deepc…
elaubsch Jan 27, 2022
3a874e7
Remove model
elaubsch Feb 1, 2022
ef713f0
Rename spots variables
elaubsch Feb 1, 2022
3079d66
Update segmentation model
elaubsch Feb 1, 2022
d7552c6
Add nuclear segmentation
elaubsch Feb 1, 2022
c5b51f8
Test errors
elaubsch Feb 1, 2022
b11a8fc
Update example notebooks
elaubsch Feb 1, 2022
5081b20
Change no seg result
elaubsch Feb 1, 2022
5f75333
Update docstring
elaubsch Feb 1, 2022
dabebce
Edit docstrings and warnings
elaubsch Feb 1, 2022
7e5e75f
Merge branch 'singleplex-app' of https://github.com/vanvalenlab/deepc…
elaubsch Feb 1, 2022
1380461
Fix bug in tests
elaubsch Feb 1, 2022
55561a9
Update notebook outputs
elaubsch Feb 1, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions deepcell_spots/applications/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,5 @@
# limitations under the License.
# ==============================================================================

from deepcell_spots.applications.spot_detection import SpotDetection
from deepcell_spots.applications.polaris import Polaris
285 changes: 100 additions & 185 deletions deepcell_spots/applications/polaris.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,227 +23,142 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Spot detection application"""
"""Singleplex FISH analysis application"""

from __future__ import absolute_import, division, print_function

import os
import timeit

import numpy as np
import tensorflow as tf
from deepcell.applications import Application

from deepcell_spots.dotnet_losses import DotNetLosses
from deepcell_spots.postprocessing_utils import y_annotations_to_point_list_max
from deepcell_spots.preprocessing_utils import min_max_normalize
from deepcell.applications import CytoplasmSegmentation
elaubsch marked this conversation as resolved.
Show resolved Hide resolved
from deepcell.applications import NuclearSegmentation
from deepcell_spots.applications import SpotDetection
from deepcell_spots.singleplex import match_spots_to_cells
from deepcell_toolbox.processing import histogram_normalization
from deepcell_toolbox.deep_watershed import deep_watershed


MODEL_PATH = ('https://deepcell-data.s3-us-west-1.amazonaws.com/'
'saved-models/SpotDetection-3.tar.gz')


class Polaris(Application):
"""Loads a :mod:`deepcell.model_zoo.featurenet.FeatureNet` model
for fluorescent spot detection with pretrained weights.
The ``predict`` method handles prep and post processing steps
to return a labeled image.
class Polaris(object):
"""Loads spot detection and cell segmentation applications
from deepcell_spots and deepcell_tf, respectively.
The ``predict`` method calls the predict method of each
application.
Example:
.. code-block:: python
from skimage.io import imread
from deepcell_spots.applications import Polaris
# Load the image
im = imread('spots_image.png')
# Load the images
spots_im = imread('spots_image.png')
cyto_im = imread('cyto_image.png')
# Expand image dimensions to rank 4
im = np.expand_dims(im, axis=-1)
im = np.expand_dims(im, axis=0)
spots_im = np.expand_dims(spots_im, axis=[0,-1])
cyto_im = np.expand_dims(cyto_im, axis=[0,-1])
# Concatenate images
im = np.concatenate((cyto_im, spots_im), axis=-1)
# Create the application
app = Polaris()
# create the lab
labeled_image = app.predict(im)
# Find the spot locations
result = app.predict(im)
elaubsch marked this conversation as resolved.
Show resolved Hide resolved
spots_dict = result[0]['spots_assignment']
labeled_im = result[0]['cell_segmentation']
coords = result[0]['spot_locations']
Args:
model (tf.keras.Model): The model to load. If ``None``,
a pre-trained model will be downloaded.
segmentation_model (tf.keras.Model): The model to load.
If ``None``, a pre-trained model will be downloaded.
spots_model (tf.keras.Model): The model to load.
If ``None``, a pre-trained model will be downloaded.
"""

#: Metadata for the dataset used to train the model
dataset_metadata = {
'name': 'general_train', # update
'other': """Pooled FISH data including MERFISH data
and SunTag viral RNA data""" # update
}

#: Metadata for the model and training process
model_metadata = {
'batch_size': 1,
'lr': 0.01,
'lr_decay': 0.99,
'training_seed': 0,
'n_epochs': 10,
'training_steps_per_epoch': 552
}

def __init__(self, model=None):

if model is None:
archive_path = tf.keras.utils.get_file(
'SpotDetection.tgz', MODEL_PATH,
file_hash='2b9a46087b25e9aab20a2c9f67f4f559',
extract=True, cache_subdir='models'
)
model_path = os.path.splitext(archive_path)[0]
model = tf.keras.models.load_model(
model_path, custom_objects={
'regression_loss': DotNetLosses.regression_loss,
'classification_loss': DotNetLosses.classification_loss
}
)

super(Polaris, self).__init__(
model,
model_image_shape=model.input_shape[1:],
model_mpp=0.65,
preprocessing_fn=min_max_normalize,
postprocessing_fn=y_annotations_to_point_list_max,
dataset_metadata=self.dataset_metadata,
model_metadata=self.model_metadata)

def _postprocess(self, image, **kwargs):
"""Applies postprocessing function to image if one has been defined.
Differs from parent class in that it returns a set of coordinate spot
locations, so handling of dimensions differs.

Otherwise returns unmodified image.
Args:
image (numpy.array or list): Input to postprocessing function
either an ``numpy.array`` or list of ``numpy.arrays``.
Returns:
numpy.array: labeled image
"""
if self.postprocessing_fn is not None:
t = timeit.default_timer()
self.logger.debug('Post-processing results with %s and kwargs: %s',
self.postprocessing_fn.__name__, kwargs)

image = self.postprocessing_fn(image, **kwargs)

self.logger.debug('Post-processed results with %s in %s s',
self.postprocessing_fn.__name__,
timeit.default_timer() - t)

elif isinstance(image, list) and len(image) == 1:
image = image[0]

return image

def _predict_segmentation(self,
image,
batch_size=4,
image_mpp=None,
pad_mode='constant',
preprocess_kwargs={},
postprocess_kwargs={}):
"""Generates a list of coordinate spot locations of the input running
prediction with appropriate pre and post processing functions.
This differs from parent Application class which returns a labeled image.
Input images are required to have 4 dimensions
``[batch, x, y, channel]``. Additional empty dimensions can be added
using ``np.expand_dims``.
Args:
image (numpy.array): Input image with shape
``[batch, x, y, channel]``.
batch_size (int): Number of images to predict on per batch.
image_mpp (float): Microns per pixel for ``image``.
pad_mode (str): The padding mode, one of "constant" or "reflect".
preprocess_kwargs (dict): Keyword arguments to pass to the
pre-processing function.
postprocess_kwargs (dict): Keyword arguments to pass to the
post-processing function.
Raises:
ValueError: Input data must match required rank, calculated as one
dimension more (batch dimension) than expected by the model.
ValueError: Input data must match required number of channels.
Returns:
numpy.array: Coordinate spot locations
"""
# Check input size of image
if len(image.shape) != self.required_rank:
raise ValueError('Input data must have {} dimensions. '
'Input data only has {} dimensions'.format(
self.required_rank, len(image.shape)))

if image.shape[-1] != self.required_channels:
raise ValueError('Input data must have {} channels. '
'Input data only has {} channels'.format(
self.required_channels, image.shape[-1]))

# Resize image, returns unmodified if appropriate
resized_image = self._resize_input(image, image_mpp)
def __init__(self,
segmentation_model=None,
segmentation_compartment='cytoplasm',
spots_model=None):

# Generate model outputs
output_images = self._run_model(
image=resized_image, batch_size=batch_size,
pad_mode=pad_mode, preprocess_kwargs=preprocess_kwargs
)
self.spots_app = SpotDetection(model=spots_model)

# Resize output_images back to original resolution if necessary
label_image = self._resize_output(output_images, image.shape)
valid_compartments = ['cytoplasm', 'nucleus', 'None']
elaubsch marked this conversation as resolved.
Show resolved Hide resolved
if segmentation_compartment not in valid_compartments:
raise ValueError('Invalid compartment supplied: {}. '
'Must be one of {}'.format(segmentation_compartment,
valid_compartments))

# Postprocess predictions to create label image
predicted_spots = self._postprocess(label_image, **postprocess_kwargs)

return predicted_spots
if segmentation_compartment == 'cytoplasm':
self.segmentation_app = CytoplasmSegmentation(model=segmentation_model)
self.segmentation_app.preprocessing_fn = histogram_normalization
self.segmentation_app.postprocessing_fn = deep_watershed
elif segmentation_compartment == 'nucleus':
self.segmentation_app = NuclearSegmentation(model=segmentation_model)
else:
self.segmentation_app = None
elaubsch marked this conversation as resolved.
Show resolved Hide resolved

def predict(self,
image,
batch_size=4,
spots_image,
segmentation_image=None,
image_mpp=None,
pad_mode='reflect',
preprocess_kwargs=None,
postprocess_kwargs=None,
threshold=0.9):
"""Generates a list of coordinate spot locations of the input
running prediction with appropriate pre and post processing
functions.
spots_threshold=0.95,
spots_clip=False):
"""Generates prediction output consisting of a labeled cell segmentation image,
detected spot locations, and a dictionary of spot locations assigned to labeled
cells of the input.

Input images are required to have 4 dimensions
``[batch, x, y, channel]``.
``[batch, x, y, channel]``. Channel dimension should be 2.

Additional empty dimensions can be added using ``np.expand_dims``.

Args:
image (numpy.array): Input image with shape
``[batch, x, y, channel]``.
batch_size (int): Number of images to predict on per batch.
image_mpp (float): Microns per pixel for ``image``.
pad_mode (str): The padding mode, one of "constant" or "reflect".
preprocess_kwargs (dict): Keyword arguments to pass to the
pre-processing function.
postprocess_kwargs (dict): Keyword arguments to pass to the
post-processing function.
cytoplasm_channel (int): Value should be 0 or 1 depending on the channel
containing the images for cell segmentation. Defaults to 0.
spots_channel (int): Value should be 0 or 1 depending on the channel
containing the images for spot detection. Defaults to 1.
elaubsch marked this conversation as resolved.
Show resolved Hide resolved
elaubsch marked this conversation as resolved.
Show resolved Hide resolved
threshold (float): Probability threshold for a pixel to be
considered as a spot.
considered as a spot.
clip (bool): Determines if pixel values will be clipped by percentile.
Defaults to false.
Raises:
ValueError: Input data must match required rank of the application,
calculated as one dimension more (batch dimension) than
expected by the model.
calculated as one dimension more (batch dimension) than expected
by the model.
elaubsch marked this conversation as resolved.
Show resolved Hide resolved
ValueError: Input data must match required number of channels.
ValueError: Threshold value must be between 0 and 1.
ValueError: Segmentation application must be instantiated if segmentation
image is defined.

Returns:
numpy.array: Coordinate locations of detected spots.
list: List of dictionaries, length equal to batch dimension.
"""

if threshold < 0 or threshold > 1:
raise ValueError("""Enter a probability threshold value between
0 and 1.""")

if preprocess_kwargs is None:
preprocess_kwargs = {}

if postprocess_kwargs is None:
postprocess_kwargs = {
'threshold': threshold,
'min_distance': 1}

return self._predict_segmentation(
image,
batch_size=batch_size,
image_mpp=image_mpp,
pad_mode=pad_mode,
preprocess_kwargs=preprocess_kwargs,
postprocess_kwargs=postprocess_kwargs)
if spots_threshold < 0 or spots_threshold > 1:
elaubsch marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError('Threshold of %s was input. Threshold value must be '
'between 0 and 1.'.format())

spots_result = self.spots_app.predict(spots_image,
threshold=spots_threshold,
clip=spots_clip)

if segmentation_image is not None:
if not self.segmentation_app:
raise ValueError('Segmentation application must be instantiated if '
'segmentation image is defined.')
else:
segmentation_result = self.segmentation_app.predict(segmentation_image,
image_mpp=image_mpp)
result = []
for i in range(len(spots_result)):
spots_dict = match_spots_to_cells(segmentation_result[i:i + 1],
spots_result[i])

result.append({'spots_assignment': spots_dict,
'cell_segmentation': segmentation_result[i:i + 1],
'spot_locations': spots_result[i]})

else:
result = spots_result

return result