Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix the lifetime of InferPayload #241

Merged
merged 12 commits into from
May 17, 2023
26 changes: 14 additions & 12 deletions src/infer_payload.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,28 +33,30 @@ InferPayload::InferPayload(
std::function<void(std::unique_ptr<InferResponse>)> callback)
: is_decoupled_(is_decoupled), is_promise_set_(false), callback_(callback)
{
prev_promise_.reset(new std::promise<std::unique_ptr<InferResponse>>());
}

InferPayload::~InferPayload()
{
prev_promise_.reset();
promise_.reset(new std::promise<std::unique_ptr<InferResponse>>());
}

void
InferPayload::SetValueForPrevPromise(
std::unique_ptr<InferResponse> infer_response)
InferPayload::SetValue(std::unique_ptr<InferResponse> infer_response)
{
prev_promise_->set_value(std::move(infer_response));
prev_promise_.reset();
is_promise_set_ = true;
{
// Only set value to the promise with the first response. Call the callback
// function to send decoupled response to the stub.
std::lock_guard<std::mutex> lock(mutex_);
if (!is_promise_set_) {
is_promise_set_ = true;
promise_->set_value(std::move(infer_response));
return;
}
}
Callback(std::move(infer_response));
}

void
InferPayload::SetFuture(
std::future<std::unique_ptr<InferResponse>>& response_future)
{
response_future = prev_promise_->get_future();
response_future = promise_->get_future();
}

bool
Expand Down
12 changes: 8 additions & 4 deletions src/infer_payload.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,17 @@ struct ResponseAllocatorUserp {
PreferredMemory preferred_memory;
};

class InferPayload {
class InferPayload : public std::enable_shared_from_this<InferPayload> {
public:
InferPayload(
const bool is_decouple,
std::function<void(std::unique_ptr<InferResponse>)> callback);
~InferPayload();

void SetValueForPrevPromise(std::unique_ptr<InferResponse> infer_response);
/// GetPtr should be only called when the InferPayload object is constructed
/// using a shared pointer. Calling this function in any other circumstance
/// is undefined behaviour until C++17.
std::shared_ptr<InferPayload> GetPtr() { return shared_from_this(); }
void SetValue(std::unique_ptr<InferResponse> infer_response);
void SetFuture(std::future<std::unique_ptr<InferResponse>>& response_future);
bool IsDecoupled();
bool IsPromiseSet();
Expand All @@ -60,8 +63,9 @@ class InferPayload {
std::shared_ptr<ResponseAllocatorUserp> ResponseAllocUserp();

private:
std::unique_ptr<std::promise<std::unique_ptr<InferResponse>>> prev_promise_;
std::unique_ptr<std::promise<std::unique_ptr<InferResponse>>> promise_;
bool is_decoupled_;
std::mutex mutex_;
bool is_promise_set_;
std::function<void(std::unique_ptr<InferResponse>)> callback_;
std::shared_ptr<ResponseAllocatorUserp> response_alloc_userp_;
Expand Down
20 changes: 8 additions & 12 deletions src/request_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ void
InferResponseComplete(
TRITONSERVER_InferenceResponse* response, const uint32_t flags, void* userp)
{
auto p = reinterpret_cast<InferPayload*>(userp);
auto linfer_payload = reinterpret_cast<InferPayload*>(userp);
std::shared_ptr<InferPayload> infer_payload = linfer_payload->GetPtr();
Comment on lines +80 to +81
Copy link
Contributor

@rmccorm4 rmccorm4 May 16, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it valid to call raw_ptr->GetPtr() here?

I think it is because under the hood we know that userp was infer_payload.get(), where infer_payload was a shared_ptr, but I think it might a little sketchy/prone to future error if the caller changes the type from shared_ptr to something to unique or raw or something else in the future. If this is indeed correct and don't have a clean alternative, I think we should document the constraints with a comment or two here.

It is permitted to call shared_from_this only on a previously shared object, i.e. on an object managed by std::shared_ptr. Otherwise the behavior is undefined (until C++17)std::bad_weak_ptr is thrown (by the shared_ptr constructor from a default-constructed weak_this) (since C++17).

enable_shared_from_this provides the safe alternative to an expression like std::shared_ptr(this), which is likely to result in this being destructed more than once by multiple owners that are unaware of each other (see example below).

ref:

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also can you summarize in words the current issue and what this PR is solving?

The issue appears to be some lifetime issues passing the payload back and forth, and it looks like we're trying to solve this by passing shared_ptr through (C++ <-> C API <-> C++) to automatically manage the lifetime on both sides.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated the PR description based on my understanding. @Tabrizian Please make any changes if needed.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also can you summarize in words the current issue and what this PR is solving?

The issue appears to be some lifetime issues passing the payload back and forth, and it looks like we're trying to solve this by passing shared_ptr through (C++ <-> C API <-> C++) to automatically manage the lifetime on both sides.

Given the current design of the infer_payload being created as a shared pointer by python_be and then passed to request_executor, I believe this is cleanest approach. The basic problem was that the lifetime was managed by the shared pointer but the callback wasn't using a shared pointer and thus never increased the reference count.

That being said - I think if we were able to refactor to have the request_executor manage the creation and lifetime of infer_payload that may be cleaner (it doesn't seem at first glance that python_be needs to hang on to a reference - just the future). The creation and destruction ideally would be all handled internal to request_executor.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it valid to call raw_ptr->GetPtr() here?

I think it is because under the hood we know that userp was infer_payload.get(), where infer_payload was a shared_ptr, but I think it might a little sketchy/prone to future error if the caller changes the type from shared_ptr to something to unique or raw or something else in the future. If this is indeed correct and don't have a clean alternative, I think we should document the constraints with a comment or two here.

It is permitted to call shared_from_this only on a previously shared object, i.e. on an object managed by std::shared_ptr. Otherwise the behavior is undefined (until C++17)std::bad_weak_ptr is thrown (by the shared_ptr constructor from a default-constructed weak_this) (since C++17).
enable_shared_from_this provides the safe alternative to an expression like std::shared_ptr(this), which is likely to result in this being destructed more than once by multiple owners that are unaware of each other (see example below).

ref:

@Tabrizian also had an alternative that involved dynamically allocating a shared pointer which would be passed instead of the raw infer_payload. That method also maintains the lifetime of the object. In that case we would have a pointer to a shared pointer pointing to a infer_payload. For me that was more confusing - but open to reconsider.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm inclined to use the enable_shared_from_this since it is cleaner. Added a comment to warn the API user about this caveat.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @krishung5 for updating the description. Looks accurate to me!

std::unique_ptr<InferResponse> infer_response;
std::vector<std::shared_ptr<PbTensor>> output_tensors;
std::shared_ptr<PbError> pb_error;
Expand Down Expand Up @@ -146,7 +147,7 @@ InferResponseComplete(
output_tensors.clear();
}

if (!p->IsDecoupled()) {
if (!infer_payload->IsDecoupled()) {
infer_response = std::make_unique<InferResponse>(
output_tensors, pb_error, true /* is_last_response */);
} else {
Expand All @@ -167,7 +168,8 @@ InferResponseComplete(
TRITONSERVER_InferenceResponseDelete(response),
"Failed to release BLS inference response.");
} else if (
p->IsDecoupled() && (flags & TRITONSERVER_RESPONSE_COMPLETE_FINAL) != 0) {
(infer_payload)->IsDecoupled() &&
(flags & TRITONSERVER_RESPONSE_COMPLETE_FINAL) != 0) {
// An empty response may be the last reponse for decoupled models.
infer_response = std::make_unique<InferResponse>(
output_tensors, pb_error, true /* is_last_response */, userp /* id */);
Expand All @@ -177,13 +179,7 @@ InferResponseComplete(
output_tensors, pb_error, true /* is_last_response */, userp /* id */);
}

// Only set value to the promise with the first response. Call the callback
// function to send decoupled response to the stub.
if (p->IsPromiseSet()) {
p->Callback(std::move(infer_response));
} else {
p->SetValueForPrevPromise(std::move(infer_response));
}
infer_payload->SetValue(std::move(infer_response));
}

TRITONSERVER_Error*
Expand Down Expand Up @@ -339,8 +335,8 @@ RequestExecutor::Infer(
std::string("Model ") + model_name +
" is using the decoupled. The current BLS request call doesn't "
"support models using the decoupled transaction policy. Please use "
"stream API 'stream_exec()' or 'async_stream_exec() for decoupled "
"models.'");
"'decoupled=True' argument to the 'exec' or 'async_exec' calls for "
"decoupled models.'");
}

// Inference
Expand Down