Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 1 addition & 11 deletions dvc/command/experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,9 +146,7 @@ def run(self):
if not self.repo.experiments:
return 0

self.repo.experiments.checkout(
self.args.experiment, force=self.args.force
)
self.repo.experiments.checkout(self.args.experiment)

return 0

Expand Down Expand Up @@ -287,14 +285,6 @@ def add_parser(subparsers, parent_parser):
help=EXPERIMENTS_CHECKOUT_HELP,
formatter_class=argparse.RawDescriptionHelpFormatter,
)
experiments_checkout_parser.add_argument(
"-f",
"--force",
action="store_true",
default=False,
help="Overwrite your current workspace with changes from the "
"experiment.",
)
experiments_checkout_parser.add_argument(
"experiment", help="Checkout this experiment.",
)
Expand Down
63 changes: 51 additions & 12 deletions dvc/repo/experiments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import re
import tempfile
from concurrent.futures import ProcessPoolExecutor, as_completed
from contextlib import contextmanager
from typing import Iterable, Optional

from funcy import cached_property
Expand All @@ -26,7 +27,15 @@ def hash_exp(stages):

class UnchangedExperimentError(DvcException):
def __init__(self, rev):
super().__init__("Experiment identical to baseline '{rev[:7]}'.")
super().__init__(f"Experiment identical to baseline '{rev[:7]}'.")
self.rev = rev


class BaselineMismatchError(DvcException):
def __init__(self, rev):
super().__init__(
f"Experiment is not derived from current baseline '{rev[:7]}'."
)
self.rev = rev


Expand Down Expand Up @@ -79,6 +88,13 @@ def exp_dvc(self):

return Repo(self.exp_dvc_dir)

@contextmanager
def chdir(self):
cwd = os.getcwd()
os.chdir(self.exp_dvc.root_dir)
yield self.exp_dvc.root_dir
os.chdir(cwd)

Comment on lines +91 to +97
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't see it being used anywhere

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's used in experiments show, but got removed in the last PR by mistake.

https://github.com/iterative/dvc/blob/11d906ee517e1980ce9884427a33b17253b1e4c2/dvc/repo/experiments/show.py#L61

@cached_property
def args_file(self):
return os.path.join(self.exp_dvc.tmp_dir, self.PACKED_ARGS_FILE)
Expand Down Expand Up @@ -120,7 +136,7 @@ def _scm_checkout(self, rev):
self.scm.repo.heads[0].checkout()
if not Git.is_sha(rev) or not self.scm.has_rev(rev):
self.scm.pull()
logger.debug("Checking out base experiment commit '%s'", rev)
logger.debug("Checking out experiment commit '%s'", rev)
self.scm.checkout(rev)

def _stash_exp(self, *args, **kwargs):
Expand Down Expand Up @@ -176,7 +192,7 @@ def reproduce_one(self, queue=False, **kwargs):
return [stash_rev]
results = self.reproduce([stash_rev], keep_stash=False)
for exp_rev in results:
self.checkout_exp(exp_rev, force=True)
self.checkout_exp(exp_rev)
return results

def reproduce_queued(self, **kwargs):
Expand Down Expand Up @@ -338,30 +354,53 @@ def _collect_output(self, rev: str, executor: ExperimentExecutor):
src = executor.path_info / relpath(fname, tree.tree_root)
copyfile(src, fname)

def checkout_exp(self, rev, force=False):
def checkout_exp(self, rev):
"""Checkout an experiment to the user's workspace."""
from git.exc import GitCommandError
from dvc.repo.checkout import _checkout as dvc_checkout

if force:
self.repo.scm.repo.git.reset(hard=True)
self._check_baseline(rev)
self._scm_checkout(rev)

tmp = tempfile.NamedTemporaryFile(delete=False).name
self.scm.repo.head.commit.diff("HEAD~1", patch=True, output=tmp)

logger.debug("Stashing workspace changes.")
self.repo.scm.repo.git.stash("push")

try:
if os.path.getsize(tmp):
logger.debug("Patching local workspace")
self.repo.scm.repo.git.apply(tmp, reverse=True)
dvc_checkout(self.repo)
need_checkout = True
else:
need_checkout = False
except GitCommandError:
raise DvcException(
"Checkout failed, experiment contains changes which "
"conflict with your current workspace. To overwrite "
"your workspace, use `dvc experiments checkout --force`."
)
raise DvcException("failed to apply experiment changes.")
finally:
remove(tmp)
self._unstash_workspace()

if need_checkout:
dvc_checkout(self.repo)

def _check_baseline(self, exp_rev):
baseline_sha = self.repo.scm.get_rev()
exp_commit = self.scm.repo.rev_parse(exp_rev)
for parent in exp_commit.parents:
if parent.hexsha == baseline_sha:
return
raise BaselineMismatchError(baseline_sha)

def _unstash_workspace(self):
# Essentially we want `git stash pop` with `-X ours` merge strategy
# to prefer the applied experiment changes over stashed workspace
# changes. git stash doesn't support merge strategy parameters, but we
# can do it ourselves with checkout/reset.
logger.debug("Unstashing workspace changes.")
self.repo.scm.repo.git.checkout("--ours", "stash@{0}", "--", ".")
self.repo.scm.repo.git.reset("HEAD")
self.repo.scm.repo.git.stash("drop", "stash@{0}")

def checkout(self, *args, **kwargs):
from dvc.repo.experiments.checkout import checkout
Expand Down