Skip to content

Commit

Permalink
Add typing for set_local (#1837)
Browse files Browse the repository at this point in the history
* Use SuccessfulInferenceResult for locals
  • Loading branch information
cdce8p committed Nov 8, 2022
1 parent 4acf578 commit 6cf238d
Show file tree
Hide file tree
Showing 8 changed files with 35 additions and 46 deletions.
4 changes: 2 additions & 2 deletions astroid/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,10 +225,10 @@ def sort_locals(my_list: list[nodes.NodeNG]) -> None:
continue
for name in imported.public_names():
node.parent.set_local(name, node)
sort_locals(node.parent.scope().locals[name])
sort_locals(node.parent.scope().locals[name]) # type: ignore[assignment]
else:
node.parent.set_local(asname or name, node)
sort_locals(node.parent.scope().locals[asname or name])
sort_locals(node.parent.scope().locals[asname or name]) # type: ignore[assignment]

def delayed_assattr(self, node: nodes.AssignAttr) -> None:
"""Visit a AssAttr node
Expand Down
2 changes: 1 addition & 1 deletion astroid/nodes/node_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -4896,7 +4896,7 @@ def scope(self) -> LocalsDictNodeNG:

return self.parent.scope()

def set_local(self, name: str, stmt: AssignName) -> None:
def set_local(self, name: str, stmt: NodeNG) -> None:
"""Define that the given name is declared in the given statement node.
NamedExpr's in Arguments, Keyword or Comprehension are evaluated in their
parent's parent scope. So we add to their frame's locals.
Expand Down
5 changes: 2 additions & 3 deletions astroid/nodes/node_ng.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,19 +493,18 @@ def block_range(self, lineno):
"""
return lineno, self.tolineno

def set_local(self, name, stmt):
def set_local(self, name: str, stmt: NodeNG) -> None:
"""Define that the given name is declared in the given statement node.
This definition is stored on the parent scope node.
.. seealso:: :meth:`scope`
:param name: The name that is being defined.
:type name: str
:param stmt: The statement that defines the given name.
:type stmt: NodeNG
"""
assert self.parent
self.parent.set_local(name, stmt)

@overload
Expand Down
29 changes: 18 additions & 11 deletions astroid/nodes/scoped_nodes/mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@

from __future__ import annotations

from typing import TYPE_CHECKING, TypeVar
from typing import TYPE_CHECKING, TypeVar, overload

from astroid.filter_statements import _filter_stmts
from astroid.nodes import node_classes, scoped_nodes
from astroid.nodes.scoped_nodes.utils import builtin_lookup
from astroid.typing import SuccessfulInferenceResult

if TYPE_CHECKING:
from astroid import nodes
Expand All @@ -26,7 +27,7 @@ class LocalsDictNodeNG(node_classes.LookupMixIn):

# attributes below are set by the builder module or by raw factories

locals: dict[str, list[nodes.NodeNG]] = {}
locals: dict[str, list[SuccessfulInferenceResult]] = {}
"""A map of the name of a local variable to the node defining the local."""

def qname(self):
Expand Down Expand Up @@ -88,23 +89,21 @@ def _scope_lookup(self, node, name, offset=0):
# self is at the top level of a module, or is enclosed only by ClassDefs
return builtin_lookup(name)

def set_local(self, name, stmt):
def set_local(self, name: str, stmt: nodes.NodeNG) -> None:
"""Define that the given name is declared in the given statement node.
.. seealso:: :meth:`scope`
:param name: The name that is being defined.
:type name: str
:param stmt: The statement that defines the given name.
:type stmt: NodeNG
"""
# assert not stmt in self.locals.get(name, ()), (self, stmt)
self.locals.setdefault(name, []).append(stmt)

__setitem__ = set_local

def _append_node(self, child):
def _append_node(self, child: nodes.NodeNG) -> None:
"""append a child, linking it in the tree"""
# pylint: disable=no-member; depending by the class
# which uses the current class as a mixin or base class.
Expand All @@ -113,22 +112,30 @@ def _append_node(self, child):
self.body.append(child)
child.parent = self

def add_local_node(self, child_node, name=None):
@overload
def add_local_node(
self, child_node: nodes.ClassDef, name: str | None = ...
) -> None:
...

@overload
def add_local_node(self, child_node: nodes.NodeNG, name: str) -> None:
...

def add_local_node(self, child_node: nodes.NodeNG, name: str | None = None) -> None:
"""Append a child that should alter the locals of this scope node.
:param child_node: The child node that will alter locals.
:type child_node: NodeNG
:param name: The name of the local that will be altered by
the given child node.
:type name: str or None
"""
if name != "__class__":
# add __class__ node as a child will cause infinite recursion later!
self._append_node(child_node)
self.set_local(name or child_node.name, child_node)
self.set_local(name or child_node.name, child_node) # type: ignore[attr-defined]

def __getitem__(self, item: str) -> nodes.NodeNG:
def __getitem__(self, item: str) -> SuccessfulInferenceResult:
"""The first node the defines the given local.
:param item: The name of the locally defined object.
Expand Down
36 changes: 9 additions & 27 deletions astroid/nodes/scoped_nodes/scoped_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
from astroid.nodes.scoped_nodes.mixin import ComprehensionScope, LocalsDictNodeNG
from astroid.nodes.scoped_nodes.utils import builtin_lookup
from astroid.nodes.utils import Position
from astroid.typing import InferenceResult
from astroid.typing import InferenceResult, SuccessfulInferenceResult

if sys.version_info >= (3, 8):
from functools import cached_property
Expand Down Expand Up @@ -709,10 +709,7 @@ def __init__(
:type end_col_offset: Optional[int]
"""
self.locals = {}
"""A map of the name of a local variable to the node defining the local.
:type: dict(str, NodeNG)
"""
"""A map of the name of a local variable to the node defining the local."""

super().__init__(
lineno=lineno,
Expand Down Expand Up @@ -801,10 +798,7 @@ def __init__(
:type end_col_offset: Optional[int]
"""
self.locals = {}
"""A map of the name of a local variable to the node defining the local.
:type: dict(str, NodeNG)
"""
"""A map of the name of a local variable to the node defining the local."""

super().__init__(
lineno=lineno,
Expand Down Expand Up @@ -898,10 +892,7 @@ def __init__(
:type end_col_offset: Optional[int]
"""
self.locals = {}
"""A map of the name of a local variable to the node defining the local.
:type: dict(str, NodeNG)
"""
"""A map of the name of a local variable to the node defining the local."""

super().__init__(
lineno=lineno,
Expand Down Expand Up @@ -968,10 +959,7 @@ def __init__(
end_col_offset=None,
):
self.locals = {}
"""A map of the name of a local variable to the node defining it.
:type: dict(str, NodeNG)
"""
"""A map of the name of a local variable to the node defining it."""

super().__init__(
lineno=lineno,
Expand Down Expand Up @@ -1106,10 +1094,7 @@ def __init__(
:type end_col_offset: Optional[int]
"""
self.locals = {}
"""A map of the name of a local variable to the node defining it.
:type: dict(str, NodeNG)
"""
"""A map of the name of a local variable to the node defining it."""

self.args: Arguments
"""The arguments that the function takes."""
Expand Down Expand Up @@ -2003,10 +1988,7 @@ def __init__(
"""
self.instance_attrs = {}
self.locals = {}
"""A map of the name of a local variable to the node defining it.
:type: dict(str, NodeNG)
"""
"""A map of the name of a local variable to the node defining it."""

self.keywords = []
"""The keywords given to the class definition.
Expand Down Expand Up @@ -2531,7 +2513,7 @@ def getattr(
name: str,
context: InferenceContext | None = None,
class_context: bool = True,
) -> list[NodeNG]:
) -> list[SuccessfulInferenceResult]:
"""Get an attribute from this class, using Python's attribute semantic.
This method doesn't look in the :attr:`instance_attrs` dictionary
Expand All @@ -2558,7 +2540,7 @@ def getattr(
raise AttributeInferenceError(target=self, attribute=name, context=context)

# don't modify the list in self.locals!
values = list(self.locals.get(name, []))
values: list[SuccessfulInferenceResult] = list(self.locals.get(name, []))
for classnode in self.ancestors(recurs=True, context=context):
values += classnode.locals.get(name, [])

Expand Down
2 changes: 1 addition & 1 deletion astroid/nodes/scoped_nodes/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def builtin_lookup(name: str) -> tuple[nodes.Module, Sequence[nodes.NodeNG]]:
if name == "__dict__":
return _builtin_astroid, ()
try:
stmts: Sequence[nodes.NodeNG] = _builtin_astroid.locals[name]
stmts: Sequence[nodes.NodeNG] = _builtin_astroid.locals[name] # type: ignore[assignment]
except KeyError:
stmts = ()
return _builtin_astroid, stmts
2 changes: 1 addition & 1 deletion astroid/raw_building.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def build_from_import(fromname, names):
return nodes.ImportFrom(fromname, [(name, None) for name in names])


def register_arguments(func, args=None):
def register_arguments(func: nodes.FunctionDef, args: list | None = None) -> None:
"""add given arguments to local
args is a list that may contains nested lists
Expand Down
1 change: 1 addition & 0 deletions astroid/rebuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,6 +610,7 @@ def _save_assignment(self, node: nodes.AssignName | nodes.DelName) -> None:
node.root().set_local(node.name, node)
else:
assert node.parent
assert node.name
node.parent.set_local(node.name, node)

def visit_arg(self, node: ast.arg, parent: NodeNG) -> nodes.AssignName:
Expand Down

0 comments on commit 6cf238d

Please sign in to comment.