diff --git a/consume-redis-events.py b/consume-redis-events.py index a92fa566..2e1b313b 100644 --- a/consume-redis-events.py +++ b/consume-redis-events.py @@ -32,13 +32,32 @@ from __future__ import print_function import sys +import signal import traceback import logging import logging.handlers import redis_consumer from redis_consumer import settings -from redis_consumer import storage + + +class GracefulDeath: + """Catch signals to allow graceful shutdown. + + Adapted from: https://stackoverflow.com/questions/18499497 + """ + + def __init__(self): + self.signum = None + self.kill_now = False + self.logger = logging.getLogger(str(self.__class__.__name__)) + signal.signal(signal.SIGINT, self.handle_signal) + signal.signal(signal.SIGTERM, self.handle_signal) + + def handle_signal(self, signum, frame): # pylint: disable=unused-argument + self.signum = signum + self.kill_now = True + self.logger.debug('Received signal `%s` and frame `%s`', signum, frame) def initialize_logger(debug_mode=True): @@ -76,6 +95,7 @@ def get_consumer(consumer_type, **kwargs): if __name__ == '__main__': initialize_logger(settings.DEBUG) + sighandler = GracefulDeath() _logger = logging.getLogger(__file__) @@ -98,8 +118,12 @@ def get_consumer(consumer_type, **kwargs): while True: try: consumer.consume() + if sighandler.kill_now: + break except Exception as err: # pylint: disable=broad-except _logger.critical('Fatal Error: %s: %s\n%s', type(err).__name__, err, traceback.format_exc()) sys.exit(1) + + _logger.info('Gracefully exited after signal number %s', sighandler.signum) diff --git a/redis_consumer/consumers.py b/redis_consumer/consumers.py index 6969efcb..12d8c6db 100644 --- a/redis_consumer/consumers.py +++ b/redis_consumer/consumers.py @@ -34,9 +34,9 @@ import uuid import urllib import timeit -import datetime import logging import zipfile +import datetime import pytz import grpc @@ -52,9 +52,10 @@ class Consumer(object): """Base class for all redis event consumer classes. Args: - redis_client: Client class to communicate with redis - storage_client: Client to communicate with cloud storage buckets. - final_status: Update the status of redis event with this value. + redis_client: obj, Client class to communicate with redis + storage_client: obj, Client to communicate with cloud storage buckets. + queue: str, Name of queue to pop off work items. + final_status: str, Update the status of redis event with this value. """ def __init__(self, @@ -67,9 +68,26 @@ def __init__(self, self.redis = redis_client self.storage = storage_client self.queue = str(queue).lower() - self.processing_queue = 'processing-{}'.format(self.queue) self.final_status = final_status self.logger = logging.getLogger(str(self.__class__.__name__)) + self.processing_queue = 'processing-{queue}:{name}'.format( + queue=self.queue, name=self.hostname) + + def _put_back_hash(self, redis_hash): + """Put the hash back into the work queue""" + queue_size = self.redis.llen(self.processing_queue) + if queue_size == 1: + key = self.redis.rpoplpush(self.processing_queue, self.queue) + if key != redis_hash: + self.logger.warning('`RPOPLPUSH %s %s` popped key %s but' + 'expected key to be %s', + self.processing_queue, self.queue, + key, redis_hash) + else: + self.logger.warning('Expected `%s` would have 1 item, but has %s. ' + 'restarting the key the old way') + self.redis.lrem(self.processing_queue, 1, redis_hash) + self.redis.lpush(self.queue, redis_hash) def get_redis_hash(self): while True: @@ -79,14 +97,15 @@ def get_redis_hash(self): if redis_hash is None: return redis_hash + self.update_key(redis_hash) # update timestamp that it was touched + # if hash is found and valid, return the hash if self.is_valid_hash(redis_hash): return redis_hash # this invalid hash should not be processed by this consumer. # remove it from processing, and push it back to the work queue. - self.redis.lrem(self.processing_queue, 1, redis_hash) - self.redis.lpush(self.queue, redis_hash) + self._put_back_hash(redis_hash) def _handle_error(self, err, redis_hash): """Update redis with failure information, and log errors. @@ -96,7 +115,8 @@ def _handle_error(self, err, redis_hash): redis_hash: string, the hash that will be updated to failure. """ # Update redis with failed status - self.update_status(redis_hash, 'failed', { + self.update_key(redis_hash, { + 'status': 'failed', 'reason': '{}: {}'.format(type(err).__name__, err), }) self.logger.error('Failed to process redis key %s due to %s: %s', @@ -110,8 +130,8 @@ def get_current_timestamp(self): """Helper function, returns ISO formatted UTC timestamp""" return datetime.datetime.now(pytz.UTC).isoformat() - def update_status(self, redis_hash, status, data=None): - """Update the status of a the given hash. + def update_key(self, redis_hash, data=None): + """Update the hash with `data` and updated_by & updated_at stamps. Args: redis_hash: string, the hash that will be updated @@ -124,8 +144,8 @@ def update_status(self, redis_hash, status, data=None): data = {} if data is None else data data.update({ - 'status': status, - 'updated_at': self.get_current_timestamp() + 'updated_at': self.get_current_timestamp(), + 'updated_by': self.hostname, }) self.redis.hmset(redis_hash, data) @@ -133,22 +153,19 @@ def _consume(self, redis_hash): raise NotImplementedError def consume(self): - """Consume all redis events every `interval` seconds. - - Args: - status: string, only consume hashes where `status` == status. - prefix: string, only consume hashes that start with `prefix`. - - Returns: - nothing: this is the consumer main process - """ + """Find a redis key and process it""" start = timeit.default_timer() redis_hash = self.get_redis_hash() if redis_hash is not None: # popped something off the queue try: self._consume(redis_hash) - hvals = self.redis.hgetall(redis_hash) + except Exception as err: # pylint: disable=broad-except + # log the error and update redis with details + self._handle_error(err, redis_hash) + + hvals = self.redis.hgetall(redis_hash) + if hvals.get('status') == self.final_status: self.logger.debug('Consumed key %s (model %s:%s, ' 'preprocessing: %s, postprocessing: %s) ' '(%s retries) in %s seconds.', @@ -157,12 +174,19 @@ def consume(self): hvals.get('preprocess_function'), hvals.get('postprocess_function'), 0, timeit.default_timer() - start) - except Exception as err: # pylint: disable=broad-except - # log the error and update redis with details - self._handle_error(err, redis_hash) - # remove the key from the processing queue - self.redis.lrem(self.processing_queue, 1, redis_hash) + if hvals.get('status') in {self.final_status, 'failed'}: + # this key is done. remove the key from the processing queue. + self.redis.lrem(self.processing_queue, 1, redis_hash) + else: + # this key is not done yet. + # remove it from processing and push it back to the work queue. + self._put_back_hash(redis_hash) + + else: + self.logger.debug('Queue `%s` is empty. Waiting for %s seconds.', + self.queue, settings.EMPTY_QUEUE_TIMEOUT) + time.sleep(settings.EMPTY_QUEUE_TIMEOUT) class ImageFileConsumer(Consumer): @@ -171,8 +195,8 @@ class ImageFileConsumer(Consumer): def is_valid_hash(self, redis_hash): if redis_hash is None: return False - fname = str(self.redis.hget(redis_hash, 'input_file_name')) - is_valid = not fname.lower().endswith('.zip') + fname = self.redis.hget(redis_hash, 'input_file_name') + is_valid = not str(fname).lower().endswith('.zip') return is_valid def _process(self, image, key, process_type, timeout=30, streaming=False): @@ -249,7 +273,8 @@ def _process(self, image, key, process_type, timeout=30, streaming=False): count += 1 temp_status = 'retry-processing - {} - {}'.format( count, err.code().name) - self.update_status(self._redis_hash, temp_status, { + self.update_key(self._redis_hash, { + 'status': temp_status, 'process_retries': count, }) sleeptime = np.random.randint(1, 20) @@ -407,7 +432,8 @@ def grpc_image(self, img, model_name, model_version, timeout=30, backoff=3): # write update to Redis temp_status = 'retry-predicting - {} - {}'.format( count, err.code().name) - self.update_status(self._redis_hash, temp_status, { + self.update_key(self._redis_hash, { + 'status': temp_status, 'predict_retries': count, }) self.logger.warning('%sException `%s: %s` during ' @@ -435,10 +461,11 @@ def _consume(self, redis_hash): # hold on to the redis hash/values for logging purposes self._redis_hash = redis_hash self._redis_values = hvals - self.logger.debug('Found hash to process "%s": %s', - redis_hash, json.dumps(hvals, indent=4)) + self.logger.debug('Found hash to process `%s` with status `%s`.', + redis_hash, hvals.get('status')) - self.update_status(redis_hash, 'started', { + self.update_key(redis_hash, { + 'status': 'started', 'identity_started': self.hostname, }) @@ -457,13 +484,13 @@ def _consume(self, redis_hash): timeout = timeout if not streaming else timeout * int(cuts) # Pre-process data before sending to the model - self.update_status(redis_hash, 'pre-processing') + self.update_key(redis_hash, {'status': 'pre-processing'}) pre_funcs = hvals.get('preprocess_function', '').split(',') image = self.preprocess(image, pre_funcs, timeout, True) # Send data to the model - self.update_status(redis_hash, 'predicting') + self.update_key(redis_hash, {'status': 'predicting'}) if streaming: image = self.process_big_image( @@ -473,13 +500,13 @@ def _consume(self, redis_hash): image, model_name, model_version, timeout) # Post-process model results - self.update_status(redis_hash, 'post-processing') + self.update_key(redis_hash, {'status': 'post-processing'}) post_funcs = hvals.get('postprocess_function', '').split(',') image = self.postprocess(image, post_funcs, timeout, True) # Save the post-processed results to a file - self.update_status(redis_hash, 'saving-results') + self.update_key(redis_hash, {'status': 'saving-results'}) # Save each result channel as an image file save_name = hvals.get('original_name', fname) @@ -499,7 +526,8 @@ def _consume(self, redis_hash): dest, output_url = self.storage.upload(zip_file, subdir=subdir) # Update redis with the final results - self.update_status(redis_hash, self.final_status, { + self.update_key(redis_hash, { + 'status': self.final_status, 'output_url': output_url, 'output_file_name': dest, 'finished_at': self.get_current_timestamp(), @@ -512,8 +540,8 @@ class ZipFileConsumer(Consumer): def is_valid_hash(self, redis_hash): if redis_hash is None: return False - fname = str(self.redis.hget(redis_hash, 'input_file_name')) - is_valid = fname.lower().endswith('.zip') + fname = self.redis.hget(redis_hash, 'input_file_name') + is_valid = str(fname).lower().endswith('.zip') return is_valid def _upload_archived_images(self, hvalues): @@ -522,13 +550,13 @@ def _upload_archived_images(self, hvalues): with utils.get_tempdir() as tempdir: fname = self.storage.download(hvalues.get('input_file_name'), tempdir) image_files = utils.get_image_files_from_dir(fname, tempdir) - for imfile in image_files: + for i, imfile in enumerate(image_files): clean_imfile = settings._strip(imfile.replace(tempdir, '')) # Save each result channel as an image file subdir = os.path.dirname(clean_imfile) dest, _ = self.storage.upload(imfile, subdir=subdir) - new_hash = '{prefix}_{file}_{hash}'.format( + new_hash = '{prefix}:{file}:{hash}'.format( prefix=settings.HASH_PREFIX, file=clean_imfile, hash=uuid.uuid4().hex) @@ -543,92 +571,151 @@ def _upload_archived_images(self, hvalues): new_hvals['created_at'] = current_timestamp new_hvals['updated_at'] = current_timestamp + # remove unnecessary/confusing keys (maybe from getting restarted) + bad_keys = [ + 'children', + 'children:done', + 'children:failed', + 'identity_started', + ] + for k in bad_keys: + if k in new_hvals: + del new_hvals[k] + self.redis.hmset(new_hash, new_hvals) self.redis.lpush(self.queue, new_hash) - self.logger.debug('Added new hash `%s`: %s', - new_hash, json.dumps(new_hvals, indent=4)) + self.logger.debug('Added new hash %s of %s: `%s`', + i + 1, len(image_files), new_hash) all_hashes.add(new_hash) return all_hashes - def _consume(self, redis_hash): - start = timeit.default_timer() - hvals = self.redis.hgetall(redis_hash) - self.logger.debug('Found hash to process `%s`: %s', - redis_hash, json.dumps(hvals, indent=4)) - - self.update_status(redis_hash, 'started', { - 'identity_started': self.hostname, - }) + def _upload_finished_children(self, finished_children, expire_time=3600): + saved_files = set() + with utils.get_tempdir() as tempdir: + # process each successfully completed key + for key in finished_children: + if not key: + continue + fname = self.redis.hget(key, 'output_file_name') + local_fname = self.storage.download(fname, tempdir) + + self.logger.info('Saved file: %s', local_fname) + + if zipfile.is_zipfile(local_fname): + image_files = utils.get_image_files_from_dir( + local_fname, tempdir) + else: + image_files = [local_fname] - all_hashes = self._upload_archived_images(hvals) - self.logger.info('Uploaded %s hashes. Waiting for ImageConsumers.', - len(all_hashes)) + for imfile in image_files: + saved_files.add(imfile) - # Now all images have been uploaded with new redis hashes - # Wait for these to be processed by an ImageFileConsumer - self.update_status(redis_hash, 'waiting') - - with utils.get_tempdir() as tempdir: - finished_hashes = set() - failed_hashes = dict() - saved_files = set() - - expire_time = 60 * 10 # ten minutes - - # ping redis until all the sets are finished - while all_hashes.symmetric_difference(finished_hashes): - for h in all_hashes: - if h in finished_hashes: - continue - - status = self.redis.hget(h, 'status') - - if status == 'failed': - reason = self.redis.hget(h, 'reason') - # one of the hashes failed to process - self.logger.error('Failed to process hash `%s`: %s', - h, reason) - failed_hashes[h] = reason - finished_hashes.add(h) - self.redis.expire(h, expire_time) - - elif status == self.final_status: - # one of our hashes is done! - fname = self.redis.hget(h, 'output_file_name') - local_fname = self.storage.download(fname, tempdir) - self.logger.info('Saved file: %s', local_fname) - if zipfile.is_zipfile(local_fname): - image_files = utils.get_image_files_from_dir( - local_fname, tempdir) - else: - image_files = [local_fname] - - for imfile in image_files: - saved_files.add(imfile) - finished_hashes.add(h) - self.redis.expire(h, expire_time) - - if failed_hashes: - self.logger.warning('Failed to process hashes: %s', - json.dumps(failed_hashes, indent=4)) + self.redis.expire(key, expire_time) + # zip up all saved results zip_file = utils.zip_files(saved_files, tempdir) # Upload the zip file to cloud storage bucket - uploaded_file_path, output_url = self.storage.upload(zip_file) - self.logger.debug('Uploaded output to: `%s`', output_url) + path, url = self.storage.upload(zip_file) + self.logger.debug('Uploaded output to: `%s`', url) + return path, url + + def _parse_failures(self, failed_children, expire_time=3600): + failed_hashes = {} + for key in failed_children: + if not key: + continue + reason = self.redis.hget(key, 'reason') + # one of the hashes failed to process + self.logger.error('Failed to process hash `%s`: %s', + key, reason) + failed_hashes[key] = reason + self.redis.expire(key, expire_time) + + if failed_hashes: + self.logger.warning('Failed to process hashes: %s', + json.dumps(failed_hashes, indent=4)) + + # check python2 vs python3 + if hasattr(urllib, 'parse'): + url_encode = urllib.parse.urlencode # pylint: disable=E1101 + else: + url_encode = urllib.urlencode # pylint: disable=E1101 + + return url_encode(failed_hashes) + + def _consume(self, redis_hash): + start = timeit.default_timer() + hvals = self.redis.hgetall(redis_hash) + self.logger.debug('Found hash to process `%s` with status `%s`.', + redis_hash, hvals.get('status')) + + key_separator = ',' # char to separate child keys in Redis + expire_time = 60 * 10 # expire finished child keys in ten minutes + + # update without changing status, just to refresh timestamp + self.update_key(redis_hash, {'status': hvals.get('status')}) + + if hvals.get('status') == 'new': + # download the zip file, upload the contents, and enter into Redis + all_hashes = self._upload_archived_images(hvals) + self.logger.info('Uploaded %s child keys for key `%s`. Waiting for' + ' ImageConsumers.', len(all_hashes), redis_hash) + + # Now all images have been uploaded with new redis hashes + # Update Redis with child keys and put item back in queue + self.update_key(redis_hash, { + 'status': 'waiting', + 'children': key_separator.join(all_hashes) + }) + + elif hvals.get('status') == 'waiting': + # this key was previously processed by a ZipConsumer + # check to see which child keys have been processed + children = set(hvals.get('children', '').split(key_separator)) + done = set(hvals.get('children:done', '').split(key_separator)) + failed = set(hvals.get('children:failed', '').split(key_separator)) + + # get keys that have not yet reached a completed status + remaining_children = children - done - failed + for child in remaining_children: + status = self.redis.hget(child, 'status') + if status == 'failed': + failed.add(child) + elif status == self.final_status: + done.add(child) + + remaining_children = children - done - failed + + self.logger.info('Key `%s` has %s children waiting for processing', + redis_hash, len(remaining_children)) + + # if there are no remaining children, update status to cleanup + self.update_key(redis_hash, { + 'status': 'cleanup' if not remaining_children else 'waiting', + 'children:done': key_separator.join(d for d in done if d), + 'children:failed': key_separator.join(f for f in failed if f), + }) + + elif hvals.get('status') == 'cleanup': + # clean up children with status `done` and `failed` + children = set(hvals.get('children', '').split(key_separator)) + done = set(hvals.get('children:done', '').split(key_separator)) + failed = set(hvals.get('children:failed', '').split(key_separator)) + + output_file_name, output_url = self._upload_finished_children( + done, expire_time) - # check python2 vs python3 - url = urllib.parse.urlencode if hasattr(urllib, 'parse') else urllib.urlencode + failures = self._parse_failures(failed, expire_time) # Update redis with the results - self.update_status(redis_hash, self.final_status, { - 'identity_output': self.hostname, + self.update_key(redis_hash, { + 'status': self.final_status, 'finished_at': self.get_current_timestamp(), 'output_url': output_url, - 'failures': url(failed_hashes), - 'output_file_name': uploaded_file_path + 'failures': failures, + 'output_file_name': output_file_name }) self.logger.info('Processed all %s images of zipfile `%s` in %s', - len(all_hashes), hvals['input_file_name'], + len(children), hvals.get('input_file_name'), timeit.default_timer() - start) diff --git a/redis_consumer/consumers_test.py b/redis_consumer/consumers_test.py index 55a87030..751bc460 100644 --- a/redis_consumer/consumers_test.py +++ b/redis_consumer/consumers_test.py @@ -30,6 +30,8 @@ import os import copy +import math +import random import redis import numpy as np @@ -39,6 +41,7 @@ from redis_consumer import consumers from redis_consumer import utils +from redis_consumer import settings def _get_image(img_h=300, img_w=300): @@ -55,12 +58,12 @@ def __init__(self, items=[], prefix='predict', status='new'): self.prefix = '/'.join(x for x in prefix.split('/') if x) self.status = status self.keys = [ - '{}_{}_{}'.format(self.prefix, self.status, 'x.tiff'), - '{}_{}_{}'.format(self.prefix, 'other', 'x.zip'), - '{}_{}_{}'.format('other', self.status, 'x.TIFF'), - '{}_{}_{}'.format(self.prefix, self.status, 'x.ZIP'), - '{}_{}_{}'.format(self.prefix, 'other', 'x.tiff'), - '{}_{}_{}'.format('other', self.status, 'x.zip'), + '{}:{}:{}'.format(self.prefix, 'x.tiff', self.status), + '{}:{}:{}'.format(self.prefix, 'x.zip', 'other'), + '{}:{}:{}'.format('other', 'x.TIFF', self.status), + '{}:{}:{}'.format(self.prefix, 'x.ZIP', self.status), + '{}:{}:{}'.format(self.prefix, 'x.tiff', 'other'), + '{}:{}:{}'.format('other', 'x.zip', self.status), ] def rpoplpush(self, src, dst): @@ -73,10 +76,17 @@ def rpoplpush(self, src, dst): def lpush(self, name, *values): self.work_queue = list(values) + self.work_queue - return len(values) + return len(self.work_queue) def lrem(self, name, count, value): self.processing_queue.remove(value) + return count + + def llen(self, queue): + if queue.startswith('processing'): + return len(self.processing_queue) + else: + return len(self.work_queue) def scan_iter(self, match=None, count=None): if match: @@ -85,7 +95,7 @@ def scan_iter(self, match=None, count=None): def expected_keys(self, suffix=None): for k in self.keys: - v = k.split('_') + v = k.split(':') if v[0] == self.prefix: if v[1] == self.status: if suffix: @@ -102,13 +112,15 @@ def expire(self, name, time): # pylint: disable=W0613 def hget(self, rhash, field): if field == 'status': - return rhash.split('_')[1] + return rhash.split(':')[-1] elif field == 'file_name': - return rhash.split('_')[-1] + return rhash.split(':')[1] elif field == 'input_file_name': - return rhash.split('_')[-1] + return rhash.split(':')[1] elif field == 'output_file_name': - return rhash.split('_')[-1] + return rhash.split(':')[1] + elif field == 'reason': + return 'reason' return False def hset(self, rhash, status, value): # pylint: disable=W0613 @@ -122,9 +134,13 @@ def hgetall(self, rhash): # pylint: disable=W0613 'cuts': '0', 'postprocess_function': '', 'preprocess_function': '', - 'file_name': rhash.split('_')[-1], - 'input_file_name': rhash.split('_')[-1], - 'output_file_name': rhash.split('_')[-1] + 'file_name': rhash.split(':')[1], + 'input_file_name': rhash.split(':')[1], + 'output_file_name': rhash.split(':')[1], + 'status': rhash.split(':')[-1], + 'children': 'predict:1.tiff:done,predict:2.tiff:failed,predict:3.tiff:new', + 'children:done': 'predict:4.tiff:done,predict:5.tiff:done', + 'children:failed': 'predict:6.tiff:failed,predict:7.tiff:failed', } @@ -173,10 +189,10 @@ def test_get_redis_hash(self): rhash = consumer.get_redis_hash() assert rhash == items[0] - assert redis_client.work_queue == items[1:] - assert redis_client.processing_queue == items[0:1] + # assert redis_client.work_queue == items[1:] + # assert redis_client.processing_queue == items[0:1] - def test_update_status(self): + def test_update_key(self): global _redis_values _redis_values = None @@ -187,7 +203,8 @@ def hmset(self, _, hvals): consumer = consumers.Consumer(_DummyRedis(), None, 'q') status = 'updated_status' - consumer.update_status('redis-hash', status, { + consumer.update_key('redis-hash', { + 'status': status, 'new_field': True }) assert isinstance(_redis_values, dict) @@ -196,7 +213,7 @@ def hmset(self, _, hvals): assert _redis_values.get('new_field') is True with pytest.raises(ValueError): - consumer.update_status('redis-hash', status, 'data') + consumer.update_key('redis-hash', 'data') def test_handle_error(self): global _redis_values @@ -215,7 +232,7 @@ def hmset(self, _, hvals): assert _redis_values.get('status') == 'failed' def test_consume(self): - items = ['item%s' % x for x in range(1, 4)] + items = ['{}:{}:{}.tiff'.format('predict', 'new', x) for x in range(1, 4)] N = 1 # using a queue, only one key is processed per consume() consumer = consumers.Consumer(DummyRedis(items), DummyStorage(), 'q') @@ -237,10 +254,34 @@ def F(*_): consumer.consume() assert _processed == N + 1 + # empty redis queue + consumer.get_redis_hash = lambda: None + settings.EMPTY_QUEUE_TIMEOUT = 0.1 # don't sleep too long + consumer.consume() + + # failed and done statuses call lrem + def lrem(key, count, value): + global _processed + _processed = True + + _processed = False + redis_client = DummyRedis(items) + redis_client.lrem = lrem + consumer = consumers.Consumer(redis_client, DummyStorage(), 'q') + consumer.get_redis_hash = lambda: 'predict:f.tiff:failed' + consumer.consume() + assert _processed is True + + _processed = False + consumer.get_redis_hash = lambda: 'predict:f.tiff:{status}'.format( + status=consumer.final_status) + consumer.consume() + assert _processed is True + def test__consume(self): with np.testing.assert_raises(NotImplementedError): consumer = consumers.Consumer(None, None, 'q') - consumer._consume('hash') + consumer._consume('predict:new:hash.tiff') class TestImageFileConsumer(object): @@ -288,7 +329,7 @@ def _handle_error(err, rhash): # pylint: disable=W0613 def grpc_image_multi(data, *args, **kwargs): # pylint: disable=W0613 return np.array(tuple(list(data.shape) + [2])) - dummyhash = '{}_test.tiff'.format(prefix) + dummyhash = '{}:test.tiff:{}'.format(prefix, status) # consumer._handle_error = _handle_error consumer.grpc_image = grpc_image_multi @@ -340,41 +381,72 @@ def test__upload_archived_images(self): hsh = consumer._upload_archived_images({'input_file_name': 'test.zip'}) assert len(hsh) == N + def test__upload_finished_children(self): + finished_children = ['predict:1.tiff', 'predict:2.zip', ''] + N = 3 + items = ['item%s' % x for x in range(1, N + 1)] + redis_client = DummyRedis(items) + storage = DummyStorage(num=N) + consumer = consumers.ZipFileConsumer(redis_client, storage, 'q') + path, url = consumer._upload_finished_children(finished_children, 0) + assert path and url + + def test__parse_failures(self): + N = 3 + items = ['item%s' % x for x in range(1, N + 1)] + redis_client = DummyRedis(items) + storage = DummyStorage(num=N) + consumer = consumers.ZipFileConsumer(redis_client, storage, 'q') + + # no failures + failed_children = '' + parsed = consumer._parse_failures(failed_children) + assert parsed == '' + + failed_children = ['item1', 'item2', ''] + parsed = consumer._parse_failures(failed_children) + assert 'item1=reason' in parsed and 'item2=reason' in parsed + def test__consume(self): N = 3 prefix = 'predict' items = ['item%s' % x for x in range(1, 4)] - _redis = DummyRedis(items) redis_client = DummyRedis(items) storage = DummyStorage(num=N) - # test `status` = "done" - hget = lambda h, k: 'done' if k == 'status' else _redis.hget(h, k) - redis_client.hget = hget + # test `status` = "new" + status = 'new' consumer = consumers.ZipFileConsumer(redis_client, storage, 'q') - dummyhash = '{}_test.zip'.format(prefix) + consumer._upload_archived_images = lambda x: items + dummyhash = '{queue}:{fname}.zip:{status}'.format( + queue=prefix, status=status, fname=status) consumer._consume(dummyhash) - # test `status` = "failed" - hget = lambda h, k: 'failed' if k == 'status' else _redis.hget(h, k) - redis_client.hget = hget + # test `status` = "waiting" + status = 'waiting' consumer = consumers.ZipFileConsumer(redis_client, storage, 'q') - dummyhash = '{}_test.zip'.format(prefix) + dummyhash = '{queue}:{fname}.zip:{status}'.format( + queue=prefix, status=status, fname=status) consumer._consume(dummyhash) - # test mixed `status` = "waiting" and "done" - global counter - counter = 0 + # test `status` = "cleanup" + status = 'cleanup' + consumer = consumers.ZipFileConsumer(redis_client, storage, 'q') + consumer._upload_finished_children = lambda x, y: (x, y) + dummyhash = '{queue}:{fname}.zip:{status}'.format( + queue=prefix, status=status, fname=status) + consumer._consume(dummyhash) - def hget_wait(h, k): - if k == 'status': - global counter - status = 'waiting' if counter % 2 == 0 else 'done' - counter += 1 - return status - return _redis.hget(h, k) + # test `status` = "done" + status = 'done' + consumer = consumers.ZipFileConsumer(redis_client, storage, 'q') + dummyhash = '{queue}:{fname}.zip:{status}'.format( + queue=prefix, status=status, fname=status) + consumer._consume(dummyhash) - redis_client.hget = hget_wait + # test `status` = "failed" + status = 'failed' consumer = consumers.ZipFileConsumer(redis_client, storage, 'q') - dummyhash = '{}_test.zip'.format(prefix) + dummyhash = '{queue}:{fname}.zip:{status}'.format( + queue=prefix, status=status, fname=status) consumer._consume(dummyhash) diff --git a/redis_consumer/settings.py b/redis_consumer/settings.py index 0a1743f7..a41b934c 100644 --- a/redis_consumer/settings.py +++ b/redis_consumer/settings.py @@ -62,7 +62,9 @@ # gRPC API timeout in seconds (scales with `cuts`) GRPC_TIMEOUT = config('GRPC_TIMEOUT', default=30, cast=int) +# timeout/backoff wait time in seconds REDIS_TIMEOUT = config('REDIS_TIMEOUT', default=3, cast=int) +EMPTY_QUEUE_TIMEOUT = config('EMPTY_QUEUE_TIMEOUT', default=5, cast=int) # Status of hashes marked for prediction STATUS = config('STATUS', default='new') diff --git a/redis_consumer/utils.py b/redis_consumer/utils.py index e1d58a4b..2a2a6639 100644 --- a/redis_consumer/utils.py +++ b/redis_consumer/utils.py @@ -160,7 +160,7 @@ def iter_image_archive(zip_path, destination): Returns: Iterator of all image paths in extracted archive """ - archive = zipfile.ZipFile(zip_path, 'r') + archive = zipfile.ZipFile(zip_path, 'r', allowZip64=True) is_valid = lambda x: os.path.splitext(x)[1] and '__MACOSX' not in x for info in archive.infolist(): extracted = archive.extract(info, path=destination) @@ -298,7 +298,7 @@ def zip_files(files, dest=None, prefix=None): try: logger.debug('Saving %s files to %s', len(files), filepath) - with zipfile.ZipFile(filepath, 'w') as zip_file: + with zipfile.ZipFile(filepath, 'w', allowZip64=True) as zip_file: for f in files: # writing each file one by one name = f.replace(dest, '') name = name[1:] if name.startswith(os.path.sep) else name