diff --git a/include/nanobind/stl/bind_map.h b/include/nanobind/stl/bind_map.h index 1f79d7b5..bef5bc56 100644 --- a/include/nanobind/stl/bind_map.h +++ b/include/nanobind/stl/bind_map.h @@ -46,7 +46,7 @@ class_ bind_map(handle scope, const char *name, Args &&...args) { .def(init<>(), "Default constructor") - .def("__len__", &Map::size) + .def("__len__", [](const Map &m) { return m.size(); }) .def("__bool__", [](const Map &m) { return !m.empty(); }, diff --git a/include/nanobind/stl/bind_vector.h b/include/nanobind/stl/bind_vector.h new file mode 100644 index 00000000..1b53213d --- /dev/null +++ b/include/nanobind/stl/bind_vector.h @@ -0,0 +1,211 @@ +/* + nanobind/stl/bind_vector.h: Automatic creation of bindings for vector-style containers + + Copyright (c) 2022 Wenzel Jakob + + All rights reserved. Use of this source code is governed by a + BSD-style license that can be found in the LICENSE file. +*/ + +#pragma once + +#include +#include +#include +#include +#include +#include + +NAMESPACE_BEGIN(NB_NAMESPACE) +NAMESPACE_BEGIN(detail) + +inline size_t wrap(Py_ssize_t i, size_t n) { + if (i < 0) + i += (Py_ssize_t) n; + + if (i < 0 || (size_t) i >= n) + throw index_error(); + + return (size_t) i; +} + +template <> struct iterator_access::iterator> { + using result_type = bool; + result_type operator()(typename std::vector::iterator &it) const { return *it; } +}; + +template struct iterable_type_id { + static constexpr auto Name = const_name("Iterable[") + + make_caster::Name + + const_name("]"); +}; + +NAMESPACE_END(detail) + + +template +class_ bind_vector(handle scope, const char *name, Args &&...args) { + using ValueRef = typename detail::iterator_access::result_type; + using Value = std::decay_t; + + auto cl = class_(scope, name, std::forward(args)...) + .def(init<>(), "Default constructor") + + .def("__len__", [](const Vector &v) { return v.size(); }) + + .def("__bool__", + [](const Vector &v) { return !v.empty(); }, + "Check whether the vector is nonempty") + + .def("__iter__", + [](Vector &v) { + return make_iterator(type(), "Iterator", + v.begin(), v.end()); + }, keep_alive<0, 1>()) + + .def("__getitem__", + [](Vector &v, Py_ssize_t i) -> ValueRef { + return v[detail::wrap(i, v.size())]; + }, + rv_policy::reference_internal) + + .def("clear", [](Vector &v) { v.clear(); }, + "Remove all items from list."); + + if constexpr (detail::is_copy_constructible_v) { + cl.def(init(), + "Copy constructor"); + + cl.def("__init__", [](Vector *v, typed> &seq) { + new (v) Vector(); + v->reserve(len_hint(seq.value)); + for (handle h : seq.value) + v->push_back(cast(h)); + }, "Construct from an iterable object"); + + implicitly_convertible(); + + cl.def("append", + [](Vector &v, const Value &value) { v.push_back(value); }, + "Append `arg` to the end of the list.") + + .def("insert", + [](Vector &v, Py_ssize_t i, const Value &x) { + if (i < 0) + i += (Py_ssize_t) v.size(); + if (i < 0 || (size_t) i > v.size()) + throw index_error(); + v.insert(v.begin() + i, x); + }, + "Insert object `arg1` before index `arg0`.") + + .def("pop", + [](Vector &v, Py_ssize_t i) { + size_t index = detail::wrap(i, v.size()); + Value result = std::move(v[index]); + v.erase(v.begin() + index); + return result; + }, + arg("index") = -1, + "Remove and return item at `index` (default last).") + + .def("extend", + [](Vector &v, const Vector &src) { + v.insert(v.end(), src.begin(), src.end()); + }, + "Extend `self` by appending elements from `arg`.") + + .def("__setitem__", + [](Vector &v, Py_ssize_t i, const Value &value) { + v[detail::wrap(i, v.size())] = value; + }) + + .def("__delitem__", + [](Vector &v, Py_ssize_t i) { + v.erase(v.begin() + detail::wrap(i, v.size())); + }) + + .def("__getitem__", + [](const Vector &v, const slice &slice) -> Vector * { + auto [start, stop, step, length] = slice.compute(v.size()); + auto *seq = new Vector(); + seq->reserve(length); + + for (size_t i = 0; i < length; ++i) { + seq->push_back(v[start]); + start += step; + } + + return seq; + }) + + .def("__setitem__", + [](Vector &v, const slice &slice, const Vector &value) { + auto [start, stop, step, length] = slice.compute(v.size()); + + if (length != value.size()) + throw index_error( + "The left and right hand side of the slice " + "assignment have mismatched sizes!"); + + for (size_t i = 0; i < length; ++i) { + v[start] = value[i]; + start += step; + } + }) + + .def("__delitem__", + [](Vector &v, const slice &slice) { + auto [start, stop, step, length] = slice.compute(v.size()); + if (length == 0) + return; + + stop = start + (length - 1) * step; + if (start > stop) { + std::swap(start, stop); + step = -step; + } + + if (step == 1) { + v.erase(v.begin() + start, v.begin() + stop + 1); + } else { + for (size_t i = 0; i < length; ++i) { + v.erase(v.begin() + stop); + stop -= step; + } + } + }); + } + + if constexpr (detail::is_equality_comparable_v) { + cl.def(self == self) + .def(self != self) + + .def("__contains__", + [](const Vector &v, const Value &x) { + return std::find(v.begin(), v.end(), x) != v.end(); + }) + + .def("__contains__", // fallback for incompatible types + [](const Vector &, handle) { return false; }) + + .def("count", + [](const Vector &v, const Value &x) { + return std::count(v.begin(), v.end(), x); + }, "Return number of occurrences of `arg`.") + + .def("remove", + [](Vector &v, const Value &x) { + auto p = std::find(v.begin(), v.end(), x); + if (p != v.end()) + v.erase(p); + else + throw value_error(); + }, + "Remove first occurrence of `arg`."); + } + + return cl; +} + +NAMESPACE_END(NB_NAMESPACE) diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index b544bd62..9ea5d721 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -14,6 +14,7 @@ nanobind_add_module(test_classes_ext test_classes.cpp ${NB_EXTRA_ARGS}) nanobind_add_module(test_holders_ext test_holders.cpp ${NB_EXTRA_ARGS}) nanobind_add_module(test_stl_ext test_stl.cpp ${NB_EXTRA_ARGS}) nanobind_add_module(test_bind_map_ext test_stl_bind_map.cpp ${NB_EXTRA_ARGS}) +nanobind_add_module(test_bind_vector_ext test_stl_bind_vector.cpp ${NB_EXTRA_ARGS}) nanobind_add_module(test_enum_ext test_enum.cpp ${NB_EXTRA_ARGS}) nanobind_add_module(test_tensor_ext test_tensor.cpp ${NB_EXTRA_ARGS}) nanobind_add_module(test_intrusive_ext test_intrusive.cpp object.cpp object.h ${NB_EXTRA_ARGS}) diff --git a/tests/test_stl_bind_map.py b/tests/test_stl_bind_map.py index 1d0de874..b0800f1a 100644 --- a/tests/test_stl_bind_map.py +++ b/tests/test_stl_bind_map.py @@ -1,4 +1,5 @@ import pytest +import sys import test_bind_map_ext as t @@ -85,10 +86,15 @@ def test_map_string_double(): assert type(values).__qualname__ == 'MapStringDouble.ValueView' assert type(items).__qualname__ == 'MapStringDouble.ItemView' + if sys.version_info < (3, 9): + d = "Dict" + else: + d = "dict" + assert t.MapStringDouble.__init__.__doc__ == \ """__init__(self) -> None __init__(self, arg: test_bind_map_ext.MapStringDouble, /) -> None -__init__(self, arg: dict[str, float], /) -> None +__init__(self, arg: %s[str, float], /) -> None Overloaded function. @@ -100,9 +106,9 @@ def test_map_string_double(): Copy constructor -3. ``__init__(self, arg: dict[str, float], /) -> None`` +3. ``__init__(self, arg: %s[str, float], /) -> None`` -Construct from a dictionary""" +Construct from a dictionary""" % (d, d) def test_map_string_double_const(): diff --git a/tests/test_stl_bind_vector.cpp b/tests/test_stl_bind_vector.cpp new file mode 100644 index 00000000..bf7b1722 --- /dev/null +++ b/tests/test_stl_bind_vector.cpp @@ -0,0 +1,44 @@ +#include + +#include + +namespace nb = nanobind; + +NB_MODULE(test_bind_vector_ext, m) { + nb::bind_vector>(m, "VectorInt"); + nb::bind_vector>(m, "VectorBool"); + + struct El { + explicit El(int v) : a(v) {} + int a; + }; + + // test_vector_custom + nb::class_(m, "El").def(nb::init()) + .def_readwrite("a", &El::a); + nb::bind_vector>(m, "VectorEl"); + nb::bind_vector>>(m, "VectorVectorEl"); + + struct E_nc { + explicit E_nc(int i) : value{i} {} + E_nc(const E_nc &) = delete; + E_nc &operator=(const E_nc &) = delete; + E_nc(E_nc &&) = default; + E_nc &operator=(E_nc &&) = default; + + int value; + }; + + // test_noncopyable_containers + nb::class_(m, "ENC") + .def(nb::init()) + .def_readwrite("value", &E_nc::value); + + nb::bind_vector>(m, "VectorENC"); + m.def("get_vnc", [](int n) { + std::vector result; + for (int i = 1; i <= n; i++) + result.emplace_back(i); + return result; + }); +} diff --git a/tests/test_stl_bind_vector.py b/tests/test_stl_bind_vector.py new file mode 100644 index 00000000..1079ff93 --- /dev/null +++ b/tests/test_stl_bind_vector.py @@ -0,0 +1,139 @@ +import pytest + +import test_bind_vector_ext as t + +def test01_vector_int(): + v_int = t.VectorInt([0, 0]) + assert len(v_int) == 2 + assert bool(v_int) is True + + # test construction from a generator + v_int1 = t.VectorInt(x for x in range(5)) + assert t.VectorInt(v_int1) == t.VectorInt([0, 1, 2, 3, 4]) + + v_int2 = t.VectorInt([0, 0]) + assert v_int == v_int2 + v_int2[1] = 1 + assert v_int != v_int2 + + v_int2.append(2) + v_int2.insert(0, 1) + v_int2.insert(0, 2) + v_int2.insert(0, 3) + v_int2.insert(6, 3) + with pytest.raises(IndexError): + v_int2.insert(8, 4) + + v_int.append(99) + v_int2[2:-2] = v_int + assert v_int2 == t.VectorInt([3, 2, 0, 0, 99, 2, 3]) + del v_int2[1:3] + assert v_int2 == t.VectorInt([3, 0, 99, 2, 3]) + del v_int2[0] + assert v_int2 == t.VectorInt([0, 99, 2, 3]) + + v_int2.extend(t.VectorInt([4, 5])) + assert v_int2 == t.VectorInt([0, 99, 2, 3, 4, 5]) + + v_int2.extend([6, 7]) + assert v_int2 == t.VectorInt([0, 99, 2, 3, 4, 5, 6, 7]) + + # test error handling, and that the vector is unchanged + with pytest.warns(RuntimeWarning, match="implicit conversion from type 'list' to type 'test_bind_vector_ext.VectorInt' failed"): + with pytest.raises(TypeError): + v_int2.extend([8, "a"]) + + assert v_int2 == t.VectorInt([0, 99, 2, 3, 4, 5, 6, 7]) + + # test extending from a generator + v_int2.extend(x for x in range(5)) + assert v_int2 == t.VectorInt([0, 99, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4]) + + # Test count feature + assert v_int2.count(2) == 2 + assert v_int2.count(5) == 1 + assert v_int2.count(8) == 0 + assert 2 in v_int2 + assert 5 in v_int2 + assert 8 not in v_int2 + + # test negative indexing + assert v_int2[-1] == 4 + + # insert with negative index + v_int2.insert(-1, 88) + assert v_int2 == t.VectorInt([0, 99, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 88, 4]) + + # delete negative index + del v_int2[-1] + assert v_int2 == t.VectorInt([0, 99, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 88]) + + assert v_int2.pop() == 88 + assert v_int2 == t.VectorInt([0, 99, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3]) + + assert v_int2.pop(1) == 99 + assert v_int2 == t.VectorInt([0, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3]) + + v_int2.clear() + assert len(v_int2) == 0 + + +def test02_vector_bool(): + vv_c = t.VectorBool() + for i in range(9): + vv_c.append(i % 2 == 0) + for i in range(9): + assert vv_c[i] == (i % 2 == 0) + assert vv_c.count(True) == 5 + assert vv_c.count(False) == 4 + + +def test03_vector_custom(): + v_a = t.VectorEl() + v_a.append(t.El(1)) + v_a.append(t.El(2)) + assert len(v_a) == 2 and v_a[0].a == 1 and v_a[1].a == 2 + + vv_a = t.VectorVectorEl() + vv_a.append(v_a) + v_b = vv_a[0] + assert len(v_b) == 2 and v_b[0].a == 1 and v_b[1].a == 2 + + +def test04_vector_noncopyable(): + vnc = t.get_vnc(5) + for i in range(0, 5): + assert vnc[i].value == i + 1 + + for i, j in enumerate(vnc, start=1): + assert j.value == i + +def test05_vector_slicing(): + l1 = list(range(100)) + l2 = t.VectorInt(l1) + + def check_same(s): + assert l1[s] == l2[s] + + def check_del(s): + l1c = type(l1)(l1) + l2c = type(l2)(l2) + del l1c[s] + del l2c[s] + l2c = list(l2c) + print(repr(l1c)) + print(repr(l2c)) + assert l1c == l2c + + check_same(slice(1, 13, 4)) + check_same(slice(1, 14, 4)) + check_same(slice(10, 2000, 1)) + check_same(slice(200, 10, 1)) + check_same(slice(200, 10, -1)) + check_same(slice(200, 10, -3)) + check_del(slice(1, 13, 4)) + check_del(slice(1, 14, 4)) + check_del(slice(10, 2000, 1)) + check_del(slice(200, 10, 1)) + check_del(slice(200, 10, -1)) + check_del(slice(200, 10, -3))