diff --git a/docs/changelog.md b/docs/changelog.md index 971d7bdb..4bf6dab1 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -2,6 +2,7 @@ ## Unreleased +- Add plugin providing a precise type for `dict.get` calls (#460) - Fix internal error when an `__eq__` method throws (#461) - Fix handling of `async def` methods in stubs (#459) - Treat Thrift enums as compatible with protocols that diff --git a/pyanalyze/implementation.py b/pyanalyze/implementation.py index f37b3413..003575a3 100644 --- a/pyanalyze/implementation.py +++ b/pyanalyze/implementation.py @@ -524,6 +524,81 @@ def inner(key: Value) -> Value: return flatten_unions(inner, ctx.vars["k"]) +def _dict_get_impl(ctx: CallContext) -> ImplReturn: + default = ctx.vars["default"] + + def inner(key: Value) -> Value: + self_value = ctx.vars["self"] + if isinstance(self_value, AnnotatedValue): + self_value = self_value.value + if isinstance(key, KnownValue): + try: + hash(key.val) + except Exception: + ctx.show_error( + f"Dictionary key {key} is not hashable", + ErrorCode.unhashable_key, + arg="k", + ) + return AnyValue(AnySource.error) + if isinstance(self_value, KnownValue): + if isinstance(key, KnownValue): + try: + return_value = self_value.val[key.val] + except Exception: + return default + else: + return KnownValue(return_value) | default + # else just treat it together with DictIncompleteValue + self_value = replace_known_sequence_value(self_value) + if isinstance(self_value, TypedDictValue): + if not TypedValue(str).is_assignable(key, ctx.visitor): + ctx.show_error( + f"TypedDict key must be str, not {key}", + ErrorCode.invalid_typeddict_key, + arg="k", + ) + return AnyValue(AnySource.error) + elif isinstance(key, KnownValue): + try: + required, value = self_value.items[key.val] + # probably KeyError, but catch anything in case it's an + # unhashable str subclass or something + except Exception: + # No error here; TypedDicts may have additional keys at runtime. + pass + else: + if required: + return value + else: + return value | default + # TODO strictly we should throw an error for any non-Literal or unknown key: + # https://www.python.org/dev/peps/pep-0589/#supported-and-unsupported-operations + # Don't do that yet because it may cause too much disruption. + return default + elif isinstance(self_value, DictIncompleteValue): + val = self_value.get_value(key, ctx.visitor) + if val is UNINITIALIZED_VALUE: + return default + return val | default + elif isinstance(self_value, TypedValue): + key_type = self_value.get_generic_arg_for_type(dict, ctx.visitor, 0) + can_assign = key_type.can_assign(key, ctx.visitor) + if isinstance(can_assign, CanAssignError): + ctx.show_error( + f"Dictionary does not accept keys of type {key}", + error_code=ErrorCode.incompatible_argument, + detail=str(can_assign), + arg="key", + ) + value_type = self_value.get_generic_arg_for_type(dict, ctx.visitor, 1) + return value_type | default + else: + return AnyValue(AnySource.inference) + + return flatten_unions(inner, ctx.vars["key"]) + + def _dict_setdefault_impl(ctx: CallContext) -> ImplReturn: key = ctx.vars["key"] default = ctx.vars["default"] @@ -1370,6 +1445,15 @@ def get_default_argspecs() -> Dict[object, Signature]: callable=dict.__getitem__, impl=_dict_getitem_impl, ), + Signature.make( + [ + SigParameter("self", _POS_ONLY, annotation=TypedValue(dict)), + SigParameter("key", _POS_ONLY), + SigParameter("default", _POS_ONLY, default=KnownValue(None)), + ], + callable=dict.get, + impl=_dict_get_impl, + ), Signature.make( [ SigParameter("self", _POS_ONLY, annotation=TypedValue(dict)), diff --git a/pyanalyze/test_implementation.py b/pyanalyze/test_implementation.py index e2b5aab1..dca9c5fb 100644 --- a/pyanalyze/test_implementation.py +++ b/pyanalyze/test_implementation.py @@ -663,6 +663,37 @@ def capybara(cond): ), ) + @assert_passes() + def test_dict_get(self): + from typing_extensions import TypedDict, NotRequired + from typing import Dict + + class TD(TypedDict): + a: int + b: str + c: NotRequired[str] + + def capybara(td: TD, s: str, d: Dict[str, int]): + assert_is_value(td.get("a"), TypedValue(int)) + assert_is_value(td.get("c"), TypedValue(str) | KnownValue(None)) + assert_is_value(td.get("c", 1), TypedValue(str) | KnownValue(1)) + td.get(1) # E: invalid_typeddict_key + + known = {"a": "b"} + assert_is_value(known.get("a"), KnownValue("b") | KnownValue(None)) + assert_is_value(known.get("b", 1), KnownValue(1)) + assert_is_value(known.get(s), KnownValue("b") | KnownValue(None)) + + incomplete = {**td, "b": 1, "d": s} + assert_is_value(incomplete.get("a"), TypedValue(int) | KnownValue(None)) + assert_is_value(incomplete.get("b"), KnownValue(1) | KnownValue(None)) + assert_is_value(incomplete.get("d"), TypedValue(str) | KnownValue(None)) + assert_is_value(incomplete.get("e"), KnownValue(None)) + + assert_is_value(d.get("x"), TypedValue(int) | KnownValue(None)) + assert_is_value(d.get(s), TypedValue(int) | KnownValue(None)) + d.get(1) # E: incompatible_argument + @assert_passes() def test_setdefault(self): from typing_extensions import TypedDict