Skip to content

Commit

Permalink
Support flexible TypedDict creation/update (#15425)
Browse files Browse the repository at this point in the history
Fixes #9408
Fixes #4122
Fixes #6462
Supersedes #13353

This PR enables two similar technically unsafe behaviors for TypedDicts,
as @JukkaL explained in
#6462 (comment)
allowing an "incomplete" TypedDict as an argument to `.update()` is
technically unsafe (and a similar argument applies to `**` syntax in
TypedDict literals). These are however very common patterns (judging
from number of duplicates to above issues), so I think we should support
them. Here is what I propose:
* Always support cases that are safe (like passing the type itself to
`update`)
* Allow popular but technically unsafe cases _by default_
* Have a new flag (as part of `--strict`) to fall back to current
behavior

Note that unfortunately we can't use just a custom new error code, since
we need to conditionally tweak some types in a plugin. Btw there are
couple TODOs I add here:
* First is for unsafe behavior for repeated TypedDict keys. This is not
new, I just noticed it when working on this
* Second is for tricky corner case involving multiple `**` items where
we may have false-negatives in strict mode.

Note that I don't test all the possible combinations here (since the
phase space is huge), but I think I am testing all main ingredients (and
I will be glad to add more if needed):
* All syntax variants for TypedDicts creation are handled
* Various shadowing/overrides scenarios
* Required vs non-required keys handling
* Union types (both as item and target types)
* Inference for generic TypedDicts
* New strictness flag

More than half of the tests I took from the original PR #13353
  • Loading branch information
ilevkivskyi committed Jun 26, 2023
1 parent 7ce3568 commit 8290bb8
Show file tree
Hide file tree
Showing 11 changed files with 659 additions and 69 deletions.
28 changes: 28 additions & 0 deletions docs/source/command_line.rst
Original file line number Diff line number Diff line change
Expand Up @@ -612,6 +612,34 @@ of the above sections.
assert text is not None # OK, check against None is allowed as a special case.
.. option:: --extra-checks

This flag enables additional checks that are technically correct but may be
impractical in real code. In particular, it prohibits partial overlap in
``TypedDict`` updates, and makes arguments prepended via ``Concatenate``
positional-only. For example:

.. code-block:: python
from typing import TypedDict
class Foo(TypedDict):
a: int
class Bar(TypedDict):
a: int
b: int
def test(foo: Foo, bar: Bar) -> None:
# This is technically unsafe since foo can have a subtype of Foo at
# runtime, where type of key "b" is incompatible with int, see below
bar.update(foo)
class Bad(Foo):
b: str
bad: Bad = {"a": 0, "b": "no"}
test(bad, bar)
.. option:: --strict

This flag mode enables all optional error checking flags. You can see the
Expand Down
255 changes: 194 additions & 61 deletions mypy/checkexpr.py

Large diffs are not rendered by default.

13 changes: 11 additions & 2 deletions mypy/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -826,10 +826,12 @@ def add_invertible_flag(
)

add_invertible_flag(
"--strict-concatenate",
"--extra-checks",
default=False,
strict_flag=True,
help="Make arguments prepended via Concatenate be truly positional-only",
help="Enable additional checks that are technically correct but may be impractical "
"in real code. For example, this prohibits partial overlap in TypedDict updates, "
"and makes arguments prepended via Concatenate positional-only",
group=strictness_group,
)

Expand Down Expand Up @@ -1155,6 +1157,8 @@ def add_invertible_flag(
parser.add_argument(
"--disable-memoryview-promotion", action="store_true", help=argparse.SUPPRESS
)
# This flag is deprecated, it has been moved to --extra-checks
parser.add_argument("--strict-concatenate", action="store_true", help=argparse.SUPPRESS)

# options specifying code to check
code_group = parser.add_argument_group(
Expand Down Expand Up @@ -1226,8 +1230,11 @@ def add_invertible_flag(
parser.error(f"Cannot find config file '{config_file}'")

options = Options()
strict_option_set = False

def set_strict_flags() -> None:
nonlocal strict_option_set
strict_option_set = True
for dest, value in strict_flag_assignments:
setattr(options, dest, value)

Expand Down Expand Up @@ -1379,6 +1386,8 @@ def set_strict_flags() -> None:
"Warning: --enable-recursive-aliases is deprecated;"
" recursive types are enabled by default"
)
if options.strict_concatenate and not strict_option_set:
print("Warning: --strict-concatenate is deprecated; use --extra-checks instead")

# Set target.
if special_opts.modules + special_opts.packages:
Expand Down
18 changes: 18 additions & 0 deletions mypy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -1757,6 +1757,24 @@ def need_annotation_for_var(
def explicit_any(self, ctx: Context) -> None:
self.fail('Explicit "Any" is not allowed', ctx)

def unsupported_target_for_star_typeddict(self, typ: Type, ctx: Context) -> None:
self.fail(
"Unsupported type {} for ** expansion in TypedDict".format(
format_type(typ, self.options)
),
ctx,
code=codes.TYPEDDICT_ITEM,
)

def non_required_keys_absent_with_star(self, keys: list[str], ctx: Context) -> None:
self.fail(
"Non-required {} not explicitly found in any ** item".format(
format_key_list(keys, short=True)
),
ctx,
code=codes.TYPEDDICT_ITEM,
)

def unexpected_typeddict_keys(
self,
typ: TypedDictType,
Expand Down
6 changes: 5 additions & 1 deletion mypy/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class BuildType:
"disallow_untyped_defs",
"enable_error_code",
"enabled_error_codes",
"extra_checks",
"follow_imports_for_stubs",
"follow_imports",
"ignore_errors",
Expand Down Expand Up @@ -200,9 +201,12 @@ def __init__(self) -> None:
# This makes 1 == '1', 1 in ['1'], and 1 is '1' errors.
self.strict_equality = False

# Make arguments prepended via Concatenate be truly positional-only.
# Deprecated, use extra_checks instead.
self.strict_concatenate = False

# Enable additional checks that are technically correct but impractical.
self.extra_checks = False

# Report an error for any branches inferred to be unreachable as a result of
# type analysis.
self.warn_unreachable = False
Expand Down
27 changes: 27 additions & 0 deletions mypy/plugins/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@
TypedDictType,
TypeOfAny,
TypeVarType,
UnionType,
get_proper_type,
get_proper_types,
)


Expand Down Expand Up @@ -404,6 +406,31 @@ def typed_dict_update_signature_callback(ctx: MethodSigContext) -> CallableType:
assert isinstance(arg_type, TypedDictType)
arg_type = arg_type.as_anonymous()
arg_type = arg_type.copy_modified(required_keys=set())
if ctx.args and ctx.args[0]:
with ctx.api.msg.filter_errors():
inferred = get_proper_type(
ctx.api.get_expression_type(ctx.args[0][0], type_context=arg_type)
)
possible_tds = []
if isinstance(inferred, TypedDictType):
possible_tds = [inferred]
elif isinstance(inferred, UnionType):
possible_tds = [
t
for t in get_proper_types(inferred.relevant_items())
if isinstance(t, TypedDictType)
]
items = []
for td in possible_tds:
item = arg_type.copy_modified(
required_keys=(arg_type.required_keys | td.required_keys)
& arg_type.items.keys()
)
if not ctx.api.options.extra_checks:
item = item.copy_modified(item_names=list(td.items))
items.append(item)
if items:
arg_type = make_simplified_union(items)
return signature.copy_modified(arg_types=[arg_type])
return signature

Expand Down
4 changes: 2 additions & 2 deletions mypy/semanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -5084,14 +5084,14 @@ def translate_dict_call(self, call: CallExpr) -> DictExpr | None:
For other variants of dict(...), return None.
"""
if not all(kind == ARG_NAMED for kind in call.arg_kinds):
if not all(kind in (ARG_NAMED, ARG_STAR2) for kind in call.arg_kinds):
# Must still accept those args.
for a in call.args:
a.accept(self)
return None
expr = DictExpr(
[
(StrExpr(cast(str, key)), value) # since they are all ARG_NAMED
(StrExpr(key) if key is not None else None, value)
for key, value in zip(call.arg_names, call.args)
]
)
Expand Down
10 changes: 8 additions & 2 deletions mypy/subtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -694,7 +694,9 @@ def visit_callable_type(self, left: CallableType) -> bool:
right,
is_compat=self._is_subtype,
ignore_pos_arg_names=self.subtype_context.ignore_pos_arg_names,
strict_concatenate=self.options.strict_concatenate if self.options else True,
strict_concatenate=(self.options.extra_checks or self.options.strict_concatenate)
if self.options
else True,
)
elif isinstance(right, Overloaded):
return all(self._is_subtype(left, item) for item in right.items)
Expand Down Expand Up @@ -858,7 +860,11 @@ def visit_overloaded(self, left: Overloaded) -> bool:
else:
# If this one overlaps with the supertype in any way, but it wasn't
# an exact match, then it's a potential error.
strict_concat = self.options.strict_concatenate if self.options else True
strict_concat = (
(self.options.extra_checks or self.options.strict_concatenate)
if self.options
else True
)
if left_index not in matched_overloads and (
is_callable_compatible(
left_item,
Expand Down
4 changes: 4 additions & 0 deletions mypy/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2437,6 +2437,7 @@ def copy_modified(
*,
fallback: Instance | None = None,
item_types: list[Type] | None = None,
item_names: list[str] | None = None,
required_keys: set[str] | None = None,
) -> TypedDictType:
if fallback is None:
Expand All @@ -2447,6 +2448,9 @@ def copy_modified(
items = dict(zip(self.items, item_types))
if required_keys is None:
required_keys = self.required_keys
if item_names is not None:
items = {k: v for (k, v) in items.items() if k in item_names}
required_keys &= set(item_names)
return TypedDictType(items, required_keys, fallback, self.line, self.column)

def create_anonymous_fallback(self) -> Instance:
Expand Down
2 changes: 1 addition & 1 deletion test-data/unit/check-parameter-specification.test
Original file line number Diff line number Diff line change
Expand Up @@ -570,7 +570,7 @@ reveal_type(f(n)) # N: Revealed type is "def (builtins.int, builtins.bytes) ->
[builtins fixtures/paramspec.pyi]

[case testParamSpecConcatenateNamedArgs]
# flags: --python-version 3.8 --strict-concatenate
# flags: --python-version 3.8 --extra-checks
# this is one noticeable deviation from PEP but I believe it is for the better
from typing_extensions import ParamSpec, Concatenate
from typing import Callable, TypeVar
Expand Down
Loading

0 comments on commit 8290bb8

Please sign in to comment.