Skip to content

Commit

Permalink
Merge 861ce97 into fe4025d
Browse files Browse the repository at this point in the history
  • Loading branch information
willgraf committed Apr 24, 2020
2 parents fe4025d + 861ce97 commit 7408332
Show file tree
Hide file tree
Showing 5 changed files with 131 additions and 44 deletions.
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,6 @@ The consumer is configured using environment variables. Please find a table of a
| `METADATA_EXPIRE_TIME` | Expire cached model metadata after this many seconds. | `30` |
| `TF_HOST` | The IP address or hostname of TensorFlow Serving. | `"tf-serving"` |
| `TF_PORT` | The port used to connect to TensorFlow Serving. | `8500` |
| `TF_TENSOR_NAME` | Name of input tensor for the exported model. | `"image"` |
| `GRPC_TIMEOUT` | Timeout for gRPC API requests, in seconds. | `30` |
| `GRPC_BACKOFF` | Time to wait before retrying a gRPC API request. | `3` |
| `MAX_RETRY` | Maximum number of retries for a failed TensorFlow Serving request. | `5` |
Expand Down
73 changes: 52 additions & 21 deletions redis_consumer/consumers/base_consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ def _get_predict_client(self, model_name, model_version):
return client

def grpc_image(self, img, model_name, model_version, model_shape,
in_tensor_dtype='DT_FLOAT'):
in_tensor_name='image', in_tensor_dtype='DT_FLOAT'):

in_tensor_dtype = str(in_tensor_dtype).upper()

Expand All @@ -268,7 +268,7 @@ def grpc_image(self, img, model_name, model_version, model_shape,
# but the model rejects the type, wants "int" or "long"
img = img.astype('int')

req_data = [{'in_tensor_name': settings.TF_TENSOR_NAME,
req_data = [{'in_tensor_name': in_tensor_name,
'in_tensor_dtype': in_tensor_dtype,
'data': img}]

Expand All @@ -295,6 +295,27 @@ def grpc_image(self, img, model_name, model_version, model_shape,
0, finished)
return results

def parse_model_metadata(self, metadata):
"""Parse the metadata response and return list of input metadata.
Args:
metadata (dict): model metadata response
Returns:
list: List of metadata objects for each defined input.
"""
# TODO: handle multiple inputs in a general way.
all_metadata = []
for k, v in metadata.items():
shape = ','.join([d['size'] for d in v['tensorShape']['dim']])
data = {
'in_tensor_name': k,
'in_tensor_dtype': v['dtype'],
'in_tensor_shape': shape,
}
all_metadata.append(data)
return all_metadata

def get_model_metadata(self, model_name, model_version):
"""Check Redis for saved model metadata or get from TensorFlow Serving.
Expand All @@ -309,12 +330,11 @@ def get_model_metadata(self, model_name, model_version):
model = '{}:{}'.format(model_name, model_version)
self.logger.debug('Getting model metadata for model %s.', model)

fields = ['in_tensor_dtype', 'in_tensor_shape']
response = self.redis.hmget(model, *fields)
response = self.redis.hget(model, 'metadata')

if all(response):
if response:
self.logger.debug('Got cached metadata for model %s.', model)
return dict(zip(fields, response))
return self.parse_model_metadata(json.loads(response))

# No response! The key was expired. Get from TFS and update it.
start = timeit.default_timer()
Expand All @@ -323,20 +343,15 @@ def get_model_metadata(self, model_name, model_version):

try:
inputs = model_metadata['metadata']['signature_def']['signatureDef']
inputs = inputs['serving_default']['inputs'][settings.TF_TENSOR_NAME]

dtype = inputs['dtype']
shape = ','.join([d['size'] for d in inputs['tensorShape']['dim']])

parsed_metadata = dict(zip(fields, [dtype, shape]))
inputs = inputs['serving_default']['inputs']

finished = timeit.default_timer() - start
self.logger.debug('Got model metadata for %s in %s seconds.',
model, finished)

self.redis.hmset(model, parsed_metadata)
self.redis.hset(model, 'metadata', json.dumps(inputs))
self.redis.expire(model, settings.METADATA_EXPIRE_TIME)
return parsed_metadata
return self.parse_model_metadata(inputs)
except (KeyError, IndexError) as err:
self.logger.error('Malformed metadata: %s', model_metadata)
raise err
Expand All @@ -346,6 +361,7 @@ def _predict_big_image(self,
model_name,
model_version,
model_shape,
model_input_name='image',
model_dtype='DT_FLOAT',
untile=True,
stride_ratio=0.75):
Expand All @@ -357,6 +373,7 @@ def _predict_big_image(self,
model_version (str): model version to query.
model_shape (tuple): shape of the model's expected input.
model_dtype (str): dtype of the model's input array.
model_input_name (str): name of the model's input array.
untile (bool): untiles results back to image shape if True.
stride_ratio (float): amount to overlap between tiles, (0, 1].
Expand Down Expand Up @@ -385,8 +402,10 @@ def _predict_big_image(self,
results = []
for t in range(0, tiles.shape[0], batch_size):
batch = tiles[t:t + batch_size]
output = self.grpc_image(batch, model_name, model_version,
model_shape, in_tensor_dtype=model_dtype)
output = self.grpc_image(
batch, model_name, model_version, model_shape,
in_tensor_name=model_input_name,
in_tensor_dtype=model_dtype)

if not isinstance(output, list):
output = [output]
Expand All @@ -411,6 +430,7 @@ def _predict_small_image(self,
model_name,
model_version,
model_shape,
model_input_name='image',
model_dtype='DT_FLOAT'):
"""Pad an image that is too small for the model, and unpad the results.
Expand All @@ -420,6 +440,7 @@ def _predict_small_image(self,
model_name (str): hosted model to send image data.
model_version (str): model version to query.
model_shape (tuple): shape of the model's expected input.
model_input_name (str): name of the model's input array.
model_dtype (str): dtype of the model's input array.
Returns:
Expand Down Expand Up @@ -447,7 +468,8 @@ def _predict_small_image(self,

padded_img = np.pad(image, pad_width, 'reflect')
image = self.grpc_image(padded_img, model_name, model_version,
model_shape, in_tensor_dtype=model_dtype)
model_shape, in_tensor_name=model_input_name,
in_tensor_dtype=model_dtype)

image = [image] if not isinstance(image, list) else image

Expand All @@ -467,6 +489,14 @@ def predict(self, image, model_name, model_version, untile=True):
start = timeit.default_timer()
model_metadata = self.get_model_metadata(model_name, model_version)

# TODO: generalize for more than a single input.
if len(model_metadata) > 1:
raise ValueError('Model %s:%s has %s required inputs but was only '
'given %s inputs.', model_name, model_version,
len(model_metadata), len(image))
model_metadata = model_metadata[0]

model_input_name = model_metadata['in_tensor_name']
model_dtype = model_metadata['in_tensor_dtype']

model_shape = [int(x) for x in model_metadata['in_tensor_shape'].split(',')]
Expand All @@ -493,17 +523,18 @@ def predict(self, image, model_name, model_version, untile=True):
image.shape[image.ndim - 2] < size_y):
# image is too small for the model, pad the image.
image = self._predict_small_image(image, model_name, model_version,
model_shape, model_dtype)
model_shape, model_input_name,
model_dtype)
elif (image.shape[image.ndim - 3] > size_x or
image.shape[image.ndim - 2] > size_y):
# image is too big for the model, multiple images are tiled.
image = self._predict_big_image(image, model_name, model_version,
model_shape, model_dtype,
untile=untile)
model_shape, model_input_name,
model_dtype, untile=untile)
else:
# image size is perfect, just send it to the model
image = self.grpc_image(image, model_name, model_version,
model_shape, model_dtype)
model_shape, model_input_name, model_dtype)

if isinstance(image, list):
output_shapes = [i.shape for i in image]
Expand Down
95 changes: 76 additions & 19 deletions redis_consumer/consumers/base_consumer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,9 @@
from __future__ import division
from __future__ import print_function

import os
import copy
import json
import os

import numpy as np
from skimage.external import tifffile as tiff
Expand Down Expand Up @@ -357,23 +358,40 @@ def _get_predict_client(model_name, model_version):
consumer._get_predict_client = _get_predict_client

img = np.zeros((1, 32, 32, 3))
out = consumer.grpc_image(img, 'model', 1, model_shape, 'DT_HALF')
out = consumer.grpc_image(img, 'model', 1, model_shape, 'i', 'DT_HALF')

assert img.shape == out.shape
assert img.sum() == out.sum()

img = np.zeros((32, 32, 3))
consumer._redis_hash = 'not none'
out = consumer.grpc_image(img, 'model', 1, model_shape, 'i', 'DT_HALF')

assert (1,) + img.shape == out.shape
assert img.sum() == out.sum()

def test_get_model_metadata(self):
redis_client = DummyRedis([])
model_shape = (-1, 216, 216, 1)
model_dtype = 'DT_FLOAT'
model_input_name = 'input_1'

def hget_success(key, *others):
metadata = {
model_input_name: {
'dtype': model_dtype,
'tensorShape': {
'dim': [
{'size': str(x)}
for x in model_shape
]
}
}
}
return json.dumps(metadata)

def hmget_success(key, *others):
shape = ','.join(str(s) for s in model_shape)
dtype = 'DT_FLOAT'
return dtype, shape

def hmget_fail(key, *others):
return [None] * len(others)
def hget_fail(key, *others):
return None

def _get_predict_client(model_name, model_version):
return Bunch(get_model_metadata=lambda: {
Expand All @@ -382,7 +400,39 @@ def _get_predict_client(model_name, model_version):
'signatureDef': {
'serving_default': {
'inputs': {
settings.TF_TENSOR_NAME: {
model_input_name: {
'dtype': model_dtype,
'tensorShape': {
'dim': [
{'size': str(x)}
for x in model_shape
]
}
}
}
}
}
}
}
})

def _get_predict_client_multi(model_name, model_version):
return Bunch(get_model_metadata=lambda: {
'metadata': {
'signature_def': {
'signatureDef': {
'serving_default': {
'inputs': {
model_input_name: {
'dtype': model_dtype,
'tensorShape': {
'dim': [
{'size': str(x)}
for x in model_shape
]
}
},
'{}_2'.format(model_input_name): {
'dtype': model_dtype,
'tensorShape': {
'dim': [
Expand All @@ -401,24 +451,30 @@ def _get_predict_client(model_name, model_version):
def _get_bad_predict_client(model_name, model_version):
return Bunch(get_model_metadata=lambda: dict())

redis_client.hmget = hmget_success
# test cached input
redis_client.hget = hget_success
consumer = consumers.TensorFlowServingConsumer(redis_client, None, 'q')
consumer._get_predict_client = _get_predict_client
metadata = consumer.get_model_metadata('model', 1)

assert metadata['in_tensor_dtype'] == 'DT_FLOAT'
assert metadata['in_tensor_shape'] == ','.join(str(x) for x in model_shape)
for m in metadata:
assert m['in_tensor_dtype'] == model_dtype
assert m['in_tensor_name'] == model_input_name
assert m['in_tensor_shape'] == ','.join(str(x) for x in model_shape)

redis_client.hmget = hmget_fail
# test stale cache
redis_client.hget = hget_fail
consumer = consumers.TensorFlowServingConsumer(redis_client, None, 'q')
consumer._get_predict_client = _get_predict_client
metadata = consumer.get_model_metadata('model', 1)

assert metadata['in_tensor_dtype'] == 'DT_FLOAT'
assert metadata['in_tensor_shape'] == ','.join(str(x) for x in model_shape)
for m in metadata:
assert m['in_tensor_dtype'] == model_dtype
assert m['in_tensor_name'] == model_input_name
assert m['in_tensor_shape'] == ','.join(str(x) for x in model_shape)

with pytest.raises(KeyError):
redis_client.hmget = hmget_fail
redis_client.hget = hget_fail
consumer = consumers.TensorFlowServingConsumer(redis_client, None, 'q')
consumer._get_predict_client = _get_bad_predict_client
consumer.get_model_metadata('model', 1)
Expand Down Expand Up @@ -449,10 +505,11 @@ def grpc_image_list(data, *args, **kwargs): # pylint: disable=W0613

x = np.random.random(image_shape)
consumer.grpc_image = grpc_func
consumer.get_model_metadata = lambda x, y: {
consumer.get_model_metadata = lambda x, y: [{
'in_tensor_name': 'image',
'in_tensor_dtype': 'DT_HALF',
'in_tensor_shape': ','.join(str(s) for s in model_shape),
}
}]

consumer.predict(x, model_name='modelname', model_version=0,
untile=untile)
Expand Down
5 changes: 3 additions & 2 deletions redis_consumer/consumers/image_consumer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,10 +323,11 @@ def detect_label(_):
def make_model_metadata_of_size(model_shape=(-1, 256, 256, 1)):

def get_model_metadata(model_name, model_version):
return {
return [{
'in_tensor_name': 'image',
'in_tensor_dtype': 'DT_FLOAT',
'in_tensor_shape': ','.join(str(s) for s in model_shape),
}
}]

return get_model_metadata

Expand Down
1 change: 0 additions & 1 deletion redis_consumer/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ def _strip(x):
# TensorFlow Serving client connection
TF_HOST = config('TF_HOST', default='tf-serving')
TF_PORT = config('TF_PORT', default=8500, cast=int)
TF_TENSOR_NAME = config('TF_TENSOR_NAME', default='image')
# maximum batch allowed by TensorFlow Serving
TF_MAX_BATCH_SIZE = config('TF_MAX_BATCH_SIZE', default=128, cast=int)
# minimum expected model size, dynamically change batches proportionately.
Expand Down

0 comments on commit 7408332

Please sign in to comment.