Skip to content

Commit

Permalink
Add support for jump statements in partially defined vars check (#13632)
Browse files Browse the repository at this point in the history
This builds on #13601 to add support for statements like `continue`,
`break`, `return`, `raise` in partially defined variables check. The
simplest example is:
```python
def f1() -> int:
    if int():
        x = 1
    else:
        return 0
    return x
```

Previously, mypy would generate a false positive on the last line of
example. See test cases for more details.

Adding this support was relatively simple, given all the already
existing code.

Things that aren't supported yet: `match`, `with`, and detecting
unreachable blocks.

After this PR, when enabling this check on mypy itself, it generates 18
errors, all of them are potential bugs.
  • Loading branch information
ilinum committed Sep 12, 2022
1 parent 0f17aff commit 216a45b
Show file tree
Hide file tree
Showing 4 changed files with 300 additions and 55 deletions.
10 changes: 6 additions & 4 deletions mypy/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -2331,13 +2331,15 @@ def type_check_second_pass(self) -> bool:
self.time_spent_us += time_spent_us(t0)
return result

def detect_partially_defined_vars(self) -> None:
def detect_partially_defined_vars(self, type_map: dict[Expression, Type]) -> None:
assert self.tree is not None, "Internal error: method must be called on parsed file only"
manager = self.manager
if manager.errors.is_error_code_enabled(codes.PARTIALLY_DEFINED):
manager.errors.set_file(self.xpath, self.tree.fullname, options=manager.options)
self.tree.accept(
PartiallyDefinedVariableVisitor(MessageBuilder(manager.errors, manager.modules))
PartiallyDefinedVariableVisitor(
MessageBuilder(manager.errors, manager.modules), type_map
)
)

def finish_passes(self) -> None:
Expand Down Expand Up @@ -3368,7 +3370,7 @@ def process_stale_scc(graph: Graph, scc: list[str], manager: BuildManager) -> No
graph[id].type_check_first_pass()
if not graph[id].type_checker().deferred_nodes:
unfinished_modules.discard(id)
graph[id].detect_partially_defined_vars()
graph[id].detect_partially_defined_vars(graph[id].type_map())
graph[id].finish_passes()

while unfinished_modules:
Expand All @@ -3377,7 +3379,7 @@ def process_stale_scc(graph: Graph, scc: list[str], manager: BuildManager) -> No
continue
if not graph[id].type_check_second_pass():
unfinished_modules.discard(id)
graph[id].detect_partially_defined_vars()
graph[id].detect_partially_defined_vars(graph[id].type_map())
graph[id].finish_passes()
for id in stale:
graph[id].generate_unused_ignore_notes()
Expand Down
161 changes: 111 additions & 50 deletions mypy/partially_defined.py
Original file line number Diff line number Diff line change
@@ -1,108 +1,129 @@
from __future__ import annotations

from typing import NamedTuple

from mypy import checker
from mypy.messages import MessageBuilder
from mypy.nodes import (
AssertStmt,
AssignmentStmt,
BreakStmt,
ContinueStmt,
Expression,
ExpressionStmt,
ForStmt,
FuncDef,
FuncItem,
GeneratorExpr,
IfStmt,
ListExpr,
Lvalue,
NameExpr,
RaiseStmt,
ReturnStmt,
TupleExpr,
WhileStmt,
)
from mypy.traverser import TraverserVisitor
from mypy.traverser import ExtendedTraverserVisitor
from mypy.types import Type, UninhabitedType


class DefinedVars(NamedTuple):
"""DefinedVars contains information about variable definition at the end of a branching statement.
class BranchState:
"""BranchState contains information about variable definition at the end of a branching statement.
`if` and `match` are examples of branching statements.
`may_be_defined` contains variables that were defined in only some branches.
`must_be_defined` contains variables that were defined in all branches.
"""

may_be_defined: set[str]
must_be_defined: set[str]
def __init__(
self,
must_be_defined: set[str] | None = None,
may_be_defined: set[str] | None = None,
skipped: bool = False,
) -> None:
if may_be_defined is None:
may_be_defined = set()
if must_be_defined is None:
must_be_defined = set()

self.may_be_defined = set(may_be_defined)
self.must_be_defined = set(must_be_defined)
self.skipped = skipped


class BranchStatement:
def __init__(self, already_defined: DefinedVars) -> None:
self.already_defined = already_defined
self.defined_by_branch: list[DefinedVars] = [
DefinedVars(may_be_defined=set(), must_be_defined=set(already_defined.must_be_defined))
def __init__(self, initial_state: BranchState) -> None:
self.initial_state = initial_state
self.branches: list[BranchState] = [
BranchState(must_be_defined=self.initial_state.must_be_defined)
]

def next_branch(self) -> None:
self.defined_by_branch.append(
DefinedVars(
may_be_defined=set(), must_be_defined=set(self.already_defined.must_be_defined)
)
)
self.branches.append(BranchState(must_be_defined=self.initial_state.must_be_defined))

def record_definition(self, name: str) -> None:
assert len(self.defined_by_branch) > 0
self.defined_by_branch[-1].must_be_defined.add(name)
self.defined_by_branch[-1].may_be_defined.discard(name)

def record_nested_branch(self, vars: DefinedVars) -> None:
assert len(self.defined_by_branch) > 0
current_branch = self.defined_by_branch[-1]
current_branch.must_be_defined.update(vars.must_be_defined)
current_branch.may_be_defined.update(vars.may_be_defined)
assert len(self.branches) > 0
self.branches[-1].must_be_defined.add(name)
self.branches[-1].may_be_defined.discard(name)

def record_nested_branch(self, state: BranchState) -> None:
assert len(self.branches) > 0
current_branch = self.branches[-1]
if state.skipped:
current_branch.skipped = True
return
current_branch.must_be_defined.update(state.must_be_defined)
current_branch.may_be_defined.update(state.may_be_defined)
current_branch.may_be_defined.difference_update(current_branch.must_be_defined)

def skip_branch(self) -> None:
assert len(self.branches) > 0
self.branches[-1].skipped = True

def is_possibly_undefined(self, name: str) -> bool:
assert len(self.defined_by_branch) > 0
return name in self.defined_by_branch[-1].may_be_defined
assert len(self.branches) > 0
return name in self.branches[-1].may_be_defined

def done(self) -> DefinedVars:
assert len(self.defined_by_branch) > 0
if len(self.defined_by_branch) == 1:
# If there's only one branch, then we just return current.
# Note that this case is a different case when an empty branch is omitted (e.g. `if` without `else`).
return self.defined_by_branch[0]
def done(self) -> BranchState:
branches = [b for b in self.branches if not b.skipped]
if len(branches) == 0:
return BranchState(skipped=True)
if len(branches) == 1:
return branches[0]

# must_be_defined is a union of must_be_defined of all branches.
must_be_defined = set(self.defined_by_branch[0].must_be_defined)
for branch_vars in self.defined_by_branch[1:]:
must_be_defined.intersection_update(branch_vars.must_be_defined)
must_be_defined = set(branches[0].must_be_defined)
for b in branches[1:]:
must_be_defined.intersection_update(b.must_be_defined)
# may_be_defined are all variables that are not must be defined.
all_vars = set()
for branch_vars in self.defined_by_branch:
all_vars.update(branch_vars.may_be_defined)
all_vars.update(branch_vars.must_be_defined)
for b in branches:
all_vars.update(b.may_be_defined)
all_vars.update(b.must_be_defined)
may_be_defined = all_vars.difference(must_be_defined)
return DefinedVars(may_be_defined=may_be_defined, must_be_defined=must_be_defined)
return BranchState(may_be_defined=may_be_defined, must_be_defined=must_be_defined)


class DefinedVariableTracker:
"""DefinedVariableTracker manages the state and scope for the UndefinedVariablesVisitor."""

def __init__(self) -> None:
# There's always at least one scope. Within each scope, there's at least one "global" BranchingStatement.
self.scopes: list[list[BranchStatement]] = [
[BranchStatement(DefinedVars(may_be_defined=set(), must_be_defined=set()))]
]
self.scopes: list[list[BranchStatement]] = [[BranchStatement(BranchState())]]

def _scope(self) -> list[BranchStatement]:
assert len(self.scopes) > 0
return self.scopes[-1]

def enter_scope(self) -> None:
assert len(self._scope()) > 0
self.scopes.append([BranchStatement(self._scope()[-1].defined_by_branch[-1])])
self.scopes.append([BranchStatement(self._scope()[-1].branches[-1])])

def exit_scope(self) -> None:
self.scopes.pop()

def start_branch_statement(self) -> None:
assert len(self._scope()) > 0
self._scope().append(BranchStatement(self._scope()[-1].defined_by_branch[-1]))
self._scope().append(BranchStatement(self._scope()[-1].branches[-1]))

def next_branch(self) -> None:
assert len(self._scope()) > 1
Expand All @@ -113,6 +134,11 @@ def end_branch_statement(self) -> None:
result = self._scope().pop().done()
self._scope()[-1].record_nested_branch(result)

def skip_branch(self) -> None:
# Only skip branch if we're outside of "root" branch statement.
if len(self._scope()) > 1:
self._scope()[-1].skip_branch()

def record_declaration(self, name: str) -> None:
assert len(self.scopes) > 0
assert len(self.scopes[-1]) > 0
Expand All @@ -125,7 +151,7 @@ def is_possibly_undefined(self, name: str) -> bool:
return self._scope()[-1].is_possibly_undefined(name)


class PartiallyDefinedVariableVisitor(TraverserVisitor):
class PartiallyDefinedVariableVisitor(ExtendedTraverserVisitor):
"""Detect variables that are defined only part of the time.
This visitor detects the following case:
Expand All @@ -137,8 +163,9 @@ class PartiallyDefinedVariableVisitor(TraverserVisitor):
handled by the semantic analyzer.
"""

def __init__(self, msg: MessageBuilder) -> None:
def __init__(self, msg: MessageBuilder, type_map: dict[Expression, Type]) -> None:
self.msg = msg
self.type_map = type_map
self.tracker = DefinedVariableTracker()

def process_lvalue(self, lvalue: Lvalue) -> None:
Expand Down Expand Up @@ -175,6 +202,13 @@ def visit_func(self, o: FuncItem) -> None:
self.tracker.record_declaration(arg.variable.name)
super().visit_func(o)

def visit_generator_expr(self, o: GeneratorExpr) -> None:
self.tracker.enter_scope()
for idx in o.indices:
self.process_lvalue(idx)
super().visit_generator_expr(o)
self.tracker.exit_scope()

def visit_for_stmt(self, o: ForStmt) -> None:
o.expr.accept(self)
self.process_lvalue(o.index)
Expand All @@ -186,13 +220,40 @@ def visit_for_stmt(self, o: ForStmt) -> None:
o.else_body.accept(self)
self.tracker.end_branch_statement()

def visit_return_stmt(self, o: ReturnStmt) -> None:
super().visit_return_stmt(o)
self.tracker.skip_branch()

def visit_assert_stmt(self, o: AssertStmt) -> None:
super().visit_assert_stmt(o)
if checker.is_false_literal(o.expr):
self.tracker.skip_branch()

def visit_raise_stmt(self, o: RaiseStmt) -> None:
super().visit_raise_stmt(o)
self.tracker.skip_branch()

def visit_continue_stmt(self, o: ContinueStmt) -> None:
super().visit_continue_stmt(o)
self.tracker.skip_branch()

def visit_break_stmt(self, o: BreakStmt) -> None:
super().visit_break_stmt(o)
self.tracker.skip_branch()

def visit_expression_stmt(self, o: ExpressionStmt) -> None:
if isinstance(self.type_map.get(o.expr, None), UninhabitedType):
self.tracker.skip_branch()
super().visit_expression_stmt(o)

def visit_while_stmt(self, o: WhileStmt) -> None:
o.expr.accept(self)
self.tracker.start_branch_statement()
o.body.accept(self)
self.tracker.next_branch()
if o.else_body:
o.else_body.accept(self)
if not checker.is_true_literal(o.expr):
self.tracker.next_branch()
if o.else_body:
o.else_body.accept(self)
self.tracker.end_branch_statement()

def visit_name_expr(self, o: NameExpr) -> None:
Expand Down
2 changes: 1 addition & 1 deletion mypy/server/update.py
Original file line number Diff line number Diff line change
Expand Up @@ -651,7 +651,7 @@ def restore(ids: list[str]) -> None:
state.type_checker().reset()
state.type_check_first_pass()
state.type_check_second_pass()
state.detect_partially_defined_vars()
state.detect_partially_defined_vars(state.type_map())
t2 = time.time()
state.finish_passes()
t3 = time.time()
Expand Down

0 comments on commit 216a45b

Please sign in to comment.