Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support type stub generation for staticmethod #14934

Merged
merged 13 commits into from
Jan 8, 2024
16 changes: 10 additions & 6 deletions mypy/stubgenc.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,12 +530,14 @@ def is_classmethod(self, class_info: ClassInfo, name: str, obj: object) -> bool:
return inspect.ismethod(obj)

def is_staticmethod(self, class_info: ClassInfo | None, name: str, obj: object) -> bool:
if self.is_c_module:
if class_info is None:
return False
elif self.is_c_module:
raw_lookup: Mapping[str, Any] = getattr(class_info.cls, "__dict__") # noqa: B009
raw_value = raw_lookup.get(name, obj)
return isinstance(raw_value, staticmethod)
else:
return class_info is not None and isinstance(
inspect.getattr_static(class_info.cls, name), staticmethod
)
return isinstance(inspect.getattr_static(class_info.cls, name), staticmethod)

@staticmethod
def is_abstract_method(obj: object) -> bool:
Expand Down Expand Up @@ -761,7 +763,7 @@ def generate_class_stub(self, class_name: str, cls: type, output: list[str]) ->
The result lines will be appended to 'output'. If necessary, any
required names will be added to 'imports'.
"""
raw_lookup = getattr(cls, "__dict__") # noqa: B009
raw_lookup: Mapping[str, Any] = getattr(cls, "__dict__") # noqa: B009
items = self.get_members(cls)
if self.resort_members:
items = sorted(items, key=lambda x: method_name_sort_key(x[0]))
Expand Down Expand Up @@ -793,7 +795,9 @@ def generate_class_stub(self, class_name: str, cls: type, output: list[str]) ->
continue
attr = "__init__"
# FIXME: make this nicer
if self.is_classmethod(class_info, attr, value):
if self.is_staticmethod(class_info, attr, value):
class_info.self_var = ""
elif self.is_classmethod(class_info, attr, value):
class_info.self_var = "cls"
else:
class_info.self_var = "self"
Expand Down
15 changes: 15 additions & 0 deletions test-data/pybind11_mypy_demo/src/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,13 @@ const Point Point::y_axis = Point(0, 1);
Point::LengthUnit Point::length_unit = Point::LengthUnit::mm;
Point::AngleUnit Point::angle_unit = Point::AngleUnit::radian;

struct Foo
{
static int some_static_method(int a, int b) { return a * 42 + b; }
static int overloaded_static_method(int value) { return value * 42; }
static double overloaded_static_method(double value) { return value * 42; }
};

} // namespace: basics

void bind_basics(py::module& basics) {
Expand Down Expand Up @@ -166,6 +173,14 @@ void bind_basics(py::module& basics) {
.value("radian", Point::AngleUnit::radian)
.value("degree", Point::AngleUnit::degree);

// Static methods
py::class_<Foo> pyFoo(basics, "Foo");

pyFoo
.def_static("some_static_method", &Foo::some_static_method, R"#(None)#", py::arg("a"), py::arg("b"))
.def_static("overloaded_static_method", py::overload_cast<int>(&Foo::overloaded_static_method), py::arg("value"))
.def_static("overloaded_static_method", py::overload_cast<double>(&Foo::overloaded_static_method), py::arg("value"));

// Module-level attributes
basics.attr("PI") = std::acos(-1);
basics.attr("__version__") = "0.0.1";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,36 @@ from typing import ClassVar, List, overload
PI: float
__version__: str

class Foo:
def __init__(self, *args, **kwargs) -> None:
"""Initialize self. See help(type(self)) for accurate signature."""
@overload
@staticmethod
def overloaded_static_method(value: int) -> int:
"""overloaded_static_method(*args, **kwargs)
Overloaded function.

1. overloaded_static_method(value: int) -> int

2. overloaded_static_method(value: float) -> float
"""
@overload
@staticmethod
def overloaded_static_method(value: float) -> float:
"""overloaded_static_method(*args, **kwargs)
Overloaded function.

1. overloaded_static_method(value: int) -> int

2. overloaded_static_method(value: float) -> float
"""
@staticmethod
def some_static_method(a: int, b: int) -> int:
"""some_static_method(a: int, b: int) -> int

None
"""

class Point:
class AngleUnit:
__members__: ClassVar[dict] = ... # read-only
Expand Down
11 changes: 11 additions & 0 deletions test-data/pybind11_mypy_demo/stubgen/pybind11_mypy_demo/basics.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,17 @@ from typing import ClassVar, List, overload
PI: float
__version__: str

class Foo:
def __init__(self, *args, **kwargs) -> None: ...
@overload
@staticmethod
def overloaded_static_method(value: int) -> int: ...
@overload
@staticmethod
def overloaded_static_method(value: float) -> float: ...
@staticmethod
def some_static_method(a: int, b: int) -> int: ...

class Point:
class AngleUnit:
__members__: ClassVar[dict] = ... # read-only
Expand Down