Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

reimplement 91X in libcst #143

Merged
merged 1 commit into from
Feb 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 15 additions & 18 deletions flake8_trio/visitors/flake8triovisitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from typing import TYPE_CHECKING, Any, Union

import libcst as cst
import libcst.matchers as m
from libcst.metadata import PositionProvider

from ..base import Error, Statement
Expand Down Expand Up @@ -47,8 +46,6 @@ def __init__(self, shared_state: SharedState):
"typed_calls",
}

self.suppress_errors = False

Zac-HD marked this conversation as resolved.
Show resolved Hide resolved
# `variables` can be saved/loaded, but need a setter to not clear the reference
@property
def variables(self) -> dict[str, str]:
Expand Down Expand Up @@ -105,17 +102,16 @@ def error(
elif not re.match(self.options.enable_visitor_codes_regex, error_code):
return

if not self.suppress_errors:
self.__state.problems.append(
Error(
# 7 == len('TRIO...'), so alt messages raise the original code
error_code[:7],
node.lineno,
node.col_offset,
self.error_codes[error_code],
*args,
)
self.__state.problems.append(
Error(
# 7 == len('TRIO...'), so alt messages raise the original code
error_code[:7],
node.lineno,
node.col_offset,
self.error_codes[error_code],
*args,
)
)

def get_state(self, *attrs: str, copy: bool = False) -> dict[str, Any]:
if not attrs:
Expand All @@ -131,8 +127,6 @@ def get_state(self, *attrs: str, copy: bool = False) -> dict[str, Any]:

def set_state(self, attrs: dict[str, Any], copy: bool = False):
for attr, value in attrs.items():
if copy and hasattr(value, "copy"):
value = value.copy()
Zac-HD marked this conversation as resolved.
Show resolved Hide resolved
setattr(self, attr, value)

def save_state(self, node: ast.AST, *attrs: str, copy: bool = False):
Expand Down Expand Up @@ -163,14 +157,14 @@ def add_library(self, name: str) -> None:
self.__state.library = self.__state.library + (name,)


class Flake8TrioVisitor_cst(m.MatcherDecoratableTransformer, ABC):
class Flake8TrioVisitor_cst(cst.CSTTransformer, ABC):
jakkdl marked this conversation as resolved.
Show resolved Hide resolved
# abstract attribute by not providing a value
error_codes: dict[str, str] # pyright: reportUninitializedInstanceVariable=false
METADATA_DEPENDENCIES = (PositionProvider,)

def __init__(self, shared_state: SharedState):
super().__init__()
self.outer: dict[cst.BaseStatement, dict[str, Any]] = {}
self.outer: dict[cst.CSTNode, dict[str, Any]] = {}
self.__state = shared_state

self.options = self.__state.options
Expand All @@ -193,7 +187,7 @@ def set_state(self, attrs: dict[str, Any], copy: bool = False):
value = value.copy()
setattr(self, attr, value)

def save_state(self, node: cst.BaseStatement, *attrs: str, copy: bool = False):
def save_state(self, node: cst.CSTNode, *attrs: str, copy: bool = False):
state = self.get_state(*attrs, copy=copy)
if node in self.outer:
# not currently used, and not gonna bother adding dedicated test
Expand All @@ -202,6 +196,9 @@ def save_state(self, node: cst.BaseStatement, *attrs: str, copy: bool = False):
else:
self.outer[node] = state

def restore_state(self, node: cst.CSTNode):
self.set_state(self.outer.pop(node, {}))

Zac-HD marked this conversation as resolved.
Show resolved Hide resolved
def error(
self,
node: cst.CSTNode,
Expand Down
77 changes: 75 additions & 2 deletions flake8_trio/visitors/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@

import ast
from fnmatch import fnmatch
from typing import TYPE_CHECKING, NamedTuple, TypeVar
from typing import TYPE_CHECKING, Any, NamedTuple, TypeVar

import libcst as cst
import libcst.matchers as m
from libcst.helpers import ensure_type, get_full_name_for_node_or_raise

from ..base import Statement
from . import (
Expand All @@ -27,6 +28,7 @@

T = TypeVar("T", bound=Flake8TrioVisitor)
T_CST = TypeVar("T_CST", bound=Flake8TrioVisitor_cst)
T_EITHER = TypeVar("T_EITHER", Flake8TrioVisitor, Flake8TrioVisitor_cst)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think here you mean bound=Union[...]? The multi-arg form means "must be an exact match to one of these types".

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Confused me for a minute why the multi-arg form doesn't give errors when passed subtypes, but figured out that it transforms the type.
And mypy gave me errors on the union one, but that's a bug with a PR in the pipe: python/mypy#14755



def error_class(error_class: type[T]) -> type[T]:
Expand All @@ -41,7 +43,7 @@ def error_class_cst(error_class: type[T_CST]) -> type[T_CST]:
return error_class


def disabled_by_default(error_class: type[T]) -> type[T]:
def disabled_by_default(error_class: type[T_EITHER]) -> type[T_EITHER]:
assert error_class.error_codes
default_disabled_error_codes.extend(error_class.error_codes)
return error_class
Expand Down Expand Up @@ -86,6 +88,19 @@ def fnmatch_qualified_name(name_list: list[ast.expr], *patterns: str) -> str | N
return None


def fnmatch_qualified_name_cst(
name_list: Iterable[cst.Decorator], *patterns: str
) -> str | None:
for name in name_list:
qualified_name = get_full_name_for_node_or_raise(name)

for pattern in patterns:
# strip leading "@"s for when we're working with decorators
if fnmatch(qualified_name, pattern.lstrip("@")):
return pattern
return None


# used in 103/104 and 910/911
def iter_guaranteed_once(iterable: ast.expr) -> bool:
# static container with an "elts" attribute
Expand All @@ -112,6 +127,8 @@ def iter_guaranteed_once(iterable: ast.expr) -> bool:
return True
else:
return True
return False

jakkdl marked this conversation as resolved.
Show resolved Hide resolved
# check for range() with literal parameters
if (
isinstance(iterable, ast.Call)
Expand All @@ -125,6 +142,62 @@ def iter_guaranteed_once(iterable: ast.expr) -> bool:
return False


def cst_literal_eval(node: cst.BaseExpression) -> Any:
ast_node = cst.Module([cst.SimpleStatementLine([cst.Expr(node)])]).code
try:
return ast.literal_eval(ast_node)
except Exception: # noqa: PIE786
return None
Comment on lines +145 to +150
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

major hoops to have to run through, but didn't find any better way of doing it for now.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems surprising to return None on failure, instead of raising an exception - you lose the ability to distinguish None from foobar() (which raises ValueError as it's not a literal).

We might also want to open an upstream issue - there's Integer.evaluated_value for example, but no such property on e.g. lists or tuples and fixing that would let future projects avoid this helper function.

Copy link
Member Author

@jakkdl jakkdl Feb 28, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, in the case of iter_guaranteed_once (literal parameter to range) it's actually plenty fine to use Integer/Float.evaluated_value; edit: nvm I forgot about range-based for loops edit 2: nvm what actually matters is range(5+5) or range(10, 5, -1) edit 3: ast.literal_eval doesn't handle 5+5 anyway 😂
I opened Instagram/LibCST#878



def iter_guaranteed_once_cst(iterable: cst.BaseExpression) -> bool:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ugly, but works pretty much like the ast version.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd be inclined to do the same "unparse and use ast" trick - if not we'll need some serious testing to be sure they're equivalent.

Copy link
Member Author

@jakkdl jakkdl Feb 28, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you mean using iter_guaranteed_once ?
It has all the same tests that the ast version had, and vast majority of the logic is the same, so I'm pretty confident it's equivalent actually. And if/when 103/104 is converted to libcst there will be no ast users left using it.

# static container with an "elts" attribute
if elts := getattr(iterable, "elements", []):
for elt in elts:
assert isinstance(
elt,
(
cst.Element,
cst.DictElement,
cst.StarredElement,
cst.StarredDictElement,
),
)
# recurse starred expression
if isinstance(elt, (cst.StarredElement, cst.StarredDictElement)):
if iter_guaranteed_once_cst(elt.value):
return True
else:
return True
return False

if isinstance(iterable, cst.SimpleString):
return len(ast.literal_eval(iterable.value)) > 0

# check for range() with literal parameters
if m.matches(
iterable,
m.Call(
func=m.Name("range"),
),
):
try:
return (
len(
range(
*[
cst_literal_eval(a.value)
for a in ensure_type(iterable, cst.Call).args
]
jakkdl marked this conversation as resolved.
Show resolved Hide resolved
)
)
> 0
)
except Exception: # noqa: PIE786
return False
return False


# used in 102, 103 and 104
def critical_except(node: ast.ExceptHandler) -> Statement | None:
def has_exception(node: ast.expr) -> str | None:
Expand Down
25 changes: 15 additions & 10 deletions flake8_trio/visitors/visitor100.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
the timeout can only be triggered by a checkpoint.
Checkpoints on Await, Async For and Async With
"""
# if future annotations are imported then shed will reformat away the Union use
from typing import Any, Union
from __future__ import annotations

from typing import Any
Zac-HD marked this conversation as resolved.
Show resolved Hide resolved

import libcst as cst
import libcst.matchers as m
Expand All @@ -29,9 +30,13 @@ def __init__(self, *args: Any, **kwargs: Any):
self.has_checkpoint_stack: list[bool] = []
self.node_dict: dict[cst.With, list[AttributeCall]] = {}

def checkpoint(self) -> None:
if self.has_checkpoint_stack:
self.has_checkpoint_stack[-1] = True

def visit_With(self, node: cst.With) -> None:
if m.matches(node, m.With(asynchronous=m.Asynchronous())):
self.checkpoint_node(node)
self.checkpoint()
if res := with_has_call(
node, "fail_after", "fail_at", "move_on_after", "move_on_at", "CancelScope"
):
Expand All @@ -49,12 +54,12 @@ def leave_With(self, original_node: cst.With, updated_node: cst.With) -> cst.Wit
# then: remove the with and pop out it's body
return updated_node

@m.visit(m.Await() | m.For(asynchronous=m.Asynchronous()))
# can't use m.call_if_inside(m.With), since it matches parents *or* the node itself
# need to use Union due to https://github.com/Instagram/LibCST/issues/870
def checkpoint_node(self, node: Union[cst.Await, cst.For, cst.With]):
if self.has_checkpoint_stack:
self.has_checkpoint_stack[-1] = True
def visit_For(self, node: cst.For):
if node.asynchronous is not None:
self.checkpoint()

def visit_Await(self, node: cst.Await | cst.For | cst.With):
self.checkpoint()

def visit_FunctionDef(self, node: cst.FunctionDef):
self.save_state(node, "has_checkpoint_stack", copy=True)
Expand All @@ -63,5 +68,5 @@ def visit_FunctionDef(self, node: cst.FunctionDef):
def leave_FunctionDef(
self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef
) -> cst.FunctionDef:
self.set_state(self.outer.pop(original_node, {}))
self.restore_state(original_node)
return updated_node
15 changes: 8 additions & 7 deletions flake8_trio/visitors/visitor101.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,7 @@
"""
from __future__ import annotations

from typing import Any

import libcst as cst
import libcst.matchers as m
from typing import TYPE_CHECKING, Any

from .flake8triovisitor import Flake8TrioVisitor_cst
from .helpers import (
Expand All @@ -18,6 +15,9 @@
with_has_call,
)

if TYPE_CHECKING:
import libcst as cst


@error_class_cst
class Visitor101(Flake8TrioVisitor_cst):
Expand Down Expand Up @@ -45,13 +45,14 @@ def visit_With(self, node: cst.With):
and bool(with_has_call(node, "open_nursery", *cancel_scope_names))
)

@m.leave(m.OneOf(m.With(), m.FunctionDef()))
def restore_state(
def leave_With(
self, original_node: cst.BaseStatement, updated_node: cst.BaseStatement
) -> cst.BaseStatement:
self.set_state(self.outer.pop(original_node, {}))
self.restore_state(original_node)
return updated_node

leave_FunctionDef = leave_With

def visit_FunctionDef(self, node: cst.FunctionDef):
self.save_state(node, "_yield_is_error", "_safe_decorator")
self._yield_is_error = False
Expand Down
Loading