Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions ChangeLog
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ Release date: TBA
* Suppress ``SyntaxWarning`` for invalid escape sequences and return in finally on
Python 3.14 when parsing modules.

* Assign ``Import`` and ``ImportFrom`` nodes to module locals if used with ``global``.

Closes pylint-dev/pylint#10632


What's New in astroid 4.0.0?
============================
Expand Down
37 changes: 23 additions & 14 deletions astroid/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@
import textwrap
import types
import warnings
from collections.abc import Iterator, Sequence
from collections.abc import Collection, Iterator, Sequence
from io import TextIOWrapper
from tokenize import detect_encoding
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, cast

from astroid import bases, modutils, nodes, raw_building, rebuilder, util
from astroid._ast import ParserModule, get_parser_module
Expand Down Expand Up @@ -163,11 +163,11 @@ def _post_build(
module.file_encoding = encoding
self._manager.cache_module(module)
# post tree building steps after we stored the module in the cache:
for from_node in builder._import_from_nodes:
for from_node, global_names in builder._import_from_nodes:
if from_node.modname == "__future__":
for symbol, _ in from_node.names:
module.future_imports.add(symbol)
self.add_from_names_to_locals(from_node)
self.add_from_names_to_locals(from_node, global_names)
# handle delayed assattr nodes
for delayed in builder._delayed_assattr:
self.delayed_assattr(delayed)
Expand Down Expand Up @@ -210,31 +210,40 @@ def _data_build(
module = builder.visit_module(node, modname, node_file, package)
return module, builder

def add_from_names_to_locals(self, node: nodes.ImportFrom) -> None:
def add_from_names_to_locals(
self, node: nodes.ImportFrom, global_name: Collection[str]
) -> None:
"""Store imported names to the locals.

Resort the locals if coming from a delayed node
"""

def _key_func(node: nodes.NodeNG) -> int:
return node.fromlineno or 0

def sort_locals(my_list: list[nodes.NodeNG]) -> None:
my_list.sort(key=_key_func)
def add_local(parent_or_root: nodes.NodeNG, name: str) -> None:
parent_or_root.set_local(name, node)
my_list = parent_or_root.scope().locals[name]
if TYPE_CHECKING:
my_list = cast(list[nodes.NodeNG], my_list)
my_list.sort(key=lambda n: n.fromlineno or 0)

assert node.parent # It should always default to the module
module = node.root()
for name, asname in node.names:
if name == "*":
try:
imported = node.do_import_module()
except AstroidBuildingError:
continue
for name in imported.public_names():
node.parent.set_local(name, node)
sort_locals(node.parent.scope().locals[name]) # type: ignore[arg-type]
if name in global_name:
add_local(module, name)
else:
add_local(node.parent, name)
else:
node.parent.set_local(asname or name, node)
sort_locals(node.parent.scope().locals[asname or name]) # type: ignore[arg-type]
name = asname or name
if name in global_name:
add_local(module, name)
else:
add_local(node.parent, name)

def delayed_assattr(self, node: nodes.AssignAttr) -> None:
"""Visit an AssignAttr node.
Expand Down
15 changes: 10 additions & 5 deletions astroid/rebuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import ast
import sys
import token
from collections.abc import Callable, Generator
from collections.abc import Callable, Collection, Generator
from io import StringIO
from tokenize import TokenInfo, generate_tokens
from typing import TYPE_CHECKING, Final, TypeVar, cast, overload
Expand Down Expand Up @@ -61,7 +61,7 @@ def __init__(
self._manager = manager
self._data = data.split("\n") if data else None
self._global_names: list[dict[str, list[nodes.Global]]] = []
self._import_from_nodes: list[nodes.ImportFrom] = []
self._import_from_nodes: list[tuple[nodes.ImportFrom, Collection[str]]] = []
self._delayed_assattr: list[nodes.AssignAttr] = []
self._visit_meths: dict[
type[ast.AST], Callable[[ast.AST, nodes.NodeNG], nodes.NodeNG]
Expand Down Expand Up @@ -1099,7 +1099,9 @@ def visit_importfrom(
parent=parent,
)
# store From names to add them to locals after building
self._import_from_nodes.append(newnode)
self._import_from_nodes.append(
(newnode, self._global_names[-1].keys() if self._global_names else ())
)
return newnode

@overload
Expand Down Expand Up @@ -1300,8 +1302,11 @@ def visit_import(self, node: ast.Import, parent: nodes.NodeNG) -> nodes.Import:
)
# save import names in parent's locals:
for name, asname in newnode.names:
name = asname or name
parent.set_local(name.split(".")[0], newnode)
name = (asname or name).split(".")[0]
if self._global_names and name in self._global_names[-1]:
parent.root().set_local(name, newnode)
else:
parent.set_local(name, newnode)
return newnode

def visit_joinedstr(
Expand Down
25 changes: 25 additions & 0 deletions tests/test_scoped_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2803,6 +2803,31 @@ class First(object, object): #@
astroid["First"].slots()


def test_import_with_global() -> None:
code = builder.parse(
"""
def f1():
global platform
from sys import platform as plat
platform = plat

def f2():
global os, RE, deque, VERSION, Path
import os
import re as RE
from collections import deque
from sys import version as VERSION
from pathlib import *
"""
)
assert "platform" in code.locals
assert "os" in code.locals
assert "RE" in code.locals
assert "deque" in code.locals
assert "VERSION" in code.locals
assert "Path" in code.locals


class TestFrameNodes:
@staticmethod
def test_frame_node():
Expand Down