Skip to content

Commit

Permalink
Add typing to TransformVisitor (#2062)
Browse files Browse the repository at this point in the history
  • Loading branch information
DanielNoord committed Mar 22, 2023
1 parent 7ed0804 commit 598e4c3
Showing 1 changed file with 90 additions and 25 deletions.
115 changes: 90 additions & 25 deletions astroid/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,34 @@

from __future__ import annotations

import collections
from typing import TYPE_CHECKING
from collections import defaultdict
from collections.abc import Callable
from typing import TYPE_CHECKING, List, Optional, Tuple, TypeVar, Union, cast, overload

from astroid.context import _invalidate_cache
from astroid.typing import SuccessfulInferenceResult

if TYPE_CHECKING:
from astroid import NodeNG
from astroid import nodes

_SuccessfulInferenceResultT = TypeVar(
"_SuccessfulInferenceResultT", bound=SuccessfulInferenceResult
)
_Transform = Callable[
[_SuccessfulInferenceResultT], Optional[SuccessfulInferenceResult]
]
_Predicate = Optional[Callable[[_SuccessfulInferenceResultT], bool]]

_Vistables = Union[
"nodes.NodeNG", List["nodes.NodeNG"], Tuple["nodes.NodeNG", ...], str, None
]
_VisitReturns = Union[
SuccessfulInferenceResult,
List[SuccessfulInferenceResult],
Tuple[SuccessfulInferenceResult, ...],
str,
None,
]


class TransformVisitor:
Expand All @@ -24,17 +45,26 @@ class TransformVisitor:
Based on its usage in AstroidManager.brain, it should not be reinstantiated.
"""

def __init__(self):
self.transforms = collections.defaultdict(list)

def _transform(self, node: NodeNG) -> NodeNG:
def __init__(self) -> None:
# The typing here is incorrect, but it's the best we can do
# Refer to register_transform and unregister_transform for the correct types
self.transforms: defaultdict[
type[SuccessfulInferenceResult],
list[
tuple[
_Transform[SuccessfulInferenceResult],
_Predicate[SuccessfulInferenceResult],
]
],
] = defaultdict(list)

def _transform(self, node: SuccessfulInferenceResult) -> SuccessfulInferenceResult:
"""Call matching transforms for the given node if any and return the
transformed node.
"""
cls = node.__class__

transforms = self.transforms[cls]
for transform_func, predicate in transforms:
for transform_func, predicate in self.transforms[cls]:
if predicate is None or predicate(node):
ret = transform_func(node)
# if the transformation function returns something, it's
Expand All @@ -47,16 +77,40 @@ def _transform(self, node: NodeNG) -> NodeNG:
break
return node

def _visit(self, node):
if hasattr(node, "_astroid_fields"):
for name in node._astroid_fields:
value = getattr(node, name)
visited = self._visit_generic(value)
if visited != value:
setattr(node, name, visited)
def _visit(self, node: nodes.NodeNG) -> SuccessfulInferenceResult:
for name in node._astroid_fields:
value = getattr(node, name)
value = cast(_Vistables, value)
visited = self._visit_generic(value)
if visited != value:
setattr(node, name, visited)
return self._transform(node)

def _visit_generic(self, node):
@overload
def _visit_generic(self, node: None) -> None:
...

@overload
def _visit_generic(self, node: str) -> str:
...

@overload
def _visit_generic(
self, node: list[nodes.NodeNG]
) -> list[SuccessfulInferenceResult]:
...

@overload
def _visit_generic(
self, node: tuple[nodes.NodeNG, ...]
) -> tuple[SuccessfulInferenceResult, ...]:
...

@overload
def _visit_generic(self, node: nodes.NodeNG) -> SuccessfulInferenceResult:
...

def _visit_generic(self, node: _Vistables) -> _VisitReturns:
if isinstance(node, list):
return [self._visit_generic(child) for child in node]
if isinstance(node, tuple):
Expand All @@ -66,21 +120,32 @@ def _visit_generic(self, node):

return self._visit(node)

def register_transform(self, node_class, transform, predicate=None) -> None:
"""Register `transform(node)` function to be applied on the given
astroid's `node_class` if `predicate` is None or returns true
def register_transform(
self,
node_class: type[_SuccessfulInferenceResultT],
transform: _Transform[_SuccessfulInferenceResultT],
predicate: _Predicate[_SuccessfulInferenceResultT] | None = None,
) -> None:
"""Register `transform(node)` function to be applied on the given node.
The transform will only be applied if `predicate` is None or returns true
when called with the node as argument.
The transform function may return a value which is then used to
substitute the original node in the tree.
"""
self.transforms[node_class].append((transform, predicate))

def unregister_transform(self, node_class, transform, predicate=None) -> None:
self.transforms[node_class].append((transform, predicate)) # type: ignore[index, arg-type]

def unregister_transform(
self,
node_class: type[_SuccessfulInferenceResultT],
transform: _Transform[_SuccessfulInferenceResultT],
predicate: _Predicate[_SuccessfulInferenceResultT] | None = None,
) -> None:
"""Unregister the given transform."""
self.transforms[node_class].remove((transform, predicate))
self.transforms[node_class].remove((transform, predicate)) # type: ignore[index, arg-type]

def visit(self, module):
def visit(self, module: nodes.Module) -> SuccessfulInferenceResult:
"""Walk the given astroid *tree* and transform each encountered node.
Only the nodes which have transforms registered will actually
Expand Down

0 comments on commit 598e4c3

Please sign in to comment.