From fde74f7b2c229d8d72df361ba7152946b13f1f85 Mon Sep 17 00:00:00 2001 From: Michael Lee Date: Sun, 16 Dec 2018 20:28:24 -0800 Subject: [PATCH 1/2] Add intelligent indexing of tuples, NamedTuples, and TypedDict This pull request adds a preliminary implementation of intelligent indexing of tuples, NamedTuples, and TypedDicts. It uses the first approach we discussed earlier: modifying the existing plugins and special-casing code to also check if the expression has a Literal[...] type. Once I'm finished with the baseline literal types implementation, I'll look into circling back and seeing how viable the second approach is (writing some sort of plugin that replaces the signatures of methods like `.__getitem__` or `.get()` with overloads that use the appropriate literal types). --- mypy/checkexpr.py | 16 +++- mypy/plugin.py | 2 +- mypy/plugins/common.py | 18 +++- mypy/plugins/default.py | 101 +++++++++++----------- test-data/unit/check-literal.test | 135 ++++++++++++++++++++++++++++++ 5 files changed, 216 insertions(+), 56 deletions(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 88968f9735bb..9a38f1634aad 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -2421,13 +2421,21 @@ def _get_value(self, index: Expression) -> Optional[int]: operand = index.expr if isinstance(operand, IntExpr): return -1 * operand.value + typ = self.accept(index) + if isinstance(typ, LiteralType) and isinstance(typ.value, int): + return typ.value return None def visit_typeddict_index_expr(self, td_type: TypedDictType, index: Expression) -> Type: - if not isinstance(index, (StrExpr, UnicodeExpr)): - self.msg.typeddict_key_must_be_string_literal(td_type, index) - return AnyType(TypeOfAny.from_error) - item_name = index.value + if isinstance(index, (StrExpr, UnicodeExpr)): + item_name = index.value + else: + typ = self.accept(index) + if isinstance(typ, LiteralType) and isinstance(typ.value, str): + item_name = typ.value + else: + self.msg.typeddict_key_must_be_string_literal(td_type, index) + return AnyType(TypeOfAny.from_error) item_type = td_type.items.get(item_name) if item_type is None: diff --git a/mypy/plugin.py b/mypy/plugin.py index 7238dd132877..7e5ae4e86fae 100644 --- a/mypy/plugin.py +++ b/mypy/plugin.py @@ -134,7 +134,7 @@ class CheckerPluginInterface: @abstractmethod def fail(self, msg: str, ctx: Context) -> None: - """Emmit an error message at given location.""" + """Emit an error message at given location.""" raise NotImplementedError @abstractmethod diff --git a/mypy/plugins/common.py b/mypy/plugins/common.py index fab836fcf711..c1dcd6b4ca2e 100644 --- a/mypy/plugins/common.py +++ b/mypy/plugins/common.py @@ -2,11 +2,11 @@ from mypy.nodes import ( ARG_POS, MDEF, Argument, Block, CallExpr, Expression, FuncBase, - FuncDef, PassStmt, RefExpr, SymbolTableNode, Var + FuncDef, PassStmt, RefExpr, SymbolTableNode, Var, StrExpr, ) from mypy.plugin import ClassDefContext from mypy.semanal import set_callable_name -from mypy.types import CallableType, Overloaded, Type, TypeVarDef +from mypy.types import CallableType, Overloaded, Type, TypeVarDef, LiteralType from mypy.typevars import fill_typevars @@ -112,3 +112,17 @@ def add_method( info.names[name] = SymbolTableNode(MDEF, func, plugin_generated=True) info.defn.defs.body.append(func) + + +def try_getting_str_literal(expr: Expression, typ: Type) -> Optional[str]: + """If this expression is a string literal, or if the corresponding type + is something like 'Literal["some string here"]', returns the underlying + string value. Otherwise, returns None.""" + if isinstance(typ, LiteralType) and typ.fallback.type.fullname() == 'builtins.str': + val = typ.value + assert isinstance(val, str) + return val + elif isinstance(expr, StrExpr): + return expr.value + else: + return None diff --git a/mypy/plugins/default.py b/mypy/plugins/default.py index 544f78c3f812..a4b2eb745c52 100644 --- a/mypy/plugins/default.py +++ b/mypy/plugins/default.py @@ -6,6 +6,7 @@ from mypy.plugin import ( Plugin, FunctionContext, MethodContext, MethodSigContext, AttributeContext, ClassDefContext ) +from mypy.plugins.common import try_getting_str_literal from mypy.types import ( Type, Instance, AnyType, TypeOfAny, CallableType, NoneTyp, UnionType, TypedDictType, TypeVarType @@ -170,24 +171,26 @@ def typed_dict_get_callback(ctx: MethodContext) -> Type: if (isinstance(ctx.type, TypedDictType) and len(ctx.arg_types) >= 1 and len(ctx.arg_types[0]) == 1): - if isinstance(ctx.args[0][0], StrExpr): - key = ctx.args[0][0].value - value_type = ctx.type.items.get(key) - if value_type: - if len(ctx.arg_types) == 1: - return UnionType.make_simplified_union([value_type, NoneTyp()]) - elif (len(ctx.arg_types) == 2 and len(ctx.arg_types[1]) == 1 - and len(ctx.args[1]) == 1): - default_arg = ctx.args[1][0] - if (isinstance(default_arg, DictExpr) and len(default_arg.items) == 0 - and isinstance(value_type, TypedDictType)): - # Special case '{}' as the default for a typed dict type. - return value_type.copy_modified(required_keys=set()) - else: - return UnionType.make_simplified_union([value_type, ctx.arg_types[1][0]]) - else: - ctx.api.msg.typeddict_key_not_found(ctx.type, key, ctx.context) - return AnyType(TypeOfAny.from_error) + key = try_getting_str_literal(ctx.args[0][0], ctx.arg_types[0][0]) + if key is None: + return ctx.default_return_type + + value_type = ctx.type.items.get(key) + if value_type: + if len(ctx.arg_types) == 1: + return UnionType.make_simplified_union([value_type, NoneTyp()]) + elif (len(ctx.arg_types) == 2 and len(ctx.arg_types[1]) == 1 + and len(ctx.args[1]) == 1): + default_arg = ctx.args[1][0] + if (isinstance(default_arg, DictExpr) and len(default_arg.items) == 0 + and isinstance(value_type, TypedDictType)): + # Special case '{}' as the default for a typed dict type. + return value_type.copy_modified(required_keys=set()) + else: + return UnionType.make_simplified_union([value_type, ctx.arg_types[1][0]]) + else: + ctx.api.msg.typeddict_key_not_found(ctx.type, key, ctx.context) + return AnyType(TypeOfAny.from_error) return ctx.default_return_type @@ -225,23 +228,23 @@ def typed_dict_pop_callback(ctx: MethodContext) -> Type: if (isinstance(ctx.type, TypedDictType) and len(ctx.arg_types) >= 1 and len(ctx.arg_types[0]) == 1): - if isinstance(ctx.args[0][0], StrExpr): - key = ctx.args[0][0].value - if key in ctx.type.required_keys: - ctx.api.msg.typeddict_key_cannot_be_deleted(ctx.type, key, ctx.context) - value_type = ctx.type.items.get(key) - if value_type: - if len(ctx.args[1]) == 0: - return value_type - elif (len(ctx.arg_types) == 2 and len(ctx.arg_types[1]) == 1 - and len(ctx.args[1]) == 1): - return UnionType.make_simplified_union([value_type, ctx.arg_types[1][0]]) - else: - ctx.api.msg.typeddict_key_not_found(ctx.type, key, ctx.context) - return AnyType(TypeOfAny.from_error) - else: + key = try_getting_str_literal(ctx.args[0][0], ctx.arg_types[0][0]) + if key is None: ctx.api.fail(messages.TYPEDDICT_KEY_MUST_BE_STRING_LITERAL, ctx.context) return AnyType(TypeOfAny.from_error) + + if key in ctx.type.required_keys: + ctx.api.msg.typeddict_key_cannot_be_deleted(ctx.type, key, ctx.context) + value_type = ctx.type.items.get(key) + if value_type: + if len(ctx.args[1]) == 0: + return value_type + elif (len(ctx.arg_types) == 2 and len(ctx.arg_types[1]) == 1 + and len(ctx.args[1]) == 1): + return UnionType.make_simplified_union([value_type, ctx.arg_types[1][0]]) + else: + ctx.api.msg.typeddict_key_not_found(ctx.type, key, ctx.context) + return AnyType(TypeOfAny.from_error) return ctx.default_return_type @@ -271,17 +274,17 @@ def typed_dict_setdefault_callback(ctx: MethodContext) -> Type: if (isinstance(ctx.type, TypedDictType) and len(ctx.arg_types) == 2 and len(ctx.arg_types[0]) == 1): - if isinstance(ctx.args[0][0], StrExpr): - key = ctx.args[0][0].value - value_type = ctx.type.items.get(key) - if value_type: - return value_type - else: - ctx.api.msg.typeddict_key_not_found(ctx.type, key, ctx.context) - return AnyType(TypeOfAny.from_error) - else: + key = try_getting_str_literal(ctx.args[0][0], ctx.arg_types[0][0]) + if key is None: ctx.api.fail(messages.TYPEDDICT_KEY_MUST_BE_STRING_LITERAL, ctx.context) return AnyType(TypeOfAny.from_error) + + value_type = ctx.type.items.get(key) + if value_type: + return value_type + else: + ctx.api.msg.typeddict_key_not_found(ctx.type, key, ctx.context) + return AnyType(TypeOfAny.from_error) return ctx.default_return_type @@ -296,15 +299,15 @@ def typed_dict_delitem_callback(ctx: MethodContext) -> Type: if (isinstance(ctx.type, TypedDictType) and len(ctx.arg_types) == 1 and len(ctx.arg_types[0]) == 1): - if isinstance(ctx.args[0][0], StrExpr): - key = ctx.args[0][0].value - if key in ctx.type.required_keys: - ctx.api.msg.typeddict_key_cannot_be_deleted(ctx.type, key, ctx.context) - elif key not in ctx.type.items: - ctx.api.msg.typeddict_key_not_found(ctx.type, key, ctx.context) - else: + key = try_getting_str_literal(ctx.args[0][0], ctx.arg_types[0][0]) + if key is None: ctx.api.fail(messages.TYPEDDICT_KEY_MUST_BE_STRING_LITERAL, ctx.context) return AnyType(TypeOfAny.from_error) + + if key in ctx.type.required_keys: + ctx.api.msg.typeddict_key_cannot_be_deleted(ctx.type, key, ctx.context) + elif key not in ctx.type.items: + ctx.api.msg.typeddict_key_not_found(ctx.type, key, ctx.context) return ctx.default_return_type diff --git a/test-data/unit/check-literal.test b/test-data/unit/check-literal.test index fb3941c10b47..3d51214d7349 100644 --- a/test-data/unit/check-literal.test +++ b/test-data/unit/check-literal.test @@ -2080,3 +2080,138 @@ def func(x: Literal[1], y: Literal[2]) -> None: pass reveal_type(unify(func)) # E: Revealed type is '' [builtins fixtures/list.pyi] [out] + + +-- +-- Checks for intelligent indexing +-- + +[case testLiteralIntelligentIndexingTuples] +from typing import Tuple, NamedTuple +from typing_extensions import Literal + +class A: pass +class B: pass +class C: pass +class D: pass +class E: pass + +idx0: Literal[0] +idx1: Literal[1] +idx2: Literal[2] +idx3: Literal[3] +idx4: Literal[4] +idx5: Literal[5] +idx_neg1: Literal[-1] + +tup1: Tuple[A, B, C, D, E] +reveal_type(tup1[idx0]) # E: Revealed type is '__main__.A' +reveal_type(tup1[idx1]) # E: Revealed type is '__main__.B' +reveal_type(tup1[idx2]) # E: Revealed type is '__main__.C' +reveal_type(tup1[idx3]) # E: Revealed type is '__main__.D' +reveal_type(tup1[idx4]) # E: Revealed type is '__main__.E' +reveal_type(tup1[idx_neg1]) # E: Revealed type is '__main__.E' +tup1[idx5] # E: Tuple index out of range +reveal_type(tup1[idx2:idx4]) # E: Revealed type is 'Tuple[__main__.C, __main__.D]' +reveal_type(tup1[::idx2]) # E: Revealed type is 'Tuple[__main__.A, __main__.C, __main__.E]' + +Tup2Class = NamedTuple('Tup2Class', [('a', A), ('b', B), ('c', C), ('d', D), ('e', E)]) +tup2: Tup2Class +reveal_type(tup2[idx0]) # E: Revealed type is '__main__.A' +reveal_type(tup2[idx1]) # E: Revealed type is '__main__.B' +reveal_type(tup2[idx2]) # E: Revealed type is '__main__.C' +reveal_type(tup2[idx3]) # E: Revealed type is '__main__.D' +reveal_type(tup2[idx4]) # E: Revealed type is '__main__.E' +reveal_type(tup2[idx_neg1]) # E: Revealed type is '__main__.E' +tup2[idx5] # E: Tuple index out of range +reveal_type(tup2[idx2:idx4]) # E: Revealed type is 'Tuple[__main__.C, __main__.D, fallback=__main__.Tup2Class]' +reveal_type(tup2[::idx2]) # E: Revealed type is 'Tuple[__main__.A, __main__.C, __main__.E, fallback=__main__.Tup2Class]' +[builtins fixtures/slice.pyi] +[out] + +[case testLiteralIntelligentIndexingTypedDict] +from typing_extensions import Literal +from mypy_extensions import TypedDict + +class Unrelated: pass +u: Unrelated + +class Inner(TypedDict): + a: int +class Outer(Inner, total=False): + b: str + +a_key: Literal["a"] +b_key: Literal["b"] +c_key: Literal["c"] + +d: Outer + +reveal_type(d[a_key]) # E: Revealed type is 'builtins.int' +reveal_type(d[b_key]) # E: Revealed type is 'builtins.str' +d[c_key] # E: TypedDict "Outer" has no key 'c' + +reveal_type(d.get(a_key, u)) # E: Revealed type is 'Union[builtins.int, __main__.Unrelated]' +reveal_type(d.get(b_key, u)) # E: Revealed type is 'Union[builtins.str, __main__.Unrelated]' +d.get(c_key, u) # E: TypedDict "Outer" has no key 'c' + +reveal_type(d.pop(a_key)) # E: Revealed type is 'builtins.int' \ + # E: Key 'a' of TypedDict "Outer" cannot be deleted +reveal_type(d.pop(b_key)) # E: Revealed type is 'builtins.str' +d.pop(c_key) # E: TypedDict "Outer" has no key 'c' + +del d[a_key] # E: Key 'a' of TypedDict "Outer" cannot be deleted +del d[b_key] +del d[c_key] # E: TypedDict "Outer" has no key 'c' +[builtins fixtures/dict.pyi] +[typing fixtures/typing-full.pyi] +[out] + +[case testLiteralIntelligentIndexingTypedDictPython2-skip] +# flags: --python-version 2.7 +from normal_mod import NormalDict +from unicode_mod import UnicodeDict + +from typing_extensions import Literal + +normal_dict = NormalDict(key=4) +unicode_dict = UnicodeDict(key=4) + +normal_key = "key" # type: Literal["key"] +unicode_key = u"key" # type: Literal[u"key"] + +# TODO: Make the runtime and mypy behaviors here consistent +# +# At runtime, all eight of the below operations will successfully return +# the int because b"key" == u"key" in Python 2. +# +# Mypy, in contrast, will accept all the four calls to `some_dict[...]` +# but will reject `normal_dict.get(unicode_key)` and `unicode_dict.get(unicode_key)` +# because the signature of `.get(...)` accepts only a str, not unicode. +# +# I don't think this has anything to do with Literal types: we had the same +# +# +# Tracking issue at https://github.com/python/mypy/issues/6123 +reveal_type(normal_dict[normal_key]) # E: Revealed type is 'builtins.int' +reveal_type(normal_dict[unicode_key]) # E: Revealed type is 'builtins.int' +reveal_type(unicode_dict[normal_key]) # E: Revealed type is 'builtins.int' +reveal_type(unicode_dict[unicode_key]) # E: Revealed type is 'builtins.int' + +reveal_type(normal_dict.get(normal_key)) # E: Revealed type is 'builtins.int' +reveal_type(normal_dict.get(unicode_key)) # E: Revealed type is 'builtins.int' +reveal_type(unicode_dict.get(normal_key)) # E: Revealed type is 'builtins.int' +reveal_type(unicode_dict.get(unicode_key)) # E: Revealed type is 'builtins.int' + +[file normal_mod.py] +from mypy_extensions import TypedDict +NormalDict = TypedDict('NormalDict', {'key': int}) + +[file unicode_mod.py] +from __future__ import unicode_literals +from mypy_extensions import TypedDict +UnicodeDict = TypedDict(b'UnicodeDict', {'key': int}) + +[builtins fixtures/dict.pyi] +[typing fixtures/typing-full.pyi] +[out] From a30430053645de6fb2b8c42d24b4074a4f868f88 Mon Sep 17 00:00:00 2001 From: Michael Lee Date: Thu, 3 Jan 2019 21:17:17 -0800 Subject: [PATCH 2/2] Fix trailing sentence --- test-data/unit/check-literal.test | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test-data/unit/check-literal.test b/test-data/unit/check-literal.test index 3d51214d7349..c7b7869f2bb3 100644 --- a/test-data/unit/check-literal.test +++ b/test-data/unit/check-literal.test @@ -2189,10 +2189,10 @@ unicode_key = u"key" # type: Literal[u"key"] # but will reject `normal_dict.get(unicode_key)` and `unicode_dict.get(unicode_key)` # because the signature of `.get(...)` accepts only a str, not unicode. # -# I don't think this has anything to do with Literal types: we had the same +# We get the same behavior if we replace all of the Literal[...] types for +# actual string literals. # -# -# Tracking issue at https://github.com/python/mypy/issues/6123 +# See https://github.com/python/mypy/issues/6123 for more details. reveal_type(normal_dict[normal_key]) # E: Revealed type is 'builtins.int' reveal_type(normal_dict[unicode_key]) # E: Revealed type is 'builtins.int' reveal_type(unicode_dict[normal_key]) # E: Revealed type is 'builtins.int'