From 55a80e5f3cca6de34629748c4db48563271cbcba Mon Sep 17 00:00:00 2001 From: Kyle Lucke Date: Wed, 22 May 2024 12:51:23 -0700 Subject: [PATCH] Move StreamExecutorMemoryAllocator to its own header & implementation file. This helps eliminate some circular dependencies. PiperOrigin-RevId: 636262525 --- third_party/xla/xla/service/BUILD | 4 + third_party/xla/xla/service/backend.cc | 1 + third_party/xla/xla/service/backend.h | 1 + .../service/generic_transfer_manager_test.cc | 1 + third_party/xla/xla/service/gpu/BUILD | 7 +- .../xla/xla/service/gpu/autotuner_util.h | 1 + third_party/xla/xla/service/gpu/fusions/BUILD | 1 + .../xla/xla/service/gpu/fusions/cudnn_test.cc | 1 + .../xla/service/gpu/gemm_fusion_autotuner.cc | 1 + third_party/xla/xla/service/gpu/runtime/BUILD | 9 +- .../runtime/address_computation_thunk_test.cc | 1 + .../gpu/runtime/command_buffer_cmd_test.cc | 1 + .../gpu/runtime/command_buffer_thunk_test.cc | 1 + third_party/xla/xla/service/gpu/tests/BUILD | 2 + .../service/gpu/tests/gemm_rewrite_test.cc | 1 + .../gpu/tests/gpu_too_many_blocks_test.cc | 1 + third_party/xla/xla/service/hlo_runner.cc | 1 + .../xla/xla/service/shaped_buffer_test.cc | 1 + third_party/xla/xla/stream_executor/BUILD | 24 +++++- .../stream_executor/device_memory_allocator.h | 55 +----------- third_party/xla/xla/stream_executor/gpu/BUILD | 1 + .../gpu/redzone_allocator_test.cc | 1 + ...cc => stream_executor_memory_allocator.cc} | 8 +- .../stream_executor_memory_allocator.h | 83 +++++++++++++++++++ third_party/xla/xla/tests/BUILD | 11 +++ .../xla/xla/tests/buffer_donation_test.cc | 1 + .../xla/xla/tests/cpu_gpu_fusion_test.cc | 1 + .../xla/xla/tests/dot_operation_test.cc | 1 + third_party/xla/xla/tests/dynamic_ops_test.cc | 1 + third_party/xla/xla/tests/hlo_test_base.cc | 1 + .../xla/tests/local_client_execute_test.cc | 1 + .../xla/xla/tests/local_client_test_base.cc | 1 + .../xla/xla/tests/local_client_test_base.h | 1 + third_party/xla/xla/tests/while_test.cc | 1 + third_party/xla/xla/tools/BUILD | 1 + third_party/xla/xla/tools/xla_compile_lib.cc | 1 + 36 files changed, 164 insertions(+), 66 deletions(-) rename third_party/xla/xla/stream_executor/{device_memory_allocator.cc => stream_executor_memory_allocator.cc} (97%) create mode 100644 third_party/xla/xla/stream_executor/stream_executor_memory_allocator.h diff --git a/third_party/xla/xla/service/BUILD b/third_party/xla/xla/service/BUILD index b9a1aa734ea3cc..abf68cb2b1359a 100644 --- a/third_party/xla/xla/service/BUILD +++ b/third_party/xla/xla/service/BUILD @@ -1108,6 +1108,7 @@ cc_library( "//xla/stream_executor", "//xla/stream_executor:device_memory_allocator", "//xla/stream_executor:stream_executor_interface", + "//xla/stream_executor:stream_executor_memory_allocator", "//xla/stream_executor/host:host_platform_id", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", @@ -1457,6 +1458,7 @@ xla_cc_test( "//xla:test", "//xla/stream_executor:device_memory_allocator", "//xla/stream_executor:stream_executor_interface", + "//xla/stream_executor:stream_executor_memory_allocator", "//xla/tests:xla_internal_test_main", "@local_tsl//tsl/platform:test_benchmark", ], @@ -4095,6 +4097,7 @@ xla_cc_test( "//xla:types", "//xla/stream_executor", "//xla/stream_executor:platform_manager", + "//xla/stream_executor:stream_executor_memory_allocator", "//xla/stream_executor/host:host_platform", "//xla/stream_executor/host:host_platform_id", "//xla/tests:literal_test_util", @@ -5753,6 +5756,7 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/hlo/ir:hlo_module_group", "//xla/stream_executor", + "//xla/stream_executor:stream_executor_memory_allocator", "@com_google_absl//absl/types:span", "@eigen_archive//:eigen3", "@local_tsl//tsl/platform:blocking_counter", diff --git a/third_party/xla/xla/service/backend.cc b/third_party/xla/xla/service/backend.cc index b3999ae2364001..82285fb781588a 100644 --- a/third_party/xla/xla/service/backend.cc +++ b/third_party/xla/xla/service/backend.cc @@ -31,6 +31,7 @@ limitations under the License. #include "xla/stream_executor/host/host_platform_id.h" #include "xla/stream_executor/stream_executor.h" #include "xla/stream_executor/stream_executor_interface.h" +#include "xla/stream_executor/stream_executor_memory_allocator.h" #include "xla/util.h" #include "tsl/platform/cpu_info.h" #include "tsl/platform/env.h" diff --git a/third_party/xla/xla/service/backend.h b/third_party/xla/xla/service/backend.h index 0c919517d98bb7..89a91dddf32de6 100644 --- a/third_party/xla/xla/service/backend.h +++ b/third_party/xla/xla/service/backend.h @@ -33,6 +33,7 @@ limitations under the License. #include "xla/statusor.h" #include "xla/stream_executor/device_memory_allocator.h" #include "xla/stream_executor/stream_executor.h" +#include "xla/stream_executor/stream_executor_memory_allocator.h" namespace Eigen { struct ThreadPoolDevice; diff --git a/third_party/xla/xla/service/generic_transfer_manager_test.cc b/third_party/xla/xla/service/generic_transfer_manager_test.cc index d0235816488e65..41ea92d46a0385 100644 --- a/third_party/xla/xla/service/generic_transfer_manager_test.cc +++ b/third_party/xla/xla/service/generic_transfer_manager_test.cc @@ -34,6 +34,7 @@ limitations under the License. #include "xla/stream_executor/host/host_platform_id.h" #include "xla/stream_executor/platform_manager.h" #include "xla/stream_executor/stream_executor.h" +#include "xla/stream_executor/stream_executor_memory_allocator.h" #include "xla/tests/literal_test_util.h" #include "xla/types.h" #include "tsl/lib/core/status_test_util.h" diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index fd226aebb54890..69212de9446054 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -767,7 +767,10 @@ cc_library( "@local_tsl//tsl/profiler/lib:scoped_annotation", "//xla/tsl/util/proto:proto_utils", "//xla/service/gpu:hlo_traversal", - ]) + ["@local_tsl//tsl/platform:path"], + ]) + [ + "//xla/stream_executor:stream_executor_memory_allocator", + "@local_tsl//tsl/platform:path", + ], ) xla_test( @@ -1610,7 +1613,7 @@ cc_library( "@local_tsl//tsl/platform:path", "@local_tsl//tsl/platform:protobuf", "@local_tsl//tsl/platform:statusor", - ]), + ]) + ["//xla/stream_executor:stream_executor_memory_allocator"], ) # We need a separate target, as runtime executable cannot depend on compilation diff --git a/third_party/xla/xla/service/gpu/autotuner_util.h b/third_party/xla/xla/service/gpu/autotuner_util.h index d9ab1989ddaefc..4a57e31d29cfd4 100644 --- a/third_party/xla/xla/service/gpu/autotuner_util.h +++ b/third_party/xla/xla/service/gpu/autotuner_util.h @@ -39,6 +39,7 @@ limitations under the License. #include "xla/stream_executor/device_memory_allocator.h" #include "xla/stream_executor/gpu/redzone_allocator.h" #include "xla/stream_executor/stream_executor.h" +#include "xla/stream_executor/stream_executor_memory_allocator.h" #include "xla/xla.pb.h" namespace xla { diff --git a/third_party/xla/xla/service/gpu/fusions/BUILD b/third_party/xla/xla/service/gpu/fusions/BUILD index e600de03c9d3a2..af20239b5c67fe 100644 --- a/third_party/xla/xla/service/gpu/fusions/BUILD +++ b/third_party/xla/xla/service/gpu/fusions/BUILD @@ -706,6 +706,7 @@ xla_test( "//xla/service/gpu/runtime:thunk", "//xla/service/gpu/tests:gpu_codegen_test", "//xla/stream_executor:stream_executor_headers", + "//xla/stream_executor:stream_executor_memory_allocator", "//xla/tests:filecheck", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", diff --git a/third_party/xla/xla/service/gpu/fusions/cudnn_test.cc b/third_party/xla/xla/service/gpu/fusions/cudnn_test.cc index cc380db0fffbc9..df80683e797fcd 100644 --- a/third_party/xla/xla/service/gpu/fusions/cudnn_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/cudnn_test.cc @@ -33,6 +33,7 @@ limitations under the License. #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" #include "xla/stream_executor/dnn.h" +#include "xla/stream_executor/stream_executor_memory_allocator.h" #include "xla/stream_executor/stream_executor_pimpl.h" #include "xla/tests/filecheck.h" #include "xla/xla.pb.h" diff --git a/third_party/xla/xla/service/gpu/gemm_fusion_autotuner.cc b/third_party/xla/xla/service/gpu/gemm_fusion_autotuner.cc index 1644433f48e779..b69ec7de93401f 100644 --- a/third_party/xla/xla/service/gpu/gemm_fusion_autotuner.cc +++ b/third_party/xla/xla/service/gpu/gemm_fusion_autotuner.cc @@ -81,6 +81,7 @@ limitations under the License. #include "xla/stream_executor/device_memory_allocator.h" #include "xla/stream_executor/gpu/redzone_allocator.h" #include "xla/stream_executor/stream.h" +#include "xla/stream_executor/stream_executor_memory_allocator.h" #include "xla/tools/hlo_decomposer.h" #include "xla/tsl/util/proto/proto_utils.h" #include "xla/util.h" diff --git a/third_party/xla/xla/service/gpu/runtime/BUILD b/third_party/xla/xla/service/gpu/runtime/BUILD index 88765d1177cb41..15c20f4bf21aae 100644 --- a/third_party/xla/xla/service/gpu/runtime/BUILD +++ b/third_party/xla/xla/service/gpu/runtime/BUILD @@ -148,6 +148,7 @@ xla_test( local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured(["TENSORFLOW_USE_ROCM=1"]), deps = [ ":command_buffer_cmd", + ":thunk", "//xla:status", "//xla:types", "//xla/service:buffer_assignment", @@ -155,10 +156,10 @@ xla_test( "//xla/service:platform_util", "//xla/service/gpu:buffer_allocations", "//xla/service/gpu:launch_dimensions", - "//xla/service/gpu/runtime:thunk", "//xla/stream_executor", "//xla/stream_executor:platform", "//xla/stream_executor:platform_manager", + "//xla/stream_executor:stream_executor_memory_allocator", "//xla/stream_executor/gpu:gpu_test_kernels", "@com_google_absl//absl/functional:function_ref", "@com_google_absl//absl/status", @@ -356,6 +357,7 @@ xla_test( ":address_computation_thunk", ":custom_call_thunk", ":gemm_thunk", + ":thunk", "//xla:shape_util", "//xla:types", "//xla/ffi", @@ -366,10 +368,10 @@ xla_test( "//xla/service/gpu:buffer_allocations", "//xla/service/gpu:launch_dimensions", "//xla/service/gpu:matmul_utils", - "//xla/service/gpu/runtime:thunk", "//xla/stream_executor", "//xla/stream_executor:platform", "//xla/stream_executor:platform_manager", + "//xla/stream_executor:stream_executor_memory_allocator", "//xla/stream_executor/gpu:gpu_test_kernels", "//xla/stream_executor/gpu:gpu_types_header", "@com_google_absl//absl/algorithm:container", @@ -454,6 +456,7 @@ xla_test( deps = [ ":command_buffer_cmd", ":command_buffer_thunk", + ":thunk", "//xla:shape_util", "//xla:types", "//xla:xla_data_proto_cc", @@ -463,10 +466,10 @@ xla_test( "//xla/service/gpu:buffer_allocations", "//xla/service/gpu:launch_dimensions", "//xla/service/gpu:matmul_utils", - "//xla/service/gpu/runtime:thunk", "//xla/stream_executor", "//xla/stream_executor:platform", "//xla/stream_executor:platform_manager", + "//xla/stream_executor:stream_executor_memory_allocator", "//xla/stream_executor/gpu:gpu_test_kernels", "//xla/stream_executor/gpu:gpu_types_header", "@com_google_absl//absl/status:statusor", diff --git a/third_party/xla/xla/service/gpu/runtime/address_computation_thunk_test.cc b/third_party/xla/xla/service/gpu/runtime/address_computation_thunk_test.cc index bd51204ab2f480..6da63493cd5a15 100644 --- a/third_party/xla/xla/service/gpu/runtime/address_computation_thunk_test.cc +++ b/third_party/xla/xla/service/gpu/runtime/address_computation_thunk_test.cc @@ -46,6 +46,7 @@ limitations under the License. #include "xla/stream_executor/platform_manager.h" #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor.h" +#include "xla/stream_executor/stream_executor_memory_allocator.h" #include "xla/types.h" // IWYU pragma: keep #include "tsl/lib/core/status_test_util.h" #include "tsl/platform/test.h" diff --git a/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd_test.cc b/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd_test.cc index f96f4fa5d9e246..c6109c6ece5bc5 100644 --- a/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd_test.cc +++ b/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd_test.cc @@ -35,6 +35,7 @@ limitations under the License. #include "xla/stream_executor/platform.h" #include "xla/stream_executor/platform_manager.h" #include "xla/stream_executor/stream_executor.h" +#include "xla/stream_executor/stream_executor_memory_allocator.h" #include "xla/types.h" // IWYU pragma: keep #include "tsl/lib/core/status_test_util.h" #include "tsl/platform/status.h" diff --git a/third_party/xla/xla/service/gpu/runtime/command_buffer_thunk_test.cc b/third_party/xla/xla/service/gpu/runtime/command_buffer_thunk_test.cc index e8ab612924474b..b8394af8e3b5bf 100644 --- a/third_party/xla/xla/service/gpu/runtime/command_buffer_thunk_test.cc +++ b/third_party/xla/xla/service/gpu/runtime/command_buffer_thunk_test.cc @@ -44,6 +44,7 @@ limitations under the License. #include "xla/stream_executor/platform.h" #include "xla/stream_executor/platform_manager.h" #include "xla/stream_executor/stream_executor.h" +#include "xla/stream_executor/stream_executor_memory_allocator.h" #include "xla/types.h" // IWYU pragma: keep #include "xla/xla_data.pb.h" #include "tsl/lib/core/status_test_util.h" diff --git a/third_party/xla/xla/service/gpu/tests/BUILD b/third_party/xla/xla/service/gpu/tests/BUILD index 8359fa18436004..3d4e2152ca1b67 100644 --- a/third_party/xla/xla/service/gpu/tests/BUILD +++ b/third_party/xla/xla/service/gpu/tests/BUILD @@ -184,6 +184,7 @@ xla_test( "//xla/service/gpu:gpu_executable", "//xla/stream_executor:device_description", "//xla/stream_executor:device_memory_allocator", + "//xla/stream_executor:stream_executor_memory_allocator", "//xla/tests:filecheck", "//xla/tests:verified_hlo_module", "@com_google_absl//absl/container:flat_hash_map", @@ -230,6 +231,7 @@ xla_cc_test( ":gpu_codegen_test", "//xla/hlo/ir:hlo", "//xla/service:executable", + "//xla/stream_executor:stream_executor_memory_allocator", "@com_google_absl//absl/status:statusor", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", diff --git a/third_party/xla/xla/service/gpu/tests/gemm_rewrite_test.cc b/third_party/xla/xla/service/gpu/tests/gemm_rewrite_test.cc index bebb21dbadd9c8..67a7b99d451c12 100644 --- a/third_party/xla/xla/service/gpu/tests/gemm_rewrite_test.cc +++ b/third_party/xla/xla/service/gpu/tests/gemm_rewrite_test.cc @@ -44,6 +44,7 @@ limitations under the License. #include "xla/service/pattern_matcher_gmock.h" #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/device_memory_allocator.h" +#include "xla/stream_executor/stream_executor_memory_allocator.h" #include "xla/test.h" #include "xla/tests/filecheck.h" #include "xla/tests/verified_hlo_module.h" diff --git a/third_party/xla/xla/service/gpu/tests/gpu_too_many_blocks_test.cc b/third_party/xla/xla/service/gpu/tests/gpu_too_many_blocks_test.cc index 9832fb8895a13b..bf346ed724cf8b 100644 --- a/third_party/xla/xla/service/gpu/tests/gpu_too_many_blocks_test.cc +++ b/third_party/xla/xla/service/gpu/tests/gpu_too_many_blocks_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_module.h" #include "xla/service/executable.h" #include "xla/service/gpu/tests/gpu_codegen_test.h" +#include "xla/stream_executor/stream_executor_memory_allocator.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" diff --git a/third_party/xla/xla/service/hlo_runner.cc b/third_party/xla/xla/service/hlo_runner.cc index 564256b42b33fc..c51e189fe4784a 100644 --- a/third_party/xla/xla/service/hlo_runner.cc +++ b/third_party/xla/xla/service/hlo_runner.cc @@ -31,6 +31,7 @@ limitations under the License. #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/stream_executor/device_memory_allocator.h" +#include "xla/stream_executor/stream_executor_memory_allocator.h" #include "tsl/platform/blocking_counter.h" #include "tsl/platform/logging.h" #include "tsl/platform/statusor.h" diff --git a/third_party/xla/xla/service/shaped_buffer_test.cc b/third_party/xla/xla/service/shaped_buffer_test.cc index cc7fcccef460cf..c13eb86f72168a 100644 --- a/third_party/xla/xla/service/shaped_buffer_test.cc +++ b/third_party/xla/xla/service/shaped_buffer_test.cc @@ -23,6 +23,7 @@ limitations under the License. #include "xla/shape_util.h" #include "xla/stream_executor/device_memory_allocator.h" #include "xla/stream_executor/stream_executor_interface.h" +#include "xla/stream_executor/stream_executor_memory_allocator.h" #include "xla/test.h" #include "tsl/platform/test_benchmark.h" diff --git a/third_party/xla/xla/stream_executor/BUILD b/third_party/xla/xla/stream_executor/BUILD index 972aab2d39eace..6df97ba9d49c16 100644 --- a/third_party/xla/xla/stream_executor/BUILD +++ b/third_party/xla/xla/stream_executor/BUILD @@ -248,12 +248,10 @@ cc_library( cc_library( name = "device_memory_allocator", - srcs = ["device_memory_allocator.cc"], hdrs = ["device_memory_allocator.h"], deps = [ ":device_memory", ":platform", - ":stream_executor_headers", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", @@ -269,6 +267,28 @@ cc_library( ], ) +cc_library( + name = "stream_executor_memory_allocator", + srcs = ["stream_executor_memory_allocator.cc"], + hdrs = ["stream_executor_memory_allocator.h"], + deps = [ + ":device_memory", + ":device_memory_allocator", + ":platform", + ":stream_executor_headers", + ":stream_executor_interface", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:numbers", + "@local_tsl//tsl/platform:statusor", + ], +) + cc_library( name = "host_memory_allocation", srcs = ["host_memory_allocation.cc"], diff --git a/third_party/xla/xla/stream_executor/device_memory_allocator.h b/third_party/xla/xla/stream_executor/device_memory_allocator.h index bab6812458036b..9520fa9b917012 100644 --- a/third_party/xla/xla/stream_executor/device_memory_allocator.h +++ b/third_party/xla/xla/stream_executor/device_memory_allocator.h @@ -18,19 +18,12 @@ limitations under the License. #include #include -#include -#include -#include -#include "absl/base/thread_annotations.h" #include "absl/log/check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" -#include "absl/synchronization/mutex.h" -#include "absl/types/span.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/platform.h" -#include "xla/stream_executor/stream_executor_interface.h" #include "tsl/platform/errors.h" #include "tsl/platform/status.h" @@ -79,8 +72,7 @@ class ScopedDeviceMemory { device_ordinal_(other.device_ordinal_), allocator_(other.allocator_) {} - // Releases the memory that was provided in the constructor, through the - // "parent" StreamExecutor. + // Releases the memory that was provided in the constructor. ~ScopedDeviceMemory() { TF_CHECK_OK(Free()); } // Moves ownership of the memory from other to this object. @@ -223,51 +215,6 @@ class DeviceMemoryAllocator { const Platform *platform_; }; -// Default memory allocator for a platform which uses -// StreamExecutor::Allocate/Deallocate. -class StreamExecutorMemoryAllocator : public DeviceMemoryAllocator { - public: - // Create an allocator supporting a single device, corresponding to the passed - // executor. - explicit StreamExecutorMemoryAllocator(StreamExecutorInterface *executor); - - // Create an allocator supporting multiple stream executors. - // - // Precondition: all stream_executors have different device ordinals. - StreamExecutorMemoryAllocator( - const Platform *platform, - absl::Span stream_executors); - - absl::StatusOr Allocate(int device_ordinal, uint64_t size, - bool retry_on_failure, - int64_t memory_space) override; - - // Pull in two-arg overload that sets retry_on_failure to true. - using DeviceMemoryAllocator::Allocate; - - absl::Status Deallocate(int device_ordinal, DeviceMemoryBase mem) override; - - bool AllowsAsynchronousDeallocation() const override; - - // Gets-or-creates a stream for a given `device_ordinal` from an appropriate - // stream executor. - absl::StatusOr GetStream(int device_ordinal) override; - - // Gets the stream executor for given device ordinal. - absl::StatusOr GetStreamExecutor( - int device_ordinal) const; - - private: - // Available stream executors. Each stream executor has a different device - // ordinal. - std::vector stream_executors_; - - absl::Mutex mutex_; - - // Cache of streams for GetStream. - std::map> streams_ ABSL_GUARDED_BY(mutex_); -}; - template absl::Status ScopedDeviceMemory::Free() { if (!wrapped_.is_null()) { diff --git a/third_party/xla/xla/stream_executor/gpu/BUILD b/third_party/xla/xla/stream_executor/gpu/BUILD index d43863f01115dd..db228a20fe7a27 100644 --- a/third_party/xla/xla/stream_executor/gpu/BUILD +++ b/third_party/xla/xla/stream_executor/gpu/BUILD @@ -566,6 +566,7 @@ xla_cc_test( ":redzone_allocator", "//xla/stream_executor", "//xla/stream_executor:device_memory_allocator", + "//xla/stream_executor:stream_executor_memory_allocator", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:string_view", diff --git a/third_party/xla/xla/stream_executor/gpu/redzone_allocator_test.cc b/third_party/xla/xla/stream_executor/gpu/redzone_allocator_test.cc index 6c5ee154d15026..1ab7dea3030050 100644 --- a/third_party/xla/xla/stream_executor/gpu/redzone_allocator_test.cc +++ b/third_party/xla/xla/stream_executor/gpu/redzone_allocator_test.cc @@ -27,6 +27,7 @@ limitations under the License. #include "xla/stream_executor/gpu/gpu_init.h" #include "xla/stream_executor/platform.h" #include "xla/stream_executor/platform_manager.h" +#include "xla/stream_executor/stream_executor_memory_allocator.h" #include "tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" diff --git a/third_party/xla/xla/stream_executor/device_memory_allocator.cc b/third_party/xla/xla/stream_executor/stream_executor_memory_allocator.cc similarity index 97% rename from third_party/xla/xla/stream_executor/device_memory_allocator.cc rename to third_party/xla/xla/stream_executor/stream_executor_memory_allocator.cc index d659322c263fae..c9b1494e1ed4f0 100644 --- a/third_party/xla/xla/stream_executor/device_memory_allocator.cc +++ b/third_party/xla/xla/stream_executor/stream_executor_memory_allocator.cc @@ -13,23 +13,21 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/stream_executor/device_memory_allocator.h" +#include "xla/stream_executor/stream_executor_memory_allocator.h" #include -#include #include -#include "absl/log/log.h" #include "absl/status/status.h" -#include "absl/status/statusor.h" #include "absl/strings/str_format.h" #include "absl/synchronization/mutex.h" #include "absl/types/span.h" #include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/device_memory_allocator.h" #include "xla/stream_executor/platform.h" #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor_interface.h" -#include "tsl/platform/errors.h" +#include "tsl/platform/logging.h" #include "tsl/platform/numbers.h" #include "tsl/platform/statusor.h" diff --git a/third_party/xla/xla/stream_executor/stream_executor_memory_allocator.h b/third_party/xla/xla/stream_executor/stream_executor_memory_allocator.h new file mode 100644 index 00000000000000..077b7161b66321 --- /dev/null +++ b/third_party/xla/xla/stream_executor/stream_executor_memory_allocator.h @@ -0,0 +1,83 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_STREAM_EXECUTOR_STREAM_EXECUTOR_MEMORY_ALLOCATOR_H_ +#define XLA_STREAM_EXECUTOR_STREAM_EXECUTOR_MEMORY_ALLOCATOR_H_ + +#include +#include +#include +#include + +#include "absl/base/thread_annotations.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/synchronization/mutex.h" +#include "absl/types/span.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/device_memory_allocator.h" +#include "xla/stream_executor/platform.h" +#include "xla/stream_executor/stream_executor_interface.h" + +namespace stream_executor { + +// Default memory allocator for a platform which uses +// StreamExecutor::Allocate/Deallocate. +class StreamExecutorMemoryAllocator : public DeviceMemoryAllocator { + public: + // Create an allocator supporting a single device, corresponding to the + // passed executor. + explicit StreamExecutorMemoryAllocator(StreamExecutorInterface *executor); + + // Create an allocator supporting multiple stream executors. + // + // Precondition: all stream_executors have different device ordinals. + StreamExecutorMemoryAllocator( + const Platform *platform, + absl::Span stream_executors); + + absl::StatusOr Allocate(int device_ordinal, uint64_t size, + bool retry_on_failure, + int64_t memory_space) override; + + // Pull in two-arg overload that sets retry_on_failure to true. + using DeviceMemoryAllocator::Allocate; + + absl::Status Deallocate(int device_ordinal, DeviceMemoryBase mem) override; + + bool AllowsAsynchronousDeallocation() const override; + + // Gets-or-creates a stream for a given `device_ordinal` from an appropriate + // stream executor. + absl::StatusOr GetStream(int device_ordinal) override; + + // Gets the stream executor for given device ordinal. + absl::StatusOr GetStreamExecutor( + int device_ordinal) const; + + private: + // Available stream executors. Each stream executor has a different device + // ordinal. + std::vector stream_executors_; + + absl::Mutex mutex_; + + // Cache of streams for GetStream. + std::map> streams_ ABSL_GUARDED_BY(mutex_); +}; + +} // namespace stream_executor + +#endif // XLA_STREAM_EXECUTOR_STREAM_EXECUTOR_MEMORY_ALLOCATOR_H_ diff --git a/third_party/xla/xla/tests/BUILD b/third_party/xla/xla/tests/BUILD index 1f59b9c1b2ca8b..0ca5d12b25a262 100644 --- a/third_party/xla/xla/tests/BUILD +++ b/third_party/xla/xla/tests/BUILD @@ -220,6 +220,7 @@ cc_library( "//xla/service:interpreter_plugin", # reference backend "//xla/service:platform_util", "//xla/stream_executor", + "//xla/stream_executor:stream_executor_memory_allocator", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/log:check", "@com_google_absl//absl/types:span", @@ -373,6 +374,7 @@ cc_library( "//xla/service:transfer_manager", "//xla/stream_executor", "//xla/stream_executor:device_memory_allocator", + "//xla/stream_executor:stream_executor_memory_allocator", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@eigen_archive//:eigen3", @@ -414,6 +416,7 @@ xla_test( "//xla/hlo/ir:hlo", "//xla/service:backend", "//xla/service:executable", + "//xla/stream_executor:stream_executor_memory_allocator", "@local_tsl//tsl/lib/core:status_test_util", ], ) @@ -538,6 +541,7 @@ xla_test( "//xla/client:xla_computation", "//xla/client/lib:arithmetic", "//xla/service:platform_util", + "//xla/stream_executor:stream_executor_memory_allocator", "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:test", @@ -922,6 +926,7 @@ xla_test( "//xla/client/lib:arithmetic", "//xla/client/lib:matrix", "//xla/service:hlo_parser", + "//xla/stream_executor:stream_executor_memory_allocator", "@com_google_absl//absl/strings", "@local_tsl//tsl/platform:test", "@local_tsl//tsl/platform:test_benchmark", @@ -959,6 +964,7 @@ xla_test( "//xla/client/lib:arithmetic", "//xla/client/lib:matrix", "//xla/service:hlo_parser", + "//xla/stream_executor:stream_executor_memory_allocator", "@com_google_absl//absl/strings", "@local_tsl//tsl/platform:test", "@local_tsl//tsl/platform:test_benchmark", @@ -1001,6 +1007,7 @@ xla_test( "//xla/client/lib:arithmetic", "//xla/client/lib:matrix", "//xla/service:hlo_parser", + "//xla/stream_executor:stream_executor_memory_allocator", "@com_google_absl//absl/strings", "@local_tsl//tsl/platform:test", "@local_tsl//tsl/platform:test_benchmark", @@ -1077,6 +1084,7 @@ xla_test( "//xla/client/lib:arithmetic", "//xla/client/lib:matrix", "//xla/service:hlo_parser", + "//xla/stream_executor:stream_executor_memory_allocator", "@com_google_absl//absl/strings", "@local_tsl//tsl/platform:test", "@local_tsl//tsl/platform:test_benchmark", @@ -1533,6 +1541,7 @@ xla_test( "//xla/service:transfer_manager", "//xla/stream_executor", "//xla/stream_executor:device_memory_allocator", + "//xla/stream_executor:stream_executor_memory_allocator", "@local_tsl//tsl/platform:test", "@local_tsl//tsl/platform:test_benchmark", ], @@ -2475,6 +2484,7 @@ xla_test( "//xla/client:xla_builder", "//xla/hlo/ir:hlo", "//xla/service:platform_util", + "//xla/stream_executor:stream_executor_memory_allocator", "@com_google_absl//absl/types:span", "@eigen_archive//:eigen3", "@local_tsl//tsl/platform:logging", @@ -2572,6 +2582,7 @@ xla_test( "//xla/stream_executor", "//xla/stream_executor:device_memory_allocator", "//xla/stream_executor:platform_manager", + "//xla/stream_executor:stream_executor_memory_allocator", "//xla/stream_executor/host:host_platform", "//xla/stream_executor/host:host_platform_id", "@local_tsl//tsl/platform:env", diff --git a/third_party/xla/xla/tests/buffer_donation_test.cc b/third_party/xla/xla/tests/buffer_donation_test.cc index 184161921433f4..44ff367ad6aa79 100644 --- a/third_party/xla/xla/tests/buffer_donation_test.cc +++ b/third_party/xla/xla/tests/buffer_donation_test.cc @@ -26,6 +26,7 @@ limitations under the License. #include "xla/service/backend.h" #include "xla/service/executable.h" #include "xla/status_macros.h" +#include "xla/stream_executor/stream_executor_memory_allocator.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/literal_test_util.h" #include "xla/tests/verified_hlo_module.h" diff --git a/third_party/xla/xla/tests/cpu_gpu_fusion_test.cc b/third_party/xla/xla/tests/cpu_gpu_fusion_test.cc index 9524d65d41d486..34cd15db068096 100644 --- a/third_party/xla/xla/tests/cpu_gpu_fusion_test.cc +++ b/third_party/xla/xla/tests/cpu_gpu_fusion_test.cc @@ -36,6 +36,7 @@ limitations under the License. #include "xla/primitive_util.h" #include "xla/service/platform_util.h" #include "xla/shape_util.h" +#include "xla/stream_executor/stream_executor_memory_allocator.h" #include "xla/tests/client_library_test_base.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/literal_test_util.h" diff --git a/third_party/xla/xla/tests/dot_operation_test.cc b/third_party/xla/xla/tests/dot_operation_test.cc index f272d8da4e00d3..546d53e456e5a7 100644 --- a/third_party/xla/xla/tests/dot_operation_test.cc +++ b/third_party/xla/xla/tests/dot_operation_test.cc @@ -31,6 +31,7 @@ limitations under the License. #include "xla/reference_util.h" #include "xla/service/hlo_parser.h" #include "xla/shape_util.h" +#include "xla/stream_executor/stream_executor_memory_allocator.h" #include "xla/tests/client_library_test_base.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/test_macros.h" diff --git a/third_party/xla/xla/tests/dynamic_ops_test.cc b/third_party/xla/xla/tests/dynamic_ops_test.cc index 58e4a1bb0731c3..0de3c8f6dbec37 100644 --- a/third_party/xla/xla/tests/dynamic_ops_test.cc +++ b/third_party/xla/xla/tests/dynamic_ops_test.cc @@ -27,6 +27,7 @@ limitations under the License. #include "xla/service/transfer_manager.h" #include "xla/stream_executor/device_memory_allocator.h" #include "xla/stream_executor/stream_executor.h" +#include "xla/stream_executor/stream_executor_memory_allocator.h" #include "xla/test_helpers.h" #include "xla/tests/client_library_test_base.h" #include "xla/tests/hlo_test_base.h" diff --git a/third_party/xla/xla/tests/hlo_test_base.cc b/third_party/xla/xla/tests/hlo_test_base.cc index 12f8b6fde817b1..6f19458e56c598 100644 --- a/third_party/xla/xla/tests/hlo_test_base.cc +++ b/third_party/xla/xla/tests/hlo_test_base.cc @@ -36,6 +36,7 @@ limitations under the License. #include "xla/shape_util.h" #include "xla/statusor.h" #include "xla/stream_executor/device_memory_allocator.h" +#include "xla/stream_executor/stream_executor_memory_allocator.h" #include "xla/tests/filecheck.h" #include "xla/tests/literal_test_util.h" #include "xla/tests/pjrt_client_registry.h" diff --git a/third_party/xla/xla/tests/local_client_execute_test.cc b/third_party/xla/xla/tests/local_client_execute_test.cc index af10fcf3620f8b..53fce7717be064 100644 --- a/third_party/xla/xla/tests/local_client_execute_test.cc +++ b/third_party/xla/xla/tests/local_client_execute_test.cc @@ -32,6 +32,7 @@ limitations under the License. #include "xla/stream_executor/host/host_platform_id.h" #include "xla/stream_executor/platform_manager.h" #include "xla/stream_executor/stream_executor.h" +#include "xla/stream_executor/stream_executor_memory_allocator.h" #include "xla/test_helpers.h" #include "xla/tests/literal_test_util.h" #include "xla/tests/local_client_test_base.h" diff --git a/third_party/xla/xla/tests/local_client_test_base.cc b/third_party/xla/xla/tests/local_client_test_base.cc index 60d0310fa7d0b5..aa8df18711c4f5 100644 --- a/third_party/xla/xla/tests/local_client_test_base.cc +++ b/third_party/xla/xla/tests/local_client_test_base.cc @@ -29,6 +29,7 @@ limitations under the License. #include "xla/shape_util.h" #include "xla/status_macros.h" #include "xla/statusor.h" +#include "xla/stream_executor/stream_executor_memory_allocator.h" #include "xla/test_helpers.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" diff --git a/third_party/xla/xla/tests/local_client_test_base.h b/third_party/xla/xla/tests/local_client_test_base.h index 02504d0ab0b7bd..2ade8663703027 100644 --- a/third_party/xla/xla/tests/local_client_test_base.h +++ b/third_party/xla/xla/tests/local_client_test_base.h @@ -34,6 +34,7 @@ limitations under the License. #include "xla/statusor.h" #include "xla/stream_executor/device_memory_allocator.h" #include "xla/stream_executor/stream_executor.h" +#include "xla/stream_executor/stream_executor_memory_allocator.h" #include "xla/tests/client_library_test_base.h" #include "xla/tests/manifest_checking_test.h" #include "xla/tests/verified_hlo_module.h" diff --git a/third_party/xla/xla/tests/while_test.cc b/third_party/xla/xla/tests/while_test.cc index 2406e7869925fd..dd9db003037340 100644 --- a/third_party/xla/xla/tests/while_test.cc +++ b/third_party/xla/xla/tests/while_test.cc @@ -27,6 +27,7 @@ limitations under the License. #include "xla/shape_util.h" #include "xla/status_macros.h" #include "xla/statusor.h" +#include "xla/stream_executor/stream_executor_memory_allocator.h" #include "xla/tests/client_library_test_base.h" #include "xla/tests/literal_test_util.h" #include "xla/tests/test_macros.h" diff --git a/third_party/xla/xla/tools/BUILD b/third_party/xla/xla/tools/BUILD index 0b31a60c0cae19..72833194a5283d 100644 --- a/third_party/xla/xla/tools/BUILD +++ b/third_party/xla/xla/tools/BUILD @@ -743,6 +743,7 @@ tsl_gpu_library( "//xla/service/cpu:cpu_executable", "//xla/stream_executor", "//xla/stream_executor:device_memory_allocator", + "//xla/stream_executor:stream_executor_memory_allocator", "@com_google_absl//absl/cleanup", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", diff --git a/third_party/xla/xla/tools/xla_compile_lib.cc b/third_party/xla/xla/tools/xla_compile_lib.cc index 2c7e2b36b50f46..40681060a9fcb4 100644 --- a/third_party/xla/xla/tools/xla_compile_lib.cc +++ b/third_party/xla/xla/tools/xla_compile_lib.cc @@ -56,6 +56,7 @@ limitations under the License. #include "xla/status.h" #include "xla/stream_executor/device_memory_allocator.h" #include "xla/stream_executor/stream_executor.h" +#include "xla/stream_executor/stream_executor_memory_allocator.h" #include "xla/tools/hlo_module_loader.h" #include "xla/util.h" #include "tsl/platform/env.h"