Skip to content

Commit

Permalink
fixes issue reported in #318 also in other type casters
Browse files Browse the repository at this point in the history
  • Loading branch information
wjakob committed Oct 9, 2023
1 parent d1ad3b9 commit 5f25ae0
Show file tree
Hide file tree
Showing 9 changed files with 54 additions and 29 deletions.
5 changes: 4 additions & 1 deletion include/nanobind/nb_cast.h
Expand Up @@ -34,7 +34,10 @@ enum cast_flags : uint8_t {
convert = (1 << 0),

// Passed to the 'self' argument in a constructor call (__init__)
construct = (1 << 1)
construct = (1 << 1),

// Don't accept 'None' Python objects in the base class caster
none_disallowed = (1 << 2),
};

/**
Expand Down
7 changes: 5 additions & 2 deletions include/nanobind/stl/detail/nb_array.h
Expand Up @@ -5,8 +5,8 @@
NAMESPACE_BEGIN(NB_NAMESPACE)
NAMESPACE_BEGIN(detail)

template <typename Value_, typename Entry, size_t Size> struct array_caster {
NB_TYPE_CASTER(Value_, const_name(NB_TYPING_LIST "[") +
template <typename Array, typename Entry, size_t Size> struct array_caster {
NB_TYPE_CASTER(Array, const_name(NB_TYPING_LIST "[") +
make_caster<Entry>::Name + const_name("]"));

using Caster = make_caster<Entry>;
Expand All @@ -20,6 +20,9 @@ template <typename Value_, typename Entry, size_t Size> struct array_caster {
Caster caster;
bool success = o != nullptr;

if (is_base_caster_v<Caster> && !std::is_pointer_v<Entry>)
flags |= (uint8_t) cast_flags::none_disallowed;

if (success) {
for (size_t i = 0; i < Size; ++i) {
if (!caster.from_python(o[i], flags, cleanup)) {
Expand Down
29 changes: 18 additions & 11 deletions include/nanobind/stl/detail/nb_dict.h
Expand Up @@ -14,13 +14,13 @@
NAMESPACE_BEGIN(NB_NAMESPACE)
NAMESPACE_BEGIN(detail)

template <typename Value_, typename Key, typename Element> struct dict_caster {
NB_TYPE_CASTER(Value_, const_name(NB_TYPING_DICT "[") + make_caster<Key>::Name +
const_name(", ") + make_caster<Element>::Name +
template <typename Dict, typename Key, typename Val> struct dict_caster {
NB_TYPE_CASTER(Dict, const_name(NB_TYPING_DICT "[") + make_caster<Key>::Name +
const_name(", ") + make_caster<Val>::Name +
const_name("]"));

using KeyCaster = make_caster<Key>;
using ElementCaster = make_caster<Element>;
using ValCaster = make_caster<Val>;

bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept {
value.clear();
Expand All @@ -32,27 +32,34 @@ template <typename Value_, typename Key, typename Element> struct dict_caster {
}

Py_ssize_t size = NB_LIST_GET_SIZE(items);
bool success = (size >= 0);
bool success = size >= 0;

uint8_t flags_key = flags, flags_val = flags;

if (is_base_caster_v<KeyCaster> && !std::is_pointer_v<Key>)
flags_key |= (uint8_t) cast_flags::none_disallowed;
if (is_base_caster_v<ValCaster> && !std::is_pointer_v<Val>)
flags_val |= (uint8_t) cast_flags::none_disallowed;

KeyCaster key_caster;
ElementCaster element_caster;
ValCaster val_caster;
for (Py_ssize_t i = 0; i < size; ++i) {
PyObject *item = NB_LIST_GET_ITEM(items, i);
PyObject *key = NB_TUPLE_GET_ITEM(item, 0);
PyObject *element = NB_TUPLE_GET_ITEM(item, 1);
PyObject *val = NB_TUPLE_GET_ITEM(item, 1);

if (!key_caster.from_python(key, flags, cleanup)) {
if (!key_caster.from_python(key, flags_key, cleanup)) {
success = false;
break;
}

if (!element_caster.from_python(element, flags, cleanup)) {
if (!val_caster.from_python(val, flags_val, cleanup)) {
success = false;
break;
}

value.emplace(key_caster.operator cast_t<Key>(),
element_caster.operator cast_t<Element>());
val_caster.operator cast_t<Val>());
}

Py_DECREF(items);
Expand All @@ -68,7 +75,7 @@ template <typename Value_, typename Key, typename Element> struct dict_caster {
for (auto &item : src) {
object k = steal(KeyCaster::from_cpp(
forward_like<T>(item.first), policy, cleanup));
object e = steal(ElementCaster::from_cpp(
object e = steal(ValCaster::from_cpp(
forward_like<T>(item.second), policy, cleanup));

if (!k.is_valid() || !e.is_valid() ||
Expand Down
9 changes: 6 additions & 3 deletions include/nanobind/stl/detail/nb_list.h
Expand Up @@ -14,8 +14,8 @@
NAMESPACE_BEGIN(NB_NAMESPACE)
NAMESPACE_BEGIN(detail)

template <typename Value_, typename Entry> struct list_caster {
NB_TYPE_CASTER(Value_, const_name(NB_TYPING_LIST "[") +
template <typename List, typename Entry> struct list_caster {
NB_TYPE_CASTER(List, const_name(NB_TYPING_LIST "[") +
make_caster<Entry>::Name + const_name("]"));

using Caster = make_caster<Entry>;
Expand All @@ -32,12 +32,15 @@ template <typename Value_, typename Entry> struct list_caster {

value.clear();

if constexpr (is_detected_v<has_reserve, Value_>)
if constexpr (is_detected_v<has_reserve, List>)
value.reserve(size);

Caster caster;
bool success = o != nullptr;

if (is_base_caster_v<Caster> && !std::is_pointer_v<Entry>)
flags |= (uint8_t) cast_flags::none_disallowed;

for (size_t i = 0; i < size; ++i) {
if (!caster.from_python(o[i], flags, cleanup)) {
success = false;
Expand Down
15 changes: 9 additions & 6 deletions include/nanobind/stl/detail/nb_set.h
Expand Up @@ -14,24 +14,27 @@
NAMESPACE_BEGIN(NB_NAMESPACE)
NAMESPACE_BEGIN(detail)

template <typename Value_, typename Key> struct set_caster {
NB_TYPE_CASTER(Value_, const_name(NB_TYPING_SET "[") + make_caster<Key>::Name + const_name("]"));
template <typename Set, typename Key> struct set_caster {
NB_TYPE_CASTER(Set, const_name(NB_TYPING_SET "[") + make_caster<Key>::Name + const_name("]"));

using KeyCaster = make_caster<Key>;
using Caster = make_caster<Key>;

bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept {
value.clear();

PyObject* iter = obj_iter(src.ptr());
if (iter == nullptr) {
if (!iter) {
PyErr_Clear();
return false;
}

bool success = true;
KeyCaster key_caster;
Caster key_caster;
PyObject *key;

if (is_base_caster_v<Caster> && !std::is_pointer_v<Key>)
flags |= (uint8_t) cast_flags::none_disallowed;

while ((key = PyIter_Next(iter)) != nullptr) {
success &= key_caster.from_python(key, flags, cleanup);
Py_DECREF(key);
Expand Down Expand Up @@ -59,7 +62,7 @@ template <typename Value_, typename Key> struct set_caster {
if (ret.is_valid()) {
for (auto& key : src) {
object k = steal(
KeyCaster::from_cpp(forward_like<T>(key), policy, cleanup));
Caster::from_cpp(forward_like<T>(key), policy, cleanup));

if (!k.is_valid() || PySet_Add(ret.ptr(), k.ptr()) != 0) {
ret.reset();
Expand Down
4 changes: 3 additions & 1 deletion include/nanobind/stl/optional.h
Expand Up @@ -29,8 +29,10 @@ struct type_caster<std::optional<T>> {
type_caster() : value(std::nullopt) { }

bool from_python(handle src, uint8_t flags, cleanup_list* cleanup) noexcept {
if (src.is_none())
if (src.is_none()) {
value = std::nullopt;
return true;
}

Caster caster;
if (!caster.from_python(src, flags, cleanup))
Expand Down
6 changes: 2 additions & 4 deletions include/nanobind/stl/variant.h
Expand Up @@ -52,10 +52,8 @@ template <typename... Ts> struct type_caster<std::variant<Ts...>> {
"type caster was registered to intercept this particular "
"type, which is not allowed.");

if constexpr (!std::is_pointer_v<T> && is_base_caster_v<CasterT>) {
if (src.is_none())
return false;
}
if (is_base_caster_v<CasterT> && !std::is_pointer_v<T>)
flags |= (uint8_t) cast_flags::none_disallowed;

CasterT caster;

Expand Down
2 changes: 1 addition & 1 deletion src/nb_type.cpp
Expand Up @@ -1013,7 +1013,7 @@ bool nb_type_get(const std::type_info *cpp_type, PyObject *src, uint8_t flags,
// Convert None -> nullptr
if (src == Py_None) {
*out = nullptr;
return true;
return (flags & (uint8_t) cast_flags::none_disallowed) == 0;
}

PyTypeObject *src_type = Py_TYPE(src);
Expand Down
6 changes: 6 additions & 0 deletions tests/test_stl.py
Expand Up @@ -824,3 +824,9 @@ def test69_complex_array():
def test70_vec_char():
assert isinstance(t.vector_str("123"), str)
assert isinstance(t.vector_str(["123", "345"]), list)

def test71_null_input():
with pytest.raises(TypeError):
t.vec_movable_in_value([None])
with pytest.raises(TypeError):
t.map_copyable_in_value({'a': None})

0 comments on commit 5f25ae0

Please sign in to comment.