diff --git a/dvc/repo/experiments/__init__.py b/dvc/repo/experiments/__init__.py index 345dcf807b..294af32da3 100644 --- a/dvc/repo/experiments/__init__.py +++ b/dvc/repo/experiments/__init__.py @@ -723,6 +723,7 @@ def _workspace_repro(self) -> Mapping[str, str]: rev, name=entry.name, rel_cwd=relpath(os.getcwd(), self.scm.root_dir), + log_errors=False, ) if not exec_result.exp_hash or not exec_result.ref_info: diff --git a/dvc/repo/experiments/executor/base.py b/dvc/repo/experiments/executor/base.py index ce5cd920ea..e7aee41266 100644 --- a/dvc/repo/experiments/executor/base.py +++ b/dvc/repo/experiments/executor/base.py @@ -181,16 +181,16 @@ def fetch_exps( refs.append(ref) def on_diverged_ref(orig_ref: str, new_rev: str): - orig_rev = dest_scm.get_ref(orig_ref) - if dest_scm.diff(orig_rev, new_rev): - if force: - logger.debug( - "Replacing existing experiment '%s'", orig_ref, - ) - return True - if on_diverged: - checkpoint = self.scm.get_ref(EXEC_CHECKPOINT) is not None - on_diverged(orig_ref, checkpoint) + if force: + logger.debug( + "Replacing existing experiment '%s'", orig_ref, + ) + return True + + checkpoint = self.scm.get_ref(EXEC_CHECKPOINT) is not None + self._raise_ref_conflict(dest_scm, orig_ref, new_rev, checkpoint) + if on_diverged: + on_diverged(orig_ref, checkpoint) logger.debug("Reproduced existing experiment '%s'", orig_ref) return False @@ -218,6 +218,7 @@ def reproduce( queue: Optional["Queue"] = None, rel_cwd: Optional[str] = None, name: Optional[str] = None, + log_errors: bool = True, log_level: Optional[int] = None, ) -> "ExecutorResult": """Run dvc repro and return the result. @@ -234,7 +235,7 @@ def reproduce( if queue is not None: queue.put((rev, os.getpid())) - if log_level is not None: + if log_errors and log_level is not None: cls._set_log_level(log_level) def filter_pipeline(stages): @@ -246,7 +247,7 @@ def filter_pipeline(stages): exp_ref: Optional["ExpRefInfo"] = None repro_force: bool = False - with cls._repro_dvc(dvc_dir, rel_cwd) as dvc: + with cls._repro_dvc(dvc_dir, rel_cwd, log_errors) as dvc: args, kwargs = cls._repro_args(dvc) if args: targets: Optional[Union[list, str]] = args[0] @@ -321,7 +322,9 @@ def filter_pipeline(stages): @classmethod @contextmanager - def _repro_dvc(cls, dvc_dir: Optional[str], rel_cwd: Optional[str]): + def _repro_dvc( + cls, dvc_dir: Optional[str], rel_cwd: Optional[str], log_errors: bool + ): from dvc.repo import Repo dvc = Repo(dvc_dir) @@ -342,10 +345,12 @@ def _repro_dvc(cls, dvc_dir: Optional[str], rel_cwd: Optional[str]): except CheckpointKilledError: raise except DvcException: - logger.exception("") + if log_errors: + logger.exception("") raise except Exception: - logger.exception("unexpected error") + if log_errors: + logger.exception("unexpected error") raise finally: dvc.close() @@ -396,6 +401,7 @@ def commit( logger.debug("No changes to commit") raise UnchangedExperimentError(rev) + check_conflict = False branch = scm.get_ref(EXEC_BRANCH, follow=False) if branch: old_ref = rev @@ -406,21 +412,41 @@ def commit( ref_info = ExpRefInfo(baseline_rev, name) branch = str(ref_info) old_ref = None - if not force and scm.get_ref(branch): - if checkpoint: - raise CheckpointExistsError(ref_info.name) - raise ExperimentExistsError(ref_info.name) - logger.debug("Commit to new experiment branch '%s'", branch) + if scm.get_ref(branch): + if not force: + check_conflict = True + logger.debug( + "%s existing experiment branch '%s'", + "Replace" if force else "Reuse", + branch, + ) + else: + logger.debug("Commit to new experiment branch '%s'", branch) scm.add([], update=True) scm.commit(f"dvc: commit experiment {exp_hash}", no_verify=True) new_rev = scm.get_rev() - scm.set_ref(branch, new_rev, old_ref=old_ref) + if check_conflict: + new_rev = cls._raise_ref_conflict(scm, branch, new_rev, checkpoint) + else: + scm.set_ref(branch, new_rev, old_ref=old_ref) scm.set_ref(EXEC_BRANCH, branch, symbolic=True) if checkpoint: scm.set_ref(EXEC_CHECKPOINT, new_rev) return new_rev + @staticmethod + def _raise_ref_conflict(scm, ref, new_rev, checkpoint): + # If this commit is a duplicate of the existing commit at 'ref', return + # the existing commit. Otherwise, error out and require user to re-run + # with --force as needed + orig_rev = scm.get_ref(ref) + if scm.diff(orig_rev, new_rev): + if checkpoint: + raise CheckpointExistsError(ref) + raise ExperimentExistsError(ref) + return orig_rev + @staticmethod def _set_log_level(level): from dvc.logger import disable_other_loggers