Skip to content

Commit

Permalink
Preserve parent CallContext when inferring nested functions (#1982)
Browse files Browse the repository at this point in the history
  • Loading branch information
cdce8p committed Jan 29, 2023
1 parent dfd88f5 commit a0d219c
Show file tree
Hide file tree
Showing 7 changed files with 38 additions and 45 deletions.
5 changes: 4 additions & 1 deletion ChangeLog
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,12 @@ Release date: TBA

* Fix issues with ``typing_extensions.TypeVar``.


* Fix ``ClassDef.fromlino`` for PyPy 3.8 (v7.3.11) if class is wrapped by a decorator.

* Preserve parent CallContext when inferring nested functions.

Closes PyCQA/pylint#8074


What's New in astroid 2.13.3?
=============================
Expand Down
34 changes: 0 additions & 34 deletions astroid/brain/brain_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
Const,
JoinedStr,
Name,
NodeNG,
Subscript,
Tuple,
)
Expand Down Expand Up @@ -380,36 +379,6 @@ def infer_special_alias(
return iter([class_def])


def _looks_like_typing_cast(node: Call) -> bool:
return isinstance(node, Call) and (
isinstance(node.func, Name)
and node.func.name == "cast"
or isinstance(node.func, Attribute)
and node.func.attrname == "cast"
)


def infer_typing_cast(
node: Call, ctx: context.InferenceContext | None = None
) -> Iterator[NodeNG]:
"""Infer call to cast() returning same type as casted-from var."""
if not isinstance(node.func, (Name, Attribute)):
raise UseInferenceDefault

try:
func = next(node.func.infer(context=ctx))
except (InferenceError, StopIteration) as exc:
raise UseInferenceDefault from exc
if (
not isinstance(func, FunctionDef)
or func.qname() != "typing.cast"
or len(node.args) != 2
):
raise UseInferenceDefault

return node.args[1].infer(context=ctx)


AstroidManager().register_transform(
Call,
inference_tip(infer_typing_typevar_or_newtype),
Expand All @@ -418,9 +387,6 @@ def infer_typing_cast(
AstroidManager().register_transform(
Subscript, inference_tip(infer_typing_attr), _looks_like_typing_subscript
)
AstroidManager().register_transform(
Call, inference_tip(infer_typing_cast), _looks_like_typing_cast
)

if PY39_PLUS:
AstroidManager().register_transform(
Expand Down
6 changes: 5 additions & 1 deletion astroid/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,13 +161,14 @@ def __str__(self) -> str:
class CallContext:
"""Holds information for a call site."""

__slots__ = ("args", "keywords", "callee")
__slots__ = ("args", "keywords", "callee", "parent_call_context")

def __init__(
self,
args: list[NodeNG],
keywords: list[Keyword] | None = None,
callee: NodeNG | None = None,
parent_call_context: CallContext | None = None,
):
self.args = args # Call positional arguments
if keywords:
Expand All @@ -176,6 +177,9 @@ def __init__(
arg_value_pairs = []
self.keywords = arg_value_pairs # Call keyword arguments
self.callee = callee # Function being called
self.parent_call_context = (
parent_call_context # Parent CallContext for nested calls
)


def copy_context(context: InferenceContext | None) -> InferenceContext:
Expand Down
5 changes: 4 additions & 1 deletion astroid/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,10 @@ def infer_call(
try:
if hasattr(callee, "infer_call_result"):
callcontext.callcontext = CallContext(
args=self.args, keywords=self.keywords, callee=callee
args=self.args,
keywords=self.keywords,
callee=callee,
parent_call_context=callcontext.callcontext,
)
yield from callee.infer_call_result(caller=self, context=callcontext)
except InferenceError:
Expand Down
2 changes: 1 addition & 1 deletion astroid/protocols.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,7 +470,7 @@ def arguments_assigned_stmts(
# reset call context/name
callcontext = context.callcontext
context = copy_context(context)
context.callcontext = None
context.callcontext = callcontext.parent_call_context
args = arguments.CallSite(callcontext, context=context)
return args.infer_argument(self.parent, node_name, context)
return _arguments_infer_argname(self, node_name, context)
Expand Down
26 changes: 22 additions & 4 deletions tests/unittest_brain.py
Original file line number Diff line number Diff line change
Expand Up @@ -2132,8 +2132,7 @@ class A:
pass
b = 42
a = cast(A, b)
a
cast(A, b)
"""
)
inferred = next(node.infer())
Expand All @@ -2148,14 +2147,33 @@ class A:
pass
b = 42
a = typing.cast(A, b)
a
typing.cast(A, b)
"""
)
inferred = next(node.infer())
assert isinstance(inferred, nodes.Const)
assert inferred.value == 42

def test_typing_cast_multiple_inference_calls(self) -> None:
ast_nodes = builder.extract_node(
"""
from typing import TypeVar, cast
T = TypeVar("T")
def ident(var: T) -> T:
return cast(T, var)
ident(2) #@
ident("Hello") #@
"""
)
i0 = next(ast_nodes[0].infer())
assert isinstance(i0, nodes.Const)
assert i0.value == 2

i1 = next(ast_nodes[1].infer())
assert isinstance(i1, nodes.Const)
assert i1.value == "Hello"


@pytest.mark.skipif(
not HAS_TYPING_EXTENSIONS,
Expand Down
5 changes: 2 additions & 3 deletions tests/unittest_inference_calls.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,6 @@ def g(y):
def test_inner_call_with_dynamic_argument() -> None:
"""Test function where return value is the result of a separate function call,
with a dynamic value passed to the inner function.
Currently, this is Uninferable.
"""
node = builder.extract_node(
"""
Expand All @@ -163,7 +161,8 @@ def g(y):
assert isinstance(node, nodes.NodeNG)
inferred = node.inferred()
assert len(inferred) == 1
assert inferred[0] is Uninferable
assert isinstance(inferred[0], nodes.Const)
assert inferred[0].value == 3


def test_method_const_instance_attr() -> None:
Expand Down

0 comments on commit a0d219c

Please sign in to comment.