Skip to content

Commit

Permalink
Adding Polaris consumer (#181)
Browse files Browse the repository at this point in the history
* Initial commit of Polaris consumer

* Override save_output

* Add Polaris settings to settings.py

* Add git installation to Dockerfile

* Add tests

* Bug fixes in save_output

* Update consumer output

* Fix bugs image dimensions and app dependency

* Update spot detection model

* Pin new deepcell_spots commit

* Move threshold and clip variables to settings.py

* Add error for multi-batch, multi-channel images

Co-authored-by: Morgan Schwartz <msschwartz21@gmail.com>
  • Loading branch information
elaubsch and msschwartz21 committed Jan 27, 2022
1 parent c135f80 commit cfb3cc2
Show file tree
Hide file tree
Showing 6 changed files with 300 additions and 0 deletions.
2 changes: 2 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ RUN apt-get update && apt-get install -y \
build-essential libglib2.0-0 && \
rm -rf /var/lib/apt/lists/*

RUN apt-get update && apt-get install -y git

COPY requirements.txt requirements-no-deps.txt ./

RUN pip install --no-cache-dir -r requirements.txt && \
Expand Down
2 changes: 2 additions & 0 deletions redis_consumer/consumers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from redis_consumer.consumers.segmentation_consumer import SegmentationConsumer
from redis_consumer.consumers.caliban_consumer import CalibanConsumer
from redis_consumer.consumers.mesmer_consumer import MesmerConsumer
from redis_consumer.consumers.polaris_consumer import PolarisConsumer
# TODO: Import future custom Consumer classes.


Expand All @@ -47,6 +48,7 @@
'multiplex': MesmerConsumer, # deprecated, use "mesmer" instead.
'mesmer': MesmerConsumer,
'caliban': CalibanConsumer,
'polaris': PolarisConsumer,
# TODO: Add future custom Consumer classes here.
}

Expand Down
190 changes: 190 additions & 0 deletions redis_consumer/consumers/polaris_consumer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
# Copyright 2016-2022 The Van Valen Lab at the California Institute of
# Technology (Caltech), with support from the Paul Allen Family Foundation,
# Google, & National Institutes of Health (NIH) under Grant U24CA224309-01.
# All rights reserved.
#
# Licensed under a modified Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.github.com/vanvalenlab/kiosk-redis-consumer/LICENSE
#
# The Work provided may be used for non-commercial academic purposes only.
# For any other use of the Work, including commercial use, please contact:
# vanvalenlab@gmail.com
#
# Neither the name of Caltech nor the names of its contributors may be used
# to endorse or promote products derived from this software without specific
# prior written permission.
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""PolarisConsumer class for consuming SpotDetection jobs."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import tempfile
import timeit

import matplotlib.pyplot as plt
import numpy as np

from deepcell_spots.applications import SpotDetection

from redis_consumer.consumers import TensorFlowServingConsumer
from redis_consumer import settings
from redis_consumer import utils


class PolarisConsumer(TensorFlowServingConsumer):
"""Consumes image files and uploads the results"""

def save_output(self, coords, image, save_name):
"""Save output in a zip file and upload it. Output includes predicted spot locations
plotted on original image as a .tiff file and coordinate spot locations as .npy file"""
with tempfile.TemporaryDirectory() as tempdir:
# Save each result channel as an image file
subdir = os.path.dirname(save_name.replace(tempdir, ''))
name = os.path.splitext(os.path.basename(save_name))[0]

outpaths = []
for i in range(len(coords)):
# Save image with plotted spot locations
img_name = '{}.tif'.format(i)
if name:
img_name = '{}_{}'.format(name, img_name)

img_path = os.path.join(tempdir, subdir, img_name)

fig = plt.figure()
plt.ioff()
plt.imshow(image[i], cmap='gray')
plt.scatter(coords[i][:, 1], coords[i][:, 0], edgecolors='r', facecolors='None')
plt.xticks([])
plt.yticks([])
plt.savefig(img_path)

# Save coordiates
coords_name = '{}.npy'.format(i)
if name:
coords_name = '{}_{}'.format(name, coords_name)

coords_path = os.path.join(tempdir, subdir, coords_name)

np.save(coords_path, coords[i])

outpaths.extend([img_path, coords_path])
# outpaths.extend([coords_path])

# Save each prediction image as zip file
zip_file = utils.zip_files(outpaths, tempdir)

# Upload the zip file to cloud storage bucket
cleaned = zip_file.replace(tempdir, '')
subdir = os.path.dirname(utils.strip_bucket_path(cleaned))
subdir = subdir if subdir else None
dest, output_url = self.storage.upload(zip_file, subdir=subdir)

return dest, output_url

def _consume(self, redis_hash):
start = timeit.default_timer()
hvals = self.redis.hgetall(redis_hash)

if hvals.get('status') in self.finished_statuses:
self.logger.warning('Found completed hash `%s` with status %s.',
redis_hash, hvals.get('status'))
return hvals.get('status')

self.logger.debug('Found hash to process `%s` with status `%s`.',
redis_hash, hvals.get('status'))

self.update_key(redis_hash, {
'status': 'started',
'identity_started': self.name,
})

# Get model_name and version
model_name, model_version = settings.POLARIS_MODEL.split(':')

_ = timeit.default_timer()

# Load input image
fname = hvals.get('input_file_name')
image = self.download_image(fname)

# squeeze extra dimension that is added by get_image
image = np.squeeze(image)
if image.ndim == 2:
# add in the batch and channel dims
image = np.expand_dims(image, axis=[0, -1])
elif image.ndim == 3:
# check if batch first or last
if np.shape(image)[2] < np.shape(image)[1]:
image = np.rollaxis(image, 2, 0)
# add in the channel dim
image = np.expand_dims(image, axis=[-1])
else:
raise ValueError('Image with {} shape was uploaded, but Polaris only '
'supports multi-batch or multi-channel images.'.format(
np.shape(image)))

# Pre-process data before sending to the model
self.update_key(redis_hash, {
'status': 'pre-processing',
'download_time': timeit.default_timer() - _,
})

# detect dimension order and add to redis
dim_order = self.detect_dimension_order(image, model_name, model_version)
self.update_key(redis_hash, {
'dim_order': ','.join(dim_order)
})

# Validate input image
if hvals.get('channels'):
channels = [int(c) for c in hvals.get('channels').split(',')]
else:
channels = None

image = self.validate_model_input(image, model_name, model_version,
channels=channels)

# Send data to the model
self.update_key(redis_hash, {'status': 'predicting'})

app = self.get_grpc_app(settings.POLARIS_MODEL, SpotDetection)

# with new batching update in deepcell.applications,
# app.predict() cannot handle a batch_size of None.
batch_size = app.model.get_batch_size()
threshold = hvals.get('threshold', settings.POLARIS_THRESHOLD)
clip = hvals.get('clip', settings.POLARIS_CLIP)
results = app.predict(image, batch_size=batch_size, threshold=threshold,
clip=clip)

# Save the post-processed results to a file
_ = timeit.default_timer()
self.update_key(redis_hash, {'status': 'saving-results'})

save_name = hvals.get('original_name', fname)
dest, output_url = self.save_output(results, image, save_name)

# Update redis with the final results
end = timeit.default_timer()
self.update_key(redis_hash, {
'status': self.final_status,
'output_url': output_url,
'upload_time': end - _,
'output_file_name': dest,
'total_jobs': 1,
'total_time': end - start,
'finished_at': self.get_current_timestamp()
})
return self.final_status
98 changes: 98 additions & 0 deletions redis_consumer/consumers/polaris_consumer_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# Copyright 2016-2022 The Van Valen Lab at the California Institute of
# Technology (Caltech), with support from the Paul Allen Family Foundation,
# Google, & National Institutes of Health (NIH) under Grant U24CA224309-01.
# All rights reserved.
#
# Licensed under a modified Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.github.com/vanvalenlab/kiosk-redis-consumer/LICENSE
#
# The Work provided may be used for non-commercial academic purposes only.
# For any other use of the Work, including commercial use, please contact:
# vanvalenlab@gmail.com
#
# Neither the name of Caltech nor the names of its contributors may be used
# to endorse or promote products derived from this software without specific
# prior written permission.
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Tests for PolarisConsumer"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np

import pytest

from redis_consumer import consumers
from redis_consumer import settings
from redis_consumer.testing_utils import _get_image
from redis_consumer.testing_utils import Bunch
from redis_consumer.testing_utils import DummyStorage
from redis_consumer.testing_utils import redis_client


class TestPolarisConsumer(object):
# pylint: disable=R0201,W0621

def test__consume_finished_status(self, redis_client):
queue = 'q'
storage = DummyStorage()

consumer = consumers.PolarisConsumer(redis_client, storage, queue)

empty_data = {'input_file_name': 'file.tiff'}

test_hash = 0
# test finished statuses are returned
for status in (consumer.failed_status, consumer.final_status):
test_hash += 1
data = empty_data.copy()
data['status'] = status
redis_client.hmset(test_hash, data)
result = consumer._consume(test_hash)
assert result == status
result = redis_client.hget(test_hash, 'status')
assert result == status
test_hash += 1

def test__consume(self, mocker, redis_client):
# pylint: disable=W0613
queue = 'multiplex'
storage = DummyStorage()

consumer = consumers.PolarisConsumer(redis_client, storage, queue)

empty_data = {'input_file_name': 'file.tiff'}

output_shape = (1, 256, 256, 2)

mock_app = Bunch(
predict=lambda *x, **y: np.random.randint(1, 5, size=output_shape),
model_mpp=1,
model=Bunch(
get_batch_size=lambda *x: 1,
input_shape=(1, 32, 32, 1)
)
)

mocker.patch.object(consumer, 'get_grpc_app', lambda *x, **_: mock_app)
mocker.patch.object(consumer, 'get_image_scale', lambda *x, **_: 1)
mocker.patch.object(consumer, 'validate_model_input', lambda *x, **_: x[0])
mocker.patch.object(consumer, 'detect_dimension_order', lambda *x, **_: 'YXC')

test_hash = 'some hash'

redis_client.hmset(test_hash, empty_data)
result = consumer._consume(test_hash)
assert result == consumer.final_status
result = redis_client.hget(test_hash, 'status')
assert result == consumer.final_status
5 changes: 5 additions & 0 deletions redis_consumer/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,11 @@
MESMER_MODEL = config('MESMER_MODEL', default=MULTIPLEX_MODEL, cast=str)
MESMER_COMPARTMENT = config('MESMER_COMPARTMENT', default='whole-cell')

# Polaris model Settings
POLARIS_MODEL = config('POLARIS_MODEL', default='SpotDetection:3', cast=str)
POLARIS_THRESHOLD = config('POLARIS_THRESHOLD', default=0.95, cast=float)
POLARIS_CLIP = config('POLARIS_CLIP', default=False, cast=bool)

# Set default models based on label type
MODEL_CHOICES = {
0: config('NUCLEAR_MODEL', default='NuclearSegmentation:0', cast=str),
Expand Down
3 changes: 3 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ deepcell-tracking~=0.5.2
tensorflow-cpu~=2.5.2
tifffile>=2020.9.3
numpy>=1.16.6
matplotlib>=2.1.1

git+git://github.com/vanvalenlab/deepcell-spots@f7749cf77d67a4bfd3a56a66b6488cb0feaffecf

# tensorflow-serving-apis and gRPC dependencies
grpcio>=1.0,<2
Expand Down

0 comments on commit cfb3cc2

Please sign in to comment.