From 1b9b4cdacf71c2b56e4971fc24fb6501fcf04280 Mon Sep 17 00:00:00 2001 From: karajan1001 Date: Thu, 21 Oct 2021 16:59:45 +0800 Subject: [PATCH 1/8] Add experiment name check. 1. Add experiment name check (https://git-scm.com/docs/git-check-ref-format) 2. Add duplicate exp name check. 3. Add some unit test for it. --- dvc/command/experiments.py | 27 ++++++++ tests/unit/command/test_experiments.py | 91 ++++++++++++++++++++++++++ 2 files changed, 118 insertions(+) diff --git a/dvc/command/experiments.py b/dvc/command/experiments.py index b2616d8818..dac04057eb 100644 --- a/dvc/command/experiments.py +++ b/dvc/command/experiments.py @@ -594,9 +594,29 @@ def run(self): class CmdExperimentsRun(CmdRepro): + @staticmethod + def _check_ref_format(name: Optional[str]): + import re + + if not name: + return + + invalid_ref_re = ( + r".*(\/\.|\.lock\/|\.\.|[\00-\040\177]" + r"|\^|~|:|\?|\*|\[|\]|\/{2,}|@\{}|\\).*" + r"|^(\.|\/).*|.*(\.lock|\.|\/)$|^@$" + ) + if re.compile(invalid_ref_re).match(name): + raise InvalidArgumentError( + f"Invalid exp name {name}, the exp name must follow rules in " + "https://git-scm.com/docs/git-check-ref-format" + ) + def run(self): from dvc.compare import show_metrics + self._check_ref_format(self.args.name) + if self.args.checkpoint_resume: if self.args.reset: raise InvalidArgumentError( @@ -607,6 +627,13 @@ def run(self): "--rev can only be used in conjunction with " "--queue or --temp." ) + elif self.args.name: + exps = self.repo.experiments.ls() + for _, exp_list in exps.items(): + if self.args.name in exp_list: + raise InvalidArgumentError( + f"Duplicate experiment name {self.args.name}" + ) if self.args.reset: ui.write("Any existing checkpoints will be reset and re-run.") diff --git a/tests/unit/command/test_experiments.py b/tests/unit/command/test_experiments.py index 3b00c47167..bf0cc6d4b4 100644 --- a/tests/unit/command/test_experiments.py +++ b/tests/unit/command/test_experiments.py @@ -590,3 +590,94 @@ def test_experiments_init_config(dvc, mocker): "plots": "plots", "live": "dvclive", } + + +def test_run_check_ref_format(): + fun = CmdExperimentsRun._check_ref_format + + # They can include slash / for hierarchical (directory) grouping, + # but no slash-separated component can begin with a dot . or end + # with the sequence .lock. + + fun("na.me") + with pytest.raises(InvalidArgumentError): + fun(".name") + with pytest.raises(InvalidArgumentError): + fun("group/.name") + + fun("name.lock1") + with pytest.raises(InvalidArgumentError): + fun("name.lock") + with pytest.raises(InvalidArgumentError): + fun("group/name.lock") + + # They cannot have two consecutive dots .. anywhere. + + with pytest.raises(InvalidArgumentError): + fun("na..me") + + # They cannot have ASCII control characters (i.e. bytes whose values + # are lower than \040, or \177 DEL), space, tilde ~, caret ^, or colon + # : anywhere. + with pytest.raises(InvalidArgumentError): + fun("na\05me") + with pytest.raises(InvalidArgumentError): + fun("na\177me") + with pytest.raises(InvalidArgumentError): + fun("na me") + with pytest.raises(InvalidArgumentError): + fun("na~me") + with pytest.raises(InvalidArgumentError): + fun("na^me") + with pytest.raises(InvalidArgumentError): + fun("na:me") + + # They cannot have question-mark ?, asterisk *, or open bracket [ + # anywhere. See the --refspec-pattern option below for an exception + # to this rule. + + with pytest.raises(InvalidArgumentError): + fun("na?me") + with pytest.raises(InvalidArgumentError): + fun("na*me") + with pytest.raises(InvalidArgumentError): + fun("na[me") + + # They cannot begin or end with a slash / or contain multiple + # consecutive slashes (see the --normalize option below for an + # exception to this rule) + + with pytest.raises(InvalidArgumentError): + fun("/name") + with pytest.raises(InvalidArgumentError): + fun("name/") + with pytest.raises(InvalidArgumentError): + fun("na//me") + + # They cannot end with a dot .. + + with pytest.raises(InvalidArgumentError): + fun("name.") + + # They cannot contain a sequence @{. + + with pytest.raises(InvalidArgumentError): + fun("na@{me.") + + # They cannot be the single character @. + + fun("@name") + with pytest.raises(InvalidArgumentError): + fun("@") + + # They cannot contain a \. + with pytest.raises(InvalidArgumentError): + fun("na\\me") + + +def test_run_duplicate_exp(dvc, scm, mocker): + mocker.patch("dvc.repo.experiments.ls.ls", return_value={"rev": ["name"]}) + + cli_args = parse_args(["experiments", "run", "-n", "name"]) + with pytest.raises(InvalidArgumentError): + CmdExperimentsRun(cli_args).run() From 8c944b6defa5e4ff61eecb8605dcb4b3dd0377bc Mon Sep 17 00:00:00 2001 From: karajan1001 Date: Thu, 21 Oct 2021 18:50:56 +0800 Subject: [PATCH 2/8] Ban slash / in dvc exp names --- dvc/command/experiments.py | 1 + tests/unit/command/test_experiments.py | 4 ++++ 2 files changed, 5 insertions(+) diff --git a/dvc/command/experiments.py b/dvc/command/experiments.py index dac04057eb..2caa6d0411 100644 --- a/dvc/command/experiments.py +++ b/dvc/command/experiments.py @@ -602,6 +602,7 @@ def _check_ref_format(name: Optional[str]): return invalid_ref_re = ( + r".*\/.*|" r".*(\/\.|\.lock\/|\.\.|[\00-\040\177]" r"|\^|~|:|\?|\*|\[|\]|\/{2,}|@\{}|\\).*" r"|^(\.|\/).*|.*(\.lock|\.|\/)$|^@$" diff --git a/tests/unit/command/test_experiments.py b/tests/unit/command/test_experiments.py index bf0cc6d4b4..521efaafe5 100644 --- a/tests/unit/command/test_experiments.py +++ b/tests/unit/command/test_experiments.py @@ -595,6 +595,10 @@ def test_experiments_init_config(dvc, mocker): def test_run_check_ref_format(): fun = CmdExperimentsRun._check_ref_format + # Forbid slash / here because we didn't support it for now. + with pytest.raises(InvalidArgumentError): + fun("group/name") + # They can include slash / for hierarchical (directory) grouping, # but no slash-separated component can begin with a dot . or end # with the sequence .lock. From b812d614b205497ed724813605458906b68d172a Mon Sep 17 00:00:00 2001 From: karajan1001 Date: Sun, 24 Oct 2021 10:33:19 +0800 Subject: [PATCH 3/8] Use dulwich backend for ref name checking --- dvc/command/experiments.py | 28 ------- dvc/repo/experiments/__init__.py | 16 ++++ dvc/repo/experiments/run.py | 1 + dvc/repo/experiments/utils.py | 9 ++ dvc/scm/git/__init__.py | 1 + dvc/scm/git/backend/dulwich/__init__.py | 5 ++ tests/func/experiments/test_experiments.py | 18 ++++ tests/unit/command/test_experiments.py | 95 ---------------------- tests/unit/repo/experiments/test_utils.py | 22 ++++- 9 files changed, 71 insertions(+), 124 deletions(-) diff --git a/dvc/command/experiments.py b/dvc/command/experiments.py index 2caa6d0411..b2616d8818 100644 --- a/dvc/command/experiments.py +++ b/dvc/command/experiments.py @@ -594,30 +594,9 @@ def run(self): class CmdExperimentsRun(CmdRepro): - @staticmethod - def _check_ref_format(name: Optional[str]): - import re - - if not name: - return - - invalid_ref_re = ( - r".*\/.*|" - r".*(\/\.|\.lock\/|\.\.|[\00-\040\177]" - r"|\^|~|:|\?|\*|\[|\]|\/{2,}|@\{}|\\).*" - r"|^(\.|\/).*|.*(\.lock|\.|\/)$|^@$" - ) - if re.compile(invalid_ref_re).match(name): - raise InvalidArgumentError( - f"Invalid exp name {name}, the exp name must follow rules in " - "https://git-scm.com/docs/git-check-ref-format" - ) - def run(self): from dvc.compare import show_metrics - self._check_ref_format(self.args.name) - if self.args.checkpoint_resume: if self.args.reset: raise InvalidArgumentError( @@ -628,13 +607,6 @@ def run(self): "--rev can only be used in conjunction with " "--queue or --temp." ) - elif self.args.name: - exps = self.repo.experiments.ls() - for _, exp_list in exps.items(): - if self.args.name in exp_list: - raise InvalidArgumentError( - f"Duplicate experiment name {self.args.name}" - ) if self.args.reset: ui.write("Any existing checkpoints will be reset and re-run.") diff --git a/dvc/repo/experiments/__init__.py b/dvc/repo/experiments/__init__.py index 30b9a5d55e..1d63d18d3f 100644 --- a/dvc/repo/experiments/__init__.py +++ b/dvc/repo/experiments/__init__.py @@ -473,6 +473,18 @@ def _log_reproduced(self, revs: Iterable[str], tmp_dir: bool = False): "\tdvc exp branch \n" ) + def _ref_name_validation(self, name: Optional[str], force: bool): + from .utils import check_ref_format + + if name is None: + return + + baseline_sha = self.repo.scm.get_rev() + exp_ref = ExpRefInfo(baseline_sha=baseline_sha, name=name) + check_ref_format(self.scm, exp_ref) + if not force and self.scm.get_ref(str(exp_ref)): + raise ExperimentExistsError(name) + @scm_locked def new(self, *args, checkpoint_resume: Optional[str] = None, **kwargs): """Create a new experiment. @@ -485,6 +497,10 @@ def new(self, *args, checkpoint_resume: Optional[str] = None, **kwargs): *args, resume_rev=checkpoint_resume, **kwargs ) + name = kwargs.get("name", None) + force = kwargs.get("force", False) + self._ref_name_validation(name, force) + return self._stash_exp(*args, **kwargs) def _resume_checkpoint( diff --git a/dvc/repo/experiments/run.py b/dvc/repo/experiments/run.py index ea042a2161..a5f3961a71 100644 --- a/dvc/repo/experiments/run.py +++ b/dvc/repo/experiments/run.py @@ -29,6 +29,7 @@ def run( if params: params = loads_param_overrides(params) + return repo.experiments.reproduce_one( targets=targets, params=params, tmp_dir=tmp_dir, **kwargs ) diff --git a/dvc/repo/experiments/utils.py b/dvc/repo/experiments/utils.py index c2a8f9cb05..1c68a2b0a9 100644 --- a/dvc/repo/experiments/utils.py +++ b/dvc/repo/experiments/utils.py @@ -157,3 +157,12 @@ def resolve_exp_ref( msg.extend([f"\t{info}" for info in exp_ref_list]) raise InvalidArgumentError("\n".join(msg)) return exp_ref_list[0] + + +def check_ref_format(scm: "Git", ref: ExpRefInfo): + # "/" forbidden, only in dvc exp as we didn't support it for now. + if not scm.check_ref_format(str(ref)) or "/" in ref.name: + raise InvalidArgumentError( + f"Invalid exp name {ref.name}, the exp name must follow rules in " + "https://git-scm.com/docs/git-check-ref-format" + ) diff --git a/dvc/scm/git/__init__.py b/dvc/scm/git/__init__.py index 5d60ec232f..ba8e76c4ce 100644 --- a/dvc/scm/git/__init__.py +++ b/dvc/scm/git/__init__.py @@ -346,6 +346,7 @@ def get_fs(self, rev: str): status = partialmethod(_backend_func, "status") merge = partialmethod(_backend_func, "merge") validate_git_remote = partialmethod(_backend_func, "validate_git_remote") + check_ref_format = partialmethod(_backend_func, "check_ref_format") def resolve_rev(self, rev: str) -> str: from dvc.repo.experiments.utils import exp_refs_by_name diff --git a/dvc/scm/git/backend/dulwich/__init__.py b/dvc/scm/git/backend/dulwich/__init__.py index 84bfd2f4f7..1f706a8e2d 100644 --- a/dvc/scm/git/backend/dulwich/__init__.py +++ b/dvc/scm/git/backend/dulwich/__init__.py @@ -681,3 +681,8 @@ def validate_git_remote(self, url: str, **kwargs): os.path.join("", path) ): raise InvalidRemoteSCMRepo(url) + + def check_ref_format(self, refname: str): + from dulwich.refs import check_ref_format + + return check_ref_format(refname.encode()) diff --git a/tests/func/experiments/test_experiments.py b/tests/func/experiments/test_experiments.py index cf94871f5b..38c83f4f58 100644 --- a/tests/func/experiments/test_experiments.py +++ b/tests/func/experiments/test_experiments.py @@ -53,6 +53,7 @@ def test_experiment_exists(tmp_dir, scm, dvc, exp_stage, mocker, workspace): tmp_dir=not workspace, ) + new_mock = mocker.spy(dvc.experiments, "_stash_exp") with pytest.raises(ExperimentExistsError): dvc.experiments.run( exp_stage.addressing, @@ -61,6 +62,8 @@ def test_experiment_exists(tmp_dir, scm, dvc, exp_stage, mocker, workspace): tmp_dir=not workspace, ) + new_mock.assert_not_called() + results = dvc.experiments.run( exp_stage.addressing, name="foo", @@ -68,6 +71,7 @@ def test_experiment_exists(tmp_dir, scm, dvc, exp_stage, mocker, workspace): force=True, tmp_dir=not workspace, ) + exp = first(results) fs = scm.get_fs(exp) @@ -685,3 +689,17 @@ def test_exp_run_recursive(tmp_dir, scm, dvc, run_copy_metrics): ) assert dvc.experiments.run(".", recursive=True) assert (tmp_dir / "metric.json").parse() == {"foo": 1} + + +def test_experiment_name_invalid(tmp_dir, scm, dvc, exp_stage, mocker): + from dvc.exceptions import InvalidArgumentError + + new_mock = mocker.spy(dvc.experiments, "_stash_exp") + with pytest.raises(InvalidArgumentError): + dvc.experiments.run( + exp_stage.addressing, + name="fo/o", + params=["foo=3"], + ) + + new_mock.assert_not_called() diff --git a/tests/unit/command/test_experiments.py b/tests/unit/command/test_experiments.py index 521efaafe5..3b00c47167 100644 --- a/tests/unit/command/test_experiments.py +++ b/tests/unit/command/test_experiments.py @@ -590,98 +590,3 @@ def test_experiments_init_config(dvc, mocker): "plots": "plots", "live": "dvclive", } - - -def test_run_check_ref_format(): - fun = CmdExperimentsRun._check_ref_format - - # Forbid slash / here because we didn't support it for now. - with pytest.raises(InvalidArgumentError): - fun("group/name") - - # They can include slash / for hierarchical (directory) grouping, - # but no slash-separated component can begin with a dot . or end - # with the sequence .lock. - - fun("na.me") - with pytest.raises(InvalidArgumentError): - fun(".name") - with pytest.raises(InvalidArgumentError): - fun("group/.name") - - fun("name.lock1") - with pytest.raises(InvalidArgumentError): - fun("name.lock") - with pytest.raises(InvalidArgumentError): - fun("group/name.lock") - - # They cannot have two consecutive dots .. anywhere. - - with pytest.raises(InvalidArgumentError): - fun("na..me") - - # They cannot have ASCII control characters (i.e. bytes whose values - # are lower than \040, or \177 DEL), space, tilde ~, caret ^, or colon - # : anywhere. - with pytest.raises(InvalidArgumentError): - fun("na\05me") - with pytest.raises(InvalidArgumentError): - fun("na\177me") - with pytest.raises(InvalidArgumentError): - fun("na me") - with pytest.raises(InvalidArgumentError): - fun("na~me") - with pytest.raises(InvalidArgumentError): - fun("na^me") - with pytest.raises(InvalidArgumentError): - fun("na:me") - - # They cannot have question-mark ?, asterisk *, or open bracket [ - # anywhere. See the --refspec-pattern option below for an exception - # to this rule. - - with pytest.raises(InvalidArgumentError): - fun("na?me") - with pytest.raises(InvalidArgumentError): - fun("na*me") - with pytest.raises(InvalidArgumentError): - fun("na[me") - - # They cannot begin or end with a slash / or contain multiple - # consecutive slashes (see the --normalize option below for an - # exception to this rule) - - with pytest.raises(InvalidArgumentError): - fun("/name") - with pytest.raises(InvalidArgumentError): - fun("name/") - with pytest.raises(InvalidArgumentError): - fun("na//me") - - # They cannot end with a dot .. - - with pytest.raises(InvalidArgumentError): - fun("name.") - - # They cannot contain a sequence @{. - - with pytest.raises(InvalidArgumentError): - fun("na@{me.") - - # They cannot be the single character @. - - fun("@name") - with pytest.raises(InvalidArgumentError): - fun("@") - - # They cannot contain a \. - with pytest.raises(InvalidArgumentError): - fun("na\\me") - - -def test_run_duplicate_exp(dvc, scm, mocker): - mocker.patch("dvc.repo.experiments.ls.ls", return_value={"rev": ["name"]}) - - cli_args = parse_args(["experiments", "run", "-n", "name"]) - with pytest.raises(InvalidArgumentError): - CmdExperimentsRun(cli_args).run() diff --git a/tests/unit/repo/experiments/test_utils.py b/tests/unit/repo/experiments/test_utils.py index b0fd3da808..81f8a3fa4d 100644 --- a/tests/unit/repo/experiments/test_utils.py +++ b/tests/unit/repo/experiments/test_utils.py @@ -1,7 +1,8 @@ import pytest +from dvc.exceptions import InvalidArgumentError from dvc.repo.experiments.base import EXPS_NAMESPACE, ExpRefInfo -from dvc.repo.experiments.utils import resolve_exp_ref +from dvc.repo.experiments.utils import check_ref_format, resolve_exp_ref def commit_exp_ref(tmp_dir, scm, file="foo", contents="foo", name="foo"): @@ -25,3 +26,22 @@ def test_resolve_exp_ref(tmp_dir, scm, git_upstream, name_only, use_url): remote_ref_info = resolve_exp_ref(scm, "foo" if name_only else ref, remote) assert isinstance(remote_ref_info, ExpRefInfo) assert str(remote_ref_info) == ref + + +def test_run_check_ref_format(scm): + baseline_rev = "b05eecc666734e899f79af228ff49a7ae5a18cc0" + + def fun(name): + ref = ExpRefInfo(baseline_rev, name) + check_ref_format(scm, ref) + + # Forbid slash / here because we didn't support it for now. + with pytest.raises(InvalidArgumentError): + fun("group/name") + + fun("name") + + with pytest.raises(InvalidArgumentError): + fun("na\05me") + with pytest.raises(InvalidArgumentError): + fun("na me") From 715d1998bbccfb851bb15088e623220f698bf987 Mon Sep 17 00:00:00 2001 From: karajan1001 Date: Sun, 24 Oct 2021 11:09:11 +0800 Subject: [PATCH 4/8] Some bug fix --- dvc/repo/experiments/__init__.py | 14 ++++++++------ dvc/repo/experiments/run.py | 1 - tests/func/experiments/test_experiments.py | 5 +---- tests/unit/repo/experiments/test_utils.py | 4 ++-- 4 files changed, 11 insertions(+), 13 deletions(-) diff --git a/dvc/repo/experiments/__init__.py b/dvc/repo/experiments/__init__.py index 1d63d18d3f..19f7891f18 100644 --- a/dvc/repo/experiments/__init__.py +++ b/dvc/repo/experiments/__init__.py @@ -473,7 +473,7 @@ def _log_reproduced(self, revs: Iterable[str], tmp_dir: bool = False): "\tdvc exp branch \n" ) - def _ref_name_validation(self, name: Optional[str], force: bool): + def _ref_name_validation(self, name: Optional[str], **kwargs): from .utils import check_ref_format if name is None: @@ -482,7 +482,10 @@ def _ref_name_validation(self, name: Optional[str], force: bool): baseline_sha = self.repo.scm.get_rev() exp_ref = ExpRefInfo(baseline_sha=baseline_sha, name=name) check_ref_format(self.scm, exp_ref) - if not force and self.scm.get_ref(str(exp_ref)): + + reset = kwargs.get("reset", False) + force = kwargs.get("force", False) + if not (force or reset) and self.scm.get_ref(str(exp_ref)): raise ExperimentExistsError(name) @scm_locked @@ -497,11 +500,10 @@ def new(self, *args, checkpoint_resume: Optional[str] = None, **kwargs): *args, resume_rev=checkpoint_resume, **kwargs ) - name = kwargs.get("name", None) - force = kwargs.get("force", False) - self._ref_name_validation(name, force) + name = kwargs.pop("name", None) + self._ref_name_validation(name, **kwargs) - return self._stash_exp(*args, **kwargs) + return self._stash_exp(*args, name=name, **kwargs) def _resume_checkpoint( self, *args, resume_rev: Optional[str] = None, **kwargs diff --git a/dvc/repo/experiments/run.py b/dvc/repo/experiments/run.py index a5f3961a71..ea042a2161 100644 --- a/dvc/repo/experiments/run.py +++ b/dvc/repo/experiments/run.py @@ -29,7 +29,6 @@ def run( if params: params = loads_param_overrides(params) - return repo.experiments.reproduce_one( targets=targets, params=params, tmp_dir=tmp_dir, **kwargs ) diff --git a/tests/func/experiments/test_experiments.py b/tests/func/experiments/test_experiments.py index 38c83f4f58..ee44e83753 100644 --- a/tests/func/experiments/test_experiments.py +++ b/tests/func/experiments/test_experiments.py @@ -61,7 +61,6 @@ def test_experiment_exists(tmp_dir, scm, dvc, exp_stage, mocker, workspace): params=["foo=3"], tmp_dir=not workspace, ) - new_mock.assert_not_called() results = dvc.experiments.run( @@ -71,7 +70,6 @@ def test_experiment_exists(tmp_dir, scm, dvc, exp_stage, mocker, workspace): force=True, tmp_dir=not workspace, ) - exp = first(results) fs = scm.get_fs(exp) @@ -698,8 +696,7 @@ def test_experiment_name_invalid(tmp_dir, scm, dvc, exp_stage, mocker): with pytest.raises(InvalidArgumentError): dvc.experiments.run( exp_stage.addressing, - name="fo/o", + name="fo^o", params=["foo=3"], ) - new_mock.assert_not_called() diff --git a/tests/unit/repo/experiments/test_utils.py b/tests/unit/repo/experiments/test_utils.py index 81f8a3fa4d..4fce74244c 100644 --- a/tests/unit/repo/experiments/test_utils.py +++ b/tests/unit/repo/experiments/test_utils.py @@ -35,12 +35,12 @@ def fun(name): ref = ExpRefInfo(baseline_rev, name) check_ref_format(scm, ref) + fun("name") + # Forbid slash / here because we didn't support it for now. with pytest.raises(InvalidArgumentError): fun("group/name") - fun("name") - with pytest.raises(InvalidArgumentError): fun("na\05me") with pytest.raises(InvalidArgumentError): From 01250832ce51c9903311095e9355df599de1b1a3 Mon Sep 17 00:00:00 2001 From: Gao Date: Mon, 25 Oct 2021 11:38:12 +0800 Subject: [PATCH 5/8] Update dvc/repo/experiments/__init__.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Peter Rowlands (변기호) --- dvc/repo/experiments/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dvc/repo/experiments/__init__.py b/dvc/repo/experiments/__init__.py index 19f7891f18..ee1a4d6814 100644 --- a/dvc/repo/experiments/__init__.py +++ b/dvc/repo/experiments/__init__.py @@ -479,7 +479,7 @@ def _ref_name_validation(self, name: Optional[str], **kwargs): if name is None: return - baseline_sha = self.repo.scm.get_rev() + baseline_sha = kwargs.get("baseline_rev") or self.repo.scm.get_rev() exp_ref = ExpRefInfo(baseline_sha=baseline_sha, name=name) check_ref_format(self.scm, exp_ref) From 9038fdad4bcc6b28e23aca4c499023ee5cba4408 Mon Sep 17 00:00:00 2001 From: Gao Date: Mon, 25 Oct 2021 11:38:33 +0800 Subject: [PATCH 6/8] Update dvc/repo/experiments/__init__.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Peter Rowlands (변기호) --- dvc/repo/experiments/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dvc/repo/experiments/__init__.py b/dvc/repo/experiments/__init__.py index ee1a4d6814..bc6da14e8c 100644 --- a/dvc/repo/experiments/__init__.py +++ b/dvc/repo/experiments/__init__.py @@ -473,7 +473,7 @@ def _log_reproduced(self, revs: Iterable[str], tmp_dir: bool = False): "\tdvc exp branch \n" ) - def _ref_name_validation(self, name: Optional[str], **kwargs): + def _validate_ref_name(self, name: Optional[str], **kwargs): from .utils import check_ref_format if name is None: From 2120cde5eb205141b92f77074eeb4b6a14d931ab Mon Sep 17 00:00:00 2001 From: karajan1001 Date: Mon, 25 Oct 2021 12:32:44 +0800 Subject: [PATCH 7/8] Some review changes --- dvc/repo/experiments/__init__.py | 2 +- tests/unit/repo/experiments/test_utils.py | 24 +++++++++-------------- 2 files changed, 10 insertions(+), 16 deletions(-) diff --git a/dvc/repo/experiments/__init__.py b/dvc/repo/experiments/__init__.py index bc6da14e8c..0090118d56 100644 --- a/dvc/repo/experiments/__init__.py +++ b/dvc/repo/experiments/__init__.py @@ -501,7 +501,7 @@ def new(self, *args, checkpoint_resume: Optional[str] = None, **kwargs): ) name = kwargs.pop("name", None) - self._ref_name_validation(name, **kwargs) + self._validate_ref_name(name, **kwargs) return self._stash_exp(*args, name=name, **kwargs) diff --git a/tests/unit/repo/experiments/test_utils.py b/tests/unit/repo/experiments/test_utils.py index 4fce74244c..0d5cb8ddf0 100644 --- a/tests/unit/repo/experiments/test_utils.py +++ b/tests/unit/repo/experiments/test_utils.py @@ -28,20 +28,14 @@ def test_resolve_exp_ref(tmp_dir, scm, git_upstream, name_only, use_url): assert str(remote_ref_info) == ref -def test_run_check_ref_format(scm): - baseline_rev = "b05eecc666734e899f79af228ff49a7ae5a18cc0" +@pytest.mark.parametrize( + "name,result", [("name", True), ("group/name", False), ("na me", False)] +) +def test_run_check_ref_format(scm, name, result): - def fun(name): - ref = ExpRefInfo(baseline_rev, name) + ref = ExpRefInfo("abc123", name) + if result: check_ref_format(scm, ref) - - fun("name") - - # Forbid slash / here because we didn't support it for now. - with pytest.raises(InvalidArgumentError): - fun("group/name") - - with pytest.raises(InvalidArgumentError): - fun("na\05me") - with pytest.raises(InvalidArgumentError): - fun("na me") + else: + with pytest.raises(InvalidArgumentError): + check_ref_format(scm, ref) From 3f79e7b5a0d36805c62b7018b90dc44cfd3a63c1 Mon Sep 17 00:00:00 2001 From: karajan1001 Date: Mon, 25 Oct 2021 18:02:23 +0800 Subject: [PATCH 8/8] Make some funtion more reusable. --- dvc/repo/experiments/__init__.py | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/dvc/repo/experiments/__init__.py b/dvc/repo/experiments/__init__.py index 0090118d56..43d1f1db1c 100644 --- a/dvc/repo/experiments/__init__.py +++ b/dvc/repo/experiments/__init__.py @@ -473,20 +473,16 @@ def _log_reproduced(self, revs: Iterable[str], tmp_dir: bool = False): "\tdvc exp branch \n" ) - def _validate_ref_name(self, name: Optional[str], **kwargs): + def _validate_new_ref(self, exp_ref: ExpRefInfo): from .utils import check_ref_format - if name is None: + if not exp_ref.name: return - baseline_sha = kwargs.get("baseline_rev") or self.repo.scm.get_rev() - exp_ref = ExpRefInfo(baseline_sha=baseline_sha, name=name) check_ref_format(self.scm, exp_ref) - reset = kwargs.get("reset", False) - force = kwargs.get("force", False) - if not (force or reset) and self.scm.get_ref(str(exp_ref)): - raise ExperimentExistsError(name) + if self.scm.get_ref(str(exp_ref)): + raise ExperimentExistsError(exp_ref.name) @scm_locked def new(self, *args, checkpoint_resume: Optional[str] = None, **kwargs): @@ -500,10 +496,17 @@ def new(self, *args, checkpoint_resume: Optional[str] = None, **kwargs): *args, resume_rev=checkpoint_resume, **kwargs ) - name = kwargs.pop("name", None) - self._validate_ref_name(name, **kwargs) + name = kwargs.get("name", None) + baseline_sha = kwargs.get("baseline_rev") or self.repo.scm.get_rev() + exp_ref = ExpRefInfo(baseline_sha=baseline_sha, name=name) + + try: + self._validate_new_ref(exp_ref) + except ExperimentExistsError as err: + if not (kwargs.get("force", False) or kwargs.get("reset", False)): + raise err - return self._stash_exp(*args, name=name, **kwargs) + return self._stash_exp(*args, **kwargs) def _resume_checkpoint( self, *args, resume_rev: Optional[str] = None, **kwargs