Skip to content

Commit

Permalink
Fix crash in astdiff and clean it up (#14497)
Browse files Browse the repository at this point in the history
Ref #14329

This fixes one of the crashes reported in the issue. In fact, using
recursive type caught this crash statically, plus another subtle crash
in `snapshot_optional_type()`, _without a single false positive_ (I was
able to cleanly type also symbol table snapshots, but decided it is not
worth the churn since we only ever compare them with `==`, supported by
~every Python object). I feel triumphant :-)
  • Loading branch information
ilevkivskyi committed Jan 22, 2023
1 parent e8c844b commit a08388c
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 15 deletions.
30 changes: 18 additions & 12 deletions mypy/server/astdiff.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class level -- these are handled at attribute level (say, 'mod.Cls.method'

from __future__ import annotations

from typing import Sequence, Tuple, cast
from typing import Sequence, Tuple, Union, cast
from typing_extensions import TypeAlias as _TypeAlias

from mypy.expandtype import expand_type
Expand Down Expand Up @@ -109,11 +109,17 @@ class level -- these are handled at attribute level (say, 'mod.Cls.method'
# snapshots are immutable).
#
# For example, the snapshot of the 'int' type is ('Instance', 'builtins.int', ()).
SnapshotItem: _TypeAlias = Tuple[object, ...]

# Type snapshots are strict, they must be hashable and ordered (e.g. for Unions).
Primitive: _TypeAlias = Union[str, float, int, bool] # float is for Literal[3.14] support.
SnapshotItem: _TypeAlias = Tuple[Union[Primitive, "SnapshotItem"], ...]

# Symbol snapshots can be more lenient.
SymbolSnapshot: _TypeAlias = Tuple[object, ...]


def compare_symbol_table_snapshots(
name_prefix: str, snapshot1: dict[str, SnapshotItem], snapshot2: dict[str, SnapshotItem]
name_prefix: str, snapshot1: dict[str, SymbolSnapshot], snapshot2: dict[str, SymbolSnapshot]
) -> set[str]:
"""Return names that are different in two snapshots of a symbol table.
Expand Down Expand Up @@ -155,7 +161,7 @@ def compare_symbol_table_snapshots(
return triggers


def snapshot_symbol_table(name_prefix: str, table: SymbolTable) -> dict[str, SnapshotItem]:
def snapshot_symbol_table(name_prefix: str, table: SymbolTable) -> dict[str, SymbolSnapshot]:
"""Create a snapshot description that represents the state of a symbol table.
The snapshot has a representation based on nested tuples and dicts
Expand All @@ -165,7 +171,7 @@ def snapshot_symbol_table(name_prefix: str, table: SymbolTable) -> dict[str, Sna
things defined in other modules are represented just by the names of
the targets.
"""
result: dict[str, SnapshotItem] = {}
result: dict[str, SymbolSnapshot] = {}
for name, symbol in table.items():
node = symbol.node
# TODO: cross_ref?
Expand Down Expand Up @@ -206,7 +212,7 @@ def snapshot_symbol_table(name_prefix: str, table: SymbolTable) -> dict[str, Sna
return result


def snapshot_definition(node: SymbolNode | None, common: tuple[object, ...]) -> tuple[object, ...]:
def snapshot_definition(node: SymbolNode | None, common: SymbolSnapshot) -> SymbolSnapshot:
"""Create a snapshot description of a symbol table node.
The representation is nested tuples and dicts. Only externally
Expand Down Expand Up @@ -290,11 +296,11 @@ def snapshot_type(typ: Type) -> SnapshotItem:
return typ.accept(SnapshotTypeVisitor())


def snapshot_optional_type(typ: Type | None) -> SnapshotItem | None:
def snapshot_optional_type(typ: Type | None) -> SnapshotItem:
if typ:
return snapshot_type(typ)
else:
return None
return ("<not set>",)


def snapshot_types(types: Sequence[Type]) -> SnapshotItem:
Expand Down Expand Up @@ -396,7 +402,7 @@ def visit_parameters(self, typ: Parameters) -> SnapshotItem:
"Parameters",
snapshot_types(typ.arg_types),
tuple(encode_optional_str(name) for name in typ.arg_names),
tuple(typ.arg_kinds),
tuple(k.value for k in typ.arg_kinds),
)

def visit_callable_type(self, typ: CallableType) -> SnapshotItem:
Expand All @@ -407,7 +413,7 @@ def visit_callable_type(self, typ: CallableType) -> SnapshotItem:
snapshot_types(typ.arg_types),
snapshot_type(typ.ret_type),
tuple(encode_optional_str(name) for name in typ.arg_names),
tuple(typ.arg_kinds),
tuple(k.value for k in typ.arg_kinds),
typ.is_type_obj(),
typ.is_ellipsis_args,
snapshot_types(typ.variables),
Expand Down Expand Up @@ -464,7 +470,7 @@ def visit_type_alias_type(self, typ: TypeAliasType) -> SnapshotItem:
return ("TypeAliasType", typ.alias.fullname, snapshot_types(typ.args))


def snapshot_untyped_signature(func: OverloadedFuncDef | FuncItem) -> tuple[object, ...]:
def snapshot_untyped_signature(func: OverloadedFuncDef | FuncItem) -> SymbolSnapshot:
"""Create a snapshot of the signature of a function that has no explicit signature.
If the arguments to a function without signature change, it must be
Expand All @@ -476,7 +482,7 @@ def snapshot_untyped_signature(func: OverloadedFuncDef | FuncItem) -> tuple[obje
if isinstance(func, FuncItem):
return (tuple(func.arg_names), tuple(func.arg_kinds))
else:
result = []
result: list[SymbolSnapshot] = []
for item in func.items:
if isinstance(item, Decorator):
if item.var.type:
Expand Down
10 changes: 7 additions & 3 deletions mypy/server/update.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,11 @@
semantic_analysis_for_scc,
semantic_analysis_for_targets,
)
from mypy.server.astdiff import SnapshotItem, compare_symbol_table_snapshots, snapshot_symbol_table
from mypy.server.astdiff import (
SymbolSnapshot,
compare_symbol_table_snapshots,
snapshot_symbol_table,
)
from mypy.server.astmerge import merge_asts
from mypy.server.aststrip import SavedAttributes, strip_target
from mypy.server.deps import get_dependencies_of_target, merge_dependencies
Expand Down Expand Up @@ -417,7 +421,7 @@ def update_module(

t0 = time.time()
# Record symbol table snapshot of old version the changed module.
old_snapshots: dict[str, dict[str, SnapshotItem]] = {}
old_snapshots: dict[str, dict[str, SymbolSnapshot]] = {}
if module in manager.modules:
snapshot = snapshot_symbol_table(module, manager.modules[module].names)
old_snapshots[module] = snapshot
Expand Down Expand Up @@ -751,7 +755,7 @@ def get_sources(

def calculate_active_triggers(
manager: BuildManager,
old_snapshots: dict[str, dict[str, SnapshotItem]],
old_snapshots: dict[str, dict[str, SymbolSnapshot]],
new_modules: dict[str, MypyFile | None],
) -> set[str]:
"""Determine activated triggers by comparing old and new symbol tables.
Expand Down
28 changes: 28 additions & 0 deletions test-data/unit/fine-grained.test
Original file line number Diff line number Diff line change
Expand Up @@ -10313,3 +10313,31 @@ a.py:3: note: See https://mypy.readthedocs.io/en/stable/common_issues.html#varia
a.py:4: note: Revealed type is "A?"
==
a.py:4: note: Revealed type is "Union[builtins.str, builtins.int]"

[case testUnionOfSimilarCallablesCrash]
import b

[file b.py]
from a import x

[file m.py]
from typing import Union, TypeVar

T = TypeVar("T")
S = TypeVar("S")
def foo(x: T, y: S) -> Union[T, S]: ...
def f(x: int) -> int: ...
def g(*x: int) -> int: ...

[file a.py]
from m import f, g, foo
x = foo(f, g)

[file a.py.2]
from m import f, g, foo
x = foo(f, g)
reveal_type(x)
[builtins fixtures/tuple.pyi]
[out]
==
a.py:3: note: Revealed type is "Union[def (x: builtins.int) -> builtins.int, def (*x: builtins.int) -> builtins.int]"

0 comments on commit a08388c

Please sign in to comment.