Skip to content

Commit

Permalink
ndarray: support for custom arithmetic types
Browse files Browse the repository at this point in the history
This commit makes it possible to exchange tensors based on arithmetic
types like ``__fp16``, ``__int128``, etc., that are nonstandard in the
sense that they aren't part of core C++.

Note that this commit does *not* add support for custom dtypes--this is
more limited in scope.
  • Loading branch information
wjakob committed Sep 11, 2023
1 parent 1350a5e commit 49eab28
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 9 deletions.
22 changes: 22 additions & 0 deletions docs/ndarray.rst
Expand Up @@ -388,3 +388,25 @@ integers, floating point values, and boolean values.
Nanobind can receive and return read-only arrays via the buffer protocol used
to exchange data with NumPy. The DLPack interface currently ignores this
annotation.

Supporting nonstandard arithmetic types
---------------------------------------

Low or extended-precision arithmetic types (e.g., ``int128``, ``float16``,
``bfloat``) are sometimes used but don't have standardized C++ equivalents. If
you wish to exchange arrays based on such types, you must register a partial
overload of ``nanobind::ndarray_traits`` to inform nanobind about it.

For example, the following snippet makes ``__fp16`` (half-precision type on
``aarch64``) available:

.. code-block:: cpp
namespace nanobind {
template <> struct ndarray_traits<__fp16> {
static constexpr bool is_float = true;
static constexpr bool is_bool = false;
static constexpr bool is_int = false;
static constexpr bool is_signed = true;
};
};
27 changes: 18 additions & 9 deletions include/nanobind/ndarray.h
Expand Up @@ -80,10 +80,19 @@ struct pytorch { };
struct jax { };
struct ro { };

template <typename T> struct ndarray_traits {
static constexpr bool is_float = std::is_floating_point_v<T>;
static constexpr bool is_bool = std::is_same_v<std::remove_cv_t<T>, bool>;
static constexpr bool is_int = std::is_integral_v<T> && !is_bool;
static constexpr bool is_signed = std::is_signed_v<T>;
};

NAMESPACE_BEGIN(detail)

template<typename T> constexpr bool is_ndarray_scalar_v =
std::is_floating_point_v<T> || std::is_integral_v<T>;
template <typename T>
constexpr bool is_ndarray_scalar_v =
ndarray_traits<T>::is_float || ndarray_traits<T>::is_int ||
ndarray_traits<T>::is_bool;

template <typename> struct ndim_shape;
template <size_t... S> struct ndim_shape<std::index_sequence<S...>> {
Expand All @@ -102,9 +111,9 @@ template <typename T> constexpr dlpack::dtype dtype() {

dlpack::dtype result;

if constexpr (std::is_floating_point_v<T>)
if constexpr (ndarray_traits<T>::is_float)
result.code = (uint8_t) dlpack::dtype_code::Float;
else if constexpr (std::is_signed_v<T>)
else if constexpr (ndarray_traits<T>::is_signed)
result.code = (uint8_t) dlpack::dtype_code::Int;
else if constexpr (std::is_same_v<std::remove_cv_t<T>, bool>)
result.code = (uint8_t) dlpack::dtype_code::Bool;
Expand Down Expand Up @@ -139,7 +148,7 @@ template <typename T, typename = int> struct ndarray_arg {
static void apply(ndarray_req &) { }
};

template <typename T> struct ndarray_arg<T, enable_if_t<std::is_floating_point_v<T>>> {
template <typename T> struct ndarray_arg<T, enable_if_t<ndarray_traits<T>::is_float>> {
static constexpr size_t size = 0;

static constexpr auto name =
Expand All @@ -154,7 +163,7 @@ template <typename T> struct ndarray_arg<T, enable_if_t<std::is_floating_point_v
}
};

template <typename T> struct ndarray_arg<T, enable_if_t<std::is_integral_v<T> && !std::is_same_v<std::remove_cv_t<T>, bool>>> {
template <typename T> struct ndarray_arg<T, enable_if_t<ndarray_traits<T>::is_int>> {
static constexpr size_t size = 0;

static constexpr auto name =
Expand All @@ -170,7 +179,7 @@ template <typename T> struct ndarray_arg<T, enable_if_t<std::is_integral_v<T> &&
}
};

template <typename T> struct ndarray_arg<T, enable_if_t<std::is_same_v<std::remove_cv_t<T>, bool>>> {
template <typename T> struct ndarray_arg<T, enable_if_t<ndarray_traits<T>::is_bool>> {
static constexpr size_t size = 0;

static constexpr auto name =
Expand Down Expand Up @@ -242,8 +251,8 @@ template <typename... Ts> struct ndarray_info {

template <typename T, typename... Ts> struct ndarray_info<T, Ts...> : ndarray_info<Ts...> {
using scalar_type =
std::conditional_t<std::is_scalar_v<T>, T,
typename ndarray_info<Ts...>::scalar_type>;
std::conditional_t<ndarray_traits<T>::is_float || ndarray_traits<T>::is_int ||
ndarray_traits<T>::is_bool, T, typename ndarray_info<Ts...>::scalar_type>;
};

template <size_t... Is, typename... Ts> struct ndarray_info<shape<Is...>, Ts...> : ndarray_info<Ts...> {
Expand Down
27 changes: 27 additions & 0 deletions tests/test_ndarray.cpp
Expand Up @@ -9,6 +9,18 @@ using namespace nb::literals;

int destruct_count = 0;
static float f_global[] { 1, 2, 3, 4, 5, 6, 7, 8 };
static int i_global[] { 1, 2, 3, 4, 5, 6, 7, 8 };

#if defined(__aarch64__)
namespace nanobind {
template <> struct ndarray_traits<__fp16> {
static constexpr bool is_float = true;
static constexpr bool is_bool = false;
static constexpr bool is_int = false;
static constexpr bool is_signed = true;
};
};
#endif

NB_MODULE(test_ndarray_ext, m) {
m.def("get_shape", [](const nb::ndarray<nb::ro> &t) {
Expand Down Expand Up @@ -218,4 +230,19 @@ NB_MODULE(test_ndarray_ext, m) {
.def("f1_ri", &Cls::f1, nb::rv_policy::reference_internal)
.def("f2_ri", &Cls::f2, nb::rv_policy::reference_internal)
.def("f3_ri", &Cls::f3, nb::rv_policy::reference_internal);

#if defined(__aarch64__)
m.def("ret_numpy_half", []() {
__fp16 *f = new __fp16[8] { 1, 2, 3, 4, 5, 6, 7, 8 };
size_t shape[2] = { 2, 4 };

nb::capsule deleter(f, [](void *data) noexcept {
destruct_count++;
delete[] (__fp16*) data;
});

return nb::ndarray<nb::numpy, __fp16, nb::shape<2, 4>>(f, 2, shape,
deleter);
});
#endif
}
10 changes: 10 additions & 0 deletions tests/test_ndarray.py
Expand Up @@ -561,3 +561,13 @@ def test30_force_contig_pytorch():
b = t.make_contig(a)
assert b is not a
assert torch.all(b == a)

@needs_numpy
def test32_half():
if not hasattr(t, 'ret_numpy_half'):
pytest.skip('half precision test is missing')
x = t.ret_numpy_half()
assert x.dtype == np.float16
assert x.shape == (2, 4)
assert np.all(x == [[1, 2, 3, 4], [5, 6, 7, 8]])

0 comments on commit 49eab28

Please sign in to comment.