From 3bea0aab2fe267299c8a71b3f500498ae177f088 Mon Sep 17 00:00:00 2001 From: William Graf <7930703+willgraf@users.noreply.github.com> Date: Tue, 19 Jan 2021 17:41:54 -0800 Subject: [PATCH 01/73] Move is_valid_hash definition to TensorFlowServingConsumer --- redis_consumer/consumers/base_consumer.py | 12 +++++++++--- redis_consumer/consumers/base_consumer_test.py | 14 ++++++++++++++ redis_consumer/consumers/image_consumer.py | 7 ------- redis_consumer/consumers/image_consumer_test.py | 13 ------------- redis_consumer/consumers/multiplex_consumer.py | 7 ------- .../consumers/multiplex_consumer_test.py | 13 ------------- 6 files changed, 23 insertions(+), 43 deletions(-) diff --git a/redis_consumer/consumers/base_consumer.py b/redis_consumer/consumers/base_consumer.py index d6260080..feb536b6 100644 --- a/redis_consumer/consumers/base_consumer.py +++ b/redis_consumer/consumers/base_consumer.py @@ -148,7 +148,7 @@ def _handle_error(self, err, redis_hash): def is_valid_hash(self, redis_hash): # pylint: disable=unused-argument """Returns True if the consumer should work on the item""" - return True + return redis_hash is not None def get_current_timestamp(self): """Helper function, returns ISO formatted UTC timestamp""" @@ -252,9 +252,15 @@ def __init__(self, self._redis_values = dict() super(TensorFlowServingConsumer, self).__init__( redis_client, storage_client, queue, **kwargs) + + def is_valid_hash(self, redis_hash): + """Don't run on zip files""" + if redis_hash is None: + return False + + fname = str(self.redis.hget(redis_hash, 'input_file_name')) + return not fname.lower().endswith('.zip') - def _consume(self, redis_hash): - raise NotImplementedError def _get_predict_client(self, model_name, model_version): """Returns the TensorFlow Serving gRPC client. diff --git a/redis_consumer/consumers/base_consumer_test.py b/redis_consumer/consumers/base_consumer_test.py index a88589c3..b3a338af 100644 --- a/redis_consumer/consumers/base_consumer_test.py +++ b/redis_consumer/consumers/base_consumer_test.py @@ -208,6 +208,20 @@ def test__consume(self): class TestTensorFlowServingConsumer(object): # pylint: disable=R0201,W0613,W0621 + def test_is_valid_hash(self, mocker, redis_client): + storage = DummyStorage() + mocker.patch.object(redis_client, 'hget', lambda x, y: x.split(':')[-1]) + + consumer = consumers.TensorFlowServingConsumer(redis_client, storage, 'predict') + + assert consumer.is_valid_hash(None) is False + assert consumer.is_valid_hash('file.ZIp') is False + assert consumer.is_valid_hash('predict:1234567890:file.ZIp') is False + assert consumer.is_valid_hash('track:123456789:file.zip') is False + assert consumer.is_valid_hash('predict:123456789:file.zip') is False + assert consumer.is_valid_hash('predict:1234567890:file.tiff') is True + assert consumer.is_valid_hash('predict:1234567890:file.png') is True + def test__get_predict_client(self, redis_client): stg = DummyStorage() consumer = consumers.TensorFlowServingConsumer(redis_client, stg, 'q') diff --git a/redis_consumer/consumers/image_consumer.py b/redis_consumer/consumers/image_consumer.py index fde67cf3..5661e6f9 100644 --- a/redis_consumer/consumers/image_consumer.py +++ b/redis_consumer/consumers/image_consumer.py @@ -41,13 +41,6 @@ class ImageFileConsumer(TensorFlowServingConsumer): """Consumes image files and uploads the results""" - def is_valid_hash(self, redis_hash): - if redis_hash is None: - return False - - fname = str(self.redis.hget(redis_hash, 'input_file_name')) - return not fname.lower().endswith('.zip') - def detect_scale(self, image): """Send the image to the SCALE_DETECT_MODEL to detect the relative scale difference from the image to the model's training data. diff --git a/redis_consumer/consumers/image_consumer_test.py b/redis_consumer/consumers/image_consumer_test.py index 6676fcd2..019681ea 100644 --- a/redis_consumer/consumers/image_consumer_test.py +++ b/redis_consumer/consumers/image_consumer_test.py @@ -42,19 +42,6 @@ class TestImageFileConsumer(object): # pylint: disable=R0201,W0621 - def test_is_valid_hash(self, mocker, redis_client): - storage = DummyStorage() - mocker.patch.object(redis_client, 'hget', lambda x, y: x.split(':')[-1]) - - consumer = consumers.ImageFileConsumer(redis_client, storage, 'predict') - - assert consumer.is_valid_hash(None) is False - assert consumer.is_valid_hash('file.ZIp') is False - assert consumer.is_valid_hash('predict:1234567890:file.ZIp') is False - assert consumer.is_valid_hash('track:123456789:file.zip') is False - assert consumer.is_valid_hash('predict:123456789:file.zip') is False - assert consumer.is_valid_hash('predict:1234567890:file.tiff') is True - assert consumer.is_valid_hash('predict:1234567890:file.png') is True def test_detect_label(self, mocker, redis_client): # pylint: disable=W0613 diff --git a/redis_consumer/consumers/multiplex_consumer.py b/redis_consumer/consumers/multiplex_consumer.py index d00fcf44..e289086b 100644 --- a/redis_consumer/consumers/multiplex_consumer.py +++ b/redis_consumer/consumers/multiplex_consumer.py @@ -41,13 +41,6 @@ class MultiplexConsumer(TensorFlowServingConsumer): """Consumes image files and uploads the results""" - def is_valid_hash(self, redis_hash): - if redis_hash is None: - return False - - fname = str(self.redis.hget(redis_hash, 'input_file_name')) - return not fname.lower().endswith('.zip') - def _consume(self, redis_hash): start = timeit.default_timer() self._redis_hash = redis_hash # workaround for logging. diff --git a/redis_consumer/consumers/multiplex_consumer_test.py b/redis_consumer/consumers/multiplex_consumer_test.py index d23310bb..ee810616 100644 --- a/redis_consumer/consumers/multiplex_consumer_test.py +++ b/redis_consumer/consumers/multiplex_consumer_test.py @@ -42,19 +42,6 @@ class TestMultiplexConsumer(object): # pylint: disable=R0201 - def test_is_valid_hash(self, mocker, redis_client): - storage = DummyStorage() - mocker.patch.object(redis_client, 'hget', lambda *x: x[0]) - - consumer = consumers.MultiplexConsumer(redis_client, storage, 'multiplex') - assert consumer.is_valid_hash(None) is False - assert consumer.is_valid_hash('file.ZIp') is False - assert consumer.is_valid_hash('predict:1234567890:file.ZIp') is False - assert consumer.is_valid_hash('track:123456789:file.zip') is False - assert consumer.is_valid_hash('predict:123456789:file.zip') is False - assert consumer.is_valid_hash('multiplex:1234567890:file.tiff') is True - assert consumer.is_valid_hash('multiplex:1234567890:file.png') is True - def test__consume(self, mocker, redis_client): # pylint: disable=W0613 From 5046aa8356a00e2f06da30590e3725f5fc1793aa Mon Sep 17 00:00:00 2001 From: William Graf <7930703+willgraf@users.noreply.github.com> Date: Tue, 19 Jan 2021 17:46:43 -0800 Subject: [PATCH 02/73] Add download_image and validate_model_input helper functions --- redis_consumer/consumers/base_consumer.py | 29 +++++++++++++++++++ redis_consumer/consumers/image_consumer.py | 8 +++-- .../consumers/multiplex_consumer.py | 26 ++--------------- 3 files changed, 37 insertions(+), 26 deletions(-) diff --git a/redis_consumer/consumers/base_consumer.py b/redis_consumer/consumers/base_consumer.py index feb536b6..5bfcada6 100644 --- a/redis_consumer/consumers/base_consumer.py +++ b/redis_consumer/consumers/base_consumer.py @@ -261,6 +261,35 @@ def is_valid_hash(self, redis_hash): fname = str(self.redis.hget(redis_hash, 'input_file_name')) return not fname.lower().endswith('.zip') + def download_image(self, image_path): + """Download file from bucket and load it as an image""" + with utils.get_tempdir() as tempdir: + fname = self.storage.download(image_path, tempdir) + image = utils.get_image(fname) + return image + + def validate_model_input(self, image, model_name, model_version): + """Validate that the input image meets the workflow requirements.""" + model_metadata = self.get_model_metadata(model_name, model_version) + shape = [int(x) for x in model_metadata['in_tensor_shape'].split(',')] + + rank = len(shape) - 1 # ignoring batch dimension + channels = shape[-1] + + errtext = (f'Invalid image shape: {image.shape}. ' + f'The {self.queue} job expects images of shape ' + f'[height, widths, {channels}]') + + if len(image.shape) != rank: + raise ValueError(errtext) + + if image.shape[0] == channels: + image = np.rollaxis(image, 0, rank) + + if image.shape[rank - 1] != channels: + raise ValueError(errtext) + + return image def _get_predict_client(self, model_name, model_version): """Returns the TensorFlow Serving gRPC client. diff --git a/redis_consumer/consumers/image_consumer.py b/redis_consumer/consumers/image_consumer.py index 5661e6f9..590f5684 100644 --- a/redis_consumer/consumers/image_consumer.py +++ b/redis_consumer/consumers/image_consumer.py @@ -129,9 +129,11 @@ def _consume(self, redis_hash): _ = timeit.default_timer() - with utils.get_tempdir() as tempdir: - fname = self.storage.download(hvals.get('input_file_name'), tempdir) - image = utils.get_image(fname) + # Load input image + image = self.download_image(hvals.get('input_file_name')) + + # Validate input image + image = self.validate_model_input(image, model_name, model_version) # Pre-process data before sending to the model self.update_key(redis_hash, { diff --git a/redis_consumer/consumers/multiplex_consumer.py b/redis_consumer/consumers/multiplex_consumer.py index e289086b..c595308e 100644 --- a/redis_consumer/consumers/multiplex_consumer.py +++ b/redis_consumer/consumers/multiplex_consumer.py @@ -65,33 +65,13 @@ def _consume(self, redis_hash): _ = timeit.default_timer() # Load input image - with utils.get_tempdir() as tempdir: - fname = self.storage.download(hvals.get('input_file_name'), tempdir) - # TODO: tiffs expand the last axis, is that a problem here? - image = utils.get_image(fname) + image = self.download_image(hvals.get('input_file_name')) # squeeze extra dimension that is added by get_image image = np.squeeze(image) - # validate correct shape of image - if len(image.shape) > 3: - raise ValueError('Invalid image shape. An image of shape {} was supplied, but the ' - 'multiplex model expects of images of shape' - '[height, widths, 2]'.format(image.shape)) - elif len(image.shape) < 3: - # TODO: Once we can pass warning messages to user, we can treat this as nuclear image - raise ValueError('Invalid image shape. An image of shape {} was supplied, but the ' - 'multiplex model expects images of shape' - '[height, width, 2]'.format(image.shape)) - else: - if image.shape[0] == 2: - image = np.rollaxis(image, 0, 3) - elif image.shape[2] == 2: - pass - else: - raise ValueError('Invalid image shape. An image of shape {} was supplied, ' - 'but the multiplex model expects images of shape' - '[height, widths, 2]'.format(image.shape)) + # Validate input image + image = self.validate_model_input(image, model_name, model_version) # Pre-process data before sending to the model self.update_key(redis_hash, { From a46eac618266e2c7eb9691b3b3cb70f851c22591 Mon Sep 17 00:00:00 2001 From: William Graf <7930703+willgraf@users.noreply.github.com> Date: Wed, 20 Jan 2021 18:30:05 -0800 Subject: [PATCH 03/73] Remove protos and use tensorflow-serving-api and tensorflow protos. --- .github/workflows/tests.yaml | 1 + Dockerfile | 6 +- protos/attr_value.proto | 62 - protos/function.proto | 113 -- protos/get_model_metadata.proto | 30 - protos/graph.proto | 56 - protos/meta_graph.proto | 342 ------ protos/model.proto | 33 - protos/node_def.proto | 86 -- protos/op_def.proto | 170 --- protos/predict.proto | 40 - protos/prediction_service.proto | 31 - protos/resource_handle.proto | 42 - protos/saved_object_graph.proto | 164 --- protos/saver.proto | 47 - protos/struct.proto | 134 --- protos/tensor.proto | 94 -- protos/tensor_shape.proto | 46 - protos/trackable_object_graph.proto | 59 - protos/types.proto | 76 -- protos/variable.proto | 85 -- protos/versions.proto | 32 - redis_consumer/__init__.py | 1 - redis_consumer/pbs/__init__.py | 0 redis_consumer/pbs/attr_value_pb2.py | 365 ------ redis_consumer/pbs/attr_value_pb2_grpc.py | 3 - redis_consumer/pbs/function_pb2.py | 486 -------- redis_consumer/pbs/function_pb2_grpc.py | 3 - redis_consumer/pbs/get_model_metadata_pb2.py | 265 ----- .../pbs/get_model_metadata_pb2_grpc.py | 3 - redis_consumer/pbs/graph_pb2.py | 98 -- redis_consumer/pbs/graph_pb2_grpc.py | 3 - redis_consumer/pbs/meta_graph_pb2.py | 1031 ----------------- redis_consumer/pbs/meta_graph_pb2_grpc.py | 3 - redis_consumer/pbs/model_pb2.py | 102 -- redis_consumer/pbs/model_pb2_grpc.py | 3 - redis_consumer/pbs/node_def_pb2.py | 202 ---- redis_consumer/pbs/node_def_pb2_grpc.py | 3 - redis_consumer/pbs/op_def_pb2.py | 404 ------- redis_consumer/pbs/op_def_pb2_grpc.py | 3 - redis_consumer/pbs/predict_pb2.py | 232 ---- redis_consumer/pbs/predict_pb2_grpc.py | 3 - redis_consumer/pbs/prediction_service_pb2.py | 66 -- .../pbs/prediction_service_pb2_grpc.py | 78 -- redis_consumer/pbs/resource_handle_pb2.py | 156 --- .../pbs/resource_handle_pb2_grpc.py | 3 - redis_consumer/pbs/saved_object_graph_pb2.py | 720 ------------ .../pbs/saved_object_graph_pb2_grpc.py | 3 - redis_consumer/pbs/saver_pb2.py | 140 --- redis_consumer/pbs/saver_pb2_grpc.py | 3 - redis_consumer/pbs/struct_pb2.py | 662 ----------- redis_consumer/pbs/struct_pb2_grpc.py | 3 - redis_consumer/pbs/tensor_pb2.py | 253 ---- redis_consumer/pbs/tensor_pb2_grpc.py | 3 - redis_consumer/pbs/tensor_shape_pb2.py | 123 -- redis_consumer/pbs/tensor_shape_pb2_grpc.py | 3 - .../pbs/trackable_object_graph_pb2.py | 285 ----- .../pbs/trackable_object_graph_pb2_grpc.py | 3 - redis_consumer/pbs/types_pb2.py | 282 ----- redis_consumer/pbs/types_pb2_grpc.py | 3 - redis_consumer/pbs/variable_pb2.py | 261 ----- redis_consumer/pbs/variable_pb2_grpc.py | 3 - redis_consumer/pbs/versions_pb2.py | 83 -- redis_consumer/pbs/versions_pb2_grpc.py | 3 - requirements-no-deps.txt | 3 + requirements.txt | 21 +- 66 files changed, 22 insertions(+), 8103 deletions(-) delete mode 100644 protos/attr_value.proto delete mode 100644 protos/function.proto delete mode 100644 protos/get_model_metadata.proto delete mode 100644 protos/graph.proto delete mode 100644 protos/meta_graph.proto delete mode 100644 protos/model.proto delete mode 100644 protos/node_def.proto delete mode 100644 protos/op_def.proto delete mode 100644 protos/predict.proto delete mode 100644 protos/prediction_service.proto delete mode 100644 protos/resource_handle.proto delete mode 100644 protos/saved_object_graph.proto delete mode 100644 protos/saver.proto delete mode 100644 protos/struct.proto delete mode 100644 protos/tensor.proto delete mode 100644 protos/tensor_shape.proto delete mode 100644 protos/trackable_object_graph.proto delete mode 100644 protos/types.proto delete mode 100644 protos/variable.proto delete mode 100644 protos/versions.proto delete mode 100644 redis_consumer/pbs/__init__.py delete mode 100644 redis_consumer/pbs/attr_value_pb2.py delete mode 100644 redis_consumer/pbs/attr_value_pb2_grpc.py delete mode 100644 redis_consumer/pbs/function_pb2.py delete mode 100644 redis_consumer/pbs/function_pb2_grpc.py delete mode 100644 redis_consumer/pbs/get_model_metadata_pb2.py delete mode 100644 redis_consumer/pbs/get_model_metadata_pb2_grpc.py delete mode 100644 redis_consumer/pbs/graph_pb2.py delete mode 100644 redis_consumer/pbs/graph_pb2_grpc.py delete mode 100644 redis_consumer/pbs/meta_graph_pb2.py delete mode 100644 redis_consumer/pbs/meta_graph_pb2_grpc.py delete mode 100644 redis_consumer/pbs/model_pb2.py delete mode 100644 redis_consumer/pbs/model_pb2_grpc.py delete mode 100644 redis_consumer/pbs/node_def_pb2.py delete mode 100644 redis_consumer/pbs/node_def_pb2_grpc.py delete mode 100644 redis_consumer/pbs/op_def_pb2.py delete mode 100644 redis_consumer/pbs/op_def_pb2_grpc.py delete mode 100644 redis_consumer/pbs/predict_pb2.py delete mode 100644 redis_consumer/pbs/predict_pb2_grpc.py delete mode 100644 redis_consumer/pbs/prediction_service_pb2.py delete mode 100644 redis_consumer/pbs/prediction_service_pb2_grpc.py delete mode 100644 redis_consumer/pbs/resource_handle_pb2.py delete mode 100644 redis_consumer/pbs/resource_handle_pb2_grpc.py delete mode 100644 redis_consumer/pbs/saved_object_graph_pb2.py delete mode 100644 redis_consumer/pbs/saved_object_graph_pb2_grpc.py delete mode 100644 redis_consumer/pbs/saver_pb2.py delete mode 100644 redis_consumer/pbs/saver_pb2_grpc.py delete mode 100644 redis_consumer/pbs/struct_pb2.py delete mode 100644 redis_consumer/pbs/struct_pb2_grpc.py delete mode 100644 redis_consumer/pbs/tensor_pb2.py delete mode 100644 redis_consumer/pbs/tensor_pb2_grpc.py delete mode 100644 redis_consumer/pbs/tensor_shape_pb2.py delete mode 100644 redis_consumer/pbs/tensor_shape_pb2_grpc.py delete mode 100644 redis_consumer/pbs/trackable_object_graph_pb2.py delete mode 100644 redis_consumer/pbs/trackable_object_graph_pb2_grpc.py delete mode 100644 redis_consumer/pbs/types_pb2.py delete mode 100644 redis_consumer/pbs/types_pb2_grpc.py delete mode 100644 redis_consumer/pbs/variable_pb2.py delete mode 100644 redis_consumer/pbs/variable_pb2_grpc.py delete mode 100644 redis_consumer/pbs/versions_pb2.py delete mode 100644 redis_consumer/pbs/versions_pb2_grpc.py create mode 100644 requirements-no-deps.txt diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 021ccd53..b6718075 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -37,6 +37,7 @@ jobs: run: | python -m pip install --upgrade pip pip install -r requirements.txt + pip install --no-deps -r requirements-no-deps.txt pip install -r requirements-test.txt - name: Run PyTest and Coveralls diff --git a/Dockerfile b/Dockerfile index ed3cdc3d..b32a145f 100644 --- a/Dockerfile +++ b/Dockerfile @@ -23,14 +23,16 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ -FROM python:3.6 +FROM python:3.7 WORKDIR /usr/src/app -COPY requirements.txt . +COPY requirements.txt requirements-no-deps.txt ./ RUN pip install --no-cache-dir -r requirements.txt +RUN pip install --no-cache-dir --no-deps -r requirements-no-deps.txt + COPY . . CMD ["/bin/sh", "-c", "python consume-redis-events.py"] diff --git a/protos/attr_value.proto b/protos/attr_value.proto deleted file mode 100644 index 76944f77..00000000 --- a/protos/attr_value.proto +++ /dev/null @@ -1,62 +0,0 @@ -syntax = "proto3"; - -package tensorflow; -option cc_enable_arenas = true; -option java_outer_classname = "AttrValueProtos"; -option java_multiple_files = true; -option java_package = "org.tensorflow.framework"; -option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framework"; -import "tensor.proto"; -import "tensor_shape.proto"; -import "types.proto"; - -// Protocol buffer representing the value for an attr used to configure an Op. -// Comment indicates the corresponding attr type. Only the field matching the -// attr type may be filled. -message AttrValue { - // LINT.IfChange - message ListValue { - repeated bytes s = 2; // "list(string)" - repeated int64 i = 3 [packed = true]; // "list(int)" - repeated float f = 4 [packed = true]; // "list(float)" - repeated bool b = 5 [packed = true]; // "list(bool)" - repeated DataType type = 6 [packed = true]; // "list(type)" - repeated TensorShapeProto shape = 7; // "list(shape)" - repeated TensorProto tensor = 8; // "list(tensor)" - repeated NameAttrList func = 9; // "list(attr)" - } - // LINT.ThenChange(https://www.tensorflow.org/code/tensorflow/c/c_api.cc) - - oneof value { - bytes s = 2; // "string" - int64 i = 3; // "int" - float f = 4; // "float" - bool b = 5; // "bool" - DataType type = 6; // "type" - TensorShapeProto shape = 7; // "shape" - TensorProto tensor = 8; // "tensor" - ListValue list = 1; // any "list(...)" - - // "func" represents a function. func.name is a function's name or - // a primitive op's name. func.attr.first is the name of an attr - // defined for that function. func.attr.second is the value for - // that attr in the instantiation. - NameAttrList func = 10; - - // This is a placeholder only used in nodes defined inside a - // function. It indicates the attr value will be supplied when - // the function is instantiated. For example, let us suppose a - // node "N" in function "FN". "N" has an attr "A" with value - // placeholder = "foo". When FN is instantiated with attr "foo" - // set to "bar", the instantiated node N's attr A will have been - // given the value "bar". - string placeholder = 9; - } -} - -// A list of attr names and their values. The whole list is attached -// with a string name. E.g., MatMul[T=float]. -message NameAttrList { - string name = 1; - map attr = 2; -} diff --git a/protos/function.proto b/protos/function.proto deleted file mode 100644 index 6d107635..00000000 --- a/protos/function.proto +++ /dev/null @@ -1,113 +0,0 @@ -syntax = "proto3"; - -package tensorflow; -option cc_enable_arenas = true; -option java_outer_classname = "FunctionProtos"; -option java_multiple_files = true; -option java_package = "org.tensorflow.framework"; -option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framework"; -import "attr_value.proto"; -import "node_def.proto"; -import "op_def.proto"; - -// A library is a set of named functions. -message FunctionDefLibrary { - repeated FunctionDef function = 1; - repeated GradientDef gradient = 2; -} - -// A function can be instantiated when the runtime can bind every attr -// with a value. When a GraphDef has a call to a function, it must -// have binding for every attr defined in the signature. -// -// TODO(zhifengc): -// * device spec, etc. -message FunctionDef { - // The definition of the function's name, arguments, return values, - // attrs etc. - OpDef signature = 1; - - // Attributes specific to this function definition. - map attr = 5; - - // Attributes for function arguments. These attributes are the same set of - // valid attributes as to _Arg nodes. - message ArgAttrs { - map attr = 1; - } - map arg_attr = 7; - - // NOTE: field id 2 deleted on Jan 11, 2017, GraphDef version 21. - reserved 2; - - // In both of the following fields, there is the need to specify an - // output that is used as either the input to another node (in - // `node_def`) or as a return value of the function (in `ret`). - // Unlike the NodeDefs in GraphDef, we need to be able to specify a - // list in some cases (instead of just single outputs). Also, we - // need to be able to deal with lists of unknown length (so the - // output index may not be known at function definition time). So - // we use the following format instead: - // * "fun_in" where "fun_in" is the name of a function input arg in - // the `signature` field above. This represents that input, whether - // it is a single tensor or a list. - // * "fun_in:0" gives the first element of a function input arg (a - // non-list input is considered a list of length 1 for these - // purposes). - // * "node:out" where "node" is the name of a node in `node_def` and - // "out" is the name one of its op's output arguments (the name - // comes from the OpDef of the node's op). This represents that - // node's output, whether it is a single tensor or a list. - // Note: We enforce that an op's output arguments are never - // renamed in the backwards-compatibility test. - // * "node:out:0" gives the first element of a node output arg (a - // non-list output is considered a list of length 1 for these - // purposes). - // - // NOT CURRENTLY SUPPORTED (but may be in the future): - // * "node:out:-1" gives last element in a node output list - // * "node:out:1:" gives a list with all but the first element in a - // node output list - // * "node:out::-1" gives a list with all but the last element in a - // node output list - - // The body of the function. Unlike the NodeDefs in a GraphDef, attrs - // may have values of type `placeholder` and the `input` field uses - // the "output" format above. - - // By convention, "op" in node_def is resolved by consulting with a - // user-defined library first. If not resolved, "func" is assumed to - // be a builtin op. - repeated NodeDef node_def = 3; - - // A mapping from the output arg names from `signature` to the - // outputs from `node_def` that should be returned by the function. - map ret = 4; - - // A mapping from control output names from `signature` to node names in - // `node_def` which should be control outputs of this function. - map control_ret = 6; -} - -// GradientDef defines the gradient function of a function defined in -// a function library. -// -// A gradient function g (specified by gradient_func) for a function f -// (specified by function_name) must follow the following: -// -// The function 'f' must be a numerical function which takes N inputs -// and produces M outputs. Its gradient function 'g', which is a -// function taking N + M inputs and produces N outputs. -// -// I.e. if we have -// (y1, y2, ..., y_M) = f(x1, x2, ..., x_N), -// then, g is -// (dL/dx1, dL/dx2, ..., dL/dx_N) = g(x1, x2, ..., x_N, -// dL/dy1, dL/dy2, ..., dL/dy_M), -// where L is a scalar-value function of (x1, x2, ..., xN) (e.g., the -// loss function). dL/dx_i is the partial derivative of L with respect -// to x_i. -message GradientDef { - string function_name = 1; // The function name. - string gradient_func = 2; // The gradient function's name. -} diff --git a/protos/get_model_metadata.proto b/protos/get_model_metadata.proto deleted file mode 100644 index 60ddfd56..00000000 --- a/protos/get_model_metadata.proto +++ /dev/null @@ -1,30 +0,0 @@ -syntax = "proto3"; - -package tensorflow.serving; -option cc_enable_arenas = true; - -import "google/protobuf/any.proto"; -import "meta_graph.proto"; -import "model.proto"; - -// Message returned for "signature_def" field. -message SignatureDefMap { - map signature_def = 1; -}; - -message GetModelMetadataRequest { - // Model Specification indicating which model we are querying for metadata. - // If version is not specified, will use the latest (numerical) version. - ModelSpec model_spec = 1; - // Metadata fields to get. Currently supported: "signature_def". - repeated string metadata_field = 2; -} - -message GetModelMetadataResponse { - // Model Specification indicating which model this metadata belongs to. - ModelSpec model_spec = 1; - // Map of metadata field name to metadata field. The options for metadata - // field name are listed in GetModelMetadataRequest. Currently supported: - // "signature_def". - map metadata = 2; -} diff --git a/protos/graph.proto b/protos/graph.proto deleted file mode 100644 index 14d9edfa..00000000 --- a/protos/graph.proto +++ /dev/null @@ -1,56 +0,0 @@ -syntax = "proto3"; - -package tensorflow; -option cc_enable_arenas = true; -option java_outer_classname = "GraphProtos"; -option java_multiple_files = true; -option java_package = "org.tensorflow.framework"; -option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framework"; -import "node_def.proto"; -import "function.proto"; -import "versions.proto"; - -// Represents the graph of operations -message GraphDef { - repeated NodeDef node = 1; - - // Compatibility versions of the graph. See core/public/version.h for version - // history. The GraphDef version is distinct from the TensorFlow version, and - // each release of TensorFlow will support a range of GraphDef versions. - VersionDef versions = 4; - - // Deprecated single version field; use versions above instead. Since all - // GraphDef changes before "versions" was introduced were forward - // compatible, this field is entirely ignored. - int32 version = 3 [deprecated = true]; - - // EXPERIMENTAL. DO NOT USE OR DEPEND ON THIS YET. - // - // "library" provides user-defined functions. - // - // Naming: - // * library.function.name are in a flat namespace. - // NOTE: We may need to change it to be hierarchical to support - // different orgs. E.g., - // { "/google/nn", { ... }}, - // { "/google/vision", { ... }} - // { "/org_foo/module_bar", { ... }} - // map named_lib; - // * If node[i].op is the name of one function in "library", - // node[i] is deemed as a function call. Otherwise, node[i].op - // must be a primitive operation supported by the runtime. - // - // - // Function call semantics: - // - // * The callee may start execution as soon as some of its inputs - // are ready. The caller may want to use Tuple() mechanism to - // ensure all inputs are ready in the same time. - // - // * The consumer of return values may start executing as soon as - // the return values the consumer depends on are ready. The - // consumer may want to use Tuple() mechanism to ensure the - // consumer does not start until all return values of the callee - // function are ready. - FunctionDefLibrary library = 2; -}; diff --git a/protos/meta_graph.proto b/protos/meta_graph.proto deleted file mode 100644 index f1005543..00000000 --- a/protos/meta_graph.proto +++ /dev/null @@ -1,342 +0,0 @@ -syntax = "proto3"; - -package tensorflow; - -import "google/protobuf/any.proto"; -import "graph.proto"; -import "op_def.proto"; -import "tensor_shape.proto"; -import "types.proto"; -import "saved_object_graph.proto"; -import "saver.proto"; -import "struct.proto"; - -option cc_enable_arenas = true; -option java_outer_classname = "MetaGraphProtos"; -option java_multiple_files = true; -option java_package = "org.tensorflow.framework"; -option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/core_protos_go_proto"; - -// NOTE: This protocol buffer is evolving, and will go through revisions in the -// coming months. -// -// Protocol buffer containing the following which are necessary to restart -// training, run inference. It can be used to serialize/de-serialize memory -// objects necessary for running computation in a graph when crossing the -// process boundary. It can be used for long term storage of graphs, -// cross-language execution of graphs, etc. -// MetaInfoDef -// GraphDef -// SaverDef -// CollectionDef -// TensorInfo -// SignatureDef -message MetaGraphDef { - // Meta information regarding the graph to be exported. To be used by users - // of this protocol buffer to encode information regarding their meta graph. - message MetaInfoDef { - // User specified Version string. Can be the name of the model and revision, - // steps this model has been trained to, etc. - string meta_graph_version = 1; - - // A copy of the OpDefs used by the producer of this graph_def. - // Descriptions and Ops not used in graph_def are stripped out. - OpList stripped_op_list = 2; - - // A serialized protobuf. Can be the time this meta graph is created, or - // modified, or name of the model. - google.protobuf.Any any_info = 3; - - // User supplied tag(s) on the meta_graph and included graph_def. - // - // MetaGraphDefs should be tagged with their capabilities or use-cases. - // Examples: "train", "serve", "gpu", "tpu", etc. - // These tags enable loaders to access the MetaGraph(s) appropriate for a - // specific use-case or runtime environment. - repeated string tags = 4; - - // The __version__ string of the tensorflow build used to write this graph. - // This will be populated by the framework, which will overwrite any user - // supplied value. - string tensorflow_version = 5; - - // The __git_version__ string of the tensorflow build used to write this - // graph. This will be populated by the framework, which will overwrite any - // user supplied value. - string tensorflow_git_version = 6; - - // A flag to denote whether default-valued attrs have been stripped from - // the nodes in this graph_def. - bool stripped_default_attrs = 7; - - // FunctionDef name to aliases mapping. - map function_aliases = 8; - } - MetaInfoDef meta_info_def = 1; - - // GraphDef. - GraphDef graph_def = 2; - - // SaverDef. - SaverDef saver_def = 3; - - // collection_def: Map from collection name to collections. - // See CollectionDef section for details. - map collection_def = 4; - - // signature_def: Map from user supplied key for a signature to a single - // SignatureDef. - map signature_def = 5; - - // Asset file def to be used with the defined graph. - repeated AssetFileDef asset_file_def = 6; - - // Extra information about the structure of functions and stateful objects. - SavedObjectGraph object_graph_def = 7; -} - -// CollectionDef should cover most collections. -// To add a user-defined collection, do one of the following: -// 1. For simple data types, such as string, int, float: -// tf.add_to_collection("your_collection_name", your_simple_value) -// strings will be stored as bytes_list. -// -// 2. For Protobuf types, there are three ways to add them: -// 1) tf.add_to_collection("your_collection_name", -// your_proto.SerializeToString()) -// -// collection_def { -// key: "user_defined_bytes_collection" -// value { -// bytes_list { -// value: "queue_name: \"test_queue\"\n" -// } -// } -// } -// -// or -// -// 2) tf.add_to_collection("your_collection_name", str(your_proto)) -// -// collection_def { -// key: "user_defined_string_collection" -// value { -// bytes_list { -// value: "\n\ntest_queue" -// } -// } -// } -// -// or -// -// 3) any_buf = any_pb2.Any() -// tf.add_to_collection("your_collection_name", -// any_buf.Pack(your_proto)) -// -// collection_def { -// key: "user_defined_any_collection" -// value { -// any_list { -// value { -// type_url: "type.googleapis.com/tensorflow.QueueRunnerDef" -// value: "\n\ntest_queue" -// } -// } -// } -// } -// -// 3. For Python objects, implement to_proto() and from_proto(), and register -// them in the following manner: -// ops.register_proto_function("your_collection_name", -// proto_type, -// to_proto=YourPythonObject.to_proto, -// from_proto=YourPythonObject.from_proto) -// These functions will be invoked to serialize and de-serialize the -// collection. For example, -// ops.register_proto_function(ops.GraphKeys.GLOBAL_VARIABLES, -// proto_type=variable_pb2.VariableDef, -// to_proto=Variable.to_proto, -// from_proto=Variable.from_proto) -message CollectionDef { - // NodeList is used for collecting nodes in graph. For example - // collection_def { - // key: "summaries" - // value { - // node_list { - // value: "input_producer/ScalarSummary:0" - // value: "shuffle_batch/ScalarSummary:0" - // value: "ImageSummary:0" - // } - // } - message NodeList { - repeated string value = 1; - } - - // BytesList is used for collecting strings and serialized protobufs. For - // example: - // collection_def { - // key: "trainable_variables" - // value { - // bytes_list { - // value: "\n\017conv1/weights:0\022\024conv1/weights/Assign - // \032\024conv1/weights/read:0" - // value: "\n\016conv1/biases:0\022\023conv1/biases/Assign\032 - // \023conv1/biases/read:0" - // } - // } - // } - message BytesList { - repeated bytes value = 1; - } - - // Int64List is used for collecting int, int64 and long values. - message Int64List { - repeated int64 value = 1 [packed = true]; - } - - // FloatList is used for collecting float values. - message FloatList { - repeated float value = 1 [packed = true]; - } - - // AnyList is used for collecting Any protos. - message AnyList { - repeated google.protobuf.Any value = 1; - } - - oneof kind { - NodeList node_list = 1; - BytesList bytes_list = 2; - Int64List int64_list = 3; - FloatList float_list = 4; - AnyList any_list = 5; - } -} - -// Information about a Tensor necessary for feeding or retrieval. -message TensorInfo { - // For sparse tensors, The COO encoding stores a triple of values, indices, - // and shape. - message CooSparse { - // The shape of the values Tensor is [?]. Its dtype must be the dtype of - // the SparseTensor as a whole, given in the enclosing TensorInfo. - string values_tensor_name = 1; - - // The indices Tensor must have dtype int64 and shape [?, ?]. - string indices_tensor_name = 2; - - // The dynamic logical shape represented by the SparseTensor is recorded in - // the Tensor referenced here. It must have dtype int64 and shape [?]. - string dense_shape_tensor_name = 3; - } - - // Generic encoding for composite tensors. - message CompositeTensor { - // The serialized TypeSpec for the composite tensor. - TypeSpecProto type_spec = 1; - - // A TensorInfo for each flattened component tensor. - repeated TensorInfo components = 2; - } - - oneof encoding { - // For dense `Tensor`s, the name of the tensor in the graph. - string name = 1; - // There are many possible encodings of sparse matrices - // (https://en.wikipedia.org/wiki/Sparse_matrix). Currently, TensorFlow - // uses only the COO encoding. This is supported and documented in the - // SparseTensor Python class. - CooSparse coo_sparse = 4; - // Generic encoding for CompositeTensors. - CompositeTensor composite_tensor = 5; - } - DataType dtype = 2; - // The static shape should be recorded here, to the extent that it can - // be known in advance. In the case of a SparseTensor, this field describes - // the logical shape of the represented tensor (aka dense_shape). - TensorShapeProto tensor_shape = 3; -} - -// SignatureDef defines the signature of a computation supported by a TensorFlow -// graph. -// -// For example, a model with two loss computations, sharing a single input, -// might have the following signature_def map. -// -// Note that across the two SignatureDefs "loss_A" and "loss_B", the input key, -// output key, and method_name are identical, and will be used by system(s) that -// implement or rely upon this particular loss method. The output tensor names -// differ, demonstrating how different outputs can exist for the same method. -// -// signature_def { -// key: "loss_A" -// value { -// inputs { -// key: "input" -// value { -// name: "input:0" -// dtype: DT_STRING -// tensor_shape: ... -// } -// } -// outputs { -// key: "loss_output" -// value { -// name: "loss_output_A:0" -// dtype: DT_FLOAT -// tensor_shape: ... -// } -// } -// } -// ... -// method_name: "some/package/compute_loss" -// } -// signature_def { -// key: "loss_B" -// value { -// inputs { -// key: "input" -// value { -// name: "input:0" -// dtype: DT_STRING -// tensor_shape: ... -// } -// } -// outputs { -// key: "loss_output" -// value { -// name: "loss_output_B:0" -// dtype: DT_FLOAT -// tensor_shape: ... -// } -// } -// } -// ... -// method_name: "some/package/compute_loss" -// } -message SignatureDef { - // Named input parameters. - map inputs = 1; - // Named output parameters. - map outputs = 2; - // Extensible method_name information enabling third-party users to mark a - // SignatureDef as supporting a particular method. This enables producers and - // consumers of SignatureDefs, e.g. a model definition library and a serving - // library to have a clear hand-off regarding the semantics of a computation. - // - // Note that multiple SignatureDefs in a single MetaGraphDef may have the same - // method_name. This is commonly used to support multi-headed computation, - // where a single graph computation may return multiple results. - string method_name = 3; -} - -// An asset file def for a single file or a set of sharded files with the same -// name. -message AssetFileDef { - // The tensor to bind the asset filename to. - TensorInfo tensor_info = 1; - // The filename within an assets directory. Note: does not include the path - // prefix, i.e. directories. For an asset at /tmp/path/vocab.txt, the filename - // would be "vocab.txt". - string filename = 2; -} diff --git a/protos/model.proto b/protos/model.proto deleted file mode 100644 index 56493f68..00000000 --- a/protos/model.proto +++ /dev/null @@ -1,33 +0,0 @@ -syntax = "proto3"; - -package tensorflow.serving; -option cc_enable_arenas = true; - -import "google/protobuf/wrappers.proto"; - -// Metadata for an inference request such as the model name and version. -message ModelSpec { - // Required servable name. - string name = 1; - - // Optional choice of which version of the model to use. - // - // Recommended to be left unset in the common case. Should be specified only - // when there is a strong version consistency requirement. - // - // When left unspecified, the system will serve the best available version. - // This is typically the latest version, though during version transitions, - // notably when serving on a fleet of instances, may be either the previous or - // new version. - oneof version_choice { - // Use this specific version number. - google.protobuf.Int64Value version = 2; - - // Use the version associated with the given label. - string version_label = 4; - } - - // A named signature to evaluate. If unspecified, the default signature will - // be used. - string signature_name = 3; -} diff --git a/protos/node_def.proto b/protos/node_def.proto deleted file mode 100644 index 1e0da16f..00000000 --- a/protos/node_def.proto +++ /dev/null @@ -1,86 +0,0 @@ -syntax = "proto3"; - -package tensorflow; -option cc_enable_arenas = true; -option java_outer_classname = "NodeProto"; -option java_multiple_files = true; -option java_package = "org.tensorflow.framework"; -option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framework"; -import "attr_value.proto"; - -message NodeDef { - // The name given to this operator. Used for naming inputs, - // logging, visualization, etc. Unique within a single GraphDef. - // Must match the regexp "[A-Za-z0-9.][A-Za-z0-9_>./]*". - string name = 1; - - // The operation name. There may be custom parameters in attrs. - // Op names starting with an underscore are reserved for internal use. - string op = 2; - - // Each input is "node:src_output" with "node" being a string name and - // "src_output" indicating which output tensor to use from "node". If - // "src_output" is 0 the ":0" suffix can be omitted. Regular inputs - // may optionally be followed by control inputs that have the format - // "^node". - repeated string input = 3; - - // A (possibly partial) specification for the device on which this - // node should be placed. - // The expected syntax for this string is as follows: - // - // DEVICE_SPEC ::= PARTIAL_SPEC - // - // PARTIAL_SPEC ::= ("/" CONSTRAINT) * - // CONSTRAINT ::= ("job:" JOB_NAME) - // | ("replica:" [1-9][0-9]*) - // | ("task:" [1-9][0-9]*) - // | ("device:" [A-Za-z]* ":" ([1-9][0-9]* | "*") ) - // - // Valid values for this string include: - // * "/job:worker/replica:0/task:1/device:GPU:3" (full specification) - // * "/job:worker/device:GPU:3" (partial specification) - // * "" (no specification) - // - // If the constraints do not resolve to a single device (or if this - // field is empty or not present), the runtime will attempt to - // choose a device automatically. - string device = 4; - - // Operation-specific graph-construction-time configuration. - // Note that this should include all attrs defined in the - // corresponding OpDef, including those with a value matching - // the default -- this allows the default to change and makes - // NodeDefs easier to interpret on their own. However, if - // an attr with a default is not specified in this list, the - // default will be used. - // The "names" (keys) must match the regexp "[a-z][a-z0-9_]+" (and - // one of the names from the corresponding OpDef's attr field). - // The values must have a type matching the corresponding OpDef - // attr's type field. - // TODO(josh11b): Add some examples here showing best practices. - map attr = 5; - - message ExperimentalDebugInfo { - // Opaque string inserted into error messages created by the runtime. - // - // This is intended to store the list of names of the nodes from the - // original graph that this node was derived. For example if this node, say - // C, was result of a fusion of 2 nodes A and B, then 'original_node' would - // be {A, B}. This information can be used to map errors originating at the - // current node to some top level source code. - repeated string original_node_names = 1; - - // This is intended to store the list of names of the functions from the - // original graph that this node was derived. For example if this node, say - // C, was result of a fusion of node A in function FA and node B in function - // FB, then `original_funcs` would be {FA, FB}. If the node is in the top - // level graph, the `original_func` is empty. This information, with the - // `original_node_names` can be used to map errors originating at the - // current ndoe to some top level source code. - repeated string original_func_names = 2; - }; - - // This stores debug information associated with the node. - ExperimentalDebugInfo experimental_debug_info = 6; -}; diff --git a/protos/op_def.proto b/protos/op_def.proto deleted file mode 100644 index 9f5f6bf8..00000000 --- a/protos/op_def.proto +++ /dev/null @@ -1,170 +0,0 @@ -syntax = "proto3"; - -package tensorflow; -option cc_enable_arenas = true; -option java_outer_classname = "OpDefProtos"; -option java_multiple_files = true; -option java_package = "org.tensorflow.framework"; -option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framework"; -import "attr_value.proto"; -import "types.proto"; - -// Defines an operation. A NodeDef in a GraphDef specifies an Op by -// using the "op" field which should match the name of a OpDef. -// LINT.IfChange -message OpDef { - // Op names starting with an underscore are reserved for internal use. - // Names should be CamelCase and match the regexp "[A-Z][a-zA-Z0-9>_]*". - string name = 1; - - // For describing inputs and outputs. - message ArgDef { - // Name for the input/output. Should match the regexp "[a-z][a-z0-9_]*". - string name = 1; - - // Human readable description. - string description = 2; - - // Describes the type of one or more tensors that are accepted/produced - // by this input/output arg. The only legal combinations are: - // * For a single tensor: either the "type" field is set or the - // "type_attr" field is set to the name of an attr with type "type". - // * For a sequence of tensors with the same type: the "number_attr" - // field will be set to the name of an attr with type "int", and - // either the "type" or "type_attr" field will be set as for - // single tensors. - // * For a sequence of tensors, the "type_list_attr" field will be set - // to the name of an attr with type "list(type)". - DataType type = 3; - string type_attr = 4; // if specified, attr must have type "type" - string number_attr = 5; // if specified, attr must have type "int" - // If specified, attr must have type "list(type)", and none of - // type, type_attr, and number_attr may be specified. - string type_list_attr = 6; - - // For inputs: if true, the inputs are required to be refs. - // By default, inputs can be either refs or non-refs. - // For outputs: if true, outputs are refs, otherwise they are not. - bool is_ref = 16; - }; - - // Description of the input(s). - repeated ArgDef input_arg = 2; - - // Description of the output(s). - repeated ArgDef output_arg = 3; - - // Named control outputs for this operation. Useful only for composite - // operations (i.e. functions) which want to name different control outputs. - repeated string control_output = 20; - - // Description of the graph-construction-time configuration of this - // Op. That is to say, this describes the attr fields that will - // be specified in the NodeDef. - message AttrDef { - // A descriptive name for the argument. May be used, e.g. by the - // Python client, as a keyword argument name, and so should match - // the regexp "[a-z][a-z0-9_]+". - string name = 1; - - // One of the type names from attr_value.proto ("string", "list(string)", - // "int", etc.). - string type = 2; - - // A reasonable default for this attribute if the user does not supply - // a value. If not specified, the user must supply a value. - AttrValue default_value = 3; - - // Human-readable description. - string description = 4; - - // TODO(josh11b): bool is_optional? - - // --- Constraints --- - // These constraints are only in effect if specified. Default is no - // constraints. - - // For type == "int", this is a minimum value. For "list(___)" - // types, this is the minimum length. - bool has_minimum = 5; - int64 minimum = 6; - - // The set of allowed values. Has type that is the "list" version - // of the "type" field above (uses the "list" field of AttrValue). - // If type == "type" or "list(type)" above, then the "type" field - // of "allowed_values.list" has the set of allowed DataTypes. - // If type == "string" or "list(string)", then the "s" field of - // "allowed_values.list" has the set of allowed strings. - AttrValue allowed_values = 7; - } - repeated AttrDef attr = 4; - - // Optional deprecation based on GraphDef versions. - OpDeprecation deprecation = 8; - - // One-line human-readable description of what the Op does. - string summary = 5; - - // Additional, longer human-readable description of what the Op does. - string description = 6; - - // ------------------------------------------------------------------------- - // Which optimizations this operation can participate in. - - // True if the operation is commutative ("op(a,b) == op(b,a)" for all inputs) - bool is_commutative = 18; - - // If is_aggregate is true, then this operation accepts N >= 2 - // inputs and produces 1 output all of the same type. Should be - // associative and commutative, and produce output with the same - // shape as the input. The optimizer may replace an aggregate op - // taking input from multiple devices with a tree of aggregate ops - // that aggregate locally within each device (and possibly within - // groups of nearby devices) before communicating. - // TODO(josh11b): Implement that optimization. - bool is_aggregate = 16; // for things like add - - // Other optimizations go here, like - // can_alias_input, rewrite_when_output_unused, partitioning_strategy, etc. - - // ------------------------------------------------------------------------- - // Optimization constraints. - - // Ops are marked as stateful if their behavior depends on some state beyond - // their input tensors (e.g. variable reading op) or if they have - // a side-effect (e.g. printing or asserting ops). Equivalently, stateless ops - // must always produce the same output for the same input and have - // no side-effects. - // - // By default Ops may be moved between devices. Stateful ops should - // either not be moved, or should only be moved if that state can also - // be moved (e.g. via some sort of save / restore). - // Stateful ops are guaranteed to never be optimized away by Common - // Subexpression Elimination (CSE). - bool is_stateful = 17; // for things like variables, queue - - // ------------------------------------------------------------------------- - // Non-standard options. - - // By default, all inputs to an Op must be initialized Tensors. Ops - // that may initialize tensors for the first time should set this - // field to true, to allow the Op to take an uninitialized Tensor as - // input. - bool allows_uninitialized_input = 19; // for Assign, etc. -}; -// LINT.ThenChange( -// https://www.tensorflow.org/code/tensorflow/core/framework/op_def_util.cc) - -// Information about version-dependent deprecation of an op -message OpDeprecation { - // First GraphDef version at which the op is disallowed. - int32 version = 1; - - // Explanation of why it was deprecated and what to use instead. - string explanation = 2; -}; - -// A collection of OpDefs -message OpList { - repeated OpDef op = 1; -}; diff --git a/protos/predict.proto b/protos/predict.proto deleted file mode 100644 index f4a83f8c..00000000 --- a/protos/predict.proto +++ /dev/null @@ -1,40 +0,0 @@ -syntax = "proto3"; - -package tensorflow.serving; -option cc_enable_arenas = true; - -import "tensor.proto"; -import "model.proto"; - -// PredictRequest specifies which TensorFlow model to run, as well as -// how inputs are mapped to tensors and how outputs are filtered before -// returning to user. -message PredictRequest { - // Model Specification. If version is not specified, will use the latest - // (numerical) version. - ModelSpec model_spec = 1; - - // Input tensors. - // Names of input tensor are alias names. The mapping from aliases to real - // input tensor names is stored in the SavedModel export as a prediction - // SignatureDef under the 'inputs' field. - map inputs = 2; - - // Output filter. - // Names specified are alias names. The mapping from aliases to real output - // tensor names is stored in the SavedModel export as a prediction - // SignatureDef under the 'outputs' field. - // Only tensors specified here will be run/fetched and returned, with the - // exception that when none is specified, all tensors specified in the - // named signature will be run/fetched and returned. - repeated string output_filter = 3; -} - -// Response for PredictRequest on successful run. -message PredictResponse { - // Effective Model Specification used to process PredictRequest. - ModelSpec model_spec = 2; - - // Output tensors. - map outputs = 1; -} diff --git a/protos/prediction_service.proto b/protos/prediction_service.proto deleted file mode 100644 index 681796a6..00000000 --- a/protos/prediction_service.proto +++ /dev/null @@ -1,31 +0,0 @@ -syntax = "proto3"; - -package tensorflow.serving; -option cc_enable_arenas = true; - -// import "tensorflow_serving/apis/classification.proto"; -// import "tensorflow_serving/apis/inference.proto"; -// import "tensorflow_serving/apis/regression.proto"; -import "get_model_metadata.proto"; -import "predict.proto"; - -// open source marker; do not remove -// PredictionService provides access to machine-learned models loaded by -// model_servers. -service PredictionService { - // Classify. - // rpc Classify(ClassificationRequest) returns (ClassificationResponse); - - // Regress. - // rpc Regress(RegressionRequest) returns (RegressionResponse); - - // Predict -- provides access to loaded TensorFlow model. - rpc Predict(PredictRequest) returns (PredictResponse); - - // MultiInference API for multi-headed models. - // rpc MultiInference(MultiInferenceRequest) returns (MultiInferenceResponse); - - // GetModelMetadata - provides access to metadata for loaded models. - rpc GetModelMetadata(GetModelMetadataRequest) - returns (GetModelMetadataResponse); -} diff --git a/protos/resource_handle.proto b/protos/resource_handle.proto deleted file mode 100644 index 82194668..00000000 --- a/protos/resource_handle.proto +++ /dev/null @@ -1,42 +0,0 @@ -syntax = "proto3"; - -package tensorflow; -option cc_enable_arenas = true; -option java_outer_classname = "ResourceHandle"; -option java_multiple_files = true; -option java_package = "org.tensorflow.framework"; -option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framework"; - -import "tensor_shape.proto"; -import "types.proto"; - -// Protocol buffer representing a handle to a tensorflow resource. Handles are -// not valid across executions, but can be serialized back and forth from within -// a single run. -message ResourceHandleProto { - // Unique name for the device containing the resource. - string device = 1; - - // Container in which this resource is placed. - string container = 2; - - // Unique name of this resource. - string name = 3; - - // Hash code for the type of the resource. Is only valid in the same device - // and in the same execution. - uint64 hash_code = 4; - - // For debug-only, the name of the type pointed to by this handle, if - // available. - string maybe_type_name = 5; - - // Protocol buffer representing a pair of (data type, tensor shape). - message DtypeAndShape { - DataType dtype = 1; - TensorShapeProto shape = 2; - } - - // Data types and shapes for the underlying resource. - repeated DtypeAndShape dtypes_and_shapes = 6; -}; diff --git a/protos/saved_object_graph.proto b/protos/saved_object_graph.proto deleted file mode 100644 index 0bad3a57..00000000 --- a/protos/saved_object_graph.proto +++ /dev/null @@ -1,164 +0,0 @@ -syntax = "proto3"; - -import "trackable_object_graph.proto"; -import "struct.proto"; -import "tensor_shape.proto"; -import "types.proto"; -import "versions.proto"; -import "variable.proto"; - -option cc_enable_arenas = true; - -package tensorflow; - -// A SavedObjectGraph is part of object-based SavedModels in TF 2.0. It -// describes the directed graph of Python objects (or equivalent in other -// languages) that make up a model, with nodes[0] at the root. - -// SavedObjectGraph shares some structure with TrackableObjectGraph, but -// SavedObjectGraph belongs to the MetaGraph and contains pointers to functions -// and type information, while TrackableObjectGraph lives in the checkpoint -// and contains pointers only to variable values. - -message SavedObjectGraph { - // Flattened list of objects in the object graph. - // - // The position of the object in this list indicates its id. - // Nodes[0] is considered the root node. - repeated SavedObject nodes = 1; - - // Information about captures and output structures in concrete functions. - // Referenced from SavedBareConcreteFunction and SavedFunction. - map concrete_functions = 2; -} - -message SavedObject { - // Objects which this object depends on: named edges in the dependency - // graph. - // - // Note: currently only valid if kind == "user_object". - repeated TrackableObjectGraph.TrackableObject.ObjectReference - children = 1; - - // Removed when forking SavedObject from TrackableObjectGraph. - reserved "attributes"; - reserved 2; - - // Slot variables owned by this object. This describes the three-way - // (optimizer, variable, slot variable) relationship; none of the three - // depend on the others directly. - // - // Note: currently only valid if kind == "user_object". - repeated TrackableObjectGraph.TrackableObject.SlotVariableReference - slot_variables = 3; - - oneof kind { - SavedUserObject user_object = 4; - SavedAsset asset = 5; - SavedFunction function = 6; - SavedVariable variable = 7; - SavedBareConcreteFunction bare_concrete_function = 8; - SavedConstant constant = 9; - SavedResource resource = 10; - } -} - -// A SavedUserObject is an object (in the object-oriented language of the -// TensorFlow program) of some user- or framework-defined class other than -// those handled specifically by the other kinds of SavedObjects. -// -// This object cannot be evaluated as a tensor, and therefore cannot be bound -// to an input of a function. -message SavedUserObject { - // Corresponds to a registration of the type to use in the loading program. - string identifier = 1; - // Version information from the producer of this SavedUserObject. - VersionDef version = 2; - // Initialization-related metadata. - string metadata = 3; -} - -// A SavedAsset points to an asset in the MetaGraph. -// -// When bound to a function this object evaluates to a tensor with the absolute -// filename. Users should not depend on a particular part of the filename to -// remain stable (e.g. basename could be changed). -message SavedAsset { - // Index into `MetaGraphDef.asset_file_def[]` that describes the Asset. - // - // Only the field `AssetFileDef.filename` is used. Other fields, such as - // `AssetFileDef.tensor_info`, MUST be ignored. - int32 asset_file_def_index = 1; -} - -// A function with multiple signatures, possibly with non-Tensor arguments. -message SavedFunction { - repeated string concrete_functions = 1; - FunctionSpec function_spec = 2; -} - -// Stores low-level information about a concrete function. Referenced in either -// a SavedFunction or a SavedBareConcreteFunction. -message SavedConcreteFunction { - // Bound inputs to the function. The SavedObjects identified by the node ids - // given here are appended as extra inputs to the caller-supplied inputs. - // The only types of SavedObjects valid here are SavedVariable, SavedResource - // and SavedAsset. - repeated int32 bound_inputs = 2; - // Input in canonicalized form that was received to create this concrete - // function. - StructuredValue canonicalized_input_signature = 3; - // Output that was the return value of this function after replacing all - // Tensors with TensorSpecs. This can be an arbitrary nested function and will - // be used to reconstruct the full structure from pure tensors. - StructuredValue output_signature = 4; -} - -message SavedBareConcreteFunction { - // Identifies a SavedConcreteFunction. - string concrete_function_name = 1; - - // A sequence of unique strings, one per Tensor argument. - repeated string argument_keywords = 2; - // The prefix of `argument_keywords` which may be identified by position. - int64 allowed_positional_arguments = 3; -} - -message SavedConstant { - // An Operation name for a ConstantOp in this SavedObjectGraph's MetaGraph. - string operation = 1; -} - -// Represents a Variable that is initialized by loading the contents from the -// checkpoint. -message SavedVariable { - DataType dtype = 1; - TensorShapeProto shape = 2; - bool trainable = 3; - VariableSynchronization synchronization = 4; - VariableAggregation aggregation = 5; - string name = 6; -} - -// Represents `FunctionSpec` used in `Function`. This represents a -// function that has been wrapped as a TensorFlow `Function`. -message FunctionSpec { - // Full arg spec from inspect.getfullargspec(). - StructuredValue fullargspec = 1; - // Whether this represents a class method. - bool is_method = 2; - // The input signature, if specified. - StructuredValue input_signature = 5; - - reserved 3, 4; -} - -// A SavedResource represents a TF object that holds state during its lifetime. -// An object of this type can have a reference to a: -// create_resource() and an initialize() function. -message SavedResource { - // A device specification indicating a required placement for the resource - // creation function, e.g. "CPU". An empty string allows the user to select a - // device. - string device = 1; -} diff --git a/protos/saver.proto b/protos/saver.proto deleted file mode 100644 index 42453861..00000000 --- a/protos/saver.proto +++ /dev/null @@ -1,47 +0,0 @@ -syntax = "proto3"; - -package tensorflow; -option cc_enable_arenas = true; -option java_outer_classname = "SaverProtos"; -option java_multiple_files = true; -option java_package = "org.tensorflow.util"; -option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/protobuf"; - -// Protocol buffer representing the configuration of a Saver. -message SaverDef { - // The name of the tensor in which to specify the filename when saving or - // restoring a model checkpoint. - string filename_tensor_name = 1; - - // The operation to run when saving a model checkpoint. - string save_tensor_name = 2; - - // The operation to run when restoring a model checkpoint. - string restore_op_name = 3; - - // Maximum number of checkpoints to keep. If 0, no checkpoints are deleted. - int32 max_to_keep = 4; - - // Shard the save files, one per device that has Variable nodes. - bool sharded = 5; - - // How often to keep an additional checkpoint. If not specified, only the last - // "max_to_keep" checkpoints are kept; if specified, in addition to keeping - // the last "max_to_keep" checkpoints, an additional checkpoint will be kept - // for every n hours of training. - float keep_checkpoint_every_n_hours = 6; - - // A version number that identifies a different on-disk checkpoint format. - // Usually, each subclass of BaseSaverBuilder works with a particular - // version/format. However, it is possible that the same builder may be - // upgraded to support a newer checkpoint format in the future. - enum CheckpointFormatVersion { - // Internal legacy format. - LEGACY = 0; - // Deprecated format: tf.Saver() which works with tensorflow::table::Table. - V1 = 1; - // Current format: more efficient. - V2 = 2; - } - CheckpointFormatVersion version = 7; -} diff --git a/protos/struct.proto b/protos/struct.proto deleted file mode 100644 index 7b590903..00000000 --- a/protos/struct.proto +++ /dev/null @@ -1,134 +0,0 @@ -syntax = "proto3"; - -import "tensor_shape.proto"; -import "types.proto"; - -package tensorflow; - -// `StructuredValue` represents a dynamically typed value representing various -// data structures that are inspired by Python data structures typically used in -// TensorFlow functions as inputs and outputs. -// -// For example when saving a Layer there may be a `training` argument. If the -// user passes a boolean True/False, that switches between two concrete -// TensorFlow functions. In order to switch between them in the same way after -// loading the SavedModel, we need to represent "True" and "False". -// -// A more advanced example might be a function which takes a list of -// dictionaries mapping from strings to Tensors. In order to map from -// user-specified arguments `[{"a": tf.constant(1.)}, {"q": tf.constant(3.)}]` -// after load to the right saved TensorFlow function, we need to represent the -// nested structure and the strings, recording that we have a trace for anything -// matching `[{"a": tf.TensorSpec(None, tf.float32)}, {"q": tf.TensorSpec([], -// tf.float64)}]` as an example. -// -// Likewise functions may return nested structures of Tensors, for example -// returning a dictionary mapping from strings to Tensors. In order for the -// loaded function to return the same structure we need to serialize it. -// -// This is an ergonomic aid for working with loaded SavedModels, not a promise -// to serialize all possible function signatures. For example we do not expect -// to pickle generic Python objects, and ideally we'd stay language-agnostic. -message StructuredValue { - // The kind of value. - oneof kind { - // Represents None. - NoneValue none_value = 1; - - // Represents a double-precision floating-point value (a Python `float`). - double float64_value = 11; - // Represents a signed integer value, limited to 64 bits. - // Larger values from Python's arbitrary-precision integers are unsupported. - sint64 int64_value = 12; - // Represents a string of Unicode characters stored in a Python `str`. - // In Python 3, this is exactly what type `str` is. - // In Python 2, this is the UTF-8 encoding of the characters. - // For strings with ASCII characters only (as often used in TensorFlow code) - // there is effectively no difference between the language versions. - // The obsolescent `unicode` type of Python 2 is not supported here. - string string_value = 13; - // Represents a boolean value. - bool bool_value = 14; - - // Represents a TensorShape. - tensorflow.TensorShapeProto tensor_shape_value = 31; - // Represents an enum value for dtype. - tensorflow.DataType tensor_dtype_value = 32; - // Represents a value for tf.TensorSpec. - TensorSpecProto tensor_spec_value = 33; - // Represents a value for tf.TypeSpec. - TypeSpecProto type_spec_value = 34; - - // Represents a list of `Value`. - ListValue list_value = 51; - // Represents a tuple of `Value`. - TupleValue tuple_value = 52; - // Represents a dict `Value`. - DictValue dict_value = 53; - // Represents Python's namedtuple. - NamedTupleValue named_tuple_value = 54; - } -} - -// Represents None. -message NoneValue {} - -// Represents a Python list. -message ListValue { - repeated StructuredValue values = 1; -} - -// Represents a Python tuple. -message TupleValue { - repeated StructuredValue values = 1; -} - -// Represents a Python dict keyed by `str`. -// The comment on Unicode from Value.string_value applies analogously. -message DictValue { - map fields = 1; -} - -// Represents a (key, value) pair. -message PairValue { - string key = 1; - StructuredValue value = 2; -} - -// Represents Python's namedtuple. -message NamedTupleValue { - string name = 1; - repeated PairValue values = 2; -} - -// A protobuf to tf.TensorSpec. -message TensorSpecProto { - string name = 1; - tensorflow.TensorShapeProto shape = 2; - tensorflow.DataType dtype = 3; -} - -// Represents a tf.TypeSpec -message TypeSpecProto { - enum TypeSpecClass { - UNKNOWN = 0; - SPARSE_TENSOR_SPEC = 1; // tf.SparseTensorSpec - INDEXED_SLICES_SPEC = 2; // tf.IndexedSlicesSpec - RAGGED_TENSOR_SPEC = 3; // tf.RaggedTensorSpec - TENSOR_ARRAY_SPEC = 4; // tf.TensorArraySpec - DATA_DATASET_SPEC = 5; // tf.data.DatasetSpec - DATA_ITERATOR_SPEC = 6; // IteratorSpec from data/ops/iterator_ops.py - OPTIONAL_SPEC = 7; // tf.OptionalSpec - PER_REPLICA_SPEC = 8; // PerReplicaSpec from distribute/values.py - } - TypeSpecClass type_spec_class = 1; - - // The value returned by TypeSpec._serialize(). - StructuredValue type_state = 2; - - // This is currently redundant with the type_spec_class enum, and is only - // used for error reporting. In particular, if you use an older binary to - // load a newer model, and the model uses a TypeSpecClass that the older - // binary doesn't support, then this lets us display a useful error message. - string type_spec_class_name = 3; -} diff --git a/protos/tensor.proto b/protos/tensor.proto deleted file mode 100644 index 5d4d66ae..00000000 --- a/protos/tensor.proto +++ /dev/null @@ -1,94 +0,0 @@ -syntax = "proto3"; - -package tensorflow; -option cc_enable_arenas = true; -option java_outer_classname = "TensorProtos"; -option java_multiple_files = true; -option java_package = "org.tensorflow.framework"; -option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framework"; -import "resource_handle.proto"; -import "tensor_shape.proto"; -import "types.proto"; - -// Protocol buffer representing a tensor. -message TensorProto { - DataType dtype = 1; - - // Shape of the tensor. TODO(touts): sort out the 0-rank issues. - TensorShapeProto tensor_shape = 2; - - // Only one of the representations below is set, one of "tensor_contents" and - // the "xxx_val" attributes. We are not using oneof because as oneofs cannot - // contain repeated fields it would require another extra set of messages. - - // Version number. - // - // In version 0, if the "repeated xxx" representations contain only one - // element, that element is repeated to fill the shape. This makes it easy - // to represent a constant Tensor with a single value. - int32 version_number = 3; - - // Serialized raw tensor content from either Tensor::AsProtoTensorContent or - // memcpy in tensorflow::grpc::EncodeTensorToByteBuffer. This representation - // can be used for all tensor types. The purpose of this representation is to - // reduce serialization overhead during RPC call by avoiding serialization of - // many repeated small items. - bytes tensor_content = 4; - - // Type specific representations that make it easy to create tensor protos in - // all languages. Only the representation corresponding to "dtype" can - // be set. The values hold the flattened representation of the tensor in - // row major order. - - // DT_HALF, DT_BFLOAT16. Note that since protobuf has no int16 type, we'll - // have some pointless zero padding for each value here. - repeated int32 half_val = 13 [packed = true]; - - // DT_FLOAT. - repeated float float_val = 5 [packed = true]; - - // DT_DOUBLE. - repeated double double_val = 6 [packed = true]; - - // DT_INT32, DT_INT16, DT_INT8, DT_UINT8. - repeated int32 int_val = 7 [packed = true]; - - // DT_STRING - repeated bytes string_val = 8; - - // DT_COMPLEX64. scomplex_val(2*i) and scomplex_val(2*i+1) are real - // and imaginary parts of i-th single precision complex. - repeated float scomplex_val = 9 [packed = true]; - - // DT_INT64 - repeated int64 int64_val = 10 [packed = true]; - - // DT_BOOL - repeated bool bool_val = 11 [packed = true]; - - // DT_COMPLEX128. dcomplex_val(2*i) and dcomplex_val(2*i+1) are real - // and imaginary parts of i-th double precision complex. - repeated double dcomplex_val = 12 [packed = true]; - - // DT_RESOURCE - repeated ResourceHandleProto resource_handle_val = 14; - - // DT_VARIANT - repeated VariantTensorDataProto variant_val = 15; - - // DT_UINT32 - repeated uint32 uint32_val = 16 [packed = true]; - - // DT_UINT64 - repeated uint64 uint64_val = 17 [packed = true]; -}; - -// Protocol buffer representing the serialization format of DT_VARIANT tensors. -message VariantTensorDataProto { - // Name of the type of objects being serialized. - string type_name = 1; - // Portions of the object that are not Tensors. - bytes metadata = 2; - // Tensors contained within objects being serialized. - repeated TensorProto tensors = 3; -} diff --git a/protos/tensor_shape.proto b/protos/tensor_shape.proto deleted file mode 100644 index 286156a0..00000000 --- a/protos/tensor_shape.proto +++ /dev/null @@ -1,46 +0,0 @@ -// Protocol buffer representing the shape of tensors. - -syntax = "proto3"; -option cc_enable_arenas = true; -option java_outer_classname = "TensorShapeProtos"; -option java_multiple_files = true; -option java_package = "org.tensorflow.framework"; -option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framework"; - -package tensorflow; - -// Dimensions of a tensor. -message TensorShapeProto { - // One dimension of the tensor. - message Dim { - // Size of the tensor in that dimension. - // This value must be >= -1, but values of -1 are reserved for "unknown" - // shapes (values of -1 mean "unknown" dimension). Certain wrappers - // that work with TensorShapeProto may fail at runtime when deserializing - // a TensorShapeProto containing a dim value of -1. - int64 size = 1; - - // Optional name of the tensor dimension. - string name = 2; - }; - - // Dimensions of the tensor, such as {"input", 30}, {"output", 40} - // for a 30 x 40 2D tensor. If an entry has size -1, this - // corresponds to a dimension of unknown size. The names are - // optional. - // - // The order of entries in "dim" matters: It indicates the layout of the - // values in the tensor in-memory representation. - // - // The first entry in "dim" is the outermost dimension used to layout the - // values, the last entry is the innermost dimension. This matches the - // in-memory layout of RowMajor Eigen tensors. - // - // If "dim.size()" > 0, "unknown_rank" must be false. - repeated Dim dim = 2; - - // If true, the number of dimensions in the shape is unknown. - // - // If true, "dim.size()" must be 0. - bool unknown_rank = 3; -}; diff --git a/protos/trackable_object_graph.proto b/protos/trackable_object_graph.proto deleted file mode 100644 index 02d852e6..00000000 --- a/protos/trackable_object_graph.proto +++ /dev/null @@ -1,59 +0,0 @@ -syntax = "proto3"; - -option cc_enable_arenas = true; - -package tensorflow; - -// A TensorBundle addition which saves extra information about the objects which -// own variables, allowing for more robust checkpoint loading into modified -// programs. - -message TrackableObjectGraph { - message TrackableObject { - message ObjectReference { - // An index into `TrackableObjectGraph.nodes`, indicating the object - // being referenced. - int32 node_id = 1; - // A user-provided name for the edge. - string local_name = 2; - } - - message SerializedTensor { - // A name for the Tensor. Simple variables have only one - // `SerializedTensor` named "VARIABLE_VALUE" by convention. This value may - // be restored on object creation as an optimization. - string name = 1; - // The full name of the variable/tensor, if applicable. Used to allow - // name-based loading of checkpoints which were saved using an - // object-based API. Should match the checkpoint key which would have been - // assigned by tf.train.Saver. - string full_name = 2; - // The generated name of the Tensor in the checkpoint. - string checkpoint_key = 3; - // Whether checkpoints should be considered as matching even without this - // value restored. Used for non-critical values which don't affect the - // TensorFlow graph, such as layer configurations. - bool optional_restore = 4; - } - - message SlotVariableReference { - // An index into `TrackableObjectGraph.nodes`, indicating the - // variable object this slot was created for. - int32 original_variable_node_id = 1; - // The name of the slot (e.g. "m"/"v"). - string slot_name = 2; - // An index into `TrackableObjectGraph.nodes`, indicating the - // `Object` with the value of the slot variable. - int32 slot_variable_node_id = 3; - } - - // Objects which this object depends on. - repeated ObjectReference children = 1; - // Serialized data specific to this object. - repeated SerializedTensor attributes = 2; - // Slot variables owned by this object. - repeated SlotVariableReference slot_variables = 3; - } - - repeated TrackableObject nodes = 1; -} diff --git a/protos/types.proto b/protos/types.proto deleted file mode 100644 index 5356f9f9..00000000 --- a/protos/types.proto +++ /dev/null @@ -1,76 +0,0 @@ -syntax = "proto3"; - -package tensorflow; -option cc_enable_arenas = true; -option java_outer_classname = "TypesProtos"; -option java_multiple_files = true; -option java_package = "org.tensorflow.framework"; -option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framework"; - -// (== suppress_warning documentation-presence ==) -// LINT.IfChange -enum DataType { - // Not a legal value for DataType. Used to indicate a DataType field - // has not been set. - DT_INVALID = 0; - - // Data types that all computation devices are expected to be - // capable to support. - DT_FLOAT = 1; - DT_DOUBLE = 2; - DT_INT32 = 3; - DT_UINT8 = 4; - DT_INT16 = 5; - DT_INT8 = 6; - DT_STRING = 7; - DT_COMPLEX64 = 8; // Single-precision complex - DT_INT64 = 9; - DT_BOOL = 10; - DT_QINT8 = 11; // Quantized int8 - DT_QUINT8 = 12; // Quantized uint8 - DT_QINT32 = 13; // Quantized int32 - DT_BFLOAT16 = 14; // Float32 truncated to 16 bits. Only for cast ops. - DT_QINT16 = 15; // Quantized int16 - DT_QUINT16 = 16; // Quantized uint16 - DT_UINT16 = 17; - DT_COMPLEX128 = 18; // Double-precision complex - DT_HALF = 19; - DT_RESOURCE = 20; - DT_VARIANT = 21; // Arbitrary C++ data types - DT_UINT32 = 22; - DT_UINT64 = 23; - - // Do not use! These are only for parameters. Every enum above - // should have a corresponding value below (verified by types_test). - DT_FLOAT_REF = 101; - DT_DOUBLE_REF = 102; - DT_INT32_REF = 103; - DT_UINT8_REF = 104; - DT_INT16_REF = 105; - DT_INT8_REF = 106; - DT_STRING_REF = 107; - DT_COMPLEX64_REF = 108; - DT_INT64_REF = 109; - DT_BOOL_REF = 110; - DT_QINT8_REF = 111; - DT_QUINT8_REF = 112; - DT_QINT32_REF = 113; - DT_BFLOAT16_REF = 114; - DT_QINT16_REF = 115; - DT_QUINT16_REF = 116; - DT_UINT16_REF = 117; - DT_COMPLEX128_REF = 118; - DT_HALF_REF = 119; - DT_RESOURCE_REF = 120; - DT_VARIANT_REF = 121; - DT_UINT32_REF = 122; - DT_UINT64_REF = 123; -} -// LINT.ThenChange( -// https://www.tensorflow.org/code/tensorflow/c/tf_datatype.h, -// https://www.tensorflow.org/code/tensorflow/go/tensor.go, -// https://www.tensorflow.org/code/tensorflow/core/framework/tensor.cc, -// https://www.tensorflow.org/code/tensorflow/core/framework/types.h, -// https://www.tensorflow.org/code/tensorflow/core/framework/types.cc, -// https://www.tensorflow.org/code/tensorflow/python/framework/dtypes.py, -// https://www.tensorflow.org/code/tensorflow/python/framework/function.py) diff --git a/protos/variable.proto b/protos/variable.proto deleted file mode 100644 index b2978c75..00000000 --- a/protos/variable.proto +++ /dev/null @@ -1,85 +0,0 @@ -syntax = "proto3"; - -package tensorflow; - -option cc_enable_arenas = true; -option java_outer_classname = "VariableProtos"; -option java_multiple_files = true; -option java_package = "org.tensorflow.framework"; - -option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framework"; - -// Indicates when a distributed variable will be synced. -enum VariableSynchronization { - // `AUTO`: Indicates that the synchronization will be determined by the - // current `DistributionStrategy` (eg. With `MirroredStrategy` this would be - // `ON_WRITE`). - VARIABLE_SYNCHRONIZATION_AUTO = 0; - // `NONE`: Indicates that there will only be one copy of the variable, so - // there is no need to sync. - VARIABLE_SYNCHRONIZATION_NONE = 1; - // `ON_WRITE`: Indicates that the variable will be updated across devices - // every time it is written. - VARIABLE_SYNCHRONIZATION_ON_WRITE = 2; - // `ON_READ`: Indicates that the variable will be aggregated across devices - // when it is read (eg. when checkpointing or when evaluating an op that uses - // the variable). - VARIABLE_SYNCHRONIZATION_ON_READ = 3; -} - -// Indicates how a distributed variable will be aggregated. -enum VariableAggregation { - // `NONE`: This is the default, giving an error if you use a - // variable-update operation with multiple replicas. - VARIABLE_AGGREGATION_NONE = 0; - // `SUM`: Add the updates across replicas. - VARIABLE_AGGREGATION_SUM = 1; - // `MEAN`: Take the arithmetic mean ("average") of the updates across - // replicas. - VARIABLE_AGGREGATION_MEAN = 2; - // `ONLY_FIRST_REPLICA`: This is for when every replica is performing the same - // update, but we only want to perform the update once. Used, e.g., for the - // global step counter. - VARIABLE_AGGREGATION_ONLY_FIRST_REPLICA = 3; -} - -// Protocol buffer representing a Variable. -message VariableDef { - // Name of the variable tensor. - string variable_name = 1; - - // Name of the tensor holding the variable's initial value. - string initial_value_name = 6; - - // Name of the initializer op. - string initializer_name = 2; - - // Name of the snapshot tensor. - string snapshot_name = 3; - - // Support for saving variables as slices of a larger variable. - SaveSliceInfoDef save_slice_info_def = 4; - - // Whether to represent this as a ResourceVariable. - bool is_resource = 5; - - // Whether this variable should be trained. - bool trainable = 7; - - // Indicates when a distributed variable will be synced. - VariableSynchronization synchronization = 8; - - // Indicates how a distributed variable will be aggregated. - VariableAggregation aggregation = 9; -} - -message SaveSliceInfoDef { - // Name of the full variable of which this is a slice. - string full_name = 1; - // Shape of the full variable. - repeated int64 full_shape = 2; - // Offset of this variable into the full variable. - repeated int64 var_offset = 3; - // Shape of this variable. - repeated int64 var_shape = 4; -} diff --git a/protos/versions.proto b/protos/versions.proto deleted file mode 100644 index dd2ec552..00000000 --- a/protos/versions.proto +++ /dev/null @@ -1,32 +0,0 @@ -syntax = "proto3"; - -package tensorflow; -option cc_enable_arenas = true; -option java_outer_classname = "VersionsProtos"; -option java_multiple_files = true; -option java_package = "org.tensorflow.framework"; -option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framework"; - -// Version information for a piece of serialized data -// -// There are different types of versions for each type of data -// (GraphDef, etc.), but they all have the same common shape -// described here. -// -// Each consumer has "consumer" and "min_producer" versions (specified -// elsewhere). A consumer is allowed to consume this data if -// -// producer >= min_producer -// consumer >= min_consumer -// consumer not in bad_consumers -// -message VersionDef { - // The version of the code that produced this data. - int32 producer = 1; - - // Any consumer below this version is not allowed to consume this data. - int32 min_consumer = 2; - - // Specific consumer versions which are disallowed (e.g. due to bugs). - repeated int32 bad_consumers = 3; -}; diff --git a/redis_consumer/__init__.py b/redis_consumer/__init__.py index 40aaeee7..2192c61c 100644 --- a/redis_consumer/__init__.py +++ b/redis_consumer/__init__.py @@ -29,7 +29,6 @@ from redis_consumer import consumers from redis_consumer import grpc_clients -from redis_consumer import pbs from redis_consumer import redis from redis_consumer import settings from redis_consumer import storage diff --git a/redis_consumer/pbs/__init__.py b/redis_consumer/pbs/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/redis_consumer/pbs/attr_value_pb2.py b/redis_consumer/pbs/attr_value_pb2.py deleted file mode 100644 index 3262caf7..00000000 --- a/redis_consumer/pbs/attr_value_pb2.py +++ /dev/null @@ -1,365 +0,0 @@ -# -*- coding: utf-8 -*- -# Generated by the protocol buffer compiler. DO NOT EDIT! -# source: attr_value.proto - -from google.protobuf import descriptor as _descriptor -from google.protobuf import message as _message -from google.protobuf import reflection as _reflection -from google.protobuf import symbol_database as _symbol_database -# @@protoc_insertion_point(imports) - -_sym_db = _symbol_database.Default() - - -import redis_consumer.pbs.tensor_pb2 as tensor__pb2 -import redis_consumer.pbs.tensor_shape_pb2 as tensor__shape__pb2 -import redis_consumer.pbs.types_pb2 as types__pb2 - - -DESCRIPTOR = _descriptor.FileDescriptor( - name='attr_value.proto', - package='tensorflow', - syntax='proto3', - serialized_options=b'\n\030org.tensorflow.frameworkB\017AttrValueProtosP\001Z=github.com/tensorflow/tensorflow/tensorflow/go/core/framework\370\001\001', - serialized_pb=b'\n\x10\x61ttr_value.proto\x12\ntensorflow\x1a\x0ctensor.proto\x1a\x12tensor_shape.proto\x1a\x0btypes.proto\"\xa6\x04\n\tAttrValue\x12\x0b\n\x01s\x18\x02 \x01(\x0cH\x00\x12\x0b\n\x01i\x18\x03 \x01(\x03H\x00\x12\x0b\n\x01\x66\x18\x04 \x01(\x02H\x00\x12\x0b\n\x01\x62\x18\x05 \x01(\x08H\x00\x12$\n\x04type\x18\x06 \x01(\x0e\x32\x14.tensorflow.DataTypeH\x00\x12-\n\x05shape\x18\x07 \x01(\x0b\x32\x1c.tensorflow.TensorShapeProtoH\x00\x12)\n\x06tensor\x18\x08 \x01(\x0b\x32\x17.tensorflow.TensorProtoH\x00\x12/\n\x04list\x18\x01 \x01(\x0b\x32\x1f.tensorflow.AttrValue.ListValueH\x00\x12(\n\x04\x66unc\x18\n \x01(\x0b\x32\x18.tensorflow.NameAttrListH\x00\x12\x15\n\x0bplaceholder\x18\t \x01(\tH\x00\x1a\xe9\x01\n\tListValue\x12\t\n\x01s\x18\x02 \x03(\x0c\x12\r\n\x01i\x18\x03 \x03(\x03\x42\x02\x10\x01\x12\r\n\x01\x66\x18\x04 \x03(\x02\x42\x02\x10\x01\x12\r\n\x01\x62\x18\x05 \x03(\x08\x42\x02\x10\x01\x12&\n\x04type\x18\x06 \x03(\x0e\x32\x14.tensorflow.DataTypeB\x02\x10\x01\x12+\n\x05shape\x18\x07 \x03(\x0b\x32\x1c.tensorflow.TensorShapeProto\x12\'\n\x06tensor\x18\x08 \x03(\x0b\x32\x17.tensorflow.TensorProto\x12&\n\x04\x66unc\x18\t \x03(\x0b\x32\x18.tensorflow.NameAttrListB\x07\n\x05value\"\x92\x01\n\x0cNameAttrList\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x30\n\x04\x61ttr\x18\x02 \x03(\x0b\x32\".tensorflow.NameAttrList.AttrEntry\x1a\x42\n\tAttrEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12$\n\x05value\x18\x02 \x01(\x0b\x32\x15.tensorflow.AttrValue:\x02\x38\x01\x42o\n\x18org.tensorflow.frameworkB\x0f\x41ttrValueProtosP\x01Z=github.com/tensorflow/tensorflow/tensorflow/go/core/framework\xf8\x01\x01\x62\x06proto3' - , - dependencies=[tensor__pb2.DESCRIPTOR,tensor__shape__pb2.DESCRIPTOR,types__pb2.DESCRIPTOR,]) - - - - -_ATTRVALUE_LISTVALUE = _descriptor.Descriptor( - name='ListValue', - full_name='tensorflow.AttrValue.ListValue', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='s', full_name='tensorflow.AttrValue.ListValue.s', index=0, - number=2, type=12, cpp_type=9, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='i', full_name='tensorflow.AttrValue.ListValue.i', index=1, - number=3, type=3, cpp_type=2, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=b'\020\001', file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='f', full_name='tensorflow.AttrValue.ListValue.f', index=2, - number=4, type=2, cpp_type=6, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=b'\020\001', file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='b', full_name='tensorflow.AttrValue.ListValue.b', index=3, - number=5, type=8, cpp_type=7, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=b'\020\001', file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='type', full_name='tensorflow.AttrValue.ListValue.type', index=4, - number=6, type=14, cpp_type=8, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=b'\020\001', file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='shape', full_name='tensorflow.AttrValue.ListValue.shape', index=5, - number=7, type=11, cpp_type=10, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='tensor', full_name='tensorflow.AttrValue.ListValue.tensor', index=6, - number=8, type=11, cpp_type=10, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='func', full_name='tensorflow.AttrValue.ListValue.func', index=7, - number=9, type=11, cpp_type=10, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=388, - serialized_end=621, -) - -_ATTRVALUE = _descriptor.Descriptor( - name='AttrValue', - full_name='tensorflow.AttrValue', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='s', full_name='tensorflow.AttrValue.s', index=0, - number=2, type=12, cpp_type=9, label=1, - has_default_value=False, default_value=b"", - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='i', full_name='tensorflow.AttrValue.i', index=1, - number=3, type=3, cpp_type=2, label=1, - has_default_value=False, default_value=0, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='f', full_name='tensorflow.AttrValue.f', index=2, - number=4, type=2, cpp_type=6, label=1, - has_default_value=False, default_value=float(0), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='b', full_name='tensorflow.AttrValue.b', index=3, - number=5, type=8, cpp_type=7, label=1, - has_default_value=False, default_value=False, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='type', full_name='tensorflow.AttrValue.type', index=4, - number=6, type=14, cpp_type=8, label=1, - has_default_value=False, default_value=0, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='shape', full_name='tensorflow.AttrValue.shape', index=5, - number=7, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='tensor', full_name='tensorflow.AttrValue.tensor', index=6, - number=8, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='list', full_name='tensorflow.AttrValue.list', index=7, - number=1, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='func', full_name='tensorflow.AttrValue.func', index=8, - number=10, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='placeholder', full_name='tensorflow.AttrValue.placeholder', index=9, - number=9, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[_ATTRVALUE_LISTVALUE, ], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - _descriptor.OneofDescriptor( - name='value', full_name='tensorflow.AttrValue.value', - index=0, containing_type=None, fields=[]), - ], - serialized_start=80, - serialized_end=630, -) - - -_NAMEATTRLIST_ATTRENTRY = _descriptor.Descriptor( - name='AttrEntry', - full_name='tensorflow.NameAttrList.AttrEntry', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='key', full_name='tensorflow.NameAttrList.AttrEntry.key', index=0, - number=1, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='value', full_name='tensorflow.NameAttrList.AttrEntry.value', index=1, - number=2, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=b'8\001', - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=713, - serialized_end=779, -) - -_NAMEATTRLIST = _descriptor.Descriptor( - name='NameAttrList', - full_name='tensorflow.NameAttrList', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='name', full_name='tensorflow.NameAttrList.name', index=0, - number=1, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='attr', full_name='tensorflow.NameAttrList.attr', index=1, - number=2, type=11, cpp_type=10, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[_NAMEATTRLIST_ATTRENTRY, ], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=633, - serialized_end=779, -) - -_ATTRVALUE_LISTVALUE.fields_by_name['type'].enum_type = types__pb2._DATATYPE -_ATTRVALUE_LISTVALUE.fields_by_name['shape'].message_type = tensor__shape__pb2._TENSORSHAPEPROTO -_ATTRVALUE_LISTVALUE.fields_by_name['tensor'].message_type = tensor__pb2._TENSORPROTO -_ATTRVALUE_LISTVALUE.fields_by_name['func'].message_type = _NAMEATTRLIST -_ATTRVALUE_LISTVALUE.containing_type = _ATTRVALUE -_ATTRVALUE.fields_by_name['type'].enum_type = types__pb2._DATATYPE -_ATTRVALUE.fields_by_name['shape'].message_type = tensor__shape__pb2._TENSORSHAPEPROTO -_ATTRVALUE.fields_by_name['tensor'].message_type = tensor__pb2._TENSORPROTO -_ATTRVALUE.fields_by_name['list'].message_type = _ATTRVALUE_LISTVALUE -_ATTRVALUE.fields_by_name['func'].message_type = _NAMEATTRLIST -_ATTRVALUE.oneofs_by_name['value'].fields.append( - _ATTRVALUE.fields_by_name['s']) -_ATTRVALUE.fields_by_name['s'].containing_oneof = _ATTRVALUE.oneofs_by_name['value'] -_ATTRVALUE.oneofs_by_name['value'].fields.append( - _ATTRVALUE.fields_by_name['i']) -_ATTRVALUE.fields_by_name['i'].containing_oneof = _ATTRVALUE.oneofs_by_name['value'] -_ATTRVALUE.oneofs_by_name['value'].fields.append( - _ATTRVALUE.fields_by_name['f']) -_ATTRVALUE.fields_by_name['f'].containing_oneof = _ATTRVALUE.oneofs_by_name['value'] -_ATTRVALUE.oneofs_by_name['value'].fields.append( - _ATTRVALUE.fields_by_name['b']) -_ATTRVALUE.fields_by_name['b'].containing_oneof = _ATTRVALUE.oneofs_by_name['value'] -_ATTRVALUE.oneofs_by_name['value'].fields.append( - _ATTRVALUE.fields_by_name['type']) -_ATTRVALUE.fields_by_name['type'].containing_oneof = _ATTRVALUE.oneofs_by_name['value'] -_ATTRVALUE.oneofs_by_name['value'].fields.append( - _ATTRVALUE.fields_by_name['shape']) -_ATTRVALUE.fields_by_name['shape'].containing_oneof = _ATTRVALUE.oneofs_by_name['value'] -_ATTRVALUE.oneofs_by_name['value'].fields.append( - _ATTRVALUE.fields_by_name['tensor']) -_ATTRVALUE.fields_by_name['tensor'].containing_oneof = _ATTRVALUE.oneofs_by_name['value'] -_ATTRVALUE.oneofs_by_name['value'].fields.append( - _ATTRVALUE.fields_by_name['list']) -_ATTRVALUE.fields_by_name['list'].containing_oneof = _ATTRVALUE.oneofs_by_name['value'] -_ATTRVALUE.oneofs_by_name['value'].fields.append( - _ATTRVALUE.fields_by_name['func']) -_ATTRVALUE.fields_by_name['func'].containing_oneof = _ATTRVALUE.oneofs_by_name['value'] -_ATTRVALUE.oneofs_by_name['value'].fields.append( - _ATTRVALUE.fields_by_name['placeholder']) -_ATTRVALUE.fields_by_name['placeholder'].containing_oneof = _ATTRVALUE.oneofs_by_name['value'] -_NAMEATTRLIST_ATTRENTRY.fields_by_name['value'].message_type = _ATTRVALUE -_NAMEATTRLIST_ATTRENTRY.containing_type = _NAMEATTRLIST -_NAMEATTRLIST.fields_by_name['attr'].message_type = _NAMEATTRLIST_ATTRENTRY -DESCRIPTOR.message_types_by_name['AttrValue'] = _ATTRVALUE -DESCRIPTOR.message_types_by_name['NameAttrList'] = _NAMEATTRLIST -_sym_db.RegisterFileDescriptor(DESCRIPTOR) - -AttrValue = _reflection.GeneratedProtocolMessageType('AttrValue', (_message.Message,), { - - 'ListValue' : _reflection.GeneratedProtocolMessageType('ListValue', (_message.Message,), { - 'DESCRIPTOR' : _ATTRVALUE_LISTVALUE, - '__module__' : 'attr_value_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.AttrValue.ListValue) - }) - , - 'DESCRIPTOR' : _ATTRVALUE, - '__module__' : 'attr_value_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.AttrValue) - }) -_sym_db.RegisterMessage(AttrValue) -_sym_db.RegisterMessage(AttrValue.ListValue) - -NameAttrList = _reflection.GeneratedProtocolMessageType('NameAttrList', (_message.Message,), { - - 'AttrEntry' : _reflection.GeneratedProtocolMessageType('AttrEntry', (_message.Message,), { - 'DESCRIPTOR' : _NAMEATTRLIST_ATTRENTRY, - '__module__' : 'attr_value_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.NameAttrList.AttrEntry) - }) - , - 'DESCRIPTOR' : _NAMEATTRLIST, - '__module__' : 'attr_value_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.NameAttrList) - }) -_sym_db.RegisterMessage(NameAttrList) -_sym_db.RegisterMessage(NameAttrList.AttrEntry) - - -DESCRIPTOR._options = None -_ATTRVALUE_LISTVALUE.fields_by_name['i']._options = None -_ATTRVALUE_LISTVALUE.fields_by_name['f']._options = None -_ATTRVALUE_LISTVALUE.fields_by_name['b']._options = None -_ATTRVALUE_LISTVALUE.fields_by_name['type']._options = None -_NAMEATTRLIST_ATTRENTRY._options = None -# @@protoc_insertion_point(module_scope) diff --git a/redis_consumer/pbs/attr_value_pb2_grpc.py b/redis_consumer/pbs/attr_value_pb2_grpc.py deleted file mode 100644 index a8943526..00000000 --- a/redis_consumer/pbs/attr_value_pb2_grpc.py +++ /dev/null @@ -1,3 +0,0 @@ -# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! -import grpc - diff --git a/redis_consumer/pbs/function_pb2.py b/redis_consumer/pbs/function_pb2.py deleted file mode 100644 index 8084de52..00000000 --- a/redis_consumer/pbs/function_pb2.py +++ /dev/null @@ -1,486 +0,0 @@ -# -*- coding: utf-8 -*- -# Generated by the protocol buffer compiler. DO NOT EDIT! -# source: function.proto - -from google.protobuf import descriptor as _descriptor -from google.protobuf import message as _message -from google.protobuf import reflection as _reflection -from google.protobuf import symbol_database as _symbol_database -# @@protoc_insertion_point(imports) - -_sym_db = _symbol_database.Default() - - -import redis_consumer.pbs.attr_value_pb2 as attr__value__pb2 -import redis_consumer.pbs.node_def_pb2 as node__def__pb2 -import redis_consumer.pbs.op_def_pb2 as op__def__pb2 - - -DESCRIPTOR = _descriptor.FileDescriptor( - name='function.proto', - package='tensorflow', - syntax='proto3', - serialized_options=b'\n\030org.tensorflow.frameworkB\016FunctionProtosP\001Z=github.com/tensorflow/tensorflow/tensorflow/go/core/framework\370\001\001', - serialized_pb=b'\n\x0e\x66unction.proto\x12\ntensorflow\x1a\x10\x61ttr_value.proto\x1a\x0enode_def.proto\x1a\x0cop_def.proto\"j\n\x12\x46unctionDefLibrary\x12)\n\x08\x66unction\x18\x01 \x03(\x0b\x32\x17.tensorflow.FunctionDef\x12)\n\x08gradient\x18\x02 \x03(\x0b\x32\x17.tensorflow.GradientDef\"\xb6\x05\n\x0b\x46unctionDef\x12$\n\tsignature\x18\x01 \x01(\x0b\x32\x11.tensorflow.OpDef\x12/\n\x04\x61ttr\x18\x05 \x03(\x0b\x32!.tensorflow.FunctionDef.AttrEntry\x12\x36\n\x08\x61rg_attr\x18\x07 \x03(\x0b\x32$.tensorflow.FunctionDef.ArgAttrEntry\x12%\n\x08node_def\x18\x03 \x03(\x0b\x32\x13.tensorflow.NodeDef\x12-\n\x03ret\x18\x04 \x03(\x0b\x32 .tensorflow.FunctionDef.RetEntry\x12<\n\x0b\x63ontrol_ret\x18\x06 \x03(\x0b\x32\'.tensorflow.FunctionDef.ControlRetEntry\x1a\x42\n\tAttrEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12$\n\x05value\x18\x02 \x01(\x0b\x32\x15.tensorflow.AttrValue:\x02\x38\x01\x1a\x88\x01\n\x08\x41rgAttrs\x12\x38\n\x04\x61ttr\x18\x01 \x03(\x0b\x32*.tensorflow.FunctionDef.ArgAttrs.AttrEntry\x1a\x42\n\tAttrEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12$\n\x05value\x18\x02 \x01(\x0b\x32\x15.tensorflow.AttrValue:\x02\x38\x01\x1aP\n\x0c\x41rgAttrEntry\x12\x0b\n\x03key\x18\x01 \x01(\r\x12/\n\x05value\x18\x02 \x01(\x0b\x32 .tensorflow.FunctionDef.ArgAttrs:\x02\x38\x01\x1a*\n\x08RetEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\x1a\x31\n\x0f\x43ontrolRetEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01J\x04\x08\x02\x10\x03\";\n\x0bGradientDef\x12\x15\n\rfunction_name\x18\x01 \x01(\t\x12\x15\n\rgradient_func\x18\x02 \x01(\tBn\n\x18org.tensorflow.frameworkB\x0e\x46unctionProtosP\x01Z=github.com/tensorflow/tensorflow/tensorflow/go/core/framework\xf8\x01\x01\x62\x06proto3' - , - dependencies=[attr__value__pb2.DESCRIPTOR,node__def__pb2.DESCRIPTOR,op__def__pb2.DESCRIPTOR,]) - - - - -_FUNCTIONDEFLIBRARY = _descriptor.Descriptor( - name='FunctionDefLibrary', - full_name='tensorflow.FunctionDefLibrary', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='function', full_name='tensorflow.FunctionDefLibrary.function', index=0, - number=1, type=11, cpp_type=10, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='gradient', full_name='tensorflow.FunctionDefLibrary.gradient', index=1, - number=2, type=11, cpp_type=10, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=78, - serialized_end=184, -) - - -_FUNCTIONDEF_ATTRENTRY = _descriptor.Descriptor( - name='AttrEntry', - full_name='tensorflow.FunctionDef.AttrEntry', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='key', full_name='tensorflow.FunctionDef.AttrEntry.key', index=0, - number=1, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='value', full_name='tensorflow.FunctionDef.AttrEntry.value', index=1, - number=2, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=b'8\001', - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=493, - serialized_end=559, -) - -_FUNCTIONDEF_ARGATTRS_ATTRENTRY = _descriptor.Descriptor( - name='AttrEntry', - full_name='tensorflow.FunctionDef.ArgAttrs.AttrEntry', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='key', full_name='tensorflow.FunctionDef.ArgAttrs.AttrEntry.key', index=0, - number=1, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='value', full_name='tensorflow.FunctionDef.ArgAttrs.AttrEntry.value', index=1, - number=2, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=b'8\001', - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=493, - serialized_end=559, -) - -_FUNCTIONDEF_ARGATTRS = _descriptor.Descriptor( - name='ArgAttrs', - full_name='tensorflow.FunctionDef.ArgAttrs', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='attr', full_name='tensorflow.FunctionDef.ArgAttrs.attr', index=0, - number=1, type=11, cpp_type=10, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[_FUNCTIONDEF_ARGATTRS_ATTRENTRY, ], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=562, - serialized_end=698, -) - -_FUNCTIONDEF_ARGATTRENTRY = _descriptor.Descriptor( - name='ArgAttrEntry', - full_name='tensorflow.FunctionDef.ArgAttrEntry', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='key', full_name='tensorflow.FunctionDef.ArgAttrEntry.key', index=0, - number=1, type=13, cpp_type=3, label=1, - has_default_value=False, default_value=0, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='value', full_name='tensorflow.FunctionDef.ArgAttrEntry.value', index=1, - number=2, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=b'8\001', - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=700, - serialized_end=780, -) - -_FUNCTIONDEF_RETENTRY = _descriptor.Descriptor( - name='RetEntry', - full_name='tensorflow.FunctionDef.RetEntry', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='key', full_name='tensorflow.FunctionDef.RetEntry.key', index=0, - number=1, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='value', full_name='tensorflow.FunctionDef.RetEntry.value', index=1, - number=2, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=b'8\001', - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=782, - serialized_end=824, -) - -_FUNCTIONDEF_CONTROLRETENTRY = _descriptor.Descriptor( - name='ControlRetEntry', - full_name='tensorflow.FunctionDef.ControlRetEntry', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='key', full_name='tensorflow.FunctionDef.ControlRetEntry.key', index=0, - number=1, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='value', full_name='tensorflow.FunctionDef.ControlRetEntry.value', index=1, - number=2, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=b'8\001', - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=826, - serialized_end=875, -) - -_FUNCTIONDEF = _descriptor.Descriptor( - name='FunctionDef', - full_name='tensorflow.FunctionDef', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='signature', full_name='tensorflow.FunctionDef.signature', index=0, - number=1, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='attr', full_name='tensorflow.FunctionDef.attr', index=1, - number=5, type=11, cpp_type=10, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='arg_attr', full_name='tensorflow.FunctionDef.arg_attr', index=2, - number=7, type=11, cpp_type=10, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='node_def', full_name='tensorflow.FunctionDef.node_def', index=3, - number=3, type=11, cpp_type=10, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='ret', full_name='tensorflow.FunctionDef.ret', index=4, - number=4, type=11, cpp_type=10, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='control_ret', full_name='tensorflow.FunctionDef.control_ret', index=5, - number=6, type=11, cpp_type=10, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[_FUNCTIONDEF_ATTRENTRY, _FUNCTIONDEF_ARGATTRS, _FUNCTIONDEF_ARGATTRENTRY, _FUNCTIONDEF_RETENTRY, _FUNCTIONDEF_CONTROLRETENTRY, ], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=187, - serialized_end=881, -) - - -_GRADIENTDEF = _descriptor.Descriptor( - name='GradientDef', - full_name='tensorflow.GradientDef', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='function_name', full_name='tensorflow.GradientDef.function_name', index=0, - number=1, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='gradient_func', full_name='tensorflow.GradientDef.gradient_func', index=1, - number=2, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=883, - serialized_end=942, -) - -_FUNCTIONDEFLIBRARY.fields_by_name['function'].message_type = _FUNCTIONDEF -_FUNCTIONDEFLIBRARY.fields_by_name['gradient'].message_type = _GRADIENTDEF -_FUNCTIONDEF_ATTRENTRY.fields_by_name['value'].message_type = attr__value__pb2._ATTRVALUE -_FUNCTIONDEF_ATTRENTRY.containing_type = _FUNCTIONDEF -_FUNCTIONDEF_ARGATTRS_ATTRENTRY.fields_by_name['value'].message_type = attr__value__pb2._ATTRVALUE -_FUNCTIONDEF_ARGATTRS_ATTRENTRY.containing_type = _FUNCTIONDEF_ARGATTRS -_FUNCTIONDEF_ARGATTRS.fields_by_name['attr'].message_type = _FUNCTIONDEF_ARGATTRS_ATTRENTRY -_FUNCTIONDEF_ARGATTRS.containing_type = _FUNCTIONDEF -_FUNCTIONDEF_ARGATTRENTRY.fields_by_name['value'].message_type = _FUNCTIONDEF_ARGATTRS -_FUNCTIONDEF_ARGATTRENTRY.containing_type = _FUNCTIONDEF -_FUNCTIONDEF_RETENTRY.containing_type = _FUNCTIONDEF -_FUNCTIONDEF_CONTROLRETENTRY.containing_type = _FUNCTIONDEF -_FUNCTIONDEF.fields_by_name['signature'].message_type = op__def__pb2._OPDEF -_FUNCTIONDEF.fields_by_name['attr'].message_type = _FUNCTIONDEF_ATTRENTRY -_FUNCTIONDEF.fields_by_name['arg_attr'].message_type = _FUNCTIONDEF_ARGATTRENTRY -_FUNCTIONDEF.fields_by_name['node_def'].message_type = node__def__pb2._NODEDEF -_FUNCTIONDEF.fields_by_name['ret'].message_type = _FUNCTIONDEF_RETENTRY -_FUNCTIONDEF.fields_by_name['control_ret'].message_type = _FUNCTIONDEF_CONTROLRETENTRY -DESCRIPTOR.message_types_by_name['FunctionDefLibrary'] = _FUNCTIONDEFLIBRARY -DESCRIPTOR.message_types_by_name['FunctionDef'] = _FUNCTIONDEF -DESCRIPTOR.message_types_by_name['GradientDef'] = _GRADIENTDEF -_sym_db.RegisterFileDescriptor(DESCRIPTOR) - -FunctionDefLibrary = _reflection.GeneratedProtocolMessageType('FunctionDefLibrary', (_message.Message,), { - 'DESCRIPTOR' : _FUNCTIONDEFLIBRARY, - '__module__' : 'function_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.FunctionDefLibrary) - }) -_sym_db.RegisterMessage(FunctionDefLibrary) - -FunctionDef = _reflection.GeneratedProtocolMessageType('FunctionDef', (_message.Message,), { - - 'AttrEntry' : _reflection.GeneratedProtocolMessageType('AttrEntry', (_message.Message,), { - 'DESCRIPTOR' : _FUNCTIONDEF_ATTRENTRY, - '__module__' : 'function_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.FunctionDef.AttrEntry) - }) - , - - 'ArgAttrs' : _reflection.GeneratedProtocolMessageType('ArgAttrs', (_message.Message,), { - - 'AttrEntry' : _reflection.GeneratedProtocolMessageType('AttrEntry', (_message.Message,), { - 'DESCRIPTOR' : _FUNCTIONDEF_ARGATTRS_ATTRENTRY, - '__module__' : 'function_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.FunctionDef.ArgAttrs.AttrEntry) - }) - , - 'DESCRIPTOR' : _FUNCTIONDEF_ARGATTRS, - '__module__' : 'function_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.FunctionDef.ArgAttrs) - }) - , - - 'ArgAttrEntry' : _reflection.GeneratedProtocolMessageType('ArgAttrEntry', (_message.Message,), { - 'DESCRIPTOR' : _FUNCTIONDEF_ARGATTRENTRY, - '__module__' : 'function_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.FunctionDef.ArgAttrEntry) - }) - , - - 'RetEntry' : _reflection.GeneratedProtocolMessageType('RetEntry', (_message.Message,), { - 'DESCRIPTOR' : _FUNCTIONDEF_RETENTRY, - '__module__' : 'function_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.FunctionDef.RetEntry) - }) - , - - 'ControlRetEntry' : _reflection.GeneratedProtocolMessageType('ControlRetEntry', (_message.Message,), { - 'DESCRIPTOR' : _FUNCTIONDEF_CONTROLRETENTRY, - '__module__' : 'function_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.FunctionDef.ControlRetEntry) - }) - , - 'DESCRIPTOR' : _FUNCTIONDEF, - '__module__' : 'function_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.FunctionDef) - }) -_sym_db.RegisterMessage(FunctionDef) -_sym_db.RegisterMessage(FunctionDef.AttrEntry) -_sym_db.RegisterMessage(FunctionDef.ArgAttrs) -_sym_db.RegisterMessage(FunctionDef.ArgAttrs.AttrEntry) -_sym_db.RegisterMessage(FunctionDef.ArgAttrEntry) -_sym_db.RegisterMessage(FunctionDef.RetEntry) -_sym_db.RegisterMessage(FunctionDef.ControlRetEntry) - -GradientDef = _reflection.GeneratedProtocolMessageType('GradientDef', (_message.Message,), { - 'DESCRIPTOR' : _GRADIENTDEF, - '__module__' : 'function_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.GradientDef) - }) -_sym_db.RegisterMessage(GradientDef) - - -DESCRIPTOR._options = None -_FUNCTIONDEF_ATTRENTRY._options = None -_FUNCTIONDEF_ARGATTRS_ATTRENTRY._options = None -_FUNCTIONDEF_ARGATTRENTRY._options = None -_FUNCTIONDEF_RETENTRY._options = None -_FUNCTIONDEF_CONTROLRETENTRY._options = None -# @@protoc_insertion_point(module_scope) diff --git a/redis_consumer/pbs/function_pb2_grpc.py b/redis_consumer/pbs/function_pb2_grpc.py deleted file mode 100644 index a8943526..00000000 --- a/redis_consumer/pbs/function_pb2_grpc.py +++ /dev/null @@ -1,3 +0,0 @@ -# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! -import grpc - diff --git a/redis_consumer/pbs/get_model_metadata_pb2.py b/redis_consumer/pbs/get_model_metadata_pb2.py deleted file mode 100644 index 1695ddee..00000000 --- a/redis_consumer/pbs/get_model_metadata_pb2.py +++ /dev/null @@ -1,265 +0,0 @@ -# -*- coding: utf-8 -*- -# Generated by the protocol buffer compiler. DO NOT EDIT! -# source: get_model_metadata.proto - -from google.protobuf import descriptor as _descriptor -from google.protobuf import message as _message -from google.protobuf import reflection as _reflection -from google.protobuf import symbol_database as _symbol_database -# @@protoc_insertion_point(imports) - -_sym_db = _symbol_database.Default() - - -from google.protobuf import any_pb2 as google_dot_protobuf_dot_any__pb2 -import redis_consumer.pbs.meta_graph_pb2 as meta__graph__pb2 -import redis_consumer.pbs.model_pb2 as model__pb2 - - -DESCRIPTOR = _descriptor.FileDescriptor( - name='get_model_metadata.proto', - package='tensorflow.serving', - syntax='proto3', - serialized_options=b'\370\001\001', - serialized_pb=b'\n\x18get_model_metadata.proto\x12\x12tensorflow.serving\x1a\x19google/protobuf/any.proto\x1a\x10meta_graph.proto\x1a\x0bmodel.proto\"\xae\x01\n\x0fSignatureDefMap\x12L\n\rsignature_def\x18\x01 \x03(\x0b\x32\x35.tensorflow.serving.SignatureDefMap.SignatureDefEntry\x1aM\n\x11SignatureDefEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\'\n\x05value\x18\x02 \x01(\x0b\x32\x18.tensorflow.SignatureDef:\x02\x38\x01\"d\n\x17GetModelMetadataRequest\x12\x31\n\nmodel_spec\x18\x01 \x01(\x0b\x32\x1d.tensorflow.serving.ModelSpec\x12\x16\n\x0emetadata_field\x18\x02 \x03(\t\"\xe2\x01\n\x18GetModelMetadataResponse\x12\x31\n\nmodel_spec\x18\x01 \x01(\x0b\x32\x1d.tensorflow.serving.ModelSpec\x12L\n\x08metadata\x18\x02 \x03(\x0b\x32:.tensorflow.serving.GetModelMetadataResponse.MetadataEntry\x1a\x45\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.google.protobuf.Any:\x02\x38\x01\x42\x03\xf8\x01\x01\x62\x06proto3' - , - dependencies=[google_dot_protobuf_dot_any__pb2.DESCRIPTOR,meta__graph__pb2.DESCRIPTOR,model__pb2.DESCRIPTOR,]) - - - - -_SIGNATUREDEFMAP_SIGNATUREDEFENTRY = _descriptor.Descriptor( - name='SignatureDefEntry', - full_name='tensorflow.serving.SignatureDefMap.SignatureDefEntry', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='key', full_name='tensorflow.serving.SignatureDefMap.SignatureDefEntry.key', index=0, - number=1, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='value', full_name='tensorflow.serving.SignatureDefMap.SignatureDefEntry.value', index=1, - number=2, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=b'8\001', - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=204, - serialized_end=281, -) - -_SIGNATUREDEFMAP = _descriptor.Descriptor( - name='SignatureDefMap', - full_name='tensorflow.serving.SignatureDefMap', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='signature_def', full_name='tensorflow.serving.SignatureDefMap.signature_def', index=0, - number=1, type=11, cpp_type=10, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[_SIGNATUREDEFMAP_SIGNATUREDEFENTRY, ], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=107, - serialized_end=281, -) - - -_GETMODELMETADATAREQUEST = _descriptor.Descriptor( - name='GetModelMetadataRequest', - full_name='tensorflow.serving.GetModelMetadataRequest', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='model_spec', full_name='tensorflow.serving.GetModelMetadataRequest.model_spec', index=0, - number=1, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='metadata_field', full_name='tensorflow.serving.GetModelMetadataRequest.metadata_field', index=1, - number=2, type=9, cpp_type=9, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=283, - serialized_end=383, -) - - -_GETMODELMETADATARESPONSE_METADATAENTRY = _descriptor.Descriptor( - name='MetadataEntry', - full_name='tensorflow.serving.GetModelMetadataResponse.MetadataEntry', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='key', full_name='tensorflow.serving.GetModelMetadataResponse.MetadataEntry.key', index=0, - number=1, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='value', full_name='tensorflow.serving.GetModelMetadataResponse.MetadataEntry.value', index=1, - number=2, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=b'8\001', - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=543, - serialized_end=612, -) - -_GETMODELMETADATARESPONSE = _descriptor.Descriptor( - name='GetModelMetadataResponse', - full_name='tensorflow.serving.GetModelMetadataResponse', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='model_spec', full_name='tensorflow.serving.GetModelMetadataResponse.model_spec', index=0, - number=1, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='metadata', full_name='tensorflow.serving.GetModelMetadataResponse.metadata', index=1, - number=2, type=11, cpp_type=10, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[_GETMODELMETADATARESPONSE_METADATAENTRY, ], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=386, - serialized_end=612, -) - -_SIGNATUREDEFMAP_SIGNATUREDEFENTRY.fields_by_name['value'].message_type = meta__graph__pb2._SIGNATUREDEF -_SIGNATUREDEFMAP_SIGNATUREDEFENTRY.containing_type = _SIGNATUREDEFMAP -_SIGNATUREDEFMAP.fields_by_name['signature_def'].message_type = _SIGNATUREDEFMAP_SIGNATUREDEFENTRY -_GETMODELMETADATAREQUEST.fields_by_name['model_spec'].message_type = model__pb2._MODELSPEC -_GETMODELMETADATARESPONSE_METADATAENTRY.fields_by_name['value'].message_type = google_dot_protobuf_dot_any__pb2._ANY -_GETMODELMETADATARESPONSE_METADATAENTRY.containing_type = _GETMODELMETADATARESPONSE -_GETMODELMETADATARESPONSE.fields_by_name['model_spec'].message_type = model__pb2._MODELSPEC -_GETMODELMETADATARESPONSE.fields_by_name['metadata'].message_type = _GETMODELMETADATARESPONSE_METADATAENTRY -DESCRIPTOR.message_types_by_name['SignatureDefMap'] = _SIGNATUREDEFMAP -DESCRIPTOR.message_types_by_name['GetModelMetadataRequest'] = _GETMODELMETADATAREQUEST -DESCRIPTOR.message_types_by_name['GetModelMetadataResponse'] = _GETMODELMETADATARESPONSE -_sym_db.RegisterFileDescriptor(DESCRIPTOR) - -SignatureDefMap = _reflection.GeneratedProtocolMessageType('SignatureDefMap', (_message.Message,), { - - 'SignatureDefEntry' : _reflection.GeneratedProtocolMessageType('SignatureDefEntry', (_message.Message,), { - 'DESCRIPTOR' : _SIGNATUREDEFMAP_SIGNATUREDEFENTRY, - '__module__' : 'get_model_metadata_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.serving.SignatureDefMap.SignatureDefEntry) - }) - , - 'DESCRIPTOR' : _SIGNATUREDEFMAP, - '__module__' : 'get_model_metadata_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.serving.SignatureDefMap) - }) -_sym_db.RegisterMessage(SignatureDefMap) -_sym_db.RegisterMessage(SignatureDefMap.SignatureDefEntry) - -GetModelMetadataRequest = _reflection.GeneratedProtocolMessageType('GetModelMetadataRequest', (_message.Message,), { - 'DESCRIPTOR' : _GETMODELMETADATAREQUEST, - '__module__' : 'get_model_metadata_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.serving.GetModelMetadataRequest) - }) -_sym_db.RegisterMessage(GetModelMetadataRequest) - -GetModelMetadataResponse = _reflection.GeneratedProtocolMessageType('GetModelMetadataResponse', (_message.Message,), { - - 'MetadataEntry' : _reflection.GeneratedProtocolMessageType('MetadataEntry', (_message.Message,), { - 'DESCRIPTOR' : _GETMODELMETADATARESPONSE_METADATAENTRY, - '__module__' : 'get_model_metadata_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.serving.GetModelMetadataResponse.MetadataEntry) - }) - , - 'DESCRIPTOR' : _GETMODELMETADATARESPONSE, - '__module__' : 'get_model_metadata_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.serving.GetModelMetadataResponse) - }) -_sym_db.RegisterMessage(GetModelMetadataResponse) -_sym_db.RegisterMessage(GetModelMetadataResponse.MetadataEntry) - - -DESCRIPTOR._options = None -_SIGNATUREDEFMAP_SIGNATUREDEFENTRY._options = None -_GETMODELMETADATARESPONSE_METADATAENTRY._options = None -# @@protoc_insertion_point(module_scope) diff --git a/redis_consumer/pbs/get_model_metadata_pb2_grpc.py b/redis_consumer/pbs/get_model_metadata_pb2_grpc.py deleted file mode 100644 index a8943526..00000000 --- a/redis_consumer/pbs/get_model_metadata_pb2_grpc.py +++ /dev/null @@ -1,3 +0,0 @@ -# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! -import grpc - diff --git a/redis_consumer/pbs/graph_pb2.py b/redis_consumer/pbs/graph_pb2.py deleted file mode 100644 index 78131df8..00000000 --- a/redis_consumer/pbs/graph_pb2.py +++ /dev/null @@ -1,98 +0,0 @@ -# -*- coding: utf-8 -*- -# Generated by the protocol buffer compiler. DO NOT EDIT! -# source: graph.proto - -from google.protobuf import descriptor as _descriptor -from google.protobuf import message as _message -from google.protobuf import reflection as _reflection -from google.protobuf import symbol_database as _symbol_database -# @@protoc_insertion_point(imports) - -_sym_db = _symbol_database.Default() - - -import redis_consumer.pbs.node_def_pb2 as node__def__pb2 -import redis_consumer.pbs.function_pb2 as function__pb2 -import redis_consumer.pbs.versions_pb2 as versions__pb2 - - -DESCRIPTOR = _descriptor.FileDescriptor( - name='graph.proto', - package='tensorflow', - syntax='proto3', - serialized_options=b'\n\030org.tensorflow.frameworkB\013GraphProtosP\001Z=github.com/tensorflow/tensorflow/tensorflow/go/core/framework\370\001\001', - serialized_pb=b'\n\x0bgraph.proto\x12\ntensorflow\x1a\x0enode_def.proto\x1a\x0e\x66unction.proto\x1a\x0eversions.proto\"\x9d\x01\n\x08GraphDef\x12!\n\x04node\x18\x01 \x03(\x0b\x32\x13.tensorflow.NodeDef\x12(\n\x08versions\x18\x04 \x01(\x0b\x32\x16.tensorflow.VersionDef\x12\x13\n\x07version\x18\x03 \x01(\x05\x42\x02\x18\x01\x12/\n\x07library\x18\x02 \x01(\x0b\x32\x1e.tensorflow.FunctionDefLibraryBk\n\x18org.tensorflow.frameworkB\x0bGraphProtosP\x01Z=github.com/tensorflow/tensorflow/tensorflow/go/core/framework\xf8\x01\x01\x62\x06proto3' - , - dependencies=[node__def__pb2.DESCRIPTOR,function__pb2.DESCRIPTOR,versions__pb2.DESCRIPTOR,]) - - - - -_GRAPHDEF = _descriptor.Descriptor( - name='GraphDef', - full_name='tensorflow.GraphDef', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='node', full_name='tensorflow.GraphDef.node', index=0, - number=1, type=11, cpp_type=10, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='versions', full_name='tensorflow.GraphDef.versions', index=1, - number=4, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='version', full_name='tensorflow.GraphDef.version', index=2, - number=3, type=5, cpp_type=1, label=1, - has_default_value=False, default_value=0, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=b'\030\001', file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='library', full_name='tensorflow.GraphDef.library', index=3, - number=2, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=76, - serialized_end=233, -) - -_GRAPHDEF.fields_by_name['node'].message_type = node__def__pb2._NODEDEF -_GRAPHDEF.fields_by_name['versions'].message_type = versions__pb2._VERSIONDEF -_GRAPHDEF.fields_by_name['library'].message_type = function__pb2._FUNCTIONDEFLIBRARY -DESCRIPTOR.message_types_by_name['GraphDef'] = _GRAPHDEF -_sym_db.RegisterFileDescriptor(DESCRIPTOR) - -GraphDef = _reflection.GeneratedProtocolMessageType('GraphDef', (_message.Message,), { - 'DESCRIPTOR' : _GRAPHDEF, - '__module__' : 'graph_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.GraphDef) - }) -_sym_db.RegisterMessage(GraphDef) - - -DESCRIPTOR._options = None -_GRAPHDEF.fields_by_name['version']._options = None -# @@protoc_insertion_point(module_scope) diff --git a/redis_consumer/pbs/graph_pb2_grpc.py b/redis_consumer/pbs/graph_pb2_grpc.py deleted file mode 100644 index a8943526..00000000 --- a/redis_consumer/pbs/graph_pb2_grpc.py +++ /dev/null @@ -1,3 +0,0 @@ -# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! -import grpc - diff --git a/redis_consumer/pbs/meta_graph_pb2.py b/redis_consumer/pbs/meta_graph_pb2.py deleted file mode 100644 index c8e8eebd..00000000 --- a/redis_consumer/pbs/meta_graph_pb2.py +++ /dev/null @@ -1,1031 +0,0 @@ -# -*- coding: utf-8 -*- -# Generated by the protocol buffer compiler. DO NOT EDIT! -# source: meta_graph.proto - -from google.protobuf import descriptor as _descriptor -from google.protobuf import message as _message -from google.protobuf import reflection as _reflection -from google.protobuf import symbol_database as _symbol_database -# @@protoc_insertion_point(imports) - -_sym_db = _symbol_database.Default() - - -from google.protobuf import any_pb2 as google_dot_protobuf_dot_any__pb2 -import redis_consumer.pbs.graph_pb2 as graph__pb2 -import redis_consumer.pbs.op_def_pb2 as op__def__pb2 -import redis_consumer.pbs.tensor_shape_pb2 as tensor__shape__pb2 -import redis_consumer.pbs.types_pb2 as types__pb2 -import redis_consumer.pbs.saved_object_graph_pb2 as saved__object__graph__pb2 -import redis_consumer.pbs.saver_pb2 as saver__pb2 -import redis_consumer.pbs.struct_pb2 as struct__pb2 - - -DESCRIPTOR = _descriptor.FileDescriptor( - name='meta_graph.proto', - package='tensorflow', - syntax='proto3', - serialized_options=b'\n\030org.tensorflow.frameworkB\017MetaGraphProtosP\001ZHgithub.com/tensorflow/tensorflow/tensorflow/go/core/core_protos_go_proto\370\001\001', - serialized_pb=b'\n\x10meta_graph.proto\x12\ntensorflow\x1a\x19google/protobuf/any.proto\x1a\x0bgraph.proto\x1a\x0cop_def.proto\x1a\x12tensor_shape.proto\x1a\x0btypes.proto\x1a\x18saved_object_graph.proto\x1a\x0bsaver.proto\x1a\x0cstruct.proto\"\xa8\x07\n\x0cMetaGraphDef\x12;\n\rmeta_info_def\x18\x01 \x01(\x0b\x32$.tensorflow.MetaGraphDef.MetaInfoDef\x12\'\n\tgraph_def\x18\x02 \x01(\x0b\x32\x14.tensorflow.GraphDef\x12\'\n\tsaver_def\x18\x03 \x01(\x0b\x32\x14.tensorflow.SaverDef\x12\x43\n\x0e\x63ollection_def\x18\x04 \x03(\x0b\x32+.tensorflow.MetaGraphDef.CollectionDefEntry\x12\x41\n\rsignature_def\x18\x05 \x03(\x0b\x32*.tensorflow.MetaGraphDef.SignatureDefEntry\x12\x30\n\x0e\x61sset_file_def\x18\x06 \x03(\x0b\x32\x18.tensorflow.AssetFileDef\x12\x36\n\x10object_graph_def\x18\x07 \x01(\x0b\x32\x1c.tensorflow.SavedObjectGraph\x1a\xf6\x02\n\x0bMetaInfoDef\x12\x1a\n\x12meta_graph_version\x18\x01 \x01(\t\x12,\n\x10stripped_op_list\x18\x02 \x01(\x0b\x32\x12.tensorflow.OpList\x12&\n\x08\x61ny_info\x18\x03 \x01(\x0b\x32\x14.google.protobuf.Any\x12\x0c\n\x04tags\x18\x04 \x03(\t\x12\x1a\n\x12tensorflow_version\x18\x05 \x01(\t\x12\x1e\n\x16tensorflow_git_version\x18\x06 \x01(\t\x12\x1e\n\x16stripped_default_attrs\x18\x07 \x01(\x08\x12S\n\x10\x66unction_aliases\x18\x08 \x03(\x0b\x32\x39.tensorflow.MetaGraphDef.MetaInfoDef.FunctionAliasesEntry\x1a\x36\n\x14\x46unctionAliasesEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\x1aO\n\x12\x43ollectionDefEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12(\n\x05value\x18\x02 \x01(\x0b\x32\x19.tensorflow.CollectionDef:\x02\x38\x01\x1aM\n\x11SignatureDefEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\'\n\x05value\x18\x02 \x01(\x0b\x32\x18.tensorflow.SignatureDef:\x02\x38\x01\"\xdf\x03\n\rCollectionDef\x12\x37\n\tnode_list\x18\x01 \x01(\x0b\x32\".tensorflow.CollectionDef.NodeListH\x00\x12\x39\n\nbytes_list\x18\x02 \x01(\x0b\x32#.tensorflow.CollectionDef.BytesListH\x00\x12\x39\n\nint64_list\x18\x03 \x01(\x0b\x32#.tensorflow.CollectionDef.Int64ListH\x00\x12\x39\n\nfloat_list\x18\x04 \x01(\x0b\x32#.tensorflow.CollectionDef.FloatListH\x00\x12\x35\n\x08\x61ny_list\x18\x05 \x01(\x0b\x32!.tensorflow.CollectionDef.AnyListH\x00\x1a\x19\n\x08NodeList\x12\r\n\x05value\x18\x01 \x03(\t\x1a\x1a\n\tBytesList\x12\r\n\x05value\x18\x01 \x03(\x0c\x1a\x1e\n\tInt64List\x12\x11\n\x05value\x18\x01 \x03(\x03\x42\x02\x10\x01\x1a\x1e\n\tFloatList\x12\x11\n\x05value\x18\x01 \x03(\x02\x42\x02\x10\x01\x1a.\n\x07\x41nyList\x12#\n\x05value\x18\x01 \x03(\x0b\x32\x14.google.protobuf.AnyB\x06\n\x04kind\"\xd1\x03\n\nTensorInfo\x12\x0e\n\x04name\x18\x01 \x01(\tH\x00\x12\x36\n\ncoo_sparse\x18\x04 \x01(\x0b\x32 .tensorflow.TensorInfo.CooSparseH\x00\x12\x42\n\x10\x63omposite_tensor\x18\x05 \x01(\x0b\x32&.tensorflow.TensorInfo.CompositeTensorH\x00\x12#\n\x05\x64type\x18\x02 \x01(\x0e\x32\x14.tensorflow.DataType\x12\x32\n\x0ctensor_shape\x18\x03 \x01(\x0b\x32\x1c.tensorflow.TensorShapeProto\x1a\x65\n\tCooSparse\x12\x1a\n\x12values_tensor_name\x18\x01 \x01(\t\x12\x1b\n\x13indices_tensor_name\x18\x02 \x01(\t\x12\x1f\n\x17\x64\x65nse_shape_tensor_name\x18\x03 \x01(\t\x1ak\n\x0f\x43ompositeTensor\x12,\n\ttype_spec\x18\x01 \x01(\x0b\x32\x19.tensorflow.TypeSpecProto\x12*\n\ncomponents\x18\x02 \x03(\x0b\x32\x16.tensorflow.TensorInfoB\n\n\x08\x65ncoding\"\xa0\x02\n\x0cSignatureDef\x12\x34\n\x06inputs\x18\x01 \x03(\x0b\x32$.tensorflow.SignatureDef.InputsEntry\x12\x36\n\x07outputs\x18\x02 \x03(\x0b\x32%.tensorflow.SignatureDef.OutputsEntry\x12\x13\n\x0bmethod_name\x18\x03 \x01(\t\x1a\x45\n\x0bInputsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12%\n\x05value\x18\x02 \x01(\x0b\x32\x16.tensorflow.TensorInfo:\x02\x38\x01\x1a\x46\n\x0cOutputsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12%\n\x05value\x18\x02 \x01(\x0b\x32\x16.tensorflow.TensorInfo:\x02\x38\x01\"M\n\x0c\x41ssetFileDef\x12+\n\x0btensor_info\x18\x01 \x01(\x0b\x32\x16.tensorflow.TensorInfo\x12\x10\n\x08\x66ilename\x18\x02 \x01(\tBz\n\x18org.tensorflow.frameworkB\x0fMetaGraphProtosP\x01ZHgithub.com/tensorflow/tensorflow/tensorflow/go/core/core_protos_go_proto\xf8\x01\x01\x62\x06proto3' - , - dependencies=[google_dot_protobuf_dot_any__pb2.DESCRIPTOR,graph__pb2.DESCRIPTOR,op__def__pb2.DESCRIPTOR,tensor__shape__pb2.DESCRIPTOR,types__pb2.DESCRIPTOR,saved__object__graph__pb2.DESCRIPTOR,saver__pb2.DESCRIPTOR,struct__pb2.DESCRIPTOR,]) - - - - -_METAGRAPHDEF_METAINFODEF_FUNCTIONALIASESENTRY = _descriptor.Descriptor( - name='FunctionAliasesEntry', - full_name='tensorflow.MetaGraphDef.MetaInfoDef.FunctionAliasesEntry', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='key', full_name='tensorflow.MetaGraphDef.MetaInfoDef.FunctionAliasesEntry.key', index=0, - number=1, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='value', full_name='tensorflow.MetaGraphDef.MetaInfoDef.FunctionAliasesEntry.value', index=1, - number=2, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=b'8\001', - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=895, - serialized_end=949, -) - -_METAGRAPHDEF_METAINFODEF = _descriptor.Descriptor( - name='MetaInfoDef', - full_name='tensorflow.MetaGraphDef.MetaInfoDef', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='meta_graph_version', full_name='tensorflow.MetaGraphDef.MetaInfoDef.meta_graph_version', index=0, - number=1, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='stripped_op_list', full_name='tensorflow.MetaGraphDef.MetaInfoDef.stripped_op_list', index=1, - number=2, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='any_info', full_name='tensorflow.MetaGraphDef.MetaInfoDef.any_info', index=2, - number=3, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='tags', full_name='tensorflow.MetaGraphDef.MetaInfoDef.tags', index=3, - number=4, type=9, cpp_type=9, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='tensorflow_version', full_name='tensorflow.MetaGraphDef.MetaInfoDef.tensorflow_version', index=4, - number=5, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='tensorflow_git_version', full_name='tensorflow.MetaGraphDef.MetaInfoDef.tensorflow_git_version', index=5, - number=6, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='stripped_default_attrs', full_name='tensorflow.MetaGraphDef.MetaInfoDef.stripped_default_attrs', index=6, - number=7, type=8, cpp_type=7, label=1, - has_default_value=False, default_value=False, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='function_aliases', full_name='tensorflow.MetaGraphDef.MetaInfoDef.function_aliases', index=7, - number=8, type=11, cpp_type=10, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[_METAGRAPHDEF_METAINFODEF_FUNCTIONALIASESENTRY, ], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=575, - serialized_end=949, -) - -_METAGRAPHDEF_COLLECTIONDEFENTRY = _descriptor.Descriptor( - name='CollectionDefEntry', - full_name='tensorflow.MetaGraphDef.CollectionDefEntry', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='key', full_name='tensorflow.MetaGraphDef.CollectionDefEntry.key', index=0, - number=1, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='value', full_name='tensorflow.MetaGraphDef.CollectionDefEntry.value', index=1, - number=2, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=b'8\001', - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=951, - serialized_end=1030, -) - -_METAGRAPHDEF_SIGNATUREDEFENTRY = _descriptor.Descriptor( - name='SignatureDefEntry', - full_name='tensorflow.MetaGraphDef.SignatureDefEntry', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='key', full_name='tensorflow.MetaGraphDef.SignatureDefEntry.key', index=0, - number=1, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='value', full_name='tensorflow.MetaGraphDef.SignatureDefEntry.value', index=1, - number=2, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=b'8\001', - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=1032, - serialized_end=1109, -) - -_METAGRAPHDEF = _descriptor.Descriptor( - name='MetaGraphDef', - full_name='tensorflow.MetaGraphDef', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='meta_info_def', full_name='tensorflow.MetaGraphDef.meta_info_def', index=0, - number=1, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='graph_def', full_name='tensorflow.MetaGraphDef.graph_def', index=1, - number=2, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='saver_def', full_name='tensorflow.MetaGraphDef.saver_def', index=2, - number=3, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='collection_def', full_name='tensorflow.MetaGraphDef.collection_def', index=3, - number=4, type=11, cpp_type=10, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='signature_def', full_name='tensorflow.MetaGraphDef.signature_def', index=4, - number=5, type=11, cpp_type=10, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='asset_file_def', full_name='tensorflow.MetaGraphDef.asset_file_def', index=5, - number=6, type=11, cpp_type=10, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='object_graph_def', full_name='tensorflow.MetaGraphDef.object_graph_def', index=6, - number=7, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[_METAGRAPHDEF_METAINFODEF, _METAGRAPHDEF_COLLECTIONDEFENTRY, _METAGRAPHDEF_SIGNATUREDEFENTRY, ], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=173, - serialized_end=1109, -) - - -_COLLECTIONDEF_NODELIST = _descriptor.Descriptor( - name='NodeList', - full_name='tensorflow.CollectionDef.NodeList', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='value', full_name='tensorflow.CollectionDef.NodeList.value', index=0, - number=1, type=9, cpp_type=9, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=1418, - serialized_end=1443, -) - -_COLLECTIONDEF_BYTESLIST = _descriptor.Descriptor( - name='BytesList', - full_name='tensorflow.CollectionDef.BytesList', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='value', full_name='tensorflow.CollectionDef.BytesList.value', index=0, - number=1, type=12, cpp_type=9, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=1445, - serialized_end=1471, -) - -_COLLECTIONDEF_INT64LIST = _descriptor.Descriptor( - name='Int64List', - full_name='tensorflow.CollectionDef.Int64List', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='value', full_name='tensorflow.CollectionDef.Int64List.value', index=0, - number=1, type=3, cpp_type=2, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=b'\020\001', file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=1473, - serialized_end=1503, -) - -_COLLECTIONDEF_FLOATLIST = _descriptor.Descriptor( - name='FloatList', - full_name='tensorflow.CollectionDef.FloatList', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='value', full_name='tensorflow.CollectionDef.FloatList.value', index=0, - number=1, type=2, cpp_type=6, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=b'\020\001', file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=1505, - serialized_end=1535, -) - -_COLLECTIONDEF_ANYLIST = _descriptor.Descriptor( - name='AnyList', - full_name='tensorflow.CollectionDef.AnyList', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='value', full_name='tensorflow.CollectionDef.AnyList.value', index=0, - number=1, type=11, cpp_type=10, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=1537, - serialized_end=1583, -) - -_COLLECTIONDEF = _descriptor.Descriptor( - name='CollectionDef', - full_name='tensorflow.CollectionDef', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='node_list', full_name='tensorflow.CollectionDef.node_list', index=0, - number=1, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='bytes_list', full_name='tensorflow.CollectionDef.bytes_list', index=1, - number=2, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='int64_list', full_name='tensorflow.CollectionDef.int64_list', index=2, - number=3, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='float_list', full_name='tensorflow.CollectionDef.float_list', index=3, - number=4, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='any_list', full_name='tensorflow.CollectionDef.any_list', index=4, - number=5, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[_COLLECTIONDEF_NODELIST, _COLLECTIONDEF_BYTESLIST, _COLLECTIONDEF_INT64LIST, _COLLECTIONDEF_FLOATLIST, _COLLECTIONDEF_ANYLIST, ], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - _descriptor.OneofDescriptor( - name='kind', full_name='tensorflow.CollectionDef.kind', - index=0, containing_type=None, fields=[]), - ], - serialized_start=1112, - serialized_end=1591, -) - - -_TENSORINFO_COOSPARSE = _descriptor.Descriptor( - name='CooSparse', - full_name='tensorflow.TensorInfo.CooSparse', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='values_tensor_name', full_name='tensorflow.TensorInfo.CooSparse.values_tensor_name', index=0, - number=1, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='indices_tensor_name', full_name='tensorflow.TensorInfo.CooSparse.indices_tensor_name', index=1, - number=2, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='dense_shape_tensor_name', full_name='tensorflow.TensorInfo.CooSparse.dense_shape_tensor_name', index=2, - number=3, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=1837, - serialized_end=1938, -) - -_TENSORINFO_COMPOSITETENSOR = _descriptor.Descriptor( - name='CompositeTensor', - full_name='tensorflow.TensorInfo.CompositeTensor', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='type_spec', full_name='tensorflow.TensorInfo.CompositeTensor.type_spec', index=0, - number=1, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='components', full_name='tensorflow.TensorInfo.CompositeTensor.components', index=1, - number=2, type=11, cpp_type=10, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=1940, - serialized_end=2047, -) - -_TENSORINFO = _descriptor.Descriptor( - name='TensorInfo', - full_name='tensorflow.TensorInfo', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='name', full_name='tensorflow.TensorInfo.name', index=0, - number=1, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='coo_sparse', full_name='tensorflow.TensorInfo.coo_sparse', index=1, - number=4, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='composite_tensor', full_name='tensorflow.TensorInfo.composite_tensor', index=2, - number=5, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='dtype', full_name='tensorflow.TensorInfo.dtype', index=3, - number=2, type=14, cpp_type=8, label=1, - has_default_value=False, default_value=0, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='tensor_shape', full_name='tensorflow.TensorInfo.tensor_shape', index=4, - number=3, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[_TENSORINFO_COOSPARSE, _TENSORINFO_COMPOSITETENSOR, ], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - _descriptor.OneofDescriptor( - name='encoding', full_name='tensorflow.TensorInfo.encoding', - index=0, containing_type=None, fields=[]), - ], - serialized_start=1594, - serialized_end=2059, -) - - -_SIGNATUREDEF_INPUTSENTRY = _descriptor.Descriptor( - name='InputsEntry', - full_name='tensorflow.SignatureDef.InputsEntry', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='key', full_name='tensorflow.SignatureDef.InputsEntry.key', index=0, - number=1, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='value', full_name='tensorflow.SignatureDef.InputsEntry.value', index=1, - number=2, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=b'8\001', - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=2209, - serialized_end=2278, -) - -_SIGNATUREDEF_OUTPUTSENTRY = _descriptor.Descriptor( - name='OutputsEntry', - full_name='tensorflow.SignatureDef.OutputsEntry', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='key', full_name='tensorflow.SignatureDef.OutputsEntry.key', index=0, - number=1, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='value', full_name='tensorflow.SignatureDef.OutputsEntry.value', index=1, - number=2, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=b'8\001', - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=2280, - serialized_end=2350, -) - -_SIGNATUREDEF = _descriptor.Descriptor( - name='SignatureDef', - full_name='tensorflow.SignatureDef', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='inputs', full_name='tensorflow.SignatureDef.inputs', index=0, - number=1, type=11, cpp_type=10, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='outputs', full_name='tensorflow.SignatureDef.outputs', index=1, - number=2, type=11, cpp_type=10, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='method_name', full_name='tensorflow.SignatureDef.method_name', index=2, - number=3, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[_SIGNATUREDEF_INPUTSENTRY, _SIGNATUREDEF_OUTPUTSENTRY, ], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=2062, - serialized_end=2350, -) - - -_ASSETFILEDEF = _descriptor.Descriptor( - name='AssetFileDef', - full_name='tensorflow.AssetFileDef', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='tensor_info', full_name='tensorflow.AssetFileDef.tensor_info', index=0, - number=1, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='filename', full_name='tensorflow.AssetFileDef.filename', index=1, - number=2, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=2352, - serialized_end=2429, -) - -_METAGRAPHDEF_METAINFODEF_FUNCTIONALIASESENTRY.containing_type = _METAGRAPHDEF_METAINFODEF -_METAGRAPHDEF_METAINFODEF.fields_by_name['stripped_op_list'].message_type = op__def__pb2._OPLIST -_METAGRAPHDEF_METAINFODEF.fields_by_name['any_info'].message_type = google_dot_protobuf_dot_any__pb2._ANY -_METAGRAPHDEF_METAINFODEF.fields_by_name['function_aliases'].message_type = _METAGRAPHDEF_METAINFODEF_FUNCTIONALIASESENTRY -_METAGRAPHDEF_METAINFODEF.containing_type = _METAGRAPHDEF -_METAGRAPHDEF_COLLECTIONDEFENTRY.fields_by_name['value'].message_type = _COLLECTIONDEF -_METAGRAPHDEF_COLLECTIONDEFENTRY.containing_type = _METAGRAPHDEF -_METAGRAPHDEF_SIGNATUREDEFENTRY.fields_by_name['value'].message_type = _SIGNATUREDEF -_METAGRAPHDEF_SIGNATUREDEFENTRY.containing_type = _METAGRAPHDEF -_METAGRAPHDEF.fields_by_name['meta_info_def'].message_type = _METAGRAPHDEF_METAINFODEF -_METAGRAPHDEF.fields_by_name['graph_def'].message_type = graph__pb2._GRAPHDEF -_METAGRAPHDEF.fields_by_name['saver_def'].message_type = saver__pb2._SAVERDEF -_METAGRAPHDEF.fields_by_name['collection_def'].message_type = _METAGRAPHDEF_COLLECTIONDEFENTRY -_METAGRAPHDEF.fields_by_name['signature_def'].message_type = _METAGRAPHDEF_SIGNATUREDEFENTRY -_METAGRAPHDEF.fields_by_name['asset_file_def'].message_type = _ASSETFILEDEF -_METAGRAPHDEF.fields_by_name['object_graph_def'].message_type = saved__object__graph__pb2._SAVEDOBJECTGRAPH -_COLLECTIONDEF_NODELIST.containing_type = _COLLECTIONDEF -_COLLECTIONDEF_BYTESLIST.containing_type = _COLLECTIONDEF -_COLLECTIONDEF_INT64LIST.containing_type = _COLLECTIONDEF -_COLLECTIONDEF_FLOATLIST.containing_type = _COLLECTIONDEF -_COLLECTIONDEF_ANYLIST.fields_by_name['value'].message_type = google_dot_protobuf_dot_any__pb2._ANY -_COLLECTIONDEF_ANYLIST.containing_type = _COLLECTIONDEF -_COLLECTIONDEF.fields_by_name['node_list'].message_type = _COLLECTIONDEF_NODELIST -_COLLECTIONDEF.fields_by_name['bytes_list'].message_type = _COLLECTIONDEF_BYTESLIST -_COLLECTIONDEF.fields_by_name['int64_list'].message_type = _COLLECTIONDEF_INT64LIST -_COLLECTIONDEF.fields_by_name['float_list'].message_type = _COLLECTIONDEF_FLOATLIST -_COLLECTIONDEF.fields_by_name['any_list'].message_type = _COLLECTIONDEF_ANYLIST -_COLLECTIONDEF.oneofs_by_name['kind'].fields.append( - _COLLECTIONDEF.fields_by_name['node_list']) -_COLLECTIONDEF.fields_by_name['node_list'].containing_oneof = _COLLECTIONDEF.oneofs_by_name['kind'] -_COLLECTIONDEF.oneofs_by_name['kind'].fields.append( - _COLLECTIONDEF.fields_by_name['bytes_list']) -_COLLECTIONDEF.fields_by_name['bytes_list'].containing_oneof = _COLLECTIONDEF.oneofs_by_name['kind'] -_COLLECTIONDEF.oneofs_by_name['kind'].fields.append( - _COLLECTIONDEF.fields_by_name['int64_list']) -_COLLECTIONDEF.fields_by_name['int64_list'].containing_oneof = _COLLECTIONDEF.oneofs_by_name['kind'] -_COLLECTIONDEF.oneofs_by_name['kind'].fields.append( - _COLLECTIONDEF.fields_by_name['float_list']) -_COLLECTIONDEF.fields_by_name['float_list'].containing_oneof = _COLLECTIONDEF.oneofs_by_name['kind'] -_COLLECTIONDEF.oneofs_by_name['kind'].fields.append( - _COLLECTIONDEF.fields_by_name['any_list']) -_COLLECTIONDEF.fields_by_name['any_list'].containing_oneof = _COLLECTIONDEF.oneofs_by_name['kind'] -_TENSORINFO_COOSPARSE.containing_type = _TENSORINFO -_TENSORINFO_COMPOSITETENSOR.fields_by_name['type_spec'].message_type = struct__pb2._TYPESPECPROTO -_TENSORINFO_COMPOSITETENSOR.fields_by_name['components'].message_type = _TENSORINFO -_TENSORINFO_COMPOSITETENSOR.containing_type = _TENSORINFO -_TENSORINFO.fields_by_name['coo_sparse'].message_type = _TENSORINFO_COOSPARSE -_TENSORINFO.fields_by_name['composite_tensor'].message_type = _TENSORINFO_COMPOSITETENSOR -_TENSORINFO.fields_by_name['dtype'].enum_type = types__pb2._DATATYPE -_TENSORINFO.fields_by_name['tensor_shape'].message_type = tensor__shape__pb2._TENSORSHAPEPROTO -_TENSORINFO.oneofs_by_name['encoding'].fields.append( - _TENSORINFO.fields_by_name['name']) -_TENSORINFO.fields_by_name['name'].containing_oneof = _TENSORINFO.oneofs_by_name['encoding'] -_TENSORINFO.oneofs_by_name['encoding'].fields.append( - _TENSORINFO.fields_by_name['coo_sparse']) -_TENSORINFO.fields_by_name['coo_sparse'].containing_oneof = _TENSORINFO.oneofs_by_name['encoding'] -_TENSORINFO.oneofs_by_name['encoding'].fields.append( - _TENSORINFO.fields_by_name['composite_tensor']) -_TENSORINFO.fields_by_name['composite_tensor'].containing_oneof = _TENSORINFO.oneofs_by_name['encoding'] -_SIGNATUREDEF_INPUTSENTRY.fields_by_name['value'].message_type = _TENSORINFO -_SIGNATUREDEF_INPUTSENTRY.containing_type = _SIGNATUREDEF -_SIGNATUREDEF_OUTPUTSENTRY.fields_by_name['value'].message_type = _TENSORINFO -_SIGNATUREDEF_OUTPUTSENTRY.containing_type = _SIGNATUREDEF -_SIGNATUREDEF.fields_by_name['inputs'].message_type = _SIGNATUREDEF_INPUTSENTRY -_SIGNATUREDEF.fields_by_name['outputs'].message_type = _SIGNATUREDEF_OUTPUTSENTRY -_ASSETFILEDEF.fields_by_name['tensor_info'].message_type = _TENSORINFO -DESCRIPTOR.message_types_by_name['MetaGraphDef'] = _METAGRAPHDEF -DESCRIPTOR.message_types_by_name['CollectionDef'] = _COLLECTIONDEF -DESCRIPTOR.message_types_by_name['TensorInfo'] = _TENSORINFO -DESCRIPTOR.message_types_by_name['SignatureDef'] = _SIGNATUREDEF -DESCRIPTOR.message_types_by_name['AssetFileDef'] = _ASSETFILEDEF -_sym_db.RegisterFileDescriptor(DESCRIPTOR) - -MetaGraphDef = _reflection.GeneratedProtocolMessageType('MetaGraphDef', (_message.Message,), { - - 'MetaInfoDef' : _reflection.GeneratedProtocolMessageType('MetaInfoDef', (_message.Message,), { - - 'FunctionAliasesEntry' : _reflection.GeneratedProtocolMessageType('FunctionAliasesEntry', (_message.Message,), { - 'DESCRIPTOR' : _METAGRAPHDEF_METAINFODEF_FUNCTIONALIASESENTRY, - '__module__' : 'meta_graph_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.MetaGraphDef.MetaInfoDef.FunctionAliasesEntry) - }) - , - 'DESCRIPTOR' : _METAGRAPHDEF_METAINFODEF, - '__module__' : 'meta_graph_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.MetaGraphDef.MetaInfoDef) - }) - , - - 'CollectionDefEntry' : _reflection.GeneratedProtocolMessageType('CollectionDefEntry', (_message.Message,), { - 'DESCRIPTOR' : _METAGRAPHDEF_COLLECTIONDEFENTRY, - '__module__' : 'meta_graph_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.MetaGraphDef.CollectionDefEntry) - }) - , - - 'SignatureDefEntry' : _reflection.GeneratedProtocolMessageType('SignatureDefEntry', (_message.Message,), { - 'DESCRIPTOR' : _METAGRAPHDEF_SIGNATUREDEFENTRY, - '__module__' : 'meta_graph_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.MetaGraphDef.SignatureDefEntry) - }) - , - 'DESCRIPTOR' : _METAGRAPHDEF, - '__module__' : 'meta_graph_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.MetaGraphDef) - }) -_sym_db.RegisterMessage(MetaGraphDef) -_sym_db.RegisterMessage(MetaGraphDef.MetaInfoDef) -_sym_db.RegisterMessage(MetaGraphDef.MetaInfoDef.FunctionAliasesEntry) -_sym_db.RegisterMessage(MetaGraphDef.CollectionDefEntry) -_sym_db.RegisterMessage(MetaGraphDef.SignatureDefEntry) - -CollectionDef = _reflection.GeneratedProtocolMessageType('CollectionDef', (_message.Message,), { - - 'NodeList' : _reflection.GeneratedProtocolMessageType('NodeList', (_message.Message,), { - 'DESCRIPTOR' : _COLLECTIONDEF_NODELIST, - '__module__' : 'meta_graph_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.CollectionDef.NodeList) - }) - , - - 'BytesList' : _reflection.GeneratedProtocolMessageType('BytesList', (_message.Message,), { - 'DESCRIPTOR' : _COLLECTIONDEF_BYTESLIST, - '__module__' : 'meta_graph_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.CollectionDef.BytesList) - }) - , - - 'Int64List' : _reflection.GeneratedProtocolMessageType('Int64List', (_message.Message,), { - 'DESCRIPTOR' : _COLLECTIONDEF_INT64LIST, - '__module__' : 'meta_graph_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.CollectionDef.Int64List) - }) - , - - 'FloatList' : _reflection.GeneratedProtocolMessageType('FloatList', (_message.Message,), { - 'DESCRIPTOR' : _COLLECTIONDEF_FLOATLIST, - '__module__' : 'meta_graph_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.CollectionDef.FloatList) - }) - , - - 'AnyList' : _reflection.GeneratedProtocolMessageType('AnyList', (_message.Message,), { - 'DESCRIPTOR' : _COLLECTIONDEF_ANYLIST, - '__module__' : 'meta_graph_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.CollectionDef.AnyList) - }) - , - 'DESCRIPTOR' : _COLLECTIONDEF, - '__module__' : 'meta_graph_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.CollectionDef) - }) -_sym_db.RegisterMessage(CollectionDef) -_sym_db.RegisterMessage(CollectionDef.NodeList) -_sym_db.RegisterMessage(CollectionDef.BytesList) -_sym_db.RegisterMessage(CollectionDef.Int64List) -_sym_db.RegisterMessage(CollectionDef.FloatList) -_sym_db.RegisterMessage(CollectionDef.AnyList) - -TensorInfo = _reflection.GeneratedProtocolMessageType('TensorInfo', (_message.Message,), { - - 'CooSparse' : _reflection.GeneratedProtocolMessageType('CooSparse', (_message.Message,), { - 'DESCRIPTOR' : _TENSORINFO_COOSPARSE, - '__module__' : 'meta_graph_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.TensorInfo.CooSparse) - }) - , - - 'CompositeTensor' : _reflection.GeneratedProtocolMessageType('CompositeTensor', (_message.Message,), { - 'DESCRIPTOR' : _TENSORINFO_COMPOSITETENSOR, - '__module__' : 'meta_graph_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.TensorInfo.CompositeTensor) - }) - , - 'DESCRIPTOR' : _TENSORINFO, - '__module__' : 'meta_graph_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.TensorInfo) - }) -_sym_db.RegisterMessage(TensorInfo) -_sym_db.RegisterMessage(TensorInfo.CooSparse) -_sym_db.RegisterMessage(TensorInfo.CompositeTensor) - -SignatureDef = _reflection.GeneratedProtocolMessageType('SignatureDef', (_message.Message,), { - - 'InputsEntry' : _reflection.GeneratedProtocolMessageType('InputsEntry', (_message.Message,), { - 'DESCRIPTOR' : _SIGNATUREDEF_INPUTSENTRY, - '__module__' : 'meta_graph_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.SignatureDef.InputsEntry) - }) - , - - 'OutputsEntry' : _reflection.GeneratedProtocolMessageType('OutputsEntry', (_message.Message,), { - 'DESCRIPTOR' : _SIGNATUREDEF_OUTPUTSENTRY, - '__module__' : 'meta_graph_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.SignatureDef.OutputsEntry) - }) - , - 'DESCRIPTOR' : _SIGNATUREDEF, - '__module__' : 'meta_graph_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.SignatureDef) - }) -_sym_db.RegisterMessage(SignatureDef) -_sym_db.RegisterMessage(SignatureDef.InputsEntry) -_sym_db.RegisterMessage(SignatureDef.OutputsEntry) - -AssetFileDef = _reflection.GeneratedProtocolMessageType('AssetFileDef', (_message.Message,), { - 'DESCRIPTOR' : _ASSETFILEDEF, - '__module__' : 'meta_graph_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.AssetFileDef) - }) -_sym_db.RegisterMessage(AssetFileDef) - - -DESCRIPTOR._options = None -_METAGRAPHDEF_METAINFODEF_FUNCTIONALIASESENTRY._options = None -_METAGRAPHDEF_COLLECTIONDEFENTRY._options = None -_METAGRAPHDEF_SIGNATUREDEFENTRY._options = None -_COLLECTIONDEF_INT64LIST.fields_by_name['value']._options = None -_COLLECTIONDEF_FLOATLIST.fields_by_name['value']._options = None -_SIGNATUREDEF_INPUTSENTRY._options = None -_SIGNATUREDEF_OUTPUTSENTRY._options = None -# @@protoc_insertion_point(module_scope) diff --git a/redis_consumer/pbs/meta_graph_pb2_grpc.py b/redis_consumer/pbs/meta_graph_pb2_grpc.py deleted file mode 100644 index a8943526..00000000 --- a/redis_consumer/pbs/meta_graph_pb2_grpc.py +++ /dev/null @@ -1,3 +0,0 @@ -# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! -import grpc - diff --git a/redis_consumer/pbs/model_pb2.py b/redis_consumer/pbs/model_pb2.py deleted file mode 100644 index 8bddc4ca..00000000 --- a/redis_consumer/pbs/model_pb2.py +++ /dev/null @@ -1,102 +0,0 @@ -# -*- coding: utf-8 -*- -# Generated by the protocol buffer compiler. DO NOT EDIT! -# source: model.proto - -from google.protobuf import descriptor as _descriptor -from google.protobuf import message as _message -from google.protobuf import reflection as _reflection -from google.protobuf import symbol_database as _symbol_database -# @@protoc_insertion_point(imports) - -_sym_db = _symbol_database.Default() - - -from google.protobuf import wrappers_pb2 as google_dot_protobuf_dot_wrappers__pb2 - - -DESCRIPTOR = _descriptor.FileDescriptor( - name='model.proto', - package='tensorflow.serving', - syntax='proto3', - serialized_options=b'\370\001\001', - serialized_pb=b'\n\x0bmodel.proto\x12\x12tensorflow.serving\x1a\x1egoogle/protobuf/wrappers.proto\"\x8c\x01\n\tModelSpec\x12\x0c\n\x04name\x18\x01 \x01(\t\x12.\n\x07version\x18\x02 \x01(\x0b\x32\x1b.google.protobuf.Int64ValueH\x00\x12\x17\n\rversion_label\x18\x04 \x01(\tH\x00\x12\x16\n\x0esignature_name\x18\x03 \x01(\tB\x10\n\x0eversion_choiceB\x03\xf8\x01\x01\x62\x06proto3' - , - dependencies=[google_dot_protobuf_dot_wrappers__pb2.DESCRIPTOR,]) - - - - -_MODELSPEC = _descriptor.Descriptor( - name='ModelSpec', - full_name='tensorflow.serving.ModelSpec', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='name', full_name='tensorflow.serving.ModelSpec.name', index=0, - number=1, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='version', full_name='tensorflow.serving.ModelSpec.version', index=1, - number=2, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='version_label', full_name='tensorflow.serving.ModelSpec.version_label', index=2, - number=4, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='signature_name', full_name='tensorflow.serving.ModelSpec.signature_name', index=3, - number=3, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - _descriptor.OneofDescriptor( - name='version_choice', full_name='tensorflow.serving.ModelSpec.version_choice', - index=0, containing_type=None, fields=[]), - ], - serialized_start=68, - serialized_end=208, -) - -_MODELSPEC.fields_by_name['version'].message_type = google_dot_protobuf_dot_wrappers__pb2._INT64VALUE -_MODELSPEC.oneofs_by_name['version_choice'].fields.append( - _MODELSPEC.fields_by_name['version']) -_MODELSPEC.fields_by_name['version'].containing_oneof = _MODELSPEC.oneofs_by_name['version_choice'] -_MODELSPEC.oneofs_by_name['version_choice'].fields.append( - _MODELSPEC.fields_by_name['version_label']) -_MODELSPEC.fields_by_name['version_label'].containing_oneof = _MODELSPEC.oneofs_by_name['version_choice'] -DESCRIPTOR.message_types_by_name['ModelSpec'] = _MODELSPEC -_sym_db.RegisterFileDescriptor(DESCRIPTOR) - -ModelSpec = _reflection.GeneratedProtocolMessageType('ModelSpec', (_message.Message,), { - 'DESCRIPTOR' : _MODELSPEC, - '__module__' : 'model_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.serving.ModelSpec) - }) -_sym_db.RegisterMessage(ModelSpec) - - -DESCRIPTOR._options = None -# @@protoc_insertion_point(module_scope) diff --git a/redis_consumer/pbs/model_pb2_grpc.py b/redis_consumer/pbs/model_pb2_grpc.py deleted file mode 100644 index a8943526..00000000 --- a/redis_consumer/pbs/model_pb2_grpc.py +++ /dev/null @@ -1,3 +0,0 @@ -# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! -import grpc - diff --git a/redis_consumer/pbs/node_def_pb2.py b/redis_consumer/pbs/node_def_pb2.py deleted file mode 100644 index b2864a5b..00000000 --- a/redis_consumer/pbs/node_def_pb2.py +++ /dev/null @@ -1,202 +0,0 @@ -# -*- coding: utf-8 -*- -# Generated by the protocol buffer compiler. DO NOT EDIT! -# source: node_def.proto - -from google.protobuf import descriptor as _descriptor -from google.protobuf import message as _message -from google.protobuf import reflection as _reflection -from google.protobuf import symbol_database as _symbol_database -# @@protoc_insertion_point(imports) - -_sym_db = _symbol_database.Default() - - -import redis_consumer.pbs.attr_value_pb2 as attr__value__pb2 - - -DESCRIPTOR = _descriptor.FileDescriptor( - name='node_def.proto', - package='tensorflow', - syntax='proto3', - serialized_options=b'\n\030org.tensorflow.frameworkB\tNodeProtoP\001Z=github.com/tensorflow/tensorflow/tensorflow/go/core/framework\370\001\001', - serialized_pb=b'\n\x0enode_def.proto\x12\ntensorflow\x1a\x10\x61ttr_value.proto\"\xd2\x02\n\x07NodeDef\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\n\n\x02op\x18\x02 \x01(\t\x12\r\n\x05input\x18\x03 \x03(\t\x12\x0e\n\x06\x64\x65vice\x18\x04 \x01(\t\x12+\n\x04\x61ttr\x18\x05 \x03(\x0b\x32\x1d.tensorflow.NodeDef.AttrEntry\x12J\n\x17\x65xperimental_debug_info\x18\x06 \x01(\x0b\x32).tensorflow.NodeDef.ExperimentalDebugInfo\x1a\x42\n\tAttrEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12$\n\x05value\x18\x02 \x01(\x0b\x32\x15.tensorflow.AttrValue:\x02\x38\x01\x1aQ\n\x15\x45xperimentalDebugInfo\x12\x1b\n\x13original_node_names\x18\x01 \x03(\t\x12\x1b\n\x13original_func_names\x18\x02 \x03(\tBi\n\x18org.tensorflow.frameworkB\tNodeProtoP\x01Z=github.com/tensorflow/tensorflow/tensorflow/go/core/framework\xf8\x01\x01\x62\x06proto3' - , - dependencies=[attr__value__pb2.DESCRIPTOR,]) - - - - -_NODEDEF_ATTRENTRY = _descriptor.Descriptor( - name='AttrEntry', - full_name='tensorflow.NodeDef.AttrEntry', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='key', full_name='tensorflow.NodeDef.AttrEntry.key', index=0, - number=1, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='value', full_name='tensorflow.NodeDef.AttrEntry.value', index=1, - number=2, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=b'8\001', - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=238, - serialized_end=304, -) - -_NODEDEF_EXPERIMENTALDEBUGINFO = _descriptor.Descriptor( - name='ExperimentalDebugInfo', - full_name='tensorflow.NodeDef.ExperimentalDebugInfo', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='original_node_names', full_name='tensorflow.NodeDef.ExperimentalDebugInfo.original_node_names', index=0, - number=1, type=9, cpp_type=9, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='original_func_names', full_name='tensorflow.NodeDef.ExperimentalDebugInfo.original_func_names', index=1, - number=2, type=9, cpp_type=9, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=306, - serialized_end=387, -) - -_NODEDEF = _descriptor.Descriptor( - name='NodeDef', - full_name='tensorflow.NodeDef', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='name', full_name='tensorflow.NodeDef.name', index=0, - number=1, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='op', full_name='tensorflow.NodeDef.op', index=1, - number=2, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='input', full_name='tensorflow.NodeDef.input', index=2, - number=3, type=9, cpp_type=9, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='device', full_name='tensorflow.NodeDef.device', index=3, - number=4, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='attr', full_name='tensorflow.NodeDef.attr', index=4, - number=5, type=11, cpp_type=10, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='experimental_debug_info', full_name='tensorflow.NodeDef.experimental_debug_info', index=5, - number=6, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[_NODEDEF_ATTRENTRY, _NODEDEF_EXPERIMENTALDEBUGINFO, ], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=49, - serialized_end=387, -) - -_NODEDEF_ATTRENTRY.fields_by_name['value'].message_type = attr__value__pb2._ATTRVALUE -_NODEDEF_ATTRENTRY.containing_type = _NODEDEF -_NODEDEF_EXPERIMENTALDEBUGINFO.containing_type = _NODEDEF -_NODEDEF.fields_by_name['attr'].message_type = _NODEDEF_ATTRENTRY -_NODEDEF.fields_by_name['experimental_debug_info'].message_type = _NODEDEF_EXPERIMENTALDEBUGINFO -DESCRIPTOR.message_types_by_name['NodeDef'] = _NODEDEF -_sym_db.RegisterFileDescriptor(DESCRIPTOR) - -NodeDef = _reflection.GeneratedProtocolMessageType('NodeDef', (_message.Message,), { - - 'AttrEntry' : _reflection.GeneratedProtocolMessageType('AttrEntry', (_message.Message,), { - 'DESCRIPTOR' : _NODEDEF_ATTRENTRY, - '__module__' : 'node_def_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.NodeDef.AttrEntry) - }) - , - - 'ExperimentalDebugInfo' : _reflection.GeneratedProtocolMessageType('ExperimentalDebugInfo', (_message.Message,), { - 'DESCRIPTOR' : _NODEDEF_EXPERIMENTALDEBUGINFO, - '__module__' : 'node_def_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.NodeDef.ExperimentalDebugInfo) - }) - , - 'DESCRIPTOR' : _NODEDEF, - '__module__' : 'node_def_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.NodeDef) - }) -_sym_db.RegisterMessage(NodeDef) -_sym_db.RegisterMessage(NodeDef.AttrEntry) -_sym_db.RegisterMessage(NodeDef.ExperimentalDebugInfo) - - -DESCRIPTOR._options = None -_NODEDEF_ATTRENTRY._options = None -# @@protoc_insertion_point(module_scope) diff --git a/redis_consumer/pbs/node_def_pb2_grpc.py b/redis_consumer/pbs/node_def_pb2_grpc.py deleted file mode 100644 index a8943526..00000000 --- a/redis_consumer/pbs/node_def_pb2_grpc.py +++ /dev/null @@ -1,3 +0,0 @@ -# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! -import grpc - diff --git a/redis_consumer/pbs/op_def_pb2.py b/redis_consumer/pbs/op_def_pb2.py deleted file mode 100644 index 1a150703..00000000 --- a/redis_consumer/pbs/op_def_pb2.py +++ /dev/null @@ -1,404 +0,0 @@ -# -*- coding: utf-8 -*- -# Generated by the protocol buffer compiler. DO NOT EDIT! -# source: op_def.proto - -from google.protobuf import descriptor as _descriptor -from google.protobuf import message as _message -from google.protobuf import reflection as _reflection -from google.protobuf import symbol_database as _symbol_database -# @@protoc_insertion_point(imports) - -_sym_db = _symbol_database.Default() - - -import redis_consumer.pbs.attr_value_pb2 as attr__value__pb2 -import redis_consumer.pbs.types_pb2 as types__pb2 - - -DESCRIPTOR = _descriptor.FileDescriptor( - name='op_def.proto', - package='tensorflow', - syntax='proto3', - serialized_options=b'\n\030org.tensorflow.frameworkB\013OpDefProtosP\001Z=github.com/tensorflow/tensorflow/tensorflow/go/core/framework\370\001\001', - serialized_pb=b'\n\x0cop_def.proto\x12\ntensorflow\x1a\x10\x61ttr_value.proto\x1a\x0btypes.proto\"\xd0\x05\n\x05OpDef\x12\x0c\n\x04name\x18\x01 \x01(\t\x12+\n\tinput_arg\x18\x02 \x03(\x0b\x32\x18.tensorflow.OpDef.ArgDef\x12,\n\noutput_arg\x18\x03 \x03(\x0b\x32\x18.tensorflow.OpDef.ArgDef\x12\x16\n\x0e\x63ontrol_output\x18\x14 \x03(\t\x12\'\n\x04\x61ttr\x18\x04 \x03(\x0b\x32\x19.tensorflow.OpDef.AttrDef\x12.\n\x0b\x64\x65precation\x18\x08 \x01(\x0b\x32\x19.tensorflow.OpDeprecation\x12\x0f\n\x07summary\x18\x05 \x01(\t\x12\x13\n\x0b\x64\x65scription\x18\x06 \x01(\t\x12\x16\n\x0eis_commutative\x18\x12 \x01(\x08\x12\x14\n\x0cis_aggregate\x18\x10 \x01(\x08\x12\x13\n\x0bis_stateful\x18\x11 \x01(\x08\x12\"\n\x1a\x61llows_uninitialized_input\x18\x13 \x01(\x08\x1a\x9f\x01\n\x06\x41rgDef\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x13\n\x0b\x64\x65scription\x18\x02 \x01(\t\x12\"\n\x04type\x18\x03 \x01(\x0e\x32\x14.tensorflow.DataType\x12\x11\n\ttype_attr\x18\x04 \x01(\t\x12\x13\n\x0bnumber_attr\x18\x05 \x01(\t\x12\x16\n\x0etype_list_attr\x18\x06 \x01(\t\x12\x0e\n\x06is_ref\x18\x10 \x01(\x08\x1a\xbd\x01\n\x07\x41ttrDef\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0c\n\x04type\x18\x02 \x01(\t\x12,\n\rdefault_value\x18\x03 \x01(\x0b\x32\x15.tensorflow.AttrValue\x12\x13\n\x0b\x64\x65scription\x18\x04 \x01(\t\x12\x13\n\x0bhas_minimum\x18\x05 \x01(\x08\x12\x0f\n\x07minimum\x18\x06 \x01(\x03\x12-\n\x0e\x61llowed_values\x18\x07 \x01(\x0b\x32\x15.tensorflow.AttrValue\"5\n\rOpDeprecation\x12\x0f\n\x07version\x18\x01 \x01(\x05\x12\x13\n\x0b\x65xplanation\x18\x02 \x01(\t\"\'\n\x06OpList\x12\x1d\n\x02op\x18\x01 \x03(\x0b\x32\x11.tensorflow.OpDefBk\n\x18org.tensorflow.frameworkB\x0bOpDefProtosP\x01Z=github.com/tensorflow/tensorflow/tensorflow/go/core/framework\xf8\x01\x01\x62\x06proto3' - , - dependencies=[attr__value__pb2.DESCRIPTOR,types__pb2.DESCRIPTOR,]) - - - - -_OPDEF_ARGDEF = _descriptor.Descriptor( - name='ArgDef', - full_name='tensorflow.OpDef.ArgDef', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='name', full_name='tensorflow.OpDef.ArgDef.name', index=0, - number=1, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='description', full_name='tensorflow.OpDef.ArgDef.description', index=1, - number=2, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='type', full_name='tensorflow.OpDef.ArgDef.type', index=2, - number=3, type=14, cpp_type=8, label=1, - has_default_value=False, default_value=0, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='type_attr', full_name='tensorflow.OpDef.ArgDef.type_attr', index=3, - number=4, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='number_attr', full_name='tensorflow.OpDef.ArgDef.number_attr', index=4, - number=5, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='type_list_attr', full_name='tensorflow.OpDef.ArgDef.type_list_attr', index=5, - number=6, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='is_ref', full_name='tensorflow.OpDef.ArgDef.is_ref', index=6, - number=16, type=8, cpp_type=7, label=1, - has_default_value=False, default_value=False, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=429, - serialized_end=588, -) - -_OPDEF_ATTRDEF = _descriptor.Descriptor( - name='AttrDef', - full_name='tensorflow.OpDef.AttrDef', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='name', full_name='tensorflow.OpDef.AttrDef.name', index=0, - number=1, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='type', full_name='tensorflow.OpDef.AttrDef.type', index=1, - number=2, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='default_value', full_name='tensorflow.OpDef.AttrDef.default_value', index=2, - number=3, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='description', full_name='tensorflow.OpDef.AttrDef.description', index=3, - number=4, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='has_minimum', full_name='tensorflow.OpDef.AttrDef.has_minimum', index=4, - number=5, type=8, cpp_type=7, label=1, - has_default_value=False, default_value=False, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='minimum', full_name='tensorflow.OpDef.AttrDef.minimum', index=5, - number=6, type=3, cpp_type=2, label=1, - has_default_value=False, default_value=0, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='allowed_values', full_name='tensorflow.OpDef.AttrDef.allowed_values', index=6, - number=7, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=591, - serialized_end=780, -) - -_OPDEF = _descriptor.Descriptor( - name='OpDef', - full_name='tensorflow.OpDef', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='name', full_name='tensorflow.OpDef.name', index=0, - number=1, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='input_arg', full_name='tensorflow.OpDef.input_arg', index=1, - number=2, type=11, cpp_type=10, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='output_arg', full_name='tensorflow.OpDef.output_arg', index=2, - number=3, type=11, cpp_type=10, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='control_output', full_name='tensorflow.OpDef.control_output', index=3, - number=20, type=9, cpp_type=9, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='attr', full_name='tensorflow.OpDef.attr', index=4, - number=4, type=11, cpp_type=10, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='deprecation', full_name='tensorflow.OpDef.deprecation', index=5, - number=8, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='summary', full_name='tensorflow.OpDef.summary', index=6, - number=5, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='description', full_name='tensorflow.OpDef.description', index=7, - number=6, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='is_commutative', full_name='tensorflow.OpDef.is_commutative', index=8, - number=18, type=8, cpp_type=7, label=1, - has_default_value=False, default_value=False, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='is_aggregate', full_name='tensorflow.OpDef.is_aggregate', index=9, - number=16, type=8, cpp_type=7, label=1, - has_default_value=False, default_value=False, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='is_stateful', full_name='tensorflow.OpDef.is_stateful', index=10, - number=17, type=8, cpp_type=7, label=1, - has_default_value=False, default_value=False, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='allows_uninitialized_input', full_name='tensorflow.OpDef.allows_uninitialized_input', index=11, - number=19, type=8, cpp_type=7, label=1, - has_default_value=False, default_value=False, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[_OPDEF_ARGDEF, _OPDEF_ATTRDEF, ], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=60, - serialized_end=780, -) - - -_OPDEPRECATION = _descriptor.Descriptor( - name='OpDeprecation', - full_name='tensorflow.OpDeprecation', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='version', full_name='tensorflow.OpDeprecation.version', index=0, - number=1, type=5, cpp_type=1, label=1, - has_default_value=False, default_value=0, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='explanation', full_name='tensorflow.OpDeprecation.explanation', index=1, - number=2, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=782, - serialized_end=835, -) - - -_OPLIST = _descriptor.Descriptor( - name='OpList', - full_name='tensorflow.OpList', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='op', full_name='tensorflow.OpList.op', index=0, - number=1, type=11, cpp_type=10, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=837, - serialized_end=876, -) - -_OPDEF_ARGDEF.fields_by_name['type'].enum_type = types__pb2._DATATYPE -_OPDEF_ARGDEF.containing_type = _OPDEF -_OPDEF_ATTRDEF.fields_by_name['default_value'].message_type = attr__value__pb2._ATTRVALUE -_OPDEF_ATTRDEF.fields_by_name['allowed_values'].message_type = attr__value__pb2._ATTRVALUE -_OPDEF_ATTRDEF.containing_type = _OPDEF -_OPDEF.fields_by_name['input_arg'].message_type = _OPDEF_ARGDEF -_OPDEF.fields_by_name['output_arg'].message_type = _OPDEF_ARGDEF -_OPDEF.fields_by_name['attr'].message_type = _OPDEF_ATTRDEF -_OPDEF.fields_by_name['deprecation'].message_type = _OPDEPRECATION -_OPLIST.fields_by_name['op'].message_type = _OPDEF -DESCRIPTOR.message_types_by_name['OpDef'] = _OPDEF -DESCRIPTOR.message_types_by_name['OpDeprecation'] = _OPDEPRECATION -DESCRIPTOR.message_types_by_name['OpList'] = _OPLIST -_sym_db.RegisterFileDescriptor(DESCRIPTOR) - -OpDef = _reflection.GeneratedProtocolMessageType('OpDef', (_message.Message,), { - - 'ArgDef' : _reflection.GeneratedProtocolMessageType('ArgDef', (_message.Message,), { - 'DESCRIPTOR' : _OPDEF_ARGDEF, - '__module__' : 'op_def_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.OpDef.ArgDef) - }) - , - - 'AttrDef' : _reflection.GeneratedProtocolMessageType('AttrDef', (_message.Message,), { - 'DESCRIPTOR' : _OPDEF_ATTRDEF, - '__module__' : 'op_def_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.OpDef.AttrDef) - }) - , - 'DESCRIPTOR' : _OPDEF, - '__module__' : 'op_def_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.OpDef) - }) -_sym_db.RegisterMessage(OpDef) -_sym_db.RegisterMessage(OpDef.ArgDef) -_sym_db.RegisterMessage(OpDef.AttrDef) - -OpDeprecation = _reflection.GeneratedProtocolMessageType('OpDeprecation', (_message.Message,), { - 'DESCRIPTOR' : _OPDEPRECATION, - '__module__' : 'op_def_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.OpDeprecation) - }) -_sym_db.RegisterMessage(OpDeprecation) - -OpList = _reflection.GeneratedProtocolMessageType('OpList', (_message.Message,), { - 'DESCRIPTOR' : _OPLIST, - '__module__' : 'op_def_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.OpList) - }) -_sym_db.RegisterMessage(OpList) - - -DESCRIPTOR._options = None -# @@protoc_insertion_point(module_scope) diff --git a/redis_consumer/pbs/op_def_pb2_grpc.py b/redis_consumer/pbs/op_def_pb2_grpc.py deleted file mode 100644 index a8943526..00000000 --- a/redis_consumer/pbs/op_def_pb2_grpc.py +++ /dev/null @@ -1,3 +0,0 @@ -# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! -import grpc - diff --git a/redis_consumer/pbs/predict_pb2.py b/redis_consumer/pbs/predict_pb2.py deleted file mode 100644 index 8eb45853..00000000 --- a/redis_consumer/pbs/predict_pb2.py +++ /dev/null @@ -1,232 +0,0 @@ -# -*- coding: utf-8 -*- -# Generated by the protocol buffer compiler. DO NOT EDIT! -# source: predict.proto - -from google.protobuf import descriptor as _descriptor -from google.protobuf import message as _message -from google.protobuf import reflection as _reflection -from google.protobuf import symbol_database as _symbol_database -# @@protoc_insertion_point(imports) - -_sym_db = _symbol_database.Default() - - -import redis_consumer.pbs.tensor_pb2 as tensor__pb2 -import redis_consumer.pbs.model_pb2 as model__pb2 - - -DESCRIPTOR = _descriptor.FileDescriptor( - name='predict.proto', - package='tensorflow.serving', - syntax='proto3', - serialized_options=b'\370\001\001', - serialized_pb=b'\n\rpredict.proto\x12\x12tensorflow.serving\x1a\x0ctensor.proto\x1a\x0bmodel.proto\"\xe2\x01\n\x0ePredictRequest\x12\x31\n\nmodel_spec\x18\x01 \x01(\x0b\x32\x1d.tensorflow.serving.ModelSpec\x12>\n\x06inputs\x18\x02 \x03(\x0b\x32..tensorflow.serving.PredictRequest.InputsEntry\x12\x15\n\routput_filter\x18\x03 \x03(\t\x1a\x46\n\x0bInputsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12&\n\x05value\x18\x02 \x01(\x0b\x32\x17.tensorflow.TensorProto:\x02\x38\x01\"\xd0\x01\n\x0fPredictResponse\x12\x31\n\nmodel_spec\x18\x02 \x01(\x0b\x32\x1d.tensorflow.serving.ModelSpec\x12\x41\n\x07outputs\x18\x01 \x03(\x0b\x32\x30.tensorflow.serving.PredictResponse.OutputsEntry\x1aG\n\x0cOutputsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12&\n\x05value\x18\x02 \x01(\x0b\x32\x17.tensorflow.TensorProto:\x02\x38\x01\x42\x03\xf8\x01\x01\x62\x06proto3' - , - dependencies=[tensor__pb2.DESCRIPTOR,model__pb2.DESCRIPTOR,]) - - - - -_PREDICTREQUEST_INPUTSENTRY = _descriptor.Descriptor( - name='InputsEntry', - full_name='tensorflow.serving.PredictRequest.InputsEntry', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='key', full_name='tensorflow.serving.PredictRequest.InputsEntry.key', index=0, - number=1, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='value', full_name='tensorflow.serving.PredictRequest.InputsEntry.value', index=1, - number=2, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=b'8\001', - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=221, - serialized_end=291, -) - -_PREDICTREQUEST = _descriptor.Descriptor( - name='PredictRequest', - full_name='tensorflow.serving.PredictRequest', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='model_spec', full_name='tensorflow.serving.PredictRequest.model_spec', index=0, - number=1, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='inputs', full_name='tensorflow.serving.PredictRequest.inputs', index=1, - number=2, type=11, cpp_type=10, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='output_filter', full_name='tensorflow.serving.PredictRequest.output_filter', index=2, - number=3, type=9, cpp_type=9, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[_PREDICTREQUEST_INPUTSENTRY, ], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=65, - serialized_end=291, -) - - -_PREDICTRESPONSE_OUTPUTSENTRY = _descriptor.Descriptor( - name='OutputsEntry', - full_name='tensorflow.serving.PredictResponse.OutputsEntry', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='key', full_name='tensorflow.serving.PredictResponse.OutputsEntry.key', index=0, - number=1, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='value', full_name='tensorflow.serving.PredictResponse.OutputsEntry.value', index=1, - number=2, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=b'8\001', - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=431, - serialized_end=502, -) - -_PREDICTRESPONSE = _descriptor.Descriptor( - name='PredictResponse', - full_name='tensorflow.serving.PredictResponse', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='model_spec', full_name='tensorflow.serving.PredictResponse.model_spec', index=0, - number=2, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='outputs', full_name='tensorflow.serving.PredictResponse.outputs', index=1, - number=1, type=11, cpp_type=10, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[_PREDICTRESPONSE_OUTPUTSENTRY, ], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=294, - serialized_end=502, -) - -_PREDICTREQUEST_INPUTSENTRY.fields_by_name['value'].message_type = tensor__pb2._TENSORPROTO -_PREDICTREQUEST_INPUTSENTRY.containing_type = _PREDICTREQUEST -_PREDICTREQUEST.fields_by_name['model_spec'].message_type = model__pb2._MODELSPEC -_PREDICTREQUEST.fields_by_name['inputs'].message_type = _PREDICTREQUEST_INPUTSENTRY -_PREDICTRESPONSE_OUTPUTSENTRY.fields_by_name['value'].message_type = tensor__pb2._TENSORPROTO -_PREDICTRESPONSE_OUTPUTSENTRY.containing_type = _PREDICTRESPONSE -_PREDICTRESPONSE.fields_by_name['model_spec'].message_type = model__pb2._MODELSPEC -_PREDICTRESPONSE.fields_by_name['outputs'].message_type = _PREDICTRESPONSE_OUTPUTSENTRY -DESCRIPTOR.message_types_by_name['PredictRequest'] = _PREDICTREQUEST -DESCRIPTOR.message_types_by_name['PredictResponse'] = _PREDICTRESPONSE -_sym_db.RegisterFileDescriptor(DESCRIPTOR) - -PredictRequest = _reflection.GeneratedProtocolMessageType('PredictRequest', (_message.Message,), { - - 'InputsEntry' : _reflection.GeneratedProtocolMessageType('InputsEntry', (_message.Message,), { - 'DESCRIPTOR' : _PREDICTREQUEST_INPUTSENTRY, - '__module__' : 'predict_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.serving.PredictRequest.InputsEntry) - }) - , - 'DESCRIPTOR' : _PREDICTREQUEST, - '__module__' : 'predict_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.serving.PredictRequest) - }) -_sym_db.RegisterMessage(PredictRequest) -_sym_db.RegisterMessage(PredictRequest.InputsEntry) - -PredictResponse = _reflection.GeneratedProtocolMessageType('PredictResponse', (_message.Message,), { - - 'OutputsEntry' : _reflection.GeneratedProtocolMessageType('OutputsEntry', (_message.Message,), { - 'DESCRIPTOR' : _PREDICTRESPONSE_OUTPUTSENTRY, - '__module__' : 'predict_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.serving.PredictResponse.OutputsEntry) - }) - , - 'DESCRIPTOR' : _PREDICTRESPONSE, - '__module__' : 'predict_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.serving.PredictResponse) - }) -_sym_db.RegisterMessage(PredictResponse) -_sym_db.RegisterMessage(PredictResponse.OutputsEntry) - - -DESCRIPTOR._options = None -_PREDICTREQUEST_INPUTSENTRY._options = None -_PREDICTRESPONSE_OUTPUTSENTRY._options = None -# @@protoc_insertion_point(module_scope) diff --git a/redis_consumer/pbs/predict_pb2_grpc.py b/redis_consumer/pbs/predict_pb2_grpc.py deleted file mode 100644 index a8943526..00000000 --- a/redis_consumer/pbs/predict_pb2_grpc.py +++ /dev/null @@ -1,3 +0,0 @@ -# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! -import grpc - diff --git a/redis_consumer/pbs/prediction_service_pb2.py b/redis_consumer/pbs/prediction_service_pb2.py deleted file mode 100644 index 7e88b957..00000000 --- a/redis_consumer/pbs/prediction_service_pb2.py +++ /dev/null @@ -1,66 +0,0 @@ -# -*- coding: utf-8 -*- -# Generated by the protocol buffer compiler. DO NOT EDIT! -# source: prediction_service.proto - -from google.protobuf import descriptor as _descriptor -from google.protobuf import message as _message -from google.protobuf import reflection as _reflection -from google.protobuf import symbol_database as _symbol_database -# @@protoc_insertion_point(imports) - -_sym_db = _symbol_database.Default() - - -import redis_consumer.pbs.get_model_metadata_pb2 as get__model__metadata__pb2 -import redis_consumer.pbs.predict_pb2 as predict__pb2 - - -DESCRIPTOR = _descriptor.FileDescriptor( - name='prediction_service.proto', - package='tensorflow.serving', - syntax='proto3', - serialized_options=b'\370\001\001', - serialized_pb=b'\n\x18prediction_service.proto\x12\x12tensorflow.serving\x1a\x18get_model_metadata.proto\x1a\rpredict.proto2\xd6\x01\n\x11PredictionService\x12R\n\x07Predict\x12\".tensorflow.serving.PredictRequest\x1a#.tensorflow.serving.PredictResponse\x12m\n\x10GetModelMetadata\x12+.tensorflow.serving.GetModelMetadataRequest\x1a,.tensorflow.serving.GetModelMetadataResponseB\x03\xf8\x01\x01\x62\x06proto3' - , - dependencies=[get__model__metadata__pb2.DESCRIPTOR,predict__pb2.DESCRIPTOR,]) - - - -_sym_db.RegisterFileDescriptor(DESCRIPTOR) - - -DESCRIPTOR._options = None - -_PREDICTIONSERVICE = _descriptor.ServiceDescriptor( - name='PredictionService', - full_name='tensorflow.serving.PredictionService', - file=DESCRIPTOR, - index=0, - serialized_options=None, - serialized_start=90, - serialized_end=304, - methods=[ - _descriptor.MethodDescriptor( - name='Predict', - full_name='tensorflow.serving.PredictionService.Predict', - index=0, - containing_service=None, - input_type=predict__pb2._PREDICTREQUEST, - output_type=predict__pb2._PREDICTRESPONSE, - serialized_options=None, - ), - _descriptor.MethodDescriptor( - name='GetModelMetadata', - full_name='tensorflow.serving.PredictionService.GetModelMetadata', - index=1, - containing_service=None, - input_type=get__model__metadata__pb2._GETMODELMETADATAREQUEST, - output_type=get__model__metadata__pb2._GETMODELMETADATARESPONSE, - serialized_options=None, - ), -]) -_sym_db.RegisterServiceDescriptor(_PREDICTIONSERVICE) - -DESCRIPTOR.services_by_name['PredictionService'] = _PREDICTIONSERVICE - -# @@protoc_insertion_point(module_scope) diff --git a/redis_consumer/pbs/prediction_service_pb2_grpc.py b/redis_consumer/pbs/prediction_service_pb2_grpc.py deleted file mode 100644 index 76a23c73..00000000 --- a/redis_consumer/pbs/prediction_service_pb2_grpc.py +++ /dev/null @@ -1,78 +0,0 @@ -# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! -import grpc - -import redis_consumer.pbs.get_model_metadata_pb2 as get__model__metadata__pb2 -import redis_consumer.pbs.predict_pb2 as predict__pb2 - - -class PredictionServiceStub(object): - """open source marker; do not remove - PredictionService provides access to machine-learned models loaded by - model_servers. - Classify. - rpc Classify(ClassificationRequest) returns (ClassificationResponse); - """ - - def __init__(self, channel): - """Constructor. - - Args: - channel: A grpc.Channel. - """ - self.Predict = channel.unary_unary( - '/tensorflow.serving.PredictionService/Predict', - request_serializer=predict__pb2.PredictRequest.SerializeToString, - response_deserializer=predict__pb2.PredictResponse.FromString, - ) - self.GetModelMetadata = channel.unary_unary( - '/tensorflow.serving.PredictionService/GetModelMetadata', - request_serializer=get__model__metadata__pb2.GetModelMetadataRequest.SerializeToString, - response_deserializer=get__model__metadata__pb2.GetModelMetadataResponse.FromString, - ) - - -class PredictionServiceServicer(object): - """open source marker; do not remove - PredictionService provides access to machine-learned models loaded by - model_servers. - Classify. - rpc Classify(ClassificationRequest) returns (ClassificationResponse); - """ - - def Predict(self, request, context): - """Regress. - rpc Regress(RegressionRequest) returns (RegressionResponse); - - Predict -- provides access to loaded TensorFlow model. - """ - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') - - def GetModelMetadata(self, request, context): - """MultiInference API for multi-headed models. - rpc MultiInference(MultiInferenceRequest) returns (MultiInferenceResponse); - - GetModelMetadata - provides access to metadata for loaded models. - """ - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') - - -def add_PredictionServiceServicer_to_server(servicer, server): - rpc_method_handlers = { - 'Predict': grpc.unary_unary_rpc_method_handler( - servicer.Predict, - request_deserializer=predict__pb2.PredictRequest.FromString, - response_serializer=predict__pb2.PredictResponse.SerializeToString, - ), - 'GetModelMetadata': grpc.unary_unary_rpc_method_handler( - servicer.GetModelMetadata, - request_deserializer=get__model__metadata__pb2.GetModelMetadataRequest.FromString, - response_serializer=get__model__metadata__pb2.GetModelMetadataResponse.SerializeToString, - ), - } - generic_handler = grpc.method_handlers_generic_handler( - 'tensorflow.serving.PredictionService', rpc_method_handlers) - server.add_generic_rpc_handlers((generic_handler,)) diff --git a/redis_consumer/pbs/resource_handle_pb2.py b/redis_consumer/pbs/resource_handle_pb2.py deleted file mode 100644 index e149683e..00000000 --- a/redis_consumer/pbs/resource_handle_pb2.py +++ /dev/null @@ -1,156 +0,0 @@ -# -*- coding: utf-8 -*- -# Generated by the protocol buffer compiler. DO NOT EDIT! -# source: resource_handle.proto - -from google.protobuf import descriptor as _descriptor -from google.protobuf import message as _message -from google.protobuf import reflection as _reflection -from google.protobuf import symbol_database as _symbol_database -# @@protoc_insertion_point(imports) - -_sym_db = _symbol_database.Default() - - -import redis_consumer.pbs.tensor_shape_pb2 as tensor__shape__pb2 -import redis_consumer.pbs.types_pb2 as types__pb2 - - -DESCRIPTOR = _descriptor.FileDescriptor( - name='resource_handle.proto', - package='tensorflow', - syntax='proto3', - serialized_options=b'\n\030org.tensorflow.frameworkB\016ResourceHandleP\001Z=github.com/tensorflow/tensorflow/tensorflow/go/core/framework\370\001\001', - serialized_pb=b'\n\x15resource_handle.proto\x12\ntensorflow\x1a\x12tensor_shape.proto\x1a\x0btypes.proto\"\x9f\x02\n\x13ResourceHandleProto\x12\x0e\n\x06\x64\x65vice\x18\x01 \x01(\t\x12\x11\n\tcontainer\x18\x02 \x01(\t\x12\x0c\n\x04name\x18\x03 \x01(\t\x12\x11\n\thash_code\x18\x04 \x01(\x04\x12\x17\n\x0fmaybe_type_name\x18\x05 \x01(\t\x12H\n\x11\x64types_and_shapes\x18\x06 \x03(\x0b\x32-.tensorflow.ResourceHandleProto.DtypeAndShape\x1a\x61\n\rDtypeAndShape\x12#\n\x05\x64type\x18\x01 \x01(\x0e\x32\x14.tensorflow.DataType\x12+\n\x05shape\x18\x02 \x01(\x0b\x32\x1c.tensorflow.TensorShapeProtoBn\n\x18org.tensorflow.frameworkB\x0eResourceHandleP\x01Z=github.com/tensorflow/tensorflow/tensorflow/go/core/framework\xf8\x01\x01\x62\x06proto3' - , - dependencies=[tensor__shape__pb2.DESCRIPTOR,types__pb2.DESCRIPTOR,]) - - - - -_RESOURCEHANDLEPROTO_DTYPEANDSHAPE = _descriptor.Descriptor( - name='DtypeAndShape', - full_name='tensorflow.ResourceHandleProto.DtypeAndShape', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='dtype', full_name='tensorflow.ResourceHandleProto.DtypeAndShape.dtype', index=0, - number=1, type=14, cpp_type=8, label=1, - has_default_value=False, default_value=0, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='shape', full_name='tensorflow.ResourceHandleProto.DtypeAndShape.shape', index=1, - number=2, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=261, - serialized_end=358, -) - -_RESOURCEHANDLEPROTO = _descriptor.Descriptor( - name='ResourceHandleProto', - full_name='tensorflow.ResourceHandleProto', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='device', full_name='tensorflow.ResourceHandleProto.device', index=0, - number=1, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='container', full_name='tensorflow.ResourceHandleProto.container', index=1, - number=2, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='name', full_name='tensorflow.ResourceHandleProto.name', index=2, - number=3, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='hash_code', full_name='tensorflow.ResourceHandleProto.hash_code', index=3, - number=4, type=4, cpp_type=4, label=1, - has_default_value=False, default_value=0, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='maybe_type_name', full_name='tensorflow.ResourceHandleProto.maybe_type_name', index=4, - number=5, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='dtypes_and_shapes', full_name='tensorflow.ResourceHandleProto.dtypes_and_shapes', index=5, - number=6, type=11, cpp_type=10, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[_RESOURCEHANDLEPROTO_DTYPEANDSHAPE, ], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=71, - serialized_end=358, -) - -_RESOURCEHANDLEPROTO_DTYPEANDSHAPE.fields_by_name['dtype'].enum_type = types__pb2._DATATYPE -_RESOURCEHANDLEPROTO_DTYPEANDSHAPE.fields_by_name['shape'].message_type = tensor__shape__pb2._TENSORSHAPEPROTO -_RESOURCEHANDLEPROTO_DTYPEANDSHAPE.containing_type = _RESOURCEHANDLEPROTO -_RESOURCEHANDLEPROTO.fields_by_name['dtypes_and_shapes'].message_type = _RESOURCEHANDLEPROTO_DTYPEANDSHAPE -DESCRIPTOR.message_types_by_name['ResourceHandleProto'] = _RESOURCEHANDLEPROTO -_sym_db.RegisterFileDescriptor(DESCRIPTOR) - -ResourceHandleProto = _reflection.GeneratedProtocolMessageType('ResourceHandleProto', (_message.Message,), { - - 'DtypeAndShape' : _reflection.GeneratedProtocolMessageType('DtypeAndShape', (_message.Message,), { - 'DESCRIPTOR' : _RESOURCEHANDLEPROTO_DTYPEANDSHAPE, - '__module__' : 'resource_handle_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.ResourceHandleProto.DtypeAndShape) - }) - , - 'DESCRIPTOR' : _RESOURCEHANDLEPROTO, - '__module__' : 'resource_handle_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.ResourceHandleProto) - }) -_sym_db.RegisterMessage(ResourceHandleProto) -_sym_db.RegisterMessage(ResourceHandleProto.DtypeAndShape) - - -DESCRIPTOR._options = None -# @@protoc_insertion_point(module_scope) diff --git a/redis_consumer/pbs/resource_handle_pb2_grpc.py b/redis_consumer/pbs/resource_handle_pb2_grpc.py deleted file mode 100644 index a8943526..00000000 --- a/redis_consumer/pbs/resource_handle_pb2_grpc.py +++ /dev/null @@ -1,3 +0,0 @@ -# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! -import grpc - diff --git a/redis_consumer/pbs/saved_object_graph_pb2.py b/redis_consumer/pbs/saved_object_graph_pb2.py deleted file mode 100644 index 7cd9de9e..00000000 --- a/redis_consumer/pbs/saved_object_graph_pb2.py +++ /dev/null @@ -1,720 +0,0 @@ -# -*- coding: utf-8 -*- -# Generated by the protocol buffer compiler. DO NOT EDIT! -# source: saved_object_graph.proto - -from google.protobuf import descriptor as _descriptor -from google.protobuf import message as _message -from google.protobuf import reflection as _reflection -from google.protobuf import symbol_database as _symbol_database -# @@protoc_insertion_point(imports) - -_sym_db = _symbol_database.Default() - - -import redis_consumer.pbs.trackable_object_graph_pb2 as trackable__object__graph__pb2 -import redis_consumer.pbs.struct_pb2 as struct__pb2 -import redis_consumer.pbs.tensor_shape_pb2 as tensor__shape__pb2 -import redis_consumer.pbs.types_pb2 as types__pb2 -import redis_consumer.pbs.versions_pb2 as versions__pb2 -import redis_consumer.pbs.variable_pb2 as variable__pb2 - - -DESCRIPTOR = _descriptor.FileDescriptor( - name='saved_object_graph.proto', - package='tensorflow', - syntax='proto3', - serialized_options=b'\370\001\001', - serialized_pb=b'\n\x18saved_object_graph.proto\x12\ntensorflow\x1a\x1ctrackable_object_graph.proto\x1a\x0cstruct.proto\x1a\x12tensor_shape.proto\x1a\x0btypes.proto\x1a\x0eversions.proto\x1a\x0evariable.proto\"\xe8\x01\n\x10SavedObjectGraph\x12&\n\x05nodes\x18\x01 \x03(\x0b\x32\x17.tensorflow.SavedObject\x12O\n\x12\x63oncrete_functions\x18\x02 \x03(\x0b\x32\x33.tensorflow.SavedObjectGraph.ConcreteFunctionsEntry\x1a[\n\x16\x43oncreteFunctionsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\x30\n\x05value\x18\x02 \x01(\x0b\x32!.tensorflow.SavedConcreteFunction:\x02\x38\x01\"\xbd\x04\n\x0bSavedObject\x12R\n\x08\x63hildren\x18\x01 \x03(\x0b\x32@.tensorflow.TrackableObjectGraph.TrackableObject.ObjectReference\x12^\n\x0eslot_variables\x18\x03 \x03(\x0b\x32\x46.tensorflow.TrackableObjectGraph.TrackableObject.SlotVariableReference\x12\x32\n\x0buser_object\x18\x04 \x01(\x0b\x32\x1b.tensorflow.SavedUserObjectH\x00\x12\'\n\x05\x61sset\x18\x05 \x01(\x0b\x32\x16.tensorflow.SavedAssetH\x00\x12-\n\x08\x66unction\x18\x06 \x01(\x0b\x32\x19.tensorflow.SavedFunctionH\x00\x12-\n\x08variable\x18\x07 \x01(\x0b\x32\x19.tensorflow.SavedVariableH\x00\x12G\n\x16\x62\x61re_concrete_function\x18\x08 \x01(\x0b\x32%.tensorflow.SavedBareConcreteFunctionH\x00\x12-\n\x08\x63onstant\x18\t \x01(\x0b\x32\x19.tensorflow.SavedConstantH\x00\x12-\n\x08resource\x18\n \x01(\x0b\x32\x19.tensorflow.SavedResourceH\x00\x42\x06\n\x04kindJ\x04\x08\x02\x10\x03R\nattributes\"`\n\x0fSavedUserObject\x12\x12\n\nidentifier\x18\x01 \x01(\t\x12\'\n\x07version\x18\x02 \x01(\x0b\x32\x16.tensorflow.VersionDef\x12\x10\n\x08metadata\x18\x03 \x01(\t\"*\n\nSavedAsset\x12\x1c\n\x14\x61sset_file_def_index\x18\x01 \x01(\x05\"\\\n\rSavedFunction\x12\x1a\n\x12\x63oncrete_functions\x18\x01 \x03(\t\x12/\n\rfunction_spec\x18\x02 \x01(\x0b\x32\x18.tensorflow.FunctionSpec\"\xa8\x01\n\x15SavedConcreteFunction\x12\x14\n\x0c\x62ound_inputs\x18\x02 \x03(\x05\x12\x42\n\x1d\x63\x61nonicalized_input_signature\x18\x03 \x01(\x0b\x32\x1b.tensorflow.StructuredValue\x12\x35\n\x10output_signature\x18\x04 \x01(\x0b\x32\x1b.tensorflow.StructuredValue\"|\n\x19SavedBareConcreteFunction\x12\x1e\n\x16\x63oncrete_function_name\x18\x01 \x01(\t\x12\x19\n\x11\x61rgument_keywords\x18\x02 \x03(\t\x12$\n\x1c\x61llowed_positional_arguments\x18\x03 \x01(\x03\"\"\n\rSavedConstant\x12\x11\n\toperation\x18\x01 \x01(\t\"\xf6\x01\n\rSavedVariable\x12#\n\x05\x64type\x18\x01 \x01(\x0e\x32\x14.tensorflow.DataType\x12+\n\x05shape\x18\x02 \x01(\x0b\x32\x1c.tensorflow.TensorShapeProto\x12\x11\n\ttrainable\x18\x03 \x01(\x08\x12<\n\x0fsynchronization\x18\x04 \x01(\x0e\x32#.tensorflow.VariableSynchronization\x12\x34\n\x0b\x61ggregation\x18\x05 \x01(\x0e\x32\x1f.tensorflow.VariableAggregation\x12\x0c\n\x04name\x18\x06 \x01(\t\"\x95\x01\n\x0c\x46unctionSpec\x12\x30\n\x0b\x66ullargspec\x18\x01 \x01(\x0b\x32\x1b.tensorflow.StructuredValue\x12\x11\n\tis_method\x18\x02 \x01(\x08\x12\x34\n\x0finput_signature\x18\x05 \x01(\x0b\x32\x1b.tensorflow.StructuredValueJ\x04\x08\x03\x10\x04J\x04\x08\x04\x10\x05\"\x1f\n\rSavedResource\x12\x0e\n\x06\x64\x65vice\x18\x01 \x01(\tB\x03\xf8\x01\x01\x62\x06proto3' - , - dependencies=[trackable__object__graph__pb2.DESCRIPTOR,struct__pb2.DESCRIPTOR,tensor__shape__pb2.DESCRIPTOR,types__pb2.DESCRIPTOR,versions__pb2.DESCRIPTOR,variable__pb2.DESCRIPTOR,]) - - - - -_SAVEDOBJECTGRAPH_CONCRETEFUNCTIONSENTRY = _descriptor.Descriptor( - name='ConcreteFunctionsEntry', - full_name='tensorflow.SavedObjectGraph.ConcreteFunctionsEntry', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='key', full_name='tensorflow.SavedObjectGraph.ConcreteFunctionsEntry.key', index=0, - number=1, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='value', full_name='tensorflow.SavedObjectGraph.ConcreteFunctionsEntry.value', index=1, - number=2, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=b'8\001', - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=291, - serialized_end=382, -) - -_SAVEDOBJECTGRAPH = _descriptor.Descriptor( - name='SavedObjectGraph', - full_name='tensorflow.SavedObjectGraph', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='nodes', full_name='tensorflow.SavedObjectGraph.nodes', index=0, - number=1, type=11, cpp_type=10, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='concrete_functions', full_name='tensorflow.SavedObjectGraph.concrete_functions', index=1, - number=2, type=11, cpp_type=10, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[_SAVEDOBJECTGRAPH_CONCRETEFUNCTIONSENTRY, ], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=150, - serialized_end=382, -) - - -_SAVEDOBJECT = _descriptor.Descriptor( - name='SavedObject', - full_name='tensorflow.SavedObject', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='children', full_name='tensorflow.SavedObject.children', index=0, - number=1, type=11, cpp_type=10, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='slot_variables', full_name='tensorflow.SavedObject.slot_variables', index=1, - number=3, type=11, cpp_type=10, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='user_object', full_name='tensorflow.SavedObject.user_object', index=2, - number=4, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='asset', full_name='tensorflow.SavedObject.asset', index=3, - number=5, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='function', full_name='tensorflow.SavedObject.function', index=4, - number=6, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='variable', full_name='tensorflow.SavedObject.variable', index=5, - number=7, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='bare_concrete_function', full_name='tensorflow.SavedObject.bare_concrete_function', index=6, - number=8, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='constant', full_name='tensorflow.SavedObject.constant', index=7, - number=9, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='resource', full_name='tensorflow.SavedObject.resource', index=8, - number=10, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - _descriptor.OneofDescriptor( - name='kind', full_name='tensorflow.SavedObject.kind', - index=0, containing_type=None, fields=[]), - ], - serialized_start=385, - serialized_end=958, -) - - -_SAVEDUSEROBJECT = _descriptor.Descriptor( - name='SavedUserObject', - full_name='tensorflow.SavedUserObject', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='identifier', full_name='tensorflow.SavedUserObject.identifier', index=0, - number=1, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='version', full_name='tensorflow.SavedUserObject.version', index=1, - number=2, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='metadata', full_name='tensorflow.SavedUserObject.metadata', index=2, - number=3, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=960, - serialized_end=1056, -) - - -_SAVEDASSET = _descriptor.Descriptor( - name='SavedAsset', - full_name='tensorflow.SavedAsset', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='asset_file_def_index', full_name='tensorflow.SavedAsset.asset_file_def_index', index=0, - number=1, type=5, cpp_type=1, label=1, - has_default_value=False, default_value=0, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=1058, - serialized_end=1100, -) - - -_SAVEDFUNCTION = _descriptor.Descriptor( - name='SavedFunction', - full_name='tensorflow.SavedFunction', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='concrete_functions', full_name='tensorflow.SavedFunction.concrete_functions', index=0, - number=1, type=9, cpp_type=9, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='function_spec', full_name='tensorflow.SavedFunction.function_spec', index=1, - number=2, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=1102, - serialized_end=1194, -) - - -_SAVEDCONCRETEFUNCTION = _descriptor.Descriptor( - name='SavedConcreteFunction', - full_name='tensorflow.SavedConcreteFunction', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='bound_inputs', full_name='tensorflow.SavedConcreteFunction.bound_inputs', index=0, - number=2, type=5, cpp_type=1, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='canonicalized_input_signature', full_name='tensorflow.SavedConcreteFunction.canonicalized_input_signature', index=1, - number=3, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='output_signature', full_name='tensorflow.SavedConcreteFunction.output_signature', index=2, - number=4, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=1197, - serialized_end=1365, -) - - -_SAVEDBARECONCRETEFUNCTION = _descriptor.Descriptor( - name='SavedBareConcreteFunction', - full_name='tensorflow.SavedBareConcreteFunction', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='concrete_function_name', full_name='tensorflow.SavedBareConcreteFunction.concrete_function_name', index=0, - number=1, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='argument_keywords', full_name='tensorflow.SavedBareConcreteFunction.argument_keywords', index=1, - number=2, type=9, cpp_type=9, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='allowed_positional_arguments', full_name='tensorflow.SavedBareConcreteFunction.allowed_positional_arguments', index=2, - number=3, type=3, cpp_type=2, label=1, - has_default_value=False, default_value=0, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=1367, - serialized_end=1491, -) - - -_SAVEDCONSTANT = _descriptor.Descriptor( - name='SavedConstant', - full_name='tensorflow.SavedConstant', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='operation', full_name='tensorflow.SavedConstant.operation', index=0, - number=1, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=1493, - serialized_end=1527, -) - - -_SAVEDVARIABLE = _descriptor.Descriptor( - name='SavedVariable', - full_name='tensorflow.SavedVariable', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='dtype', full_name='tensorflow.SavedVariable.dtype', index=0, - number=1, type=14, cpp_type=8, label=1, - has_default_value=False, default_value=0, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='shape', full_name='tensorflow.SavedVariable.shape', index=1, - number=2, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='trainable', full_name='tensorflow.SavedVariable.trainable', index=2, - number=3, type=8, cpp_type=7, label=1, - has_default_value=False, default_value=False, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='synchronization', full_name='tensorflow.SavedVariable.synchronization', index=3, - number=4, type=14, cpp_type=8, label=1, - has_default_value=False, default_value=0, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='aggregation', full_name='tensorflow.SavedVariable.aggregation', index=4, - number=5, type=14, cpp_type=8, label=1, - has_default_value=False, default_value=0, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='name', full_name='tensorflow.SavedVariable.name', index=5, - number=6, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=1530, - serialized_end=1776, -) - - -_FUNCTIONSPEC = _descriptor.Descriptor( - name='FunctionSpec', - full_name='tensorflow.FunctionSpec', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='fullargspec', full_name='tensorflow.FunctionSpec.fullargspec', index=0, - number=1, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='is_method', full_name='tensorflow.FunctionSpec.is_method', index=1, - number=2, type=8, cpp_type=7, label=1, - has_default_value=False, default_value=False, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='input_signature', full_name='tensorflow.FunctionSpec.input_signature', index=2, - number=5, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=1779, - serialized_end=1928, -) - - -_SAVEDRESOURCE = _descriptor.Descriptor( - name='SavedResource', - full_name='tensorflow.SavedResource', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='device', full_name='tensorflow.SavedResource.device', index=0, - number=1, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=1930, - serialized_end=1961, -) - -_SAVEDOBJECTGRAPH_CONCRETEFUNCTIONSENTRY.fields_by_name['value'].message_type = _SAVEDCONCRETEFUNCTION -_SAVEDOBJECTGRAPH_CONCRETEFUNCTIONSENTRY.containing_type = _SAVEDOBJECTGRAPH -_SAVEDOBJECTGRAPH.fields_by_name['nodes'].message_type = _SAVEDOBJECT -_SAVEDOBJECTGRAPH.fields_by_name['concrete_functions'].message_type = _SAVEDOBJECTGRAPH_CONCRETEFUNCTIONSENTRY -_SAVEDOBJECT.fields_by_name['children'].message_type = trackable__object__graph__pb2._TRACKABLEOBJECTGRAPH_TRACKABLEOBJECT_OBJECTREFERENCE -_SAVEDOBJECT.fields_by_name['slot_variables'].message_type = trackable__object__graph__pb2._TRACKABLEOBJECTGRAPH_TRACKABLEOBJECT_SLOTVARIABLEREFERENCE -_SAVEDOBJECT.fields_by_name['user_object'].message_type = _SAVEDUSEROBJECT -_SAVEDOBJECT.fields_by_name['asset'].message_type = _SAVEDASSET -_SAVEDOBJECT.fields_by_name['function'].message_type = _SAVEDFUNCTION -_SAVEDOBJECT.fields_by_name['variable'].message_type = _SAVEDVARIABLE -_SAVEDOBJECT.fields_by_name['bare_concrete_function'].message_type = _SAVEDBARECONCRETEFUNCTION -_SAVEDOBJECT.fields_by_name['constant'].message_type = _SAVEDCONSTANT -_SAVEDOBJECT.fields_by_name['resource'].message_type = _SAVEDRESOURCE -_SAVEDOBJECT.oneofs_by_name['kind'].fields.append( - _SAVEDOBJECT.fields_by_name['user_object']) -_SAVEDOBJECT.fields_by_name['user_object'].containing_oneof = _SAVEDOBJECT.oneofs_by_name['kind'] -_SAVEDOBJECT.oneofs_by_name['kind'].fields.append( - _SAVEDOBJECT.fields_by_name['asset']) -_SAVEDOBJECT.fields_by_name['asset'].containing_oneof = _SAVEDOBJECT.oneofs_by_name['kind'] -_SAVEDOBJECT.oneofs_by_name['kind'].fields.append( - _SAVEDOBJECT.fields_by_name['function']) -_SAVEDOBJECT.fields_by_name['function'].containing_oneof = _SAVEDOBJECT.oneofs_by_name['kind'] -_SAVEDOBJECT.oneofs_by_name['kind'].fields.append( - _SAVEDOBJECT.fields_by_name['variable']) -_SAVEDOBJECT.fields_by_name['variable'].containing_oneof = _SAVEDOBJECT.oneofs_by_name['kind'] -_SAVEDOBJECT.oneofs_by_name['kind'].fields.append( - _SAVEDOBJECT.fields_by_name['bare_concrete_function']) -_SAVEDOBJECT.fields_by_name['bare_concrete_function'].containing_oneof = _SAVEDOBJECT.oneofs_by_name['kind'] -_SAVEDOBJECT.oneofs_by_name['kind'].fields.append( - _SAVEDOBJECT.fields_by_name['constant']) -_SAVEDOBJECT.fields_by_name['constant'].containing_oneof = _SAVEDOBJECT.oneofs_by_name['kind'] -_SAVEDOBJECT.oneofs_by_name['kind'].fields.append( - _SAVEDOBJECT.fields_by_name['resource']) -_SAVEDOBJECT.fields_by_name['resource'].containing_oneof = _SAVEDOBJECT.oneofs_by_name['kind'] -_SAVEDUSEROBJECT.fields_by_name['version'].message_type = versions__pb2._VERSIONDEF -_SAVEDFUNCTION.fields_by_name['function_spec'].message_type = _FUNCTIONSPEC -_SAVEDCONCRETEFUNCTION.fields_by_name['canonicalized_input_signature'].message_type = struct__pb2._STRUCTUREDVALUE -_SAVEDCONCRETEFUNCTION.fields_by_name['output_signature'].message_type = struct__pb2._STRUCTUREDVALUE -_SAVEDVARIABLE.fields_by_name['dtype'].enum_type = types__pb2._DATATYPE -_SAVEDVARIABLE.fields_by_name['shape'].message_type = tensor__shape__pb2._TENSORSHAPEPROTO -_SAVEDVARIABLE.fields_by_name['synchronization'].enum_type = variable__pb2._VARIABLESYNCHRONIZATION -_SAVEDVARIABLE.fields_by_name['aggregation'].enum_type = variable__pb2._VARIABLEAGGREGATION -_FUNCTIONSPEC.fields_by_name['fullargspec'].message_type = struct__pb2._STRUCTUREDVALUE -_FUNCTIONSPEC.fields_by_name['input_signature'].message_type = struct__pb2._STRUCTUREDVALUE -DESCRIPTOR.message_types_by_name['SavedObjectGraph'] = _SAVEDOBJECTGRAPH -DESCRIPTOR.message_types_by_name['SavedObject'] = _SAVEDOBJECT -DESCRIPTOR.message_types_by_name['SavedUserObject'] = _SAVEDUSEROBJECT -DESCRIPTOR.message_types_by_name['SavedAsset'] = _SAVEDASSET -DESCRIPTOR.message_types_by_name['SavedFunction'] = _SAVEDFUNCTION -DESCRIPTOR.message_types_by_name['SavedConcreteFunction'] = _SAVEDCONCRETEFUNCTION -DESCRIPTOR.message_types_by_name['SavedBareConcreteFunction'] = _SAVEDBARECONCRETEFUNCTION -DESCRIPTOR.message_types_by_name['SavedConstant'] = _SAVEDCONSTANT -DESCRIPTOR.message_types_by_name['SavedVariable'] = _SAVEDVARIABLE -DESCRIPTOR.message_types_by_name['FunctionSpec'] = _FUNCTIONSPEC -DESCRIPTOR.message_types_by_name['SavedResource'] = _SAVEDRESOURCE -_sym_db.RegisterFileDescriptor(DESCRIPTOR) - -SavedObjectGraph = _reflection.GeneratedProtocolMessageType('SavedObjectGraph', (_message.Message,), { - - 'ConcreteFunctionsEntry' : _reflection.GeneratedProtocolMessageType('ConcreteFunctionsEntry', (_message.Message,), { - 'DESCRIPTOR' : _SAVEDOBJECTGRAPH_CONCRETEFUNCTIONSENTRY, - '__module__' : 'saved_object_graph_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.SavedObjectGraph.ConcreteFunctionsEntry) - }) - , - 'DESCRIPTOR' : _SAVEDOBJECTGRAPH, - '__module__' : 'saved_object_graph_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.SavedObjectGraph) - }) -_sym_db.RegisterMessage(SavedObjectGraph) -_sym_db.RegisterMessage(SavedObjectGraph.ConcreteFunctionsEntry) - -SavedObject = _reflection.GeneratedProtocolMessageType('SavedObject', (_message.Message,), { - 'DESCRIPTOR' : _SAVEDOBJECT, - '__module__' : 'saved_object_graph_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.SavedObject) - }) -_sym_db.RegisterMessage(SavedObject) - -SavedUserObject = _reflection.GeneratedProtocolMessageType('SavedUserObject', (_message.Message,), { - 'DESCRIPTOR' : _SAVEDUSEROBJECT, - '__module__' : 'saved_object_graph_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.SavedUserObject) - }) -_sym_db.RegisterMessage(SavedUserObject) - -SavedAsset = _reflection.GeneratedProtocolMessageType('SavedAsset', (_message.Message,), { - 'DESCRIPTOR' : _SAVEDASSET, - '__module__' : 'saved_object_graph_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.SavedAsset) - }) -_sym_db.RegisterMessage(SavedAsset) - -SavedFunction = _reflection.GeneratedProtocolMessageType('SavedFunction', (_message.Message,), { - 'DESCRIPTOR' : _SAVEDFUNCTION, - '__module__' : 'saved_object_graph_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.SavedFunction) - }) -_sym_db.RegisterMessage(SavedFunction) - -SavedConcreteFunction = _reflection.GeneratedProtocolMessageType('SavedConcreteFunction', (_message.Message,), { - 'DESCRIPTOR' : _SAVEDCONCRETEFUNCTION, - '__module__' : 'saved_object_graph_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.SavedConcreteFunction) - }) -_sym_db.RegisterMessage(SavedConcreteFunction) - -SavedBareConcreteFunction = _reflection.GeneratedProtocolMessageType('SavedBareConcreteFunction', (_message.Message,), { - 'DESCRIPTOR' : _SAVEDBARECONCRETEFUNCTION, - '__module__' : 'saved_object_graph_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.SavedBareConcreteFunction) - }) -_sym_db.RegisterMessage(SavedBareConcreteFunction) - -SavedConstant = _reflection.GeneratedProtocolMessageType('SavedConstant', (_message.Message,), { - 'DESCRIPTOR' : _SAVEDCONSTANT, - '__module__' : 'saved_object_graph_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.SavedConstant) - }) -_sym_db.RegisterMessage(SavedConstant) - -SavedVariable = _reflection.GeneratedProtocolMessageType('SavedVariable', (_message.Message,), { - 'DESCRIPTOR' : _SAVEDVARIABLE, - '__module__' : 'saved_object_graph_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.SavedVariable) - }) -_sym_db.RegisterMessage(SavedVariable) - -FunctionSpec = _reflection.GeneratedProtocolMessageType('FunctionSpec', (_message.Message,), { - 'DESCRIPTOR' : _FUNCTIONSPEC, - '__module__' : 'saved_object_graph_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.FunctionSpec) - }) -_sym_db.RegisterMessage(FunctionSpec) - -SavedResource = _reflection.GeneratedProtocolMessageType('SavedResource', (_message.Message,), { - 'DESCRIPTOR' : _SAVEDRESOURCE, - '__module__' : 'saved_object_graph_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.SavedResource) - }) -_sym_db.RegisterMessage(SavedResource) - - -DESCRIPTOR._options = None -_SAVEDOBJECTGRAPH_CONCRETEFUNCTIONSENTRY._options = None -# @@protoc_insertion_point(module_scope) diff --git a/redis_consumer/pbs/saved_object_graph_pb2_grpc.py b/redis_consumer/pbs/saved_object_graph_pb2_grpc.py deleted file mode 100644 index a8943526..00000000 --- a/redis_consumer/pbs/saved_object_graph_pb2_grpc.py +++ /dev/null @@ -1,3 +0,0 @@ -# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! -import grpc - diff --git a/redis_consumer/pbs/saver_pb2.py b/redis_consumer/pbs/saver_pb2.py deleted file mode 100644 index bcf329e4..00000000 --- a/redis_consumer/pbs/saver_pb2.py +++ /dev/null @@ -1,140 +0,0 @@ -# -*- coding: utf-8 -*- -# Generated by the protocol buffer compiler. DO NOT EDIT! -# source: saver.proto - -from google.protobuf import descriptor as _descriptor -from google.protobuf import message as _message -from google.protobuf import reflection as _reflection -from google.protobuf import symbol_database as _symbol_database -# @@protoc_insertion_point(imports) - -_sym_db = _symbol_database.Default() - - - - -DESCRIPTOR = _descriptor.FileDescriptor( - name='saver.proto', - package='tensorflow', - syntax='proto3', - serialized_options=b'\n\023org.tensorflow.utilB\013SaverProtosP\001Z=0.8.3 +deepcell-tracking>=0.2.6 +deepcell-toolbox>=0.8.2 +tensorflow-cpu +scikit-image>=0.14.0,<0.17.0 +numpy>=1.16.4 + +# tensorflow-serving-apis and gRPC dependencies +grpcio>=1.0,<2 +dict-to-protobuf==0.0.3.9 +protobuf>=3.6.0 + +# misc storage and redis clients boto3==1.9.195 google-cloud-storage>=1.16.1 python-decouple==3.1 redis==3.4.1 -scikit-image>=0.14.0,<0.17.0 -numpy>=1.16.4 -keras-preprocessing==1.1.0 -grpcio==1.27.2 -dict-to-protobuf==0.0.3.9 pytz==2019.1 -deepcell-tracking==0.2.6 -deepcell-toolbox>=0.8.2 From 316758e9e87fdb62a12d71d403fdbf29b298b5f9 Mon Sep 17 00:00:00 2001 From: William Graf <7930703+willgraf@users.noreply.github.com> Date: Wed, 20 Jan 2021 19:00:30 -0800 Subject: [PATCH 04/73] Move all grpc functions into grpc_clients and use tensorflow pbs. --- redis_consumer/grpc_clients.py | 93 ++++++++++++++++++++--- redis_consumer/grpc_clients_test.py | 80 ++++++++++++++++++++ redis_consumer/testing_utils.py | 12 ++- redis_consumer/utils.py | 87 +-------------------- redis_consumer/utils_test.py | 113 +++++++++------------------- 5 files changed, 209 insertions(+), 176 deletions(-) create mode 100644 redis_consumer/grpc_clients_test.py diff --git a/redis_consumer/grpc_clients.py b/redis_consumer/grpc_clients.py index b6324ba3..2f6c3787 100644 --- a/redis_consumer/grpc_clients.py +++ b/redis_consumer/grpc_clients.py @@ -34,24 +34,99 @@ import logging import time import timeit +import six +import dict_to_protobuf +from google.protobuf.json_format import MessageToJson import grpc -from grpc import RpcError import grpc.beta.implementations from grpc._cython import cygrpc - import numpy as np - -from google.protobuf.json_format import MessageToJson +from tensorflow.core.framework.types_pb2 import DESCRIPTOR +from tensorflow.core.framework.tensor_pb2 import TensorProto +from tensorflow_serving.apis.prediction_service_pb2_grpc import PredictionServiceStub +from tensorflow_serving.apis.predict_pb2 import PredictRequest +from tensorflow_serving.apis.get_model_metadata_pb2 import GetModelMetadataRequest from redis_consumer import settings -from redis_consumer.pbs.prediction_service_pb2_grpc import PredictionServiceStub -from redis_consumer.pbs.predict_pb2 import PredictRequest -from redis_consumer.pbs.get_model_metadata_pb2 import GetModelMetadataRequest -from redis_consumer.utils import grpc_response_to_dict -from redis_consumer.utils import make_tensor_proto +logger = logging.getLogger('redis_consumer.grpc_clients') + + +dtype_to_number = { + i.name: i.number for i in DESCRIPTOR.enum_types_by_name['DataType'].values +} + +# TODO: build this dynamically +number_to_dtype_value = { + 1: 'float_val', + 2: 'double_val', + 3: 'int_val', + 4: 'int_val', + 5: 'int_val', + 6: 'int_val', + 7: 'string_val', + 8: 'scomplex_val', + 9: 'int64_val', + 10: 'bool_val', + 18: 'dcomplex_val', + 19: 'half_val', + 20: 'resource_handle_val' +} + + +def grpc_response_to_dict(grpc_response): + # TODO: 'unicode' object has no attribute 'ListFields' + # response_dict = dict_to_protobuf.protobuf_to_dict(grpc_response) + # return response_dict + grpc_response_dict = dict() + + for k in grpc_response.outputs: + shape = [x.size for x in grpc_response.outputs[k].tensor_shape.dim] + + dtype_constant = grpc_response.outputs[k].dtype + + if dtype_constant not in number_to_dtype_value: + grpc_response_dict[k] = 'value not found' + logger.error('Tensor output data type not supported. ' + 'Returning empty dict.') + + dt = number_to_dtype_value[dtype_constant] + if shape == [1]: + grpc_response_dict[k] = eval( + 'grpc_response.outputs[k].' + dt)[0] + else: + grpc_response_dict[k] = np.array( + eval('grpc_response.outputs[k].' + dt)).reshape(shape) + + return grpc_response_dict + + +def make_tensor_proto(data, dtype): + tensor_proto = TensorProto() + + if isinstance(dtype, six.string_types): + dtype = dtype_to_number[dtype] + + dim = [{'size': 1}] + values = [data] + + if hasattr(data, 'shape'): + dim = [{'size': dim} for dim in data.shape] + values = list(data.reshape(-1)) + + tensor_proto_dict = { + 'dtype': dtype, + 'tensor_shape': { + 'dim': dim + }, + number_to_dtype_value[dtype]: values + } + dict_to_protobuf.dict_to_protobuf(tensor_proto_dict, tensor_proto) + + return tensor_proto + class GrpcClient(object): """Abstract class for all gRPC clients. diff --git a/redis_consumer/grpc_clients_test.py b/redis_consumer/grpc_clients_test.py new file mode 100644 index 00000000..2f454a08 --- /dev/null +++ b/redis_consumer/grpc_clients_test.py @@ -0,0 +1,80 @@ +# Copyright 2016-2020 The Van Valen Lab at the California Institute of +# Technology (Caltech), with support from the Paul Allen Family Foundation, +# Google, & National Institutes of Health (NIH) under Grant U24CA224309-01. +# All rights reserved. +# +# Licensed under a modified Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.github.com/vanvalenlab/kiosk-redis-consumer/LICENSE +# +# The Work provided may be used for non-commercial academic purposes only. +# For any other use of the Work, including commercial use, please contact: +# vanvalenlab@gmail.com +# +# Neither the name of Caltech nor the names of its contributors may be used +# to endorse or promote products derived from this software without specific +# prior written permission. +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Tests for gRPC Clients""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import pytest + +import numpy as np +from tensorflow.core.framework import types_pb2 +from tensorflow.core.framework.tensor_pb2 import TensorProto +from tensorflow_serving.apis.predict_pb2 import PredictResponse + +from redis_consumer.testing_utils import _get_image + +from redis_consumer import grpc_clients + + +def test_make_tensor_proto(): + # test with numpy array + data = _get_image(300, 300, 1) + proto = grpc_clients.make_tensor_proto(data, 'DT_FLOAT') + assert isinstance(proto, (TensorProto,)) + # test with value + data = 10.0 + proto = grpc_clients.make_tensor_proto(data, types_pb2.DT_FLOAT) + assert isinstance(proto, (TensorProto,)) + + +def test_grpc_response_to_dict(): + # pylint: disable=E1101 + # test valid response + data = _get_image(300, 300, 1) + tensor_proto = grpc_clients.make_tensor_proto(data, 'DT_FLOAT') + response = PredictResponse() + response.outputs['prediction'].CopyFrom(tensor_proto) + response_dict = grpc_clients.grpc_response_to_dict(response) + assert isinstance(response_dict, (dict,)) + np.testing.assert_allclose(response_dict['prediction'], data) + # test scalar input + data = 3 + tensor_proto = grpc_clients.make_tensor_proto(data, 'DT_FLOAT') + response = PredictResponse() + response.outputs['prediction'].CopyFrom(tensor_proto) + response_dict = grpc_clients.grpc_response_to_dict(response) + assert isinstance(response_dict, (dict,)) + np.testing.assert_allclose(response_dict['prediction'], data) + # test bad dtype + # logs an error, but should throw a KeyError as well. + data = _get_image(300, 300, 1) + tensor_proto = grpc_clients.make_tensor_proto(data, 'DT_FLOAT') + response = PredictResponse() + response.outputs['prediction'].CopyFrom(tensor_proto) + response.outputs['prediction'].dtype = 32 + with pytest.raises(KeyError): + response_dict = grpc_clients.grpc_response_to_dict(response) \ No newline at end of file diff --git a/redis_consumer/testing_utils.py b/redis_consumer/testing_utils.py index 2f89646d..99105b17 100644 --- a/redis_consumer/testing_utils.py +++ b/redis_consumer/testing_utils.py @@ -48,10 +48,14 @@ def redis_client(): yield client -def _get_image(img_h=300, img_w=300): - bias = np.random.rand(img_w, img_h) * 64 - variance = np.random.rand(img_w, img_h) * (255 - 64) - img = np.random.rand(img_w, img_h) * variance + bias +def _get_image(img_h=300, img_w=300, channels=None): + shape = [img_w, img_h] + if channels: + shape.append(channels) + shape = tuple(shape) + bias = np.random.rand(*shape) * 64 + variance = np.random.rand(*shape) * (255 - 64) + img = np.random.rand(*shape) * variance + bias return img diff --git a/redis_consumer/utils.py b/redis_consumer/utils.py index 01192bf2..dd6ca2b8 100644 --- a/redis_consumer/utils.py +++ b/redis_consumer/utils.py @@ -39,99 +39,18 @@ import tarfile import tempfile import zipfile -import six - -import skimage -from skimage.external import tifffile import numpy as np -import keras_preprocessing.image -import dict_to_protobuf import PIL +from skimage.external import tifffile +from tensorflow.keras.preprocessing.image import img_to_array -from redis_consumer.pbs.types_pb2 import DESCRIPTOR -from redis_consumer.pbs.tensor_pb2 import TensorProto -from redis_consumer.pbs.tensor_shape_pb2 import TensorShapeProto from redis_consumer import settings logger = logging.getLogger('redis_consumer.utils') -dtype_to_number = { - i.name: i.number for i in DESCRIPTOR.enum_types_by_name['DataType'].values -} - -# TODO: build this dynamically -number_to_dtype_value = { - 1: 'float_val', - 2: 'double_val', - 3: 'int_val', - 4: 'int_val', - 5: 'int_val', - 6: 'int_val', - 7: 'string_val', - 8: 'scomplex_val', - 9: 'int64_val', - 10: 'bool_val', - 18: 'dcomplex_val', - 19: 'half_val', - 20: 'resource_handle_val' -} - - -def grpc_response_to_dict(grpc_response): - # TODO: 'unicode' object has no attribute 'ListFields' - # response_dict = dict_to_protobuf.protobuf_to_dict(grpc_response) - # return response_dict - grpc_response_dict = dict() - - for k in grpc_response.outputs: - shape = [x.size for x in grpc_response.outputs[k].tensor_shape.dim] - - dtype_constant = grpc_response.outputs[k].dtype - - if dtype_constant not in number_to_dtype_value: - grpc_response_dict[k] = 'value not found' - logger.error('Tensor output data type not supported. ' - 'Returning empty dict.') - - dt = number_to_dtype_value[dtype_constant] - if shape == [1]: - grpc_response_dict[k] = eval( - 'grpc_response.outputs[k].' + dt)[0] - else: - grpc_response_dict[k] = np.array( - eval('grpc_response.outputs[k].' + dt)).reshape(shape) - - return grpc_response_dict - - -def make_tensor_proto(data, dtype): - tensor_proto = TensorProto() - - if isinstance(dtype, six.string_types): - dtype = dtype_to_number[dtype] - - dim = [{'size': 1}] - values = [data] - - if hasattr(data, 'shape'): - dim = [{'size': dim} for dim in data.shape] - values = list(data.reshape(-1)) - - tensor_proto_dict = { - 'dtype': dtype, - 'tensor_shape': { - 'dim': dim - }, - number_to_dtype_value[dtype]: values - } - dict_to_protobuf.dict_to_protobuf(tensor_proto_dict, tensor_proto) - - return tensor_proto - - # Workaround for python2 not supporting `with tempfile.TemporaryDirectory() as` # These are unnecessary if not supporting python2 @contextlib.contextmanager @@ -209,7 +128,7 @@ def get_image(filepath): # tiff files should not have a channel dim img = np.expand_dims(img, axis=-1) else: - img = keras_preprocessing.image.img_to_array(PIL.Image.open(filepath)) + img = img_to_array(PIL.Image.open(filepath)) logger.debug('Loaded %s into numpy array with shape %s', filepath, img.shape) diff --git a/redis_consumer/utils_test.py b/redis_consumer/utils_test.py index d66402ec..b747e54b 100644 --- a/redis_consumer/utils_test.py +++ b/redis_consumer/utils_test.py @@ -34,24 +34,18 @@ import tempfile import zipfile +import pytest + import numpy as np -from keras_preprocessing.image import array_to_img +from tensorflow.keras.preprocessing.image import array_to_img from skimage.external import tifffile as tiff -from redis_consumer.pbs.predict_pb2 import PredictResponse -from redis_consumer.pbs import types_pb2 -from redis_consumer.pbs.tensor_pb2 import TensorProto +from redis_consumer.testing_utils import _get_image + from redis_consumer import utils from redis_consumer import settings -def _get_image(img_h=300, img_w=300, channels=1): - bias = np.random.rand(img_w, img_h, channels) * 64 - variance = np.random.rand(img_w, img_h, channels) * (255 - 64) - img = np.random.rand(img_w, img_h, channels) * variance + bias - return img - - def _write_image(filepath, img_w=300, img_h=300): imarray = _get_image(img_h, img_w) if filepath.lower().endswith('tif') or filepath.lower().endswith('tiff'): @@ -77,77 +71,38 @@ def _write_trks(filepath, X_mean=10, y_mean=5, trks.add(tracked_file.name, 'tracked.npy') -def test_make_tensor_proto(): - # test with numpy array - data = _get_image(300, 300, 1) - proto = utils.make_tensor_proto(data, 'DT_FLOAT') - assert isinstance(proto, (TensorProto,)) - # test with value - data = 10.0 - proto = utils.make_tensor_proto(data, types_pb2.DT_FLOAT) - assert isinstance(proto, (TensorProto,)) - - -def test_grpc_response_to_dict(): - # pylint: disable=E1101 - # test valid response - data = _get_image(300, 300, 1) - tensor_proto = utils.make_tensor_proto(data, 'DT_FLOAT') - response = PredictResponse() - response.outputs['prediction'].CopyFrom(tensor_proto) - response_dict = utils.grpc_response_to_dict(response) - assert isinstance(response_dict, (dict,)) - np.testing.assert_allclose(response_dict['prediction'], data) - # test scalar input - data = 3 - tensor_proto = utils.make_tensor_proto(data, 'DT_FLOAT') - response = PredictResponse() - response.outputs['prediction'].CopyFrom(tensor_proto) - response_dict = utils.grpc_response_to_dict(response) - assert isinstance(response_dict, (dict,)) - np.testing.assert_allclose(response_dict['prediction'], data) - # test bad dtype - # logs an error, but should throw a KeyError as well. - data = _get_image(300, 300, 1) - tensor_proto = utils.make_tensor_proto(data, 'DT_FLOAT') - response = PredictResponse() - response.outputs['prediction'].CopyFrom(tensor_proto) - response.outputs['prediction'].dtype = 32 - with pytest.raises(KeyError): - response_dict = utils.grpc_response_to_dict(response) - - -def test_iter_image_archive(): - with utils.get_tempdir() as tempdir: - zip_path = os.path.join(tempdir, 'test.zip') - archive = zipfile.ZipFile(zip_path, 'w') - num_files = 3 - for n in range(num_files): - path = os.path.join(tempdir, '{}.tif'.format(n)) - _write_image(path, 30, 30) - archive.write(path) - archive.close() +def test_iter_image_archive(tmpdir): + tmpdir = str(tmpdir) + zip_path = os.path.join(tmpdir, 'test.zip') + archive = zipfile.ZipFile(zip_path, 'w') + num_files = 3 + for n in range(num_files): + path = os.path.join(tmpdir, '{}.tif'.format(n)) + _write_image(path, 30, 30) + archive.write(path) + archive.close() - unzipped = [z for z in utils.iter_image_archive(zip_path, tempdir)] - assert len(unzipped) == num_files + unzipped = [z for z in utils.iter_image_archive(zip_path, tmpdir)] + assert len(unzipped) == num_files -def test_get_image_files_from_dir(): - with utils.get_tempdir() as tempdir: - zip_path = os.path.join(tempdir, 'test.zip') - archive = zipfile.ZipFile(zip_path, 'w') - num_files = 3 - for n in range(num_files): - path = os.path.join(tempdir, '{}.tif'.format(n)) - _write_image(path, 30, 30) - archive.write(path) - archive.close() - - imfiles = list(utils.get_image_files_from_dir(path, None)) - assert len(imfiles) == 1 - - imfiles = list(utils.get_image_files_from_dir(zip_path, tempdir)) - assert len(imfiles) == num_files +def test_get_image_files_from_dir(tmpdir): + tmpdir = str(tmpdir) + + zip_path = os.path.join(tmpdir, 'test.zip') + archive = zipfile.ZipFile(zip_path, 'w') + num_files = 3 + for n in range(num_files): + path = os.path.join(tmpdir, '{}.tif'.format(n)) + _write_image(path, 30, 30) + archive.write(path) + archive.close() + + imfiles = list(utils.get_image_files_from_dir(path, None)) + assert len(imfiles) == 1 + + imfiles = list(utils.get_image_files_from_dir(zip_path, tmpdir)) + assert len(imfiles) == num_files def test_get_image(tmpdir): From 1be19bb57a800ddc581f0620b7943f32196bfa4e Mon Sep 17 00:00:00 2001 From: William Graf <7930703+willgraf@users.noreply.github.com> Date: Wed, 20 Jan 2021 19:07:46 -0800 Subject: [PATCH 05/73] Add gRPC Model wrapper and use it and applications in the consumers. --- redis_consumer/consumers/base_consumer.py | 431 ++---------------- .../consumers/base_consumer_test.py | 54 --- redis_consumer/consumers/image_consumer.py | 113 ++--- .../consumers/multiplex_consumer.py | 60 ++- redis_consumer/grpc_clients.py | 82 ++++ redis_consumer/settings.py | 36 +- 6 files changed, 192 insertions(+), 584 deletions(-) diff --git a/redis_consumer/consumers/base_consumer.py b/redis_consumer/consumers/base_consumer.py index 5bfcada6..11f87126 100644 --- a/redis_consumer/consumers/base_consumer.py +++ b/redis_consumer/consumers/base_consumer.py @@ -42,9 +42,9 @@ import numpy as np import pytz -from deepcell_toolbox.utils import tile_image, untile_image, resize +from deepcell.applications import ScaleDetection -from redis_consumer.grpc_clients import PredictClient +from redis_consumer.grpc_clients import PredictClient, GrpcModelWrapper from redis_consumer import utils from redis_consumer import settings @@ -308,61 +308,6 @@ def _get_predict_client(self, model_name, model_version): timeit.default_timer() - t) return client - def grpc_image(self, img, model_name, model_version, model_shape, - in_tensor_name='image', in_tensor_dtype='DT_FLOAT'): - """Use the TensorFlow Serving gRPC API for model inference on an image. - - Args: - img (numpy.array): The image to send to the model - model_name (str): The name of the model - model_version (int): The version of the model - model_shape (tuple): The shape of input data for the model - in_tensor_name (str): The name of the input tensor for the request - in_tensor_dtype (str): The dtype of the input data - - Returns: - numpy.array: The results of model inference. - """ - in_tensor_dtype = str(in_tensor_dtype).upper() - - start = timeit.default_timer() - self.logger.debug('Segmenting image of shape %s with model %s:%s', - img.shape, model_name, model_version) - - if len(model_shape) == img.ndim + 1: - img = np.expand_dims(img, axis=0) - - if in_tensor_dtype == 'DT_HALF': - # TODO: seems like should cast to "half" - # but the model rejects the type, wants "int" or "long" - img = img.astype('int') - - req_data = [{'in_tensor_name': in_tensor_name, - 'in_tensor_dtype': in_tensor_dtype, - 'data': img}] - - client = self._get_predict_client(model_name, model_version) - - prediction = client.predict(req_data, settings.GRPC_TIMEOUT) - results = [prediction[k] for k in sorted(prediction.keys())] - - if len(results) == 1: - results = results[0] - - finished = timeit.default_timer() - start - if self._redis_hash is not None: - self.update_key(self._redis_hash, { - 'prediction_time': finished, - }) - self.logger.debug('Segmented key %s (model %s:%s, ' - 'preprocessing: %s, postprocessing: %s)' - ' (%s retries) in %s seconds.', - self._redis_hash, model_name, model_version, - self._redis_values.get('preprocess_function'), - self._redis_values.get('postprocess_function'), - 0, finished) - return results - def parse_model_metadata(self, metadata): """Parse the metadata response and return list of input metadata. @@ -424,11 +369,42 @@ def get_model_metadata(self, model_name, model_version): self.logger.error('Malformed metadata: %s', model_metadata) raise err - def detect_scale(self, image): # pylint: disable=unused-argument - """Stub for scale detection""" - self.logger.debug('Scale was not given. Defaults to 1') - scale = 1 - return scale + def get_grpc_app(self, model, application_cls): + """ + Create an application from deepcell.applications + with a gRPC model wrapper as a model + """ + model_name, model_version = model.split(':') + model_metadata = self.get_model_metadata(model_name, model_version) + client = self._get_predict_client(model_name, model_version) + model_wrapper = GrpcModelWrapper(client, model_metadata) + return application_cls(model_wrapper) + + def detect_scale(self, image): + """Send the image to the SCALE_DETECT_MODEL to detect the relative + scale difference from the image to the model's training data. + + Args: + image (numpy.array): The image data. + + Returns: + scale (float): The detected scale, used to rescale data. + """ + start = timeit.default_timer() + + app = self.get_grpc_app(settings.SCALE_DETECT_MODEL, ScaleDetection) + + if not settings.SCALE_DETECT_ENABLED: + self.logger.debug('Scale detection disabled.') + return app.model_mpp + + batch_size = app.model.get_batch_size() + detected_scale = app.predict(image, batch_size=batch_size) + + self.logger.debug('Scale %s detected in %s seconds', + detected_scale, timeit.default_timer() - start) + + return app.model_mpp * detected_scale def get_image_scale(self, scale, image, redis_hash): """Calculate scale of image and rescale""" @@ -447,325 +423,7 @@ def get_image_scale(self, scale, image, redis_hash): settings.MAX_SCALE)) return scale - def _predict_big_image(self, - image, - model_name, - model_version, - model_shape, - model_input_name='image', - model_dtype='DT_FLOAT', - untile=True, - stride_ratio=0.75): - """Use tile_image to tile image for the model and untile the results. - - Args: - image (numpy.array): image data as numpy. - model_name (str): hosted model to send image data. - model_version (str): model version to query. - model_shape (tuple): shape of the model's expected input. - model_dtype (str): dtype of the model's input array. - model_input_name (str): name of the model's input array. - untile (bool): untiles results back to image shape if True. - stride_ratio (float): amount to overlap between tiles, (0, 1]. - - Returns: - numpy.array: untiled results from the model. - """ - model_ndim = len(model_shape) - input_shape = (model_shape[model_ndim - 3], model_shape[model_ndim - 2]) - - ratio = (model_shape[model_ndim - 3] / settings.TF_MIN_MODEL_SIZE) * \ - (model_shape[model_ndim - 2] / settings.TF_MIN_MODEL_SIZE) * \ - (model_shape[model_ndim - 1]) - - batch_size = int(settings.TF_MAX_BATCH_SIZE // ratio) - - tiles, tiles_info = tile_image( - np.expand_dims(image, axis=0), - model_input_shape=input_shape, - stride_ratio=stride_ratio) - - self.logger.debug('Tiling image of shape %s into shape %s.', - image.shape, tiles.shape) - - # max_batch_size is 1 by default. - # dependent on the tf-serving configuration - results = [] - for t in range(0, tiles.shape[0], batch_size): - batch = tiles[t:t + batch_size] - output = self.grpc_image( - batch, model_name, model_version, model_shape, - in_tensor_name=model_input_name, - in_tensor_dtype=model_dtype) - - if not isinstance(output, list): - output = [output] - - if len(results) == 0: - results = output - else: - for i, o in enumerate(output): - results[i] = np.vstack((results[i], o)) - - if not untile: - image = results - else: - image = [untile_image(r, tiles_info, model_input_shape=input_shape) - for r in results] - - image = image[0] if len(image) == 1 else image - return image - - def _predict_small_image(self, - image, - model_name, - model_version, - model_shape, - model_input_name='image', - model_dtype='DT_FLOAT'): - """Pad an image that is too small for the model, and unpad the results. - - Args: - img (numpy.array): The too-small image to be predicted with - model_name and model_version. - model_name (str): hosted model to send image data. - model_version (str): model version to query. - model_shape (tuple): shape of the model's expected input. - model_input_name (str): name of the model's input array. - model_dtype (str): dtype of the model's input array. - - Returns: - numpy.array: unpadded results from the model. - """ - pad_width = [] - model_ndim = len(model_shape) - for i in range(image.ndim): - if i in {image.ndim - 3, image.ndim - 2}: - if i == image.ndim - 3: - diff = model_shape[model_ndim - 3] - image.shape[i] - else: - diff = model_shape[model_ndim - 2] - image.shape[i] - - if diff > 0: - if diff % 2: - pad_width.append((diff // 2, diff // 2 + 1)) - else: - pad_width.append((diff // 2, diff // 2)) - else: - pad_width.append((0, 0)) - else: - pad_width.append((0, 0)) - - self.logger.info('Padding image from shape %s to shape %s.', - image.shape, tuple([x + y1 + y2 for x, (y1, y2) in - zip(image.shape, pad_width)])) - - padded_img = np.pad(image, pad_width, 'reflect') - image = self.grpc_image(padded_img, model_name, model_version, - model_shape, in_tensor_name=model_input_name, - in_tensor_dtype=model_dtype) - - image = [image] if not isinstance(image, list) else image - - # pad batch_size and frames for each output. - pad_widths = [pad_width] * len(image) - for i, im in enumerate(image): - while len(pad_widths[i]) < im.ndim: - pad_widths[i].insert(0, (0, 0)) - - # unpad results - image = [utils.unpad_image(i, p) for i, p in zip(image, pad_widths)] - image = image[0] if len(image) == 1 else image - - return image - - def predict(self, image, model_name, model_version, untile=True): - """Performs model inference on the image data. - - Args: - image (numpy.array): the image data - model_name (str): hosted model to send image data. - model_version (int): model version to query. - untile (bool): Whether to untile the tiled inference results. This - should be True when the model output is the same shape as the - input, and False otherwise. - """ - start = timeit.default_timer() - model_metadata = self.get_model_metadata(model_name, model_version) - - # TODO: generalize for more than a single input. - if len(model_metadata) > 1: - raise ValueError('Model {}:{} has {} required inputs but was only ' - 'given {} inputs.'.format( - model_name, model_version, - len(model_metadata), len(image))) - model_metadata = model_metadata[0] - - model_input_name = model_metadata['in_tensor_name'] - model_dtype = model_metadata['in_tensor_dtype'] - - model_shape = [int(x) for x in model_metadata['in_tensor_shape'].split(',')] - model_ndim = len(model_shape) - - image = image[0] if image.shape[0] == 1 else image - - if model_ndim != image.ndim + 1: - raise ValueError('Image of shape {} is incompatible with model ' - '{}:{} with input shape {}'.format( - image.shape, model_name, model_version, - tuple(model_shape))) - - if (image.shape[-2] > settings.MAX_IMAGE_WIDTH or - image.shape[-3] > settings.MAX_IMAGE_HEIGHT): - raise ValueError( - 'The image is too large! Rescaled images have a maximum size ' - 'of ({}, {}) but found size {}.'.format( - settings.MAX_IMAGE_WIDTH, - settings.MAX_IMAGE_HEIGHT, - image.shape)) - - size_x = model_shape[model_ndim - 3] - size_y = model_shape[model_ndim - 2] - - size_x = image.shape[image.ndim - 3] if size_x <= 0 else size_x - size_y = image.shape[image.ndim - 2] if size_y <= 0 else size_y - - self.logger.debug('Calling predict on model %s:%s with input shape %s' - ' and dtype %s to segment an image of shape %s.', - model_name, model_version, tuple(model_shape), - model_dtype, image.shape) - - if (image.shape[image.ndim - 3] < size_x or - image.shape[image.ndim - 2] < size_y): - # image is too small for the model, pad the image. - image = self._predict_small_image(image, model_name, model_version, - model_shape, model_input_name, - model_dtype) - elif (image.shape[image.ndim - 3] > size_x or - image.shape[image.ndim - 2] > size_y): - # image is too big for the model, multiple images are tiled. - image = self._predict_big_image(image, model_name, model_version, - model_shape, model_input_name, - model_dtype, untile=untile) - else: - # image size is perfect, just send it to the model - image = self.grpc_image(image, model_name, model_version, - model_shape, model_input_name, model_dtype) - - if isinstance(image, list): - output_shapes = [i.shape for i in image] - else: - output_shapes = [image.shape] # cast as list - - self.logger.debug('Got response from model %s:%s of shape %s in %s ' - 'seconds.', model_name, model_version, output_shapes, - timeit.default_timer() - start) - - return image - - def _get_processing_function(self, process_type, function_name): - """Based on the function category and name, return the function. - - Args: - process_type (str): "pre" or "post" processing - function_name (str): Name processing function, must exist in - settings.PROCESSING_FUNCTIONS. - - Returns: - function: the selected pre- or post-processing function. - """ - clean = lambda x: str(x).lower() - # first, verify the route parameters - name = clean(function_name) - cat = clean(process_type) - if cat not in settings.PROCESSING_FUNCTIONS: - raise ValueError('Processing functions are either "pre" or "post" ' - 'processing. Got %s.' % cat) - - if name not in settings.PROCESSING_FUNCTIONS[cat]: - raise ValueError('"%s" is not a valid %s-processing function' - % (name, cat)) - return settings.PROCESSING_FUNCTIONS[cat][name] - - def process(self, image, key, process_type): - """Apply the pre- or post-processing function to the image data. - - Args: - image (numpy.array): The image data to process. - key (str): The name of the function to use. - process_type (str): "pre" or "post" processing. - - Returns: - numpy.array: The processed image data. - """ - start = timeit.default_timer() - if not key: - return image - - f = self._get_processing_function(process_type, key) - - if key == 'retinanet-semantic': - # image[:-1] is targeted at a two semantic head panoptic model - # TODO This may need to be modified and generalized in the future - results = f(image[:-1]) - elif key == 'retinanet': - results = f(image, self._rawshape[0], self._rawshape[1]) - else: - results = f(image) - - if not isinstance(results, list) and results.shape[0] == 1: - results = np.squeeze(results, axis=0) - - finished = timeit.default_timer() - start - - self.update_key(self._redis_hash, { - '{}process_time'.format(process_type): finished - }) - - self.logger.debug('%s-processed key %s (model %s:%s, preprocessing: %s,' - ' postprocessing: %s) in %s seconds.', - process_type.capitalize(), self._redis_hash, - self._redis_values.get('model_name'), - self._redis_values.get('model_version'), - self._redis_values.get('preprocess_function'), - self._redis_values.get('postprocess_function'), - finished) - - return results - - def preprocess(self, image, keys): - """Wrapper for _process_image but can only call with type="pre". - - Args: - image (numpy.array): image data - keys (list): list of function names to apply to the image - - Returns: - numpy.array: pre-processed image data - """ - pre = None - for key in keys: - x = pre if pre else image - pre = self.process(x, key, 'pre') - return pre - - def postprocess(self, image, keys): - """Wrapper for _process_image but can only call with type="post". - - Args: - image (numpy.array): image data - keys (list): list of function names to apply to the image - - Returns: - numpy.array: post-processed image data - """ - post = None - for key in keys: - x = post if post else image - post = self.process(x, key, 'post') - return post - - def save_output(self, image, redis_hash, save_name, output_shape=None): + def save_output(self, image, save_name): with utils.get_tempdir() as tempdir: # Save each result channel as an image file subdir = os.path.dirname(save_name.replace(tempdir, '')) @@ -774,19 +432,8 @@ def save_output(self, image, redis_hash, save_name, output_shape=None): if not isinstance(image, list): image = [image] - # Rescale image to original size before sending back to user outpaths = [] - added_batch = False for i, im in enumerate(image): - if output_shape: - if im.shape[0] != 1: - added_batch = True - im = np.expand_dims(im, axis=0) # add batch for resize - self.logger.info('Resizing image of shape %s to %s', im.shape, output_shape) - im = resize(im, output_shape, labeled_image=True) - if added_batch: - im = im[0] - outpaths.extend(utils.save_numpy_array( im, name='{}_{}'.format(name, i), diff --git a/redis_consumer/consumers/base_consumer_test.py b/redis_consumer/consumers/base_consumer_test.py index b3a338af..98247cb5 100644 --- a/redis_consumer/consumers/base_consumer_test.py +++ b/redis_consumer/consumers/base_consumer_test.py @@ -395,60 +395,6 @@ def grpc_image_list(data, *args, **kwargs): # pylint: disable=W0613 x = np.random.random((300, 300, 1)) consumer.predict(x, model_name='modelname', model_version=0) - def test__get_processing_function(self, mocker, redis_client): - mocker.patch.object(settings, 'PROCESSING_FUNCTIONS', { - 'valid': { - 'valid': lambda x: True - } - }) - - storage = DummyStorage() - consumer = consumers.ImageFileConsumer(redis_client, storage, 'q') - - x = consumer._get_processing_function('VaLiD', 'vAlId') - y = consumer._get_processing_function('vAlId', 'VaLiD') - assert x == y - - with pytest.raises(ValueError): - consumer._get_processing_function('invalid', 'valid') - - with pytest.raises(ValueError): - consumer._get_processing_function('valid', 'invalid') - - def test_process(self, mocker, redis_client): - # TODO: better test coverage - storage = DummyStorage() - queue = 'q' - img = np.random.random((1, 32, 32, 1)) - - mocker.patch.object(settings, 'PROCESSING_FUNCTIONS', { - 'valid': { - 'valid': lambda x: x, - 'retinanet': lambda *x: x[0], - 'retinanet-semantic': lambda x: x, - } - }) - - consumer = consumers.ImageFileConsumer(redis_client, storage, queue) - - mocker.patch.object(consumer, '_redis_hash', 'a hash') - - output = consumer.process(img, '', '') - np.testing.assert_equal(img, output) - - # image is returned but channel squeezed out - output = consumer.process(img, 'valid', 'valid') - np.testing.assert_equal(img[0], output) - - img = np.random.random((2, 32, 32, 1)) - output = consumer.process(img, 'retinanet-semantic', 'valid') - np.testing.assert_equal(img[0], output) - - consumer._rawshape = (21, 21) - img = np.random.random((1, 32, 32, 1)) - output = consumer.process(img, 'retinanet', 'valid') - np.testing.assert_equal(img[0], output) - def test_get_image_scale(self, mocker, redis_client): stg = DummyStorage() consumer = consumers.TensorFlowServingConsumer(redis_client, stg, 'q') diff --git a/redis_consumer/consumers/image_consumer.py b/redis_consumer/consumers/image_consumer.py index 590f5684..2281e1fd 100644 --- a/redis_consumer/consumers/image_consumer.py +++ b/redis_consumer/consumers/image_consumer.py @@ -28,11 +28,12 @@ from __future__ import division from __future__ import print_function -import os import timeit import numpy as np +from deepcell.applications import LabelDetection, NuclearSegmentation + from redis_consumer.consumers import TensorFlowServingConsumer from redis_consumer import utils from redis_consumer import settings @@ -41,37 +42,6 @@ class ImageFileConsumer(TensorFlowServingConsumer): """Consumes image files and uploads the results""" - def detect_scale(self, image): - """Send the image to the SCALE_DETECT_MODEL to detect the relative - scale difference from the image to the model's training data. - - Args: - image (numpy.array): The image data. - - Returns: - scale (float): The detected scale, used to rescale data. - """ - start = timeit.default_timer() - - if not settings.SCALE_DETECT_ENABLED: - self.logger.debug('Scale detection disabled. Scale set to 1.') - return 1 - - model_name, model_version = settings.SCALE_DETECT_MODEL.split(':') - - scales = self.predict(image, model_name, model_version, - untile=False) - - detected_scale = np.mean(scales) - - error_rate = .01 # error rate is ~1% for current model. - if abs(detected_scale - 1) < error_rate: - detected_scale = 1 - - self.logger.debug('Scale %s detected in %s seconds', - detected_scale, timeit.default_timer() - start) - return detected_scale - def detect_label(self, image): """Send the image to the LABEL_DETECT_MODEL to detect the type of image data. The model output is mapped with settings.MODEL_CHOICES. @@ -84,31 +54,23 @@ def detect_label(self, image): """ start = timeit.default_timer() + app = self.get_grpc_app(settings.LABEL_DETECT_MODEL, LabelDetection) + if not settings.LABEL_DETECT_ENABLED: self.logger.debug('Label detection disabled. Label set to None.') return None - model_name, model_version = settings.LABEL_DETECT_MODEL.split(':') - - labels = self.predict(image, model_name, model_version, - untile=False) + batch_size = app.model.get_batch_size() + detected_label = app.predict(image, batch_size=batch_size) - labels = np.array(labels) - vote = labels.sum(axis=0) - maj = vote.max() + self.logger.debug('Label %s detected in %s seconds', + detected_label, timeit.default_timer() - start) - detected = np.where(vote == maj)[-1][0] - - self.logger.debug('Label %s detected in %s seconds.', - detected, timeit.default_timer() - start) - return detected + return detected_label def _consume(self, redis_hash): start = timeit.default_timer() hvals = self.redis.hgetall(redis_hash) - # hold on to the redis hash/values for logging purposes - self._redis_hash = redis_hash - self._redis_values = hvals if hvals.get('status') in self.finished_statuses: self.logger.warning('Found completed hash `%s` with status %s.', @@ -130,10 +92,9 @@ def _consume(self, redis_hash): _ = timeit.default_timer() # Load input image - image = self.download_image(hvals.get('input_file_name')) - - # Validate input image - image = self.validate_model_input(image, model_name, model_version) + fname = hvals.get('input_file_name') + image = self.download_image(fname) + image = np.expand_dims(image, axis=0) # add a batch dimension # Pre-process data before sending to the model self.update_key(redis_hash, { @@ -145,13 +106,6 @@ def _consume(self, redis_hash): scale = hvals.get('scale', '') scale = self.get_image_scale(scale, image, redis_hash) - original_shape = image.shape - - image = utils.rescale(image, scale) - - # Save shape value for postprocessing purposes - # TODO this is a big janky - self._rawshape = image.shape label = None if settings.LABEL_DETECT_ENABLED and model_name and model_version: self.logger.warning('Label Detection is enabled, but the model' @@ -169,56 +123,43 @@ def _consume(self, redis_hash): label = int(label) self.logger.debug('Image label already calculated: %s', label) + label = 0 # TODO: remove this! hotfix for bad label detection. + # Grap appropriate model model_name, model_version = utils._pick_model(label) - if settings.LABEL_DETECT_ENABLED and label is not None: - pre_funcs = utils._pick_preprocess(label).split(',') - else: - pre_funcs = hvals.get('preprocess_function', '').split(',') + # Validate input image + # TODO: batch dimension wonkiness + image = image[0] # remove batch dimension + image = self.validate_model_input(image, model_name, model_version) + image = np.expand_dims(image, axis=0) # add batch dim back - image = np.expand_dims(image, axis=0) # add in the batch dim - image = self.preprocess(image, pre_funcs) + app_cls = settings.APPLICATION_CHOICES.get(label, NuclearSegmentation) # Send data to the model self.update_key(redis_hash, {'status': 'predicting'}) - image = self.predict(image, model_name, model_version) + app = self.get_grpc_app(f'{model_name}:{model_version}', app_cls) - # Post-process model results - self.update_key(redis_hash, {'status': 'post-processing'}) - - if settings.LABEL_DETECT_ENABLED and label is not None: - post_funcs = utils._pick_postprocess(label).split(',') - else: - post_funcs = hvals.get('postprocess_function', '').split(',') - - image = self.postprocess(image, post_funcs) + results = app.predict(image, image_mpp=scale, + batch_size=app.model.get_batch_size()) # Save the post-processed results to a file _ = timeit.default_timer() self.update_key(redis_hash, {'status': 'saving-results'}) save_name = hvals.get('original_name', fname) - - if isinstance(image, list): - for i, img in enumerate(image): - if img.shape[-1] != 1: - image[i] = np.expand_dims(img, axis=-1) - elif image.shape[-1] != 1: - image = np.expand_dims(image, axis=-1) - dest, output_url = self.save_output( - image, redis_hash, save_name, original_shape[:-1]) + dest, output_url = self.save_output(results, save_name) # Update redis with the final results - t = timeit.default_timer() - start + end = timeit.default_timer() self.update_key(redis_hash, { 'status': self.final_status, 'output_url': output_url, - 'upload_time': timeit.default_timer() - _, + 'upload_time': end - _, 'output_file_name': dest, 'total_jobs': 1, - 'total_time': t, + 'total_time': end - start, 'finished_at': self.get_current_timestamp() }) return self.final_status diff --git a/redis_consumer/consumers/multiplex_consumer.py b/redis_consumer/consumers/multiplex_consumer.py index c595308e..8ba6f124 100644 --- a/redis_consumer/consumers/multiplex_consumer.py +++ b/redis_consumer/consumers/multiplex_consumer.py @@ -32,18 +32,44 @@ import numpy as np +from deepcell.applications import ScaleDetection, MultiplexSegmentation + from redis_consumer.consumers import TensorFlowServingConsumer -from redis_consumer import utils from redis_consumer import settings -from redis_consumer import processing class MultiplexConsumer(TensorFlowServingConsumer): """Consumes image files and uploads the results""" + def detect_scale(self, image): + """Send the image to the SCALE_DETECT_MODEL to detect the relative + scale difference from the image to the model's training data. + + Args: + image (numpy.array): The image data. + + Returns: + scale (float): The detected scale, used to rescale data. + """ + start = timeit.default_timer() + + app = self.get_grpc_app(settings.SCALE_DETECT_MODEL, ScaleDetection) + + if not settings.SCALE_DETECT_ENABLED: + self.logger.debug('Scale detection disabled.') + return app.model_mpp + + # TODO: What to do with multi-channel data? + # detected_scale = app.predict(image[..., 0]) + detected_scale = 1 + + self.logger.debug('Scale %s detected in %s seconds', + detected_scale, timeit.default_timer() - start) + + return app.model_mpp * detected_scale + def _consume(self, redis_hash): start = timeit.default_timer() - self._redis_hash = redis_hash # workaround for logging. hvals = self.redis.hgetall(redis_hash) if hvals.get('status') in self.finished_statuses: @@ -65,7 +91,8 @@ def _consume(self, redis_hash): _ = timeit.default_timer() # Load input image - image = self.download_image(hvals.get('input_file_name')) + fname = hvals.get('input_file_name') + image = self.download_image(fname) # squeeze extra dimension that is added by get_image image = np.squeeze(image) @@ -83,41 +110,32 @@ def _consume(self, redis_hash): scale = hvals.get('scale', '') scale = self.get_image_scale(scale, image, redis_hash) - original_shape = image.shape - # Rescale each channel of the image - image = utils.rescale(image, scale) image = np.expand_dims(image, axis=0) # add in the batch dim - # Preprocess image - image = self.preprocess(image, ['multiplex_preprocess']) - # Send data to the model - self.update_key(redis_hash, {'status': 'predicting'}) - image = self.predict(image, model_name, model_version) + app = self.get_grpc_app( + f'{model_name}:{model_version}', MultiplexSegmentation) - # Post-process model results - self.update_key(redis_hash, {'status': 'post-processing'}) - image = processing.format_output_multiplex(image) - image = self.postprocess(image, ['multiplex_postprocess_consumer']) + results = app.predict(image, image_mpp=scale, + batch_size=settings.TF_MAX_BATCH_SIZE) # Save the post-processed results to a file _ = timeit.default_timer() self.update_key(redis_hash, {'status': 'saving-results'}) save_name = hvals.get('original_name', fname) - dest, output_url = self.save_output( - image, redis_hash, save_name, original_shape[:-1]) + dest, output_url = self.save_output(results, save_name) # Update redis with the final results - t = timeit.default_timer() - start + end = timeit.default_timer() self.update_key(redis_hash, { 'status': self.final_status, 'output_url': output_url, - 'upload_time': timeit.default_timer() - _, + 'upload_time': end - _, 'output_file_name': dest, 'total_jobs': 1, - 'total_time': t, + 'total_time': end - start, 'finished_at': self.get_current_timestamp() }) return self.final_status diff --git a/redis_consumer/grpc_clients.py b/redis_consumer/grpc_clients.py index 2f6c3787..a944c88b 100644 --- a/redis_consumer/grpc_clients.py +++ b/redis_consumer/grpc_clients.py @@ -343,3 +343,85 @@ def progress(self, progress): # clamp to an integer between 0 and 100 progress = min(100, max(0, round(progress))) self.progress_callback(self.redis_hash, progress) + + +class GrpcModelWrapper(object): + """A wrapper class that mocks a Keras model using a gRPC client. + + https://github.com/vanvalenlab/deepcell-tf/blob/master/deepcell/applications + """ + + def __init__(self, client, model_metadata): + self._client = client + + if len(model_metadata) > 1: + # TODO: how to handle this? + raise NotImplementedError('Multiple input tensors are not supported.') + + self._metadata = model_metadata[0] + + self._in_tensor_name = self._metadata['in_tensor_name'] + self._in_tensor_dtype = str(self._metadata['in_tensor_dtype']).upper() + + shape = [int(x) for x in self._metadata['in_tensor_shape'].split(',')] + self.input_shape = shape + + def send_grpc(self, img): + """Use the TensorFlow Serving gRPC API for model inference on an image. + + Args: + img (numpy.array): The image to send to the model + + Returns: + numpy.array: The results of model inference. + """ + start = timeit.default_timer() + if self._in_tensor_dtype == 'DT_HALF': + # TODO: seems like should cast to "half" + # but the model rejects the type, wants "int" or "long" + img = img.astype('int') + + req_data = [{'in_tensor_name': self._in_tensor_name, + 'in_tensor_dtype': self._in_tensor_dtype, + 'data': img}] + + prediction = self._client.predict(req_data, settings.GRPC_TIMEOUT) + results = [prediction[k] for k in sorted(prediction.keys())] + + self._client.logger.debug('Got prediction results of shape %s in %s s.', + [r.shape for r in results], + timeit.default_timer() - start) + + if len(results) == 1: + results = results[0] + + return results + + def get_batch_size(self): + """Calculate the best batch size based on TF_MAX_BATCH_SIZE and + TF_MIN_MODEL_SIZE + """ + rank = len(self.input_shape) + ratio = (self.input_shape[rank - 3] / settings.TF_MIN_MODEL_SIZE) * \ + (self.input_shape[rank - 2] / settings.TF_MIN_MODEL_SIZE) * \ + (self.input_shape[rank - 1]) + + batch_size = int(settings.TF_MAX_BATCH_SIZE // ratio) + return batch_size + + def predict(self, tiles, batch_size): + results = [] + + for t in range(0, tiles.shape[0], batch_size): + output = self.send_grpc(tiles[t:t + batch_size]) + + if not isinstance(output, list): + output = [output] + + if len(results) == 0: + results = output + else: + for i, o in enumerate(output): + results[i] = np.vstack((results[i], o)) + + return results diff --git a/redis_consumer/settings.py b/redis_consumer/settings.py index 73778fce..928cd926 100644 --- a/redis_consumer/settings.py +++ b/redis_consumer/settings.py @@ -33,7 +33,7 @@ import grpc from decouple import config -from redis_consumer import processing +import deepcell # remove leading/trailing '/'s from cloud bucket folder names @@ -117,26 +117,6 @@ def _strip(x): # Configure expiration for cached model metadata METADATA_EXPIRE_TIME = config('METADATA_EXPIRE_TIME', default=30, cast=int) -# Pre- and Post-processing settings -PROCESSING_FUNCTIONS = { - 'pre': { - 'normalize': processing.normalize, - 'histogram_normalization': processing.phase_preprocess, - 'multiplex_preprocess': processing.multiplex_preprocess, - 'none': lambda x: x - }, - 'post': { - 'deepcell': processing.pixelwise, # TODO: this is deprecated. - 'pixelwise': processing.pixelwise, - 'watershed': processing.watershed, - 'retinanet': processing.retinanet_to_label_image, - 'retinanet-semantic': processing.retinanet_semantic_to_label_image, - 'deep_watershed': processing.deep_watershed, - 'multiplex_postprocess_consumer': processing.multiplex_postprocess_consumer, - 'none': lambda x: x - }, -} - # Tracking settings TRACKING_SEGMENT_MODEL = config('TRACKING_SEGMENT_MODEL', default='panoptic:3', cast=str) TRACKING_POSTPROCESS_FUNCTION = config('TRACKING_POSTPROCESS_FUNCTION', @@ -176,14 +156,8 @@ def _strip(x): 2: config('CYTOPLASM_MODEL', default='FluoCytoSegmentation:0', cast=str) } -PREPROCESS_CHOICES = { - 0: config('NUCLEAR_PREPROCESS', default='normalize', cast=str), - 1: config('PHASE_PREPROCESS', default='histogram_normalization', cast=str), - 2: config('CYTOPLASM_PREPROCESS', default='histogram_normalization', cast=str) -} - -POSTPROCESS_CHOICES = { - 0: config('NUCLEAR_POSTPROCESS', default='deep_watershed', cast=str), - 1: config('PHASE_POSTPROCESS', default='deep_watershed', cast=str), - 2: config('CYTOPLASM_POSTPROCESS', default='deep_watershed', cast=str) +APPLICATION_CHOICES = { + 0: deepcell.applications.NuclearSegmentation, + 1: deepcell.applications.CytoplasmSegmentation, + 2: deepcell.applications.CytoplasmSegmentation } From 451865d1bb88d2f3b211ce47c5aa6105e4d2575a Mon Sep 17 00:00:00 2001 From: William Graf <7930703+willgraf@users.noreply.github.com> Date: Wed, 20 Jan 2021 19:08:53 -0800 Subject: [PATCH 06/73] Remove now unused code. --- redis_consumer/consumers/tracking_consumer.py | 4 +- redis_consumer/processing.py | 91 -------- redis_consumer/processing_test.py | 75 ------- redis_consumer/utils.py | 149 ------------- redis_consumer/utils_test.py | 204 +++--------------- 5 files changed, 29 insertions(+), 494 deletions(-) delete mode 100644 redis_consumer/processing.py delete mode 100644 redis_consumer/processing_test.py diff --git a/redis_consumer/consumers/tracking_consumer.py b/redis_consumer/consumers/tracking_consumer.py index a1b7999a..59fc14d9 100644 --- a/redis_consumer/consumers/tracking_consumer.py +++ b/redis_consumer/consumers/tracking_consumer.py @@ -36,13 +36,13 @@ from skimage.external import tifffile import numpy as np +from deepcell_toolbox.processing import normalize from redis_consumer.grpc_clients import TrackingClient from redis_consumer.consumers import TensorFlowServingConsumer from redis_consumer import utils from redis_consumer import tracking from redis_consumer import settings -from redis_consumer import processing class TrackingConsumer(TensorFlowServingConsumer): @@ -92,7 +92,7 @@ def _get_tracker(self, redis_hash, hvalues, raw, segmented): # If not, the data must be normalized before being tracked. if settings.NORMALIZE_TRACKING: for frame in range(raw.shape[0]): - raw[frame, ..., 0] = processing.normalize(raw[frame, ..., 0]) + raw[frame, ..., 0] = normalize(raw[frame, ..., 0]) features = {'appearance', 'distance', 'neighborhood', 'regionprop'} tracker = tracking.CellTracker( diff --git a/redis_consumer/processing.py b/redis_consumer/processing.py deleted file mode 100644 index 208a962d..00000000 --- a/redis_consumer/processing.py +++ /dev/null @@ -1,91 +0,0 @@ -# Copyright 2016-2020 The Van Valen Lab at the California Institute of -# Technology (Caltech), with support from the Paul Allen Family Foundation, -# Google, & National Institutes of Health (NIH) under Grant U24CA224309-01. -# All rights reserved. -# -# Licensed under a modified Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.github.com/vanvalenlab/kiosk-redis-consumer/LICENSE -# -# The Work provided may be used for non-commercial academic purposes only. -# For any other use of the Work, including commercial use, please contact: -# vanvalenlab@gmail.com -# -# Neither the name of Caltech nor the names of its contributors may be used -# to endorse or promote products derived from this software without specific -# prior written permission. -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""DEPRECATED. Please use the "deepell_toolbox" package instead. - -Functions for pre- and post-processing image data -""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -# pylint: disable=W0611 - -from deepcell_toolbox import normalize -from deepcell_toolbox import mibi -from deepcell_toolbox import watershed -from deepcell_toolbox import pixelwise -from deepcell_toolbox import correct_drift - -from deepcell_toolbox.deep_watershed import deep_watershed - -# import mibi pre- and post-processing functions -from deepcell_toolbox.processing import phase_preprocess -from deepcell_toolbox.multiplex_utils import format_output_multiplex -from deepcell_toolbox.multiplex_utils import multiplex_preprocess -from deepcell_toolbox.multiplex_utils import multiplex_postprocess - -from deepcell_toolbox import retinanet_semantic_to_label_image -from deepcell_toolbox import retinanet_to_label_image - -del absolute_import -del division -del print_function - - -def multiplex_postprocess_consumer(model_output, compartment='whole-cell', - whole_cell_kwargs=None, - nuclear_kwargs=None): - """Wrapper function to control post-processing params - - Args: - model_output (dict): output to be post-processed - compartment (str): which cellular compartments to generate predictions for. - must be one of 'whole_cell', 'nuclear', 'both' - whole_cell_kwargs (dict): Optional list of post-processing kwargs for whole-cell prediction - nuclear_kwargs (dict): Optional list of post-processing kwargs for nuclear prediction - - Returns: - numpy.ndarray: labeled image - """ - - if whole_cell_kwargs is None: - whole_cell_kwargs = {'maxima_threshold': 0.1, 'maxima_model_smooth': 0, - 'interior_threshold': 0.3, 'interior_model_smooth': 2, - 'small_objects_threshold': 15, - 'fill_holes_threshold': 15, - 'radius': 2} - if nuclear_kwargs is None: - nuclear_kwargs = {'maxima_threshold': 0.1, 'maxima_model_smooth': 0, - 'interior_threshold': 0.6, 'interior_model_smooth': 0, - 'small_objects_threshold': 15, - 'fill_holes_threshold': 15, - 'radius': 2} - - label_images = multiplex_postprocess(model_output=model_output, compartment=compartment, - whole_cell_kwargs=whole_cell_kwargs, - nuclear_kwargs=nuclear_kwargs) - - return label_images diff --git a/redis_consumer/processing_test.py b/redis_consumer/processing_test.py deleted file mode 100644 index 36637523..00000000 --- a/redis_consumer/processing_test.py +++ /dev/null @@ -1,75 +0,0 @@ -# Copyright 2016-2020 The Van Valen Lab at the California Institute of -# Technology (Caltech), with support from the Paul Allen Family Foundation, -# Google, & National Institutes of Health (NIH) under Grant U24CA224309-01. -# All rights reserved. -# -# Licensed under a modified Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.github.com/vanvalenlab/kiosk-redis-consumer/LICENSE -# -# The Work provided may be used for non-commercial academic purposes only. -# For any other use of the Work, including commercial use, please contact: -# vanvalenlab@gmail.com -# -# Neither the name of Caltech nor the names of its contributors may be used -# to endorse or promote products derived from this software without specific -# prior written permission. -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import numpy as np - -from redis_consumer.processing import multiplex_postprocess_consumer - - -# return input dicts to make sure they were passed appropriately -def mocked_postprocessing(model_output, compartment, whole_cell_kwargs, nuclear_kwargs): - return whole_cell_kwargs, nuclear_kwargs - - -def test_multiplex_postprocess_consumer(mocker): - mocker.patch('redis_consumer.processing.multiplex_postprocess', mocked_postprocessing) - - model_output = np.zeros((1, 40, 40, 4)) - compartment = 'both' - - defalt_cell_dict = {'maxima_threshold': 0.1, 'maxima_model_smooth': 0, - 'interior_threshold': 0.3, 'interior_model_smooth': 2, - 'small_objects_threshold': 15, - 'fill_holes_threshold': 15, - 'radius': 2} - - default_nuc_dict = {'maxima_threshold': 0.1, 'maxima_model_smooth': 0, - 'interior_threshold': 0.6, 'interior_model_smooth': 0, - 'small_objects_threshold': 15, - 'fill_holes_threshold': 15, - 'radius': 2} - - cell_dict, nuc_dict = multiplex_postprocess_consumer(model_output=model_output, - compartment=compartment, - whole_cell_kwargs=None, - nuclear_kwargs=None) - - assert defalt_cell_dict == cell_dict - assert default_nuc_dict == nuc_dict - - modified_cell_dict = {'maxima_threshold': 0.4, 'maxima_model_smooth': 4, - 'small_objects_threshold': 2, - 'radius': 0} - - modified_nuc_dict = {'maxima_threshold': 0.43, 'maxima_model_smooth': 41, - 'small_objects_threshold': 20, - 'radius': 4} - - cell_dict, nuc_dict = multiplex_postprocess_consumer(model_output=model_output, - compartment=compartment, - whole_cell_kwargs=modified_cell_dict, - nuclear_kwargs=modified_nuc_dict) - - assert modified_cell_dict == cell_dict - assert modified_nuc_dict == nuc_dict diff --git a/redis_consumer/utils.py b/redis_consumer/utils.py index dd6ca2b8..86a99ba6 100644 --- a/redis_consumer/utils.py +++ b/redis_consumer/utils.py @@ -135,47 +135,6 @@ def get_image(filepath): return img.astype('float32') -def pad_image(image, field): - """Pad each the input image for proper dimensions when stitiching. - - Args: - image: np.array of image data - field: receptive field size of model - - Returns: - image data padded in the x and y axes - """ - window = (field - 1) // 2 - # Pad images by the field size in the x and y axes - pad_width = [] - for i in range(len(image.shape)): - if i == image.ndim - 3: - pad_width.append((window, window)) - elif i == image.ndim - 2: - pad_width.append((window, window)) - else: - pad_width.append((0, 0)) - - return np.pad(image, pad_width, mode='reflect') - - -def unpad_image(x, pad_width): - """Unpad image padded with the pad_width. - - Args: - image (numpy.array): Image to unpad. - pad_width (list): List of pads used to pad the image with np.pad. - - Returns: - numpy.array: The unpadded image. - """ - slices = [] - for c in pad_width: - e = None if c[1] == 0 else -c[1] - slices.append(slice(c[0], e)) - return x[tuple(slices)] - - def save_numpy_array(arr, name='', subdir='', output_dir=None): """Split tensor into channels and save each as a tiff file. @@ -296,96 +255,6 @@ def zip_files(files, dest=None, prefix=None): return filepath -def reshape_matrix(X, y, reshape_size=256, is_channels_first=False): - """ - Reshape matrix of dimension 4 to have x and y of size reshape_size. - Adds overlapping slices to batches. - E.g. reshape_size of 256 yields (1, 1024, 1024, 1) -> (16, 256, 256, 1) - - Args: - X: raw 4D image tensor - y: label mask of 4D image data - reshape_size: size of the square output tensor - is_channels_first: default False for channel dimension last - - Returns: - reshaped `X` and `y` tensors in shape (`reshape_size`, `reshape_size`) - """ - if X.ndim != 4: - raise ValueError('reshape_matrix expects X dim to be 4, got', X.ndim) - elif y.ndim != 4: - raise ValueError('reshape_matrix expects y dim to be 4, got', y.ndim) - - image_size_x, _ = X.shape[2:] if is_channels_first else X.shape[1:3] - rep_number = np.int(np.ceil(np.float(image_size_x) / np.float(reshape_size))) - new_batch_size = X.shape[0] * (rep_number) ** 2 - - if is_channels_first: - new_X_shape = (new_batch_size, X.shape[1], reshape_size, reshape_size) - new_y_shape = (new_batch_size, y.shape[1], reshape_size, reshape_size) - else: - new_X_shape = (new_batch_size, reshape_size, reshape_size, X.shape[3]) - new_y_shape = (new_batch_size, reshape_size, reshape_size, y.shape[3]) - - new_X = np.zeros(new_X_shape, dtype='float32') - new_y = np.zeros(new_y_shape, dtype='int32') - - counter = 0 - for b in range(X.shape[0]): - for i in range(rep_number): - for j in range(rep_number): - if i != rep_number - 1: - x_start, x_end = i * reshape_size, (i + 1) * reshape_size - else: - x_start, x_end = -reshape_size, X.shape[2 if is_channels_first else 1] - - if j != rep_number - 1: - y_start, y_end = j * reshape_size, (j + 1) * reshape_size - else: - y_start, y_end = -reshape_size, y.shape[3 if is_channels_first else 2] - - if is_channels_first: - new_X[counter] = X[b, :, x_start:x_end, y_start:y_end] - new_y[counter] = y[b, :, x_start:x_end, y_start:y_end] - else: - new_X[counter] = X[b, x_start:x_end, y_start:y_end, :] - new_y[counter] = y[b, x_start:x_end, y_start:y_end, :] - - counter += 1 - - print('Reshaped feature data from {} to {}'.format(y.shape, new_y.shape)) - print('Reshaped training data from {} to {}'.format(X.shape, new_X.shape)) - return new_X, new_y - - -def rescale(image, scale, channel_axis=-1): - multichannel = False - add_channel = False - if scale == 1: - return image # no rescale necessary, short-circuit - - if len(image.shape) != 2: # we have a channel axis - try: - image = np.squeeze(image, axis=channel_axis) - add_channel = True - except: # pylint: disable=bare-except - multichannel = True # channel axis is not 1 - - rescaled_img = skimage.transform.rescale( - image, scale, - mode='edge', - anti_aliasing=False, - anti_aliasing_sigma=None, - multichannel=multichannel, - preserve_range=True, - order=0 - ) - if add_channel: - rescaled_img = np.expand_dims(rescaled_img, axis=channel_axis) - logger.debug('Rescaled image from %s to %s', image.shape, rescaled_img.shape) - return rescaled_img - - def _pick_model(label): model = settings.MODEL_CHOICES.get(label) if model is None: @@ -393,21 +262,3 @@ def _pick_model(label): raise ValueError('Label type {} is not supported'.format(label)) return model.split(':') - - -def _pick_preprocess(label): - func = settings.PREPROCESS_CHOICES.get(label) - if func is None: - logger.error('Label type %s is not supported', label) - raise ValueError('Label type {} is not supported'.format(label)) - - return func - - -def _pick_postprocess(label): - func = settings.POSTPROCESS_CHOICES.get(label) - if func is None: - logger.error('Label type %s is not supported', label) - raise ValueError('Label type {} is not supported'.format(label)) - - return func diff --git a/redis_consumer/utils_test.py b/redis_consumer/utils_test.py index b747e54b..2f0cdb3c 100644 --- a/redis_consumer/utils_test.py +++ b/redis_consumer/utils_test.py @@ -29,7 +29,6 @@ from __future__ import print_function import os -import pytest import tarfile import tempfile import zipfile @@ -111,6 +110,7 @@ def test_get_image(tmpdir): test_img_path = os.path.join(tmpdir, 'phase.tif') _write_image(test_img_path, 300, 300) test_img = utils.get_image(test_img_path) + print(test_img.shape) np.testing.assert_equal(test_img.shape, (300, 300, 1)) # test png files test_img_path = os.path.join(tmpdir, 'feature_0.png') @@ -120,92 +120,36 @@ def test_get_image(tmpdir): np.testing.assert_equal(test_img.shape, (400, 400, 1)) -def test_pad_image(): - # 2D images - h, w = 300, 300 - img = _get_image(h, w) - field_size = 61 - padded = utils.pad_image(img, field_size) - - new_h, new_w = h + (field_size - 1), w + (field_size - 1) - np.testing.assert_equal(padded.shape, (new_h, new_w, 1)) - - # 3D images - frames = np.random.randint(low=1, high=6) - imgs = np.vstack([_get_image(h, w)[None, ...] for i in range(frames)]) - padded = utils.pad_image(imgs, field_size) - np.testing.assert_equal(padded.shape, (frames, new_h, new_w, 1)) - - -def test_unpad_image(): - # 2D images - h, w = 300, 300 - - sizes = [ - (300, 300), - (101, 101) - ] - - pads = [ - (10, 10), - (15, 15), - (10, 15) - ] - for pad in pads: - for h, w in sizes: - raw = _get_image(h, w) - pad_width = [pad, pad, (0, 0)] - padded = np.pad(raw, pad_width, mode='reflect') - - unpadded = utils.unpad_image(padded, pad_width) - np.testing.assert_equal(unpadded.shape, (h, w, 1)) - np.testing.assert_equal(unpadded, raw) - - # 3D images - frames = np.random.randint(low=1, high=6) - imgs = np.vstack([_get_image(h, w)[None, ...] - for _ in range(frames)]) - - pad_width = [(0, 0), pad, pad, (0, 0)] - - padded = np.pad(imgs, pad_width, mode='reflect') - - unpadded = utils.unpad_image(padded, pad_width) - - np.testing.assert_equal(unpadded.shape, imgs.shape) - np.testing.assert_equal(unpadded, imgs) - - -def test_save_numpy_array(): +def test_save_numpy_array(tmpdir): + tmpdir = str(tmpdir) h, w = 30, 30 c = np.random.randint(low=1, high=4) z = np.random.randint(low=1, high=6) - with utils.get_tempdir() as tempdir: - # 2D images without channel axis - img = _get_image(h, w, 1) - img = np.squeeze(img) - files = utils.save_numpy_array(img, 'name', '/a/b/', tempdir) - assert len(files) == 1 - for f in files: - assert os.path.isfile(f) - assert f.startswith(os.path.join(tempdir, 'a', 'b')) - - # 2D images - img = _get_image(h, w, c) - files = utils.save_numpy_array(img, 'name', '/a/b/', tempdir) - assert len(files) == c - for f in files: - assert os.path.isfile(f) - assert f.startswith(os.path.join(tempdir, 'a', 'b')) - - # 3D images - imgs = np.vstack([_get_image(h, w, c)[None, ...] for i in range(z)]) - files = utils.save_numpy_array(imgs, 'name', '/a/b/', tempdir) - assert len(files) == c - for f in files: - assert os.path.isfile(f) - assert f.startswith(os.path.join(tempdir, 'a', 'b')) + # 2D images without channel axis + img = _get_image(h, w, 1) + img = np.squeeze(img) + files = utils.save_numpy_array(img, 'name', '/a/b/', tmpdir) + assert len(files) == 1 + for f in files: + assert os.path.isfile(f) + assert f.startswith(os.path.join(tmpdir, 'a', 'b')) + + # 2D images + img = _get_image(h, w, c) + files = utils.save_numpy_array(img, 'name', '/a/b/', tmpdir) + assert len(files) == c + for f in files: + assert os.path.isfile(f) + assert f.startswith(os.path.join(tmpdir, 'a', 'b')) + + # 3D images + imgs = np.vstack([_get_image(h, w, c)[None, ...] for i in range(z)]) + files = utils.save_numpy_array(imgs, 'name', '/a/b/', tmpdir) + assert len(files) == c + for f in files: + assert os.path.isfile(f) + assert f.startswith(os.path.join(tmpdir, 'a', 'b')) # Bad path will not fail, but will log error img = _get_image(h, w, c) @@ -279,82 +223,6 @@ def test_zip_files(tmpdir): zip_path = utils.zip_files(paths, bad_dest, prefix) -def test_reshape_matrix(): - # K.set_image_data_format('channels_last') - X = np.zeros((1, 16, 16, 3)) - y = np.zeros((1, 16, 16, 1)) - new_size = 4 - - # test resize to smaller image, divisible - new_X, new_y = utils.reshape_matrix(X, y, new_size) - new_batch = np.ceil(16 / new_size) ** 2 - assert new_X.shape == (new_batch, new_size, new_size, 3) - assert new_y.shape == (new_batch, new_size, new_size, 1) - - # test reshape with non-divisible values. - new_size = 5 - new_batch = np.ceil(16 / new_size) ** 2 - new_X, new_y = utils.reshape_matrix(X, y, new_size) - assert new_X.shape == (new_batch, new_size, new_size, 3) - assert new_y.shape == (new_batch, new_size, new_size, 1) - - # test reshape to bigger size - with pytest.raises(ValueError): - new_X, new_y = utils.reshape_matrix(X, y, 32) - - # test wrong dimensions - bigger = np.zeros((1, 16, 16, 3, 1)) - smaller = np.zeros((1, 16, 16)) - with pytest.raises(ValueError): - new_X, new_y = utils.reshape_matrix(smaller, y, new_size) - with pytest.raises(ValueError): - new_X, new_y = utils.reshape_matrix(bigger, y, new_size) - with pytest.raises(ValueError): - new_X, new_y = utils.reshape_matrix(X, smaller, new_size) - with pytest.raises(ValueError): - new_X, new_y = utils.reshape_matrix(X, bigger, new_size) - - # channels_first - # K.set_image_data_format('channels_first') - X = np.zeros((1, 3, 16, 16)) - y = np.zeros((1, 1, 16, 16)) - new_size = 4 - - # test resize to smaller image, divisible - new_X, new_y = utils.reshape_matrix(X, y, new_size, True) - new_batch = np.ceil(16 / new_size) ** 2 - assert new_X.shape == (new_batch, 3, new_size, new_size) - assert new_y.shape == (new_batch, 1, new_size, new_size) - - # test reshape with non-divisible values. - new_size = 5 - new_batch = np.ceil(16 / new_size) ** 2 - new_X, new_y = utils.reshape_matrix(X, y, new_size, True) - assert new_X.shape == (new_batch, 3, new_size, new_size) - assert new_y.shape == (new_batch, 1, new_size, new_size) - - -def test_rescale(): - scales = [.5, 2] - shapes = [(4, 4, 5), (4, 4, 1), (4, 4)] - for scale in scales: - for shape in shapes: - image = np.random.random(shape) - rescaled = utils.rescale(image, 1) - np.testing.assert_array_equal(rescaled, image) - - rescaled = utils.rescale(image, scale) - expected_shape = (int(np.ceil(shape[0] * scale)), - int(np.ceil(shape[1] * scale))) - - if len(shape) > 2: - expected_shape = tuple(list(expected_shape) + [int(shape[2])]) - assert rescaled.shape == expected_shape - # scale it back - rescaled = utils.rescale(rescaled, 1 / scale) - assert rescaled.shape == shape - - def test__pick_model(mocker): mocker.patch.object(settings, 'MODEL_CHOICES', {0: 'dummymodel:0'}) res = utils._pick_model(0) @@ -364,21 +232,3 @@ def test__pick_model(mocker): with pytest.raises(ValueError): utils._pick_model(-1) - - -def test__pick_preprocess(mocker): - mocker.patch.object(settings, 'PREPROCESS_CHOICES', {0: 'pre'}) - res = utils._pick_preprocess(0) - assert res == 'pre' - - with pytest.raises(ValueError): - utils._pick_preprocess(-1) - - -def test__pick_postprocess(mocker): - mocker.patch.object(settings, 'POSTPROCESS_CHOICES', {0: 'post'}) - res = utils._pick_postprocess(0) - assert res == 'post' - - with pytest.raises(ValueError): - utils._pick_postprocess(-1) From 798cc24583fabd3440f29b471ad03dcb33cbf429 Mon Sep 17 00:00:00 2001 From: William Graf <7930703+willgraf@users.noreply.github.com> Date: Wed, 20 Jan 2021 19:10:27 -0800 Subject: [PATCH 07/73] Remove tempdir utils function from storage_test.py --- redis_consumer/storage_test.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/redis_consumer/storage_test.py b/redis_consumer/storage_test.py index fd381210..ebde4f66 100644 --- a/redis_consumer/storage_test.py +++ b/redis_consumer/storage_test.py @@ -137,18 +137,19 @@ def test_get_backoff(self): backoff = client.get_backoff(attempts=5) assert backoff == max_backoff - def test_get_download_path(self, mocker): + def test_get_download_path(self, mocker, tmpdir): + tmpdir = str(tmpdir) mocker.patch('redis_consumer.storage.Storage.get_storage_client', lambda *x: True) - with utils.get_tempdir() as tempdir: - bucket = 'test-bucket' - stg = storage.Storage(bucket, tempdir) - filekey = 'upload_dir/key/to.zip' - path = stg.get_download_path(filekey, tempdir) - path2 = stg.get_download_path(filekey) - assert path == path2 - assert str(path).startswith(tempdir) - assert str(path).endswith(filekey.replace('upload_dir/', '')) + + bucket = 'test-bucket' + stg = storage.Storage(bucket, tmpdir) + filekey = 'upload_dir/key/to.zip' + path = stg.get_download_path(filekey, tmpdir) + path2 = stg.get_download_path(filekey) + assert path == path2 + assert str(path).startswith(tmpdir) + assert str(path).endswith(filekey.replace('upload_dir/', '')) class TestGoogleStorage(object): From 65edc9fddc0aea4f113a8a6f7a64cc663b439447 Mon Sep 17 00:00:00 2001 From: William Graf <7930703+willgraf@users.noreply.github.com> Date: Wed, 20 Jan 2021 19:16:31 -0800 Subject: [PATCH 08/73] Replace contextlib tempdir workaround with tempfile.TemporaryDirectory --- README.md | 3 +-- redis_consumer/consumers/base_consumer.py | 7 ++--- redis_consumer/consumers/tracking_consumer.py | 9 ++++--- redis_consumer/utils.py | 26 ------------------- 4 files changed, 10 insertions(+), 35 deletions(-) diff --git a/README.md b/README.md index c89a726f..753bcac7 100644 --- a/README.md +++ b/README.md @@ -17,7 +17,6 @@ Consumers consume Redis events. Each type of Redis event is put into a queue (e. Consumers call the `_consume` method to consume each item it finds in the queue. This method must be implemented for every consumer. - The quickest way to get a custom consumer up and running is to: 1. Add a new file for the consumer: `redis_consumer/consumers/my_new_consumer.py` @@ -45,7 +44,7 @@ def _consume(self, redis_hash): # and parsed in settings.py. model_name, model_version = 'CustomModel:1'.split(':') - with utils.get_tempdir() as tempdir: + with tempfile.TemporaryDirectory() as tempdir: # download the image file fname = self.storage.download(input_file_name, tempdir) # load image file as data diff --git a/redis_consumer/consumers/base_consumer.py b/redis_consumer/consumers/base_consumer.py index 11f87126..8b697b7d 100644 --- a/redis_consumer/consumers/base_consumer.py +++ b/redis_consumer/consumers/base_consumer.py @@ -33,6 +33,7 @@ import logging import os import sys +import tempfile import time import timeit import urllib @@ -424,7 +425,7 @@ def get_image_scale(self, scale, image, redis_hash): return scale def save_output(self, image, save_name): - with utils.get_tempdir() as tempdir: + with tempfile.TemporaryDirectory() as tempdir: # Save each result channel as an image file subdir = os.path.dirname(save_name.replace(tempdir, '')) name = os.path.splitext(os.path.basename(save_name))[0] @@ -476,7 +477,7 @@ def _upload_archived_images(self, hvalues, redis_hash): """Extract all image files and upload them to storage and redis""" all_hashes = set() archive_uuid = uuid.uuid4().hex - with utils.get_tempdir() as tempdir: + with tempfile.TemporaryDirectory() as tempdir: fname = self.storage.download(hvalues.get('input_file_name'), tempdir) image_files = utils.get_image_files_from_dir(fname, tempdir) for i, imfile in enumerate(image_files): @@ -552,7 +553,7 @@ def _get_output_file_name(self, key): def _upload_finished_children(self, finished_children, redis_hash): # saved_files = set() - with utils.get_tempdir() as tempdir: + with tempfile.TemporaryDirectory() as tempdir: filename = '{}.zip'.format(uuid.uuid4().hex) zip_path = os.path.join(tempdir, filename) diff --git a/redis_consumer/consumers/tracking_consumer.py b/redis_consumer/consumers/tracking_consumer.py index 59fc14d9..cf2a2540 100644 --- a/redis_consumer/consumers/tracking_consumer.py +++ b/redis_consumer/consumers/tracking_consumer.py @@ -30,6 +30,7 @@ import json import os +import tempfile import time import timeit import uuid @@ -169,7 +170,7 @@ def _load_data(self, redis_hash, subdir, fname): uid = uuid.uuid4().hex for i, img in enumerate(tiff_stack): - with utils.get_tempdir() as tempdir: + with tempfile.TemporaryDirectory() as tempdir: # Save and upload the frame. segment_fname = '{}-{}-tracking-frame-{}.tif'.format( uid, hvalues.get('original_name'), i) @@ -231,7 +232,7 @@ def _load_data(self, redis_hash, subdir, fname): if status == self.final_status: # Segmentation is finished, save and load the frame. - with utils.get_tempdir() as tempdir: + with tempfile.TemporaryDirectory() as tempdir: out = self.redis.hget(segment_hash, 'output_file_name') frame_zip = self.storage.download(out, tempdir) frame_files = list(utils.iter_image_archive( @@ -274,7 +275,7 @@ def _consume(self, redis_hash): 'identity_started': self.name, }) - with utils.get_tempdir() as tempdir: + with tempfile.TemporaryDirectory() as tempdir: fname = self.storage.download(hvalues.get('input_file_name'), tempdir) data = self._load_data(redis_hash, tempdir, fname) @@ -303,7 +304,7 @@ def _consume(self, redis_hash): tracked_data = tracker.postprocess() self.update_key(redis_hash, {'status': 'saving-results'}) - with utils.get_tempdir() as tempdir: + with tempfile.TemporaryDirectory() as tempdir: # Save lineage data to JSON file lineage_file = os.path.join(tempdir, 'lineage.json') with open(lineage_file, 'w') as fp: diff --git a/redis_consumer/utils.py b/redis_consumer/utils.py index 86a99ba6..bde7de46 100644 --- a/redis_consumer/utils.py +++ b/redis_consumer/utils.py @@ -32,12 +32,9 @@ import os import time import timeit -import contextlib import hashlib import logging -import shutil import tarfile -import tempfile import zipfile import numpy as np @@ -51,29 +48,6 @@ logger = logging.getLogger('redis_consumer.utils') -# Workaround for python2 not supporting `with tempfile.TemporaryDirectory() as` -# These are unnecessary if not supporting python2 -@contextlib.contextmanager -def cd(newdir, cleanup=lambda: True): - prevdir = os.getcwd() - os.chdir(os.path.expanduser(newdir)) - try: - yield - finally: - os.chdir(prevdir) - cleanup() - - -@contextlib.contextmanager -def get_tempdir(): - dirpath = tempfile.mkdtemp() - - def cleanup(): - return shutil.rmtree(dirpath) - with cd(dirpath, cleanup): - yield dirpath - - def iter_image_archive(zip_path, destination): """Extract all files in archive and yield the paths of all images. From 6083e88e927dd0f860f8971cd6c38b0492334bcc Mon Sep 17 00:00:00 2001 From: William Graf <7930703+willgraf@users.noreply.github.com> Date: Wed, 20 Jan 2021 19:16:54 -0800 Subject: [PATCH 09/73] Change container base image back to python:3.7 --- Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Dockerfile b/Dockerfile index b32a145f..6ed5b706 100644 --- a/Dockerfile +++ b/Dockerfile @@ -23,7 +23,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ -FROM python:3.7 +FROM python:3.6 WORKDIR /usr/src/app From 8f5c80e58f7e9ccfa087e553571207d97f0eea6b Mon Sep 17 00:00:00 2001 From: William Graf <7930703+willgraf@users.noreply.github.com> Date: Wed, 20 Jan 2021 19:19:29 -0800 Subject: [PATCH 10/73] missed a utils.get_tempdir --- redis_consumer/consumers/base_consumer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/redis_consumer/consumers/base_consumer.py b/redis_consumer/consumers/base_consumer.py index 8b697b7d..df5179c4 100644 --- a/redis_consumer/consumers/base_consumer.py +++ b/redis_consumer/consumers/base_consumer.py @@ -264,7 +264,7 @@ def is_valid_hash(self, redis_hash): def download_image(self, image_path): """Download file from bucket and load it as an image""" - with utils.get_tempdir() as tempdir: + with tempfile.TemporaryDirectory() as tempdir: fname = self.storage.download(image_path, tempdir) image = utils.get_image(fname) return image From f01b9982f9afc98ce48e77b4d800294541e2c633 Mon Sep 17 00:00:00 2001 From: William Graf <7930703+willgraf@users.noreply.github.com> Date: Wed, 20 Jan 2021 19:22:41 -0800 Subject: [PATCH 11/73] Small logging improvements + metadata bugfix --- redis_consumer/consumers/base_consumer.py | 22 ++++------------------ redis_consumer/grpc_clients.py | 1 + 2 files changed, 5 insertions(+), 18 deletions(-) diff --git a/redis_consumer/consumers/base_consumer.py b/redis_consumer/consumers/base_consumer.py index df5179c4..34f18469 100644 --- a/redis_consumer/consumers/base_consumer.py +++ b/redis_consumer/consumers/base_consumer.py @@ -207,22 +207,8 @@ def consume(self): status = self.failed_status if status == self.final_status: - required_fields = [ - 'model_name', - 'model_version', - 'preprocess_function', - 'postprocess_function', - ] - result = self.redis.hmget(redis_hash, *required_fields) - hvals = dict(zip(required_fields, result)) - self.logger.debug('Consumed key %s (model %s:%s, ' - 'preprocessing: %s, postprocessing: %s) ' - '(%s retries) in %s seconds.', - redis_hash, hvals.get('model_name'), - hvals.get('model_version'), - hvals.get('preprocess_function'), - hvals.get('postprocess_function'), - 0, timeit.default_timer() - start) + self.logger.debug('Consumed key `%s` in %s seconds.', + redis_hash, timeit.default_timer() - start) if status in self.finished_statuses: # this key is done. remove the key from the processing queue. @@ -253,7 +239,7 @@ def __init__(self, self._redis_values = dict() super(TensorFlowServingConsumer, self).__init__( redis_client, storage_client, queue, **kwargs) - + def is_valid_hash(self, redis_hash): """Don't run on zip files""" if redis_hash is None: @@ -272,7 +258,7 @@ def download_image(self, image_path): def validate_model_input(self, image, model_name, model_version): """Validate that the input image meets the workflow requirements.""" model_metadata = self.get_model_metadata(model_name, model_version) - shape = [int(x) for x in model_metadata['in_tensor_shape'].split(',')] + shape = [int(x) for x in model_metadata[0]['in_tensor_shape'].split(',')] rank = len(shape) - 1 # ignoring batch dimension channels = shape[-1] diff --git a/redis_consumer/grpc_clients.py b/redis_consumer/grpc_clients.py index a944c88b..dffbfa79 100644 --- a/redis_consumer/grpc_clients.py +++ b/redis_consumer/grpc_clients.py @@ -168,6 +168,7 @@ class PredictClient(GrpcClient): def __init__(self, host, model_name, model_version): super(PredictClient, self).__init__(host) + self.logger = logging.getLogger(f'{model_name}:{model_version}:gRPC') self.model_name = model_name self.model_version = model_version From 0de2f2d13593a12dfc2e97fbcf7993b76e835cb9 Mon Sep 17 00:00:00 2001 From: William Graf <7930703+willgraf@users.noreply.github.com> Date: Wed, 20 Jan 2021 19:32:12 -0800 Subject: [PATCH 12/73] remove some hacky attributes. --- redis_consumer/consumers/base_consumer.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/redis_consumer/consumers/base_consumer.py b/redis_consumer/consumers/base_consumer.py index 34f18469..1674f1ef 100644 --- a/redis_consumer/consumers/base_consumer.py +++ b/redis_consumer/consumers/base_consumer.py @@ -234,9 +234,6 @@ def __init__(self, storage_client, queue, **kwargs): - # Create some attributes only used during consume() - self._redis_hash = None - self._redis_values = dict() super(TensorFlowServingConsumer, self).__init__( redis_client, storage_client, queue, **kwargs) From 5e8512da6c917c9a7908c5a7ce9c9b28ebb07b74 Mon Sep 17 00:00:00 2001 From: William Graf <7930703+willgraf@users.noreply.github.com> Date: Thu, 21 Jan 2021 10:36:23 -0800 Subject: [PATCH 13/73] Add tests for the GrpcModelWrapper --- redis_consumer/grpc_clients_test.py | 82 +++++++++++++++++++++++++++-- 1 file changed, 79 insertions(+), 3 deletions(-) diff --git a/redis_consumer/grpc_clients_test.py b/redis_consumer/grpc_clients_test.py index 2f454a08..755a82d8 100644 --- a/redis_consumer/grpc_clients_test.py +++ b/redis_consumer/grpc_clients_test.py @@ -28,6 +28,8 @@ from __future__ import division from __future__ import print_function +import logging + import pytest import numpy as np @@ -35,9 +37,21 @@ from tensorflow.core.framework.tensor_pb2 import TensorProto from tensorflow_serving.apis.predict_pb2 import PredictResponse -from redis_consumer.testing_utils import _get_image +from redis_consumer.testing_utils import _get_image, make_model_metadata_of_size + +from redis_consumer import grpc_clients, settings + -from redis_consumer import grpc_clients +class DummyPredictClient(object): + # pylint: disable=unused-argument + def __init__(self, host, model_name, model_version): + self.logger = logging.getLogger(self.__class__.__name__) + + def predict(self, request_data, request_timeout=10): + retval = {} + for i, d in enumerate(request_data): + retval[f'prediction{i}'] = d.get('data') + return retval def test_make_tensor_proto(): @@ -76,5 +90,67 @@ def test_grpc_response_to_dict(): response = PredictResponse() response.outputs['prediction'].CopyFrom(tensor_proto) response.outputs['prediction'].dtype = 32 + with pytest.raises(KeyError): - response_dict = grpc_clients.grpc_response_to_dict(response) \ No newline at end of file + response_dict = grpc_clients.grpc_response_to_dict(response) + + +class TestGrpcModelWrapper(object): + shape = (1, 300, 300, 1) + name = 'test-model' + version = '0' + + def _get_metadata(self): + metadata_fn = make_model_metadata_of_size(self.shape) + return metadata_fn(self.name, self.version) + + def test_init(self): + metadata = self._get_metadata() + wrapper = grpc_clients.GrpcModelWrapper(None, metadata) + assert wrapper.input_shape == self.shape + + multi_metadata = [metadata, metadata] + with pytest.raises(NotImplementedError): + wrapper = grpc_clients.GrpcModelWrapper(None, multi_metadata) + + def test_get_batch_size(self, mocker): + metadata = self._get_metadata() + wrapper = grpc_clients.GrpcModelWrapper(None, metadata) + + for m in (.5, 1, 2): + mocker.patch.object(settings, 'TF_MIN_MODEL_SIZE', self.shape[1] * m) + batch_size = wrapper.get_batch_size() + assert batch_size == settings.TF_MAX_BATCH_SIZE * m * m + + def test_send_grpc(self, mocker): + client = DummyPredictClient(1, 2, 3) + metadata = self._get_metadata() + wrapper = grpc_clients.GrpcModelWrapper(client, metadata) + + input_data = np.ones(self.shape) + result = wrapper.send_grpc(input_data) + assert isinstance(result, list) + assert len(result) == 1 + np.testing.assert_array_equal(result[0], input_data) + + input_data = np.ones(self.shape) + mocker.patch.object(wrapper, '_in_tensor_dtype', 'DT_HALF') + result = wrapper.send_grpc(input_data) + assert isinstance(result, list) + assert len(result) == 1 + np.testing.assert_array_equal(result[0], input_data) + + def test_predict(self, mocker): + metadata = self._get_metadata() + wrapper = grpc_clients.GrpcModelWrapper(None, metadata) + + def mock_send_grpc(img): + return [img] + + mocker.patch.object(wrapper, 'send_grpc', mock_send_grpc) + + batch_size = 2 + input_data = np.ones((batch_size * 2, 30, 30, 1)) + + results = wrapper.predict(input_data, batch_size=batch_size) + np.testing.assert_array_equal(input_data, results) From 4fb468611b249b2f1942d738e69857aca2c7048e Mon Sep 17 00:00:00 2001 From: William Graf <7930703+willgraf@users.noreply.github.com> Date: Thu, 21 Jan 2021 10:36:36 -0800 Subject: [PATCH 14/73] Simplify some of the prediction handling in the wrapper --- redis_consumer/grpc_clients.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/redis_consumer/grpc_clients.py b/redis_consumer/grpc_clients.py index dffbfa79..12a8e8a2 100644 --- a/redis_consumer/grpc_clients.py +++ b/redis_consumer/grpc_clients.py @@ -365,7 +365,7 @@ def __init__(self, client, model_metadata): self._in_tensor_dtype = str(self._metadata['in_tensor_dtype']).upper() shape = [int(x) for x in self._metadata['in_tensor_shape'].split(',')] - self.input_shape = shape + self.input_shape = tuple(shape) def send_grpc(self, img): """Use the TensorFlow Serving gRPC API for model inference on an image. @@ -393,9 +393,6 @@ def send_grpc(self, img): [r.shape for r in results], timeit.default_timer() - start) - if len(results) == 1: - results = results[0] - return results def get_batch_size(self): @@ -416,13 +413,10 @@ def predict(self, tiles, batch_size): for t in range(0, tiles.shape[0], batch_size): output = self.send_grpc(tiles[t:t + batch_size]) - if not isinstance(output, list): - output = [output] - if len(results) == 0: results = output else: for i, o in enumerate(output): results[i] = np.vstack((results[i], o)) - return results + return results[0] if len(results) == 1 else results From 41aa3fb97db93e529940f956a9a7b8dff552bf2d Mon Sep 17 00:00:00 2001 From: William Graf <7930703+willgraf@users.noreply.github.com> Date: Thu, 21 Jan 2021 10:36:54 -0800 Subject: [PATCH 15/73] DummyStorage should return the full path to the file. --- redis_consumer/testing_utils.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/redis_consumer/testing_utils.py b/redis_consumer/testing_utils.py index 99105b17..5fb816a1 100644 --- a/redis_consumer/testing_utils.py +++ b/redis_consumer/testing_utils.py @@ -76,12 +76,14 @@ def download(self, path, dest): img = _get_image() base, ext = os.path.splitext(path) _path = '{}{}{}'.format(base, i, ext) - tiff.imsave(os.path.join(dest, _path), img) - paths.append(_path) + outpath = os.path.join(dest, _path) + tiff.imsave(outpath, img) + paths.append(outpath) return utils.zip_files(paths, dest) img = _get_image() - tiff.imsave(os.path.join(dest, path), img) - return path + outpath = os.path.join(dest, path) + tiff.imsave(outpath, img) + return outpath def upload(self, zip_path, subdir=None): return 'zip_path.zip', 'blob.public_url' From 06820fd91d18b7d0fd71f0752164a25d70e155c6 Mon Sep 17 00:00:00 2001 From: William Graf <7930703+willgraf@users.noreply.github.com> Date: Thu, 21 Jan 2021 10:37:34 -0800 Subject: [PATCH 16/73] get correct_drift from deepcell_toolbox. --- redis_consumer/consumers/tracking_consumer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/redis_consumer/consumers/tracking_consumer.py b/redis_consumer/consumers/tracking_consumer.py index cf2a2540..167d207f 100644 --- a/redis_consumer/consumers/tracking_consumer.py +++ b/redis_consumer/consumers/tracking_consumer.py @@ -37,7 +37,9 @@ from skimage.external import tifffile import numpy as np + from deepcell_toolbox.processing import normalize +from deepcell_toolbox.processing import correct_drift from redis_consumer.grpc_clients import TrackingClient from redis_consumer.consumers import TensorFlowServingConsumer @@ -287,7 +289,7 @@ def _consume(self, redis_hash): # Correct for drift if enabled if settings.DRIFT_CORRECT_ENABLED: t = timeit.default_timer() - data['X'], data['y'] = processing.correct_drift(data['X'], data['y']) + data['X'], data['y'] = correct_drift(data['X'], data['y']) self.logger.debug('Drift correction complete in %s seconds.', timeit.default_timer() - t) From 22c40fa86d51a08320b9c9192dc3599fe78c76fa Mon Sep 17 00:00:00 2001 From: William Graf <7930703+willgraf@users.noreply.github.com> Date: Thu, 21 Jan 2021 13:40:30 -0800 Subject: [PATCH 17/73] PEP8 --- redis_consumer/grpc_clients.py | 1 + 1 file changed, 1 insertion(+) diff --git a/redis_consumer/grpc_clients.py b/redis_consumer/grpc_clients.py index 12a8e8a2..ff25c726 100644 --- a/redis_consumer/grpc_clients.py +++ b/redis_consumer/grpc_clients.py @@ -127,6 +127,7 @@ def make_tensor_proto(data, dtype): return tensor_proto + class GrpcClient(object): """Abstract class for all gRPC clients. From d0f28010a3c3543a9783e85c49a677dd00412778 Mon Sep 17 00:00:00 2001 From: William Graf <7930703+willgraf@users.noreply.github.com> Date: Thu, 21 Jan 2021 14:31:02 -0800 Subject: [PATCH 18/73] Edit make_model_metadata_of_size to support multiple metadata inputs --- redis_consumer/testing_utils.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/redis_consumer/testing_utils.py b/redis_consumer/testing_utils.py index 5fb816a1..688ecfdd 100644 --- a/redis_consumer/testing_utils.py +++ b/redis_consumer/testing_utils.py @@ -94,10 +94,16 @@ def get_public_url(self, zip_path): def make_model_metadata_of_size(model_shape=(-1, 256, 256, 2)): + model_shape = model_shape if isinstance(model_shape, list) else [model_shape] + def get_model_metadata(model_name, model_version): # pylint: disable=unused-argument - return [{ - 'in_tensor_name': 'image', - 'in_tensor_dtype': 'DT_FLOAT', - 'in_tensor_shape': ','.join(str(s) for s in model_shape), - }] + output = [] + + for ms in model_shape: + output.append({ + 'in_tensor_name': 'image', + 'in_tensor_dtype': 'DT_FLOAT', + 'in_tensor_shape': ','.join(str(s) for s in ms), + }) + return output return get_model_metadata From bf53b6d102a692bbfbb3e4bd47935e3c436909d6 Mon Sep 17 00:00:00 2001 From: William Graf <7930703+willgraf@users.noreply.github.com> Date: Thu, 21 Jan 2021 14:31:41 -0800 Subject: [PATCH 19/73] validate_model_input makes no assumptions about length of metadata --- redis_consumer/consumers/base_consumer.py | 35 +++++++++++++++-------- 1 file changed, 23 insertions(+), 12 deletions(-) diff --git a/redis_consumer/consumers/base_consumer.py b/redis_consumer/consumers/base_consumer.py index 1674f1ef..e8019b97 100644 --- a/redis_consumer/consumers/base_consumer.py +++ b/redis_consumer/consumers/base_consumer.py @@ -255,25 +255,36 @@ def download_image(self, image_path): def validate_model_input(self, image, model_name, model_version): """Validate that the input image meets the workflow requirements.""" model_metadata = self.get_model_metadata(model_name, model_version) - shape = [int(x) for x in model_metadata[0]['in_tensor_shape'].split(',')] + parse_shape = lambda x: tuple(int(y) for y in x.split(',')) + shapes = [parse_shape(x['in_tensor_shape']) for x in model_metadata] - rank = len(shape) - 1 # ignoring batch dimension - channels = shape[-1] + # cast as image to match with the list of shapes. + image = [image] if not isinstance(image, list) else image - errtext = (f'Invalid image shape: {image.shape}. ' - f'The {self.queue} job expects images of shape ' - f'[height, widths, {channels}]') + errtext = (f'Invalid image shape: {[s.shape for s in image]}. ' + f'The {self.queue} job expects images of shape {shapes}') - if len(image.shape) != rank: + if len(image) != len(shapes): raise ValueError(errtext) - if image.shape[0] == channels: - image = np.rollaxis(image, 0, rank) + validated = [] - if image.shape[rank - 1] != channels: - raise ValueError(errtext) + for img, shape in zip(image, shapes): + rank = len(shape) - 1 # ignoring batch dimension + channels = shape[-1] - return image + if len(img.shape) != rank: + raise ValueError(errtext) + + if img.shape[0] == channels: + img = np.rollaxis(img, 0, rank) + + if img.shape[rank - 1] != channels: + raise ValueError(errtext) + + validated.append(img) + + return validated[0] if len(validated) == 1 else validated def _get_predict_client(self, model_name, model_version): """Returns the TensorFlow Serving gRPC client. From 6b5aac5e88599bc216544438146171d3d0f05bf8 Mon Sep 17 00:00:00 2001 From: William Graf <7930703+willgraf@users.noreply.github.com> Date: Thu, 21 Jan 2021 14:33:50 -0800 Subject: [PATCH 20/73] Create get_image_label which returns 0 if label detection is off. deprecate model_name and model_version being pulled from job hash. --- redis_consumer/consumers/image_consumer.py | 54 ++++++++++------------ 1 file changed, 25 insertions(+), 29 deletions(-) diff --git a/redis_consumer/consumers/image_consumer.py b/redis_consumer/consumers/image_consumer.py index 2281e1fd..ee858cbd 100644 --- a/redis_consumer/consumers/image_consumer.py +++ b/redis_consumer/consumers/image_consumer.py @@ -57,8 +57,8 @@ def detect_label(self, image): app = self.get_grpc_app(settings.LABEL_DETECT_MODEL, LabelDetection) if not settings.LABEL_DETECT_ENABLED: - self.logger.debug('Label detection disabled. Label set to None.') - return None + self.logger.debug('Label detection disabled. Label set to 0.') + return 0 # Use NuclearSegmentation as default model batch_size = app.model.get_batch_size() detected_label = app.predict(image, batch_size=batch_size) @@ -68,6 +68,21 @@ def detect_label(self, image): return detected_label + def get_image_label(self, label, image, redis_hash): + """Calculate label of image.""" + if not label: + # Detect scale of image (Default to 1) + label = self.detect_label(image) + self.logger.debug('Image scale detected: %s', label) + self.update_key(redis_hash, {'label': label}) + else: + label = int(label) + self.logger.debug('Image label already calculated %s', label) + if label not in settings.APPLICATION_CHOICES: + raise ValueError('Label type {} is not supported'.format(label)) + + return label + def _consume(self, redis_hash): start = timeit.default_timer() hvals = self.redis.hgetall(redis_hash) @@ -85,10 +100,6 @@ def _consume(self, redis_hash): 'identity_started': self.name, }) - # Overridden with LABEL_DETECT_ENABLED - model_name = hvals.get('model_name') - model_version = hvals.get('model_version') - _ = timeit.default_timer() # Load input image @@ -106,27 +117,14 @@ def _consume(self, redis_hash): scale = hvals.get('scale', '') scale = self.get_image_scale(scale, image, redis_hash) - label = None - if settings.LABEL_DETECT_ENABLED and model_name and model_version: - self.logger.warning('Label Detection is enabled, but the model' - ' %s:%s was specified in Redis.', - model_name, model_version) + label = hvals.get('label', '') + label = self.get_image_label(label, image, redis_hash) - elif settings.LABEL_DETECT_ENABLED: - # Detect image label type - label = hvals.get('label', '') - if not label: - label = self.detect_label(image) - self.logger.debug('Image label detected: %s', label) - self.update_key(redis_hash, {'label': str(label)}) - else: - label = int(label) - self.logger.debug('Image label already calculated: %s', label) + # Grap appropriate model and application class + model = settings.MODEL_CHOICES[label] + app_cls = settings.APPLICATION_CHOICES[label] - label = 0 # TODO: remove this! hotfix for bad label detection. - - # Grap appropriate model - model_name, model_version = utils._pick_model(label) + model_name, model_version = model.split(':') # Validate input image # TODO: batch dimension wonkiness @@ -134,14 +132,12 @@ def _consume(self, redis_hash): image = self.validate_model_input(image, model_name, model_version) image = np.expand_dims(image, axis=0) # add batch dim back - app_cls = settings.APPLICATION_CHOICES.get(label, NuclearSegmentation) - # Send data to the model self.update_key(redis_hash, {'status': 'predicting'}) - app = self.get_grpc_app(f'{model_name}:{model_version}', app_cls) + app = self.get_grpc_app(model, app_cls) - results = app.predict(image, image_mpp=scale, + results = app.predict(image, image_mpp=scale * app.model_mpp, batch_size=app.model.get_batch_size()) # Save the post-processed results to a file From 994715d50712b67e9878b57420a51a8886388b4e Mon Sep 17 00:00:00 2001 From: William Graf <7930703+willgraf@users.noreply.github.com> Date: Thu, 21 Jan 2021 14:35:00 -0800 Subject: [PATCH 21/73] Update scale detection to not do any mpp work. Multiply model_mpp by scale when calling predict() --- redis_consumer/consumers/base_consumer.py | 5 +++-- redis_consumer/consumers/multiplex_consumer.py | 6 +++--- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/redis_consumer/consumers/base_consumer.py b/redis_consumer/consumers/base_consumer.py index e8019b97..1a4d8e0c 100644 --- a/redis_consumer/consumers/base_consumer.py +++ b/redis_consumer/consumers/base_consumer.py @@ -391,7 +391,7 @@ def detect_scale(self, image): if not settings.SCALE_DETECT_ENABLED: self.logger.debug('Scale detection disabled.') - return app.model_mpp + return 1 # app.model_mpp batch_size = app.model.get_batch_size() detected_scale = app.predict(image, batch_size=batch_size) @@ -399,7 +399,8 @@ def detect_scale(self, image): self.logger.debug('Scale %s detected in %s seconds', detected_scale, timeit.default_timer() - start) - return app.model_mpp * detected_scale + # detected_scale = detected_scale * app.model_mpp + return detected_scale def get_image_scale(self, scale, image, redis_hash): """Calculate scale of image and rescale""" diff --git a/redis_consumer/consumers/multiplex_consumer.py b/redis_consumer/consumers/multiplex_consumer.py index 8ba6f124..e8ae7ebd 100644 --- a/redis_consumer/consumers/multiplex_consumer.py +++ b/redis_consumer/consumers/multiplex_consumer.py @@ -57,7 +57,7 @@ def detect_scale(self, image): if not settings.SCALE_DETECT_ENABLED: self.logger.debug('Scale detection disabled.') - return app.model_mpp + return 1 # TODO: What to do with multi-channel data? # detected_scale = app.predict(image[..., 0]) @@ -66,7 +66,7 @@ def detect_scale(self, image): self.logger.debug('Scale %s detected in %s seconds', detected_scale, timeit.default_timer() - start) - return app.model_mpp * detected_scale + return detected_scale def _consume(self, redis_hash): start = timeit.default_timer() @@ -117,7 +117,7 @@ def _consume(self, redis_hash): app = self.get_grpc_app( f'{model_name}:{model_version}', MultiplexSegmentation) - results = app.predict(image, image_mpp=scale, + results = app.predict(image, image_mpp=scale * app.model_mpp, batch_size=settings.TF_MAX_BATCH_SIZE) # Save the post-processed results to a file From 14a716b4db5de30cc3b14442808eab7eb5b23fab Mon Sep 17 00:00:00 2001 From: William Graf <7930703+willgraf@users.noreply.github.com> Date: Thu, 21 Jan 2021 14:35:25 -0800 Subject: [PATCH 22/73] Move input validation after scale detection --- redis_consumer/consumers/multiplex_consumer.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/redis_consumer/consumers/multiplex_consumer.py b/redis_consumer/consumers/multiplex_consumer.py index e8ae7ebd..249d04de 100644 --- a/redis_consumer/consumers/multiplex_consumer.py +++ b/redis_consumer/consumers/multiplex_consumer.py @@ -97,9 +97,6 @@ def _consume(self, redis_hash): # squeeze extra dimension that is added by get_image image = np.squeeze(image) - # Validate input image - image = self.validate_model_input(image, model_name, model_version) - # Pre-process data before sending to the model self.update_key(redis_hash, { 'status': 'pre-processing', @@ -110,12 +107,14 @@ def _consume(self, redis_hash): scale = hvals.get('scale', '') scale = self.get_image_scale(scale, image, redis_hash) - # Rescale each channel of the image image = np.expand_dims(image, axis=0) # add in the batch dim + # Validate input image + image = self.validate_model_input(image, model_name, model_version) + # Send data to the model - app = self.get_grpc_app( - f'{model_name}:{model_version}', MultiplexSegmentation) + app = self.get_grpc_app(settings.MULTIPLEX_MODEL, + MultiplexSegmentation) results = app.predict(image, image_mpp=scale * app.model_mpp, batch_size=settings.TF_MAX_BATCH_SIZE) From 642bdaae767ad42854cee134167cdd0843b50a90 Mon Sep 17 00:00:00 2001 From: William Graf <7930703+willgraf@users.noreply.github.com> Date: Thu, 21 Jan 2021 14:35:50 -0800 Subject: [PATCH 23/73] Add tests for new functions --- .../consumers/base_consumer_test.py | 229 ++++++++++-------- 1 file changed, 131 insertions(+), 98 deletions(-) diff --git a/redis_consumer/consumers/base_consumer_test.py b/redis_consumer/consumers/base_consumer_test.py index 98247cb5..dd06189b 100644 --- a/redis_consumer/consumers/base_consumer_test.py +++ b/redis_consumer/consumers/base_consumer_test.py @@ -28,8 +28,9 @@ from __future__ import division from __future__ import print_function -import itertools import json +import os +import random import time import numpy as np @@ -37,24 +38,29 @@ import pytest from redis_consumer import consumers +from redis_consumer.grpc_clients import GrpcModelWrapper from redis_consumer import settings -from redis_consumer.testing_utils import Bunch, DummyStorage, redis_client +from redis_consumer.testing_utils import _get_image +from redis_consumer.testing_utils import Bunch +from redis_consumer.testing_utils import DummyStorage +from redis_consumer.testing_utils import redis_client # pylint: disable=unused-import +from redis_consumer.testing_utils import make_model_metadata_of_size class TestConsumer(object): # pylint: disable=R0201,W0621 def test_get_redis_hash(self, mocker, redis_client): mocker.patch.object(settings, 'EMPTY_QUEUE_TIMEOUT', 0.01) - queue_name = 'q' - consumer = consumers.Consumer(redis_client, None, queue_name) + queue = 'q' + consumer = consumers.Consumer(redis_client, None, queue) # test emtpy queue assert consumer.get_redis_hash() is None # test that invalid items are not processed and are removed. item = 'item to process' - redis_client.lpush(queue_name, item) + redis_client.lpush(queue, item) # is_valid_hash returns True by default assert consumer.get_redis_hash() == item assert redis_client.llen(consumer.processing_queue) == 1 @@ -64,7 +70,7 @@ def test_get_redis_hash(self, mocker, redis_client): # test invalid hash is failed and removed from queue mocker.patch.object(consumer, 'is_valid_hash', return_value=False) - redis_client.lpush(queue_name, 'invalid') + redis_client.lpush(queue, 'invalid') assert consumer.get_redis_hash() is None # invalid hash, returns None # invalid has was removed from the processing queue assert redis_client.llen(consumer.processing_queue) == 0 @@ -76,9 +82,9 @@ def test_get_redis_hash(self, mocker, redis_client): assert consumer.get_redis_hash() is None def test_purge_processing_queue(self, redis_client): - queue_name = 'q' + queue = 'q' keys = ['abc', 'def', 'xyz'] - consumer = consumers.Consumer(redis_client, None, queue_name) + consumer = consumers.Consumer(redis_client, None, queue) # set keys in processing queue for key in keys: redis_client.lpush(consumer.processing_queue, key) @@ -116,16 +122,16 @@ def test_handle_error(self, redis_client): assert redis_values.get('status') == 'failed' def test__put_back_hash(self, redis_client): - queue_name = 'q' + queue = 'q' # test emtpy queue - consumer = consumers.Consumer(redis_client, None, queue_name) + consumer = consumers.Consumer(redis_client, None, queue) consumer._put_back_hash('DNE') # should be None, shows warning # put back the proper item item = 'redis_hash1' redis_client.lpush(consumer.processing_queue, item) - consumer = consumers.Consumer(redis_client, None, queue_name) + consumer = consumers.Consumer(redis_client, None, queue) consumer._put_back_hash(item) assert redis_client.llen(consumer.processing_queue) == 0 assert redis_client.llen(consumer.queue) == 1 @@ -134,7 +140,7 @@ def test__put_back_hash(self, redis_client): # put back the wrong item other = 'otherhash' redis_client.lpush(consumer.processing_queue, other, item) - consumer = consumers.Consumer(redis_client, None, queue_name) + consumer = consumers.Consumer(redis_client, None, queue) consumer._put_back_hash(item) assert redis_client.llen(consumer.processing_queue) == 0 assert redis_client.llen(consumer.queue) == 2 @@ -143,12 +149,12 @@ def test__put_back_hash(self, redis_client): def test_consume(self, mocker, redis_client): mocker.patch.object(settings, 'EMPTY_QUEUE_TIMEOUT', 0) - queue_name = 'q' + queue = 'q' keys = [str(x) for x in range(1, 10)] err = OSError('thrown on purpose') i = 0 - consumer = consumers.Consumer(redis_client, DummyStorage(), queue_name) + consumer = consumers.Consumer(redis_client, DummyStorage(), queue) def throw_error(*_, **__): raise err @@ -208,6 +214,7 @@ def test__consume(self): class TestTensorFlowServingConsumer(object): # pylint: disable=R0201,W0613,W0621 + def test_is_valid_hash(self, mocker, redis_client): storage = DummyStorage() mocker.patch.object(redis_client, 'hget', lambda x, y: x.split(':')[-1]) @@ -222,41 +229,86 @@ def test_is_valid_hash(self, mocker, redis_client): assert consumer.is_valid_hash('predict:1234567890:file.tiff') is True assert consumer.is_valid_hash('predict:1234567890:file.png') is True - def test__get_predict_client(self, redis_client): - stg = DummyStorage() - consumer = consumers.TensorFlowServingConsumer(redis_client, stg, 'q') - - with pytest.raises(ValueError): - consumer._get_predict_client('model_name', 'model_version') + def test_download_image(self, redis_client): + storage = DummyStorage() + consumer = consumers.TensorFlowServingConsumer(redis_client, storage, 'q') - consumer._get_predict_client('model_name', 1) + image = consumer.download_image('test.tif') + assert isinstance(image, np.ndarray) + assert not os.path.exists('test.tif') - def test_grpc_image(self, mocker, redis_client): + def test_validate_model_input(self, mocker, redis_client): storage = DummyStorage() - queue = 'q' + consumer = consumers.TensorFlowServingConsumer(redis_client, storage, 'q') - consumer = consumers.TensorFlowServingConsumer( - redis_client, storage, queue) + model_input_shape = (-1, 32, 32, 1) - model_shape = (-1, 128, 128, 1) + mocked_metadata = make_model_metadata_of_size(model_input_shape) + mocker.patch.object(consumer, 'get_model_metadata', mocked_metadata) - def _get_predict_client(model_name, model_version): - return Bunch(predict=lambda x, y: { - 'prediction': x[0]['data'] - }) + # test valid channels last shapes + valid_input_shapes = [ + (32, 32, 1), # exact same shape + (64, 64, 1), # bigger + (32, 32, 1), # smaller + (33, 31, 1), # mixed + ] + for shape in valid_input_shapes: + # check channels last + img = np.ones(shape) + valid_img = consumer.validate_model_input(img, 'model', '1') + np.testing.assert_array_equal(img, valid_img) + + # should also work for channels first + img = np.rollaxis(img, -1, 0) + valid_img = consumer.validate_model_input(img, 'model', '1') + expected_img = np.rollaxis(img, 0, img.ndim) + np.testing.assert_array_equal(expected_img, valid_img) + + # test invalid shapes + invalid_input_shapes = [ + (32, 1), # rank too small + (32, 32, 32, 1), # rank too large + (32, 32, 2), # wrong channels + (16, 64, 2), # wrong channels with mixed shape + ] + for shape in invalid_input_shapes: + img = np.ones(shape) + with pytest.raises(ValueError): + consumer.validate_model_input(img, 'model', '1') + + # test multiple inputs/metadata + count = 3 + model_input_shape = [(-1, 32, 32, 1)] * count + mocked_metadata = make_model_metadata_of_size(model_input_shape) + mocker.patch.object(consumer, 'get_model_metadata', mocked_metadata) + image = [np.ones(s) for s in valid_input_shapes[:count]] + valid_img = consumer.validate_model_input(image, 'model', '1') + # each image should be validated + for i, j in zip(image, valid_img): + np.testing.assert_array_equal(i, j) + + # metadata and image counts do not match + with pytest.raises(ValueError): + image = [np.ones(s) for s in valid_input_shapes[:count]] + consumer.validate_model_input(img, 'model', '1') - mocker.patch.object(consumer, '_get_predict_client', _get_predict_client) + # correct number of inputs, but one invalid entry + with pytest.raises(ValueError): + image = [np.ones(s) for s in valid_input_shapes[:count]] + # set a random entry to be invalid + i = random.randint(0, count) + image[i] = np.ones(random.choice(invalid_input_shapes)) + consumer.validate_model_input(image, 'model', '1') - img = np.zeros((1, 32, 32, 3)) - out = consumer.grpc_image(img, 'model', 1, model_shape, 'i', 'DT_HALF') - assert img.shape == out.shape - assert img.sum() == out.sum() + def test__get_predict_client(self, redis_client): + stg = DummyStorage() + consumer = consumers.TensorFlowServingConsumer(redis_client, stg, 'q') + + with pytest.raises(ValueError): + consumer._get_predict_client('model_name', 'model_version') - img = np.zeros((32, 32, 3)) - consumer._redis_hash = 'not none' - out = consumer.grpc_image(img, 'model', 1, model_shape, 'i', 'DT_HALF') - assert (1,) + img.shape == out.shape - assert img.sum() == out.sum() + consumer._get_predict_client('model_name', 1) def test_get_model_metadata(self, mocker, redis_client): model_shape = (-1, 216, 216, 1) @@ -338,67 +390,10 @@ def _get_bad_predict_client(model_name, model_version): _get_bad_predict_client) consumer.get_model_metadata('model', 1) - def test_predict(self, mocker, redis_client): - model_shape = (-1, 128, 128, 1) - stg = DummyStorage() - consumer = consumers.TensorFlowServingConsumer(redis_client, stg, 'q') - - mocker.patch.object(settings, 'TF_MAX_BATCH_SIZE', 2) - mocker.patch.object(consumer, 'get_model_metadata', lambda x, y: [{ - 'in_tensor_name': 'image', - 'in_tensor_dtype': 'DT_HALF', - 'in_tensor_shape': ','.join(str(s) for s in model_shape), - }]) - - def grpc_image(data, *args, **kwargs): - return data - - def grpc_image_list(data, *args, **kwargs): # pylint: disable=W0613 - return [data, data] - - image_shapes = [ - (256, 256, 1), - (128, 128, 1), - (64, 64, 1), - (100, 100, 1), - (300, 300, 1), - (257, 301, 1), - (65, 127, 1), - (127, 129, 1), - ] - grpc_funcs = (grpc_image, grpc_image_list) - untiles = (False, True) - prod = itertools.product(image_shapes, grpc_funcs, untiles) - - for image_shape, grpc_func, untile in prod: - x = np.random.random(image_shape) - mocker.patch.object(consumer, 'grpc_image', grpc_func) - - consumer.predict(x, model_name='modelname', model_version=0, - untile=untile) - - # test image larger than max dimensions - with pytest.raises(ValueError): - mocker.patch.object(settings, 'MAX_IMAGE_WIDTH', 300) - mocker.patch.object(settings, 'MAX_IMAGE_HEIGHT', 300) - x = np.random.random((301, 301, 1)) - consumer.predict(x, model_name='modelname', model_version=0) - - # test mismatch of input data and model shape - with pytest.raises(ValueError): - x = np.random.random((5,)) - consumer.predict(x, model_name='modelname', model_version=0) - - # test multiple model metadata inputs are not supported - with pytest.raises(ValueError): - mocker.patch.object(consumer, 'get_model_metadata', grpc_image_list) - x = np.random.random((300, 300, 1)) - consumer.predict(x, model_name='modelname', model_version=0) - def test_get_image_scale(self, mocker, redis_client): stg = DummyStorage() consumer = consumers.TensorFlowServingConsumer(redis_client, stg, 'q') - image = np.random.random((256, 256, 1)) + image = _get_image(256, 256, 1) # test no scale provided expected = 2 @@ -421,6 +416,44 @@ def test_get_image_scale(self, mocker, redis_client): scale = settings.MIN_SCALE - 0.1 consumer.get_image_scale(scale, image, 'some hash') + def test_get_grpc_app(self, mocker, redis_client): + stg = DummyStorage() + consumer = consumers.TensorFlowServingConsumer(redis_client, stg, 'q') + + get_metadata = make_model_metadata_of_size() + get_mock_client = lambda *x: Bunch(predict=lambda *x: None) + mocker.patch.object(consumer, 'get_model_metadata', get_metadata) + mocker.patch.object(consumer, '_get_predict_client', get_mock_client) + + app = consumer.get_grpc_app('model:0', lambda x: x) + + assert isinstance(app, GrpcModelWrapper) + + def test_detect_scale(self, mocker, redis_client): + # pylint: disable=W0613 + shape = (1, 256, 256, 1) + consumer = consumers.TensorFlowServingConsumer(redis_client, None, 'q') + + image = _get_image(shape[1] * 2, shape[2] * 2, shape[3]) + + expected_scale = random.uniform(0.5, 1.5) + # model_mpp = random.uniform(0.5, 1.5) + + mock_app = Bunch( + predict=lambda *x, **y: expected_scale, + # model_mpp=model_mpp, + model=Bunch(get_batch_size=lambda *x: 1)) + + mocker.patch.object(consumer, 'get_grpc_app', lambda *x: mock_app) + + mocker.patch.object(settings, 'SCALE_DETECT_ENABLED', False) + scale = consumer.detect_scale(image) + assert scale == 1 # model_mpp + + mocker.patch.object(settings, 'SCALE_DETECT_ENABLED', True) + scale = consumer.detect_scale(image) + assert scale == expected_scale # * model_mpp + class TestZipFileConsumer(object): # pylint: disable=R0201,W0613,W0621 From 3b3dc625cbc29496a88f614e4aa94b7a920a7cda Mon Sep 17 00:00:00 2001 From: William Graf <7930703+willgraf@users.noreply.github.com> Date: Thu, 21 Jan 2021 14:36:13 -0800 Subject: [PATCH 24/73] Vastly simplify tests for imagefile and multiplex consumers. --- .../consumers/image_consumer_test.py | 185 +++++++----------- .../consumers/multiplex_consumer_test.py | 120 ++++-------- 2 files changed, 110 insertions(+), 195 deletions(-) diff --git a/redis_consumer/consumers/image_consumer_test.py b/redis_consumer/consumers/image_consumer_test.py index 019681ea..5d460dcb 100644 --- a/redis_consumer/consumers/image_consumer_test.py +++ b/redis_consumer/consumers/image_consumer_test.py @@ -28,7 +28,7 @@ from __future__ import division from __future__ import print_function -import itertools +import random import numpy as np @@ -36,8 +36,11 @@ from redis_consumer import consumers from redis_consumer import settings -from redis_consumer.testing_utils import DummyStorage, redis_client -from redis_consumer.testing_utils import _get_image, make_model_metadata_of_size + +from redis_consumer.testing_utils import Bunch +from redis_consumer.testing_utils import DummyStorage +from redis_consumer.testing_utils import redis_client +from redis_consumer.testing_utils import _get_image class TestImageFileConsumer(object): @@ -45,118 +48,62 @@ class TestImageFileConsumer(object): def test_detect_label(self, mocker, redis_client): # pylint: disable=W0613 - model_shape = (1, 216, 216, 1) - consumer = consumers.ImageFileConsumer(redis_client, None, 'q') + shape = (1, 256, 256, 1) + queue = 'q' + consumer = consumers.ImageFileConsumer(redis_client, None, queue) - def dummy_metadata(*_, **__): - return { - 'in_tensor_dtype': 'DT_FLOAT', - 'in_tensor_shape': ','.join(str(s) for s in model_shape), - } + expected_label = random.randint(1, 9) - image = _get_image(model_shape[1] * 2, model_shape[2] * 2) + mock_app = Bunch( + predict=lambda *x, **y: expected_label, + model=Bunch(get_batch_size=lambda *x: 1)) - def predict(*_, **__): - data = np.zeros((3,)) - i = np.random.randint(3) - data[i] = 1 - return data + mocker.patch.object(consumer, 'get_grpc_app', lambda *x: mock_app) - mocker.patch.object(consumer, 'predict', predict) - mocker.patch.object(consumer, 'get_model_metadata', dummy_metadata) - mocker.patch.object(settings, 'LABEL_DETECT_MODEL', 'dummymodel:1') + image = _get_image(shape[1] * 2, shape[2] * 2, shape[3]) mocker.patch.object(settings, 'LABEL_DETECT_ENABLED', False) label = consumer.detect_label(image) - assert label is None + assert label == 0 mocker.patch.object(settings, 'LABEL_DETECT_ENABLED', True) label = consumer.detect_label(image) - assert label in set(list(range(4))) - - def test_detect_scale(self, mocker, redis_client): - # pylint: disable=W0613 - # TODO: test rescale is < 1% of the original - model_shape = (1, 216, 216, 1) - consumer = consumers.ImageFileConsumer(redis_client, None, 'q') + assert label == expected_label - def dummy_metadata(*_, **__): - return { - 'in_tensor_dtype': 'DT_FLOAT', - 'in_tensor_shape': ','.join(str(s) for s in model_shape), - } - - big_size = model_shape[1] * np.random.randint(2, 9) - image = _get_image(big_size, big_size) + def test_get_image_label(self, mocker, redis_client): + queue = 'q' + stg = DummyStorage() + consumer = consumers.ImageFileConsumer(redis_client, stg, queue) + image = _get_image(256, 256, 1) + # test no label provided expected = 1 - - def predict(diff=1e-8): - def _predict(*_, **__): - sign = -1 if np.random.randint(1, 5) > 2 else 1 - return expected + sign * diff - return _predict - - mocker.patch.object(consumer, 'get_model_metadata', dummy_metadata) - mocker.patch.object(settings, 'SCALE_DETECT_ENABLED', False) - mocker.patch.object(settings, 'SCALE_DETECT_MODEL', 'dummymodel:1') - scale = consumer.detect_scale(image) - assert scale == 1 - - mocker.patch.object(settings, 'SCALE_DETECT_ENABLED', True) - mocker.patch.object(consumer, 'predict', predict(1e-8)) - scale = consumer.detect_scale(image) - # very small changes within error range: - assert scale == 1 - - mocker.patch.object(settings, 'SCALE_DETECT_ENABLED', True) - mocker.patch.object(consumer, 'predict', predict(1e-1)) - scale = consumer.detect_scale(image) - assert isinstance(scale, float) - np.testing.assert_almost_equal(scale, expected, 1e-1) - - def test__consume(self, mocker, redis_client): - # pylint: disable=W0613 - prefix = 'predict' - status = 'new' + mocker.patch.object(consumer, 'detect_label', lambda *x: expected) + label = consumer.get_image_label(None, image, 'some hash') + assert label == expected + + # test label provided + expected = 2 + label = consumer.get_image_label(expected, image, 'some hash') + assert label == expected + + # test label provided is invalid + with pytest.raises(ValueError): + label = -1 + consumer.get_image_label(label, image, 'some hash') + + # test label provided is bad type + with pytest.raises(ValueError): + label = 'badval' + consumer.get_image_label(label, image, 'some hash') + + def test__consume_finished_status(self, redis_client): + queue = 'q' storage = DummyStorage() - consumer = consumers.ImageFileConsumer(redis_client, storage, prefix) - - def grpc_image(data, *args, **kwargs): - return data - - def grpc_image_multi(data, *args, **kwargs): - return np.array(tuple(list(data.shape) + [2])) - - def grpc_image_list(data, *args, **kwargs): # pylint: disable=W0613 - return [data, data] - - mocker.patch.object(consumer, 'detect_label', lambda x: 1) - mocker.patch.object(consumer, 'detect_scale', lambda x: 1) - mocker.patch.object(settings, 'LABEL_DETECT_ENABLED', True) - - grpc_funcs = (grpc_image, grpc_image_list) - model_shapes = [ - (-1, 600, 600, 1), # image too small, pad - (-1, 300, 300, 1), # image is exactly the right size - (-1, 150, 150, 1), # image too big, tile - (-1, 150, 600, 1), # image has one size too small, one size too big - (-1, 600, 150, 1), # image has one size too small, one size too big - ] + consumer = consumers.ImageFileConsumer(redis_client, storage, queue) empty_data = {'input_file_name': 'file.tiff'} - full_data = { - 'input_file_name': 'file.tiff', - 'model_version': '0', - 'model_name': 'model', - 'label': '1', - 'scale': '1', - } - label_no_model_data = full_data.copy() - label_no_model_data['model_name'] = '' - - datasets = [empty_data, full_data, label_no_model_data] test_hash = 0 # test finished statuses are returned @@ -171,17 +118,35 @@ def grpc_image_list(data, *args, **kwargs): # pylint: disable=W0613 assert result == status test_hash += 1 - prod = itertools.product(model_shapes, grpc_funcs, datasets) + def test__consume(self, mocker, redis_client): + # pylint: disable=W0613 + queue = 'predict' + storage = DummyStorage() - for model_shape, grpc_func, data in prod: - metadata = make_model_metadata_of_size(model_shape) - mocker.patch.object(consumer, 'grpc_image', grpc_func) - mocker.patch.object(consumer, 'get_model_metadata', metadata) - mocker.patch.object(consumer, 'process', lambda *x: x[0]) + consumer = consumers.ImageFileConsumer(redis_client, storage, queue) - redis_client.hmset(test_hash, data) - result = consumer._consume(test_hash) - assert result == consumer.final_status - result = redis_client.hget(test_hash, 'status') - assert result == consumer.final_status - test_hash += 1 + empty_data = {'input_file_name': 'file.tiff'} + + output_shape = (1, 32, 32, 1) + + mock_app = Bunch( + predict=lambda *x, **y: np.random.randint(1, 5, size=output_shape), + model_mpp=1, + model=Bunch( + get_batch_size=lambda *x: 1, + input_shape=(1, 32, 32, 1) + ) + ) + + mocker.patch.object(consumer, 'get_grpc_app', lambda *x: mock_app) + mocker.patch.object(consumer, 'get_image_scale', lambda *x: 1) + mocker.patch.object(consumer, 'get_image_label', lambda *x: 1) + mocker.patch.object(consumer, 'validate_model_input', lambda *x: True) + + test_hash = 'some hash' + + redis_client.hmset(test_hash, empty_data) + result = consumer._consume(test_hash) + assert result == consumer.final_status + result = redis_client.hget(test_hash, 'status') + assert result == consumer.final_status diff --git a/redis_consumer/consumers/multiplex_consumer_test.py b/redis_consumer/consumers/multiplex_consumer_test.py index ee810616..3bc5a3ee 100644 --- a/redis_consumer/consumers/multiplex_consumer_test.py +++ b/redis_consumer/consumers/multiplex_consumer_test.py @@ -28,66 +28,32 @@ from __future__ import division from __future__ import print_function -import itertools - import numpy as np import pytest from redis_consumer import consumers -from redis_consumer.testing_utils import redis_client, DummyStorage -from redis_consumer.testing_utils import make_model_metadata_of_size +from redis_consumer.testing_utils import Bunch +from redis_consumer.testing_utils import DummyStorage +from redis_consumer.testing_utils import redis_client class TestMultiplexConsumer(object): - # pylint: disable=R0201 - - def test__consume(self, mocker, redis_client): - # pylint: disable=W0613 - - def make_grpc_image(model_shape=(-1, 256, 256, 2)): - # pylint: disable=E1101 - shape = model_shape[1:-1] - - def grpc(data, *args, **kwargs): - inner_shape = tuple([1] + list(shape) + [1]) - feature_shape = tuple([1] + list(shape) + [3]) - - inner = np.random.random(inner_shape) - feature = np.random.random(feature_shape) - - inner2 = np.random.random(inner_shape) - feature2 = np.random.random(feature_shape) - return [inner, feature, inner2, feature2] - - return grpc - - image_shapes = [ - (2, 300, 300), # channels first - (300, 300, 2), # channels last - ] + # pylint: disable=R0201,W0621 - model_shapes = [ - (-1, 600, 600, 2), # image too small, pad - (-1, 300, 300, 2), # image is exactly the right size - (-1, 150, 150, 2), # image too big, tile - (-1, 150, 600, 2), # image has one size too small, one size too big - (-1, 600, 150, 2), # image has one size too small, one size too big - ] + def test__consume_finished_status(self, redis_client): + queue = 'q' + storage = DummyStorage() - scales = ['.9', '1.1', ''] + consumer = consumers.MultiplexConsumer(redis_client, storage, queue) - job_data = { - 'input_file_name': 'file.tiff', - } - - consumer = consumers.MultiplexConsumer(redis_client, DummyStorage(), 'multiplex') + empty_data = {'input_file_name': 'file.tiff'} test_hash = 0 # test finished statuses are returned for status in (consumer.failed_status, consumer.final_status): test_hash += 1 - data = job_data.copy() + data = empty_data.copy() data['status'] = status redis_client.hmset(test_hash, data) result = consumer._consume(test_hash) @@ -96,50 +62,34 @@ def grpc(data, *args, **kwargs): assert result == status test_hash += 1 - prod = itertools.product(model_shapes, scales, image_shapes) + def test__consume(self, mocker, redis_client): + # pylint: disable=W0613 + queue = 'multiplex' + storage = DummyStorage() - for model_shape, scale, image_shape in prod: - mocker.patch('redis_consumer.utils.get_image', - lambda x: np.random.random(list(image_shape) + [1])) + consumer = consumers.MultiplexConsumer(redis_client, storage, queue) - metadata = make_model_metadata_of_size(model_shape) - grpc_image = make_grpc_image(model_shape) - mocker.patch.object(consumer, 'get_model_metadata', metadata) - mocker.patch.object(consumer, 'grpc_image', grpc_image) - mocker.patch.object(consumer, 'postprocess', - lambda *x: np.random.randint(0, 5, size=(300, 300, 1))) + empty_data = {'input_file_name': 'file.tiff'} - data = job_data.copy() - data['scale'] = scale + output_shape = (1, 256, 256, 2) - redis_client.hmset(test_hash, data) - result = consumer._consume(test_hash) - assert result == consumer.final_status - result = redis_client.hget(test_hash, 'status') - assert result == consumer.final_status - test_hash += 1 + mock_app = Bunch( + predict=lambda *x, **y: np.random.randint(1, 5, size=output_shape), + model_mpp=1, + model=Bunch( + get_batch_size=lambda *x: 1, + input_shape=(1, 32, 32, 1) + ) + ) - model_shape = (-1, 150, 150, 2) - invalid_image_shapes = [ - (150, 150), - (150,), - (150, 150, 1), - (1, 150, 150), - (3, 150, 150), - (1, 1, 150, 150) - ] - - for image_shape in invalid_image_shapes: - mocker.patch('redis_consumer.utils.get_image', - lambda x: np.random.random(list(image_shape) + [1])) - metadata = make_model_metadata_of_size(model_shape) - grpc_image = make_grpc_image(model_shape) - mocker.patch.object(consumer, 'get_model_metadata', metadata) - mocker.patch.object(consumer, 'grpc_image', grpc_image) - - data = job_data.copy() - data['scale'] = '1' + mocker.patch.object(consumer, 'get_grpc_app', lambda *x: mock_app) + mocker.patch.object(consumer, 'get_image_scale', lambda *x: 1) + mocker.patch.object(consumer, 'validate_model_input', lambda *x: x[0]) - redis_client.hmset(test_hash, data) - with pytest.raises(ValueError, match='Invalid image shape'): - _ = consumer._consume(test_hash) + test_hash = 'some hash' + + redis_client.hmset(test_hash, empty_data) + result = consumer._consume(test_hash) + assert result == consumer.final_status + result = redis_client.hget(test_hash, 'status') + assert result == consumer.final_status From ccdb96e029e880f1230b3d1883fe15cb8b400e43 Mon Sep 17 00:00:00 2001 From: William Graf <7930703+willgraf@users.noreply.github.com> Date: Thu, 21 Jan 2021 14:38:02 -0800 Subject: [PATCH 25/73] remove _pick_model as it is simply a dict lookup --- redis_consumer/utils.py | 9 --------- redis_consumer/utils_test.py | 11 ----------- 2 files changed, 20 deletions(-) diff --git a/redis_consumer/utils.py b/redis_consumer/utils.py index bde7de46..03eda5dc 100644 --- a/redis_consumer/utils.py +++ b/redis_consumer/utils.py @@ -227,12 +227,3 @@ def zip_files(files, dest=None, prefix=None): logger.debug('Zipped %s files into %s in %s seconds.', len(files), filepath, timeit.default_timer() - start) return filepath - - -def _pick_model(label): - model = settings.MODEL_CHOICES.get(label) - if model is None: - logger.error('Label type %s is not supported', label) - raise ValueError('Label type {} is not supported'.format(label)) - - return model.split(':') diff --git a/redis_consumer/utils_test.py b/redis_consumer/utils_test.py index 2f0cdb3c..f7a46861 100644 --- a/redis_consumer/utils_test.py +++ b/redis_consumer/utils_test.py @@ -221,14 +221,3 @@ def test_zip_files(tmpdir): with pytest.raises(Exception): bad_dest = os.path.join(tmpdir, 'does', 'not', 'exist') zip_path = utils.zip_files(paths, bad_dest, prefix) - - -def test__pick_model(mocker): - mocker.patch.object(settings, 'MODEL_CHOICES', {0: 'dummymodel:0'}) - res = utils._pick_model(0) - assert len(res) == 2 - assert res[0] == 'dummymodel' - assert res[1] == '0' - - with pytest.raises(ValueError): - utils._pick_model(-1) From 9ed261e7e57eb927b26960d30ff65962c946a4f9 Mon Sep 17 00:00:00 2001 From: William Graf <7930703+willgraf@users.noreply.github.com> Date: Thu, 21 Jan 2021 14:38:18 -0800 Subject: [PATCH 26/73] PEP8 --- redis_consumer/utils.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/redis_consumer/utils.py b/redis_consumer/utils.py index 03eda5dc..7495e3a1 100644 --- a/redis_consumer/utils.py +++ b/redis_consumer/utils.py @@ -42,8 +42,6 @@ from skimage.external import tifffile from tensorflow.keras.preprocessing.image import img_to_array -from redis_consumer import settings - logger = logging.getLogger('redis_consumer.utils') From d1d4b4a8fbc13113fac12f94cec69c038aa01544 Mon Sep 17 00:00:00 2001 From: William Graf <7930703+willgraf@users.noreply.github.com> Date: Thu, 21 Jan 2021 14:44:49 -0800 Subject: [PATCH 27/73] no need to account for python2. --- redis_consumer/consumers/base_consumer.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/redis_consumer/consumers/base_consumer.py b/redis_consumer/consumers/base_consumer.py index 1a4d8e0c..207a21a0 100644 --- a/redis_consumer/consumers/base_consumer.py +++ b/redis_consumer/consumers/base_consumer.py @@ -608,13 +608,7 @@ def _parse_failures(self, failed_children): len(failed_hashes), json.dumps(failed_hashes, indent=4)) - # check python2 vs python3 - if hasattr(urllib, 'parse'): - url_encode = urllib.parse.urlencode # pylint: disable=E1101 - else: - url_encode = urllib.urlencode # pylint: disable=E1101 - - return url_encode(failed_hashes) + return urllib.parse.urlencode(failed_hashes) def _cleanup(self, redis_hash, children, done, failed): start = timeit.default_timer() From e09cc755412b9e67b98e347c4d7463faed29e49b Mon Sep 17 00:00:00 2001 From: William Graf <7930703+willgraf@users.noreply.github.com> Date: Thu, 21 Jan 2021 14:45:15 -0800 Subject: [PATCH 28/73] Drop support for python2 --- .github/workflows/tests.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index b6718075..7b398fb8 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -12,7 +12,7 @@ jobs: strategy: matrix: - python-version: [2.7, 3.5, 3.6, 3.7, 3.8] + python-version: [3.5, 3.6, 3.7, 3.8] steps: - uses: actions/checkout@v2 From 7fec70384937f3a1e11da4c38e6f691c5df0a5c8 Mon Sep 17 00:00:00 2001 From: William Graf <7930703+willgraf@users.noreply.github.com> Date: Thu, 21 Jan 2021 14:45:39 -0800 Subject: [PATCH 29/73] Add channel dimension to _write_image test helper. --- redis_consumer/utils_test.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/redis_consumer/utils_test.py b/redis_consumer/utils_test.py index f7a46861..1fc964f5 100644 --- a/redis_consumer/utils_test.py +++ b/redis_consumer/utils_test.py @@ -42,12 +42,12 @@ from redis_consumer.testing_utils import _get_image from redis_consumer import utils -from redis_consumer import settings def _write_image(filepath, img_w=300, img_h=300): - imarray = _get_image(img_h, img_w) - if filepath.lower().endswith('tif') or filepath.lower().endswith('tiff'): + imarray = _get_image(img_h, img_w, 1) + _, ext = os.path.splitext(filepath.lower()) + if ext in {'.tif', '.tiff'}: tiff.imsave(filepath, imarray[..., 0]) else: img = array_to_img(imarray, scale=False, data_format='channels_last') From b1c341cf95bb1a0679ae92141071e2e259fef502 Mon Sep 17 00:00:00 2001 From: William Graf <7930703+willgraf@users.noreply.github.com> Date: Thu, 21 Jan 2021 14:46:08 -0800 Subject: [PATCH 30/73] Add stub for testing scale detection with Multiplex consumer. --- .../consumers/multiplex_consumer_test.py | 27 +++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/redis_consumer/consumers/multiplex_consumer_test.py b/redis_consumer/consumers/multiplex_consumer_test.py index 3bc5a3ee..f7d082e5 100644 --- a/redis_consumer/consumers/multiplex_consumer_test.py +++ b/redis_consumer/consumers/multiplex_consumer_test.py @@ -33,6 +33,8 @@ import pytest from redis_consumer import consumers +from redis_consumer import settings +from redis_consumer.testing_utils import _get_image from redis_consumer.testing_utils import Bunch from redis_consumer.testing_utils import DummyStorage from redis_consumer.testing_utils import redis_client @@ -41,6 +43,31 @@ class TestMultiplexConsumer(object): # pylint: disable=R0201,W0621 + def test_detect_scale(self, mocker, redis_client): + # pylint: disable=W0613 + shape = (1, 256, 256, 1) + consumer = consumers.MultiplexConsumer(redis_client, None, 'q') + + image = _get_image(shape[1] * 2, shape[2] * 2, shape[3]) + + expected_scale = 1 # random.uniform(0.5, 1.5) + # model_mpp = random.uniform(0.5, 1.5) + + mock_app = Bunch( + predict=lambda *x, **y: expected_scale, + # model_mpp=model_mpp, + model=Bunch(get_batch_size=lambda *x: 1)) + + mocker.patch.object(consumer, 'get_grpc_app', lambda *x: mock_app) + + mocker.patch.object(settings, 'SCALE_DETECT_ENABLED', False) + scale = consumer.detect_scale(image) + assert scale == 1 # model_mpp + + mocker.patch.object(settings, 'SCALE_DETECT_ENABLED', True) + scale = consumer.detect_scale(image) + assert scale == expected_scale # * model_mpp + def test__consume_finished_status(self, redis_client): queue = 'q' storage = DummyStorage() From 943081823ba5077ccc50d6dfa59df42df9c1649c Mon Sep 17 00:00:00 2001 From: William Graf <7930703+willgraf@users.noreply.github.com> Date: Thu, 21 Jan 2021 16:27:20 -0800 Subject: [PATCH 31/73] Cast detected label as int() --- redis_consumer/consumers/image_consumer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/redis_consumer/consumers/image_consumer.py b/redis_consumer/consumers/image_consumer.py index ee858cbd..1478535b 100644 --- a/redis_consumer/consumers/image_consumer.py +++ b/redis_consumer/consumers/image_consumer.py @@ -66,7 +66,7 @@ def detect_label(self, image): self.logger.debug('Label %s detected in %s seconds', detected_label, timeit.default_timer() - start) - return detected_label + return int(detected_label) def get_image_label(self, label, image, redis_hash): """Calculate label of image.""" From 3543f875388f27adf8266c2a753c2fc4e6c833b3 Mon Sep 17 00:00:00 2001 From: William Graf <7930703+willgraf@users.noreply.github.com> Date: Thu, 21 Jan 2021 17:47:15 -0800 Subject: [PATCH 32/73] Add .github folder to .gitignore. --- .dockerignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.dockerignore b/.dockerignore index 3d12409d..b8fb0c1d 100644 --- a/.dockerignore +++ b/.dockerignore @@ -5,6 +5,7 @@ logs/ protos/ docs/ build/ +.github/ # Byte-compiled / optimized / DLL files __pycache__/ From bd09598b19b30dd6276a61b726bbb6f7d1afc300 Mon Sep 17 00:00:00 2001 From: William Graf <7930703+willgraf@users.noreply.github.com> Date: Thu, 21 Jan 2021 18:05:31 -0800 Subject: [PATCH 33/73] Change base image to slim-buster tensorflow unfortunately is not compatible with alpine through pip. --- Dockerfile | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/Dockerfile b/Dockerfile index 6ed5b706..6662de8e 100644 --- a/Dockerfile +++ b/Dockerfile @@ -23,15 +23,18 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ -FROM python:3.6 +FROM python:3.6-slim-buster WORKDIR /usr/src/app -COPY requirements.txt requirements-no-deps.txt ./ +RUN apt-get update && apt-get install -y \ + build-essential && \ + rm -rf /var/lib/apt/lists/* -RUN pip install --no-cache-dir -r requirements.txt +COPY requirements.txt requirements-no-deps.txt ./ -RUN pip install --no-cache-dir --no-deps -r requirements-no-deps.txt +RUN pip install --no-cache-dir -r requirements.txt && \ + pip install --no-cache-dir --no-deps -r requirements-no-deps.txt COPY . . From 62751f7714659f491dd876cef3c3c935688e002d Mon Sep 17 00:00:00 2001 From: William Graf <7930703+willgraf@users.noreply.github.com> Date: Thu, 21 Jan 2021 18:21:34 -0800 Subject: [PATCH 34/73] Use old-style format strings instead of f"". --- redis_consumer/consumers/base_consumer.py | 5 +++-- redis_consumer/grpc_clients.py | 4 +++- redis_consumer/grpc_clients_test.py | 2 +- 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/redis_consumer/consumers/base_consumer.py b/redis_consumer/consumers/base_consumer.py index 207a21a0..3c7b3e12 100644 --- a/redis_consumer/consumers/base_consumer.py +++ b/redis_consumer/consumers/base_consumer.py @@ -261,8 +261,9 @@ def validate_model_input(self, image, model_name, model_version): # cast as image to match with the list of shapes. image = [image] if not isinstance(image, list) else image - errtext = (f'Invalid image shape: {[s.shape for s in image]}. ' - f'The {self.queue} job expects images of shape {shapes}') + errtext = ('Invalid image shape: {}. The {} job expects ' + 'images of shape{}').format( + [s.shape for s in image], self.queue, shapes) if len(image) != len(shapes): raise ValueError(errtext) diff --git a/redis_consumer/grpc_clients.py b/redis_consumer/grpc_clients.py index ff25c726..d7adcc59 100644 --- a/redis_consumer/grpc_clients.py +++ b/redis_consumer/grpc_clients.py @@ -169,7 +169,9 @@ class PredictClient(GrpcClient): def __init__(self, host, model_name, model_version): super(PredictClient, self).__init__(host) - self.logger = logging.getLogger(f'{model_name}:{model_version}:gRPC') + self.logger = logging.getLogger('{}:{}:gRPC'.format( + model_name, model_version + )) self.model_name = model_name self.model_version = model_version diff --git a/redis_consumer/grpc_clients_test.py b/redis_consumer/grpc_clients_test.py index 755a82d8..d5e6a078 100644 --- a/redis_consumer/grpc_clients_test.py +++ b/redis_consumer/grpc_clients_test.py @@ -50,7 +50,7 @@ def __init__(self, host, model_name, model_version): def predict(self, request_data, request_timeout=10): retval = {} for i, d in enumerate(request_data): - retval[f'prediction{i}'] = d.get('data') + retval['prediction{}'.format(i)] = d.get('data') return retval From 5786e86c429fce39aceed7f6589823705cd5b5b4 Mon Sep 17 00:00:00 2001 From: William Graf <7930703+willgraf@users.noreply.github.com> Date: Thu, 21 Jan 2021 19:48:55 -0800 Subject: [PATCH 35/73] Drop support for python 3.5 --- .github/workflows/tests.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 7b398fb8..4c69c164 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -12,7 +12,7 @@ jobs: strategy: matrix: - python-version: [3.5, 3.6, 3.7, 3.8] + python-version: [3.6, 3.7, 3.8] steps: - uses: actions/checkout@v2 From 27b47322c378d98d4a10824fd14278192395ee29 Mon Sep 17 00:00:00 2001 From: William Graf <7930703+willgraf@users.noreply.github.com> Date: Thu, 21 Jan 2021 20:00:30 -0800 Subject: [PATCH 36/73] Add python 3.5 back and pin tracking to 0.2.7 --- .github/workflows/tests.yaml | 2 +- requirements.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 4c69c164..7b398fb8 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -12,7 +12,7 @@ jobs: strategy: matrix: - python-version: [3.6, 3.7, 3.8] + python-version: [3.5, 3.6, 3.7, 3.8] steps: - uses: actions/checkout@v2 diff --git a/requirements.txt b/requirements.txt index 38c5429a..afca247b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ # deepcell packages deepcell-cpu>=0.8.3 -deepcell-tracking>=0.2.6 +deepcell-tracking==0.2.7 deepcell-toolbox>=0.8.2 tensorflow-cpu scikit-image>=0.14.0,<0.17.0 From b389502277c811ba206739fcad105c1577088ebd Mon Sep 17 00:00:00 2001 From: William Graf <7930703+willgraf@users.noreply.github.com> Date: Thu, 21 Jan 2021 20:09:01 -0800 Subject: [PATCH 37/73] fix randint maxvalue --- redis_consumer/consumers/base_consumer_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/redis_consumer/consumers/base_consumer_test.py b/redis_consumer/consumers/base_consumer_test.py index dd06189b..59af2d95 100644 --- a/redis_consumer/consumers/base_consumer_test.py +++ b/redis_consumer/consumers/base_consumer_test.py @@ -297,7 +297,7 @@ def test_validate_model_input(self, mocker, redis_client): with pytest.raises(ValueError): image = [np.ones(s) for s in valid_input_shapes[:count]] # set a random entry to be invalid - i = random.randint(0, count) + i = random.randint(0, count - 1) image[i] = np.ones(random.choice(invalid_input_shapes)) consumer.validate_model_input(image, 'model', '1') From f50ed07ebdef8058914f398c15bf4debf3d91368 Mon Sep 17 00:00:00 2001 From: William Graf <7930703+willgraf@users.noreply.github.com> Date: Fri, 22 Jan 2021 10:11:54 -0800 Subject: [PATCH 38/73] Refactor ImageFileConsumer to SegmentationConsumer. --- redis_consumer/consumers/__init__.py | 5 +++-- .../{image_consumer.py => segmentation_consumer.py} | 7 +++---- ..._consumer_test.py => segmentation_consumer_test.py} | 10 +++++----- 3 files changed, 11 insertions(+), 11 deletions(-) rename redis_consumer/consumers/{image_consumer.py => segmentation_consumer.py} (96%) rename redis_consumer/consumers/{image_consumer_test.py => segmentation_consumer_test.py} (93%) diff --git a/redis_consumer/consumers/__init__.py b/redis_consumer/consumers/__init__.py index 82d80aa4..ef981644 100644 --- a/redis_consumer/consumers/__init__.py +++ b/redis_consumer/consumers/__init__.py @@ -33,14 +33,15 @@ from redis_consumer.consumers.base_consumer import ZipFileConsumer # Custom Workflow consumers -from redis_consumer.consumers.image_consumer import ImageFileConsumer +from redis_consumer.consumers.segmentation_consumer import SegmentationConsumer from redis_consumer.consumers.tracking_consumer import TrackingConsumer from redis_consumer.consumers.multiplex_consumer import MultiplexConsumer # TODO: Import future custom Consumer classes. CONSUMERS = { - 'image': ImageFileConsumer, + 'image': SegmentationConsumer, # deprecated, use "segmentation" instead. + 'segmentation': SegmentationConsumer, 'zip': ZipFileConsumer, 'tracking': TrackingConsumer, 'multiplex': MultiplexConsumer, diff --git a/redis_consumer/consumers/image_consumer.py b/redis_consumer/consumers/segmentation_consumer.py similarity index 96% rename from redis_consumer/consumers/image_consumer.py rename to redis_consumer/consumers/segmentation_consumer.py index 1478535b..c17af28e 100644 --- a/redis_consumer/consumers/image_consumer.py +++ b/redis_consumer/consumers/segmentation_consumer.py @@ -23,7 +23,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ -"""ImageFileConsumer class for consuming image segmentation jobs.""" +"""SegmentationConsumer class for consuming image segmentation jobs.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -32,14 +32,13 @@ import numpy as np -from deepcell.applications import LabelDetection, NuclearSegmentation +from deepcell.applications import LabelDetection from redis_consumer.consumers import TensorFlowServingConsumer -from redis_consumer import utils from redis_consumer import settings -class ImageFileConsumer(TensorFlowServingConsumer): +class SegmentationConsumer(TensorFlowServingConsumer): """Consumes image files and uploads the results""" def detect_label(self, image): diff --git a/redis_consumer/consumers/image_consumer_test.py b/redis_consumer/consumers/segmentation_consumer_test.py similarity index 93% rename from redis_consumer/consumers/image_consumer_test.py rename to redis_consumer/consumers/segmentation_consumer_test.py index 5d460dcb..cacba189 100644 --- a/redis_consumer/consumers/image_consumer_test.py +++ b/redis_consumer/consumers/segmentation_consumer_test.py @@ -43,14 +43,14 @@ from redis_consumer.testing_utils import _get_image -class TestImageFileConsumer(object): +class TestSegmentationConsumer(object): # pylint: disable=R0201,W0621 def test_detect_label(self, mocker, redis_client): # pylint: disable=W0613 shape = (1, 256, 256, 1) queue = 'q' - consumer = consumers.ImageFileConsumer(redis_client, None, queue) + consumer = consumers.SegmentationConsumer(redis_client, None, queue) expected_label = random.randint(1, 9) @@ -73,7 +73,7 @@ def test_detect_label(self, mocker, redis_client): def test_get_image_label(self, mocker, redis_client): queue = 'q' stg = DummyStorage() - consumer = consumers.ImageFileConsumer(redis_client, stg, queue) + consumer = consumers.SegmentationConsumer(redis_client, stg, queue) image = _get_image(256, 256, 1) # test no label provided @@ -101,7 +101,7 @@ def test__consume_finished_status(self, redis_client): queue = 'q' storage = DummyStorage() - consumer = consumers.ImageFileConsumer(redis_client, storage, queue) + consumer = consumers.SegmentationConsumer(redis_client, storage, queue) empty_data = {'input_file_name': 'file.tiff'} @@ -123,7 +123,7 @@ def test__consume(self, mocker, redis_client): queue = 'predict' storage = DummyStorage() - consumer = consumers.ImageFileConsumer(redis_client, storage, queue) + consumer = consumers.SegmentationConsumer(redis_client, storage, queue) empty_data = {'input_file_name': 'file.tiff'} From 957887ca1a7bf6f073314c23d20b097390f39099 Mon Sep 17 00:00:00 2001 From: William Graf <7930703+willgraf@users.noreply.github.com> Date: Mon, 25 Jan 2021 14:06:28 -0800 Subject: [PATCH 39/73] Update minimum numpy and skimage versions to latest py27 support. --- requirements.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/requirements.txt b/requirements.txt index afca247b..da2aaba4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,8 +3,8 @@ deepcell-cpu>=0.8.3 deepcell-tracking==0.2.7 deepcell-toolbox>=0.8.2 tensorflow-cpu -scikit-image>=0.14.0,<0.17.0 -numpy>=1.16.4 +scikit-image>=0.14.5,<0.17.0 +numpy>=1.16.6 # tensorflow-serving-apis and gRPC dependencies grpcio>=1.0,<2 From c5a3e84e12219db81c7f0fe43a6b6d0b955b8b75 Mon Sep 17 00:00:00 2001 From: William Graf <7930703+willgraf@users.noreply.github.com> Date: Mon, 25 Jan 2021 14:06:55 -0800 Subject: [PATCH 40/73] Loosen boto3, gcs, and decouple requirements to current major version. --- requirements.txt | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/requirements.txt b/requirements.txt index da2aaba4..0cb29092 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,8 +12,8 @@ dict-to-protobuf==0.0.3.9 protobuf>=3.6.0 # misc storage and redis clients -boto3==1.9.195 -google-cloud-storage>=1.16.1 -python-decouple==3.1 +boto3>=1.9.195,<2 +google-cloud-storage>=1.16.1,<2 +python-decouple>=3.1,<4 redis==3.4.1 pytz==2019.1 From b8ec8dd1be693ad091f2ef2874c27307f67b8fb5 Mon Sep 17 00:00:00 2001 From: William Graf <7930703+willgraf@users.noreply.github.com> Date: Mon, 25 Jan 2021 15:42:19 -0800 Subject: [PATCH 41/73] Bump toolbox version to 0.8.4 --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 0cb29092..f8e74f8d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ # deepcell packages deepcell-cpu>=0.8.3 deepcell-tracking==0.2.7 -deepcell-toolbox>=0.8.2 +deepcell-toolbox>=0.8.4 tensorflow-cpu scikit-image>=0.14.5,<0.17.0 numpy>=1.16.6 From 182a5c02b852eafebb3efe453689673321bfbdef Mon Sep 17 00:00:00 2001 From: William Graf <7930703+willgraf@users.noreply.github.com> Date: Mon, 25 Jan 2021 16:01:13 -0800 Subject: [PATCH 42/73] missing library libglib2.0-0 required for opencv --- Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Dockerfile b/Dockerfile index 6662de8e..9f81941c 100644 --- a/Dockerfile +++ b/Dockerfile @@ -28,7 +28,7 @@ FROM python:3.6-slim-buster WORKDIR /usr/src/app RUN apt-get update && apt-get install -y \ - build-essential && \ + build-essential libglib2.0-0 && \ rm -rf /var/lib/apt/lists/* COPY requirements.txt requirements-no-deps.txt ./ From d7961980cc5509331d67bdcc7a6284956f5ac8b4 Mon Sep 17 00:00:00 2001 From: William Graf <7930703+willgraf@users.noreply.github.com> Date: Mon, 25 Jan 2021 20:08:33 -0800 Subject: [PATCH 43/73] Clean up gRPC client logging. --- redis_consumer/grpc_clients.py | 36 +++++++--------------------------- 1 file changed, 7 insertions(+), 29 deletions(-) diff --git a/redis_consumer/grpc_clients.py b/redis_consumer/grpc_clients.py index d7adcc59..6a44ad97 100644 --- a/redis_consumer/grpc_clients.py +++ b/redis_consumer/grpc_clients.py @@ -169,7 +169,7 @@ class PredictClient(GrpcClient): def __init__(self, host, model_name, model_version): super(PredictClient, self).__init__(host) - self.logger = logging.getLogger('{}:{}:gRPC'.format( + self.logger = logging.getLogger('gRPC:{}:{}'.format( model_name, model_version )) self.model_name = model_name @@ -182,9 +182,7 @@ def __init__(self, host, model_name, model_version): def _retry_grpc(self, request, request_timeout): request_name = request.__class__.__name__ - self.logger.info('Sending %s to %s model %s:%s.', - request_name, self.host, - self.model_name, self.model_version) + self.logger.info('Sending %s to %s.', request_name, self.host) true_failures, count = 0, 0 @@ -201,8 +199,9 @@ def _retry_grpc(self, request, request_timeout): api_call = getattr(stub, api_endpoint_name) response = api_call(request, timeout=request_timeout) - self.logger.debug('%s finished in %s seconds.', - request_name, timeit.default_timer() - t) + self.logger.debug('%s finished in %s seconds (%s retries).', + request_name, timeit.default_timer() - t, + true_failures) return response except grpc.RpcError as err: @@ -241,38 +240,24 @@ def _retry_grpc(self, request, request_timeout): raise err def predict(self, request_data, request_timeout=10): - self.logger.info('Sending PredictRequest to %s model %s:%s.', - self.host, self.model_name, self.model_version) - - t = timeit.default_timer() - request = PredictRequest() - self.logger.debug('Created PredictRequest object in %s seconds.', - timeit.default_timer() - t) - # pylint: disable=E1101 + request = PredictRequest() request.model_spec.name = self.model_name request.model_spec.version.value = self.model_version - t = timeit.default_timer() for d in request_data: tensor_proto = make_tensor_proto(d['data'], d['in_tensor_dtype']) request.inputs[d['in_tensor_name']].CopyFrom(tensor_proto) - self.logger.debug('Made tensor protos in %s seconds.', - timeit.default_timer() - t) - response = self._retry_grpc(request, request_timeout) response_dict = grpc_response_to_dict(response) - self.logger.info('Got PredictResponse with keys: %s ', + self.logger.info('Got PredictResponse with keys: %s.', list(response_dict)) return response_dict def get_model_metadata(self, request_timeout=10): - self.logger.info('Sending GetModelMetadataRequest to %s model %s:%s.', - self.host, self.model_name, self.model_version) - # pylint: disable=E1101 request = GetModelMetadataRequest() request.metadata_field.append('signature_def') @@ -280,14 +265,7 @@ def get_model_metadata(self, request_timeout=10): request.model_spec.version.value = self.model_version response = self._retry_grpc(request, request_timeout) - - t = timeit.default_timer() - response_dict = json.loads(MessageToJson(response)) - - self.logger.debug('gRPC GetModelMetadataProtobufConversion took ' - '%s seconds.', timeit.default_timer() - t) - return response_dict From 212827179ee097c6515ad3e9db47b1f6603af118 Mon Sep 17 00:00:00 2001 From: William Graf <7930703+willgraf@users.noreply.github.com> Date: Mon, 25 Jan 2021 20:09:02 -0800 Subject: [PATCH 44/73] Update log level to info to match begin log statement. --- redis_consumer/consumers/base_consumer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/redis_consumer/consumers/base_consumer.py b/redis_consumer/consumers/base_consumer.py index 98c9af40..4288bfe0 100644 --- a/redis_consumer/consumers/base_consumer.py +++ b/redis_consumer/consumers/base_consumer.py @@ -355,7 +355,7 @@ def get_model_metadata(self, model_name, model_version): inputs = inputs['serving_default']['inputs'] finished = timeit.default_timer() - start - self.logger.debug('Got model metadata for %s in %s seconds.', + self.logger.info('Got model metadata for %s in %s seconds.', model, finished) self.redis.hset(model, 'metadata', json.dumps(inputs)) From 3b419c31648302633c1c2d42074767ba2dfe58bb Mon Sep 17 00:00:00 2001 From: William Graf <7930703+willgraf@users.noreply.github.com> Date: Mon, 25 Jan 2021 20:10:13 -0800 Subject: [PATCH 45/73] Fix image label detection logs. --- redis_consumer/consumers/segmentation_consumer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/redis_consumer/consumers/segmentation_consumer.py b/redis_consumer/consumers/segmentation_consumer.py index c17af28e..b01a906c 100644 --- a/redis_consumer/consumers/segmentation_consumer.py +++ b/redis_consumer/consumers/segmentation_consumer.py @@ -72,11 +72,11 @@ def get_image_label(self, label, image, redis_hash): if not label: # Detect scale of image (Default to 1) label = self.detect_label(image) - self.logger.debug('Image scale detected: %s', label) + self.logger.debug('Image label detected: %s.', label) self.update_key(redis_hash, {'label': label}) else: label = int(label) - self.logger.debug('Image label already calculated %s', label) + self.logger.debug('Image label already calculated: %s.', label) if label not in settings.APPLICATION_CHOICES: raise ValueError('Label type {} is not supported'.format(label)) From 41cd73342dfaa21e21aee7cd3840bc7eeedeaef7 Mon Sep 17 00:00:00 2001 From: William Graf <7930703+willgraf@users.noreply.github.com> Date: Mon, 25 Jan 2021 20:11:16 -0800 Subject: [PATCH 46/73] Add summary logs after app.predict --- redis_consumer/consumers/segmentation_consumer.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/redis_consumer/consumers/segmentation_consumer.py b/redis_consumer/consumers/segmentation_consumer.py index b01a906c..62c52f59 100644 --- a/redis_consumer/consumers/segmentation_consumer.py +++ b/redis_consumer/consumers/segmentation_consumer.py @@ -136,9 +136,21 @@ def _consume(self, redis_hash): app = self.get_grpc_app(model, app_cls) + t = timeit.default_timer() + results = app.predict(image, image_mpp=scale * app.model_mpp, batch_size=app.model.get_batch_size()) + # log output + pre = app.preprocessing_fn + post = app.postprocessing_fn + self.logger.info('%s finished with %s pre-processing ' + 'and %s post-processing in %s seconds.', + app.__class__.__name__, + pre.__name__ if pre is not None else 'None', + post.__name__ if post is not None else 'None', + timeit.default_timer() - t) + # Save the post-processed results to a file _ = timeit.default_timer() self.update_key(redis_hash, {'status': 'saving-results'}) From c9b826e6a9d29c07705815b5cb41b85d32314427 Mon Sep 17 00:00:00 2001 From: William Graf <7930703+willgraf@users.noreply.github.com> Date: Mon, 25 Jan 2021 22:02:35 -0800 Subject: [PATCH 47/73] PEP8 --- redis_consumer/consumers/base_consumer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/redis_consumer/consumers/base_consumer.py b/redis_consumer/consumers/base_consumer.py index 4288bfe0..a9cc8e5f 100644 --- a/redis_consumer/consumers/base_consumer.py +++ b/redis_consumer/consumers/base_consumer.py @@ -356,7 +356,7 @@ def get_model_metadata(self, model_name, model_version): finished = timeit.default_timer() - start self.logger.info('Got model metadata for %s in %s seconds.', - model, finished) + model, finished) self.redis.hset(model, 'metadata', json.dumps(inputs)) self.redis.expire(model, settings.METADATA_EXPIRE_TIME) From 4ffa9edac3a21881f919e8cd944b9db4c5296331 Mon Sep 17 00:00:00 2001 From: William Graf <7930703+willgraf@users.noreply.github.com> Date: Mon, 25 Jan 2021 22:07:09 -0800 Subject: [PATCH 48/73] Undo logging due to test errors. --- redis_consumer/consumers/segmentation_consumer.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/redis_consumer/consumers/segmentation_consumer.py b/redis_consumer/consumers/segmentation_consumer.py index 62c52f59..b01a906c 100644 --- a/redis_consumer/consumers/segmentation_consumer.py +++ b/redis_consumer/consumers/segmentation_consumer.py @@ -136,21 +136,9 @@ def _consume(self, redis_hash): app = self.get_grpc_app(model, app_cls) - t = timeit.default_timer() - results = app.predict(image, image_mpp=scale * app.model_mpp, batch_size=app.model.get_batch_size()) - # log output - pre = app.preprocessing_fn - post = app.postprocessing_fn - self.logger.info('%s finished with %s pre-processing ' - 'and %s post-processing in %s seconds.', - app.__class__.__name__, - pre.__name__ if pre is not None else 'None', - post.__name__ if post is not None else 'None', - timeit.default_timer() - t) - # Save the post-processed results to a file _ = timeit.default_timer() self.update_key(redis_hash, {'status': 'saving-results'}) From a09e7efef43d1356a7e3f02deb32aa9f8addacf6 Mon Sep 17 00:00:00 2001 From: William Graf <7930703+willgraf@users.noreply.github.com> Date: Mon, 25 Jan 2021 22:23:19 -0800 Subject: [PATCH 49/73] Clean up logging in zip_files. --- redis_consumer/utils.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/redis_consumer/utils.py b/redis_consumer/utils.py index 7495e3a1..2cfd2366 100644 --- a/redis_consumer/utils.py +++ b/redis_consumer/utils.py @@ -218,10 +218,9 @@ def zip_files(files, dest=None, prefix=None): name = f.replace(dest, '') name = name[1:] if name.startswith(os.path.sep) else name zf.write(f, arcname=name) - logger.debug('Saved %s files to %s', len(files), filepath) + logger.debug('Saved %s files to %s in %s seconds.', + len(files), filepath, timeit.default_timer() - start) except Exception as err: logger.error('Failed to write zipfile: %s', err) raise err - logger.debug('Zipped %s files into %s in %s seconds.', - len(files), filepath, timeit.default_timer() - start) return filepath From d39ce23b0cad9477de5521764b9e3cb96cfd7652 Mon Sep 17 00:00:00 2001 From: William Graf <7930703+willgraf@users.noreply.github.com> Date: Mon, 25 Jan 2021 22:24:39 -0800 Subject: [PATCH 50/73] No need to enumerate --- redis_consumer/consumers/base_consumer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/redis_consumer/consumers/base_consumer.py b/redis_consumer/consumers/base_consumer.py index a9cc8e5f..c2fc24d4 100644 --- a/redis_consumer/consumers/base_consumer.py +++ b/redis_consumer/consumers/base_consumer.py @@ -430,9 +430,9 @@ def save_output(self, image, save_name): image = [image] outpaths = [] - for i, im in enumerate(image): + for img in image: outpaths.extend(utils.save_numpy_array( - im, + img, name=str(name), subdir=subdir, output_dir=tempdir)) From 5cb78f94d9a39f19d90a81357ae6808de087841a Mon Sep 17 00:00:00 2001 From: William Graf <7930703+willgraf@users.noreply.github.com> Date: Mon, 25 Jan 2021 22:33:42 -0800 Subject: [PATCH 51/73] Convert to lambda function --- redis_consumer/utils.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/redis_consumer/utils.py b/redis_consumer/utils.py index 2cfd2366..bafd476f 100644 --- a/redis_consumer/utils.py +++ b/redis_consumer/utils.py @@ -57,9 +57,7 @@ def iter_image_archive(zip_path, destination): Iterator of all image paths in extracted archive """ archive = zipfile.ZipFile(zip_path, 'r', allowZip64=True) - - def is_valid(x): - return os.path.splitext(x)[1] and '__MACOSX' not in x + is_valid = lambda x: os.path.splitext(x)[1] and '__MACOSX' not in x for info in archive.infolist(): extracted = archive.extract(info, path=destination) if os.path.isfile(extracted): From a208d02b85744e0985b1b5c9bb70f5540023407c Mon Sep 17 00:00:00 2001 From: William Graf <7930703+willgraf@users.noreply.github.com> Date: Mon, 25 Jan 2021 23:25:34 -0800 Subject: [PATCH 52/73] Replace DEBUG with LOG_LEVEL --- consume-redis-events.py | 29 +++++++++++++++-------------- redis_consumer/settings.py | 6 +----- 2 files changed, 16 insertions(+), 19 deletions(-) diff --git a/consume-redis-events.py b/consume-redis-events.py index 73c44b90..e7c63435 100644 --- a/consume-redis-events.py +++ b/consume-redis-events.py @@ -37,32 +37,33 @@ import sys import traceback +import decouple + import redis_consumer from redis_consumer import settings -def initialize_logger(debug_mode=True): +def initialize_logger(log_level='DEBUG'): + log_level = str(log_level).upper() logger = logging.getLogger() logger.setLevel(logging.DEBUG) - formatter = logging.Formatter('[%(asctime)s]:[%(levelname)s]:[%(name)s]: %(message)s') + formatter = logging.Formatter('[%(levelname)s]:[%(name)s]: %(message)s') console = logging.StreamHandler(stream=sys.stdout) console.setFormatter(formatter) - fh = logging.handlers.RotatingFileHandler( - filename='redis-consumer.log', - maxBytes=10000000, - backupCount=10) - fh.setFormatter(formatter) - - if debug_mode: - console.setLevel(logging.DEBUG) - else: + if log_level == 'CRITICAL': + console.setLevel(logging.CRITICAL) + elif log_level == 'ERROR': + console.setLevel(logging.ERROR) + elif log_level == 'WARN': + console.setLevel(logging.WARN) + elif log_level == 'INFO': console.setLevel(logging.INFO) - fh.setLevel(logging.DEBUG) + else: + console.setLevel(logging.DEBUG) logger.addHandler(console) - logger.addHandler(fh) def get_consumer(consumer_type, **kwargs): @@ -74,7 +75,7 @@ def get_consumer(consumer_type, **kwargs): if __name__ == '__main__': - initialize_logger(settings.DEBUG) + initialize_logger(decouple.config('LOG_LEVEL', default='DEBUG')) _logger = logging.getLogger(__file__) diff --git a/redis_consumer/settings.py b/redis_consumer/settings.py index 928cd926..99968c18 100644 --- a/redis_consumer/settings.py +++ b/redis_consumer/settings.py @@ -41,9 +41,6 @@ def _strip(x): return '/'.join(y for y in x.split('/') if y) -# Debug Mode -DEBUG = config('DEBUG', cast=bool, default=False) - # Consumer settings INTERVAL = config('INTERVAL', default=10, cast=int) CONSUMER_TYPE = config('CONSUMER_TYPE', default='image') @@ -87,9 +84,8 @@ def _strip(x): ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) DOWNLOAD_DIR = os.path.join(ROOT_DIR, 'download') OUTPUT_DIR = os.path.join(ROOT_DIR, 'output') -LOG_DIR = os.path.join(ROOT_DIR, 'logs') -for d in (DOWNLOAD_DIR, OUTPUT_DIR, LOG_DIR): +for d in (DOWNLOAD_DIR, OUTPUT_DIR): try: os.mkdir(d) except OSError: From b25cdd8757125caadd4c70c488878747cca845a5 Mon Sep 17 00:00:00 2001 From: William Graf <7930703+willgraf@users.noreply.github.com> Date: Mon, 25 Jan 2021 23:26:07 -0800 Subject: [PATCH 53/73] Move GRPC_RETRY_STATUSES into class attribute. --- redis_consumer/grpc_clients.py | 9 ++++++++- redis_consumer/settings.py | 8 -------- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/redis_consumer/grpc_clients.py b/redis_consumer/grpc_clients.py index 6a44ad97..196e136f 100644 --- a/redis_consumer/grpc_clients.py +++ b/redis_consumer/grpc_clients.py @@ -180,6 +180,13 @@ def __init__(self, host, model_name, model_version): PredictRequest: 'Predict', } + # Retry-able gRPC status codes + self.retry_status_codes = { + grpc.StatusCode.DEADLINE_EXCEEDED, + grpc.StatusCode.RESOURCE_EXHAUSTED, + grpc.StatusCode.UNAVAILABLE + } + def _retry_grpc(self, request, request_timeout): request_name = request.__class__.__name__ self.logger.info('Sending %s to %s.', request_name, self.host) @@ -211,7 +218,7 @@ def _retry_grpc(self, request, request_timeout): '%s', request_name, count, err) raise err - if err.code() in settings.GRPC_RETRY_STATUSES: + if err.code() in self.retry_status_codes: count += 1 is_true_failure = err.code() != grpc.StatusCode.UNAVAILABLE true_failures += int(is_true_failure) diff --git a/redis_consumer/settings.py b/redis_consumer/settings.py index 99968c18..556670e9 100644 --- a/redis_consumer/settings.py +++ b/redis_consumer/settings.py @@ -30,7 +30,6 @@ import os -import grpc from decouple import config import deepcell @@ -64,13 +63,6 @@ def _strip(x): GRPC_TIMEOUT = config('GRPC_TIMEOUT', default=30, cast=int) GRPC_BACKOFF = config('GRPC_BACKOFF', default=3, cast=int) -# Retry-able gRPC status codes -GRPC_RETRY_STATUSES = { - grpc.StatusCode.DEADLINE_EXCEEDED, - grpc.StatusCode.RESOURCE_EXHAUSTED, - grpc.StatusCode.UNAVAILABLE -} - # timeout/backoff wait time in seconds REDIS_TIMEOUT = config('REDIS_TIMEOUT', default=3, cast=int) EMPTY_QUEUE_TIMEOUT = config('EMPTY_QUEUE_TIMEOUT', default=5, cast=int) From e7dc7bba09b199f2fd6aa92f94bd28c683fe11b2 Mon Sep 17 00:00:00 2001 From: William Graf <7930703+willgraf@users.noreply.github.com> Date: Mon, 25 Jan 2021 23:49:15 -0800 Subject: [PATCH 54/73] Remove unused OUTPUT_DIR settings value. --- consume-redis-events.py | 1 - redis_consumer/consumers/base_consumer.py | 4 +--- redis_consumer/settings.py | 3 +-- 3 files changed, 2 insertions(+), 6 deletions(-) diff --git a/consume-redis-events.py b/consume-redis-events.py index e7c63435..fa6845cb 100644 --- a/consume-redis-events.py +++ b/consume-redis-events.py @@ -93,7 +93,6 @@ def get_consumer(consumer_type, **kwargs): 'final_status': 'done', 'failed_status': 'failed', 'name': settings.HOSTNAME, - 'output_dir': settings.OUTPUT_DIR, } _logger.debug('Getting `%s` consumer with args %s.', diff --git a/redis_consumer/consumers/base_consumer.py b/redis_consumer/consumers/base_consumer.py index c2fc24d4..b84c9d7e 100644 --- a/redis_consumer/consumers/base_consumer.py +++ b/redis_consumer/consumers/base_consumer.py @@ -66,13 +66,11 @@ def __init__(self, queue, final_status='done', failed_status='failed', - name=settings.HOSTNAME, - output_dir=settings.OUTPUT_DIR): + name=settings.HOSTNAME): self.redis = redis_client self.storage = storage_client self.queue = str(queue).lower() self.name = name - self.output_dir = output_dir self.final_status = final_status self.failed_status = failed_status self.finished_statuses = {final_status, failed_status} diff --git a/redis_consumer/settings.py b/redis_consumer/settings.py index 556670e9..da814657 100644 --- a/redis_consumer/settings.py +++ b/redis_consumer/settings.py @@ -75,9 +75,8 @@ def _strip(x): # Application directories ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) DOWNLOAD_DIR = os.path.join(ROOT_DIR, 'download') -OUTPUT_DIR = os.path.join(ROOT_DIR, 'output') -for d in (DOWNLOAD_DIR, OUTPUT_DIR): +for d in (DOWNLOAD_DIR,): try: os.mkdir(d) except OSError: From d256d6351b5bcc6cc3d07e0b5cc6449499c0c099 Mon Sep 17 00:00:00 2001 From: William Graf <7930703+willgraf@users.noreply.github.com> Date: Mon, 25 Jan 2021 23:52:07 -0800 Subject: [PATCH 55/73] Replace both bucket env vars and CLOUD_PROVIDER with STORAGE_BUCKET --- .env.example | 11 ++--------- README.md | 3 +-- consume-redis-events.py | 2 +- redis_consumer/settings.py | 8 ++------ redis_consumer/storage.py | 30 ++++++++++++++++-------------- 5 files changed, 22 insertions(+), 32 deletions(-) diff --git a/.env.example b/.env.example index 01185e24..3202d9b1 100644 --- a/.env.example +++ b/.env.example @@ -9,17 +9,10 @@ REDIS_HOST= TF_PORT= TF_HOST= -# Cloud selection -CLOUD_PROVIDER= - -# DEBUG Logging -DEBUG= +# Storage bucket +STORAGE_BUCKET= # AWS Credentials AWS_REGION= -AWS_S3_BUCKET= AWS_ACCESS_KEY_ID= AWS_SECRET_ACCESS_KEY= - -# Google variables -GKE_BUCKET= diff --git a/README.md b/README.md index 753bcac7..0b09dc5c 100644 --- a/README.md +++ b/README.md @@ -89,8 +89,7 @@ The consumer is configured using environment variables. Please find a table of a | :--- | :--- | :--- | | `QUEUE` | **REQUIRED**: The Redis job queue to check for items to consume. | `"predict"` | | `CONSUMER_TYPE` | **REQUIRED**: The type of consumer to run, used in `consume-redis-events.py`. | `"image"` | -| `CLOUD_PROVIDER` | **REQUIRED**: The cloud provider, one of `"aws"` and `"gke"`. | `"gke"` | -| `GCLOUD_STORAGE_BUCKET` | **REQUIRED**: The name of the storage bucket used to download and upload files. | `"default-bucket"` | +| `STORAGE_BUCKET` | **REQUIRED**: The name of the storage bucket used to download and upload files. | `"s3://default-bucket"` | | `INTERVAL` | How frequently the consumer checks the Redis queue for items, in seconds. | `5` | | `REDIS_HOST` | The IP address or hostname of Redis. | `"redis-master"` | | `REDIS_PORT` | The port used to connect to Redis. | `6379` | diff --git a/consume-redis-events.py b/consume-redis-events.py index fa6845cb..25d2d347 100644 --- a/consume-redis-events.py +++ b/consume-redis-events.py @@ -84,7 +84,7 @@ def get_consumer(consumer_type, **kwargs): port=settings.REDIS_PORT, backoff=settings.REDIS_TIMEOUT) - storage_client = redis_consumer.storage.get_client(settings.CLOUD_PROVIDER) + storage_client = redis_consumer.storage.get_client(settings.STORAGE_BUCKET) consumer_kwargs = { 'redis_client': redis, diff --git a/redis_consumer/settings.py b/redis_consumer/settings.py index da814657..cde9e8d4 100644 --- a/redis_consumer/settings.py +++ b/redis_consumer/settings.py @@ -69,9 +69,6 @@ def _strip(x): DO_NOTHING_TIMEOUT = config('DO_NOTHING_TIMEOUT', default=0.5, cast=float) STORAGE_MAX_BACKOFF = config('STORAGE_MAX_BACKOFF', default=60, cast=float) -# Cloud storage -CLOUD_PROVIDER = config('CLOUD_PROVIDER', cast=str, default='gke').lower() - # Application directories ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) DOWNLOAD_DIR = os.path.join(ROOT_DIR, 'download') @@ -84,12 +81,11 @@ def _strip(x): # AWS Credentials AWS_REGION = config('AWS_REGION', default='us-east-1') -AWS_S3_BUCKET = config('AWS_S3_BUCKET', default='default-bucket') AWS_ACCESS_KEY_ID = config('AWS_ACCESS_KEY_ID', default='specify_me') AWS_SECRET_ACCESS_KEY = config('AWS_SECRET_ACCESS_KEY', default='specify_me') -# Google Credentials -GCLOUD_STORAGE_BUCKET = config('GKE_BUCKET', default='default-bucket') +# Cloud Storage Bucket +STORAGE_BUCKET = config('STORAGE_BUCKET', default='s3://default-bucket') # Pod Meteadta HOSTNAME = config('HOSTNAME', default='host-unkonwn') diff --git a/redis_consumer/storage.py b/redis_consumer/storage.py index fe093c78..2e66b384 100644 --- a/redis_consumer/storage.py +++ b/redis_consumer/storage.py @@ -51,23 +51,25 @@ class StorageException(Exception): pass -def get_client(cloud_provider): - """Returns the Storage Client appropriate for the cloud provider - # Arguments: - cloud_provider: Indicates which cloud platform (AWS vs GKE) - # Returns: - storage_client: Client for interacting with the cloud. +def get_client(bucket): + """Get the Storage Client appropriate for the bucket. + + Args: + bucket (str): Bucket including + + Returns: + ~Storage: Client for interacting with the cloud. """ - cloud_provider = str(cloud_provider).lower() + protocol, bucket_name = str(bucket).lower().split('://', 1) logger = logging.getLogger('storage.get_client') - if cloud_provider == 'aws': - storage_client = S3Storage(settings.AWS_S3_BUCKET) - elif cloud_provider == 'gke': - storage_client = GoogleStorage(settings.GCLOUD_STORAGE_BUCKET) + if protocol == 's3': + storage_client = S3Storage(bucket_name) + elif protocol == 'gs': + storage_client = GoogleStorage(bucket_name) else: - errmsg = 'Bad value for CLOUD_PROVIDER: %s' - logger.error(errmsg, cloud_provider) - raise ValueError(errmsg % cloud_provider) + errmsg = 'Unknown STORAGE_BUCKET protocol: %s' + logger.error(errmsg, protocol) + raise ValueError(errmsg % protocol) return storage_client From 1b8789e3663d04eae527e24bd5e855faa525bd9a Mon Sep 17 00:00:00 2001 From: William Graf <7930703+willgraf@users.noreply.github.com> Date: Tue, 26 Jan 2021 00:04:33 -0800 Subject: [PATCH 56/73] set batch_size automatically by default. --- redis_consumer/consumers/multiplex_consumer.py | 3 +-- redis_consumer/consumers/segmentation_consumer.py | 3 +-- redis_consumer/grpc_clients.py | 5 ++++- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/redis_consumer/consumers/multiplex_consumer.py b/redis_consumer/consumers/multiplex_consumer.py index 249d04de..2cbd6df8 100644 --- a/redis_consumer/consumers/multiplex_consumer.py +++ b/redis_consumer/consumers/multiplex_consumer.py @@ -116,8 +116,7 @@ def _consume(self, redis_hash): app = self.get_grpc_app(settings.MULTIPLEX_MODEL, MultiplexSegmentation) - results = app.predict(image, image_mpp=scale * app.model_mpp, - batch_size=settings.TF_MAX_BATCH_SIZE) + results = app.predict(image, image_mpp=scale * app.model_mpp) # Save the post-processed results to a file _ = timeit.default_timer() diff --git a/redis_consumer/consumers/segmentation_consumer.py b/redis_consumer/consumers/segmentation_consumer.py index b01a906c..daa12723 100644 --- a/redis_consumer/consumers/segmentation_consumer.py +++ b/redis_consumer/consumers/segmentation_consumer.py @@ -136,8 +136,7 @@ def _consume(self, redis_hash): app = self.get_grpc_app(model, app_cls) - results = app.predict(image, image_mpp=scale * app.model_mpp, - batch_size=app.model.get_batch_size()) + results = app.predict(image, image_mpp=scale * app.model_mpp) # Save the post-processed results to a file _ = timeit.default_timer() diff --git a/redis_consumer/grpc_clients.py b/redis_consumer/grpc_clients.py index 196e136f..7f31a11d 100644 --- a/redis_consumer/grpc_clients.py +++ b/redis_consumer/grpc_clients.py @@ -395,9 +395,12 @@ def get_batch_size(self): batch_size = int(settings.TF_MAX_BATCH_SIZE // ratio) return batch_size - def predict(self, tiles, batch_size): + def predict(self, tiles, batch_size=None): results = [] + if batch_size is None: + batch_size = self.get_batch_size() + for t in range(0, tiles.shape[0], batch_size): output = self.send_grpc(tiles[t:t + batch_size]) From 3c8357121eae35034cf2e4675ab4492914b80402 Mon Sep 17 00:00:00 2001 From: William Graf <7930703+willgraf@users.noreply.github.com> Date: Tue, 26 Jan 2021 00:05:12 -0800 Subject: [PATCH 57/73] Update README example with new style. --- README.md | 25 ++++++++----------------- 1 file changed, 8 insertions(+), 17 deletions(-) diff --git a/README.md b/README.md index 0b09dc5c..a3305a69 100644 --- a/README.md +++ b/README.md @@ -37,28 +37,19 @@ def _consume(self, redis_hash): redis_hash, hvals.get('status')) return hvals.get('status') - # the data to process with the model, required. - input_file_name = hvals.get('input_file_name') + # Load input image + fname = hvals.get('input_file_name') + image = self.download_image(fname) # the model can be passed in as an environment variable, # and parsed in settings.py. - model_name, model_version = 'CustomModel:1'.split(':') + model = 'NuclearSegmentation:1' - with tempfile.TemporaryDirectory() as tempdir: - # download the image file - fname = self.storage.download(input_file_name, tempdir) - # load image file as data - image = utils.get_image(fname) + # Use a custom Application from deepcell.applications + app = self.get_grpc_app(model, deepcell.applications.NuclearSegmentation) - # pre- and post-processing can be used with the BaseConsumer.process, - # which uses pre-defined functions in settings.PROCESSING_FUNCTIONS. - image = self.preprocess(image, 'normalize') - - # send the data to the model - results = self.predict(image, model_name, model_version) - - # post-process model results - image = self.postprocess(image, 'deep_watershed') + # Run the predictions on the image + results = app.predict(image) # save the results as an image file and upload it to the bucket save_name = hvals.get('original_name', fname) From 7167a423abdd563b9e37a68e34a17594ac37c5a0 Mon Sep 17 00:00:00 2001 From: William Graf <7930703+willgraf@users.noreply.github.com> Date: Tue, 26 Jan 2021 00:12:27 -0800 Subject: [PATCH 58/73] Move DOWNLOAD_DIR creation into storage client. --- redis_consumer/settings.py | 14 ++++---------- redis_consumer/storage.py | 6 ++++++ 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/redis_consumer/settings.py b/redis_consumer/settings.py index cde9e8d4..353b44d4 100644 --- a/redis_consumer/settings.py +++ b/redis_consumer/settings.py @@ -40,6 +40,10 @@ def _strip(x): return '/'.join(y for y in x.split('/') if y) +# Application directories +ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +DOWNLOAD_DIR = os.path.join(ROOT_DIR, 'download') + # Consumer settings INTERVAL = config('INTERVAL', default=10, cast=int) CONSUMER_TYPE = config('CONSUMER_TYPE', default='image') @@ -69,16 +73,6 @@ def _strip(x): DO_NOTHING_TIMEOUT = config('DO_NOTHING_TIMEOUT', default=0.5, cast=float) STORAGE_MAX_BACKOFF = config('STORAGE_MAX_BACKOFF', default=60, cast=float) -# Application directories -ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) -DOWNLOAD_DIR = os.path.join(ROOT_DIR, 'download') - -for d in (DOWNLOAD_DIR,): - try: - os.mkdir(d) - except OSError: - pass - # AWS Credentials AWS_REGION = config('AWS_REGION', default='us-east-1') AWS_ACCESS_KEY_ID = config('AWS_ACCESS_KEY_ID', default='specify_me') diff --git a/redis_consumer/storage.py b/redis_consumer/storage.py index 2e66b384..82d7026f 100644 --- a/redis_consumer/storage.py +++ b/redis_consumer/storage.py @@ -91,6 +91,12 @@ def __init__(self, bucket, self.logger = logging.getLogger(str(self.__class__.__name__)) self.max_backoff = max_backoff + # try to write the download dir in case it does not exist. + try: + os.mkdir(self.download_dir) + except OSError: + pass + def get_backoff(self, attempts): """Get backoff time based on previous number of attempts""" milis = random.randint(1, 1000) / 1000 From f91351dd8e3cd4e52c8af0984e10ec7a24003347 Mon Sep 17 00:00:00 2001 From: William Graf <7930703+willgraf@users.noreply.github.com> Date: Tue, 26 Jan 2021 00:14:56 -0800 Subject: [PATCH 59/73] Remove _strip unused helper function. --- redis_consumer/settings.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/redis_consumer/settings.py b/redis_consumer/settings.py index 353b44d4..97dc1cc5 100644 --- a/redis_consumer/settings.py +++ b/redis_consumer/settings.py @@ -35,11 +35,6 @@ import deepcell -# remove leading/trailing '/'s from cloud bucket folder names -def _strip(x): - return '/'.join(y for y in x.split('/') if y) - - # Application directories ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) DOWNLOAD_DIR = os.path.join(ROOT_DIR, 'download') From 3ccf2284f5a245c47ab35ddc438d9ca0d00cf0b3 Mon Sep 17 00:00:00 2001 From: William Graf <7930703+willgraf@users.noreply.github.com> Date: Tue, 26 Jan 2021 10:35:21 -0800 Subject: [PATCH 60/73] Move _strip into utils as a real function --- redis_consumer/consumers/base_consumer.py | 5 +++-- redis_consumer/utils.py | 5 +++++ redis_consumer/utils_test.py | 9 +++++++++ 3 files changed, 17 insertions(+), 2 deletions(-) diff --git a/redis_consumer/consumers/base_consumer.py b/redis_consumer/consumers/base_consumer.py index b84c9d7e..83f065c2 100644 --- a/redis_consumer/consumers/base_consumer.py +++ b/redis_consumer/consumers/base_consumer.py @@ -419,6 +419,7 @@ def get_image_scale(self, scale, image, redis_hash): return scale def save_output(self, image, save_name): + """Save output images into a zip file and upload it.""" with tempfile.TemporaryDirectory() as tempdir: # Save each result channel as an image file subdir = os.path.dirname(save_name.replace(tempdir, '')) @@ -439,7 +440,7 @@ def save_output(self, image, save_name): # Upload the zip file to cloud storage bucket cleaned = zip_file.replace(tempdir, '') - subdir = os.path.dirname(settings._strip(cleaned)) + subdir = os.path.dirname(utils.strip_bucket_path(cleaned)) subdir = subdir if subdir else None dest, output_url = self.storage.upload(zip_file, subdir=subdir) @@ -476,7 +477,7 @@ def _upload_archived_images(self, hvalues, redis_hash): image_files = utils.get_image_files_from_dir(fname, tempdir) for i, imfile in enumerate(image_files): - clean_imfile = settings._strip(imfile.replace(tempdir, '')) + clean_imfile = utils.strip_bucket_path(imfile.replace(tempdir, '')) # Save each result channel as an image file subdir = os.path.join(archive_uuid, os.path.dirname(clean_imfile)) dest, _ = self.storage.upload(imfile, subdir=subdir) diff --git a/redis_consumer/utils.py b/redis_consumer/utils.py index bafd476f..c24b0cf2 100644 --- a/redis_consumer/utils.py +++ b/redis_consumer/utils.py @@ -46,6 +46,11 @@ logger = logging.getLogger('redis_consumer.utils') +def strip_bucket_path(path): + """Remove leading/trailing '/'s from cloud bucket folder names""" + return '/'.join(y for y in path.split('/') if y) + + def iter_image_archive(zip_path, destination): """Extract all files in archive and yield the paths of all images. diff --git a/redis_consumer/utils_test.py b/redis_consumer/utils_test.py index 1fc964f5..59a80d31 100644 --- a/redis_consumer/utils_test.py +++ b/redis_consumer/utils_test.py @@ -70,6 +70,15 @@ def _write_trks(filepath, X_mean=10, y_mean=5, trks.add(tracked_file.name, 'tracked.npy') +def test_strip_bucket_path(): + path = 'path/to/file' + # leading, trailing, and both format strings + format_strs = ['/{}', '{}/', '/{}/'] + for fmtstr in format_strs: + stripped = utils.strip_bucket_path(fmtstr.format(path)) + assert path == stripped + + def test_iter_image_archive(tmpdir): tmpdir = str(tmpdir) zip_path = os.path.join(tmpdir, 'test.zip') From d65d98e27cb9dc04136ab0436002b2e902bdea82 Mon Sep 17 00:00:00 2001 From: William Graf <7930703+willgraf@users.noreply.github.com> Date: Tue, 26 Jan 2021 10:35:39 -0800 Subject: [PATCH 61/73] Update storage.get_client to expect bucket protocols. --- redis_consumer/storage.py | 6 +++++- redis_consumer/storage_test.py | 14 ++++++++------ 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/redis_consumer/storage.py b/redis_consumer/storage.py index 82d7026f..c432abde 100644 --- a/redis_consumer/storage.py +++ b/redis_consumer/storage.py @@ -60,7 +60,11 @@ def get_client(bucket): Returns: ~Storage: Client for interacting with the cloud. """ - protocol, bucket_name = str(bucket).lower().split('://', 1) + try: + protocol, bucket_name = str(bucket).lower().split('://', 1) + except ValueError: + raise ValueError('Invalid storage bucket name: {}'.format(bucket)) + logger = logging.getLogger('storage.get_client') if protocol == 's3': storage_client = S3Storage(bucket_name) diff --git a/redis_consumer/storage_test.py b/redis_consumer/storage_test.py index ebde4f66..11ef0cda 100644 --- a/redis_consumer/storage_test.py +++ b/redis_consumer/storage_test.py @@ -109,18 +109,20 @@ def upload_file(self, path, bucket, dest, **_): def test_get_client(): - aws = storage.get_client('aws') - AWS = storage.get_client('AWS') + aws = storage.get_client('s3://bucket') + AWS = storage.get_client('S3://anotherbucket') assert isinstance(aws, type(AWS)) # TODO: set GCLOUD env vars to test this # with pytest.raises(OSError): - gke = storage.get_client('gke') - GKE = storage.get_client('GKE') + gke = storage.get_client('gs://bucket') + GKE = storage.get_client('GS://anotherbucket') assert isinstance(gke, type(GKE)) - with pytest.raises(ValueError): - _ = storage.get_client('bad_value') + bad_values = ['s3', 'gs', 's3:/badval', 'gs//badval'] + for bad_value in bad_values: + with pytest.raises(ValueError): + _ = storage.get_client(bad_value) class TestStorage(object): From 99680aac85a9a59bf9072133722f6c7fa3cca1f5 Mon Sep 17 00:00:00 2001 From: William Graf <7930703+willgraf@users.noreply.github.com> Date: Tue, 26 Jan 2021 10:43:18 -0800 Subject: [PATCH 62/73] Update validate_model_input to expect batch dimension --- redis_consumer/consumers/base_consumer.py | 6 +++--- .../consumers/base_consumer_test.py | 19 ++++++++++--------- .../consumers/multiplex_consumer.py | 3 +-- .../consumers/segmentation_consumer.py | 3 --- 4 files changed, 14 insertions(+), 17 deletions(-) diff --git a/redis_consumer/consumers/base_consumer.py b/redis_consumer/consumers/base_consumer.py index 83f065c2..532d2945 100644 --- a/redis_consumer/consumers/base_consumer.py +++ b/redis_consumer/consumers/base_consumer.py @@ -269,14 +269,14 @@ def validate_model_input(self, image, model_name, model_version): validated = [] for img, shape in zip(image, shapes): - rank = len(shape) - 1 # ignoring batch dimension + rank = len(shape) # expects a batch dimension channels = shape[-1] if len(img.shape) != rank: raise ValueError(errtext) - if img.shape[0] == channels: - img = np.rollaxis(img, 0, rank) + if img.shape[1] == channels: + img = np.rollaxis(img, 1, rank) if img.shape[rank - 1] != channels: raise ValueError(errtext) diff --git a/redis_consumer/consumers/base_consumer_test.py b/redis_consumer/consumers/base_consumer_test.py index 59af2d95..7a8c6476 100644 --- a/redis_consumer/consumers/base_consumer_test.py +++ b/redis_consumer/consumers/base_consumer_test.py @@ -248,10 +248,10 @@ def test_validate_model_input(self, mocker, redis_client): # test valid channels last shapes valid_input_shapes = [ - (32, 32, 1), # exact same shape - (64, 64, 1), # bigger - (32, 32, 1), # smaller - (33, 31, 1), # mixed + (1, 32, 32, 1), # exact same shape + (1, 64, 64, 1), # bigger + (1, 32, 32, 1), # smaller + (1, 33, 31, 1), # mixed ] for shape in valid_input_shapes: # check channels last @@ -260,17 +260,18 @@ def test_validate_model_input(self, mocker, redis_client): np.testing.assert_array_equal(img, valid_img) # should also work for channels first - img = np.rollaxis(img, -1, 0) + img = np.rollaxis(img, -1, 1) valid_img = consumer.validate_model_input(img, 'model', '1') expected_img = np.rollaxis(img, 0, img.ndim) np.testing.assert_array_equal(expected_img, valid_img) # test invalid shapes invalid_input_shapes = [ - (32, 1), # rank too small - (32, 32, 32, 1), # rank too large - (32, 32, 2), # wrong channels - (16, 64, 2), # wrong channels with mixed shape + (32, 32, 1), # no batch dimension + (1, 32, 1), # rank too small + (1, 32, 32, 32, 1), # rank too large + (1, 32, 32, 2), # wrong channels + (1, 16, 64, 2), # wrong channels with mixed shape ] for shape in invalid_input_shapes: img = np.ones(shape) diff --git a/redis_consumer/consumers/multiplex_consumer.py b/redis_consumer/consumers/multiplex_consumer.py index 2cbd6df8..c806df6a 100644 --- a/redis_consumer/consumers/multiplex_consumer.py +++ b/redis_consumer/consumers/multiplex_consumer.py @@ -93,6 +93,7 @@ def _consume(self, redis_hash): # Load input image fname = hvals.get('input_file_name') image = self.download_image(fname) + image = np.expand_dims(image, axis=0) # add in the batch dim # squeeze extra dimension that is added by get_image image = np.squeeze(image) @@ -107,8 +108,6 @@ def _consume(self, redis_hash): scale = hvals.get('scale', '') scale = self.get_image_scale(scale, image, redis_hash) - image = np.expand_dims(image, axis=0) # add in the batch dim - # Validate input image image = self.validate_model_input(image, model_name, model_version) diff --git a/redis_consumer/consumers/segmentation_consumer.py b/redis_consumer/consumers/segmentation_consumer.py index daa12723..cdd7ac36 100644 --- a/redis_consumer/consumers/segmentation_consumer.py +++ b/redis_consumer/consumers/segmentation_consumer.py @@ -126,10 +126,7 @@ def _consume(self, redis_hash): model_name, model_version = model.split(':') # Validate input image - # TODO: batch dimension wonkiness - image = image[0] # remove batch dimension image = self.validate_model_input(image, model_name, model_version) - image = np.expand_dims(image, axis=0) # add batch dim back # Send data to the model self.update_key(redis_hash, {'status': 'predicting'}) From f568d390caaa598e101f1b66cea7eac2f7f4b835 Mon Sep 17 00:00:00 2001 From: William Graf <7930703+willgraf@users.noreply.github.com> Date: Tue, 26 Jan 2021 11:10:53 -0800 Subject: [PATCH 63/73] 100% test coverage for storage.py --- redis_consumer/storage_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/redis_consumer/storage_test.py b/redis_consumer/storage_test.py index 11ef0cda..c1ccd635 100644 --- a/redis_consumer/storage_test.py +++ b/redis_consumer/storage_test.py @@ -119,7 +119,7 @@ def test_get_client(): GKE = storage.get_client('GS://anotherbucket') assert isinstance(gke, type(GKE)) - bad_values = ['s3', 'gs', 's3:/badval', 'gs//badval'] + bad_values = ['s3', 'gs', 's3:/badval', 'gs//badval', 'other://bucket'] for bad_value in bad_values: with pytest.raises(ValueError): _ = storage.get_client(bad_value) From 2aecf8e6bff71ad310a8fd7a7080ec1e3d5faff2 Mon Sep 17 00:00:00 2001 From: William Graf <7930703+willgraf@users.noreply.github.com> Date: Tue, 26 Jan 2021 11:16:30 -0800 Subject: [PATCH 64/73] fix storage tests, the random integer MAY be the maximum of randint. --- redis_consumer/storage_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/redis_consumer/storage_test.py b/redis_consumer/storage_test.py index c1ccd635..75979dd7 100644 --- a/redis_consumer/storage_test.py +++ b/redis_consumer/storage_test.py @@ -131,10 +131,10 @@ def test_get_backoff(self): max_backoff = 30 client = storage.Storage('bucket', max_backoff=max_backoff) backoff = client.get_backoff(attempts=0) - assert 1 < backoff < 2 + assert 1 < backoff <= 2 backoff = client.get_backoff(attempts=3) - assert 8 < backoff < 9 + assert 8 < backoff <= 9 backoff = client.get_backoff(attempts=5) assert backoff == max_backoff From 3138cbbd46743adff690a8724e5abf27fbcbb604 Mon Sep 17 00:00:00 2001 From: William Graf <7930703+willgraf@users.noreply.github.com> Date: Tue, 26 Jan 2021 11:23:03 -0800 Subject: [PATCH 65/73] test no batch size. --- redis_consumer/grpc_clients_test.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/redis_consumer/grpc_clients_test.py b/redis_consumer/grpc_clients_test.py index d5e6a078..4baecbf8 100644 --- a/redis_consumer/grpc_clients_test.py +++ b/redis_consumer/grpc_clients_test.py @@ -154,3 +154,7 @@ def mock_send_grpc(img): results = wrapper.predict(input_data, batch_size=batch_size) np.testing.assert_array_equal(input_data, results) + + # no batch size + results = wrapper.predict(input_data) + np.testing.assert_array_equal(input_data, results) From 946a4e91ec5b6fa1cc5ae013cdacc4b5028df371 Mon Sep 17 00:00:00 2001 From: William Graf <7930703+willgraf@users.noreply.github.com> Date: Tue, 26 Jan 2021 11:55:06 -0800 Subject: [PATCH 66/73] add the batch dimension after the squeeze. --- redis_consumer/consumers/multiplex_consumer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/redis_consumer/consumers/multiplex_consumer.py b/redis_consumer/consumers/multiplex_consumer.py index c806df6a..f545e53d 100644 --- a/redis_consumer/consumers/multiplex_consumer.py +++ b/redis_consumer/consumers/multiplex_consumer.py @@ -93,10 +93,11 @@ def _consume(self, redis_hash): # Load input image fname = hvals.get('input_file_name') image = self.download_image(fname) - image = np.expand_dims(image, axis=0) # add in the batch dim # squeeze extra dimension that is added by get_image image = np.squeeze(image) + # add in the batch dim + image = np.expand_dims(image, axis=0) # Pre-process data before sending to the model self.update_key(redis_hash, { From a1c886a2f25d49695f8cc30eeff63f6ced5d565c Mon Sep 17 00:00:00 2001 From: William Graf <7930703+willgraf@users.noreply.github.com> Date: Tue, 2 Feb 2021 12:28:15 -0800 Subject: [PATCH 67/73] Update toolbox and tracking to latest releases. --- requirements.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/requirements.txt b/requirements.txt index e08f012f..2f73454b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ # deepcell packages deepcell-cpu>=0.8.3 -deepcell-tracking==0.3.0 -deepcell-toolbox>=0.8.4 +deepcell-tracking==0.3.1 +deepcell-toolbox>=0.8.5 tensorflow-cpu scikit-image>=0.14.5,<0.17.0 numpy>=1.16.6 From 062fd49463da8531f20e972f887957afb5506d68 Mon Sep 17 00:00:00 2001 From: William Graf <7930703+willgraf@users.noreply.github.com> Date: Tue, 2 Feb 2021 16:32:17 -0800 Subject: [PATCH 68/73] Specify batch_size as None to auto-generate --- redis_consumer/consumers/multiplex_consumer.py | 3 ++- redis_consumer/consumers/segmentation_consumer.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/redis_consumer/consumers/multiplex_consumer.py b/redis_consumer/consumers/multiplex_consumer.py index f545e53d..142e9777 100644 --- a/redis_consumer/consumers/multiplex_consumer.py +++ b/redis_consumer/consumers/multiplex_consumer.py @@ -116,7 +116,8 @@ def _consume(self, redis_hash): app = self.get_grpc_app(settings.MULTIPLEX_MODEL, MultiplexSegmentation) - results = app.predict(image, image_mpp=scale * app.model_mpp) + results = app.predict(image, batch_size=None, + image_mpp=scale * app.model_mpp) # Save the post-processed results to a file _ = timeit.default_timer() diff --git a/redis_consumer/consumers/segmentation_consumer.py b/redis_consumer/consumers/segmentation_consumer.py index cdd7ac36..1a3364e1 100644 --- a/redis_consumer/consumers/segmentation_consumer.py +++ b/redis_consumer/consumers/segmentation_consumer.py @@ -133,7 +133,8 @@ def _consume(self, redis_hash): app = self.get_grpc_app(model, app_cls) - results = app.predict(image, image_mpp=scale * app.model_mpp) + results = app.predict(image, batch_size=None, + image_mpp=scale * app.model_mpp) # Save the post-processed results to a file _ = timeit.default_timer() From f43bd4cb7d04557e84e4dd2ae9e223d52d307d55 Mon Sep 17 00:00:00 2001 From: William Graf <7930703+willgraf@users.noreply.github.com> Date: Tue, 2 Feb 2021 16:57:08 -0800 Subject: [PATCH 69/73] temporary workaround for nuclear segmentation --- redis_consumer/consumers/segmentation_consumer.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/redis_consumer/consumers/segmentation_consumer.py b/redis_consumer/consumers/segmentation_consumer.py index 1a3364e1..6a15e522 100644 --- a/redis_consumer/consumers/segmentation_consumer.py +++ b/redis_consumer/consumers/segmentation_consumer.py @@ -33,6 +33,7 @@ import numpy as np from deepcell.applications import LabelDetection +from deepcell_toolbox.processing import normalize from redis_consumer.consumers import TensorFlowServingConsumer from redis_consumer import settings @@ -133,8 +134,12 @@ def _consume(self, redis_hash): app = self.get_grpc_app(model, app_cls) + # Temporary patch + app.preprocessing_fn = normalize + results = app.predict(image, batch_size=None, - image_mpp=scale * app.model_mpp) + image_mpp=scale * app.model_mpp, + preprocess_kwargs={}) # Save the post-processed results to a file _ = timeit.default_timer() From e932ecc54567503e5be76501da8367f3607c7320 Mon Sep 17 00:00:00 2001 From: William Graf <7930703+willgraf@users.noreply.github.com> Date: Fri, 5 Feb 2021 11:51:35 -0800 Subject: [PATCH 70/73] Update deepcell and deepcell-toolbox versions --- requirements.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/requirements.txt b/requirements.txt index 2f73454b..4c8ed73b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ # deepcell packages -deepcell-cpu>=0.8.3 +deepcell-cpu>=0.8.6 deepcell-tracking==0.3.1 -deepcell-toolbox>=0.8.5 +deepcell-toolbox>=0.8.6 tensorflow-cpu scikit-image>=0.14.5,<0.17.0 numpy>=1.16.6 From 72b3327e730e8a7b59f041e8b33a16dacc93b9cc Mon Sep 17 00:00:00 2001 From: William Graf <7930703+willgraf@users.noreply.github.com> Date: Fri, 5 Feb 2021 11:52:21 -0800 Subject: [PATCH 71/73] Remove preprocessing workaround for segmentation consumer. --- redis_consumer/consumers/segmentation_consumer.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/redis_consumer/consumers/segmentation_consumer.py b/redis_consumer/consumers/segmentation_consumer.py index 6a15e522..a35728c4 100644 --- a/redis_consumer/consumers/segmentation_consumer.py +++ b/redis_consumer/consumers/segmentation_consumer.py @@ -134,12 +134,8 @@ def _consume(self, redis_hash): app = self.get_grpc_app(model, app_cls) - # Temporary patch - app.preprocessing_fn = normalize - results = app.predict(image, batch_size=None, - image_mpp=scale * app.model_mpp, - preprocess_kwargs={}) + image_mpp=scale * app.model_mpp) # Save the post-processed results to a file _ = timeit.default_timer() From f6339b07d889c6de4bd52ae9be81e9992221757e Mon Sep 17 00:00:00 2001 From: William Graf <7930703+willgraf@users.noreply.github.com> Date: Fri, 5 Feb 2021 11:59:35 -0800 Subject: [PATCH 72/73] Add TODO in predict() --- redis_consumer/grpc_clients.py | 1 + 1 file changed, 1 insertion(+) diff --git a/redis_consumer/grpc_clients.py b/redis_consumer/grpc_clients.py index 2aa8b78d..64b0b93a 100644 --- a/redis_consumer/grpc_clients.py +++ b/redis_consumer/grpc_clients.py @@ -414,6 +414,7 @@ def get_batch_size(self): return batch_size def predict(self, tiles, batch_size=None): + # TODO: Can the result size be known beforehand via model metadata? results = [] if batch_size is None: From b1ba2a98ca70ffac1c0bc61eb374497c33bde646 Mon Sep 17 00:00:00 2001 From: willgraf <7930703+willgraf@users.noreply.github.com> Date: Mon, 8 Feb 2021 10:28:31 -0800 Subject: [PATCH 73/73] Update TrackingConsumer to use GrpcModelWrapper (#158) * Update TrackingConumser to use deepcell.applications * GrpcModelWrapper handles multiple input tensors and dictionaries of inputs. * convert hvalues to hvals for better consistency. * Update GrpcModelWrapper tests for multiple inputs and dictionary inputs. --- redis_consumer/__init__.py | 1 - redis_consumer/consumers/base_consumer.py | 4 +- redis_consumer/consumers/tracking_consumer.py | 120 +++----------- .../consumers/tracking_consumer_test.py | 67 +++----- redis_consumer/grpc_clients.py | 151 ++++++------------ redis_consumer/grpc_clients_test.py | 48 ++++-- redis_consumer/settings.py | 11 +- redis_consumer/tracking.py | 59 ------- redis_consumer/tracking_test.py | 106 ------------ 9 files changed, 139 insertions(+), 428 deletions(-) delete mode 100644 redis_consumer/tracking.py delete mode 100644 redis_consumer/tracking_test.py diff --git a/redis_consumer/__init__.py b/redis_consumer/__init__.py index 2192c61c..0ee2ddaa 100644 --- a/redis_consumer/__init__.py +++ b/redis_consumer/__init__.py @@ -32,7 +32,6 @@ from redis_consumer import redis from redis_consumer import settings from redis_consumer import storage -from redis_consumer import tracking from redis_consumer import utils del absolute_import diff --git a/redis_consumer/consumers/base_consumer.py b/redis_consumer/consumers/base_consumer.py index 532d2945..a4a01316 100644 --- a/redis_consumer/consumers/base_consumer.py +++ b/redis_consumer/consumers/base_consumer.py @@ -363,7 +363,7 @@ def get_model_metadata(self, model_name, model_version): self.logger.error('Malformed metadata: %s', model_metadata) raise err - def get_grpc_app(self, model, application_cls): + def get_grpc_app(self, model, application_cls, **kwargs): """ Create an application from deepcell.applications with a gRPC model wrapper as a model @@ -372,7 +372,7 @@ def get_grpc_app(self, model, application_cls): model_metadata = self.get_model_metadata(model_name, model_version) client = self._get_predict_client(model_name, model_version) model_wrapper = GrpcModelWrapper(client, model_metadata) - return application_cls(model_wrapper) + return application_cls(model_wrapper, **kwargs) def detect_scale(self, image): """Send the image to the SCALE_DETECT_MODEL to detect the relative diff --git a/redis_consumer/consumers/tracking_consumer.py b/redis_consumer/consumers/tracking_consumer.py index 9f88e0ec..a3257956 100644 --- a/redis_consumer/consumers/tracking_consumer.py +++ b/redis_consumer/consumers/tracking_consumer.py @@ -38,13 +38,12 @@ from skimage.external import tifffile import numpy as np -from deepcell_toolbox.processing import normalize from deepcell_toolbox.processing import correct_drift -from redis_consumer.grpc_clients import TrackingClient +from deepcell.applications import CellTracking + from redis_consumer.consumers import TensorFlowServingConsumer from redis_consumer import utils -from redis_consumer import tracking from redis_consumer import settings @@ -68,58 +67,6 @@ def is_valid_hash(self, redis_hash): return valid_file - def _get_model(self, redis_hash, hvalues): - hostname = '{}:{}'.format(settings.TF_HOST, settings.TF_PORT) - - # Pick model based on redis or default setting - model = hvalues.get('model_name', '') - version = hvalues.get('model_version', '') - if not model or not version: - model, version = settings.TRACKING_MODEL.split(':') - - t = timeit.default_timer() - model = TrackingClient(host=hostname, - redis_hash=redis_hash, - model_name=model, - model_version=int(version), - progress_callback=self._update_progress) - - self.logger.debug('Created the TrackingClient in %s seconds.', - timeit.default_timer() - t) - return model - - def _get_tracker(self, redis_hash, hvalues, raw, segmented): - self.logger.debug('Creating tracker...') - t = timeit.default_timer() - tracking_model = self._get_model(redis_hash, hvalues) - - # Some tracking models do not have an ImageNormalization Layer. - # If not, the data must be normalized before being tracked. - if settings.NORMALIZE_TRACKING: - for frame in range(raw.shape[0]): - raw[frame, ..., 0] = normalize(raw[frame, ..., 0]) - - features = {'appearance', 'distance', 'neighborhood', 'regionprop'} - tracker = tracking.CellTracker( - raw, segmented, - tracking_model, - max_distance=settings.MAX_DISTANCE, - track_length=settings.TRACK_LENGTH, - division=settings.DIVISION, - birth=settings.BIRTH, - death=settings.DEATH, - neighborhood_scale_size=settings.NEIGHBORHOOD_SCALE_SIZE, - features=features) - - self.logger.debug('Created Tracker in %s seconds.', - timeit.default_timer() - t) - return tracker - - def _update_progress(self, redis_hash, progress): - self.update_key(redis_hash, { - 'progress': progress, - }) - def _load_data(self, redis_hash, subdir, fname): """ Given the upload location `input_file_name`, and the downloaded @@ -151,20 +98,6 @@ def _load_data(self, redis_hash, subdir, fname): 'with 3 dimensions, (time, width, height)'.format( tiff_stack.shape)) - # Calculate scale of a subset of raw - scale = hvalues.get('scale', '') - scale = scale if settings.SCALE_DETECT_ENABLED else 1 - - # Pick model and postprocess based on either label or defaults - if settings.LABEL_DETECT_ENABLED: - # model and postprocessing will be determined automatically - # by the ImageFileConsumer - model_name, model_version = '', '' - postprocess_function = '' - else: - model_name, model_version = settings.TRACKING_SEGMENT_MODEL.split(':') - postprocess_function = settings.TRACKING_POSTPROCESS_FUNCTION - num_frames = len(tiff_stack) hash_to_frame = {} remaining_hashes = set() @@ -190,15 +123,10 @@ def _load_data(self, redis_hash, subdir, fname): 'identity_upload': self.name, 'input_file_name': upload_file_name, 'original_name': segment_fname, - 'model_name': model_name, - 'model_version': model_version, - 'postprocess_function': postprocess_function, 'status': 'new', 'created_at': current_timestamp, 'updated_at': current_timestamp, - 'url': upload_file_url, - 'scale': scale, - # 'label': str(label) + 'url': upload_file_url } # make a hash for this frame @@ -259,19 +187,21 @@ def _load_data(self, redis_hash, subdir, fname): labels = [frames[i] for i in range(num_frames)] # Cast y to int to avoid issues during fourier transform/drift correction - return {'X': np.expand_dims(tiff_stack, axis=-1), - 'y': np.array(labels, dtype='uint16')} + y = np.array(labels, dtype='uint16') + # TODO: Why is there an extra dimension? + # Not a problem in tests, only with application based results. + # Issue with batch dimension from outputs? + y = y[:, 0] if y.shape[1] == 1 else y + return {'X': np.expand_dims(tiff_stack, axis=-1), 'y': y} def _consume(self, redis_hash): start = timeit.default_timer() - hvalues = self.redis.hgetall(redis_hash) - self.logger.debug('Found `%s:*` hash to process "%s": %s', - self.queue, redis_hash, json.dumps(hvalues, indent=4)) + hvals = self.redis.hgetall(redis_hash) - if hvalues.get('status') in self.finished_statuses: + if hvals.get('status') in self.finished_statuses: self.logger.warning('Found completed hash `%s` with status %s.', - redis_hash, hvalues.get('status')) - return hvalues.get('status') + redis_hash, hvals.get('status')) + return hvals.get('status') # Set status and initial progress self.update_key(redis_hash, { @@ -281,7 +211,7 @@ def _consume(self, redis_hash): }) with tempfile.TemporaryDirectory() as tempdir: - fname = self.storage.download(hvalues.get('input_file_name'), + fname = self.storage.download(hvals.get('input_file_name'), tempdir) data = self._load_data(redis_hash, tempdir, fname) @@ -296,32 +226,32 @@ def _consume(self, redis_hash): self.logger.debug('Drift correction complete in %s seconds.', timeit.default_timer() - t) - # TODO: Add support for rescaling in the tracker - tracker = self._get_tracker(redis_hash, hvalues, data['X'], data['y']) + # Send data to the model + app = self.get_grpc_app(settings.TRACKING_MODEL, CellTracking, + birth=settings.BIRTH, + death=settings.DEATH, + division=settings.DIVISION, + track_length=settings.TRACK_LENGTH) - self.logger.debug('Trying to track...') + self.logger.debug('Tracking...') self.update_key(redis_hash, {'status': 'predicting'}) - tracker.track_cells() + results = app.predict(data['X'], data['y']) self.logger.debug('Tracking done!') - self.update_key(redis_hash, {'status': 'post-processing'}) - # Post-process and save the output file - tracked_data = tracker.postprocess() - self.update_key(redis_hash, {'status': 'saving-results'}) with tempfile.TemporaryDirectory() as tempdir: # Save lineage data to JSON file lineage_file = os.path.join(tempdir, 'lineage.json') with open(lineage_file, 'w') as fp: - json.dump(tracked_data['tracks'], fp) + json.dump(results['tracks'], fp) - save_name = hvalues.get('original_name', fname) + save_name = hvals.get('original_name', fname) subdir = os.path.dirname(save_name.replace(tempdir, '')) name = os.path.splitext(os.path.basename(save_name))[0] # Save tracked data as tiff stack outpaths = utils.save_numpy_array( - tracked_data['y_tracked'], name=name, + results['y_tracked'], name=name, subdir=subdir, output_dir=tempdir) outpaths.append(lineage_file) diff --git a/redis_consumer/consumers/tracking_consumer_test.py b/redis_consumer/consumers/tracking_consumer_test.py index 19bc03ff..b462536b 100644 --- a/redis_consumer/consumers/tracking_consumer_test.py +++ b/redis_consumer/consumers/tracking_consumer_test.py @@ -37,31 +37,12 @@ import numpy as np from skimage.external import tifffile -import redis_consumer from redis_consumer import consumers from redis_consumer import settings -from redis_consumer.testing_utils import DummyStorage, redis_client, _get_image - - -class DummyTracker(object): - # pylint: disable=R0201,W0613 - def __init__(self, *_, **__): - pass - - def _track_cells(self): - return None - - def track_cells(self): - return None - - def dump(self, *_, **__): - return None - - def postprocess(self, *_, **__): - return { - 'y_tracked': np.zeros((32, 32, 1)), - 'tracks': [] - } +from redis_consumer.testing_utils import Bunch +from redis_consumer.testing_utils import DummyStorage +from redis_consumer.testing_utils import redis_client +from redis_consumer.testing_utils import _get_image class TestTrackingConsumer(object): @@ -84,16 +65,6 @@ def test_is_valid_hash(self, mocker, redis_client): assert consumer.is_valid_hash('track:1234567890:file.trk') is True assert consumer.is_valid_hash('track:1234567890:file.trks') is True - def test__update_progress(self, redis_client): - queue = 'track' - storage = DummyStorage() - consumer = consumers.TrackingConsumer(redis_client, storage, queue) - - redis_hash = 'a job hash' - progress = random.randint(0, 99) - consumer._update_progress(redis_hash, progress) - assert int(redis_client.hget(redis_hash, 'progress')) == progress - def test__load_data(self, tmpdir, mocker, redis_client): queue = 'track' storage = DummyStorage() @@ -163,28 +134,30 @@ def write_child_tiff(*_, **__): lambda *x: range(1, 3)) consumer._load_data(key, tmpdir, fname) - def test__get_tracker(self, mocker, redis_client): - queue = 'track' - storage = DummyStorage() - - shape = (5, 21, 21, 1) - raw = np.random.random(shape) - segmented = np.random.randint(1, 10, size=shape) - - mocker.patch.object(settings, 'NORMALIZE_TRACKING', True) - consumer = consumers.TrackingConsumer(redis_client, storage, queue) - tracker = consumer._get_tracker('item1', {}, raw, segmented) - assert isinstance(tracker, redis_consumer.tracking.CellTracker) - def test__consume(self, mocker, redis_client): queue = 'track' storage = DummyStorage() test_hash = 0 + dummy_results = { + 'y_tracked': np.zeros((32, 32, 1)), + 'tracks': [] + } + + mock_app = Bunch( + predict=lambda *x, **y: dummy_results, + track=lambda *x, **y: dummy_results, + model_mpp=1, + model=Bunch( + get_batch_size=lambda *x: 1, + input_shape=(1, 32, 32, 1) + ) + ) + consumer = consumers.TrackingConsumer(redis_client, storage, queue) - mocker.patch.object(consumer, '_get_tracker', DummyTracker) mocker.patch.object(settings, 'DRIFT_CORRECT_ENABLED', True) + mocker.patch.object(consumer, 'get_grpc_app', lambda *x, **y: mock_app) frames = 3 dummy_data = { diff --git a/redis_consumer/grpc_clients.py b/redis_consumer/grpc_clients.py index 64b0b93a..45d5f95b 100644 --- a/redis_consumer/grpc_clients.py +++ b/redis_consumer/grpc_clients.py @@ -276,82 +276,6 @@ def get_model_metadata(self, request_timeout=10): return response_dict -class TrackingClient(PredictClient): - """gRPC Client for tensorflow-serving API. - - Arguments: - host: string, the hostname and port of the server (`localhost:8080`) - model_name: string, name of model served by tensorflow-serving - model_version: integer, version of the named model - """ - - def __init__(self, host, model_name, model_version, - redis_hash, progress_callback): - self.redis_hash = redis_hash - self.progress_callback = progress_callback - super(TrackingClient, self).__init__(host, model_name, model_version) - - def predict(self, data, request_timeout=10): - t = timeit.default_timer() - self.logger.info('Tracking data with %s model %s:%s.', - self.host, self.model_name, self.model_version) - - # TODO: features should be retrieved from model metadata - features = {'appearance', 'distance', 'neighborhood', 'regionprop'} - features = sorted(features) - # Grab a random value from the data dict and select batch dim (num cell comparisons) - batch_size = next(iter(data.values())).shape[0] - self.logger.info('batch size: %s', batch_size) - results = [] - # from 0 to Num Cells (comparisons) in increments of TF batch size - for b in range(0, batch_size, settings.TF_MAX_BATCH_SIZE): - request_data = [] - for f in features: - input_name1 = '{}_input1'.format(f) - d1 = { - 'in_tensor_name': input_name1, - 'in_tensor_dtype': 'DT_FLOAT', - 'data': data[input_name1][b:b + settings.TF_MAX_BATCH_SIZE] - } - request_data.append(d1) - - input_name2 = '{}_input2'.format(f) - d2 = { - 'in_tensor_name': input_name2, - 'in_tensor_dtype': 'DT_FLOAT', - 'data': data[input_name2][b:b + settings.TF_MAX_BATCH_SIZE] - } - request_data.append(d2) - - response_dict = super(TrackingClient, self).predict( - request_data, request_timeout) - - output = [response_dict[k] for k in sorted(response_dict.keys())] - if len(output) == 1: - output = output[0] - - if len(results) == 0: - results = output - else: - results = np.vstack((results, output)) - - self.logger.info('Tracked %s input pairs in %s seconds.', - batch_size, timeit.default_timer() - t) - - return np.array(results) - - def progress(self, progress): - """Update the internal state regarding progress - - Arguments: - progress: float, the progress in the interval [0, 1] - """ - progress *= 100 - # clamp to an integer between 0 and 100 - progress = min(100, max(0, round(progress))) - self.progress_callback(self.redis_hash, progress) - - class GrpcModelWrapper(object): """A wrapper class that mocks a Keras model using a gRPC client. @@ -360,18 +284,15 @@ class GrpcModelWrapper(object): def __init__(self, client, model_metadata): self._client = client + self._metadata = model_metadata - if len(model_metadata) > 1: - # TODO: how to handle this? - raise NotImplementedError('Multiple input tensors are not supported.') - - self._metadata = model_metadata[0] - - self._in_tensor_name = self._metadata['in_tensor_name'] - self._in_tensor_dtype = str(self._metadata['in_tensor_dtype']).upper() - - shape = [int(x) for x in self._metadata['in_tensor_shape'].split(',')] - self.input_shape = tuple(shape) + shapes = [ + tuple([int(x) for x in m['in_tensor_shape'].split(',')]) + for m in self._metadata + ] + if len(shapes) == 1: + shapes = shapes[0] + self.input_shape = shapes def send_grpc(self, img): """Use the TensorFlow Serving gRPC API for model inference on an image. @@ -383,14 +304,30 @@ def send_grpc(self, img): numpy.array: The results of model inference. """ start = timeit.default_timer() - if self._in_tensor_dtype == 'DT_HALF': - # TODO: seems like should cast to "half" - # but the model rejects the type, wants "int" or "long" - img = img.astype('int') - req_data = [{'in_tensor_name': self._in_tensor_name, - 'in_tensor_dtype': self._in_tensor_dtype, - 'data': img}] + # cast input as list + if not isinstance(img, list): + img = [img] + + if len(self._metadata) != len(img): + raise ValueError('Expected {} model inputs but got {}.'.format( + len(self._metadata), len(img))) + + req_data = [] + + for i, m in enumerate(self._metadata): + data = img[i] + + if m['in_tensor_dtype'] == 'DT_HALF': + # seems like should cast to "half" + # but the model rejects the type, wants "int" or "long" + data = data.astype('int') + + req_data.append({ + 'in_tensor_name': m['in_tensor_name'], + 'in_tensor_dtype': m['in_tensor_dtype'], + 'data': data + }) prediction = self._client.predict(req_data, settings.GRPC_TIMEOUT) results = [prediction[k] for k in sorted(prediction.keys())] @@ -405,10 +342,16 @@ def get_batch_size(self): """Calculate the best batch size based on TF_MAX_BATCH_SIZE and TF_MIN_MODEL_SIZE """ - rank = len(self.input_shape) - ratio = (self.input_shape[rank - 3] / settings.TF_MIN_MODEL_SIZE) * \ - (self.input_shape[rank - 2] / settings.TF_MIN_MODEL_SIZE) * \ - (self.input_shape[rank - 1]) + input_shape = self.input_shape + if not isinstance(input_shape, list): + input_shape = [input_shape] + + ratio = 1 + for shape in input_shape: + rank = len(shape) + ratio *= (shape[rank - 3] / settings.TF_MIN_MODEL_SIZE) * \ + (shape[rank - 2] / settings.TF_MIN_MODEL_SIZE) * \ + (shape[rank - 1]) batch_size = int(settings.TF_MAX_BATCH_SIZE // ratio) return batch_size @@ -417,11 +360,19 @@ def predict(self, tiles, batch_size=None): # TODO: Can the result size be known beforehand via model metadata? results = [] + if isinstance(tiles, dict): + tiles = [tiles[m['in_tensor_name']] for m in self._metadata] + + if not isinstance(tiles, list): + tiles = [tiles] + if batch_size is None: batch_size = self.get_batch_size() - for t in range(0, tiles.shape[0], batch_size): - output = self.send_grpc(tiles[t:t + batch_size]) + for t in range(0, tiles[0].shape[0], batch_size): + inputs = [tile[t:t + batch_size] for tile in tiles] + inputs = inputs[0] if len(inputs) == 1 else inputs + output = self.send_grpc(inputs) if len(results) == 0: results = output diff --git a/redis_consumer/grpc_clients_test.py b/redis_consumer/grpc_clients_test.py index 4baecbf8..281fce5d 100644 --- a/redis_consumer/grpc_clients_test.py +++ b/redis_consumer/grpc_clients_test.py @@ -109,9 +109,11 @@ def test_init(self): wrapper = grpc_clients.GrpcModelWrapper(None, metadata) assert wrapper.input_shape == self.shape - multi_metadata = [metadata, metadata] - with pytest.raises(NotImplementedError): - wrapper = grpc_clients.GrpcModelWrapper(None, multi_metadata) + metadata += metadata + wrapper = grpc_clients.GrpcModelWrapper(None, metadata) + assert isinstance(wrapper.input_shape, list) + for s in wrapper.input_shape: + assert s == self.shape def test_get_batch_size(self, mocker): metadata = self._get_metadata() @@ -122,9 +124,10 @@ def test_get_batch_size(self, mocker): batch_size = wrapper.get_batch_size() assert batch_size == settings.TF_MAX_BATCH_SIZE * m * m - def test_send_grpc(self, mocker): + def test_send_grpc(self): client = DummyPredictClient(1, 2, 3) metadata = self._get_metadata() + metadata[0]['in_tensor_dtype'] = 'DT_HALF' wrapper = grpc_clients.GrpcModelWrapper(client, metadata) input_data = np.ones(self.shape) @@ -133,19 +136,25 @@ def test_send_grpc(self, mocker): assert len(result) == 1 np.testing.assert_array_equal(result[0], input_data) - input_data = np.ones(self.shape) - mocker.patch.object(wrapper, '_in_tensor_dtype', 'DT_HALF') + # test multiple inputs + metadata = self._get_metadata() + self._get_metadata() + input_data = [np.ones(self.shape)] * 2 + wrapper = grpc_clients.GrpcModelWrapper(client, metadata) result = wrapper.send_grpc(input_data) assert isinstance(result, list) - assert len(result) == 1 - np.testing.assert_array_equal(result[0], input_data) + assert len(result) == 2 + np.testing.assert_array_equal(result, input_data) + + # test inputs don't match metadata + with pytest.raises(ValueError): + wrapper.send_grpc(np.ones(self.shape)) def test_predict(self, mocker): metadata = self._get_metadata() wrapper = grpc_clients.GrpcModelWrapper(None, metadata) def mock_send_grpc(img): - return [img] + return img if isinstance(img, list) else [img] mocker.patch.object(wrapper, 'send_grpc', mock_send_grpc) @@ -158,3 +167,24 @@ def mock_send_grpc(img): # no batch size results = wrapper.predict(input_data) np.testing.assert_array_equal(input_data, results) + + # multiple inputs + metadata = self._get_metadata() * 2 + wrapper = grpc_clients.GrpcModelWrapper(None, metadata) + mocker.patch.object(wrapper, 'send_grpc', mock_send_grpc) + input_data = [np.ones((batch_size * 2, 30, 30, 1))] * 2 + results = wrapper.predict(input_data, batch_size=batch_size) + np.testing.assert_array_equal(input_data, results) + + # dictionary input + metadata = self._get_metadata() + wrapper = grpc_clients.GrpcModelWrapper(None, metadata) + mocker.patch.object(wrapper, 'send_grpc', mock_send_grpc) + input_data = { + m['in_tensor_name']: np.ones((batch_size * 2, 30, 30, 1)) + for m in metadata + } + results = wrapper.predict(input_data, batch_size=batch_size) + for m in metadata: + np.testing.assert_array_equal( + input_data[m['in_tensor_name']], results) diff --git a/redis_consumer/settings.py b/redis_consumer/settings.py index 8d032fe0..6b81fcb9 100644 --- a/redis_consumer/settings.py +++ b/redis_consumer/settings.py @@ -90,14 +90,8 @@ METADATA_EXPIRE_TIME = config('METADATA_EXPIRE_TIME', default=30, cast=int) # Tracking settings -TRACKING_SEGMENT_MODEL = config('TRACKING_SEGMENT_MODEL', default='panoptic:3', cast=str) -TRACKING_POSTPROCESS_FUNCTION = config('TRACKING_POSTPROCESS_FUNCTION', - default='retinanet', cast=str) - TRACKING_MODEL = config('TRACKING_MODEL', default='TrackingModel:0', cast=str) - DRIFT_CORRECT_ENABLED = config('DRIFT_CORRECT_ENABLED', default=False, cast=bool) -NORMALIZE_TRACKING = config('NORMALIZE_TRACKING', default=True, cast=bool) # tracking.cell_tracker settings TODO: can we extract from model_metadata? MAX_DISTANCE = config('MAX_DISTANCE', default=50, cast=int) @@ -107,12 +101,11 @@ DEATH = config('DEATH', default=0.99, cast=float) NEIGHBORHOOD_SCALE_SIZE = config('NEIGHBORHOOD_SCALE_SIZE', default=30, cast=int) -MAX_SCALE = config('MAX_SCALE', default=3, cast=float) -MIN_SCALE = config('MIN_SCALE', default=1 / MAX_SCALE, cast=float) - # Scale detection settings SCALE_DETECT_MODEL = config('SCALE_DETECT_MODEL', default='ScaleDetection:1') SCALE_DETECT_ENABLED = config('SCALE_DETECT_ENABLED', default=False, cast=bool) +MAX_SCALE = config('MAX_SCALE', default=3, cast=float) +MIN_SCALE = config('MIN_SCALE', default=1 / MAX_SCALE, cast=float) # Type detection settings LABEL_DETECT_MODEL = config('LABEL_DETECT_MODEL', default='LabelDetection:1', cast=str) diff --git a/redis_consumer/tracking.py b/redis_consumer/tracking.py deleted file mode 100644 index 55aac681..00000000 --- a/redis_consumer/tracking.py +++ /dev/null @@ -1,59 +0,0 @@ -# Copyright 2016-2020 The Van Valen Lab at the California Institute of -# Technology (Caltech), with support from the Paul Allen Family Foundation, -# Google, & National Institutes of Health (NIH) under Grant U24CA224309-01. -# All rights reserved. -# -# Licensed under a modified Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.github.com/vanvalenlab/kiosk-redis-consumer/LICENSE -# -# The Work provided may be used for non-commercial academic purposes only. -# For any other use of the Work, including commercial use, please contact: -# vanvalenlab@gmail.com -# -# Neither the name of Caltech nor the names of its contributors may be used -# to endorse or promote products derived from this software without specific -# prior written permission. -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Override the deepcell_tracking.CellTracker class to update progress""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import logging -import timeit - -from deepcell_tracking import CellTracker as _CellTracker - - -class CellTracker(_CellTracker): - """Override the original cell_tracker class to call model.progress()""" - - def __init__(self, *args, **kwargs): - self.logger = logging.getLogger(str(self.__class__.__name__)) - super(CellTracker, self).__init__(*args, **kwargs) - - def track_cells(self): - """Tracks all of the cells in every frame. - """ - start = timeit.default_timer() - self._initialize_tracks() - - for frame in range(1, self.x.shape[self.time_axis]): - self._track_frame(frame) - - # The only difference between the original and this - # is calling model.progress after every frame. - self.model.progress(frame / self.x.shape[0]) - - self.logger.info('Tracked all %s frames in %s s.', - self.x.shape[self.time_axis], - timeit.default_timer() - start) diff --git a/redis_consumer/tracking_test.py b/redis_consumer/tracking_test.py deleted file mode 100644 index fdd746f2..00000000 --- a/redis_consumer/tracking_test.py +++ /dev/null @@ -1,106 +0,0 @@ -# Copyright 2016-2020 The Van Valen Lab at the California Institute of -# Technology (Caltech), with support from the Paul Allen Family Foundation, -# Google, & National Institutes of Health (NIH) under Grant U24CA224309-01. -# All rights reserved. -# -# Licensed under a modified Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.github.com/vanvalenlab/kiosk-redis-consumer/LICENSE -# -# The Work provided may be used for non-commercial academic purposes only. -# For any other use of the Work, including commercial use, please contact: -# vanvalenlab@gmail.com -# -# Neither the name of Caltech nor the names of its contributors may be used -# to endorse or promote products derived from this software without specific -# prior written permission. -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for tracking.py""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np -import skimage as sk - -import pytest - -from redis_consumer import tracking - - -def _get_dummy_tracking_data(length=128, frames=3, - data_format='channels_last'): - if data_format == 'channels_last': - channel_axis = -1 - else: - channel_axis = 0 - - x, y = [], [] - while len(x) < frames: - _x = sk.data.binary_blobs(length=length, n_dim=2) - _y = sk.measure.label(_x) - if len(np.unique(_y)) > 2: - x.append(_x) - y.append(_y) - - x = np.stack(x, axis=0) # expand to 3D - y = np.stack(y, axis=0) # expand to 3D - - x = np.expand_dims(x, axis=channel_axis) - y = np.expand_dims(y, axis=channel_axis) - - return x.astype('float32'), y.astype('int32') - - -class DummyModel(object): # pylint: disable=useless-object-inheritance - - def predict(self, data): - # Grab a random value from the data dict and select batch dim - batches = 0 if not data else next(iter(data.values())).shape[0] - - return np.random.random((batches, 3)) - - def progress(self, n): - return n - - -class TestTracking(object): - - def test__track_cells(self): - length = 128 - frames = 5 - track_length = 2 - - features = ['appearance', 'neighborhood', 'regionprop', 'distance'] - - # TODO: Fix for channels_first - for data_format in ('channels_last',): # 'channels_first'): - - x, y = _get_dummy_tracking_data( - length, frames=frames, data_format=data_format) - - tracker = tracking.CellTracker( - x, y, - model=DummyModel(), - track_length=track_length, - data_format=data_format, - features=features) - - tracker.track_cells() - - # test tracker.dataframe - df = tracker.dataframe(cell_type='test-value') - assert 'cell_type' in df.columns - - # test incorrect values in tracker.dataframe - with pytest.raises(ValueError): - tracker.dataframe(bad_value=-1)