diff --git a/include/pybind11/cast.h b/include/pybind11/cast.h index 4dcbf26236..7b014fed99 100644 --- a/include/pybind11/cast.h +++ b/include/pybind11/cast.h @@ -1401,7 +1401,7 @@ struct handle_type_name { }; template <> struct handle_type_name { - static constexpr auto name = io_name("typing.SupportsInt", "int"); + static constexpr auto name = const_name("int"); }; template <> struct handle_type_name { @@ -1413,7 +1413,7 @@ struct handle_type_name { }; template <> struct handle_type_name { - static constexpr auto name = io_name("typing.SupportsFloat", "float"); + static constexpr auto name = const_name("float"); }; template <> struct handle_type_name { @@ -1534,6 +1534,21 @@ struct pyobject_caster { template class type_caster::value>> : public pyobject_caster {}; +template <> +class type_caster : public pyobject_caster { +public: + bool load(handle src, bool /* convert */) { + if (isinstance(src)) { + value = reinterpret_borrow(src); + } else if (isinstance(src)) { + value = float_(reinterpret_borrow(src)); + } else { + return false; + } + return true; + } +}; + // Our conditions for enabling moving are quite restrictive: // At compile time: // - T needs to be a non-const, non-pointer, non-reference type diff --git a/tests/test_pytypes.cpp b/tests/test_pytypes.cpp index 1c136cf0f2..7d5423e549 100644 --- a/tests/test_pytypes.cpp +++ b/tests/test_pytypes.cpp @@ -209,6 +209,7 @@ TEST_SUBMODULE(pytypes, m) { m.def("get_tuple_from_iterable", [](const py::iterable &iter) { return py::tuple(iter); }); // test_float m.def("get_float", [] { return py::float_(0.0f); }); + m.def("float_roundtrip", [](py::float_ f) { return f; }); // test_list m.def("list_no_args", []() { return py::list{}; }); m.def("list_ssize_t", []() { return py::list{(py::ssize_t) 0}; }); diff --git a/tests/test_pytypes.py b/tests/test_pytypes.py index a199d72f0a..c1798f924c 100644 --- a/tests/test_pytypes.py +++ b/tests/test_pytypes.py @@ -61,6 +61,13 @@ def test_iterable(doc): def test_float(doc): assert doc(m.get_float) == "get_float() -> float" + assert doc(m.float_roundtrip) == "float_roundtrip(arg0: float) -> float" + f1 = m.float_roundtrip(5.5) + assert isinstance(f1, float) + assert f1 == 5.5 + f2 = m.float_roundtrip(5) + assert isinstance(f2, float) + assert f2 == 5.0 def test_list(capture, doc): @@ -917,7 +924,7 @@ def test_inplace_rshift(a, b): def test_tuple_nonempty_annotations(doc): assert ( doc(m.annotate_tuple_float_str) - == "annotate_tuple_float_str(arg0: tuple[typing.SupportsFloat, str]) -> None" + == "annotate_tuple_float_str(arg0: tuple[float, str]) -> None" ) @@ -930,7 +937,7 @@ def test_tuple_empty_annotations(doc): def test_tuple_variable_length_annotations(doc): assert ( doc(m.annotate_tuple_variable_length) - == "annotate_tuple_variable_length(arg0: tuple[typing.SupportsFloat, ...]) -> None" + == "annotate_tuple_variable_length(arg0: tuple[float, ...]) -> None" ) @@ -989,7 +996,7 @@ def test_type_annotation(doc): def test_union_annotations(doc): assert ( doc(m.annotate_union) - == "annotate_union(arg0: list[str | typing.SupportsInt | object], arg1: str, arg2: typing.SupportsInt, arg3: object) -> list[str | int | object]" + == "annotate_union(arg0: list[str | int | object], arg1: str, arg2: int, arg3: object) -> list[str | int | object]" )