Skip to content

Commit

Permalink
Make Stream inherit from StreamInterface, and all the concrete Stream…
Browse files Browse the repository at this point in the history
…Interfaces inherit from Stream.

This is the first step in eliminating StreamInterface as a separate class, and making Stream an abstract base class.

PiperOrigin-RevId: 636592651
  • Loading branch information
klucke authored and tensorflower-gardener committed May 23, 2024
1 parent 1c42da4 commit b28c56f
Show file tree
Hide file tree
Showing 19 changed files with 74 additions and 88 deletions.
5 changes: 2 additions & 3 deletions tensorflow/c/experimental/stream_executor/stream_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -529,9 +529,8 @@ class CStreamExecutor : public StreamExecutor {
absl::StatusOr<std::unique_ptr<Stream>> CreateStream(
std::optional<std::variant<StreamPriority, int>> priority =
std::nullopt) override {
auto c_stream = std::make_unique<CStream>(&device_, stream_executor_);
TF_RETURN_IF_ERROR(c_stream->Create());
auto stream = std::make_unique<Stream>(this, std::move(c_stream));
auto stream = std::make_unique<CStream>(&device_, stream_executor_, this);
TF_RETURN_IF_ERROR(stream->Create());
return std::move(stream);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,13 +97,18 @@ class CPlatform : public Platform {
stream_executor::ExecutorCache executor_cache_;
};

class CStream : public StreamInterface {
class CStream : public Stream {
public:
CStream(SP_Device* device, SP_StreamExecutor* stream_executor)
: device_(device),
CStream(SP_Device* device, SP_StreamExecutor* stream_executor,
StreamExecutor* executor)
: Stream(executor),
device_(device),
stream_executor_(stream_executor),
stream_handle_(nullptr) {}
~CStream() override { Destroy(); }
~CStream() override {
parent()->BlockHostUntilDone(this).IgnoreError();
Destroy();
}

absl::Status Create() {
tensorflow::TF_StatusPtr c_status(TF_NewStatus());
Expand Down
4 changes: 1 addition & 3 deletions third_party/xla/xla/backends/interpreter/executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -152,9 +152,7 @@ class XlaInterpreterExecutor : public StreamExecutor {
absl::StatusOr<std::unique_ptr<Stream>> CreateStream(
std::optional<std::variant<StreamPriority, int>> priority =
std::nullopt) override {
auto stream =
std::make_unique<Stream>(this, std::make_unique<host::HostStream>());
return std::move(stream);
return std::make_unique<host::HostStream>(this);
}

private:
Expand Down
5 changes: 2 additions & 3 deletions third_party/xla/xla/stream_executor/cuda/cuda_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -947,9 +947,8 @@ absl::StatusOr<std::unique_ptr<Stream>> GpuExecutor::CreateStream(
bool init_worked = gpu_stream->Init();
if (init_worked) {
auto platform_specific_stream = gpu_stream->platform_specific_stream();
auto stream = std::make_unique<Stream>(this, std::move(gpu_stream));
alive_gpu_streams_[platform_specific_stream] = stream.get();
return std::move(stream);
alive_gpu_streams_[platform_specific_stream] = gpu_stream.get();
return std::move(gpu_stream);
} else {
return absl::InvalidArgumentError("Failed to initialize gpu stream");
}
Expand Down
3 changes: 1 addition & 2 deletions third_party/xla/xla/stream_executor/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -294,9 +294,9 @@ gpu_only_cc_library(
name = "gpu_stream_header",
hdrs = ["gpu_stream.h"],
deps = [
":gpu_executor_header",
":gpu_types_header",
"//xla/stream_executor",
"//xla/stream_executor:stream_executor_interface",
"@com_google_absl//absl/log:check",
],
)
Expand All @@ -310,7 +310,6 @@ gpu_only_cc_library(
":gpu_executor_header",
":gpu_types_header",
"//xla/stream_executor",
"//xla/stream_executor:stream_executor_interface",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
Expand Down
2 changes: 1 addition & 1 deletion third_party/xla/xla/stream_executor/gpu/gpu_stream.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ bool GpuStream::IsIdle() const {

GpuStream* AsGpuStream(Stream* stream) {
DCHECK(stream != nullptr);
return static_cast<GpuStream*>(stream->implementation());
return static_cast<GpuStream*>(stream);
}

GpuStreamHandle AsGpuStreamValue(Stream* stream) {
Expand Down
14 changes: 9 additions & 5 deletions third_party/xla/xla/stream_executor/gpu/gpu_stream.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@ limitations under the License.
#include <variant>

#include "absl/log/check.h"
#include "xla/stream_executor/gpu/gpu_executor.h"
#include "xla/stream_executor/gpu/gpu_types.h"
#include "xla/stream_executor/platform.h"
#include "xla/stream_executor/stream_executor_interface.h"
#include "xla/stream_executor/stream.h"

namespace stream_executor {
namespace gpu {
Expand All @@ -35,15 +36,18 @@ class GpuExecutor;
// StreamInterface.
//
// Thread-safe post-initialization.
class GpuStream : public StreamInterface {
class GpuStream : public Stream {
public:
explicit GpuStream(GpuExecutor* parent)
: parent_(parent), gpu_stream_(nullptr), completed_event_(nullptr) {}
: Stream(parent),
parent_(parent),
gpu_stream_(nullptr),
completed_event_(nullptr) {}

// Note: teardown is handled by a parent's call to DeallocateStream.
~GpuStream() override = default;
~GpuStream() override { BlockHostUntilDone().IgnoreError(); }

void* platform_specific_stream() override { return gpu_stream_; }
void* platform_specific_stream() const override { return gpu_stream_; }

// Explicitly initialize the CUDA resources associated with this stream.
bool Init();
Expand Down
2 changes: 1 addition & 1 deletion third_party/xla/xla/stream_executor/host/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ cc_library(
"host_stream.h",
],
deps = [
"//xla/stream_executor:stream_interface",
"//xla/stream_executor:stream_executor_headers",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/functional:any_invocable",
"@com_google_absl//absl/log:check",
Expand Down
2 changes: 1 addition & 1 deletion third_party/xla/xla/stream_executor/host/host_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ HostExecutor::CreateDeviceDescription(int device_ordinal) {

absl::StatusOr<std::unique_ptr<Stream>> HostExecutor::CreateStream(
std::optional<std::variant<StreamPriority, int>> priority) {
return std::make_unique<Stream>(this, std::make_unique<HostStream>());
return std::make_unique<HostStream>(this);
}

} // namespace host
Expand Down
7 changes: 5 additions & 2 deletions third_party/xla/xla/stream_executor/host/host_stream.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,18 @@ limitations under the License.
#include "absl/status/status.h"
#include "absl/synchronization/mutex.h"
#include "absl/synchronization/notification.h"
#include "xla/stream_executor/stream.h"
#include "xla/stream_executor/stream_executor_pimpl.h"
#include "tsl/platform/denormal.h"
#include "tsl/platform/env.h"
#include "tsl/platform/setround.h"

namespace stream_executor {
namespace host {

HostStream::HostStream()
: thread_(tsl::Env::Default()->StartThread({}, "host_executor",
HostStream::HostStream(StreamExecutor* executor)
: Stream(executor),
thread_(tsl::Env::Default()->StartThread({}, "host_executor",
[this]() { WorkLoop(); })) {}

HostStream::~HostStream() {
Expand Down
6 changes: 3 additions & 3 deletions third_party/xla/xla/stream_executor/host/host_stream.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,16 @@ limitations under the License.
#include "absl/functional/any_invocable.h"
#include "absl/status/status.h"
#include "absl/synchronization/mutex.h"
#include "xla/stream_executor/stream_interface.h"
#include "xla/stream_executor/stream.h"
#include "tsl/platform/env.h"
#include "tsl/platform/thread_annotations.h"

namespace stream_executor {
namespace host {

class HostStream : public StreamInterface {
class HostStream : public Stream {
public:
HostStream();
explicit HostStream(StreamExecutor* executor);
~HostStream() override;

// Enqueue a task that reports a status when finished. Tasks that fail do not
Expand Down
5 changes: 2 additions & 3 deletions third_party/xla/xla/stream_executor/rocm/rocm_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -853,9 +853,8 @@ absl::StatusOr<std::unique_ptr<Stream>> GpuExecutor::CreateStream(
bool init_worked = gpu_stream->Init();
if (init_worked) {
auto platform_specific_stream = gpu_stream->platform_specific_stream();
auto stream = std::make_unique<Stream>(this, std::move(gpu_stream));
alive_gpu_streams_[platform_specific_stream] = stream.get();
return std::move(stream);
alive_gpu_streams_[platform_specific_stream] = gpu_stream.get();
return std::move(gpu_stream);
} else {
return absl::InvalidArgumentError("Failed to initialize GPU stream");
}
Expand Down
26 changes: 5 additions & 21 deletions third_party/xla/xla/stream_executor/stream.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,32 +42,16 @@ limitations under the License.

namespace stream_executor {

Stream::Stream(StreamExecutor *parent,
std::unique_ptr<StreamInterface> implementation)
: parent_(parent),
implementation_(std::move(implementation)),
status_(absl::OkStatus()) {}

Stream::~Stream() {
// Ensure the stream is completed.
auto status = BlockHostUntilDone();
if (!status.ok()) {
LOG(WARNING) << "Error blocking host until done in stream destructor: "
<< status;
}

if (implementation_ != nullptr) {
parent_->DeallocateStream(this);
}
Stream::Stream(StreamExecutor *parent)
: parent_(parent), status_(absl::OkStatus()) {
CHECK_NE(parent, nullptr);
}

std::variant<StreamPriority, int> Stream::priority() const {
return implementation_->priority();
}
Stream::~Stream() { parent_->DeallocateStream(this); }

Stream::PlatformSpecificHandle Stream::platform_specific_handle() const {
PlatformSpecificHandle handle;
handle.stream = implementation_->platform_specific_stream();
handle.stream = platform_specific_stream();
return handle;
}

Expand Down
13 changes: 3 additions & 10 deletions third_party/xla/xla/stream_executor/stream.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class StreamExecutor;
// !ok(), it will never be ok().
//
// Thread-safe post-initialization.
class Stream {
class Stream : public StreamInterface {
public:
// Platform specific handle to the underlying resources behind a stream
// implementation (e.g. it gives access to CUstream for CUDA platform).
Expand All @@ -79,8 +79,7 @@ class Stream {
// Instantiate a stream tied to parent as a platform executor. Work
// entrained onto this stream will be launched/managed on that
// StreamExecutor's platform.
explicit Stream(StreamExecutor *parent,
std::unique_ptr<StreamInterface> implementation);
explicit Stream(StreamExecutor *parent);

// Deallocates any stream resources that the parent StreamExecutor has
// bestowed
Expand Down Expand Up @@ -238,7 +237,7 @@ class Stream {

// Returns the (opaque) platform-specific backing object. Ownership is not
// transferred to the caller.
StreamInterface *implementation() { return implementation_.get(); }
StreamInterface *implementation() { return this; }

// Entrains onto the stream a callback to the host (from the device).
// Behaves as DoHostCallbackWithStatus below, but the callback should
Expand Down Expand Up @@ -274,8 +273,6 @@ class Stream {
return parent()->GetDeviceDescription().rocm_compute_capability();
}

std::variant<StreamPriority, int> priority() const;

private:
bool InErrorState() const TF_LOCKS_EXCLUDED(mu_) {
absl::ReaderMutexLock lock(&mu_);
Expand All @@ -294,10 +291,6 @@ class Stream {
// The StreamExecutor that supports the operation of this stream.
StreamExecutor *parent_;

// The platform-dependent implementation that the StreamExecutor interface
// delegates to.
std::unique_ptr<StreamInterface> implementation_;

// mutex that guards the allocation / error state flags.
// Mutable so that it can be obtained via const reader lock.
mutable absl::Mutex mu_;
Expand Down
2 changes: 1 addition & 1 deletion third_party/xla/xla/stream_executor/stream_interface.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class StreamInterface {
// if it exists, or nullptr otherwise. This is available via Stream public API
// as Stream::PlatformSpecificHandle, and should not be accessed directly
// outside of a StreamExecutor package.
virtual void* platform_specific_stream() { return nullptr; }
virtual void* platform_specific_stream() const { return nullptr; }

private:
StreamInterface(const StreamInterface&) = delete;
Expand Down
3 changes: 3 additions & 0 deletions third_party/xla/xla/stream_executor/tpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ cc_library(
":tpu_stream_interface",
":tpu_topology_external",
"//xla/stream_executor",
"//xla/stream_executor:stream_executor_headers",
"//xla/stream_executor:stream_executor_interface",
"//xla/stream_executor:stream_interface",
"//xla/stream_executor/platform",
Expand Down Expand Up @@ -262,6 +263,7 @@ cc_library(
":tpu_executor_c_api_hdrs",
":tpu_topology_external",
"//xla/stream_executor",
"//xla/stream_executor:stream_executor_headers",
"//xla/stream_executor:stream_executor_interface",
"//xla/stream_executor:stream_interface",
"@com_google_absl//absl/container:flat_hash_map",
Expand Down Expand Up @@ -310,6 +312,7 @@ cc_library(
":tpu_topology_external",
"//xla:status",
"//xla/stream_executor",
"//xla/stream_executor:stream_executor_headers",
"//xla/stream_executor:stream_executor_interface",
"//xla/stream_executor:stream_interface",
"//xla/tsl/c:tsl_status_internal",
Expand Down

0 comments on commit b28c56f

Please sign in to comment.