From 02c1cd35b2af06a9932ba6f236a1f372c09c6b3c Mon Sep 17 00:00:00 2001
From: William Graf <7930703+willgraf@users.noreply.github.com>
Date: Mon, 13 Jul 2020 13:46:04 -0700
Subject: [PATCH 01/12] update code example with pre- and post-processing
examples.
---
README.md | 89 +++++++++++++++++++++++++++++++------------------------
1 file changed, 50 insertions(+), 39 deletions(-)
diff --git a/README.md b/README.md
index 85c99ec9..304aab49 100644
--- a/README.md
+++ b/README.md
@@ -24,45 +24,56 @@ Each Redis event should have the following fields:
If the consumer will send data to a TensorFlow Serving model, it should inherit from `redis_consumer.consumers.TensorFlowServingConsumer` ([docs](https://deepcell-kiosk.readthedocs.io/projects/kiosk-redis-consumer/en/master/redis_consumer.consumers.html)), which has methods `_get_predict_client()` and `grpc_image()` which can send data to the specific model. The new consumer must also implement the `_consume()` method which performs the bulk of the work. The `_consume()` method will fetch data from Redis, download data file from the bucket, process the data with a model, and upload the results to the bucket again. See below for a basic implementation of `_consume()`:
```python
- def _consume(self, redis_hash):
- # get all redis data for the given hash
- hvals = self.redis.hgetall(redis_hash)
-
- with utils.get_tempdir() as tempdir:
- # download the image file
- fname = self.storage.download(hvals.get('input_file_name'), tempdir)
-
- # load image file as data
- image = utils.get_image(fname)
-
- # preprocess data if necessary
-
- # send the data to the model
- results = self.grpc_image(image,
- hvals.get('model_name'),
- hvals.get('model_version'))
-
- # postprocess results if necessary
-
- # save the results as an image
- outpaths = utils.save_numpy_array(results, name=name,
- subdir=subdir, output_dir=tempdir)
-
- # zip up the file
- zip_file = utils.zip_files(outpaths, tempdir)
-
- # upload the zip file to the cloud bucket
- dest, output_url = self.storage.upload(zip_file)
-
- # save the results to the redis hash
- self.update_key(redis_hash, {
- 'status': self.final_status,
- 'output_url': output_url,
- 'output_file_name': dest
- })
-
- # return the final status
- return self.final_status
+def _consume(self, redis_hash):
+ # get all redis data for the given hash
+ hvals = self.redis.hgetall(redis_hash)
+
+ # only work on unfinished jobs
+ if hvals.get('status') in self.finished_statuses:
+ self.logger.warning('Found completed hash `%s` with status %s.',
+ redis_hash, hvals.get('status'))
+ return hvals.get('status')
+
+ # the data to process with the model, required.
+ input_file_name = hvals.get('input_file_name')
+
+ # TODO: model information can be saved in redis or defined in the consumer.
+ model_name = hvals.get('model_name')
+ model_version = hvals.get('model_version')
+
+ with utils.get_tempdir() as tempdir:
+ # download the image file
+ fname = self.storage.download(input_file_name, tempdir)
+ # load image file as data
+ image = utils.get_image(fname)
+
+ # TODO: pre- and post-processing can be used with the BaseConsumer.process,
+ # which uses pre-defined functions in settings.PROCESSING_FUNCTIONS.
+ image = self.preprocess(image, 'normalize')
+
+ # send the data to the model
+ results = self.predict(image, model_name, model_version)
+
+ # TODO: post-process the model results into a label image.
+ image = self.postprocess(image, 'deep_watershed')
+
+ # save the results as an image file and upload it to the bucket
+ save_name = hvals.get('original_name', fname)
+ dest, output_url = self.save_output(image, redis_hash, save_name, scale)
+
+ # save the results to the redis hash
+ self.update_key(redis_hash, {
+ 'status': self.final_status,
+ 'output_url': output_url,
+ 'upload_time': timeit.default_timer() - _,
+ 'output_file_name': dest,
+ 'total_jobs': 1,
+ 'total_time': timeit.default_timer() - start,
+ 'finished_at': self.get_current_timestamp()
+ })
+
+ # return the final status
+ return self.final_status
```
Finally, the new consumer needs to be imported into the redis_consumer/consumers/\_\_init\_\_.py and added to the `CONSUMERS` dictionary with a correponding queue type (`queue_name`). The script consume-redis-events.py will load the consumer class based on the `CONSUMER_TYPE`.
From 83879df9482674ff872884db5c146341f8b4793c Mon Sep 17 00:00:00 2001
From: William Graf <7930703+willgraf@users.noreply.github.com>
Date: Mon, 13 Jul 2020 14:42:47 -0700
Subject: [PATCH 02/12] clarify header paragraph
---
README.md | 11 ++++-------
1 file changed, 4 insertions(+), 7 deletions(-)
diff --git a/README.md b/README.md
index 304aab49..6327e19e 100644
--- a/README.md
+++ b/README.md
@@ -14,14 +14,11 @@ This repository is part of the [DeepCell Kiosk](https://github.com/vanvalenlab/k
Custom consumers can be used to implement custom model pipelines. This documentation is a continuation of a [tutorial](https://deepcell-kiosk.readthedocs.io/en/master/CUSTOM-JOB.html) on building a custom job pipeline.
Consumers consume Redis events. Each type of Redis event is put into a separate queue (e.g. `predict`, `track`), and each consumer type will pop items to consume off that queue.
+Consumers call the `_consume` method to consume each item it finds in the queue.
+This method must be implemented for every consumer.
-Each Redis event should have the following fields:
-
-- `model_name` - The name of the model that will be retrieved by TensorFlow Serving from `gs:///models`
-- `model_version` - The version number of the model in TensorFlow Serving
-- `input_file_name` - The path to the data file in a cloud bucket.
-
-If the consumer will send data to a TensorFlow Serving model, it should inherit from `redis_consumer.consumers.TensorFlowServingConsumer` ([docs](https://deepcell-kiosk.readthedocs.io/projects/kiosk-redis-consumer/en/master/redis_consumer.consumers.html)), which has methods `_get_predict_client()` and `grpc_image()` which can send data to the specific model. The new consumer must also implement the `_consume()` method which performs the bulk of the work. The `_consume()` method will fetch data from Redis, download data file from the bucket, process the data with a model, and upload the results to the bucket again. See below for a basic implementation of `_consume()`:
+The quickest way to get a custom consumer up and running is to inherit from `redis_consumer.consumers.ImageFileConsumer` ([docs](https://deepcell-kiosk.readthedocs.io/projects/kiosk-redis-consumer/en/master/redis_consumer.consumers.html)), which uses the `preprocess`, `predict`, and `postprocess` methods to easily process data with the model.
+See below for a basic implementation of `_consume()` making use of the methods inherited from `ImageFileConsumer`:
```python
def _consume(self, redis_hash):
From efaa5f48a64603391a5628d3fc2c583b8a247b89 Mon Sep 17 00:00:00 2001
From: William Graf <7930703+willgraf@users.noreply.github.com>
Date: Mon, 13 Jul 2020 14:48:48 -0700
Subject: [PATCH 03/12] scale should not be required.
---
README.md | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/README.md b/README.md
index 6327e19e..6fb90139 100644
--- a/README.md
+++ b/README.md
@@ -56,7 +56,7 @@ def _consume(self, redis_hash):
# save the results as an image file and upload it to the bucket
save_name = hvals.get('original_name', fname)
- dest, output_url = self.save_output(image, redis_hash, save_name, scale)
+ dest, output_url = self.save_output(image, redis_hash, save_name)
# save the results to the redis hash
self.update_key(redis_hash, {
From e09137271ca98daacb0f24d76a2ef0012078c554 Mon Sep 17 00:00:00 2001
From: William Graf <7930703+willgraf@users.noreply.github.com>
Date: Mon, 13 Jul 2020 14:49:10 -0700
Subject: [PATCH 04/12] model_name and model_version should be passed in
through environment, not through redis.
---
README.md | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/README.md b/README.md
index 6fb90139..9a64b78c 100644
--- a/README.md
+++ b/README.md
@@ -34,9 +34,9 @@ def _consume(self, redis_hash):
# the data to process with the model, required.
input_file_name = hvals.get('input_file_name')
- # TODO: model information can be saved in redis or defined in the consumer.
- model_name = hvals.get('model_name')
- model_version = hvals.get('model_version')
+ # TODO: the model can be passed in as an environment variable,
+ # and parsed in settings.py.
+ model_name, model_version = 'CustomModel:1'
with utils.get_tempdir() as tempdir:
# download the image file
From 4bbdac1c7bd1b6b2f8e244b143057a430e130379 Mon Sep 17 00:00:00 2001
From: William Graf <7930703+willgraf@users.noreply.github.com>
Date: Tue, 14 Jul 2020 15:25:02 -0700
Subject: [PATCH 05/12] Improve Consumer docstrings.
---
redis_consumer/consumers/base_consumer.py | 45 ++++++++++++++++--
redis_consumer/consumers/image_consumer.py | 53 ++++++++++++++++++----
2 files changed, 86 insertions(+), 12 deletions(-)
diff --git a/redis_consumer/consumers/base_consumer.py b/redis_consumer/consumers/base_consumer.py
index ffac8dae..220ae797 100644
--- a/redis_consumer/consumers/base_consumer.py
+++ b/redis_consumer/consumers/base_consumer.py
@@ -95,6 +95,13 @@ def _put_back_hash(self, redis_hash):
pass # success
def get_redis_hash(self):
+ """Pop off an item from the Job queue.
+
+ If a Job hash is invalid it will be failed and removed from the queue.
+
+ Returns:
+ str: A valid Redish Job hash, or None if one cannot be found.
+ """
while True:
if self.redis.llen(self.queue) == 0:
@@ -162,9 +169,9 @@ def update_key(self, redis_hash, data=None):
"""Update the hash with `data` and updated_by & updated_at stamps.
Args:
- redis_hash: string, the hash that will be updated
- status: string, the new status value
- data: dict, optional data to include in the hmset call
+ redis_hash (str): The hash that will be updated
+ status (str): The new status value
+ data (dict): Optional data to include in the hmset call
"""
if data is not None and not isinstance(data, dict):
raise ValueError('`data` must be a dictionary, got {}.'.format(
@@ -178,6 +185,7 @@ def update_key(self, redis_hash, data=None):
self.redis.hmset(redis_hash, data)
def _consume(self, redis_hash):
+ """Consume the Redis Job. All Consumers must implement this function"""
raise NotImplementedError
def consume(self):
@@ -249,6 +257,15 @@ def _consume(self, redis_hash):
raise NotImplementedError
def _get_predict_client(self, model_name, model_version):
+ """Returns the TensorFlow Serving gRPC client.
+
+ Args:
+ model_name (str): The name of the model
+ model_version (int): The version of the model
+
+ Returns:
+ redis_consumer.grpc_clients.PredictClient: the gRPC client.
+ """
t = timeit.default_timer()
hostname = '{}:{}'.format(settings.TF_HOST, settings.TF_PORT)
client = PredictClient(hostname, model_name, int(model_version))
@@ -258,7 +275,19 @@ def _get_predict_client(self, model_name, model_version):
def grpc_image(self, img, model_name, model_version, model_shape,
in_tensor_name='image', in_tensor_dtype='DT_FLOAT'):
+ """Use the TensorFlow Serving gRPC API for model inference on an image.
+ Args:
+ img (numpy.array): The image to send to the model
+ model_name (str): The name of the model
+ model_version (int): The version of the model
+ model_shape (tuple): The shape of input data for the model
+ in_tensor_name (str): The name of the input tensor for the request
+ in_tensor_dtype (str): The dtype of the input data
+
+ Returns:
+ numpy.array: The results of model inference.
+ """
in_tensor_dtype = str(in_tensor_dtype).upper()
start = timeit.default_timer()
@@ -491,6 +520,16 @@ def _predict_small_image(self,
return image
def predict(self, image, model_name, model_version, untile=True):
+ """Performs model inference on the image data.
+
+ Args:
+ image (numpy.array): the image data
+ model_name (str): hosted model to send image data.
+ model_version (int): model version to query.
+ untile (bool): Whether to untile the tiled inference results. This
+ should be True when the model output is the same shape as the
+ input, and False otherwise.
+ """
start = timeit.default_timer()
model_metadata = self.get_model_metadata(model_name, model_version)
diff --git a/redis_consumer/consumers/image_consumer.py b/redis_consumer/consumers/image_consumer.py
index 70be1821..5eed415b 100644
--- a/redis_consumer/consumers/image_consumer.py
+++ b/redis_consumer/consumers/image_consumer.py
@@ -49,7 +49,16 @@ def is_valid_hash(self, redis_hash):
return not fname.lower().endswith('.zip')
def _get_processing_function(self, process_type, function_name):
- """Based on the function category and name, return the function"""
+ """Based on the function category and name, return the function.
+
+ Args:
+ process_type (str): "pre" or "post" processing
+ function_name (str): Name processing function, must exist in
+ settings.PROCESSING_FUNCTIONS.
+
+ Returns:
+ function: the selected pre- or post-processing function.
+ """
clean = lambda x: str(x).lower()
# first, verify the route parameters
name = clean(function_name)
@@ -64,6 +73,16 @@ def _get_processing_function(self, process_type, function_name):
return settings.PROCESSING_FUNCTIONS[cat][name]
def process(self, image, key, process_type):
+ """Apply the pre- or post-processing function to the image data.
+
+ Args:
+ image (numpy.array): The image data to process.
+ key (str): The name of the function to use.
+ process_type (str): "pre" or "post" processing.
+
+ Returns:
+ numpy.array: The processed image data.
+ """
start = timeit.default_timer()
if not key:
return image
@@ -100,6 +119,15 @@ def process(self, image, key, process_type):
return results
def detect_scale(self, image):
+ """Send the image to the SCALE_DETECT_MODEL to detect the relative
+ scale difference from the image to the model's training data.
+
+ Args:
+ image (numpy.array): The image data.
+
+ Returns:
+ scale (float): The detected scale, used to rescale data.
+ """
start = timeit.default_timer()
if not settings.SCALE_DETECT_ENABLED:
@@ -122,6 +150,15 @@ def detect_scale(self, image):
return detected_scale
def detect_label(self, image):
+ """Send the image to the LABEL_DETECT_MODEL to detect the type of image
+ data. The model output is mapped with settings.MODEL_CHOICES.
+
+ Args:
+ image (numpy.array): The image data.
+
+ Returns:
+ label (int): The detected label.
+ """
start = timeit.default_timer()
if not settings.LABEL_DETECT_ENABLED:
@@ -147,12 +184,11 @@ def preprocess(self, image, keys):
"""Wrapper for _process_image but can only call with type="pre".
Args:
- image: numpy array of image data
- keys: list of function names to apply to the image
- streaming: boolean. if True, streams data in multiple requests
+ image (numpy.array): image data
+ keys (list): list of function names to apply to the image
Returns:
- pre-processed image data
+ numpy.array: pre-processed image data
"""
pre = None
for key in keys:
@@ -164,12 +200,11 @@ def postprocess(self, image, keys):
"""Wrapper for _process_image but can only call with type="post".
Args:
- image: numpy array of image data
- keys: list of function names to apply to the image
- streaming: boolean. if True, streams data in multiple requests
+ image (numpy.array): image data
+ keys (list): list of function names to apply to the image
Returns:
- post-processed image data
+ numpy.array: post-processed image data
"""
post = None
for key in keys:
From 4f729c65347e42af6feb3d1fd352ac51f0019b43 Mon Sep 17 00:00:00 2001
From: William Graf <7930703+willgraf@users.noreply.github.com>
Date: Wed, 15 Jul 2020 12:34:47 -0700
Subject: [PATCH 06/12] split(':') the model_name and model_version.
---
README.md | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/README.md b/README.md
index 9a64b78c..1e5bd688 100644
--- a/README.md
+++ b/README.md
@@ -36,7 +36,7 @@ def _consume(self, redis_hash):
# TODO: the model can be passed in as an environment variable,
# and parsed in settings.py.
- model_name, model_version = 'CustomModel:1'
+ model_name, model_version = 'CustomModel:1'.split(':')
with utils.get_tempdir() as tempdir:
# download the image file
From d54668d05c09f6c9d0ce443035e2560aa5dbdf30 Mon Sep 17 00:00:00 2001
From: William Graf <7930703+willgraf@users.noreply.github.com>
Date: Wed, 15 Jul 2020 12:34:58 -0700
Subject: [PATCH 07/12] remove the TODOs.
---
README.md | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/README.md b/README.md
index 1e5bd688..644829ed 100644
--- a/README.md
+++ b/README.md
@@ -44,14 +44,14 @@ def _consume(self, redis_hash):
# load image file as data
image = utils.get_image(fname)
- # TODO: pre- and post-processing can be used with the BaseConsumer.process,
+ # pre- and post-processing can be used with the BaseConsumer.process,
# which uses pre-defined functions in settings.PROCESSING_FUNCTIONS.
image = self.preprocess(image, 'normalize')
# send the data to the model
results = self.predict(image, model_name, model_version)
- # TODO: post-process the model results into a label image.
+ # post-process model results
image = self.postprocess(image, 'deep_watershed')
# save the results as an image file and upload it to the bucket
From 57e7a4353ccb869475b35694f750b1b6597d3a34 Mon Sep 17 00:00:00 2001
From: William Graf <7930703+willgraf@users.noreply.github.com>
Date: Wed, 15 Jul 2020 12:35:45 -0700
Subject: [PATCH 08/12] clarifications.
---
README.md | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/README.md b/README.md
index 644829ed..cb7db64d 100644
--- a/README.md
+++ b/README.md
@@ -13,7 +13,7 @@ This repository is part of the [DeepCell Kiosk](https://github.com/vanvalenlab/k
Custom consumers can be used to implement custom model pipelines. This documentation is a continuation of a [tutorial](https://deepcell-kiosk.readthedocs.io/en/master/CUSTOM-JOB.html) on building a custom job pipeline.
-Consumers consume Redis events. Each type of Redis event is put into a separate queue (e.g. `predict`, `track`), and each consumer type will pop items to consume off that queue.
+Consumers consume Redis events. Each type of Redis event is put into a queue (e.g. `predict`, `track`), and each queue has a specific consumer type will pop items off the queue.
Consumers call the `_consume` method to consume each item it finds in the queue.
This method must be implemented for every consumer.
From 8ef698afe4507ab06688c6558aa2699cd295314b Mon Sep 17 00:00:00 2001
From: William Graf <7930703+willgraf@users.noreply.github.com>
Date: Wed, 15 Jul 2020 13:08:27 -0700
Subject: [PATCH 09/12] move process functions into TensorFlowServingConsumer.
---
redis_consumer/consumers/base_consumer.py | 102 ++++++++++++++++++
.../consumers/base_consumer_test.py | 54 ++++++++++
redis_consumer/consumers/image_consumer.py | 102 ------------------
.../consumers/image_consumer_test.py | 54 ----------
4 files changed, 156 insertions(+), 156 deletions(-)
diff --git a/redis_consumer/consumers/base_consumer.py b/redis_consumer/consumers/base_consumer.py
index 220ae797..49dba321 100644
--- a/redis_consumer/consumers/base_consumer.py
+++ b/redis_consumer/consumers/base_consumer.py
@@ -592,6 +592,108 @@ def predict(self, image, model_name, model_version, untile=True):
return image
+ def _get_processing_function(self, process_type, function_name):
+ """Based on the function category and name, return the function.
+
+ Args:
+ process_type (str): "pre" or "post" processing
+ function_name (str): Name processing function, must exist in
+ settings.PROCESSING_FUNCTIONS.
+
+ Returns:
+ function: the selected pre- or post-processing function.
+ """
+ clean = lambda x: str(x).lower()
+ # first, verify the route parameters
+ name = clean(function_name)
+ cat = clean(process_type)
+ if cat not in settings.PROCESSING_FUNCTIONS:
+ raise ValueError('Processing functions are either "pre" or "post" '
+ 'processing. Got %s.' % cat)
+
+ if name not in settings.PROCESSING_FUNCTIONS[cat]:
+ raise ValueError('"%s" is not a valid %s-processing function'
+ % (name, cat))
+ return settings.PROCESSING_FUNCTIONS[cat][name]
+
+ def process(self, image, key, process_type):
+ """Apply the pre- or post-processing function to the image data.
+
+ Args:
+ image (numpy.array): The image data to process.
+ key (str): The name of the function to use.
+ process_type (str): "pre" or "post" processing.
+
+ Returns:
+ numpy.array: The processed image data.
+ """
+ start = timeit.default_timer()
+ if not key:
+ return image
+
+ f = self._get_processing_function(process_type, key)
+
+ if key == 'retinanet-semantic':
+ # image[:-1] is targeted at a two semantic head panoptic model
+ # TODO This may need to be modified and generalized in the future
+ results = f(image[:-1])
+ elif key == 'retinanet':
+ results = f(image, self._rawshape[0], self._rawshape[1])
+ else:
+ results = f(image)
+
+ if results.shape[0] == 1:
+ results = np.squeeze(results, axis=0)
+
+ finished = timeit.default_timer() - start
+
+ self.update_key(self._redis_hash, {
+ '{}process_time'.format(process_type): finished
+ })
+
+ self.logger.debug('%s-processed key %s (model %s:%s, preprocessing: %s,'
+ ' postprocessing: %s) in %s seconds.',
+ process_type.capitalize(), self._redis_hash,
+ self._redis_values.get('model_name'),
+ self._redis_values.get('model_version'),
+ self._redis_values.get('preprocess_function'),
+ self._redis_values.get('postprocess_function'),
+ finished)
+
+ return results
+
+ def preprocess(self, image, keys):
+ """Wrapper for _process_image but can only call with type="pre".
+
+ Args:
+ image (numpy.array): image data
+ keys (list): list of function names to apply to the image
+
+ Returns:
+ numpy.array: pre-processed image data
+ """
+ pre = None
+ for key in keys:
+ x = pre if pre else image
+ pre = self.process(x, key, 'pre')
+ return pre
+
+ def postprocess(self, image, keys):
+ """Wrapper for _process_image but can only call with type="post".
+
+ Args:
+ image (numpy.array): image data
+ keys (list): list of function names to apply to the image
+
+ Returns:
+ numpy.array: post-processed image data
+ """
+ post = None
+ for key in keys:
+ x = post if post else image
+ post = self.process(x, key, 'post')
+ return post
+
class ZipFileConsumer(Consumer):
"""Consumes zip files and uploads the results"""
diff --git a/redis_consumer/consumers/base_consumer_test.py b/redis_consumer/consumers/base_consumer_test.py
index 3d8b2b8d..a5ba06e0 100644
--- a/redis_consumer/consumers/base_consumer_test.py
+++ b/redis_consumer/consumers/base_consumer_test.py
@@ -373,6 +373,60 @@ def grpc_image_list(data, *args, **kwargs): # pylint: disable=W0613
x = np.random.random((300, 300, 1))
consumer.predict(x, model_name='modelname', model_version=0)
+ def test__get_processing_function(self, mocker, redis_client):
+ mocker.patch.object(settings, 'PROCESSING_FUNCTIONS', {
+ 'valid': {
+ 'valid': lambda x: True
+ }
+ })
+
+ storage = DummyStorage()
+ consumer = consumers.ImageFileConsumer(redis_client, storage, 'q')
+
+ x = consumer._get_processing_function('VaLiD', 'vAlId')
+ y = consumer._get_processing_function('vAlId', 'VaLiD')
+ assert x == y
+
+ with pytest.raises(ValueError):
+ consumer._get_processing_function('invalid', 'valid')
+
+ with pytest.raises(ValueError):
+ consumer._get_processing_function('valid', 'invalid')
+
+ def test_process(self, mocker, redis_client):
+ # TODO: better test coverage
+ storage = DummyStorage()
+ queue = 'q'
+ img = np.random.random((1, 32, 32, 1))
+
+ mocker.patch.object(settings, 'PROCESSING_FUNCTIONS', {
+ 'valid': {
+ 'valid': lambda x: x,
+ 'retinanet': lambda *x: x[0],
+ 'retinanet-semantic': lambda x: x,
+ }
+ })
+
+ consumer = consumers.ImageFileConsumer(redis_client, storage, queue)
+
+ mocker.patch.object(consumer, '_redis_hash', 'a hash')
+
+ output = consumer.process(img, '', '')
+ np.testing.assert_equal(img, output)
+
+ # image is returned but channel squeezed out
+ output = consumer.process(img, 'valid', 'valid')
+ np.testing.assert_equal(img[0], output)
+
+ img = np.random.random((2, 32, 32, 1))
+ output = consumer.process(img, 'retinanet-semantic', 'valid')
+ np.testing.assert_equal(img[0], output)
+
+ consumer._rawshape = (21, 21)
+ img = np.random.random((1, 32, 32, 1))
+ output = consumer.process(img, 'retinanet', 'valid')
+ np.testing.assert_equal(img[0], output)
+
class TestZipFileConsumer(object):
# pylint: disable=R0201,W0613,W0621
diff --git a/redis_consumer/consumers/image_consumer.py b/redis_consumer/consumers/image_consumer.py
index 5eed415b..056a27c1 100644
--- a/redis_consumer/consumers/image_consumer.py
+++ b/redis_consumer/consumers/image_consumer.py
@@ -48,76 +48,6 @@ def is_valid_hash(self, redis_hash):
fname = str(self.redis.hget(redis_hash, 'input_file_name'))
return not fname.lower().endswith('.zip')
- def _get_processing_function(self, process_type, function_name):
- """Based on the function category and name, return the function.
-
- Args:
- process_type (str): "pre" or "post" processing
- function_name (str): Name processing function, must exist in
- settings.PROCESSING_FUNCTIONS.
-
- Returns:
- function: the selected pre- or post-processing function.
- """
- clean = lambda x: str(x).lower()
- # first, verify the route parameters
- name = clean(function_name)
- cat = clean(process_type)
- if cat not in settings.PROCESSING_FUNCTIONS:
- raise ValueError('Processing functions are either "pre" or "post" '
- 'processing. Got %s.' % cat)
-
- if name not in settings.PROCESSING_FUNCTIONS[cat]:
- raise ValueError('"%s" is not a valid %s-processing function'
- % (name, cat))
- return settings.PROCESSING_FUNCTIONS[cat][name]
-
- def process(self, image, key, process_type):
- """Apply the pre- or post-processing function to the image data.
-
- Args:
- image (numpy.array): The image data to process.
- key (str): The name of the function to use.
- process_type (str): "pre" or "post" processing.
-
- Returns:
- numpy.array: The processed image data.
- """
- start = timeit.default_timer()
- if not key:
- return image
-
- f = self._get_processing_function(process_type, key)
-
- if key == 'retinanet-semantic':
- # image[:-1] is targeted at a two semantic head panoptic model
- # TODO This may need to be modified and generalized in the future
- results = f(image[:-1])
- elif key == 'retinanet':
- results = f(image, self._rawshape[0], self._rawshape[1])
- else:
- results = f(image)
-
- if results.shape[0] == 1:
- results = np.squeeze(results, axis=0)
-
- finished = timeit.default_timer() - start
-
- self.update_key(self._redis_hash, {
- '{}process_time'.format(process_type): finished
- })
-
- self.logger.debug('%s-processed key %s (model %s:%s, preprocessing: %s,'
- ' postprocessing: %s) in %s seconds.',
- process_type.capitalize(), self._redis_hash,
- self._redis_values.get('model_name'),
- self._redis_values.get('model_version'),
- self._redis_values.get('preprocess_function'),
- self._redis_values.get('postprocess_function'),
- finished)
-
- return results
-
def detect_scale(self, image):
"""Send the image to the SCALE_DETECT_MODEL to detect the relative
scale difference from the image to the model's training data.
@@ -180,38 +110,6 @@ def detect_label(self, image):
detected, timeit.default_timer() - start)
return detected
- def preprocess(self, image, keys):
- """Wrapper for _process_image but can only call with type="pre".
-
- Args:
- image (numpy.array): image data
- keys (list): list of function names to apply to the image
-
- Returns:
- numpy.array: pre-processed image data
- """
- pre = None
- for key in keys:
- x = pre if pre else image
- pre = self.process(x, key, 'pre')
- return pre
-
- def postprocess(self, image, keys):
- """Wrapper for _process_image but can only call with type="post".
-
- Args:
- image (numpy.array): image data
- keys (list): list of function names to apply to the image
-
- Returns:
- numpy.array: post-processed image data
- """
- post = None
- for key in keys:
- x = post if post else image
- post = self.process(x, key, 'post')
- return post
-
def _consume(self, redis_hash):
start = timeit.default_timer()
hvals = self.redis.hgetall(redis_hash)
diff --git a/redis_consumer/consumers/image_consumer_test.py b/redis_consumer/consumers/image_consumer_test.py
index 4ec5b075..3cdda144 100644
--- a/redis_consumer/consumers/image_consumer_test.py
+++ b/redis_consumer/consumers/image_consumer_test.py
@@ -55,60 +55,6 @@ def test_is_valid_hash(self, mocker, redis_client):
assert consumer.is_valid_hash('predict:1234567890:file.tiff') is True
assert consumer.is_valid_hash('predict:1234567890:file.png') is True
- def test__get_processing_function(self, mocker, redis_client):
- mocker.patch.object(settings, 'PROCESSING_FUNCTIONS', {
- 'valid': {
- 'valid': lambda x: True
- }
- })
-
- storage = DummyStorage()
- consumer = consumers.ImageFileConsumer(redis_client, storage, 'q')
-
- x = consumer._get_processing_function('VaLiD', 'vAlId')
- y = consumer._get_processing_function('vAlId', 'VaLiD')
- assert x == y
-
- with pytest.raises(ValueError):
- consumer._get_processing_function('invalid', 'valid')
-
- with pytest.raises(ValueError):
- consumer._get_processing_function('valid', 'invalid')
-
- def test_process(self, mocker, redis_client):
- # TODO: better test coverage
- storage = DummyStorage()
- queue = 'q'
- img = np.random.random((1, 32, 32, 1))
-
- mocker.patch.object(settings, 'PROCESSING_FUNCTIONS', {
- 'valid': {
- 'valid': lambda x: x,
- 'retinanet': lambda *x: x[0],
- 'retinanet-semantic': lambda x: x,
- }
- })
-
- consumer = consumers.ImageFileConsumer(redis_client, storage, queue)
-
- mocker.patch.object(consumer, '_redis_hash', 'a hash')
-
- output = consumer.process(img, '', '')
- np.testing.assert_equal(img, output)
-
- # image is returned but channel squeezed out
- output = consumer.process(img, 'valid', 'valid')
- np.testing.assert_equal(img[0], output)
-
- img = np.random.random((2, 32, 32, 1))
- output = consumer.process(img, 'retinanet-semantic', 'valid')
- np.testing.assert_equal(img[0], output)
-
- consumer._rawshape = (21, 21)
- img = np.random.random((1, 32, 32, 1))
- output = consumer.process(img, 'retinanet', 'valid')
- np.testing.assert_equal(img[0], output)
-
def test_detect_label(self, mocker, redis_client):
# pylint: disable=W0613
model_shape = (1, 216, 216, 1)
From 3c924376cad448b024ee0b10a4d02ce28fb42e9f Mon Sep 17 00:00:00 2001
From: William Graf <7930703+willgraf@users.noreply.github.com>
Date: Wed, 15 Jul 2020 13:10:09 -0700
Subject: [PATCH 10/12] Restructure into ordered list of instructions.
---
README.md | 28 +++++++++-------------------
1 file changed, 9 insertions(+), 19 deletions(-)
diff --git a/README.md b/README.md
index cb7db64d..e6345f3c 100644
--- a/README.md
+++ b/README.md
@@ -17,7 +17,14 @@ Consumers consume Redis events. Each type of Redis event is put into a queue (e.
Consumers call the `_consume` method to consume each item it finds in the queue.
This method must be implemented for every consumer.
-The quickest way to get a custom consumer up and running is to inherit from `redis_consumer.consumers.ImageFileConsumer` ([docs](https://deepcell-kiosk.readthedocs.io/projects/kiosk-redis-consumer/en/master/redis_consumer.consumers.html)), which uses the `preprocess`, `predict`, and `postprocess` methods to easily process data with the model.
+
+The quickest way to get a custom consumer up and running is to:
+
+1. Add a new file for the consumer: `redis_consumer/consumers/my_new_consumer.py`
+2. Create a new class, inheriting from `TensorFlowServingConsumer` ([docs](https://deepcell-kiosk.readthedocs.io/projects/kiosk-redis-consumer/en/master/redis_consumer.consumers.html)), which uses the `preprocess`, `predict`, and `postprocess` methods to easily process data with the model.
+3. Implement the `_consume` method, which should download the data, run inference on the data, save and upload the results, and finish the job by updating the Redis fields.
+4. Import the new consumer in redis_consumer/consumers/\_\_init\_\_.py and add it to the `CONSUMERS` dictionary with a correponding queue type (`queue_name`). The script consume-redis-events.py will load the consumer class based on the `CONSUMER_TYPE`.
+
See below for a basic implementation of `_consume()` making use of the methods inherited from `ImageFileConsumer`:
```python
@@ -34,7 +41,7 @@ def _consume(self, redis_hash):
# the data to process with the model, required.
input_file_name = hvals.get('input_file_name')
- # TODO: the model can be passed in as an environment variable,
+ # the model can be passed in as an environment variable,
# and parsed in settings.py.
model_name, model_version = 'CustomModel:1'.split(':')
@@ -73,23 +80,6 @@ def _consume(self, redis_hash):
return self.final_status
```
-Finally, the new consumer needs to be imported into the redis_consumer/consumers/\_\_init\_\_.py and added to the `CONSUMERS` dictionary with a correponding queue type (`queue_name`). The script consume-redis-events.py will load the consumer class based on the `CONSUMER_TYPE`.
-
-```python
-# Custom Workflow consumers
-from redis_consumer.consumers.image_consumer import ImageFileConsumer
-from redis_consumer.consumers.tracking_consumer import TrackingConsumer
-# TODO: Import future custom Consumer classes.
-
-
-CONSUMERS = {
- 'image': ImageFileConsumer,
- 'zip': ZipFileConsumer,
- 'tracking': TrackingConsumer,
- # TODO: Add future custom Consumer classes here.
-}
-```
-
For guidance on how to complete the deployment of a custom consumer, please return to [Tutorial: Custom Job](https://deepcell-kiosk.readthedocs.io/en/master/CUSTOM-JOB.html).
## Configuration
From ed70d6859c878b28cdb70f24b03e273b81f50e74 Mon Sep 17 00:00:00 2001
From: William Graf <7930703+willgraf@users.noreply.github.com>
Date: Thu, 16 Jul 2020 16:10:19 -0700
Subject: [PATCH 11/12] change the MultiplexConsumer base.
---
redis_consumer/consumers/multiplex_consumer.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/redis_consumer/consumers/multiplex_consumer.py b/redis_consumer/consumers/multiplex_consumer.py
index eaa969bd..34aad479 100644
--- a/redis_consumer/consumers/multiplex_consumer.py
+++ b/redis_consumer/consumers/multiplex_consumer.py
@@ -33,13 +33,13 @@
import numpy as np
-from redis_consumer.consumers import ImageFileConsumer
+from redis_consumer.consumers import TensorFlowServingConsumer
from redis_consumer import utils
from redis_consumer import settings
from redis_consumer import processing
-class MultiplexConsumer(ImageFileConsumer):
+class MultiplexConsumer(TensorFlowServingConsumer):
"""Consumes image files and uploads the results"""
def _consume(self, redis_hash):
From 6e683b716cdc7a937a15383cf828d2b89f5d210c Mon Sep 17 00:00:00 2001
From: William Graf <7930703+willgraf@users.noreply.github.com>
Date: Thu, 16 Jul 2020 16:17:02 -0700
Subject: [PATCH 12/12] undo base change.
---
redis_consumer/consumers/multiplex_consumer.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/redis_consumer/consumers/multiplex_consumer.py b/redis_consumer/consumers/multiplex_consumer.py
index 34aad479..eaa969bd 100644
--- a/redis_consumer/consumers/multiplex_consumer.py
+++ b/redis_consumer/consumers/multiplex_consumer.py
@@ -33,13 +33,13 @@
import numpy as np
-from redis_consumer.consumers import TensorFlowServingConsumer
+from redis_consumer.consumers import ImageFileConsumer
from redis_consumer import utils
from redis_consumer import settings
from redis_consumer import processing
-class MultiplexConsumer(TensorFlowServingConsumer):
+class MultiplexConsumer(ImageFileConsumer):
"""Consumes image files and uploads the results"""
def _consume(self, redis_hash):