Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make Event::PollForStatus a virtual method, and override it where necessary. #68349

Merged
merged 1 commit into from
May 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 1 addition & 21 deletions tensorflow/c/experimental/stream_executor/stream_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -154,20 +154,6 @@ absl::Status ValidateSEPlatformRegistrationParams(
}
#undef TF_VALIDATE_NOT_NULL

// Converts SE_EventStatus to Event::Status.
Event::Status SEEventStatusToEventStatus(SE_EventStatus s) {
switch (s) {
case SE_EVENT_ERROR:
return Event::Status::kError;
case SE_EVENT_PENDING:
return Event::Status::kPending;
case SE_EVENT_COMPLETE:
return Event::Status::kComplete;
default:
return Event::Status::kUnknown;
}
}

// Converts DeviceMemoryBase to a C struct.
SP_DeviceMemoryBase DeviceMemoryBaseToC(const DeviceMemoryBase* mem) {
SP_DeviceMemoryBase device_memory_base{SP_DEVICE_MEMORY_BASE_STRUCT_SIZE};
Expand Down Expand Up @@ -422,12 +408,6 @@ class CStreamExecutor : public StreamExecutor {
absl::Status s = StatusFromTF_Status(c_status.get());
return s;
}
Event::Status PollForEventStatus(Event* event) override {
SP_Event event_handle = static_cast<CEvent*>(event)->Handle();
SE_EventStatus event_status =
stream_executor_->get_event_status(&device_, event_handle);
return SEEventStatusToEventStatus(event_status);
}
void DeallocateStream(Stream* stream) override {
static_cast<CStream*>(stream->implementation())->Destroy();
}
Expand Down Expand Up @@ -541,7 +521,7 @@ class CStreamExecutor : public StreamExecutor {
}

absl::StatusOr<std::unique_ptr<Event>> CreateEvent() override {
auto c_event = std::make_unique<CEvent>(&device_, stream_executor_, this);
auto c_event = std::make_unique<CEvent>(&device_, stream_executor_);
TF_RETURN_IF_ERROR(c_event->Create());
return std::move(c_event);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,14 +129,28 @@ class CStream : public StreamInterface {

class CEvent : public Event {
public:
CEvent(SP_Device* device, SP_StreamExecutor* stream_executor,
StreamExecutorInterface* executor_interface)
: Event(executor_interface),
device_(device),
CEvent(SP_Device* device, SP_StreamExecutor* stream_executor)
: device_(device),
stream_executor_(stream_executor),
event_handle_(nullptr) {}
~CEvent() override { Destroy(); }

Event::Status PollForStatus() override {
SE_EventStatus event_status =
stream_executor_->get_event_status(device_, event_handle_);

switch (event_status) {
case SE_EVENT_ERROR:
return Event::Status::kError;
case SE_EVENT_PENDING:
return Event::Status::kPending;
case SE_EVENT_COMPLETE:
return Event::Status::kComplete;
default:
return Event::Status::kUnknown;
}
}

absl::Status Create() {
tensorflow::TF_StatusPtr c_status(TF_NewStatus());
stream_executor_->create_event(device_, &event_handle_, c_status.get());
Expand Down
6 changes: 1 addition & 5 deletions third_party/xla/xla/backends/interpreter/executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,10 +121,6 @@ class XlaInterpreterExecutor : public StreamExecutor {
return absl::Status{absl::StatusCode::kUnimplemented, "WaitForEvent"};
}

Event::Status PollForEventStatus(Event *event) override {
return Event::Status::kError;
}

void DeallocateStream(Stream *stream) override {}
bool CreateStreamDependency(Stream *dependent, Stream *other) override;

Expand All @@ -150,7 +146,7 @@ class XlaInterpreterExecutor : public StreamExecutor {
return true;
}
absl::StatusOr<std::unique_ptr<Event>> CreateEvent() override {
return std::make_unique<Event>(this);
return std::make_unique<Event>();
}

absl::StatusOr<std::unique_ptr<Stream>> CreateStream(
Expand Down
1 change: 0 additions & 1 deletion third_party/xla/xla/stream_executor/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -658,7 +658,6 @@ transitive_hdrs(
cc_library(
name = "stream_executor_pimpl",
srcs = [
"event.cc",
"stream.cc",
"stream_executor_pimpl.cc",
],
Expand Down
6 changes: 4 additions & 2 deletions third_party/xla/xla/stream_executor/cuda/cuda_event.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "xla/stream_executor/cuda/cuda_event.h"

#include "absl/log/log.h"
#include "absl/status/statusor.h"
#include "third_party/gpus/cuda/include/cuda.h"
Expand All @@ -24,9 +26,9 @@ limitations under the License.
namespace stream_executor {
namespace gpu {

Event::Status GpuEvent::PollForStatus() {
Event::Status CudaEvent::PollForStatus() {
absl::StatusOr<CUresult> status =
QueryEvent(parent_->gpu_context(), gpu_event_);
QueryEvent(parent()->gpu_context(), gpu_event());
if (!status.ok()) {
LOG(ERROR) << "Error polling for event status: "
<< status.status().message();
Expand Down
16 changes: 11 additions & 5 deletions third_party/xla/xla/stream_executor/cuda/cuda_event.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,20 @@ limitations under the License.
#ifndef XLA_STREAM_EXECUTOR_CUDA_CUDA_EVENT_H_
#define XLA_STREAM_EXECUTOR_CUDA_CUDA_EVENT_H_

#include "xla/stream_executor/event.h"
#include "xla/stream_executor/gpu/gpu_event.h"
#include "xla/stream_executor/gpu/gpu_executor.h"

namespace stream_executor {
namespace cuda {
namespace stream_executor::gpu {

using CUDAEvent = gpu::GpuEvent;
// This class implements Event::PollForStatus for CUDA devices.
class CudaEvent : public GpuEvent {
public:
explicit CudaEvent(GpuExecutor *executor) : GpuEvent(executor) {}

} // namespace cuda
} // namespace stream_executor
Event::Status PollForStatus() override;
};

} // namespace stream_executor::gpu

#endif // XLA_STREAM_EXECUTOR_CUDA_CUDA_EVENT_H_
7 changes: 2 additions & 5 deletions third_party/xla/xla/stream_executor/cuda/cuda_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ limitations under the License.
#include "xla/stream_executor/command_buffer.h"
#include "xla/stream_executor/cuda/cuda_diagnostics.h"
#include "xla/stream_executor/cuda/cuda_driver.h"
#include "xla/stream_executor/cuda/cuda_event.h"
#include "xla/stream_executor/cuda/cuda_platform_id.h"
#include "xla/stream_executor/gpu/gpu_collectives.h"
#include "xla/stream_executor/gpu/gpu_command_buffer.h"
Expand Down Expand Up @@ -772,10 +773,6 @@ absl::Status GpuExecutor::WaitForEvent(Stream* stream, Event* event) {
}
}

Event::Status GpuExecutor::PollForEventStatus(Event* event) {
return AsGpuEvent(event)->PollForStatus();
}

void GpuExecutor::DeallocateStream(Stream* stream) {
{
absl::MutexLock lock(&mu_);
Expand Down Expand Up @@ -931,7 +928,7 @@ absl::Status FillBlockDimLimit(GpuDeviceHandle device,
}

absl::StatusOr<std::unique_ptr<Event>> GpuExecutor::CreateEvent() {
auto gpu_event = std::make_unique<GpuEvent>(this);
auto gpu_event = std::make_unique<CudaEvent>(this);
TF_RETURN_IF_ERROR(gpu_event->Init());
return std::move(gpu_event);
}
Expand Down
17 changes: 1 addition & 16 deletions third_party/xla/xla/stream_executor/event.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@ limitations under the License.

namespace stream_executor {

class StreamExecutorInterface;

// The Event class, when supported by a platform, enables low-overhead status
// reporting for a Stream. An Event is inserted at a location in a stream via
// the Stream::RecordEvent() API. From then on, the Event's status can be
Expand All @@ -41,30 +39,17 @@ class Event {
kComplete,
};

explicit Event(StreamExecutorInterface* stream_exec);

// Releases any resources held by the Event object.
virtual ~Event() = default;

// Returns the current Status for the event.
Status PollForStatus();
virtual Status PollForStatus() { return Status::kError; }

// Blocks `stream` on this event. `stream` is a raw platform-specific
// stream (e.g. GpuStreamHandle).
virtual absl::Status WaitForEventOnExternalStream(std::intptr_t stream) {
return absl::UnimplementedError("Not supported for this Event.");
}

Event(Event&&) = default;
Event& operator=(Event&&) = default;

private:
// Pointer to the StreamExecutorInterface interface used to create this
// object. Not owned.
StreamExecutorInterface* stream_exec_;

Event(const Event&) = delete;
void operator=(const Event&) = delete;
};

} // namespace stream_executor
Expand Down
2 changes: 1 addition & 1 deletion third_party/xla/xla/stream_executor/gpu/gpu_event.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ namespace stream_executor {
namespace gpu {

GpuEvent::GpuEvent(GpuExecutor* parent)
: Event(parent), parent_(parent), gpu_event_(nullptr) {}
: parent_(parent), gpu_event_(nullptr) {}

GpuEvent::~GpuEvent() { Destroy().IgnoreError(); }

Expand Down
6 changes: 3 additions & 3 deletions third_party/xla/xla/stream_executor/gpu/gpu_event.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,14 @@ class GpuEvent : public Event {
// Inserts the event at the current position into the specified stream.
absl::Status Record(GpuStream* stream);

// Polls the CUDA platform for the event's current status.
Event::Status PollForStatus();

// The underlying CUDA event element.
GpuEventHandle gpu_event();

absl::Status WaitForEventOnExternalStream(std::intptr_t stream) override;

protected:
GpuExecutor* parent() const { return parent_; }

private:
// The Executor used to which this object and GpuEventHandle are bound.
GpuExecutor* parent_;
Expand Down
2 changes: 0 additions & 2 deletions third_party/xla/xla/stream_executor/gpu/gpu_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -237,8 +237,6 @@ class GpuExecutor : public StreamExecutor {

absl::Status WaitForEvent(Stream* stream, Event* event) override;

Event::Status PollForEventStatus(Event* event) override;

absl::Status BlockHostUntilDone(Stream* stream) override;

absl::Status EnablePeerAccessTo(StreamExecutorInterface* other) override;
Expand Down
17 changes: 7 additions & 10 deletions third_party/xla/xla/stream_executor/host/host_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -223,12 +223,15 @@ bool HostExecutor::CreateStreamDependency(Stream* dependent, Stream* other) {

class HostEvent : public Event {
public:
HostEvent(StreamExecutorInterface* executor)
: Event(executor),
notification_(std::make_shared<absl::Notification>()) {}
HostEvent() : notification_(std::make_shared<absl::Notification>()) {}

std::shared_ptr<absl::Notification>& notification() { return notification_; }

Status PollForStatus() override {
return notification_->HasBeenNotified() ? Event::Status::kComplete
: Event::Status::kPending;
}

private:
// We use a std::shared_ptr here because the client may delete the HostEvent
// object while there are still RecordEvent and WaitForEvent callbacks pending
Expand All @@ -237,7 +240,7 @@ class HostEvent : public Event {
};

absl::StatusOr<std::unique_ptr<Event>> HostExecutor::CreateEvent() {
return std::make_unique<HostEvent>(this);
return std::make_unique<HostEvent>();
}

static HostEvent* AsHostEvent(Event* event) {
Expand All @@ -263,12 +266,6 @@ absl::Status HostExecutor::WaitForEvent(Stream* stream, Event* event) {
return absl::OkStatus();
}

Event::Status HostExecutor::PollForEventStatus(Event* event) {
absl::Notification& notification = *AsHostEvent(event)->notification();
return notification.HasBeenNotified() ? Event::Status::kComplete
: Event::Status::kPending;
}

absl::Status HostExecutor::BlockHostUntilDone(Stream* stream) {
return AsHostStream(stream)->BlockUntilDone();
}
Expand Down
1 change: 0 additions & 1 deletion third_party/xla/xla/stream_executor/host/host_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,6 @@ class HostExecutor : public StreamExecutor {

absl::Status RecordEvent(Stream* stream, Event* event) override;
absl::Status WaitForEvent(Stream* stream, Event* event) override;
Event::Status PollForEventStatus(Event* event) override;

void DeallocateStream(Stream* stream) override;
bool CreateStreamDependency(Stream* dependent, Stream* other) override;
Expand Down
1 change: 0 additions & 1 deletion third_party/xla/xla/stream_executor/mock_stream_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,6 @@ class MockStreamExecutor : public StreamExecutorInterface {
(override));
MOCK_METHOD(absl::Status, WaitForEvent, (Stream * stream, Event* event),
(override));
MOCK_METHOD(Event::Status, PollForEventStatus, (Event * event), (override));
MOCK_METHOD(void, DeallocateStream, (Stream * stream), (override));
MOCK_METHOD(bool, CreateStreamDependency, (Stream * dependent, Stream* other),
(override));
Expand Down
1 change: 1 addition & 0 deletions third_party/xla/xla/stream_executor/rocm/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ cc_library(
cc_library(
name = "rocm_event",
srcs = if_rocm_is_configured(["rocm_event.cc"]),
hdrs = if_rocm_is_configured(["rocm_event.h"]),
deps = if_rocm_is_configured([
# keep sorted
":rocm_driver",
Expand Down
6 changes: 4 additions & 2 deletions third_party/xla/xla/stream_executor/rocm/rocm_event.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "xla/stream_executor/rocm/rocm_event.h"

#include "xla/stream_executor/gpu/gpu_event.h"
#include "xla/stream_executor/gpu/gpu_executor.h"
#include "xla/stream_executor/gpu/gpu_stream.h"
Expand All @@ -21,9 +23,9 @@ limitations under the License.
namespace stream_executor {
namespace gpu {

Event::Status GpuEvent::PollForStatus() {
Event::Status RocmEvent::PollForStatus() {
absl::StatusOr<hipError_t> status =
QueryEvent(parent_->gpu_context(), gpu_event_);
QueryEvent(parent()->gpu_context(), gpu_event());
if (!status.ok()) {
LOG(ERROR) << "Error polling for event status: "
<< status.status().message();
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/* Copyright 2015 The OpenXLA Authors.
/* Copyright 2018 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand All @@ -13,23 +13,21 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "xla/stream_executor/event.h"
#ifndef XLA_STREAM_EXECUTOR_ROCM_ROCM_EVENT_H_
#define XLA_STREAM_EXECUTOR_ROCM_ROCM_EVENT_H_

#include <cstdint>
#include <memory>
#include <utility>
#include "xla/stream_executor/gpu/gpu_event.h"
#include "xla/stream_executor/gpu/gpu_executor.h"

#include "absl/log/log.h"
#include "absl/status/status.h"
#include "xla/stream_executor/stream_executor_interface.h"
namespace stream_executor::gpu {

namespace stream_executor {
// This class implements Event::PollForStatus for ROCm devices.
class RocmEvent : public GpuEvent {
public:
explicit RocmEvent(GpuExecutor *executor) : GpuEvent(executor) {}

Event::Event(StreamExecutorInterface* stream_exec)
: stream_exec_(stream_exec) {}
Event::Status PollForStatus() override;
};
} // namespace stream_executor::gpu

Event::Status Event::PollForStatus() {
return stream_exec_->PollForEventStatus(this);
}

} // namespace stream_executor
#endif // XLA_STREAM_EXECUTOR_ROCM_ROCM_EVENT_H_
Loading
Loading