From e5592accbb63f0ada62553db293d9d55afec4774 Mon Sep 17 00:00:00 2001 From: Kyle Lucke Date: Wed, 22 May 2024 10:09:15 -0700 Subject: [PATCH] Remove StreamExecutorInterface::WaitForEventOnExternalStream and replace with a proper virtual method on Event. PiperOrigin-RevId: 636207445 --- .../xla/xla/stream_executor/cuda/cuda_executor.cc | 12 ------------ third_party/xla/xla/stream_executor/event.cc | 4 ---- third_party/xla/xla/stream_executor/event.h | 6 ++++-- third_party/xla/xla/stream_executor/gpu/BUILD | 1 + .../xla/xla/stream_executor/gpu/gpu_event.cc | 14 ++++++++++++++ .../xla/xla/stream_executor/gpu/gpu_event.h | 4 ++++ .../xla/xla/stream_executor/gpu/gpu_executor.h | 3 --- .../xla/xla/stream_executor/mock_stream_executor.h | 2 -- .../xla/xla/stream_executor/rocm/rocm_executor.cc | 12 ------------ .../stream_executor/stream_executor_interface.h | 8 -------- 10 files changed, 23 insertions(+), 43 deletions(-) diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc b/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc index 85d54dd84e1dd8..c589bfb0bc5e49 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc @@ -772,18 +772,6 @@ absl::Status GpuExecutor::WaitForEvent(Stream* stream, Event* event) { } } -absl::Status GpuExecutor::WaitForEventOnExternalStream(std::intptr_t stream, - Event* event) { - if (GpuDriver::WaitStreamOnEvent(context_, - absl::bit_cast(stream), - AsGpuEvent(event)->gpu_event())) { - return absl::OkStatus(); - } else { - return absl::InternalError( - "error waiting for CUDA event on external stream"); - } -} - Event::Status GpuExecutor::PollForEventStatus(Event* event) { return AsGpuEvent(event)->PollForStatus(); } diff --git a/third_party/xla/xla/stream_executor/event.cc b/third_party/xla/xla/stream_executor/event.cc index 3de5e9045d40fb..3b2131f28995a0 100644 --- a/third_party/xla/xla/stream_executor/event.cc +++ b/third_party/xla/xla/stream_executor/event.cc @@ -32,8 +32,4 @@ Event::Status Event::PollForStatus() { return stream_exec_->PollForEventStatus(this); } -absl::Status Event::WaitForEventOnExternalStream(std::intptr_t stream) { - return stream_exec_->WaitForEventOnExternalStream(stream, this); -} - } // namespace stream_executor diff --git a/third_party/xla/xla/stream_executor/event.h b/third_party/xla/xla/stream_executor/event.h index f3fee06dde4f1b..cf014e55220208 100644 --- a/third_party/xla/xla/stream_executor/event.h +++ b/third_party/xla/xla/stream_executor/event.h @@ -41,7 +41,7 @@ class Event { kComplete, }; - Event(StreamExecutorInterface* stream_exec); + explicit Event(StreamExecutorInterface* stream_exec); // Releases any resources held by the Event object. virtual ~Event() = default; @@ -51,7 +51,9 @@ class Event { // Blocks `stream` on this event. `stream` is a raw platform-specific // stream (e.g. GpuStreamHandle). - absl::Status WaitForEventOnExternalStream(std::intptr_t stream); + virtual absl::Status WaitForEventOnExternalStream(std::intptr_t stream) { + return absl::UnimplementedError("Not supported for this Event."); + } Event(Event&&) = default; Event& operator=(Event&&) = default; diff --git a/third_party/xla/xla/stream_executor/gpu/BUILD b/third_party/xla/xla/stream_executor/gpu/BUILD index eb5babd09bc664..d43863f01115dd 100644 --- a/third_party/xla/xla/stream_executor/gpu/BUILD +++ b/third_party/xla/xla/stream_executor/gpu/BUILD @@ -197,6 +197,7 @@ gpu_only_cc_library( ":gpu_stream", ":gpu_types_header", "//xla/stream_executor:stream_executor_headers", + "@com_google_absl//absl/base", "@com_google_absl//absl/status", ], ) diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_event.cc b/third_party/xla/xla/stream_executor/gpu/gpu_event.cc index e1d078122212c2..4cd66783ea382c 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_event.cc +++ b/third_party/xla/xla/stream_executor/gpu/gpu_event.cc @@ -15,7 +15,11 @@ limitations under the License. #include "xla/stream_executor/gpu/gpu_event.h" +#include + +#include "absl/base/casts.h" #include "absl/status/status.h" +#include "xla/stream_executor/event.h" #include "xla/stream_executor/gpu/gpu_driver.h" #include "xla/stream_executor/gpu/gpu_executor.h" #include "xla/stream_executor/gpu/gpu_stream.h" @@ -45,5 +49,15 @@ absl::Status GpuEvent::Record(GpuStream* stream) { GpuEventHandle GpuEvent::gpu_event() { return gpu_event_; } +absl::Status GpuEvent::WaitForEventOnExternalStream(std::intptr_t stream) { + if (GpuDriver::WaitStreamOnEvent(parent_->gpu_context(), + absl::bit_cast(stream), + gpu_event_)) { + return absl::OkStatus(); + } else { + return absl::InternalError("Error waiting for event on external stream"); + } +} + } // namespace gpu } // namespace stream_executor diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_event.h b/third_party/xla/xla/stream_executor/gpu/gpu_event.h index 6574e50c426424..5ab851dfb60205 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_event.h +++ b/third_party/xla/xla/stream_executor/gpu/gpu_event.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef XLA_STREAM_EXECUTOR_GPU_GPU_EVENT_H_ #define XLA_STREAM_EXECUTOR_GPU_GPU_EVENT_H_ +#include + #include "absl/status/status.h" #include "xla/stream_executor/event.h" #include "xla/stream_executor/gpu/gpu_stream.h" @@ -47,6 +49,8 @@ class GpuEvent : public Event { // The underlying CUDA event element. GpuEventHandle gpu_event(); + absl::Status WaitForEventOnExternalStream(std::intptr_t stream) override; + private: // The Executor used to which this object and GpuEventHandle are bound. GpuExecutor* parent_; diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_executor.h b/third_party/xla/xla/stream_executor/gpu/gpu_executor.h index 65d85dc0a7ac85..5024220a1ab45d 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_executor.h +++ b/third_party/xla/xla/stream_executor/gpu/gpu_executor.h @@ -237,9 +237,6 @@ class GpuExecutor : public StreamExecutor { absl::Status WaitForEvent(Stream* stream, Event* event) override; - absl::Status WaitForEventOnExternalStream(std::intptr_t stream, - Event* event) override; - Event::Status PollForEventStatus(Event* event) override; absl::Status BlockHostUntilDone(Stream* stream) override; diff --git a/third_party/xla/xla/stream_executor/mock_stream_executor.h b/third_party/xla/xla/stream_executor/mock_stream_executor.h index 8e11d3998748a5..ad2de24f6292b2 100644 --- a/third_party/xla/xla/stream_executor/mock_stream_executor.h +++ b/third_party/xla/xla/stream_executor/mock_stream_executor.h @@ -136,8 +136,6 @@ class MockStreamExecutor : public StreamExecutorInterface { (override)); MOCK_METHOD(absl::Status, WaitForEvent, (Stream * stream, Event* event), (override)); - MOCK_METHOD(absl::Status, WaitForEventOnExternalStream, - (std::intptr_t stream, Event* event), (override)); MOCK_METHOD(Event::Status, PollForEventStatus, (Event * event), (override)); MOCK_METHOD(void, DeallocateStream, (Stream * stream), (override)); MOCK_METHOD(bool, CreateStreamDependency, (Stream * dependent, Stream* other), diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_executor.cc b/third_party/xla/xla/stream_executor/rocm/rocm_executor.cc index 1003f659e231a9..60433afcbf9e3e 100644 --- a/third_party/xla/xla/stream_executor/rocm/rocm_executor.cc +++ b/third_party/xla/xla/stream_executor/rocm/rocm_executor.cc @@ -676,18 +676,6 @@ absl::Status GpuExecutor::WaitForEvent(Stream* stream, Event* event) { } } -absl::Status GpuExecutor::WaitForEventOnExternalStream(std::intptr_t stream, - Event* event) { - if (GpuDriver::WaitStreamOnEvent(context_, - absl::bit_cast(stream), - AsGpuEvent(event)->gpu_event())) { - return absl::OkStatus(); - } else { - return absl::InternalError( - "error waiting for ROCM event on external stream"); - } -} - Event::Status GpuExecutor::PollForEventStatus(Event* event) { return AsGpuEvent(event)->PollForStatus(); } diff --git a/third_party/xla/xla/stream_executor/stream_executor_interface.h b/third_party/xla/xla/stream_executor/stream_executor_interface.h index 982a32ecd8006f..a6a5eb4c0e313d 100644 --- a/third_party/xla/xla/stream_executor/stream_executor_interface.h +++ b/third_party/xla/xla/stream_executor/stream_executor_interface.h @@ -269,14 +269,6 @@ class StreamExecutorInterface { // Waits for the specified event at the end of the specified stream. virtual absl::Status WaitForEvent(Stream* stream, Event* event) = 0; - // Waits for the specified event at the end of the raw platform-specific - // stream. - virtual absl::Status WaitForEventOnExternalStream(std::intptr_t stream, - Event* event) { - return absl::UnimplementedError( - "WaitForEventOnExternalStream not supported on this executor."); - } - // Requests the current status of the event from the underlying platform. virtual Event::Status PollForEventStatus(Event* event) = 0;