Skip to content

Commit

Permalink
Fixed multiple grpc outputs bug and improved consumer testing (#48)
Browse files Browse the repository at this point in the history
* remove bug from bad merge

* add helper function for _get_predict_client

* update ImageFileConsumer._consume tests for multiple outputs

* fixed bug in saving multiple outputs (?)

* add simple test for grpc_image success

* udpate zipconsumer tests

* update redis backoff time to be GRPC_BACKOFF.
  • Loading branch information
willgraf committed Aug 2, 2019
1 parent 3a502f1 commit a741332
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 22 deletions.
21 changes: 13 additions & 8 deletions redis_consumer/consumers.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,14 @@ def iter_cuts(img, cuts, field):
tf_results.shape, timeit.default_timer() - start)
return tf_results

def _get_predict_client(self, model_name, model_version):
t = timeit.default_timer()
hostname = '{}:{}'.format(settings.TF_HOST, settings.TF_PORT)
client = PredictClient(hostname, model_name, int(model_version))
self.logger.debug('Created the PredictClient in %s seconds.',
timeit.default_timer() - t)
return client

def grpc_image(self, img, model_name, model_version):
count = 0
start = timeit.default_timer()
Expand All @@ -486,14 +494,12 @@ def grpc_image(self, img, model_name, model_version):
# TODO: seems like should cast to "half"
# but the model rejects the type, wants "int" or "long"
img = img.astype('int')
hostname = '{}:{}'.format(settings.TF_HOST, settings.TF_PORT)

req_data = [{'in_tensor_name': settings.TF_TENSOR_NAME,
'in_tensor_dtype': floatx,
'data': np.expand_dims(img, axis=0)}]
t = timeit.default_timer()
client = PredictClient(hostname, model_name, int(model_version))
self.logger.debug('Created the PredictClient in %s seconds.',
timeit.default_timer() - t)

client = self._get_predict_client(model_name, model_version)

prediction = client.predict(req_data, settings.GRPC_TIMEOUT)
results = [prediction[k] for k in sorted(prediction.keys())
Expand All @@ -503,7 +509,6 @@ def grpc_image(self, img, model_name, model_version):
results = results[0]

retrying = False
results = prediction['prediction']

finished = timeit.default_timer() - start
self.update_key(self._redis_hash, {
Expand Down Expand Up @@ -612,7 +617,7 @@ def _consume(self, redis_hash):
outpaths = []
for i in image:
outpaths.extend(utils.save_numpy_array(
image, name=name, subdir=subdir, output_dir=tempdir))
i, name=name, subdir=subdir, output_dir=tempdir))
else:
outpaths = utils.save_numpy_array(
image, name=name, subdir=subdir, output_dir=tempdir)
Expand Down Expand Up @@ -728,7 +733,7 @@ def _get_output_file_name(self, key):
' no output_file_name', key, ttl)

self.redis._update_masters_and_slaves()
time.sleep(3)
time.sleep(settings.GRPC_BACKOFF)
else:
break
else:
Expand Down
104 changes: 90 additions & 14 deletions redis_consumer/consumers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,11 @@ def _get_image(img_h=300, img_w=300):
return img


class Bunch(object):
def __init__(self, **kwds):
self.__dict__.update(kwds)


class DummyRedis(object):
def __init__(self, items=[], prefix='predict', status='new'):
self.work_queue = copy.copy(items)
Expand Down Expand Up @@ -374,6 +379,31 @@ def test_process(self):
output = consumer.process(img, 'valid', 'valid')
assert img.shape[1:] == output.shape

def test__get_predict_client(self):
redis_client = DummyRedis([])
consumer = consumers.ImageFileConsumer(redis_client, None, 'q')

with pytest.raises(ValueError):
consumer._get_predict_client('model_name', 'model_version')

client = consumer._get_predict_client('model_name', 1)

def test_grpc_image(self):
redis_client = DummyRedis([])
consumer = consumers.ImageFileConsumer(redis_client, None, 'q')

def _get_predict_client(model_name, model_version):
return Bunch(predict=lambda x, y: {
'prediction': x[0]['data']
})

consumer._get_predict_client = _get_predict_client

img = np.zeros((1, 32, 32, 3))
out = consumer.grpc_image(img, 'f16model', 1)
assert img.shape == out.shape[1:]
assert img.sum() == out.sum()

def test_process_big_image(self):
name = 'model'
version = 0
Expand All @@ -393,11 +423,11 @@ def test_process_big_image(self):
np.testing.assert_equal(res, img)

def test__consume(self):
prefix = 'prefix'
prefix = 'predict'
status = 'new'
redis_client = DummyRedis(prefix, status)
storage = DummyStorage()
consumer = consumers.ImageFileConsumer(redis_client, storage, 'predict')
consumer = consumers.ImageFileConsumer(redis_client, storage, prefix)

def _handle_error(err, rhash): # pylint: disable=W0613
raise err
Expand All @@ -416,7 +446,7 @@ def grpc_image(data, *args, **kwargs): # pylint: disable=W0613
return data

# test with cuts > 0
redis.hgetall = lambda x: {
redis_client.hgetall = lambda x: {
'model_name': 'model',
'model_version': '0',
'field': '61',
Expand All @@ -427,12 +457,23 @@ def grpc_image(data, *args, **kwargs): # pylint: disable=W0613
'input_file_name': 'test_image.tiff',
'output_file_name': 'test_image.tiff'
}
redis.hmset = lambda x, y: True
consumer = consumers.ImageFileConsumer(redis, storage, 'predict')
redis_client.hmset = lambda x, y: True
consumer = consumers.ImageFileConsumer(redis_client, storage, prefix)
consumer._handle_error = _handle_error
consumer.grpc_image = grpc_image
consumer._consume(dummyhash)

# test with multiple outputs from model and cuts == 0

def grpc_image_list(data, *args, **kwargs): # pylint: disable=W0613
return [data, data]

redis_client = DummyRedis(prefix, status)
consumer = consumers.ImageFileConsumer(redis_client, storage, prefix)
consumer._handle_error = _handle_error
consumer.grpc_image = grpc_image_list
consumer._consume(dummyhash)


class TestZipFileConsumer(object):

Expand All @@ -459,7 +500,8 @@ def test__upload_archived_images(self):
storage = DummyStorage(num=N)
consumer = consumers.ZipFileConsumer(redis_client, storage, 'predict')
hsh = consumer._upload_archived_images(
{'input_file_name': 'test.zip'}, 'predict:redis_hash:f.zip')
{'input_file_name': 'test.zip', 'children': ''},
'predict:redis_hash:f.zip')
assert len(hsh) == N

def test__upload_finished_children(self):
Expand All @@ -473,6 +515,30 @@ def test__upload_finished_children(self):
finished_children, 'predict:redis_hash:f.zip')
assert path and url

def test__get_output_file_name(self):
settings.GRPC_BACKOFF = 0
redis_client = DummyRedis([])
redis_client.ttl = lambda x: -1 # key is missing
redis_client._update_masters_and_slaves = lambda: True

redis_client._redis_master = Bunch(hget=lambda x, y: None)
consumer = consumers.ZipFileConsumer(redis_client, None, 'predict')

with pytest.raises(ValueError):
redis_client.ttl = lambda x: -2 # key is missing
consumer = consumers.ZipFileConsumer(redis_client, None, 'predict')
consumer._get_output_file_name('randomkey')

with pytest.raises(ValueError):
redis_client.ttl = lambda x: 1 # key is expired
consumer = consumers.ZipFileConsumer(redis_client, None, 'predict')
consumer._get_output_file_name('randomkey')

with pytest.raises(ValueError):
redis_client.ttl = lambda x: -1 # key not expired
consumer = consumers.ZipFileConsumer(redis_client, None, 'predict')
consumer._get_output_file_name('randomkey')

def test__parse_failures(self):
N = 3
items = ['item%s' % x for x in range(1, N + 1)]
Expand All @@ -489,6 +555,24 @@ def test__parse_failures(self):
parsed = consumer._parse_failures(failed_children)
assert 'item1=reason' in parsed and 'item2=reason' in parsed

def test__cleanup(self):
N = 3
prefix = 'predict'
status = 'waiting'
items = ['item%s' % x for x in range(1, N + 1)]
redis_client = DummyRedis(items)
storage = DummyStorage(num=N)
consumer = consumers.ZipFileConsumer(redis_client, storage, 'predict')

children = list('abcdef')
done = ['{}:done'.format(c) for c in children[:3]]
failed = ['{}:failed'.format(c) for c in children[3:]]

key = '{queue}:{fname}.zip:{status}'.format(
queue='prefix', status=status, fname=status)

consumer._cleanup(items[0], children, done, failed)

def test__consume(self):
N = 3
prefix = 'predict'
Expand All @@ -511,14 +595,6 @@ def test__consume(self):
queue=prefix, status=status, fname=status)
consumer._consume(dummyhash)

# test `status` = "cleanup"
status = 'cleanup'
consumer = consumers.ZipFileConsumer(redis_client, storage, 'predict')
consumer._upload_finished_children = lambda x, y, z: (x, y)
dummyhash = '{queue}:{fname}.zip:{status}'.format(
queue=prefix, status=status, fname=status)
consumer._consume(dummyhash)

# test `status` = "done"
status = 'done'
consumer = consumers.ZipFileConsumer(redis_client, storage, 'predict')
Expand Down

0 comments on commit a741332

Please sign in to comment.