Skip to content

Commit

Permalink
ndarray: fast views into existing arrays
Browse files Browse the repository at this point in the history
CPU loops involving nanobind ndarrays weren't getting properly
vectorized. This commit adds *views*, which provide an efficient
abstraction that enables better code generation.
  • Loading branch information
wjakob committed Sep 11, 2023
1 parent a1ac207 commit 8f602e1
Show file tree
Hide file tree
Showing 5 changed files with 302 additions and 45 deletions.
11 changes: 11 additions & 0 deletions docs/api_extra.rst
Expand Up @@ -565,6 +565,17 @@ section <ndarrays>`.
Return a mutable pointer to the array data. Only enabled when `Scalar` is
not itself ``const``.

.. cpp:function:: template <typename... Extra> auto view()

Returns an nd-array view that is optimized for fast array access on the
CPU. You may optionally specify additional ndarray constraints via the
`Extra` parameter (though a runtime check should first be performed to
ensure that the array possesses these properties).

The returned view provides the operations ``data()``, ``ndim()``,
``shape()``, ``stride()``, and ``operator()`` following the conventions
of the `ndarray` type.

.. cpp:function:: template <typename... Ts> auto& operator()(Ts... indices)

Return a mutable reference to the element at stored at the provided
Expand Down
153 changes: 122 additions & 31 deletions docs/ndarray.rst
Expand Up @@ -83,13 +83,12 @@ should therefore prevent such undefined behavior.
:cpp:class:`nb::ndarray\<...\> <ndarray>` accepts template arguments to
specify such constraints. For example the function interface below
guarantees that the implementation is only invoked when it is provided with
a ``MxNx3`` array of 8-bit unsigned integers that is furthermore stored
contiguously in CPU memory using a C-style array ordering convention.
a ``MxNx3`` array of 8-bit unsigned integers.

.. code-block:: cpp
m.def("process", [](nb::ndarray<uint8_t, nb::shape<nb::any, nb::any, 3>,
nb::c_contig, nb::device::cpu> data) {
nb::device::cpu> data) {
// Double brightness of the MxNx3 RGB image
for (size_t y = 0; y < data.shape(0); ++y)
for (size_t x = 0; x < data.shape(1); ++x)
Expand All @@ -100,15 +99,16 @@ contiguously in CPU memory using a C-style array ordering convention.
The above example also demonstrates the use of
:cpp:func:`nb::ndarray\<...\>::operator() <ndarray::operator()>`, which
provides direct (i.e., high-performance) read/write access to the array
data. Note that this function is only available when the underlying data
type and ndarray rank are specified. It should only be used when the
array storage is reachable via CPU’s virtual memory address space.
provides direct read/write access to the array contents. Note that this
function is only available when the underlying data type and ndarray dimension
are specified via the :cpp:type:`ndarray\<..\> <ndarray>` template parameters.
It should only be used when the array storage is accessible through the CPU's
virtual memory address space.

.. _ndarray-constraints-1:

Constraint types
~~~~~~~~~~~~~~~~
----------------

The following constraints are available

Expand Down Expand Up @@ -153,6 +153,94 @@ count until they go out of scope. It is legal call
when the `GIL <https://wiki.python.org/moin/GlobalInterpreterLock>`__ is not
held.

.. _ndarray-views:

Fast array views
----------------

The following advice applies to performance-sensitive CPU code that reads and
writes arrays using loops that invoke :cpp:func:`nb::ndarray\<...\>::operator()
<ndarray::operator()>`. It does not apply to GPU arrays because they are
usually not accessed in this way.

Consider the following snippet, which fills a 2D array with data:

.. code-block:: cpp
void fill(nb::ndarray<float, nb::ndim<2>, nb::c_contig, nb::device::cpu> arg) {
for (size_t i = 0; i < array.shape(0); ++i)
for (size_t j = 0; j < array.shape(1); ++j)
arg(i, j) = /* ... */;
}
While functional, this code is not perfect. The problem is that to compute the
address of an entry, ``operator()`` accesses the DLPack array descriptor. This
indirection can break certain compiler optimizations.

nanobind provides the method :cpp:func:`ndarray\<...\>::view() <ndarray::view>`
to fix this. It creates a tiny data structure that provides all information
needed to access the array contents, and which can be held within CPU
registers. All relevant compile-time information (:cpp:class:`nb::ndim <ndim>`,
:cpp:class:`nb::shape <shape>`, :cpp:class:`nb::c_contig <c_contig>`,
:cpp:class:`nb::f_contig <f_contig>`) is materialized in this view, which
enables constant propagation, auto-vectorization, and loop unrolling.

An improved version of the example using such a view is shown below:

.. code-block:: cpp
void fill(nb::ndarray<float, nb::ndim<2>, nb::c_contig, nb::device::cpu> arg) {
auto v = array.view(); /// <-- new!
for (size_t i = 0; i < v.shape(0); ++i) // Important; use 'v' instead of 'arg' everywhere in loop
for (size_t j = 0; j < v.shape(1); ++j)
v(i, j) = /* ... */;
}
Note that the view performs no reference counting. You may not store it in a way
that exceeds the lifetime of the original array.

When using OpenMP to parallelize expensive array operations, pass the
``firstprivate(view_1, view_2, ...)`` so that each worker thread can copy the
view into its register file.

.. code-block:: cpp
auto v = array.view();
#pragma omp parallel for schedule(static) firstprivate(v)
for (...) { /* parallel loop */ }
.. _ndarray-runtime-specialization:

Specializing views at runtime
-----------------------------

As mentioned earlier, element access via ``operator()`` only works when both
the array's scalar type and its dimension are specified within the type (i.e.,
when they are known at compile time); the same is also true for array views.
However, sometimes, it is useful that a function can be called with different
array types.

You may use the :cpp:func:`ndarray\<...\>::view() <ndarray::view>` method to
create *specialized* views if a run-time check determines that it is safe to
do so. For example, the function below accepts contiguous CPU arrays and
performs a loop over a specialized 2D ``float`` view when the array is of
this type.

.. code-block:: cpp
void fill(nb::ndarray<nb::c_contig, nb::device::cpu> arg) {
if (arg.dtype() == nb::dtype<float>() && arg.ndim() == 2) {
auto v = array.view<float, nb::ndim<2>>(); // <-- new!
for (size_t i = 0; i < v.shape(0); ++i) {
for (size_t j = 0; j < v.shape(1); ++j) {
v(i, j) = /* ... */;
}
}
} else { /* ... */ }
}
Constraints in type signatures
------------------------------

Expand Down Expand Up @@ -364,6 +452,30 @@ interpreted as follows:
- :cpp:enumerator:`rv_policy::move` is unsupported and demoted to
:cpp:enumerator:`rv_policy::copy`.

.. _ndarray_nonstandard_arithmetic:

Nonstandard arithmetic types
----------------------------

Low or extended-precision arithmetic types (e.g., ``int128``, ``float16``,
``bfloat``) are sometimes used but don't have standardized C++ equivalents. If
you wish to exchange arrays based on such types, you must register a partial
overload of ``nanobind::ndarray_traits`` to inform nanobind about it.

For example, the following snippet makes ``__fp16`` (half-precision type on
``aarch64``) available:

.. code-block:: cpp
namespace nanobind {
template <> struct ndarray_traits<__fp16> {
static constexpr bool is_float = true;
static constexpr bool is_bool = false;
static constexpr bool is_int = false;
static constexpr bool is_signed = true;
};
};
Limitations
-----------

Expand All @@ -383,30 +495,9 @@ internal representations (*dtypes*), including
nanobind's :cpp:class:`nb::ndarray\<...\> <ndarray>` is based on the `DLPack
<https://github.com/dmlc/dlpack>`__ array exchange protocol, which causes it to
be more restrictive. Presently supported dtypes include signed/unsigned
integers, floating point values, and boolean values.
integers, floating point values, and boolean values. Some :ref:`nonstandard
arithmetic types <ndarray_nonstandard_arithmetic>` can be supported as well.

Nanobind can receive and return read-only arrays via the buffer protocol used
to exchange data with NumPy. The DLPack interface currently ignores this
annotation.

Supporting nonstandard arithmetic types
---------------------------------------

Low or extended-precision arithmetic types (e.g., ``int128``, ``float16``,
``bfloat``) are sometimes used but don't have standardized C++ equivalents. If
you wish to exchange arrays based on such types, you must register a partial
overload of ``nanobind::ndarray_traits`` to inform nanobind about it.

For example, the following snippet makes ``__fp16`` (half-precision type on
``aarch64``) available:

.. code-block:: cpp
namespace nanobind {
template <> struct ndarray_traits<__fp16> {
static constexpr bool is_float = true;
static constexpr bool is_bool = false;
static constexpr bool is_int = false;
static constexpr bool is_signed = true;
};
};
130 changes: 116 additions & 14 deletions include/nanobind/ndarray.h
Expand Up @@ -247,6 +247,7 @@ template <typename... Ts> struct ndarray_info {
using shape_type = void;
constexpr static auto name = const_name("ndarray");
constexpr static ndarray_framework framework = ndarray_framework::none;
constexpr static char order = '\0';
};

template <typename T, typename... Ts> struct ndarray_info<T, Ts...> : ndarray_info<Ts...> {
Expand All @@ -259,6 +260,14 @@ template <size_t... Is, typename... Ts> struct ndarray_info<shape<Is...>, Ts...>
using shape_type = shape<Is...>;
};

template <typename... Ts> struct ndarray_info<c_contig, Ts...> : ndarray_info<Ts...> {
constexpr static char order = 'C';
};

template <typename... Ts> struct ndarray_info<f_contig, Ts...> : ndarray_info<Ts...> {
constexpr static char order = 'F';
};

template <typename... Ts> struct ndarray_info<numpy, Ts...> : ndarray_info<Ts...> {
constexpr static auto name = const_name("numpy.ndarray");
constexpr static ndarray_framework framework = ndarray_framework::numpy;
Expand All @@ -282,6 +291,64 @@ template <typename... Ts> struct ndarray_info<jax, Ts...> : ndarray_info<Ts...>

NAMESPACE_END(detail)

template <typename Scalar, typename Shape, char Order> struct ndarray_view {
static constexpr size_t Dim = Shape::size;

ndarray_view() = default;
ndarray_view(const ndarray_view &) = default;
ndarray_view(ndarray_view &&) = default;
ndarray_view &operator=(const ndarray_view &) = default;
ndarray_view &operator=(ndarray_view &&) noexcept = default;
~ndarray_view() noexcept = default;

template <typename... Ts> NB_INLINE Scalar &operator()(Ts... indices) const {
static_assert(
sizeof...(Ts) == Dim,
"ndarray_view::operator(): invalid number of arguments");

const int64_t indices_i64[] { (int64_t) indices... };
int64_t offset = 0;
for (size_t i = 0; i < Dim; ++i)
offset += indices_i64[i] * m_strides[i];

return *(m_data + offset);
}

size_t ndim() const { return Dim; }
size_t shape(size_t i) const { return m_shape[i]; }
int64_t stride(size_t i) const { return m_strides[i]; }
Scalar *data() const { return m_data; }

private:
template <typename...> friend class ndarray;

template <size_t... I1, size_t... I2>
ndarray_view(Scalar *data, const int64_t *shape, const int64_t *strides,
std::index_sequence<I1...>, nanobind::shape<I2...>)
: m_data(data) {

/* Initialize shape/strides with compile-time knowledge if
available (to permit vectorization, loop unrolling, etc.) */
((m_shape[I1] = (I2 == any) ? shape[I1] : I2), ...);
((m_strides[I1] = strides[I1]), ...);

if constexpr (Order == 'F') {
m_strides[0] = 1;
for (size_t i = 1; i < Dim; ++i)
m_strides[i] = m_strides[i - 1] * m_shape[i - 1];
} else if constexpr (Order == 'C') {
m_strides[Dim - 1] = 1;
for (Py_ssize_t i = (Py_ssize_t) Dim - 2; i >= 0; --i)
m_strides[i] = m_strides[i + 1] * m_shape[i + 1];
}
}

Scalar *m_data = nullptr;
int64_t m_shape[Dim] { };
int64_t m_strides[Dim] { };
};


template <typename... Args> class ndarray {
public:
template <typename...> friend class ndarray;
Expand Down Expand Up @@ -405,24 +472,59 @@ template <typename... Args> class ndarray {
byte_offset(indices...));
}

template <typename... Extra> NB_INLINE auto view() {
using Info2 = typename ndarray<Args..., Extra...>::Info;
using Scalar2 = typename Info2::scalar_type;
using Shape2 = typename Info2::shape_type;

constexpr bool has_scalar = !std::is_same_v<Scalar2, void>,
has_shape = !std::is_same_v<Shape2, void>;

static_assert(has_scalar,
"To use the ndarray::view<..>() method, you must add a scalar type "
"annotation (e.g. 'float') to the template parameters of the parent "
"ndarray, or to the call to .view<..>()");

static_assert(has_shape,
"To use the ndarray::view<..>() method, you must add a shape<..> "
"or ndim<..> annotation to the template parameters of the parent "
"ndarray, or to the call to .view<..>()");

if constexpr (has_scalar && has_shape) {
return ndarray_view<Scalar2, Shape2, Info2::order>(
(Scalar2 *) data(), shape_ptr(), stride_ptr(),
std::make_index_sequence<Shape2::size>(), Shape2());
} else {
return nullptr;
}
}

private:
template <typename... Ts>
NB_INLINE int64_t byte_offset(Ts... indices) const {
static_assert(
!std::is_same_v<Scalar, void>,
"To use nb::ndarray::operator(), you must add a scalar type "
constexpr bool has_scalar = !std::is_same_v<Scalar, void>,
has_shape = !std::is_same_v<typename Info::shape_type, void>;

static_assert(has_scalar,
"To use ndarray::operator(), you must add a scalar type "
"annotation (e.g. 'float') to the ndarray template parameters.");
static_assert(
!std::is_same_v<Scalar, void>,
"To use nb::ndarray::operator(), you must add a nb::shape<> "
"annotation to the ndarray template parameters.");
static_assert(sizeof...(Ts) == Info::shape_type::size,
"nb::ndarray::operator(): invalid number of arguments");
size_t counter = 0;
int64_t index = 0;
((index += int64_t(indices) * m_dltensor.strides[counter++]), ...);

return (int64_t) m_dltensor.byte_offset + index * sizeof(typename Info::scalar_type);

static_assert(has_shape,
"To use ndarray::operator(), you must add a shape<> or "
"ndim<> annotation to the ndarray template parameters.");

if constexpr (has_scalar && has_shape) {
static_assert(sizeof...(Ts) == Info::shape_type::size,
"ndarray::operator(): invalid number of arguments");

size_t counter = 0;
int64_t index = 0;
((index += int64_t(indices) * m_dltensor.strides[counter++]), ...);

return (int64_t) m_dltensor.byte_offset + index * sizeof(typename Info::scalar_type);
} else {
return 0;
}
}

detail::ndarray_handle *m_handle = nullptr;
Expand Down

0 comments on commit 8f602e1

Please sign in to comment.