Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
8 changed files
with
250 additions
and
16 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
/* | ||
nanobind/make_iterator.h: nb::make_[key,value_]iterator() | ||
This implementation is a port from pybind11 with minimal adjustments. | ||
All rights reserved. Use of this source code is governed by a | ||
BSD-style license that can be found in the LICENSE file. | ||
*/ | ||
|
||
#pragma once | ||
|
||
#include <nanobind/nanobind.h> | ||
#include <nanobind/stl/pair.h> | ||
|
||
NAMESPACE_BEGIN(NB_NAMESPACE) | ||
NAMESPACE_BEGIN(detail) | ||
|
||
/* There are a large number of apparently unused template arguments because | ||
each combination requires a separate nb::class_ registration. */ | ||
template <typename Access, rv_policy Policy, typename Iterator, | ||
typename Sentinel, typename ValueType, typename... Extra> | ||
struct iterator_state { | ||
Iterator it; | ||
Sentinel end; | ||
bool first_or_done; | ||
}; | ||
|
||
// Note: these helpers take the iterator by non-const reference because some | ||
// iterators in the wild can't be dereferenced when const. | ||
template <typename Iterator> struct iterator_access { | ||
using result_type = decltype(*std::declval<Iterator &>()); | ||
result_type operator()(Iterator &it) const { return *it; } | ||
}; | ||
|
||
template <typename Iterator> struct iterator_key_access { | ||
using result_type = decltype((*std::declval<Iterator &>()).first); | ||
result_type operator()(Iterator &it) const { return (*it).first; } | ||
}; | ||
|
||
template <typename Iterator> struct iterator_value_access { | ||
using result_type = decltype((*std::declval<Iterator &>()).second); | ||
result_type operator()(Iterator &it) const { return (*it).second; } | ||
}; | ||
|
||
template <typename Access, rv_policy Policy, typename Iterator, | ||
typename Sentinel, typename ValueType, typename... Extra> | ||
iterator make_iterator_impl(handle scope, const char *name, | ||
Iterator &&first, Sentinel &&last, | ||
Extra &&...extra) { | ||
using State = iterator_state<Access, Policy, Iterator, Sentinel, ValueType, Extra...>; | ||
|
||
if (!type<State>().is_valid()) { | ||
class_<State>(scope, name) | ||
.def("__iter__", [](handle h) { return h; }) | ||
.def("__next__", | ||
[](State &s) -> ValueType { | ||
if (!s.first_or_done) | ||
++s.it; | ||
else | ||
s.first_or_done = false; | ||
|
||
if (s.it == s.end) { | ||
s.first_or_done = true; | ||
throw stop_iteration(); | ||
} | ||
|
||
return Access()(s.it); | ||
}, | ||
std::forward<Extra>(extra)..., | ||
Policy); | ||
} | ||
|
||
return borrow<iterator>(cast(State{ std::forward<Iterator>(first), | ||
std::forward<Sentinel>(last), true })); | ||
} | ||
|
||
NAMESPACE_END(detail) | ||
|
||
/// Makes a python iterator from a first and past-the-end C++ InputIterator. | ||
template <rv_policy Policy = rv_policy::reference_internal, | ||
typename Iterator, | ||
typename Sentinel, | ||
typename ValueType = typename detail::iterator_access<Iterator>::result_type, | ||
typename... Extra> | ||
iterator make_iterator(handle scope, const char *name, Iterator &&first, Sentinel &&last, Extra &&...extra) { | ||
return detail::make_iterator_impl<detail::iterator_access<Iterator>, Policy, | ||
Iterator, Sentinel, ValueType, Extra...>( | ||
scope, name, std::forward<Iterator>(first), | ||
std::forward<Sentinel>(last), std::forward<Extra>(extra)...); | ||
} | ||
|
||
/// Makes an iterator over the keys (`.first`) of a iterator over pairs from a | ||
/// first and past-the-end InputIterator. | ||
template <rv_policy Policy = rv_policy::reference_internal, typename Iterator, | ||
typename Sentinel, | ||
typename KeyType = | ||
typename detail::iterator_key_access<Iterator>::result_type, | ||
typename... Extra> | ||
iterator make_key_iterator(handle scope, const char *name, Iterator &&first, | ||
Sentinel &&last, Extra &&...extra) { | ||
return detail::make_iterator_impl<detail::iterator_key_access<Iterator>, | ||
Policy, Iterator, Sentinel, KeyType, | ||
Extra...>( | ||
scope, name, std::forward<Iterator>(first), | ||
std::forward<Sentinel>(last), std::forward<Extra>(extra)...); | ||
} | ||
|
||
/// Makes an iterator over the values (`.second`) of a iterator over pairs from a | ||
/// first and past-the-end InputIterator. | ||
template <rv_policy Policy = rv_policy::reference_internal, | ||
typename Iterator, | ||
typename Sentinel, | ||
typename ValueType = typename detail::iterator_value_access<Iterator>::result_type, | ||
typename... Extra> | ||
iterator make_value_iterator(handle scope, const char *name, Iterator &&first, Sentinel &&last, Extra &&...extra) { | ||
return detail::make_iterator_impl<detail::iterator_value_access<Iterator>, | ||
Policy, Iterator, Sentinel, ValueType, | ||
Extra...>( | ||
scope, name, std::forward<Iterator>(first), | ||
std::forward<Sentinel>(last), std::forward<Extra>(extra)...); | ||
} | ||
|
||
/// Makes an iterator over values of a container supporting `std::begin()`/`std::end()` | ||
template <rv_policy Policy = rv_policy::reference_internal, | ||
typename Type, | ||
typename... Extra> | ||
iterator make_iterator(handle scope, const char *name, Type &value, Extra &&...extra) { | ||
return make_iterator<Policy>(scope, name, std::begin(value), | ||
std::end(value), | ||
std::forward<Extra>(extra)...); | ||
} | ||
|
||
NAMESPACE_END(NB_NAMESPACE) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
#include <nanobind/make_iterator.h> | ||
#include <nanobind/stl/unordered_map.h> | ||
#include <nanobind/stl/string.h> | ||
|
||
namespace nb = nanobind; | ||
|
||
NB_MODULE(test_make_iterator_ext, m) { | ||
struct StringMap { | ||
std::unordered_map<std::string, std::string> map; | ||
decltype(map.cbegin()) begin() const { return map.cbegin(); } | ||
decltype(map.cend()) end() const { return map.cend(); } | ||
}; | ||
|
||
nb::class_<StringMap>(m, "StringMap") | ||
.def(nb::init<>()) | ||
.def(nb::init<std::unordered_map<std::string, std::string>>()) | ||
.def("__iter__", | ||
[](const StringMap &map) { | ||
return nb::make_key_iterator(nb::type<StringMap>(), | ||
"key_iterator", | ||
map.begin(), | ||
map.end()); | ||
}, nb::keep_alive<0, 1>()) | ||
.def("items", | ||
[](const StringMap &map) { | ||
return nb::make_iterator(nb::type<StringMap>(), | ||
"item_iterator", | ||
map.begin(), | ||
map.end()); | ||
}, nb::keep_alive<0, 1>()) | ||
.def("values", [](const StringMap &map) { | ||
return nb::make_value_iterator(nb::type<StringMap>(), | ||
"value_iterator", | ||
map.begin(), | ||
map.end()); | ||
}, nb::keep_alive<0, 1>()); | ||
|
||
nb::handle mod = m; | ||
m.def("iterator_passthrough", [mod](nb::iterator s) -> nb::iterator { | ||
return nb::make_iterator(mod, "pt_iterator", std::begin(s), std::end(s)); | ||
}); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
import test_make_iterator_ext as t | ||
import pytest | ||
|
||
data = [ | ||
{}, | ||
{ 'a' : 'b' }, | ||
{ str(i) : chr(i) for i in range(1000) } | ||
] | ||
|
||
|
||
def test01_key_iterator(): | ||
for d in data: | ||
m = t.StringMap(d) | ||
assert sorted(list(m)) == sorted(list(d)) | ||
|
||
|
||
def test02_value_iterator(): | ||
types = [] | ||
for d in data: | ||
m = t.StringMap(d) | ||
types.append(type(m.values())) | ||
assert sorted(list(m.values())) == sorted(list(d.values())) | ||
assert types[0] is types[1] and types[1] is types[2] | ||
|
||
|
||
def test03_items_iterator(): | ||
for d in data: | ||
m = t.StringMap(d) | ||
assert sorted(list(m.items())) == sorted(list(d.items())) | ||
|
||
|
||
def test04_passthrough_iterator(): | ||
for d in data: | ||
m = t.StringMap(d) | ||
assert list(t.iterator_passthrough(m.values())) == list(m.values()) |