Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix crash in astdiff and clean it up #14497

Merged
merged 1 commit into from
Jan 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 18 additions & 12 deletions mypy/server/astdiff.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class level -- these are handled at attribute level (say, 'mod.Cls.method'

from __future__ import annotations

from typing import Sequence, Tuple, cast
from typing import Sequence, Tuple, Union, cast
from typing_extensions import TypeAlias as _TypeAlias

from mypy.expandtype import expand_type
Expand Down Expand Up @@ -109,11 +109,17 @@ class level -- these are handled at attribute level (say, 'mod.Cls.method'
# snapshots are immutable).
#
# For example, the snapshot of the 'int' type is ('Instance', 'builtins.int', ()).
SnapshotItem: _TypeAlias = Tuple[object, ...]

# Type snapshots are strict, they must be hashable and ordered (e.g. for Unions).
Primitive: _TypeAlias = Union[str, float, int, bool] # float is for Literal[3.14] support.
SnapshotItem: _TypeAlias = Tuple[Union[Primitive, "SnapshotItem"], ...]

# Symbol snapshots can be more lenient.
SymbolSnapshot: _TypeAlias = Tuple[object, ...]


def compare_symbol_table_snapshots(
name_prefix: str, snapshot1: dict[str, SnapshotItem], snapshot2: dict[str, SnapshotItem]
name_prefix: str, snapshot1: dict[str, SymbolSnapshot], snapshot2: dict[str, SymbolSnapshot]
) -> set[str]:
"""Return names that are different in two snapshots of a symbol table.

Expand Down Expand Up @@ -155,7 +161,7 @@ def compare_symbol_table_snapshots(
return triggers


def snapshot_symbol_table(name_prefix: str, table: SymbolTable) -> dict[str, SnapshotItem]:
def snapshot_symbol_table(name_prefix: str, table: SymbolTable) -> dict[str, SymbolSnapshot]:
"""Create a snapshot description that represents the state of a symbol table.

The snapshot has a representation based on nested tuples and dicts
Expand All @@ -165,7 +171,7 @@ def snapshot_symbol_table(name_prefix: str, table: SymbolTable) -> dict[str, Sna
things defined in other modules are represented just by the names of
the targets.
"""
result: dict[str, SnapshotItem] = {}
result: dict[str, SymbolSnapshot] = {}
for name, symbol in table.items():
node = symbol.node
# TODO: cross_ref?
Expand Down Expand Up @@ -206,7 +212,7 @@ def snapshot_symbol_table(name_prefix: str, table: SymbolTable) -> dict[str, Sna
return result


def snapshot_definition(node: SymbolNode | None, common: tuple[object, ...]) -> tuple[object, ...]:
def snapshot_definition(node: SymbolNode | None, common: SymbolSnapshot) -> SymbolSnapshot:
"""Create a snapshot description of a symbol table node.

The representation is nested tuples and dicts. Only externally
Expand Down Expand Up @@ -289,11 +295,11 @@ def snapshot_type(typ: Type) -> SnapshotItem:
return typ.accept(SnapshotTypeVisitor())


def snapshot_optional_type(typ: Type | None) -> SnapshotItem | None:
def snapshot_optional_type(typ: Type | None) -> SnapshotItem:
if typ:
return snapshot_type(typ)
else:
return None
return ("<not set>",)


def snapshot_types(types: Sequence[Type]) -> SnapshotItem:
Expand Down Expand Up @@ -395,7 +401,7 @@ def visit_parameters(self, typ: Parameters) -> SnapshotItem:
"Parameters",
snapshot_types(typ.arg_types),
tuple(encode_optional_str(name) for name in typ.arg_names),
tuple(typ.arg_kinds),
tuple(k.value for k in typ.arg_kinds),
)

def visit_callable_type(self, typ: CallableType) -> SnapshotItem:
Expand All @@ -406,7 +412,7 @@ def visit_callable_type(self, typ: CallableType) -> SnapshotItem:
snapshot_types(typ.arg_types),
snapshot_type(typ.ret_type),
tuple(encode_optional_str(name) for name in typ.arg_names),
tuple(typ.arg_kinds),
tuple(k.value for k in typ.arg_kinds),
typ.is_type_obj(),
typ.is_ellipsis_args,
snapshot_types(typ.variables),
Expand Down Expand Up @@ -463,7 +469,7 @@ def visit_type_alias_type(self, typ: TypeAliasType) -> SnapshotItem:
return ("TypeAliasType", typ.alias.fullname, snapshot_types(typ.args))


def snapshot_untyped_signature(func: OverloadedFuncDef | FuncItem) -> tuple[object, ...]:
def snapshot_untyped_signature(func: OverloadedFuncDef | FuncItem) -> SymbolSnapshot:
"""Create a snapshot of the signature of a function that has no explicit signature.

If the arguments to a function without signature change, it must be
Expand All @@ -475,7 +481,7 @@ def snapshot_untyped_signature(func: OverloadedFuncDef | FuncItem) -> tuple[obje
if isinstance(func, FuncItem):
return (tuple(func.arg_names), tuple(func.arg_kinds))
else:
result = []
result: list[SymbolSnapshot] = []
for item in func.items:
if isinstance(item, Decorator):
if item.var.type:
Expand Down
10 changes: 7 additions & 3 deletions mypy/server/update.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,11 @@
semantic_analysis_for_scc,
semantic_analysis_for_targets,
)
from mypy.server.astdiff import SnapshotItem, compare_symbol_table_snapshots, snapshot_symbol_table
from mypy.server.astdiff import (
SymbolSnapshot,
compare_symbol_table_snapshots,
snapshot_symbol_table,
)
from mypy.server.astmerge import merge_asts
from mypy.server.aststrip import SavedAttributes, strip_target
from mypy.server.deps import get_dependencies_of_target, merge_dependencies
Expand Down Expand Up @@ -417,7 +421,7 @@ def update_module(

t0 = time.time()
# Record symbol table snapshot of old version the changed module.
old_snapshots: dict[str, dict[str, SnapshotItem]] = {}
old_snapshots: dict[str, dict[str, SymbolSnapshot]] = {}
if module in manager.modules:
snapshot = snapshot_symbol_table(module, manager.modules[module].names)
old_snapshots[module] = snapshot
Expand Down Expand Up @@ -751,7 +755,7 @@ def get_sources(

def calculate_active_triggers(
manager: BuildManager,
old_snapshots: dict[str, dict[str, SnapshotItem]],
old_snapshots: dict[str, dict[str, SymbolSnapshot]],
new_modules: dict[str, MypyFile | None],
) -> set[str]:
"""Determine activated triggers by comparing old and new symbol tables.
Expand Down
28 changes: 28 additions & 0 deletions test-data/unit/fine-grained.test
Original file line number Diff line number Diff line change
Expand Up @@ -10315,3 +10315,31 @@ a.py:3: note: See https://mypy.readthedocs.io/en/stable/common_issues.html#varia
a.py:4: note: Revealed type is "A?"
==
a.py:4: note: Revealed type is "Union[builtins.str, builtins.int]"

[case testUnionOfSimilarCallablesCrash]
import b

[file b.py]
from a import x

[file m.py]
from typing import Union, TypeVar

T = TypeVar("T")
S = TypeVar("S")
def foo(x: T, y: S) -> Union[T, S]: ...
def f(x: int) -> int: ...
def g(*x: int) -> int: ...

[file a.py]
from m import f, g, foo
x = foo(f, g)

[file a.py.2]
from m import f, g, foo
x = foo(f, g)
reveal_type(x)
[builtins fixtures/tuple.pyi]
[out]
==
a.py:3: note: Revealed type is "Union[def (x: builtins.int) -> builtins.int, def (*x: builtins.int) -> builtins.int]"