Skip to content

Commit

Permalink
Make Event::PollForStatus a virtual method, and override it where nec…
Browse files Browse the repository at this point in the history
…essary.

This enables Event to be a completely virtual base class, and removes it from circular dependencies.

PiperOrigin-RevId: 636235458
  • Loading branch information
klucke authored and tensorflower-gardener committed May 22, 2024
1 parent 74a9135 commit ac74514
Show file tree
Hide file tree
Showing 24 changed files with 73 additions and 120 deletions.
22 changes: 1 addition & 21 deletions tensorflow/c/experimental/stream_executor/stream_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -154,20 +154,6 @@ absl::Status ValidateSEPlatformRegistrationParams(
}
#undef TF_VALIDATE_NOT_NULL

// Converts SE_EventStatus to Event::Status.
Event::Status SEEventStatusToEventStatus(SE_EventStatus s) {
switch (s) {
case SE_EVENT_ERROR:
return Event::Status::kError;
case SE_EVENT_PENDING:
return Event::Status::kPending;
case SE_EVENT_COMPLETE:
return Event::Status::kComplete;
default:
return Event::Status::kUnknown;
}
}

// Converts DeviceMemoryBase to a C struct.
SP_DeviceMemoryBase DeviceMemoryBaseToC(const DeviceMemoryBase* mem) {
SP_DeviceMemoryBase device_memory_base{SP_DEVICE_MEMORY_BASE_STRUCT_SIZE};
Expand Down Expand Up @@ -422,12 +408,6 @@ class CStreamExecutor : public StreamExecutor {
absl::Status s = StatusFromTF_Status(c_status.get());
return s;
}
Event::Status PollForEventStatus(Event* event) override {
SP_Event event_handle = static_cast<CEvent*>(event)->Handle();
SE_EventStatus event_status =
stream_executor_->get_event_status(&device_, event_handle);
return SEEventStatusToEventStatus(event_status);
}
void DeallocateStream(Stream* stream) override {
static_cast<CStream*>(stream->implementation())->Destroy();
}
Expand Down Expand Up @@ -541,7 +521,7 @@ class CStreamExecutor : public StreamExecutor {
}

absl::StatusOr<std::unique_ptr<Event>> CreateEvent() override {
auto c_event = std::make_unique<CEvent>(&device_, stream_executor_, this);
auto c_event = std::make_unique<CEvent>(&device_, stream_executor_);
TF_RETURN_IF_ERROR(c_event->Create());
return std::move(c_event);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,14 +129,28 @@ class CStream : public StreamInterface {

class CEvent : public Event {
public:
CEvent(SP_Device* device, SP_StreamExecutor* stream_executor,
StreamExecutorInterface* executor_interface)
: Event(executor_interface),
device_(device),
CEvent(SP_Device* device, SP_StreamExecutor* stream_executor)
: device_(device),
stream_executor_(stream_executor),
event_handle_(nullptr) {}
~CEvent() override { Destroy(); }

Event::Status PollForStatus() override {
SE_EventStatus event_status =
stream_executor_->get_event_status(device_, event_handle_);

switch (event_status) {
case SE_EVENT_ERROR:
return Event::Status::kError;
case SE_EVENT_PENDING:
return Event::Status::kPending;
case SE_EVENT_COMPLETE:
return Event::Status::kComplete;
default:
return Event::Status::kUnknown;
}
}

absl::Status Create() {
tensorflow::TF_StatusPtr c_status(TF_NewStatus());
stream_executor_->create_event(device_, &event_handle_, c_status.get());
Expand Down
6 changes: 1 addition & 5 deletions third_party/xla/xla/backends/interpreter/executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,10 +121,6 @@ class XlaInterpreterExecutor : public StreamExecutor {
return absl::Status{absl::StatusCode::kUnimplemented, "WaitForEvent"};
}

Event::Status PollForEventStatus(Event *event) override {
return Event::Status::kError;
}

void DeallocateStream(Stream *stream) override {}
bool CreateStreamDependency(Stream *dependent, Stream *other) override;

Expand All @@ -150,7 +146,7 @@ class XlaInterpreterExecutor : public StreamExecutor {
return true;
}
absl::StatusOr<std::unique_ptr<Event>> CreateEvent() override {
return std::make_unique<Event>(this);
return std::make_unique<Event>();
}

absl::StatusOr<std::unique_ptr<Stream>> CreateStream(
Expand Down
1 change: 0 additions & 1 deletion third_party/xla/xla/stream_executor/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -658,7 +658,6 @@ transitive_hdrs(
cc_library(
name = "stream_executor_pimpl",
srcs = [
"event.cc",
"stream.cc",
"stream_executor_pimpl.cc",
],
Expand Down
6 changes: 4 additions & 2 deletions third_party/xla/xla/stream_executor/cuda/cuda_event.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "xla/stream_executor/cuda/cuda_event.h"

#include "absl/log/log.h"
#include "absl/status/statusor.h"
#include "third_party/gpus/cuda/include/cuda.h"
Expand All @@ -24,9 +26,9 @@ limitations under the License.
namespace stream_executor {
namespace gpu {

Event::Status GpuEvent::PollForStatus() {
Event::Status CudaEvent::PollForStatus() {
absl::StatusOr<CUresult> status =
QueryEvent(parent_->gpu_context(), gpu_event_);
QueryEvent(parent()->gpu_context(), gpu_event());
if (!status.ok()) {
LOG(ERROR) << "Error polling for event status: "
<< status.status().message();
Expand Down
16 changes: 11 additions & 5 deletions third_party/xla/xla/stream_executor/cuda/cuda_event.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,20 @@ limitations under the License.
#ifndef XLA_STREAM_EXECUTOR_CUDA_CUDA_EVENT_H_
#define XLA_STREAM_EXECUTOR_CUDA_CUDA_EVENT_H_

#include "xla/stream_executor/event.h"
#include "xla/stream_executor/gpu/gpu_event.h"
#include "xla/stream_executor/gpu/gpu_executor.h"

namespace stream_executor {
namespace cuda {
namespace stream_executor::gpu {

using CUDAEvent = gpu::GpuEvent;
// This class implements Event::PollForStatus for CUDA devices.
class CudaEvent : public GpuEvent {
public:
explicit CudaEvent(GpuExecutor *executor) : GpuEvent(executor) {}

} // namespace cuda
} // namespace stream_executor
Event::Status PollForStatus() override;
};

} // namespace stream_executor::gpu

#endif // XLA_STREAM_EXECUTOR_CUDA_CUDA_EVENT_H_
7 changes: 2 additions & 5 deletions third_party/xla/xla/stream_executor/cuda/cuda_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ limitations under the License.
#include "xla/stream_executor/command_buffer.h"
#include "xla/stream_executor/cuda/cuda_diagnostics.h"
#include "xla/stream_executor/cuda/cuda_driver.h"
#include "xla/stream_executor/cuda/cuda_event.h"
#include "xla/stream_executor/cuda/cuda_platform_id.h"
#include "xla/stream_executor/gpu/gpu_collectives.h"
#include "xla/stream_executor/gpu/gpu_command_buffer.h"
Expand Down Expand Up @@ -772,10 +773,6 @@ absl::Status GpuExecutor::WaitForEvent(Stream* stream, Event* event) {
}
}

Event::Status GpuExecutor::PollForEventStatus(Event* event) {
return AsGpuEvent(event)->PollForStatus();
}

void GpuExecutor::DeallocateStream(Stream* stream) {
{
absl::MutexLock lock(&mu_);
Expand Down Expand Up @@ -931,7 +928,7 @@ absl::Status FillBlockDimLimit(GpuDeviceHandle device,
}

absl::StatusOr<std::unique_ptr<Event>> GpuExecutor::CreateEvent() {
auto gpu_event = std::make_unique<GpuEvent>(this);
auto gpu_event = std::make_unique<CudaEvent>(this);
TF_RETURN_IF_ERROR(gpu_event->Init());
return std::move(gpu_event);
}
Expand Down
17 changes: 1 addition & 16 deletions third_party/xla/xla/stream_executor/event.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@ limitations under the License.

namespace stream_executor {

class StreamExecutorInterface;

// The Event class, when supported by a platform, enables low-overhead status
// reporting for a Stream. An Event is inserted at a location in a stream via
// the Stream::RecordEvent() API. From then on, the Event's status can be
Expand All @@ -41,30 +39,17 @@ class Event {
kComplete,
};

explicit Event(StreamExecutorInterface* stream_exec);

// Releases any resources held by the Event object.
virtual ~Event() = default;

// Returns the current Status for the event.
Status PollForStatus();
virtual Status PollForStatus() { return Status::kError; }

// Blocks `stream` on this event. `stream` is a raw platform-specific
// stream (e.g. GpuStreamHandle).
virtual absl::Status WaitForEventOnExternalStream(std::intptr_t stream) {
return absl::UnimplementedError("Not supported for this Event.");
}

Event(Event&&) = default;
Event& operator=(Event&&) = default;

private:
// Pointer to the StreamExecutorInterface interface used to create this
// object. Not owned.
StreamExecutorInterface* stream_exec_;

Event(const Event&) = delete;
void operator=(const Event&) = delete;
};

} // namespace stream_executor
Expand Down
2 changes: 1 addition & 1 deletion third_party/xla/xla/stream_executor/gpu/gpu_event.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ namespace stream_executor {
namespace gpu {

GpuEvent::GpuEvent(GpuExecutor* parent)
: Event(parent), parent_(parent), gpu_event_(nullptr) {}
: parent_(parent), gpu_event_(nullptr) {}

GpuEvent::~GpuEvent() { Destroy().IgnoreError(); }

Expand Down
6 changes: 3 additions & 3 deletions third_party/xla/xla/stream_executor/gpu/gpu_event.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,14 @@ class GpuEvent : public Event {
// Inserts the event at the current position into the specified stream.
absl::Status Record(GpuStream* stream);

// Polls the CUDA platform for the event's current status.
Event::Status PollForStatus();

// The underlying CUDA event element.
GpuEventHandle gpu_event();

absl::Status WaitForEventOnExternalStream(std::intptr_t stream) override;

protected:
GpuExecutor* parent() const { return parent_; }

private:
// The Executor used to which this object and GpuEventHandle are bound.
GpuExecutor* parent_;
Expand Down
2 changes: 0 additions & 2 deletions third_party/xla/xla/stream_executor/gpu/gpu_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -237,8 +237,6 @@ class GpuExecutor : public StreamExecutor {

absl::Status WaitForEvent(Stream* stream, Event* event) override;

Event::Status PollForEventStatus(Event* event) override;

absl::Status BlockHostUntilDone(Stream* stream) override;

absl::Status EnablePeerAccessTo(StreamExecutorInterface* other) override;
Expand Down
17 changes: 7 additions & 10 deletions third_party/xla/xla/stream_executor/host/host_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -223,12 +223,15 @@ bool HostExecutor::CreateStreamDependency(Stream* dependent, Stream* other) {

class HostEvent : public Event {
public:
HostEvent(StreamExecutorInterface* executor)
: Event(executor),
notification_(std::make_shared<absl::Notification>()) {}
HostEvent() : notification_(std::make_shared<absl::Notification>()) {}

std::shared_ptr<absl::Notification>& notification() { return notification_; }

Status PollForStatus() override {
return notification_->HasBeenNotified() ? Event::Status::kComplete
: Event::Status::kPending;
}

private:
// We use a std::shared_ptr here because the client may delete the HostEvent
// object while there are still RecordEvent and WaitForEvent callbacks pending
Expand All @@ -237,7 +240,7 @@ class HostEvent : public Event {
};

absl::StatusOr<std::unique_ptr<Event>> HostExecutor::CreateEvent() {
return std::make_unique<HostEvent>(this);
return std::make_unique<HostEvent>();
}

static HostEvent* AsHostEvent(Event* event) {
Expand All @@ -263,12 +266,6 @@ absl::Status HostExecutor::WaitForEvent(Stream* stream, Event* event) {
return absl::OkStatus();
}

Event::Status HostExecutor::PollForEventStatus(Event* event) {
absl::Notification& notification = *AsHostEvent(event)->notification();
return notification.HasBeenNotified() ? Event::Status::kComplete
: Event::Status::kPending;
}

absl::Status HostExecutor::BlockHostUntilDone(Stream* stream) {
return AsHostStream(stream)->BlockUntilDone();
}
Expand Down
1 change: 0 additions & 1 deletion third_party/xla/xla/stream_executor/host/host_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,6 @@ class HostExecutor : public StreamExecutor {

absl::Status RecordEvent(Stream* stream, Event* event) override;
absl::Status WaitForEvent(Stream* stream, Event* event) override;
Event::Status PollForEventStatus(Event* event) override;

void DeallocateStream(Stream* stream) override;
bool CreateStreamDependency(Stream* dependent, Stream* other) override;
Expand Down
1 change: 0 additions & 1 deletion third_party/xla/xla/stream_executor/mock_stream_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,6 @@ class MockStreamExecutor : public StreamExecutorInterface {
(override));
MOCK_METHOD(absl::Status, WaitForEvent, (Stream * stream, Event* event),
(override));
MOCK_METHOD(Event::Status, PollForEventStatus, (Event * event), (override));
MOCK_METHOD(void, DeallocateStream, (Stream * stream), (override));
MOCK_METHOD(bool, CreateStreamDependency, (Stream * dependent, Stream* other),
(override));
Expand Down
1 change: 1 addition & 0 deletions third_party/xla/xla/stream_executor/rocm/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ cc_library(
cc_library(
name = "rocm_event",
srcs = if_rocm_is_configured(["rocm_event.cc"]),
hdrs = if_rocm_is_configured(["rocm_event.h"]),
deps = if_rocm_is_configured([
# keep sorted
":rocm_driver",
Expand Down
6 changes: 4 additions & 2 deletions third_party/xla/xla/stream_executor/rocm/rocm_event.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "xla/stream_executor/rocm/rocm_event.h"

#include "xla/stream_executor/gpu/gpu_event.h"
#include "xla/stream_executor/gpu/gpu_executor.h"
#include "xla/stream_executor/gpu/gpu_stream.h"
Expand All @@ -21,9 +23,9 @@ limitations under the License.
namespace stream_executor {
namespace gpu {

Event::Status GpuEvent::PollForStatus() {
Event::Status RocmEvent::PollForStatus() {
absl::StatusOr<hipError_t> status =
QueryEvent(parent_->gpu_context(), gpu_event_);
QueryEvent(parent()->gpu_context(), gpu_event());
if (!status.ok()) {
LOG(ERROR) << "Error polling for event status: "
<< status.status().message();
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/* Copyright 2015 The OpenXLA Authors.
/* Copyright 2018 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand All @@ -13,23 +13,21 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "xla/stream_executor/event.h"
#ifndef XLA_STREAM_EXECUTOR_ROCM_ROCM_EVENT_H_
#define XLA_STREAM_EXECUTOR_ROCM_ROCM_EVENT_H_

#include <cstdint>
#include <memory>
#include <utility>
#include "xla/stream_executor/gpu/gpu_event.h"
#include "xla/stream_executor/gpu/gpu_executor.h"

#include "absl/log/log.h"
#include "absl/status/status.h"
#include "xla/stream_executor/stream_executor_interface.h"
namespace stream_executor::gpu {

namespace stream_executor {
// This class implements Event::PollForStatus for ROCm devices.
class RocmEvent : public GpuEvent {
public:
explicit RocmEvent(GpuExecutor *executor) : GpuEvent(executor) {}

Event::Event(StreamExecutorInterface* stream_exec)
: stream_exec_(stream_exec) {}
Event::Status PollForStatus() override;
};
} // namespace stream_executor::gpu

Event::Status Event::PollForStatus() {
return stream_exec_->PollForEventStatus(this);
}

} // namespace stream_executor
#endif // XLA_STREAM_EXECUTOR_ROCM_ROCM_EVENT_H_
Loading

0 comments on commit ac74514

Please sign in to comment.