Skip to content

Commit

Permalink
std::array type caster
Browse files Browse the repository at this point in the history
  • Loading branch information
wjakob committed Oct 16, 2022
1 parent 080beb9 commit be34b16
Show file tree
Hide file tree
Showing 6 changed files with 107 additions and 2 deletions.
6 changes: 5 additions & 1 deletion docs/changelog.rst
Expand Up @@ -8,14 +8,18 @@ current version is still in the prototype range (*0.x.y*), there are no (formal)
guarantees of API or ABI stability. That said, I will do my best to minimize
inconvenience whenever possible.

Version 0.0.8 (TBD)
----------------------------

* Caster for ``std::array<..>``.

Version 0.0.7 (Oct 14, 2022)
----------------------------

* Fixed a regression involving function docstrings in ``pydoc``. (commit
`384f4a
<https://github.com/wjakob/nanobind/commit/384f4ada1f3f08486fb03427227878ddbbcaad43>`_).


Version 0.0.6 (Oct 14, 2022)
----------------------------

Expand Down
13 changes: 13 additions & 0 deletions include/nanobind/stl/array.h
@@ -0,0 +1,13 @@
#pragma once

#include "detail/nb_array.h"
#include <array>

NAMESPACE_BEGIN(NB_NAMESPACE)
NAMESPACE_BEGIN(detail)

template <typename Type, size_t Size> struct type_caster<std::array<Type, Size>>
: array_caster<std::array<Type, Size>, Type, Size> { };

NAMESPACE_END(detail)
NAMESPACE_END(NB_NAMESPACE)
63 changes: 63 additions & 0 deletions include/nanobind/stl/detail/nb_array.h
@@ -0,0 +1,63 @@
#pragma once

#include <nanobind/nanobind.h>

NAMESPACE_BEGIN(NB_NAMESPACE)
NAMESPACE_BEGIN(detail)

template <typename Value_, typename Entry, size_t Size> struct array_caster {
NB_TYPE_CASTER(Value_, const_name("Sequence[") + make_caster<Entry>::Name +
const_name("]"));

using Caster = make_caster<Entry>;

bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept {
PyObject *temp;

/* Will initialize 'temp' (NULL in the case of a failure.) */
PyObject **o = seq_get_with_size(src.ptr(), Size, &temp);

Caster caster;
bool success = o != nullptr;

if (success) {
for (size_t i = 0; i < Size; ++i) {
if (!caster.from_python(o[i], flags, cleanup)) {
success = false;
break;
}

value[i] = ((Caster &&) caster).operator cast_t<Entry &&>();
}

Py_XDECREF(temp);
}

return success;
}

template <typename T>
static handle from_cpp(T &&src, rv_policy policy, cleanup_list *cleanup) {
object list = steal(PyList_New(Size));

if (list.is_valid()) {
Py_ssize_t index = 0;

for (auto &value : src) {
handle h =
Caster::from_cpp(forward_like<T>(value), policy, cleanup);

NB_LIST_SET_ITEM(list.ptr(), index++, h.ptr());
if (!h.is_valid())
return handle();
}
} else {
PyErr_Clear();
}

return list.release();
}
};

NAMESPACE_END(detail)
NAMESPACE_END(NB_NAMESPACE)
6 changes: 5 additions & 1 deletion include/nanobind/stl/detail/nb_list.h
Expand Up @@ -34,6 +34,7 @@ template <typename Value_, typename Entry> struct list_caster {
success = false;
break;
}

value.push_back(((Caster &&) caster).operator cast_t<Entry &&>());
}

Expand All @@ -45,7 +46,8 @@ template <typename Value_, typename Entry> struct list_caster {
template <typename T>
static handle from_cpp(T &&src, rv_policy policy, cleanup_list *cleanup) {
object list = steal(PyList_New(src.size()));
if (list) {

if (list.is_valid()) {
Py_ssize_t index = 0;

for (auto &value : src) {
Expand All @@ -56,6 +58,8 @@ template <typename Value_, typename Entry> struct list_caster {
if (!h.is_valid())
return handle();
}
} else {
PyErr_Clear();
}

return list.release();
Expand Down
12 changes: 12 additions & 0 deletions tests/test_stl.cpp
Expand Up @@ -8,6 +8,7 @@
#include <nanobind/stl/optional.h>
#include <nanobind/stl/variant.h>
#include <nanobind/stl/map.h>
#include <nanobind/stl/array.h>

NB_MAKE_OPAQUE(NB_TYPE(std::vector<float, std::allocator<float>>))

Expand Down Expand Up @@ -233,6 +234,7 @@ NB_MODULE(test_stl_ext, m) {
x.emplace(std::string(1, 'a' + i), i);
return x;
});

m.def("map_return_copyable_value", [](){
std::map<std::string, Copyable> x;
for (int i = 0; i < 10; ++i) {
Expand All @@ -241,6 +243,7 @@ NB_MODULE(test_stl_ext, m) {
}
return x;
});

m.def("map_movable_in_value", [](std::map<std::string, Movable> x) {
if (x.size() != 10) fail();
for (int i = 0; i < 10; ++i) {
Expand All @@ -249,6 +252,7 @@ NB_MODULE(test_stl_ext, m) {
if (x[key].value != i) fail();
}
}, nb::arg("x"));

m.def("map_copyable_in_value", [](std::map<std::string, Copyable> x) {
if (x.size() != 10) fail();
for (int i = 0; i < 10; ++i) {
Expand All @@ -257,6 +261,7 @@ NB_MODULE(test_stl_ext, m) {
if (x[key].value != i) fail();
}
}, nb::arg("x"));

m.def("map_movable_in_lvalue_ref", [](std::map<std::string, Movable> &x) {
if (x.size() != 10) fail();
for (int i = 0; i < 10; ++i) {
Expand All @@ -265,6 +270,7 @@ NB_MODULE(test_stl_ext, m) {
if (x[key].value != i) fail();
}
}, nb::arg("x"));

m.def("map_movable_in_rvalue_ref", [](std::map<std::string, Movable> &&x) {
if (x.size() != 10) fail();
for (int i = 0; i < 10; ++i) {
Expand All @@ -273,6 +279,7 @@ NB_MODULE(test_stl_ext, m) {
if (x[key].value != i) fail();
}
}, nb::arg("x"));

m.def("map_movable_in_ptr", [](std::map<std::string, Movable *> x) {
if (x.size() != 10) fail();
for (int i = 0; i < 10; ++i) {
Expand All @@ -281,11 +288,16 @@ NB_MODULE(test_stl_ext, m) {
if (x[key]->value != i) fail();
}
}, nb::arg("x"));

m.def("map_return_readonly_value", [](){
StructWithReadonlyMap x;
for (int i = 0; i < 10; ++i) {
x.map.insert({std::string(1, 'a' + i), i});
}
return x;
});

// test56
m.def("array_out", [](){ return std::array<int, 3>{1, 2, 3}; });
m.def("array_in", [](std::array<int, 3> x) { return x[0] + x[1] + x[2]; });
}
9 changes: 9 additions & 0 deletions tests/test_stl.py
Expand Up @@ -519,3 +519,12 @@ def test55_map_return_readonly_value(clean):
assert t.map_return_readonly_value.__doc__ == (
"map_return_readonly_value() -> test_stl_ext.StructWithReadonlyMap"
)

def test56_array(clean):
o = t.array_out()
assert isinstance(o, list) and o == [1, 2, 3]
assert t.array_in([1, 2, 3]) == 6
assert t.array_in((1, 2, 3)) == 6
with pytest.raises(TypeError) as excinfo:
assert t.array_in((1, 2, 3, 4)) == 6
assert 'incompatible function arguments' in str(excinfo.value)

0 comments on commit be34b16

Please sign in to comment.