From 7e1cee132451b3b6b670ae3b55208a64ec97cd3f Mon Sep 17 00:00:00 2001 From: willgraf <7930703+willgraf@users.noreply.github.com> Date: Tue, 1 Jun 2021 17:27:07 -0700 Subject: [PATCH] Add new `detect_dimension_order` to detect the dim ordering string. (#172) * Add new detect_dimension_order method. * Call detect_dimension_order and save to redis in both mesmer and segmentation. Saved as Comma-separated string as "dim_order". * add dim_order to Redis for Segmentation consumers too. --- redis_consumer/consumers/base_consumer.py | 25 +++++++++++++ .../consumers/base_consumer_test.py | 37 +++++++++++++++++++ redis_consumer/consumers/mesmer_consumer.py | 6 +++ .../consumers/mesmer_consumer_test.py | 1 + .../consumers/segmentation_consumer.py | 6 +++ .../consumers/segmentation_consumer_test.py | 1 + 6 files changed, 76 insertions(+) diff --git a/redis_consumer/consumers/base_consumer.py b/redis_consumer/consumers/base_consumer.py index adb3bcbd..28692481 100644 --- a/redis_consumer/consumers/base_consumer.py +++ b/redis_consumer/consumers/base_consumer.py @@ -250,6 +250,31 @@ def download_image(self, image_path): image = utils.get_image(fname) return image + def detect_dimension_order(self, image, model_name, model_version): + """Detect the dimension ordering of the input image from metadata""" + # TODO: there is overlap with ``validate_model_input``. + # Should we combine the logic? metadata gets cached... + model_metadata = self.get_model_metadata(model_name, model_version) + parse_shape = lambda x: tuple(int(y) for y in x.split(',')) + shapes = [parse_shape(x['in_tensor_shape']) for x in model_metadata] + # cast as image to match with the list of shapes. + image = [image] if not isinstance(image, list) else image + + dimension_order = [] + + default_order = 'ZYX' + + for img, shape in zip(image, shapes): + rank = len(shape) # expects a batch dimension + # subtract 2 to remove the batch and channel axes + order = default_order[-(rank - 2):] + # detect channel axis + channel_axis = img.shape[1:].index(min(img.shape[1:])) + 1 + # append/prepend C to dimension order string + fmt_order = 'C{}' if channel_axis != rank - 1 else '{}C' + dimension_order.append(fmt_order.format(order)) + return dimension_order + def validate_model_input(self, image, model_name, model_version, channels=None): """Validate that the input image meets the workflow requirements.""" model_metadata = self.get_model_metadata(model_name, model_version) diff --git a/redis_consumer/consumers/base_consumer_test.py b/redis_consumer/consumers/base_consumer_test.py index 05e0fc99..1c6eb555 100644 --- a/redis_consumer/consumers/base_consumer_test.py +++ b/redis_consumer/consumers/base_consumer_test.py @@ -237,6 +237,43 @@ def test_download_image(self, redis_client): assert isinstance(image, np.ndarray) assert not os.path.exists('test.tif') + def test_detect_dimension_order(self, mocker, redis_client): + storage = DummyStorage() + consumer = consumers.TensorFlowServingConsumer(redis_client, storage, 'q') + + model_input_shape = (-1, 32, 32, 1) + + mocked_metadata = make_model_metadata_of_size(model_input_shape) + mocker.patch.object(consumer, 'get_model_metadata', mocked_metadata) + + input_pairs = [ + ((1, 32, 32, 1), 'YXC'), # channels last + ((1, 1, 32, 32), 'CYX'), # channels first + ] + + for shape, expected_dim_order in input_pairs: + # check channels last + img = np.ones(shape) + dim_order = consumer.detect_dimension_order(img, 'model', '1') + np.testing.assert_array_equal(dim_order, expected_dim_order) + + # Test again for 3D models + model_input_shape = (-1, 5, 32, 32, 1) + + mocked_metadata = make_model_metadata_of_size(model_input_shape) + mocker.patch.object(consumer, 'get_model_metadata', mocked_metadata) + + input_pairs = [ + ((1, 10, 32, 32, 1), 'ZYXC'), # channels last + ((1, 1, 10, 32, 32), 'CZYX'), # channels first + ] + + for shape, expected_dim_order in input_pairs: + # check channels last + img = np.ones(shape) + dim_order = consumer.detect_dimension_order(img, 'model', '1') + np.testing.assert_array_equal(dim_order, expected_dim_order) + def test_validate_model_input(self, mocker, redis_client): storage = DummyStorage() consumer = consumers.TensorFlowServingConsumer(redis_client, storage, 'q') diff --git a/redis_consumer/consumers/mesmer_consumer.py b/redis_consumer/consumers/mesmer_consumer.py index 4112545a..144d086d 100644 --- a/redis_consumer/consumers/mesmer_consumer.py +++ b/redis_consumer/consumers/mesmer_consumer.py @@ -109,6 +109,12 @@ def _consume(self, redis_hash): scale = hvals.get('scale', '') scale = self.get_image_scale(scale, image, redis_hash) + # detect dimension order and add to redis + dim_order = self.detect_dimension_order(image, model_name, model_version) + self.update_key(redis_hash, { + 'dim_order': ','.join(dim_order) + }) + # Validate input image if hvals.get('channels'): channels = [int(c) for c in hvals.get('channels').split(',')] diff --git a/redis_consumer/consumers/mesmer_consumer_test.py b/redis_consumer/consumers/mesmer_consumer_test.py index 97c5955c..b8ba8c07 100644 --- a/redis_consumer/consumers/mesmer_consumer_test.py +++ b/redis_consumer/consumers/mesmer_consumer_test.py @@ -112,6 +112,7 @@ def test__consume(self, mocker, redis_client): mocker.patch.object(consumer, 'get_grpc_app', lambda *x, **_: mock_app) mocker.patch.object(consumer, 'get_image_scale', lambda *x, **_: 1) mocker.patch.object(consumer, 'validate_model_input', lambda *x, **_: x[0]) + mocker.patch.object(consumer, 'detect_dimension_order', lambda *x, **_: 'YXC') test_hash = 'some hash' diff --git a/redis_consumer/consumers/segmentation_consumer.py b/redis_consumer/consumers/segmentation_consumer.py index 7807d9ff..e77112da 100644 --- a/redis_consumer/consumers/segmentation_consumer.py +++ b/redis_consumer/consumers/segmentation_consumer.py @@ -132,6 +132,12 @@ def _consume(self, redis_hash): model_name, model_version = model.split(':') + # detect dimension order and add to redis + dim_order = self.detect_dimension_order(image, model_name, model_version) + self.update_key(redis_hash, { + 'dim_order': ','.join(dim_order) + }) + # Validate input image image = self.validate_model_input(image, model_name, model_version, channels=channels) diff --git a/redis_consumer/consumers/segmentation_consumer_test.py b/redis_consumer/consumers/segmentation_consumer_test.py index 65346983..a3f41d98 100644 --- a/redis_consumer/consumers/segmentation_consumer_test.py +++ b/redis_consumer/consumers/segmentation_consumer_test.py @@ -142,6 +142,7 @@ def test__consume(self, mocker, redis_client): mocker.patch.object(consumer, 'get_image_scale', lambda *x, **_: 1) mocker.patch.object(consumer, 'get_image_label', lambda *x, **_: 1) mocker.patch.object(consumer, 'validate_model_input', lambda *x, **_: True) + mocker.patch.object(consumer, 'detect_dimension_order', lambda *x, **_: 'YXC') test_hash = 'some hash'