-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Closed
Labels
Description
I'm serving the distilbert model from huggingface for text classification problem using tensorflow serving. I was able to call REST api for prediction just fine but I decided to use gRPC to improve the predicting speed. But when I use gRPC I get the following error:
ERROR:root:Exception in callback <function _callback at 0x7f88f3052950>: <_Rendezvous of RPC that terminated with:
status = StatusCode.INVALID_ARGUMENT
details = "Input to reshape is a tensor with 262144 values, but the requested shape has 512
[[{{node model/wrapped__distil_bert/tf_distil_bert_model/distilbert/transformer/layer_._0/attention/Reshape_3}}]]"
debug_error_string = "{"created":"@1579314389.082482474","description":"Error received from peer ipv6:[::1]:3502","file":"src/core/lib/surface/call.cc","file_line":1055,"grpc_message":"Input to reshape is a tensor with 262144 values, but the requested shape has 512\n\t [[{{node model/wrapped__distil_bert/tf_distil_bert_model/distilbert/transformer/layer_._0/attention/Reshape_3}}]]","grpc_status":3}"
This is the code I use to call the gRPC api:
def _callback(result_future):
print(result_future.result().outputs)
def do_gRPC_predict(doc):
channel = grpc.insecure_channel("{0}:{1}".format(global_config.tfserving.host, global_config.tfserving.port))
stub = prediction_service_pb2_grpc.PredictionServiceStub(channel)
req = predict_pb2.PredictRequest()
req.model_spec.name = global_config.tfserving.model_name
req.model_spec.signature_name = 'serving_default'
x = np.array(doc, dtype='int32')
tensor_proto = tf.make_tensor_proto(x, shape=[512])
req.inputs['input_text'].CopyFrom(tensor_proto)
result_future = stub.Predict.future(req, 10.25)
result_future.add_done_callback(_callback)
doc is a list of 512 token ids.
It is so weird that the same saved model I can use REST API but not gRPC. What can be the problem here?