Skip to content

Commit

Permalink
Maintenance to clip and threshold application arguments (#44)
Browse files Browse the repository at this point in the history
* Refactor spots_clip and spots_threshold to clip and threshold

* Change default clip value to True
  • Loading branch information
elaubsch committed Feb 17, 2023
1 parent a0dfabe commit 6eef203
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 20 deletions.
31 changes: 15 additions & 16 deletions deepcell_spots/applications/polaris.py
Expand Up @@ -23,7 +23,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Singleplex FISH analysis application"""

"""Singleplex and multiplex FISH analysis application"""

from __future__ import absolute_import, division, print_function

Expand Down Expand Up @@ -180,16 +181,16 @@ def __init__(self,
self.segmentation_app = None
warnings.warn('No segmentation application instantiated.')

def _predict_spots_image(self, spots_image, spots_threshold, spots_clip):
def _predict_spots_image(self, spots_image, threshold, clip):
"""Iterate through all channels and generate model output (probability maps)
Args:
spots_image (numpy.array): Input image for spot detection with shape
``[batch, x, y, channel]``.
spots_threshold (float): Probability threshold for a pixel to be
threshold (float): Probability threshold for a pixel to be
considered as a spot.
spots_clip (bool): Determines if pixel values will be clipped by percentile.
Defaults to false.
clip (bool): Determines if pixel values will be clipped by percentile.
Defaults to True.
Returns:
numpy.array: Output probability map with shape ``[batch, x, y, channel]``.
Expand All @@ -199,18 +200,16 @@ def _predict_spots_image(self, spots_image, spots_threshold, spots_clip):
for idx_channel in range(spots_image.shape[-1]):
output_image[..., idx_channel] = self.spots_app.predict(
image=spots_image[..., idx_channel:idx_channel+1],
# TODO: threshold is disabled, but must feed a float [0,1] number
threshold=spots_threshold,
clip=spots_clip
clip=clip
)['classification'][..., 1]
return output_image

def predict(self,
spots_image,
segmentation_image=None,
image_mpp=None,
spots_threshold=0.95,
spots_clip=False,
threshold=0.95,
clip=True,
maxpool_extra_pixel_num=0,
decoding_training_kwargs=None):
"""Generates prediction output consisting of a labeled cell segmentation image,
Expand All @@ -228,10 +227,10 @@ def predict(self,
segmentation_image (numpy.array): Input image for cell segmentation with shape
``[batch, x, y, channel]``. Defaults to None.
image_mpp (float): Microns per pixel for ``image``.
spots_threshold (float): Probability threshold for a pixel to be
threshold (float): Probability threshold for a pixel to be
considered as a spot.
spots_clip (bool): Determines if pixel values will be clipped by percentile.
Defaults to false.
clip (bool): Determines if pixel values will be clipped by percentile.
Defaults to True.
maxpool_extra_pixel_num (int): Number of extra pixel for max pooling. Defaults
to 0, means no max pooling. For any number t, there will be a pool with
shape ``[-t, t] x [-t, t]``.
Expand All @@ -247,16 +246,16 @@ def predict(self,
df_intensities (pandas.DataFrame): Columns are channels and rows are spots.
segmentation_result (numpy.array): Segmentation mask with shape ``[batch, x, y, 1]``.
"""
if spots_threshold < 0 or spots_threshold > 1:
if threshold < 0 or threshold > 1:
raise ValueError('Threshold of %s was input. Threshold value must be '
'between 0 and 1.'.format())

output_image = self._predict_spots_image(spots_image, spots_threshold, spots_clip)
output_image = self._predict_spots_image(spots_image, threshold, clip)

clipped_output_image = np.clip(output_image, 0, 1)
max_proj_images = np.max(clipped_output_image, axis=-1)
spots_locations = max_cp_array_to_point_list_max(max_proj_images,
threshold=spots_threshold, min_distance=1)
threshold=threshold, min_distance=1)

spots_intensities = extract_spots_prob_from_coords_maxpool(
clipped_output_image, spots_locations, extra_pixel_num=maxpool_extra_pixel_num)
Expand Down
4 changes: 2 additions & 2 deletions deepcell_spots/applications/polaris_test.py
Expand Up @@ -81,9 +81,9 @@ def test_polaris_app(self):
app = Polaris()
spots_image = np.random.rand(1, 128, 128, 1)
with self.assertRaises(ValueError):
_ = app.predict(spots_image=spots_image, spots_threshold=1.1)
_ = app.predict(spots_image=spots_image, threshold=1.1)
with self.assertRaises(ValueError):
_ = app.predict(spots_image=spots_image, spots_threshold=-1.1)
_ = app.predict(spots_image=spots_image, threshold=-1.1)

# test segmentation app error
app = Polaris(segmentation_type='no segmentation')
Expand Down
5 changes: 3 additions & 2 deletions deepcell_spots/applications/spot_detection.py
Expand Up @@ -240,7 +240,7 @@ def predict(self,
preprocess_kwargs=None,
postprocess_kwargs=None,
threshold=0.95,
clip=False):
clip=True):
"""Generates a list of coordinate spot locations of the input
running prediction with appropriate pre and post processing
functions.
Expand All @@ -260,8 +260,9 @@ def predict(self,
postprocess_kwargs (dict): Keyword arguments to pass to the
post-processing function.
threshold (float): Probability threshold for a pixel to be
considered as a spot.
considered as a spot.
clip (bool): Determines if pixel values will be clipped by percentile.
Defaults to True.
Raises:
ValueError: Input data must match required rank of the application,
Expand Down

0 comments on commit 6eef203

Please sign in to comment.