Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add decoding functionality to Polaris #36

Merged
merged 71 commits into from
Feb 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
71 commits
Select commit Hold shift + click to select a range
a506697
Add latest decoding function as is
xuefei-wang Nov 11, 2022
f9177f8
Remove redundent code in decoding function
xuefei-wang Nov 11, 2022
3d0f10b
Add liscence for decoding_functions file
xuefei-wang Nov 18, 2022
9542980
Fix typo: polaris predict image channel should be 1
xuefei-wang Nov 18, 2022
155372b
Initialize spot_decoding model
xuefei-wang Nov 18, 2022
da06afb
Deprecate old functions in multiplex.py
xuefei-wang Nov 18, 2022
3993292
Add new utils functions for decoding
xuefei-wang Nov 22, 2022
4241534
Refactor decoding functions
xuefei-wang Nov 22, 2022
67ae54c
Implement spot decoding application
xuefei-wang Nov 22, 2022
5c0db27
Add decoding functionality to Polaris app; format output
xuefei-wang Nov 22, 2022
fed160a
Add comment about cell_id=0 case: background
xuefei-wang Nov 22, 2022
bb0ecbd
Make maxpool_extra_picel_num default to 0
xuefei-wang Nov 22, 2022
b365a17
Copy df_barcodes so that it does not modify input data
xuefei-wang Nov 22, 2022
8ec7aea
Update analysis notebooks
xuefei-wang Nov 22, 2022
1d3c759
Specify url for torch packages to use cpu-only ones
xuefei-wang Nov 22, 2022
7382e94
Use a larger fov for multiplex example
xuefei-wang Nov 22, 2022
0aea80a
Update requirements.txt (option: extra index url)
xuefei-wang Nov 22, 2022
55465d4
PEP8 formatted
xuefei-wang Nov 23, 2022
4aa0675
PEP8 formating continued
xuefei-wang Nov 23, 2022
d22e8a0
PEP8 continued
xuefei-wang Nov 23, 2022
39351d5
PEP8 continued again
xuefei-wang Nov 23, 2022
2ed66b4
PEP8 continued again * 2
xuefei-wang Nov 24, 2022
fa860fb
Point users to notebook for formatting df_barcodes
xuefei-wang Nov 24, 2022
abbbd34
PEP8 rm trailing white spaces
xuefei-wang Nov 24, 2022
97cbec6
Initialize tests to decoding related functions
xuefei-wang Nov 24, 2022
f9a7eac
Revert PEP8 too short wrapped lines
xuefei-wang Dec 2, 2022
86d3fc8
Implement tests for decoding
xuefei-wang Dec 2, 2022
f65387a
Fix bug: import missing package name
xuefei-wang Dec 5, 2022
3936ae8
Fix bug: replace range with np.arange
xuefei-wang Dec 5, 2022
a1260cc
Lint test files for decoding
xuefei-wang Dec 6, 2022
153f0eb
Fix typo: extra comma
xuefei-wang Dec 8, 2022
a7fb152
fix test bug: class_probs has #barcodes dim
xuefei-wang Jan 3, 2023
c3fde7b
Fix test bug: type error instead of value error
xuefei-wang Jan 3, 2023
0027cfd
Fix test bug: format file
xuefei-wang Jan 3, 2023
64e7286
Fix test bug: missing one dim
xuefei-wang Jan 3, 2023
b23cf0a
Fix test bug: format (rm blank lines)
xuefei-wang Jan 3, 2023
e7cf81e
Fix test bug: redundant extra dim in index for barcodes
xuefei-wang Jan 3, 2023
ab3ea8e
Fix test bug: missing num_image dim & add more rigorous tests
xuefei-wang Jan 3, 2023
9b1d37e
Fix test bug: redundant extra dim in index for barcodes (cont'd)
xuefei-wang Jan 3, 2023
fbeb984
Fix test bug: formating
xuefei-wang Jan 3, 2023
566a1b8
FIx test bug: not nesting arrays
xuefei-wang Jan 3, 2023
131d7f3
Comment out unpassed test for now
xuefei-wang Jan 3, 2023
28978f9
Fix test bug: numpy.int does not exist due to version upgrade
xuefei-wang Jan 3, 2023
f5974f3
Fix test bug: assert list equal
xuefei-wang Jan 3, 2023
18516aa
Fix test bug: slice_annotated_image bug found, np.int does not exist
xuefei-wang Jan 3, 2023
c3ce153
Fix test bug: typo
xuefei-wang Jan 3, 2023
30b8c34
fix test bug: formating
xuefei-wang Jan 3, 2023
3412bd5
Fix test bug: barcode index starts from 1
xuefei-wang Jan 3, 2023
eedc015
Fix test bug: multiple tests got crosswired
xuefei-wang Jan 3, 2023
8996c66
Comment out test_slice_annotated_image for now
xuefei-wang Jan 3, 2023
22afcb4
Fix test bug: rm ambiguity in test data
xuefei-wang Jan 3, 2023
c28d635
Test pining python version for tests
xuefei-wang Jan 3, 2023
cc140d2
Test pining python version for tests cont'd
xuefei-wang Jan 3, 2023
d849e28
Apply suggestions from code review and revert python version back
xuefei-wang Jan 4, 2023
c2c4a80
Apply suggestions from code review and revert python version back cont'd
xuefei-wang Jan 4, 2023
5fc7775
Clips probabilities before decoding (#39)
elaubsch Feb 2, 2023
2269b68
Merge branch 'master' into decoding
elaubsch Feb 2, 2023
ee282cb
Doctoring formatting
elaubsch Feb 7, 2023
1192f04
Remove code-block
elaubsch Feb 7, 2023
0b6f592
Add code-block
elaubsch Feb 7, 2023
3fedce5
Docstring formatting
elaubsch Feb 7, 2023
b892320
Unpin python 3.10 version
elaubsch Feb 7, 2023
4eca71f
Remove code-block directive from all scripts
elaubsch Feb 7, 2023
497562b
Refactor r and c to rounds and channels
elaubsch Feb 7, 2023
35c3dd6
Fix misspelling of singleplex
elaubsch Feb 7, 2023
7d80d32
Remove empty dictionary as default
elaubsch Feb 7, 2023
727592f
Remove split empty parens
elaubsch Feb 7, 2023
cbe6c8f
Add deprecation warnings
elaubsch Feb 7, 2023
be92fbc
Rewrite slicing in _predict_spots_image for clarity
elaubsch Feb 7, 2023
27f1d96
Refactor r and c to rounds and channels in tests and revert indexing …
elaubsch Feb 7, 2023
c8a4055
Remove dim calculation in reshape_torch_array
elaubsch Feb 7, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions deepcell_spots/applications/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,5 @@
"""Applications for pre-trained spot detection models"""

from deepcell_spots.applications.spot_detection import SpotDetection
from deepcell_spots.applications.spot_decoding import SpotDecoding
from deepcell_spots.applications.polaris import Polaris
181 changes: 151 additions & 30 deletions deepcell_spots/applications/polaris.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,40 @@
from __future__ import absolute_import, division, print_function

import warnings
import numpy as np
import pandas as pd

from deepcell.applications import CytoplasmSegmentation, NuclearSegmentation
from deepcell.applications import Mesmer
from deepcell_spots.applications import SpotDetection
from deepcell_spots.singleplex import match_spots_to_cells
from deepcell_spots.applications import SpotDetection, SpotDecoding
from deepcell_spots.singleplex import match_spots_to_cells_as_vec_batched
from deepcell_toolbox.processing import histogram_normalization
from deepcell_toolbox.deep_watershed import deep_watershed
from deepcell_spots.postprocessing_utils import max_cp_array_to_point_list_max
from deepcell_spots.multiplex import extract_spots_prob_from_coords_maxpool


def output_to_df(spots_locations_vec, cell_id_list, decoding_result):
"""
Formats model output from lists and arrays to dataframe.

Args:
spots_locations_vec (numpy.array): An array of spots coordinates with
shape ``[num_spots, 2]``.
cell_id_list (numpy.array): An array of assigned cell id for each spot
with shape ``[num_spots,]``.
decoding_result (dict): Keys include: 'probability', 'predicted_id',
'predicted_name'.

Returns:
pandas.DataFrame: A dataframe combines all input information.
"""
df = pd.DataFrame()
df[['x', 'y', 'batch_id']] = spots_locations_vec.astype(np.int32)
df['cell_id'] = cell_id_list
for name, val in decoding_result.items():
df[name] = val
return df


class Polaris(object):
Expand All @@ -44,9 +71,7 @@ class Polaris(object):
The ``predict`` method calls the predict method of each
application.

Example:

.. code-block:: python
Example::

from skimage.io import imread
from deepcell_spots.applications import Polaris
Expand All @@ -59,17 +84,29 @@ class Polaris(object):
spots_im = np.expand_dims(spots_im, axis=[0,-1])
cyto_im = np.expand_dims(cyto_im, axis=[0,-1])

# Create the application
app = Polaris()
####################################################################
# Singleplex case:
app = Polaris(image_type='singleplex')
df_spots, df_intensities, segmentation_result = app.predict(
spots_image=spots_im,
segmentation_image=cyto_im)

# Find the spot locations
result = app.predict(spots_image=spots_im,
####################################################################
# Multiplex case:
rounds = 10
channels = 2
df_barcodes = pd.read_csv('barcodes.csv', index_col=0)
app = Polaris(image_type='singleplex',
decoding_kwargs={'rounds': rounds,
'channels': channels,
'df_barcodes': df_barcodes})
df_spots, df_intensities, segmentation_result = app.predict(
spots_image=spots_im,
segmentation_image=cyto_im)
spots_dict = result[0]['spots_assignment']
labeled_im = result[0]['cell_segmentation']
coords = result[0]['spot_locations']

Args:
image_type (str): The type of the image. Valid values are
'singleplex' and 'multiplex'. Defaults to 'singleplex'.
elaubsch marked this conversation as resolved.
Show resolved Hide resolved
segmentation_model (tf.keras.Model): The model to load.
If ``None``, a pre-trained model will be downloaded.
segmentation_compartment (str): The cellular compartment
Expand All @@ -78,14 +115,52 @@ class Polaris(object):
to 'cytoplasm'.
spots_model (tf.keras.Model): The model to load.
If ``None``, a pre-trained model will be downloaded.
decoding_kwargs (dict): Keyword arguments to pass to the decoding method.
df_barcodes, rounds, channels. Defaults to empty, no decoding is performed.
df_barcodes (pandas.DataFrame): Codebook, one column is gene names ('code_name'),
the rest are binary barcodes, encoded using 1 and 0. Index should start at 1.
For exmaple, for a (rounds=10, channels=2) codebook, it should look the following
(see `notebooks/Multiplex FISH Analysis.ipynb` for examples)::

Index:
RangeIndex (starting from 1)
Columns:
Name: code_name, dtype: object
Name: r0c0, dtype: int64
Name: r0c1, dtype: int64
Name: r1c0, dtype: int64
Name: r1c1, dtype: int64
...
Name: r9c0, dtype: int64
Name: r9c1, dtype: int64
"""

def __init__(self,
image_type='singleplex',
segmentation_model=None,
segmentation_type='cytoplasm',
spots_model=None):
spots_model=None,
decoding_kwargs=None):

self.spots_app = SpotDetection(model=spots_model)
# Disable postprocessing_fn to return the full images
self.spots_app.postprocessing_fn = None

valid_image_types = ['singleplex', 'multiplex']
if image_type not in valid_image_types:
raise ValueError('Invalid image type supplied: {}. '
'Must be one of {}'.format(image_type,
valid_image_types))

self.image_type = image_type
if self.image_type == 'singleplex':
self.decoding_app = None
elif self.image_type == 'multiplex':
if not decoding_kwargs:
elaubsch marked this conversation as resolved.
Show resolved Hide resolved
self.decoding_app = None
warnings.warn('No spot decoding application instantiated.')
else:
self.decoding_app = SpotDecoding(**decoding_kwargs)

valid_compartments = ['cytoplasm', 'nucleus', 'mesmer', 'no segmentation']
if segmentation_type not in valid_compartments:
Expand All @@ -105,18 +180,45 @@ def __init__(self,
self.segmentation_app = None
warnings.warn('No segmentation application instantiated.')

def _predict_spots_image(self, spots_image, spots_threshold, spots_clip):
"""Iterate through all channels and generate model output (probability maps)

Args:
spots_image (numpy.array): Input image for spot detection with shape
``[batch, x, y, channel]``.
spots_threshold (float): Probability threshold for a pixel to be
considered as a spot.
spots_clip (bool): Determines if pixel values will be clipped by percentile.
Defaults to false.

Returns:
numpy.array: Output probability map with shape ``[batch, x, y, channel]``.
"""

output_image = np.zeros_like(spots_image, dtype=np.float32)
for idx_channel in range(spots_image.shape[-1]):
output_image[..., idx_channel] = self.spots_app.predict(
image=spots_image[..., idx_channel:idx_channel+1],
# TODO: threshold is disabled, but must feed a float [0,1] number
threshold=spots_threshold,
clip=spots_clip
)['classification'][..., 1]
return output_image

def predict(self,
spots_image,
segmentation_image=None,
image_mpp=None,
spots_threshold=0.95,
spots_clip=False):
spots_clip=False,
maxpool_extra_pixel_num=0,
decoding_training_kwargs=None):
"""Generates prediction output consisting of a labeled cell segmentation image,
detected spot locations, and a dictionary of spot locations assigned to labeled
cells of the input.

Input images are required to have 4 dimensions
``[batch, x, y, channel]``. Channel dimension should be 2.
``[batch, x, y, channel]``. Channel dimension should be 1.

Additional empty dimensions can be added using ``np.expand_dims``.

Expand All @@ -130,22 +232,38 @@ def predict(self,
considered as a spot.
spots_clip (bool): Determines if pixel values will be clipped by percentile.
Defaults to false.
maxpool_extra_pixel_num (int): Number of extra pixel for max pooling. Defaults
to 0, means no max pooling. For any number t, there will be a pool with
shape ``[-t, t] x [-t, t]``.
decoding_training_kwargs (dict): Including num_iter, batch_size, thres_prob.
Raises:
ValueError: Threshold value must be between 0 and 1.
ValueError: Segmentation application must be instantiated if segmentation
image is defined.

Returns:
list: List of dictionaries, length equal to batch dimension.
df_spots (pandas.DataFrame): Columns are x, y, batch_id, cell_id, probability,
predicted_id, preicted_name. Cell_id = 0 means background.
df_intensities (pandas.DataFrame): Columns are channels and rows are spots.
segmentation_result (numpy.array): Segmentation mask with shape ``[batch, x, y, 1]``.
"""

if spots_threshold < 0 or spots_threshold > 1:
raise ValueError('Threshold of %s was input. Threshold value must be '
'between 0 and 1.'.format())

spots_result = self.spots_app.predict(spots_image,
threshold=spots_threshold,
clip=spots_clip)
output_image = self._predict_spots_image(spots_image, spots_threshold, spots_clip)

clipped_output_image = np.clip(output_image, 0, 1)
max_proj_images = np.max(clipped_output_image, axis=-1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure what the shape of output_image is expected to be, but it's worth double-checking (and eventually testing) that you are taking the maximum projection along the dimension you expect.

spots_locations = max_cp_array_to_point_list_max(max_proj_images,
threshold=spots_threshold, min_distance=1)

spots_intensities = extract_spots_prob_from_coords_maxpool(
clipped_output_image, spots_locations, extra_pixel_num=maxpool_extra_pixel_num)
spots_intensities_vec = np.concatenate(spots_intensities)
spots_locations_vec = np.concatenate([np.concatenate(
[item, [[idx_batch]] * len(item)], axis=1)
for idx_batch, item in enumerate(spots_locations)])
Comment on lines +261 to +266
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's tough to say without knowing the shapes/types of all the inputs, but this looks like it might be creating a ragged array (i.e. where the shape is not a constant along one of the dimensions) which could have poor performance depending on the types of subsequent operations performed.

If this is the case, it might be worth switching to a dictionary or other mapping-type data structure.


if segmentation_image is not None:
if not self.segmentation_app:
Expand All @@ -154,16 +272,19 @@ def predict(self,
else:
segmentation_result = self.segmentation_app.predict(segmentation_image,
image_mpp=image_mpp)
result = []
for i in range(len(spots_result)):
spots_dict = match_spots_to_cells(segmentation_result[i:i + 1],
spots_result[i])

result.append({'spots_assignment': spots_dict,
'cell_segmentation': segmentation_result[i:i + 1],
'spot_locations': spots_result[i]})
spots_cell_assignments_vec = match_spots_to_cells_as_vec_batched(
segmentation_result, spots_locations)
else:
segmentation_result = None
spots_cell_assignments_vec = None

if self.decoding_app is not None:
decoding_result = self.decoding_app.predict(
spots_intensities_vec, **decoding_training_kwargs)
else:
result = spots_result
decoding_result = {'probability': None,
'predicted_id': None, 'predicted_name': None}

return result
df_spots = output_to_df(spots_locations_vec, spots_cell_assignments_vec, decoding_result)
df_intensities = pd.DataFrame(spots_intensities_vec)
return df_spots, df_intensities, segmentation_result
Loading