diff --git a/src/check_datapackage/extensions.py b/src/check_datapackage/extensions.py index afe67685..36185ca2 100644 --- a/src/check_datapackage/extensions.py +++ b/src/check_datapackage/extensions.py @@ -1,7 +1,11 @@ -import re from collections.abc import Callable -from typing import Any, Self +from dataclasses import dataclass +from operator import itemgetter +from typing import Any, Self, cast +from jsonpath import JSONPath, compile +from jsonpath.segments import JSONPathRecursiveDescentSegment +from jsonpath.selectors import NameSelector from pydantic import BaseModel, PrivateAttr, field_validator, model_validator from check_datapackage.internals import ( @@ -91,6 +95,51 @@ def apply(self, properties: dict[str, Any]) -> list[Issue]: ) +@dataclass(frozen=True) +class TargetJsonPath: + """A JSON path targeted by a `RequiredCheck`. + + Attributes: + parent (str): The JSON path to the parent of the targeted field. + field (str): The name of the targeted field. + """ + + parent: str + field: str + + +def _jsonpath_to_targets(jsonpath: JSONPath) -> list[TargetJsonPath]: + """Create a list of `TargetJsonPath`s from a `JSONPath`.""" + # Segments are path parts, e.g., `resources`, `*`, `name` for `$.resources[*].name` + if not jsonpath.segments: + return [] + + full_path = jsonpath.segments[0].token.path + last_segment = jsonpath.segments[-1] + if isinstance(last_segment, JSONPathRecursiveDescentSegment): + raise ValueError( + f"Cannot use the JSON path `{full_path}` in `RequiredCheck`" + " because it ends in the recursive descent (`..`) operator." + ) + + # Things like field names, array indices, and/or wildcards. + selectors = last_segment.selectors + if _filter(selectors, lambda selector: not isinstance(selector, NameSelector)): + raise ValueError( + f"Cannot use `RequiredCheck` for the JSON path `{full_path}`" + " because it doesn't end in a name selector." + ) + + parent = "".join(_map(jsonpath.segments[:-1], str)) + name_selectors = cast(tuple[NameSelector], selectors) + return _map( + name_selectors, + lambda selector: TargetJsonPath( + parent=str(compile(parent)), field=selector.name + ), + ) + + class RequiredCheck(BaseModel, frozen=True): """Set a specific property as required. @@ -112,22 +161,18 @@ class RequiredCheck(BaseModel, frozen=True): jsonpath: JsonPath message: str - _field_name: str = PrivateAttr() + _targets: list[TargetJsonPath] = PrivateAttr() @model_validator(mode="after") def _check_field_name_in_jsonpath(self) -> Self: - field_name_match = re.search(r"(? list[Issue]: @@ -140,16 +185,27 @@ def apply(self, properties: dict[str, Any]) -> list[Issue]: A list of `Issue`s. """ matching_paths = _get_direct_jsonpaths(self.jsonpath, properties) - indirect_parent_path = self.jsonpath.removesuffix(self._field_name) - direct_parent_paths = _get_direct_jsonpaths(indirect_parent_path, properties) + return _flat_map( + self._targets, + lambda target: self._target_to_issues(target, matching_paths, properties), + ) + + def _target_to_issues( + self, + target: TargetJsonPath, + matching_paths: list[str], + properties: dict[str, Any], + ) -> list[Issue]: + """Create a list of `Issue`s from a `TargetJsonPath`.""" + direct_parent_paths = _get_direct_jsonpaths(target.parent, properties) missing_paths = _filter( direct_parent_paths, - lambda path: f"{path}{self._field_name}" not in matching_paths, + lambda path: f"{path}.{target.field}" not in matching_paths, ) return _map( missing_paths, lambda path: Issue( - jsonpath=path + self._field_name, + jsonpath=f"{path}.{target.field}", type="required", message=self.message, ), diff --git a/tests/test_extensions.py b/tests/test_extensions.py index 6c9a2c32..bb303b68 100644 --- a/tests/test_extensions.py +++ b/tests/test_extensions.py @@ -7,6 +7,7 @@ example_resource_properties, ) from check_datapackage.extensions import CustomCheck, Extensions, RequiredCheck +from check_datapackage.internals import _map from check_datapackage.issue import Issue lowercase_check = CustomCheck( @@ -157,13 +158,61 @@ def test_required_check_array_wildcard(): ] +def test_required_check_union(): + properties = example_package_properties() + del properties["licenses"] + required_check = RequiredCheck( + jsonpath="$['licenses', 'sources'] | $.resources[*]['licenses', 'sources']", + message="Package and resources must have licenses and sources.", + ) + config = Config(extensions=Extensions(required_checks=[required_check])) + + issues = check(properties, config=config) + + assert all(_map(issues, lambda issue: issue.type == "required")) + assert _map(issues, lambda issue: issue.jsonpath) == [ + "$.licenses", + "$.resources[0].licenses", + "$.resources[0].sources", + "$.sources", + ] + + +def test_required_check_non_final_recursive_descent(): + properties = example_package_properties() + properties["resources"][0]["licenses"] = [{"name": "odc-pddl"}] + required_check = RequiredCheck( + jsonpath="$..licenses[*].title", + message="Licenses must have a title.", + ) + config = Config(extensions=Extensions(required_checks=[required_check])) + + issues = check(properties, config=config) + + assert _map(issues, lambda issue: issue.jsonpath) == [ + "$.licenses[0].title", + "$.resources[0].licenses[0].title", + ] + + +def test_required_check_root(): + properties = example_package_properties() + required_check = RequiredCheck( + jsonpath="$", + message="Package must have a root.", + ) + config = Config(extensions=Extensions(required_checks=[required_check])) + + issues = check(properties, config=config) + + assert issues == [] + + @mark.parametrize( "jsonpath", [ "<><>bad.path", - "$", "..*", - "created", "$..path", "..resources", "$.resources[0].*",