diff --git a/dvc/repo/commit.py b/dvc/repo/commit.py index a0405b61f0..d58b4775a2 100644 --- a/dvc/repo/commit.py +++ b/dvc/repo/commit.py @@ -37,20 +37,32 @@ def prompt_to_commit(stage, changes, force=False): @locked def commit( - self, target, with_deps=False, recursive=False, force=False, + self, + target, + with_deps=False, + recursive=False, + force=False, + allow_missing=False, + data_only=False, ): from dvc.dvcfile import Dvcfile - stages_info = self.stage.collect_granular( - target, with_deps=with_deps, recursive=recursive - ) + stages_info = [ + info + for info in self.stage.collect_granular( + target, with_deps=with_deps, recursive=recursive + ) + if not data_only or info.stage.is_data_source + ] for stage_info in stages_info: stage = stage_info.stage changes = stage.changed_entries() if any(changes): prompt_to_commit(stage, changes, force=force) - stage.save() - stage.commit(filter_info=stage_info.filter_info) + stage.save(allow_missing=allow_missing) + stage.commit( + filter_info=stage_info.filter_info, allow_missing=allow_missing + ) Dvcfile(self, stage.path).dump(stage, update_pipeline=False) return [s.stage for s in stages_info] diff --git a/dvc/repo/experiments/__init__.py b/dvc/repo/experiments/__init__.py index 08451ed41c..cf84180c4e 100644 --- a/dvc/repo/experiments/__init__.py +++ b/dvc/repo/experiments/__init__.py @@ -186,6 +186,10 @@ def _stash_exp( if params: self._update_params(params) + # DVC commit data deps to preserve state across workspace + # & tempdir runs + self._stash_commit_deps(*args, **kwargs) + if resume_rev: if branch: branch_name = ExpRefInfo.from_ref(branch).name @@ -242,6 +246,25 @@ def _stash_exp( return stash_rev + def _stash_commit_deps(self, *args, **kwargs): + if len(args): + targets = args[0] + else: + targets = kwargs.get("targets") + if isinstance(targets, str): + targets = [targets] + elif not targets: + targets = [None] + for target in targets: + self.repo.commit( + target, + with_deps=True, + recursive=kwargs.get("recursive", False), + force=True, + allow_missing=True, + data_only=True, + ) + def _stash_msg( self, rev: str, diff --git a/dvc/stage/__init__.py b/dvc/stage/__init__.py index bbbed8beb4..9f60813490 100644 --- a/dvc/stage/__init__.py +++ b/dvc/stage/__init__.py @@ -433,15 +433,21 @@ def compute_md5(self): return m def save(self, allow_missing=False): - self.save_deps() + self.save_deps(allow_missing=allow_missing) self.save_outs(allow_missing=allow_missing) self.md5 = self.compute_md5() self.repo.stage_cache.save(self) - def save_deps(self): + def save_deps(self, allow_missing=False): + from dvc.dependency.base import DependencyDoesNotExistError + for dep in self.deps: - dep.save() + try: + dep.save() + except DependencyDoesNotExistError: + if not allow_missing: + raise def save_outs(self, allow_missing=False): from dvc.output.base import OutputDoesNotExistError diff --git a/tests/func/experiments/test_experiments.py b/tests/func/experiments/test_experiments.py index 9e3d6bac5c..6b9dfa5a56 100644 --- a/tests/func/experiments/test_experiments.py +++ b/tests/func/experiments/test_experiments.py @@ -1,3 +1,4 @@ +import itertools import logging import os import stat @@ -624,3 +625,52 @@ def test_fix_exp_head(tmp_dir, scm, tail): scm.set_ref(EXEC_BASELINE, "refs/heads/master") assert EXEC_BASELINE + tail == fix_exp_head(scm, head) assert "foo" + tail == fix_exp_head(scm, "foo" + tail) + + +@pytest.mark.parametrize( + "workspace, params, target", + itertools.product((True, False), ("foo: 1", "foo: 2"), (True, False)), +) +def test_modified_data_dep(tmp_dir, scm, dvc, workspace, params, target): + tmp_dir.dvc_gen("data", "data") + tmp_dir.gen("copy.py", COPY_SCRIPT) + tmp_dir.gen("params.yaml", "foo: 1") + exp_stage = dvc.run( + cmd="python copy.py params.yaml metrics.yaml", + metrics_no_cache=["metrics.yaml"], + params=["foo"], + name="copy-file", + deps=["copy.py", "data"], + ) + scm.add( + [ + "dvc.yaml", + "dvc.lock", + "copy.py", + "params.yaml", + "metrics.yaml", + "data.dvc", + ".gitignore", + ] + ) + scm.commit("init") + + tmp_dir.gen("params.yaml", params) + tmp_dir.gen("data", "modified") + + results = dvc.experiments.run( + exp_stage.addressing if target else None, tmp_dir=not workspace + ) + exp = first(results) + + for rev in dvc.brancher(revs=[exp]): + if rev != exp: + continue + with dvc.repo_fs.open(tmp_dir / "metrics.yaml") as fobj: + assert fobj.read().strip() == params + with dvc.repo_fs.open(tmp_dir / "data") as fobj: + assert fobj.read().strip() == "modified" + + if workspace: + assert (tmp_dir / "metrics.yaml").read_text().strip() == params + assert (tmp_dir / "data").read_text().strip() == "modified"