Skip to content

Commit

Permalink
[pjrt] Switch GetReadyFuture to PjRtFuture<> to signal completion event
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 623347808
  • Loading branch information
ezhulenev authored and tensorflower-gardener committed Apr 12, 2024
1 parent 42e7261 commit 92b57df
Show file tree
Hide file tree
Showing 10 changed files with 90 additions and 81 deletions.
5 changes: 2 additions & 3 deletions third_party/xla/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc
Expand Up @@ -1826,9 +1826,8 @@ PJRT_Error* PJRT_Buffer_ReadyEvent(PJRT_Buffer_ReadyEvent_Args* args) {
PJRT_RETURN_IF_ERROR(ActualStructSizeIsGreaterOrEqual(
"PJRT_Buffer_ReadyEvent_Args", PJRT_Buffer_ReadyEvent_Args_STRUCT_SIZE,
args->struct_size));
xla::PjRtFuture<absl::Status> wrapped_promise =
args->buffer->buffer->GetReadyFuture();
args->event = new PJRT_Event{std::move(wrapped_promise)};
xla::PjRtFuture<> wrapped_promise = args->buffer->buffer->GetReadyFuture();
args->event = new PJRT_Event{std::move(wrapped_promise).ToStatusFuture()};
return nullptr;
}

Expand Down
22 changes: 10 additions & 12 deletions third_party/xla/xla/pjrt/cpu/abstract_tfrt_cpu_buffer.cc
Expand Up @@ -533,12 +533,12 @@ AbstractTfrtCpuBuffer::CopyToDeviceHelper(AsyncWorkRunner* async_work_runner) {
std::move(dst_definition_events));
}

PjRtFuture<Status> AbstractTfrtCpuBuffer::GetReadyFuture() {
PjRtFuture<> AbstractTfrtCpuBuffer::GetReadyFuture() {
tsl::AsyncValueRef<CpuEvent> definition_event;
{
absl::MutexLock lock(&mu_);
if (!tracked_device_buffer_) {
return PjRtFuture<Status>(InvalidArgument(
return PjRtFuture<>(InvalidArgument(
"GetReadyFuture() called on deleted or donated buffer"));
}
definition_event = tracked_device_buffer_->definition_event();
Expand All @@ -547,29 +547,27 @@ PjRtFuture<Status> AbstractTfrtCpuBuffer::GetReadyFuture() {

if (definition_event.IsAvailable()) {
if (definition_event.IsError()) {
return PjRtFuture<Status>(
return PjRtFuture<>(
FailedPrecondition("Buffer Definition Event: %s",
definition_event.GetError().message()));
}
return PjRtFuture<Status>(OkStatus());
return PjRtFuture<>(OkStatus());
} else {
tsl::AsyncValueRef<Status> status_event =
tsl::MakeUnconstructedAsyncValueRef<Status>();

PjRtFuture<>::Promise promise = PjRtFuture<>::CreatePromise();
definition_event.AndThen(
[definition_event = definition_event.AsPtr(), status_event]() {
[definition_event = definition_event.AsPtr(), promise]() mutable {
if (definition_event.IsError()) {
status_event.emplace(
promise.SetError(
FailedPrecondition("Buffer Definition Event: %s",
definition_event.GetError().message()));
} else {
status_event.emplace(OkStatus());
promise.Set();
}
});

std::string message = absl::StrCat(buffer_name(), "::Await");
return PjRtFuture<Status>(
std::move(status_event),
return PjRtFuture<>(
std::move(promise),
/*on_block_start=*/
[message]() {
absl::string_view message_view(message);
Expand Down
2 changes: 1 addition & 1 deletion third_party/xla/xla/pjrt/cpu/abstract_tfrt_cpu_buffer.h
Expand Up @@ -140,7 +140,7 @@ class AbstractTfrtCpuBuffer : public PjRtBuffer {
}
}

PjRtFuture<Status> GetReadyFuture() override;
PjRtFuture<> GetReadyFuture() override;

bool IsOnCpu() const override { return true; }

Expand Down
16 changes: 10 additions & 6 deletions third_party/xla/xla/pjrt/pjrt_c_api_client.cc
Expand Up @@ -2032,7 +2032,11 @@ void PjRtCApiBuffer::MakePromiseTrackEvent() {
args.user_arg = new std::function<void(PJRT_Error*)>(
[promise = readiness_promise_, api](PJRT_Error* error) -> void {
Status status = ::pjrt::PjrtErrorToStatus(error, api);
promise->Set(status);
if (status.ok()) {
promise->Set();
} else {
promise->SetError(status);
}
::pjrt::MakeErrorDeleter(api)(error);
});
args.callback = [](PJRT_Error* error, void* callback_ptr) {
Expand All @@ -2046,17 +2050,17 @@ void PjRtCApiBuffer::MakePromiseTrackEvent() {
std::unique_ptr<PJRT_Error, ::pjrt::PJRT_ErrorDeleter> error{
api->PJRT_Event_OnReady(&args), ::pjrt::MakeErrorDeleter(api)};
if (error != nullptr) {
readiness_promise_->Set(::pjrt::PjrtErrorToStatus(error.get(), api));
readiness_promise_->SetError(::pjrt::PjrtErrorToStatus(error.get(), api));
}
}

PjRtFuture<Status> PjRtCApiBuffer::GetReadyFuture() {
PjRtFuture<> PjRtCApiBuffer::GetReadyFuture() {
if (readiness_promise_ == nullptr) {
readiness_promise_ = std::make_shared<PjRtFuture<Status>::Promise>(
PjRtFuture<Status>::CreatePromise());
readiness_promise_ =
std::make_shared<PjRtFuture<>::Promise>(PjRtFuture<>::CreatePromise());
MakePromiseTrackEvent();
}
return PjRtFuture<Status>{*readiness_promise_};
return PjRtFuture<>{*readiness_promise_};
}

StatusOr<std::unique_ptr<PjRtBuffer::ExternalReference>>
Expand Down
4 changes: 2 additions & 2 deletions third_party/xla/xla/pjrt/pjrt_c_api_client.h
Expand Up @@ -531,7 +531,7 @@ class PjRtCApiBuffer : public PjRtBuffer {
LOG(ERROR) << "PJRT C API does not support CopyToRemoteDeviceScattered";
}

PjRtFuture<Status> GetReadyFuture() override;
PjRtFuture<> GetReadyFuture() override;

bool IsOnCpu() const override;

Expand All @@ -555,7 +555,7 @@ class PjRtCApiBuffer : public PjRtBuffer {
// This is a shared_ptr to keep the underlying future alive even if
// `readiness_promise` is destroyed before `readiness_event`, and the callback
// we set on `readiness_event` modifies `readiness_promise_`.
std::shared_ptr<PjRtFuture<Status>::Promise> readiness_promise_;
std::shared_ptr<PjRtFuture<>::Promise> readiness_promise_;
// Set and cached the first time layout() is called.
mutable std::optional<PjRtXlaLayout> layout_;
// Set and cached the first time is_dynamic_dimension() is called.
Expand Down
2 changes: 1 addition & 1 deletion third_party/xla/xla/pjrt/pjrt_client.h
Expand Up @@ -1385,7 +1385,7 @@ class PjRtBuffer {
// the buffer has been deleted or donated then the returned future will stay
// valid (will not transition to error as a consequence of buffer deletion)
// even if the buffer is subsequently donated or deleted.
virtual PjRtFuture<Status> GetReadyFuture() = 0;
virtual PjRtFuture<> GetReadyFuture() = 0;

// Blocks the host until the buffer's value has been computed and is ready for
// immediate use on the device. Useful in particular for timing benchmarks.
Expand Down
106 changes: 58 additions & 48 deletions third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc
Expand Up @@ -1895,75 +1895,85 @@ void PjRtStreamExecutorBuffer::CopyToRemoteDeviceScattered(
}
}

PjRtFuture<Status> PjRtStreamExecutorBuffer::GetReadyFuture() {
PjRtFuture<> PjRtStreamExecutorBuffer::GetReadyFuture() {
std::shared_ptr<TrackedDeviceBuffer> device_buffer;
PjRtFuture<Status>::Promise definition_promise;
PjRtFuture<>::Promise definition_promise;
{
absl::MutexLock lock(&mu_);
if (device_buffer_ == nullptr) {
return PjRtFuture<Status>(InvalidArgument(
return PjRtFuture<>(InvalidArgument(
"GetReadyFuture() called on deleted or donated buffer"));
}
if (!definition_promise_) {
device_buffer = device_buffer_;
definition_promise_ = PjRtFuture<Status>::CreatePromise();
definition_promise_ = PjRtFuture<>::CreatePromise();
}
definition_promise = definition_promise_;
}

if (device_buffer) {
LocalDeviceState* local_device_state = device_->local_device_state();
auto async_wait_for_events =
[device_buffer, local_device_state = std::move(local_device_state),
definition_promise]() mutable {
std::unique_ptr<se::Stream> stream;
Status defined_status =
device_buffer->definition_events()[0]->GetDefinedStatus();
if (!defined_status.ok()) {
definition_promise.Set(defined_status);
return;
}
for (auto& event : device_buffer->definition_events()) {
if (!event->IsComplete()) {
if (stream == nullptr) {
stream = local_device_state->BorrowStreamFromPool();
}
event->WaitForEventOnStream(stream.get());
}
auto async_wait_for_events = [device_buffer,
local_device_state =
std::move(local_device_state),
definition_promise]() mutable {
std::unique_ptr<se::Stream> stream;
Status defined_status =
device_buffer->definition_events()[0]->GetDefinedStatus();
if (!defined_status.ok()) {
definition_promise.SetError(defined_status);
return;
}
for (auto& event : device_buffer->definition_events()) {
if (!event->IsComplete()) {
if (stream == nullptr) {
stream = local_device_state->BorrowStreamFromPool();
}
event->WaitForEventOnStream(stream.get());
}
}

if (stream != nullptr) {
auto* stream_ptr = stream.release();
// We already borrowed a stream from the pool so we can safely do
// the callback directly on that stream instead of bouncing through
// local_device_state->ThenExecuteCallback. The direct callback
// saves significant time.
auto status = stream_ptr->DoHostCallback(
[definition_promise, stream_ptr, local_device_state,
event_with_status =
device_buffer->definition_events()[0]]() mutable {
local_device_state->ReturnStreamToPool(
std::unique_ptr<se::Stream>(stream_ptr));
definition_promise.Set(event_with_status->GetDefinedStatus());
});
if (!status.ok()) {
definition_promise.Set(status);
return;
}
} else {
// All events are already complete; set the `definition_promise`
// with the status of the buffer's first definition event which may
// have error status to propagate.
definition_promise.Set(
device_buffer->definition_events()[0]->GetDefinedStatus());
}
};
if (stream != nullptr) {
auto* stream_ptr = stream.release();
// We already borrowed a stream from the pool so we can safely do
// the callback directly on that stream instead of bouncing through
// local_device_state->ThenExecuteCallback. The direct callback
// saves significant time.
auto status = stream_ptr->DoHostCallback(
[definition_promise, stream_ptr, local_device_state,
event_with_status =
device_buffer->definition_events()[0]]() mutable {
local_device_state->ReturnStreamToPool(
std::unique_ptr<se::Stream>(stream_ptr));
auto status = event_with_status->GetDefinedStatus();
if (status.ok()) {
definition_promise.Set();
} else {
definition_promise.SetError(status);
}
});
if (!status.ok()) {
definition_promise.SetError(status);
return;
}
} else {
// All events are already complete; set the `definition_promise`
// with the status of the buffer's first definition event which may
// have error status to propagate.
auto status = device_buffer->definition_events()[0]->GetDefinedStatus();
if (status.ok()) {
definition_promise.Set();
} else {
definition_promise.SetError(status);
}
}
};
device_buffer->definition_events()[0]->ExecuteOrAddToFutureTasks(
absl::StrFormat("async_wait_for_events_%p", &async_wait_for_events),
std::move(async_wait_for_events));
}

return PjRtFuture<Status>(
return PjRtFuture<>(
std::move(definition_promise),
/*on_block_start=*/
[]() {
Expand Down
4 changes: 2 additions & 2 deletions third_party/xla/xla/pjrt/pjrt_stream_executor_client.h
Expand Up @@ -723,7 +723,7 @@ class PjRtStreamExecutorBuffer : public PjRtBuffer {
std::vector<RemoteSendCallback> callbacks,
const ScatterDetails& scatter_details) override;

PjRtFuture<Status> GetReadyFuture() override;
PjRtFuture<> GetReadyFuture() override;

bool IsOnCpu() const override;

Expand Down Expand Up @@ -803,7 +803,7 @@ class PjRtStreamExecutorBuffer : public PjRtBuffer {
std::shared_ptr<TrackedDeviceBuffer> device_buffer_ ABSL_GUARDED_BY(mu_);
// Count of holds on the buffer.
std::array<int, ScopedHold::Type::kMaxValue> holds_ ABSL_GUARDED_BY(mu_);
PjRtFuture<Status>::Promise definition_promise_ ABSL_GUARDED_BY(mu_);
PjRtFuture<>::Promise definition_promise_ ABSL_GUARDED_BY(mu_);
};

// Wraps one or more XLA LocalExecutables (one per partition, as specified by
Expand Down
6 changes: 2 additions & 4 deletions third_party/xla/xla/pjrt/tf_pjrt_client.h
Expand Up @@ -99,9 +99,7 @@ class TfPjRtBuffer : public PjRtBuffer {
std::move(serialized_descriptors), std::move(callbacks),
scatter_details);
}
PjRtFuture<Status> GetReadyFuture() override {
return wrapped_->GetReadyFuture();
}
PjRtFuture<> GetReadyFuture() override { return wrapped_->GetReadyFuture(); }
bool IsOnCpu() const override { return wrapped_->IsOnCpu(); }

// Not thread-safe. The caller should promises to have some external
Expand Down Expand Up @@ -225,7 +223,7 @@ class TfPjRtClient : public PjRtClient {
PjRtLocalDeviceId local_device_id) const override {
if (wrapped_ == nullptr) {
return tsl::errors::Internal(
"Wrapped PJRT client in TfPjRtClient is already destoryed.");
"Wrapped PJRT client in TfPjRtClient is already destroyed.");
}
return wrapped_->LookupAddressableDevice(local_device_id);
}
Expand Down
4 changes: 2 additions & 2 deletions third_party/xla/xla/python/pjrt_ifrt/pjrt_array.cc
Expand Up @@ -524,12 +524,12 @@ absl::StatusOr<tsl::RCReference<Array>> PjRtArray::Reshard(
Future<Status> PjRtArray::GetReadyFuture() const {
DCHECK(this);
if (pjrt_buffers_.size() == 1) {
return pjrt_buffers_.front()->GetReadyFuture();
return pjrt_buffers_.front()->GetReadyFuture().ToStatusFuture();
}
std::vector<Future<Status>> futures;
futures.reserve(pjrt_buffers_.size());
for (auto& buf : pjrt_buffers_) {
futures.push_back(buf->GetReadyFuture());
futures.push_back(buf->GetReadyFuture().ToStatusFuture());
}
return JoinFutures(absl::MakeSpan(futures));
}
Expand Down

0 comments on commit 92b57df

Please sign in to comment.