Skip to content

Commit

Permalink
Merge c8a4055 into a0a884d
Browse files Browse the repository at this point in the history
  • Loading branch information
xuefei-wang committed Feb 7, 2023
2 parents a0a884d + c8a4055 commit 65090bc
Show file tree
Hide file tree
Showing 18 changed files with 2,403 additions and 1,251 deletions.
1 change: 1 addition & 0 deletions deepcell_spots/applications/__init__.py
Expand Up @@ -27,4 +27,5 @@
"""Applications for pre-trained spot detection models"""

from deepcell_spots.applications.spot_detection import SpotDetection
from deepcell_spots.applications.spot_decoding import SpotDecoding
from deepcell_spots.applications.polaris import Polaris
181 changes: 151 additions & 30 deletions deepcell_spots/applications/polaris.py
Expand Up @@ -28,13 +28,40 @@
from __future__ import absolute_import, division, print_function

import warnings
import numpy as np
import pandas as pd

from deepcell.applications import CytoplasmSegmentation, NuclearSegmentation
from deepcell.applications import Mesmer
from deepcell_spots.applications import SpotDetection
from deepcell_spots.singleplex import match_spots_to_cells
from deepcell_spots.applications import SpotDetection, SpotDecoding
from deepcell_spots.singleplex import match_spots_to_cells_as_vec_batched
from deepcell_toolbox.processing import histogram_normalization
from deepcell_toolbox.deep_watershed import deep_watershed
from deepcell_spots.postprocessing_utils import max_cp_array_to_point_list_max
from deepcell_spots.multiplex import extract_spots_prob_from_coords_maxpool


def output_to_df(spots_locations_vec, cell_id_list, decoding_result):
"""
Formats model output from lists and arrays to dataframe.
Args:
spots_locations_vec (numpy.array): An array of spots coordinates with
shape ``[num_spots, 2]``.
cell_id_list (numpy.array): An array of assigned cell id for each spot
with shape ``[num_spots,]``.
decoding_result (dict): Keys include: 'probability', 'predicted_id',
'predicted_name'.
Returns:
pandas.DataFrame: A dataframe combines all input information.
"""
df = pd.DataFrame()
df[['x', 'y', 'batch_id']] = spots_locations_vec.astype(np.int32)
df['cell_id'] = cell_id_list
for name, val in decoding_result.items():
df[name] = val
return df


class Polaris(object):
Expand All @@ -44,9 +71,7 @@ class Polaris(object):
The ``predict`` method calls the predict method of each
application.
Example:
.. code-block:: python
Example::
from skimage.io import imread
from deepcell_spots.applications import Polaris
Expand All @@ -59,17 +84,29 @@ class Polaris(object):
spots_im = np.expand_dims(spots_im, axis=[0,-1])
cyto_im = np.expand_dims(cyto_im, axis=[0,-1])
# Create the application
app = Polaris()
####################################################################
# Singleplex case:
app = Polaris(image_type='singleplex')
df_spots, df_intensities, segmentation_result = app.predict(
spots_image=spots_im,
segmentation_image=cyto_im)
# Find the spot locations
result = app.predict(spots_image=spots_im,
####################################################################
# Multiplex case:
rounds = 10
channels = 2
df_barcodes = pd.read_csv('barcodes.csv', index_col=0)
app = Polaris(image_type='singleplex',
decoding_kwargs={'rounds': rounds,
'channels': channels,
'df_barcodes': df_barcodes})
df_spots, df_intensities, segmentation_result = app.predict(
spots_image=spots_im,
segmentation_image=cyto_im)
spots_dict = result[0]['spots_assignment']
labeled_im = result[0]['cell_segmentation']
coords = result[0]['spot_locations']
Args:
image_type (str): The type of the image. Valid values are
'singleplex' and 'multiplex'. Defaults to 'singleplex'.
segmentation_model (tf.keras.Model): The model to load.
If ``None``, a pre-trained model will be downloaded.
segmentation_compartment (str): The cellular compartment
Expand All @@ -78,14 +115,52 @@ class Polaris(object):
to 'cytoplasm'.
spots_model (tf.keras.Model): The model to load.
If ``None``, a pre-trained model will be downloaded.
decoding_kwargs (dict): Keyword arguments to pass to the decoding method.
df_barcodes, rounds, channels. Defaults to empty, no decoding is performed.
df_barcodes (pandas.DataFrame): Codebook, one column is gene names ('code_name'),
the rest are binary barcodes, encoded using 1 and 0. Index should start at 1.
For exmaple, for a (rounds=10, channels=2) codebook, it should look the following
(see `notebooks/Multiplex FISH Analysis.ipynb` for examples)::
Index:
RangeIndex (starting from 1)
Columns:
Name: code_name, dtype: object
Name: r0c0, dtype: int64
Name: r0c1, dtype: int64
Name: r1c0, dtype: int64
Name: r1c1, dtype: int64
...
Name: r9c0, dtype: int64
Name: r9c1, dtype: int64
"""

def __init__(self,
image_type='singleplex',
segmentation_model=None,
segmentation_type='cytoplasm',
spots_model=None):
spots_model=None,
decoding_kwargs=None):

self.spots_app = SpotDetection(model=spots_model)
# Disable postprocessing_fn to return the full images
self.spots_app.postprocessing_fn = None

valid_image_types = ['singleplex', 'multiplex']
if image_type not in valid_image_types:
raise ValueError('Invalid image type supplied: {}. '
'Must be one of {}'.format(image_type,
valid_image_types))

self.image_type = image_type
if self.image_type == 'singleplex':
self.decoding_app = None
elif self.image_type == 'multiplex':
if not decoding_kwargs:
self.decoding_app = None
warnings.warn('No spot decoding application instantiated.')
else:
self.decoding_app = SpotDecoding(**decoding_kwargs)

valid_compartments = ['cytoplasm', 'nucleus', 'mesmer', 'no segmentation']
if segmentation_type not in valid_compartments:
Expand All @@ -105,18 +180,45 @@ def __init__(self,
self.segmentation_app = None
warnings.warn('No segmentation application instantiated.')

def _predict_spots_image(self, spots_image, spots_threshold, spots_clip):
"""Iterate through all channels and generate model output (probability maps)
Args:
spots_image (numpy.array): Input image for spot detection with shape
``[batch, x, y, channel]``.
spots_threshold (float): Probability threshold for a pixel to be
considered as a spot.
spots_clip (bool): Determines if pixel values will be clipped by percentile.
Defaults to false.
Returns:
numpy.array: Output probability map with shape ``[batch, x, y, channel]``.
"""

output_image = np.zeros_like(spots_image, dtype=np.float32)
for idx_channel in range(spots_image.shape[-1]):
output_image[..., idx_channel] = self.spots_app.predict(
image=spots_image[..., idx_channel:idx_channel+1],
# TODO: threshold is disabled, but must feed a float [0,1] number
threshold=spots_threshold,
clip=spots_clip
)['classification'][..., 1]
return output_image

def predict(self,
spots_image,
segmentation_image=None,
image_mpp=None,
spots_threshold=0.95,
spots_clip=False):
spots_clip=False,
maxpool_extra_pixel_num=0,
decoding_training_kwargs=None):
"""Generates prediction output consisting of a labeled cell segmentation image,
detected spot locations, and a dictionary of spot locations assigned to labeled
cells of the input.
Input images are required to have 4 dimensions
``[batch, x, y, channel]``. Channel dimension should be 2.
``[batch, x, y, channel]``. Channel dimension should be 1.
Additional empty dimensions can be added using ``np.expand_dims``.
Expand All @@ -130,22 +232,38 @@ def predict(self,
considered as a spot.
spots_clip (bool): Determines if pixel values will be clipped by percentile.
Defaults to false.
maxpool_extra_pixel_num (int): Number of extra pixel for max pooling. Defaults
to 0, means no max pooling. For any number t, there will be a pool with
shape ``[-t, t] x [-t, t]``.
decoding_training_kwargs (dict): Including num_iter, batch_size, thres_prob.
Raises:
ValueError: Threshold value must be between 0 and 1.
ValueError: Segmentation application must be instantiated if segmentation
image is defined.
Returns:
list: List of dictionaries, length equal to batch dimension.
df_spots (pandas.DataFrame): Columns are x, y, batch_id, cell_id, probability,
predicted_id, preicted_name. Cell_id = 0 means background.
df_intensities (pandas.DataFrame): Columns are channels and rows are spots.
segmentation_result (numpy.array): Segmentation mask with shape ``[batch, x, y, 1]``.
"""

if spots_threshold < 0 or spots_threshold > 1:
raise ValueError('Threshold of %s was input. Threshold value must be '
'between 0 and 1.'.format())

spots_result = self.spots_app.predict(spots_image,
threshold=spots_threshold,
clip=spots_clip)
output_image = self._predict_spots_image(spots_image, spots_threshold, spots_clip)

clipped_output_image = np.clip(output_image, 0, 1)
max_proj_images = np.max(clipped_output_image, axis=-1)
spots_locations = max_cp_array_to_point_list_max(max_proj_images,
threshold=spots_threshold, min_distance=1)

spots_intensities = extract_spots_prob_from_coords_maxpool(
clipped_output_image, spots_locations, extra_pixel_num=maxpool_extra_pixel_num)
spots_intensities_vec = np.concatenate(spots_intensities)
spots_locations_vec = np.concatenate([np.concatenate(
[item, [[idx_batch]] * len(item)], axis=1)
for idx_batch, item in enumerate(spots_locations)])

if segmentation_image is not None:
if not self.segmentation_app:
Expand All @@ -154,16 +272,19 @@ def predict(self,
else:
segmentation_result = self.segmentation_app.predict(segmentation_image,
image_mpp=image_mpp)
result = []
for i in range(len(spots_result)):
spots_dict = match_spots_to_cells(segmentation_result[i:i + 1],
spots_result[i])

result.append({'spots_assignment': spots_dict,
'cell_segmentation': segmentation_result[i:i + 1],
'spot_locations': spots_result[i]})
spots_cell_assignments_vec = match_spots_to_cells_as_vec_batched(
segmentation_result, spots_locations)
else:
segmentation_result = None
spots_cell_assignments_vec = None

if self.decoding_app is not None:
decoding_result = self.decoding_app.predict(
spots_intensities_vec, **decoding_training_kwargs)
else:
result = spots_result
decoding_result = {'probability': None,
'predicted_id': None, 'predicted_name': None}

return result
df_spots = output_to_df(spots_locations_vec, spots_cell_assignments_vec, decoding_result)
df_intensities = pd.DataFrame(spots_intensities_vec)
return df_spots, df_intensities, segmentation_result

0 comments on commit 65090bc

Please sign in to comment.