Skip to content

Commit

Permalink
Parameterize exception type caught by failure_callback_resource_adapt…
Browse files Browse the repository at this point in the history
…or (#898)

#892 added `failure_callback_resource_adaptor` which provides the ability to respond to memory allocation failures. However, it was hard-coded to catch (and rethrow) `std::bad_alloc` exceptions.  This PR makes the type of exception the adaptor catches a template parameter, to provide greater flexibility. The default exception type is now `rmm::out_of_memory` since we expect this to be the common use case.

Also a few changes to fix clang-tidy warnings.

Authors:
  - Mark Harris (https://github.com/harrism)

Approvers:
  - Rong Ou (https://github.com/rongou)
  - Mads R. B. Kristensen (https://github.com/madsbk)
  - Jake Hemstad (https://github.com/jrhemstad)

URL: #898
  • Loading branch information
harrism committed Nov 9, 2021
1 parent 76ae622 commit 728a117
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 31 deletions.
54 changes: 29 additions & 25 deletions include/rmm/mr/device/failure_callback_resource_adaptor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,42 +20,40 @@

#include <cstddef>
#include <functional>
#include <utility>

namespace rmm::mr {

/**
* @brief Callback function type used by failure_callback_resource_adaptor
*
* The resource adaptor calls this function when a memory allocation throws a
* `std::bad_alloc` exception. The function decides whether the resource adaptor
* should try to allocate the memory again or re-throw the `std::bad_alloc`
* exception.
* The resource adaptor calls this function when a memory allocation throws a specified exception
* type. The function decides whether the resource adaptor should try to allocate the memory again
* or re-throw the exception.
*
* The callback function signature is:
* `bool failure_callback_t(std::size_t bytes, void* callback_arg)`
*
* The callback function will be passed two parameters: `bytes` is the size of the
* failed memory allocation, and `arg` is the extra argument passed to the constructor
* of the `failure_callback_resource_adaptor`. The callback function returns a Boolean
* where true means to retry the memory allocation and false means to throw a
* `rmm::bad_alloc` exception.
* The callback function is passed two parameters: `bytes` is the size of the failed memory
* allocation and `arg` is the extra argument passed to the constructor of the
* `failure_callback_resource_adaptor`. The callback function returns a Boolean where true means to
* retry the memory allocation and false means to re-throw the exception.
*/
using failure_callback_t = std::function<bool(std::size_t, void*)>;

/**
* @brief A device memory resource that calls a callback function when allocations
* throws `std::bad_alloc`.
* throw a specified exception type.
*
* An instance of this resource must be constructed with an existing, upstream
* resource in order to satisfy allocation requests.
*
* The callback function takes an allocation size and a callback argument and returns
* a bool representing whether to retry the allocation (true) or throw `std::bad_alloc`
* a bool representing whether to retry the allocation (true) or re-throw the caught exception
* (false).
*
* When implementing a callback function for allocation retry, care must be taken to
* avoid an infinite loop. In the following example, we make sure to only retry the allocation
* once:
* When implementing a callback function for allocation retry, care must be taken to avoid an
* infinite loop. The following example makes sure to only retry the allocation once:
*
* @code{c++}
* using failure_callback_adaptor =
Expand All @@ -67,9 +65,8 @@ using failure_callback_t = std::function<bool(std::size_t, void*)>;
* if (!retried) {
* retried = true;
* return true; // First time we request an allocation retry
* } else {
* return false; // Second time we let the adaptor throw std::bad_alloc
* }
* return false; // Second time we let the adaptor throw std::bad_alloc
* }
*
* int main()
Expand All @@ -83,10 +80,13 @@ using failure_callback_t = std::function<bool(std::size_t, void*)>;
* @endcode
*
* @tparam Upstream The type of the upstream resource used for allocation/deallocation.
* @tparam ExceptionType The type of exception that this adaptor should respond to
*/
template <typename Upstream>
template <typename Upstream, typename ExceptionType = rmm::out_of_memory>
class failure_callback_resource_adaptor final : public device_memory_resource {
public:
using exception_type = ExceptionType; ///< The type of exception this object catches/throws

/**
* @brief Construct a new `failure_callback_resource_adaptor` using `upstream` to satisfy
* allocation requests.
Expand All @@ -100,7 +100,7 @@ class failure_callback_resource_adaptor final : public device_memory_resource {
failure_callback_resource_adaptor(Upstream* upstream,
failure_callback_t callback,
void* callback_arg)
: upstream_{upstream}, callback_{callback}, callback_arg_{callback_arg}
: upstream_{upstream}, callback_{std::move(callback)}, callback_arg_{callback_arg}
{
RMM_EXPECTS(nullptr != upstream, "Unexpected null upstream resource pointer.");
}
Expand All @@ -126,14 +126,17 @@ class failure_callback_resource_adaptor final : public device_memory_resource {
* @return true The upstream resource supports streams
* @return false The upstream resource does not support streams.
*/
bool supports_streams() const noexcept override { return upstream_->supports_streams(); }
[[nodiscard]] bool supports_streams() const noexcept override
{
return upstream_->supports_streams();
}

/**
* @brief Query whether the resource supports the get_mem_info API.
*
* @return bool true if the upstream resource supports get_mem_info, false otherwise.
*/
bool supports_get_mem_info() const noexcept override
[[nodiscard]] bool supports_get_mem_info() const noexcept override
{
return upstream_->supports_get_mem_info();
}
Expand All @@ -143,7 +146,7 @@ class failure_callback_resource_adaptor final : public device_memory_resource {
* @brief Allocates memory of size at least `bytes` using the upstream
* resource.
*
* @throws `rmm::bad_alloc` if the requested allocation could not be fulfilled
* @throws `exception_type` if the requested allocation could not be fulfilled
* by the upstream resource.
*
* @param bytes The size, in bytes, of the allocation
Expand All @@ -152,13 +155,13 @@ class failure_callback_resource_adaptor final : public device_memory_resource {
*/
void* do_allocate(std::size_t bytes, cuda_stream_view stream) override
{
void* ret;
void* ret{};

while (true) {
try {
ret = upstream_->allocate(bytes, stream);
break;
} catch (std::bad_alloc const& e) {
} catch (exception_type const& e) {
if (!callback_(bytes, callback_arg_)) { throw; }
}
}
Expand Down Expand Up @@ -188,7 +191,7 @@ class failure_callback_resource_adaptor final : public device_memory_resource {
* @return true If the two resources are equivalent
* @return false If the two resources are not equal
*/
bool do_is_equal(device_memory_resource const& other) const noexcept override
[[nodiscard]] bool do_is_equal(device_memory_resource const& other) const noexcept override
{
if (this == &other) { return true; }
auto cast = dynamic_cast<failure_callback_resource_adaptor<Upstream> const*>(&other);
Expand All @@ -204,7 +207,8 @@ class failure_callback_resource_adaptor final : public device_memory_resource {
* @param stream Stream on which to get the mem info.
* @return std::pair contaiing free_size and total_size of memory
*/
std::pair<std::size_t, std::size_t> do_get_mem_info(cuda_stream_view stream) const override
[[nodiscard]] std::pair<std::size_t, std::size_t> do_get_mem_info(
cuda_stream_view stream) const override
{
return upstream_->get_mem_info(stream);
}
Expand Down
77 changes: 71 additions & 6 deletions tests/mr/device/failure_callback_mr_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
*/

#include "../../byte_literals.hpp"
#include "rmm/cuda_stream_view.hpp"
#include "rmm/mr/device/device_memory_resource.hpp"

#include <cstddef>
#include <rmm/detail/error.hpp>
Expand All @@ -26,29 +28,92 @@
namespace rmm::test {
namespace {

template <typename ExceptionType = rmm::bad_alloc>
using failure_callback_adaptor =
rmm::mr::failure_callback_resource_adaptor<rmm::mr::device_memory_resource>;
rmm::mr::failure_callback_resource_adaptor<rmm::mr::device_memory_resource, ExceptionType>;

bool failure_handler(std::size_t bytes, void* arg)
bool failure_handler(std::size_t /*bytes*/, void* arg)
{
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
bool& retried = *reinterpret_cast<bool*>(arg);
if (!retried) {
retried = true;
return true; // First time we request an allocation retry
} else {
return false; // Second time we let the adaptor throw std::bad_alloc
}
return false; // Second time we let the adaptor throw std::bad_alloc
}

TEST(FailureCallbackTest, RetryAllocationOnce)
{
bool retried{false};
failure_callback_adaptor mr{rmm::mr::get_current_device_resource(), failure_handler, &retried};
rmm::mr::set_current_device_resource(&mr);
failure_callback_adaptor<> mr{rmm::mr::get_current_device_resource(), failure_handler, &retried};
EXPECT_EQ(retried, false);
EXPECT_THROW(mr.allocate(512_GiB), std::bad_alloc);
EXPECT_EQ(retried, true);
}

template <typename ExceptionType>
class always_throw_memory_resource final : public mr::device_memory_resource {
private:
void* do_allocate(std::size_t bytes, cuda_stream_view stream) override
{
throw ExceptionType{"foo"};
}
void do_deallocate(void* ptr, std::size_t bytes, cuda_stream_view stream) override{};
[[nodiscard]] std::pair<std::size_t, std::size_t> do_get_mem_info(
cuda_stream_view stream) const override
{
return {0, 0};
}

[[nodiscard]] bool supports_streams() const noexcept override { return false; }
[[nodiscard]] bool supports_get_mem_info() const noexcept override { return false; }
};

TEST(FailureCallbackTest, DifferentExceptionTypes)
{
always_throw_memory_resource<rmm::bad_alloc> bad_alloc_mr;
always_throw_memory_resource<rmm::out_of_memory> oom_mr;

EXPECT_THROW(bad_alloc_mr.allocate(1_MiB), rmm::bad_alloc);
EXPECT_THROW(oom_mr.allocate(1_MiB), rmm::out_of_memory);

// Wrap a bad_alloc-catching callback adaptor around an MR that always throws bad_alloc:
// Should retry once and then re-throw bad_alloc
{
bool retried{false};
failure_callback_adaptor<rmm::bad_alloc> bad_alloc_callback_mr{
&bad_alloc_mr, failure_handler, &retried};

EXPECT_EQ(retried, false);
EXPECT_THROW(bad_alloc_callback_mr.allocate(1_MiB), rmm::bad_alloc);
EXPECT_EQ(retried, true);
}

// Wrap a out_of_memory-catching callback adaptor around an MR that always throws out_of_memory:
// Should retry once and then re-throw out_of_memory
{
bool retried{false};

failure_callback_adaptor<rmm::out_of_memory> oom_callback_mr{
&oom_mr, failure_handler, &retried};
EXPECT_EQ(retried, false);
EXPECT_THROW(oom_callback_mr.allocate(1_MiB), rmm::out_of_memory);
EXPECT_EQ(retried, true);
}

// Wrap a out_of_memory-catching callback adaptor around an MR that always throws bad_alloc:
// Should not catch the bad_alloc exception
{
bool retried{false};

failure_callback_adaptor<rmm::out_of_memory> oom_callback_mr{
&bad_alloc_mr, failure_handler, &retried};
EXPECT_EQ(retried, false);
EXPECT_THROW(oom_callback_mr.allocate(1_MiB), rmm::bad_alloc); // bad_alloc passes through
EXPECT_EQ(retried, false); // Does not catch / retry on anything except OOM
}
}

} // namespace
} // namespace rmm::test

0 comments on commit 728a117

Please sign in to comment.