Skip to content

Commit

Permalink
Merge 6e683b7 into 6372e3f
Browse files Browse the repository at this point in the history
  • Loading branch information
willgraf committed Jul 16, 2020
2 parents 6372e3f + 6e683b7 commit eda1e87
Show file tree
Hide file tree
Showing 5 changed files with 274 additions and 202 deletions.
120 changes: 59 additions & 61 deletions README.md
Expand Up @@ -13,73 +13,71 @@ This repository is part of the [DeepCell Kiosk](https://github.com/vanvalenlab/k

Custom consumers can be used to implement custom model pipelines. This documentation is a continuation of a [tutorial](https://deepcell-kiosk.readthedocs.io/en/master/CUSTOM-JOB.html) on building a custom job pipeline.

Consumers consume Redis events. Each type of Redis event is put into a separate queue (e.g. `predict`, `track`), and each consumer type will pop items to consume off that queue.
Consumers consume Redis events. Each type of Redis event is put into a queue (e.g. `predict`, `track`), and each queue has a specific consumer type will pop items off the queue.
Consumers call the `_consume` method to consume each item it finds in the queue.
This method must be implemented for every consumer.

Each Redis event should have the following fields:

- `model_name` - The name of the model that will be retrieved by TensorFlow Serving from `gs://<bucket-name>/models`
- `model_version` - The version number of the model in TensorFlow Serving
- `input_file_name` - The path to the data file in a cloud bucket.
The quickest way to get a custom consumer up and running is to:

If the consumer will send data to a TensorFlow Serving model, it should inherit from `redis_consumer.consumers.TensorFlowServingConsumer` ([docs](https://deepcell-kiosk.readthedocs.io/projects/kiosk-redis-consumer/en/master/redis_consumer.consumers.html)), which has methods `_get_predict_client()` and `grpc_image()` which can send data to the specific model. The new consumer must also implement the `_consume()` method which performs the bulk of the work. The `_consume()` method will fetch data from Redis, download data file from the bucket, process the data with a model, and upload the results to the bucket again. See below for a basic implementation of `_consume()`:
1. Add a new file for the consumer: `redis_consumer/consumers/my_new_consumer.py`
2. Create a new class, inheriting from `TensorFlowServingConsumer` ([docs](https://deepcell-kiosk.readthedocs.io/projects/kiosk-redis-consumer/en/master/redis_consumer.consumers.html)), which uses the `preprocess`, `predict`, and `postprocess` methods to easily process data with the model.
3. Implement the `_consume` method, which should download the data, run inference on the data, save and upload the results, and finish the job by updating the Redis fields.
4. Import the new consumer in <tt><a href="https://github.com/vanvalenlab/kiosk-redis-consumer/blob/master/redis_consumer/consumers/__init__.py">redis_consumer/consumers/\_\_init\_\_.py</a></tt> and add it to the `CONSUMERS` dictionary with a correponding queue type (`queue_name`). The script <tt><a href="https://github.com/vanvalenlab/kiosk-redis-consumer/blob/master/consume-redis-events.py">consume-redis-events.py</a></tt> will load the consumer class based on the `CONSUMER_TYPE`.

```python
def _consume(self, redis_hash):
# get all redis data for the given hash
hvals = self.redis.hgetall(redis_hash)

with utils.get_tempdir() as tempdir:
# download the image file
fname = self.storage.download(hvals.get('input_file_name'), tempdir)

# load image file as data
image = utils.get_image(fname)

# preprocess data if necessary

# send the data to the model
results = self.grpc_image(image,
hvals.get('model_name'),
hvals.get('model_version'))

# postprocess results if necessary

# save the results as an image
outpaths = utils.save_numpy_array(results, name=name,
subdir=subdir, output_dir=tempdir)

# zip up the file
zip_file = utils.zip_files(outpaths, tempdir)

# upload the zip file to the cloud bucket
dest, output_url = self.storage.upload(zip_file)

# save the results to the redis hash
self.update_key(redis_hash, {
'status': self.final_status,
'output_url': output_url,
'output_file_name': dest
})

# return the final status
return self.final_status
```

Finally, the new consumer needs to be imported into the <tt><a href="https://github.com/vanvalenlab/kiosk-redis-consumer/blob/master/redis_consumer/consumers/__init__.py">redis_consumer/consumers/\_\_init\_\_.py</a></tt> and added to the `CONSUMERS` dictionary with a correponding queue type (`queue_name`). The script <tt><a href="https://github.com/vanvalenlab/kiosk-redis-consumer/blob/master/consume-redis-events.py">consume-redis-events.py</a></tt> will load the consumer class based on the `CONSUMER_TYPE`.
See below for a basic implementation of `_consume()` making use of the methods inherited from `ImageFileConsumer`:

```python
# Custom Workflow consumers
from redis_consumer.consumers.image_consumer import ImageFileConsumer
from redis_consumer.consumers.tracking_consumer import TrackingConsumer
# TODO: Import future custom Consumer classes.


CONSUMERS = {
'image': ImageFileConsumer,
'zip': ZipFileConsumer,
'tracking': TrackingConsumer,
# TODO: Add future custom Consumer classes here.
}
def _consume(self, redis_hash):
# get all redis data for the given hash
hvals = self.redis.hgetall(redis_hash)

# only work on unfinished jobs
if hvals.get('status') in self.finished_statuses:
self.logger.warning('Found completed hash `%s` with status %s.',
redis_hash, hvals.get('status'))
return hvals.get('status')

# the data to process with the model, required.
input_file_name = hvals.get('input_file_name')

# the model can be passed in as an environment variable,
# and parsed in settings.py.
model_name, model_version = 'CustomModel:1'.split(':')

with utils.get_tempdir() as tempdir:
# download the image file
fname = self.storage.download(input_file_name, tempdir)
# load image file as data
image = utils.get_image(fname)

# pre- and post-processing can be used with the BaseConsumer.process,
# which uses pre-defined functions in settings.PROCESSING_FUNCTIONS.
image = self.preprocess(image, 'normalize')

# send the data to the model
results = self.predict(image, model_name, model_version)

# post-process model results
image = self.postprocess(image, 'deep_watershed')

# save the results as an image file and upload it to the bucket
save_name = hvals.get('original_name', fname)
dest, output_url = self.save_output(image, redis_hash, save_name)

# save the results to the redis hash
self.update_key(redis_hash, {
'status': self.final_status,
'output_url': output_url,
'upload_time': timeit.default_timer() - _,
'output_file_name': dest,
'total_jobs': 1,
'total_time': timeit.default_timer() - start,
'finished_at': self.get_current_timestamp()
})

# return the final status
return self.final_status
```

For guidance on how to complete the deployment of a custom consumer, please return to [Tutorial: Custom Job](https://deepcell-kiosk.readthedocs.io/en/master/CUSTOM-JOB.html).
Expand Down
147 changes: 144 additions & 3 deletions redis_consumer/consumers/base_consumer.py
Expand Up @@ -95,6 +95,13 @@ def _put_back_hash(self, redis_hash):
pass # success

def get_redis_hash(self):
"""Pop off an item from the Job queue.
If a Job hash is invalid it will be failed and removed from the queue.
Returns:
str: A valid Redish Job hash, or None if one cannot be found.
"""
while True:

if self.redis.llen(self.queue) == 0:
Expand Down Expand Up @@ -162,9 +169,9 @@ def update_key(self, redis_hash, data=None):
"""Update the hash with `data` and updated_by & updated_at stamps.
Args:
redis_hash: string, the hash that will be updated
status: string, the new status value
data: dict, optional data to include in the hmset call
redis_hash (str): The hash that will be updated
status (str): The new status value
data (dict): Optional data to include in the hmset call
"""
if data is not None and not isinstance(data, dict):
raise ValueError('`data` must be a dictionary, got {}.'.format(
Expand All @@ -178,6 +185,7 @@ def update_key(self, redis_hash, data=None):
self.redis.hmset(redis_hash, data)

def _consume(self, redis_hash):
"""Consume the Redis Job. All Consumers must implement this function"""
raise NotImplementedError

def consume(self):
Expand Down Expand Up @@ -249,6 +257,15 @@ def _consume(self, redis_hash):
raise NotImplementedError

def _get_predict_client(self, model_name, model_version):
"""Returns the TensorFlow Serving gRPC client.
Args:
model_name (str): The name of the model
model_version (int): The version of the model
Returns:
redis_consumer.grpc_clients.PredictClient: the gRPC client.
"""
t = timeit.default_timer()
hostname = '{}:{}'.format(settings.TF_HOST, settings.TF_PORT)
client = PredictClient(hostname, model_name, int(model_version))
Expand All @@ -258,7 +275,19 @@ def _get_predict_client(self, model_name, model_version):

def grpc_image(self, img, model_name, model_version, model_shape,
in_tensor_name='image', in_tensor_dtype='DT_FLOAT'):
"""Use the TensorFlow Serving gRPC API for model inference on an image.
Args:
img (numpy.array): The image to send to the model
model_name (str): The name of the model
model_version (int): The version of the model
model_shape (tuple): The shape of input data for the model
in_tensor_name (str): The name of the input tensor for the request
in_tensor_dtype (str): The dtype of the input data
Returns:
numpy.array: The results of model inference.
"""
in_tensor_dtype = str(in_tensor_dtype).upper()

start = timeit.default_timer()
Expand Down Expand Up @@ -491,6 +520,16 @@ def _predict_small_image(self,
return image

def predict(self, image, model_name, model_version, untile=True):
"""Performs model inference on the image data.
Args:
image (numpy.array): the image data
model_name (str): hosted model to send image data.
model_version (int): model version to query.
untile (bool): Whether to untile the tiled inference results. This
should be True when the model output is the same shape as the
input, and False otherwise.
"""
start = timeit.default_timer()
model_metadata = self.get_model_metadata(model_name, model_version)

Expand Down Expand Up @@ -553,6 +592,108 @@ def predict(self, image, model_name, model_version, untile=True):

return image

def _get_processing_function(self, process_type, function_name):
"""Based on the function category and name, return the function.
Args:
process_type (str): "pre" or "post" processing
function_name (str): Name processing function, must exist in
settings.PROCESSING_FUNCTIONS.
Returns:
function: the selected pre- or post-processing function.
"""
clean = lambda x: str(x).lower()
# first, verify the route parameters
name = clean(function_name)
cat = clean(process_type)
if cat not in settings.PROCESSING_FUNCTIONS:
raise ValueError('Processing functions are either "pre" or "post" '
'processing. Got %s.' % cat)

if name not in settings.PROCESSING_FUNCTIONS[cat]:
raise ValueError('"%s" is not a valid %s-processing function'
% (name, cat))
return settings.PROCESSING_FUNCTIONS[cat][name]

def process(self, image, key, process_type):
"""Apply the pre- or post-processing function to the image data.
Args:
image (numpy.array): The image data to process.
key (str): The name of the function to use.
process_type (str): "pre" or "post" processing.
Returns:
numpy.array: The processed image data.
"""
start = timeit.default_timer()
if not key:
return image

f = self._get_processing_function(process_type, key)

if key == 'retinanet-semantic':
# image[:-1] is targeted at a two semantic head panoptic model
# TODO This may need to be modified and generalized in the future
results = f(image[:-1])
elif key == 'retinanet':
results = f(image, self._rawshape[0], self._rawshape[1])
else:
results = f(image)

if results.shape[0] == 1:
results = np.squeeze(results, axis=0)

finished = timeit.default_timer() - start

self.update_key(self._redis_hash, {
'{}process_time'.format(process_type): finished
})

self.logger.debug('%s-processed key %s (model %s:%s, preprocessing: %s,'
' postprocessing: %s) in %s seconds.',
process_type.capitalize(), self._redis_hash,
self._redis_values.get('model_name'),
self._redis_values.get('model_version'),
self._redis_values.get('preprocess_function'),
self._redis_values.get('postprocess_function'),
finished)

return results

def preprocess(self, image, keys):
"""Wrapper for _process_image but can only call with type="pre".
Args:
image (numpy.array): image data
keys (list): list of function names to apply to the image
Returns:
numpy.array: pre-processed image data
"""
pre = None
for key in keys:
x = pre if pre else image
pre = self.process(x, key, 'pre')
return pre

def postprocess(self, image, keys):
"""Wrapper for _process_image but can only call with type="post".
Args:
image (numpy.array): image data
keys (list): list of function names to apply to the image
Returns:
numpy.array: post-processed image data
"""
post = None
for key in keys:
x = post if post else image
post = self.process(x, key, 'post')
return post

def save_output(self, image, redis_hash, save_name, scale=1):
with utils.get_tempdir() as tempdir:
# Save each result channel as an image file
Expand Down
54 changes: 54 additions & 0 deletions redis_consumer/consumers/base_consumer_test.py
Expand Up @@ -373,6 +373,60 @@ def grpc_image_list(data, *args, **kwargs): # pylint: disable=W0613
x = np.random.random((300, 300, 1))
consumer.predict(x, model_name='modelname', model_version=0)

def test__get_processing_function(self, mocker, redis_client):
mocker.patch.object(settings, 'PROCESSING_FUNCTIONS', {
'valid': {
'valid': lambda x: True
}
})

storage = DummyStorage()
consumer = consumers.ImageFileConsumer(redis_client, storage, 'q')

x = consumer._get_processing_function('VaLiD', 'vAlId')
y = consumer._get_processing_function('vAlId', 'VaLiD')
assert x == y

with pytest.raises(ValueError):
consumer._get_processing_function('invalid', 'valid')

with pytest.raises(ValueError):
consumer._get_processing_function('valid', 'invalid')

def test_process(self, mocker, redis_client):
# TODO: better test coverage
storage = DummyStorage()
queue = 'q'
img = np.random.random((1, 32, 32, 1))

mocker.patch.object(settings, 'PROCESSING_FUNCTIONS', {
'valid': {
'valid': lambda x: x,
'retinanet': lambda *x: x[0],
'retinanet-semantic': lambda x: x,
}
})

consumer = consumers.ImageFileConsumer(redis_client, storage, queue)

mocker.patch.object(consumer, '_redis_hash', 'a hash')

output = consumer.process(img, '', '')
np.testing.assert_equal(img, output)

# image is returned but channel squeezed out
output = consumer.process(img, 'valid', 'valid')
np.testing.assert_equal(img[0], output)

img = np.random.random((2, 32, 32, 1))
output = consumer.process(img, 'retinanet-semantic', 'valid')
np.testing.assert_equal(img[0], output)

consumer._rawshape = (21, 21)
img = np.random.random((1, 32, 32, 1))
output = consumer.process(img, 'retinanet', 'valid')
np.testing.assert_equal(img[0], output)


class TestZipFileConsumer(object):
# pylint: disable=R0201,W0613,W0621
Expand Down

0 comments on commit eda1e87

Please sign in to comment.