Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 95 additions & 50 deletions src/taskgraph/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,17 @@
from io import BytesIO
from pprint import pformat
from subprocess import CalledProcessError
from typing import Optional, Union
from unittest.mock import Mock
from urllib.parse import urlparse
from urllib.request import urlopen

import mozilla_repo_urls
from voluptuous import ALLOW_EXTRA, Any, Optional, Required, Schema
import msgspec

from taskgraph.util import json, yaml
from taskgraph.util.readonlydict import ReadOnlyDict
from taskgraph.util.schema import validate_schema
from taskgraph.util.schema import Schema, validate_schema
from taskgraph.util.taskcluster import find_task_id, get_artifact_url
from taskgraph.util.vcs import get_repository

Expand All @@ -30,43 +31,50 @@ class ParameterMismatch(Exception):

#: Schema for base parameters.
#: Please keep this list sorted and in sync with docs/reference/parameters.rst
base_schema = Schema(
base_schema = Schema.from_dict(
{
Required("base_repository"): str,
Optional("base_ref"): str,
Required("base_rev"): str,
Required("build_date"): int,
Required("build_number"): int,
Required("do_not_optimize"): [str],
Required("enable_always_target"): Any(bool, [str]),
Required("existing_tasks"): {str: str},
Required("files_changed"): [str],
Required("filters"): [str],
Required("head_ref"): str,
Required("head_repository"): str,
Required("head_rev"): str,
Required("head_tag"): str,
Required("level"): str,
Required("moz_build_date"): str,
Required("next_version"): Any(str, None),
Required("optimize_strategies"): Any(str, None),
Required("optimize_target_tasks"): bool,
Required("owner"): str,
Required("project"): str,
Required("pushdate"): int,
Required("pushlog_id"): str,
Required("repository_type"): str,
"base_repository": str,
"base_ref": Optional[str],
"base_rev": str,
"build_date": int,
"build_number": int,
"do_not_optimize": list[str],
"enable_always_target": Union[bool, list[str]],
"existing_tasks": dict[str, str],
"files_changed": list[str],
"filters": list[str],
"head_ref": str,
"head_repository": str,
"head_rev": str,
"head_tag": str,
"level": str,
"moz_build_date": str,
"next_version": Optional[str],
"optimize_strategies": Optional[str],
"optimize_target_tasks": bool,
"owner": str,
"project": str,
"pushdate": int,
"pushlog_id": str,
"repository_type": str,
# target-kinds is not included, since it should never be
# used at run-time
Required("target_tasks_method"): str,
Required("tasks_for"): str,
Required("version"): Any(str, None),
Optional("code-review"): {
Required("phabricator-build-target"): str,
},
}
"target_tasks_method": str,
"tasks_for": str,
"version": Optional[str],
"code-review": Schema.from_dict(
{"phabricator-build-target": str},
name="CodeReviewConfig",
optional=True,
),
},
name="BaseParametersSchema",
forbid_unknown_fields=False,
kw_only=True,
)

_parameter_extensions: list = []


def get_contents(path):
with open(path) as fh:
Expand All @@ -83,11 +91,21 @@ def _get_defaults(repo_root=None):
repo_path = repo_root or os.getcwd()
try:
repo = get_repository(repo_path)
except RuntimeError:
# Use fake values if no repo is detected.
repo = Mock(branch="", head_rev="", tool="git")
# Resolve git-backed attributes eagerly so any subprocess failures
# (e.g. Windows "dubious ownership" when safe.directory isn't honored)
# are caught by the except below instead of escaping later.
branch = repo.branch
head_rev = repo.head_rev
tool = repo.tool
files_changed = repo.get_changed_files("AM")
except (RuntimeError, CalledProcessError):
# Use fake values if no repo is detected or git refuses to operate.
repo = Mock()
repo.get_url.return_value = ""
repo.get_changed_files.return_value = []
branch = ""
head_rev = ""
tool = "git"
files_changed = []

try:
repo_url = repo.get_url()
Expand All @@ -110,11 +128,11 @@ def _get_defaults(repo_root=None):
"do_not_optimize": [],
"enable_always_target": True,
"existing_tasks": {},
"files_changed": lambda: repo.get_changed_files("AM"),
"files_changed": files_changed,
"filters": ["target_tasks_method"],
"head_ref": repo.branch or repo.head_rev,
"head_ref": branch or head_rev,
"head_repository": repo_url,
"head_rev": repo.head_rev,
"head_rev": head_rev,
"head_tag": "",
"level": "3",
"moz_build_date": datetime.now().strftime("%Y%m%d%H%M%S"),
Expand All @@ -125,7 +143,7 @@ def _get_defaults(repo_root=None):
"project": project,
"pushdate": int(time.time()),
"pushlog_id": "0",
"repository_type": repo.tool,
"repository_type": tool,
"target_tasks_method": "default",
"tasks_for": "",
"version": get_version(repo_path),
Expand All @@ -143,19 +161,27 @@ def extend_parameters_schema(schema, defaults_fn=None):
graph-configuration.

Args:
schema (Schema): The voluptuous.Schema object used to describe extended
parameters.
schema: A msgspec ``Schema`` subclass describing extended parameters.
defaults_fn (function): A function which takes no arguments and returns a
dict mapping parameter name to default value in the
event strict=False (optional).
"""
global base_schema
global defaults_functions
base_schema = base_schema.extend(schema)
if not (isinstance(schema, type) and issubclass(schema, msgspec.Struct)):
raise TypeError(
"extend_parameters_schema requires a msgspec Schema subclass; "
f"got {type(schema).__name__}"
)
_parameter_extensions.append(schema)
if defaults_fn:
defaults_functions.append(defaults_fn)


def _schema_key_names(schema) -> set:
"""Return the data-level field names declared by a parameters schema."""
return {f.encode_name for f in msgspec.structs.fields(schema)}


class Parameters(ReadOnlyDict):
"""An immutable dictionary with nicer KeyError messages on failure"""

Expand Down Expand Up @@ -214,11 +240,30 @@ def _fill_defaults(repo_root=None, **kwargs):
return kwargs

def check(self):
schema = (
base_schema if self.strict else base_schema.extend({}, extra=ALLOW_EXTRA)
)
data = dict(self.copy())
try:
validate_schema(schema, self.copy(), "Invalid parameters:")
# Validate core fields against just the subset of data owned by the
# base schema. Extension keys are validated separately below, and a
# strict-mode check rejects anything unknown to either.
base_keys = _schema_key_names(base_schema)
base_data = {k: v for k, v in data.items() if k in base_keys}
validate_schema(base_schema, base_data, "Invalid parameters:")

# Validate each registered extension against the keys it declares.
allowed = set(base_keys)
for ext in _parameter_extensions:
ext_keys = _schema_key_names(ext)
allowed |= ext_keys
ext_data = {k: data[k] for k in ext_keys if k in data}
validate_schema(ext, ext_data, "Invalid parameters:")

# Strict mode: reject any data key not covered by base or extensions.
if self.strict:
unknown = sorted(set(data) - allowed)
if unknown:
raise Exception(
"Invalid parameters:\nunknown keys: " + ", ".join(unknown)
)
except Exception as e:
raise ParameterMismatch(str(e))

Expand Down
15 changes: 9 additions & 6 deletions taskcluster/self_taskgraph/custom_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
# file, You can obtain one at http://mozilla.org/MPL/2.0/.

import os
from typing import Annotated, Optional

from voluptuous import All, Any, Range, Required
import msgspec

from taskgraph.parameters import extend_parameters_schema
from taskgraph.util.schema import Schema


def get_defaults(repo_root):
Expand All @@ -15,14 +17,15 @@ def get_defaults(repo_root):
}


extend_parameters_schema(
{
Required("pull_request_number"): Any(All(int, Range(min=1)), None),
},
defaults_fn=get_defaults,
CustomParametersSchema = Schema.from_dict(
{"pull_request_number": Optional[Annotated[int, msgspec.Meta(ge=1)]]},
name="CustomParametersSchema",
)


extend_parameters_schema(CustomParametersSchema, defaults_fn=get_defaults)


def decision_parameters(graph_config, parameters):
if parameters["tasks_for"] == "github-release":
parameters["target_tasks_method"] = "release"
Expand Down
21 changes: 8 additions & 13 deletions test/test_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
import gzip
import os
from base64 import b64decode
from typing import Optional
from unittest import TestCase, mock

import mozilla_repo_urls
import pytest
from voluptuous import Optional, Required, Schema

import taskgraph # noqa: F401
from taskgraph import parameters
Expand All @@ -21,6 +21,7 @@
extend_parameters_schema,
load_parameters_file,
)
from taskgraph.util.schema import Schema

from .mockedopen import MockedOpen

Expand Down Expand Up @@ -274,20 +275,16 @@ def test_parameters_format_spec(spec, expected):


def test_extend_parameters_schema(monkeypatch):
monkeypatch.setattr(
parameters,
"base_schema",
Schema(
{
Required("foo"): str,
}
),
)
FooSchema = Schema.from_dict({"foo": str}, name="FooSchema")
BarSchema = Schema.from_dict({"bar": Optional[bool]}, name="BarSchema")

monkeypatch.setattr(parameters, "base_schema", FooSchema)
monkeypatch.setattr(
parameters,
"defaults_functions",
list(parameters.defaults_functions),
)
monkeypatch.setattr(parameters, "_parameter_extensions", [])

with pytest.raises(ParameterMismatch):
Parameters(strict=False).check()
Expand All @@ -296,9 +293,7 @@ def test_extend_parameters_schema(monkeypatch):
Parameters(foo="1", bar=True).check()

extend_parameters_schema(
{
Optional("bar"): bool,
},
BarSchema,
defaults_fn=lambda root: {"foo": "1", "bar": False},
)

Expand Down
Loading