Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parameterize exception type caught by failure_callback_resource_adaptor #898

Merged
59 changes: 31 additions & 28 deletions include/rmm/mr/device/failure_callback_resource_adaptor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,42 +20,39 @@

#include <cstddef>
#include <functional>
#include <utility>

namespace rmm::mr {

/**
* @brief Callback function type used by failure_callback_resource_adaptor
*
* The resource adaptor calls this function when a memory allocation throws a
* `std::bad_alloc` exception. The function decides whether the resource adaptor
* should try to allocate the memory again or re-throw the `std::bad_alloc`
* exception.
* The resource adaptor calls this function when a memory allocation throws a specified exception
* type. The function decides whether the resource adaptor should try to allocate the memory again
* or re-throw the exception.
*
* The callback function signature is:
* `bool failure_callback_t(std::size_t bytes, void* callback_arg)`
* `bool failure_callback_t(std::size_t bytes, std::exception const& error, void* callback_arg)`
harrism marked this conversation as resolved.
Show resolved Hide resolved
*
* The callback function will be passed two parameters: `bytes` is the size of the
* failed memory allocation, and `arg` is the extra argument passed to the constructor
* of the `failure_callback_resource_adaptor`. The callback function returns a Boolean
* where true means to retry the memory allocation and false means to throw a
* `rmm::bad_alloc` exception.
* The callback function is passed three parameters: `bytes` is the size of the failed memory
harrism marked this conversation as resolved.
Show resolved Hide resolved
* allocation and `arg` is the extra argument passed to the constructor of the
* `failure_callback_resource_adaptor`. The callback function returns a Boolean where true means to
* retry the memory allocation and false means to re-throw the exception.
*/
using failure_callback_t = std::function<bool(std::size_t, void*)>;

/**
* @brief A device memory resource that calls a callback function when allocations
* throws `std::bad_alloc`.
* throw a specified exception type.
*
* An instance of this resource must be constructed with an existing, upstream
* resource in order to satisfy allocation requests.
*
* The callback function takes an allocation size and a callback argument and returns
* a bool representing whether to retry the allocation (true) or throw `std::bad_alloc`
* (false).
* The callback function takes an allocation size, an exception, and a callback argument and returns
harrism marked this conversation as resolved.
Show resolved Hide resolved
* a bool representing whether to retry the allocation (true) or throw `std::bad_alloc` (false).
harrism marked this conversation as resolved.
Show resolved Hide resolved
*
* When implementing a callback function for allocation retry, care must be taken to
* avoid an infinite loop. In the following example, we make sure to only retry the allocation
* once:
* When implementing a callback function for allocation retry, care must be taken to avoid an
* infinite loop. The following example makes sure to only retry the allocation once:
*
* @code{c++}
* using failure_callback_adaptor =
Expand All @@ -67,9 +64,8 @@ using failure_callback_t = std::function<bool(std::size_t, void*)>;
* if (!retried) {
* retried = true;
* return true; // First time we request an allocation retry
* } else {
* return false; // Second time we let the adaptor throw std::bad_alloc
* }
* return false; // Second time we let the adaptor throw std::bad_alloc
* }
*
* int main()
Expand All @@ -83,10 +79,13 @@ using failure_callback_t = std::function<bool(std::size_t, void*)>;
* @endcode
*
* @tparam Upstream The type of the upstream resource used for allocation/deallocation.
* @tparam ExceptionType The type of exception that this adaptor should respond to
*/
template <typename Upstream>
template <typename Upstream, typename ExceptionType = rmm::out_of_memory>
class failure_callback_resource_adaptor final : public device_memory_resource {
public:
using exception_type = ExceptionType; ///< The type of exception this object catches/throws

/**
* @brief Construct a new `failure_callback_resource_adaptor` using `upstream` to satisfy
* allocation requests.
Expand All @@ -100,7 +99,7 @@ class failure_callback_resource_adaptor final : public device_memory_resource {
failure_callback_resource_adaptor(Upstream* upstream,
failure_callback_t callback,
void* callback_arg)
: upstream_{upstream}, callback_{callback}, callback_arg_{callback_arg}
: upstream_{upstream}, callback_{std::move(callback)}, callback_arg_{callback_arg}
{
RMM_EXPECTS(nullptr != upstream, "Unexpected null upstream resource pointer.");
}
Expand All @@ -126,14 +125,17 @@ class failure_callback_resource_adaptor final : public device_memory_resource {
* @return true The upstream resource supports streams
* @return false The upstream resource does not support streams.
*/
bool supports_streams() const noexcept override { return upstream_->supports_streams(); }
[[nodiscard]] bool supports_streams() const noexcept override
{
return upstream_->supports_streams();
}

/**
* @brief Query whether the resource supports the get_mem_info API.
*
* @return bool true if the upstream resource supports get_mem_info, false otherwise.
*/
bool supports_get_mem_info() const noexcept override
[[nodiscard]] bool supports_get_mem_info() const noexcept override
{
return upstream_->supports_get_mem_info();
}
Expand All @@ -143,7 +145,7 @@ class failure_callback_resource_adaptor final : public device_memory_resource {
* @brief Allocates memory of size at least `bytes` using the upstream
* resource.
*
* @throws `rmm::bad_alloc` if the requested allocation could not be fulfilled
* @throws `exception_type` if the requested allocation could not be fulfilled
* by the upstream resource.
*
* @param bytes The size, in bytes, of the allocation
Expand All @@ -152,13 +154,13 @@ class failure_callback_resource_adaptor final : public device_memory_resource {
*/
void* do_allocate(std::size_t bytes, cuda_stream_view stream) override
{
void* ret;
void* ret{};

while (true) {
try {
ret = upstream_->allocate(bytes, stream);
break;
} catch (std::bad_alloc const& e) {
} catch (exception_type const& e) {
if (!callback_(bytes, callback_arg_)) { throw; }
}
}
Expand Down Expand Up @@ -188,7 +190,7 @@ class failure_callback_resource_adaptor final : public device_memory_resource {
* @return true If the two resources are equivalent
* @return false If the two resources are not equal
*/
bool do_is_equal(device_memory_resource const& other) const noexcept override
[[nodiscard]] bool do_is_equal(device_memory_resource const& other) const noexcept override
{
if (this == &other) { return true; }
auto cast = dynamic_cast<failure_callback_resource_adaptor<Upstream> const*>(&other);
Expand All @@ -204,7 +206,8 @@ class failure_callback_resource_adaptor final : public device_memory_resource {
* @param stream Stream on which to get the mem info.
* @return std::pair contaiing free_size and total_size of memory
*/
std::pair<std::size_t, std::size_t> do_get_mem_info(cuda_stream_view stream) const override
[[nodiscard]] std::pair<std::size_t, std::size_t> do_get_mem_info(
cuda_stream_view stream) const override
{
return upstream_->get_mem_info(stream);
}
Expand Down
77 changes: 71 additions & 6 deletions tests/mr/device/failure_callback_mr_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
*/

#include "../../byte_literals.hpp"
#include "rmm/cuda_stream_view.hpp"
#include "rmm/mr/device/device_memory_resource.hpp"

#include <cstddef>
#include <rmm/detail/error.hpp>
Expand All @@ -26,29 +28,92 @@
namespace rmm::test {
namespace {

template <typename ExceptionType = rmm::bad_alloc>
using failure_callback_adaptor =
rmm::mr::failure_callback_resource_adaptor<rmm::mr::device_memory_resource>;
rmm::mr::failure_callback_resource_adaptor<rmm::mr::device_memory_resource, ExceptionType>;

bool failure_handler(std::size_t bytes, void* arg)
bool failure_handler(std::size_t /*bytes*/, void* arg)
{
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
bool& retried = *reinterpret_cast<bool*>(arg);
if (!retried) {
retried = true;
return true; // First time we request an allocation retry
} else {
return false; // Second time we let the adaptor throw std::bad_alloc
}
return false; // Second time we let the adaptor throw std::bad_alloc
}

TEST(FailureCallbackTest, RetryAllocationOnce)
{
bool retried{false};
failure_callback_adaptor mr{rmm::mr::get_current_device_resource(), failure_handler, &retried};
rmm::mr::set_current_device_resource(&mr);
failure_callback_adaptor<> mr{rmm::mr::get_current_device_resource(), failure_handler, &retried};
EXPECT_EQ(retried, false);
EXPECT_THROW(mr.allocate(512_GiB), std::bad_alloc);
EXPECT_EQ(retried, true);
}

template <typename ExceptionType>
class always_throw_memory_resource final : public mr::device_memory_resource {
private:
void* do_allocate(std::size_t bytes, cuda_stream_view stream) override
{
throw ExceptionType{"foo"};
}
void do_deallocate(void* ptr, std::size_t bytes, cuda_stream_view stream) override{};
[[nodiscard]] std::pair<std::size_t, std::size_t> do_get_mem_info(
cuda_stream_view stream) const override
{
return {0, 0};
}

[[nodiscard]] bool supports_streams() const noexcept override { return false; }
[[nodiscard]] bool supports_get_mem_info() const noexcept override { return false; }
};

TEST(FailureCallbackTest, DifferentExceptionTypes)
{
always_throw_memory_resource<rmm::bad_alloc> bad_alloc_mr;
always_throw_memory_resource<rmm::out_of_memory> oom_mr;

EXPECT_THROW(bad_alloc_mr.allocate(1_MiB), rmm::bad_alloc);
EXPECT_THROW(oom_mr.allocate(1_MiB), rmm::out_of_memory);

// Wrap a bad_alloc-catching callback adaptor around an MR that always throws bad_alloc:
// Should retry once and then re-throw bad_alloc
{
bool retried{false};
failure_callback_adaptor<rmm::bad_alloc> bad_alloc_callback_mr{
&bad_alloc_mr, failure_handler, &retried};

EXPECT_EQ(retried, false);
EXPECT_THROW(bad_alloc_callback_mr.allocate(1_MiB), rmm::bad_alloc);
EXPECT_EQ(retried, true);
}

// Wrap a out_of_memory-catching callback adaptor around an MR that always throws out_of_memory:
// Should retry once and then re-throw out_of_memory
{
bool retried{false};

failure_callback_adaptor<rmm::out_of_memory> oom_callback_mr{
&oom_mr, failure_handler, &retried};
EXPECT_EQ(retried, false);
EXPECT_THROW(oom_callback_mr.allocate(1_MiB), rmm::out_of_memory);
EXPECT_EQ(retried, true);
}

// Wrap a out_of_memory-catching callback adaptor around an MR that always throws bad_alloc:
// Should not catch the bad_alloc exception
{
bool retried{false};

failure_callback_adaptor<rmm::out_of_memory> oom_callback_mr{
&bad_alloc_mr, failure_handler, &retried};
EXPECT_EQ(retried, false);
EXPECT_THROW(oom_callback_mr.allocate(1_MiB), rmm::bad_alloc); // bad_alloc passes through
EXPECT_EQ(retried, false); // Does not catch / retry on anything except OOM
}
}

} // namespace
} // namespace rmm::test