Skip to content

Commit

Permalink
Add inference of Compare nodes (#979)
Browse files Browse the repository at this point in the history
* Add inference for Compare nodes

Ref #846.

Identity checks are currently Uninferable as there is no sensible way to
infer that two Instances refer to the same object without accurately
modelling control flow.

Co-authored-by: Pierre Sassoulas <pierre.sassoulas@gmail.com>
  • Loading branch information
nelfin and Pierre-Sassoulas committed Sep 14, 2021
1 parent 24a1118 commit 82d7faf
Show file tree
Hide file tree
Showing 2 changed files with 351 additions and 2 deletions.
94 changes: 94 additions & 0 deletions astroid/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,11 @@
"""this module contains a set of functions to handle inference on astroid trees
"""

import ast
import functools
import itertools
import operator
from typing import Any, Iterable

import wrapt

Expand Down Expand Up @@ -790,6 +792,98 @@ def infer_binop(self, context=None):
nodes.BinOp._infer_binop = _infer_binop
nodes.BinOp._infer = infer_binop

COMPARE_OPS = {
"==": operator.eq,
"!=": operator.ne,
"<": operator.lt,
"<=": operator.le,
">": operator.gt,
">=": operator.ge,
"in": lambda a, b: a in b,
"not in": lambda a, b: a not in b,
}
UNINFERABLE_OPS = {
"is",
"is not",
}


def _to_literal(node: nodes.NodeNG) -> Any:
# Can raise SyntaxError or ValueError from ast.literal_eval
# Is this the stupidest idea or the simplest idea?
return ast.literal_eval(node.as_string())


def _do_compare(
left_iter: Iterable[nodes.NodeNG], op: str, right_iter: Iterable[nodes.NodeNG]
) -> "bool | type[util.Uninferable]":
"""
If all possible combinations are either True or False, return that:
>>> _do_compare([1, 2], '<=', [3, 4])
True
>>> _do_compare([1, 2], '==', [3, 4])
False
If any item is uninferable, or if some combinations are True and some
are False, return Uninferable:
>>> _do_compare([1, 3], '<=', [2, 4])
util.Uninferable
"""
retval = None
if op in UNINFERABLE_OPS:
return util.Uninferable
op_func = COMPARE_OPS[op]

for left, right in itertools.product(left_iter, right_iter):
if left is util.Uninferable or right is util.Uninferable:
return util.Uninferable

try:
left, right = _to_literal(left), _to_literal(right)
except (SyntaxError, ValueError):
return util.Uninferable

try:
expr = op_func(left, right)
except TypeError as exc:
raise AstroidTypeError from exc

if retval is None:
retval = expr
elif retval != expr:
return util.Uninferable
# (or both, but "True | False" is basically the same)

return retval # it was all the same value


def _infer_compare(self: nodes.Compare, context: InferenceContext) -> Any:
"""Chained comparison inference logic."""
retval = True

ops = self.ops
left_node = self.left
lhs = list(left_node.infer(context=context))
# should we break early if first element is uninferable?
for op, right_node in ops:
# eagerly evaluate rhs so that values can be re-used as lhs
rhs = list(right_node.infer(context=context))
try:
retval = _do_compare(lhs, op, rhs)
except AstroidTypeError:
retval = util.Uninferable
break
if retval is not True:
break # short-circuit
lhs = rhs # continue
if retval is util.Uninferable:
yield retval
else:
yield nodes.Const(retval)


nodes.Compare._infer = _infer_compare


def _infer_augassign(self, context=None):
"""Inference logic for augmented binary operations."""
Expand Down
259 changes: 257 additions & 2 deletions tests/unittest_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -5280,6 +5280,261 @@ def f(**kwargs):
assert next(extract_node(code).infer()).as_string() == "{'f': 1}"


@pytest.mark.parametrize(
"op,result",
[
("<", False),
("<=", True),
("==", True),
(">=", True),
(">", False),
("!=", False),
],
)
def test_compare(op, result) -> None:
code = """
123 {} 123
""".format(
op
)
node = extract_node(code)
inferred = next(node.infer())
assert inferred.value == result


@pytest.mark.xfail(reason="uninferable")
@pytest.mark.parametrize(
"op,result",
[
("is", True),
("is not", False),
],
)
def test_compare_identity(op, result) -> None:
code = """
obj = object()
obj {} obj
""".format(
op
)
node = extract_node(code)
inferred = next(node.infer())
assert inferred.value == result


@pytest.mark.parametrize(
"op,result",
[
("in", True),
("not in", False),
],
)
def test_compare_membership(op, result) -> None:
code = """
1 {} [1, 2, 3]
""".format(
op
)
node = extract_node(code)
inferred = next(node.infer())
assert inferred.value == result


@pytest.mark.parametrize(
"lhs,rhs,result",
[
(1, 1, True),
(1, 1.1, True),
(1.1, 1, False),
(1.0, 1.0, True),
("abc", "def", True),
("abc", "", False),
([], [1], True),
((1, 2), (2, 3), True),
((1, 0), (1,), False),
(True, True, True),
(True, False, False),
(False, 1, True),
(1 + 0j, 2 + 0j, util.Uninferable),
(+0.0, -0.0, True),
(0, "1", util.Uninferable),
(b"\x00", b"\x01", True),
],
)
def test_compare_lesseq_types(lhs, rhs, result) -> None:
code = """
{lhs!r} <= {rhs!r}
""".format(
lhs=lhs, rhs=rhs
)
node = extract_node(code)
inferred = next(node.infer())
assert inferred.value == result


def test_compare_chained() -> None:
code = """
3 < 5 > 3
"""
node = extract_node(code)
inferred = next(node.infer())
assert inferred.value is True


def test_compare_inferred_members() -> None:
code = """
a = 11
b = 13
a < b
"""
node = extract_node(code)
inferred = next(node.infer())
assert inferred.value is True


def test_compare_instance_members() -> None:
code = """
class A:
value = 123
class B:
@property
def value(self):
return 456
A().value < B().value
"""
node = extract_node(code)
inferred = next(node.infer())
assert inferred.value is True


@pytest.mark.xfail(reason="unimplemented")
def test_compare_dynamic() -> None:
code = """
class A:
def __le__(self, other):
return True
A() <= None
"""
node = extract_node(code)
inferred = next(node.infer())
assert inferred.value is True


def test_compare_uninferable_member() -> None:
code = """
from unknown import UNKNOWN
0 <= UNKNOWN
"""
node = extract_node(code)
inferred = next(node.infer())
assert inferred is util.Uninferable


def test_compare_chained_comparisons_shortcircuit_on_false() -> None:
code = """
from unknown import UNKNOWN
2 < 1 < UNKNOWN
"""
node = extract_node(code)
inferred = next(node.infer())
assert inferred.value is False


def test_compare_chained_comparisons_continue_on_true() -> None:
code = """
from unknown import UNKNOWN
1 < 2 < UNKNOWN
"""
node = extract_node(code)
inferred = next(node.infer())
assert inferred is util.Uninferable


@pytest.mark.xfail(reason="unimplemented")
def test_compare_known_false_branch() -> None:
code = """
a = 'hello'
if 1 < 2:
a = 'goodbye'
a
"""
node = extract_node(code)
inferred = list(node.infer())
assert len(inferred) == 1
assert isinstance(inferred[0], nodes.Const)
assert inferred[0].value == "hello"


def test_compare_ifexp_constant() -> None:
code = """
a = 'hello' if 1 < 2 else 'goodbye'
a
"""
node = extract_node(code)
inferred = list(node.infer())
assert len(inferred) == 1
assert isinstance(inferred[0], nodes.Const)
assert inferred[0].value == "hello"


def test_compare_typeerror() -> None:
code = """
123 <= "abc"
"""
node = extract_node(code)
inferred = list(node.infer())
assert len(inferred) == 1
assert inferred[0] is util.Uninferable


def test_compare_multiple_possibilites() -> None:
code = """
from unknown import UNKNOWN
a = 1
if UNKNOWN:
a = 2
b = 3
if UNKNOWN:
b = 4
a < b
"""
node = extract_node(code)
inferred = list(node.infer())
assert len(inferred) == 1
# All possible combinations are true: (1 < 3), (1 < 4), (2 < 3), (2 < 4)
assert inferred[0].value is True


def test_compare_ambiguous_multiple_possibilites() -> None:
code = """
from unknown import UNKNOWN
a = 1
if UNKNOWN:
a = 3
b = 2
if UNKNOWN:
b = 4
a < b
"""
node = extract_node(code)
inferred = list(node.infer())
assert len(inferred) == 1
# Not all possible combinations are true: (1 < 2), (1 < 4), (3 !< 2), (3 < 4)
assert inferred[0] is util.Uninferable


def test_compare_nonliteral() -> None:
code = """
def func(a, b):
return (a, b) <= (1, 2) #@
"""
return_node = extract_node(code)
node = return_node.value
inferred = list(node.infer()) # should not raise ValueError
assert len(inferred) == 1
assert inferred[0] is util.Uninferable


def test_limit_inference_result_amount() -> None:
"""Test setting limit inference result amount"""
code = """
Expand Down Expand Up @@ -5560,7 +5815,7 @@ def method(self):
""",
],
)
def test_subclass_of_exception(code):
def test_subclass_of_exception(code) -> None:
inferred = next(extract_node(code).infer())
assert isinstance(inferred, Instance)
args = next(inferred.igetattr("args"))
Expand Down Expand Up @@ -5721,7 +5976,7 @@ def test(self):
),
],
)
def test_inference_is_limited_to_the_boundnode(code, instance_name):
def test_inference_is_limited_to_the_boundnode(code, instance_name) -> None:
node = extract_node(code)
inferred = next(node.infer())
assert isinstance(inferred, Instance)
Expand Down

0 comments on commit 82d7faf

Please sign in to comment.