Skip to content

Commit

Permalink
Modify PredictStreamed to return a response or an error.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 611657273
  • Loading branch information
kenfranko authored and tensorflow-copybara committed Mar 1, 2024
1 parent 2c7a489 commit 5b5d30f
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 14 deletions.
5 changes: 3 additions & 2 deletions tensorflow_serving/servables/tensorflow/mock_servable.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ limitations under the License.

#include <memory>

#include <gmock/gmock.h>
#include "absl/functional/any_invocable.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
Expand All @@ -27,7 +28,6 @@ limitations under the License.
#include "tensorflow_serving/apis/predict.pb.h"
#include "tensorflow_serving/apis/regression.pb.h"
#include "tensorflow_serving/servables/tensorflow/servable.h"
#include "tensorflow_serving/test_util/test_util.h"

namespace tensorflow {
namespace serving {
Expand Down Expand Up @@ -63,7 +63,8 @@ class MockServable : public Servable {
MOCK_METHOD(absl::StatusOr<std::unique_ptr<PredictStreamedContext>>,
PredictStreamed,
(const tensorflow::serving::Servable::RunOptions& run_options,
absl::AnyInvocable<void(tensorflow::serving::PredictResponse)>
absl::AnyInvocable<
void(absl::StatusOr<tensorflow::serving::PredictResponse>)>
response_callback),
(final));
MOCK_METHOD(absl::Status, MultiInference,
Expand Down
23 changes: 13 additions & 10 deletions tensorflow_serving/servables/tensorflow/servable.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ limitations under the License.
#include "tensorflow_serving/apis/inference.pb.h"
#include "tensorflow_serving/apis/predict.pb.h"
#include "tensorflow_serving/apis/regression.pb.h"
#include "tensorflow_serving/servables/tensorflow/run_options.h"
#include "tensorflow_serving/servables/tensorflow/google/run_options.h"

namespace tensorflow {
namespace serving {
Expand Down Expand Up @@ -109,15 +109,17 @@ class Servable {
// alive until the context object is deleted.
//
// `response_callback` is called for each streamed output, zero or more times,
// when the streamed output becomes available. The callback invocation must be
// serialized by the implementation, so that `response_callback` does not have
// to be thread-safe, but blocking inside the callback may cause the next
// callback invocation to be delayed. The implementation must guarantee that
// the callback is never called after the `PredictStreamed` method returns.
// when the streamed output becomes available. If an error is returned for any
// response, subsequent responses and requests will be ignored and the error
// will be returned. The callback invocation must be serialized by the
// implementation, so that `response_callback` does not have to be
// thread-safe, but blocking inside the callback may cause the next callback
// invocation to be delayed. The implementation must guarantee that the
// callback is never called after the `PredictStreamed` method returns.
virtual absl::StatusOr<std::unique_ptr<PredictStreamedContext>>
PredictStreamed(
const RunOptions& run_options,
absl::AnyInvocable<void(PredictResponse)> response_callback) = 0;
PredictStreamed(const RunOptions& run_options,
absl::AnyInvocable<void(absl::StatusOr<PredictResponse>)>
response_callback) = 0;

virtual absl::Status MultiInference(const RunOptions& run_options,
const MultiInferenceRequest& request,
Expand Down Expand Up @@ -165,7 +167,8 @@ class EmptyServable : public Servable {

absl::StatusOr<std::unique_ptr<PredictStreamedContext>> PredictStreamed(
const RunOptions& run_options,
absl::AnyInvocable<void(PredictResponse)> response_callback) {
absl::AnyInvocable<void(absl::StatusOr<PredictResponse>)>
response_callback) {
return error_;
}

Expand Down
5 changes: 4 additions & 1 deletion tensorflow_serving/servables/tensorflow/tfrt_servable.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,11 @@ limitations under the License.
#include "tensorflow/cc/saved_model/signature_constants.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/platform/statusor.h"
#include "tensorflow/core/platform/tracing.h" // NOLINT
#include "tensorflow/core/tfrt/saved_model/saved_model.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/statusor.h"
#include "tsl/platform/threadpool_options.h"
#include "tensorflow_serving/apis/classification.pb.h"
#include "tensorflow_serving/apis/get_model_metadata.pb.h"
Expand Down Expand Up @@ -126,7 +128,8 @@ absl::Status TfrtSavedModelServable::Predict(const RunOptions& run_options,
absl::StatusOr<std::unique_ptr<PredictStreamedContext>>
TfrtSavedModelServable::PredictStreamed(
const RunOptions& run_options,
absl::AnyInvocable<void(PredictResponse)> response_callback) {
absl::AnyInvocable<void(absl::StatusOr<PredictResponse>)>
response_callback) {
return std::make_unique<SingleRequestPredictStreamedContext>(
[this, run_options, response_callback = std::move(response_callback)](
const PredictRequest& request) mutable -> absl::Status {
Expand Down
4 changes: 3 additions & 1 deletion tensorflow_serving/servables/tensorflow/tfrt_servable.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ limitations under the License.
#include "absl/functional/any_invocable.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "tensorflow/core/tfrt/saved_model/saved_model.h"
#include "tensorflow_serving/apis/classification.pb.h"
#include "tensorflow_serving/apis/get_model_metadata.pb.h"
Expand Down Expand Up @@ -61,7 +62,8 @@ class TfrtSavedModelServable : public Servable {

absl::StatusOr<std::unique_ptr<PredictStreamedContext>> PredictStreamed(
const RunOptions& run_options,
absl::AnyInvocable<void(PredictResponse)> response_callback) override;
absl::AnyInvocable<void(absl::StatusOr<PredictResponse>)>
response_callback) override;

absl::Status MultiInference(const RunOptions& run_options,
const MultiInferenceRequest& request,
Expand Down

0 comments on commit 5b5d30f

Please sign in to comment.