Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 41 additions & 6 deletions dvc/dependency/param.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import errno
import logging
import os
from collections import defaultdict
Expand All @@ -18,6 +19,14 @@ class MissingParamsError(DvcException):
pass


class MissingParamsFile(DvcException):
pass


class ParamsIsADirectoryError(DvcException):
pass


class BadParamFileError(DvcException):
pass

Expand All @@ -27,9 +36,9 @@ class ParamsDependency(Dependency):
PARAM_SCHEMA = {PARAM_PARAMS: Any(dict, list, None)}
DEFAULT_PARAMS_FILE = "params.yaml"

def __init__(self, stage, path, params):
def __init__(self, stage, path, params=None, repo=None):
info = {}
self.params = []
self.params = params or []
if params:
if isinstance(params, list):
self.params = params
Expand All @@ -43,6 +52,7 @@ def __init__(self, stage, path, params):
path
or os.path.join(stage.repo.root_dir, self.DEFAULT_PARAMS_FILE),
info=info,
repo=repo,
)

def dumpd(self):
Expand Down Expand Up @@ -88,19 +98,44 @@ def workspace_status(self):
def status(self):
return self.workspace_status()

def _read(self):
def validate_filepath(self):
if not self.exists:
return {}
raise FileNotFoundError(
errno.ENOENT, os.strerror(errno.ENOENT), str(self)
)
if self.isdir():
raise IsADirectoryError(
errno.EISDIR, os.strerror(errno.EISDIR), str(self)
)

def read_file(self):
_, ext = os.path.splitext(self.fs_path)
loader = LOADERS[ext]

try:
self.validate_filepath()
except FileNotFoundError as exc:
raise MissingParamsFile(
f"Parameters file '{self}' does not exist"
) from exc
except IsADirectoryError as exc:
raise ParamsIsADirectoryError(
f"'{self}' is a directory, expected a parameters file"
) from exc

suffix = self.repo.fs.path.suffix(self.fs_path).lower()
loader = LOADERS[suffix]
try:
return loader(self.fs_path, fs=self.repo.fs)
except ParseError as exc:
raise BadParamFileError(
f"Unable to read parameters from '{self}'"
) from exc

def _read(self):
try:
return self.read_file()
except MissingParamsFile:
return {}

def read_params_d(self, **kwargs):
config = self._read()

Expand Down
4 changes: 2 additions & 2 deletions dvc/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,9 +291,9 @@ def __init__(
desc=None,
isexec=False,
remote=None,
repo=None,
):
self.repo = stage.repo if stage else None

self.repo = stage.repo if not repo and stage else repo
fs_cls, fs_config, fs_path = get_cloud_fs(self.repo, url=path)
self.fs = fs_cls(**fs_config)

Expand Down
42 changes: 27 additions & 15 deletions dvc/repo/experiments/init.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
Callable,
Dict,
Iterable,
List,
Optional,
TextIO,
Tuple,
Expand Down Expand Up @@ -168,26 +167,29 @@ def _check_stage_exists(
)


def loadd_params(path: str) -> Dict[str, List[str]]:
from dvc.utils.serialize import LOADERS

_, ext = os.path.splitext(path)
return {path: list(LOADERS[ext](path))}


def validate_prompts(key: str, value: str) -> Union[Any, Tuple[Any, str]]:
def validate_prompts(
repo: "Repo", key: str, value: str
) -> Union[Any, Tuple[Any, str]]:
from dvc.ui.prompt import InvalidResponse

if key == "params":
import errno

from dvc.dependency.param import ParamsDependency

assert isinstance(value, str)
msg_format = (
"[prompt.invalid]'{0}' {1}. "
"Please retry with an existing parameters file."
)
if not os.path.exists(value):
raise InvalidResponse(msg_format.format(value, "does not exist"))
if os.path.isdir(value):
raise InvalidResponse(msg_format.format(value, "is a directory"))
try:
ParamsDependency(None, value, repo=repo).validate_filepath()
except (IsADirectoryError, FileNotFoundError) as e:
suffices = {
errno.EISDIR: "is a directory",
errno.ENOENT: "does not exist",
}
raise InvalidResponse(msg_format.format(value, suffices[e.errno]))
elif key in ("code", "data"):
if not os.path.exists(value):
return value, (
Expand Down Expand Up @@ -220,7 +222,7 @@ def init(
if interactive:
defaults = init_interactive(
name,
validator=validate_prompts,
validator=partial(validate_prompts, repo),
defaults=defaults,
live=with_live,
provided=overrides,
Expand All @@ -242,7 +244,17 @@ def init(
params_kv = []
params = context.get("params")
if params:
params_kv.append(loadd_params(params))
from dvc.dependency.param import (
MissingParamsFile,
ParamsDependency,
ParamsIsADirectoryError,
)

try:
params_d = ParamsDependency(None, params, repo=repo).read_file()
except (MissingParamsFile, ParamsIsADirectoryError) as exc:
raise DvcException(f"{exc}.") # swallow cause for display
params_kv.append({params: list(params_d.keys())})

checkpoint_out = bool(context.get("live"))
models = context.get("models")
Expand Down