Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: removing typing and duplicate class_ for KeysView/ValuesView/ItemsView. #4985

Merged
merged 2 commits into from
Jan 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
79 changes: 25 additions & 54 deletions include/pybind11/stl_bind.h
Original file line number Diff line number Diff line change
Expand Up @@ -645,49 +645,50 @@ auto map_if_insertion_operator(Class_ &cl, std::string const &name)
"Return the canonical string representation of this map.");
}

template <typename KeyType>
struct keys_view {
virtual size_t len() = 0;
virtual iterator iter() = 0;
virtual bool contains(const KeyType &k) = 0;
virtual bool contains(const object &k) = 0;
virtual bool contains(const handle &k) = 0;
virtual ~keys_view() = default;
};

template <typename MappedType>
struct values_view {
virtual size_t len() = 0;
virtual iterator iter() = 0;
virtual ~values_view() = default;
};

template <typename KeyType, typename MappedType>
struct items_view {
virtual size_t len() = 0;
virtual iterator iter() = 0;
virtual ~items_view() = default;
};

template <typename Map, typename KeysView>
struct KeysViewImpl : public KeysView {
template <typename Map>
struct KeysViewImpl : public detail::keys_view {
explicit KeysViewImpl(Map &map) : map(map) {}
size_t len() override { return map.size(); }
iterator iter() override { return make_key_iterator(map.begin(), map.end()); }
bool contains(const typename Map::key_type &k) override { return map.find(k) != map.end(); }
bool contains(const object &) override { return false; }
bool contains(const handle &k) override {
try {
return map.find(k.template cast<typename Map::key_type>()) != map.end();
} catch (const cast_error &) {
return false;
}
}
Map &map;
};

template <typename Map, typename ValuesView>
struct ValuesViewImpl : public ValuesView {
template <typename Map>
struct ValuesViewImpl : public detail::values_view {
explicit ValuesViewImpl(Map &map) : map(map) {}
size_t len() override { return map.size(); }
iterator iter() override { return make_value_iterator(map.begin(), map.end()); }
Map &map;
};

template <typename Map, typename ItemsView>
struct ItemsViewImpl : public ItemsView {
template <typename Map>
struct ItemsViewImpl : public detail::items_view {
explicit ItemsViewImpl(Map &map) : map(map) {}
size_t len() override { return map.size(); }
iterator iter() override { return make_iterator(map.begin(), map.end()); }
Expand All @@ -700,11 +701,9 @@ template <typename Map, typename holder_type = std::unique_ptr<Map>, typename...
class_<Map, holder_type> bind_map(handle scope, const std::string &name, Args &&...args) {
using KeyType = typename Map::key_type;
using MappedType = typename Map::mapped_type;
using StrippedKeyType = detail::remove_cvref_t<KeyType>;
using StrippedMappedType = detail::remove_cvref_t<MappedType>;
using KeysView = detail::keys_view<StrippedKeyType>;
using ValuesView = detail::values_view<StrippedMappedType>;
using ItemsView = detail::items_view<StrippedKeyType, StrippedMappedType>;
using KeysView = detail::keys_view;
using ValuesView = detail::values_view;
using ItemsView = detail::items_view;
using Class_ = class_<Map, holder_type>;

// If either type is a non-module-local bound type then make the map binding non-local as well;
Expand All @@ -718,39 +717,20 @@ class_<Map, holder_type> bind_map(handle scope, const std::string &name, Args &&
}

Class_ cl(scope, name.c_str(), pybind11::module_local(local), std::forward<Args>(args)...);
static constexpr auto key_type_descr = detail::make_caster<KeyType>::name;
static constexpr auto mapped_type_descr = detail::make_caster<MappedType>::name;
std::string key_type_name(key_type_descr.text), mapped_type_name(mapped_type_descr.text);

// If key type isn't properly wrapped, fall back to C++ names
if (key_type_name == "%") {
key_type_name = detail::type_info_description(typeid(KeyType));
}
// Similarly for value type:
if (mapped_type_name == "%") {
mapped_type_name = detail::type_info_description(typeid(MappedType));
}

// Wrap KeysView[KeyType] if it wasn't already wrapped
// Wrap KeysView if it wasn't already wrapped
if (!detail::get_type_info(typeid(KeysView))) {
class_<KeysView> keys_view(
scope, ("KeysView[" + key_type_name + "]").c_str(), pybind11::module_local(local));
class_<KeysView> keys_view(scope, "KeysView", pybind11::module_local(local));
keys_view.def("__len__", &KeysView::len);
keys_view.def("__iter__",
&KeysView::iter,
keep_alive<0, 1>() /* Essential: keep view alive while iterator exists */
);
keys_view.def("__contains__",
static_cast<bool (KeysView::*)(const KeyType &)>(&KeysView::contains));
// Fallback for when the object is not of the key type
keys_view.def("__contains__",
static_cast<bool (KeysView::*)(const object &)>(&KeysView::contains));
keys_view.def("__contains__", &KeysView::contains);
}
// Similarly for ValuesView:
if (!detail::get_type_info(typeid(ValuesView))) {
class_<ValuesView> values_view(scope,
("ValuesView[" + mapped_type_name + "]").c_str(),
pybind11::module_local(local));
class_<ValuesView> values_view(scope, "ValuesView", pybind11::module_local(local));
values_view.def("__len__", &ValuesView::len);
values_view.def("__iter__",
&ValuesView::iter,
Expand All @@ -759,10 +739,7 @@ class_<Map, holder_type> bind_map(handle scope, const std::string &name, Args &&
}
// Similarly for ItemsView:
if (!detail::get_type_info(typeid(ItemsView))) {
class_<ItemsView> items_view(
scope,
("ItemsView[" + key_type_name + ", ").append(mapped_type_name + "]").c_str(),
pybind11::module_local(local));
class_<ItemsView> items_view(scope, "ItemsView", pybind11::module_local(local));
items_view.def("__len__", &ItemsView::len);
items_view.def("__iter__",
&ItemsView::iter,
Expand All @@ -788,25 +765,19 @@ class_<Map, holder_type> bind_map(handle scope, const std::string &name, Args &&

cl.def(
"keys",
[](Map &m) {
return std::unique_ptr<KeysView>(new detail::KeysViewImpl<Map, KeysView>(m));
},
[](Map &m) { return std::unique_ptr<KeysView>(new detail::KeysViewImpl<Map>(m)); },
keep_alive<0, 1>() /* Essential: keep map alive while view exists */
);

cl.def(
"values",
[](Map &m) {
return std::unique_ptr<ValuesView>(new detail::ValuesViewImpl<Map, ValuesView>(m));
},
[](Map &m) { return std::unique_ptr<ValuesView>(new detail::ValuesViewImpl<Map>(m)); },
keep_alive<0, 1>() /* Essential: keep map alive while view exists */
);

cl.def(
"items",
[](Map &m) {
return std::unique_ptr<ItemsView>(new detail::ItemsViewImpl<Map, ItemsView>(m));
},
[](Map &m) { return std::unique_ptr<ItemsView>(new detail::ItemsViewImpl<Map>(m)); },
keep_alive<0, 1>() /* Essential: keep map alive while view exists */
);

Expand Down
10 changes: 10 additions & 0 deletions tests/test_stl_binders.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,16 @@ TEST_SUBMODULE(stl_binders, m) {
py::bind_map<std::unordered_map<std::string, double const>>(m,
"UnorderedMapStringDoubleConst");

// test_map_view_types
py::bind_map<std::map<std::string, float>>(m, "MapStringFloat");
py::bind_map<std::unordered_map<std::string, float>>(m, "UnorderedMapStringFloat");

py::bind_map<std::map<std::pair<double, int>, int32_t>>(m, "MapPairDoubleIntInt32");
py::bind_map<std::map<std::pair<double, int>, int64_t>>(m, "MapPairDoubleIntInt64");

py::bind_map<std::map<int, py::object>>(m, "MapIntObject");
py::bind_map<std::map<std::string, py::object>>(m, "MapStringObject");

py::class_<E_nc>(m, "ENC").def(py::init<int>()).def_readwrite("value", &E_nc::value);

// test_noncopyable_containers
Expand Down
30 changes: 27 additions & 3 deletions tests/test_stl_binders.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,9 +317,9 @@ def test_map_view_types():
map_string_double_const = m.MapStringDoubleConst()
unordered_map_string_double_const = m.UnorderedMapStringDoubleConst()

assert map_string_double.keys().__class__.__name__ == "KeysView[str]"
assert map_string_double.values().__class__.__name__ == "ValuesView[float]"
assert map_string_double.items().__class__.__name__ == "ItemsView[str, float]"
assert map_string_double.keys().__class__.__name__ == "KeysView"
assert map_string_double.values().__class__.__name__ == "ValuesView"
assert map_string_double.items().__class__.__name__ == "ItemsView"

keys_type = type(map_string_double.keys())
assert type(unordered_map_string_double.keys()) is keys_type
Expand All @@ -336,6 +336,30 @@ def test_map_view_types():
assert type(map_string_double_const.items()) is items_type
assert type(unordered_map_string_double_const.items()) is items_type

map_string_float = m.MapStringFloat()
unordered_map_string_float = m.UnorderedMapStringFloat()

assert type(map_string_float.keys()) is keys_type
assert type(unordered_map_string_float.keys()) is keys_type
assert type(map_string_float.values()) is values_type
assert type(unordered_map_string_float.values()) is values_type
assert type(map_string_float.items()) is items_type
assert type(unordered_map_string_float.items()) is items_type

map_pair_double_int_int32 = m.MapPairDoubleIntInt32()
map_pair_double_int_int64 = m.MapPairDoubleIntInt64()

assert type(map_pair_double_int_int32.values()) is values_type
assert type(map_pair_double_int_int64.values()) is values_type

map_int_object = m.MapIntObject()
map_string_object = m.MapStringObject()

assert type(map_int_object.keys()) is keys_type
assert type(map_string_object.keys()) is keys_type
assert type(map_int_object.items()) is items_type
assert type(map_string_object.items()) is items_type


def test_recursive_vector():
recursive_vector = m.RecursiveVector()
Expand Down