Skip to content

Commit

Permalink
Make isinstance/issubclass generate ad-hoc intersections
Browse files Browse the repository at this point in the history
This diff makes `isinstance(...)` and `issubclass(...)` try
generating ad-hoc intersections of Instances when possible.

For example, we previously concluded the if-branch is unreachable
in the following program. This PR makes mypy infer an ad-hoc
intersection instead.

    class A: pass
    class B: pass

    x: A
    if isinstance(x, B):
        reveal_type(x)  # N: Revealed type is 'test.<subclass of "A" and "B">'

If you try doing an `isinstance(...)` that legitimately is impossible
due to conflicting method signatures or MRO issues, we continue to
declare the branch unreachable. Passing in the `--warn-unreachable`
flag will now also report an error about this:

    # flags: --warn-unreachable
    x: str

    # E: Subclass of "str" and "bytes" cannot exist: would have
    #    incompatible method signatures
    if isinstance(x, bytes):
        reveal_type(x)  # E: Statement is unreachable

This error message has the same limitations as the other
`--warn-unreachable` ones: we suppress them if the isinstance check
is inside a function using TypeVars with multiple values.

However, we *do* end up always inferring an intersection type when
possible -- that logic is never suppressed.

I initially thought we might have to suppress the new logic as well
(see #3603 (comment)),
but it turns out this is a non-issue in practice once you add in
the check that disallows impossible intersections.

For example, when I tried running this PR on the larger of our two
internal codebases, I found about 25 distinct errors, all of which
were legitimate and unrelated to the problem discussed in the PR.

(And if we don't suppress the extra error message, we get about
100-120 errors, mostly due to tests repeatdly doing `result = blah()`
followed by `assert isinstance(result, X)` where X keeps changing.)
  • Loading branch information
Michael0x2a committed Jan 20, 2020
1 parent 861f01c commit 658150d
Show file tree
Hide file tree
Showing 9 changed files with 500 additions and 40 deletions.
151 changes: 129 additions & 22 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
get_proper_types, is_literal_type, TypeAliasType)
from mypy.sametypes import is_same_type
from mypy.messages import (
MessageBuilder, make_inferred_type_note, append_invariance_notes,
MessageBuilder, make_inferred_type_note, append_invariance_notes, pretty_seq,
format_type, format_type_bare, format_type_distinctly, SUGGESTED_TEST_FIXTURES
)
import mypy.checkexpr
Expand All @@ -63,7 +63,7 @@
from mypy.maptype import map_instance_to_supertype
from mypy.typevars import fill_typevars, has_no_typevars, fill_typevars_with_any
from mypy.semanal import set_callable_name, refers_to_fullname
from mypy.mro import calculate_mro
from mypy.mro import calculate_mro, MroError
from mypy.erasetype import erase_typevars, remove_instance_last_known_values, erase_type
from mypy.expandtype import expand_type, expand_type_by_instance
from mypy.visitor import NodeVisitor
Expand Down Expand Up @@ -1963,13 +1963,15 @@ def visit_block(self, b: Block) -> None:
return
for s in b.body:
if self.binder.is_unreachable():
if (self.options.warn_unreachable
and not self.binder.is_unreachable_warning_suppressed()
and not self.is_raising_or_empty(s)):
if self.should_report_unreachable_issues() and not self.is_raising_or_empty(s):
self.msg.unreachable_statement(s)
break
self.accept(s)

def should_report_unreachable_issues(self) -> bool:
return (self.options.warn_unreachable
and not self.binder.is_unreachable_warning_suppressed())

def is_raising_or_empty(self, s: Statement) -> bool:
"""Returns 'true' if the given statement either throws an error of some kind
or is a no-op.
Expand Down Expand Up @@ -3636,6 +3638,78 @@ def visit_continue_stmt(self, s: ContinueStmt) -> None:
self.binder.handle_continue()
return None

def make_fake_typeinfo(self,
curr_module_fullname: str,
class_gen_name: str,
class_short_name: str,
bases: List[Instance],
) -> Tuple[ClassDef, TypeInfo]:
# Build the fake ClassDef and TypeInfo together.
# The ClassDef is full of lies and doesn't actually contain a body.
# Use format_bare to generate a nice name for error messages.
# We skip fully filling out a handful of TypeInfo fields because they
# should be irrelevant for a generated type like this:
# is_protocol, protocol_members, is_abstract
cdef = ClassDef(class_short_name, Block([]))
cdef.fullname = curr_module_fullname + '.' + class_gen_name
info = TypeInfo(SymbolTable(), cdef, curr_module_fullname)
cdef.info = info
info.bases = bases
calculate_mro(info)
info.calculate_metaclass_type()
return cdef, info

def intersect_instances(self,
instances: Sequence[Instance],
ctx: Context,
) -> Optional[Instance]:
curr_module = self.scope.stack[0]
assert isinstance(curr_module, MypyFile)

base_classes = []
formatted_names = []
for inst in instances:
expanded = [inst]
if inst.type.is_intersection:
expanded = inst.type.bases

for expanded_inst in expanded:
base_classes.append(expanded_inst)
formatted_names.append(format_type_bare(expanded_inst))

pretty_names_list = pretty_seq(format_type_distinctly(*base_classes, bare=True), "and")
short_name = '<subclass of {}>'.format(pretty_names_list)
full_name = gen_unique_name(short_name, curr_module.names)

old_msg = self.msg
new_msg = self.msg.clean_copy()
self.msg = new_msg
try:
cdef, info = self.make_fake_typeinfo(
curr_module.fullname,
full_name,
short_name,
base_classes,
)
self.check_multiple_inheritance(info)
info.is_intersection = True
except MroError:
if self.should_report_unreachable_issues():
old_msg.impossible_intersection(
pretty_names_list, "inconsistent method resolution order", ctx)
return None
finally:
self.msg = old_msg

if new_msg.is_errors():
if self.should_report_unreachable_issues():
self.msg.impossible_intersection(
pretty_names_list, "incompatible method signatures", ctx)
return None

curr_module.names[full_name] = SymbolTableNode(GDEF, info)
return Instance(info, [])

def intersect_instance_callable(self, typ: Instance, callable_type: CallableType) -> Instance:
"""Creates a fake type that represents the intersection of an Instance and a CallableType.
Expand All @@ -3650,20 +3724,9 @@ def intersect_instance_callable(self, typ: Instance, callable_type: CallableType
gen_name = gen_unique_name("<callable subtype of {}>".format(typ.type.name),
cur_module.names)

# Build the fake ClassDef and TypeInfo together.
# The ClassDef is full of lies and doesn't actually contain a body.
# Use format_bare to generate a nice name for error messages.
# We skip fully filling out a handful of TypeInfo fields because they
# should be irrelevant for a generated type like this:
# is_protocol, protocol_members, is_abstract
# Synthesize a fake TypeInfo
short_name = format_type_bare(typ)
cdef = ClassDef(short_name, Block([]))
cdef.fullname = cur_module.fullname + '.' + gen_name
info = TypeInfo(SymbolTable(), cdef, cur_module.fullname)
cdef.info = info
info.bases = [typ]
calculate_mro(info)
info.calculate_metaclass_type()
cdef, info = self.make_fake_typeinfo(cur_module.fullname, gen_name, short_name, [typ])

# Build up a fake FuncDef so we can populate the symbol table.
func_def = FuncDef('__call__', [], Block([]), callable_type)
Expand Down Expand Up @@ -3828,9 +3891,11 @@ def find_isinstance_check_helper(self, node: Expression) -> Tuple[TypeMap, TypeM
return {}, {}
expr = node.args[0]
if literal(expr) == LITERAL_TYPE:
vartype = type_map[expr]
type = get_isinstance_type(node.args[1], type_map)
return conditional_type_map(expr, vartype, type)
return self.conditional_type_map_with_intersection(
expr,
type_map[expr],
get_isinstance_type(node.args[1], type_map),
)
elif refers_to_fullname(node.callee, 'builtins.issubclass'):
if len(node.args) != 2: # the error will be reported elsewhere
return {}, {}
Expand Down Expand Up @@ -4309,6 +4374,10 @@ def refine_identity_comparison_expression(self,

if enum_name is not None:
expr_type = try_expanding_enum_to_union(expr_type, enum_name)

# We intentionally use 'conditional_type_map' directly here instead of
# 'self.conditional_type_map_with_intersection': we only compute ad-hoc
# intersections when working with pure instances.
partial_type_maps.append(conditional_type_map(expr, expr_type, target_type))

return reduce_conditional_maps(partial_type_maps)
Expand Down Expand Up @@ -4726,10 +4795,48 @@ def infer_issubclass_maps(self, node: CallExpr,
# Any other object whose type we don't know precisely
# for example, Any or a custom metaclass.
return {}, {} # unknown type
yes_map, no_map = conditional_type_map(expr, vartype, type)
yes_map, no_map = self.conditional_type_map_with_intersection(expr, vartype, type)
yes_map, no_map = map(convert_to_typetype, (yes_map, no_map))
return yes_map, no_map

def conditional_type_map_with_intersection(self,
expr: Expression,
expr_type: Type,
type_ranges: Optional[List[TypeRange]],
) -> Tuple[TypeMap, TypeMap]:
yes_map, no_map = conditional_type_map(expr, expr_type, type_ranges)

if yes_map is not None or type_ranges is None:
return yes_map, no_map

# If we couldn't infer anything useful, try again by trying to compute an intersection
expr_type = get_proper_type(expr_type)
if isinstance(expr_type, UnionType):
possible_expr_types = get_proper_types(expr_type.relevant_items())
else:
possible_expr_types = [expr_type]

possible_target_types = []
for tr in type_ranges:
item = get_proper_type(tr.item)
if not isinstance(item, Instance) or tr.is_upper_bound:
return yes_map, no_map
possible_target_types.append(item)

out = []
for v in possible_expr_types:
if not isinstance(v, Instance):
return yes_map, no_map
for t in possible_target_types:
intersection = self.intersect_instances([v, t], expr)
if intersection is None:
continue
out.append(intersection)
if len(out) == 0:
return None, {}
new_yes_type = make_simplified_union(out)
return {expr: new_yes_type}, {}


def conditional_type_map(expr: Expression,
current_type: Optional[Type],
Expand Down
23 changes: 18 additions & 5 deletions mypy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,11 @@ def has_no_attr(self,
if matches:
self.fail(
'{} has no attribute "{}"; maybe {}?{}'.format(
format_type(original_type), member, pretty_or(matches), extra),
format_type(original_type),
member,
pretty_seq(matches, "or"),
extra,
),
context,
code=codes.ATTR_DEFINED)
failed = True
Expand Down Expand Up @@ -623,7 +627,7 @@ def unexpected_keyword_argument(self, callee: CallableType, name: str, arg_type:
if not matches:
matches = best_matches(name, not_matching_type_args)
if matches:
msg += "; did you mean {}?".format(pretty_or(matches[:3]))
msg += "; did you mean {}?".format(pretty_seq(matches[:3], "or"))
self.fail(msg, context, code=codes.CALL_ARG)
module = find_defining_module(self.modules, callee)
if module:
Expand Down Expand Up @@ -1263,6 +1267,14 @@ def redundant_condition_in_assert(self, truthiness: bool, context: Context) -> N
def redundant_expr(self, description: str, truthiness: bool, context: Context) -> None:
self.fail("{} is always {}".format(description, str(truthiness).lower()), context)

def impossible_intersection(self,
formatted_base_class_list: str,
reason: str,
context: Context,
) -> None:
template = "Subclass of {} cannot exist: would have {}"
self.fail(template.format(formatted_base_class_list, reason), context)

def report_protocol_problems(self,
subtype: Union[Instance, TupleType, TypedDictType],
supertype: Instance,
Expand Down Expand Up @@ -1995,13 +2007,14 @@ def best_matches(current: str, options: Iterable[str]) -> List[str]:
reverse=True, key=lambda v: (ratios[v], v))


def pretty_or(args: List[str]) -> str:
def pretty_seq(args: Sequence[str], conjunction: str) -> str:
quoted = ['"' + a + '"' for a in args]
if len(quoted) == 1:
return quoted[0]
if len(quoted) == 2:
return "{} or {}".format(quoted[0], quoted[1])
return ", ".join(quoted[:-1]) + ", or " + quoted[-1]
return "{} {} {}".format(quoted[0], conjunction, quoted[1])
last_sep = ", " + conjunction + " "
return ", ".join(quoted[:-1]) + last_sep + quoted[-1]


def append_invariance_notes(notes: List[str], arg_type: Instance,
Expand Down
4 changes: 4 additions & 0 deletions mypy/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2379,13 +2379,17 @@ class is generic then it will be a type constructor of higher kind.
# Is this a newtype type?
is_newtype = False

# Is this a synthesized intersection type?
is_intersection = False

# This is a dictionary that will be serialized and un-serialized as is.
# It is useful for plugins to add their data to save in the cache.
metadata = None # type: Dict[str, JsonDict]

FLAGS = [
'is_abstract', 'is_enum', 'fallback_to_any', 'is_named_tuple',
'is_newtype', 'is_protocol', 'runtime_protocol', 'is_final',
'is_intersection',
] # type: Final[List[str]]

def __init__(self, names: 'SymbolTable', defn: ClassDef, module_name: str) -> None:
Expand Down
4 changes: 2 additions & 2 deletions mypy/semanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@
from mypy.typevars import fill_typevars
from mypy.visitor import NodeVisitor
from mypy.errors import Errors, report_internal_error
from mypy.messages import best_matches, MessageBuilder, pretty_or, SUGGESTED_TEST_FIXTURES
from mypy.messages import best_matches, MessageBuilder, pretty_seq, SUGGESTED_TEST_FIXTURES
from mypy.errorcodes import ErrorCode
from mypy import message_registry, errorcodes as codes
from mypy.types import (
Expand Down Expand Up @@ -1802,7 +1802,7 @@ def report_missing_module_attribute(self, import_id: str, source_id: str, import
alternatives = set(module.names.keys()).difference({source_id})
matches = best_matches(source_id, alternatives)[:3]
if matches:
suggestion = "; maybe {}?".format(pretty_or(matches))
suggestion = "; maybe {}?".format(pretty_seq(matches, "or"))
message += "{}".format(suggestion)
self.fail(message, context, code=codes.ATTR_DEFINED)
self.add_unknown_imported_symbol(imported_id, context)
Expand Down

0 comments on commit 658150d

Please sign in to comment.