diff --git a/Dockerfile b/Dockerfile index 4aee9de3..6c3f1743 100644 --- a/Dockerfile +++ b/Dockerfile @@ -31,6 +31,8 @@ RUN apt-get update && apt-get install -y \ build-essential libglib2.0-0 && \ rm -rf /var/lib/apt/lists/* +RUN apt-get update && apt-get install -y git + COPY requirements.txt requirements-no-deps.txt ./ RUN pip install --no-cache-dir -r requirements.txt && \ diff --git a/redis_consumer/consumers/__init__.py b/redis_consumer/consumers/__init__.py index 387fe28d..de3ed386 100644 --- a/redis_consumer/consumers/__init__.py +++ b/redis_consumer/consumers/__init__.py @@ -36,6 +36,7 @@ from redis_consumer.consumers.segmentation_consumer import SegmentationConsumer from redis_consumer.consumers.caliban_consumer import CalibanConsumer from redis_consumer.consumers.mesmer_consumer import MesmerConsumer +from redis_consumer.consumers.polaris_consumer import PolarisConsumer # TODO: Import future custom Consumer classes. @@ -47,6 +48,7 @@ 'multiplex': MesmerConsumer, # deprecated, use "mesmer" instead. 'mesmer': MesmerConsumer, 'caliban': CalibanConsumer, + 'polaris': PolarisConsumer, # TODO: Add future custom Consumer classes here. } diff --git a/redis_consumer/consumers/polaris_consumer.py b/redis_consumer/consumers/polaris_consumer.py new file mode 100644 index 00000000..8f70673d --- /dev/null +++ b/redis_consumer/consumers/polaris_consumer.py @@ -0,0 +1,190 @@ +# Copyright 2016-2022 The Van Valen Lab at the California Institute of +# Technology (Caltech), with support from the Paul Allen Family Foundation, +# Google, & National Institutes of Health (NIH) under Grant U24CA224309-01. +# All rights reserved. +# +# Licensed under a modified Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.github.com/vanvalenlab/kiosk-redis-consumer/LICENSE +# +# The Work provided may be used for non-commercial academic purposes only. +# For any other use of the Work, including commercial use, please contact: +# vanvalenlab@gmail.com +# +# Neither the name of Caltech nor the names of its contributors may be used +# to endorse or promote products derived from this software without specific +# prior written permission. +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""PolarisConsumer class for consuming SpotDetection jobs.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import tempfile +import timeit + +import matplotlib.pyplot as plt +import numpy as np + +from deepcell_spots.applications import SpotDetection + +from redis_consumer.consumers import TensorFlowServingConsumer +from redis_consumer import settings +from redis_consumer import utils + + +class PolarisConsumer(TensorFlowServingConsumer): + """Consumes image files and uploads the results""" + + def save_output(self, coords, image, save_name): + """Save output in a zip file and upload it. Output includes predicted spot locations + plotted on original image as a .tiff file and coordinate spot locations as .npy file""" + with tempfile.TemporaryDirectory() as tempdir: + # Save each result channel as an image file + subdir = os.path.dirname(save_name.replace(tempdir, '')) + name = os.path.splitext(os.path.basename(save_name))[0] + + outpaths = [] + for i in range(len(coords)): + # Save image with plotted spot locations + img_name = '{}.tif'.format(i) + if name: + img_name = '{}_{}'.format(name, img_name) + + img_path = os.path.join(tempdir, subdir, img_name) + + fig = plt.figure() + plt.ioff() + plt.imshow(image[i], cmap='gray') + plt.scatter(coords[i][:, 1], coords[i][:, 0], edgecolors='r', facecolors='None') + plt.xticks([]) + plt.yticks([]) + plt.savefig(img_path) + + # Save coordiates + coords_name = '{}.npy'.format(i) + if name: + coords_name = '{}_{}'.format(name, coords_name) + + coords_path = os.path.join(tempdir, subdir, coords_name) + + np.save(coords_path, coords[i]) + + outpaths.extend([img_path, coords_path]) + # outpaths.extend([coords_path]) + + # Save each prediction image as zip file + zip_file = utils.zip_files(outpaths, tempdir) + + # Upload the zip file to cloud storage bucket + cleaned = zip_file.replace(tempdir, '') + subdir = os.path.dirname(utils.strip_bucket_path(cleaned)) + subdir = subdir if subdir else None + dest, output_url = self.storage.upload(zip_file, subdir=subdir) + + return dest, output_url + + def _consume(self, redis_hash): + start = timeit.default_timer() + hvals = self.redis.hgetall(redis_hash) + + if hvals.get('status') in self.finished_statuses: + self.logger.warning('Found completed hash `%s` with status %s.', + redis_hash, hvals.get('status')) + return hvals.get('status') + + self.logger.debug('Found hash to process `%s` with status `%s`.', + redis_hash, hvals.get('status')) + + self.update_key(redis_hash, { + 'status': 'started', + 'identity_started': self.name, + }) + + # Get model_name and version + model_name, model_version = settings.POLARIS_MODEL.split(':') + + _ = timeit.default_timer() + + # Load input image + fname = hvals.get('input_file_name') + image = self.download_image(fname) + + # squeeze extra dimension that is added by get_image + image = np.squeeze(image) + if image.ndim == 2: + # add in the batch and channel dims + image = np.expand_dims(image, axis=[0, -1]) + elif image.ndim == 3: + # check if batch first or last + if np.shape(image)[2] < np.shape(image)[1]: + image = np.rollaxis(image, 2, 0) + # add in the channel dim + image = np.expand_dims(image, axis=[-1]) + else: + raise ValueError('Image with {} shape was uploaded, but Polaris only ' + 'supports multi-batch or multi-channel images.'.format( + np.shape(image))) + + # Pre-process data before sending to the model + self.update_key(redis_hash, { + 'status': 'pre-processing', + 'download_time': timeit.default_timer() - _, + }) + + # detect dimension order and add to redis + dim_order = self.detect_dimension_order(image, model_name, model_version) + self.update_key(redis_hash, { + 'dim_order': ','.join(dim_order) + }) + + # Validate input image + if hvals.get('channels'): + channels = [int(c) for c in hvals.get('channels').split(',')] + else: + channels = None + + image = self.validate_model_input(image, model_name, model_version, + channels=channels) + + # Send data to the model + self.update_key(redis_hash, {'status': 'predicting'}) + + app = self.get_grpc_app(settings.POLARIS_MODEL, SpotDetection) + + # with new batching update in deepcell.applications, + # app.predict() cannot handle a batch_size of None. + batch_size = app.model.get_batch_size() + threshold = hvals.get('threshold', settings.POLARIS_THRESHOLD) + clip = hvals.get('clip', settings.POLARIS_CLIP) + results = app.predict(image, batch_size=batch_size, threshold=threshold, + clip=clip) + + # Save the post-processed results to a file + _ = timeit.default_timer() + self.update_key(redis_hash, {'status': 'saving-results'}) + + save_name = hvals.get('original_name', fname) + dest, output_url = self.save_output(results, image, save_name) + + # Update redis with the final results + end = timeit.default_timer() + self.update_key(redis_hash, { + 'status': self.final_status, + 'output_url': output_url, + 'upload_time': end - _, + 'output_file_name': dest, + 'total_jobs': 1, + 'total_time': end - start, + 'finished_at': self.get_current_timestamp() + }) + return self.final_status diff --git a/redis_consumer/consumers/polaris_consumer_test.py b/redis_consumer/consumers/polaris_consumer_test.py new file mode 100644 index 00000000..328f11a1 --- /dev/null +++ b/redis_consumer/consumers/polaris_consumer_test.py @@ -0,0 +1,98 @@ +# Copyright 2016-2022 The Van Valen Lab at the California Institute of +# Technology (Caltech), with support from the Paul Allen Family Foundation, +# Google, & National Institutes of Health (NIH) under Grant U24CA224309-01. +# All rights reserved. +# +# Licensed under a modified Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.github.com/vanvalenlab/kiosk-redis-consumer/LICENSE +# +# The Work provided may be used for non-commercial academic purposes only. +# For any other use of the Work, including commercial use, please contact: +# vanvalenlab@gmail.com +# +# Neither the name of Caltech nor the names of its contributors may be used +# to endorse or promote products derived from this software without specific +# prior written permission. +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Tests for PolarisConsumer""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +import pytest + +from redis_consumer import consumers +from redis_consumer import settings +from redis_consumer.testing_utils import _get_image +from redis_consumer.testing_utils import Bunch +from redis_consumer.testing_utils import DummyStorage +from redis_consumer.testing_utils import redis_client + + +class TestPolarisConsumer(object): + # pylint: disable=R0201,W0621 + + def test__consume_finished_status(self, redis_client): + queue = 'q' + storage = DummyStorage() + + consumer = consumers.PolarisConsumer(redis_client, storage, queue) + + empty_data = {'input_file_name': 'file.tiff'} + + test_hash = 0 + # test finished statuses are returned + for status in (consumer.failed_status, consumer.final_status): + test_hash += 1 + data = empty_data.copy() + data['status'] = status + redis_client.hmset(test_hash, data) + result = consumer._consume(test_hash) + assert result == status + result = redis_client.hget(test_hash, 'status') + assert result == status + test_hash += 1 + + def test__consume(self, mocker, redis_client): + # pylint: disable=W0613 + queue = 'multiplex' + storage = DummyStorage() + + consumer = consumers.PolarisConsumer(redis_client, storage, queue) + + empty_data = {'input_file_name': 'file.tiff'} + + output_shape = (1, 256, 256, 2) + + mock_app = Bunch( + predict=lambda *x, **y: np.random.randint(1, 5, size=output_shape), + model_mpp=1, + model=Bunch( + get_batch_size=lambda *x: 1, + input_shape=(1, 32, 32, 1) + ) + ) + + mocker.patch.object(consumer, 'get_grpc_app', lambda *x, **_: mock_app) + mocker.patch.object(consumer, 'get_image_scale', lambda *x, **_: 1) + mocker.patch.object(consumer, 'validate_model_input', lambda *x, **_: x[0]) + mocker.patch.object(consumer, 'detect_dimension_order', lambda *x, **_: 'YXC') + + test_hash = 'some hash' + + redis_client.hmset(test_hash, empty_data) + result = consumer._consume(test_hash) + assert result == consumer.final_status + result = redis_client.hget(test_hash, 'status') + assert result == consumer.final_status diff --git a/redis_consumer/settings.py b/redis_consumer/settings.py index 8b46e840..2791f384 100644 --- a/redis_consumer/settings.py +++ b/redis_consumer/settings.py @@ -118,6 +118,11 @@ MESMER_MODEL = config('MESMER_MODEL', default=MULTIPLEX_MODEL, cast=str) MESMER_COMPARTMENT = config('MESMER_COMPARTMENT', default='whole-cell') +# Polaris model Settings +POLARIS_MODEL = config('POLARIS_MODEL', default='SpotDetection:3', cast=str) +POLARIS_THRESHOLD = config('POLARIS_THRESHOLD', default=0.95, cast=float) +POLARIS_CLIP = config('POLARIS_CLIP', default=False, cast=bool) + # Set default models based on label type MODEL_CHOICES = { 0: config('NUCLEAR_MODEL', default='NuclearSegmentation:0', cast=str), diff --git a/requirements.txt b/requirements.txt index 6e28b31c..f221e552 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,6 +5,9 @@ deepcell-tracking~=0.5.2 tensorflow-cpu~=2.5.2 tifffile>=2020.9.3 numpy>=1.16.6 +matplotlib>=2.1.1 + +git+git://github.com/vanvalenlab/deepcell-spots@f7749cf77d67a4bfd3a56a66b6488cb0feaffecf # tensorflow-serving-apis and gRPC dependencies grpcio>=1.0,<2