Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add implicit reexport tracking #271

Merged
merged 8 commits into from
Nov 22, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pyanalyze/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from . import implementation
from . import method_return_type
from . import node_visitor
from . import reexport
from . import safe
from . import signature
from . import stacked_scopes
Expand All @@ -34,3 +35,4 @@
used(dump_value)
used(extensions.LiteralOnly)
used(value.UNRESOLVED_VALUE) # keeping it around for now just in case
used(reexport)
8 changes: 8 additions & 0 deletions pyanalyze/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
if TYPE_CHECKING:
from .arg_spec import ArgSpecCache
from .signature import Signature
from .reexport import ImplicitReexportTracker


class Config(object):
Expand Down Expand Up @@ -285,3 +286,10 @@ def should_check_class_for_duplicate_values(self, cls: type) -> bool:
def get_additional_bases(self, typ: Union[type, super]) -> Set[type]:
"""Return additional classes that should be considered bae classes of typ."""
return set()

#
# Used by reexport.py
#
def configure_reexports(self, tracker: "ImplicitReexportTracker") -> None:
"""Override this to set some names as explicitly re-exported."""
pass
4 changes: 4 additions & 0 deletions pyanalyze/error_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ class ErrorCode(enum.Enum):
type_does_not_support_bool = 66
missing_return = 67
no_return_may_return = 68
implicit_reexport = 69


# Allow testing unannotated functions without too much fuss
Expand All @@ -103,6 +104,8 @@ class ErrorCode(enum.Enum):
ErrorCode.possibly_undefined_name,
ErrorCode.missing_f,
ErrorCode.bare_ignore,
# TODO: turn this on
ErrorCode.implicit_reexport,
}

ERROR_DESCRIPTION = {
Expand Down Expand Up @@ -187,6 +190,7 @@ class ErrorCode(enum.Enum):
ErrorCode.type_does_not_support_bool: "Type does not support bool()",
ErrorCode.missing_return: "Function may exit without returning a value",
ErrorCode.no_return_may_return: "Function is annotated as NoReturn but may return",
ErrorCode.implicit_reexport: "Use of implicitly re-exported name",
}


Expand Down
89 changes: 63 additions & 26 deletions pyanalyze/name_check_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
from .error_code import ErrorCode, DISABLED_BY_DEFAULT, ERROR_DESCRIPTION
from .extensions import ParameterTypeGuard
from .find_unused import UnusedObjectFinder, used
from .reexport import ErrorContext, ImplicitReexportTracker
from .safe import safe_getattr, is_hashable, safe_in, all_of_type
from .stacked_scopes import (
AbstractConstraint,
Expand Down Expand Up @@ -677,7 +678,9 @@ def record_call(self, caller: object, callee: object) -> None:
pass


class NameCheckVisitor(node_visitor.ReplacingNodeVisitor, CanAssignContext):
class NameCheckVisitor(
node_visitor.ReplacingNodeVisitor, CanAssignContext, ErrorContext
):
"""Visitor class that infers the type and value of Python objects and detects errors."""

error_code_enum = ErrorCode
Expand All @@ -703,6 +706,7 @@ def __init__(
attribute_checker: Optional[ClassAttributeChecker] = None,
arg_spec_cache: Optional[ArgSpecCache] = None,
collector: Optional[CallSiteCollector] = None,
reexport_tracker: Optional[ImplicitReexportTracker] = None,
annotate: bool = False,
add_ignores: bool = False,
) -> None:
Expand Down Expand Up @@ -744,6 +748,9 @@ def __init__(
if arg_spec_cache is None:
arg_spec_cache = ArgSpecCache(self.config)
self.arg_spec_cache = arg_spec_cache
if reexport_tracker is None:
reexport_tracker = ImplicitReexportTracker(self.config)
self.reexport_tracker = reexport_tracker
if (
self.attribute_checker is not None
and self.module is not None
Expand Down Expand Up @@ -861,6 +868,8 @@ def check(self, ignore_missing_module: bool = False) -> List[node_visitor.Failur
and not self.has_file_level_ignore()
):
self.unused_finder.record_module_visited(self.module)
if self.module is not None and self.module.__name__ is not None:
self.reexport_tracker.record_module_completed(self.module.__name__)
except node_visitor.VisitorError:
raise
except Exception as e:
Expand Down Expand Up @@ -959,16 +968,22 @@ def _show_error_if_checking(
)

def _set_name_in_scope(
self, varname: str, node: object, value: Value = AnyValue(AnySource.inference)
self,
varname: str,
node: object,
value: Value = AnyValue(AnySource.inference),
*,
private: bool = False,
) -> Value:
current_scope = self.scopes.current_scope()
scope_type = current_scope.scope_type
if (
self.module is not None
and scope_type == ScopeType.module_scope
and varname in current_scope
):
return current_scope.get_local(varname, node, self.state)
if self.module is not None and scope_type == ScopeType.module_scope:
if self.module.__name__ is not None and not private:
self.reexport_tracker.record_exported_attribute(
self.module.__name__, varname
)
if varname in current_scope:
return current_scope.get_local(varname, node, self.state)
if scope_type == ScopeType.class_scope and isinstance(node, ast.AST):
self._check_for_class_variable_redefinition(varname, node)
current_scope.set(varname, value, node, self.state)
Expand Down Expand Up @@ -1824,10 +1839,16 @@ def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
self._maybe_record_usages_from_import(node)

is_star_import = len(node.names) == 1 and node.names[0].name == "*"
force_public = self.filename.endswith("/__init__.py") and node.level == 1
if force_public:
# from .a import b implicitly sets a in the parent module's namespace.
# We allow relying on this behavior.
self._set_name_in_scope(node.module, node)
if self.scopes.scope_type() == ScopeType.module_scope and not is_star_import:
self._handle_imports(node.names)
self._handle_imports(node.names, force_public=force_public)
else:
self._simulate_import(node)
# For now we always treat star imports as public. We might revisit this later.
self._simulate_import(node, force_public=True)

def _maybe_record_usages_from_import(self, node: ast.ImportFrom) -> None:
if self.unused_finder is None or self.module is None:
Expand Down Expand Up @@ -1874,7 +1895,9 @@ def _is_unimportable_module(self, node: Union[ast.Import, ast.ImportFrom]) -> bo
for name in node.names
)

def _simulate_import(self, node: Union[ast.ImportFrom, ast.Import]) -> None:
def _simulate_import(
self, node: Union[ast.ImportFrom, ast.Import], *, force_public: bool = False
) -> None:
"""Set the names retrieved from an import node in nontrivial situations.

For simple imports (module-global imports that are not "from ... import *"), we can just
Expand All @@ -1890,13 +1913,13 @@ def _simulate_import(self, node: Union[ast.ImportFrom, ast.Import]) -> None:

"""
if self.module is None:
self._handle_imports(node.names)
self._handle_imports(node.names, force_public=force_public)
return

source_code = decompile(node)

if self._is_unimportable_module(node):
self._handle_imports(node.names)
self._handle_imports(node.names, force_public=force_public)
self.log(logging.INFO, "Ignoring import node", source_code)
return

Expand All @@ -1922,7 +1945,12 @@ def _simulate_import(self, node: Union[ast.ImportFrom, ast.Import]) -> None:
and "." not in node.module
): # not in the package
if node.level == 1 or (node.level == 0 and node.module not in sys.modules):
self._set_name_in_scope(node.module, node, TypedValue(types.ModuleType))
self._set_name_in_scope(
node.module,
node,
TypedValue(types.ModuleType),
private=not force_public,
)

with tempfile.NamedTemporaryFile(suffix=".py") as f:
f.write(source_code.encode("utf-8"))
Expand All @@ -1933,7 +1961,7 @@ def _simulate_import(self, node: Union[ast.ImportFrom, ast.Import]) -> None:
except Exception:
# sets the name of the imported module to Any so we don't get further
# errors
self._handle_imports(node.names)
self._handle_imports(node.names, force_public=force_public)
return
finally:
# clean up pyc file
Expand All @@ -1949,20 +1977,19 @@ def _simulate_import(self, node: Union[ast.ImportFrom, ast.Import]) -> None:
hasattr(builtins, name) and value == getattr(builtins, name)
):
continue
self._set_name_in_scope(name, (node, name), KnownValue(value))
self._set_name_in_scope(
name, (node, name), KnownValue(value), private=not force_public
)

def _imported_names_of_nodes(
self, names: Iterable[ast.alias]
) -> Iterable[Tuple[str, ast.alias]]:
def _handle_imports(
self, names: Iterable[ast.alias], *, force_public: bool = False
) -> None:
for node in names:
if node.asname is not None:
yield node.asname, node
self._set_name_in_scope(node.asname, node)
else:
yield node.name.split(".")[0], node

def _handle_imports(self, names: Iterable[ast.alias]) -> None:
for varname, node in self._imported_names_of_nodes(names):
self._set_name_in_scope(varname, node)
varname = node.name.split(".")[0]
self._set_name_in_scope(varname, node, private=not force_public)

# Comprehensions

Expand Down Expand Up @@ -3082,7 +3109,7 @@ def visit_ExceptHandler(self, node: ast.ExceptHandler) -> None:
# statically findable
to_assign = TypedValue(BaseException)
if node.name is not None:
self._set_name_in_scope(node.name, node, value=to_assign)
self._set_name_in_scope(node.name, node, value=to_assign, private=True)

self._generic_visit_list(node.body)

Expand Down Expand Up @@ -3593,6 +3620,14 @@ def composite_from_attribute(self, node: ast.Attribute) -> Composite:
self.asynq_checker.record_attribute_access(
root_composite.value, node.attr, node
)
if (
isinstance(root_composite.value, KnownValue)
and isinstance(root_composite.value.val, types.ModuleType)
and root_composite.value.val.__name__ is not None
):
self.reexport_tracker.record_attribute_accessed(
root_composite.value.val.__name__, node.attr, node, self
)
value = self._get_attribute_with_fallback(root_composite, node.attr, node)
if self._should_use_varname_value(value):
varname_value = VariableNameValue.from_varname(
Expand Down Expand Up @@ -4238,6 +4273,7 @@ def _run_on_files(
attribute_checker_enabled = settings[ErrorCode.attribute_is_never_set]
if "arg_spec_cache" not in kwargs:
kwargs["arg_spec_cache"] = ArgSpecCache(cls.config)
kwargs.setdefault("reexport_tracker", ImplicitReexportTracker(cls.config))
if attribute_checker is None:
inner_attribute_checker_obj = attribute_checker = ClassAttributeChecker(
cls.config,
Expand Down Expand Up @@ -4283,6 +4319,7 @@ def _run_on_files(
def check_all_files(cls, *args: Any, **kwargs: Any) -> List[node_visitor.Failure]:
if "arg_spec_cache" not in kwargs:
kwargs["arg_spec_cache"] = ArgSpecCache(cls.config)
kwargs.setdefault("reexport_tracker", ImplicitReexportTracker(cls.config))
return super().check_all_files(*args, **kwargs)

@classmethod
Expand Down
66 changes: 66 additions & 0 deletions pyanalyze/reexport.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
"""

Functionality for dealing with implicit reexports.

"""
from ast import AST
from collections import defaultdict
from dataclasses import InitVar, dataclass, field
from enum import Enum
from typing import Dict, List, Optional, Set, Tuple

from .node_visitor import Failure
from .config import Config
from .error_code import ErrorCode


class ErrorContext:
all_failures: List[Failure]

def show_error(
self, node: AST, message: str, error_code: Enum
) -> Optional[Failure]:
raise NotImplementedError


@dataclass
class ImplicitReexportTracker:
config: InitVar[Config]
completed_modules: Set[str] = field(default_factory=set)
module_to_reexports: Dict[str, Set[str]] = field(
default_factory=lambda: defaultdict(set)
)
used_reexports: Dict[str, List[Tuple[str, AST, ErrorContext]]] = field(
default_factory=lambda: defaultdict(list)
)

def __post_init__(self, config: Config) -> None:
config.configure_reexports(self)

def record_exported_attribute(self, module: str, attr: str) -> None:
self.module_to_reexports[module].add(attr)

def record_module_completed(self, module: str) -> None:
self.completed_modules.add(module)
reexports = self.module_to_reexports[module]
for attr, node, ctx in self.used_reexports[module]:
if attr not in reexports:
self.show_error(module, attr, node, ctx)

def record_attribute_accessed(
self, module: str, attr: str, node: AST, ctx: ErrorContext
) -> None:
if module in self.completed_modules:
if attr not in self.module_to_reexports[module]:
self.show_error(module, attr, node, ctx)
else:
self.used_reexports[module].append((attr, node, ctx))

def show_error(self, module: str, attr: str, node: AST, ctx: ErrorContext) -> None:
failure = ctx.show_error(
node,
f"Attribute '{attr}' is not exported by module '{module}'",
ErrorCode.implicit_reexport,
)
if failure is not None:
ctx.all_failures.append(failure)