Skip to content

Commit

Permalink
Improve errors for invalid Get objects
Browse files Browse the repository at this point in the history
# Rust tests and lints will be skipped. Delete if not intended.
[ci skip-rust]

# Building wheels and fs_util will be skipped. Delete if not intended.
[ci skip-build-wheels]
  • Loading branch information
Eric-Arellano committed Aug 11, 2020
1 parent b10cca1 commit 3611d5e
Show file tree
Hide file tree
Showing 4 changed files with 175 additions and 108 deletions.
81 changes: 78 additions & 3 deletions src/python/pants/engine/internals/selectors.py
@@ -1,15 +1,18 @@
# Copyright 2015 Pants project contributors (see CONTRIBUTORS.md).
# Licensed under the Apache License, Version 2.0 (see LICENSE).

import ast
import itertools
from dataclasses import dataclass
from functools import partial
from textwrap import dedent
from typing import (
Any,
Generator,
Generic,
Iterable,
Optional,
Sequence,
Tuple,
Type,
TypeVar,
Expand All @@ -26,11 +29,83 @@
_SubjectType = TypeVar("_SubjectType")


class GetParseError(ValueError):
def __init__(
self, explanation: str, *, get_args: Sequence[ast.expr], source_file_name: str
) -> None:
def render_arg(expr: ast.expr) -> str:
if isinstance(expr, ast.Name):
return expr.id
if isinstance(expr, ast.Call):
# Check if it's a top-level function call.
if hasattr(expr.func, "id"):
return f"{expr.func.id}()" # type: ignore[attr-defined]
# Check if it's a method call.
if hasattr(expr.func, "attr") and hasattr(expr.func, "value"):
return f"{expr.func.value.id}.{expr.func.attr}()" # type: ignore[attr-defined]

# Fall back to the name of the ast node's class.
return str(type(expr))

rendered_args = ", ".join(render_arg(arg) for arg in get_args)
# TODO: Add the line numbers for the `Get`. The number for `get_args[0].lineno` are
# off because they're relative to the wrapping rule.
super().__init__(
f"Invalid Get. {explanation} Failed for Get({rendered_args}) in {source_file_name}."
)


@frozen_after_init
@dataclass(unsafe_hash=True)
class GetConstraints(Generic[_ProductType, _SubjectType]):
product_type: Type[_ProductType]
subject_declared_type: Type[_SubjectType]
class GetConstraints:
product_type: Type
subject_declared_type: Type

@classmethod
def parse_product_and_subject_types(
cls, get_args: Sequence[ast.expr], *, source_file_name: str
) -> Tuple[str, str]:
parse_error = partial(GetParseError, get_args=get_args, source_file_name=source_file_name)

if len(get_args) not in (2, 3):
raise parse_error(
f"Expected either two or three arguments, but got {len(get_args)} arguments."
)

product_expr = get_args[0]
if not isinstance(product_expr, ast.Name):
raise parse_error(
"The first argument should be the type of the product, like `Digest` or "
"`ProcessResult`."
)
product_type = product_expr.id

subject_args = get_args[1:]
if len(subject_args) == 1:
subject_constructor = subject_args[0]
if not isinstance(subject_constructor, ast.Call):
raise parse_error(
"Because you are using the shorthand form Get(ProductType, "
"SubjectType(constructor args), the second argument should be a constructor "
"call, like `MergeDigest(...)` or `Process(...)`."
)
if not hasattr(subject_constructor.func, "id"):
raise parse_error(
"Because you are using the shorthand form Get(ProductType, "
"SubjectType(constructor args), the second argument should be a top-level "
"constructor function call, like `MergeDigest(...)` or `Process(...)`, rather "
"than a method call."
)
return product_type, subject_constructor.func.id # type: ignore[attr-defined]

subject_type, _ = subject_args
if not isinstance(subject_type, ast.Name):
raise parse_error(
"Because you are using the longhand form Get(ProductType, SubjectType, "
"subject_instance), the second argument should be a type, like `MergeDigests` or "
"`Process`."
)
return product_type, subject_type.id


@frozen_after_init
Expand Down
81 changes: 73 additions & 8 deletions src/python/pants/engine/internals/selectors_test.py
@@ -1,12 +1,77 @@
# Copyright 2016 Pants project contributors (see CONTRIBUTORS.md).
# Licensed under the Apache License, Version 2.0 (see LICENSE).

import ast
import re
from typing import Any
from typing import Any, Tuple

import pytest

from pants.engine.internals.selectors import Get, MultiGet
from pants.engine.internals.selectors import Get, GetConstraints, GetParseError, MultiGet


def parse_get_types(get: str) -> Tuple[str, str]:
get_args = ast.parse(get).body[0].value.args # type: ignore[attr-defined]
return GetConstraints.parse_product_and_subject_types(get_args, source_file_name="test.py")


def test_parse_get_types_valid() -> None:
assert parse_get_types("Get(P, S, subject)") == ("P", "S")
assert parse_get_types("Get(P, S())") == ("P", "S")


def assert_parse_get_types_fails(get: str, *, expected_explanation: str) -> None:
with pytest.raises(GetParseError) as exc:
parse_get_types(get)
assert str(exc.value) == f"Invalid Get. {expected_explanation} Failed for {get} in test.py."


def test_parse_get_types_wrong_number_args() -> None:
assert_parse_get_types_fails(
"Get()",
expected_explanation="Expected either two or three arguments, but got 0 arguments.",
)
assert_parse_get_types_fails(
"Get(P, S1, S2(), S3.create())",
expected_explanation="Expected either two or three arguments, but got 4 arguments.",
)


def test_parse_get_types_invalid_product() -> None:
assert_parse_get_types_fails(
"Get(P(), S, subject)",
expected_explanation=(
"The first argument should be the type of the product, like `Digest` or "
"`ProcessResult`."
),
)


def test_parse_get_types_invalid_subject() -> None:
assert_parse_get_types_fails(
"Get(P, S)",
expected_explanation=(
"Because you are using the shorthand form Get(ProductType, SubjectType(constructor "
"args), the second argument should be a constructor call, like `MergeDigest(...)` or "
"`Process(...)`."
),
)
assert_parse_get_types_fails(
"Get(P, Subject.create())",
expected_explanation=(
"Because you are using the shorthand form Get(ProductType, SubjectType(constructor "
"args), the second argument should be a top-level constructor function call, like "
"`MergeDigest(...)` or `Process(...)`, rather than a method call."
),
)
assert_parse_get_types_fails(
"Get(P, Subject(), subject)",
expected_explanation=(
"Because you are using the longhand form Get(ProductType, SubjectType, "
"subject_instance), the second argument should be a type, like `MergeDigests` or "
"`Process`."
),
)


class AClass:
Expand All @@ -18,35 +83,35 @@ def __eq__(self, other: Any):
return type(self) == type(other)


def test_create() -> None:
def test_create_get() -> None:
get = Get(AClass, BClass, 42)
assert get.product_type is AClass
assert get.subject_declared_type is BClass
assert get.subject == 42


def test_create_abbreviated() -> None:
def test_create_get_abbreviated() -> None:
# Test the equivalence of the 1-arg and 2-arg versions.
assert Get(AClass, BClass()) == Get(AClass, BClass, BClass())


def test_invalid_abbreviated() -> None:
def test_invalid_get_abbreviated() -> None:
with pytest.raises(
expected_exception=TypeError,
match=re.escape(f"The subject argument cannot be a type, given {BClass}."),
):
Get(AClass, BClass)


def test_invalid_subject() -> None:
def test_invalid_get_subject() -> None:
with pytest.raises(
expected_exception=TypeError,
match=re.escape(f"The subject argument cannot be a type, given {BClass}."),
):
Get(AClass, BClass, BClass)


def test_invalid_subject_declared_type() -> None:
def test_invalid_get_subject_declared_type() -> None:
with pytest.raises(
expected_exception=TypeError,
match=re.escape(
Expand All @@ -56,7 +121,7 @@ def test_invalid_subject_declared_type() -> None:
Get(AClass, 1, BClass) # type: ignore[call-overload]


def test_invalid_product_type() -> None:
def test_invalid_get_product_type() -> None:
with pytest.raises(
expected_exception=TypeError,
match=re.escape(f"The product type argument must be a type, given {1} of type {type(1)}."),
Expand Down
107 changes: 16 additions & 91 deletions src/python/pants/engine/rules.py
Expand Up @@ -46,104 +46,29 @@ def side_effecting(cls):
class _RuleVisitor(ast.NodeVisitor):
"""Pull `Get` calls out of an @rule body."""

def __init__(
self, *, resolve_type: Callable[[str], Type[Any]], source_file: Optional[str] = None
) -> None:
def __init__(self, *, resolve_type: Callable[[str], Type[Any]], source_file_name: str) -> None:
super().__init__()
self._source_file = source_file or "<string>"
self._resolve_type = resolve_type
self._gets: List[GetConstraints] = []

@property
def gets(self) -> List[GetConstraints]:
return self._gets

@frozen_after_init
@dataclass(unsafe_hash=True)
class _GetDescriptor:
product_type_name: str
subject_arg_exprs: Tuple[ast.expr, ...]

def __init__(
self, product_type_expr: ast.expr, subject_arg_exprs: Iterable[ast.expr]
) -> None:
if not isinstance(product_type_expr, ast.Name):
raise ValueError(
f"Unrecognized type argument T for Get[T]: " f"{ast.dump(product_type_expr)}"
)
self.product_type_name = product_type_expr.id
self.subject_arg_exprs = tuple(subject_arg_exprs)

def _identify_source(self, node: Union[ast.expr, ast.stmt]) -> str:
start_pos = f"{node.lineno}:{node.col_offset}"
self.source_file_name = source_file_name
self.resolve_type = resolve_type
self.gets: List[GetConstraints] = []

end_lineno, end_col_offset = [
getattr(node, attr, None) for attr in ("end_lineno", "end_col_offset")
]
end_pos = f"-{end_lineno}:{end_col_offset}" if end_lineno and end_col_offset else ""

return f"{self._source_file} at {start_pos}{end_pos}"

def _extract_get_descriptor(self, call_node: ast.Call) -> Optional[_GetDescriptor]:
@staticmethod
def maybe_extract_get_args(call_node: ast.Call) -> Optional[List[ast.expr]]:
"""Check if the node looks like a Get(T, ...) call."""
if not isinstance(call_node.func, ast.Name):
return None
if call_node.func.id != "Get":
return None
return self._GetDescriptor(
product_type_expr=call_node.args[0], subject_arg_exprs=call_node.args[1:]
)

def _extract_constraints(self, get_descriptor: _GetDescriptor) -> GetConstraints[Any, Any]:
"""Parses a `Get(T, ...)` call in one of its two legal forms to return its type constraints.
:param get_descriptor: An `ast.Call` node representing a call to `Get(T, ...)`.
:return: A tuple of product type id and subject type id.
"""

def render_args():
rendered_args = ", ".join(
# Dump the Name's id to simplify output when available, falling back to the name of
# the node's class.
getattr(subject_arg, "id", type(subject_arg).__name__)
for subject_arg in get_descriptor.subject_arg_exprs
)
return f"Get({get_descriptor.product_type_name}, {rendered_args})"

if not 1 <= len(get_descriptor.subject_arg_exprs) <= 2:
raise ValueError(
f"Invalid Get. Expected either one or two args, but got: {render_args()}"
)

product_type = self._resolve_type(get_descriptor.product_type_name)

if len(get_descriptor.subject_arg_exprs) == 1:
subject_constructor = get_descriptor.subject_arg_exprs[0]
if not isinstance(subject_constructor, ast.Call):
raise ValueError(
f"Expected Get(product_type, subject_type(subject)), but got: {render_args()}"
)
constructor_type_id = subject_constructor.func.id # type: ignore[attr-defined]
return GetConstraints[Any, Any](
product_type=product_type,
subject_declared_type=self._resolve_type(constructor_type_id),
)

subject_declared_type, _ = get_descriptor.subject_arg_exprs
if not isinstance(subject_declared_type, ast.Name):
raise ValueError(
f"Expected Get(product_type, subject_declared_type, subject), but got: "
f"{render_args()}"
)
return GetConstraints[Any, Any](
product_type=product_type,
subject_declared_type=self._resolve_type(subject_declared_type.id),
)
return call_node.args

def visit_Call(self, call_node: ast.Call) -> None:
get_descriptor = self._extract_get_descriptor(call_node)
if get_descriptor:
self._gets.append(self._extract_constraints(get_descriptor))
get_args = self.maybe_extract_get_args(call_node)
if get_args is not None:
product_str, subject_str = GetConstraints.parse_product_and_subject_types(
get_args, source_file_name=self.source_file_name
)
get = GetConstraints(self.resolve_type(product_str), self.resolve_type(subject_str))
self.gets.append(get)
# Ensure we descend into e.g. MultiGet(Get(...)...) calls.
self.generic_visit(call_node)

Expand Down Expand Up @@ -210,7 +135,7 @@ def wrapper(func):
raise ValueError("The @rule decorator must be applied innermost of all decorators.")

owning_module = sys.modules[func.__module__]
source = inspect.getsource(func)
source = inspect.getsource(func) or "<string>"
source_file = inspect.getsourcefile(func)
beginning_indent = _get_starting_indent(source)
if beginning_indent:
Expand Down Expand Up @@ -245,7 +170,7 @@ def resolve_type(name):
for child in ast.iter_child_nodes(parent):
parents_table[child] = parent

rule_visitor = _RuleVisitor(source_file=source_file, resolve_type=resolve_type)
rule_visitor = _RuleVisitor(source_file_name=source_file, resolve_type=resolve_type)
rule_visitor.visit(rule_func_node)

gets = FrozenOrderedSet(rule_visitor.gets)
Expand Down

0 comments on commit 3611d5e

Please sign in to comment.