Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix edge-case stubtest crashes when an instance of an enum.Flag that is not a member of that enum.Flag is used as a parameter default #15933

Merged
merged 5 commits into from Aug 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion mypy/stubtest.py
Expand Up @@ -1553,7 +1553,7 @@ def anytype() -> mypy.types.AnyType:
value: bool | int | str
if isinstance(runtime, bytes):
value = bytes_to_human_readable_repr(runtime)
elif isinstance(runtime, enum.Enum):
elif isinstance(runtime, enum.Enum) and isinstance(runtime.name, str):
value = runtime.name
elif isinstance(runtime, (bool, int, str)):
value = runtime
Expand Down
103 changes: 100 additions & 3 deletions mypy/test/teststubtest.py
Expand Up @@ -64,6 +64,7 @@ def __init__(self, name: str) -> None: ...

class Coroutine(Generic[_T_co, _S, _R]): ...
class Iterable(Generic[_T_co]): ...
class Iterator(Iterable[_T_co]): ...
class Mapping(Generic[_K, _V]): ...
class Match(Generic[AnyStr]): ...
class Sequence(Iterable[_T_co]): ...
Expand All @@ -86,7 +87,9 @@ def __init__(self) -> None: pass
def __repr__(self) -> str: pass
class type: ...

class tuple(Sequence[T_co], Generic[T_co]): ...
class tuple(Sequence[T_co], Generic[T_co]):
def __ge__(self, __other: tuple[T_co, ...]) -> bool: pass

class dict(Mapping[KT, VT]): ...

class function: pass
Expand All @@ -105,6 +108,39 @@ def classmethod(f: T) -> T: ...
def staticmethod(f: T) -> T: ...
"""

stubtest_enum_stub = """
import sys
from typing import Any, TypeVar, Iterator

_T = TypeVar('_T')

class EnumMeta(type):
def __len__(self) -> int: pass
def __iter__(self: type[_T]) -> Iterator[_T]: pass
def __reversed__(self: type[_T]) -> Iterator[_T]: pass
def __getitem__(self: type[_T], name: str) -> _T: pass

class Enum(metaclass=EnumMeta):
def __new__(cls: type[_T], value: object) -> _T: pass
def __repr__(self) -> str: pass
def __str__(self) -> str: pass
def __format__(self, format_spec: str) -> str: pass
def __hash__(self) -> Any: pass
def __reduce_ex__(self, proto: Any) -> Any: pass
name: str
value: Any

class Flag(Enum):
def __or__(self: _T, other: _T) -> _T: pass
def __and__(self: _T, other: _T) -> _T: pass
def __xor__(self: _T, other: _T) -> _T: pass
def __invert__(self: _T) -> _T: pass
if sys.version_info >= (3, 11):
__ror__ = __or__
__rand__ = __and__
__rxor__ = __xor__
Comment on lines +112 to +141
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was required because the standard enum fixture is missing many methods on enum.Flag, and this was causing stubtest to complain about "missing methods in the stub" when I created a test case that used enum.Flag

class Flag(Enum):
def __or__(self: _T, other: Union[int, _T]) -> _T: pass

I initially tried adding the missing methods to the standard enum fixture, but this caused many unrelated tests to fail

"""


def run_stubtest(
stub: str, runtime: str, options: list[str], config_file: str | None = None
Expand All @@ -114,6 +150,8 @@ def run_stubtest(
f.write(stubtest_builtins_stub)
with open("typing.pyi", "w") as f:
f.write(stubtest_typing_stub)
with open("enum.pyi", "w") as f:
f.write(stubtest_enum_stub)
with open(f"{TEST_MODULE_NAME}.pyi", "w") as f:
f.write(stub)
with open(f"{TEST_MODULE_NAME}.py", "w") as f:
Expand Down Expand Up @@ -954,23 +992,82 @@ def fizz(self): pass

@collect_cases
def test_enum(self) -> Iterator[Case]:
yield Case(stub="import enum", runtime="import enum", error=None)
yield Case(
stub="""
import enum
class X(enum.Enum):
a: int
b: str
c: str
""",
runtime="""
import enum
class X(enum.Enum):
a = 1
b = "asdf"
c = 2
""",
error="X.c",
)
yield Case(
stub="""
class Flags1(enum.Flag):
a: int
b: int
def foo(x: Flags1 = ...) -> None: ...
""",
runtime="""
class Flags1(enum.Flag):
a = 1
b = 2
def foo(x=Flags1.a|Flags1.b): pass
""",
error=None,
)
yield Case(
stub="""
class Flags2(enum.Flag):
a: int
b: int
def bar(x: Flags2 | None = None) -> None: ...
""",
runtime="""
class Flags2(enum.Flag):
a = 1
b = 2
def bar(x=Flags2.a|Flags2.b): pass
""",
error="bar",
)
yield Case(
stub="""
class Flags3(enum.Flag):
a: int
b: int
def baz(x: Flags3 | None = ...) -> None: ...
""",
runtime="""
class Flags3(enum.Flag):
a = 1
b = 2
def baz(x=Flags3(0)): pass
""",
error=None,
)
yield Case(
stub="""
class Flags4(enum.Flag):
a: int
b: int
def spam(x: Flags4 | None = None) -> None: ...
""",
runtime="""
class Flags4(enum.Flag):
a = 1
b = 2
def spam(x=Flags4(0)): pass
""",
error="spam",
)

@collect_cases
def test_decorator(self) -> Iterator[Case]:
Expand Down