Skip to content

Commit

Permalink
Merge d6a8938 into 33e3955
Browse files Browse the repository at this point in the history
  • Loading branch information
willgraf committed May 10, 2021
2 parents 33e3955 + d6a8938 commit cde27ee
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 0 deletions.
25 changes: 25 additions & 0 deletions redis_consumer/consumers/base_consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,31 @@ def download_image(self, image_path):
image = utils.get_image(fname)
return image

def detect_dimension_order(self, image, model_name, model_version):
"""Detect the dimension ordering of the input image from metadata"""
# TODO: there is overlap with ``validate_model_input``.
# Should we combine the logic? metadata gets cached...
model_metadata = self.get_model_metadata(model_name, model_version)
parse_shape = lambda x: tuple(int(y) for y in x.split(','))
shapes = [parse_shape(x['in_tensor_shape']) for x in model_metadata]
# cast as image to match with the list of shapes.
image = [image] if not isinstance(image, list) else image

dimension_order = []

default_order = 'ZYX'

for img, shape in zip(image, shapes):
rank = len(shape) # expects a batch dimension
# subtract 2 to remove the batch and channel axes
order = default_order[-(rank - 2):]
# detect channel axis
channel_axis = img.shape[1:].index(min(img.shape[1:])) + 1
# append/prepend C to dimension order string
fmt_order = 'C{}' if channel_axis != rank - 1 else '{}C'
dimension_order.append(fmt_order.format(order))
return dimension_order

def validate_model_input(self, image, model_name, model_version, channels=None):
"""Validate that the input image meets the workflow requirements."""
model_metadata = self.get_model_metadata(model_name, model_version)
Expand Down
37 changes: 37 additions & 0 deletions redis_consumer/consumers/base_consumer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,43 @@ def test_download_image(self, redis_client):
assert isinstance(image, np.ndarray)
assert not os.path.exists('test.tif')

def test_detect_dimension_order(self, mocker, redis_client):
storage = DummyStorage()
consumer = consumers.TensorFlowServingConsumer(redis_client, storage, 'q')

model_input_shape = (-1, 32, 32, 1)

mocked_metadata = make_model_metadata_of_size(model_input_shape)
mocker.patch.object(consumer, 'get_model_metadata', mocked_metadata)

input_pairs = [
((1, 32, 32, 1), 'YXC'), # channels last
((1, 1, 32, 32), 'CYX'), # channels first
]

for shape, expected_dim_order in input_pairs:
# check channels last
img = np.ones(shape)
dim_order = consumer.detect_dimension_order(img, 'model', '1')
np.testing.assert_array_equal(dim_order, expected_dim_order)

# Test again for 3D models
model_input_shape = (-1, 5, 32, 32, 1)

mocked_metadata = make_model_metadata_of_size(model_input_shape)
mocker.patch.object(consumer, 'get_model_metadata', mocked_metadata)

input_pairs = [
((1, 10, 32, 32, 1), 'ZYXC'), # channels last
((1, 1, 10, 32, 32), 'CZYX'), # channels first
]

for shape, expected_dim_order in input_pairs:
# check channels last
img = np.ones(shape)
dim_order = consumer.detect_dimension_order(img, 'model', '1')
np.testing.assert_array_equal(dim_order, expected_dim_order)

def test_validate_model_input(self, mocker, redis_client):
storage = DummyStorage()
consumer = consumers.TensorFlowServingConsumer(redis_client, storage, 'q')
Expand Down
6 changes: 6 additions & 0 deletions redis_consumer/consumers/mesmer_consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,12 @@ def _consume(self, redis_hash):
scale = hvals.get('scale', '')
scale = self.get_image_scale(scale, image, redis_hash)

# detect dimension order and add to redis
dim_order = self.detect_dimension_order(image, model_name, model_version)
self.update_key(redis_hash, {
'dim_order': ','.join(dim_order)
})

# Validate input image
if hvals.get('channels'):
channels = [int(c) for c in hvals.get('channels').split(',')]
Expand Down
1 change: 1 addition & 0 deletions redis_consumer/consumers/mesmer_consumer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ def test__consume(self, mocker, redis_client):
mocker.patch.object(consumer, 'get_grpc_app', lambda *x, **_: mock_app)
mocker.patch.object(consumer, 'get_image_scale', lambda *x, **_: 1)
mocker.patch.object(consumer, 'validate_model_input', lambda *x, **_: x[0])
mocker.patch.object(consumer, 'detect_dimension_order', lambda *x, **_: 'YXC')

test_hash = 'some hash'

Expand Down
1 change: 1 addition & 0 deletions redis_consumer/consumers/segmentation_consumer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ def test__consume(self, mocker, redis_client):
mocker.patch.object(consumer, 'get_image_scale', lambda *x, **_: 1)
mocker.patch.object(consumer, 'get_image_label', lambda *x, **_: 1)
mocker.patch.object(consumer, 'validate_model_input', lambda *x, **_: True)
mocker.patch.object(consumer, 'detect_dimension_order', lambda *x, **_: 'YXC')

test_hash = 'some hash'

Expand Down

0 comments on commit cde27ee

Please sign in to comment.