diff --git a/dvc/repo/experiments/__init__.py b/dvc/repo/experiments/__init__.py index 24bd4e302d..009a2b9982 100644 --- a/dvc/repo/experiments/__init__.py +++ b/dvc/repo/experiments/__init__.py @@ -437,15 +437,21 @@ def reproduce_queued(self, **kwargs): @scm_locked def new( - self, *args, branch: Optional[str] = None, **kwargs, + self, + *args, + branch: Optional[str] = None, + checkpoint_resume: Optional[str] = None, + **kwargs, ): """Create a new experiment. Experiment will be reproduced and checked out into the user's workspace. """ - if kwargs.get("checkpoint_resume", None) is not None: - return self._resume_checkpoint(*args, **kwargs) + if checkpoint_resume is not None: + return self._resume_checkpoint( + *args, **kwargs, checkpoint_resume=checkpoint_resume + ) if branch: rev = self.scm.resolve_rev(branch) diff --git a/tests/func/experiments/test_checkpoints.py b/tests/func/experiments/test_checkpoints.py index d214c013d9..2aed4e07dc 100644 --- a/tests/func/experiments/test_checkpoints.py +++ b/tests/func/experiments/test_checkpoints.py @@ -24,18 +24,17 @@ def test_new_checkpoint(tmp_dir, scm, dvc, checkpoint_stage, mocker): ).read_text().strip() == "foo: 2" -@pytest.mark.parametrize("last", [True, False]) -def test_resume_checkpoint(tmp_dir, scm, dvc, checkpoint_stage, last): +@pytest.mark.parametrize( + "checkpoint_resume", [Experiments.LAST_CHECKPOINT, "foo"] +) +def test_resume_checkpoint( + tmp_dir, scm, dvc, checkpoint_stage, checkpoint_resume +): with pytest.raises(DvcException): - if last: - dvc.experiments.run( - checkpoint_stage.addressing, - checkpoint_resume=Experiments.LAST_CHECKPOINT, - ) - else: - dvc.experiments.run( - checkpoint_stage.addressing, checkpoint_resume="foo" - ) + dvc.experiments.run( + checkpoint_stage=checkpoint_stage.addressing, + checkpoint_resume=checkpoint_resume, + ) results = dvc.experiments.run( checkpoint_stage.addressing, params=["foo=2"] @@ -46,12 +45,12 @@ def test_resume_checkpoint(tmp_dir, scm, dvc, checkpoint_stage, last): checkpoint_stage.addressing, checkpoint_resume="abc1234", ) - if last: - exp_rev = Experiments.LAST_CHECKPOINT - else: - exp_rev = first(results) + if checkpoint_resume != Experiments.LAST_CHECKPOINT: + checkpoint_resume = first(results) - dvc.experiments.run(checkpoint_stage.addressing, checkpoint_resume=exp_rev) + dvc.experiments.run( + checkpoint_stage.addressing, checkpoint_resume=checkpoint_resume + ) assert (tmp_dir / "foo").read_text() == "10" assert (