Skip to content

Commit

Permalink
reimplement 91X in libcst
Browse files Browse the repository at this point in the history
  • Loading branch information
jakkdl committed Feb 24, 2023
1 parent d8325ad commit d781c27
Show file tree
Hide file tree
Showing 8 changed files with 471 additions and 214 deletions.
28 changes: 14 additions & 14 deletions flake8_trio/visitors/flake8triovisitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,6 @@ def __init__(self, shared_state: SharedState):
"typed_calls",
}

self.suppress_errors = False

# `variables` can be saved/loaded, but need a setter to not clear the reference
@property
def variables(self) -> dict[str, str]:
Expand Down Expand Up @@ -105,17 +103,16 @@ def error(
elif not re.match(self.options.enable_visitor_codes_regex, error_code):
return

if not self.suppress_errors:
self.__state.problems.append(
Error(
# 7 == len('TRIO...'), so alt messages raise the original code
error_code[:7],
node.lineno,
node.col_offset,
self.error_codes[error_code],
*args,
)
self.__state.problems.append(
Error(
# 7 == len('TRIO...'), so alt messages raise the original code
error_code[:7],
node.lineno,
node.col_offset,
self.error_codes[error_code],
*args,
)
)

def get_state(self, *attrs: str, copy: bool = False) -> dict[str, Any]:
if not attrs:
Expand Down Expand Up @@ -170,7 +167,7 @@ class Flake8TrioVisitor_cst(m.MatcherDecoratableTransformer, ABC):

def __init__(self, shared_state: SharedState):
super().__init__()
self.outer: dict[cst.BaseStatement, dict[str, Any]] = {}
self.outer: dict[cst.CSTNode, dict[str, Any]] = {}
self.__state = shared_state

self.options = self.__state.options
Expand All @@ -193,7 +190,7 @@ def set_state(self, attrs: dict[str, Any], copy: bool = False):
value = value.copy()
setattr(self, attr, value)

def save_state(self, node: cst.BaseStatement, *attrs: str, copy: bool = False):
def save_state(self, node: cst.CSTNode, *attrs: str, copy: bool = False):
state = self.get_state(*attrs, copy=copy)
if node in self.outer:
# not currently used, and not gonna bother adding dedicated test
Expand All @@ -202,6 +199,9 @@ def save_state(self, node: cst.BaseStatement, *attrs: str, copy: bool = False):
else:
self.outer[node] = state

def restore_state(self, node: cst.CSTNode):
self.set_state(self.outer.pop(node, {}))

def error(
self,
node: cst.CSTNode,
Expand Down
77 changes: 75 additions & 2 deletions flake8_trio/visitors/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@

import ast
from fnmatch import fnmatch
from typing import TYPE_CHECKING, NamedTuple, TypeVar
from typing import TYPE_CHECKING, Any, NamedTuple, TypeVar

import libcst as cst
import libcst.matchers as m
from libcst.helpers import ensure_type, get_full_name_for_node_or_raise

from ..base import Statement
from . import (
Expand All @@ -27,6 +28,7 @@

T = TypeVar("T", bound=Flake8TrioVisitor)
T_CST = TypeVar("T_CST", bound=Flake8TrioVisitor_cst)
T_EITHER = TypeVar("T_EITHER", Flake8TrioVisitor, Flake8TrioVisitor_cst)


def error_class(error_class: type[T]) -> type[T]:
Expand All @@ -41,7 +43,7 @@ def error_class_cst(error_class: type[T_CST]) -> type[T_CST]:
return error_class


def disabled_by_default(error_class: type[T]) -> type[T]:
def disabled_by_default(error_class: type[T_EITHER]) -> type[T_EITHER]:
assert error_class.error_codes
default_disabled_error_codes.extend(error_class.error_codes)
return error_class
Expand Down Expand Up @@ -86,6 +88,21 @@ def fnmatch_qualified_name(name_list: list[ast.expr], *patterns: str) -> str | N
return None


def fnmatch_qualified_name_cst(
name_list: Iterable[cst.Decorator], *patterns: str
) -> str | None:
for name in name_list:
if isinstance(name, cst.Call):
name = name.func
qualified_name = get_full_name_for_node_or_raise(name)

for pattern in patterns:
# strip leading "@"s for when we're working with decorators
if fnmatch(qualified_name, pattern.lstrip("@")):
return pattern
return None


# used in 103/104 and 910/911
def iter_guaranteed_once(iterable: ast.expr) -> bool:
# static container with an "elts" attribute
Expand Down Expand Up @@ -125,6 +142,62 @@ def iter_guaranteed_once(iterable: ast.expr) -> bool:
return False


def cst_literal_eval(node: cst.BaseExpression) -> Any:
ast_node = cst.Module([cst.SimpleStatementLine([cst.Expr(node)])]).code
try:
return ast.literal_eval(ast_node)
except Exception: # noqa: PIE786
return None


def iter_guaranteed_once_cst(iterable: cst.BaseExpression) -> bool:
# static container with an "elts" attribute
if elts := getattr(iterable, "elements", []):
for elt in elts:
assert isinstance(
elt,
(
cst.Element,
cst.DictElement,
cst.StarredElement,
cst.StarredDictElement,
),
)
# recurse starred expression
if isinstance(elt, (cst.StarredElement, cst.StarredDictElement)):
if iter_guaranteed_once_cst(elt.value):
return True
else:
return True
return False

if isinstance(iterable, cst.SimpleString):
return len(ast.literal_eval(iterable.value)) > 0

# check for range() with literal parameters
if m.matches(
iterable,
m.Call(
func=m.Name("range"),
),
):
try:
return (
len(
range(
*[
cst_literal_eval(a.value)
for a in ensure_type(iterable, cst.Call).args
]
)
)
> 0
)
except Exception: # noqa: PIE786
return False
return False


# used in 102, 103 and 104
def critical_except(node: ast.ExceptHandler) -> Statement | None:
def has_exception(node: ast.expr) -> str | None:
Expand Down
2 changes: 1 addition & 1 deletion flake8_trio/visitors/visitor100.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,5 +63,5 @@ def visit_FunctionDef(self, node: cst.FunctionDef):
def leave_FunctionDef(
self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef
) -> cst.FunctionDef:
self.set_state(self.outer.pop(original_node, {}))
self.restore_state(original_node)
return updated_node
4 changes: 2 additions & 2 deletions flake8_trio/visitors/visitor101.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,10 @@ def visit_With(self, node: cst.With):
)

@m.leave(m.OneOf(m.With(), m.FunctionDef()))
def restore_state(
def restore_state_(
self, original_node: cst.BaseStatement, updated_node: cst.BaseStatement
) -> cst.BaseStatement:
self.set_state(self.outer.pop(original_node, {}))
self.restore_state(original_node)
return updated_node

def visit_FunctionDef(self, node: cst.FunctionDef):
Expand Down
Loading

0 comments on commit d781c27

Please sign in to comment.