diff --git a/ChangeLog b/ChangeLog index 1aa31f9f2..4447594a9 100644 --- a/ChangeLog +++ b/ChangeLog @@ -7,6 +7,10 @@ What's New in astroid 4.0.0? ============================ Release date: TBA +* Support constraints from ternary expressions in inference. + + Closes pylint-dev/pylint#9729 + * Handle deprecated `bool(NotImplemented)` cast in const nodes. * Add support for boolean truthiness constraints (`x`, `not x`) in inference. diff --git a/astroid/constraint.py b/astroid/constraint.py index 75a5e6aca..692d22d03 100644 --- a/astroid/constraint.py +++ b/astroid/constraint.py @@ -127,7 +127,7 @@ def satisfied_by(self, inferred: InferenceResult) -> bool: def get_constraints( expr: _NameNodes, frame: nodes.LocalsDictNodeNG -) -> dict[nodes.If, set[Constraint]]: +) -> dict[nodes.If | nodes.IfExp, set[Constraint]]: """Returns the constraints for the given expression. The returned dictionary maps the node where the constraint was generated to the @@ -137,10 +137,10 @@ def get_constraints( Currently this only supports constraints generated from if conditions. """ current_node: nodes.NodeNG | None = expr - constraints_mapping: dict[nodes.If, set[Constraint]] = {} + constraints_mapping: dict[nodes.If | nodes.IfExp, set[Constraint]] = {} while current_node is not None and current_node is not frame: parent = current_node.parent - if isinstance(parent, nodes.If): + if isinstance(parent, (nodes.If, nodes.IfExp)): branch, _ = parent.locate_child(current_node) constraints: set[Constraint] | None = None if branch == "body": diff --git a/astroid/context.py b/astroid/context.py index 3002b532c..d1aeef3bb 100644 --- a/astroid/context.py +++ b/astroid/context.py @@ -80,7 +80,9 @@ def __init__( self.extra_context: dict[SuccessfulInferenceResult, InferenceContext] = {} """Context that needs to be passed down through call stacks for call arguments.""" - self.constraints: dict[str, dict[nodes.If, set[constraint.Constraint]]] = {} + self.constraints: dict[ + str, dict[nodes.If | nodes.IfExp, set[constraint.Constraint]] + ] = {} """The constraints on nodes.""" @property diff --git a/tests/test_constraint.py b/tests/test_constraint.py index 84ef498d0..4859d4241 100644 --- a/tests/test_constraint.py +++ b/tests/test_constraint.py @@ -592,3 +592,184 @@ def method(self): assert isinstance(inferred[1], nodes.Const) assert inferred[1].value == fail_val + + +@common_params(node="x") +def test_if_exp_body( + condition: str, satisfy_val: int | None, fail_val: int | None +) -> None: + """Test constraint for a variable that is used in an if exp body.""" + node1, node2 = builder.extract_node( + f""" + def f1(x = {fail_val}): + return ( + x if {condition} else None #@ + ) + + def f2(x = {satisfy_val}): + return ( + x if {condition} else None #@ + ) + """ + ) + + inferred = node1.body.inferred() + assert len(inferred) == 1 + assert inferred[0] is Uninferable + + inferred = node2.body.inferred() + assert len(inferred) == 2 + assert isinstance(inferred[0], nodes.Const) + assert inferred[0].value == satisfy_val + assert inferred[1] is Uninferable + + +@common_params(node="x") +def test_if_exp_else( + condition: str, satisfy_val: int | None, fail_val: int | None +) -> None: + """Test constraint for a variable that is used in an if exp else block.""" + node1, node2 = builder.extract_node( + f""" + def f1(x = {satisfy_val}): + return ( + None if {condition} else x #@ + ) + + def f2(x = {fail_val}): + return ( + None if {condition} else x #@ + ) + """ + ) + + inferred = node1.orelse.inferred() + assert len(inferred) == 1 + assert inferred[0] is Uninferable + + inferred = node2.orelse.inferred() + assert len(inferred) == 2 + assert isinstance(inferred[0], nodes.Const) + assert inferred[0].value == fail_val + assert inferred[1] is Uninferable + + +@common_params(node="x") +def test_outside_if_exp( + condition: str, satisfy_val: int | None, fail_val: int | None +) -> None: + """Test that constraint in an if exp condition doesn't apply outside of the if exp.""" + nodes_ = builder.extract_node( + f""" + def f1(x = {fail_val}): + x if {condition} else None + return ( + x #@ + ) + + def f2(x = {satisfy_val}): + None if {condition} else x + return ( + x #@ + ) + """ + ) + for node, val in zip(nodes_, (fail_val, satisfy_val)): + inferred = node.inferred() + assert len(inferred) == 2 + assert isinstance(inferred[0], nodes.Const) + assert inferred[0].value == val + assert inferred[1] is Uninferable + + +@common_params(node="x") +def test_nested_if_exp( + condition: str, satisfy_val: int | None, fail_val: int | None +) -> None: + """Test that constraint in an if exp condition applies within inner if exp.""" + node1, node2 = builder.extract_node( + f""" + def f1(y, x = {fail_val}): + return ( + (x if y else None) if {condition} else None #@ + ) + + def f2(y, x = {satisfy_val}): + return ( + (x if y else None) if {condition} else None #@ + ) + """ + ) + + inferred = node1.body.body.inferred() + assert len(inferred) == 1 + assert inferred[0] is Uninferable + + inferred = node2.body.body.inferred() + assert len(inferred) == 2 + assert isinstance(inferred[0], nodes.Const) + assert inferred[0].value == satisfy_val + assert inferred[1] is Uninferable + + +@common_params(node="self.x") +def test_if_exp_instance_attr( + condition: str, satisfy_val: int | None, fail_val: int | None +) -> None: + """Test constraint for an instance attribute in an if exp.""" + node1, node2 = builder.extract_node( + f""" + class A1: + def __init__(self, x = {fail_val}): + self.x = x + + def method(self): + return ( + self.x if {condition} else None #@ + ) + + class A2: + def __init__(self, x = {satisfy_val}): + self.x = x + + def method(self): + return ( + self.x if {condition} else None #@ + ) + """ + ) + + inferred = node1.body.inferred() + assert len(inferred) == 1 + assert inferred[0] is Uninferable + + inferred = node2.body.inferred() + assert len(inferred) == 2 + assert isinstance(inferred[0], nodes.Const) + assert inferred[0].value == satisfy_val + assert inferred[1].value is Uninferable + + +@common_params(node="self.x") +def test_if_exp_instance_attr_varname_collision( + condition: str, satisfy_val: int | None, fail_val: int | None +) -> None: + """Test that constraint in an if exp condition doesn't apply to a variable with the same name.""" + node = builder.extract_node( + f""" + class A: + def __init__(self, x = {fail_val}): + self.x = x + + def method(self, x = {fail_val}): + return ( + x if {condition} else None #@ + ) + """ + ) + + inferred = node.body.inferred() + assert len(inferred) == 2 + assert isinstance(inferred[0], nodes.Const) + assert inferred[0].value == fail_val + assert inferred[1].value is Uninferable