Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions dvc/config_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,4 +289,9 @@ class RelPath(str):
"bool": All(Lower, Choices("store_true", "boolean_optional")),
"list": All(Lower, Choices("nargs", "append")),
},
"hydra": {
Optional("enabled", default=False): Bool,
"config_dir": str,
"config_name": str,
},
}
22 changes: 19 additions & 3 deletions dvc/repo/experiments/queue/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from funcy import cached_property

from dvc.dependency import ParamsDependency
from dvc.env import DVCLIVE_RESUME
from dvc.exceptions import DvcException
from dvc.ui import ui
Expand Down Expand Up @@ -521,21 +522,36 @@ def _update_params(self, params: Dict[str, List[str]]):
provided via `exp run --set-param`.

.. _Hydra Override:
https://hydra.cc/docs/next/advanced/override_grammar/basic/
https://hydra.cc/docs/advanced/override_grammar/basic/
"""
logger.debug("Using experiment params '%s'", params)

try:
from dvc.utils.hydra import apply_overrides
from dvc.utils.hydra import apply_overrides, compose_and_dump
except ValueError:
if sys.version_info >= (3, 11):
raise DvcException(
"--set-param is not supported in Python >= 3.11"
)
raise

hydra_config = self.repo.config.get("hydra", {})
hydra_enabled = hydra_config.get("enabled", False)
hydra_output_file = ParamsDependency.DEFAULT_PARAMS_FILE
for path, overrides in params.items():
apply_overrides(path, overrides)
if hydra_enabled and path == hydra_output_file:
config_dir = os.path.join(
self.repo.root_dir, hydra_config.get("config_dir", "conf")
)
config_name = hydra_config.get("config_name", "config")
compose_and_dump(
path,
config_dir,
config_name,
overrides,
)
else:
apply_overrides(path, overrides)

# Force params file changes to be staged in git
# Otherwise in certain situations the changes to params file may be
Expand Down
15 changes: 12 additions & 3 deletions dvc/repo/experiments/run.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
from typing import Dict, Iterable, Optional

from dvc.dependency.param import ParamsDependency
from dvc.repo import locked
from dvc.ui import ui
from dvc.utils.cli_parse import to_path_overrides
Expand Down Expand Up @@ -31,21 +32,29 @@ def run(
return repo.experiments.reproduce_celery(entries, jobs=jobs)

if params:
params = to_path_overrides(params)
path_overrides = to_path_overrides(params)
else:
path_overrides = {}

hydra_enabled = repo.config.get("hydra", {}).get("enabled", False)
hydra_output_file = ParamsDependency.DEFAULT_PARAMS_FILE
if hydra_enabled and hydra_output_file not in path_overrides:
# Force `_update_params` even if `--set-param` was not used
path_overrides[hydra_output_file] = []

if queue:
if not kwargs.get("checkpoint_resume", None):
kwargs["reset"] = True
queue_entry = repo.experiments.queue_one(
repo.experiments.celery_queue,
targets=targets,
params=params,
params=path_overrides,
**kwargs,
)
name = queue_entry.name or queue_entry.stash_rev[:7]
ui.write(f"Queued experiment '{name}' for future execution.")
return {}

return repo.experiments.reproduce_one(
targets=targets, params=params, tmp_dir=tmp_dir, **kwargs
targets=targets, params=path_overrides, tmp_dir=tmp_dir, **kwargs
)
29 changes: 28 additions & 1 deletion dvc/utils/hydra.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from pathlib import Path
from typing import TYPE_CHECKING, List

from hydra import compose, initialize_config_dir
from hydra._internal.config_loader_impl import ConfigLoaderImpl
from hydra.core.override_parser.overrides_parser import OverridesParser
from hydra.errors import ConfigCompositionException, OverrideParseException
Expand All @@ -10,12 +11,38 @@
from dvc.exceptions import InvalidArgumentError

from .collections import merge_dicts, remove_missing_keys, to_omegaconf
from .serialize import MODIFIERS
from .serialize import DUMPERS, MODIFIERS

if TYPE_CHECKING:
from dvc.types import StrPath


def compose_and_dump(
output_file: "StrPath",
config_dir: str,
config_name: str,
overrides: List[str],
) -> None:
"""Compose Hydra config and dumpt it to `output_file`.

Args:
output_file: File where the composed config will be dumped.
config_dir: Folder containing the Hydra config files.
Must be absolute file system path.
config_name: Name of the config file containing defaults,
without the .yaml extension.
overrides: List of `Hydra Override`_ patterns.

.. _Hydra Override:
https://hydra.cc/docs/advanced/override_grammar/basic/
"""
with initialize_config_dir(config_dir, version_base=None):
cfg = compose(config_name=config_name, overrides=overrides)

dumper = DUMPERS[Path(output_file).suffix.lower()]
dumper(output_file, OmegaConf.to_object(cfg))


def apply_overrides(path: "StrPath", overrides: List[str]) -> None:
"""Update `path` params with the provided `Hydra Override`_ patterns.

Expand Down
14 changes: 0 additions & 14 deletions tests/func/experiments/test_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,20 +104,6 @@ def test_failed_exp_workspace(
)


@pytest.mark.parametrize(
"changes, expected",
[
[["foo=baz"], "foo: baz\ngoo:\n bag: 3.0\nlorem: false"],
[["params.yaml:foo=baz"], "foo: baz\ngoo:\n bag: 3.0\nlorem: false"],
],
)
def test_modify_params(params_repo, dvc, changes, expected):
dvc.experiments.run(params=changes)
# pylint: disable=unspecified-encoding
with open("params.yaml", mode="r") as fobj:
assert fobj.read().strip() == expected


def test_apply(tmp_dir, scm, dvc, exp_stage):
from dvc.exceptions import InvalidArgumentError
from dvc.repo.experiments.exceptions import ApplyConflictError
Expand Down
87 changes: 87 additions & 0 deletions tests/func/experiments/test_set_params.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import sys

import pytest

from ..utils.test_hydra import hydra_setup


@pytest.mark.parametrize(
"changes, expected",
[
[["foo=baz"], "foo: baz\ngoo:\n bag: 3.0\nlorem: false"],
[["params.yaml:foo=baz"], "foo: baz\ngoo:\n bag: 3.0\nlorem: false"],
],
)
def test_modify_params(params_repo, dvc, changes, expected):
dvc.experiments.run(params=changes)
# pylint: disable=unspecified-encoding
with open("params.yaml", mode="r") as fobj:
assert fobj.read().strip() == expected


@pytest.mark.parametrize(
"hydra_enabled",
[
pytest.param(
True,
marks=pytest.mark.skipif(
sys.version_info >= (3, 11), reason="unsupported on 3.11"
),
),
False,
],
)
@pytest.mark.parametrize(
"config_dir,config_name",
[
(None, None),
(None, "bar"),
("conf", "bar"),
],
)
def test_hydra_compose_and_dump(
tmp_dir, params_repo, dvc, hydra_enabled, config_dir, config_name
):
hydra_setup(
tmp_dir,
config_dir=config_dir or "conf",
config_name=config_name or "config",
)

dvc.experiments.run()
assert (tmp_dir / "params.yaml").parse() == {
"foo": [{"bar": 1}, {"baz": 2}],
"goo": {"bag": 3.0},
"lorem": False,
}

with dvc.config.edit() as conf:
if hydra_enabled:
conf["hydra"]["enabled"] = True
if config_dir is not None:
conf["hydra"]["config_dir"] = config_dir
if config_name is not None:
conf["hydra"]["config_name"] = config_name

dvc.experiments.run()

if hydra_enabled:
assert (tmp_dir / "params.yaml").parse() == {
"db": {"driver": "mysql", "user": "omry", "pass": "secret"},
}

dvc.experiments.run(params=["db=postgresql"])
assert (tmp_dir / "params.yaml").parse() == {
"db": {
"driver": "postgresql",
"user": "foo",
"pass": "bar",
"timeout": 10,
}
}
else:
assert (tmp_dir / "params.yaml").parse() == {
"foo": [{"bar": 1}, {"baz": 2}],
"goo": {"bag": 3.0},
"lorem": False,
}
55 changes: 55 additions & 0 deletions tests/func/utils/test_hydra.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import sys

import pytest

from dvc.exceptions import InvalidArgumentError
Expand Down Expand Up @@ -123,3 +125,56 @@ def test_invalid_overrides(tmp_dir, overrides):
)
with pytest.raises(InvalidArgumentError):
apply_overrides(path=params_file.name, overrides=overrides)


def hydra_setup(tmp_dir, config_dir, config_name):
config_dir = tmp_dir / config_dir
(config_dir / "db").mkdir(parents=True)
(config_dir / f"{config_name}.yaml").dump({"defaults": [{"db": "mysql"}]})
(config_dir / "db" / "mysql.yaml").dump(
{"driver": "mysql", "user": "omry", "pass": "secret"}
)
(config_dir / "db" / "postgresql.yaml").dump(
{"driver": "postgresql", "user": "foo", "pass": "bar", "timeout": 10}
)
return str(config_dir)


@pytest.mark.skipif(sys.version_info >= (3, 11), reason="unsupported on 3.11")
@pytest.mark.parametrize("suffix", ["yaml", "toml", "json"])
@pytest.mark.parametrize(
"overrides,expected",
[
([], {"db": {"driver": "mysql", "user": "omry", "pass": "secret"}}),
(
["db=postgresql"],
{
"db": {
"driver": "postgresql",
"user": "foo",
"pass": "bar",
"timeout": 10,
}
},
),
(
["db=postgresql", "db.timeout=20"],
{
"db": {
"driver": "postgresql",
"user": "foo",
"pass": "bar",
"timeout": 20,
}
},
),
],
)
def test_compose_and_dump(tmp_dir, suffix, overrides, expected):
from dvc.utils.hydra import compose_and_dump

config_name = "config"
config_dir = hydra_setup(tmp_dir, "conf", "config")
output_file = tmp_dir / f"params.{suffix}"
compose_and_dump(output_file, config_dir, config_name, overrides)
assert output_file.parse() == expected