Skip to content

Commit

Permalink
minor tweaks to nb::cast<T> interface
Browse files Browse the repository at this point in the history
  • Loading branch information
wjakob committed Apr 11, 2023
1 parent b0bd2ba commit 9ae3205
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 14 deletions.
22 changes: 8 additions & 14 deletions include/nanobind/nb_cast.h
Expand Up @@ -338,27 +338,21 @@ T cast(const detail::api<Derived> &value, bool convert = true) {
if constexpr (std::is_same_v<T, void>) {
return;
} else {
using Ti = detail::intrinsic_t<T>;
using Caster = detail::make_caster<Ti>;
using Caster = detail::make_caster<T>;
using Output = typename Caster::template Cast<T>;

static_assert(
!(std::is_reference_v<T> || std::is_pointer_v<T>) || Caster::IsClass ||
std::is_same_v<const char *, T>,
"nanobind::cast(): cannot return a reference to a temporary.");

Caster caster;
if (!caster.from_python(value.derived().ptr(),
convert ? (uint8_t) detail::cast_flags::convert
: (uint8_t) 0, nullptr))
detail::raise_cast_error();

if constexpr (std::is_same_v<T, const char *>) {
return caster.operator const char *();
} else {
static_assert(
!(std::is_reference_v<T> || std::is_pointer_v<T>) || Caster::IsClass,
"nanobind::cast(): cannot return a reference to a temporary.");

if constexpr (detail::is_pointer_v<T>)
return caster.operator Ti*();
else
return caster.operator Ti&();
}
return caster.operator Output();
}
}

Expand Down
8 changes: 8 additions & 0 deletions tests/test_functions.cpp
Expand Up @@ -212,4 +212,12 @@ NB_MODULE(test_functions_ext, m) {

return nb::cpp_function(callback);
});

m.def("test_cast_char", [](nb::handle h) {
return nb::cast<char>(h);
});

m.def("test_cast_str", [](nb::handle h) {
return nb::cast<const char *>(h);
});
}
13 changes: 13 additions & 0 deletions tests/test_functions.py
Expand Up @@ -307,3 +307,16 @@ def test34_module_docstring():
def test35_return_capture():
x = t.test_35()
assert x() == 'Test Foo'

def test36_test_char():
assert t.test_cast_char('c') == 'c'
with pytest.raises(TypeError):
assert t.test_cast_char('abc')
with pytest.raises(RuntimeError):
assert t.test_cast_char(123)

def test37_test_str():
assert t.test_cast_str('c') == 'c'
assert t.test_cast_str('abc') == 'abc'
with pytest.raises(RuntimeError):
assert t.test_cast_str(123)

0 comments on commit 9ae3205

Please sign in to comment.