Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions dvc/repo/experiments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -723,6 +723,7 @@ def _workspace_repro(self) -> Mapping[str, str]:
rev,
name=entry.name,
rel_cwd=relpath(os.getcwd(), self.scm.root_dir),
log_errors=False,
)

if not exec_result.exp_hash or not exec_result.ref_info:
Expand Down
68 changes: 47 additions & 21 deletions dvc/repo/experiments/executor/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,16 +181,16 @@ def fetch_exps(
refs.append(ref)

def on_diverged_ref(orig_ref: str, new_rev: str):
orig_rev = dest_scm.get_ref(orig_ref)
if dest_scm.diff(orig_rev, new_rev):
if force:
logger.debug(
"Replacing existing experiment '%s'", orig_ref,
)
return True
if on_diverged:
checkpoint = self.scm.get_ref(EXEC_CHECKPOINT) is not None
on_diverged(orig_ref, checkpoint)
if force:
logger.debug(
"Replacing existing experiment '%s'", orig_ref,
)
return True

checkpoint = self.scm.get_ref(EXEC_CHECKPOINT) is not None
self._raise_ref_conflict(dest_scm, orig_ref, new_rev, checkpoint)
if on_diverged:
on_diverged(orig_ref, checkpoint)
logger.debug("Reproduced existing experiment '%s'", orig_ref)
return False

Expand Down Expand Up @@ -218,6 +218,7 @@ def reproduce(
queue: Optional["Queue"] = None,
rel_cwd: Optional[str] = None,
name: Optional[str] = None,
log_errors: bool = True,
log_level: Optional[int] = None,
) -> "ExecutorResult":
"""Run dvc repro and return the result.
Expand All @@ -234,7 +235,7 @@ def reproduce(

if queue is not None:
queue.put((rev, os.getpid()))
if log_level is not None:
if log_errors and log_level is not None:
cls._set_log_level(log_level)

def filter_pipeline(stages):
Expand All @@ -246,7 +247,7 @@ def filter_pipeline(stages):
exp_ref: Optional["ExpRefInfo"] = None
repro_force: bool = False

with cls._repro_dvc(dvc_dir, rel_cwd) as dvc:
with cls._repro_dvc(dvc_dir, rel_cwd, log_errors) as dvc:
args, kwargs = cls._repro_args(dvc)
if args:
targets: Optional[Union[list, str]] = args[0]
Expand Down Expand Up @@ -321,7 +322,9 @@ def filter_pipeline(stages):

@classmethod
@contextmanager
def _repro_dvc(cls, dvc_dir: Optional[str], rel_cwd: Optional[str]):
def _repro_dvc(
cls, dvc_dir: Optional[str], rel_cwd: Optional[str], log_errors: bool
):
from dvc.repo import Repo

dvc = Repo(dvc_dir)
Expand All @@ -342,10 +345,12 @@ def _repro_dvc(cls, dvc_dir: Optional[str], rel_cwd: Optional[str]):
except CheckpointKilledError:
raise
except DvcException:
logger.exception("")
if log_errors:
logger.exception("")
raise
except Exception:
logger.exception("unexpected error")
if log_errors:
logger.exception("unexpected error")
raise
finally:
dvc.close()
Expand Down Expand Up @@ -396,6 +401,7 @@ def commit(
logger.debug("No changes to commit")
raise UnchangedExperimentError(rev)

check_conflict = False
branch = scm.get_ref(EXEC_BRANCH, follow=False)
if branch:
old_ref = rev
Expand All @@ -406,21 +412,41 @@ def commit(
ref_info = ExpRefInfo(baseline_rev, name)
branch = str(ref_info)
old_ref = None
if not force and scm.get_ref(branch):
if checkpoint:
raise CheckpointExistsError(ref_info.name)
raise ExperimentExistsError(ref_info.name)
logger.debug("Commit to new experiment branch '%s'", branch)
if scm.get_ref(branch):
if not force:
check_conflict = True
logger.debug(
"%s existing experiment branch '%s'",
"Replace" if force else "Reuse",
branch,
)
else:
logger.debug("Commit to new experiment branch '%s'", branch)

scm.add([], update=True)
scm.commit(f"dvc: commit experiment {exp_hash}", no_verify=True)
new_rev = scm.get_rev()
scm.set_ref(branch, new_rev, old_ref=old_ref)
if check_conflict:
new_rev = cls._raise_ref_conflict(scm, branch, new_rev, checkpoint)
else:
scm.set_ref(branch, new_rev, old_ref=old_ref)
scm.set_ref(EXEC_BRANCH, branch, symbolic=True)
if checkpoint:
scm.set_ref(EXEC_CHECKPOINT, new_rev)
return new_rev

@staticmethod
def _raise_ref_conflict(scm, ref, new_rev, checkpoint):
# If this commit is a duplicate of the existing commit at 'ref', return
# the existing commit. Otherwise, error out and require user to re-run
# with --force as needed
orig_rev = scm.get_ref(ref)
if scm.diff(orig_rev, new_rev):
if checkpoint:
raise CheckpointExistsError(ref)
raise ExperimentExistsError(ref)
return orig_rev

@staticmethod
def _set_log_level(level):
from dvc.logger import disable_other_loggers
Expand Down