Skip to content

Commit

Permalink
Added a bind_vector API
Browse files Browse the repository at this point in the history
Following the recent addition of the `nb::bind_map<..>` API, this commit
adds a similar `nb::bind_vector<..>` function for creating Python
bindings of `std::vector<..>`.
  • Loading branch information
wjakob committed Jan 26, 2023
1 parent cf40722 commit f2df8a9
Show file tree
Hide file tree
Showing 6 changed files with 405 additions and 4 deletions.
2 changes: 1 addition & 1 deletion include/nanobind/stl/bind_map.h
Expand Up @@ -46,7 +46,7 @@ class_<Map> bind_map(handle scope, const char *name, Args &&...args) {
.def(init<>(),
"Default constructor")

.def("__len__", &Map::size)
.def("__len__", [](const Map &m) { return m.size(); })

.def("__bool__",
[](const Map &m) { return !m.empty(); },
Expand Down
211 changes: 211 additions & 0 deletions include/nanobind/stl/bind_vector.h
@@ -0,0 +1,211 @@
/*
nanobind/stl/bind_vector.h: Automatic creation of bindings for vector-style containers
Copyright (c) 2022 Wenzel Jakob
All rights reserved. Use of this source code is governed by a
BSD-style license that can be found in the LICENSE file.
*/

#pragma once

#include <nanobind/nanobind.h>
#include <nanobind/operators.h>
#include <nanobind/make_iterator.h>
#include <nanobind/stl/detail/traits.h>
#include <vector>
#include <algorithm>

NAMESPACE_BEGIN(NB_NAMESPACE)
NAMESPACE_BEGIN(detail)

inline size_t wrap(Py_ssize_t i, size_t n) {
if (i < 0)
i += (Py_ssize_t) n;

if (i < 0 || (size_t) i >= n)
throw index_error();

return (size_t) i;
}

template <> struct iterator_access<typename std::vector<bool>::iterator> {
using result_type = bool;
result_type operator()(typename std::vector<bool>::iterator &it) const { return *it; }
};

template <typename Value> struct iterable_type_id {
static constexpr auto Name = const_name("Iterable[") +
make_caster<Value>::Name +
const_name("]");
};

NAMESPACE_END(detail)


template <typename Vector, typename... Args>
class_<Vector> bind_vector(handle scope, const char *name, Args &&...args) {
using ValueRef = typename detail::iterator_access<typename Vector::iterator>::result_type;
using Value = std::decay_t<ValueRef>;

auto cl = class_<Vector>(scope, name, std::forward<Args>(args)...)
.def(init<>(), "Default constructor")

.def("__len__", [](const Vector &v) { return v.size(); })

.def("__bool__",
[](const Vector &v) { return !v.empty(); },
"Check whether the vector is nonempty")

.def("__iter__",
[](Vector &v) {
return make_iterator(type<Vector>(), "Iterator",
v.begin(), v.end());
}, keep_alive<0, 1>())

.def("__getitem__",
[](Vector &v, Py_ssize_t i) -> ValueRef {
return v[detail::wrap(i, v.size())];
},
rv_policy::reference_internal)

.def("clear", [](Vector &v) { v.clear(); },
"Remove all items from list.");

if constexpr (detail::is_copy_constructible_v<Value>) {
cl.def(init<const Vector &>(),
"Copy constructor");

cl.def("__init__", [](Vector *v, typed<iterable, detail::iterable_type_id<Value>> &seq) {
new (v) Vector();
v->reserve(len_hint(seq.value));
for (handle h : seq.value)
v->push_back(cast<Value>(h));
}, "Construct from an iterable object");

implicitly_convertible<iterable, Vector>();

cl.def("append",
[](Vector &v, const Value &value) { v.push_back(value); },
"Append `arg` to the end of the list.")

.def("insert",
[](Vector &v, Py_ssize_t i, const Value &x) {
if (i < 0)
i += (Py_ssize_t) v.size();
if (i < 0 || (size_t) i > v.size())
throw index_error();
v.insert(v.begin() + i, x);
},
"Insert object `arg1` before index `arg0`.")

.def("pop",
[](Vector &v, Py_ssize_t i) {
size_t index = detail::wrap(i, v.size());
Value result = std::move(v[index]);
v.erase(v.begin() + index);
return result;
},
arg("index") = -1,
"Remove and return item at `index` (default last).")

.def("extend",
[](Vector &v, const Vector &src) {
v.insert(v.end(), src.begin(), src.end());
},
"Extend `self` by appending elements from `arg`.")

.def("__setitem__",
[](Vector &v, Py_ssize_t i, const Value &value) {
v[detail::wrap(i, v.size())] = value;
})

.def("__delitem__",
[](Vector &v, Py_ssize_t i) {
v.erase(v.begin() + detail::wrap(i, v.size()));
})

.def("__getitem__",
[](const Vector &v, const slice &slice) -> Vector * {
auto [start, stop, step, length] = slice.compute(v.size());
auto *seq = new Vector();
seq->reserve(length);

for (size_t i = 0; i < length; ++i) {
seq->push_back(v[start]);
start += step;
}

return seq;
})

.def("__setitem__",
[](Vector &v, const slice &slice, const Vector &value) {
auto [start, stop, step, length] = slice.compute(v.size());

if (length != value.size())
throw index_error(
"The left and right hand side of the slice "
"assignment have mismatched sizes!");

for (size_t i = 0; i < length; ++i) {
v[start] = value[i];
start += step;
}
})

.def("__delitem__",
[](Vector &v, const slice &slice) {
auto [start, stop, step, length] = slice.compute(v.size());
if (length == 0)
return;

stop = start + (length - 1) * step;
if (start > stop) {
std::swap(start, stop);
step = -step;
}

if (step == 1) {
v.erase(v.begin() + start, v.begin() + stop + 1);
} else {
for (size_t i = 0; i < length; ++i) {
v.erase(v.begin() + stop);
stop -= step;
}
}
});
}

if constexpr (detail::is_equality_comparable_v<Value>) {
cl.def(self == self)
.def(self != self)

.def("__contains__",
[](const Vector &v, const Value &x) {
return std::find(v.begin(), v.end(), x) != v.end();
})

.def("__contains__", // fallback for incompatible types
[](const Vector &, handle) { return false; })

.def("count",
[](const Vector &v, const Value &x) {
return std::count(v.begin(), v.end(), x);
}, "Return number of occurrences of `arg`.")

.def("remove",
[](Vector &v, const Value &x) {
auto p = std::find(v.begin(), v.end(), x);
if (p != v.end())
v.erase(p);
else
throw value_error();
},
"Remove first occurrence of `arg`.");
}

return cl;
}

NAMESPACE_END(NB_NAMESPACE)
1 change: 1 addition & 0 deletions tests/CMakeLists.txt
Expand Up @@ -14,6 +14,7 @@ nanobind_add_module(test_classes_ext test_classes.cpp ${NB_EXTRA_ARGS})
nanobind_add_module(test_holders_ext test_holders.cpp ${NB_EXTRA_ARGS})
nanobind_add_module(test_stl_ext test_stl.cpp ${NB_EXTRA_ARGS})
nanobind_add_module(test_bind_map_ext test_stl_bind_map.cpp ${NB_EXTRA_ARGS})
nanobind_add_module(test_bind_vector_ext test_stl_bind_vector.cpp ${NB_EXTRA_ARGS})
nanobind_add_module(test_enum_ext test_enum.cpp ${NB_EXTRA_ARGS})
nanobind_add_module(test_tensor_ext test_tensor.cpp ${NB_EXTRA_ARGS})
nanobind_add_module(test_intrusive_ext test_intrusive.cpp object.cpp object.h ${NB_EXTRA_ARGS})
Expand Down
12 changes: 9 additions & 3 deletions tests/test_stl_bind_map.py
@@ -1,4 +1,5 @@
import pytest
import sys

import test_bind_map_ext as t

Expand Down Expand Up @@ -85,10 +86,15 @@ def test_map_string_double():
assert type(values).__qualname__ == 'MapStringDouble.ValueView'
assert type(items).__qualname__ == 'MapStringDouble.ItemView'

if sys.version_info < (3, 9):
d = "Dict"
else:
d = "dict"

assert t.MapStringDouble.__init__.__doc__ == \
"""__init__(self) -> None
__init__(self, arg: test_bind_map_ext.MapStringDouble, /) -> None
__init__(self, arg: dict[str, float], /) -> None
__init__(self, arg: %s[str, float], /) -> None
Overloaded function.
Expand All @@ -100,9 +106,9 @@ def test_map_string_double():
Copy constructor
3. ``__init__(self, arg: dict[str, float], /) -> None``
3. ``__init__(self, arg: %s[str, float], /) -> None``
Construct from a dictionary"""
Construct from a dictionary""" % (d, d)


def test_map_string_double_const():
Expand Down
44 changes: 44 additions & 0 deletions tests/test_stl_bind_vector.cpp
@@ -0,0 +1,44 @@
#include <vector>

#include <nanobind/stl/bind_vector.h>

namespace nb = nanobind;

NB_MODULE(test_bind_vector_ext, m) {
nb::bind_vector<std::vector<unsigned int>>(m, "VectorInt");
nb::bind_vector<std::vector<bool>>(m, "VectorBool");

struct El {
explicit El(int v) : a(v) {}
int a;
};

// test_vector_custom
nb::class_<El>(m, "El").def(nb::init<int>())
.def_readwrite("a", &El::a);
nb::bind_vector<std::vector<El>>(m, "VectorEl");
nb::bind_vector<std::vector<std::vector<El>>>(m, "VectorVectorEl");

struct E_nc {
explicit E_nc(int i) : value{i} {}
E_nc(const E_nc &) = delete;
E_nc &operator=(const E_nc &) = delete;
E_nc(E_nc &&) = default;
E_nc &operator=(E_nc &&) = default;

int value;
};

// test_noncopyable_containers
nb::class_<E_nc>(m, "ENC")
.def(nb::init<int>())
.def_readwrite("value", &E_nc::value);

nb::bind_vector<std::vector<E_nc>>(m, "VectorENC");
m.def("get_vnc", [](int n) {
std::vector<E_nc> result;
for (int i = 1; i <= n; i++)
result.emplace_back(i);
return result;
});
}

0 comments on commit f2df8a9

Please sign in to comment.