From 4020d10607ccd992c28af9424e3dec3aec79f3b8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9E=97=E7=8E=AE=20=28Jade=20Lin=29?= Date: Sun, 1 Aug 2021 21:43:27 +0800 Subject: [PATCH 1/3] Fix apply_dynamic_class_hook does not work in some scenarios --- mypy/semanal.py | 51 ++++++++++++++-------- test-data/unit/check-custom-plugin.test | 56 ++++++++++++++++++------- 2 files changed, 74 insertions(+), 33 deletions(-) diff --git a/mypy/semanal.py b/mypy/semanal.py index 1c39aa0de256..0fb074786f28 100644 --- a/mypy/semanal.py +++ b/mypy/semanal.py @@ -2277,25 +2277,40 @@ def analyze_lvalues(self, s: AssignmentStmt) -> None: is_final=s.is_final_def) def apply_dynamic_class_hook(self, s: AssignmentStmt) -> None: - if len(s.lvalues) > 1: - return - lval = s.lvalues[0] - if not isinstance(lval, NameExpr) or not isinstance(s.rvalue, CallExpr): + if not isinstance(s.rvalue, CallExpr): return - call = s.rvalue - fname = None - if isinstance(call.callee, RefExpr): - fname = call.callee.fullname - # check if method call - if fname is None and isinstance(call.callee, MemberExpr): - callee_expr = call.callee.expr - if isinstance(callee_expr, RefExpr) and callee_expr.fullname: - method_name = call.callee.name - fname = callee_expr.fullname + '.' + method_name - if fname: - hook = self.plugin.get_dynamic_class_hook(fname) - if hook: - hook(DynamicClassDefContext(call, lval.name, self)) + + from .traverser import TraverserVisitor + + class CallExprVisitor(TraverserVisitor): + analyzer: SemanticAnalyzer + + def __init__(self, analyzer: SemanticAnalyzer) -> None: + super().__init__() + self.analyzer = analyzer + + def visit_call_expr(self, call: CallExpr) -> None: + fname = None + if isinstance(call.callee, RefExpr): + fname = call.callee.fullname + # check if method call + if fname is None and isinstance(call.callee, MemberExpr): + callee_expr = call.callee.expr + if isinstance(callee_expr, RefExpr) and callee_expr.fullname: + method_name = call.callee.name + fname = callee_expr.fullname + '.' + method_name + if fname: + for lval in s.lvalues: + if not isinstance(lval, NameExpr): + continue + hook = self.analyzer.plugin.get_dynamic_class_hook(fname) + if hook: + hook(DynamicClassDefContext(call, lval.name, self.analyzer)) + + super().visit_call_expr(call) + + visitor = CallExprVisitor(analyzer=self) + s.accept(visitor) def unwrap_final(self, s: AssignmentStmt) -> bool: """Strip Final[...] if present in an assignment. diff --git a/test-data/unit/check-custom-plugin.test b/test-data/unit/check-custom-plugin.test index 7c85881363d6..8f02a13cfc67 100644 --- a/test-data/unit/check-custom-plugin.test +++ b/test-data/unit/check-custom-plugin.test @@ -543,26 +543,23 @@ class Instr(Generic[T]): ... \[mypy] plugins=/test-data/unit/plugins/dyn_class.py -[case testDynamicClassPluginNegatives] +[case testDynamicClassPluginChainCall] # flags: --config-file tmp/mypy.ini -from mod import declarative_base, Column, Instr, non_declarative_base +from mod import declarative_base, Column, Instr -Bad1 = non_declarative_base() -Bad2 = Bad3 = declarative_base() +Base = declarative_base().with_optional_xxx() -class C1(Bad1): ... # E: Variable "__main__.Bad1" is not valid as a type \ - # N: See https://mypy.readthedocs.io/en/stable/common_issues.html#variables-vs-type-aliases \ - # E: Invalid base class "Bad1" -class C2(Bad2): ... # E: Variable "__main__.Bad2" is not valid as a type \ - # N: See https://mypy.readthedocs.io/en/stable/common_issues.html#variables-vs-type-aliases \ - # E: Invalid base class "Bad2" -class C3(Bad3): ... # E: Variable "__main__.Bad3" is not valid as a type \ - # N: See https://mypy.readthedocs.io/en/stable/common_issues.html#variables-vs-type-aliases \ - # E: Invalid base class "Bad3" +class Model(Base): + x: Column[int] + +reveal_type(Model().x) # N: Revealed type is "mod.Instr[builtins.int]" [file mod.py] from typing import Generic, TypeVar -def declarative_base(): ... -def non_declarative_base(): ... + +class Base: + def with_optional_xxx(self) -> Base: ... + +def declarative_base() -> Base: ... T = TypeVar('T') @@ -573,6 +570,35 @@ class Instr(Generic[T]): ... \[mypy] plugins=/test-data/unit/plugins/dyn_class.py +[case testDynamicClassPluginChainedAssignment] +# flags: --config-file tmp/mypy.ini +from mod import declarative_base + +Base1 = Base2 = declarative_base() + +class C1(Base1): ... +class C2(Base2): ... +[file mod.py] +def declarative_base(): ... +[file mypy.ini] +\[mypy] +plugins=/test-data/unit/plugins/dyn_class.py + +[case testDynamicClassPluginNegatives] +# flags: --config-file tmp/mypy.ini +from mod import non_declarative_base + +Bad1 = non_declarative_base() + +class C1(Bad1): ... # E: Variable "__main__.Bad1" is not valid as a type \ + # N: See https://mypy.readthedocs.io/en/stable/common_issues.html#variables-vs-type-aliases \ + # E: Invalid base class "Bad1" +[file mod.py] +def non_declarative_base(): ... +[file mypy.ini] +\[mypy] +plugins=/test-data/unit/plugins/dyn_class.py + [case testDynamicClassHookFromClassMethod] # flags: --config-file tmp/mypy.ini From 24b125fffa7d342c8562e3ad3c7ecf16e3a21e44 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9E=97=E7=8E=AE=20=28Jade=20Lin=29?= Date: Mon, 2 Aug 2021 12:12:54 +0800 Subject: [PATCH 2/3] Add all_call_expressions function to get all CallExpr from a specific node --- mypy/semanal.py | 31 +++++++++---------------------- mypy/traverser.py | 16 ++++++++++++++++ 2 files changed, 25 insertions(+), 22 deletions(-) diff --git a/mypy/semanal.py b/mypy/semanal.py index 0fb074786f28..b10c1886da92 100644 --- a/mypy/semanal.py +++ b/mypy/semanal.py @@ -95,6 +95,7 @@ ) from mypy.typeops import function_type from mypy.type_visitor import TypeQuery +from mypy.traverser import all_call_expressions from mypy.nodes import implicit_module_attrs from mypy.typeanal import ( TypeAnalyser, analyze_type_alias, no_subscript_builtin_alias, @@ -2279,17 +2280,11 @@ def analyze_lvalues(self, s: AssignmentStmt) -> None: def apply_dynamic_class_hook(self, s: AssignmentStmt) -> None: if not isinstance(s.rvalue, CallExpr): return - - from .traverser import TraverserVisitor - - class CallExprVisitor(TraverserVisitor): - analyzer: SemanticAnalyzer - - def __init__(self, analyzer: SemanticAnalyzer) -> None: - super().__init__() - self.analyzer = analyzer - - def visit_call_expr(self, call: CallExpr) -> None: + call_expressions = all_call_expressions(s) + for lval in s.lvalues: + if not isinstance(lval, NameExpr): + continue + for call in call_expressions: fname = None if isinstance(call.callee, RefExpr): fname = call.callee.fullname @@ -2300,17 +2295,9 @@ def visit_call_expr(self, call: CallExpr) -> None: method_name = call.callee.name fname = callee_expr.fullname + '.' + method_name if fname: - for lval in s.lvalues: - if not isinstance(lval, NameExpr): - continue - hook = self.analyzer.plugin.get_dynamic_class_hook(fname) - if hook: - hook(DynamicClassDefContext(call, lval.name, self.analyzer)) - - super().visit_call_expr(call) - - visitor = CallExprVisitor(analyzer=self) - s.accept(visitor) + hook = self.plugin.get_dynamic_class_hook(fname) + if hook: + hook(DynamicClassDefContext(call, lval.name, self)) def unwrap_final(self, s: AssignmentStmt) -> bool: """Strip Final[...] if present in an assignment. diff --git a/mypy/traverser.py b/mypy/traverser.py index a5f993bd2fa5..7c09f09e5e60 100644 --- a/mypy/traverser.py +++ b/mypy/traverser.py @@ -378,3 +378,19 @@ def all_yield_expressions(node: Node) -> List[Tuple[YieldExpr, bool]]: v = YieldCollector() node.accept(v) return v.yield_expressions + + +class CallCollector(TraverserVisitor): + def __init__(self) -> None: + super().__init__() + self.call_expressions: List[CallExpr] = [] + + def visit_call_expr(self, o: CallExpr) -> None: + self.call_expressions.append(o) + return super().visit_call_expr(o) + + +def all_call_expressions(node: Node) -> List[CallExpr]: + v = CallCollector() + node.accept(v) + return v.call_expressions From 1e544c13bce5832ffa4a36d796c4e233680190e4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9E=97=E7=8E=AE=20=28Jade=20Lin=29?= Date: Mon, 11 Oct 2021 22:27:01 +0800 Subject: [PATCH 3/3] Fix applying dynamic class hook on the wrong call expressions --- mypy/semanal.py | 38 ++++++++++++++++++++++---------------- mypy/traverser.py | 16 ---------------- 2 files changed, 22 insertions(+), 32 deletions(-) diff --git a/mypy/semanal.py b/mypy/semanal.py index b10c1886da92..89c9f9522d6e 100644 --- a/mypy/semanal.py +++ b/mypy/semanal.py @@ -95,7 +95,6 @@ ) from mypy.typeops import function_type from mypy.type_visitor import TypeQuery -from mypy.traverser import all_call_expressions from mypy.nodes import implicit_module_attrs from mypy.typeanal import ( TypeAnalyser, analyze_type_alias, no_subscript_builtin_alias, @@ -2280,24 +2279,31 @@ def analyze_lvalues(self, s: AssignmentStmt) -> None: def apply_dynamic_class_hook(self, s: AssignmentStmt) -> None: if not isinstance(s.rvalue, CallExpr): return - call_expressions = all_call_expressions(s) + fname = None + call = s.rvalue + while True: + if isinstance(call.callee, RefExpr): + fname = call.callee.fullname + # check if method call + if fname is None and isinstance(call.callee, MemberExpr): + callee_expr = call.callee.expr + if isinstance(callee_expr, RefExpr) and callee_expr.fullname: + method_name = call.callee.name + fname = callee_expr.fullname + '.' + method_name + elif isinstance(callee_expr, CallExpr): + # check if chain call + call = callee_expr + continue + break + if not fname: + return + hook = self.plugin.get_dynamic_class_hook(fname) + if not hook: + return for lval in s.lvalues: if not isinstance(lval, NameExpr): continue - for call in call_expressions: - fname = None - if isinstance(call.callee, RefExpr): - fname = call.callee.fullname - # check if method call - if fname is None and isinstance(call.callee, MemberExpr): - callee_expr = call.callee.expr - if isinstance(callee_expr, RefExpr) and callee_expr.fullname: - method_name = call.callee.name - fname = callee_expr.fullname + '.' + method_name - if fname: - hook = self.plugin.get_dynamic_class_hook(fname) - if hook: - hook(DynamicClassDefContext(call, lval.name, self)) + hook(DynamicClassDefContext(call, lval.name, self)) def unwrap_final(self, s: AssignmentStmt) -> bool: """Strip Final[...] if present in an assignment. diff --git a/mypy/traverser.py b/mypy/traverser.py index 7c09f09e5e60..a5f993bd2fa5 100644 --- a/mypy/traverser.py +++ b/mypy/traverser.py @@ -378,19 +378,3 @@ def all_yield_expressions(node: Node) -> List[Tuple[YieldExpr, bool]]: v = YieldCollector() node.accept(v) return v.yield_expressions - - -class CallCollector(TraverserVisitor): - def __init__(self) -> None: - super().__init__() - self.call_expressions: List[CallExpr] = [] - - def visit_call_expr(self, o: CallExpr) -> None: - self.call_expressions.append(o) - return super().visit_call_expr(o) - - -def all_call_expressions(node: Node) -> List[CallExpr]: - v = CallCollector() - node.accept(v) - return v.call_expressions