diff --git a/third_party/xla/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc b/third_party/xla/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc index 65c579454b14de..53a2d060aca019 100644 --- a/third_party/xla/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc +++ b/third_party/xla/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc @@ -1826,9 +1826,8 @@ PJRT_Error* PJRT_Buffer_ReadyEvent(PJRT_Buffer_ReadyEvent_Args* args) { PJRT_RETURN_IF_ERROR(ActualStructSizeIsGreaterOrEqual( "PJRT_Buffer_ReadyEvent_Args", PJRT_Buffer_ReadyEvent_Args_STRUCT_SIZE, args->struct_size)); - xla::PjRtFuture wrapped_promise = - args->buffer->buffer->GetReadyFuture(); - args->event = new PJRT_Event{std::move(wrapped_promise)}; + xla::PjRtFuture<> wrapped_promise = args->buffer->buffer->GetReadyFuture(); + args->event = new PJRT_Event{std::move(wrapped_promise).ToStatusFuture()}; return nullptr; } diff --git a/third_party/xla/xla/pjrt/cpu/abstract_tfrt_cpu_buffer.cc b/third_party/xla/xla/pjrt/cpu/abstract_tfrt_cpu_buffer.cc index 9ac6270dded067..1c073f3f86a935 100644 --- a/third_party/xla/xla/pjrt/cpu/abstract_tfrt_cpu_buffer.cc +++ b/third_party/xla/xla/pjrt/cpu/abstract_tfrt_cpu_buffer.cc @@ -533,12 +533,12 @@ AbstractTfrtCpuBuffer::CopyToDeviceHelper(AsyncWorkRunner* async_work_runner) { std::move(dst_definition_events)); } -PjRtFuture AbstractTfrtCpuBuffer::GetReadyFuture() { +PjRtFuture<> AbstractTfrtCpuBuffer::GetReadyFuture() { tsl::AsyncValueRef definition_event; { absl::MutexLock lock(&mu_); if (!tracked_device_buffer_) { - return PjRtFuture(InvalidArgument( + return PjRtFuture<>(InvalidArgument( "GetReadyFuture() called on deleted or donated buffer")); } definition_event = tracked_device_buffer_->definition_event(); @@ -547,29 +547,27 @@ PjRtFuture AbstractTfrtCpuBuffer::GetReadyFuture() { if (definition_event.IsAvailable()) { if (definition_event.IsError()) { - return PjRtFuture( + return PjRtFuture<>( FailedPrecondition("Buffer Definition Event: %s", definition_event.GetError().message())); } - return PjRtFuture(OkStatus()); + return PjRtFuture<>(OkStatus()); } else { - tsl::AsyncValueRef status_event = - tsl::MakeUnconstructedAsyncValueRef(); - + PjRtFuture<>::Promise promise = PjRtFuture<>::CreatePromise(); definition_event.AndThen( - [definition_event = definition_event.AsPtr(), status_event]() { + [definition_event = definition_event.AsPtr(), promise]() mutable { if (definition_event.IsError()) { - status_event.emplace( + promise.SetError( FailedPrecondition("Buffer Definition Event: %s", definition_event.GetError().message())); } else { - status_event.emplace(OkStatus()); + promise.Set(); } }); std::string message = absl::StrCat(buffer_name(), "::Await"); - return PjRtFuture( - std::move(status_event), + return PjRtFuture<>( + std::move(promise), /*on_block_start=*/ [message]() { absl::string_view message_view(message); diff --git a/third_party/xla/xla/pjrt/cpu/abstract_tfrt_cpu_buffer.h b/third_party/xla/xla/pjrt/cpu/abstract_tfrt_cpu_buffer.h index b7a15a8fb95af6..113b3fa32f00cd 100644 --- a/third_party/xla/xla/pjrt/cpu/abstract_tfrt_cpu_buffer.h +++ b/third_party/xla/xla/pjrt/cpu/abstract_tfrt_cpu_buffer.h @@ -140,7 +140,7 @@ class AbstractTfrtCpuBuffer : public PjRtBuffer { } } - PjRtFuture GetReadyFuture() override; + PjRtFuture<> GetReadyFuture() override; bool IsOnCpu() const override { return true; } diff --git a/third_party/xla/xla/pjrt/pjrt_c_api_client.cc b/third_party/xla/xla/pjrt/pjrt_c_api_client.cc index cd0f4bcb1424e0..8e8b65619ca0b8 100644 --- a/third_party/xla/xla/pjrt/pjrt_c_api_client.cc +++ b/third_party/xla/xla/pjrt/pjrt_c_api_client.cc @@ -2032,7 +2032,11 @@ void PjRtCApiBuffer::MakePromiseTrackEvent() { args.user_arg = new std::function( [promise = readiness_promise_, api](PJRT_Error* error) -> void { Status status = ::pjrt::PjrtErrorToStatus(error, api); - promise->Set(status); + if (status.ok()) { + promise->Set(); + } else { + promise->SetError(status); + } ::pjrt::MakeErrorDeleter(api)(error); }); args.callback = [](PJRT_Error* error, void* callback_ptr) { @@ -2046,17 +2050,17 @@ void PjRtCApiBuffer::MakePromiseTrackEvent() { std::unique_ptr error{ api->PJRT_Event_OnReady(&args), ::pjrt::MakeErrorDeleter(api)}; if (error != nullptr) { - readiness_promise_->Set(::pjrt::PjrtErrorToStatus(error.get(), api)); + readiness_promise_->SetError(::pjrt::PjrtErrorToStatus(error.get(), api)); } } -PjRtFuture PjRtCApiBuffer::GetReadyFuture() { +PjRtFuture<> PjRtCApiBuffer::GetReadyFuture() { if (readiness_promise_ == nullptr) { - readiness_promise_ = std::make_shared::Promise>( - PjRtFuture::CreatePromise()); + readiness_promise_ = + std::make_shared::Promise>(PjRtFuture<>::CreatePromise()); MakePromiseTrackEvent(); } - return PjRtFuture{*readiness_promise_}; + return PjRtFuture<>{*readiness_promise_}; } StatusOr> diff --git a/third_party/xla/xla/pjrt/pjrt_c_api_client.h b/third_party/xla/xla/pjrt/pjrt_c_api_client.h index 97b5a616e7f77d..9c117775d9e749 100644 --- a/third_party/xla/xla/pjrt/pjrt_c_api_client.h +++ b/third_party/xla/xla/pjrt/pjrt_c_api_client.h @@ -531,7 +531,7 @@ class PjRtCApiBuffer : public PjRtBuffer { LOG(ERROR) << "PJRT C API does not support CopyToRemoteDeviceScattered"; } - PjRtFuture GetReadyFuture() override; + PjRtFuture<> GetReadyFuture() override; bool IsOnCpu() const override; @@ -555,7 +555,7 @@ class PjRtCApiBuffer : public PjRtBuffer { // This is a shared_ptr to keep the underlying future alive even if // `readiness_promise` is destroyed before `readiness_event`, and the callback // we set on `readiness_event` modifies `readiness_promise_`. - std::shared_ptr::Promise> readiness_promise_; + std::shared_ptr::Promise> readiness_promise_; // Set and cached the first time layout() is called. mutable std::optional layout_; // Set and cached the first time is_dynamic_dimension() is called. diff --git a/third_party/xla/xla/pjrt/pjrt_client.h b/third_party/xla/xla/pjrt/pjrt_client.h index d50e465fcd8af9..551b8ce83a1a93 100644 --- a/third_party/xla/xla/pjrt/pjrt_client.h +++ b/third_party/xla/xla/pjrt/pjrt_client.h @@ -1385,7 +1385,7 @@ class PjRtBuffer { // the buffer has been deleted or donated then the returned future will stay // valid (will not transition to error as a consequence of buffer deletion) // even if the buffer is subsequently donated or deleted. - virtual PjRtFuture GetReadyFuture() = 0; + virtual PjRtFuture<> GetReadyFuture() = 0; // Blocks the host until the buffer's value has been computed and is ready for // immediate use on the device. Useful in particular for timing benchmarks. diff --git a/third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc b/third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc index cb2c1fa0fa6874..6d88a51ffd6a8a 100644 --- a/third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc +++ b/third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc @@ -1895,75 +1895,85 @@ void PjRtStreamExecutorBuffer::CopyToRemoteDeviceScattered( } } -PjRtFuture PjRtStreamExecutorBuffer::GetReadyFuture() { +PjRtFuture<> PjRtStreamExecutorBuffer::GetReadyFuture() { std::shared_ptr device_buffer; - PjRtFuture::Promise definition_promise; + PjRtFuture<>::Promise definition_promise; { absl::MutexLock lock(&mu_); if (device_buffer_ == nullptr) { - return PjRtFuture(InvalidArgument( + return PjRtFuture<>(InvalidArgument( "GetReadyFuture() called on deleted or donated buffer")); } if (!definition_promise_) { device_buffer = device_buffer_; - definition_promise_ = PjRtFuture::CreatePromise(); + definition_promise_ = PjRtFuture<>::CreatePromise(); } definition_promise = definition_promise_; } if (device_buffer) { LocalDeviceState* local_device_state = device_->local_device_state(); - auto async_wait_for_events = - [device_buffer, local_device_state = std::move(local_device_state), - definition_promise]() mutable { - std::unique_ptr stream; - Status defined_status = - device_buffer->definition_events()[0]->GetDefinedStatus(); - if (!defined_status.ok()) { - definition_promise.Set(defined_status); - return; - } - for (auto& event : device_buffer->definition_events()) { - if (!event->IsComplete()) { - if (stream == nullptr) { - stream = local_device_state->BorrowStreamFromPool(); - } - event->WaitForEventOnStream(stream.get()); - } + auto async_wait_for_events = [device_buffer, + local_device_state = + std::move(local_device_state), + definition_promise]() mutable { + std::unique_ptr stream; + Status defined_status = + device_buffer->definition_events()[0]->GetDefinedStatus(); + if (!defined_status.ok()) { + definition_promise.SetError(defined_status); + return; + } + for (auto& event : device_buffer->definition_events()) { + if (!event->IsComplete()) { + if (stream == nullptr) { + stream = local_device_state->BorrowStreamFromPool(); } + event->WaitForEventOnStream(stream.get()); + } + } - if (stream != nullptr) { - auto* stream_ptr = stream.release(); - // We already borrowed a stream from the pool so we can safely do - // the callback directly on that stream instead of bouncing through - // local_device_state->ThenExecuteCallback. The direct callback - // saves significant time. - auto status = stream_ptr->DoHostCallback( - [definition_promise, stream_ptr, local_device_state, - event_with_status = - device_buffer->definition_events()[0]]() mutable { - local_device_state->ReturnStreamToPool( - std::unique_ptr(stream_ptr)); - definition_promise.Set(event_with_status->GetDefinedStatus()); - }); - if (!status.ok()) { - definition_promise.Set(status); - return; - } - } else { - // All events are already complete; set the `definition_promise` - // with the status of the buffer's first definition event which may - // have error status to propagate. - definition_promise.Set( - device_buffer->definition_events()[0]->GetDefinedStatus()); - } - }; + if (stream != nullptr) { + auto* stream_ptr = stream.release(); + // We already borrowed a stream from the pool so we can safely do + // the callback directly on that stream instead of bouncing through + // local_device_state->ThenExecuteCallback. The direct callback + // saves significant time. + auto status = stream_ptr->DoHostCallback( + [definition_promise, stream_ptr, local_device_state, + event_with_status = + device_buffer->definition_events()[0]]() mutable { + local_device_state->ReturnStreamToPool( + std::unique_ptr(stream_ptr)); + auto status = event_with_status->GetDefinedStatus(); + if (status.ok()) { + definition_promise.Set(); + } else { + definition_promise.SetError(status); + } + }); + if (!status.ok()) { + definition_promise.SetError(status); + return; + } + } else { + // All events are already complete; set the `definition_promise` + // with the status of the buffer's first definition event which may + // have error status to propagate. + auto status = device_buffer->definition_events()[0]->GetDefinedStatus(); + if (status.ok()) { + definition_promise.Set(); + } else { + definition_promise.SetError(status); + } + } + }; device_buffer->definition_events()[0]->ExecuteOrAddToFutureTasks( absl::StrFormat("async_wait_for_events_%p", &async_wait_for_events), std::move(async_wait_for_events)); } - return PjRtFuture( + return PjRtFuture<>( std::move(definition_promise), /*on_block_start=*/ []() { diff --git a/third_party/xla/xla/pjrt/pjrt_stream_executor_client.h b/third_party/xla/xla/pjrt/pjrt_stream_executor_client.h index 3544b5d6f28140..57b93b150d2cdd 100644 --- a/third_party/xla/xla/pjrt/pjrt_stream_executor_client.h +++ b/third_party/xla/xla/pjrt/pjrt_stream_executor_client.h @@ -723,7 +723,7 @@ class PjRtStreamExecutorBuffer : public PjRtBuffer { std::vector callbacks, const ScatterDetails& scatter_details) override; - PjRtFuture GetReadyFuture() override; + PjRtFuture<> GetReadyFuture() override; bool IsOnCpu() const override; @@ -803,7 +803,7 @@ class PjRtStreamExecutorBuffer : public PjRtBuffer { std::shared_ptr device_buffer_ ABSL_GUARDED_BY(mu_); // Count of holds on the buffer. std::array holds_ ABSL_GUARDED_BY(mu_); - PjRtFuture::Promise definition_promise_ ABSL_GUARDED_BY(mu_); + PjRtFuture<>::Promise definition_promise_ ABSL_GUARDED_BY(mu_); }; // Wraps one or more XLA LocalExecutables (one per partition, as specified by diff --git a/third_party/xla/xla/pjrt/tf_pjrt_client.h b/third_party/xla/xla/pjrt/tf_pjrt_client.h index 867cbecb8b39cb..671d079359b52e 100644 --- a/third_party/xla/xla/pjrt/tf_pjrt_client.h +++ b/third_party/xla/xla/pjrt/tf_pjrt_client.h @@ -99,9 +99,7 @@ class TfPjRtBuffer : public PjRtBuffer { std::move(serialized_descriptors), std::move(callbacks), scatter_details); } - PjRtFuture GetReadyFuture() override { - return wrapped_->GetReadyFuture(); - } + PjRtFuture<> GetReadyFuture() override { return wrapped_->GetReadyFuture(); } bool IsOnCpu() const override { return wrapped_->IsOnCpu(); } // Not thread-safe. The caller should promises to have some external @@ -225,7 +223,7 @@ class TfPjRtClient : public PjRtClient { PjRtLocalDeviceId local_device_id) const override { if (wrapped_ == nullptr) { return tsl::errors::Internal( - "Wrapped PJRT client in TfPjRtClient is already destoryed."); + "Wrapped PJRT client in TfPjRtClient is already destroyed."); } return wrapped_->LookupAddressableDevice(local_device_id); } diff --git a/third_party/xla/xla/python/pjrt_ifrt/pjrt_array.cc b/third_party/xla/xla/python/pjrt_ifrt/pjrt_array.cc index 0510940c349ce2..68bc459684762b 100644 --- a/third_party/xla/xla/python/pjrt_ifrt/pjrt_array.cc +++ b/third_party/xla/xla/python/pjrt_ifrt/pjrt_array.cc @@ -524,12 +524,12 @@ absl::StatusOr> PjRtArray::Reshard( Future PjRtArray::GetReadyFuture() const { DCHECK(this); if (pjrt_buffers_.size() == 1) { - return pjrt_buffers_.front()->GetReadyFuture(); + return pjrt_buffers_.front()->GetReadyFuture().ToStatusFuture(); } std::vector> futures; futures.reserve(pjrt_buffers_.size()); for (auto& buf : pjrt_buffers_) { - futures.push_back(buf->GetReadyFuture()); + futures.push_back(buf->GetReadyFuture().ToStatusFuture()); } return JoinFutures(absl::MakeSpan(futures)); }