diff --git a/include/rmm/mr/device/failure_callback_resource_adaptor.hpp b/include/rmm/mr/device/failure_callback_resource_adaptor.hpp index d3f8e5b7a..b4e497868 100644 --- a/include/rmm/mr/device/failure_callback_resource_adaptor.hpp +++ b/include/rmm/mr/device/failure_callback_resource_adaptor.hpp @@ -20,42 +20,40 @@ #include #include +#include namespace rmm::mr { /** * @brief Callback function type used by failure_callback_resource_adaptor * - * The resource adaptor calls this function when a memory allocation throws a - * `std::bad_alloc` exception. The function decides whether the resource adaptor - * should try to allocate the memory again or re-throw the `std::bad_alloc` - * exception. + * The resource adaptor calls this function when a memory allocation throws a specified exception + * type. The function decides whether the resource adaptor should try to allocate the memory again + * or re-throw the exception. * * The callback function signature is: * `bool failure_callback_t(std::size_t bytes, void* callback_arg)` * - * The callback function will be passed two parameters: `bytes` is the size of the - * failed memory allocation, and `arg` is the extra argument passed to the constructor - * of the `failure_callback_resource_adaptor`. The callback function returns a Boolean - * where true means to retry the memory allocation and false means to throw a - * `rmm::bad_alloc` exception. + * The callback function is passed two parameters: `bytes` is the size of the failed memory + * allocation and `arg` is the extra argument passed to the constructor of the + * `failure_callback_resource_adaptor`. The callback function returns a Boolean where true means to + * retry the memory allocation and false means to re-throw the exception. */ using failure_callback_t = std::function; /** * @brief A device memory resource that calls a callback function when allocations - * throws `std::bad_alloc`. + * throw a specified exception type. * * An instance of this resource must be constructed with an existing, upstream * resource in order to satisfy allocation requests. * * The callback function takes an allocation size and a callback argument and returns - * a bool representing whether to retry the allocation (true) or throw `std::bad_alloc` + * a bool representing whether to retry the allocation (true) or re-throw the caught exception * (false). * - * When implementing a callback function for allocation retry, care must be taken to - * avoid an infinite loop. In the following example, we make sure to only retry the allocation - * once: + * When implementing a callback function for allocation retry, care must be taken to avoid an + * infinite loop. The following example makes sure to only retry the allocation once: * * @code{c++} * using failure_callback_adaptor = @@ -67,9 +65,8 @@ using failure_callback_t = std::function; * if (!retried) { * retried = true; * return true; // First time we request an allocation retry - * } else { - * return false; // Second time we let the adaptor throw std::bad_alloc * } + * return false; // Second time we let the adaptor throw std::bad_alloc * } * * int main() @@ -83,10 +80,13 @@ using failure_callback_t = std::function; * @endcode * * @tparam Upstream The type of the upstream resource used for allocation/deallocation. + * @tparam ExceptionType The type of exception that this adaptor should respond to */ -template +template class failure_callback_resource_adaptor final : public device_memory_resource { public: + using exception_type = ExceptionType; ///< The type of exception this object catches/throws + /** * @brief Construct a new `failure_callback_resource_adaptor` using `upstream` to satisfy * allocation requests. @@ -100,7 +100,7 @@ class failure_callback_resource_adaptor final : public device_memory_resource { failure_callback_resource_adaptor(Upstream* upstream, failure_callback_t callback, void* callback_arg) - : upstream_{upstream}, callback_{callback}, callback_arg_{callback_arg} + : upstream_{upstream}, callback_{std::move(callback)}, callback_arg_{callback_arg} { RMM_EXPECTS(nullptr != upstream, "Unexpected null upstream resource pointer."); } @@ -126,14 +126,17 @@ class failure_callback_resource_adaptor final : public device_memory_resource { * @return true The upstream resource supports streams * @return false The upstream resource does not support streams. */ - bool supports_streams() const noexcept override { return upstream_->supports_streams(); } + [[nodiscard]] bool supports_streams() const noexcept override + { + return upstream_->supports_streams(); + } /** * @brief Query whether the resource supports the get_mem_info API. * * @return bool true if the upstream resource supports get_mem_info, false otherwise. */ - bool supports_get_mem_info() const noexcept override + [[nodiscard]] bool supports_get_mem_info() const noexcept override { return upstream_->supports_get_mem_info(); } @@ -143,7 +146,7 @@ class failure_callback_resource_adaptor final : public device_memory_resource { * @brief Allocates memory of size at least `bytes` using the upstream * resource. * - * @throws `rmm::bad_alloc` if the requested allocation could not be fulfilled + * @throws `exception_type` if the requested allocation could not be fulfilled * by the upstream resource. * * @param bytes The size, in bytes, of the allocation @@ -152,13 +155,13 @@ class failure_callback_resource_adaptor final : public device_memory_resource { */ void* do_allocate(std::size_t bytes, cuda_stream_view stream) override { - void* ret; + void* ret{}; while (true) { try { ret = upstream_->allocate(bytes, stream); break; - } catch (std::bad_alloc const& e) { + } catch (exception_type const& e) { if (!callback_(bytes, callback_arg_)) { throw; } } } @@ -188,7 +191,7 @@ class failure_callback_resource_adaptor final : public device_memory_resource { * @return true If the two resources are equivalent * @return false If the two resources are not equal */ - bool do_is_equal(device_memory_resource const& other) const noexcept override + [[nodiscard]] bool do_is_equal(device_memory_resource const& other) const noexcept override { if (this == &other) { return true; } auto cast = dynamic_cast const*>(&other); @@ -204,7 +207,8 @@ class failure_callback_resource_adaptor final : public device_memory_resource { * @param stream Stream on which to get the mem info. * @return std::pair contaiing free_size and total_size of memory */ - std::pair do_get_mem_info(cuda_stream_view stream) const override + [[nodiscard]] std::pair do_get_mem_info( + cuda_stream_view stream) const override { return upstream_->get_mem_info(stream); } diff --git a/tests/mr/device/failure_callback_mr_tests.cpp b/tests/mr/device/failure_callback_mr_tests.cpp index ef4553c1c..bb5484c69 100644 --- a/tests/mr/device/failure_callback_mr_tests.cpp +++ b/tests/mr/device/failure_callback_mr_tests.cpp @@ -15,6 +15,8 @@ */ #include "../../byte_literals.hpp" +#include "rmm/cuda_stream_view.hpp" +#include "rmm/mr/device/device_memory_resource.hpp" #include #include @@ -26,29 +28,92 @@ namespace rmm::test { namespace { +template using failure_callback_adaptor = - rmm::mr::failure_callback_resource_adaptor; + rmm::mr::failure_callback_resource_adaptor; -bool failure_handler(std::size_t bytes, void* arg) +bool failure_handler(std::size_t /*bytes*/, void* arg) { + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast) bool& retried = *reinterpret_cast(arg); if (!retried) { retried = true; return true; // First time we request an allocation retry - } else { - return false; // Second time we let the adaptor throw std::bad_alloc } + return false; // Second time we let the adaptor throw std::bad_alloc } TEST(FailureCallbackTest, RetryAllocationOnce) { bool retried{false}; - failure_callback_adaptor mr{rmm::mr::get_current_device_resource(), failure_handler, &retried}; - rmm::mr::set_current_device_resource(&mr); + failure_callback_adaptor<> mr{rmm::mr::get_current_device_resource(), failure_handler, &retried}; EXPECT_EQ(retried, false); EXPECT_THROW(mr.allocate(512_GiB), std::bad_alloc); EXPECT_EQ(retried, true); } +template +class always_throw_memory_resource final : public mr::device_memory_resource { + private: + void* do_allocate(std::size_t bytes, cuda_stream_view stream) override + { + throw ExceptionType{"foo"}; + } + void do_deallocate(void* ptr, std::size_t bytes, cuda_stream_view stream) override{}; + [[nodiscard]] std::pair do_get_mem_info( + cuda_stream_view stream) const override + { + return {0, 0}; + } + + [[nodiscard]] bool supports_streams() const noexcept override { return false; } + [[nodiscard]] bool supports_get_mem_info() const noexcept override { return false; } +}; + +TEST(FailureCallbackTest, DifferentExceptionTypes) +{ + always_throw_memory_resource bad_alloc_mr; + always_throw_memory_resource oom_mr; + + EXPECT_THROW(bad_alloc_mr.allocate(1_MiB), rmm::bad_alloc); + EXPECT_THROW(oom_mr.allocate(1_MiB), rmm::out_of_memory); + + // Wrap a bad_alloc-catching callback adaptor around an MR that always throws bad_alloc: + // Should retry once and then re-throw bad_alloc + { + bool retried{false}; + failure_callback_adaptor bad_alloc_callback_mr{ + &bad_alloc_mr, failure_handler, &retried}; + + EXPECT_EQ(retried, false); + EXPECT_THROW(bad_alloc_callback_mr.allocate(1_MiB), rmm::bad_alloc); + EXPECT_EQ(retried, true); + } + + // Wrap a out_of_memory-catching callback adaptor around an MR that always throws out_of_memory: + // Should retry once and then re-throw out_of_memory + { + bool retried{false}; + + failure_callback_adaptor oom_callback_mr{ + &oom_mr, failure_handler, &retried}; + EXPECT_EQ(retried, false); + EXPECT_THROW(oom_callback_mr.allocate(1_MiB), rmm::out_of_memory); + EXPECT_EQ(retried, true); + } + + // Wrap a out_of_memory-catching callback adaptor around an MR that always throws bad_alloc: + // Should not catch the bad_alloc exception + { + bool retried{false}; + + failure_callback_adaptor oom_callback_mr{ + &bad_alloc_mr, failure_handler, &retried}; + EXPECT_EQ(retried, false); + EXPECT_THROW(oom_callback_mr.allocate(1_MiB), rmm::bad_alloc); // bad_alloc passes through + EXPECT_EQ(retried, false); // Does not catch / retry on anything except OOM + } +} + } // namespace } // namespace rmm::test