diff --git a/dvc/dependency/param.py b/dvc/dependency/param.py index 311e8a82a2..f35a7e0b57 100644 --- a/dvc/dependency/param.py +++ b/dvc/dependency/param.py @@ -1,3 +1,4 @@ +import errno import logging import os from collections import defaultdict @@ -18,6 +19,14 @@ class MissingParamsError(DvcException): pass +class MissingParamsFile(DvcException): + pass + + +class ParamsIsADirectoryError(DvcException): + pass + + class BadParamFileError(DvcException): pass @@ -27,9 +36,9 @@ class ParamsDependency(Dependency): PARAM_SCHEMA = {PARAM_PARAMS: Any(dict, list, None)} DEFAULT_PARAMS_FILE = "params.yaml" - def __init__(self, stage, path, params): + def __init__(self, stage, path, params=None, repo=None): info = {} - self.params = [] + self.params = params or [] if params: if isinstance(params, list): self.params = params @@ -43,6 +52,7 @@ def __init__(self, stage, path, params): path or os.path.join(stage.repo.root_dir, self.DEFAULT_PARAMS_FILE), info=info, + repo=repo, ) def dumpd(self): @@ -88,12 +98,31 @@ def workspace_status(self): def status(self): return self.workspace_status() - def _read(self): + def validate_filepath(self): if not self.exists: - return {} + raise FileNotFoundError( + errno.ENOENT, os.strerror(errno.ENOENT), str(self) + ) + if self.isdir(): + raise IsADirectoryError( + errno.EISDIR, os.strerror(errno.EISDIR), str(self) + ) + + def read_file(self): + _, ext = os.path.splitext(self.fs_path) + loader = LOADERS[ext] + + try: + self.validate_filepath() + except FileNotFoundError as exc: + raise MissingParamsFile( + f"Parameters file '{self}' does not exist" + ) from exc + except IsADirectoryError as exc: + raise ParamsIsADirectoryError( + f"'{self}' is a directory, expected a parameters file" + ) from exc - suffix = self.repo.fs.path.suffix(self.fs_path).lower() - loader = LOADERS[suffix] try: return loader(self.fs_path, fs=self.repo.fs) except ParseError as exc: @@ -101,6 +130,12 @@ def _read(self): f"Unable to read parameters from '{self}'" ) from exc + def _read(self): + try: + return self.read_file() + except MissingParamsFile: + return {} + def read_params_d(self, **kwargs): config = self._read() diff --git a/dvc/output.py b/dvc/output.py index bc3c0964c8..87bd4fa477 100644 --- a/dvc/output.py +++ b/dvc/output.py @@ -291,9 +291,9 @@ def __init__( desc=None, isexec=False, remote=None, + repo=None, ): - self.repo = stage.repo if stage else None - + self.repo = stage.repo if not repo and stage else repo fs_cls, fs_config, fs_path = get_cloud_fs(self.repo, url=path) self.fs = fs_cls(**fs_config) diff --git a/dvc/repo/experiments/init.py b/dvc/repo/experiments/init.py index 613906b1ae..f347ba345a 100644 --- a/dvc/repo/experiments/init.py +++ b/dvc/repo/experiments/init.py @@ -8,7 +8,6 @@ Callable, Dict, Iterable, - List, Optional, TextIO, Tuple, @@ -168,26 +167,29 @@ def _check_stage_exists( ) -def loadd_params(path: str) -> Dict[str, List[str]]: - from dvc.utils.serialize import LOADERS - - _, ext = os.path.splitext(path) - return {path: list(LOADERS[ext](path))} - - -def validate_prompts(key: str, value: str) -> Union[Any, Tuple[Any, str]]: +def validate_prompts( + repo: "Repo", key: str, value: str +) -> Union[Any, Tuple[Any, str]]: from dvc.ui.prompt import InvalidResponse if key == "params": + import errno + + from dvc.dependency.param import ParamsDependency + assert isinstance(value, str) msg_format = ( "[prompt.invalid]'{0}' {1}. " "Please retry with an existing parameters file." ) - if not os.path.exists(value): - raise InvalidResponse(msg_format.format(value, "does not exist")) - if os.path.isdir(value): - raise InvalidResponse(msg_format.format(value, "is a directory")) + try: + ParamsDependency(None, value, repo=repo).validate_filepath() + except (IsADirectoryError, FileNotFoundError) as e: + suffices = { + errno.EISDIR: "is a directory", + errno.ENOENT: "does not exist", + } + raise InvalidResponse(msg_format.format(value, suffices[e.errno])) elif key in ("code", "data"): if not os.path.exists(value): return value, ( @@ -220,7 +222,7 @@ def init( if interactive: defaults = init_interactive( name, - validator=validate_prompts, + validator=partial(validate_prompts, repo), defaults=defaults, live=with_live, provided=overrides, @@ -242,7 +244,17 @@ def init( params_kv = [] params = context.get("params") if params: - params_kv.append(loadd_params(params)) + from dvc.dependency.param import ( + MissingParamsFile, + ParamsDependency, + ParamsIsADirectoryError, + ) + + try: + params_d = ParamsDependency(None, params, repo=repo).read_file() + except (MissingParamsFile, ParamsIsADirectoryError) as exc: + raise DvcException(f"{exc}.") # swallow cause for display + params_kv.append({params: list(params_d.keys())}) checkpoint_out = bool(context.get("live")) models = context.get("models")