From 7e2f75ea7ea50faa8857c0dc7b234edbada95626 Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Wed, 8 Oct 2025 10:09:36 -0700 Subject: [PATCH 1/2] Move cuda/runtime/shim/utils to cuda/runtime for better usibility. This diff moves the `cuda/runtime/shim` directory to `cuda/runtime` to make the utility functions sharable among all runtime infrastructure. Differential Revision: [D84169267](https://our.internmc.facebook.com/intern/diff/D84169267/) ghstack-source-id: 314867735 Pull Request resolved: https://github.com/pytorch/executorch/pull/14900 --- backends/cuda/runtime/TARGETS | 2 +- backends/cuda/runtime/shims/memory.cpp | 2 +- .../runtime/shims/tests/test_aoti_torch__reinterpret_tensor.cpp | 2 +- backends/cuda/runtime/shims/tests/test_aoti_torch_copy_.cpp | 2 +- .../shims/tests/test_aoti_torch_create_tensor_from_blob_v2.cpp | 2 +- .../shims/tests/test_aoti_torch_delete_tensor_object.cpp | 2 +- .../cuda/runtime/shims/tests/test_aoti_torch_empty_strided.cpp | 2 +- backends/cuda/runtime/{shims => }/utils.h | 0 8 files changed, 7 insertions(+), 7 deletions(-) rename backends/cuda/runtime/{shims => }/utils.h (100%) diff --git a/backends/cuda/runtime/TARGETS b/backends/cuda/runtime/TARGETS index 1aa38760e5a..29fba0e706a 100644 --- a/backends/cuda/runtime/TARGETS +++ b/backends/cuda/runtime/TARGETS @@ -11,7 +11,7 @@ runtime.cxx_library( headers = [ "shims/memory.h", "shims/tensor_attribute.h", - "shims/utils.h", + "utils.h", ], # @lint-ignore BUCKLINT: Avoid `link_whole=True` (https://fburl.com/avoid-link-whole) link_whole = True, diff --git a/backends/cuda/runtime/shims/memory.cpp b/backends/cuda/runtime/shims/memory.cpp index 2b32d820301..cbaca68576e 100644 --- a/backends/cuda/runtime/shims/memory.cpp +++ b/backends/cuda/runtime/shims/memory.cpp @@ -10,7 +10,7 @@ #include #include #include -#include +#include #include #include #include // For posix_memalign diff --git a/backends/cuda/runtime/shims/tests/test_aoti_torch__reinterpret_tensor.cpp b/backends/cuda/runtime/shims/tests/test_aoti_torch__reinterpret_tensor.cpp index ef00ecff656..e18bf142b5c 100644 --- a/backends/cuda/runtime/shims/tests/test_aoti_torch__reinterpret_tensor.cpp +++ b/backends/cuda/runtime/shims/tests/test_aoti_torch__reinterpret_tensor.cpp @@ -10,7 +10,7 @@ #include #include #include -#include +#include #include #include #include diff --git a/backends/cuda/runtime/shims/tests/test_aoti_torch_copy_.cpp b/backends/cuda/runtime/shims/tests/test_aoti_torch_copy_.cpp index 7579eaef039..9fca0f92cf8 100644 --- a/backends/cuda/runtime/shims/tests/test_aoti_torch_copy_.cpp +++ b/backends/cuda/runtime/shims/tests/test_aoti_torch_copy_.cpp @@ -10,7 +10,7 @@ #include #include #include -#include +#include #include #include #include diff --git a/backends/cuda/runtime/shims/tests/test_aoti_torch_create_tensor_from_blob_v2.cpp b/backends/cuda/runtime/shims/tests/test_aoti_torch_create_tensor_from_blob_v2.cpp index 2cb12719782..d9b785a5a78 100644 --- a/backends/cuda/runtime/shims/tests/test_aoti_torch_create_tensor_from_blob_v2.cpp +++ b/backends/cuda/runtime/shims/tests/test_aoti_torch_create_tensor_from_blob_v2.cpp @@ -10,7 +10,7 @@ #include #include #include -#include +#include #include #include #include diff --git a/backends/cuda/runtime/shims/tests/test_aoti_torch_delete_tensor_object.cpp b/backends/cuda/runtime/shims/tests/test_aoti_torch_delete_tensor_object.cpp index eceb141e9ca..10c8d8c1a31 100644 --- a/backends/cuda/runtime/shims/tests/test_aoti_torch_delete_tensor_object.cpp +++ b/backends/cuda/runtime/shims/tests/test_aoti_torch_delete_tensor_object.cpp @@ -10,7 +10,7 @@ #include #include #include -#include +#include #include #include #include diff --git a/backends/cuda/runtime/shims/tests/test_aoti_torch_empty_strided.cpp b/backends/cuda/runtime/shims/tests/test_aoti_torch_empty_strided.cpp index 8e6998f457c..da65129f18a 100644 --- a/backends/cuda/runtime/shims/tests/test_aoti_torch_empty_strided.cpp +++ b/backends/cuda/runtime/shims/tests/test_aoti_torch_empty_strided.cpp @@ -10,7 +10,7 @@ #include #include #include -#include +#include #include #include #include diff --git a/backends/cuda/runtime/shims/utils.h b/backends/cuda/runtime/utils.h similarity index 100% rename from backends/cuda/runtime/shims/utils.h rename to backends/cuda/runtime/utils.h From 0a7ac54a42949f624b9973bb2dcdb9a384459382 Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Wed, 8 Oct 2025 10:09:40 -0700 Subject: [PATCH 2/2] introduce CudaGuard and cudastreamguard ### Introduce CudaGuard and CudaStreamGuard This diff introduces `CudaGuard` and `CudaStreamGuard` in the Executorch CUDA runtime. These classes provide a convenient way to manage CUDA device and stream selection. #### Changes * Added `CudaGuard` and `CudaStreamGuard` classes in `fbcode/executorch/backends/cuda/runtime/guard.h`. * Implemented `CudaGuard` and `CudaStreamGuard` in `fbcode/executorch/backends/cuda/runtime/guard.cpp`. * Added unit tests for `CudaStreamGuard` in `fbcode/executorch/backends/cuda/runtime/tests/test_cuda_stream_guard.cpp`. * Updated `TARGETS` file to include the new files. #### Purpose The `CudaGuard` class provides a way to select a CUDA device and ensure that it is properly released when the guard goes out of scope. The `CudaStreamGuard` class provides a way to select a CUDA stream and ensure that it is properly synchronized when the guard goes out of scope. #### Usage They will be further used and controled by their shim layer functions. Differential Revision: [D84126481](https://our.internmc.facebook.com/intern/diff/D84126481/) ghstack-source-id: 314867936 Pull Request resolved: https://github.com/pytorch/executorch/pull/14901 --- backends/cuda/CMakeLists.txt | 2 +- backends/cuda/runtime/TARGETS | 2 + backends/cuda/runtime/guard.cpp | 151 ++++++++++ backends/cuda/runtime/guard.h | 195 +++++++++++++ backends/cuda/runtime/tests/TARGETS | 6 + backends/cuda/runtime/tests/targets.bzl | 27 ++ .../cuda/runtime/tests/test_cuda_guard.cpp | 113 ++++++++ .../runtime/tests/test_cuda_stream_guard.cpp | 264 ++++++++++++++++++ 8 files changed, 759 insertions(+), 1 deletion(-) create mode 100644 backends/cuda/runtime/guard.cpp create mode 100644 backends/cuda/runtime/guard.h create mode 100644 backends/cuda/runtime/tests/TARGETS create mode 100644 backends/cuda/runtime/tests/targets.bzl create mode 100644 backends/cuda/runtime/tests/test_cuda_guard.cpp create mode 100644 backends/cuda/runtime/tests/test_cuda_stream_guard.cpp diff --git a/backends/cuda/CMakeLists.txt b/backends/cuda/CMakeLists.txt index 90588218c02..7a53478773d 100644 --- a/backends/cuda/CMakeLists.txt +++ b/backends/cuda/CMakeLists.txt @@ -36,7 +36,7 @@ find_package_torch() # CUDA-specific AOTI functionality set(_aoti_cuda_sources runtime/cuda_backend.cpp runtime/shims/memory.cpp - runtime/shims/tensor_attribute.cpp + runtime/shims/tensor_attribute.cpp runtime/guard.cpp ) add_library(aoti_cuda STATIC ${_aoti_cuda_sources}) target_include_directories( diff --git a/backends/cuda/runtime/TARGETS b/backends/cuda/runtime/TARGETS index 29fba0e706a..c4b778eccc5 100644 --- a/backends/cuda/runtime/TARGETS +++ b/backends/cuda/runtime/TARGETS @@ -5,10 +5,12 @@ oncall("executorch") runtime.cxx_library( name = "runtime_shims", srcs = [ + "guard.cpp", "shims/memory.cpp", "shims/tensor_attribute.cpp", ], headers = [ + "guard.h", "shims/memory.h", "shims/tensor_attribute.h", "utils.h", diff --git a/backends/cuda/runtime/guard.cpp b/backends/cuda/runtime/guard.cpp new file mode 100644 index 00000000000..885efc7670d --- /dev/null +++ b/backends/cuda/runtime/guard.cpp @@ -0,0 +1,151 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +namespace executorch { +namespace backends { +namespace cuda { + +namespace { +// Thread-local stream storage (private to this file) +thread_local std::unordered_map current_streams_; +} // namespace + +Error setCurrentCUDAStream(cudaStream_t stream, DeviceIndex device_index) { + if (device_index == -1) { + // Get current device if not specified + int current_device; + ET_CUDA_CHECK_OR_RETURN_ERROR(cudaGetDevice(¤t_device)); + device_index = current_device; + } + + current_streams_[device_index] = stream; + return Error::Ok; +} + +Result getCurrentCUDAStream(DeviceIndex device_index) { + if (device_index == -1) { + int current_device; + ET_CUDA_CHECK_OR_RETURN_ERROR(cudaGetDevice(¤t_device)); + device_index = current_device; + } + + auto it = current_streams_.find(device_index); + if (it != current_streams_.end()) { + return it->second; + } + + cudaStream_t stream; + ET_CUDA_CHECK_OR_RETURN_ERROR(cudaStreamCreate(&stream)); + setCurrentCUDAStream(stream, device_index); + return stream; +} + +CUDAGuard::CUDAGuard(CUDAGuard&& other) noexcept + : original_device_index_(other.original_device_index_), + current_device_index_(other.current_device_index_) { + // Mark the moved-from object as "already restored" so its destructor doesn't + // try to restore the device + other.original_device_index_ = other.current_device_index_; +} + +CUDAGuard::~CUDAGuard() { + if (original_device_index_ != current_device_index_) { + cudaError_t err = cudaSetDevice(original_device_index_); + if (err != cudaSuccess) { + ET_LOG( + Error, + "~CUDAGuard: Failed to restore device to %d: %s", + original_device_index_, + cudaGetErrorString(err)); + } + } +} + +Error CUDAGuard::set_index(DeviceIndex device_index) { + int orig_index = -1; + ET_CUDA_CHECK_OR_RETURN_ERROR(cudaGetDevice(&orig_index)); + + original_device_index_ = orig_index; + current_device_index_ = device_index; + + if (current_device_index_ != original_device_index_) { + ET_CUDA_CHECK_OR_RETURN_ERROR(cudaSetDevice(current_device_index_)); + } + + return Error::Ok; +} + +Result CUDAGuard::create(DeviceIndex device_index) { + CUDAGuard guard; // Fixed: Removed () to create a variable, not a function + ET_CHECK_OK_OR_RETURN_ERROR(guard.set_index(device_index)); + return guard; +} + +CUDAStreamGuard::CUDAStreamGuard(CUDAStreamGuard&& other) noexcept + : device_guard_(std::move(other.device_guard_)), + original_stream_(other.original_stream_), + current_stream_(other.current_stream_), + device_index_(other.device_index_) { + // Mark the moved-from object as "already restored" so its destructor doesn't + // try to restore the stream + other.original_stream_ = other.current_stream_; +} + +CUDAStreamGuard::~CUDAStreamGuard() { + // Restore the original stream unless this object was moved-from. + // After a move, original_stream_ == current_stream_, which indicates + // the moved-from object should not restore. + // Note: nullptr is a valid stream value (represents the default stream), + // so we must restore even if original_stream_ is nullptr. + if (original_stream_ != current_stream_) { + Error err = setCurrentCUDAStream(original_stream_, device_index_); + if (err != Error::Ok) { + ET_LOG( + Error, + "~CUDAStreamGuard: Failed to restore stream for device %d", + device_index_); + } + } +} + +Error CUDAStreamGuard::set_stream( + cudaStream_t stream, + DeviceIndex device_index) { + auto result = getCurrentCUDAStream(device_index); + if (!result.ok()) { + ET_LOG(Error, "Failed to get current stream for device %d", device_index); + return result.error(); + } + + original_stream_ = result.get(); + current_stream_ = stream; + device_index_ = device_index; + + ET_CHECK_OK_OR_RETURN_ERROR(setCurrentCUDAStream(stream, device_index)); + + return Error::Ok; +} + +Result CUDAStreamGuard::create( + cudaStream_t stream, + DeviceIndex device_index) { + auto guard_result = CUDAGuard::create(device_index); + ET_CHECK_OK_OR_RETURN_ERROR(guard_result.error()); + + CUDAStreamGuard stream_guard(std::move(guard_result.get())); + ET_CHECK_OK_OR_RETURN_ERROR(stream_guard.set_stream(stream, device_index)); + + return stream_guard; +} + +} // namespace cuda +} // namespace backends +} // namespace executorch diff --git a/backends/cuda/runtime/guard.h b/backends/cuda/runtime/guard.h new file mode 100644 index 00000000000..4e5a18a4c0f --- /dev/null +++ b/backends/cuda/runtime/guard.h @@ -0,0 +1,195 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include +#include + +namespace executorch { +namespace backends { +namespace cuda { + +using executorch::runtime::Error; +using executorch::runtime::Result; + +// Type alias for device index +using DeviceIndex = int32_t; + +/** + * Set the current CUDA stream for the specified device. + * + * @param stream The CUDA stream to set as current + * @param device_index The device index (-1 to use current device) + * @return Error code indicating success or failure + */ +Error setCurrentCUDAStream(cudaStream_t stream, DeviceIndex device_index = -1); + +/** + * Get the current CUDA stream for the specified device. + * If no stream has been set, creates a new stream and sets it as current. + * + * @param device_index The device index (-1 to use current device) + * @return Result containing the current stream on success, or an error code on + * failure + */ +Result getCurrentCUDAStream(DeviceIndex device_index = -1); + +/** + * RAII guard that sets the current CUDA device and restores it on destruction. + * This ensures that the device is properly restored even if an exception + * occurs. + * + */ +class CUDAGuard { + private: + /** + * Private constructor - use create() factory method instead. + */ + explicit CUDAGuard() + : original_device_index_(-1), current_device_index_(-1) {} + + public: + /** + * Factory method to create a CUDAGuard. + * + * @param device_index The device index to set as current + * @return Result containing the guard on success, or an error code on failure + */ + static Result create(DeviceIndex device_index); + + // Copy is not allowed + CUDAGuard(const CUDAGuard&) = delete; + CUDAGuard& operator=(const CUDAGuard&) = delete; + + // Move constructor and assignment + CUDAGuard(CUDAGuard&& other) noexcept; + CUDAGuard& operator=(CUDAGuard&& other) = delete; + + /** + * Destructor that restores the original device if necessary. + */ + ~CUDAGuard(); + + /** + * Sets the CUDA device to the given device index. + * + * @param device_index The device index to set as current + * @return Error code indicating success or failure + */ + Error set_index(DeviceIndex device_index); + + /** + * Get the original device index before the guard was created. + * + * @return The original device index + */ + DeviceIndex original_device() const { + return original_device_index_; + } + + /** + * Get the current device index. + * + * @return The current device index + */ + DeviceIndex current_device() const { + return current_device_index_; + } + + private: + /// The original device before this guard was created + DeviceIndex original_device_index_; + /// The current device managed by this guard + DeviceIndex current_device_index_; +}; + +/** + * RAII guard that sets the current CUDA device and stream, restoring both on + * destruction. This is useful for temporarily switching to a different device + * and stream. + * + */ +class CUDAStreamGuard { + private: + // Private constructor that takes a CUDAGuard + explicit CUDAStreamGuard(CUDAGuard&& guard) + : device_guard_(std::move(guard)), + original_stream_(nullptr), + current_stream_(nullptr), + device_index_(-1) {} + + public: + /** + * Factory method to create a CUDAStreamGuard. + * + * @param stream The CUDA stream to set as current + * @param device_index The device index for the stream + * @return Result containing the guard on success, or an error code on failure + */ + static Result create( + cudaStream_t stream, + DeviceIndex device_index); + + // Copy is not allowed + CUDAStreamGuard(const CUDAStreamGuard&) = delete; + CUDAStreamGuard& operator=(const CUDAStreamGuard&) = delete; + + // Move constructor and assignment + CUDAStreamGuard(CUDAStreamGuard&& other) noexcept; + CUDAStreamGuard& operator=(CUDAStreamGuard&& other) noexcept = delete; + + /** + * Destructor that restores the original stream and device. + */ + ~CUDAStreamGuard(); + + /** + * Sets the CUDA stream to the given stream on the specified device. + * + * @param stream The CUDA stream to set as current + * @param device_index The device index for the stream + * @return Error code indicating success or failure + */ + Error set_stream(cudaStream_t stream, DeviceIndex device_index); + + /** + * Get the current guarded stream. + * + * @return The current stream + */ + cudaStream_t stream() const { + return current_stream_; + } + + /** + * Get the device index being guarded. + * + * @return The device index + */ + DeviceIndex device_index() const { + return device_index_; + } + + private: + /// The device guard that handles device switching + CUDAGuard device_guard_; + /// The original stream that was current before this guard + cudaStream_t original_stream_ = nullptr; + /// The current stream being guarded + cudaStream_t current_stream_ = nullptr; + /// The device index for this stream guard + DeviceIndex device_index_; +}; + +} // namespace cuda +} // namespace backends +} // namespace executorch diff --git a/backends/cuda/runtime/tests/TARGETS b/backends/cuda/runtime/tests/TARGETS new file mode 100644 index 00000000000..9ff3e83a8bd --- /dev/null +++ b/backends/cuda/runtime/tests/TARGETS @@ -0,0 +1,6 @@ +load("@fbcode_macros//build_defs:cpp_unittest.bzl", "cpp_unittest") +load(":targets.bzl", "define_common_targets") + +oncall("executorch") + +define_common_targets() diff --git a/backends/cuda/runtime/tests/targets.bzl b/backends/cuda/runtime/tests/targets.bzl new file mode 100644 index 00000000000..37e8d876526 --- /dev/null +++ b/backends/cuda/runtime/tests/targets.bzl @@ -0,0 +1,27 @@ +load("@fbcode_macros//build_defs:cpp_unittest.bzl", "cpp_unittest") + +def cuda_runtime_cpp_unittest(name): + cpp_unittest( + name = "test_" + name, + srcs = [ + "test_" + name + ".cpp", + ], + deps = [ + "//executorch/backends/cuda/runtime:runtime_shims", + "//executorch/runtime/core:core", + "//executorch/runtime/core/exec_aten:lib", + "//executorch/runtime/platform:platform", + ], + external_deps = [ + ("cuda", None, "cuda-lazy"), + ], + ) + +def define_common_targets(): + """Defines targets that should be shared between fbcode and xplat. + + The directory containing this targets.bzl file should also contain both + TARGETS and BUCK files that call this function. + """ + cuda_runtime_cpp_unittest("cuda_guard") + cuda_runtime_cpp_unittest("cuda_stream_guard") diff --git a/backends/cuda/runtime/tests/test_cuda_guard.cpp b/backends/cuda/runtime/tests/test_cuda_guard.cpp new file mode 100644 index 00000000000..a364ae98484 --- /dev/null +++ b/backends/cuda/runtime/tests/test_cuda_guard.cpp @@ -0,0 +1,113 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include + +using namespace executorch::backends::cuda; +using namespace executorch::runtime; + +// TODO(gasoonjia): Multiple device tests were not included due to test +// environment limitations. These tests should be added in the future when +// multi-GPU test environments are available, + +class CUDAGuardTest : public ::testing::Test { + protected: + void SetUp() override { + et_pal_init(); + + int device_count = 0; + cudaError_t error = cudaGetDeviceCount(&device_count); + if (error != cudaSuccess || device_count == 0) { + GTEST_SKIP() << "CUDA not available or no CUDA devices found"; + } + device_count_ = device_count; + + ASSERT_EQ(cudaGetDevice(&original_device_), cudaSuccess); + } + + void TearDown() override { + if (device_count_ > 0) { + ASSERT_EQ(cudaSetDevice(original_device_), cudaSuccess); + } + } + + int device_count_ = 0; + int original_device_ = 0; +}; + +TEST_F(CUDAGuardTest, BasicDeviceSwitching) { + int current_device; + ASSERT_EQ(cudaGetDevice(¤t_device), cudaSuccess); + + { + auto guard_result = CUDAGuard::create(0); + ASSERT_TRUE(guard_result.ok()); + CUDAGuard guard = std::move(guard_result.get()); + + int device_after_guard; + ASSERT_EQ(cudaGetDevice(&device_after_guard), cudaSuccess); + EXPECT_EQ(device_after_guard, 0); + EXPECT_EQ(guard.current_device(), 0); + EXPECT_EQ(guard.original_device(), current_device); + } + + int device_after_destruction; + ASSERT_EQ(cudaGetDevice(&device_after_destruction), cudaSuccess); + EXPECT_EQ(device_after_destruction, current_device); +} + +TEST_F(CUDAGuardTest, SameDeviceNoSwitching) { + ASSERT_EQ(cudaSetDevice(0), cudaSuccess); + + { + auto guard_result = CUDAGuard::create(0); + ASSERT_TRUE(guard_result.ok()); + CUDAGuard guard = std::move(guard_result.get()); + + int current_device; + ASSERT_EQ(cudaGetDevice(¤t_device), cudaSuccess); + EXPECT_EQ(current_device, 0); + EXPECT_EQ(guard.current_device(), 0); + EXPECT_EQ(guard.original_device(), 0); + } + + int final_device; + ASSERT_EQ(cudaGetDevice(&final_device), cudaSuccess); + EXPECT_EQ(final_device, 0); +} + +TEST_F(CUDAGuardTest, InvalidDeviceIndex) { + auto guard_result = CUDAGuard::create(999); + EXPECT_FALSE(guard_result.ok()); +} + +TEST_F(CUDAGuardTest, NegativeDeviceIndex) { + auto guard_result = CUDAGuard::create(-2); + EXPECT_FALSE(guard_result.ok()); +} + +TEST_F(CUDAGuardTest, CopyConstructorDeleted) { + static_assert( + !std::is_copy_constructible_v, + "CUDAGuard should not be copy constructible"); +} + +TEST_F(CUDAGuardTest, CopyAssignmentDeleted) { + static_assert( + !std::is_copy_assignable_v, + "CUDAGuard should not be copy assignable"); +} + +TEST_F(CUDAGuardTest, MoveAssignmentDeleted) { + static_assert( + !std::is_move_assignable_v, + "CUDAGuard should not be move assignable"); +} diff --git a/backends/cuda/runtime/tests/test_cuda_stream_guard.cpp b/backends/cuda/runtime/tests/test_cuda_stream_guard.cpp new file mode 100644 index 00000000000..68a050a69be --- /dev/null +++ b/backends/cuda/runtime/tests/test_cuda_stream_guard.cpp @@ -0,0 +1,264 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include + +using namespace executorch::backends::cuda; +using namespace executorch::runtime; + +// TODO(gasoonjia): Multiple device tests were not included due to test +// environment limitations. These tests should be added in the future when +// multi-GPU test environments are available, + +class CUDAStreamGuardTest : public ::testing::Test { + protected: + void SetUp() override { + et_pal_init(); + + int device_count = 0; + cudaError_t error = cudaGetDeviceCount(&device_count); + if (error != cudaSuccess || device_count == 0) { + GTEST_SKIP() << "CUDA not available or no CUDA devices found"; + } + device_count_ = device_count; + + ASSERT_EQ(cudaGetDevice(&original_device_), cudaSuccess); + + ASSERT_EQ(cudaStreamCreate(&test_stream1_), cudaSuccess); + ASSERT_EQ(cudaStreamCreate(&test_stream2_), cudaSuccess); + } + + void TearDown() override { + if (test_stream1_) { + ASSERT_EQ(cudaStreamDestroy(test_stream1_), cudaSuccess); + } + if (test_stream2_) { + ASSERT_EQ(cudaStreamDestroy(test_stream2_), cudaSuccess); + } + + if (device_count_ > 0) { + ASSERT_EQ(cudaSetDevice(original_device_), cudaSuccess); + } + } + + int device_count_ = 0; + int original_device_ = 0; + cudaStream_t test_stream1_ = nullptr; + cudaStream_t test_stream2_ = nullptr; +}; + +TEST_F(CUDAStreamGuardTest, BasicStreamSwitching) { + auto guard_result = CUDAStreamGuard::create(test_stream1_, 0); + ASSERT_TRUE(guard_result.ok()); + CUDAStreamGuard guard = std::move(guard_result.get()); + + EXPECT_EQ(guard.stream(), test_stream1_); + EXPECT_EQ(guard.device_index(), 0); + + auto current_stream_result = getCurrentCUDAStream(0); + ASSERT_TRUE(current_stream_result.ok()); + EXPECT_EQ(current_stream_result.get(), test_stream1_); + + int current_device; + ASSERT_EQ(cudaGetDevice(¤t_device), cudaSuccess); + EXPECT_EQ(current_device, 0); +} + +TEST_F(CUDAStreamGuardTest, StreamSwitchingOnSameDevice) { + Error err = setCurrentCUDAStream(test_stream1_, 0); + ASSERT_EQ(err, Error::Ok); + + auto current_stream_result = getCurrentCUDAStream(0); + ASSERT_TRUE(current_stream_result.ok()); + EXPECT_EQ(current_stream_result.get(), test_stream1_); + + { + auto guard_result = CUDAStreamGuard::create(test_stream2_, 0); + ASSERT_TRUE(guard_result.ok()); + CUDAStreamGuard guard = std::move(guard_result.get()); + + auto new_stream_result = getCurrentCUDAStream(0); + ASSERT_TRUE(new_stream_result.ok()); + EXPECT_EQ(new_stream_result.get(), test_stream2_); + EXPECT_EQ(guard.stream(), test_stream2_); + } + + auto restored_stream_result = getCurrentCUDAStream(0); + ASSERT_TRUE(restored_stream_result.ok()); + EXPECT_EQ(restored_stream_result.get(), test_stream1_); +} + +TEST_F(CUDAStreamGuardTest, NestedStreamGuards) { + cudaStream_t initial_stream; + ASSERT_EQ(cudaStreamCreate(&initial_stream), cudaSuccess); + + Error err = setCurrentCUDAStream(initial_stream, 0); + ASSERT_EQ(err, Error::Ok); + + { + auto guard1_result = CUDAStreamGuard::create(test_stream1_, 0); + ASSERT_TRUE(guard1_result.ok()); + CUDAStreamGuard guard1 = std::move(guard1_result.get()); + + auto stream_result = getCurrentCUDAStream(0); + ASSERT_TRUE(stream_result.ok()); + EXPECT_EQ(stream_result.get(), test_stream1_); + + { + auto guard2_result = CUDAStreamGuard::create(test_stream2_, 0); + ASSERT_TRUE(guard2_result.ok()); + CUDAStreamGuard guard2 = std::move(guard2_result.get()); + + auto stream_result2 = getCurrentCUDAStream(0); + ASSERT_TRUE(stream_result2.ok()); + EXPECT_EQ(stream_result2.get(), test_stream2_); + } + + auto stream_result3 = getCurrentCUDAStream(0); + ASSERT_TRUE(stream_result3.ok()); + EXPECT_EQ(stream_result3.get(), test_stream1_); + } + + auto final_stream_result = getCurrentCUDAStream(0); + ASSERT_TRUE(final_stream_result.ok()); + EXPECT_EQ(final_stream_result.get(), initial_stream); + + ASSERT_EQ(cudaStreamDestroy(initial_stream), cudaSuccess); +} + +TEST_F(CUDAStreamGuardTest, SameStreamNoChange) { + Error err = setCurrentCUDAStream(test_stream1_, 0); + ASSERT_EQ(err, Error::Ok); + + { + auto guard_result = CUDAStreamGuard::create(test_stream1_, 0); + ASSERT_TRUE(guard_result.ok()); + CUDAStreamGuard guard = std::move(guard_result.get()); + + auto stream_result = getCurrentCUDAStream(0); + ASSERT_TRUE(stream_result.ok()); + EXPECT_EQ(stream_result.get(), test_stream1_); + EXPECT_EQ(guard.stream(), test_stream1_); + } + + auto final_stream_result = getCurrentCUDAStream(0); + ASSERT_TRUE(final_stream_result.ok()); + EXPECT_EQ(final_stream_result.get(), test_stream1_); +} + +TEST_F(CUDAStreamGuardTest, StreamAccessor) { + auto guard_result = CUDAStreamGuard::create(test_stream1_, 0); + ASSERT_TRUE(guard_result.ok()); + CUDAStreamGuard guard = std::move(guard_result.get()); + + EXPECT_EQ(guard.stream(), test_stream1_); + EXPECT_EQ(guard.device_index(), 0); +} + +TEST_F(CUDAStreamGuardTest, SetStreamMethod) { + auto guard_result = CUDAStreamGuard::create(test_stream1_, 0); + ASSERT_TRUE(guard_result.ok()); + CUDAStreamGuard guard = std::move(guard_result.get()); + + EXPECT_EQ(guard.stream(), test_stream1_); + + Error err = guard.set_stream(test_stream2_, 0); + EXPECT_EQ(err, Error::Ok); + + EXPECT_EQ(guard.stream(), test_stream2_); + + auto current_stream_result = getCurrentCUDAStream(0); + ASSERT_TRUE(current_stream_result.ok()); + EXPECT_EQ(current_stream_result.get(), test_stream2_); +} + +TEST_F(CUDAStreamGuardTest, MoveConstructor) { + auto guard1_result = CUDAStreamGuard::create(test_stream1_, 0); + ASSERT_TRUE(guard1_result.ok()); + CUDAStreamGuard guard1 = std::move(guard1_result.get()); + + EXPECT_EQ(guard1.stream(), test_stream1_); + EXPECT_EQ(guard1.device_index(), 0); + + CUDAStreamGuard guard2 = std::move(guard1); + + EXPECT_EQ(guard2.stream(), test_stream1_); + EXPECT_EQ(guard2.device_index(), 0); + + auto current_stream_result = getCurrentCUDAStream(0); + ASSERT_TRUE(current_stream_result.ok()); + EXPECT_EQ(current_stream_result.get(), test_stream1_); +} + +TEST_F(CUDAStreamGuardTest, MoveConstructorRestoresOnlyOnce) { + cudaStream_t initial_stream; + ASSERT_EQ(cudaStreamCreate(&initial_stream), cudaSuccess); + + Error err = setCurrentCUDAStream(initial_stream, 0); + ASSERT_EQ(err, Error::Ok); + + { + auto guard1_result = CUDAStreamGuard::create(test_stream1_, 0); + ASSERT_TRUE(guard1_result.ok()); + CUDAStreamGuard guard1 = std::move(guard1_result.get()); + + { CUDAStreamGuard guard2 = std::move(guard1); } + + auto stream_result = getCurrentCUDAStream(0); + ASSERT_TRUE(stream_result.ok()); + EXPECT_EQ(stream_result.get(), initial_stream); + } + + auto final_stream_result = getCurrentCUDAStream(0); + ASSERT_TRUE(final_stream_result.ok()); + EXPECT_EQ(final_stream_result.get(), initial_stream); + + ASSERT_EQ(cudaStreamDestroy(initial_stream), cudaSuccess); +} + +TEST_F(CUDAStreamGuardTest, InvalidDeviceIndex) { + auto guard_result = CUDAStreamGuard::create(test_stream1_, 999); + EXPECT_FALSE(guard_result.ok()); +} + +TEST_F(CUDAStreamGuardTest, NegativeDeviceIndex) { + auto guard_result = CUDAStreamGuard::create(test_stream1_, -2); + EXPECT_FALSE(guard_result.ok()); +} + +TEST_F(CUDAStreamGuardTest, CopyConstructorDeleted) { + static_assert( + !std::is_copy_constructible_v, + "CUDAStreamGuard should not be copy constructible"); +} + +TEST_F(CUDAStreamGuardTest, CopyAssignmentDeleted) { + static_assert( + !std::is_copy_assignable_v, + "CUDAStreamGuard should not be copy assignable"); +} + +TEST_F(CUDAStreamGuardTest, MoveAssignmentDeleted) { + static_assert( + !std::is_move_assignable_v, + "CUDAStreamGuard should not be move assignable"); +} + +TEST_F(CUDAStreamGuardTest, NullStreamPointer) { + auto guard_result = CUDAStreamGuard::create(nullptr, 0); + ASSERT_TRUE(guard_result.ok()); + CUDAStreamGuard guard = std::move(guard_result.get()); + + EXPECT_EQ(guard.stream(), nullptr); + + auto current_stream_result = getCurrentCUDAStream(0); + ASSERT_TRUE(current_stream_result.ok()); +}