diff --git a/astroid/inference.py b/astroid/inference.py index fd2735ebe..02be060d3 100644 --- a/astroid/inference.py +++ b/astroid/inference.py @@ -27,9 +27,11 @@ """this module contains a set of functions to handle inference on astroid trees """ +import ast import functools import itertools import operator +from typing import Any, Iterable import wrapt @@ -790,6 +792,98 @@ def infer_binop(self, context=None): nodes.BinOp._infer_binop = _infer_binop nodes.BinOp._infer = infer_binop +COMPARE_OPS = { + "==": operator.eq, + "!=": operator.ne, + "<": operator.lt, + "<=": operator.le, + ">": operator.gt, + ">=": operator.ge, + "in": lambda a, b: a in b, + "not in": lambda a, b: a not in b, +} +UNINFERABLE_OPS = { + "is", + "is not", +} + + +def _to_literal(node: nodes.NodeNG) -> Any: + # Can raise SyntaxError or ValueError from ast.literal_eval + # Is this the stupidest idea or the simplest idea? + return ast.literal_eval(node.as_string()) + + +def _do_compare( + left_iter: Iterable[nodes.NodeNG], op: str, right_iter: Iterable[nodes.NodeNG] +) -> "bool | type[util.Uninferable]": + """ + If all possible combinations are either True or False, return that: + >>> _do_compare([1, 2], '<=', [3, 4]) + True + >>> _do_compare([1, 2], '==', [3, 4]) + False + + If any item is uninferable, or if some combinations are True and some + are False, return Uninferable: + >>> _do_compare([1, 3], '<=', [2, 4]) + util.Uninferable + """ + retval = None + if op in UNINFERABLE_OPS: + return util.Uninferable + op_func = COMPARE_OPS[op] + + for left, right in itertools.product(left_iter, right_iter): + if left is util.Uninferable or right is util.Uninferable: + return util.Uninferable + + try: + left, right = _to_literal(left), _to_literal(right) + except (SyntaxError, ValueError): + return util.Uninferable + + try: + expr = op_func(left, right) + except TypeError as exc: + raise AstroidTypeError from exc + + if retval is None: + retval = expr + elif retval != expr: + return util.Uninferable + # (or both, but "True | False" is basically the same) + + return retval # it was all the same value + + +def _infer_compare(self: nodes.Compare, context: InferenceContext) -> Any: + """Chained comparison inference logic.""" + retval = True + + ops = self.ops + left_node = self.left + lhs = list(left_node.infer(context=context)) + # should we break early if first element is uninferable? + for op, right_node in ops: + # eagerly evaluate rhs so that values can be re-used as lhs + rhs = list(right_node.infer(context=context)) + try: + retval = _do_compare(lhs, op, rhs) + except AstroidTypeError: + retval = util.Uninferable + break + if retval is not True: + break # short-circuit + lhs = rhs # continue + if retval is util.Uninferable: + yield retval + else: + yield nodes.Const(retval) + + +nodes.Compare._infer = _infer_compare + def _infer_augassign(self, context=None): """Inference logic for augmented binary operations.""" diff --git a/tests/unittest_inference.py b/tests/unittest_inference.py index e2502539f..f721b32f7 100644 --- a/tests/unittest_inference.py +++ b/tests/unittest_inference.py @@ -5280,6 +5280,261 @@ def f(**kwargs): assert next(extract_node(code).infer()).as_string() == "{'f': 1}" +@pytest.mark.parametrize( + "op,result", + [ + ("<", False), + ("<=", True), + ("==", True), + (">=", True), + (">", False), + ("!=", False), + ], +) +def test_compare(op, result) -> None: + code = """ + 123 {} 123 + """.format( + op + ) + node = extract_node(code) + inferred = next(node.infer()) + assert inferred.value == result + + +@pytest.mark.xfail(reason="uninferable") +@pytest.mark.parametrize( + "op,result", + [ + ("is", True), + ("is not", False), + ], +) +def test_compare_identity(op, result) -> None: + code = """ + obj = object() + obj {} obj + """.format( + op + ) + node = extract_node(code) + inferred = next(node.infer()) + assert inferred.value == result + + +@pytest.mark.parametrize( + "op,result", + [ + ("in", True), + ("not in", False), + ], +) +def test_compare_membership(op, result) -> None: + code = """ + 1 {} [1, 2, 3] + """.format( + op + ) + node = extract_node(code) + inferred = next(node.infer()) + assert inferred.value == result + + +@pytest.mark.parametrize( + "lhs,rhs,result", + [ + (1, 1, True), + (1, 1.1, True), + (1.1, 1, False), + (1.0, 1.0, True), + ("abc", "def", True), + ("abc", "", False), + ([], [1], True), + ((1, 2), (2, 3), True), + ((1, 0), (1,), False), + (True, True, True), + (True, False, False), + (False, 1, True), + (1 + 0j, 2 + 0j, util.Uninferable), + (+0.0, -0.0, True), + (0, "1", util.Uninferable), + (b"\x00", b"\x01", True), + ], +) +def test_compare_lesseq_types(lhs, rhs, result) -> None: + code = """ + {lhs!r} <= {rhs!r} + """.format( + lhs=lhs, rhs=rhs + ) + node = extract_node(code) + inferred = next(node.infer()) + assert inferred.value == result + + +def test_compare_chained() -> None: + code = """ + 3 < 5 > 3 + """ + node = extract_node(code) + inferred = next(node.infer()) + assert inferred.value is True + + +def test_compare_inferred_members() -> None: + code = """ + a = 11 + b = 13 + a < b + """ + node = extract_node(code) + inferred = next(node.infer()) + assert inferred.value is True + + +def test_compare_instance_members() -> None: + code = """ + class A: + value = 123 + class B: + @property + def value(self): + return 456 + A().value < B().value + """ + node = extract_node(code) + inferred = next(node.infer()) + assert inferred.value is True + + +@pytest.mark.xfail(reason="unimplemented") +def test_compare_dynamic() -> None: + code = """ + class A: + def __le__(self, other): + return True + A() <= None + """ + node = extract_node(code) + inferred = next(node.infer()) + assert inferred.value is True + + +def test_compare_uninferable_member() -> None: + code = """ + from unknown import UNKNOWN + 0 <= UNKNOWN + """ + node = extract_node(code) + inferred = next(node.infer()) + assert inferred is util.Uninferable + + +def test_compare_chained_comparisons_shortcircuit_on_false() -> None: + code = """ + from unknown import UNKNOWN + 2 < 1 < UNKNOWN + """ + node = extract_node(code) + inferred = next(node.infer()) + assert inferred.value is False + + +def test_compare_chained_comparisons_continue_on_true() -> None: + code = """ + from unknown import UNKNOWN + 1 < 2 < UNKNOWN + """ + node = extract_node(code) + inferred = next(node.infer()) + assert inferred is util.Uninferable + + +@pytest.mark.xfail(reason="unimplemented") +def test_compare_known_false_branch() -> None: + code = """ + a = 'hello' + if 1 < 2: + a = 'goodbye' + a + """ + node = extract_node(code) + inferred = list(node.infer()) + assert len(inferred) == 1 + assert isinstance(inferred[0], nodes.Const) + assert inferred[0].value == "hello" + + +def test_compare_ifexp_constant() -> None: + code = """ + a = 'hello' if 1 < 2 else 'goodbye' + a + """ + node = extract_node(code) + inferred = list(node.infer()) + assert len(inferred) == 1 + assert isinstance(inferred[0], nodes.Const) + assert inferred[0].value == "hello" + + +def test_compare_typeerror() -> None: + code = """ + 123 <= "abc" + """ + node = extract_node(code) + inferred = list(node.infer()) + assert len(inferred) == 1 + assert inferred[0] is util.Uninferable + + +def test_compare_multiple_possibilites() -> None: + code = """ + from unknown import UNKNOWN + a = 1 + if UNKNOWN: + a = 2 + b = 3 + if UNKNOWN: + b = 4 + a < b + """ + node = extract_node(code) + inferred = list(node.infer()) + assert len(inferred) == 1 + # All possible combinations are true: (1 < 3), (1 < 4), (2 < 3), (2 < 4) + assert inferred[0].value is True + + +def test_compare_ambiguous_multiple_possibilites() -> None: + code = """ + from unknown import UNKNOWN + a = 1 + if UNKNOWN: + a = 3 + b = 2 + if UNKNOWN: + b = 4 + a < b + """ + node = extract_node(code) + inferred = list(node.infer()) + assert len(inferred) == 1 + # Not all possible combinations are true: (1 < 2), (1 < 4), (3 !< 2), (3 < 4) + assert inferred[0] is util.Uninferable + + +def test_compare_nonliteral() -> None: + code = """ + def func(a, b): + return (a, b) <= (1, 2) #@ + """ + return_node = extract_node(code) + node = return_node.value + inferred = list(node.infer()) # should not raise ValueError + assert len(inferred) == 1 + assert inferred[0] is util.Uninferable + + def test_limit_inference_result_amount() -> None: """Test setting limit inference result amount""" code = """ @@ -5560,7 +5815,7 @@ def method(self): """, ], ) -def test_subclass_of_exception(code): +def test_subclass_of_exception(code) -> None: inferred = next(extract_node(code).infer()) assert isinstance(inferred, Instance) args = next(inferred.igetattr("args")) @@ -5721,7 +5976,7 @@ def test(self): ), ], ) -def test_inference_is_limited_to_the_boundnode(code, instance_name): +def test_inference_is_limited_to_the_boundnode(code, instance_name) -> None: node = extract_node(code) inferred = next(node.infer()) assert isinstance(inferred, Instance)