Skip to content

Commit

Permalink
Split stream_executor_internal.h into separate files defining the thr…
Browse files Browse the repository at this point in the history
…ee Interface classes, and move them out of the internal namespace in preparation for removing the PIMPL pattern.

PiperOrigin-RevId: 624216399
  • Loading branch information
klucke authored and tensorflower-gardener committed Apr 12, 2024
1 parent 0852a09 commit 49d4f54
Show file tree
Hide file tree
Showing 48 changed files with 434 additions and 442 deletions.
4 changes: 3 additions & 1 deletion tensorflow/c/experimental/stream_executor/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,9 @@ cc_library(
"//tensorflow/c:tf_status_helper",
"@local_tsl//tsl/platform:statusor",
"@local_xla//xla/stream_executor",
"@local_xla//xla/stream_executor:stream_executor_internal",
"@local_xla//xla/stream_executor:event_interface",
"@local_xla//xla/stream_executor:stream_executor_interface",
"@local_xla//xla/stream_executor:stream_interface",
],
)

Expand Down
20 changes: 5 additions & 15 deletions tensorflow/c/experimental/stream_executor/stream_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ void HostCallbackTrampoline(void* ctx, TF_Status* status) {
delete host_ctx;
}

class CStreamExecutor : public internal::StreamExecutorInterface {
class CStreamExecutor : public StreamExecutorInterface {
public:
explicit CStreamExecutor(SP_Device device, SP_DeviceFns* device_fns,
SP_StreamExecutor* stream_executor,
Expand Down Expand Up @@ -254,9 +254,6 @@ class CStreamExecutor : public internal::StreamExecutorInterface {
stream_executor_->host_memory_deallocate(&device_, mem);
}

bool HostMemoryRegister(void* mem, uint64 size) override { return false; }
bool HostMemoryUnregister(void* mem) override { return false; }

void* UnifiedMemoryAllocate(uint64 size) override {
CHECK(stream_executor_->unified_memory_allocate);
return stream_executor_->unified_memory_allocate(&device_, size);
Expand Down Expand Up @@ -311,11 +308,6 @@ class CStreamExecutor : public internal::StreamExecutorInterface {
return tsl::errors::Unimplemented(
"SynchronousMemZero is not supported by pluggable device.");
}
absl::Status SynchronousMemSet(DeviceMemoryBase* location, int value,
uint64 size) override {
return tsl::errors::Unimplemented(
"SynchronousMemSet is not supported by pluggable device.");
}
absl::Status SynchronousMemcpy(DeviceMemoryBase* gpu_dst,
const void* host_src, uint64 size) override {
OwnedTFStatus c_status(TF_NewStatus());
Expand Down Expand Up @@ -577,14 +569,12 @@ class CStreamExecutor : public internal::StreamExecutorInterface {

// Each call creates a new instance of the platform-specific implementation of
// the corresponding interface type.
std::unique_ptr<internal::EventInterface> CreateEventImplementation()
override {
return std::unique_ptr<internal::EventInterface>(
std::unique_ptr<EventInterface> CreateEventImplementation() override {
return std::unique_ptr<EventInterface>(
new CEvent(&device_, stream_executor_));
}
std::unique_ptr<internal::StreamInterface> GetStreamImplementation()
override {
return std::unique_ptr<internal::StreamInterface>(
std::unique_ptr<StreamInterface> GetStreamImplementation() override {
return std::unique_ptr<StreamInterface>(
new CStream(&device_, stream_executor_));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@ limitations under the License.

#include "tensorflow/c/experimental/stream_executor/stream_executor.h"
#include "tensorflow/c/tf_status_helper.h"
#include "xla/stream_executor/event_interface.h"
#include "xla/stream_executor/executor_cache.h"
#include "xla/stream_executor/platform.h"
#include "xla/stream_executor/stream_executor.h"
#include "xla/stream_executor/stream_executor_internal.h"
#include "xla/stream_executor/stream_executor_interface.h"
#include "xla/stream_executor/stream_interface.h"
#include "tsl/platform/statusor.h"

namespace stream_executor {
Expand Down Expand Up @@ -95,7 +97,7 @@ class CPlatform : public Platform {
stream_executor::ExecutorCache executor_cache_;
};

class CStream : public internal::StreamInterface {
class CStream : public StreamInterface {
public:
CStream(SP_Device* device, SP_StreamExecutor* stream_executor)
: device_(device),
Expand Down Expand Up @@ -125,7 +127,7 @@ class CStream : public internal::StreamInterface {
SP_Stream stream_handle_;
};

class CEvent : public internal::EventInterface {
class CEvent : public EventInterface {
public:
CEvent(SP_Device* device, SP_StreamExecutor* stream_executor)
: device_(device),
Expand Down
2 changes: 1 addition & 1 deletion third_party/xla/xla/backends/interpreter/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ cc_library(
"//xla:status_macros",
"//xla:xla_data_proto_cc",
"//xla/stream_executor",
"//xla/stream_executor:stream_executor_internal",
"//xla/stream_executor:stream_executor_interface",
"//xla/stream_executor/host:host_stream",
"@com_google_absl//absl/functional:any_invocable",
"@com_google_absl//absl/log",
Expand Down
17 changes: 4 additions & 13 deletions third_party/xla/xla/backends/interpreter/executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,13 @@ limitations under the License.
#include "xla/stream_executor/memory_allocation.h"
#include "xla/stream_executor/stream.h"
#include "xla/stream_executor/stream_executor.h"
#include "xla/stream_executor/stream_executor_internal.h"
#include "xla/stream_executor/stream_executor_interface.h"
#include "xla/xla_data.pb.h"

namespace stream_executor {
namespace interpreter {

class XlaInterpreterExecutor : public internal::StreamExecutorInterface {
class XlaInterpreterExecutor : public StreamExecutorInterface {
public:
explicit XlaInterpreterExecutor(int device_ordinal)
: device_ordinal_(device_ordinal) {}
Expand Down Expand Up @@ -73,8 +73,6 @@ class XlaInterpreterExecutor : public internal::StreamExecutorInterface {
void HostMemoryDeallocate(void *mem) override {
delete[] static_cast<char *>(mem);
}
bool HostMemoryRegister(void *mem, uint64_t size) override { return true; }
bool HostMemoryUnregister(void *mem) override { return true; }

absl::Status Memcpy(Stream *stream, void *host_dst,
const DeviceMemoryBase &dev_src, uint64_t size) override;
Expand Down Expand Up @@ -106,11 +104,6 @@ class XlaInterpreterExecutor : public internal::StreamExecutorInterface {
return absl::InternalError("Interpreter can not memzero");
}

absl::Status SynchronousMemSet(DeviceMemoryBase *location, int value,
uint64_t size) override {
return absl::InternalError("Interpreter can not memset");
}

absl::Status SynchronousMemcpy(DeviceMemoryBase *dev_dst,
const void *host_src, uint64_t size) override;
absl::Status SynchronousMemcpy(void *host_dst,
Expand Down Expand Up @@ -169,13 +162,11 @@ class XlaInterpreterExecutor : public internal::StreamExecutorInterface {
return true;
}

std::unique_ptr<internal::EventInterface> CreateEventImplementation()
override {
std::unique_ptr<EventInterface> CreateEventImplementation() override {
return nullptr;
}

std::unique_ptr<internal::StreamInterface> GetStreamImplementation()
override {
std::unique_ptr<StreamInterface> GetStreamImplementation() override {
return std::make_unique<host::HostStream>();
}

Expand Down
37 changes: 28 additions & 9 deletions third_party/xla/xla/stream_executor/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ cc_library(
hdrs = ["host_memory_allocation.h"],
deps = [
":memory_allocation",
":stream_executor_internal", # TODO(b/323534971): Remove dependency on Interface.
":stream_executor_interface",
],
)

Expand Down Expand Up @@ -359,11 +359,13 @@ exports_files(["lazy_op_runner.h"])
# implementation in static build configuration), or a header only `stream_executor_headers`.

cc_library(
name = "stream_executor_internal",
hdrs = ["stream_executor_internal.h"],
visibility = internal_visibility([":internal"]),
name = "stream_executor_interface",
hdrs = [
"stream_executor_interface.h",
],
deps = [
":stream_executor_headers",
":stream_interface",
"//xla/stream_executor/platform",
"@com_google_absl//absl/functional:any_invocable",
"@com_google_absl//absl/status",
Expand All @@ -382,7 +384,7 @@ cc_library(
cc_library(
name = "stream_executor_headers",
hdrs = [
"stream_executor_internal.h", # TODO(301020144): Remove internal header
"stream_executor_interface.h",
"stream_executor_pimpl.h", # TODO(301020144): Remove internal header
":stream_executor_api_headers",
":stream_executor_plugin_headers",
Expand All @@ -391,6 +393,7 @@ cc_library(
deps = STREAM_EXECUTOR_DEPENDENCIES + if_static([
"@com_google_protobuf//:protobuf", # indirectly-used by dnn.h
]) + [
":stream_interface",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status:statusor",
"@local_tsl//tsl/platform:thread_annotations",
Expand All @@ -416,6 +419,21 @@ cc_library(
],
)

cc_library(
name = "event_interface",
hdrs = ["event_interface.h"],
deps = [
],
)

cc_library(
name = "stream_interface",
hdrs = ["stream_interface.h"],
deps = [
":platform",
],
)

#===--------------------------------------------------------------------------------------------===#
# StreamExecutor private implementation (has private visibility)
#===--------------------------------------------------------------------------------------------===#
Expand All @@ -440,7 +458,7 @@ cc_library(
visibility = ["//visibility:private"],
deps = [
":stream_executor_headers",
":stream_executor_internal",
":stream_executor_interface",
"@com_google_absl//absl/functional:any_invocable",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
Expand All @@ -457,8 +475,9 @@ cc_library(
hdrs = ["event.h"],
visibility = ["//visibility:private"],
deps = [
":event_interface",
":stream_executor_headers",
":stream_executor_internal",
":stream_executor_interface",
"@com_google_absl//absl/log",
"@com_google_absl//absl/status",
],
Expand Down Expand Up @@ -504,7 +523,7 @@ cc_library(
":kernel_spec",
":platform",
":stream_executor_headers",
":stream_executor_internal",
":stream_executor_interface",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/meta:type_traits",
Expand Down Expand Up @@ -564,7 +583,7 @@ cc_library(
":kernel_spec",
":platform",
":stream_executor_headers",
":stream_executor_internal",
":stream_executor_interface",
"//xla/tsl/util:env_var",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/functional:any_invocable",
Expand Down
2 changes: 1 addition & 1 deletion third_party/xla/xla/stream_executor/command_buffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ limitations under the License.
#include "xla/stream_executor/device_memory.h"
#include "xla/stream_executor/kernel.h"
#include "xla/stream_executor/stream_executor.h"
#include "xla/stream_executor/stream_executor_internal.h"
#include "xla/stream_executor/stream_executor_interface.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/statusor.h"

Expand Down
10 changes: 5 additions & 5 deletions third_party/xla/xla/stream_executor/cuda/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ cuda_only_cc_library(
":cuda_runtime",
"//xla/stream_executor",
"//xla/stream_executor:platform_manager",
"//xla/stream_executor:stream_executor_internal",
"//xla/stream_executor:stream_executor_interface",
"//xla/stream_executor/gpu:gpu_driver_header",
"//xla/stream_executor/gpu:gpu_executor_header",
"//xla/stream_executor/platform",
Expand Down Expand Up @@ -223,7 +223,7 @@ cuda_only_cc_library(
deps = [
":cuda_driver",
"//xla/stream_executor",
"//xla/stream_executor:stream_executor_internal",
"//xla/stream_executor:stream_executor_interface",
"//xla/stream_executor/gpu:gpu_activation",
"//xla/stream_executor/platform",
"@local_config_cuda//cuda:cuda_headers",
Expand Down Expand Up @@ -336,7 +336,7 @@ cuda_only_cc_library(
"//xla/stream_executor",
"//xla/stream_executor:fft",
"//xla/stream_executor:plugin_registry",
"//xla/stream_executor:stream_executor_internal",
"//xla/stream_executor:stream_executor_interface",
"//xla/stream_executor/gpu:gpu_executor_header",
"//xla/stream_executor/gpu:gpu_helpers_header",
"//xla/stream_executor/gpu:gpu_stream_header",
Expand Down Expand Up @@ -387,7 +387,7 @@ cuda_only_cc_library(
"//xla/stream_executor",
"//xla/stream_executor:dnn",
"//xla/stream_executor:plugin_registry",
"//xla/stream_executor:stream_executor_internal",
"//xla/stream_executor:stream_executor_interface",
"//xla/stream_executor/gpu:gpu_activation_header",
"//xla/stream_executor/gpu:gpu_diagnostics_header",
"//xla/stream_executor/gpu:gpu_driver_header",
Expand Down Expand Up @@ -620,7 +620,7 @@ cuda_only_cc_library(
":cuda_runtime", # buildcleaner: keep
"//xla/stream_executor",
"//xla/stream_executor:plugin_registry",
"//xla/stream_executor:stream_executor_internal",
"//xla/stream_executor:stream_executor_interface",
"//xla/stream_executor/gpu:gpu_collectives_header",
"//xla/stream_executor/gpu:gpu_command_buffer",
"//xla/stream_executor/gpu:gpu_diagnostics_header",
Expand Down
2 changes: 1 addition & 1 deletion third_party/xla/xla/stream_executor/cuda/cuda_blas.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1429,7 +1429,7 @@ void initialize_cublas() {
absl::Status status =
PluginRegistry::Instance()->RegisterFactory<PluginRegistry::BlasFactory>(
kCudaPlatformId, "cuBLAS",
[](::stream_executor::internal::StreamExecutorInterface *parent)
[](::stream_executor::StreamExecutorInterface *parent)
-> blas::BlasSupport * {
gpu::GpuExecutor *cuda_executor =
dynamic_cast<gpu::GpuExecutor *>(parent);
Expand Down
12 changes: 6 additions & 6 deletions third_party/xla/xla/stream_executor/cuda/cuda_dnn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ limitations under the License.
#include "xla/stream_executor/scratch_allocator.h"
#include "xla/stream_executor/stream.h"
#include "xla/stream_executor/stream_executor.h"
#include "xla/stream_executor/stream_executor_internal.h"
#include "xla/stream_executor/stream_executor_interface.h"
#include "xla/tsl/util/env_var.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/logging.h"
Expand Down Expand Up @@ -6776,7 +6776,7 @@ class CudnnLegacyConvRunner : public dnn::ConvRunner {
DeviceMemoryBase output_data) const override {
auto algo = MakeAlgorithmDesc();

if (static_cast<internal::StreamExecutorInterface*>(parent_) !=
if (static_cast<StreamExecutorInterface*>(parent_) !=
stream->parent()->implementation()) {
return tsl::errors::Internal(
"CudnnLegacyConvRunner cached across multiple StreamExecutors.");
Expand Down Expand Up @@ -7185,7 +7185,7 @@ class CudnnExecutionPlanRunner<void(Args...)>
absl::Status operator()(Stream* stream, dnn::ProfileResult* profile_result,
DeviceMemoryBase scratch_memory,
Args... inputs) const override {
if (static_cast<internal::StreamExecutorInterface*>(parent_) !=
if (static_cast<StreamExecutorInterface*>(parent_) !=
stream->parent()->implementation()) {
return tsl::errors::Internal(
"CudnnExecutionPlanRunner cached across multiple StreamExecutors.");
Expand Down Expand Up @@ -7391,7 +7391,7 @@ class CudnnGraphRunner<void(Args...)> : public dnn::OpRunner<void(Args...)> {
absl::Status operator()(Stream* stream, dnn::ProfileResult* profile_result,
DeviceMemoryBase scratch_memory,
Args... inputs) const override {
if (static_cast<internal::StreamExecutorInterface*>(parent_) !=
if (static_cast<StreamExecutorInterface*>(parent_) !=
stream->parent()->implementation()) {
return tsl::errors::Internal(
"CudnnExecutionPlanRunner cached across multiple StreamExecutors.");
Expand Down Expand Up @@ -7861,7 +7861,7 @@ class CudnnLegacyFusedConvRunner : public dnn::FusedConvRunner {
DeviceMemoryBase side_input_data,
DeviceMemoryBase bias_data,
DeviceMemoryBase output_data) const override {
if (static_cast<internal::StreamExecutorInterface*>(parent_) !=
if (static_cast<StreamExecutorInterface*>(parent_) !=
stream->parent()->implementation()) {
return tsl::errors::Internal(
"CudnnLegacyFusedConvRunner cached across multiple "
Expand Down Expand Up @@ -9797,7 +9797,7 @@ void initialize_cudnn() {
absl::Status status =
PluginRegistry::Instance()->RegisterFactory<PluginRegistry::DnnFactory>(
cuda::kCudaPlatformId, "cuDNN",
[](internal::StreamExecutorInterface* parent) -> dnn::DnnSupport* {
[](StreamExecutorInterface* parent) -> dnn::DnnSupport* {
gpu::GpuExecutor* cuda_executor =
dynamic_cast<gpu::GpuExecutor*>(parent);
if (cuda_executor == nullptr) {
Expand Down

0 comments on commit 49d4f54

Please sign in to comment.