diff --git a/dvc/command/repro.py b/dvc/command/repro.py index ecad1c4a2d..6060fa8d40 100644 --- a/dvc/command/repro.py +++ b/dvc/command/repro.py @@ -44,6 +44,7 @@ def run(self): queue=self.args.queue, run_all=self.args.run_all, jobs=self.args.jobs, + params=self.args.params, ) if len(stages) == 0: @@ -177,6 +178,13 @@ def add_parser(subparsers, parent_parser): default=False, help=argparse.SUPPRESS, ) + repro_parser.add_argument( + "--params", + action="append", + default=[], + help="Declare parameter values for an experiment.", + metavar="[:]", + ) repro_parser.add_argument( "--queue", action="store_true", default=False, help=argparse.SUPPRESS ) diff --git a/dvc/exceptions.py b/dvc/exceptions.py index 204c378fdd..b49a71b006 100644 --- a/dvc/exceptions.py +++ b/dvc/exceptions.py @@ -194,6 +194,14 @@ def __init__(self, path): ) +class TOMLFileCorruptedError(DvcException): + def __init__(self, path): + path = relpath(path) + super().__init__( + f"unable to read: '{path}', TOML file structure is corrupted" + ) + + class RecursiveAddingWhileUsingFilename(DvcException): def __init__(self): super().__init__( diff --git a/dvc/repo/experiments/__init__.py b/dvc/repo/experiments/__init__.py index bcabf99966..20963e1faa 100644 --- a/dvc/repo/experiments/__init__.py +++ b/dvc/repo/experiments/__init__.py @@ -2,6 +2,8 @@ import os import re import tempfile +from collections import defaultdict +from collections.abc import Mapping from concurrent.futures import ProcessPoolExecutor, as_completed from contextlib import contextmanager from typing import Iterable, Optional @@ -9,6 +11,7 @@ from funcy import cached_property from dvc.exceptions import DvcException +from dvc.path_info import PathInfo from dvc.repo.experiments.executor import ExperimentExecutor, LocalExecutor from dvc.scm.git import Git from dvc.stage.serialize import to_lockfile @@ -139,21 +142,39 @@ def _scm_checkout(self, rev): logger.debug("Checking out experiment commit '%s'", rev) self.scm.checkout(rev) - def _stash_exp(self, *args, **kwargs): + def _stash_exp(self, *args, params: Optional[dict] = None, **kwargs): """Stash changes from the current (parent) workspace as an experiment. + + Args: + params: Optional dictionary of parameter values to be used. + Values take priority over any parameters specified in the + user's workspace. """ rev = self.scm.get_rev() + + # patch user's workspace into experiments clone tmp = tempfile.NamedTemporaryFile(delete=False).name try: self.repo.scm.repo.git.diff(patch=True, output=tmp) if os.path.getsize(tmp): logger.debug("Patching experiment workspace") self.scm.repo.git.apply(tmp) - else: + elif not params: + # experiment matches original baseline raise UnchangedExperimentError(rev) finally: remove(tmp) + + # update experiment params from command line + if params: + self._update_params(params) + + # save additional repro command line arguments self._pack_args(*args, **kwargs) + + # save experiment as a stash commit w/message containing baseline rev + # (stash commits are merge commits and do not contain a parent commit + # SHA) msg = f"{self.STASH_MSG_PREFIX}{rev}" self.scm.repo.git.stash("push", "-m", msg) return self.scm.resolve_rev("stash@{0}") @@ -166,6 +187,36 @@ def _unpack_args(self, tree=None): args_file = os.path.join(self.exp_dvc.tmp_dir, self.PACKED_ARGS_FILE) return ExperimentExecutor.unpack_repro_args(args_file, tree=tree) + def _update_params(self, params: dict): + """Update experiment params files with the specified values.""" + from dvc.utils.toml import dump_toml, parse_toml_for_update + from dvc.utils.yaml import dump_yaml, parse_yaml_for_update + + logger.debug("Using experiment params '%s'", params) + + # recursive dict update + def _update(dict_, other): + for key, value in other.items(): + if isinstance(value, Mapping): + dict_[key] = _update(dict_.get(key, {}), value) + else: + dict_[key] = value + return dict_ + + loaders = defaultdict(lambda: parse_yaml_for_update) + loaders.update({".toml": parse_toml_for_update}) + dumpers = defaultdict(lambda: dump_yaml) + dumpers.update({".toml": dump_toml}) + + for params_fname in params: + path = PathInfo(self.exp_dvc.root_dir) / params_fname + with self.exp_dvc.tree.open(path, "r") as fobj: + text = fobj.read() + suffix = path.suffix.lower() + data = loaders[suffix](text, path) + _update(data, params[params_fname]) + dumpers[suffix](path, data) + def _commit(self, exp_hash, check_exists=True, branch=True): """Commit stages as an experiment and return the commit SHA.""" if not self.scm.is_dirty(): @@ -207,7 +258,7 @@ def reproduce_queued(self, **kwargs): ) return results - def new(self, *args, workspace=True, **kwargs): + def new(self, *args, **kwargs): """Create a new experiment. Experiment will be reproduced and checked out into the user's @@ -215,15 +266,11 @@ def new(self, *args, workspace=True, **kwargs): """ rev = self.repo.scm.get_rev() self._scm_checkout(rev) - if workspace: - try: - stash_rev = self._stash_exp(*args, **kwargs) - except UnchangedExperimentError as exc: - logger.info("Reproducing existing experiment '%s'.", rev[:7]) - raise exc - else: - # configure params via command line here - pass + try: + stash_rev = self._stash_exp(*args, **kwargs) + except UnchangedExperimentError as exc: + logger.info("Reproducing existing experiment '%s'.", rev[:7]) + raise exc logger.debug( "Stashed experiment '%s' for future execution.", stash_rev[:7] ) @@ -365,8 +412,10 @@ def checkout_exp(self, rev): tmp = tempfile.NamedTemporaryFile(delete=False).name self.scm.repo.head.commit.diff("HEAD~1", patch=True, output=tmp) - logger.debug("Stashing workspace changes.") - self.repo.scm.repo.git.stash("push") + dirty = self.repo.scm.is_dirty() + if dirty: + logger.debug("Stashing workspace changes.") + self.repo.scm.repo.git.stash("push") try: if os.path.getsize(tmp): @@ -379,7 +428,8 @@ def checkout_exp(self, rev): raise DvcException("failed to apply experiment changes.") finally: remove(tmp) - self._unstash_workspace() + if dirty: + self._unstash_workspace() if need_checkout: dvc_checkout(self.repo) diff --git a/dvc/repo/reproduce.py b/dvc/repo/reproduce.py index f94f9dd89f..1f60117b45 100644 --- a/dvc/repo/reproduce.py +++ b/dvc/repo/reproduce.py @@ -60,7 +60,7 @@ def reproduce( recursive=False, pipeline=False, all_pipelines=False, - **kwargs + **kwargs, ): from dvc.utils import parse_target @@ -71,6 +71,7 @@ def reproduce( ) experiment = kwargs.pop("experiment", False) + params = _parse_params(kwargs.pop("params", [])) queue = kwargs.pop("queue", False) run_all = kwargs.pop("run_all", False) jobs = kwargs.pop("jobs", 1) @@ -81,6 +82,7 @@ def reproduce( target=target, recursive=recursive, all_pipelines=all_pipelines, + params=params, queue=queue, run_all=run_all, jobs=jobs, @@ -116,6 +118,31 @@ def reproduce( return _reproduce_stages(active_graph, targets, **kwargs) +def _parse_params(path_params): + from flatten_json import unflatten + from yaml import safe_load, YAMLError + from dvc.dependency.param import ParamsDependency + + ret = {} + for path_param in path_params: + path, _, params_str = path_param.rpartition(":") + # remove empty strings from params, on condition such as `-p "file1:"` + params = {} + for param_str in filter(bool, params_str.split(",")): + try: + # interpret value strings using YAML rules + key, value = param_str.split("=") + params[key] = safe_load(value) + except (ValueError, YAMLError): + raise InvalidArgumentError( + f"Invalid param/value pair '{param_str}'" + ) + if not path: + path = ParamsDependency.DEFAULT_PARAMS_FILE + ret[path] = unflatten(params, ".") + return ret + + def _reproduce_experiments(repo, run_all=False, jobs=1, **kwargs): if run_all: return repo.experiments.reproduce_queued(jobs=jobs) diff --git a/dvc/utils/toml.py b/dvc/utils/toml.py new file mode 100644 index 0000000000..74aa59da38 --- /dev/null +++ b/dvc/utils/toml.py @@ -0,0 +1,21 @@ +import toml + +from dvc.exceptions import TOMLFileCorruptedError + + +def parse_toml_for_update(text, path): + """Parses text into Python structure. + + NOTE: Python toml package does not currently use ordered dicts, so + keys may be re-ordered between load/dump, but this function will at + least preserve comments. + """ + try: + return toml.loads(text, decoder=toml.TomlPreserveCommentDecoder()) + except toml.TomlDecodeError as exc: + raise TOMLFileCorruptedError(path) from exc + + +def dump_toml(path, data): + with open(path, "w", encoding="utf-8") as fobj: + toml.dump(data, fobj, encoder=toml.TomlPreserveCommentEncoder()) diff --git a/tests/unit/command/test_repro.py b/tests/unit/command/test_repro.py index 0b2302ef74..ef86a7549d 100644 --- a/tests/unit/command/test_repro.py +++ b/tests/unit/command/test_repro.py @@ -15,6 +15,7 @@ "recursive": False, "force_downstream": False, "experiment": False, + "params": [], "queue": False, "run_all": False, "jobs": None,