Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Following the recent addition of the `nb::bind_map<..>` API, this commit adds a similar `nb::bind_vector<..>` function for creating Python bindings of `std::vector<..>`.
- Loading branch information
Showing
6 changed files
with
405 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,211 @@ | ||
/* | ||
nanobind/stl/bind_vector.h: Automatic creation of bindings for vector-style containers | ||
Copyright (c) 2022 Wenzel Jakob | ||
All rights reserved. Use of this source code is governed by a | ||
BSD-style license that can be found in the LICENSE file. | ||
*/ | ||
|
||
#pragma once | ||
|
||
#include <nanobind/nanobind.h> | ||
#include <nanobind/operators.h> | ||
#include <nanobind/make_iterator.h> | ||
#include <nanobind/stl/detail/traits.h> | ||
#include <vector> | ||
#include <algorithm> | ||
|
||
NAMESPACE_BEGIN(NB_NAMESPACE) | ||
NAMESPACE_BEGIN(detail) | ||
|
||
inline size_t wrap(Py_ssize_t i, size_t n) { | ||
if (i < 0) | ||
i += (Py_ssize_t) n; | ||
|
||
if (i < 0 || (size_t) i >= n) | ||
throw index_error(); | ||
|
||
return (size_t) i; | ||
} | ||
|
||
template <> struct iterator_access<typename std::vector<bool>::iterator> { | ||
using result_type = bool; | ||
result_type operator()(typename std::vector<bool>::iterator &it) const { return *it; } | ||
}; | ||
|
||
template <typename Value> struct iterable_type_id { | ||
static constexpr auto Name = const_name("Iterable[") + | ||
make_caster<Value>::Name + | ||
const_name("]"); | ||
}; | ||
|
||
NAMESPACE_END(detail) | ||
|
||
|
||
template <typename Vector, typename... Args> | ||
class_<Vector> bind_vector(handle scope, const char *name, Args &&...args) { | ||
using ValueRef = typename detail::iterator_access<typename Vector::iterator>::result_type; | ||
using Value = std::decay_t<ValueRef>; | ||
|
||
auto cl = class_<Vector>(scope, name, std::forward<Args>(args)...) | ||
.def(init<>(), "Default constructor") | ||
|
||
.def("__len__", [](const Vector &v) { return v.size(); }) | ||
|
||
.def("__bool__", | ||
[](const Vector &v) { return !v.empty(); }, | ||
"Check whether the vector is nonempty") | ||
|
||
.def("__iter__", | ||
[](Vector &v) { | ||
return make_iterator(type<Vector>(), "Iterator", | ||
v.begin(), v.end()); | ||
}, keep_alive<0, 1>()) | ||
|
||
.def("__getitem__", | ||
[](Vector &v, Py_ssize_t i) -> ValueRef { | ||
return v[detail::wrap(i, v.size())]; | ||
}, | ||
rv_policy::reference_internal) | ||
|
||
.def("clear", [](Vector &v) { v.clear(); }, | ||
"Remove all items from list."); | ||
|
||
if constexpr (detail::is_copy_constructible_v<Value>) { | ||
cl.def(init<const Vector &>(), | ||
"Copy constructor"); | ||
|
||
cl.def("__init__", [](Vector *v, typed<iterable, detail::iterable_type_id<Value>> &seq) { | ||
new (v) Vector(); | ||
v->reserve(len_hint(seq.value)); | ||
for (handle h : seq.value) | ||
v->push_back(cast<Value>(h)); | ||
}, "Construct from an iterable object"); | ||
|
||
implicitly_convertible<iterable, Vector>(); | ||
|
||
cl.def("append", | ||
[](Vector &v, const Value &value) { v.push_back(value); }, | ||
"Append `arg` to the end of the list.") | ||
|
||
.def("insert", | ||
[](Vector &v, Py_ssize_t i, const Value &x) { | ||
if (i < 0) | ||
i += (Py_ssize_t) v.size(); | ||
if (i < 0 || (size_t) i > v.size()) | ||
throw index_error(); | ||
v.insert(v.begin() + i, x); | ||
}, | ||
"Insert object `arg1` before index `arg0`.") | ||
|
||
.def("pop", | ||
[](Vector &v, Py_ssize_t i) { | ||
size_t index = detail::wrap(i, v.size()); | ||
Value result = std::move(v[index]); | ||
v.erase(v.begin() + index); | ||
return result; | ||
}, | ||
arg("index") = -1, | ||
"Remove and return item at `index` (default last).") | ||
|
||
.def("extend", | ||
[](Vector &v, const Vector &src) { | ||
v.insert(v.end(), src.begin(), src.end()); | ||
}, | ||
"Extend `self` by appending elements from `arg`.") | ||
|
||
.def("__setitem__", | ||
[](Vector &v, Py_ssize_t i, const Value &value) { | ||
v[detail::wrap(i, v.size())] = value; | ||
}) | ||
|
||
.def("__delitem__", | ||
[](Vector &v, Py_ssize_t i) { | ||
v.erase(v.begin() + detail::wrap(i, v.size())); | ||
}) | ||
|
||
.def("__getitem__", | ||
[](const Vector &v, const slice &slice) -> Vector * { | ||
auto [start, stop, step, length] = slice.compute(v.size()); | ||
auto *seq = new Vector(); | ||
seq->reserve(length); | ||
|
||
for (size_t i = 0; i < length; ++i) { | ||
seq->push_back(v[start]); | ||
start += step; | ||
} | ||
|
||
return seq; | ||
}) | ||
|
||
.def("__setitem__", | ||
[](Vector &v, const slice &slice, const Vector &value) { | ||
auto [start, stop, step, length] = slice.compute(v.size()); | ||
|
||
if (length != value.size()) | ||
throw index_error( | ||
"The left and right hand side of the slice " | ||
"assignment have mismatched sizes!"); | ||
|
||
for (size_t i = 0; i < length; ++i) { | ||
v[start] = value[i]; | ||
start += step; | ||
} | ||
}) | ||
|
||
.def("__delitem__", | ||
[](Vector &v, const slice &slice) { | ||
auto [start, stop, step, length] = slice.compute(v.size()); | ||
if (length == 0) | ||
return; | ||
|
||
stop = start + (length - 1) * step; | ||
if (start > stop) { | ||
std::swap(start, stop); | ||
step = -step; | ||
} | ||
|
||
if (step == 1) { | ||
v.erase(v.begin() + start, v.begin() + stop + 1); | ||
} else { | ||
for (size_t i = 0; i < length; ++i) { | ||
v.erase(v.begin() + stop); | ||
stop -= step; | ||
} | ||
} | ||
}); | ||
} | ||
|
||
if constexpr (detail::is_equality_comparable_v<Value>) { | ||
cl.def(self == self) | ||
.def(self != self) | ||
|
||
.def("__contains__", | ||
[](const Vector &v, const Value &x) { | ||
return std::find(v.begin(), v.end(), x) != v.end(); | ||
}) | ||
|
||
.def("__contains__", // fallback for incompatible types | ||
[](const Vector &, handle) { return false; }) | ||
|
||
.def("count", | ||
[](const Vector &v, const Value &x) { | ||
return std::count(v.begin(), v.end(), x); | ||
}, "Return number of occurrences of `arg`.") | ||
|
||
.def("remove", | ||
[](Vector &v, const Value &x) { | ||
auto p = std::find(v.begin(), v.end(), x); | ||
if (p != v.end()) | ||
v.erase(p); | ||
else | ||
throw value_error(); | ||
}, | ||
"Remove first occurrence of `arg`."); | ||
} | ||
|
||
return cl; | ||
} | ||
|
||
NAMESPACE_END(NB_NAMESPACE) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
#include <vector> | ||
|
||
#include <nanobind/stl/bind_vector.h> | ||
|
||
namespace nb = nanobind; | ||
|
||
NB_MODULE(test_bind_vector_ext, m) { | ||
nb::bind_vector<std::vector<unsigned int>>(m, "VectorInt"); | ||
nb::bind_vector<std::vector<bool>>(m, "VectorBool"); | ||
|
||
struct El { | ||
explicit El(int v) : a(v) {} | ||
int a; | ||
}; | ||
|
||
// test_vector_custom | ||
nb::class_<El>(m, "El").def(nb::init<int>()) | ||
.def_readwrite("a", &El::a); | ||
nb::bind_vector<std::vector<El>>(m, "VectorEl"); | ||
nb::bind_vector<std::vector<std::vector<El>>>(m, "VectorVectorEl"); | ||
|
||
struct E_nc { | ||
explicit E_nc(int i) : value{i} {} | ||
E_nc(const E_nc &) = delete; | ||
E_nc &operator=(const E_nc &) = delete; | ||
E_nc(E_nc &&) = default; | ||
E_nc &operator=(E_nc &&) = default; | ||
|
||
int value; | ||
}; | ||
|
||
// test_noncopyable_containers | ||
nb::class_<E_nc>(m, "ENC") | ||
.def(nb::init<int>()) | ||
.def_readwrite("value", &E_nc::value); | ||
|
||
nb::bind_vector<std::vector<E_nc>>(m, "VectorENC"); | ||
m.def("get_vnc", [](int n) { | ||
std::vector<E_nc> result; | ||
for (int i = 1; i <= n; i++) | ||
result.emplace_back(i); | ||
return result; | ||
}); | ||
} |
Oops, something went wrong.