Skip to content

Commit

Permalink
Merge efb92c1 into 75d8f07
Browse files Browse the repository at this point in the history
  • Loading branch information
cdpavelchek committed Jul 10, 2020
2 parents 75d8f07 + efb92c1 commit 59c394b
Show file tree
Hide file tree
Showing 7 changed files with 297 additions and 29 deletions.
2 changes: 2 additions & 0 deletions redis_consumer/consumers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,15 @@
# Custom Workflow consumers
from redis_consumer.consumers.image_consumer import ImageFileConsumer
from redis_consumer.consumers.tracking_consumer import TrackingConsumer
from redis_consumer.consumers.mibi_consumer import MibiConsumer
# TODO: Import future custom Consumer classes.


CONSUMERS = {
'image': ImageFileConsumer,
'zip': ZipFileConsumer,
'tracking': TrackingConsumer,
'mibi': MibiConsumer,
# TODO: Add future custom Consumer classes here.
}

Expand Down
29 changes: 29 additions & 0 deletions redis_consumer/consumers/base_consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,6 +553,35 @@ def predict(self, image, model_name, model_version, untile=True):

return image

def save_output(self, image, redis_hash, save_name, scale):
with utils.get_tempdir() as tempdir:
# Save each result channel as an image file
subdir = os.path.dirname(save_name.replace(tempdir, ''))
name = os.path.splitext(os.path.basename(save_name))[0]

# Rescale image to original size before sending back to user
if isinstance(image, list):
outpaths = []
for i in image:
outpaths.extend(utils.save_numpy_array(
utils.rescale(i, 1 / scale), name=name,
subdir=subdir, output_dir=tempdir))
else:
outpaths = utils.save_numpy_array(
utils.rescale(image, 1 / scale), name=name,
subdir=subdir, output_dir=tempdir)

# Save each prediction image as zip file
zip_file = utils.zip_files(outpaths, tempdir)

# Upload the zip file to cloud storage bucket
cleaned = zip_file.replace(tempdir, '')
subdir = os.path.dirname(settings._strip(cleaned))
subdir = subdir if subdir else None
dest, output_url = self.storage.upload(zip_file, subdir=subdir)

return dest, output_url


class ZipFileConsumer(Consumer):
"""Consumes zip files and uploads the results"""
Expand Down
28 changes: 2 additions & 26 deletions redis_consumer/consumers/image_consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,32 +271,8 @@ def _consume(self, redis_hash):
_ = timeit.default_timer()
self.update_key(redis_hash, {'status': 'saving-results'})

with utils.get_tempdir() as tempdir:
# Save each result channel as an image file
save_name = hvals.get('original_name', fname)
subdir = os.path.dirname(save_name.replace(tempdir, ''))
name = os.path.splitext(os.path.basename(save_name))[0]

# Rescale image to original size before sending back to user
if isinstance(image, list):
outpaths = []
for i in image:
outpaths.extend(utils.save_numpy_array(
utils.rescale(i, 1 / scale), name=name,
subdir=subdir, output_dir=tempdir))
else:
outpaths = utils.save_numpy_array(
utils.rescale(image, 1 / scale), name=name,
subdir=subdir, output_dir=tempdir)

# Save each prediction image as zip file
zip_file = utils.zip_files(outpaths, tempdir)

# Upload the zip file to cloud storage bucket
cleaned = zip_file.replace(tempdir, '')
subdir = os.path.dirname(settings._strip(cleaned))
subdir = subdir if subdir else None
dest, output_url = self.storage.upload(zip_file, subdir=subdir)
save_name = hvals.get('original_name', fname)
dest, output_url = self.save_output(image, redis_hash, save_name, scale)

# Update redis with the final results
t = timeit.default_timer() - start
Expand Down
126 changes: 126 additions & 0 deletions redis_consumer/consumers/mibi_consumer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
# Copyright 2016-2020 The Van Valen Lab at the California Institute of
# Technology (Caltech), with support from the Paul Allen Family Foundation,
# Google, & National Institutes of Health (NIH) under Grant U24CA224309-01.
# All rights reserved.
#
# Licensed under a modified Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.github.com/vanvalenlab/kiosk-redis-consumer/LICENSE
#
# The Work provided may be used for non-commercial academic purposes only.
# For any other use of the Work, including commercial use, please contact:
# vanvalenlab@gmail.com
#
# Neither the name of Caltech nor the names of its contributors may be used
# to endorse or promote products derived from this software without specific
# prior written permission.
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""ImageFileConsumer class for consuming image segmentation jobs."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import timeit

import numpy as np

from redis_consumer.consumers import ImageFileConsumer
from redis_consumer import utils
from redis_consumer import settings


class MibiConsumer(ImageFileConsumer):
"""Consumes image files and uploads the results"""

def _consume(self, redis_hash):
start = timeit.default_timer()
self._redis_hash = redis_hash # workaround for logging.
hvals = self.redis.hgetall(redis_hash)

if hvals.get('status') in self.finished_statuses:
self.logger.warning('Found completed hash `%s` with status %s.',
redis_hash, hvals.get('status'))
return hvals.get('status')

self.logger.debug('Found hash to process `%s` with status `%s`.',
redis_hash, hvals.get('status'))

self.update_key(redis_hash, {
'status': 'started',
'identity_started': self.name,
})

# Get model_name and version
model_name, model_version = settings.MIBI_MODEL.split(':')

_ = timeit.default_timer()

# Load input image
with utils.get_tempdir() as tempdir:
fname = self.storage.download(hvals.get('input_file_name'), tempdir)
# TODO: tiffs expand the last axis, is that a problem here?
image = utils.get_image(fname)

# Pre-process data before sending to the model
self.update_key(redis_hash, {
'status': 'pre-processing',
'download_time': timeit.default_timer() - _,
})

# Calculate scale of image and rescale
scale = hvals.get('scale', '')
if not scale:
# Detect scale of image (Default to 1)
# TODO: implement SCALE_DETECT here for mibi model
# scale = self.detect_scale(image)
# self.logger.debug('Image scale detected: %s', scale)
# self.update_key(redis_hash, {'scale': scale})
self.logger.debug('Scale was not given. Defaults to 1')
scale = 1
else:
scale = float(scale)
self.logger.debug('Image scale already calculated: %s', scale)

# Rescale each channel of the image
image = utils.rescale(image, scale)
image = np.expand_dims(image, axis=0) # add in the batch dim

# Preprocess image
image = self.preprocess(image, ['histogram_normalization'])

# Send data to the model
self.update_key(redis_hash, {'status': 'predicting'})
image = self.predict(image, model_name, model_version)

# Post-process model results
self.update_key(redis_hash, {'status': 'post-processing'})
image = self.postprocess(image, ['mibi'])

# Save the post-processed results to a file
_ = timeit.default_timer()
self.update_key(redis_hash, {'status': 'saving-results'})

save_name = hvals.get('original_name', fname)
dest, output_url = self.save_output(image, redis_hash, save_name, scale)

# Update redis with the final results
t = timeit.default_timer() - start
self.update_key(redis_hash, {
'status': self.final_status,
'output_url': output_url,
'upload_time': timeit.default_timer() - _,
'output_file_name': dest,
'total_jobs': 1,
'total_time': t,
'finished_at': self.get_current_timestamp()
})
return self.final_status
127 changes: 127 additions & 0 deletions redis_consumer/consumers/mibi_consumer_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# Copyright 2016-2020 The Van Valen Lab at the California Institute of
# Technology (Caltech), with support from the Paul Allen Family Foundation,
# Google, & National Institutes of Health (NIH) under Grant U24CA224309-01.
# All rights reserved.
#
# Licensed under a modified Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.github.com/vanvalenlab/kiosk-redis-consumer/LICENSE
#
# The Work provided may be used for non-commercial academic purposes only.
# For any other use of the Work, including commercial use, please contact:
# vanvalenlab@gmail.com
#
# Neither the name of Caltech nor the names of its contributors may be used
# to endorse or promote products derived from this software without specific
# prior written permission.
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Tests for MibiConsumer"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import itertools

import numpy as np

import pytest

from redis_consumer import consumers
from redis_consumer.testing_utils import redis_client, DummyStorage


class TestMibiConsumer(object):
# pylint: disable=R0201

def test_is_valid_hash(self, mocker, redis_client):
storage = DummyStorage()
mocker.patch.object(redis_client, 'hget', lambda *x: x[0])

consumer = consumers.MibiConsumer(redis_client, storage, 'mibi')
assert consumer.is_valid_hash(None) is False
assert consumer.is_valid_hash('file.ZIp') is False
assert consumer.is_valid_hash('predict:1234567890:file.ZIp') is False
assert consumer.is_valid_hash('track:123456789:file.zip') is False
assert consumer.is_valid_hash('predict:123456789:file.zip') is False
assert consumer.is_valid_hash('mibi:1234567890:file.tiff') is True
assert consumer.is_valid_hash('mibi:1234567890:file.png') is True

def test__consume(self, mocker, redis_client):
# pylint: disable=W0613

def make_model_metadata_of_size(model_shape=(-1, 256, 256, 2)):

def get_model_metadata(model_name, model_version):
return [{
'in_tensor_name': 'image',
'in_tensor_dtype': 'DT_FLOAT',
'in_tensor_shape': ','.join(str(s) for s in model_shape),
}]
return get_model_metadata

def make_grpc_image(model_shape=(-1, 256, 256, 2)):
# pylint: disable=E1101
shape = model_shape[1:-1]

def grpc(data, *args, **kwargs):
inner = np.random.random((1,) + shape + (1,))
outer = np.random.random((1,) + shape + (1,))
fgbg = np.random.random((1,) + shape + (2,))
feature = np.random.random((1,) + shape + (3,))
return [inner, outer, fgbg, feature]
return grpc

image_shape = (300, 300, 2)
model_shapes = [
(-1, 600, 600, 2), # image too small, pad
(-1, 300, 300, 2), # image is exactly the right size
(-1, 150, 150, 2), # image too big, tile
]

scales = ['.9', '']

job_data = {
'input_file_name': 'file.tiff',
}

consumer = consumers.MibiConsumer(redis_client, DummyStorage(), 'mibi')

test_hash = 0
# test finished statuses are returned
for status in (consumer.failed_status, consumer.final_status):
test_hash += 1
data = job_data.copy()
data['status'] = status
redis_client.hmset(test_hash, data)
result = consumer._consume(test_hash)
assert result == status
result = redis_client.hget(test_hash, 'status')
assert result == status
test_hash += 1

mocker.patch('redis_consumer.utils.get_image',
lambda x: np.random.random(image_shape))

for model_shape, scale in itertools.product(model_shapes, scales):
metadata = make_model_metadata_of_size(model_shape)
grpc_image = make_grpc_image(model_shape)
mocker.patch.object(consumer, 'get_model_metadata', metadata)
mocker.patch.object(consumer, 'grpc_image', grpc_image)

data = job_data.copy()
data['scale'] = scale

redis_client.hmset(test_hash, data)
result = consumer._consume(test_hash)
assert result == consumer.final_status
result = redis_client.hget(test_hash, 'status')
assert result == consumer.final_status
test_hash += 1
4 changes: 4 additions & 0 deletions redis_consumer/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@

from deepcell_toolbox.deep_watershed import deep_watershed

# import mibi pre- and post-processing functions
from deepcell_toolbox.deep_watershed import deep_watershed_mibi
from deepcell_toolbox.processing import phase_preprocess

from deepcell_toolbox import retinanet_semantic_to_label_image
from deepcell_toolbox import retinanet_to_label_image

Expand Down
10 changes: 7 additions & 3 deletions redis_consumer/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,16 +118,17 @@ def _strip(x):
# Pre- and Post-processing settings
PROCESSING_FUNCTIONS = {
'pre': {
'normalize': processing.normalize
'normalize': processing.normalize,
'histogram_normalization': processing.phase_preprocess,
},
'post': {
'deepcell': processing.pixelwise, # TODO: this is deprecated.
'pixelwise': processing.pixelwise,
'mibi': processing.mibi,
'watershed': processing.watershed,
'retinanet': processing.retinanet_to_label_image,
'retinanet-semantic': processing.retinanet_semantic_to_label_image,
'deep_watershed': processing.deep_watershed,
'mibi': processing.deep_watershed_mibi,
},
}

Expand Down Expand Up @@ -157,11 +158,14 @@ def _strip(x):
LABEL_DETECT_MODEL = config('LABEL_DETECT_MODEL', default='LabelDetection:1', cast=str)
LABEL_DETECT_ENABLED = config('LABEL_DETECT_ENABLED', default=False, cast=bool)

# MIBI model Settings
MIBI_MODEL = config('MIBI_MODEL', default='NewMIBI:0', cast=str)

# Set default models based on label type
MODEL_CHOICES = {
0: config('NUCLEAR_MODEL', default='NuclearSegmentation:0', cast=str),
1: config('PHASE_MODEL', default='PhaseCytoSegmentation:0', cast=str),
2: config('CYTOPLASM_MODEL', default='FluoCytoSegmentation:0', cast=str)
2: config('CYTOPLASM_MODEL', default='FluoCytoSegmentation:0', cast=str),
}

POSTPROCESS_CHOICES = {
Expand Down

0 comments on commit 59c394b

Please sign in to comment.