Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 38 additions & 6 deletions mypy/stubtest.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,6 +587,12 @@ def names_approx_match(a: str, b: str) -> bool:
)


def _stub_default_matches_runtime(evaluated_stub_default: object, runtime: object) -> bool:
# We want the types to match exactly, e.g. in case the stub has
# `bar: bool = True` and the runtime has `bar = 1` (or vice versa).
return evaluated_stub_default == runtime and type(evaluated_stub_default) is type(runtime)


def _verify_arg_default_value(
stub_arg: nodes.Argument, runtime_arg: inspect.Parameter
) -> Iterator[str]:
Expand Down Expand Up @@ -625,12 +631,7 @@ def _verify_arg_default_value(
if (
stub_default is not UNKNOWN
and stub_default is not ...
and (
stub_default != runtime_arg.default
# We want the types to match exactly, e.g. in case the stub has
# True and the runtime has 1 (or vice versa).
or type(stub_default) is not type(runtime_arg.default) # noqa: E721
)
and not _stub_default_matches_runtime(stub_default, runtime_arg.default)
):
yield (
f'runtime argument "{runtime_arg.name}" '
Expand Down Expand Up @@ -1019,6 +1020,37 @@ def verify_var(
yield Error(
object_path, f"variable differs from runtime type {runtime_type}", stub, runtime
)
return

# If the stub has a default value and it's not an enum,
# attempt to verify that the default value matches the runtime
if isinstance(runtime, enum.Enum) or not stub.has_explicit_value:
return
stub_type = mypy.types.get_proper_type(stub.type)
if not isinstance(stub_type, mypy.types.Instance):
return

node_for_stub_value = stub_type.last_known_value
if node_for_stub_value is None:
return

stub_value: object
if isinstance(node_for_stub_value.value, (int, bool, float)):
stub_value = node_for_stub_value.value
elif isinstance(node_for_stub_value.value, str):
fallback = node_for_stub_value.fallback.type
if fallback.fullname == "builtins.str":
stub_value = node_for_stub_value.value
elif fallback.fullname == "builtins.bytes":
stub_value = bytes(node_for_stub_value.value, "utf-8")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know about these codecs to decide whether this is an actual problem, but FWIW mypyc converts BytesExpr.value into bytes in a bit more involved way.

else:
return
else:
return

if not _stub_default_matches_runtime(stub_value, runtime):
msg = f"default value for variable differs from runtime {runtime!r}"
yield Error(object_path, msg, stub, runtime, stub_desc=repr(stub_value))


@verify.register(nodes.OverloadedFuncDef)
Expand Down
72 changes: 72 additions & 0 deletions mypy/test/teststubtest.py
Original file line number Diff line number Diff line change
Expand Up @@ -762,6 +762,78 @@ def read_write_attr(self, val): self._val = val
""",
error=None,
)
yield Case(
stub="""
from typing_extensions import Final
final1: Final = 1
final2: Final = 1
""",
runtime="""
final1 = 2
final2 = 1
""",
error="final1",
)
yield Case(
stub="""
from typing_extensions import Final
final3: Final = 1.5
final4: Final = 1.5
""",
runtime="""
final3 = 2.5
final4 = 1.5
""",
error="final3",
)
yield Case(
stub="""
from typing_extensions import Final
final5: Final = "foo"
final6: Final = "foo"
""",
runtime="""
final5 = "bar"
final6 = "foo"
""",
error="final5",
)
yield Case(
stub="""
from typing_extensions import Final
final7: Final = True
final8: Final = True
""",
runtime="""
final7 = False
final8 = True
""",
error="final7",
)
yield Case(
stub="""
from typing_extensions import Final
final9: Final = b"foo"
final10: Final = b"bar"
""",
runtime="""
final9 = b"bar"
final10 = b"bar"
""",
error="final9",
)
yield Case(
stub="""
from typing_extensions import Final
class MatchMaker:
__match_args__: Final = ("foo",)
""",
runtime="""
class MatchMaker:
__match_args__ = ("foo",)
""",
error=None,
)

@collect_cases
def test_type_alias(self) -> Iterator[Case]:
Expand Down