Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add |= and | operators support for TypedDict #16249

Merged
merged 15 commits into from
Oct 23, 2023
19 changes: 15 additions & 4 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -7452,14 +7452,25 @@ def infer_operator_assignment_method(typ: Type, operator: str) -> tuple[bool, st
"""
typ = get_proper_type(typ)
method = operators.op_methods[operator]
existing_method = None
if isinstance(typ, Instance):
if operator in operators.ops_with_inplace_method:
inplace_method = "__i" + method[2:]
if typ.type.has_readable_member(inplace_method):
return True, inplace_method
existing_method = _find_inplace_method(typ, method, operator)
elif isinstance(typ, TypedDictType):
existing_method = _find_inplace_method(typ.fallback, method, operator)

if existing_method is not None:
return True, existing_method
return False, method


def _find_inplace_method(inst: Instance, method: str, operator: str) -> str | None:
if operator in operators.ops_with_inplace_method:
inplace_method = "__i" + method[2:]
if inst.type.has_readable_member(inplace_method):
return inplace_method
return None


def is_valid_inferred_type(typ: Type, is_lvalue_final: bool = False) -> bool:
"""Is an inferred type valid and needs no further refinement?

Expand Down
55 changes: 53 additions & 2 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@

from __future__ import annotations

import enum
import itertools
import time
from collections import defaultdict
from contextlib import contextmanager
from typing import Callable, ClassVar, Final, Iterable, Iterator, List, Optional, Sequence, cast
from typing_extensions import TypeAlias as _TypeAlias, overload
from typing_extensions import TypeAlias as _TypeAlias, assert_never, overload

import mypy.checker
import mypy.errorcodes as codes
Expand Down Expand Up @@ -272,6 +273,20 @@ class Finished(Exception):
"""Raised if we can terminate overload argument check early (no match)."""


@enum.unique
class UseReverse(enum.Enum):
"""Used in `visit_op_expr` to enable or disable reverse method checks."""

DEFAULT = 0
ALWAYS = 1
NEVER = 2


USE_REVERSE_DEFAULT: Final = UseReverse.DEFAULT
USE_REVERSE_ALWAYS: Final = UseReverse.ALWAYS
USE_REVERSE_NEVER: Final = UseReverse.NEVER


class ExpressionChecker(ExpressionVisitor[Type]):
"""Expression type checker.

Expand Down Expand Up @@ -3366,6 +3381,24 @@ def visit_op_expr(self, e: OpExpr) -> Type:
return proper_left_type.copy_modified(
items=proper_left_type.items + [UnpackType(mapped)]
)

use_reverse: UseReverse = USE_REVERSE_DEFAULT
if e.op == "|":
if is_named_instance(proper_left_type, "builtins.dict"):
# This is a special case for `dict | TypedDict`.
# 1. Find `dict | TypedDict` case
# 2. Switch `dict.__or__` to `TypedDict.__ror__` (the same from both runtime and typing perspective)
proper_right_type = get_proper_type(self.accept(e.right))
if isinstance(proper_right_type, TypedDictType):
use_reverse = USE_REVERSE_ALWAYS
if isinstance(proper_left_type, TypedDictType):
# This is the reverse case: `TypedDict | dict`,
# simply do not allow the reverse checking:
# do not call `__dict__.__ror__`.
proper_right_type = get_proper_type(self.accept(e.right))
if is_named_instance(proper_right_type, "builtins.dict"):
use_reverse = USE_REVERSE_NEVER

if TYPE_VAR_TUPLE in self.chk.options.enable_incomplete_feature:
# Handle tuple[X, ...] + tuple[Y, Z] = tuple[*tuple[X, ...], Y, Z].
if (
Expand All @@ -3385,7 +3418,25 @@ def visit_op_expr(self, e: OpExpr) -> Type:

if e.op in operators.op_methods:
method = operators.op_methods[e.op]
result, method_type = self.check_op(method, left_type, e.right, e, allow_reverse=True)
if use_reverse is UseReverse.DEFAULT or use_reverse is UseReverse.NEVER:
result, method_type = self.check_op(
method,
base_type=left_type,
arg=e.right,
context=e,
allow_reverse=use_reverse is UseReverse.DEFAULT,
)
elif use_reverse is UseReverse.ALWAYS:
result, method_type = self.check_op(
# The reverse operator here gives better error messages:
operators.reverse_op_methods[method],
base_type=self.accept(e.right),
arg=e.left,
context=e,
allow_reverse=False,
)
else:
assert_never(use_reverse)
e.method_type = method_type
return result
else:
Expand Down
22 changes: 18 additions & 4 deletions mypy/plugins/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,12 +74,21 @@ def get_method_signature_hook(
return typed_dict_setdefault_signature_callback
elif fullname in {n + ".pop" for n in TPDICT_FB_NAMES}:
return typed_dict_pop_signature_callback
elif fullname in {n + ".update" for n in TPDICT_FB_NAMES}:
return typed_dict_update_signature_callback
elif fullname == "_ctypes.Array.__setitem__":
return ctypes.array_setitem_callback
elif fullname == singledispatch.SINGLEDISPATCH_CALLABLE_CALL_METHOD:
return singledispatch.call_singledispatch_function_callback

typed_dict_updates = set()
for n in TPDICT_FB_NAMES:
typed_dict_updates.add(n + ".update")
typed_dict_updates.add(n + ".__or__")
typed_dict_updates.add(n + ".__ror__")
typed_dict_updates.add(n + ".__ior__")

if fullname in typed_dict_updates:
return typed_dict_update_signature_callback

return None

def get_method_hook(self, fullname: str) -> Callable[[MethodContext], Type] | None:
Expand Down Expand Up @@ -401,11 +410,16 @@ def typed_dict_delitem_callback(ctx: MethodContext) -> Type:


def typed_dict_update_signature_callback(ctx: MethodSigContext) -> CallableType:
"""Try to infer a better signature type for TypedDict.update."""
"""Try to infer a better signature type for methods that update `TypedDict`.

This includes: `TypedDict.update`, `TypedDict.__or__`, `TypedDict.__ror__`,
and `TypedDict.__ior__`.
"""
signature = ctx.default_signature
if isinstance(ctx.type, TypedDictType) and len(signature.arg_types) == 1:
arg_type = get_proper_type(signature.arg_types[0])
assert isinstance(arg_type, TypedDictType)
if not isinstance(arg_type, TypedDictType):
return signature
arg_type = arg_type.as_anonymous()
arg_type = arg_type.copy_modified(required_keys=set())
if ctx.args and ctx.args[0]:
Expand Down
143 changes: 143 additions & 0 deletions test-data/unit/check-typeddict.test
Original file line number Diff line number Diff line change
Expand Up @@ -3236,3 +3236,146 @@ def foo(x: int) -> Foo: ...
f: Foo = {**foo("no")} # E: Argument 1 to "foo" has incompatible type "str"; expected "int"
[builtins fixtures/dict.pyi]
[typing fixtures/typing-typeddict.pyi]


[case testTypedDictWith__or__method]
from typing import Dict
from mypy_extensions import TypedDict

class Foo(TypedDict):
key: int

foo1: Foo = {'key': 1}
foo2: Foo = {'key': 2}

reveal_type(foo1 | foo2) # N: Revealed type is "TypedDict('__main__.Foo', {'key': builtins.int})"
reveal_type(foo1 | {'key': 1}) # N: Revealed type is "TypedDict('__main__.Foo', {'key': builtins.int})"
reveal_type(foo1 | {'key': 'a'}) # N: Revealed type is "builtins.dict[builtins.str, builtins.object]"
reveal_type(foo1 | {}) # N: Revealed type is "TypedDict('__main__.Foo', {'key': builtins.int})"

d1: Dict[str, int]
d2: Dict[int, str]

reveal_type(foo1 | d1) # N: Revealed type is "builtins.dict[builtins.str, builtins.object]"
foo1 | d2 # E: Unsupported operand types for | ("Foo" and "Dict[int, str]")


class Bar(TypedDict):
key: int
value: str

bar: Bar
reveal_type(bar | {}) # N: Revealed type is "TypedDict('__main__.Bar', {'key': builtins.int, 'value': builtins.str})"
reveal_type(bar | {'key': 1, 'value': 'v'}) # N: Revealed type is "TypedDict('__main__.Bar', {'key': builtins.int, 'value': builtins.str})"
reveal_type(bar | {'key': 1}) # N: Revealed type is "TypedDict('__main__.Bar', {'key': builtins.int, 'value': builtins.str})"
reveal_type(bar | {'value': 'v'}) # N: Revealed type is "TypedDict('__main__.Bar', {'key': builtins.int, 'value': builtins.str})"
reveal_type(bar | {'key': 'a'}) # N: Revealed type is "builtins.dict[builtins.str, builtins.object]"
reveal_type(bar | {'value': 1}) # N: Revealed type is "builtins.dict[builtins.str, builtins.object]"
reveal_type(bar | {'key': 'a', 'value': 1}) # N: Revealed type is "builtins.dict[builtins.str, builtins.object]"

reveal_type(bar | foo1) # N: Revealed type is "TypedDict('__main__.Bar', {'key': builtins.int, 'value': builtins.str})"
reveal_type(bar | d1) # N: Revealed type is "builtins.dict[builtins.str, builtins.object]"
bar | d2 # E: Unsupported operand types for | ("Bar" and "Dict[int, str]")
[builtins fixtures/dict.pyi]
[typing fixtures/typing-typeddict-iror.pyi]

[case testTypedDictWith__or__method_error]
from mypy_extensions import TypedDict

class Foo(TypedDict):
key: int

foo: Foo = {'key': 1}
foo | 1

class SubDict(dict): ...
foo | SubDict()
[out]
main:7: error: No overload variant of "__or__" of "TypedDict" matches argument type "int"
main:7: note: Possible overload variants:
main:7: note: def __or__(self, TypedDict({'key'?: int}), /) -> Foo
main:7: note: def __or__(self, Dict[str, Any], /) -> Dict[str, object]
main:10: error: No overload variant of "__ror__" of "dict" matches argument type "Foo"
main:10: note: Possible overload variants:
main:10: note: def __ror__(self, Dict[Any, Any], /) -> Dict[Any, Any]
main:10: note: def [T, T2] __ror__(self, Dict[T, T2], /) -> Dict[Union[Any, T], Union[Any, T2]]
[builtins fixtures/dict.pyi]
[typing fixtures/typing-typeddict-iror.pyi]

[case testTypedDictWith__ror__method]
from typing import Dict
from mypy_extensions import TypedDict

class Foo(TypedDict):
key: int

foo: Foo = {'key': 1}

reveal_type({'key': 1} | foo) # N: Revealed type is "TypedDict('__main__.Foo', {'key': builtins.int})"
reveal_type({'key': 'a'} | foo) # N: Revealed type is "builtins.dict[builtins.str, builtins.object]"
reveal_type({} | foo) # N: Revealed type is "TypedDict('__main__.Foo', {'key': builtins.int})"
{1: 'a'} | foo # E: Dict entry 0 has incompatible type "int": "str"; expected "str": "Any"

d1: Dict[str, int]
d2: Dict[int, str]

reveal_type(d1 | foo) # N: Revealed type is "builtins.dict[builtins.str, builtins.object]"
d2 | foo # E: Unsupported operand types for | ("Dict[int, str]" and "Foo")
1 | foo # E: Unsupported left operand type for | ("int")


class Bar(TypedDict):
key: int
value: str

bar: Bar
reveal_type({} | bar) # N: Revealed type is "TypedDict('__main__.Bar', {'key': builtins.int, 'value': builtins.str})"
reveal_type({'key': 1, 'value': 'v'} | bar) # N: Revealed type is "TypedDict('__main__.Bar', {'key': builtins.int, 'value': builtins.str})"
reveal_type({'key': 1} | bar) # N: Revealed type is "TypedDict('__main__.Bar', {'key': builtins.int, 'value': builtins.str})"
reveal_type({'value': 'v'} | bar) # N: Revealed type is "TypedDict('__main__.Bar', {'key': builtins.int, 'value': builtins.str})"
reveal_type({'key': 'a'} | bar) # N: Revealed type is "builtins.dict[builtins.str, builtins.object]"
reveal_type({'value': 1} | bar) # N: Revealed type is "builtins.dict[builtins.str, builtins.object]"
reveal_type({'key': 'a', 'value': 1} | bar) # N: Revealed type is "builtins.dict[builtins.str, builtins.object]"

reveal_type(d1 | bar) # N: Revealed type is "builtins.dict[builtins.str, builtins.object]"
d2 | bar # E: Unsupported operand types for | ("Dict[int, str]" and "Bar")
[builtins fixtures/dict.pyi]
[typing fixtures/typing-typeddict-iror.pyi]

[case testTypedDictWith__ior__method]
from typing import Dict
from mypy_extensions import TypedDict

class Foo(TypedDict):
key: int
sobolevn marked this conversation as resolved.
Show resolved Hide resolved

foo: Foo = {'key': 1}
foo |= {'key': 2}

foo |= {}
foo |= {'key': 'a', 'b': 'a'} # E: Expected TypedDict key "key" but found keys ("key", "b") \
# E: Incompatible types (expression has type "str", TypedDict item "key" has type "int")
foo |= {'b': 2} # E: Unexpected TypedDict key "b"

d1: Dict[str, int]
d2: Dict[int, str]

foo |= d1 # E: Argument 1 to "__ior__" of "TypedDict" has incompatible type "Dict[str, int]"; expected "TypedDict({'key'?: int})"
foo |= d2 # E: Argument 1 to "__ior__" of "TypedDict" has incompatible type "Dict[int, str]"; expected "TypedDict({'key'?: int})"


class Bar(TypedDict):
key: int
value: str

bar: Bar
bar |= {}
bar |= {'key': 1, 'value': 'a'}
bar |= {'key': 'a', 'value': 'a', 'b': 'a'} # E: Expected TypedDict keys ("key", "value") but found keys ("key", "value", "b") \
# E: Incompatible types (expression has type "str", TypedDict item "key" has type "int")

bar |= foo
bar |= d1 # E: Argument 1 to "__ior__" of "TypedDict" has incompatible type "Dict[str, int]"; expected "TypedDict({'key'?: int, 'value'?: str})"
bar |= d2 # E: Argument 1 to "__ior__" of "TypedDict" has incompatible type "Dict[int, str]"; expected "TypedDict({'key'?: int, 'value'?: str})"
[builtins fixtures/dict.pyi]
[typing fixtures/typing-typeddict-iror.pyi]
19 changes: 18 additions & 1 deletion test-data/unit/fixtures/dict.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
from _typeshed import SupportsKeysAndGetItem
import _typeshed
from typing import (
TypeVar, Generic, Iterable, Iterator, Mapping, Tuple, overload, Optional, Union, Sequence
TypeVar, Generic, Iterable, Iterator, Mapping, Tuple, overload, Optional, Union, Sequence,
Self,
)

T = TypeVar('T')
T2 = TypeVar('T2')
KT = TypeVar('KT')
VT = TypeVar('VT')

Expand Down Expand Up @@ -34,6 +36,21 @@ class dict(Mapping[KT, VT]):
def get(self, k: KT, default: Union[VT, T]) -> Union[VT, T]: pass
def __len__(self) -> int: ...

# This was actually added in 3.9:
@overload
def __or__(self, __value: dict[KT, VT]) -> dict[KT, VT]: ...
@overload
def __or__(self, __value: dict[T, T2]) -> dict[Union[KT, T], Union[VT, T2]]: ...
@overload
def __ror__(self, __value: dict[KT, VT]) -> dict[KT, VT]: ...
@overload
def __ror__(self, __value: dict[T, T2]) -> dict[Union[KT, T], Union[VT, T2]]: ...
# dict.__ior__ should be kept roughly in line with MutableMapping.update()
@overload # type: ignore[misc]
def __ior__(self, __value: _typeshed.SupportsKeysAndGetItem[KT, VT]) -> Self: ...
@overload
def __ior__(self, __value: Iterable[Tuple[KT, VT]]) -> Self: ...

class int: # for convenience
def __add__(self, x: Union[int, complex]) -> int: pass
def __radd__(self, x: int) -> int: pass
Expand Down
1 change: 1 addition & 0 deletions test-data/unit/fixtures/typing-async.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ ClassVar = 0
Final = 0
Literal = 0
NoReturn = 0
Self = 0

T = TypeVar('T')
T_co = TypeVar('T_co', covariant=True)
Expand Down
1 change: 1 addition & 0 deletions test-data/unit/fixtures/typing-full.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ Literal = 0
TypedDict = 0
NoReturn = 0
NewType = 0
Self = 0

T = TypeVar('T')
T_co = TypeVar('T_co', covariant=True)
Expand Down
1 change: 1 addition & 0 deletions test-data/unit/fixtures/typing-medium.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ NoReturn = 0
NewType = 0
TypeAlias = 0
LiteralString = 0
Self = 0

T = TypeVar('T')
T_co = TypeVar('T_co', covariant=True)
Expand Down
Loading