Skip to content

Commit

Permalink
Remove StreamExecutorInterface::WaitForEventOnExternalStream and repl…
Browse files Browse the repository at this point in the history
…ace with a proper virtual method on Event.

PiperOrigin-RevId: 635600740
  • Loading branch information
klucke authored and tensorflower-gardener committed May 22, 2024
1 parent 6dd6123 commit 3ccf550
Show file tree
Hide file tree
Showing 26 changed files with 80 additions and 172 deletions.
1 change: 0 additions & 1 deletion tensorflow/c/experimental/stream_executor/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@ cc_library(
"//tensorflow/c:tf_status_helper",
"@local_tsl//tsl/platform:statusor",
"@local_xla//xla/stream_executor",
"@local_xla//xla/stream_executor:event_interface",
"@local_xla//xla/stream_executor:stream_executor_interface",
"@local_xla//xla/stream_executor:stream_interface",
],
Expand Down
15 changes: 6 additions & 9 deletions tensorflow/c/experimental/stream_executor/stream_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -410,22 +410,20 @@ class CStreamExecutor : public StreamExecutor {
absl::Status RecordEvent(Stream* stream, Event* event) override {
SP_Stream stream_handle =
static_cast<CStream*>(stream->implementation())->Handle();
return static_cast<CEvent*>(event->implementation())->Record(stream_handle);
return static_cast<CEvent*>(event)->Record(stream_handle);
}
absl::Status WaitForEvent(Stream* stream, Event* event) override {
SP_Stream stream_handle =
static_cast<CStream*>(stream->implementation())->Handle();
SP_Event event_handle =
static_cast<CEvent*>(event->implementation())->Handle();
SP_Event event_handle = static_cast<CEvent*>(event)->Handle();
OwnedTFStatus c_status(TF_NewStatus());
stream_executor_->wait_for_event(&device_, stream_handle, event_handle,
c_status.get());
absl::Status s = StatusFromTF_Status(c_status.get());
return s;
}
Event::Status PollForEventStatus(Event* event) override {
SP_Event event_handle =
static_cast<CEvent*>(event->implementation())->Handle();
SP_Event event_handle = static_cast<CEvent*>(event)->Handle();
SE_EventStatus event_status =
stream_executor_->get_event_status(&device_, event_handle);
return SEEventStatusToEventStatus(event_status);
Expand All @@ -449,8 +447,7 @@ class CStreamExecutor : public StreamExecutor {
}
absl::Status BlockHostForEvent(Stream* stream, Event* event) {
OwnedTFStatus c_status(TF_NewStatus());
SP_Event event_handle =
static_cast<CEvent*>(event->implementation())->Handle();
SP_Event event_handle = static_cast<CEvent*>(event)->Handle();
stream_executor_->block_host_for_event(&device_, event_handle,
c_status.get());
return StatusFromTF_Status(c_status.get());
Expand Down Expand Up @@ -544,9 +541,9 @@ class CStreamExecutor : public StreamExecutor {
}

absl::StatusOr<std::unique_ptr<Event>> CreateEvent() override {
auto c_event = std::make_unique<CEvent>(&device_, stream_executor_);
auto c_event = std::make_unique<CEvent>(&device_, stream_executor_, this);
TF_RETURN_IF_ERROR(c_event->Create());
return std::make_unique<Event>(this, std::move(c_event));
return std::move(c_event);
}

absl::StatusOr<std::unique_ptr<Stream>> CreateStream(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ limitations under the License.

#include "tensorflow/c/experimental/stream_executor/stream_executor.h"
#include "tensorflow/c/tf_status_helper.h"
#include "xla/stream_executor/event_interface.h"
#include "xla/stream_executor/executor_cache.h"
#include "xla/stream_executor/platform.h"
#include "xla/stream_executor/stream_executor.h"
Expand Down Expand Up @@ -128,10 +127,12 @@ class CStream : public StreamInterface {
SP_Stream stream_handle_;
};

class CEvent : public EventInterface {
class CEvent : public Event {
public:
CEvent(SP_Device* device, SP_StreamExecutor* stream_executor)
: device_(device),
CEvent(SP_Device* device, SP_StreamExecutor* stream_executor,
StreamExecutorInterface* executor_interface)
: Event(executor_interface),
device_(device),
stream_executor_(stream_executor),
event_handle_(nullptr) {}
~CEvent() override { Destroy(); }
Expand Down
1 change: 0 additions & 1 deletion third_party/xla/xla/backends/interpreter/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,6 @@ cc_library(
"//xla:status_macros",
"//xla:xla_data_proto_cc",
"//xla/stream_executor",
"//xla/stream_executor:event_interface",
"//xla/stream_executor:stream_executor_interface",
"//xla/stream_executor/host:host_stream",
"@com_google_absl//absl/functional:any_invocable",
Expand Down
3 changes: 1 addition & 2 deletions third_party/xla/xla/backends/interpreter/executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ limitations under the License.
#include "xla/stream_executor/device_description.h"
#include "xla/stream_executor/device_memory.h"
#include "xla/stream_executor/event.h"
#include "xla/stream_executor/event_interface.h"
#include "xla/stream_executor/host/host_stream.h"
#include "xla/stream_executor/host_memory_allocation.h"
#include "xla/stream_executor/kernel.h"
Expand Down Expand Up @@ -151,7 +150,7 @@ class XlaInterpreterExecutor : public StreamExecutor {
return true;
}
absl::StatusOr<std::unique_ptr<Event>> CreateEvent() override {
return std::make_unique<Event>(this, nullptr);
return std::make_unique<Event>(this);
}

absl::StatusOr<std::unique_ptr<Stream>> CreateStream(
Expand Down
8 changes: 4 additions & 4 deletions third_party/xla/xla/service/gpu/runtime/copy_thunk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ absl::Status CopyThunk::AsyncEvents::Emplace(se::StreamExecutor* executor,
std::unique_ptr<se::Event> event) {
Key key = {executor, instr};
absl::MutexLock lock(&mutex_);
VLOG(3) << "Emplace event " << event->implementation();
VLOG(3) << "Emplace event " << event.get();
if (auto [it, inserted] = events_.try_emplace(key, std::move(event));
inserted) {
return absl::OkStatus();
Expand All @@ -97,7 +97,7 @@ absl::StatusOr<std::unique_ptr<se::Event>> CopyThunk::AsyncEvents::Extract(
Key key = {executor, instr};
absl::MutexLock lock(&mutex_);
if (auto event = events_.extract(key)) {
VLOG(3) << "Extract event " << event.mapped()->implementation();
VLOG(3) << "Extract event " << event.mapped().get();
return std::move(event.mapped());
}
return absl::InternalError("Async copy event was not found!");
Expand Down Expand Up @@ -136,7 +136,7 @@ absl::Status DeviceToHostCopyThunk::ExecuteOnStream(
TF_ASSIGN_OR_RETURN(auto event, executor->CreateEvent());
// Record memcpy operation completion.
TF_RETURN_IF_ERROR(stream->RecordEvent(event.get()));
VLOG(3) << "Emplace events: " << event->implementation()
VLOG(3) << "Emplace events: " << event.get()
<< " for instr: " << instr_->ToString();
return async_events_->Emplace(executor, instr_, std::move(event));
}
Expand Down Expand Up @@ -174,7 +174,7 @@ absl::Status HostToDeviceCopyThunk::ExecuteOnStream(
TF_ASSIGN_OR_RETURN(auto event, executor->CreateEvent());
// Record memcpy operation completion.
TF_RETURN_IF_ERROR(stream->RecordEvent(event.get()));
VLOG(3) << "Emplace events: " << event->implementation()
VLOG(3) << "Emplace events: " << event.get()
<< " for instr: " << instr_->ToString();
return async_events_->Emplace(executor, instr_, std::move(event));
}
Expand Down
9 changes: 0 additions & 9 deletions third_party/xla/xla/stream_executor/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,6 @@ cc_library(
deps = [
":device_description",
":device_memory",
":event_interface",
":kernel_spec",
":module_spec",
":stream_executor_headers",
Expand Down Expand Up @@ -487,13 +486,6 @@ cc_library(
],
)

cc_library(
name = "event_interface",
hdrs = ["event_interface.h"],
deps = [
],
)

cc_library(
name = "stream_interface",
hdrs = ["stream_interface.h"],
Expand Down Expand Up @@ -677,7 +669,6 @@ cc_library(
":blas", # build_cleaner: keep
":command_buffer", # build_cleaner: keep
":dnn", # build_cleaner: keep
":event_interface",
":fft",
":host_memory_allocation", # build_cleaner: keep
":kernel_spec",
Expand Down
16 changes: 2 additions & 14 deletions third_party/xla/xla/stream_executor/cuda/cuda_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ namespace gpu {

static GpuEvent* AsGpuEvent(Event* event) {
DCHECK(event != nullptr);
return static_cast<GpuEvent*>(event->implementation());
return static_cast<GpuEvent*>(event);
}

// Given const GPU memory, returns a libcuda device pointer datatype, suitable
Expand Down Expand Up @@ -772,18 +772,6 @@ absl::Status GpuExecutor::WaitForEvent(Stream* stream, Event* event) {
}
}

absl::Status GpuExecutor::WaitForEventOnExternalStream(std::intptr_t stream,
Event* event) {
if (GpuDriver::WaitStreamOnEvent(context_,
absl::bit_cast<GpuStreamHandle>(stream),
AsGpuEvent(event)->gpu_event())) {
return absl::OkStatus();
} else {
return absl::InternalError(
"error waiting for CUDA event on external stream");
}
}

Event::Status GpuExecutor::PollForEventStatus(Event* event) {
return AsGpuEvent(event)->PollForStatus();
}
Expand Down Expand Up @@ -945,7 +933,7 @@ absl::Status FillBlockDimLimit(GpuDeviceHandle device,
absl::StatusOr<std::unique_ptr<Event>> GpuExecutor::CreateEvent() {
auto gpu_event = std::make_unique<GpuEvent>(this);
TF_RETURN_IF_ERROR(gpu_event->Init());
return std::make_unique<Event>(this, std::move(gpu_event));
return std::move(gpu_event);
}

absl::StatusOr<std::unique_ptr<Stream>> GpuExecutor::CreateStream(
Expand Down
15 changes: 2 additions & 13 deletions third_party/xla/xla/stream_executor/event.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,26 +21,15 @@ limitations under the License.

#include "absl/log/log.h"
#include "absl/status/status.h"
#include "xla/stream_executor/event_interface.h"
#include "xla/stream_executor/stream_executor_interface.h"

namespace stream_executor {

Event::Event(StreamExecutorInterface* stream_exec,
std::unique_ptr<EventInterface> implementation)
: stream_exec_(stream_exec), implementation_(std::move(implementation)) {}

Event::~Event() = default;

Event::Event(Event&&) = default;
Event& Event::operator=(Event&&) = default;
Event::Event(StreamExecutorInterface* stream_exec)
: stream_exec_(stream_exec) {}

Event::Status Event::PollForStatus() {
return stream_exec_->PollForEventStatus(this);
}

absl::Status Event::WaitForEventOnExternalStream(std::intptr_t stream) {
return stream_exec_->WaitForEventOnExternalStream(stream, this);
}

} // namespace stream_executor
22 changes: 7 additions & 15 deletions third_party/xla/xla/stream_executor/event.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,11 @@ limitations under the License.
#define XLA_STREAM_EXECUTOR_EVENT_H_

#include <cstdint>
#include <memory>

#include "absl/status/status.h"

namespace stream_executor {

class EventInterface;
class StreamExecutorInterface;

// The Event class, when supported by a platform, enables low-overhead status
Expand All @@ -43,34 +41,28 @@ class Event {
kComplete,
};

Event(StreamExecutorInterface* stream_exec,
std::unique_ptr<EventInterface> implementation);
explicit Event(StreamExecutorInterface* stream_exec);

// Releases any resources held by the Event object.
~Event();
virtual ~Event() = default;

// Returns the current Status for the event.
Status PollForStatus();

// Blocks `stream` on this event. `stream` is a raw platform-specific
// stream (e.g. GpuStreamHandle).
absl::Status WaitForEventOnExternalStream(std::intptr_t stream);
virtual absl::Status WaitForEventOnExternalStream(std::intptr_t stream) {
return absl::UnimplementedError("Not supported for this Event.");
}

// Returns a pointer to the underlying platform-specific implementation.
EventInterface* implementation() { return implementation_.get(); }

Event(Event&&);
Event& operator=(Event&&);
Event(Event&&) = default;
Event& operator=(Event&&) = default;

private:
// Pointer to the StreamExecutorInterface interface used to create this
// object. Not owned.
StreamExecutorInterface* stream_exec_;

// Pointer to the platform-specific EventInterface implementation underlying
// the object. Owned.
std::unique_ptr<EventInterface> implementation_;

Event(const Event&) = delete;
void operator=(const Event&) = delete;
};
Expand Down
34 changes: 0 additions & 34 deletions third_party/xla/xla/stream_executor/event_interface.h

This file was deleted.

7 changes: 3 additions & 4 deletions third_party/xla/xla/stream_executor/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -182,8 +182,7 @@ gpu_only_cc_library(
deps = [
":gpu_stream_header",
":gpu_types_header",
"//xla/stream_executor:event_interface",
"//xla/stream_executor:stream_executor_interface",
"//xla/stream_executor:stream_executor_headers",
"@com_google_absl//absl/status",
],
)
Expand All @@ -197,8 +196,8 @@ gpu_only_cc_library(
":gpu_executor_header",
":gpu_stream",
":gpu_types_header",
"//xla/stream_executor:event_interface",
"//xla/stream_executor:stream_executor_interface",
"//xla/stream_executor:stream_executor_headers",
"@com_google_absl//absl/base",
"@com_google_absl//absl/status",
],
)
Expand Down
16 changes: 15 additions & 1 deletion third_party/xla/xla/stream_executor/gpu/gpu_event.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@ limitations under the License.

#include "xla/stream_executor/gpu/gpu_event.h"

#include <cstdint>

#include "absl/base/casts.h"
#include "absl/status/status.h"
#include "xla/stream_executor/event.h"
#include "xla/stream_executor/gpu/gpu_driver.h"
#include "xla/stream_executor/gpu/gpu_executor.h"
#include "xla/stream_executor/gpu/gpu_stream.h"
Expand All @@ -25,7 +29,7 @@ namespace stream_executor {
namespace gpu {

GpuEvent::GpuEvent(GpuExecutor* parent)
: parent_(parent), gpu_event_(nullptr) {}
: Event(parent), parent_(parent), gpu_event_(nullptr) {}

GpuEvent::~GpuEvent() { Destroy().IgnoreError(); }

Expand All @@ -45,5 +49,15 @@ absl::Status GpuEvent::Record(GpuStream* stream) {

GpuEventHandle GpuEvent::gpu_event() { return gpu_event_; }

absl::Status GpuEvent::WaitForEventOnExternalStream(std::intptr_t stream) {
if (GpuDriver::WaitStreamOnEvent(parent_->gpu_context(),
absl::bit_cast<GpuStreamHandle>(stream),
gpu_event_)) {
return absl::OkStatus();
} else {
return absl::InternalError("Error waiting for event on external stream");
}
}

} // namespace gpu
} // namespace stream_executor
Loading

0 comments on commit 3ccf550

Please sign in to comment.