diff --git a/dvc/command/experiments.py b/dvc/command/experiments.py index a8d591f4ad..ffa14a4317 100644 --- a/dvc/command/experiments.py +++ b/dvc/command/experiments.py @@ -146,9 +146,7 @@ def run(self): if not self.repo.experiments: return 0 - self.repo.experiments.checkout( - self.args.experiment, force=self.args.force - ) + self.repo.experiments.checkout(self.args.experiment) return 0 @@ -287,14 +285,6 @@ def add_parser(subparsers, parent_parser): help=EXPERIMENTS_CHECKOUT_HELP, formatter_class=argparse.RawDescriptionHelpFormatter, ) - experiments_checkout_parser.add_argument( - "-f", - "--force", - action="store_true", - default=False, - help="Overwrite your current workspace with changes from the " - "experiment.", - ) experiments_checkout_parser.add_argument( "experiment", help="Checkout this experiment.", ) diff --git a/dvc/repo/experiments/__init__.py b/dvc/repo/experiments/__init__.py index fe899762d6..bcabf99966 100644 --- a/dvc/repo/experiments/__init__.py +++ b/dvc/repo/experiments/__init__.py @@ -3,6 +3,7 @@ import re import tempfile from concurrent.futures import ProcessPoolExecutor, as_completed +from contextlib import contextmanager from typing import Iterable, Optional from funcy import cached_property @@ -26,7 +27,15 @@ def hash_exp(stages): class UnchangedExperimentError(DvcException): def __init__(self, rev): - super().__init__("Experiment identical to baseline '{rev[:7]}'.") + super().__init__(f"Experiment identical to baseline '{rev[:7]}'.") + self.rev = rev + + +class BaselineMismatchError(DvcException): + def __init__(self, rev): + super().__init__( + f"Experiment is not derived from current baseline '{rev[:7]}'." + ) self.rev = rev @@ -79,6 +88,13 @@ def exp_dvc(self): return Repo(self.exp_dvc_dir) + @contextmanager + def chdir(self): + cwd = os.getcwd() + os.chdir(self.exp_dvc.root_dir) + yield self.exp_dvc.root_dir + os.chdir(cwd) + @cached_property def args_file(self): return os.path.join(self.exp_dvc.tmp_dir, self.PACKED_ARGS_FILE) @@ -120,7 +136,7 @@ def _scm_checkout(self, rev): self.scm.repo.heads[0].checkout() if not Git.is_sha(rev) or not self.scm.has_rev(rev): self.scm.pull() - logger.debug("Checking out base experiment commit '%s'", rev) + logger.debug("Checking out experiment commit '%s'", rev) self.scm.checkout(rev) def _stash_exp(self, *args, **kwargs): @@ -176,7 +192,7 @@ def reproduce_one(self, queue=False, **kwargs): return [stash_rev] results = self.reproduce([stash_rev], keep_stash=False) for exp_rev in results: - self.checkout_exp(exp_rev, force=True) + self.checkout_exp(exp_rev) return results def reproduce_queued(self, **kwargs): @@ -338,30 +354,53 @@ def _collect_output(self, rev: str, executor: ExperimentExecutor): src = executor.path_info / relpath(fname, tree.tree_root) copyfile(src, fname) - def checkout_exp(self, rev, force=False): + def checkout_exp(self, rev): """Checkout an experiment to the user's workspace.""" from git.exc import GitCommandError from dvc.repo.checkout import _checkout as dvc_checkout - if force: - self.repo.scm.repo.git.reset(hard=True) + self._check_baseline(rev) self._scm_checkout(rev) tmp = tempfile.NamedTemporaryFile(delete=False).name self.scm.repo.head.commit.diff("HEAD~1", patch=True, output=tmp) + + logger.debug("Stashing workspace changes.") + self.repo.scm.repo.git.stash("push") + try: if os.path.getsize(tmp): logger.debug("Patching local workspace") self.repo.scm.repo.git.apply(tmp, reverse=True) - dvc_checkout(self.repo) + need_checkout = True + else: + need_checkout = False except GitCommandError: - raise DvcException( - "Checkout failed, experiment contains changes which " - "conflict with your current workspace. To overwrite " - "your workspace, use `dvc experiments checkout --force`." - ) + raise DvcException("failed to apply experiment changes.") finally: remove(tmp) + self._unstash_workspace() + + if need_checkout: + dvc_checkout(self.repo) + + def _check_baseline(self, exp_rev): + baseline_sha = self.repo.scm.get_rev() + exp_commit = self.scm.repo.rev_parse(exp_rev) + for parent in exp_commit.parents: + if parent.hexsha == baseline_sha: + return + raise BaselineMismatchError(baseline_sha) + + def _unstash_workspace(self): + # Essentially we want `git stash pop` with `-X ours` merge strategy + # to prefer the applied experiment changes over stashed workspace + # changes. git stash doesn't support merge strategy parameters, but we + # can do it ourselves with checkout/reset. + logger.debug("Unstashing workspace changes.") + self.repo.scm.repo.git.checkout("--ours", "stash@{0}", "--", ".") + self.repo.scm.repo.git.reset("HEAD") + self.repo.scm.repo.git.stash("drop", "stash@{0}") def checkout(self, *args, **kwargs): from dvc.repo.experiments.checkout import checkout