diff --git a/.dockerignore b/.dockerignore index 3d12409d..b8fb0c1d 100644 --- a/.dockerignore +++ b/.dockerignore @@ -5,6 +5,7 @@ logs/ protos/ docs/ build/ +.github/ # Byte-compiled / optimized / DLL files __pycache__/ diff --git a/.env.example b/.env.example index 01185e24..3202d9b1 100644 --- a/.env.example +++ b/.env.example @@ -9,17 +9,10 @@ REDIS_HOST= TF_PORT= TF_HOST= -# Cloud selection -CLOUD_PROVIDER= - -# DEBUG Logging -DEBUG= +# Storage bucket +STORAGE_BUCKET= # AWS Credentials AWS_REGION= -AWS_S3_BUCKET= AWS_ACCESS_KEY_ID= AWS_SECRET_ACCESS_KEY= - -# Google variables -GKE_BUCKET= diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 384cd99f..6c40af12 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -12,7 +12,7 @@ jobs: strategy: matrix: - python-version: [2.7, 3.5, 3.6, 3.7, 3.8] + python-version: [3.5, 3.6, 3.7, 3.8] steps: - uses: actions/checkout@v2 @@ -37,6 +37,7 @@ jobs: run: | python -m pip install --upgrade pip pip install -r requirements.txt + pip install --no-deps -r requirements-no-deps.txt pip install -r requirements-test.txt - name: PyTest diff --git a/Dockerfile b/Dockerfile index ed3cdc3d..9f81941c 100644 --- a/Dockerfile +++ b/Dockerfile @@ -23,13 +23,18 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ -FROM python:3.6 +FROM python:3.6-slim-buster WORKDIR /usr/src/app -COPY requirements.txt . +RUN apt-get update && apt-get install -y \ + build-essential libglib2.0-0 && \ + rm -rf /var/lib/apt/lists/* -RUN pip install --no-cache-dir -r requirements.txt +COPY requirements.txt requirements-no-deps.txt ./ + +RUN pip install --no-cache-dir -r requirements.txt && \ + pip install --no-cache-dir --no-deps -r requirements-no-deps.txt COPY . . diff --git a/README.md b/README.md index c89a726f..a3305a69 100644 --- a/README.md +++ b/README.md @@ -17,7 +17,6 @@ Consumers consume Redis events. Each type of Redis event is put into a queue (e. Consumers call the `_consume` method to consume each item it finds in the queue. This method must be implemented for every consumer. - The quickest way to get a custom consumer up and running is to: 1. Add a new file for the consumer: `redis_consumer/consumers/my_new_consumer.py` @@ -38,28 +37,19 @@ def _consume(self, redis_hash): redis_hash, hvals.get('status')) return hvals.get('status') - # the data to process with the model, required. - input_file_name = hvals.get('input_file_name') + # Load input image + fname = hvals.get('input_file_name') + image = self.download_image(fname) # the model can be passed in as an environment variable, # and parsed in settings.py. - model_name, model_version = 'CustomModel:1'.split(':') - - with utils.get_tempdir() as tempdir: - # download the image file - fname = self.storage.download(input_file_name, tempdir) - # load image file as data - image = utils.get_image(fname) - - # pre- and post-processing can be used with the BaseConsumer.process, - # which uses pre-defined functions in settings.PROCESSING_FUNCTIONS. - image = self.preprocess(image, 'normalize') + model = 'NuclearSegmentation:1' - # send the data to the model - results = self.predict(image, model_name, model_version) + # Use a custom Application from deepcell.applications + app = self.get_grpc_app(model, deepcell.applications.NuclearSegmentation) - # post-process model results - image = self.postprocess(image, 'deep_watershed') + # Run the predictions on the image + results = app.predict(image) # save the results as an image file and upload it to the bucket save_name = hvals.get('original_name', fname) @@ -90,8 +80,7 @@ The consumer is configured using environment variables. Please find a table of a | :--- | :--- | :--- | | `QUEUE` | **REQUIRED**: The Redis job queue to check for items to consume. | `"predict"` | | `CONSUMER_TYPE` | **REQUIRED**: The type of consumer to run, used in `consume-redis-events.py`. | `"image"` | -| `CLOUD_PROVIDER` | **REQUIRED**: The cloud provider, one of `"aws"` and `"gke"`. | `"gke"` | -| `GCLOUD_STORAGE_BUCKET` | **REQUIRED**: The name of the storage bucket used to download and upload files. | `"default-bucket"` | +| `STORAGE_BUCKET` | **REQUIRED**: The name of the storage bucket used to download and upload files. | `"s3://default-bucket"` | | `INTERVAL` | How frequently the consumer checks the Redis queue for items, in seconds. | `5` | | `REDIS_HOST` | The IP address or hostname of Redis. | `"redis-master"` | | `REDIS_PORT` | The port used to connect to Redis. | `6379` | diff --git a/consume-redis-events.py b/consume-redis-events.py index 73c44b90..25d2d347 100644 --- a/consume-redis-events.py +++ b/consume-redis-events.py @@ -37,32 +37,33 @@ import sys import traceback +import decouple + import redis_consumer from redis_consumer import settings -def initialize_logger(debug_mode=True): +def initialize_logger(log_level='DEBUG'): + log_level = str(log_level).upper() logger = logging.getLogger() logger.setLevel(logging.DEBUG) - formatter = logging.Formatter('[%(asctime)s]:[%(levelname)s]:[%(name)s]: %(message)s') + formatter = logging.Formatter('[%(levelname)s]:[%(name)s]: %(message)s') console = logging.StreamHandler(stream=sys.stdout) console.setFormatter(formatter) - fh = logging.handlers.RotatingFileHandler( - filename='redis-consumer.log', - maxBytes=10000000, - backupCount=10) - fh.setFormatter(formatter) - - if debug_mode: - console.setLevel(logging.DEBUG) - else: + if log_level == 'CRITICAL': + console.setLevel(logging.CRITICAL) + elif log_level == 'ERROR': + console.setLevel(logging.ERROR) + elif log_level == 'WARN': + console.setLevel(logging.WARN) + elif log_level == 'INFO': console.setLevel(logging.INFO) - fh.setLevel(logging.DEBUG) + else: + console.setLevel(logging.DEBUG) logger.addHandler(console) - logger.addHandler(fh) def get_consumer(consumer_type, **kwargs): @@ -74,7 +75,7 @@ def get_consumer(consumer_type, **kwargs): if __name__ == '__main__': - initialize_logger(settings.DEBUG) + initialize_logger(decouple.config('LOG_LEVEL', default='DEBUG')) _logger = logging.getLogger(__file__) @@ -83,7 +84,7 @@ def get_consumer(consumer_type, **kwargs): port=settings.REDIS_PORT, backoff=settings.REDIS_TIMEOUT) - storage_client = redis_consumer.storage.get_client(settings.CLOUD_PROVIDER) + storage_client = redis_consumer.storage.get_client(settings.STORAGE_BUCKET) consumer_kwargs = { 'redis_client': redis, @@ -92,7 +93,6 @@ def get_consumer(consumer_type, **kwargs): 'final_status': 'done', 'failed_status': 'failed', 'name': settings.HOSTNAME, - 'output_dir': settings.OUTPUT_DIR, } _logger.debug('Getting `%s` consumer with args %s.', diff --git a/protos/attr_value.proto b/protos/attr_value.proto deleted file mode 100644 index 76944f77..00000000 --- a/protos/attr_value.proto +++ /dev/null @@ -1,62 +0,0 @@ -syntax = "proto3"; - -package tensorflow; -option cc_enable_arenas = true; -option java_outer_classname = "AttrValueProtos"; -option java_multiple_files = true; -option java_package = "org.tensorflow.framework"; -option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framework"; -import "tensor.proto"; -import "tensor_shape.proto"; -import "types.proto"; - -// Protocol buffer representing the value for an attr used to configure an Op. -// Comment indicates the corresponding attr type. Only the field matching the -// attr type may be filled. -message AttrValue { - // LINT.IfChange - message ListValue { - repeated bytes s = 2; // "list(string)" - repeated int64 i = 3 [packed = true]; // "list(int)" - repeated float f = 4 [packed = true]; // "list(float)" - repeated bool b = 5 [packed = true]; // "list(bool)" - repeated DataType type = 6 [packed = true]; // "list(type)" - repeated TensorShapeProto shape = 7; // "list(shape)" - repeated TensorProto tensor = 8; // "list(tensor)" - repeated NameAttrList func = 9; // "list(attr)" - } - // LINT.ThenChange(https://www.tensorflow.org/code/tensorflow/c/c_api.cc) - - oneof value { - bytes s = 2; // "string" - int64 i = 3; // "int" - float f = 4; // "float" - bool b = 5; // "bool" - DataType type = 6; // "type" - TensorShapeProto shape = 7; // "shape" - TensorProto tensor = 8; // "tensor" - ListValue list = 1; // any "list(...)" - - // "func" represents a function. func.name is a function's name or - // a primitive op's name. func.attr.first is the name of an attr - // defined for that function. func.attr.second is the value for - // that attr in the instantiation. - NameAttrList func = 10; - - // This is a placeholder only used in nodes defined inside a - // function. It indicates the attr value will be supplied when - // the function is instantiated. For example, let us suppose a - // node "N" in function "FN". "N" has an attr "A" with value - // placeholder = "foo". When FN is instantiated with attr "foo" - // set to "bar", the instantiated node N's attr A will have been - // given the value "bar". - string placeholder = 9; - } -} - -// A list of attr names and their values. The whole list is attached -// with a string name. E.g., MatMul[T=float]. -message NameAttrList { - string name = 1; - map attr = 2; -} diff --git a/protos/function.proto b/protos/function.proto deleted file mode 100644 index 6d107635..00000000 --- a/protos/function.proto +++ /dev/null @@ -1,113 +0,0 @@ -syntax = "proto3"; - -package tensorflow; -option cc_enable_arenas = true; -option java_outer_classname = "FunctionProtos"; -option java_multiple_files = true; -option java_package = "org.tensorflow.framework"; -option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framework"; -import "attr_value.proto"; -import "node_def.proto"; -import "op_def.proto"; - -// A library is a set of named functions. -message FunctionDefLibrary { - repeated FunctionDef function = 1; - repeated GradientDef gradient = 2; -} - -// A function can be instantiated when the runtime can bind every attr -// with a value. When a GraphDef has a call to a function, it must -// have binding for every attr defined in the signature. -// -// TODO(zhifengc): -// * device spec, etc. -message FunctionDef { - // The definition of the function's name, arguments, return values, - // attrs etc. - OpDef signature = 1; - - // Attributes specific to this function definition. - map attr = 5; - - // Attributes for function arguments. These attributes are the same set of - // valid attributes as to _Arg nodes. - message ArgAttrs { - map attr = 1; - } - map arg_attr = 7; - - // NOTE: field id 2 deleted on Jan 11, 2017, GraphDef version 21. - reserved 2; - - // In both of the following fields, there is the need to specify an - // output that is used as either the input to another node (in - // `node_def`) or as a return value of the function (in `ret`). - // Unlike the NodeDefs in GraphDef, we need to be able to specify a - // list in some cases (instead of just single outputs). Also, we - // need to be able to deal with lists of unknown length (so the - // output index may not be known at function definition time). So - // we use the following format instead: - // * "fun_in" where "fun_in" is the name of a function input arg in - // the `signature` field above. This represents that input, whether - // it is a single tensor or a list. - // * "fun_in:0" gives the first element of a function input arg (a - // non-list input is considered a list of length 1 for these - // purposes). - // * "node:out" where "node" is the name of a node in `node_def` and - // "out" is the name one of its op's output arguments (the name - // comes from the OpDef of the node's op). This represents that - // node's output, whether it is a single tensor or a list. - // Note: We enforce that an op's output arguments are never - // renamed in the backwards-compatibility test. - // * "node:out:0" gives the first element of a node output arg (a - // non-list output is considered a list of length 1 for these - // purposes). - // - // NOT CURRENTLY SUPPORTED (but may be in the future): - // * "node:out:-1" gives last element in a node output list - // * "node:out:1:" gives a list with all but the first element in a - // node output list - // * "node:out::-1" gives a list with all but the last element in a - // node output list - - // The body of the function. Unlike the NodeDefs in a GraphDef, attrs - // may have values of type `placeholder` and the `input` field uses - // the "output" format above. - - // By convention, "op" in node_def is resolved by consulting with a - // user-defined library first. If not resolved, "func" is assumed to - // be a builtin op. - repeated NodeDef node_def = 3; - - // A mapping from the output arg names from `signature` to the - // outputs from `node_def` that should be returned by the function. - map ret = 4; - - // A mapping from control output names from `signature` to node names in - // `node_def` which should be control outputs of this function. - map control_ret = 6; -} - -// GradientDef defines the gradient function of a function defined in -// a function library. -// -// A gradient function g (specified by gradient_func) for a function f -// (specified by function_name) must follow the following: -// -// The function 'f' must be a numerical function which takes N inputs -// and produces M outputs. Its gradient function 'g', which is a -// function taking N + M inputs and produces N outputs. -// -// I.e. if we have -// (y1, y2, ..., y_M) = f(x1, x2, ..., x_N), -// then, g is -// (dL/dx1, dL/dx2, ..., dL/dx_N) = g(x1, x2, ..., x_N, -// dL/dy1, dL/dy2, ..., dL/dy_M), -// where L is a scalar-value function of (x1, x2, ..., xN) (e.g., the -// loss function). dL/dx_i is the partial derivative of L with respect -// to x_i. -message GradientDef { - string function_name = 1; // The function name. - string gradient_func = 2; // The gradient function's name. -} diff --git a/protos/get_model_metadata.proto b/protos/get_model_metadata.proto deleted file mode 100644 index 60ddfd56..00000000 --- a/protos/get_model_metadata.proto +++ /dev/null @@ -1,30 +0,0 @@ -syntax = "proto3"; - -package tensorflow.serving; -option cc_enable_arenas = true; - -import "google/protobuf/any.proto"; -import "meta_graph.proto"; -import "model.proto"; - -// Message returned for "signature_def" field. -message SignatureDefMap { - map signature_def = 1; -}; - -message GetModelMetadataRequest { - // Model Specification indicating which model we are querying for metadata. - // If version is not specified, will use the latest (numerical) version. - ModelSpec model_spec = 1; - // Metadata fields to get. Currently supported: "signature_def". - repeated string metadata_field = 2; -} - -message GetModelMetadataResponse { - // Model Specification indicating which model this metadata belongs to. - ModelSpec model_spec = 1; - // Map of metadata field name to metadata field. The options for metadata - // field name are listed in GetModelMetadataRequest. Currently supported: - // "signature_def". - map metadata = 2; -} diff --git a/protos/graph.proto b/protos/graph.proto deleted file mode 100644 index 14d9edfa..00000000 --- a/protos/graph.proto +++ /dev/null @@ -1,56 +0,0 @@ -syntax = "proto3"; - -package tensorflow; -option cc_enable_arenas = true; -option java_outer_classname = "GraphProtos"; -option java_multiple_files = true; -option java_package = "org.tensorflow.framework"; -option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framework"; -import "node_def.proto"; -import "function.proto"; -import "versions.proto"; - -// Represents the graph of operations -message GraphDef { - repeated NodeDef node = 1; - - // Compatibility versions of the graph. See core/public/version.h for version - // history. The GraphDef version is distinct from the TensorFlow version, and - // each release of TensorFlow will support a range of GraphDef versions. - VersionDef versions = 4; - - // Deprecated single version field; use versions above instead. Since all - // GraphDef changes before "versions" was introduced were forward - // compatible, this field is entirely ignored. - int32 version = 3 [deprecated = true]; - - // EXPERIMENTAL. DO NOT USE OR DEPEND ON THIS YET. - // - // "library" provides user-defined functions. - // - // Naming: - // * library.function.name are in a flat namespace. - // NOTE: We may need to change it to be hierarchical to support - // different orgs. E.g., - // { "/google/nn", { ... }}, - // { "/google/vision", { ... }} - // { "/org_foo/module_bar", { ... }} - // map named_lib; - // * If node[i].op is the name of one function in "library", - // node[i] is deemed as a function call. Otherwise, node[i].op - // must be a primitive operation supported by the runtime. - // - // - // Function call semantics: - // - // * The callee may start execution as soon as some of its inputs - // are ready. The caller may want to use Tuple() mechanism to - // ensure all inputs are ready in the same time. - // - // * The consumer of return values may start executing as soon as - // the return values the consumer depends on are ready. The - // consumer may want to use Tuple() mechanism to ensure the - // consumer does not start until all return values of the callee - // function are ready. - FunctionDefLibrary library = 2; -}; diff --git a/protos/meta_graph.proto b/protos/meta_graph.proto deleted file mode 100644 index f1005543..00000000 --- a/protos/meta_graph.proto +++ /dev/null @@ -1,342 +0,0 @@ -syntax = "proto3"; - -package tensorflow; - -import "google/protobuf/any.proto"; -import "graph.proto"; -import "op_def.proto"; -import "tensor_shape.proto"; -import "types.proto"; -import "saved_object_graph.proto"; -import "saver.proto"; -import "struct.proto"; - -option cc_enable_arenas = true; -option java_outer_classname = "MetaGraphProtos"; -option java_multiple_files = true; -option java_package = "org.tensorflow.framework"; -option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/core_protos_go_proto"; - -// NOTE: This protocol buffer is evolving, and will go through revisions in the -// coming months. -// -// Protocol buffer containing the following which are necessary to restart -// training, run inference. It can be used to serialize/de-serialize memory -// objects necessary for running computation in a graph when crossing the -// process boundary. It can be used for long term storage of graphs, -// cross-language execution of graphs, etc. -// MetaInfoDef -// GraphDef -// SaverDef -// CollectionDef -// TensorInfo -// SignatureDef -message MetaGraphDef { - // Meta information regarding the graph to be exported. To be used by users - // of this protocol buffer to encode information regarding their meta graph. - message MetaInfoDef { - // User specified Version string. Can be the name of the model and revision, - // steps this model has been trained to, etc. - string meta_graph_version = 1; - - // A copy of the OpDefs used by the producer of this graph_def. - // Descriptions and Ops not used in graph_def are stripped out. - OpList stripped_op_list = 2; - - // A serialized protobuf. Can be the time this meta graph is created, or - // modified, or name of the model. - google.protobuf.Any any_info = 3; - - // User supplied tag(s) on the meta_graph and included graph_def. - // - // MetaGraphDefs should be tagged with their capabilities or use-cases. - // Examples: "train", "serve", "gpu", "tpu", etc. - // These tags enable loaders to access the MetaGraph(s) appropriate for a - // specific use-case or runtime environment. - repeated string tags = 4; - - // The __version__ string of the tensorflow build used to write this graph. - // This will be populated by the framework, which will overwrite any user - // supplied value. - string tensorflow_version = 5; - - // The __git_version__ string of the tensorflow build used to write this - // graph. This will be populated by the framework, which will overwrite any - // user supplied value. - string tensorflow_git_version = 6; - - // A flag to denote whether default-valued attrs have been stripped from - // the nodes in this graph_def. - bool stripped_default_attrs = 7; - - // FunctionDef name to aliases mapping. - map function_aliases = 8; - } - MetaInfoDef meta_info_def = 1; - - // GraphDef. - GraphDef graph_def = 2; - - // SaverDef. - SaverDef saver_def = 3; - - // collection_def: Map from collection name to collections. - // See CollectionDef section for details. - map collection_def = 4; - - // signature_def: Map from user supplied key for a signature to a single - // SignatureDef. - map signature_def = 5; - - // Asset file def to be used with the defined graph. - repeated AssetFileDef asset_file_def = 6; - - // Extra information about the structure of functions and stateful objects. - SavedObjectGraph object_graph_def = 7; -} - -// CollectionDef should cover most collections. -// To add a user-defined collection, do one of the following: -// 1. For simple data types, such as string, int, float: -// tf.add_to_collection("your_collection_name", your_simple_value) -// strings will be stored as bytes_list. -// -// 2. For Protobuf types, there are three ways to add them: -// 1) tf.add_to_collection("your_collection_name", -// your_proto.SerializeToString()) -// -// collection_def { -// key: "user_defined_bytes_collection" -// value { -// bytes_list { -// value: "queue_name: \"test_queue\"\n" -// } -// } -// } -// -// or -// -// 2) tf.add_to_collection("your_collection_name", str(your_proto)) -// -// collection_def { -// key: "user_defined_string_collection" -// value { -// bytes_list { -// value: "\n\ntest_queue" -// } -// } -// } -// -// or -// -// 3) any_buf = any_pb2.Any() -// tf.add_to_collection("your_collection_name", -// any_buf.Pack(your_proto)) -// -// collection_def { -// key: "user_defined_any_collection" -// value { -// any_list { -// value { -// type_url: "type.googleapis.com/tensorflow.QueueRunnerDef" -// value: "\n\ntest_queue" -// } -// } -// } -// } -// -// 3. For Python objects, implement to_proto() and from_proto(), and register -// them in the following manner: -// ops.register_proto_function("your_collection_name", -// proto_type, -// to_proto=YourPythonObject.to_proto, -// from_proto=YourPythonObject.from_proto) -// These functions will be invoked to serialize and de-serialize the -// collection. For example, -// ops.register_proto_function(ops.GraphKeys.GLOBAL_VARIABLES, -// proto_type=variable_pb2.VariableDef, -// to_proto=Variable.to_proto, -// from_proto=Variable.from_proto) -message CollectionDef { - // NodeList is used for collecting nodes in graph. For example - // collection_def { - // key: "summaries" - // value { - // node_list { - // value: "input_producer/ScalarSummary:0" - // value: "shuffle_batch/ScalarSummary:0" - // value: "ImageSummary:0" - // } - // } - message NodeList { - repeated string value = 1; - } - - // BytesList is used for collecting strings and serialized protobufs. For - // example: - // collection_def { - // key: "trainable_variables" - // value { - // bytes_list { - // value: "\n\017conv1/weights:0\022\024conv1/weights/Assign - // \032\024conv1/weights/read:0" - // value: "\n\016conv1/biases:0\022\023conv1/biases/Assign\032 - // \023conv1/biases/read:0" - // } - // } - // } - message BytesList { - repeated bytes value = 1; - } - - // Int64List is used for collecting int, int64 and long values. - message Int64List { - repeated int64 value = 1 [packed = true]; - } - - // FloatList is used for collecting float values. - message FloatList { - repeated float value = 1 [packed = true]; - } - - // AnyList is used for collecting Any protos. - message AnyList { - repeated google.protobuf.Any value = 1; - } - - oneof kind { - NodeList node_list = 1; - BytesList bytes_list = 2; - Int64List int64_list = 3; - FloatList float_list = 4; - AnyList any_list = 5; - } -} - -// Information about a Tensor necessary for feeding or retrieval. -message TensorInfo { - // For sparse tensors, The COO encoding stores a triple of values, indices, - // and shape. - message CooSparse { - // The shape of the values Tensor is [?]. Its dtype must be the dtype of - // the SparseTensor as a whole, given in the enclosing TensorInfo. - string values_tensor_name = 1; - - // The indices Tensor must have dtype int64 and shape [?, ?]. - string indices_tensor_name = 2; - - // The dynamic logical shape represented by the SparseTensor is recorded in - // the Tensor referenced here. It must have dtype int64 and shape [?]. - string dense_shape_tensor_name = 3; - } - - // Generic encoding for composite tensors. - message CompositeTensor { - // The serialized TypeSpec for the composite tensor. - TypeSpecProto type_spec = 1; - - // A TensorInfo for each flattened component tensor. - repeated TensorInfo components = 2; - } - - oneof encoding { - // For dense `Tensor`s, the name of the tensor in the graph. - string name = 1; - // There are many possible encodings of sparse matrices - // (https://en.wikipedia.org/wiki/Sparse_matrix). Currently, TensorFlow - // uses only the COO encoding. This is supported and documented in the - // SparseTensor Python class. - CooSparse coo_sparse = 4; - // Generic encoding for CompositeTensors. - CompositeTensor composite_tensor = 5; - } - DataType dtype = 2; - // The static shape should be recorded here, to the extent that it can - // be known in advance. In the case of a SparseTensor, this field describes - // the logical shape of the represented tensor (aka dense_shape). - TensorShapeProto tensor_shape = 3; -} - -// SignatureDef defines the signature of a computation supported by a TensorFlow -// graph. -// -// For example, a model with two loss computations, sharing a single input, -// might have the following signature_def map. -// -// Note that across the two SignatureDefs "loss_A" and "loss_B", the input key, -// output key, and method_name are identical, and will be used by system(s) that -// implement or rely upon this particular loss method. The output tensor names -// differ, demonstrating how different outputs can exist for the same method. -// -// signature_def { -// key: "loss_A" -// value { -// inputs { -// key: "input" -// value { -// name: "input:0" -// dtype: DT_STRING -// tensor_shape: ... -// } -// } -// outputs { -// key: "loss_output" -// value { -// name: "loss_output_A:0" -// dtype: DT_FLOAT -// tensor_shape: ... -// } -// } -// } -// ... -// method_name: "some/package/compute_loss" -// } -// signature_def { -// key: "loss_B" -// value { -// inputs { -// key: "input" -// value { -// name: "input:0" -// dtype: DT_STRING -// tensor_shape: ... -// } -// } -// outputs { -// key: "loss_output" -// value { -// name: "loss_output_B:0" -// dtype: DT_FLOAT -// tensor_shape: ... -// } -// } -// } -// ... -// method_name: "some/package/compute_loss" -// } -message SignatureDef { - // Named input parameters. - map inputs = 1; - // Named output parameters. - map outputs = 2; - // Extensible method_name information enabling third-party users to mark a - // SignatureDef as supporting a particular method. This enables producers and - // consumers of SignatureDefs, e.g. a model definition library and a serving - // library to have a clear hand-off regarding the semantics of a computation. - // - // Note that multiple SignatureDefs in a single MetaGraphDef may have the same - // method_name. This is commonly used to support multi-headed computation, - // where a single graph computation may return multiple results. - string method_name = 3; -} - -// An asset file def for a single file or a set of sharded files with the same -// name. -message AssetFileDef { - // The tensor to bind the asset filename to. - TensorInfo tensor_info = 1; - // The filename within an assets directory. Note: does not include the path - // prefix, i.e. directories. For an asset at /tmp/path/vocab.txt, the filename - // would be "vocab.txt". - string filename = 2; -} diff --git a/protos/model.proto b/protos/model.proto deleted file mode 100644 index 56493f68..00000000 --- a/protos/model.proto +++ /dev/null @@ -1,33 +0,0 @@ -syntax = "proto3"; - -package tensorflow.serving; -option cc_enable_arenas = true; - -import "google/protobuf/wrappers.proto"; - -// Metadata for an inference request such as the model name and version. -message ModelSpec { - // Required servable name. - string name = 1; - - // Optional choice of which version of the model to use. - // - // Recommended to be left unset in the common case. Should be specified only - // when there is a strong version consistency requirement. - // - // When left unspecified, the system will serve the best available version. - // This is typically the latest version, though during version transitions, - // notably when serving on a fleet of instances, may be either the previous or - // new version. - oneof version_choice { - // Use this specific version number. - google.protobuf.Int64Value version = 2; - - // Use the version associated with the given label. - string version_label = 4; - } - - // A named signature to evaluate. If unspecified, the default signature will - // be used. - string signature_name = 3; -} diff --git a/protos/node_def.proto b/protos/node_def.proto deleted file mode 100644 index 1e0da16f..00000000 --- a/protos/node_def.proto +++ /dev/null @@ -1,86 +0,0 @@ -syntax = "proto3"; - -package tensorflow; -option cc_enable_arenas = true; -option java_outer_classname = "NodeProto"; -option java_multiple_files = true; -option java_package = "org.tensorflow.framework"; -option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framework"; -import "attr_value.proto"; - -message NodeDef { - // The name given to this operator. Used for naming inputs, - // logging, visualization, etc. Unique within a single GraphDef. - // Must match the regexp "[A-Za-z0-9.][A-Za-z0-9_>./]*". - string name = 1; - - // The operation name. There may be custom parameters in attrs. - // Op names starting with an underscore are reserved for internal use. - string op = 2; - - // Each input is "node:src_output" with "node" being a string name and - // "src_output" indicating which output tensor to use from "node". If - // "src_output" is 0 the ":0" suffix can be omitted. Regular inputs - // may optionally be followed by control inputs that have the format - // "^node". - repeated string input = 3; - - // A (possibly partial) specification for the device on which this - // node should be placed. - // The expected syntax for this string is as follows: - // - // DEVICE_SPEC ::= PARTIAL_SPEC - // - // PARTIAL_SPEC ::= ("/" CONSTRAINT) * - // CONSTRAINT ::= ("job:" JOB_NAME) - // | ("replica:" [1-9][0-9]*) - // | ("task:" [1-9][0-9]*) - // | ("device:" [A-Za-z]* ":" ([1-9][0-9]* | "*") ) - // - // Valid values for this string include: - // * "/job:worker/replica:0/task:1/device:GPU:3" (full specification) - // * "/job:worker/device:GPU:3" (partial specification) - // * "" (no specification) - // - // If the constraints do not resolve to a single device (or if this - // field is empty or not present), the runtime will attempt to - // choose a device automatically. - string device = 4; - - // Operation-specific graph-construction-time configuration. - // Note that this should include all attrs defined in the - // corresponding OpDef, including those with a value matching - // the default -- this allows the default to change and makes - // NodeDefs easier to interpret on their own. However, if - // an attr with a default is not specified in this list, the - // default will be used. - // The "names" (keys) must match the regexp "[a-z][a-z0-9_]+" (and - // one of the names from the corresponding OpDef's attr field). - // The values must have a type matching the corresponding OpDef - // attr's type field. - // TODO(josh11b): Add some examples here showing best practices. - map attr = 5; - - message ExperimentalDebugInfo { - // Opaque string inserted into error messages created by the runtime. - // - // This is intended to store the list of names of the nodes from the - // original graph that this node was derived. For example if this node, say - // C, was result of a fusion of 2 nodes A and B, then 'original_node' would - // be {A, B}. This information can be used to map errors originating at the - // current node to some top level source code. - repeated string original_node_names = 1; - - // This is intended to store the list of names of the functions from the - // original graph that this node was derived. For example if this node, say - // C, was result of a fusion of node A in function FA and node B in function - // FB, then `original_funcs` would be {FA, FB}. If the node is in the top - // level graph, the `original_func` is empty. This information, with the - // `original_node_names` can be used to map errors originating at the - // current ndoe to some top level source code. - repeated string original_func_names = 2; - }; - - // This stores debug information associated with the node. - ExperimentalDebugInfo experimental_debug_info = 6; -}; diff --git a/protos/op_def.proto b/protos/op_def.proto deleted file mode 100644 index 9f5f6bf8..00000000 --- a/protos/op_def.proto +++ /dev/null @@ -1,170 +0,0 @@ -syntax = "proto3"; - -package tensorflow; -option cc_enable_arenas = true; -option java_outer_classname = "OpDefProtos"; -option java_multiple_files = true; -option java_package = "org.tensorflow.framework"; -option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framework"; -import "attr_value.proto"; -import "types.proto"; - -// Defines an operation. A NodeDef in a GraphDef specifies an Op by -// using the "op" field which should match the name of a OpDef. -// LINT.IfChange -message OpDef { - // Op names starting with an underscore are reserved for internal use. - // Names should be CamelCase and match the regexp "[A-Z][a-zA-Z0-9>_]*". - string name = 1; - - // For describing inputs and outputs. - message ArgDef { - // Name for the input/output. Should match the regexp "[a-z][a-z0-9_]*". - string name = 1; - - // Human readable description. - string description = 2; - - // Describes the type of one or more tensors that are accepted/produced - // by this input/output arg. The only legal combinations are: - // * For a single tensor: either the "type" field is set or the - // "type_attr" field is set to the name of an attr with type "type". - // * For a sequence of tensors with the same type: the "number_attr" - // field will be set to the name of an attr with type "int", and - // either the "type" or "type_attr" field will be set as for - // single tensors. - // * For a sequence of tensors, the "type_list_attr" field will be set - // to the name of an attr with type "list(type)". - DataType type = 3; - string type_attr = 4; // if specified, attr must have type "type" - string number_attr = 5; // if specified, attr must have type "int" - // If specified, attr must have type "list(type)", and none of - // type, type_attr, and number_attr may be specified. - string type_list_attr = 6; - - // For inputs: if true, the inputs are required to be refs. - // By default, inputs can be either refs or non-refs. - // For outputs: if true, outputs are refs, otherwise they are not. - bool is_ref = 16; - }; - - // Description of the input(s). - repeated ArgDef input_arg = 2; - - // Description of the output(s). - repeated ArgDef output_arg = 3; - - // Named control outputs for this operation. Useful only for composite - // operations (i.e. functions) which want to name different control outputs. - repeated string control_output = 20; - - // Description of the graph-construction-time configuration of this - // Op. That is to say, this describes the attr fields that will - // be specified in the NodeDef. - message AttrDef { - // A descriptive name for the argument. May be used, e.g. by the - // Python client, as a keyword argument name, and so should match - // the regexp "[a-z][a-z0-9_]+". - string name = 1; - - // One of the type names from attr_value.proto ("string", "list(string)", - // "int", etc.). - string type = 2; - - // A reasonable default for this attribute if the user does not supply - // a value. If not specified, the user must supply a value. - AttrValue default_value = 3; - - // Human-readable description. - string description = 4; - - // TODO(josh11b): bool is_optional? - - // --- Constraints --- - // These constraints are only in effect if specified. Default is no - // constraints. - - // For type == "int", this is a minimum value. For "list(___)" - // types, this is the minimum length. - bool has_minimum = 5; - int64 minimum = 6; - - // The set of allowed values. Has type that is the "list" version - // of the "type" field above (uses the "list" field of AttrValue). - // If type == "type" or "list(type)" above, then the "type" field - // of "allowed_values.list" has the set of allowed DataTypes. - // If type == "string" or "list(string)", then the "s" field of - // "allowed_values.list" has the set of allowed strings. - AttrValue allowed_values = 7; - } - repeated AttrDef attr = 4; - - // Optional deprecation based on GraphDef versions. - OpDeprecation deprecation = 8; - - // One-line human-readable description of what the Op does. - string summary = 5; - - // Additional, longer human-readable description of what the Op does. - string description = 6; - - // ------------------------------------------------------------------------- - // Which optimizations this operation can participate in. - - // True if the operation is commutative ("op(a,b) == op(b,a)" for all inputs) - bool is_commutative = 18; - - // If is_aggregate is true, then this operation accepts N >= 2 - // inputs and produces 1 output all of the same type. Should be - // associative and commutative, and produce output with the same - // shape as the input. The optimizer may replace an aggregate op - // taking input from multiple devices with a tree of aggregate ops - // that aggregate locally within each device (and possibly within - // groups of nearby devices) before communicating. - // TODO(josh11b): Implement that optimization. - bool is_aggregate = 16; // for things like add - - // Other optimizations go here, like - // can_alias_input, rewrite_when_output_unused, partitioning_strategy, etc. - - // ------------------------------------------------------------------------- - // Optimization constraints. - - // Ops are marked as stateful if their behavior depends on some state beyond - // their input tensors (e.g. variable reading op) or if they have - // a side-effect (e.g. printing or asserting ops). Equivalently, stateless ops - // must always produce the same output for the same input and have - // no side-effects. - // - // By default Ops may be moved between devices. Stateful ops should - // either not be moved, or should only be moved if that state can also - // be moved (e.g. via some sort of save / restore). - // Stateful ops are guaranteed to never be optimized away by Common - // Subexpression Elimination (CSE). - bool is_stateful = 17; // for things like variables, queue - - // ------------------------------------------------------------------------- - // Non-standard options. - - // By default, all inputs to an Op must be initialized Tensors. Ops - // that may initialize tensors for the first time should set this - // field to true, to allow the Op to take an uninitialized Tensor as - // input. - bool allows_uninitialized_input = 19; // for Assign, etc. -}; -// LINT.ThenChange( -// https://www.tensorflow.org/code/tensorflow/core/framework/op_def_util.cc) - -// Information about version-dependent deprecation of an op -message OpDeprecation { - // First GraphDef version at which the op is disallowed. - int32 version = 1; - - // Explanation of why it was deprecated and what to use instead. - string explanation = 2; -}; - -// A collection of OpDefs -message OpList { - repeated OpDef op = 1; -}; diff --git a/protos/predict.proto b/protos/predict.proto deleted file mode 100644 index f4a83f8c..00000000 --- a/protos/predict.proto +++ /dev/null @@ -1,40 +0,0 @@ -syntax = "proto3"; - -package tensorflow.serving; -option cc_enable_arenas = true; - -import "tensor.proto"; -import "model.proto"; - -// PredictRequest specifies which TensorFlow model to run, as well as -// how inputs are mapped to tensors and how outputs are filtered before -// returning to user. -message PredictRequest { - // Model Specification. If version is not specified, will use the latest - // (numerical) version. - ModelSpec model_spec = 1; - - // Input tensors. - // Names of input tensor are alias names. The mapping from aliases to real - // input tensor names is stored in the SavedModel export as a prediction - // SignatureDef under the 'inputs' field. - map inputs = 2; - - // Output filter. - // Names specified are alias names. The mapping from aliases to real output - // tensor names is stored in the SavedModel export as a prediction - // SignatureDef under the 'outputs' field. - // Only tensors specified here will be run/fetched and returned, with the - // exception that when none is specified, all tensors specified in the - // named signature will be run/fetched and returned. - repeated string output_filter = 3; -} - -// Response for PredictRequest on successful run. -message PredictResponse { - // Effective Model Specification used to process PredictRequest. - ModelSpec model_spec = 2; - - // Output tensors. - map outputs = 1; -} diff --git a/protos/prediction_service.proto b/protos/prediction_service.proto deleted file mode 100644 index 681796a6..00000000 --- a/protos/prediction_service.proto +++ /dev/null @@ -1,31 +0,0 @@ -syntax = "proto3"; - -package tensorflow.serving; -option cc_enable_arenas = true; - -// import "tensorflow_serving/apis/classification.proto"; -// import "tensorflow_serving/apis/inference.proto"; -// import "tensorflow_serving/apis/regression.proto"; -import "get_model_metadata.proto"; -import "predict.proto"; - -// open source marker; do not remove -// PredictionService provides access to machine-learned models loaded by -// model_servers. -service PredictionService { - // Classify. - // rpc Classify(ClassificationRequest) returns (ClassificationResponse); - - // Regress. - // rpc Regress(RegressionRequest) returns (RegressionResponse); - - // Predict -- provides access to loaded TensorFlow model. - rpc Predict(PredictRequest) returns (PredictResponse); - - // MultiInference API for multi-headed models. - // rpc MultiInference(MultiInferenceRequest) returns (MultiInferenceResponse); - - // GetModelMetadata - provides access to metadata for loaded models. - rpc GetModelMetadata(GetModelMetadataRequest) - returns (GetModelMetadataResponse); -} diff --git a/protos/resource_handle.proto b/protos/resource_handle.proto deleted file mode 100644 index 82194668..00000000 --- a/protos/resource_handle.proto +++ /dev/null @@ -1,42 +0,0 @@ -syntax = "proto3"; - -package tensorflow; -option cc_enable_arenas = true; -option java_outer_classname = "ResourceHandle"; -option java_multiple_files = true; -option java_package = "org.tensorflow.framework"; -option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framework"; - -import "tensor_shape.proto"; -import "types.proto"; - -// Protocol buffer representing a handle to a tensorflow resource. Handles are -// not valid across executions, but can be serialized back and forth from within -// a single run. -message ResourceHandleProto { - // Unique name for the device containing the resource. - string device = 1; - - // Container in which this resource is placed. - string container = 2; - - // Unique name of this resource. - string name = 3; - - // Hash code for the type of the resource. Is only valid in the same device - // and in the same execution. - uint64 hash_code = 4; - - // For debug-only, the name of the type pointed to by this handle, if - // available. - string maybe_type_name = 5; - - // Protocol buffer representing a pair of (data type, tensor shape). - message DtypeAndShape { - DataType dtype = 1; - TensorShapeProto shape = 2; - } - - // Data types and shapes for the underlying resource. - repeated DtypeAndShape dtypes_and_shapes = 6; -}; diff --git a/protos/saved_object_graph.proto b/protos/saved_object_graph.proto deleted file mode 100644 index 0bad3a57..00000000 --- a/protos/saved_object_graph.proto +++ /dev/null @@ -1,164 +0,0 @@ -syntax = "proto3"; - -import "trackable_object_graph.proto"; -import "struct.proto"; -import "tensor_shape.proto"; -import "types.proto"; -import "versions.proto"; -import "variable.proto"; - -option cc_enable_arenas = true; - -package tensorflow; - -// A SavedObjectGraph is part of object-based SavedModels in TF 2.0. It -// describes the directed graph of Python objects (or equivalent in other -// languages) that make up a model, with nodes[0] at the root. - -// SavedObjectGraph shares some structure with TrackableObjectGraph, but -// SavedObjectGraph belongs to the MetaGraph and contains pointers to functions -// and type information, while TrackableObjectGraph lives in the checkpoint -// and contains pointers only to variable values. - -message SavedObjectGraph { - // Flattened list of objects in the object graph. - // - // The position of the object in this list indicates its id. - // Nodes[0] is considered the root node. - repeated SavedObject nodes = 1; - - // Information about captures and output structures in concrete functions. - // Referenced from SavedBareConcreteFunction and SavedFunction. - map concrete_functions = 2; -} - -message SavedObject { - // Objects which this object depends on: named edges in the dependency - // graph. - // - // Note: currently only valid if kind == "user_object". - repeated TrackableObjectGraph.TrackableObject.ObjectReference - children = 1; - - // Removed when forking SavedObject from TrackableObjectGraph. - reserved "attributes"; - reserved 2; - - // Slot variables owned by this object. This describes the three-way - // (optimizer, variable, slot variable) relationship; none of the three - // depend on the others directly. - // - // Note: currently only valid if kind == "user_object". - repeated TrackableObjectGraph.TrackableObject.SlotVariableReference - slot_variables = 3; - - oneof kind { - SavedUserObject user_object = 4; - SavedAsset asset = 5; - SavedFunction function = 6; - SavedVariable variable = 7; - SavedBareConcreteFunction bare_concrete_function = 8; - SavedConstant constant = 9; - SavedResource resource = 10; - } -} - -// A SavedUserObject is an object (in the object-oriented language of the -// TensorFlow program) of some user- or framework-defined class other than -// those handled specifically by the other kinds of SavedObjects. -// -// This object cannot be evaluated as a tensor, and therefore cannot be bound -// to an input of a function. -message SavedUserObject { - // Corresponds to a registration of the type to use in the loading program. - string identifier = 1; - // Version information from the producer of this SavedUserObject. - VersionDef version = 2; - // Initialization-related metadata. - string metadata = 3; -} - -// A SavedAsset points to an asset in the MetaGraph. -// -// When bound to a function this object evaluates to a tensor with the absolute -// filename. Users should not depend on a particular part of the filename to -// remain stable (e.g. basename could be changed). -message SavedAsset { - // Index into `MetaGraphDef.asset_file_def[]` that describes the Asset. - // - // Only the field `AssetFileDef.filename` is used. Other fields, such as - // `AssetFileDef.tensor_info`, MUST be ignored. - int32 asset_file_def_index = 1; -} - -// A function with multiple signatures, possibly with non-Tensor arguments. -message SavedFunction { - repeated string concrete_functions = 1; - FunctionSpec function_spec = 2; -} - -// Stores low-level information about a concrete function. Referenced in either -// a SavedFunction or a SavedBareConcreteFunction. -message SavedConcreteFunction { - // Bound inputs to the function. The SavedObjects identified by the node ids - // given here are appended as extra inputs to the caller-supplied inputs. - // The only types of SavedObjects valid here are SavedVariable, SavedResource - // and SavedAsset. - repeated int32 bound_inputs = 2; - // Input in canonicalized form that was received to create this concrete - // function. - StructuredValue canonicalized_input_signature = 3; - // Output that was the return value of this function after replacing all - // Tensors with TensorSpecs. This can be an arbitrary nested function and will - // be used to reconstruct the full structure from pure tensors. - StructuredValue output_signature = 4; -} - -message SavedBareConcreteFunction { - // Identifies a SavedConcreteFunction. - string concrete_function_name = 1; - - // A sequence of unique strings, one per Tensor argument. - repeated string argument_keywords = 2; - // The prefix of `argument_keywords` which may be identified by position. - int64 allowed_positional_arguments = 3; -} - -message SavedConstant { - // An Operation name for a ConstantOp in this SavedObjectGraph's MetaGraph. - string operation = 1; -} - -// Represents a Variable that is initialized by loading the contents from the -// checkpoint. -message SavedVariable { - DataType dtype = 1; - TensorShapeProto shape = 2; - bool trainable = 3; - VariableSynchronization synchronization = 4; - VariableAggregation aggregation = 5; - string name = 6; -} - -// Represents `FunctionSpec` used in `Function`. This represents a -// function that has been wrapped as a TensorFlow `Function`. -message FunctionSpec { - // Full arg spec from inspect.getfullargspec(). - StructuredValue fullargspec = 1; - // Whether this represents a class method. - bool is_method = 2; - // The input signature, if specified. - StructuredValue input_signature = 5; - - reserved 3, 4; -} - -// A SavedResource represents a TF object that holds state during its lifetime. -// An object of this type can have a reference to a: -// create_resource() and an initialize() function. -message SavedResource { - // A device specification indicating a required placement for the resource - // creation function, e.g. "CPU". An empty string allows the user to select a - // device. - string device = 1; -} diff --git a/protos/saver.proto b/protos/saver.proto deleted file mode 100644 index 42453861..00000000 --- a/protos/saver.proto +++ /dev/null @@ -1,47 +0,0 @@ -syntax = "proto3"; - -package tensorflow; -option cc_enable_arenas = true; -option java_outer_classname = "SaverProtos"; -option java_multiple_files = true; -option java_package = "org.tensorflow.util"; -option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/protobuf"; - -// Protocol buffer representing the configuration of a Saver. -message SaverDef { - // The name of the tensor in which to specify the filename when saving or - // restoring a model checkpoint. - string filename_tensor_name = 1; - - // The operation to run when saving a model checkpoint. - string save_tensor_name = 2; - - // The operation to run when restoring a model checkpoint. - string restore_op_name = 3; - - // Maximum number of checkpoints to keep. If 0, no checkpoints are deleted. - int32 max_to_keep = 4; - - // Shard the save files, one per device that has Variable nodes. - bool sharded = 5; - - // How often to keep an additional checkpoint. If not specified, only the last - // "max_to_keep" checkpoints are kept; if specified, in addition to keeping - // the last "max_to_keep" checkpoints, an additional checkpoint will be kept - // for every n hours of training. - float keep_checkpoint_every_n_hours = 6; - - // A version number that identifies a different on-disk checkpoint format. - // Usually, each subclass of BaseSaverBuilder works with a particular - // version/format. However, it is possible that the same builder may be - // upgraded to support a newer checkpoint format in the future. - enum CheckpointFormatVersion { - // Internal legacy format. - LEGACY = 0; - // Deprecated format: tf.Saver() which works with tensorflow::table::Table. - V1 = 1; - // Current format: more efficient. - V2 = 2; - } - CheckpointFormatVersion version = 7; -} diff --git a/protos/struct.proto b/protos/struct.proto deleted file mode 100644 index 7b590903..00000000 --- a/protos/struct.proto +++ /dev/null @@ -1,134 +0,0 @@ -syntax = "proto3"; - -import "tensor_shape.proto"; -import "types.proto"; - -package tensorflow; - -// `StructuredValue` represents a dynamically typed value representing various -// data structures that are inspired by Python data structures typically used in -// TensorFlow functions as inputs and outputs. -// -// For example when saving a Layer there may be a `training` argument. If the -// user passes a boolean True/False, that switches between two concrete -// TensorFlow functions. In order to switch between them in the same way after -// loading the SavedModel, we need to represent "True" and "False". -// -// A more advanced example might be a function which takes a list of -// dictionaries mapping from strings to Tensors. In order to map from -// user-specified arguments `[{"a": tf.constant(1.)}, {"q": tf.constant(3.)}]` -// after load to the right saved TensorFlow function, we need to represent the -// nested structure and the strings, recording that we have a trace for anything -// matching `[{"a": tf.TensorSpec(None, tf.float32)}, {"q": tf.TensorSpec([], -// tf.float64)}]` as an example. -// -// Likewise functions may return nested structures of Tensors, for example -// returning a dictionary mapping from strings to Tensors. In order for the -// loaded function to return the same structure we need to serialize it. -// -// This is an ergonomic aid for working with loaded SavedModels, not a promise -// to serialize all possible function signatures. For example we do not expect -// to pickle generic Python objects, and ideally we'd stay language-agnostic. -message StructuredValue { - // The kind of value. - oneof kind { - // Represents None. - NoneValue none_value = 1; - - // Represents a double-precision floating-point value (a Python `float`). - double float64_value = 11; - // Represents a signed integer value, limited to 64 bits. - // Larger values from Python's arbitrary-precision integers are unsupported. - sint64 int64_value = 12; - // Represents a string of Unicode characters stored in a Python `str`. - // In Python 3, this is exactly what type `str` is. - // In Python 2, this is the UTF-8 encoding of the characters. - // For strings with ASCII characters only (as often used in TensorFlow code) - // there is effectively no difference between the language versions. - // The obsolescent `unicode` type of Python 2 is not supported here. - string string_value = 13; - // Represents a boolean value. - bool bool_value = 14; - - // Represents a TensorShape. - tensorflow.TensorShapeProto tensor_shape_value = 31; - // Represents an enum value for dtype. - tensorflow.DataType tensor_dtype_value = 32; - // Represents a value for tf.TensorSpec. - TensorSpecProto tensor_spec_value = 33; - // Represents a value for tf.TypeSpec. - TypeSpecProto type_spec_value = 34; - - // Represents a list of `Value`. - ListValue list_value = 51; - // Represents a tuple of `Value`. - TupleValue tuple_value = 52; - // Represents a dict `Value`. - DictValue dict_value = 53; - // Represents Python's namedtuple. - NamedTupleValue named_tuple_value = 54; - } -} - -// Represents None. -message NoneValue {} - -// Represents a Python list. -message ListValue { - repeated StructuredValue values = 1; -} - -// Represents a Python tuple. -message TupleValue { - repeated StructuredValue values = 1; -} - -// Represents a Python dict keyed by `str`. -// The comment on Unicode from Value.string_value applies analogously. -message DictValue { - map fields = 1; -} - -// Represents a (key, value) pair. -message PairValue { - string key = 1; - StructuredValue value = 2; -} - -// Represents Python's namedtuple. -message NamedTupleValue { - string name = 1; - repeated PairValue values = 2; -} - -// A protobuf to tf.TensorSpec. -message TensorSpecProto { - string name = 1; - tensorflow.TensorShapeProto shape = 2; - tensorflow.DataType dtype = 3; -} - -// Represents a tf.TypeSpec -message TypeSpecProto { - enum TypeSpecClass { - UNKNOWN = 0; - SPARSE_TENSOR_SPEC = 1; // tf.SparseTensorSpec - INDEXED_SLICES_SPEC = 2; // tf.IndexedSlicesSpec - RAGGED_TENSOR_SPEC = 3; // tf.RaggedTensorSpec - TENSOR_ARRAY_SPEC = 4; // tf.TensorArraySpec - DATA_DATASET_SPEC = 5; // tf.data.DatasetSpec - DATA_ITERATOR_SPEC = 6; // IteratorSpec from data/ops/iterator_ops.py - OPTIONAL_SPEC = 7; // tf.OptionalSpec - PER_REPLICA_SPEC = 8; // PerReplicaSpec from distribute/values.py - } - TypeSpecClass type_spec_class = 1; - - // The value returned by TypeSpec._serialize(). - StructuredValue type_state = 2; - - // This is currently redundant with the type_spec_class enum, and is only - // used for error reporting. In particular, if you use an older binary to - // load a newer model, and the model uses a TypeSpecClass that the older - // binary doesn't support, then this lets us display a useful error message. - string type_spec_class_name = 3; -} diff --git a/protos/tensor.proto b/protos/tensor.proto deleted file mode 100644 index 5d4d66ae..00000000 --- a/protos/tensor.proto +++ /dev/null @@ -1,94 +0,0 @@ -syntax = "proto3"; - -package tensorflow; -option cc_enable_arenas = true; -option java_outer_classname = "TensorProtos"; -option java_multiple_files = true; -option java_package = "org.tensorflow.framework"; -option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framework"; -import "resource_handle.proto"; -import "tensor_shape.proto"; -import "types.proto"; - -// Protocol buffer representing a tensor. -message TensorProto { - DataType dtype = 1; - - // Shape of the tensor. TODO(touts): sort out the 0-rank issues. - TensorShapeProto tensor_shape = 2; - - // Only one of the representations below is set, one of "tensor_contents" and - // the "xxx_val" attributes. We are not using oneof because as oneofs cannot - // contain repeated fields it would require another extra set of messages. - - // Version number. - // - // In version 0, if the "repeated xxx" representations contain only one - // element, that element is repeated to fill the shape. This makes it easy - // to represent a constant Tensor with a single value. - int32 version_number = 3; - - // Serialized raw tensor content from either Tensor::AsProtoTensorContent or - // memcpy in tensorflow::grpc::EncodeTensorToByteBuffer. This representation - // can be used for all tensor types. The purpose of this representation is to - // reduce serialization overhead during RPC call by avoiding serialization of - // many repeated small items. - bytes tensor_content = 4; - - // Type specific representations that make it easy to create tensor protos in - // all languages. Only the representation corresponding to "dtype" can - // be set. The values hold the flattened representation of the tensor in - // row major order. - - // DT_HALF, DT_BFLOAT16. Note that since protobuf has no int16 type, we'll - // have some pointless zero padding for each value here. - repeated int32 half_val = 13 [packed = true]; - - // DT_FLOAT. - repeated float float_val = 5 [packed = true]; - - // DT_DOUBLE. - repeated double double_val = 6 [packed = true]; - - // DT_INT32, DT_INT16, DT_INT8, DT_UINT8. - repeated int32 int_val = 7 [packed = true]; - - // DT_STRING - repeated bytes string_val = 8; - - // DT_COMPLEX64. scomplex_val(2*i) and scomplex_val(2*i+1) are real - // and imaginary parts of i-th single precision complex. - repeated float scomplex_val = 9 [packed = true]; - - // DT_INT64 - repeated int64 int64_val = 10 [packed = true]; - - // DT_BOOL - repeated bool bool_val = 11 [packed = true]; - - // DT_COMPLEX128. dcomplex_val(2*i) and dcomplex_val(2*i+1) are real - // and imaginary parts of i-th double precision complex. - repeated double dcomplex_val = 12 [packed = true]; - - // DT_RESOURCE - repeated ResourceHandleProto resource_handle_val = 14; - - // DT_VARIANT - repeated VariantTensorDataProto variant_val = 15; - - // DT_UINT32 - repeated uint32 uint32_val = 16 [packed = true]; - - // DT_UINT64 - repeated uint64 uint64_val = 17 [packed = true]; -}; - -// Protocol buffer representing the serialization format of DT_VARIANT tensors. -message VariantTensorDataProto { - // Name of the type of objects being serialized. - string type_name = 1; - // Portions of the object that are not Tensors. - bytes metadata = 2; - // Tensors contained within objects being serialized. - repeated TensorProto tensors = 3; -} diff --git a/protos/tensor_shape.proto b/protos/tensor_shape.proto deleted file mode 100644 index 286156a0..00000000 --- a/protos/tensor_shape.proto +++ /dev/null @@ -1,46 +0,0 @@ -// Protocol buffer representing the shape of tensors. - -syntax = "proto3"; -option cc_enable_arenas = true; -option java_outer_classname = "TensorShapeProtos"; -option java_multiple_files = true; -option java_package = "org.tensorflow.framework"; -option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framework"; - -package tensorflow; - -// Dimensions of a tensor. -message TensorShapeProto { - // One dimension of the tensor. - message Dim { - // Size of the tensor in that dimension. - // This value must be >= -1, but values of -1 are reserved for "unknown" - // shapes (values of -1 mean "unknown" dimension). Certain wrappers - // that work with TensorShapeProto may fail at runtime when deserializing - // a TensorShapeProto containing a dim value of -1. - int64 size = 1; - - // Optional name of the tensor dimension. - string name = 2; - }; - - // Dimensions of the tensor, such as {"input", 30}, {"output", 40} - // for a 30 x 40 2D tensor. If an entry has size -1, this - // corresponds to a dimension of unknown size. The names are - // optional. - // - // The order of entries in "dim" matters: It indicates the layout of the - // values in the tensor in-memory representation. - // - // The first entry in "dim" is the outermost dimension used to layout the - // values, the last entry is the innermost dimension. This matches the - // in-memory layout of RowMajor Eigen tensors. - // - // If "dim.size()" > 0, "unknown_rank" must be false. - repeated Dim dim = 2; - - // If true, the number of dimensions in the shape is unknown. - // - // If true, "dim.size()" must be 0. - bool unknown_rank = 3; -}; diff --git a/protos/trackable_object_graph.proto b/protos/trackable_object_graph.proto deleted file mode 100644 index 02d852e6..00000000 --- a/protos/trackable_object_graph.proto +++ /dev/null @@ -1,59 +0,0 @@ -syntax = "proto3"; - -option cc_enable_arenas = true; - -package tensorflow; - -// A TensorBundle addition which saves extra information about the objects which -// own variables, allowing for more robust checkpoint loading into modified -// programs. - -message TrackableObjectGraph { - message TrackableObject { - message ObjectReference { - // An index into `TrackableObjectGraph.nodes`, indicating the object - // being referenced. - int32 node_id = 1; - // A user-provided name for the edge. - string local_name = 2; - } - - message SerializedTensor { - // A name for the Tensor. Simple variables have only one - // `SerializedTensor` named "VARIABLE_VALUE" by convention. This value may - // be restored on object creation as an optimization. - string name = 1; - // The full name of the variable/tensor, if applicable. Used to allow - // name-based loading of checkpoints which were saved using an - // object-based API. Should match the checkpoint key which would have been - // assigned by tf.train.Saver. - string full_name = 2; - // The generated name of the Tensor in the checkpoint. - string checkpoint_key = 3; - // Whether checkpoints should be considered as matching even without this - // value restored. Used for non-critical values which don't affect the - // TensorFlow graph, such as layer configurations. - bool optional_restore = 4; - } - - message SlotVariableReference { - // An index into `TrackableObjectGraph.nodes`, indicating the - // variable object this slot was created for. - int32 original_variable_node_id = 1; - // The name of the slot (e.g. "m"/"v"). - string slot_name = 2; - // An index into `TrackableObjectGraph.nodes`, indicating the - // `Object` with the value of the slot variable. - int32 slot_variable_node_id = 3; - } - - // Objects which this object depends on. - repeated ObjectReference children = 1; - // Serialized data specific to this object. - repeated SerializedTensor attributes = 2; - // Slot variables owned by this object. - repeated SlotVariableReference slot_variables = 3; - } - - repeated TrackableObject nodes = 1; -} diff --git a/protos/types.proto b/protos/types.proto deleted file mode 100644 index 5356f9f9..00000000 --- a/protos/types.proto +++ /dev/null @@ -1,76 +0,0 @@ -syntax = "proto3"; - -package tensorflow; -option cc_enable_arenas = true; -option java_outer_classname = "TypesProtos"; -option java_multiple_files = true; -option java_package = "org.tensorflow.framework"; -option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framework"; - -// (== suppress_warning documentation-presence ==) -// LINT.IfChange -enum DataType { - // Not a legal value for DataType. Used to indicate a DataType field - // has not been set. - DT_INVALID = 0; - - // Data types that all computation devices are expected to be - // capable to support. - DT_FLOAT = 1; - DT_DOUBLE = 2; - DT_INT32 = 3; - DT_UINT8 = 4; - DT_INT16 = 5; - DT_INT8 = 6; - DT_STRING = 7; - DT_COMPLEX64 = 8; // Single-precision complex - DT_INT64 = 9; - DT_BOOL = 10; - DT_QINT8 = 11; // Quantized int8 - DT_QUINT8 = 12; // Quantized uint8 - DT_QINT32 = 13; // Quantized int32 - DT_BFLOAT16 = 14; // Float32 truncated to 16 bits. Only for cast ops. - DT_QINT16 = 15; // Quantized int16 - DT_QUINT16 = 16; // Quantized uint16 - DT_UINT16 = 17; - DT_COMPLEX128 = 18; // Double-precision complex - DT_HALF = 19; - DT_RESOURCE = 20; - DT_VARIANT = 21; // Arbitrary C++ data types - DT_UINT32 = 22; - DT_UINT64 = 23; - - // Do not use! These are only for parameters. Every enum above - // should have a corresponding value below (verified by types_test). - DT_FLOAT_REF = 101; - DT_DOUBLE_REF = 102; - DT_INT32_REF = 103; - DT_UINT8_REF = 104; - DT_INT16_REF = 105; - DT_INT8_REF = 106; - DT_STRING_REF = 107; - DT_COMPLEX64_REF = 108; - DT_INT64_REF = 109; - DT_BOOL_REF = 110; - DT_QINT8_REF = 111; - DT_QUINT8_REF = 112; - DT_QINT32_REF = 113; - DT_BFLOAT16_REF = 114; - DT_QINT16_REF = 115; - DT_QUINT16_REF = 116; - DT_UINT16_REF = 117; - DT_COMPLEX128_REF = 118; - DT_HALF_REF = 119; - DT_RESOURCE_REF = 120; - DT_VARIANT_REF = 121; - DT_UINT32_REF = 122; - DT_UINT64_REF = 123; -} -// LINT.ThenChange( -// https://www.tensorflow.org/code/tensorflow/c/tf_datatype.h, -// https://www.tensorflow.org/code/tensorflow/go/tensor.go, -// https://www.tensorflow.org/code/tensorflow/core/framework/tensor.cc, -// https://www.tensorflow.org/code/tensorflow/core/framework/types.h, -// https://www.tensorflow.org/code/tensorflow/core/framework/types.cc, -// https://www.tensorflow.org/code/tensorflow/python/framework/dtypes.py, -// https://www.tensorflow.org/code/tensorflow/python/framework/function.py) diff --git a/protos/variable.proto b/protos/variable.proto deleted file mode 100644 index b2978c75..00000000 --- a/protos/variable.proto +++ /dev/null @@ -1,85 +0,0 @@ -syntax = "proto3"; - -package tensorflow; - -option cc_enable_arenas = true; -option java_outer_classname = "VariableProtos"; -option java_multiple_files = true; -option java_package = "org.tensorflow.framework"; - -option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framework"; - -// Indicates when a distributed variable will be synced. -enum VariableSynchronization { - // `AUTO`: Indicates that the synchronization will be determined by the - // current `DistributionStrategy` (eg. With `MirroredStrategy` this would be - // `ON_WRITE`). - VARIABLE_SYNCHRONIZATION_AUTO = 0; - // `NONE`: Indicates that there will only be one copy of the variable, so - // there is no need to sync. - VARIABLE_SYNCHRONIZATION_NONE = 1; - // `ON_WRITE`: Indicates that the variable will be updated across devices - // every time it is written. - VARIABLE_SYNCHRONIZATION_ON_WRITE = 2; - // `ON_READ`: Indicates that the variable will be aggregated across devices - // when it is read (eg. when checkpointing or when evaluating an op that uses - // the variable). - VARIABLE_SYNCHRONIZATION_ON_READ = 3; -} - -// Indicates how a distributed variable will be aggregated. -enum VariableAggregation { - // `NONE`: This is the default, giving an error if you use a - // variable-update operation with multiple replicas. - VARIABLE_AGGREGATION_NONE = 0; - // `SUM`: Add the updates across replicas. - VARIABLE_AGGREGATION_SUM = 1; - // `MEAN`: Take the arithmetic mean ("average") of the updates across - // replicas. - VARIABLE_AGGREGATION_MEAN = 2; - // `ONLY_FIRST_REPLICA`: This is for when every replica is performing the same - // update, but we only want to perform the update once. Used, e.g., for the - // global step counter. - VARIABLE_AGGREGATION_ONLY_FIRST_REPLICA = 3; -} - -// Protocol buffer representing a Variable. -message VariableDef { - // Name of the variable tensor. - string variable_name = 1; - - // Name of the tensor holding the variable's initial value. - string initial_value_name = 6; - - // Name of the initializer op. - string initializer_name = 2; - - // Name of the snapshot tensor. - string snapshot_name = 3; - - // Support for saving variables as slices of a larger variable. - SaveSliceInfoDef save_slice_info_def = 4; - - // Whether to represent this as a ResourceVariable. - bool is_resource = 5; - - // Whether this variable should be trained. - bool trainable = 7; - - // Indicates when a distributed variable will be synced. - VariableSynchronization synchronization = 8; - - // Indicates how a distributed variable will be aggregated. - VariableAggregation aggregation = 9; -} - -message SaveSliceInfoDef { - // Name of the full variable of which this is a slice. - string full_name = 1; - // Shape of the full variable. - repeated int64 full_shape = 2; - // Offset of this variable into the full variable. - repeated int64 var_offset = 3; - // Shape of this variable. - repeated int64 var_shape = 4; -} diff --git a/protos/versions.proto b/protos/versions.proto deleted file mode 100644 index dd2ec552..00000000 --- a/protos/versions.proto +++ /dev/null @@ -1,32 +0,0 @@ -syntax = "proto3"; - -package tensorflow; -option cc_enable_arenas = true; -option java_outer_classname = "VersionsProtos"; -option java_multiple_files = true; -option java_package = "org.tensorflow.framework"; -option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framework"; - -// Version information for a piece of serialized data -// -// There are different types of versions for each type of data -// (GraphDef, etc.), but they all have the same common shape -// described here. -// -// Each consumer has "consumer" and "min_producer" versions (specified -// elsewhere). A consumer is allowed to consume this data if -// -// producer >= min_producer -// consumer >= min_consumer -// consumer not in bad_consumers -// -message VersionDef { - // The version of the code that produced this data. - int32 producer = 1; - - // Any consumer below this version is not allowed to consume this data. - int32 min_consumer = 2; - - // Specific consumer versions which are disallowed (e.g. due to bugs). - repeated int32 bad_consumers = 3; -}; diff --git a/redis_consumer/__init__.py b/redis_consumer/__init__.py index 40aaeee7..0ee2ddaa 100644 --- a/redis_consumer/__init__.py +++ b/redis_consumer/__init__.py @@ -29,11 +29,9 @@ from redis_consumer import consumers from redis_consumer import grpc_clients -from redis_consumer import pbs from redis_consumer import redis from redis_consumer import settings from redis_consumer import storage -from redis_consumer import tracking from redis_consumer import utils del absolute_import diff --git a/redis_consumer/consumers/__init__.py b/redis_consumer/consumers/__init__.py index 82d80aa4..ef981644 100644 --- a/redis_consumer/consumers/__init__.py +++ b/redis_consumer/consumers/__init__.py @@ -33,14 +33,15 @@ from redis_consumer.consumers.base_consumer import ZipFileConsumer # Custom Workflow consumers -from redis_consumer.consumers.image_consumer import ImageFileConsumer +from redis_consumer.consumers.segmentation_consumer import SegmentationConsumer from redis_consumer.consumers.tracking_consumer import TrackingConsumer from redis_consumer.consumers.multiplex_consumer import MultiplexConsumer # TODO: Import future custom Consumer classes. CONSUMERS = { - 'image': ImageFileConsumer, + 'image': SegmentationConsumer, # deprecated, use "segmentation" instead. + 'segmentation': SegmentationConsumer, 'zip': ZipFileConsumer, 'tracking': TrackingConsumer, 'multiplex': MultiplexConsumer, diff --git a/redis_consumer/consumers/base_consumer.py b/redis_consumer/consumers/base_consumer.py index e0878a70..a4a01316 100644 --- a/redis_consumer/consumers/base_consumer.py +++ b/redis_consumer/consumers/base_consumer.py @@ -33,6 +33,7 @@ import logging import os import sys +import tempfile import time import timeit import urllib @@ -42,9 +43,9 @@ import numpy as np import pytz -from deepcell_toolbox.utils import tile_image, untile_image, resize +from deepcell.applications import ScaleDetection -from redis_consumer.grpc_clients import PredictClient +from redis_consumer.grpc_clients import PredictClient, GrpcModelWrapper from redis_consumer import utils from redis_consumer import settings @@ -65,13 +66,11 @@ def __init__(self, queue, final_status='done', failed_status='failed', - name=settings.HOSTNAME, - output_dir=settings.OUTPUT_DIR): + name=settings.HOSTNAME): self.redis = redis_client self.storage = storage_client self.queue = str(queue).lower() self.name = name - self.output_dir = output_dir self.final_status = final_status self.failed_status = failed_status self.finished_statuses = {final_status, failed_status} @@ -148,7 +147,7 @@ def _handle_error(self, err, redis_hash): def is_valid_hash(self, redis_hash): # pylint: disable=unused-argument """Returns True if the consumer should work on the item""" - return True + return redis_hash is not None def get_current_timestamp(self): """Helper function, returns ISO formatted UTC timestamp""" @@ -206,22 +205,8 @@ def consume(self): status = self.failed_status if status == self.final_status: - required_fields = [ - 'model_name', - 'model_version', - 'preprocess_function', - 'postprocess_function', - ] - result = self.redis.hmget(redis_hash, *required_fields) - hvals = dict(zip(required_fields, result)) - self.logger.debug('Consumed key %s (model %s:%s, ' - 'preprocessing: %s, postprocessing: %s) ' - '(%s retries) in %s seconds.', - redis_hash, hvals.get('model_name'), - hvals.get('model_version'), - hvals.get('preprocess_function'), - hvals.get('postprocess_function'), - 0, timeit.default_timer() - start) + self.logger.debug('Consumed key `%s` in %s seconds.', + redis_hash, timeit.default_timer() - start) if status in self.finished_statuses: # this key is done. remove the key from the processing queue. @@ -247,14 +232,58 @@ def __init__(self, storage_client, queue, **kwargs): - # Create some attributes only used during consume() - self._redis_hash = None - self._redis_values = dict() super(TensorFlowServingConsumer, self).__init__( redis_client, storage_client, queue, **kwargs) - def _consume(self, redis_hash): - raise NotImplementedError + def is_valid_hash(self, redis_hash): + """Don't run on zip files""" + if redis_hash is None: + return False + + fname = str(self.redis.hget(redis_hash, 'input_file_name')) + return not fname.lower().endswith('.zip') + + def download_image(self, image_path): + """Download file from bucket and load it as an image""" + with tempfile.TemporaryDirectory() as tempdir: + fname = self.storage.download(image_path, tempdir) + image = utils.get_image(fname) + return image + + def validate_model_input(self, image, model_name, model_version): + """Validate that the input image meets the workflow requirements.""" + model_metadata = self.get_model_metadata(model_name, model_version) + parse_shape = lambda x: tuple(int(y) for y in x.split(',')) + shapes = [parse_shape(x['in_tensor_shape']) for x in model_metadata] + + # cast as image to match with the list of shapes. + image = [image] if not isinstance(image, list) else image + + errtext = ('Invalid image shape: {}. The {} job expects ' + 'images of shape{}').format( + [s.shape for s in image], self.queue, shapes) + + if len(image) != len(shapes): + raise ValueError(errtext) + + validated = [] + + for img, shape in zip(image, shapes): + rank = len(shape) # expects a batch dimension + channels = shape[-1] + + if len(img.shape) != rank: + raise ValueError(errtext) + + if img.shape[1] == channels: + img = np.rollaxis(img, 1, rank) + + if img.shape[rank - 1] != channels: + raise ValueError(errtext) + + validated.append(img) + + return validated[0] if len(validated) == 1 else validated def _get_predict_client(self, model_name, model_version): """Returns the TensorFlow Serving gRPC client. @@ -273,61 +302,6 @@ def _get_predict_client(self, model_name, model_version): timeit.default_timer() - t) return client - def grpc_image(self, img, model_name, model_version, model_shape, - in_tensor_name='image', in_tensor_dtype='DT_FLOAT'): - """Use the TensorFlow Serving gRPC API for model inference on an image. - - Args: - img (numpy.array): The image to send to the model - model_name (str): The name of the model - model_version (int): The version of the model - model_shape (tuple): The shape of input data for the model - in_tensor_name (str): The name of the input tensor for the request - in_tensor_dtype (str): The dtype of the input data - - Returns: - numpy.array: The results of model inference. - """ - in_tensor_dtype = str(in_tensor_dtype).upper() - - start = timeit.default_timer() - self.logger.debug('Segmenting image of shape %s with model %s:%s', - img.shape, model_name, model_version) - - if len(model_shape) == img.ndim + 1: - img = np.expand_dims(img, axis=0) - - if in_tensor_dtype == 'DT_HALF': - # TODO: seems like should cast to "half" - # but the model rejects the type, wants "int" or "long" - img = img.astype('int') - - req_data = [{'in_tensor_name': in_tensor_name, - 'in_tensor_dtype': in_tensor_dtype, - 'data': img}] - - client = self._get_predict_client(model_name, model_version) - - prediction = client.predict(req_data, settings.GRPC_TIMEOUT) - results = [prediction[k] for k in sorted(prediction.keys())] - - if len(results) == 1: - results = results[0] - - finished = timeit.default_timer() - start - if self._redis_hash is not None: - self.update_key(self._redis_hash, { - 'prediction_time': finished, - }) - self.logger.debug('Segmented key %s (model %s:%s, ' - 'preprocessing: %s, postprocessing: %s)' - ' (%s retries) in %s seconds.', - self._redis_hash, model_name, model_version, - self._redis_values.get('preprocess_function'), - self._redis_values.get('postprocess_function'), - 0, finished) - return results - def parse_model_metadata(self, metadata): """Parse the metadata response and return list of input metadata. @@ -379,8 +353,8 @@ def get_model_metadata(self, model_name, model_version): inputs = inputs['serving_default']['inputs'] finished = timeit.default_timer() - start - self.logger.debug('Got model metadata for %s in %s seconds.', - model, finished) + self.logger.info('Got model metadata for %s in %s seconds.', + model, finished) self.redis.hset(model, 'metadata', json.dumps(inputs)) self.redis.expire(model, settings.METADATA_EXPIRE_TIME) @@ -389,11 +363,43 @@ def get_model_metadata(self, model_name, model_version): self.logger.error('Malformed metadata: %s', model_metadata) raise err - def detect_scale(self, image): # pylint: disable=unused-argument - """Stub for scale detection""" - self.logger.debug('Scale was not given. Defaults to 1') - scale = 1 - return scale + def get_grpc_app(self, model, application_cls, **kwargs): + """ + Create an application from deepcell.applications + with a gRPC model wrapper as a model + """ + model_name, model_version = model.split(':') + model_metadata = self.get_model_metadata(model_name, model_version) + client = self._get_predict_client(model_name, model_version) + model_wrapper = GrpcModelWrapper(client, model_metadata) + return application_cls(model_wrapper, **kwargs) + + def detect_scale(self, image): + """Send the image to the SCALE_DETECT_MODEL to detect the relative + scale difference from the image to the model's training data. + + Args: + image (numpy.array): The image data. + + Returns: + scale (float): The detected scale, used to rescale data. + """ + start = timeit.default_timer() + + app = self.get_grpc_app(settings.SCALE_DETECT_MODEL, ScaleDetection) + + if not settings.SCALE_DETECT_ENABLED: + self.logger.debug('Scale detection disabled.') + return 1 # app.model_mpp + + batch_size = app.model.get_batch_size() + detected_scale = app.predict(image, batch_size=batch_size) + + self.logger.debug('Scale %s detected in %s seconds', + detected_scale, timeit.default_timer() - start) + + # detected_scale = detected_scale * app.model_mpp + return detected_scale def get_image_scale(self, scale, image, redis_hash): """Calculate scale of image and rescale""" @@ -412,326 +418,9 @@ def get_image_scale(self, scale, image, redis_hash): settings.MAX_SCALE)) return scale - def _predict_big_image(self, - image, - model_name, - model_version, - model_shape, - model_input_name='image', - model_dtype='DT_FLOAT', - untile=True, - stride_ratio=0.75): - """Use tile_image to tile image for the model and untile the results. - - Args: - image (numpy.array): image data as numpy. - model_name (str): hosted model to send image data. - model_version (str): model version to query. - model_shape (tuple): shape of the model's expected input. - model_dtype (str): dtype of the model's input array. - model_input_name (str): name of the model's input array. - untile (bool): untiles results back to image shape if True. - stride_ratio (float): amount to overlap between tiles, (0, 1]. - - Returns: - numpy.array: untiled results from the model. - """ - model_ndim = len(model_shape) - input_shape = (model_shape[model_ndim - 3], model_shape[model_ndim - 2]) - - ratio = (model_shape[model_ndim - 3] / settings.TF_MIN_MODEL_SIZE) * \ - (model_shape[model_ndim - 2] / settings.TF_MIN_MODEL_SIZE) * \ - (model_shape[model_ndim - 1]) - - batch_size = int(settings.TF_MAX_BATCH_SIZE // ratio) - - tiles, tiles_info = tile_image( - np.expand_dims(image, axis=0), - model_input_shape=input_shape, - stride_ratio=stride_ratio) - - self.logger.debug('Tiling image of shape %s into shape %s.', - image.shape, tiles.shape) - - # max_batch_size is 1 by default. - # dependent on the tf-serving configuration - results = [] - for t in range(0, tiles.shape[0], batch_size): - batch = tiles[t:t + batch_size] - output = self.grpc_image( - batch, model_name, model_version, model_shape, - in_tensor_name=model_input_name, - in_tensor_dtype=model_dtype) - - if not isinstance(output, list): - output = [output] - - if len(results) == 0: - results = output - else: - for i, o in enumerate(output): - results[i] = np.vstack((results[i], o)) - - if not untile: - image = results - else: - image = [untile_image(r, tiles_info, model_input_shape=input_shape) - for r in results] - - image = image[0] if len(image) == 1 else image - return image - - def _predict_small_image(self, - image, - model_name, - model_version, - model_shape, - model_input_name='image', - model_dtype='DT_FLOAT'): - """Pad an image that is too small for the model, and unpad the results. - - Args: - img (numpy.array): The too-small image to be predicted with - model_name and model_version. - model_name (str): hosted model to send image data. - model_version (str): model version to query. - model_shape (tuple): shape of the model's expected input. - model_input_name (str): name of the model's input array. - model_dtype (str): dtype of the model's input array. - - Returns: - numpy.array: unpadded results from the model. - """ - pad_width = [] - model_ndim = len(model_shape) - for i in range(image.ndim): - if i in {image.ndim - 3, image.ndim - 2}: - if i == image.ndim - 3: - diff = model_shape[model_ndim - 3] - image.shape[i] - else: - diff = model_shape[model_ndim - 2] - image.shape[i] - - if diff > 0: - if diff % 2: - pad_width.append((diff // 2, diff // 2 + 1)) - else: - pad_width.append((diff // 2, diff // 2)) - else: - pad_width.append((0, 0)) - else: - pad_width.append((0, 0)) - - self.logger.info('Padding image from shape %s to shape %s.', - image.shape, tuple([x + y1 + y2 for x, (y1, y2) in - zip(image.shape, pad_width)])) - - padded_img = np.pad(image, pad_width, 'reflect') - image = self.grpc_image(padded_img, model_name, model_version, - model_shape, in_tensor_name=model_input_name, - in_tensor_dtype=model_dtype) - - image = [image] if not isinstance(image, list) else image - - # pad batch_size and frames for each output. - pad_widths = [pad_width] * len(image) - for i, im in enumerate(image): - while len(pad_widths[i]) < im.ndim: - pad_widths[i].insert(0, (0, 0)) - - # unpad results - image = [utils.unpad_image(i, p) for i, p in zip(image, pad_widths)] - image = image[0] if len(image) == 1 else image - - return image - - def predict(self, image, model_name, model_version, untile=True): - """Performs model inference on the image data. - - Args: - image (numpy.array): the image data - model_name (str): hosted model to send image data. - model_version (int): model version to query. - untile (bool): Whether to untile the tiled inference results. This - should be True when the model output is the same shape as the - input, and False otherwise. - """ - start = timeit.default_timer() - model_metadata = self.get_model_metadata(model_name, model_version) - - # TODO: generalize for more than a single input. - if len(model_metadata) > 1: - raise ValueError('Model {}:{} has {} required inputs but was only ' - 'given {} inputs.'.format( - model_name, model_version, - len(model_metadata), len(image))) - model_metadata = model_metadata[0] - - model_input_name = model_metadata['in_tensor_name'] - model_dtype = model_metadata['in_tensor_dtype'] - - model_shape = [int(x) for x in model_metadata['in_tensor_shape'].split(',')] - model_ndim = len(model_shape) - - image = image[0] if image.shape[0] == 1 else image - - if model_ndim != image.ndim + 1: - raise ValueError('Image of shape {} is incompatible with model ' - '{}:{} with input shape {}'.format( - image.shape, model_name, model_version, - tuple(model_shape))) - - if (image.shape[-2] > settings.MAX_IMAGE_WIDTH or - image.shape[-3] > settings.MAX_IMAGE_HEIGHT): - raise ValueError( - 'The image is too large! Rescaled images have a maximum size ' - 'of ({}, {}) but found size {}.'.format( - settings.MAX_IMAGE_WIDTH, - settings.MAX_IMAGE_HEIGHT, - image.shape)) - - size_x = model_shape[model_ndim - 3] - size_y = model_shape[model_ndim - 2] - - size_x = image.shape[image.ndim - 3] if size_x <= 0 else size_x - size_y = image.shape[image.ndim - 2] if size_y <= 0 else size_y - - self.logger.debug('Calling predict on model %s:%s with input shape %s' - ' and dtype %s to segment an image of shape %s.', - model_name, model_version, tuple(model_shape), - model_dtype, image.shape) - - if (image.shape[image.ndim - 3] < size_x or - image.shape[image.ndim - 2] < size_y): - # image is too small for the model, pad the image. - image = self._predict_small_image(image, model_name, model_version, - model_shape, model_input_name, - model_dtype) - elif (image.shape[image.ndim - 3] > size_x or - image.shape[image.ndim - 2] > size_y): - # image is too big for the model, multiple images are tiled. - image = self._predict_big_image(image, model_name, model_version, - model_shape, model_input_name, - model_dtype, untile=untile) - else: - # image size is perfect, just send it to the model - image = self.grpc_image(image, model_name, model_version, - model_shape, model_input_name, model_dtype) - - if isinstance(image, list): - output_shapes = [i.shape for i in image] - else: - output_shapes = [image.shape] # cast as list - - self.logger.debug('Got response from model %s:%s of shape %s in %s ' - 'seconds.', model_name, model_version, output_shapes, - timeit.default_timer() - start) - - return image - - def _get_processing_function(self, process_type, function_name): - """Based on the function category and name, return the function. - - Args: - process_type (str): "pre" or "post" processing - function_name (str): Name processing function, must exist in - settings.PROCESSING_FUNCTIONS. - - Returns: - function: the selected pre- or post-processing function. - """ - clean = lambda x: str(x).lower() - # first, verify the route parameters - name = clean(function_name) - cat = clean(process_type) - if cat not in settings.PROCESSING_FUNCTIONS: - raise ValueError('Processing functions are either "pre" or "post" ' - 'processing. Got %s.' % cat) - - if name not in settings.PROCESSING_FUNCTIONS[cat]: - raise ValueError('"%s" is not a valid %s-processing function' - % (name, cat)) - return settings.PROCESSING_FUNCTIONS[cat][name] - - def process(self, image, key, process_type): - """Apply the pre- or post-processing function to the image data. - - Args: - image (numpy.array): The image data to process. - key (str): The name of the function to use. - process_type (str): "pre" or "post" processing. - - Returns: - numpy.array: The processed image data. - """ - start = timeit.default_timer() - if not key: - return image - - f = self._get_processing_function(process_type, key) - - if key == 'retinanet-semantic': - # image[:-1] is targeted at a two semantic head panoptic model - # TODO This may need to be modified and generalized in the future - results = f(image[:-1]) - elif key == 'retinanet': - results = f(image, self._rawshape[0], self._rawshape[1]) - else: - results = f(image) - - if not isinstance(results, list) and results.shape[0] == 1: - results = np.squeeze(results, axis=0) - - finished = timeit.default_timer() - start - - self.update_key(self._redis_hash, { - '{}process_time'.format(process_type): finished - }) - - self.logger.debug('%s-processed key %s (model %s:%s, preprocessing: %s,' - ' postprocessing: %s) in %s seconds.', - process_type.capitalize(), self._redis_hash, - self._redis_values.get('model_name'), - self._redis_values.get('model_version'), - self._redis_values.get('preprocess_function'), - self._redis_values.get('postprocess_function'), - finished) - - return results - - def preprocess(self, image, keys): - """Wrapper for _process_image but can only call with type="pre". - - Args: - image (numpy.array): image data - keys (list): list of function names to apply to the image - - Returns: - numpy.array: pre-processed image data - """ - pre = None - for key in keys: - x = pre if pre else image - pre = self.process(x, key, 'pre') - return pre - - def postprocess(self, image, keys): - """Wrapper for _process_image but can only call with type="post". - - Args: - image (numpy.array): image data - keys (list): list of function names to apply to the image - - Returns: - numpy.array: post-processed image data - """ - post = None - for key in keys: - x = post if post else image - post = self.process(x, key, 'post') - return post - - def save_output(self, image, redis_hash, save_name, output_shape=None): - with utils.get_tempdir() as tempdir: + def save_output(self, image, save_name): + """Save output images into a zip file and upload it.""" + with tempfile.TemporaryDirectory() as tempdir: # Save each result channel as an image file subdir = os.path.dirname(save_name.replace(tempdir, '')) name = os.path.splitext(os.path.basename(save_name))[0] @@ -739,21 +428,10 @@ def save_output(self, image, redis_hash, save_name, output_shape=None): if not isinstance(image, list): image = [image] - # Rescale image to original size before sending back to user outpaths = [] - added_batch = False - for i, im in enumerate(image): - if output_shape: - if im.shape[0] != 1: - added_batch = True - im = np.expand_dims(im, axis=0) # add batch for resize - self.logger.info('Resizing image of shape %s to %s', im.shape, output_shape) - im = resize(im, output_shape, labeled_image=True) - if added_batch: - im = im[0] - + for img in image: outpaths.extend(utils.save_numpy_array( - im, + img, name=str(name), subdir=subdir, output_dir=tempdir)) @@ -762,7 +440,7 @@ def save_output(self, image, redis_hash, save_name, output_shape=None): # Upload the zip file to cloud storage bucket cleaned = zip_file.replace(tempdir, '') - subdir = os.path.dirname(settings._strip(cleaned)) + subdir = os.path.dirname(utils.strip_bucket_path(cleaned)) subdir = subdir if subdir else None dest, output_url = self.storage.upload(zip_file, subdir=subdir) @@ -794,12 +472,12 @@ def _upload_archived_images(self, hvalues, redis_hash): """Extract all image files and upload them to storage and redis""" all_hashes = set() archive_uuid = uuid.uuid4().hex - with utils.get_tempdir() as tempdir: + with tempfile.TemporaryDirectory() as tempdir: fname = self.storage.download(hvalues.get('input_file_name'), tempdir) image_files = utils.get_image_files_from_dir(fname, tempdir) for i, imfile in enumerate(image_files): - clean_imfile = settings._strip(imfile.replace(tempdir, '')) + clean_imfile = utils.strip_bucket_path(imfile.replace(tempdir, '')) # Save each result channel as an image file subdir = os.path.join(archive_uuid, os.path.dirname(clean_imfile)) dest, _ = self.storage.upload(imfile, subdir=subdir) @@ -870,7 +548,7 @@ def _get_output_file_name(self, key): def _upload_finished_children(self, finished_children, redis_hash): # saved_files = set() - with utils.get_tempdir() as tempdir: + with tempfile.TemporaryDirectory() as tempdir: filename = '{}.zip'.format(uuid.uuid4().hex) zip_path = os.path.join(tempdir, filename) @@ -930,13 +608,7 @@ def _parse_failures(self, failed_children): len(failed_hashes), json.dumps(failed_hashes, indent=4)) - # check python2 vs python3 - if hasattr(urllib, 'parse'): - url_encode = urllib.parse.urlencode # pylint: disable=E1101 - else: - url_encode = urllib.urlencode # pylint: disable=E1101 - - return url_encode(failed_hashes) + return urllib.parse.urlencode(failed_hashes) def _cleanup(self, redis_hash, children, done, failed): start = timeit.default_timer() diff --git a/redis_consumer/consumers/base_consumer_test.py b/redis_consumer/consumers/base_consumer_test.py index a88589c3..7a8c6476 100644 --- a/redis_consumer/consumers/base_consumer_test.py +++ b/redis_consumer/consumers/base_consumer_test.py @@ -28,8 +28,9 @@ from __future__ import division from __future__ import print_function -import itertools import json +import os +import random import time import numpy as np @@ -37,24 +38,29 @@ import pytest from redis_consumer import consumers +from redis_consumer.grpc_clients import GrpcModelWrapper from redis_consumer import settings -from redis_consumer.testing_utils import Bunch, DummyStorage, redis_client +from redis_consumer.testing_utils import _get_image +from redis_consumer.testing_utils import Bunch +from redis_consumer.testing_utils import DummyStorage +from redis_consumer.testing_utils import redis_client # pylint: disable=unused-import +from redis_consumer.testing_utils import make_model_metadata_of_size class TestConsumer(object): # pylint: disable=R0201,W0621 def test_get_redis_hash(self, mocker, redis_client): mocker.patch.object(settings, 'EMPTY_QUEUE_TIMEOUT', 0.01) - queue_name = 'q' - consumer = consumers.Consumer(redis_client, None, queue_name) + queue = 'q' + consumer = consumers.Consumer(redis_client, None, queue) # test emtpy queue assert consumer.get_redis_hash() is None # test that invalid items are not processed and are removed. item = 'item to process' - redis_client.lpush(queue_name, item) + redis_client.lpush(queue, item) # is_valid_hash returns True by default assert consumer.get_redis_hash() == item assert redis_client.llen(consumer.processing_queue) == 1 @@ -64,7 +70,7 @@ def test_get_redis_hash(self, mocker, redis_client): # test invalid hash is failed and removed from queue mocker.patch.object(consumer, 'is_valid_hash', return_value=False) - redis_client.lpush(queue_name, 'invalid') + redis_client.lpush(queue, 'invalid') assert consumer.get_redis_hash() is None # invalid hash, returns None # invalid has was removed from the processing queue assert redis_client.llen(consumer.processing_queue) == 0 @@ -76,9 +82,9 @@ def test_get_redis_hash(self, mocker, redis_client): assert consumer.get_redis_hash() is None def test_purge_processing_queue(self, redis_client): - queue_name = 'q' + queue = 'q' keys = ['abc', 'def', 'xyz'] - consumer = consumers.Consumer(redis_client, None, queue_name) + consumer = consumers.Consumer(redis_client, None, queue) # set keys in processing queue for key in keys: redis_client.lpush(consumer.processing_queue, key) @@ -116,16 +122,16 @@ def test_handle_error(self, redis_client): assert redis_values.get('status') == 'failed' def test__put_back_hash(self, redis_client): - queue_name = 'q' + queue = 'q' # test emtpy queue - consumer = consumers.Consumer(redis_client, None, queue_name) + consumer = consumers.Consumer(redis_client, None, queue) consumer._put_back_hash('DNE') # should be None, shows warning # put back the proper item item = 'redis_hash1' redis_client.lpush(consumer.processing_queue, item) - consumer = consumers.Consumer(redis_client, None, queue_name) + consumer = consumers.Consumer(redis_client, None, queue) consumer._put_back_hash(item) assert redis_client.llen(consumer.processing_queue) == 0 assert redis_client.llen(consumer.queue) == 1 @@ -134,7 +140,7 @@ def test__put_back_hash(self, redis_client): # put back the wrong item other = 'otherhash' redis_client.lpush(consumer.processing_queue, other, item) - consumer = consumers.Consumer(redis_client, None, queue_name) + consumer = consumers.Consumer(redis_client, None, queue) consumer._put_back_hash(item) assert redis_client.llen(consumer.processing_queue) == 0 assert redis_client.llen(consumer.queue) == 2 @@ -143,12 +149,12 @@ def test__put_back_hash(self, redis_client): def test_consume(self, mocker, redis_client): mocker.patch.object(settings, 'EMPTY_QUEUE_TIMEOUT', 0) - queue_name = 'q' + queue = 'q' keys = [str(x) for x in range(1, 10)] err = OSError('thrown on purpose') i = 0 - consumer = consumers.Consumer(redis_client, DummyStorage(), queue_name) + consumer = consumers.Consumer(redis_client, DummyStorage(), queue) def throw_error(*_, **__): raise err @@ -208,41 +214,102 @@ def test__consume(self): class TestTensorFlowServingConsumer(object): # pylint: disable=R0201,W0613,W0621 - def test__get_predict_client(self, redis_client): - stg = DummyStorage() - consumer = consumers.TensorFlowServingConsumer(redis_client, stg, 'q') - with pytest.raises(ValueError): - consumer._get_predict_client('model_name', 'model_version') + def test_is_valid_hash(self, mocker, redis_client): + storage = DummyStorage() + mocker.patch.object(redis_client, 'hget', lambda x, y: x.split(':')[-1]) - consumer._get_predict_client('model_name', 1) + consumer = consumers.TensorFlowServingConsumer(redis_client, storage, 'predict') - def test_grpc_image(self, mocker, redis_client): + assert consumer.is_valid_hash(None) is False + assert consumer.is_valid_hash('file.ZIp') is False + assert consumer.is_valid_hash('predict:1234567890:file.ZIp') is False + assert consumer.is_valid_hash('track:123456789:file.zip') is False + assert consumer.is_valid_hash('predict:123456789:file.zip') is False + assert consumer.is_valid_hash('predict:1234567890:file.tiff') is True + assert consumer.is_valid_hash('predict:1234567890:file.png') is True + + def test_download_image(self, redis_client): storage = DummyStorage() - queue = 'q' + consumer = consumers.TensorFlowServingConsumer(redis_client, storage, 'q') - consumer = consumers.TensorFlowServingConsumer( - redis_client, storage, queue) + image = consumer.download_image('test.tif') + assert isinstance(image, np.ndarray) + assert not os.path.exists('test.tif') - model_shape = (-1, 128, 128, 1) + def test_validate_model_input(self, mocker, redis_client): + storage = DummyStorage() + consumer = consumers.TensorFlowServingConsumer(redis_client, storage, 'q') - def _get_predict_client(model_name, model_version): - return Bunch(predict=lambda x, y: { - 'prediction': x[0]['data'] - }) + model_input_shape = (-1, 32, 32, 1) - mocker.patch.object(consumer, '_get_predict_client', _get_predict_client) + mocked_metadata = make_model_metadata_of_size(model_input_shape) + mocker.patch.object(consumer, 'get_model_metadata', mocked_metadata) - img = np.zeros((1, 32, 32, 3)) - out = consumer.grpc_image(img, 'model', 1, model_shape, 'i', 'DT_HALF') - assert img.shape == out.shape - assert img.sum() == out.sum() + # test valid channels last shapes + valid_input_shapes = [ + (1, 32, 32, 1), # exact same shape + (1, 64, 64, 1), # bigger + (1, 32, 32, 1), # smaller + (1, 33, 31, 1), # mixed + ] + for shape in valid_input_shapes: + # check channels last + img = np.ones(shape) + valid_img = consumer.validate_model_input(img, 'model', '1') + np.testing.assert_array_equal(img, valid_img) + + # should also work for channels first + img = np.rollaxis(img, -1, 1) + valid_img = consumer.validate_model_input(img, 'model', '1') + expected_img = np.rollaxis(img, 0, img.ndim) + np.testing.assert_array_equal(expected_img, valid_img) + + # test invalid shapes + invalid_input_shapes = [ + (32, 32, 1), # no batch dimension + (1, 32, 1), # rank too small + (1, 32, 32, 32, 1), # rank too large + (1, 32, 32, 2), # wrong channels + (1, 16, 64, 2), # wrong channels with mixed shape + ] + for shape in invalid_input_shapes: + img = np.ones(shape) + with pytest.raises(ValueError): + consumer.validate_model_input(img, 'model', '1') + + # test multiple inputs/metadata + count = 3 + model_input_shape = [(-1, 32, 32, 1)] * count + mocked_metadata = make_model_metadata_of_size(model_input_shape) + mocker.patch.object(consumer, 'get_model_metadata', mocked_metadata) + image = [np.ones(s) for s in valid_input_shapes[:count]] + valid_img = consumer.validate_model_input(image, 'model', '1') + # each image should be validated + for i, j in zip(image, valid_img): + np.testing.assert_array_equal(i, j) + + # metadata and image counts do not match + with pytest.raises(ValueError): + image = [np.ones(s) for s in valid_input_shapes[:count]] + consumer.validate_model_input(img, 'model', '1') - img = np.zeros((32, 32, 3)) - consumer._redis_hash = 'not none' - out = consumer.grpc_image(img, 'model', 1, model_shape, 'i', 'DT_HALF') - assert (1,) + img.shape == out.shape - assert img.sum() == out.sum() + # correct number of inputs, but one invalid entry + with pytest.raises(ValueError): + image = [np.ones(s) for s in valid_input_shapes[:count]] + # set a random entry to be invalid + i = random.randint(0, count - 1) + image[i] = np.ones(random.choice(invalid_input_shapes)) + consumer.validate_model_input(image, 'model', '1') + + def test__get_predict_client(self, redis_client): + stg = DummyStorage() + consumer = consumers.TensorFlowServingConsumer(redis_client, stg, 'q') + + with pytest.raises(ValueError): + consumer._get_predict_client('model_name', 'model_version') + + consumer._get_predict_client('model_name', 1) def test_get_model_metadata(self, mocker, redis_client): model_shape = (-1, 216, 216, 1) @@ -324,121 +391,10 @@ def _get_bad_predict_client(model_name, model_version): _get_bad_predict_client) consumer.get_model_metadata('model', 1) - def test_predict(self, mocker, redis_client): - model_shape = (-1, 128, 128, 1) - stg = DummyStorage() - consumer = consumers.TensorFlowServingConsumer(redis_client, stg, 'q') - - mocker.patch.object(settings, 'TF_MAX_BATCH_SIZE', 2) - mocker.patch.object(consumer, 'get_model_metadata', lambda x, y: [{ - 'in_tensor_name': 'image', - 'in_tensor_dtype': 'DT_HALF', - 'in_tensor_shape': ','.join(str(s) for s in model_shape), - }]) - - def grpc_image(data, *args, **kwargs): - return data - - def grpc_image_list(data, *args, **kwargs): # pylint: disable=W0613 - return [data, data] - - image_shapes = [ - (256, 256, 1), - (128, 128, 1), - (64, 64, 1), - (100, 100, 1), - (300, 300, 1), - (257, 301, 1), - (65, 127, 1), - (127, 129, 1), - ] - grpc_funcs = (grpc_image, grpc_image_list) - untiles = (False, True) - prod = itertools.product(image_shapes, grpc_funcs, untiles) - - for image_shape, grpc_func, untile in prod: - x = np.random.random(image_shape) - mocker.patch.object(consumer, 'grpc_image', grpc_func) - - consumer.predict(x, model_name='modelname', model_version=0, - untile=untile) - - # test image larger than max dimensions - with pytest.raises(ValueError): - mocker.patch.object(settings, 'MAX_IMAGE_WIDTH', 300) - mocker.patch.object(settings, 'MAX_IMAGE_HEIGHT', 300) - x = np.random.random((301, 301, 1)) - consumer.predict(x, model_name='modelname', model_version=0) - - # test mismatch of input data and model shape - with pytest.raises(ValueError): - x = np.random.random((5,)) - consumer.predict(x, model_name='modelname', model_version=0) - - # test multiple model metadata inputs are not supported - with pytest.raises(ValueError): - mocker.patch.object(consumer, 'get_model_metadata', grpc_image_list) - x = np.random.random((300, 300, 1)) - consumer.predict(x, model_name='modelname', model_version=0) - - def test__get_processing_function(self, mocker, redis_client): - mocker.patch.object(settings, 'PROCESSING_FUNCTIONS', { - 'valid': { - 'valid': lambda x: True - } - }) - - storage = DummyStorage() - consumer = consumers.ImageFileConsumer(redis_client, storage, 'q') - - x = consumer._get_processing_function('VaLiD', 'vAlId') - y = consumer._get_processing_function('vAlId', 'VaLiD') - assert x == y - - with pytest.raises(ValueError): - consumer._get_processing_function('invalid', 'valid') - - with pytest.raises(ValueError): - consumer._get_processing_function('valid', 'invalid') - - def test_process(self, mocker, redis_client): - # TODO: better test coverage - storage = DummyStorage() - queue = 'q' - img = np.random.random((1, 32, 32, 1)) - - mocker.patch.object(settings, 'PROCESSING_FUNCTIONS', { - 'valid': { - 'valid': lambda x: x, - 'retinanet': lambda *x: x[0], - 'retinanet-semantic': lambda x: x, - } - }) - - consumer = consumers.ImageFileConsumer(redis_client, storage, queue) - - mocker.patch.object(consumer, '_redis_hash', 'a hash') - - output = consumer.process(img, '', '') - np.testing.assert_equal(img, output) - - # image is returned but channel squeezed out - output = consumer.process(img, 'valid', 'valid') - np.testing.assert_equal(img[0], output) - - img = np.random.random((2, 32, 32, 1)) - output = consumer.process(img, 'retinanet-semantic', 'valid') - np.testing.assert_equal(img[0], output) - - consumer._rawshape = (21, 21) - img = np.random.random((1, 32, 32, 1)) - output = consumer.process(img, 'retinanet', 'valid') - np.testing.assert_equal(img[0], output) - def test_get_image_scale(self, mocker, redis_client): stg = DummyStorage() consumer = consumers.TensorFlowServingConsumer(redis_client, stg, 'q') - image = np.random.random((256, 256, 1)) + image = _get_image(256, 256, 1) # test no scale provided expected = 2 @@ -461,6 +417,44 @@ def test_get_image_scale(self, mocker, redis_client): scale = settings.MIN_SCALE - 0.1 consumer.get_image_scale(scale, image, 'some hash') + def test_get_grpc_app(self, mocker, redis_client): + stg = DummyStorage() + consumer = consumers.TensorFlowServingConsumer(redis_client, stg, 'q') + + get_metadata = make_model_metadata_of_size() + get_mock_client = lambda *x: Bunch(predict=lambda *x: None) + mocker.patch.object(consumer, 'get_model_metadata', get_metadata) + mocker.patch.object(consumer, '_get_predict_client', get_mock_client) + + app = consumer.get_grpc_app('model:0', lambda x: x) + + assert isinstance(app, GrpcModelWrapper) + + def test_detect_scale(self, mocker, redis_client): + # pylint: disable=W0613 + shape = (1, 256, 256, 1) + consumer = consumers.TensorFlowServingConsumer(redis_client, None, 'q') + + image = _get_image(shape[1] * 2, shape[2] * 2, shape[3]) + + expected_scale = random.uniform(0.5, 1.5) + # model_mpp = random.uniform(0.5, 1.5) + + mock_app = Bunch( + predict=lambda *x, **y: expected_scale, + # model_mpp=model_mpp, + model=Bunch(get_batch_size=lambda *x: 1)) + + mocker.patch.object(consumer, 'get_grpc_app', lambda *x: mock_app) + + mocker.patch.object(settings, 'SCALE_DETECT_ENABLED', False) + scale = consumer.detect_scale(image) + assert scale == 1 # model_mpp + + mocker.patch.object(settings, 'SCALE_DETECT_ENABLED', True) + scale = consumer.detect_scale(image) + assert scale == expected_scale # * model_mpp + class TestZipFileConsumer(object): # pylint: disable=R0201,W0613,W0621 diff --git a/redis_consumer/consumers/image_consumer.py b/redis_consumer/consumers/image_consumer.py deleted file mode 100644 index fde67cf3..00000000 --- a/redis_consumer/consumers/image_consumer.py +++ /dev/null @@ -1,229 +0,0 @@ -# Copyright 2016-2020 The Van Valen Lab at the California Institute of -# Technology (Caltech), with support from the Paul Allen Family Foundation, -# Google, & National Institutes of Health (NIH) under Grant U24CA224309-01. -# All rights reserved. -# -# Licensed under a modified Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.github.com/vanvalenlab/kiosk-redis-consumer/LICENSE -# -# The Work provided may be used for non-commercial academic purposes only. -# For any other use of the Work, including commercial use, please contact: -# vanvalenlab@gmail.com -# -# Neither the name of Caltech nor the names of its contributors may be used -# to endorse or promote products derived from this software without specific -# prior written permission. -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""ImageFileConsumer class for consuming image segmentation jobs.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import os -import timeit - -import numpy as np - -from redis_consumer.consumers import TensorFlowServingConsumer -from redis_consumer import utils -from redis_consumer import settings - - -class ImageFileConsumer(TensorFlowServingConsumer): - """Consumes image files and uploads the results""" - - def is_valid_hash(self, redis_hash): - if redis_hash is None: - return False - - fname = str(self.redis.hget(redis_hash, 'input_file_name')) - return not fname.lower().endswith('.zip') - - def detect_scale(self, image): - """Send the image to the SCALE_DETECT_MODEL to detect the relative - scale difference from the image to the model's training data. - - Args: - image (numpy.array): The image data. - - Returns: - scale (float): The detected scale, used to rescale data. - """ - start = timeit.default_timer() - - if not settings.SCALE_DETECT_ENABLED: - self.logger.debug('Scale detection disabled. Scale set to 1.') - return 1 - - model_name, model_version = settings.SCALE_DETECT_MODEL.split(':') - - scales = self.predict(image, model_name, model_version, - untile=False) - - detected_scale = np.mean(scales) - - error_rate = .01 # error rate is ~1% for current model. - if abs(detected_scale - 1) < error_rate: - detected_scale = 1 - - self.logger.debug('Scale %s detected in %s seconds', - detected_scale, timeit.default_timer() - start) - return detected_scale - - def detect_label(self, image): - """Send the image to the LABEL_DETECT_MODEL to detect the type of image - data. The model output is mapped with settings.MODEL_CHOICES. - - Args: - image (numpy.array): The image data. - - Returns: - label (int): The detected label. - """ - start = timeit.default_timer() - - if not settings.LABEL_DETECT_ENABLED: - self.logger.debug('Label detection disabled. Label set to None.') - return None - - model_name, model_version = settings.LABEL_DETECT_MODEL.split(':') - - labels = self.predict(image, model_name, model_version, - untile=False) - - labels = np.array(labels) - vote = labels.sum(axis=0) - maj = vote.max() - - detected = np.where(vote == maj)[-1][0] - - self.logger.debug('Label %s detected in %s seconds.', - detected, timeit.default_timer() - start) - return detected - - def _consume(self, redis_hash): - start = timeit.default_timer() - hvals = self.redis.hgetall(redis_hash) - # hold on to the redis hash/values for logging purposes - self._redis_hash = redis_hash - self._redis_values = hvals - - if hvals.get('status') in self.finished_statuses: - self.logger.warning('Found completed hash `%s` with status %s.', - redis_hash, hvals.get('status')) - return hvals.get('status') - - self.logger.debug('Found hash to process `%s` with status `%s`.', - redis_hash, hvals.get('status')) - - self.update_key(redis_hash, { - 'status': 'started', - 'identity_started': self.name, - }) - - # Overridden with LABEL_DETECT_ENABLED - model_name = hvals.get('model_name') - model_version = hvals.get('model_version') - - _ = timeit.default_timer() - - with utils.get_tempdir() as tempdir: - fname = self.storage.download(hvals.get('input_file_name'), tempdir) - image = utils.get_image(fname) - - # Pre-process data before sending to the model - self.update_key(redis_hash, { - 'status': 'pre-processing', - 'download_time': timeit.default_timer() - _, - }) - - # Calculate scale of image and rescale - scale = hvals.get('scale', '') - scale = self.get_image_scale(scale, image, redis_hash) - - original_shape = image.shape - - image = utils.rescale(image, scale) - - # Save shape value for postprocessing purposes - # TODO this is a big janky - self._rawshape = image.shape - label = None - if settings.LABEL_DETECT_ENABLED and model_name and model_version: - self.logger.warning('Label Detection is enabled, but the model' - ' %s:%s was specified in Redis.', - model_name, model_version) - - elif settings.LABEL_DETECT_ENABLED: - # Detect image label type - label = hvals.get('label', '') - if not label: - label = self.detect_label(image) - self.logger.debug('Image label detected: %s', label) - self.update_key(redis_hash, {'label': str(label)}) - else: - label = int(label) - self.logger.debug('Image label already calculated: %s', label) - - # Grap appropriate model - model_name, model_version = utils._pick_model(label) - - if settings.LABEL_DETECT_ENABLED and label is not None: - pre_funcs = utils._pick_preprocess(label).split(',') - else: - pre_funcs = hvals.get('preprocess_function', '').split(',') - - image = np.expand_dims(image, axis=0) # add in the batch dim - image = self.preprocess(image, pre_funcs) - - # Send data to the model - self.update_key(redis_hash, {'status': 'predicting'}) - - image = self.predict(image, model_name, model_version) - - # Post-process model results - self.update_key(redis_hash, {'status': 'post-processing'}) - - if settings.LABEL_DETECT_ENABLED and label is not None: - post_funcs = utils._pick_postprocess(label).split(',') - else: - post_funcs = hvals.get('postprocess_function', '').split(',') - - image = self.postprocess(image, post_funcs) - - # Save the post-processed results to a file - _ = timeit.default_timer() - self.update_key(redis_hash, {'status': 'saving-results'}) - - save_name = hvals.get('original_name', fname) - - if isinstance(image, list): - for i, img in enumerate(image): - if img.shape[-1] != 1: - image[i] = np.expand_dims(img, axis=-1) - elif image.shape[-1] != 1: - image = np.expand_dims(image, axis=-1) - dest, output_url = self.save_output( - image, redis_hash, save_name, original_shape[:-1]) - - # Update redis with the final results - t = timeit.default_timer() - start - self.update_key(redis_hash, { - 'status': self.final_status, - 'output_url': output_url, - 'upload_time': timeit.default_timer() - _, - 'output_file_name': dest, - 'total_jobs': 1, - 'total_time': t, - 'finished_at': self.get_current_timestamp() - }) - return self.final_status diff --git a/redis_consumer/consumers/image_consumer_test.py b/redis_consumer/consumers/image_consumer_test.py deleted file mode 100644 index 6676fcd2..00000000 --- a/redis_consumer/consumers/image_consumer_test.py +++ /dev/null @@ -1,200 +0,0 @@ -# Copyright 2016-2020 The Van Valen Lab at the California Institute of -# Technology (Caltech), with support from the Paul Allen Family Foundation, -# Google, & National Institutes of Health (NIH) under Grant U24CA224309-01. -# All rights reserved. -# -# Licensed under a modified Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.github.com/vanvalenlab/kiosk-redis-consumer/LICENSE -# -# The Work provided may be used for non-commercial academic purposes only. -# For any other use of the Work, including commercial use, please contact: -# vanvalenlab@gmail.com -# -# Neither the name of Caltech nor the names of its contributors may be used -# to endorse or promote products derived from this software without specific -# prior written permission. -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""Tests for post-processing functions""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import itertools - -import numpy as np - -import pytest - -from redis_consumer import consumers -from redis_consumer import settings -from redis_consumer.testing_utils import DummyStorage, redis_client -from redis_consumer.testing_utils import _get_image, make_model_metadata_of_size - - -class TestImageFileConsumer(object): - # pylint: disable=R0201,W0621 - def test_is_valid_hash(self, mocker, redis_client): - storage = DummyStorage() - mocker.patch.object(redis_client, 'hget', lambda x, y: x.split(':')[-1]) - - consumer = consumers.ImageFileConsumer(redis_client, storage, 'predict') - - assert consumer.is_valid_hash(None) is False - assert consumer.is_valid_hash('file.ZIp') is False - assert consumer.is_valid_hash('predict:1234567890:file.ZIp') is False - assert consumer.is_valid_hash('track:123456789:file.zip') is False - assert consumer.is_valid_hash('predict:123456789:file.zip') is False - assert consumer.is_valid_hash('predict:1234567890:file.tiff') is True - assert consumer.is_valid_hash('predict:1234567890:file.png') is True - - def test_detect_label(self, mocker, redis_client): - # pylint: disable=W0613 - model_shape = (1, 216, 216, 1) - consumer = consumers.ImageFileConsumer(redis_client, None, 'q') - - def dummy_metadata(*_, **__): - return { - 'in_tensor_dtype': 'DT_FLOAT', - 'in_tensor_shape': ','.join(str(s) for s in model_shape), - } - - image = _get_image(model_shape[1] * 2, model_shape[2] * 2) - - def predict(*_, **__): - data = np.zeros((3,)) - i = np.random.randint(3) - data[i] = 1 - return data - - mocker.patch.object(consumer, 'predict', predict) - mocker.patch.object(consumer, 'get_model_metadata', dummy_metadata) - mocker.patch.object(settings, 'LABEL_DETECT_MODEL', 'dummymodel:1') - - mocker.patch.object(settings, 'LABEL_DETECT_ENABLED', False) - label = consumer.detect_label(image) - assert label is None - - mocker.patch.object(settings, 'LABEL_DETECT_ENABLED', True) - label = consumer.detect_label(image) - assert label in set(list(range(4))) - - def test_detect_scale(self, mocker, redis_client): - # pylint: disable=W0613 - # TODO: test rescale is < 1% of the original - model_shape = (1, 216, 216, 1) - consumer = consumers.ImageFileConsumer(redis_client, None, 'q') - - def dummy_metadata(*_, **__): - return { - 'in_tensor_dtype': 'DT_FLOAT', - 'in_tensor_shape': ','.join(str(s) for s in model_shape), - } - - big_size = model_shape[1] * np.random.randint(2, 9) - image = _get_image(big_size, big_size) - - expected = 1 - - def predict(diff=1e-8): - def _predict(*_, **__): - sign = -1 if np.random.randint(1, 5) > 2 else 1 - return expected + sign * diff - return _predict - - mocker.patch.object(consumer, 'get_model_metadata', dummy_metadata) - mocker.patch.object(settings, 'SCALE_DETECT_ENABLED', False) - mocker.patch.object(settings, 'SCALE_DETECT_MODEL', 'dummymodel:1') - scale = consumer.detect_scale(image) - assert scale == 1 - - mocker.patch.object(settings, 'SCALE_DETECT_ENABLED', True) - mocker.patch.object(consumer, 'predict', predict(1e-8)) - scale = consumer.detect_scale(image) - # very small changes within error range: - assert scale == 1 - - mocker.patch.object(settings, 'SCALE_DETECT_ENABLED', True) - mocker.patch.object(consumer, 'predict', predict(1e-1)) - scale = consumer.detect_scale(image) - assert isinstance(scale, float) - np.testing.assert_almost_equal(scale, expected, 1e-1) - - def test__consume(self, mocker, redis_client): - # pylint: disable=W0613 - prefix = 'predict' - status = 'new' - storage = DummyStorage() - - consumer = consumers.ImageFileConsumer(redis_client, storage, prefix) - - def grpc_image(data, *args, **kwargs): - return data - - def grpc_image_multi(data, *args, **kwargs): - return np.array(tuple(list(data.shape) + [2])) - - def grpc_image_list(data, *args, **kwargs): # pylint: disable=W0613 - return [data, data] - - mocker.patch.object(consumer, 'detect_label', lambda x: 1) - mocker.patch.object(consumer, 'detect_scale', lambda x: 1) - mocker.patch.object(settings, 'LABEL_DETECT_ENABLED', True) - - grpc_funcs = (grpc_image, grpc_image_list) - model_shapes = [ - (-1, 600, 600, 1), # image too small, pad - (-1, 300, 300, 1), # image is exactly the right size - (-1, 150, 150, 1), # image too big, tile - (-1, 150, 600, 1), # image has one size too small, one size too big - (-1, 600, 150, 1), # image has one size too small, one size too big - ] - - empty_data = {'input_file_name': 'file.tiff'} - full_data = { - 'input_file_name': 'file.tiff', - 'model_version': '0', - 'model_name': 'model', - 'label': '1', - 'scale': '1', - } - label_no_model_data = full_data.copy() - label_no_model_data['model_name'] = '' - - datasets = [empty_data, full_data, label_no_model_data] - - test_hash = 0 - # test finished statuses are returned - for status in (consumer.failed_status, consumer.final_status): - test_hash += 1 - data = empty_data.copy() - data['status'] = status - redis_client.hmset(test_hash, data) - result = consumer._consume(test_hash) - assert result == status - result = redis_client.hget(test_hash, 'status') - assert result == status - test_hash += 1 - - prod = itertools.product(model_shapes, grpc_funcs, datasets) - - for model_shape, grpc_func, data in prod: - metadata = make_model_metadata_of_size(model_shape) - mocker.patch.object(consumer, 'grpc_image', grpc_func) - mocker.patch.object(consumer, 'get_model_metadata', metadata) - mocker.patch.object(consumer, 'process', lambda *x: x[0]) - - redis_client.hmset(test_hash, data) - result = consumer._consume(test_hash) - assert result == consumer.final_status - result = redis_client.hget(test_hash, 'status') - assert result == consumer.final_status - test_hash += 1 diff --git a/redis_consumer/consumers/multiplex_consumer.py b/redis_consumer/consumers/multiplex_consumer.py index d00fcf44..142e9777 100644 --- a/redis_consumer/consumers/multiplex_consumer.py +++ b/redis_consumer/consumers/multiplex_consumer.py @@ -32,25 +32,44 @@ import numpy as np +from deepcell.applications import ScaleDetection, MultiplexSegmentation + from redis_consumer.consumers import TensorFlowServingConsumer -from redis_consumer import utils from redis_consumer import settings -from redis_consumer import processing class MultiplexConsumer(TensorFlowServingConsumer): """Consumes image files and uploads the results""" - def is_valid_hash(self, redis_hash): - if redis_hash is None: - return False + def detect_scale(self, image): + """Send the image to the SCALE_DETECT_MODEL to detect the relative + scale difference from the image to the model's training data. + + Args: + image (numpy.array): The image data. + + Returns: + scale (float): The detected scale, used to rescale data. + """ + start = timeit.default_timer() + + app = self.get_grpc_app(settings.SCALE_DETECT_MODEL, ScaleDetection) + + if not settings.SCALE_DETECT_ENABLED: + self.logger.debug('Scale detection disabled.') + return 1 - fname = str(self.redis.hget(redis_hash, 'input_file_name')) - return not fname.lower().endswith('.zip') + # TODO: What to do with multi-channel data? + # detected_scale = app.predict(image[..., 0]) + detected_scale = 1 + + self.logger.debug('Scale %s detected in %s seconds', + detected_scale, timeit.default_timer() - start) + + return detected_scale def _consume(self, redis_hash): start = timeit.default_timer() - self._redis_hash = redis_hash # workaround for logging. hvals = self.redis.hgetall(redis_hash) if hvals.get('status') in self.finished_statuses: @@ -72,33 +91,13 @@ def _consume(self, redis_hash): _ = timeit.default_timer() # Load input image - with utils.get_tempdir() as tempdir: - fname = self.storage.download(hvals.get('input_file_name'), tempdir) - # TODO: tiffs expand the last axis, is that a problem here? - image = utils.get_image(fname) + fname = hvals.get('input_file_name') + image = self.download_image(fname) # squeeze extra dimension that is added by get_image image = np.squeeze(image) - - # validate correct shape of image - if len(image.shape) > 3: - raise ValueError('Invalid image shape. An image of shape {} was supplied, but the ' - 'multiplex model expects of images of shape' - '[height, widths, 2]'.format(image.shape)) - elif len(image.shape) < 3: - # TODO: Once we can pass warning messages to user, we can treat this as nuclear image - raise ValueError('Invalid image shape. An image of shape {} was supplied, but the ' - 'multiplex model expects images of shape' - '[height, width, 2]'.format(image.shape)) - else: - if image.shape[0] == 2: - image = np.rollaxis(image, 0, 3) - elif image.shape[2] == 2: - pass - else: - raise ValueError('Invalid image shape. An image of shape {} was supplied, ' - 'but the multiplex model expects images of shape' - '[height, widths, 2]'.format(image.shape)) + # add in the batch dim + image = np.expand_dims(image, axis=0) # Pre-process data before sending to the model self.update_key(redis_hash, { @@ -110,41 +109,32 @@ def _consume(self, redis_hash): scale = hvals.get('scale', '') scale = self.get_image_scale(scale, image, redis_hash) - original_shape = image.shape - - # Rescale each channel of the image - image = utils.rescale(image, scale) - image = np.expand_dims(image, axis=0) # add in the batch dim - - # Preprocess image - image = self.preprocess(image, ['multiplex_preprocess']) + # Validate input image + image = self.validate_model_input(image, model_name, model_version) # Send data to the model - self.update_key(redis_hash, {'status': 'predicting'}) - image = self.predict(image, model_name, model_version) + app = self.get_grpc_app(settings.MULTIPLEX_MODEL, + MultiplexSegmentation) - # Post-process model results - self.update_key(redis_hash, {'status': 'post-processing'}) - image = processing.format_output_multiplex(image) - image = self.postprocess(image, ['multiplex_postprocess_consumer']) + results = app.predict(image, batch_size=None, + image_mpp=scale * app.model_mpp) # Save the post-processed results to a file _ = timeit.default_timer() self.update_key(redis_hash, {'status': 'saving-results'}) save_name = hvals.get('original_name', fname) - dest, output_url = self.save_output( - image, redis_hash, save_name, original_shape[:-1]) + dest, output_url = self.save_output(results, save_name) # Update redis with the final results - t = timeit.default_timer() - start + end = timeit.default_timer() self.update_key(redis_hash, { 'status': self.final_status, 'output_url': output_url, - 'upload_time': timeit.default_timer() - _, + 'upload_time': end - _, 'output_file_name': dest, 'total_jobs': 1, - 'total_time': t, + 'total_time': end - start, 'finished_at': self.get_current_timestamp() }) return self.final_status diff --git a/redis_consumer/consumers/multiplex_consumer_test.py b/redis_consumer/consumers/multiplex_consumer_test.py index d23310bb..f7d082e5 100644 --- a/redis_consumer/consumers/multiplex_consumer_test.py +++ b/redis_consumer/consumers/multiplex_consumer_test.py @@ -28,79 +28,59 @@ from __future__ import division from __future__ import print_function -import itertools - import numpy as np import pytest from redis_consumer import consumers -from redis_consumer.testing_utils import redis_client, DummyStorage -from redis_consumer.testing_utils import make_model_metadata_of_size +from redis_consumer import settings +from redis_consumer.testing_utils import _get_image +from redis_consumer.testing_utils import Bunch +from redis_consumer.testing_utils import DummyStorage +from redis_consumer.testing_utils import redis_client class TestMultiplexConsumer(object): - # pylint: disable=R0201 - - def test_is_valid_hash(self, mocker, redis_client): - storage = DummyStorage() - mocker.patch.object(redis_client, 'hget', lambda *x: x[0]) - - consumer = consumers.MultiplexConsumer(redis_client, storage, 'multiplex') - assert consumer.is_valid_hash(None) is False - assert consumer.is_valid_hash('file.ZIp') is False - assert consumer.is_valid_hash('predict:1234567890:file.ZIp') is False - assert consumer.is_valid_hash('track:123456789:file.zip') is False - assert consumer.is_valid_hash('predict:123456789:file.zip') is False - assert consumer.is_valid_hash('multiplex:1234567890:file.tiff') is True - assert consumer.is_valid_hash('multiplex:1234567890:file.png') is True + # pylint: disable=R0201,W0621 - def test__consume(self, mocker, redis_client): + def test_detect_scale(self, mocker, redis_client): # pylint: disable=W0613 + shape = (1, 256, 256, 1) + consumer = consumers.MultiplexConsumer(redis_client, None, 'q') - def make_grpc_image(model_shape=(-1, 256, 256, 2)): - # pylint: disable=E1101 - shape = model_shape[1:-1] - - def grpc(data, *args, **kwargs): - inner_shape = tuple([1] + list(shape) + [1]) - feature_shape = tuple([1] + list(shape) + [3]) + image = _get_image(shape[1] * 2, shape[2] * 2, shape[3]) - inner = np.random.random(inner_shape) - feature = np.random.random(feature_shape) + expected_scale = 1 # random.uniform(0.5, 1.5) + # model_mpp = random.uniform(0.5, 1.5) - inner2 = np.random.random(inner_shape) - feature2 = np.random.random(feature_shape) - return [inner, feature, inner2, feature2] + mock_app = Bunch( + predict=lambda *x, **y: expected_scale, + # model_mpp=model_mpp, + model=Bunch(get_batch_size=lambda *x: 1)) - return grpc + mocker.patch.object(consumer, 'get_grpc_app', lambda *x: mock_app) - image_shapes = [ - (2, 300, 300), # channels first - (300, 300, 2), # channels last - ] + mocker.patch.object(settings, 'SCALE_DETECT_ENABLED', False) + scale = consumer.detect_scale(image) + assert scale == 1 # model_mpp - model_shapes = [ - (-1, 600, 600, 2), # image too small, pad - (-1, 300, 300, 2), # image is exactly the right size - (-1, 150, 150, 2), # image too big, tile - (-1, 150, 600, 2), # image has one size too small, one size too big - (-1, 600, 150, 2), # image has one size too small, one size too big - ] + mocker.patch.object(settings, 'SCALE_DETECT_ENABLED', True) + scale = consumer.detect_scale(image) + assert scale == expected_scale # * model_mpp - scales = ['.9', '1.1', ''] + def test__consume_finished_status(self, redis_client): + queue = 'q' + storage = DummyStorage() - job_data = { - 'input_file_name': 'file.tiff', - } + consumer = consumers.MultiplexConsumer(redis_client, storage, queue) - consumer = consumers.MultiplexConsumer(redis_client, DummyStorage(), 'multiplex') + empty_data = {'input_file_name': 'file.tiff'} test_hash = 0 # test finished statuses are returned for status in (consumer.failed_status, consumer.final_status): test_hash += 1 - data = job_data.copy() + data = empty_data.copy() data['status'] = status redis_client.hmset(test_hash, data) result = consumer._consume(test_hash) @@ -109,50 +89,34 @@ def grpc(data, *args, **kwargs): assert result == status test_hash += 1 - prod = itertools.product(model_shapes, scales, image_shapes) + def test__consume(self, mocker, redis_client): + # pylint: disable=W0613 + queue = 'multiplex' + storage = DummyStorage() + + consumer = consumers.MultiplexConsumer(redis_client, storage, queue) - for model_shape, scale, image_shape in prod: - mocker.patch('redis_consumer.utils.get_image', - lambda x: np.random.random(list(image_shape) + [1])) + empty_data = {'input_file_name': 'file.tiff'} - metadata = make_model_metadata_of_size(model_shape) - grpc_image = make_grpc_image(model_shape) - mocker.patch.object(consumer, 'get_model_metadata', metadata) - mocker.patch.object(consumer, 'grpc_image', grpc_image) - mocker.patch.object(consumer, 'postprocess', - lambda *x: np.random.randint(0, 5, size=(300, 300, 1))) + output_shape = (1, 256, 256, 2) - data = job_data.copy() - data['scale'] = scale + mock_app = Bunch( + predict=lambda *x, **y: np.random.randint(1, 5, size=output_shape), + model_mpp=1, + model=Bunch( + get_batch_size=lambda *x: 1, + input_shape=(1, 32, 32, 1) + ) + ) - redis_client.hmset(test_hash, data) - result = consumer._consume(test_hash) - assert result == consumer.final_status - result = redis_client.hget(test_hash, 'status') - assert result == consumer.final_status - test_hash += 1 + mocker.patch.object(consumer, 'get_grpc_app', lambda *x: mock_app) + mocker.patch.object(consumer, 'get_image_scale', lambda *x: 1) + mocker.patch.object(consumer, 'validate_model_input', lambda *x: x[0]) - model_shape = (-1, 150, 150, 2) - invalid_image_shapes = [ - (150, 150), - (150,), - (150, 150, 1), - (1, 150, 150), - (3, 150, 150), - (1, 1, 150, 150) - ] - - for image_shape in invalid_image_shapes: - mocker.patch('redis_consumer.utils.get_image', - lambda x: np.random.random(list(image_shape) + [1])) - metadata = make_model_metadata_of_size(model_shape) - grpc_image = make_grpc_image(model_shape) - mocker.patch.object(consumer, 'get_model_metadata', metadata) - mocker.patch.object(consumer, 'grpc_image', grpc_image) - - data = job_data.copy() - data['scale'] = '1' + test_hash = 'some hash' - redis_client.hmset(test_hash, data) - with pytest.raises(ValueError, match='Invalid image shape'): - _ = consumer._consume(test_hash) + redis_client.hmset(test_hash, empty_data) + result = consumer._consume(test_hash) + assert result == consumer.final_status + result = redis_client.hget(test_hash, 'status') + assert result == consumer.final_status diff --git a/redis_consumer/consumers/segmentation_consumer.py b/redis_consumer/consumers/segmentation_consumer.py new file mode 100644 index 00000000..a35728c4 --- /dev/null +++ b/redis_consumer/consumers/segmentation_consumer.py @@ -0,0 +1,158 @@ +# Copyright 2016-2020 The Van Valen Lab at the California Institute of +# Technology (Caltech), with support from the Paul Allen Family Foundation, +# Google, & National Institutes of Health (NIH) under Grant U24CA224309-01. +# All rights reserved. +# +# Licensed under a modified Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.github.com/vanvalenlab/kiosk-redis-consumer/LICENSE +# +# The Work provided may be used for non-commercial academic purposes only. +# For any other use of the Work, including commercial use, please contact: +# vanvalenlab@gmail.com +# +# Neither the name of Caltech nor the names of its contributors may be used +# to endorse or promote products derived from this software without specific +# prior written permission. +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""SegmentationConsumer class for consuming image segmentation jobs.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import timeit + +import numpy as np + +from deepcell.applications import LabelDetection +from deepcell_toolbox.processing import normalize + +from redis_consumer.consumers import TensorFlowServingConsumer +from redis_consumer import settings + + +class SegmentationConsumer(TensorFlowServingConsumer): + """Consumes image files and uploads the results""" + + def detect_label(self, image): + """Send the image to the LABEL_DETECT_MODEL to detect the type of image + data. The model output is mapped with settings.MODEL_CHOICES. + + Args: + image (numpy.array): The image data. + + Returns: + label (int): The detected label. + """ + start = timeit.default_timer() + + app = self.get_grpc_app(settings.LABEL_DETECT_MODEL, LabelDetection) + + if not settings.LABEL_DETECT_ENABLED: + self.logger.debug('Label detection disabled. Label set to 0.') + return 0 # Use NuclearSegmentation as default model + + batch_size = app.model.get_batch_size() + detected_label = app.predict(image, batch_size=batch_size) + + self.logger.debug('Label %s detected in %s seconds', + detected_label, timeit.default_timer() - start) + + return int(detected_label) + + def get_image_label(self, label, image, redis_hash): + """Calculate label of image.""" + if not label: + # Detect scale of image (Default to 1) + label = self.detect_label(image) + self.logger.debug('Image label detected: %s.', label) + self.update_key(redis_hash, {'label': label}) + else: + label = int(label) + self.logger.debug('Image label already calculated: %s.', label) + if label not in settings.APPLICATION_CHOICES: + raise ValueError('Label type {} is not supported'.format(label)) + + return label + + def _consume(self, redis_hash): + start = timeit.default_timer() + hvals = self.redis.hgetall(redis_hash) + + if hvals.get('status') in self.finished_statuses: + self.logger.warning('Found completed hash `%s` with status %s.', + redis_hash, hvals.get('status')) + return hvals.get('status') + + self.logger.debug('Found hash to process `%s` with status `%s`.', + redis_hash, hvals.get('status')) + + self.update_key(redis_hash, { + 'status': 'started', + 'identity_started': self.name, + }) + + _ = timeit.default_timer() + + # Load input image + fname = hvals.get('input_file_name') + image = self.download_image(fname) + image = np.expand_dims(image, axis=0) # add a batch dimension + + # Pre-process data before sending to the model + self.update_key(redis_hash, { + 'status': 'pre-processing', + 'download_time': timeit.default_timer() - _, + }) + + # Calculate scale of image and rescale + scale = hvals.get('scale', '') + scale = self.get_image_scale(scale, image, redis_hash) + + label = hvals.get('label', '') + label = self.get_image_label(label, image, redis_hash) + + # Grap appropriate model and application class + model = settings.MODEL_CHOICES[label] + app_cls = settings.APPLICATION_CHOICES[label] + + model_name, model_version = model.split(':') + + # Validate input image + image = self.validate_model_input(image, model_name, model_version) + + # Send data to the model + self.update_key(redis_hash, {'status': 'predicting'}) + + app = self.get_grpc_app(model, app_cls) + + results = app.predict(image, batch_size=None, + image_mpp=scale * app.model_mpp) + + # Save the post-processed results to a file + _ = timeit.default_timer() + self.update_key(redis_hash, {'status': 'saving-results'}) + + save_name = hvals.get('original_name', fname) + dest, output_url = self.save_output(results, save_name) + + # Update redis with the final results + end = timeit.default_timer() + self.update_key(redis_hash, { + 'status': self.final_status, + 'output_url': output_url, + 'upload_time': end - _, + 'output_file_name': dest, + 'total_jobs': 1, + 'total_time': end - start, + 'finished_at': self.get_current_timestamp() + }) + return self.final_status diff --git a/redis_consumer/consumers/segmentation_consumer_test.py b/redis_consumer/consumers/segmentation_consumer_test.py new file mode 100644 index 00000000..cacba189 --- /dev/null +++ b/redis_consumer/consumers/segmentation_consumer_test.py @@ -0,0 +1,152 @@ +# Copyright 2016-2020 The Van Valen Lab at the California Institute of +# Technology (Caltech), with support from the Paul Allen Family Foundation, +# Google, & National Institutes of Health (NIH) under Grant U24CA224309-01. +# All rights reserved. +# +# Licensed under a modified Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.github.com/vanvalenlab/kiosk-redis-consumer/LICENSE +# +# The Work provided may be used for non-commercial academic purposes only. +# For any other use of the Work, including commercial use, please contact: +# vanvalenlab@gmail.com +# +# Neither the name of Caltech nor the names of its contributors may be used +# to endorse or promote products derived from this software without specific +# prior written permission. +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Tests for post-processing functions""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import random + +import numpy as np + +import pytest + +from redis_consumer import consumers +from redis_consumer import settings + +from redis_consumer.testing_utils import Bunch +from redis_consumer.testing_utils import DummyStorage +from redis_consumer.testing_utils import redis_client +from redis_consumer.testing_utils import _get_image + + +class TestSegmentationConsumer(object): + # pylint: disable=R0201,W0621 + + def test_detect_label(self, mocker, redis_client): + # pylint: disable=W0613 + shape = (1, 256, 256, 1) + queue = 'q' + consumer = consumers.SegmentationConsumer(redis_client, None, queue) + + expected_label = random.randint(1, 9) + + mock_app = Bunch( + predict=lambda *x, **y: expected_label, + model=Bunch(get_batch_size=lambda *x: 1)) + + mocker.patch.object(consumer, 'get_grpc_app', lambda *x: mock_app) + + image = _get_image(shape[1] * 2, shape[2] * 2, shape[3]) + + mocker.patch.object(settings, 'LABEL_DETECT_ENABLED', False) + label = consumer.detect_label(image) + assert label == 0 + + mocker.patch.object(settings, 'LABEL_DETECT_ENABLED', True) + label = consumer.detect_label(image) + assert label == expected_label + + def test_get_image_label(self, mocker, redis_client): + queue = 'q' + stg = DummyStorage() + consumer = consumers.SegmentationConsumer(redis_client, stg, queue) + image = _get_image(256, 256, 1) + + # test no label provided + expected = 1 + mocker.patch.object(consumer, 'detect_label', lambda *x: expected) + label = consumer.get_image_label(None, image, 'some hash') + assert label == expected + + # test label provided + expected = 2 + label = consumer.get_image_label(expected, image, 'some hash') + assert label == expected + + # test label provided is invalid + with pytest.raises(ValueError): + label = -1 + consumer.get_image_label(label, image, 'some hash') + + # test label provided is bad type + with pytest.raises(ValueError): + label = 'badval' + consumer.get_image_label(label, image, 'some hash') + + def test__consume_finished_status(self, redis_client): + queue = 'q' + storage = DummyStorage() + + consumer = consumers.SegmentationConsumer(redis_client, storage, queue) + + empty_data = {'input_file_name': 'file.tiff'} + + test_hash = 0 + # test finished statuses are returned + for status in (consumer.failed_status, consumer.final_status): + test_hash += 1 + data = empty_data.copy() + data['status'] = status + redis_client.hmset(test_hash, data) + result = consumer._consume(test_hash) + assert result == status + result = redis_client.hget(test_hash, 'status') + assert result == status + test_hash += 1 + + def test__consume(self, mocker, redis_client): + # pylint: disable=W0613 + queue = 'predict' + storage = DummyStorage() + + consumer = consumers.SegmentationConsumer(redis_client, storage, queue) + + empty_data = {'input_file_name': 'file.tiff'} + + output_shape = (1, 32, 32, 1) + + mock_app = Bunch( + predict=lambda *x, **y: np.random.randint(1, 5, size=output_shape), + model_mpp=1, + model=Bunch( + get_batch_size=lambda *x: 1, + input_shape=(1, 32, 32, 1) + ) + ) + + mocker.patch.object(consumer, 'get_grpc_app', lambda *x: mock_app) + mocker.patch.object(consumer, 'get_image_scale', lambda *x: 1) + mocker.patch.object(consumer, 'get_image_label', lambda *x: 1) + mocker.patch.object(consumer, 'validate_model_input', lambda *x: True) + + test_hash = 'some hash' + + redis_client.hmset(test_hash, empty_data) + result = consumer._consume(test_hash) + assert result == consumer.final_status + result = redis_client.hget(test_hash, 'status') + assert result == consumer.final_status diff --git a/redis_consumer/consumers/tracking_consumer.py b/redis_consumer/consumers/tracking_consumer.py index c014431b..a3257956 100644 --- a/redis_consumer/consumers/tracking_consumer.py +++ b/redis_consumer/consumers/tracking_consumer.py @@ -30,6 +30,7 @@ import json import os +import tempfile import time import timeit import uuid @@ -37,12 +38,13 @@ from skimage.external import tifffile import numpy as np -from redis_consumer.grpc_clients import TrackingClient +from deepcell_toolbox.processing import correct_drift + +from deepcell.applications import CellTracking + from redis_consumer.consumers import TensorFlowServingConsumer from redis_consumer import utils -from redis_consumer import tracking from redis_consumer import settings -from redis_consumer import processing class TrackingConsumer(TensorFlowServingConsumer): @@ -65,58 +67,6 @@ def is_valid_hash(self, redis_hash): return valid_file - def _get_model(self, redis_hash, hvalues): - hostname = '{}:{}'.format(settings.TF_HOST, settings.TF_PORT) - - # Pick model based on redis or default setting - model = hvalues.get('model_name', '') - version = hvalues.get('model_version', '') - if not model or not version: - model, version = settings.TRACKING_MODEL.split(':') - - t = timeit.default_timer() - model = TrackingClient(host=hostname, - redis_hash=redis_hash, - model_name=model, - model_version=int(version), - progress_callback=self._update_progress) - - self.logger.debug('Created the TrackingClient in %s seconds.', - timeit.default_timer() - t) - return model - - def _get_tracker(self, redis_hash, hvalues, raw, segmented): - self.logger.debug('Creating tracker...') - t = timeit.default_timer() - tracking_model = self._get_model(redis_hash, hvalues) - - # Some tracking models do not have an ImageNormalization Layer. - # If not, the data must be normalized before being tracked. - if settings.NORMALIZE_TRACKING: - for frame in range(raw.shape[0]): - raw[frame, ..., 0] = processing.normalize(raw[frame, ..., 0]) - - features = {'appearance', 'distance', 'neighborhood', 'regionprop'} - tracker = tracking.CellTracker( - raw, segmented, - tracking_model, - max_distance=settings.MAX_DISTANCE, - track_length=settings.TRACK_LENGTH, - division=settings.DIVISION, - birth=settings.BIRTH, - death=settings.DEATH, - neighborhood_scale_size=settings.NEIGHBORHOOD_SCALE_SIZE, - features=features) - - self.logger.debug('Created Tracker in %s seconds.', - timeit.default_timer() - t) - return tracker - - def _update_progress(self, redis_hash, progress): - self.update_key(redis_hash, { - 'progress': progress, - }) - def _load_data(self, redis_hash, subdir, fname): """ Given the upload location `input_file_name`, and the downloaded @@ -148,20 +98,6 @@ def _load_data(self, redis_hash, subdir, fname): 'with 3 dimensions, (time, width, height)'.format( tiff_stack.shape)) - # Calculate scale of a subset of raw - scale = hvalues.get('scale', '') - scale = scale if settings.SCALE_DETECT_ENABLED else 1 - - # Pick model and postprocess based on either label or defaults - if settings.LABEL_DETECT_ENABLED: - # model and postprocessing will be determined automatically - # by the ImageFileConsumer - model_name, model_version = '', '' - postprocess_function = '' - else: - model_name, model_version = settings.TRACKING_SEGMENT_MODEL.split(':') - postprocess_function = settings.TRACKING_POSTPROCESS_FUNCTION - num_frames = len(tiff_stack) hash_to_frame = {} remaining_hashes = set() @@ -172,7 +108,7 @@ def _load_data(self, redis_hash, subdir, fname): uid = uuid.uuid4().hex for i, img in enumerate(tiff_stack): - with utils.get_tempdir() as tempdir: + with tempfile.TemporaryDirectory() as tempdir: # Save and upload the frame. segment_fname = '{}-{}-tracking-frame-{}.tif'.format( uid, hvalues.get('original_name'), i) @@ -187,15 +123,10 @@ def _load_data(self, redis_hash, subdir, fname): 'identity_upload': self.name, 'input_file_name': upload_file_name, 'original_name': segment_fname, - 'model_name': model_name, - 'model_version': model_version, - 'postprocess_function': postprocess_function, 'status': 'new', 'created_at': current_timestamp, 'updated_at': current_timestamp, - 'url': upload_file_url, - 'scale': scale, - # 'label': str(label) + 'url': upload_file_url } # make a hash for this frame @@ -234,7 +165,7 @@ def _load_data(self, redis_hash, subdir, fname): if status == self.final_status: # Segmentation is finished, save and load the frame. - with utils.get_tempdir() as tempdir: + with tempfile.TemporaryDirectory() as tempdir: out = self.redis.hget(segment_hash, 'output_file_name') frame_zip = self.storage.download(out, tempdir) frame_files = list(utils.iter_image_archive( @@ -256,19 +187,21 @@ def _load_data(self, redis_hash, subdir, fname): labels = [frames[i] for i in range(num_frames)] # Cast y to int to avoid issues during fourier transform/drift correction - return {'X': np.expand_dims(tiff_stack, axis=-1), - 'y': np.array(labels, dtype='uint16')} + y = np.array(labels, dtype='uint16') + # TODO: Why is there an extra dimension? + # Not a problem in tests, only with application based results. + # Issue with batch dimension from outputs? + y = y[:, 0] if y.shape[1] == 1 else y + return {'X': np.expand_dims(tiff_stack, axis=-1), 'y': y} def _consume(self, redis_hash): start = timeit.default_timer() - hvalues = self.redis.hgetall(redis_hash) - self.logger.debug('Found `%s:*` hash to process "%s": %s', - self.queue, redis_hash, json.dumps(hvalues, indent=4)) + hvals = self.redis.hgetall(redis_hash) - if hvalues.get('status') in self.finished_statuses: + if hvals.get('status') in self.finished_statuses: self.logger.warning('Found completed hash `%s` with status %s.', - redis_hash, hvalues.get('status')) - return hvalues.get('status') + redis_hash, hvals.get('status')) + return hvals.get('status') # Set status and initial progress self.update_key(redis_hash, { @@ -277,8 +210,8 @@ def _consume(self, redis_hash): 'identity_started': self.name, }) - with utils.get_tempdir() as tempdir: - fname = self.storage.download(hvalues.get('input_file_name'), + with tempfile.TemporaryDirectory() as tempdir: + fname = self.storage.download(hvals.get('input_file_name'), tempdir) data = self._load_data(redis_hash, tempdir, fname) @@ -289,36 +222,36 @@ def _consume(self, redis_hash): # Correct for drift if enabled if settings.DRIFT_CORRECT_ENABLED: t = timeit.default_timer() - data['X'], data['y'] = processing.correct_drift(data['X'], data['y']) + data['X'], data['y'] = correct_drift(data['X'], data['y']) self.logger.debug('Drift correction complete in %s seconds.', timeit.default_timer() - t) - # TODO: Add support for rescaling in the tracker - tracker = self._get_tracker(redis_hash, hvalues, data['X'], data['y']) + # Send data to the model + app = self.get_grpc_app(settings.TRACKING_MODEL, CellTracking, + birth=settings.BIRTH, + death=settings.DEATH, + division=settings.DIVISION, + track_length=settings.TRACK_LENGTH) - self.logger.debug('Trying to track...') + self.logger.debug('Tracking...') self.update_key(redis_hash, {'status': 'predicting'}) - tracker.track_cells() + results = app.predict(data['X'], data['y']) self.logger.debug('Tracking done!') - self.update_key(redis_hash, {'status': 'post-processing'}) - # Post-process and save the output file - tracked_data = tracker.postprocess() - self.update_key(redis_hash, {'status': 'saving-results'}) - with utils.get_tempdir() as tempdir: + with tempfile.TemporaryDirectory() as tempdir: # Save lineage data to JSON file lineage_file = os.path.join(tempdir, 'lineage.json') with open(lineage_file, 'w') as fp: - json.dump(tracked_data['tracks'], fp) + json.dump(results['tracks'], fp) - save_name = hvalues.get('original_name', fname) + save_name = hvals.get('original_name', fname) subdir = os.path.dirname(save_name.replace(tempdir, '')) name = os.path.splitext(os.path.basename(save_name))[0] # Save tracked data as tiff stack outpaths = utils.save_numpy_array( - tracked_data['y_tracked'], name=name, + results['y_tracked'], name=name, subdir=subdir, output_dir=tempdir) outpaths.append(lineage_file) diff --git a/redis_consumer/consumers/tracking_consumer_test.py b/redis_consumer/consumers/tracking_consumer_test.py index 19bc03ff..b462536b 100644 --- a/redis_consumer/consumers/tracking_consumer_test.py +++ b/redis_consumer/consumers/tracking_consumer_test.py @@ -37,31 +37,12 @@ import numpy as np from skimage.external import tifffile -import redis_consumer from redis_consumer import consumers from redis_consumer import settings -from redis_consumer.testing_utils import DummyStorage, redis_client, _get_image - - -class DummyTracker(object): - # pylint: disable=R0201,W0613 - def __init__(self, *_, **__): - pass - - def _track_cells(self): - return None - - def track_cells(self): - return None - - def dump(self, *_, **__): - return None - - def postprocess(self, *_, **__): - return { - 'y_tracked': np.zeros((32, 32, 1)), - 'tracks': [] - } +from redis_consumer.testing_utils import Bunch +from redis_consumer.testing_utils import DummyStorage +from redis_consumer.testing_utils import redis_client +from redis_consumer.testing_utils import _get_image class TestTrackingConsumer(object): @@ -84,16 +65,6 @@ def test_is_valid_hash(self, mocker, redis_client): assert consumer.is_valid_hash('track:1234567890:file.trk') is True assert consumer.is_valid_hash('track:1234567890:file.trks') is True - def test__update_progress(self, redis_client): - queue = 'track' - storage = DummyStorage() - consumer = consumers.TrackingConsumer(redis_client, storage, queue) - - redis_hash = 'a job hash' - progress = random.randint(0, 99) - consumer._update_progress(redis_hash, progress) - assert int(redis_client.hget(redis_hash, 'progress')) == progress - def test__load_data(self, tmpdir, mocker, redis_client): queue = 'track' storage = DummyStorage() @@ -163,28 +134,30 @@ def write_child_tiff(*_, **__): lambda *x: range(1, 3)) consumer._load_data(key, tmpdir, fname) - def test__get_tracker(self, mocker, redis_client): - queue = 'track' - storage = DummyStorage() - - shape = (5, 21, 21, 1) - raw = np.random.random(shape) - segmented = np.random.randint(1, 10, size=shape) - - mocker.patch.object(settings, 'NORMALIZE_TRACKING', True) - consumer = consumers.TrackingConsumer(redis_client, storage, queue) - tracker = consumer._get_tracker('item1', {}, raw, segmented) - assert isinstance(tracker, redis_consumer.tracking.CellTracker) - def test__consume(self, mocker, redis_client): queue = 'track' storage = DummyStorage() test_hash = 0 + dummy_results = { + 'y_tracked': np.zeros((32, 32, 1)), + 'tracks': [] + } + + mock_app = Bunch( + predict=lambda *x, **y: dummy_results, + track=lambda *x, **y: dummy_results, + model_mpp=1, + model=Bunch( + get_batch_size=lambda *x: 1, + input_shape=(1, 32, 32, 1) + ) + ) + consumer = consumers.TrackingConsumer(redis_client, storage, queue) - mocker.patch.object(consumer, '_get_tracker', DummyTracker) mocker.patch.object(settings, 'DRIFT_CORRECT_ENABLED', True) + mocker.patch.object(consumer, 'get_grpc_app', lambda *x, **y: mock_app) frames = 3 dummy_data = { diff --git a/redis_consumer/grpc_clients.py b/redis_consumer/grpc_clients.py index bd41704b..45d5f95b 100644 --- a/redis_consumer/grpc_clients.py +++ b/redis_consumer/grpc_clients.py @@ -34,22 +34,98 @@ import logging import time import timeit +import six +import dict_to_protobuf +from google.protobuf.json_format import MessageToJson import grpc -from grpc import RpcError import grpc.beta.implementations from grpc._cython import cygrpc - import numpy as np - -from google.protobuf.json_format import MessageToJson +from tensorflow.core.framework.types_pb2 import DESCRIPTOR +from tensorflow.core.framework.tensor_pb2 import TensorProto +from tensorflow_serving.apis.prediction_service_pb2_grpc import PredictionServiceStub +from tensorflow_serving.apis.predict_pb2 import PredictRequest +from tensorflow_serving.apis.get_model_metadata_pb2 import GetModelMetadataRequest from redis_consumer import settings -from redis_consumer.pbs.prediction_service_pb2_grpc import PredictionServiceStub -from redis_consumer.pbs.predict_pb2 import PredictRequest -from redis_consumer.pbs.get_model_metadata_pb2 import GetModelMetadataRequest -from redis_consumer.utils import grpc_response_to_dict -from redis_consumer.utils import make_tensor_proto + + +logger = logging.getLogger('redis_consumer.grpc_clients') + + +dtype_to_number = { + i.name: i.number for i in DESCRIPTOR.enum_types_by_name['DataType'].values +} + +# TODO: build this dynamically +number_to_dtype_value = { + 1: 'float_val', + 2: 'double_val', + 3: 'int_val', + 4: 'int_val', + 5: 'int_val', + 6: 'int_val', + 7: 'string_val', + 8: 'scomplex_val', + 9: 'int64_val', + 10: 'bool_val', + 18: 'dcomplex_val', + 19: 'half_val', + 20: 'resource_handle_val' +} + + +def grpc_response_to_dict(grpc_response): + # TODO: 'unicode' object has no attribute 'ListFields' + # response_dict = dict_to_protobuf.protobuf_to_dict(grpc_response) + # return response_dict + grpc_response_dict = dict() + + for k in grpc_response.outputs: + shape = [x.size for x in grpc_response.outputs[k].tensor_shape.dim] + + dtype_constant = grpc_response.outputs[k].dtype + + if dtype_constant not in number_to_dtype_value: + grpc_response_dict[k] = 'value not found' + logger.error('Tensor output data type not supported. ' + 'Returning empty dict.') + + dt = number_to_dtype_value[dtype_constant] + if shape == [1]: + grpc_response_dict[k] = eval( + 'grpc_response.outputs[k].' + dt)[0] + else: + grpc_response_dict[k] = np.array( + eval('grpc_response.outputs[k].' + dt)).reshape(shape) + + return grpc_response_dict + + +def make_tensor_proto(data, dtype): + tensor_proto = TensorProto() + + if isinstance(dtype, six.string_types): + dtype = dtype_to_number[dtype] + + dim = [{'size': 1}] + values = [data] + + if hasattr(data, 'shape'): + dim = [{'size': dim} for dim in data.shape] + values = list(data.reshape(-1)) + + tensor_proto_dict = { + 'dtype': dtype, + 'tensor_shape': { + 'dim': dim + }, + number_to_dtype_value[dtype]: values + } + dict_to_protobuf.dict_to_protobuf(tensor_proto_dict, tensor_proto) + + return tensor_proto class GrpcClient(object): @@ -93,6 +169,9 @@ class PredictClient(GrpcClient): def __init__(self, host, model_name, model_version): super(PredictClient, self).__init__(host) + self.logger = logging.getLogger('gRPC:{}:{}'.format( + model_name, model_version + )) self.model_name = model_name self.model_version = model_version @@ -101,11 +180,16 @@ def __init__(self, host, model_name, model_version): PredictRequest: 'Predict', } + # Retry-able gRPC status codes + self.retry_status_codes = { + grpc.StatusCode.DEADLINE_EXCEEDED, + grpc.StatusCode.RESOURCE_EXHAUSTED, + grpc.StatusCode.UNAVAILABLE + } + def _retry_grpc(self, request, request_timeout): request_name = request.__class__.__name__ - self.logger.info('Sending %s to %s model %s:%s.', - request_name, self.host, - self.model_name, self.model_version) + self.logger.info('Sending %s to %s.', request_name, self.host) true_failures, count = 0, 0 @@ -122,8 +206,9 @@ def _retry_grpc(self, request, request_timeout): api_call = getattr(stub, api_endpoint_name) response = api_call(request, timeout=request_timeout) - self.logger.debug('%s finished in %s seconds.', - request_name, timeit.default_timer() - t) + self.logger.debug('%s finished in %s seconds (%s retries).', + request_name, timeit.default_timer() - t, + true_failures) return response except grpc.RpcError as err: @@ -133,7 +218,7 @@ def _retry_grpc(self, request, request_timeout): '%s', request_name, count, err) raise err - if err.code() in settings.GRPC_RETRY_STATUSES: + if err.code() in self.retry_status_codes: count += 1 is_true_failure = err.code() != grpc.StatusCode.UNAVAILABLE true_failures += int(is_true_failure) @@ -162,38 +247,24 @@ def _retry_grpc(self, request, request_timeout): raise err def predict(self, request_data, request_timeout=10): - self.logger.info('Sending PredictRequest to %s model %s:%s.', - self.host, self.model_name, self.model_version) - - t = timeit.default_timer() - request = PredictRequest() - self.logger.debug('Created PredictRequest object in %s seconds.', - timeit.default_timer() - t) - # pylint: disable=E1101 + request = PredictRequest() request.model_spec.name = self.model_name request.model_spec.version.value = self.model_version - t = timeit.default_timer() for d in request_data: tensor_proto = make_tensor_proto(d['data'], d['in_tensor_dtype']) request.inputs[d['in_tensor_name']].CopyFrom(tensor_proto) - self.logger.debug('Made tensor protos in %s seconds.', - timeit.default_timer() - t) - response = self._retry_grpc(request, request_timeout) response_dict = grpc_response_to_dict(response) - self.logger.info('Got PredictResponse with keys: %s ', + self.logger.info('Got PredictResponse with keys: %s.', list(response_dict)) return response_dict def get_model_metadata(self, request_timeout=10): - self.logger.info('Sending GetModelMetadataRequest to %s model %s:%s.', - self.host, self.model_name, self.model_version) - # pylint: disable=E1101 request = GetModelMetadataRequest() request.metadata_field.append('signature_def') @@ -201,88 +272,112 @@ def get_model_metadata(self, request_timeout=10): request.model_spec.version.value = self.model_version response = self._retry_grpc(request, request_timeout) + response_dict = json.loads(MessageToJson(response)) + return response_dict - t = timeit.default_timer() - response_dict = json.loads(MessageToJson(response)) +class GrpcModelWrapper(object): + """A wrapper class that mocks a Keras model using a gRPC client. - self.logger.debug('gRPC GetModelMetadataProtobufConversion took ' - '%s seconds.', timeit.default_timer() - t) + https://github.com/vanvalenlab/deepcell-tf/blob/master/deepcell/applications + """ - return response_dict + def __init__(self, client, model_metadata): + self._client = client + self._metadata = model_metadata + shapes = [ + tuple([int(x) for x in m['in_tensor_shape'].split(',')]) + for m in self._metadata + ] + if len(shapes) == 1: + shapes = shapes[0] + self.input_shape = shapes -class TrackingClient(PredictClient): - """gRPC Client for tensorflow-serving API. + def send_grpc(self, img): + """Use the TensorFlow Serving gRPC API for model inference on an image. - Arguments: - host: string, the hostname and port of the server (`localhost:8080`) - model_name: string, name of model served by tensorflow-serving - model_version: integer, version of the named model - """ + Args: + img (numpy.array): The image to send to the model - def __init__(self, host, model_name, model_version, - redis_hash, progress_callback): - self.redis_hash = redis_hash - self.progress_callback = progress_callback - super(TrackingClient, self).__init__(host, model_name, model_version) + Returns: + numpy.array: The results of model inference. + """ + start = timeit.default_timer() - def predict(self, data, request_timeout=10): - t = timeit.default_timer() - self.logger.info('Tracking data with %s model %s:%s.', - self.host, self.model_name, self.model_version) - - # TODO: features should be retrieved from model metadata - features = {'appearance', 'distance', 'neighborhood', 'regionprop'} - features = sorted(features) - # Grab a random value from the data dict and select batch dim (num cell comparisons) - batch_size = next(iter(data.values())).shape[0] - self.logger.info('batch size: %s', batch_size) - results = [] - # from 0 to Num Cells (comparisons) in increments of TF batch size - for b in range(0, batch_size, settings.TF_MAX_BATCH_SIZE): - request_data = [] - for f in features: - input_name1 = '{}_input1'.format(f) - d1 = { - 'in_tensor_name': input_name1, - 'in_tensor_dtype': 'DT_FLOAT', - 'data': data[input_name1][b:b + settings.TF_MAX_BATCH_SIZE] - } - request_data.append(d1) - - input_name2 = '{}_input2'.format(f) - d2 = { - 'in_tensor_name': input_name2, - 'in_tensor_dtype': 'DT_FLOAT', - 'data': data[input_name2][b:b + settings.TF_MAX_BATCH_SIZE] - } - request_data.append(d2) - - response_dict = super(TrackingClient, self).predict( - request_data, request_timeout) - - output = [response_dict[k] for k in sorted(response_dict.keys())] - if len(output) == 1: - output = output[0] + # cast input as list + if not isinstance(img, list): + img = [img] - if len(results) == 0: - results = output - else: - results = np.vstack((results, output)) + if len(self._metadata) != len(img): + raise ValueError('Expected {} model inputs but got {}.'.format( + len(self._metadata), len(img))) + + req_data = [] + + for i, m in enumerate(self._metadata): + data = img[i] - self.logger.info('Tracked %s input pairs in %s seconds.', - batch_size, timeit.default_timer() - t) + if m['in_tensor_dtype'] == 'DT_HALF': + # seems like should cast to "half" + # but the model rejects the type, wants "int" or "long" + data = data.astype('int') - return np.array(results) + req_data.append({ + 'in_tensor_name': m['in_tensor_name'], + 'in_tensor_dtype': m['in_tensor_dtype'], + 'data': data + }) - def progress(self, progress): - """Update the internal state regarding progress + prediction = self._client.predict(req_data, settings.GRPC_TIMEOUT) + results = [prediction[k] for k in sorted(prediction.keys())] - Arguments: - progress: float, the progress in the interval [0, 1] + self._client.logger.debug('Got prediction results of shape %s in %s s.', + [r.shape for r in results], + timeit.default_timer() - start) + + return results + + def get_batch_size(self): + """Calculate the best batch size based on TF_MAX_BATCH_SIZE and + TF_MIN_MODEL_SIZE """ - progress *= 100 - # clamp to an integer between 0 and 100 - progress = min(100, max(0, round(progress))) - self.progress_callback(self.redis_hash, progress) + input_shape = self.input_shape + if not isinstance(input_shape, list): + input_shape = [input_shape] + + ratio = 1 + for shape in input_shape: + rank = len(shape) + ratio *= (shape[rank - 3] / settings.TF_MIN_MODEL_SIZE) * \ + (shape[rank - 2] / settings.TF_MIN_MODEL_SIZE) * \ + (shape[rank - 1]) + + batch_size = int(settings.TF_MAX_BATCH_SIZE // ratio) + return batch_size + + def predict(self, tiles, batch_size=None): + # TODO: Can the result size be known beforehand via model metadata? + results = [] + + if isinstance(tiles, dict): + tiles = [tiles[m['in_tensor_name']] for m in self._metadata] + + if not isinstance(tiles, list): + tiles = [tiles] + + if batch_size is None: + batch_size = self.get_batch_size() + + for t in range(0, tiles[0].shape[0], batch_size): + inputs = [tile[t:t + batch_size] for tile in tiles] + inputs = inputs[0] if len(inputs) == 1 else inputs + output = self.send_grpc(inputs) + + if len(results) == 0: + results = output + else: + for i, o in enumerate(output): + results[i] = np.vstack((results[i], o)) + + return results[0] if len(results) == 1 else results diff --git a/redis_consumer/grpc_clients_test.py b/redis_consumer/grpc_clients_test.py new file mode 100644 index 00000000..281fce5d --- /dev/null +++ b/redis_consumer/grpc_clients_test.py @@ -0,0 +1,190 @@ +# Copyright 2016-2020 The Van Valen Lab at the California Institute of +# Technology (Caltech), with support from the Paul Allen Family Foundation, +# Google, & National Institutes of Health (NIH) under Grant U24CA224309-01. +# All rights reserved. +# +# Licensed under a modified Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.github.com/vanvalenlab/kiosk-redis-consumer/LICENSE +# +# The Work provided may be used for non-commercial academic purposes only. +# For any other use of the Work, including commercial use, please contact: +# vanvalenlab@gmail.com +# +# Neither the name of Caltech nor the names of its contributors may be used +# to endorse or promote products derived from this software without specific +# prior written permission. +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Tests for gRPC Clients""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import logging + +import pytest + +import numpy as np +from tensorflow.core.framework import types_pb2 +from tensorflow.core.framework.tensor_pb2 import TensorProto +from tensorflow_serving.apis.predict_pb2 import PredictResponse + +from redis_consumer.testing_utils import _get_image, make_model_metadata_of_size + +from redis_consumer import grpc_clients, settings + + +class DummyPredictClient(object): + # pylint: disable=unused-argument + def __init__(self, host, model_name, model_version): + self.logger = logging.getLogger(self.__class__.__name__) + + def predict(self, request_data, request_timeout=10): + retval = {} + for i, d in enumerate(request_data): + retval['prediction{}'.format(i)] = d.get('data') + return retval + + +def test_make_tensor_proto(): + # test with numpy array + data = _get_image(300, 300, 1) + proto = grpc_clients.make_tensor_proto(data, 'DT_FLOAT') + assert isinstance(proto, (TensorProto,)) + # test with value + data = 10.0 + proto = grpc_clients.make_tensor_proto(data, types_pb2.DT_FLOAT) + assert isinstance(proto, (TensorProto,)) + + +def test_grpc_response_to_dict(): + # pylint: disable=E1101 + # test valid response + data = _get_image(300, 300, 1) + tensor_proto = grpc_clients.make_tensor_proto(data, 'DT_FLOAT') + response = PredictResponse() + response.outputs['prediction'].CopyFrom(tensor_proto) + response_dict = grpc_clients.grpc_response_to_dict(response) + assert isinstance(response_dict, (dict,)) + np.testing.assert_allclose(response_dict['prediction'], data) + # test scalar input + data = 3 + tensor_proto = grpc_clients.make_tensor_proto(data, 'DT_FLOAT') + response = PredictResponse() + response.outputs['prediction'].CopyFrom(tensor_proto) + response_dict = grpc_clients.grpc_response_to_dict(response) + assert isinstance(response_dict, (dict,)) + np.testing.assert_allclose(response_dict['prediction'], data) + # test bad dtype + # logs an error, but should throw a KeyError as well. + data = _get_image(300, 300, 1) + tensor_proto = grpc_clients.make_tensor_proto(data, 'DT_FLOAT') + response = PredictResponse() + response.outputs['prediction'].CopyFrom(tensor_proto) + response.outputs['prediction'].dtype = 32 + + with pytest.raises(KeyError): + response_dict = grpc_clients.grpc_response_to_dict(response) + + +class TestGrpcModelWrapper(object): + shape = (1, 300, 300, 1) + name = 'test-model' + version = '0' + + def _get_metadata(self): + metadata_fn = make_model_metadata_of_size(self.shape) + return metadata_fn(self.name, self.version) + + def test_init(self): + metadata = self._get_metadata() + wrapper = grpc_clients.GrpcModelWrapper(None, metadata) + assert wrapper.input_shape == self.shape + + metadata += metadata + wrapper = grpc_clients.GrpcModelWrapper(None, metadata) + assert isinstance(wrapper.input_shape, list) + for s in wrapper.input_shape: + assert s == self.shape + + def test_get_batch_size(self, mocker): + metadata = self._get_metadata() + wrapper = grpc_clients.GrpcModelWrapper(None, metadata) + + for m in (.5, 1, 2): + mocker.patch.object(settings, 'TF_MIN_MODEL_SIZE', self.shape[1] * m) + batch_size = wrapper.get_batch_size() + assert batch_size == settings.TF_MAX_BATCH_SIZE * m * m + + def test_send_grpc(self): + client = DummyPredictClient(1, 2, 3) + metadata = self._get_metadata() + metadata[0]['in_tensor_dtype'] = 'DT_HALF' + wrapper = grpc_clients.GrpcModelWrapper(client, metadata) + + input_data = np.ones(self.shape) + result = wrapper.send_grpc(input_data) + assert isinstance(result, list) + assert len(result) == 1 + np.testing.assert_array_equal(result[0], input_data) + + # test multiple inputs + metadata = self._get_metadata() + self._get_metadata() + input_data = [np.ones(self.shape)] * 2 + wrapper = grpc_clients.GrpcModelWrapper(client, metadata) + result = wrapper.send_grpc(input_data) + assert isinstance(result, list) + assert len(result) == 2 + np.testing.assert_array_equal(result, input_data) + + # test inputs don't match metadata + with pytest.raises(ValueError): + wrapper.send_grpc(np.ones(self.shape)) + + def test_predict(self, mocker): + metadata = self._get_metadata() + wrapper = grpc_clients.GrpcModelWrapper(None, metadata) + + def mock_send_grpc(img): + return img if isinstance(img, list) else [img] + + mocker.patch.object(wrapper, 'send_grpc', mock_send_grpc) + + batch_size = 2 + input_data = np.ones((batch_size * 2, 30, 30, 1)) + + results = wrapper.predict(input_data, batch_size=batch_size) + np.testing.assert_array_equal(input_data, results) + + # no batch size + results = wrapper.predict(input_data) + np.testing.assert_array_equal(input_data, results) + + # multiple inputs + metadata = self._get_metadata() * 2 + wrapper = grpc_clients.GrpcModelWrapper(None, metadata) + mocker.patch.object(wrapper, 'send_grpc', mock_send_grpc) + input_data = [np.ones((batch_size * 2, 30, 30, 1))] * 2 + results = wrapper.predict(input_data, batch_size=batch_size) + np.testing.assert_array_equal(input_data, results) + + # dictionary input + metadata = self._get_metadata() + wrapper = grpc_clients.GrpcModelWrapper(None, metadata) + mocker.patch.object(wrapper, 'send_grpc', mock_send_grpc) + input_data = { + m['in_tensor_name']: np.ones((batch_size * 2, 30, 30, 1)) + for m in metadata + } + results = wrapper.predict(input_data, batch_size=batch_size) + for m in metadata: + np.testing.assert_array_equal( + input_data[m['in_tensor_name']], results) diff --git a/redis_consumer/pbs/__init__.py b/redis_consumer/pbs/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/redis_consumer/pbs/attr_value_pb2.py b/redis_consumer/pbs/attr_value_pb2.py deleted file mode 100644 index 3262caf7..00000000 --- a/redis_consumer/pbs/attr_value_pb2.py +++ /dev/null @@ -1,365 +0,0 @@ -# -*- coding: utf-8 -*- -# Generated by the protocol buffer compiler. DO NOT EDIT! -# source: attr_value.proto - -from google.protobuf import descriptor as _descriptor -from google.protobuf import message as _message -from google.protobuf import reflection as _reflection -from google.protobuf import symbol_database as _symbol_database -# @@protoc_insertion_point(imports) - -_sym_db = _symbol_database.Default() - - -import redis_consumer.pbs.tensor_pb2 as tensor__pb2 -import redis_consumer.pbs.tensor_shape_pb2 as tensor__shape__pb2 -import redis_consumer.pbs.types_pb2 as types__pb2 - - -DESCRIPTOR = _descriptor.FileDescriptor( - name='attr_value.proto', - package='tensorflow', - syntax='proto3', - serialized_options=b'\n\030org.tensorflow.frameworkB\017AttrValueProtosP\001Z=github.com/tensorflow/tensorflow/tensorflow/go/core/framework\370\001\001', - serialized_pb=b'\n\x10\x61ttr_value.proto\x12\ntensorflow\x1a\x0ctensor.proto\x1a\x12tensor_shape.proto\x1a\x0btypes.proto\"\xa6\x04\n\tAttrValue\x12\x0b\n\x01s\x18\x02 \x01(\x0cH\x00\x12\x0b\n\x01i\x18\x03 \x01(\x03H\x00\x12\x0b\n\x01\x66\x18\x04 \x01(\x02H\x00\x12\x0b\n\x01\x62\x18\x05 \x01(\x08H\x00\x12$\n\x04type\x18\x06 \x01(\x0e\x32\x14.tensorflow.DataTypeH\x00\x12-\n\x05shape\x18\x07 \x01(\x0b\x32\x1c.tensorflow.TensorShapeProtoH\x00\x12)\n\x06tensor\x18\x08 \x01(\x0b\x32\x17.tensorflow.TensorProtoH\x00\x12/\n\x04list\x18\x01 \x01(\x0b\x32\x1f.tensorflow.AttrValue.ListValueH\x00\x12(\n\x04\x66unc\x18\n \x01(\x0b\x32\x18.tensorflow.NameAttrListH\x00\x12\x15\n\x0bplaceholder\x18\t \x01(\tH\x00\x1a\xe9\x01\n\tListValue\x12\t\n\x01s\x18\x02 \x03(\x0c\x12\r\n\x01i\x18\x03 \x03(\x03\x42\x02\x10\x01\x12\r\n\x01\x66\x18\x04 \x03(\x02\x42\x02\x10\x01\x12\r\n\x01\x62\x18\x05 \x03(\x08\x42\x02\x10\x01\x12&\n\x04type\x18\x06 \x03(\x0e\x32\x14.tensorflow.DataTypeB\x02\x10\x01\x12+\n\x05shape\x18\x07 \x03(\x0b\x32\x1c.tensorflow.TensorShapeProto\x12\'\n\x06tensor\x18\x08 \x03(\x0b\x32\x17.tensorflow.TensorProto\x12&\n\x04\x66unc\x18\t \x03(\x0b\x32\x18.tensorflow.NameAttrListB\x07\n\x05value\"\x92\x01\n\x0cNameAttrList\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x30\n\x04\x61ttr\x18\x02 \x03(\x0b\x32\".tensorflow.NameAttrList.AttrEntry\x1a\x42\n\tAttrEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12$\n\x05value\x18\x02 \x01(\x0b\x32\x15.tensorflow.AttrValue:\x02\x38\x01\x42o\n\x18org.tensorflow.frameworkB\x0f\x41ttrValueProtosP\x01Z=github.com/tensorflow/tensorflow/tensorflow/go/core/framework\xf8\x01\x01\x62\x06proto3' - , - dependencies=[tensor__pb2.DESCRIPTOR,tensor__shape__pb2.DESCRIPTOR,types__pb2.DESCRIPTOR,]) - - - - -_ATTRVALUE_LISTVALUE = _descriptor.Descriptor( - name='ListValue', - full_name='tensorflow.AttrValue.ListValue', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='s', full_name='tensorflow.AttrValue.ListValue.s', index=0, - number=2, type=12, cpp_type=9, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='i', full_name='tensorflow.AttrValue.ListValue.i', index=1, - number=3, type=3, cpp_type=2, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=b'\020\001', file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='f', full_name='tensorflow.AttrValue.ListValue.f', index=2, - number=4, type=2, cpp_type=6, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=b'\020\001', file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='b', full_name='tensorflow.AttrValue.ListValue.b', index=3, - number=5, type=8, cpp_type=7, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=b'\020\001', file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='type', full_name='tensorflow.AttrValue.ListValue.type', index=4, - number=6, type=14, cpp_type=8, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=b'\020\001', file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='shape', full_name='tensorflow.AttrValue.ListValue.shape', index=5, - number=7, type=11, cpp_type=10, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='tensor', full_name='tensorflow.AttrValue.ListValue.tensor', index=6, - number=8, type=11, cpp_type=10, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='func', full_name='tensorflow.AttrValue.ListValue.func', index=7, - number=9, type=11, cpp_type=10, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=388, - serialized_end=621, -) - -_ATTRVALUE = _descriptor.Descriptor( - name='AttrValue', - full_name='tensorflow.AttrValue', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='s', full_name='tensorflow.AttrValue.s', index=0, - number=2, type=12, cpp_type=9, label=1, - has_default_value=False, default_value=b"", - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='i', full_name='tensorflow.AttrValue.i', index=1, - number=3, type=3, cpp_type=2, label=1, - has_default_value=False, default_value=0, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='f', full_name='tensorflow.AttrValue.f', index=2, - number=4, type=2, cpp_type=6, label=1, - has_default_value=False, default_value=float(0), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='b', full_name='tensorflow.AttrValue.b', index=3, - number=5, type=8, cpp_type=7, label=1, - has_default_value=False, default_value=False, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='type', full_name='tensorflow.AttrValue.type', index=4, - number=6, type=14, cpp_type=8, label=1, - has_default_value=False, default_value=0, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='shape', full_name='tensorflow.AttrValue.shape', index=5, - number=7, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='tensor', full_name='tensorflow.AttrValue.tensor', index=6, - number=8, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='list', full_name='tensorflow.AttrValue.list', index=7, - number=1, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='func', full_name='tensorflow.AttrValue.func', index=8, - number=10, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='placeholder', full_name='tensorflow.AttrValue.placeholder', index=9, - number=9, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[_ATTRVALUE_LISTVALUE, ], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - _descriptor.OneofDescriptor( - name='value', full_name='tensorflow.AttrValue.value', - index=0, containing_type=None, fields=[]), - ], - serialized_start=80, - serialized_end=630, -) - - -_NAMEATTRLIST_ATTRENTRY = _descriptor.Descriptor( - name='AttrEntry', - full_name='tensorflow.NameAttrList.AttrEntry', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='key', full_name='tensorflow.NameAttrList.AttrEntry.key', index=0, - number=1, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='value', full_name='tensorflow.NameAttrList.AttrEntry.value', index=1, - number=2, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=b'8\001', - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=713, - serialized_end=779, -) - -_NAMEATTRLIST = _descriptor.Descriptor( - name='NameAttrList', - full_name='tensorflow.NameAttrList', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='name', full_name='tensorflow.NameAttrList.name', index=0, - number=1, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='attr', full_name='tensorflow.NameAttrList.attr', index=1, - number=2, type=11, cpp_type=10, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[_NAMEATTRLIST_ATTRENTRY, ], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=633, - serialized_end=779, -) - -_ATTRVALUE_LISTVALUE.fields_by_name['type'].enum_type = types__pb2._DATATYPE -_ATTRVALUE_LISTVALUE.fields_by_name['shape'].message_type = tensor__shape__pb2._TENSORSHAPEPROTO -_ATTRVALUE_LISTVALUE.fields_by_name['tensor'].message_type = tensor__pb2._TENSORPROTO -_ATTRVALUE_LISTVALUE.fields_by_name['func'].message_type = _NAMEATTRLIST -_ATTRVALUE_LISTVALUE.containing_type = _ATTRVALUE -_ATTRVALUE.fields_by_name['type'].enum_type = types__pb2._DATATYPE -_ATTRVALUE.fields_by_name['shape'].message_type = tensor__shape__pb2._TENSORSHAPEPROTO -_ATTRVALUE.fields_by_name['tensor'].message_type = tensor__pb2._TENSORPROTO -_ATTRVALUE.fields_by_name['list'].message_type = _ATTRVALUE_LISTVALUE -_ATTRVALUE.fields_by_name['func'].message_type = _NAMEATTRLIST -_ATTRVALUE.oneofs_by_name['value'].fields.append( - _ATTRVALUE.fields_by_name['s']) -_ATTRVALUE.fields_by_name['s'].containing_oneof = _ATTRVALUE.oneofs_by_name['value'] -_ATTRVALUE.oneofs_by_name['value'].fields.append( - _ATTRVALUE.fields_by_name['i']) -_ATTRVALUE.fields_by_name['i'].containing_oneof = _ATTRVALUE.oneofs_by_name['value'] -_ATTRVALUE.oneofs_by_name['value'].fields.append( - _ATTRVALUE.fields_by_name['f']) -_ATTRVALUE.fields_by_name['f'].containing_oneof = _ATTRVALUE.oneofs_by_name['value'] -_ATTRVALUE.oneofs_by_name['value'].fields.append( - _ATTRVALUE.fields_by_name['b']) -_ATTRVALUE.fields_by_name['b'].containing_oneof = _ATTRVALUE.oneofs_by_name['value'] -_ATTRVALUE.oneofs_by_name['value'].fields.append( - _ATTRVALUE.fields_by_name['type']) -_ATTRVALUE.fields_by_name['type'].containing_oneof = _ATTRVALUE.oneofs_by_name['value'] -_ATTRVALUE.oneofs_by_name['value'].fields.append( - _ATTRVALUE.fields_by_name['shape']) -_ATTRVALUE.fields_by_name['shape'].containing_oneof = _ATTRVALUE.oneofs_by_name['value'] -_ATTRVALUE.oneofs_by_name['value'].fields.append( - _ATTRVALUE.fields_by_name['tensor']) -_ATTRVALUE.fields_by_name['tensor'].containing_oneof = _ATTRVALUE.oneofs_by_name['value'] -_ATTRVALUE.oneofs_by_name['value'].fields.append( - _ATTRVALUE.fields_by_name['list']) -_ATTRVALUE.fields_by_name['list'].containing_oneof = _ATTRVALUE.oneofs_by_name['value'] -_ATTRVALUE.oneofs_by_name['value'].fields.append( - _ATTRVALUE.fields_by_name['func']) -_ATTRVALUE.fields_by_name['func'].containing_oneof = _ATTRVALUE.oneofs_by_name['value'] -_ATTRVALUE.oneofs_by_name['value'].fields.append( - _ATTRVALUE.fields_by_name['placeholder']) -_ATTRVALUE.fields_by_name['placeholder'].containing_oneof = _ATTRVALUE.oneofs_by_name['value'] -_NAMEATTRLIST_ATTRENTRY.fields_by_name['value'].message_type = _ATTRVALUE -_NAMEATTRLIST_ATTRENTRY.containing_type = _NAMEATTRLIST -_NAMEATTRLIST.fields_by_name['attr'].message_type = _NAMEATTRLIST_ATTRENTRY -DESCRIPTOR.message_types_by_name['AttrValue'] = _ATTRVALUE -DESCRIPTOR.message_types_by_name['NameAttrList'] = _NAMEATTRLIST -_sym_db.RegisterFileDescriptor(DESCRIPTOR) - -AttrValue = _reflection.GeneratedProtocolMessageType('AttrValue', (_message.Message,), { - - 'ListValue' : _reflection.GeneratedProtocolMessageType('ListValue', (_message.Message,), { - 'DESCRIPTOR' : _ATTRVALUE_LISTVALUE, - '__module__' : 'attr_value_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.AttrValue.ListValue) - }) - , - 'DESCRIPTOR' : _ATTRVALUE, - '__module__' : 'attr_value_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.AttrValue) - }) -_sym_db.RegisterMessage(AttrValue) -_sym_db.RegisterMessage(AttrValue.ListValue) - -NameAttrList = _reflection.GeneratedProtocolMessageType('NameAttrList', (_message.Message,), { - - 'AttrEntry' : _reflection.GeneratedProtocolMessageType('AttrEntry', (_message.Message,), { - 'DESCRIPTOR' : _NAMEATTRLIST_ATTRENTRY, - '__module__' : 'attr_value_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.NameAttrList.AttrEntry) - }) - , - 'DESCRIPTOR' : _NAMEATTRLIST, - '__module__' : 'attr_value_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.NameAttrList) - }) -_sym_db.RegisterMessage(NameAttrList) -_sym_db.RegisterMessage(NameAttrList.AttrEntry) - - -DESCRIPTOR._options = None -_ATTRVALUE_LISTVALUE.fields_by_name['i']._options = None -_ATTRVALUE_LISTVALUE.fields_by_name['f']._options = None -_ATTRVALUE_LISTVALUE.fields_by_name['b']._options = None -_ATTRVALUE_LISTVALUE.fields_by_name['type']._options = None -_NAMEATTRLIST_ATTRENTRY._options = None -# @@protoc_insertion_point(module_scope) diff --git a/redis_consumer/pbs/attr_value_pb2_grpc.py b/redis_consumer/pbs/attr_value_pb2_grpc.py deleted file mode 100644 index a8943526..00000000 --- a/redis_consumer/pbs/attr_value_pb2_grpc.py +++ /dev/null @@ -1,3 +0,0 @@ -# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! -import grpc - diff --git a/redis_consumer/pbs/function_pb2.py b/redis_consumer/pbs/function_pb2.py deleted file mode 100644 index 8084de52..00000000 --- a/redis_consumer/pbs/function_pb2.py +++ /dev/null @@ -1,486 +0,0 @@ -# -*- coding: utf-8 -*- -# Generated by the protocol buffer compiler. DO NOT EDIT! -# source: function.proto - -from google.protobuf import descriptor as _descriptor -from google.protobuf import message as _message -from google.protobuf import reflection as _reflection -from google.protobuf import symbol_database as _symbol_database -# @@protoc_insertion_point(imports) - -_sym_db = _symbol_database.Default() - - -import redis_consumer.pbs.attr_value_pb2 as attr__value__pb2 -import redis_consumer.pbs.node_def_pb2 as node__def__pb2 -import redis_consumer.pbs.op_def_pb2 as op__def__pb2 - - -DESCRIPTOR = _descriptor.FileDescriptor( - name='function.proto', - package='tensorflow', - syntax='proto3', - serialized_options=b'\n\030org.tensorflow.frameworkB\016FunctionProtosP\001Z=github.com/tensorflow/tensorflow/tensorflow/go/core/framework\370\001\001', - serialized_pb=b'\n\x0e\x66unction.proto\x12\ntensorflow\x1a\x10\x61ttr_value.proto\x1a\x0enode_def.proto\x1a\x0cop_def.proto\"j\n\x12\x46unctionDefLibrary\x12)\n\x08\x66unction\x18\x01 \x03(\x0b\x32\x17.tensorflow.FunctionDef\x12)\n\x08gradient\x18\x02 \x03(\x0b\x32\x17.tensorflow.GradientDef\"\xb6\x05\n\x0b\x46unctionDef\x12$\n\tsignature\x18\x01 \x01(\x0b\x32\x11.tensorflow.OpDef\x12/\n\x04\x61ttr\x18\x05 \x03(\x0b\x32!.tensorflow.FunctionDef.AttrEntry\x12\x36\n\x08\x61rg_attr\x18\x07 \x03(\x0b\x32$.tensorflow.FunctionDef.ArgAttrEntry\x12%\n\x08node_def\x18\x03 \x03(\x0b\x32\x13.tensorflow.NodeDef\x12-\n\x03ret\x18\x04 \x03(\x0b\x32 .tensorflow.FunctionDef.RetEntry\x12<\n\x0b\x63ontrol_ret\x18\x06 \x03(\x0b\x32\'.tensorflow.FunctionDef.ControlRetEntry\x1a\x42\n\tAttrEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12$\n\x05value\x18\x02 \x01(\x0b\x32\x15.tensorflow.AttrValue:\x02\x38\x01\x1a\x88\x01\n\x08\x41rgAttrs\x12\x38\n\x04\x61ttr\x18\x01 \x03(\x0b\x32*.tensorflow.FunctionDef.ArgAttrs.AttrEntry\x1a\x42\n\tAttrEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12$\n\x05value\x18\x02 \x01(\x0b\x32\x15.tensorflow.AttrValue:\x02\x38\x01\x1aP\n\x0c\x41rgAttrEntry\x12\x0b\n\x03key\x18\x01 \x01(\r\x12/\n\x05value\x18\x02 \x01(\x0b\x32 .tensorflow.FunctionDef.ArgAttrs:\x02\x38\x01\x1a*\n\x08RetEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\x1a\x31\n\x0f\x43ontrolRetEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01J\x04\x08\x02\x10\x03\";\n\x0bGradientDef\x12\x15\n\rfunction_name\x18\x01 \x01(\t\x12\x15\n\rgradient_func\x18\x02 \x01(\tBn\n\x18org.tensorflow.frameworkB\x0e\x46unctionProtosP\x01Z=github.com/tensorflow/tensorflow/tensorflow/go/core/framework\xf8\x01\x01\x62\x06proto3' - , - dependencies=[attr__value__pb2.DESCRIPTOR,node__def__pb2.DESCRIPTOR,op__def__pb2.DESCRIPTOR,]) - - - - -_FUNCTIONDEFLIBRARY = _descriptor.Descriptor( - name='FunctionDefLibrary', - full_name='tensorflow.FunctionDefLibrary', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='function', full_name='tensorflow.FunctionDefLibrary.function', index=0, - number=1, type=11, cpp_type=10, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='gradient', full_name='tensorflow.FunctionDefLibrary.gradient', index=1, - number=2, type=11, cpp_type=10, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=78, - serialized_end=184, -) - - -_FUNCTIONDEF_ATTRENTRY = _descriptor.Descriptor( - name='AttrEntry', - full_name='tensorflow.FunctionDef.AttrEntry', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='key', full_name='tensorflow.FunctionDef.AttrEntry.key', index=0, - number=1, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='value', full_name='tensorflow.FunctionDef.AttrEntry.value', index=1, - number=2, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=b'8\001', - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=493, - serialized_end=559, -) - -_FUNCTIONDEF_ARGATTRS_ATTRENTRY = _descriptor.Descriptor( - name='AttrEntry', - full_name='tensorflow.FunctionDef.ArgAttrs.AttrEntry', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='key', full_name='tensorflow.FunctionDef.ArgAttrs.AttrEntry.key', index=0, - number=1, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='value', full_name='tensorflow.FunctionDef.ArgAttrs.AttrEntry.value', index=1, - number=2, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=b'8\001', - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=493, - serialized_end=559, -) - -_FUNCTIONDEF_ARGATTRS = _descriptor.Descriptor( - name='ArgAttrs', - full_name='tensorflow.FunctionDef.ArgAttrs', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='attr', full_name='tensorflow.FunctionDef.ArgAttrs.attr', index=0, - number=1, type=11, cpp_type=10, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[_FUNCTIONDEF_ARGATTRS_ATTRENTRY, ], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=562, - serialized_end=698, -) - -_FUNCTIONDEF_ARGATTRENTRY = _descriptor.Descriptor( - name='ArgAttrEntry', - full_name='tensorflow.FunctionDef.ArgAttrEntry', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='key', full_name='tensorflow.FunctionDef.ArgAttrEntry.key', index=0, - number=1, type=13, cpp_type=3, label=1, - has_default_value=False, default_value=0, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='value', full_name='tensorflow.FunctionDef.ArgAttrEntry.value', index=1, - number=2, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=b'8\001', - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=700, - serialized_end=780, -) - -_FUNCTIONDEF_RETENTRY = _descriptor.Descriptor( - name='RetEntry', - full_name='tensorflow.FunctionDef.RetEntry', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='key', full_name='tensorflow.FunctionDef.RetEntry.key', index=0, - number=1, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='value', full_name='tensorflow.FunctionDef.RetEntry.value', index=1, - number=2, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=b'8\001', - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=782, - serialized_end=824, -) - -_FUNCTIONDEF_CONTROLRETENTRY = _descriptor.Descriptor( - name='ControlRetEntry', - full_name='tensorflow.FunctionDef.ControlRetEntry', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='key', full_name='tensorflow.FunctionDef.ControlRetEntry.key', index=0, - number=1, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='value', full_name='tensorflow.FunctionDef.ControlRetEntry.value', index=1, - number=2, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=b'8\001', - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=826, - serialized_end=875, -) - -_FUNCTIONDEF = _descriptor.Descriptor( - name='FunctionDef', - full_name='tensorflow.FunctionDef', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='signature', full_name='tensorflow.FunctionDef.signature', index=0, - number=1, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='attr', full_name='tensorflow.FunctionDef.attr', index=1, - number=5, type=11, cpp_type=10, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='arg_attr', full_name='tensorflow.FunctionDef.arg_attr', index=2, - number=7, type=11, cpp_type=10, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='node_def', full_name='tensorflow.FunctionDef.node_def', index=3, - number=3, type=11, cpp_type=10, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='ret', full_name='tensorflow.FunctionDef.ret', index=4, - number=4, type=11, cpp_type=10, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='control_ret', full_name='tensorflow.FunctionDef.control_ret', index=5, - number=6, type=11, cpp_type=10, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[_FUNCTIONDEF_ATTRENTRY, _FUNCTIONDEF_ARGATTRS, _FUNCTIONDEF_ARGATTRENTRY, _FUNCTIONDEF_RETENTRY, _FUNCTIONDEF_CONTROLRETENTRY, ], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=187, - serialized_end=881, -) - - -_GRADIENTDEF = _descriptor.Descriptor( - name='GradientDef', - full_name='tensorflow.GradientDef', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='function_name', full_name='tensorflow.GradientDef.function_name', index=0, - number=1, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='gradient_func', full_name='tensorflow.GradientDef.gradient_func', index=1, - number=2, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=883, - serialized_end=942, -) - -_FUNCTIONDEFLIBRARY.fields_by_name['function'].message_type = _FUNCTIONDEF -_FUNCTIONDEFLIBRARY.fields_by_name['gradient'].message_type = _GRADIENTDEF -_FUNCTIONDEF_ATTRENTRY.fields_by_name['value'].message_type = attr__value__pb2._ATTRVALUE -_FUNCTIONDEF_ATTRENTRY.containing_type = _FUNCTIONDEF -_FUNCTIONDEF_ARGATTRS_ATTRENTRY.fields_by_name['value'].message_type = attr__value__pb2._ATTRVALUE -_FUNCTIONDEF_ARGATTRS_ATTRENTRY.containing_type = _FUNCTIONDEF_ARGATTRS -_FUNCTIONDEF_ARGATTRS.fields_by_name['attr'].message_type = _FUNCTIONDEF_ARGATTRS_ATTRENTRY -_FUNCTIONDEF_ARGATTRS.containing_type = _FUNCTIONDEF -_FUNCTIONDEF_ARGATTRENTRY.fields_by_name['value'].message_type = _FUNCTIONDEF_ARGATTRS -_FUNCTIONDEF_ARGATTRENTRY.containing_type = _FUNCTIONDEF -_FUNCTIONDEF_RETENTRY.containing_type = _FUNCTIONDEF -_FUNCTIONDEF_CONTROLRETENTRY.containing_type = _FUNCTIONDEF -_FUNCTIONDEF.fields_by_name['signature'].message_type = op__def__pb2._OPDEF -_FUNCTIONDEF.fields_by_name['attr'].message_type = _FUNCTIONDEF_ATTRENTRY -_FUNCTIONDEF.fields_by_name['arg_attr'].message_type = _FUNCTIONDEF_ARGATTRENTRY -_FUNCTIONDEF.fields_by_name['node_def'].message_type = node__def__pb2._NODEDEF -_FUNCTIONDEF.fields_by_name['ret'].message_type = _FUNCTIONDEF_RETENTRY -_FUNCTIONDEF.fields_by_name['control_ret'].message_type = _FUNCTIONDEF_CONTROLRETENTRY -DESCRIPTOR.message_types_by_name['FunctionDefLibrary'] = _FUNCTIONDEFLIBRARY -DESCRIPTOR.message_types_by_name['FunctionDef'] = _FUNCTIONDEF -DESCRIPTOR.message_types_by_name['GradientDef'] = _GRADIENTDEF -_sym_db.RegisterFileDescriptor(DESCRIPTOR) - -FunctionDefLibrary = _reflection.GeneratedProtocolMessageType('FunctionDefLibrary', (_message.Message,), { - 'DESCRIPTOR' : _FUNCTIONDEFLIBRARY, - '__module__' : 'function_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.FunctionDefLibrary) - }) -_sym_db.RegisterMessage(FunctionDefLibrary) - -FunctionDef = _reflection.GeneratedProtocolMessageType('FunctionDef', (_message.Message,), { - - 'AttrEntry' : _reflection.GeneratedProtocolMessageType('AttrEntry', (_message.Message,), { - 'DESCRIPTOR' : _FUNCTIONDEF_ATTRENTRY, - '__module__' : 'function_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.FunctionDef.AttrEntry) - }) - , - - 'ArgAttrs' : _reflection.GeneratedProtocolMessageType('ArgAttrs', (_message.Message,), { - - 'AttrEntry' : _reflection.GeneratedProtocolMessageType('AttrEntry', (_message.Message,), { - 'DESCRIPTOR' : _FUNCTIONDEF_ARGATTRS_ATTRENTRY, - '__module__' : 'function_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.FunctionDef.ArgAttrs.AttrEntry) - }) - , - 'DESCRIPTOR' : _FUNCTIONDEF_ARGATTRS, - '__module__' : 'function_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.FunctionDef.ArgAttrs) - }) - , - - 'ArgAttrEntry' : _reflection.GeneratedProtocolMessageType('ArgAttrEntry', (_message.Message,), { - 'DESCRIPTOR' : _FUNCTIONDEF_ARGATTRENTRY, - '__module__' : 'function_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.FunctionDef.ArgAttrEntry) - }) - , - - 'RetEntry' : _reflection.GeneratedProtocolMessageType('RetEntry', (_message.Message,), { - 'DESCRIPTOR' : _FUNCTIONDEF_RETENTRY, - '__module__' : 'function_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.FunctionDef.RetEntry) - }) - , - - 'ControlRetEntry' : _reflection.GeneratedProtocolMessageType('ControlRetEntry', (_message.Message,), { - 'DESCRIPTOR' : _FUNCTIONDEF_CONTROLRETENTRY, - '__module__' : 'function_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.FunctionDef.ControlRetEntry) - }) - , - 'DESCRIPTOR' : _FUNCTIONDEF, - '__module__' : 'function_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.FunctionDef) - }) -_sym_db.RegisterMessage(FunctionDef) -_sym_db.RegisterMessage(FunctionDef.AttrEntry) -_sym_db.RegisterMessage(FunctionDef.ArgAttrs) -_sym_db.RegisterMessage(FunctionDef.ArgAttrs.AttrEntry) -_sym_db.RegisterMessage(FunctionDef.ArgAttrEntry) -_sym_db.RegisterMessage(FunctionDef.RetEntry) -_sym_db.RegisterMessage(FunctionDef.ControlRetEntry) - -GradientDef = _reflection.GeneratedProtocolMessageType('GradientDef', (_message.Message,), { - 'DESCRIPTOR' : _GRADIENTDEF, - '__module__' : 'function_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.GradientDef) - }) -_sym_db.RegisterMessage(GradientDef) - - -DESCRIPTOR._options = None -_FUNCTIONDEF_ATTRENTRY._options = None -_FUNCTIONDEF_ARGATTRS_ATTRENTRY._options = None -_FUNCTIONDEF_ARGATTRENTRY._options = None -_FUNCTIONDEF_RETENTRY._options = None -_FUNCTIONDEF_CONTROLRETENTRY._options = None -# @@protoc_insertion_point(module_scope) diff --git a/redis_consumer/pbs/function_pb2_grpc.py b/redis_consumer/pbs/function_pb2_grpc.py deleted file mode 100644 index a8943526..00000000 --- a/redis_consumer/pbs/function_pb2_grpc.py +++ /dev/null @@ -1,3 +0,0 @@ -# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! -import grpc - diff --git a/redis_consumer/pbs/get_model_metadata_pb2.py b/redis_consumer/pbs/get_model_metadata_pb2.py deleted file mode 100644 index 1695ddee..00000000 --- a/redis_consumer/pbs/get_model_metadata_pb2.py +++ /dev/null @@ -1,265 +0,0 @@ -# -*- coding: utf-8 -*- -# Generated by the protocol buffer compiler. DO NOT EDIT! -# source: get_model_metadata.proto - -from google.protobuf import descriptor as _descriptor -from google.protobuf import message as _message -from google.protobuf import reflection as _reflection -from google.protobuf import symbol_database as _symbol_database -# @@protoc_insertion_point(imports) - -_sym_db = _symbol_database.Default() - - -from google.protobuf import any_pb2 as google_dot_protobuf_dot_any__pb2 -import redis_consumer.pbs.meta_graph_pb2 as meta__graph__pb2 -import redis_consumer.pbs.model_pb2 as model__pb2 - - -DESCRIPTOR = _descriptor.FileDescriptor( - name='get_model_metadata.proto', - package='tensorflow.serving', - syntax='proto3', - serialized_options=b'\370\001\001', - serialized_pb=b'\n\x18get_model_metadata.proto\x12\x12tensorflow.serving\x1a\x19google/protobuf/any.proto\x1a\x10meta_graph.proto\x1a\x0bmodel.proto\"\xae\x01\n\x0fSignatureDefMap\x12L\n\rsignature_def\x18\x01 \x03(\x0b\x32\x35.tensorflow.serving.SignatureDefMap.SignatureDefEntry\x1aM\n\x11SignatureDefEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\'\n\x05value\x18\x02 \x01(\x0b\x32\x18.tensorflow.SignatureDef:\x02\x38\x01\"d\n\x17GetModelMetadataRequest\x12\x31\n\nmodel_spec\x18\x01 \x01(\x0b\x32\x1d.tensorflow.serving.ModelSpec\x12\x16\n\x0emetadata_field\x18\x02 \x03(\t\"\xe2\x01\n\x18GetModelMetadataResponse\x12\x31\n\nmodel_spec\x18\x01 \x01(\x0b\x32\x1d.tensorflow.serving.ModelSpec\x12L\n\x08metadata\x18\x02 \x03(\x0b\x32:.tensorflow.serving.GetModelMetadataResponse.MetadataEntry\x1a\x45\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.google.protobuf.Any:\x02\x38\x01\x42\x03\xf8\x01\x01\x62\x06proto3' - , - dependencies=[google_dot_protobuf_dot_any__pb2.DESCRIPTOR,meta__graph__pb2.DESCRIPTOR,model__pb2.DESCRIPTOR,]) - - - - -_SIGNATUREDEFMAP_SIGNATUREDEFENTRY = _descriptor.Descriptor( - name='SignatureDefEntry', - full_name='tensorflow.serving.SignatureDefMap.SignatureDefEntry', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='key', full_name='tensorflow.serving.SignatureDefMap.SignatureDefEntry.key', index=0, - number=1, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='value', full_name='tensorflow.serving.SignatureDefMap.SignatureDefEntry.value', index=1, - number=2, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=b'8\001', - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=204, - serialized_end=281, -) - -_SIGNATUREDEFMAP = _descriptor.Descriptor( - name='SignatureDefMap', - full_name='tensorflow.serving.SignatureDefMap', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='signature_def', full_name='tensorflow.serving.SignatureDefMap.signature_def', index=0, - number=1, type=11, cpp_type=10, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[_SIGNATUREDEFMAP_SIGNATUREDEFENTRY, ], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=107, - serialized_end=281, -) - - -_GETMODELMETADATAREQUEST = _descriptor.Descriptor( - name='GetModelMetadataRequest', - full_name='tensorflow.serving.GetModelMetadataRequest', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='model_spec', full_name='tensorflow.serving.GetModelMetadataRequest.model_spec', index=0, - number=1, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='metadata_field', full_name='tensorflow.serving.GetModelMetadataRequest.metadata_field', index=1, - number=2, type=9, cpp_type=9, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=283, - serialized_end=383, -) - - -_GETMODELMETADATARESPONSE_METADATAENTRY = _descriptor.Descriptor( - name='MetadataEntry', - full_name='tensorflow.serving.GetModelMetadataResponse.MetadataEntry', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='key', full_name='tensorflow.serving.GetModelMetadataResponse.MetadataEntry.key', index=0, - number=1, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='value', full_name='tensorflow.serving.GetModelMetadataResponse.MetadataEntry.value', index=1, - number=2, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=b'8\001', - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=543, - serialized_end=612, -) - -_GETMODELMETADATARESPONSE = _descriptor.Descriptor( - name='GetModelMetadataResponse', - full_name='tensorflow.serving.GetModelMetadataResponse', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='model_spec', full_name='tensorflow.serving.GetModelMetadataResponse.model_spec', index=0, - number=1, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='metadata', full_name='tensorflow.serving.GetModelMetadataResponse.metadata', index=1, - number=2, type=11, cpp_type=10, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[_GETMODELMETADATARESPONSE_METADATAENTRY, ], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=386, - serialized_end=612, -) - -_SIGNATUREDEFMAP_SIGNATUREDEFENTRY.fields_by_name['value'].message_type = meta__graph__pb2._SIGNATUREDEF -_SIGNATUREDEFMAP_SIGNATUREDEFENTRY.containing_type = _SIGNATUREDEFMAP -_SIGNATUREDEFMAP.fields_by_name['signature_def'].message_type = _SIGNATUREDEFMAP_SIGNATUREDEFENTRY -_GETMODELMETADATAREQUEST.fields_by_name['model_spec'].message_type = model__pb2._MODELSPEC -_GETMODELMETADATARESPONSE_METADATAENTRY.fields_by_name['value'].message_type = google_dot_protobuf_dot_any__pb2._ANY -_GETMODELMETADATARESPONSE_METADATAENTRY.containing_type = _GETMODELMETADATARESPONSE -_GETMODELMETADATARESPONSE.fields_by_name['model_spec'].message_type = model__pb2._MODELSPEC -_GETMODELMETADATARESPONSE.fields_by_name['metadata'].message_type = _GETMODELMETADATARESPONSE_METADATAENTRY -DESCRIPTOR.message_types_by_name['SignatureDefMap'] = _SIGNATUREDEFMAP -DESCRIPTOR.message_types_by_name['GetModelMetadataRequest'] = _GETMODELMETADATAREQUEST -DESCRIPTOR.message_types_by_name['GetModelMetadataResponse'] = _GETMODELMETADATARESPONSE -_sym_db.RegisterFileDescriptor(DESCRIPTOR) - -SignatureDefMap = _reflection.GeneratedProtocolMessageType('SignatureDefMap', (_message.Message,), { - - 'SignatureDefEntry' : _reflection.GeneratedProtocolMessageType('SignatureDefEntry', (_message.Message,), { - 'DESCRIPTOR' : _SIGNATUREDEFMAP_SIGNATUREDEFENTRY, - '__module__' : 'get_model_metadata_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.serving.SignatureDefMap.SignatureDefEntry) - }) - , - 'DESCRIPTOR' : _SIGNATUREDEFMAP, - '__module__' : 'get_model_metadata_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.serving.SignatureDefMap) - }) -_sym_db.RegisterMessage(SignatureDefMap) -_sym_db.RegisterMessage(SignatureDefMap.SignatureDefEntry) - -GetModelMetadataRequest = _reflection.GeneratedProtocolMessageType('GetModelMetadataRequest', (_message.Message,), { - 'DESCRIPTOR' : _GETMODELMETADATAREQUEST, - '__module__' : 'get_model_metadata_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.serving.GetModelMetadataRequest) - }) -_sym_db.RegisterMessage(GetModelMetadataRequest) - -GetModelMetadataResponse = _reflection.GeneratedProtocolMessageType('GetModelMetadataResponse', (_message.Message,), { - - 'MetadataEntry' : _reflection.GeneratedProtocolMessageType('MetadataEntry', (_message.Message,), { - 'DESCRIPTOR' : _GETMODELMETADATARESPONSE_METADATAENTRY, - '__module__' : 'get_model_metadata_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.serving.GetModelMetadataResponse.MetadataEntry) - }) - , - 'DESCRIPTOR' : _GETMODELMETADATARESPONSE, - '__module__' : 'get_model_metadata_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.serving.GetModelMetadataResponse) - }) -_sym_db.RegisterMessage(GetModelMetadataResponse) -_sym_db.RegisterMessage(GetModelMetadataResponse.MetadataEntry) - - -DESCRIPTOR._options = None -_SIGNATUREDEFMAP_SIGNATUREDEFENTRY._options = None -_GETMODELMETADATARESPONSE_METADATAENTRY._options = None -# @@protoc_insertion_point(module_scope) diff --git a/redis_consumer/pbs/get_model_metadata_pb2_grpc.py b/redis_consumer/pbs/get_model_metadata_pb2_grpc.py deleted file mode 100644 index a8943526..00000000 --- a/redis_consumer/pbs/get_model_metadata_pb2_grpc.py +++ /dev/null @@ -1,3 +0,0 @@ -# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! -import grpc - diff --git a/redis_consumer/pbs/graph_pb2.py b/redis_consumer/pbs/graph_pb2.py deleted file mode 100644 index 78131df8..00000000 --- a/redis_consumer/pbs/graph_pb2.py +++ /dev/null @@ -1,98 +0,0 @@ -# -*- coding: utf-8 -*- -# Generated by the protocol buffer compiler. DO NOT EDIT! -# source: graph.proto - -from google.protobuf import descriptor as _descriptor -from google.protobuf import message as _message -from google.protobuf import reflection as _reflection -from google.protobuf import symbol_database as _symbol_database -# @@protoc_insertion_point(imports) - -_sym_db = _symbol_database.Default() - - -import redis_consumer.pbs.node_def_pb2 as node__def__pb2 -import redis_consumer.pbs.function_pb2 as function__pb2 -import redis_consumer.pbs.versions_pb2 as versions__pb2 - - -DESCRIPTOR = _descriptor.FileDescriptor( - name='graph.proto', - package='tensorflow', - syntax='proto3', - serialized_options=b'\n\030org.tensorflow.frameworkB\013GraphProtosP\001Z=github.com/tensorflow/tensorflow/tensorflow/go/core/framework\370\001\001', - serialized_pb=b'\n\x0bgraph.proto\x12\ntensorflow\x1a\x0enode_def.proto\x1a\x0e\x66unction.proto\x1a\x0eversions.proto\"\x9d\x01\n\x08GraphDef\x12!\n\x04node\x18\x01 \x03(\x0b\x32\x13.tensorflow.NodeDef\x12(\n\x08versions\x18\x04 \x01(\x0b\x32\x16.tensorflow.VersionDef\x12\x13\n\x07version\x18\x03 \x01(\x05\x42\x02\x18\x01\x12/\n\x07library\x18\x02 \x01(\x0b\x32\x1e.tensorflow.FunctionDefLibraryBk\n\x18org.tensorflow.frameworkB\x0bGraphProtosP\x01Z=github.com/tensorflow/tensorflow/tensorflow/go/core/framework\xf8\x01\x01\x62\x06proto3' - , - dependencies=[node__def__pb2.DESCRIPTOR,function__pb2.DESCRIPTOR,versions__pb2.DESCRIPTOR,]) - - - - -_GRAPHDEF = _descriptor.Descriptor( - name='GraphDef', - full_name='tensorflow.GraphDef', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='node', full_name='tensorflow.GraphDef.node', index=0, - number=1, type=11, cpp_type=10, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='versions', full_name='tensorflow.GraphDef.versions', index=1, - number=4, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='version', full_name='tensorflow.GraphDef.version', index=2, - number=3, type=5, cpp_type=1, label=1, - has_default_value=False, default_value=0, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=b'\030\001', file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='library', full_name='tensorflow.GraphDef.library', index=3, - number=2, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=76, - serialized_end=233, -) - -_GRAPHDEF.fields_by_name['node'].message_type = node__def__pb2._NODEDEF -_GRAPHDEF.fields_by_name['versions'].message_type = versions__pb2._VERSIONDEF -_GRAPHDEF.fields_by_name['library'].message_type = function__pb2._FUNCTIONDEFLIBRARY -DESCRIPTOR.message_types_by_name['GraphDef'] = _GRAPHDEF -_sym_db.RegisterFileDescriptor(DESCRIPTOR) - -GraphDef = _reflection.GeneratedProtocolMessageType('GraphDef', (_message.Message,), { - 'DESCRIPTOR' : _GRAPHDEF, - '__module__' : 'graph_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.GraphDef) - }) -_sym_db.RegisterMessage(GraphDef) - - -DESCRIPTOR._options = None -_GRAPHDEF.fields_by_name['version']._options = None -# @@protoc_insertion_point(module_scope) diff --git a/redis_consumer/pbs/graph_pb2_grpc.py b/redis_consumer/pbs/graph_pb2_grpc.py deleted file mode 100644 index a8943526..00000000 --- a/redis_consumer/pbs/graph_pb2_grpc.py +++ /dev/null @@ -1,3 +0,0 @@ -# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! -import grpc - diff --git a/redis_consumer/pbs/meta_graph_pb2.py b/redis_consumer/pbs/meta_graph_pb2.py deleted file mode 100644 index c8e8eebd..00000000 --- a/redis_consumer/pbs/meta_graph_pb2.py +++ /dev/null @@ -1,1031 +0,0 @@ -# -*- coding: utf-8 -*- -# Generated by the protocol buffer compiler. DO NOT EDIT! -# source: meta_graph.proto - -from google.protobuf import descriptor as _descriptor -from google.protobuf import message as _message -from google.protobuf import reflection as _reflection -from google.protobuf import symbol_database as _symbol_database -# @@protoc_insertion_point(imports) - -_sym_db = _symbol_database.Default() - - -from google.protobuf import any_pb2 as google_dot_protobuf_dot_any__pb2 -import redis_consumer.pbs.graph_pb2 as graph__pb2 -import redis_consumer.pbs.op_def_pb2 as op__def__pb2 -import redis_consumer.pbs.tensor_shape_pb2 as tensor__shape__pb2 -import redis_consumer.pbs.types_pb2 as types__pb2 -import redis_consumer.pbs.saved_object_graph_pb2 as saved__object__graph__pb2 -import redis_consumer.pbs.saver_pb2 as saver__pb2 -import redis_consumer.pbs.struct_pb2 as struct__pb2 - - -DESCRIPTOR = _descriptor.FileDescriptor( - name='meta_graph.proto', - package='tensorflow', - syntax='proto3', - serialized_options=b'\n\030org.tensorflow.frameworkB\017MetaGraphProtosP\001ZHgithub.com/tensorflow/tensorflow/tensorflow/go/core/core_protos_go_proto\370\001\001', - serialized_pb=b'\n\x10meta_graph.proto\x12\ntensorflow\x1a\x19google/protobuf/any.proto\x1a\x0bgraph.proto\x1a\x0cop_def.proto\x1a\x12tensor_shape.proto\x1a\x0btypes.proto\x1a\x18saved_object_graph.proto\x1a\x0bsaver.proto\x1a\x0cstruct.proto\"\xa8\x07\n\x0cMetaGraphDef\x12;\n\rmeta_info_def\x18\x01 \x01(\x0b\x32$.tensorflow.MetaGraphDef.MetaInfoDef\x12\'\n\tgraph_def\x18\x02 \x01(\x0b\x32\x14.tensorflow.GraphDef\x12\'\n\tsaver_def\x18\x03 \x01(\x0b\x32\x14.tensorflow.SaverDef\x12\x43\n\x0e\x63ollection_def\x18\x04 \x03(\x0b\x32+.tensorflow.MetaGraphDef.CollectionDefEntry\x12\x41\n\rsignature_def\x18\x05 \x03(\x0b\x32*.tensorflow.MetaGraphDef.SignatureDefEntry\x12\x30\n\x0e\x61sset_file_def\x18\x06 \x03(\x0b\x32\x18.tensorflow.AssetFileDef\x12\x36\n\x10object_graph_def\x18\x07 \x01(\x0b\x32\x1c.tensorflow.SavedObjectGraph\x1a\xf6\x02\n\x0bMetaInfoDef\x12\x1a\n\x12meta_graph_version\x18\x01 \x01(\t\x12,\n\x10stripped_op_list\x18\x02 \x01(\x0b\x32\x12.tensorflow.OpList\x12&\n\x08\x61ny_info\x18\x03 \x01(\x0b\x32\x14.google.protobuf.Any\x12\x0c\n\x04tags\x18\x04 \x03(\t\x12\x1a\n\x12tensorflow_version\x18\x05 \x01(\t\x12\x1e\n\x16tensorflow_git_version\x18\x06 \x01(\t\x12\x1e\n\x16stripped_default_attrs\x18\x07 \x01(\x08\x12S\n\x10\x66unction_aliases\x18\x08 \x03(\x0b\x32\x39.tensorflow.MetaGraphDef.MetaInfoDef.FunctionAliasesEntry\x1a\x36\n\x14\x46unctionAliasesEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\x1aO\n\x12\x43ollectionDefEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12(\n\x05value\x18\x02 \x01(\x0b\x32\x19.tensorflow.CollectionDef:\x02\x38\x01\x1aM\n\x11SignatureDefEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\'\n\x05value\x18\x02 \x01(\x0b\x32\x18.tensorflow.SignatureDef:\x02\x38\x01\"\xdf\x03\n\rCollectionDef\x12\x37\n\tnode_list\x18\x01 \x01(\x0b\x32\".tensorflow.CollectionDef.NodeListH\x00\x12\x39\n\nbytes_list\x18\x02 \x01(\x0b\x32#.tensorflow.CollectionDef.BytesListH\x00\x12\x39\n\nint64_list\x18\x03 \x01(\x0b\x32#.tensorflow.CollectionDef.Int64ListH\x00\x12\x39\n\nfloat_list\x18\x04 \x01(\x0b\x32#.tensorflow.CollectionDef.FloatListH\x00\x12\x35\n\x08\x61ny_list\x18\x05 \x01(\x0b\x32!.tensorflow.CollectionDef.AnyListH\x00\x1a\x19\n\x08NodeList\x12\r\n\x05value\x18\x01 \x03(\t\x1a\x1a\n\tBytesList\x12\r\n\x05value\x18\x01 \x03(\x0c\x1a\x1e\n\tInt64List\x12\x11\n\x05value\x18\x01 \x03(\x03\x42\x02\x10\x01\x1a\x1e\n\tFloatList\x12\x11\n\x05value\x18\x01 \x03(\x02\x42\x02\x10\x01\x1a.\n\x07\x41nyList\x12#\n\x05value\x18\x01 \x03(\x0b\x32\x14.google.protobuf.AnyB\x06\n\x04kind\"\xd1\x03\n\nTensorInfo\x12\x0e\n\x04name\x18\x01 \x01(\tH\x00\x12\x36\n\ncoo_sparse\x18\x04 \x01(\x0b\x32 .tensorflow.TensorInfo.CooSparseH\x00\x12\x42\n\x10\x63omposite_tensor\x18\x05 \x01(\x0b\x32&.tensorflow.TensorInfo.CompositeTensorH\x00\x12#\n\x05\x64type\x18\x02 \x01(\x0e\x32\x14.tensorflow.DataType\x12\x32\n\x0ctensor_shape\x18\x03 \x01(\x0b\x32\x1c.tensorflow.TensorShapeProto\x1a\x65\n\tCooSparse\x12\x1a\n\x12values_tensor_name\x18\x01 \x01(\t\x12\x1b\n\x13indices_tensor_name\x18\x02 \x01(\t\x12\x1f\n\x17\x64\x65nse_shape_tensor_name\x18\x03 \x01(\t\x1ak\n\x0f\x43ompositeTensor\x12,\n\ttype_spec\x18\x01 \x01(\x0b\x32\x19.tensorflow.TypeSpecProto\x12*\n\ncomponents\x18\x02 \x03(\x0b\x32\x16.tensorflow.TensorInfoB\n\n\x08\x65ncoding\"\xa0\x02\n\x0cSignatureDef\x12\x34\n\x06inputs\x18\x01 \x03(\x0b\x32$.tensorflow.SignatureDef.InputsEntry\x12\x36\n\x07outputs\x18\x02 \x03(\x0b\x32%.tensorflow.SignatureDef.OutputsEntry\x12\x13\n\x0bmethod_name\x18\x03 \x01(\t\x1a\x45\n\x0bInputsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12%\n\x05value\x18\x02 \x01(\x0b\x32\x16.tensorflow.TensorInfo:\x02\x38\x01\x1a\x46\n\x0cOutputsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12%\n\x05value\x18\x02 \x01(\x0b\x32\x16.tensorflow.TensorInfo:\x02\x38\x01\"M\n\x0c\x41ssetFileDef\x12+\n\x0btensor_info\x18\x01 \x01(\x0b\x32\x16.tensorflow.TensorInfo\x12\x10\n\x08\x66ilename\x18\x02 \x01(\tBz\n\x18org.tensorflow.frameworkB\x0fMetaGraphProtosP\x01ZHgithub.com/tensorflow/tensorflow/tensorflow/go/core/core_protos_go_proto\xf8\x01\x01\x62\x06proto3' - , - dependencies=[google_dot_protobuf_dot_any__pb2.DESCRIPTOR,graph__pb2.DESCRIPTOR,op__def__pb2.DESCRIPTOR,tensor__shape__pb2.DESCRIPTOR,types__pb2.DESCRIPTOR,saved__object__graph__pb2.DESCRIPTOR,saver__pb2.DESCRIPTOR,struct__pb2.DESCRIPTOR,]) - - - - -_METAGRAPHDEF_METAINFODEF_FUNCTIONALIASESENTRY = _descriptor.Descriptor( - name='FunctionAliasesEntry', - full_name='tensorflow.MetaGraphDef.MetaInfoDef.FunctionAliasesEntry', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='key', full_name='tensorflow.MetaGraphDef.MetaInfoDef.FunctionAliasesEntry.key', index=0, - number=1, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='value', full_name='tensorflow.MetaGraphDef.MetaInfoDef.FunctionAliasesEntry.value', index=1, - number=2, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=b'8\001', - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=895, - serialized_end=949, -) - -_METAGRAPHDEF_METAINFODEF = _descriptor.Descriptor( - name='MetaInfoDef', - full_name='tensorflow.MetaGraphDef.MetaInfoDef', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='meta_graph_version', full_name='tensorflow.MetaGraphDef.MetaInfoDef.meta_graph_version', index=0, - number=1, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='stripped_op_list', full_name='tensorflow.MetaGraphDef.MetaInfoDef.stripped_op_list', index=1, - number=2, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='any_info', full_name='tensorflow.MetaGraphDef.MetaInfoDef.any_info', index=2, - number=3, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='tags', full_name='tensorflow.MetaGraphDef.MetaInfoDef.tags', index=3, - number=4, type=9, cpp_type=9, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='tensorflow_version', full_name='tensorflow.MetaGraphDef.MetaInfoDef.tensorflow_version', index=4, - number=5, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='tensorflow_git_version', full_name='tensorflow.MetaGraphDef.MetaInfoDef.tensorflow_git_version', index=5, - number=6, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='stripped_default_attrs', full_name='tensorflow.MetaGraphDef.MetaInfoDef.stripped_default_attrs', index=6, - number=7, type=8, cpp_type=7, label=1, - has_default_value=False, default_value=False, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='function_aliases', full_name='tensorflow.MetaGraphDef.MetaInfoDef.function_aliases', index=7, - number=8, type=11, cpp_type=10, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[_METAGRAPHDEF_METAINFODEF_FUNCTIONALIASESENTRY, ], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=575, - serialized_end=949, -) - -_METAGRAPHDEF_COLLECTIONDEFENTRY = _descriptor.Descriptor( - name='CollectionDefEntry', - full_name='tensorflow.MetaGraphDef.CollectionDefEntry', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='key', full_name='tensorflow.MetaGraphDef.CollectionDefEntry.key', index=0, - number=1, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='value', full_name='tensorflow.MetaGraphDef.CollectionDefEntry.value', index=1, - number=2, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=b'8\001', - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=951, - serialized_end=1030, -) - -_METAGRAPHDEF_SIGNATUREDEFENTRY = _descriptor.Descriptor( - name='SignatureDefEntry', - full_name='tensorflow.MetaGraphDef.SignatureDefEntry', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='key', full_name='tensorflow.MetaGraphDef.SignatureDefEntry.key', index=0, - number=1, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='value', full_name='tensorflow.MetaGraphDef.SignatureDefEntry.value', index=1, - number=2, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=b'8\001', - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=1032, - serialized_end=1109, -) - -_METAGRAPHDEF = _descriptor.Descriptor( - name='MetaGraphDef', - full_name='tensorflow.MetaGraphDef', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='meta_info_def', full_name='tensorflow.MetaGraphDef.meta_info_def', index=0, - number=1, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='graph_def', full_name='tensorflow.MetaGraphDef.graph_def', index=1, - number=2, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='saver_def', full_name='tensorflow.MetaGraphDef.saver_def', index=2, - number=3, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='collection_def', full_name='tensorflow.MetaGraphDef.collection_def', index=3, - number=4, type=11, cpp_type=10, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='signature_def', full_name='tensorflow.MetaGraphDef.signature_def', index=4, - number=5, type=11, cpp_type=10, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='asset_file_def', full_name='tensorflow.MetaGraphDef.asset_file_def', index=5, - number=6, type=11, cpp_type=10, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='object_graph_def', full_name='tensorflow.MetaGraphDef.object_graph_def', index=6, - number=7, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[_METAGRAPHDEF_METAINFODEF, _METAGRAPHDEF_COLLECTIONDEFENTRY, _METAGRAPHDEF_SIGNATUREDEFENTRY, ], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=173, - serialized_end=1109, -) - - -_COLLECTIONDEF_NODELIST = _descriptor.Descriptor( - name='NodeList', - full_name='tensorflow.CollectionDef.NodeList', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='value', full_name='tensorflow.CollectionDef.NodeList.value', index=0, - number=1, type=9, cpp_type=9, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=1418, - serialized_end=1443, -) - -_COLLECTIONDEF_BYTESLIST = _descriptor.Descriptor( - name='BytesList', - full_name='tensorflow.CollectionDef.BytesList', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='value', full_name='tensorflow.CollectionDef.BytesList.value', index=0, - number=1, type=12, cpp_type=9, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=1445, - serialized_end=1471, -) - -_COLLECTIONDEF_INT64LIST = _descriptor.Descriptor( - name='Int64List', - full_name='tensorflow.CollectionDef.Int64List', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='value', full_name='tensorflow.CollectionDef.Int64List.value', index=0, - number=1, type=3, cpp_type=2, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=b'\020\001', file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=1473, - serialized_end=1503, -) - -_COLLECTIONDEF_FLOATLIST = _descriptor.Descriptor( - name='FloatList', - full_name='tensorflow.CollectionDef.FloatList', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='value', full_name='tensorflow.CollectionDef.FloatList.value', index=0, - number=1, type=2, cpp_type=6, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=b'\020\001', file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=1505, - serialized_end=1535, -) - -_COLLECTIONDEF_ANYLIST = _descriptor.Descriptor( - name='AnyList', - full_name='tensorflow.CollectionDef.AnyList', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='value', full_name='tensorflow.CollectionDef.AnyList.value', index=0, - number=1, type=11, cpp_type=10, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=1537, - serialized_end=1583, -) - -_COLLECTIONDEF = _descriptor.Descriptor( - name='CollectionDef', - full_name='tensorflow.CollectionDef', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='node_list', full_name='tensorflow.CollectionDef.node_list', index=0, - number=1, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='bytes_list', full_name='tensorflow.CollectionDef.bytes_list', index=1, - number=2, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='int64_list', full_name='tensorflow.CollectionDef.int64_list', index=2, - number=3, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='float_list', full_name='tensorflow.CollectionDef.float_list', index=3, - number=4, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='any_list', full_name='tensorflow.CollectionDef.any_list', index=4, - number=5, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[_COLLECTIONDEF_NODELIST, _COLLECTIONDEF_BYTESLIST, _COLLECTIONDEF_INT64LIST, _COLLECTIONDEF_FLOATLIST, _COLLECTIONDEF_ANYLIST, ], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - _descriptor.OneofDescriptor( - name='kind', full_name='tensorflow.CollectionDef.kind', - index=0, containing_type=None, fields=[]), - ], - serialized_start=1112, - serialized_end=1591, -) - - -_TENSORINFO_COOSPARSE = _descriptor.Descriptor( - name='CooSparse', - full_name='tensorflow.TensorInfo.CooSparse', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='values_tensor_name', full_name='tensorflow.TensorInfo.CooSparse.values_tensor_name', index=0, - number=1, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='indices_tensor_name', full_name='tensorflow.TensorInfo.CooSparse.indices_tensor_name', index=1, - number=2, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='dense_shape_tensor_name', full_name='tensorflow.TensorInfo.CooSparse.dense_shape_tensor_name', index=2, - number=3, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=1837, - serialized_end=1938, -) - -_TENSORINFO_COMPOSITETENSOR = _descriptor.Descriptor( - name='CompositeTensor', - full_name='tensorflow.TensorInfo.CompositeTensor', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='type_spec', full_name='tensorflow.TensorInfo.CompositeTensor.type_spec', index=0, - number=1, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='components', full_name='tensorflow.TensorInfo.CompositeTensor.components', index=1, - number=2, type=11, cpp_type=10, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=1940, - serialized_end=2047, -) - -_TENSORINFO = _descriptor.Descriptor( - name='TensorInfo', - full_name='tensorflow.TensorInfo', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='name', full_name='tensorflow.TensorInfo.name', index=0, - number=1, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='coo_sparse', full_name='tensorflow.TensorInfo.coo_sparse', index=1, - number=4, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='composite_tensor', full_name='tensorflow.TensorInfo.composite_tensor', index=2, - number=5, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='dtype', full_name='tensorflow.TensorInfo.dtype', index=3, - number=2, type=14, cpp_type=8, label=1, - has_default_value=False, default_value=0, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='tensor_shape', full_name='tensorflow.TensorInfo.tensor_shape', index=4, - number=3, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[_TENSORINFO_COOSPARSE, _TENSORINFO_COMPOSITETENSOR, ], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - _descriptor.OneofDescriptor( - name='encoding', full_name='tensorflow.TensorInfo.encoding', - index=0, containing_type=None, fields=[]), - ], - serialized_start=1594, - serialized_end=2059, -) - - -_SIGNATUREDEF_INPUTSENTRY = _descriptor.Descriptor( - name='InputsEntry', - full_name='tensorflow.SignatureDef.InputsEntry', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='key', full_name='tensorflow.SignatureDef.InputsEntry.key', index=0, - number=1, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='value', full_name='tensorflow.SignatureDef.InputsEntry.value', index=1, - number=2, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=b'8\001', - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=2209, - serialized_end=2278, -) - -_SIGNATUREDEF_OUTPUTSENTRY = _descriptor.Descriptor( - name='OutputsEntry', - full_name='tensorflow.SignatureDef.OutputsEntry', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='key', full_name='tensorflow.SignatureDef.OutputsEntry.key', index=0, - number=1, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='value', full_name='tensorflow.SignatureDef.OutputsEntry.value', index=1, - number=2, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=b'8\001', - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=2280, - serialized_end=2350, -) - -_SIGNATUREDEF = _descriptor.Descriptor( - name='SignatureDef', - full_name='tensorflow.SignatureDef', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='inputs', full_name='tensorflow.SignatureDef.inputs', index=0, - number=1, type=11, cpp_type=10, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='outputs', full_name='tensorflow.SignatureDef.outputs', index=1, - number=2, type=11, cpp_type=10, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='method_name', full_name='tensorflow.SignatureDef.method_name', index=2, - number=3, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[_SIGNATUREDEF_INPUTSENTRY, _SIGNATUREDEF_OUTPUTSENTRY, ], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=2062, - serialized_end=2350, -) - - -_ASSETFILEDEF = _descriptor.Descriptor( - name='AssetFileDef', - full_name='tensorflow.AssetFileDef', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='tensor_info', full_name='tensorflow.AssetFileDef.tensor_info', index=0, - number=1, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='filename', full_name='tensorflow.AssetFileDef.filename', index=1, - number=2, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=2352, - serialized_end=2429, -) - -_METAGRAPHDEF_METAINFODEF_FUNCTIONALIASESENTRY.containing_type = _METAGRAPHDEF_METAINFODEF -_METAGRAPHDEF_METAINFODEF.fields_by_name['stripped_op_list'].message_type = op__def__pb2._OPLIST -_METAGRAPHDEF_METAINFODEF.fields_by_name['any_info'].message_type = google_dot_protobuf_dot_any__pb2._ANY -_METAGRAPHDEF_METAINFODEF.fields_by_name['function_aliases'].message_type = _METAGRAPHDEF_METAINFODEF_FUNCTIONALIASESENTRY -_METAGRAPHDEF_METAINFODEF.containing_type = _METAGRAPHDEF -_METAGRAPHDEF_COLLECTIONDEFENTRY.fields_by_name['value'].message_type = _COLLECTIONDEF -_METAGRAPHDEF_COLLECTIONDEFENTRY.containing_type = _METAGRAPHDEF -_METAGRAPHDEF_SIGNATUREDEFENTRY.fields_by_name['value'].message_type = _SIGNATUREDEF -_METAGRAPHDEF_SIGNATUREDEFENTRY.containing_type = _METAGRAPHDEF -_METAGRAPHDEF.fields_by_name['meta_info_def'].message_type = _METAGRAPHDEF_METAINFODEF -_METAGRAPHDEF.fields_by_name['graph_def'].message_type = graph__pb2._GRAPHDEF -_METAGRAPHDEF.fields_by_name['saver_def'].message_type = saver__pb2._SAVERDEF -_METAGRAPHDEF.fields_by_name['collection_def'].message_type = _METAGRAPHDEF_COLLECTIONDEFENTRY -_METAGRAPHDEF.fields_by_name['signature_def'].message_type = _METAGRAPHDEF_SIGNATUREDEFENTRY -_METAGRAPHDEF.fields_by_name['asset_file_def'].message_type = _ASSETFILEDEF -_METAGRAPHDEF.fields_by_name['object_graph_def'].message_type = saved__object__graph__pb2._SAVEDOBJECTGRAPH -_COLLECTIONDEF_NODELIST.containing_type = _COLLECTIONDEF -_COLLECTIONDEF_BYTESLIST.containing_type = _COLLECTIONDEF -_COLLECTIONDEF_INT64LIST.containing_type = _COLLECTIONDEF -_COLLECTIONDEF_FLOATLIST.containing_type = _COLLECTIONDEF -_COLLECTIONDEF_ANYLIST.fields_by_name['value'].message_type = google_dot_protobuf_dot_any__pb2._ANY -_COLLECTIONDEF_ANYLIST.containing_type = _COLLECTIONDEF -_COLLECTIONDEF.fields_by_name['node_list'].message_type = _COLLECTIONDEF_NODELIST -_COLLECTIONDEF.fields_by_name['bytes_list'].message_type = _COLLECTIONDEF_BYTESLIST -_COLLECTIONDEF.fields_by_name['int64_list'].message_type = _COLLECTIONDEF_INT64LIST -_COLLECTIONDEF.fields_by_name['float_list'].message_type = _COLLECTIONDEF_FLOATLIST -_COLLECTIONDEF.fields_by_name['any_list'].message_type = _COLLECTIONDEF_ANYLIST -_COLLECTIONDEF.oneofs_by_name['kind'].fields.append( - _COLLECTIONDEF.fields_by_name['node_list']) -_COLLECTIONDEF.fields_by_name['node_list'].containing_oneof = _COLLECTIONDEF.oneofs_by_name['kind'] -_COLLECTIONDEF.oneofs_by_name['kind'].fields.append( - _COLLECTIONDEF.fields_by_name['bytes_list']) -_COLLECTIONDEF.fields_by_name['bytes_list'].containing_oneof = _COLLECTIONDEF.oneofs_by_name['kind'] -_COLLECTIONDEF.oneofs_by_name['kind'].fields.append( - _COLLECTIONDEF.fields_by_name['int64_list']) -_COLLECTIONDEF.fields_by_name['int64_list'].containing_oneof = _COLLECTIONDEF.oneofs_by_name['kind'] -_COLLECTIONDEF.oneofs_by_name['kind'].fields.append( - _COLLECTIONDEF.fields_by_name['float_list']) -_COLLECTIONDEF.fields_by_name['float_list'].containing_oneof = _COLLECTIONDEF.oneofs_by_name['kind'] -_COLLECTIONDEF.oneofs_by_name['kind'].fields.append( - _COLLECTIONDEF.fields_by_name['any_list']) -_COLLECTIONDEF.fields_by_name['any_list'].containing_oneof = _COLLECTIONDEF.oneofs_by_name['kind'] -_TENSORINFO_COOSPARSE.containing_type = _TENSORINFO -_TENSORINFO_COMPOSITETENSOR.fields_by_name['type_spec'].message_type = struct__pb2._TYPESPECPROTO -_TENSORINFO_COMPOSITETENSOR.fields_by_name['components'].message_type = _TENSORINFO -_TENSORINFO_COMPOSITETENSOR.containing_type = _TENSORINFO -_TENSORINFO.fields_by_name['coo_sparse'].message_type = _TENSORINFO_COOSPARSE -_TENSORINFO.fields_by_name['composite_tensor'].message_type = _TENSORINFO_COMPOSITETENSOR -_TENSORINFO.fields_by_name['dtype'].enum_type = types__pb2._DATATYPE -_TENSORINFO.fields_by_name['tensor_shape'].message_type = tensor__shape__pb2._TENSORSHAPEPROTO -_TENSORINFO.oneofs_by_name['encoding'].fields.append( - _TENSORINFO.fields_by_name['name']) -_TENSORINFO.fields_by_name['name'].containing_oneof = _TENSORINFO.oneofs_by_name['encoding'] -_TENSORINFO.oneofs_by_name['encoding'].fields.append( - _TENSORINFO.fields_by_name['coo_sparse']) -_TENSORINFO.fields_by_name['coo_sparse'].containing_oneof = _TENSORINFO.oneofs_by_name['encoding'] -_TENSORINFO.oneofs_by_name['encoding'].fields.append( - _TENSORINFO.fields_by_name['composite_tensor']) -_TENSORINFO.fields_by_name['composite_tensor'].containing_oneof = _TENSORINFO.oneofs_by_name['encoding'] -_SIGNATUREDEF_INPUTSENTRY.fields_by_name['value'].message_type = _TENSORINFO -_SIGNATUREDEF_INPUTSENTRY.containing_type = _SIGNATUREDEF -_SIGNATUREDEF_OUTPUTSENTRY.fields_by_name['value'].message_type = _TENSORINFO -_SIGNATUREDEF_OUTPUTSENTRY.containing_type = _SIGNATUREDEF -_SIGNATUREDEF.fields_by_name['inputs'].message_type = _SIGNATUREDEF_INPUTSENTRY -_SIGNATUREDEF.fields_by_name['outputs'].message_type = _SIGNATUREDEF_OUTPUTSENTRY -_ASSETFILEDEF.fields_by_name['tensor_info'].message_type = _TENSORINFO -DESCRIPTOR.message_types_by_name['MetaGraphDef'] = _METAGRAPHDEF -DESCRIPTOR.message_types_by_name['CollectionDef'] = _COLLECTIONDEF -DESCRIPTOR.message_types_by_name['TensorInfo'] = _TENSORINFO -DESCRIPTOR.message_types_by_name['SignatureDef'] = _SIGNATUREDEF -DESCRIPTOR.message_types_by_name['AssetFileDef'] = _ASSETFILEDEF -_sym_db.RegisterFileDescriptor(DESCRIPTOR) - -MetaGraphDef = _reflection.GeneratedProtocolMessageType('MetaGraphDef', (_message.Message,), { - - 'MetaInfoDef' : _reflection.GeneratedProtocolMessageType('MetaInfoDef', (_message.Message,), { - - 'FunctionAliasesEntry' : _reflection.GeneratedProtocolMessageType('FunctionAliasesEntry', (_message.Message,), { - 'DESCRIPTOR' : _METAGRAPHDEF_METAINFODEF_FUNCTIONALIASESENTRY, - '__module__' : 'meta_graph_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.MetaGraphDef.MetaInfoDef.FunctionAliasesEntry) - }) - , - 'DESCRIPTOR' : _METAGRAPHDEF_METAINFODEF, - '__module__' : 'meta_graph_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.MetaGraphDef.MetaInfoDef) - }) - , - - 'CollectionDefEntry' : _reflection.GeneratedProtocolMessageType('CollectionDefEntry', (_message.Message,), { - 'DESCRIPTOR' : _METAGRAPHDEF_COLLECTIONDEFENTRY, - '__module__' : 'meta_graph_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.MetaGraphDef.CollectionDefEntry) - }) - , - - 'SignatureDefEntry' : _reflection.GeneratedProtocolMessageType('SignatureDefEntry', (_message.Message,), { - 'DESCRIPTOR' : _METAGRAPHDEF_SIGNATUREDEFENTRY, - '__module__' : 'meta_graph_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.MetaGraphDef.SignatureDefEntry) - }) - , - 'DESCRIPTOR' : _METAGRAPHDEF, - '__module__' : 'meta_graph_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.MetaGraphDef) - }) -_sym_db.RegisterMessage(MetaGraphDef) -_sym_db.RegisterMessage(MetaGraphDef.MetaInfoDef) -_sym_db.RegisterMessage(MetaGraphDef.MetaInfoDef.FunctionAliasesEntry) -_sym_db.RegisterMessage(MetaGraphDef.CollectionDefEntry) -_sym_db.RegisterMessage(MetaGraphDef.SignatureDefEntry) - -CollectionDef = _reflection.GeneratedProtocolMessageType('CollectionDef', (_message.Message,), { - - 'NodeList' : _reflection.GeneratedProtocolMessageType('NodeList', (_message.Message,), { - 'DESCRIPTOR' : _COLLECTIONDEF_NODELIST, - '__module__' : 'meta_graph_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.CollectionDef.NodeList) - }) - , - - 'BytesList' : _reflection.GeneratedProtocolMessageType('BytesList', (_message.Message,), { - 'DESCRIPTOR' : _COLLECTIONDEF_BYTESLIST, - '__module__' : 'meta_graph_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.CollectionDef.BytesList) - }) - , - - 'Int64List' : _reflection.GeneratedProtocolMessageType('Int64List', (_message.Message,), { - 'DESCRIPTOR' : _COLLECTIONDEF_INT64LIST, - '__module__' : 'meta_graph_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.CollectionDef.Int64List) - }) - , - - 'FloatList' : _reflection.GeneratedProtocolMessageType('FloatList', (_message.Message,), { - 'DESCRIPTOR' : _COLLECTIONDEF_FLOATLIST, - '__module__' : 'meta_graph_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.CollectionDef.FloatList) - }) - , - - 'AnyList' : _reflection.GeneratedProtocolMessageType('AnyList', (_message.Message,), { - 'DESCRIPTOR' : _COLLECTIONDEF_ANYLIST, - '__module__' : 'meta_graph_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.CollectionDef.AnyList) - }) - , - 'DESCRIPTOR' : _COLLECTIONDEF, - '__module__' : 'meta_graph_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.CollectionDef) - }) -_sym_db.RegisterMessage(CollectionDef) -_sym_db.RegisterMessage(CollectionDef.NodeList) -_sym_db.RegisterMessage(CollectionDef.BytesList) -_sym_db.RegisterMessage(CollectionDef.Int64List) -_sym_db.RegisterMessage(CollectionDef.FloatList) -_sym_db.RegisterMessage(CollectionDef.AnyList) - -TensorInfo = _reflection.GeneratedProtocolMessageType('TensorInfo', (_message.Message,), { - - 'CooSparse' : _reflection.GeneratedProtocolMessageType('CooSparse', (_message.Message,), { - 'DESCRIPTOR' : _TENSORINFO_COOSPARSE, - '__module__' : 'meta_graph_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.TensorInfo.CooSparse) - }) - , - - 'CompositeTensor' : _reflection.GeneratedProtocolMessageType('CompositeTensor', (_message.Message,), { - 'DESCRIPTOR' : _TENSORINFO_COMPOSITETENSOR, - '__module__' : 'meta_graph_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.TensorInfo.CompositeTensor) - }) - , - 'DESCRIPTOR' : _TENSORINFO, - '__module__' : 'meta_graph_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.TensorInfo) - }) -_sym_db.RegisterMessage(TensorInfo) -_sym_db.RegisterMessage(TensorInfo.CooSparse) -_sym_db.RegisterMessage(TensorInfo.CompositeTensor) - -SignatureDef = _reflection.GeneratedProtocolMessageType('SignatureDef', (_message.Message,), { - - 'InputsEntry' : _reflection.GeneratedProtocolMessageType('InputsEntry', (_message.Message,), { - 'DESCRIPTOR' : _SIGNATUREDEF_INPUTSENTRY, - '__module__' : 'meta_graph_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.SignatureDef.InputsEntry) - }) - , - - 'OutputsEntry' : _reflection.GeneratedProtocolMessageType('OutputsEntry', (_message.Message,), { - 'DESCRIPTOR' : _SIGNATUREDEF_OUTPUTSENTRY, - '__module__' : 'meta_graph_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.SignatureDef.OutputsEntry) - }) - , - 'DESCRIPTOR' : _SIGNATUREDEF, - '__module__' : 'meta_graph_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.SignatureDef) - }) -_sym_db.RegisterMessage(SignatureDef) -_sym_db.RegisterMessage(SignatureDef.InputsEntry) -_sym_db.RegisterMessage(SignatureDef.OutputsEntry) - -AssetFileDef = _reflection.GeneratedProtocolMessageType('AssetFileDef', (_message.Message,), { - 'DESCRIPTOR' : _ASSETFILEDEF, - '__module__' : 'meta_graph_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.AssetFileDef) - }) -_sym_db.RegisterMessage(AssetFileDef) - - -DESCRIPTOR._options = None -_METAGRAPHDEF_METAINFODEF_FUNCTIONALIASESENTRY._options = None -_METAGRAPHDEF_COLLECTIONDEFENTRY._options = None -_METAGRAPHDEF_SIGNATUREDEFENTRY._options = None -_COLLECTIONDEF_INT64LIST.fields_by_name['value']._options = None -_COLLECTIONDEF_FLOATLIST.fields_by_name['value']._options = None -_SIGNATUREDEF_INPUTSENTRY._options = None -_SIGNATUREDEF_OUTPUTSENTRY._options = None -# @@protoc_insertion_point(module_scope) diff --git a/redis_consumer/pbs/meta_graph_pb2_grpc.py b/redis_consumer/pbs/meta_graph_pb2_grpc.py deleted file mode 100644 index a8943526..00000000 --- a/redis_consumer/pbs/meta_graph_pb2_grpc.py +++ /dev/null @@ -1,3 +0,0 @@ -# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! -import grpc - diff --git a/redis_consumer/pbs/model_pb2.py b/redis_consumer/pbs/model_pb2.py deleted file mode 100644 index 8bddc4ca..00000000 --- a/redis_consumer/pbs/model_pb2.py +++ /dev/null @@ -1,102 +0,0 @@ -# -*- coding: utf-8 -*- -# Generated by the protocol buffer compiler. DO NOT EDIT! -# source: model.proto - -from google.protobuf import descriptor as _descriptor -from google.protobuf import message as _message -from google.protobuf import reflection as _reflection -from google.protobuf import symbol_database as _symbol_database -# @@protoc_insertion_point(imports) - -_sym_db = _symbol_database.Default() - - -from google.protobuf import wrappers_pb2 as google_dot_protobuf_dot_wrappers__pb2 - - -DESCRIPTOR = _descriptor.FileDescriptor( - name='model.proto', - package='tensorflow.serving', - syntax='proto3', - serialized_options=b'\370\001\001', - serialized_pb=b'\n\x0bmodel.proto\x12\x12tensorflow.serving\x1a\x1egoogle/protobuf/wrappers.proto\"\x8c\x01\n\tModelSpec\x12\x0c\n\x04name\x18\x01 \x01(\t\x12.\n\x07version\x18\x02 \x01(\x0b\x32\x1b.google.protobuf.Int64ValueH\x00\x12\x17\n\rversion_label\x18\x04 \x01(\tH\x00\x12\x16\n\x0esignature_name\x18\x03 \x01(\tB\x10\n\x0eversion_choiceB\x03\xf8\x01\x01\x62\x06proto3' - , - dependencies=[google_dot_protobuf_dot_wrappers__pb2.DESCRIPTOR,]) - - - - -_MODELSPEC = _descriptor.Descriptor( - name='ModelSpec', - full_name='tensorflow.serving.ModelSpec', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='name', full_name='tensorflow.serving.ModelSpec.name', index=0, - number=1, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='version', full_name='tensorflow.serving.ModelSpec.version', index=1, - number=2, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='version_label', full_name='tensorflow.serving.ModelSpec.version_label', index=2, - number=4, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='signature_name', full_name='tensorflow.serving.ModelSpec.signature_name', index=3, - number=3, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - _descriptor.OneofDescriptor( - name='version_choice', full_name='tensorflow.serving.ModelSpec.version_choice', - index=0, containing_type=None, fields=[]), - ], - serialized_start=68, - serialized_end=208, -) - -_MODELSPEC.fields_by_name['version'].message_type = google_dot_protobuf_dot_wrappers__pb2._INT64VALUE -_MODELSPEC.oneofs_by_name['version_choice'].fields.append( - _MODELSPEC.fields_by_name['version']) -_MODELSPEC.fields_by_name['version'].containing_oneof = _MODELSPEC.oneofs_by_name['version_choice'] -_MODELSPEC.oneofs_by_name['version_choice'].fields.append( - _MODELSPEC.fields_by_name['version_label']) -_MODELSPEC.fields_by_name['version_label'].containing_oneof = _MODELSPEC.oneofs_by_name['version_choice'] -DESCRIPTOR.message_types_by_name['ModelSpec'] = _MODELSPEC -_sym_db.RegisterFileDescriptor(DESCRIPTOR) - -ModelSpec = _reflection.GeneratedProtocolMessageType('ModelSpec', (_message.Message,), { - 'DESCRIPTOR' : _MODELSPEC, - '__module__' : 'model_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.serving.ModelSpec) - }) -_sym_db.RegisterMessage(ModelSpec) - - -DESCRIPTOR._options = None -# @@protoc_insertion_point(module_scope) diff --git a/redis_consumer/pbs/model_pb2_grpc.py b/redis_consumer/pbs/model_pb2_grpc.py deleted file mode 100644 index a8943526..00000000 --- a/redis_consumer/pbs/model_pb2_grpc.py +++ /dev/null @@ -1,3 +0,0 @@ -# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! -import grpc - diff --git a/redis_consumer/pbs/node_def_pb2.py b/redis_consumer/pbs/node_def_pb2.py deleted file mode 100644 index b2864a5b..00000000 --- a/redis_consumer/pbs/node_def_pb2.py +++ /dev/null @@ -1,202 +0,0 @@ -# -*- coding: utf-8 -*- -# Generated by the protocol buffer compiler. DO NOT EDIT! -# source: node_def.proto - -from google.protobuf import descriptor as _descriptor -from google.protobuf import message as _message -from google.protobuf import reflection as _reflection -from google.protobuf import symbol_database as _symbol_database -# @@protoc_insertion_point(imports) - -_sym_db = _symbol_database.Default() - - -import redis_consumer.pbs.attr_value_pb2 as attr__value__pb2 - - -DESCRIPTOR = _descriptor.FileDescriptor( - name='node_def.proto', - package='tensorflow', - syntax='proto3', - serialized_options=b'\n\030org.tensorflow.frameworkB\tNodeProtoP\001Z=github.com/tensorflow/tensorflow/tensorflow/go/core/framework\370\001\001', - serialized_pb=b'\n\x0enode_def.proto\x12\ntensorflow\x1a\x10\x61ttr_value.proto\"\xd2\x02\n\x07NodeDef\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\n\n\x02op\x18\x02 \x01(\t\x12\r\n\x05input\x18\x03 \x03(\t\x12\x0e\n\x06\x64\x65vice\x18\x04 \x01(\t\x12+\n\x04\x61ttr\x18\x05 \x03(\x0b\x32\x1d.tensorflow.NodeDef.AttrEntry\x12J\n\x17\x65xperimental_debug_info\x18\x06 \x01(\x0b\x32).tensorflow.NodeDef.ExperimentalDebugInfo\x1a\x42\n\tAttrEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12$\n\x05value\x18\x02 \x01(\x0b\x32\x15.tensorflow.AttrValue:\x02\x38\x01\x1aQ\n\x15\x45xperimentalDebugInfo\x12\x1b\n\x13original_node_names\x18\x01 \x03(\t\x12\x1b\n\x13original_func_names\x18\x02 \x03(\tBi\n\x18org.tensorflow.frameworkB\tNodeProtoP\x01Z=github.com/tensorflow/tensorflow/tensorflow/go/core/framework\xf8\x01\x01\x62\x06proto3' - , - dependencies=[attr__value__pb2.DESCRIPTOR,]) - - - - -_NODEDEF_ATTRENTRY = _descriptor.Descriptor( - name='AttrEntry', - full_name='tensorflow.NodeDef.AttrEntry', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='key', full_name='tensorflow.NodeDef.AttrEntry.key', index=0, - number=1, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='value', full_name='tensorflow.NodeDef.AttrEntry.value', index=1, - number=2, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=b'8\001', - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=238, - serialized_end=304, -) - -_NODEDEF_EXPERIMENTALDEBUGINFO = _descriptor.Descriptor( - name='ExperimentalDebugInfo', - full_name='tensorflow.NodeDef.ExperimentalDebugInfo', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='original_node_names', full_name='tensorflow.NodeDef.ExperimentalDebugInfo.original_node_names', index=0, - number=1, type=9, cpp_type=9, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='original_func_names', full_name='tensorflow.NodeDef.ExperimentalDebugInfo.original_func_names', index=1, - number=2, type=9, cpp_type=9, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=306, - serialized_end=387, -) - -_NODEDEF = _descriptor.Descriptor( - name='NodeDef', - full_name='tensorflow.NodeDef', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='name', full_name='tensorflow.NodeDef.name', index=0, - number=1, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='op', full_name='tensorflow.NodeDef.op', index=1, - number=2, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='input', full_name='tensorflow.NodeDef.input', index=2, - number=3, type=9, cpp_type=9, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='device', full_name='tensorflow.NodeDef.device', index=3, - number=4, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='attr', full_name='tensorflow.NodeDef.attr', index=4, - number=5, type=11, cpp_type=10, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='experimental_debug_info', full_name='tensorflow.NodeDef.experimental_debug_info', index=5, - number=6, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[_NODEDEF_ATTRENTRY, _NODEDEF_EXPERIMENTALDEBUGINFO, ], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=49, - serialized_end=387, -) - -_NODEDEF_ATTRENTRY.fields_by_name['value'].message_type = attr__value__pb2._ATTRVALUE -_NODEDEF_ATTRENTRY.containing_type = _NODEDEF -_NODEDEF_EXPERIMENTALDEBUGINFO.containing_type = _NODEDEF -_NODEDEF.fields_by_name['attr'].message_type = _NODEDEF_ATTRENTRY -_NODEDEF.fields_by_name['experimental_debug_info'].message_type = _NODEDEF_EXPERIMENTALDEBUGINFO -DESCRIPTOR.message_types_by_name['NodeDef'] = _NODEDEF -_sym_db.RegisterFileDescriptor(DESCRIPTOR) - -NodeDef = _reflection.GeneratedProtocolMessageType('NodeDef', (_message.Message,), { - - 'AttrEntry' : _reflection.GeneratedProtocolMessageType('AttrEntry', (_message.Message,), { - 'DESCRIPTOR' : _NODEDEF_ATTRENTRY, - '__module__' : 'node_def_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.NodeDef.AttrEntry) - }) - , - - 'ExperimentalDebugInfo' : _reflection.GeneratedProtocolMessageType('ExperimentalDebugInfo', (_message.Message,), { - 'DESCRIPTOR' : _NODEDEF_EXPERIMENTALDEBUGINFO, - '__module__' : 'node_def_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.NodeDef.ExperimentalDebugInfo) - }) - , - 'DESCRIPTOR' : _NODEDEF, - '__module__' : 'node_def_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.NodeDef) - }) -_sym_db.RegisterMessage(NodeDef) -_sym_db.RegisterMessage(NodeDef.AttrEntry) -_sym_db.RegisterMessage(NodeDef.ExperimentalDebugInfo) - - -DESCRIPTOR._options = None -_NODEDEF_ATTRENTRY._options = None -# @@protoc_insertion_point(module_scope) diff --git a/redis_consumer/pbs/node_def_pb2_grpc.py b/redis_consumer/pbs/node_def_pb2_grpc.py deleted file mode 100644 index a8943526..00000000 --- a/redis_consumer/pbs/node_def_pb2_grpc.py +++ /dev/null @@ -1,3 +0,0 @@ -# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! -import grpc - diff --git a/redis_consumer/pbs/op_def_pb2.py b/redis_consumer/pbs/op_def_pb2.py deleted file mode 100644 index 1a150703..00000000 --- a/redis_consumer/pbs/op_def_pb2.py +++ /dev/null @@ -1,404 +0,0 @@ -# -*- coding: utf-8 -*- -# Generated by the protocol buffer compiler. DO NOT EDIT! -# source: op_def.proto - -from google.protobuf import descriptor as _descriptor -from google.protobuf import message as _message -from google.protobuf import reflection as _reflection -from google.protobuf import symbol_database as _symbol_database -# @@protoc_insertion_point(imports) - -_sym_db = _symbol_database.Default() - - -import redis_consumer.pbs.attr_value_pb2 as attr__value__pb2 -import redis_consumer.pbs.types_pb2 as types__pb2 - - -DESCRIPTOR = _descriptor.FileDescriptor( - name='op_def.proto', - package='tensorflow', - syntax='proto3', - serialized_options=b'\n\030org.tensorflow.frameworkB\013OpDefProtosP\001Z=github.com/tensorflow/tensorflow/tensorflow/go/core/framework\370\001\001', - serialized_pb=b'\n\x0cop_def.proto\x12\ntensorflow\x1a\x10\x61ttr_value.proto\x1a\x0btypes.proto\"\xd0\x05\n\x05OpDef\x12\x0c\n\x04name\x18\x01 \x01(\t\x12+\n\tinput_arg\x18\x02 \x03(\x0b\x32\x18.tensorflow.OpDef.ArgDef\x12,\n\noutput_arg\x18\x03 \x03(\x0b\x32\x18.tensorflow.OpDef.ArgDef\x12\x16\n\x0e\x63ontrol_output\x18\x14 \x03(\t\x12\'\n\x04\x61ttr\x18\x04 \x03(\x0b\x32\x19.tensorflow.OpDef.AttrDef\x12.\n\x0b\x64\x65precation\x18\x08 \x01(\x0b\x32\x19.tensorflow.OpDeprecation\x12\x0f\n\x07summary\x18\x05 \x01(\t\x12\x13\n\x0b\x64\x65scription\x18\x06 \x01(\t\x12\x16\n\x0eis_commutative\x18\x12 \x01(\x08\x12\x14\n\x0cis_aggregate\x18\x10 \x01(\x08\x12\x13\n\x0bis_stateful\x18\x11 \x01(\x08\x12\"\n\x1a\x61llows_uninitialized_input\x18\x13 \x01(\x08\x1a\x9f\x01\n\x06\x41rgDef\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x13\n\x0b\x64\x65scription\x18\x02 \x01(\t\x12\"\n\x04type\x18\x03 \x01(\x0e\x32\x14.tensorflow.DataType\x12\x11\n\ttype_attr\x18\x04 \x01(\t\x12\x13\n\x0bnumber_attr\x18\x05 \x01(\t\x12\x16\n\x0etype_list_attr\x18\x06 \x01(\t\x12\x0e\n\x06is_ref\x18\x10 \x01(\x08\x1a\xbd\x01\n\x07\x41ttrDef\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0c\n\x04type\x18\x02 \x01(\t\x12,\n\rdefault_value\x18\x03 \x01(\x0b\x32\x15.tensorflow.AttrValue\x12\x13\n\x0b\x64\x65scription\x18\x04 \x01(\t\x12\x13\n\x0bhas_minimum\x18\x05 \x01(\x08\x12\x0f\n\x07minimum\x18\x06 \x01(\x03\x12-\n\x0e\x61llowed_values\x18\x07 \x01(\x0b\x32\x15.tensorflow.AttrValue\"5\n\rOpDeprecation\x12\x0f\n\x07version\x18\x01 \x01(\x05\x12\x13\n\x0b\x65xplanation\x18\x02 \x01(\t\"\'\n\x06OpList\x12\x1d\n\x02op\x18\x01 \x03(\x0b\x32\x11.tensorflow.OpDefBk\n\x18org.tensorflow.frameworkB\x0bOpDefProtosP\x01Z=github.com/tensorflow/tensorflow/tensorflow/go/core/framework\xf8\x01\x01\x62\x06proto3' - , - dependencies=[attr__value__pb2.DESCRIPTOR,types__pb2.DESCRIPTOR,]) - - - - -_OPDEF_ARGDEF = _descriptor.Descriptor( - name='ArgDef', - full_name='tensorflow.OpDef.ArgDef', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='name', full_name='tensorflow.OpDef.ArgDef.name', index=0, - number=1, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='description', full_name='tensorflow.OpDef.ArgDef.description', index=1, - number=2, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='type', full_name='tensorflow.OpDef.ArgDef.type', index=2, - number=3, type=14, cpp_type=8, label=1, - has_default_value=False, default_value=0, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='type_attr', full_name='tensorflow.OpDef.ArgDef.type_attr', index=3, - number=4, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='number_attr', full_name='tensorflow.OpDef.ArgDef.number_attr', index=4, - number=5, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='type_list_attr', full_name='tensorflow.OpDef.ArgDef.type_list_attr', index=5, - number=6, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='is_ref', full_name='tensorflow.OpDef.ArgDef.is_ref', index=6, - number=16, type=8, cpp_type=7, label=1, - has_default_value=False, default_value=False, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=429, - serialized_end=588, -) - -_OPDEF_ATTRDEF = _descriptor.Descriptor( - name='AttrDef', - full_name='tensorflow.OpDef.AttrDef', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='name', full_name='tensorflow.OpDef.AttrDef.name', index=0, - number=1, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='type', full_name='tensorflow.OpDef.AttrDef.type', index=1, - number=2, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='default_value', full_name='tensorflow.OpDef.AttrDef.default_value', index=2, - number=3, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='description', full_name='tensorflow.OpDef.AttrDef.description', index=3, - number=4, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='has_minimum', full_name='tensorflow.OpDef.AttrDef.has_minimum', index=4, - number=5, type=8, cpp_type=7, label=1, - has_default_value=False, default_value=False, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='minimum', full_name='tensorflow.OpDef.AttrDef.minimum', index=5, - number=6, type=3, cpp_type=2, label=1, - has_default_value=False, default_value=0, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='allowed_values', full_name='tensorflow.OpDef.AttrDef.allowed_values', index=6, - number=7, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=591, - serialized_end=780, -) - -_OPDEF = _descriptor.Descriptor( - name='OpDef', - full_name='tensorflow.OpDef', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='name', full_name='tensorflow.OpDef.name', index=0, - number=1, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='input_arg', full_name='tensorflow.OpDef.input_arg', index=1, - number=2, type=11, cpp_type=10, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='output_arg', full_name='tensorflow.OpDef.output_arg', index=2, - number=3, type=11, cpp_type=10, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='control_output', full_name='tensorflow.OpDef.control_output', index=3, - number=20, type=9, cpp_type=9, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='attr', full_name='tensorflow.OpDef.attr', index=4, - number=4, type=11, cpp_type=10, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='deprecation', full_name='tensorflow.OpDef.deprecation', index=5, - number=8, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='summary', full_name='tensorflow.OpDef.summary', index=6, - number=5, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='description', full_name='tensorflow.OpDef.description', index=7, - number=6, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='is_commutative', full_name='tensorflow.OpDef.is_commutative', index=8, - number=18, type=8, cpp_type=7, label=1, - has_default_value=False, default_value=False, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='is_aggregate', full_name='tensorflow.OpDef.is_aggregate', index=9, - number=16, type=8, cpp_type=7, label=1, - has_default_value=False, default_value=False, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='is_stateful', full_name='tensorflow.OpDef.is_stateful', index=10, - number=17, type=8, cpp_type=7, label=1, - has_default_value=False, default_value=False, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='allows_uninitialized_input', full_name='tensorflow.OpDef.allows_uninitialized_input', index=11, - number=19, type=8, cpp_type=7, label=1, - has_default_value=False, default_value=False, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[_OPDEF_ARGDEF, _OPDEF_ATTRDEF, ], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=60, - serialized_end=780, -) - - -_OPDEPRECATION = _descriptor.Descriptor( - name='OpDeprecation', - full_name='tensorflow.OpDeprecation', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='version', full_name='tensorflow.OpDeprecation.version', index=0, - number=1, type=5, cpp_type=1, label=1, - has_default_value=False, default_value=0, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='explanation', full_name='tensorflow.OpDeprecation.explanation', index=1, - number=2, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=782, - serialized_end=835, -) - - -_OPLIST = _descriptor.Descriptor( - name='OpList', - full_name='tensorflow.OpList', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='op', full_name='tensorflow.OpList.op', index=0, - number=1, type=11, cpp_type=10, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=837, - serialized_end=876, -) - -_OPDEF_ARGDEF.fields_by_name['type'].enum_type = types__pb2._DATATYPE -_OPDEF_ARGDEF.containing_type = _OPDEF -_OPDEF_ATTRDEF.fields_by_name['default_value'].message_type = attr__value__pb2._ATTRVALUE -_OPDEF_ATTRDEF.fields_by_name['allowed_values'].message_type = attr__value__pb2._ATTRVALUE -_OPDEF_ATTRDEF.containing_type = _OPDEF -_OPDEF.fields_by_name['input_arg'].message_type = _OPDEF_ARGDEF -_OPDEF.fields_by_name['output_arg'].message_type = _OPDEF_ARGDEF -_OPDEF.fields_by_name['attr'].message_type = _OPDEF_ATTRDEF -_OPDEF.fields_by_name['deprecation'].message_type = _OPDEPRECATION -_OPLIST.fields_by_name['op'].message_type = _OPDEF -DESCRIPTOR.message_types_by_name['OpDef'] = _OPDEF -DESCRIPTOR.message_types_by_name['OpDeprecation'] = _OPDEPRECATION -DESCRIPTOR.message_types_by_name['OpList'] = _OPLIST -_sym_db.RegisterFileDescriptor(DESCRIPTOR) - -OpDef = _reflection.GeneratedProtocolMessageType('OpDef', (_message.Message,), { - - 'ArgDef' : _reflection.GeneratedProtocolMessageType('ArgDef', (_message.Message,), { - 'DESCRIPTOR' : _OPDEF_ARGDEF, - '__module__' : 'op_def_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.OpDef.ArgDef) - }) - , - - 'AttrDef' : _reflection.GeneratedProtocolMessageType('AttrDef', (_message.Message,), { - 'DESCRIPTOR' : _OPDEF_ATTRDEF, - '__module__' : 'op_def_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.OpDef.AttrDef) - }) - , - 'DESCRIPTOR' : _OPDEF, - '__module__' : 'op_def_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.OpDef) - }) -_sym_db.RegisterMessage(OpDef) -_sym_db.RegisterMessage(OpDef.ArgDef) -_sym_db.RegisterMessage(OpDef.AttrDef) - -OpDeprecation = _reflection.GeneratedProtocolMessageType('OpDeprecation', (_message.Message,), { - 'DESCRIPTOR' : _OPDEPRECATION, - '__module__' : 'op_def_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.OpDeprecation) - }) -_sym_db.RegisterMessage(OpDeprecation) - -OpList = _reflection.GeneratedProtocolMessageType('OpList', (_message.Message,), { - 'DESCRIPTOR' : _OPLIST, - '__module__' : 'op_def_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.OpList) - }) -_sym_db.RegisterMessage(OpList) - - -DESCRIPTOR._options = None -# @@protoc_insertion_point(module_scope) diff --git a/redis_consumer/pbs/op_def_pb2_grpc.py b/redis_consumer/pbs/op_def_pb2_grpc.py deleted file mode 100644 index a8943526..00000000 --- a/redis_consumer/pbs/op_def_pb2_grpc.py +++ /dev/null @@ -1,3 +0,0 @@ -# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! -import grpc - diff --git a/redis_consumer/pbs/predict_pb2.py b/redis_consumer/pbs/predict_pb2.py deleted file mode 100644 index 8eb45853..00000000 --- a/redis_consumer/pbs/predict_pb2.py +++ /dev/null @@ -1,232 +0,0 @@ -# -*- coding: utf-8 -*- -# Generated by the protocol buffer compiler. DO NOT EDIT! -# source: predict.proto - -from google.protobuf import descriptor as _descriptor -from google.protobuf import message as _message -from google.protobuf import reflection as _reflection -from google.protobuf import symbol_database as _symbol_database -# @@protoc_insertion_point(imports) - -_sym_db = _symbol_database.Default() - - -import redis_consumer.pbs.tensor_pb2 as tensor__pb2 -import redis_consumer.pbs.model_pb2 as model__pb2 - - -DESCRIPTOR = _descriptor.FileDescriptor( - name='predict.proto', - package='tensorflow.serving', - syntax='proto3', - serialized_options=b'\370\001\001', - serialized_pb=b'\n\rpredict.proto\x12\x12tensorflow.serving\x1a\x0ctensor.proto\x1a\x0bmodel.proto\"\xe2\x01\n\x0ePredictRequest\x12\x31\n\nmodel_spec\x18\x01 \x01(\x0b\x32\x1d.tensorflow.serving.ModelSpec\x12>\n\x06inputs\x18\x02 \x03(\x0b\x32..tensorflow.serving.PredictRequest.InputsEntry\x12\x15\n\routput_filter\x18\x03 \x03(\t\x1a\x46\n\x0bInputsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12&\n\x05value\x18\x02 \x01(\x0b\x32\x17.tensorflow.TensorProto:\x02\x38\x01\"\xd0\x01\n\x0fPredictResponse\x12\x31\n\nmodel_spec\x18\x02 \x01(\x0b\x32\x1d.tensorflow.serving.ModelSpec\x12\x41\n\x07outputs\x18\x01 \x03(\x0b\x32\x30.tensorflow.serving.PredictResponse.OutputsEntry\x1aG\n\x0cOutputsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12&\n\x05value\x18\x02 \x01(\x0b\x32\x17.tensorflow.TensorProto:\x02\x38\x01\x42\x03\xf8\x01\x01\x62\x06proto3' - , - dependencies=[tensor__pb2.DESCRIPTOR,model__pb2.DESCRIPTOR,]) - - - - -_PREDICTREQUEST_INPUTSENTRY = _descriptor.Descriptor( - name='InputsEntry', - full_name='tensorflow.serving.PredictRequest.InputsEntry', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='key', full_name='tensorflow.serving.PredictRequest.InputsEntry.key', index=0, - number=1, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='value', full_name='tensorflow.serving.PredictRequest.InputsEntry.value', index=1, - number=2, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=b'8\001', - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=221, - serialized_end=291, -) - -_PREDICTREQUEST = _descriptor.Descriptor( - name='PredictRequest', - full_name='tensorflow.serving.PredictRequest', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='model_spec', full_name='tensorflow.serving.PredictRequest.model_spec', index=0, - number=1, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='inputs', full_name='tensorflow.serving.PredictRequest.inputs', index=1, - number=2, type=11, cpp_type=10, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='output_filter', full_name='tensorflow.serving.PredictRequest.output_filter', index=2, - number=3, type=9, cpp_type=9, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[_PREDICTREQUEST_INPUTSENTRY, ], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=65, - serialized_end=291, -) - - -_PREDICTRESPONSE_OUTPUTSENTRY = _descriptor.Descriptor( - name='OutputsEntry', - full_name='tensorflow.serving.PredictResponse.OutputsEntry', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='key', full_name='tensorflow.serving.PredictResponse.OutputsEntry.key', index=0, - number=1, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='value', full_name='tensorflow.serving.PredictResponse.OutputsEntry.value', index=1, - number=2, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=b'8\001', - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=431, - serialized_end=502, -) - -_PREDICTRESPONSE = _descriptor.Descriptor( - name='PredictResponse', - full_name='tensorflow.serving.PredictResponse', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='model_spec', full_name='tensorflow.serving.PredictResponse.model_spec', index=0, - number=2, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='outputs', full_name='tensorflow.serving.PredictResponse.outputs', index=1, - number=1, type=11, cpp_type=10, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[_PREDICTRESPONSE_OUTPUTSENTRY, ], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=294, - serialized_end=502, -) - -_PREDICTREQUEST_INPUTSENTRY.fields_by_name['value'].message_type = tensor__pb2._TENSORPROTO -_PREDICTREQUEST_INPUTSENTRY.containing_type = _PREDICTREQUEST -_PREDICTREQUEST.fields_by_name['model_spec'].message_type = model__pb2._MODELSPEC -_PREDICTREQUEST.fields_by_name['inputs'].message_type = _PREDICTREQUEST_INPUTSENTRY -_PREDICTRESPONSE_OUTPUTSENTRY.fields_by_name['value'].message_type = tensor__pb2._TENSORPROTO -_PREDICTRESPONSE_OUTPUTSENTRY.containing_type = _PREDICTRESPONSE -_PREDICTRESPONSE.fields_by_name['model_spec'].message_type = model__pb2._MODELSPEC -_PREDICTRESPONSE.fields_by_name['outputs'].message_type = _PREDICTRESPONSE_OUTPUTSENTRY -DESCRIPTOR.message_types_by_name['PredictRequest'] = _PREDICTREQUEST -DESCRIPTOR.message_types_by_name['PredictResponse'] = _PREDICTRESPONSE -_sym_db.RegisterFileDescriptor(DESCRIPTOR) - -PredictRequest = _reflection.GeneratedProtocolMessageType('PredictRequest', (_message.Message,), { - - 'InputsEntry' : _reflection.GeneratedProtocolMessageType('InputsEntry', (_message.Message,), { - 'DESCRIPTOR' : _PREDICTREQUEST_INPUTSENTRY, - '__module__' : 'predict_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.serving.PredictRequest.InputsEntry) - }) - , - 'DESCRIPTOR' : _PREDICTREQUEST, - '__module__' : 'predict_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.serving.PredictRequest) - }) -_sym_db.RegisterMessage(PredictRequest) -_sym_db.RegisterMessage(PredictRequest.InputsEntry) - -PredictResponse = _reflection.GeneratedProtocolMessageType('PredictResponse', (_message.Message,), { - - 'OutputsEntry' : _reflection.GeneratedProtocolMessageType('OutputsEntry', (_message.Message,), { - 'DESCRIPTOR' : _PREDICTRESPONSE_OUTPUTSENTRY, - '__module__' : 'predict_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.serving.PredictResponse.OutputsEntry) - }) - , - 'DESCRIPTOR' : _PREDICTRESPONSE, - '__module__' : 'predict_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.serving.PredictResponse) - }) -_sym_db.RegisterMessage(PredictResponse) -_sym_db.RegisterMessage(PredictResponse.OutputsEntry) - - -DESCRIPTOR._options = None -_PREDICTREQUEST_INPUTSENTRY._options = None -_PREDICTRESPONSE_OUTPUTSENTRY._options = None -# @@protoc_insertion_point(module_scope) diff --git a/redis_consumer/pbs/predict_pb2_grpc.py b/redis_consumer/pbs/predict_pb2_grpc.py deleted file mode 100644 index a8943526..00000000 --- a/redis_consumer/pbs/predict_pb2_grpc.py +++ /dev/null @@ -1,3 +0,0 @@ -# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! -import grpc - diff --git a/redis_consumer/pbs/prediction_service_pb2.py b/redis_consumer/pbs/prediction_service_pb2.py deleted file mode 100644 index 7e88b957..00000000 --- a/redis_consumer/pbs/prediction_service_pb2.py +++ /dev/null @@ -1,66 +0,0 @@ -# -*- coding: utf-8 -*- -# Generated by the protocol buffer compiler. DO NOT EDIT! -# source: prediction_service.proto - -from google.protobuf import descriptor as _descriptor -from google.protobuf import message as _message -from google.protobuf import reflection as _reflection -from google.protobuf import symbol_database as _symbol_database -# @@protoc_insertion_point(imports) - -_sym_db = _symbol_database.Default() - - -import redis_consumer.pbs.get_model_metadata_pb2 as get__model__metadata__pb2 -import redis_consumer.pbs.predict_pb2 as predict__pb2 - - -DESCRIPTOR = _descriptor.FileDescriptor( - name='prediction_service.proto', - package='tensorflow.serving', - syntax='proto3', - serialized_options=b'\370\001\001', - serialized_pb=b'\n\x18prediction_service.proto\x12\x12tensorflow.serving\x1a\x18get_model_metadata.proto\x1a\rpredict.proto2\xd6\x01\n\x11PredictionService\x12R\n\x07Predict\x12\".tensorflow.serving.PredictRequest\x1a#.tensorflow.serving.PredictResponse\x12m\n\x10GetModelMetadata\x12+.tensorflow.serving.GetModelMetadataRequest\x1a,.tensorflow.serving.GetModelMetadataResponseB\x03\xf8\x01\x01\x62\x06proto3' - , - dependencies=[get__model__metadata__pb2.DESCRIPTOR,predict__pb2.DESCRIPTOR,]) - - - -_sym_db.RegisterFileDescriptor(DESCRIPTOR) - - -DESCRIPTOR._options = None - -_PREDICTIONSERVICE = _descriptor.ServiceDescriptor( - name='PredictionService', - full_name='tensorflow.serving.PredictionService', - file=DESCRIPTOR, - index=0, - serialized_options=None, - serialized_start=90, - serialized_end=304, - methods=[ - _descriptor.MethodDescriptor( - name='Predict', - full_name='tensorflow.serving.PredictionService.Predict', - index=0, - containing_service=None, - input_type=predict__pb2._PREDICTREQUEST, - output_type=predict__pb2._PREDICTRESPONSE, - serialized_options=None, - ), - _descriptor.MethodDescriptor( - name='GetModelMetadata', - full_name='tensorflow.serving.PredictionService.GetModelMetadata', - index=1, - containing_service=None, - input_type=get__model__metadata__pb2._GETMODELMETADATAREQUEST, - output_type=get__model__metadata__pb2._GETMODELMETADATARESPONSE, - serialized_options=None, - ), -]) -_sym_db.RegisterServiceDescriptor(_PREDICTIONSERVICE) - -DESCRIPTOR.services_by_name['PredictionService'] = _PREDICTIONSERVICE - -# @@protoc_insertion_point(module_scope) diff --git a/redis_consumer/pbs/prediction_service_pb2_grpc.py b/redis_consumer/pbs/prediction_service_pb2_grpc.py deleted file mode 100644 index 76a23c73..00000000 --- a/redis_consumer/pbs/prediction_service_pb2_grpc.py +++ /dev/null @@ -1,78 +0,0 @@ -# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! -import grpc - -import redis_consumer.pbs.get_model_metadata_pb2 as get__model__metadata__pb2 -import redis_consumer.pbs.predict_pb2 as predict__pb2 - - -class PredictionServiceStub(object): - """open source marker; do not remove - PredictionService provides access to machine-learned models loaded by - model_servers. - Classify. - rpc Classify(ClassificationRequest) returns (ClassificationResponse); - """ - - def __init__(self, channel): - """Constructor. - - Args: - channel: A grpc.Channel. - """ - self.Predict = channel.unary_unary( - '/tensorflow.serving.PredictionService/Predict', - request_serializer=predict__pb2.PredictRequest.SerializeToString, - response_deserializer=predict__pb2.PredictResponse.FromString, - ) - self.GetModelMetadata = channel.unary_unary( - '/tensorflow.serving.PredictionService/GetModelMetadata', - request_serializer=get__model__metadata__pb2.GetModelMetadataRequest.SerializeToString, - response_deserializer=get__model__metadata__pb2.GetModelMetadataResponse.FromString, - ) - - -class PredictionServiceServicer(object): - """open source marker; do not remove - PredictionService provides access to machine-learned models loaded by - model_servers. - Classify. - rpc Classify(ClassificationRequest) returns (ClassificationResponse); - """ - - def Predict(self, request, context): - """Regress. - rpc Regress(RegressionRequest) returns (RegressionResponse); - - Predict -- provides access to loaded TensorFlow model. - """ - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') - - def GetModelMetadata(self, request, context): - """MultiInference API for multi-headed models. - rpc MultiInference(MultiInferenceRequest) returns (MultiInferenceResponse); - - GetModelMetadata - provides access to metadata for loaded models. - """ - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') - - -def add_PredictionServiceServicer_to_server(servicer, server): - rpc_method_handlers = { - 'Predict': grpc.unary_unary_rpc_method_handler( - servicer.Predict, - request_deserializer=predict__pb2.PredictRequest.FromString, - response_serializer=predict__pb2.PredictResponse.SerializeToString, - ), - 'GetModelMetadata': grpc.unary_unary_rpc_method_handler( - servicer.GetModelMetadata, - request_deserializer=get__model__metadata__pb2.GetModelMetadataRequest.FromString, - response_serializer=get__model__metadata__pb2.GetModelMetadataResponse.SerializeToString, - ), - } - generic_handler = grpc.method_handlers_generic_handler( - 'tensorflow.serving.PredictionService', rpc_method_handlers) - server.add_generic_rpc_handlers((generic_handler,)) diff --git a/redis_consumer/pbs/resource_handle_pb2.py b/redis_consumer/pbs/resource_handle_pb2.py deleted file mode 100644 index e149683e..00000000 --- a/redis_consumer/pbs/resource_handle_pb2.py +++ /dev/null @@ -1,156 +0,0 @@ -# -*- coding: utf-8 -*- -# Generated by the protocol buffer compiler. DO NOT EDIT! -# source: resource_handle.proto - -from google.protobuf import descriptor as _descriptor -from google.protobuf import message as _message -from google.protobuf import reflection as _reflection -from google.protobuf import symbol_database as _symbol_database -# @@protoc_insertion_point(imports) - -_sym_db = _symbol_database.Default() - - -import redis_consumer.pbs.tensor_shape_pb2 as tensor__shape__pb2 -import redis_consumer.pbs.types_pb2 as types__pb2 - - -DESCRIPTOR = _descriptor.FileDescriptor( - name='resource_handle.proto', - package='tensorflow', - syntax='proto3', - serialized_options=b'\n\030org.tensorflow.frameworkB\016ResourceHandleP\001Z=github.com/tensorflow/tensorflow/tensorflow/go/core/framework\370\001\001', - serialized_pb=b'\n\x15resource_handle.proto\x12\ntensorflow\x1a\x12tensor_shape.proto\x1a\x0btypes.proto\"\x9f\x02\n\x13ResourceHandleProto\x12\x0e\n\x06\x64\x65vice\x18\x01 \x01(\t\x12\x11\n\tcontainer\x18\x02 \x01(\t\x12\x0c\n\x04name\x18\x03 \x01(\t\x12\x11\n\thash_code\x18\x04 \x01(\x04\x12\x17\n\x0fmaybe_type_name\x18\x05 \x01(\t\x12H\n\x11\x64types_and_shapes\x18\x06 \x03(\x0b\x32-.tensorflow.ResourceHandleProto.DtypeAndShape\x1a\x61\n\rDtypeAndShape\x12#\n\x05\x64type\x18\x01 \x01(\x0e\x32\x14.tensorflow.DataType\x12+\n\x05shape\x18\x02 \x01(\x0b\x32\x1c.tensorflow.TensorShapeProtoBn\n\x18org.tensorflow.frameworkB\x0eResourceHandleP\x01Z=github.com/tensorflow/tensorflow/tensorflow/go/core/framework\xf8\x01\x01\x62\x06proto3' - , - dependencies=[tensor__shape__pb2.DESCRIPTOR,types__pb2.DESCRIPTOR,]) - - - - -_RESOURCEHANDLEPROTO_DTYPEANDSHAPE = _descriptor.Descriptor( - name='DtypeAndShape', - full_name='tensorflow.ResourceHandleProto.DtypeAndShape', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='dtype', full_name='tensorflow.ResourceHandleProto.DtypeAndShape.dtype', index=0, - number=1, type=14, cpp_type=8, label=1, - has_default_value=False, default_value=0, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='shape', full_name='tensorflow.ResourceHandleProto.DtypeAndShape.shape', index=1, - number=2, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=261, - serialized_end=358, -) - -_RESOURCEHANDLEPROTO = _descriptor.Descriptor( - name='ResourceHandleProto', - full_name='tensorflow.ResourceHandleProto', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='device', full_name='tensorflow.ResourceHandleProto.device', index=0, - number=1, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='container', full_name='tensorflow.ResourceHandleProto.container', index=1, - number=2, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='name', full_name='tensorflow.ResourceHandleProto.name', index=2, - number=3, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='hash_code', full_name='tensorflow.ResourceHandleProto.hash_code', index=3, - number=4, type=4, cpp_type=4, label=1, - has_default_value=False, default_value=0, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='maybe_type_name', full_name='tensorflow.ResourceHandleProto.maybe_type_name', index=4, - number=5, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='dtypes_and_shapes', full_name='tensorflow.ResourceHandleProto.dtypes_and_shapes', index=5, - number=6, type=11, cpp_type=10, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[_RESOURCEHANDLEPROTO_DTYPEANDSHAPE, ], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=71, - serialized_end=358, -) - -_RESOURCEHANDLEPROTO_DTYPEANDSHAPE.fields_by_name['dtype'].enum_type = types__pb2._DATATYPE -_RESOURCEHANDLEPROTO_DTYPEANDSHAPE.fields_by_name['shape'].message_type = tensor__shape__pb2._TENSORSHAPEPROTO -_RESOURCEHANDLEPROTO_DTYPEANDSHAPE.containing_type = _RESOURCEHANDLEPROTO -_RESOURCEHANDLEPROTO.fields_by_name['dtypes_and_shapes'].message_type = _RESOURCEHANDLEPROTO_DTYPEANDSHAPE -DESCRIPTOR.message_types_by_name['ResourceHandleProto'] = _RESOURCEHANDLEPROTO -_sym_db.RegisterFileDescriptor(DESCRIPTOR) - -ResourceHandleProto = _reflection.GeneratedProtocolMessageType('ResourceHandleProto', (_message.Message,), { - - 'DtypeAndShape' : _reflection.GeneratedProtocolMessageType('DtypeAndShape', (_message.Message,), { - 'DESCRIPTOR' : _RESOURCEHANDLEPROTO_DTYPEANDSHAPE, - '__module__' : 'resource_handle_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.ResourceHandleProto.DtypeAndShape) - }) - , - 'DESCRIPTOR' : _RESOURCEHANDLEPROTO, - '__module__' : 'resource_handle_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.ResourceHandleProto) - }) -_sym_db.RegisterMessage(ResourceHandleProto) -_sym_db.RegisterMessage(ResourceHandleProto.DtypeAndShape) - - -DESCRIPTOR._options = None -# @@protoc_insertion_point(module_scope) diff --git a/redis_consumer/pbs/resource_handle_pb2_grpc.py b/redis_consumer/pbs/resource_handle_pb2_grpc.py deleted file mode 100644 index a8943526..00000000 --- a/redis_consumer/pbs/resource_handle_pb2_grpc.py +++ /dev/null @@ -1,3 +0,0 @@ -# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! -import grpc - diff --git a/redis_consumer/pbs/saved_object_graph_pb2.py b/redis_consumer/pbs/saved_object_graph_pb2.py deleted file mode 100644 index 7cd9de9e..00000000 --- a/redis_consumer/pbs/saved_object_graph_pb2.py +++ /dev/null @@ -1,720 +0,0 @@ -# -*- coding: utf-8 -*- -# Generated by the protocol buffer compiler. DO NOT EDIT! -# source: saved_object_graph.proto - -from google.protobuf import descriptor as _descriptor -from google.protobuf import message as _message -from google.protobuf import reflection as _reflection -from google.protobuf import symbol_database as _symbol_database -# @@protoc_insertion_point(imports) - -_sym_db = _symbol_database.Default() - - -import redis_consumer.pbs.trackable_object_graph_pb2 as trackable__object__graph__pb2 -import redis_consumer.pbs.struct_pb2 as struct__pb2 -import redis_consumer.pbs.tensor_shape_pb2 as tensor__shape__pb2 -import redis_consumer.pbs.types_pb2 as types__pb2 -import redis_consumer.pbs.versions_pb2 as versions__pb2 -import redis_consumer.pbs.variable_pb2 as variable__pb2 - - -DESCRIPTOR = _descriptor.FileDescriptor( - name='saved_object_graph.proto', - package='tensorflow', - syntax='proto3', - serialized_options=b'\370\001\001', - serialized_pb=b'\n\x18saved_object_graph.proto\x12\ntensorflow\x1a\x1ctrackable_object_graph.proto\x1a\x0cstruct.proto\x1a\x12tensor_shape.proto\x1a\x0btypes.proto\x1a\x0eversions.proto\x1a\x0evariable.proto\"\xe8\x01\n\x10SavedObjectGraph\x12&\n\x05nodes\x18\x01 \x03(\x0b\x32\x17.tensorflow.SavedObject\x12O\n\x12\x63oncrete_functions\x18\x02 \x03(\x0b\x32\x33.tensorflow.SavedObjectGraph.ConcreteFunctionsEntry\x1a[\n\x16\x43oncreteFunctionsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\x30\n\x05value\x18\x02 \x01(\x0b\x32!.tensorflow.SavedConcreteFunction:\x02\x38\x01\"\xbd\x04\n\x0bSavedObject\x12R\n\x08\x63hildren\x18\x01 \x03(\x0b\x32@.tensorflow.TrackableObjectGraph.TrackableObject.ObjectReference\x12^\n\x0eslot_variables\x18\x03 \x03(\x0b\x32\x46.tensorflow.TrackableObjectGraph.TrackableObject.SlotVariableReference\x12\x32\n\x0buser_object\x18\x04 \x01(\x0b\x32\x1b.tensorflow.SavedUserObjectH\x00\x12\'\n\x05\x61sset\x18\x05 \x01(\x0b\x32\x16.tensorflow.SavedAssetH\x00\x12-\n\x08\x66unction\x18\x06 \x01(\x0b\x32\x19.tensorflow.SavedFunctionH\x00\x12-\n\x08variable\x18\x07 \x01(\x0b\x32\x19.tensorflow.SavedVariableH\x00\x12G\n\x16\x62\x61re_concrete_function\x18\x08 \x01(\x0b\x32%.tensorflow.SavedBareConcreteFunctionH\x00\x12-\n\x08\x63onstant\x18\t \x01(\x0b\x32\x19.tensorflow.SavedConstantH\x00\x12-\n\x08resource\x18\n \x01(\x0b\x32\x19.tensorflow.SavedResourceH\x00\x42\x06\n\x04kindJ\x04\x08\x02\x10\x03R\nattributes\"`\n\x0fSavedUserObject\x12\x12\n\nidentifier\x18\x01 \x01(\t\x12\'\n\x07version\x18\x02 \x01(\x0b\x32\x16.tensorflow.VersionDef\x12\x10\n\x08metadata\x18\x03 \x01(\t\"*\n\nSavedAsset\x12\x1c\n\x14\x61sset_file_def_index\x18\x01 \x01(\x05\"\\\n\rSavedFunction\x12\x1a\n\x12\x63oncrete_functions\x18\x01 \x03(\t\x12/\n\rfunction_spec\x18\x02 \x01(\x0b\x32\x18.tensorflow.FunctionSpec\"\xa8\x01\n\x15SavedConcreteFunction\x12\x14\n\x0c\x62ound_inputs\x18\x02 \x03(\x05\x12\x42\n\x1d\x63\x61nonicalized_input_signature\x18\x03 \x01(\x0b\x32\x1b.tensorflow.StructuredValue\x12\x35\n\x10output_signature\x18\x04 \x01(\x0b\x32\x1b.tensorflow.StructuredValue\"|\n\x19SavedBareConcreteFunction\x12\x1e\n\x16\x63oncrete_function_name\x18\x01 \x01(\t\x12\x19\n\x11\x61rgument_keywords\x18\x02 \x03(\t\x12$\n\x1c\x61llowed_positional_arguments\x18\x03 \x01(\x03\"\"\n\rSavedConstant\x12\x11\n\toperation\x18\x01 \x01(\t\"\xf6\x01\n\rSavedVariable\x12#\n\x05\x64type\x18\x01 \x01(\x0e\x32\x14.tensorflow.DataType\x12+\n\x05shape\x18\x02 \x01(\x0b\x32\x1c.tensorflow.TensorShapeProto\x12\x11\n\ttrainable\x18\x03 \x01(\x08\x12<\n\x0fsynchronization\x18\x04 \x01(\x0e\x32#.tensorflow.VariableSynchronization\x12\x34\n\x0b\x61ggregation\x18\x05 \x01(\x0e\x32\x1f.tensorflow.VariableAggregation\x12\x0c\n\x04name\x18\x06 \x01(\t\"\x95\x01\n\x0c\x46unctionSpec\x12\x30\n\x0b\x66ullargspec\x18\x01 \x01(\x0b\x32\x1b.tensorflow.StructuredValue\x12\x11\n\tis_method\x18\x02 \x01(\x08\x12\x34\n\x0finput_signature\x18\x05 \x01(\x0b\x32\x1b.tensorflow.StructuredValueJ\x04\x08\x03\x10\x04J\x04\x08\x04\x10\x05\"\x1f\n\rSavedResource\x12\x0e\n\x06\x64\x65vice\x18\x01 \x01(\tB\x03\xf8\x01\x01\x62\x06proto3' - , - dependencies=[trackable__object__graph__pb2.DESCRIPTOR,struct__pb2.DESCRIPTOR,tensor__shape__pb2.DESCRIPTOR,types__pb2.DESCRIPTOR,versions__pb2.DESCRIPTOR,variable__pb2.DESCRIPTOR,]) - - - - -_SAVEDOBJECTGRAPH_CONCRETEFUNCTIONSENTRY = _descriptor.Descriptor( - name='ConcreteFunctionsEntry', - full_name='tensorflow.SavedObjectGraph.ConcreteFunctionsEntry', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='key', full_name='tensorflow.SavedObjectGraph.ConcreteFunctionsEntry.key', index=0, - number=1, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='value', full_name='tensorflow.SavedObjectGraph.ConcreteFunctionsEntry.value', index=1, - number=2, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=b'8\001', - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=291, - serialized_end=382, -) - -_SAVEDOBJECTGRAPH = _descriptor.Descriptor( - name='SavedObjectGraph', - full_name='tensorflow.SavedObjectGraph', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='nodes', full_name='tensorflow.SavedObjectGraph.nodes', index=0, - number=1, type=11, cpp_type=10, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='concrete_functions', full_name='tensorflow.SavedObjectGraph.concrete_functions', index=1, - number=2, type=11, cpp_type=10, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[_SAVEDOBJECTGRAPH_CONCRETEFUNCTIONSENTRY, ], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=150, - serialized_end=382, -) - - -_SAVEDOBJECT = _descriptor.Descriptor( - name='SavedObject', - full_name='tensorflow.SavedObject', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='children', full_name='tensorflow.SavedObject.children', index=0, - number=1, type=11, cpp_type=10, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='slot_variables', full_name='tensorflow.SavedObject.slot_variables', index=1, - number=3, type=11, cpp_type=10, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='user_object', full_name='tensorflow.SavedObject.user_object', index=2, - number=4, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='asset', full_name='tensorflow.SavedObject.asset', index=3, - number=5, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='function', full_name='tensorflow.SavedObject.function', index=4, - number=6, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='variable', full_name='tensorflow.SavedObject.variable', index=5, - number=7, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='bare_concrete_function', full_name='tensorflow.SavedObject.bare_concrete_function', index=6, - number=8, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='constant', full_name='tensorflow.SavedObject.constant', index=7, - number=9, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='resource', full_name='tensorflow.SavedObject.resource', index=8, - number=10, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - _descriptor.OneofDescriptor( - name='kind', full_name='tensorflow.SavedObject.kind', - index=0, containing_type=None, fields=[]), - ], - serialized_start=385, - serialized_end=958, -) - - -_SAVEDUSEROBJECT = _descriptor.Descriptor( - name='SavedUserObject', - full_name='tensorflow.SavedUserObject', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='identifier', full_name='tensorflow.SavedUserObject.identifier', index=0, - number=1, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='version', full_name='tensorflow.SavedUserObject.version', index=1, - number=2, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='metadata', full_name='tensorflow.SavedUserObject.metadata', index=2, - number=3, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=960, - serialized_end=1056, -) - - -_SAVEDASSET = _descriptor.Descriptor( - name='SavedAsset', - full_name='tensorflow.SavedAsset', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='asset_file_def_index', full_name='tensorflow.SavedAsset.asset_file_def_index', index=0, - number=1, type=5, cpp_type=1, label=1, - has_default_value=False, default_value=0, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=1058, - serialized_end=1100, -) - - -_SAVEDFUNCTION = _descriptor.Descriptor( - name='SavedFunction', - full_name='tensorflow.SavedFunction', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='concrete_functions', full_name='tensorflow.SavedFunction.concrete_functions', index=0, - number=1, type=9, cpp_type=9, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='function_spec', full_name='tensorflow.SavedFunction.function_spec', index=1, - number=2, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=1102, - serialized_end=1194, -) - - -_SAVEDCONCRETEFUNCTION = _descriptor.Descriptor( - name='SavedConcreteFunction', - full_name='tensorflow.SavedConcreteFunction', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='bound_inputs', full_name='tensorflow.SavedConcreteFunction.bound_inputs', index=0, - number=2, type=5, cpp_type=1, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='canonicalized_input_signature', full_name='tensorflow.SavedConcreteFunction.canonicalized_input_signature', index=1, - number=3, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='output_signature', full_name='tensorflow.SavedConcreteFunction.output_signature', index=2, - number=4, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=1197, - serialized_end=1365, -) - - -_SAVEDBARECONCRETEFUNCTION = _descriptor.Descriptor( - name='SavedBareConcreteFunction', - full_name='tensorflow.SavedBareConcreteFunction', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='concrete_function_name', full_name='tensorflow.SavedBareConcreteFunction.concrete_function_name', index=0, - number=1, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='argument_keywords', full_name='tensorflow.SavedBareConcreteFunction.argument_keywords', index=1, - number=2, type=9, cpp_type=9, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='allowed_positional_arguments', full_name='tensorflow.SavedBareConcreteFunction.allowed_positional_arguments', index=2, - number=3, type=3, cpp_type=2, label=1, - has_default_value=False, default_value=0, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=1367, - serialized_end=1491, -) - - -_SAVEDCONSTANT = _descriptor.Descriptor( - name='SavedConstant', - full_name='tensorflow.SavedConstant', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='operation', full_name='tensorflow.SavedConstant.operation', index=0, - number=1, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=1493, - serialized_end=1527, -) - - -_SAVEDVARIABLE = _descriptor.Descriptor( - name='SavedVariable', - full_name='tensorflow.SavedVariable', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='dtype', full_name='tensorflow.SavedVariable.dtype', index=0, - number=1, type=14, cpp_type=8, label=1, - has_default_value=False, default_value=0, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='shape', full_name='tensorflow.SavedVariable.shape', index=1, - number=2, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='trainable', full_name='tensorflow.SavedVariable.trainable', index=2, - number=3, type=8, cpp_type=7, label=1, - has_default_value=False, default_value=False, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='synchronization', full_name='tensorflow.SavedVariable.synchronization', index=3, - number=4, type=14, cpp_type=8, label=1, - has_default_value=False, default_value=0, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='aggregation', full_name='tensorflow.SavedVariable.aggregation', index=4, - number=5, type=14, cpp_type=8, label=1, - has_default_value=False, default_value=0, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='name', full_name='tensorflow.SavedVariable.name', index=5, - number=6, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=1530, - serialized_end=1776, -) - - -_FUNCTIONSPEC = _descriptor.Descriptor( - name='FunctionSpec', - full_name='tensorflow.FunctionSpec', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='fullargspec', full_name='tensorflow.FunctionSpec.fullargspec', index=0, - number=1, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='is_method', full_name='tensorflow.FunctionSpec.is_method', index=1, - number=2, type=8, cpp_type=7, label=1, - has_default_value=False, default_value=False, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='input_signature', full_name='tensorflow.FunctionSpec.input_signature', index=2, - number=5, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=1779, - serialized_end=1928, -) - - -_SAVEDRESOURCE = _descriptor.Descriptor( - name='SavedResource', - full_name='tensorflow.SavedResource', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='device', full_name='tensorflow.SavedResource.device', index=0, - number=1, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=1930, - serialized_end=1961, -) - -_SAVEDOBJECTGRAPH_CONCRETEFUNCTIONSENTRY.fields_by_name['value'].message_type = _SAVEDCONCRETEFUNCTION -_SAVEDOBJECTGRAPH_CONCRETEFUNCTIONSENTRY.containing_type = _SAVEDOBJECTGRAPH -_SAVEDOBJECTGRAPH.fields_by_name['nodes'].message_type = _SAVEDOBJECT -_SAVEDOBJECTGRAPH.fields_by_name['concrete_functions'].message_type = _SAVEDOBJECTGRAPH_CONCRETEFUNCTIONSENTRY -_SAVEDOBJECT.fields_by_name['children'].message_type = trackable__object__graph__pb2._TRACKABLEOBJECTGRAPH_TRACKABLEOBJECT_OBJECTREFERENCE -_SAVEDOBJECT.fields_by_name['slot_variables'].message_type = trackable__object__graph__pb2._TRACKABLEOBJECTGRAPH_TRACKABLEOBJECT_SLOTVARIABLEREFERENCE -_SAVEDOBJECT.fields_by_name['user_object'].message_type = _SAVEDUSEROBJECT -_SAVEDOBJECT.fields_by_name['asset'].message_type = _SAVEDASSET -_SAVEDOBJECT.fields_by_name['function'].message_type = _SAVEDFUNCTION -_SAVEDOBJECT.fields_by_name['variable'].message_type = _SAVEDVARIABLE -_SAVEDOBJECT.fields_by_name['bare_concrete_function'].message_type = _SAVEDBARECONCRETEFUNCTION -_SAVEDOBJECT.fields_by_name['constant'].message_type = _SAVEDCONSTANT -_SAVEDOBJECT.fields_by_name['resource'].message_type = _SAVEDRESOURCE -_SAVEDOBJECT.oneofs_by_name['kind'].fields.append( - _SAVEDOBJECT.fields_by_name['user_object']) -_SAVEDOBJECT.fields_by_name['user_object'].containing_oneof = _SAVEDOBJECT.oneofs_by_name['kind'] -_SAVEDOBJECT.oneofs_by_name['kind'].fields.append( - _SAVEDOBJECT.fields_by_name['asset']) -_SAVEDOBJECT.fields_by_name['asset'].containing_oneof = _SAVEDOBJECT.oneofs_by_name['kind'] -_SAVEDOBJECT.oneofs_by_name['kind'].fields.append( - _SAVEDOBJECT.fields_by_name['function']) -_SAVEDOBJECT.fields_by_name['function'].containing_oneof = _SAVEDOBJECT.oneofs_by_name['kind'] -_SAVEDOBJECT.oneofs_by_name['kind'].fields.append( - _SAVEDOBJECT.fields_by_name['variable']) -_SAVEDOBJECT.fields_by_name['variable'].containing_oneof = _SAVEDOBJECT.oneofs_by_name['kind'] -_SAVEDOBJECT.oneofs_by_name['kind'].fields.append( - _SAVEDOBJECT.fields_by_name['bare_concrete_function']) -_SAVEDOBJECT.fields_by_name['bare_concrete_function'].containing_oneof = _SAVEDOBJECT.oneofs_by_name['kind'] -_SAVEDOBJECT.oneofs_by_name['kind'].fields.append( - _SAVEDOBJECT.fields_by_name['constant']) -_SAVEDOBJECT.fields_by_name['constant'].containing_oneof = _SAVEDOBJECT.oneofs_by_name['kind'] -_SAVEDOBJECT.oneofs_by_name['kind'].fields.append( - _SAVEDOBJECT.fields_by_name['resource']) -_SAVEDOBJECT.fields_by_name['resource'].containing_oneof = _SAVEDOBJECT.oneofs_by_name['kind'] -_SAVEDUSEROBJECT.fields_by_name['version'].message_type = versions__pb2._VERSIONDEF -_SAVEDFUNCTION.fields_by_name['function_spec'].message_type = _FUNCTIONSPEC -_SAVEDCONCRETEFUNCTION.fields_by_name['canonicalized_input_signature'].message_type = struct__pb2._STRUCTUREDVALUE -_SAVEDCONCRETEFUNCTION.fields_by_name['output_signature'].message_type = struct__pb2._STRUCTUREDVALUE -_SAVEDVARIABLE.fields_by_name['dtype'].enum_type = types__pb2._DATATYPE -_SAVEDVARIABLE.fields_by_name['shape'].message_type = tensor__shape__pb2._TENSORSHAPEPROTO -_SAVEDVARIABLE.fields_by_name['synchronization'].enum_type = variable__pb2._VARIABLESYNCHRONIZATION -_SAVEDVARIABLE.fields_by_name['aggregation'].enum_type = variable__pb2._VARIABLEAGGREGATION -_FUNCTIONSPEC.fields_by_name['fullargspec'].message_type = struct__pb2._STRUCTUREDVALUE -_FUNCTIONSPEC.fields_by_name['input_signature'].message_type = struct__pb2._STRUCTUREDVALUE -DESCRIPTOR.message_types_by_name['SavedObjectGraph'] = _SAVEDOBJECTGRAPH -DESCRIPTOR.message_types_by_name['SavedObject'] = _SAVEDOBJECT -DESCRIPTOR.message_types_by_name['SavedUserObject'] = _SAVEDUSEROBJECT -DESCRIPTOR.message_types_by_name['SavedAsset'] = _SAVEDASSET -DESCRIPTOR.message_types_by_name['SavedFunction'] = _SAVEDFUNCTION -DESCRIPTOR.message_types_by_name['SavedConcreteFunction'] = _SAVEDCONCRETEFUNCTION -DESCRIPTOR.message_types_by_name['SavedBareConcreteFunction'] = _SAVEDBARECONCRETEFUNCTION -DESCRIPTOR.message_types_by_name['SavedConstant'] = _SAVEDCONSTANT -DESCRIPTOR.message_types_by_name['SavedVariable'] = _SAVEDVARIABLE -DESCRIPTOR.message_types_by_name['FunctionSpec'] = _FUNCTIONSPEC -DESCRIPTOR.message_types_by_name['SavedResource'] = _SAVEDRESOURCE -_sym_db.RegisterFileDescriptor(DESCRIPTOR) - -SavedObjectGraph = _reflection.GeneratedProtocolMessageType('SavedObjectGraph', (_message.Message,), { - - 'ConcreteFunctionsEntry' : _reflection.GeneratedProtocolMessageType('ConcreteFunctionsEntry', (_message.Message,), { - 'DESCRIPTOR' : _SAVEDOBJECTGRAPH_CONCRETEFUNCTIONSENTRY, - '__module__' : 'saved_object_graph_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.SavedObjectGraph.ConcreteFunctionsEntry) - }) - , - 'DESCRIPTOR' : _SAVEDOBJECTGRAPH, - '__module__' : 'saved_object_graph_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.SavedObjectGraph) - }) -_sym_db.RegisterMessage(SavedObjectGraph) -_sym_db.RegisterMessage(SavedObjectGraph.ConcreteFunctionsEntry) - -SavedObject = _reflection.GeneratedProtocolMessageType('SavedObject', (_message.Message,), { - 'DESCRIPTOR' : _SAVEDOBJECT, - '__module__' : 'saved_object_graph_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.SavedObject) - }) -_sym_db.RegisterMessage(SavedObject) - -SavedUserObject = _reflection.GeneratedProtocolMessageType('SavedUserObject', (_message.Message,), { - 'DESCRIPTOR' : _SAVEDUSEROBJECT, - '__module__' : 'saved_object_graph_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.SavedUserObject) - }) -_sym_db.RegisterMessage(SavedUserObject) - -SavedAsset = _reflection.GeneratedProtocolMessageType('SavedAsset', (_message.Message,), { - 'DESCRIPTOR' : _SAVEDASSET, - '__module__' : 'saved_object_graph_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.SavedAsset) - }) -_sym_db.RegisterMessage(SavedAsset) - -SavedFunction = _reflection.GeneratedProtocolMessageType('SavedFunction', (_message.Message,), { - 'DESCRIPTOR' : _SAVEDFUNCTION, - '__module__' : 'saved_object_graph_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.SavedFunction) - }) -_sym_db.RegisterMessage(SavedFunction) - -SavedConcreteFunction = _reflection.GeneratedProtocolMessageType('SavedConcreteFunction', (_message.Message,), { - 'DESCRIPTOR' : _SAVEDCONCRETEFUNCTION, - '__module__' : 'saved_object_graph_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.SavedConcreteFunction) - }) -_sym_db.RegisterMessage(SavedConcreteFunction) - -SavedBareConcreteFunction = _reflection.GeneratedProtocolMessageType('SavedBareConcreteFunction', (_message.Message,), { - 'DESCRIPTOR' : _SAVEDBARECONCRETEFUNCTION, - '__module__' : 'saved_object_graph_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.SavedBareConcreteFunction) - }) -_sym_db.RegisterMessage(SavedBareConcreteFunction) - -SavedConstant = _reflection.GeneratedProtocolMessageType('SavedConstant', (_message.Message,), { - 'DESCRIPTOR' : _SAVEDCONSTANT, - '__module__' : 'saved_object_graph_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.SavedConstant) - }) -_sym_db.RegisterMessage(SavedConstant) - -SavedVariable = _reflection.GeneratedProtocolMessageType('SavedVariable', (_message.Message,), { - 'DESCRIPTOR' : _SAVEDVARIABLE, - '__module__' : 'saved_object_graph_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.SavedVariable) - }) -_sym_db.RegisterMessage(SavedVariable) - -FunctionSpec = _reflection.GeneratedProtocolMessageType('FunctionSpec', (_message.Message,), { - 'DESCRIPTOR' : _FUNCTIONSPEC, - '__module__' : 'saved_object_graph_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.FunctionSpec) - }) -_sym_db.RegisterMessage(FunctionSpec) - -SavedResource = _reflection.GeneratedProtocolMessageType('SavedResource', (_message.Message,), { - 'DESCRIPTOR' : _SAVEDRESOURCE, - '__module__' : 'saved_object_graph_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.SavedResource) - }) -_sym_db.RegisterMessage(SavedResource) - - -DESCRIPTOR._options = None -_SAVEDOBJECTGRAPH_CONCRETEFUNCTIONSENTRY._options = None -# @@protoc_insertion_point(module_scope) diff --git a/redis_consumer/pbs/saved_object_graph_pb2_grpc.py b/redis_consumer/pbs/saved_object_graph_pb2_grpc.py deleted file mode 100644 index a8943526..00000000 --- a/redis_consumer/pbs/saved_object_graph_pb2_grpc.py +++ /dev/null @@ -1,3 +0,0 @@ -# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! -import grpc - diff --git a/redis_consumer/pbs/saver_pb2.py b/redis_consumer/pbs/saver_pb2.py deleted file mode 100644 index bcf329e4..00000000 --- a/redis_consumer/pbs/saver_pb2.py +++ /dev/null @@ -1,140 +0,0 @@ -# -*- coding: utf-8 -*- -# Generated by the protocol buffer compiler. DO NOT EDIT! -# source: saver.proto - -from google.protobuf import descriptor as _descriptor -from google.protobuf import message as _message -from google.protobuf import reflection as _reflection -from google.protobuf import symbol_database as _symbol_database -# @@protoc_insertion_point(imports) - -_sym_db = _symbol_database.Default() - - - - -DESCRIPTOR = _descriptor.FileDescriptor( - name='saver.proto', - package='tensorflow', - syntax='proto3', - serialized_options=b'\n\023org.tensorflow.utilB\013SaverProtosP\001Z 2: - x.append(_x) - y.append(_y) - - x = np.stack(x, axis=0) # expand to 3D - y = np.stack(y, axis=0) # expand to 3D - - x = np.expand_dims(x, axis=channel_axis) - y = np.expand_dims(y, axis=channel_axis) - - return x.astype('float32'), y.astype('int32') - - -class DummyModel(object): # pylint: disable=useless-object-inheritance - - def predict(self, data): - # Grab a random value from the data dict and select batch dim - batches = 0 if not data else next(iter(data.values())).shape[0] - - return np.random.random((batches, 3)) - - def progress(self, n): - return n - - -class TestTracking(object): - - def test__track_cells(self): - length = 128 - frames = 5 - track_length = 2 - - features = ['appearance', 'neighborhood', 'regionprop', 'distance'] - - # TODO: Fix for channels_first - for data_format in ('channels_last',): # 'channels_first'): - - x, y = _get_dummy_tracking_data( - length, frames=frames, data_format=data_format) - - tracker = tracking.CellTracker( - x, y, - model=DummyModel(), - track_length=track_length, - data_format=data_format, - features=features) - - tracker.track_cells() - - # test tracker.dataframe - df = tracker.dataframe(cell_type='test-value') - assert 'cell_type' in df.columns - - # test incorrect values in tracker.dataframe - with pytest.raises(ValueError): - tracker.dataframe(bad_value=-1) diff --git a/redis_consumer/utils.py b/redis_consumer/utils.py index 01192bf2..c24b0cf2 100644 --- a/redis_consumer/utils.py +++ b/redis_consumer/utils.py @@ -32,127 +32,23 @@ import os import time import timeit -import contextlib import hashlib import logging -import shutil import tarfile -import tempfile import zipfile -import six - -import skimage -from skimage.external import tifffile import numpy as np -import keras_preprocessing.image -import dict_to_protobuf import PIL - -from redis_consumer.pbs.types_pb2 import DESCRIPTOR -from redis_consumer.pbs.tensor_pb2 import TensorProto -from redis_consumer.pbs.tensor_shape_pb2 import TensorShapeProto -from redis_consumer import settings +from skimage.external import tifffile +from tensorflow.keras.preprocessing.image import img_to_array logger = logging.getLogger('redis_consumer.utils') -dtype_to_number = { - i.name: i.number for i in DESCRIPTOR.enum_types_by_name['DataType'].values -} - -# TODO: build this dynamically -number_to_dtype_value = { - 1: 'float_val', - 2: 'double_val', - 3: 'int_val', - 4: 'int_val', - 5: 'int_val', - 6: 'int_val', - 7: 'string_val', - 8: 'scomplex_val', - 9: 'int64_val', - 10: 'bool_val', - 18: 'dcomplex_val', - 19: 'half_val', - 20: 'resource_handle_val' -} - - -def grpc_response_to_dict(grpc_response): - # TODO: 'unicode' object has no attribute 'ListFields' - # response_dict = dict_to_protobuf.protobuf_to_dict(grpc_response) - # return response_dict - grpc_response_dict = dict() - - for k in grpc_response.outputs: - shape = [x.size for x in grpc_response.outputs[k].tensor_shape.dim] - - dtype_constant = grpc_response.outputs[k].dtype - - if dtype_constant not in number_to_dtype_value: - grpc_response_dict[k] = 'value not found' - logger.error('Tensor output data type not supported. ' - 'Returning empty dict.') - - dt = number_to_dtype_value[dtype_constant] - if shape == [1]: - grpc_response_dict[k] = eval( - 'grpc_response.outputs[k].' + dt)[0] - else: - grpc_response_dict[k] = np.array( - eval('grpc_response.outputs[k].' + dt)).reshape(shape) - - return grpc_response_dict - - -def make_tensor_proto(data, dtype): - tensor_proto = TensorProto() - - if isinstance(dtype, six.string_types): - dtype = dtype_to_number[dtype] - - dim = [{'size': 1}] - values = [data] - - if hasattr(data, 'shape'): - dim = [{'size': dim} for dim in data.shape] - values = list(data.reshape(-1)) - - tensor_proto_dict = { - 'dtype': dtype, - 'tensor_shape': { - 'dim': dim - }, - number_to_dtype_value[dtype]: values - } - dict_to_protobuf.dict_to_protobuf(tensor_proto_dict, tensor_proto) - - return tensor_proto - - -# Workaround for python2 not supporting `with tempfile.TemporaryDirectory() as` -# These are unnecessary if not supporting python2 -@contextlib.contextmanager -def cd(newdir, cleanup=lambda: True): - prevdir = os.getcwd() - os.chdir(os.path.expanduser(newdir)) - try: - yield - finally: - os.chdir(prevdir) - cleanup() - - -@contextlib.contextmanager -def get_tempdir(): - dirpath = tempfile.mkdtemp() - - def cleanup(): - return shutil.rmtree(dirpath) - with cd(dirpath, cleanup): - yield dirpath +def strip_bucket_path(path): + """Remove leading/trailing '/'s from cloud bucket folder names""" + return '/'.join(y for y in path.split('/') if y) def iter_image_archive(zip_path, destination): @@ -166,9 +62,7 @@ def iter_image_archive(zip_path, destination): Iterator of all image paths in extracted archive """ archive = zipfile.ZipFile(zip_path, 'r', allowZip64=True) - - def is_valid(x): - return os.path.splitext(x)[1] and '__MACOSX' not in x + is_valid = lambda x: os.path.splitext(x)[1] and '__MACOSX' not in x for info in archive.infolist(): extracted = archive.extract(info, path=destination) if os.path.isfile(extracted): @@ -209,54 +103,13 @@ def get_image(filepath): # tiff files should not have a channel dim img = np.expand_dims(img, axis=-1) else: - img = keras_preprocessing.image.img_to_array(PIL.Image.open(filepath)) + img = img_to_array(PIL.Image.open(filepath)) logger.debug('Loaded %s into numpy array with shape %s', filepath, img.shape) return img.astype('float32') -def pad_image(image, field): - """Pad each the input image for proper dimensions when stitiching. - - Args: - image: np.array of image data - field: receptive field size of model - - Returns: - image data padded in the x and y axes - """ - window = (field - 1) // 2 - # Pad images by the field size in the x and y axes - pad_width = [] - for i in range(len(image.shape)): - if i == image.ndim - 3: - pad_width.append((window, window)) - elif i == image.ndim - 2: - pad_width.append((window, window)) - else: - pad_width.append((0, 0)) - - return np.pad(image, pad_width, mode='reflect') - - -def unpad_image(x, pad_width): - """Unpad image padded with the pad_width. - - Args: - image (numpy.array): Image to unpad. - pad_width (list): List of pads used to pad the image with np.pad. - - Returns: - numpy.array: The unpadded image. - """ - slices = [] - for c in pad_width: - e = None if c[1] == 0 else -c[1] - slices.append(slice(c[0], e)) - return x[tuple(slices)] - - def save_numpy_array(arr, name='', subdir='', output_dir=None): """Split tensor into channels and save each as a tiff file. @@ -368,127 +221,9 @@ def zip_files(files, dest=None, prefix=None): name = f.replace(dest, '') name = name[1:] if name.startswith(os.path.sep) else name zf.write(f, arcname=name) - logger.debug('Saved %s files to %s', len(files), filepath) + logger.debug('Saved %s files to %s in %s seconds.', + len(files), filepath, timeit.default_timer() - start) except Exception as err: logger.error('Failed to write zipfile: %s', err) raise err - logger.debug('Zipped %s files into %s in %s seconds.', - len(files), filepath, timeit.default_timer() - start) return filepath - - -def reshape_matrix(X, y, reshape_size=256, is_channels_first=False): - """ - Reshape matrix of dimension 4 to have x and y of size reshape_size. - Adds overlapping slices to batches. - E.g. reshape_size of 256 yields (1, 1024, 1024, 1) -> (16, 256, 256, 1) - - Args: - X: raw 4D image tensor - y: label mask of 4D image data - reshape_size: size of the square output tensor - is_channels_first: default False for channel dimension last - - Returns: - reshaped `X` and `y` tensors in shape (`reshape_size`, `reshape_size`) - """ - if X.ndim != 4: - raise ValueError('reshape_matrix expects X dim to be 4, got', X.ndim) - elif y.ndim != 4: - raise ValueError('reshape_matrix expects y dim to be 4, got', y.ndim) - - image_size_x, _ = X.shape[2:] if is_channels_first else X.shape[1:3] - rep_number = np.int(np.ceil(np.float(image_size_x) / np.float(reshape_size))) - new_batch_size = X.shape[0] * (rep_number) ** 2 - - if is_channels_first: - new_X_shape = (new_batch_size, X.shape[1], reshape_size, reshape_size) - new_y_shape = (new_batch_size, y.shape[1], reshape_size, reshape_size) - else: - new_X_shape = (new_batch_size, reshape_size, reshape_size, X.shape[3]) - new_y_shape = (new_batch_size, reshape_size, reshape_size, y.shape[3]) - - new_X = np.zeros(new_X_shape, dtype='float32') - new_y = np.zeros(new_y_shape, dtype='int32') - - counter = 0 - for b in range(X.shape[0]): - for i in range(rep_number): - for j in range(rep_number): - if i != rep_number - 1: - x_start, x_end = i * reshape_size, (i + 1) * reshape_size - else: - x_start, x_end = -reshape_size, X.shape[2 if is_channels_first else 1] - - if j != rep_number - 1: - y_start, y_end = j * reshape_size, (j + 1) * reshape_size - else: - y_start, y_end = -reshape_size, y.shape[3 if is_channels_first else 2] - - if is_channels_first: - new_X[counter] = X[b, :, x_start:x_end, y_start:y_end] - new_y[counter] = y[b, :, x_start:x_end, y_start:y_end] - else: - new_X[counter] = X[b, x_start:x_end, y_start:y_end, :] - new_y[counter] = y[b, x_start:x_end, y_start:y_end, :] - - counter += 1 - - print('Reshaped feature data from {} to {}'.format(y.shape, new_y.shape)) - print('Reshaped training data from {} to {}'.format(X.shape, new_X.shape)) - return new_X, new_y - - -def rescale(image, scale, channel_axis=-1): - multichannel = False - add_channel = False - if scale == 1: - return image # no rescale necessary, short-circuit - - if len(image.shape) != 2: # we have a channel axis - try: - image = np.squeeze(image, axis=channel_axis) - add_channel = True - except: # pylint: disable=bare-except - multichannel = True # channel axis is not 1 - - rescaled_img = skimage.transform.rescale( - image, scale, - mode='edge', - anti_aliasing=False, - anti_aliasing_sigma=None, - multichannel=multichannel, - preserve_range=True, - order=0 - ) - if add_channel: - rescaled_img = np.expand_dims(rescaled_img, axis=channel_axis) - logger.debug('Rescaled image from %s to %s', image.shape, rescaled_img.shape) - return rescaled_img - - -def _pick_model(label): - model = settings.MODEL_CHOICES.get(label) - if model is None: - logger.error('Label type %s is not supported', label) - raise ValueError('Label type {} is not supported'.format(label)) - - return model.split(':') - - -def _pick_preprocess(label): - func = settings.PREPROCESS_CHOICES.get(label) - if func is None: - logger.error('Label type %s is not supported', label) - raise ValueError('Label type {} is not supported'.format(label)) - - return func - - -def _pick_postprocess(label): - func = settings.POSTPROCESS_CHOICES.get(label) - if func is None: - logger.error('Label type %s is not supported', label) - raise ValueError('Label type {} is not supported'.format(label)) - - return func diff --git a/redis_consumer/utils_test.py b/redis_consumer/utils_test.py index d66402ec..59a80d31 100644 --- a/redis_consumer/utils_test.py +++ b/redis_consumer/utils_test.py @@ -29,32 +29,25 @@ from __future__ import print_function import os -import pytest import tarfile import tempfile import zipfile +import pytest + import numpy as np -from keras_preprocessing.image import array_to_img +from tensorflow.keras.preprocessing.image import array_to_img from skimage.external import tifffile as tiff -from redis_consumer.pbs.predict_pb2 import PredictResponse -from redis_consumer.pbs import types_pb2 -from redis_consumer.pbs.tensor_pb2 import TensorProto -from redis_consumer import utils -from redis_consumer import settings - +from redis_consumer.testing_utils import _get_image -def _get_image(img_h=300, img_w=300, channels=1): - bias = np.random.rand(img_w, img_h, channels) * 64 - variance = np.random.rand(img_w, img_h, channels) * (255 - 64) - img = np.random.rand(img_w, img_h, channels) * variance + bias - return img +from redis_consumer import utils def _write_image(filepath, img_w=300, img_h=300): - imarray = _get_image(img_h, img_w) - if filepath.lower().endswith('tif') or filepath.lower().endswith('tiff'): + imarray = _get_image(img_h, img_w, 1) + _, ext = os.path.splitext(filepath.lower()) + if ext in {'.tif', '.tiff'}: tiff.imsave(filepath, imarray[..., 0]) else: img = array_to_img(imarray, scale=False, data_format='channels_last') @@ -77,77 +70,47 @@ def _write_trks(filepath, X_mean=10, y_mean=5, trks.add(tracked_file.name, 'tracked.npy') -def test_make_tensor_proto(): - # test with numpy array - data = _get_image(300, 300, 1) - proto = utils.make_tensor_proto(data, 'DT_FLOAT') - assert isinstance(proto, (TensorProto,)) - # test with value - data = 10.0 - proto = utils.make_tensor_proto(data, types_pb2.DT_FLOAT) - assert isinstance(proto, (TensorProto,)) - - -def test_grpc_response_to_dict(): - # pylint: disable=E1101 - # test valid response - data = _get_image(300, 300, 1) - tensor_proto = utils.make_tensor_proto(data, 'DT_FLOAT') - response = PredictResponse() - response.outputs['prediction'].CopyFrom(tensor_proto) - response_dict = utils.grpc_response_to_dict(response) - assert isinstance(response_dict, (dict,)) - np.testing.assert_allclose(response_dict['prediction'], data) - # test scalar input - data = 3 - tensor_proto = utils.make_tensor_proto(data, 'DT_FLOAT') - response = PredictResponse() - response.outputs['prediction'].CopyFrom(tensor_proto) - response_dict = utils.grpc_response_to_dict(response) - assert isinstance(response_dict, (dict,)) - np.testing.assert_allclose(response_dict['prediction'], data) - # test bad dtype - # logs an error, but should throw a KeyError as well. - data = _get_image(300, 300, 1) - tensor_proto = utils.make_tensor_proto(data, 'DT_FLOAT') - response = PredictResponse() - response.outputs['prediction'].CopyFrom(tensor_proto) - response.outputs['prediction'].dtype = 32 - with pytest.raises(KeyError): - response_dict = utils.grpc_response_to_dict(response) - - -def test_iter_image_archive(): - with utils.get_tempdir() as tempdir: - zip_path = os.path.join(tempdir, 'test.zip') - archive = zipfile.ZipFile(zip_path, 'w') - num_files = 3 - for n in range(num_files): - path = os.path.join(tempdir, '{}.tif'.format(n)) - _write_image(path, 30, 30) - archive.write(path) - archive.close() - - unzipped = [z for z in utils.iter_image_archive(zip_path, tempdir)] - assert len(unzipped) == num_files - - -def test_get_image_files_from_dir(): - with utils.get_tempdir() as tempdir: - zip_path = os.path.join(tempdir, 'test.zip') - archive = zipfile.ZipFile(zip_path, 'w') - num_files = 3 - for n in range(num_files): - path = os.path.join(tempdir, '{}.tif'.format(n)) - _write_image(path, 30, 30) - archive.write(path) - archive.close() - - imfiles = list(utils.get_image_files_from_dir(path, None)) - assert len(imfiles) == 1 - - imfiles = list(utils.get_image_files_from_dir(zip_path, tempdir)) - assert len(imfiles) == num_files +def test_strip_bucket_path(): + path = 'path/to/file' + # leading, trailing, and both format strings + format_strs = ['/{}', '{}/', '/{}/'] + for fmtstr in format_strs: + stripped = utils.strip_bucket_path(fmtstr.format(path)) + assert path == stripped + + +def test_iter_image_archive(tmpdir): + tmpdir = str(tmpdir) + zip_path = os.path.join(tmpdir, 'test.zip') + archive = zipfile.ZipFile(zip_path, 'w') + num_files = 3 + for n in range(num_files): + path = os.path.join(tmpdir, '{}.tif'.format(n)) + _write_image(path, 30, 30) + archive.write(path) + archive.close() + + unzipped = [z for z in utils.iter_image_archive(zip_path, tmpdir)] + assert len(unzipped) == num_files + + +def test_get_image_files_from_dir(tmpdir): + tmpdir = str(tmpdir) + + zip_path = os.path.join(tmpdir, 'test.zip') + archive = zipfile.ZipFile(zip_path, 'w') + num_files = 3 + for n in range(num_files): + path = os.path.join(tmpdir, '{}.tif'.format(n)) + _write_image(path, 30, 30) + archive.write(path) + archive.close() + + imfiles = list(utils.get_image_files_from_dir(path, None)) + assert len(imfiles) == 1 + + imfiles = list(utils.get_image_files_from_dir(zip_path, tmpdir)) + assert len(imfiles) == num_files def test_get_image(tmpdir): @@ -156,6 +119,7 @@ def test_get_image(tmpdir): test_img_path = os.path.join(tmpdir, 'phase.tif') _write_image(test_img_path, 300, 300) test_img = utils.get_image(test_img_path) + print(test_img.shape) np.testing.assert_equal(test_img.shape, (300, 300, 1)) # test png files test_img_path = os.path.join(tmpdir, 'feature_0.png') @@ -165,92 +129,36 @@ def test_get_image(tmpdir): np.testing.assert_equal(test_img.shape, (400, 400, 1)) -def test_pad_image(): - # 2D images - h, w = 300, 300 - img = _get_image(h, w) - field_size = 61 - padded = utils.pad_image(img, field_size) - - new_h, new_w = h + (field_size - 1), w + (field_size - 1) - np.testing.assert_equal(padded.shape, (new_h, new_w, 1)) - - # 3D images - frames = np.random.randint(low=1, high=6) - imgs = np.vstack([_get_image(h, w)[None, ...] for i in range(frames)]) - padded = utils.pad_image(imgs, field_size) - np.testing.assert_equal(padded.shape, (frames, new_h, new_w, 1)) - - -def test_unpad_image(): - # 2D images - h, w = 300, 300 - - sizes = [ - (300, 300), - (101, 101) - ] - - pads = [ - (10, 10), - (15, 15), - (10, 15) - ] - for pad in pads: - for h, w in sizes: - raw = _get_image(h, w) - pad_width = [pad, pad, (0, 0)] - padded = np.pad(raw, pad_width, mode='reflect') - - unpadded = utils.unpad_image(padded, pad_width) - np.testing.assert_equal(unpadded.shape, (h, w, 1)) - np.testing.assert_equal(unpadded, raw) - - # 3D images - frames = np.random.randint(low=1, high=6) - imgs = np.vstack([_get_image(h, w)[None, ...] - for _ in range(frames)]) - - pad_width = [(0, 0), pad, pad, (0, 0)] - - padded = np.pad(imgs, pad_width, mode='reflect') - - unpadded = utils.unpad_image(padded, pad_width) - - np.testing.assert_equal(unpadded.shape, imgs.shape) - np.testing.assert_equal(unpadded, imgs) - - -def test_save_numpy_array(): +def test_save_numpy_array(tmpdir): + tmpdir = str(tmpdir) h, w = 30, 30 c = np.random.randint(low=1, high=4) z = np.random.randint(low=1, high=6) - with utils.get_tempdir() as tempdir: - # 2D images without channel axis - img = _get_image(h, w, 1) - img = np.squeeze(img) - files = utils.save_numpy_array(img, 'name', '/a/b/', tempdir) - assert len(files) == 1 - for f in files: - assert os.path.isfile(f) - assert f.startswith(os.path.join(tempdir, 'a', 'b')) - - # 2D images - img = _get_image(h, w, c) - files = utils.save_numpy_array(img, 'name', '/a/b/', tempdir) - assert len(files) == c - for f in files: - assert os.path.isfile(f) - assert f.startswith(os.path.join(tempdir, 'a', 'b')) - - # 3D images - imgs = np.vstack([_get_image(h, w, c)[None, ...] for i in range(z)]) - files = utils.save_numpy_array(imgs, 'name', '/a/b/', tempdir) - assert len(files) == c - for f in files: - assert os.path.isfile(f) - assert f.startswith(os.path.join(tempdir, 'a', 'b')) + # 2D images without channel axis + img = _get_image(h, w, 1) + img = np.squeeze(img) + files = utils.save_numpy_array(img, 'name', '/a/b/', tmpdir) + assert len(files) == 1 + for f in files: + assert os.path.isfile(f) + assert f.startswith(os.path.join(tmpdir, 'a', 'b')) + + # 2D images + img = _get_image(h, w, c) + files = utils.save_numpy_array(img, 'name', '/a/b/', tmpdir) + assert len(files) == c + for f in files: + assert os.path.isfile(f) + assert f.startswith(os.path.join(tmpdir, 'a', 'b')) + + # 3D images + imgs = np.vstack([_get_image(h, w, c)[None, ...] for i in range(z)]) + files = utils.save_numpy_array(imgs, 'name', '/a/b/', tmpdir) + assert len(files) == c + for f in files: + assert os.path.isfile(f) + assert f.startswith(os.path.join(tmpdir, 'a', 'b')) # Bad path will not fail, but will log error img = _get_image(h, w, c) @@ -322,108 +230,3 @@ def test_zip_files(tmpdir): with pytest.raises(Exception): bad_dest = os.path.join(tmpdir, 'does', 'not', 'exist') zip_path = utils.zip_files(paths, bad_dest, prefix) - - -def test_reshape_matrix(): - # K.set_image_data_format('channels_last') - X = np.zeros((1, 16, 16, 3)) - y = np.zeros((1, 16, 16, 1)) - new_size = 4 - - # test resize to smaller image, divisible - new_X, new_y = utils.reshape_matrix(X, y, new_size) - new_batch = np.ceil(16 / new_size) ** 2 - assert new_X.shape == (new_batch, new_size, new_size, 3) - assert new_y.shape == (new_batch, new_size, new_size, 1) - - # test reshape with non-divisible values. - new_size = 5 - new_batch = np.ceil(16 / new_size) ** 2 - new_X, new_y = utils.reshape_matrix(X, y, new_size) - assert new_X.shape == (new_batch, new_size, new_size, 3) - assert new_y.shape == (new_batch, new_size, new_size, 1) - - # test reshape to bigger size - with pytest.raises(ValueError): - new_X, new_y = utils.reshape_matrix(X, y, 32) - - # test wrong dimensions - bigger = np.zeros((1, 16, 16, 3, 1)) - smaller = np.zeros((1, 16, 16)) - with pytest.raises(ValueError): - new_X, new_y = utils.reshape_matrix(smaller, y, new_size) - with pytest.raises(ValueError): - new_X, new_y = utils.reshape_matrix(bigger, y, new_size) - with pytest.raises(ValueError): - new_X, new_y = utils.reshape_matrix(X, smaller, new_size) - with pytest.raises(ValueError): - new_X, new_y = utils.reshape_matrix(X, bigger, new_size) - - # channels_first - # K.set_image_data_format('channels_first') - X = np.zeros((1, 3, 16, 16)) - y = np.zeros((1, 1, 16, 16)) - new_size = 4 - - # test resize to smaller image, divisible - new_X, new_y = utils.reshape_matrix(X, y, new_size, True) - new_batch = np.ceil(16 / new_size) ** 2 - assert new_X.shape == (new_batch, 3, new_size, new_size) - assert new_y.shape == (new_batch, 1, new_size, new_size) - - # test reshape with non-divisible values. - new_size = 5 - new_batch = np.ceil(16 / new_size) ** 2 - new_X, new_y = utils.reshape_matrix(X, y, new_size, True) - assert new_X.shape == (new_batch, 3, new_size, new_size) - assert new_y.shape == (new_batch, 1, new_size, new_size) - - -def test_rescale(): - scales = [.5, 2] - shapes = [(4, 4, 5), (4, 4, 1), (4, 4)] - for scale in scales: - for shape in shapes: - image = np.random.random(shape) - rescaled = utils.rescale(image, 1) - np.testing.assert_array_equal(rescaled, image) - - rescaled = utils.rescale(image, scale) - expected_shape = (int(np.ceil(shape[0] * scale)), - int(np.ceil(shape[1] * scale))) - - if len(shape) > 2: - expected_shape = tuple(list(expected_shape) + [int(shape[2])]) - assert rescaled.shape == expected_shape - # scale it back - rescaled = utils.rescale(rescaled, 1 / scale) - assert rescaled.shape == shape - - -def test__pick_model(mocker): - mocker.patch.object(settings, 'MODEL_CHOICES', {0: 'dummymodel:0'}) - res = utils._pick_model(0) - assert len(res) == 2 - assert res[0] == 'dummymodel' - assert res[1] == '0' - - with pytest.raises(ValueError): - utils._pick_model(-1) - - -def test__pick_preprocess(mocker): - mocker.patch.object(settings, 'PREPROCESS_CHOICES', {0: 'pre'}) - res = utils._pick_preprocess(0) - assert res == 'pre' - - with pytest.raises(ValueError): - utils._pick_preprocess(-1) - - -def test__pick_postprocess(mocker): - mocker.patch.object(settings, 'POSTPROCESS_CHOICES', {0: 'post'}) - res = utils._pick_postprocess(0) - assert res == 'post' - - with pytest.raises(ValueError): - utils._pick_postprocess(-1) diff --git a/requirements-no-deps.txt b/requirements-no-deps.txt new file mode 100644 index 00000000..a5fb2630 --- /dev/null +++ b/requirements-no-deps.txt @@ -0,0 +1,3 @@ +# tensorflow-serving-api installs tensorflow so install +# with --no-deps to prevent overwriting tensorflow-cpu +tensorflow-serving-api==2.3.0 diff --git a/requirements.txt b/requirements.txt index dd66fc75..4c8ed73b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,12 +1,19 @@ -boto3==1.9.195 -google-cloud-storage>=1.16.1 -python-decouple==3.1 -redis==3.4.1 -scikit-image>=0.14.0,<0.17.0 -numpy>=1.16.4 -keras-preprocessing==1.1.0 -grpcio==1.27.2 +# deepcell packages +deepcell-cpu>=0.8.6 +deepcell-tracking==0.3.1 +deepcell-toolbox>=0.8.6 +tensorflow-cpu +scikit-image>=0.14.5,<0.17.0 +numpy>=1.16.6 + +# tensorflow-serving-apis and gRPC dependencies +grpcio>=1.0,<2 dict-to-protobuf==0.0.3.9 +protobuf>=3.6.0 + +# misc storage and redis clients +boto3>=1.9.195,<2 +google-cloud-storage>=1.16.1,<2 +python-decouple>=3.1,<4 +redis==3.4.1 pytz==2019.1 -deepcell-tracking==0.3.0 -deepcell-toolbox>=0.8.2