Skip to content

Commit

Permalink
Eliminate StreamExecutor::WaitForEvent in favor of making derived cla…
Browse files Browse the repository at this point in the history
…sses of Stream implement WaitFor(Event) method.

PiperOrigin-RevId: 642303158
  • Loading branch information
klucke authored and tensorflower-gardener committed Jun 11, 2024
1 parent c7515ca commit 95d33fd
Show file tree
Hide file tree
Showing 22 changed files with 190 additions and 134 deletions.
9 changes: 0 additions & 9 deletions tensorflow/c/experimental/stream_executor/stream_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -391,15 +391,6 @@ class CStreamExecutor : public StreamExecutorCommon {
SP_Stream stream_handle = static_cast<CStream*>(stream)->Handle();
return static_cast<CEvent*>(event)->Record(stream_handle);
}
absl::Status WaitForEvent(Stream* stream, Event* event) override {
SP_Stream stream_handle = static_cast<CStream*>(stream)->Handle();
SP_Event event_handle = static_cast<CEvent*>(event)->Handle();
OwnedTFStatus c_status(TF_NewStatus());
stream_executor_->wait_for_event(&device_, stream_handle, event_handle,
c_status.get());
absl::Status s = StatusFromTF_Status(c_status.get());
return s;
}
void DeallocateStream(Stream* stream) override {
static_cast<CStream*>(stream)->Destroy();
}
Expand Down
101 changes: 57 additions & 44 deletions tensorflow/c/experimental/stream_executor/stream_executor_internal.h
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include "xla/stream_executor/stream.h"
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -96,50 +97,6 @@ class CPlatform : public Platform {
stream_executor::ExecutorCache executor_cache_;
};

class CStream : public StreamCommon {
public:
CStream(SP_Device* device, SP_StreamExecutor* stream_executor,
StreamExecutor* executor)
: StreamCommon(executor),
device_(device),
stream_executor_(stream_executor),
stream_handle_(nullptr) {}
~CStream() override {
parent()->BlockHostUntilDone(this).IgnoreError();
parent()->DeallocateStream(this);
Destroy();
}

absl::Status Create() {
tensorflow::TF_StatusPtr c_status(TF_NewStatus());
stream_executor_->create_stream(device_, &stream_handle_, c_status.get());
absl::Status s = tensorflow::StatusFromTF_Status(c_status.get());
return s;
}

void Destroy() {
if (stream_handle_ != nullptr) {
stream_executor_->destroy_stream(device_, stream_handle_);
stream_handle_ = nullptr;
}
}
absl::Status RefreshStatus() override {
tensorflow::TF_StatusPtr c_status(TF_NewStatus());
stream_executor_->get_stream_status(device_, stream_handle_,
c_status.get());
absl::Status status = tensorflow::StatusFromTF_Status(c_status.get());
CheckStatus(status);
return status;
}

SP_Stream Handle() { return stream_handle_; }

private:
SP_Device* device_;
SP_StreamExecutor* stream_executor_;
SP_Stream stream_handle_;
};

class CEvent : public Event {
public:
CEvent(SP_Device* device, SP_StreamExecutor* stream_executor)
Expand Down Expand Up @@ -192,5 +149,61 @@ class CEvent : public Event {
SP_Event event_handle_;
};

class CStream : public StreamCommon {
public:
CStream(SP_Device* device, SP_StreamExecutor* stream_executor,
StreamExecutor* executor)
: StreamCommon(executor),
device_(device),
stream_executor_(stream_executor),
stream_handle_(nullptr) {}
~CStream() override {
parent()->BlockHostUntilDone(this).IgnoreError();
parent()->DeallocateStream(this);
Destroy();
}

absl::Status Create() {
tensorflow::TF_StatusPtr c_status(TF_NewStatus());
stream_executor_->create_stream(device_, &stream_handle_, c_status.get());
absl::Status s = tensorflow::StatusFromTF_Status(c_status.get());
return s;
}

void Destroy() {
if (stream_handle_ != nullptr) {
stream_executor_->destroy_stream(device_, stream_handle_);
stream_handle_ = nullptr;
}
}
absl::Status RefreshStatus() override {
tensorflow::TF_StatusPtr c_status(TF_NewStatus());
stream_executor_->get_stream_status(device_, stream_handle_,
c_status.get());
absl::Status status = tensorflow::StatusFromTF_Status(c_status.get());
CheckStatus(status);
return status;
}

absl::Status WaitFor(Stream* stream) override {
return StreamCommon::WaitFor(stream);
}
absl::Status WaitFor(Event* event) override {
SP_Event event_handle = static_cast<CEvent*>(event)->Handle();
tensorflow::TF_StatusPtr c_status(TF_NewStatus());
stream_executor_->wait_for_event(device_, stream_handle_, event_handle,
c_status.get());
absl::Status s = tensorflow::StatusFromTF_Status(c_status.get());
return s;
}

SP_Stream Handle() { return stream_handle_; }

private:
SP_Device* device_;
SP_StreamExecutor* stream_executor_;
SP_Stream stream_handle_;
};

} // namespace stream_executor
#endif // TENSORFLOW_C_EXPERIMENTAL_STREAM_EXECUTOR_STREAM_EXECUTOR_INTERNAL_H_
19 changes: 14 additions & 5 deletions third_party/xla/xla/backends/interpreter/executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,19 @@ limitations under the License.
namespace stream_executor {
namespace interpreter {

// A HostStream that is used for the interpreter.
class InterpreterStream : public host::HostStream {
public:
explicit InterpreterStream(StreamExecutor *executor)
: host::HostStream(executor) {}
absl::Status WaitFor(Stream *stream) override {
return host::HostStream::WaitFor(stream);
}
absl::Status WaitFor(Event *event) override {
return absl::UnimplementedError("Not implemented.");
}
};

class XlaInterpreterExecutor : public StreamExecutorCommon {
public:
XlaInterpreterExecutor(int device_ordinal, Platform *platform)
Expand Down Expand Up @@ -117,10 +130,6 @@ class XlaInterpreterExecutor : public StreamExecutorCommon {
return absl::Status{absl::StatusCode::kUnimplemented, "RecordEvent"};
}

absl::Status WaitForEvent(Stream *stream, Event *event) override {
return absl::Status{absl::StatusCode::kUnimplemented, "WaitForEvent"};
}

void DeallocateStream(Stream *stream) override {}
bool CreateStreamDependency(Stream *dependent, Stream *other) override;

Expand Down Expand Up @@ -150,7 +159,7 @@ class XlaInterpreterExecutor : public StreamExecutorCommon {
absl::StatusOr<std::unique_ptr<Stream>> CreateStream(
std::optional<std::variant<StreamPriority, int>> priority =
std::nullopt) override {
return std::make_unique<host::HostStream>(this);
return std::make_unique<InterpreterStream>(this);
}

private:
Expand Down
10 changes: 0 additions & 10 deletions third_party/xla/xla/stream_executor/cuda/cuda_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -762,16 +762,6 @@ absl::Status GpuExecutor::RecordEvent(Stream* stream, Event* event) {
return AsGpuEvent(event)->Record(AsGpuStream(stream));
}

absl::Status GpuExecutor::WaitForEvent(Stream* stream, Event* event) {
if (GpuDriver::WaitStreamOnEvent(context_, AsGpuStream(stream)->gpu_stream(),
AsGpuEvent(event)->gpu_event())) {
return absl::OkStatus();
} else {
return absl::InternalError(absl::StrFormat(
"error recording waiting for CUDA event on stream %p", stream));
}
}

void GpuExecutor::DeallocateStream(Stream* stream) {
{
absl::MutexLock lock(&mu_);
Expand Down
3 changes: 3 additions & 0 deletions third_party/xla/xla/stream_executor/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -318,14 +318,17 @@ gpu_only_cc_library(
hdrs = ["gpu_stream.h"],
deps = [
":gpu_driver_header",
":gpu_event_header",
":gpu_executor_header",
":gpu_types_header",
"//xla/stream_executor:event",
"//xla/stream_executor:platform",
"//xla/stream_executor:stream",
"//xla/stream_executor:stream_common",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings:str_format",
],
)

Expand Down
2 changes: 0 additions & 2 deletions third_party/xla/xla/stream_executor/gpu/gpu_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -234,8 +234,6 @@ class GpuExecutor : public StreamExecutorCommon {

absl::Status RecordEvent(Stream* stream, Event* event) override;

absl::Status WaitForEvent(Stream* stream, Event* event) override;

absl::Status BlockHostUntilDone(Stream* stream) override;

absl::Status EnablePeerAccessTo(StreamExecutor* other) override;
Expand Down
14 changes: 14 additions & 0 deletions third_party/xla/xla/stream_executor/gpu/gpu_stream.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@ limitations under the License.
#include "absl/log/check.h"
#include "absl/log/log.h"
#include "absl/status/status.h"
#include "absl/strings/str_format.h"
#include "xla/stream_executor/event.h"
#include "xla/stream_executor/gpu/gpu_driver.h"
#include "xla/stream_executor/gpu/gpu_event.h"
#include "xla/stream_executor/gpu/gpu_executor.h"
#include "xla/stream_executor/gpu/gpu_types.h"
#include "xla/stream_executor/platform.h"
Expand Down Expand Up @@ -52,6 +55,17 @@ Stream::PlatformSpecificHandle GpuStream::platform_specific_handle() const {
return handle;
}

absl::Status GpuStream::WaitFor(Event* event) {
if (GpuDriver::WaitStreamOnEvent(
parent_->gpu_context(), gpu_stream(),
static_cast<GpuEvent*>(event)->gpu_event())) {
return absl::OkStatus();
} else {
return absl::InternalError(absl::StrFormat(
"error recording waiting for event on stream %p", this));
}
}

void GpuStream::Destroy() {
if (completed_event_ != nullptr) {
absl::Status status =
Expand Down
4 changes: 4 additions & 0 deletions third_party/xla/xla/stream_executor/gpu/gpu_stream.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,10 @@ class GpuStream : public StreamCommon {
GpuStreamHandle cuda_stream() const { return gpu_stream(); }

GpuExecutor* parent() const { return parent_; }
absl::Status WaitFor(Stream* stream) override {
return StreamCommon::WaitFor(stream);
}
absl::Status WaitFor(Event* event) override;

private:
GpuExecutor* parent_; // Executor that spawned this stream.
Expand Down
12 changes: 12 additions & 0 deletions third_party/xla/xla/stream_executor/host/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,15 @@ cc_library(
alwayslink = True, # Registers itself with the PlatformManager.
)

cc_library(
name = "host_event",
hdrs = ["host_event.h"],
deps = [
"//xla/stream_executor:event",
"@com_google_absl//absl/synchronization",
],
)

cc_library(
name = "host_stream",
srcs = [
Expand All @@ -71,6 +80,8 @@ cc_library(
"host_stream.h",
],
deps = [
":host_event",
"//xla/stream_executor:event",
"//xla/stream_executor:stream_common",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/functional:any_invocable",
Expand Down Expand Up @@ -148,6 +159,7 @@ cc_library(
"host_executor.h",
],
deps = [
":host_event",
":host_kernel",
":host_stream",
"//xla/stream_executor",
Expand Down
47 changes: 47 additions & 0 deletions third_party/xla/xla/stream_executor/host/host_event.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/* Copyright 2024 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef XLA_STREAM_EXECUTOR_HOST_HOST_EVENT_H_
#define XLA_STREAM_EXECUTOR_HOST_HOST_EVENT_H_

#include <memory>

#include "absl/synchronization/notification.h"
#include "xla/stream_executor/event.h"

namespace stream_executor {

// This class is a host-side implementation of the Event interface. It is
// intended to be used with the HostStream implementation.
class HostEvent : public Event {
public:
HostEvent() : notification_(std::make_shared<absl::Notification>()) {}

std::shared_ptr<absl::Notification>& notification() { return notification_; }

Status PollForStatus() override {
return notification_->HasBeenNotified() ? Event::Status::kComplete
: Event::Status::kPending;
}

private:
// We use a std::shared_ptr here because the client may delete the HostEvent
// object while there are still RecordEvent and WaitForEvent callbacks pending
// on a stream.
std::shared_ptr<absl::Notification> notification_;
};
} // namespace stream_executor

#endif // XLA_STREAM_EXECUTOR_HOST_HOST_EVENT_H_
27 changes: 1 addition & 26 deletions third_party/xla/xla/stream_executor/host/host_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ limitations under the License.
#include "xla/stream_executor/device_description.h"
#include "xla/stream_executor/device_memory.h"
#include "xla/stream_executor/event.h"
#include "xla/stream_executor/host/host_event.h"
#include "xla/stream_executor/host/host_kernel.h"
#include "xla/stream_executor/host/host_stream.h"
#include "xla/stream_executor/kernel_spec.h"
Expand Down Expand Up @@ -237,24 +238,6 @@ bool HostExecutor::CreateStreamDependency(Stream* dependent, Stream* other) {
return true;
}

class HostEvent : public Event {
public:
HostEvent() : notification_(std::make_shared<absl::Notification>()) {}

std::shared_ptr<absl::Notification>& notification() { return notification_; }

Status PollForStatus() override {
return notification_->HasBeenNotified() ? Event::Status::kComplete
: Event::Status::kPending;
}

private:
// We use a std::shared_ptr here because the client may delete the HostEvent
// object while there are still RecordEvent and WaitForEvent callbacks pending
// on a stream.
std::shared_ptr<absl::Notification> notification_;
};

absl::StatusOr<std::unique_ptr<Event>> HostExecutor::CreateEvent() {
return std::make_unique<HostEvent>();
}
Expand All @@ -274,14 +257,6 @@ absl::Status HostExecutor::RecordEvent(Stream* stream, Event* event) {
return absl::OkStatus();
}

absl::Status HostExecutor::WaitForEvent(Stream* stream, Event* event) {
std::shared_ptr<absl::Notification> notification =
AsHostEvent(event)->notification();
AsHostStream(stream)->EnqueueTask(
[notification]() { notification->WaitForNotification(); });
return absl::OkStatus();
}

absl::Status HostExecutor::BlockHostUntilDone(Stream* stream) {
return AsHostStream(stream)->BlockUntilDone();
}
Expand Down
1 change: 0 additions & 1 deletion third_party/xla/xla/stream_executor/host/host_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,6 @@ class HostExecutor : public StreamExecutorCommon {
absl::AnyInvocable<absl::Status() &&> callback) override;

absl::Status RecordEvent(Stream* stream, Event* event) override;
absl::Status WaitForEvent(Stream* stream, Event* event) override;

void DeallocateStream(Stream* stream) override;
bool CreateStreamDependency(Stream* dependent, Stream* other) override;
Expand Down
Loading

0 comments on commit 95d33fd

Please sign in to comment.