From 8f602e187b0634e1df13ba370352cf092e9042c0 Mon Sep 17 00:00:00 2001 From: Wenzel Jakob Date: Tue, 12 Sep 2023 00:12:55 +0200 Subject: [PATCH] ``ndarray``: fast views into existing arrays CPU loops involving nanobind ndarrays weren't getting properly vectorized. This commit adds *views*, which provide an efficient abstraction that enables better code generation. --- docs/api_extra.rst | 11 +++ docs/ndarray.rst | 153 +++++++++++++++++++++++++++++-------- include/nanobind/ndarray.h | 130 +++++++++++++++++++++++++++---- tests/test_ndarray.cpp | 30 ++++++++ tests/test_ndarray.py | 23 ++++++ 5 files changed, 302 insertions(+), 45 deletions(-) diff --git a/docs/api_extra.rst b/docs/api_extra.rst index 44945094..339c7361 100644 --- a/docs/api_extra.rst +++ b/docs/api_extra.rst @@ -565,6 +565,17 @@ section `. Return a mutable pointer to the array data. Only enabled when `Scalar` is not itself ``const``. + .. cpp:function:: template auto view() + + Returns an nd-array view that is optimized for fast array access on the + CPU. You may optionally specify additional ndarray constraints via the + `Extra` parameter (though a runtime check should first be performed to + ensure that the array possesses these properties). + + The returned view provides the operations ``data()``, ``ndim()``, + ``shape()``, ``stride()``, and ``operator()`` following the conventions + of the `ndarray` type. + .. cpp:function:: template auto& operator()(Ts... indices) Return a mutable reference to the element at stored at the provided diff --git a/docs/ndarray.rst b/docs/ndarray.rst index bb394de6..92ae3deb 100644 --- a/docs/ndarray.rst +++ b/docs/ndarray.rst @@ -83,13 +83,12 @@ should therefore prevent such undefined behavior. :cpp:class:`nb::ndarray\<...\> ` accepts template arguments to specify such constraints. For example the function interface below guarantees that the implementation is only invoked when it is provided with -a ``MxNx3`` array of 8-bit unsigned integers that is furthermore stored -contiguously in CPU memory using a C-style array ordering convention. +a ``MxNx3`` array of 8-bit unsigned integers. .. code-block:: cpp m.def("process", [](nb::ndarray, - nb::c_contig, nb::device::cpu> data) { + nb::device::cpu> data) { // Double brightness of the MxNx3 RGB image for (size_t y = 0; y < data.shape(0); ++y) for (size_t x = 0; x < data.shape(1); ++x) @@ -100,15 +99,16 @@ contiguously in CPU memory using a C-style array ordering convention. The above example also demonstrates the use of :cpp:func:`nb::ndarray\<...\>::operator() `, which -provides direct (i.e., high-performance) read/write access to the array -data. Note that this function is only available when the underlying data -type and ndarray rank are specified. It should only be used when the -array storage is reachable via CPU’s virtual memory address space. +provides direct read/write access to the array contents. Note that this +function is only available when the underlying data type and ndarray dimension +are specified via the :cpp:type:`ndarray\<..\> ` template parameters. +It should only be used when the array storage is accessible through the CPU's +virtual memory address space. .. _ndarray-constraints-1: Constraint types -~~~~~~~~~~~~~~~~ +---------------- The following constraints are available @@ -153,6 +153,94 @@ count until they go out of scope. It is legal call when the `GIL `__ is not held. +.. _ndarray-views: + +Fast array views +---------------- + +The following advice applies to performance-sensitive CPU code that reads and +writes arrays using loops that invoke :cpp:func:`nb::ndarray\<...\>::operator() +`. It does not apply to GPU arrays because they are +usually not accessed in this way. + +Consider the following snippet, which fills a 2D array with data: + +.. code-block:: cpp + + void fill(nb::ndarray, nb::c_contig, nb::device::cpu> arg) { + for (size_t i = 0; i < array.shape(0); ++i) + for (size_t j = 0; j < array.shape(1); ++j) + arg(i, j) = /* ... */; + } + +While functional, this code is not perfect. The problem is that to compute the +address of an entry, ``operator()`` accesses the DLPack array descriptor. This +indirection can break certain compiler optimizations. + +nanobind provides the method :cpp:func:`ndarray\<...\>::view() ` +to fix this. It creates a tiny data structure that provides all information +needed to access the array contents, and which can be held within CPU +registers. All relevant compile-time information (:cpp:class:`nb::ndim `, +:cpp:class:`nb::shape `, :cpp:class:`nb::c_contig `, +:cpp:class:`nb::f_contig `) is materialized in this view, which +enables constant propagation, auto-vectorization, and loop unrolling. + +An improved version of the example using such a view is shown below: + +.. code-block:: cpp + + void fill(nb::ndarray, nb::c_contig, nb::device::cpu> arg) { + auto v = array.view(); /// <-- new! + + for (size_t i = 0; i < v.shape(0); ++i) // Important; use 'v' instead of 'arg' everywhere in loop + for (size_t j = 0; j < v.shape(1); ++j) + v(i, j) = /* ... */; + } + +Note that the view performs no reference counting. You may not store it in a way +that exceeds the lifetime of the original array. + +When using OpenMP to parallelize expensive array operations, pass the +``firstprivate(view_1, view_2, ...)`` so that each worker thread can copy the +view into its register file. + +.. code-block:: cpp + + auto v = array.view(); + #pragma omp parallel for schedule(static) firstprivate(v) + for (...) { /* parallel loop */ } + +.. _ndarray-runtime-specialization: + +Specializing views at runtime +----------------------------- + +As mentioned earlier, element access via ``operator()`` only works when both +the array's scalar type and its dimension are specified within the type (i.e., +when they are known at compile time); the same is also true for array views. +However, sometimes, it is useful that a function can be called with different +array types. + +You may use the :cpp:func:`ndarray\<...\>::view() ` method to +create *specialized* views if a run-time check determines that it is safe to +do so. For example, the function below accepts contiguous CPU arrays and +performs a loop over a specialized 2D ``float`` view when the array is of +this type. + +.. code-block:: cpp + + void fill(nb::ndarray arg) { + if (arg.dtype() == nb::dtype() && arg.ndim() == 2) { + auto v = array.view>(); // <-- new! + + for (size_t i = 0; i < v.shape(0); ++i) { + for (size_t j = 0; j < v.shape(1); ++j) { + v(i, j) = /* ... */; + } + } + } else { /* ... */ } + } + Constraints in type signatures ------------------------------ @@ -364,6 +452,30 @@ interpreted as follows: - :cpp:enumerator:`rv_policy::move` is unsupported and demoted to :cpp:enumerator:`rv_policy::copy`. +.. _ndarray_nonstandard_arithmetic: + +Nonstandard arithmetic types +---------------------------- + +Low or extended-precision arithmetic types (e.g., ``int128``, ``float16``, +``bfloat``) are sometimes used but don't have standardized C++ equivalents. If +you wish to exchange arrays based on such types, you must register a partial +overload of ``nanobind::ndarray_traits`` to inform nanobind about it. + +For example, the following snippet makes ``__fp16`` (half-precision type on +``aarch64``) available: + +.. code-block:: cpp + + namespace nanobind { + template <> struct ndarray_traits<__fp16> { + static constexpr bool is_float = true; + static constexpr bool is_bool = false; + static constexpr bool is_int = false; + static constexpr bool is_signed = true; + }; + }; + Limitations ----------- @@ -383,30 +495,9 @@ internal representations (*dtypes*), including nanobind's :cpp:class:`nb::ndarray\<...\> ` is based on the `DLPack `__ array exchange protocol, which causes it to be more restrictive. Presently supported dtypes include signed/unsigned -integers, floating point values, and boolean values. +integers, floating point values, and boolean values. Some :ref:`nonstandard +arithmetic types ` can be supported as well. Nanobind can receive and return read-only arrays via the buffer protocol used to exchange data with NumPy. The DLPack interface currently ignores this annotation. - -Supporting nonstandard arithmetic types ---------------------------------------- - -Low or extended-precision arithmetic types (e.g., ``int128``, ``float16``, -``bfloat``) are sometimes used but don't have standardized C++ equivalents. If -you wish to exchange arrays based on such types, you must register a partial -overload of ``nanobind::ndarray_traits`` to inform nanobind about it. - -For example, the following snippet makes ``__fp16`` (half-precision type on -``aarch64``) available: - -.. code-block:: cpp - - namespace nanobind { - template <> struct ndarray_traits<__fp16> { - static constexpr bool is_float = true; - static constexpr bool is_bool = false; - static constexpr bool is_int = false; - static constexpr bool is_signed = true; - }; - }; diff --git a/include/nanobind/ndarray.h b/include/nanobind/ndarray.h index 38635b77..9721d693 100644 --- a/include/nanobind/ndarray.h +++ b/include/nanobind/ndarray.h @@ -247,6 +247,7 @@ template struct ndarray_info { using shape_type = void; constexpr static auto name = const_name("ndarray"); constexpr static ndarray_framework framework = ndarray_framework::none; + constexpr static char order = '\0'; }; template struct ndarray_info : ndarray_info { @@ -259,6 +260,14 @@ template struct ndarray_info, Ts...> using shape_type = shape; }; +template struct ndarray_info : ndarray_info { + constexpr static char order = 'C'; +}; + +template struct ndarray_info : ndarray_info { + constexpr static char order = 'F'; +}; + template struct ndarray_info : ndarray_info { constexpr static auto name = const_name("numpy.ndarray"); constexpr static ndarray_framework framework = ndarray_framework::numpy; @@ -282,6 +291,64 @@ template struct ndarray_info : ndarray_info NAMESPACE_END(detail) +template struct ndarray_view { + static constexpr size_t Dim = Shape::size; + + ndarray_view() = default; + ndarray_view(const ndarray_view &) = default; + ndarray_view(ndarray_view &&) = default; + ndarray_view &operator=(const ndarray_view &) = default; + ndarray_view &operator=(ndarray_view &&) noexcept = default; + ~ndarray_view() noexcept = default; + + template NB_INLINE Scalar &operator()(Ts... indices) const { + static_assert( + sizeof...(Ts) == Dim, + "ndarray_view::operator(): invalid number of arguments"); + + const int64_t indices_i64[] { (int64_t) indices... }; + int64_t offset = 0; + for (size_t i = 0; i < Dim; ++i) + offset += indices_i64[i] * m_strides[i]; + + return *(m_data + offset); + } + + size_t ndim() const { return Dim; } + size_t shape(size_t i) const { return m_shape[i]; } + int64_t stride(size_t i) const { return m_strides[i]; } + Scalar *data() const { return m_data; } + +private: + template friend class ndarray; + + template + ndarray_view(Scalar *data, const int64_t *shape, const int64_t *strides, + std::index_sequence, nanobind::shape) + : m_data(data) { + + /* Initialize shape/strides with compile-time knowledge if + available (to permit vectorization, loop unrolling, etc.) */ + ((m_shape[I1] = (I2 == any) ? shape[I1] : I2), ...); + ((m_strides[I1] = strides[I1]), ...); + + if constexpr (Order == 'F') { + m_strides[0] = 1; + for (size_t i = 1; i < Dim; ++i) + m_strides[i] = m_strides[i - 1] * m_shape[i - 1]; + } else if constexpr (Order == 'C') { + m_strides[Dim - 1] = 1; + for (Py_ssize_t i = (Py_ssize_t) Dim - 2; i >= 0; --i) + m_strides[i] = m_strides[i + 1] * m_shape[i + 1]; + } + } + + Scalar *m_data = nullptr; + int64_t m_shape[Dim] { }; + int64_t m_strides[Dim] { }; +}; + + template class ndarray { public: template friend class ndarray; @@ -405,24 +472,59 @@ template class ndarray { byte_offset(indices...)); } + template NB_INLINE auto view() { + using Info2 = typename ndarray::Info; + using Scalar2 = typename Info2::scalar_type; + using Shape2 = typename Info2::shape_type; + + constexpr bool has_scalar = !std::is_same_v, + has_shape = !std::is_same_v; + + static_assert(has_scalar, + "To use the ndarray::view<..>() method, you must add a scalar type " + "annotation (e.g. 'float') to the template parameters of the parent " + "ndarray, or to the call to .view<..>()"); + + static_assert(has_shape, + "To use the ndarray::view<..>() method, you must add a shape<..> " + "or ndim<..> annotation to the template parameters of the parent " + "ndarray, or to the call to .view<..>()"); + + if constexpr (has_scalar && has_shape) { + return ndarray_view( + (Scalar2 *) data(), shape_ptr(), stride_ptr(), + std::make_index_sequence(), Shape2()); + } else { + return nullptr; + } + } + private: template NB_INLINE int64_t byte_offset(Ts... indices) const { - static_assert( - !std::is_same_v, - "To use nb::ndarray::operator(), you must add a scalar type " + constexpr bool has_scalar = !std::is_same_v, + has_shape = !std::is_same_v; + + static_assert(has_scalar, + "To use ndarray::operator(), you must add a scalar type " "annotation (e.g. 'float') to the ndarray template parameters."); - static_assert( - !std::is_same_v, - "To use nb::ndarray::operator(), you must add a nb::shape<> " - "annotation to the ndarray template parameters."); - static_assert(sizeof...(Ts) == Info::shape_type::size, - "nb::ndarray::operator(): invalid number of arguments"); - size_t counter = 0; - int64_t index = 0; - ((index += int64_t(indices) * m_dltensor.strides[counter++]), ...); - - return (int64_t) m_dltensor.byte_offset + index * sizeof(typename Info::scalar_type); + + static_assert(has_shape, + "To use ndarray::operator(), you must add a shape<> or " + "ndim<> annotation to the ndarray template parameters."); + + if constexpr (has_scalar && has_shape) { + static_assert(sizeof...(Ts) == Info::shape_type::size, + "ndarray::operator(): invalid number of arguments"); + + size_t counter = 0; + int64_t index = 0; + ((index += int64_t(indices) * m_dltensor.strides[counter++]), ...); + + return (int64_t) m_dltensor.byte_offset + index * sizeof(typename Info::scalar_type); + } else { + return 0; + } } detail::ndarray_handle *m_handle = nullptr; diff --git a/tests/test_ndarray.cpp b/tests/test_ndarray.cpp index 34112e12..092f6401 100644 --- a/tests/test_ndarray.cpp +++ b/tests/test_ndarray.cpp @@ -231,6 +231,36 @@ NB_MODULE(test_ndarray_ext, m) { .def("f2_ri", &Cls::f2, nb::rv_policy::reference_internal) .def("f3_ri", &Cls::f3, nb::rv_policy::reference_internal); + m.def("fill_view_1", [](nb::ndarray<> x) { + if (x.ndim() == 2 && x.dtype() == nb::dtype()) { + auto v = x.view>(); + for (size_t i = 0; i < v.shape(0); i++) + for (size_t j = 0; j < v.shape(1); j++) + v(i, j) *= 2; + } + }, "x"_a.noconvert()); + + m.def("fill_view_2", [](nb::ndarray, nb::device::cpu> x) { + auto v = x.view(); + for (size_t i = 0; i < v.shape(0); ++i) + for (size_t j = 0; j < v.shape(1); ++j) + v(i, j) = (float) (i * 10 + j); + }, "x"_a.noconvert()); + + m.def("fill_view_3", [](nb::ndarray, nb::c_contig, nb::device::cpu> x) { + auto v = x.view(); + for (size_t i = 0; i < v.shape(0); ++i) + for (size_t j = 0; j < v.shape(1); ++j) + v(i, j) = (float) (i * 10 + j); + }, "x"_a.noconvert()); + + m.def("fill_view_4", [](nb::ndarray, nb::f_contig, nb::device::cpu> x) { + auto v = x.view(); + for (size_t i = 0; i < v.shape(0); ++i) + for (size_t j = 0; j < v.shape(1); ++j) + v(i, j) = (float) (i * 10 + j); + }, "x"_a.noconvert()); + #if defined(__aarch64__) m.def("ret_numpy_half", []() { __fp16 *f = new __fp16[8] { 1, 2, 3, 4, 5, 6, 7, 8 }; diff --git a/tests/test_ndarray.py b/tests/test_ndarray.py index 488e67a1..2554787e 100644 --- a/tests/test_ndarray.py +++ b/tests/test_ndarray.py @@ -562,6 +562,29 @@ def test30_force_contig_pytorch(): assert b is not a assert torch.all(b == a) +@needs_numpy +def test31_view(): + # 1 + x1 = np.array([[1,2],[3,4]], dtype=np.float32) + x2 = np.array([[1,2],[3,4]], dtype=np.float64) + assert np.allclose(x1, x2) + t.fill_view_1(x1) + assert np.allclose(x1, x2*2) + t.fill_view_1(x2) + assert np.allclose(x1, x2*2) + + #2 + x1 = np.zeros((3, 4), dtype=np.float32, order='C') + x2 = np.zeros((3, 4), dtype=np.float32, order='F') + t.fill_view_2(x1) + t.fill_view_2(x2) + x3 = np.zeros((3, 4), dtype=np.float32, order='C') + t.fill_view_3(x3) + x4 = np.zeros((3, 4), dtype=np.float32, order='F') + t.fill_view_4(x4) + + assert np.all(x1 == x2) and np.all(x2 == x3) and np.all(x3 == x4) + @needs_numpy def test32_half(): if not hasattr(t, 'ret_numpy_half'):