From 3260bb00062a3809bfab20f3119cca88cbd2106b Mon Sep 17 00:00:00 2001 From: Jelle Zijlstra Date: Sat, 30 Sep 2023 06:55:19 -0700 Subject: [PATCH] More progress on PEP 695 (#692) --- docs/changelog.md | 2 +- pyanalyze/annotations.py | 10 +++- pyanalyze/name_check_visitor.py | 99 +++++++++++++++++++++++++++++++-- pyanalyze/stacked_scopes.py | 1 + pyanalyze/test_type_aliases.py | 96 ++++++++++++++++++++++++++++++++ pyanalyze/value.py | 30 +++++++++- 6 files changed, 227 insertions(+), 11 deletions(-) create mode 100644 pyanalyze/test_type_aliases.py diff --git a/docs/changelog.md b/docs/changelog.md index 945ffefb..2ab0d3d2 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -2,9 +2,9 @@ ## Unreleased +- Partial support for PEP 695-style type aliases (#690, #692) - Fix tests to account for new `typeshed_client` release (#694) -- Partial support for PEP 695-style type aliases (#690) - Add option to disable all error codes (#659) - Add hacky fix for bugs with hashability on type objects (#689) - Show an error on calls to `typing.Any` (#688) diff --git a/pyanalyze/annotations.py b/pyanalyze/annotations.py index 5399317d..6d0a0ba5 100644 --- a/pyanalyze/annotations.py +++ b/pyanalyze/annotations.py @@ -46,6 +46,7 @@ TypeVar, Union, ) +import typing_extensions import qcore @@ -539,11 +540,14 @@ def _type_from_runtime( def make_type_var_value(tv: TypeVarLike, ctx: Context) -> TypeVarValue: - if tv.__bound__ is not None: + if ( + isinstance(tv, (TypeVar, typing_extensions.TypeVar)) + and tv.__bound__ is not None + ): bound = _type_from_runtime(tv.__bound__, ctx) else: bound = None - if isinstance(tv, TypeVar) and tv.__constraints__: + if isinstance(tv, (TypeVar, typing_extensions.TypeVar)) and tv.__constraints__: constraints = tuple( _type_from_runtime(constraint, ctx) for constraint in tv.__constraints__ ) @@ -656,7 +660,7 @@ def _type_from_value( return _type_from_runtime( value.val, ctx, is_typeddict=is_typeddict, allow_unpack=allow_unpack ) - elif isinstance(value, TypeVarValue): + elif isinstance(value, (TypeVarValue, TypeAliasValue)): return value elif isinstance(value, MultiValuedValue): return unite_values( diff --git a/pyanalyze/name_check_visitor.py b/pyanalyze/name_check_visitor.py index 36920769..cf7cf50a 100644 --- a/pyanalyze/name_check_visitor.py +++ b/pyanalyze/name_check_visitor.py @@ -23,6 +23,7 @@ import sys import traceback import types +import typing from argparse import ArgumentParser from dataclasses import dataclass from itertools import chain @@ -100,6 +101,7 @@ from .predicates import EqualsPredicate, InPredicate from .reexport import ImplicitReexportTracker from .safe import ( + all_of_type, is_dataclass_type, is_hashable, safe_getattr, @@ -162,6 +164,8 @@ DefiniteValueExtension, DeprecatedExtension, SkipDeprecatedExtension, + TypeAlias, + TypeAliasValue, annotate_value, AnnotatedValue, AnySource, @@ -1754,22 +1758,22 @@ def visit_ClassDef(self, node: ast.ClassDef) -> Value: value, _ = self._set_name_in_scope(node.name, node, value) return value - def _get_class_object(self, node: ast.ClassDef) -> Value: + def _get_local_object(self, name: str, node: ast.AST) -> Value: if self.scopes.scope_type() == ScopeType.module_scope: - return self.scopes.get(node.name, node, self.state) + return self.scopes.get(name, node, self.state) elif ( self.scopes.scope_type() == ScopeType.class_scope and self.current_class is not None and hasattr(self.current_class, "__dict__") ): - runtime_obj = self.current_class.__dict__.get(node.name) + runtime_obj = self.current_class.__dict__.get(name) if isinstance(runtime_obj, type): return KnownValue(runtime_obj) return AnyValue(AnySource.inference) def _visit_class_and_get_value(self, node: ast.ClassDef) -> Value: if self._is_checking(): - cls_obj = self._get_class_object(node) + cls_obj = self._get_local_object(node.name, node) module = self.module if isinstance(cls_obj, MultiValuedValue) and module is not None: @@ -4506,6 +4510,93 @@ def visit_AugAssign(self, node: ast.AugAssign) -> None: # syntax like 'x = y = 0' results in multiple targets self.visit(node.target) + if sys.version_info >= (3, 12): + + def visit_TypeAlias(self, node: ast.TypeAlias) -> Value: + assert isinstance(node.name, ast.Name) + name = node.name.id + alias_val = self._get_local_object(name, node) + if isinstance(alias_val, KnownValue) and isinstance( + alias_val.val, typing.TypeAliasType + ): + alias_obj = alias_val.val + else: + alias_obj = None + type_param_values = [] + if self._is_checking(): + if node.type_params: + with self.scopes.add_scope( + ScopeType.annotation_scope, + scope_node=node, + scope_object=alias_obj, + ): + type_param_values = [ + self.visit(param) for param in node.type_params + ] + assert all_of_type(type_param_values, TypeVarValue) + with self.scopes.add_scope( + ScopeType.annotation_scope, + scope_node=node, + scope_object=alias_obj, + ): + value = self.visit(node.value) + + else: + with self.scopes.add_scope( + ScopeType.annotation_scope, + scope_node=node, + scope_object=alias_obj, + ): + value = self.visit(node.value) + else: + value = None + if alias_obj is None: + if value is None: + alias_val = AnyValue(AnySource.inference) + else: + alias_val = TypeAliasValue( + name, + self.module.__name__ if self.module is not None else "", + TypeAlias( + lambda: type_from_value(value, self, node), + lambda: tuple(val.typevar for val in type_param_values), + ), + ) + set_value, _ = self._set_name_in_scope(name, node, alias_val) + return set_value + + def visit_TypeVar(self, node: ast.TypeVar) -> Value: + bound = constraints = None + if node.bound is not None: + if isinstance(node.bound, ast.Tuple): + constraints = [self.visit(elt) for elt in node.bound.elts] + else: + bound = self.visit(node.bound) + tv = TypeVar(node.name) + typevar = TypeVarValue( + tv, + type_from_value(bound, self, node) if bound is not None else None, + ( + tuple(type_from_value(c, self, node) for c in constraints) + if constraints is not None + else () + ), + ) + self._set_name_in_scope(node.name, node, typevar) + return typevar + + def visit_ParamSpec(self, node: ast.ParamSpec) -> Value: + ps = typing.ParamSpec(node.name) + typevar = TypeVarValue(ps, is_paramspec=True) + self._set_name_in_scope(node.name, node, typevar) + return typevar + + def visit_TypeVarTuple(self, node: ast.TypeVarTuple) -> Value: + tv = TypeVar(node.name) + typevar = TypeVarValue(tv, is_typevartuple=True) + self._set_name_in_scope(node.name, node, typevar) + return typevar + def visit_Name(self, node: ast.Name, force_read: bool = False) -> Value: return self.composite_from_name(node, force_read=force_read).value diff --git a/pyanalyze/stacked_scopes.py b/pyanalyze/stacked_scopes.py index b461db91..c9aa7ac6 100644 --- a/pyanalyze/stacked_scopes.py +++ b/pyanalyze/stacked_scopes.py @@ -91,6 +91,7 @@ class ScopeType(enum.Enum): module_scope = 2 class_scope = 3 function_scope = 4 + annotation_scope = 5 # Nodes as used in scopes can be any object, as long as they are hashable. diff --git a/pyanalyze/test_type_aliases.py b/pyanalyze/test_type_aliases.py new file mode 100644 index 00000000..0411f7a5 --- /dev/null +++ b/pyanalyze/test_type_aliases.py @@ -0,0 +1,96 @@ +# static analysis: ignore +from .test_name_check_visitor import TestNameCheckVisitorBase +from .test_node_visitor import assert_passes, skip_before + + +class TestRecursion(TestNameCheckVisitorBase): + @assert_passes() + def test(self): + from typing import Dict, List, Union + + JSON = Union[Dict[str, "JSON"], List["JSON"], int, str, float, bool, None] + + def f(x: JSON): + pass + + def capybara(): + f([]) + f([1, 2, 3]) + f([[{1}]]) # TODO this should throw an error + + +class TestTypeAliasType(TestNameCheckVisitorBase): + @assert_passes() + def test_typing_extensions(self): + from typing_extensions import TypeAliasType, assert_type + + MyType = TypeAliasType("MyType", int) + + def f(x: MyType): + assert_type(x, MyType) + assert_type(x + 1, int) + + def capybara(i: int, s: str): + f(i) + f(s) # E: incompatible_argument + + @assert_passes() + def test_typing_extensions_generic(self): + from typing_extensions import TypeAliasType, assert_type + from typing import TypeVar, Union, List, Set + + T = TypeVar("T") + MyType = TypeAliasType("MyType", Union[List[T], Set[T]], type_params=(T,)) + + def f(x: MyType[int]): + assert_type(x, MyType[int]) + assert_type(list(x), List[int]) + + def capybara(i: int, s: str): + f([i]) + f([s]) # E: incompatible_argument + + @skip_before((3, 12)) + def test_312(self): + self.assert_passes(""" + from typing_extensions import assert_type + type MyType = int + + def f(x: MyType): + assert_type(x, MyType) + assert_type(x + 1, int) + + def capybara(i: int, s: str): + f(i) + f(s) # E: incompatible_argument + """) + + @skip_before((3, 12)) + def test_312_generic(self): + self.assert_passes(""" + from typing_extensions import assert_type + type MyType[T] = list[T] | set[T] + + def f(x: MyType[int]): + assert_type(x, MyType[int]) + assert_type(list(x), list[int]) + + def capybara(i: int, s: str): + f([i]) + f([s]) # E: incompatible_argument + """) + + @skip_before((3, 12)) + def test_312_local_alias(self): + self.assert_passes(""" + from typing_extensions import assert_type + + def capybara(): + type MyType = int + def f(x: MyType): + assert_type(x, MyType) + assert_type(x + 1, int) + + f(1) + f("x") # E: incompatible_argument + """) diff --git a/pyanalyze/value.py b/pyanalyze/value.py index 9c3103fa..6b6a484c 100644 --- a/pyanalyze/value.py +++ b/pyanalyze/value.py @@ -25,6 +25,7 @@ def function(x: int, y: list[int], z: Any): from collections import deque from dataclasses import dataclass, field, InitVar from itertools import chain +import sys from types import FunctionType from typing import ( Any, @@ -60,9 +61,31 @@ def function(x: int, y: list[int], z: Any): KNOWN_MUTABLE_TYPES = (list, set, dict, deque) ITERATION_LIMIT = 1000 -TypeVarLike = Union[ - ExternalType["typing.TypeVar"], ExternalType["typing_extensions.ParamSpec"] -] +if sys.version_info >= (3, 11): + TypeVarLike = Union[ + ExternalType["typing.TypeVar"], + ExternalType["typing_extensions.TypeVar"], + ExternalType["typing.ParamSpec"], + ExternalType["typing_extensions.ParamSpec"], + ExternalType["typing.TypeVarTuple"], + ExternalType["typing_extensions.TypeVarTuple"], + ] +elif sys.version_info >= (3, 10): + TypeVarLike = Union[ + ExternalType["typing.TypeVar"], + ExternalType["typing_extensions.TypeVar"], + ExternalType["typing.ParamSpec"], + ExternalType["typing_extensions.ParamSpec"], + ExternalType["typing_extensions.TypeVarTuple"], + ] +else: + TypeVarLike = Union[ + ExternalType["typing.TypeVar"], + ExternalType["typing_extensions.TypeVar"], + ExternalType["typing_extensions.ParamSpec"], + ExternalType["typing_extensions.TypeVarTuple"], + ] + TypeVarMap = Mapping[TypeVarLike, ExternalType["pyanalyze.value.Value"]] BoundsMap = Mapping[TypeVarLike, Sequence[ExternalType["pyanalyze.value.Bound"]]] GenericBases = Mapping[Union[type, str], TypeVarMap] @@ -1737,6 +1760,7 @@ class TypeVarValue(Value): bound: Optional[Value] = None constraints: Sequence[Value] = () is_paramspec: bool = False + is_typevartuple: bool = False # unsupported def substitute_typevars(self, typevars: TypeVarMap) -> Value: return typevars.get(self.typevar, self)