From b28c56fc1d5447b953509c11e945aba28a82b526 Mon Sep 17 00:00:00 2001 From: Kyle Lucke Date: Thu, 23 May 2024 10:18:46 -0700 Subject: [PATCH] Make Stream inherit from StreamInterface, and all the concrete StreamInterfaces inherit from Stream. This is the first step in eliminating StreamInterface as a separate class, and making Stream an abstract base class. PiperOrigin-RevId: 636592651 --- .../stream_executor/stream_executor.cc | 5 +-- .../stream_executor_internal.h | 13 +++++-- .../xla/xla/backends/interpreter/executor.h | 4 +- .../xla/stream_executor/cuda/cuda_executor.cc | 5 +-- third_party/xla/xla/stream_executor/gpu/BUILD | 3 +- .../xla/xla/stream_executor/gpu/gpu_stream.cc | 2 +- .../xla/xla/stream_executor/gpu/gpu_stream.h | 14 ++++--- .../xla/xla/stream_executor/host/BUILD | 2 +- .../xla/stream_executor/host/host_executor.cc | 2 +- .../xla/stream_executor/host/host_stream.cc | 7 +++- .../xla/stream_executor/host/host_stream.h | 6 +-- .../xla/stream_executor/rocm/rocm_executor.cc | 5 +-- third_party/xla/xla/stream_executor/stream.cc | 26 +++---------- third_party/xla/xla/stream_executor/stream.h | 13 ++----- .../xla/stream_executor/stream_interface.h | 2 +- third_party/xla/xla/stream_executor/tpu/BUILD | 3 ++ .../xla/stream_executor/tpu/tpu_executor.cc | 39 ++++++++----------- .../xla/xla/stream_executor/tpu/tpu_stream.h | 6 ++- .../tpu/tpu_stream_interface.h | 5 ++- 19 files changed, 74 insertions(+), 88 deletions(-) diff --git a/tensorflow/c/experimental/stream_executor/stream_executor.cc b/tensorflow/c/experimental/stream_executor/stream_executor.cc index 027d2bd60caa85..7822647d0487d3 100644 --- a/tensorflow/c/experimental/stream_executor/stream_executor.cc +++ b/tensorflow/c/experimental/stream_executor/stream_executor.cc @@ -529,9 +529,8 @@ class CStreamExecutor : public StreamExecutor { absl::StatusOr> CreateStream( std::optional> priority = std::nullopt) override { - auto c_stream = std::make_unique(&device_, stream_executor_); - TF_RETURN_IF_ERROR(c_stream->Create()); - auto stream = std::make_unique(this, std::move(c_stream)); + auto stream = std::make_unique(&device_, stream_executor_, this); + TF_RETURN_IF_ERROR(stream->Create()); return std::move(stream); } diff --git a/tensorflow/c/experimental/stream_executor/stream_executor_internal.h b/tensorflow/c/experimental/stream_executor/stream_executor_internal.h index 4fbc9d3090dc68..fa3b4321e27634 100644 --- a/tensorflow/c/experimental/stream_executor/stream_executor_internal.h +++ b/tensorflow/c/experimental/stream_executor/stream_executor_internal.h @@ -97,13 +97,18 @@ class CPlatform : public Platform { stream_executor::ExecutorCache executor_cache_; }; -class CStream : public StreamInterface { +class CStream : public Stream { public: - CStream(SP_Device* device, SP_StreamExecutor* stream_executor) - : device_(device), + CStream(SP_Device* device, SP_StreamExecutor* stream_executor, + StreamExecutor* executor) + : Stream(executor), + device_(device), stream_executor_(stream_executor), stream_handle_(nullptr) {} - ~CStream() override { Destroy(); } + ~CStream() override { + parent()->BlockHostUntilDone(this).IgnoreError(); + Destroy(); + } absl::Status Create() { tensorflow::TF_StatusPtr c_status(TF_NewStatus()); diff --git a/third_party/xla/xla/backends/interpreter/executor.h b/third_party/xla/xla/backends/interpreter/executor.h index 45cc64256377be..15d0029cf8ba50 100644 --- a/third_party/xla/xla/backends/interpreter/executor.h +++ b/third_party/xla/xla/backends/interpreter/executor.h @@ -152,9 +152,7 @@ class XlaInterpreterExecutor : public StreamExecutor { absl::StatusOr> CreateStream( std::optional> priority = std::nullopt) override { - auto stream = - std::make_unique(this, std::make_unique()); - return std::move(stream); + return std::make_unique(this); } private: diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc b/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc index 082f6e8e493d53..bec611675c73f2 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc @@ -947,9 +947,8 @@ absl::StatusOr> GpuExecutor::CreateStream( bool init_worked = gpu_stream->Init(); if (init_worked) { auto platform_specific_stream = gpu_stream->platform_specific_stream(); - auto stream = std::make_unique(this, std::move(gpu_stream)); - alive_gpu_streams_[platform_specific_stream] = stream.get(); - return std::move(stream); + alive_gpu_streams_[platform_specific_stream] = gpu_stream.get(); + return std::move(gpu_stream); } else { return absl::InvalidArgumentError("Failed to initialize gpu stream"); } diff --git a/third_party/xla/xla/stream_executor/gpu/BUILD b/third_party/xla/xla/stream_executor/gpu/BUILD index db228a20fe7a27..53bc2b3a2fe398 100644 --- a/third_party/xla/xla/stream_executor/gpu/BUILD +++ b/third_party/xla/xla/stream_executor/gpu/BUILD @@ -294,9 +294,9 @@ gpu_only_cc_library( name = "gpu_stream_header", hdrs = ["gpu_stream.h"], deps = [ + ":gpu_executor_header", ":gpu_types_header", "//xla/stream_executor", - "//xla/stream_executor:stream_executor_interface", "@com_google_absl//absl/log:check", ], ) @@ -310,7 +310,6 @@ gpu_only_cc_library( ":gpu_executor_header", ":gpu_types_header", "//xla/stream_executor", - "//xla/stream_executor:stream_executor_interface", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_stream.cc b/third_party/xla/xla/stream_executor/gpu/gpu_stream.cc index a04f5410dbd613..a9e5108df70ce8 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_stream.cc +++ b/third_party/xla/xla/stream_executor/gpu/gpu_stream.cc @@ -64,7 +64,7 @@ bool GpuStream::IsIdle() const { GpuStream* AsGpuStream(Stream* stream) { DCHECK(stream != nullptr); - return static_cast(stream->implementation()); + return static_cast(stream); } GpuStreamHandle AsGpuStreamValue(Stream* stream) { diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_stream.h b/third_party/xla/xla/stream_executor/gpu/gpu_stream.h index a06839b5c1c79b..262d37344cd8f1 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_stream.h +++ b/third_party/xla/xla/stream_executor/gpu/gpu_stream.h @@ -22,9 +22,10 @@ limitations under the License. #include #include "absl/log/check.h" +#include "xla/stream_executor/gpu/gpu_executor.h" #include "xla/stream_executor/gpu/gpu_types.h" #include "xla/stream_executor/platform.h" -#include "xla/stream_executor/stream_executor_interface.h" +#include "xla/stream_executor/stream.h" namespace stream_executor { namespace gpu { @@ -35,15 +36,18 @@ class GpuExecutor; // StreamInterface. // // Thread-safe post-initialization. -class GpuStream : public StreamInterface { +class GpuStream : public Stream { public: explicit GpuStream(GpuExecutor* parent) - : parent_(parent), gpu_stream_(nullptr), completed_event_(nullptr) {} + : Stream(parent), + parent_(parent), + gpu_stream_(nullptr), + completed_event_(nullptr) {} // Note: teardown is handled by a parent's call to DeallocateStream. - ~GpuStream() override = default; + ~GpuStream() override { BlockHostUntilDone().IgnoreError(); } - void* platform_specific_stream() override { return gpu_stream_; } + void* platform_specific_stream() const override { return gpu_stream_; } // Explicitly initialize the CUDA resources associated with this stream. bool Init(); diff --git a/third_party/xla/xla/stream_executor/host/BUILD b/third_party/xla/xla/stream_executor/host/BUILD index c4614f4ca30b65..1d437965758bad 100644 --- a/third_party/xla/xla/stream_executor/host/BUILD +++ b/third_party/xla/xla/stream_executor/host/BUILD @@ -71,7 +71,7 @@ cc_library( "host_stream.h", ], deps = [ - "//xla/stream_executor:stream_interface", + "//xla/stream_executor:stream_executor_headers", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/log:check", diff --git a/third_party/xla/xla/stream_executor/host/host_executor.cc b/third_party/xla/xla/stream_executor/host/host_executor.cc index fd4162243dba5d..9c00dad836470e 100644 --- a/third_party/xla/xla/stream_executor/host/host_executor.cc +++ b/third_party/xla/xla/stream_executor/host/host_executor.cc @@ -292,7 +292,7 @@ HostExecutor::CreateDeviceDescription(int device_ordinal) { absl::StatusOr> HostExecutor::CreateStream( std::optional> priority) { - return std::make_unique(this, std::make_unique()); + return std::make_unique(this); } } // namespace host diff --git a/third_party/xla/xla/stream_executor/host/host_stream.cc b/third_party/xla/xla/stream_executor/host/host_stream.cc index dfe091680db92b..654cf4856e084c 100644 --- a/third_party/xla/xla/stream_executor/host/host_stream.cc +++ b/third_party/xla/xla/stream_executor/host/host_stream.cc @@ -27,6 +27,8 @@ limitations under the License. #include "absl/status/status.h" #include "absl/synchronization/mutex.h" #include "absl/synchronization/notification.h" +#include "xla/stream_executor/stream.h" +#include "xla/stream_executor/stream_executor_pimpl.h" #include "tsl/platform/denormal.h" #include "tsl/platform/env.h" #include "tsl/platform/setround.h" @@ -34,8 +36,9 @@ limitations under the License. namespace stream_executor { namespace host { -HostStream::HostStream() - : thread_(tsl::Env::Default()->StartThread({}, "host_executor", +HostStream::HostStream(StreamExecutor* executor) + : Stream(executor), + thread_(tsl::Env::Default()->StartThread({}, "host_executor", [this]() { WorkLoop(); })) {} HostStream::~HostStream() { diff --git a/third_party/xla/xla/stream_executor/host/host_stream.h b/third_party/xla/xla/stream_executor/host/host_stream.h index 2cabde179ca2b9..04282c9ce957fa 100644 --- a/third_party/xla/xla/stream_executor/host/host_stream.h +++ b/third_party/xla/xla/stream_executor/host/host_stream.h @@ -26,16 +26,16 @@ limitations under the License. #include "absl/functional/any_invocable.h" #include "absl/status/status.h" #include "absl/synchronization/mutex.h" -#include "xla/stream_executor/stream_interface.h" +#include "xla/stream_executor/stream.h" #include "tsl/platform/env.h" #include "tsl/platform/thread_annotations.h" namespace stream_executor { namespace host { -class HostStream : public StreamInterface { +class HostStream : public Stream { public: - HostStream(); + explicit HostStream(StreamExecutor* executor); ~HostStream() override; // Enqueue a task that reports a status when finished. Tasks that fail do not diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_executor.cc b/third_party/xla/xla/stream_executor/rocm/rocm_executor.cc index d26eaa4fff5252..a6783d5e08f2da 100644 --- a/third_party/xla/xla/stream_executor/rocm/rocm_executor.cc +++ b/third_party/xla/xla/stream_executor/rocm/rocm_executor.cc @@ -853,9 +853,8 @@ absl::StatusOr> GpuExecutor::CreateStream( bool init_worked = gpu_stream->Init(); if (init_worked) { auto platform_specific_stream = gpu_stream->platform_specific_stream(); - auto stream = std::make_unique(this, std::move(gpu_stream)); - alive_gpu_streams_[platform_specific_stream] = stream.get(); - return std::move(stream); + alive_gpu_streams_[platform_specific_stream] = gpu_stream.get(); + return std::move(gpu_stream); } else { return absl::InvalidArgumentError("Failed to initialize GPU stream"); } diff --git a/third_party/xla/xla/stream_executor/stream.cc b/third_party/xla/xla/stream_executor/stream.cc index 0c611032332b06..e470ceb985c1c4 100644 --- a/third_party/xla/xla/stream_executor/stream.cc +++ b/third_party/xla/xla/stream_executor/stream.cc @@ -42,32 +42,16 @@ limitations under the License. namespace stream_executor { -Stream::Stream(StreamExecutor *parent, - std::unique_ptr implementation) - : parent_(parent), - implementation_(std::move(implementation)), - status_(absl::OkStatus()) {} - -Stream::~Stream() { - // Ensure the stream is completed. - auto status = BlockHostUntilDone(); - if (!status.ok()) { - LOG(WARNING) << "Error blocking host until done in stream destructor: " - << status; - } - - if (implementation_ != nullptr) { - parent_->DeallocateStream(this); - } +Stream::Stream(StreamExecutor *parent) + : parent_(parent), status_(absl::OkStatus()) { + CHECK_NE(parent, nullptr); } -std::variant Stream::priority() const { - return implementation_->priority(); -} +Stream::~Stream() { parent_->DeallocateStream(this); } Stream::PlatformSpecificHandle Stream::platform_specific_handle() const { PlatformSpecificHandle handle; - handle.stream = implementation_->platform_specific_stream(); + handle.stream = platform_specific_stream(); return handle; } diff --git a/third_party/xla/xla/stream_executor/stream.h b/third_party/xla/xla/stream_executor/stream.h index 317846ea3bdfef..fcb302f914c95d 100644 --- a/third_party/xla/xla/stream_executor/stream.h +++ b/third_party/xla/xla/stream_executor/stream.h @@ -68,7 +68,7 @@ class StreamExecutor; // !ok(), it will never be ok(). // // Thread-safe post-initialization. -class Stream { +class Stream : public StreamInterface { public: // Platform specific handle to the underlying resources behind a stream // implementation (e.g. it gives access to CUstream for CUDA platform). @@ -79,8 +79,7 @@ class Stream { // Instantiate a stream tied to parent as a platform executor. Work // entrained onto this stream will be launched/managed on that // StreamExecutor's platform. - explicit Stream(StreamExecutor *parent, - std::unique_ptr implementation); + explicit Stream(StreamExecutor *parent); // Deallocates any stream resources that the parent StreamExecutor has // bestowed @@ -238,7 +237,7 @@ class Stream { // Returns the (opaque) platform-specific backing object. Ownership is not // transferred to the caller. - StreamInterface *implementation() { return implementation_.get(); } + StreamInterface *implementation() { return this; } // Entrains onto the stream a callback to the host (from the device). // Behaves as DoHostCallbackWithStatus below, but the callback should @@ -274,8 +273,6 @@ class Stream { return parent()->GetDeviceDescription().rocm_compute_capability(); } - std::variant priority() const; - private: bool InErrorState() const TF_LOCKS_EXCLUDED(mu_) { absl::ReaderMutexLock lock(&mu_); @@ -294,10 +291,6 @@ class Stream { // The StreamExecutor that supports the operation of this stream. StreamExecutor *parent_; - // The platform-dependent implementation that the StreamExecutor interface - // delegates to. - std::unique_ptr implementation_; - // mutex that guards the allocation / error state flags. // Mutable so that it can be obtained via const reader lock. mutable absl::Mutex mu_; diff --git a/third_party/xla/xla/stream_executor/stream_interface.h b/third_party/xla/xla/stream_executor/stream_interface.h index ae8e5dd2660c79..dcfd56201bfc50 100644 --- a/third_party/xla/xla/stream_executor/stream_interface.h +++ b/third_party/xla/xla/stream_executor/stream_interface.h @@ -46,7 +46,7 @@ class StreamInterface { // if it exists, or nullptr otherwise. This is available via Stream public API // as Stream::PlatformSpecificHandle, and should not be accessed directly // outside of a StreamExecutor package. - virtual void* platform_specific_stream() { return nullptr; } + virtual void* platform_specific_stream() const { return nullptr; } private: StreamInterface(const StreamInterface&) = delete; diff --git a/third_party/xla/xla/stream_executor/tpu/BUILD b/third_party/xla/xla/stream_executor/tpu/BUILD index 019a7bdec96a83..1cefc0a5771ed9 100644 --- a/third_party/xla/xla/stream_executor/tpu/BUILD +++ b/third_party/xla/xla/stream_executor/tpu/BUILD @@ -201,6 +201,7 @@ cc_library( ":tpu_stream_interface", ":tpu_topology_external", "//xla/stream_executor", + "//xla/stream_executor:stream_executor_headers", "//xla/stream_executor:stream_executor_interface", "//xla/stream_executor:stream_interface", "//xla/stream_executor/platform", @@ -262,6 +263,7 @@ cc_library( ":tpu_executor_c_api_hdrs", ":tpu_topology_external", "//xla/stream_executor", + "//xla/stream_executor:stream_executor_headers", "//xla/stream_executor:stream_executor_interface", "//xla/stream_executor:stream_interface", "@com_google_absl//absl/container:flat_hash_map", @@ -310,6 +312,7 @@ cc_library( ":tpu_topology_external", "//xla:status", "//xla/stream_executor", + "//xla/stream_executor:stream_executor_headers", "//xla/stream_executor:stream_executor_interface", "//xla/stream_executor:stream_interface", "//xla/tsl/c:tsl_status_internal", diff --git a/third_party/xla/xla/stream_executor/tpu/tpu_executor.cc b/third_party/xla/xla/stream_executor/tpu/tpu_executor.cc index 88d30655e39edb..e24dca6f5f9462 100644 --- a/third_party/xla/xla/stream_executor/tpu/tpu_executor.cc +++ b/third_party/xla/xla/stream_executor/tpu/tpu_executor.cc @@ -61,14 +61,14 @@ bool TpuExecutor::SynchronizeAllActivity() { absl::Status TpuExecutor::BlockHostUntilDone(Stream* stream) { StatusHelper status; ExecutorApiFn()->TpuExecutor_BlockHostUntilDoneFn( - executor_, get_stream(stream->implementation()), status.c_status); + executor_, get_stream(stream), status.c_status); return status.status(); } absl::Status TpuExecutor::GetStatus(Stream* stream) { StatusHelper status; - ExecutorApiFn()->TpuExecutor_GetStatusFn( - executor_, get_stream(stream->implementation()), status.c_status); + ExecutorApiFn()->TpuExecutor_GetStatusFn(executor_, get_stream(stream), + status.c_status); return status.status(); } @@ -79,10 +79,10 @@ tensorflow::tpu::TpuCoreLocationExternal TpuExecutor::GetCoreLocationExternal() } void TpuExecutor::DeallocateStream(Stream* stream) { - ExecutorApiFn()->TpuExecutor_DeallocateStreamFn( - executor_, get_stream(stream->implementation())); + ExecutorApiFn()->TpuExecutor_DeallocateStreamFn(executor_, + get_stream(stream)); tpu_platform().mutex().Lock(); - stream_map().erase(stream->implementation()); + stream_map().erase(stream); tpu_platform().mutex().Unlock(); } @@ -96,9 +96,8 @@ absl::Status TpuExecutor::RecordEvent(Stream* stream, ::stream_executor::Event* event) { StatusHelper status; auto se_event = tpu_platform().LookupEvent(event); - ExecutorApiFn()->TpuExecutor_RecordEventFn( - executor_, get_stream(stream->implementation()), se_event, - status.c_status); + ExecutorApiFn()->TpuExecutor_RecordEventFn(executor_, get_stream(stream), + se_event, status.c_status); return status.status(); } @@ -106,20 +105,18 @@ absl::Status TpuExecutor::WaitForEvent(Stream* stream, ::stream_executor::Event* event) { StatusHelper status; auto se_event = tpu_platform().LookupEvent(event); - ExecutorApiFn()->TpuExecutor_WaitForEventFn( - executor_, get_stream(stream->implementation()), se_event, - status.c_status); + ExecutorApiFn()->TpuExecutor_WaitForEventFn(executor_, get_stream(stream), + se_event, status.c_status); return status.status(); } absl::StatusOr> TpuExecutor::CreateStream( std::optional> priority) { SE_Stream* tpu_stream = ExecutorApiFn()->TpuStream_NewFn(executor_); - auto ptr = std::make_unique(tpu_stream); + auto stream = std::make_unique(tpu_stream, this); tpu_platform().mutex().Lock(); - stream_map()[ptr.get()] = tpu_stream; + stream_map()[stream.get()] = tpu_stream; tpu_platform().mutex().Unlock(); - auto stream = std::make_unique(this, std::move(ptr)); return std::move(stream); } @@ -212,8 +209,7 @@ absl::Status TpuExecutor::Memcpy( StatusHelper status; SE_DeviceMemoryBase se_base = ApiConverter::ToC(device_src); ExecutorApiFn()->TpuExecutor_MemcpyToHostFn( - executor_, get_stream(stream->implementation()), host_dst, &se_base, size, - status.c_status); + executor_, get_stream(stream), host_dst, &se_base, size, status.c_status); return status.status(); } @@ -223,8 +219,7 @@ absl::Status TpuExecutor::Memcpy( StatusHelper status; SE_DeviceMemoryBase se_base = ApiConverter::ToC(*device_dst); ExecutorApiFn()->TpuExecutor_MemcpyFromHostFn( - executor_, get_stream(stream->implementation()), &se_base, host_src, size, - status.c_status); + executor_, get_stream(stream), &se_base, host_src, size, status.c_status); return status.status(); } @@ -264,8 +259,7 @@ absl::Status TpuExecutor::EnqueueCompactionOnStreamForHbm( Stream* compaction_stream) { StatusHelper status; ExecutorApiFn()->TpuExecutor_EnqueueCompactionOnStreamForHbmFn( - executor_, get_stream(compaction_stream->implementation()), - status.c_status); + executor_, get_stream(compaction_stream), status.c_status); return status.status(); } @@ -286,8 +280,7 @@ bool TpuExecutor::HostCallback(Stream* stream, absl::AnyInvocable callback) { HostCallbackContext* ctx = new HostCallbackContext{std::move(callback)}; return ExecutorApiFn()->TpuExecutor_HostCallbackFn( - executor_, get_stream(stream->implementation()), &HostCallbackTrampoline, - ctx); + executor_, get_stream(stream), &HostCallbackTrampoline, ctx); } absl::StatusOr> diff --git a/third_party/xla/xla/stream_executor/tpu/tpu_stream.h b/third_party/xla/xla/stream_executor/tpu/tpu_stream.h index 456ef878d67653..7eaf2f99b46173 100644 --- a/third_party/xla/xla/stream_executor/tpu/tpu_stream.h +++ b/third_party/xla/xla/stream_executor/tpu/tpu_stream.h @@ -20,6 +20,7 @@ limitations under the License. #include "absl/status/status.h" #include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/stream_executor_pimpl.h" #include "xla/stream_executor/tpu/c_api_conversions.h" #include "xla/stream_executor/tpu/c_api_decl.h" #include "xla/stream_executor/tpu/status_helper.h" @@ -32,8 +33,11 @@ namespace tpu { class TpuStream : public tensorflow::tpu::TpuStreamInterface { public: - explicit TpuStream(SE_Stream* stream) : stream_(stream) {} + explicit TpuStream(SE_Stream* stream, + stream_executor::StreamExecutor* executor) + : TpuStreamInterface(executor), stream_(stream) {} ~TpuStream() override { + BlockHostUntilDone().IgnoreError(); stream_executor::tpu::ExecutorApiFn()->TpuStream_FreeFn(stream_); } diff --git a/third_party/xla/xla/stream_executor/tpu/tpu_stream_interface.h b/third_party/xla/xla/stream_executor/tpu/tpu_stream_interface.h index 4cdacedffde5f2..c492f5a2c26005 100644 --- a/third_party/xla/xla/stream_executor/tpu/tpu_stream_interface.h +++ b/third_party/xla/xla/stream_executor/tpu/tpu_stream_interface.h @@ -18,14 +18,17 @@ limitations under the License. #include "absl/status/status.h" #include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor.h" #include "xla/stream_executor/stream_executor_interface.h" namespace tensorflow { namespace tpu { -class TpuStreamInterface : public stream_executor::StreamInterface { +class TpuStreamInterface : public stream_executor::Stream { public: + explicit TpuStreamInterface(stream_executor::StreamExecutor* executor) + : Stream(executor) {} virtual bool IsSameSharedMemoryLocation(TpuStreamInterface* other) = 0; virtual absl::Status EnqueueOnTpuDeviceSendRecvLocal( stream_executor::DeviceMemoryBase send_buffer,