From 38471241cf9e1476b2e4777fd6d5da7eb5bd2af8 Mon Sep 17 00:00:00 2001 From: Adrian Freund Date: Tue, 9 Mar 2021 23:09:43 +0100 Subject: [PATCH 01/76] Add Nodes needed for match statement support --- mypy-requirements.txt | 2 +- mypy/fastparse.py | 22 +++++++++++++-- mypy/nodes.py | 43 +++++++++++++++++++++++++++++ mypy/test/testcheck.py | 2 ++ mypy/visitor.py | 12 ++++++++ test-data/unit/check-python310.test | 35 +++++++++++++++++++++++ 6 files changed, 113 insertions(+), 3 deletions(-) create mode 100644 test-data/unit/check-python310.test diff --git a/mypy-requirements.txt b/mypy-requirements.txt index b5bb625d5a56..6c497d709e47 100644 --- a/mypy-requirements.txt +++ b/mypy-requirements.txt @@ -1,6 +1,6 @@ typing_extensions>=3.7.4 mypy_extensions>=0.4.3,<0.5.0 -typed_ast>=1.4.0,<1.5.0 +typed_ast>=1.4.0,<1.5.0; python_version<'3.8' types-typing-extensions>=3.7.0 types-mypy-extensions>=0.4.0 toml diff --git a/mypy/fastparse.py b/mypy/fastparse.py index b250095c74a8..e0277e12696e 100644 --- a/mypy/fastparse.py +++ b/mypy/fastparse.py @@ -17,14 +17,14 @@ ClassDef, Decorator, Block, Var, OperatorAssignmentStmt, ExpressionStmt, AssignmentStmt, ReturnStmt, RaiseStmt, AssertStmt, DelStmt, BreakStmt, ContinueStmt, PassStmt, GlobalDecl, - WhileStmt, ForStmt, IfStmt, TryStmt, WithStmt, + WhileStmt, ForStmt, IfStmt, TryStmt, WithStmt, MatchStmt, TupleExpr, GeneratorExpr, ListComprehension, ListExpr, ConditionalExpr, DictExpr, SetExpr, NameExpr, IntExpr, StrExpr, BytesExpr, UnicodeExpr, FloatExpr, CallExpr, SuperExpr, MemberExpr, IndexExpr, SliceExpr, OpExpr, UnaryExpr, LambdaExpr, ComparisonExpr, AssignmentExpr, StarExpr, YieldFromExpr, NonlocalDecl, DictionaryComprehension, SetComprehension, ComplexExpr, EllipsisExpr, YieldExpr, Argument, - AwaitExpr, TempNode, Expression, Statement, + AwaitExpr, MatchAs, MatchOr, TempNode, Expression, Statement, ARG_POS, ARG_OPT, ARG_STAR, ARG_NAMED, ARG_NAMED_OPT, ARG_STAR2, check_arg_names, FakeInfo, @@ -1272,6 +1272,24 @@ def visit_Index(self, n: Index) -> Node: # cast for mypyc's benefit on Python 3.9 return self.visit(cast(Any, n).value) + # Match(expr subject, match_case* cases) # python 3.10 and later + def visit_Match(self, n: ast3.Match) -> MatchStmt: + node = MatchStmt(self.visit(n.subject), + [self.visit(c.pattern) for c in n.cases], + [self.visit(c.guard) for c in n.cases], + [self.as_required_block(c.body, n.lineno) for c in n.cases]) + return self.set_line(node, n) + + # MatchAs(expr pattern, identifier name) + def visit_MatchAs(self, n: ast3.MatchAs) -> MatchAs: + node = MatchAs(self.visit(n.pattern), n.name) + return self.set_line(node, n) + + # MatchOr(expr* pattern) + def visit_MatchOr(self, n: ast3.MatchOr) -> MatchOr: + node = MatchOr([self.visit(pattern) for pattern in n.patterns]) + return self.set_line(node, n) + class TypeConverter: def __init__(self, diff --git a/mypy/nodes.py b/mypy/nodes.py index 76521e8c2b38..e8aa93ab5a2e 100644 --- a/mypy/nodes.py +++ b/mypy/nodes.py @@ -1268,6 +1268,24 @@ def accept(self, visitor: StatementVisitor[T]) -> T: return visitor.visit_with_stmt(self) +class MatchStmt(Statement): + subject = None # type: Expression + patterns = None # type: List[Expression] + guards = None # type: List[Optional[Expression]] + bodies = None # type: List[Block] + + def __init__(self, subject: Expression, patterns: List[Expression], + guards: List[Optional[Expression]], bodies: List[Block]) -> None: + super().__init__() + self.subject = subject + self.patterns = patterns + self.guards = guards + self.bodies = bodies + + def accept(self, visitor: StatementVisitor[T]) -> T: + return visitor.visit_match_stmt(self) + + class PrintStmt(Statement): """Python 2 print statement""" @@ -2277,6 +2295,31 @@ def accept(self, visitor: ExpressionVisitor[T]) -> T: return visitor.visit_await_expr(self) +# Note: CPython considers MatchAs and MatchOr to be expressions, but they are only allowed inside match_case patterns +class MatchAs(Expression): + pattern = None # type: Expression + name = None # type: str + + def __init__(self, pattern: Expression, name: str) -> None: + super().__init__() + self.pattern = pattern + self.name = name + + def accept(self, visitor: ExpressionVisitor[T]) -> T: + return visitor.visit_match_as(self) + + +class MatchOr(Expression): + patterns = None # type: List[Expression] + + def __init__(self, patterns: List[Expression]) -> None: + super().__init__() + self.patterns = patterns + + def accept(self, visitor: ExpressionVisitor[T]) -> T: + return visitor.visit_match_or(self) + + # Constants diff --git a/mypy/test/testcheck.py b/mypy/test/testcheck.py index 51f5d71c12ad..134f093d8c8a 100644 --- a/mypy/test/testcheck.py +++ b/mypy/test/testcheck.py @@ -101,6 +101,8 @@ typecheck_files.append('check-python38.test') if sys.version_info >= (3, 9): typecheck_files.append('check-python39.test') +if sys.version_info >= (3, 10): + typecheck_files.append('check-python310.test') # Special tests for platforms with case-insensitive filesystems. if sys.platform in ('darwin', 'win32'): diff --git a/mypy/visitor.py b/mypy/visitor.py index b98ec773bbe3..f13fe47668b2 100644 --- a/mypy/visitor.py +++ b/mypy/visitor.py @@ -192,6 +192,14 @@ def visit_await_expr(self, o: 'mypy.nodes.AwaitExpr') -> T: def visit_temp_node(self, o: 'mypy.nodes.TempNode') -> T: pass + @abstractmethod + def visit_match_as(self, o: 'mypy.nodes.MatchAs') -> T: + pass + + @abstractmethod + def visit_match_or(self, o: 'mypy.nodes.MatchOr') -> T: + pass + @trait @mypyc_attr(allow_interpreted_subclasses=True) @@ -310,6 +318,10 @@ def visit_print_stmt(self, o: 'mypy.nodes.PrintStmt') -> T: def visit_exec_stmt(self, o: 'mypy.nodes.ExecStmt') -> T: pass + @abstractmethod + def visit_match_stmt(self, o: 'mypy.nodes.MatchStmt') -> T: + pass + @trait @mypyc_attr(allow_interpreted_subclasses=True) diff --git a/test-data/unit/check-python310.test b/test-data/unit/check-python310.test new file mode 100644 index 000000000000..4d0e1c387116 --- /dev/null +++ b/test-data/unit/check-python310.test @@ -0,0 +1,35 @@ +[case testSimpleMatch] +# flags: --python-version 3.10 +class A: ... +class B: ... +a: A +m: object + +match m: + case ["quit"]: + a = B() # E: Incompatible types in assignment (expression has type "B", variable has type "A") + case ["look"]: + a = A() + +reveal_type(a) + + +[case testMatchAs] +# flags: --python-version 3.10 +class A: ... +m: object + +match m: + case [x] as a: + reveal_type(a) + reveal_type(x) + + +[case testMatchOr] +# flags: --python-version 3.10 +class A: ... +m: object + +match m: + case [x] | (x): + reveal_type(x) From f9e9277723eb3b0807b069a467fc5fe4882c9260 Mon Sep 17 00:00:00 2001 From: Adrian Freund Date: Wed, 10 Mar 2021 00:19:42 +0100 Subject: [PATCH 02/76] Added match statement support to StrConv visitor --- mypy/strconv.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/mypy/strconv.py b/mypy/strconv.py index 5cc890bd91dc..9c8096188c3d 100644 --- a/mypy/strconv.py +++ b/mypy/strconv.py @@ -311,6 +311,15 @@ def visit_print_stmt(self, o: 'mypy.nodes.PrintStmt') -> str: def visit_exec_stmt(self, o: 'mypy.nodes.ExecStmt') -> str: return self.dump([o.expr, o.globals, o.locals], o) + def visit_match_stmt(self, o: 'mypy.nodes.MatchStmt') -> str: + a = [o.subject] # type: List[Any] + for i in range(len(o.patterns)): + a.append(('Pattern', [o.patterns[i]])) + if o.guards[i] is not None: + a.append(('Guard', [o.guards[i]])) + a.append(('Body', o.bodies[i].body)) + return self.dump(a, o) + # Expressions # Simple expressions @@ -535,6 +544,12 @@ def visit_backquote_expr(self, o: 'mypy.nodes.BackquoteExpr') -> str: def visit_temp_node(self, o: 'mypy.nodes.TempNode') -> str: return self.dump([o.type], o) + def visit_match_as(self, o: 'mypy.nodes.MatchAs') -> str: + return self.dump([o.name, o.pattern], o) + + def visit_match_or(self, o: 'mypy.nodes.MatchOr') -> str: + return self.dump(o.patterns, o) + def dump_tagged(nodes: Sequence[object], tag: Optional[str], str_conv: 'StrConv') -> str: """Convert an array into a pretty-printed multiline string representation. From 6c4109f43ddc94afab16d42086bc6a7c620c4805 Mon Sep 17 00:00:00 2001 From: Adrian Freund Date: Wed, 10 Mar 2021 00:30:55 +0100 Subject: [PATCH 03/76] Added match statement support to TraverserVisitor --- mypy/traverser.py | 23 ++++++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/mypy/traverser.py b/mypy/traverser.py index c4834c9acb6b..84328108df1f 100644 --- a/mypy/traverser.py +++ b/mypy/traverser.py @@ -8,12 +8,13 @@ Block, MypyFile, FuncBase, FuncItem, CallExpr, ClassDef, Decorator, FuncDef, ExpressionStmt, AssignmentStmt, OperatorAssignmentStmt, WhileStmt, ForStmt, ReturnStmt, AssertStmt, DelStmt, IfStmt, RaiseStmt, - TryStmt, WithStmt, NameExpr, MemberExpr, OpExpr, SliceExpr, CastExpr, RevealExpr, - UnaryExpr, ListExpr, TupleExpr, DictExpr, SetExpr, IndexExpr, AssignmentExpr, + TryStmt, WithStmt, MatchStmt, NameExpr, MemberExpr, OpExpr, SliceExpr, CastExpr, + RevealExpr, UnaryExpr, ListExpr, TupleExpr, DictExpr, SetExpr, IndexExpr, AssignmentExpr, GeneratorExpr, ListComprehension, SetComprehension, DictionaryComprehension, ConditionalExpr, TypeApplication, ExecStmt, Import, ImportFrom, LambdaExpr, ComparisonExpr, OverloadedFuncDef, YieldFromExpr, - YieldExpr, StarExpr, BackquoteExpr, AwaitExpr, PrintStmt, SuperExpr, Node, REVEAL_TYPE, + YieldExpr, StarExpr, BackquoteExpr, AwaitExpr, PrintStmt, SuperExpr, MatchAs, MatchOr, + Node, REVEAL_TYPE, ) @@ -156,6 +157,15 @@ def visit_with_stmt(self, o: WithStmt) -> None: targ.accept(self) o.body.accept(self) + def visit_match_stmt(self, o: MatchStmt) -> None: + o.subject.accept(self) + for i in range(len(o.patterns)): + o.patterns[i].accept(self) + guard = o.guards[i] + if guard is not None: + guard.accept(self) + o.bodies[i].accept(self) + def visit_member_expr(self, o: MemberExpr) -> None: o.expr.accept(self) @@ -279,6 +289,13 @@ def visit_await_expr(self, o: AwaitExpr) -> None: def visit_super_expr(self, o: SuperExpr) -> None: o.call.accept(self) + def visit_match_as(self, o: MatchAs) -> None: + o.pattern.accept(self) + + def visit_match_or(self, o: MatchOr) -> None: + for p in o.patterns: + p.accept(self) + def visit_import(self, o: Import) -> None: for a in o.assignments: a.accept(self) From 289d0149eee1602a94353dcdc973f98820e70bce Mon Sep 17 00:00:00 2001 From: Adrian Freund Date: Wed, 10 Mar 2021 12:31:06 +0100 Subject: [PATCH 04/76] Move MatchAs and MatchOr to patterns and add literal patterns This commit introduces patterns. Instead of representing patterns as expressions, like the python ast does, we create new data structures for them. As of this commit data structures for as-patterns, or-patterns and literal patterns are in place --- mypy/fastparse.py | 178 ++++++++++++++++++++-------- mypy/nodes.py | 25 ---- mypy/patterns.py | 57 +++++++++ mypy/strconv.py | 14 ++- mypy/traverser.py | 8 +- mypy/visitor.py | 27 +++-- test-data/unit/check-python310.test | 40 +++++++ 7 files changed, 260 insertions(+), 89 deletions(-) create mode 100644 mypy/patterns.py diff --git a/mypy/fastparse.py b/mypy/fastparse.py index e0277e12696e..e9fb75be2f16 100644 --- a/mypy/fastparse.py +++ b/mypy/fastparse.py @@ -6,6 +6,8 @@ from typing import ( Tuple, Union, TypeVar, Callable, Sequence, Optional, Any, Dict, cast, List, overload ) + +from mypy_extensions import trait from typing_extensions import Final, Literal, overload from mypy.sharedparse import ( @@ -24,11 +26,14 @@ UnaryExpr, LambdaExpr, ComparisonExpr, AssignmentExpr, StarExpr, YieldFromExpr, NonlocalDecl, DictionaryComprehension, SetComprehension, ComplexExpr, EllipsisExpr, YieldExpr, Argument, - AwaitExpr, MatchAs, MatchOr, TempNode, Expression, Statement, + AwaitExpr, TempNode, Expression, Statement, ARG_POS, ARG_OPT, ARG_STAR, ARG_NAMED, ARG_NAMED_OPT, ARG_STAR2, check_arg_names, FakeInfo, ) +from mypy.patterns import ( + AsPattern, OrPattern, LiteralPattern +) from mypy.types import ( Type, CallableType, AnyType, UnboundType, TupleType, TypeList, EllipsisType, CallableArgument, TypeOfAny, Instance, RawExpressionType, ProperType, UnionType, @@ -47,6 +52,7 @@ # Check if we can use the stdlib ast module instead of typed_ast. if sys.version_info >= (3, 8): import ast as ast3 + assert 'kind' in ast3.Constant._fields, \ "This 3.8.0 alpha (%s) is too old; 3.8.0a3 required" % sys.version.split()[0] # TODO: Num, Str, Bytes, NameConstant, Ellipsis are deprecated in 3.8. @@ -104,6 +110,16 @@ def ast3_parse(source: Union[str, bytes], filename: str, mode: str, # These don't exist before 3.8 NamedExpr = Any Constant = Any + + if sys.version_info >= (3, 10): + Match = ast3.Match + MatchAs = ast3.MatchAs + MatchOr = ast3.MatchOr + else: + Match = Any + MatchAs = Any + MatchOr = Any + except ImportError: try: from typed_ast import ast35 # type: ignore[attr-defined] # noqa: F401 @@ -137,7 +153,6 @@ def parse(source: Union[str, bytes], module: Optional[str], errors: Optional[Errors] = None, options: Optional[Options] = None) -> MypyFile: - """Parse a source file, without doing any semantic analysis. Return the parse tree. If errors is not provided, raise ParseError @@ -288,35 +303,14 @@ def is_no_type_check_decorator(expr: ast3.expr) -> bool: return False -class ASTConverter: - def __init__(self, - options: Options, - is_stub: bool, - errors: Errors) -> None: - # 'C' for class, 'F' for function - self.class_and_function_stack = [] # type: List[Literal['C', 'F']] - self.imports = [] # type: List[ImportBase] - - self.options = options - self.is_stub = is_stub +@trait +class Converter: + def __init__(self, errors: Optional[Errors]): self.errors = errors - self.type_ignores = {} # type: Dict[int, List[str]] - # Cache of visit_X methods keyed by type of visited object self.visitor_cache = {} # type: Dict[type, Callable[[Optional[AST]], Any]] - def note(self, msg: str, line: int, column: int) -> None: - self.errors.report(line, column, msg, severity='note', code=codes.SYNTAX) - - def fail(self, - msg: str, - line: int, - column: int, - blocker: bool = True) -> None: - if blocker or not self.options.ignore_errors: - self.errors.report(line, column, msg, blocker=blocker, code=codes.SYNTAX) - def visit(self, node: Optional[AST]) -> Any: if node is None: return None @@ -334,6 +328,45 @@ def set_line(self, node: N, n: Union[ast3.expr, ast3.stmt, ast3.ExceptHandler]) node.end_line = getattr(n, "end_lineno", None) if isinstance(n, ast3.expr) else None return node + def note(self, msg: str, line: int, column: int) -> None: + if self.errors is not None: + self.errors.report(line, column, msg, severity='note', code=codes.SYNTAX) + + def fail(self, + msg: str, + line: int, + column: int, + blocker: bool = True) -> None: + if self.errors is not None: + self.errors.report(line, column, msg, blocker=blocker, code=codes.SYNTAX) + + +class ASTConverter(Converter): + # Errors is optional is superclass, but not here + errors = None # type: Errors + + def __init__(self, + options: Options, + is_stub: bool, + errors: Errors) -> None: + super().__init__(errors) + # 'C' for class, 'F' for function + self.class_and_function_stack = [] # type: List[Literal['C', 'F']] + self.imports = [] # type: List[ImportBase] + + self.options = options + self.is_stub = is_stub + + self.type_ignores = {} # type: Dict[int, List[str]] + + def fail(self, + msg: str, + line: int, + column: int, + blocker: bool = True) -> None: + if blocker or not self.options.ignore_errors: + super().fail(msg, line, column, blocker) + def translate_opt_expr_list(self, l: Sequence[Optional[AST]]) -> List[Optional[Expression]]: res = [] # type: List[Optional[Expression]] for e in l: @@ -1273,33 +1306,88 @@ def visit_Index(self, n: Index) -> Node: return self.visit(cast(Any, n).value) # Match(expr subject, match_case* cases) # python 3.10 and later - def visit_Match(self, n: ast3.Match) -> MatchStmt: + def visit_Match(self, n: Match) -> MatchStmt: + pattern_converter = PatternConverter(self.errors) node = MatchStmt(self.visit(n.subject), - [self.visit(c.pattern) for c in n.cases], + [pattern_converter.visit(c.pattern) for c in n.cases], [self.visit(c.guard) for c in n.cases], [self.as_required_block(c.body, n.lineno) for c in n.cases]) return self.set_line(node, n) + +class PatternConverter(Converter): + def __init__(self, errors: Optional[Errors]) -> None: + super().__init__(errors) + # MatchAs(expr pattern, identifier name) - def visit_MatchAs(self, n: ast3.MatchAs) -> MatchAs: - node = MatchAs(self.visit(n.pattern), n.name) + def visit_MatchAs(self, n: MatchAs) -> AsPattern: + node = AsPattern(self.visit(n.pattern), n.name) return self.set_line(node, n) # MatchOr(expr* pattern) - def visit_MatchOr(self, n: ast3.MatchOr) -> MatchOr: - node = MatchOr([self.visit(pattern) for pattern in n.patterns]) + def visit_MatchOr(self, n: MatchOr) -> OrPattern: + node = OrPattern([self.visit(pattern) for pattern in n.patterns]) return self.set_line(node, n) + def assert_numeric_constant(self, n: ast3.AST) -> Union[int, float, complex]: + # Constant is Any on python < 3.8, but this code is only reachable on python >= 3.10 + if isinstance(n, Constant): # type: ignore[misc] + val = n.value + if isinstance(val, int) or isinstance(val, float) or isinstance(val, complex): + return val + raise RuntimeError("Only numeric literals can be used with '+' and '-'. Found " + + str(type(n))) -class TypeConverter: - def __init__(self, - errors: Optional[Errors], - line: int = -1, - override_column: int = -1, - assume_str_is_unicode: bool = True, - is_evaluated: bool = True, - ) -> None: - self.errors = errors + def visit_Constant(self, n: Constant) -> LiteralPattern: + val = n.value + if val is None or isinstance(val, bool) or isinstance(val, int) or \ + isinstance(val, float) or isinstance(val, complex) or \ + isinstance(val, str) or isinstance(val, bytes): + node = LiteralPattern(val) + else: + raise RuntimeError("Pattern not implemented for " + str(type(val))) + return self.set_line(node, n) + + def visit_UnaryOp(self, n: ast3.UnaryOp) -> LiteralPattern: + # Constant is Any on python < 3.8, but this code is only reachable on python >= 3.10 + if not isinstance(n.operand, Constant): # type: ignore[misc] + raise RuntimeError("Pattern not implemented for " + str(type(n.operand))) + + value = self.assert_numeric_constant(n.operand) + + if isinstance(n.op, ast3.UAdd): + node = LiteralPattern(value) + elif isinstance(n.op, ast3.USub): + node = LiteralPattern(-value) + else: + raise RuntimeError("Pattern not implemented for " + str(type(n.op))) + + return self.set_line(node, n) + + def visit_BinOp(self, n: ast3.BinOp) -> LiteralPattern: + if isinstance(n.left, UnaryOp) and isinstance(n.left.op, ast3.USub): + left_val = -1 * self.assert_numeric_constant(n.left.operand) + else: + left_val = self.assert_numeric_constant(n.left) + right_val = self.assert_numeric_constant(n.right) + + if left_val.imag != 0 or right_val.real: + raise RuntimeError("Unsupported pattern") + + if isinstance(n.op, ast3.Add): + node = LiteralPattern(left_val + right_val) + elif isinstance(n.op, ast3.Sub): + node = LiteralPattern(left_val - right_val) + else: + raise RuntimeError("Unsupported pattern") + + return self.set_line(node, n) + + +class TypeConverter(Converter): + def __init__(self, errors: Optional[Errors], line: int = -1, override_column: int = -1, + assume_str_is_unicode: bool = True, is_evaluated: bool = True) -> None: + super().__init__(errors) self.line = line self.override_column = override_column self.node_stack = [] # type: List[AST] @@ -1362,14 +1450,6 @@ def parent(self) -> Optional[AST]: return None return self.node_stack[-2] - def fail(self, msg: str, line: int, column: int) -> None: - if self.errors: - self.errors.report(line, column, msg, blocker=True, code=codes.SYNTAX) - - def note(self, msg: str, line: int, column: int) -> None: - if self.errors: - self.errors.report(line, column, msg, severity='note', code=codes.SYNTAX) - def translate_expr_list(self, l: Sequence[ast3.expr]) -> List[Type]: return [self.visit(e) for e in l] diff --git a/mypy/nodes.py b/mypy/nodes.py index e8aa93ab5a2e..662bd4382d5f 100644 --- a/mypy/nodes.py +++ b/mypy/nodes.py @@ -2295,31 +2295,6 @@ def accept(self, visitor: ExpressionVisitor[T]) -> T: return visitor.visit_await_expr(self) -# Note: CPython considers MatchAs and MatchOr to be expressions, but they are only allowed inside match_case patterns -class MatchAs(Expression): - pattern = None # type: Expression - name = None # type: str - - def __init__(self, pattern: Expression, name: str) -> None: - super().__init__() - self.pattern = pattern - self.name = name - - def accept(self, visitor: ExpressionVisitor[T]) -> T: - return visitor.visit_match_as(self) - - -class MatchOr(Expression): - patterns = None # type: List[Expression] - - def __init__(self, patterns: List[Expression]) -> None: - super().__init__() - self.patterns = patterns - - def accept(self, visitor: ExpressionVisitor[T]) -> T: - return visitor.visit_match_or(self) - - # Constants diff --git a/mypy/patterns.py b/mypy/patterns.py new file mode 100644 index 000000000000..dccbbb9dd2f6 --- /dev/null +++ b/mypy/patterns.py @@ -0,0 +1,57 @@ +"""Classes for representing match statement patterns.""" +from typing import TypeVar, List, Any + +from mypy_extensions import trait + +from mypy.nodes import Node +from mypy.visitor import PatternVisitor + +# These are not real AST nodes. CPython represents patterns using the normal expression nodes. + +T = TypeVar('T') + + +@trait +class Pattern(Node): + """A pattern node.""" + + __slots__ = () + + def accept(self, visitor: PatternVisitor[T]) -> T: + raise RuntimeError('Not implemented') + + +class AsPattern(Pattern): + pattern = None # type: Pattern + name = None # type: str + + def __init__(self, pattern: Pattern, name: str) -> None: + super().__init__() + self.pattern = pattern + self.name = name + + def accept(self, visitor: PatternVisitor[T]) -> T: + return visitor.visit_as_pattern(self) + + +class OrPattern(Pattern): + patterns = None # type: List[Pattern] + + def __init__(self, patterns: List[Pattern]) -> None: + super().__init__() + self.patterns = patterns + + def accept(self, visitor: PatternVisitor[T]) -> T: + return visitor.visit_or_pattern(self) + + +# TODO: Do we need subclassed for the typed of literals? +class LiteralPattern(Pattern): + value = None # type: Any + + def __init__(self, value: Any): + super().__init__() + self.value = value + + def accept(self, visitor: PatternVisitor[T]) -> T: + return visitor.visit_literal_pattern(self) diff --git a/mypy/strconv.py b/mypy/strconv.py index 9c8096188c3d..3baf2ca725e5 100644 --- a/mypy/strconv.py +++ b/mypy/strconv.py @@ -4,11 +4,15 @@ import os from typing import Any, List, Tuple, Optional, Union, Sequence +from typing_extensions import TYPE_CHECKING from mypy.util import short_type, IdMapper import mypy.nodes from mypy.visitor import NodeVisitor +if TYPE_CHECKING: + import mypy.patterns + class StrConv(NodeVisitor[str]): """Visitor for converting a node to a human-readable string. @@ -544,12 +548,18 @@ def visit_backquote_expr(self, o: 'mypy.nodes.BackquoteExpr') -> str: def visit_temp_node(self, o: 'mypy.nodes.TempNode') -> str: return self.dump([o.type], o) - def visit_match_as(self, o: 'mypy.nodes.MatchAs') -> str: + def visit_as_pattern(self, o: 'mypy.patterns.AsPattern') -> str: return self.dump([o.name, o.pattern], o) - def visit_match_or(self, o: 'mypy.nodes.MatchOr') -> str: + def visit_or_pattern(self, o: 'mypy.patterns.OrPattern') -> str: return self.dump(o.patterns, o) + def visit_literal_pattern(self, o: 'mypy.patterns.LiteralPattern') -> str: + value = o.value + if isinstance(o.value, str): + value = self.str_repr(o.value) + return self.dump([value], o) + def dump_tagged(nodes: Sequence[object], tag: Optional[str], str_conv: 'StrConv') -> str: """Convert an array into a pretty-printed multiline string representation. diff --git a/mypy/traverser.py b/mypy/traverser.py index 84328108df1f..33902d415479 100644 --- a/mypy/traverser.py +++ b/mypy/traverser.py @@ -3,6 +3,7 @@ from typing import List from mypy_extensions import mypyc_attr +from mypy.patterns import AsPattern, OrPattern from mypy.visitor import NodeVisitor from mypy.nodes import ( Block, MypyFile, FuncBase, FuncItem, CallExpr, ClassDef, Decorator, FuncDef, @@ -13,8 +14,7 @@ GeneratorExpr, ListComprehension, SetComprehension, DictionaryComprehension, ConditionalExpr, TypeApplication, ExecStmt, Import, ImportFrom, LambdaExpr, ComparisonExpr, OverloadedFuncDef, YieldFromExpr, - YieldExpr, StarExpr, BackquoteExpr, AwaitExpr, PrintStmt, SuperExpr, MatchAs, MatchOr, - Node, REVEAL_TYPE, + YieldExpr, StarExpr, BackquoteExpr, AwaitExpr, PrintStmt, SuperExpr, Node, REVEAL_TYPE, ) @@ -289,10 +289,10 @@ def visit_await_expr(self, o: AwaitExpr) -> None: def visit_super_expr(self, o: SuperExpr) -> None: o.call.accept(self) - def visit_match_as(self, o: MatchAs) -> None: + def visit_as_pattern(self, o: AsPattern) -> None: o.pattern.accept(self) - def visit_match_or(self, o: MatchOr) -> None: + def visit_or_pattern(self, o: OrPattern) -> None: for p in o.patterns: p.accept(self) diff --git a/mypy/visitor.py b/mypy/visitor.py index f13fe47668b2..e7897c45fea1 100644 --- a/mypy/visitor.py +++ b/mypy/visitor.py @@ -8,6 +8,7 @@ if TYPE_CHECKING: # break import cycle only needed for mypy import mypy.nodes + import mypy.patterns T = TypeVar('T') @@ -192,14 +193,6 @@ def visit_await_expr(self, o: 'mypy.nodes.AwaitExpr') -> T: def visit_temp_node(self, o: 'mypy.nodes.TempNode') -> T: pass - @abstractmethod - def visit_match_as(self, o: 'mypy.nodes.MatchAs') -> T: - pass - - @abstractmethod - def visit_match_or(self, o: 'mypy.nodes.MatchOr') -> T: - pass - @trait @mypyc_attr(allow_interpreted_subclasses=True) @@ -325,7 +318,23 @@ def visit_match_stmt(self, o: 'mypy.nodes.MatchStmt') -> T: @trait @mypyc_attr(allow_interpreted_subclasses=True) -class NodeVisitor(Generic[T], ExpressionVisitor[T], StatementVisitor[T]): +class PatternVisitor(Generic[T]): + @abstractmethod + def visit_as_pattern(self, o: 'mypy.patterns.AsPattern') -> T: + pass + + @abstractmethod + def visit_or_pattern(self, o: 'mypy.patterns.OrPattern') -> T: + pass + + @abstractmethod + def visit_literal_pattern(self, o: 'mypy.patterns.LiteralPattern') -> T: + pass + + +@trait +@mypyc_attr(allow_interpreted_subclasses=True) +class NodeVisitor(Generic[T], ExpressionVisitor[T], StatementVisitor[T], PatternVisitor[T]): """Empty base class for parse tree node visitors. The T type argument specifies the return type of the visit diff --git a/test-data/unit/check-python310.test b/test-data/unit/check-python310.test index 4d0e1c387116..bf80434c8e28 100644 --- a/test-data/unit/check-python310.test +++ b/test-data/unit/check-python310.test @@ -33,3 +33,43 @@ m: object match m: case [x] | (x): reveal_type(x) + + +[case testPatternLiteral] +# flags: --python-version 3.10 +m: object +match m: + case 1: + pass + case -1: + pass + case 1+2j: + pass + case -1+2j: + pass + case 1-2j: + pass + case -1-2j: + pass + case "str": + pass + case b"bytes": + pass + case r"raw_string": + pass + case None: + pass + case True: + pass + case False: + pass + + +[case testCapturePattern] +# flags: --python-version 3.10 +m: object +match m: + case x: + pass + case longName: + pass From 1973a88b265d824a6fb5d5db1172ea90a8a35ba5 Mon Sep 17 00:00:00 2001 From: Adrian Freund Date: Wed, 10 Mar 2021 12:51:50 +0100 Subject: [PATCH 05/76] Added nodes for capture and wildcard patterns --- mypy/fastparse.py | 10 +++++++--- mypy/patterns.py | 18 +++++++++++++++++- mypy/visitor.py | 8 ++++++++ test-data/unit/check-python310.test | 2 +- 4 files changed, 33 insertions(+), 5 deletions(-) diff --git a/mypy/fastparse.py b/mypy/fastparse.py index e9fb75be2f16..b80f24277948 100644 --- a/mypy/fastparse.py +++ b/mypy/fastparse.py @@ -32,7 +32,7 @@ FakeInfo, ) from mypy.patterns import ( - AsPattern, OrPattern, LiteralPattern + AsPattern, OrPattern, LiteralPattern, CapturePattern, WildcardPattern ) from mypy.types import ( Type, CallableType, AnyType, UnboundType, TupleType, TypeList, EllipsisType, CallableArgument, @@ -52,7 +52,6 @@ # Check if we can use the stdlib ast module instead of typed_ast. if sys.version_info >= (3, 8): import ast as ast3 - assert 'kind' in ast3.Constant._fields, \ "This 3.8.0 alpha (%s) is too old; 3.8.0a3 required" % sys.version.split()[0] # TODO: Num, Str, Bytes, NameConstant, Ellipsis are deprecated in 3.8. @@ -119,7 +118,6 @@ def ast3_parse(source: Union[str, bytes], filename: str, mode: str, Match = Any MatchAs = Any MatchOr = Any - except ImportError: try: from typed_ast import ast35 # type: ignore[attr-defined] # noqa: F401 @@ -1383,6 +1381,12 @@ def visit_BinOp(self, n: ast3.BinOp) -> LiteralPattern: return self.set_line(node, n) + def visit_Name(self, n: ast3.Name) -> Union[WildcardPattern, CapturePattern]: + if n.id == '_': + return WildcardPattern() + else: + return CapturePattern(n.id) + class TypeConverter(Converter): def __init__(self, errors: Optional[Errors], line: int = -1, override_column: int = -1, diff --git a/mypy/patterns.py b/mypy/patterns.py index dccbbb9dd2f6..37bdb0514b85 100644 --- a/mypy/patterns.py +++ b/mypy/patterns.py @@ -45,7 +45,7 @@ def accept(self, visitor: PatternVisitor[T]) -> T: return visitor.visit_or_pattern(self) -# TODO: Do we need subclassed for the typed of literals? +# TODO: Do we need subclasses for the types of literals? class LiteralPattern(Pattern): value = None # type: Any @@ -55,3 +55,19 @@ def __init__(self, value: Any): def accept(self, visitor: PatternVisitor[T]) -> T: return visitor.visit_literal_pattern(self) + + +class CapturePattern(Pattern): + name = None # type: str + + def __init__(self, name: str): + super().__init__() + self.name = name + + def accept(self, visitor: PatternVisitor[T]) -> T: + return visitor.visit_capture_pattern(self) + + +class WildcardPattern(Pattern): + def accept(self, visitor: PatternVisitor[T]) -> T: + return visitor.visit_wildcard_pattern(self) diff --git a/mypy/visitor.py b/mypy/visitor.py index e7897c45fea1..0107f9f6b8f4 100644 --- a/mypy/visitor.py +++ b/mypy/visitor.py @@ -331,6 +331,14 @@ def visit_or_pattern(self, o: 'mypy.patterns.OrPattern') -> T: def visit_literal_pattern(self, o: 'mypy.patterns.LiteralPattern') -> T: pass + @abstractmethod + def visit_capture_pattern(self, o: 'mypy.patterns.CapturePattern') -> T: + pass + + @abstractmethod + def visit_wildcard_pattern(self, o: 'mypy.patterns.WildcardPattern') -> T: + pass + @trait @mypyc_attr(allow_interpreted_subclasses=True) diff --git a/test-data/unit/check-python310.test b/test-data/unit/check-python310.test index bf80434c8e28..c3fc81c3ebed 100644 --- a/test-data/unit/check-python310.test +++ b/test-data/unit/check-python310.test @@ -35,7 +35,7 @@ match m: reveal_type(x) -[case testPatternLiteral] +[case testLiteralPattern] # flags: --python-version 3.10 m: object match m: From cc92db44f73a0963fa3312e20f57b80c0bd0eb23 Mon Sep 17 00:00:00 2001 From: Adrian Freund Date: Wed, 10 Mar 2021 15:04:11 +0100 Subject: [PATCH 06/76] Added nodes for value, sequence and mapping patterns --- mypy/fastparse.py | 97 ++++++++++++++++++++++++----- mypy/patterns.py | 58 ++++++++++++++++- mypy/visitor.py | 16 +++++ test-data/unit/check-python310.test | 82 ++++++++++++++++++++++++ test-requirements.txt | 1 + 5 files changed, 235 insertions(+), 19 deletions(-) diff --git a/mypy/fastparse.py b/mypy/fastparse.py index b80f24277948..f4e78278874c 100644 --- a/mypy/fastparse.py +++ b/mypy/fastparse.py @@ -32,7 +32,8 @@ FakeInfo, ) from mypy.patterns import ( - AsPattern, OrPattern, LiteralPattern, CapturePattern, WildcardPattern + AsPattern, OrPattern, LiteralPattern, CapturePattern, WildcardPattern, ValuePattern, + SequencePattern, StarredPattern, MappingPattern ) from mypy.types import ( Type, CallableType, AnyType, UnboundType, TupleType, TypeList, EllipsisType, CallableArgument, @@ -1305,7 +1306,7 @@ def visit_Index(self, n: Index) -> Node: # Match(expr subject, match_case* cases) # python 3.10 and later def visit_Match(self, n: Match) -> MatchStmt: - pattern_converter = PatternConverter(self.errors) + pattern_converter = PatternConverter(self.options, self.errors) node = MatchStmt(self.visit(n.subject), [pattern_converter.visit(c.pattern) for c in n.cases], [self.visit(c.guard) for c in n.cases], @@ -1314,9 +1315,14 @@ def visit_Match(self, n: Match) -> MatchStmt: class PatternConverter(Converter): - def __init__(self, errors: Optional[Errors]) -> None: + # Errors is optional is superclass, but not here + errors = None # type: Errors + + def __init__(self, options: Options, errors: Errors) -> None: super().__init__(errors) + self.options = options + # MatchAs(expr pattern, identifier name) def visit_MatchAs(self, n: MatchAs) -> AsPattern: node = AsPattern(self.visit(n.pattern), n.name) @@ -1327,25 +1333,16 @@ def visit_MatchOr(self, n: MatchOr) -> OrPattern: node = OrPattern([self.visit(pattern) for pattern in n.patterns]) return self.set_line(node, n) - def assert_numeric_constant(self, n: ast3.AST) -> Union[int, float, complex]: - # Constant is Any on python < 3.8, but this code is only reachable on python >= 3.10 - if isinstance(n, Constant): # type: ignore[misc] - val = n.value - if isinstance(val, int) or isinstance(val, float) or isinstance(val, complex): - return val - raise RuntimeError("Only numeric literals can be used with '+' and '-'. Found " - + str(type(n))) - + # Constant(constant value) def visit_Constant(self, n: Constant) -> LiteralPattern: val = n.value - if val is None or isinstance(val, bool) or isinstance(val, int) or \ - isinstance(val, float) or isinstance(val, complex) or \ - isinstance(val, str) or isinstance(val, bytes): + if val is None or isinstance(val, (bool, int, float, complex, str, bytes)): node = LiteralPattern(val) else: raise RuntimeError("Pattern not implemented for " + str(type(val))) return self.set_line(node, n) + # UnaryOp(unaryop op, expr operand) def visit_UnaryOp(self, n: ast3.UnaryOp) -> LiteralPattern: # Constant is Any on python < 3.8, but this code is only reachable on python >= 3.10 if not isinstance(n.operand, Constant): # type: ignore[misc] @@ -1362,6 +1359,7 @@ def visit_UnaryOp(self, n: ast3.UnaryOp) -> LiteralPattern: return self.set_line(node, n) + # BinOp(expr left, operator op, expr right) def visit_BinOp(self, n: ast3.BinOp) -> LiteralPattern: if isinstance(n.left, UnaryOp) and isinstance(n.left.op, ast3.USub): left_val = -1 * self.assert_numeric_constant(n.left.operand) @@ -1381,11 +1379,76 @@ def visit_BinOp(self, n: ast3.BinOp) -> LiteralPattern: return self.set_line(node, n) + def assert_numeric_constant(self, n: ast3.AST) -> Union[int, float, complex]: + # Constant is Any on python < 3.8, but this code is only reachable on python >= 3.10 + if isinstance(n, Constant): # type: ignore[misc] + val = n.value + if isinstance(val, (int, float, complex)): + return val + raise RuntimeError("Only numeric literals can be used with '+' and '-'. Found " + + str(type(n))) + + # Name(identifier id, expr_context ctx) def visit_Name(self, n: ast3.Name) -> Union[WildcardPattern, CapturePattern]: + node = None # type: Optional[Union[WildcardPattern, CapturePattern]] if n.id == '_': - return WildcardPattern() + node = WildcardPattern() else: - return CapturePattern(n.id) + node = CapturePattern(n.id) + + return self.set_line(node, n) + + # Attribute(expr value, identifier attr, expr_context ctx) + def visit_Attribute(self, n: ast3.Attribute) -> ValuePattern: + # We can directly call `visit_Attribute`, as we know the type of n + node = ASTConverter(self.options, False, self.errors).visit_Attribute(n) + if not isinstance(node, MemberExpr): + raise RuntimeError("Unsupported pattern") + return self.set_line(ValuePattern(node), n) + + # List(expr* elts, expr_context ctx) + def visit_List(self, n: ast3.List) -> SequencePattern: + return self.make_sequence(n) + + # Tuple(expr* elts, expr_context ctx) + def visit_Tuple(self, n: ast3.Tuple) -> SequencePattern: + return self.make_sequence(n) + + def make_sequence(self, n: Union[ast3.List, ast3.Tuple]) -> SequencePattern: + patterns = [self.visit(p) for p in n.elts] + stars = [p for p in patterns if isinstance(p, StarredPattern)] + if len(stars) >= 2: + raise RuntimeError("Unsupported pattern") + + node = SequencePattern(patterns) + return self.set_line(node, n) + + # Starred(expr value, expr_context ctx) + def visit_Starred(self, n: ast3.Starred) -> StarredPattern: + expr = n.value + if not isinstance(expr, Name): + raise RuntimeError("Unsupported Pattern") + node = StarredPattern(expr.id) + + return self.set_line(node, n) + + # Dict(expr* keys, expr* values) + def visit_Dict(self, n: ast3.Dict) -> MappingPattern: + keys = [self.visit(k) for k in n.keys] + values = [self.visit(v) for v in n.values] + + if keys[-1] is None: + rest = values.pop() + keys.pop() + else: + rest = None + + for key in keys: + if not isinstance(key, (LiteralPattern, ValuePattern)): + raise RuntimeError("Unsupported Pattern") + + node = MappingPattern(keys, values, rest) + return self.set_line(node, n) class TypeConverter(Converter): diff --git a/mypy/patterns.py b/mypy/patterns.py index 37bdb0514b85..2249328a7650 100644 --- a/mypy/patterns.py +++ b/mypy/patterns.py @@ -1,9 +1,9 @@ """Classes for representing match statement patterns.""" -from typing import TypeVar, List, Any +from typing import TypeVar, List, Any, Union, Optional from mypy_extensions import trait -from mypy.nodes import Node +from mypy.nodes import Node, MemberExpr from mypy.visitor import PatternVisitor # These are not real AST nodes. CPython represents patterns using the normal expression nodes. @@ -71,3 +71,57 @@ def accept(self, visitor: PatternVisitor[T]) -> T: class WildcardPattern(Pattern): def accept(self, visitor: PatternVisitor[T]) -> T: return visitor.visit_wildcard_pattern(self) + + +class ValuePattern(Pattern): + expr = None # type: MemberExpr + + def __init__(self, expr: MemberExpr): + super().__init__() + self.expr = expr + + def accept(self, visitor: PatternVisitor[T]) -> T: + return visitor.visit_value_pattern(self) + + +class SequencePattern(Pattern): + patterns = None # type: List[Pattern] + + def __init__(self, patterns: List[Pattern]): + super().__init__() + self.patterns = patterns + + def accept(self, visitor: PatternVisitor[T]) -> T: + return visitor.visit_sequence_pattern(self) + + +# TODO: A StarredPattern is only valid within a SequencePattern. This is not guaranteed by our +# type hierarchy. Should it be? +class StarredPattern(Pattern): + name = None # type: str + + def __init__(self, name: str): + super().__init__() + self.name = name + + def accept(self, visitor: PatternVisitor[T]) -> T: + return visitor.visit_starred_pattern(self) + + +MappingKeyPattern = Union[LiteralPattern, ValuePattern] + + +class MappingPattern(Pattern): + keys = None # type: List[MappingKeyPattern] + values = None # type: List[Pattern] + rest = None # type: Optional[CapturePattern] + + def __init__(self, keys: List[MappingKeyPattern], values: List[Pattern], + rest: Optional[CapturePattern]): + super().__init__() + self.keys = keys + self.values = values + self.rest = rest + + def accept(self, visitor: PatternVisitor[T]) -> T: + return visitor.visit_mapping_pattern(self) diff --git a/mypy/visitor.py b/mypy/visitor.py index 0107f9f6b8f4..ebbeba202f63 100644 --- a/mypy/visitor.py +++ b/mypy/visitor.py @@ -339,6 +339,22 @@ def visit_capture_pattern(self, o: 'mypy.patterns.CapturePattern') -> T: def visit_wildcard_pattern(self, o: 'mypy.patterns.WildcardPattern') -> T: pass + @abstractmethod + def visit_value_pattern(self, o: 'mypy.patterns.ValuePattern') -> T: + pass + + @abstractmethod + def visit_sequence_pattern(self, o: 'mypy.patterns.SequencePattern') -> T: + pass + + @abstractmethod + def visit_starred_pattern(self, o: 'mypy.patterns.StarredPattern') -> T: + pass + + @abstractmethod + def visit_mapping_pattern(self, o: 'mypy.patterns.MappingPattern') -> T: + pass + @trait @mypyc_attr(allow_interpreted_subclasses=True) diff --git a/test-data/unit/check-python310.test b/test-data/unit/check-python310.test index c3fc81c3ebed..778d5dd5db37 100644 --- a/test-data/unit/check-python310.test +++ b/test-data/unit/check-python310.test @@ -73,3 +73,85 @@ match m: pass case longName: pass + +[case testWildcardPattern] +# flags: --python-version 3.10 +m: object +match m: + case _: + pass + + +[case testValuePattern] +# flags: --python-version 3.10 +class A: + b = 1 +a = A() +m: object + +match m: + case a.b: + pass + + +[case testGroupPattern] +# flags: --python-version 3.10 +m: object + +match m: + case (1): + pass + + +[case testSequencePattern] +# flags: --python-version 3.10 +m: object + +match m: + case []: + pass + case (): + pass + case [1]: + pass + case (1,): + pass + case [1, 2, 3]: + pass + case (1, 2, 3): + pass + case [1, *a, 2]: + pass + case (1, *a, 2): + pass + case [1, *_, 2]: + pass + case (1, *_, 2): + pass + + +[case testMappingPattern] +# flags: --python-version 3.10 +class A: + b = 'l' + c = 2 +a = A() +m: object + +match m: + case {'k': v}: + pass + case {a.b: v}: + pass + case {1: v}: + pass + case {a.c: v}: + pass + case {'k': v1, a.b: v2, 1: v3, a.c: v4}: + pass + case {'k': 1, a.b: "str", 1: b'bytes', a.c: None}: + pass + case {'k': v, **r}: + pass + case {**r}: + pass diff --git a/test-requirements.txt b/test-requirements.txt index 2d83221c2f7a..8986b4115cd5 100644 --- a/test-requirements.txt +++ b/test-requirements.txt @@ -13,6 +13,7 @@ pytest-forked>=1.3.0,<2.0.0 pytest-cov>=2.10.0,<3.0.0 typing>=3.5.2; python_version < '3.5' py>=1.5.2 +typed_ast>=1.4.0,<1.5.0 virtualenv<20 setuptools!=50 importlib-metadata==0.20 From 057716c3969383beb6c6e5230233023178671fd1 Mon Sep 17 00:00:00 2001 From: Adrian Freund Date: Wed, 10 Mar 2021 15:44:10 +0100 Subject: [PATCH 07/76] Added nodes for class patterns --- mypy/fastparse.py | 22 ++++++++++++++++++++-- mypy/patterns.py | 20 +++++++++++++++++++- mypy/visitor.py | 4 ++++ test-data/unit/check-python310.test | 23 +++++++++++++++++++++++ 4 files changed, 66 insertions(+), 3 deletions(-) diff --git a/mypy/fastparse.py b/mypy/fastparse.py index f4e78278874c..3af29e747764 100644 --- a/mypy/fastparse.py +++ b/mypy/fastparse.py @@ -26,14 +26,14 @@ UnaryExpr, LambdaExpr, ComparisonExpr, AssignmentExpr, StarExpr, YieldFromExpr, NonlocalDecl, DictionaryComprehension, SetComprehension, ComplexExpr, EllipsisExpr, YieldExpr, Argument, - AwaitExpr, TempNode, Expression, Statement, + AwaitExpr, TempNode, RefExpr, Expression, Statement, ARG_POS, ARG_OPT, ARG_STAR, ARG_NAMED, ARG_NAMED_OPT, ARG_STAR2, check_arg_names, FakeInfo, ) from mypy.patterns import ( AsPattern, OrPattern, LiteralPattern, CapturePattern, WildcardPattern, ValuePattern, - SequencePattern, StarredPattern, MappingPattern + SequencePattern, StarredPattern, MappingPattern, ClassPattern ) from mypy.types import ( Type, CallableType, AnyType, UnboundType, TupleType, TypeList, EllipsisType, CallableArgument, @@ -1450,6 +1450,24 @@ def visit_Dict(self, n: ast3.Dict) -> MappingPattern: node = MappingPattern(keys, values, rest) return self.set_line(node, n) + # Call(expr func, expr* args, keyword* keywords) + def visit_Call(self, n: ast3.Call) -> ClassPattern: + def raise_if_none(value: Optional[str]) -> str: + if value is None: + raise RuntimeError("Unsupported Pattern") + else: + return value + + class_ref = ASTConverter(self.options, False, self.errors).visit(n.func) + if not isinstance(class_ref, RefExpr): + raise RuntimeError("Unsupported Pattern") + positionals = [self.visit(p) for p in n.args] + keyword_keys = [raise_if_none(keyword.arg) for keyword in n.keywords] + keyword_values = [self.visit(keyword.value) for keyword in n.keywords] + + node = ClassPattern(class_ref, positionals, keyword_keys, keyword_values) + return self.set_line(node, n) + class TypeConverter(Converter): def __init__(self, errors: Optional[Errors], line: int = -1, override_column: int = -1, diff --git a/mypy/patterns.py b/mypy/patterns.py index 2249328a7650..023ad204a504 100644 --- a/mypy/patterns.py +++ b/mypy/patterns.py @@ -3,7 +3,7 @@ from mypy_extensions import trait -from mypy.nodes import Node, MemberExpr +from mypy.nodes import Node, MemberExpr, RefExpr from mypy.visitor import PatternVisitor # These are not real AST nodes. CPython represents patterns using the normal expression nodes. @@ -125,3 +125,21 @@ def __init__(self, keys: List[MappingKeyPattern], values: List[Pattern], def accept(self, visitor: PatternVisitor[T]) -> T: return visitor.visit_mapping_pattern(self) + + +class ClassPattern(Pattern): + class_ref = None # type: RefExpr + positionals = None # type: List[Pattern] + keyword_keys = None # type: List[str] + keyword_values = None # type: List[Pattern] + + def __init__(self, class_ref: RefExpr, positionals: List[Pattern], keyword_keys: List[str], + keyword_values: List[Pattern]): + super().__init__() + self.class_ref = class_ref + self.positionals = positionals + self.keyword_keys = keyword_keys + self.keyword_values = keyword_values + + def accept(self, visitor: PatternVisitor[T]) -> T: + return visitor.visit_class_pattern(self) diff --git a/mypy/visitor.py b/mypy/visitor.py index ebbeba202f63..3d8b92901997 100644 --- a/mypy/visitor.py +++ b/mypy/visitor.py @@ -355,6 +355,10 @@ def visit_starred_pattern(self, o: 'mypy.patterns.StarredPattern') -> T: def visit_mapping_pattern(self, o: 'mypy.patterns.MappingPattern') -> T: pass + @abstractmethod + def visit_class_pattern(self, o: 'mypy.patterns.ClassPattern') -> T: + pass + @trait @mypyc_attr(allow_interpreted_subclasses=True) diff --git a/test-data/unit/check-python310.test b/test-data/unit/check-python310.test index 778d5dd5db37..a65eabefc21d 100644 --- a/test-data/unit/check-python310.test +++ b/test-data/unit/check-python310.test @@ -155,3 +155,26 @@ match m: pass case {**r}: pass + + +[case testClassPattern] +# flags: --python-version 3.10 +class A: + pass +class B: + __match_args__ = ('a', 'b') + a: int + b: int + +m: object + +match m: + case A(): + pass + case B(1, 2): + pass + case B(1, b=2): + pass + case B(a=1, b=2): + pass +[builtins fixtures/tuple.pyi] From 6bfb098fa0320f88e8f211f0330ca7c19b82efea Mon Sep 17 00:00:00 2001 From: Adrian Freund Date: Wed, 10 Mar 2021 18:31:00 +0100 Subject: [PATCH 08/76] Added parser tests for pattern matching and tweaks to pattern classes --- mypy/fastparse.py | 4 +- mypy/patterns.py | 12 +- mypy/strconv.py | 36 +- mypy/test/testparse.py | 3 + test-data/unit/check-python310.test | 12 +- test-data/unit/parse-python310.test | 515 ++++++++++++++++++++++++++++ 6 files changed, 572 insertions(+), 10 deletions(-) create mode 100644 test-data/unit/parse-python310.test diff --git a/mypy/fastparse.py b/mypy/fastparse.py index 3af29e747764..e249fab228c2 100644 --- a/mypy/fastparse.py +++ b/mypy/fastparse.py @@ -1394,7 +1394,7 @@ def visit_Name(self, n: ast3.Name) -> Union[WildcardPattern, CapturePattern]: if n.id == '_': node = WildcardPattern() else: - node = CapturePattern(n.id) + node = CapturePattern(ASTConverter(self.options, False, self.errors).visit_Name(n)) return self.set_line(node, n) @@ -1428,7 +1428,7 @@ def visit_Starred(self, n: ast3.Starred) -> StarredPattern: expr = n.value if not isinstance(expr, Name): raise RuntimeError("Unsupported Pattern") - node = StarredPattern(expr.id) + node = StarredPattern(self.visit_Name(expr)) return self.set_line(node, n) diff --git a/mypy/patterns.py b/mypy/patterns.py index 023ad204a504..af5ca29df393 100644 --- a/mypy/patterns.py +++ b/mypy/patterns.py @@ -3,7 +3,7 @@ from mypy_extensions import trait -from mypy.nodes import Node, MemberExpr, RefExpr +from mypy.nodes import Node, MemberExpr, RefExpr, NameExpr from mypy.visitor import PatternVisitor # These are not real AST nodes. CPython represents patterns using the normal expression nodes. @@ -58,9 +58,9 @@ def accept(self, visitor: PatternVisitor[T]) -> T: class CapturePattern(Pattern): - name = None # type: str + name = None # type: NameExpr - def __init__(self, name: str): + def __init__(self, name: NameExpr): super().__init__() self.name = name @@ -98,11 +98,11 @@ def accept(self, visitor: PatternVisitor[T]) -> T: # TODO: A StarredPattern is only valid within a SequencePattern. This is not guaranteed by our # type hierarchy. Should it be? class StarredPattern(Pattern): - name = None # type: str + capture = None # type: Union[CapturePattern, WildcardPattern] - def __init__(self, name: str): + def __init__(self, capture: Union[CapturePattern, WildcardPattern]): super().__init__() - self.name = name + self.capture = capture def accept(self, visitor: PatternVisitor[T]) -> T: return visitor.visit_starred_pattern(self) diff --git a/mypy/strconv.py b/mypy/strconv.py index 3baf2ca725e5..ecccea57a636 100644 --- a/mypy/strconv.py +++ b/mypy/strconv.py @@ -549,6 +549,7 @@ def visit_temp_node(self, o: 'mypy.nodes.TempNode') -> str: return self.dump([o.type], o) def visit_as_pattern(self, o: 'mypy.patterns.AsPattern') -> str: + # We display the name first for better readability return self.dump([o.name, o.pattern], o) def visit_or_pattern(self, o: 'mypy.patterns.OrPattern') -> str: @@ -557,9 +558,42 @@ def visit_or_pattern(self, o: 'mypy.patterns.OrPattern') -> str: def visit_literal_pattern(self, o: 'mypy.patterns.LiteralPattern') -> str: value = o.value if isinstance(o.value, str): - value = self.str_repr(o.value) + value = "'" + self.str_repr(o.value) + "'" return self.dump([value], o) + def visit_capture_pattern(self, o: 'mypy.patterns.CapturePattern') -> str: + return self.dump([o.name], o) + + def visit_wildcard_pattern(self, o: 'mypy.patterns.WildcardPattern') -> str: + return self.dump([], o) + + def visit_value_pattern(self, o: 'mypy.patterns.ValuePattern') -> str: + return self.dump([o.expr], o) + + def visit_sequence_pattern(self, o: 'mypy.patterns.SequencePattern') -> str: + return self.dump(o.patterns, o) + + def visit_starred_pattern(self, o: 'mypy.patterns.StarredPattern') -> str: + return self.dump([o.capture], o) + + def visit_mapping_pattern(self, o: 'mypy.patterns.MappingPattern') -> str: + a = [] # type: List[Any] + for i in range(len(o.keys)): + a.append(('Key', [o.keys[i]])) + a.append(('Value', [o.values[i]])) + if o.rest is not None: + a.append(('Rest', [o.rest])) + return self.dump(a, o) + + def visit_class_pattern(self, o: 'mypy.patterns.ClassPattern') -> str: + a = [o.class_ref] # type: List[Any] + if len(o.positionals) > 0: + a.append(('Positionals', o.positionals)) + for i in range(len(o.keyword_keys)): + a.append(('Keyword', [o.keyword_keys[i], o.keyword_values[i]])) + + return self.dump(a, o) + def dump_tagged(nodes: Sequence[object], tag: Optional[str], str_conv: 'StrConv') -> str: """Convert an array into a pretty-printed multiline string representation. diff --git a/mypy/test/testparse.py b/mypy/test/testparse.py index e9ff6839bc2c..c8147774cd86 100644 --- a/mypy/test/testparse.py +++ b/mypy/test/testparse.py @@ -16,6 +16,7 @@ class ParserSuite(DataSuite): required_out_section = True base_path = '.' files = ['parse.test', + 'parse-python310.test', 'parse-python2.test'] def run_case(self, testcase: DataDrivenTestCase) -> None: @@ -31,6 +32,8 @@ def test_parser(testcase: DataDrivenTestCase) -> None: if testcase.file.endswith('python2.test'): options.python_version = defaults.PYTHON2_VERSION + elif testcase.file.endswith('python310.test'): + options.python_version = (3, 10) else: options.python_version = defaults.PYTHON3_VERSION diff --git a/test-data/unit/check-python310.test b/test-data/unit/check-python310.test index a65eabefc21d..defe1038fa63 100644 --- a/test-data/unit/check-python310.test +++ b/test-data/unit/check-python310.test @@ -35,6 +35,16 @@ match m: reveal_type(x) +[case testMatchWithGuard] +# flags: --python-version 3.10 +m: object + +match m: + case 1 if False: + pass + case 2: + pass + [case testLiteralPattern] # flags: --python-version 3.10 m: object @@ -149,7 +159,7 @@ match m: pass case {'k': v1, a.b: v2, 1: v3, a.c: v4}: pass - case {'k': 1, a.b: "str", 1: b'bytes', a.c: None}: + case {'k1': 1, 'k2': "str", 'k3': b'bytes', 'k4': None}: pass case {'k': v, **r}: pass diff --git a/test-data/unit/parse-python310.test b/test-data/unit/parse-python310.test new file mode 100644 index 000000000000..1216408d21a9 --- /dev/null +++ b/test-data/unit/parse-python310.test @@ -0,0 +1,515 @@ +-- Test cases for parser -- Python 3.10 syntax (match statement) +-- +-- See parse.test for a description of this file format. + +[case testSimpleMatch] +match a: + case 1: + pass +[out] +MypyFile:1( + MatchStmt:1( + NameExpr(a) + Pattern( + LiteralPattern:2( + 1)) + Body( + PassStmt:3()))) + + +[case testAsPattern] +match a: + case 1 as b: + pass +[out] +MypyFile:1( + MatchStmt:1( + NameExpr(a) + Pattern( + AsPattern:2( + b + LiteralPattern:2( + 1))) + Body( + PassStmt:3()))) + + +[case testLiteralPattern] +match a: + case 1: + pass + case -1: + pass + case 1+2j: + pass + case -1+2j: + pass + case 1-2j: + pass + case -1-2j: + pass + case "str": + pass + case b"bytes": + pass + case r"raw_string": + pass + case None: + pass + case True: + pass + case False: + pass +[out] +MypyFile:1( + MatchStmt:1( + NameExpr(a) + Pattern( + LiteralPattern:2( + 1)) + Body( + PassStmt:3()) + Pattern( + LiteralPattern:4( + -1)) + Body( + PassStmt:5()) + Pattern( + LiteralPattern:6( + (1+2j))) + Body( + PassStmt:7()) + Pattern( + LiteralPattern:8( + (-1+2j))) + Body( + PassStmt:9()) + Pattern( + LiteralPattern:10( + (1-2j))) + Body( + PassStmt:11()) + Pattern( + LiteralPattern:12( + (-1-2j))) + Body( + PassStmt:13()) + Pattern( + LiteralPattern:14( + 'str')) + Body( + PassStmt:15()) + Pattern( + LiteralPattern:16( + b'bytes')) + Body( + PassStmt:17()) + Pattern( + LiteralPattern:18( + 'raw_string')) + Body( + PassStmt:19()) + Pattern( + LiteralPattern:20()) + Body( + PassStmt:21()) + Pattern( + LiteralPattern:22( + True)) + Body( + PassStmt:23()) + Pattern( + LiteralPattern:24( + False)) + Body( + PassStmt:25()))) + +[case testCapturePattern] +match a: + case x: + pass + case longName: + pass +[out] +MypyFile:1( + MatchStmt:1( + NameExpr(a) + Pattern( + CapturePattern:2( + NameExpr(x))) + Body( + PassStmt:3()) + Pattern( + CapturePattern:4( + NameExpr(longName))) + Body( + PassStmt:5()))) + +[case testWildcardPattern] +match a: + case _: + pass +[out] +MypyFile:1( + MatchStmt:1( + NameExpr(a) + Pattern( + WildcardPattern:2()) + Body( + PassStmt:3()))) + +[case testValuePattern] +match a: + case b.c: + pass + case b.c.d.e.f: + pass +[out] +MypyFile:1( + MatchStmt:1( + NameExpr(a) + Pattern( + ValuePattern:2( + MemberExpr:2( + NameExpr(b) + c))) + Body( + PassStmt:3()) + Pattern( + ValuePattern:4( + MemberExpr:4( + MemberExpr:4( + MemberExpr:4( + MemberExpr:4( + NameExpr(b) + c) + d) + e) + f))) + Body( + PassStmt:5()))) + +[case testGroupPattern] +# This is optimized out by the compiler. It doesn't appear in the ast +match a: + case (1): + pass +[out] +MypyFile:1( + MatchStmt:2( + NameExpr(a) + Pattern( + LiteralPattern:3( + 1)) + Body( + PassStmt:4()))) + +[case testSequencePattern] +match a: + case []: + pass + case (): + pass + case [1]: + pass + case (1,): + pass + case [1, 2, 3]: + pass + case (1, 2, 3): + pass + case [1, *a, 2]: + pass + case (1, *a, 2): + pass + case [1, *_, 2]: + pass + case (1, *_, 2): + pass +[out] +MypyFile:1( + MatchStmt:1( + NameExpr(a) + Pattern( + SequencePattern:2()) + Body( + PassStmt:3()) + Pattern( + SequencePattern:4()) + Body( + PassStmt:5()) + Pattern( + SequencePattern:6( + LiteralPattern:6( + 1))) + Body( + PassStmt:7()) + Pattern( + SequencePattern:8( + LiteralPattern:8( + 1))) + Body( + PassStmt:9()) + Pattern( + SequencePattern:10( + LiteralPattern:10( + 1) + LiteralPattern:10( + 2) + LiteralPattern:10( + 3))) + Body( + PassStmt:11()) + Pattern( + SequencePattern:12( + LiteralPattern:12( + 1) + LiteralPattern:12( + 2) + LiteralPattern:12( + 3))) + Body( + PassStmt:13()) + Pattern( + SequencePattern:14( + LiteralPattern:14( + 1) + StarredPattern:14( + CapturePattern:14( + NameExpr(a))) + LiteralPattern:14( + 2))) + Body( + PassStmt:15()) + Pattern( + SequencePattern:16( + LiteralPattern:16( + 1) + StarredPattern:16( + CapturePattern:16( + NameExpr(a))) + LiteralPattern:16( + 2))) + Body( + PassStmt:17()) + Pattern( + SequencePattern:18( + LiteralPattern:18( + 1) + StarredPattern:18( + WildcardPattern:18()) + LiteralPattern:18( + 2))) + Body( + PassStmt:19()) + Pattern( + SequencePattern:20( + LiteralPattern:20( + 1) + StarredPattern:20( + WildcardPattern:20()) + LiteralPattern:20( + 2))) + Body( + PassStmt:21()))) + +[case testMappingPattern] +match a: + case {'k': v}: + pass + case {a.b: v}: + pass + case {1: v}: + pass + case {a.c: v}: + pass + case {'k': v1, a.b: v2, 1: v3, a.c: v4}: + pass + case {'k1': 1, 'k2': "str", 'k3': b'bytes', 'k4': None}: + pass + case {'k': v, **r}: + pass + case {**r}: + pass +[out] +MypyFile:1( + MatchStmt:1( + NameExpr(a) + Pattern( + MappingPattern:2( + Key( + LiteralPattern:2( + 'k')) + Value( + CapturePattern:2( + NameExpr(v))))) + Body( + PassStmt:3()) + Pattern( + MappingPattern:4( + Key( + ValuePattern:4( + MemberExpr:4( + NameExpr(a) + b))) + Value( + CapturePattern:4( + NameExpr(v))))) + Body( + PassStmt:5()) + Pattern( + MappingPattern:6( + Key( + LiteralPattern:6( + 1)) + Value( + CapturePattern:6( + NameExpr(v))))) + Body( + PassStmt:7()) + Pattern( + MappingPattern:8( + Key( + ValuePattern:8( + MemberExpr:8( + NameExpr(a) + c))) + Value( + CapturePattern:8( + NameExpr(v))))) + Body( + PassStmt:9()) + Pattern( + MappingPattern:10( + Key( + LiteralPattern:10( + 'k')) + Value( + CapturePattern:10( + NameExpr(v1))) + Key( + ValuePattern:10( + MemberExpr:10( + NameExpr(a) + b))) + Value( + CapturePattern:10( + NameExpr(v2))) + Key( + LiteralPattern:10( + 1)) + Value( + CapturePattern:10( + NameExpr(v3))) + Key( + ValuePattern:10( + MemberExpr:10( + NameExpr(a) + c))) + Value( + CapturePattern:10( + NameExpr(v4))))) + Body( + PassStmt:11()) + Pattern( + MappingPattern:12( + Key( + LiteralPattern:12( + 'k1')) + Value( + LiteralPattern:12( + 1)) + Key( + LiteralPattern:12( + 'k2')) + Value( + LiteralPattern:12( + 'str')) + Key( + LiteralPattern:12( + 'k3')) + Value( + LiteralPattern:12( + b'bytes')) + Key( + LiteralPattern:12( + 'k4')) + Value( + LiteralPattern:12()))) + Body( + PassStmt:13()) + Pattern( + MappingPattern:14( + Key( + LiteralPattern:14( + 'k')) + Value( + CapturePattern:14( + NameExpr(v))) + Rest( + CapturePattern:14( + NameExpr(r))))) + Body( + PassStmt:15()) + Pattern( + MappingPattern:16( + Rest( + CapturePattern:16( + NameExpr(r))))) + Body( + PassStmt:17()))) + +[case testClassPattern] +match a: + case A(): + pass + case B(1, 2): + pass + case B(1, b=2): + pass + case B(a=1, b=2): + pass +[out] +MypyFile:1( + MatchStmt:1( + NameExpr(a) + Pattern( + ClassPattern:2( + NameExpr(A))) + Body( + PassStmt:3()) + Pattern( + ClassPattern:4( + NameExpr(B) + Positionals( + LiteralPattern:4( + 1) + LiteralPattern:4( + 2)))) + Body( + PassStmt:5()) + Pattern( + ClassPattern:6( + NameExpr(B) + Positionals( + LiteralPattern:6( + 1)) + Keyword( + b + LiteralPattern:6( + 2)))) + Body( + PassStmt:7()) + Pattern( + ClassPattern:8( + NameExpr(B) + Keyword( + a + LiteralPattern:8( + 1)) + Keyword( + b + LiteralPattern:8( + 2)))) + Body( + PassStmt:9()))) \ No newline at end of file From a74b0e6ea05d4bd06aae23806660fa9fdca9440e Mon Sep 17 00:00:00 2001 From: Adrian Freund Date: Wed, 10 Mar 2021 19:27:04 +0100 Subject: [PATCH 09/76] Added patterns to NodeVisitor and TraverserVisitor --- mypy/traverser.py | 31 ++++++++++++++++++++++++++++++- mypy/visitor.py | 35 +++++++++++++++++++++++++++++++++++ 2 files changed, 65 insertions(+), 1 deletion(-) diff --git a/mypy/traverser.py b/mypy/traverser.py index 33902d415479..8693d2038981 100644 --- a/mypy/traverser.py +++ b/mypy/traverser.py @@ -3,7 +3,8 @@ from typing import List from mypy_extensions import mypyc_attr -from mypy.patterns import AsPattern, OrPattern +from mypy.patterns import AsPattern, OrPattern, CapturePattern, ValuePattern, SequencePattern, \ + StarredPattern, MappingPattern, ClassPattern from mypy.visitor import NodeVisitor from mypy.nodes import ( Block, MypyFile, FuncBase, FuncItem, CallExpr, ClassDef, Decorator, FuncDef, @@ -296,6 +297,34 @@ def visit_or_pattern(self, o: OrPattern) -> None: for p in o.patterns: p.accept(self) + def visit_capture_pattern(self, o: CapturePattern) -> None: + o.name.accept(self) + + def visit_value_pattern(self, o: ValuePattern) -> None: + o.expr.accept(self) + + def visit_sequence_pattern(self, o: SequencePattern) -> None: + for p in o.patterns: + p.accept(self) + + def visit_starred_patten(self, o: StarredPattern) -> None: + o.capture.accept(self) + + def visit_mapping_pattern(self, o: MappingPattern) -> None: + for key in o.keys: + key.accept(self) + for value in o.values: + value.accept(self) + if o.rest is not None: + o.rest.accept(self) + + def visit_class_pattern(self, o: ClassPattern) -> None: + o.class_ref.accept(self) + for p in o.positionals: + p.accept(self) + for v in o.keyword_values: + v.accept(self) + def visit_import(self, o: Import) -> None: for a in o.assignments: a.accept(self) diff --git a/mypy/visitor.py b/mypy/visitor.py index 3d8b92901997..5dac830cfe17 100644 --- a/mypy/visitor.py +++ b/mypy/visitor.py @@ -478,6 +478,9 @@ def visit_print_stmt(self, o: 'mypy.nodes.PrintStmt') -> T: def visit_exec_stmt(self, o: 'mypy.nodes.ExecStmt') -> T: pass + def visit_match_stmt(self, o: 'mypy.nodes.MatchStmt') -> T: + pass + # Expressions (default no-op implementation) def visit_int_expr(self, o: 'mypy.nodes.IntExpr') -> T: @@ -611,3 +614,35 @@ def visit_await_expr(self, o: 'mypy.nodes.AwaitExpr') -> T: def visit_temp_node(self, o: 'mypy.nodes.TempNode') -> T: pass + + # Patterns + + def visit_as_pattern(self, o: 'mypy.patterns.AsPattern') -> T: + pass + + def visit_or_pattern(self, o: 'mypy.patterns.OrPattern') -> T: + pass + + def visit_literal_pattern(self, o: 'mypy.patterns.LiteralPattern') -> T: + pass + + def visit_capture_pattern(self, o: 'mypy.patterns.CapturePattern') -> T: + pass + + def visit_wildcard_pattern(self, o: 'mypy.patterns.WildcardPattern') -> T: + pass + + def visit_value_pattern(self, o: 'mypy.patterns.ValuePattern') -> T: + pass + + def visit_sequence_pattern(self, o: 'mypy.patterns.SequencePattern') -> T: + pass + + def visit_starred_pattern(self, o: 'mypy.patterns.StarredPattern') -> T: + pass + + def visit_mapping_pattern(self, o: 'mypy.patterns.MappingPattern') -> T: + pass + + def visit_class_pattern(self, o: 'mypy.patterns.ClassPattern') -> T: + pass From fc3c1d2033eee5ea588d165de7062a88221c08c8 Mon Sep 17 00:00:00 2001 From: Adrian Freund Date: Wed, 10 Mar 2021 22:14:54 +0100 Subject: [PATCH 10/76] Add missing parse test and prevent tests from running on < 3.10 --- mypy/test/testparse.py | 4 +++- test-data/unit/parse-python310.test | 29 +++++++++++++++++++++++++++++ 2 files changed, 32 insertions(+), 1 deletion(-) diff --git a/mypy/test/testparse.py b/mypy/test/testparse.py index c8147774cd86..1587147c0777 100644 --- a/mypy/test/testparse.py +++ b/mypy/test/testparse.py @@ -16,9 +16,11 @@ class ParserSuite(DataSuite): required_out_section = True base_path = '.' files = ['parse.test', - 'parse-python310.test', 'parse-python2.test'] + if sys.version_info >= (3, 10): + files.append('parse-python310.test') + def run_case(self, testcase: DataDrivenTestCase) -> None: test_parser(testcase) diff --git a/test-data/unit/parse-python310.test b/test-data/unit/parse-python310.test index 1216408d21a9..6d032a138760 100644 --- a/test-data/unit/parse-python310.test +++ b/test-data/unit/parse-python310.test @@ -16,6 +16,35 @@ MypyFile:1( Body( PassStmt:3()))) +[case testMatchWithGuard] +match a: + case 1 if f(): + pass + case d if d > 5: + pass +[out] +MypyFile:1( + MatchStmt:1( + NameExpr(a) + Pattern( + LiteralPattern:2( + 1)) + Guard( + CallExpr:2( + NameExpr(f) + Args())) + Body( + PassStmt:3()) + Pattern( + CapturePattern:4( + NameExpr(d))) + Guard( + ComparisonExpr:4( + > + NameExpr(d) + IntExpr(5))) + Body( + PassStmt:5()))) [case testAsPattern] match a: From db041867ef8f5c1c47a6453bc7e9fe21983ce9ff Mon Sep 17 00:00:00 2001 From: Adrian Freund Date: Wed, 10 Mar 2021 22:34:33 +0100 Subject: [PATCH 11/76] Add parse tests for open sequence patterns --- test-data/unit/parse-python310.test | 77 +++++++++++++++++++++++------ 1 file changed, 61 insertions(+), 16 deletions(-) diff --git a/test-data/unit/parse-python310.test b/test-data/unit/parse-python310.test index 6d032a138760..e09e2ffac9fe 100644 --- a/test-data/unit/parse-python310.test +++ b/test-data/unit/parse-python310.test @@ -243,18 +243,26 @@ match a: pass case (1,): pass + case 1,: + pass case [1, 2, 3]: pass case (1, 2, 3): pass + case 1, 2, 3: + pass case [1, *a, 2]: pass case (1, *a, 2): pass + case 1, *a, 2: + pass case [1, *_, 2]: pass case (1, *_, 2): pass + case 1, *_, 2: + pass [out] MypyFile:1( MatchStmt:1( @@ -282,11 +290,7 @@ MypyFile:1( Pattern( SequencePattern:10( LiteralPattern:10( - 1) - LiteralPattern:10( - 2) - LiteralPattern:10( - 3))) + 1))) Body( PassStmt:11()) Pattern( @@ -303,22 +307,20 @@ MypyFile:1( SequencePattern:14( LiteralPattern:14( 1) - StarredPattern:14( - CapturePattern:14( - NameExpr(a))) LiteralPattern:14( - 2))) + 2) + LiteralPattern:14( + 3))) Body( PassStmt:15()) Pattern( SequencePattern:16( LiteralPattern:16( 1) - StarredPattern:16( - CapturePattern:16( - NameExpr(a))) LiteralPattern:16( - 2))) + 2) + LiteralPattern:16( + 3))) Body( PassStmt:17()) Pattern( @@ -326,7 +328,8 @@ MypyFile:1( LiteralPattern:18( 1) StarredPattern:18( - WildcardPattern:18()) + CapturePattern:18( + NameExpr(a))) LiteralPattern:18( 2))) Body( @@ -336,11 +339,53 @@ MypyFile:1( LiteralPattern:20( 1) StarredPattern:20( - WildcardPattern:20()) + CapturePattern:20( + NameExpr(a))) LiteralPattern:20( 2))) Body( - PassStmt:21()))) + PassStmt:21()) + Pattern( + SequencePattern:22( + LiteralPattern:22( + 1) + StarredPattern:22( + CapturePattern:22( + NameExpr(a))) + LiteralPattern:22( + 2))) + Body( + PassStmt:23()) + Pattern( + SequencePattern:24( + LiteralPattern:24( + 1) + StarredPattern:24( + WildcardPattern:24()) + LiteralPattern:24( + 2))) + Body( + PassStmt:25()) + Pattern( + SequencePattern:26( + LiteralPattern:26( + 1) + StarredPattern:26( + WildcardPattern:26()) + LiteralPattern:26( + 2))) + Body( + PassStmt:27()) + Pattern( + SequencePattern:28( + LiteralPattern:28( + 1) + StarredPattern:28( + WildcardPattern:28()) + LiteralPattern:28( + 2))) + Body( + PassStmt:29()))) [case testMappingPattern] match a: From eeca25b5e228108c331514422bfe63d43836746b Mon Sep 17 00:00:00 2001 From: Adrian Freund Date: Fri, 12 Mar 2021 12:59:20 +0100 Subject: [PATCH 12/76] Added match statement support to SemanticAnalyzerPreAnalysis --- mypy/reachability.py | 21 ++++++++++++++++++--- mypy/semanal_pass1.py | 15 +++++++++++++-- 2 files changed, 31 insertions(+), 5 deletions(-) diff --git a/mypy/reachability.py b/mypy/reachability.py index 5ee813dc982c..0911e1ae8eed 100644 --- a/mypy/reachability.py +++ b/mypy/reachability.py @@ -4,9 +4,9 @@ from typing_extensions import Final from mypy.nodes import ( - Expression, IfStmt, Block, AssertStmt, NameExpr, UnaryExpr, MemberExpr, OpExpr, ComparisonExpr, - StrExpr, UnicodeExpr, CallExpr, IntExpr, TupleExpr, IndexExpr, SliceExpr, Import, ImportFrom, - ImportAll, LITERAL_YES + Expression, IfStmt, Block, AssertStmt, MatchStmt, NameExpr, UnaryExpr, MemberExpr, OpExpr, + ComparisonExpr, StrExpr, UnicodeExpr, CallExpr, IntExpr, TupleExpr, IndexExpr, SliceExpr, + Import, ImportFrom, ImportAll, LITERAL_YES ) from mypy.options import Options from mypy.traverser import TraverserVisitor @@ -54,6 +54,21 @@ def infer_reachability_of_if_statement(s: IfStmt, options: Options) -> None: break +def infer_reachability_of_match_statement(s: MatchStmt, options: Options) -> None: + for i, guard in enumerate(s.guards): + # Right now we only consider the guard to infer unreachability. + # In the future this could also consider the pattern + if guard is not None: + result = infer_condition_value(guard, options) + if result in (ALWAYS_FALSE, MYPY_FALSE): + # The guard is considered always false, so we skip the case body. + mark_block_unreachable(s.bodies[i]) + elif result == MYPY_TRUE: + # This condition is false at runtime; this will affect + # import priorities. + mark_block_mypy_only(s.bodies[i]) + + def assert_will_always_fail(s: AssertStmt, options: Options) -> bool: return infer_condition_value(s.expr, options) in (ALWAYS_FALSE, MYPY_FALSE) diff --git a/mypy/semanal_pass1.py b/mypy/semanal_pass1.py index 0296788e3990..2b096f08082a 100644 --- a/mypy/semanal_pass1.py +++ b/mypy/semanal_pass1.py @@ -2,11 +2,14 @@ from mypy.nodes import ( MypyFile, AssertStmt, IfStmt, Block, AssignmentStmt, ExpressionStmt, ReturnStmt, ForStmt, - Import, ImportAll, ImportFrom, ClassDef, FuncDef + MatchStmt, Import, ImportAll, ImportFrom, ClassDef, FuncDef ) from mypy.traverser import TraverserVisitor from mypy.options import Options -from mypy.reachability import infer_reachability_of_if_statement, assert_will_always_fail +from mypy.reachability import ( + infer_reachability_of_if_statement, assert_will_always_fail, + infer_reachability_of_match_statement +) class SemanticAnalyzerPreAnalysis(TraverserVisitor): @@ -102,6 +105,14 @@ def visit_block(self, b: Block) -> None: return super().visit_block(b) + def visit_match_stmt(self, s: MatchStmt) -> None: + infer_reachability_of_match_statement(s, self.options) + for guard in s.guards: + if guard is not None: + guard.accept(self) + for body in s.bodies: + body.accept(self) + # The remaining methods are an optimization: don't visit nested expressions # of common statements, since they can have no effect. From cc76f59d3c51a66a7463654f5107c2e3b1f4ac21 Mon Sep 17 00:00:00 2001 From: Adrian Freund Date: Fri, 12 Mar 2021 13:59:01 +0100 Subject: [PATCH 13/76] Also consider pattern in SemanticAnalyzerPreAnalysis --- mypy/fastparse.py | 14 +++++++++++--- mypy/nodes.py | 7 +++++-- mypy/patterns.py | 1 + mypy/reachability.py | 37 +++++++++++++++++++++++++++---------- 4 files changed, 44 insertions(+), 15 deletions(-) diff --git a/mypy/fastparse.py b/mypy/fastparse.py index e249fab228c2..48da283c9432 100644 --- a/mypy/fastparse.py +++ b/mypy/fastparse.py @@ -33,7 +33,7 @@ ) from mypy.patterns import ( AsPattern, OrPattern, LiteralPattern, CapturePattern, WildcardPattern, ValuePattern, - SequencePattern, StarredPattern, MappingPattern, ClassPattern + SequencePattern, StarredPattern, MappingPattern, MappingKeyPattern, ClassPattern, Pattern ) from mypy.types import ( Type, CallableType, AnyType, UnboundType, TupleType, TypeList, EllipsisType, CallableArgument, @@ -1323,6 +1323,9 @@ def __init__(self, options: Options, errors: Errors) -> None: self.options = options + def visit(self, node: Optional[AST]) -> Pattern: + return super().visit(node) + # MatchAs(expr pattern, identifier name) def visit_MatchAs(self, n: MatchAs) -> AsPattern: node = AsPattern(self.visit(n.pattern), n.name) @@ -1443,12 +1446,17 @@ def visit_Dict(self, n: ast3.Dict) -> MappingPattern: else: rest = None + checked_keys = self.assert_key_patterns(keys) + + node = MappingPattern(checked_keys, values, rest) + return self.set_line(node, n) + + def assert_key_patterns(self, keys: List[Pattern]) -> List[MappingKeyPattern]: for key in keys: if not isinstance(key, (LiteralPattern, ValuePattern)): raise RuntimeError("Unsupported Pattern") - node = MappingPattern(keys, values, rest) - return self.set_line(node, n) + return cast(List[MappingKeyPattern], keys) # Call(expr func, expr* args, keyword* keywords) def visit_Call(self, n: ast3.Call) -> ClassPattern: diff --git a/mypy/nodes.py b/mypy/nodes.py index 662bd4382d5f..c5016c80c438 100644 --- a/mypy/nodes.py +++ b/mypy/nodes.py @@ -16,6 +16,9 @@ from mypy.bogus_type import Bogus +if TYPE_CHECKING: + from mypy.patterns import Pattern + class Context: """Base type for objects that are valid as error message locations.""" @@ -1270,11 +1273,11 @@ def accept(self, visitor: StatementVisitor[T]) -> T: class MatchStmt(Statement): subject = None # type: Expression - patterns = None # type: List[Expression] + patterns = None # type: List['Pattern'] guards = None # type: List[Optional[Expression]] bodies = None # type: List[Block] - def __init__(self, subject: Expression, patterns: List[Expression], + def __init__(self, subject: Expression, patterns: List['Pattern'], guards: List[Optional[Expression]], bodies: List[Block]) -> None: super().__init__() self.subject = subject diff --git a/mypy/patterns.py b/mypy/patterns.py index af5ca29df393..57131bd0a3f1 100644 --- a/mypy/patterns.py +++ b/mypy/patterns.py @@ -108,6 +108,7 @@ def accept(self, visitor: PatternVisitor[T]) -> T: return visitor.visit_starred_pattern(self) +# When changing this make sure to also change assert_key_pattern in fastparse.py PatternConverter MappingKeyPattern = Union[LiteralPattern, ValuePattern] diff --git a/mypy/reachability.py b/mypy/reachability.py index 0911e1ae8eed..2a300984d4b1 100644 --- a/mypy/reachability.py +++ b/mypy/reachability.py @@ -9,6 +9,7 @@ Import, ImportFrom, ImportAll, LITERAL_YES ) from mypy.options import Options +from mypy.patterns import Pattern, WildcardPattern, CapturePattern from mypy.traverser import TraverserVisitor from mypy.literals import literal @@ -56,17 +57,26 @@ def infer_reachability_of_if_statement(s: IfStmt, options: Options) -> None: def infer_reachability_of_match_statement(s: MatchStmt, options: Options) -> None: for i, guard in enumerate(s.guards): - # Right now we only consider the guard to infer unreachability. - # In the future this could also consider the pattern + pattern_value = infer_pattern_value(s.patterns[i]) + if guard is not None: - result = infer_condition_value(guard, options) - if result in (ALWAYS_FALSE, MYPY_FALSE): - # The guard is considered always false, so we skip the case body. - mark_block_unreachable(s.bodies[i]) - elif result == MYPY_TRUE: - # This condition is false at runtime; this will affect - # import priorities. - mark_block_mypy_only(s.bodies[i]) + guard_value = infer_condition_value(guard, options) + else: + guard_value = ALWAYS_TRUE + + if pattern_value in (ALWAYS_FALSE, MYPY_FALSE) \ + or guard_value in (ALWAYS_FALSE, MYPY_FALSE): + # The case is considered always false, so we skip the case body. + mark_block_unreachable(s.bodies[i]) + elif pattern_value in (ALWAYS_FALSE, MYPY_TRUE) \ + and guard_value in (ALWAYS_TRUE, MYPY_TRUE): + for body in s.bodies[i + 1:]: + mark_block_unreachable(body) + + if guard_value == MYPY_TRUE: + # This condition is false at runtime; this will affect + # import priorities. + mark_block_mypy_only(s.bodies[i]) def assert_will_always_fail(s: AssertStmt, options: Options) -> bool: @@ -124,6 +134,13 @@ def infer_condition_value(expr: Expression, options: Options) -> int: return result +def infer_pattern_value(pattern: Pattern) -> int: + if isinstance(pattern, (WildcardPattern, CapturePattern)): + return ALWAYS_TRUE + else: + return TRUTH_VALUE_UNKNOWN + + def consider_sys_version_info(expr: Expression, pyversion: Tuple[int, ...]) -> int: """Consider whether expr is a comparison involving sys.version_info. From 36a45b34c2935eb6faabbd27c52d0bf43335e523 Mon Sep 17 00:00:00 2001 From: Adrian Freund Date: Fri, 12 Mar 2021 14:10:59 +0100 Subject: [PATCH 14/76] Use common base classes instead of unions for patterns --- mypy/fastparse.py | 2 +- mypy/patterns.py | 32 +++++++++++++++++++++----------- 2 files changed, 22 insertions(+), 12 deletions(-) diff --git a/mypy/fastparse.py b/mypy/fastparse.py index 48da283c9432..99a428cdd635 100644 --- a/mypy/fastparse.py +++ b/mypy/fastparse.py @@ -1453,7 +1453,7 @@ def visit_Dict(self, n: ast3.Dict) -> MappingPattern: def assert_key_patterns(self, keys: List[Pattern]) -> List[MappingKeyPattern]: for key in keys: - if not isinstance(key, (LiteralPattern, ValuePattern)): + if not isinstance(key, MappingKeyPattern): raise RuntimeError("Unsupported Pattern") return cast(List[MappingKeyPattern], keys) diff --git a/mypy/patterns.py b/mypy/patterns.py index 57131bd0a3f1..b2b47a7cc9bf 100644 --- a/mypy/patterns.py +++ b/mypy/patterns.py @@ -1,5 +1,5 @@ """Classes for representing match statement patterns.""" -from typing import TypeVar, List, Any, Union, Optional +from typing import TypeVar, List, Any, Optional from mypy_extensions import trait @@ -21,6 +21,20 @@ def accept(self, visitor: PatternVisitor[T]) -> T: raise RuntimeError('Not implemented') +@trait +class AlwaysTruePattern(Pattern): + """A pattern that is always matches""" + + __slots__ = () + + +@trait +class MappingKeyPattern(Pattern): + """A pattern that can be used as a key in a mapping pattern""" + + __slots__ = () + + class AsPattern(Pattern): pattern = None # type: Pattern name = None # type: str @@ -46,7 +60,7 @@ def accept(self, visitor: PatternVisitor[T]) -> T: # TODO: Do we need subclasses for the types of literals? -class LiteralPattern(Pattern): +class LiteralPattern(MappingKeyPattern): value = None # type: Any def __init__(self, value: Any): @@ -57,7 +71,7 @@ def accept(self, visitor: PatternVisitor[T]) -> T: return visitor.visit_literal_pattern(self) -class CapturePattern(Pattern): +class CapturePattern(AlwaysTruePattern): name = None # type: NameExpr def __init__(self, name: NameExpr): @@ -68,12 +82,12 @@ def accept(self, visitor: PatternVisitor[T]) -> T: return visitor.visit_capture_pattern(self) -class WildcardPattern(Pattern): +class WildcardPattern(AlwaysTruePattern): def accept(self, visitor: PatternVisitor[T]) -> T: return visitor.visit_wildcard_pattern(self) -class ValuePattern(Pattern): +class ValuePattern(MappingKeyPattern): expr = None # type: MemberExpr def __init__(self, expr: MemberExpr): @@ -98,9 +112,9 @@ def accept(self, visitor: PatternVisitor[T]) -> T: # TODO: A StarredPattern is only valid within a SequencePattern. This is not guaranteed by our # type hierarchy. Should it be? class StarredPattern(Pattern): - capture = None # type: Union[CapturePattern, WildcardPattern] + capture = None # type: AlwaysTruePattern - def __init__(self, capture: Union[CapturePattern, WildcardPattern]): + def __init__(self, capture: AlwaysTruePattern): super().__init__() self.capture = capture @@ -108,10 +122,6 @@ def accept(self, visitor: PatternVisitor[T]) -> T: return visitor.visit_starred_pattern(self) -# When changing this make sure to also change assert_key_pattern in fastparse.py PatternConverter -MappingKeyPattern = Union[LiteralPattern, ValuePattern] - - class MappingPattern(Pattern): keys = None # type: List[MappingKeyPattern] values = None # type: List[Pattern] From 7facb0d47a26a91941c96bc9e9504239cf9d61c9 Mon Sep 17 00:00:00 2001 From: Adrian Freund Date: Fri, 12 Mar 2021 18:28:17 +0100 Subject: [PATCH 15/76] Added match statement support for variable renaming (allow-redefinition) --- mypy/renaming.py | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/mypy/renaming.py b/mypy/renaming.py index 56eb623afe8a..584b34a99b87 100644 --- a/mypy/renaming.py +++ b/mypy/renaming.py @@ -3,9 +3,10 @@ from mypy.nodes import ( Block, AssignmentStmt, NameExpr, MypyFile, FuncDef, Lvalue, ListExpr, TupleExpr, - WhileStmt, ForStmt, BreakStmt, ContinueStmt, TryStmt, WithStmt, StarExpr, ImportFrom, - MemberExpr, IndexExpr, Import, ClassDef + WhileStmt, ForStmt, BreakStmt, ContinueStmt, TryStmt, WithStmt, MatchStmt, StarExpr, + ImportFrom, MemberExpr, IndexExpr, Import, ClassDef ) +from mypy.patterns import CapturePattern from mypy.traverser import TraverserVisitor # Scope kinds @@ -173,6 +174,21 @@ def visit_assignment_stmt(self, s: AssignmentStmt) -> None: for lvalue in s.lvalues: self.analyze_lvalue(lvalue) + def visit_match_stmt(self, s: MatchStmt) -> None: + for i in range(len(s.patterns)): + self.enter_block() + s.patterns[i].accept(self) + guard = s.guards[i] + if guard is not None: + guard.accept(self) + # We already entered a block, so visit this block's statements directly + for stmt in s.bodies[i].body: + stmt.accept(self) + self.leave_block() + + def visit_capture_pattern(self, p: CapturePattern) -> None: + self.analyze_lvalue(p.name) + def analyze_lvalue(self, lvalue: Lvalue, is_nested: bool = False) -> None: """Process assignment; in particular, keep track of (re)defined names. From 1bc8a4bac2588638da2616cfd721ea9ee619d1bf Mon Sep 17 00:00:00 2001 From: Adrian Freund Date: Sat, 13 Mar 2021 12:55:04 +0100 Subject: [PATCH 16/76] Added match statement support for SemanticAnalyzer --- mypy/fastparse.py | 2 +- mypy/patterns.py | 4 +- mypy/semanal.py | 62 ++++++- mypy/strconv.py | 3 +- mypy/test/helpers.py | 2 + mypy/test/testsemanal.py | 7 + test-data/unit/parse-python310.test | 21 ++- test-data/unit/semanal-errors-python310.test | 25 +++ test-data/unit/semanal-python310.test | 184 +++++++++++++++++++ 9 files changed, 300 insertions(+), 10 deletions(-) create mode 100644 test-data/unit/semanal-errors-python310.test create mode 100644 test-data/unit/semanal-python310.test diff --git a/mypy/fastparse.py b/mypy/fastparse.py index 99a428cdd635..1c8cf637cd23 100644 --- a/mypy/fastparse.py +++ b/mypy/fastparse.py @@ -1328,7 +1328,7 @@ def visit(self, node: Optional[AST]) -> Pattern: # MatchAs(expr pattern, identifier name) def visit_MatchAs(self, n: MatchAs) -> AsPattern: - node = AsPattern(self.visit(n.pattern), n.name) + node = AsPattern(self.visit(n.pattern), NameExpr(n.name)) return self.set_line(node, n) # MatchOr(expr* pattern) diff --git a/mypy/patterns.py b/mypy/patterns.py index b2b47a7cc9bf..f63e7c16057c 100644 --- a/mypy/patterns.py +++ b/mypy/patterns.py @@ -37,9 +37,9 @@ class MappingKeyPattern(Pattern): class AsPattern(Pattern): pattern = None # type: Pattern - name = None # type: str + name = None # type: NameExpr - def __init__(self, pattern: Pattern, name: str) -> None: + def __init__(self, pattern: Pattern, name: NameExpr) -> None: super().__init__() self.pattern = pattern self.name = name diff --git a/mypy/semanal.py b/mypy/semanal.py index 115e490d3e0e..0f324562ddef 100644 --- a/mypy/semanal.py +++ b/mypy/semanal.py @@ -76,7 +76,12 @@ get_nongen_builtins, get_member_expr_fullname, REVEAL_TYPE, REVEAL_LOCALS, is_final_node, TypedDictExpr, type_aliases_source_versions, EnumCallExpr, RUNTIME_PROTOCOL_DECOS, FakeExpression, Statement, AssignmentExpr, - ParamSpecExpr + ParamSpecExpr, MatchStmt +) +from mypy.patterns import ( + AsPattern, OrPattern, CapturePattern, ValuePattern, SequencePattern, StarredPattern, + MappingPattern, ClassPattern + ) from mypy.tvar_scope import TypeVarLikeScope from mypy.typevars import fill_typevars @@ -116,8 +121,8 @@ from mypy.semanal_enum import EnumCallAnalyzer from mypy.semanal_newtype import NewTypeAnalyzer from mypy.reachability import ( - infer_reachability_of_if_statement, infer_condition_value, ALWAYS_FALSE, ALWAYS_TRUE, - MYPY_TRUE, MYPY_FALSE + infer_reachability_of_if_statement, infer_reachability_of_match_statement, + infer_condition_value, ALWAYS_FALSE, ALWAYS_TRUE, MYPY_TRUE, MYPY_FALSE ) from mypy.mro import calculate_mro, MroError @@ -3501,6 +3506,17 @@ def visit_exec_stmt(self, s: ExecStmt) -> None: if s.locals: s.locals.accept(self) + def visit_match_stmt(self, s: MatchStmt) -> None: + self.statement = s + infer_reachability_of_match_statement(s, self.options) + s.subject.accept(self) + for i in range(len(s.patterns)): + s.patterns[i].accept(self) + guard = s.guards[i] + if guard is not None: + guard.accept(self) + self.visit_block(s.bodies[i]) + # # Expressions # @@ -3965,6 +3981,46 @@ def visit_await_expr(self, expr: AwaitExpr) -> None: self.fail("'await' outside coroutine ('async def')", expr) expr.expr.accept(self) + # + # Patterns + # + + def visit_as_pattern(self, p: AsPattern) -> None: + p.pattern.accept(self) + self.analyze_lvalue(p.name) + + def visit_or_pattern(self, p: OrPattern) -> None: + for pattern in p.patterns: + pattern.accept(self) + + def visit_capture_pattern(self, p: CapturePattern) -> None: + self.analyze_lvalue(p.name) + + def visit_value_pattern(self, p: ValuePattern) -> None: + p.expr.accept(self) + + def visit_sequence_pattern(self, p: SequencePattern) -> None: + for pattern in p.patterns: + pattern.accept(self) + + def visit_starred_pattern(self, p: StarredPattern) -> None: + p.capture.accept(self) + + def visit_mapping_pattern(self, p: MappingPattern) -> None: + for key in p.keys: + key.accept(self) + for value in p.values: + value.accept(self) + if p.rest is not None: + p.rest.accept(self) + + def visit_class_pattern(self, p: ClassPattern) -> None: + p.class_ref.accept(self) + for p in p.positionals: + p.accept(self) + for v in p.keyword_values: + v.accept(self) + # # Lookup functions # diff --git a/mypy/strconv.py b/mypy/strconv.py index ecccea57a636..05296edb6992 100644 --- a/mypy/strconv.py +++ b/mypy/strconv.py @@ -549,8 +549,7 @@ def visit_temp_node(self, o: 'mypy.nodes.TempNode') -> str: return self.dump([o.type], o) def visit_as_pattern(self, o: 'mypy.patterns.AsPattern') -> str: - # We display the name first for better readability - return self.dump([o.name, o.pattern], o) + return self.dump([o.pattern, o.name], o) def visit_or_pattern(self, o: 'mypy.patterns.OrPattern') -> str: return self.dump(o.patterns, o) diff --git a/mypy/test/helpers.py b/mypy/test/helpers.py index 077c2f369cda..298f30a88116 100644 --- a/mypy/test/helpers.py +++ b/mypy/test/helpers.py @@ -284,6 +284,8 @@ def num_skipped_suffix_lines(a1: List[str], a2: List[str]) -> int: def testfile_pyversion(path: str) -> Tuple[int, int]: if path.endswith('python2.test'): return defaults.PYTHON2_VERSION + elif path.endswith('python310.test'): + return 3, 10 else: return defaults.PYTHON3_VERSION diff --git a/mypy/test/testsemanal.py b/mypy/test/testsemanal.py index e42a84e8365b..afaa911ae50e 100644 --- a/mypy/test/testsemanal.py +++ b/mypy/test/testsemanal.py @@ -1,6 +1,7 @@ """Semantic analyzer test cases""" import os.path +import sys from typing import Dict, List @@ -35,6 +36,10 @@ 'semanal-python2.test'] +if sys.version_info >= (3, 10): + semanal_files.append('semanal-python310.test') + + def get_semanal_options(program_text: str, testcase: DataDrivenTestCase) -> Options: options = parse_options(program_text, testcase, 1) options.use_builtins_fixtures = True @@ -101,6 +106,8 @@ def test_semanal(testcase: DataDrivenTestCase) -> None: class SemAnalErrorSuite(DataSuite): files = ['semanal-errors.test'] + if sys.version_info >= (3, 10): + semanal_files.append('semanal-errors-python310.test') def run_case(self, testcase: DataDrivenTestCase) -> None: test_semanal_error(testcase) diff --git a/test-data/unit/parse-python310.test b/test-data/unit/parse-python310.test index e09e2ffac9fe..9f3669f5b0b8 100644 --- a/test-data/unit/parse-python310.test +++ b/test-data/unit/parse-python310.test @@ -16,6 +16,23 @@ MypyFile:1( Body( PassStmt:3()))) + +[case testTupleMatch] +match a, b: + case 1: + pass +[out] +MypyFile:1( + MatchStmt:1( + TupleExpr:1( + NameExpr(a) + NameExpr(b)) + Pattern( + LiteralPattern:2( + 1)) + Body( + PassStmt:3()))) + [case testMatchWithGuard] match a: case 1 if f(): @@ -56,9 +73,9 @@ MypyFile:1( NameExpr(a) Pattern( AsPattern:2( - b LiteralPattern:2( - 1))) + 1) + NameExpr(b))) Body( PassStmt:3()))) diff --git a/test-data/unit/semanal-errors-python310.test b/test-data/unit/semanal-errors-python310.test new file mode 100644 index 000000000000..e8eda6ca87d7 --- /dev/null +++ b/test-data/unit/semanal-errors-python310.test @@ -0,0 +1,25 @@ +[case testMatchUndefinedSubject] +import typing +match x: + case _: + pass +[out] +main:2: error: Name 'x' is not defined + +[case testNoneBindingWildcardPattern] +import typing +x = 1 +match x: + case _: + _ +[out] +main:5: error: Name '_' is not defined + +[case testNoneBindingStarredWildcardPattern] +import typing +x = 1 +match x: + case [*_]: + _ +[out] +main:5: error: Name '_' is not defined diff --git a/test-data/unit/semanal-python310.test b/test-data/unit/semanal-python310.test new file mode 100644 index 000000000000..b1e6fef0bf81 --- /dev/null +++ b/test-data/unit/semanal-python310.test @@ -0,0 +1,184 @@ +-- Python 3.10 semantic analysis test cases. + +[case testCapturePattern] +x = 1 +match x: + case a: + a +[out] +MypyFile:1( + AssignmentStmt:1( + NameExpr(x* [__main__.x]) + IntExpr(1)) + MatchStmt:2( + NameExpr(x [__main__.x]) + Pattern( + CapturePattern:3( + NameExpr(a* [__main__.a]))) + Body( + ExpressionStmt:4( + NameExpr(a [__main__.a]))))) + +[case testCapturePatternOutliving] +x = 1 +match x: + case a: + pass +a +[out] +MypyFile:1( + AssignmentStmt:1( + NameExpr(x* [__main__.x]) + IntExpr(1)) + MatchStmt:2( + NameExpr(x [__main__.x]) + Pattern( + CapturePattern:3( + NameExpr(a* [__main__.a]))) + Body( + PassStmt:4())) + ExpressionStmt:5( + NameExpr(a [__main__.a]))) + +[case testNestedCapturePatterns] +x = 1 +match x: + case ([a], {'k': b}): + a + b +[out] +MypyFile:1( + AssignmentStmt:1( + NameExpr(x* [__main__.x]) + IntExpr(1)) + MatchStmt:2( + NameExpr(x [__main__.x]) + Pattern( + SequencePattern:3( + SequencePattern:3( + CapturePattern:3( + NameExpr(a* [__main__.a]))) + MappingPattern:3( + Key( + LiteralPattern:3( + 'k')) + Value( + CapturePattern:3( + NameExpr(b* [__main__.b])))))) + Body( + ExpressionStmt:4( + NameExpr(a [__main__.a])) + ExpressionStmt:5( + NameExpr(b [__main__.b]))))) + +[case testAsPattern] +x = 1 +match x: + case 1 as a: + a +[out] +MypyFile:1( + AssignmentStmt:1( + NameExpr(x* [__main__.x]) + IntExpr(1)) + MatchStmt:2( + NameExpr(x [__main__.x]) + Pattern( + AsPattern:3( + LiteralPattern:3( + 1) + NameExpr(a* [__main__.a]))) + Body( + ExpressionStmt:4( + NameExpr(a [__main__.a]))))) + +[case testGuard] +x = 1 +a = 1 +match x: + case 1 if a: + pass +[out] +MypyFile:1( + AssignmentStmt:1( + NameExpr(x* [__main__.x]) + IntExpr(1)) + AssignmentStmt:2( + NameExpr(a* [__main__.a]) + IntExpr(1)) + MatchStmt:3( + NameExpr(x [__main__.x]) + Pattern( + LiteralPattern:4( + 1)) + Guard( + NameExpr(a [__main__.a])) + Body( + PassStmt:5()))) + +[case testCapturePatternInGuard] +x = 1 +match x: + case a if a: + pass +[out] +MypyFile:1( + AssignmentStmt:1( + NameExpr(x* [__main__.x]) + IntExpr(1)) + MatchStmt:2( + NameExpr(x [__main__.x]) + Pattern( + CapturePattern:3( + NameExpr(a* [__main__.a]))) + Guard( + NameExpr(a [__main__.a])) + Body( + PassStmt:4()))) + +[case testAsPatternInGuard] +x = 1 +match x: + case 1 as a if a: + pass +[out] +MypyFile:1( + AssignmentStmt:1( + NameExpr(x* [__main__.x]) + IntExpr(1)) + MatchStmt:2( + NameExpr(x [__main__.x]) + Pattern( + AsPattern:3( + LiteralPattern:3( + 1) + NameExpr(a* [__main__.a]))) + Guard( + NameExpr(a [__main__.a])) + Body( + PassStmt:4()))) + +[case testValuePattern] +import _a + +x = 1 +match x: + case _a.b: + pass +[file _a.py] +b = 1 +[out] +MypyFile:1( + Import:1(_a) + AssignmentStmt:3( + NameExpr(x* [__main__.x]) + IntExpr(1)) + MatchStmt:4( + NameExpr(x [__main__.x]) + Pattern( + ValuePattern:5( + MemberExpr:5( + NameExpr(_a) + b [_a.b]))) + Body( + PassStmt:6()))) From d26d81c3c3305d8912d639a06ee5561241c6115f Mon Sep 17 00:00:00 2001 From: Adrian Freund Date: Mon, 5 Apr 2021 21:53:15 +0200 Subject: [PATCH 17/76] Moved int from types to builtin stubs in order to extend int --- test-data/unit/lib-stub/builtins.pyi | 1 + test-data/unit/lib-stub/types.pyi | 2 -- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/test-data/unit/lib-stub/builtins.pyi b/test-data/unit/lib-stub/builtins.pyi index 7ba4002ed4ac..70c24743b62a 100644 --- a/test-data/unit/lib-stub/builtins.pyi +++ b/test-data/unit/lib-stub/builtins.pyi @@ -12,6 +12,7 @@ class type: class int: def __add__(self, other: int) -> int: pass class float: pass +class bool(int): pass class str: pass class bytes: pass diff --git a/test-data/unit/lib-stub/types.pyi b/test-data/unit/lib-stub/types.pyi index 02113aea3834..8c3fd838d937 100644 --- a/test-data/unit/lib-stub/types.pyi +++ b/test-data/unit/lib-stub/types.pyi @@ -4,7 +4,5 @@ _T = TypeVar('_T') def coroutine(func: _T) -> _T: pass -class bool: ... - class ModuleType: __file__ = ... # type: str From 3effadc803c106361edb7dcd24660799e3b10327 Mon Sep 17 00:00:00 2001 From: Adrian Freund Date: Tue, 6 Apr 2021 14:50:12 +0200 Subject: [PATCH 18/76] Added initial version of match statement type checking --- mypy/checker.py | 53 ++- mypy/checkexpr.py | 10 +- mypy/checkpattern.py | 312 +++++++++++++ mypy/errorcodes.py | 2 + mypy/fastparse.py | 15 +- mypy/message_registry.py | 1 + mypy/messages.py | 2 +- mypy/patterns.py | 28 +- mypy/semanal.py | 4 +- mypy/strconv.py | 2 +- mypy/traverser.py | 5 +- test-data/unit/check-python310.test | 604 ++++++++++++++++++++------ test-data/unit/parse-python310.test | 200 ++++++--- test-data/unit/semanal-python310.test | 12 +- 14 files changed, 1035 insertions(+), 215 deletions(-) create mode 100644 mypy/checkpattern.py diff --git a/mypy/checker.py b/mypy/checker.py index fce7e7d7a08e..499c9d07fc60 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -10,6 +10,7 @@ ) from typing_extensions import Final +from mypy.checkpattern import PatternChecker from mypy.errors import Errors, report_internal_error from mypy.nodes import ( SymbolTable, Statement, MypyFile, Var, Expression, Lvalue, Node, @@ -24,8 +25,7 @@ Import, ImportFrom, ImportAll, ImportBase, TypeAlias, ARG_POS, ARG_STAR, LITERAL_TYPE, MDEF, GDEF, CONTRAVARIANT, COVARIANT, INVARIANT, TypeVarExpr, AssignmentExpr, - is_final_node, - ARG_NAMED) + is_final_node, ARG_NAMED, MatchStmt) from mypy import nodes from mypy.literals import literal, literal_hash, Key from mypy.typeanal import has_any_from_unimported_type, check_for_explicit_any @@ -1385,6 +1385,19 @@ def check_setattr_method(self, typ: Type, context: Context) -> None: if not is_subtype(typ, method_type): self.msg.invalid_signature_for_special_method(typ, context, '__setattr__') + def check_match_args(self, var: Var, typ: Type, context: Context) -> None: + """Check that __match_args__ is final and contains literal strings""" + + if not var.is_final: + self.note("__match_args__ must be final for checking of match statements to work", + context, code=codes.LITERAL_REQ) + + typ = get_proper_type(typ) + if not isinstance(typ, TupleType) or \ + not all([is_string_literal(item) for item in typ.items]): + self.msg.note("__match_args__ must be a tuple containing string literals for checking " + "of match statements to work", context, code=codes.LITERAL_REQ) + def expand_typevars(self, defn: FuncItem, typ: CallableType) -> List[Tuple[FuncItem, CallableType]]: # TODO use generator @@ -2066,6 +2079,10 @@ def check_assignment(self, lvalue: Lvalue, rvalue: Expression, infer_lvalue_type else: self.check_getattr_method(signature, lvalue, name) + if name == '__match_args__' and inferred is not None: + typ = self.expr_checker.accept(rvalue) + self.check_match_args(inferred, typ, lvalue) + # Defer PartialType's super type checking. if (isinstance(lvalue, RefExpr) and not (isinstance(lvalue_type, PartialType) and lvalue_type.type is None)): @@ -3704,6 +3721,33 @@ def visit_continue_stmt(self, s: ContinueStmt) -> None: self.binder.handle_continue() return None + def visit_match_stmt(self, s: MatchStmt) -> None: + with self.binder.frame_context(can_skip=False, fall_through=0): + t = get_proper_type(self.expr_checker.accept(s.subject)) + + if isinstance(t, DeletedType): + self.msg.deleted_as_rvalue(t, s) + + pattern_checker = PatternChecker(self, self.msg, self.plugin, s.subject, t) + + for p, g, b in zip(s.patterns, s.guards, s.bodies): + if not b.is_unreachable: + type_map = pattern_checker.check_pattern(p) + else: + type_map = None + with self.binder.frame_context(can_skip=True, fall_through=2): + self.push_type_map(type_map) + if g is not None: + gt = get_proper_type(self.expr_checker.accept(g)) + + if isinstance(gt, DeletedType): + self.msg.deleted_as_rvalue(gt, s) + + if_map, else_map = self.find_isinstance_check(g) + + self.push_type_map(if_map) + self.accept(b) + def make_fake_typeinfo(self, curr_module_fullname: str, class_gen_name: str, @@ -5801,6 +5845,11 @@ def is_private(node_name: str) -> bool: return node_name.startswith('__') and not node_name.endswith('__') +def is_string_literal(typ: Type) -> bool: + strs = try_getting_str_literals_from_type(typ) + return strs is not None and len(strs) == 1 + + def has_bool_item(typ: ProperType) -> bool: """Return True if type is 'bool' or a union with a 'bool' item.""" if is_named_instance(typ, 'builtins.bool'): diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index cbfa5dbc0b4e..6fa37a446367 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -3015,7 +3015,11 @@ def nonliteral_tuple_index_helper(self, left_type: TupleType, index: Expression) else: return union - def visit_typeddict_index_expr(self, td_type: TypedDictType, index: Expression) -> Type: + def visit_typeddict_index_expr(self, td_type: TypedDictType, + index: Expression, + local_errors: Optional[MessageBuilder] = None + ) -> Type: + local_errors = local_errors or self.msg if isinstance(index, (StrExpr, UnicodeExpr)): key_names = [index.value] else: @@ -3035,14 +3039,14 @@ def visit_typeddict_index_expr(self, td_type: TypedDictType, index: Expression) and key_type.fallback.type.fullname != 'builtins.bytes'): key_names.append(key_type.value) else: - self.msg.typeddict_key_must_be_string_literal(td_type, index) + local_errors.typeddict_key_must_be_string_literal(td_type, index) return AnyType(TypeOfAny.from_error) value_types = [] for key_name in key_names: value_type = td_type.items.get(key_name) if value_type is None: - self.msg.typeddict_key_not_found(td_type, key_name, index) + local_errors.typeddict_key_not_found(td_type, key_name, index) return AnyType(TypeOfAny.from_error) else: value_types.append(value_type) diff --git a/mypy/checkpattern.py b/mypy/checkpattern.py new file mode 100644 index 000000000000..a66b9b3b47a6 --- /dev/null +++ b/mypy/checkpattern.py @@ -0,0 +1,312 @@ +"""Pattern checker. This file is conceptually part of TypeChecker.""" +from typing import List, Optional, Union, Tuple + +from mypy import message_registry +from mypy.expandtype import expand_type_by_instance +from mypy.join import join_types + +from mypy.messages import MessageBuilder +from mypy.nodes import Expression, NameExpr, ARG_POS, TypeAlias, TypeInfo +from mypy.patterns import ( + Pattern, AsPattern, OrPattern, LiteralPattern, CapturePattern, WildcardPattern, ValuePattern, + SequencePattern, StarredPattern, MappingPattern, ClassPattern, MappingKeyPattern +) +from mypy.plugin import Plugin +from mypy.subtypes import is_subtype, find_member, is_equivalent +from mypy.typeops import try_getting_str_literals_from_type +from mypy.types import ( + ProperType, AnyType, TypeOfAny, Instance, Type, NoneType, UninhabitedType, get_proper_type, + TypedDictType, TupleType +) +from mypy.typevars import fill_typevars +from mypy.visitor import PatternVisitor +import mypy.checker + + +class PatternChecker(PatternVisitor[Optional[Type]]): + """Pattern checker. + + This class checks if a pattern can match a type, what the type can be narrowed to, and what + type capture patterns should be inferred as. + """ + + # Some services are provided by a TypeChecker instance. + chk = None # type: mypy.checker.TypeChecker + # This is shared with TypeChecker, but stored also here for convenience. + msg = None # type: MessageBuilder + # Currently unused + plugin = None # type: Plugin + # The expression being matched against the pattern + subject = None # type: Expression + # Type of the subject to check the (sub)pattern against + type_stack = [] # type: List[ProperType] + + def __init__(self, chk: 'mypy.checker.TypeChecker', msg: MessageBuilder, plugin: Plugin, + subject: Expression, subject_type: ProperType) -> None: + self.chk = chk + self.msg = msg + self.plugin = plugin + self.subject = subject + self.type_stack.append(subject_type) + + def check_pattern(self, o: Pattern) -> 'mypy.checker.TypeMap': + pattern_type = self.visit(o) + if pattern_type is None: + # This case is unreachable + return None + elif is_equivalent(self.type_stack[-1], pattern_type): + # No need to narrow + return {} + else: + return {self.subject: pattern_type} + + def visit(self, o: Pattern) -> Optional[Type]: + return o.accept(self) + + def visit_as_pattern(self, o: AsPattern) -> Optional[Type]: + return self.type_stack[-1] + + def visit_or_pattern(self, o: OrPattern) -> Optional[Type]: + return self.type_stack[-1] + + def visit_literal_pattern(self, o: LiteralPattern) -> Optional[Type]: + literal_type = self.get_literal_type(o.value) + return get_more_specific_type(literal_type, self.type_stack[-1]) + + def get_literal_type(self, l: Union[int, complex, float, str, bytes, None]) -> Type: + # TODO: Should we use ExprNodes instead of the raw value here? + if isinstance(l, int): + return self.chk.named_type("builtins.int") + elif isinstance(l, complex): + return self.chk.named_type("builtins.complex") + elif isinstance(l, float): + return self.chk.named_type("builtins.float") + elif isinstance(l, str): + return self.chk.named_type("builtins.str") + elif isinstance(l, bytes): + return self.chk.named_type("builtins.bytes") + elif isinstance(l, bool): + return self.chk.named_type("builtins.bool") + elif l is None: + return NoneType() + else: + assert False, "Invalid literal in literal pattern" + + def visit_capture_pattern(self, o: CapturePattern) -> Optional[Type]: + self.check_capture(o.name) + return self.type_stack[-1] + + def check_capture(self, capture: NameExpr) -> None: + capture_type, _, inferred = self.chk.check_lvalue(capture) + if capture_type: + self.chk.check_subtype(capture_type, self.type_stack[-1], capture, + msg=message_registry.INCOMPATIBLE_TYPES_IN_CAPTURE, + subtype_label="pattern captures type", + supertype_label="variable has type") + else: + assert inferred is not None + self.chk.infer_variable_type(inferred, capture, self.type_stack[-1], self.subject) + + def visit_wildcard_pattern(self, o: WildcardPattern) -> Optional[Type]: + return self.type_stack[-1] + + def visit_value_pattern(self, o: ValuePattern) -> Optional[Type]: + typ = self.chk.expr_checker.accept(o.expr) + return get_more_specific_type(typ, self.type_stack[-1]) + + def visit_sequence_pattern(self, o: SequencePattern) -> Optional[Type]: + current_type = self.type_stack[-1] + inner_type = get_proper_type(self.get_sequence_type(current_type)) + if inner_type is None: + if is_subtype(self.chk.named_type("typing.Iterable"), current_type): + # Current type is more general, but the actual value could still be iterable + inner_type = self.chk.named_type("builtins.object") + else: + # Pattern can't match + return None + + assert isinstance(current_type, Instance) + self.type_stack.append(inner_type) + new_inner_type = UninhabitedType() # type: Type + for p in o.patterns: + pattern_type = self.visit(p) + if pattern_type is None: + return None + new_inner_type = join_types(new_inner_type, pattern_type) + self.type_stack.pop() + iterable = self.chk.named_generic_type("typing.Iterable", [new_inner_type]) + if self.chk.type_is_iterable(current_type): + empty_type = fill_typevars(current_type.type) + partial_type = expand_type_by_instance(empty_type, iterable) + new_type = expand_type_by_instance(partial_type, current_type) + else: + new_type = iterable + + if is_subtype(new_type, current_type): + return new_type + else: + return current_type + + def get_sequence_type(self, t: ProperType) -> Optional[Type]: + if isinstance(t, AnyType): + return AnyType(TypeOfAny.from_another_any, t) + + if self.chk.type_is_iterable(t) and isinstance(t, Instance): + return self.chk.iterable_item_type(t) + else: + return None + + def visit_starred_pattern(self, o: StarredPattern) -> Optional[Type]: + if not isinstance(o.capture, WildcardPattern): + list_type = self.chk.named_generic_type('builtins.list', [self.type_stack[-1]]) + self.type_stack.append(list_type) + self.visit_capture_pattern(o.capture) + self.type_stack.pop() + return self.type_stack[-1] + + def visit_mapping_pattern(self, o: MappingPattern) -> Optional[Type]: + current_type = self.type_stack[-1] + can_match = True + for key, value in zip(o.keys, o.values): + inner_type = self.get_mapping_item_type(o, current_type, key) + if inner_type is None: + can_match = False + inner_type = self.chk.named_type("builtins.object") + inner_type = get_proper_type(inner_type) + self.type_stack.append(inner_type) + if self.visit(value) is None: + can_match = False + self.type_stack.pop() + if can_match: + return self.type_stack[-1] + else: + return None + + def get_mapping_item_type(self, + pattern: MappingPattern, + mapping_type: Type, + key_pattern: MappingKeyPattern + ) -> Optional[Type]: + mapping_type = get_proper_type(mapping_type) + local_errors = self.msg.clean_copy() + local_errors.disable_count = 0 + if isinstance(mapping_type, TypedDictType): + result = self.chk.expr_checker.visit_typeddict_index_expr(mapping_type, + key_pattern.expr, + local_errors=local_errors + ) # type: Optional[Type] + # If we can't determine the type statically fall back to treating it as a normal + # mapping + if local_errors.is_errors(): + local_errors = self.msg.clean_copy() + local_errors.disable_count = 0 + result = self.get_simple_mapping_item_type(pattern, + mapping_type, + key_pattern, + local_errors) + + if local_errors.is_errors(): + result = None + else: + result = self.get_simple_mapping_item_type(pattern, + mapping_type, + key_pattern, + local_errors) + return result + + def get_simple_mapping_item_type(self, + pattern: MappingPattern, + mapping_type: Type, + key_pattern: MappingKeyPattern, + local_errors: MessageBuilder + ) -> Type: + result, _ = self.chk.expr_checker.check_method_call_by_name('__getitem__', + mapping_type, + [key_pattern.expr], + [ARG_POS], + pattern, + local_errors=local_errors) + return result + + def visit_class_pattern(self, o: ClassPattern) -> Optional[Type]: + can_match = True + + class_name = o.class_ref.fullname + assert class_name is not None + sym = self.chk.lookup_qualified(class_name) + if isinstance(sym.node, (TypeAlias, TypeInfo)): + typ = self.chk.named_type(class_name) + else: + self.msg.fail("Class pattern must be a type. Found '{}'".format(sym.type), o.class_ref) + typ = self.chk.named_type("builtins.object") + can_match = False + match_args_type = find_member("__match_args__", typ, typ) + + if match_args_type is None and can_match: + if len(o.positionals) >= 1: + self.msg.fail("Class doesn't define __match_args__", o) + + proper_match_args_type = get_proper_type(match_args_type) + if isinstance(proper_match_args_type, TupleType): + match_arg_names = get_match_arg_names(proper_match_args_type) + + if len(o.positionals) > len(match_arg_names): + self.msg.fail("Too many positional patterns for class pattern", o) + match_arg_names += [None] * (len(o.positionals) - len(match_arg_names)) + else: + match_arg_names = [None] * len(o.positionals) + + positional_names = set() + keyword_pairs = [] # type: List[Tuple[Optional[str], Pattern]] + + for arg_name, pos in zip(match_arg_names, o.positionals): + keyword_pairs.append((arg_name, pos)) + positional_names.add(arg_name) + + keyword_names = set() + for key, value in zip(o.keyword_keys, o.keyword_values): + keyword_pairs.append((key, value)) + if key in match_arg_names: + self.msg.fail("Keyword '{}' already matches a positional pattern".format(key), + value) + elif key in keyword_names: + self.msg.fail("Duplicate keyword pattern '{}'".format(key), value) + keyword_names.add(key) + + for keyword, pattern in keyword_pairs: + if keyword is not None: + key_type = find_member(keyword, typ, typ) + if key_type is None: + key_type = self.chk.named_type("builtins.object") + else: + key_type = self.chk.named_type("builtins.object") + + self.type_stack.append(get_proper_type(key_type)) + if self.visit(pattern) is None: + can_match = False + self.type_stack.pop() + + if can_match: + return get_more_specific_type(self.type_stack[-1], typ) + else: + return None + + +def get_match_arg_names(typ: TupleType) -> List[Optional[str]]: + args = [] # type: List[Optional[str]] + for item in typ.items: + values = try_getting_str_literals_from_type(item) + if values is None or len(values) != 1: + args.append(None) + else: + args.append(values[0]) + return args + + +def get_more_specific_type(left: Type, right: Type) -> Optional[Type]: + if is_subtype(left, right): + return left + elif is_subtype(right, left): + return right + else: + return None diff --git a/mypy/errorcodes.py b/mypy/errorcodes.py index b0d0ad1f1cbe..ce36eb143a29 100644 --- a/mypy/errorcodes.py +++ b/mypy/errorcodes.py @@ -92,6 +92,8 @@ def __str__(self) -> str: 'General') # type: Final EXIT_RETURN = ErrorCode( 'exit-return', "Warn about too general return type for '__exit__'", 'General') # type: Final +LITERAL_REQ = ErrorCode( + "literal-required", "Check that value is a literal", 'General') # type: Final # These error codes aren't enabled by default. NO_UNTYPED_DEF = ErrorCode( diff --git a/mypy/fastparse.py b/mypy/fastparse.py index 1c8cf637cd23..53490ae8d594 100644 --- a/mypy/fastparse.py +++ b/mypy/fastparse.py @@ -1318,6 +1318,9 @@ class PatternConverter(Converter): # Errors is optional is superclass, but not here errors = None # type: Errors + has_sequence = False # type: bool + has_mapping = False # type: bool + def __init__(self, options: Options, errors: Errors) -> None: super().__init__(errors) @@ -1340,7 +1343,7 @@ def visit_MatchOr(self, n: MatchOr) -> OrPattern: def visit_Constant(self, n: Constant) -> LiteralPattern: val = n.value if val is None or isinstance(val, (bool, int, float, complex, str, bytes)): - node = LiteralPattern(val) + node = LiteralPattern(val, ASTConverter(self.options, False, self.errors).visit(n)) else: raise RuntimeError("Pattern not implemented for " + str(type(val))) return self.set_line(node, n) @@ -1354,9 +1357,9 @@ def visit_UnaryOp(self, n: ast3.UnaryOp) -> LiteralPattern: value = self.assert_numeric_constant(n.operand) if isinstance(n.op, ast3.UAdd): - node = LiteralPattern(value) + node = LiteralPattern(value, ASTConverter(self.options, False, self.errors).visit(n)) elif isinstance(n.op, ast3.USub): - node = LiteralPattern(-value) + node = LiteralPattern(-value, ASTConverter(self.options, False, self.errors).visit(n)) else: raise RuntimeError("Pattern not implemented for " + str(type(n.op))) @@ -1374,9 +1377,11 @@ def visit_BinOp(self, n: ast3.BinOp) -> LiteralPattern: raise RuntimeError("Unsupported pattern") if isinstance(n.op, ast3.Add): - node = LiteralPattern(left_val + right_val) + node = LiteralPattern(left_val + right_val, + ASTConverter(self.options, False, self.errors).visit(n)) elif isinstance(n.op, ast3.Sub): - node = LiteralPattern(left_val - right_val) + node = LiteralPattern(left_val - right_val, + ASTConverter(self.options, False, self.errors).visit(n)) else: raise RuntimeError("Unsupported pattern") diff --git a/mypy/message_registry.py b/mypy/message_registry.py index b25f055bccf8..9d14d086c37d 100644 --- a/mypy/message_registry.py +++ b/mypy/message_registry.py @@ -45,6 +45,7 @@ INCOMPATIBLE_TYPES_IN_YIELD_FROM = 'Incompatible types in "yield from"' # type: Final INCOMPATIBLE_TYPES_IN_STR_INTERPOLATION = \ 'Incompatible types in string interpolation' # type: Final +INCOMPATIBLE_TYPES_IN_CAPTURE = 'Incompatible types in capture pattern' # type: Final MUST_HAVE_NONE_RETURN_TYPE = 'The return type of "{}" must be None' # type: Final INVALID_TUPLE_INDEX_TYPE = 'Invalid tuple index type' # type: Final TUPLE_INDEX_OUT_OF_RANGE = 'Tuple index out of range' # type: Final diff --git a/mypy/messages.py b/mypy/messages.py index 5d0112be06b0..792cb4e18053 100644 --- a/mypy/messages.py +++ b/mypy/messages.py @@ -1156,7 +1156,7 @@ def typeddict_key_must_be_string_literal( context: Context) -> None: self.fail( 'TypedDict key must be a string literal; expected one of {}'.format( - format_item_name_list(typ.items.keys())), context) + format_item_name_list(typ.items.keys())), context, code=codes.LITERAL_REQ) def typeddict_key_not_found( self, diff --git a/mypy/patterns.py b/mypy/patterns.py index f63e7c16057c..cb0d7260cd4b 100644 --- a/mypy/patterns.py +++ b/mypy/patterns.py @@ -1,9 +1,9 @@ """Classes for representing match statement patterns.""" -from typing import TypeVar, List, Any, Optional +from typing import TypeVar, List, Optional, Union from mypy_extensions import trait -from mypy.nodes import Node, MemberExpr, RefExpr, NameExpr +from mypy.nodes import Node, MemberExpr, RefExpr, NameExpr, Expression from mypy.visitor import PatternVisitor # These are not real AST nodes. CPython represents patterns using the normal expression nodes. @@ -32,7 +32,11 @@ class AlwaysTruePattern(Pattern): class MappingKeyPattern(Pattern): """A pattern that can be used as a key in a mapping pattern""" - __slots__ = () + __slots__ = ("expr",) + + def __init__(self, expr: Expression) -> None: + super().__init__() + self.expr = expr class AsPattern(Pattern): @@ -59,12 +63,15 @@ def accept(self, visitor: PatternVisitor[T]) -> T: return visitor.visit_or_pattern(self) -# TODO: Do we need subclasses for the types of literals? +LiteralPatternType = Union[int, complex, float, str, bytes, bool, None] + + class LiteralPattern(MappingKeyPattern): - value = None # type: Any + value = None # type: LiteralPatternType + expr = None # type: Expression - def __init__(self, value: Any): - super().__init__() + def __init__(self, value: LiteralPatternType, expr: Expression): + super().__init__(expr) self.value = value def accept(self, visitor: PatternVisitor[T]) -> T: @@ -91,8 +98,7 @@ class ValuePattern(MappingKeyPattern): expr = None # type: MemberExpr def __init__(self, expr: MemberExpr): - super().__init__() - self.expr = expr + super().__init__(expr) def accept(self, visitor: PatternVisitor[T]) -> T: return visitor.visit_value_pattern(self) @@ -112,9 +118,9 @@ def accept(self, visitor: PatternVisitor[T]) -> T: # TODO: A StarredPattern is only valid within a SequencePattern. This is not guaranteed by our # type hierarchy. Should it be? class StarredPattern(Pattern): - capture = None # type: AlwaysTruePattern + capture = None # type: Union[WildcardPattern, CapturePattern] - def __init__(self, capture: AlwaysTruePattern): + def __init__(self, capture: Union[WildcardPattern, CapturePattern]): super().__init__() self.capture = capture diff --git a/mypy/semanal.py b/mypy/semanal.py index 0f324562ddef..b3d21d3bd689 100644 --- a/mypy/semanal.py +++ b/mypy/semanal.py @@ -4016,8 +4016,8 @@ def visit_mapping_pattern(self, p: MappingPattern) -> None: def visit_class_pattern(self, p: ClassPattern) -> None: p.class_ref.accept(self) - for p in p.positionals: - p.accept(self) + for pos in p.positionals: + pos.accept(self) for v in p.keyword_values: v.accept(self) diff --git a/mypy/strconv.py b/mypy/strconv.py index 05296edb6992..a12798fe73e9 100644 --- a/mypy/strconv.py +++ b/mypy/strconv.py @@ -558,7 +558,7 @@ def visit_literal_pattern(self, o: 'mypy.patterns.LiteralPattern') -> str: value = o.value if isinstance(o.value, str): value = "'" + self.str_repr(o.value) + "'" - return self.dump([value], o) + return self.dump([value, o.expr], o) def visit_capture_pattern(self, o: 'mypy.patterns.CapturePattern') -> str: return self.dump([o.name], o) diff --git a/mypy/traverser.py b/mypy/traverser.py index 8693d2038981..57f92908dbff 100644 --- a/mypy/traverser.py +++ b/mypy/traverser.py @@ -4,7 +4,7 @@ from mypy_extensions import mypyc_attr from mypy.patterns import AsPattern, OrPattern, CapturePattern, ValuePattern, SequencePattern, \ - StarredPattern, MappingPattern, ClassPattern + StarredPattern, MappingPattern, ClassPattern, LiteralPattern from mypy.visitor import NodeVisitor from mypy.nodes import ( Block, MypyFile, FuncBase, FuncItem, CallExpr, ClassDef, Decorator, FuncDef, @@ -297,6 +297,9 @@ def visit_or_pattern(self, o: OrPattern) -> None: for p in o.patterns: p.accept(self) + def visit_literal_pattern(self, o: LiteralPattern) -> None: + o.expr.accept(self) + def visit_capture_pattern(self, o: CapturePattern) -> None: o.name.accept(self) diff --git a/test-data/unit/check-python310.test b/test-data/unit/check-python310.test index defe1038fa63..c06a29ee59e2 100644 --- a/test-data/unit/check-python310.test +++ b/test-data/unit/check-python310.test @@ -1,190 +1,546 @@ -[case testSimpleMatch] -# flags: --python-version 3.10 +-- Capture Pattern -- +[case testCapturePatternType] +class A: ... +m: A + +match m: + case a: + reveal_type(a) # N: Revealed type is '__main__.A' + +[case testCapturePatternPreexistingSame] class A: ... -class B: ... a: A +m: A + +match m: + case a: + reveal_type(a) # N: Revealed type is '__main__.A' + +[case testCapturePatternPreexistingIncompatible] +class A: ... +class B: ... +a: B +m: A + +match m: + case a: # E: Incompatible types in capture pattern (pattern captures type "B", variable has type "A") + reveal_type(a) # N: Revealed type is '__main__.B' + + +-- Literal Pattern -- +[case testLiteralPatternNarrows] m: object match m: - case ["quit"]: - a = B() # E: Incompatible types in assignment (expression has type "B", variable has type "A") - case ["look"]: - a = A() + case 1: + reveal_type(m) # N: Revealed type is 'builtins.int' -reveal_type(a) +[case testLiteralPatternAlreadyNarrower] +m: bool +match m: + case 1: + reveal_type(m) # N: Revealed type is 'builtins.bool' -[case testMatchAs] -# flags: --python-version 3.10 -class A: ... +[case testLiteralPatternUnreachable] +m: int + +match m: + case "str": + reveal_type(m) + + +-- Value Pattern -- +[case testValuePatternNarrows] +import b m: object match m: - case [x] as a: - reveal_type(a) - reveal_type(x) + case b.b: + reveal_type(m) # N: Revealed type is 'builtins.int' +[file b.py] +b: int +[case testValuePatternAlreadyNarrower] +import b +m: bool -[case testMatchOr] -# flags: --python-version 3.10 -class A: ... +match m: + case b.b: + reveal_type(m) # N: Revealed type is 'builtins.bool' +[file b.py] +b: int + +[case testValuePatternUnreachable] +import b +m: int + +match m: + case b.b: + reveal_type(m) +[file b.py] +b: str + + +-- Sequence Pattern -- +[case testSequenceCPatternCaptures] +from typing import List +m: List[int] + +match m: + case [a]: + reveal_type(a) # N: Revealed type is 'builtins.int*' +[builtins fixtures/list.pyi] + +[case testSequencePatternCapturesStarred] +from typing import Iterable +m: Iterable[int] + +match m: + case [a, *b]: + reveal_type(a) # N: Revealed type is 'builtins.int' + reveal_type(b) # N: Revealed type is 'builtins.list[builtins.int]' +[builtins fixtures/list.pyi] + +[case testSequencePatternNarrowsInner] +from typing import Iterable +m: Iterable[object] + +match m: + case [1, True]: + reveal_type(m) # N: Revealed type is 'typing.Iterable[builtins.int]' + +[case testSequencePatternNarrowsOuter] +from typing import Sequence m: object match m: - case [x] | (x): - reveal_type(x) + case [1, True]: + reveal_type(m) # N: Revealed type is 'typing.Iterable[builtins.int]' + +[case testSequencePatternAlreadyNarrowerInner] +from typing import Iterable +m: Iterable[bool] +match m: + case [1, True]: + reveal_type(m) # N: Revealed type is 'typing.Iterable[builtins.bool]' -[case testMatchWithGuard] -# flags: --python-version 3.10 +[case testSequencePatternAlreadyNarrowerOuter] +from typing import Sequence +m: Sequence[object] + +match m: + case [1, True]: + reveal_type(m) # N: Revealed type is 'typing.Sequence[builtins.int]' + +[case testSequencePatternAlreadyNarrowerBoth] +from typing import Sequence +m: Sequence[bool] + +match m: + case [1, True]: + reveal_type(m) # N: Revealed type is 'typing.Sequence[builtins.bool]' + +[case testNestedSequencePatternNarrowsInner] +from typing import Iterable +m: Iterable[Iterable[object]] + +match m: + case [[1], [True]]: + reveal_type(m) # N: Revealed type is 'typing.Iterable[typing.Iterable[builtins.int]]' + +[case testNestedSequencePatternNarrowsOuter] +from typing import Iterable m: object match m: - case 1 if False: - pass - case 2: - pass + case [[1], [True]]: + reveal_type(m) # N: Revealed type is 'typing.Iterable[typing.Iterable[builtins.int]]' + + +[case testSequencePatternDoesntNarrowInvariant] +from typing import List +m: List[object] -[case testLiteralPattern] -# flags: --python-version 3.10 +match m: + case [1]: + reveal_type(m) # N: Revealed type is 'builtins.list[builtins.object]' +[builtins fixtures/list.pyi] + + +-- Mapping Pattern -- +[case testMappingPatternCaptures] +from typing import Dict +import b +m: Dict[str, int] + +match m: + case {"key": v}: + reveal_type(v) # N: Revealed type is 'builtins.int*' + case {b.b: v2}: + reveal_type(v2) # N: Revealed type is 'builtins.int*' +[file b.py] +b: str +[builtins fixtures/dict.pyi] + +[case testMappingPatternCapturesWrongKeyType] +# This is not actually unreachable, as a subclass of dict could accept keys with different types +from typing import Dict +import b +m: Dict[str, int] + +match m: + case {1: v}: + reveal_type(v) # N: Revealed type is 'builtins.int*' + case {b.b: v2}: + reveal_type(v2) # N: Revealed type is 'builtins.int*' +[file b.py] +b: int +[builtins fixtures/dict.pyi] + +[case testMappingPatternCapturesTypedDict] +from typing import TypedDict + +class A(TypedDict): + a: str + b: int + +m: A + +match m: + case {"a": v}: + reveal_type(v) # N: Revealed type is 'builtins.str' + case {"b": v2}: + reveal_type(v2) # N: Revealed type is 'builtins.int' + case {"a": v3, "b": v4}: + reveal_type(v3) # N: Revealed type is 'builtins.str' + reveal_type(v4) # N: Revealed type is 'builtins.int' + case {"o": v5}: + reveal_type(v5) # N: Revealed type is 'builtins.object*' +[typing fixtures/typing-typeddict.pyi] + +[case testMappingPatternCapturesTypedDictWithLiteral] +from typing import TypedDict +import b + +class A(TypedDict): + a: str + b: int + +m: A + +match m: + case {b.a: v}: + reveal_type(v) # N: Revealed type is 'builtins.str' + case {b.b: v2}: + reveal_type(v2) # N: Revealed type is 'builtins.int' + case {b.a: v3, b.b: v4}: + reveal_type(v3) # N: Revealed type is 'builtins.str' + reveal_type(v4) # N: Revealed type is 'builtins.int' + case {b.o: v5}: + reveal_type(v5) # N: Revealed type is 'builtins.object*' +[file b.py] +from typing import Final, Literal +a: Final = "a" +b: Literal["b"] = "b" +o: Final[str] = "o" +[typing fixtures/typing-typeddict.pyi] + +[case testMappingPatternCapturesTypedDictWithNonLiteral] +from typing import TypedDict +import b + +class A(TypedDict): + a: str + b: int + +m: A + +match m: + case {b.a: v}: + reveal_type(v) # N: Revealed type is 'builtins.object*' +[file b.py] +from typing import Final, Literal +a: str +[typing fixtures/typing-typeddict.pyi] + +[case testMappingPatternCapturesTypedDictUnreachable] +# TypedDict keys are always str, so this is actually unreachable +from typing import TypedDict +import b + +class A(TypedDict): + a: str + b: int + +m: A + +match m: + case {1: v}: + reveal_type(v) + case {b.b: v2}: + reveal_type(v2) +[file b.py] +b: int +[typing fixtures/typing-typeddict.pyi] + +-- Mapping patterns currently don't narrow -- + +-- Class Pattern -- +[case testClassPatternCapturePositional] +from typing import Final + +class A: + __match_args__: Final = ("a", "b") + a: str + b: int + +m: A + +match m: + case A(i, j): + reveal_type(i) # N: Revealed type is 'builtins.str' + reveal_type(j) # N: Revealed type is 'builtins.int' +[builtins fixtures/tuple.pyi] + +[case testClassPatternMemberClassCapturePositional] +import b + +m: b.A + +match m: + case b.A(i, j): + reveal_type(i) # N: Revealed type is 'builtins.str' + reveal_type(j) # N: Revealed type is 'builtins.int' +[file b.py] +from typing import Final + +class A: + __match_args__: Final = ("a", "b") + a: str + b: int +[builtins fixtures/tuple.pyi] + +[case testClassPatternCaptureKeyword] +class A: + a: str + b: int + +m: A + +match m: + case A(a=i, b=j): + reveal_type(i) # N: Revealed type is 'builtins.str' + reveal_type(j) # N: Revealed type is 'builtins.int' + +[case testClassPatternCaptureSelf] m: object + match m: - case 1: - pass - case -1: - pass - case 1+2j: - pass - case -1+2j: - pass - case 1-2j: - pass - case -1-2j: - pass - case "str": - pass - case b"bytes": - pass - case r"raw_string": - pass - case None: - pass - case True: - pass - case False: - pass + case bool(a): + reveal_type(a) # N: Revealed type is 'builtins.bool' + case bytearray(b): + reveal_type(b) # N: Revealed type is 'builtins.bytearray' + case bytes(c): + reveal_type(c) # N: Revealed type is 'builtins.bytes' + case dict(d): + reveal_type(d) # N: Revealed type is 'builtins.dict' + case float(e): + reveal_type(e) # N: Revealed type is 'builtins.float' + case frozenset(f): + reveal_type(f) # N: Revealed type is 'builtins.frozenset' + case int(g): + reveal_type(g) # N: Revealed type is 'builtins.int' + case list(h): + reveal_type(h) # N: Revealed type is 'builtins.list' + case set(i): + reveal_type(i) # N: Revealed type is 'builtins.set' + case str(j): + reveal_type(j) # N: Revealed type is 'builtins.str' + case tuple(k): + reveal_type(k) # N: Revealed type is 'builtins.tuple' +[builtins fixtures/dict.pyi] +[builtins fixtures/list.pyi] +[builtins fixtures/set.pyi] +[builtins fixtures/tuple.pyi] +[case testClassPatternNarrows] +from typing import Final + +class A: + __match_args__: Final = ("a", "b") + a: str + b: int -[case testCapturePattern] -# flags: --python-version 3.10 m: object + match m: - case x: - pass - case longName: - pass + case A(): + reveal_type(m) # N: Revealed type is '__main__.A' + case A(i, j): + reveal_type(m) # N: Revealed type is '__main__.A' +[builtins fixtures/tuple.pyi] + +[case testClassPatternAlreadyNarrower] +from typing import Final + +class A: + __match_args__: Final = ("a", "b") + a: str + b: int +class B(A): ... + +m: B + +match m: + case A(): + reveal_type(m) # N: Revealed type is '__main__.B' + case A(i, j): + reveal_type(m) # N: Revealed type is '__main__.B' +[builtins fixtures/tuple.pyi] + +[case testClassPatternUnreachable] +from typing import Final + +class A: + __match_args__: Final = ("a", "b") + a: str + b: int +class B: ... + +m: B + +match m: + case A(): + reveal_type(m) + case A(i, j): + reveal_type(m) +[builtins fixtures/tuple.pyi] + +[case testClassPatternNonexistentKeyword] +class A: ... -[case testWildcardPattern] -# flags: --python-version 3.10 m: object + match m: - case _: + case A(a=j): + reveal_type(m) # N: Revealed type is '__main__.A' + reveal_type(j) # N: Revealed type is 'builtins.object' + +[case testClassPatternDuplicateKeyword] +class A: + a: str + +m: object + +match m: + case A(a=i, a=j): # E: Duplicate keyword pattern 'a' pass +[case testClassPatternDuplicateImplicitKeyword] +from typing import Final -[case testValuePattern] -# flags: --python-version 3.10 class A: - b = 1 -a = A() + __match_args__: Final = ("a",) + a: str + m: object match m: - case a.b: + case A(i, a=j): # E: Keyword 'a' already matches a positional pattern pass +[builtins fixtures/tuple.pyi] + +[case testClassPatternTooManyPositionals] +from typing import Final +class A: + __match_args__: Final = ("a", "b") + a: str + b: int -[case testGroupPattern] -# flags: --python-version 3.10 m: object match m: - case (1): + case A(i, j, k): # E: Too many positional patterns for class pattern pass +[builtins fixtures/tuple.pyi] +[case testClassPatternIsNotType] +a = 1 +m: object + +match m: + case a(i, j): # E: Class pattern must be a type. Found 'builtins.int' + reveal_type(i) + reveal_type(j) + +[case testNonFinalMatchArgs] +class A: + __match_args__ = ("a", "b") # N: __match_args__ must be final for checking of match statements to work + a: str + b: int -[case testSequencePattern] -# flags: --python-version 3.10 m: object match m: - case []: - pass - case (): - pass - case [1]: - pass - case (1,): - pass - case [1, 2, 3]: - pass - case (1, 2, 3): - pass - case [1, *a, 2]: - pass - case (1, *a, 2): - pass - case [1, *_, 2]: - pass - case (1, *_, 2): - pass + case A(i, j): + reveal_type(i) # N: Revealed type is 'builtins.object' + reveal_type(j) # N: Revealed type is 'builtins.object' +[builtins fixtures/tuple.pyi] +[case testAnyTupleMatchArgs] +from typing import Tuple, Any -[case testMappingPattern] -# flags: --python-version 3.10 class A: - b = 'l' - c = 2 -a = A() + __match_args__: Tuple[Any, ...] + a: str + b: int + m: object match m: - case {'k': v}: - pass - case {a.b: v}: - pass - case {1: v}: - pass - case {a.c: v}: - pass - case {'k': v1, a.b: v2, 1: v3, a.c: v4}: - pass - case {'k1': 1, 'k2': "str", 'k3': b'bytes', 'k4': None}: - pass - case {'k': v, **r}: - pass - case {**r}: - pass + case A(i, j, k): + reveal_type(i) # N: Revealed type is 'builtins.object' + reveal_type(j) # N: Revealed type is 'builtins.object' + reveal_type(k) # N: Revealed type is 'builtins.object' +[builtins fixtures/tuple.pyi] +[case testNonLiteralMatchArgs] +from typing import Final -[case testClassPattern] -# flags: --python-version 3.10 +b: str = "b" class A: - pass -class B: - __match_args__ = ('a', 'b') - a: int + __match_args__: Final = ("a", b) # N: __match_args__ must be a tuple containing string literals for checking of match statements to work + a: str b: int m: object match m: - case A(): - pass - case B(1, 2): - pass - case B(1, b=2): - pass - case B(a=1, b=2): + case A(i, j, k): # E: Too many positional patterns for class pattern pass + case A(i, j): + reveal_type(i) # N: Revealed type is 'builtins.str' + reveal_type(j) # N: Revealed type is 'builtins.object' +[builtins fixtures/tuple.pyi] + +[case testExternalMatchArgs] +from typing import Final, Literal + +args: Final = ("a", "b") +class A: + __match_args__: Final = args + a: str + b: int + +arg: Final = "a" +arg2: Literal["b"] = "b" +class B: + __match_args__: Final = (arg, arg2) + a: str + b: int + [builtins fixtures/tuple.pyi] +[typing fixtures/typing-medium.pyi] diff --git a/test-data/unit/parse-python310.test b/test-data/unit/parse-python310.test index 9f3669f5b0b8..f6e2b998d176 100644 --- a/test-data/unit/parse-python310.test +++ b/test-data/unit/parse-python310.test @@ -12,7 +12,8 @@ MypyFile:1( NameExpr(a) Pattern( LiteralPattern:2( - 1)) + 1 + IntExpr(1))) Body( PassStmt:3()))) @@ -29,7 +30,8 @@ MypyFile:1( NameExpr(b)) Pattern( LiteralPattern:2( - 1)) + 1 + IntExpr(1))) Body( PassStmt:3()))) @@ -45,7 +47,8 @@ MypyFile:1( NameExpr(a) Pattern( LiteralPattern:2( - 1)) + 1 + IntExpr(1))) Guard( CallExpr:2( NameExpr(f) @@ -74,7 +77,8 @@ MypyFile:1( Pattern( AsPattern:2( LiteralPattern:2( - 1) + 1 + IntExpr(1)) NameExpr(b))) Body( PassStmt:3()))) @@ -112,61 +116,91 @@ MypyFile:1( NameExpr(a) Pattern( LiteralPattern:2( - 1)) + 1 + IntExpr(1))) Body( PassStmt:3()) Pattern( LiteralPattern:4( - -1)) + -1 + UnaryExpr:4( + - + IntExpr(1)))) Body( PassStmt:5()) Pattern( LiteralPattern:6( - (1+2j))) + (1+2j) + OpExpr:6( + + + IntExpr(1) + ComplexExpr(2j)))) Body( PassStmt:7()) Pattern( LiteralPattern:8( - (-1+2j))) + (-1+2j) + OpExpr:8( + + + UnaryExpr:8( + - + IntExpr(1)) + ComplexExpr(2j)))) Body( PassStmt:9()) Pattern( LiteralPattern:10( - (1-2j))) + (1-2j) + OpExpr:10( + - + IntExpr(1) + ComplexExpr(2j)))) Body( PassStmt:11()) Pattern( LiteralPattern:12( - (-1-2j))) + (-1-2j) + OpExpr:12( + - + UnaryExpr:12( + - + IntExpr(1)) + ComplexExpr(2j)))) Body( PassStmt:13()) Pattern( LiteralPattern:14( - 'str')) + 'str' + StrExpr(str))) Body( PassStmt:15()) Pattern( LiteralPattern:16( - b'bytes')) + b'bytes' + BytesExpr(bytes))) Body( PassStmt:17()) Pattern( LiteralPattern:18( - 'raw_string')) + 'raw_string' + StrExpr(raw_string))) Body( PassStmt:19()) Pattern( - LiteralPattern:20()) + LiteralPattern:20( + NameExpr(None))) Body( PassStmt:21()) Pattern( LiteralPattern:22( - True)) + True + NameExpr(True))) Body( PassStmt:23()) Pattern( LiteralPattern:24( - False)) + False + NameExpr(False))) Body( PassStmt:25()))) @@ -246,7 +280,8 @@ MypyFile:1( NameExpr(a) Pattern( LiteralPattern:3( - 1)) + 1 + IntExpr(1))) Body( PassStmt:4()))) @@ -295,112 +330,136 @@ MypyFile:1( Pattern( SequencePattern:6( LiteralPattern:6( - 1))) + 1 + IntExpr(1)))) Body( PassStmt:7()) Pattern( SequencePattern:8( LiteralPattern:8( - 1))) + 1 + IntExpr(1)))) Body( PassStmt:9()) Pattern( SequencePattern:10( LiteralPattern:10( - 1))) + 1 + IntExpr(1)))) Body( PassStmt:11()) Pattern( SequencePattern:12( LiteralPattern:12( - 1) + 1 + IntExpr(1)) LiteralPattern:12( - 2) + 2 + IntExpr(2)) LiteralPattern:12( - 3))) + 3 + IntExpr(3)))) Body( PassStmt:13()) Pattern( SequencePattern:14( LiteralPattern:14( - 1) + 1 + IntExpr(1)) LiteralPattern:14( - 2) + 2 + IntExpr(2)) LiteralPattern:14( - 3))) + 3 + IntExpr(3)))) Body( PassStmt:15()) Pattern( SequencePattern:16( LiteralPattern:16( - 1) + 1 + IntExpr(1)) LiteralPattern:16( - 2) + 2 + IntExpr(2)) LiteralPattern:16( - 3))) + 3 + IntExpr(3)))) Body( PassStmt:17()) Pattern( SequencePattern:18( LiteralPattern:18( - 1) + 1 + IntExpr(1)) StarredPattern:18( CapturePattern:18( NameExpr(a))) LiteralPattern:18( - 2))) + 2 + IntExpr(2)))) Body( PassStmt:19()) Pattern( SequencePattern:20( LiteralPattern:20( - 1) + 1 + IntExpr(1)) StarredPattern:20( CapturePattern:20( NameExpr(a))) LiteralPattern:20( - 2))) + 2 + IntExpr(2)))) Body( PassStmt:21()) Pattern( SequencePattern:22( LiteralPattern:22( - 1) + 1 + IntExpr(1)) StarredPattern:22( CapturePattern:22( NameExpr(a))) LiteralPattern:22( - 2))) + 2 + IntExpr(2)))) Body( PassStmt:23()) Pattern( SequencePattern:24( LiteralPattern:24( - 1) + 1 + IntExpr(1)) StarredPattern:24( WildcardPattern:24()) LiteralPattern:24( - 2))) + 2 + IntExpr(2)))) Body( PassStmt:25()) Pattern( SequencePattern:26( LiteralPattern:26( - 1) + 1 + IntExpr(1)) StarredPattern:26( WildcardPattern:26()) LiteralPattern:26( - 2))) + 2 + IntExpr(2)))) Body( PassStmt:27()) Pattern( SequencePattern:28( LiteralPattern:28( - 1) + 1 + IntExpr(1)) StarredPattern:28( WildcardPattern:28()) LiteralPattern:28( - 2))) + 2 + IntExpr(2)))) Body( PassStmt:29()))) @@ -430,7 +489,8 @@ MypyFile:1( MappingPattern:2( Key( LiteralPattern:2( - 'k')) + 'k' + StrExpr(k))) Value( CapturePattern:2( NameExpr(v))))) @@ -452,7 +512,8 @@ MypyFile:1( MappingPattern:6( Key( LiteralPattern:6( - 1)) + 1 + IntExpr(1))) Value( CapturePattern:6( NameExpr(v))))) @@ -474,7 +535,8 @@ MypyFile:1( MappingPattern:10( Key( LiteralPattern:10( - 'k')) + 'k' + StrExpr(k))) Value( CapturePattern:10( NameExpr(v1))) @@ -488,7 +550,8 @@ MypyFile:1( NameExpr(v2))) Key( LiteralPattern:10( - 1)) + 1 + IntExpr(1))) Value( CapturePattern:10( NameExpr(v3))) @@ -506,34 +569,43 @@ MypyFile:1( MappingPattern:12( Key( LiteralPattern:12( - 'k1')) + 'k1' + StrExpr(k1))) Value( LiteralPattern:12( - 1)) + 1 + IntExpr(1))) Key( LiteralPattern:12( - 'k2')) + 'k2' + StrExpr(k2))) Value( LiteralPattern:12( - 'str')) + 'str' + StrExpr(str))) Key( LiteralPattern:12( - 'k3')) + 'k3' + StrExpr(k3))) Value( LiteralPattern:12( - b'bytes')) + b'bytes' + BytesExpr(bytes))) Key( LiteralPattern:12( - 'k4')) + 'k4' + StrExpr(k4))) Value( - LiteralPattern:12()))) + LiteralPattern:12( + NameExpr(None))))) Body( PassStmt:13()) Pattern( MappingPattern:14( Key( LiteralPattern:14( - 'k')) + 'k' + StrExpr(k))) Value( CapturePattern:14( NameExpr(v))) @@ -574,9 +646,11 @@ MypyFile:1( NameExpr(B) Positionals( LiteralPattern:4( - 1) + 1 + IntExpr(1)) LiteralPattern:4( - 2)))) + 2 + IntExpr(2))))) Body( PassStmt:5()) Pattern( @@ -584,11 +658,13 @@ MypyFile:1( NameExpr(B) Positionals( LiteralPattern:6( - 1)) + 1 + IntExpr(1))) Keyword( b LiteralPattern:6( - 2)))) + 2 + IntExpr(2))))) Body( PassStmt:7()) Pattern( @@ -597,10 +673,12 @@ MypyFile:1( Keyword( a LiteralPattern:8( - 1)) + 1 + IntExpr(1))) Keyword( b LiteralPattern:8( - 2)))) + 2 + IntExpr(2))))) Body( - PassStmt:9()))) \ No newline at end of file + PassStmt:9()))) diff --git a/test-data/unit/semanal-python310.test b/test-data/unit/semanal-python310.test index b1e6fef0bf81..a3f2924dc130 100644 --- a/test-data/unit/semanal-python310.test +++ b/test-data/unit/semanal-python310.test @@ -61,7 +61,8 @@ MypyFile:1( MappingPattern:3( Key( LiteralPattern:3( - 'k')) + 'k' + StrExpr(k))) Value( CapturePattern:3( NameExpr(b* [__main__.b])))))) @@ -86,7 +87,8 @@ MypyFile:1( Pattern( AsPattern:3( LiteralPattern:3( - 1) + 1 + IntExpr(1)) NameExpr(a* [__main__.a]))) Body( ExpressionStmt:4( @@ -110,7 +112,8 @@ MypyFile:1( NameExpr(x [__main__.x]) Pattern( LiteralPattern:4( - 1)) + 1 + IntExpr(1))) Guard( NameExpr(a [__main__.a])) Body( @@ -151,7 +154,8 @@ MypyFile:1( Pattern( AsPattern:3( LiteralPattern:3( - 1) + 1 + IntExpr(1)) NameExpr(a* [__main__.a]))) Guard( NameExpr(a [__main__.a])) From 3aab780966513459a152abc57e749d07bac4795d Mon Sep 17 00:00:00 2001 From: Adrian Freund Date: Tue, 6 Apr 2021 18:23:36 +0200 Subject: [PATCH 19/76] Use double quotes in error messages --- mypy/checkpattern.py | 6 +- test-data/unit/check-python310.test | 132 ++++++++++++++-------------- 2 files changed, 69 insertions(+), 69 deletions(-) diff --git a/mypy/checkpattern.py b/mypy/checkpattern.py index a66b9b3b47a6..bb707a755afc 100644 --- a/mypy/checkpattern.py +++ b/mypy/checkpattern.py @@ -237,7 +237,7 @@ def visit_class_pattern(self, o: ClassPattern) -> Optional[Type]: if isinstance(sym.node, (TypeAlias, TypeInfo)): typ = self.chk.named_type(class_name) else: - self.msg.fail("Class pattern must be a type. Found '{}'".format(sym.type), o.class_ref) + self.msg.fail('Class pattern must be a type. Found "{}"'.format(sym.type), o.class_ref) typ = self.chk.named_type("builtins.object") can_match = False match_args_type = find_member("__match_args__", typ, typ) @@ -267,10 +267,10 @@ def visit_class_pattern(self, o: ClassPattern) -> Optional[Type]: for key, value in zip(o.keyword_keys, o.keyword_values): keyword_pairs.append((key, value)) if key in match_arg_names: - self.msg.fail("Keyword '{}' already matches a positional pattern".format(key), + self.msg.fail('Keyword "{}" already matches a positional pattern'.format(key), value) elif key in keyword_names: - self.msg.fail("Duplicate keyword pattern '{}'".format(key), value) + self.msg.fail('Duplicate keyword pattern "{}"'.format(key), value) keyword_names.add(key) for keyword, pattern in keyword_pairs: diff --git a/test-data/unit/check-python310.test b/test-data/unit/check-python310.test index c06a29ee59e2..d32ccfc2327c 100644 --- a/test-data/unit/check-python310.test +++ b/test-data/unit/check-python310.test @@ -5,7 +5,7 @@ m: A match m: case a: - reveal_type(a) # N: Revealed type is '__main__.A' + reveal_type(a) # N: Revealed type is "__main__.A" [case testCapturePatternPreexistingSame] class A: ... @@ -14,7 +14,7 @@ m: A match m: case a: - reveal_type(a) # N: Revealed type is '__main__.A' + reveal_type(a) # N: Revealed type is "__main__.A" [case testCapturePatternPreexistingIncompatible] class A: ... @@ -24,7 +24,7 @@ m: A match m: case a: # E: Incompatible types in capture pattern (pattern captures type "B", variable has type "A") - reveal_type(a) # N: Revealed type is '__main__.B' + reveal_type(a) # N: Revealed type is "__main__.B" -- Literal Pattern -- @@ -33,14 +33,14 @@ m: object match m: case 1: - reveal_type(m) # N: Revealed type is 'builtins.int' + reveal_type(m) # N: Revealed type is "builtins.int" [case testLiteralPatternAlreadyNarrower] m: bool match m: case 1: - reveal_type(m) # N: Revealed type is 'builtins.bool' + reveal_type(m) # N: Revealed type is "builtins.bool" [case testLiteralPatternUnreachable] m: int @@ -57,7 +57,7 @@ m: object match m: case b.b: - reveal_type(m) # N: Revealed type is 'builtins.int' + reveal_type(m) # N: Revealed type is "builtins.int" [file b.py] b: int @@ -67,7 +67,7 @@ m: bool match m: case b.b: - reveal_type(m) # N: Revealed type is 'builtins.bool' + reveal_type(m) # N: Revealed type is "builtins.bool" [file b.py] b: int @@ -89,7 +89,7 @@ m: List[int] match m: case [a]: - reveal_type(a) # N: Revealed type is 'builtins.int*' + reveal_type(a) # N: Revealed type is "builtins.int*" [builtins fixtures/list.pyi] [case testSequencePatternCapturesStarred] @@ -98,8 +98,8 @@ m: Iterable[int] match m: case [a, *b]: - reveal_type(a) # N: Revealed type is 'builtins.int' - reveal_type(b) # N: Revealed type is 'builtins.list[builtins.int]' + reveal_type(a) # N: Revealed type is "builtins.int" + reveal_type(b) # N: Revealed type is "builtins.list[builtins.int]" [builtins fixtures/list.pyi] [case testSequencePatternNarrowsInner] @@ -108,7 +108,7 @@ m: Iterable[object] match m: case [1, True]: - reveal_type(m) # N: Revealed type is 'typing.Iterable[builtins.int]' + reveal_type(m) # N: Revealed type is "typing.Iterable[builtins.int]" [case testSequencePatternNarrowsOuter] from typing import Sequence @@ -116,7 +116,7 @@ m: object match m: case [1, True]: - reveal_type(m) # N: Revealed type is 'typing.Iterable[builtins.int]' + reveal_type(m) # N: Revealed type is "typing.Iterable[builtins.int]" [case testSequencePatternAlreadyNarrowerInner] from typing import Iterable @@ -124,7 +124,7 @@ m: Iterable[bool] match m: case [1, True]: - reveal_type(m) # N: Revealed type is 'typing.Iterable[builtins.bool]' + reveal_type(m) # N: Revealed type is "typing.Iterable[builtins.bool]" [case testSequencePatternAlreadyNarrowerOuter] from typing import Sequence @@ -132,7 +132,7 @@ m: Sequence[object] match m: case [1, True]: - reveal_type(m) # N: Revealed type is 'typing.Sequence[builtins.int]' + reveal_type(m) # N: Revealed type is "typing.Sequence[builtins.int]" [case testSequencePatternAlreadyNarrowerBoth] from typing import Sequence @@ -140,7 +140,7 @@ m: Sequence[bool] match m: case [1, True]: - reveal_type(m) # N: Revealed type is 'typing.Sequence[builtins.bool]' + reveal_type(m) # N: Revealed type is "typing.Sequence[builtins.bool]" [case testNestedSequencePatternNarrowsInner] from typing import Iterable @@ -148,7 +148,7 @@ m: Iterable[Iterable[object]] match m: case [[1], [True]]: - reveal_type(m) # N: Revealed type is 'typing.Iterable[typing.Iterable[builtins.int]]' + reveal_type(m) # N: Revealed type is "typing.Iterable[typing.Iterable[builtins.int]]" [case testNestedSequencePatternNarrowsOuter] from typing import Iterable @@ -156,7 +156,7 @@ m: object match m: case [[1], [True]]: - reveal_type(m) # N: Revealed type is 'typing.Iterable[typing.Iterable[builtins.int]]' + reveal_type(m) # N: Revealed type is "typing.Iterable[typing.Iterable[builtins.int]]" [case testSequencePatternDoesntNarrowInvariant] @@ -165,7 +165,7 @@ m: List[object] match m: case [1]: - reveal_type(m) # N: Revealed type is 'builtins.list[builtins.object]' + reveal_type(m) # N: Revealed type is "builtins.list[builtins.object]" [builtins fixtures/list.pyi] @@ -177,9 +177,9 @@ m: Dict[str, int] match m: case {"key": v}: - reveal_type(v) # N: Revealed type is 'builtins.int*' + reveal_type(v) # N: Revealed type is "builtins.int*" case {b.b: v2}: - reveal_type(v2) # N: Revealed type is 'builtins.int*' + reveal_type(v2) # N: Revealed type is "builtins.int*" [file b.py] b: str [builtins fixtures/dict.pyi] @@ -192,9 +192,9 @@ m: Dict[str, int] match m: case {1: v}: - reveal_type(v) # N: Revealed type is 'builtins.int*' + reveal_type(v) # N: Revealed type is "builtins.int*" case {b.b: v2}: - reveal_type(v2) # N: Revealed type is 'builtins.int*' + reveal_type(v2) # N: Revealed type is "builtins.int*" [file b.py] b: int [builtins fixtures/dict.pyi] @@ -210,14 +210,14 @@ m: A match m: case {"a": v}: - reveal_type(v) # N: Revealed type is 'builtins.str' + reveal_type(v) # N: Revealed type is "builtins.str" case {"b": v2}: - reveal_type(v2) # N: Revealed type is 'builtins.int' + reveal_type(v2) # N: Revealed type is "builtins.int" case {"a": v3, "b": v4}: - reveal_type(v3) # N: Revealed type is 'builtins.str' - reveal_type(v4) # N: Revealed type is 'builtins.int' + reveal_type(v3) # N: Revealed type is "builtins.str" + reveal_type(v4) # N: Revealed type is "builtins.int" case {"o": v5}: - reveal_type(v5) # N: Revealed type is 'builtins.object*' + reveal_type(v5) # N: Revealed type is "builtins.object*" [typing fixtures/typing-typeddict.pyi] [case testMappingPatternCapturesTypedDictWithLiteral] @@ -232,14 +232,14 @@ m: A match m: case {b.a: v}: - reveal_type(v) # N: Revealed type is 'builtins.str' + reveal_type(v) # N: Revealed type is "builtins.str" case {b.b: v2}: - reveal_type(v2) # N: Revealed type is 'builtins.int' + reveal_type(v2) # N: Revealed type is "builtins.int" case {b.a: v3, b.b: v4}: - reveal_type(v3) # N: Revealed type is 'builtins.str' - reveal_type(v4) # N: Revealed type is 'builtins.int' + reveal_type(v3) # N: Revealed type is "builtins.str" + reveal_type(v4) # N: Revealed type is "builtins.int" case {b.o: v5}: - reveal_type(v5) # N: Revealed type is 'builtins.object*' + reveal_type(v5) # N: Revealed type is "builtins.object*" [file b.py] from typing import Final, Literal a: Final = "a" @@ -259,7 +259,7 @@ m: A match m: case {b.a: v}: - reveal_type(v) # N: Revealed type is 'builtins.object*' + reveal_type(v) # N: Revealed type is "builtins.object*" [file b.py] from typing import Final, Literal a: str @@ -300,8 +300,8 @@ m: A match m: case A(i, j): - reveal_type(i) # N: Revealed type is 'builtins.str' - reveal_type(j) # N: Revealed type is 'builtins.int' + reveal_type(i) # N: Revealed type is "builtins.str" + reveal_type(j) # N: Revealed type is "builtins.int" [builtins fixtures/tuple.pyi] [case testClassPatternMemberClassCapturePositional] @@ -311,8 +311,8 @@ m: b.A match m: case b.A(i, j): - reveal_type(i) # N: Revealed type is 'builtins.str' - reveal_type(j) # N: Revealed type is 'builtins.int' + reveal_type(i) # N: Revealed type is "builtins.str" + reveal_type(j) # N: Revealed type is "builtins.int" [file b.py] from typing import Final @@ -331,35 +331,35 @@ m: A match m: case A(a=i, b=j): - reveal_type(i) # N: Revealed type is 'builtins.str' - reveal_type(j) # N: Revealed type is 'builtins.int' + reveal_type(i) # N: Revealed type is "builtins.str" + reveal_type(j) # N: Revealed type is "builtins.int" [case testClassPatternCaptureSelf] m: object match m: case bool(a): - reveal_type(a) # N: Revealed type is 'builtins.bool' + reveal_type(a) # N: Revealed type is "builtins.bool" case bytearray(b): - reveal_type(b) # N: Revealed type is 'builtins.bytearray' + reveal_type(b) # N: Revealed type is "builtins.bytearray" case bytes(c): - reveal_type(c) # N: Revealed type is 'builtins.bytes' + reveal_type(c) # N: Revealed type is "builtins.bytes" case dict(d): - reveal_type(d) # N: Revealed type is 'builtins.dict' + reveal_type(d) # N: Revealed type is "builtins.dict" case float(e): - reveal_type(e) # N: Revealed type is 'builtins.float' + reveal_type(e) # N: Revealed type is "builtins.float" case frozenset(f): - reveal_type(f) # N: Revealed type is 'builtins.frozenset' + reveal_type(f) # N: Revealed type is "builtins.frozenset" case int(g): - reveal_type(g) # N: Revealed type is 'builtins.int' + reveal_type(g) # N: Revealed type is "builtins.int" case list(h): - reveal_type(h) # N: Revealed type is 'builtins.list' + reveal_type(h) # N: Revealed type is "builtins.list" case set(i): - reveal_type(i) # N: Revealed type is 'builtins.set' + reveal_type(i) # N: Revealed type is "builtins.set" case str(j): - reveal_type(j) # N: Revealed type is 'builtins.str' + reveal_type(j) # N: Revealed type is "builtins.str" case tuple(k): - reveal_type(k) # N: Revealed type is 'builtins.tuple' + reveal_type(k) # N: Revealed type is "builtins.tuple" [builtins fixtures/dict.pyi] [builtins fixtures/list.pyi] [builtins fixtures/set.pyi] @@ -377,9 +377,9 @@ m: object match m: case A(): - reveal_type(m) # N: Revealed type is '__main__.A' + reveal_type(m) # N: Revealed type is "__main__.A" case A(i, j): - reveal_type(m) # N: Revealed type is '__main__.A' + reveal_type(m) # N: Revealed type is "__main__.A" [builtins fixtures/tuple.pyi] [case testClassPatternAlreadyNarrower] @@ -395,9 +395,9 @@ m: B match m: case A(): - reveal_type(m) # N: Revealed type is '__main__.B' + reveal_type(m) # N: Revealed type is "__main__.B" case A(i, j): - reveal_type(m) # N: Revealed type is '__main__.B' + reveal_type(m) # N: Revealed type is "__main__.B" [builtins fixtures/tuple.pyi] [case testClassPatternUnreachable] @@ -425,8 +425,8 @@ m: object match m: case A(a=j): - reveal_type(m) # N: Revealed type is '__main__.A' - reveal_type(j) # N: Revealed type is 'builtins.object' + reveal_type(m) # N: Revealed type is "__main__.A" + reveal_type(j) # N: Revealed type is "builtins.object" [case testClassPatternDuplicateKeyword] class A: @@ -435,7 +435,7 @@ class A: m: object match m: - case A(a=i, a=j): # E: Duplicate keyword pattern 'a' + case A(a=i, a=j): # E: Duplicate keyword pattern "a" pass [case testClassPatternDuplicateImplicitKeyword] @@ -448,7 +448,7 @@ class A: m: object match m: - case A(i, a=j): # E: Keyword 'a' already matches a positional pattern + case A(i, a=j): # E: Keyword "a" already matches a positional pattern pass [builtins fixtures/tuple.pyi] @@ -472,7 +472,7 @@ a = 1 m: object match m: - case a(i, j): # E: Class pattern must be a type. Found 'builtins.int' + case a(i, j): # E: Class pattern must be a type. Found "builtins.int" reveal_type(i) reveal_type(j) @@ -486,8 +486,8 @@ m: object match m: case A(i, j): - reveal_type(i) # N: Revealed type is 'builtins.object' - reveal_type(j) # N: Revealed type is 'builtins.object' + reveal_type(i) # N: Revealed type is "builtins.object" + reveal_type(j) # N: Revealed type is "builtins.object" [builtins fixtures/tuple.pyi] [case testAnyTupleMatchArgs] @@ -502,9 +502,9 @@ m: object match m: case A(i, j, k): - reveal_type(i) # N: Revealed type is 'builtins.object' - reveal_type(j) # N: Revealed type is 'builtins.object' - reveal_type(k) # N: Revealed type is 'builtins.object' + reveal_type(i) # N: Revealed type is "builtins.object" + reveal_type(j) # N: Revealed type is "builtins.object" + reveal_type(k) # N: Revealed type is "builtins.object" [builtins fixtures/tuple.pyi] [case testNonLiteralMatchArgs] @@ -522,8 +522,8 @@ match m: case A(i, j, k): # E: Too many positional patterns for class pattern pass case A(i, j): - reveal_type(i) # N: Revealed type is 'builtins.str' - reveal_type(j) # N: Revealed type is 'builtins.object' + reveal_type(i) # N: Revealed type is "builtins.str" + reveal_type(j) # N: Revealed type is "builtins.object" [builtins fixtures/tuple.pyi] [case testExternalMatchArgs] From 3a71c3242533830d9721d4c445556eeb9cb4bae7 Mon Sep 17 00:00:00 2001 From: Adrian Freund Date: Tue, 6 Apr 2021 20:07:43 +0200 Subject: [PATCH 20/76] Add support for match statement self matching class patterns --- mypy/checkpattern.py | 101 +++++++++++++++++++------ test-data/unit/check-python310.test | 77 ++++++++++++++++--- test-data/unit/fixtures/primitives.pyi | 6 +- 3 files changed, 152 insertions(+), 32 deletions(-) diff --git a/mypy/checkpattern.py b/mypy/checkpattern.py index bb707a755afc..2bf6fee0fe49 100644 --- a/mypy/checkpattern.py +++ b/mypy/checkpattern.py @@ -23,6 +23,21 @@ import mypy.checker +self_match_type_names = [ + "builtins.bool", + "builtins.bytearray", + "builtins.bytes", + "builtins.dict", + "builtins.float", + "builtins.frozenset", + "builtins.int", + "builtins.list", + "builtins.set", + "builtins.str", + "builtins.tuple", +] + + class PatternChecker(PatternVisitor[Optional[Type]]): """Pattern checker. @@ -41,6 +56,8 @@ class PatternChecker(PatternVisitor[Optional[Type]]): # Type of the subject to check the (sub)pattern against type_stack = [] # type: List[ProperType] + self_match_types = None # type: List[Type] + def __init__(self, chk: 'mypy.checker.TypeChecker', msg: MessageBuilder, plugin: Plugin, subject: Expression, subject_type: ProperType) -> None: self.chk = chk @@ -49,6 +66,8 @@ def __init__(self, chk: 'mypy.checker.TypeChecker', msg: MessageBuilder, plugin: self.subject = subject self.type_stack.append(subject_type) + self.self_match_types = self.generate_self_match_types() + def check_pattern(self, o: Pattern) -> 'mypy.checker.TypeMap': pattern_type = self.visit(o) if pattern_type is None: @@ -229,39 +248,60 @@ def get_simple_mapping_item_type(self, return result def visit_class_pattern(self, o: ClassPattern) -> Optional[Type]: - can_match = True - + current_type = self.type_stack[-1] class_name = o.class_ref.fullname assert class_name is not None sym = self.chk.lookup_qualified(class_name) + if isinstance(sym.node, TypeAlias) and not sym.node.no_args: + self.msg.fail("Class pattern class must not be a type alias with type parameters", o) + return None if isinstance(sym.node, (TypeAlias, TypeInfo)): typ = self.chk.named_type(class_name) else: self.msg.fail('Class pattern must be a type. Found "{}"'.format(sym.type), o.class_ref) - typ = self.chk.named_type("builtins.object") - can_match = False - match_args_type = find_member("__match_args__", typ, typ) + return None - if match_args_type is None and can_match: - if len(o.positionals) >= 1: - self.msg.fail("Class doesn't define __match_args__", o) + keyword_pairs = [] # type: List[Tuple[Optional[str], Pattern]] + match_arg_names = [] # type: List[Optional[str]] - proper_match_args_type = get_proper_type(match_args_type) - if isinstance(proper_match_args_type, TupleType): - match_arg_names = get_match_arg_names(proper_match_args_type) + can_match = True - if len(o.positionals) > len(match_arg_names): - self.msg.fail("Too many positional patterns for class pattern", o) - match_arg_names += [None] * (len(o.positionals) - len(match_arg_names)) + if self.should_self_match(typ): + if len(o.positionals) >= 1: + self.type_stack.append(typ) + if self.visit(o.positionals[0]) is None: + can_match = False + self.type_stack.pop() + + if len(o.positionals) > 1: + self.msg.fail("Too many positional patterns for class pattern", o) + self.type_stack.append(self.chk.named_type("builtins.object")) + for p in o.positionals[1:]: + if self.visit(p) is None: + can_match = False + self.type_stack.pop() else: - match_arg_names = [None] * len(o.positionals) + match_args_type = find_member("__match_args__", typ, typ) - positional_names = set() - keyword_pairs = [] # type: List[Tuple[Optional[str], Pattern]] + if match_args_type is None and can_match: + if len(o.positionals) >= 1: + self.msg.fail("Class doesn't define __match_args__", o) - for arg_name, pos in zip(match_arg_names, o.positionals): - keyword_pairs.append((arg_name, pos)) - positional_names.add(arg_name) + proper_match_args_type = get_proper_type(match_args_type) + if isinstance(proper_match_args_type, TupleType): + match_arg_names = get_match_arg_names(proper_match_args_type) + + if len(o.positionals) > len(match_arg_names): + self.msg.fail("Too many positional patterns for class pattern", o) + match_arg_names += [None] * (len(o.positionals) - len(match_arg_names)) + else: + match_arg_names = [None] * len(o.positionals) + + positional_names = set() + + for arg_name, pos in zip(match_arg_names, o.positionals): + keyword_pairs.append((arg_name, pos)) + positional_names.add(arg_name) keyword_names = set() for key, value in zip(o.keyword_keys, o.keyword_values): @@ -275,7 +315,7 @@ def visit_class_pattern(self, o: ClassPattern) -> Optional[Type]: for keyword, pattern in keyword_pairs: if keyword is not None: - key_type = find_member(keyword, typ, typ) + key_type = find_member(keyword, typ, current_type) if key_type is None: key_type = self.chk.named_type("builtins.object") else: @@ -287,10 +327,27 @@ def visit_class_pattern(self, o: ClassPattern) -> Optional[Type]: self.type_stack.pop() if can_match: - return get_more_specific_type(self.type_stack[-1], typ) + return get_more_specific_type(current_type, typ) else: return None + def should_self_match(self, typ: Type): + for other in self.self_match_types: + if is_subtype(typ, other): + return True + return False + + def generate_self_match_types(self) -> List[Type]: + types = [] + for name in self_match_type_names: + try: + types.append(self.chk.named_type(name)) + except KeyError: + # Some built in types are not defined in all test cases + pass + + return types + def get_match_arg_names(typ: TupleType) -> List[Optional[str]]: args = [] # type: List[Optional[str]] diff --git a/test-data/unit/check-python310.test b/test-data/unit/check-python310.test index d32ccfc2327c..19e91ef60437 100644 --- a/test-data/unit/check-python310.test +++ b/test-data/unit/check-python310.test @@ -345,25 +345,84 @@ match m: case bytes(c): reveal_type(c) # N: Revealed type is "builtins.bytes" case dict(d): - reveal_type(d) # N: Revealed type is "builtins.dict" + reveal_type(d) # N: Revealed type is "builtins.dict[Any, Any]" case float(e): reveal_type(e) # N: Revealed type is "builtins.float" case frozenset(f): - reveal_type(f) # N: Revealed type is "builtins.frozenset" + reveal_type(f) # N: Revealed type is "builtins.frozenset[Any]" case int(g): reveal_type(g) # N: Revealed type is "builtins.int" case list(h): - reveal_type(h) # N: Revealed type is "builtins.list" + reveal_type(h) # N: Revealed type is "builtins.list[Any]" case set(i): - reveal_type(i) # N: Revealed type is "builtins.set" + reveal_type(i) # N: Revealed type is "builtins.set[Any]" case str(j): reveal_type(j) # N: Revealed type is "builtins.str" case tuple(k): - reveal_type(k) # N: Revealed type is "builtins.tuple" -[builtins fixtures/dict.pyi] -[builtins fixtures/list.pyi] -[builtins fixtures/set.pyi] -[builtins fixtures/tuple.pyi] + reveal_type(k) # N: Revealed type is "builtins.tuple[Any]" +[builtins fixtures/primitives.pyi] + +[case testClassPatternCaptureGeneric] +from typing import Generic, TypeVar + +T = TypeVar('T') + +class A(Generic[T]): + a: T + +m: object + +match m: + case A(a=i): + reveal_type(m) # N: Revealed type is "__main__.A[Any]" + reveal_type(i) # N: Revealed type is "Any" + +[case testClassPatternCaptureGenericAlreadyKnown] +from typing import Generic, TypeVar + +T = TypeVar('T') + +class A(Generic[T]): + a: T + +m: A[int] + +match m: + case A(a=i): + reveal_type(m) # N: Revealed type is "__main__.A[builtins.int]" + reveal_type(i) # N: Revealed type is "builtins.int*" + +[case testClassPatternCaptureFilledGenericTypeAlias] +from typing import Generic, TypeVar + +T = TypeVar('T') + +class A(Generic[T]): + a: T + +B = A[int] + +m: object + +match m: + case B(a=i): # E: Class pattern class must not be a type alias with type parameters + reveal_type(i) + +[case testClassPatternCaptureGenericTypeAlias] +from typing import Generic, TypeVar + +T = TypeVar('T') + +class A(Generic[T]): + a: T + +B = A + +m: object + +match m: + case B(a=i): + pass [case testClassPatternNarrows] from typing import Final diff --git a/test-data/unit/fixtures/primitives.pyi b/test-data/unit/fixtures/primitives.pyi index 71f59a9c1d8c..24cb5ea45ff2 100644 --- a/test-data/unit/fixtures/primitives.pyi +++ b/test-data/unit/fixtures/primitives.pyi @@ -1,5 +1,5 @@ # builtins stub with non-generic primitive types -from typing import Generic, TypeVar, Sequence, Iterator, Mapping +from typing import Generic, TypeVar, Sequence, Iterator, Mapping, Iterable T = TypeVar('T') V = TypeVar('V') @@ -48,5 +48,9 @@ class list(Sequence[T]): def __getitem__(self, item: int) -> T: pass class dict(Mapping[T, V]): def __iter__(self) -> Iterator[T]: pass +class set(Iterable[T]): + def __iter__(self) -> Iterator[T]: pass +class frozenset(Iterable[T]): + def __iter__(self) -> Iterator[T]: pass class function: pass class ellipsis: pass From b7468d91771f26d7c4b771e853e252bd3eece600 Mon Sep 17 00:00:00 2001 From: Adrian Freund Date: Wed, 7 Apr 2021 15:46:39 +0200 Subject: [PATCH 21/76] Replace more single quotes in tests with double quotes --- test-data/unit/semanal-errors-python310.test | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test-data/unit/semanal-errors-python310.test b/test-data/unit/semanal-errors-python310.test index e8eda6ca87d7..28fb43eea6ee 100644 --- a/test-data/unit/semanal-errors-python310.test +++ b/test-data/unit/semanal-errors-python310.test @@ -4,7 +4,7 @@ match x: case _: pass [out] -main:2: error: Name 'x' is not defined +main:2: error: Name "x" is not defined [case testNoneBindingWildcardPattern] import typing @@ -13,7 +13,7 @@ match x: case _: _ [out] -main:5: error: Name '_' is not defined +main:5: error: Name "_" is not defined [case testNoneBindingStarredWildcardPattern] import typing @@ -22,4 +22,4 @@ match x: case [*_]: _ [out] -main:5: error: Name '_' is not defined +main:5: error: Name "_" is not defined From 943f2ff1c5e07f58e73df8b08d30dae9759820ae Mon Sep 17 00:00:00 2001 From: Adrian Freund Date: Wed, 7 Apr 2021 15:47:53 +0200 Subject: [PATCH 22/76] Add support for __match_args__ to dataclass --- mypy/checkpattern.py | 4 +-- mypy/plugins/common.py | 27 ++++++++++++++++++++ mypy/plugins/dataclasses.py | 12 +++++++-- mypy/util.py | 1 + test-data/unit/check-incomplete-fixture.test | 19 -------------- test-data/unit/check-python310.test | 15 +++++++++++ test-data/unit/deps.test | 1 + test-data/unit/fixtures/attr.pyi | 1 + test-data/unit/lib-stub/builtins.pyi | 1 + 9 files changed, 58 insertions(+), 23 deletions(-) diff --git a/mypy/checkpattern.py b/mypy/checkpattern.py index 2bf6fee0fe49..0ec9a1886e86 100644 --- a/mypy/checkpattern.py +++ b/mypy/checkpattern.py @@ -331,14 +331,14 @@ def visit_class_pattern(self, o: ClassPattern) -> Optional[Type]: else: return None - def should_self_match(self, typ: Type): + def should_self_match(self, typ: Type) -> bool: for other in self.self_match_types: if is_subtype(typ, other): return True return False def generate_self_match_types(self) -> List[Type]: - types = [] + types = [] # type: List[Type] for name in self_match_type_names: try: types.append(self.chk.named_type(name)) diff --git a/mypy/plugins/common.py b/mypy/plugins/common.py index 536022a1e09e..0ee1e23017a4 100644 --- a/mypy/plugins/common.py +++ b/mypy/plugins/common.py @@ -154,6 +154,33 @@ def add_method_to_class( info.defn.defs.body.append(func) +def add_attribute_to_class( + api: SemanticAnalyzerPluginInterface, + cls: ClassDef, + name: str, + typ: Type, + final: bool = False, +) -> None: + """ + Adds a new attribute to a class definition. + This currently only generates the symbol table entry and no corresponding AssignmentStatement + """ + info = cls.info + + # NOTE: we would like the plugin generated node to dominate, but we still + # need to keep any existing definitions so they get semantically analyzed. + if name in info.names: + # Get a nice unique name instead. + r_name = get_unique_redefinition_name(name, info.names) + info.names[r_name] = info.names[name] + + node = Var(name, typ) + node.info = info + node.is_final = final + node._fullname = api.qualified_name(name) + info.names[name] = SymbolTableNode(MDEF, node, plugin_generated=True) + + def deserialize_and_fixup_type( data: Union[str, JsonDict], api: SemanticAnalyzerPluginInterface ) -> Type: diff --git a/mypy/plugins/dataclasses.py b/mypy/plugins/dataclasses.py index 5765e0599759..9fc327564169 100644 --- a/mypy/plugins/dataclasses.py +++ b/mypy/plugins/dataclasses.py @@ -10,10 +10,12 @@ ) from mypy.plugin import ClassDefContext, SemanticAnalyzerPluginInterface from mypy.plugins.common import ( - add_method, _get_decorator_bool_argument, deserialize_and_fixup_type, + add_method, _get_decorator_bool_argument, deserialize_and_fixup_type, add_attribute_to_class, ) from mypy.typeops import map_type_from_supertype -from mypy.types import Type, Instance, NoneType, TypeVarDef, TypeVarType, get_proper_type +from mypy.types import ( + Type, Instance, NoneType, TypeVarDef, TypeVarType, get_proper_type, TupleType, LiteralType, +) from mypy.server.trigger import make_wildcard_trigger # The set of decorators that generate dataclasses. @@ -173,6 +175,12 @@ def transform(self) -> None: self.reset_init_only_vars(info, attributes) + # Add __match_args__. This is always added + str_type = ctx.api.named_type("__builtins__.str") + literals = [LiteralType(attr.name, str_type) for attr in attributes] # type: List[Type] + match_args_type = TupleType(literals, ctx.api.named_type("__builtins__.tuple")) + add_attribute_to_class(ctx.api, ctx.cls, "__match_args__", match_args_type, final=True) + info.metadata['dataclass'] = { 'attributes': [attr.serialize() for attr in attributes], 'frozen': decorator_arguments['frozen'], diff --git a/mypy/util.py b/mypy/util.py index e34dffcd3ab0..9cc01795be09 100644 --- a/mypy/util.py +++ b/mypy/util.py @@ -281,6 +281,7 @@ def id(self, o: object) -> int: def get_prefix(fullname: str) -> str: """Drop the final component of a qualified name (e.g. ('x.y' -> 'x').""" + pass return fullname.rsplit('.', 1)[0] diff --git a/test-data/unit/check-incomplete-fixture.test b/test-data/unit/check-incomplete-fixture.test index d083d2f9f2d2..dab986719369 100644 --- a/test-data/unit/check-incomplete-fixture.test +++ b/test-data/unit/check-incomplete-fixture.test @@ -58,25 +58,6 @@ main:1: error: Name "isinstance" is not defined main:1: note: Maybe your test fixture does not define "builtins.isinstance"? main:1: note: Consider adding [builtins fixtures/isinstancelist.pyi] to your test description -[case testTupleMissingFromStubs1] -tuple() -[out] -main:1: error: Name "tuple" is not defined -main:1: note: Maybe your test fixture does not define "builtins.tuple"? -main:1: note: Consider adding [builtins fixtures/tuple.pyi] to your test description -main:1: note: Did you forget to import it from "typing"? (Suggestion: "from typing import Tuple") - -[case testTupleMissingFromStubs2] -tuple() -from typing import Tuple -x: Tuple[int, str] -[out] -main:1: error: Name "tuple" is not defined -main:1: note: Maybe your test fixture does not define "builtins.tuple"? -main:1: note: Consider adding [builtins fixtures/tuple.pyi] to your test description -main:1: note: Did you forget to import it from "typing"? (Suggestion: "from typing import Tuple") -main:3: error: Name "tuple" is not defined - [case testClassmethodMissingFromStubs] class A: @classmethod diff --git a/test-data/unit/check-python310.test b/test-data/unit/check-python310.test index 19e91ef60437..63e7f0ed1c09 100644 --- a/test-data/unit/check-python310.test +++ b/test-data/unit/check-python310.test @@ -362,6 +362,21 @@ match m: reveal_type(k) # N: Revealed type is "builtins.tuple[Any]" [builtins fixtures/primitives.pyi] +[case testClassPatternCaptureDataclass] +from dataclasses import dataclass + +@dataclass +class A: + a: str + b: int + +m: A + +match m: + case A(i, j): + reveal_type(i) # N: Revealed type is "builtins.str" + reveal_type(j) # N: Revealed type is "builtins.int" + [case testClassPatternCaptureGeneric] from typing import Generic, TypeVar diff --git a/test-data/unit/deps.test b/test-data/unit/deps.test index 8c074abc83a2..0da9fb61bd64 100644 --- a/test-data/unit/deps.test +++ b/test-data/unit/deps.test @@ -1436,6 +1436,7 @@ class B(A): -> , m -> -> , m.B.__init__ + -> -> -> -> diff --git a/test-data/unit/fixtures/attr.pyi b/test-data/unit/fixtures/attr.pyi index deb1906d931e..bea65e7131d8 100644 --- a/test-data/unit/fixtures/attr.pyi +++ b/test-data/unit/fixtures/attr.pyi @@ -24,4 +24,5 @@ class complex: class str: pass class unicode: pass +class tuple: pass class ellipsis: pass diff --git a/test-data/unit/lib-stub/builtins.pyi b/test-data/unit/lib-stub/builtins.pyi index 70c24743b62a..c151425bc8e0 100644 --- a/test-data/unit/lib-stub/builtins.pyi +++ b/test-data/unit/lib-stub/builtins.pyi @@ -17,6 +17,7 @@ class bool(int): pass class str: pass class bytes: pass +class tuple: pass class function: pass class ellipsis: pass From b41fa30f29f78b34f8848469a8f25582ddaccffc Mon Sep 17 00:00:00 2001 From: Adrian Freund Date: Wed, 7 Apr 2021 15:51:00 +0200 Subject: [PATCH 23/76] Minor code cleanup --- mypy/plugins/dataclasses.py | 2 +- mypy/util.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/mypy/plugins/dataclasses.py b/mypy/plugins/dataclasses.py index 9fc327564169..c991521046bc 100644 --- a/mypy/plugins/dataclasses.py +++ b/mypy/plugins/dataclasses.py @@ -14,7 +14,7 @@ ) from mypy.typeops import map_type_from_supertype from mypy.types import ( - Type, Instance, NoneType, TypeVarDef, TypeVarType, get_proper_type, TupleType, LiteralType, + Type, Instance, NoneType, TypeVarDef, TypeVarType, TupleType, LiteralType, get_proper_type, ) from mypy.server.trigger import make_wildcard_trigger diff --git a/mypy/util.py b/mypy/util.py index 9cc01795be09..e34dffcd3ab0 100644 --- a/mypy/util.py +++ b/mypy/util.py @@ -281,7 +281,6 @@ def id(self, o: object) -> int: def get_prefix(fullname: str) -> str: """Drop the final component of a qualified name (e.g. ('x.y' -> 'x').""" - pass return fullname.rsplit('.', 1)[0] From 037a140cd9a11d715228c1fbf7556dbc641419fd Mon Sep 17 00:00:00 2001 From: Adrian Freund Date: Wed, 7 Apr 2021 16:52:54 +0200 Subject: [PATCH 24/76] Add support __match_args__ to namedtuple --- mypy/checkpattern.py | 3 +++ mypy/semanal_namedtuple.py | 6 ++++- test-data/unit/check-python310.test | 27 +++++++++++++++++++ test-data/unit/merge.test | 40 +++++++++++++++-------------- 4 files changed, 56 insertions(+), 20 deletions(-) diff --git a/mypy/checkpattern.py b/mypy/checkpattern.py index 0ec9a1886e86..309e497a152c 100644 --- a/mypy/checkpattern.py +++ b/mypy/checkpattern.py @@ -332,6 +332,9 @@ def visit_class_pattern(self, o: ClassPattern) -> Optional[Type]: return None def should_self_match(self, typ: Type) -> bool: + proper_type = get_proper_type(typ) + if isinstance(proper_type, Instance) and proper_type.type.is_named_tuple: + return False for other in self.self_match_types: if is_subtype(typ, other): return True diff --git a/mypy/semanal_namedtuple.py b/mypy/semanal_namedtuple.py index 0067fba22322..27e68008236c 100644 --- a/mypy/semanal_namedtuple.py +++ b/mypy/semanal_namedtuple.py @@ -9,7 +9,7 @@ from mypy.types import ( Type, TupleType, AnyType, TypeOfAny, TypeVarDef, CallableType, TypeType, TypeVarType, - UnboundType, + UnboundType, LiteralType, ) from mypy.semanal_shared import ( SemanticAnalyzerInterface, set_callable_name, calculate_tuple_fallback, PRIORITY_FALLBACKS @@ -382,6 +382,9 @@ def build_namedtuple_typeinfo(self, iterable_type = self.api.named_type_or_none('typing.Iterable', [implicit_any]) function_type = self.api.named_type('__builtins__.function') + literals = [LiteralType(item, strtype) for item in items] # type: List[Type] + match_args_type = TupleType(literals, basetuple_type) + info = self.api.basic_new_typeinfo(name, fallback) info.is_named_tuple = True tuple_base = TupleType(types, fallback) @@ -420,6 +423,7 @@ def add_field(var: Var, is_initialized_in_class: bool = False, add_field(Var('_source', strtype), is_initialized_in_class=True) add_field(Var('__annotations__', ordereddictype), is_initialized_in_class=True) add_field(Var('__doc__', strtype), is_initialized_in_class=True) + add_field(Var('__match_args__', match_args_type), is_initialized_in_class=True) tvd = TypeVarDef(SELF_TVAR_NAME, info.fullname + '.' + SELF_TVAR_NAME, -1, [], info.tuple_type) diff --git a/test-data/unit/check-python310.test b/test-data/unit/check-python310.test index 63e7f0ed1c09..95d240526dbe 100644 --- a/test-data/unit/check-python310.test +++ b/test-data/unit/check-python310.test @@ -372,6 +372,33 @@ class A: m: A +match m: + case A(i, j): + reveal_type(i) # N: Revealed type is "builtins.str" + reveal_type(j) # N: Revealed type is "builtins.int" + +[case testClassPatternCaptureNamedTupleInline] +from collections import namedtuple + +A = namedtuple("A", ["a", "b"]) + +m: A + +match m: + case A(i, j): + reveal_type(i) # N: Revealed type is "Any" + reveal_type(j) # N: Revealed type is "Any" +[builtins fixtures/list.pyi] + +[case testClassPatternCaptureNamedTupleClass] +from typing import NamedTuple + +class A(NamedTuple): + a: str + b: int + +m: A + match m: case A(i, j): reveal_type(i) # N: Revealed type is "builtins.str" diff --git a/test-data/unit/merge.test b/test-data/unit/merge.test index 836ad87857f8..b5d68899f019 100644 --- a/test-data/unit/merge.test +++ b/test-data/unit/merge.test @@ -671,15 +671,16 @@ TypeInfo<2>( _NT<6> __annotations__<7> (builtins.object<1>) __doc__<8> (builtins.str<9>) - __new__<10> - _asdict<11> - _field_defaults<12> (builtins.object<1>) - _field_types<13> (builtins.object<1>) - _fields<14> (Tuple[builtins.str<9>]) - _make<15> - _replace<16> - _source<17> (builtins.str<9>) - x<18> (target.A<0>))) + __match_args__<10> (Tuple[Literal['x']]) + __new__<11> + _asdict<12> + _field_defaults<13> (builtins.object<1>) + _field_types<14> (builtins.object<1>) + _fields<15> (Tuple[builtins.str<9>]) + _make<16> + _replace<17> + _source<18> (builtins.str<9>) + x<19> (target.A<0>))) ==> TypeInfo<0>( Name(target.A) @@ -694,16 +695,17 @@ TypeInfo<2>( _NT<6> __annotations__<7> (builtins.object<1>) __doc__<8> (builtins.str<9>) - __new__<10> - _asdict<11> - _field_defaults<12> (builtins.object<1>) - _field_types<13> (builtins.object<1>) - _fields<14> (Tuple[builtins.str<9>, builtins.str<9>]) - _make<15> - _replace<16> - _source<17> (builtins.str<9>) - x<18> (target.A<0>) - y<19> (target.A<0>))) + __match_args__<10> (Tuple[Literal['x'], Literal['y']]) + __new__<11> + _asdict<12> + _field_defaults<13> (builtins.object<1>) + _field_types<14> (builtins.object<1>) + _fields<15> (Tuple[builtins.str<9>, builtins.str<9>]) + _make<16> + _replace<17> + _source<18> (builtins.str<9>) + x<19> (target.A<0>) + y<20> (target.A<0>))) [case testUnionType_types] import target From fb9ae477a4c34a8634bb76a4b8e775d4d2f0cc31 Mon Sep 17 00:00:00 2001 From: Adrian Freund Date: Wed, 7 Apr 2021 18:17:47 +0200 Subject: [PATCH 25/76] Add support for match statement as pattern and change type_stack to Type --- mypy/checkpattern.py | 30 +++++++++++++---------- test-data/unit/check-python310.test | 37 +++++++++++++++++++++++++++++ 2 files changed, 55 insertions(+), 12 deletions(-) diff --git a/mypy/checkpattern.py b/mypy/checkpattern.py index 309e497a152c..efe3d0bd4193 100644 --- a/mypy/checkpattern.py +++ b/mypy/checkpattern.py @@ -54,12 +54,12 @@ class PatternChecker(PatternVisitor[Optional[Type]]): # The expression being matched against the pattern subject = None # type: Expression # Type of the subject to check the (sub)pattern against - type_stack = [] # type: List[ProperType] + type_stack = [] # type: List[Type] self_match_types = None # type: List[Type] def __init__(self, chk: 'mypy.checker.TypeChecker', msg: MessageBuilder, plugin: Plugin, - subject: Expression, subject_type: ProperType) -> None: + subject: Expression, subject_type: Type) -> None: self.chk = chk self.msg = msg self.plugin = plugin @@ -83,7 +83,14 @@ def visit(self, o: Pattern) -> Optional[Type]: return o.accept(self) def visit_as_pattern(self, o: AsPattern) -> Optional[Type]: - return self.type_stack[-1] + typ = self.visit(o.pattern) + specific_type = get_more_specific_type(typ, self.type_stack[-1]) + if specific_type is None: + return None + self.type_stack.append(specific_type) + self.check_capture(o.name) + self.type_stack.pop() + return typ def visit_or_pattern(self, o: OrPattern) -> Optional[Type]: return self.type_stack[-1] @@ -135,7 +142,7 @@ def visit_value_pattern(self, o: ValuePattern) -> Optional[Type]: def visit_sequence_pattern(self, o: SequencePattern) -> Optional[Type]: current_type = self.type_stack[-1] - inner_type = get_proper_type(self.get_sequence_type(current_type)) + inner_type = self.get_sequence_type(get_proper_type(current_type)) if inner_type is None: if is_subtype(self.chk.named_type("typing.Iterable"), current_type): # Current type is more general, but the actual value could still be iterable @@ -191,7 +198,6 @@ def visit_mapping_pattern(self, o: MappingPattern) -> Optional[Type]: if inner_type is None: can_match = False inner_type = self.chk.named_type("builtins.object") - inner_type = get_proper_type(inner_type) self.type_stack.append(inner_type) if self.visit(value) is None: can_match = False @@ -206,7 +212,6 @@ def get_mapping_item_type(self, mapping_type: Type, key_pattern: MappingKeyPattern ) -> Optional[Type]: - mapping_type = get_proper_type(mapping_type) local_errors = self.msg.clean_copy() local_errors.disable_count = 0 if isinstance(mapping_type, TypedDictType): @@ -321,7 +326,7 @@ def visit_class_pattern(self, o: ClassPattern) -> Optional[Type]: else: key_type = self.chk.named_type("builtins.object") - self.type_stack.append(get_proper_type(key_type)) + self.type_stack.append(key_type) if self.visit(pattern) is None: can_match = False self.type_stack.pop() @@ -331,9 +336,8 @@ def visit_class_pattern(self, o: ClassPattern) -> Optional[Type]: else: return None - def should_self_match(self, typ: Type) -> bool: - proper_type = get_proper_type(typ) - if isinstance(proper_type, Instance) and proper_type.type.is_named_tuple: + def should_self_match(self, typ: ProperType) -> bool: + if isinstance(typ, Instance) and typ.type.is_named_tuple: return False for other in self.self_match_types: if is_subtype(typ, other): @@ -363,8 +367,10 @@ def get_match_arg_names(typ: TupleType) -> List[Optional[str]]: return args -def get_more_specific_type(left: Type, right: Type) -> Optional[Type]: - if is_subtype(left, right): +def get_more_specific_type(left: Optional[Type], right: Optional[Type]) -> Optional[Type]: + if left is None or right is None: + return None + elif is_subtype(left, right): return left elif is_subtype(right, left): return right diff --git a/test-data/unit/check-python310.test b/test-data/unit/check-python310.test index 95d240526dbe..ce07e02fe453 100644 --- a/test-data/unit/check-python310.test +++ b/test-data/unit/check-python310.test @@ -390,6 +390,19 @@ match m: reveal_type(j) # N: Revealed type is "Any" [builtins fixtures/list.pyi] +[case testClassPatternCaptureNamedTupleInlineTyped] +from typing import NamedTuple + +A = NamedTuple("A", [("a", str), ("b", int)]) + +m: A + +match m: + case A(i, j): + reveal_type(i) # N: Revealed type is "builtins.str" + reveal_type(j) # N: Revealed type is "builtins.int" +[builtins fixtures/list.pyi] + [case testClassPatternCaptureNamedTupleClass] from typing import NamedTuple @@ -645,3 +658,27 @@ class B: [builtins fixtures/tuple.pyi] [typing fixtures/typing-medium.pyi] + + +-- as pattern -- +[case testAsPattern] +m: int + +match m: + case x as l: + reveal_type(x) # N: Revealed type is "builtins.int" + reveal_type(l) # N: Revealed type is "builtins.int" + +[case testAsPatternNarrows] +m: object + +match m: + case int() as l: + reveal_type(l) # N: Revealed type is "builtins.int" + +[case testAsPatternAlreadyNarrower] +m: bool + +match m: + case int() as l: + reveal_type(l) # N: Revealed type is "builtins.bool" From 99c1bbf9ea87f26f9fdc25404157fda8d76009bd Mon Sep 17 00:00:00 2001 From: Adrian Freund Date: Thu, 8 Apr 2021 09:58:34 +0200 Subject: [PATCH 26/76] Make PatternChecker infer unions for patterns with the same name --- mypy/checker.py | 15 ++--- mypy/checkpattern.py | 91 +++++++++++++++++++++-------- test-data/unit/check-python310.test | 58 ++++++++++++------ 3 files changed, 113 insertions(+), 51 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index 499c9d07fc60..54cd24bf6305 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -3730,20 +3730,21 @@ def visit_match_stmt(self, s: MatchStmt) -> None: pattern_checker = PatternChecker(self, self.msg, self.plugin, s.subject, t) - for p, g, b in zip(s.patterns, s.guards, s.bodies): - if not b.is_unreachable: - type_map = pattern_checker.check_pattern(p) - else: - type_map = None + type_maps = pattern_checker.check_patterns(s.patterns) + + for b, g, tm in zip(s.bodies, s.guards, type_maps): with self.binder.frame_context(can_skip=True, fall_through=2): - self.push_type_map(type_map) + if not b.is_unreachable: + self.push_type_map(tm) + else: + self.push_type_map(None) if g is not None: gt = get_proper_type(self.expr_checker.accept(g)) if isinstance(gt, DeletedType): self.msg.deleted_as_rvalue(gt, s) - if_map, else_map = self.find_isinstance_check(g) + if_map, _ = self.find_isinstance_check(g) self.push_type_map(if_map) self.accept(b) diff --git a/mypy/checkpattern.py b/mypy/checkpattern.py index efe3d0bd4193..8feb46f413d5 100644 --- a/mypy/checkpattern.py +++ b/mypy/checkpattern.py @@ -1,12 +1,13 @@ """Pattern checker. This file is conceptually part of TypeChecker.""" -from typing import List, Optional, Union, Tuple +from collections import defaultdict +from typing import List, Optional, Union, Tuple, Dict from mypy import message_registry from mypy.expandtype import expand_type_by_instance from mypy.join import join_types from mypy.messages import MessageBuilder -from mypy.nodes import Expression, NameExpr, ARG_POS, TypeAlias, TypeInfo +from mypy.nodes import Expression, NameExpr, ARG_POS, TypeAlias, TypeInfo, Var from mypy.patterns import ( Pattern, AsPattern, OrPattern, LiteralPattern, CapturePattern, WildcardPattern, ValuePattern, SequencePattern, StarredPattern, MappingPattern, ClassPattern, MappingKeyPattern @@ -16,7 +17,7 @@ from mypy.typeops import try_getting_str_literals_from_type from mypy.types import ( ProperType, AnyType, TypeOfAny, Instance, Type, NoneType, UninhabitedType, get_proper_type, - TypedDictType, TupleType + TypedDictType, TupleType, UnionType ) from mypy.typevars import fill_typevars from mypy.visitor import PatternVisitor @@ -43,6 +44,11 @@ class PatternChecker(PatternVisitor[Optional[Type]]): This class checks if a pattern can match a type, what the type can be narrowed to, and what type capture patterns should be inferred as. + + visit_ methods return the type a pattern narrows to or None if the pattern can't match the + subject. They should not be called directly, as they change the state. + + Use check_patterns() instead. A new PatternChecker should be used for each match statement. """ # Some services are provided by a TypeChecker instance. @@ -55,6 +61,10 @@ class PatternChecker(PatternVisitor[Optional[Type]]): subject = None # type: Expression # Type of the subject to check the (sub)pattern against type_stack = [] # type: List[Type] + # TODO: This type looks kind of ugly + captured_types = None # type: Dict[Var, List[Tuple[NameExpr, Type]]] + + current_type_map = {} # type: 'mypy.checker.TypeMap' self_match_types = None # type: List[Type] @@ -66,18 +76,50 @@ def __init__(self, chk: 'mypy.checker.TypeChecker', msg: MessageBuilder, plugin: self.subject = subject self.type_stack.append(subject_type) + self.captured_types = defaultdict(list) self.self_match_types = self.generate_self_match_types() - def check_pattern(self, o: Pattern) -> 'mypy.checker.TypeMap': - pattern_type = self.visit(o) - if pattern_type is None: - # This case is unreachable - return None - elif is_equivalent(self.type_stack[-1], pattern_type): - # No need to narrow - return {} - else: - return {self.subject: pattern_type} + def check_patterns(self, patterns: List[Pattern]) -> List['mypy.checker.TypeMap']: + type_maps = [] # type: List['mypy.checker.TypeMap'] + for pattern in patterns: + self.current_type_map = {} # type: Dict[Expression, Type] + pattern_type = self.visit(pattern) + if pattern_type is None: + # This case is unreachable + type_maps.append(None) + elif not is_equivalent(self.type_stack[-1], pattern_type): + self.current_type_map[self.subject] = pattern_type + type_maps.append(self.current_type_map) + else: + type_maps.append(self.current_type_map) + + type_maps = self.infer_types(type_maps) + return type_maps + + def infer_types(self, type_maps: List['mypy.checker.TypeMap']) -> List['mypy.checker.TypeMap']: + for var, captures in self.captured_types.items(): + conflict = False + types = [] # type: List[Type] + for expr, typ in captures: + types.append(typ) + + previous_type, _, inferred = self.chk.check_lvalue(expr) + if previous_type is not None: + conflict = True + self.chk.check_subtype(typ, previous_type, expr, + msg=message_registry.INCOMPATIBLE_TYPES_IN_CAPTURE, + subtype_label="pattern captures type", + supertype_label="variable has type") + for type_map in type_maps: + if type_map is not None and expr in type_map: + del type_map[expr] + + if not conflict: + new_type = UnionType.make_union(types) + # Infer the union type at the first occurrence + self.chk.infer_variable_type(var, captures[0][0], new_type, captures[0][0]) + + return type_maps def visit(self, o: Pattern) -> Optional[Type]: return o.accept(self) @@ -123,15 +165,12 @@ def visit_capture_pattern(self, o: CapturePattern) -> Optional[Type]: return self.type_stack[-1] def check_capture(self, capture: NameExpr) -> None: - capture_type, _, inferred = self.chk.check_lvalue(capture) - if capture_type: - self.chk.check_subtype(capture_type, self.type_stack[-1], capture, - msg=message_registry.INCOMPATIBLE_TYPES_IN_CAPTURE, - subtype_label="pattern captures type", - supertype_label="variable has type") - else: - assert inferred is not None - self.chk.infer_variable_type(inferred, capture, self.type_stack[-1], self.subject) + node = capture.node + assert isinstance(node, Var) + self.captured_types[node].append((capture, self.type_stack[-1])) + type_map = self.current_type_map + if type_map is not None: + type_map[capture] = self.type_stack[-1] def visit_wildcard_pattern(self, o: WildcardPattern) -> Optional[Type]: return self.type_stack[-1] @@ -151,7 +190,6 @@ def visit_sequence_pattern(self, o: SequencePattern) -> Optional[Type]: # Pattern can't match return None - assert isinstance(current_type, Instance) self.type_stack.append(inner_type) new_inner_type = UninhabitedType() # type: Type for p in o.patterns: @@ -162,9 +200,11 @@ def visit_sequence_pattern(self, o: SequencePattern) -> Optional[Type]: self.type_stack.pop() iterable = self.chk.named_generic_type("typing.Iterable", [new_inner_type]) if self.chk.type_is_iterable(current_type): - empty_type = fill_typevars(current_type.type) + proper_type = get_proper_type(current_type) + assert isinstance(proper_type, Instance) + empty_type = fill_typevars(proper_type.type) partial_type = expand_type_by_instance(empty_type, iterable) - new_type = expand_type_by_instance(partial_type, current_type) + new_type = expand_type_by_instance(partial_type, proper_type) else: new_type = iterable @@ -214,6 +254,7 @@ def get_mapping_item_type(self, ) -> Optional[Type]: local_errors = self.msg.clean_copy() local_errors.disable_count = 0 + mapping_type = get_proper_type(mapping_type) if isinstance(mapping_type, TypedDictType): result = self.chk.expr_checker.visit_typeddict_index_expr(mapping_type, key_pattern.expr, diff --git a/test-data/unit/check-python310.test b/test-data/unit/check-python310.test index ce07e02fe453..44ab671fc097 100644 --- a/test-data/unit/check-python310.test +++ b/test-data/unit/check-python310.test @@ -7,25 +7,6 @@ match m: case a: reveal_type(a) # N: Revealed type is "__main__.A" -[case testCapturePatternPreexistingSame] -class A: ... -a: A -m: A - -match m: - case a: - reveal_type(a) # N: Revealed type is "__main__.A" - -[case testCapturePatternPreexistingIncompatible] -class A: ... -class B: ... -a: B -m: A - -match m: - case a: # E: Incompatible types in capture pattern (pattern captures type "B", variable has type "A") - reveal_type(a) # N: Revealed type is "__main__.B" - -- Literal Pattern -- [case testLiteralPatternNarrows] @@ -682,3 +663,42 @@ m: bool match m: case int() as l: reveal_type(l) # N: Revealed type is "builtins.bool" + + +-- Interactions -- +[case testCapturePatternMultipleCaptures] +m: object + +match m: + case int(x): + reveal_type(x) # N: Revealed type is "builtins.int" + case str(x): + reveal_type(x) # N: Revealed type is "builtins.str" + +reveal_type(x) # N: Revealed type is "Union[builtins.int, builtins.str]" + +[case testCapturePatternPreexistingSame] +a: int +m: int + +match m: + case a: + reveal_type(a) # N: Revealed type is "builtins.int" + +[case testCapturePatternPreexistingIncompatible] +a: str +m: int + +match m: + case a: # E: Incompatible types in capture pattern (pattern captures type "int", variable has type "str") + reveal_type(a) # N: Revealed type is "builtins.str" + +[case testCapturePatternPreexistingIncompatibleLater] +a: str +m: object + +match m: + case str(a): + reveal_type(a) # N: Revealed type is "builtins.str" + case int(a): # E: Incompatible types in capture pattern (pattern captures type "int", variable has type "str") + reveal_type(a) # N: Revealed type is "builtins.str" From 0e47d6f42469cc9e4b6c8d463f21ea2b264f5a64 Mon Sep 17 00:00:00 2001 From: Adrian Freund Date: Mon, 12 Apr 2021 23:54:49 +0200 Subject: [PATCH 27/76] Implemented match_args parameter for dataclasses (bpo-43764) This also needs a typeshed change to work. --- mypy/plugins/dataclasses.py | 13 ++++++----- test-data/unit/check-python310.test | 30 +++++++++++++++++++++++++ test-data/unit/lib-stub/dataclasses.pyi | 2 +- 3 files changed, 39 insertions(+), 6 deletions(-) diff --git a/mypy/plugins/dataclasses.py b/mypy/plugins/dataclasses.py index c991521046bc..9eb6659859dd 100644 --- a/mypy/plugins/dataclasses.py +++ b/mypy/plugins/dataclasses.py @@ -112,6 +112,7 @@ def transform(self) -> None: 'eq': _get_decorator_bool_argument(self._ctx, 'eq', True), 'order': _get_decorator_bool_argument(self._ctx, 'order', False), 'frozen': _get_decorator_bool_argument(self._ctx, 'frozen', False), + 'match_args': _get_decorator_bool_argument(self._ctx, 'match_args', True), } # If there are no attributes, it may be that the semantic analyzer has not @@ -175,11 +176,13 @@ def transform(self) -> None: self.reset_init_only_vars(info, attributes) - # Add __match_args__. This is always added - str_type = ctx.api.named_type("__builtins__.str") - literals = [LiteralType(attr.name, str_type) for attr in attributes] # type: List[Type] - match_args_type = TupleType(literals, ctx.api.named_type("__builtins__.tuple")) - add_attribute_to_class(ctx.api, ctx.cls, "__match_args__", match_args_type, final=True) + if (decorator_arguments['match_args'] and + ('__match_args__' not in info.names or info.names['__match_args__'].plugin_generated) and + attributes): + str_type = ctx.api.named_type("__builtins__.str") + literals = [LiteralType(attr.name, str_type) for attr in attributes if attr.is_in_init] # type: List[Type] + match_args_type = TupleType(literals, ctx.api.named_type("__builtins__.tuple")) + add_attribute_to_class(ctx.api, ctx.cls, "__match_args__", match_args_type, final=True) info.metadata['dataclass'] = { 'attributes': [attr.serialize() for attr in attributes], diff --git a/test-data/unit/check-python310.test b/test-data/unit/check-python310.test index 44ab671fc097..8a359f9c311a 100644 --- a/test-data/unit/check-python310.test +++ b/test-data/unit/check-python310.test @@ -358,6 +358,36 @@ match m: reveal_type(i) # N: Revealed type is "builtins.str" reveal_type(j) # N: Revealed type is "builtins.int" +[case testClassPatternCaptureDataclassNoMatchArgs] +from dataclasses import dataclass + +@dataclass(match_args=False) +class A: + a: str + b: int + +m: A + +match m: + case A(i, j): # E: Class doesn't define __match_args__ + pass + +[case testClassPatternCaptureDataclassPartialMatchArgs] +from dataclasses import dataclass, field + +@dataclass +class A: + a: str + b: int = field(init=False) + +m: A + +match m: + case A(i, j): # E: Too many positional patterns for class pattern + pass + case A(k): + reveal_type(k) # N: Revealed type is "builtins.str" + [case testClassPatternCaptureNamedTupleInline] from collections import namedtuple diff --git a/test-data/unit/lib-stub/dataclasses.pyi b/test-data/unit/lib-stub/dataclasses.pyi index 160cfcd066ba..7a6cb10d7b88 100644 --- a/test-data/unit/lib-stub/dataclasses.pyi +++ b/test-data/unit/lib-stub/dataclasses.pyi @@ -11,7 +11,7 @@ def dataclass(_cls: Type[_T]) -> Type[_T]: ... @overload def dataclass(*, init: bool = ..., repr: bool = ..., eq: bool = ..., order: bool = ..., - unsafe_hash: bool = ..., frozen: bool = ...) -> Callable[[Type[_T]], Type[_T]]: ... + unsafe_hash: bool = ..., frozen: bool = ..., match_args: bool = ...) -> Callable[[Type[_T]], Type[_T]]: ... @overload From 77b997b2b72f4fd675bb6712bffd8e933e455867 Mon Sep 17 00:00:00 2001 From: Adrian Freund Date: Thu, 15 Apr 2021 01:16:26 +0200 Subject: [PATCH 28/76] Refactor PatternChecker This is a big commit. It changes the design of PatternChecker to be more in line with ExpressionChecker. It allows reusing PatternChecker instances and will make supporting OrPattern easier. --- mypy/checker.py | 64 ++++- mypy/checkpattern.py | 398 +++++++++++++------------- mypy/fastparse.py | 2 +- mypy/patterns.py | 26 +- mypy/plugins/dataclasses.py | 6 +- mypy/semanal.py | 2 +- test-data/unit/check-python310.test | 56 +++- test-data/unit/parse-python310.test | 3 +- test-data/unit/semanal-python310.test | 6 +- 9 files changed, 327 insertions(+), 236 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index 54cd24bf6305..a177ca92abd9 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -2,6 +2,7 @@ import itertools import fnmatch +from collections import defaultdict from contextlib import contextmanager from typing import ( @@ -10,7 +11,6 @@ ) from typing_extensions import Final -from mypy.checkpattern import PatternChecker from mypy.errors import Errors, report_internal_error from mypy.nodes import ( SymbolTable, Statement, MypyFile, Var, Expression, Lvalue, Node, @@ -45,6 +45,7 @@ from mypy.checkmember import ( analyze_member_access, analyze_descriptor_access, type_object_type, ) +from mypy.checkpattern import PatternChecker, PatternType from mypy.typeops import ( map_type_from_supertype, bind_self, erase_to_bound, make_simplified_union, erase_def_to_union_or_bound, erase_to_union_or_bound, coerce_to_literal, @@ -164,6 +165,8 @@ class TypeChecker(NodeVisitor[None], CheckerPluginInterface): # Helper for type checking expressions expr_checker = None # type: mypy.checkexpr.ExpressionChecker + pattern_checker = None # type: PatternChecker + tscope = None # type: Scope scope = None # type: CheckerScope # Stack of function return types @@ -220,6 +223,7 @@ def __init__(self, errors: Errors, modules: Dict[str, MypyFile], options: Option self.msg = MessageBuilder(errors, modules) self.plugin = plugin self.expr_checker = mypy.checkexpr.ExpressionChecker(self, self.msg, self.plugin) + self.pattern_checker = PatternChecker(self, self.msg, self.plugin) self.tscope = Scope() self.scope = CheckerScope(tree) self.binder = ConditionalTypeBinder() @@ -3723,21 +3727,23 @@ def visit_continue_stmt(self, s: ContinueStmt) -> None: def visit_match_stmt(self, s: MatchStmt) -> None: with self.binder.frame_context(can_skip=False, fall_through=0): - t = get_proper_type(self.expr_checker.accept(s.subject)) + subject_type = get_proper_type(self.expr_checker.accept(s.subject)) - if isinstance(t, DeletedType): - self.msg.deleted_as_rvalue(t, s) + if isinstance(subject_type, DeletedType): + self.msg.deleted_as_rvalue(subject_type, s) - pattern_checker = PatternChecker(self, self.msg, self.plugin, s.subject, t) + pattern_types = [self.pattern_checker.accept(p, subject_type) for p in s.patterns] - type_maps = pattern_checker.check_patterns(s.patterns) + type_maps = get_type_maps_from_pattern_types(pattern_types) + self.infer_names_from_type_maps(type_maps) - for b, g, tm in zip(s.bodies, s.guards, type_maps): + for pattern_type, g, b in zip(pattern_types, s.guards, s.bodies): with self.binder.frame_context(can_skip=True, fall_through=2): - if not b.is_unreachable: - self.push_type_map(tm) - else: + if b.is_unreachable or pattern_type.type is None: self.push_type_map(None) + else: + self.binder.put(s.subject, pattern_type.type) + self.push_type_map(pattern_type.captures) if g is not None: gt = get_proper_type(self.expr_checker.accept(g)) @@ -3749,6 +3755,39 @@ def visit_match_stmt(self, s: MatchStmt) -> None: self.push_type_map(if_map) self.accept(b) + def infer_names_from_type_maps(self, type_maps: List[TypeMap]) -> None: + all_captures = defaultdict(list) # type: Dict[Var, List[Tuple[NameExpr, Type]]] + for tm in type_maps: + if tm is not None: + for expr, typ in tm.items(): + if isinstance(expr, NameExpr): + node = expr.node + assert isinstance(node, Var) + all_captures[node].append((expr, typ)) + + for var, captures in all_captures.items(): + conflict = False + types = [] # type: List[Type] + for expr, typ in captures: + types.append(typ) + + previous_type, _, inferred = self.check_lvalue(expr) + if previous_type is not None: + conflict = True + self.check_subtype(typ, previous_type, expr, + msg=message_registry.INCOMPATIBLE_TYPES_IN_CAPTURE, + subtype_label="pattern captures type", + supertype_label="variable has type") + for type_map in type_maps: + if type_map is not None and expr in type_map: + del type_map[expr] + + if not conflict: + new_type = UnionType.make_union(types) + # Infer the union type at the first occurrence + first_occurrence, _ = captures[0] + self.infer_variable_type(var, first_occurrence, new_type, first_occurrence) + def make_fake_typeinfo(self, curr_module_fullname: str, class_gen_name: str, @@ -5870,3 +5909,8 @@ def collapse_walrus(e: Expression) -> Expression: if isinstance(e, AssignmentExpr): return e.target return e + + +def get_type_maps_from_pattern_types(pattern_types: List[PatternType]) -> List[TypeMap]: + return [pattern_type.captures if pattern_type is not None else None + for pattern_type in pattern_types] diff --git a/mypy/checkpattern.py b/mypy/checkpattern.py index 8feb46f413d5..1054464c8897 100644 --- a/mypy/checkpattern.py +++ b/mypy/checkpattern.py @@ -1,28 +1,25 @@ """Pattern checker. This file is conceptually part of TypeChecker.""" -from collections import defaultdict -from typing import List, Optional, Union, Tuple, Dict +from typing import List, Optional, Union, Tuple, Dict, NamedTuple, Set -from mypy import message_registry +import mypy.checker from mypy.expandtype import expand_type_by_instance from mypy.join import join_types - +from mypy.literals import literal_hash from mypy.messages import MessageBuilder -from mypy.nodes import Expression, NameExpr, ARG_POS, TypeAlias, TypeInfo, Var +from mypy.nodes import Expression, ARG_POS, TypeAlias, TypeInfo, Var, NameExpr from mypy.patterns import ( Pattern, AsPattern, OrPattern, LiteralPattern, CapturePattern, WildcardPattern, ValuePattern, SequencePattern, StarredPattern, MappingPattern, ClassPattern, MappingKeyPattern ) from mypy.plugin import Plugin -from mypy.subtypes import is_subtype, find_member, is_equivalent +from mypy.subtypes import is_subtype, find_member from mypy.typeops import try_getting_str_literals_from_type from mypy.types import ( ProperType, AnyType, TypeOfAny, Instance, Type, NoneType, UninhabitedType, get_proper_type, - TypedDictType, TupleType, UnionType + TypedDictType, TupleType ) from mypy.typevars import fill_typevars from mypy.visitor import PatternVisitor -import mypy.checker - self_match_type_names = [ "builtins.bool", @@ -39,16 +36,19 @@ ] -class PatternChecker(PatternVisitor[Optional[Type]]): +PatternType = NamedTuple( + 'PatternType', + [ + ('type', Optional[Type]), + ('captures', Dict[Expression, Type]), + ]) + + +class PatternChecker(PatternVisitor[PatternType]): """Pattern checker. This class checks if a pattern can match a type, what the type can be narrowed to, and what type capture patterns should be inferred as. - - visit_ methods return the type a pattern narrows to or None if the pattern can't match the - subject. They should not be called directly, as they change the state. - - Use check_patterns() instead. A new PatternChecker should be used for each match statement. """ # Some services are provided by a TypeChecker instance. @@ -59,161 +59,113 @@ class PatternChecker(PatternVisitor[Optional[Type]]): plugin = None # type: Plugin # The expression being matched against the pattern subject = None # type: Expression - # Type of the subject to check the (sub)pattern against - type_stack = [] # type: List[Type] - # TODO: This type looks kind of ugly - captured_types = None # type: Dict[Var, List[Tuple[NameExpr, Type]]] - current_type_map = {} # type: 'mypy.checker.TypeMap' + subject_type = None # type: Type + # Type of the subject to check the (sub)pattern against + type_context = None # type: List[Type] self_match_types = None # type: List[Type] - def __init__(self, chk: 'mypy.checker.TypeChecker', msg: MessageBuilder, plugin: Plugin, - subject: Expression, subject_type: Type) -> None: + def __init__(self, + chk: 'mypy.checker.TypeChecker', + msg: MessageBuilder, plugin: Plugin + ) -> None: self.chk = chk self.msg = msg self.plugin = plugin - self.subject = subject - self.type_stack.append(subject_type) - self.captured_types = defaultdict(list) + self.type_context = [] self.self_match_types = self.generate_self_match_types() - def check_patterns(self, patterns: List[Pattern]) -> List['mypy.checker.TypeMap']: - type_maps = [] # type: List['mypy.checker.TypeMap'] - for pattern in patterns: - self.current_type_map = {} # type: Dict[Expression, Type] - pattern_type = self.visit(pattern) - if pattern_type is None: - # This case is unreachable - type_maps.append(None) - elif not is_equivalent(self.type_stack[-1], pattern_type): - self.current_type_map[self.subject] = pattern_type - type_maps.append(self.current_type_map) - else: - type_maps.append(self.current_type_map) - - type_maps = self.infer_types(type_maps) - return type_maps - - def infer_types(self, type_maps: List['mypy.checker.TypeMap']) -> List['mypy.checker.TypeMap']: - for var, captures in self.captured_types.items(): - conflict = False - types = [] # type: List[Type] - for expr, typ in captures: - types.append(typ) - - previous_type, _, inferred = self.chk.check_lvalue(expr) - if previous_type is not None: - conflict = True - self.chk.check_subtype(typ, previous_type, expr, - msg=message_registry.INCOMPATIBLE_TYPES_IN_CAPTURE, - subtype_label="pattern captures type", - supertype_label="variable has type") - for type_map in type_maps: - if type_map is not None and expr in type_map: - del type_map[expr] - - if not conflict: - new_type = UnionType.make_union(types) - # Infer the union type at the first occurrence - self.chk.infer_variable_type(var, captures[0][0], new_type, captures[0][0]) - - return type_maps - - def visit(self, o: Pattern) -> Optional[Type]: - return o.accept(self) - - def visit_as_pattern(self, o: AsPattern) -> Optional[Type]: - typ = self.visit(o.pattern) - specific_type = get_more_specific_type(typ, self.type_stack[-1]) - if specific_type is None: - return None - self.type_stack.append(specific_type) - self.check_capture(o.name) - self.type_stack.pop() - return typ + def accept(self, o: Pattern, type_context: Type) -> PatternType: + self.type_context.append(type_context) + result = o.accept(self) + self.type_context.pop() + + return result - def visit_or_pattern(self, o: OrPattern) -> Optional[Type]: - return self.type_stack[-1] + def visit_as_pattern(self, o: AsPattern) -> PatternType: + pattern_type = self.accept(o.pattern, self.type_context[-1]) + typ, type_map = pattern_type + if typ is None: + return pattern_type + as_pattern_type = self.accept(o.name, typ) + self.update_type_map(type_map, as_pattern_type.captures) + return PatternType(typ, type_map) - def visit_literal_pattern(self, o: LiteralPattern) -> Optional[Type]: + def visit_or_pattern(self, o: OrPattern) -> PatternType: + # TODO + return PatternType(self.type_context[-1], {}) + + def visit_literal_pattern(self, o: LiteralPattern) -> PatternType: literal_type = self.get_literal_type(o.value) - return get_more_specific_type(literal_type, self.type_stack[-1]) + typ = get_more_specific_type(literal_type, self.type_context[-1]) + return PatternType(typ, {}) def get_literal_type(self, l: Union[int, complex, float, str, bytes, None]) -> Type: - # TODO: Should we use ExprNodes instead of the raw value here? - if isinstance(l, int): - return self.chk.named_type("builtins.int") + if l is None: + typ = NoneType() # type: Type + elif isinstance(l, int): + typ = self.chk.named_type("builtins.int") elif isinstance(l, complex): - return self.chk.named_type("builtins.complex") + typ = self.chk.named_type("builtins.complex") elif isinstance(l, float): - return self.chk.named_type("builtins.float") + typ = self.chk.named_type("builtins.float") elif isinstance(l, str): - return self.chk.named_type("builtins.str") + typ = self.chk.named_type("builtins.str") elif isinstance(l, bytes): - return self.chk.named_type("builtins.bytes") + typ = self.chk.named_type("builtins.bytes") elif isinstance(l, bool): - return self.chk.named_type("builtins.bool") - elif l is None: - return NoneType() + typ = self.chk.named_type("builtins.bool") else: assert False, "Invalid literal in literal pattern" - def visit_capture_pattern(self, o: CapturePattern) -> Optional[Type]: - self.check_capture(o.name) - return self.type_stack[-1] + return typ - def check_capture(self, capture: NameExpr) -> None: - node = capture.node + def visit_capture_pattern(self, o: CapturePattern) -> PatternType: + node = o.name.node assert isinstance(node, Var) - self.captured_types[node].append((capture, self.type_stack[-1])) - type_map = self.current_type_map - if type_map is not None: - type_map[capture] = self.type_stack[-1] + return PatternType(self.type_context[-1], {o.name: self.type_context[-1]}) - def visit_wildcard_pattern(self, o: WildcardPattern) -> Optional[Type]: - return self.type_stack[-1] + def visit_wildcard_pattern(self, o: WildcardPattern) -> PatternType: + return PatternType(self.type_context[-1], {}) - def visit_value_pattern(self, o: ValuePattern) -> Optional[Type]: + def visit_value_pattern(self, o: ValuePattern) -> PatternType: typ = self.chk.expr_checker.accept(o.expr) - return get_more_specific_type(typ, self.type_stack[-1]) + specific_typ = get_more_specific_type(typ, self.type_context[-1]) + return PatternType(specific_typ, {}) - def visit_sequence_pattern(self, o: SequencePattern) -> Optional[Type]: - current_type = self.type_stack[-1] - inner_type = self.get_sequence_type(get_proper_type(current_type)) + def visit_sequence_pattern(self, o: SequencePattern) -> PatternType: + current_type = self.type_context[-1] + inner_type = self.get_sequence_type(current_type) if inner_type is None: if is_subtype(self.chk.named_type("typing.Iterable"), current_type): # Current type is more general, but the actual value could still be iterable inner_type = self.chk.named_type("builtins.object") else: - # Pattern can't match - return None + return early_non_match() - self.type_stack.append(inner_type) new_inner_type = UninhabitedType() # type: Type + captures = {} # type: Dict[Expression, Type] + can_match = True for p in o.patterns: - pattern_type = self.visit(p) - if pattern_type is None: - return None - new_inner_type = join_types(new_inner_type, pattern_type) - self.type_stack.pop() - iterable = self.chk.named_generic_type("typing.Iterable", [new_inner_type]) - if self.chk.type_is_iterable(current_type): - proper_type = get_proper_type(current_type) - assert isinstance(proper_type, Instance) - empty_type = fill_typevars(proper_type.type) - partial_type = expand_type_by_instance(empty_type, iterable) - new_type = expand_type_by_instance(partial_type, proper_type) - else: - new_type = iterable + pattern_type = self.accept(p, inner_type) + typ, type_map = pattern_type + if typ is None: + can_match = False + else: + new_inner_type = join_types(new_inner_type, typ) + self.update_type_map(captures, type_map) - if is_subtype(new_type, current_type): - return new_type - else: - return current_type + new_type = None # type: Optional[Type] + if can_match: + new_type = self.construct_iterable_child(current_type, new_inner_type) + if not is_subtype(new_type, current_type): + new_type = current_type + return PatternType(new_type, captures) - def get_sequence_type(self, t: ProperType) -> Optional[Type]: + def get_sequence_type(self, t: Type) -> Optional[Type]: + t = get_proper_type(t) if isinstance(t, AnyType): return AnyType(TypeOfAny.from_another_any, t) @@ -222,30 +174,36 @@ def get_sequence_type(self, t: ProperType) -> Optional[Type]: else: return None - def visit_starred_pattern(self, o: StarredPattern) -> Optional[Type]: - if not isinstance(o.capture, WildcardPattern): - list_type = self.chk.named_generic_type('builtins.list', [self.type_stack[-1]]) - self.type_stack.append(list_type) - self.visit_capture_pattern(o.capture) - self.type_stack.pop() - return self.type_stack[-1] + def visit_starred_pattern(self, o: StarredPattern) -> PatternType: + if isinstance(o.capture, CapturePattern): + list_type = self.chk.named_generic_type('builtins.list', [self.type_context[-1]]) + pattern_type = self.accept(o.capture, list_type) + captures = pattern_type.captures + elif isinstance(o.capture, WildcardPattern): + captures = {} + else: + assert False + return PatternType(self.type_context[-1], captures) - def visit_mapping_pattern(self, o: MappingPattern) -> Optional[Type]: - current_type = self.type_stack[-1] + def visit_mapping_pattern(self, o: MappingPattern) -> PatternType: + current_type = self.type_context[-1] can_match = True + captures = {} # type: Dict[Expression, Type] for key, value in zip(o.keys, o.values): inner_type = self.get_mapping_item_type(o, current_type, key) if inner_type is None: can_match = False inner_type = self.chk.named_type("builtins.object") - self.type_stack.append(inner_type) - if self.visit(value) is None: + pattern_type = self.accept(value, inner_type) + if pattern_type is None: can_match = False - self.type_stack.pop() + else: + self.update_type_map(captures, pattern_type.captures) if can_match: - return self.type_stack[-1] + new_type = self.type_context[-1] # type: Optional[Type] else: - return None + new_type = None + return PatternType(new_type, captures) def get_mapping_item_type(self, pattern: MappingPattern, @@ -293,89 +251,103 @@ def get_simple_mapping_item_type(self, local_errors=local_errors) return result - def visit_class_pattern(self, o: ClassPattern) -> Optional[Type]: - current_type = self.type_stack[-1] + def visit_class_pattern(self, o: ClassPattern) -> PatternType: + current_type = self.type_context[-1] + + # + # Check class type + # class_name = o.class_ref.fullname assert class_name is not None sym = self.chk.lookup_qualified(class_name) if isinstance(sym.node, TypeAlias) and not sym.node.no_args: self.msg.fail("Class pattern class must not be a type alias with type parameters", o) - return None + return early_non_match() if isinstance(sym.node, (TypeAlias, TypeInfo)): typ = self.chk.named_type(class_name) else: self.msg.fail('Class pattern must be a type. Found "{}"'.format(sym.type), o.class_ref) - return None + return early_non_match() + # + # Convert positional to keyword patterns + # keyword_pairs = [] # type: List[Tuple[Optional[str], Pattern]] - match_arg_names = [] # type: List[Optional[str]] + match_arg_set = set() # type: Set[str] - can_match = True - - if self.should_self_match(typ): - if len(o.positionals) >= 1: - self.type_stack.append(typ) - if self.visit(o.positionals[0]) is None: - can_match = False - self.type_stack.pop() + captures = {} # type: Dict[Expression, Type] + if len(o.positionals) != 0: + if self.should_self_match(typ): if len(o.positionals) > 1: self.msg.fail("Too many positional patterns for class pattern", o) - self.type_stack.append(self.chk.named_type("builtins.object")) - for p in o.positionals[1:]: - if self.visit(p) is None: - can_match = False - self.type_stack.pop() - else: - match_args_type = find_member("__match_args__", typ, typ) - - if match_args_type is None and can_match: - if len(o.positionals) >= 1: - self.msg.fail("Class doesn't define __match_args__", o) - - proper_match_args_type = get_proper_type(match_args_type) - if isinstance(proper_match_args_type, TupleType): - match_arg_names = get_match_arg_names(proper_match_args_type) - - if len(o.positionals) > len(match_arg_names): - self.msg.fail("Too many positional patterns for class pattern", o) - match_arg_names += [None] * (len(o.positionals) - len(match_arg_names)) + pattern_type = self.accept(o.positionals[0], typ) + if pattern_type.type is None: + return pattern_type + captures = pattern_type.captures else: - match_arg_names = [None] * len(o.positionals) + match_args_type = find_member("__match_args__", typ, typ) - positional_names = set() - - for arg_name, pos in zip(match_arg_names, o.positionals): - keyword_pairs.append((arg_name, pos)) - positional_names.add(arg_name) - - keyword_names = set() + if match_args_type is None: + self.msg.fail("Class doesn't define __match_args__", o) + return early_non_match() + + proper_match_args_type = get_proper_type(match_args_type) + if isinstance(proper_match_args_type, TupleType): + match_arg_names = get_match_arg_names(proper_match_args_type) + + if len(o.positionals) > len(match_arg_names): + self.msg.fail("Too many positional patterns for class pattern", o) + return early_non_match() + else: + match_arg_names = [None] * len(o.positionals) + + for arg_name, pos in zip(match_arg_names, o.positionals): + keyword_pairs.append((arg_name, pos)) + if arg_name is not None: + match_arg_set.add(arg_name) + + # + # Check for duplicate patterns + # + keyword_arg_set = set() + has_duplicates = False for key, value in zip(o.keyword_keys, o.keyword_values): keyword_pairs.append((key, value)) - if key in match_arg_names: + if key in match_arg_set: self.msg.fail('Keyword "{}" already matches a positional pattern'.format(key), value) - elif key in keyword_names: + has_duplicates = True + elif key in keyword_arg_set: self.msg.fail('Duplicate keyword pattern "{}"'.format(key), value) - keyword_names.add(key) + has_duplicates = True + keyword_arg_set.add(key) + if has_duplicates: + return early_non_match() + + # + # Check keyword patterns + # + can_match = True for keyword, pattern in keyword_pairs: + key_type = None # type: Optional[Type] if keyword is not None: key_type = find_member(keyword, typ, current_type) - if key_type is None: - key_type = self.chk.named_type("builtins.object") - else: - key_type = self.chk.named_type("builtins.object") + if key_type is None: + key_type = AnyType(TypeOfAny.implementation_artifact) - self.type_stack.append(key_type) - if self.visit(pattern) is None: + pattern_type = self.accept(pattern, key_type) + if pattern_type is None: can_match = False - self.type_stack.pop() + else: + self.update_type_map(captures, pattern_type.captures) if can_match: - return get_more_specific_type(current_type, typ) + new_type = get_more_specific_type(current_type, typ) else: - return None + new_type = None + return PatternType(new_type, captures) def should_self_match(self, typ: ProperType) -> bool: if isinstance(typ, Instance) and typ.type.is_named_tuple: @@ -396,6 +368,34 @@ def generate_self_match_types(self) -> List[Type]: return types + def update_type_map(self, + original_type_map: Dict[Expression, Type], + extra_type_map: Dict[Expression, Type] + ) -> None: + # Calculating this would not be needed if TypeMap directly used literal hashes instead of + # expressions, as suggested in the TODO above it's definition + already_captured = set(literal_hash(expr) for expr in original_type_map) + for expr, typ in extra_type_map.items(): + if literal_hash(expr) in already_captured: + assert isinstance(expr, NameExpr) + node = expr.node + assert node is not None + self.msg.fail('Multiple assignments to name "{}" in pattern'.format(node.name), + expr) + else: + original_type_map[expr] = typ + + def construct_iterable_child(self, outer_type: Type, inner_type: Type) -> Type: + iterable = self.chk.named_generic_type("typing.Iterable", [inner_type]) + if self.chk.type_is_iterable(outer_type): + proper_type = get_proper_type(outer_type) + assert isinstance(proper_type, Instance) + empty_type = fill_typevars(proper_type.type) + partial_type = expand_type_by_instance(empty_type, iterable) + return expand_type_by_instance(partial_type, proper_type) + else: + return iterable + def get_match_arg_names(typ: TupleType) -> List[Optional[str]]: args = [] # type: List[Optional[str]] @@ -417,3 +417,7 @@ def get_more_specific_type(left: Optional[Type], right: Optional[Type]) -> Optio return right else: return None + + +def early_non_match() -> PatternType: + return PatternType(None, {}) diff --git a/mypy/fastparse.py b/mypy/fastparse.py index 53490ae8d594..b9ac434e4301 100644 --- a/mypy/fastparse.py +++ b/mypy/fastparse.py @@ -1331,7 +1331,7 @@ def visit(self, node: Optional[AST]) -> Pattern: # MatchAs(expr pattern, identifier name) def visit_MatchAs(self, n: MatchAs) -> AsPattern: - node = AsPattern(self.visit(n.pattern), NameExpr(n.name)) + node = AsPattern(self.visit(n.pattern), self.set_line(CapturePattern(NameExpr(n.name)), n)) return self.set_line(node, n) # MatchOr(expr* pattern) diff --git a/mypy/patterns.py b/mypy/patterns.py index cb0d7260cd4b..58a1d9417d34 100644 --- a/mypy/patterns.py +++ b/mypy/patterns.py @@ -39,11 +39,22 @@ def __init__(self, expr: Expression) -> None: self.expr = expr +class CapturePattern(AlwaysTruePattern): + name = None # type: NameExpr + + def __init__(self, name: NameExpr): + super().__init__() + self.name = name + + def accept(self, visitor: PatternVisitor[T]) -> T: + return visitor.visit_capture_pattern(self) + + class AsPattern(Pattern): pattern = None # type: Pattern - name = None # type: NameExpr + name = None # type: CapturePattern - def __init__(self, pattern: Pattern, name: NameExpr) -> None: + def __init__(self, pattern: Pattern, name: CapturePattern) -> None: super().__init__() self.pattern = pattern self.name = name @@ -78,17 +89,6 @@ def accept(self, visitor: PatternVisitor[T]) -> T: return visitor.visit_literal_pattern(self) -class CapturePattern(AlwaysTruePattern): - name = None # type: NameExpr - - def __init__(self, name: NameExpr): - super().__init__() - self.name = name - - def accept(self, visitor: PatternVisitor[T]) -> T: - return visitor.visit_capture_pattern(self) - - class WildcardPattern(AlwaysTruePattern): def accept(self, visitor: PatternVisitor[T]) -> T: return visitor.visit_wildcard_pattern(self) diff --git a/mypy/plugins/dataclasses.py b/mypy/plugins/dataclasses.py index 9eb6659859dd..77ade2fc1f9f 100644 --- a/mypy/plugins/dataclasses.py +++ b/mypy/plugins/dataclasses.py @@ -177,10 +177,12 @@ def transform(self) -> None: self.reset_init_only_vars(info, attributes) if (decorator_arguments['match_args'] and - ('__match_args__' not in info.names or info.names['__match_args__'].plugin_generated) and + ('__match_args__' not in info.names or + info.names['__match_args__'].plugin_generated) and attributes): str_type = ctx.api.named_type("__builtins__.str") - literals = [LiteralType(attr.name, str_type) for attr in attributes if attr.is_in_init] # type: List[Type] + literals = [LiteralType(attr.name, str_type) + for attr in attributes if attr.is_in_init] # type: List[Type] match_args_type = TupleType(literals, ctx.api.named_type("__builtins__.tuple")) add_attribute_to_class(ctx.api, ctx.cls, "__match_args__", match_args_type, final=True) diff --git a/mypy/semanal.py b/mypy/semanal.py index b3d21d3bd689..ce93d2f6715b 100644 --- a/mypy/semanal.py +++ b/mypy/semanal.py @@ -3987,7 +3987,7 @@ def visit_await_expr(self, expr: AwaitExpr) -> None: def visit_as_pattern(self, p: AsPattern) -> None: p.pattern.accept(self) - self.analyze_lvalue(p.name) + p.name.accept(self) def visit_or_pattern(self, p: OrPattern) -> None: for pattern in p.patterns: diff --git a/test-data/unit/check-python310.test b/test-data/unit/check-python310.test index 8a359f9c311a..789a27cd66e9 100644 --- a/test-data/unit/check-python310.test +++ b/test-data/unit/check-python310.test @@ -64,7 +64,7 @@ b: str -- Sequence Pattern -- -[case testSequenceCPatternCaptures] +[case testSequencePatternCaptures] from typing import List m: List[int] @@ -343,6 +343,34 @@ match m: reveal_type(k) # N: Revealed type is "builtins.tuple[Any]" [builtins fixtures/primitives.pyi] +[case testClassPatternNarrowSelfCapture] +m: object + +match m: + case bool(): + reveal_type(m) # N: Revealed type is "builtins.bool" + case bytearray(): + reveal_type(m) # N: Revealed type is "builtins.bytearray" + case bytes(): + reveal_type(m) # N: Revealed type is "builtins.bytes" + case dict(): + reveal_type(m) # N: Revealed type is "builtins.dict[Any, Any]" + case float(): + reveal_type(m) # N: Revealed type is "builtins.float" + case frozenset(): + reveal_type(m) # N: Revealed type is "builtins.frozenset[Any]" + case int(): + reveal_type(m) # N: Revealed type is "builtins.int" + case list(): + reveal_type(m) # N: Revealed type is "builtins.list[Any]" + case set(): + reveal_type(m) # N: Revealed type is "builtins.set[Any]" + case str(): + reveal_type(m) # N: Revealed type is "builtins.str" + case tuple(): + reveal_type(m) # N: Revealed type is "builtins.tuple[Any]" +[builtins fixtures/primitives.pyi] + [case testClassPatternCaptureDataclass] from dataclasses import dataclass @@ -551,7 +579,7 @@ m: object match m: case A(a=j): reveal_type(m) # N: Revealed type is "__main__.A" - reveal_type(j) # N: Revealed type is "builtins.object" + reveal_type(j) # N: Revealed type is "Any" [case testClassPatternDuplicateKeyword] class A: @@ -611,8 +639,8 @@ m: object match m: case A(i, j): - reveal_type(i) # N: Revealed type is "builtins.object" - reveal_type(j) # N: Revealed type is "builtins.object" + reveal_type(i) # N: Revealed type is "Any" + reveal_type(j) # N: Revealed type is "Any" [builtins fixtures/tuple.pyi] [case testAnyTupleMatchArgs] @@ -627,9 +655,9 @@ m: object match m: case A(i, j, k): - reveal_type(i) # N: Revealed type is "builtins.object" - reveal_type(j) # N: Revealed type is "builtins.object" - reveal_type(k) # N: Revealed type is "builtins.object" + reveal_type(i) # N: Revealed type is "Any" + reveal_type(j) # N: Revealed type is "Any" + reveal_type(k) # N: Revealed type is "Any" [builtins fixtures/tuple.pyi] [case testNonLiteralMatchArgs] @@ -648,7 +676,7 @@ match m: pass case A(i, j): reveal_type(i) # N: Revealed type is "builtins.str" - reveal_type(j) # N: Revealed type is "builtins.object" + reveal_type(j) # N: Revealed type is "Any" [builtins fixtures/tuple.pyi] [case testExternalMatchArgs] @@ -696,7 +724,7 @@ match m: -- Interactions -- -[case testCapturePatternMultipleCaptures] +[case testCapturePatternMultipleCases] m: object match m: @@ -707,6 +735,16 @@ match m: reveal_type(x) # N: Revealed type is "Union[builtins.int, builtins.str]" +[case testCapturePatternMultipleCaptures] +from typing import Iterable + +m: Iterable[int] + +match m: + case [x, x]: # E: Multiple assignments to name "x" in pattern + reveal_type(x) # N: Revealed type is "builtins.int" +[builtins fixtures/list.pyi] + [case testCapturePatternPreexistingSame] a: int m: int diff --git a/test-data/unit/parse-python310.test b/test-data/unit/parse-python310.test index f6e2b998d176..4f87fbf0e77e 100644 --- a/test-data/unit/parse-python310.test +++ b/test-data/unit/parse-python310.test @@ -79,7 +79,8 @@ MypyFile:1( LiteralPattern:2( 1 IntExpr(1)) - NameExpr(b))) + CapturePattern:2( + NameExpr(b)))) Body( PassStmt:3()))) diff --git a/test-data/unit/semanal-python310.test b/test-data/unit/semanal-python310.test index a3f2924dc130..e889b12c0643 100644 --- a/test-data/unit/semanal-python310.test +++ b/test-data/unit/semanal-python310.test @@ -89,7 +89,8 @@ MypyFile:1( LiteralPattern:3( 1 IntExpr(1)) - NameExpr(a* [__main__.a]))) + CapturePattern:3( + NameExpr(a* [__main__.a])))) Body( ExpressionStmt:4( NameExpr(a [__main__.a]))))) @@ -156,7 +157,8 @@ MypyFile:1( LiteralPattern:3( 1 IntExpr(1)) - NameExpr(a* [__main__.a]))) + CapturePattern:3( + NameExpr(a* [__main__.a])))) Guard( NameExpr(a [__main__.a])) Body( From 5228bc7c64d59ccccea736e134707a180571a317 Mon Sep 17 00:00:00 2001 From: Adrian Freund Date: Tue, 20 Apr 2021 16:48:03 +0200 Subject: [PATCH 29/76] Add match statement support for or pattern --- mypy/checker.py | 9 +--- mypy/checkpattern.py | 65 +++++++++++++++++++++++++---- test-data/unit/check-python310.test | 37 +++++++++++++++- 3 files changed, 95 insertions(+), 16 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index a177ca92abd9..baaf48047a1f 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -45,7 +45,7 @@ from mypy.checkmember import ( analyze_member_access, analyze_descriptor_access, type_object_type, ) -from mypy.checkpattern import PatternChecker, PatternType +from mypy.checkpattern import PatternChecker from mypy.typeops import ( map_type_from_supertype, bind_self, erase_to_bound, make_simplified_union, erase_def_to_union_or_bound, erase_to_union_or_bound, coerce_to_literal, @@ -3734,7 +3734,7 @@ def visit_match_stmt(self, s: MatchStmt) -> None: pattern_types = [self.pattern_checker.accept(p, subject_type) for p in s.patterns] - type_maps = get_type_maps_from_pattern_types(pattern_types) + type_maps = [t.captures for t in pattern_types] # type: List[TypeMap] self.infer_names_from_type_maps(type_maps) for pattern_type, g, b in zip(pattern_types, s.guards, s.bodies): @@ -5909,8 +5909,3 @@ def collapse_walrus(e: Expression) -> Expression: if isinstance(e, AssignmentExpr): return e.target return e - - -def get_type_maps_from_pattern_types(pattern_types: List[PatternType]) -> List[TypeMap]: - return [pattern_type.captures if pattern_type is not None else None - for pattern_type in pattern_types] diff --git a/mypy/checkpattern.py b/mypy/checkpattern.py index 1054464c8897..97920b04ffb7 100644 --- a/mypy/checkpattern.py +++ b/mypy/checkpattern.py @@ -1,4 +1,5 @@ """Pattern checker. This file is conceptually part of TypeChecker.""" +from collections import defaultdict from typing import List, Optional, Union, Tuple, Dict, NamedTuple, Set import mypy.checker @@ -16,7 +17,7 @@ from mypy.typeops import try_getting_str_literals_from_type from mypy.types import ( ProperType, AnyType, TypeOfAny, Instance, Type, NoneType, UninhabitedType, get_proper_type, - TypedDictType, TupleType + TypedDictType, TupleType, UnionType ) from mypy.typevars import fill_typevars from mypy.visitor import PatternVisitor @@ -94,8 +95,49 @@ def visit_as_pattern(self, o: AsPattern) -> PatternType: return PatternType(typ, type_map) def visit_or_pattern(self, o: OrPattern) -> PatternType: - # TODO - return PatternType(self.type_context[-1], {}) + + # + # Check all the subpatterns + # + pattern_types = [] + for pattern in o.patterns: + pattern_types.append(self.accept(pattern, self.type_context[-1])) + + # + # Collect the final type + # + types = [] + for pattern_type in pattern_types: + if pattern_type.type is not None: + types.append(pattern_type.type) + + # + # Check the capture types + # + capture_types = defaultdict(list) # type: Dict[Var, List[Tuple[Expression, Type]]] + # Collect captures from the first subpattern + for expr, typ in pattern_types[0].captures.items(): + node = get_var(expr) + capture_types[node].append((expr, typ)) + + # Check if other subpatterns capture the same names + for i, pattern_type in enumerate(pattern_types[1:]): + vars = {get_var(expr) for expr, _ in pattern_type.captures.items()} + if capture_types.keys() != vars: + self.msg.fail("Alternative patterns bind different names", o.patterns[i]) + for expr, typ in pattern_type.captures.items(): + node = get_var(expr) + capture_types[node].append((expr, typ)) + + captures = {} # type: Dict[Expression, Type] + for var, capture_list in capture_types.items(): + typ = UninhabitedType() + for _, other in capture_list: + typ = join_types(typ, other) + + captures[capture_list[0][0]] = typ + + return PatternType(UnionType.make_union(types), captures) def visit_literal_pattern(self, o: LiteralPattern) -> PatternType: literal_type = self.get_literal_type(o.value) @@ -123,8 +165,6 @@ def get_literal_type(self, l: Union[int, complex, float, str, bytes, None]) -> T return typ def visit_capture_pattern(self, o: CapturePattern) -> PatternType: - node = o.name.node - assert isinstance(node, Var) return PatternType(self.type_context[-1], {o.name: self.type_context[-1]}) def visit_wildcard_pattern(self, o: WildcardPattern) -> PatternType: @@ -377,9 +417,7 @@ def update_type_map(self, already_captured = set(literal_hash(expr) for expr in original_type_map) for expr, typ in extra_type_map.items(): if literal_hash(expr) in already_captured: - assert isinstance(expr, NameExpr) - node = expr.node - assert node is not None + node = get_var(expr) self.msg.fail('Multiple assignments to name "{}" in pattern'.format(node.name), expr) else: @@ -421,3 +459,14 @@ def get_more_specific_type(left: Optional[Type], right: Optional[Type]) -> Optio def early_non_match() -> PatternType: return PatternType(None, {}) + + +def get_var(expr: Expression) -> Var: + """ + Warning: this in only true for expressions captured by a match statement. + Don't call it from anywhere else + """ + assert isinstance(expr, NameExpr) + node = expr.node + assert isinstance(node, Var) + return node diff --git a/test-data/unit/check-python310.test b/test-data/unit/check-python310.test index 789a27cd66e9..946e2268d817 100644 --- a/test-data/unit/check-python310.test +++ b/test-data/unit/check-python310.test @@ -699,7 +699,7 @@ class B: [typing fixtures/typing-medium.pyi] --- as pattern -- +-- As Pattern -- [case testAsPattern] m: int @@ -723,6 +723,41 @@ match m: reveal_type(l) # N: Revealed type is "builtins.bool" +-- Or Pattern -- +[case testOrPatternNarrows] +m: object + +match m: + case 1 | 2: + reveal_type(m) # N: Revealed type is "builtins.int" + +[case testOrPatternNarrowsUnion] +m: object + +match m: + case 1 | "foo": + reveal_type(m) # N: Revealed type is "Union[builtins.int, builtins.str]" + +[case testOrPatterCapturesMissing] +from typing import List +m: List[int] + +match m: + case [x, y] | list(x): # E: Alternative patterns bind different names + reveal_type(x) # N: Revealed type is "builtins.object" + reveal_type(y) # N: Revealed type is "builtins.int*" +[builtins fixtures/list.pyi] + +[case testOrPatternCapturesJoin] +m: object + +match m: + case list(x) | dict(x): + reveal_type(x) # N: Revealed type is "builtins.object" +[builtins fixtures/list.pyi] +[builtins fixtures/dict.pyi] + + -- Interactions -- [case testCapturePatternMultipleCases] m: object From bba61faf79f2fcd3aafa4ca4dbf2dfadadbb39bb Mon Sep 17 00:00:00 2001 From: Adrian Freund Date: Mon, 26 Apr 2021 15:45:42 +0200 Subject: [PATCH 30/76] Fix make_simplified_union for instances with last_known_value --- mypy/typeops.py | 14 +++++++++++++- test-data/unit/check-python38.test | 2 +- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/mypy/typeops.py b/mypy/typeops.py index 3b5ca73f8713..1760e9c00503 100644 --- a/mypy/typeops.py +++ b/mypy/typeops.py @@ -367,7 +367,8 @@ def make_simplified_union(items: Sequence[Type], # Keep track of the truishness info for deleted subtypes which can be relevant cbt = cbf = False for j, tj in enumerate(items): - if i != j and is_proper_subtype(tj, ti, keep_erased_types=keep_erased): + if i != j and is_proper_subtype(tj, ti, keep_erased_types=keep_erased) and \ + is_redundant_literal_instance(ti, tj): # We found a redundant item in the union. removed.add(j) cbt = cbt or tj.can_be_true @@ -805,3 +806,14 @@ def custom_special_method(typ: Type, name: str, check_all: bool = False) -> bool return True # TODO: support other types (see ExpressionChecker.has_member())? return False + + +def is_redundant_literal_instance(general: ProperType, specific: ProperType) -> bool: + if not isinstance(general, Instance) or general.last_known_value is None: + return True + if isinstance(specific, Instance) and specific.last_known_value == general.last_known_value: + return True + if isinstance(specific, UninhabitedType): + return True + + return False diff --git a/test-data/unit/check-python38.test b/test-data/unit/check-python38.test index 3e054a45400b..c970f56a864f 100644 --- a/test-data/unit/check-python38.test +++ b/test-data/unit/check-python38.test @@ -268,7 +268,7 @@ def f(x: int = (c := 4)) -> int: f(x=(y7 := 3)) reveal_type(y7) # N: Revealed type is "builtins.int" - reveal_type((lambda: (y8 := 3) and y8)()) # N: Revealed type is "Literal[3]?" + reveal_type((lambda: (y8 := 3) and y8)()) # N: Revealed type is "builtins.int" y8 # E: Name "y8" is not defined y7 = 1.0 # E: Incompatible types in assignment (expression has type "float", variable has type "int") From 39478fd8e7d5db6923847106425f39aa583036b9 Mon Sep 17 00:00:00 2001 From: Adrian Freund Date: Tue, 27 Apr 2021 08:51:08 +0200 Subject: [PATCH 31/76] Infer literals from literal patterns in match statements --- mypy/checkpattern.py | 33 ++++++----------------------- mypy/semanal.py | 7 ++++-- test-data/unit/check-python310.test | 20 ++++++++++++++--- 3 files changed, 29 insertions(+), 31 deletions(-) diff --git a/mypy/checkpattern.py b/mypy/checkpattern.py index 97920b04ffb7..f90f83faeaab 100644 --- a/mypy/checkpattern.py +++ b/mypy/checkpattern.py @@ -1,6 +1,6 @@ """Pattern checker. This file is conceptually part of TypeChecker.""" from collections import defaultdict -from typing import List, Optional, Union, Tuple, Dict, NamedTuple, Set +from typing import List, Optional, Tuple, Dict, NamedTuple, Set import mypy.checker from mypy.expandtype import expand_type_by_instance @@ -14,10 +14,10 @@ ) from mypy.plugin import Plugin from mypy.subtypes import is_subtype, find_member -from mypy.typeops import try_getting_str_literals_from_type +from mypy.typeops import try_getting_str_literals_from_type, make_simplified_union from mypy.types import ( - ProperType, AnyType, TypeOfAny, Instance, Type, NoneType, UninhabitedType, get_proper_type, - TypedDictType, TupleType, UnionType + ProperType, AnyType, TypeOfAny, Instance, Type, UninhabitedType, get_proper_type, + TypedDictType, TupleType ) from mypy.typevars import fill_typevars from mypy.visitor import PatternVisitor @@ -137,33 +137,14 @@ def visit_or_pattern(self, o: OrPattern) -> PatternType: captures[capture_list[0][0]] = typ - return PatternType(UnionType.make_union(types), captures) + union_type = make_simplified_union(types) + return PatternType(union_type, captures) def visit_literal_pattern(self, o: LiteralPattern) -> PatternType: - literal_type = self.get_literal_type(o.value) + literal_type = self.chk.expr_checker.accept(o.expr) typ = get_more_specific_type(literal_type, self.type_context[-1]) return PatternType(typ, {}) - def get_literal_type(self, l: Union[int, complex, float, str, bytes, None]) -> Type: - if l is None: - typ = NoneType() # type: Type - elif isinstance(l, int): - typ = self.chk.named_type("builtins.int") - elif isinstance(l, complex): - typ = self.chk.named_type("builtins.complex") - elif isinstance(l, float): - typ = self.chk.named_type("builtins.float") - elif isinstance(l, str): - typ = self.chk.named_type("builtins.str") - elif isinstance(l, bytes): - typ = self.chk.named_type("builtins.bytes") - elif isinstance(l, bool): - typ = self.chk.named_type("builtins.bool") - else: - assert False, "Invalid literal in literal pattern" - - return typ - def visit_capture_pattern(self, o: CapturePattern) -> PatternType: return PatternType(self.type_context[-1], {o.name: self.type_context[-1]}) diff --git a/mypy/semanal.py b/mypy/semanal.py index ce93d2f6715b..becd4168eaee 100644 --- a/mypy/semanal.py +++ b/mypy/semanal.py @@ -79,8 +79,8 @@ ParamSpecExpr, MatchStmt ) from mypy.patterns import ( - AsPattern, OrPattern, CapturePattern, ValuePattern, SequencePattern, StarredPattern, - MappingPattern, ClassPattern + AsPattern, OrPattern, CapturePattern, LiteralPattern, ValuePattern, SequencePattern, + StarredPattern, MappingPattern, ClassPattern ) from mypy.tvar_scope import TypeVarLikeScope @@ -3996,6 +3996,9 @@ def visit_or_pattern(self, p: OrPattern) -> None: def visit_capture_pattern(self, p: CapturePattern) -> None: self.analyze_lvalue(p.name) + def visit_literal_pattern(self, p: LiteralPattern) -> None: + p.expr.accept(self) + def visit_value_pattern(self, p: ValuePattern) -> None: p.expr.accept(self) diff --git a/test-data/unit/check-python310.test b/test-data/unit/check-python310.test index 946e2268d817..09b9c693ed09 100644 --- a/test-data/unit/check-python310.test +++ b/test-data/unit/check-python310.test @@ -14,7 +14,7 @@ m: object match m: case 1: - reveal_type(m) # N: Revealed type is "builtins.int" + reveal_type(m) # N: Revealed type is "Literal[1]?" [case testLiteralPatternAlreadyNarrower] m: bool @@ -715,6 +715,13 @@ match m: case int() as l: reveal_type(l) # N: Revealed type is "builtins.int" +[case testAsPatternCapturesOr] +m: object + +match m: + case 1 | 2 as n: + reveal_type(n) # N: Revealed type is "Union[Literal[1]?, Literal[2]?]" + [case testAsPatternAlreadyNarrower] m: bool @@ -729,14 +736,21 @@ m: object match m: case 1 | 2: - reveal_type(m) # N: Revealed type is "builtins.int" + reveal_type(m) # N: Revealed type is "Union[Literal[1]?, Literal[2]?]" + +[case testOrPatternNarrowsStr] +m: object + +match m: + case "foo" | "bar": + reveal_type(m) # N: Revealed type is "Union[Literal['foo']?, Literal['bar']?]" [case testOrPatternNarrowsUnion] m: object match m: case 1 | "foo": - reveal_type(m) # N: Revealed type is "Union[builtins.int, builtins.str]" + reveal_type(m) # N: Revealed type is "Union[Literal[1]?, Literal['foo']?]" [case testOrPatterCapturesMissing] from typing import List From a865882bc259e8c0c1f93cff6c0798df2d73bcac Mon Sep 17 00:00:00 2001 From: Adrian Freund Date: Tue, 27 Apr 2021 10:40:24 +0200 Subject: [PATCH 32/76] Infer the rest pattern for mapping patterns --- mypy/checkpattern.py | 7 +++++++ test-data/unit/check-python310.test | 8 ++++++++ 2 files changed, 15 insertions(+) diff --git a/mypy/checkpattern.py b/mypy/checkpattern.py index f90f83faeaab..fff132fc73c7 100644 --- a/mypy/checkpattern.py +++ b/mypy/checkpattern.py @@ -220,6 +220,13 @@ def visit_mapping_pattern(self, o: MappingPattern) -> PatternType: can_match = False else: self.update_type_map(captures, pattern_type.captures) + + if o.rest is not None: + # TODO: Infer dict type args + rest_type = self.accept(o.rest, self.chk.named_type("builtins.dict")) + assert rest_type is not None + self.update_type_map(captures, rest_type.captures) + if can_match: new_type = self.type_context[-1] # type: Optional[Type] else: diff --git a/test-data/unit/check-python310.test b/test-data/unit/check-python310.test index 09b9c693ed09..212822dfb120 100644 --- a/test-data/unit/check-python310.test +++ b/test-data/unit/check-python310.test @@ -266,6 +266,14 @@ match m: b: int [typing fixtures/typing-typeddict.pyi] +[case testMappingPatternCaptureRest] +m: object + +match m: + case {'k': 1, **r}: + reveal_type(r) # N: Revealed type is "builtins.dict[Any, Any]" +[builtins fixtures/dict.pyi] + -- Mapping patterns currently don't narrow -- -- Class Pattern -- From c35d92106ab5479618b0db6b618434d9126724ea Mon Sep 17 00:00:00 2001 From: Adrian Freund Date: Fri, 30 Apr 2021 19:21:40 +0200 Subject: [PATCH 33/76] Adjust ASTConverter to use dedicated pattern nodes The dedicated nodes will land in 3.10.0b1. This removes the need for a seperate PatternConverter. Patterns are now handled by the ASTConverter itself. This commit also undoes some changes made to accomondate the PatternConverter --- mypy-requirements.txt | 2 +- mypy/checkpattern.py | 73 +++---- mypy/fastparse.py | 289 +++++++++----------------- mypy/patterns.py | 72 ++----- mypy/reachability.py | 7 +- mypy/renaming.py | 7 +- mypy/semanal.py | 19 +- mypy/strconv.py | 15 +- mypy/traverser.py | 20 +- mypy/visitor.py | 22 +- test-data/unit/parse-python310.test | 274 +++++++++--------------- test-data/unit/semanal-python310.test | 50 +++-- test-requirements.txt | 5 +- 13 files changed, 317 insertions(+), 538 deletions(-) diff --git a/mypy-requirements.txt b/mypy-requirements.txt index 6c497d709e47..b5bb625d5a56 100644 --- a/mypy-requirements.txt +++ b/mypy-requirements.txt @@ -1,6 +1,6 @@ typing_extensions>=3.7.4 mypy_extensions>=0.4.3,<0.5.0 -typed_ast>=1.4.0,<1.5.0; python_version<'3.8' +typed_ast>=1.4.0,<1.5.0 types-typing-extensions>=3.7.0 types-mypy-extensions>=0.4.0 toml diff --git a/mypy/checkpattern.py b/mypy/checkpattern.py index fff132fc73c7..8e543e3dfb90 100644 --- a/mypy/checkpattern.py +++ b/mypy/checkpattern.py @@ -9,15 +9,15 @@ from mypy.messages import MessageBuilder from mypy.nodes import Expression, ARG_POS, TypeAlias, TypeInfo, Var, NameExpr from mypy.patterns import ( - Pattern, AsPattern, OrPattern, LiteralPattern, CapturePattern, WildcardPattern, ValuePattern, - SequencePattern, StarredPattern, MappingPattern, ClassPattern, MappingKeyPattern + Pattern, AsPattern, OrPattern, ValuePattern, SequencePattern, StarredPattern, MappingPattern, + ClassPattern, SingletonPattern ) from mypy.plugin import Plugin from mypy.subtypes import is_subtype, find_member from mypy.typeops import try_getting_str_literals_from_type, make_simplified_union from mypy.types import ( ProperType, AnyType, TypeOfAny, Instance, Type, UninhabitedType, get_proper_type, - TypedDictType, TupleType + TypedDictType, TupleType, NoneType ) from mypy.typevars import fill_typevars from mypy.visitor import PatternVisitor @@ -86,12 +86,18 @@ def accept(self, o: Pattern, type_context: Type) -> PatternType: return result def visit_as_pattern(self, o: AsPattern) -> PatternType: - pattern_type = self.accept(o.pattern, self.type_context[-1]) - typ, type_map = pattern_type - if typ is None: - return pattern_type - as_pattern_type = self.accept(o.name, typ) - self.update_type_map(type_map, as_pattern_type.captures) + current_type = self.type_context[-1] + if o.pattern is not None: + pattern_type = self.accept(o.pattern, current_type) + typ, type_map = pattern_type + else: + typ, type_map = current_type, {} + + if typ is not None and o.name is not None: + typ = get_more_specific_type(typ, current_type) + if typ is not None: + type_map[o.name] = typ + return PatternType(typ, type_map) def visit_or_pattern(self, o: OrPattern) -> PatternType: @@ -140,22 +146,23 @@ def visit_or_pattern(self, o: OrPattern) -> PatternType: union_type = make_simplified_union(types) return PatternType(union_type, captures) - def visit_literal_pattern(self, o: LiteralPattern) -> PatternType: - literal_type = self.chk.expr_checker.accept(o.expr) - typ = get_more_specific_type(literal_type, self.type_context[-1]) - return PatternType(typ, {}) - - def visit_capture_pattern(self, o: CapturePattern) -> PatternType: - return PatternType(self.type_context[-1], {o.name: self.type_context[-1]}) - - def visit_wildcard_pattern(self, o: WildcardPattern) -> PatternType: - return PatternType(self.type_context[-1], {}) - def visit_value_pattern(self, o: ValuePattern) -> PatternType: typ = self.chk.expr_checker.accept(o.expr) specific_typ = get_more_specific_type(typ, self.type_context[-1]) return PatternType(specific_typ, {}) + def visit_singleton_pattern(self, o: SingletonPattern) -> PatternType: + value = o.value + if isinstance(value, bool): + typ = self.chk.expr_checker.infer_literal_expr_type(value, "builtins.bool") + elif value is None: + typ = NoneType() + else: + assert False + + specific_type = get_more_specific_type(typ, self.type_context[-1]) + return PatternType(specific_type, {}) + def visit_sequence_pattern(self, o: SequencePattern) -> PatternType: current_type = self.type_context[-1] inner_type = self.get_sequence_type(current_type) @@ -196,14 +203,10 @@ def get_sequence_type(self, t: Type) -> Optional[Type]: return None def visit_starred_pattern(self, o: StarredPattern) -> PatternType: - if isinstance(o.capture, CapturePattern): + captures = {} # type: Dict[Expression, Type] + if o.capture is not None: list_type = self.chk.named_generic_type('builtins.list', [self.type_context[-1]]) - pattern_type = self.accept(o.capture, list_type) - captures = pattern_type.captures - elif isinstance(o.capture, WildcardPattern): - captures = {} - else: - assert False + captures[o.capture] = list_type return PatternType(self.type_context[-1], captures) def visit_mapping_pattern(self, o: MappingPattern) -> PatternType: @@ -223,9 +226,7 @@ def visit_mapping_pattern(self, o: MappingPattern) -> PatternType: if o.rest is not None: # TODO: Infer dict type args - rest_type = self.accept(o.rest, self.chk.named_type("builtins.dict")) - assert rest_type is not None - self.update_type_map(captures, rest_type.captures) + captures[o.rest] = self.chk.named_type("builtins.dict") if can_match: new_type = self.type_context[-1] # type: Optional[Type] @@ -236,14 +237,14 @@ def visit_mapping_pattern(self, o: MappingPattern) -> PatternType: def get_mapping_item_type(self, pattern: MappingPattern, mapping_type: Type, - key_pattern: MappingKeyPattern + key: Expression ) -> Optional[Type]: local_errors = self.msg.clean_copy() local_errors.disable_count = 0 mapping_type = get_proper_type(mapping_type) if isinstance(mapping_type, TypedDictType): result = self.chk.expr_checker.visit_typeddict_index_expr(mapping_type, - key_pattern.expr, + key, local_errors=local_errors ) # type: Optional[Type] # If we can't determine the type statically fall back to treating it as a normal @@ -253,7 +254,7 @@ def get_mapping_item_type(self, local_errors.disable_count = 0 result = self.get_simple_mapping_item_type(pattern, mapping_type, - key_pattern, + key, local_errors) if local_errors.is_errors(): @@ -261,19 +262,19 @@ def get_mapping_item_type(self, else: result = self.get_simple_mapping_item_type(pattern, mapping_type, - key_pattern, + key, local_errors) return result def get_simple_mapping_item_type(self, pattern: MappingPattern, mapping_type: Type, - key_pattern: MappingKeyPattern, + key: Expression, local_errors: MessageBuilder ) -> Type: result, _ = self.chk.expr_checker.check_method_call_by_name('__getitem__', mapping_type, - [key_pattern.expr], + [key], [ARG_POS], pattern, local_errors=local_errors) diff --git a/mypy/fastparse.py b/mypy/fastparse.py index b9ac434e4301..41326fd0f16a 100644 --- a/mypy/fastparse.py +++ b/mypy/fastparse.py @@ -7,7 +7,6 @@ Tuple, Union, TypeVar, Callable, Sequence, Optional, Any, Dict, cast, List, overload ) -from mypy_extensions import trait from typing_extensions import Final, Literal, overload from mypy.sharedparse import ( @@ -32,8 +31,8 @@ FakeInfo, ) from mypy.patterns import ( - AsPattern, OrPattern, LiteralPattern, CapturePattern, WildcardPattern, ValuePattern, - SequencePattern, StarredPattern, MappingPattern, MappingKeyPattern, ClassPattern, Pattern + AsPattern, OrPattern, ValuePattern, SequencePattern, StarredPattern, MappingPattern, + ClassPattern, SingletonPattern ) from mypy.types import ( Type, CallableType, AnyType, UnboundType, TupleType, TypeList, EllipsisType, CallableArgument, @@ -113,10 +112,22 @@ def ast3_parse(source: Union[str, bytes], filename: str, mode: str, if sys.version_info >= (3, 10): Match = ast3.Match + MatchValue = ast3.MatchValue + MatchSingleton = ast3.MatchSingleton + MatchSequence = ast3.MatchSequence + MatchStar = ast3.MatchStar + MatchMapping = ast3.MatchMapping + MatchClass = ast3.MatchClass MatchAs = ast3.MatchAs MatchOr = ast3.MatchOr else: Match = Any + MatchValue = Any + MatchSingleton = Any + MatchSequence = Any + MatchStar = Any + MatchMapping = Any + MatchClass = Any MatchAs = Any MatchOr = Any except ImportError: @@ -152,6 +163,7 @@ def parse(source: Union[str, bytes], module: Optional[str], errors: Optional[Errors] = None, options: Optional[Options] = None) -> MypyFile: + """Parse a source file, without doing any semantic analysis. Return the parse tree. If errors is not provided, raise ParseError @@ -302,14 +314,35 @@ def is_no_type_check_decorator(expr: ast3.expr) -> bool: return False -@trait -class Converter: - def __init__(self, errors: Optional[Errors]): +class ASTConverter: + def __init__(self, + options: Options, + is_stub: bool, + errors: Errors) -> None: + # 'C' for class, 'F' for function + self.class_and_function_stack = [] # type: List[Literal['C', 'F']] + self.imports = [] # type: List[ImportBase] + + self.options = options + self.is_stub = is_stub self.errors = errors + self.type_ignores = {} # type: Dict[int, List[str]] + # Cache of visit_X methods keyed by type of visited object self.visitor_cache = {} # type: Dict[type, Callable[[Optional[AST]], Any]] + def note(self, msg: str, line: int, column: int) -> None: + self.errors.report(line, column, msg, severity='note', code=codes.SYNTAX) + + def fail(self, + msg: str, + line: int, + column: int, + blocker: bool = True) -> None: + if blocker or not self.options.ignore_errors: + self.errors.report(line, column, msg, blocker=blocker, code=codes.SYNTAX) + def visit(self, node: Optional[AST]) -> Any: if node is None: return None @@ -327,45 +360,6 @@ def set_line(self, node: N, n: Union[ast3.expr, ast3.stmt, ast3.ExceptHandler]) node.end_line = getattr(n, "end_lineno", None) if isinstance(n, ast3.expr) else None return node - def note(self, msg: str, line: int, column: int) -> None: - if self.errors is not None: - self.errors.report(line, column, msg, severity='note', code=codes.SYNTAX) - - def fail(self, - msg: str, - line: int, - column: int, - blocker: bool = True) -> None: - if self.errors is not None: - self.errors.report(line, column, msg, blocker=blocker, code=codes.SYNTAX) - - -class ASTConverter(Converter): - # Errors is optional is superclass, but not here - errors = None # type: Errors - - def __init__(self, - options: Options, - is_stub: bool, - errors: Errors) -> None: - super().__init__(errors) - # 'C' for class, 'F' for function - self.class_and_function_stack = [] # type: List[Literal['C', 'F']] - self.imports = [] # type: List[ImportBase] - - self.options = options - self.is_stub = is_stub - - self.type_ignores = {} # type: Dict[int, List[str]] - - def fail(self, - msg: str, - line: int, - column: int, - blocker: bool = True) -> None: - if blocker or not self.options.ignore_errors: - super().fail(msg, line, column, blocker) - def translate_opt_expr_list(self, l: Sequence[Optional[AST]]) -> List[Optional[Expression]]: res = [] # type: List[Optional[Expression]] for e in l: @@ -1306,186 +1300,83 @@ def visit_Index(self, n: Index) -> Node: # Match(expr subject, match_case* cases) # python 3.10 and later def visit_Match(self, n: Match) -> MatchStmt: - pattern_converter = PatternConverter(self.options, self.errors) node = MatchStmt(self.visit(n.subject), - [pattern_converter.visit(c.pattern) for c in n.cases], + [self.visit(c.pattern) for c in n.cases], [self.visit(c.guard) for c in n.cases], [self.as_required_block(c.body, n.lineno) for c in n.cases]) return self.set_line(node, n) - -class PatternConverter(Converter): - # Errors is optional is superclass, but not here - errors = None # type: Errors - - has_sequence = False # type: bool - has_mapping = False # type: bool - - def __init__(self, options: Options, errors: Errors) -> None: - super().__init__(errors) - - self.options = options - - def visit(self, node: Optional[AST]) -> Pattern: - return super().visit(node) - - # MatchAs(expr pattern, identifier name) - def visit_MatchAs(self, n: MatchAs) -> AsPattern: - node = AsPattern(self.visit(n.pattern), self.set_line(CapturePattern(NameExpr(n.name)), n)) - return self.set_line(node, n) - - # MatchOr(expr* pattern) - def visit_MatchOr(self, n: MatchOr) -> OrPattern: - node = OrPattern([self.visit(pattern) for pattern in n.patterns]) + def visit_MatchValue(self, n: MatchValue) -> ValuePattern: + node = ValuePattern(self.visit(n.value)) return self.set_line(node, n) - # Constant(constant value) - def visit_Constant(self, n: Constant) -> LiteralPattern: - val = n.value - if val is None or isinstance(val, (bool, int, float, complex, str, bytes)): - node = LiteralPattern(val, ASTConverter(self.options, False, self.errors).visit(n)) - else: - raise RuntimeError("Pattern not implemented for " + str(type(val))) + def visit_MatchSingleton(self, n: MatchSingleton) -> SingletonPattern: + node = SingletonPattern(n.value) return self.set_line(node, n) - # UnaryOp(unaryop op, expr operand) - def visit_UnaryOp(self, n: ast3.UnaryOp) -> LiteralPattern: - # Constant is Any on python < 3.8, but this code is only reachable on python >= 3.10 - if not isinstance(n.operand, Constant): # type: ignore[misc] - raise RuntimeError("Pattern not implemented for " + str(type(n.operand))) - - value = self.assert_numeric_constant(n.operand) - - if isinstance(n.op, ast3.UAdd): - node = LiteralPattern(value, ASTConverter(self.options, False, self.errors).visit(n)) - elif isinstance(n.op, ast3.USub): - node = LiteralPattern(-value, ASTConverter(self.options, False, self.errors).visit(n)) - else: - raise RuntimeError("Pattern not implemented for " + str(type(n.op))) + def visit_MatchSequence(self, n: MatchSequence) -> SequencePattern: + patterns = [self.visit(p) for p in n.patterns] + stars = [p for p in patterns if isinstance(p, StarredPattern)] + assert len(stars) < 2 + node = SequencePattern(patterns) return self.set_line(node, n) - # BinOp(expr left, operator op, expr right) - def visit_BinOp(self, n: ast3.BinOp) -> LiteralPattern: - if isinstance(n.left, UnaryOp) and isinstance(n.left.op, ast3.USub): - left_val = -1 * self.assert_numeric_constant(n.left.operand) + def visit_MatchStar(self, n: MatchStar) -> StarredPattern: + if n.name is None: + node = StarredPattern(None) else: - left_val = self.assert_numeric_constant(n.left) - right_val = self.assert_numeric_constant(n.right) - - if left_val.imag != 0 or right_val.real: - raise RuntimeError("Unsupported pattern") - - if isinstance(n.op, ast3.Add): - node = LiteralPattern(left_val + right_val, - ASTConverter(self.options, False, self.errors).visit(n)) - elif isinstance(n.op, ast3.Sub): - node = LiteralPattern(left_val - right_val, - ASTConverter(self.options, False, self.errors).visit(n)) - else: - raise RuntimeError("Unsupported pattern") + node = StarredPattern(NameExpr(n.name)) return self.set_line(node, n) - def assert_numeric_constant(self, n: ast3.AST) -> Union[int, float, complex]: - # Constant is Any on python < 3.8, but this code is only reachable on python >= 3.10 - if isinstance(n, Constant): # type: ignore[misc] - val = n.value - if isinstance(val, (int, float, complex)): - return val - raise RuntimeError("Only numeric literals can be used with '+' and '-'. Found " - + str(type(n))) + def visit_MatchMapping(self, n: MatchMapping) -> MappingPattern: + keys = [self.visit(k) for k in n.keys] + values = [self.visit(v) for v in n.patterns] - # Name(identifier id, expr_context ctx) - def visit_Name(self, n: ast3.Name) -> Union[WildcardPattern, CapturePattern]: - node = None # type: Optional[Union[WildcardPattern, CapturePattern]] - if n.id == '_': - node = WildcardPattern() + if n.rest is None: + rest = None else: - node = CapturePattern(ASTConverter(self.options, False, self.errors).visit_Name(n)) - - return self.set_line(node, n) - - # Attribute(expr value, identifier attr, expr_context ctx) - def visit_Attribute(self, n: ast3.Attribute) -> ValuePattern: - # We can directly call `visit_Attribute`, as we know the type of n - node = ASTConverter(self.options, False, self.errors).visit_Attribute(n) - if not isinstance(node, MemberExpr): - raise RuntimeError("Unsupported pattern") - return self.set_line(ValuePattern(node), n) - - # List(expr* elts, expr_context ctx) - def visit_List(self, n: ast3.List) -> SequencePattern: - return self.make_sequence(n) - - # Tuple(expr* elts, expr_context ctx) - def visit_Tuple(self, n: ast3.Tuple) -> SequencePattern: - return self.make_sequence(n) - - def make_sequence(self, n: Union[ast3.List, ast3.Tuple]) -> SequencePattern: - patterns = [self.visit(p) for p in n.elts] - stars = [p for p in patterns if isinstance(p, StarredPattern)] - if len(stars) >= 2: - raise RuntimeError("Unsupported pattern") + rest = NameExpr(n.rest) - node = SequencePattern(patterns) + node = MappingPattern(keys, values, rest) return self.set_line(node, n) - # Starred(expr value, expr_context ctx) - def visit_Starred(self, n: ast3.Starred) -> StarredPattern: - expr = n.value - if not isinstance(expr, Name): - raise RuntimeError("Unsupported Pattern") - node = StarredPattern(self.visit_Name(expr)) + def visit_MatchClass(self, n: MatchClass) -> ClassPattern: + class_ref = self.visit(n.cls) + assert isinstance(class_ref, RefExpr) + positionals = [self.visit(p) for p in n.patterns] + keyword_keys = n.kwd_attrs + keyword_values = [self.visit(p) for p in n.kwd_patterns] + node = ClassPattern(class_ref, positionals, keyword_keys, keyword_values) return self.set_line(node, n) - # Dict(expr* keys, expr* values) - def visit_Dict(self, n: ast3.Dict) -> MappingPattern: - keys = [self.visit(k) for k in n.keys] - values = [self.visit(v) for v in n.values] - - if keys[-1] is None: - rest = values.pop() - keys.pop() + # MatchAs(expr pattern, identifier name) + def visit_MatchAs(self, n: MatchAs) -> AsPattern: + if n.name is None: + name = None else: - rest = None - - checked_keys = self.assert_key_patterns(keys) - - node = MappingPattern(checked_keys, values, rest) + name = NameExpr(n.name) + name = self.set_line(name, n) + node = AsPattern(self.visit(n.pattern), name) return self.set_line(node, n) - def assert_key_patterns(self, keys: List[Pattern]) -> List[MappingKeyPattern]: - for key in keys: - if not isinstance(key, MappingKeyPattern): - raise RuntimeError("Unsupported Pattern") - - return cast(List[MappingKeyPattern], keys) - - # Call(expr func, expr* args, keyword* keywords) - def visit_Call(self, n: ast3.Call) -> ClassPattern: - def raise_if_none(value: Optional[str]) -> str: - if value is None: - raise RuntimeError("Unsupported Pattern") - else: - return value - - class_ref = ASTConverter(self.options, False, self.errors).visit(n.func) - if not isinstance(class_ref, RefExpr): - raise RuntimeError("Unsupported Pattern") - positionals = [self.visit(p) for p in n.args] - keyword_keys = [raise_if_none(keyword.arg) for keyword in n.keywords] - keyword_values = [self.visit(keyword.value) for keyword in n.keywords] - - node = ClassPattern(class_ref, positionals, keyword_keys, keyword_values) + # MatchOr(expr* pattern) + def visit_MatchOr(self, n: MatchOr) -> OrPattern: + node = OrPattern([self.visit(pattern) for pattern in n.patterns]) return self.set_line(node, n) -class TypeConverter(Converter): - def __init__(self, errors: Optional[Errors], line: int = -1, override_column: int = -1, - assume_str_is_unicode: bool = True, is_evaluated: bool = True) -> None: - super().__init__(errors) +class TypeConverter: + def __init__(self, + errors: Optional[Errors], + line: int = -1, + override_column: int = -1, + assume_str_is_unicode: bool = True, + is_evaluated: bool = True, + ) -> None: + self.errors = errors self.line = line self.override_column = override_column self.node_stack = [] # type: List[AST] @@ -1548,6 +1439,14 @@ def parent(self) -> Optional[AST]: return None return self.node_stack[-2] + def fail(self, msg: str, line: int, column: int) -> None: + if self.errors: + self.errors.report(line, column, msg, blocker=True, code=codes.SYNTAX) + + def note(self, msg: str, line: int, column: int) -> None: + if self.errors: + self.errors.report(line, column, msg, severity='note', code=codes.SYNTAX) + def translate_expr_list(self, l: Sequence[ast3.expr]) -> List[Type]: return [self.visit(e) for e in l] diff --git a/mypy/patterns.py b/mypy/patterns.py index 58a1d9417d34..28b270f95077 100644 --- a/mypy/patterns.py +++ b/mypy/patterns.py @@ -3,7 +3,7 @@ from mypy_extensions import trait -from mypy.nodes import Node, MemberExpr, RefExpr, NameExpr, Expression +from mypy.nodes import Node, RefExpr, NameExpr, Expression from mypy.visitor import PatternVisitor # These are not real AST nodes. CPython represents patterns using the normal expression nodes. @@ -28,33 +28,11 @@ class AlwaysTruePattern(Pattern): __slots__ = () -@trait -class MappingKeyPattern(Pattern): - """A pattern that can be used as a key in a mapping pattern""" - - __slots__ = ("expr",) - - def __init__(self, expr: Expression) -> None: - super().__init__() - self.expr = expr - - -class CapturePattern(AlwaysTruePattern): - name = None # type: NameExpr - - def __init__(self, name: NameExpr): - super().__init__() - self.name = name - - def accept(self, visitor: PatternVisitor[T]) -> T: - return visitor.visit_capture_pattern(self) - - class AsPattern(Pattern): - pattern = None # type: Pattern - name = None # type: CapturePattern + pattern = None # type: Optional[Pattern] + name = None # type: Optional[NameExpr] - def __init__(self, pattern: Pattern, name: CapturePattern) -> None: + def __init__(self, pattern: Optional[Pattern], name: Optional[NameExpr]) -> None: super().__init__() self.pattern = pattern self.name = name @@ -74,34 +52,26 @@ def accept(self, visitor: PatternVisitor[T]) -> T: return visitor.visit_or_pattern(self) -LiteralPatternType = Union[int, complex, float, str, bytes, bool, None] - - -class LiteralPattern(MappingKeyPattern): - value = None # type: LiteralPatternType +class ValuePattern(Pattern): expr = None # type: Expression - def __init__(self, value: LiteralPatternType, expr: Expression): - super().__init__(expr) - self.value = value - - def accept(self, visitor: PatternVisitor[T]) -> T: - return visitor.visit_literal_pattern(self) - + def __init__(self, expr: Expression): + super().__init__() + self.expr = expr -class WildcardPattern(AlwaysTruePattern): def accept(self, visitor: PatternVisitor[T]) -> T: - return visitor.visit_wildcard_pattern(self) + return visitor.visit_value_pattern(self) -class ValuePattern(MappingKeyPattern): - expr = None # type: MemberExpr +class SingletonPattern(Pattern): + value = None # type: Union[bool, None] - def __init__(self, expr: MemberExpr): - super().__init__(expr) + def __init__(self, value: Union[bool, None]): + super().__init__() + self.value = value def accept(self, visitor: PatternVisitor[T]) -> T: - return visitor.visit_value_pattern(self) + return visitor.visit_singleton_pattern(self) class SequencePattern(Pattern): @@ -118,9 +88,9 @@ def accept(self, visitor: PatternVisitor[T]) -> T: # TODO: A StarredPattern is only valid within a SequencePattern. This is not guaranteed by our # type hierarchy. Should it be? class StarredPattern(Pattern): - capture = None # type: Union[WildcardPattern, CapturePattern] + capture = None # type: Optional[NameExpr] - def __init__(self, capture: Union[WildcardPattern, CapturePattern]): + def __init__(self, capture: Optional[NameExpr]): super().__init__() self.capture = capture @@ -129,12 +99,12 @@ def accept(self, visitor: PatternVisitor[T]) -> T: class MappingPattern(Pattern): - keys = None # type: List[MappingKeyPattern] + keys = None # type: List[Expression] values = None # type: List[Pattern] - rest = None # type: Optional[CapturePattern] + rest = None # type: Optional[NameExpr] - def __init__(self, keys: List[MappingKeyPattern], values: List[Pattern], - rest: Optional[CapturePattern]): + def __init__(self, keys: List[Expression], values: List[Pattern], + rest: Optional[NameExpr]): super().__init__() self.keys = keys self.values = values diff --git a/mypy/reachability.py b/mypy/reachability.py index 2a300984d4b1..d6a49a496ce6 100644 --- a/mypy/reachability.py +++ b/mypy/reachability.py @@ -9,7 +9,7 @@ Import, ImportFrom, ImportAll, LITERAL_YES ) from mypy.options import Options -from mypy.patterns import Pattern, WildcardPattern, CapturePattern +from mypy.patterns import Pattern, AsPattern, OrPattern from mypy.traverser import TraverserVisitor from mypy.literals import literal @@ -135,7 +135,10 @@ def infer_condition_value(expr: Expression, options: Options) -> int: def infer_pattern_value(pattern: Pattern) -> int: - if isinstance(pattern, (WildcardPattern, CapturePattern)): + if isinstance(pattern, AsPattern) and pattern.pattern is None: + return ALWAYS_TRUE + elif isinstance(pattern, OrPattern) and \ + any(infer_pattern_value(p) == ALWAYS_TRUE for p in pattern.patterns): return ALWAYS_TRUE else: return TRUTH_VALUE_UNKNOWN diff --git a/mypy/renaming.py b/mypy/renaming.py index 584b34a99b87..cb83225b6216 100644 --- a/mypy/renaming.py +++ b/mypy/renaming.py @@ -6,7 +6,7 @@ WhileStmt, ForStmt, BreakStmt, ContinueStmt, TryStmt, WithStmt, MatchStmt, StarExpr, ImportFrom, MemberExpr, IndexExpr, Import, ClassDef ) -from mypy.patterns import CapturePattern +from mypy.patterns import AsPattern from mypy.traverser import TraverserVisitor # Scope kinds @@ -186,8 +186,9 @@ def visit_match_stmt(self, s: MatchStmt) -> None: stmt.accept(self) self.leave_block() - def visit_capture_pattern(self, p: CapturePattern) -> None: - self.analyze_lvalue(p.name) + def visit_capture_pattern(self, p: AsPattern) -> None: + if p.name is not None: + self.analyze_lvalue(p.name) def analyze_lvalue(self, lvalue: Lvalue, is_nested: bool = False) -> None: """Process assignment; in particular, keep track of (re)defined names. diff --git a/mypy/semanal.py b/mypy/semanal.py index becd4168eaee..a54925fb53d7 100644 --- a/mypy/semanal.py +++ b/mypy/semanal.py @@ -79,7 +79,7 @@ ParamSpecExpr, MatchStmt ) from mypy.patterns import ( - AsPattern, OrPattern, CapturePattern, LiteralPattern, ValuePattern, SequencePattern, + AsPattern, OrPattern, ValuePattern, SequencePattern, StarredPattern, MappingPattern, ClassPattern ) @@ -3986,19 +3986,15 @@ def visit_await_expr(self, expr: AwaitExpr) -> None: # def visit_as_pattern(self, p: AsPattern) -> None: - p.pattern.accept(self) - p.name.accept(self) + if p.pattern is not None: + p.pattern.accept(self) + if p.name is not None: + self.analyze_lvalue(p.name) def visit_or_pattern(self, p: OrPattern) -> None: for pattern in p.patterns: pattern.accept(self) - def visit_capture_pattern(self, p: CapturePattern) -> None: - self.analyze_lvalue(p.name) - - def visit_literal_pattern(self, p: LiteralPattern) -> None: - p.expr.accept(self) - def visit_value_pattern(self, p: ValuePattern) -> None: p.expr.accept(self) @@ -4007,7 +4003,8 @@ def visit_sequence_pattern(self, p: SequencePattern) -> None: pattern.accept(self) def visit_starred_pattern(self, p: StarredPattern) -> None: - p.capture.accept(self) + if p.capture is not None: + self.analyze_lvalue(p.capture) def visit_mapping_pattern(self, p: MappingPattern) -> None: for key in p.keys: @@ -4015,7 +4012,7 @@ def visit_mapping_pattern(self, p: MappingPattern) -> None: for value in p.values: value.accept(self) if p.rest is not None: - p.rest.accept(self) + self.analyze_lvalue(p.rest) def visit_class_pattern(self, p: ClassPattern) -> None: p.class_ref.accept(self) diff --git a/mypy/strconv.py b/mypy/strconv.py index a12798fe73e9..5ecf75c35ba0 100644 --- a/mypy/strconv.py +++ b/mypy/strconv.py @@ -554,21 +554,12 @@ def visit_as_pattern(self, o: 'mypy.patterns.AsPattern') -> str: def visit_or_pattern(self, o: 'mypy.patterns.OrPattern') -> str: return self.dump(o.patterns, o) - def visit_literal_pattern(self, o: 'mypy.patterns.LiteralPattern') -> str: - value = o.value - if isinstance(o.value, str): - value = "'" + self.str_repr(o.value) + "'" - return self.dump([value, o.expr], o) - - def visit_capture_pattern(self, o: 'mypy.patterns.CapturePattern') -> str: - return self.dump([o.name], o) - - def visit_wildcard_pattern(self, o: 'mypy.patterns.WildcardPattern') -> str: - return self.dump([], o) - def visit_value_pattern(self, o: 'mypy.patterns.ValuePattern') -> str: return self.dump([o.expr], o) + def visit_singleton_pattern(self, o: 'mypy.patterns.SingletonPattern') -> str: + return self.dump([o.value], o) + def visit_sequence_pattern(self, o: 'mypy.patterns.SequencePattern') -> str: return self.dump(o.patterns, o) diff --git a/mypy/traverser.py b/mypy/traverser.py index 57f92908dbff..cea2cc3580eb 100644 --- a/mypy/traverser.py +++ b/mypy/traverser.py @@ -3,8 +3,10 @@ from typing import List from mypy_extensions import mypyc_attr -from mypy.patterns import AsPattern, OrPattern, CapturePattern, ValuePattern, SequencePattern, \ - StarredPattern, MappingPattern, ClassPattern, LiteralPattern +from mypy.patterns import ( + AsPattern, OrPattern, ValuePattern, SequencePattern, StarredPattern, MappingPattern, + ClassPattern +) from mypy.visitor import NodeVisitor from mypy.nodes import ( Block, MypyFile, FuncBase, FuncItem, CallExpr, ClassDef, Decorator, FuncDef, @@ -291,18 +293,15 @@ def visit_super_expr(self, o: SuperExpr) -> None: o.call.accept(self) def visit_as_pattern(self, o: AsPattern) -> None: - o.pattern.accept(self) + if o.pattern is not None: + o.pattern.accept(self) + if o.name is not None: + o.name.accept(self) def visit_or_pattern(self, o: OrPattern) -> None: for p in o.patterns: p.accept(self) - def visit_literal_pattern(self, o: LiteralPattern) -> None: - o.expr.accept(self) - - def visit_capture_pattern(self, o: CapturePattern) -> None: - o.name.accept(self) - def visit_value_pattern(self, o: ValuePattern) -> None: o.expr.accept(self) @@ -311,7 +310,8 @@ def visit_sequence_pattern(self, o: SequencePattern) -> None: p.accept(self) def visit_starred_patten(self, o: StarredPattern) -> None: - o.capture.accept(self) + if o.capture is not None: + o.capture.accept(self) def visit_mapping_pattern(self, o: MappingPattern) -> None: for key in o.keys: diff --git a/mypy/visitor.py b/mypy/visitor.py index 5dac830cfe17..9d3ebb6818b4 100644 --- a/mypy/visitor.py +++ b/mypy/visitor.py @@ -328,19 +328,11 @@ def visit_or_pattern(self, o: 'mypy.patterns.OrPattern') -> T: pass @abstractmethod - def visit_literal_pattern(self, o: 'mypy.patterns.LiteralPattern') -> T: - pass - - @abstractmethod - def visit_capture_pattern(self, o: 'mypy.patterns.CapturePattern') -> T: - pass - - @abstractmethod - def visit_wildcard_pattern(self, o: 'mypy.patterns.WildcardPattern') -> T: + def visit_value_pattern(self, o: 'mypy.patterns.ValuePattern') -> T: pass @abstractmethod - def visit_value_pattern(self, o: 'mypy.patterns.ValuePattern') -> T: + def visit_singleton_pattern(self, o: 'mypy.patterns.SingletonPattern') -> T: pass @abstractmethod @@ -623,16 +615,10 @@ def visit_as_pattern(self, o: 'mypy.patterns.AsPattern') -> T: def visit_or_pattern(self, o: 'mypy.patterns.OrPattern') -> T: pass - def visit_literal_pattern(self, o: 'mypy.patterns.LiteralPattern') -> T: - pass - - def visit_capture_pattern(self, o: 'mypy.patterns.CapturePattern') -> T: - pass - - def visit_wildcard_pattern(self, o: 'mypy.patterns.WildcardPattern') -> T: + def visit_value_pattern(self, o: 'mypy.patterns.ValuePattern') -> T: pass - def visit_value_pattern(self, o: 'mypy.patterns.ValuePattern') -> T: + def visit_singleton_pattern(self, o: 'mypy.patterns.SingletonPattern') -> T: pass def visit_sequence_pattern(self, o: 'mypy.patterns.SequencePattern') -> T: diff --git a/test-data/unit/parse-python310.test b/test-data/unit/parse-python310.test index 4f87fbf0e77e..87e0e9d5d283 100644 --- a/test-data/unit/parse-python310.test +++ b/test-data/unit/parse-python310.test @@ -11,8 +11,7 @@ MypyFile:1( MatchStmt:1( NameExpr(a) Pattern( - LiteralPattern:2( - 1 + ValuePattern:2( IntExpr(1))) Body( PassStmt:3()))) @@ -29,8 +28,7 @@ MypyFile:1( NameExpr(a) NameExpr(b)) Pattern( - LiteralPattern:2( - 1 + ValuePattern:2( IntExpr(1))) Body( PassStmt:3()))) @@ -46,8 +44,7 @@ MypyFile:1( MatchStmt:1( NameExpr(a) Pattern( - LiteralPattern:2( - 1 + ValuePattern:2( IntExpr(1))) Guard( CallExpr:2( @@ -56,7 +53,7 @@ MypyFile:1( Body( PassStmt:3()) Pattern( - CapturePattern:4( + AsPattern:4( NameExpr(d))) Guard( ComparisonExpr:4( @@ -76,11 +73,9 @@ MypyFile:1( NameExpr(a) Pattern( AsPattern:2( - LiteralPattern:2( - 1 + ValuePattern:2( IntExpr(1)) - CapturePattern:2( - NameExpr(b)))) + NameExpr(b))) Body( PassStmt:3()))) @@ -116,22 +111,19 @@ MypyFile:1( MatchStmt:1( NameExpr(a) Pattern( - LiteralPattern:2( - 1 + ValuePattern:2( IntExpr(1))) Body( PassStmt:3()) Pattern( - LiteralPattern:4( - -1 + ValuePattern:4( UnaryExpr:4( - IntExpr(1)))) Body( PassStmt:5()) Pattern( - LiteralPattern:6( - (1+2j) + ValuePattern:6( OpExpr:6( + IntExpr(1) @@ -139,8 +131,7 @@ MypyFile:1( Body( PassStmt:7()) Pattern( - LiteralPattern:8( - (-1+2j) + ValuePattern:8( OpExpr:8( + UnaryExpr:8( @@ -150,8 +141,7 @@ MypyFile:1( Body( PassStmt:9()) Pattern( - LiteralPattern:10( - (1-2j) + ValuePattern:10( OpExpr:10( - IntExpr(1) @@ -159,8 +149,7 @@ MypyFile:1( Body( PassStmt:11()) Pattern( - LiteralPattern:12( - (-1-2j) + ValuePattern:12( OpExpr:12( - UnaryExpr:12( @@ -170,38 +159,32 @@ MypyFile:1( Body( PassStmt:13()) Pattern( - LiteralPattern:14( - 'str' + ValuePattern:14( StrExpr(str))) Body( PassStmt:15()) Pattern( - LiteralPattern:16( - b'bytes' + ValuePattern:16( BytesExpr(bytes))) Body( PassStmt:17()) Pattern( - LiteralPattern:18( - 'raw_string' + ValuePattern:18( StrExpr(raw_string))) Body( PassStmt:19()) Pattern( - LiteralPattern:20( - NameExpr(None))) + SingletonPattern:20()) Body( PassStmt:21()) Pattern( - LiteralPattern:22( - True - NameExpr(True))) + SingletonPattern:22( + True)) Body( PassStmt:23()) Pattern( - LiteralPattern:24( - False - NameExpr(False))) + SingletonPattern:24( + False)) Body( PassStmt:25()))) @@ -216,12 +199,12 @@ MypyFile:1( MatchStmt:1( NameExpr(a) Pattern( - CapturePattern:2( + AsPattern:2( NameExpr(x))) Body( PassStmt:3()) Pattern( - CapturePattern:4( + AsPattern:4( NameExpr(longName))) Body( PassStmt:5()))) @@ -235,7 +218,7 @@ MypyFile:1( MatchStmt:1( NameExpr(a) Pattern( - WildcardPattern:2()) + AsPattern:2()) Body( PassStmt:3()))) @@ -280,8 +263,7 @@ MypyFile:1( MatchStmt:2( NameExpr(a) Pattern( - LiteralPattern:3( - 1 + ValuePattern:3( IntExpr(1))) Body( PassStmt:4()))) @@ -330,136 +312,106 @@ MypyFile:1( PassStmt:5()) Pattern( SequencePattern:6( - LiteralPattern:6( - 1 + ValuePattern:6( IntExpr(1)))) Body( PassStmt:7()) Pattern( SequencePattern:8( - LiteralPattern:8( - 1 + ValuePattern:8( IntExpr(1)))) Body( PassStmt:9()) Pattern( SequencePattern:10( - LiteralPattern:10( - 1 + ValuePattern:10( IntExpr(1)))) Body( PassStmt:11()) Pattern( SequencePattern:12( - LiteralPattern:12( - 1 + ValuePattern:12( IntExpr(1)) - LiteralPattern:12( - 2 + ValuePattern:12( IntExpr(2)) - LiteralPattern:12( - 3 + ValuePattern:12( IntExpr(3)))) Body( PassStmt:13()) Pattern( SequencePattern:14( - LiteralPattern:14( - 1 + ValuePattern:14( IntExpr(1)) - LiteralPattern:14( - 2 + ValuePattern:14( IntExpr(2)) - LiteralPattern:14( - 3 + ValuePattern:14( IntExpr(3)))) Body( PassStmt:15()) Pattern( SequencePattern:16( - LiteralPattern:16( - 1 + ValuePattern:16( IntExpr(1)) - LiteralPattern:16( - 2 + ValuePattern:16( IntExpr(2)) - LiteralPattern:16( - 3 + ValuePattern:16( IntExpr(3)))) Body( PassStmt:17()) Pattern( SequencePattern:18( - LiteralPattern:18( - 1 + ValuePattern:18( IntExpr(1)) StarredPattern:18( - CapturePattern:18( - NameExpr(a))) - LiteralPattern:18( - 2 + NameExpr(a)) + ValuePattern:18( IntExpr(2)))) Body( PassStmt:19()) Pattern( SequencePattern:20( - LiteralPattern:20( - 1 + ValuePattern:20( IntExpr(1)) StarredPattern:20( - CapturePattern:20( - NameExpr(a))) - LiteralPattern:20( - 2 + NameExpr(a)) + ValuePattern:20( IntExpr(2)))) Body( PassStmt:21()) Pattern( SequencePattern:22( - LiteralPattern:22( - 1 + ValuePattern:22( IntExpr(1)) StarredPattern:22( - CapturePattern:22( - NameExpr(a))) - LiteralPattern:22( - 2 + NameExpr(a)) + ValuePattern:22( IntExpr(2)))) Body( PassStmt:23()) Pattern( SequencePattern:24( - LiteralPattern:24( - 1 + ValuePattern:24( IntExpr(1)) - StarredPattern:24( - WildcardPattern:24()) - LiteralPattern:24( - 2 + StarredPattern:24() + ValuePattern:24( IntExpr(2)))) Body( PassStmt:25()) Pattern( SequencePattern:26( - LiteralPattern:26( - 1 + ValuePattern:26( IntExpr(1)) - StarredPattern:26( - WildcardPattern:26()) - LiteralPattern:26( - 2 + StarredPattern:26() + ValuePattern:26( IntExpr(2)))) Body( PassStmt:27()) Pattern( SequencePattern:28( - LiteralPattern:28( - 1 + ValuePattern:28( IntExpr(1)) - StarredPattern:28( - WildcardPattern:28()) - LiteralPattern:28( - 2 + StarredPattern:28() + ValuePattern:28( IntExpr(2)))) Body( PassStmt:29()))) @@ -489,137 +441,109 @@ MypyFile:1( Pattern( MappingPattern:2( Key( - LiteralPattern:2( - 'k' - StrExpr(k))) + StrExpr(k)) Value( - CapturePattern:2( + AsPattern:2( NameExpr(v))))) Body( PassStmt:3()) Pattern( MappingPattern:4( Key( - ValuePattern:4( - MemberExpr:4( - NameExpr(a) - b))) + MemberExpr:4( + NameExpr(a) + b)) Value( - CapturePattern:4( + AsPattern:4( NameExpr(v))))) Body( PassStmt:5()) Pattern( MappingPattern:6( Key( - LiteralPattern:6( - 1 - IntExpr(1))) + IntExpr(1)) Value( - CapturePattern:6( + AsPattern:6( NameExpr(v))))) Body( PassStmt:7()) Pattern( MappingPattern:8( Key( - ValuePattern:8( - MemberExpr:8( - NameExpr(a) - c))) + MemberExpr:8( + NameExpr(a) + c)) Value( - CapturePattern:8( + AsPattern:8( NameExpr(v))))) Body( PassStmt:9()) Pattern( MappingPattern:10( Key( - LiteralPattern:10( - 'k' - StrExpr(k))) + StrExpr(k)) Value( - CapturePattern:10( + AsPattern:10( NameExpr(v1))) Key( - ValuePattern:10( - MemberExpr:10( - NameExpr(a) - b))) + MemberExpr:10( + NameExpr(a) + b)) Value( - CapturePattern:10( + AsPattern:10( NameExpr(v2))) Key( - LiteralPattern:10( - 1 - IntExpr(1))) + IntExpr(1)) Value( - CapturePattern:10( + AsPattern:10( NameExpr(v3))) Key( - ValuePattern:10( - MemberExpr:10( - NameExpr(a) - c))) + MemberExpr:10( + NameExpr(a) + c)) Value( - CapturePattern:10( + AsPattern:10( NameExpr(v4))))) Body( PassStmt:11()) Pattern( MappingPattern:12( Key( - LiteralPattern:12( - 'k1' - StrExpr(k1))) + StrExpr(k1)) Value( - LiteralPattern:12( - 1 + ValuePattern:12( IntExpr(1))) Key( - LiteralPattern:12( - 'k2' - StrExpr(k2))) + StrExpr(k2)) Value( - LiteralPattern:12( - 'str' + ValuePattern:12( StrExpr(str))) Key( - LiteralPattern:12( - 'k3' - StrExpr(k3))) + StrExpr(k3)) Value( - LiteralPattern:12( - b'bytes' + ValuePattern:12( BytesExpr(bytes))) Key( - LiteralPattern:12( - 'k4' - StrExpr(k4))) + StrExpr(k4)) Value( - LiteralPattern:12( - NameExpr(None))))) + SingletonPattern:12()))) Body( PassStmt:13()) Pattern( MappingPattern:14( Key( - LiteralPattern:14( - 'k' - StrExpr(k))) + StrExpr(k)) Value( - CapturePattern:14( + AsPattern:14( NameExpr(v))) Rest( - CapturePattern:14( - NameExpr(r))))) + NameExpr(r)))) Body( PassStmt:15()) Pattern( MappingPattern:16( Rest( - CapturePattern:16( - NameExpr(r))))) + NameExpr(r)))) Body( PassStmt:17()))) @@ -646,11 +570,9 @@ MypyFile:1( ClassPattern:4( NameExpr(B) Positionals( - LiteralPattern:4( - 1 + ValuePattern:4( IntExpr(1)) - LiteralPattern:4( - 2 + ValuePattern:4( IntExpr(2))))) Body( PassStmt:5()) @@ -658,13 +580,11 @@ MypyFile:1( ClassPattern:6( NameExpr(B) Positionals( - LiteralPattern:6( - 1 + ValuePattern:6( IntExpr(1))) Keyword( b - LiteralPattern:6( - 2 + ValuePattern:6( IntExpr(2))))) Body( PassStmt:7()) @@ -673,13 +593,11 @@ MypyFile:1( NameExpr(B) Keyword( a - LiteralPattern:8( - 1 + ValuePattern:8( IntExpr(1))) Keyword( b - LiteralPattern:8( - 2 + ValuePattern:8( IntExpr(2))))) Body( PassStmt:9()))) diff --git a/test-data/unit/semanal-python310.test b/test-data/unit/semanal-python310.test index e889b12c0643..a009636575dc 100644 --- a/test-data/unit/semanal-python310.test +++ b/test-data/unit/semanal-python310.test @@ -13,7 +13,7 @@ MypyFile:1( MatchStmt:2( NameExpr(x [__main__.x]) Pattern( - CapturePattern:3( + AsPattern:3( NameExpr(a* [__main__.a]))) Body( ExpressionStmt:4( @@ -33,7 +33,7 @@ MypyFile:1( MatchStmt:2( NameExpr(x [__main__.x]) Pattern( - CapturePattern:3( + AsPattern:3( NameExpr(a* [__main__.a]))) Body( PassStmt:4())) @@ -56,15 +56,13 @@ MypyFile:1( Pattern( SequencePattern:3( SequencePattern:3( - CapturePattern:3( + AsPattern:3( NameExpr(a* [__main__.a]))) MappingPattern:3( Key( - LiteralPattern:3( - 'k' - StrExpr(k))) + StrExpr(k)) Value( - CapturePattern:3( + AsPattern:3( NameExpr(b* [__main__.b])))))) Body( ExpressionStmt:4( @@ -72,6 +70,27 @@ MypyFile:1( ExpressionStmt:5( NameExpr(b [__main__.b]))))) +[case testMappingPatternRest] +x = 1 +match x: + case {**r}: + r +[out] +MypyFile:1( + AssignmentStmt:1( + NameExpr(x* [__main__.x]) + IntExpr(1)) + MatchStmt:2( + NameExpr(x [__main__.x]) + Pattern( + MappingPattern:3( + Rest( + NameExpr(r* [__main__.r])))) + Body( + ExpressionStmt:4( + NameExpr(r [__main__.r]))))) + + [case testAsPattern] x = 1 match x: @@ -86,11 +105,9 @@ MypyFile:1( NameExpr(x [__main__.x]) Pattern( AsPattern:3( - LiteralPattern:3( - 1 + ValuePattern:3( IntExpr(1)) - CapturePattern:3( - NameExpr(a* [__main__.a])))) + NameExpr(a* [__main__.a]))) Body( ExpressionStmt:4( NameExpr(a [__main__.a]))))) @@ -112,8 +129,7 @@ MypyFile:1( MatchStmt:3( NameExpr(x [__main__.x]) Pattern( - LiteralPattern:4( - 1 + ValuePattern:4( IntExpr(1))) Guard( NameExpr(a [__main__.a])) @@ -133,7 +149,7 @@ MypyFile:1( MatchStmt:2( NameExpr(x [__main__.x]) Pattern( - CapturePattern:3( + AsPattern:3( NameExpr(a* [__main__.a]))) Guard( NameExpr(a [__main__.a])) @@ -154,11 +170,9 @@ MypyFile:1( NameExpr(x [__main__.x]) Pattern( AsPattern:3( - LiteralPattern:3( - 1 + ValuePattern:3( IntExpr(1)) - CapturePattern:3( - NameExpr(a* [__main__.a])))) + NameExpr(a* [__main__.a]))) Guard( NameExpr(a [__main__.a])) Body( diff --git a/test-requirements.txt b/test-requirements.txt index 8986b4115cd5..9b70608a9c48 100644 --- a/test-requirements.txt +++ b/test-requirements.txt @@ -6,14 +6,13 @@ flake8-bugbear; python_version >= '3.5' flake8-pyi>=20.5; python_version >= '3.6' lxml>=4.4.0 psutil>=4.0 -# pytest 6.2 does not support Python 3.5 -pytest>=6.1.0,<6.2.0 +# pytest 6.2.3 does not support Python 3.10 +pytest>6.2.3 pytest-xdist>=1.34.0,<2.0.0 pytest-forked>=1.3.0,<2.0.0 pytest-cov>=2.10.0,<3.0.0 typing>=3.5.2; python_version < '3.5' py>=1.5.2 -typed_ast>=1.4.0,<1.5.0 virtualenv<20 setuptools!=50 importlib-metadata==0.20 From 1aae2c1ed4e27603efd4c12d59c88dda9dfd5fc7 Mon Sep 17 00:00:00 2001 From: Adrian Freund Date: Sat, 1 May 2021 13:19:16 +0200 Subject: [PATCH 34/76] Improve type inference for tuples checked against sequence patterns --- mypy/checkpattern.py | 98 ++++++++++++++++++++++++----- test-data/unit/check-python310.test | 78 +++++++++++++++++++++++ 2 files changed, 161 insertions(+), 15 deletions(-) diff --git a/mypy/checkpattern.py b/mypy/checkpattern.py index 8e543e3dfb90..2179eecfb8e8 100644 --- a/mypy/checkpattern.py +++ b/mypy/checkpattern.py @@ -1,6 +1,6 @@ """Pattern checker. This file is conceptually part of TypeChecker.""" from collections import defaultdict -from typing import List, Optional, Tuple, Dict, NamedTuple, Set +from typing import List, Optional, Tuple, Dict, NamedTuple, Set, cast import mypy.checker from mypy.expandtype import expand_type_by_instance @@ -164,29 +164,76 @@ def visit_singleton_pattern(self, o: SingletonPattern) -> PatternType: return PatternType(specific_type, {}) def visit_sequence_pattern(self, o: SequencePattern) -> PatternType: - current_type = self.type_context[-1] - inner_type = self.get_sequence_type(current_type) - if inner_type is None: - if is_subtype(self.chk.named_type("typing.Iterable"), current_type): - # Current type is more general, but the actual value could still be iterable - inner_type = self.chk.named_type("builtins.object") - else: + # + # check for existence of a starred pattern + # + current_type = get_proper_type(self.type_context[-1]) + can_match = True + star_positions = [i for i, p in enumerate(o.patterns) if isinstance(p, StarredPattern)] + star_position = None # type: Optional[int] + if len(star_positions) == 1: + star_position = star_positions[0] + elif len(star_positions) >= 2: + assert False, "Parser should prevent multiple starred patterns" + required_patterns = len(o.patterns) + if star_position is not None: + required_patterns -= 1 + + # + # get inner types of original type + # + if isinstance(current_type, TupleType): + inner_types = current_type.items + size_diff = len(inner_types) - required_patterns + if size_diff < 0: + return early_non_match() + elif size_diff > 0 and star_position is None: return early_non_match() + else: + inner_type = self.get_sequence_type(current_type) + if inner_type is None: + if is_subtype(self.chk.named_type("typing.Iterable"), current_type): + # Current type is more general, but the actual value could still be iterable + inner_type = self.chk.named_type("builtins.object") + else: + return early_non_match() + inner_types = [inner_type] * len(o.patterns) - new_inner_type = UninhabitedType() # type: Type + # + # match inner patterns + # + contracted_new_inner_types = [] # type: List[Type] captures = {} # type: Dict[Expression, Type] - can_match = True - for p in o.patterns: - pattern_type = self.accept(p, inner_type) + + contracted_inner_types = self.contract_starred_pattern_types(inner_types, star_position, required_patterns) + for p, t in zip(o.patterns, contracted_inner_types): + pattern_type = self.accept(p, t) typ, type_map = pattern_type if typ is None: can_match = False else: - new_inner_type = join_types(new_inner_type, typ) + contracted_new_inner_types.append(typ) self.update_type_map(captures, type_map) + new_inner_types = self.expand_starred_pattern_types(contracted_new_inner_types, star_position, len(inner_types)) - new_type = None # type: Optional[Type] - if can_match: + # + # Calculate new type + # + if not can_match: + new_type = None # type: Optional[Type] + elif isinstance(current_type, TupleType): + specific_inner_types = [] + for inner_type, new_inner_type in zip(inner_types, new_inner_types): + specific_inner_types.append(get_more_specific_type(inner_type, new_inner_type)) + if all(typ is not None for typ in specific_inner_types): + specific_inner_types_cast = cast(List[Type], specific_inner_types) + new_type = TupleType(specific_inner_types_cast, current_type.partial_fallback) + else: + new_type = None + else: + new_inner_type = UninhabitedType() + for typ in new_inner_types: + new_inner_type = join_types(new_inner_type, typ) new_type = self.construct_iterable_child(current_type, new_inner_type) if not is_subtype(new_type, current_type): new_type = current_type @@ -202,6 +249,27 @@ def get_sequence_type(self, t: Type) -> Optional[Type]: else: return None + def contract_starred_pattern_types(self, types: List[Type], star_pos: Optional[int], num_patterns: int) -> List[Type]: + if star_pos is None: + return types + new_types = types[:star_pos] + star_length = len(types) - num_patterns + new_types.append(make_simplified_union(types[star_pos:star_pos+star_length])) + new_types += types[star_pos+star_length:] + + return new_types + + def expand_starred_pattern_types(self, types: List[Type], star_pos: Optional[int], num_types: int) -> List[Type]: + if star_pos is None: + return types + new_types = types[:star_pos] + star_length = num_types - len(types) + 1 + new_types += [types[star_pos]] * star_length + new_types += types[star_pos+1:] + + return new_types + + def visit_starred_pattern(self, o: StarredPattern) -> PatternType: captures = {} # type: Dict[Expression, Type] if o.capture is not None: diff --git a/test-data/unit/check-python310.test b/test-data/unit/check-python310.test index 212822dfb120..9ee77ea07056 100644 --- a/test-data/unit/check-python310.test +++ b/test-data/unit/check-python310.test @@ -149,6 +149,84 @@ match m: reveal_type(m) # N: Revealed type is "builtins.list[builtins.object]" [builtins fixtures/list.pyi] +[case testSequencePatternCapturesTuple] +from typing import Tuple +m: Tuple[int, str, bool] + +match m: + case [a, b, c]: + reveal_type(a) # N: Revealed type is "builtins.int" + reveal_type(b) # N: Revealed type is "builtins.str" + reveal_type(c) # N: Revealed type is "builtins.bool" + reveal_type(m) # N: Revealed type is "Tuple[builtins.int, builtins.str, builtins.bool]" +[builtins fixtures/list.pyi] + +[case testSequencePatternTupleToLong] +from typing import Tuple +m: Tuple[int, str] + +match m: + case [a, b, c]: + reveal_type(a) + reveal_type(b) + reveal_type(c) +[builtins fixtures/list.pyi] + +[case testSequencePatternTupleToShort] +from typing import Tuple +m: Tuple[int, str, bool] + +match m: + case [a, b]: + reveal_type(a) + reveal_type(b) +[builtins fixtures/list.pyi] + +[case testSequencePatternTupleNarrows] +from typing import Tuple +m: Tuple[object, object] + +match m: + case [1, "str"]: + reveal_type(m) # N: Revealed type is "Tuple[Literal[1]?, Literal['str']?]" +[builtins fixtures/list.pyi] + +[case testSequencePatternTupleStarred] +from typing import Tuple +m: Tuple[int, str, bool] + +match m: + case [a, *b, c]: + reveal_type(a) # N: Revealed type is "builtins.int" + reveal_type(b) # N: Revealed type is "builtins.list[builtins.str]" + reveal_type(c) # N: Revealed type is "builtins.bool" + reveal_type(m) # N: Revealed type is "Tuple[builtins.int, builtins.str, builtins.bool]" +[builtins fixtures/list.pyi] + +[case testSequencePatternTupleStarredUnion] +from typing import Tuple +m: Tuple[int, str, float, bool] + +match m: + case [a, *b, c]: + reveal_type(a) # N: Revealed type is "builtins.int" + reveal_type(b) # N: Revealed type is "builtins.list[Union[builtins.str, builtins.float]]" + reveal_type(c) # N: Revealed type is "builtins.bool" + reveal_type(m) # N: Revealed type is "Tuple[builtins.int, builtins.str, builtins.float, builtins.bool]" +[builtins fixtures/list.pyi] + + +[case testSequencePatternTupleStarredTooShort] +from typing import Tuple +m: Tuple[int] +reveal_type(m) # N: Revealed type is "Tuple[builtins.int]" + +match m: + case [a, *b, c]: + reveal_type(a) + reveal_type(b) + reveal_type(c) +[builtins fixtures/list.pyi] -- Mapping Pattern -- [case testMappingPatternCaptures] From 32a7b71e9a2c708705b27193413728896454337f Mon Sep 17 00:00:00 2001 From: Adrian Freund Date: Mon, 3 May 2021 12:46:40 +0200 Subject: [PATCH 35/76] Fix linter errors --- mypy/checkpattern.py | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/mypy/checkpattern.py b/mypy/checkpattern.py index 2179eecfb8e8..05fbe3491333 100644 --- a/mypy/checkpattern.py +++ b/mypy/checkpattern.py @@ -205,7 +205,9 @@ def visit_sequence_pattern(self, o: SequencePattern) -> PatternType: contracted_new_inner_types = [] # type: List[Type] captures = {} # type: Dict[Expression, Type] - contracted_inner_types = self.contract_starred_pattern_types(inner_types, star_position, required_patterns) + contracted_inner_types = self.contract_starred_pattern_types(inner_types, + star_position, + required_patterns) for p, t in zip(o.patterns, contracted_inner_types): pattern_type = self.accept(p, t) typ, type_map = pattern_type @@ -214,7 +216,9 @@ def visit_sequence_pattern(self, o: SequencePattern) -> PatternType: else: contracted_new_inner_types.append(typ) self.update_type_map(captures, type_map) - new_inner_types = self.expand_starred_pattern_types(contracted_new_inner_types, star_position, len(inner_types)) + new_inner_types = self.expand_starred_pattern_types(contracted_new_inner_types, + star_position, + len(inner_types)) # # Calculate new type @@ -249,7 +253,11 @@ def get_sequence_type(self, t: Type) -> Optional[Type]: else: return None - def contract_starred_pattern_types(self, types: List[Type], star_pos: Optional[int], num_patterns: int) -> List[Type]: + def contract_starred_pattern_types(self, + types: List[Type], + star_pos: Optional[int], + num_patterns: int + ) -> List[Type]: if star_pos is None: return types new_types = types[:star_pos] @@ -259,7 +267,11 @@ def contract_starred_pattern_types(self, types: List[Type], star_pos: Optional[i return new_types - def expand_starred_pattern_types(self, types: List[Type], star_pos: Optional[int], num_types: int) -> List[Type]: + def expand_starred_pattern_types(self, + types: List[Type], + star_pos: Optional[int], + num_types: int + ) -> List[Type]: if star_pos is None: return types new_types = types[:star_pos] @@ -269,7 +281,6 @@ def expand_starred_pattern_types(self, types: List[Type], star_pos: Optional[int return new_types - def visit_starred_pattern(self, o: StarredPattern) -> PatternType: captures = {} # type: Dict[Expression, Type] if o.capture is not None: From 25fe496c9f114dc0b26ce16fe895adada89eb9ff Mon Sep 17 00:00:00 2001 From: Adrian Freund Date: Mon, 10 May 2021 13:02:51 +0200 Subject: [PATCH 36/76] Sync typeshed Source commit: https://github.com/python/typeshed/commit/a2058829faa338a9d398fccdcd969b7a9e19f282 --- mypy/typeshed/stdlib/@python2/__builtin__.pyi | 3 +- mypy/typeshed/stdlib/@python2/builtins.pyi | 3 +- mypy/typeshed/stdlib/@python2/typing.pyi | 10 +- mypy/typeshed/stdlib/VERSIONS | 184 ++++----- mypy/typeshed/stdlib/__main__.pyi | 3 + mypy/typeshed/stdlib/_ast.pyi | 34 ++ mypy/typeshed/stdlib/_bisect.pyi | 37 +- mypy/typeshed/stdlib/_bootlocale.pyi | 5 +- mypy/typeshed/stdlib/_collections_abc.pyi | 1 - mypy/typeshed/stdlib/_curses.pyi | 22 +- .../typeshed/stdlib/_importlib_modulespec.pyi | 50 --- mypy/typeshed/stdlib/argparse.pyi | 2 +- mypy/typeshed/stdlib/array.pyi | 5 +- mypy/typeshed/stdlib/asyncio/base_events.pyi | 18 +- mypy/typeshed/stdlib/asyncio/events.pyi | 25 +- mypy/typeshed/stdlib/asyncio/subprocess.pyi | 139 +++++-- mypy/typeshed/stdlib/asyncio/tasks.pyi | 353 ++++++++++++------ mypy/typeshed/stdlib/base64.pyi | 4 + mypy/typeshed/stdlib/builtins.pyi | 51 ++- mypy/typeshed/stdlib/calendar.pyi | 1 + mypy/typeshed/stdlib/contextlib.pyi | 17 +- mypy/typeshed/stdlib/contextvars.pyi | 6 +- mypy/typeshed/stdlib/datetime.pyi | 18 +- mypy/typeshed/stdlib/email/errors.pyi | 18 + mypy/typeshed/stdlib/enum.pyi | 20 +- mypy/typeshed/stdlib/formatter.pyi | 198 +++++----- mypy/typeshed/stdlib/ftplib.pyi | 2 +- mypy/typeshed/stdlib/functools.pyi | 15 +- mypy/typeshed/stdlib/genericpath.pyi | 9 +- mypy/typeshed/stdlib/glob.pyi | 19 +- mypy/typeshed/stdlib/html/parser.pyi | 10 + mypy/typeshed/stdlib/http/client.pyi | 39 +- mypy/typeshed/stdlib/http/cookiejar.pyi | 48 ++- mypy/typeshed/stdlib/importlib/abc.pyi | 22 +- mypy/typeshed/stdlib/importlib/machinery.pyi | 26 +- mypy/typeshed/stdlib/inspect.pyi | 44 ++- mypy/typeshed/stdlib/itertools.pyi | 13 +- mypy/typeshed/stdlib/locale.pyi | 2 +- mypy/typeshed/stdlib/lzma.pyi | 4 +- mypy/typeshed/stdlib/mmap.pyi | 53 ++- .../stdlib/multiprocessing/__init__.pyi | 21 +- .../stdlib/multiprocessing/context.pyi | 71 +++- mypy/typeshed/stdlib/multiprocessing/pool.pyi | 6 + .../stdlib/multiprocessing/sharedctypes.pyi | 124 ++++-- mypy/typeshed/stdlib/nntplib.pyi | 2 +- mypy/typeshed/stdlib/ntpath.pyi | 28 +- mypy/typeshed/stdlib/os/__init__.pyi | 24 +- mypy/typeshed/stdlib/os/path.pyi | 58 ++- mypy/typeshed/stdlib/parser.pyi | 36 +- mypy/typeshed/stdlib/platform.pyi | 2 +- mypy/typeshed/stdlib/posix.pyi | 43 ++- mypy/typeshed/stdlib/posixpath.pyi | 18 +- mypy/typeshed/stdlib/pprint.pyi | 58 ++- mypy/typeshed/stdlib/re.pyi | 4 +- mypy/typeshed/stdlib/readline.pyi | 35 +- mypy/typeshed/stdlib/selectors.pyi | 2 + mypy/typeshed/stdlib/signal.pyi | 7 +- mypy/typeshed/stdlib/site.pyi | 5 +- mypy/typeshed/stdlib/smtplib.pyi | 4 +- mypy/typeshed/stdlib/sqlite3/dbapi2.pyi | 26 +- mypy/typeshed/stdlib/sre_constants.pyi | 6 +- mypy/typeshed/stdlib/symbol.pyi | 176 ++++----- mypy/typeshed/stdlib/sys.pyi | 14 +- mypy/typeshed/stdlib/tarfile.pyi | 4 +- mypy/typeshed/stdlib/termios.pyi | 12 +- mypy/typeshed/stdlib/time.pyi | 3 +- mypy/typeshed/stdlib/tkinter/__init__.pyi | 4 +- mypy/typeshed/stdlib/tkinter/font.pyi | 9 +- mypy/typeshed/stdlib/traceback.pyi | 43 ++- mypy/typeshed/stdlib/types.pyi | 15 +- mypy/typeshed/stdlib/typing.pyi | 34 +- mypy/typeshed/stdlib/typing_extensions.pyi | 33 +- mypy/typeshed/stdlib/urllib/request.pyi | 11 +- mypy/typeshed/stdlib/urllib/robotparser.pyi | 2 +- mypy/typeshed/stdlib/webbrowser.pyi | 4 +- mypy/typeshed/stdlib/xml/dom/minicompat.pyi | 18 +- mypy/typeshed/stdlib/xml/dom/minidom.pyi | 312 +++++++++++++++- mypy/typeshed/stdlib/xml/dom/xmlbuilder.pyi | 3 + mypy/typeshed/stdlib/xxlimited.pyi | 13 +- test-data/unit/lib-stub/array.pyi.bak | 15 + 80 files changed, 1955 insertions(+), 883 deletions(-) create mode 100644 mypy/typeshed/stdlib/__main__.pyi delete mode 100644 mypy/typeshed/stdlib/_importlib_modulespec.pyi create mode 100644 test-data/unit/lib-stub/array.pyi.bak diff --git a/mypy/typeshed/stdlib/@python2/__builtin__.pyi b/mypy/typeshed/stdlib/@python2/__builtin__.pyi index ed42c1e8f380..4eb5424d8da7 100644 --- a/mypy/typeshed/stdlib/@python2/__builtin__.pyi +++ b/mypy/typeshed/stdlib/@python2/__builtin__.pyi @@ -44,7 +44,7 @@ from typing import ( ValuesView, overload, ) -from typing_extensions import Literal +from typing_extensions import Literal, final class _SupportsIndex(Protocol): def __index__(self) -> int: ... @@ -568,6 +568,7 @@ class memoryview(Sized, Container[str]): def tobytes(self) -> bytes: ... def tolist(self) -> List[int]: ... +@final class bool(int): def __new__(cls: Type[_T], __o: object = ...) -> _T: ... @overload diff --git a/mypy/typeshed/stdlib/@python2/builtins.pyi b/mypy/typeshed/stdlib/@python2/builtins.pyi index ed42c1e8f380..4eb5424d8da7 100644 --- a/mypy/typeshed/stdlib/@python2/builtins.pyi +++ b/mypy/typeshed/stdlib/@python2/builtins.pyi @@ -44,7 +44,7 @@ from typing import ( ValuesView, overload, ) -from typing_extensions import Literal +from typing_extensions import Literal, final class _SupportsIndex(Protocol): def __index__(self) -> int: ... @@ -568,6 +568,7 @@ class memoryview(Sized, Container[str]): def tobytes(self) -> bytes: ... def tolist(self) -> List[int]: ... +@final class bool(int): def __new__(cls: Type[_T], __o: object = ...) -> _T: ... @overload diff --git a/mypy/typeshed/stdlib/@python2/typing.pyi b/mypy/typeshed/stdlib/@python2/typing.pyi index e134d17415b0..dfcac95197b0 100644 --- a/mypy/typeshed/stdlib/@python2/typing.pyi +++ b/mypy/typeshed/stdlib/@python2/typing.pyi @@ -5,7 +5,6 @@ from types import CodeType, FrameType, TracebackType # Definitions of special type checking related constructs. Their definitions # are not used, so their value does not matter. -overload = object() Any = object() class TypeVar: @@ -40,6 +39,7 @@ Final: _SpecialForm = ... _F = TypeVar("_F", bound=Callable[..., Any]) def final(f: _F) -> _F: ... +def overload(f: _F) -> _F: ... Literal: _SpecialForm = ... # TypedDict is a (non-subscriptable) special form. @@ -63,11 +63,9 @@ _KT_co = TypeVar("_KT_co", covariant=True) # Key type covariant containers. _VT_co = TypeVar("_VT_co", covariant=True) # Value type covariant containers. _T_contra = TypeVar("_T_contra", contravariant=True) # Ditto contravariant. _TC = TypeVar("_TC", bound=Type[object]) -_C = TypeVar("_C", bound=Callable[..., Any]) -no_type_check = object() - -def no_type_check_decorator(decorator: _C) -> _C: ... +def no_type_check(f: _F) -> _F: ... +def no_type_check_decorator(decorator: _F) -> _F: ... # Type aliases and type constructors @@ -492,4 +490,4 @@ class _TypedDict(Mapping[str, object], metaclass=ABCMeta): def NewType(name: str, tp: Type[_T]) -> Type[_T]: ... # This itself is only available during type checking -def type_check_only(func_or_cls: _C) -> _C: ... +def type_check_only(func_or_cls: _F) -> _F: ... diff --git a/mypy/typeshed/stdlib/VERSIONS b/mypy/typeshed/stdlib/VERSIONS index 872219720ffc..44557650e7bc 100644 --- a/mypy/typeshed/stdlib/VERSIONS +++ b/mypy/typeshed/stdlib/VERSIONS @@ -1,5 +1,6 @@ __future__: 2.7 -_ast: 3.6 +__main__: 2.7 +_ast: 2.7 _bisect: 2.7 _bootlocale: 3.6 _codecs: 2.7 @@ -13,48 +14,47 @@ _dummy_thread: 3.6 _dummy_threading: 2.7 _heapq: 2.7 _imp: 3.6 -_importlib_modulespec: 3.6 -_json: 3.6 -_markupbase: 3.6 +_json: 2.7 +_markupbase: 2.7 _msi: 2.7 _operator: 3.6 -_osx_support: 3.6 +_osx_support: 2.7 _posixsubprocess: 3.6 _py_abc: 3.7 _pydecimal: 3.6 _random: 2.7 _sitebuiltins: 3.6 _stat: 3.6 -_thread: 3.6 +_thread: 2.7 _threading_local: 3.6 -_tkinter: 3.6 +_tkinter: 2.7 _tracemalloc: 3.6 _typeshed: 2.7 _warnings: 2.7 _weakref: 2.7 _weakrefset: 2.7 _winapi: 3.6 -abc: 3.6 +abc: 2.7 aifc: 2.7 antigravity: 2.7 argparse: 2.7 array: 2.7 -ast: 3.6 +ast: 2.7 asynchat: 2.7 -asyncio: 3.6 +asyncio: 3.4 asyncore: 2.7 -atexit: 3.6 +atexit: 2.7 audioop: 2.7 base64: 2.7 bdb: 2.7 binascii: 2.7 binhex: 2.7 bisect: 2.7 -builtins: 3.6 +builtins: 3.0 bz2: 2.7 cProfile: 2.7 calendar: 2.7 -cgi: 3.6 +cgi: 2.7 cgitb: 2.7 chunk: 2.7 cmath: 2.7 @@ -62,73 +62,73 @@ cmd: 2.7 code: 2.7 codecs: 2.7 codeop: 2.7 -collections: 3.6 +collections: 2.7 colorsys: 2.7 -compileall: 3.6 -concurrent: 3.6 -configparser: 3.6 +compileall: 2.7 +concurrent: 3.2 +configparser: 2.7 contextlib: 2.7 contextvars: 3.7 copy: 2.7 -copyreg: 3.6 +copyreg: 2.7 crypt: 2.7 csv: 2.7 ctypes: 2.7 curses: 2.7 dataclasses: 3.7 datetime: 2.7 -dbm: 3.6 +dbm: 2.7 decimal: 2.7 difflib: 2.7 dis: 2.7 -distutils: 3.6 +distutils: 2.7 doctest: 2.7 dummy_threading: 2.7 -email: 3.6 -encodings: 3.6 +email: 2.7 +encodings: 2.7 ensurepip: 2.7 -enum: 3.6 +enum: 3.4 errno: 2.7 -faulthandler: 3.6 -fcntl: 3.6 +faulthandler: 3.3 +fcntl: 2.7 filecmp: 2.7 fileinput: 2.7 -fnmatch: 3.6 +fnmatch: 2.7 formatter: 2.7 fractions: 2.7 ftplib: 2.7 -functools: 3.6 -gc: 3.6 +functools: 2.7 +gc: 2.7 genericpath: 2.7 -getopt: 3.6 -getpass: 3.6 -gettext: 3.6 -glob: 3.6 +getopt: 2.7 +getpass: 2.7 +gettext: 2.7 +glob: 2.7 graphlib: 3.9 grp: 2.7 -gzip: 3.6 -hashlib: 3.6 -heapq: 3.6 +gzip: 2.7 +hashlib: 2.7 +heapq: 2.7 hmac: 2.7 -html: 3.6 -http: 3.6 +html: 2.7 +http: 3.0 imaplib: 2.7 imghdr: 2.7 -imp: 3.6 -importlib: 3.6 -inspect: 3.6 -io: 3.6 -ipaddress: 3.6 -itertools: 3.6 -json: 3.6 +imp: 2.7 +importlib: 2.7 +inspect: 2.7 +io: 2.7 +ipaddress: 2.7 +itertools: 2.7 +json: 2.7 keyword: 2.7 lib2to3: 2.7 linecache: 2.7 locale: 2.7 -logging: 3.6 -lzma: 3.6 +logging: 2.7 +lzma: 3.3 macpath: 2.7 -macurl2path: 3.6 +macurl2path: 2.7 mailbox: 2.7 mailcap: 2.7 marshal: 2.7 @@ -138,30 +138,30 @@ mmap: 2.7 modulefinder: 2.7 msilib: 2.7 msvcrt: 2.7 -multiprocessing: 3.6 +multiprocessing: 2.7 netrc: 2.7 nis: 2.7 -nntplib: 3.6 -ntpath: 3.6 -nturl2path: 3.6 +nntplib: 2.7 +ntpath: 2.7 +nturl2path: 2.7 numbers: 2.7 opcode: 2.7 operator: 2.7 optparse: 2.7 -os: 3.6 -ossaudiodev: 3.6 +os: 2.7 +ossaudiodev: 2.7 parser: 2.7 -pathlib: 3.6 +pathlib: 3.4 pdb: 2.7 pickle: 2.7 pickletools: 2.7 -pipes: 3.6 +pipes: 2.7 pkgutil: 2.7 -platform: 3.6 +platform: 2.7 plistlib: 2.7 poplib: 2.7 -posix: 3.6 -posixpath: 3.6 +posix: 2.7 +posixpath: 2.7 pprint: 2.7 profile: 2.7 pstats: 2.7 @@ -172,86 +172,86 @@ pyclbr: 2.7 pydoc: 2.7 pydoc_data: 2.7 pyexpat: 2.7 -queue: 3.6 +queue: 2.7 quopri: 2.7 -random: 3.6 -re: 3.6 +random: 2.7 +re: 2.7 readline: 2.7 -reprlib: 3.6 -resource: 3.6 +reprlib: 2.7 +resource: 2.7 rlcompleter: 2.7 -runpy: 3.6 +runpy: 2.7 sched: 2.7 secrets: 3.6 select: 2.7 -selectors: 3.6 -shelve: 3.6 -shlex: 3.6 +selectors: 3.4 +shelve: 2.7 +shlex: 2.7 shutil: 2.7 -signal: 3.6 +signal: 2.7 site: 2.7 smtpd: 2.7 -smtplib: 3.6 +smtplib: 2.7 sndhdr: 2.7 socket: 2.7 -socketserver: 3.6 -spwd: 3.6 +socketserver: 2.7 +spwd: 2.7 sqlite3: 2.7 sre_compile: 2.7 -sre_constants: 3.6 -sre_parse: 3.6 +sre_constants: 2.7 +sre_parse: 2.7 ssl: 2.7 -stat: 3.6 -statistics: 3.6 -string: 3.6 +stat: 2.7 +statistics: 3.4 +string: 2.7 stringprep: 2.7 struct: 2.7 -subprocess: 3.6 +subprocess: 2.7 sunau: 2.7 -symbol: 3.6 +symbol: 2.7 symtable: 2.7 -sys: 3.6 +sys: 2.7 sysconfig: 2.7 syslog: 2.7 tabnanny: 2.7 tarfile: 2.7 telnetlib: 2.7 -tempfile: 3.6 +tempfile: 2.7 termios: 2.7 -textwrap: 3.6 +textwrap: 2.7 this: 2.7 threading: 2.7 time: 2.7 timeit: 2.7 -tkinter: 3.6 +tkinter: 3.0 token: 2.7 -tokenize: 3.6 +tokenize: 2.7 trace: 2.7 traceback: 2.7 -tracemalloc: 3.6 +tracemalloc: 3.4 tty: 2.7 turtle: 2.7 -types: 3.6 -typing: 3.6 +types: 2.7 +typing: 3.5 typing_extensions: 2.7 unicodedata: 2.7 -unittest: 3.6 -urllib: 3.6 +unittest: 2.7 +urllib: 2.7 uu: 2.7 uuid: 2.7 -venv: 3.6 +venv: 3.3 warnings: 2.7 wave: 2.7 weakref: 2.7 webbrowser: 2.7 -winreg: 3.6 +winreg: 2.7 winsound: 2.7 wsgiref: 2.7 xdrlib: 2.7 xml: 2.7 -xmlrpc: 3.6 +xmlrpc: 3.0 xxlimited: 3.6 -zipapp: 3.6 +zipapp: 3.5 zipfile: 2.7 zipimport: 2.7 zlib: 2.7 diff --git a/mypy/typeshed/stdlib/__main__.pyi b/mypy/typeshed/stdlib/__main__.pyi new file mode 100644 index 000000000000..e27843e53382 --- /dev/null +++ b/mypy/typeshed/stdlib/__main__.pyi @@ -0,0 +1,3 @@ +from typing import Any + +def __getattr__(name: str) -> Any: ... diff --git a/mypy/typeshed/stdlib/_ast.pyi b/mypy/typeshed/stdlib/_ast.pyi index fd0bc107dfaf..2d0d92d83032 100644 --- a/mypy/typeshed/stdlib/_ast.pyi +++ b/mypy/typeshed/stdlib/_ast.pyi @@ -375,3 +375,37 @@ class alias(AST): class withitem(AST): context_expr: expr optional_vars: Optional[expr] + +if sys.version_info >= (3, 10): + class Match(stmt): + subject: expr + cases: typing.List[match_case] + class pattern(AST): ... + # Without the alias, Pyright complains variables named pattern are recursively defined + _pattern = pattern + class match_case(AST): + pattern: _pattern + guard: Optional[expr] + body: typing.List[stmt] + class MatchValue(pattern): + value: expr + class MatchSingleton(pattern): + value: Optional[bool] + class MatchSequence(pattern): + patterns: typing.List[pattern] + class MatchStar(pattern): + name: Optional[_identifier] + class MatchMapping(pattern): + keys: typing.List[expr] + patterns: typing.List[pattern] + rest: Optional[_identifier] + class MatchClass(pattern): + cls: expr + patterns: typing.List[pattern] + kwd_attrs: typing.List[_identifier] + kwd_patterns: typing.List[pattern] + class MatchAs(pattern): + pattern: Optional[_pattern] + name: Optional[_identifier] + class MatchOr(pattern): + patterns: typing.List[pattern] diff --git a/mypy/typeshed/stdlib/_bisect.pyi b/mypy/typeshed/stdlib/_bisect.pyi index 1e909c2a77d3..3ca863a2e939 100644 --- a/mypy/typeshed/stdlib/_bisect.pyi +++ b/mypy/typeshed/stdlib/_bisect.pyi @@ -1,8 +1,35 @@ -from typing import MutableSequence, Optional, Sequence, TypeVar +import sys +from _typeshed import SupportsLessThan +from typing import Callable, MutableSequence, Optional, Sequence, TypeVar _T = TypeVar("_T") -def bisect_left(a: Sequence[_T], x: _T, lo: int = ..., hi: Optional[int] = ...) -> int: ... -def bisect_right(a: Sequence[_T], x: _T, lo: int = ..., hi: Optional[int] = ...) -> int: ... -def insort_left(a: MutableSequence[_T], x: _T, lo: int = ..., hi: Optional[int] = ...) -> None: ... -def insort_right(a: MutableSequence[_T], x: _T, lo: int = ..., hi: Optional[int] = ...) -> None: ... +if sys.version_info >= (3, 10): + def bisect_left( + a: Sequence[_T], x: _T, lo: int = ..., hi: Optional[int] = ..., *, key: Optional[Callable[[_T], SupportsLessThan]] = ... + ) -> int: ... + def bisect_right( + a: Sequence[_T], x: _T, lo: int = ..., hi: Optional[int] = ..., *, key: Optional[Callable[[_T], SupportsLessThan]] = ... + ) -> int: ... + def insort_left( + a: MutableSequence[_T], + x: _T, + lo: int = ..., + hi: Optional[int] = ..., + *, + key: Optional[Callable[[_T], SupportsLessThan]] = ..., + ) -> None: ... + def insort_right( + a: MutableSequence[_T], + x: _T, + lo: int = ..., + hi: Optional[int] = ..., + *, + key: Optional[Callable[[_T], SupportsLessThan]] = ..., + ) -> None: ... + +else: + def bisect_left(a: Sequence[_T], x: _T, lo: int = ..., hi: Optional[int] = ...) -> int: ... + def bisect_right(a: Sequence[_T], x: _T, lo: int = ..., hi: Optional[int] = ...) -> int: ... + def insort_left(a: MutableSequence[_T], x: _T, lo: int = ..., hi: Optional[int] = ...) -> None: ... + def insort_right(a: MutableSequence[_T], x: _T, lo: int = ..., hi: Optional[int] = ...) -> None: ... diff --git a/mypy/typeshed/stdlib/_bootlocale.pyi b/mypy/typeshed/stdlib/_bootlocale.pyi index ee2d89347a9f..73e7b6b546bd 100644 --- a/mypy/typeshed/stdlib/_bootlocale.pyi +++ b/mypy/typeshed/stdlib/_bootlocale.pyi @@ -1 +1,4 @@ -def getpreferredencoding(do_setlocale: bool = ...) -> str: ... +import sys + +if sys.version_info < (3, 10): + def getpreferredencoding(do_setlocale: bool = ...) -> str: ... diff --git a/mypy/typeshed/stdlib/_collections_abc.pyi b/mypy/typeshed/stdlib/_collections_abc.pyi index 357c1f91a735..27d5234432f3 100644 --- a/mypy/typeshed/stdlib/_collections_abc.pyi +++ b/mypy/typeshed/stdlib/_collections_abc.pyi @@ -26,7 +26,6 @@ from typing import ( ValuesView as ValuesView, ) -# Without the real definition, mypy and pytest both think that __all__ is empty, so re-exports nothing __all__ = [ "Awaitable", "Coroutine", diff --git a/mypy/typeshed/stdlib/_curses.pyi b/mypy/typeshed/stdlib/_curses.pyi index cf11bb40d2ee..1ccd54e35edd 100644 --- a/mypy/typeshed/stdlib/_curses.pyi +++ b/mypy/typeshed/stdlib/_curses.pyi @@ -378,11 +378,11 @@ class _CursesWindow: def addstr(self, str: str, attr: int = ...) -> None: ... @overload def addstr(self, y: int, x: int, str: str, attr: int = ...) -> None: ... - def attroff(self, attr: int) -> None: ... - def attron(self, attr: int) -> None: ... - def attrset(self, attr: int) -> None: ... - def bkgd(self, ch: _chtype, attr: int = ...) -> None: ... - def bkgdset(self, ch: _chtype, attr: int = ...) -> None: ... + def attroff(self, __attr: int) -> None: ... + def attron(self, __attr: int) -> None: ... + def attrset(self, __attr: int) -> None: ... + def bkgd(self, __ch: _chtype, __attr: int = ...) -> None: ... + def bkgdset(self, __ch: _chtype, __attr: int = ...) -> None: ... def border( self, ls: _chtype = ..., @@ -420,8 +420,8 @@ class _CursesWindow: def derwin(self, begin_y: int, begin_x: int) -> _CursesWindow: ... @overload def derwin(self, nlines: int, ncols: int, begin_y: int, begin_x: int) -> _CursesWindow: ... - def echochar(self, ch: _chtype, attr: int = ...) -> None: ... - def enclose(self, y: int, x: int) -> bool: ... + def echochar(self, __ch: _chtype, __attr: int = ...) -> None: ... + def enclose(self, __y: int, __x: int) -> bool: ... def erase(self) -> None: ... def getbegyx(self) -> Tuple[int, int]: ... def getbkgd(self) -> Tuple[int, int]: ... @@ -478,7 +478,7 @@ class _CursesWindow: def instr(self, n: int = ...) -> _chtype: ... @overload def instr(self, y: int, x: int, n: int = ...) -> _chtype: ... - def is_linetouched(self, line: int) -> bool: ... + def is_linetouched(self, __line: int) -> bool: ... def is_wintouched(self) -> bool: ... def keypad(self, yes: bool) -> None: ... def leaveok(self, yes: bool) -> None: ... @@ -500,8 +500,8 @@ class _CursesWindow: def overwrite( self, destwin: _CursesWindow, sminrow: int, smincol: int, dminrow: int, dmincol: int, dmaxrow: int, dmaxcol: int ) -> None: ... - def putwin(self, file: IO[Any]) -> None: ... - def redrawln(self, beg: int, num: int) -> None: ... + def putwin(self, __file: IO[Any]) -> None: ... + def redrawln(self, __beg: int, __num: int) -> None: ... def redrawwin(self) -> None: ... @overload def refresh(self) -> None: ... @@ -510,7 +510,7 @@ class _CursesWindow: def resize(self, nlines: int, ncols: int) -> None: ... def scroll(self, lines: int = ...) -> None: ... def scrollok(self, flag: bool) -> None: ... - def setscrreg(self, top: int, bottom: int) -> None: ... + def setscrreg(self, __top: int, __bottom: int) -> None: ... def standend(self) -> None: ... def standout(self) -> None: ... @overload diff --git a/mypy/typeshed/stdlib/_importlib_modulespec.pyi b/mypy/typeshed/stdlib/_importlib_modulespec.pyi deleted file mode 100644 index 114b78e1061d..000000000000 --- a/mypy/typeshed/stdlib/_importlib_modulespec.pyi +++ /dev/null @@ -1,50 +0,0 @@ -# ModuleSpec, ModuleType, Loader are part of a dependency cycle. -# They are officially defined/exported in other places: -# -# - ModuleType in types -# - Loader in importlib.abc -# - ModuleSpec in importlib.machinery (3.4 and later only) -# -# _Loader is the PEP-451-defined interface for a loader type/object. - -from abc import ABCMeta -from typing import Any, Dict, List, Optional, Protocol - -class _Loader(Protocol): - def load_module(self, fullname: str) -> ModuleType: ... - -class ModuleSpec: - def __init__( - self, - name: str, - loader: Optional[Loader], - *, - origin: Optional[str] = ..., - loader_state: Any = ..., - is_package: Optional[bool] = ..., - ) -> None: ... - name: str - loader: Optional[_Loader] - origin: Optional[str] - submodule_search_locations: Optional[List[str]] - loader_state: Any - cached: Optional[str] - parent: Optional[str] - has_location: bool - -class ModuleType: - __name__: str - __file__: str - __dict__: Dict[str, Any] - __loader__: Optional[_Loader] - __package__: Optional[str] - __spec__: Optional[ModuleSpec] - def __init__(self, name: str, doc: Optional[str] = ...) -> None: ... - -class Loader(metaclass=ABCMeta): - def load_module(self, fullname: str) -> ModuleType: ... - def module_repr(self, module: ModuleType) -> str: ... - def create_module(self, spec: ModuleSpec) -> Optional[ModuleType]: ... - # Not defined on the actual class for backwards-compatibility reasons, - # but expected in new code. - def exec_module(self, module: ModuleType) -> None: ... diff --git a/mypy/typeshed/stdlib/argparse.pyi b/mypy/typeshed/stdlib/argparse.pyi index 3dd6e56175dd..9dceaabd4631 100644 --- a/mypy/typeshed/stdlib/argparse.pyi +++ b/mypy/typeshed/stdlib/argparse.pyi @@ -263,7 +263,7 @@ class HelpFormatter: def end_section(self) -> None: ... def add_text(self, text: Optional[Text]) -> None: ... def add_usage( - self, usage: Text, actions: Iterable[Action], groups: Iterable[_ArgumentGroup], prefix: Optional[Text] = ... + self, usage: Optional[Text], actions: Iterable[Action], groups: Iterable[_ArgumentGroup], prefix: Optional[Text] = ... ) -> None: ... def add_argument(self, action: Action) -> None: ... def add_arguments(self, actions: Iterable[Action]) -> None: ... diff --git a/mypy/typeshed/stdlib/array.pyi b/mypy/typeshed/stdlib/array.pyi index c7e1ef0bb8fa..498bf92919f1 100644 --- a/mypy/typeshed/stdlib/array.pyi +++ b/mypy/typeshed/stdlib/array.pyi @@ -33,7 +33,10 @@ class array(MutableSequence[_T], Generic[_T]): def fromfile(self, __f: BinaryIO, __n: int) -> None: ... def fromlist(self, __list: List[_T]) -> None: ... def fromunicode(self, __ustr: str) -> None: ... - def index(self, __v: _T) -> int: ... # type: ignore # Overrides Sequence + if sys.version_info >= (3, 10): + def index(self, __v: _T, __start: int = ..., __stop: int = ...) -> int: ... + else: + def index(self, __v: _T) -> int: ... # type: ignore # Overrides Sequence def insert(self, __i: int, __v: _T) -> None: ... def pop(self, __i: int = ...) -> _T: ... if sys.version_info < (3,): diff --git a/mypy/typeshed/stdlib/asyncio/base_events.pyi b/mypy/typeshed/stdlib/asyncio/base_events.pyi index 90235741f40c..c2ad4a609c01 100644 --- a/mypy/typeshed/stdlib/asyncio/base_events.pyi +++ b/mypy/typeshed/stdlib/asyncio/base_events.pyi @@ -315,9 +315,9 @@ class BaseEventLoop(AbstractEventLoop, metaclass=ABCMeta): protocol_factory: _ProtocolFactory, cmd: Union[bytes, str], *, - stdin: Any = ..., - stdout: Any = ..., - stderr: Any = ..., + stdin: Union[int, IO[Any], None] = ..., + stdout: Union[int, IO[Any], None] = ..., + stderr: Union[int, IO[Any], None] = ..., universal_newlines: Literal[False] = ..., shell: Literal[True] = ..., bufsize: Literal[0] = ..., @@ -329,10 +329,16 @@ class BaseEventLoop(AbstractEventLoop, metaclass=ABCMeta): async def subprocess_exec( self, protocol_factory: _ProtocolFactory, + program: Any, *args: Any, - stdin: Any = ..., - stdout: Any = ..., - stderr: Any = ..., + stdin: Union[int, IO[Any], None] = ..., + stdout: Union[int, IO[Any], None] = ..., + stderr: Union[int, IO[Any], None] = ..., + universal_newlines: Literal[False] = ..., + shell: Literal[True] = ..., + bufsize: Literal[0] = ..., + encoding: None = ..., + errors: None = ..., **kwargs: Any, ) -> _TransProtPair: ... def add_reader(self, fd: FileDescriptorLike, callback: Callable[..., Any], *args: Any) -> None: ... diff --git a/mypy/typeshed/stdlib/asyncio/events.pyi b/mypy/typeshed/stdlib/asyncio/events.pyi index 6c9717f873cd..9159af4eb20b 100644 --- a/mypy/typeshed/stdlib/asyncio/events.pyi +++ b/mypy/typeshed/stdlib/asyncio/events.pyi @@ -9,6 +9,7 @@ from asyncio.transports import BaseTransport from asyncio.unix_events import AbstractChildWatcher from socket import AddressFamily, SocketKind, _Address, _RetAddress, socket from typing import IO, Any, Awaitable, Callable, Dict, Generator, List, Optional, Sequence, Tuple, TypeVar, Union, overload +from typing_extensions import Literal if sys.version_info >= (3, 7): from contextvars import Context @@ -399,19 +400,31 @@ class AbstractEventLoop(metaclass=ABCMeta): protocol_factory: _ProtocolFactory, cmd: Union[bytes, str], *, - stdin: Any = ..., - stdout: Any = ..., - stderr: Any = ..., + stdin: Union[int, IO[Any], None] = ..., + stdout: Union[int, IO[Any], None] = ..., + stderr: Union[int, IO[Any], None] = ..., + universal_newlines: Literal[False] = ..., + shell: Literal[True] = ..., + bufsize: Literal[0] = ..., + encoding: None = ..., + errors: None = ..., + text: Literal[False, None] = ..., **kwargs: Any, ) -> _TransProtPair: ... @abstractmethod async def subprocess_exec( self, protocol_factory: _ProtocolFactory, + program: Any, *args: Any, - stdin: Any = ..., - stdout: Any = ..., - stderr: Any = ..., + stdin: Union[int, IO[Any], None] = ..., + stdout: Union[int, IO[Any], None] = ..., + stderr: Union[int, IO[Any], None] = ..., + universal_newlines: Literal[False] = ..., + shell: Literal[True] = ..., + bufsize: Literal[0] = ..., + encoding: None = ..., + errors: None = ..., **kwargs: Any, ) -> _TransProtPair: ... @abstractmethod diff --git a/mypy/typeshed/stdlib/asyncio/subprocess.pyi b/mypy/typeshed/stdlib/asyncio/subprocess.pyi index 58e1bd40d0eb..d443625db28a 100644 --- a/mypy/typeshed/stdlib/asyncio/subprocess.pyi +++ b/mypy/typeshed/stdlib/asyncio/subprocess.pyi @@ -1,13 +1,14 @@ +import subprocess import sys +from _typeshed import AnyPath from asyncio import events, protocols, streams, transports -from typing import IO, Any, Optional, Tuple, Union +from typing import IO, Any, Callable, Optional, Tuple, Union +from typing_extensions import Literal if sys.version_info >= (3, 8): - from os import PathLike - - _ExecArg = Union[str, bytes, PathLike[str], PathLike[bytes]] + _ExecArg = AnyPath else: - _ExecArg = Union[str, bytes] # Union used instead of AnyStr due to mypy issue #1236 + _ExecArg = Union[str, bytes] PIPE: int STDOUT: int @@ -39,22 +40,112 @@ class Process: def kill(self) -> None: ... async def communicate(self, input: Optional[bytes] = ...) -> Tuple[bytes, bytes]: ... -async def create_subprocess_shell( - cmd: Union[str, bytes], # Union used instead of AnyStr due to mypy issue #1236 - stdin: Union[int, IO[Any], None] = ..., - stdout: Union[int, IO[Any], None] = ..., - stderr: Union[int, IO[Any], None] = ..., - loop: Optional[events.AbstractEventLoop] = ..., - limit: int = ..., - **kwds: Any, -) -> Process: ... -async def create_subprocess_exec( - program: _ExecArg, - *args: _ExecArg, - stdin: Union[int, IO[Any], None] = ..., - stdout: Union[int, IO[Any], None] = ..., - stderr: Union[int, IO[Any], None] = ..., - loop: Optional[events.AbstractEventLoop] = ..., - limit: int = ..., - **kwds: Any, -) -> Process: ... +if sys.version_info >= (3, 10): + async def create_subprocess_shell( + cmd: Union[str, bytes], + stdin: Union[int, IO[Any], None] = ..., + stdout: Union[int, IO[Any], None] = ..., + stderr: Union[int, IO[Any], None] = ..., + limit: int = ..., + *, + # These parameters are forced to these values by BaseEventLoop.subprocess_shell + universal_newlines: Literal[False] = ..., + shell: Literal[True] = ..., + bufsize: Literal[0] = ..., + encoding: None = ..., + errors: None = ..., + text: Literal[False, None] = ..., + # These parameters are taken by subprocess.Popen, which this ultimately delegates to + executable: Optional[AnyPath] = ..., + preexec_fn: Optional[Callable[[], Any]] = ..., + close_fds: bool = ..., + cwd: Optional[AnyPath] = ..., + env: Optional[subprocess._ENV] = ..., + startupinfo: Optional[Any] = ..., + creationflags: int = ..., + restore_signals: bool = ..., + start_new_session: bool = ..., + pass_fds: Any = ..., + ) -> Process: ... + async def create_subprocess_exec( + program: _ExecArg, + *args: _ExecArg, + stdin: Union[int, IO[Any], None] = ..., + stdout: Union[int, IO[Any], None] = ..., + stderr: Union[int, IO[Any], None] = ..., + limit: int = ..., + # These parameters are forced to these values by BaseEventLoop.subprocess_shell + universal_newlines: Literal[False] = ..., + shell: Literal[True] = ..., + bufsize: Literal[0] = ..., + encoding: None = ..., + errors: None = ..., + # These parameters are taken by subprocess.Popen, which this ultimately delegates to + text: Optional[bool] = ..., + executable: Optional[AnyPath] = ..., + preexec_fn: Optional[Callable[[], Any]] = ..., + close_fds: bool = ..., + cwd: Optional[AnyPath] = ..., + env: Optional[subprocess._ENV] = ..., + startupinfo: Optional[Any] = ..., + creationflags: int = ..., + restore_signals: bool = ..., + start_new_session: bool = ..., + pass_fds: Any = ..., + ) -> Process: ... + +else: + async def create_subprocess_shell( + cmd: Union[str, bytes], + stdin: Union[int, IO[Any], None] = ..., + stdout: Union[int, IO[Any], None] = ..., + stderr: Union[int, IO[Any], None] = ..., + loop: Optional[events.AbstractEventLoop] = ..., + limit: int = ..., + *, + # These parameters are forced to these values by BaseEventLoop.subprocess_shell + universal_newlines: Literal[False] = ..., + shell: Literal[True] = ..., + bufsize: Literal[0] = ..., + encoding: None = ..., + errors: None = ..., + text: Literal[False, None] = ..., + # These parameters are taken by subprocess.Popen, which this ultimately delegates to + executable: Optional[AnyPath] = ..., + preexec_fn: Optional[Callable[[], Any]] = ..., + close_fds: bool = ..., + cwd: Optional[AnyPath] = ..., + env: Optional[subprocess._ENV] = ..., + startupinfo: Optional[Any] = ..., + creationflags: int = ..., + restore_signals: bool = ..., + start_new_session: bool = ..., + pass_fds: Any = ..., + ) -> Process: ... + async def create_subprocess_exec( + program: _ExecArg, + *args: _ExecArg, + stdin: Union[int, IO[Any], None] = ..., + stdout: Union[int, IO[Any], None] = ..., + stderr: Union[int, IO[Any], None] = ..., + loop: Optional[events.AbstractEventLoop] = ..., + limit: int = ..., + # These parameters are forced to these values by BaseEventLoop.subprocess_shell + universal_newlines: Literal[False] = ..., + shell: Literal[True] = ..., + bufsize: Literal[0] = ..., + encoding: None = ..., + errors: None = ..., + # These parameters are taken by subprocess.Popen, which this ultimately delegates to + text: Optional[bool] = ..., + executable: Optional[AnyPath] = ..., + preexec_fn: Optional[Callable[[], Any]] = ..., + close_fds: bool = ..., + cwd: Optional[AnyPath] = ..., + env: Optional[subprocess._ENV] = ..., + startupinfo: Optional[Any] = ..., + creationflags: int = ..., + restore_signals: bool = ..., + start_new_session: bool = ..., + pass_fds: Any = ..., + ) -> Process: ... diff --git a/mypy/typeshed/stdlib/asyncio/tasks.pyi b/mypy/typeshed/stdlib/asyncio/tasks.pyi index 65f61b887734..0ff0e3895252 100644 --- a/mypy/typeshed/stdlib/asyncio/tasks.pyi +++ b/mypy/typeshed/stdlib/asyncio/tasks.pyi @@ -25,9 +25,14 @@ FIRST_EXCEPTION: str FIRST_COMPLETED: str ALL_COMPLETED: str -def as_completed( - fs: Iterable[_FutureT[_T]], *, loop: Optional[AbstractEventLoop] = ..., timeout: Optional[float] = ... -) -> Iterator[Future[_T]]: ... +if sys.version_info >= (3, 10): + def as_completed(fs: Iterable[_FutureT[_T]], *, timeout: Optional[float] = ...) -> Iterator[Future[_T]]: ... + +else: + def as_completed( + fs: Iterable[_FutureT[_T]], *, loop: Optional[AbstractEventLoop] = ..., timeout: Optional[float] = ... + ) -> Iterator[Future[_T]]: ... + @overload def ensure_future(coro_or_future: _FT, *, loop: Optional[AbstractEventLoop] = ...) -> _FT: ... # type: ignore @overload @@ -40,126 +45,230 @@ def ensure_future(coro_or_future: Awaitable[_T], *, loop: Optional[AbstractEvent # of tasks passed; however, Tuple is used similar to the annotation for # zip() because typing does not support variadic type variables. See # typing PR #1550 for discussion. -@overload -def gather( - coro_or_future1: _FutureT[_T1], *, loop: Optional[AbstractEventLoop] = ..., return_exceptions: Literal[False] = ... -) -> Future[Tuple[_T1]]: ... -@overload -def gather( - coro_or_future1: _FutureT[_T1], - coro_or_future2: _FutureT[_T2], - *, - loop: Optional[AbstractEventLoop] = ..., - return_exceptions: Literal[False] = ..., -) -> Future[Tuple[_T1, _T2]]: ... -@overload -def gather( - coro_or_future1: _FutureT[_T1], - coro_or_future2: _FutureT[_T2], - coro_or_future3: _FutureT[_T3], - *, - loop: Optional[AbstractEventLoop] = ..., - return_exceptions: Literal[False] = ..., -) -> Future[Tuple[_T1, _T2, _T3]]: ... -@overload -def gather( - coro_or_future1: _FutureT[_T1], - coro_or_future2: _FutureT[_T2], - coro_or_future3: _FutureT[_T3], - coro_or_future4: _FutureT[_T4], - *, - loop: Optional[AbstractEventLoop] = ..., - return_exceptions: Literal[False] = ..., -) -> Future[Tuple[_T1, _T2, _T3, _T4]]: ... -@overload -def gather( - coro_or_future1: _FutureT[_T1], - coro_or_future2: _FutureT[_T2], - coro_or_future3: _FutureT[_T3], - coro_or_future4: _FutureT[_T4], - coro_or_future5: _FutureT[_T5], - *, - loop: Optional[AbstractEventLoop] = ..., - return_exceptions: Literal[False] = ..., -) -> Future[Tuple[_T1, _T2, _T3, _T4, _T5]]: ... -@overload -def gather( - coro_or_future1: _FutureT[Any], - coro_or_future2: _FutureT[Any], - coro_or_future3: _FutureT[Any], - coro_or_future4: _FutureT[Any], - coro_or_future5: _FutureT[Any], - coro_or_future6: _FutureT[Any], - *coros_or_futures: _FutureT[Any], - loop: Optional[AbstractEventLoop] = ..., - return_exceptions: bool = ..., -) -> Future[List[Any]]: ... -@overload -def gather( - coro_or_future1: _FutureT[_T1], *, loop: Optional[AbstractEventLoop] = ..., return_exceptions: bool = ... -) -> Future[Tuple[Union[_T1, BaseException]]]: ... -@overload -def gather( - coro_or_future1: _FutureT[_T1], - coro_or_future2: _FutureT[_T2], - *, - loop: Optional[AbstractEventLoop] = ..., - return_exceptions: bool = ..., -) -> Future[Tuple[Union[_T1, BaseException], Union[_T2, BaseException]]]: ... -@overload -def gather( - coro_or_future1: _FutureT[_T1], - coro_or_future2: _FutureT[_T2], - coro_or_future3: _FutureT[_T3], - *, - loop: Optional[AbstractEventLoop] = ..., - return_exceptions: bool = ..., -) -> Future[Tuple[Union[_T1, BaseException], Union[_T2, BaseException], Union[_T3, BaseException]]]: ... -@overload -def gather( - coro_or_future1: _FutureT[_T1], - coro_or_future2: _FutureT[_T2], - coro_or_future3: _FutureT[_T3], - coro_or_future4: _FutureT[_T4], - *, - loop: Optional[AbstractEventLoop] = ..., - return_exceptions: bool = ..., -) -> Future[ - Tuple[Union[_T1, BaseException], Union[_T2, BaseException], Union[_T3, BaseException], Union[_T4, BaseException]] -]: ... -@overload -def gather( - coro_or_future1: _FutureT[_T1], - coro_or_future2: _FutureT[_T2], - coro_or_future3: _FutureT[_T3], - coro_or_future4: _FutureT[_T4], - coro_or_future5: _FutureT[_T5], - *, - loop: Optional[AbstractEventLoop] = ..., - return_exceptions: bool = ..., -) -> Future[ - Tuple[ - Union[_T1, BaseException], - Union[_T2, BaseException], - Union[_T3, BaseException], - Union[_T4, BaseException], - Union[_T5, BaseException], - ] -]: ... +if sys.version_info >= (3, 10): + @overload + def gather(coro_or_future1: _FutureT[_T1], *, return_exceptions: Literal[False] = ...) -> Future[Tuple[_T1]]: ... + @overload + def gather( + coro_or_future1: _FutureT[_T1], coro_or_future2: _FutureT[_T2], *, return_exceptions: Literal[False] = ... + ) -> Future[Tuple[_T1, _T2]]: ... + @overload + def gather( + coro_or_future1: _FutureT[_T1], + coro_or_future2: _FutureT[_T2], + coro_or_future3: _FutureT[_T3], + *, + return_exceptions: Literal[False] = ..., + ) -> Future[Tuple[_T1, _T2, _T3]]: ... + @overload + def gather( + coro_or_future1: _FutureT[_T1], + coro_or_future2: _FutureT[_T2], + coro_or_future3: _FutureT[_T3], + coro_or_future4: _FutureT[_T4], + *, + return_exceptions: Literal[False] = ..., + ) -> Future[Tuple[_T1, _T2, _T3, _T4]]: ... + @overload + def gather( + coro_or_future1: _FutureT[_T1], + coro_or_future2: _FutureT[_T2], + coro_or_future3: _FutureT[_T3], + coro_or_future4: _FutureT[_T4], + coro_or_future5: _FutureT[_T5], + *, + return_exceptions: Literal[False] = ..., + ) -> Future[Tuple[_T1, _T2, _T3, _T4, _T5]]: ... + @overload + def gather( + coro_or_future1: _FutureT[Any], + coro_or_future2: _FutureT[Any], + coro_or_future3: _FutureT[Any], + coro_or_future4: _FutureT[Any], + coro_or_future5: _FutureT[Any], + coro_or_future6: _FutureT[Any], + *coros_or_futures: _FutureT[Any], + return_exceptions: bool = ..., + ) -> Future[List[Any]]: ... + @overload + def gather(coro_or_future1: _FutureT[_T1], *, return_exceptions: bool = ...) -> Future[Tuple[Union[_T1, BaseException]]]: ... + @overload + def gather( + coro_or_future1: _FutureT[_T1], coro_or_future2: _FutureT[_T2], *, return_exceptions: bool = ... + ) -> Future[Tuple[Union[_T1, BaseException], Union[_T2, BaseException]]]: ... + @overload + def gather( + coro_or_future1: _FutureT[_T1], + coro_or_future2: _FutureT[_T2], + coro_or_future3: _FutureT[_T3], + *, + return_exceptions: bool = ..., + ) -> Future[Tuple[Union[_T1, BaseException], Union[_T2, BaseException], Union[_T3, BaseException]]]: ... + @overload + def gather( + coro_or_future1: _FutureT[_T1], + coro_or_future2: _FutureT[_T2], + coro_or_future3: _FutureT[_T3], + coro_or_future4: _FutureT[_T4], + *, + return_exceptions: bool = ..., + ) -> Future[ + Tuple[Union[_T1, BaseException], Union[_T2, BaseException], Union[_T3, BaseException], Union[_T4, BaseException]] + ]: ... + @overload + def gather( + coro_or_future1: _FutureT[_T1], + coro_or_future2: _FutureT[_T2], + coro_or_future3: _FutureT[_T3], + coro_or_future4: _FutureT[_T4], + coro_or_future5: _FutureT[_T5], + *, + return_exceptions: bool = ..., + ) -> Future[ + Tuple[ + Union[_T1, BaseException], + Union[_T2, BaseException], + Union[_T3, BaseException], + Union[_T4, BaseException], + Union[_T5, BaseException], + ] + ]: ... + +else: + @overload + def gather( + coro_or_future1: _FutureT[_T1], *, loop: Optional[AbstractEventLoop] = ..., return_exceptions: Literal[False] = ... + ) -> Future[Tuple[_T1]]: ... + @overload + def gather( + coro_or_future1: _FutureT[_T1], + coro_or_future2: _FutureT[_T2], + *, + loop: Optional[AbstractEventLoop] = ..., + return_exceptions: Literal[False] = ..., + ) -> Future[Tuple[_T1, _T2]]: ... + @overload + def gather( + coro_or_future1: _FutureT[_T1], + coro_or_future2: _FutureT[_T2], + coro_or_future3: _FutureT[_T3], + *, + loop: Optional[AbstractEventLoop] = ..., + return_exceptions: Literal[False] = ..., + ) -> Future[Tuple[_T1, _T2, _T3]]: ... + @overload + def gather( + coro_or_future1: _FutureT[_T1], + coro_or_future2: _FutureT[_T2], + coro_or_future3: _FutureT[_T3], + coro_or_future4: _FutureT[_T4], + *, + loop: Optional[AbstractEventLoop] = ..., + return_exceptions: Literal[False] = ..., + ) -> Future[Tuple[_T1, _T2, _T3, _T4]]: ... + @overload + def gather( + coro_or_future1: _FutureT[_T1], + coro_or_future2: _FutureT[_T2], + coro_or_future3: _FutureT[_T3], + coro_or_future4: _FutureT[_T4], + coro_or_future5: _FutureT[_T5], + *, + loop: Optional[AbstractEventLoop] = ..., + return_exceptions: Literal[False] = ..., + ) -> Future[Tuple[_T1, _T2, _T3, _T4, _T5]]: ... + @overload + def gather( + coro_or_future1: _FutureT[Any], + coro_or_future2: _FutureT[Any], + coro_or_future3: _FutureT[Any], + coro_or_future4: _FutureT[Any], + coro_or_future5: _FutureT[Any], + coro_or_future6: _FutureT[Any], + *coros_or_futures: _FutureT[Any], + loop: Optional[AbstractEventLoop] = ..., + return_exceptions: bool = ..., + ) -> Future[List[Any]]: ... + @overload + def gather( + coro_or_future1: _FutureT[_T1], *, loop: Optional[AbstractEventLoop] = ..., return_exceptions: bool = ... + ) -> Future[Tuple[Union[_T1, BaseException]]]: ... + @overload + def gather( + coro_or_future1: _FutureT[_T1], + coro_or_future2: _FutureT[_T2], + *, + loop: Optional[AbstractEventLoop] = ..., + return_exceptions: bool = ..., + ) -> Future[Tuple[Union[_T1, BaseException], Union[_T2, BaseException]]]: ... + @overload + def gather( + coro_or_future1: _FutureT[_T1], + coro_or_future2: _FutureT[_T2], + coro_or_future3: _FutureT[_T3], + *, + loop: Optional[AbstractEventLoop] = ..., + return_exceptions: bool = ..., + ) -> Future[Tuple[Union[_T1, BaseException], Union[_T2, BaseException], Union[_T3, BaseException]]]: ... + @overload + def gather( + coro_or_future1: _FutureT[_T1], + coro_or_future2: _FutureT[_T2], + coro_or_future3: _FutureT[_T3], + coro_or_future4: _FutureT[_T4], + *, + loop: Optional[AbstractEventLoop] = ..., + return_exceptions: bool = ..., + ) -> Future[ + Tuple[Union[_T1, BaseException], Union[_T2, BaseException], Union[_T3, BaseException], Union[_T4, BaseException]] + ]: ... + @overload + def gather( + coro_or_future1: _FutureT[_T1], + coro_or_future2: _FutureT[_T2], + coro_or_future3: _FutureT[_T3], + coro_or_future4: _FutureT[_T4], + coro_or_future5: _FutureT[_T5], + *, + loop: Optional[AbstractEventLoop] = ..., + return_exceptions: bool = ..., + ) -> Future[ + Tuple[ + Union[_T1, BaseException], + Union[_T2, BaseException], + Union[_T3, BaseException], + Union[_T4, BaseException], + Union[_T5, BaseException], + ] + ]: ... + def run_coroutine_threadsafe(coro: _FutureT[_T], loop: AbstractEventLoop) -> concurrent.futures.Future[_T]: ... -def shield(arg: _FutureT[_T], *, loop: Optional[AbstractEventLoop] = ...) -> Future[_T]: ... -def sleep(delay: float, result: _T = ..., *, loop: Optional[AbstractEventLoop] = ...) -> Future[_T]: ... -@overload -def wait(fs: Iterable[_FT], *, loop: Optional[AbstractEventLoop] = ..., timeout: Optional[float] = ..., return_when: str = ...) -> Future[Tuple[Set[_FT], Set[_FT]]]: ... # type: ignore -@overload -def wait( - fs: Iterable[Awaitable[_T]], - *, - loop: Optional[AbstractEventLoop] = ..., - timeout: Optional[float] = ..., - return_when: str = ..., -) -> Future[Tuple[Set[Task[_T]], Set[Task[_T]]]]: ... -def wait_for(fut: _FutureT[_T], timeout: Optional[float], *, loop: Optional[AbstractEventLoop] = ...) -> Future[_T]: ... + +if sys.version_info >= (3, 10): + def shield(arg: _FutureT[_T]) -> Future[_T]: ... + def sleep(delay: float, result: _T = ...) -> Future[_T]: ... + @overload + def wait(fs: Iterable[_FT], *, timeout: Optional[float] = ..., return_when: str = ...) -> Future[Tuple[Set[_FT], Set[_FT]]]: ... # type: ignore + @overload + def wait( + fs: Iterable[Awaitable[_T]], *, timeout: Optional[float] = ..., return_when: str = ... + ) -> Future[Tuple[Set[Task[_T]], Set[Task[_T]]]]: ... + def wait_for(fut: _FutureT[_T], timeout: Optional[float]) -> Future[_T]: ... + +else: + def shield(arg: _FutureT[_T], *, loop: Optional[AbstractEventLoop] = ...) -> Future[_T]: ... + def sleep(delay: float, result: _T = ..., *, loop: Optional[AbstractEventLoop] = ...) -> Future[_T]: ... + @overload + def wait(fs: Iterable[_FT], *, loop: Optional[AbstractEventLoop] = ..., timeout: Optional[float] = ..., return_when: str = ...) -> Future[Tuple[Set[_FT], Set[_FT]]]: ... # type: ignore + @overload + def wait( + fs: Iterable[Awaitable[_T]], + *, + loop: Optional[AbstractEventLoop] = ..., + timeout: Optional[float] = ..., + return_when: str = ..., + ) -> Future[Tuple[Set[Task[_T]], Set[Task[_T]]]]: ... + def wait_for(fut: _FutureT[_T], timeout: Optional[float], *, loop: Optional[AbstractEventLoop] = ...) -> Future[_T]: ... class Task(Future[_T], Generic[_T]): if sys.version_info >= (3, 8): diff --git a/mypy/typeshed/stdlib/base64.pyi b/mypy/typeshed/stdlib/base64.pyi index 01d304dc7678..e217d6d3dbf2 100644 --- a/mypy/typeshed/stdlib/base64.pyi +++ b/mypy/typeshed/stdlib/base64.pyi @@ -19,6 +19,10 @@ def b32decode(s: _decodable, casefold: bool = ..., map01: Optional[bytes] = ...) def b16encode(s: _encodable) -> bytes: ... def b16decode(s: _decodable, casefold: bool = ...) -> bytes: ... +if sys.version_info >= (3, 10): + def b32hexencode(s: _encodable) -> bytes: ... + def b32hexdecode(s: _decodable, casefold: bool = ...) -> bytes: ... + if sys.version_info >= (3, 4): def a85encode(b: _encodable, *, foldspaces: bool = ..., wrapcol: int = ..., pad: bool = ..., adobe: bool = ...) -> bytes: ... def a85decode(b: _decodable, *, foldspaces: bool = ..., adobe: bool = ..., ignorechars: Union[str, bytes] = ...) -> bytes: ... diff --git a/mypy/typeshed/stdlib/builtins.pyi b/mypy/typeshed/stdlib/builtins.pyi index de4813a5d961..69da371906d0 100644 --- a/mypy/typeshed/stdlib/builtins.pyi +++ b/mypy/typeshed/stdlib/builtins.pyi @@ -22,10 +22,11 @@ from typing import ( IO, AbstractSet, Any, + AsyncIterable, + AsyncIterator, BinaryIO, ByteString, Callable, - Container, Dict, FrozenSet, Generic, @@ -58,7 +59,7 @@ from typing import ( ValuesView, overload, ) -from typing_extensions import Literal, SupportsIndex +from typing_extensions import Literal, SupportsIndex, final if sys.version_info >= (3, 9): from types import GenericAlias @@ -143,7 +144,7 @@ class type(object): @overload def __new__(cls, o: object) -> type: ... @overload - def __new__(cls, name: str, bases: Tuple[type, ...], namespace: Dict[str, Any], **kwds: Any) -> type: ... + def __new__(cls: Type[_TT], name: str, bases: Tuple[type, ...], namespace: Dict[str, Any], **kwds: Any) -> _TT: ... def __call__(self, *args: Any, **kwds: Any) -> Any: ... def __subclasses__(self: _TT) -> List[_TT]: ... # Note: the documentation doesnt specify what the return type is, the standard @@ -181,6 +182,8 @@ class int: def denominator(self) -> int: ... def conjugate(self) -> int: ... def bit_length(self) -> int: ... + if sys.version_info >= (3, 10): + def bit_count(self) -> int: ... def to_bytes(self, length: int, byteorder: str, *, signed: bool = ...) -> bytes: ... @classmethod def from_bytes( @@ -620,7 +623,7 @@ class bytearray(MutableSequence[int], ByteString): def __gt__(self, x: bytes) -> bool: ... def __ge__(self, x: bytes) -> bool: ... -class memoryview(Sized, Container[int]): +class memoryview(Sized, Sequence[int]): format: str itemsize: int shape: Optional[Tuple[int, ...]] @@ -664,6 +667,7 @@ class memoryview(Sized, Container[int]): else: def hex(self) -> str: ... +@final class bool(int): def __new__(cls: Type[_T], __o: object = ...) -> _T: ... @overload @@ -822,6 +826,7 @@ class dict(MutableMapping[_KT, _VT], Generic[_KT, _VT]): if sys.version_info >= (3, 9): def __class_getitem__(cls, item: Any) -> GenericAlias: ... def __or__(self, __value: Mapping[_KT, _VT]) -> Dict[_KT, _VT]: ... + def __ror__(self, __value: Mapping[_KT, _VT]) -> Dict[_KT, _VT]: ... def __ior__(self, __value: Mapping[_KT, _VT]) -> Dict[_KT, _VT]: ... class set(MutableSet[_T], Generic[_T]): @@ -959,6 +964,13 @@ _AnyStr_co = TypeVar("_AnyStr_co", str, bytes, covariant=True) class _PathLike(Protocol[_AnyStr_co]): def __fspath__(self) -> _AnyStr_co: ... +if sys.version_info >= (3, 10): + def aiter(__iterable: AsyncIterable[_T]) -> AsyncIterator[_T]: ... + @overload + async def anext(__i: AsyncIterator[_T]) -> _T: ... + @overload + async def anext(__i: AsyncIterator[_T], default: _VT) -> Union[_T, _VT]: ... + if sys.version_info >= (3, 8): def compile( source: Union[str, bytes, mod, AST], @@ -1074,29 +1086,29 @@ def max( __arg1: SupportsLessThanT, __arg2: SupportsLessThanT, *_args: SupportsLessThanT, key: None = ... ) -> SupportsLessThanT: ... @overload -def max(__arg1: _T, __arg2: _T, *_args: _T, key: Callable[[_T], SupportsLessThanT]) -> _T: ... +def max(__arg1: _T, __arg2: _T, *_args: _T, key: Callable[[_T], SupportsLessThan]) -> _T: ... @overload def max(__iterable: Iterable[SupportsLessThanT], *, key: None = ...) -> SupportsLessThanT: ... @overload -def max(__iterable: Iterable[_T], *, key: Callable[[_T], SupportsLessThanT]) -> _T: ... +def max(__iterable: Iterable[_T], *, key: Callable[[_T], SupportsLessThan]) -> _T: ... @overload def max(__iterable: Iterable[SupportsLessThanT], *, key: None = ..., default: _T) -> Union[SupportsLessThanT, _T]: ... @overload -def max(__iterable: Iterable[_T1], *, key: Callable[[_T1], SupportsLessThanT], default: _T2) -> Union[_T1, _T2]: ... +def max(__iterable: Iterable[_T1], *, key: Callable[[_T1], SupportsLessThan], default: _T2) -> Union[_T1, _T2]: ... @overload def min( __arg1: SupportsLessThanT, __arg2: SupportsLessThanT, *_args: SupportsLessThanT, key: None = ... ) -> SupportsLessThanT: ... @overload -def min(__arg1: _T, __arg2: _T, *_args: _T, key: Callable[[_T], SupportsLessThanT]) -> _T: ... +def min(__arg1: _T, __arg2: _T, *_args: _T, key: Callable[[_T], SupportsLessThan]) -> _T: ... @overload def min(__iterable: Iterable[SupportsLessThanT], *, key: None = ...) -> SupportsLessThanT: ... @overload -def min(__iterable: Iterable[_T], *, key: Callable[[_T], SupportsLessThanT]) -> _T: ... +def min(__iterable: Iterable[_T], *, key: Callable[[_T], SupportsLessThan]) -> _T: ... @overload def min(__iterable: Iterable[SupportsLessThanT], *, key: None = ..., default: _T) -> Union[SupportsLessThanT, _T]: ... @overload -def min(__iterable: Iterable[_T1], *, key: Callable[[_T1], SupportsLessThanT], default: _T2) -> Union[_T1, _T2]: ... +def min(__iterable: Iterable[_T1], *, key: Callable[[_T1], SupportsLessThan], default: _T2) -> Union[_T1, _T2]: ... @overload def next(__i: Iterator[_T]) -> _T: ... @overload @@ -1361,7 +1373,12 @@ if sys.platform == "win32": class ArithmeticError(_StandardError): ... class AssertionError(_StandardError): ... -class AttributeError(_StandardError): ... + +class AttributeError(_StandardError): + if sys.version_info >= (3, 10): + name: str + obj: object + class BufferError(_StandardError): ... class EOFError(_StandardError): ... @@ -1373,7 +1390,11 @@ class ImportError(_StandardError): class LookupError(_StandardError): ... class MemoryError(_StandardError): ... -class NameError(_StandardError): ... + +class NameError(_StandardError): + if sys.version_info >= (3, 10): + name: str + class ReferenceError(_StandardError): ... class RuntimeError(_StandardError): ... @@ -1386,6 +1407,9 @@ class SyntaxError(_StandardError): offset: Optional[int] text: Optional[str] filename: Optional[str] + if sys.version_info >= (3, 10): + end_lineno: Optional[int] + end_offset: Optional[int] class SystemError(_StandardError): ... class TypeError(_StandardError): ... @@ -1449,3 +1473,6 @@ class ImportWarning(Warning): ... class UnicodeWarning(Warning): ... class BytesWarning(Warning): ... class ResourceWarning(Warning): ... + +if sys.version_info >= (3, 10): + class EncodingWarning(Warning): ... diff --git a/mypy/typeshed/stdlib/calendar.pyi b/mypy/typeshed/stdlib/calendar.pyi index 0737062de405..ad73132b4595 100644 --- a/mypy/typeshed/stdlib/calendar.pyi +++ b/mypy/typeshed/stdlib/calendar.pyi @@ -19,6 +19,7 @@ def weekday(year: int, month: int, day: int) -> int: ... def monthrange(year: int, month: int) -> Tuple[int, int]: ... class Calendar: + firstweekday: int def __init__(self, firstweekday: int = ...) -> None: ... def getfirstweekday(self) -> int: ... def setfirstweekday(self, firstweekday: int) -> None: ... diff --git a/mypy/typeshed/stdlib/contextlib.pyi b/mypy/typeshed/stdlib/contextlib.pyi index 509bcb6e4fc3..19ef45bed241 100644 --- a/mypy/typeshed/stdlib/contextlib.pyi +++ b/mypy/typeshed/stdlib/contextlib.pyi @@ -1,6 +1,7 @@ import sys from types import TracebackType from typing import IO, Any, Callable, ContextManager, Iterable, Iterator, Optional, Type, TypeVar, overload +from typing_extensions import Protocol if sys.version_info >= (3, 5): from typing import AsyncContextManager, AsyncIterator @@ -34,8 +35,20 @@ if sys.version_info >= (3, 7): if sys.version_info < (3,): def nested(*mgr: ContextManager[Any]) -> ContextManager[Iterable[Any]]: ... -class closing(ContextManager[_T]): - def __init__(self, thing: _T) -> None: ... +class _SupportsClose(Protocol): + def close(self) -> None: ... + +_SupportsCloseT = TypeVar("_SupportsCloseT", bound=_SupportsClose) + +class closing(ContextManager[_SupportsCloseT]): + def __init__(self, thing: _SupportsCloseT) -> None: ... + +if sys.version_info >= (3, 10): + class _SupportsAclose(Protocol): + async def aclose(self) -> None: ... + _SupportsAcloseT = TypeVar("_SupportsAcloseT", bound=_SupportsAclose) + class aclosing(AsyncContextManager[_SupportsAcloseT]): + def __init__(self, thing: _SupportsAcloseT) -> None: ... if sys.version_info >= (3, 4): class suppress(ContextManager[None]): diff --git a/mypy/typeshed/stdlib/contextvars.pyi b/mypy/typeshed/stdlib/contextvars.pyi index 810b699b668d..069a6786688e 100644 --- a/mypy/typeshed/stdlib/contextvars.pyi +++ b/mypy/typeshed/stdlib/contextvars.pyi @@ -1,5 +1,5 @@ import sys -from typing import Any, Callable, ClassVar, Generic, Iterator, Mapping, TypeVar, Union, overload +from typing import Any, Callable, ClassVar, Generic, Iterator, Mapping, Optional, TypeVar, Union, overload if sys.version_info >= (3, 9): from types import GenericAlias @@ -35,6 +35,10 @@ def copy_context() -> Context: ... # a different value. class Context(Mapping[ContextVar[Any], Any]): def __init__(self) -> None: ... + @overload + def get(self, __key: ContextVar[Any]) -> Optional[Any]: ... + @overload + def get(self, __key: ContextVar[Any], __default: Optional[Any]) -> Any: ... def run(self, callable: Callable[..., _T], *args: Any, **kwargs: Any) -> _T: ... def copy(self) -> Context: ... def __getitem__(self, key: ContextVar[Any]) -> Any: ... diff --git a/mypy/typeshed/stdlib/datetime.pyi b/mypy/typeshed/stdlib/datetime.pyi index 5cfb42b32b31..4692f590b5b3 100644 --- a/mypy/typeshed/stdlib/datetime.pyi +++ b/mypy/typeshed/stdlib/datetime.pyi @@ -1,6 +1,6 @@ import sys from time import struct_time -from typing import AnyStr, ClassVar, Optional, SupportsAbs, Tuple, Type, TypeVar, Union, overload +from typing import AnyStr, ClassVar, NamedTuple, Optional, SupportsAbs, Tuple, Type, TypeVar, Union, overload _S = TypeVar("_S") @@ -26,6 +26,12 @@ if sys.version_info >= (3, 2): def __init__(self, offset: timedelta, name: str = ...) -> None: ... def __hash__(self) -> int: ... +if sys.version_info >= (3, 9): + class _IsoCalendarDate(NamedTuple): + year: int + week: int + weekday: int + _tzinfo = tzinfo class date: @@ -78,7 +84,10 @@ class date: def __hash__(self) -> int: ... def weekday(self) -> int: ... def isoweekday(self) -> int: ... - def isocalendar(self) -> Tuple[int, int, int]: ... + if sys.version_info >= (3, 9): + def isocalendar(self) -> _IsoCalendarDate: ... + else: + def isocalendar(self) -> Tuple[int, int, int]: ... class time: min: ClassVar[time] @@ -370,4 +379,7 @@ class datetime(date): def __hash__(self) -> int: ... def weekday(self) -> int: ... def isoweekday(self) -> int: ... - def isocalendar(self) -> Tuple[int, int, int]: ... + if sys.version_info >= (3, 9): + def isocalendar(self) -> _IsoCalendarDate: ... + else: + def isocalendar(self) -> Tuple[int, int, int]: ... diff --git a/mypy/typeshed/stdlib/email/errors.pyi b/mypy/typeshed/stdlib/email/errors.pyi index 561eb0bd0ea5..64ad41407857 100644 --- a/mypy/typeshed/stdlib/email/errors.pyi +++ b/mypy/typeshed/stdlib/email/errors.pyi @@ -1,3 +1,4 @@ +import sys from typing import Optional class MessageError(Exception): ... @@ -5,6 +6,7 @@ class MessageParseError(MessageError): ... class HeaderParseError(MessageParseError): ... class BoundaryError(MessageParseError): ... class MultipartConversionError(MessageError, TypeError): ... +class CharsetError(MessageError): ... class MessageDefect(ValueError): def __init__(self, line: Optional[str] = ...) -> None: ... @@ -14,9 +16,25 @@ class StartBoundaryNotFoundDefect(MessageDefect): ... class FirstHeaderLineIsContinuationDefect(MessageDefect): ... class MisplacedEnvelopeHeaderDefect(MessageDefect): ... class MultipartInvariantViolationDefect(MessageDefect): ... +class InvalidMultipartContentTransferEncodingDefect(MessageDefect): ... +class UndecodableBytesDefect(MessageDefect): ... class InvalidBase64PaddingDefect(MessageDefect): ... class InvalidBase64CharactersDefect(MessageDefect): ... +class InvalidBase64LengthDefect(MessageDefect): ... class CloseBoundaryNotFoundDefect(MessageDefect): ... class MissingHeaderBodySeparatorDefect(MessageDefect): ... MalformedHeaderDefect = MissingHeaderBodySeparatorDefect + +class HeaderDefect(MessageDefect): ... +class InvalidHeaderDefect(HeaderDefect): ... +class HeaderMissingRequiredValue(HeaderDefect): ... + +class NonPrintableDefect(HeaderDefect): + def __init__(self, non_printables: Optional[str]) -> None: ... + +class ObsoleteHeaderDefect(HeaderDefect): ... +class NonASCIILocalPartDefect(HeaderDefect): ... + +if sys.version_info >= (3, 10): + class InvalidDateDefect(HeaderDefect): ... diff --git a/mypy/typeshed/stdlib/enum.pyi b/mypy/typeshed/stdlib/enum.pyi index 9b44bc020d77..b5549253afc7 100644 --- a/mypy/typeshed/stdlib/enum.pyi +++ b/mypy/typeshed/stdlib/enum.pyi @@ -1,5 +1,6 @@ import sys from abc import ABCMeta +from builtins import property as _builtins_property from typing import Any, Dict, Iterator, List, Mapping, Type, TypeVar, Union _T = TypeVar("_T") @@ -15,7 +16,7 @@ class EnumMeta(ABCMeta): def __reversed__(self: Type[_T]) -> Iterator[_T]: ... def __contains__(self: Type[Any], member: object) -> bool: ... def __getitem__(self: Type[_T], name: str) -> _T: ... - @property + @_builtins_property def __members__(self: Type[_T]) -> Mapping[str, _T]: ... def __len__(self) -> int: ... @@ -74,3 +75,20 @@ class IntFlag(int, Flag): __ror__ = __or__ __rand__ = __and__ __rxor__ = __xor__ + +if sys.version_info >= (3, 10): + class StrEnum(str, Enum): + def __new__(cls: Type[_T], value: Union[int, _T]) -> _T: ... + class FlagBoundary(StrEnum): + STRICT: str + CONFORM: str + EJECT: str + KEEP: str + STRICT = FlagBoundary.STRICT + CONFORM = FlagBoundary.CONFORM + EJECT = FlagBoundary.EJECT + KEEP = FlagBoundary.KEEP + class property(_builtins_property): ... + def global_enum(cls: _S) -> _S: ... + def global_enum_repr(self: Enum) -> str: ... + def global_flag_repr(self: Flag) -> str: ... diff --git a/mypy/typeshed/stdlib/formatter.pyi b/mypy/typeshed/stdlib/formatter.pyi index 31c45592a215..d3ecaec5fab9 100644 --- a/mypy/typeshed/stdlib/formatter.pyi +++ b/mypy/typeshed/stdlib/formatter.pyi @@ -1,103 +1,99 @@ +import sys from typing import IO, Any, Iterable, List, Optional, Tuple -AS_IS: None -_FontType = Tuple[str, bool, bool, bool] -_StylesType = Tuple[Any, ...] - -class NullFormatter: - writer: Optional[NullWriter] - def __init__(self, writer: Optional[NullWriter] = ...) -> None: ... - def end_paragraph(self, blankline: int) -> None: ... - def add_line_break(self) -> None: ... - def add_hor_rule(self, *args: Any, **kw: Any) -> None: ... - def add_label_data(self, format: str, counter: int, blankline: Optional[int] = ...) -> None: ... - def add_flowing_data(self, data: str) -> None: ... - def add_literal_data(self, data: str) -> None: ... - def flush_softspace(self) -> None: ... - def push_alignment(self, align: Optional[str]) -> None: ... - def pop_alignment(self) -> None: ... - def push_font(self, x: _FontType) -> None: ... - def pop_font(self) -> None: ... - def push_margin(self, margin: int) -> None: ... - def pop_margin(self) -> None: ... - def set_spacing(self, spacing: Optional[str]) -> None: ... - def push_style(self, *styles: _StylesType) -> None: ... - def pop_style(self, n: int = ...) -> None: ... - def assert_line_data(self, flag: int = ...) -> None: ... - -class AbstractFormatter: - writer: NullWriter - align: Optional[str] - align_stack: List[Optional[str]] - font_stack: List[_FontType] - margin_stack: List[int] - spacing: Optional[str] - style_stack: Any - nospace: int - softspace: int - para_end: int - parskip: int - hard_break: int - have_label: int - def __init__(self, writer: NullWriter) -> None: ... - def end_paragraph(self, blankline: int) -> None: ... - def add_line_break(self) -> None: ... - def add_hor_rule(self, *args: Any, **kw: Any) -> None: ... - def add_label_data(self, format: str, counter: int, blankline: Optional[int] = ...) -> None: ... - def format_counter(self, format: Iterable[str], counter: int) -> str: ... - def format_letter(self, case: str, counter: int) -> str: ... - def format_roman(self, case: str, counter: int) -> str: ... - def add_flowing_data(self, data: str) -> None: ... - def add_literal_data(self, data: str) -> None: ... - def flush_softspace(self) -> None: ... - def push_alignment(self, align: Optional[str]) -> None: ... - def pop_alignment(self) -> None: ... - def push_font(self, font: _FontType) -> None: ... - def pop_font(self) -> None: ... - def push_margin(self, margin: int) -> None: ... - def pop_margin(self) -> None: ... - def set_spacing(self, spacing: Optional[str]) -> None: ... - def push_style(self, *styles: _StylesType) -> None: ... - def pop_style(self, n: int = ...) -> None: ... - def assert_line_data(self, flag: int = ...) -> None: ... - -class NullWriter: - def __init__(self) -> None: ... - def flush(self) -> None: ... - def new_alignment(self, align: Optional[str]) -> None: ... - def new_font(self, font: _FontType) -> None: ... - def new_margin(self, margin: int, level: int) -> None: ... - def new_spacing(self, spacing: Optional[str]) -> None: ... - def new_styles(self, styles: Tuple[Any, ...]) -> None: ... - def send_paragraph(self, blankline: int) -> None: ... - def send_line_break(self) -> None: ... - def send_hor_rule(self, *args: Any, **kw: Any) -> None: ... - def send_label_data(self, data: str) -> None: ... - def send_flowing_data(self, data: str) -> None: ... - def send_literal_data(self, data: str) -> None: ... - -class AbstractWriter(NullWriter): - def new_alignment(self, align: Optional[str]) -> None: ... - def new_font(self, font: _FontType) -> None: ... - def new_margin(self, margin: int, level: int) -> None: ... - def new_spacing(self, spacing: Optional[str]) -> None: ... - def new_styles(self, styles: Tuple[Any, ...]) -> None: ... - def send_paragraph(self, blankline: int) -> None: ... - def send_line_break(self) -> None: ... - def send_hor_rule(self, *args: Any, **kw: Any) -> None: ... - def send_label_data(self, data: str) -> None: ... - def send_flowing_data(self, data: str) -> None: ... - def send_literal_data(self, data: str) -> None: ... - -class DumbWriter(NullWriter): - file: IO[str] - maxcol: int - def __init__(self, file: Optional[IO[str]] = ..., maxcol: int = ...) -> None: ... - def reset(self) -> None: ... - def send_paragraph(self, blankline: int) -> None: ... - def send_line_break(self) -> None: ... - def send_hor_rule(self, *args: Any, **kw: Any) -> None: ... - def send_literal_data(self, data: str) -> None: ... - def send_flowing_data(self, data: str) -> None: ... - -def test(file: Optional[str] = ...) -> None: ... +if sys.version_info < (3, 10): + AS_IS: None + _FontType = Tuple[str, bool, bool, bool] + _StylesType = Tuple[Any, ...] + class NullFormatter: + writer: Optional[NullWriter] + def __init__(self, writer: Optional[NullWriter] = ...) -> None: ... + def end_paragraph(self, blankline: int) -> None: ... + def add_line_break(self) -> None: ... + def add_hor_rule(self, *args: Any, **kw: Any) -> None: ... + def add_label_data(self, format: str, counter: int, blankline: Optional[int] = ...) -> None: ... + def add_flowing_data(self, data: str) -> None: ... + def add_literal_data(self, data: str) -> None: ... + def flush_softspace(self) -> None: ... + def push_alignment(self, align: Optional[str]) -> None: ... + def pop_alignment(self) -> None: ... + def push_font(self, x: _FontType) -> None: ... + def pop_font(self) -> None: ... + def push_margin(self, margin: int) -> None: ... + def pop_margin(self) -> None: ... + def set_spacing(self, spacing: Optional[str]) -> None: ... + def push_style(self, *styles: _StylesType) -> None: ... + def pop_style(self, n: int = ...) -> None: ... + def assert_line_data(self, flag: int = ...) -> None: ... + class AbstractFormatter: + writer: NullWriter + align: Optional[str] + align_stack: List[Optional[str]] + font_stack: List[_FontType] + margin_stack: List[int] + spacing: Optional[str] + style_stack: Any + nospace: int + softspace: int + para_end: int + parskip: int + hard_break: int + have_label: int + def __init__(self, writer: NullWriter) -> None: ... + def end_paragraph(self, blankline: int) -> None: ... + def add_line_break(self) -> None: ... + def add_hor_rule(self, *args: Any, **kw: Any) -> None: ... + def add_label_data(self, format: str, counter: int, blankline: Optional[int] = ...) -> None: ... + def format_counter(self, format: Iterable[str], counter: int) -> str: ... + def format_letter(self, case: str, counter: int) -> str: ... + def format_roman(self, case: str, counter: int) -> str: ... + def add_flowing_data(self, data: str) -> None: ... + def add_literal_data(self, data: str) -> None: ... + def flush_softspace(self) -> None: ... + def push_alignment(self, align: Optional[str]) -> None: ... + def pop_alignment(self) -> None: ... + def push_font(self, font: _FontType) -> None: ... + def pop_font(self) -> None: ... + def push_margin(self, margin: int) -> None: ... + def pop_margin(self) -> None: ... + def set_spacing(self, spacing: Optional[str]) -> None: ... + def push_style(self, *styles: _StylesType) -> None: ... + def pop_style(self, n: int = ...) -> None: ... + def assert_line_data(self, flag: int = ...) -> None: ... + class NullWriter: + def __init__(self) -> None: ... + def flush(self) -> None: ... + def new_alignment(self, align: Optional[str]) -> None: ... + def new_font(self, font: _FontType) -> None: ... + def new_margin(self, margin: int, level: int) -> None: ... + def new_spacing(self, spacing: Optional[str]) -> None: ... + def new_styles(self, styles: Tuple[Any, ...]) -> None: ... + def send_paragraph(self, blankline: int) -> None: ... + def send_line_break(self) -> None: ... + def send_hor_rule(self, *args: Any, **kw: Any) -> None: ... + def send_label_data(self, data: str) -> None: ... + def send_flowing_data(self, data: str) -> None: ... + def send_literal_data(self, data: str) -> None: ... + class AbstractWriter(NullWriter): + def new_alignment(self, align: Optional[str]) -> None: ... + def new_font(self, font: _FontType) -> None: ... + def new_margin(self, margin: int, level: int) -> None: ... + def new_spacing(self, spacing: Optional[str]) -> None: ... + def new_styles(self, styles: Tuple[Any, ...]) -> None: ... + def send_paragraph(self, blankline: int) -> None: ... + def send_line_break(self) -> None: ... + def send_hor_rule(self, *args: Any, **kw: Any) -> None: ... + def send_label_data(self, data: str) -> None: ... + def send_flowing_data(self, data: str) -> None: ... + def send_literal_data(self, data: str) -> None: ... + class DumbWriter(NullWriter): + file: IO[str] + maxcol: int + def __init__(self, file: Optional[IO[str]] = ..., maxcol: int = ...) -> None: ... + def reset(self) -> None: ... + def send_paragraph(self, blankline: int) -> None: ... + def send_line_break(self) -> None: ... + def send_hor_rule(self, *args: Any, **kw: Any) -> None: ... + def send_literal_data(self, data: str) -> None: ... + def send_flowing_data(self, data: str) -> None: ... + def test(file: Optional[str] = ...) -> None: ... diff --git a/mypy/typeshed/stdlib/ftplib.pyi b/mypy/typeshed/stdlib/ftplib.pyi index d3e4758aec47..bdb1716549c0 100644 --- a/mypy/typeshed/stdlib/ftplib.pyi +++ b/mypy/typeshed/stdlib/ftplib.pyi @@ -22,7 +22,7 @@ class error_temp(Error): ... class error_perm(Error): ... class error_proto(Error): ... -all_errors = Tuple[Type[Exception], ...] +all_errors: Tuple[Type[Exception], ...] class FTP: debugging: int diff --git a/mypy/typeshed/stdlib/functools.pyi b/mypy/typeshed/stdlib/functools.pyi index ceb6ffcb3f11..d4a492c0102e 100644 --- a/mypy/typeshed/stdlib/functools.pyi +++ b/mypy/typeshed/stdlib/functools.pyi @@ -1,5 +1,5 @@ import sys -from _typeshed import SupportsLessThan +from _typeshed import SupportsItems, SupportsLessThan from typing import ( Any, Callable, @@ -11,6 +11,8 @@ from typing import ( NamedTuple, Optional, Sequence, + Set, + Sized, Tuple, Type, TypeVar, @@ -131,3 +133,14 @@ if sys.version_info >= (3, 8): if sys.version_info >= (3, 9): def cache(__user_function: Callable[..., _T]) -> _lru_cache_wrapper[_T]: ... + +def _make_key( + args: Tuple[Hashable, ...], + kwds: SupportsItems[Any, Any], + typed: bool, + kwd_mark: Tuple[object, ...] = ..., + fasttypes: Set[type] = ..., + tuple: type = ..., + type: Any = ..., + len: Callable[[Sized], int] = ..., +) -> Hashable: ... diff --git a/mypy/typeshed/stdlib/genericpath.pyi b/mypy/typeshed/stdlib/genericpath.pyi index fc314f0a1658..74dc819325de 100644 --- a/mypy/typeshed/stdlib/genericpath.pyi +++ b/mypy/typeshed/stdlib/genericpath.pyi @@ -1,12 +1,7 @@ import sys from _typeshed import AnyPath -from typing import AnyStr, Sequence, Text - -if sys.version_info >= (3, 0): - def commonprefix(m: Sequence[str]) -> str: ... - -else: - def commonprefix(m: Sequence[AnyStr]) -> AnyStr: ... +from os.path import commonprefix as commonprefix +from typing import Text def exists(path: AnyPath) -> bool: ... def isfile(path: Text) -> bool: ... diff --git a/mypy/typeshed/stdlib/glob.pyi b/mypy/typeshed/stdlib/glob.pyi index 3029b258100a..42269e95d896 100644 --- a/mypy/typeshed/stdlib/glob.pyi +++ b/mypy/typeshed/stdlib/glob.pyi @@ -1,8 +1,21 @@ -from typing import AnyStr, Iterator, List, Union +import sys +from _typeshed import AnyPath +from typing import AnyStr, Iterator, List, Optional, Union def glob0(dirname: AnyStr, pattern: AnyStr) -> List[AnyStr]: ... def glob1(dirname: AnyStr, pattern: AnyStr) -> List[AnyStr]: ... -def glob(pathname: AnyStr, *, recursive: bool = ...) -> List[AnyStr]: ... -def iglob(pathname: AnyStr, *, recursive: bool = ...) -> Iterator[AnyStr]: ... + +if sys.version_info >= (3, 10): + def glob( + pathname: AnyStr, *, root_dir: Optional[AnyPath] = ..., dir_fd: Optional[int] = ..., recursive: bool = ... + ) -> List[AnyStr]: ... + def iglob( + pathname: AnyStr, *, root_dir: Optional[AnyPath] = ..., dir_fd: Optional[int] = ..., recursive: bool = ... + ) -> Iterator[AnyStr]: ... + +else: + def glob(pathname: AnyStr, *, recursive: bool = ...) -> List[AnyStr]: ... + def iglob(pathname: AnyStr, *, recursive: bool = ...) -> Iterator[AnyStr]: ... + def escape(pathname: AnyStr) -> AnyStr: ... def has_magic(s: Union[str, bytes]) -> bool: ... # undocumented diff --git a/mypy/typeshed/stdlib/html/parser.pyi b/mypy/typeshed/stdlib/html/parser.pyi index b49766bfc9b8..82431b7e1d3e 100644 --- a/mypy/typeshed/stdlib/html/parser.pyi +++ b/mypy/typeshed/stdlib/html/parser.pyi @@ -18,3 +18,13 @@ class HTMLParser(ParserBase): def handle_decl(self, decl: str) -> None: ... def handle_pi(self, data: str) -> None: ... def unknown_decl(self, data: str) -> None: ... + CDATA_CONTENT_ELEMENTS: Tuple[str, ...] + def check_for_whole_start_tag(self, i: int) -> int: ... # undocumented + def clear_cdata_mode(self) -> None: ... # undocumented + def goahead(self, end: bool) -> None: ... # undocumented + def parse_bogus_comment(self, i: int, report: bool = ...) -> int: ... # undocumented + def parse_endtag(self, i: int) -> int: ... # undocumented + def parse_html_declaration(self, i: int) -> int: ... # undocumented + def parse_pi(self, i: int) -> int: ... # undocumented + def parse_starttag(self, i: int) -> int: ... # undocumented + def set_cdata_mode(self, elem: str) -> None: ... # undocumented diff --git a/mypy/typeshed/stdlib/http/client.pyi b/mypy/typeshed/stdlib/http/client.pyi index 2f6b2df08868..c35228928d77 100644 --- a/mypy/typeshed/stdlib/http/client.pyi +++ b/mypy/typeshed/stdlib/http/client.pyi @@ -183,18 +183,33 @@ class HTTPConnection: def send(self, data: _DataType) -> None: ... class HTTPSConnection(HTTPConnection): - def __init__( - self, - host: str, - port: Optional[int] = ..., - key_file: Optional[str] = ..., - cert_file: Optional[str] = ..., - timeout: Optional[float] = ..., - source_address: Optional[Tuple[str, int]] = ..., - *, - context: Optional[ssl.SSLContext] = ..., - check_hostname: Optional[bool] = ..., - ) -> None: ... + if sys.version_info >= (3, 7): + def __init__( + self, + host: str, + port: Optional[int] = ..., + key_file: Optional[str] = ..., + cert_file: Optional[str] = ..., + timeout: Optional[float] = ..., + source_address: Optional[Tuple[str, int]] = ..., + *, + context: Optional[ssl.SSLContext] = ..., + check_hostname: Optional[bool] = ..., + blocksize: int = ..., + ) -> None: ... + else: + def __init__( + self, + host: str, + port: Optional[int] = ..., + key_file: Optional[str] = ..., + cert_file: Optional[str] = ..., + timeout: Optional[float] = ..., + source_address: Optional[Tuple[str, int]] = ..., + *, + context: Optional[ssl.SSLContext] = ..., + check_hostname: Optional[bool] = ..., + ) -> None: ... class HTTPException(Exception): ... diff --git a/mypy/typeshed/stdlib/http/cookiejar.pyi b/mypy/typeshed/stdlib/http/cookiejar.pyi index a57c7c0fbe16..9398bae00951 100644 --- a/mypy/typeshed/stdlib/http/cookiejar.pyi +++ b/mypy/typeshed/stdlib/http/cookiejar.pyi @@ -63,21 +63,39 @@ class DefaultCookiePolicy(CookiePolicy): DomainRFC2965Match: int DomainLiberal: int DomainStrict: int - def __init__( - self, - blocked_domains: Optional[Sequence[str]] = ..., - allowed_domains: Optional[Sequence[str]] = ..., - netscape: bool = ..., - rfc2965: bool = ..., - rfc2109_as_netscape: Optional[bool] = ..., - hide_cookie2: bool = ..., - strict_domain: bool = ..., - strict_rfc2965_unverifiable: bool = ..., - strict_ns_unverifiable: bool = ..., - strict_ns_domain: int = ..., - strict_ns_set_initial_dollar: bool = ..., - strict_ns_set_path: bool = ..., - ) -> None: ... + if sys.version_info >= (3, 8): + def __init__( + self, + blocked_domains: Optional[Sequence[str]] = ..., + allowed_domains: Optional[Sequence[str]] = ..., + netscape: bool = ..., + rfc2965: bool = ..., + rfc2109_as_netscape: Optional[bool] = ..., + hide_cookie2: bool = ..., + strict_domain: bool = ..., + strict_rfc2965_unverifiable: bool = ..., + strict_ns_unverifiable: bool = ..., + strict_ns_domain: int = ..., + strict_ns_set_initial_dollar: bool = ..., + strict_ns_set_path: bool = ..., + secure_protocols: Sequence[str] = ..., + ) -> None: ... + else: + def __init__( + self, + blocked_domains: Optional[Sequence[str]] = ..., + allowed_domains: Optional[Sequence[str]] = ..., + netscape: bool = ..., + rfc2965: bool = ..., + rfc2109_as_netscape: Optional[bool] = ..., + hide_cookie2: bool = ..., + strict_domain: bool = ..., + strict_rfc2965_unverifiable: bool = ..., + strict_ns_unverifiable: bool = ..., + strict_ns_domain: int = ..., + strict_ns_set_initial_dollar: bool = ..., + strict_ns_set_path: bool = ..., + ) -> None: ... def blocked_domains(self) -> Tuple[str, ...]: ... def set_blocked_domains(self, blocked_domains: Sequence[str]) -> None: ... def is_blocked(self, domain: str) -> bool: ... diff --git a/mypy/typeshed/stdlib/importlib/abc.pyi b/mypy/typeshed/stdlib/importlib/abc.pyi index 47f7f071a6c0..62b391e216fc 100644 --- a/mypy/typeshed/stdlib/importlib/abc.pyi +++ b/mypy/typeshed/stdlib/importlib/abc.pyi @@ -2,12 +2,9 @@ import sys import types from _typeshed import AnyPath from abc import ABCMeta, abstractmethod -from typing import IO, Any, Iterator, Mapping, Optional, Sequence, Tuple, Union -from typing_extensions import Literal - -# Loader is exported from this module, but for circular import reasons -# exists in its own stub file (with ModuleSpec and ModuleType). -from _importlib_modulespec import Loader as Loader, ModuleSpec # Exported +from importlib.machinery import ModuleSpec +from typing import IO, Any, Iterator, Mapping, Optional, Protocol, Sequence, Tuple, Union +from typing_extensions import Literal, runtime_checkable _Path = Union[bytes, str] @@ -38,6 +35,7 @@ class SourceLoader(ResourceLoader, ExecutionLoader, metaclass=ABCMeta): def get_source(self, fullname: str) -> Optional[str]: ... def path_stats(self, path: _Path) -> Mapping[str, Any]: ... +# Please keep in sync with sys._MetaPathFinder class MetaPathFinder(Finder): def find_module(self, fullname: str, path: Optional[Sequence[_Path]]) -> Optional[Loader]: ... def invalidate_caches(self) -> None: ... @@ -53,6 +51,17 @@ class PathEntryFinder(Finder): # Not defined on the actual class, but expected to exist. def find_spec(self, fullname: str, target: Optional[types.ModuleType] = ...) -> Optional[ModuleSpec]: ... +class Loader(metaclass=ABCMeta): + def load_module(self, fullname: str) -> types.ModuleType: ... + def module_repr(self, module: types.ModuleType) -> str: ... + def create_module(self, spec: ModuleSpec) -> Optional[types.ModuleType]: ... + # Not defined on the actual class for backwards-compatibility reasons, + # but expected in new code. + def exec_module(self, module: types.ModuleType) -> None: ... + +class _LoaderProtocol(Protocol): + def load_module(self, fullname: str) -> types.ModuleType: ... + class FileLoader(ResourceLoader, ExecutionLoader, metaclass=ABCMeta): name: str path: _Path @@ -73,7 +82,6 @@ if sys.version_info >= (3, 7): def contents(self) -> Iterator[str]: ... if sys.version_info >= (3, 9): - from typing import Protocol, runtime_checkable @runtime_checkable class Traversable(Protocol): @abstractmethod diff --git a/mypy/typeshed/stdlib/importlib/machinery.pyi b/mypy/typeshed/stdlib/importlib/machinery.pyi index 31ffdf7a431e..beae43bd3937 100644 --- a/mypy/typeshed/stdlib/importlib/machinery.pyi +++ b/mypy/typeshed/stdlib/importlib/machinery.pyi @@ -1,10 +1,26 @@ import importlib.abc import types -from typing import Callable, List, Optional, Sequence, Tuple, Union +from typing import Any, Callable, List, Optional, Sequence, Tuple, Union -# ModuleSpec is exported from this module, but for circular import -# reasons exists in its own stub file (with Loader and ModuleType). -from _importlib_modulespec import Loader, ModuleSpec as ModuleSpec # Exported +# TODO: the loaders seem a bit backwards, attribute is protocol but __init__ arg isn't? +class ModuleSpec: + def __init__( + self, + name: str, + loader: Optional[importlib.abc.Loader], + *, + origin: Optional[str] = ..., + loader_state: Any = ..., + is_package: Optional[bool] = ..., + ) -> None: ... + name: str + loader: Optional[importlib.abc._LoaderProtocol] + origin: Optional[str] + submodule_search_locations: Optional[List[str]] + loader_state: Any + cached: Optional[str] + parent: Optional[str] + has_location: bool class BuiltinImporter(importlib.abc.MetaPathFinder, importlib.abc.InspectLoader): # MetaPathFinder @@ -78,7 +94,7 @@ class PathFinder: cls, fullname: str, path: Optional[Sequence[Union[bytes, str]]] = ..., target: Optional[types.ModuleType] = ... ) -> Optional[ModuleSpec]: ... @classmethod - def find_module(cls, fullname: str, path: Optional[Sequence[Union[bytes, str]]] = ...) -> Optional[Loader]: ... + def find_module(cls, fullname: str, path: Optional[Sequence[Union[bytes, str]]] = ...) -> Optional[importlib.abc.Loader]: ... SOURCE_SUFFIXES: List[str] DEBUG_BYTECODE_SUFFIXES: List[str] diff --git a/mypy/typeshed/stdlib/inspect.pyi b/mypy/typeshed/stdlib/inspect.pyi index d0627d54981b..fa49e4493b28 100644 --- a/mypy/typeshed/stdlib/inspect.pyi +++ b/mypy/typeshed/stdlib/inspect.pyi @@ -105,7 +105,18 @@ def indentsize(line: str) -> int: ... # # Introspecting callables with the Signature object # -def signature(obj: Callable[..., Any], *, follow_wrapped: bool = ...) -> Signature: ... +if sys.version_info >= (3, 10): + def signature( + obj: Callable[..., Any], + *, + follow_wrapped: bool = ..., + globals: Optional[Mapping[str, Any]] = ..., + locals: Optional[Mapping[str, Any]] = ..., + eval_str: bool = ..., + ) -> Signature: ... + +else: + def signature(obj: Callable[..., Any], *, follow_wrapped: bool = ...) -> Signature: ... class Signature: def __init__(self, parameters: Optional[Sequence[Parameter]] = ..., *, return_annotation: Any = ...) -> None: ... @@ -119,8 +130,29 @@ class Signature: def bind(self, *args: Any, **kwargs: Any) -> BoundArguments: ... def bind_partial(self, *args: Any, **kwargs: Any) -> BoundArguments: ... def replace(self, *, parameters: Optional[Sequence[Parameter]] = ..., return_annotation: Any = ...) -> Signature: ... - @classmethod - def from_callable(cls, obj: Callable[..., Any], *, follow_wrapped: bool = ...) -> Signature: ... + if sys.version_info >= (3, 10): + @classmethod + def from_callable( + cls, + obj: Callable[..., Any], + *, + follow_wrapped: bool = ..., + globals: Optional[Mapping[str, Any]] = ..., + locals: Optional[Mapping[str, Any]] = ..., + eval_str: bool = ..., + ) -> Signature: ... + else: + @classmethod + def from_callable(cls, obj: Callable[..., Any], *, follow_wrapped: bool = ...) -> Signature: ... + +if sys.version_info >= (3, 10): + def get_annotations( + obj: Union[Callable[..., Any], Type[Any], ModuleType], + *, + globals: Optional[Mapping[str, Any]] = ..., + locals: Optional[Mapping[str, Any]] = ..., + eval_str: bool = ..., + ) -> Dict[str, Any]: ... # The name is the same as the enum's name in CPython class _ParameterKind(enum.IntEnum): @@ -165,7 +197,8 @@ class BoundArguments: # TODO: The actual return type should be List[_ClassTreeItem] but mypy doesn't # seem to be supporting this at the moment: # _ClassTreeItem = Union[List[_ClassTreeItem], Tuple[type, Tuple[type, ...]]] -def getclasstree(classes: List[type], unique: bool = ...) -> Any: ... +def getclasstree(classes: List[type], unique: bool = ...) -> List[Any]: ... +def walktree(classes: List[type], children: Dict[Type[Any], List[type]], parent: Optional[Type[Any]]) -> List[Any]: ... class ArgSpec(NamedTuple): args: List[str] @@ -307,3 +340,6 @@ class Attribute(NamedTuple): object: _Object def classify_class_attrs(cls: type) -> List[Attribute]: ... + +if sys.version_info >= (3, 9): + class ClassFoundException(Exception): ... diff --git a/mypy/typeshed/stdlib/itertools.pyi b/mypy/typeshed/stdlib/itertools.pyi index 4ae9bac0e21e..7fe08ca80531 100644 --- a/mypy/typeshed/stdlib/itertools.pyi +++ b/mypy/typeshed/stdlib/itertools.pyi @@ -13,6 +13,7 @@ from typing import ( Tuple, Type, TypeVar, + Union, overload, ) from typing_extensions import Literal, SupportsIndex @@ -20,7 +21,7 @@ from typing_extensions import Literal, SupportsIndex _T = TypeVar("_T") _S = TypeVar("_S") _N = TypeVar("_N", int, float, SupportsFloat, SupportsInt, SupportsIndex, SupportsComplex) -_NStep = TypeVar("_NStep", int, float, SupportsFloat, SupportsInt, SupportsIndex, SupportsComplex) +_Step = Union[int, float, SupportsFloat, SupportsInt, SupportsIndex, SupportsComplex] Predicate = Callable[[_T], object] @@ -30,7 +31,9 @@ class count(Iterator[_N], Generic[_N]): @overload def __new__(cls) -> count[int]: ... @overload - def __new__(cls, start: _N, step: _NStep = ...) -> count[_N]: ... + def __new__(cls, start: _N, step: _Step = ...) -> count[_N]: ... + @overload + def __new__(cls, *, step: _N) -> count[_N]: ... def __next__(self) -> _N: ... def __iter__(self) -> Iterator[_N]: ... @@ -196,3 +199,9 @@ class combinations_with_replacement(Iterator[Tuple[_T, ...]], Generic[_T]): def __init__(self, iterable: Iterable[_T], r: int) -> None: ... def __iter__(self) -> Iterator[Tuple[_T, ...]]: ... def __next__(self) -> Tuple[_T, ...]: ... + +if sys.version_info >= (3, 10): + class pairwise(Iterator[_T_co], Generic[_T_co]): + def __new__(cls, __iterable: Iterable[_T]) -> pairwise[Tuple[_T, _T]]: ... + def __iter__(self) -> Iterator[_T_co]: ... + def __next__(self) -> _T_co: ... diff --git a/mypy/typeshed/stdlib/locale.pyi b/mypy/typeshed/stdlib/locale.pyi index 920c006cda01..9be4aa2735e1 100644 --- a/mypy/typeshed/stdlib/locale.pyi +++ b/mypy/typeshed/stdlib/locale.pyi @@ -81,7 +81,7 @@ class Error(Exception): ... def setlocale(category: int, locale: Union[_str, Iterable[_str], None] = ...) -> _str: ... def localeconv() -> Mapping[_str, Union[int, _str, List[int]]]: ... -def nl_langinfo(option: int) -> _str: ... +def nl_langinfo(__key: int) -> _str: ... def getdefaultlocale(envvars: Tuple[_str, ...] = ...) -> Tuple[Optional[_str], Optional[_str]]: ... def getlocale(category: int = ...) -> Sequence[_str]: ... def getpreferredencoding(do_setlocale: bool = ...) -> _str: ... diff --git a/mypy/typeshed/stdlib/lzma.pyi b/mypy/typeshed/stdlib/lzma.pyi index b39b5ecc77d3..7290a25b3bcd 100644 --- a/mypy/typeshed/stdlib/lzma.pyi +++ b/mypy/typeshed/stdlib/lzma.pyi @@ -60,7 +60,7 @@ class LZMACompressor(object): def __init__( self, format: Optional[int] = ..., check: int = ..., preset: Optional[int] = ..., filters: Optional[_FilterChain] = ... ) -> None: ... - def compress(self, data: bytes) -> bytes: ... + def compress(self, __data: bytes) -> bytes: ... def flush(self) -> bytes: ... class LZMAError(Exception): ... @@ -161,4 +161,4 @@ def compress( data: bytes, format: int = ..., check: int = ..., preset: Optional[int] = ..., filters: Optional[_FilterChain] = ... ) -> bytes: ... def decompress(data: bytes, format: int = ..., memlimit: Optional[int] = ..., filters: Optional[_FilterChain] = ...) -> bytes: ... -def is_check_supported(check: int) -> bool: ... +def is_check_supported(__check_id: int) -> bool: ... diff --git a/mypy/typeshed/stdlib/mmap.pyi b/mypy/typeshed/stdlib/mmap.pyi index 7039e3da9870..0ba69e5896d7 100644 --- a/mypy/typeshed/stdlib/mmap.pyi +++ b/mypy/typeshed/stdlib/mmap.pyi @@ -9,11 +9,13 @@ ACCESS_COPY: int ALLOCATIONGRANULARITY: int +if sys.platform == "linux": + MAP_DENYWRITE: int + MAP_EXECUTABLE: int + if sys.platform != "win32": MAP_ANON: int MAP_ANONYMOUS: int - MAP_DENYWRITE: int - MAP_EXECUTABLE: int MAP_PRIVATE: int MAP_SHARED: int PROT_EXEC: int @@ -82,26 +84,37 @@ else: def __delitem__(self, index: Union[int, slice]) -> None: ... def __setitem__(self, index: Union[int, slice], object: bytes) -> None: ... -if sys.version_info >= (3, 8): +if sys.version_info >= (3, 8) and sys.platform != "win32": MADV_NORMAL: int MADV_RANDOM: int MADV_SEQUENTIAL: int MADV_WILLNEED: int MADV_DONTNEED: int - MADV_REMOVE: int - MADV_DONTFORK: int - MADV_DOFORK: int - MADV_HWPOISON: int - MADV_MERGEABLE: int - MADV_UNMERGEABLE: int - MADV_SOFT_OFFLINE: int - MADV_HUGEPAGE: int - MADV_NOHUGEPAGE: int - MADV_DONTDUMP: int - MADV_DODUMP: int - MADV_FREE: int - MADV_NOSYNC: int - MADV_AUTOSYNC: int - MADV_NOCORE: int - MADV_CORE: int - MADV_PROTECT: int + + if sys.platform == "linux": + MADV_REMOVE: int + MADV_DONTFORK: int + MADV_DOFORK: int + MADV_HWPOISON: int + MADV_MERGEABLE: int + MADV_UNMERGEABLE: int + # Seems like this constant is not defined in glibc. + # See https://github.com/python/typeshed/pull/5360 for details + # MADV_SOFT_OFFLINE: int + MADV_HUGEPAGE: int + MADV_NOHUGEPAGE: int + MADV_DONTDUMP: int + MADV_DODUMP: int + MADV_FREE: int + + # This Values are defined for FreeBSD but type checkers do not support conditions for these + if sys.platform != "linux" and sys.platform != "darwin": + MADV_NOSYNC: int + MADV_AUTOSYNC: int + MADV_NOCORE: int + MADV_CORE: int + MADV_PROTECT: int + +if sys.version_info >= (3, 10) and sys.platform == "darwin": + MADV_FREE_REUSABLE: int + MADV_FREE_REUSE: int diff --git a/mypy/typeshed/stdlib/multiprocessing/__init__.pyi b/mypy/typeshed/stdlib/multiprocessing/__init__.pyi index c7de0a2d0cad..161cc72c27d8 100644 --- a/mypy/typeshed/stdlib/multiprocessing/__init__.pyi +++ b/mypy/typeshed/stdlib/multiprocessing/__init__.pyi @@ -1,6 +1,7 @@ import sys +from collections.abc import Callable, Iterable from logging import Logger -from multiprocessing import connection, pool, sharedctypes, synchronize +from multiprocessing import connection, context, pool, synchronize from multiprocessing.context import ( AuthenticationError as AuthenticationError, BaseContext, @@ -17,7 +18,7 @@ from multiprocessing.process import active_children as active_children, current_ # These are technically functions that return instances of these Queue classes. See #4313 for discussion from multiprocessing.queues import JoinableQueue as JoinableQueue, Queue as Queue, SimpleQueue as SimpleQueue from multiprocessing.spawn import freeze_support as freeze_support -from typing import Any, Callable, Iterable, List, Optional, Sequence, Tuple, Union, overload +from typing import Any, Optional, Union, overload from typing_extensions import Literal if sys.version_info >= (3, 8): @@ -32,6 +33,10 @@ if sys.platform != "win32": # Sychronization primitives _LockLike = Union[synchronize.Lock, synchronize.RLock] +RawValue = context._default_context.RawValue +RawArray = context._default_context.RawArray +Value = context._default_context.Value +Array = context._default_context.Array def Barrier(parties: int, action: Optional[Callable[..., Any]] = ..., timeout: Optional[float] = ...) -> synchronize.Barrier: ... def BoundedSemaphore(value: int = ...) -> synchronize.BoundedSemaphore: ... @@ -40,7 +45,7 @@ def Event() -> synchronize.Event: ... def Lock() -> synchronize.Lock: ... def RLock() -> synchronize.RLock: ... def Semaphore(value: int = ...) -> synchronize.Semaphore: ... -def Pipe(duplex: bool = ...) -> Tuple[connection.Connection, connection.Connection]: ... +def Pipe(duplex: bool = ...) -> tuple[connection.Connection, connection.Connection]: ... def Pool( processes: Optional[int] = ..., initializer: Optional[Callable[..., Any]] = ..., @@ -48,12 +53,6 @@ def Pool( maxtasksperchild: Optional[int] = ..., ) -> pool.Pool: ... -# Functions Array and Value are copied from context.pyi. -# See https://github.com/python/typeshed/blob/ac234f25927634e06d9c96df98d72d54dd80dfc4/stdlib/2and3/turtle.pyi#L284-L291 -# for rationale -def Array(typecode_or_type: Any, size_or_initializer: Union[int, Sequence[Any]], *, lock: bool = ...) -> sharedctypes._Array: ... -def Value(typecode_or_type: Any, *args: Any, lock: bool = ...) -> sharedctypes._Value: ... - # ----- multiprocessing function stubs ----- def allow_connection_pickling() -> None: ... def cpu_count() -> int: ... @@ -61,8 +60,8 @@ def get_logger() -> Logger: ... def log_to_stderr(level: Optional[Union[str, int]] = ...) -> Logger: ... def Manager() -> SyncManager: ... def set_executable(executable: str) -> None: ... -def set_forkserver_preload(module_names: List[str]) -> None: ... -def get_all_start_methods() -> List[str]: ... +def set_forkserver_preload(module_names: list[str]) -> None: ... +def get_all_start_methods() -> list[str]: ... def get_start_method(allow_none: bool = ...) -> Optional[str]: ... def set_start_method(method: str, force: Optional[bool] = ...) -> None: ... diff --git a/mypy/typeshed/stdlib/multiprocessing/context.pyi b/mypy/typeshed/stdlib/multiprocessing/context.pyi index 5f39d45dc32a..e02bacc0a243 100644 --- a/mypy/typeshed/stdlib/multiprocessing/context.pyi +++ b/mypy/typeshed/stdlib/multiprocessing/context.pyi @@ -1,12 +1,17 @@ +import ctypes import multiprocessing import sys +from collections.abc import Callable, Iterable, Sequence +from ctypes import _CData from logging import Logger -from multiprocessing import queues, sharedctypes, synchronize +from multiprocessing import queues, synchronize from multiprocessing.process import BaseProcess -from typing import Any, Callable, Iterable, List, Optional, Sequence, Type, Union, overload +from multiprocessing.sharedctypes import SynchronizedArray, SynchronizedBase +from typing import Any, Optional, Type, TypeVar, Union, overload from typing_extensions import Literal _LockLike = Union[synchronize.Lock, synchronize.RLock] +_CT = TypeVar("_CT", bound=_CData) class ProcessError(Exception): ... class BufferTooShort(ProcessError): ... @@ -28,7 +33,7 @@ class BaseContext(object): @staticmethod def parent_process() -> Optional[BaseProcess]: ... @staticmethod - def active_children() -> List[BaseProcess]: ... + def active_children() -> list[BaseProcess]: ... def cpu_count(self) -> int: ... # TODO: change return to SyncManager once a stub exists in multiprocessing.managers def Manager(self) -> Any: ... @@ -53,28 +58,52 @@ class BaseContext(object): initargs: Iterable[Any] = ..., maxtasksperchild: Optional[int] = ..., ) -> multiprocessing.pool.Pool: ... - # TODO: typecode_or_type param is a ctype with a base class of _SimpleCData or array.typecode Need to figure out - # how to handle the ctype - # TODO: change return to RawValue once a stub exists in multiprocessing.sharedctypes - def RawValue(self, typecode_or_type: Any, *args: Any) -> Any: ... - # TODO: typecode_or_type param is a ctype with a base class of _SimpleCData or array.typecode Need to figure out - # how to handle the ctype - # TODO: change return to RawArray once a stub exists in multiprocessing.sharedctypes - def RawArray(self, typecode_or_type: Any, size_or_initializer: Union[int, Sequence[Any]]) -> Any: ... - # TODO: typecode_or_type param is a ctype with a base class of _SimpleCData or array.typecode Need to figure out - # how to handle the ctype - def Value(self, typecode_or_type: Any, *args: Any, lock: bool = ...) -> sharedctypes._Value: ... - # TODO: typecode_or_type param is a ctype with a base class of _SimpleCData or array.typecode Need to figure out - # how to handle the ctype + @overload + def RawValue(self, typecode_or_type: Type[_CT], *args: Any) -> _CT: ... + @overload + def RawValue(self, typecode_or_type: str, *args: Any) -> Any: ... + @overload + def RawArray(self, typecode_or_type: Type[_CT], size_or_initializer: Union[int, Sequence[Any]]) -> ctypes.Array[_CT]: ... + @overload + def RawArray(self, typecode_or_type: str, size_or_initializer: Union[int, Sequence[Any]]) -> Any: ... + @overload + def Value(self, typecode_or_type: Type[_CT], *args: Any, lock: Literal[False]) -> _CT: ... + @overload + def Value(self, typecode_or_type: Type[_CT], *args: Any, lock: Union[Literal[True], _LockLike]) -> SynchronizedBase[_CT]: ... + @overload + def Value(self, typecode_or_type: str, *args: Any, lock: Union[Literal[True], _LockLike]) -> SynchronizedBase[Any]: ... + @overload + def Value(self, typecode_or_type: Union[str, Type[_CData]], *args: Any, lock: Union[bool, _LockLike] = ...) -> Any: ... + @overload def Array( - self, typecode_or_type: Any, size_or_initializer: Union[int, Sequence[Any]], *, lock: bool = ... - ) -> sharedctypes._Array: ... + self, typecode_or_type: Type[_CT], size_or_initializer: Union[int, Sequence[Any]], *, lock: Literal[False] + ) -> _CT: ... + @overload + def Array( + self, + typecode_or_type: Type[_CT], + size_or_initializer: Union[int, Sequence[Any]], + *, + lock: Union[Literal[True], _LockLike], + ) -> SynchronizedArray[_CT]: ... + @overload + def Array( + self, typecode_or_type: str, size_or_initializer: Union[int, Sequence[Any]], *, lock: Union[Literal[True], _LockLike] + ) -> SynchronizedArray[Any]: ... + @overload + def Array( + self, + typecode_or_type: Union[str, Type[_CData]], + size_or_initializer: Union[int, Sequence[Any]], + *, + lock: Union[bool, _LockLike] = ..., + ) -> Any: ... def freeze_support(self) -> None: ... def get_logger(self) -> Logger: ... def log_to_stderr(self, level: Optional[str] = ...) -> Logger: ... def allow_connection_pickling(self) -> None: ... def set_executable(self, executable: str) -> None: ... - def set_forkserver_preload(self, module_names: List[str]) -> None: ... + def set_forkserver_preload(self, module_names: list[str]) -> None: ... if sys.platform != "win32": @overload def get_context(self, method: None = ...) -> DefaultContext: ... @@ -111,7 +140,9 @@ class DefaultContext(BaseContext): def __init__(self, context: BaseContext) -> None: ... def set_start_method(self, method: Optional[str], force: bool = ...) -> None: ... def get_start_method(self, allow_none: bool = ...) -> str: ... - def get_all_start_methods(self) -> List[str]: ... + def get_all_start_methods(self) -> list[str]: ... + +_default_context: DefaultContext if sys.platform != "win32": class ForkProcess(BaseProcess): diff --git a/mypy/typeshed/stdlib/multiprocessing/pool.pyi b/mypy/typeshed/stdlib/multiprocessing/pool.pyi index 40e5ec46f49b..d8004655cacf 100644 --- a/mypy/typeshed/stdlib/multiprocessing/pool.pyi +++ b/mypy/typeshed/stdlib/multiprocessing/pool.pyi @@ -9,6 +9,12 @@ _S = TypeVar("_S") _T = TypeVar("_T") class ApplyResult(Generic[_T]): + def __init__( + self, + pool: Pool, + callback: Optional[Callable[[_T], None]] = ..., + error_callback: Optional[Callable[[BaseException], None]] = ..., + ) -> None: ... def get(self, timeout: Optional[float] = ...) -> _T: ... def wait(self, timeout: Optional[float] = ...) -> None: ... def ready(self) -> bool: ... diff --git a/mypy/typeshed/stdlib/multiprocessing/sharedctypes.pyi b/mypy/typeshed/stdlib/multiprocessing/sharedctypes.pyi index 3979b0947287..0dc5977decd7 100644 --- a/mypy/typeshed/stdlib/multiprocessing/sharedctypes.pyi +++ b/mypy/typeshed/stdlib/multiprocessing/sharedctypes.pyi @@ -1,43 +1,101 @@ -from ctypes import _CData +import ctypes +from collections.abc import Callable, Iterable, Sequence +from ctypes import _CData, _SimpleCData, c_char from multiprocessing.context import BaseContext from multiprocessing.synchronize import _LockLike -from typing import Any, List, Optional, Sequence, Type, Union, overload +from typing import Any, Generic, Optional, Protocol, Type, TypeVar, Union, overload +from typing_extensions import Literal -class _Array: - value: Any = ... - def __init__( - self, - typecode_or_type: Union[str, Type[_CData]], - size_or_initializer: Union[int, Sequence[Any]], - *, - lock: Union[bool, _LockLike] = ..., - ) -> None: ... - def acquire(self) -> bool: ... - def release(self) -> bool: ... - def get_lock(self) -> _LockLike: ... - def get_obj(self) -> Any: ... - @overload - def __getitem__(self, key: int) -> Any: ... - @overload - def __getitem__(self, key: slice) -> List[Any]: ... - def __getslice__(self, start: int, stop: int) -> Any: ... - def __setitem__(self, key: int, value: Any) -> None: ... - -class _Value: - value: Any = ... - def __init__(self, typecode_or_type: Union[str, Type[_CData]], *args: Any, lock: Union[bool, _LockLike] = ...) -> None: ... - def get_lock(self) -> _LockLike: ... - def get_obj(self) -> Any: ... - def acquire(self) -> bool: ... - def release(self) -> bool: ... +_T = TypeVar("_T") +_CT = TypeVar("_CT", bound=_CData) +@overload +def RawValue(typecode_or_type: Type[_CT], *args: Any) -> _CT: ... +@overload +def RawValue(typecode_or_type: str, *args: Any) -> Any: ... +@overload +def RawArray(typecode_or_type: Type[_CT], size_or_initializer: Union[int, Sequence[Any]]) -> ctypes.Array[_CT]: ... +@overload +def RawArray(typecode_or_type: str, size_or_initializer: Union[int, Sequence[Any]]) -> Any: ... +@overload +def Value(typecode_or_type: Type[_CT], *args: Any, lock: Literal[False], ctx: Optional[BaseContext] = ...) -> _CT: ... +@overload +def Value( + typecode_or_type: Type[_CT], *args: Any, lock: Union[Literal[True], _LockLike], ctx: Optional[BaseContext] = ... +) -> SynchronizedBase[_CT]: ... +@overload +def Value( + typecode_or_type: str, *args: Any, lock: Union[Literal[True], _LockLike], ctx: Optional[BaseContext] = ... +) -> SynchronizedBase[Any]: ... +@overload +def Value( + typecode_or_type: Union[str, Type[_CData]], *args: Any, lock: Union[bool, _LockLike] = ..., ctx: Optional[BaseContext] = ... +) -> Any: ... +@overload +def Array( + typecode_or_type: Type[_CT], + size_or_initializer: Union[int, Sequence[Any]], + *, + lock: Literal[False], + ctx: Optional[BaseContext] = ..., +) -> _CT: ... +@overload +def Array( + typecode_or_type: Type[_CT], + size_or_initializer: Union[int, Sequence[Any]], + *, + lock: Union[Literal[True], _LockLike], + ctx: Optional[BaseContext] = ..., +) -> SynchronizedArray[_CT]: ... +@overload +def Array( + typecode_or_type: str, + size_or_initializer: Union[int, Sequence[Any]], + *, + lock: Union[Literal[True], _LockLike], + ctx: Optional[BaseContext] = ..., +) -> SynchronizedArray[Any]: ... +@overload def Array( typecode_or_type: Union[str, Type[_CData]], size_or_initializer: Union[int, Sequence[Any]], *, lock: Union[bool, _LockLike] = ..., ctx: Optional[BaseContext] = ..., -) -> _Array: ... -def Value( - typecode_or_type: Union[str, Type[_CData]], *args: Any, lock: Union[bool, _LockLike] = ..., ctx: Optional[BaseContext] = ... -) -> _Value: ... +) -> Any: ... +def copy(obj: _CT) -> _CT: ... +@overload +def synchronized(obj: _SimpleCData[_T], lock: Optional[_LockLike] = ..., ctx: Optional[Any] = ...) -> Synchronized[_T]: ... +@overload +def synchronized(obj: ctypes.Array[c_char], lock: Optional[_LockLike] = ..., ctx: Optional[Any] = ...) -> SynchronizedString: ... +@overload +def synchronized(obj: ctypes.Array[_CT], lock: Optional[_LockLike] = ..., ctx: Optional[Any] = ...) -> SynchronizedArray[_CT]: ... +@overload +def synchronized(obj: _CT, lock: Optional[_LockLike] = ..., ctx: Optional[Any] = ...) -> SynchronizedBase[_CT]: ... + +class _AcquireFunc(Protocol): + def __call__(self, block: bool = ..., timeout: Optional[float] = ...) -> bool: ... + +class SynchronizedBase(Generic[_CT]): + acquire: _AcquireFunc = ... + release: Callable[[], None] = ... + def __init__(self, obj: Any, lock: Optional[_LockLike] = ..., ctx: Optional[Any] = ...) -> None: ... + def __reduce__(self) -> tuple[Callable[..., Any], tuple[Any, _LockLike]]: ... + def get_obj(self) -> _CT: ... + def get_lock(self) -> _LockLike: ... + def __enter__(self) -> bool: ... + def __exit__(self, *args: Any) -> None: ... + +class Synchronized(SynchronizedBase[_SimpleCData[_T]], Generic[_T]): + value: _T + +class SynchronizedArray(SynchronizedBase[ctypes.Array[_CT]], Generic[_CT]): + def __len__(self) -> int: ... + def __getitem__(self, i: int) -> _CT: ... + def __setitem__(self, i: int, o: _CT) -> None: ... + def __getslice__(self, start: int, stop: int) -> list[_CT]: ... + def __setslice__(self, start: int, stop: int, values: Iterable[_CT]) -> None: ... + +class SynchronizedString(SynchronizedArray[c_char]): + value: bytes + raw: bytes diff --git a/mypy/typeshed/stdlib/nntplib.pyi b/mypy/typeshed/stdlib/nntplib.pyi index 7e7b7b84c4f0..36fe063c8486 100644 --- a/mypy/typeshed/stdlib/nntplib.pyi +++ b/mypy/typeshed/stdlib/nntplib.pyi @@ -82,7 +82,7 @@ class _NNTPBase: def ihave(self, message_id: Any, data: Union[bytes, Iterable[bytes]]) -> str: ... def quit(self) -> str: ... def login(self, user: Optional[str] = ..., password: Optional[str] = ..., usenetrc: bool = ...) -> None: ... - def starttls(self, ssl_context: Optional[ssl.SSLContext] = ...) -> None: ... + def starttls(self, context: Optional[ssl.SSLContext] = ...) -> None: ... class NNTP(_NNTPBase): port: int diff --git a/mypy/typeshed/stdlib/ntpath.pyi b/mypy/typeshed/stdlib/ntpath.pyi index 4ee1b31c64e3..6af6cbbc2ab2 100644 --- a/mypy/typeshed/stdlib/ntpath.pyi +++ b/mypy/typeshed/stdlib/ntpath.pyi @@ -3,7 +3,8 @@ import sys from _typeshed import AnyPath, BytesPath, StrPath from genericpath import exists as exists from os import PathLike -from typing import Any, AnyStr, Optional, Sequence, Tuple, TypeVar, overload +from os.path import commonpath as commonpath, commonprefix as commonprefix, lexists as lexists +from typing import AnyStr, Optional, Tuple, TypeVar, overload _T = TypeVar("_T") @@ -52,27 +53,10 @@ def normcase(s: AnyStr) -> AnyStr: ... def normpath(path: PathLike[AnyStr]) -> AnyStr: ... @overload def normpath(path: AnyStr) -> AnyStr: ... - -if sys.platform == "win32": - @overload - def realpath(path: PathLike[AnyStr]) -> AnyStr: ... - @overload - def realpath(path: AnyStr) -> AnyStr: ... - -else: - @overload - def realpath(filename: PathLike[AnyStr]) -> AnyStr: ... - @overload - def realpath(filename: AnyStr) -> AnyStr: ... - -# In reality it returns str for sequences of StrPath and bytes for sequences -# of BytesPath, but mypy does not accept such a signature. -def commonpath(paths: Sequence[AnyPath]) -> Any: ... - -# NOTE: Empty lists results in '' (str) regardless of contained type. -# So, fall back to Any -def commonprefix(m: Sequence[AnyPath]) -> Any: ... -def lexists(path: AnyPath) -> bool: ... +@overload +def realpath(path: PathLike[AnyStr]) -> AnyStr: ... +@overload +def realpath(path: AnyStr) -> AnyStr: ... # These return float if os.stat_float_times() == True, # but int is a subclass of float. diff --git a/mypy/typeshed/stdlib/os/__init__.pyi b/mypy/typeshed/stdlib/os/__init__.pyi index 98f078be869f..24744b62b5d9 100644 --- a/mypy/typeshed/stdlib/os/__init__.pyi +++ b/mypy/typeshed/stdlib/os/__init__.pyi @@ -117,6 +117,8 @@ if sys.platform != "win32": RTLD_LOCAL: int RTLD_NODELETE: int RTLD_NOLOAD: int + +if sys.platform == "linux": RTLD_DEEPBIND: int SEEK_SET: int @@ -712,7 +714,7 @@ if sys.platform != "win32": ) -> Iterator[Tuple[str, List[str], List[str], int]]: ... if sys.platform == "linux": def getxattr(path: _FdOrAnyPath, attribute: AnyPath, *, follow_symlinks: bool = ...) -> bytes: ... - def listxattr(path: _FdOrAnyPath, *, follow_symlinks: bool = ...) -> List[str]: ... + def listxattr(path: Optional[_FdOrAnyPath] = ..., *, follow_symlinks: bool = ...) -> List[str]: ... def removexattr(path: _FdOrAnyPath, attribute: AnyPath, *, follow_symlinks: bool = ...) -> None: ... def setxattr( path: _FdOrAnyPath, attribute: AnyPath, value: bytes, flags: int = ..., *, follow_symlinks: bool = ... @@ -791,8 +793,9 @@ else: def spawnvp(mode: int, file: AnyPath, args: _ExecVArgs) -> int: ... def spawnvpe(mode: int, file: AnyPath, args: _ExecVArgs, env: _ExecEnv) -> int: ... def wait() -> Tuple[int, int]: ... # Unix only - from posix import waitid_result - def waitid(idtype: int, ident: int, options: int) -> waitid_result: ... + if sys.platform != "darwin": + from posix import waitid_result + def waitid(idtype: int, ident: int, options: int) -> waitid_result: ... def wait3(options: int) -> Tuple[int, int, Any]: ... def wait4(pid: int, options: int) -> Tuple[int, int, Any]: ... def WCOREDUMP(__status: int) -> bool: ... @@ -808,14 +811,15 @@ if sys.platform != "win32": from posix import sched_param def sched_get_priority_min(policy: int) -> int: ... # some flavors of Unix def sched_get_priority_max(policy: int) -> int: ... # some flavors of Unix - def sched_setscheduler(pid: int, policy: int, param: sched_param) -> None: ... # some flavors of Unix - def sched_getscheduler(pid: int) -> int: ... # some flavors of Unix - def sched_setparam(pid: int, param: sched_param) -> None: ... # some flavors of Unix - def sched_getparam(pid: int) -> sched_param: ... # some flavors of Unix - def sched_rr_get_interval(pid: int) -> float: ... # some flavors of Unix def sched_yield() -> None: ... # some flavors of Unix - def sched_setaffinity(pid: int, mask: Iterable[int]) -> None: ... # some flavors of Unix - def sched_getaffinity(pid: int) -> Set[int]: ... # some flavors of Unix + if sys.platform != "darwin": + def sched_setscheduler(pid: int, policy: int, param: sched_param) -> None: ... # some flavors of Unix + def sched_getscheduler(pid: int) -> int: ... # some flavors of Unix + def sched_rr_get_interval(pid: int) -> float: ... # some flavors of Unix + def sched_setparam(pid: int, param: sched_param) -> None: ... # some flavors of Unix + def sched_getparam(pid: int) -> sched_param: ... # some flavors of Unix + def sched_setaffinity(pid: int, mask: Iterable[int]) -> None: ... # some flavors of Unix + def sched_getaffinity(pid: int) -> Set[int]: ... # some flavors of Unix def cpu_count() -> Optional[int]: ... diff --git a/mypy/typeshed/stdlib/os/path.pyi b/mypy/typeshed/stdlib/os/path.pyi index 2fcbe12f8a08..66d5d6fea4f2 100644 --- a/mypy/typeshed/stdlib/os/path.pyi +++ b/mypy/typeshed/stdlib/os/path.pyi @@ -1,11 +1,10 @@ import os import sys -from _typeshed import AnyPath, BytesPath, StrPath +from _typeshed import AnyPath, BytesPath, StrPath, SupportsLessThanT from genericpath import exists as exists from os import PathLike -from typing import Any, AnyStr, Optional, Sequence, Tuple, TypeVar, overload - -_T = TypeVar("_T") +from typing import AnyStr, List, Optional, Sequence, Tuple, Union, overload +from typing_extensions import Literal # ----- os.path variables ----- supports_unicode_filenames: bool @@ -54,25 +53,45 @@ def normpath(path: PathLike[AnyStr]) -> AnyStr: ... def normpath(path: AnyStr) -> AnyStr: ... if sys.platform == "win32": - @overload - def realpath(path: PathLike[AnyStr]) -> AnyStr: ... - @overload - def realpath(path: AnyStr) -> AnyStr: ... + if sys.version_info >= (3, 10): + @overload + def realpath(path: PathLike[AnyStr], *, strict: bool = ...) -> AnyStr: ... + @overload + def realpath(path: AnyStr, *, strict: bool = ...) -> AnyStr: ... + else: + @overload + def realpath(path: PathLike[AnyStr]) -> AnyStr: ... + @overload + def realpath(path: AnyStr) -> AnyStr: ... else: - @overload - def realpath(filename: PathLike[AnyStr]) -> AnyStr: ... - @overload - def realpath(filename: AnyStr) -> AnyStr: ... + if sys.version_info >= (3, 10): + @overload + def realpath(filename: PathLike[AnyStr], *, strict: bool = ...) -> AnyStr: ... + @overload + def realpath(filename: AnyStr, *, strict: bool = ...) -> AnyStr: ... + else: + @overload + def realpath(filename: PathLike[AnyStr]) -> AnyStr: ... + @overload + def realpath(filename: AnyStr) -> AnyStr: ... -# In reality it returns str for sequences of StrPath and bytes for sequences -# of BytesPath, but mypy does not accept such a signature. -def commonpath(paths: Sequence[AnyPath]) -> Any: ... +@overload +def commonpath(paths: Sequence[StrPath]) -> str: ... +@overload +def commonpath(paths: Sequence[BytesPath]) -> bytes: ... -# NOTE: Empty lists results in '' (str) regardless of contained type. -# So, fall back to Any -def commonprefix(m: Sequence[AnyPath]) -> Any: ... -def lexists(path: AnyPath) -> bool: ... +# All overloads can return empty string. Ideally, Literal[""] would be a valid +# Iterable[T], so that Union[List[T], Literal[""]] could be used as a return +# type. But because this only works when T is str, we need Sequence[T] instead. +@overload +def commonprefix(m: Sequence[StrPath]) -> str: ... +@overload +def commonprefix(m: Sequence[BytesPath]) -> Union[bytes, Literal[""]]: ... +@overload +def commonprefix(m: Sequence[List[SupportsLessThanT]]) -> Sequence[SupportsLessThanT]: ... +@overload +def commonprefix(m: Sequence[Tuple[SupportsLessThanT, ...]]) -> Sequence[SupportsLessThanT]: ... # These return float if os.stat_float_times() == True, # but int is a subclass of float. @@ -80,6 +99,7 @@ def getatime(filename: AnyPath) -> float: ... def getmtime(filename: AnyPath) -> float: ... def getctime(filename: AnyPath) -> float: ... def getsize(filename: AnyPath) -> int: ... +def lexists(path: AnyPath) -> bool: ... def isabs(s: AnyPath) -> bool: ... def isfile(path: AnyPath) -> bool: ... def isdir(s: AnyPath) -> bool: ... diff --git a/mypy/typeshed/stdlib/parser.pyi b/mypy/typeshed/stdlib/parser.pyi index 799f25cf6a48..36fe6cafc0aa 100644 --- a/mypy/typeshed/stdlib/parser.pyi +++ b/mypy/typeshed/stdlib/parser.pyi @@ -1,22 +1,22 @@ +import sys from _typeshed import AnyPath from types import CodeType from typing import Any, List, Sequence, Text, Tuple -def expr(source: Text) -> STType: ... -def suite(source: Text) -> STType: ... -def sequence2st(sequence: Sequence[Any]) -> STType: ... -def tuple2st(sequence: Sequence[Any]) -> STType: ... -def st2list(st: STType, line_info: bool = ..., col_info: bool = ...) -> List[Any]: ... -def st2tuple(st: STType, line_info: bool = ..., col_info: bool = ...) -> Tuple[Any]: ... -def compilest(st: STType, filename: AnyPath = ...) -> CodeType: ... -def isexpr(st: STType) -> bool: ... -def issuite(st: STType) -> bool: ... - -class ParserError(Exception): ... - -class STType: - def compile(self, filename: AnyPath = ...) -> CodeType: ... - def isexpr(self) -> bool: ... - def issuite(self) -> bool: ... - def tolist(self, line_info: bool = ..., col_info: bool = ...) -> List[Any]: ... - def totuple(self, line_info: bool = ..., col_info: bool = ...) -> Tuple[Any]: ... +if sys.version_info < (3, 10): + def expr(source: Text) -> STType: ... + def suite(source: Text) -> STType: ... + def sequence2st(sequence: Sequence[Any]) -> STType: ... + def tuple2st(sequence: Sequence[Any]) -> STType: ... + def st2list(st: STType, line_info: bool = ..., col_info: bool = ...) -> List[Any]: ... + def st2tuple(st: STType, line_info: bool = ..., col_info: bool = ...) -> Tuple[Any]: ... + def compilest(st: STType, filename: AnyPath = ...) -> CodeType: ... + def isexpr(st: STType) -> bool: ... + def issuite(st: STType) -> bool: ... + class ParserError(Exception): ... + class STType: + def compile(self, filename: AnyPath = ...) -> CodeType: ... + def isexpr(self) -> bool: ... + def issuite(self) -> bool: ... + def tolist(self, line_info: bool = ..., col_info: bool = ...) -> List[Any]: ... + def totuple(self, line_info: bool = ..., col_info: bool = ...) -> Tuple[Any]: ... diff --git a/mypy/typeshed/stdlib/platform.pyi b/mypy/typeshed/stdlib/platform.pyi index 73579dff3887..217882224d74 100644 --- a/mypy/typeshed/stdlib/platform.pyi +++ b/mypy/typeshed/stdlib/platform.pyi @@ -1,6 +1,6 @@ import sys -if sys.version_info < (3, 9): +if sys.version_info < (3, 8): import os DEV_NULL = os.devnull diff --git a/mypy/typeshed/stdlib/posix.pyi b/mypy/typeshed/stdlib/posix.pyi index 5d0f69fa4394..2499463647cd 100644 --- a/mypy/typeshed/stdlib/posix.pyi +++ b/mypy/typeshed/stdlib/posix.pyi @@ -16,12 +16,13 @@ class times_result(NamedTuple): children_system: float elapsed: float -class waitid_result(NamedTuple): - si_pid: int - si_uid: int - si_signo: int - si_status: int - si_code: int +if sys.platform != "darwin": + class waitid_result(NamedTuple): + si_pid: int + si_uid: int + si_signo: int + si_status: int + si_code: int class sched_param(NamedTuple): sched_priority: int @@ -59,8 +60,9 @@ F_TEST: int F_TLOCK: int F_ULOCK: int -GRND_NONBLOCK: int -GRND_RANDOM: int +if sys.platform == "linux": + GRND_NONBLOCK: int + GRND_RANDOM: int NGROUPS_MAX: int O_APPEND: int @@ -84,12 +86,13 @@ O_SYNC: int O_TRUNC: int O_WRONLY: int -POSIX_FADV_DONTNEED: int -POSIX_FADV_NOREUSE: int -POSIX_FADV_NORMAL: int -POSIX_FADV_RANDOM: int -POSIX_FADV_SEQUENTIAL: int -POSIX_FADV_WILLNEED: int +if sys.platform != "darwin": + POSIX_FADV_DONTNEED: int + POSIX_FADV_NOREUSE: int + POSIX_FADV_NORMAL: int + POSIX_FADV_RANDOM: int + POSIX_FADV_SEQUENTIAL: int + POSIX_FADV_WILLNEED: int PRIO_PGRP: int PRIO_PROCESS: int @@ -99,7 +102,8 @@ P_ALL: int P_PGID: int P_PID: int -RTLD_DEEPBIND: int +if sys.platform == "linux": + RTLD_DEEPBIND: int RTLD_GLOBAL: int RTLD_LAZY: int RTLD_LOCAL: int @@ -107,13 +111,16 @@ RTLD_NODELETE: int RTLD_NOLOAD: int RTLD_NOW: int -SCHED_BATCH: int SCHED_FIFO: int -SCHED_IDLE: int SCHED_OTHER: int -SCHED_RESET_ON_FORK: int SCHED_RR: int +if sys.platform == "linux": + SCHED_BATCH: int + SCHED_IDLE: int +if sys.platform != "darwin": + SCHED_RESET_ON_FORK: int + SEEK_DATA: int SEEK_HOLE: int diff --git a/mypy/typeshed/stdlib/posixpath.pyi b/mypy/typeshed/stdlib/posixpath.pyi index 2fcbe12f8a08..684aec59b16f 100644 --- a/mypy/typeshed/stdlib/posixpath.pyi +++ b/mypy/typeshed/stdlib/posixpath.pyi @@ -3,7 +3,8 @@ import sys from _typeshed import AnyPath, BytesPath, StrPath from genericpath import exists as exists from os import PathLike -from typing import Any, AnyStr, Optional, Sequence, Tuple, TypeVar, overload +from os.path import commonpath as commonpath, commonprefix as commonprefix, lexists as lexists +from typing import AnyStr, Optional, Tuple, TypeVar, overload _T = TypeVar("_T") @@ -53,11 +54,11 @@ def normpath(path: PathLike[AnyStr]) -> AnyStr: ... @overload def normpath(path: AnyStr) -> AnyStr: ... -if sys.platform == "win32": +if sys.version_info >= (3, 10): @overload - def realpath(path: PathLike[AnyStr]) -> AnyStr: ... + def realpath(filename: PathLike[AnyStr], *, strict: bool = ...) -> AnyStr: ... @overload - def realpath(path: AnyStr) -> AnyStr: ... + def realpath(filename: AnyStr, *, strict: bool = ...) -> AnyStr: ... else: @overload @@ -65,15 +66,6 @@ else: @overload def realpath(filename: AnyStr) -> AnyStr: ... -# In reality it returns str for sequences of StrPath and bytes for sequences -# of BytesPath, but mypy does not accept such a signature. -def commonpath(paths: Sequence[AnyPath]) -> Any: ... - -# NOTE: Empty lists results in '' (str) regardless of contained type. -# So, fall back to Any -def commonprefix(m: Sequence[AnyPath]) -> Any: ... -def lexists(path: AnyPath) -> bool: ... - # These return float if os.stat_float_times() == True, # but int is a subclass of float. def getatime(filename: AnyPath) -> float: ... diff --git a/mypy/typeshed/stdlib/pprint.pyi b/mypy/typeshed/stdlib/pprint.pyi index 6c1133aa5c10..9484f92eca46 100644 --- a/mypy/typeshed/stdlib/pprint.pyi +++ b/mypy/typeshed/stdlib/pprint.pyi @@ -1,7 +1,19 @@ import sys from typing import IO, Any, Dict, Optional, Tuple -if sys.version_info >= (3, 8): +if sys.version_info >= (3, 10): + def pformat( + object: object, + indent: int = ..., + width: int = ..., + depth: Optional[int] = ..., + *, + compact: bool = ..., + sort_dicts: bool = ..., + underscore_numbers: bool = ..., + ) -> str: ... + +elif sys.version_info >= (3, 8): def pformat( object: object, indent: int = ..., @@ -20,7 +32,20 @@ elif sys.version_info >= (3, 4): else: def pformat(object: object, indent: int = ..., width: int = ..., depth: Optional[int] = ...) -> str: ... -if sys.version_info >= (3, 8): +if sys.version_info >= (3, 10): + def pp( + object: object, + stream: Optional[IO[str]] = ..., + indent: int = ..., + width: int = ..., + depth: Optional[int] = ..., + *, + compact: bool = ..., + sort_dicts: bool = ..., + underscore_numbers: bool = ..., + ) -> None: ... + +elif sys.version_info >= (3, 8): def pp( object: object, stream: Optional[IO[str]] = ..., @@ -32,7 +57,20 @@ if sys.version_info >= (3, 8): sort_dicts: bool = ..., ) -> None: ... -if sys.version_info >= (3, 8): +if sys.version_info >= (3, 10): + def pprint( + object: object, + stream: Optional[IO[str]] = ..., + indent: int = ..., + width: int = ..., + depth: Optional[int] = ..., + *, + compact: bool = ..., + sort_dicts: bool = ..., + underscore_numbers: bool = ..., + ) -> None: ... + +elif sys.version_info >= (3, 8): def pprint( object: object, stream: Optional[IO[str]] = ..., @@ -65,7 +103,19 @@ def isrecursive(object: object) -> bool: ... def saferepr(object: object) -> str: ... class PrettyPrinter: - if sys.version_info >= (3, 8): + if sys.version_info >= (3, 10): + def __init__( + self, + indent: int = ..., + width: int = ..., + depth: Optional[int] = ..., + stream: Optional[IO[str]] = ..., + *, + compact: bool = ..., + sort_dicts: bool = ..., + underscore_numbers: bool = ..., + ) -> None: ... + elif sys.version_info >= (3, 8): def __init__( self, indent: int = ..., diff --git a/mypy/typeshed/stdlib/re.pyi b/mypy/typeshed/stdlib/re.pyi index e690e76586a0..3f2490ae4e11 100644 --- a/mypy/typeshed/stdlib/re.pyi +++ b/mypy/typeshed/stdlib/re.pyi @@ -70,9 +70,9 @@ def fullmatch(pattern: AnyStr, string: AnyStr, flags: _FlagsType = ...) -> Optio @overload def fullmatch(pattern: Pattern[AnyStr], string: AnyStr, flags: _FlagsType = ...) -> Optional[Match[AnyStr]]: ... @overload -def split(pattern: AnyStr, string: AnyStr, maxsplit: int = ..., flags: _FlagsType = ...) -> List[AnyStr]: ... +def split(pattern: AnyStr, string: AnyStr, maxsplit: int = ..., flags: _FlagsType = ...) -> List[Union[AnyStr, Any]]: ... @overload -def split(pattern: Pattern[AnyStr], string: AnyStr, maxsplit: int = ..., flags: _FlagsType = ...) -> List[AnyStr]: ... +def split(pattern: Pattern[AnyStr], string: AnyStr, maxsplit: int = ..., flags: _FlagsType = ...) -> List[Union[AnyStr, Any]]: ... @overload def findall(pattern: AnyStr, string: AnyStr, flags: _FlagsType = ...) -> List[Any]: ... @overload diff --git a/mypy/typeshed/stdlib/readline.pyi b/mypy/typeshed/stdlib/readline.pyi index 9b8e05f09942..8f28a2b2b760 100644 --- a/mypy/typeshed/stdlib/readline.pyi +++ b/mypy/typeshed/stdlib/readline.pyi @@ -1,39 +1,40 @@ import sys +from _typeshed import AnyPath from typing import Callable, Optional, Sequence _CompleterT = Optional[Callable[[str, int], Optional[str]]] _CompDispT = Optional[Callable[[str, Sequence[str], int], None]] -def parse_and_bind(string: str) -> None: ... -def read_init_file(filename: str = ...) -> None: ... +def parse_and_bind(__string: str) -> None: ... +def read_init_file(__filename: Optional[AnyPath] = ...) -> None: ... def get_line_buffer() -> str: ... -def insert_text(string: str) -> None: ... +def insert_text(__string: str) -> None: ... def redisplay() -> None: ... -def read_history_file(filename: str = ...) -> None: ... -def write_history_file(filename: str = ...) -> None: ... +def read_history_file(__filename: Optional[AnyPath] = ...) -> None: ... +def write_history_file(__filename: Optional[AnyPath] = ...) -> None: ... if sys.version_info >= (3, 5): - def append_history_file(nelements: int, filename: str = ...) -> None: ... + def append_history_file(__nelements: int, __filename: Optional[AnyPath] = ...) -> None: ... def get_history_length() -> int: ... -def set_history_length(length: int) -> None: ... +def set_history_length(__length: int) -> None: ... def clear_history() -> None: ... def get_current_history_length() -> int: ... -def get_history_item(index: int) -> str: ... -def remove_history_item(pos: int) -> None: ... -def replace_history_item(pos: int, line: str) -> None: ... -def add_history(string: str) -> None: ... +def get_history_item(__index: int) -> str: ... +def remove_history_item(__pos: int) -> None: ... +def replace_history_item(__pos: int, __line: str) -> None: ... +def add_history(__string: str) -> None: ... if sys.version_info >= (3, 6): - def set_auto_history(enabled: bool) -> None: ... + def set_auto_history(__enabled: bool) -> None: ... -def set_startup_hook(function: Optional[Callable[[], None]] = ...) -> None: ... -def set_pre_input_hook(function: Optional[Callable[[], None]] = ...) -> None: ... -def set_completer(function: _CompleterT = ...) -> None: ... +def set_startup_hook(__function: Optional[Callable[[], None]] = ...) -> None: ... +def set_pre_input_hook(__function: Optional[Callable[[], None]] = ...) -> None: ... +def set_completer(__function: _CompleterT = ...) -> None: ... def get_completer() -> _CompleterT: ... def get_completion_type() -> int: ... def get_begidx() -> int: ... def get_endidx() -> int: ... -def set_completer_delims(string: str) -> None: ... +def set_completer_delims(__string: str) -> None: ... def get_completer_delims() -> str: ... -def set_completion_display_matches_hook(function: _CompDispT = ...) -> None: ... +def set_completion_display_matches_hook(__function: _CompDispT = ...) -> None: ... diff --git a/mypy/typeshed/stdlib/selectors.pyi b/mypy/typeshed/stdlib/selectors.pyi index b019c4f9c442..94690efadbf8 100644 --- a/mypy/typeshed/stdlib/selectors.pyi +++ b/mypy/typeshed/stdlib/selectors.pyi @@ -41,6 +41,8 @@ if sys.platform != "win32": def unregister(self, fileobj: FileDescriptorLike) -> SelectorKey: ... def select(self, timeout: Optional[float] = ...) -> List[Tuple[SelectorKey, _EventMask]]: ... def get_map(self) -> Mapping[FileDescriptorLike, SelectorKey]: ... + +if sys.platform == "linux": class EpollSelector(BaseSelector): def fileno(self) -> int: ... def register(self, fileobj: FileDescriptorLike, events: _EventMask, data: Any = ...) -> SelectorKey: ... diff --git a/mypy/typeshed/stdlib/signal.pyi b/mypy/typeshed/stdlib/signal.pyi index 53d8caaca471..aa0bbf2bffd1 100644 --- a/mypy/typeshed/stdlib/signal.pyi +++ b/mypy/typeshed/stdlib/signal.pyi @@ -137,7 +137,7 @@ if sys.platform == "win32": CTRL_C_EVENT: int CTRL_BREAK_EVENT: int -if sys.platform != "win32": +if sys.platform != "win32" and sys.platform != "darwin": class struct_siginfo(Tuple[int, int, int, int, int, int, int]): def __init__(self, sequence: Iterable[int]) -> None: ... @property @@ -189,6 +189,7 @@ def signal(__signalnum: _SIGNUM, __handler: _HANDLER) -> _HANDLER: ... if sys.platform != "win32": def sigpending() -> Any: ... - def sigtimedwait(sigset: Iterable[int], timeout: float) -> Optional[struct_siginfo]: ... def sigwait(__sigset: Iterable[int]) -> _SIGNUM: ... - def sigwaitinfo(sigset: Iterable[int]) -> struct_siginfo: ... + if sys.platform != "darwin": + def sigtimedwait(sigset: Iterable[int], timeout: float) -> Optional[struct_siginfo]: ... + def sigwaitinfo(sigset: Iterable[int]) -> struct_siginfo: ... diff --git a/mypy/typeshed/stdlib/site.pyi b/mypy/typeshed/stdlib/site.pyi index e91176ac4db4..db7bbefcc794 100644 --- a/mypy/typeshed/stdlib/site.pyi +++ b/mypy/typeshed/stdlib/site.pyi @@ -1,4 +1,3 @@ -import sys from typing import Iterable, List, Optional PREFIXES: List[str] @@ -6,9 +5,7 @@ ENABLE_USER_SITE: Optional[bool] USER_SITE: Optional[str] USER_BASE: Optional[str] -if sys.version_info < (3,): - def main() -> None: ... - +def main() -> None: ... def addsitedir(sitedir: str, known_paths: Optional[Iterable[str]] = ...) -> None: ... def getsitepackages(prefixes: Optional[Iterable[str]] = ...) -> List[str]: ... def getuserbase() -> str: ... diff --git a/mypy/typeshed/stdlib/smtplib.pyi b/mypy/typeshed/stdlib/smtplib.pyi index 4f376d2b645c..48a35f8e3b67 100644 --- a/mypy/typeshed/stdlib/smtplib.pyi +++ b/mypy/typeshed/stdlib/smtplib.pyi @@ -2,7 +2,7 @@ from email.message import Message as _Message from socket import socket from ssl import SSLContext from types import TracebackType -from typing import Any, Dict, List, Optional, Pattern, Protocol, Sequence, Tuple, Type, Union, overload +from typing import Any, Dict, Optional, Pattern, Protocol, Sequence, Tuple, Type, Union, overload _Reply = Tuple[int, bytes] _SendErrs = Dict[str, _Reply] @@ -117,7 +117,7 @@ class SMTP: to_addrs: Union[str, Sequence[str]], msg: Union[bytes, str], mail_options: Sequence[str] = ..., - rcpt_options: List[str] = ..., + rcpt_options: Sequence[str] = ..., ) -> _SendErrs: ... def send_message( self, diff --git a/mypy/typeshed/stdlib/sqlite3/dbapi2.pyi b/mypy/typeshed/stdlib/sqlite3/dbapi2.pyi index 0eaab5eff2a9..74caa0e64cf1 100644 --- a/mypy/typeshed/stdlib/sqlite3/dbapi2.pyi +++ b/mypy/typeshed/stdlib/sqlite3/dbapi2.pyi @@ -102,10 +102,10 @@ else: cached_statements: int = ..., ) -> Connection: ... -def enable_callback_tracebacks(flag: bool) -> None: ... -def enable_shared_cache(do_enable: int) -> None: ... -def register_adapter(type: Type[_T], callable: Callable[[_T], Union[int, float, str, bytes]]) -> None: ... -def register_converter(typename: str, callable: Callable[[bytes], Any]) -> None: ... +def enable_callback_tracebacks(__enable: bool) -> None: ... +def enable_shared_cache(enable: int) -> None: ... +def register_adapter(__type: Type[_T], __caster: Callable[[_T], Union[int, float, str, bytes]]) -> None: ... +def register_converter(__name: str, __converter: Callable[[bytes], Any]) -> None: ... if sys.version_info < (3, 8): class Cache(object): @@ -136,17 +136,17 @@ class Connection(object): def __init__(self, *args: Any, **kwargs: Any) -> None: ... def close(self) -> None: ... def commit(self) -> None: ... - def create_aggregate(self, name: str, num_params: int, aggregate_class: Callable[[], _AggregateProtocol]) -> None: ... - def create_collation(self, name: str, callable: Any) -> None: ... + def create_aggregate(self, name: str, n_arg: int, aggregate_class: Callable[[], _AggregateProtocol]) -> None: ... + def create_collation(self, __name: str, __callback: Any) -> None: ... if sys.version_info >= (3, 8): - def create_function(self, name: str, num_params: int, func: Any, *, deterministic: bool = ...) -> None: ... + def create_function(self, name: str, narg: int, func: Any, *, deterministic: bool = ...) -> None: ... else: def create_function(self, name: str, num_params: int, func: Any) -> None: ... def cursor(self, cursorClass: Optional[type] = ...) -> Cursor: ... def execute(self, sql: str, parameters: Iterable[Any] = ...) -> Cursor: ... # TODO: please check in executemany() if seq_of_parameters type is possible like this - def executemany(self, sql: str, seq_of_parameters: Iterable[Iterable[Any]]) -> Cursor: ... - def executescript(self, sql_script: Union[bytes, Text]) -> Cursor: ... + def executemany(self, __sql: str, __parameters: Iterable[Iterable[Any]]) -> Cursor: ... + def executescript(self, __sql_script: Union[bytes, Text]) -> Cursor: ... def interrupt(self, *args: Any, **kwargs: Any) -> None: ... def iterdump(self, *args: Any, **kwargs: Any) -> Generator[str, None, None]: ... def rollback(self, *args: Any, **kwargs: Any) -> None: ... @@ -173,7 +173,7 @@ class Connection(object): ) -> None: ... def __call__(self, *args: Any, **kwargs: Any) -> Any: ... def __enter__(self) -> Connection: ... - def __exit__(self, t: Optional[type] = ..., exc: Optional[BaseException] = ..., tb: Optional[Any] = ...) -> None: ... + def __exit__(self, t: Optional[type], exc: Optional[BaseException], tb: Optional[Any]) -> None: ... class Cursor(Iterator[Any]): arraysize: Any @@ -187,9 +187,9 @@ class Cursor(Iterator[Any]): # however, the name of the __init__ variable is unknown def __init__(self, *args: Any, **kwargs: Any) -> None: ... def close(self, *args: Any, **kwargs: Any) -> None: ... - def execute(self, sql: str, parameters: Iterable[Any] = ...) -> Cursor: ... - def executemany(self, sql: str, seq_of_parameters: Iterable[Iterable[Any]]) -> Cursor: ... - def executescript(self, sql_script: Union[bytes, Text]) -> Cursor: ... + def execute(self, __sql: str, __parameters: Iterable[Any] = ...) -> Cursor: ... + def executemany(self, __sql: str, __seq_of_parameters: Iterable[Iterable[Any]]) -> Cursor: ... + def executescript(self, __sql_script: Union[bytes, Text]) -> Cursor: ... def fetchall(self) -> List[Any]: ... def fetchmany(self, size: Optional[int] = ...) -> List[Any]: ... def fetchone(self) -> Any: ... diff --git a/mypy/typeshed/stdlib/sre_constants.pyi b/mypy/typeshed/stdlib/sre_constants.pyi index d66f19eda717..07a308e6f67b 100644 --- a/mypy/typeshed/stdlib/sre_constants.pyi +++ b/mypy/typeshed/stdlib/sre_constants.pyi @@ -1,3 +1,4 @@ +import sys from typing import Any, Dict, List, Optional, Union MAGIC: int @@ -72,7 +73,10 @@ REPEAT: _NamedIntConstant REPEAT_ONE: _NamedIntConstant SUBPATTERN: _NamedIntConstant MIN_REPEAT_ONE: _NamedIntConstant -RANGE_IGNORE: _NamedIntConstant +if sys.version_info >= (3, 7): + RANGE_UNI_IGNORE: _NamedIntConstant +else: + RANGE_IGNORE: _NamedIntConstant MIN_REPEAT: _NamedIntConstant MAX_REPEAT: _NamedIntConstant diff --git a/mypy/typeshed/stdlib/symbol.pyi b/mypy/typeshed/stdlib/symbol.pyi index 6fbe306fabe9..036d3191452d 100644 --- a/mypy/typeshed/stdlib/symbol.pyi +++ b/mypy/typeshed/stdlib/symbol.pyi @@ -1,90 +1,92 @@ +import sys from typing import Dict -single_input: int -file_input: int -eval_input: int -decorator: int -decorators: int -decorated: int -async_funcdef: int -funcdef: int -parameters: int -typedargslist: int -tfpdef: int -varargslist: int -vfpdef: int -stmt: int -simple_stmt: int -small_stmt: int -expr_stmt: int -annassign: int -testlist_star_expr: int -augassign: int -del_stmt: int -pass_stmt: int -flow_stmt: int -break_stmt: int -continue_stmt: int -return_stmt: int -yield_stmt: int -raise_stmt: int -import_stmt: int -import_name: int -import_from: int -import_as_name: int -dotted_as_name: int -import_as_names: int -dotted_as_names: int -dotted_name: int -global_stmt: int -nonlocal_stmt: int -assert_stmt: int -compound_stmt: int -async_stmt: int -if_stmt: int -while_stmt: int -for_stmt: int -try_stmt: int -with_stmt: int -with_item: int -except_clause: int -suite: int -test: int -test_nocond: int -lambdef: int -lambdef_nocond: int -or_test: int -and_test: int -not_test: int -comparison: int -comp_op: int -star_expr: int -expr: int -xor_expr: int -and_expr: int -shift_expr: int -arith_expr: int -term: int -factor: int -power: int -atom_expr: int -atom: int -testlist_comp: int -trailer: int -subscriptlist: int -subscript: int -sliceop: int -exprlist: int -testlist: int -dictorsetmaker: int -classdef: int -arglist: int -argument: int -comp_iter: int -comp_for: int -comp_if: int -encoding_decl: int -yield_expr: int -yield_arg: int +if sys.version_info < (3, 10): + single_input: int + file_input: int + eval_input: int + decorator: int + decorators: int + decorated: int + async_funcdef: int + funcdef: int + parameters: int + typedargslist: int + tfpdef: int + varargslist: int + vfpdef: int + stmt: int + simple_stmt: int + small_stmt: int + expr_stmt: int + annassign: int + testlist_star_expr: int + augassign: int + del_stmt: int + pass_stmt: int + flow_stmt: int + break_stmt: int + continue_stmt: int + return_stmt: int + yield_stmt: int + raise_stmt: int + import_stmt: int + import_name: int + import_from: int + import_as_name: int + dotted_as_name: int + import_as_names: int + dotted_as_names: int + dotted_name: int + global_stmt: int + nonlocal_stmt: int + assert_stmt: int + compound_stmt: int + async_stmt: int + if_stmt: int + while_stmt: int + for_stmt: int + try_stmt: int + with_stmt: int + with_item: int + except_clause: int + suite: int + test: int + test_nocond: int + lambdef: int + lambdef_nocond: int + or_test: int + and_test: int + not_test: int + comparison: int + comp_op: int + star_expr: int + expr: int + xor_expr: int + and_expr: int + shift_expr: int + arith_expr: int + term: int + factor: int + power: int + atom_expr: int + atom: int + testlist_comp: int + trailer: int + subscriptlist: int + subscript: int + sliceop: int + exprlist: int + testlist: int + dictorsetmaker: int + classdef: int + arglist: int + argument: int + comp_iter: int + comp_for: int + comp_if: int + encoding_decl: int + yield_expr: int + yield_arg: int -sym_name: Dict[int, str] + sym_name: Dict[int, str] diff --git a/mypy/typeshed/stdlib/sys.pyi b/mypy/typeshed/stdlib/sys.pyi index 9052ce32b765..d431d2733b1c 100644 --- a/mypy/typeshed/stdlib/sys.pyi +++ b/mypy/typeshed/stdlib/sys.pyi @@ -1,6 +1,7 @@ import sys from builtins import object as _object -from importlib.abc import MetaPathFinder, PathEntryFinder +from importlib.abc import Loader, PathEntryFinder +from importlib.machinery import ModuleSpec from types import FrameType, ModuleType, TracebackType from typing import ( Any, @@ -10,6 +11,7 @@ from typing import ( List, NoReturn, Optional, + Protocol, Sequence, TextIO, Tuple, @@ -24,6 +26,14 @@ _T = TypeVar("_T") # The following type alias are stub-only and do not exist during runtime _ExcInfo = Tuple[Type[BaseException], BaseException, TracebackType] _OptExcInfo = Union[_ExcInfo, Tuple[None, None, None]] +_PathSequence = Sequence[Union[bytes, str]] + +# Unlike importlib.abc.MetaPathFinder, invalidate_caches() might not exist (see python docs) +class _MetaPathFinder(Protocol): + def find_module(self, fullname: str, path: Optional[_PathSequence]) -> Optional[Loader]: ... + def find_spec( + self, fullname: str, path: Optional[_PathSequence], target: Optional[ModuleType] = ... + ) -> Optional[ModuleSpec]: ... # ----- sys variables ----- if sys.platform != "win32": @@ -48,7 +58,7 @@ last_value: Optional[BaseException] last_traceback: Optional[TracebackType] maxsize: int maxunicode: int -meta_path: List[MetaPathFinder] +meta_path: List[_MetaPathFinder] modules: Dict[str, ModuleType] path: List[str] path_hooks: List[Any] # TODO precise type; function, path to finder diff --git a/mypy/typeshed/stdlib/tarfile.pyi b/mypy/typeshed/stdlib/tarfile.pyi index 1faba6e948e5..afb88161d68b 100644 --- a/mypy/typeshed/stdlib/tarfile.pyi +++ b/mypy/typeshed/stdlib/tarfile.pyi @@ -211,10 +211,10 @@ class TarFile(Iterable[TarInfo]): def next(self) -> Optional[TarInfo]: ... if sys.version_info >= (3, 5): def extractall( - self, path: AnyPath = ..., members: Optional[List[TarInfo]] = ..., *, numeric_owner: bool = ... + self, path: AnyPath = ..., members: Optional[Iterable[TarInfo]] = ..., *, numeric_owner: bool = ... ) -> None: ... else: - def extractall(self, path: AnyPath = ..., members: Optional[List[TarInfo]] = ...) -> None: ... + def extractall(self, path: AnyPath = ..., members: Optional[Iterable[TarInfo]] = ...) -> None: ... if sys.version_info >= (3, 5): def extract( self, member: Union[str, TarInfo], path: AnyPath = ..., set_attrs: bool = ..., *, numeric_owner: bool = ... diff --git a/mypy/typeshed/stdlib/termios.pyi b/mypy/typeshed/stdlib/termios.pyi index 9eecbf68136f..0c627f4b72bd 100644 --- a/mypy/typeshed/stdlib/termios.pyi +++ b/mypy/typeshed/stdlib/termios.pyi @@ -236,11 +236,11 @@ VWERASE: int XCASE: int XTABS: int -def tcgetattr(fd: FileDescriptorLike) -> List[Any]: ... -def tcsetattr(fd: FileDescriptorLike, when: int, attributes: _Attr) -> None: ... -def tcsendbreak(fd: FileDescriptorLike, duration: int) -> None: ... -def tcdrain(fd: FileDescriptorLike) -> None: ... -def tcflush(fd: FileDescriptorLike, queue: int) -> None: ... -def tcflow(fd: FileDescriptorLike, action: int) -> None: ... +def tcgetattr(__fd: FileDescriptorLike) -> List[Any]: ... +def tcsetattr(__fd: FileDescriptorLike, __when: int, __attributes: _Attr) -> None: ... +def tcsendbreak(__fd: FileDescriptorLike, __duration: int) -> None: ... +def tcdrain(__fd: FileDescriptorLike) -> None: ... +def tcflush(__fd: FileDescriptorLike, __queue: int) -> None: ... +def tcflow(__fd: FileDescriptorLike, __action: int) -> None: ... class error(Exception): ... diff --git a/mypy/typeshed/stdlib/time.pyi b/mypy/typeshed/stdlib/time.pyi index e3315733510e..5d9c0148e386 100644 --- a/mypy/typeshed/stdlib/time.pyi +++ b/mypy/typeshed/stdlib/time.pyi @@ -14,7 +14,8 @@ timezone: int tzname: Tuple[str, str] if sys.version_info >= (3, 7) and sys.platform != "win32": - CLOCK_BOOTTIME: int # Linux + if sys.platform == "linux": + CLOCK_BOOTTIME: int CLOCK_PROF: int # FreeBSD, NetBSD, OpenBSD CLOCK_UPTIME: int # FreeBSD, OpenBSD diff --git a/mypy/typeshed/stdlib/tkinter/__init__.pyi b/mypy/typeshed/stdlib/tkinter/__init__.pyi index 3597344dd965..db1dd1584ae8 100644 --- a/mypy/typeshed/stdlib/tkinter/__init__.pyi +++ b/mypy/typeshed/stdlib/tkinter/__init__.pyi @@ -1288,7 +1288,9 @@ class Canvas(Widget, XView, YView): lower: Any def move(self, *args): ... if sys.version_info >= (3, 8): - def moveto(self, tagOrId: Union[str, _CanvasItemId], x: str = ..., y: str = ...) -> None: ... + def moveto( + self, tagOrId: Union[str, _CanvasItemId], x: Union[Literal[""], float] = ..., y: Union[Literal[""], float] = ... + ) -> None: ... def postscript(self, cnf=..., **kw): ... def tag_raise(self, *args): ... lift: Any diff --git a/mypy/typeshed/stdlib/tkinter/font.pyi b/mypy/typeshed/stdlib/tkinter/font.pyi index 81eacd507202..a19c2dd6dc67 100644 --- a/mypy/typeshed/stdlib/tkinter/font.pyi +++ b/mypy/typeshed/stdlib/tkinter/font.pyi @@ -1,3 +1,4 @@ +import sys import tkinter from typing import Any, List, Optional, Tuple, TypeVar, Union, overload from typing_extensions import Literal, TypedDict @@ -7,8 +8,6 @@ ROMAN: Literal["roman"] BOLD: Literal["bold"] ITALIC: Literal["italic"] -def nametofont(name: str) -> Font: ... - # Can contain e.g. nested sequences ('FONT DESCRIPTIONS' in font man page) _FontDescription = Union[str, Font, tkinter._TkinterSequence[Any]] @@ -95,3 +94,9 @@ class Font: def families(root: Optional[tkinter.Misc] = ..., displayof: Optional[tkinter.Misc] = ...) -> Tuple[str, ...]: ... def names(root: Optional[tkinter.Misc] = ...) -> Tuple[str, ...]: ... + +if sys.version_info >= (3, 10): + def nametofont(name: str, root: Optional[tkinter.Misc] = ...) -> Font: ... + +else: + def nametofont(name: str) -> Font: ... diff --git a/mypy/typeshed/stdlib/traceback.pyi b/mypy/typeshed/stdlib/traceback.pyi index 86635427462d..3c24ee21bee4 100644 --- a/mypy/typeshed/stdlib/traceback.pyi +++ b/mypy/typeshed/stdlib/traceback.pyi @@ -7,7 +7,17 @@ _PT = Tuple[str, int, str, Optional[str]] def print_tb(tb: Optional[TracebackType], limit: Optional[int] = ..., file: Optional[IO[str]] = ...) -> None: ... -if sys.version_info >= (3,): +if sys.version_info >= (3, 10): + def print_exception( + __exc: Optional[Type[BaseException]], + value: Optional[BaseException] = ..., + tb: Optional[TracebackType] = ..., + limit: Optional[int] = ..., + file: Optional[IO[str]] = ..., + chain: bool = ..., + ) -> None: ... + +elif sys.version_info >= (3,): def print_exception( etype: Optional[Type[BaseException]], value: Optional[BaseException], @@ -16,8 +26,6 @@ if sys.version_info >= (3,): file: Optional[IO[str]] = ..., chain: bool = ..., ) -> None: ... - def print_exc(limit: Optional[int] = ..., file: Optional[IO[str]] = ..., chain: bool = ...) -> None: ... - def print_last(limit: Optional[int] = ..., file: Optional[IO[str]] = ..., chain: bool = ...) -> None: ... else: def print_exception( @@ -27,6 +35,12 @@ else: limit: Optional[int] = ..., file: Optional[IO[str]] = ..., ) -> None: ... + +if sys.version_info >= (3,): + def print_exc(limit: Optional[int] = ..., file: Optional[IO[str]] = ..., chain: bool = ...) -> None: ... + def print_last(limit: Optional[int] = ..., file: Optional[IO[str]] = ..., chain: bool = ...) -> None: ... + +else: def print_exc(limit: Optional[int] = ..., file: Optional[IO[str]] = ...) -> None: ... def print_last(limit: Optional[int] = ..., file: Optional[IO[str]] = ...) -> None: ... @@ -44,9 +58,22 @@ else: def extract_stack(f: Optional[FrameType] = ..., limit: Optional[int] = ...) -> List[_PT]: ... def format_list(extracted_list: List[_PT]) -> List[str]: ... -def format_exception_only(etype: Optional[Type[BaseException]], value: Optional[BaseException]) -> List[str]: ... +if sys.version_info >= (3, 10): + def format_exception_only(__exc: Optional[Type[BaseException]], value: Optional[BaseException] = ...) -> List[str]: ... -if sys.version_info >= (3,): +else: + def format_exception_only(etype: Optional[Type[BaseException]], value: Optional[BaseException]) -> List[str]: ... + +if sys.version_info >= (3, 10): + def format_exception( + __exc: Optional[Type[BaseException]], + value: Optional[BaseException] = ..., + tb: Optional[TracebackType] = ..., + limit: Optional[int] = ..., + chain: bool = ..., + ) -> List[str]: ... + +elif sys.version_info >= (3,): def format_exception( etype: Optional[Type[BaseException]], value: Optional[BaseException], @@ -54,7 +81,6 @@ if sys.version_info >= (3,): limit: Optional[int] = ..., chain: bool = ..., ) -> List[str]: ... - def format_exc(limit: Optional[int] = ..., chain: bool = ...) -> str: ... else: def format_exception( @@ -63,6 +89,11 @@ else: tb: Optional[TracebackType], limit: Optional[int] = ..., ) -> List[str]: ... + +if sys.version_info >= (3,): + def format_exc(limit: Optional[int] = ..., chain: bool = ...) -> str: ... + +else: def format_exc(limit: Optional[int] = ...) -> str: ... def format_tb(tb: Optional[TracebackType], limit: Optional[int] = ...) -> List[str]: ... diff --git a/mypy/typeshed/stdlib/types.pyi b/mypy/typeshed/stdlib/types.pyi index 7fb0de7acff1..2b4d32392a45 100644 --- a/mypy/typeshed/stdlib/types.pyi +++ b/mypy/typeshed/stdlib/types.pyi @@ -1,12 +1,10 @@ import sys import typing +from importlib.abc import _LoaderProtocol +from importlib.machinery import ModuleSpec from typing import Any, Awaitable, Callable, Dict, Generic, Iterable, Iterator, Mapping, Optional, Tuple, Type, TypeVar, overload from typing_extensions import Literal, final -# ModuleType is exported from this module, but for circular import -# reasons exists in its own stub file (with ModuleSpec and Loader). -from _importlib_modulespec import ModuleType as ModuleType # Exported - # Note, all classes "defined" here require special handling. _T = TypeVar("_T") @@ -135,6 +133,15 @@ class SimpleNamespace: def __setattr__(self, name: str, value: Any) -> None: ... def __delattr__(self, name: str) -> None: ... +class ModuleType: + __name__: str + __file__: str + __dict__: Dict[str, Any] + __loader__: Optional[_LoaderProtocol] + __package__: Optional[str] + __spec__: Optional[ModuleSpec] + def __init__(self, name: str, doc: Optional[str] = ...) -> None: ... + class GeneratorType: gi_code: CodeType gi_frame: FrameType diff --git a/mypy/typeshed/stdlib/typing.pyi b/mypy/typeshed/stdlib/typing.pyi index 7f168a9196c2..aafb1fbdf1b3 100644 --- a/mypy/typeshed/stdlib/typing.pyi +++ b/mypy/typeshed/stdlib/typing.pyi @@ -12,7 +12,6 @@ if sys.version_info >= (3, 9): # Definitions of special type checking related constructs. Their definitions # are not used, so their value does not matter. -overload = object() Any = object() class TypeVar: @@ -35,6 +34,10 @@ _promote = object() class _SpecialForm: def __getitem__(self, typeargs: Any) -> object: ... +_F = TypeVar("_F", bound=Callable[..., Any]) + +def overload(func: _F) -> _F: ... + Union: _SpecialForm = ... Optional: _SpecialForm = ... Tuple: _SpecialForm = ... @@ -46,7 +49,6 @@ Type: _SpecialForm = ... ClassVar: _SpecialForm = ... if sys.version_info >= (3, 8): Final: _SpecialForm = ... - _F = TypeVar("_F", bound=Callable[..., Any]) def final(f: _F) -> _F: ... Literal: _SpecialForm = ... # TypedDict is a (non-subscriptable) special form. @@ -56,9 +58,24 @@ if sys.version_info < (3, 7): class GenericMeta(type): ... if sys.version_info >= (3, 10): + class ParamSpecArgs: + __origin__: ParamSpec + def __init__(self, origin: ParamSpec) -> None: ... + class ParamSpecKwargs: + __origin__: ParamSpec + def __init__(self, origin: ParamSpec) -> None: ... class ParamSpec: __name__: str - def __init__(self, name: str) -> None: ... + __bound__: Optional[Type[Any]] + __covariant__: bool + __contravariant__: bool + def __init__( + self, name: str, *, bound: Union[None, Type[Any], str] = ..., contravariant: bool = ..., covariant: bool = ... + ) -> None: ... + @property + def args(self) -> ParamSpecArgs: ... + @property + def kwargs(self) -> ParamSpecKwargs: ... Concatenate: _SpecialForm = ... TypeAlias: _SpecialForm = ... TypeGuard: _SpecialForm = ... @@ -79,11 +96,9 @@ _KT_co = TypeVar("_KT_co", covariant=True) # Key type covariant containers. _VT_co = TypeVar("_VT_co", covariant=True) # Value type covariant containers. _T_contra = TypeVar("_T_contra", contravariant=True) # Ditto contravariant. _TC = TypeVar("_TC", bound=Type[object]) -_C = TypeVar("_C", bound=Callable[..., Any]) - -no_type_check = object() -def no_type_check_decorator(decorator: _C) -> _C: ... +def no_type_check(arg: _F) -> _F: ... +def no_type_check_decorator(decorator: _F) -> _F: ... # Type aliases and type constructors @@ -666,7 +681,7 @@ class _TypedDict(Mapping[str, object], metaclass=ABCMeta): def NewType(name: str, tp: Type[_T]) -> Type[_T]: ... # This itself is only available during type checking -def type_check_only(func_or_cls: _C) -> _C: ... +def type_check_only(func_or_cls: _F) -> _F: ... if sys.version_info >= (3, 7): from types import CodeType @@ -681,3 +696,6 @@ if sys.version_info >= (3, 7): def __eq__(self, other: Any) -> bool: ... def __hash__(self) -> int: ... def __repr__(self) -> str: ... + +if sys.version_info >= (3, 10): + def is_typeddict(tp: Any) -> bool: ... diff --git a/mypy/typeshed/stdlib/typing_extensions.pyi b/mypy/typeshed/stdlib/typing_extensions.pyi index 8dd41b53908a..0250866f3eb8 100644 --- a/mypy/typeshed/stdlib/typing_extensions.pyi +++ b/mypy/typeshed/stdlib/typing_extensions.pyi @@ -20,7 +20,9 @@ from typing import ( Tuple, Type as Type, TypeVar, + Union, ValuesView, + _Alias, overload as overload, ) @@ -42,7 +44,7 @@ def final(f: _F) -> _F: ... Literal: _SpecialForm = ... -def IntVar(__name: str) -> Any: ... # returns a new TypeVar +def IntVar(name: str) -> Any: ... # returns a new TypeVar # Internal mypy fallback type for all typed dicts (does not exist at runtime) class _TypedDict(Mapping[str, object], metaclass=abc.ABCMeta): @@ -67,6 +69,8 @@ class _TypedDict(Mapping[str, object], metaclass=abc.ABCMeta): # TypedDict is a (non-subscriptable) special form. TypedDict: object = ... +OrderedDict = _Alias() + if sys.version_info >= (3, 3): from typing import ChainMap as ChainMap @@ -96,9 +100,6 @@ if sys.version_info >= (3, 7): Annotated: _SpecialForm = ... _AnnotatedAlias: Any = ... # undocumented -# TypeAlias is a (non-subscriptable) special form. -class TypeAlias: ... - @runtime_checkable class SupportsIndex(Protocol, metaclass=abc.ABCMeta): @abc.abstractmethod @@ -106,12 +107,26 @@ class SupportsIndex(Protocol, metaclass=abc.ABCMeta): # PEP 612 support for Python < 3.9 if sys.version_info >= (3, 10): - from typing import Concatenate as Concatenate, ParamSpec as ParamSpec + from typing import Concatenate as Concatenate, ParamSpec as ParamSpec, TypeAlias as TypeAlias, TypeGuard as TypeGuard else: + class ParamSpecArgs: + __origin__: ParamSpec + def __init__(self, origin: ParamSpec) -> None: ... + class ParamSpecKwargs: + __origin__: ParamSpec + def __init__(self, origin: ParamSpec) -> None: ... class ParamSpec: __name__: str - def __init__(self, name: str) -> None: ... + __bound__: Optional[Type[Any]] + __covariant__: bool + __contravariant__: bool + def __init__( + self, name: str, *, bound: Union[None, Type[Any], str] = ..., contravariant: bool = ..., covariant: bool = ... + ) -> None: ... + @property + def args(self) -> ParamSpecArgs: ... + @property + def kwargs(self) -> ParamSpecKwargs: ... Concatenate: _SpecialForm = ... - -# PEP 647 -TypeGuard: _SpecialForm = ... + TypeAlias: _SpecialForm = ... + TypeGuard: _SpecialForm = ... diff --git a/mypy/typeshed/stdlib/urllib/request.pyi b/mypy/typeshed/stdlib/urllib/request.pyi index bed840c00663..ba7fe8cc3508 100644 --- a/mypy/typeshed/stdlib/urllib/request.pyi +++ b/mypy/typeshed/stdlib/urllib/request.pyi @@ -21,7 +21,7 @@ from typing import ( overload, ) from urllib.error import HTTPError -from urllib.response import addinfourl +from urllib.response import addclosehook, addinfourl _T = TypeVar("_T") _UrlopenRet = Any @@ -97,7 +97,6 @@ class BaseHandler: parent: OpenerDirector def add_parent(self, parent: OpenerDirector) -> None: ... def close(self) -> None: ... - def http_error_nnn(self, req: Request, fp: IO[str], code: int, msg: int, headers: Mapping[str, str]) -> _UrlopenRet: ... class HTTPDefaultErrorHandler(BaseHandler): def http_error_default( @@ -156,6 +155,8 @@ class HTTPPasswordMgrWithPriorAuth(HTTPPasswordMgrWithDefaultRealm): class AbstractBasicAuthHandler: rx: ClassVar[Pattern[str]] # undocumented + passwd: HTTPPasswordMgr + add_password: Callable[[str, Union[str, Sequence[str]], str, str], None] def __init__(self, password_mgr: Optional[HTTPPasswordMgr] = ...) -> None: ... def http_error_auth_reqed(self, authreq: str, host: str, req: Request, headers: Mapping[str, str]) -> None: ... def http_request(self, req: Request) -> Request: ... # undocumented @@ -228,6 +229,12 @@ class ftpwrapper: # undocumented def __init__( self, user: str, passwd: str, host: str, port: int, dirs: str, timeout: Optional[float] = ..., persistent: bool = ... ) -> None: ... + def close(self) -> None: ... + def endtransfer(self) -> None: ... + def file_close(self) -> None: ... + def init(self) -> None: ... + def real_close(self) -> None: ... + def retrfile(self, file: str, type: str) -> Tuple[addclosehook, int]: ... class FTPHandler(BaseHandler): def ftp_open(self, req: Request) -> addinfourl: ... diff --git a/mypy/typeshed/stdlib/urllib/robotparser.pyi b/mypy/typeshed/stdlib/urllib/robotparser.pyi index ad96ca12bfc4..382dcee0e859 100644 --- a/mypy/typeshed/stdlib/urllib/robotparser.pyi +++ b/mypy/typeshed/stdlib/urllib/robotparser.pyi @@ -10,7 +10,7 @@ class RobotFileParser: def set_url(self, url: str) -> None: ... def read(self) -> None: ... def parse(self, lines: Iterable[str]) -> None: ... - def can_fetch(self, user_agent: str, url: str) -> bool: ... + def can_fetch(self, useragent: str, url: str) -> bool: ... def mtime(self) -> int: ... def modified(self) -> None: ... def crawl_delay(self, useragent: str) -> Optional[str]: ... diff --git a/mypy/typeshed/stdlib/webbrowser.pyi b/mypy/typeshed/stdlib/webbrowser.pyi index e29238ee07ff..322ec2764e39 100644 --- a/mypy/typeshed/stdlib/webbrowser.pyi +++ b/mypy/typeshed/stdlib/webbrowser.pyi @@ -38,7 +38,7 @@ class BackgroundBrowser(GenericBrowser): def open(self, url: Text, new: int = ..., autoraise: bool = ...) -> bool: ... class UnixBrowser(BaseBrowser): - raise_opts: List[str] + raise_opts: Optional[List[str]] background: bool redirect_stdout: bool remote_args: List[str] @@ -48,7 +48,6 @@ class UnixBrowser(BaseBrowser): def open(self, url: Text, new: int = ..., autoraise: bool = ...) -> bool: ... class Mozilla(UnixBrowser): - raise_opts: List[str] remote_args: List[str] remote_action: str remote_action_newwin: str @@ -70,7 +69,6 @@ class Chrome(UnixBrowser): background: bool class Opera(UnixBrowser): - raise_opts: List[str] remote_args: List[str] remote_action: str remote_action_newwin: str diff --git a/mypy/typeshed/stdlib/xml/dom/minicompat.pyi b/mypy/typeshed/stdlib/xml/dom/minicompat.pyi index 964e6fa3f426..aa8efd03b19f 100644 --- a/mypy/typeshed/stdlib/xml/dom/minicompat.pyi +++ b/mypy/typeshed/stdlib/xml/dom/minicompat.pyi @@ -1,3 +1,17 @@ -from typing import Any +from typing import Any, Iterable, List, Optional, Tuple, Type, TypeVar -def __getattr__(name: str) -> Any: ... # incomplete +_T = TypeVar("_T") + +StringTypes: Tuple[Type[str]] + +class NodeList(List[_T]): + length: int + def item(self, index: int) -> Optional[_T]: ... + +class EmptyNodeList(Tuple[Any, ...]): + length: int + def item(self, index: int) -> None: ... + def __add__(self, other: Iterable[_T]) -> NodeList[_T]: ... # type: ignore + def __radd__(self, other: Iterable[_T]) -> NodeList[_T]: ... + +def defproperty(klass: Type[Any], name: str, doc: str) -> None: ... diff --git a/mypy/typeshed/stdlib/xml/dom/minidom.pyi b/mypy/typeshed/stdlib/xml/dom/minidom.pyi index dc128e016548..67e9b1189528 100644 --- a/mypy/typeshed/stdlib/xml/dom/minidom.pyi +++ b/mypy/typeshed/stdlib/xml/dom/minidom.pyi @@ -1,6 +1,312 @@ -from typing import IO, Any, Optional, Text, Union +import sys +import xml.dom +from typing import IO, Any, Optional, Text as _Text, TypeVar, Union +from xml.dom.xmlbuilder import DocumentLS, DOMImplementationLS from xml.sax.xmlreader import XMLReader +_T = TypeVar("_T") + def parse(file: Union[str, IO[Any]], parser: Optional[XMLReader] = ..., bufsize: Optional[int] = ...): ... -def parseString(string: Union[bytes, Text], parser: Optional[XMLReader] = ...): ... -def __getattr__(name: str) -> Any: ... # incomplete +def parseString(string: Union[bytes, _Text], parser: Optional[XMLReader] = ...): ... +def getDOMImplementation(features=...): ... + +class Node(xml.dom.Node): + namespaceURI: Optional[str] + parentNode: Any + ownerDocument: Any + nextSibling: Any + previousSibling: Any + prefix: Any + if sys.version_info >= (3, 9): + def toxml(self, encoding: Optional[Any] = ..., standalone: Optional[Any] = ...): ... + def toprettyxml( + self, indent: str = ..., newl: str = ..., encoding: Optional[Any] = ..., standalone: Optional[Any] = ... + ): ... + else: + def toxml(self, encoding: Optional[Any] = ...): ... + def toprettyxml(self, indent: str = ..., newl: str = ..., encoding: Optional[Any] = ...): ... + def hasChildNodes(self) -> bool: ... + def insertBefore(self, newChild, refChild): ... + def appendChild(self, node): ... + def replaceChild(self, newChild, oldChild): ... + def removeChild(self, oldChild): ... + def normalize(self) -> None: ... + def cloneNode(self, deep): ... + def isSupported(self, feature, version): ... + def isSameNode(self, other): ... + def getInterface(self, feature): ... + def getUserData(self, key): ... + def setUserData(self, key, data, handler): ... + childNodes: Any + def unlink(self) -> None: ... + def __enter__(self: _T) -> _T: ... + def __exit__(self, et, ev, tb) -> None: ... + +class DocumentFragment(Node): + nodeType: Any + nodeName: str + nodeValue: Any + attributes: Any + parentNode: Any + childNodes: Any + def __init__(self) -> None: ... + +class Attr(Node): + name: str + nodeType: Any + attributes: Any + specified: bool + ownerElement: Any + namespaceURI: Optional[str] + childNodes: Any + nodeName: Any + nodeValue: str + value: str + prefix: Any + def __init__( + self, qName: str, namespaceURI: Optional[str] = ..., localName: Optional[Any] = ..., prefix: Optional[Any] = ... + ) -> None: ... + def unlink(self) -> None: ... + +class NamedNodeMap: + def __init__(self, attrs, attrsNS, ownerElement) -> None: ... + def item(self, index): ... + def items(self): ... + def itemsNS(self): ... + def __contains__(self, key): ... + def keys(self): ... + def keysNS(self): ... + def values(self): ... + def get(self, name, value: Optional[Any] = ...): ... + def __len__(self) -> int: ... + def __eq__(self, other: Any) -> bool: ... + def __ge__(self, other: Any) -> bool: ... + def __gt__(self, other: Any) -> bool: ... + def __le__(self, other: Any) -> bool: ... + def __lt__(self, other: Any) -> bool: ... + def __getitem__(self, attname_or_tuple): ... + def __setitem__(self, attname, value) -> None: ... + def getNamedItem(self, name): ... + def getNamedItemNS(self, namespaceURI: str, localName): ... + def removeNamedItem(self, name): ... + def removeNamedItemNS(self, namespaceURI: str, localName): ... + def setNamedItem(self, node): ... + def setNamedItemNS(self, node): ... + def __delitem__(self, attname_or_tuple) -> None: ... + +AttributeList = NamedNodeMap + +class TypeInfo: + namespace: Any + name: Any + def __init__(self, namespace, name) -> None: ... + +class Element(Node): + nodeType: Any + nodeValue: Any + schemaType: Any + parentNode: Any + tagName: str + prefix: Any + namespaceURI: Optional[str] + childNodes: Any + nextSibling: Any + def __init__( + self, tagName, namespaceURI: Optional[str] = ..., prefix: Optional[Any] = ..., localName: Optional[Any] = ... + ) -> None: ... + def unlink(self) -> None: ... + def getAttribute(self, attname): ... + def getAttributeNS(self, namespaceURI: str, localName): ... + def setAttribute(self, attname, value) -> None: ... + def setAttributeNS(self, namespaceURI: str, qualifiedName: str, value) -> None: ... + def getAttributeNode(self, attrname): ... + def getAttributeNodeNS(self, namespaceURI: str, localName): ... + def setAttributeNode(self, attr): ... + setAttributeNodeNS: Any + def removeAttribute(self, name) -> None: ... + def removeAttributeNS(self, namespaceURI: str, localName) -> None: ... + def removeAttributeNode(self, node): ... + removeAttributeNodeNS: Any + def hasAttribute(self, name: str) -> bool: ... + def hasAttributeNS(self, namespaceURI: str, localName) -> bool: ... + def getElementsByTagName(self, name): ... + def getElementsByTagNameNS(self, namespaceURI: str, localName): ... + def writexml(self, writer, indent: str = ..., addindent: str = ..., newl: str = ...) -> None: ... + def hasAttributes(self) -> bool: ... + def setIdAttribute(self, name) -> None: ... + def setIdAttributeNS(self, namespaceURI: str, localName) -> None: ... + def setIdAttributeNode(self, idAttr) -> None: ... + +class Childless: + attributes: Any + childNodes: Any + firstChild: Any + lastChild: Any + def appendChild(self, node) -> None: ... + def hasChildNodes(self) -> bool: ... + def insertBefore(self, newChild, refChild) -> None: ... + def removeChild(self, oldChild) -> None: ... + def normalize(self) -> None: ... + def replaceChild(self, newChild, oldChild) -> None: ... + +class ProcessingInstruction(Childless, Node): + nodeType: Any + target: Any + data: Any + def __init__(self, target, data) -> None: ... + nodeValue: Any + nodeName: Any + def writexml(self, writer, indent: str = ..., addindent: str = ..., newl: str = ...) -> None: ... + +class CharacterData(Childless, Node): + ownerDocument: Any + previousSibling: Any + def __init__(self) -> None: ... + def __len__(self) -> int: ... + data: str + nodeValue: Any + def substringData(self, offset: int, count: int) -> str: ... + def appendData(self, arg: str) -> None: ... + def insertData(self, offset: int, arg: str) -> None: ... + def deleteData(self, offset: int, count: int) -> None: ... + def replaceData(self, offset: int, count: int, arg: str) -> None: ... + +class Text(CharacterData): + nodeType: Any + nodeName: str + attributes: Any + data: Any + def splitText(self, offset): ... + def writexml(self, writer, indent: str = ..., addindent: str = ..., newl: str = ...) -> None: ... + def replaceWholeText(self, content): ... + +class Comment(CharacterData): + nodeType: Any + nodeName: str + def __init__(self, data) -> None: ... + def writexml(self, writer, indent: str = ..., addindent: str = ..., newl: str = ...) -> None: ... + +class CDATASection(Text): + nodeType: Any + nodeName: str + def writexml(self, writer, indent: str = ..., addindent: str = ..., newl: str = ...) -> None: ... + +class ReadOnlySequentialNamedNodeMap: + def __init__(self, seq=...) -> None: ... + def __len__(self): ... + def getNamedItem(self, name): ... + def getNamedItemNS(self, namespaceURI: str, localName): ... + def __getitem__(self, name_or_tuple): ... + def item(self, index): ... + def removeNamedItem(self, name) -> None: ... + def removeNamedItemNS(self, namespaceURI: str, localName) -> None: ... + def setNamedItem(self, node) -> None: ... + def setNamedItemNS(self, node) -> None: ... + +class Identified: ... + +class DocumentType(Identified, Childless, Node): + nodeType: Any + nodeValue: Any + name: Any + publicId: Any + systemId: Any + internalSubset: Any + entities: Any + notations: Any + nodeName: Any + def __init__(self, qualifiedName: str) -> None: ... + def cloneNode(self, deep): ... + def writexml(self, writer, indent: str = ..., addindent: str = ..., newl: str = ...) -> None: ... + +class Entity(Identified, Node): + attributes: Any + nodeType: Any + nodeValue: Any + actualEncoding: Any + encoding: Any + version: Any + nodeName: Any + notationName: Any + childNodes: Any + def __init__(self, name, publicId, systemId, notation) -> None: ... + def appendChild(self, newChild) -> None: ... + def insertBefore(self, newChild, refChild) -> None: ... + def removeChild(self, oldChild) -> None: ... + def replaceChild(self, newChild, oldChild) -> None: ... + +class Notation(Identified, Childless, Node): + nodeType: Any + nodeValue: Any + nodeName: Any + def __init__(self, name, publicId, systemId) -> None: ... + +class DOMImplementation(DOMImplementationLS): + def hasFeature(self, feature, version) -> bool: ... + def createDocument(self, namespaceURI: str, qualifiedName: str, doctype): ... + def createDocumentType(self, qualifiedName: str, publicId, systemId): ... + def getInterface(self, feature): ... + +class ElementInfo: + tagName: Any + def __init__(self, name) -> None: ... + def getAttributeType(self, aname): ... + def getAttributeTypeNS(self, namespaceURI: str, localName): ... + def isElementContent(self): ... + def isEmpty(self): ... + def isId(self, aname): ... + def isIdNS(self, namespaceURI: str, localName): ... + +class Document(Node, DocumentLS): + implementation: Any + nodeType: Any + nodeName: str + nodeValue: Any + attributes: Any + parentNode: Any + previousSibling: Any + nextSibling: Any + actualEncoding: Any + encoding: Any + standalone: Any + version: Any + strictErrorChecking: bool + errorHandler: Any + documentURI: Any + doctype: Any + childNodes: Any + def __init__(self) -> None: ... + def appendChild(self, node): ... + documentElement: Any + def removeChild(self, oldChild): ... + def unlink(self) -> None: ... + def cloneNode(self, deep): ... + def createDocumentFragment(self): ... + def createElement(self, tagName: str): ... + def createTextNode(self, data): ... + def createCDATASection(self, data): ... + def createComment(self, data): ... + def createProcessingInstruction(self, target, data): ... + def createAttribute(self, qName) -> Attr: ... + def createElementNS(self, namespaceURI: str, qualifiedName: str): ... + def createAttributeNS(self, namespaceURI: str, qualifiedName: str) -> Attr: ... + def getElementById(self, id): ... + def getElementsByTagName(self, name: str): ... + def getElementsByTagNameNS(self, namespaceURI: str, localName): ... + def isSupported(self, feature, version): ... + def importNode(self, node, deep): ... + if sys.version_info >= (3, 9): + def writexml( + self, + writer, + indent: str = ..., + addindent: str = ..., + newl: str = ..., + encoding: Optional[Any] = ..., + standalone: Optional[Any] = ..., + ) -> None: ... + else: + def writexml( + self, writer, indent: str = ..., addindent: str = ..., newl: str = ..., encoding: Optional[Any] = ... + ) -> None: ... + def renameNode(self, n, namespaceURI: str, name): ... diff --git a/mypy/typeshed/stdlib/xml/dom/xmlbuilder.pyi b/mypy/typeshed/stdlib/xml/dom/xmlbuilder.pyi index 964e6fa3f426..d8936bdc2ab4 100644 --- a/mypy/typeshed/stdlib/xml/dom/xmlbuilder.pyi +++ b/mypy/typeshed/stdlib/xml/dom/xmlbuilder.pyi @@ -1,3 +1,6 @@ from typing import Any def __getattr__(name: str) -> Any: ... # incomplete + +class DocumentLS(Any): ... # type: ignore +class DOMImplementationLS(Any): ... # type: ignore diff --git a/mypy/typeshed/stdlib/xxlimited.pyi b/mypy/typeshed/stdlib/xxlimited.pyi index e47694586ab1..0dddbb876638 100644 --- a/mypy/typeshed/stdlib/xxlimited.pyi +++ b/mypy/typeshed/stdlib/xxlimited.pyi @@ -1,13 +1,18 @@ +import sys from typing import Any -class Null: ... class Str: ... class Xxo: def demo(self) -> None: ... -class error: ... - def foo(__i: int, __j: int) -> Any: ... def new() -> Xxo: ... -def roj(__b: Any) -> None: ... + +if sys.version_info >= (3, 10): + class Error: ... + +else: + class error: ... + class Null: ... + def roj(__b: Any) -> None: ... diff --git a/test-data/unit/lib-stub/array.pyi.bak b/test-data/unit/lib-stub/array.pyi.bak new file mode 100644 index 000000000000..74be1cd6c16e --- /dev/null +++ b/test-data/unit/lib-stub/array.pyi.bak @@ -0,0 +1,15 @@ +from typing import MutableSequence, Generic, Literal, Union, TypeVar, Text + + +_IntTypeCode = Literal["b", "B", "h", "H", "i", "I", "l", "L", "q", "Q"] +_FloatTypeCode = Literal["f", "d"] +_UnicodeTypeCode = Literal["u"] +_TypeCode = Union[_IntTypeCode, _FloatTypeCode, _UnicodeTypeCode] + + +_T = TypeVar("_T", int, float, Text) + + +class array(MutableSequence[_T], Generic[_T]): + typecode: _TypeCode + itemsize: int From 86a6ce93ac72ff113c7ccb637ee2f406f8cd0912 Mon Sep 17 00:00:00 2001 From: Adrian Freund Date: Mon, 10 May 2021 17:21:05 +0200 Subject: [PATCH 37/76] Fix sequence pattern outer type check to be in line with PEP 634 --- mypy/checkpattern.py | 45 ++++++++---- test-data/unit/check-python310.test | 96 ++++++++++++++++++++----- test-data/unit/fixtures/primitives.pyi | 17 ++++- test-data/unit/fixtures/typing-full.pyi | 9 +++ test-data/unit/lib-stub/collections.pyi | 4 +- 5 files changed, 139 insertions(+), 32 deletions(-) diff --git a/mypy/checkpattern.py b/mypy/checkpattern.py index 05fbe3491333..b8be1a4b60ee 100644 --- a/mypy/checkpattern.py +++ b/mypy/checkpattern.py @@ -36,6 +36,12 @@ "builtins.tuple", ] +non_sequence_match_type_names = [ + "builtins.str", + "builtins.bytes", + "builtins.bytearray" +] + PatternType = NamedTuple( 'PatternType', @@ -67,6 +73,8 @@ class PatternChecker(PatternVisitor[PatternType]): self_match_types = None # type: List[Type] + non_sequence_match_types = None # type: List[Type] + def __init__(self, chk: 'mypy.checker.TypeChecker', msg: MessageBuilder, plugin: Plugin @@ -76,7 +84,8 @@ def __init__(self, self.plugin = plugin self.type_context = [] - self.self_match_types = self.generate_self_match_types() + self.self_match_types = self.generate_types(self_match_type_names) + self.non_sequence_match_types = self.generate_types(non_sequence_match_type_names) def accept(self, o: Pattern, type_context: Type) -> PatternType: self.type_context.append(type_context) @@ -168,7 +177,8 @@ def visit_sequence_pattern(self, o: SequencePattern) -> PatternType: # check for existence of a starred pattern # current_type = get_proper_type(self.type_context[-1]) - can_match = True + if not self.can_match_sequence(current_type): + return early_non_match() star_positions = [i for i, p in enumerate(o.patterns) if isinstance(p, StarredPattern)] star_position = None # type: Optional[int] if len(star_positions) == 1: @@ -192,11 +202,7 @@ def visit_sequence_pattern(self, o: SequencePattern) -> PatternType: else: inner_type = self.get_sequence_type(current_type) if inner_type is None: - if is_subtype(self.chk.named_type("typing.Iterable"), current_type): - # Current type is more general, but the actual value could still be iterable - inner_type = self.chk.named_type("builtins.object") - else: - return early_non_match() + inner_type = self.chk.named_type("builtins.object") inner_types = [inner_type] * len(o.patterns) # @@ -208,6 +214,7 @@ def visit_sequence_pattern(self, o: SequencePattern) -> PatternType: contracted_inner_types = self.contract_starred_pattern_types(inner_types, star_position, required_patterns) + can_match = True for p, t in zip(o.patterns, contracted_inner_types): pattern_type = self.accept(p, t) typ, type_map = pattern_type @@ -465,13 +472,25 @@ def should_self_match(self, typ: ProperType) -> bool: return True return False - def generate_self_match_types(self) -> List[Type]: + def can_match_sequence(self, typ: ProperType) -> bool: + for other in self.non_sequence_match_types: + # We have to ignore promotions, as memoryview should match, but bytes, + # which it can be promoted to, shouldn't + if is_subtype(typ, other, ignore_promotions=True): + return False + sequence = self.chk.named_type("typing.Sequence") + # If the static type is more general than sequence the actual type could still match + return is_subtype(typ, sequence) or is_subtype(sequence, typ) + + def generate_types(self, type_names: List[str]) -> List[Type]: types = [] # type: List[Type] - for name in self_match_type_names: + for name in type_names: try: types.append(self.chk.named_type(name)) - except KeyError: + except KeyError as e: # Some built in types are not defined in all test cases + if not name.startswith('builtins.'): + raise e pass return types @@ -492,15 +511,15 @@ def update_type_map(self, original_type_map[expr] = typ def construct_iterable_child(self, outer_type: Type, inner_type: Type) -> Type: - iterable = self.chk.named_generic_type("typing.Iterable", [inner_type]) + sequence = self.chk.named_generic_type("typing.Sequence", [inner_type]) if self.chk.type_is_iterable(outer_type): proper_type = get_proper_type(outer_type) assert isinstance(proper_type, Instance) empty_type = fill_typevars(proper_type.type) - partial_type = expand_type_by_instance(empty_type, iterable) + partial_type = expand_type_by_instance(empty_type, sequence) return expand_type_by_instance(partial_type, proper_type) else: - return iterable + return sequence def get_match_arg_names(typ: TupleType) -> List[Optional[str]]: diff --git a/test-data/unit/check-python310.test b/test-data/unit/check-python310.test index 9ee77ea07056..40698a96ea1a 100644 --- a/test-data/unit/check-python310.test +++ b/test-data/unit/check-python310.test @@ -74,22 +74,22 @@ match m: [builtins fixtures/list.pyi] [case testSequencePatternCapturesStarred] -from typing import Iterable -m: Iterable[int] +from typing import Sequence +m: Sequence[int] match m: case [a, *b]: - reveal_type(a) # N: Revealed type is "builtins.int" - reveal_type(b) # N: Revealed type is "builtins.list[builtins.int]" + reveal_type(a) # N: Revealed type is "builtins.int*" + reveal_type(b) # N: Revealed type is "builtins.list[builtins.int*]" [builtins fixtures/list.pyi] [case testSequencePatternNarrowsInner] -from typing import Iterable -m: Iterable[object] +from typing import Sequence +m: Sequence[object] match m: case [1, True]: - reveal_type(m) # N: Revealed type is "typing.Iterable[builtins.int]" + reveal_type(m) # N: Revealed type is "typing.Sequence[builtins.int]" [case testSequencePatternNarrowsOuter] from typing import Sequence @@ -97,15 +97,15 @@ m: object match m: case [1, True]: - reveal_type(m) # N: Revealed type is "typing.Iterable[builtins.int]" + reveal_type(m) # N: Revealed type is "typing.Sequence[builtins.int]" [case testSequencePatternAlreadyNarrowerInner] -from typing import Iterable -m: Iterable[bool] +from typing import Sequence +m: Sequence[bool] match m: case [1, True]: - reveal_type(m) # N: Revealed type is "typing.Iterable[builtins.bool]" + reveal_type(m) # N: Revealed type is "typing.Sequence[builtins.bool]" [case testSequencePatternAlreadyNarrowerOuter] from typing import Sequence @@ -124,20 +124,20 @@ match m: reveal_type(m) # N: Revealed type is "typing.Sequence[builtins.bool]" [case testNestedSequencePatternNarrowsInner] -from typing import Iterable -m: Iterable[Iterable[object]] +from typing import Sequence +m: Sequence[Sequence[object]] match m: case [[1], [True]]: - reveal_type(m) # N: Revealed type is "typing.Iterable[typing.Iterable[builtins.int]]" + reveal_type(m) # N: Revealed type is "typing.Sequence[typing.Sequence[builtins.int]]" [case testNestedSequencePatternNarrowsOuter] -from typing import Iterable +from typing import Sequence m: object match m: case [[1], [True]]: - reveal_type(m) # N: Revealed type is "typing.Iterable[typing.Iterable[builtins.int]]" + reveal_type(m) # N: Revealed type is "typing.Sequence[typing.Sequence[builtins.int]]" [case testSequencePatternDoesntNarrowInvariant] @@ -149,6 +149,70 @@ match m: reveal_type(m) # N: Revealed type is "builtins.list[builtins.object]" [builtins fixtures/list.pyi] +[case testSequencePatternMatches] +import array, collections +from typing import Sequence, Iterable + +m1: Sequence[int] +m2: array.array[int] +m3: collections.deque[int] +m4: list[int] +m5: memoryview +m6: range +m7: tuple[int] + +m8: str +m9: bytes +m10: bytearray +m11: Iterable[int] + +match m1: + case [a]: + reveal_type(a) # N: Revealed type is "builtins.int*" + +match m2: + case [b]: + reveal_type(b) # N: Revealed type is "builtins.int*" + +match m3: + case [c]: + reveal_type(c) # N: Revealed type is "builtins.int*" + +match m4: + case [d]: + reveal_type(d) # N: Revealed type is "builtins.int*" + +match m5: + case [e]: + reveal_type(e) # N: Revealed type is "builtins.int*" + +match m6: + case [f]: + reveal_type(f) # N: Revealed type is "builtins.int*" + +match m7: + case [g]: + reveal_type(g) # N: Revealed type is "builtins.int" + +match m8: + case [h]: + reveal_type(h) + +match m9: + case [i]: + reveal_type(i) + +match m10: + case [j]: + reveal_type(j) + +match m11: + case [k]: + reveal_type(k) +[builtins fixtures/primitives.pyi] +[typing fixtures/typing-full.pyi] + + [case testSequencePatternCapturesTuple] from typing import Tuple m: Tuple[int, str, bool] diff --git a/test-data/unit/fixtures/primitives.pyi b/test-data/unit/fixtures/primitives.pyi index 24cb5ea45ff2..8d0dcb00713c 100644 --- a/test-data/unit/fixtures/primitives.pyi +++ b/test-data/unit/fixtures/primitives.pyi @@ -1,5 +1,6 @@ # builtins stub with non-generic primitive types -from typing import Generic, TypeVar, Sequence, Iterator, Mapping, Iterable +from typing import Generic, TypeVar, Sequence, Iterator, Mapping, Iterable, overload + T = TypeVar('T') V = TypeVar('V') @@ -54,3 +55,17 @@ class frozenset(Iterable[T]): def __iter__(self) -> Iterator[T]: pass class function: pass class ellipsis: pass + +class range(Sequence[int]): + start: int + stop: int + step: int + @overload + def __init__(self, stop: int) -> None: ... + @overload + def __init__(self, start: int, stop: int, step: int = ...) -> None: ... + def count(self, value: int) -> int: ... + def index(self, value: int) -> int: ... + def __getitem__(self, i: int) -> int: ... + def __iter__(self) -> Iterator[int]: pass + def __contains__(self, other: object) -> bool: pass diff --git a/test-data/unit/fixtures/typing-full.pyi b/test-data/unit/fixtures/typing-full.pyi index 6aa2f9d291bb..d215f9014019 100644 --- a/test-data/unit/fixtures/typing-full.pyi +++ b/test-data/unit/fixtures/typing-full.pyi @@ -129,6 +129,10 @@ class Sequence(Iterable[T_co], Container[T_co]): @abstractmethod def __getitem__(self, n: Any) -> T_co: pass +class MutableSequence(Sequence[T]): + @abstractmethod + def __setitem__(self, n: Any, o: T) -> None: pass + class Mapping(Iterable[T], Generic[T, T_co], metaclass=ABCMeta): def __getitem__(self, key: T) -> T_co: pass @overload @@ -142,6 +146,11 @@ class Mapping(Iterable[T], Generic[T, T_co], metaclass=ABCMeta): class MutableMapping(Mapping[T, U], metaclass=ABCMeta): def __setitem__(self, k: T, v: U) -> None: pass +@runtime_checkable +class Reversible(Iterable[T_co], Protocol): + @abstractmethod + def __reversed__(self) -> Iterator[T_co]: ... + class SupportsInt(Protocol): def __int__(self) -> int: pass diff --git a/test-data/unit/lib-stub/collections.pyi b/test-data/unit/lib-stub/collections.pyi index 71f797e565e8..7ea264f764ee 100644 --- a/test-data/unit/lib-stub/collections.pyi +++ b/test-data/unit/lib-stub/collections.pyi @@ -1,4 +1,4 @@ -from typing import Any, Iterable, Union, Optional, Dict, TypeVar, overload, Optional, Callable, Sized +from typing import Any, Iterable, Union, Dict, TypeVar, Optional, Callable, Generic, Sequence, MutableMapping def namedtuple( typename: str, @@ -20,6 +20,6 @@ class defaultdict(Dict[KT, VT]): class Counter(Dict[KT, int], Generic[KT]): ... -class deque(Sized, Iterable[KT], Reversible[KT], Generic[KT]): ... +class deque(Sequence[KT], Generic[KT]): ... class ChainMap(MutableMapping[KT, VT], Generic[KT, VT]): ... From 74c7040cc4821160862c99e4293e0cb39f642a65 Mon Sep 17 00:00:00 2001 From: Adrian Freund Date: Mon, 10 May 2021 17:40:55 +0200 Subject: [PATCH 38/76] Remove accidentally commited changes --- test-data/unit/fixtures/typing-full.pyi | 5 ----- test-data/unit/lib-stub/array.pyi.bak | 15 --------------- 2 files changed, 20 deletions(-) delete mode 100644 test-data/unit/lib-stub/array.pyi.bak diff --git a/test-data/unit/fixtures/typing-full.pyi b/test-data/unit/fixtures/typing-full.pyi index d215f9014019..faf228517691 100644 --- a/test-data/unit/fixtures/typing-full.pyi +++ b/test-data/unit/fixtures/typing-full.pyi @@ -146,11 +146,6 @@ class Mapping(Iterable[T], Generic[T, T_co], metaclass=ABCMeta): class MutableMapping(Mapping[T, U], metaclass=ABCMeta): def __setitem__(self, k: T, v: U) -> None: pass -@runtime_checkable -class Reversible(Iterable[T_co], Protocol): - @abstractmethod - def __reversed__(self) -> Iterator[T_co]: ... - class SupportsInt(Protocol): def __int__(self) -> int: pass diff --git a/test-data/unit/lib-stub/array.pyi.bak b/test-data/unit/lib-stub/array.pyi.bak deleted file mode 100644 index 74be1cd6c16e..000000000000 --- a/test-data/unit/lib-stub/array.pyi.bak +++ /dev/null @@ -1,15 +0,0 @@ -from typing import MutableSequence, Generic, Literal, Union, TypeVar, Text - - -_IntTypeCode = Literal["b", "B", "h", "H", "i", "I", "l", "L", "q", "Q"] -_FloatTypeCode = Literal["f", "d"] -_UnicodeTypeCode = Literal["u"] -_TypeCode = Union[_IntTypeCode, _FloatTypeCode, _UnicodeTypeCode] - - -_T = TypeVar("_T", int, float, Text) - - -class array(MutableSequence[_T], Generic[_T]): - typecode: _TypeCode - itemsize: int From c9640272bb2fb1de634a324c8d4470d60bd0ff2d Mon Sep 17 00:00:00 2001 From: Adrian Freund Date: Mon, 10 May 2021 17:46:38 +0200 Subject: [PATCH 39/76] Fix failing TypeExport tests --- test-data/unit/fixtures/primitives.pyi | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/test-data/unit/fixtures/primitives.pyi b/test-data/unit/fixtures/primitives.pyi index 8d0dcb00713c..c72838535443 100644 --- a/test-data/unit/fixtures/primitives.pyi +++ b/test-data/unit/fixtures/primitives.pyi @@ -57,15 +57,12 @@ class function: pass class ellipsis: pass class range(Sequence[int]): - start: int - stop: int - step: int @overload - def __init__(self, stop: int) -> None: ... + def __init__(self, stop: int) -> None: pass @overload - def __init__(self, start: int, stop: int, step: int = ...) -> None: ... - def count(self, value: int) -> int: ... - def index(self, value: int) -> int: ... - def __getitem__(self, i: int) -> int: ... + def __init__(self, start: int, stop: int, step: int = ...) -> None: pass + def count(self, value: int) -> int: pass + def index(self, value: int) -> int: pass + def __getitem__(self, i: int) -> int: pass def __iter__(self) -> Iterator[int]: pass def __contains__(self, other: object) -> bool: pass From f6b1752d9d3bfd69556dca1ca5ef6027df38f2b7 Mon Sep 17 00:00:00 2001 From: Adrian Freund Date: Tue, 11 May 2021 14:32:06 +0200 Subject: [PATCH 40/76] Improve type inference for mapping pattern rest --- mypy/checkpattern.py | 15 +++++++++++++-- test-data/unit/check-python310.test | 12 +++++++++++- 2 files changed, 24 insertions(+), 3 deletions(-) diff --git a/mypy/checkpattern.py b/mypy/checkpattern.py index b8be1a4b60ee..b21cb0c10458 100644 --- a/mypy/checkpattern.py +++ b/mypy/checkpattern.py @@ -6,6 +6,7 @@ from mypy.expandtype import expand_type_by_instance from mypy.join import join_types from mypy.literals import literal_hash +from mypy.maptype import map_instance_to_supertype from mypy.messages import MessageBuilder from mypy.nodes import Expression, ARG_POS, TypeAlias, TypeInfo, Var, NameExpr from mypy.patterns import ( @@ -311,10 +312,20 @@ def visit_mapping_pattern(self, o: MappingPattern) -> PatternType: self.update_type_map(captures, pattern_type.captures) if o.rest is not None: - # TODO: Infer dict type args - captures[o.rest] = self.chk.named_type("builtins.dict") + mapping = self.chk.named_type("typing.Mapping") + if is_subtype(current_type, mapping) and isinstance(current_type, Instance): + mapping_inst = map_instance_to_supertype(current_type, mapping.type) + dict_typeinfo = self.chk.lookup_typeinfo("builtins.dict") + dict_type = fill_typevars(dict_typeinfo) + rest_type = expand_type_by_instance(dict_type, mapping_inst) + else: + object_type = self.chk.named_type("builtins.object") + rest_type = self.chk.named_generic_type("builtins.dict", [object_type, object_type]) + + captures[o.rest] = rest_type if can_match: + # We can't narrow the type here, as Mapping key is invariant. new_type = self.type_context[-1] # type: Optional[Type] else: new_type = None diff --git a/test-data/unit/check-python310.test b/test-data/unit/check-python310.test index 40698a96ea1a..71b31d01d713 100644 --- a/test-data/unit/check-python310.test +++ b/test-data/unit/check-python310.test @@ -413,7 +413,17 @@ m: object match m: case {'k': 1, **r}: - reveal_type(r) # N: Revealed type is "builtins.dict[Any, Any]" + reveal_type(r) # N: Revealed type is "builtins.dict[builtins.object, builtins.object]" +[builtins fixtures/dict.pyi] + +[case testMappingPatternCaptureRestFromMapping] +from typing import Mapping + +m: Mapping[str, int] + +match m: + case {'k': 1, **r}: + reveal_type(r) # N: Revealed type is "builtins.dict[builtins.str*, builtins.int*]" [builtins fixtures/dict.pyi] -- Mapping patterns currently don't narrow -- From d2b7d1e59f98a10d5f73d0ca6b36cdde5c7b8f24 Mon Sep 17 00:00:00 2001 From: Adrian Freund Date: Tue, 11 May 2021 16:35:24 +0200 Subject: [PATCH 41/76] Improve class pattern subpattern type inference --- mypy/checkpattern.py | 41 ++++++++++++++++++----------- test-data/unit/check-python310.test | 30 +++++++++++++++++++++ 2 files changed, 56 insertions(+), 15 deletions(-) diff --git a/mypy/checkpattern.py b/mypy/checkpattern.py index b21cb0c10458..34be8b27314e 100644 --- a/mypy/checkpattern.py +++ b/mypy/checkpattern.py @@ -3,6 +3,7 @@ from typing import List, Optional, Tuple, Dict, NamedTuple, Set, cast import mypy.checker +from mypy.checkmember import analyze_member_access from mypy.expandtype import expand_type_by_instance from mypy.join import join_types from mypy.literals import literal_hash @@ -297,7 +298,7 @@ def visit_starred_pattern(self, o: StarredPattern) -> PatternType: return PatternType(self.type_context[-1], captures) def visit_mapping_pattern(self, o: MappingPattern) -> PatternType: - current_type = self.type_context[-1] + current_type = get_proper_type(self.type_context[-1]) can_match = True captures = {} # type: Dict[Expression, Type] for key, value in zip(o.keys, o.values): @@ -320,7 +321,8 @@ def visit_mapping_pattern(self, o: MappingPattern) -> PatternType: rest_type = expand_type_by_instance(dict_type, mapping_inst) else: object_type = self.chk.named_type("builtins.object") - rest_type = self.chk.named_generic_type("builtins.dict", [object_type, object_type]) + rest_type = self.chk.named_generic_type("builtins.dict", + [object_type, object_type]) captures[o.rest] = rest_type @@ -378,21 +380,31 @@ def get_simple_mapping_item_type(self, return result def visit_class_pattern(self, o: ClassPattern) -> PatternType: - current_type = self.type_context[-1] + current_type = get_proper_type(self.type_context[-1]) # # Check class type # - class_name = o.class_ref.fullname - assert class_name is not None - sym = self.chk.lookup_qualified(class_name) - if isinstance(sym.node, TypeAlias) and not sym.node.no_args: + type_info = o.class_ref.node + assert type_info is not None + if isinstance(type_info, TypeAlias) and not type_info.no_args: self.msg.fail("Class pattern class must not be a type alias with type parameters", o) return early_non_match() - if isinstance(sym.node, (TypeAlias, TypeInfo)): - typ = self.chk.named_type(class_name) + if isinstance(type_info, TypeInfo): + any_type = AnyType(TypeOfAny.implementation_artifact) + typ = Instance(type_info, [any_type] * len(type_info.defn.type_vars)) + elif isinstance(type_info, TypeAlias): + typ = type_info.target else: - self.msg.fail('Class pattern must be a type. Found "{}"'.format(sym.type), o.class_ref) + if isinstance(type_info, Var): + name = type_info.type + else: + name = type_info.name + self.msg.fail('Class pattern must be a type. Found "{}"'.format(name), o.class_ref) + return early_non_match() + + new_type = get_more_specific_type(current_type, typ) + if new_type is None: return early_non_match() # @@ -458,9 +470,10 @@ def visit_class_pattern(self, o: ClassPattern) -> PatternType: can_match = True for keyword, pattern in keyword_pairs: key_type = None # type: Optional[Type] + local_errors = self.msg.clean_copy() if keyword is not None: - key_type = find_member(keyword, typ, current_type) - if key_type is None: + key_type = analyze_member_access(keyword, new_type, pattern, False, False, False, local_errors, original_type=new_type, chk=self.chk) + if local_errors.is_errors() or key_type is None: key_type = AnyType(TypeOfAny.implementation_artifact) pattern_type = self.accept(pattern, key_type) @@ -469,9 +482,7 @@ def visit_class_pattern(self, o: ClassPattern) -> PatternType: else: self.update_type_map(captures, pattern_type.captures) - if can_match: - new_type = get_more_specific_type(current_type, typ) - else: + if not can_match: new_type = None return PatternType(new_type, captures) diff --git a/test-data/unit/check-python310.test b/test-data/unit/check-python310.test index 71b31d01d713..13c3790a0f45 100644 --- a/test-data/unit/check-python310.test +++ b/test-data/unit/check-python310.test @@ -695,6 +695,36 @@ match m: reveal_type(m) # N: Revealed type is "__main__.A" [builtins fixtures/tuple.pyi] +[case testClassPatternNarrowsUnion] +from typing import Final, Union + +class A: + __match_args__: Final = ("a", "b") + a: str + b: int + +class B: + __match_args__: Final = ("a", "b") + a: int + b: str + +m: Union[A, B] + +match m: + case A(): + reveal_type(m) # N: Revealed type is "__main__.A" + case A(i, j): + reveal_type(m) # N: Revealed type is "__main__.A" + reveal_type(i) # N: Revealed type is "builtins.str" + reveal_type(j) # N: Revealed type is "builtins.int" + case B(): + reveal_type(m) # N: Revealed type is "__main__.B" + case B(i, j): + reveal_type(m) # N: Revealed type is "__main__.B" + reveal_type(i) # N: Revealed type is "builtins.int" + reveal_type(j) # N: Revealed type is "builtins.str" +[builtins fixtures/tuple.pyi] + [case testClassPatternAlreadyNarrower] from typing import Final From 5c478527735d1ba0a77d4e3a04ae600382791ca4 Mon Sep 17 00:00:00 2001 From: Adrian Freund Date: Tue, 18 May 2021 12:29:58 +0200 Subject: [PATCH 42/76] Fix selftest failing --- mypy/checkpattern.py | 28 +++++++++++++++++++++------- 1 file changed, 21 insertions(+), 7 deletions(-) diff --git a/mypy/checkpattern.py b/mypy/checkpattern.py index 34be8b27314e..613ddd2085e7 100644 --- a/mypy/checkpattern.py +++ b/mypy/checkpattern.py @@ -15,7 +15,7 @@ ClassPattern, SingletonPattern ) from mypy.plugin import Plugin -from mypy.subtypes import is_subtype, find_member +from mypy.subtypes import is_subtype from mypy.typeops import try_getting_str_literals_from_type, make_simplified_union from mypy.types import ( ProperType, AnyType, TypeOfAny, Instance, Type, UninhabitedType, get_proper_type, @@ -392,12 +392,12 @@ def visit_class_pattern(self, o: ClassPattern) -> PatternType: return early_non_match() if isinstance(type_info, TypeInfo): any_type = AnyType(TypeOfAny.implementation_artifact) - typ = Instance(type_info, [any_type] * len(type_info.defn.type_vars)) + typ = Instance(type_info, [any_type] * len(type_info.defn.type_vars)) # type: Type elif isinstance(type_info, TypeAlias): typ = type_info.target else: if isinstance(type_info, Var): - name = type_info.type + name = str(type_info.type) else: name = type_info.name self.msg.fail('Class pattern must be a type. Found "{}"'.format(name), o.class_ref) @@ -424,9 +424,14 @@ def visit_class_pattern(self, o: ClassPattern) -> PatternType: return pattern_type captures = pattern_type.captures else: - match_args_type = find_member("__match_args__", typ, typ) + local_errors = self.msg.clean_copy() + match_args_type = analyze_member_access("__match_args__", typ, o, + False, False, False, + local_errors, + original_type=typ, + chk=self.chk) - if match_args_type is None: + if local_errors.is_errors(): self.msg.fail("Class doesn't define __match_args__", o) return early_non_match() @@ -472,7 +477,15 @@ def visit_class_pattern(self, o: ClassPattern) -> PatternType: key_type = None # type: Optional[Type] local_errors = self.msg.clean_copy() if keyword is not None: - key_type = analyze_member_access(keyword, new_type, pattern, False, False, False, local_errors, original_type=new_type, chk=self.chk) + key_type = analyze_member_access(keyword, + new_type, + pattern, + False, + False, + False, + local_errors, + original_type=new_type, + chk=self.chk) if local_errors.is_errors() or key_type is None: key_type = AnyType(TypeOfAny.implementation_artifact) @@ -486,7 +499,8 @@ def visit_class_pattern(self, o: ClassPattern) -> PatternType: new_type = None return PatternType(new_type, captures) - def should_self_match(self, typ: ProperType) -> bool: + def should_self_match(self, typ: Type) -> bool: + typ = get_proper_type(typ) if isinstance(typ, Instance) and typ.type.is_named_tuple: return False for other in self.self_match_types: From 0ac21d03c608a3a0cffe8572722f660eea5d8eb1 Mon Sep 17 00:00:00 2001 From: Adrian Freund Date: Tue, 18 May 2021 12:30:10 +0200 Subject: [PATCH 43/76] Fix mypyc build failing --- mypyc/irbuild/visitor.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/mypyc/irbuild/visitor.py b/mypyc/irbuild/visitor.py index 67b8f04a7dc2..fb6d92ada368 100644 --- a/mypyc/irbuild/visitor.py +++ b/mypyc/irbuild/visitor.py @@ -16,7 +16,8 @@ FloatExpr, GeneratorExpr, GlobalDecl, LambdaExpr, ListComprehension, SetComprehension, NamedTupleExpr, NewTypeExpr, NonlocalDecl, OverloadedFuncDef, PrintStmt, RaiseStmt, RevealExpr, SetExpr, SliceExpr, StarExpr, SuperExpr, TryStmt, TypeAliasExpr, TypeApplication, - TypeVarExpr, TypedDictExpr, UnicodeExpr, WithStmt, YieldFromExpr, YieldExpr, ParamSpecExpr + TypeVarExpr, TypedDictExpr, UnicodeExpr, WithStmt, YieldFromExpr, YieldExpr, ParamSpecExpr, + MatchStmt ) from mypyc.ir.ops import Value @@ -179,6 +180,9 @@ def visit_nonlocal_decl(self, stmt: NonlocalDecl) -> None: # Pure declaration -- no runtime effect pass + def visit_match_stmt(self, stmt: MatchStmt) -> None: + self.bail("Match statements are not yet supported", stmt.line) + # Expressions def visit_name_expr(self, expr: NameExpr) -> Value: From 5f98cdd4cf3fd68ad091c1f9bdb895170977b386 Mon Sep 17 00:00:00 2001 From: Adrian Freund Date: Tue, 18 May 2021 14:13:10 +0200 Subject: [PATCH 44/76] Fix another mypyc build error --- mypy/checkpattern.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mypy/checkpattern.py b/mypy/checkpattern.py index 613ddd2085e7..bf250e31ccae 100644 --- a/mypy/checkpattern.py +++ b/mypy/checkpattern.py @@ -1,6 +1,6 @@ """Pattern checker. This file is conceptually part of TypeChecker.""" from collections import defaultdict -from typing import List, Optional, Tuple, Dict, NamedTuple, Set, cast +from typing import List, Optional, Tuple, Dict, NamedTuple, Set, cast, Union import mypy.checker from mypy.checkmember import analyze_member_access @@ -163,7 +163,7 @@ def visit_value_pattern(self, o: ValuePattern) -> PatternType: return PatternType(specific_typ, {}) def visit_singleton_pattern(self, o: SingletonPattern) -> PatternType: - value = o.value + value = o.value # type: Union[bool, None] if isinstance(value, bool): typ = self.chk.expr_checker.infer_literal_expr_type(value, "builtins.bool") elif value is None: From f2e12f8cd7eb7cf68bb8dee67bedd642c994bb6c Mon Sep 17 00:00:00 2001 From: Adrian Freund Date: Mon, 27 Sep 2021 15:25:09 +0200 Subject: [PATCH 45/76] Fix failing tests after merge (Caused by #8578) --- test-data/unit/check-python310.test | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test-data/unit/check-python310.test b/test-data/unit/check-python310.test index 13c3790a0f45..686c92bc749b 100644 --- a/test-data/unit/check-python310.test +++ b/test-data/unit/check-python310.test @@ -545,6 +545,7 @@ match m: case A(i, j): reveal_type(i) # N: Revealed type is "builtins.str" reveal_type(j) # N: Revealed type is "builtins.int" +[builtins fixtures/dataclasses.pyi] [case testClassPatternCaptureDataclassNoMatchArgs] from dataclasses import dataclass @@ -559,6 +560,7 @@ m: A match m: case A(i, j): # E: Class doesn't define __match_args__ pass +[builtins fixtures/dataclasses.pyi] [case testClassPatternCaptureDataclassPartialMatchArgs] from dataclasses import dataclass, field @@ -575,6 +577,7 @@ match m: pass case A(k): reveal_type(k) # N: Revealed type is "builtins.str" +[builtins fixtures/dataclasses.pyi] [case testClassPatternCaptureNamedTupleInline] from collections import namedtuple From a5eb48085722a282a8e22eda279202c546783c3d Mon Sep 17 00:00:00 2001 From: Adrian Freund Date: Mon, 27 Sep 2021 15:43:44 +0200 Subject: [PATCH 46/76] Fixed more failing tests after merge (Caused by #10685) --- mypy/renaming.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/mypy/renaming.py b/mypy/renaming.py index 214eb9ad165c..c200e94d58e7 100644 --- a/mypy/renaming.py +++ b/mypy/renaming.py @@ -162,15 +162,14 @@ def visit_assignment_stmt(self, s: AssignmentStmt) -> None: def visit_match_stmt(self, s: MatchStmt) -> None: for i in range(len(s.patterns)): - self.enter_block() - s.patterns[i].accept(self) - guard = s.guards[i] - if guard is not None: - guard.accept(self) - # We already entered a block, so visit this block's statements directly - for stmt in s.bodies[i].body: - stmt.accept(self) - self.leave_block() + with self.enter_block(): + s.patterns[i].accept(self) + guard = s.guards[i] + if guard is not None: + guard.accept(self) + # We already entered a block, so visit this block's statements directly + for stmt in s.bodies[i].body: + stmt.accept(self) def visit_capture_pattern(self, p: AsPattern) -> None: if p.name is not None: From 9d5d600114208dd7985a15f0aa92444946a2a82b Mon Sep 17 00:00:00 2001 From: Adrian Freund Date: Tue, 28 Sep 2021 16:54:04 +0200 Subject: [PATCH 47/76] Adjust some tests and add new ones --- test-data/unit/check-python310.test | 71 +++++++++++++++++++++-------- 1 file changed, 53 insertions(+), 18 deletions(-) diff --git a/test-data/unit/check-python310.test b/test-data/unit/check-python310.test index 686c92bc749b..12484725826e 100644 --- a/test-data/unit/check-python310.test +++ b/test-data/unit/check-python310.test @@ -153,22 +153,22 @@ match m: import array, collections from typing import Sequence, Iterable -m1: Sequence[int] -m2: array.array[int] -m3: collections.deque[int] -m4: list[int] -m5: memoryview -m6: range -m7: tuple[int] - -m8: str -m9: bytes -m10: bytearray -m11: Iterable[int] +m1: object +m2: Sequence[int] +m3: array.array[int] +m4: collections.deque[int] +m5: list[int] +m6: memoryview +m7: range +m8: tuple[int] + +m9: str +m10: bytes +m11: bytearray match m1: case [a]: - reveal_type(a) # N: Revealed type is "builtins.int*" + reveal_type(a) # N: Revealed type is "builtins.object" match m2: case [b]: @@ -192,11 +192,11 @@ match m6: match m7: case [g]: - reveal_type(g) # N: Revealed type is "builtins.int" + reveal_type(g) # N: Revealed type is "builtins.int*" match m8: case [h]: - reveal_type(h) + reveal_type(h) # N: Revealed type is "builtins.int" match m9: case [i]: @@ -716,16 +716,22 @@ m: Union[A, B] match m: case A(): reveal_type(m) # N: Revealed type is "__main__.A" + +match m: case A(i, j): reveal_type(m) # N: Revealed type is "__main__.A" reveal_type(i) # N: Revealed type is "builtins.str" reveal_type(j) # N: Revealed type is "builtins.int" + +match m: case B(): reveal_type(m) # N: Revealed type is "__main__.B" - case B(i, j): + +match m: + case B(k, l): reveal_type(m) # N: Revealed type is "__main__.B" - reveal_type(i) # N: Revealed type is "builtins.int" - reveal_type(j) # N: Revealed type is "builtins.str" + reveal_type(k) # N: Revealed type is "builtins.int" + reveal_type(l) # N: Revealed type is "builtins.str" [builtins fixtures/tuple.pyi] [case testClassPatternAlreadyNarrower] @@ -742,6 +748,8 @@ m: B match m: case A(): reveal_type(m) # N: Revealed type is "__main__.B" + +match m: case A(i, j): reveal_type(m) # N: Revealed type is "__main__.B" [builtins fixtures/tuple.pyi] @@ -1012,3 +1020,30 @@ match m: reveal_type(a) # N: Revealed type is "builtins.str" case int(a): # E: Incompatible types in capture pattern (pattern captures type "int", variable has type "str") reveal_type(a) # N: Revealed type is "builtins.str" + + +-- Exhaustiveness -- +[case testUnionNegativeNarrowing] +from typing import Union + +m: Union[str, int] + +match m: + case str(a): + reveal_type(a) # N: Revealed type is "builtins.str" + reveal_type(m) # N: Revealed type is "builtins.str" + case b: + reveal_type(b) # N: Revealed type is "builtins.int" + reveal_type(m) # N: Revealed type is "builtins.int" + +[case testOrPatternNegativeNarrowing] +from typing import Union + +m: Union[str, bytes, int] + +match m: + case str(a) | bytes(a): + reveal_type(a) # N: Revealed type is "builtins.object" + reveal_type(m) # N: Revealed type is "Union[builtins.str, builtins.bytes]" + case b: + reveal_type(b) # N: Revealed type is "builtins.int" From a68ea5983d7949db7453af322a27d8dff61ccd2b Mon Sep 17 00:00:00 2001 From: Adrian Freund Date: Wed, 29 Sep 2021 15:14:17 +0200 Subject: [PATCH 48/76] Fixed types staying narrowed after match statement --- mypy/checker.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/mypy/checker.py b/mypy/checker.py index 00c5c9ca1037..40931a182983 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -3912,6 +3912,10 @@ def visit_match_stmt(self, s: MatchStmt) -> None: self.push_type_map(if_map) self.accept(b) + # This is needed due to a quirk in frame_context. Without it types will stay narrowed after the match. + with self.binder.frame_context(can_skip=False, fall_through=2): + pass + def infer_names_from_type_maps(self, type_maps: List[TypeMap]) -> None: all_captures = defaultdict(list) # type: Dict[Var, List[Tuple[NameExpr, Type]]] for tm in type_maps: From 43fd6e5df36c0df8e18e5cbe808613c011a35be5 Mon Sep 17 00:00:00 2001 From: Adrian Freund Date: Wed, 29 Sep 2021 19:33:23 +0200 Subject: [PATCH 49/76] Run com2ann and fix errors --- mypy/checker.py | 9 +++--- mypy/checkpattern.py | 56 ++++++++++++++++++------------------- mypy/nodes.py | 8 +++--- mypy/patterns.py | 28 +++++++++---------- mypy/plugins/dataclasses.py | 4 +-- mypy/semanal_namedtuple.py | 2 +- mypy/strconv.py | 6 ++-- 7 files changed, 56 insertions(+), 57 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index 40931a182983..33496a7fc0e2 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -3891,7 +3891,7 @@ def visit_match_stmt(self, s: MatchStmt) -> None: pattern_types = [self.pattern_checker.accept(p, subject_type) for p in s.patterns] - type_maps = [t.captures for t in pattern_types] # type: List[TypeMap] + type_maps: List[TypeMap] = [t.captures for t in pattern_types] self.infer_names_from_type_maps(type_maps) for pattern_type, g, b in zip(pattern_types, s.guards, s.bodies): @@ -3912,12 +3912,13 @@ def visit_match_stmt(self, s: MatchStmt) -> None: self.push_type_map(if_map) self.accept(b) - # This is needed due to a quirk in frame_context. Without it types will stay narrowed after the match. + # This is needed due to a quirk in frame_context. Without it types will stay narrowed + # after the match. with self.binder.frame_context(can_skip=False, fall_through=2): pass def infer_names_from_type_maps(self, type_maps: List[TypeMap]) -> None: - all_captures = defaultdict(list) # type: Dict[Var, List[Tuple[NameExpr, Type]]] + all_captures: Dict[Var, List[Tuple[NameExpr, Type]]] = defaultdict(list) for tm in type_maps: if tm is not None: for expr, typ in tm.items(): @@ -3928,7 +3929,7 @@ def infer_names_from_type_maps(self, type_maps: List[TypeMap]) -> None: for var, captures in all_captures.items(): conflict = False - types = [] # type: List[Type] + types: List[Type] = [] for expr, typ in captures: types.append(typ) diff --git a/mypy/checkpattern.py b/mypy/checkpattern.py index bf250e31ccae..d2643750ee75 100644 --- a/mypy/checkpattern.py +++ b/mypy/checkpattern.py @@ -61,21 +61,21 @@ class PatternChecker(PatternVisitor[PatternType]): """ # Some services are provided by a TypeChecker instance. - chk = None # type: mypy.checker.TypeChecker + chk: 'mypy.checker.TypeChecker' # This is shared with TypeChecker, but stored also here for convenience. - msg = None # type: MessageBuilder + msg: MessageBuilder # Currently unused - plugin = None # type: Plugin + plugin: Plugin # The expression being matched against the pattern - subject = None # type: Expression + subject: Expression - subject_type = None # type: Type + subject_type: Type # Type of the subject to check the (sub)pattern against - type_context = None # type: List[Type] + type_context: List[Type] - self_match_types = None # type: List[Type] + self_match_types: List[Type] - non_sequence_match_types = None # type: List[Type] + non_sequence_match_types: List[Type] def __init__(self, chk: 'mypy.checker.TypeChecker', @@ -131,7 +131,7 @@ def visit_or_pattern(self, o: OrPattern) -> PatternType: # # Check the capture types # - capture_types = defaultdict(list) # type: Dict[Var, List[Tuple[Expression, Type]]] + capture_types: Dict[Var, List[Tuple[Expression, Type]]] = defaultdict(list) # Collect captures from the first subpattern for expr, typ in pattern_types[0].captures.items(): node = get_var(expr) @@ -146,7 +146,7 @@ def visit_or_pattern(self, o: OrPattern) -> PatternType: node = get_var(expr) capture_types[node].append((expr, typ)) - captures = {} # type: Dict[Expression, Type] + captures: Dict[Expression, Type] = {} for var, capture_list in capture_types.items(): typ = UninhabitedType() for _, other in capture_list: @@ -163,7 +163,7 @@ def visit_value_pattern(self, o: ValuePattern) -> PatternType: return PatternType(specific_typ, {}) def visit_singleton_pattern(self, o: SingletonPattern) -> PatternType: - value = o.value # type: Union[bool, None] + value: Union[bool, None] = o.value if isinstance(value, bool): typ = self.chk.expr_checker.infer_literal_expr_type(value, "builtins.bool") elif value is None: @@ -182,7 +182,7 @@ def visit_sequence_pattern(self, o: SequencePattern) -> PatternType: if not self.can_match_sequence(current_type): return early_non_match() star_positions = [i for i, p in enumerate(o.patterns) if isinstance(p, StarredPattern)] - star_position = None # type: Optional[int] + star_position: Optional[int] = None if len(star_positions) == 1: star_position = star_positions[0] elif len(star_positions) >= 2: @@ -210,8 +210,8 @@ def visit_sequence_pattern(self, o: SequencePattern) -> PatternType: # # match inner patterns # - contracted_new_inner_types = [] # type: List[Type] - captures = {} # type: Dict[Expression, Type] + contracted_new_inner_types: List[Type] = [] + captures: Dict[Expression, Type] = {} contracted_inner_types = self.contract_starred_pattern_types(inner_types, star_position, @@ -233,7 +233,7 @@ def visit_sequence_pattern(self, o: SequencePattern) -> PatternType: # Calculate new type # if not can_match: - new_type = None # type: Optional[Type] + new_type: Optional[Type] elif isinstance(current_type, TupleType): specific_inner_types = [] for inner_type, new_inner_type in zip(inner_types, new_inner_types): @@ -291,7 +291,7 @@ def expand_starred_pattern_types(self, return new_types def visit_starred_pattern(self, o: StarredPattern) -> PatternType: - captures = {} # type: Dict[Expression, Type] + captures: Dict[Expression, Type] = {} if o.capture is not None: list_type = self.chk.named_generic_type('builtins.list', [self.type_context[-1]]) captures[o.capture] = list_type @@ -300,7 +300,7 @@ def visit_starred_pattern(self, o: StarredPattern) -> PatternType: def visit_mapping_pattern(self, o: MappingPattern) -> PatternType: current_type = get_proper_type(self.type_context[-1]) can_match = True - captures = {} # type: Dict[Expression, Type] + captures: Dict[Expression, Type] = {} for key, value in zip(o.keys, o.values): inner_type = self.get_mapping_item_type(o, current_type, key) if inner_type is None: @@ -328,7 +328,7 @@ def visit_mapping_pattern(self, o: MappingPattern) -> PatternType: if can_match: # We can't narrow the type here, as Mapping key is invariant. - new_type = self.type_context[-1] # type: Optional[Type] + new_type: Optional[Type] = self.type_context[-1] else: new_type = None return PatternType(new_type, captures) @@ -342,10 +342,8 @@ def get_mapping_item_type(self, local_errors.disable_count = 0 mapping_type = get_proper_type(mapping_type) if isinstance(mapping_type, TypedDictType): - result = self.chk.expr_checker.visit_typeddict_index_expr(mapping_type, - key, - local_errors=local_errors - ) # type: Optional[Type] + result: Optional[Type] = self.chk.expr_checker.visit_typeddict_index_expr( + mapping_type, key, local_errors=local_errors) # If we can't determine the type statically fall back to treating it as a normal # mapping if local_errors.is_errors(): @@ -392,7 +390,7 @@ def visit_class_pattern(self, o: ClassPattern) -> PatternType: return early_non_match() if isinstance(type_info, TypeInfo): any_type = AnyType(TypeOfAny.implementation_artifact) - typ = Instance(type_info, [any_type] * len(type_info.defn.type_vars)) # type: Type + typ: Type = Instance(type_info, [any_type] * len(type_info.defn.type_vars)) elif isinstance(type_info, TypeAlias): typ = type_info.target else: @@ -410,10 +408,10 @@ def visit_class_pattern(self, o: ClassPattern) -> PatternType: # # Convert positional to keyword patterns # - keyword_pairs = [] # type: List[Tuple[Optional[str], Pattern]] - match_arg_set = set() # type: Set[str] + keyword_pairs: List[Tuple[Optional[str], Pattern]] = [] + match_arg_set: Set[str] = set() - captures = {} # type: Dict[Expression, Type] + captures: Dict[Expression, Type] = {} if len(o.positionals) != 0: if self.should_self_match(typ): @@ -474,7 +472,7 @@ def visit_class_pattern(self, o: ClassPattern) -> PatternType: # can_match = True for keyword, pattern in keyword_pairs: - key_type = None # type: Optional[Type] + key_type: Optional[Type] = None local_errors = self.msg.clean_copy() if keyword is not None: key_type = analyze_member_access(keyword, @@ -519,7 +517,7 @@ def can_match_sequence(self, typ: ProperType) -> bool: return is_subtype(typ, sequence) or is_subtype(sequence, typ) def generate_types(self, type_names: List[str]) -> List[Type]: - types = [] # type: List[Type] + types: List[Type] = [] for name in type_names: try: types.append(self.chk.named_type(name)) @@ -559,7 +557,7 @@ def construct_iterable_child(self, outer_type: Type, inner_type: Type) -> Type: def get_match_arg_names(typ: TupleType) -> List[Optional[str]]: - args = [] # type: List[Optional[str]] + args: List[Optional[str]] = [] for item in typ.items: values = try_getting_str_literals_from_type(item) if values is None or len(values) != 1: diff --git a/mypy/nodes.py b/mypy/nodes.py index 89ccfd6aaff3..4a2e43fe698f 100644 --- a/mypy/nodes.py +++ b/mypy/nodes.py @@ -1282,10 +1282,10 @@ def accept(self, visitor: StatementVisitor[T]) -> T: class MatchStmt(Statement): - subject = None # type: Expression - patterns = None # type: List['Pattern'] - guards = None # type: List[Optional[Expression]] - bodies = None # type: List[Block] + subject: Expression + patterns: List['Pattern'] + guards: List[Optional[Expression]] + bodies: List[Block] def __init__(self, subject: Expression, patterns: List['Pattern'], guards: List[Optional[Expression]], bodies: List[Block]) -> None: diff --git a/mypy/patterns.py b/mypy/patterns.py index 28b270f95077..79d394f7f629 100644 --- a/mypy/patterns.py +++ b/mypy/patterns.py @@ -29,8 +29,8 @@ class AlwaysTruePattern(Pattern): class AsPattern(Pattern): - pattern = None # type: Optional[Pattern] - name = None # type: Optional[NameExpr] + pattern: Optional[Pattern] + name: Optional[NameExpr] def __init__(self, pattern: Optional[Pattern], name: Optional[NameExpr]) -> None: super().__init__() @@ -42,7 +42,7 @@ def accept(self, visitor: PatternVisitor[T]) -> T: class OrPattern(Pattern): - patterns = None # type: List[Pattern] + patterns: List[Pattern] def __init__(self, patterns: List[Pattern]) -> None: super().__init__() @@ -53,7 +53,7 @@ def accept(self, visitor: PatternVisitor[T]) -> T: class ValuePattern(Pattern): - expr = None # type: Expression + expr: Expression def __init__(self, expr: Expression): super().__init__() @@ -64,7 +64,7 @@ def accept(self, visitor: PatternVisitor[T]) -> T: class SingletonPattern(Pattern): - value = None # type: Union[bool, None] + value: Union[bool, None] def __init__(self, value: Union[bool, None]): super().__init__() @@ -75,7 +75,7 @@ def accept(self, visitor: PatternVisitor[T]) -> T: class SequencePattern(Pattern): - patterns = None # type: List[Pattern] + patterns: List[Pattern] def __init__(self, patterns: List[Pattern]): super().__init__() @@ -88,7 +88,7 @@ def accept(self, visitor: PatternVisitor[T]) -> T: # TODO: A StarredPattern is only valid within a SequencePattern. This is not guaranteed by our # type hierarchy. Should it be? class StarredPattern(Pattern): - capture = None # type: Optional[NameExpr] + capture: Optional[NameExpr] def __init__(self, capture: Optional[NameExpr]): super().__init__() @@ -99,9 +99,9 @@ def accept(self, visitor: PatternVisitor[T]) -> T: class MappingPattern(Pattern): - keys = None # type: List[Expression] - values = None # type: List[Pattern] - rest = None # type: Optional[NameExpr] + keys: List[Expression] + values: List[Pattern] + rest: Optional[NameExpr] def __init__(self, keys: List[Expression], values: List[Pattern], rest: Optional[NameExpr]): @@ -115,10 +115,10 @@ def accept(self, visitor: PatternVisitor[T]) -> T: class ClassPattern(Pattern): - class_ref = None # type: RefExpr - positionals = None # type: List[Pattern] - keyword_keys = None # type: List[str] - keyword_values = None # type: List[Pattern] + class_ref: RefExpr + positionals: List[Pattern] + keyword_keys: List[str] + keyword_values: List[Pattern] def __init__(self, class_ref: RefExpr, positionals: List[Pattern], keyword_keys: List[str], keyword_values: List[Pattern]): diff --git a/mypy/plugins/dataclasses.py b/mypy/plugins/dataclasses.py index e838b721c4ba..c5b7793e3838 100644 --- a/mypy/plugins/dataclasses.py +++ b/mypy/plugins/dataclasses.py @@ -196,8 +196,8 @@ def transform(self) -> None: info.names['__match_args__'].plugin_generated) and attributes): str_type = ctx.api.named_type("__builtins__.str") - literals = [LiteralType(attr.name, str_type) - for attr in attributes if attr.is_in_init] # type: List[Type] + literals: List[Type] = [LiteralType(attr.name, str_type) + for attr in attributes if attr.is_in_init] match_args_type = TupleType(literals, ctx.api.named_type("__builtins__.tuple")) add_attribute_to_class(ctx.api, ctx.cls, "__match_args__", match_args_type, final=True) diff --git a/mypy/semanal_namedtuple.py b/mypy/semanal_namedtuple.py index e974409fe2da..6c728690e66b 100644 --- a/mypy/semanal_namedtuple.py +++ b/mypy/semanal_namedtuple.py @@ -398,7 +398,7 @@ def build_namedtuple_typeinfo(self, iterable_type = self.api.named_type_or_none('typing.Iterable', [implicit_any]) function_type = self.api.named_type('__builtins__.function') - literals = [LiteralType(item, strtype) for item in items] # type: List[Type] + literals: List[Type] = [LiteralType(item, strtype) for item in items] match_args_type = TupleType(literals, basetuple_type) info = self.api.basic_new_typeinfo(name, fallback, line) diff --git a/mypy/strconv.py b/mypy/strconv.py index 118ad6f2b17b..22534a44971d 100644 --- a/mypy/strconv.py +++ b/mypy/strconv.py @@ -316,7 +316,7 @@ def visit_exec_stmt(self, o: 'mypy.nodes.ExecStmt') -> str: return self.dump([o.expr, o.globals, o.locals], o) def visit_match_stmt(self, o: 'mypy.nodes.MatchStmt') -> str: - a = [o.subject] # type: List[Any] + a: List[Any] = [o.subject] for i in range(len(o.patterns)): a.append(('Pattern', [o.patterns[i]])) if o.guards[i] is not None: @@ -569,7 +569,7 @@ def visit_starred_pattern(self, o: 'mypy.patterns.StarredPattern') -> str: return self.dump([o.capture], o) def visit_mapping_pattern(self, o: 'mypy.patterns.MappingPattern') -> str: - a = [] # type: List[Any] + a: List[Any] = [] for i in range(len(o.keys)): a.append(('Key', [o.keys[i]])) a.append(('Value', [o.values[i]])) @@ -578,7 +578,7 @@ def visit_mapping_pattern(self, o: 'mypy.patterns.MappingPattern') -> str: return self.dump(a, o) def visit_class_pattern(self, o: 'mypy.patterns.ClassPattern') -> str: - a = [o.class_ref] # type: List[Any] + a: List[Any] = [o.class_ref] if len(o.positionals) > 0: a.append(('Positionals', o.positionals)) for i in range(len(o.keyword_keys)): From d7baadf0f94741fac922f1a1fb150564d4593aaa Mon Sep 17 00:00:00 2001 From: Adrian Freund Date: Wed, 29 Sep 2021 21:27:14 +0200 Subject: [PATCH 50/76] Split conditional_type_maps into two functions --- mypy/checker.py | 59 +++++++++++++++++++++++++++++------------------- mypy/subtypes.py | 2 +- 2 files changed, 37 insertions(+), 24 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index 33496a7fc0e2..90cce47eefd1 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -4885,10 +4885,11 @@ def refine_identity_comparison_expression(self, if enum_name is not None: expr_type = try_expanding_enum_to_union(expr_type, enum_name) - # We intentionally use 'conditional_type_map' directly here instead of + # We intentionally use 'conditional_types' directly here instead of # 'self.conditional_type_map_with_intersection': we only compute ad-hoc # intersections when working with pure instances. - partial_type_maps.append(conditional_type_map(expr, expr_type, target_type)) + types = conditional_types(expr_type, target_type) + partial_type_maps.append(conditional_types_to_typemaps(expr, *types)) return reduce_conditional_maps(partial_type_maps) @@ -5315,11 +5316,8 @@ def conditional_type_map_with_intersection(self, expr_type: Type, type_ranges: Optional[List[TypeRange]], ) -> Tuple[TypeMap, TypeMap]: - # For some reason, doing "yes_map, no_map = conditional_type_maps(...)" - # doesn't work: mypyc will decide that 'yes_map' is of type None if we try. - initial_maps = conditional_type_map(expr, expr_type, type_ranges) - yes_map: TypeMap = initial_maps[0] - no_map: TypeMap = initial_maps[1] + initial_types = conditional_types(expr_type, type_ranges) + yes_map, no_map = conditional_types_to_typemaps(expr, *initial_types) if yes_map is not None or type_ranges is None: return yes_map, no_map @@ -5366,17 +5364,14 @@ def is_writable_attribute(self, node: Node) -> bool: return False -def conditional_type_map(expr: Expression, - current_type: Optional[Type], - proposed_type_ranges: Optional[List[TypeRange]], - ) -> Tuple[TypeMap, TypeMap]: - """Takes in an expression, the current type of the expression, and a - proposed type of that expression. +def conditional_types(current_type: Optional[Type], + proposed_type_ranges: Optional[List[TypeRange]], + ) -> Tuple[Optional[Type], Optional[Type]]: + """Takes in the current type and a proposed type of an expression. - Returns a 2-tuple: The first element is a map from the expression to - the proposed type, if the expression can be the proposed type. The - second element is a map from the expression to the type it would hold - if it was not the proposed type, if any. None means bot, {} means top""" + Returns a 2-tuple: The first element is the proposed type, if the expression + can be the proposed type. The second element is the type it would hold + if it was not the proposed type, if any. None means top, UninhabitedType means bot""" if proposed_type_ranges: proposed_items = [type_range.item for type_range in proposed_type_ranges] proposed_type = make_simplified_union(proposed_items) @@ -5385,27 +5380,45 @@ def conditional_type_map(expr: Expression, # We don't really know much about the proposed type, so we shouldn't # attempt to narrow anything. Instead, we broaden the expr to Any to # avoid false positives - return {expr: proposed_type}, {} + return proposed_type, None elif (not any(type_range.is_upper_bound for type_range in proposed_type_ranges) and is_proper_subtype(current_type, proposed_type)): # Expression is always of one of the types in proposed_type_ranges - return {}, None + return None, UninhabitedType() elif not is_overlapping_types(current_type, proposed_type, prohibit_none_typevar_overlap=True): # Expression is never of any type in proposed_type_ranges - return None, {} + return UninhabitedType(), None else: # we can only restrict when the type is precise, not bounded proposed_precise_type = UnionType([type_range.item for type_range in proposed_type_ranges if not type_range.is_upper_bound]) remaining_type = restrict_subtype_away(current_type, proposed_precise_type) - return {expr: proposed_type}, {expr: remaining_type} + return proposed_type, remaining_type else: - return {expr: proposed_type}, {} + return proposed_type, None else: # An isinstance check, but we don't understand the type - return {}, {} + return None, None + + +def conditional_types_to_typemaps(expr: Expression, + yes_type: Optional[Type], + no_type: Optional[Type] + ) -> Tuple[TypeMap, TypeMap]: + maps: List[TypeMap] = [] + for typ in (yes_type, no_type): + proper_type = get_proper_type(typ) + if isinstance(proper_type, UninhabitedType): + maps.append(None) + elif proper_type is None: + maps.append({}) + else: + assert typ is not None # If proper_type is not None type is neither. + maps.append({expr: typ}) + + return cast(Tuple[TypeMap, TypeMap], tuple(maps)) def gen_unique_name(base: str, table: SymbolTable) -> str: diff --git a/mypy/subtypes.py b/mypy/subtypes.py index 63cebc8aa483..dcd06a5ba665 100644 --- a/mypy/subtypes.py +++ b/mypy/subtypes.py @@ -1396,7 +1396,7 @@ def visit_type_type(self, left: TypeType) -> bool: if right.type.fullname == 'builtins.type': # TODO: Strictly speaking, the type builtins.type is considered equivalent to # Type[Any]. However, this would break the is_proper_subtype check in - # conditional_type_map for cases like isinstance(x, type) when the type + # conditional_types for cases like isinstance(x, type) when the type # of x is Type[int]. It's unclear what's the right way to address this. return True if right.type.fullname == 'builtins.object': From 8a3ba8eead5b23a549fb4762383acba278995599 Mon Sep 17 00:00:00 2001 From: Adrian Freund Date: Wed, 29 Sep 2021 21:38:51 +0200 Subject: [PATCH 51/76] Readd removed workaround as the problem is still present --- mypy/checker.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index 90cce47eefd1..9e31af3fd9e6 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -5317,7 +5317,11 @@ def conditional_type_map_with_intersection(self, type_ranges: Optional[List[TypeRange]], ) -> Tuple[TypeMap, TypeMap]: initial_types = conditional_types(expr_type, type_ranges) - yes_map, no_map = conditional_types_to_typemaps(expr, *initial_types) + # For some reason, doing "yes_map, no_map = conditional_types_to_typemaps(...)" + # doesn't work: mypyc will decide that 'yes_map' is of type None if we try. + initial_maps = conditional_types_to_typemaps(expr, *initial_types) + yes_map: TypeMap = initial_maps[0] + no_map: TypeMap = initial_maps[1] if yes_map is not None or type_ranges is None: return yes_map, no_map @@ -5415,7 +5419,7 @@ def conditional_types_to_typemaps(expr: Expression, elif proper_type is None: maps.append({}) else: - assert typ is not None # If proper_type is not None type is neither. + assert typ is not None # If proper_type is not None type is neighter. maps.append({expr: typ}) return cast(Tuple[TypeMap, TypeMap], tuple(maps)) From 90aa1a79dbbb973059ee94ff14cd1f64b2392851 Mon Sep 17 00:00:00 2001 From: Adrian Freund Date: Mon, 4 Oct 2021 11:57:46 +0200 Subject: [PATCH 52/76] Disable Exhaustiveness test for now --- test-data/unit/check-python310.test | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test-data/unit/check-python310.test b/test-data/unit/check-python310.test index 12484725826e..4335d2098f6a 100644 --- a/test-data/unit/check-python310.test +++ b/test-data/unit/check-python310.test @@ -1023,7 +1023,7 @@ match m: -- Exhaustiveness -- -[case testUnionNegativeNarrowing] +[case testUnionNegativeNarrowing-skip] from typing import Union m: Union[str, int] @@ -1036,7 +1036,7 @@ match m: reveal_type(b) # N: Revealed type is "builtins.int" reveal_type(m) # N: Revealed type is "builtins.int" -[case testOrPatternNegativeNarrowing] +[case testOrPatternNegativeNarrowing-skip] from typing import Union m: Union[str, bytes, int] From de723e143d017ec346ffb4f60732aca42a5a1b87 Mon Sep 17 00:00:00 2001 From: Adrian Freund Date: Wed, 6 Oct 2021 21:25:45 +0200 Subject: [PATCH 53/76] Fixed checking of nested list statements failing --- mypy/checkpattern.py | 2 +- test-data/unit/check-python310.test | 11 +++++++++++ test-data/unit/fixtures/dict.pyi | 5 ++++- 3 files changed, 16 insertions(+), 2 deletions(-) diff --git a/mypy/checkpattern.py b/mypy/checkpattern.py index d2643750ee75..15ad5469e8db 100644 --- a/mypy/checkpattern.py +++ b/mypy/checkpattern.py @@ -417,7 +417,7 @@ def visit_class_pattern(self, o: ClassPattern) -> PatternType: if self.should_self_match(typ): if len(o.positionals) > 1: self.msg.fail("Too many positional patterns for class pattern", o) - pattern_type = self.accept(o.positionals[0], typ) + pattern_type = self.accept(o.positionals[0], new_type) if pattern_type.type is None: return pattern_type captures = pattern_type.captures diff --git a/test-data/unit/check-python310.test b/test-data/unit/check-python310.test index 4335d2098f6a..a2fc75bc4bcc 100644 --- a/test-data/unit/check-python310.test +++ b/test-data/unit/check-python310.test @@ -139,6 +139,17 @@ match m: case [[1], [True]]: reveal_type(m) # N: Revealed type is "typing.Sequence[typing.Sequence[builtins.int]]" +[case testMultipleNestedSequencePattern] +# From cpython test_patma.py +x = [[{0: 0}]] +match x: + case list([({-0-0j: int(real=0+0j, imag=0-0j) | (1) as z},)]): + y = 0 + +reveal_type(x) # N: Revealed type is "builtins.list[builtins.list*[builtins.dict*[builtins.int*, builtins.int*]]]" +reveal_type(y) # N: Revealed type is "builtins.int" +reveal_type(z) # N: Revealed type is "builtins.int*" +[builtins fixtures/dict.pyi] [case testSequencePatternDoesntNarrowInvariant] from typing import List diff --git a/test-data/unit/fixtures/dict.pyi b/test-data/unit/fixtures/dict.pyi index ab8127badd4c..476b00840d13 100644 --- a/test-data/unit/fixtures/dict.pyi +++ b/test-data/unit/fixtures/dict.pyi @@ -31,8 +31,11 @@ class dict(Mapping[KT, VT]): def get(self, k: KT, default: Union[KT, T]) -> Union[VT, T]: pass def __len__(self) -> int: ... +class complex: pass class int: # for convenience - def __add__(self, x: int) -> int: pass + def __add__(self, x: Union[int, complex]) -> int: pass + def __sub__(self, x: Union[int, complex]) -> int: pass + def __neg__(self): pass class str: pass # for keyword argument key type class unicode: pass # needed for py2 docstrings From 3fb2cf21d6cf9a170a2a048ef9a60b662df3b1f1 Mon Sep 17 00:00:00 2001 From: Adrian Freund Date: Wed, 6 Oct 2021 21:28:07 +0200 Subject: [PATCH 54/76] Fixed mypy crashing when list statement can't match --- mypy/checkpattern.py | 3 ++- test-data/unit/check-python310.test | 8 ++++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/mypy/checkpattern.py b/mypy/checkpattern.py index 15ad5469e8db..592976a6d389 100644 --- a/mypy/checkpattern.py +++ b/mypy/checkpattern.py @@ -232,8 +232,9 @@ def visit_sequence_pattern(self, o: SequencePattern) -> PatternType: # # Calculate new type # + new_type: Optional[Type] if not can_match: - new_type: Optional[Type] + new_type = None elif isinstance(current_type, TupleType): specific_inner_types = [] for inner_type, new_inner_type in zip(inner_types, new_inner_types): diff --git a/test-data/unit/check-python310.test b/test-data/unit/check-python310.test index a2fc75bc4bcc..b6b3f57f12da 100644 --- a/test-data/unit/check-python310.test +++ b/test-data/unit/check-python310.test @@ -303,6 +303,14 @@ match m: reveal_type(c) [builtins fixtures/list.pyi] +[case testNonMatchingSequencePattern] +from typing import List + +x: List[int] +match x: + case [str()]: + pass + -- Mapping Pattern -- [case testMappingPatternCaptures] from typing import Dict From ab7a2e3bad350b780969f1078c27ad8a88dc3c82 Mon Sep 17 00:00:00 2001 From: Adrian Freund Date: Wed, 6 Oct 2021 22:02:31 +0200 Subject: [PATCH 55/76] Moved test to correct category I first misidentified what the actual problem causing the bug was. This test should be here. --- test-data/unit/check-python310.test | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/test-data/unit/check-python310.test b/test-data/unit/check-python310.test index b6b3f57f12da..150a99050ad3 100644 --- a/test-data/unit/check-python310.test +++ b/test-data/unit/check-python310.test @@ -139,18 +139,6 @@ match m: case [[1], [True]]: reveal_type(m) # N: Revealed type is "typing.Sequence[typing.Sequence[builtins.int]]" -[case testMultipleNestedSequencePattern] -# From cpython test_patma.py -x = [[{0: 0}]] -match x: - case list([({-0-0j: int(real=0+0j, imag=0-0j) | (1) as z},)]): - y = 0 - -reveal_type(x) # N: Revealed type is "builtins.list[builtins.list*[builtins.dict*[builtins.int*, builtins.int*]]]" -reveal_type(y) # N: Revealed type is "builtins.int" -reveal_type(z) # N: Revealed type is "builtins.int*" -[builtins fixtures/dict.pyi] - [case testSequencePatternDoesntNarrowInvariant] from typing import List m: List[object] @@ -849,6 +837,18 @@ match m: reveal_type(i) reveal_type(j) +[case testClassPatternNestedGenerics] +# From cpython test_patma.py +x = [[{0: 0}]] +match x: + case list([({-0-0j: int(real=0+0j, imag=0-0j) | (1) as z},)]): + y = 0 + +reveal_type(x) # N: Revealed type is "builtins.list[builtins.list*[builtins.dict*[builtins.int*, builtins.int*]]]" +reveal_type(y) # N: Revealed type is "builtins.int" +reveal_type(z) # N: Revealed type is "builtins.int*" +[builtins fixtures/dict.pyi] + [case testNonFinalMatchArgs] class A: __match_args__ = ("a", "b") # N: __match_args__ must be final for checking of match statements to work From 45e2abe142d8c2d4cb9afa6ba3a46b4f86cd9514 Mon Sep 17 00:00:00 2001 From: Adrian Freund Date: Thu, 30 Sep 2021 12:34:40 +0200 Subject: [PATCH 56/76] Simplify conditional_types --- mypy/checker.py | 61 +++++++++++++++++++++++-------------------------- mypy/typeops.py | 2 +- 2 files changed, 30 insertions(+), 33 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index 9e31af3fd9e6..82a80c6f9c19 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -4889,7 +4889,7 @@ def refine_identity_comparison_expression(self, # 'self.conditional_type_map_with_intersection': we only compute ad-hoc # intersections when working with pure instances. types = conditional_types(expr_type, target_type) - partial_type_maps.append(conditional_types_to_typemaps(expr, *types)) + partial_type_maps.append(conditional_types_to_typemaps(expr, expr_type, *types)) return reduce_conditional_maps(partial_type_maps) @@ -5319,7 +5319,7 @@ def conditional_type_map_with_intersection(self, initial_types = conditional_types(expr_type, type_ranges) # For some reason, doing "yes_map, no_map = conditional_types_to_typemaps(...)" # doesn't work: mypyc will decide that 'yes_map' is of type None if we try. - initial_maps = conditional_types_to_typemaps(expr, *initial_types) + initial_maps = conditional_types_to_typemaps(expr, expr_type, *initial_types) yes_map: TypeMap = initial_maps[0] no_map: TypeMap = initial_maps[1] @@ -5368,58 +5368,55 @@ def is_writable_attribute(self, node: Node) -> bool: return False -def conditional_types(current_type: Optional[Type], +def conditional_types(current_type: Type, proposed_type_ranges: Optional[List[TypeRange]], - ) -> Tuple[Optional[Type], Optional[Type]]: + ) -> Tuple[Type, Type]: """Takes in the current type and a proposed type of an expression. Returns a 2-tuple: The first element is the proposed type, if the expression can be the proposed type. The second element is the type it would hold - if it was not the proposed type, if any. None means top, UninhabitedType means bot""" + if it was not the proposed type, if any. UninhabitedType means unreachable""" if proposed_type_ranges: proposed_items = [type_range.item for type_range in proposed_type_ranges] proposed_type = make_simplified_union(proposed_items) - if current_type: - if isinstance(proposed_type, AnyType): - # We don't really know much about the proposed type, so we shouldn't - # attempt to narrow anything. Instead, we broaden the expr to Any to - # avoid false positives - return proposed_type, None - elif (not any(type_range.is_upper_bound for type_range in proposed_type_ranges) - and is_proper_subtype(current_type, proposed_type)): - # Expression is always of one of the types in proposed_type_ranges - return None, UninhabitedType() - elif not is_overlapping_types(current_type, proposed_type, - prohibit_none_typevar_overlap=True): - # Expression is never of any type in proposed_type_ranges - return UninhabitedType(), None - else: - # we can only restrict when the type is precise, not bounded - proposed_precise_type = UnionType([type_range.item - for type_range in proposed_type_ranges - if not type_range.is_upper_bound]) - remaining_type = restrict_subtype_away(current_type, proposed_precise_type) - return proposed_type, remaining_type + if isinstance(proposed_type, AnyType): + # We don't really know much about the proposed type, so we shouldn't + # attempt to narrow anything. Instead, we broaden the expr to Any to + # avoid false positives + return proposed_type, current_type + elif (not any(type_range.is_upper_bound for type_range in proposed_type_ranges) + and is_proper_subtype(current_type, proposed_type)): + # Expression is always of one of the types in proposed_type_ranges + return current_type, UninhabitedType() + elif not is_overlapping_types(current_type, proposed_type, + prohibit_none_typevar_overlap=True): + # Expression is never of any type in proposed_type_ranges + return UninhabitedType(), current_type else: - return proposed_type, None + # we can only restrict when the type is precise, not bounded + proposed_precise_type = UnionType([type_range.item + for type_range in proposed_type_ranges + if not type_range.is_upper_bound]) + remaining_type = restrict_subtype_away(current_type, proposed_precise_type) + return proposed_type, remaining_type else: # An isinstance check, but we don't understand the type - return None, None + return current_type, current_type def conditional_types_to_typemaps(expr: Expression, - yes_type: Optional[Type], - no_type: Optional[Type] + expr_type: Type, + yes_type: Type, + no_type: Type ) -> Tuple[TypeMap, TypeMap]: maps: List[TypeMap] = [] for typ in (yes_type, no_type): proper_type = get_proper_type(typ) if isinstance(proper_type, UninhabitedType): maps.append(None) - elif proper_type is None: + elif proper_type == get_proper_type(expr_type): maps.append({}) else: - assert typ is not None # If proper_type is not None type is neighter. maps.append({expr: typ}) return cast(Tuple[TypeMap, TypeMap], tuple(maps)) diff --git a/mypy/typeops.py b/mypy/typeops.py index fc0dee538d79..cfe236a54776 100644 --- a/mypy/typeops.py +++ b/mypy/typeops.py @@ -782,7 +782,7 @@ def coerce_to_literal(typ: Type) -> Type: typ = get_proper_type(typ) if isinstance(typ, UnionType): new_items = [coerce_to_literal(item) for item in typ.items] - return make_simplified_union(new_items) + return UnionType.make_union(new_items) elif isinstance(typ, Instance): if typ.last_known_value: return typ.last_known_value From 44ca265a5a1871909c954571b25877a92fc832e6 Mon Sep 17 00:00:00 2001 From: Adrian Freund Date: Mon, 11 Oct 2021 14:28:25 +0200 Subject: [PATCH 57/76] Generate intersection types in match statements and add guard tests --- mypy/checker.py | 62 +++++----- mypy/checkpattern.py | 168 +++++++++++++++++----------- test-data/unit/check-python310.test | 108 +++++++++++++++++- 3 files changed, 242 insertions(+), 96 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index 82a80c6f9c19..55d333bc74e4 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -3896,7 +3896,7 @@ def visit_match_stmt(self, s: MatchStmt) -> None: for pattern_type, g, b in zip(pattern_types, s.guards, s.bodies): with self.binder.frame_context(can_skip=True, fall_through=2): - if b.is_unreachable or pattern_type.type is None: + if b.is_unreachable or isinstance(get_proper_type(pattern_type.type), UninhabitedType): self.push_type_map(None) else: self.binder.put(s.subject, pattern_type.type) @@ -4325,11 +4325,15 @@ def is_type_call(expr: CallExpr) -> bool: if_maps: List[TypeMap] = [] else_maps: List[TypeMap] = [] for expr in exprs_in_type_calls: - current_if_map, current_else_map = self.conditional_type_map_with_intersection( - expr, + current_if_type, current_else_type = self.conditional_types_with_intersection( type_map[expr], - type_being_compared + type_being_compared, + expr ) + current_if_map, current_else_map = conditional_types_to_typemaps(expr, + type_map[expr], + current_if_type, + current_else_type) if_maps.append(current_if_map) else_maps.append(current_else_map) @@ -4386,10 +4390,14 @@ def find_isinstance_check_helper(self, node: Expression) -> Tuple[TypeMap, TypeM if len(node.args) != 2: # the error will be reported elsewhere return {}, {} if literal(expr) == LITERAL_TYPE: - return self.conditional_type_map_with_intersection( + return conditional_types_to_typemaps( expr, type_map[expr], - get_isinstance_type(node.args[1], type_map), + *self.conditional_types_with_intersection( + type_map[expr], + get_isinstance_type(node.args[1], type_map), + expr + ) ) elif refers_to_fullname(node.callee, 'builtins.issubclass'): if len(node.args) != 2: # the error will be reported elsewhere @@ -4886,7 +4894,7 @@ def refine_identity_comparison_expression(self, expr_type = try_expanding_enum_to_union(expr_type, enum_name) # We intentionally use 'conditional_types' directly here instead of - # 'self.conditional_type_map_with_intersection': we only compute ad-hoc + # 'self.conditional_types_with_intersection': we only compute ad-hoc # intersections when working with pure instances. types = conditional_types(expr_type, target_type) partial_type_maps.append(conditional_types_to_typemaps(expr, expr_type, *types)) @@ -5307,55 +5315,55 @@ def infer_issubclass_maps(self, node: CallExpr, # Any other object whose type we don't know precisely # for example, Any or a custom metaclass. return {}, {} # unknown type - yes_map, no_map = self.conditional_type_map_with_intersection(expr, vartype, type) + yes_type, no_type = self.conditional_types_with_intersection(vartype, type, expr) + yes_map, no_map = conditional_types_to_typemaps(expr, vartype, yes_type, no_type) yes_map, no_map = map(convert_to_typetype, (yes_map, no_map)) return yes_map, no_map - def conditional_type_map_with_intersection(self, - expr: Expression, - expr_type: Type, - type_ranges: Optional[List[TypeRange]], - ) -> Tuple[TypeMap, TypeMap]: + def conditional_types_with_intersection(self, + expr_type: Type, + type_ranges: Optional[List[TypeRange]], + ctx: Context, + ) -> Tuple[Type, Type]: initial_types = conditional_types(expr_type, type_ranges) # For some reason, doing "yes_map, no_map = conditional_types_to_typemaps(...)" # doesn't work: mypyc will decide that 'yes_map' is of type None if we try. - initial_maps = conditional_types_to_typemaps(expr, expr_type, *initial_types) - yes_map: TypeMap = initial_maps[0] - no_map: TypeMap = initial_maps[1] + yes_type: Type = initial_types[0] + no_type: Type = initial_types[1] - if yes_map is not None or type_ranges is None: - return yes_map, no_map + if not isinstance(get_proper_type(yes_type), UninhabitedType) or type_ranges is None: + return yes_type, no_type # If conditions_type_map was unable to successfully narrow the expr_type # using the type_ranges and concluded if-branch is unreachable, we try # computing it again using a different algorithm that tries to generate # an ad-hoc intersection between the expr_type and the type_ranges. - expr_type = get_proper_type(expr_type) - if isinstance(expr_type, UnionType): - possible_expr_types = get_proper_types(expr_type.relevant_items()) + proper_type = get_proper_type(expr_type) + if isinstance(proper_type, UnionType): + possible_expr_types = get_proper_types(proper_type.relevant_items()) else: - possible_expr_types = [expr_type] + possible_expr_types = [proper_type] possible_target_types = [] for tr in type_ranges: item = get_proper_type(tr.item) if not isinstance(item, Instance) or tr.is_upper_bound: - return yes_map, no_map + return yes_type, no_type possible_target_types.append(item) out = [] for v in possible_expr_types: if not isinstance(v, Instance): - return yes_map, no_map + return yes_type, no_type for t in possible_target_types: - intersection = self.intersect_instances((v, t), expr) + intersection = self.intersect_instances((v, t), ctx) if intersection is None: continue out.append(intersection) if len(out) == 0: - return None, {} + return UninhabitedType(), expr_type new_yes_type = make_simplified_union(out) - return {expr: new_yes_type}, {} + return new_yes_type, expr_type def is_writable_attribute(self, node: Node) -> bool: """Check if an attribute is writable""" diff --git a/mypy/checkpattern.py b/mypy/checkpattern.py index 592976a6d389..13df1be9a667 100644 --- a/mypy/checkpattern.py +++ b/mypy/checkpattern.py @@ -8,6 +8,7 @@ from mypy.join import join_types from mypy.literals import literal_hash from mypy.maptype import map_instance_to_supertype +from mypy.meet import narrow_declared_type from mypy.messages import MessageBuilder from mypy.nodes import Expression, ARG_POS, TypeAlias, TypeInfo, Var, NameExpr from mypy.patterns import ( @@ -49,6 +50,7 @@ 'PatternType', [ ('type', Optional[Type]), + ('rest_type', Type), # For exhaustiveness checking. Not used yet ('captures', Dict[Expression, Type]), ]) @@ -100,32 +102,37 @@ def visit_as_pattern(self, o: AsPattern) -> PatternType: current_type = self.type_context[-1] if o.pattern is not None: pattern_type = self.accept(o.pattern, current_type) - typ, type_map = pattern_type + typ, rest_type, type_map = pattern_type else: - typ, type_map = current_type, {} + typ, rest_type, type_map = current_type, UninhabitedType(), {} - if typ is not None and o.name is not None: - typ = get_more_specific_type(typ, current_type) - if typ is not None: + if not is_uninhabited(typ) and o.name is not None: + typ, _ = self.chk.conditional_types_with_intersection(current_type, + [get_type_range(typ)], + o) + if not is_uninhabited(typ): type_map[o.name] = typ - return PatternType(typ, type_map) + return PatternType(typ, rest_type, type_map) def visit_or_pattern(self, o: OrPattern) -> PatternType: + current_type = self.type_context[-1] # # Check all the subpatterns # pattern_types = [] for pattern in o.patterns: - pattern_types.append(self.accept(pattern, self.type_context[-1])) + pattern_type = self.accept(pattern, current_type) + pattern_types.append(pattern_type) + current_type = pattern_type.rest_type # # Collect the final type # types = [] for pattern_type in pattern_types: - if pattern_type.type is not None: + if not is_uninhabited(pattern_type.type): types.append(pattern_type.type) # @@ -155,12 +162,17 @@ def visit_or_pattern(self, o: OrPattern) -> PatternType: captures[capture_list[0][0]] = typ union_type = make_simplified_union(types) - return PatternType(union_type, captures) + return PatternType(union_type, current_type, captures) def visit_value_pattern(self, o: ValuePattern) -> PatternType: + current_type = self.type_context[-1] typ = self.chk.expr_checker.accept(o.expr) - specific_typ = get_more_specific_type(typ, self.type_context[-1]) - return PatternType(specific_typ, {}) + narrowed_type, rest_type = self.chk.conditional_types_with_intersection( + current_type, + [get_type_range(typ)], + o + ) + return PatternType(narrowed_type, rest_type, {}) def visit_singleton_pattern(self, o: SingletonPattern) -> PatternType: value: Union[bool, None] = o.value @@ -171,8 +183,12 @@ def visit_singleton_pattern(self, o: SingletonPattern) -> PatternType: else: assert False - specific_type = get_more_specific_type(typ, self.type_context[-1]) - return PatternType(specific_type, {}) + narrowed_type, rest_type = self.chk.conditional_types_with_intersection( + self.type_context[-1], + [get_type_range(typ)], + o + ) + return PatternType(narrowed_type, rest_type, {}) def visit_sequence_pattern(self, o: SequencePattern) -> PatternType: # @@ -180,7 +196,7 @@ def visit_sequence_pattern(self, o: SequencePattern) -> PatternType: # current_type = get_proper_type(self.type_context[-1]) if not self.can_match_sequence(current_type): - return early_non_match() + return self.early_non_match() star_positions = [i for i, p in enumerate(o.patterns) if isinstance(p, StarredPattern)] star_position: Optional[int] = None if len(star_positions) == 1: @@ -198,9 +214,9 @@ def visit_sequence_pattern(self, o: SequencePattern) -> PatternType: inner_types = current_type.items size_diff = len(inner_types) - required_patterns if size_diff < 0: - return early_non_match() + return self.early_non_match() elif size_diff > 0 and star_position is None: - return early_non_match() + return self.early_non_match() else: inner_type = self.get_sequence_type(current_type) if inner_type is None: @@ -211,39 +227,57 @@ def visit_sequence_pattern(self, o: SequencePattern) -> PatternType: # match inner patterns # contracted_new_inner_types: List[Type] = [] + contracted_rest_inner_types: List[Type] = [] captures: Dict[Expression, Type] = {} contracted_inner_types = self.contract_starred_pattern_types(inner_types, - star_position, + star_position, required_patterns) can_match = True for p, t in zip(o.patterns, contracted_inner_types): pattern_type = self.accept(p, t) - typ, type_map = pattern_type - if typ is None: + typ, rest, type_map = pattern_type + if is_uninhabited(typ): can_match = False else: contracted_new_inner_types.append(typ) + contracted_rest_inner_types.append(rest) self.update_type_map(captures, type_map) new_inner_types = self.expand_starred_pattern_types(contracted_new_inner_types, star_position, len(inner_types)) + rest_inner_types = self.expand_starred_pattern_types(contracted_rest_inner_types, + star_position, + len(inner_types)) # # Calculate new type # - new_type: Optional[Type] + new_type: Type + rest_type = current_type if not can_match: - new_type = None + new_type = UninhabitedType() elif isinstance(current_type, TupleType): - specific_inner_types = [] + narrowed_inner_types = [] + inner_rest_types = [] for inner_type, new_inner_type in zip(inner_types, new_inner_types): - specific_inner_types.append(get_more_specific_type(inner_type, new_inner_type)) - if all(typ is not None for typ in specific_inner_types): - specific_inner_types_cast = cast(List[Type], specific_inner_types) - new_type = TupleType(specific_inner_types_cast, current_type.partial_fallback) + narrowed_inner_type, inner_rest_type = self.chk.conditional_types_with_intersection( + new_inner_type, + [get_type_range(inner_type)], + o + ) + narrowed_inner_types.append(narrowed_inner_type) + inner_rest_types.append(inner_rest_type) + if all(not is_uninhabited(typ) for typ in narrowed_inner_types): + new_type = TupleType(narrowed_inner_types, current_type.partial_fallback) else: - new_type = None + new_type = UninhabitedType() + + if all(is_uninhabited(typ) for typ in inner_rest_types): + # All subpatterns always match, so we can apply negative narrowing + new_type, rest_type = self.chk.conditional_types_with_intersection( + current_type, [get_type_range(new_type)], o + ) else: new_inner_type = UninhabitedType() for typ in new_inner_types: @@ -251,7 +285,7 @@ def visit_sequence_pattern(self, o: SequencePattern) -> PatternType: new_type = self.construct_iterable_child(current_type, new_inner_type) if not is_subtype(new_type, current_type): new_type = current_type - return PatternType(new_type, captures) + return PatternType(new_type, rest_type, captures) def get_sequence_type(self, t: Type) -> Optional[Type]: t = get_proper_type(t) @@ -296,7 +330,7 @@ def visit_starred_pattern(self, o: StarredPattern) -> PatternType: if o.capture is not None: list_type = self.chk.named_generic_type('builtins.list', [self.type_context[-1]]) captures[o.capture] = list_type - return PatternType(self.type_context[-1], captures) + return PatternType(self.type_context[-1], UninhabitedType(), captures) def visit_mapping_pattern(self, o: MappingPattern) -> PatternType: current_type = get_proper_type(self.type_context[-1]) @@ -308,7 +342,7 @@ def visit_mapping_pattern(self, o: MappingPattern) -> PatternType: can_match = False inner_type = self.chk.named_type("builtins.object") pattern_type = self.accept(value, inner_type) - if pattern_type is None: + if is_uninhabited(pattern_type.type): can_match = False else: self.update_type_map(captures, pattern_type.captures) @@ -329,10 +363,10 @@ def visit_mapping_pattern(self, o: MappingPattern) -> PatternType: if can_match: # We can't narrow the type here, as Mapping key is invariant. - new_type: Optional[Type] = self.type_context[-1] + new_type = self.type_context[-1] else: - new_type = None - return PatternType(new_type, captures) + new_type = UninhabitedType() + return PatternType(new_type, current_type, captures) def get_mapping_item_type(self, pattern: MappingPattern, @@ -388,7 +422,7 @@ def visit_class_pattern(self, o: ClassPattern) -> PatternType: assert type_info is not None if isinstance(type_info, TypeAlias) and not type_info.no_args: self.msg.fail("Class pattern class must not be a type alias with type parameters", o) - return early_non_match() + return self.early_non_match() if isinstance(type_info, TypeInfo): any_type = AnyType(TypeOfAny.implementation_artifact) typ: Type = Instance(type_info, [any_type] * len(type_info.defn.type_vars)) @@ -400,11 +434,15 @@ def visit_class_pattern(self, o: ClassPattern) -> PatternType: else: name = type_info.name self.msg.fail('Class pattern must be a type. Found "{}"'.format(name), o.class_ref) - return early_non_match() + return self.early_non_match() - new_type = get_more_specific_type(current_type, typ) - if new_type is None: - return early_non_match() + new_type, rest_type = self.chk.conditional_types_with_intersection( + current_type, [get_type_range(typ)], o + ) + if is_uninhabited(new_type): + return self.early_non_match() + # TODO: Do I need this? + narrowed_type = narrow_declared_type(current_type, new_type) # # Convert positional to keyword patterns @@ -418,9 +456,11 @@ def visit_class_pattern(self, o: ClassPattern) -> PatternType: if self.should_self_match(typ): if len(o.positionals) > 1: self.msg.fail("Too many positional patterns for class pattern", o) - pattern_type = self.accept(o.positionals[0], new_type) - if pattern_type.type is None: - return pattern_type + pattern_type = self.accept(o.positionals[0], narrowed_type) + if not is_uninhabited(pattern_type.type): + return PatternType(pattern_type.type, + join_types(rest_type, pattern_type.rest_type), + pattern_type.captures) captures = pattern_type.captures else: local_errors = self.msg.clean_copy() @@ -432,7 +472,7 @@ def visit_class_pattern(self, o: ClassPattern) -> PatternType: if local_errors.is_errors(): self.msg.fail("Class doesn't define __match_args__", o) - return early_non_match() + return self.early_non_match() proper_match_args_type = get_proper_type(match_args_type) if isinstance(proper_match_args_type, TupleType): @@ -440,7 +480,7 @@ def visit_class_pattern(self, o: ClassPattern) -> PatternType: if len(o.positionals) > len(match_arg_names): self.msg.fail("Too many positional patterns for class pattern", o) - return early_non_match() + return self.early_non_match() else: match_arg_names = [None] * len(o.positionals) @@ -466,7 +506,7 @@ def visit_class_pattern(self, o: ClassPattern) -> PatternType: keyword_arg_set.add(key) if has_duplicates: - return early_non_match() + return self.early_non_match() # # Check keyword patterns @@ -477,7 +517,7 @@ def visit_class_pattern(self, o: ClassPattern) -> PatternType: local_errors = self.msg.clean_copy() if keyword is not None: key_type = analyze_member_access(keyword, - new_type, + narrowed_type, pattern, False, False, @@ -488,15 +528,17 @@ def visit_class_pattern(self, o: ClassPattern) -> PatternType: if local_errors.is_errors() or key_type is None: key_type = AnyType(TypeOfAny.implementation_artifact) - pattern_type = self.accept(pattern, key_type) - if pattern_type is None: + inner_type, inner_rest_type, inner_captures = self.accept(pattern, key_type) + if is_uninhabited(inner_type): can_match = False else: - self.update_type_map(captures, pattern_type.captures) + self.update_type_map(captures, inner_captures) + if not is_uninhabited(inner_rest_type): + rest_type = current_type if not can_match: - new_type = None - return PatternType(new_type, captures) + new_type = UninhabitedType() + return PatternType(new_type, rest_type, captures) def should_self_match(self, typ: Type) -> bool: typ = get_proper_type(typ) @@ -556,6 +598,9 @@ def construct_iterable_child(self, outer_type: Type, inner_type: Type) -> Type: else: return sequence + def early_non_match(self) -> PatternType: + return PatternType(UninhabitedType(), self.type_context[-1], {}) + def get_match_arg_names(typ: TupleType) -> List[Optional[str]]: args: List[Optional[str]] = [] @@ -568,21 +613,6 @@ def get_match_arg_names(typ: TupleType) -> List[Optional[str]]: return args -def get_more_specific_type(left: Optional[Type], right: Optional[Type]) -> Optional[Type]: - if left is None or right is None: - return None - elif is_subtype(left, right): - return left - elif is_subtype(right, left): - return right - else: - return None - - -def early_non_match() -> PatternType: - return PatternType(None, {}) - - def get_var(expr: Expression) -> Var: """ Warning: this in only true for expressions captured by a match statement. @@ -592,3 +622,11 @@ def get_var(expr: Expression) -> Var: node = expr.node assert isinstance(node, Var) return node + + +def get_type_range(typ: Type) -> 'mypy.checker.TypeRange': + return mypy.checker.TypeRange(typ, is_upper_bound=False) + + +def is_uninhabited(typ: Type) -> bool: + return isinstance(get_proper_type(typ), UninhabitedType) diff --git a/test-data/unit/check-python310.test b/test-data/unit/check-python310.test index 150a99050ad3..5e8d68f65cd7 100644 --- a/test-data/unit/check-python310.test +++ b/test-data/unit/check-python310.test @@ -24,11 +24,13 @@ match m: reveal_type(m) # N: Revealed type is "builtins.bool" [case testLiteralPatternUnreachable] +# primitives are needed because otherwise mypy doesn't see that int and str are incompatible m: int match m: case "str": reveal_type(m) +[builtins fixtures/primitives.pyi] -- Value Pattern -- @@ -52,8 +54,23 @@ match m: [file b.py] b: int +[case testValuePatternIntersect] +import b + +class A: ... +m: A + +match m: + case b.b: + reveal_type(m) # N: Revealed type is "__main__." +[file b.py] +class B: ... +b: B + [case testValuePatternUnreachable] +# primitives are needed because otherwise mypy doesn't see that int and str are incompatible import b + m: int match m: @@ -61,6 +78,7 @@ match m: reveal_type(m) [file b.py] b: str +[builtins fixtures/primitives.pyi] -- Sequence Pattern -- @@ -235,7 +253,7 @@ match m: reveal_type(c) [builtins fixtures/list.pyi] -[case testSequencePatternTupleToShort] +[case testSequencePatternTupleTooShort] from typing import Tuple m: Tuple[int, str, bool] @@ -761,7 +779,7 @@ match m: reveal_type(m) # N: Revealed type is "__main__.B" [builtins fixtures/tuple.pyi] -[case testClassPatternUnreachable] +[case testClassPatternIntersection] from typing import Final class A: @@ -774,9 +792,9 @@ m: B match m: case A(): - reveal_type(m) + reveal_type(m) # N: Revealed type is "__main__." case A(i, j): - reveal_type(m) + reveal_type(m) # N: Revealed type is "__main__.1" [builtins fixtures/tuple.pyi] [case testClassPatternNonexistentKeyword] @@ -1041,6 +1059,72 @@ match m: reveal_type(a) # N: Revealed type is "builtins.str" +-- Guards -- +[case testSimplePatternGuard] +m: str + +def guard() -> bool: ... + +match m: + case a if guard(): + reveal_type(a) # N: Revealed type is "builtins.str" + +[case testAlwaysTruePatternGuard] +m: str + +match m: + case a if True: + reveal_type(a) # N: Revealed type is "builtins.str" + +[case testAlwaysFalsePatternGuard] +m: str + +match m: + case a if False: + reveal_type(a) + +[case testRedefiningPatternGuard] +# flags: --strict-optional +m: str + +match m: + case a if a := 1: # E: Incompatible types in assignment (expression has type "int", variable has type "str") + reveal_type(a) # N: Revealed type is "" + +[case testAssigningPatternGuard] +m: str + +match m: + case a if a := "test": + reveal_type(a) # N: Revealed type is "builtins.str" + +[case testNarrowingPatternGuard] +m: object + +match m: + case a if isinstance(a, str): + reveal_type(a) # N: Revealed type is "builtins.str" +[builtins fixtures/isinstancelist.pyi] + +[case testIncompatiblePatternGuard] +class A: ... +class B: ... + +m: A + +match m: + case a if isinstance(a, B): + reveal_type(a) # N: Revealed type is "__main__." +[builtins fixtures/isinstancelist.pyi] + +[case testUnreachablePatternGuard] +m: str + +match m: + case a if isinstance(a, int): + reveal_type(a) +[builtins fixtures/isinstancelist.pyi] + -- Exhaustiveness -- [case testUnionNegativeNarrowing-skip] from typing import Union @@ -1066,3 +1150,19 @@ match m: reveal_type(m) # N: Revealed type is "Union[builtins.str, builtins.bytes]" case b: reveal_type(b) # N: Revealed type is "builtins.int" + +[case testExhaustiveReturn-skip] +def foo(value) -> int: + match value: + case "bar": + return 1 + case _: + return 2 + +[case testNoneExhaustiveReturn-skip] +def foo(value) -> int: # E: Missing return statement + match value: + case "bar": + return 1 + case 2: + return 2 From 21f7b59230bcd9d99c8bed8c9d3d2163056c8cb2 Mon Sep 17 00:00:00 2001 From: Adrian Freund Date: Mon, 11 Oct 2021 15:29:39 +0200 Subject: [PATCH 58/76] Fixed broken tests --- mypy/checker.py | 3 ++- mypy/checkpattern.py | 31 ++++++++++++++++++----------- test-data/unit/check-python310.test | 9 +++++++++ 3 files changed, 30 insertions(+), 13 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index 55d333bc74e4..d5f99a2e7ba2 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -3896,7 +3896,8 @@ def visit_match_stmt(self, s: MatchStmt) -> None: for pattern_type, g, b in zip(pattern_types, s.guards, s.bodies): with self.binder.frame_context(can_skip=True, fall_through=2): - if b.is_unreachable or isinstance(get_proper_type(pattern_type.type), UninhabitedType): + if b.is_unreachable or isinstance(get_proper_type(pattern_type.type), + UninhabitedType): self.push_type_map(None) else: self.binder.put(s.subject, pattern_type.type) diff --git a/mypy/checkpattern.py b/mypy/checkpattern.py index 13df1be9a667..34209eea7328 100644 --- a/mypy/checkpattern.py +++ b/mypy/checkpattern.py @@ -1,6 +1,6 @@ """Pattern checker. This file is conceptually part of TypeChecker.""" from collections import defaultdict -from typing import List, Optional, Tuple, Dict, NamedTuple, Set, cast, Union +from typing import List, Optional, Tuple, Dict, NamedTuple, Set, Union import mypy.checker from mypy.checkmember import analyze_member_access @@ -20,7 +20,7 @@ from mypy.typeops import try_getting_str_literals_from_type, make_simplified_union from mypy.types import ( ProperType, AnyType, TypeOfAny, Instance, Type, UninhabitedType, get_proper_type, - TypedDictType, TupleType, NoneType + TypedDictType, TupleType, NoneType, UnionType ) from mypy.typevars import fill_typevars from mypy.visitor import PatternVisitor @@ -49,7 +49,7 @@ PatternType = NamedTuple( 'PatternType', [ - ('type', Optional[Type]), + ('type', Type), ('rest_type', Type), # For exhaustiveness checking. Not used yet ('captures', Dict[Expression, Type]), ]) @@ -246,26 +246,24 @@ def visit_sequence_pattern(self, o: SequencePattern) -> PatternType: new_inner_types = self.expand_starred_pattern_types(contracted_new_inner_types, star_position, len(inner_types)) - rest_inner_types = self.expand_starred_pattern_types(contracted_rest_inner_types, - star_position, - len(inner_types)) # # Calculate new type # new_type: Type - rest_type = current_type + rest_type: Type = current_type if not can_match: new_type = UninhabitedType() elif isinstance(current_type, TupleType): narrowed_inner_types = [] inner_rest_types = [] for inner_type, new_inner_type in zip(inner_types, new_inner_types): - narrowed_inner_type, inner_rest_type = self.chk.conditional_types_with_intersection( - new_inner_type, - [get_type_range(inner_type)], - o - ) + narrowed_inner_type, inner_rest_type = \ + self.chk.conditional_types_with_intersection( + new_inner_type, + [get_type_range(inner_type)], + o + ) narrowed_inner_types.append(narrowed_inner_type) inner_rest_types.append(inner_rest_type) if all(not is_uninhabited(typ) for typ in narrowed_inner_types): @@ -291,6 +289,13 @@ def get_sequence_type(self, t: Type) -> Optional[Type]: t = get_proper_type(t) if isinstance(t, AnyType): return AnyType(TypeOfAny.from_another_any, t) + if isinstance(t, UnionType): + items = [self.get_sequence_type(item) for item in t.items] + not_none_items = [item for item in items if item is not None] + if len(not_none_items) > 0: + return make_simplified_union(not_none_items) + else: + return None if self.chk.type_is_iterable(t) and isinstance(t, Instance): return self.chk.iterable_item_type(t) @@ -550,6 +555,8 @@ def should_self_match(self, typ: Type) -> bool: return False def can_match_sequence(self, typ: ProperType) -> bool: + if isinstance(typ, UnionType): + return any(self.can_match_sequence(get_proper_type(item)) for item in typ.items) for other in self.non_sequence_match_types: # We have to ignore promotions, as memoryview should match, but bytes, # which it can be promoted to, shouldn't diff --git a/test-data/unit/check-python310.test b/test-data/unit/check-python310.test index 5e8d68f65cd7..4d62a5a314dc 100644 --- a/test-data/unit/check-python310.test +++ b/test-data/unit/check-python310.test @@ -317,6 +317,15 @@ match x: case [str()]: pass +[case testSequenceUnion] +from typing import List, Union +m: Union[List[List[str]], str] + +match m: + case [list(['str'])]: + reveal_type(m) # N: Revealed type is "Union[builtins.list[builtins.list[builtins.str]], builtins.str]" +[builtins fixtures/list.pyi] + -- Mapping Pattern -- [case testMappingPatternCaptures] from typing import Dict From bfb578d90c1c66a0d6934cbeed0194084c8ee06c Mon Sep 17 00:00:00 2001 From: Adrian Freund Date: Tue, 12 Oct 2021 13:29:50 +0200 Subject: [PATCH 59/76] Fix mypy_primer differences --- mypy/checker.py | 14 ++++++++------ mypy/checkpattern.py | 19 ++++++++++++------- 2 files changed, 20 insertions(+), 13 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index d5f99a2e7ba2..8c3edcfc5f64 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -5325,8 +5325,9 @@ def conditional_types_with_intersection(self, expr_type: Type, type_ranges: Optional[List[TypeRange]], ctx: Context, + default: Optional[Type] = None ) -> Tuple[Type, Type]: - initial_types = conditional_types(expr_type, type_ranges) + initial_types = conditional_types(expr_type, type_ranges, default) # For some reason, doing "yes_map, no_map = conditional_types_to_typemaps(...)" # doesn't work: mypyc will decide that 'yes_map' is of type None if we try. yes_type: Type = initial_types[0] @@ -5379,6 +5380,7 @@ def is_writable_attribute(self, node: Node) -> bool: def conditional_types(current_type: Type, proposed_type_ranges: Optional[List[TypeRange]], + default: Optional[Type] = None ) -> Tuple[Type, Type]: """Takes in the current type and a proposed type of an expression. @@ -5392,15 +5394,15 @@ def conditional_types(current_type: Type, # We don't really know much about the proposed type, so we shouldn't # attempt to narrow anything. Instead, we broaden the expr to Any to # avoid false positives - return proposed_type, current_type + return proposed_type, default elif (not any(type_range.is_upper_bound for type_range in proposed_type_ranges) and is_proper_subtype(current_type, proposed_type)): # Expression is always of one of the types in proposed_type_ranges - return current_type, UninhabitedType() + return default, UninhabitedType() elif not is_overlapping_types(current_type, proposed_type, prohibit_none_typevar_overlap=True): # Expression is never of any type in proposed_type_ranges - return UninhabitedType(), current_type + return UninhabitedType(), default else: # we can only restrict when the type is precise, not bounded proposed_precise_type = UnionType([type_range.item @@ -5410,7 +5412,7 @@ def conditional_types(current_type: Type, return proposed_type, remaining_type else: # An isinstance check, but we don't understand the type - return current_type, current_type + return current_type, default def conditional_types_to_typemaps(expr: Expression, @@ -5423,7 +5425,7 @@ def conditional_types_to_typemaps(expr: Expression, proper_type = get_proper_type(typ) if isinstance(proper_type, UninhabitedType): maps.append(None) - elif proper_type == get_proper_type(expr_type): + elif proper_type is None: maps.append({}) else: maps.append({expr: typ}) diff --git a/mypy/checkpattern.py b/mypy/checkpattern.py index 34209eea7328..bc8b7af2bdfa 100644 --- a/mypy/checkpattern.py +++ b/mypy/checkpattern.py @@ -109,7 +109,8 @@ def visit_as_pattern(self, o: AsPattern) -> PatternType: if not is_uninhabited(typ) and o.name is not None: typ, _ = self.chk.conditional_types_with_intersection(current_type, [get_type_range(typ)], - o) + o, + default=current_type) if not is_uninhabited(typ): type_map[o.name] = typ @@ -170,11 +171,13 @@ def visit_value_pattern(self, o: ValuePattern) -> PatternType: narrowed_type, rest_type = self.chk.conditional_types_with_intersection( current_type, [get_type_range(typ)], - o + o, + default=current_type ) return PatternType(narrowed_type, rest_type, {}) def visit_singleton_pattern(self, o: SingletonPattern) -> PatternType: + current_type = self.type_context[-1] value: Union[bool, None] = o.value if isinstance(value, bool): typ = self.chk.expr_checker.infer_literal_expr_type(value, "builtins.bool") @@ -184,9 +187,10 @@ def visit_singleton_pattern(self, o: SingletonPattern) -> PatternType: assert False narrowed_type, rest_type = self.chk.conditional_types_with_intersection( - self.type_context[-1], + current_type, [get_type_range(typ)], - o + o, + default=current_type ) return PatternType(narrowed_type, rest_type, {}) @@ -262,7 +266,8 @@ def visit_sequence_pattern(self, o: SequencePattern) -> PatternType: self.chk.conditional_types_with_intersection( new_inner_type, [get_type_range(inner_type)], - o + o, + default=new_inner_type ) narrowed_inner_types.append(narrowed_inner_type) inner_rest_types.append(inner_rest_type) @@ -274,7 +279,7 @@ def visit_sequence_pattern(self, o: SequencePattern) -> PatternType: if all(is_uninhabited(typ) for typ in inner_rest_types): # All subpatterns always match, so we can apply negative narrowing new_type, rest_type = self.chk.conditional_types_with_intersection( - current_type, [get_type_range(new_type)], o + current_type, [get_type_range(new_type)], o, default=current_type ) else: new_inner_type = UninhabitedType() @@ -442,7 +447,7 @@ def visit_class_pattern(self, o: ClassPattern) -> PatternType: return self.early_non_match() new_type, rest_type = self.chk.conditional_types_with_intersection( - current_type, [get_type_range(typ)], o + current_type, [get_type_range(typ)], o, default=current_type ) if is_uninhabited(new_type): return self.early_non_match() From c0a43e8de8faa73cdaf394ac2f531090eece6f94 Mon Sep 17 00:00:00 2001 From: Adrian Freund Date: Tue, 12 Oct 2021 14:10:38 +0200 Subject: [PATCH 60/76] Fix type errors --- mypy/checker.py | 65 ++++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 53 insertions(+), 12 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index b20d37dfa406..4c97e575d3d3 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -7,7 +7,7 @@ from typing import ( Any, Dict, Set, List, cast, Tuple, TypeVar, Union, Optional, NamedTuple, Iterator, - Iterable, Sequence, Mapping, Generic, AbstractSet, Callable + Iterable, Sequence, Mapping, Generic, AbstractSet, Callable, overload ) from typing_extensions import Final @@ -4402,7 +4402,6 @@ def is_type_call(expr: CallExpr) -> bool: expr ) current_if_map, current_else_map = conditional_types_to_typemaps(expr, - type_map[expr], current_if_type, current_else_type) if_maps.append(current_if_map) @@ -4461,7 +4460,6 @@ def find_isinstance_check_helper(self, node: Expression) -> Tuple[TypeMap, TypeM if literal(expr) == LITERAL_TYPE: return conditional_types_to_typemaps( expr, - type_map[expr], *self.conditional_types_with_intersection( type_map[expr], get_isinstance_type(node.args[1], type_map), @@ -4987,7 +4985,7 @@ def refine_identity_comparison_expression(self, # 'self.conditional_types_with_intersection': we only compute ad-hoc # intersections when working with pure instances. types = conditional_types(expr_type, target_type) - partial_type_maps.append(conditional_types_to_typemaps(expr, expr_type, *types)) + partial_type_maps.append(conditional_types_to_typemaps(expr, *types)) return reduce_conditional_maps(partial_type_maps) @@ -5405,21 +5403,44 @@ def infer_issubclass_maps(self, node: CallExpr, # for example, Any or a custom metaclass. return {}, {} # unknown type yes_type, no_type = self.conditional_types_with_intersection(vartype, type, expr) - yes_map, no_map = conditional_types_to_typemaps(expr, vartype, yes_type, no_type) + yes_map, no_map = conditional_types_to_typemaps(expr, yes_type, no_type) yes_map, no_map = map(convert_to_typetype, (yes_map, no_map)) return yes_map, no_map + @overload + def conditional_types_with_intersection(self, + expr_type: Type, + type_ranges: Optional[List[TypeRange]], + ctx: Context, + ) -> Tuple[Optional[Type], Optional[Type]]: ... + + @overload + def conditional_types_with_intersection(self, + expr_type: Type, + type_ranges: Optional[List[TypeRange]], + ctx: Context, + default: None + ) -> Tuple[Optional[Type], Optional[Type]]: ... + + @overload + def conditional_types_with_intersection(self, + expr_type: Type, + type_ranges: Optional[List[TypeRange]], + ctx: Context, + default: Type + ) -> Tuple[Type, Type]: ... + def conditional_types_with_intersection(self, expr_type: Type, type_ranges: Optional[List[TypeRange]], ctx: Context, default: Optional[Type] = None - ) -> Tuple[Type, Type]: + ) -> Tuple[Optional[Type], Optional[Type]]: initial_types = conditional_types(expr_type, type_ranges, default) # For some reason, doing "yes_map, no_map = conditional_types_to_typemaps(...)" # doesn't work: mypyc will decide that 'yes_map' is of type None if we try. - yes_type: Type = initial_types[0] - no_type: Type = initial_types[1] + yes_type: Optional[Type] = initial_types[0] + no_type: Optional[Type] = initial_types[1] if not isinstance(get_proper_type(yes_type), UninhabitedType) or type_ranges is None: return yes_type, no_type @@ -5466,10 +5487,30 @@ def is_writable_attribute(self, node: Node) -> bool: return False +@overload +def conditional_types(current_type: Type, + proposed_type_ranges: Optional[List[TypeRange]], + ) -> Tuple[Optional[Type], Optional[Type]]: ... + + +@overload +def conditional_types(current_type: Type, + proposed_type_ranges: Optional[List[TypeRange]], + default: None + ) -> Tuple[Type, Type]: ... + + +@overload +def conditional_types(current_type: Type, + proposed_type_ranges: Optional[List[TypeRange]], + default: Type + ) -> Tuple[Type, Type]: ... + + def conditional_types(current_type: Type, proposed_type_ranges: Optional[List[TypeRange]], default: Optional[Type] = None - ) -> Tuple[Type, Type]: + ) -> Tuple[Optional[Type], Optional[Type]]: """Takes in the current type and a proposed type of an expression. Returns a 2-tuple: The first element is the proposed type, if the expression @@ -5504,9 +5545,8 @@ def conditional_types(current_type: Type, def conditional_types_to_typemaps(expr: Expression, - expr_type: Type, - yes_type: Type, - no_type: Type + yes_type: Optional[Type], + no_type: Optional[Type] ) -> Tuple[TypeMap, TypeMap]: maps: List[TypeMap] = [] for typ in (yes_type, no_type): @@ -5516,6 +5556,7 @@ def conditional_types_to_typemaps(expr: Expression, elif proper_type is None: maps.append({}) else: + assert typ is not None maps.append({expr: typ}) return cast(Tuple[TypeMap, TypeMap], tuple(maps)) From f768993580aadaacba2178c056a349f212cd4660 Mon Sep 17 00:00:00 2001 From: Adrian Freund Date: Tue, 12 Oct 2021 14:49:47 +0200 Subject: [PATCH 61/76] Fixed copy-paste error causing mypyc crash --- mypy/checker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mypy/checker.py b/mypy/checker.py index 4c97e575d3d3..ba6029dc1e2a 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -5497,7 +5497,7 @@ def conditional_types(current_type: Type, def conditional_types(current_type: Type, proposed_type_ranges: Optional[List[TypeRange]], default: None - ) -> Tuple[Type, Type]: ... + ) -> Tuple[Optional[Type], Optional[Type]]: ... @overload From 7b60596576a43685d555ee75805c74d785b28d0d Mon Sep 17 00:00:00 2001 From: Adrian Freund Date: Wed, 20 Oct 2021 21:05:11 +0200 Subject: [PATCH 62/76] Added semanal tests for undefined value and class patterns --- test-data/unit/semanal-errors-python310.test | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/test-data/unit/semanal-errors-python310.test b/test-data/unit/semanal-errors-python310.test index 28fb43eea6ee..68c158cddae6 100644 --- a/test-data/unit/semanal-errors-python310.test +++ b/test-data/unit/semanal-errors-python310.test @@ -6,6 +6,24 @@ match x: [out] main:2: error: Name "x" is not defined +[case testMatchUndefinedValuePattern] +import typing +x = 1 +match x: + case a.b: + pass +[out] +main:4: error: Name "a" is not defined + +[case testMatchUndefinedClassPattern] +import typing +x = 1 +match x: + case A(): + pass +[out] +main:4: error: Name "A" is not defined + [case testNoneBindingWildcardPattern] import typing x = 1 From 6cf9f587ed189a0cef9baa53e073a3620dc480a4 Mon Sep 17 00:00:00 2001 From: Adrian Freund Date: Wed, 20 Oct 2021 21:05:46 +0200 Subject: [PATCH 63/76] Removed redundant bool from lib-stubs --- test-data/unit/lib-stub/builtins.pyi | 1 - 1 file changed, 1 deletion(-) diff --git a/test-data/unit/lib-stub/builtins.pyi b/test-data/unit/lib-stub/builtins.pyi index b57223698b7b..8e4c744be8fa 100644 --- a/test-data/unit/lib-stub/builtins.pyi +++ b/test-data/unit/lib-stub/builtins.pyi @@ -13,7 +13,6 @@ class int: def __add__(self, other: int) -> int: pass class bool(int): pass class float: pass -class bool(int): pass class str: pass class bytes: pass From 268e181f70b2ea26d6d74b7d230c287cb6f9a8cd Mon Sep 17 00:00:00 2001 From: Adrian Freund Date: Wed, 20 Oct 2021 21:11:05 +0200 Subject: [PATCH 64/76] Removed tuple from lib-stub --- test-data/unit/check-python310.test | 1 + test-data/unit/lib-stub/builtins.pyi | 1 - 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/test-data/unit/check-python310.test b/test-data/unit/check-python310.test index 4d62a5a314dc..33a8dc2e3f0c 100644 --- a/test-data/unit/check-python310.test +++ b/test-data/unit/check-python310.test @@ -652,6 +652,7 @@ match m: case A(i, j): reveal_type(i) # N: Revealed type is "builtins.str" reveal_type(j) # N: Revealed type is "builtins.int" +[builtins fixtures/tuple.pyi] [case testClassPatternCaptureGeneric] from typing import Generic, TypeVar diff --git a/test-data/unit/lib-stub/builtins.pyi b/test-data/unit/lib-stub/builtins.pyi index 8e4c744be8fa..8c4f504fb2e7 100644 --- a/test-data/unit/lib-stub/builtins.pyi +++ b/test-data/unit/lib-stub/builtins.pyi @@ -17,7 +17,6 @@ class float: pass class str: pass class bytes: pass -class tuple: pass class function: pass class ellipsis: pass From dbcd886073ee21cde4d97da42d2e7a84f16a382e Mon Sep 17 00:00:00 2001 From: Adrian Freund Date: Wed, 20 Oct 2021 21:12:58 +0200 Subject: [PATCH 65/76] Fix typo in test --- test-data/unit/check-python310.test | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test-data/unit/check-python310.test b/test-data/unit/check-python310.test index 33a8dc2e3f0c..d641be9b2da4 100644 --- a/test-data/unit/check-python310.test +++ b/test-data/unit/check-python310.test @@ -242,7 +242,7 @@ match m: reveal_type(m) # N: Revealed type is "Tuple[builtins.int, builtins.str, builtins.bool]" [builtins fixtures/list.pyi] -[case testSequencePatternTupleToLong] +[case testSequencePatternTupleTooLong] from typing import Tuple m: Tuple[int, str] From bd7e777baf632ffcf613037e942aa25012679248 Mon Sep 17 00:00:00 2001 From: Adrian Freund Date: Wed, 20 Oct 2021 21:52:52 +0200 Subject: [PATCH 66/76] Fixed testcase and marked it as skip --- test-data/unit/check-python310.test | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test-data/unit/check-python310.test b/test-data/unit/check-python310.test index d641be9b2da4..525c61bbbe8e 100644 --- a/test-data/unit/check-python310.test +++ b/test-data/unit/check-python310.test @@ -317,13 +317,13 @@ match x: case [str()]: pass -[case testSequenceUnion] +[case testSequenceUnion-skip] from typing import List, Union m: Union[List[List[str]], str] match m: case [list(['str'])]: - reveal_type(m) # N: Revealed type is "Union[builtins.list[builtins.list[builtins.str]], builtins.str]" + reveal_type(m) # N: Revealed type is "builtins.list[builtins.list[builtins.str]]" [builtins fixtures/list.pyi] -- Mapping Pattern -- From f51c83ff8997b9d144d03202034d92d3480b073f Mon Sep 17 00:00:00 2001 From: Adrian Freund Date: Wed, 20 Oct 2021 22:06:54 +0200 Subject: [PATCH 67/76] Moved messages to message_registry Also added class name to missing match args message --- mypy/checkpattern.py | 24 ++++++++++++++---------- mypy/message_registry.py | 14 ++++++++++++++ test-data/unit/check-python310.test | 2 +- 3 files changed, 29 insertions(+), 11 deletions(-) diff --git a/mypy/checkpattern.py b/mypy/checkpattern.py index bc8b7af2bdfa..285b13359f81 100644 --- a/mypy/checkpattern.py +++ b/mypy/checkpattern.py @@ -9,6 +9,7 @@ from mypy.literals import literal_hash from mypy.maptype import map_instance_to_supertype from mypy.meet import narrow_declared_type +from mypy import message_registry from mypy.messages import MessageBuilder from mypy.nodes import Expression, ARG_POS, TypeAlias, TypeInfo, Var, NameExpr from mypy.patterns import ( @@ -149,7 +150,7 @@ def visit_or_pattern(self, o: OrPattern) -> PatternType: for i, pattern_type in enumerate(pattern_types[1:]): vars = {get_var(expr) for expr, _ in pattern_type.captures.items()} if capture_types.keys() != vars: - self.msg.fail("Alternative patterns bind different names", o.patterns[i]) + self.msg.fail(message_registry.OR_PATTERN_ALTERNATIVE_NAMES, o.patterns[i]) for expr, typ in pattern_type.captures.items(): node = get_var(expr) capture_types[node].append((expr, typ)) @@ -431,7 +432,7 @@ def visit_class_pattern(self, o: ClassPattern) -> PatternType: type_info = o.class_ref.node assert type_info is not None if isinstance(type_info, TypeAlias) and not type_info.no_args: - self.msg.fail("Class pattern class must not be a type alias with type parameters", o) + self.msg.fail(message_registry.CLASS_PATTERN_GENERIC_TYPE_ALIAS, o) return self.early_non_match() if isinstance(type_info, TypeInfo): any_type = AnyType(TypeOfAny.implementation_artifact) @@ -443,7 +444,7 @@ def visit_class_pattern(self, o: ClassPattern) -> PatternType: name = str(type_info.type) else: name = type_info.name - self.msg.fail('Class pattern must be a type. Found "{}"'.format(name), o.class_ref) + self.msg.fail(message_registry.CLASS_PATTERN_TYPE_REQUIRED.format(name), o.class_ref) return self.early_non_match() new_type, rest_type = self.chk.conditional_types_with_intersection( @@ -465,7 +466,7 @@ def visit_class_pattern(self, o: ClassPattern) -> PatternType: if len(o.positionals) != 0: if self.should_self_match(typ): if len(o.positionals) > 1: - self.msg.fail("Too many positional patterns for class pattern", o) + self.msg.fail(message_registry.CLASS_PATTERN_TOO_MANY_POSITIONAL_ARGS, o) pattern_type = self.accept(o.positionals[0], narrowed_type) if not is_uninhabited(pattern_type.type): return PatternType(pattern_type.type, @@ -481,7 +482,7 @@ def visit_class_pattern(self, o: ClassPattern) -> PatternType: chk=self.chk) if local_errors.is_errors(): - self.msg.fail("Class doesn't define __match_args__", o) + self.msg.fail(message_registry.MISSING_MATCH_ARGS.format(typ), o) return self.early_non_match() proper_match_args_type = get_proper_type(match_args_type) @@ -489,7 +490,7 @@ def visit_class_pattern(self, o: ClassPattern) -> PatternType: match_arg_names = get_match_arg_names(proper_match_args_type) if len(o.positionals) > len(match_arg_names): - self.msg.fail("Too many positional patterns for class pattern", o) + self.msg.fail(message_registry.CLASS_PATTERN_TOO_MANY_POSITIONAL_ARGS, o) return self.early_non_match() else: match_arg_names = [None] * len(o.positionals) @@ -507,11 +508,14 @@ def visit_class_pattern(self, o: ClassPattern) -> PatternType: for key, value in zip(o.keyword_keys, o.keyword_values): keyword_pairs.append((key, value)) if key in match_arg_set: - self.msg.fail('Keyword "{}" already matches a positional pattern'.format(key), - value) + self.msg.fail( + message_registry.CLASS_PATTERN_KEYWORD_MATCHES_POSITIONAL.format(key), + value + ) has_duplicates = True elif key in keyword_arg_set: - self.msg.fail('Duplicate keyword pattern "{}"'.format(key), value) + self.msg.fail(message_registry.CLASS_PATTERN_DUPLICATE_KEYWORD_PATTERN.format(key), + value) has_duplicates = True keyword_arg_set.add(key) @@ -594,7 +598,7 @@ def update_type_map(self, for expr, typ in extra_type_map.items(): if literal_hash(expr) in already_captured: node = get_var(expr) - self.msg.fail('Multiple assignments to name "{}" in pattern'.format(node.name), + self.msg.fail(message_registry.MULTIPLE_ASSIGNMENTS_IN_PATTERN.format(node.name), expr) else: original_type_map[expr] = typ diff --git a/mypy/message_registry.py b/mypy/message_registry.py index e04538221f0e..21374be5af64 100644 --- a/mypy/message_registry.py +++ b/mypy/message_registry.py @@ -158,3 +158,17 @@ "Only @runtime_checkable protocols can be used with instance and class checks" ) CANNOT_INSTANTIATE_PROTOCOL: Final = 'Cannot instantiate protocol class "{}"' + +# Match Statement +MISSING_MATCH_ARGS: Final = 'Class "{}" doesn\'t define "__match_args__"' +OR_PATTERN_ALTERNATIVE_NAMES: Final = "Alternative patterns bind different names" +CLASS_PATTERN_GENERIC_TYPE_ALIAS: Final = ( + "Class pattern class must not be a type alias with type parameters" +) +CLASS_PATTERN_TYPE_REQUIRED: Final = 'Class pattern must be a type. Found "{}"' +CLASS_PATTERN_TOO_MANY_POSITIONAL_ARGS: Final = "Too many positional patterns for class pattern" +CLASS_PATTERN_KEYWORD_MATCHES_POSITIONAL: Final = ( + 'Keyword "{}" already matches a positional pattern' +) +CLASS_PATTERN_DUPLICATE_KEYWORD_PATTERN: Final = 'Duplicate keyword pattern "{}"' +MULTIPLE_ASSIGNMENTS_IN_PATTERN: Final = 'Multiple assignments to name "{}" in pattern' diff --git a/test-data/unit/check-python310.test b/test-data/unit/check-python310.test index 525c61bbbe8e..ff4cdcfafcf1 100644 --- a/test-data/unit/check-python310.test +++ b/test-data/unit/check-python310.test @@ -592,7 +592,7 @@ class A: m: A match m: - case A(i, j): # E: Class doesn't define __match_args__ + case A(i, j): # E: Class "__main__.A" doesn't define "__match_args__" pass [builtins fixtures/dataclasses.pyi] From 4d87ca1110fe6f37b3eabb24693ba5dcfa942aaf Mon Sep 17 00:00:00 2001 From: Adrian Freund Date: Wed, 20 Oct 2021 22:10:07 +0200 Subject: [PATCH 68/76] Removed outdated comment --- mypy/patterns.py | 1 - 1 file changed, 1 deletion(-) diff --git a/mypy/patterns.py b/mypy/patterns.py index 79d394f7f629..a47d33ce3e8c 100644 --- a/mypy/patterns.py +++ b/mypy/patterns.py @@ -6,7 +6,6 @@ from mypy.nodes import Node, RefExpr, NameExpr, Expression from mypy.visitor import PatternVisitor -# These are not real AST nodes. CPython represents patterns using the normal expression nodes. T = TypeVar('T') From 9f9efccc8ddf79fc8ff893b33f28a9b74ce630b9 Mon Sep 17 00:00:00 2001 From: Adrian Freund Date: Wed, 20 Oct 2021 22:11:12 +0200 Subject: [PATCH 69/76] Removed full stop from error message --- mypy/message_registry.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mypy/message_registry.py b/mypy/message_registry.py index 21374be5af64..9945fbe4bbcb 100644 --- a/mypy/message_registry.py +++ b/mypy/message_registry.py @@ -165,7 +165,7 @@ CLASS_PATTERN_GENERIC_TYPE_ALIAS: Final = ( "Class pattern class must not be a type alias with type parameters" ) -CLASS_PATTERN_TYPE_REQUIRED: Final = 'Class pattern must be a type. Found "{}"' +CLASS_PATTERN_TYPE_REQUIRED: Final = 'Expected type in class pattern; found "{}"' CLASS_PATTERN_TOO_MANY_POSITIONAL_ARGS: Final = "Too many positional patterns for class pattern" CLASS_PATTERN_KEYWORD_MATCHES_POSITIONAL: Final = ( 'Keyword "{}" already matches a positional pattern' From 68bd2ba74f947b8564331bc3160db873b84eb984 Mon Sep 17 00:00:00 2001 From: Adrian Freund Date: Wed, 20 Oct 2021 22:13:00 +0200 Subject: [PATCH 70/76] Change test case to new error message --- test-data/unit/check-python310.test | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test-data/unit/check-python310.test b/test-data/unit/check-python310.test index ff4cdcfafcf1..6115e09c0368 100644 --- a/test-data/unit/check-python310.test +++ b/test-data/unit/check-python310.test @@ -861,7 +861,7 @@ a = 1 m: object match m: - case a(i, j): # E: Class pattern must be a type. Found "builtins.int" + case a(i, j): # E: Expected type in class pattern; found "builtins.int" reveal_type(i) reveal_type(j) From ff5123c8879c1db382217185642abf406bad665d Mon Sep 17 00:00:00 2001 From: Adrian Freund Date: Wed, 20 Oct 2021 22:31:43 +0200 Subject: [PATCH 71/76] Report error for non-existing keywords on class patterns --- mypy/checkpattern.py | 6 +++++- mypy/message_registry.py | 1 + test-data/unit/check-python310.test | 2 +- test-data/unit/fixtures/dict.pyi | 2 ++ 4 files changed, 9 insertions(+), 2 deletions(-) diff --git a/mypy/checkpattern.py b/mypy/checkpattern.py index 285b13359f81..7c098d114a77 100644 --- a/mypy/checkpattern.py +++ b/mypy/checkpattern.py @@ -539,8 +539,12 @@ def visit_class_pattern(self, o: ClassPattern) -> PatternType: local_errors, original_type=new_type, chk=self.chk) + else: + key_type = AnyType(TypeOfAny.from_error) if local_errors.is_errors() or key_type is None: - key_type = AnyType(TypeOfAny.implementation_artifact) + key_type = AnyType(TypeOfAny.from_error) + self.msg.fail(message_registry.CLASS_PATTERN_UNKNOWN_KEYWORD.format(typ, keyword), + value) inner_type, inner_rest_type, inner_captures = self.accept(pattern, key_type) if is_uninhabited(inner_type): diff --git a/mypy/message_registry.py b/mypy/message_registry.py index 9945fbe4bbcb..45f97e6ba07c 100644 --- a/mypy/message_registry.py +++ b/mypy/message_registry.py @@ -171,4 +171,5 @@ 'Keyword "{}" already matches a positional pattern' ) CLASS_PATTERN_DUPLICATE_KEYWORD_PATTERN: Final = 'Duplicate keyword pattern "{}"' +CLASS_PATTERN_UNKNOWN_KEYWORD: Final = 'Class "{}" has no attribute "{}"' MULTIPLE_ASSIGNMENTS_IN_PATTERN: Final = 'Multiple assignments to name "{}" in pattern' diff --git a/test-data/unit/check-python310.test b/test-data/unit/check-python310.test index 6115e09c0368..c4180b41fd63 100644 --- a/test-data/unit/check-python310.test +++ b/test-data/unit/check-python310.test @@ -813,7 +813,7 @@ class A: ... m: object match m: - case A(a=j): + case A(a=j): # E: Class "__main__.A" has no attribute "a" reveal_type(m) # N: Revealed type is "__main__.A" reveal_type(j) # N: Revealed type is "Any" diff --git a/test-data/unit/fixtures/dict.pyi b/test-data/unit/fixtures/dict.pyi index 476b00840d13..37a5d6784916 100644 --- a/test-data/unit/fixtures/dict.pyi +++ b/test-data/unit/fixtures/dict.pyi @@ -36,6 +36,8 @@ class int: # for convenience def __add__(self, x: Union[int, complex]) -> int: pass def __sub__(self, x: Union[int, complex]) -> int: pass def __neg__(self): pass + real: int + imag: int class str: pass # for keyword argument key type class unicode: pass # needed for py2 docstrings From e4373235fedd446f913d195e56119a78db3057d2 Mon Sep 17 00:00:00 2001 From: Adrian Freund Date: Wed, 20 Oct 2021 22:44:53 +0200 Subject: [PATCH 72/76] Fixed dataclasses plugin after merging master. Caused by https://github.com/python/mypy/pull/11332 --- mypy/plugins/dataclasses.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mypy/plugins/dataclasses.py b/mypy/plugins/dataclasses.py index 7284cfd3292b..b8557f77bca8 100644 --- a/mypy/plugins/dataclasses.py +++ b/mypy/plugins/dataclasses.py @@ -195,10 +195,10 @@ def transform(self) -> None: ('__match_args__' not in info.names or info.names['__match_args__'].plugin_generated) and attributes): - str_type = ctx.api.named_type("__builtins__.str") + str_type = ctx.api.named_type("builtins.str") literals: List[Type] = [LiteralType(attr.name, str_type) for attr in attributes if attr.is_in_init] - match_args_type = TupleType(literals, ctx.api.named_type("__builtins__.tuple")) + match_args_type = TupleType(literals, ctx.api.named_type("builtins.tuple")) add_attribute_to_class(ctx.api, ctx.cls, "__match_args__", match_args_type, final=True) self._add_dataclass_fields_magic_attribute() From 68a09ee20aece8ab4ca0e807b34e1f614b45684c Mon Sep 17 00:00:00 2001 From: Adrian Freund Date: Thu, 21 Oct 2021 18:33:13 +0200 Subject: [PATCH 73/76] Added more comments and asserts to pattern classes --- mypy/nodes.py | 1 + mypy/patterns.py | 19 ++++++++++--------- mypy/plugins/common.py | 2 +- 3 files changed, 12 insertions(+), 10 deletions(-) diff --git a/mypy/nodes.py b/mypy/nodes.py index c0e383b82e68..160f84aec643 100644 --- a/mypy/nodes.py +++ b/mypy/nodes.py @@ -1372,6 +1372,7 @@ class MatchStmt(Statement): def __init__(self, subject: Expression, patterns: List['Pattern'], guards: List[Optional[Expression]], bodies: List[Block]) -> None: super().__init__() + assert len(patterns) == len(guards) == len(bodies) self.subject = subject self.patterns = patterns self.guards = guards diff --git a/mypy/patterns.py b/mypy/patterns.py index a47d33ce3e8c..8557fac6daf6 100644 --- a/mypy/patterns.py +++ b/mypy/patterns.py @@ -20,14 +20,12 @@ def accept(self, visitor: PatternVisitor[T]) -> T: raise RuntimeError('Not implemented') -@trait -class AlwaysTruePattern(Pattern): - """A pattern that is always matches""" - - __slots__ = () - - class AsPattern(Pattern): + # The python ast, and therefore also our ast merges capture, wildcard and as patterns into one + # for easier handling. + # If pattern is None this is a capture pattern. If name and pattern are both none this is a + # wildcard pattern. + # Only name being None should not happen but also won't break anything. pattern: Optional[Pattern] name: Optional[NameExpr] @@ -63,6 +61,7 @@ def accept(self, visitor: PatternVisitor[T]) -> T: class SingletonPattern(Pattern): + # This can be exactly True, False or None value: Union[bool, None] def __init__(self, value: Union[bool, None]): @@ -84,9 +83,9 @@ def accept(self, visitor: PatternVisitor[T]) -> T: return visitor.visit_sequence_pattern(self) -# TODO: A StarredPattern is only valid within a SequencePattern. This is not guaranteed by our -# type hierarchy. Should it be? class StarredPattern(Pattern): + # None corresponds to *_ in a list pattern. It will match multiple items but won't bind them to + # a name. capture: Optional[NameExpr] def __init__(self, capture: Optional[NameExpr]): @@ -105,6 +104,7 @@ class MappingPattern(Pattern): def __init__(self, keys: List[Expression], values: List[Pattern], rest: Optional[NameExpr]): super().__init__() + assert len(keys) == len(values) self.keys = keys self.values = values self.rest = rest @@ -122,6 +122,7 @@ class ClassPattern(Pattern): def __init__(self, class_ref: RefExpr, positionals: List[Pattern], keyword_keys: List[str], keyword_values: List[Pattern]): super().__init__() + assert len(keyword_keys) == len(keyword_values) self.class_ref = class_ref self.positionals = positionals self.keyword_keys = keyword_keys diff --git a/mypy/plugins/common.py b/mypy/plugins/common.py index 40fee190e702..95f4618da4a1 100644 --- a/mypy/plugins/common.py +++ b/mypy/plugins/common.py @@ -179,7 +179,7 @@ def add_attribute_to_class( node = Var(name, typ) node.info = info node.is_final = final - node._fullname = api.qualified_name(name) + node._fullname = info.fullname + '.' + name info.names[name] = SymbolTableNode(MDEF, node, plugin_generated=True) From a5baf606fc93501624eeb9f90247c5ecffd2e8f5 Mon Sep 17 00:00:00 2001 From: Adrian Freund Date: Thu, 21 Oct 2021 18:50:17 +0200 Subject: [PATCH 74/76] Readd tuple missing tests --- test-data/unit/check-incomplete-fixture.test | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/test-data/unit/check-incomplete-fixture.test b/test-data/unit/check-incomplete-fixture.test index b5552adda6ec..f06dad293184 100644 --- a/test-data/unit/check-incomplete-fixture.test +++ b/test-data/unit/check-incomplete-fixture.test @@ -50,6 +50,25 @@ main:1: error: Name "isinstance" is not defined main:1: note: Maybe your test fixture does not define "builtins.isinstance"? main:1: note: Consider adding [builtins fixtures/isinstancelist.pyi] to your test description +[case testTupleMissingFromStubs1] +tuple() +[out] +main:1: error: Name "tuple" is not defined +main:1: note: Maybe your test fixture does not define "builtins.tuple"? +main:1: note: Consider adding [builtins fixtures/tuple.pyi] to your test description +main:1: note: Did you forget to import it from "typing"? (Suggestion: "from typing import Tuple") + +[case testTupleMissingFromStubs2] +tuple() +from typing import Tuple +x: Tuple[int, str] +[out] +main:1: error: Name "tuple" is not defined +main:1: note: Maybe your test fixture does not define "builtins.tuple"? +main:1: note: Consider adding [builtins fixtures/tuple.pyi] to your test description +main:1: note: Did you forget to import it from "typing"? (Suggestion: "from typing import Tuple") +main:3: error: Name "tuple" is not defined + [case testClassmethodMissingFromStubs] class A: @classmethod From 53d40f83c1efe300aed08debff2d4b6ee15f5ee5 Mon Sep 17 00:00:00 2001 From: Adrian Freund Date: Sun, 19 Dec 2021 13:13:23 +0100 Subject: [PATCH 75/76] Added missing Final annotations --- mypy/checkpattern.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mypy/checkpattern.py b/mypy/checkpattern.py index 7c098d114a77..1ae3ad1dfd83 100644 --- a/mypy/checkpattern.py +++ b/mypy/checkpattern.py @@ -1,6 +1,6 @@ """Pattern checker. This file is conceptually part of TypeChecker.""" from collections import defaultdict -from typing import List, Optional, Tuple, Dict, NamedTuple, Set, Union +from typing import List, Optional, Tuple, Dict, NamedTuple, Set, Union, Final import mypy.checker from mypy.checkmember import analyze_member_access @@ -26,7 +26,7 @@ from mypy.typevars import fill_typevars from mypy.visitor import PatternVisitor -self_match_type_names = [ +self_match_type_names: Final = [ "builtins.bool", "builtins.bytearray", "builtins.bytes", @@ -40,7 +40,7 @@ "builtins.tuple", ] -non_sequence_match_type_names = [ +non_sequence_match_type_names: Final = [ "builtins.str", "builtins.bytes", "builtins.bytearray" From 52dc066692e0ecc6f78dfe2ee3a9e316b749ab2d Mon Sep 17 00:00:00 2001 From: Adrian Freund Date: Sun, 19 Dec 2021 14:26:42 +0100 Subject: [PATCH 76/76] Add comments and clarify some names --- mypy/checker.py | 27 +++++++--------------- mypy/checkpattern.py | 53 +++++++++++++++++++++++++++++++++++--------- 2 files changed, 50 insertions(+), 30 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index c60c05509474..0abd744b8aa6 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -3935,7 +3935,7 @@ def visit_match_stmt(self, s: MatchStmt) -> None: pattern_types = [self.pattern_checker.accept(p, subject_type) for p in s.patterns] type_maps: List[TypeMap] = [t.captures for t in pattern_types] - self.infer_names_from_type_maps(type_maps) + self.infer_variable_types_from_type_maps(type_maps) for pattern_type, g, b in zip(pattern_types, s.guards, s.bodies): with self.binder.frame_context(can_skip=True, fall_through=2): @@ -3961,7 +3961,7 @@ def visit_match_stmt(self, s: MatchStmt) -> None: with self.binder.frame_context(can_skip=False, fall_through=2): pass - def infer_names_from_type_maps(self, type_maps: List[TypeMap]) -> None: + def infer_variable_types_from_type_maps(self, type_maps: List[TypeMap]) -> None: all_captures: Dict[Var, List[Tuple[NameExpr, Type]]] = defaultdict(list) for tm in type_maps: if tm is not None: @@ -5387,14 +5387,7 @@ def conditional_types_with_intersection(self, expr_type: Type, type_ranges: Optional[List[TypeRange]], ctx: Context, - ) -> Tuple[Optional[Type], Optional[Type]]: ... - - @overload - def conditional_types_with_intersection(self, - expr_type: Type, - type_ranges: Optional[List[TypeRange]], - ctx: Context, - default: None + default: None = None ) -> Tuple[Optional[Type], Optional[Type]]: ... @overload @@ -5420,7 +5413,7 @@ def conditional_types_with_intersection(self, if not isinstance(get_proper_type(yes_type), UninhabitedType) or type_ranges is None: return yes_type, no_type - # If conditions_type_map was unable to successfully narrow the expr_type + # If conditional_types was unable to successfully narrow the expr_type # using the type_ranges and concluded if-branch is unreachable, we try # computing it again using a different algorithm that tries to generate # an ad-hoc intersection between the expr_type and the type_ranges. @@ -5465,13 +5458,7 @@ def is_writable_attribute(self, node: Node) -> bool: @overload def conditional_types(current_type: Type, proposed_type_ranges: Optional[List[TypeRange]], - ) -> Tuple[Optional[Type], Optional[Type]]: ... - - -@overload -def conditional_types(current_type: Type, - proposed_type_ranges: Optional[List[TypeRange]], - default: None + default: None = None ) -> Tuple[Optional[Type], Optional[Type]]: ... @@ -5490,7 +5477,9 @@ def conditional_types(current_type: Type, Returns a 2-tuple: The first element is the proposed type, if the expression can be the proposed type. The second element is the type it would hold - if it was not the proposed type, if any. UninhabitedType means unreachable""" + if it was not the proposed type, if any. UninhabitedType means unreachable. + None means no new information can be inferred. If default is set it is returned + instead.""" if proposed_type_ranges: proposed_items = [type_range.item for type_range in proposed_type_ranges] proposed_type = make_simplified_union(proposed_items) diff --git a/mypy/checkpattern.py b/mypy/checkpattern.py index 1ae3ad1dfd83..2c40e856be88 100644 --- a/mypy/checkpattern.py +++ b/mypy/checkpattern.py @@ -1,6 +1,7 @@ """Pattern checker. This file is conceptually part of TypeChecker.""" from collections import defaultdict -from typing import List, Optional, Tuple, Dict, NamedTuple, Set, Union, Final +from typing import List, Optional, Tuple, Dict, NamedTuple, Set, Union +from typing_extensions import Final import mypy.checker from mypy.checkmember import analyze_member_access @@ -47,12 +48,15 @@ ] +# For every Pattern a PatternType can be calculated. This requires recursively calculating +# the PatternTypes of the sub-patterns first. +# Using the data in the PatternType the match subject and captured names can be narrowed/inferred. PatternType = NamedTuple( 'PatternType', [ - ('type', Type), + ('type', Type), # The type the match subject can be narrowed to ('rest_type', Type), # For exhaustiveness checking. Not used yet - ('captures', Dict[Expression, Type]), + ('captures', Dict[Expression, Type]), # The variables captured by the pattern ]) @@ -75,9 +79,11 @@ class PatternChecker(PatternVisitor[PatternType]): subject_type: Type # Type of the subject to check the (sub)pattern against type_context: List[Type] - + # Types that match against self instead of their __match_args__ if used as a class pattern + # Filled in from self_match_type_names self_match_types: List[Type] - + # Types that are sequences, but don't match sequence patterns. Filled in from + # non_sequence_match_type_names non_sequence_match_types: List[Type] def __init__(self, @@ -89,8 +95,10 @@ def __init__(self, self.plugin = plugin self.type_context = [] - self.self_match_types = self.generate_types(self_match_type_names) - self.non_sequence_match_types = self.generate_types(non_sequence_match_type_names) + self.self_match_types = self.generate_types_from_names(self_match_type_names) + self.non_sequence_match_types = self.generate_types_from_names( + non_sequence_match_type_names + ) def accept(self, o: Pattern, type_context: Type) -> PatternType: self.type_context.append(type_context) @@ -286,7 +294,7 @@ def visit_sequence_pattern(self, o: SequencePattern) -> PatternType: new_inner_type = UninhabitedType() for typ in new_inner_types: new_inner_type = join_types(new_inner_type, typ) - new_type = self.construct_iterable_child(current_type, new_inner_type) + new_type = self.construct_sequence_child(current_type, new_inner_type) if not is_subtype(new_type, current_type): new_type = current_type return PatternType(new_type, rest_type, captures) @@ -313,6 +321,15 @@ def contract_starred_pattern_types(self, star_pos: Optional[int], num_patterns: int ) -> List[Type]: + """ + Contracts a list of types in a sequence pattern depending on the position of a starred + capture pattern. + + For example if the sequence pattern [a, *b, c] is matched against types [bool, int, str, + bytes] the contracted types are [bool, Union[int, str], bytes]. + + If star_pos in None the types are returned unchanged. + """ if star_pos is None: return types new_types = types[:star_pos] @@ -327,6 +344,12 @@ def expand_starred_pattern_types(self, star_pos: Optional[int], num_types: int ) -> List[Type]: + """ + Undoes the contraction done by contract_starred_pattern_types. + + For example if the sequence pattern is [a, *b, c] and types [bool, int, str] are extended + to lenght 4 the result is [bool, int, int, str]. + """ if star_pos is None: return types new_types = types[:star_pos] @@ -579,7 +602,7 @@ def can_match_sequence(self, typ: ProperType) -> bool: # If the static type is more general than sequence the actual type could still match return is_subtype(typ, sequence) or is_subtype(sequence, typ) - def generate_types(self, type_names: List[str]) -> List[Type]: + def generate_types_from_names(self, type_names: List[str]) -> List[Type]: types: List[Type] = [] for name in type_names: try: @@ -607,9 +630,17 @@ def update_type_map(self, else: original_type_map[expr] = typ - def construct_iterable_child(self, outer_type: Type, inner_type: Type) -> Type: + def construct_sequence_child(self, outer_type: Type, inner_type: Type) -> Type: + """ + If outer_type is a child class of typing.Sequence returns a new instance of + outer_type, that is a Sequence of inner_type. If outer_type is not a child class of + typing.Sequence just returns a Sequence of inner_type + + For example: + construct_sequence_child(List[int], str) = List[str] + """ sequence = self.chk.named_generic_type("typing.Sequence", [inner_type]) - if self.chk.type_is_iterable(outer_type): + if is_subtype(outer_type, self.chk.named_type("typing.Sequence")): proper_type = get_proper_type(outer_type) assert isinstance(proper_type, Instance) empty_type = fill_typevars(proper_type.type)