Skip to content
19 changes: 17 additions & 2 deletions include/pybind11/cast.h
Original file line number Diff line number Diff line change
Expand Up @@ -1401,7 +1401,7 @@ struct handle_type_name<buffer> {
};
template <>
struct handle_type_name<int_> {
static constexpr auto name = io_name("typing.SupportsInt", "int");
static constexpr auto name = const_name("int");
};
template <>
struct handle_type_name<iterable> {
Expand All @@ -1413,7 +1413,7 @@ struct handle_type_name<iterator> {
};
template <>
struct handle_type_name<float_> {
static constexpr auto name = io_name("typing.SupportsFloat", "float");
static constexpr auto name = const_name("float");
};
template <>
struct handle_type_name<function> {
Expand Down Expand Up @@ -1534,6 +1534,21 @@ struct pyobject_caster {
template <typename T>
class type_caster<T, enable_if_t<is_pyobject<T>::value>> : public pyobject_caster<T> {};

template <>
class type_caster<float_> : public pyobject_caster<float_> {
public:
bool load(handle src, bool /* convert */) {
if (isinstance<float_>(src)) {
value = reinterpret_borrow<float_>(src);
} else if (isinstance<int_>(src)) {
value = float_(reinterpret_borrow<int_>(src));
} else {
return false;
}
return true;
}
};

// Our conditions for enabling moving are quite restrictive:
// At compile time:
// - T needs to be a non-const, non-pointer, non-reference type
Expand Down
1 change: 1 addition & 0 deletions tests/test_pytypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ TEST_SUBMODULE(pytypes, m) {
m.def("get_tuple_from_iterable", [](const py::iterable &iter) { return py::tuple(iter); });
// test_float
m.def("get_float", [] { return py::float_(0.0f); });
m.def("float_roundtrip", [](py::float_ f) { return f; });
// test_list
m.def("list_no_args", []() { return py::list{}; });
m.def("list_ssize_t", []() { return py::list{(py::ssize_t) 0}; });
Expand Down
13 changes: 10 additions & 3 deletions tests/test_pytypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,13 @@ def test_iterable(doc):

def test_float(doc):
assert doc(m.get_float) == "get_float() -> float"
assert doc(m.float_roundtrip) == "float_roundtrip(arg0: float) -> float"
f1 = m.float_roundtrip(5.5)
assert isinstance(f1, float)
assert f1 == 5.5
f2 = m.float_roundtrip(5)
assert isinstance(f2, float)
assert f2 == 5.0


def test_list(capture, doc):
Expand Down Expand Up @@ -917,7 +924,7 @@ def test_inplace_rshift(a, b):
def test_tuple_nonempty_annotations(doc):
assert (
doc(m.annotate_tuple_float_str)
== "annotate_tuple_float_str(arg0: tuple[typing.SupportsFloat, str]) -> None"
== "annotate_tuple_float_str(arg0: tuple[float, str]) -> None"
)


Expand All @@ -930,7 +937,7 @@ def test_tuple_empty_annotations(doc):
def test_tuple_variable_length_annotations(doc):
assert (
doc(m.annotate_tuple_variable_length)
== "annotate_tuple_variable_length(arg0: tuple[typing.SupportsFloat, ...]) -> None"
== "annotate_tuple_variable_length(arg0: tuple[float, ...]) -> None"
)


Expand Down Expand Up @@ -989,7 +996,7 @@ def test_type_annotation(doc):
def test_union_annotations(doc):
assert (
doc(m.annotate_union)
== "annotate_union(arg0: list[str | typing.SupportsInt | object], arg1: str, arg2: typing.SupportsInt, arg3: object) -> list[str | int | object]"
== "annotate_union(arg0: list[str | int | object], arg1: str, arg2: int, arg3: object) -> list[str | int | object]"
)


Expand Down
Loading