Skip to content

Commit

Permalink
Fix the lifetime of InferPayload (#241)
Browse files Browse the repository at this point in the history
* Add mutex for InferPayload to make sure it's thread-safe during callback

* Remove reset for the promise

* Address comment

* Remove destructor

* Fix lifetime of infer payload

* Make sure the mutex is unlocked before promise.set_value

* Revert "Make sure the mutex is unlocked before promise.set_value"

This reverts commit 2eb5c32.

* fix leak

* Serialize all the responses in decoupled BLS

* use enable_shared_from_this

* Add a warning about using "GetPtr"

* Remove the callback from mutex lock

---------

Co-authored-by: Iman Tabrizian <itabrizian@nvidia.com>
  • Loading branch information
krishung5 and Tabrizian committed May 17, 2023
1 parent 6c4b817 commit 84f3bb7
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 28 deletions.
26 changes: 14 additions & 12 deletions src/infer_payload.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,28 +33,30 @@ InferPayload::InferPayload(
std::function<void(std::unique_ptr<InferResponse>)> callback)
: is_decoupled_(is_decoupled), is_promise_set_(false), callback_(callback)
{
prev_promise_.reset(new std::promise<std::unique_ptr<InferResponse>>());
}

InferPayload::~InferPayload()
{
prev_promise_.reset();
promise_.reset(new std::promise<std::unique_ptr<InferResponse>>());
}

void
InferPayload::SetValueForPrevPromise(
std::unique_ptr<InferResponse> infer_response)
InferPayload::SetValue(std::unique_ptr<InferResponse> infer_response)
{
prev_promise_->set_value(std::move(infer_response));
prev_promise_.reset();
is_promise_set_ = true;
{
// Only set value to the promise with the first response. Call the callback
// function to send decoupled response to the stub.
std::lock_guard<std::mutex> lock(mutex_);
if (!is_promise_set_) {
is_promise_set_ = true;
promise_->set_value(std::move(infer_response));
return;
}
}
Callback(std::move(infer_response));
}

void
InferPayload::SetFuture(
std::future<std::unique_ptr<InferResponse>>& response_future)
{
response_future = prev_promise_->get_future();
response_future = promise_->get_future();
}

bool
Expand Down
12 changes: 8 additions & 4 deletions src/infer_payload.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,17 @@ struct ResponseAllocatorUserp {
PreferredMemory preferred_memory;
};

class InferPayload {
class InferPayload : public std::enable_shared_from_this<InferPayload> {
public:
InferPayload(
const bool is_decouple,
std::function<void(std::unique_ptr<InferResponse>)> callback);
~InferPayload();

void SetValueForPrevPromise(std::unique_ptr<InferResponse> infer_response);
/// GetPtr should be only called when the InferPayload object is constructed
/// using a shared pointer. Calling this function in any other circumstance
/// is undefined behaviour until C++17.
std::shared_ptr<InferPayload> GetPtr() { return shared_from_this(); }
void SetValue(std::unique_ptr<InferResponse> infer_response);
void SetFuture(std::future<std::unique_ptr<InferResponse>>& response_future);
bool IsDecoupled();
bool IsPromiseSet();
Expand All @@ -60,8 +63,9 @@ class InferPayload {
std::shared_ptr<ResponseAllocatorUserp> ResponseAllocUserp();

private:
std::unique_ptr<std::promise<std::unique_ptr<InferResponse>>> prev_promise_;
std::unique_ptr<std::promise<std::unique_ptr<InferResponse>>> promise_;
bool is_decoupled_;
std::mutex mutex_;
bool is_promise_set_;
std::function<void(std::unique_ptr<InferResponse>)> callback_;
std::shared_ptr<ResponseAllocatorUserp> response_alloc_userp_;
Expand Down
20 changes: 8 additions & 12 deletions src/request_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ void
InferResponseComplete(
TRITONSERVER_InferenceResponse* response, const uint32_t flags, void* userp)
{
auto p = reinterpret_cast<InferPayload*>(userp);
auto linfer_payload = reinterpret_cast<InferPayload*>(userp);
std::shared_ptr<InferPayload> infer_payload = linfer_payload->GetPtr();
std::unique_ptr<InferResponse> infer_response;
std::vector<std::shared_ptr<PbTensor>> output_tensors;
std::shared_ptr<PbError> pb_error;
Expand Down Expand Up @@ -146,7 +147,7 @@ InferResponseComplete(
output_tensors.clear();
}

if (!p->IsDecoupled()) {
if (!infer_payload->IsDecoupled()) {
infer_response = std::make_unique<InferResponse>(
output_tensors, pb_error, true /* is_last_response */);
} else {
Expand All @@ -167,7 +168,8 @@ InferResponseComplete(
TRITONSERVER_InferenceResponseDelete(response),
"Failed to release BLS inference response.");
} else if (
p->IsDecoupled() && (flags & TRITONSERVER_RESPONSE_COMPLETE_FINAL) != 0) {
(infer_payload)->IsDecoupled() &&
(flags & TRITONSERVER_RESPONSE_COMPLETE_FINAL) != 0) {
// An empty response may be the last reponse for decoupled models.
infer_response = std::make_unique<InferResponse>(
output_tensors, pb_error, true /* is_last_response */, userp /* id */);
Expand All @@ -177,13 +179,7 @@ InferResponseComplete(
output_tensors, pb_error, true /* is_last_response */, userp /* id */);
}

// Only set value to the promise with the first response. Call the callback
// function to send decoupled response to the stub.
if (p->IsPromiseSet()) {
p->Callback(std::move(infer_response));
} else {
p->SetValueForPrevPromise(std::move(infer_response));
}
infer_payload->SetValue(std::move(infer_response));
}

TRITONSERVER_Error*
Expand Down Expand Up @@ -339,8 +335,8 @@ RequestExecutor::Infer(
std::string("Model ") + model_name +
" is using the decoupled. The current BLS request call doesn't "
"support models using the decoupled transaction policy. Please use "
"stream API 'stream_exec()' or 'async_stream_exec() for decoupled "
"models.'");
"'decoupled=True' argument to the 'exec' or 'async_exec' calls for "
"decoupled models.'");
}

// Inference
Expand Down

0 comments on commit 84f3bb7

Please sign in to comment.