From fbbf336ac94258a9353682382895bbfaa46f9604 Mon Sep 17 00:00:00 2001 From: William Graf <7930703+willgraf@users.noreply.github.com> Date: Sun, 1 Mar 2020 21:36:41 -0800 Subject: [PATCH 01/47] deprecate processing API --- protos/process.proto | 78 ---- protos/processing_service.proto | 13 - redis_consumer/grpc_clients.py | 141 ------ redis_consumer/pbs/process_pb2.py | 442 ------------------ redis_consumer/pbs/process_pb2_grpc.py | 3 - redis_consumer/pbs/processing_service_pb2.py | 66 --- .../pbs/processing_service_pb2_grpc.py | 63 --- 7 files changed, 806 deletions(-) delete mode 100644 protos/process.proto delete mode 100644 protos/processing_service.proto delete mode 100644 redis_consumer/pbs/process_pb2.py delete mode 100644 redis_consumer/pbs/process_pb2_grpc.py delete mode 100644 redis_consumer/pbs/processing_service_pb2.py delete mode 100644 redis_consumer/pbs/processing_service_pb2_grpc.py diff --git a/protos/process.proto b/protos/process.proto deleted file mode 100644 index 1c2b0b55..00000000 --- a/protos/process.proto +++ /dev/null @@ -1,78 +0,0 @@ -syntax = "proto3"; - -package tensorflow.serving; -option cc_enable_arenas = true; - -import "tensor.proto"; -import "function.proto"; - -// ProcessRequest specifies which TensorFlow model to run, as well as -// how inputs are mapped to tensors and how outputs are filtered before -// returning to user. -message ProcessRequest { - // Model Specification. - FunctionSpec function_spec = 1; - - // Input tensors. - // Names of input tensor are alias names. The mapping from aliases to real - // input tensor names is expected to be stored as named generic signature - // under the key "inputs" in the model export. - // Each alias listed in a generic signature named "inputs" should be provided - // exactly once in order to run the processing. - map inputs = 2; - - // Output filter. - // Names specified are alias names. The mapping from aliases to real output - // tensor names is expected to be stored as named generic signature under - // the key "outputs" in the model export. - // Only tensors specified here will be run/fetched and returned, with the - // exception that when none is specified, all tensors specified in the - // named signature will be run/fetched and returned. - repeated string output_filter = 3; -} - -// Response for ProcessRequest on successful run. -message ProcessResponse { - // Output tensors. - map outputs = 1; -} - -message ChunkedProcessRequest { - // Model Specification. - FunctionSpec function_spec = 1; - - // Input tensors. - // Names of input tensor are alias names. The mapping from aliases to real - // input tensor names is expected to be stored as named generic signature - // under the key "inputs" in the model export. - // Each alias listed in a generic signature named "inputs" should be provided - // exactly once in order to run the processing. - map inputs = 2; - - // Output filter. - // Names specified are alias names. The mapping from aliases to real output - // tensor names is expected to be stored as named generic signature under - // the key "outputs" in the model export. - // Only tensors specified here will be run/fetched and returned, with the - // exception that when none is specified, all tensors specified in the - // named signature will be run/fetched and returned. - repeated string output_filter = 3; - - // Shape of chunked array. - repeated int64 shape = 4; - - // Dtype of chunked array. - string dtype = 5; -} - -// Response for ChunkedProcessRequest on successful run. -message ChunkedProcessResponse { - // Output tensors. - map outputs = 1; - - // Shape of chunked array. - repeated int64 shape = 4; - - // Dtype of chunked array. - string dtype = 5; -} diff --git a/protos/processing_service.proto b/protos/processing_service.proto deleted file mode 100644 index a3c29328..00000000 --- a/protos/processing_service.proto +++ /dev/null @@ -1,13 +0,0 @@ -syntax = "proto3"; - -package tensorflow.serving; -option cc_enable_arenas = true; - -import "process.proto"; - -// ProcessingService provides access to data processing functions -service ProcessingService { - // Process -- provides access to a data processing function - rpc Process(ProcessRequest) returns (ProcessResponse); - rpc StreamProcess(stream ChunkedProcessRequest) returns (stream ChunkedProcessResponse); -} diff --git a/redis_consumer/grpc_clients.py b/redis_consumer/grpc_clients.py index 82cbf336..f7b9c07b 100644 --- a/redis_consumer/grpc_clients.py +++ b/redis_consumer/grpc_clients.py @@ -151,147 +151,6 @@ def predict(self, request_data, request_timeout=10): return {} -class ProcessClient(GrpcClient): - """gRPC Client for data-processing API. - - Arguments: - host: string, the hostname and port of the server (`localhost:8080`) - process_type: string, pre or post processing - function_name: string, name of processing function - """ - - def __init__(self, host, process_type, function_name): - super(ProcessClient, self).__init__(host) - self.process_type = process_type - self.function_name = function_name - - def process(self, request_data, request_timeout=10): - self.logger.info('Sending request to %s %s-process data with the ' - 'data-processing API at %s.', self.function_name, - self.process_type, self.host) - - # Create gRPC client and request - channel = self.insecure_channel() - - t = timeit.default_timer() - stub = ProcessingServiceStub(channel) - self.logger.debug('Created DataProcessingProcessingServiceStub in %s ' - 'seconds.', timeit.default_timer() - t) - - t = timeit.default_timer() - request = ProcessRequest() - self.logger.debug('Created DataProcessingRequest object in %s seconds.', - timeit.default_timer() - t) - - # pylint: disable=E1101 - request.function_spec.name = self.function_name - request.function_spec.type = self.process_type - # pylint: enable=E1101 - - t = timeit.default_timer() - for d in request_data: - tensor_proto = make_tensor_proto(d['data'], d['in_tensor_dtype']) - # pylint: disable=E1101 - request.inputs[d['in_tensor_name']].CopyFrom(tensor_proto) - - self.logger.debug('Made tensor protos in %s seconds.', - timeit.default_timer() - t) - - try: - t = timeit.default_timer() - response = stub.Process(request, timeout=request_timeout) - self.logger.debug('gRPC DataProcessingRequest finished in %s ' - 'seconds.', timeit.default_timer() - t) - - t = timeit.default_timer() - response_dict = grpc_response_to_dict(response) - self.logger.debug('gRPC DataProcessingProtobufConversion took %s ' - 'seconds.', timeit.default_timer() - t) - - keys = [k for k in response_dict] - self.logger.debug('Got processing_response with keys: %s', keys) - channel.close() - return response_dict - - except RpcError as err: - self.logger.error('Processing failed due to: %s', err) - channel.close() - raise err - - channel.close() - return {} - - def stream_process(self, request_data, request_timeout=10): - self.logger.info('Sending request to %s %s-process data with the ' - 'data-processing API at %s.', self.function_name, - self.process_type, self.host) - - # Create gRPC client and request - channel = self.insecure_channel() - - t = timeit.default_timer() - stub = ProcessingServiceStub(channel) - self.logger.debug('Created stub in %s seconds.', - timeit.default_timer() - t) - chunk_size = 64 * 1024 # 64 kB is recommended payload size - - def request_iterator(image): - dtype = str(image.dtype) - shape = list(image.shape) - bytearr = image.tobytes() - - self.logger.info('Streaming %s bytes in %s requests', - len(bytearr), chunk_size % len(bytearr)) - - for i in range(0, len(bytearr), chunk_size): - request = ChunkedProcessRequest() - # pylint: disable=E1101 - request.function_spec.name = self.function_name - request.function_spec.type = self.process_type - request.shape[:] = shape - request.dtype = dtype - request.inputs['data'] = bytearr[i: i + chunk_size] - # pylint: enable=E1101 - yield request - - try: - t = timeit.default_timer() - req_iter = request_iterator(request_data[0]['data']) - res_iter = stub.StreamProcess(req_iter, timeout=request_timeout) - - shape = None - dtype = None - processed_bytes = [] - for response in res_iter: - shape = tuple(response.shape) - dtype = str(response.dtype) - processed_bytes.append(response.outputs['data']) - - npbytes = b''.join(processed_bytes) - # Got response stream of %s bytes in %s seconds. - self.logger.info('gRPC DataProcessingStreamRequest of %s bytes ' - 'finished in %s seconds.', len(npbytes), - timeit.default_timer() - t) - - t = timeit.default_timer() - processed_image = np.frombuffer(npbytes, dtype=dtype) - results = processed_image.reshape(shape) - self.logger.info('gRPC DataProcessingStreamConversion from %s bytes' - ' to a numpy array of shape %s in %s seconds.', - len(npbytes), results.shape, - timeit.default_timer() - t) - channel.close() - return {'results': results} - - except RpcError as err: - self.logger.error('Processing failed due to: %s', err) - channel.close() - raise err - - channel.close() - return {} - - class TrackingClient(GrpcClient): """gRPC Client for tensorflow-serving API. diff --git a/redis_consumer/pbs/process_pb2.py b/redis_consumer/pbs/process_pb2.py deleted file mode 100644 index 55ee8eb1..00000000 --- a/redis_consumer/pbs/process_pb2.py +++ /dev/null @@ -1,442 +0,0 @@ -# Generated by the protocol buffer compiler. DO NOT EDIT! -# source: process.proto - -import sys -_b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) -from google.protobuf import descriptor as _descriptor -from google.protobuf import message as _message -from google.protobuf import reflection as _reflection -from google.protobuf import symbol_database as _symbol_database -# @@protoc_insertion_point(imports) - -_sym_db = _symbol_database.Default() - - -import redis_consumer.pbs.tensor_pb2 as tensor__pb2 -import redis_consumer.pbs.function_pb2 as function__pb2 - - -DESCRIPTOR = _descriptor.FileDescriptor( - name='process.proto', - package='tensorflow.serving', - syntax='proto3', - serialized_options=_b('\370\001\001'), - serialized_pb=_b('\n\rprocess.proto\x12\x12tensorflow.serving\x1a\x0ctensor.proto\x1a\x0e\x66unction.proto\"\xe8\x01\n\x0eProcessRequest\x12\x37\n\rfunction_spec\x18\x01 \x01(\x0b\x32 .tensorflow.serving.FunctionSpec\x12>\n\x06inputs\x18\x02 \x03(\x0b\x32..tensorflow.serving.ProcessRequest.InputsEntry\x12\x15\n\routput_filter\x18\x03 \x03(\t\x1a\x46\n\x0bInputsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12&\n\x05value\x18\x02 \x01(\x0b\x32\x17.tensorflow.TensorProto:\x02\x38\x01\"\x9d\x01\n\x0fProcessResponse\x12\x41\n\x07outputs\x18\x01 \x03(\x0b\x32\x30.tensorflow.serving.ProcessResponse.OutputsEntry\x1aG\n\x0cOutputsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12&\n\x05value\x18\x02 \x01(\x0b\x32\x17.tensorflow.TensorProto:\x02\x38\x01\"\xfb\x01\n\x15\x43hunkedProcessRequest\x12\x37\n\rfunction_spec\x18\x01 \x01(\x0b\x32 .tensorflow.serving.FunctionSpec\x12\x45\n\x06inputs\x18\x02 \x03(\x0b\x32\x35.tensorflow.serving.ChunkedProcessRequest.InputsEntry\x12\x15\n\routput_filter\x18\x03 \x03(\t\x12\r\n\x05shape\x18\x04 \x03(\x03\x12\r\n\x05\x64type\x18\x05 \x01(\t\x1a-\n\x0bInputsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\x0c:\x02\x38\x01\"\xb0\x01\n\x16\x43hunkedProcessResponse\x12H\n\x07outputs\x18\x01 \x03(\x0b\x32\x37.tensorflow.serving.ChunkedProcessResponse.OutputsEntry\x12\r\n\x05shape\x18\x04 \x03(\x03\x12\r\n\x05\x64type\x18\x05 \x01(\t\x1a.\n\x0cOutputsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\x0c:\x02\x38\x01\x42\x03\xf8\x01\x01\x62\x06proto3') - , - dependencies=[tensor__pb2.DESCRIPTOR,function__pb2.DESCRIPTOR,]) - - - - -_PROCESSREQUEST_INPUTSENTRY = _descriptor.Descriptor( - name='InputsEntry', - full_name='tensorflow.serving.ProcessRequest.InputsEntry', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='key', full_name='tensorflow.serving.ProcessRequest.InputsEntry.key', index=0, - number=1, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=_b("").decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='value', full_name='tensorflow.serving.ProcessRequest.InputsEntry.value', index=1, - number=2, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=_b('8\001'), - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=230, - serialized_end=300, -) - -_PROCESSREQUEST = _descriptor.Descriptor( - name='ProcessRequest', - full_name='tensorflow.serving.ProcessRequest', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='function_spec', full_name='tensorflow.serving.ProcessRequest.function_spec', index=0, - number=1, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='inputs', full_name='tensorflow.serving.ProcessRequest.inputs', index=1, - number=2, type=11, cpp_type=10, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='output_filter', full_name='tensorflow.serving.ProcessRequest.output_filter', index=2, - number=3, type=9, cpp_type=9, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[_PROCESSREQUEST_INPUTSENTRY, ], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=68, - serialized_end=300, -) - - -_PROCESSRESPONSE_OUTPUTSENTRY = _descriptor.Descriptor( - name='OutputsEntry', - full_name='tensorflow.serving.ProcessResponse.OutputsEntry', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='key', full_name='tensorflow.serving.ProcessResponse.OutputsEntry.key', index=0, - number=1, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=_b("").decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='value', full_name='tensorflow.serving.ProcessResponse.OutputsEntry.value', index=1, - number=2, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=_b('8\001'), - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=389, - serialized_end=460, -) - -_PROCESSRESPONSE = _descriptor.Descriptor( - name='ProcessResponse', - full_name='tensorflow.serving.ProcessResponse', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='outputs', full_name='tensorflow.serving.ProcessResponse.outputs', index=0, - number=1, type=11, cpp_type=10, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[_PROCESSRESPONSE_OUTPUTSENTRY, ], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=303, - serialized_end=460, -) - - -_CHUNKEDPROCESSREQUEST_INPUTSENTRY = _descriptor.Descriptor( - name='InputsEntry', - full_name='tensorflow.serving.ChunkedProcessRequest.InputsEntry', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='key', full_name='tensorflow.serving.ChunkedProcessRequest.InputsEntry.key', index=0, - number=1, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=_b("").decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='value', full_name='tensorflow.serving.ChunkedProcessRequest.InputsEntry.value', index=1, - number=2, type=12, cpp_type=9, label=1, - has_default_value=False, default_value=_b(""), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=_b('8\001'), - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=669, - serialized_end=714, -) - -_CHUNKEDPROCESSREQUEST = _descriptor.Descriptor( - name='ChunkedProcessRequest', - full_name='tensorflow.serving.ChunkedProcessRequest', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='function_spec', full_name='tensorflow.serving.ChunkedProcessRequest.function_spec', index=0, - number=1, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='inputs', full_name='tensorflow.serving.ChunkedProcessRequest.inputs', index=1, - number=2, type=11, cpp_type=10, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='output_filter', full_name='tensorflow.serving.ChunkedProcessRequest.output_filter', index=2, - number=3, type=9, cpp_type=9, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='shape', full_name='tensorflow.serving.ChunkedProcessRequest.shape', index=3, - number=4, type=3, cpp_type=2, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='dtype', full_name='tensorflow.serving.ChunkedProcessRequest.dtype', index=4, - number=5, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=_b("").decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[_CHUNKEDPROCESSREQUEST_INPUTSENTRY, ], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=463, - serialized_end=714, -) - - -_CHUNKEDPROCESSRESPONSE_OUTPUTSENTRY = _descriptor.Descriptor( - name='OutputsEntry', - full_name='tensorflow.serving.ChunkedProcessResponse.OutputsEntry', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='key', full_name='tensorflow.serving.ChunkedProcessResponse.OutputsEntry.key', index=0, - number=1, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=_b("").decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='value', full_name='tensorflow.serving.ChunkedProcessResponse.OutputsEntry.value', index=1, - number=2, type=12, cpp_type=9, label=1, - has_default_value=False, default_value=_b(""), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=_b('8\001'), - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=847, - serialized_end=893, -) - -_CHUNKEDPROCESSRESPONSE = _descriptor.Descriptor( - name='ChunkedProcessResponse', - full_name='tensorflow.serving.ChunkedProcessResponse', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='outputs', full_name='tensorflow.serving.ChunkedProcessResponse.outputs', index=0, - number=1, type=11, cpp_type=10, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='shape', full_name='tensorflow.serving.ChunkedProcessResponse.shape', index=1, - number=4, type=3, cpp_type=2, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='dtype', full_name='tensorflow.serving.ChunkedProcessResponse.dtype', index=2, - number=5, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=_b("").decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[_CHUNKEDPROCESSRESPONSE_OUTPUTSENTRY, ], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=717, - serialized_end=893, -) - -_PROCESSREQUEST_INPUTSENTRY.fields_by_name['value'].message_type = tensor__pb2._TENSORPROTO -_PROCESSREQUEST_INPUTSENTRY.containing_type = _PROCESSREQUEST -_PROCESSREQUEST.fields_by_name['function_spec'].message_type = function__pb2._FUNCTIONSPEC -_PROCESSREQUEST.fields_by_name['inputs'].message_type = _PROCESSREQUEST_INPUTSENTRY -_PROCESSRESPONSE_OUTPUTSENTRY.fields_by_name['value'].message_type = tensor__pb2._TENSORPROTO -_PROCESSRESPONSE_OUTPUTSENTRY.containing_type = _PROCESSRESPONSE -_PROCESSRESPONSE.fields_by_name['outputs'].message_type = _PROCESSRESPONSE_OUTPUTSENTRY -_CHUNKEDPROCESSREQUEST_INPUTSENTRY.containing_type = _CHUNKEDPROCESSREQUEST -_CHUNKEDPROCESSREQUEST.fields_by_name['function_spec'].message_type = function__pb2._FUNCTIONSPEC -_CHUNKEDPROCESSREQUEST.fields_by_name['inputs'].message_type = _CHUNKEDPROCESSREQUEST_INPUTSENTRY -_CHUNKEDPROCESSRESPONSE_OUTPUTSENTRY.containing_type = _CHUNKEDPROCESSRESPONSE -_CHUNKEDPROCESSRESPONSE.fields_by_name['outputs'].message_type = _CHUNKEDPROCESSRESPONSE_OUTPUTSENTRY -DESCRIPTOR.message_types_by_name['ProcessRequest'] = _PROCESSREQUEST -DESCRIPTOR.message_types_by_name['ProcessResponse'] = _PROCESSRESPONSE -DESCRIPTOR.message_types_by_name['ChunkedProcessRequest'] = _CHUNKEDPROCESSREQUEST -DESCRIPTOR.message_types_by_name['ChunkedProcessResponse'] = _CHUNKEDPROCESSRESPONSE -_sym_db.RegisterFileDescriptor(DESCRIPTOR) - -ProcessRequest = _reflection.GeneratedProtocolMessageType('ProcessRequest', (_message.Message,), dict( - - InputsEntry = _reflection.GeneratedProtocolMessageType('InputsEntry', (_message.Message,), dict( - DESCRIPTOR = _PROCESSREQUEST_INPUTSENTRY, - __module__ = 'process_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.serving.ProcessRequest.InputsEntry) - )) - , - DESCRIPTOR = _PROCESSREQUEST, - __module__ = 'process_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.serving.ProcessRequest) - )) -_sym_db.RegisterMessage(ProcessRequest) -_sym_db.RegisterMessage(ProcessRequest.InputsEntry) - -ProcessResponse = _reflection.GeneratedProtocolMessageType('ProcessResponse', (_message.Message,), dict( - - OutputsEntry = _reflection.GeneratedProtocolMessageType('OutputsEntry', (_message.Message,), dict( - DESCRIPTOR = _PROCESSRESPONSE_OUTPUTSENTRY, - __module__ = 'process_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.serving.ProcessResponse.OutputsEntry) - )) - , - DESCRIPTOR = _PROCESSRESPONSE, - __module__ = 'process_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.serving.ProcessResponse) - )) -_sym_db.RegisterMessage(ProcessResponse) -_sym_db.RegisterMessage(ProcessResponse.OutputsEntry) - -ChunkedProcessRequest = _reflection.GeneratedProtocolMessageType('ChunkedProcessRequest', (_message.Message,), dict( - - InputsEntry = _reflection.GeneratedProtocolMessageType('InputsEntry', (_message.Message,), dict( - DESCRIPTOR = _CHUNKEDPROCESSREQUEST_INPUTSENTRY, - __module__ = 'process_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.serving.ChunkedProcessRequest.InputsEntry) - )) - , - DESCRIPTOR = _CHUNKEDPROCESSREQUEST, - __module__ = 'process_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.serving.ChunkedProcessRequest) - )) -_sym_db.RegisterMessage(ChunkedProcessRequest) -_sym_db.RegisterMessage(ChunkedProcessRequest.InputsEntry) - -ChunkedProcessResponse = _reflection.GeneratedProtocolMessageType('ChunkedProcessResponse', (_message.Message,), dict( - - OutputsEntry = _reflection.GeneratedProtocolMessageType('OutputsEntry', (_message.Message,), dict( - DESCRIPTOR = _CHUNKEDPROCESSRESPONSE_OUTPUTSENTRY, - __module__ = 'process_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.serving.ChunkedProcessResponse.OutputsEntry) - )) - , - DESCRIPTOR = _CHUNKEDPROCESSRESPONSE, - __module__ = 'process_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.serving.ChunkedProcessResponse) - )) -_sym_db.RegisterMessage(ChunkedProcessResponse) -_sym_db.RegisterMessage(ChunkedProcessResponse.OutputsEntry) - - -DESCRIPTOR._options = None -_PROCESSREQUEST_INPUTSENTRY._options = None -_PROCESSRESPONSE_OUTPUTSENTRY._options = None -_CHUNKEDPROCESSREQUEST_INPUTSENTRY._options = None -_CHUNKEDPROCESSRESPONSE_OUTPUTSENTRY._options = None -# @@protoc_insertion_point(module_scope) diff --git a/redis_consumer/pbs/process_pb2_grpc.py b/redis_consumer/pbs/process_pb2_grpc.py deleted file mode 100644 index a8943526..00000000 --- a/redis_consumer/pbs/process_pb2_grpc.py +++ /dev/null @@ -1,3 +0,0 @@ -# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! -import grpc - diff --git a/redis_consumer/pbs/processing_service_pb2.py b/redis_consumer/pbs/processing_service_pb2.py deleted file mode 100644 index 0d081158..00000000 --- a/redis_consumer/pbs/processing_service_pb2.py +++ /dev/null @@ -1,66 +0,0 @@ -# Generated by the protocol buffer compiler. DO NOT EDIT! -# source: processing_service.proto - -import sys -_b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) -from google.protobuf import descriptor as _descriptor -from google.protobuf import message as _message -from google.protobuf import reflection as _reflection -from google.protobuf import symbol_database as _symbol_database -# @@protoc_insertion_point(imports) - -_sym_db = _symbol_database.Default() - - -import redis_consumer.pbs.process_pb2 as process__pb2 - - -DESCRIPTOR = _descriptor.FileDescriptor( - name='processing_service.proto', - package='tensorflow.serving', - syntax='proto3', - serialized_options=_b('\370\001\001'), - serialized_pb=_b('\n\x18processing_service.proto\x12\x12tensorflow.serving\x1a\rprocess.proto2\xd3\x01\n\x11ProcessingService\x12R\n\x07Process\x12\".tensorflow.serving.ProcessRequest\x1a#.tensorflow.serving.ProcessResponse\x12j\n\rStreamProcess\x12).tensorflow.serving.ChunkedProcessRequest\x1a*.tensorflow.serving.ChunkedProcessResponse(\x01\x30\x01\x42\x03\xf8\x01\x01\x62\x06proto3') - , - dependencies=[process__pb2.DESCRIPTOR,]) - - - -_sym_db.RegisterFileDescriptor(DESCRIPTOR) - - -DESCRIPTOR._options = None - -_PROCESSINGSERVICE = _descriptor.ServiceDescriptor( - name='ProcessingService', - full_name='tensorflow.serving.ProcessingService', - file=DESCRIPTOR, - index=0, - serialized_options=None, - serialized_start=64, - serialized_end=275, - methods=[ - _descriptor.MethodDescriptor( - name='Process', - full_name='tensorflow.serving.ProcessingService.Process', - index=0, - containing_service=None, - input_type=process__pb2._PROCESSREQUEST, - output_type=process__pb2._PROCESSRESPONSE, - serialized_options=None, - ), - _descriptor.MethodDescriptor( - name='StreamProcess', - full_name='tensorflow.serving.ProcessingService.StreamProcess', - index=1, - containing_service=None, - input_type=process__pb2._CHUNKEDPROCESSREQUEST, - output_type=process__pb2._CHUNKEDPROCESSRESPONSE, - serialized_options=None, - ), -]) -_sym_db.RegisterServiceDescriptor(_PROCESSINGSERVICE) - -DESCRIPTOR.services_by_name['ProcessingService'] = _PROCESSINGSERVICE - -# @@protoc_insertion_point(module_scope) diff --git a/redis_consumer/pbs/processing_service_pb2_grpc.py b/redis_consumer/pbs/processing_service_pb2_grpc.py deleted file mode 100644 index 8c4972ae..00000000 --- a/redis_consumer/pbs/processing_service_pb2_grpc.py +++ /dev/null @@ -1,63 +0,0 @@ -# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! -import grpc - -import redis_consumer.pbs.process_pb2 as process__pb2 - - -class ProcessingServiceStub(object): - """ProcessingService provides access to data processing functions - """ - - def __init__(self, channel): - """Constructor. - - Args: - channel: A grpc.Channel. - """ - self.Process = channel.unary_unary( - '/tensorflow.serving.ProcessingService/Process', - request_serializer=process__pb2.ProcessRequest.SerializeToString, - response_deserializer=process__pb2.ProcessResponse.FromString, - ) - self.StreamProcess = channel.stream_stream( - '/tensorflow.serving.ProcessingService/StreamProcess', - request_serializer=process__pb2.ChunkedProcessRequest.SerializeToString, - response_deserializer=process__pb2.ChunkedProcessResponse.FromString, - ) - - -class ProcessingServiceServicer(object): - """ProcessingService provides access to data processing functions - """ - - def Process(self, request, context): - """Process -- provides access to a data processing function - """ - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') - - def StreamProcess(self, request_iterator, context): - # missing associated documentation comment in .proto file - pass - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') - - -def add_ProcessingServiceServicer_to_server(servicer, server): - rpc_method_handlers = { - 'Process': grpc.unary_unary_rpc_method_handler( - servicer.Process, - request_deserializer=process__pb2.ProcessRequest.FromString, - response_serializer=process__pb2.ProcessResponse.SerializeToString, - ), - 'StreamProcess': grpc.stream_stream_rpc_method_handler( - servicer.StreamProcess, - request_deserializer=process__pb2.ChunkedProcessRequest.FromString, - response_serializer=process__pb2.ChunkedProcessResponse.SerializeToString, - ), - } - generic_handler = grpc.method_handlers_generic_handler( - 'tensorflow.serving.ProcessingService', rpc_method_handlers) - server.add_generic_rpc_handlers((generic_handler,)) From 7b327c053a5fb10b874f3d697f619ec0a1611898 Mon Sep 17 00:00:00 2001 From: William Graf <7930703+willgraf@users.noreply.github.com> Date: Sun, 1 Mar 2020 21:48:03 -0800 Subject: [PATCH 02/47] update protos for 1.15, add model metadata request. --- protos/attr_value.proto | 62 +++++ protos/function.proto | 115 +++++++++- protos/get_model_metadata.proto | 30 +++ protos/graph.proto | 56 +++++ protos/meta_graph.proto | 342 ++++++++++++++++++++++++++++ protos/model.proto | 25 +- protos/node_def.proto | 86 +++++++ protos/op_def.proto | 170 ++++++++++++++ protos/predict.proto | 16 +- protos/prediction_service.proto | 18 ++ protos/resource_handle.proto | 17 +- protos/saved_object_graph.proto | 164 +++++++++++++ protos/saver.proto | 47 ++++ protos/struct.proto | 134 +++++++++++ protos/tensor.proto | 36 ++- protos/tensor_shape.proto | 1 + protos/trackable_object_graph.proto | 59 +++++ protos/types.proto | 20 +- protos/variable.proto | 85 +++++++ protos/versions.proto | 32 +++ 20 files changed, 1483 insertions(+), 32 deletions(-) create mode 100644 protos/attr_value.proto create mode 100644 protos/get_model_metadata.proto create mode 100644 protos/graph.proto create mode 100644 protos/meta_graph.proto create mode 100644 protos/node_def.proto create mode 100644 protos/op_def.proto create mode 100644 protos/saved_object_graph.proto create mode 100644 protos/saver.proto create mode 100644 protos/struct.proto create mode 100644 protos/trackable_object_graph.proto create mode 100644 protos/variable.proto create mode 100644 protos/versions.proto diff --git a/protos/attr_value.proto b/protos/attr_value.proto new file mode 100644 index 00000000..76944f77 --- /dev/null +++ b/protos/attr_value.proto @@ -0,0 +1,62 @@ +syntax = "proto3"; + +package tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "AttrValueProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; +option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framework"; +import "tensor.proto"; +import "tensor_shape.proto"; +import "types.proto"; + +// Protocol buffer representing the value for an attr used to configure an Op. +// Comment indicates the corresponding attr type. Only the field matching the +// attr type may be filled. +message AttrValue { + // LINT.IfChange + message ListValue { + repeated bytes s = 2; // "list(string)" + repeated int64 i = 3 [packed = true]; // "list(int)" + repeated float f = 4 [packed = true]; // "list(float)" + repeated bool b = 5 [packed = true]; // "list(bool)" + repeated DataType type = 6 [packed = true]; // "list(type)" + repeated TensorShapeProto shape = 7; // "list(shape)" + repeated TensorProto tensor = 8; // "list(tensor)" + repeated NameAttrList func = 9; // "list(attr)" + } + // LINT.ThenChange(https://www.tensorflow.org/code/tensorflow/c/c_api.cc) + + oneof value { + bytes s = 2; // "string" + int64 i = 3; // "int" + float f = 4; // "float" + bool b = 5; // "bool" + DataType type = 6; // "type" + TensorShapeProto shape = 7; // "shape" + TensorProto tensor = 8; // "tensor" + ListValue list = 1; // any "list(...)" + + // "func" represents a function. func.name is a function's name or + // a primitive op's name. func.attr.first is the name of an attr + // defined for that function. func.attr.second is the value for + // that attr in the instantiation. + NameAttrList func = 10; + + // This is a placeholder only used in nodes defined inside a + // function. It indicates the attr value will be supplied when + // the function is instantiated. For example, let us suppose a + // node "N" in function "FN". "N" has an attr "A" with value + // placeholder = "foo". When FN is instantiated with attr "foo" + // set to "bar", the instantiated node N's attr A will have been + // given the value "bar". + string placeholder = 9; + } +} + +// A list of attr names and their values. The whole list is attached +// with a string name. E.g., MatMul[T=float]. +message NameAttrList { + string name = 1; + map attr = 2; +} diff --git a/protos/function.proto b/protos/function.proto index 0c8cc89a..6d107635 100644 --- a/protos/function.proto +++ b/protos/function.proto @@ -1,12 +1,113 @@ syntax = "proto3"; -package tensorflow.serving; +package tensorflow; option cc_enable_arenas = true; +option java_outer_classname = "FunctionProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; +option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framework"; +import "attr_value.proto"; +import "node_def.proto"; +import "op_def.proto"; -// Metadata for an inference request such as the processing name and type -message FunctionSpec { - // Required function name. - string name = 1; - // Required function type. - string type = 2; +// A library is a set of named functions. +message FunctionDefLibrary { + repeated FunctionDef function = 1; + repeated GradientDef gradient = 2; +} + +// A function can be instantiated when the runtime can bind every attr +// with a value. When a GraphDef has a call to a function, it must +// have binding for every attr defined in the signature. +// +// TODO(zhifengc): +// * device spec, etc. +message FunctionDef { + // The definition of the function's name, arguments, return values, + // attrs etc. + OpDef signature = 1; + + // Attributes specific to this function definition. + map attr = 5; + + // Attributes for function arguments. These attributes are the same set of + // valid attributes as to _Arg nodes. + message ArgAttrs { + map attr = 1; + } + map arg_attr = 7; + + // NOTE: field id 2 deleted on Jan 11, 2017, GraphDef version 21. + reserved 2; + + // In both of the following fields, there is the need to specify an + // output that is used as either the input to another node (in + // `node_def`) or as a return value of the function (in `ret`). + // Unlike the NodeDefs in GraphDef, we need to be able to specify a + // list in some cases (instead of just single outputs). Also, we + // need to be able to deal with lists of unknown length (so the + // output index may not be known at function definition time). So + // we use the following format instead: + // * "fun_in" where "fun_in" is the name of a function input arg in + // the `signature` field above. This represents that input, whether + // it is a single tensor or a list. + // * "fun_in:0" gives the first element of a function input arg (a + // non-list input is considered a list of length 1 for these + // purposes). + // * "node:out" where "node" is the name of a node in `node_def` and + // "out" is the name one of its op's output arguments (the name + // comes from the OpDef of the node's op). This represents that + // node's output, whether it is a single tensor or a list. + // Note: We enforce that an op's output arguments are never + // renamed in the backwards-compatibility test. + // * "node:out:0" gives the first element of a node output arg (a + // non-list output is considered a list of length 1 for these + // purposes). + // + // NOT CURRENTLY SUPPORTED (but may be in the future): + // * "node:out:-1" gives last element in a node output list + // * "node:out:1:" gives a list with all but the first element in a + // node output list + // * "node:out::-1" gives a list with all but the last element in a + // node output list + + // The body of the function. Unlike the NodeDefs in a GraphDef, attrs + // may have values of type `placeholder` and the `input` field uses + // the "output" format above. + + // By convention, "op" in node_def is resolved by consulting with a + // user-defined library first. If not resolved, "func" is assumed to + // be a builtin op. + repeated NodeDef node_def = 3; + + // A mapping from the output arg names from `signature` to the + // outputs from `node_def` that should be returned by the function. + map ret = 4; + + // A mapping from control output names from `signature` to node names in + // `node_def` which should be control outputs of this function. + map control_ret = 6; +} + +// GradientDef defines the gradient function of a function defined in +// a function library. +// +// A gradient function g (specified by gradient_func) for a function f +// (specified by function_name) must follow the following: +// +// The function 'f' must be a numerical function which takes N inputs +// and produces M outputs. Its gradient function 'g', which is a +// function taking N + M inputs and produces N outputs. +// +// I.e. if we have +// (y1, y2, ..., y_M) = f(x1, x2, ..., x_N), +// then, g is +// (dL/dx1, dL/dx2, ..., dL/dx_N) = g(x1, x2, ..., x_N, +// dL/dy1, dL/dy2, ..., dL/dy_M), +// where L is a scalar-value function of (x1, x2, ..., xN) (e.g., the +// loss function). dL/dx_i is the partial derivative of L with respect +// to x_i. +message GradientDef { + string function_name = 1; // The function name. + string gradient_func = 2; // The gradient function's name. } diff --git a/protos/get_model_metadata.proto b/protos/get_model_metadata.proto new file mode 100644 index 00000000..60ddfd56 --- /dev/null +++ b/protos/get_model_metadata.proto @@ -0,0 +1,30 @@ +syntax = "proto3"; + +package tensorflow.serving; +option cc_enable_arenas = true; + +import "google/protobuf/any.proto"; +import "meta_graph.proto"; +import "model.proto"; + +// Message returned for "signature_def" field. +message SignatureDefMap { + map signature_def = 1; +}; + +message GetModelMetadataRequest { + // Model Specification indicating which model we are querying for metadata. + // If version is not specified, will use the latest (numerical) version. + ModelSpec model_spec = 1; + // Metadata fields to get. Currently supported: "signature_def". + repeated string metadata_field = 2; +} + +message GetModelMetadataResponse { + // Model Specification indicating which model this metadata belongs to. + ModelSpec model_spec = 1; + // Map of metadata field name to metadata field. The options for metadata + // field name are listed in GetModelMetadataRequest. Currently supported: + // "signature_def". + map metadata = 2; +} diff --git a/protos/graph.proto b/protos/graph.proto new file mode 100644 index 00000000..14d9edfa --- /dev/null +++ b/protos/graph.proto @@ -0,0 +1,56 @@ +syntax = "proto3"; + +package tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "GraphProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; +option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framework"; +import "node_def.proto"; +import "function.proto"; +import "versions.proto"; + +// Represents the graph of operations +message GraphDef { + repeated NodeDef node = 1; + + // Compatibility versions of the graph. See core/public/version.h for version + // history. The GraphDef version is distinct from the TensorFlow version, and + // each release of TensorFlow will support a range of GraphDef versions. + VersionDef versions = 4; + + // Deprecated single version field; use versions above instead. Since all + // GraphDef changes before "versions" was introduced were forward + // compatible, this field is entirely ignored. + int32 version = 3 [deprecated = true]; + + // EXPERIMENTAL. DO NOT USE OR DEPEND ON THIS YET. + // + // "library" provides user-defined functions. + // + // Naming: + // * library.function.name are in a flat namespace. + // NOTE: We may need to change it to be hierarchical to support + // different orgs. E.g., + // { "/google/nn", { ... }}, + // { "/google/vision", { ... }} + // { "/org_foo/module_bar", { ... }} + // map named_lib; + // * If node[i].op is the name of one function in "library", + // node[i] is deemed as a function call. Otherwise, node[i].op + // must be a primitive operation supported by the runtime. + // + // + // Function call semantics: + // + // * The callee may start execution as soon as some of its inputs + // are ready. The caller may want to use Tuple() mechanism to + // ensure all inputs are ready in the same time. + // + // * The consumer of return values may start executing as soon as + // the return values the consumer depends on are ready. The + // consumer may want to use Tuple() mechanism to ensure the + // consumer does not start until all return values of the callee + // function are ready. + FunctionDefLibrary library = 2; +}; diff --git a/protos/meta_graph.proto b/protos/meta_graph.proto new file mode 100644 index 00000000..f1005543 --- /dev/null +++ b/protos/meta_graph.proto @@ -0,0 +1,342 @@ +syntax = "proto3"; + +package tensorflow; + +import "google/protobuf/any.proto"; +import "graph.proto"; +import "op_def.proto"; +import "tensor_shape.proto"; +import "types.proto"; +import "saved_object_graph.proto"; +import "saver.proto"; +import "struct.proto"; + +option cc_enable_arenas = true; +option java_outer_classname = "MetaGraphProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; +option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/core_protos_go_proto"; + +// NOTE: This protocol buffer is evolving, and will go through revisions in the +// coming months. +// +// Protocol buffer containing the following which are necessary to restart +// training, run inference. It can be used to serialize/de-serialize memory +// objects necessary for running computation in a graph when crossing the +// process boundary. It can be used for long term storage of graphs, +// cross-language execution of graphs, etc. +// MetaInfoDef +// GraphDef +// SaverDef +// CollectionDef +// TensorInfo +// SignatureDef +message MetaGraphDef { + // Meta information regarding the graph to be exported. To be used by users + // of this protocol buffer to encode information regarding their meta graph. + message MetaInfoDef { + // User specified Version string. Can be the name of the model and revision, + // steps this model has been trained to, etc. + string meta_graph_version = 1; + + // A copy of the OpDefs used by the producer of this graph_def. + // Descriptions and Ops not used in graph_def are stripped out. + OpList stripped_op_list = 2; + + // A serialized protobuf. Can be the time this meta graph is created, or + // modified, or name of the model. + google.protobuf.Any any_info = 3; + + // User supplied tag(s) on the meta_graph and included graph_def. + // + // MetaGraphDefs should be tagged with their capabilities or use-cases. + // Examples: "train", "serve", "gpu", "tpu", etc. + // These tags enable loaders to access the MetaGraph(s) appropriate for a + // specific use-case or runtime environment. + repeated string tags = 4; + + // The __version__ string of the tensorflow build used to write this graph. + // This will be populated by the framework, which will overwrite any user + // supplied value. + string tensorflow_version = 5; + + // The __git_version__ string of the tensorflow build used to write this + // graph. This will be populated by the framework, which will overwrite any + // user supplied value. + string tensorflow_git_version = 6; + + // A flag to denote whether default-valued attrs have been stripped from + // the nodes in this graph_def. + bool stripped_default_attrs = 7; + + // FunctionDef name to aliases mapping. + map function_aliases = 8; + } + MetaInfoDef meta_info_def = 1; + + // GraphDef. + GraphDef graph_def = 2; + + // SaverDef. + SaverDef saver_def = 3; + + // collection_def: Map from collection name to collections. + // See CollectionDef section for details. + map collection_def = 4; + + // signature_def: Map from user supplied key for a signature to a single + // SignatureDef. + map signature_def = 5; + + // Asset file def to be used with the defined graph. + repeated AssetFileDef asset_file_def = 6; + + // Extra information about the structure of functions and stateful objects. + SavedObjectGraph object_graph_def = 7; +} + +// CollectionDef should cover most collections. +// To add a user-defined collection, do one of the following: +// 1. For simple data types, such as string, int, float: +// tf.add_to_collection("your_collection_name", your_simple_value) +// strings will be stored as bytes_list. +// +// 2. For Protobuf types, there are three ways to add them: +// 1) tf.add_to_collection("your_collection_name", +// your_proto.SerializeToString()) +// +// collection_def { +// key: "user_defined_bytes_collection" +// value { +// bytes_list { +// value: "queue_name: \"test_queue\"\n" +// } +// } +// } +// +// or +// +// 2) tf.add_to_collection("your_collection_name", str(your_proto)) +// +// collection_def { +// key: "user_defined_string_collection" +// value { +// bytes_list { +// value: "\n\ntest_queue" +// } +// } +// } +// +// or +// +// 3) any_buf = any_pb2.Any() +// tf.add_to_collection("your_collection_name", +// any_buf.Pack(your_proto)) +// +// collection_def { +// key: "user_defined_any_collection" +// value { +// any_list { +// value { +// type_url: "type.googleapis.com/tensorflow.QueueRunnerDef" +// value: "\n\ntest_queue" +// } +// } +// } +// } +// +// 3. For Python objects, implement to_proto() and from_proto(), and register +// them in the following manner: +// ops.register_proto_function("your_collection_name", +// proto_type, +// to_proto=YourPythonObject.to_proto, +// from_proto=YourPythonObject.from_proto) +// These functions will be invoked to serialize and de-serialize the +// collection. For example, +// ops.register_proto_function(ops.GraphKeys.GLOBAL_VARIABLES, +// proto_type=variable_pb2.VariableDef, +// to_proto=Variable.to_proto, +// from_proto=Variable.from_proto) +message CollectionDef { + // NodeList is used for collecting nodes in graph. For example + // collection_def { + // key: "summaries" + // value { + // node_list { + // value: "input_producer/ScalarSummary:0" + // value: "shuffle_batch/ScalarSummary:0" + // value: "ImageSummary:0" + // } + // } + message NodeList { + repeated string value = 1; + } + + // BytesList is used for collecting strings and serialized protobufs. For + // example: + // collection_def { + // key: "trainable_variables" + // value { + // bytes_list { + // value: "\n\017conv1/weights:0\022\024conv1/weights/Assign + // \032\024conv1/weights/read:0" + // value: "\n\016conv1/biases:0\022\023conv1/biases/Assign\032 + // \023conv1/biases/read:0" + // } + // } + // } + message BytesList { + repeated bytes value = 1; + } + + // Int64List is used for collecting int, int64 and long values. + message Int64List { + repeated int64 value = 1 [packed = true]; + } + + // FloatList is used for collecting float values. + message FloatList { + repeated float value = 1 [packed = true]; + } + + // AnyList is used for collecting Any protos. + message AnyList { + repeated google.protobuf.Any value = 1; + } + + oneof kind { + NodeList node_list = 1; + BytesList bytes_list = 2; + Int64List int64_list = 3; + FloatList float_list = 4; + AnyList any_list = 5; + } +} + +// Information about a Tensor necessary for feeding or retrieval. +message TensorInfo { + // For sparse tensors, The COO encoding stores a triple of values, indices, + // and shape. + message CooSparse { + // The shape of the values Tensor is [?]. Its dtype must be the dtype of + // the SparseTensor as a whole, given in the enclosing TensorInfo. + string values_tensor_name = 1; + + // The indices Tensor must have dtype int64 and shape [?, ?]. + string indices_tensor_name = 2; + + // The dynamic logical shape represented by the SparseTensor is recorded in + // the Tensor referenced here. It must have dtype int64 and shape [?]. + string dense_shape_tensor_name = 3; + } + + // Generic encoding for composite tensors. + message CompositeTensor { + // The serialized TypeSpec for the composite tensor. + TypeSpecProto type_spec = 1; + + // A TensorInfo for each flattened component tensor. + repeated TensorInfo components = 2; + } + + oneof encoding { + // For dense `Tensor`s, the name of the tensor in the graph. + string name = 1; + // There are many possible encodings of sparse matrices + // (https://en.wikipedia.org/wiki/Sparse_matrix). Currently, TensorFlow + // uses only the COO encoding. This is supported and documented in the + // SparseTensor Python class. + CooSparse coo_sparse = 4; + // Generic encoding for CompositeTensors. + CompositeTensor composite_tensor = 5; + } + DataType dtype = 2; + // The static shape should be recorded here, to the extent that it can + // be known in advance. In the case of a SparseTensor, this field describes + // the logical shape of the represented tensor (aka dense_shape). + TensorShapeProto tensor_shape = 3; +} + +// SignatureDef defines the signature of a computation supported by a TensorFlow +// graph. +// +// For example, a model with two loss computations, sharing a single input, +// might have the following signature_def map. +// +// Note that across the two SignatureDefs "loss_A" and "loss_B", the input key, +// output key, and method_name are identical, and will be used by system(s) that +// implement or rely upon this particular loss method. The output tensor names +// differ, demonstrating how different outputs can exist for the same method. +// +// signature_def { +// key: "loss_A" +// value { +// inputs { +// key: "input" +// value { +// name: "input:0" +// dtype: DT_STRING +// tensor_shape: ... +// } +// } +// outputs { +// key: "loss_output" +// value { +// name: "loss_output_A:0" +// dtype: DT_FLOAT +// tensor_shape: ... +// } +// } +// } +// ... +// method_name: "some/package/compute_loss" +// } +// signature_def { +// key: "loss_B" +// value { +// inputs { +// key: "input" +// value { +// name: "input:0" +// dtype: DT_STRING +// tensor_shape: ... +// } +// } +// outputs { +// key: "loss_output" +// value { +// name: "loss_output_B:0" +// dtype: DT_FLOAT +// tensor_shape: ... +// } +// } +// } +// ... +// method_name: "some/package/compute_loss" +// } +message SignatureDef { + // Named input parameters. + map inputs = 1; + // Named output parameters. + map outputs = 2; + // Extensible method_name information enabling third-party users to mark a + // SignatureDef as supporting a particular method. This enables producers and + // consumers of SignatureDefs, e.g. a model definition library and a serving + // library to have a clear hand-off regarding the semantics of a computation. + // + // Note that multiple SignatureDefs in a single MetaGraphDef may have the same + // method_name. This is commonly used to support multi-headed computation, + // where a single graph computation may return multiple results. + string method_name = 3; +} + +// An asset file def for a single file or a set of sharded files with the same +// name. +message AssetFileDef { + // The tensor to bind the asset filename to. + TensorInfo tensor_info = 1; + // The filename within an assets directory. Note: does not include the path + // prefix, i.e. directories. For an asset at /tmp/path/vocab.txt, the filename + // would be "vocab.txt". + string filename = 2; +} diff --git a/protos/model.proto b/protos/model.proto index 889303ff..56493f68 100644 --- a/protos/model.proto +++ b/protos/model.proto @@ -10,9 +10,24 @@ message ModelSpec { // Required servable name. string name = 1; - // Optional version. If unspecified, will use the latest (numerical) version. - // Typically not needed unless coordinating across multiple models that were - // co-trained and/or have inter-dependencies on the versions used at inference - // time. - google.protobuf.Int64Value version = 2; + // Optional choice of which version of the model to use. + // + // Recommended to be left unset in the common case. Should be specified only + // when there is a strong version consistency requirement. + // + // When left unspecified, the system will serve the best available version. + // This is typically the latest version, though during version transitions, + // notably when serving on a fleet of instances, may be either the previous or + // new version. + oneof version_choice { + // Use this specific version number. + google.protobuf.Int64Value version = 2; + + // Use the version associated with the given label. + string version_label = 4; + } + + // A named signature to evaluate. If unspecified, the default signature will + // be used. + string signature_name = 3; } diff --git a/protos/node_def.proto b/protos/node_def.proto new file mode 100644 index 00000000..1e0da16f --- /dev/null +++ b/protos/node_def.proto @@ -0,0 +1,86 @@ +syntax = "proto3"; + +package tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "NodeProto"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; +option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framework"; +import "attr_value.proto"; + +message NodeDef { + // The name given to this operator. Used for naming inputs, + // logging, visualization, etc. Unique within a single GraphDef. + // Must match the regexp "[A-Za-z0-9.][A-Za-z0-9_>./]*". + string name = 1; + + // The operation name. There may be custom parameters in attrs. + // Op names starting with an underscore are reserved for internal use. + string op = 2; + + // Each input is "node:src_output" with "node" being a string name and + // "src_output" indicating which output tensor to use from "node". If + // "src_output" is 0 the ":0" suffix can be omitted. Regular inputs + // may optionally be followed by control inputs that have the format + // "^node". + repeated string input = 3; + + // A (possibly partial) specification for the device on which this + // node should be placed. + // The expected syntax for this string is as follows: + // + // DEVICE_SPEC ::= PARTIAL_SPEC + // + // PARTIAL_SPEC ::= ("/" CONSTRAINT) * + // CONSTRAINT ::= ("job:" JOB_NAME) + // | ("replica:" [1-9][0-9]*) + // | ("task:" [1-9][0-9]*) + // | ("device:" [A-Za-z]* ":" ([1-9][0-9]* | "*") ) + // + // Valid values for this string include: + // * "/job:worker/replica:0/task:1/device:GPU:3" (full specification) + // * "/job:worker/device:GPU:3" (partial specification) + // * "" (no specification) + // + // If the constraints do not resolve to a single device (or if this + // field is empty or not present), the runtime will attempt to + // choose a device automatically. + string device = 4; + + // Operation-specific graph-construction-time configuration. + // Note that this should include all attrs defined in the + // corresponding OpDef, including those with a value matching + // the default -- this allows the default to change and makes + // NodeDefs easier to interpret on their own. However, if + // an attr with a default is not specified in this list, the + // default will be used. + // The "names" (keys) must match the regexp "[a-z][a-z0-9_]+" (and + // one of the names from the corresponding OpDef's attr field). + // The values must have a type matching the corresponding OpDef + // attr's type field. + // TODO(josh11b): Add some examples here showing best practices. + map attr = 5; + + message ExperimentalDebugInfo { + // Opaque string inserted into error messages created by the runtime. + // + // This is intended to store the list of names of the nodes from the + // original graph that this node was derived. For example if this node, say + // C, was result of a fusion of 2 nodes A and B, then 'original_node' would + // be {A, B}. This information can be used to map errors originating at the + // current node to some top level source code. + repeated string original_node_names = 1; + + // This is intended to store the list of names of the functions from the + // original graph that this node was derived. For example if this node, say + // C, was result of a fusion of node A in function FA and node B in function + // FB, then `original_funcs` would be {FA, FB}. If the node is in the top + // level graph, the `original_func` is empty. This information, with the + // `original_node_names` can be used to map errors originating at the + // current ndoe to some top level source code. + repeated string original_func_names = 2; + }; + + // This stores debug information associated with the node. + ExperimentalDebugInfo experimental_debug_info = 6; +}; diff --git a/protos/op_def.proto b/protos/op_def.proto new file mode 100644 index 00000000..9f5f6bf8 --- /dev/null +++ b/protos/op_def.proto @@ -0,0 +1,170 @@ +syntax = "proto3"; + +package tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "OpDefProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; +option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framework"; +import "attr_value.proto"; +import "types.proto"; + +// Defines an operation. A NodeDef in a GraphDef specifies an Op by +// using the "op" field which should match the name of a OpDef. +// LINT.IfChange +message OpDef { + // Op names starting with an underscore are reserved for internal use. + // Names should be CamelCase and match the regexp "[A-Z][a-zA-Z0-9>_]*". + string name = 1; + + // For describing inputs and outputs. + message ArgDef { + // Name for the input/output. Should match the regexp "[a-z][a-z0-9_]*". + string name = 1; + + // Human readable description. + string description = 2; + + // Describes the type of one or more tensors that are accepted/produced + // by this input/output arg. The only legal combinations are: + // * For a single tensor: either the "type" field is set or the + // "type_attr" field is set to the name of an attr with type "type". + // * For a sequence of tensors with the same type: the "number_attr" + // field will be set to the name of an attr with type "int", and + // either the "type" or "type_attr" field will be set as for + // single tensors. + // * For a sequence of tensors, the "type_list_attr" field will be set + // to the name of an attr with type "list(type)". + DataType type = 3; + string type_attr = 4; // if specified, attr must have type "type" + string number_attr = 5; // if specified, attr must have type "int" + // If specified, attr must have type "list(type)", and none of + // type, type_attr, and number_attr may be specified. + string type_list_attr = 6; + + // For inputs: if true, the inputs are required to be refs. + // By default, inputs can be either refs or non-refs. + // For outputs: if true, outputs are refs, otherwise they are not. + bool is_ref = 16; + }; + + // Description of the input(s). + repeated ArgDef input_arg = 2; + + // Description of the output(s). + repeated ArgDef output_arg = 3; + + // Named control outputs for this operation. Useful only for composite + // operations (i.e. functions) which want to name different control outputs. + repeated string control_output = 20; + + // Description of the graph-construction-time configuration of this + // Op. That is to say, this describes the attr fields that will + // be specified in the NodeDef. + message AttrDef { + // A descriptive name for the argument. May be used, e.g. by the + // Python client, as a keyword argument name, and so should match + // the regexp "[a-z][a-z0-9_]+". + string name = 1; + + // One of the type names from attr_value.proto ("string", "list(string)", + // "int", etc.). + string type = 2; + + // A reasonable default for this attribute if the user does not supply + // a value. If not specified, the user must supply a value. + AttrValue default_value = 3; + + // Human-readable description. + string description = 4; + + // TODO(josh11b): bool is_optional? + + // --- Constraints --- + // These constraints are only in effect if specified. Default is no + // constraints. + + // For type == "int", this is a minimum value. For "list(___)" + // types, this is the minimum length. + bool has_minimum = 5; + int64 minimum = 6; + + // The set of allowed values. Has type that is the "list" version + // of the "type" field above (uses the "list" field of AttrValue). + // If type == "type" or "list(type)" above, then the "type" field + // of "allowed_values.list" has the set of allowed DataTypes. + // If type == "string" or "list(string)", then the "s" field of + // "allowed_values.list" has the set of allowed strings. + AttrValue allowed_values = 7; + } + repeated AttrDef attr = 4; + + // Optional deprecation based on GraphDef versions. + OpDeprecation deprecation = 8; + + // One-line human-readable description of what the Op does. + string summary = 5; + + // Additional, longer human-readable description of what the Op does. + string description = 6; + + // ------------------------------------------------------------------------- + // Which optimizations this operation can participate in. + + // True if the operation is commutative ("op(a,b) == op(b,a)" for all inputs) + bool is_commutative = 18; + + // If is_aggregate is true, then this operation accepts N >= 2 + // inputs and produces 1 output all of the same type. Should be + // associative and commutative, and produce output with the same + // shape as the input. The optimizer may replace an aggregate op + // taking input from multiple devices with a tree of aggregate ops + // that aggregate locally within each device (and possibly within + // groups of nearby devices) before communicating. + // TODO(josh11b): Implement that optimization. + bool is_aggregate = 16; // for things like add + + // Other optimizations go here, like + // can_alias_input, rewrite_when_output_unused, partitioning_strategy, etc. + + // ------------------------------------------------------------------------- + // Optimization constraints. + + // Ops are marked as stateful if their behavior depends on some state beyond + // their input tensors (e.g. variable reading op) or if they have + // a side-effect (e.g. printing or asserting ops). Equivalently, stateless ops + // must always produce the same output for the same input and have + // no side-effects. + // + // By default Ops may be moved between devices. Stateful ops should + // either not be moved, or should only be moved if that state can also + // be moved (e.g. via some sort of save / restore). + // Stateful ops are guaranteed to never be optimized away by Common + // Subexpression Elimination (CSE). + bool is_stateful = 17; // for things like variables, queue + + // ------------------------------------------------------------------------- + // Non-standard options. + + // By default, all inputs to an Op must be initialized Tensors. Ops + // that may initialize tensors for the first time should set this + // field to true, to allow the Op to take an uninitialized Tensor as + // input. + bool allows_uninitialized_input = 19; // for Assign, etc. +}; +// LINT.ThenChange( +// https://www.tensorflow.org/code/tensorflow/core/framework/op_def_util.cc) + +// Information about version-dependent deprecation of an op +message OpDeprecation { + // First GraphDef version at which the op is disallowed. + int32 version = 1; + + // Explanation of why it was deprecated and what to use instead. + string explanation = 2; +}; + +// A collection of OpDefs +message OpList { + repeated OpDef op = 1; +}; diff --git a/protos/predict.proto b/protos/predict.proto index 10a9f63a..f4a83f8c 100644 --- a/protos/predict.proto +++ b/protos/predict.proto @@ -10,21 +10,20 @@ import "model.proto"; // how inputs are mapped to tensors and how outputs are filtered before // returning to user. message PredictRequest { - // Model Specification. + // Model Specification. If version is not specified, will use the latest + // (numerical) version. ModelSpec model_spec = 1; // Input tensors. // Names of input tensor are alias names. The mapping from aliases to real - // input tensor names is expected to be stored as named generic signature - // under the key "inputs" in the model export. - // Each alias listed in a generic signature named "inputs" should be provided - // exactly once in order to run the prediction. + // input tensor names is stored in the SavedModel export as a prediction + // SignatureDef under the 'inputs' field. map inputs = 2; // Output filter. // Names specified are alias names. The mapping from aliases to real output - // tensor names is expected to be stored as named generic signature under - // the key "outputs" in the model export. + // tensor names is stored in the SavedModel export as a prediction + // SignatureDef under the 'outputs' field. // Only tensors specified here will be run/fetched and returned, with the // exception that when none is specified, all tensors specified in the // named signature will be run/fetched and returned. @@ -33,6 +32,9 @@ message PredictRequest { // Response for PredictRequest on successful run. message PredictResponse { + // Effective Model Specification used to process PredictRequest. + ModelSpec model_spec = 2; + // Output tensors. map outputs = 1; } diff --git a/protos/prediction_service.proto b/protos/prediction_service.proto index 5160863e..681796a6 100644 --- a/protos/prediction_service.proto +++ b/protos/prediction_service.proto @@ -3,11 +3,29 @@ syntax = "proto3"; package tensorflow.serving; option cc_enable_arenas = true; +// import "tensorflow_serving/apis/classification.proto"; +// import "tensorflow_serving/apis/inference.proto"; +// import "tensorflow_serving/apis/regression.proto"; +import "get_model_metadata.proto"; import "predict.proto"; +// open source marker; do not remove // PredictionService provides access to machine-learned models loaded by // model_servers. service PredictionService { + // Classify. + // rpc Classify(ClassificationRequest) returns (ClassificationResponse); + + // Regress. + // rpc Regress(RegressionRequest) returns (RegressionResponse); + // Predict -- provides access to loaded TensorFlow model. rpc Predict(PredictRequest) returns (PredictResponse); + + // MultiInference API for multi-headed models. + // rpc MultiInference(MultiInferenceRequest) returns (MultiInferenceResponse); + + // GetModelMetadata - provides access to metadata for loaded models. + rpc GetModelMetadata(GetModelMetadataRequest) + returns (GetModelMetadataResponse); } diff --git a/protos/resource_handle.proto b/protos/resource_handle.proto index f9f19ca5..82194668 100644 --- a/protos/resource_handle.proto +++ b/protos/resource_handle.proto @@ -2,14 +2,18 @@ syntax = "proto3"; package tensorflow; option cc_enable_arenas = true; -option java_outer_classname = "ResourceHandleProto"; +option java_outer_classname = "ResourceHandle"; option java_multiple_files = true; option java_package = "org.tensorflow.framework"; +option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framework"; + +import "tensor_shape.proto"; +import "types.proto"; // Protocol buffer representing a handle to a tensorflow resource. Handles are // not valid across executions, but can be serialized back and forth from within // a single run. -message ResourceHandle { +message ResourceHandleProto { // Unique name for the device containing the resource. string device = 1; @@ -26,4 +30,13 @@ message ResourceHandle { // For debug-only, the name of the type pointed to by this handle, if // available. string maybe_type_name = 5; + + // Protocol buffer representing a pair of (data type, tensor shape). + message DtypeAndShape { + DataType dtype = 1; + TensorShapeProto shape = 2; + } + + // Data types and shapes for the underlying resource. + repeated DtypeAndShape dtypes_and_shapes = 6; }; diff --git a/protos/saved_object_graph.proto b/protos/saved_object_graph.proto new file mode 100644 index 00000000..0bad3a57 --- /dev/null +++ b/protos/saved_object_graph.proto @@ -0,0 +1,164 @@ +syntax = "proto3"; + +import "trackable_object_graph.proto"; +import "struct.proto"; +import "tensor_shape.proto"; +import "types.proto"; +import "versions.proto"; +import "variable.proto"; + +option cc_enable_arenas = true; + +package tensorflow; + +// A SavedObjectGraph is part of object-based SavedModels in TF 2.0. It +// describes the directed graph of Python objects (or equivalent in other +// languages) that make up a model, with nodes[0] at the root. + +// SavedObjectGraph shares some structure with TrackableObjectGraph, but +// SavedObjectGraph belongs to the MetaGraph and contains pointers to functions +// and type information, while TrackableObjectGraph lives in the checkpoint +// and contains pointers only to variable values. + +message SavedObjectGraph { + // Flattened list of objects in the object graph. + // + // The position of the object in this list indicates its id. + // Nodes[0] is considered the root node. + repeated SavedObject nodes = 1; + + // Information about captures and output structures in concrete functions. + // Referenced from SavedBareConcreteFunction and SavedFunction. + map concrete_functions = 2; +} + +message SavedObject { + // Objects which this object depends on: named edges in the dependency + // graph. + // + // Note: currently only valid if kind == "user_object". + repeated TrackableObjectGraph.TrackableObject.ObjectReference + children = 1; + + // Removed when forking SavedObject from TrackableObjectGraph. + reserved "attributes"; + reserved 2; + + // Slot variables owned by this object. This describes the three-way + // (optimizer, variable, slot variable) relationship; none of the three + // depend on the others directly. + // + // Note: currently only valid if kind == "user_object". + repeated TrackableObjectGraph.TrackableObject.SlotVariableReference + slot_variables = 3; + + oneof kind { + SavedUserObject user_object = 4; + SavedAsset asset = 5; + SavedFunction function = 6; + SavedVariable variable = 7; + SavedBareConcreteFunction bare_concrete_function = 8; + SavedConstant constant = 9; + SavedResource resource = 10; + } +} + +// A SavedUserObject is an object (in the object-oriented language of the +// TensorFlow program) of some user- or framework-defined class other than +// those handled specifically by the other kinds of SavedObjects. +// +// This object cannot be evaluated as a tensor, and therefore cannot be bound +// to an input of a function. +message SavedUserObject { + // Corresponds to a registration of the type to use in the loading program. + string identifier = 1; + // Version information from the producer of this SavedUserObject. + VersionDef version = 2; + // Initialization-related metadata. + string metadata = 3; +} + +// A SavedAsset points to an asset in the MetaGraph. +// +// When bound to a function this object evaluates to a tensor with the absolute +// filename. Users should not depend on a particular part of the filename to +// remain stable (e.g. basename could be changed). +message SavedAsset { + // Index into `MetaGraphDef.asset_file_def[]` that describes the Asset. + // + // Only the field `AssetFileDef.filename` is used. Other fields, such as + // `AssetFileDef.tensor_info`, MUST be ignored. + int32 asset_file_def_index = 1; +} + +// A function with multiple signatures, possibly with non-Tensor arguments. +message SavedFunction { + repeated string concrete_functions = 1; + FunctionSpec function_spec = 2; +} + +// Stores low-level information about a concrete function. Referenced in either +// a SavedFunction or a SavedBareConcreteFunction. +message SavedConcreteFunction { + // Bound inputs to the function. The SavedObjects identified by the node ids + // given here are appended as extra inputs to the caller-supplied inputs. + // The only types of SavedObjects valid here are SavedVariable, SavedResource + // and SavedAsset. + repeated int32 bound_inputs = 2; + // Input in canonicalized form that was received to create this concrete + // function. + StructuredValue canonicalized_input_signature = 3; + // Output that was the return value of this function after replacing all + // Tensors with TensorSpecs. This can be an arbitrary nested function and will + // be used to reconstruct the full structure from pure tensors. + StructuredValue output_signature = 4; +} + +message SavedBareConcreteFunction { + // Identifies a SavedConcreteFunction. + string concrete_function_name = 1; + + // A sequence of unique strings, one per Tensor argument. + repeated string argument_keywords = 2; + // The prefix of `argument_keywords` which may be identified by position. + int64 allowed_positional_arguments = 3; +} + +message SavedConstant { + // An Operation name for a ConstantOp in this SavedObjectGraph's MetaGraph. + string operation = 1; +} + +// Represents a Variable that is initialized by loading the contents from the +// checkpoint. +message SavedVariable { + DataType dtype = 1; + TensorShapeProto shape = 2; + bool trainable = 3; + VariableSynchronization synchronization = 4; + VariableAggregation aggregation = 5; + string name = 6; +} + +// Represents `FunctionSpec` used in `Function`. This represents a +// function that has been wrapped as a TensorFlow `Function`. +message FunctionSpec { + // Full arg spec from inspect.getfullargspec(). + StructuredValue fullargspec = 1; + // Whether this represents a class method. + bool is_method = 2; + // The input signature, if specified. + StructuredValue input_signature = 5; + + reserved 3, 4; +} + +// A SavedResource represents a TF object that holds state during its lifetime. +// An object of this type can have a reference to a: +// create_resource() and an initialize() function. +message SavedResource { + // A device specification indicating a required placement for the resource + // creation function, e.g. "CPU". An empty string allows the user to select a + // device. + string device = 1; +} diff --git a/protos/saver.proto b/protos/saver.proto new file mode 100644 index 00000000..42453861 --- /dev/null +++ b/protos/saver.proto @@ -0,0 +1,47 @@ +syntax = "proto3"; + +package tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "SaverProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.util"; +option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/protobuf"; + +// Protocol buffer representing the configuration of a Saver. +message SaverDef { + // The name of the tensor in which to specify the filename when saving or + // restoring a model checkpoint. + string filename_tensor_name = 1; + + // The operation to run when saving a model checkpoint. + string save_tensor_name = 2; + + // The operation to run when restoring a model checkpoint. + string restore_op_name = 3; + + // Maximum number of checkpoints to keep. If 0, no checkpoints are deleted. + int32 max_to_keep = 4; + + // Shard the save files, one per device that has Variable nodes. + bool sharded = 5; + + // How often to keep an additional checkpoint. If not specified, only the last + // "max_to_keep" checkpoints are kept; if specified, in addition to keeping + // the last "max_to_keep" checkpoints, an additional checkpoint will be kept + // for every n hours of training. + float keep_checkpoint_every_n_hours = 6; + + // A version number that identifies a different on-disk checkpoint format. + // Usually, each subclass of BaseSaverBuilder works with a particular + // version/format. However, it is possible that the same builder may be + // upgraded to support a newer checkpoint format in the future. + enum CheckpointFormatVersion { + // Internal legacy format. + LEGACY = 0; + // Deprecated format: tf.Saver() which works with tensorflow::table::Table. + V1 = 1; + // Current format: more efficient. + V2 = 2; + } + CheckpointFormatVersion version = 7; +} diff --git a/protos/struct.proto b/protos/struct.proto new file mode 100644 index 00000000..7b590903 --- /dev/null +++ b/protos/struct.proto @@ -0,0 +1,134 @@ +syntax = "proto3"; + +import "tensor_shape.proto"; +import "types.proto"; + +package tensorflow; + +// `StructuredValue` represents a dynamically typed value representing various +// data structures that are inspired by Python data structures typically used in +// TensorFlow functions as inputs and outputs. +// +// For example when saving a Layer there may be a `training` argument. If the +// user passes a boolean True/False, that switches between two concrete +// TensorFlow functions. In order to switch between them in the same way after +// loading the SavedModel, we need to represent "True" and "False". +// +// A more advanced example might be a function which takes a list of +// dictionaries mapping from strings to Tensors. In order to map from +// user-specified arguments `[{"a": tf.constant(1.)}, {"q": tf.constant(3.)}]` +// after load to the right saved TensorFlow function, we need to represent the +// nested structure and the strings, recording that we have a trace for anything +// matching `[{"a": tf.TensorSpec(None, tf.float32)}, {"q": tf.TensorSpec([], +// tf.float64)}]` as an example. +// +// Likewise functions may return nested structures of Tensors, for example +// returning a dictionary mapping from strings to Tensors. In order for the +// loaded function to return the same structure we need to serialize it. +// +// This is an ergonomic aid for working with loaded SavedModels, not a promise +// to serialize all possible function signatures. For example we do not expect +// to pickle generic Python objects, and ideally we'd stay language-agnostic. +message StructuredValue { + // The kind of value. + oneof kind { + // Represents None. + NoneValue none_value = 1; + + // Represents a double-precision floating-point value (a Python `float`). + double float64_value = 11; + // Represents a signed integer value, limited to 64 bits. + // Larger values from Python's arbitrary-precision integers are unsupported. + sint64 int64_value = 12; + // Represents a string of Unicode characters stored in a Python `str`. + // In Python 3, this is exactly what type `str` is. + // In Python 2, this is the UTF-8 encoding of the characters. + // For strings with ASCII characters only (as often used in TensorFlow code) + // there is effectively no difference between the language versions. + // The obsolescent `unicode` type of Python 2 is not supported here. + string string_value = 13; + // Represents a boolean value. + bool bool_value = 14; + + // Represents a TensorShape. + tensorflow.TensorShapeProto tensor_shape_value = 31; + // Represents an enum value for dtype. + tensorflow.DataType tensor_dtype_value = 32; + // Represents a value for tf.TensorSpec. + TensorSpecProto tensor_spec_value = 33; + // Represents a value for tf.TypeSpec. + TypeSpecProto type_spec_value = 34; + + // Represents a list of `Value`. + ListValue list_value = 51; + // Represents a tuple of `Value`. + TupleValue tuple_value = 52; + // Represents a dict `Value`. + DictValue dict_value = 53; + // Represents Python's namedtuple. + NamedTupleValue named_tuple_value = 54; + } +} + +// Represents None. +message NoneValue {} + +// Represents a Python list. +message ListValue { + repeated StructuredValue values = 1; +} + +// Represents a Python tuple. +message TupleValue { + repeated StructuredValue values = 1; +} + +// Represents a Python dict keyed by `str`. +// The comment on Unicode from Value.string_value applies analogously. +message DictValue { + map fields = 1; +} + +// Represents a (key, value) pair. +message PairValue { + string key = 1; + StructuredValue value = 2; +} + +// Represents Python's namedtuple. +message NamedTupleValue { + string name = 1; + repeated PairValue values = 2; +} + +// A protobuf to tf.TensorSpec. +message TensorSpecProto { + string name = 1; + tensorflow.TensorShapeProto shape = 2; + tensorflow.DataType dtype = 3; +} + +// Represents a tf.TypeSpec +message TypeSpecProto { + enum TypeSpecClass { + UNKNOWN = 0; + SPARSE_TENSOR_SPEC = 1; // tf.SparseTensorSpec + INDEXED_SLICES_SPEC = 2; // tf.IndexedSlicesSpec + RAGGED_TENSOR_SPEC = 3; // tf.RaggedTensorSpec + TENSOR_ARRAY_SPEC = 4; // tf.TensorArraySpec + DATA_DATASET_SPEC = 5; // tf.data.DatasetSpec + DATA_ITERATOR_SPEC = 6; // IteratorSpec from data/ops/iterator_ops.py + OPTIONAL_SPEC = 7; // tf.OptionalSpec + PER_REPLICA_SPEC = 8; // PerReplicaSpec from distribute/values.py + } + TypeSpecClass type_spec_class = 1; + + // The value returned by TypeSpec._serialize(). + StructuredValue type_state = 2; + + // This is currently redundant with the type_spec_class enum, and is only + // used for error reporting. In particular, if you use an older binary to + // load a newer model, and the model uses a TypeSpecClass that the older + // binary doesn't support, then this lets us display a useful error message. + string type_spec_class_name = 3; +} diff --git a/protos/tensor.proto b/protos/tensor.proto index c41372d3..5d4d66ae 100644 --- a/protos/tensor.proto +++ b/protos/tensor.proto @@ -5,7 +5,7 @@ option cc_enable_arenas = true; option java_outer_classname = "TensorProtos"; option java_multiple_files = true; option java_package = "org.tensorflow.framework"; - +option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framework"; import "resource_handle.proto"; import "tensor_shape.proto"; import "types.proto"; @@ -28,8 +28,11 @@ message TensorProto { // to represent a constant Tensor with a single value. int32 version_number = 3; - // Serialized content from Tensor::AsProtoTensorContent(). This representation - // can be used for all tensor types. + // Serialized raw tensor content from either Tensor::AsProtoTensorContent or + // memcpy in tensorflow::grpc::EncodeTensorToByteBuffer. This representation + // can be used for all tensor types. The purpose of this representation is to + // reduce serialization overhead during RPC call by avoiding serialization of + // many repeated small items. bytes tensor_content = 4; // Type specific representations that make it easy to create tensor protos in @@ -37,8 +40,8 @@ message TensorProto { // be set. The values hold the flattened representation of the tensor in // row major order. - // DT_HALF. Note that since protobuf has no int16 type, we'll have some - // pointless zero padding for each value here. + // DT_HALF, DT_BFLOAT16. Note that since protobuf has no int16 type, we'll + // have some pointless zero padding for each value here. repeated int32 half_val = 13 [packed = true]; // DT_FLOAT. @@ -48,7 +51,7 @@ message TensorProto { repeated double double_val = 6 [packed = true]; // DT_INT32, DT_INT16, DT_INT8, DT_UINT8. - repeated int32 int_val = 7; + repeated int32 int_val = 7 [packed = true]; // DT_STRING repeated bytes string_val = 8; @@ -68,5 +71,24 @@ message TensorProto { repeated double dcomplex_val = 12 [packed = true]; // DT_RESOURCE - repeated ResourceHandle resource_handle_val = 14; + repeated ResourceHandleProto resource_handle_val = 14; + + // DT_VARIANT + repeated VariantTensorDataProto variant_val = 15; + + // DT_UINT32 + repeated uint32 uint32_val = 16 [packed = true]; + + // DT_UINT64 + repeated uint64 uint64_val = 17 [packed = true]; }; + +// Protocol buffer representing the serialization format of DT_VARIANT tensors. +message VariantTensorDataProto { + // Name of the type of objects being serialized. + string type_name = 1; + // Portions of the object that are not Tensors. + bytes metadata = 2; + // Tensors contained within objects being serialized. + repeated TensorProto tensors = 3; +} diff --git a/protos/tensor_shape.proto b/protos/tensor_shape.proto index 1ec3c532..286156a0 100644 --- a/protos/tensor_shape.proto +++ b/protos/tensor_shape.proto @@ -5,6 +5,7 @@ option cc_enable_arenas = true; option java_outer_classname = "TensorShapeProtos"; option java_multiple_files = true; option java_package = "org.tensorflow.framework"; +option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framework"; package tensorflow; diff --git a/protos/trackable_object_graph.proto b/protos/trackable_object_graph.proto new file mode 100644 index 00000000..02d852e6 --- /dev/null +++ b/protos/trackable_object_graph.proto @@ -0,0 +1,59 @@ +syntax = "proto3"; + +option cc_enable_arenas = true; + +package tensorflow; + +// A TensorBundle addition which saves extra information about the objects which +// own variables, allowing for more robust checkpoint loading into modified +// programs. + +message TrackableObjectGraph { + message TrackableObject { + message ObjectReference { + // An index into `TrackableObjectGraph.nodes`, indicating the object + // being referenced. + int32 node_id = 1; + // A user-provided name for the edge. + string local_name = 2; + } + + message SerializedTensor { + // A name for the Tensor. Simple variables have only one + // `SerializedTensor` named "VARIABLE_VALUE" by convention. This value may + // be restored on object creation as an optimization. + string name = 1; + // The full name of the variable/tensor, if applicable. Used to allow + // name-based loading of checkpoints which were saved using an + // object-based API. Should match the checkpoint key which would have been + // assigned by tf.train.Saver. + string full_name = 2; + // The generated name of the Tensor in the checkpoint. + string checkpoint_key = 3; + // Whether checkpoints should be considered as matching even without this + // value restored. Used for non-critical values which don't affect the + // TensorFlow graph, such as layer configurations. + bool optional_restore = 4; + } + + message SlotVariableReference { + // An index into `TrackableObjectGraph.nodes`, indicating the + // variable object this slot was created for. + int32 original_variable_node_id = 1; + // The name of the slot (e.g. "m"/"v"). + string slot_name = 2; + // An index into `TrackableObjectGraph.nodes`, indicating the + // `Object` with the value of the slot variable. + int32 slot_variable_node_id = 3; + } + + // Objects which this object depends on. + repeated ObjectReference children = 1; + // Serialized data specific to this object. + repeated SerializedTensor attributes = 2; + // Slot variables owned by this object. + repeated SlotVariableReference slot_variables = 3; + } + + repeated TrackableObject nodes = 1; +} diff --git a/protos/types.proto b/protos/types.proto index b80e2b31..5356f9f9 100644 --- a/protos/types.proto +++ b/protos/types.proto @@ -5,7 +5,9 @@ option cc_enable_arenas = true; option java_outer_classname = "TypesProtos"; option java_multiple_files = true; option java_package = "org.tensorflow.framework"; +option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framework"; +// (== suppress_warning documentation-presence ==) // LINT.IfChange enum DataType { // Not a legal value for DataType. Used to indicate a DataType field @@ -34,9 +36,9 @@ enum DataType { DT_COMPLEX128 = 18; // Double-precision complex DT_HALF = 19; DT_RESOURCE = 20; - - // TODO(josh11b): DT_GENERIC_PROTO = ??; - // TODO(jeff,josh11b): DT_UINT64? DT_UINT32? + DT_VARIANT = 21; // Arbitrary C++ data types + DT_UINT32 = 22; + DT_UINT64 = 23; // Do not use! These are only for parameters. Every enum above // should have a corresponding value below (verified by types_test). @@ -60,5 +62,15 @@ enum DataType { DT_COMPLEX128_REF = 118; DT_HALF_REF = 119; DT_RESOURCE_REF = 120; + DT_VARIANT_REF = 121; + DT_UINT32_REF = 122; + DT_UINT64_REF = 123; } -// LINT.ThenChange(https://www.tensorflow.org/code/tensorflow/c/c_api.h,https://www.tensorflow.org/code/tensorflow/go/tensor.go) +// LINT.ThenChange( +// https://www.tensorflow.org/code/tensorflow/c/tf_datatype.h, +// https://www.tensorflow.org/code/tensorflow/go/tensor.go, +// https://www.tensorflow.org/code/tensorflow/core/framework/tensor.cc, +// https://www.tensorflow.org/code/tensorflow/core/framework/types.h, +// https://www.tensorflow.org/code/tensorflow/core/framework/types.cc, +// https://www.tensorflow.org/code/tensorflow/python/framework/dtypes.py, +// https://www.tensorflow.org/code/tensorflow/python/framework/function.py) diff --git a/protos/variable.proto b/protos/variable.proto new file mode 100644 index 00000000..b2978c75 --- /dev/null +++ b/protos/variable.proto @@ -0,0 +1,85 @@ +syntax = "proto3"; + +package tensorflow; + +option cc_enable_arenas = true; +option java_outer_classname = "VariableProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framework"; + +// Indicates when a distributed variable will be synced. +enum VariableSynchronization { + // `AUTO`: Indicates that the synchronization will be determined by the + // current `DistributionStrategy` (eg. With `MirroredStrategy` this would be + // `ON_WRITE`). + VARIABLE_SYNCHRONIZATION_AUTO = 0; + // `NONE`: Indicates that there will only be one copy of the variable, so + // there is no need to sync. + VARIABLE_SYNCHRONIZATION_NONE = 1; + // `ON_WRITE`: Indicates that the variable will be updated across devices + // every time it is written. + VARIABLE_SYNCHRONIZATION_ON_WRITE = 2; + // `ON_READ`: Indicates that the variable will be aggregated across devices + // when it is read (eg. when checkpointing or when evaluating an op that uses + // the variable). + VARIABLE_SYNCHRONIZATION_ON_READ = 3; +} + +// Indicates how a distributed variable will be aggregated. +enum VariableAggregation { + // `NONE`: This is the default, giving an error if you use a + // variable-update operation with multiple replicas. + VARIABLE_AGGREGATION_NONE = 0; + // `SUM`: Add the updates across replicas. + VARIABLE_AGGREGATION_SUM = 1; + // `MEAN`: Take the arithmetic mean ("average") of the updates across + // replicas. + VARIABLE_AGGREGATION_MEAN = 2; + // `ONLY_FIRST_REPLICA`: This is for when every replica is performing the same + // update, but we only want to perform the update once. Used, e.g., for the + // global step counter. + VARIABLE_AGGREGATION_ONLY_FIRST_REPLICA = 3; +} + +// Protocol buffer representing a Variable. +message VariableDef { + // Name of the variable tensor. + string variable_name = 1; + + // Name of the tensor holding the variable's initial value. + string initial_value_name = 6; + + // Name of the initializer op. + string initializer_name = 2; + + // Name of the snapshot tensor. + string snapshot_name = 3; + + // Support for saving variables as slices of a larger variable. + SaveSliceInfoDef save_slice_info_def = 4; + + // Whether to represent this as a ResourceVariable. + bool is_resource = 5; + + // Whether this variable should be trained. + bool trainable = 7; + + // Indicates when a distributed variable will be synced. + VariableSynchronization synchronization = 8; + + // Indicates how a distributed variable will be aggregated. + VariableAggregation aggregation = 9; +} + +message SaveSliceInfoDef { + // Name of the full variable of which this is a slice. + string full_name = 1; + // Shape of the full variable. + repeated int64 full_shape = 2; + // Offset of this variable into the full variable. + repeated int64 var_offset = 3; + // Shape of this variable. + repeated int64 var_shape = 4; +} diff --git a/protos/versions.proto b/protos/versions.proto new file mode 100644 index 00000000..dd2ec552 --- /dev/null +++ b/protos/versions.proto @@ -0,0 +1,32 @@ +syntax = "proto3"; + +package tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "VersionsProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; +option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framework"; + +// Version information for a piece of serialized data +// +// There are different types of versions for each type of data +// (GraphDef, etc.), but they all have the same common shape +// described here. +// +// Each consumer has "consumer" and "min_producer" versions (specified +// elsewhere). A consumer is allowed to consume this data if +// +// producer >= min_producer +// consumer >= min_consumer +// consumer not in bad_consumers +// +message VersionDef { + // The version of the code that produced this data. + int32 producer = 1; + + // Any consumer below this version is not allowed to consume this data. + int32 min_consumer = 2; + + // Specific consumer versions which are disallowed (e.g. due to bugs). + repeated int32 bad_consumers = 3; +}; From 3a3cc8f665246ab8446f5ef817e271ee8be50dc2 Mon Sep 17 00:00:00 2001 From: William Graf <7930703+willgraf@users.noreply.github.com> Date: Sun, 1 Mar 2020 21:49:21 -0800 Subject: [PATCH 03/47] generated pbs from new proto files. --- redis_consumer/pbs/attr_value_pb2.py | 365 ++++++ redis_consumer/pbs/attr_value_pb2_grpc.py | 3 + redis_consumer/pbs/function_pb2.py | 451 ++++++- redis_consumer/pbs/get_model_metadata_pb2.py | 265 +++++ .../pbs/get_model_metadata_pb2_grpc.py | 3 + redis_consumer/pbs/graph_pb2.py | 98 ++ redis_consumer/pbs/graph_pb2_grpc.py | 3 + redis_consumer/pbs/meta_graph_pb2.py | 1031 +++++++++++++++++ redis_consumer/pbs/meta_graph_pb2_grpc.py | 3 + redis_consumer/pbs/model_pb2.py | 44 +- redis_consumer/pbs/node_def_pb2.py | 202 ++++ redis_consumer/pbs/node_def_pb2_grpc.py | 3 + redis_consumer/pbs/op_def_pb2.py | 404 +++++++ redis_consumer/pbs/op_def_pb2_grpc.py | 3 + redis_consumer/pbs/predict_pb2.py | 67 +- redis_consumer/pbs/prediction_service_pb2.py | 25 +- .../pbs/prediction_service_pb2_grpc.py | 38 +- redis_consumer/pbs/resource_handle_pb2.py | 112 +- redis_consumer/pbs/saved_object_graph_pb2.py | 720 ++++++++++++ .../pbs/saved_object_graph_pb2_grpc.py | 3 + redis_consumer/pbs/saver_pb2.py | 140 +++ redis_consumer/pbs/saver_pb2_grpc.py | 3 + redis_consumer/pbs/struct_pb2.py | 662 +++++++++++ redis_consumer/pbs/struct_pb2_grpc.py | 3 + redis_consumer/pbs/tensor_pb2.py | 122 +- redis_consumer/pbs/tensor_shape_pb2.py | 25 +- .../pbs/trackable_object_graph_pb2.py | 285 +++++ .../pbs/trackable_object_graph_pb2_grpc.py | 3 + redis_consumer/pbs/types_pb2.py | 79 +- redis_consumer/pbs/variable_pb2.py | 261 +++++ redis_consumer/pbs/variable_pb2_grpc.py | 3 + redis_consumer/pbs/versions_pb2.py | 83 ++ redis_consumer/pbs/versions_pb2_grpc.py | 3 + 33 files changed, 5354 insertions(+), 161 deletions(-) create mode 100644 redis_consumer/pbs/attr_value_pb2.py create mode 100644 redis_consumer/pbs/attr_value_pb2_grpc.py create mode 100644 redis_consumer/pbs/get_model_metadata_pb2.py create mode 100644 redis_consumer/pbs/get_model_metadata_pb2_grpc.py create mode 100644 redis_consumer/pbs/graph_pb2.py create mode 100644 redis_consumer/pbs/graph_pb2_grpc.py create mode 100644 redis_consumer/pbs/meta_graph_pb2.py create mode 100644 redis_consumer/pbs/meta_graph_pb2_grpc.py create mode 100644 redis_consumer/pbs/node_def_pb2.py create mode 100644 redis_consumer/pbs/node_def_pb2_grpc.py create mode 100644 redis_consumer/pbs/op_def_pb2.py create mode 100644 redis_consumer/pbs/op_def_pb2_grpc.py create mode 100644 redis_consumer/pbs/saved_object_graph_pb2.py create mode 100644 redis_consumer/pbs/saved_object_graph_pb2_grpc.py create mode 100644 redis_consumer/pbs/saver_pb2.py create mode 100644 redis_consumer/pbs/saver_pb2_grpc.py create mode 100644 redis_consumer/pbs/struct_pb2.py create mode 100644 redis_consumer/pbs/struct_pb2_grpc.py create mode 100644 redis_consumer/pbs/trackable_object_graph_pb2.py create mode 100644 redis_consumer/pbs/trackable_object_graph_pb2_grpc.py create mode 100644 redis_consumer/pbs/variable_pb2.py create mode 100644 redis_consumer/pbs/variable_pb2_grpc.py create mode 100644 redis_consumer/pbs/versions_pb2.py create mode 100644 redis_consumer/pbs/versions_pb2_grpc.py diff --git a/redis_consumer/pbs/attr_value_pb2.py b/redis_consumer/pbs/attr_value_pb2.py new file mode 100644 index 00000000..a059608c --- /dev/null +++ b/redis_consumer/pbs/attr_value_pb2.py @@ -0,0 +1,365 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: attr_value.proto + +from google.protobuf import descriptor as _descriptor +from google.protobuf import message as _message +from google.protobuf import reflection as _reflection +from google.protobuf import symbol_database as _symbol_database +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +import tensor_pb2 as tensor__pb2 +import tensor_shape_pb2 as tensor__shape__pb2 +import types_pb2 as types__pb2 + + +DESCRIPTOR = _descriptor.FileDescriptor( + name='attr_value.proto', + package='tensorflow', + syntax='proto3', + serialized_options=b'\n\030org.tensorflow.frameworkB\017AttrValueProtosP\001Z=github.com/tensorflow/tensorflow/tensorflow/go/core/framework\370\001\001', + serialized_pb=b'\n\x10\x61ttr_value.proto\x12\ntensorflow\x1a\x0ctensor.proto\x1a\x12tensor_shape.proto\x1a\x0btypes.proto\"\xa6\x04\n\tAttrValue\x12\x0b\n\x01s\x18\x02 \x01(\x0cH\x00\x12\x0b\n\x01i\x18\x03 \x01(\x03H\x00\x12\x0b\n\x01\x66\x18\x04 \x01(\x02H\x00\x12\x0b\n\x01\x62\x18\x05 \x01(\x08H\x00\x12$\n\x04type\x18\x06 \x01(\x0e\x32\x14.tensorflow.DataTypeH\x00\x12-\n\x05shape\x18\x07 \x01(\x0b\x32\x1c.tensorflow.TensorShapeProtoH\x00\x12)\n\x06tensor\x18\x08 \x01(\x0b\x32\x17.tensorflow.TensorProtoH\x00\x12/\n\x04list\x18\x01 \x01(\x0b\x32\x1f.tensorflow.AttrValue.ListValueH\x00\x12(\n\x04\x66unc\x18\n \x01(\x0b\x32\x18.tensorflow.NameAttrListH\x00\x12\x15\n\x0bplaceholder\x18\t \x01(\tH\x00\x1a\xe9\x01\n\tListValue\x12\t\n\x01s\x18\x02 \x03(\x0c\x12\r\n\x01i\x18\x03 \x03(\x03\x42\x02\x10\x01\x12\r\n\x01\x66\x18\x04 \x03(\x02\x42\x02\x10\x01\x12\r\n\x01\x62\x18\x05 \x03(\x08\x42\x02\x10\x01\x12&\n\x04type\x18\x06 \x03(\x0e\x32\x14.tensorflow.DataTypeB\x02\x10\x01\x12+\n\x05shape\x18\x07 \x03(\x0b\x32\x1c.tensorflow.TensorShapeProto\x12\'\n\x06tensor\x18\x08 \x03(\x0b\x32\x17.tensorflow.TensorProto\x12&\n\x04\x66unc\x18\t \x03(\x0b\x32\x18.tensorflow.NameAttrListB\x07\n\x05value\"\x92\x01\n\x0cNameAttrList\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x30\n\x04\x61ttr\x18\x02 \x03(\x0b\x32\".tensorflow.NameAttrList.AttrEntry\x1a\x42\n\tAttrEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12$\n\x05value\x18\x02 \x01(\x0b\x32\x15.tensorflow.AttrValue:\x02\x38\x01\x42o\n\x18org.tensorflow.frameworkB\x0f\x41ttrValueProtosP\x01Z=github.com/tensorflow/tensorflow/tensorflow/go/core/framework\xf8\x01\x01\x62\x06proto3' + , + dependencies=[tensor__pb2.DESCRIPTOR,tensor__shape__pb2.DESCRIPTOR,types__pb2.DESCRIPTOR,]) + + + + +_ATTRVALUE_LISTVALUE = _descriptor.Descriptor( + name='ListValue', + full_name='tensorflow.AttrValue.ListValue', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='s', full_name='tensorflow.AttrValue.ListValue.s', index=0, + number=2, type=12, cpp_type=9, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='i', full_name='tensorflow.AttrValue.ListValue.i', index=1, + number=3, type=3, cpp_type=2, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=b'\020\001', file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='f', full_name='tensorflow.AttrValue.ListValue.f', index=2, + number=4, type=2, cpp_type=6, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=b'\020\001', file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='b', full_name='tensorflow.AttrValue.ListValue.b', index=3, + number=5, type=8, cpp_type=7, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=b'\020\001', file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='type', full_name='tensorflow.AttrValue.ListValue.type', index=4, + number=6, type=14, cpp_type=8, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=b'\020\001', file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='shape', full_name='tensorflow.AttrValue.ListValue.shape', index=5, + number=7, type=11, cpp_type=10, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='tensor', full_name='tensorflow.AttrValue.ListValue.tensor', index=6, + number=8, type=11, cpp_type=10, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='func', full_name='tensorflow.AttrValue.ListValue.func', index=7, + number=9, type=11, cpp_type=10, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=388, + serialized_end=621, +) + +_ATTRVALUE = _descriptor.Descriptor( + name='AttrValue', + full_name='tensorflow.AttrValue', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='s', full_name='tensorflow.AttrValue.s', index=0, + number=2, type=12, cpp_type=9, label=1, + has_default_value=False, default_value=b"", + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='i', full_name='tensorflow.AttrValue.i', index=1, + number=3, type=3, cpp_type=2, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='f', full_name='tensorflow.AttrValue.f', index=2, + number=4, type=2, cpp_type=6, label=1, + has_default_value=False, default_value=float(0), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='b', full_name='tensorflow.AttrValue.b', index=3, + number=5, type=8, cpp_type=7, label=1, + has_default_value=False, default_value=False, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='type', full_name='tensorflow.AttrValue.type', index=4, + number=6, type=14, cpp_type=8, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='shape', full_name='tensorflow.AttrValue.shape', index=5, + number=7, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='tensor', full_name='tensorflow.AttrValue.tensor', index=6, + number=8, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='list', full_name='tensorflow.AttrValue.list', index=7, + number=1, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='func', full_name='tensorflow.AttrValue.func', index=8, + number=10, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='placeholder', full_name='tensorflow.AttrValue.placeholder', index=9, + number=9, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[_ATTRVALUE_LISTVALUE, ], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + _descriptor.OneofDescriptor( + name='value', full_name='tensorflow.AttrValue.value', + index=0, containing_type=None, fields=[]), + ], + serialized_start=80, + serialized_end=630, +) + + +_NAMEATTRLIST_ATTRENTRY = _descriptor.Descriptor( + name='AttrEntry', + full_name='tensorflow.NameAttrList.AttrEntry', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='key', full_name='tensorflow.NameAttrList.AttrEntry.key', index=0, + number=1, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='value', full_name='tensorflow.NameAttrList.AttrEntry.value', index=1, + number=2, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=b'8\001', + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=713, + serialized_end=779, +) + +_NAMEATTRLIST = _descriptor.Descriptor( + name='NameAttrList', + full_name='tensorflow.NameAttrList', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='name', full_name='tensorflow.NameAttrList.name', index=0, + number=1, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='attr', full_name='tensorflow.NameAttrList.attr', index=1, + number=2, type=11, cpp_type=10, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[_NAMEATTRLIST_ATTRENTRY, ], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=633, + serialized_end=779, +) + +_ATTRVALUE_LISTVALUE.fields_by_name['type'].enum_type = types__pb2._DATATYPE +_ATTRVALUE_LISTVALUE.fields_by_name['shape'].message_type = tensor__shape__pb2._TENSORSHAPEPROTO +_ATTRVALUE_LISTVALUE.fields_by_name['tensor'].message_type = tensor__pb2._TENSORPROTO +_ATTRVALUE_LISTVALUE.fields_by_name['func'].message_type = _NAMEATTRLIST +_ATTRVALUE_LISTVALUE.containing_type = _ATTRVALUE +_ATTRVALUE.fields_by_name['type'].enum_type = types__pb2._DATATYPE +_ATTRVALUE.fields_by_name['shape'].message_type = tensor__shape__pb2._TENSORSHAPEPROTO +_ATTRVALUE.fields_by_name['tensor'].message_type = tensor__pb2._TENSORPROTO +_ATTRVALUE.fields_by_name['list'].message_type = _ATTRVALUE_LISTVALUE +_ATTRVALUE.fields_by_name['func'].message_type = _NAMEATTRLIST +_ATTRVALUE.oneofs_by_name['value'].fields.append( + _ATTRVALUE.fields_by_name['s']) +_ATTRVALUE.fields_by_name['s'].containing_oneof = _ATTRVALUE.oneofs_by_name['value'] +_ATTRVALUE.oneofs_by_name['value'].fields.append( + _ATTRVALUE.fields_by_name['i']) +_ATTRVALUE.fields_by_name['i'].containing_oneof = _ATTRVALUE.oneofs_by_name['value'] +_ATTRVALUE.oneofs_by_name['value'].fields.append( + _ATTRVALUE.fields_by_name['f']) +_ATTRVALUE.fields_by_name['f'].containing_oneof = _ATTRVALUE.oneofs_by_name['value'] +_ATTRVALUE.oneofs_by_name['value'].fields.append( + _ATTRVALUE.fields_by_name['b']) +_ATTRVALUE.fields_by_name['b'].containing_oneof = _ATTRVALUE.oneofs_by_name['value'] +_ATTRVALUE.oneofs_by_name['value'].fields.append( + _ATTRVALUE.fields_by_name['type']) +_ATTRVALUE.fields_by_name['type'].containing_oneof = _ATTRVALUE.oneofs_by_name['value'] +_ATTRVALUE.oneofs_by_name['value'].fields.append( + _ATTRVALUE.fields_by_name['shape']) +_ATTRVALUE.fields_by_name['shape'].containing_oneof = _ATTRVALUE.oneofs_by_name['value'] +_ATTRVALUE.oneofs_by_name['value'].fields.append( + _ATTRVALUE.fields_by_name['tensor']) +_ATTRVALUE.fields_by_name['tensor'].containing_oneof = _ATTRVALUE.oneofs_by_name['value'] +_ATTRVALUE.oneofs_by_name['value'].fields.append( + _ATTRVALUE.fields_by_name['list']) +_ATTRVALUE.fields_by_name['list'].containing_oneof = _ATTRVALUE.oneofs_by_name['value'] +_ATTRVALUE.oneofs_by_name['value'].fields.append( + _ATTRVALUE.fields_by_name['func']) +_ATTRVALUE.fields_by_name['func'].containing_oneof = _ATTRVALUE.oneofs_by_name['value'] +_ATTRVALUE.oneofs_by_name['value'].fields.append( + _ATTRVALUE.fields_by_name['placeholder']) +_ATTRVALUE.fields_by_name['placeholder'].containing_oneof = _ATTRVALUE.oneofs_by_name['value'] +_NAMEATTRLIST_ATTRENTRY.fields_by_name['value'].message_type = _ATTRVALUE +_NAMEATTRLIST_ATTRENTRY.containing_type = _NAMEATTRLIST +_NAMEATTRLIST.fields_by_name['attr'].message_type = _NAMEATTRLIST_ATTRENTRY +DESCRIPTOR.message_types_by_name['AttrValue'] = _ATTRVALUE +DESCRIPTOR.message_types_by_name['NameAttrList'] = _NAMEATTRLIST +_sym_db.RegisterFileDescriptor(DESCRIPTOR) + +AttrValue = _reflection.GeneratedProtocolMessageType('AttrValue', (_message.Message,), { + + 'ListValue' : _reflection.GeneratedProtocolMessageType('ListValue', (_message.Message,), { + 'DESCRIPTOR' : _ATTRVALUE_LISTVALUE, + '__module__' : 'attr_value_pb2' + # @@protoc_insertion_point(class_scope:tensorflow.AttrValue.ListValue) + }) + , + 'DESCRIPTOR' : _ATTRVALUE, + '__module__' : 'attr_value_pb2' + # @@protoc_insertion_point(class_scope:tensorflow.AttrValue) + }) +_sym_db.RegisterMessage(AttrValue) +_sym_db.RegisterMessage(AttrValue.ListValue) + +NameAttrList = _reflection.GeneratedProtocolMessageType('NameAttrList', (_message.Message,), { + + 'AttrEntry' : _reflection.GeneratedProtocolMessageType('AttrEntry', (_message.Message,), { + 'DESCRIPTOR' : _NAMEATTRLIST_ATTRENTRY, + '__module__' : 'attr_value_pb2' + # @@protoc_insertion_point(class_scope:tensorflow.NameAttrList.AttrEntry) + }) + , + 'DESCRIPTOR' : _NAMEATTRLIST, + '__module__' : 'attr_value_pb2' + # @@protoc_insertion_point(class_scope:tensorflow.NameAttrList) + }) +_sym_db.RegisterMessage(NameAttrList) +_sym_db.RegisterMessage(NameAttrList.AttrEntry) + + +DESCRIPTOR._options = None +_ATTRVALUE_LISTVALUE.fields_by_name['i']._options = None +_ATTRVALUE_LISTVALUE.fields_by_name['f']._options = None +_ATTRVALUE_LISTVALUE.fields_by_name['b']._options = None +_ATTRVALUE_LISTVALUE.fields_by_name['type']._options = None +_NAMEATTRLIST_ATTRENTRY._options = None +# @@protoc_insertion_point(module_scope) diff --git a/redis_consumer/pbs/attr_value_pb2_grpc.py b/redis_consumer/pbs/attr_value_pb2_grpc.py new file mode 100644 index 00000000..a8943526 --- /dev/null +++ b/redis_consumer/pbs/attr_value_pb2_grpc.py @@ -0,0 +1,3 @@ +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +import grpc + diff --git a/redis_consumer/pbs/function_pb2.py b/redis_consumer/pbs/function_pb2.py index fb4f1f13..2b6c5658 100644 --- a/redis_consumer/pbs/function_pb2.py +++ b/redis_consumer/pbs/function_pb2.py @@ -1,8 +1,7 @@ +# -*- coding: utf-8 -*- # Generated by the protocol buffer compiler. DO NOT EDIT! # source: function.proto -import sys -_b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) from google.protobuf import descriptor as _descriptor from google.protobuf import message as _message from google.protobuf import reflection as _reflection @@ -12,37 +11,360 @@ _sym_db = _symbol_database.Default() +import attr_value_pb2 as attr__value__pb2 +import node_def_pb2 as node__def__pb2 +import op_def_pb2 as op__def__pb2 DESCRIPTOR = _descriptor.FileDescriptor( name='function.proto', - package='tensorflow.serving', + package='tensorflow', syntax='proto3', - serialized_options=_b('\370\001\001'), - serialized_pb=_b('\n\x0e\x66unction.proto\x12\x12tensorflow.serving\"*\n\x0c\x46unctionSpec\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0c\n\x04type\x18\x02 \x01(\tB\x03\xf8\x01\x01\x62\x06proto3') + serialized_options=b'\n\030org.tensorflow.frameworkB\016FunctionProtosP\001Z=github.com/tensorflow/tensorflow/tensorflow/go/core/framework\370\001\001', + serialized_pb=b'\n\x0e\x66unction.proto\x12\ntensorflow\x1a\x10\x61ttr_value.proto\x1a\x0enode_def.proto\x1a\x0cop_def.proto\"j\n\x12\x46unctionDefLibrary\x12)\n\x08\x66unction\x18\x01 \x03(\x0b\x32\x17.tensorflow.FunctionDef\x12)\n\x08gradient\x18\x02 \x03(\x0b\x32\x17.tensorflow.GradientDef\"\xb6\x05\n\x0b\x46unctionDef\x12$\n\tsignature\x18\x01 \x01(\x0b\x32\x11.tensorflow.OpDef\x12/\n\x04\x61ttr\x18\x05 \x03(\x0b\x32!.tensorflow.FunctionDef.AttrEntry\x12\x36\n\x08\x61rg_attr\x18\x07 \x03(\x0b\x32$.tensorflow.FunctionDef.ArgAttrEntry\x12%\n\x08node_def\x18\x03 \x03(\x0b\x32\x13.tensorflow.NodeDef\x12-\n\x03ret\x18\x04 \x03(\x0b\x32 .tensorflow.FunctionDef.RetEntry\x12<\n\x0b\x63ontrol_ret\x18\x06 \x03(\x0b\x32\'.tensorflow.FunctionDef.ControlRetEntry\x1a\x42\n\tAttrEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12$\n\x05value\x18\x02 \x01(\x0b\x32\x15.tensorflow.AttrValue:\x02\x38\x01\x1a\x88\x01\n\x08\x41rgAttrs\x12\x38\n\x04\x61ttr\x18\x01 \x03(\x0b\x32*.tensorflow.FunctionDef.ArgAttrs.AttrEntry\x1a\x42\n\tAttrEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12$\n\x05value\x18\x02 \x01(\x0b\x32\x15.tensorflow.AttrValue:\x02\x38\x01\x1aP\n\x0c\x41rgAttrEntry\x12\x0b\n\x03key\x18\x01 \x01(\r\x12/\n\x05value\x18\x02 \x01(\x0b\x32 .tensorflow.FunctionDef.ArgAttrs:\x02\x38\x01\x1a*\n\x08RetEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\x1a\x31\n\x0f\x43ontrolRetEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01J\x04\x08\x02\x10\x03\";\n\x0bGradientDef\x12\x15\n\rfunction_name\x18\x01 \x01(\t\x12\x15\n\rgradient_func\x18\x02 \x01(\tBn\n\x18org.tensorflow.frameworkB\x0e\x46unctionProtosP\x01Z=github.com/tensorflow/tensorflow/tensorflow/go/core/framework\xf8\x01\x01\x62\x06proto3' + , + dependencies=[attr__value__pb2.DESCRIPTOR,node__def__pb2.DESCRIPTOR,op__def__pb2.DESCRIPTOR,]) + + + + +_FUNCTIONDEFLIBRARY = _descriptor.Descriptor( + name='FunctionDefLibrary', + full_name='tensorflow.FunctionDefLibrary', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='function', full_name='tensorflow.FunctionDefLibrary.function', index=0, + number=1, type=11, cpp_type=10, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='gradient', full_name='tensorflow.FunctionDefLibrary.gradient', index=1, + number=2, type=11, cpp_type=10, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=78, + serialized_end=184, +) + + +_FUNCTIONDEF_ATTRENTRY = _descriptor.Descriptor( + name='AttrEntry', + full_name='tensorflow.FunctionDef.AttrEntry', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='key', full_name='tensorflow.FunctionDef.AttrEntry.key', index=0, + number=1, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='value', full_name='tensorflow.FunctionDef.AttrEntry.value', index=1, + number=2, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=b'8\001', + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=493, + serialized_end=559, +) + +_FUNCTIONDEF_ARGATTRS_ATTRENTRY = _descriptor.Descriptor( + name='AttrEntry', + full_name='tensorflow.FunctionDef.ArgAttrs.AttrEntry', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='key', full_name='tensorflow.FunctionDef.ArgAttrs.AttrEntry.key', index=0, + number=1, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='value', full_name='tensorflow.FunctionDef.ArgAttrs.AttrEntry.value', index=1, + number=2, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=b'8\001', + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=493, + serialized_end=559, +) + +_FUNCTIONDEF_ARGATTRS = _descriptor.Descriptor( + name='ArgAttrs', + full_name='tensorflow.FunctionDef.ArgAttrs', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='attr', full_name='tensorflow.FunctionDef.ArgAttrs.attr', index=0, + number=1, type=11, cpp_type=10, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[_FUNCTIONDEF_ARGATTRS_ATTRENTRY, ], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=562, + serialized_end=698, +) + +_FUNCTIONDEF_ARGATTRENTRY = _descriptor.Descriptor( + name='ArgAttrEntry', + full_name='tensorflow.FunctionDef.ArgAttrEntry', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='key', full_name='tensorflow.FunctionDef.ArgAttrEntry.key', index=0, + number=1, type=13, cpp_type=3, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='value', full_name='tensorflow.FunctionDef.ArgAttrEntry.value', index=1, + number=2, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=b'8\001', + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=700, + serialized_end=780, ) +_FUNCTIONDEF_RETENTRY = _descriptor.Descriptor( + name='RetEntry', + full_name='tensorflow.FunctionDef.RetEntry', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='key', full_name='tensorflow.FunctionDef.RetEntry.key', index=0, + number=1, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='value', full_name='tensorflow.FunctionDef.RetEntry.value', index=1, + number=2, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=b'8\001', + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=782, + serialized_end=824, +) + +_FUNCTIONDEF_CONTROLRETENTRY = _descriptor.Descriptor( + name='ControlRetEntry', + full_name='tensorflow.FunctionDef.ControlRetEntry', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='key', full_name='tensorflow.FunctionDef.ControlRetEntry.key', index=0, + number=1, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='value', full_name='tensorflow.FunctionDef.ControlRetEntry.value', index=1, + number=2, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=b'8\001', + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=826, + serialized_end=875, +) +_FUNCTIONDEF = _descriptor.Descriptor( + name='FunctionDef', + full_name='tensorflow.FunctionDef', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='signature', full_name='tensorflow.FunctionDef.signature', index=0, + number=1, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='attr', full_name='tensorflow.FunctionDef.attr', index=1, + number=5, type=11, cpp_type=10, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='arg_attr', full_name='tensorflow.FunctionDef.arg_attr', index=2, + number=7, type=11, cpp_type=10, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='node_def', full_name='tensorflow.FunctionDef.node_def', index=3, + number=3, type=11, cpp_type=10, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='ret', full_name='tensorflow.FunctionDef.ret', index=4, + number=4, type=11, cpp_type=10, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='control_ret', full_name='tensorflow.FunctionDef.control_ret', index=5, + number=6, type=11, cpp_type=10, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[_FUNCTIONDEF_ATTRENTRY, _FUNCTIONDEF_ARGATTRS, _FUNCTIONDEF_ARGATTRENTRY, _FUNCTIONDEF_RETENTRY, _FUNCTIONDEF_CONTROLRETENTRY, ], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=187, + serialized_end=881, +) -_FUNCTIONSPEC = _descriptor.Descriptor( - name='FunctionSpec', - full_name='tensorflow.serving.FunctionSpec', +_GRADIENTDEF = _descriptor.Descriptor( + name='GradientDef', + full_name='tensorflow.GradientDef', filename=None, file=DESCRIPTOR, containing_type=None, fields=[ _descriptor.FieldDescriptor( - name='name', full_name='tensorflow.serving.FunctionSpec.name', index=0, + name='function_name', full_name='tensorflow.GradientDef.function_name', index=0, number=1, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=_b("").decode('utf-8'), + has_default_value=False, default_value=b"".decode('utf-8'), message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, serialized_options=None, file=DESCRIPTOR), _descriptor.FieldDescriptor( - name='type', full_name='tensorflow.serving.FunctionSpec.type', index=1, + name='gradient_func', full_name='tensorflow.GradientDef.gradient_func', index=1, number=2, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=_b("").decode('utf-8'), + has_default_value=False, default_value=b"".decode('utf-8'), message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, serialized_options=None, file=DESCRIPTOR), @@ -58,20 +380,107 @@ extension_ranges=[], oneofs=[ ], - serialized_start=38, - serialized_end=80, + serialized_start=883, + serialized_end=942, ) -DESCRIPTOR.message_types_by_name['FunctionSpec'] = _FUNCTIONSPEC +_FUNCTIONDEFLIBRARY.fields_by_name['function'].message_type = _FUNCTIONDEF +_FUNCTIONDEFLIBRARY.fields_by_name['gradient'].message_type = _GRADIENTDEF +_FUNCTIONDEF_ATTRENTRY.fields_by_name['value'].message_type = attr__value__pb2._ATTRVALUE +_FUNCTIONDEF_ATTRENTRY.containing_type = _FUNCTIONDEF +_FUNCTIONDEF_ARGATTRS_ATTRENTRY.fields_by_name['value'].message_type = attr__value__pb2._ATTRVALUE +_FUNCTIONDEF_ARGATTRS_ATTRENTRY.containing_type = _FUNCTIONDEF_ARGATTRS +_FUNCTIONDEF_ARGATTRS.fields_by_name['attr'].message_type = _FUNCTIONDEF_ARGATTRS_ATTRENTRY +_FUNCTIONDEF_ARGATTRS.containing_type = _FUNCTIONDEF +_FUNCTIONDEF_ARGATTRENTRY.fields_by_name['value'].message_type = _FUNCTIONDEF_ARGATTRS +_FUNCTIONDEF_ARGATTRENTRY.containing_type = _FUNCTIONDEF +_FUNCTIONDEF_RETENTRY.containing_type = _FUNCTIONDEF +_FUNCTIONDEF_CONTROLRETENTRY.containing_type = _FUNCTIONDEF +_FUNCTIONDEF.fields_by_name['signature'].message_type = op__def__pb2._OPDEF +_FUNCTIONDEF.fields_by_name['attr'].message_type = _FUNCTIONDEF_ATTRENTRY +_FUNCTIONDEF.fields_by_name['arg_attr'].message_type = _FUNCTIONDEF_ARGATTRENTRY +_FUNCTIONDEF.fields_by_name['node_def'].message_type = node__def__pb2._NODEDEF +_FUNCTIONDEF.fields_by_name['ret'].message_type = _FUNCTIONDEF_RETENTRY +_FUNCTIONDEF.fields_by_name['control_ret'].message_type = _FUNCTIONDEF_CONTROLRETENTRY +DESCRIPTOR.message_types_by_name['FunctionDefLibrary'] = _FUNCTIONDEFLIBRARY +DESCRIPTOR.message_types_by_name['FunctionDef'] = _FUNCTIONDEF +DESCRIPTOR.message_types_by_name['GradientDef'] = _GRADIENTDEF _sym_db.RegisterFileDescriptor(DESCRIPTOR) -FunctionSpec = _reflection.GeneratedProtocolMessageType('FunctionSpec', (_message.Message,), dict( - DESCRIPTOR = _FUNCTIONSPEC, - __module__ = 'function_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.serving.FunctionSpec) - )) -_sym_db.RegisterMessage(FunctionSpec) +FunctionDefLibrary = _reflection.GeneratedProtocolMessageType('FunctionDefLibrary', (_message.Message,), { + 'DESCRIPTOR' : _FUNCTIONDEFLIBRARY, + '__module__' : 'function_pb2' + # @@protoc_insertion_point(class_scope:tensorflow.FunctionDefLibrary) + }) +_sym_db.RegisterMessage(FunctionDefLibrary) + +FunctionDef = _reflection.GeneratedProtocolMessageType('FunctionDef', (_message.Message,), { + + 'AttrEntry' : _reflection.GeneratedProtocolMessageType('AttrEntry', (_message.Message,), { + 'DESCRIPTOR' : _FUNCTIONDEF_ATTRENTRY, + '__module__' : 'function_pb2' + # @@protoc_insertion_point(class_scope:tensorflow.FunctionDef.AttrEntry) + }) + , + + 'ArgAttrs' : _reflection.GeneratedProtocolMessageType('ArgAttrs', (_message.Message,), { + + 'AttrEntry' : _reflection.GeneratedProtocolMessageType('AttrEntry', (_message.Message,), { + 'DESCRIPTOR' : _FUNCTIONDEF_ARGATTRS_ATTRENTRY, + '__module__' : 'function_pb2' + # @@protoc_insertion_point(class_scope:tensorflow.FunctionDef.ArgAttrs.AttrEntry) + }) + , + 'DESCRIPTOR' : _FUNCTIONDEF_ARGATTRS, + '__module__' : 'function_pb2' + # @@protoc_insertion_point(class_scope:tensorflow.FunctionDef.ArgAttrs) + }) + , + + 'ArgAttrEntry' : _reflection.GeneratedProtocolMessageType('ArgAttrEntry', (_message.Message,), { + 'DESCRIPTOR' : _FUNCTIONDEF_ARGATTRENTRY, + '__module__' : 'function_pb2' + # @@protoc_insertion_point(class_scope:tensorflow.FunctionDef.ArgAttrEntry) + }) + , + + 'RetEntry' : _reflection.GeneratedProtocolMessageType('RetEntry', (_message.Message,), { + 'DESCRIPTOR' : _FUNCTIONDEF_RETENTRY, + '__module__' : 'function_pb2' + # @@protoc_insertion_point(class_scope:tensorflow.FunctionDef.RetEntry) + }) + , + + 'ControlRetEntry' : _reflection.GeneratedProtocolMessageType('ControlRetEntry', (_message.Message,), { + 'DESCRIPTOR' : _FUNCTIONDEF_CONTROLRETENTRY, + '__module__' : 'function_pb2' + # @@protoc_insertion_point(class_scope:tensorflow.FunctionDef.ControlRetEntry) + }) + , + 'DESCRIPTOR' : _FUNCTIONDEF, + '__module__' : 'function_pb2' + # @@protoc_insertion_point(class_scope:tensorflow.FunctionDef) + }) +_sym_db.RegisterMessage(FunctionDef) +_sym_db.RegisterMessage(FunctionDef.AttrEntry) +_sym_db.RegisterMessage(FunctionDef.ArgAttrs) +_sym_db.RegisterMessage(FunctionDef.ArgAttrs.AttrEntry) +_sym_db.RegisterMessage(FunctionDef.ArgAttrEntry) +_sym_db.RegisterMessage(FunctionDef.RetEntry) +_sym_db.RegisterMessage(FunctionDef.ControlRetEntry) + +GradientDef = _reflection.GeneratedProtocolMessageType('GradientDef', (_message.Message,), { + 'DESCRIPTOR' : _GRADIENTDEF, + '__module__' : 'function_pb2' + # @@protoc_insertion_point(class_scope:tensorflow.GradientDef) + }) +_sym_db.RegisterMessage(GradientDef) DESCRIPTOR._options = None +_FUNCTIONDEF_ATTRENTRY._options = None +_FUNCTIONDEF_ARGATTRS_ATTRENTRY._options = None +_FUNCTIONDEF_ARGATTRENTRY._options = None +_FUNCTIONDEF_RETENTRY._options = None +_FUNCTIONDEF_CONTROLRETENTRY._options = None # @@protoc_insertion_point(module_scope) diff --git a/redis_consumer/pbs/get_model_metadata_pb2.py b/redis_consumer/pbs/get_model_metadata_pb2.py new file mode 100644 index 00000000..bc5579fb --- /dev/null +++ b/redis_consumer/pbs/get_model_metadata_pb2.py @@ -0,0 +1,265 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: get_model_metadata.proto + +from google.protobuf import descriptor as _descriptor +from google.protobuf import message as _message +from google.protobuf import reflection as _reflection +from google.protobuf import symbol_database as _symbol_database +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +from google.protobuf import any_pb2 as google_dot_protobuf_dot_any__pb2 +import meta_graph_pb2 as meta__graph__pb2 +import model_pb2 as model__pb2 + + +DESCRIPTOR = _descriptor.FileDescriptor( + name='get_model_metadata.proto', + package='tensorflow.serving', + syntax='proto3', + serialized_options=b'\370\001\001', + serialized_pb=b'\n\x18get_model_metadata.proto\x12\x12tensorflow.serving\x1a\x19google/protobuf/any.proto\x1a\x10meta_graph.proto\x1a\x0bmodel.proto\"\xae\x01\n\x0fSignatureDefMap\x12L\n\rsignature_def\x18\x01 \x03(\x0b\x32\x35.tensorflow.serving.SignatureDefMap.SignatureDefEntry\x1aM\n\x11SignatureDefEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\'\n\x05value\x18\x02 \x01(\x0b\x32\x18.tensorflow.SignatureDef:\x02\x38\x01\"d\n\x17GetModelMetadataRequest\x12\x31\n\nmodel_spec\x18\x01 \x01(\x0b\x32\x1d.tensorflow.serving.ModelSpec\x12\x16\n\x0emetadata_field\x18\x02 \x03(\t\"\xe2\x01\n\x18GetModelMetadataResponse\x12\x31\n\nmodel_spec\x18\x01 \x01(\x0b\x32\x1d.tensorflow.serving.ModelSpec\x12L\n\x08metadata\x18\x02 \x03(\x0b\x32:.tensorflow.serving.GetModelMetadataResponse.MetadataEntry\x1a\x45\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.google.protobuf.Any:\x02\x38\x01\x42\x03\xf8\x01\x01\x62\x06proto3' + , + dependencies=[google_dot_protobuf_dot_any__pb2.DESCRIPTOR,meta__graph__pb2.DESCRIPTOR,model__pb2.DESCRIPTOR,]) + + + + +_SIGNATUREDEFMAP_SIGNATUREDEFENTRY = _descriptor.Descriptor( + name='SignatureDefEntry', + full_name='tensorflow.serving.SignatureDefMap.SignatureDefEntry', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='key', full_name='tensorflow.serving.SignatureDefMap.SignatureDefEntry.key', index=0, + number=1, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='value', full_name='tensorflow.serving.SignatureDefMap.SignatureDefEntry.value', index=1, + number=2, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=b'8\001', + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=204, + serialized_end=281, +) + +_SIGNATUREDEFMAP = _descriptor.Descriptor( + name='SignatureDefMap', + full_name='tensorflow.serving.SignatureDefMap', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='signature_def', full_name='tensorflow.serving.SignatureDefMap.signature_def', index=0, + number=1, type=11, cpp_type=10, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[_SIGNATUREDEFMAP_SIGNATUREDEFENTRY, ], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=107, + serialized_end=281, +) + + +_GETMODELMETADATAREQUEST = _descriptor.Descriptor( + name='GetModelMetadataRequest', + full_name='tensorflow.serving.GetModelMetadataRequest', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='model_spec', full_name='tensorflow.serving.GetModelMetadataRequest.model_spec', index=0, + number=1, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='metadata_field', full_name='tensorflow.serving.GetModelMetadataRequest.metadata_field', index=1, + number=2, type=9, cpp_type=9, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=283, + serialized_end=383, +) + + +_GETMODELMETADATARESPONSE_METADATAENTRY = _descriptor.Descriptor( + name='MetadataEntry', + full_name='tensorflow.serving.GetModelMetadataResponse.MetadataEntry', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='key', full_name='tensorflow.serving.GetModelMetadataResponse.MetadataEntry.key', index=0, + number=1, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='value', full_name='tensorflow.serving.GetModelMetadataResponse.MetadataEntry.value', index=1, + number=2, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=b'8\001', + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=543, + serialized_end=612, +) + +_GETMODELMETADATARESPONSE = _descriptor.Descriptor( + name='GetModelMetadataResponse', + full_name='tensorflow.serving.GetModelMetadataResponse', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='model_spec', full_name='tensorflow.serving.GetModelMetadataResponse.model_spec', index=0, + number=1, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='metadata', full_name='tensorflow.serving.GetModelMetadataResponse.metadata', index=1, + number=2, type=11, cpp_type=10, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[_GETMODELMETADATARESPONSE_METADATAENTRY, ], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=386, + serialized_end=612, +) + +_SIGNATUREDEFMAP_SIGNATUREDEFENTRY.fields_by_name['value'].message_type = meta__graph__pb2._SIGNATUREDEF +_SIGNATUREDEFMAP_SIGNATUREDEFENTRY.containing_type = _SIGNATUREDEFMAP +_SIGNATUREDEFMAP.fields_by_name['signature_def'].message_type = _SIGNATUREDEFMAP_SIGNATUREDEFENTRY +_GETMODELMETADATAREQUEST.fields_by_name['model_spec'].message_type = model__pb2._MODELSPEC +_GETMODELMETADATARESPONSE_METADATAENTRY.fields_by_name['value'].message_type = google_dot_protobuf_dot_any__pb2._ANY +_GETMODELMETADATARESPONSE_METADATAENTRY.containing_type = _GETMODELMETADATARESPONSE +_GETMODELMETADATARESPONSE.fields_by_name['model_spec'].message_type = model__pb2._MODELSPEC +_GETMODELMETADATARESPONSE.fields_by_name['metadata'].message_type = _GETMODELMETADATARESPONSE_METADATAENTRY +DESCRIPTOR.message_types_by_name['SignatureDefMap'] = _SIGNATUREDEFMAP +DESCRIPTOR.message_types_by_name['GetModelMetadataRequest'] = _GETMODELMETADATAREQUEST +DESCRIPTOR.message_types_by_name['GetModelMetadataResponse'] = _GETMODELMETADATARESPONSE +_sym_db.RegisterFileDescriptor(DESCRIPTOR) + +SignatureDefMap = _reflection.GeneratedProtocolMessageType('SignatureDefMap', (_message.Message,), { + + 'SignatureDefEntry' : _reflection.GeneratedProtocolMessageType('SignatureDefEntry', (_message.Message,), { + 'DESCRIPTOR' : _SIGNATUREDEFMAP_SIGNATUREDEFENTRY, + '__module__' : 'get_model_metadata_pb2' + # @@protoc_insertion_point(class_scope:tensorflow.serving.SignatureDefMap.SignatureDefEntry) + }) + , + 'DESCRIPTOR' : _SIGNATUREDEFMAP, + '__module__' : 'get_model_metadata_pb2' + # @@protoc_insertion_point(class_scope:tensorflow.serving.SignatureDefMap) + }) +_sym_db.RegisterMessage(SignatureDefMap) +_sym_db.RegisterMessage(SignatureDefMap.SignatureDefEntry) + +GetModelMetadataRequest = _reflection.GeneratedProtocolMessageType('GetModelMetadataRequest', (_message.Message,), { + 'DESCRIPTOR' : _GETMODELMETADATAREQUEST, + '__module__' : 'get_model_metadata_pb2' + # @@protoc_insertion_point(class_scope:tensorflow.serving.GetModelMetadataRequest) + }) +_sym_db.RegisterMessage(GetModelMetadataRequest) + +GetModelMetadataResponse = _reflection.GeneratedProtocolMessageType('GetModelMetadataResponse', (_message.Message,), { + + 'MetadataEntry' : _reflection.GeneratedProtocolMessageType('MetadataEntry', (_message.Message,), { + 'DESCRIPTOR' : _GETMODELMETADATARESPONSE_METADATAENTRY, + '__module__' : 'get_model_metadata_pb2' + # @@protoc_insertion_point(class_scope:tensorflow.serving.GetModelMetadataResponse.MetadataEntry) + }) + , + 'DESCRIPTOR' : _GETMODELMETADATARESPONSE, + '__module__' : 'get_model_metadata_pb2' + # @@protoc_insertion_point(class_scope:tensorflow.serving.GetModelMetadataResponse) + }) +_sym_db.RegisterMessage(GetModelMetadataResponse) +_sym_db.RegisterMessage(GetModelMetadataResponse.MetadataEntry) + + +DESCRIPTOR._options = None +_SIGNATUREDEFMAP_SIGNATUREDEFENTRY._options = None +_GETMODELMETADATARESPONSE_METADATAENTRY._options = None +# @@protoc_insertion_point(module_scope) diff --git a/redis_consumer/pbs/get_model_metadata_pb2_grpc.py b/redis_consumer/pbs/get_model_metadata_pb2_grpc.py new file mode 100644 index 00000000..a8943526 --- /dev/null +++ b/redis_consumer/pbs/get_model_metadata_pb2_grpc.py @@ -0,0 +1,3 @@ +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +import grpc + diff --git a/redis_consumer/pbs/graph_pb2.py b/redis_consumer/pbs/graph_pb2.py new file mode 100644 index 00000000..fecd4080 --- /dev/null +++ b/redis_consumer/pbs/graph_pb2.py @@ -0,0 +1,98 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: graph.proto + +from google.protobuf import descriptor as _descriptor +from google.protobuf import message as _message +from google.protobuf import reflection as _reflection +from google.protobuf import symbol_database as _symbol_database +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +import node_def_pb2 as node__def__pb2 +import function_pb2 as function__pb2 +import versions_pb2 as versions__pb2 + + +DESCRIPTOR = _descriptor.FileDescriptor( + name='graph.proto', + package='tensorflow', + syntax='proto3', + serialized_options=b'\n\030org.tensorflow.frameworkB\013GraphProtosP\001Z=github.com/tensorflow/tensorflow/tensorflow/go/core/framework\370\001\001', + serialized_pb=b'\n\x0bgraph.proto\x12\ntensorflow\x1a\x0enode_def.proto\x1a\x0e\x66unction.proto\x1a\x0eversions.proto\"\x9d\x01\n\x08GraphDef\x12!\n\x04node\x18\x01 \x03(\x0b\x32\x13.tensorflow.NodeDef\x12(\n\x08versions\x18\x04 \x01(\x0b\x32\x16.tensorflow.VersionDef\x12\x13\n\x07version\x18\x03 \x01(\x05\x42\x02\x18\x01\x12/\n\x07library\x18\x02 \x01(\x0b\x32\x1e.tensorflow.FunctionDefLibraryBk\n\x18org.tensorflow.frameworkB\x0bGraphProtosP\x01Z=github.com/tensorflow/tensorflow/tensorflow/go/core/framework\xf8\x01\x01\x62\x06proto3' + , + dependencies=[node__def__pb2.DESCRIPTOR,function__pb2.DESCRIPTOR,versions__pb2.DESCRIPTOR,]) + + + + +_GRAPHDEF = _descriptor.Descriptor( + name='GraphDef', + full_name='tensorflow.GraphDef', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='node', full_name='tensorflow.GraphDef.node', index=0, + number=1, type=11, cpp_type=10, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='versions', full_name='tensorflow.GraphDef.versions', index=1, + number=4, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='version', full_name='tensorflow.GraphDef.version', index=2, + number=3, type=5, cpp_type=1, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=b'\030\001', file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='library', full_name='tensorflow.GraphDef.library', index=3, + number=2, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=76, + serialized_end=233, +) + +_GRAPHDEF.fields_by_name['node'].message_type = node__def__pb2._NODEDEF +_GRAPHDEF.fields_by_name['versions'].message_type = versions__pb2._VERSIONDEF +_GRAPHDEF.fields_by_name['library'].message_type = function__pb2._FUNCTIONDEFLIBRARY +DESCRIPTOR.message_types_by_name['GraphDef'] = _GRAPHDEF +_sym_db.RegisterFileDescriptor(DESCRIPTOR) + +GraphDef = _reflection.GeneratedProtocolMessageType('GraphDef', (_message.Message,), { + 'DESCRIPTOR' : _GRAPHDEF, + '__module__' : 'graph_pb2' + # @@protoc_insertion_point(class_scope:tensorflow.GraphDef) + }) +_sym_db.RegisterMessage(GraphDef) + + +DESCRIPTOR._options = None +_GRAPHDEF.fields_by_name['version']._options = None +# @@protoc_insertion_point(module_scope) diff --git a/redis_consumer/pbs/graph_pb2_grpc.py b/redis_consumer/pbs/graph_pb2_grpc.py new file mode 100644 index 00000000..a8943526 --- /dev/null +++ b/redis_consumer/pbs/graph_pb2_grpc.py @@ -0,0 +1,3 @@ +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +import grpc + diff --git a/redis_consumer/pbs/meta_graph_pb2.py b/redis_consumer/pbs/meta_graph_pb2.py new file mode 100644 index 00000000..de9d402e --- /dev/null +++ b/redis_consumer/pbs/meta_graph_pb2.py @@ -0,0 +1,1031 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: meta_graph.proto + +from google.protobuf import descriptor as _descriptor +from google.protobuf import message as _message +from google.protobuf import reflection as _reflection +from google.protobuf import symbol_database as _symbol_database +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +from google.protobuf import any_pb2 as google_dot_protobuf_dot_any__pb2 +import graph_pb2 as graph__pb2 +import op_def_pb2 as op__def__pb2 +import tensor_shape_pb2 as tensor__shape__pb2 +import types_pb2 as types__pb2 +import saved_object_graph_pb2 as saved__object__graph__pb2 +import saver_pb2 as saver__pb2 +import struct_pb2 as struct__pb2 + + +DESCRIPTOR = _descriptor.FileDescriptor( + name='meta_graph.proto', + package='tensorflow', + syntax='proto3', + serialized_options=b'\n\030org.tensorflow.frameworkB\017MetaGraphProtosP\001ZHgithub.com/tensorflow/tensorflow/tensorflow/go/core/core_protos_go_proto\370\001\001', + serialized_pb=b'\n\x10meta_graph.proto\x12\ntensorflow\x1a\x19google/protobuf/any.proto\x1a\x0bgraph.proto\x1a\x0cop_def.proto\x1a\x12tensor_shape.proto\x1a\x0btypes.proto\x1a\x18saved_object_graph.proto\x1a\x0bsaver.proto\x1a\x0cstruct.proto\"\xa8\x07\n\x0cMetaGraphDef\x12;\n\rmeta_info_def\x18\x01 \x01(\x0b\x32$.tensorflow.MetaGraphDef.MetaInfoDef\x12\'\n\tgraph_def\x18\x02 \x01(\x0b\x32\x14.tensorflow.GraphDef\x12\'\n\tsaver_def\x18\x03 \x01(\x0b\x32\x14.tensorflow.SaverDef\x12\x43\n\x0e\x63ollection_def\x18\x04 \x03(\x0b\x32+.tensorflow.MetaGraphDef.CollectionDefEntry\x12\x41\n\rsignature_def\x18\x05 \x03(\x0b\x32*.tensorflow.MetaGraphDef.SignatureDefEntry\x12\x30\n\x0e\x61sset_file_def\x18\x06 \x03(\x0b\x32\x18.tensorflow.AssetFileDef\x12\x36\n\x10object_graph_def\x18\x07 \x01(\x0b\x32\x1c.tensorflow.SavedObjectGraph\x1a\xf6\x02\n\x0bMetaInfoDef\x12\x1a\n\x12meta_graph_version\x18\x01 \x01(\t\x12,\n\x10stripped_op_list\x18\x02 \x01(\x0b\x32\x12.tensorflow.OpList\x12&\n\x08\x61ny_info\x18\x03 \x01(\x0b\x32\x14.google.protobuf.Any\x12\x0c\n\x04tags\x18\x04 \x03(\t\x12\x1a\n\x12tensorflow_version\x18\x05 \x01(\t\x12\x1e\n\x16tensorflow_git_version\x18\x06 \x01(\t\x12\x1e\n\x16stripped_default_attrs\x18\x07 \x01(\x08\x12S\n\x10\x66unction_aliases\x18\x08 \x03(\x0b\x32\x39.tensorflow.MetaGraphDef.MetaInfoDef.FunctionAliasesEntry\x1a\x36\n\x14\x46unctionAliasesEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\x1aO\n\x12\x43ollectionDefEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12(\n\x05value\x18\x02 \x01(\x0b\x32\x19.tensorflow.CollectionDef:\x02\x38\x01\x1aM\n\x11SignatureDefEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\'\n\x05value\x18\x02 \x01(\x0b\x32\x18.tensorflow.SignatureDef:\x02\x38\x01\"\xdf\x03\n\rCollectionDef\x12\x37\n\tnode_list\x18\x01 \x01(\x0b\x32\".tensorflow.CollectionDef.NodeListH\x00\x12\x39\n\nbytes_list\x18\x02 \x01(\x0b\x32#.tensorflow.CollectionDef.BytesListH\x00\x12\x39\n\nint64_list\x18\x03 \x01(\x0b\x32#.tensorflow.CollectionDef.Int64ListH\x00\x12\x39\n\nfloat_list\x18\x04 \x01(\x0b\x32#.tensorflow.CollectionDef.FloatListH\x00\x12\x35\n\x08\x61ny_list\x18\x05 \x01(\x0b\x32!.tensorflow.CollectionDef.AnyListH\x00\x1a\x19\n\x08NodeList\x12\r\n\x05value\x18\x01 \x03(\t\x1a\x1a\n\tBytesList\x12\r\n\x05value\x18\x01 \x03(\x0c\x1a\x1e\n\tInt64List\x12\x11\n\x05value\x18\x01 \x03(\x03\x42\x02\x10\x01\x1a\x1e\n\tFloatList\x12\x11\n\x05value\x18\x01 \x03(\x02\x42\x02\x10\x01\x1a.\n\x07\x41nyList\x12#\n\x05value\x18\x01 \x03(\x0b\x32\x14.google.protobuf.AnyB\x06\n\x04kind\"\xd1\x03\n\nTensorInfo\x12\x0e\n\x04name\x18\x01 \x01(\tH\x00\x12\x36\n\ncoo_sparse\x18\x04 \x01(\x0b\x32 .tensorflow.TensorInfo.CooSparseH\x00\x12\x42\n\x10\x63omposite_tensor\x18\x05 \x01(\x0b\x32&.tensorflow.TensorInfo.CompositeTensorH\x00\x12#\n\x05\x64type\x18\x02 \x01(\x0e\x32\x14.tensorflow.DataType\x12\x32\n\x0ctensor_shape\x18\x03 \x01(\x0b\x32\x1c.tensorflow.TensorShapeProto\x1a\x65\n\tCooSparse\x12\x1a\n\x12values_tensor_name\x18\x01 \x01(\t\x12\x1b\n\x13indices_tensor_name\x18\x02 \x01(\t\x12\x1f\n\x17\x64\x65nse_shape_tensor_name\x18\x03 \x01(\t\x1ak\n\x0f\x43ompositeTensor\x12,\n\ttype_spec\x18\x01 \x01(\x0b\x32\x19.tensorflow.TypeSpecProto\x12*\n\ncomponents\x18\x02 \x03(\x0b\x32\x16.tensorflow.TensorInfoB\n\n\x08\x65ncoding\"\xa0\x02\n\x0cSignatureDef\x12\x34\n\x06inputs\x18\x01 \x03(\x0b\x32$.tensorflow.SignatureDef.InputsEntry\x12\x36\n\x07outputs\x18\x02 \x03(\x0b\x32%.tensorflow.SignatureDef.OutputsEntry\x12\x13\n\x0bmethod_name\x18\x03 \x01(\t\x1a\x45\n\x0bInputsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12%\n\x05value\x18\x02 \x01(\x0b\x32\x16.tensorflow.TensorInfo:\x02\x38\x01\x1a\x46\n\x0cOutputsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12%\n\x05value\x18\x02 \x01(\x0b\x32\x16.tensorflow.TensorInfo:\x02\x38\x01\"M\n\x0c\x41ssetFileDef\x12+\n\x0btensor_info\x18\x01 \x01(\x0b\x32\x16.tensorflow.TensorInfo\x12\x10\n\x08\x66ilename\x18\x02 \x01(\tBz\n\x18org.tensorflow.frameworkB\x0fMetaGraphProtosP\x01ZHgithub.com/tensorflow/tensorflow/tensorflow/go/core/core_protos_go_proto\xf8\x01\x01\x62\x06proto3' + , + dependencies=[google_dot_protobuf_dot_any__pb2.DESCRIPTOR,graph__pb2.DESCRIPTOR,op__def__pb2.DESCRIPTOR,tensor__shape__pb2.DESCRIPTOR,types__pb2.DESCRIPTOR,saved__object__graph__pb2.DESCRIPTOR,saver__pb2.DESCRIPTOR,struct__pb2.DESCRIPTOR,]) + + + + +_METAGRAPHDEF_METAINFODEF_FUNCTIONALIASESENTRY = _descriptor.Descriptor( + name='FunctionAliasesEntry', + full_name='tensorflow.MetaGraphDef.MetaInfoDef.FunctionAliasesEntry', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='key', full_name='tensorflow.MetaGraphDef.MetaInfoDef.FunctionAliasesEntry.key', index=0, + number=1, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='value', full_name='tensorflow.MetaGraphDef.MetaInfoDef.FunctionAliasesEntry.value', index=1, + number=2, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=b'8\001', + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=895, + serialized_end=949, +) + +_METAGRAPHDEF_METAINFODEF = _descriptor.Descriptor( + name='MetaInfoDef', + full_name='tensorflow.MetaGraphDef.MetaInfoDef', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='meta_graph_version', full_name='tensorflow.MetaGraphDef.MetaInfoDef.meta_graph_version', index=0, + number=1, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='stripped_op_list', full_name='tensorflow.MetaGraphDef.MetaInfoDef.stripped_op_list', index=1, + number=2, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='any_info', full_name='tensorflow.MetaGraphDef.MetaInfoDef.any_info', index=2, + number=3, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='tags', full_name='tensorflow.MetaGraphDef.MetaInfoDef.tags', index=3, + number=4, type=9, cpp_type=9, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='tensorflow_version', full_name='tensorflow.MetaGraphDef.MetaInfoDef.tensorflow_version', index=4, + number=5, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='tensorflow_git_version', full_name='tensorflow.MetaGraphDef.MetaInfoDef.tensorflow_git_version', index=5, + number=6, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='stripped_default_attrs', full_name='tensorflow.MetaGraphDef.MetaInfoDef.stripped_default_attrs', index=6, + number=7, type=8, cpp_type=7, label=1, + has_default_value=False, default_value=False, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='function_aliases', full_name='tensorflow.MetaGraphDef.MetaInfoDef.function_aliases', index=7, + number=8, type=11, cpp_type=10, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[_METAGRAPHDEF_METAINFODEF_FUNCTIONALIASESENTRY, ], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=575, + serialized_end=949, +) + +_METAGRAPHDEF_COLLECTIONDEFENTRY = _descriptor.Descriptor( + name='CollectionDefEntry', + full_name='tensorflow.MetaGraphDef.CollectionDefEntry', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='key', full_name='tensorflow.MetaGraphDef.CollectionDefEntry.key', index=0, + number=1, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='value', full_name='tensorflow.MetaGraphDef.CollectionDefEntry.value', index=1, + number=2, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=b'8\001', + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=951, + serialized_end=1030, +) + +_METAGRAPHDEF_SIGNATUREDEFENTRY = _descriptor.Descriptor( + name='SignatureDefEntry', + full_name='tensorflow.MetaGraphDef.SignatureDefEntry', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='key', full_name='tensorflow.MetaGraphDef.SignatureDefEntry.key', index=0, + number=1, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='value', full_name='tensorflow.MetaGraphDef.SignatureDefEntry.value', index=1, + number=2, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=b'8\001', + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=1032, + serialized_end=1109, +) + +_METAGRAPHDEF = _descriptor.Descriptor( + name='MetaGraphDef', + full_name='tensorflow.MetaGraphDef', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='meta_info_def', full_name='tensorflow.MetaGraphDef.meta_info_def', index=0, + number=1, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='graph_def', full_name='tensorflow.MetaGraphDef.graph_def', index=1, + number=2, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='saver_def', full_name='tensorflow.MetaGraphDef.saver_def', index=2, + number=3, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='collection_def', full_name='tensorflow.MetaGraphDef.collection_def', index=3, + number=4, type=11, cpp_type=10, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='signature_def', full_name='tensorflow.MetaGraphDef.signature_def', index=4, + number=5, type=11, cpp_type=10, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='asset_file_def', full_name='tensorflow.MetaGraphDef.asset_file_def', index=5, + number=6, type=11, cpp_type=10, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='object_graph_def', full_name='tensorflow.MetaGraphDef.object_graph_def', index=6, + number=7, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[_METAGRAPHDEF_METAINFODEF, _METAGRAPHDEF_COLLECTIONDEFENTRY, _METAGRAPHDEF_SIGNATUREDEFENTRY, ], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=173, + serialized_end=1109, +) + + +_COLLECTIONDEF_NODELIST = _descriptor.Descriptor( + name='NodeList', + full_name='tensorflow.CollectionDef.NodeList', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='value', full_name='tensorflow.CollectionDef.NodeList.value', index=0, + number=1, type=9, cpp_type=9, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=1418, + serialized_end=1443, +) + +_COLLECTIONDEF_BYTESLIST = _descriptor.Descriptor( + name='BytesList', + full_name='tensorflow.CollectionDef.BytesList', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='value', full_name='tensorflow.CollectionDef.BytesList.value', index=0, + number=1, type=12, cpp_type=9, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=1445, + serialized_end=1471, +) + +_COLLECTIONDEF_INT64LIST = _descriptor.Descriptor( + name='Int64List', + full_name='tensorflow.CollectionDef.Int64List', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='value', full_name='tensorflow.CollectionDef.Int64List.value', index=0, + number=1, type=3, cpp_type=2, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=b'\020\001', file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=1473, + serialized_end=1503, +) + +_COLLECTIONDEF_FLOATLIST = _descriptor.Descriptor( + name='FloatList', + full_name='tensorflow.CollectionDef.FloatList', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='value', full_name='tensorflow.CollectionDef.FloatList.value', index=0, + number=1, type=2, cpp_type=6, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=b'\020\001', file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=1505, + serialized_end=1535, +) + +_COLLECTIONDEF_ANYLIST = _descriptor.Descriptor( + name='AnyList', + full_name='tensorflow.CollectionDef.AnyList', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='value', full_name='tensorflow.CollectionDef.AnyList.value', index=0, + number=1, type=11, cpp_type=10, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=1537, + serialized_end=1583, +) + +_COLLECTIONDEF = _descriptor.Descriptor( + name='CollectionDef', + full_name='tensorflow.CollectionDef', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='node_list', full_name='tensorflow.CollectionDef.node_list', index=0, + number=1, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='bytes_list', full_name='tensorflow.CollectionDef.bytes_list', index=1, + number=2, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='int64_list', full_name='tensorflow.CollectionDef.int64_list', index=2, + number=3, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='float_list', full_name='tensorflow.CollectionDef.float_list', index=3, + number=4, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='any_list', full_name='tensorflow.CollectionDef.any_list', index=4, + number=5, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[_COLLECTIONDEF_NODELIST, _COLLECTIONDEF_BYTESLIST, _COLLECTIONDEF_INT64LIST, _COLLECTIONDEF_FLOATLIST, _COLLECTIONDEF_ANYLIST, ], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + _descriptor.OneofDescriptor( + name='kind', full_name='tensorflow.CollectionDef.kind', + index=0, containing_type=None, fields=[]), + ], + serialized_start=1112, + serialized_end=1591, +) + + +_TENSORINFO_COOSPARSE = _descriptor.Descriptor( + name='CooSparse', + full_name='tensorflow.TensorInfo.CooSparse', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='values_tensor_name', full_name='tensorflow.TensorInfo.CooSparse.values_tensor_name', index=0, + number=1, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='indices_tensor_name', full_name='tensorflow.TensorInfo.CooSparse.indices_tensor_name', index=1, + number=2, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='dense_shape_tensor_name', full_name='tensorflow.TensorInfo.CooSparse.dense_shape_tensor_name', index=2, + number=3, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=1837, + serialized_end=1938, +) + +_TENSORINFO_COMPOSITETENSOR = _descriptor.Descriptor( + name='CompositeTensor', + full_name='tensorflow.TensorInfo.CompositeTensor', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='type_spec', full_name='tensorflow.TensorInfo.CompositeTensor.type_spec', index=0, + number=1, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='components', full_name='tensorflow.TensorInfo.CompositeTensor.components', index=1, + number=2, type=11, cpp_type=10, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=1940, + serialized_end=2047, +) + +_TENSORINFO = _descriptor.Descriptor( + name='TensorInfo', + full_name='tensorflow.TensorInfo', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='name', full_name='tensorflow.TensorInfo.name', index=0, + number=1, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='coo_sparse', full_name='tensorflow.TensorInfo.coo_sparse', index=1, + number=4, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='composite_tensor', full_name='tensorflow.TensorInfo.composite_tensor', index=2, + number=5, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='dtype', full_name='tensorflow.TensorInfo.dtype', index=3, + number=2, type=14, cpp_type=8, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='tensor_shape', full_name='tensorflow.TensorInfo.tensor_shape', index=4, + number=3, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[_TENSORINFO_COOSPARSE, _TENSORINFO_COMPOSITETENSOR, ], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + _descriptor.OneofDescriptor( + name='encoding', full_name='tensorflow.TensorInfo.encoding', + index=0, containing_type=None, fields=[]), + ], + serialized_start=1594, + serialized_end=2059, +) + + +_SIGNATUREDEF_INPUTSENTRY = _descriptor.Descriptor( + name='InputsEntry', + full_name='tensorflow.SignatureDef.InputsEntry', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='key', full_name='tensorflow.SignatureDef.InputsEntry.key', index=0, + number=1, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='value', full_name='tensorflow.SignatureDef.InputsEntry.value', index=1, + number=2, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=b'8\001', + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=2209, + serialized_end=2278, +) + +_SIGNATUREDEF_OUTPUTSENTRY = _descriptor.Descriptor( + name='OutputsEntry', + full_name='tensorflow.SignatureDef.OutputsEntry', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='key', full_name='tensorflow.SignatureDef.OutputsEntry.key', index=0, + number=1, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='value', full_name='tensorflow.SignatureDef.OutputsEntry.value', index=1, + number=2, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=b'8\001', + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=2280, + serialized_end=2350, +) + +_SIGNATUREDEF = _descriptor.Descriptor( + name='SignatureDef', + full_name='tensorflow.SignatureDef', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='inputs', full_name='tensorflow.SignatureDef.inputs', index=0, + number=1, type=11, cpp_type=10, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='outputs', full_name='tensorflow.SignatureDef.outputs', index=1, + number=2, type=11, cpp_type=10, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='method_name', full_name='tensorflow.SignatureDef.method_name', index=2, + number=3, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[_SIGNATUREDEF_INPUTSENTRY, _SIGNATUREDEF_OUTPUTSENTRY, ], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=2062, + serialized_end=2350, +) + + +_ASSETFILEDEF = _descriptor.Descriptor( + name='AssetFileDef', + full_name='tensorflow.AssetFileDef', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='tensor_info', full_name='tensorflow.AssetFileDef.tensor_info', index=0, + number=1, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='filename', full_name='tensorflow.AssetFileDef.filename', index=1, + number=2, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=2352, + serialized_end=2429, +) + +_METAGRAPHDEF_METAINFODEF_FUNCTIONALIASESENTRY.containing_type = _METAGRAPHDEF_METAINFODEF +_METAGRAPHDEF_METAINFODEF.fields_by_name['stripped_op_list'].message_type = op__def__pb2._OPLIST +_METAGRAPHDEF_METAINFODEF.fields_by_name['any_info'].message_type = google_dot_protobuf_dot_any__pb2._ANY +_METAGRAPHDEF_METAINFODEF.fields_by_name['function_aliases'].message_type = _METAGRAPHDEF_METAINFODEF_FUNCTIONALIASESENTRY +_METAGRAPHDEF_METAINFODEF.containing_type = _METAGRAPHDEF +_METAGRAPHDEF_COLLECTIONDEFENTRY.fields_by_name['value'].message_type = _COLLECTIONDEF +_METAGRAPHDEF_COLLECTIONDEFENTRY.containing_type = _METAGRAPHDEF +_METAGRAPHDEF_SIGNATUREDEFENTRY.fields_by_name['value'].message_type = _SIGNATUREDEF +_METAGRAPHDEF_SIGNATUREDEFENTRY.containing_type = _METAGRAPHDEF +_METAGRAPHDEF.fields_by_name['meta_info_def'].message_type = _METAGRAPHDEF_METAINFODEF +_METAGRAPHDEF.fields_by_name['graph_def'].message_type = graph__pb2._GRAPHDEF +_METAGRAPHDEF.fields_by_name['saver_def'].message_type = saver__pb2._SAVERDEF +_METAGRAPHDEF.fields_by_name['collection_def'].message_type = _METAGRAPHDEF_COLLECTIONDEFENTRY +_METAGRAPHDEF.fields_by_name['signature_def'].message_type = _METAGRAPHDEF_SIGNATUREDEFENTRY +_METAGRAPHDEF.fields_by_name['asset_file_def'].message_type = _ASSETFILEDEF +_METAGRAPHDEF.fields_by_name['object_graph_def'].message_type = saved__object__graph__pb2._SAVEDOBJECTGRAPH +_COLLECTIONDEF_NODELIST.containing_type = _COLLECTIONDEF +_COLLECTIONDEF_BYTESLIST.containing_type = _COLLECTIONDEF +_COLLECTIONDEF_INT64LIST.containing_type = _COLLECTIONDEF +_COLLECTIONDEF_FLOATLIST.containing_type = _COLLECTIONDEF +_COLLECTIONDEF_ANYLIST.fields_by_name['value'].message_type = google_dot_protobuf_dot_any__pb2._ANY +_COLLECTIONDEF_ANYLIST.containing_type = _COLLECTIONDEF +_COLLECTIONDEF.fields_by_name['node_list'].message_type = _COLLECTIONDEF_NODELIST +_COLLECTIONDEF.fields_by_name['bytes_list'].message_type = _COLLECTIONDEF_BYTESLIST +_COLLECTIONDEF.fields_by_name['int64_list'].message_type = _COLLECTIONDEF_INT64LIST +_COLLECTIONDEF.fields_by_name['float_list'].message_type = _COLLECTIONDEF_FLOATLIST +_COLLECTIONDEF.fields_by_name['any_list'].message_type = _COLLECTIONDEF_ANYLIST +_COLLECTIONDEF.oneofs_by_name['kind'].fields.append( + _COLLECTIONDEF.fields_by_name['node_list']) +_COLLECTIONDEF.fields_by_name['node_list'].containing_oneof = _COLLECTIONDEF.oneofs_by_name['kind'] +_COLLECTIONDEF.oneofs_by_name['kind'].fields.append( + _COLLECTIONDEF.fields_by_name['bytes_list']) +_COLLECTIONDEF.fields_by_name['bytes_list'].containing_oneof = _COLLECTIONDEF.oneofs_by_name['kind'] +_COLLECTIONDEF.oneofs_by_name['kind'].fields.append( + _COLLECTIONDEF.fields_by_name['int64_list']) +_COLLECTIONDEF.fields_by_name['int64_list'].containing_oneof = _COLLECTIONDEF.oneofs_by_name['kind'] +_COLLECTIONDEF.oneofs_by_name['kind'].fields.append( + _COLLECTIONDEF.fields_by_name['float_list']) +_COLLECTIONDEF.fields_by_name['float_list'].containing_oneof = _COLLECTIONDEF.oneofs_by_name['kind'] +_COLLECTIONDEF.oneofs_by_name['kind'].fields.append( + _COLLECTIONDEF.fields_by_name['any_list']) +_COLLECTIONDEF.fields_by_name['any_list'].containing_oneof = _COLLECTIONDEF.oneofs_by_name['kind'] +_TENSORINFO_COOSPARSE.containing_type = _TENSORINFO +_TENSORINFO_COMPOSITETENSOR.fields_by_name['type_spec'].message_type = struct__pb2._TYPESPECPROTO +_TENSORINFO_COMPOSITETENSOR.fields_by_name['components'].message_type = _TENSORINFO +_TENSORINFO_COMPOSITETENSOR.containing_type = _TENSORINFO +_TENSORINFO.fields_by_name['coo_sparse'].message_type = _TENSORINFO_COOSPARSE +_TENSORINFO.fields_by_name['composite_tensor'].message_type = _TENSORINFO_COMPOSITETENSOR +_TENSORINFO.fields_by_name['dtype'].enum_type = types__pb2._DATATYPE +_TENSORINFO.fields_by_name['tensor_shape'].message_type = tensor__shape__pb2._TENSORSHAPEPROTO +_TENSORINFO.oneofs_by_name['encoding'].fields.append( + _TENSORINFO.fields_by_name['name']) +_TENSORINFO.fields_by_name['name'].containing_oneof = _TENSORINFO.oneofs_by_name['encoding'] +_TENSORINFO.oneofs_by_name['encoding'].fields.append( + _TENSORINFO.fields_by_name['coo_sparse']) +_TENSORINFO.fields_by_name['coo_sparse'].containing_oneof = _TENSORINFO.oneofs_by_name['encoding'] +_TENSORINFO.oneofs_by_name['encoding'].fields.append( + _TENSORINFO.fields_by_name['composite_tensor']) +_TENSORINFO.fields_by_name['composite_tensor'].containing_oneof = _TENSORINFO.oneofs_by_name['encoding'] +_SIGNATUREDEF_INPUTSENTRY.fields_by_name['value'].message_type = _TENSORINFO +_SIGNATUREDEF_INPUTSENTRY.containing_type = _SIGNATUREDEF +_SIGNATUREDEF_OUTPUTSENTRY.fields_by_name['value'].message_type = _TENSORINFO +_SIGNATUREDEF_OUTPUTSENTRY.containing_type = _SIGNATUREDEF +_SIGNATUREDEF.fields_by_name['inputs'].message_type = _SIGNATUREDEF_INPUTSENTRY +_SIGNATUREDEF.fields_by_name['outputs'].message_type = _SIGNATUREDEF_OUTPUTSENTRY +_ASSETFILEDEF.fields_by_name['tensor_info'].message_type = _TENSORINFO +DESCRIPTOR.message_types_by_name['MetaGraphDef'] = _METAGRAPHDEF +DESCRIPTOR.message_types_by_name['CollectionDef'] = _COLLECTIONDEF +DESCRIPTOR.message_types_by_name['TensorInfo'] = _TENSORINFO +DESCRIPTOR.message_types_by_name['SignatureDef'] = _SIGNATUREDEF +DESCRIPTOR.message_types_by_name['AssetFileDef'] = _ASSETFILEDEF +_sym_db.RegisterFileDescriptor(DESCRIPTOR) + +MetaGraphDef = _reflection.GeneratedProtocolMessageType('MetaGraphDef', (_message.Message,), { + + 'MetaInfoDef' : _reflection.GeneratedProtocolMessageType('MetaInfoDef', (_message.Message,), { + + 'FunctionAliasesEntry' : _reflection.GeneratedProtocolMessageType('FunctionAliasesEntry', (_message.Message,), { + 'DESCRIPTOR' : _METAGRAPHDEF_METAINFODEF_FUNCTIONALIASESENTRY, + '__module__' : 'meta_graph_pb2' + # @@protoc_insertion_point(class_scope:tensorflow.MetaGraphDef.MetaInfoDef.FunctionAliasesEntry) + }) + , + 'DESCRIPTOR' : _METAGRAPHDEF_METAINFODEF, + '__module__' : 'meta_graph_pb2' + # @@protoc_insertion_point(class_scope:tensorflow.MetaGraphDef.MetaInfoDef) + }) + , + + 'CollectionDefEntry' : _reflection.GeneratedProtocolMessageType('CollectionDefEntry', (_message.Message,), { + 'DESCRIPTOR' : _METAGRAPHDEF_COLLECTIONDEFENTRY, + '__module__' : 'meta_graph_pb2' + # @@protoc_insertion_point(class_scope:tensorflow.MetaGraphDef.CollectionDefEntry) + }) + , + + 'SignatureDefEntry' : _reflection.GeneratedProtocolMessageType('SignatureDefEntry', (_message.Message,), { + 'DESCRIPTOR' : _METAGRAPHDEF_SIGNATUREDEFENTRY, + '__module__' : 'meta_graph_pb2' + # @@protoc_insertion_point(class_scope:tensorflow.MetaGraphDef.SignatureDefEntry) + }) + , + 'DESCRIPTOR' : _METAGRAPHDEF, + '__module__' : 'meta_graph_pb2' + # @@protoc_insertion_point(class_scope:tensorflow.MetaGraphDef) + }) +_sym_db.RegisterMessage(MetaGraphDef) +_sym_db.RegisterMessage(MetaGraphDef.MetaInfoDef) +_sym_db.RegisterMessage(MetaGraphDef.MetaInfoDef.FunctionAliasesEntry) +_sym_db.RegisterMessage(MetaGraphDef.CollectionDefEntry) +_sym_db.RegisterMessage(MetaGraphDef.SignatureDefEntry) + +CollectionDef = _reflection.GeneratedProtocolMessageType('CollectionDef', (_message.Message,), { + + 'NodeList' : _reflection.GeneratedProtocolMessageType('NodeList', (_message.Message,), { + 'DESCRIPTOR' : _COLLECTIONDEF_NODELIST, + '__module__' : 'meta_graph_pb2' + # @@protoc_insertion_point(class_scope:tensorflow.CollectionDef.NodeList) + }) + , + + 'BytesList' : _reflection.GeneratedProtocolMessageType('BytesList', (_message.Message,), { + 'DESCRIPTOR' : _COLLECTIONDEF_BYTESLIST, + '__module__' : 'meta_graph_pb2' + # @@protoc_insertion_point(class_scope:tensorflow.CollectionDef.BytesList) + }) + , + + 'Int64List' : _reflection.GeneratedProtocolMessageType('Int64List', (_message.Message,), { + 'DESCRIPTOR' : _COLLECTIONDEF_INT64LIST, + '__module__' : 'meta_graph_pb2' + # @@protoc_insertion_point(class_scope:tensorflow.CollectionDef.Int64List) + }) + , + + 'FloatList' : _reflection.GeneratedProtocolMessageType('FloatList', (_message.Message,), { + 'DESCRIPTOR' : _COLLECTIONDEF_FLOATLIST, + '__module__' : 'meta_graph_pb2' + # @@protoc_insertion_point(class_scope:tensorflow.CollectionDef.FloatList) + }) + , + + 'AnyList' : _reflection.GeneratedProtocolMessageType('AnyList', (_message.Message,), { + 'DESCRIPTOR' : _COLLECTIONDEF_ANYLIST, + '__module__' : 'meta_graph_pb2' + # @@protoc_insertion_point(class_scope:tensorflow.CollectionDef.AnyList) + }) + , + 'DESCRIPTOR' : _COLLECTIONDEF, + '__module__' : 'meta_graph_pb2' + # @@protoc_insertion_point(class_scope:tensorflow.CollectionDef) + }) +_sym_db.RegisterMessage(CollectionDef) +_sym_db.RegisterMessage(CollectionDef.NodeList) +_sym_db.RegisterMessage(CollectionDef.BytesList) +_sym_db.RegisterMessage(CollectionDef.Int64List) +_sym_db.RegisterMessage(CollectionDef.FloatList) +_sym_db.RegisterMessage(CollectionDef.AnyList) + +TensorInfo = _reflection.GeneratedProtocolMessageType('TensorInfo', (_message.Message,), { + + 'CooSparse' : _reflection.GeneratedProtocolMessageType('CooSparse', (_message.Message,), { + 'DESCRIPTOR' : _TENSORINFO_COOSPARSE, + '__module__' : 'meta_graph_pb2' + # @@protoc_insertion_point(class_scope:tensorflow.TensorInfo.CooSparse) + }) + , + + 'CompositeTensor' : _reflection.GeneratedProtocolMessageType('CompositeTensor', (_message.Message,), { + 'DESCRIPTOR' : _TENSORINFO_COMPOSITETENSOR, + '__module__' : 'meta_graph_pb2' + # @@protoc_insertion_point(class_scope:tensorflow.TensorInfo.CompositeTensor) + }) + , + 'DESCRIPTOR' : _TENSORINFO, + '__module__' : 'meta_graph_pb2' + # @@protoc_insertion_point(class_scope:tensorflow.TensorInfo) + }) +_sym_db.RegisterMessage(TensorInfo) +_sym_db.RegisterMessage(TensorInfo.CooSparse) +_sym_db.RegisterMessage(TensorInfo.CompositeTensor) + +SignatureDef = _reflection.GeneratedProtocolMessageType('SignatureDef', (_message.Message,), { + + 'InputsEntry' : _reflection.GeneratedProtocolMessageType('InputsEntry', (_message.Message,), { + 'DESCRIPTOR' : _SIGNATUREDEF_INPUTSENTRY, + '__module__' : 'meta_graph_pb2' + # @@protoc_insertion_point(class_scope:tensorflow.SignatureDef.InputsEntry) + }) + , + + 'OutputsEntry' : _reflection.GeneratedProtocolMessageType('OutputsEntry', (_message.Message,), { + 'DESCRIPTOR' : _SIGNATUREDEF_OUTPUTSENTRY, + '__module__' : 'meta_graph_pb2' + # @@protoc_insertion_point(class_scope:tensorflow.SignatureDef.OutputsEntry) + }) + , + 'DESCRIPTOR' : _SIGNATUREDEF, + '__module__' : 'meta_graph_pb2' + # @@protoc_insertion_point(class_scope:tensorflow.SignatureDef) + }) +_sym_db.RegisterMessage(SignatureDef) +_sym_db.RegisterMessage(SignatureDef.InputsEntry) +_sym_db.RegisterMessage(SignatureDef.OutputsEntry) + +AssetFileDef = _reflection.GeneratedProtocolMessageType('AssetFileDef', (_message.Message,), { + 'DESCRIPTOR' : _ASSETFILEDEF, + '__module__' : 'meta_graph_pb2' + # @@protoc_insertion_point(class_scope:tensorflow.AssetFileDef) + }) +_sym_db.RegisterMessage(AssetFileDef) + + +DESCRIPTOR._options = None +_METAGRAPHDEF_METAINFODEF_FUNCTIONALIASESENTRY._options = None +_METAGRAPHDEF_COLLECTIONDEFENTRY._options = None +_METAGRAPHDEF_SIGNATUREDEFENTRY._options = None +_COLLECTIONDEF_INT64LIST.fields_by_name['value']._options = None +_COLLECTIONDEF_FLOATLIST.fields_by_name['value']._options = None +_SIGNATUREDEF_INPUTSENTRY._options = None +_SIGNATUREDEF_OUTPUTSENTRY._options = None +# @@protoc_insertion_point(module_scope) diff --git a/redis_consumer/pbs/meta_graph_pb2_grpc.py b/redis_consumer/pbs/meta_graph_pb2_grpc.py new file mode 100644 index 00000000..a8943526 --- /dev/null +++ b/redis_consumer/pbs/meta_graph_pb2_grpc.py @@ -0,0 +1,3 @@ +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +import grpc + diff --git a/redis_consumer/pbs/model_pb2.py b/redis_consumer/pbs/model_pb2.py index 13ac6337..8bddc4ca 100644 --- a/redis_consumer/pbs/model_pb2.py +++ b/redis_consumer/pbs/model_pb2.py @@ -1,8 +1,7 @@ +# -*- coding: utf-8 -*- # Generated by the protocol buffer compiler. DO NOT EDIT! # source: model.proto -import sys -_b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) from google.protobuf import descriptor as _descriptor from google.protobuf import message as _message from google.protobuf import reflection as _reflection @@ -19,8 +18,8 @@ name='model.proto', package='tensorflow.serving', syntax='proto3', - serialized_options=_b('\370\001\001'), - serialized_pb=_b('\n\x0bmodel.proto\x12\x12tensorflow.serving\x1a\x1egoogle/protobuf/wrappers.proto\"G\n\tModelSpec\x12\x0c\n\x04name\x18\x01 \x01(\t\x12,\n\x07version\x18\x02 \x01(\x0b\x32\x1b.google.protobuf.Int64ValueB\x03\xf8\x01\x01\x62\x06proto3') + serialized_options=b'\370\001\001', + serialized_pb=b'\n\x0bmodel.proto\x12\x12tensorflow.serving\x1a\x1egoogle/protobuf/wrappers.proto\"\x8c\x01\n\tModelSpec\x12\x0c\n\x04name\x18\x01 \x01(\t\x12.\n\x07version\x18\x02 \x01(\x0b\x32\x1b.google.protobuf.Int64ValueH\x00\x12\x17\n\rversion_label\x18\x04 \x01(\tH\x00\x12\x16\n\x0esignature_name\x18\x03 \x01(\tB\x10\n\x0eversion_choiceB\x03\xf8\x01\x01\x62\x06proto3' , dependencies=[google_dot_protobuf_dot_wrappers__pb2.DESCRIPTOR,]) @@ -37,7 +36,7 @@ _descriptor.FieldDescriptor( name='name', full_name='tensorflow.serving.ModelSpec.name', index=0, number=1, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=_b("").decode('utf-8'), + has_default_value=False, default_value=b"".decode('utf-8'), message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, serialized_options=None, file=DESCRIPTOR), @@ -48,6 +47,20 @@ message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='version_label', full_name='tensorflow.serving.ModelSpec.version_label', index=2, + number=4, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='signature_name', full_name='tensorflow.serving.ModelSpec.signature_name', index=3, + number=3, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), ], extensions=[ ], @@ -59,20 +72,29 @@ syntax='proto3', extension_ranges=[], oneofs=[ + _descriptor.OneofDescriptor( + name='version_choice', full_name='tensorflow.serving.ModelSpec.version_choice', + index=0, containing_type=None, fields=[]), ], - serialized_start=67, - serialized_end=138, + serialized_start=68, + serialized_end=208, ) _MODELSPEC.fields_by_name['version'].message_type = google_dot_protobuf_dot_wrappers__pb2._INT64VALUE +_MODELSPEC.oneofs_by_name['version_choice'].fields.append( + _MODELSPEC.fields_by_name['version']) +_MODELSPEC.fields_by_name['version'].containing_oneof = _MODELSPEC.oneofs_by_name['version_choice'] +_MODELSPEC.oneofs_by_name['version_choice'].fields.append( + _MODELSPEC.fields_by_name['version_label']) +_MODELSPEC.fields_by_name['version_label'].containing_oneof = _MODELSPEC.oneofs_by_name['version_choice'] DESCRIPTOR.message_types_by_name['ModelSpec'] = _MODELSPEC _sym_db.RegisterFileDescriptor(DESCRIPTOR) -ModelSpec = _reflection.GeneratedProtocolMessageType('ModelSpec', (_message.Message,), dict( - DESCRIPTOR = _MODELSPEC, - __module__ = 'model_pb2' +ModelSpec = _reflection.GeneratedProtocolMessageType('ModelSpec', (_message.Message,), { + 'DESCRIPTOR' : _MODELSPEC, + '__module__' : 'model_pb2' # @@protoc_insertion_point(class_scope:tensorflow.serving.ModelSpec) - )) + }) _sym_db.RegisterMessage(ModelSpec) diff --git a/redis_consumer/pbs/node_def_pb2.py b/redis_consumer/pbs/node_def_pb2.py new file mode 100644 index 00000000..31c0ba7e --- /dev/null +++ b/redis_consumer/pbs/node_def_pb2.py @@ -0,0 +1,202 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: node_def.proto + +from google.protobuf import descriptor as _descriptor +from google.protobuf import message as _message +from google.protobuf import reflection as _reflection +from google.protobuf import symbol_database as _symbol_database +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +import attr_value_pb2 as attr__value__pb2 + + +DESCRIPTOR = _descriptor.FileDescriptor( + name='node_def.proto', + package='tensorflow', + syntax='proto3', + serialized_options=b'\n\030org.tensorflow.frameworkB\tNodeProtoP\001Z=github.com/tensorflow/tensorflow/tensorflow/go/core/framework\370\001\001', + serialized_pb=b'\n\x0enode_def.proto\x12\ntensorflow\x1a\x10\x61ttr_value.proto\"\xd2\x02\n\x07NodeDef\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\n\n\x02op\x18\x02 \x01(\t\x12\r\n\x05input\x18\x03 \x03(\t\x12\x0e\n\x06\x64\x65vice\x18\x04 \x01(\t\x12+\n\x04\x61ttr\x18\x05 \x03(\x0b\x32\x1d.tensorflow.NodeDef.AttrEntry\x12J\n\x17\x65xperimental_debug_info\x18\x06 \x01(\x0b\x32).tensorflow.NodeDef.ExperimentalDebugInfo\x1a\x42\n\tAttrEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12$\n\x05value\x18\x02 \x01(\x0b\x32\x15.tensorflow.AttrValue:\x02\x38\x01\x1aQ\n\x15\x45xperimentalDebugInfo\x12\x1b\n\x13original_node_names\x18\x01 \x03(\t\x12\x1b\n\x13original_func_names\x18\x02 \x03(\tBi\n\x18org.tensorflow.frameworkB\tNodeProtoP\x01Z=github.com/tensorflow/tensorflow/tensorflow/go/core/framework\xf8\x01\x01\x62\x06proto3' + , + dependencies=[attr__value__pb2.DESCRIPTOR,]) + + + + +_NODEDEF_ATTRENTRY = _descriptor.Descriptor( + name='AttrEntry', + full_name='tensorflow.NodeDef.AttrEntry', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='key', full_name='tensorflow.NodeDef.AttrEntry.key', index=0, + number=1, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='value', full_name='tensorflow.NodeDef.AttrEntry.value', index=1, + number=2, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=b'8\001', + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=238, + serialized_end=304, +) + +_NODEDEF_EXPERIMENTALDEBUGINFO = _descriptor.Descriptor( + name='ExperimentalDebugInfo', + full_name='tensorflow.NodeDef.ExperimentalDebugInfo', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='original_node_names', full_name='tensorflow.NodeDef.ExperimentalDebugInfo.original_node_names', index=0, + number=1, type=9, cpp_type=9, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='original_func_names', full_name='tensorflow.NodeDef.ExperimentalDebugInfo.original_func_names', index=1, + number=2, type=9, cpp_type=9, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=306, + serialized_end=387, +) + +_NODEDEF = _descriptor.Descriptor( + name='NodeDef', + full_name='tensorflow.NodeDef', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='name', full_name='tensorflow.NodeDef.name', index=0, + number=1, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='op', full_name='tensorflow.NodeDef.op', index=1, + number=2, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='input', full_name='tensorflow.NodeDef.input', index=2, + number=3, type=9, cpp_type=9, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='device', full_name='tensorflow.NodeDef.device', index=3, + number=4, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='attr', full_name='tensorflow.NodeDef.attr', index=4, + number=5, type=11, cpp_type=10, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='experimental_debug_info', full_name='tensorflow.NodeDef.experimental_debug_info', index=5, + number=6, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[_NODEDEF_ATTRENTRY, _NODEDEF_EXPERIMENTALDEBUGINFO, ], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=49, + serialized_end=387, +) + +_NODEDEF_ATTRENTRY.fields_by_name['value'].message_type = attr__value__pb2._ATTRVALUE +_NODEDEF_ATTRENTRY.containing_type = _NODEDEF +_NODEDEF_EXPERIMENTALDEBUGINFO.containing_type = _NODEDEF +_NODEDEF.fields_by_name['attr'].message_type = _NODEDEF_ATTRENTRY +_NODEDEF.fields_by_name['experimental_debug_info'].message_type = _NODEDEF_EXPERIMENTALDEBUGINFO +DESCRIPTOR.message_types_by_name['NodeDef'] = _NODEDEF +_sym_db.RegisterFileDescriptor(DESCRIPTOR) + +NodeDef = _reflection.GeneratedProtocolMessageType('NodeDef', (_message.Message,), { + + 'AttrEntry' : _reflection.GeneratedProtocolMessageType('AttrEntry', (_message.Message,), { + 'DESCRIPTOR' : _NODEDEF_ATTRENTRY, + '__module__' : 'node_def_pb2' + # @@protoc_insertion_point(class_scope:tensorflow.NodeDef.AttrEntry) + }) + , + + 'ExperimentalDebugInfo' : _reflection.GeneratedProtocolMessageType('ExperimentalDebugInfo', (_message.Message,), { + 'DESCRIPTOR' : _NODEDEF_EXPERIMENTALDEBUGINFO, + '__module__' : 'node_def_pb2' + # @@protoc_insertion_point(class_scope:tensorflow.NodeDef.ExperimentalDebugInfo) + }) + , + 'DESCRIPTOR' : _NODEDEF, + '__module__' : 'node_def_pb2' + # @@protoc_insertion_point(class_scope:tensorflow.NodeDef) + }) +_sym_db.RegisterMessage(NodeDef) +_sym_db.RegisterMessage(NodeDef.AttrEntry) +_sym_db.RegisterMessage(NodeDef.ExperimentalDebugInfo) + + +DESCRIPTOR._options = None +_NODEDEF_ATTRENTRY._options = None +# @@protoc_insertion_point(module_scope) diff --git a/redis_consumer/pbs/node_def_pb2_grpc.py b/redis_consumer/pbs/node_def_pb2_grpc.py new file mode 100644 index 00000000..a8943526 --- /dev/null +++ b/redis_consumer/pbs/node_def_pb2_grpc.py @@ -0,0 +1,3 @@ +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +import grpc + diff --git a/redis_consumer/pbs/op_def_pb2.py b/redis_consumer/pbs/op_def_pb2.py new file mode 100644 index 00000000..d4d90bcb --- /dev/null +++ b/redis_consumer/pbs/op_def_pb2.py @@ -0,0 +1,404 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: op_def.proto + +from google.protobuf import descriptor as _descriptor +from google.protobuf import message as _message +from google.protobuf import reflection as _reflection +from google.protobuf import symbol_database as _symbol_database +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +import attr_value_pb2 as attr__value__pb2 +import types_pb2 as types__pb2 + + +DESCRIPTOR = _descriptor.FileDescriptor( + name='op_def.proto', + package='tensorflow', + syntax='proto3', + serialized_options=b'\n\030org.tensorflow.frameworkB\013OpDefProtosP\001Z=github.com/tensorflow/tensorflow/tensorflow/go/core/framework\370\001\001', + serialized_pb=b'\n\x0cop_def.proto\x12\ntensorflow\x1a\x10\x61ttr_value.proto\x1a\x0btypes.proto\"\xd0\x05\n\x05OpDef\x12\x0c\n\x04name\x18\x01 \x01(\t\x12+\n\tinput_arg\x18\x02 \x03(\x0b\x32\x18.tensorflow.OpDef.ArgDef\x12,\n\noutput_arg\x18\x03 \x03(\x0b\x32\x18.tensorflow.OpDef.ArgDef\x12\x16\n\x0e\x63ontrol_output\x18\x14 \x03(\t\x12\'\n\x04\x61ttr\x18\x04 \x03(\x0b\x32\x19.tensorflow.OpDef.AttrDef\x12.\n\x0b\x64\x65precation\x18\x08 \x01(\x0b\x32\x19.tensorflow.OpDeprecation\x12\x0f\n\x07summary\x18\x05 \x01(\t\x12\x13\n\x0b\x64\x65scription\x18\x06 \x01(\t\x12\x16\n\x0eis_commutative\x18\x12 \x01(\x08\x12\x14\n\x0cis_aggregate\x18\x10 \x01(\x08\x12\x13\n\x0bis_stateful\x18\x11 \x01(\x08\x12\"\n\x1a\x61llows_uninitialized_input\x18\x13 \x01(\x08\x1a\x9f\x01\n\x06\x41rgDef\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x13\n\x0b\x64\x65scription\x18\x02 \x01(\t\x12\"\n\x04type\x18\x03 \x01(\x0e\x32\x14.tensorflow.DataType\x12\x11\n\ttype_attr\x18\x04 \x01(\t\x12\x13\n\x0bnumber_attr\x18\x05 \x01(\t\x12\x16\n\x0etype_list_attr\x18\x06 \x01(\t\x12\x0e\n\x06is_ref\x18\x10 \x01(\x08\x1a\xbd\x01\n\x07\x41ttrDef\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0c\n\x04type\x18\x02 \x01(\t\x12,\n\rdefault_value\x18\x03 \x01(\x0b\x32\x15.tensorflow.AttrValue\x12\x13\n\x0b\x64\x65scription\x18\x04 \x01(\t\x12\x13\n\x0bhas_minimum\x18\x05 \x01(\x08\x12\x0f\n\x07minimum\x18\x06 \x01(\x03\x12-\n\x0e\x61llowed_values\x18\x07 \x01(\x0b\x32\x15.tensorflow.AttrValue\"5\n\rOpDeprecation\x12\x0f\n\x07version\x18\x01 \x01(\x05\x12\x13\n\x0b\x65xplanation\x18\x02 \x01(\t\"\'\n\x06OpList\x12\x1d\n\x02op\x18\x01 \x03(\x0b\x32\x11.tensorflow.OpDefBk\n\x18org.tensorflow.frameworkB\x0bOpDefProtosP\x01Z=github.com/tensorflow/tensorflow/tensorflow/go/core/framework\xf8\x01\x01\x62\x06proto3' + , + dependencies=[attr__value__pb2.DESCRIPTOR,types__pb2.DESCRIPTOR,]) + + + + +_OPDEF_ARGDEF = _descriptor.Descriptor( + name='ArgDef', + full_name='tensorflow.OpDef.ArgDef', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='name', full_name='tensorflow.OpDef.ArgDef.name', index=0, + number=1, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='description', full_name='tensorflow.OpDef.ArgDef.description', index=1, + number=2, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='type', full_name='tensorflow.OpDef.ArgDef.type', index=2, + number=3, type=14, cpp_type=8, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='type_attr', full_name='tensorflow.OpDef.ArgDef.type_attr', index=3, + number=4, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='number_attr', full_name='tensorflow.OpDef.ArgDef.number_attr', index=4, + number=5, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='type_list_attr', full_name='tensorflow.OpDef.ArgDef.type_list_attr', index=5, + number=6, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='is_ref', full_name='tensorflow.OpDef.ArgDef.is_ref', index=6, + number=16, type=8, cpp_type=7, label=1, + has_default_value=False, default_value=False, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=429, + serialized_end=588, +) + +_OPDEF_ATTRDEF = _descriptor.Descriptor( + name='AttrDef', + full_name='tensorflow.OpDef.AttrDef', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='name', full_name='tensorflow.OpDef.AttrDef.name', index=0, + number=1, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='type', full_name='tensorflow.OpDef.AttrDef.type', index=1, + number=2, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='default_value', full_name='tensorflow.OpDef.AttrDef.default_value', index=2, + number=3, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='description', full_name='tensorflow.OpDef.AttrDef.description', index=3, + number=4, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='has_minimum', full_name='tensorflow.OpDef.AttrDef.has_minimum', index=4, + number=5, type=8, cpp_type=7, label=1, + has_default_value=False, default_value=False, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='minimum', full_name='tensorflow.OpDef.AttrDef.minimum', index=5, + number=6, type=3, cpp_type=2, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='allowed_values', full_name='tensorflow.OpDef.AttrDef.allowed_values', index=6, + number=7, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=591, + serialized_end=780, +) + +_OPDEF = _descriptor.Descriptor( + name='OpDef', + full_name='tensorflow.OpDef', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='name', full_name='tensorflow.OpDef.name', index=0, + number=1, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='input_arg', full_name='tensorflow.OpDef.input_arg', index=1, + number=2, type=11, cpp_type=10, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='output_arg', full_name='tensorflow.OpDef.output_arg', index=2, + number=3, type=11, cpp_type=10, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='control_output', full_name='tensorflow.OpDef.control_output', index=3, + number=20, type=9, cpp_type=9, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='attr', full_name='tensorflow.OpDef.attr', index=4, + number=4, type=11, cpp_type=10, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='deprecation', full_name='tensorflow.OpDef.deprecation', index=5, + number=8, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='summary', full_name='tensorflow.OpDef.summary', index=6, + number=5, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='description', full_name='tensorflow.OpDef.description', index=7, + number=6, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='is_commutative', full_name='tensorflow.OpDef.is_commutative', index=8, + number=18, type=8, cpp_type=7, label=1, + has_default_value=False, default_value=False, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='is_aggregate', full_name='tensorflow.OpDef.is_aggregate', index=9, + number=16, type=8, cpp_type=7, label=1, + has_default_value=False, default_value=False, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='is_stateful', full_name='tensorflow.OpDef.is_stateful', index=10, + number=17, type=8, cpp_type=7, label=1, + has_default_value=False, default_value=False, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='allows_uninitialized_input', full_name='tensorflow.OpDef.allows_uninitialized_input', index=11, + number=19, type=8, cpp_type=7, label=1, + has_default_value=False, default_value=False, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[_OPDEF_ARGDEF, _OPDEF_ATTRDEF, ], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=60, + serialized_end=780, +) + + +_OPDEPRECATION = _descriptor.Descriptor( + name='OpDeprecation', + full_name='tensorflow.OpDeprecation', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='version', full_name='tensorflow.OpDeprecation.version', index=0, + number=1, type=5, cpp_type=1, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='explanation', full_name='tensorflow.OpDeprecation.explanation', index=1, + number=2, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=782, + serialized_end=835, +) + + +_OPLIST = _descriptor.Descriptor( + name='OpList', + full_name='tensorflow.OpList', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='op', full_name='tensorflow.OpList.op', index=0, + number=1, type=11, cpp_type=10, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=837, + serialized_end=876, +) + +_OPDEF_ARGDEF.fields_by_name['type'].enum_type = types__pb2._DATATYPE +_OPDEF_ARGDEF.containing_type = _OPDEF +_OPDEF_ATTRDEF.fields_by_name['default_value'].message_type = attr__value__pb2._ATTRVALUE +_OPDEF_ATTRDEF.fields_by_name['allowed_values'].message_type = attr__value__pb2._ATTRVALUE +_OPDEF_ATTRDEF.containing_type = _OPDEF +_OPDEF.fields_by_name['input_arg'].message_type = _OPDEF_ARGDEF +_OPDEF.fields_by_name['output_arg'].message_type = _OPDEF_ARGDEF +_OPDEF.fields_by_name['attr'].message_type = _OPDEF_ATTRDEF +_OPDEF.fields_by_name['deprecation'].message_type = _OPDEPRECATION +_OPLIST.fields_by_name['op'].message_type = _OPDEF +DESCRIPTOR.message_types_by_name['OpDef'] = _OPDEF +DESCRIPTOR.message_types_by_name['OpDeprecation'] = _OPDEPRECATION +DESCRIPTOR.message_types_by_name['OpList'] = _OPLIST +_sym_db.RegisterFileDescriptor(DESCRIPTOR) + +OpDef = _reflection.GeneratedProtocolMessageType('OpDef', (_message.Message,), { + + 'ArgDef' : _reflection.GeneratedProtocolMessageType('ArgDef', (_message.Message,), { + 'DESCRIPTOR' : _OPDEF_ARGDEF, + '__module__' : 'op_def_pb2' + # @@protoc_insertion_point(class_scope:tensorflow.OpDef.ArgDef) + }) + , + + 'AttrDef' : _reflection.GeneratedProtocolMessageType('AttrDef', (_message.Message,), { + 'DESCRIPTOR' : _OPDEF_ATTRDEF, + '__module__' : 'op_def_pb2' + # @@protoc_insertion_point(class_scope:tensorflow.OpDef.AttrDef) + }) + , + 'DESCRIPTOR' : _OPDEF, + '__module__' : 'op_def_pb2' + # @@protoc_insertion_point(class_scope:tensorflow.OpDef) + }) +_sym_db.RegisterMessage(OpDef) +_sym_db.RegisterMessage(OpDef.ArgDef) +_sym_db.RegisterMessage(OpDef.AttrDef) + +OpDeprecation = _reflection.GeneratedProtocolMessageType('OpDeprecation', (_message.Message,), { + 'DESCRIPTOR' : _OPDEPRECATION, + '__module__' : 'op_def_pb2' + # @@protoc_insertion_point(class_scope:tensorflow.OpDeprecation) + }) +_sym_db.RegisterMessage(OpDeprecation) + +OpList = _reflection.GeneratedProtocolMessageType('OpList', (_message.Message,), { + 'DESCRIPTOR' : _OPLIST, + '__module__' : 'op_def_pb2' + # @@protoc_insertion_point(class_scope:tensorflow.OpList) + }) +_sym_db.RegisterMessage(OpList) + + +DESCRIPTOR._options = None +# @@protoc_insertion_point(module_scope) diff --git a/redis_consumer/pbs/op_def_pb2_grpc.py b/redis_consumer/pbs/op_def_pb2_grpc.py new file mode 100644 index 00000000..a8943526 --- /dev/null +++ b/redis_consumer/pbs/op_def_pb2_grpc.py @@ -0,0 +1,3 @@ +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +import grpc + diff --git a/redis_consumer/pbs/predict_pb2.py b/redis_consumer/pbs/predict_pb2.py index 641b1bc4..9d6f2cef 100644 --- a/redis_consumer/pbs/predict_pb2.py +++ b/redis_consumer/pbs/predict_pb2.py @@ -1,8 +1,7 @@ +# -*- coding: utf-8 -*- # Generated by the protocol buffer compiler. DO NOT EDIT! # source: predict.proto -import sys -_b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) from google.protobuf import descriptor as _descriptor from google.protobuf import message as _message from google.protobuf import reflection as _reflection @@ -12,16 +11,16 @@ _sym_db = _symbol_database.Default() -import redis_consumer.pbs.tensor_pb2 as tensor__pb2 -import redis_consumer.pbs.model_pb2 as model__pb2 +import tensor_pb2 as tensor__pb2 +import model_pb2 as model__pb2 DESCRIPTOR = _descriptor.FileDescriptor( name='predict.proto', package='tensorflow.serving', syntax='proto3', - serialized_options=_b('\370\001\001'), - serialized_pb=_b('\n\rpredict.proto\x12\x12tensorflow.serving\x1a\x0ctensor.proto\x1a\x0bmodel.proto\"\xe2\x01\n\x0ePredictRequest\x12\x31\n\nmodel_spec\x18\x01 \x01(\x0b\x32\x1d.tensorflow.serving.ModelSpec\x12>\n\x06inputs\x18\x02 \x03(\x0b\x32..tensorflow.serving.PredictRequest.InputsEntry\x12\x15\n\routput_filter\x18\x03 \x03(\t\x1a\x46\n\x0bInputsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12&\n\x05value\x18\x02 \x01(\x0b\x32\x17.tensorflow.TensorProto:\x02\x38\x01\"\x9d\x01\n\x0fPredictResponse\x12\x41\n\x07outputs\x18\x01 \x03(\x0b\x32\x30.tensorflow.serving.PredictResponse.OutputsEntry\x1aG\n\x0cOutputsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12&\n\x05value\x18\x02 \x01(\x0b\x32\x17.tensorflow.TensorProto:\x02\x38\x01\x42\x03\xf8\x01\x01\x62\x06proto3') + serialized_options=b'\370\001\001', + serialized_pb=b'\n\rpredict.proto\x12\x12tensorflow.serving\x1a\x0ctensor.proto\x1a\x0bmodel.proto\"\xe2\x01\n\x0ePredictRequest\x12\x31\n\nmodel_spec\x18\x01 \x01(\x0b\x32\x1d.tensorflow.serving.ModelSpec\x12>\n\x06inputs\x18\x02 \x03(\x0b\x32..tensorflow.serving.PredictRequest.InputsEntry\x12\x15\n\routput_filter\x18\x03 \x03(\t\x1a\x46\n\x0bInputsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12&\n\x05value\x18\x02 \x01(\x0b\x32\x17.tensorflow.TensorProto:\x02\x38\x01\"\xd0\x01\n\x0fPredictResponse\x12\x31\n\nmodel_spec\x18\x02 \x01(\x0b\x32\x1d.tensorflow.serving.ModelSpec\x12\x41\n\x07outputs\x18\x01 \x03(\x0b\x32\x30.tensorflow.serving.PredictResponse.OutputsEntry\x1aG\n\x0cOutputsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12&\n\x05value\x18\x02 \x01(\x0b\x32\x17.tensorflow.TensorProto:\x02\x38\x01\x42\x03\xf8\x01\x01\x62\x06proto3' , dependencies=[tensor__pb2.DESCRIPTOR,model__pb2.DESCRIPTOR,]) @@ -38,7 +37,7 @@ _descriptor.FieldDescriptor( name='key', full_name='tensorflow.serving.PredictRequest.InputsEntry.key', index=0, number=1, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=_b("").decode('utf-8'), + has_default_value=False, default_value=b"".decode('utf-8'), message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, serialized_options=None, file=DESCRIPTOR), @@ -55,7 +54,7 @@ nested_types=[], enum_types=[ ], - serialized_options=_b('8\001'), + serialized_options=b'8\001', is_extendable=False, syntax='proto3', extension_ranges=[], @@ -120,7 +119,7 @@ _descriptor.FieldDescriptor( name='key', full_name='tensorflow.serving.PredictResponse.OutputsEntry.key', index=0, number=1, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=_b("").decode('utf-8'), + has_default_value=False, default_value=b"".decode('utf-8'), message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, serialized_options=None, file=DESCRIPTOR), @@ -137,14 +136,14 @@ nested_types=[], enum_types=[ ], - serialized_options=_b('8\001'), + serialized_options=b'8\001', is_extendable=False, syntax='proto3', extension_ranges=[], oneofs=[ ], - serialized_start=380, - serialized_end=451, + serialized_start=431, + serialized_end=502, ) _PREDICTRESPONSE = _descriptor.Descriptor( @@ -155,7 +154,14 @@ containing_type=None, fields=[ _descriptor.FieldDescriptor( - name='outputs', full_name='tensorflow.serving.PredictResponse.outputs', index=0, + name='model_spec', full_name='tensorflow.serving.PredictResponse.model_spec', index=0, + number=2, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='outputs', full_name='tensorflow.serving.PredictResponse.outputs', index=1, number=1, type=11, cpp_type=10, label=3, has_default_value=False, default_value=[], message_type=None, enum_type=None, containing_type=None, @@ -174,7 +180,7 @@ oneofs=[ ], serialized_start=294, - serialized_end=451, + serialized_end=502, ) _PREDICTREQUEST_INPUTSENTRY.fields_by_name['value'].message_type = tensor__pb2._TENSORPROTO @@ -183,38 +189,39 @@ _PREDICTREQUEST.fields_by_name['inputs'].message_type = _PREDICTREQUEST_INPUTSENTRY _PREDICTRESPONSE_OUTPUTSENTRY.fields_by_name['value'].message_type = tensor__pb2._TENSORPROTO _PREDICTRESPONSE_OUTPUTSENTRY.containing_type = _PREDICTRESPONSE +_PREDICTRESPONSE.fields_by_name['model_spec'].message_type = model__pb2._MODELSPEC _PREDICTRESPONSE.fields_by_name['outputs'].message_type = _PREDICTRESPONSE_OUTPUTSENTRY DESCRIPTOR.message_types_by_name['PredictRequest'] = _PREDICTREQUEST DESCRIPTOR.message_types_by_name['PredictResponse'] = _PREDICTRESPONSE _sym_db.RegisterFileDescriptor(DESCRIPTOR) -PredictRequest = _reflection.GeneratedProtocolMessageType('PredictRequest', (_message.Message,), dict( +PredictRequest = _reflection.GeneratedProtocolMessageType('PredictRequest', (_message.Message,), { - InputsEntry = _reflection.GeneratedProtocolMessageType('InputsEntry', (_message.Message,), dict( - DESCRIPTOR = _PREDICTREQUEST_INPUTSENTRY, - __module__ = 'predict_pb2' + 'InputsEntry' : _reflection.GeneratedProtocolMessageType('InputsEntry', (_message.Message,), { + 'DESCRIPTOR' : _PREDICTREQUEST_INPUTSENTRY, + '__module__' : 'predict_pb2' # @@protoc_insertion_point(class_scope:tensorflow.serving.PredictRequest.InputsEntry) - )) + }) , - DESCRIPTOR = _PREDICTREQUEST, - __module__ = 'predict_pb2' + 'DESCRIPTOR' : _PREDICTREQUEST, + '__module__' : 'predict_pb2' # @@protoc_insertion_point(class_scope:tensorflow.serving.PredictRequest) - )) + }) _sym_db.RegisterMessage(PredictRequest) _sym_db.RegisterMessage(PredictRequest.InputsEntry) -PredictResponse = _reflection.GeneratedProtocolMessageType('PredictResponse', (_message.Message,), dict( +PredictResponse = _reflection.GeneratedProtocolMessageType('PredictResponse', (_message.Message,), { - OutputsEntry = _reflection.GeneratedProtocolMessageType('OutputsEntry', (_message.Message,), dict( - DESCRIPTOR = _PREDICTRESPONSE_OUTPUTSENTRY, - __module__ = 'predict_pb2' + 'OutputsEntry' : _reflection.GeneratedProtocolMessageType('OutputsEntry', (_message.Message,), { + 'DESCRIPTOR' : _PREDICTRESPONSE_OUTPUTSENTRY, + '__module__' : 'predict_pb2' # @@protoc_insertion_point(class_scope:tensorflow.serving.PredictResponse.OutputsEntry) - )) + }) , - DESCRIPTOR = _PREDICTRESPONSE, - __module__ = 'predict_pb2' + 'DESCRIPTOR' : _PREDICTRESPONSE, + '__module__' : 'predict_pb2' # @@protoc_insertion_point(class_scope:tensorflow.serving.PredictResponse) - )) + }) _sym_db.RegisterMessage(PredictResponse) _sym_db.RegisterMessage(PredictResponse.OutputsEntry) diff --git a/redis_consumer/pbs/prediction_service_pb2.py b/redis_consumer/pbs/prediction_service_pb2.py index a5569d88..7be5fa12 100644 --- a/redis_consumer/pbs/prediction_service_pb2.py +++ b/redis_consumer/pbs/prediction_service_pb2.py @@ -1,8 +1,7 @@ +# -*- coding: utf-8 -*- # Generated by the protocol buffer compiler. DO NOT EDIT! # source: prediction_service.proto -import sys -_b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) from google.protobuf import descriptor as _descriptor from google.protobuf import message as _message from google.protobuf import reflection as _reflection @@ -12,17 +11,18 @@ _sym_db = _symbol_database.Default() -import redis_consumer.pbs.predict_pb2 as predict__pb2 +import get_model_metadata_pb2 as get__model__metadata__pb2 +import predict_pb2 as predict__pb2 DESCRIPTOR = _descriptor.FileDescriptor( name='prediction_service.proto', package='tensorflow.serving', syntax='proto3', - serialized_options=_b('\370\001\001'), - serialized_pb=_b('\n\x18prediction_service.proto\x12\x12tensorflow.serving\x1a\rpredict.proto2g\n\x11PredictionService\x12R\n\x07Predict\x12\".tensorflow.serving.PredictRequest\x1a#.tensorflow.serving.PredictResponseB\x03\xf8\x01\x01\x62\x06proto3') + serialized_options=b'\370\001\001', + serialized_pb=b'\n\x18prediction_service.proto\x12\x12tensorflow.serving\x1a\x18get_model_metadata.proto\x1a\rpredict.proto2\xd6\x01\n\x11PredictionService\x12R\n\x07Predict\x12\".tensorflow.serving.PredictRequest\x1a#.tensorflow.serving.PredictResponse\x12m\n\x10GetModelMetadata\x12+.tensorflow.serving.GetModelMetadataRequest\x1a,.tensorflow.serving.GetModelMetadataResponseB\x03\xf8\x01\x01\x62\x06proto3' , - dependencies=[predict__pb2.DESCRIPTOR,]) + dependencies=[get__model__metadata__pb2.DESCRIPTOR,predict__pb2.DESCRIPTOR,]) @@ -37,8 +37,8 @@ file=DESCRIPTOR, index=0, serialized_options=None, - serialized_start=63, - serialized_end=166, + serialized_start=90, + serialized_end=304, methods=[ _descriptor.MethodDescriptor( name='Predict', @@ -49,6 +49,15 @@ output_type=predict__pb2._PREDICTRESPONSE, serialized_options=None, ), + _descriptor.MethodDescriptor( + name='GetModelMetadata', + full_name='tensorflow.serving.PredictionService.GetModelMetadata', + index=1, + containing_service=None, + input_type=get__model__metadata__pb2._GETMODELMETADATAREQUEST, + output_type=get__model__metadata__pb2._GETMODELMETADATARESPONSE, + serialized_options=None, + ), ]) _sym_db.RegisterServiceDescriptor(_PREDICTIONSERVICE) diff --git a/redis_consumer/pbs/prediction_service_pb2_grpc.py b/redis_consumer/pbs/prediction_service_pb2_grpc.py index 292d71ac..e0235f0a 100644 --- a/redis_consumer/pbs/prediction_service_pb2_grpc.py +++ b/redis_consumer/pbs/prediction_service_pb2_grpc.py @@ -1,12 +1,16 @@ # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! import grpc -import redis_consumer.pbs.predict_pb2 as predict__pb2 +import get_model_metadata_pb2 as get__model__metadata__pb2 +import predict_pb2 as predict__pb2 class PredictionServiceStub(object): - """PredictionService provides access to machine-learned models loaded by + """open source marker; do not remove + PredictionService provides access to machine-learned models loaded by model_servers. + Classify. + rpc Classify(ClassificationRequest) returns (ClassificationResponse); """ def __init__(self, channel): @@ -20,15 +24,36 @@ def __init__(self, channel): request_serializer=predict__pb2.PredictRequest.SerializeToString, response_deserializer=predict__pb2.PredictResponse.FromString, ) + self.GetModelMetadata = channel.unary_unary( + '/tensorflow.serving.PredictionService/GetModelMetadata', + request_serializer=get__model__metadata__pb2.GetModelMetadataRequest.SerializeToString, + response_deserializer=get__model__metadata__pb2.GetModelMetadataResponse.FromString, + ) class PredictionServiceServicer(object): - """PredictionService provides access to machine-learned models loaded by + """open source marker; do not remove + PredictionService provides access to machine-learned models loaded by model_servers. + Classify. + rpc Classify(ClassificationRequest) returns (ClassificationResponse); """ def Predict(self, request, context): - """Predict -- provides access to loaded TensorFlow model. + """Regress. + rpc Regress(RegressionRequest) returns (RegressionResponse); + + Predict -- provides access to loaded TensorFlow model. + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def GetModelMetadata(self, request, context): + """MultiInference API for multi-headed models. + rpc MultiInference(MultiInferenceRequest) returns (MultiInferenceResponse); + + GetModelMetadata - provides access to metadata for loaded models. """ context.set_code(grpc.StatusCode.UNIMPLEMENTED) context.set_details('Method not implemented!') @@ -42,6 +67,11 @@ def add_PredictionServiceServicer_to_server(servicer, server): request_deserializer=predict__pb2.PredictRequest.FromString, response_serializer=predict__pb2.PredictResponse.SerializeToString, ), + 'GetModelMetadata': grpc.unary_unary_rpc_method_handler( + servicer.GetModelMetadata, + request_deserializer=get__model__metadata__pb2.GetModelMetadataRequest.FromString, + response_serializer=get__model__metadata__pb2.GetModelMetadataResponse.SerializeToString, + ), } generic_handler = grpc.method_handlers_generic_handler( 'tensorflow.serving.PredictionService', rpc_method_handlers) diff --git a/redis_consumer/pbs/resource_handle_pb2.py b/redis_consumer/pbs/resource_handle_pb2.py index a210be0b..4e5fcccc 100644 --- a/redis_consumer/pbs/resource_handle_pb2.py +++ b/redis_consumer/pbs/resource_handle_pb2.py @@ -1,8 +1,7 @@ +# -*- coding: utf-8 -*- # Generated by the protocol buffer compiler. DO NOT EDIT! # source: resource_handle.proto -import sys -_b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) from google.protobuf import descriptor as _descriptor from google.protobuf import message as _message from google.protobuf import reflection as _reflection @@ -12,65 +11,112 @@ _sym_db = _symbol_database.Default() +import tensor_shape_pb2 as tensor__shape__pb2 +import types_pb2 as types__pb2 DESCRIPTOR = _descriptor.FileDescriptor( name='resource_handle.proto', package='tensorflow', syntax='proto3', - serialized_options=_b('\n\030org.tensorflow.frameworkB\023ResourceHandleProtoP\001\370\001\001'), - serialized_pb=_b('\n\x15resource_handle.proto\x12\ntensorflow\"m\n\x0eResourceHandle\x12\x0e\n\x06\x64\x65vice\x18\x01 \x01(\t\x12\x11\n\tcontainer\x18\x02 \x01(\t\x12\x0c\n\x04name\x18\x03 \x01(\t\x12\x11\n\thash_code\x18\x04 \x01(\x04\x12\x17\n\x0fmaybe_type_name\x18\x05 \x01(\tB4\n\x18org.tensorflow.frameworkB\x13ResourceHandleProtoP\x01\xf8\x01\x01\x62\x06proto3') -) + serialized_options=b'\n\030org.tensorflow.frameworkB\016ResourceHandleP\001Z=github.com/tensorflow/tensorflow/tensorflow/go/core/framework\370\001\001', + serialized_pb=b'\n\x15resource_handle.proto\x12\ntensorflow\x1a\x12tensor_shape.proto\x1a\x0btypes.proto\"\x9f\x02\n\x13ResourceHandleProto\x12\x0e\n\x06\x64\x65vice\x18\x01 \x01(\t\x12\x11\n\tcontainer\x18\x02 \x01(\t\x12\x0c\n\x04name\x18\x03 \x01(\t\x12\x11\n\thash_code\x18\x04 \x01(\x04\x12\x17\n\x0fmaybe_type_name\x18\x05 \x01(\t\x12H\n\x11\x64types_and_shapes\x18\x06 \x03(\x0b\x32-.tensorflow.ResourceHandleProto.DtypeAndShape\x1a\x61\n\rDtypeAndShape\x12#\n\x05\x64type\x18\x01 \x01(\x0e\x32\x14.tensorflow.DataType\x12+\n\x05shape\x18\x02 \x01(\x0b\x32\x1c.tensorflow.TensorShapeProtoBn\n\x18org.tensorflow.frameworkB\x0eResourceHandleP\x01Z=github.com/tensorflow/tensorflow/tensorflow/go/core/framework\xf8\x01\x01\x62\x06proto3' + , + dependencies=[tensor__shape__pb2.DESCRIPTOR,types__pb2.DESCRIPTOR,]) + +_RESOURCEHANDLEPROTO_DTYPEANDSHAPE = _descriptor.Descriptor( + name='DtypeAndShape', + full_name='tensorflow.ResourceHandleProto.DtypeAndShape', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='dtype', full_name='tensorflow.ResourceHandleProto.DtypeAndShape.dtype', index=0, + number=1, type=14, cpp_type=8, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='shape', full_name='tensorflow.ResourceHandleProto.DtypeAndShape.shape', index=1, + number=2, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=261, + serialized_end=358, +) -_RESOURCEHANDLE = _descriptor.Descriptor( - name='ResourceHandle', - full_name='tensorflow.ResourceHandle', +_RESOURCEHANDLEPROTO = _descriptor.Descriptor( + name='ResourceHandleProto', + full_name='tensorflow.ResourceHandleProto', filename=None, file=DESCRIPTOR, containing_type=None, fields=[ _descriptor.FieldDescriptor( - name='device', full_name='tensorflow.ResourceHandle.device', index=0, + name='device', full_name='tensorflow.ResourceHandleProto.device', index=0, number=1, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=_b("").decode('utf-8'), + has_default_value=False, default_value=b"".decode('utf-8'), message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, serialized_options=None, file=DESCRIPTOR), _descriptor.FieldDescriptor( - name='container', full_name='tensorflow.ResourceHandle.container', index=1, + name='container', full_name='tensorflow.ResourceHandleProto.container', index=1, number=2, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=_b("").decode('utf-8'), + has_default_value=False, default_value=b"".decode('utf-8'), message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, serialized_options=None, file=DESCRIPTOR), _descriptor.FieldDescriptor( - name='name', full_name='tensorflow.ResourceHandle.name', index=2, + name='name', full_name='tensorflow.ResourceHandleProto.name', index=2, number=3, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=_b("").decode('utf-8'), + has_default_value=False, default_value=b"".decode('utf-8'), message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, serialized_options=None, file=DESCRIPTOR), _descriptor.FieldDescriptor( - name='hash_code', full_name='tensorflow.ResourceHandle.hash_code', index=3, + name='hash_code', full_name='tensorflow.ResourceHandleProto.hash_code', index=3, number=4, type=4, cpp_type=4, label=1, has_default_value=False, default_value=0, message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, serialized_options=None, file=DESCRIPTOR), _descriptor.FieldDescriptor( - name='maybe_type_name', full_name='tensorflow.ResourceHandle.maybe_type_name', index=4, + name='maybe_type_name', full_name='tensorflow.ResourceHandleProto.maybe_type_name', index=4, number=5, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=_b("").decode('utf-8'), + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='dtypes_and_shapes', full_name='tensorflow.ResourceHandleProto.dtypes_and_shapes', index=5, + number=6, type=11, cpp_type=10, label=3, + has_default_value=False, default_value=[], message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, serialized_options=None, file=DESCRIPTOR), ], extensions=[ ], - nested_types=[], + nested_types=[_RESOURCEHANDLEPROTO_DTYPEANDSHAPE, ], enum_types=[ ], serialized_options=None, @@ -79,19 +125,31 @@ extension_ranges=[], oneofs=[ ], - serialized_start=37, - serialized_end=146, + serialized_start=71, + serialized_end=358, ) -DESCRIPTOR.message_types_by_name['ResourceHandle'] = _RESOURCEHANDLE +_RESOURCEHANDLEPROTO_DTYPEANDSHAPE.fields_by_name['dtype'].enum_type = types__pb2._DATATYPE +_RESOURCEHANDLEPROTO_DTYPEANDSHAPE.fields_by_name['shape'].message_type = tensor__shape__pb2._TENSORSHAPEPROTO +_RESOURCEHANDLEPROTO_DTYPEANDSHAPE.containing_type = _RESOURCEHANDLEPROTO +_RESOURCEHANDLEPROTO.fields_by_name['dtypes_and_shapes'].message_type = _RESOURCEHANDLEPROTO_DTYPEANDSHAPE +DESCRIPTOR.message_types_by_name['ResourceHandleProto'] = _RESOURCEHANDLEPROTO _sym_db.RegisterFileDescriptor(DESCRIPTOR) -ResourceHandle = _reflection.GeneratedProtocolMessageType('ResourceHandle', (_message.Message,), dict( - DESCRIPTOR = _RESOURCEHANDLE, - __module__ = 'resource_handle_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.ResourceHandle) - )) -_sym_db.RegisterMessage(ResourceHandle) +ResourceHandleProto = _reflection.GeneratedProtocolMessageType('ResourceHandleProto', (_message.Message,), { + + 'DtypeAndShape' : _reflection.GeneratedProtocolMessageType('DtypeAndShape', (_message.Message,), { + 'DESCRIPTOR' : _RESOURCEHANDLEPROTO_DTYPEANDSHAPE, + '__module__' : 'resource_handle_pb2' + # @@protoc_insertion_point(class_scope:tensorflow.ResourceHandleProto.DtypeAndShape) + }) + , + 'DESCRIPTOR' : _RESOURCEHANDLEPROTO, + '__module__' : 'resource_handle_pb2' + # @@protoc_insertion_point(class_scope:tensorflow.ResourceHandleProto) + }) +_sym_db.RegisterMessage(ResourceHandleProto) +_sym_db.RegisterMessage(ResourceHandleProto.DtypeAndShape) DESCRIPTOR._options = None diff --git a/redis_consumer/pbs/saved_object_graph_pb2.py b/redis_consumer/pbs/saved_object_graph_pb2.py new file mode 100644 index 00000000..4779b05d --- /dev/null +++ b/redis_consumer/pbs/saved_object_graph_pb2.py @@ -0,0 +1,720 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: saved_object_graph.proto + +from google.protobuf import descriptor as _descriptor +from google.protobuf import message as _message +from google.protobuf import reflection as _reflection +from google.protobuf import symbol_database as _symbol_database +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +import trackable_object_graph_pb2 as trackable__object__graph__pb2 +import struct_pb2 as struct__pb2 +import tensor_shape_pb2 as tensor__shape__pb2 +import types_pb2 as types__pb2 +import versions_pb2 as versions__pb2 +import variable_pb2 as variable__pb2 + + +DESCRIPTOR = _descriptor.FileDescriptor( + name='saved_object_graph.proto', + package='tensorflow', + syntax='proto3', + serialized_options=b'\370\001\001', + serialized_pb=b'\n\x18saved_object_graph.proto\x12\ntensorflow\x1a\x1ctrackable_object_graph.proto\x1a\x0cstruct.proto\x1a\x12tensor_shape.proto\x1a\x0btypes.proto\x1a\x0eversions.proto\x1a\x0evariable.proto\"\xe8\x01\n\x10SavedObjectGraph\x12&\n\x05nodes\x18\x01 \x03(\x0b\x32\x17.tensorflow.SavedObject\x12O\n\x12\x63oncrete_functions\x18\x02 \x03(\x0b\x32\x33.tensorflow.SavedObjectGraph.ConcreteFunctionsEntry\x1a[\n\x16\x43oncreteFunctionsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\x30\n\x05value\x18\x02 \x01(\x0b\x32!.tensorflow.SavedConcreteFunction:\x02\x38\x01\"\xbd\x04\n\x0bSavedObject\x12R\n\x08\x63hildren\x18\x01 \x03(\x0b\x32@.tensorflow.TrackableObjectGraph.TrackableObject.ObjectReference\x12^\n\x0eslot_variables\x18\x03 \x03(\x0b\x32\x46.tensorflow.TrackableObjectGraph.TrackableObject.SlotVariableReference\x12\x32\n\x0buser_object\x18\x04 \x01(\x0b\x32\x1b.tensorflow.SavedUserObjectH\x00\x12\'\n\x05\x61sset\x18\x05 \x01(\x0b\x32\x16.tensorflow.SavedAssetH\x00\x12-\n\x08\x66unction\x18\x06 \x01(\x0b\x32\x19.tensorflow.SavedFunctionH\x00\x12-\n\x08variable\x18\x07 \x01(\x0b\x32\x19.tensorflow.SavedVariableH\x00\x12G\n\x16\x62\x61re_concrete_function\x18\x08 \x01(\x0b\x32%.tensorflow.SavedBareConcreteFunctionH\x00\x12-\n\x08\x63onstant\x18\t \x01(\x0b\x32\x19.tensorflow.SavedConstantH\x00\x12-\n\x08resource\x18\n \x01(\x0b\x32\x19.tensorflow.SavedResourceH\x00\x42\x06\n\x04kindJ\x04\x08\x02\x10\x03R\nattributes\"`\n\x0fSavedUserObject\x12\x12\n\nidentifier\x18\x01 \x01(\t\x12\'\n\x07version\x18\x02 \x01(\x0b\x32\x16.tensorflow.VersionDef\x12\x10\n\x08metadata\x18\x03 \x01(\t\"*\n\nSavedAsset\x12\x1c\n\x14\x61sset_file_def_index\x18\x01 \x01(\x05\"\\\n\rSavedFunction\x12\x1a\n\x12\x63oncrete_functions\x18\x01 \x03(\t\x12/\n\rfunction_spec\x18\x02 \x01(\x0b\x32\x18.tensorflow.FunctionSpec\"\xa8\x01\n\x15SavedConcreteFunction\x12\x14\n\x0c\x62ound_inputs\x18\x02 \x03(\x05\x12\x42\n\x1d\x63\x61nonicalized_input_signature\x18\x03 \x01(\x0b\x32\x1b.tensorflow.StructuredValue\x12\x35\n\x10output_signature\x18\x04 \x01(\x0b\x32\x1b.tensorflow.StructuredValue\"|\n\x19SavedBareConcreteFunction\x12\x1e\n\x16\x63oncrete_function_name\x18\x01 \x01(\t\x12\x19\n\x11\x61rgument_keywords\x18\x02 \x03(\t\x12$\n\x1c\x61llowed_positional_arguments\x18\x03 \x01(\x03\"\"\n\rSavedConstant\x12\x11\n\toperation\x18\x01 \x01(\t\"\xf6\x01\n\rSavedVariable\x12#\n\x05\x64type\x18\x01 \x01(\x0e\x32\x14.tensorflow.DataType\x12+\n\x05shape\x18\x02 \x01(\x0b\x32\x1c.tensorflow.TensorShapeProto\x12\x11\n\ttrainable\x18\x03 \x01(\x08\x12<\n\x0fsynchronization\x18\x04 \x01(\x0e\x32#.tensorflow.VariableSynchronization\x12\x34\n\x0b\x61ggregation\x18\x05 \x01(\x0e\x32\x1f.tensorflow.VariableAggregation\x12\x0c\n\x04name\x18\x06 \x01(\t\"\x95\x01\n\x0c\x46unctionSpec\x12\x30\n\x0b\x66ullargspec\x18\x01 \x01(\x0b\x32\x1b.tensorflow.StructuredValue\x12\x11\n\tis_method\x18\x02 \x01(\x08\x12\x34\n\x0finput_signature\x18\x05 \x01(\x0b\x32\x1b.tensorflow.StructuredValueJ\x04\x08\x03\x10\x04J\x04\x08\x04\x10\x05\"\x1f\n\rSavedResource\x12\x0e\n\x06\x64\x65vice\x18\x01 \x01(\tB\x03\xf8\x01\x01\x62\x06proto3' + , + dependencies=[trackable__object__graph__pb2.DESCRIPTOR,struct__pb2.DESCRIPTOR,tensor__shape__pb2.DESCRIPTOR,types__pb2.DESCRIPTOR,versions__pb2.DESCRIPTOR,variable__pb2.DESCRIPTOR,]) + + + + +_SAVEDOBJECTGRAPH_CONCRETEFUNCTIONSENTRY = _descriptor.Descriptor( + name='ConcreteFunctionsEntry', + full_name='tensorflow.SavedObjectGraph.ConcreteFunctionsEntry', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='key', full_name='tensorflow.SavedObjectGraph.ConcreteFunctionsEntry.key', index=0, + number=1, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='value', full_name='tensorflow.SavedObjectGraph.ConcreteFunctionsEntry.value', index=1, + number=2, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=b'8\001', + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=291, + serialized_end=382, +) + +_SAVEDOBJECTGRAPH = _descriptor.Descriptor( + name='SavedObjectGraph', + full_name='tensorflow.SavedObjectGraph', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='nodes', full_name='tensorflow.SavedObjectGraph.nodes', index=0, + number=1, type=11, cpp_type=10, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='concrete_functions', full_name='tensorflow.SavedObjectGraph.concrete_functions', index=1, + number=2, type=11, cpp_type=10, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[_SAVEDOBJECTGRAPH_CONCRETEFUNCTIONSENTRY, ], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=150, + serialized_end=382, +) + + +_SAVEDOBJECT = _descriptor.Descriptor( + name='SavedObject', + full_name='tensorflow.SavedObject', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='children', full_name='tensorflow.SavedObject.children', index=0, + number=1, type=11, cpp_type=10, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='slot_variables', full_name='tensorflow.SavedObject.slot_variables', index=1, + number=3, type=11, cpp_type=10, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='user_object', full_name='tensorflow.SavedObject.user_object', index=2, + number=4, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='asset', full_name='tensorflow.SavedObject.asset', index=3, + number=5, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='function', full_name='tensorflow.SavedObject.function', index=4, + number=6, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='variable', full_name='tensorflow.SavedObject.variable', index=5, + number=7, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='bare_concrete_function', full_name='tensorflow.SavedObject.bare_concrete_function', index=6, + number=8, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='constant', full_name='tensorflow.SavedObject.constant', index=7, + number=9, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='resource', full_name='tensorflow.SavedObject.resource', index=8, + number=10, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + _descriptor.OneofDescriptor( + name='kind', full_name='tensorflow.SavedObject.kind', + index=0, containing_type=None, fields=[]), + ], + serialized_start=385, + serialized_end=958, +) + + +_SAVEDUSEROBJECT = _descriptor.Descriptor( + name='SavedUserObject', + full_name='tensorflow.SavedUserObject', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='identifier', full_name='tensorflow.SavedUserObject.identifier', index=0, + number=1, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='version', full_name='tensorflow.SavedUserObject.version', index=1, + number=2, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='metadata', full_name='tensorflow.SavedUserObject.metadata', index=2, + number=3, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=960, + serialized_end=1056, +) + + +_SAVEDASSET = _descriptor.Descriptor( + name='SavedAsset', + full_name='tensorflow.SavedAsset', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='asset_file_def_index', full_name='tensorflow.SavedAsset.asset_file_def_index', index=0, + number=1, type=5, cpp_type=1, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=1058, + serialized_end=1100, +) + + +_SAVEDFUNCTION = _descriptor.Descriptor( + name='SavedFunction', + full_name='tensorflow.SavedFunction', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='concrete_functions', full_name='tensorflow.SavedFunction.concrete_functions', index=0, + number=1, type=9, cpp_type=9, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='function_spec', full_name='tensorflow.SavedFunction.function_spec', index=1, + number=2, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=1102, + serialized_end=1194, +) + + +_SAVEDCONCRETEFUNCTION = _descriptor.Descriptor( + name='SavedConcreteFunction', + full_name='tensorflow.SavedConcreteFunction', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='bound_inputs', full_name='tensorflow.SavedConcreteFunction.bound_inputs', index=0, + number=2, type=5, cpp_type=1, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='canonicalized_input_signature', full_name='tensorflow.SavedConcreteFunction.canonicalized_input_signature', index=1, + number=3, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='output_signature', full_name='tensorflow.SavedConcreteFunction.output_signature', index=2, + number=4, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=1197, + serialized_end=1365, +) + + +_SAVEDBARECONCRETEFUNCTION = _descriptor.Descriptor( + name='SavedBareConcreteFunction', + full_name='tensorflow.SavedBareConcreteFunction', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='concrete_function_name', full_name='tensorflow.SavedBareConcreteFunction.concrete_function_name', index=0, + number=1, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='argument_keywords', full_name='tensorflow.SavedBareConcreteFunction.argument_keywords', index=1, + number=2, type=9, cpp_type=9, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='allowed_positional_arguments', full_name='tensorflow.SavedBareConcreteFunction.allowed_positional_arguments', index=2, + number=3, type=3, cpp_type=2, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=1367, + serialized_end=1491, +) + + +_SAVEDCONSTANT = _descriptor.Descriptor( + name='SavedConstant', + full_name='tensorflow.SavedConstant', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='operation', full_name='tensorflow.SavedConstant.operation', index=0, + number=1, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=1493, + serialized_end=1527, +) + + +_SAVEDVARIABLE = _descriptor.Descriptor( + name='SavedVariable', + full_name='tensorflow.SavedVariable', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='dtype', full_name='tensorflow.SavedVariable.dtype', index=0, + number=1, type=14, cpp_type=8, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='shape', full_name='tensorflow.SavedVariable.shape', index=1, + number=2, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='trainable', full_name='tensorflow.SavedVariable.trainable', index=2, + number=3, type=8, cpp_type=7, label=1, + has_default_value=False, default_value=False, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='synchronization', full_name='tensorflow.SavedVariable.synchronization', index=3, + number=4, type=14, cpp_type=8, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='aggregation', full_name='tensorflow.SavedVariable.aggregation', index=4, + number=5, type=14, cpp_type=8, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='name', full_name='tensorflow.SavedVariable.name', index=5, + number=6, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=1530, + serialized_end=1776, +) + + +_FUNCTIONSPEC = _descriptor.Descriptor( + name='FunctionSpec', + full_name='tensorflow.FunctionSpec', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='fullargspec', full_name='tensorflow.FunctionSpec.fullargspec', index=0, + number=1, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='is_method', full_name='tensorflow.FunctionSpec.is_method', index=1, + number=2, type=8, cpp_type=7, label=1, + has_default_value=False, default_value=False, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='input_signature', full_name='tensorflow.FunctionSpec.input_signature', index=2, + number=5, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=1779, + serialized_end=1928, +) + + +_SAVEDRESOURCE = _descriptor.Descriptor( + name='SavedResource', + full_name='tensorflow.SavedResource', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='device', full_name='tensorflow.SavedResource.device', index=0, + number=1, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=1930, + serialized_end=1961, +) + +_SAVEDOBJECTGRAPH_CONCRETEFUNCTIONSENTRY.fields_by_name['value'].message_type = _SAVEDCONCRETEFUNCTION +_SAVEDOBJECTGRAPH_CONCRETEFUNCTIONSENTRY.containing_type = _SAVEDOBJECTGRAPH +_SAVEDOBJECTGRAPH.fields_by_name['nodes'].message_type = _SAVEDOBJECT +_SAVEDOBJECTGRAPH.fields_by_name['concrete_functions'].message_type = _SAVEDOBJECTGRAPH_CONCRETEFUNCTIONSENTRY +_SAVEDOBJECT.fields_by_name['children'].message_type = trackable__object__graph__pb2._TRACKABLEOBJECTGRAPH_TRACKABLEOBJECT_OBJECTREFERENCE +_SAVEDOBJECT.fields_by_name['slot_variables'].message_type = trackable__object__graph__pb2._TRACKABLEOBJECTGRAPH_TRACKABLEOBJECT_SLOTVARIABLEREFERENCE +_SAVEDOBJECT.fields_by_name['user_object'].message_type = _SAVEDUSEROBJECT +_SAVEDOBJECT.fields_by_name['asset'].message_type = _SAVEDASSET +_SAVEDOBJECT.fields_by_name['function'].message_type = _SAVEDFUNCTION +_SAVEDOBJECT.fields_by_name['variable'].message_type = _SAVEDVARIABLE +_SAVEDOBJECT.fields_by_name['bare_concrete_function'].message_type = _SAVEDBARECONCRETEFUNCTION +_SAVEDOBJECT.fields_by_name['constant'].message_type = _SAVEDCONSTANT +_SAVEDOBJECT.fields_by_name['resource'].message_type = _SAVEDRESOURCE +_SAVEDOBJECT.oneofs_by_name['kind'].fields.append( + _SAVEDOBJECT.fields_by_name['user_object']) +_SAVEDOBJECT.fields_by_name['user_object'].containing_oneof = _SAVEDOBJECT.oneofs_by_name['kind'] +_SAVEDOBJECT.oneofs_by_name['kind'].fields.append( + _SAVEDOBJECT.fields_by_name['asset']) +_SAVEDOBJECT.fields_by_name['asset'].containing_oneof = _SAVEDOBJECT.oneofs_by_name['kind'] +_SAVEDOBJECT.oneofs_by_name['kind'].fields.append( + _SAVEDOBJECT.fields_by_name['function']) +_SAVEDOBJECT.fields_by_name['function'].containing_oneof = _SAVEDOBJECT.oneofs_by_name['kind'] +_SAVEDOBJECT.oneofs_by_name['kind'].fields.append( + _SAVEDOBJECT.fields_by_name['variable']) +_SAVEDOBJECT.fields_by_name['variable'].containing_oneof = _SAVEDOBJECT.oneofs_by_name['kind'] +_SAVEDOBJECT.oneofs_by_name['kind'].fields.append( + _SAVEDOBJECT.fields_by_name['bare_concrete_function']) +_SAVEDOBJECT.fields_by_name['bare_concrete_function'].containing_oneof = _SAVEDOBJECT.oneofs_by_name['kind'] +_SAVEDOBJECT.oneofs_by_name['kind'].fields.append( + _SAVEDOBJECT.fields_by_name['constant']) +_SAVEDOBJECT.fields_by_name['constant'].containing_oneof = _SAVEDOBJECT.oneofs_by_name['kind'] +_SAVEDOBJECT.oneofs_by_name['kind'].fields.append( + _SAVEDOBJECT.fields_by_name['resource']) +_SAVEDOBJECT.fields_by_name['resource'].containing_oneof = _SAVEDOBJECT.oneofs_by_name['kind'] +_SAVEDUSEROBJECT.fields_by_name['version'].message_type = versions__pb2._VERSIONDEF +_SAVEDFUNCTION.fields_by_name['function_spec'].message_type = _FUNCTIONSPEC +_SAVEDCONCRETEFUNCTION.fields_by_name['canonicalized_input_signature'].message_type = struct__pb2._STRUCTUREDVALUE +_SAVEDCONCRETEFUNCTION.fields_by_name['output_signature'].message_type = struct__pb2._STRUCTUREDVALUE +_SAVEDVARIABLE.fields_by_name['dtype'].enum_type = types__pb2._DATATYPE +_SAVEDVARIABLE.fields_by_name['shape'].message_type = tensor__shape__pb2._TENSORSHAPEPROTO +_SAVEDVARIABLE.fields_by_name['synchronization'].enum_type = variable__pb2._VARIABLESYNCHRONIZATION +_SAVEDVARIABLE.fields_by_name['aggregation'].enum_type = variable__pb2._VARIABLEAGGREGATION +_FUNCTIONSPEC.fields_by_name['fullargspec'].message_type = struct__pb2._STRUCTUREDVALUE +_FUNCTIONSPEC.fields_by_name['input_signature'].message_type = struct__pb2._STRUCTUREDVALUE +DESCRIPTOR.message_types_by_name['SavedObjectGraph'] = _SAVEDOBJECTGRAPH +DESCRIPTOR.message_types_by_name['SavedObject'] = _SAVEDOBJECT +DESCRIPTOR.message_types_by_name['SavedUserObject'] = _SAVEDUSEROBJECT +DESCRIPTOR.message_types_by_name['SavedAsset'] = _SAVEDASSET +DESCRIPTOR.message_types_by_name['SavedFunction'] = _SAVEDFUNCTION +DESCRIPTOR.message_types_by_name['SavedConcreteFunction'] = _SAVEDCONCRETEFUNCTION +DESCRIPTOR.message_types_by_name['SavedBareConcreteFunction'] = _SAVEDBARECONCRETEFUNCTION +DESCRIPTOR.message_types_by_name['SavedConstant'] = _SAVEDCONSTANT +DESCRIPTOR.message_types_by_name['SavedVariable'] = _SAVEDVARIABLE +DESCRIPTOR.message_types_by_name['FunctionSpec'] = _FUNCTIONSPEC +DESCRIPTOR.message_types_by_name['SavedResource'] = _SAVEDRESOURCE +_sym_db.RegisterFileDescriptor(DESCRIPTOR) + +SavedObjectGraph = _reflection.GeneratedProtocolMessageType('SavedObjectGraph', (_message.Message,), { + + 'ConcreteFunctionsEntry' : _reflection.GeneratedProtocolMessageType('ConcreteFunctionsEntry', (_message.Message,), { + 'DESCRIPTOR' : _SAVEDOBJECTGRAPH_CONCRETEFUNCTIONSENTRY, + '__module__' : 'saved_object_graph_pb2' + # @@protoc_insertion_point(class_scope:tensorflow.SavedObjectGraph.ConcreteFunctionsEntry) + }) + , + 'DESCRIPTOR' : _SAVEDOBJECTGRAPH, + '__module__' : 'saved_object_graph_pb2' + # @@protoc_insertion_point(class_scope:tensorflow.SavedObjectGraph) + }) +_sym_db.RegisterMessage(SavedObjectGraph) +_sym_db.RegisterMessage(SavedObjectGraph.ConcreteFunctionsEntry) + +SavedObject = _reflection.GeneratedProtocolMessageType('SavedObject', (_message.Message,), { + 'DESCRIPTOR' : _SAVEDOBJECT, + '__module__' : 'saved_object_graph_pb2' + # @@protoc_insertion_point(class_scope:tensorflow.SavedObject) + }) +_sym_db.RegisterMessage(SavedObject) + +SavedUserObject = _reflection.GeneratedProtocolMessageType('SavedUserObject', (_message.Message,), { + 'DESCRIPTOR' : _SAVEDUSEROBJECT, + '__module__' : 'saved_object_graph_pb2' + # @@protoc_insertion_point(class_scope:tensorflow.SavedUserObject) + }) +_sym_db.RegisterMessage(SavedUserObject) + +SavedAsset = _reflection.GeneratedProtocolMessageType('SavedAsset', (_message.Message,), { + 'DESCRIPTOR' : _SAVEDASSET, + '__module__' : 'saved_object_graph_pb2' + # @@protoc_insertion_point(class_scope:tensorflow.SavedAsset) + }) +_sym_db.RegisterMessage(SavedAsset) + +SavedFunction = _reflection.GeneratedProtocolMessageType('SavedFunction', (_message.Message,), { + 'DESCRIPTOR' : _SAVEDFUNCTION, + '__module__' : 'saved_object_graph_pb2' + # @@protoc_insertion_point(class_scope:tensorflow.SavedFunction) + }) +_sym_db.RegisterMessage(SavedFunction) + +SavedConcreteFunction = _reflection.GeneratedProtocolMessageType('SavedConcreteFunction', (_message.Message,), { + 'DESCRIPTOR' : _SAVEDCONCRETEFUNCTION, + '__module__' : 'saved_object_graph_pb2' + # @@protoc_insertion_point(class_scope:tensorflow.SavedConcreteFunction) + }) +_sym_db.RegisterMessage(SavedConcreteFunction) + +SavedBareConcreteFunction = _reflection.GeneratedProtocolMessageType('SavedBareConcreteFunction', (_message.Message,), { + 'DESCRIPTOR' : _SAVEDBARECONCRETEFUNCTION, + '__module__' : 'saved_object_graph_pb2' + # @@protoc_insertion_point(class_scope:tensorflow.SavedBareConcreteFunction) + }) +_sym_db.RegisterMessage(SavedBareConcreteFunction) + +SavedConstant = _reflection.GeneratedProtocolMessageType('SavedConstant', (_message.Message,), { + 'DESCRIPTOR' : _SAVEDCONSTANT, + '__module__' : 'saved_object_graph_pb2' + # @@protoc_insertion_point(class_scope:tensorflow.SavedConstant) + }) +_sym_db.RegisterMessage(SavedConstant) + +SavedVariable = _reflection.GeneratedProtocolMessageType('SavedVariable', (_message.Message,), { + 'DESCRIPTOR' : _SAVEDVARIABLE, + '__module__' : 'saved_object_graph_pb2' + # @@protoc_insertion_point(class_scope:tensorflow.SavedVariable) + }) +_sym_db.RegisterMessage(SavedVariable) + +FunctionSpec = _reflection.GeneratedProtocolMessageType('FunctionSpec', (_message.Message,), { + 'DESCRIPTOR' : _FUNCTIONSPEC, + '__module__' : 'saved_object_graph_pb2' + # @@protoc_insertion_point(class_scope:tensorflow.FunctionSpec) + }) +_sym_db.RegisterMessage(FunctionSpec) + +SavedResource = _reflection.GeneratedProtocolMessageType('SavedResource', (_message.Message,), { + 'DESCRIPTOR' : _SAVEDRESOURCE, + '__module__' : 'saved_object_graph_pb2' + # @@protoc_insertion_point(class_scope:tensorflow.SavedResource) + }) +_sym_db.RegisterMessage(SavedResource) + + +DESCRIPTOR._options = None +_SAVEDOBJECTGRAPH_CONCRETEFUNCTIONSENTRY._options = None +# @@protoc_insertion_point(module_scope) diff --git a/redis_consumer/pbs/saved_object_graph_pb2_grpc.py b/redis_consumer/pbs/saved_object_graph_pb2_grpc.py new file mode 100644 index 00000000..a8943526 --- /dev/null +++ b/redis_consumer/pbs/saved_object_graph_pb2_grpc.py @@ -0,0 +1,3 @@ +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +import grpc + diff --git a/redis_consumer/pbs/saver_pb2.py b/redis_consumer/pbs/saver_pb2.py new file mode 100644 index 00000000..bcf329e4 --- /dev/null +++ b/redis_consumer/pbs/saver_pb2.py @@ -0,0 +1,140 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: saver.proto + +from google.protobuf import descriptor as _descriptor +from google.protobuf import message as _message +from google.protobuf import reflection as _reflection +from google.protobuf import symbol_database as _symbol_database +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor.FileDescriptor( + name='saver.proto', + package='tensorflow', + syntax='proto3', + serialized_options=b'\n\023org.tensorflow.utilB\013SaverProtosP\001Z Date: Mon, 2 Mar 2020 11:10:12 -0800 Subject: [PATCH 04/47] add function for get_model_metadata to grpc_client and to tfsconsumer --- README.md | 1 + redis_consumer/consumers/base_consumer.py | 46 ++++++++++++++ redis_consumer/grpc_clients.py | 76 ++++++++++++++++++++++- redis_consumer/settings.py | 3 + 4 files changed, 123 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index be4b2176..57e23fa1 100644 --- a/README.md +++ b/README.md @@ -97,6 +97,7 @@ The consumer is configured using environment variables. Please find a table of a | `REDIS_TIMEOUT` | Timeout for each Redis request, in seconds. | `3` | | `EMPTY_QUEUE_TIMEOUT` | Time to wait after finding an empty queue, in seconds. | `5` | | `EXPIRE_TIME` | Expire Redis items this many seconds after completion. | `3600` | +| `METADATA_EXPIRE_TIME` | Expire cached model metadata after this many seconds. | `30` | | `TF_HOST` | The IP address or hostname of TensorFlow Serving. | `"tf-serving"` | | `TF_PORT` | The port used to connect to TensorFlow Serving. | `8500` | | `TF_TENSOR_NAME` | Name of input tensor for the exported model. | `"image"` | diff --git a/redis_consumer/consumers/base_consumer.py b/redis_consumer/consumers/base_consumer.py index 4900105c..c74aa4bd 100644 --- a/redis_consumer/consumers/base_consumer.py +++ b/redis_consumer/consumers/base_consumer.py @@ -328,6 +328,52 @@ def grpc_image(self, img, model_name, model_version): model_name, model_version, err) raise err + def get_model_metadata(self, model_name, model_version): + """Check Redis for saved model metadata or get from TensorFlow Serving. + + The Consumer prefers to get the model metadata from Redis, + but if the metadata does not exist or is too stale, + a TensorFlow Serving request will be made. + + Args: + model_name (str): The model name to get metadata. + model_version (int): The model version to get metadata. + """ + model = '{}:{}'.format(model_name, model_version) + self.logger.debug('Getting model metadata for model %s.', model) + + fields = ['in_tensor_dtype', 'in_tensor_shape'] + response = self.redis.hmget(model, *fields) + + if response: + self.logger.debug('Got cached metadata for model %s.', model) + return dict(zip(fields, response)) + + # No response! The key was expired. Get from TFS and update it. + start = timeit.default_timer() + client = self._get_predict_client(model_name, model_version) + model_metadata = client.get_model_metadata() + + try: + inputs = model_metadata['metadata']['signature_def']['signature_def'] + inputs = inputs[settings.TF_TENSOR_NAME] + + dtype = inputs['dtype'] + shape = [d['size'] for d in inputs['tensor_shape']['dim']] + + parsed_metadata = dict(zip(fields, [dtype, shape])) + + finished = timeit.default_timer() - start + self.logger.debug('Got model metadata for %s in %s seconds.', + model, finished) + + self.redis.hmset(model, mapping=parsed_metadata) + self.redis.expire(model, settings.METADATA_EXPIRE_TIME) + return parsed_metadata + except (KeyError, IndexError) as err: + self.logger.error('Malformed metadata: %s', model_metadata) + raise err + def process_big_image(self, cuts, img, diff --git a/redis_consumer/grpc_clients.py b/redis_consumer/grpc_clients.py index f7b9c07b..a8aa99a9 100644 --- a/redis_consumer/grpc_clients.py +++ b/redis_consumer/grpc_clients.py @@ -41,11 +41,10 @@ import numpy as np +from redis_consumer import settings from redis_consumer.pbs.prediction_service_pb2_grpc import PredictionServiceStub -from redis_consumer.pbs.processing_service_pb2_grpc import ProcessingServiceStub from redis_consumer.pbs.predict_pb2 import PredictRequest -from redis_consumer.pbs.process_pb2 import ProcessRequest -from redis_consumer.pbs.process_pb2 import ChunkedProcessRequest +from redis_consumer.pbs.get_model_metadata_pb2 import GetModelMetadataRequest from redis_consumer.utils import grpc_response_to_dict from redis_consumer.utils import make_tensor_proto @@ -150,6 +149,77 @@ def predict(self, request_data, request_timeout=10): channel.close() return {} + def get_model_metadata(self, request_timeout=10): + self.logger.info('Sending GetModelMetadataRequest to %s model %s:%s.', + self.host, self.model_name, self.model_version) + + true_failures, count = 0, 0 + + retrying = True + while retrying: + try: + t = timeit.default_timer() + channel = self.insecure_channel() + + stub = PredictionServiceStub(channel) + + request = GetModelMetadataRequest() + + request.model_spec.name = self.model_name # pylint: disable=E1101 + + if self.model_version > 0: + # pylint: disable=E1101 + request.model_spec.version.value = self.model_version + + predict_response = stub.GetModelMetadata( + request, timeout=request_timeout) + + self.logger.debug('gRPC GetModelMetadataRequest finished in %s ' + 'seconds.', timeit.default_timer() - t) + + t = timeit.default_timer() + predict_response_dict = grpc_response_to_dict(predict_response) + self.logger.debug('gRPC GetModelMetadataProtobufConversion took ' + '%s seconds.', timeit.default_timer() - t) + + channel.close() + return predict_response_dict + + except grpc.RpcError as err: + # pylint: disable=E1101 + channel.close() + if true_failures > settings.MAX_RETRY > 0: + retrying = False + self.logger.error('GetModelMetadataRequest has failed %s ' + 'times due to err %s', count, err) + raise err + + if err.code() in settings.GRPC_RETRY_STATUSES: + count += 1 + is_true_failure = err.code() != grpc.StatusCode.UNAVAILABLE + true_failures += int(is_true_failure) + + self.logger.warning('%sException `%s: %s` during ' + 'PredictClient GetModelMetadataRequest to ' + 'model %s:%s. Waiting %s seconds before ' + 'retrying.', type(err).__name__, + err.code().name, err.details(), + self.model_name, self.model_version, + settings.GRPC_BACKOFF) + + time.sleep(settings.GRPC_BACKOFF) # sleep before retry + retrying = True # Unneccessary but explicit + else: + retrying = False + raise err + except Exception as err: + channel.close() + retrying = False + self.logger.error('Encountered %s during GetModelMetadataRequest' + ' to model %s:%s: %s', type(err).__name__, + self.model_name, self.model_version, err) + raise err + class TrackingClient(GrpcClient): """gRPC Client for tensorflow-serving API. diff --git a/redis_consumer/settings.py b/redis_consumer/settings.py index 7829140d..c4561523 100644 --- a/redis_consumer/settings.py +++ b/redis_consumer/settings.py @@ -112,6 +112,9 @@ def _strip(x): # Configure expiration time for child keys EXPIRE_TIME = config('EXPIRE_TIME', default=3600, cast=int) +# Configure expiration for cached model metadata +METADATA_EXPIRE_TIME = config('METADATA_EXPIRE_TIME', default=30, cast=int) + # Pre- and Post-processing settings PROCESSING_FUNCTIONS = { 'pre': { From d0daac3e789eb5b6bb1c69020e85a99a087d7b2a Mon Sep 17 00:00:00 2001 From: William Graf <7930703+willgraf@users.noreply.github.com> Date: Mon, 2 Mar 2020 11:53:13 -0800 Subject: [PATCH 05/47] convert detect_scale and detect_label to use model metadata --- redis_consumer/consumers/base_consumer.py | 68 +++++++++++++++-------- 1 file changed, 45 insertions(+), 23 deletions(-) diff --git a/redis_consumer/consumers/base_consumer.py b/redis_consumer/consumers/base_consumer.py index c74aa4bd..68a41b2a 100644 --- a/redis_consumer/consumers/base_consumer.py +++ b/redis_consumer/consumers/base_consumer.py @@ -44,6 +44,8 @@ import numpy as np import pytz +from deepcell_toolbox.utils import tile_image, untile_image + from redis_consumer.grpc_clients import PredictClient from redis_consumer import utils from redis_consumer import settings @@ -434,6 +436,14 @@ def detect_scale(self, image): self.logger.debug('Scale detection disabled. Scale set to 1.') return 1 + model_name, model_version = settings.SCALE_DETECT_MODEL.split(':') + model_metadata = self.get_model_metadata(model_name, model_version) + + model_shape = [int(x) for x in model_metadata['in_tensor_shape'].split(',')] + + size_x = model_shape[len(model_shape) - 3] + size_y = model_shape[len(model_shape) - 2] + # Rescale image for compatibility with scale model # TODO Generalize to prevent from breaking on new input data types if image.shape[-1] == 1: @@ -441,27 +451,39 @@ def detect_scale(self, image): else: image = np.expand_dims(image, axis=-1) - # Reshape data to match size of data that model was trained on - # TODO Generalize to support rectangular and other shapes - size = settings.SCALE_RESHAPE_SIZE - if (image.shape[1] >= size) and (image.shape[2] >= size): - image, _ = utils.reshape_matrix(image, image, reshape_size=size) - - model_name, model_version = settings.SCALE_DETECT_MODEL.split(':') + tiles, _ = tile_image( + np.expand_dims(image, axis=0), + model_input_shape=(size_x, size_y), + stride_ratio=0.75) # Loop over each image in the batch dimension for scale prediction # TODO Calculate scale_detect_sample based on batch size # Could be based on fraction or sampling a minimum set number of frames scales = [] - for i in range(0, image.shape[0], settings.SCALE_DETECT_SAMPLE): - scales.append(self.grpc_image(image[i], model_name, model_version)) + for i in range(0, tiles.shape[0], settings.SCALE_DETECT_SAMPLE): + scales.append(self.grpc_image(tiles[i], model_name, model_version)) - self.logger.debug('Scale detection complete in %s seconds', - timeit.default_timer() - start) - return np.mean(scales) + detected_scale = np.mean(scales) + + self.logger.debug('Scale %s detected in %s seconds', + detected_scale, timeit.default_timer() - start) + return detected_scale def detect_label(self, image): start = timeit.default_timer() + + if not settings.LABEL_DETECT_ENABLED: + self.logger.debug('Label detection disabled. Label set to None.') + return None + + model_name, model_version = settings.LABEL_DETECT_MODEL.split(':') + model_metadata = self.get_model_metadata(model_name, model_version) + + model_shape = [int(x) for x in model_metadata['in_tensor_shape'].split(',')] + + size_x = model_shape[len(model_shape) - 3] + size_y = model_shape[len(model_shape) - 2] + # Rescale for model compatibility # TODO Generalize to prevent from breaking on new input data types if image.shape[-1] == 1: @@ -469,25 +491,25 @@ def detect_label(self, image): else: image = np.expand_dims(image, axis=-1) - # TODO Generalize to support rectangular and other shapes - size = settings.LABEL_RESHAPE_SIZE - if (image.shape[1] >= size) and (image.shape[2] >= size): - image, _ = utils.reshape_matrix(image, image, reshape_size=size) - - model_name, model_version = settings.LABEL_DETECT_MODEL.split(':') + tiles, _ = tile_image( + np.expand_dims(image, axis=0), + model_input_shape=(size_x, size_y), + stride_ratio=0.75) # Loop over each image in batch labels = [] - for i in range(0, image.shape[0], settings.LABEL_DETECT_SAMPLE): - labels.append(self.grpc_image(image[i], model_name, model_version)) + for i in range(0, tiles.shape[0], settings.LABEL_DETECT_SAMPLE): + labels.append(self.grpc_image(tiles[i], model_name, model_version)) labels = np.array(labels) vote = labels.sum(axis=0) maj = vote.max() - self.logger.debug('Label detection complete %s seconds.', - timeit.default_timer() - start) - return np.where(vote == maj)[-1][0] + detected = np.where(vote == maj)[-1][0] + + self.logger.debug('Label %s detected in %s seconds.', + detected, timeit.default_timer() - start) + return detected class ZipFileConsumer(Consumer): From d6a9236e72f8c6fff658e269b021646e8629c83e Mon Sep 17 00:00:00 2001 From: William Graf <7930703+willgraf@users.noreply.github.com> Date: Mon, 2 Mar 2020 11:56:32 -0800 Subject: [PATCH 06/47] update PredictRequest logging. --- redis_consumer/grpc_clients.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/redis_consumer/grpc_clients.py b/redis_consumer/grpc_clients.py index a8aa99a9..50f6572a 100644 --- a/redis_consumer/grpc_clients.py +++ b/redis_consumer/grpc_clients.py @@ -94,20 +94,20 @@ def __init__(self, host, model_name, model_version): self.model_version = model_version def predict(self, request_data, request_timeout=10): - self.logger.info('Sending request to %s model %s:%s.', + self.logger.info('Sending PredictRequest to %s model %s:%s.', self.host, self.model_name, self.model_version) channel = self.insecure_channel() t = timeit.default_timer() stub = PredictionServiceStub(channel) - self.logger.debug('Created TensorFlowServingServiceStub in %s seconds.', + self.logger.debug('Created PredictionServiceStub in %s seconds.', timeit.default_timer() - t) t = timeit.default_timer() request = PredictRequest() - self.logger.debug('Created TensorFlowServingRequest object in %s ' - 'seconds.', timeit.default_timer() - t) + self.logger.debug('Created PredictRequest object in %s seconds.', + timeit.default_timer() - t) request.model_spec.name = self.model_name # pylint: disable=E1101 @@ -127,22 +127,22 @@ def predict(self, request_data, request_timeout=10): try: t = timeit.default_timer() predict_response = stub.Predict(request, timeout=request_timeout) - self.logger.debug('gRPC TensorFlowServingRequest finished in %s ' - 'seconds.', timeit.default_timer() - t) + self.logger.debug('gRPC PredictRequest finished in %s seconds.', + timeit.default_timer() - t) t = timeit.default_timer() predict_response_dict = grpc_response_to_dict(predict_response) - self.logger.debug('gRPC TensorFlowServingProtobufConversion took ' + self.logger.debug('gRPC PredictResponseProtobufConversion took ' '%s seconds.', timeit.default_timer() - t) keys = [k for k in predict_response_dict] - self.logger.info('Got TensorFlowServingResponse with keys: %s ', + self.logger.info('Got PredictResponse with keys: %s ', keys) channel.close() return predict_response_dict except RpcError as err: - self.logger.error('Prediction failed due to: %s', err) + self.logger.error('PredictRequest failed due to: %s', err) channel.close() raise err From 8c012ce7fc9241e8e7b5da9c35f8268fe09ecfd0 Mon Sep 17 00:00:00 2001 From: William Graf <7930703+willgraf@users.noreply.github.com> Date: Mon, 2 Mar 2020 12:21:51 -0800 Subject: [PATCH 07/47] fix import issue in new pbs files --- redis_consumer/pbs/attr_value_pb2.py | 6 +++--- redis_consumer/pbs/function_pb2.py | 6 +++--- redis_consumer/pbs/get_model_metadata_pb2.py | 4 ++-- redis_consumer/pbs/graph_pb2.py | 6 +++--- redis_consumer/pbs/meta_graph_pb2.py | 14 +++++++------- redis_consumer/pbs/node_def_pb2.py | 2 +- redis_consumer/pbs/op_def_pb2.py | 4 ++-- redis_consumer/pbs/predict_pb2.py | 4 ++-- redis_consumer/pbs/prediction_service_pb2.py | 4 ++-- redis_consumer/pbs/prediction_service_pb2_grpc.py | 4 ++-- redis_consumer/pbs/resource_handle_pb2.py | 4 ++-- redis_consumer/pbs/saved_object_graph_pb2.py | 12 ++++++------ redis_consumer/pbs/struct_pb2.py | 4 ++-- redis_consumer/pbs/tensor_pb2.py | 6 +++--- 14 files changed, 40 insertions(+), 40 deletions(-) diff --git a/redis_consumer/pbs/attr_value_pb2.py b/redis_consumer/pbs/attr_value_pb2.py index a059608c..3262caf7 100644 --- a/redis_consumer/pbs/attr_value_pb2.py +++ b/redis_consumer/pbs/attr_value_pb2.py @@ -11,9 +11,9 @@ _sym_db = _symbol_database.Default() -import tensor_pb2 as tensor__pb2 -import tensor_shape_pb2 as tensor__shape__pb2 -import types_pb2 as types__pb2 +import redis_consumer.pbs.tensor_pb2 as tensor__pb2 +import redis_consumer.pbs.tensor_shape_pb2 as tensor__shape__pb2 +import redis_consumer.pbs.types_pb2 as types__pb2 DESCRIPTOR = _descriptor.FileDescriptor( diff --git a/redis_consumer/pbs/function_pb2.py b/redis_consumer/pbs/function_pb2.py index 2b6c5658..8084de52 100644 --- a/redis_consumer/pbs/function_pb2.py +++ b/redis_consumer/pbs/function_pb2.py @@ -11,9 +11,9 @@ _sym_db = _symbol_database.Default() -import attr_value_pb2 as attr__value__pb2 -import node_def_pb2 as node__def__pb2 -import op_def_pb2 as op__def__pb2 +import redis_consumer.pbs.attr_value_pb2 as attr__value__pb2 +import redis_consumer.pbs.node_def_pb2 as node__def__pb2 +import redis_consumer.pbs.op_def_pb2 as op__def__pb2 DESCRIPTOR = _descriptor.FileDescriptor( diff --git a/redis_consumer/pbs/get_model_metadata_pb2.py b/redis_consumer/pbs/get_model_metadata_pb2.py index bc5579fb..1695ddee 100644 --- a/redis_consumer/pbs/get_model_metadata_pb2.py +++ b/redis_consumer/pbs/get_model_metadata_pb2.py @@ -12,8 +12,8 @@ from google.protobuf import any_pb2 as google_dot_protobuf_dot_any__pb2 -import meta_graph_pb2 as meta__graph__pb2 -import model_pb2 as model__pb2 +import redis_consumer.pbs.meta_graph_pb2 as meta__graph__pb2 +import redis_consumer.pbs.model_pb2 as model__pb2 DESCRIPTOR = _descriptor.FileDescriptor( diff --git a/redis_consumer/pbs/graph_pb2.py b/redis_consumer/pbs/graph_pb2.py index fecd4080..78131df8 100644 --- a/redis_consumer/pbs/graph_pb2.py +++ b/redis_consumer/pbs/graph_pb2.py @@ -11,9 +11,9 @@ _sym_db = _symbol_database.Default() -import node_def_pb2 as node__def__pb2 -import function_pb2 as function__pb2 -import versions_pb2 as versions__pb2 +import redis_consumer.pbs.node_def_pb2 as node__def__pb2 +import redis_consumer.pbs.function_pb2 as function__pb2 +import redis_consumer.pbs.versions_pb2 as versions__pb2 DESCRIPTOR = _descriptor.FileDescriptor( diff --git a/redis_consumer/pbs/meta_graph_pb2.py b/redis_consumer/pbs/meta_graph_pb2.py index de9d402e..c8e8eebd 100644 --- a/redis_consumer/pbs/meta_graph_pb2.py +++ b/redis_consumer/pbs/meta_graph_pb2.py @@ -12,13 +12,13 @@ from google.protobuf import any_pb2 as google_dot_protobuf_dot_any__pb2 -import graph_pb2 as graph__pb2 -import op_def_pb2 as op__def__pb2 -import tensor_shape_pb2 as tensor__shape__pb2 -import types_pb2 as types__pb2 -import saved_object_graph_pb2 as saved__object__graph__pb2 -import saver_pb2 as saver__pb2 -import struct_pb2 as struct__pb2 +import redis_consumer.pbs.graph_pb2 as graph__pb2 +import redis_consumer.pbs.op_def_pb2 as op__def__pb2 +import redis_consumer.pbs.tensor_shape_pb2 as tensor__shape__pb2 +import redis_consumer.pbs.types_pb2 as types__pb2 +import redis_consumer.pbs.saved_object_graph_pb2 as saved__object__graph__pb2 +import redis_consumer.pbs.saver_pb2 as saver__pb2 +import redis_consumer.pbs.struct_pb2 as struct__pb2 DESCRIPTOR = _descriptor.FileDescriptor( diff --git a/redis_consumer/pbs/node_def_pb2.py b/redis_consumer/pbs/node_def_pb2.py index 31c0ba7e..b2864a5b 100644 --- a/redis_consumer/pbs/node_def_pb2.py +++ b/redis_consumer/pbs/node_def_pb2.py @@ -11,7 +11,7 @@ _sym_db = _symbol_database.Default() -import attr_value_pb2 as attr__value__pb2 +import redis_consumer.pbs.attr_value_pb2 as attr__value__pb2 DESCRIPTOR = _descriptor.FileDescriptor( diff --git a/redis_consumer/pbs/op_def_pb2.py b/redis_consumer/pbs/op_def_pb2.py index d4d90bcb..1a150703 100644 --- a/redis_consumer/pbs/op_def_pb2.py +++ b/redis_consumer/pbs/op_def_pb2.py @@ -11,8 +11,8 @@ _sym_db = _symbol_database.Default() -import attr_value_pb2 as attr__value__pb2 -import types_pb2 as types__pb2 +import redis_consumer.pbs.attr_value_pb2 as attr__value__pb2 +import redis_consumer.pbs.types_pb2 as types__pb2 DESCRIPTOR = _descriptor.FileDescriptor( diff --git a/redis_consumer/pbs/predict_pb2.py b/redis_consumer/pbs/predict_pb2.py index 9d6f2cef..8eb45853 100644 --- a/redis_consumer/pbs/predict_pb2.py +++ b/redis_consumer/pbs/predict_pb2.py @@ -11,8 +11,8 @@ _sym_db = _symbol_database.Default() -import tensor_pb2 as tensor__pb2 -import model_pb2 as model__pb2 +import redis_consumer.pbs.tensor_pb2 as tensor__pb2 +import redis_consumer.pbs.model_pb2 as model__pb2 DESCRIPTOR = _descriptor.FileDescriptor( diff --git a/redis_consumer/pbs/prediction_service_pb2.py b/redis_consumer/pbs/prediction_service_pb2.py index 7be5fa12..7e88b957 100644 --- a/redis_consumer/pbs/prediction_service_pb2.py +++ b/redis_consumer/pbs/prediction_service_pb2.py @@ -11,8 +11,8 @@ _sym_db = _symbol_database.Default() -import get_model_metadata_pb2 as get__model__metadata__pb2 -import predict_pb2 as predict__pb2 +import redis_consumer.pbs.get_model_metadata_pb2 as get__model__metadata__pb2 +import redis_consumer.pbs.predict_pb2 as predict__pb2 DESCRIPTOR = _descriptor.FileDescriptor( diff --git a/redis_consumer/pbs/prediction_service_pb2_grpc.py b/redis_consumer/pbs/prediction_service_pb2_grpc.py index e0235f0a..76a23c73 100644 --- a/redis_consumer/pbs/prediction_service_pb2_grpc.py +++ b/redis_consumer/pbs/prediction_service_pb2_grpc.py @@ -1,8 +1,8 @@ # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! import grpc -import get_model_metadata_pb2 as get__model__metadata__pb2 -import predict_pb2 as predict__pb2 +import redis_consumer.pbs.get_model_metadata_pb2 as get__model__metadata__pb2 +import redis_consumer.pbs.predict_pb2 as predict__pb2 class PredictionServiceStub(object): diff --git a/redis_consumer/pbs/resource_handle_pb2.py b/redis_consumer/pbs/resource_handle_pb2.py index 4e5fcccc..e149683e 100644 --- a/redis_consumer/pbs/resource_handle_pb2.py +++ b/redis_consumer/pbs/resource_handle_pb2.py @@ -11,8 +11,8 @@ _sym_db = _symbol_database.Default() -import tensor_shape_pb2 as tensor__shape__pb2 -import types_pb2 as types__pb2 +import redis_consumer.pbs.tensor_shape_pb2 as tensor__shape__pb2 +import redis_consumer.pbs.types_pb2 as types__pb2 DESCRIPTOR = _descriptor.FileDescriptor( diff --git a/redis_consumer/pbs/saved_object_graph_pb2.py b/redis_consumer/pbs/saved_object_graph_pb2.py index 4779b05d..7cd9de9e 100644 --- a/redis_consumer/pbs/saved_object_graph_pb2.py +++ b/redis_consumer/pbs/saved_object_graph_pb2.py @@ -11,12 +11,12 @@ _sym_db = _symbol_database.Default() -import trackable_object_graph_pb2 as trackable__object__graph__pb2 -import struct_pb2 as struct__pb2 -import tensor_shape_pb2 as tensor__shape__pb2 -import types_pb2 as types__pb2 -import versions_pb2 as versions__pb2 -import variable_pb2 as variable__pb2 +import redis_consumer.pbs.trackable_object_graph_pb2 as trackable__object__graph__pb2 +import redis_consumer.pbs.struct_pb2 as struct__pb2 +import redis_consumer.pbs.tensor_shape_pb2 as tensor__shape__pb2 +import redis_consumer.pbs.types_pb2 as types__pb2 +import redis_consumer.pbs.versions_pb2 as versions__pb2 +import redis_consumer.pbs.variable_pb2 as variable__pb2 DESCRIPTOR = _descriptor.FileDescriptor( diff --git a/redis_consumer/pbs/struct_pb2.py b/redis_consumer/pbs/struct_pb2.py index 3521e1aa..47613a6e 100644 --- a/redis_consumer/pbs/struct_pb2.py +++ b/redis_consumer/pbs/struct_pb2.py @@ -11,8 +11,8 @@ _sym_db = _symbol_database.Default() -import tensor_shape_pb2 as tensor__shape__pb2 -import types_pb2 as types__pb2 +import redis_consumer.pbs.tensor_shape_pb2 as tensor__shape__pb2 +import redis_consumer.pbs.types_pb2 as types__pb2 DESCRIPTOR = _descriptor.FileDescriptor( diff --git a/redis_consumer/pbs/tensor_pb2.py b/redis_consumer/pbs/tensor_pb2.py index 17ff0c1a..1318c383 100644 --- a/redis_consumer/pbs/tensor_pb2.py +++ b/redis_consumer/pbs/tensor_pb2.py @@ -11,9 +11,9 @@ _sym_db = _symbol_database.Default() -import resource_handle_pb2 as resource__handle__pb2 -import tensor_shape_pb2 as tensor__shape__pb2 -import types_pb2 as types__pb2 +import redis_consumer.pbs.resource_handle_pb2 as resource__handle__pb2 +import redis_consumer.pbs.tensor_shape_pb2 as tensor__shape__pb2 +import redis_consumer.pbs.types_pb2 as types__pb2 DESCRIPTOR = _descriptor.FileDescriptor( From 3234e5b894c546537c59f10af4d3bd086da0a8d9 Mon Sep 17 00:00:00 2001 From: William Graf <7930703+willgraf@users.noreply.github.com> Date: Mon, 2 Mar 2020 13:35:53 -0800 Subject: [PATCH 08/47] add test for get_model_metadata --- redis_consumer/consumers/base_consumer.py | 25 +++-- .../consumers/base_consumer_test.py | 98 +++++++++++++++++-- 2 files changed, 104 insertions(+), 19 deletions(-) diff --git a/redis_consumer/consumers/base_consumer.py b/redis_consumer/consumers/base_consumer.py index 68a41b2a..d581f880 100644 --- a/redis_consumer/consumers/base_consumer.py +++ b/redis_consumer/consumers/base_consumer.py @@ -250,7 +250,10 @@ def _get_predict_client(self, model_name, model_version): timeit.default_timer() - t) return client - def grpc_image(self, img, model_name, model_version): + def grpc_image(self, img, model_name, model_version, + in_tensor_dtype='DT_FLOAT'): + + in_tensor_dtype = str(in_tensor_dtype).upper() true_failures, count = 0, 0 start = timeit.default_timer() self.logger.debug('Segmenting image of shape %s with model %s:%s', @@ -258,15 +261,13 @@ def grpc_image(self, img, model_name, model_version): retrying = True while retrying: try: - floatx = settings.TF_TENSOR_DTYPE - if 'f16' in model_name: - floatx = 'DT_HALF' + if in_tensor_dtype == 'DT_HALF': # TODO: seems like should cast to "half" # but the model rejects the type, wants "int" or "long" img = img.astype('int') req_data = [{'in_tensor_name': settings.TF_TENSOR_NAME, - 'in_tensor_dtype': floatx, + 'in_tensor_dtype': in_tensor_dtype, 'data': np.expand_dims(img, axis=0)}] client = self._get_predict_client(model_name, model_version) @@ -358,10 +359,10 @@ def get_model_metadata(self, model_name, model_version): try: inputs = model_metadata['metadata']['signature_def']['signature_def'] - inputs = inputs[settings.TF_TENSOR_NAME] + inputs = inputs['serving_default']['inputs'][settings.TF_TENSOR_NAME] dtype = inputs['dtype'] - shape = [d['size'] for d in inputs['tensor_shape']['dim']] + shape = ','.join([d['size'] for d in inputs['tensor_shape']['dim']]) parsed_metadata = dict(zip(fields, [dtype, shape])) @@ -369,7 +370,7 @@ def get_model_metadata(self, model_name, model_version): self.logger.debug('Got model metadata for %s in %s seconds.', model, finished) - self.redis.hmset(model, mapping=parsed_metadata) + self.redis.hmset(model, parsed_metadata) self.redis.expire(model, settings.METADATA_EXPIRE_TIME) return parsed_metadata except (KeyError, IndexError) as err: @@ -439,6 +440,7 @@ def detect_scale(self, image): model_name, model_version = settings.SCALE_DETECT_MODEL.split(':') model_metadata = self.get_model_metadata(model_name, model_version) + model_dtype = model_metadata['in_tensor_dtype'] model_shape = [int(x) for x in model_metadata['in_tensor_shape'].split(',')] size_x = model_shape[len(model_shape) - 3] @@ -461,7 +463,8 @@ def detect_scale(self, image): # Could be based on fraction or sampling a minimum set number of frames scales = [] for i in range(0, tiles.shape[0], settings.SCALE_DETECT_SAMPLE): - scales.append(self.grpc_image(tiles[i], model_name, model_version)) + scales.append(self.grpc_image(tiles[i], model_name, model_version, + in_tensor_dtype=model_dtype)) detected_scale = np.mean(scales) @@ -479,6 +482,7 @@ def detect_label(self, image): model_name, model_version = settings.LABEL_DETECT_MODEL.split(':') model_metadata = self.get_model_metadata(model_name, model_version) + model_dtype = model_metadata['in_tensor_dtype'] model_shape = [int(x) for x in model_metadata['in_tensor_shape'].split(',')] size_x = model_shape[len(model_shape) - 3] @@ -499,7 +503,8 @@ def detect_label(self, image): # Loop over each image in batch labels = [] for i in range(0, tiles.shape[0], settings.LABEL_DETECT_SAMPLE): - labels.append(self.grpc_image(tiles[i], model_name, model_version)) + labels.append(self.grpc_image(tiles[i], model_name, model_version, + in_tensor_dtype=model_dtype)) labels = np.array(labels) vote = labels.sum(axis=0) diff --git a/redis_consumer/consumers/base_consumer_test.py b/redis_consumer/consumers/base_consumer_test.py index c8a2f44a..880b634b 100644 --- a/redis_consumer/consumers/base_consumer_test.py +++ b/redis_consumer/consumers/base_consumer_test.py @@ -379,11 +379,79 @@ def test_process_big_image(self): res = consumer.process_big_image(cuts, img, field, name, version) np.testing.assert_equal(res, img) + def test_get_model_metadata(self): + # pytest: disable=W0613 + redis_client = DummyRedis([]) + model_shape = (-1, 216, 216, 1) + model_dtype = 'DT_FLOAT' + + def hmget_success(key, *others): + shape = ','.join(str(s) for s in model_shape) + dtype = 'DT_FLOAT' + return dtype, shape + + def hmget_fail(key, *others): + shape = ','.join(str(s) for s in model_shape) + dtype = 'DT_FLOAT' + return None + + def _get_predict_client(model_name, model_version): + return Bunch(get_model_metadata=lambda: { + 'metadata': { + 'signature_def': { + 'signature_def': { + 'serving_default': { + 'inputs': { + settings.TF_TENSOR_NAME: { + 'dtype': model_dtype, + 'tensor_shape': { + 'dim': [ + {'size': str(x)} + for x in model_shape + ] + } + } + } + } + } + } + } + }) + + def _get_bad_predict_client(model_name, model_version): + return Bunch(get_model_metadata=lambda: dict()) + + redis_client.hmget = hmget_success + consumer = consumers.TensorFlowServingConsumer(redis_client, None, 'q') + consumer._get_predict_client = _get_predict_client + metadata = consumer.get_model_metadata('model', 1) + + assert metadata['in_tensor_dtype'] == 'DT_FLOAT' + assert metadata['in_tensor_shape'] == ','.join(str(x) for x in model_shape) + + redis_client.hmget = hmget_fail + consumer = consumers.TensorFlowServingConsumer(redis_client, None, 'q') + consumer._get_predict_client = _get_predict_client + metadata = consumer.get_model_metadata('model', 1) + + assert metadata['in_tensor_dtype'] == 'DT_FLOAT' + assert metadata['in_tensor_shape'] == ','.join(str(x) for x in model_shape) + + with pytest.raises(KeyError): + redis_client.hmget = hmget_fail + consumer = consumers.TensorFlowServingConsumer(redis_client, None, 'q') + consumer._get_predict_client = _get_bad_predict_client + consumer.get_model_metadata('model', 1) + def test_detect_label(self): redis_client = DummyRedis([]) + model_shape = (1, 216, 216, 1) consumer = consumers.TensorFlowServingConsumer(redis_client, None, 'q') - image = _get_image(settings.LABEL_RESHAPE_SIZE * 2, - settings.LABEL_RESHAPE_SIZE * 2) + consumer.get_model_metadata = lambda x, y: { + 'in_tensor_dtype': 'DT_FLOAT', + 'in_tensor_shape': ','.join(str(s) for s in model_shape), + } + image = _get_image(model_shape[1] * 2, model_shape[2] * 2) settings.LABEL_DETECT_MODEL = 'dummymodel:1' @@ -395,16 +463,28 @@ def dummydata(*_, **__): consumer.grpc_image = dummydata + settings.LABEL_DETECT_ENABLED = False + + label = consumer.detect_label(image) + assert label is None + + settings.LABEL_DETECT_ENABLED = True + label = consumer.detect_label(image) assert label in set(list(range(4))) def test_detect_scale(self): redis_client = DummyRedis([]) + model_shape = (1, 216, 216, 1) consumer = consumers.TensorFlowServingConsumer(redis_client, None, 'q') - big_size = settings.SCALE_RESHAPE_SIZE * np.random.randint(2, 9) + consumer.get_model_metadata = lambda x, y: { + 'in_tensor_dtype': 'DT_FLOAT', + 'in_tensor_shape': ','.join(str(s) for s in model_shape), + } + big_size = model_shape[1] * np.random.randint(2, 9) image = _get_image(big_size, big_size) - expected = (settings.SCALE_RESHAPE_SIZE / (big_size)) ** 2 + expected = (model_shape[1] / (big_size)) ** 2 settings.SCALE_DETECT_MODEL = 'dummymodel:1' @@ -421,16 +501,16 @@ def grpc_image(*_, **__): settings.SCALE_DETECT_ENABLED = True - scale = consumer.detect_scale(image) - assert isinstance(scale, (float, int)) - np.testing.assert_almost_equal(scale, expected) - consumer.grpc_image = grpc_image - scale = consumer.detect_scale(np.expand_dims(image, axis=-1)) + scale = consumer.detect_scale(image) assert isinstance(scale, (float, int)) np.testing.assert_almost_equal(scale, expected) + # scale = consumer.detect_scale(np.expand_dims(image, axis=-1)) + # assert isinstance(scale, (float, int)) + # np.testing.assert_almost_equal(scale, expected) + class TestZipFileConsumer(object): From bf58685a083d30397bdb9a006a8a60e1ab7d60ce Mon Sep 17 00:00:00 2001 From: William Graf <7930703+willgraf@users.noreply.github.com> Date: Mon, 2 Mar 2020 18:00:01 -0800 Subject: [PATCH 09/47] update image_consumer to query the model_metadata server for model size and dtype --- redis_consumer/consumers/image_consumer.py | 53 +++++++---- .../consumers/image_consumer_test.py | 89 ++++++++++++------- redis_consumer/settings.py | 12 --- 3 files changed, 95 insertions(+), 59 deletions(-) diff --git a/redis_consumer/consumers/image_consumer.py b/redis_consumer/consumers/image_consumer.py index ce9349e9..312430bb 100644 --- a/redis_consumer/consumers/image_consumer.py +++ b/redis_consumer/consumers/image_consumer.py @@ -223,48 +223,60 @@ def _consume(self, redis_hash): # Grap appropriate model model_name, model_version = utils._pick_model(label) + model_metadata = self.get_model_metadata(model_name, model_version) + + model_dtype = model_metadata['in_tensor_dtype'] + model_shape = [int(x) for x in model_metadata['in_tensor_shape'].split(',')] + + size_x = model_shape[len(model_shape) - 3] + size_y = model_shape[len(model_shape) - 2] + pre_funcs = hvals.get('preprocess_function', '').split(',') image = self.preprocess(image, pre_funcs, True) # Send data to the model self.update_key(redis_hash, {'status': 'predicting'}) - model_shape = settings.MODEL_SIZES.get( - '{}:{}'.format(model_name, model_version), max(image.shape)) - - if (image.shape[image.ndim - 3] < model_shape or - image.shape[image.ndim - 2] < model_shape): + if (image.shape[image.ndim - 3] < size_x or + image.shape[image.ndim - 2] < size_y): # tiling not necessary, but image must be padded. pad_width = [] for i in range(image.ndim): if i in {image.ndim - 3, image.ndim - 2}: - diff = model_shape - image.shape[i] + if i == image.ndim - 3: + diff = size_x - image.shape[i] + else: + diff = size_y - image.shape[i] + if diff % 2: pad_width.append((diff // 2, diff // 2 + 1)) else: pad_width.append((diff // 2, diff // 2)) + else: pad_width.append((0, 0)) - padded_img = np.pad(image, pad_width, 'reflect') - image = self.grpc_image(padded_img, model_name, model_version) - for i, j in enumerate(image): + padded_img = np.pad(image, pad_width, 'reflect') + image = self.grpc_image(padded_img, model_name, model_version, + in_tensor_dtype=model_dtype) - self.logger.critical('output %s shape is %s', i, j.shape) + # pad batch_size and frames. + while len(pad_width) < padded_img.ndim: + pad_width.insert(0, (0, 0)) # unpad results - pad_width.insert(0, (0, 0)) # batch size if isinstance(image, list): image = [utils.unpad_image(i, pad_width) for i in image] else: image = utils.unpad_image(image, pad_width) - elif (image.shape[image.ndim - 3] > model_shape or - image.shape[image.ndim - 2] > model_shape): + elif (image.shape[image.ndim - 3] > size_x or + image.shape[image.ndim - 2] > size_y): + # need to tile! tiles, tiles_info = tile_image( np.expand_dims(image, axis=0), - model_input_shape=(model_shape, model_shape), + model_input_shape=(size_x, size_y), stride_ratio=0.75) # max_batch_size is 1 by default. @@ -272,13 +284,22 @@ def _consume(self, redis_hash): results = [] for t in range(tiles.shape[0]): output = self.grpc_image(tiles[t], model_name, model_version) - if not results: + + if not isinstance(output, list): + output = [output] + + if results == []: results = output + else: for i, o in enumerate(output): results[i] = np.vstack((results[i], o)) - image = [untile_image(r, tiles_info) for r in results] + image = [ + untile_image(r, tiles_info, model_input_shape=(size_x, size_y)) + for r in results + ] + image = image[0] if len(image) == 1 else image else: image = self.grpc_image(image, model_name, model_version) diff --git a/redis_consumer/consumers/image_consumer_test.py b/redis_consumer/consumers/image_consumer_test.py index bc3bb493..e02ff134 100644 --- a/redis_consumer/consumers/image_consumer_test.py +++ b/redis_consumer/consumers/image_consumer_test.py @@ -225,8 +225,7 @@ def test_process(self): settings.PROCESSING_FUNCTIONS = _funcs def test__consume(self): - settings.LABEL_DETECT_ENABLED = False - settings.SCALE_DETECT_ENABLED = False + # pylint: disable=W0613 prefix = 'predict' status = 'new' redis_client = DummyRedis(prefix, status) @@ -234,32 +233,75 @@ def test__consume(self): consumer = consumers.ImageFileConsumer(redis_client, storage, prefix) - def _handle_error(err, rhash): # pylint: disable=W0613 + def _handle_error(err, rhash): raise err - def grpc_image_multi(data, *args, **kwargs): # pylint: disable=W0613 + def grpc_image(data, *args, **kwargs): + data = np.expand_dims(data, axis=0) + return data + + def grpc_image_multi(data, *args, **kwargs): + data = np.expand_dims(data, axis=0) return np.array(tuple(list(data.shape) + [2])) + def grpc_image_list(data, *args, **kwargs): # pylint: disable=W0613 + data = np.expand_dims(data, axis=0) + return [data, data] + def detect_scale(_): return 1 def detect_label(_): return 0 + def make_model_metadata_of_size(model_shape=(-1, 256, 256, 1)): + + def get_model_metadata(model_name, model_version): + return { + 'in_tensor_dtype': 'DT_FLOAT', + 'in_tensor_shape': ','.join(str(s) for s in model_shape), + } + + return get_model_metadata + dummyhash = '{}:test.tiff:{}'.format(prefix, status) - # consumer._handle_error = _handle_error - consumer.grpc_image = grpc_image_multi + model_shapes = [ + (1, 600, 600, 1), # image too small, pad + (1, 300, 300, 1), # image is exactly the right size + (1, 150, 150, 1), # image too big, tile + ] + + consumer._handle_error = _handle_error + consumer.grpc_image = grpc_image consumer.detect_scale = detect_scale - result = consumer._consume(dummyhash) - assert result == consumer.final_status - # test with a finished hash - result = consumer._consume('{}:test.tiff:{}'.format(prefix, 'done')) - assert result == 'done' + consumer.detect_label = detect_label - # test mutli-channel - def grpc_image(data, *args, **kwargs): # pylint: disable=W0613 - return data + # consumer.grpc_image = grpc_image_multi + # consumer.get_model_metadata = make_model_metadata_of_size(model_shapes[0]) + # + # result = consumer._consume(dummyhash) + # assert result == consumer.final_status + # + # # test with a finished hash + # result = consumer._consume('{}:test.tiff:{}'.format(prefix, 'done')) + # assert result == 'done' + + for b in (False, True): + settings.SCALE_DETECT_ENABLED = settings.LABEL_DETECT_ENABLED = b + for model_shape in model_shapes: + for grpc_func in (grpc_image, grpc_image_list): + + consumer.grpc_image = grpc_func + consumer.get_model_metadata = \ + make_model_metadata_of_size(model_shape) + + result = consumer._consume(dummyhash) + assert result == consumer.final_status + # test with a finished hash + result = consumer._consume('{}:test.tiff:{}'.format( + prefix, consumer.final_status)) + assert result == consumer.final_status # test with cuts > 0 redis_client.hgetall = lambda x: { @@ -278,27 +320,11 @@ def grpc_image(data, *args, **kwargs): # pylint: disable=W0613 consumer._handle_error = _handle_error consumer.detect_scale = detect_scale consumer.detect_label = detect_label + consumer.get_model_metadata = make_model_metadata_of_size((1, 300, 300, 1)) consumer.grpc_image = grpc_image result = consumer._consume(dummyhash) assert result == consumer.final_status - # test with multiple outputs from model and cuts == 0 - - def grpc_image_list(data, *args, **kwargs): # pylint: disable=W0613 - return [data, data] - - redis_client = DummyRedis(prefix, status) - consumer = consumers.ImageFileConsumer(redis_client, storage, prefix) - consumer._handle_error = _handle_error - consumer.detect_scale = detect_scale - consumer.detect_label = detect_label - consumer.grpc_image = grpc_image_list - result = consumer._consume(dummyhash) - assert result == consumer.final_status - - settings.LABEL_DETECT_ENABLED = True - settings.SCALE_DETECT_ENABLED = True - # test with model_name and model_version redis_client.hgetall = lambda x: { 'model_name': 'model', @@ -316,6 +342,7 @@ def grpc_image_list(data, *args, **kwargs): # pylint: disable=W0613 consumer._handle_error = _handle_error consumer.detect_scale = detect_scale consumer.detect_label = detect_label + consumer.get_model_metadata = make_model_metadata_of_size((1, 300, 300, 1)) consumer.grpc_image = grpc_image result = consumer._consume(dummyhash) assert result == consumer.final_status diff --git a/redis_consumer/settings.py b/redis_consumer/settings.py index c4561523..fd30ed82 100644 --- a/redis_consumer/settings.py +++ b/redis_consumer/settings.py @@ -155,13 +155,11 @@ def _strip(x): SCALE_DETECT_SAMPLE = config('SCALE_DETECT_SAMPLE', default=3, cast=int) # Not supported for tracking. Always detects scale SCALE_DETECT_ENABLED = config('SCALE_DETECT_ENABLED', default=False, cast=bool) -SCALE_RESHAPE_SIZE = config('SCALE_RESHAPE_SIZE', default=216, cast=int) # Type detection settings LABEL_DETECT_MODEL = config('LABEL_DETECT_MODEL', default='LabelDetection:2', cast=str) LABEL_DETECT_SAMPLE = config('LABEL_DETECT_SAMPLE', default=3, cast=int) LABEL_DETECT_ENABLED = config('LABEL_DETECT_ENABLED', default=False, cast=bool) -LABEL_RESHAPE_SIZE = config('LABEL_RESHAPE_SIZE', default=216, cast=int) # Set default models based on label type PHASE_MODEL = config('PHASE_MODEL', default='panoptic_phase:0', cast=str) @@ -178,16 +176,6 @@ def _strip(x): CYTOPLASM_POSTPROCESS = config('CYTOPLASM_POSTPROCESS', default='deep_watershed', cast=str) NUCLEAR_POSTPROCESS = config('NUCLEAR_POSTPROCESS', default='deep_watershed', cast=str) -PHASE_RESHAPE_SIZE = config('PHASE_RESHAPE_SIZE', default=512, cast=int) -CYTOPLASM_RESHAPE_SIZE = config('CYTOPLASM_RESHAPE_SIZE', default=512, cast=int) -NUCLEAR_RESHAPE_SIZE = config('NUCLEAR_RESHAPE_SIZE', default=512, cast=int) - -MODEL_SIZES = { - NUCLEAR_MODEL: NUCLEAR_RESHAPE_SIZE, - PHASE_MODEL: PHASE_RESHAPE_SIZE, - CYTOPLASM_MODEL: CYTOPLASM_RESHAPE_SIZE, -} - POSTPROCESS_CHOICES = { 0: NUCLEAR_POSTPROCESS, 1: PHASE_POSTPROCESS, From 5a5d32999a233eaba60dbf788326c4313f3a0e9f Mon Sep 17 00:00:00 2001 From: William Graf <7930703+willgraf@users.noreply.github.com> Date: Mon, 2 Mar 2020 18:00:10 -0800 Subject: [PATCH 10/47] more rigorous testing on unpad_image --- redis_consumer/utils_test.py | 44 +++++++++++++++++++++++++----------- 1 file changed, 31 insertions(+), 13 deletions(-) diff --git a/redis_consumer/utils_test.py b/redis_consumer/utils_test.py index 7b4e2dee..c4fa79ca 100644 --- a/redis_consumer/utils_test.py +++ b/redis_consumer/utils_test.py @@ -184,23 +184,41 @@ def test_pad_image(): def test_unpad_image(): # 2D images - h, w = 330, 330 - padded = _get_image(h, w) - pad_width = [(15, 15), (15, 15), (0, 0)] + h, w = 300, 300 - new_h = h - (pad_width[0][0] + pad_width[0][1]) - new_w = w - (pad_width[1][0] + pad_width[1][1]) + sizes = [ + (300, 300), + (101, 101) + ] - unpadded = utils.unpad_image(padded, pad_width) - np.testing.assert_equal(unpadded.shape, (new_h, new_w, 1)) + pads = [ + (10, 10), + (15, 15), + (10, 15) + ] + for pad in pads: + for h, w in sizes: + raw = _get_image(h, w) + pad_width = [pad, pad, (0, 0)] + padded = np.pad(raw, pad_width, mode='reflect') - # 3D images - frames = np.random.randint(low=1, high=6) - imgs = np.vstack([_get_image(h, w)[None, ...] for i in range(frames)]) + unpadded = utils.unpad_image(padded, pad_width) + np.testing.assert_equal(unpadded.shape, (h, w, 1)) + np.testing.assert_equal(unpadded, raw) + + # 3D images + frames = np.random.randint(low=1, high=6) + imgs = np.vstack([_get_image(h, w)[None, ...] + for _ in range(frames)]) + + pad_width = [(0, 0), pad, pad, (0, 0)] + + padded = np.pad(imgs, pad_width, mode='reflect') + + unpadded = utils.unpad_image(padded, pad_width) - pad_width = [(0, 0), (15, 15), (15, 15), (0, 0)] - unpadded = utils.unpad_image(imgs, pad_width) - np.testing.assert_equal(unpadded.shape, (frames, new_h, new_w, 1)) + np.testing.assert_equal(unpadded.shape, imgs.shape) + np.testing.assert_equal(unpadded, imgs) def test_save_numpy_array(): From d4a03ca417ffe9c9dbc98b6743da3ecac137ebb8 Mon Sep 17 00:00:00 2001 From: William Graf <7930703+willgraf@users.noreply.github.com> Date: Mon, 2 Mar 2020 18:06:35 -0800 Subject: [PATCH 11/47] add TODOs in env vars. --- redis_consumer/settings.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/redis_consumer/settings.py b/redis_consumer/settings.py index fd30ed82..922a8454 100644 --- a/redis_consumer/settings.py +++ b/redis_consumer/settings.py @@ -105,6 +105,8 @@ def _strip(x): # Pod Meteadta HOSTNAME = config('HOSTNAME', default='host-unkonwn') +CUTS = config('CUTS', default=0, cast=int) # TODO: deprecated + # Redis queue QUEUE = config('QUEUE', default='predict') SEGMENTATION_QUEUE = config('SEGMENTATION_QUEUE', default='predict') @@ -135,14 +137,13 @@ def _strip(x): TRACKING_SEGMENT_MODEL = config('TRACKING_SEGMENT_MODEL', default='panoptic:3', cast=str) TRACKING_POSTPROCESS_FUNCTION = config('TRACKING_POSTPROCESS_FUNCTION', default='retinanet', cast=str) -CUTS = config('CUTS', default=0, cast=int) TRACKING_MODEL = config('TRACKING_MODEL', default='TrackingModel:0', cast=str) DRIFT_CORRECT_ENABLED = config('DRIFT_CORRECT_ENABLED', default=False, cast=bool) NORMALIZE_TRACKING = config('NORMALIZE_TRACKING', default=True, cast=bool) -# tracking.cell_tracker settings +# tracking.cell_tracker settings TODO: can we extract from model_metadata? MAX_DISTANCE = config('MAX_DISTANCE', default=50, cast=int) TRACK_LENGTH = config('TRACK_LENGTH', default=5, cast=int) DIVISION = config('DIVISION', default=0.9, cast=float) From dc1415a9d5bccf92725612f34eb3457ddb5b463c Mon Sep 17 00:00:00 2001 From: William Graf <7930703+willgraf@users.noreply.github.com> Date: Mon, 2 Mar 2020 18:15:50 -0800 Subject: [PATCH 12/47] add metadata_field to GetModelMetadata request --- redis_consumer/grpc_clients.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/redis_consumer/grpc_clients.py b/redis_consumer/grpc_clients.py index 50f6572a..ae57863b 100644 --- a/redis_consumer/grpc_clients.py +++ b/redis_consumer/grpc_clients.py @@ -157,6 +157,7 @@ def get_model_metadata(self, request_timeout=10): retrying = True while retrying: + # pylint: disable=E1101 try: t = timeit.default_timer() channel = self.insecure_channel() @@ -165,10 +166,10 @@ def get_model_metadata(self, request_timeout=10): request = GetModelMetadataRequest() - request.model_spec.name = self.model_name # pylint: disable=E1101 + request.metadata_field = 'signature_def' + request.model_spec.name = self.model_name if self.model_version > 0: - # pylint: disable=E1101 request.model_spec.version.value = self.model_version predict_response = stub.GetModelMetadata( @@ -186,7 +187,6 @@ def get_model_metadata(self, request_timeout=10): return predict_response_dict except grpc.RpcError as err: - # pylint: disable=E1101 channel.close() if true_failures > settings.MAX_RETRY > 0: retrying = False From 6805af01e54d75c86bfe10bbaa036b7131ec5620 Mon Sep 17 00:00:00 2001 From: William Graf <7930703+willgraf@users.noreply.github.com> Date: Mon, 2 Mar 2020 18:16:04 -0800 Subject: [PATCH 13/47] make sure all responses are truthy to be cached. --- redis_consumer/consumers/base_consumer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/redis_consumer/consumers/base_consumer.py b/redis_consumer/consumers/base_consumer.py index d581f880..b2497988 100644 --- a/redis_consumer/consumers/base_consumer.py +++ b/redis_consumer/consumers/base_consumer.py @@ -348,7 +348,7 @@ def get_model_metadata(self, model_name, model_version): fields = ['in_tensor_dtype', 'in_tensor_shape'] response = self.redis.hmget(model, *fields) - if response: + if all(response): self.logger.debug('Got cached metadata for model %s.', model) return dict(zip(fields, response)) From e7d327d9861650d74e4710ca3f0cc1e2a02d81c1 Mon Sep 17 00:00:00 2001 From: William Graf <7930703+willgraf@users.noreply.github.com> Date: Mon, 2 Mar 2020 18:22:40 -0800 Subject: [PATCH 14/47] append to repeated field, not assign. --- redis_consumer/grpc_clients.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/redis_consumer/grpc_clients.py b/redis_consumer/grpc_clients.py index ae57863b..4a4267e8 100644 --- a/redis_consumer/grpc_clients.py +++ b/redis_consumer/grpc_clients.py @@ -166,7 +166,8 @@ def get_model_metadata(self, request_timeout=10): request = GetModelMetadataRequest() - request.metadata_field = 'signature_def' + request.metadata_field.append('signature_def') + request.model_spec.name = self.model_name if self.model_version > 0: From 7add417f5add39926d8b308a6f8c4cf2231eeafd Mon Sep 17 00:00:00 2001 From: William Graf <7930703+willgraf@users.noreply.github.com> Date: Mon, 2 Mar 2020 19:39:49 -0800 Subject: [PATCH 15/47] fix test for all(response) check --- redis_consumer/consumers/base_consumer_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/redis_consumer/consumers/base_consumer_test.py b/redis_consumer/consumers/base_consumer_test.py index 880b634b..d2ac917a 100644 --- a/redis_consumer/consumers/base_consumer_test.py +++ b/redis_consumer/consumers/base_consumer_test.py @@ -393,7 +393,7 @@ def hmget_success(key, *others): def hmget_fail(key, *others): shape = ','.join(str(s) for s in model_shape) dtype = 'DT_FLOAT' - return None + return [None] * len(others) def _get_predict_client(model_name, model_version): return Bunch(get_model_metadata=lambda: { From 3cb33f263a11468c319c16358823639bf76ca393 Mon Sep 17 00:00:00 2001 From: William Graf <7930703+willgraf@users.noreply.github.com> Date: Mon, 2 Mar 2020 19:40:19 -0800 Subject: [PATCH 16/47] fix signatureDef index --- redis_consumer/consumers/base_consumer.py | 2 +- redis_consumer/consumers/base_consumer_test.py | 2 +- redis_consumer/grpc_clients.py | 13 +++++++++---- 3 files changed, 11 insertions(+), 6 deletions(-) diff --git a/redis_consumer/consumers/base_consumer.py b/redis_consumer/consumers/base_consumer.py index b2497988..c1101d80 100644 --- a/redis_consumer/consumers/base_consumer.py +++ b/redis_consumer/consumers/base_consumer.py @@ -358,7 +358,7 @@ def get_model_metadata(self, model_name, model_version): model_metadata = client.get_model_metadata() try: - inputs = model_metadata['metadata']['signature_def']['signature_def'] + inputs = model_metadata['metadata']['signature_def']['signatureDef'] inputs = inputs['serving_default']['inputs'][settings.TF_TENSOR_NAME] dtype = inputs['dtype'] diff --git a/redis_consumer/consumers/base_consumer_test.py b/redis_consumer/consumers/base_consumer_test.py index d2ac917a..e81587c9 100644 --- a/redis_consumer/consumers/base_consumer_test.py +++ b/redis_consumer/consumers/base_consumer_test.py @@ -399,7 +399,7 @@ def _get_predict_client(model_name, model_version): return Bunch(get_model_metadata=lambda: { 'metadata': { 'signature_def': { - 'signature_def': { + 'signatureDef': { 'serving_default': { 'inputs': { settings.TF_TENSOR_NAME: { diff --git a/redis_consumer/grpc_clients.py b/redis_consumer/grpc_clients.py index 4a4267e8..192aa0a4 100644 --- a/redis_consumer/grpc_clients.py +++ b/redis_consumer/grpc_clients.py @@ -30,6 +30,7 @@ from __future__ import division from __future__ import print_function +import json import logging import time import timeit @@ -41,6 +42,8 @@ import numpy as np +from google.protobuf.json_format import MessageToJson + from redis_consumer import settings from redis_consumer.pbs.prediction_service_pb2_grpc import PredictionServiceStub from redis_consumer.pbs.predict_pb2 import PredictRequest @@ -173,19 +176,21 @@ def get_model_metadata(self, request_timeout=10): if self.model_version > 0: request.model_spec.version.value = self.model_version - predict_response = stub.GetModelMetadata( - request, timeout=request_timeout) + response = stub.GetModelMetadata(request, timeout=request_timeout) self.logger.debug('gRPC GetModelMetadataRequest finished in %s ' 'seconds.', timeit.default_timer() - t) t = timeit.default_timer() - predict_response_dict = grpc_response_to_dict(predict_response) + + response_dict = json.loads(MessageToJson(response)) + + # signature_def = response.metadata['signature_def'] self.logger.debug('gRPC GetModelMetadataProtobufConversion took ' '%s seconds.', timeit.default_timer() - t) channel.close() - return predict_response_dict + return response_dict except grpc.RpcError as err: channel.close() From b3a75d8a20fd0de8eb7b4220549c0033806632f5 Mon Sep 17 00:00:00 2001 From: William Graf <7930703+willgraf@users.noreply.github.com> Date: Mon, 2 Mar 2020 19:41:40 -0800 Subject: [PATCH 17/47] tensorShape not tensor_shape --- redis_consumer/consumers/base_consumer.py | 2 +- redis_consumer/consumers/base_consumer_test.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/redis_consumer/consumers/base_consumer.py b/redis_consumer/consumers/base_consumer.py index c1101d80..a1fdfc0f 100644 --- a/redis_consumer/consumers/base_consumer.py +++ b/redis_consumer/consumers/base_consumer.py @@ -362,7 +362,7 @@ def get_model_metadata(self, model_name, model_version): inputs = inputs['serving_default']['inputs'][settings.TF_TENSOR_NAME] dtype = inputs['dtype'] - shape = ','.join([d['size'] for d in inputs['tensor_shape']['dim']]) + shape = ','.join([d['size'] for d in inputs['tensorShape']['dim']]) parsed_metadata = dict(zip(fields, [dtype, shape])) diff --git a/redis_consumer/consumers/base_consumer_test.py b/redis_consumer/consumers/base_consumer_test.py index e81587c9..127b41b9 100644 --- a/redis_consumer/consumers/base_consumer_test.py +++ b/redis_consumer/consumers/base_consumer_test.py @@ -404,7 +404,7 @@ def _get_predict_client(model_name, model_version): 'inputs': { settings.TF_TENSOR_NAME: { 'dtype': model_dtype, - 'tensor_shape': { + 'tensorShape': { 'dim': [ {'size': str(x)} for x in model_shape From f2cd2a3d21327b2e429cb276fd7e0b09d79d729e Mon Sep 17 00:00:00 2001 From: William Graf <7930703+willgraf@users.noreply.github.com> Date: Mon, 2 Mar 2020 21:11:06 -0800 Subject: [PATCH 18/47] deprecate process_big_image for predict, which handles too big and too small --- redis_consumer/consumers/base_consumer.py | 207 +++++++++++------- .../consumers/base_consumer_test.py | 64 ++++-- redis_consumer/consumers/image_consumer.py | 80 +------ 3 files changed, 166 insertions(+), 185 deletions(-) diff --git a/redis_consumer/consumers/base_consumer.py b/redis_consumer/consumers/base_consumer.py index a1fdfc0f..e423d10b 100644 --- a/redis_consumer/consumers/base_consumer.py +++ b/redis_consumer/consumers/base_consumer.py @@ -377,58 +377,133 @@ def get_model_metadata(self, model_name, model_version): self.logger.error('Malformed metadata: %s', model_metadata) raise err - def process_big_image(self, - cuts, - img, - field, - model_name, - model_version): - """Slice big image into smaller images for prediction, - then stitches all the smaller images back together. + def _predict_big_image(self, + image, + model_name, + model_version, + model_shape, + model_dtype='DT_FLOAT'): + """Use tile_image to tile image for the model and untile the results. Args: - cuts: number of cuts in x and y to slice smaller images - img: image data as numpy array - field: receptive field size of model, changes padding sizes - model_name: hosted model to send image data - model_version: model version to query + img (numpy.array): image data as numpy. + model_name (str): hosted model to send image data. + model_version (str): model version to query. + model_shape (tuple): shape of the model's expected input. + model_dtype (str): dtype of the model's input array. Returns: - tf_results: single numpy array of predictions on big input image + numpy.array: untiled results from the model. """ - start = timeit.default_timer() - cuts = int(cuts) - field = int(field) - winx, winy = (field - 1) // 2, (field - 1) // 2 - - def iter_cuts(img, cuts, field): - padded_img = utils.pad_image(img, field) - crop_x = img.shape[img.ndim - 3] // cuts - crop_y = img.shape[img.ndim - 2] // cuts - for i in range(cuts): - for j in range(cuts): - a, b = i * crop_x, (i + 1) * crop_x - c, d = j * crop_y, (j + 1) * crop_y - data = padded_img[..., a:b + 2 * winx, c:d + 2 * winy, :] - coord = (a, b, c, d) - yield data, coord - - slcs, coords = zip(*iter_cuts(img, cuts, field)) - reqs = (self.grpc_image(s, model_name, model_version) for s in slcs) - - tf_results = None - for resp, (a, b, c, d) in zip(reqs, coords): - # resp = await asyncio.ensure_future(req) - if tf_results is None: - tf_results = np.zeros(list(img.shape)[:-1] + [resp.shape[-1]]) - self.logger.debug('Initialized output tensor of shape %s', - tf_results.shape) - - tf_results[..., a:b, c:d, :] = resp[..., winx:-winx, winy:-winy, :] - - self.logger.debug('Segmented image into shape %s in %s s', - tf_results.shape, timeit.default_timer() - start) - return tf_results + model_ndim = len(model_shape) + input_shape = (model_shape[model_ndim - 3], model_shape[model_ndim - 2]) + tiles, tiles_info = tile_image( + np.expand_dims(image, axis=0), + model_input_shape=input_shape, + stride_ratio=0.75) + + # max_batch_size is 1 by default. + # dependent on the tf-serving configuration + results = [] + for t in range(tiles.shape[0]): + output = self.grpc_image(tiles[t], model_name, model_version, + in_tensor_dtype=model_dtype) + + if not isinstance(output, list): + output = [output] + + if results == []: + results = output + + else: + for i, o in enumerate(output): + results[i] = np.vstack((results[i], o)) + + image = [ + untile_image(r, tiles_info, model_input_shape=input_shape) + for r in results + ] + image = image[0] if len(image) == 1 else image + return image + + def _predict_small_image(self, + image, + model_name, + model_version, + model_shape, + model_dtype='DT_FLOAT'): + """Pad an image that is too small for the model, and unpad the results. + + Args: + img (numpy.array): The too-small image to be predicted with + model_name and model_version. + model_name (str): hosted model to send image data. + model_version (str): model version to query. + model_shape (tuple): shape of the model's expected input. + model_dtype (str): dtype of the model's input array. + + Returns: + numpy.array: unpadded results from the model. + """ + pad_width = [] + model_ndim = len(model_shape) + for i in range(image.ndim): + if i in {image.ndim - 3, image.ndim - 2}: + if i == image.ndim - 3: + diff = model_shape[model_ndim - 3] - image.shape[i] + else: + diff = model_shape[model_ndim - 2] - image.shape[i] + + if diff % 2: + pad_width.append((diff // 2, diff // 2 + 1)) + else: + pad_width.append((diff // 2, diff // 2)) + else: + pad_width.append((0, 0)) + + padded_img = np.pad(image, pad_width, 'reflect') + image = self.grpc_image(padded_img, model_name, model_version, + in_tensor_dtype=model_dtype) + + # pad batch_size and frames. + while len(pad_width) < padded_img.ndim: + pad_width.insert(0, (0, 0)) + + # unpad results + if isinstance(image, list): + image = [utils.unpad_image(i, pad_width) for i in image] + else: + image = utils.unpad_image(image, pad_width) + return image + + def predict(self, image, model_name, model_version): + model_metadata = self.get_model_metadata(model_name, model_version) + + model_dtype = model_metadata['in_tensor_dtype'] + + model_shape = [int(x) for x in model_metadata['in_tensor_shape'].split(',')] + model_ndim = len(model_shape) + + size_x = model_shape[model_ndim - 3] + size_y = model_shape[model_ndim - 2] + + size_x = image.shape[image.ndim - 3] if size_x <= 0 else size_x + size_y = image.shape[image.ndim - 2] if size_y <= 0 else size_y + + if (image.shape[image.ndim - 3] < size_x or + image.shape[image.ndim - 2] < size_y): + # image is too small for the model, pad the image. + self._predict_small_image(image, model_name, model_version, + model_shape, model_dtype) + elif (image.shape[image.ndim - 3] > size_x or + image.shape[image.ndim - 2] > size_y): + # image is too big for the model, multiple images are tiled. + image = self._predict_big_image(image, model_name, model_version, + model_shape, model_dtype) + else: + image = self.grpc_image(image, model_name, model_version, + in_tensor_dtype=model_dtype) + return image def detect_scale(self, image): start = timeit.default_timer() @@ -438,13 +513,6 @@ def detect_scale(self, image): return 1 model_name, model_version = settings.SCALE_DETECT_MODEL.split(':') - model_metadata = self.get_model_metadata(model_name, model_version) - - model_dtype = model_metadata['in_tensor_dtype'] - model_shape = [int(x) for x in model_metadata['in_tensor_shape'].split(',')] - - size_x = model_shape[len(model_shape) - 3] - size_y = model_shape[len(model_shape) - 2] # Rescale image for compatibility with scale model # TODO Generalize to prevent from breaking on new input data types @@ -453,18 +521,7 @@ def detect_scale(self, image): else: image = np.expand_dims(image, axis=-1) - tiles, _ = tile_image( - np.expand_dims(image, axis=0), - model_input_shape=(size_x, size_y), - stride_ratio=0.75) - - # Loop over each image in the batch dimension for scale prediction - # TODO Calculate scale_detect_sample based on batch size - # Could be based on fraction or sampling a minimum set number of frames - scales = [] - for i in range(0, tiles.shape[0], settings.SCALE_DETECT_SAMPLE): - scales.append(self.grpc_image(tiles[i], model_name, model_version, - in_tensor_dtype=model_dtype)) + scales = self.predict(image, model_name, model_version) detected_scale = np.mean(scales) @@ -480,13 +537,6 @@ def detect_label(self, image): return None model_name, model_version = settings.LABEL_DETECT_MODEL.split(':') - model_metadata = self.get_model_metadata(model_name, model_version) - - model_dtype = model_metadata['in_tensor_dtype'] - model_shape = [int(x) for x in model_metadata['in_tensor_shape'].split(',')] - - size_x = model_shape[len(model_shape) - 3] - size_y = model_shape[len(model_shape) - 2] # Rescale for model compatibility # TODO Generalize to prevent from breaking on new input data types @@ -495,16 +545,7 @@ def detect_label(self, image): else: image = np.expand_dims(image, axis=-1) - tiles, _ = tile_image( - np.expand_dims(image, axis=0), - model_input_shape=(size_x, size_y), - stride_ratio=0.75) - - # Loop over each image in batch - labels = [] - for i in range(0, tiles.shape[0], settings.LABEL_DETECT_SAMPLE): - labels.append(self.grpc_image(tiles[i], model_name, model_version, - in_tensor_dtype=model_dtype)) + labels = self.predict(image, model_name, model_version) labels = np.array(labels) vote = labels.sum(axis=0) diff --git a/redis_consumer/consumers/base_consumer_test.py b/redis_consumer/consumers/base_consumer_test.py index 127b41b9..cf656f80 100644 --- a/redis_consumer/consumers/base_consumer_test.py +++ b/redis_consumer/consumers/base_consumer_test.py @@ -361,24 +361,6 @@ def _get_predict_client(model_name, model_version): assert img.shape == out.shape[1:] assert img.sum() == out.sum() - def test_process_big_image(self): - name = 'model' - version = 0 - field = 11 - cuts = 2 - - img = np.expand_dims(_get_image(100, 100), axis=-1) - img = np.expand_dims(img, axis=0) - - redis_client = None - storage = None - consumer = consumers.TensorFlowServingConsumer(redis_client, storage, 'predict') - - # image should be chopped into cuts**2 pieces and reassembled - consumer.grpc_image = lambda x, y, z: x - res = consumer.process_big_image(cuts, img, field, name, version) - np.testing.assert_equal(res, img) - def test_get_model_metadata(self): # pytest: disable=W0613 redis_client = DummyRedis([]) @@ -443,6 +425,41 @@ def _get_bad_predict_client(model_name, model_version): consumer._get_predict_client = _get_bad_predict_client consumer.get_model_metadata('model', 1) + def test_predict(self): + redis_client = DummyRedis([]) + consumer = consumers.TensorFlowServingConsumer(redis_client, None, 'q') + + def grpc_image(data, *args, **kwargs): + data = np.expand_dims(data, axis=0) + return data + + def grpc_image_list(data, *args, **kwargs): # pylint: disable=W0613 + data = np.expand_dims(data, axis=0) + return [data, data] + + model_shape = (1, 128, 128, 1) + + image_shapes = [ + (256, 256, 1), + (128, 128, 1), + (64, 64, 1), + (100, 100, 1), + (300, 300, 1), + ] + + for image_shape in image_shapes: + for grpc_func in (grpc_image, grpc_image_list): + + x = np.random.random(image_shape) + consumer.grpc_image = grpc_func + consumer.get_model_metadata = lambda x, y: { + 'in_tensor_dtype': 'DT_FLOAT', + 'in_tensor_shape': ','.join(str(s) for s in model_shape), + } + + y = consumer.predict(x, model_name='modelname', model_version=0) + pass + def test_detect_label(self): redis_client = DummyRedis([]) model_shape = (1, 216, 216, 1) @@ -455,13 +472,13 @@ def test_detect_label(self): settings.LABEL_DETECT_MODEL = 'dummymodel:1' - def dummydata(*_, **__): + def predict(*_, **__): data = np.zeros((3,)) i = np.random.randint(3) data[i] = 1 return data - consumer.grpc_image = dummydata + consumer.predict = predict settings.LABEL_DETECT_ENABLED = False @@ -475,6 +492,7 @@ def dummydata(*_, **__): def test_detect_scale(self): redis_client = DummyRedis([]) + model_shape = (1, 216, 216, 1) consumer = consumers.TensorFlowServingConsumer(redis_client, None, 'q') consumer.get_model_metadata = lambda x, y: { @@ -488,11 +506,11 @@ def test_detect_scale(self): settings.SCALE_DETECT_MODEL = 'dummymodel:1' - def grpc_image(*_, **__): + def predict(*_, **__): sign = -1 if np.random.randint(1, 5) > 2 else 1 return expected + sign * 1e-8 # small differences get averaged out - consumer.grpc_image = grpc_image + consumer.predict = predict settings.SCALE_DETECT_ENABLED = False @@ -501,7 +519,7 @@ def grpc_image(*_, **__): settings.SCALE_DETECT_ENABLED = True - consumer.grpc_image = grpc_image + consumer.predict = predict scale = consumer.detect_scale(image) assert isinstance(scale, (float, int)) diff --git a/redis_consumer/consumers/image_consumer.py b/redis_consumer/consumers/image_consumer.py index 312430bb..f22f51eb 100644 --- a/redis_consumer/consumers/image_consumer.py +++ b/redis_consumer/consumers/image_consumer.py @@ -167,9 +167,6 @@ def _consume(self, redis_hash): 'identity_started': self.hostname, }) - cuts = hvals.get('cuts', '0') # TODO: deprecated - field = hvals.get('field_size', '61') # TODO: deprecated - # Overridden with LABEL_DETECT_ENABLED model_name = hvals.get('model_name') model_version = hvals.get('model_version') @@ -179,8 +176,6 @@ def _consume(self, redis_hash): fname = self.storage.download(hvals.get('input_file_name'), tempdir) image = utils.get_image(fname) - streaming = str(cuts).isdigit() and int(cuts) > 0 - # Pre-process data before sending to the model self.update_key(redis_hash, { 'status': 'pre-processing', @@ -223,86 +218,13 @@ def _consume(self, redis_hash): # Grap appropriate model model_name, model_version = utils._pick_model(label) - model_metadata = self.get_model_metadata(model_name, model_version) - - model_dtype = model_metadata['in_tensor_dtype'] - model_shape = [int(x) for x in model_metadata['in_tensor_shape'].split(',')] - - size_x = model_shape[len(model_shape) - 3] - size_y = model_shape[len(model_shape) - 2] - pre_funcs = hvals.get('preprocess_function', '').split(',') image = self.preprocess(image, pre_funcs, True) # Send data to the model self.update_key(redis_hash, {'status': 'predicting'}) - if (image.shape[image.ndim - 3] < size_x or - image.shape[image.ndim - 2] < size_y): - # tiling not necessary, but image must be padded. - pad_width = [] - for i in range(image.ndim): - if i in {image.ndim - 3, image.ndim - 2}: - if i == image.ndim - 3: - diff = size_x - image.shape[i] - else: - diff = size_y - image.shape[i] - - if diff % 2: - pad_width.append((diff // 2, diff // 2 + 1)) - else: - pad_width.append((diff // 2, diff // 2)) - - else: - pad_width.append((0, 0)) - - padded_img = np.pad(image, pad_width, 'reflect') - image = self.grpc_image(padded_img, model_name, model_version, - in_tensor_dtype=model_dtype) - - # pad batch_size and frames. - while len(pad_width) < padded_img.ndim: - pad_width.insert(0, (0, 0)) - - # unpad results - if isinstance(image, list): - image = [utils.unpad_image(i, pad_width) for i in image] - else: - image = utils.unpad_image(image, pad_width) - - elif (image.shape[image.ndim - 3] > size_x or - image.shape[image.ndim - 2] > size_y): - - # need to tile! - tiles, tiles_info = tile_image( - np.expand_dims(image, axis=0), - model_input_shape=(size_x, size_y), - stride_ratio=0.75) - - # max_batch_size is 1 by default. - # dependent on the tf-serving configuration - results = [] - for t in range(tiles.shape[0]): - output = self.grpc_image(tiles[t], model_name, model_version) - - if not isinstance(output, list): - output = [output] - - if results == []: - results = output - - else: - for i, o in enumerate(output): - results[i] = np.vstack((results[i], o)) - - image = [ - untile_image(r, tiles_info, model_input_shape=(size_x, size_y)) - for r in results - ] - image = image[0] if len(image) == 1 else image - - else: - image = self.grpc_image(image, model_name, model_version) + image = self.predict(image, model_name, model_version) # Post-process model results self.update_key(redis_hash, {'status': 'post-processing'}) From 13a3d15c2f5d336a7d62a1d33e36d328b3ccc41e Mon Sep 17 00:00:00 2001 From: William Graf <7930703+willgraf@users.noreply.github.com> Date: Tue, 3 Mar 2020 15:22:42 -0800 Subject: [PATCH 19/47] better list comp --- redis_consumer/consumers/base_consumer.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/redis_consumer/consumers/base_consumer.py b/redis_consumer/consumers/base_consumer.py index e423d10b..88a37c49 100644 --- a/redis_consumer/consumers/base_consumer.py +++ b/redis_consumer/consumers/base_consumer.py @@ -419,10 +419,8 @@ def _predict_big_image(self, for i, o in enumerate(output): results[i] = np.vstack((results[i], o)) - image = [ - untile_image(r, tiles_info, model_input_shape=input_shape) - for r in results - ] + image = [untile_image(r, tiles_info, model_input_shape=input_shape) + for r in results] image = image[0] if len(image) == 1 else image return image From 2cd04e0fc29e66d89d375b5895c665be093f093a Mon Sep 17 00:00:00 2001 From: William Graf <7930703+willgraf@users.noreply.github.com> Date: Wed, 4 Mar 2020 08:44:33 -0800 Subject: [PATCH 20/47] pass sample to _predict_big_image, only untile if sample is default. --- redis_consumer/consumers/base_consumer.py | 34 ++++++++++++++++------- 1 file changed, 24 insertions(+), 10 deletions(-) diff --git a/redis_consumer/consumers/base_consumer.py b/redis_consumer/consumers/base_consumer.py index 88a37c49..ee5838d6 100644 --- a/redis_consumer/consumers/base_consumer.py +++ b/redis_consumer/consumers/base_consumer.py @@ -382,30 +382,40 @@ def _predict_big_image(self, model_name, model_version, model_shape, - model_dtype='DT_FLOAT'): + model_dtype='DT_FLOAT', + sample=None): """Use tile_image to tile image for the model and untile the results. Args: - img (numpy.array): image data as numpy. + image (numpy.array): image data as numpy. model_name (str): hosted model to send image data. model_version (str): model version to query. model_shape (tuple): shape of the model's expected input. model_dtype (str): dtype of the model's input array. + sample (int): Only predict every sample'th tile. + If sample is not None, no untiling will be performed, + as the untiling data will be incomplete. Returns: numpy.array: untiled results from the model. """ + is_untile_required = sample is None + sample = 1 if sample is None else sample model_ndim = len(model_shape) input_shape = (model_shape[model_ndim - 3], model_shape[model_ndim - 2]) + tiles, tiles_info = tile_image( np.expand_dims(image, axis=0), model_input_shape=input_shape, stride_ratio=0.75) + self.logger.debug('Tiling image of shape %s into shape %s.', + image.shape, tiles.shape) + # max_batch_size is 1 by default. # dependent on the tf-serving configuration results = [] - for t in range(tiles.shape[0]): + for t in range(0, tiles.shape[0], sample): output = self.grpc_image(tiles[t], model_name, model_version, in_tensor_dtype=model_dtype) @@ -414,13 +424,16 @@ def _predict_big_image(self, if results == []: results = output - else: for i, o in enumerate(output): results[i] = np.vstack((results[i], o)) - image = [untile_image(r, tiles_info, model_input_shape=input_shape) - for r in results] + if not is_untile_required: + image = results + else: + image = [untile_image(r, tiles_info, model_input_shape=input_shape) + for r in results] + image = image[0] if len(image) == 1 else image return image @@ -474,7 +487,7 @@ def _predict_small_image(self, image = utils.unpad_image(image, pad_width) return image - def predict(self, image, model_name, model_version): + def predict(self, image, model_name, model_version, sample=None): model_metadata = self.get_model_metadata(model_name, model_version) model_dtype = model_metadata['in_tensor_dtype'] @@ -491,14 +504,15 @@ def predict(self, image, model_name, model_version): if (image.shape[image.ndim - 3] < size_x or image.shape[image.ndim - 2] < size_y): # image is too small for the model, pad the image. - self._predict_small_image(image, model_name, model_version, - model_shape, model_dtype) + image = self._predict_small_image(image, model_name, model_version, + model_shape, model_dtype) elif (image.shape[image.ndim - 3] > size_x or image.shape[image.ndim - 2] > size_y): # image is too big for the model, multiple images are tiled. image = self._predict_big_image(image, model_name, model_version, - model_shape, model_dtype) + model_shape, model_dtype, sample) else: + # image size is perfect, just send it to the model image = self.grpc_image(image, model_name, model_version, in_tensor_dtype=model_dtype) return image From aeb75ba7639a85150ebfd4dfb9c63e7bf4c54a13 Mon Sep 17 00:00:00 2001 From: William Graf <7930703+willgraf@users.noreply.github.com> Date: Wed, 4 Mar 2020 08:45:30 -0800 Subject: [PATCH 21/47] log prediction statements about padding, tiling, etc. --- redis_consumer/consumers/base_consumer.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/redis_consumer/consumers/base_consumer.py b/redis_consumer/consumers/base_consumer.py index ee5838d6..2a979de5 100644 --- a/redis_consumer/consumers/base_consumer.py +++ b/redis_consumer/consumers/base_consumer.py @@ -472,6 +472,10 @@ def _predict_small_image(self, else: pad_width.append((0, 0)) + self.logger.info('Padding image from shape %s to shape %s.', + image.shape, tuple([x + y1 + y2 for x, (y1, y2) in + zip(image.shape, pad_width)])) + padded_img = np.pad(image, pad_width, 'reflect') image = self.grpc_image(padded_img, model_name, model_version, in_tensor_dtype=model_dtype) @@ -501,6 +505,11 @@ def predict(self, image, model_name, model_version, sample=None): size_x = image.shape[image.ndim - 3] if size_x <= 0 else size_x size_y = image.shape[image.ndim - 2] if size_y <= 0 else size_y + self.logger.debug('Calling predict on model %s:%s with input shape %s' + ' and dtype %s to segment an image of shape %s.', + model_name, model_version, tuple(model_shape), + model_dtype, image.shape) + if (image.shape[image.ndim - 3] < size_x or image.shape[image.ndim - 2] < size_y): # image is too small for the model, pad the image. From 80447223de66b7029161d1c8bb37d4a4566c7243 Mon Sep 17 00:00:00 2001 From: William Graf <7930703+willgraf@users.noreply.github.com> Date: Wed, 4 Mar 2020 08:45:56 -0800 Subject: [PATCH 22/47] use sample in detect scale and label, no more image size TODOs --- redis_consumer/consumers/base_consumer.py | 20 ++++---------------- 1 file changed, 4 insertions(+), 16 deletions(-) diff --git a/redis_consumer/consumers/base_consumer.py b/redis_consumer/consumers/base_consumer.py index 2a979de5..f7c01920 100644 --- a/redis_consumer/consumers/base_consumer.py +++ b/redis_consumer/consumers/base_consumer.py @@ -535,14 +535,8 @@ def detect_scale(self, image): model_name, model_version = settings.SCALE_DETECT_MODEL.split(':') - # Rescale image for compatibility with scale model - # TODO Generalize to prevent from breaking on new input data types - if image.shape[-1] == 1: - image = np.expand_dims(image, axis=0) - else: - image = np.expand_dims(image, axis=-1) - - scales = self.predict(image, model_name, model_version) + scales = self.predict(image, model_name, model_version, + sample=settings.SCALE_DETECT_SAMPLE) detected_scale = np.mean(scales) @@ -559,14 +553,8 @@ def detect_label(self, image): model_name, model_version = settings.LABEL_DETECT_MODEL.split(':') - # Rescale for model compatibility - # TODO Generalize to prevent from breaking on new input data types - if image.shape[-1] == 1: - image = np.expand_dims(image, axis=0) - else: - image = np.expand_dims(image, axis=-1) - - labels = self.predict(image, model_name, model_version) + labels = self.predict(image, model_name, model_version, + sample=settings.SCALE_DETECT_SAMPLE) labels = np.array(labels) vote = labels.sum(axis=0) From 32cd91845be3cf02fc60e30e858d0acb31f7f15f Mon Sep 17 00:00:00 2001 From: William Graf <7930703+willgraf@users.noreply.github.com> Date: Wed, 4 Mar 2020 10:43:38 -0800 Subject: [PATCH 23/47] fix unpadding for list of outputs. --- redis_consumer/consumers/base_consumer.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/redis_consumer/consumers/base_consumer.py b/redis_consumer/consumers/base_consumer.py index f7c01920..4768fc3f 100644 --- a/redis_consumer/consumers/base_consumer.py +++ b/redis_consumer/consumers/base_consumer.py @@ -480,15 +480,18 @@ def _predict_small_image(self, image = self.grpc_image(padded_img, model_name, model_version, in_tensor_dtype=model_dtype) - # pad batch_size and frames. - while len(pad_width) < padded_img.ndim: - pad_width.insert(0, (0, 0)) + image = [image] if not isinstance(image, list) else image + + # pad batch_size and frames for each output. + pad_widths = [pad_width] * len(image) + for i, im in enumerate(image): + while len(pad_widths[i]) < im.ndim: + pad_widths[i].insert(0, (0, 0)) # unpad results - if isinstance(image, list): - image = [utils.unpad_image(i, pad_width) for i in image] - else: - image = utils.unpad_image(image, pad_width) + image = [utils.unpad_image(i, p) for i, p in zip(image, pad_widths)] + image = image[0] if len(image) == 1 else image + return image def predict(self, image, model_name, model_version, sample=None): From 3ad1f06cac59a58cf0fbf0c529a754d31430bc8c Mon Sep 17 00:00:00 2001 From: William Graf <7930703+willgraf@users.noreply.github.com> Date: Wed, 4 Mar 2020 10:45:45 -0800 Subject: [PATCH 24/47] image ndim should be 1 less than model ndim. (missing batch dim in image) --- redis_consumer/consumers/base_consumer.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/redis_consumer/consumers/base_consumer.py b/redis_consumer/consumers/base_consumer.py index 4768fc3f..534b17bb 100644 --- a/redis_consumer/consumers/base_consumer.py +++ b/redis_consumer/consumers/base_consumer.py @@ -502,6 +502,12 @@ def predict(self, image, model_name, model_version, sample=None): model_shape = [int(x) for x in model_metadata['in_tensor_shape'].split(',')] model_ndim = len(model_shape) + if model_ndim != image.ndim + 1: + raise ValueError('Image of shape {} is incompatible with model ' + '{}:{} with input shape {}'.format( + image.shape, model_name, model_version, + tuple(model_shape))) + size_x = model_shape[model_ndim - 3] size_y = model_shape[model_ndim - 2] From 91dcf46466d5948a5e6f4c878c7f10ed6aaac209 Mon Sep 17 00:00:00 2001 From: William Graf <7930703+willgraf@users.noreply.github.com> Date: Wed, 4 Mar 2020 10:47:18 -0800 Subject: [PATCH 25/47] migrate detect_scale and detect_label to ImageFileConsumer --- redis_consumer/consumers/base_consumer.py | 46 +++---------- .../consumers/base_consumer_test.py | 69 ------------------- redis_consumer/consumers/image_consumer.py | 44 ++++++++++++ .../consumers/image_consumer_test.py | 69 +++++++++++++++++++ redis_consumer/consumers/tracking_consumer.py | 49 +++++++------ 5 files changed, 145 insertions(+), 132 deletions(-) diff --git a/redis_consumer/consumers/base_consumer.py b/redis_consumer/consumers/base_consumer.py index 534b17bb..445dfdec 100644 --- a/redis_consumer/consumers/base_consumer.py +++ b/redis_consumer/consumers/base_consumer.py @@ -533,47 +533,17 @@ def predict(self, image, model_name, model_version, sample=None): # image size is perfect, just send it to the model image = self.grpc_image(image, model_name, model_version, in_tensor_dtype=model_dtype) - return image - - def detect_scale(self, image): - start = timeit.default_timer() - - if not settings.SCALE_DETECT_ENABLED: - self.logger.debug('Scale detection disabled. Scale set to 1.') - return 1 - - model_name, model_version = settings.SCALE_DETECT_MODEL.split(':') - - scales = self.predict(image, model_name, model_version, - sample=settings.SCALE_DETECT_SAMPLE) - - detected_scale = np.mean(scales) - - self.logger.debug('Scale %s detected in %s seconds', - detected_scale, timeit.default_timer() - start) - return detected_scale - def detect_label(self, image): - start = timeit.default_timer() - - if not settings.LABEL_DETECT_ENABLED: - self.logger.debug('Label detection disabled. Label set to None.') - return None - - model_name, model_version = settings.LABEL_DETECT_MODEL.split(':') - - labels = self.predict(image, model_name, model_version, - sample=settings.SCALE_DETECT_SAMPLE) - - labels = np.array(labels) - vote = labels.sum(axis=0) - maj = vote.max() + if isinstance(image, list): + output_shapes = [i.shape for i in image] + else: + output_shapes = [image.shape] # cast as list - detected = np.where(vote == maj)[-1][0] + self.logger.debug('Got response from model %s:%s of shape %s in %s ' + 'seconds.', model_name, model_version, output_shapes, + timeit.default_timer() - start) - self.logger.debug('Label %s detected in %s seconds.', - detected, timeit.default_timer() - start) - return detected + return image class ZipFileConsumer(Consumer): diff --git a/redis_consumer/consumers/base_consumer_test.py b/redis_consumer/consumers/base_consumer_test.py index cf656f80..a5ea43a3 100644 --- a/redis_consumer/consumers/base_consumer_test.py +++ b/redis_consumer/consumers/base_consumer_test.py @@ -460,75 +460,6 @@ def grpc_image_list(data, *args, **kwargs): # pylint: disable=W0613 y = consumer.predict(x, model_name='modelname', model_version=0) pass - def test_detect_label(self): - redis_client = DummyRedis([]) - model_shape = (1, 216, 216, 1) - consumer = consumers.TensorFlowServingConsumer(redis_client, None, 'q') - consumer.get_model_metadata = lambda x, y: { - 'in_tensor_dtype': 'DT_FLOAT', - 'in_tensor_shape': ','.join(str(s) for s in model_shape), - } - image = _get_image(model_shape[1] * 2, model_shape[2] * 2) - - settings.LABEL_DETECT_MODEL = 'dummymodel:1' - - def predict(*_, **__): - data = np.zeros((3,)) - i = np.random.randint(3) - data[i] = 1 - return data - - consumer.predict = predict - - settings.LABEL_DETECT_ENABLED = False - - label = consumer.detect_label(image) - assert label is None - - settings.LABEL_DETECT_ENABLED = True - - label = consumer.detect_label(image) - assert label in set(list(range(4))) - - def test_detect_scale(self): - redis_client = DummyRedis([]) - - model_shape = (1, 216, 216, 1) - consumer = consumers.TensorFlowServingConsumer(redis_client, None, 'q') - consumer.get_model_metadata = lambda x, y: { - 'in_tensor_dtype': 'DT_FLOAT', - 'in_tensor_shape': ','.join(str(s) for s in model_shape), - } - big_size = model_shape[1] * np.random.randint(2, 9) - image = _get_image(big_size, big_size) - - expected = (model_shape[1] / (big_size)) ** 2 - - settings.SCALE_DETECT_MODEL = 'dummymodel:1' - - def predict(*_, **__): - sign = -1 if np.random.randint(1, 5) > 2 else 1 - return expected + sign * 1e-8 # small differences get averaged out - - consumer.predict = predict - - settings.SCALE_DETECT_ENABLED = False - - scale = consumer.detect_scale(image) - assert scale == 1 - - settings.SCALE_DETECT_ENABLED = True - - consumer.predict = predict - - scale = consumer.detect_scale(image) - assert isinstance(scale, (float, int)) - np.testing.assert_almost_equal(scale, expected) - - # scale = consumer.detect_scale(np.expand_dims(image, axis=-1)) - # assert isinstance(scale, (float, int)) - # np.testing.assert_almost_equal(scale, expected) - class TestZipFileConsumer(object): diff --git a/redis_consumer/consumers/image_consumer.py b/redis_consumer/consumers/image_consumer.py index f22f51eb..caab2958 100644 --- a/redis_consumer/consumers/image_consumer.py +++ b/redis_consumer/consumers/image_consumer.py @@ -111,6 +111,50 @@ def process(self, image, key, process_type): return results + def detect_scale(self, image): + start = timeit.default_timer() + + if not settings.SCALE_DETECT_ENABLED: + self.logger.debug('Scale detection disabled. Scale set to 1.') + return 1 + + model_name, model_version = settings.SCALE_DETECT_MODEL.split(':') + + scales = self.predict(image, model_name, model_version, + sample=settings.SCALE_DETECT_SAMPLE) + + detected_scale = np.mean(scales) + + error_rate = .01 # error rate is ~1% for current model. + if abs(detected_scale - 1) < error_rate: + detected_scale = 1 + + self.logger.debug('Scale %s detected in %s seconds', + detected_scale, timeit.default_timer() - start) + return detected_scale + + def detect_label(self, image): + start = timeit.default_timer() + + if not settings.LABEL_DETECT_ENABLED: + self.logger.debug('Label detection disabled. Label set to None.') + return None + + model_name, model_version = settings.LABEL_DETECT_MODEL.split(':') + + labels = self.predict(image, model_name, model_version, + sample=settings.SCALE_DETECT_SAMPLE) + + labels = np.array(labels) + vote = labels.sum(axis=0) + maj = vote.max() + + detected = np.where(vote == maj)[-1][0] + + self.logger.debug('Label %s detected in %s seconds.', + detected, timeit.default_timer() - start) + return detected + def preprocess(self, image, keys, streaming=False): """Wrapper for _process_image but can only call with type="pre". diff --git a/redis_consumer/consumers/image_consumer_test.py b/redis_consumer/consumers/image_consumer_test.py index e02ff134..6c93caae 100644 --- a/redis_consumer/consumers/image_consumer_test.py +++ b/redis_consumer/consumers/image_consumer_test.py @@ -224,6 +224,75 @@ def test_process(self): settings.PROCESSING_FUNCTIONS = _funcs + def test_detect_label(self): + redis_client = DummyRedis([]) + model_shape = (1, 216, 216, 1) + consumer = consumers.ImageFileConsumer(redis_client, None, 'q') + consumer.get_model_metadata = lambda x, y: { + 'in_tensor_dtype': 'DT_FLOAT', + 'in_tensor_shape': ','.join(str(s) for s in model_shape), + } + image = _get_image(model_shape[1] * 2, model_shape[2] * 2) + + settings.LABEL_DETECT_MODEL = 'dummymodel:1' + + def predict(*_, **__): + data = np.zeros((3,)) + i = np.random.randint(3) + data[i] = 1 + return data + + consumer.predict = predict + + settings.LABEL_DETECT_ENABLED = False + + label = consumer.detect_label(image) + assert label is None + + settings.LABEL_DETECT_ENABLED = True + + label = consumer.detect_label(image) + assert label in set(list(range(4))) + + def test_detect_scale(self): + redis_client = DummyRedis([]) + + model_shape = (1, 216, 216, 1) + consumer = consumers.ImageFileConsumer(redis_client, None, 'q') + consumer.get_model_metadata = lambda x, y: { + 'in_tensor_dtype': 'DT_FLOAT', + 'in_tensor_shape': ','.join(str(s) for s in model_shape), + } + big_size = model_shape[1] * np.random.randint(2, 9) + image = _get_image(big_size, big_size) + + expected = (model_shape[1] / (big_size)) ** 2 + + settings.SCALE_DETECT_MODEL = 'dummymodel:1' + + def predict(*_, **__): + sign = -1 if np.random.randint(1, 5) > 2 else 1 + return expected + sign * 1e-8 # small differences get averaged out + + consumer.predict = predict + + settings.SCALE_DETECT_ENABLED = False + + scale = consumer.detect_scale(image) + assert scale == 1 + + settings.SCALE_DETECT_ENABLED = True + + consumer.predict = predict + + scale = consumer.detect_scale(image) + assert isinstance(scale, (float, int)) + np.testing.assert_almost_equal(scale, expected) + + # scale = consumer.detect_scale(np.expand_dims(image, axis=-1)) + # assert isinstance(scale, (float, int)) + # np.testing.assert_almost_equal(scale, expected) + def test__consume(self): # pylint: disable=W0613 prefix = 'predict' diff --git a/redis_consumer/consumers/tracking_consumer.py b/redis_consumer/consumers/tracking_consumer.py index 04307f67..81e95bf5 100644 --- a/redis_consumer/consumers/tracking_consumer.py +++ b/redis_consumer/consumers/tracking_consumer.py @@ -145,27 +145,27 @@ def _load_data(self, redis_hash, subdir, fname): tiff_stack.shape)) # Calculate scale of a subset of raw - scale = hvalues.get('scale', '') - if not scale: - # Detect scale of image - scale = self.detect_scale(tiff_stack) - self.logger.debug('Image scale detected: %s', scale) - self.update_key(redis_hash, {'scale': scale}) - else: - scale = float(scale) - self.logger.debug('Image scale already calculated: %s', scale) + # scale = hvalues.get('scale', '') + # if not scale: + # # Detect scale of image + # scale = self.detect_scale(tiff_stack) + # self.logger.debug('Image scale detected: %s', scale) + # self.update_key(redis_hash, {'scale': scale}) + # else: + # scale = float(scale) + # self.logger.debug('Image scale already calculated: %s', scale) # Pick model and postprocess based on either label or defaults - if settings.LABEL_DETECT_ENABLED: - label = self.detect_label(tiff_stack) # Predict label type - - # Get appropriate model and postprocess function for the label - model_name, model_version = utils._pick_model(label) - postprocess_function = utils._pick_postprocess(label) - else: - label = 99 # Equivalent to none - model_name, model_version = settings.TRACKING_SEGMENT_MODEL.split(':') - postprocess_function = settings.TRACKING_POSTPROCESS_FUNCTION + # if settings.LABEL_DETECT_ENABLED: + # label = self.detect_label(tiff_stack) # Predict label type + # + # # Get appropriate model and postprocess function for the label + # model_name, model_version = utils._pick_model(label) + # postprocess_function = utils._pick_postprocess(label) + # else: + # label = 99 # Equivalent to none + # model_name, model_version = settings.TRACKING_SEGMENT_MODEL.split(':') + # postprocess_function = settings.TRACKING_POSTPROCESS_FUNCTION num_frames = len(tiff_stack) hash_to_frame = {} @@ -192,16 +192,15 @@ def _load_data(self, redis_hash, subdir, fname): 'identity_upload': self.hostname, 'input_file_name': upload_file_name, 'original_name': segment_fname, - 'model_name': model_name, - 'model_version': model_version, - 'postprocess_function': postprocess_function, - 'cuts': settings.CUTS, + # 'model_name': model_name, + # 'model_version': model_version, + # 'postprocess_function': postprocess_function, 'status': 'new', 'created_at': current_timestamp, 'updated_at': current_timestamp, 'url': upload_file_url, - 'scale': scale, - 'label': str(label) + # 'scale': scale, + # 'label': str(label) } self.logger.debug("Setting %s", frame_hvalues) From e830592a3c3dbdb173e22b80a1fdb04c37f7434e Mon Sep 17 00:00:00 2001 From: William Graf <7930703+willgraf@users.noreply.github.com> Date: Wed, 4 Mar 2020 10:47:31 -0800 Subject: [PATCH 26/47] missing timer --- redis_consumer/consumers/base_consumer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/redis_consumer/consumers/base_consumer.py b/redis_consumer/consumers/base_consumer.py index 445dfdec..05c4fd60 100644 --- a/redis_consumer/consumers/base_consumer.py +++ b/redis_consumer/consumers/base_consumer.py @@ -495,6 +495,7 @@ def _predict_small_image(self, return image def predict(self, image, model_name, model_version, sample=None): + start = timeit.default_timer() model_metadata = self.get_model_metadata(model_name, model_version) model_dtype = model_metadata['in_tensor_dtype'] From 4eaeca97fc31a883c1c0ed73fbab15d65612b764 Mon Sep 17 00:00:00 2001 From: William Graf <7930703+willgraf@users.noreply.github.com> Date: Wed, 4 Mar 2020 10:55:35 -0800 Subject: [PATCH 27/47] lint tests for base_consumer --- .../consumers/base_consumer_test.py | 37 +++++++------------ 1 file changed, 14 insertions(+), 23 deletions(-) diff --git a/redis_consumer/consumers/base_consumer_test.py b/redis_consumer/consumers/base_consumer_test.py index a5ea43a3..ac8f54da 100644 --- a/redis_consumer/consumers/base_consumer_test.py +++ b/redis_consumer/consumers/base_consumer_test.py @@ -54,6 +54,7 @@ def __init__(self, **kwds): class DummyRedis(object): + # pylint: disable=W0613,R0201 def __init__(self, items=[], prefix='predict', status='new'): self.work_queue = copy.copy(items) self.processing_queue = [] @@ -95,16 +96,15 @@ def lrem(self, name, count, value): def llen(self, queue): if queue.startswith('processing'): return len(self.processing_queue) - else: - return len(self.work_queue) + return len(self.work_queue) def hmget(self, rhash, *args): return [self.hget(rhash, a) for a in args] - def hmset(self, rhash, hvals): # pylint: disable=W0613 + def hmset(self, rhash, hvals): return hvals - def expire(self, name, time): # pylint: disable=W0613 + def expire(self, name, time): return 1 def hget(self, rhash, field): @@ -120,15 +120,13 @@ def hget(self, rhash, field): return 'reason' return False - def hset(self, rhash, status, value): # pylint: disable=W0613 + def hset(self, rhash, status, value): return {status: value} - def hgetall(self, rhash): # pylint: disable=W0613 + def hgetall(self, rhash): return { 'model_name': 'model', 'model_version': '0', - 'field': '61', - 'cuts': '0', 'postprocess_function': '', 'preprocess_function': '', 'file_name': rhash.split(':')[1], @@ -142,6 +140,7 @@ def hgetall(self, rhash): # pylint: disable=W0613 class DummyStorage(object): + # pylint: disable=W0613,R0201 def __init__(self, num=3): self.num = num @@ -159,15 +158,15 @@ def download(self, path, dest): tiff.imsave(os.path.join(dest, path), img) return path - def upload(self, zip_path, subdir=None): # pylint: disable=W0613 + def upload(self, zip_path, subdir=None): return True, True - def get_public_url(self, zip_path): # pylint: disable=W0613 + def get_public_url(self, zip_path): return True class TestConsumer(object): - + # pylint: disable=R0201 def test_get_redis_hash(self): settings.EMPTY_QUEUE_TIMEOUT = 0.01 # don't sleep too long @@ -335,7 +334,7 @@ def test__consume(self): class TestTensorFlowServingConsumer(object): - + # pylint: disable=R0201,W0613 def test__get_predict_client(self): redis_client = DummyRedis([]) consumer = consumers.TensorFlowServingConsumer(redis_client, None, 'q') @@ -343,7 +342,7 @@ def test__get_predict_client(self): with pytest.raises(ValueError): consumer._get_predict_client('model_name', 'model_version') - client = consumer._get_predict_client('model_name', 1) + consumer._get_predict_client('model_name', 1) def test_grpc_image(self): redis_client = DummyRedis([]) @@ -362,7 +361,6 @@ def _get_predict_client(model_name, model_version): assert img.sum() == out.sum() def test_get_model_metadata(self): - # pytest: disable=W0613 redis_client = DummyRedis([]) model_shape = (-1, 216, 216, 1) model_dtype = 'DT_FLOAT' @@ -373,8 +371,6 @@ def hmget_success(key, *others): return dtype, shape def hmget_fail(key, *others): - shape = ','.join(str(s) for s in model_shape) - dtype = 'DT_FLOAT' return [None] * len(others) def _get_predict_client(model_name, model_version): @@ -457,12 +453,11 @@ def grpc_image_list(data, *args, **kwargs): # pylint: disable=W0613 'in_tensor_shape': ','.join(str(s) for s in model_shape), } - y = consumer.predict(x, model_name='modelname', model_version=0) - pass + consumer.predict(x, model_name='modelname', model_version=0) class TestZipFileConsumer(object): - + # pylint: disable=R0201,W0613 def test_is_valid_hash(self): items = ['item%s' % x for x in range(1, 4)] @@ -544,7 +539,6 @@ def test__parse_failures(self): def test__cleanup(self): N = 3 queue = 'predict' - status = 'waiting' items = ['item%s' % x for x in range(1, N + 1)] redis_client = DummyRedis(items) storage = DummyStorage(num=N) @@ -554,9 +548,6 @@ def test__cleanup(self): done = ['{}:done'.format(c) for c in children[:3]] failed = ['{}:failed'.format(c) for c in children[3:]] - key = '{queue}:{fname}.zip:{status}'.format( - queue=queue, status=status, fname=status) - consumer._cleanup(items[0], children, done, failed) # test non-float values From 1cb9805d41a3327331e70016ec34cce9605e187d Mon Sep 17 00:00:00 2001 From: William Graf <7930703+willgraf@users.noreply.github.com> Date: Wed, 4 Mar 2020 10:55:59 -0800 Subject: [PATCH 28/47] remove cuts, field, and streaming form image_consumer --- redis_consumer/consumers/image_consumer.py | 10 +++---- .../consumers/image_consumer_test.py | 26 ++----------------- 2 files changed, 6 insertions(+), 30 deletions(-) diff --git a/redis_consumer/consumers/image_consumer.py b/redis_consumer/consumers/image_consumer.py index caab2958..c416ef87 100644 --- a/redis_consumer/consumers/image_consumer.py +++ b/redis_consumer/consumers/image_consumer.py @@ -155,7 +155,7 @@ def detect_label(self, image): detected, timeit.default_timer() - start) return detected - def preprocess(self, image, keys, streaming=False): + def preprocess(self, image, keys): """Wrapper for _process_image but can only call with type="pre". Args: @@ -169,11 +169,10 @@ def preprocess(self, image, keys, streaming=False): pre = None for key in keys: x = pre if pre else image - # pre = self._process(x, key, 'pre', streaming) pre = self.process(x, key, 'pre') return pre - def postprocess(self, image, keys, streaming=False): + def postprocess(self, image, keys): """Wrapper for _process_image but can only call with type="post". Args: @@ -187,7 +186,6 @@ def postprocess(self, image, keys, streaming=False): post = None for key in keys: x = post if post else image - # post = self._process(x, key, 'post', streaming) post = self.process(x, key, 'post') return post @@ -263,7 +261,7 @@ def _consume(self, redis_hash): model_name, model_version = utils._pick_model(label) pre_funcs = hvals.get('preprocess_function', '').split(',') - image = self.preprocess(image, pre_funcs, True) + image = self.preprocess(image, pre_funcs) # Send data to the model self.update_key(redis_hash, {'status': 'predicting'}) @@ -278,7 +276,7 @@ def _consume(self, redis_hash): else: post_funcs = hvals.get('postprocess_function', '').split(',') - image = self.postprocess(image, post_funcs, True) + image = self.postprocess(image, post_funcs) # Save the post-processed results to a file _ = timeit.default_timer() diff --git a/redis_consumer/consumers/image_consumer_test.py b/redis_consumer/consumers/image_consumer_test.py index 6c93caae..b78a3301 100644 --- a/redis_consumer/consumers/image_consumer_test.py +++ b/redis_consumer/consumers/image_consumer_test.py @@ -128,8 +128,6 @@ def hgetall(self, rhash): # pylint: disable=W0613 return { 'model_name': 'model', 'model_version': '0', - 'field': '61', - 'cuts': '0', 'postprocess_function': '', 'preprocess_function': '', 'file_name': rhash.split(':')[1], @@ -225,6 +223,7 @@ def test_process(self): settings.PROCESSING_FUNCTIONS = _funcs def test_detect_label(self): + # pylint: disable=W0613 redis_client = DummyRedis([]) model_shape = (1, 216, 216, 1) consumer = consumers.ImageFileConsumer(redis_client, None, 'q') @@ -255,6 +254,7 @@ def predict(*_, **__): assert label in set(list(range(4))) def test_detect_scale(self): + # pylint: disable=W0613 redis_client = DummyRedis([]) model_shape = (1, 216, 216, 1) @@ -372,28 +372,6 @@ def get_model_metadata(model_name, model_version): prefix, consumer.final_status)) assert result == consumer.final_status - # test with cuts > 0 - redis_client.hgetall = lambda x: { - 'model_name': 'model', - 'model_version': '0', - 'field': '61', - 'cuts': '2', - 'postprocess_function': '', - 'preprocess_function': '', - 'file_name': 'test_image.tiff', - 'input_file_name': 'test_image.tiff', - 'output_file_name': 'test_image.tiff' - } - redis_client.hmset = lambda x, y: True - consumer = consumers.ImageFileConsumer(redis_client, storage, prefix) - consumer._handle_error = _handle_error - consumer.detect_scale = detect_scale - consumer.detect_label = detect_label - consumer.get_model_metadata = make_model_metadata_of_size((1, 300, 300, 1)) - consumer.grpc_image = grpc_image - result = consumer._consume(dummyhash) - assert result == consumer.final_status - # test with model_name and model_version redis_client.hgetall = lambda x: { 'model_name': 'model', From fc9942ff74bddf52c35f5e591737adeed0cb1acc Mon Sep 17 00:00:00 2001 From: William Graf <7930703+willgraf@users.noreply.github.com> Date: Wed, 4 Mar 2020 10:56:05 -0800 Subject: [PATCH 29/47] remove unused import --- redis_consumer/consumers/image_consumer.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/redis_consumer/consumers/image_consumer.py b/redis_consumer/consumers/image_consumer.py index c416ef87..91d07ae7 100644 --- a/redis_consumer/consumers/image_consumer.py +++ b/redis_consumer/consumers/image_consumer.py @@ -33,8 +33,6 @@ import numpy as np -from deepcell_toolbox.utils import tile_image, untile_image - from redis_consumer.consumers import TensorFlowServingConsumer from redis_consumer import utils from redis_consumer import settings From 932d4415e031d1b85204f3c642063f679d6c1691 Mon Sep 17 00:00:00 2001 From: William Graf <7930703+willgraf@users.noreply.github.com> Date: Wed, 4 Mar 2020 11:38:23 -0800 Subject: [PATCH 30/47] update default models --- redis_consumer/settings.py | 24 ++++++++---------------- 1 file changed, 8 insertions(+), 16 deletions(-) diff --git a/redis_consumer/settings.py b/redis_consumer/settings.py index 922a8454..acf8d551 100644 --- a/redis_consumer/settings.py +++ b/redis_consumer/settings.py @@ -152,33 +152,25 @@ def _strip(x): NEIGHBORHOOD_SCALE_SIZE = config('NEIGHBORHOOD_SCALE_SIZE', default=30, cast=int) # Scale detection settings -SCALE_DETECT_MODEL = config('SCALE_DETECT_MODEL', default='ScaleDetection:3') +SCALE_DETECT_MODEL = config('SCALE_DETECT_MODEL', default='ScaleDetection:1') SCALE_DETECT_SAMPLE = config('SCALE_DETECT_SAMPLE', default=3, cast=int) # Not supported for tracking. Always detects scale SCALE_DETECT_ENABLED = config('SCALE_DETECT_ENABLED', default=False, cast=bool) # Type detection settings -LABEL_DETECT_MODEL = config('LABEL_DETECT_MODEL', default='LabelDetection:2', cast=str) +LABEL_DETECT_MODEL = config('LABEL_DETECT_MODEL', default='LabelDetection:1', cast=str) LABEL_DETECT_SAMPLE = config('LABEL_DETECT_SAMPLE', default=3, cast=int) LABEL_DETECT_ENABLED = config('LABEL_DETECT_ENABLED', default=False, cast=bool) # Set default models based on label type -PHASE_MODEL = config('PHASE_MODEL', default='panoptic_phase:0', cast=str) -CYTOPLASM_MODEL = config('CYTOPLASM_MODEL', default='panoptic_cytoplasm:0', cast=str) -NUCLEAR_MODEL = config('NUCLEAR_MODEL', default='panoptic:3', cast=str) - MODEL_CHOICES = { - 0: NUCLEAR_MODEL, - 1: PHASE_MODEL, - 2: CYTOPLASM_MODEL + 0: config('NUCLEAR_MODEL', default='NuclearSegmentation:0', cast=str), + 1: config('PHASE_MODEL', default='PhaseCytoSegmentation:0', cast=str), + 2: config('CYTOPLASM_MODEL', default='FluoCytoSegmentation:0', cast=str) } -PHASE_POSTPROCESS = config('PHASE_POSTPROCESS', default='deep_watershed', cast=str) -CYTOPLASM_POSTPROCESS = config('CYTOPLASM_POSTPROCESS', default='deep_watershed', cast=str) -NUCLEAR_POSTPROCESS = config('NUCLEAR_POSTPROCESS', default='deep_watershed', cast=str) - POSTPROCESS_CHOICES = { - 0: NUCLEAR_POSTPROCESS, - 1: PHASE_POSTPROCESS, - 2: CYTOPLASM_POSTPROCESS + 0: config('NUCLEAR_POSTPROCESS', default='deep_watershed', cast=str), + 1: config('PHASE_POSTPROCESS', default='deep_watershed', cast=str), + 2: config('CYTOPLASM_POSTPROCESS', default='deep_watershed', cast=str) } From 3bb214349985f13240fe3d70d183a5b2c9fea706 Mon Sep 17 00:00:00 2001 From: William Graf <7930703+willgraf@users.noreply.github.com> Date: Wed, 4 Mar 2020 11:39:05 -0800 Subject: [PATCH 31/47] remove TF_TENSOR_DTYPE from env var table --- README.md | 1 - 1 file changed, 1 deletion(-) diff --git a/README.md b/README.md index 57e23fa1..6c54c73c 100644 --- a/README.md +++ b/README.md @@ -101,7 +101,6 @@ The consumer is configured using environment variables. Please find a table of a | `TF_HOST` | The IP address or hostname of TensorFlow Serving. | `"tf-serving"` | | `TF_PORT` | The port used to connect to TensorFlow Serving. | `8500` | | `TF_TENSOR_NAME` | Name of input tensor for the exported model. | `"image"` | -| `TF_TENSOR_DTYPE` | The `dtype` used for the exported model. | `"DT_FLOAT"` | | `GRPC_TIMEOUT` | Timeout for gRPC API requests, in seconds. | `30` | | `GRPC_BACKOFF` | Time to wait before retrying a gRPC API request. | `3` | | `MAX_RETRY` | Maximum number of retries for a failed TensorFlow Serving request. | `5` | From 323a64a423ab122bb8f65e5f91ddca366a7c0844 Mon Sep 17 00:00:00 2001 From: William Graf <7930703+willgraf@users.noreply.github.com> Date: Wed, 4 Mar 2020 12:08:39 -0800 Subject: [PATCH 32/47] if LABEL_DETECT_ENABLED, use empty model and postprocess, otherwise use env vars. --- redis_consumer/consumers/tracking_consumer.py | 24 +++++++++---------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/redis_consumer/consumers/tracking_consumer.py b/redis_consumer/consumers/tracking_consumer.py index 81e95bf5..6e030233 100644 --- a/redis_consumer/consumers/tracking_consumer.py +++ b/redis_consumer/consumers/tracking_consumer.py @@ -156,16 +156,14 @@ def _load_data(self, redis_hash, subdir, fname): # self.logger.debug('Image scale already calculated: %s', scale) # Pick model and postprocess based on either label or defaults - # if settings.LABEL_DETECT_ENABLED: - # label = self.detect_label(tiff_stack) # Predict label type - # - # # Get appropriate model and postprocess function for the label - # model_name, model_version = utils._pick_model(label) - # postprocess_function = utils._pick_postprocess(label) - # else: - # label = 99 # Equivalent to none - # model_name, model_version = settings.TRACKING_SEGMENT_MODEL.split(':') - # postprocess_function = settings.TRACKING_POSTPROCESS_FUNCTION + if settings.LABEL_DETECT_ENABLED: + # model and postprocessing will be determined automatically + # by the ImageFileConsumer + model_name, model_version = '', '' + postprocess_function = '' + else: + model_name, model_version = settings.TRACKING_SEGMENT_MODEL.split(':') + postprocess_function = settings.TRACKING_POSTPROCESS_FUNCTION num_frames = len(tiff_stack) hash_to_frame = {} @@ -192,9 +190,9 @@ def _load_data(self, redis_hash, subdir, fname): 'identity_upload': self.hostname, 'input_file_name': upload_file_name, 'original_name': segment_fname, - # 'model_name': model_name, - # 'model_version': model_version, - # 'postprocess_function': postprocess_function, + 'model_name': model_name, + 'model_version': model_version, + 'postprocess_function': postprocess_function, 'status': 'new', 'created_at': current_timestamp, 'updated_at': current_timestamp, From 9c9070437c711219b71f23c7f23e8f633ee82e87 Mon Sep 17 00:00:00 2001 From: William Graf <7930703+willgraf@users.noreply.github.com> Date: Wed, 4 Mar 2020 12:21:12 -0800 Subject: [PATCH 33/47] handle scale detection --- redis_consumer/consumers/tracking_consumer.py | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/redis_consumer/consumers/tracking_consumer.py b/redis_consumer/consumers/tracking_consumer.py index 6e030233..d7938c5b 100644 --- a/redis_consumer/consumers/tracking_consumer.py +++ b/redis_consumer/consumers/tracking_consumer.py @@ -145,15 +145,8 @@ def _load_data(self, redis_hash, subdir, fname): tiff_stack.shape)) # Calculate scale of a subset of raw - # scale = hvalues.get('scale', '') - # if not scale: - # # Detect scale of image - # scale = self.detect_scale(tiff_stack) - # self.logger.debug('Image scale detected: %s', scale) - # self.update_key(redis_hash, {'scale': scale}) - # else: - # scale = float(scale) - # self.logger.debug('Image scale already calculated: %s', scale) + scale = hvalues.get('scale', '') + scale = scale if settings.SCALE_DETECT_ENABLED else 1 # Pick model and postprocess based on either label or defaults if settings.LABEL_DETECT_ENABLED: @@ -197,7 +190,7 @@ def _load_data(self, redis_hash, subdir, fname): 'created_at': current_timestamp, 'updated_at': current_timestamp, 'url': upload_file_url, - # 'scale': scale, + 'scale': scale, # 'label': str(label) } From 2f6353e4ac3a3c7d38ab621c75abd1fddae453a9 Mon Sep 17 00:00:00 2001 From: William Graf <7930703+willgraf@users.noreply.github.com> Date: Wed, 4 Mar 2020 12:37:58 -0800 Subject: [PATCH 34/47] linted --- redis_consumer/consumers/base_consumer.py | 2 +- redis_consumer/consumers/tracking_consumer.py | 38 ++++++++++--------- 2 files changed, 21 insertions(+), 19 deletions(-) diff --git a/redis_consumer/consumers/base_consumer.py b/redis_consumer/consumers/base_consumer.py index 05c4fd60..4044afb3 100644 --- a/redis_consumer/consumers/base_consumer.py +++ b/redis_consumer/consumers/base_consumer.py @@ -199,7 +199,7 @@ def consume(self): 'postprocess_function', ] result = self.redis.hmget(redis_hash, *required_fields) - hvals = {f: v for f, v in zip(required_fields, result)} + hvals = dict(zip(required_fields, result)) self.logger.debug('Consumed key %s (model %s:%s, ' 'preprocessing: %s, postprocessing: %s) ' '(%s retries) in %s seconds.', diff --git a/redis_consumer/consumers/tracking_consumer.py b/redis_consumer/consumers/tracking_consumer.py index d7938c5b..553ca495 100644 --- a/redis_consumer/consumers/tracking_consumer.py +++ b/redis_consumer/consumers/tracking_consumer.py @@ -95,15 +95,16 @@ def _get_tracker(self, redis_hash, hvalues, raw, segmented): raw[frame, :, :, 0] = processing.normalize(raw[frame, :, :, 0]) features = {'appearance', 'distance', 'neighborhood', 'regionprop'} - tracker = tracking.CellTracker(raw, segmented, - tracking_model, - max_distance=settings.MAX_DISTANCE, - track_length=settings.TRACK_LENGTH, - division=settings.DIVISION, - birth=settings.BIRTH, - death=settings.DEATH, - neighborhood_scale_size=settings.NEIGHBORHOOD_SCALE_SIZE, - features=features) + tracker = tracking.CellTracker( + raw, segmented, + tracking_model, + max_distance=settings.MAX_DISTANCE, + track_length=settings.TRACK_LENGTH, + division=settings.DIVISION, + birth=settings.BIRTH, + death=settings.DEATH, + neighborhood_scale_size=settings.NEIGHBORHOOD_SCALE_SIZE, + features=features) self.logger.debug('Created tracker!') return tracker @@ -139,9 +140,9 @@ def _load_data(self, redis_hash, subdir, fname): # remove the last dimensions added by `get_image` tiff_stack = np.squeeze(raw, -1) # TODO: required? check the ndim? if len(tiff_stack.shape) != 3: - raise ValueError("This tiff file has shape {}, which is not 3 " - "dimensions. Tracking can only be done on images " - "with 3 dimensions, (time, width, height)".format( + raise ValueError('This tiff file has shape {}, which is not 3 ' + 'dimensions. Tracking can only be done on images ' + 'with 3 dimensions, (time, width, height)'.format( tiff_stack.shape)) # Calculate scale of a subset of raw @@ -194,7 +195,7 @@ def _load_data(self, redis_hash, subdir, fname): # 'label': str(label) } - self.logger.debug("Setting %s", frame_hvalues) + self.logger.debug('Setting %s', frame_hvalues) # make a hash for this frame segment_hash = '{prefix}:{file}:{hash}'.format( @@ -233,7 +234,7 @@ def _load_data(self, redis_hash, subdir, fname): '\nSegmentation Error: {}'.format( hash_to_frame[segment_hash], reason)) - elif status == self.final_status: + if status == self.final_status: # if it's done, save the frame, as they'll be packed up # later frame_zip = self.storage.download( @@ -245,9 +246,9 @@ def _load_data(self, redis_hash, subdir, fname): if len(frame_files) != 1: raise RuntimeError( - "After unzipping predicted frame, got " - "back multiple files {}. Expected a " - "single file.".format(frame_files)) + 'After unzipping predicted frame, got ' + 'back multiple files {}. Expected a ' + 'single file.'.format(frame_files)) frame_idx = hash_to_frame[segment_hash] frames[frame_idx] = utils.get_image(frame_files[0]) @@ -259,7 +260,8 @@ def _load_data(self, redis_hash, subdir, fname): frames = [frames[i] for i in range(num_frames)] # Cast y to int to avoid issues during fourier transform/drift correction - return {"X": np.expand_dims(tiff_stack, axis=-1), "y": np.array(frames, dtype='uint16')} + return {'X': np.expand_dims(tiff_stack, axis=-1), + 'y': np.array(frames, dtype='uint16')} def _consume(self, redis_hash): hvalues = self.redis.hgetall(redis_hash) From 73f356ef3c37a06361d233004dbbc87ed79c6a44 Mon Sep 17 00:00:00 2001 From: William Graf <7930703+willgraf@users.noreply.github.com> Date: Sat, 7 Mar 2020 10:33:33 -0800 Subject: [PATCH 35/47] migrate retry login into PredictClient and use context to auto-close channel. --- redis_consumer/consumers/base_consumer.py | 107 ++++-------- redis_consumer/grpc_clients.py | 194 ++++++++++------------ 2 files changed, 124 insertions(+), 177 deletions(-) diff --git a/redis_consumer/consumers/base_consumer.py b/redis_consumer/consumers/base_consumer.py index 4044afb3..3280ffb8 100644 --- a/redis_consumer/consumers/base_consumer.py +++ b/redis_consumer/consumers/base_consumer.py @@ -40,7 +40,6 @@ import uuid import zipfile -import grpc import numpy as np import pytz @@ -254,82 +253,42 @@ def grpc_image(self, img, model_name, model_version, in_tensor_dtype='DT_FLOAT'): in_tensor_dtype = str(in_tensor_dtype).upper() - true_failures, count = 0, 0 + start = timeit.default_timer() self.logger.debug('Segmenting image of shape %s with model %s:%s', img.shape, model_name, model_version) - retrying = True - while retrying: - try: - if in_tensor_dtype == 'DT_HALF': - # TODO: seems like should cast to "half" - # but the model rejects the type, wants "int" or "long" - img = img.astype('int') - - req_data = [{'in_tensor_name': settings.TF_TENSOR_NAME, - 'in_tensor_dtype': in_tensor_dtype, - 'data': np.expand_dims(img, axis=0)}] - - client = self._get_predict_client(model_name, model_version) - - prediction = client.predict(req_data, settings.GRPC_TIMEOUT) - results = [prediction[k] for k in sorted(prediction.keys()) - if k.startswith('prediction')] - - if len(results) == 1: - results = results[0] - - retrying = False - - finished = timeit.default_timer() - start - if self._redis_hash is not None: - self.update_key(self._redis_hash, { - 'prediction_time': finished, - 'predict_retries': count, - }) - self.logger.debug('Segmented key %s (model %s:%s, ' - 'preprocessing: %s, postprocessing: %s)' - ' (%s retries) in %s seconds.', - self._redis_hash, model_name, model_version, - self._redis_values.get('preprocess_function'), - self._redis_values.get('postprocess_function'), - count, finished) - return results - except grpc.RpcError as err: - # pylint: disable=E1101 - if true_failures > settings.MAX_RETRY > 0: - retrying = False - raise RuntimeError('Prediction has failed {} times due to ' - 'error {}'.format(count, err)) - if err.code() in settings.GRPC_RETRY_STATUSES: - count += 1 - is_true_failure = err.code() != grpc.StatusCode.UNAVAILABLE - true_failures += int(is_true_failure) - # write update to Redis - temp_status = 'retry-predicting - {} - {}'.format( - count, err.code().name) - if self._redis_hash is not None: - self.update_key(self._redis_hash, { - 'status': temp_status, - 'predict_retries': count, - }) - self.logger.warning('%sException `%s: %s` during ' - 'PredictClient request to model %s:%s.' - ' Waiting %s seconds before retrying.', - type(err).__name__, err.code().name, - err.details(), model_name, - model_version, settings.GRPC_BACKOFF) - time.sleep(settings.GRPC_BACKOFF) # sleep before retry - retrying = True # Unneccessary but explicit - else: - retrying = False - raise err - except Exception as err: - retrying = False - self.logger.error('Encountered %s during tf-serving request to ' - 'model %s:%s: %s', type(err).__name__, - model_name, model_version, err) - raise err + + if in_tensor_dtype == 'DT_HALF': + # TODO: seems like should cast to "half" + # but the model rejects the type, wants "int" or "long" + img = img.astype('int') + + req_data = [{'in_tensor_name': settings.TF_TENSOR_NAME, + 'in_tensor_dtype': in_tensor_dtype, + 'data': np.expand_dims(img, axis=0)}] + + client = self._get_predict_client(model_name, model_version) + + prediction = client.predict(req_data, settings.GRPC_TIMEOUT) + results = [prediction[k] for k in sorted(prediction.keys()) + if k.startswith('prediction')] + + if len(results) == 1: + results = results[0] + + finished = timeit.default_timer() - start + if self._redis_hash is not None: + self.update_key(self._redis_hash, { + 'prediction_time': finished, + }) + self.logger.debug('Segmented key %s (model %s:%s, ' + 'preprocessing: %s, postprocessing: %s)' + ' (%s retries) in %s seconds.', + self._redis_hash, model_name, model_version, + self._redis_values.get('preprocess_function'), + self._redis_values.get('postprocess_function'), + 0, finished) + return results def get_model_metadata(self, model_name, model_version): """Check Redis for saved model metadata or get from TensorFlow Serving. diff --git a/redis_consumer/grpc_clients.py b/redis_consumer/grpc_clients.py index 192aa0a4..51c856be 100644 --- a/redis_consumer/grpc_clients.py +++ b/redis_consumer/grpc_clients.py @@ -62,6 +62,12 @@ class GrpcClient(object): def __init__(self, host): self.logger = logging.getLogger(self.__class__.__name__) self.host = host + self.options = [ + (cygrpc.ChannelArgKey.max_send_message_length, -1), + (cygrpc.ChannelArgKey.max_receive_message_length, -1), + ('grpc.default_compression_algorithm', cygrpc.CompressionAlgorithm.gzip), + ('grpc.grpc.default_compression_level', cygrpc.CompressionLevel.high) + ] def insecure_channel(self): """Create an insecure channel with max message length. @@ -70,13 +76,7 @@ def insecure_channel(self): channel: grpc.insecure channel object """ t = timeit.default_timer() - options = [ - (cygrpc.ChannelArgKey.max_send_message_length, -1), - (cygrpc.ChannelArgKey.max_receive_message_length, -1), - ('grpc.default_compression_algorithm', cygrpc.CompressionAlgorithm.gzip), - ('grpc.grpc.default_compression_level', cygrpc.CompressionLevel.high) - ] - channel = grpc.insecure_channel(target=self.host, options=options) + channel = grpc.insecure_channel(target=self.host, options=self.options) self.logger.debug('Establishing insecure channel took: %s', timeit.default_timer() - t) return channel @@ -96,135 +96,123 @@ def __init__(self, host, model_name, model_version): self.model_name = model_name self.model_version = model_version + self.stub_lookup = { + GetModelMetadataRequest: 'GetModelMetadata', + PredictRequest: 'Predict', + } + + def _retry_grpc(self, request, request_timeout): + request_name = request.__class__.__name__ + self.logger.info('Sending %s to %s model %s:%s.', + request_name, self.host, + self.model_name, self.model_version) + + true_failures, count = 0, 0 + + retrying = True + while retrying: + with self.insecure_channel() as channel: + # pylint: disable=E1101 + try: + t = timeit.default_timer() + + stub = PredictionServiceStub(channel) + + api_endpoint_name = self.stub_lookup.get(request.__class__) + api_call = getattr(stub, api_endpoint_name) + response = api_call(request, timeout=request_timeout) + + self.logger.debug('%s finished in %s seconds.', + request_name, timeit.default_timer() - t) + return response + + except grpc.RpcError as err: + if true_failures > settings.MAX_RETRY > 0: + retrying = False + self.logger.error('%s has failed %s times due to err ' + '%s', request_name, count, err) + raise err + + if err.code() in settings.GRPC_RETRY_STATUSES: + count += 1 + is_true_failure = err.code() != grpc.StatusCode.UNAVAILABLE + true_failures += int(is_true_failure) + + self.logger.warning('%sException `%s: %s` during ' + '%s %s to model %s:%s. Waiting %s ' + 'seconds before retrying.', + type(err).__name__, + err.code().name, err.details(), + self.__class__.__name__, + request_name, + self.model_name, self.model_version, + settings.GRPC_BACKOFF) + + time.sleep(settings.GRPC_BACKOFF) # sleep before retry + retrying = True # Unneccessary but explicit + else: + retrying = False + raise err + except Exception as err: + retrying = False + self.logger.error('Encountered %s during %s to model ' + '%s:%s: %s', type(err).__name__, + request_name, self.model_name, + self.model_version, err) + raise err + def predict(self, request_data, request_timeout=10): self.logger.info('Sending PredictRequest to %s model %s:%s.', self.host, self.model_name, self.model_version) - channel = self.insecure_channel() - - t = timeit.default_timer() - stub = PredictionServiceStub(channel) - self.logger.debug('Created PredictionServiceStub in %s seconds.', - timeit.default_timer() - t) - t = timeit.default_timer() request = PredictRequest() self.logger.debug('Created PredictRequest object in %s seconds.', timeit.default_timer() - t) - request.model_spec.name = self.model_name # pylint: disable=E1101 + # pylint: disable=E1101 + request.model_spec.name = self.model_name if self.model_version > 0: - # pylint: disable=E1101 request.model_spec.version.value = self.model_version t = timeit.default_timer() for d in request_data: tensor_proto = make_tensor_proto(d['data'], d['in_tensor_dtype']) - # pylint: disable=E1101 request.inputs[d['in_tensor_name']].CopyFrom(tensor_proto) self.logger.debug('Made tensor protos in %s seconds.', timeit.default_timer() - t) - try: - t = timeit.default_timer() - predict_response = stub.Predict(request, timeout=request_timeout) - self.logger.debug('gRPC PredictRequest finished in %s seconds.', - timeit.default_timer() - t) - - t = timeit.default_timer() - predict_response_dict = grpc_response_to_dict(predict_response) - self.logger.debug('gRPC PredictResponseProtobufConversion took ' - '%s seconds.', timeit.default_timer() - t) - - keys = [k for k in predict_response_dict] - self.logger.info('Got PredictResponse with keys: %s ', - keys) - channel.close() - return predict_response_dict + response = self._retry_grpc(request, request_timeout) + response_dict = grpc_response_to_dict(response) - except RpcError as err: - self.logger.error('PredictRequest failed due to: %s', err) - channel.close() - raise err + self.logger.info('Got PredictResponse with keys: %s ', + list(response_dict)) - channel.close() - return {} + return response_dict def get_model_metadata(self, request_timeout=10): self.logger.info('Sending GetModelMetadataRequest to %s model %s:%s.', self.host, self.model_name, self.model_version) - true_failures, count = 0, 0 - - retrying = True - while retrying: - # pylint: disable=E1101 - try: - t = timeit.default_timer() - channel = self.insecure_channel() - - stub = PredictionServiceStub(channel) - - request = GetModelMetadataRequest() - - request.metadata_field.append('signature_def') - - request.model_spec.name = self.model_name - - if self.model_version > 0: - request.model_spec.version.value = self.model_version - - response = stub.GetModelMetadata(request, timeout=request_timeout) - - self.logger.debug('gRPC GetModelMetadataRequest finished in %s ' - 'seconds.', timeit.default_timer() - t) - - t = timeit.default_timer() + # pylint: disable=E1101 + request = GetModelMetadataRequest() + request.metadata_field.append('signature_def') + request.model_spec.name = self.model_name + if self.model_version > 0: + request.model_spec.version.value = self.model_version - response_dict = json.loads(MessageToJson(response)) + response = self._retry_grpc(request, request_timeout) - # signature_def = response.metadata['signature_def'] - self.logger.debug('gRPC GetModelMetadataProtobufConversion took ' - '%s seconds.', timeit.default_timer() - t) + t = timeit.default_timer() - channel.close() - return response_dict + response_dict = json.loads(MessageToJson(response)) - except grpc.RpcError as err: - channel.close() - if true_failures > settings.MAX_RETRY > 0: - retrying = False - self.logger.error('GetModelMetadataRequest has failed %s ' - 'times due to err %s', count, err) - raise err + self.logger.debug('gRPC GetModelMetadataProtobufConversion took ' + '%s seconds.', timeit.default_timer() - t) - if err.code() in settings.GRPC_RETRY_STATUSES: - count += 1 - is_true_failure = err.code() != grpc.StatusCode.UNAVAILABLE - true_failures += int(is_true_failure) - - self.logger.warning('%sException `%s: %s` during ' - 'PredictClient GetModelMetadataRequest to ' - 'model %s:%s. Waiting %s seconds before ' - 'retrying.', type(err).__name__, - err.code().name, err.details(), - self.model_name, self.model_version, - settings.GRPC_BACKOFF) - - time.sleep(settings.GRPC_BACKOFF) # sleep before retry - retrying = True # Unneccessary but explicit - else: - retrying = False - raise err - except Exception as err: - channel.close() - retrying = False - self.logger.error('Encountered %s during GetModelMetadataRequest' - ' to model %s:%s: %s', type(err).__name__, - self.model_name, self.model_version, err) - raise err + return response_dict class TrackingClient(GrpcClient): From 1675bcc66acff71251b7d488e49d055ebce3a011 Mon Sep 17 00:00:00 2001 From: William Graf <7930703+willgraf@users.noreply.github.com> Date: Sat, 7 Mar 2020 13:31:15 -0800 Subject: [PATCH 36/47] first attempt to support batch size --- redis_consumer/consumers/base_consumer.py | 8 ++++++-- redis_consumer/settings.py | 6 +----- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/redis_consumer/consumers/base_consumer.py b/redis_consumer/consumers/base_consumer.py index 3280ffb8..7a156929 100644 --- a/redis_consumer/consumers/base_consumer.py +++ b/redis_consumer/consumers/base_consumer.py @@ -359,7 +359,10 @@ def _predict_big_image(self, numpy.array: untiled results from the model. """ is_untile_required = sample is None - sample = 1 if sample is None else sample + + if sample is None: + sample = settings.TF_MAX_BATCH_SIZE + model_ndim = len(model_shape) input_shape = (model_shape[model_ndim - 3], model_shape[model_ndim - 2]) @@ -375,7 +378,8 @@ def _predict_big_image(self, # dependent on the tf-serving configuration results = [] for t in range(0, tiles.shape[0], sample): - output = self.grpc_image(tiles[t], model_name, model_version, + batch = tiles[t:t + sample] if is_untile_required else tiles[t] + output = self.grpc_image(tiles[batch], model_name, model_version, in_tensor_dtype=model_dtype) if not isinstance(output, list): diff --git a/redis_consumer/settings.py b/redis_consumer/settings.py index acf8d551..ad43b0fe 100644 --- a/redis_consumer/settings.py +++ b/redis_consumer/settings.py @@ -57,11 +57,7 @@ def _strip(x): TF_HOST = config('TF_HOST', default='tf-serving') TF_PORT = config('TF_PORT', default=8500, cast=int) TF_TENSOR_NAME = config('TF_TENSOR_NAME', default='image') -TF_TENSOR_DTYPE = config('TF_TENSOR_DTYPE', default='DT_FLOAT') - -# data-processing client connection -DP_HOST = config('DP_HOST', default='data-processing') -DP_PORT = config('DP_PORT', default=8080, cast=int) +TF_MAX_BATCH_SIZE = config('TF_MAX_BATCH_SIZE', default=1) # gRPC API timeout in seconds (scales with `cuts`) GRPC_TIMEOUT = config('GRPC_TIMEOUT', default=30, cast=int) From c5d801020dc6e2860b9ecba8493aa5cc9896261e Mon Sep 17 00:00:00 2001 From: William Graf <7930703+willgraf@users.noreply.github.com> Date: Sat, 7 Mar 2020 13:31:29 -0800 Subject: [PATCH 37/47] sleep if the hash had to get put back --- redis_consumer/consumers/base_consumer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/redis_consumer/consumers/base_consumer.py b/redis_consumer/consumers/base_consumer.py index 7a156929..65fb5065 100644 --- a/redis_consumer/consumers/base_consumer.py +++ b/redis_consumer/consumers/base_consumer.py @@ -216,6 +216,7 @@ def consume(self): # this key is not done yet. # remove it from processing and push it back to the work queue. self._put_back_hash(redis_hash) + time.sleep(.5) else: self.logger.debug('Queue `%s` is empty. Waiting for %s seconds.', From c3df985c9f7fe5851c4fd5d860b5388ce1cec8f3 Mon Sep 17 00:00:00 2001 From: William Graf <7930703+willgraf@users.noreply.github.com> Date: Sun, 8 Mar 2020 09:47:42 -0700 Subject: [PATCH 38/47] cast batch size as an int and only expand img if no batches exist. --- redis_consumer/consumers/base_consumer.py | 15 +++++++++------ redis_consumer/settings.py | 2 +- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/redis_consumer/consumers/base_consumer.py b/redis_consumer/consumers/base_consumer.py index 65fb5065..cec62159 100644 --- a/redis_consumer/consumers/base_consumer.py +++ b/redis_consumer/consumers/base_consumer.py @@ -250,7 +250,7 @@ def _get_predict_client(self, model_name, model_version): timeit.default_timer() - t) return client - def grpc_image(self, img, model_name, model_version, + def grpc_image(self, img, model_name, model_version, model_shape, in_tensor_dtype='DT_FLOAT'): in_tensor_dtype = str(in_tensor_dtype).upper() @@ -259,6 +259,9 @@ def grpc_image(self, img, model_name, model_version, self.logger.debug('Segmenting image of shape %s with model %s:%s', img.shape, model_name, model_version) + if len(model_shape) == img.ndim + 1: + img = np.expand_dims(img, axis=0) + if in_tensor_dtype == 'DT_HALF': # TODO: seems like should cast to "half" # but the model rejects the type, wants "int" or "long" @@ -266,7 +269,7 @@ def grpc_image(self, img, model_name, model_version, req_data = [{'in_tensor_name': settings.TF_TENSOR_NAME, 'in_tensor_dtype': in_tensor_dtype, - 'data': np.expand_dims(img, axis=0)}] + 'data': img}] client = self._get_predict_client(model_name, model_version) @@ -380,8 +383,8 @@ def _predict_big_image(self, results = [] for t in range(0, tiles.shape[0], sample): batch = tiles[t:t + sample] if is_untile_required else tiles[t] - output = self.grpc_image(tiles[batch], model_name, model_version, - in_tensor_dtype=model_dtype) + output = self.grpc_image(batch, model_name, model_version, + model_shape, in_tensor_dtype=model_dtype) if not isinstance(output, list): output = [output] @@ -442,7 +445,7 @@ def _predict_small_image(self, padded_img = np.pad(image, pad_width, 'reflect') image = self.grpc_image(padded_img, model_name, model_version, - in_tensor_dtype=model_dtype) + model_shape, in_tensor_dtype=model_dtype) image = [image] if not isinstance(image, list) else image @@ -497,7 +500,7 @@ def predict(self, image, model_name, model_version, sample=None): else: # image size is perfect, just send it to the model image = self.grpc_image(image, model_name, model_version, - in_tensor_dtype=model_dtype) + model_shape, in_tensor_dtype=model_dtype) if isinstance(image, list): output_shapes = [i.shape for i in image] diff --git a/redis_consumer/settings.py b/redis_consumer/settings.py index ad43b0fe..af199a3c 100644 --- a/redis_consumer/settings.py +++ b/redis_consumer/settings.py @@ -57,7 +57,7 @@ def _strip(x): TF_HOST = config('TF_HOST', default='tf-serving') TF_PORT = config('TF_PORT', default=8500, cast=int) TF_TENSOR_NAME = config('TF_TENSOR_NAME', default='image') -TF_MAX_BATCH_SIZE = config('TF_MAX_BATCH_SIZE', default=1) +TF_MAX_BATCH_SIZE = config('TF_MAX_BATCH_SIZE', default=1, cast=int) # gRPC API timeout in seconds (scales with `cuts`) GRPC_TIMEOUT = config('GRPC_TIMEOUT', default=30, cast=int) From 3b01513565abfa9e874b950e8d2ce725b9005ec8 Mon Sep 17 00:00:00 2001 From: William Graf <7930703+willgraf@users.noreply.github.com> Date: Tue, 10 Mar 2020 11:28:56 -0700 Subject: [PATCH 39/47] replace sample with dynamically calculated batch size --- redis_consumer/consumers/base_consumer.py | 30 +++++++++++----------- redis_consumer/consumers/image_consumer.py | 4 +-- redis_consumer/settings.py | 13 +++++----- 3 files changed, 23 insertions(+), 24 deletions(-) diff --git a/redis_consumer/consumers/base_consumer.py b/redis_consumer/consumers/base_consumer.py index cec62159..a82a081b 100644 --- a/redis_consumer/consumers/base_consumer.py +++ b/redis_consumer/consumers/base_consumer.py @@ -346,7 +346,7 @@ def _predict_big_image(self, model_version, model_shape, model_dtype='DT_FLOAT', - sample=None): + untile=True): """Use tile_image to tile image for the model and untile the results. Args: @@ -355,21 +355,20 @@ def _predict_big_image(self, model_version (str): model version to query. model_shape (tuple): shape of the model's expected input. model_dtype (str): dtype of the model's input array. - sample (int): Only predict every sample'th tile. - If sample is not None, no untiling will be performed, - as the untiling data will be incomplete. + untile (bool): untiles results back to image shape if True. Returns: numpy.array: untiled results from the model. """ - is_untile_required = sample is None - - if sample is None: - sample = settings.TF_MAX_BATCH_SIZE - model_ndim = len(model_shape) input_shape = (model_shape[model_ndim - 3], model_shape[model_ndim - 2]) + ratio = (model_shape[model_ndim - 3] / settings.TF_MIN_MODEL_SIZE) * \ + (model_shape[model_ndim - 2] / settings.TF_MIN_MODEL_SIZE) * \ + (model_shape[model_ndim - 1]) + + batch_size = int(settings.TF_MAX_BATCH_SIZE // ratio) + tiles, tiles_info = tile_image( np.expand_dims(image, axis=0), model_input_shape=input_shape, @@ -381,8 +380,8 @@ def _predict_big_image(self, # max_batch_size is 1 by default. # dependent on the tf-serving configuration results = [] - for t in range(0, tiles.shape[0], sample): - batch = tiles[t:t + sample] if is_untile_required else tiles[t] + for t in range(0, tiles.shape[0], batch_size): + batch = tiles[t:t + batch_size] output = self.grpc_image(batch, model_name, model_version, model_shape, in_tensor_dtype=model_dtype) @@ -395,7 +394,7 @@ def _predict_big_image(self, for i, o in enumerate(output): results[i] = np.vstack((results[i], o)) - if not is_untile_required: + if not untile: image = results else: image = [untile_image(r, tiles_info, model_input_shape=input_shape) @@ -461,7 +460,7 @@ def _predict_small_image(self, return image - def predict(self, image, model_name, model_version, sample=None): + def predict(self, image, model_name, model_version, untile=True): start = timeit.default_timer() model_metadata = self.get_model_metadata(model_name, model_version) @@ -496,11 +495,12 @@ def predict(self, image, model_name, model_version, sample=None): image.shape[image.ndim - 2] > size_y): # image is too big for the model, multiple images are tiled. image = self._predict_big_image(image, model_name, model_version, - model_shape, model_dtype, sample) + model_shape, model_dtype, + untile=untile) else: # image size is perfect, just send it to the model image = self.grpc_image(image, model_name, model_version, - model_shape, in_tensor_dtype=model_dtype) + model_shape, model_dtype) if isinstance(image, list): output_shapes = [i.shape for i in image] diff --git a/redis_consumer/consumers/image_consumer.py b/redis_consumer/consumers/image_consumer.py index 91d07ae7..64b76d36 100644 --- a/redis_consumer/consumers/image_consumer.py +++ b/redis_consumer/consumers/image_consumer.py @@ -119,7 +119,7 @@ def detect_scale(self, image): model_name, model_version = settings.SCALE_DETECT_MODEL.split(':') scales = self.predict(image, model_name, model_version, - sample=settings.SCALE_DETECT_SAMPLE) + untile=False) detected_scale = np.mean(scales) @@ -141,7 +141,7 @@ def detect_label(self, image): model_name, model_version = settings.LABEL_DETECT_MODEL.split(':') labels = self.predict(image, model_name, model_version, - sample=settings.SCALE_DETECT_SAMPLE) + untile=False) labels = np.array(labels) vote = labels.sum(axis=0) diff --git a/redis_consumer/settings.py b/redis_consumer/settings.py index af199a3c..2d5462ad 100644 --- a/redis_consumer/settings.py +++ b/redis_consumer/settings.py @@ -53,13 +53,16 @@ def _strip(x): REDIS_HOST = config('REDIS_HOST', default='redis-master') REDIS_PORT = config('REDIS_PORT', default=6379, cast=int) -# tensorflow-serving client connection +# TensorFlow Serving client connection TF_HOST = config('TF_HOST', default='tf-serving') TF_PORT = config('TF_PORT', default=8500, cast=int) TF_TENSOR_NAME = config('TF_TENSOR_NAME', default='image') -TF_MAX_BATCH_SIZE = config('TF_MAX_BATCH_SIZE', default=1, cast=int) +# maximum batch allowed by TensorFlow Serving +TF_MAX_BATCH_SIZE = config('TF_MAX_BATCH_SIZE', default=128, cast=int) +# minimum expected model size, dynamically change batches proportionately. +TF_MIN_MODEL_SIZE = config('TF_MIN_MODEL_SIZE', default=128, cast=int) -# gRPC API timeout in seconds (scales with `cuts`) +# gRPC API timeout in seconds GRPC_TIMEOUT = config('GRPC_TIMEOUT', default=30, cast=int) GRPC_BACKOFF = config('GRPC_BACKOFF', default=3, cast=int) @@ -101,8 +104,6 @@ def _strip(x): # Pod Meteadta HOSTNAME = config('HOSTNAME', default='host-unkonwn') -CUTS = config('CUTS', default=0, cast=int) # TODO: deprecated - # Redis queue QUEUE = config('QUEUE', default='predict') SEGMENTATION_QUEUE = config('SEGMENTATION_QUEUE', default='predict') @@ -149,13 +150,11 @@ def _strip(x): # Scale detection settings SCALE_DETECT_MODEL = config('SCALE_DETECT_MODEL', default='ScaleDetection:1') -SCALE_DETECT_SAMPLE = config('SCALE_DETECT_SAMPLE', default=3, cast=int) # Not supported for tracking. Always detects scale SCALE_DETECT_ENABLED = config('SCALE_DETECT_ENABLED', default=False, cast=bool) # Type detection settings LABEL_DETECT_MODEL = config('LABEL_DETECT_MODEL', default='LabelDetection:1', cast=str) -LABEL_DETECT_SAMPLE = config('LABEL_DETECT_SAMPLE', default=3, cast=int) LABEL_DETECT_ENABLED = config('LABEL_DETECT_ENABLED', default=False, cast=bool) # Set default models based on label type From d30f2d1fc536d86fd5ec393ea279290424565f3c Mon Sep 17 00:00:00 2001 From: William Graf <7930703+willgraf@users.noreply.github.com> Date: Tue, 10 Mar 2020 13:11:34 -0700 Subject: [PATCH 40/47] pass stride_ratio as a parameter to _predict_big_image --- redis_consumer/consumers/base_consumer.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/redis_consumer/consumers/base_consumer.py b/redis_consumer/consumers/base_consumer.py index a82a081b..da65b588 100644 --- a/redis_consumer/consumers/base_consumer.py +++ b/redis_consumer/consumers/base_consumer.py @@ -346,7 +346,8 @@ def _predict_big_image(self, model_version, model_shape, model_dtype='DT_FLOAT', - untile=True): + untile=True, + stride_ratio=0.75): """Use tile_image to tile image for the model and untile the results. Args: @@ -356,6 +357,7 @@ def _predict_big_image(self, model_shape (tuple): shape of the model's expected input. model_dtype (str): dtype of the model's input array. untile (bool): untiles results back to image shape if True. + stride_ratio (float): amount to overlap between tiles, (0, 1]. Returns: numpy.array: untiled results from the model. @@ -372,7 +374,7 @@ def _predict_big_image(self, tiles, tiles_info = tile_image( np.expand_dims(image, axis=0), model_input_shape=input_shape, - stride_ratio=0.75) + stride_ratio=stride_ratio) self.logger.debug('Tiling image of shape %s into shape %s.', image.shape, tiles.shape) From c31f7320bf5f4f8a72b057bc62d708d2e7873052 Mon Sep 17 00:00:00 2001 From: William Graf <7930703+willgraf@users.noreply.github.com> Date: Tue, 10 Mar 2020 13:13:18 -0700 Subject: [PATCH 41/47] fix tests for dynamic batch expansion --- .../consumers/base_consumer_test.py | 26 ++++++++++--------- .../consumers/image_consumer_test.py | 9 +++---- 2 files changed, 17 insertions(+), 18 deletions(-) diff --git a/redis_consumer/consumers/base_consumer_test.py b/redis_consumer/consumers/base_consumer_test.py index ac8f54da..d57c2f18 100644 --- a/redis_consumer/consumers/base_consumer_test.py +++ b/redis_consumer/consumers/base_consumer_test.py @@ -347,6 +347,7 @@ def test__get_predict_client(self): def test_grpc_image(self): redis_client = DummyRedis([]) consumer = consumers.TensorFlowServingConsumer(redis_client, None, 'q') + model_shape = (-1, 128, 128, 1) def _get_predict_client(model_name, model_version): return Bunch(predict=lambda x, y: { @@ -356,8 +357,9 @@ def _get_predict_client(model_name, model_version): consumer._get_predict_client = _get_predict_client img = np.zeros((1, 32, 32, 3)) - out = consumer.grpc_image(img, 'f16model', 1) - assert img.shape == out.shape[1:] + out = consumer.grpc_image(img, 'model', 1, model_shape, 'DT_HALF') + + assert img.shape == out.shape assert img.sum() == out.sum() def test_get_model_metadata(self): @@ -426,14 +428,12 @@ def test_predict(self): consumer = consumers.TensorFlowServingConsumer(redis_client, None, 'q') def grpc_image(data, *args, **kwargs): - data = np.expand_dims(data, axis=0) return data def grpc_image_list(data, *args, **kwargs): # pylint: disable=W0613 - data = np.expand_dims(data, axis=0) return [data, data] - model_shape = (1, 128, 128, 1) + model_shape = (-1, 128, 128, 1) image_shapes = [ (256, 256, 1), @@ -445,15 +445,17 @@ def grpc_image_list(data, *args, **kwargs): # pylint: disable=W0613 for image_shape in image_shapes: for grpc_func in (grpc_image, grpc_image_list): + for untile in (False, True): - x = np.random.random(image_shape) - consumer.grpc_image = grpc_func - consumer.get_model_metadata = lambda x, y: { - 'in_tensor_dtype': 'DT_FLOAT', - 'in_tensor_shape': ','.join(str(s) for s in model_shape), - } + x = np.random.random(image_shape) + consumer.grpc_image = grpc_func + consumer.get_model_metadata = lambda x, y: { + 'in_tensor_dtype': 'DT_HALF', + 'in_tensor_shape': ','.join(str(s) for s in model_shape), + } - consumer.predict(x, model_name='modelname', model_version=0) + consumer.predict(x, model_name='modelname', model_version=0, + untile=untile) class TestZipFileConsumer(object): diff --git a/redis_consumer/consumers/image_consumer_test.py b/redis_consumer/consumers/image_consumer_test.py index b78a3301..7aabd74e 100644 --- a/redis_consumer/consumers/image_consumer_test.py +++ b/redis_consumer/consumers/image_consumer_test.py @@ -306,15 +306,12 @@ def _handle_error(err, rhash): raise err def grpc_image(data, *args, **kwargs): - data = np.expand_dims(data, axis=0) return data def grpc_image_multi(data, *args, **kwargs): - data = np.expand_dims(data, axis=0) return np.array(tuple(list(data.shape) + [2])) def grpc_image_list(data, *args, **kwargs): # pylint: disable=W0613 - data = np.expand_dims(data, axis=0) return [data, data] def detect_scale(_): @@ -336,9 +333,9 @@ def get_model_metadata(model_name, model_version): dummyhash = '{}:test.tiff:{}'.format(prefix, status) model_shapes = [ - (1, 600, 600, 1), # image too small, pad - (1, 300, 300, 1), # image is exactly the right size - (1, 150, 150, 1), # image too big, tile + (-1, 600, 600, 1), # image too small, pad + (-1, 300, 300, 1), # image is exactly the right size + (-1, 150, 150, 1), # image too big, tile ] consumer._handle_error = _handle_error From c5b76da8e80b6d088c2f24c397c74a8f74b05b9a Mon Sep 17 00:00:00 2001 From: William Graf <7930703+willgraf@users.noreply.github.com> Date: Tue, 10 Mar 2020 13:13:32 -0700 Subject: [PATCH 42/47] check if length is 0, not == [] --- redis_consumer/consumers/base_consumer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/redis_consumer/consumers/base_consumer.py b/redis_consumer/consumers/base_consumer.py index da65b588..95b8fba3 100644 --- a/redis_consumer/consumers/base_consumer.py +++ b/redis_consumer/consumers/base_consumer.py @@ -390,7 +390,7 @@ def _predict_big_image(self, if not isinstance(output, list): output = [output] - if results == []: + if len(results) == 0: results = output else: for i, o in enumerate(output): From eeaf1a564218877ac1c6a05a9760385ceeae75cc Mon Sep 17 00:00:00 2001 From: William Graf <7930703+willgraf@users.noreply.github.com> Date: Wed, 11 Mar 2020 11:56:57 -0700 Subject: [PATCH 43/47] add GatewayTimeout to list of retry-able storage exceptions. --- redis_consumer/storage.py | 1 + 1 file changed, 1 insertion(+) diff --git a/redis_consumer/storage.py b/redis_consumer/storage.py index 85b6038c..24a58785 100644 --- a/redis_consumer/storage.py +++ b/redis_consumer/storage.py @@ -152,6 +152,7 @@ def __init__(self, bucket, download_dir=DOWNLOAD_DIR, backoff=1.5): google_exceptions.TooManyRequests, google_exceptions.InternalServerError, google_exceptions.ServiceUnavailable, + google_exceptions.GatewayTimeout, urllib3.exceptions.MaxRetryError, urllib3.exceptions.NewConnectionError, requests.exceptions.ConnectionError, From 4ff6db45993e1e3a3ebd4c101d9d219fa3b4eed6 Mon Sep 17 00:00:00 2001 From: William Graf <7930703+willgraf@users.noreply.github.com> Date: Fri, 13 Mar 2020 10:22:45 -0700 Subject: [PATCH 44/47] clena up purge_processing_queue --- redis_consumer/consumers/base_consumer.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/redis_consumer/consumers/base_consumer.py b/redis_consumer/consumers/base_consumer.py index 95b8fba3..aa8294c2 100644 --- a/redis_consumer/consumers/base_consumer.py +++ b/redis_consumer/consumers/base_consumer.py @@ -143,13 +143,14 @@ def get_current_timestamp(self): def purge_processing_queue(self): """Move all items from the processing queue to the work queue""" - while True: + queue_has_items = True + while queue_has_items: key = self.redis.rpoplpush(self.processing_queue, self.queue) - if key is None: - break - self.logger.debug('Found stranded key `%s` in queue `%s`. ' - 'Moving it back to `%s`.', - key, self.processing_queue, self.queue) + queue_has_items = key is not None + + self.logger.debug('Found stranded key `%s` in queue `%s`. ' + 'Moving it back to `%s`.', + key, self.processing_queue, self.queue) def update_key(self, redis_hash, data=None): """Update the hash with `data` and updated_by & updated_at stamps. From cd6f8481cab526d50bc3821189453413416ba401 Mon Sep 17 00:00:00 2001 From: William Graf <7930703+willgraf@users.noreply.github.com> Date: Fri, 13 Mar 2020 10:22:50 -0700 Subject: [PATCH 45/47] add comment --- redis_consumer/consumers/base_consumer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/redis_consumer/consumers/base_consumer.py b/redis_consumer/consumers/base_consumer.py index aa8294c2..0f1b5639 100644 --- a/redis_consumer/consumers/base_consumer.py +++ b/redis_consumer/consumers/base_consumer.py @@ -219,7 +219,7 @@ def consume(self): self._put_back_hash(redis_hash) time.sleep(.5) - else: + else: # queue is empty self.logger.debug('Queue `%s` is empty. Waiting for %s seconds.', self.queue, settings.EMPTY_QUEUE_TIMEOUT) time.sleep(settings.EMPTY_QUEUE_TIMEOUT) From 4e62c9ec8e2727e2894ec1ce441fa7ed94209872 Mon Sep 17 00:00:00 2001 From: William Graf <7930703+willgraf@users.noreply.github.com> Date: Thu, 19 Mar 2020 23:10:38 -0700 Subject: [PATCH 46/47] add docs/ and build/ to the dockerignore. --- .dockerignore | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.dockerignore b/.dockerignore index 22dcd45e..3d12409d 100644 --- a/.dockerignore +++ b/.dockerignore @@ -3,6 +3,8 @@ output/ download/ logs/ protos/ +docs/ +build/ # Byte-compiled / optimized / DLL files __pycache__/ From f551d3659ad542e60243664970f815d6ba7b2a78 Mon Sep 17 00:00:00 2001 From: William Graf <7930703+willgraf@users.noreply.github.com> Date: Thu, 19 Mar 2020 23:33:04 -0700 Subject: [PATCH 47/47] don't sleep after putting back the hash. maybe in another PR. --- redis_consumer/consumers/base_consumer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/redis_consumer/consumers/base_consumer.py b/redis_consumer/consumers/base_consumer.py index f33874c4..3011507e 100644 --- a/redis_consumer/consumers/base_consumer.py +++ b/redis_consumer/consumers/base_consumer.py @@ -217,7 +217,6 @@ def consume(self): # this key is not done yet. # remove it from processing and push it back to the work queue. self._put_back_hash(redis_hash) - time.sleep(.5) else: # queue is empty self.logger.debug('Queue `%s` is empty. Waiting for %s seconds.',