From d2afe7f0017335b22f546226d718eb476a6062f5 Mon Sep 17 00:00:00 2001
From: Pim Schellart
Date: Wed, 12 Oct 2016 12:00:53 -0400
Subject: [PATCH] Accept any sequence type as std::vector (or std::list)
---
include/pybind11/pytypes.h | 26 ++++++++++++++++++++++++++
include/pybind11/stl.h | 12 ++++++------
tests/test_python_types.py | 8 ++++++++
3 files changed, 40 insertions(+), 6 deletions(-)
diff --git a/include/pybind11/pytypes.h b/include/pybind11/pytypes.h
index 4e5d7f3f8c..61599d2b3f 100644
--- a/include/pybind11/pytypes.h
+++ b/include/pybind11/pytypes.h
@@ -29,12 +29,14 @@ namespace accessor_policies {
struct obj_attr;
struct str_attr;
struct generic_item;
+ struct sequence_item;
struct list_item;
struct tuple_item;
}
using obj_attr_accessor = accessor;
using str_attr_accessor = accessor;
using item_accessor = accessor;
+using sequence_accessor = accessor;
using list_accessor = accessor;
using tuple_accessor = accessor;
@@ -261,6 +263,23 @@ struct generic_item {
}
};
+struct sequence_item {
+ using key_type = size_t;
+
+ static object get(handle obj, size_t index) {
+ PyObject *result = PySequence_GetItem(obj.ptr(), static_cast(index));
+ if (!result) { throw error_already_set(); }
+ return {result, true};
+ }
+
+ static void set(handle obj, size_t index, handle val) {
+ // PySequence_SetItem does not steal a reference to 'val'
+ if (PySequence_SetItem(obj.ptr(), static_cast(index), val.ptr()) != 0) {
+ throw error_already_set();
+ }
+ }
+};
+
struct list_item {
using key_type = size_t;
@@ -673,6 +692,13 @@ class dict : public object {
bool contains(const char *key) const { return PyDict_Contains(ptr(), pybind11::str(key).ptr()) == 1; }
};
+class sequence : public object {
+public:
+ PYBIND11_OBJECT(sequence, object, PySequence_Check)
+ size_t size() const { return (size_t) PySequence_Size(m_ptr); }
+ detail::sequence_accessor operator[](size_t index) const { return {*this, index}; }
+};
+
class list : public object {
public:
PYBIND11_OBJECT(list, object, PyList_Check)
diff --git a/include/pybind11/stl.h b/include/pybind11/stl.h
index 1f052ee0b8..e5c6e3c7e1 100644
--- a/include/pybind11/stl.h
+++ b/include/pybind11/stl.h
@@ -97,13 +97,13 @@ template struct list_caster {
using value_conv = make_caster;
bool load(handle src, bool convert) {
- list l(src, true);
- if (!l.check())
+ sequence s(src, true);
+ if (!s.check())
return false;
value_conv conv;
value.clear();
- reserve_maybe(l, &value);
- for (auto it : l) {
+ reserve_maybe(s, &value);
+ for (auto it : s) {
if (!conv.load(it, convert))
return false;
value.push_back((Value) conv);
@@ -113,8 +113,8 @@ template struct list_caster {
template ().reserve(0)), void>::value, int> = 0>
- void reserve_maybe(list l, Type *) { value.reserve(l.size()); }
- void reserve_maybe(list, void *) { }
+ void reserve_maybe(sequence s, Type *) { value.reserve(s.size()); }
+ void reserve_maybe(sequence, void *) { }
static handle cast(const Type &src, return_value_policy policy, handle parent) {
list l(src.size());
diff --git a/tests/test_python_types.py b/tests/test_python_types.py
index 628e46124b..9bba8249ce 100644
--- a/tests/test_python_types.py
+++ b/tests/test_python_types.py
@@ -71,6 +71,14 @@ def test_instance(capture):
list item 0: value
list item 1: value2
"""
+ with capture:
+ list_result = instance.get_list_2()
+ list_result.append('value2')
+ instance.print_list_2(tuple(list_result))
+ assert capture.unordered == """
+ list item 0: value
+ list item 1: value2
+ """
array_result = instance.get_array()
assert array_result == ['array entry 1', 'array entry 2']
with capture: