Skip to content

Commit

Permalink
add args
Browse files Browse the repository at this point in the history
  • Loading branch information
jakkdl committed Mar 22, 2023
1 parent 112b998 commit 254881e
Show file tree
Hide file tree
Showing 28 changed files with 358 additions and 171 deletions.
134 changes: 108 additions & 26 deletions flake8_trio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import functools
import keyword
import os
import re
import subprocess
import sys
import tokenize
Expand All @@ -24,8 +23,9 @@

import libcst as cst

from .base import Options
from .runner import Flake8TrioRunner, Flake8TrioRunner_cst
from .visitors import default_disabled_error_codes
from .visitors import ERROR_CLASSES, ERROR_CLASSES_CST, default_disabled_error_codes

if TYPE_CHECKING:
from collections.abc import Iterable, Sequence
Expand Down Expand Up @@ -76,12 +76,6 @@ def cst_parse_module_native(source: str) -> cst.Module:

def main() -> int:
parser = ArgumentParser(prog="flake8_trio")
parser.add_argument(
nargs="*",
metavar="file",
dest="files",
help="Files(s) to format, instead of autodetection.",
)
Plugin.add_options(parser)
args = parser.parse_args()
Plugin.parse_options(args)
Expand Down Expand Up @@ -124,7 +118,13 @@ def main() -> int:
class Plugin:
name = __name__
version = __version__
options: Namespace = Namespace()
standalone = True
_options: Options | None = None

@property
def options(self) -> Options:
assert self._options is not None
return self._options

def __init__(self, tree: ast.AST, lines: Sequence[str]):
super().__init__()
Expand Down Expand Up @@ -158,18 +158,64 @@ def run(self) -> Iterable[Error]:
@staticmethod
def add_options(option_manager: OptionManager | ArgumentParser):
if isinstance(option_manager, ArgumentParser):
# TODO: disable TRIO9xx calls by default
# if run as standalone
Plugin.standalone = True
add_argument = option_manager.add_argument
add_argument(
nargs="*",
metavar="file",
dest="files",
help="Files(s) to format, instead of autodetection.",
)
else: # if run as a flake8 plugin
Plugin.standalone = False
# Disable TRIO9xx calls by default
option_manager.extend_default_ignore(default_disabled_error_codes)
# add parameter to parse from flake8 config
add_argument = functools.partial( # type: ignore
option_manager.add_option, parse_from_config=True
)
add_argument("--autofix", action="store_true", required=False)

add_argument(
"--enable",
type=comma_separated_list,
default="TRIO",
required=False,
help=(
"Comma-separated list of error codes to enable, similar to flake8"
" --select but is additionally more performant as it will disable"
" non-enabled visitors from running instead of just silencing their"
" errors."
),
)
add_argument(
"--disable",
type=comma_separated_list,
default="TRIO9" if Plugin.standalone else "",
required=False,
help=(
"Comma-separated list of error codes to disable, similar to flake8"
" --ignore but is additionally more performant as it will disable"
" non-enabled visitors from running instead of just silencing their"
" errors."
),
)
add_argument(
"--autofix",
type=comma_separated_list,
default="",
required=False,
help=(
"Comma-separated list of error-codes to enable autofixing for"
"if implemented. Requires running as a standalone program."
),
)
add_argument(
"--error-on-autofix",
action="store_true",
required=False,
default=False,
help="Whether to also print an error message for autofixed errors",
)
add_argument(
"--no-checkpoint-warning-decorators",
default="asynccontextmanager",
Expand Down Expand Up @@ -208,19 +254,6 @@ def add_options(option_manager: OptionManager | ArgumentParser):
"suggesting it be replaced with {value}"
),
)
add_argument(
"--enable-visitor-codes-regex",
type=re.compile, # type: ignore[arg-type]
default=".*",
required=False,
help=(
"Regex string of visitors to enable. Can be used to disable broken "
"visitors, or instead of --select/--disable to select error codes "
"in a way that is more performant. If a visitor raises multiple codes "
"it will not be disabled unless all codes are disabled, but it will "
"not report codes matching this regex."
),
)
add_argument(
"--anyio",
# action=store_true + parse_from_config does seem to work here, despite
Expand All @@ -237,7 +270,56 @@ def add_options(option_manager: OptionManager | ArgumentParser):

@staticmethod
def parse_options(options: Namespace):
Plugin.options = options
def get_matching_codes(
patterns: Iterable[str], codes: Iterable[str], msg: str
) -> Iterable[str]:
for pattern in patterns:
for code in codes:
if code.lower().startswith(pattern.lower()):
yield code

autofix_codes: set[str] = set()
enabled_codes: set[str] = set()
disabled_codes: set[str] = {
err_code
for err_class in (*ERROR_CLASSES, *ERROR_CLASSES_CST)
for err_code in err_class.error_codes.keys() # type: ignore[attr-defined]
if len(err_code) == 7 # exclude e.g. TRIO103_anyio_trio
}

if options.autofix:
if not Plugin.standalone:
print("Cannot autofix when run as a flake8 plugin.", file=sys.stderr)
sys.exit(1)
autofix_codes.update(
get_matching_codes(options.autofix, disabled_codes, "autofix")
)

# enable codes
tmp = set(get_matching_codes(options.enable, disabled_codes, "enable"))
enabled_codes |= tmp
disabled_codes -= tmp

# disable codes
tmp = set(get_matching_codes(options.disable, enabled_codes, "disable"))

# if disable has default value, don't disable explicitly enabled codes
if options.disable == ["TRIO9"]:
tmp -= {code for code in options.enable if len(code) == 7}

disabled_codes |= tmp
enabled_codes -= tmp

Plugin._options = Options(
enable=enabled_codes,
disable=disabled_codes,
autofix=autofix_codes,
error_on_autofix=options.error_on_autofix,
no_checkpoint_warning_decorators=options.no_checkpoint_warning_decorators,
startable_in_context_manager=options.startable_in_context_manager,
trio200_blocking_calls=options.trio200_blocking_calls,
anyio=options.anyio,
)


def comma_separated_list(raw_value: str) -> list[str]:
Expand Down
19 changes: 18 additions & 1 deletion flake8_trio/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,24 @@

from __future__ import annotations

from typing import Any, NamedTuple
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, NamedTuple

if TYPE_CHECKING:
from collections.abc import Collection


@dataclass
class Options:
# enable and disable have been expanded to contain full codes, and have no overlap
enable: Collection[str]
disable: Collection[str]
autofix: Collection[str]
error_on_autofix: bool
no_checkpoint_warning_decorators: Collection[str]
startable_in_context_manager: Collection[str]
trio200_blocking_calls: dict[str, str]
anyio: bool


class Statement(NamedTuple):
Expand Down
43 changes: 20 additions & 23 deletions flake8_trio/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from __future__ import annotations

import ast
import re
from abc import ABC
from dataclasses import dataclass, field
from typing import TYPE_CHECKING

Expand All @@ -21,44 +21,48 @@
)

if TYPE_CHECKING:
from argparse import Namespace
from collections.abc import Iterable

from libcst import Module

from .base import Error
from .base import Error, Options
from .visitors.flake8triovisitor import Flake8TrioVisitor, Flake8TrioVisitor_cst


@dataclass
class SharedState:
options: Namespace
options: Options
problems: list[Error] = field(default_factory=list)
library: tuple[str, ...] = ()
typed_calls: dict[str, str] = field(default_factory=dict)
variables: dict[str, str] = field(default_factory=dict)


class Flake8TrioRunner(ast.NodeVisitor):
def __init__(self, options: Namespace):
class CommonRunner(ABC): # noqa: B024 # no abstract methods
def __init__(self, options: Options):
super().__init__()
self.state = SharedState(options)

def selected(self, error_codes: dict[str, str]) -> bool:
for code in error_codes:
for enabled in self.state.options.enable, self.state.options.autofix:
if code in enabled:
return True
return False


class Flake8TrioRunner(ast.NodeVisitor, CommonRunner):
def __init__(self, options: Options):
super().__init__(options)
# utility visitors that need to run before the error-checking visitors
self.utility_visitors = {v(self.state) for v in utility_visitors}

self.visitors = {
v(self.state) for v in ERROR_CLASSES if self.selected(v.error_codes)
}

def selected(self, error_codes: dict[str, str]) -> bool:
return any(
re.match(self.state.options.enable_visitor_codes_regex, code)
for code in error_codes
)

@classmethod
def run(cls, tree: ast.AST, options: Namespace) -> Iterable[Error]:
def run(cls, tree: ast.AST, options: Options) -> Iterable[Error]:
runner = cls(options)
runner.visit(tree)
yield from runner.state.problems
Expand Down Expand Up @@ -104,10 +108,9 @@ def visit(self, node: ast.AST):
subclass.set_state(subclass.outer.pop(node, {}))


class Flake8TrioRunner_cst:
def __init__(self, options: Namespace, module: Module):
super().__init__()
self.state = SharedState(options)
class Flake8TrioRunner_cst(CommonRunner):
def __init__(self, options: Options, module: Module):
super().__init__(options)
self.options = options

# Could possibly enable/disable utility visitors here, if visitors declared
Expand All @@ -127,9 +130,3 @@ def run(self) -> Iterable[Error]:
for v in (*self.utility_visitors, *self.visitors):
self.module = cst.MetadataWrapper(self.module).visit(v)
yield from self.state.problems

def selected(self, error_codes: dict[str, str]) -> bool:
return any(
re.match(self.state.options.enable_visitor_codes_regex, code)
for code in error_codes
)
20 changes: 11 additions & 9 deletions flake8_trio/visitors/flake8triovisitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from __future__ import annotations

import ast
import re
from abc import ABC
from typing import TYPE_CHECKING, Any, Union

Expand Down Expand Up @@ -99,7 +98,7 @@ def error(
), "No error code defined, but class has multiple codes"
error_code = next(iter(self.error_codes))
# don't emit an error if this code is disabled in a multi-code visitor
elif not re.match(self.options.enable_visitor_codes_regex, error_code):
elif error_code in self.options.disable:
return

self.__state.problems.append(
Expand Down Expand Up @@ -190,9 +189,7 @@ def set_state(self, attrs: dict[str, Any], copy: bool = False):
def save_state(self, node: cst.CSTNode, *attrs: str, copy: bool = False):
state = self.get_state(*attrs, copy=copy)
if node in self.outer:
# not currently used, and not gonna bother adding dedicated test
# visitors atm
self.outer[node].update(state) # pragma: no cover
self.outer[node].update(state)
else:
self.outer[node] = state

Expand All @@ -211,10 +208,9 @@ def error(
), "No error code defined, but class has multiple codes"
error_code = next(iter(self.error_codes))
# don't emit an error if this code is disabled in a multi-code visitor
elif not re.match(
self.options.enable_visitor_codes_regex, error_code
): # pragma: no cover
return
# TODO: write test for only one of 910/911 enabled/autofixed
elif error_code in self.options.disable:
return # pragma: no cover
pos = self.get_metadata(PositionProvider, node).start

self.__state.problems.append(
Expand All @@ -228,6 +224,12 @@ def error(
)
)

def autofix(self, code: str | None = None):
if code is None:
assert len(self.error_codes) == 1
code = next(iter(self.error_codes))
return code in self.options.autofix

@property
def library(self) -> tuple[str, ...]:
return self.__state.library if self.__state.library else ("trio",)
Expand Down
Loading

0 comments on commit 254881e

Please sign in to comment.