Skip to content

Commit

Permalink
add args
Browse files Browse the repository at this point in the history
  • Loading branch information
jakkdl committed Mar 22, 2023
1 parent 112b998 commit 334c495
Show file tree
Hide file tree
Showing 27 changed files with 255 additions and 118 deletions.
131 changes: 105 additions & 26 deletions flake8_trio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import functools
import keyword
import os
import re
import subprocess
import sys
import tokenize
Expand All @@ -24,8 +23,9 @@

import libcst as cst

from .base import Options
from .runner import Flake8TrioRunner, Flake8TrioRunner_cst
from .visitors import default_disabled_error_codes
from .visitors import ERROR_CLASSES, ERROR_CLASSES_CST, default_disabled_error_codes

if TYPE_CHECKING:
from collections.abc import Iterable, Sequence
Expand Down Expand Up @@ -76,12 +76,6 @@ def cst_parse_module_native(source: str) -> cst.Module:

def main() -> int:
parser = ArgumentParser(prog="flake8_trio")
parser.add_argument(
nargs="*",
metavar="file",
dest="files",
help="Files(s) to format, instead of autodetection.",
)
Plugin.add_options(parser)
args = parser.parse_args()
Plugin.parse_options(args)
Expand Down Expand Up @@ -124,7 +118,13 @@ def main() -> int:
class Plugin:
name = __name__
version = __version__
options: Namespace = Namespace()
standalone = True
_options: Options | None = None

@property
def options(self) -> Options:
assert self._options is not None
return self._options

def __init__(self, tree: ast.AST, lines: Sequence[str]):
super().__init__()
Expand Down Expand Up @@ -158,18 +158,64 @@ def run(self) -> Iterable[Error]:
@staticmethod
def add_options(option_manager: OptionManager | ArgumentParser):
if isinstance(option_manager, ArgumentParser):
# TODO: disable TRIO9xx calls by default
# if run as standalone
Plugin.standalone = True
add_argument = option_manager.add_argument
add_argument(
nargs="*",
metavar="file",
dest="files",
help="Files(s) to format, instead of autodetection.",
)
else: # if run as a flake8 plugin
Plugin.standalone = False
# Disable TRIO9xx calls by default
option_manager.extend_default_ignore(default_disabled_error_codes)
# add parameter to parse from flake8 config
add_argument = functools.partial( # type: ignore
option_manager.add_option, parse_from_config=True
)
add_argument("--autofix", action="store_true", required=False)

add_argument(
"--enable",
type=comma_separated_list,
default="TRIO",
required=False,
help=(
"Comma-separated list of error codes to enable, similar to flake8"
" --select but is additionally more performant as it will disable"
" non-enabled visitors from running instead of just silencing their"
" errors."
),
)
add_argument(
"--disable",
type=comma_separated_list,
default="TRIO9" if Plugin.standalone else "",
required=False,
help=(
"Comma-separated list of error codes to disable, similar to flake8"
" --ignore but is additionally more performant as it will disable"
" non-enabled visitors from running instead of just silencing their"
" errors."
),
)
add_argument(
"--autofix",
type=comma_separated_list,
default="",
required=False,
help=(
"Comma-separated list of error-codes to enable autofixing for"
"if implemented. Requires running as a standalone program."
),
)
add_argument(
"--error-on-autofix",
action="store_true",
required=False,
default=False,
help="Whether to also print an error message for autofixed errors",
)
add_argument(
"--no-checkpoint-warning-decorators",
default="asynccontextmanager",
Expand Down Expand Up @@ -208,19 +254,6 @@ def add_options(option_manager: OptionManager | ArgumentParser):
"suggesting it be replaced with {value}"
),
)
add_argument(
"--enable-visitor-codes-regex",
type=re.compile, # type: ignore[arg-type]
default=".*",
required=False,
help=(
"Regex string of visitors to enable. Can be used to disable broken "
"visitors, or instead of --select/--disable to select error codes "
"in a way that is more performant. If a visitor raises multiple codes "
"it will not be disabled unless all codes are disabled, but it will "
"not report codes matching this regex."
),
)
add_argument(
"--anyio",
# action=store_true + parse_from_config does seem to work here, despite
Expand All @@ -237,7 +270,53 @@ def add_options(option_manager: OptionManager | ArgumentParser):

@staticmethod
def parse_options(options: Namespace):
Plugin.options = options
def get_matching_codes(
patterns: Iterable[str], codes: Iterable[str], msg: str
) -> Iterable[str]:
for pattern in patterns:
for code in codes:
if code.lower().startswith(pattern.lower()):
yield code

autofix_codes: set[str] = set()
enabled_codes: set[str] = set()
disabled_codes: set[str] = {
err_code
for err_class in (*ERROR_CLASSES, *ERROR_CLASSES_CST)
for err_code in err_class.error_codes.keys() # type: ignore[attr-defined]
if len(err_code) == 7 # exclude e.g. TRIO103_anyio_trio
}

if options.autofix:
if not Plugin.standalone:
print("Cannot autofix when run as a flake8 plugin.", file=sys.stderr)
sys.exit(1)
autofix_codes.update(
get_matching_codes(options.autofix, disabled_codes, "autofix")
)

enabled = set(get_matching_codes(options.enable, disabled_codes, "enable"))
enabled_codes |= enabled
disabled_codes -= enabled

disabled = set(get_matching_codes(options.disable, enabled_codes, "disable"))
# if disable has default value, don't disable explicitly enabled codes
if options.disable == ["TRIO9"]:
disabled -= {code for code in options.enable if len(code) == 7}

disabled_codes |= disabled
enabled_codes -= disabled

Plugin._options = Options(
enable=enabled_codes,
disable=disabled_codes,
autofix=autofix_codes,
error_on_autofix=options.error_on_autofix,
no_checkpoint_warning_decorators=options.no_checkpoint_warning_decorators,
startable_in_context_manager=options.startable_in_context_manager,
trio200_blocking_calls=options.trio200_blocking_calls,
anyio=options.anyio,
)


def comma_separated_list(raw_value: str) -> list[str]:
Expand Down
19 changes: 18 additions & 1 deletion flake8_trio/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,24 @@

from __future__ import annotations

from typing import Any, NamedTuple
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, NamedTuple

if TYPE_CHECKING:
from collections.abc import Collection


@dataclass
class Options:
# enable and disable have been expanded to contain full codes, and have no overlap
enable: Collection[str]
disable: Collection[str]
autofix: Collection[str]
error_on_autofix: bool
no_checkpoint_warning_decorators: Collection[str]
startable_in_context_manager: Collection[str]
trio200_blocking_calls: dict[str, str]
anyio: bool


class Statement(NamedTuple):
Expand Down
22 changes: 7 additions & 15 deletions flake8_trio/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from __future__ import annotations

import ast
import re
from dataclasses import dataclass, field
from typing import TYPE_CHECKING

Expand All @@ -21,26 +20,25 @@
)

if TYPE_CHECKING:
from argparse import Namespace
from collections.abc import Iterable

from libcst import Module

from .base import Error
from .base import Error, Options
from .visitors.flake8triovisitor import Flake8TrioVisitor, Flake8TrioVisitor_cst


@dataclass
class SharedState:
options: Namespace
options: Options
problems: list[Error] = field(default_factory=list)
library: tuple[str, ...] = ()
typed_calls: dict[str, str] = field(default_factory=dict)
variables: dict[str, str] = field(default_factory=dict)


class Flake8TrioRunner(ast.NodeVisitor):
def __init__(self, options: Namespace):
def __init__(self, options: Options):
super().__init__()
self.state = SharedState(options)

Expand All @@ -52,13 +50,10 @@ def __init__(self, options: Namespace):
}

def selected(self, error_codes: dict[str, str]) -> bool:
return any(
re.match(self.state.options.enable_visitor_codes_regex, code)
for code in error_codes
)
return any(code in self.state.options.enable for code in error_codes)

@classmethod
def run(cls, tree: ast.AST, options: Namespace) -> Iterable[Error]:
def run(cls, tree: ast.AST, options: Options) -> Iterable[Error]:
runner = cls(options)
runner.visit(tree)
yield from runner.state.problems
Expand Down Expand Up @@ -105,7 +100,7 @@ def visit(self, node: ast.AST):


class Flake8TrioRunner_cst:
def __init__(self, options: Namespace, module: Module):
def __init__(self, options: Options, module: Module):
super().__init__()
self.state = SharedState(options)
self.options = options
Expand All @@ -129,7 +124,4 @@ def run(self) -> Iterable[Error]:
yield from self.state.problems

def selected(self, error_codes: dict[str, str]) -> bool:
return any(
re.match(self.state.options.enable_visitor_codes_regex, code)
for code in error_codes
)
return any(code in self.state.options.enable for code in error_codes)
15 changes: 9 additions & 6 deletions flake8_trio/visitors/flake8triovisitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from __future__ import annotations

import ast
import re
from abc import ABC
from typing import TYPE_CHECKING, Any, Union

Expand Down Expand Up @@ -99,7 +98,7 @@ def error(
), "No error code defined, but class has multiple codes"
error_code = next(iter(self.error_codes))
# don't emit an error if this code is disabled in a multi-code visitor
elif not re.match(self.options.enable_visitor_codes_regex, error_code):
elif error_code in self.options.disable:
return

self.__state.problems.append(
Expand Down Expand Up @@ -211,10 +210,8 @@ def error(
), "No error code defined, but class has multiple codes"
error_code = next(iter(self.error_codes))
# don't emit an error if this code is disabled in a multi-code visitor
elif not re.match(
self.options.enable_visitor_codes_regex, error_code
): # pragma: no cover
return
elif error_code in self.options.disable:
return # pragma: no cover
pos = self.get_metadata(PositionProvider, node).start

self.__state.problems.append(
Expand All @@ -228,6 +225,12 @@ def error(
)
)

def autofix(self, code: str | None = None):
if code is None:
assert len(self.error_codes) == 1
code = next(iter(self.error_codes))
return code in self.options.autofix

@property
def library(self) -> tuple[str, ...]:
return self.__state.library if self.__state.library else ("trio",)
Expand Down
2 changes: 1 addition & 1 deletion flake8_trio/visitors/visitor100.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def leave_With(
for res in self.node_dict[original_node]:
self.error(res.node, res.base, res.function)

if self.options.autofix and len(updated_node.items) == 1:
if self.autofix() and len(updated_node.items) == 1:
return flatten_preserving_comments(updated_node)

return updated_node
Expand Down
5 changes: 4 additions & 1 deletion flake8_trio/visitors/visitor91x.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,9 @@ def __init__(self, *args: Any, **kwargs: Any):
self.loop_state = LoopState()
self.try_state = TryState()

def autofix(self, code: str | None = None) -> bool:
return super().autofix("TRIO911" if self.has_yield else "TRIO910")

def checkpoint_statement(self) -> cst.SimpleStatementLine:
return checkpoint_statement(self.library[0])

Expand Down Expand Up @@ -280,7 +283,7 @@ def leave_FunctionDef(
self.async_function
# updated_node does not have a position, so we must send original_node
and self.check_function_exit(original_node)
and self.options.autofix
and self.autofix()
and isinstance(updated_node.body, cst.IndentedBlock)
):
# insert checkpoint at the end of body
Expand Down
2 changes: 1 addition & 1 deletion tests/autofix_files/trio910.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def bar() -> Any:
...


# ARG --enable-visitor-codes-regex=(TRIO910)|(TRIO911)
# ARG --enable=TRIO910,TRIO911
# ARG --no-checkpoint-warning-decorator=custom_disabled_decorator


Expand Down
2 changes: 1 addition & 1 deletion tests/autofix_files/trio911.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

_: Any = ""

# ARG --enable-visitor-codes-regex=(TRIO910)|(TRIO911)
# ARG --enable=TRIO910,TRIO911


async def foo() -> Any:
Expand Down
2 changes: 1 addition & 1 deletion tests/autofix_files/trio91x_autofix.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
So we make sure that import is added after it.
"""
# isort: skip_file
# ARG --enable-visitor-codes-regex=(TRIO910)|(TRIO911)
# ARG --enable=TRIO910,TRIO911

from typing import Any
import trio
Expand Down
2 changes: 1 addition & 1 deletion tests/autofix_files/trio91x_autofix.py.diff
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
---
+++
@@ -9,6 +9,7 @@
# ARG --enable-visitor-codes-regex=(TRIO910)|(TRIO911)
# ARG --enable=TRIO910,TRIO911

from typing import Any
+import trio
Expand Down
Loading

0 comments on commit 334c495

Please sign in to comment.