diff --git a/mypy/plugins/common.py b/mypy/plugins/common.py index 55f2870cadb4..84d50b7086c6 100644 --- a/mypy/plugins/common.py +++ b/mypy/plugins/common.py @@ -1,5 +1,7 @@ from __future__ import annotations +from typing import NamedTuple + from mypy.argmap import map_actuals_to_formals from mypy.fixup import TypeFixer from mypy.nodes import ( @@ -16,9 +18,11 @@ JsonDict, NameExpr, Node, + OverloadedFuncDef, PassStmt, RefExpr, SymbolTableNode, + TypeInfo, Var, ) from mypy.plugin import CheckerPluginInterface, ClassDefContext, SemanticAnalyzerPluginInterface @@ -209,24 +213,99 @@ def add_method( ) +class MethodSpec(NamedTuple): + """Represents a method signature to be added, except for `name`.""" + + args: list[Argument] + return_type: Type + self_type: Type | None = None + tvar_defs: list[TypeVarType] | None = None + + def add_method_to_class( api: SemanticAnalyzerPluginInterface | CheckerPluginInterface, cls: ClassDef, name: str, + # MethodSpec items kept for backward compatibility: args: list[Argument], return_type: Type, self_type: Type | None = None, - tvar_def: TypeVarType | None = None, + tvar_def: list[TypeVarType] | TypeVarType | None = None, is_classmethod: bool = False, is_staticmethod: bool = False, -) -> None: +) -> FuncDef | Decorator: """Adds a new method to a class definition.""" + _prepare_class_namespace(cls, name) - assert not ( - is_classmethod is True and is_staticmethod is True - ), "Can't add a new method that's both staticmethod and classmethod." + if tvar_def is not None and not isinstance(tvar_def, list): + tvar_def = [tvar_def] + + func, sym = _add_method_by_spec( + api, + cls.info, + name, + MethodSpec(args=args, return_type=return_type, self_type=self_type, tvar_defs=tvar_def), + is_classmethod=is_classmethod, + is_staticmethod=is_staticmethod, + ) + cls.info.names[name] = sym + cls.info.defn.defs.body.append(func) + return func + +def add_overloaded_method_to_class( + api: SemanticAnalyzerPluginInterface | CheckerPluginInterface, + cls: ClassDef, + name: str, + items: list[MethodSpec], + is_classmethod: bool = False, + is_staticmethod: bool = False, +) -> OverloadedFuncDef: + """Adds a new overloaded method to a class definition.""" + assert len(items) >= 2, "Overloads must contain at least two cases" + + # Save old definition, if it exists. + _prepare_class_namespace(cls, name) + + # Create function bodies for each passed method spec. + funcs: list[Decorator | FuncDef] = [] + for item in items: + func, _sym = _add_method_by_spec( + api, + cls.info, + name=name, + spec=item, + is_classmethod=is_classmethod, + is_staticmethod=is_staticmethod, + ) + if isinstance(func, FuncDef): + var = Var(func.name, func.type) + var.set_line(func.line) + func.is_decorated = True + func.deco_line = func.line + + deco = Decorator(func, [], var) + else: + deco = func + deco.is_overload = True + funcs.append(deco) + + # Create the final OverloadedFuncDef node: + overload_def = OverloadedFuncDef(funcs) + overload_def.info = cls.info + overload_def.is_class = is_classmethod + overload_def.is_static = is_staticmethod + sym = SymbolTableNode(MDEF, overload_def) + sym.plugin_generated = True + + cls.info.names[name] = sym + cls.info.defn.defs.body.append(overload_def) + return overload_def + + +def _prepare_class_namespace(cls: ClassDef, name: str) -> None: info = cls.info + assert info # First remove any previously generated methods with the same name # to avoid clashes and problems in the semantic analyzer. @@ -235,6 +314,29 @@ def add_method_to_class( if sym.plugin_generated and isinstance(sym.node, FuncDef): cls.defs.body.remove(sym.node) + # NOTE: we would like the plugin generated node to dominate, but we still + # need to keep any existing definitions so they get semantically analyzed. + if name in info.names: + # Get a nice unique name instead. + r_name = get_unique_redefinition_name(name, info.names) + info.names[r_name] = info.names[name] + + +def _add_method_by_spec( + api: SemanticAnalyzerPluginInterface | CheckerPluginInterface, + info: TypeInfo, + name: str, + spec: MethodSpec, + *, + is_classmethod: bool, + is_staticmethod: bool, +) -> tuple[FuncDef | Decorator, SymbolTableNode]: + args, return_type, self_type, tvar_defs = spec + + assert not ( + is_classmethod is True and is_staticmethod is True + ), "Can't add a new method that's both staticmethod and classmethod." + if isinstance(api, SemanticAnalyzerPluginInterface): function_type = api.named_type("builtins.function") else: @@ -258,8 +360,8 @@ def add_method_to_class( arg_kinds.append(arg.kind) signature = CallableType(arg_types, arg_kinds, arg_names, return_type, function_type) - if tvar_def: - signature.variables = [tvar_def] + if tvar_defs: + signature.variables = tvar_defs func = FuncDef(name, args, Block([PassStmt()])) func.info = info @@ -269,13 +371,6 @@ def add_method_to_class( func._fullname = info.fullname + "." + name func.line = info.line - # NOTE: we would like the plugin generated node to dominate, but we still - # need to keep any existing definitions so they get semantically analyzed. - if name in info.names: - # Get a nice unique name instead. - r_name = get_unique_redefinition_name(name, info.names) - info.names[r_name] = info.names[name] - # Add decorator for is_staticmethod. It's unnecessary for is_classmethod. if is_staticmethod: func.is_decorated = True @@ -286,12 +381,12 @@ def add_method_to_class( dec = Decorator(func, [], v) dec.line = info.line sym = SymbolTableNode(MDEF, dec) - else: - sym = SymbolTableNode(MDEF, func) - sym.plugin_generated = True - info.names[name] = sym + sym.plugin_generated = True + return dec, sym - info.defn.defs.body.append(func) + sym = SymbolTableNode(MDEF, func) + sym.plugin_generated = True + return func, sym def add_attribute_to_class( @@ -304,7 +399,7 @@ def add_attribute_to_class( override_allow_incompatible: bool = False, fullname: str | None = None, is_classvar: bool = False, -) -> None: +) -> Var: """ Adds a new attribute to a class definition. This currently only generates the symbol table entry and no corresponding AssignmentStatement @@ -335,6 +430,7 @@ def add_attribute_to_class( info.names[name] = SymbolTableNode( MDEF, node, plugin_generated=True, no_serialize=no_serialize ) + return node def deserialize_and_fixup_type(data: str | JsonDict, api: SemanticAnalyzerPluginInterface) -> Type: diff --git a/test-data/unit/check-custom-plugin.test b/test-data/unit/check-custom-plugin.test index 9a0668f98c21..22374d09cf9f 100644 --- a/test-data/unit/check-custom-plugin.test +++ b/test-data/unit/check-custom-plugin.test @@ -1011,13 +1011,35 @@ class BaseAddMethod: pass class MyClass(BaseAddMethod): pass -my_class = MyClass() reveal_type(MyClass.foo_classmethod) # N: Revealed type is "def ()" reveal_type(MyClass.foo_staticmethod) # N: Revealed type is "def (builtins.int) -> builtins.str" + +my_class = MyClass() +reveal_type(my_class.foo_classmethod) # N: Revealed type is "def ()" +reveal_type(my_class.foo_staticmethod) # N: Revealed type is "def (builtins.int) -> builtins.str" [file mypy.ini] \[mypy] plugins=/test-data/unit/plugins/add_classmethod.py +[case testAddOverloadedMethodPlugin] +# flags: --config-file tmp/mypy.ini +class AddOverloadedMethod: pass + +class MyClass(AddOverloadedMethod): + pass + +reveal_type(MyClass.method) # N: Revealed type is "Overload(def (self: __main__.MyClass, arg: builtins.int) -> builtins.str, def (self: __main__.MyClass, arg: builtins.str) -> builtins.int)" +reveal_type(MyClass.clsmethod) # N: Revealed type is "Overload(def (arg: builtins.int) -> builtins.str, def (arg: builtins.str) -> builtins.int)" +reveal_type(MyClass.stmethod) # N: Revealed type is "Overload(def (arg: builtins.int) -> builtins.str, def (arg: builtins.str) -> builtins.int)" + +my_class = MyClass() +reveal_type(my_class.method) # N: Revealed type is "Overload(def (arg: builtins.int) -> builtins.str, def (arg: builtins.str) -> builtins.int)" +reveal_type(my_class.clsmethod) # N: Revealed type is "Overload(def (arg: builtins.int) -> builtins.str, def (arg: builtins.str) -> builtins.int)" +reveal_type(my_class.stmethod) # N: Revealed type is "Overload(def (arg: builtins.int) -> builtins.str, def (arg: builtins.str) -> builtins.int)" +[file mypy.ini] +\[mypy] +plugins=/test-data/unit/plugins/add_overloaded_method.py + [case testCustomErrorCodePlugin] # flags: --config-file tmp/mypy.ini --show-error-codes def main() -> int: diff --git a/test-data/unit/check-incremental.test b/test-data/unit/check-incremental.test index fcab0545b982..b4cd21aa552c 100644 --- a/test-data/unit/check-incremental.test +++ b/test-data/unit/check-incremental.test @@ -5935,6 +5935,44 @@ tmp/b.py:4: note: Revealed type is "def ()" tmp/b.py:5: note: Revealed type is "def (builtins.int) -> builtins.str" tmp/b.py:6: note: Revealed type is "def ()" tmp/b.py:7: note: Revealed type is "def (builtins.int) -> builtins.str" + +[case testIncrementalAddOverloadedMethodPlugin] +# flags: --config-file tmp/mypy.ini +import b + +[file mypy.ini] +\[mypy] +plugins=/test-data/unit/plugins/add_overloaded_method.py + +[file a.py] +class AddOverloadedMethod: pass + +class MyClass(AddOverloadedMethod): + pass + +[file b.py] +import a + +[file b.py.2] +import a + +reveal_type(a.MyClass.method) +reveal_type(a.MyClass.clsmethod) +reveal_type(a.MyClass.stmethod) + +my_class = a.MyClass() +reveal_type(my_class.method) +reveal_type(my_class.clsmethod) +reveal_type(my_class.stmethod) +[rechecked b] +[out2] +tmp/b.py:3: note: Revealed type is "Overload(def (self: a.MyClass, arg: builtins.int) -> builtins.str, def (self: a.MyClass, arg: builtins.str) -> builtins.int)" +tmp/b.py:4: note: Revealed type is "Overload(def (arg: builtins.int) -> builtins.str, def (arg: builtins.str) -> builtins.int)" +tmp/b.py:5: note: Revealed type is "Overload(def (arg: builtins.int) -> builtins.str, def (arg: builtins.str) -> builtins.int)" +tmp/b.py:8: note: Revealed type is "Overload(def (arg: builtins.int) -> builtins.str, def (arg: builtins.str) -> builtins.int)" +tmp/b.py:9: note: Revealed type is "Overload(def (arg: builtins.int) -> builtins.str, def (arg: builtins.str) -> builtins.int)" +tmp/b.py:10: note: Revealed type is "Overload(def (arg: builtins.int) -> builtins.str, def (arg: builtins.str) -> builtins.int)" + [case testGenericNamedTupleSerialization] import b [file a.py] diff --git a/test-data/unit/deps.test b/test-data/unit/deps.test index c3295b79e4ed..5e77ff1d85e0 100644 --- a/test-data/unit/deps.test +++ b/test-data/unit/deps.test @@ -1387,12 +1387,13 @@ class B(A): -> , m -> -> , m.B.__init__ - -> , m.B.__mypy-replace + -> , m, m.B.__mypy-replace -> -> -> -> m, m.A, m.B -> m + -> m -> m -> m.B -> m @@ -1419,12 +1420,13 @@ class B(A): -> -> , m.B.__init__ -> - -> , m.B.__mypy-replace + -> , m, m.B.__mypy-replace -> -> -> -> m, m.A, m.B -> m + -> m -> m -> m.B -> m diff --git a/test-data/unit/plugins/add_overloaded_method.py b/test-data/unit/plugins/add_overloaded_method.py new file mode 100644 index 000000000000..efda848f790c --- /dev/null +++ b/test-data/unit/plugins/add_overloaded_method.py @@ -0,0 +1,41 @@ +from __future__ import annotations + +from typing import Callable + +from mypy.nodes import ARG_POS, Argument, Var +from mypy.plugin import ClassDefContext, Plugin +from mypy.plugins.common import MethodSpec, add_overloaded_method_to_class + + +class OverloadedMethodPlugin(Plugin): + def get_base_class_hook(self, fullname: str) -> Callable[[ClassDefContext], None] | None: + if "AddOverloadedMethod" in fullname: + return add_overloaded_method_hook + return None + + +def add_overloaded_method_hook(ctx: ClassDefContext) -> None: + add_overloaded_method_to_class(ctx.api, ctx.cls, "method", _generate_method_specs(ctx)) + add_overloaded_method_to_class( + ctx.api, ctx.cls, "clsmethod", _generate_method_specs(ctx), is_classmethod=True + ) + add_overloaded_method_to_class( + ctx.api, ctx.cls, "stmethod", _generate_method_specs(ctx), is_staticmethod=True + ) + + +def _generate_method_specs(ctx: ClassDefContext) -> list[MethodSpec]: + return [ + MethodSpec( + args=[Argument(Var("arg"), ctx.api.named_type("builtins.int"), None, ARG_POS)], + return_type=ctx.api.named_type("builtins.str"), + ), + MethodSpec( + args=[Argument(Var("arg"), ctx.api.named_type("builtins.str"), None, ARG_POS)], + return_type=ctx.api.named_type("builtins.int"), + ), + ] + + +def plugin(version: str) -> type[OverloadedMethodPlugin]: + return OverloadedMethodPlugin