Skip to content

Commit

Permalink
Device, Host, Managed Accessor Types for mdspan (#776)
Browse files Browse the repository at this point in the history
Authors:
  - Divye Gala (https://github.com/divyegala)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #776
  • Loading branch information
divyegala committed Sep 1, 2022
1 parent 57df37d commit c2e7e90
Show file tree
Hide file tree
Showing 4 changed files with 117 additions and 59 deletions.
130 changes: 90 additions & 40 deletions cpp/include/raft/core/mdarray.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,23 +122,15 @@ template <typename T>
inline constexpr bool is_mdspan_v = is_mdspan_t<T>::value;
} // namespace detail

template <typename...>
struct is_mdspan : std::true_type {
};
template <typename T1>
struct is_mdspan<T1> : detail::is_mdspan_t<T1> {
};
template <typename T1, typename... Tn>
struct is_mdspan<T1, Tn...>
: std::conditional_t<detail::is_mdspan_v<T1>, is_mdspan<Tn...>, std::false_type> {
};

/**
* @\brief Boolean to determine if variadic template types Tn are either
* raft::host_mdspan/raft::device_mdspan or their derived types
*/
template <typename... Tn>
inline constexpr bool is_mdspan_v = is_mdspan<Tn...>::value;
inline constexpr bool is_mdspan_v = std::conjunction_v<detail::is_mdspan_t<Tn>...>;

template <typename... Tn>
using enable_if_mdspan = std::enable_if_t<is_mdspan_v<Tn...>>;

/**
* @brief stdex::mdspan with device tag to avoid accessing incorrect memory location.
Expand All @@ -160,69 +152,83 @@ template <typename ElementType,
using host_mdspan =
mdspan<ElementType, Extents, LayoutPolicy, detail::host_accessor<AccessorPolicy>>;

template <typename ElementType,
typename Extents,
typename LayoutPolicy = layout_c_contiguous,
typename AccessorPolicy = detail::stdex::default_accessor<ElementType>>
using managed_mdspan =
mdspan<ElementType, Extents, LayoutPolicy, detail::managed_accessor<AccessorPolicy>>;

namespace detail {
template <typename T, bool B>
struct is_device_mdspan : std::false_type {
};
template <typename T>
struct is_device_mdspan<T, true> : std::bool_constant<not T::accessor_type::is_host_type::value> {
struct is_device_mdspan<T, true> : std::bool_constant<T::accessor_type::is_device_accessible> {
};

/**
* @\brief Boolean to determine if template type T is either raft::device_mdspan or a derived type
*/
template <typename T>
inline constexpr bool is_device_mdspan_v = is_device_mdspan<T, is_mdspan_v<T>>::value;
using is_device_mdspan_t = is_device_mdspan<T, is_mdspan_v<T>>;

template <typename T, bool B>
struct is_host_mdspan : std::false_type {
};
template <typename T>
struct is_host_mdspan<T, true> : T::accessor_type::is_host_type {
struct is_host_mdspan<T, true> : std::bool_constant<T::accessor_type::is_host_accessible> {
};

/**
* @\brief Boolean to determine if template type T is either raft::host_mdspan or a derived type
*/
template <typename T>
inline constexpr bool is_host_mdspan_v = is_host_mdspan<T, is_mdspan_v<T>>::value;
} // namespace detail
using is_host_mdspan_t = is_host_mdspan<T, is_mdspan_v<T>>;

template <typename...>
struct is_device_mdspan : std::true_type {
};
template <typename T1>
struct is_device_mdspan<T1> : detail::is_device_mdspan<T1, detail::is_mdspan_v<T1>> {
template <typename T, bool B>
struct is_managed_mdspan : std::false_type {
};
template <typename T1, typename... Tn>
struct is_device_mdspan<T1, Tn...>
: std::conditional_t<detail::is_device_mdspan_v<T1>, is_device_mdspan<Tn...>, std::false_type> {
template <typename T>
struct is_managed_mdspan<T, true> : std::bool_constant<T::accessor_type::is_managed_accessible> {
};

/**
* @\brief Boolean to determine if template type T is either raft::managed_mdspan or a derived type
*/
template <typename T>
using is_managed_mdspan_t = is_managed_mdspan<T, is_mdspan_v<T>>;
} // namespace detail

/**
* @\brief Boolean to determine if variadic template types Tn are either raft::device_mdspan or a
* derived type
*/
template <typename... Tn>
inline constexpr bool is_device_mdspan_v = is_device_mdspan<Tn...>::value;
inline constexpr bool is_device_mdspan_v = std::conjunction_v<detail::is_device_mdspan_t<Tn>...>;

template <typename...>
struct is_host_mdspan : std::true_type {
};
template <typename T1>
struct is_host_mdspan<T1> : detail::is_host_mdspan<T1, detail::is_mdspan_v<T1>> {
};
template <typename T1, typename... Tn>
struct is_host_mdspan<T1, Tn...>
: std::conditional_t<detail::is_host_mdspan_v<T1>, is_host_mdspan<Tn...>, std::false_type> {
};
template <typename... Tn>
using enable_if_device_mdspan = std::enable_if_t<is_device_mdspan_v<Tn...>>;

/**
* @\brief Boolean to determine if variadic template types Tn are either raft::host_mdspan or a
* derived type
*/
template <typename... Tn>
inline constexpr bool is_host_mdspan_v = is_host_mdspan<Tn...>::value;
inline constexpr bool is_host_mdspan_v = std::conjunction_v<detail::is_host_mdspan_t<Tn>...>;

template <typename... Tn>
using enable_if_host_mdspan = std::enable_if_t<is_host_mdspan_v<Tn...>>;

/**
* @\brief Boolean to determine if variadic template types Tn are either raft::managed_mdspan or a
* derived type
*/
template <typename... Tn>
inline constexpr bool is_managed_mdspan_v = std::conjunction_v<detail::is_managed_mdspan_t<Tn>...>;

template <typename... Tn>
using enable_if_managed_mdspan = std::enable_if_t<is_managed_mdspan_v<Tn...>>;

/**
* @brief Interface to implement an owning multi-dimensional array
Expand Down Expand Up @@ -348,7 +354,7 @@ class mdarray
typename container_policy_type::const_accessor_policy,
typename container_policy_type::accessor_policy>>
using view_type_impl =
std::conditional_t<container_policy_type::is_host_type::value,
std::conditional_t<container_policy_type::is_host_accessible,
host_mdspan<E, extents_type, layout_type, ViewAccessorPolicy>,
device_mdspan<E, extents_type, layout_type, ViewAccessorPolicy>>;

Expand Down Expand Up @@ -672,6 +678,50 @@ template <typename ElementType,
typename LayoutPolicy = layout_c_contiguous>
using device_matrix_view = device_mdspan<ElementType, matrix_extent<IndexType>, LayoutPolicy>;

/**
* @brief Create a raft::mdspan
* @tparam ElementType the data type of the matrix elements
* @tparam IndexType the index type of the extents
* @tparam LayoutPolicy policy for strides and layout ordering
* @tparam is_host_accessible whether the data is accessible on host
* @tparam is_device_accessible whether the data is accessible on device
* @param ptr Pointer to the data
* @param exts dimensionality of the array (series of integers)
* @return raft::mdspan
*/
template <typename ElementType,
typename IndexType = std::uint32_t,
typename LayoutPolicy = layout_c_contiguous,
bool is_host_accessible = false,
bool is_device_accessible = true,
size_t... Extents>
auto make_mdspan(ElementType* ptr, extents<IndexType, Extents...> exts)
{
using accessor_type = detail::accessor_mixin<std::experimental::default_accessor<ElementType>,
is_host_accessible,
is_device_accessible>;

return mdspan<ElementType, decltype(exts), LayoutPolicy, accessor_type>{ptr, exts};
}

/**
* @brief Create a raft::managed_mdspan
* @tparam ElementType the data type of the matrix elements
* @tparam IndexType the index type of the extents
* @tparam LayoutPolicy policy for strides and layout ordering
* @param ptr Pointer to the data
* @param exts dimensionality of the array (series of integers)
* @return raft::managed_mdspan
*/
template <typename ElementType,
typename IndexType = std::uint32_t,
typename LayoutPolicy = layout_c_contiguous,
size_t... Extents>
auto make_managed_mdspan(ElementType* ptr, extents<IndexType, Extents...> exts)
{
return make_mdspan<ElementType, IndexType, LayoutPolicy, true, true>(ptr, exts);
}

/**
* @brief Create a 0-dim (scalar) mdspan instance for host value.
*
Expand Down Expand Up @@ -983,7 +1033,7 @@ auto make_device_vector(raft::handle_t const& handle, IndexType n)
* @return raft::host_mdspan or raft::device_mdspan with vector_extent
* depending on AccessoryPolicy
*/
template <typename mdspan_type, std::enable_if_t<is_mdspan_v<mdspan_type>>* = nullptr>
template <typename mdspan_type, typename = enable_if_mdspan<mdspan_type>>
auto flatten(mdspan_type mds)
{
RAFT_EXPECTS(mds.is_exhaustive(), "Input must be contiguous.");
Expand Down Expand Up @@ -1024,7 +1074,7 @@ auto flatten(const array_interface_type& mda)
template <typename mdspan_type,
typename IndexType = std::uint32_t,
size_t... Extents,
std::enable_if_t<is_mdspan_v<mdspan_type>>* = nullptr>
typename = enable_if_mdspan<mdspan_type>>
auto reshape(mdspan_type mds, extents<IndexType, Extents...> new_shape)
{
RAFT_EXPECTS(mds.is_exhaustive(), "Input must be contiguous.");
Expand Down
18 changes: 13 additions & 5 deletions cpp/include/raft/detail/mdarray.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -231,21 +231,29 @@ class host_vector_policy {
/**
* @brief A mixin to distinguish host and device memory.
*/
template <typename AccessorPolicy, bool is_host>
template <typename AccessorPolicy, bool is_host, bool is_device>
struct accessor_mixin : public AccessorPolicy {
using accessor_type = AccessorPolicy;
using is_host_type = std::conditional_t<is_host, std::true_type, std::false_type>;
using accessor_type = AccessorPolicy;
using is_host_type = std::conditional_t<is_host, std::true_type, std::false_type>;
using is_device_type = std::conditional_t<is_device, std::true_type, std::false_type>;
using is_managed_type = std::conditional_t<is_device && is_host, std::true_type, std::false_type>;
static constexpr bool is_host_accessible = is_host;
static constexpr bool is_device_accessible = is_device;
static constexpr bool is_managed_accessible = is_device && is_host;
// make sure the explicit ctor can fall through
using AccessorPolicy::AccessorPolicy;
using offset_policy = accessor_mixin;
accessor_mixin(AccessorPolicy const& that) : AccessorPolicy{that} {} // NOLINT
};

template <typename AccessorPolicy>
using host_accessor = accessor_mixin<AccessorPolicy, true>;
using host_accessor = accessor_mixin<AccessorPolicy, true, false>;

template <typename AccessorPolicy>
using device_accessor = accessor_mixin<AccessorPolicy, false>;
using device_accessor = accessor_mixin<AccessorPolicy, false, true>;

template <typename AccessorPolicy>
using managed_accessor = accessor_mixin<AccessorPolicy, true, true>;

namespace stdex = std::experimental;

Expand Down
12 changes: 12 additions & 0 deletions cpp/test/mdarray.cu
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,18 @@ void test_factory_methods()
auto view = make_host_scalar_view(h_scalar.data_handle());
ASSERT_EQ(view(0), 17.0);
}

// managed
{
raft::handle_t handle{};
auto mda = make_device_vector<int>(handle, 10);

auto mdv = make_managed_mdspan(mda.data_handle(), raft::vector_extent<int>{10});

static_assert(decltype(mdv)::accessor_type::is_managed_accessible, "Not managed mdspan");

ASSERT_EQ(mdv.size(), 10);
}
}
} // anonymous namespace

Expand Down
16 changes: 2 additions & 14 deletions cpp/test/mdspan_utils.cu
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,9 @@ void test_template_asserts()

// Checking if types are host_mdspan
static_assert(!is_host_mdspan_v<device_matrix_view<float>>,
"device_matrix_view type not a host_mdspan");
"device_matrix_view type is a host_mdspan");
static_assert(is_host_mdspan_v<host_matrix_view<float>>,
"host_matrix_view type is a host_mdspan");
"host_matrix_view type is not a host_mdspan");

// checking variadics
static_assert(!is_mdspan_v<three_d_mdspan, std::vector<int>>, "variadics mdspans");
Expand Down Expand Up @@ -171,12 +171,6 @@ void test_reshape()
three_d_mdarray mda{layout, policy};

auto flat_view = reshape(mda, raft::extents<int, dynamic_extent>{27});
// this confirms aliasing works as intended
static_assert(std::is_same_v<decltype(flat_view),
host_vector_view<typename decltype(flat_view)::element_type,
typename decltype(flat_view)::index_type,
typename decltype(flat_view)::layout_type>>,
"types not the same");

ASSERT_EQ(flat_view.extents().rank(), 1);
ASSERT_EQ(flat_view.size(), mda.size());
Expand All @@ -195,12 +189,6 @@ void test_reshape()
four_d_mdarray mda{layout, policy};

auto matrix = reshape(mda, raft::extents<int, dynamic_extent, dynamic_extent>{4, 4});
// this confirms aliasing works as intended
static_assert(std::is_same_v<decltype(matrix),
device_matrix_view<typename decltype(matrix)::element_type,
typename decltype(matrix)::index_type,
typename decltype(matrix)::layout_type>>,
"types not the same");

ASSERT_EQ(matrix.extents().rank(), 2);
ASSERT_EQ(matrix.extent(0), 4);
Expand Down

0 comments on commit c2e7e90

Please sign in to comment.