From 3611d5e554c44ad55be5616e010574034df3af25 Mon Sep 17 00:00:00 2001 From: Eric Arellano Date: Tue, 11 Aug 2020 14:20:37 -0700 Subject: [PATCH] Improve errors for invalid `Get` objects # Rust tests and lints will be skipped. Delete if not intended. [ci skip-rust] # Building wheels and fs_util will be skipped. Delete if not intended. [ci skip-build-wheels] --- .../pants/engine/internals/selectors.py | 81 ++++++++++++- .../pants/engine/internals/selectors_test.py | 81 +++++++++++-- src/python/pants/engine/rules.py | 107 +++--------------- src/python/pants/engine/rules_test.py | 14 ++- 4 files changed, 175 insertions(+), 108 deletions(-) diff --git a/src/python/pants/engine/internals/selectors.py b/src/python/pants/engine/internals/selectors.py index 6f8238fbedc..12620115373 100644 --- a/src/python/pants/engine/internals/selectors.py +++ b/src/python/pants/engine/internals/selectors.py @@ -1,8 +1,10 @@ # Copyright 2015 Pants project contributors (see CONTRIBUTORS.md). # Licensed under the Apache License, Version 2.0 (see LICENSE). +import ast import itertools from dataclasses import dataclass +from functools import partial from textwrap import dedent from typing import ( Any, @@ -10,6 +12,7 @@ Generic, Iterable, Optional, + Sequence, Tuple, Type, TypeVar, @@ -26,11 +29,83 @@ _SubjectType = TypeVar("_SubjectType") +class GetParseError(ValueError): + def __init__( + self, explanation: str, *, get_args: Sequence[ast.expr], source_file_name: str + ) -> None: + def render_arg(expr: ast.expr) -> str: + if isinstance(expr, ast.Name): + return expr.id + if isinstance(expr, ast.Call): + # Check if it's a top-level function call. + if hasattr(expr.func, "id"): + return f"{expr.func.id}()" # type: ignore[attr-defined] + # Check if it's a method call. + if hasattr(expr.func, "attr") and hasattr(expr.func, "value"): + return f"{expr.func.value.id}.{expr.func.attr}()" # type: ignore[attr-defined] + + # Fall back to the name of the ast node's class. + return str(type(expr)) + + rendered_args = ", ".join(render_arg(arg) for arg in get_args) + # TODO: Add the line numbers for the `Get`. The number for `get_args[0].lineno` are + # off because they're relative to the wrapping rule. + super().__init__( + f"Invalid Get. {explanation} Failed for Get({rendered_args}) in {source_file_name}." + ) + + @frozen_after_init @dataclass(unsafe_hash=True) -class GetConstraints(Generic[_ProductType, _SubjectType]): - product_type: Type[_ProductType] - subject_declared_type: Type[_SubjectType] +class GetConstraints: + product_type: Type + subject_declared_type: Type + + @classmethod + def parse_product_and_subject_types( + cls, get_args: Sequence[ast.expr], *, source_file_name: str + ) -> Tuple[str, str]: + parse_error = partial(GetParseError, get_args=get_args, source_file_name=source_file_name) + + if len(get_args) not in (2, 3): + raise parse_error( + f"Expected either two or three arguments, but got {len(get_args)} arguments." + ) + + product_expr = get_args[0] + if not isinstance(product_expr, ast.Name): + raise parse_error( + "The first argument should be the type of the product, like `Digest` or " + "`ProcessResult`." + ) + product_type = product_expr.id + + subject_args = get_args[1:] + if len(subject_args) == 1: + subject_constructor = subject_args[0] + if not isinstance(subject_constructor, ast.Call): + raise parse_error( + "Because you are using the shorthand form Get(ProductType, " + "SubjectType(constructor args), the second argument should be a constructor " + "call, like `MergeDigest(...)` or `Process(...)`." + ) + if not hasattr(subject_constructor.func, "id"): + raise parse_error( + "Because you are using the shorthand form Get(ProductType, " + "SubjectType(constructor args), the second argument should be a top-level " + "constructor function call, like `MergeDigest(...)` or `Process(...)`, rather " + "than a method call." + ) + return product_type, subject_constructor.func.id # type: ignore[attr-defined] + + subject_type, _ = subject_args + if not isinstance(subject_type, ast.Name): + raise parse_error( + "Because you are using the longhand form Get(ProductType, SubjectType, " + "subject_instance), the second argument should be a type, like `MergeDigests` or " + "`Process`." + ) + return product_type, subject_type.id @frozen_after_init diff --git a/src/python/pants/engine/internals/selectors_test.py b/src/python/pants/engine/internals/selectors_test.py index 108f022f15d..6a073cf0720 100644 --- a/src/python/pants/engine/internals/selectors_test.py +++ b/src/python/pants/engine/internals/selectors_test.py @@ -1,12 +1,77 @@ # Copyright 2016 Pants project contributors (see CONTRIBUTORS.md). # Licensed under the Apache License, Version 2.0 (see LICENSE). +import ast import re -from typing import Any +from typing import Any, Tuple import pytest -from pants.engine.internals.selectors import Get, MultiGet +from pants.engine.internals.selectors import Get, GetConstraints, GetParseError, MultiGet + + +def parse_get_types(get: str) -> Tuple[str, str]: + get_args = ast.parse(get).body[0].value.args # type: ignore[attr-defined] + return GetConstraints.parse_product_and_subject_types(get_args, source_file_name="test.py") + + +def test_parse_get_types_valid() -> None: + assert parse_get_types("Get(P, S, subject)") == ("P", "S") + assert parse_get_types("Get(P, S())") == ("P", "S") + + +def assert_parse_get_types_fails(get: str, *, expected_explanation: str) -> None: + with pytest.raises(GetParseError) as exc: + parse_get_types(get) + assert str(exc.value) == f"Invalid Get. {expected_explanation} Failed for {get} in test.py." + + +def test_parse_get_types_wrong_number_args() -> None: + assert_parse_get_types_fails( + "Get()", + expected_explanation="Expected either two or three arguments, but got 0 arguments.", + ) + assert_parse_get_types_fails( + "Get(P, S1, S2(), S3.create())", + expected_explanation="Expected either two or three arguments, but got 4 arguments.", + ) + + +def test_parse_get_types_invalid_product() -> None: + assert_parse_get_types_fails( + "Get(P(), S, subject)", + expected_explanation=( + "The first argument should be the type of the product, like `Digest` or " + "`ProcessResult`." + ), + ) + + +def test_parse_get_types_invalid_subject() -> None: + assert_parse_get_types_fails( + "Get(P, S)", + expected_explanation=( + "Because you are using the shorthand form Get(ProductType, SubjectType(constructor " + "args), the second argument should be a constructor call, like `MergeDigest(...)` or " + "`Process(...)`." + ), + ) + assert_parse_get_types_fails( + "Get(P, Subject.create())", + expected_explanation=( + "Because you are using the shorthand form Get(ProductType, SubjectType(constructor " + "args), the second argument should be a top-level constructor function call, like " + "`MergeDigest(...)` or `Process(...)`, rather than a method call." + ), + ) + assert_parse_get_types_fails( + "Get(P, Subject(), subject)", + expected_explanation=( + "Because you are using the longhand form Get(ProductType, SubjectType, " + "subject_instance), the second argument should be a type, like `MergeDigests` or " + "`Process`." + ), + ) class AClass: @@ -18,19 +83,19 @@ def __eq__(self, other: Any): return type(self) == type(other) -def test_create() -> None: +def test_create_get() -> None: get = Get(AClass, BClass, 42) assert get.product_type is AClass assert get.subject_declared_type is BClass assert get.subject == 42 -def test_create_abbreviated() -> None: +def test_create_get_abbreviated() -> None: # Test the equivalence of the 1-arg and 2-arg versions. assert Get(AClass, BClass()) == Get(AClass, BClass, BClass()) -def test_invalid_abbreviated() -> None: +def test_invalid_get_abbreviated() -> None: with pytest.raises( expected_exception=TypeError, match=re.escape(f"The subject argument cannot be a type, given {BClass}."), @@ -38,7 +103,7 @@ def test_invalid_abbreviated() -> None: Get(AClass, BClass) -def test_invalid_subject() -> None: +def test_invalid_get_subject() -> None: with pytest.raises( expected_exception=TypeError, match=re.escape(f"The subject argument cannot be a type, given {BClass}."), @@ -46,7 +111,7 @@ def test_invalid_subject() -> None: Get(AClass, BClass, BClass) -def test_invalid_subject_declared_type() -> None: +def test_invalid_get_subject_declared_type() -> None: with pytest.raises( expected_exception=TypeError, match=re.escape( @@ -56,7 +121,7 @@ def test_invalid_subject_declared_type() -> None: Get(AClass, 1, BClass) # type: ignore[call-overload] -def test_invalid_product_type() -> None: +def test_invalid_get_product_type() -> None: with pytest.raises( expected_exception=TypeError, match=re.escape(f"The product type argument must be a type, given {1} of type {type(1)}."), diff --git a/src/python/pants/engine/rules.py b/src/python/pants/engine/rules.py index ffe535bf799..923d414ebea 100644 --- a/src/python/pants/engine/rules.py +++ b/src/python/pants/engine/rules.py @@ -46,104 +46,29 @@ def side_effecting(cls): class _RuleVisitor(ast.NodeVisitor): """Pull `Get` calls out of an @rule body.""" - def __init__( - self, *, resolve_type: Callable[[str], Type[Any]], source_file: Optional[str] = None - ) -> None: + def __init__(self, *, resolve_type: Callable[[str], Type[Any]], source_file_name: str) -> None: super().__init__() - self._source_file = source_file or "" - self._resolve_type = resolve_type - self._gets: List[GetConstraints] = [] - - @property - def gets(self) -> List[GetConstraints]: - return self._gets - - @frozen_after_init - @dataclass(unsafe_hash=True) - class _GetDescriptor: - product_type_name: str - subject_arg_exprs: Tuple[ast.expr, ...] - - def __init__( - self, product_type_expr: ast.expr, subject_arg_exprs: Iterable[ast.expr] - ) -> None: - if not isinstance(product_type_expr, ast.Name): - raise ValueError( - f"Unrecognized type argument T for Get[T]: " f"{ast.dump(product_type_expr)}" - ) - self.product_type_name = product_type_expr.id - self.subject_arg_exprs = tuple(subject_arg_exprs) - - def _identify_source(self, node: Union[ast.expr, ast.stmt]) -> str: - start_pos = f"{node.lineno}:{node.col_offset}" + self.source_file_name = source_file_name + self.resolve_type = resolve_type + self.gets: List[GetConstraints] = [] - end_lineno, end_col_offset = [ - getattr(node, attr, None) for attr in ("end_lineno", "end_col_offset") - ] - end_pos = f"-{end_lineno}:{end_col_offset}" if end_lineno and end_col_offset else "" - - return f"{self._source_file} at {start_pos}{end_pos}" - - def _extract_get_descriptor(self, call_node: ast.Call) -> Optional[_GetDescriptor]: + @staticmethod + def maybe_extract_get_args(call_node: ast.Call) -> Optional[List[ast.expr]]: """Check if the node looks like a Get(T, ...) call.""" if not isinstance(call_node.func, ast.Name): return None if call_node.func.id != "Get": return None - return self._GetDescriptor( - product_type_expr=call_node.args[0], subject_arg_exprs=call_node.args[1:] - ) - - def _extract_constraints(self, get_descriptor: _GetDescriptor) -> GetConstraints[Any, Any]: - """Parses a `Get(T, ...)` call in one of its two legal forms to return its type constraints. - - :param get_descriptor: An `ast.Call` node representing a call to `Get(T, ...)`. - :return: A tuple of product type id and subject type id. - """ - - def render_args(): - rendered_args = ", ".join( - # Dump the Name's id to simplify output when available, falling back to the name of - # the node's class. - getattr(subject_arg, "id", type(subject_arg).__name__) - for subject_arg in get_descriptor.subject_arg_exprs - ) - return f"Get({get_descriptor.product_type_name}, {rendered_args})" - - if not 1 <= len(get_descriptor.subject_arg_exprs) <= 2: - raise ValueError( - f"Invalid Get. Expected either one or two args, but got: {render_args()}" - ) - - product_type = self._resolve_type(get_descriptor.product_type_name) - - if len(get_descriptor.subject_arg_exprs) == 1: - subject_constructor = get_descriptor.subject_arg_exprs[0] - if not isinstance(subject_constructor, ast.Call): - raise ValueError( - f"Expected Get(product_type, subject_type(subject)), but got: {render_args()}" - ) - constructor_type_id = subject_constructor.func.id # type: ignore[attr-defined] - return GetConstraints[Any, Any]( - product_type=product_type, - subject_declared_type=self._resolve_type(constructor_type_id), - ) - - subject_declared_type, _ = get_descriptor.subject_arg_exprs - if not isinstance(subject_declared_type, ast.Name): - raise ValueError( - f"Expected Get(product_type, subject_declared_type, subject), but got: " - f"{render_args()}" - ) - return GetConstraints[Any, Any]( - product_type=product_type, - subject_declared_type=self._resolve_type(subject_declared_type.id), - ) + return call_node.args def visit_Call(self, call_node: ast.Call) -> None: - get_descriptor = self._extract_get_descriptor(call_node) - if get_descriptor: - self._gets.append(self._extract_constraints(get_descriptor)) + get_args = self.maybe_extract_get_args(call_node) + if get_args is not None: + product_str, subject_str = GetConstraints.parse_product_and_subject_types( + get_args, source_file_name=self.source_file_name + ) + get = GetConstraints(self.resolve_type(product_str), self.resolve_type(subject_str)) + self.gets.append(get) # Ensure we descend into e.g. MultiGet(Get(...)...) calls. self.generic_visit(call_node) @@ -210,7 +135,7 @@ def wrapper(func): raise ValueError("The @rule decorator must be applied innermost of all decorators.") owning_module = sys.modules[func.__module__] - source = inspect.getsource(func) + source = inspect.getsource(func) or "" source_file = inspect.getsourcefile(func) beginning_indent = _get_starting_indent(source) if beginning_indent: @@ -245,7 +170,7 @@ def resolve_type(name): for child in ast.iter_child_nodes(parent): parents_table[child] = parent - rule_visitor = _RuleVisitor(source_file=source_file, resolve_type=resolve_type) + rule_visitor = _RuleVisitor(source_file_name=source_file, resolve_type=resolve_type) rule_visitor.visit(rule_func_node) gets = FrozenOrderedSet(rule_visitor.gets) diff --git a/src/python/pants/engine/rules_test.py b/src/python/pants/engine/rules_test.py index cb0c100fcd0..b255a6bae53 100644 --- a/src/python/pants/engine/rules_test.py +++ b/src/python/pants/engine/rules_test.py @@ -16,7 +16,7 @@ from pants.engine.goal import Goal, GoalSubsystem from pants.engine.internals.native import Native from pants.engine.internals.scheduler import Scheduler -from pants.engine.internals.selectors import GetConstraints +from pants.engine.internals.selectors import GetConstraints, GetParseError from pants.engine.rules import ( Get, MissingParameterTypeAnnotation, @@ -57,7 +57,9 @@ def create_scheduler(rules, validate=True, native=None): class RuleVisitorTest(unittest.TestCase): @staticmethod def _parse_rule_gets(rule_text: str, **types: Type) -> List[GetConstraints]: - rule_visitor = _RuleVisitor(resolve_type=lambda name: types[name]) + rule_visitor = _RuleVisitor( + resolve_type=lambda name: types[name], source_file_name="parse_rules.py" + ) rule_visitor.visit(ast.parse(rule_text)) return rule_visitor.gets @@ -153,19 +155,19 @@ def test_valid_get_unresolvable_subject_declared_type(self) -> None: self._parse_rule_gets("Get(int, DNE, 'bob')") def test_invalid_get_no_subject_args(self) -> None: - with pytest.raises(ValueError): + with pytest.raises(GetParseError): self._parse_rule_gets("Get(A, )", A=int) def test_invalid_get_too_many_subject_args(self) -> None: - with pytest.raises(ValueError): + with pytest.raises(GetParseError): self._parse_rule_gets("Get(A, B, 'bob', 3)", A=int, B=str) def test_invalid_get_invalid_subject_arg_no_constructor_call(self) -> None: - with pytest.raises(ValueError): + with pytest.raises(GetParseError): self._parse_rule_gets("Get(A, 'bob')", A=int) def test_invalid_get_invalid_product_type_not_a_type_name(self) -> None: - with pytest.raises(ValueError): + with pytest.raises(GetParseError): self._parse_rule_gets("Get(call(), A('bob'))", A=str)