Skip to content

Commit

Permalink
Merge 07f1c1c into 37101d1
Browse files Browse the repository at this point in the history
  • Loading branch information
willgraf committed Mar 26, 2020
2 parents 37101d1 + 07f1c1c commit 4784422
Show file tree
Hide file tree
Showing 6 changed files with 116 additions and 112 deletions.
22 changes: 13 additions & 9 deletions consume-redis-events.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from __future__ import division
from __future__ import print_function

import gc
import logging
import logging.handlers
import sys
Expand Down Expand Up @@ -66,14 +67,10 @@ def initialize_logger(debug_mode=True):

def get_consumer(consumer_type, **kwargs):
logging.debug('Getting `%s` consumer with args %s.', consumer_type, kwargs)
ct = str(consumer_type).lower()
if ct == 'image':
return redis_consumer.consumers.ImageFileConsumer(**kwargs)
if ct == 'zip':
return redis_consumer.consumers.ZipFileConsumer(**kwargs)
if ct == 'tracking':
return redis_consumer.consumers.TrackingConsumer(**kwargs)
raise ValueError('Invalid `consumer_type`: "{}"'.format(consumer_type))
consumer_cls = redis_consumer.consumers.CONSUMERS.get(str(consumer_type).lower())
if not consumer_cls:
raise ValueError('Invalid `consumer_type`: "{}"'.format(consumer_type))
return consumer_cls(**kwargs)


if __name__ == '__main__':
Expand All @@ -91,17 +88,24 @@ def get_consumer(consumer_type, **kwargs):
consumer_kwargs = {
'redis_client': redis,
'storage_client': storage_client,
'final_status': 'done',
'queue': settings.QUEUE,
'final_status': 'done',
'failed_status': 'failed',
'name': settings.HOSTNAME,
'output_dir': settings.OUTPUT_DIR,
}

_logger.debug('Getting `%s` consumer with args %s.',
settings.CONSUMER_TYPE, consumer_kwargs)

consumer = get_consumer(settings.CONSUMER_TYPE, **consumer_kwargs)

_logger.debug('Got `%s` consumer.', settings.CONSUMER_TYPE)

while True:
try:
consumer.consume()
gc.collect()
except Exception as err: # pylint: disable=broad-except
_logger.critical('Fatal Error: %s: %s\n%s',
type(err).__name__, err, traceback.format_exc())
Expand Down
9 changes: 9 additions & 0 deletions redis_consumer/consumers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,15 @@
from redis_consumer.consumers.image_consumer import ImageFileConsumer
from redis_consumer.consumers.tracking_consumer import TrackingConsumer


CONSUMERS = {
'image': ImageFileConsumer,
'zip': ZipFileConsumer,
'tracking': TrackingConsumer,
# TODO: Add future custom Consumer classes here.
}


del absolute_import
del division
del print_function
32 changes: 16 additions & 16 deletions redis_consumer/consumers/base_consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,18 +65,20 @@ def __init__(self,
storage_client,
queue,
final_status='done',
failed_status='failed'):
self.output_dir = settings.OUTPUT_DIR
self.hostname = settings.HOSTNAME
failed_status='failed',
name=settings.HOSTNAME,
output_dir=settings.OUTPUT_DIR):
self.redis = redis_client
self.storage = storage_client
self.queue = str(queue).lower()
self.name = name
self.output_dir = output_dir
self.final_status = final_status
self.failed_status = failed_status
self.finished_statuses = {final_status, failed_status}
self.logger = logging.getLogger(str(self.__class__.__name__))
self.processing_queue = 'processing-{queue}:{name}'.format(
queue=self.queue, name=self.hostname)
queue=self.queue, name=self.name)

def _put_back_hash(self, redis_hash):
"""Put the hash back into the work queue"""
Expand Down Expand Up @@ -147,10 +149,10 @@ def purge_processing_queue(self):
while queue_has_items:
key = self.redis.rpoplpush(self.processing_queue, self.queue)
queue_has_items = key is not None

self.logger.debug('Found stranded key `%s` in queue `%s`. '
'Moving it back to `%s`.',
key, self.processing_queue, self.queue)
if queue_has_items:
self.logger.debug('Found stranded key `%s` in queue `%s`. '
'Moving it back to `%s`.',
key, self.processing_queue, self.queue)

def update_key(self, redis_hash, data=None):
"""Update the hash with `data` and updated_by & updated_at stamps.
Expand All @@ -167,7 +169,7 @@ def update_key(self, redis_hash, data=None):
data = {} if data is None else data
data.update({
'updated_at': self.get_current_timestamp(),
'updated_by': self.hostname,
'updated_by': self.name,
})
self.redis.hmset(redis_hash, data)

Expand Down Expand Up @@ -231,13 +233,12 @@ def __init__(self,
redis_client,
storage_client,
queue,
final_status='done'):
**kwargs):
# Create some attributes only used during consume()
self._redis_hash = None
self._redis_values = dict()
super(TensorFlowServingConsumer, self).__init__(
redis_client, storage_client,
queue, final_status)
redis_client, storage_client, queue, **kwargs)

def _consume(self, redis_hash):
raise NotImplementedError
Expand Down Expand Up @@ -523,13 +524,12 @@ def __init__(self,
redis_client,
storage_client,
queue,
final_status='done'):
**kwargs):
# zip files go in a new queue
zip_queue = '{}-zip'.format(queue)
self.child_queue = queue
super(ZipFileConsumer, self).__init__(
redis_client, storage_client,
zip_queue, final_status)
redis_client, storage_client, zip_queue, **kwargs)

def is_valid_hash(self, redis_hash):
if redis_hash is None:
Expand Down Expand Up @@ -563,7 +563,7 @@ def _upload_archived_images(self, hvalues, redis_hash):
new_hvals['input_file_name'] = dest
new_hvals['original_name'] = clean_imfile
new_hvals['status'] = 'new'
new_hvals['identity_upload'] = self.hostname
new_hvals['identity_upload'] = self.name
new_hvals['created_at'] = current_timestamp
new_hvals['updated_at'] = current_timestamp

Expand Down
158 changes: 75 additions & 83 deletions redis_consumer/consumers/image_consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,16 +41,6 @@
class ImageFileConsumer(TensorFlowServingConsumer):
"""Consumes image files and uploads the results"""

def __init__(self,
redis_client,
storage_client,
queue,
final_status='done'):
# Create some attributes only used during consume()
super(ImageFileConsumer, self).__init__(
redis_client, storage_client,
queue, final_status)

def is_valid_hash(self, redis_hash):
if redis_hash is None:
return False
Expand Down Expand Up @@ -204,82 +194,84 @@ def _consume(self, redis_hash):

self.update_key(redis_hash, {
'status': 'started',
'identity_started': self.hostname,
'identity_started': self.name,
})

# Overridden with LABEL_DETECT_ENABLED
model_name = hvals.get('model_name')
model_version = hvals.get('model_version')

_ = timeit.default_timer()

with utils.get_tempdir() as tempdir:
_ = timeit.default_timer()
fname = self.storage.download(hvals.get('input_file_name'), tempdir)
image = utils.get_image(fname)

# Pre-process data before sending to the model
self.update_key(redis_hash, {
'status': 'pre-processing',
'download_time': timeit.default_timer() - _,
})

# Calculate scale of image and rescale
scale = hvals.get('scale', '')
if not scale:
# Detect scale of image
scale = self.detect_scale(image)
self.logger.debug('Image scale detected: %s', scale)
self.update_key(redis_hash, {'scale': scale})
else:
scale = float(scale)
self.logger.debug('Image scale already calculated: %s', scale)

image = utils.rescale(image, scale)

# Save shape value for postprocessing purposes
# TODO this is a big janky
self._rawshape = image.shape
label = None
if settings.LABEL_DETECT_ENABLED and model_name and model_version:
self.logger.warning('Label Detection is enabled, but the model'
' %s:%s was specified in Redis.',
model_name, model_version)

elif settings.LABEL_DETECT_ENABLED:
# Detect image label type
label = hvals.get('label', '')
if not label:
label = self.detect_label(image)
self.logger.debug('Image label detected: %s', label)
self.update_key(redis_hash, {'label': str(label)})
else:
label = int(label)
self.logger.debug('Image label already calculated: %s', label)

# Grap appropriate model
model_name, model_version = utils._pick_model(label)

pre_funcs = hvals.get('preprocess_function', '').split(',')
image = self.preprocess(image, pre_funcs)

# Send data to the model
self.update_key(redis_hash, {'status': 'predicting'})

image = self.predict(image, model_name, model_version)

# Post-process model results
self.update_key(redis_hash, {'status': 'post-processing'})

if settings.LABEL_DETECT_ENABLED and label is not None:
post_funcs = utils._pick_postprocess(label).split(',')
# Pre-process data before sending to the model
self.update_key(redis_hash, {
'status': 'pre-processing',
'download_time': timeit.default_timer() - _,
})

# Calculate scale of image and rescale
scale = hvals.get('scale', '')
if not scale:
# Detect scale of image
scale = self.detect_scale(image)
self.logger.debug('Image scale detected: %s', scale)
self.update_key(redis_hash, {'scale': scale})
else:
scale = float(scale)
self.logger.debug('Image scale already calculated: %s', scale)

image = utils.rescale(image, scale)

# Save shape value for postprocessing purposes
# TODO this is a big janky
self._rawshape = image.shape
label = None
if settings.LABEL_DETECT_ENABLED and model_name and model_version:
self.logger.warning('Label Detection is enabled, but the model'
' %s:%s was specified in Redis.',
model_name, model_version)

elif settings.LABEL_DETECT_ENABLED:
# Detect image label type
label = hvals.get('label', '')
if not label:
label = self.detect_label(image)
self.logger.debug('Image label detected: %s', label)
self.update_key(redis_hash, {'label': str(label)})
else:
post_funcs = hvals.get('postprocess_function', '').split(',')
label = int(label)
self.logger.debug('Image label already calculated: %s', label)

# Grap appropriate model
model_name, model_version = utils._pick_model(label)

pre_funcs = hvals.get('preprocess_function', '').split(',')
image = self.preprocess(image, pre_funcs)

image = self.postprocess(image, post_funcs)
# Send data to the model
self.update_key(redis_hash, {'status': 'predicting'})

# Save the post-processed results to a file
_ = timeit.default_timer()
self.update_key(redis_hash, {'status': 'saving-results'})
image = self.predict(image, model_name, model_version)

# Post-process model results
self.update_key(redis_hash, {'status': 'post-processing'})

if settings.LABEL_DETECT_ENABLED and label is not None:
post_funcs = utils._pick_postprocess(label).split(',')
else:
post_funcs = hvals.get('postprocess_function', '').split(',')

image = self.postprocess(image, post_funcs)

# Save the post-processed results to a file
_ = timeit.default_timer()
self.update_key(redis_hash, {'status': 'saving-results'})

with utils.get_tempdir() as tempdir:
# Save each result channel as an image file
save_name = hvals.get('original_name', fname)
subdir = os.path.dirname(save_name.replace(tempdir, ''))
Expand All @@ -306,15 +298,15 @@ def _consume(self, redis_hash):
subdir = subdir if subdir else None
dest, output_url = self.storage.upload(zip_file, subdir=subdir)

# Update redis with the final results
t = timeit.default_timer() - start
self.update_key(redis_hash, {
'status': self.final_status,
'output_url': output_url,
'upload_time': timeit.default_timer() - _,
'output_file_name': dest,
'total_jobs': 1,
'total_time': t,
'finished_at': self.get_current_timestamp()
})
# Update redis with the final results
t = timeit.default_timer() - start
self.update_key(redis_hash, {
'status': self.final_status,
'output_url': output_url,
'upload_time': timeit.default_timer() - _,
'output_file_name': dest,
'total_jobs': 1,
'total_time': t,
'finished_at': self.get_current_timestamp()
})
return self.final_status
2 changes: 1 addition & 1 deletion redis_consumer/consumers/tracking_consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def _load_data(self, redis_hash, subdir, fname):
# prepare hvalues for this frame's hash
current_timestamp = self.get_current_timestamp()
frame_hvalues = {
'identity_upload': self.hostname,
'identity_upload': self.name,
'input_file_name': upload_file_name,
'original_name': segment_fname,
'model_name': model_name,
Expand Down
5 changes: 2 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
boto3==1.9.195
google-cloud-storage>=1.16.1
python-decouple==3.1
redis==3.2.1
redis==3.4.1
scikit-image>=0.14.0
keras-preprocessing==1.1.0
grpcio==1.22.0
grpcio==1.27.2
dict-to-protobuf==0.0.3.9
pytz==2019.1
keras_retinanet==0.5.1
opencv-python==4.1.0.25
deepcell-tracking==0.2.5
deepcell-toolbox==0.2.0

0 comments on commit 4784422

Please sign in to comment.