From 5bfdbddc3b47c9089a9f3dc8fe096b346a309dd1 Mon Sep 17 00:00:00 2001 From: Sxderp Date: Sun, 25 Sep 2022 05:48:36 -0500 Subject: [PATCH] Fix magic method yields (__next__, __anext__, __aiter__) (#2400) * do not allow yield in __anext__ nor __next__ __anext__: This is a runtime constraint. __anext__ cannot be a generator. This will cause the runtime to throw a TypeError when used in an async for loop. __next__: The usefulness of having __next__ be a generator is practically nil. Doing so will cause an infinite loop with a new generator object passed to each iteration. * fix handling of __aiter__ and yield __aiter__ can only contain yield if it is an async function (async generator). Otherwise it must be a sync function that does not contain yield and returns an object that implements __anext__. * update violation class documentation * remove unnecessary parenthesis * add __anext__ and __next__ to invalid yields test * add new tests for magic methods that depend on async + yield * add missing argument when creating new tests * Make tests pass * Make tests pass Co-authored-by: sobolevn --- CHANGELOG.md | 2 + .../test_async_yield_magic_methods.py | 148 ++++++++++++++++++ .../test_methods/test_yield_magic_method.py | 7 +- .../test_function/test_asserts_count.py | 6 +- wemake_python_styleguide/constants.py | 5 +- wemake_python_styleguide/violations/oop.py | 4 +- .../visitors/ast/classes.py | 109 +++++++------ 7 files changed, 226 insertions(+), 55 deletions(-) create mode 100644 tests/test_visitors/test_ast/test_classes/test_methods/test_async_yield_magic_methods.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 0f1b69f52..6feb83596 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -30,6 +30,8 @@ Semantic versioning in our case means: - Fixes `WPS226` false positives on `|` use in `SomeType | AnotherType` type hints syntax - Now `-1` is not reported to be an overused expression +- Allow `__aiter__` to be async iterator +- Adds violation method name to error message of `YieldMagicMethodViolation` ### Misc diff --git a/tests/test_visitors/test_ast/test_classes/test_methods/test_async_yield_magic_methods.py b/tests/test_visitors/test_ast/test_classes/test_methods/test_async_yield_magic_methods.py new file mode 100644 index 000000000..e375df7cd --- /dev/null +++ b/tests/test_visitors/test_ast/test_classes/test_methods/test_async_yield_magic_methods.py @@ -0,0 +1,148 @@ +import pytest + +from wemake_python_styleguide.violations.oop import AsyncMagicMethodViolation +from wemake_python_styleguide.visitors.ast.classes import WrongMethodVisitor + +sync_method = """ +class Example(object): + def {0}(self, *args, **kwargs): + {1} +""" + +async_method = """ +class Example(object): + async def {0}(self, *args, **kwargs): + {1} +""" + + +@pytest.mark.parametrize('template', [ + sync_method, + async_method, +]) +@pytest.mark.parametrize('method', [ + '__aiter__', +]) +@pytest.mark.parametrize('statement', [ + 'yield', + 'yield 1', +]) +def test_yield_is_always_allowed_in_aiter( + assert_errors, + parse_ast_tree, + default_options, + template, + method, + statement, +): + """Testing that the `__aiter__` can always have `yield`.""" + tree = parse_ast_tree(template.format(method, statement)) + + visitor = WrongMethodVisitor(default_options, tree=tree) + visitor.run() + + assert_errors(visitor, []) + + +@pytest.mark.parametrize('method', [ + '__aiter__', +]) +@pytest.mark.parametrize('statement', [ + 'return some_async_iterator()', +]) +def test_wrong_async_magic_used( + assert_errors, + assert_error_text, + parse_ast_tree, + default_options, + method, + statement, +): + """Testing that the method cannot be a coroutine.""" + tree = parse_ast_tree(async_method.format(method, statement)) + + visitor = WrongMethodVisitor(default_options, tree=tree) + visitor.run() + + assert_errors(visitor, [AsyncMagicMethodViolation]) + assert_error_text(visitor, method) + + +@pytest.mark.parametrize('method', [ + '__aiter__', +]) +@pytest.mark.parametrize('statement', [ + 'yield', + 'yield 1', +]) +def test_correct_async_yield_magic_used( + assert_errors, + parse_ast_tree, + default_options, + method, + statement, +): + """Testing that the method can be an async generator.""" + tree = parse_ast_tree(async_method.format(method, statement)) + + visitor = WrongMethodVisitor(default_options, tree=tree) + visitor.run() + + assert_errors(visitor, []) + + +@pytest.mark.parametrize('method', [ + '__aiter__', +]) +@pytest.mark.parametrize('statement', [ + 'return some_async_iterator()', +]) +def test_correct_sync_magic_used( + assert_errors, + parse_ast_tree, + default_options, + method, + statement, +): + """Testing that the method can be a normal method.""" + tree = parse_ast_tree(sync_method.format(method, statement)) + + visitor = WrongMethodVisitor(default_options, tree=tree) + visitor.run() + + assert_errors(visitor, []) + + +# Examples: + +correct_nested_example = """ +class Some: + {0}def __aiter__(self): + async def inner(): + yield 1 + return inner() +""" + + +@pytest.mark.parametrize('example', [ + correct_nested_example, +]) +@pytest.mark.parametrize('mode', [ + # We don't use `mode()` fixture here, because we have a nested func. + '', # sync + 'async ', +]) +def test_correct_examples( + assert_errors, + parse_ast_tree, + default_options, + example, + mode, +): + """Testing specific real-life examples that should be working.""" + tree = parse_ast_tree(example.format(mode)) + + visitor = WrongMethodVisitor(default_options, tree=tree) + visitor.run() + + assert_errors(visitor, []) diff --git a/tests/test_visitors/test_ast/test_classes/test_methods/test_yield_magic_method.py b/tests/test_visitors/test_ast/test_classes/test_methods/test_yield_magic_method.py index 9d3f1a13f..b6d212013 100644 --- a/tests/test_visitors/test_ast/test_classes/test_methods/test_yield_magic_method.py +++ b/tests/test_visitors/test_ast/test_classes/test_methods/test_yield_magic_method.py @@ -27,6 +27,8 @@ def {0}(cls, *args, **kwargs): '__str__', '__aenter__', '__exit__', + '__anext__', + '__next__', ]) @pytest.mark.parametrize('statement', [ 'yield', @@ -35,6 +37,7 @@ def {0}(cls, *args, **kwargs): ]) def test_magic_generator( assert_errors, + assert_error_text, parse_ast_tree, default_options, code, @@ -48,6 +51,7 @@ def test_magic_generator( visitor.run() assert_errors(visitor, [YieldMagicMethodViolation]) + assert_error_text(visitor, method) @pytest.mark.parametrize('code', [ @@ -90,6 +94,7 @@ def test_magic_statement( ]) @pytest.mark.parametrize('method', [ '__iter__', + '__aiter__', '__call__', '__custom__', ]) @@ -106,7 +111,7 @@ def test_iter_generator( method, statement, ): - """Testing that magic `iter` and `call` methods with `yield` are allowed.""" + """Testing that some magic methods with `yield` are allowed.""" tree = parse_ast_tree(code.format(method, statement)) visitor = WrongMethodVisitor(default_options, tree=tree) diff --git a/tests/test_visitors/test_ast/test_complexity/test_function/test_asserts_count.py b/tests/test_visitors/test_ast/test_complexity/test_function/test_asserts_count.py index 9592b45f1..c17e535b3 100644 --- a/tests/test_visitors/test_ast/test_complexity/test_function/test_asserts_count.py +++ b/tests/test_visitors/test_ast/test_complexity/test_function/test_asserts_count.py @@ -33,10 +33,11 @@ def test_asserts_correct_count( assert_errors, parse_ast_tree, code, + mode, default_options, ): """Testing that asserts counted correctly.""" - tree = parse_ast_tree(code) + tree = parse_ast_tree(mode(code)) visitor = FunctionComplexityVisitor(default_options, tree=tree) visitor.run() @@ -54,9 +55,10 @@ def test_asserts_wrong_count( parse_ast_tree, options, code, + mode, ): """Testing that many asserts raises a warning.""" - tree = parse_ast_tree(code) + tree = parse_ast_tree(mode(code)) option_values = options(max_asserts=1) visitor = FunctionComplexityVisitor(option_values, tree=tree) diff --git a/wemake_python_styleguide/constants.py b/wemake_python_styleguide/constants.py index 60d0cce87..7b788908f 100644 --- a/wemake_python_styleguide/constants.py +++ b/wemake_python_styleguide/constants.py @@ -325,19 +325,18 @@ # Allowed to be used with ``yield`` keyword: '__call__', '__iter__', - '__anext__', '__aiter__', - '__next__', }) #: List of magic methods that are not allowed to be async. ASYNC_MAGIC_METHODS_BLACKLIST: Final = ALL_MAGIC_METHODS.difference({ # In order of appearance on # https://docs.python.org/3/reference/datamodel.html#basic-customization - # Allowed magic methods are: + # Allowed async magic methods are: '__anext__', '__aenter__', '__aexit__', + '__aiter__', '__call__', }) diff --git a/wemake_python_styleguide/violations/oop.py b/wemake_python_styleguide/violations/oop.py index 817c908ad..071318225 100644 --- a/wemake_python_styleguide/violations/oop.py +++ b/wemake_python_styleguide/violations/oop.py @@ -445,6 +445,7 @@ class AsyncMagicMethodViolation(ASTViolation): Forbid certain async magic methods. We allow to make ``__anext__``, ``__aenter__``, ``__aexit__`` async. + We allow to make ``__aiter__`` async if it is a generator (contains yield). We also allow custom magic methods to be async. See @@ -485,6 +486,7 @@ class YieldMagicMethodViolation(ASTViolation): Forbid ``yield`` inside of certain magic methods. We allow to make ``__iter__`` a generator. + We allow to make ``__aiter__`` an async generator. See :py:data:`~wemake_python_styleguide.constants.YIELD_MAGIC_METHODS_BLACKLIST` for the whole list of blacklisted generator magic methods. @@ -519,7 +521,7 @@ def __init__(self): """ - error_template = 'Found forbidden `yield` magic method usage' + error_template = 'Found forbidden `yield` magic method usage: {0}' code = 611 previous_codes = {439, 435} diff --git a/wemake_python_styleguide/visitors/ast/classes.py b/wemake_python_styleguide/visitors/ast/classes.py index 4ff53cab1..31c252922 100644 --- a/wemake_python_styleguide/visitors/ast/classes.py +++ b/wemake_python_styleguide/visitors/ast/classes.py @@ -157,12 +157,22 @@ class WrongMethodVisitor(base.BaseNodeVisitor): _staticmethod_names: ClassVar[FrozenSet[str]] = frozenset(( 'staticmethod', )) + _special_async_iter: ClassVar[FrozenSet[str]] = frozenset(( + '__aiter__', + )) def visit_any_function(self, node: types.AnyFunctionDef) -> None: """Checking class methods: async and regular.""" - self._check_decorators(node) - self._check_bound_methods(node) - self._check_method_contents(node) + node_context = nodes.get_context(node) + if isinstance(node_context, ast.ClassDef): + self._check_decorators(node) + self._check_bound_methods(node) + self._check_yield_magic_methods(node) + self._check_async_magic_methods(node) + self._check_useless_overwritten_methods( + node, + class_name=node_context.name, + ) self.generic_visit(node) def _check_decorators(self, node: types.AnyFunctionDef) -> None: @@ -172,10 +182,6 @@ def _check_decorators(self, node: types.AnyFunctionDef) -> None: self.add_violation(oop.StaticMethodViolation(node)) def _check_bound_methods(self, node: types.AnyFunctionDef) -> None: - node_context = nodes.get_context(node) - if not isinstance(node_context, ast.ClassDef): - return - if not functions.get_all_arguments(node): self.add_violation( oop.MethodWithoutArgumentsViolation(node, text=node.name), @@ -186,22 +192,59 @@ def _check_bound_methods(self, node: types.AnyFunctionDef) -> None: oop.BadMagicMethodViolation(node, text=node.name), ) - is_async = isinstance(node, ast.AsyncFunctionDef) - if is_async and access.is_magic(node.name): - if node.name in constants.ASYNC_MAGIC_METHODS_BLACKLIST: + def _check_yield_magic_methods(self, node: types.AnyFunctionDef) -> None: + if isinstance(node, ast.AsyncFunctionDef): + return + + if node.name in constants.YIELD_MAGIC_METHODS_BLACKLIST: + if walk.is_contained(node, (ast.Yield, ast.YieldFrom)): + self.add_violation( + oop.YieldMagicMethodViolation(node, text=node.name), + ) + + def _check_async_magic_methods(self, node: types.AnyFunctionDef) -> None: + if not isinstance(node, ast.AsyncFunctionDef): + return + + if node.name in self._special_async_iter: + if not walk.is_contained(node, ast.Yield): # YieldFrom not async self.add_violation( oop.AsyncMagicMethodViolation(node, text=node.name), ) + elif node.name in constants.ASYNC_MAGIC_METHODS_BLACKLIST: + self.add_violation( + oop.AsyncMagicMethodViolation(node, text=node.name), + ) - self._check_useless_overwritten_methods( - node, - class_name=node_context.name, - ) + def _check_useless_overwritten_methods( + self, + node: types.AnyFunctionDef, + class_name: str, + ) -> None: + if node.decorator_list: + # Any decorator can change logic and make this overwrite useful. + return - def _check_method_contents(self, node: types.AnyFunctionDef) -> None: - if node.name in constants.YIELD_MAGIC_METHODS_BLACKLIST: - if walk.is_contained(node, (ast.Yield, ast.YieldFrom)): - self.add_violation(oop.YieldMagicMethodViolation(node)) + call_stmt = self._get_call_stmt_of_useless_method(node) + if call_stmt is None or not isinstance(call_stmt.func, ast.Attribute): + return + + attribute = call_stmt.func + defined_method_name = node.name + if defined_method_name != attribute.attr: + return + + if not super_args.is_ordinary_super_call(attribute.value, class_name): + return + + if not function_args.is_call_matched_by_arguments(node, call_stmt): + return + + self.add_violation( + oop.UselessOverwrittenMethodViolation( + node, text=defined_method_name, + ), + ) def _get_call_stmt_of_useless_method( self, @@ -236,36 +279,6 @@ def _get_call_stmt_of_useless_method( return stmt.value return None - def _check_useless_overwritten_methods( - self, - node: types.AnyFunctionDef, - class_name: str, - ) -> None: - if node.decorator_list: - # Any decorator can change logic and make this overwrite useful. - return - - call_stmt = self._get_call_stmt_of_useless_method(node) - if call_stmt is None or not isinstance(call_stmt.func, ast.Attribute): - return - - attribute = call_stmt.func - defined_method_name = node.name - if defined_method_name != attribute.attr: - return - - if not super_args.is_ordinary_super_call(attribute.value, class_name): - return - - if not function_args.is_call_matched_by_arguments(node, call_stmt): - return - - self.add_violation( - oop.UselessOverwrittenMethodViolation( - node, text=defined_method_name, - ), - ) - @final @decorators.alias('visit_any_assign', (