From 5866e4c5d7a8dc5a844299c0e4579fdf3e83cdd4 Mon Sep 17 00:00:00 2001 From: Saugat Pachhai Date: Thu, 9 Apr 2020 15:53:53 +0545 Subject: [PATCH 1/2] multistage dvc --- dvc/command/pipeline.py | 7 +- dvc/command/run.py | 2 + dvc/dvcfile.py | 248 ++++++++++++++++++++++++++++++++++++++++ dvc/repo/__init__.py | 113 +++++++++++++++++- dvc/repo/reproduce.py | 10 +- dvc/repo/run.py | 14 ++- dvc/stage.py | 137 ++++++++++++++++++++++ 7 files changed, 518 insertions(+), 13 deletions(-) create mode 100644 dvc/dvcfile.py diff --git a/dvc/command/pipeline.py b/dvc/command/pipeline.py index a265f10e11..baac44561e 100644 --- a/dvc/command/pipeline.py +++ b/dvc/command/pipeline.py @@ -11,10 +11,9 @@ class CmdPipelineShow(CmdBase): def _show(self, target, commands, outs, locked): import networkx - from dvc.stage import Stage - stage = Stage.load(self.repo, target) - G = self.repo.graph + stage = self.repo._get_stage_from(self.repo.pipeline_stages, name=target, path=target) + G = self.repo.pipeline_graph stages = networkx.dfs_postorder_nodes(G, stage) if locked: @@ -36,7 +35,7 @@ def _build_graph(self, target, commands, outs): from dvc.stage import Stage from dvc.repo.graph import get_pipeline - target_stage = Stage.load(self.repo, target) + target_stage = self.repo._get_stage_from(self.repo.pipeline_stages, name=target, path=target) G = get_pipeline(self.repo.pipelines, target_stage) nodes = set() diff --git a/dvc/command/run.py b/dvc/command/run.py index fdac98bcaa..d13d0b7dd5 100644 --- a/dvc/command/run.py +++ b/dvc/command/run.py @@ -53,6 +53,7 @@ def run(self): outs_persist=self.args.outs_persist, outs_persist_no_cache=self.args.outs_persist_no_cache, always_changed=self.args.always_changed, + name=self.args.name, ) except DvcException: logger.exception("failed to run command") @@ -98,6 +99,7 @@ def add_parser(subparsers, parent_parser): default=[], help="Declare dependencies for reproducible cmd.", ) + run_parser.add_argument("-n", "--name", help="Specify name of the stage.") run_parser.add_argument( "-o", "--outs", diff --git a/dvc/dvcfile.py b/dvc/dvcfile.py new file mode 100644 index 0000000000..1e6eb1bdec --- /dev/null +++ b/dvc/dvcfile.py @@ -0,0 +1,248 @@ +import collections +import json +import logging +import os +from typing import TYPE_CHECKING + +from funcy import cached_property + +from dvc import dependency, output +from dvc.utils import relpath, file_md5 + +if TYPE_CHECKING: + from dvc.repo import Repo + from dvc.stage import Stage + +logger = logging.getLogger(__name__) + + +class Dvcfile: + def __init__(self, repo: "Repo", path: str) -> None: + self.path = path + self.repo = repo + self.is_multi_stages = False + + @cached_property + def stages(self): + from dvc.stage import Stage, PipelineStage + from dvc.utils.stage import parse_stage + + fname, tag = Stage._get_path_tag(self.path) + # it raises the proper exceptions by priority: + # 1. when the file doesn't exists + # 2. filename is not a DVC-file + # 3. path doesn't represent a regular file + Stage._check_file_exists(self.repo, fname) + Stage._check_dvc_filename(fname) + Stage._check_isfile(self.repo, fname) + + with self.repo.tree.open(fname) as fd: + stage_text = fd.read() + + d = parse_stage(stage_text, fname) + + Stage.validate(d, fname=relpath(fname)) + path = os.path.abspath(fname) + + if not d.get("stages"): + stages_obj = {fname: d} + stage_cls = Stage + else: + # load lockfile and coerce + lock_file = os.path.splitext(fname)[0] + ".lock" + locks = {} + if os.path.exists(lock_file): + with open(lock_file) as fd: + locks = json.load(fd) + + self._coerce_stages_lock_deps(d, locks) + self._coerce_stages_lock_outs(d, locks) + self._coerce_stages_lock_stages(d, locks) + + stages_obj = d.get("stages", []) + self.is_multi_stages = True + stage_cls = PipelineStage + + stages = [] + for name, stage_obj in stages_obj.items(): + stage = stage_cls( + repo=self.repo, + path=path, + wdir=os.path.abspath( + os.path.join( + os.path.dirname(path), d.get(Stage.PARAM_WDIR, ".") + ) + ), + cmd=stage_obj.get(Stage.PARAM_CMD), + md5=stage_obj.get(Stage.PARAM_MD5), + locked=stage_obj.get(Stage.PARAM_LOCKED, False), + tag=tag, + always_changed=stage_obj.get( + Stage.PARAM_ALWAYS_CHANGED, False + ), + # We store stage text to apply updates to the same structure + stage_text=stage_text if not d.get("stages") else None, + ) + if stage_cls == PipelineStage: + stage.name = name + stage.dvcfile = self + + stage.deps = dependency.loadd_from( + stage, stage_obj.get(Stage.PARAM_DEPS) or [] + ) + stage.outs = output.loadd_from( + stage, stage_obj.get(Stage.PARAM_OUTS) or [] + ) + stages.append(stage) + + return stages + + def _coerce_stages_lock_outs(self, stages, locks): + for stage_id, stage in stages["stages"].items(): + stage["outs"] = [ + {"path": item, **locks.get("outs", {}).get(item, {})} + for item in stage.get("outs", []) + ] + + def _coerce_stages_lock_deps(self, stages, locks): + for stage_id, stage in stages["stages"].items(): + stage["deps"] = [ + { + "path": item, + **locks.get("deps", {}).get(stage_id, {}).get(item, {}), + } + for item in stage.get("deps", []) + ] + + def _coerce_stages_lock_stages(self, stages, locks): + for stage_id, stage in stages["stages"].items(): + stage["md5"] = locks.get("stages", {}).get(stage_id, {}).get("md5") + + def dump_multistages(self, stage, path="Dvcfile"): + from dvc.utils.stage import parse_stage_for_update, dump_stage_file + + if not os.path.exists(path): + open(path, "w+").close() + + with open(path, "r") as fd: + data = parse_stage_for_update(fd.read(), path) + + # handle this in Stage::dumpd() + data["stages"] = data.get("stages", {}) + data["stages"][stage.name] = { + "cmd": stage.cmd, + "deps": [dep.def_path for dep in stage.deps], + "outs": [out.def_path for out in stage.outs], + } + + dump_stage_file(path, data) + self.repo.scm.track_file(path) + + def _dump_lockfile(self, stage): + """ + { + "md5": 0, + "deps": { + "1_generator": { + "1.txt": { + "md5": 1 + }, + "2.txt": { + "md5": 2 + }, + "3.txt": { + "md5": 3 + } + } + }, + "outs": { + "1.txt": { + "md5": 4 + }, + "2.txt": { + "md5": 5 + } + }, + "stages": { + "1_generator": { + "md5": 6 + } + } + """ + lockfile = os.path.splitext(stage.path)[0] + ".lock" + + if not os.path.exists(lockfile): + open(lockfile, "w+").close() + + with open(lockfile, "r") as fd: + try: + lock = json.load(fd, object_pairs_hook=collections.OrderedDict) + except json.JSONDecodeError: + lock = collections.OrderedDict() + + print(lock) + lock["md5"] = file_md5(stage.path)[0] + lock["deps"] = lock.get("deps", {}) + lock["outs"] = lock.get("outs", {}) + lock["stages"] = lock.get("stages", {}) + + lock["outs"].update( + { + out.def_path: {out.remote.PARAM_CHECKSUM: out.checksum} + for out in stage.outs + if out.checksum + } + ) + lock["deps"][stage.name] = { + dep.def_path: {dep.remote.PARAM_CHECKSUM: dep.checksum} + for dep in stage.deps + if dep.checksum + } + lock["stages"][stage.name] = {"md5": stage.md5 or stage._compute_md5()} + + with open(lockfile, "w") as fd: + json.dump(lock, fd) + + self.repo.scm.track_file(os.path.relpath(lockfile)) + + def _dump_checkoutstage(self, stage): + from dvc.stage import Stage + + for out in stage.outs: + if not out.use_cache: + continue + + s = Stage( + stage.repo, + # TODO: remove this after dependency graph collection is improved + out.def_path + ".pipeline" + Stage.STAGE_FILE_SUFFIX, + ) + s.outs = [out] + s.md5 = s._compute_md5() + s.dump() + self.repo.scm.track_file(s.path) + + def dump(self, stage): + from dvc.utils.stage import parse_stage_for_update + from dvc.stage import Stage + + fname = stage.path + Stage._check_dvc_filename(fname) + + logger.debug( + "Saving information to '{file}'.".format(file=relpath(fname)) + ) + + if not os.path.exists(fname): + open(stage.path, "w+").close() + + with self.repo.tree.open(fname) as fd: + text = fd.read() + saved_state = parse_stage_for_update(text, fname) + + if saved_state.get("stages") or not ( + saved_state or stage.is_data_source + ): + self.is_multi_stages = True + self._dump_lockfile(stage) + self._dump_checkoutstage(stage) diff --git a/dvc/repo/__init__.py b/dvc/repo/__init__.py index e7f2b4dfd3..efa464b7c1 100644 --- a/dvc/repo/__init__.py +++ b/dvc/repo/__init__.py @@ -172,7 +172,7 @@ def _ignore(self): self.scm.ignore_list(flist) - def check_modified_graph(self, new_stages): + def check_modified_graph(self, new_stages, old_stages=None): """Generate graph including the new stage to check for errors""" # Building graph might be costly for the ones with many DVC-files, # so we provide this undocumented hack to skip it. See [1] for @@ -187,7 +187,7 @@ def check_modified_graph(self, new_stages): # # [1] https://github.com/iterative/dvc/issues/2671 if not getattr(self, "_skip_graph_checks", False): - self._collect_graph(self.stages + new_stages) + self._collect_graph((old_stages or self.stages) + new_stages) def collect(self, target, with_deps=False, recursive=False, graph=None): import networkx as nx @@ -211,6 +211,78 @@ def collect(self, target, with_deps=False, recursive=False, graph=None): pipeline = get_pipeline(get_pipelines(graph or self.graph), stage) return list(nx.dfs_postorder_nodes(pipeline, stage)) + def _collect_for_pipelines( + self, target, with_deps=False, recursive=False, graph=None + ): + # TODO: Refactor `collect` + import networkx as nx + from dvc.stage import Stage + + name = target + if not target: + return list(graph) if graph else self.stages + + target = os.path.abspath(target) + + if recursive and os.path.isdir(target): + stages = nx.dfs_postorder_nodes(graph or self.pipeline_graph) + return [stage for stage in stages if path_isin(stage.path, target)] + + stage = self._get_stage_from( + self.pipeline_stages, name=name, path=target + ) + # Optimization: do not collect the graph for a specific target + if not with_deps: + return [stage] + + pipeline = get_pipeline( + get_pipelines(graph or self.pipeline_graph), stage + ) + return list(nx.dfs_postorder_nodes(pipeline, stage)) + + def _get_stage_from( + self, stages=None, path=None, name=None, priority="name" + ): + # HACK: Split this into two: one that reloads a given stage and returns + # other one that can return a stage from a given name or path from + # `stages`. + # ?: Make `pipeline_stages` {name: value} pair? + stages = stages or [] + assert priority in ("name", "path") + # prioritize path, then use target + found = None + for s in stages: + if name and getattr(s, "name", None) == name: + found = s + if priority == "name": + break + + if path and s.path == path: + found = s + if priority == "path": + break + + from dvc.dvcfile import Dvcfile + + if found: + dvcfile = Dvcfile(self, found.path) + stages = dvcfile.stages + if dvcfile.is_multi_stages: + return stages[0] + for st in stages: + if getattr(st, "name", None) == name: + if priority == "name": + return st + found = st + + for st in stages: + if st.path == path: + if priority == "path": + return st + found = st + + return found + def collect_granular(self, target, *args, **kwargs): from dvc.stage import Stage @@ -396,7 +468,42 @@ def graph(self): @cached_property def pipelines(self): - return get_pipelines(self.graph) + return get_pipelines(self.pipeline_graph) + + @cached_property + def pipeline_stages(self): + # Remove code duplication + # It's okay to do it for each `stages` and `pipeline_stages`. + # Because, only one of them will be used at a given time? + from dvc.dvcfile import Dvcfile + from dvc.stage import Stage + + stages = [] + outs = set() + + for root, dirs, files in self.tree.walk(self.root_dir): + for fname in files: + path = os.path.join(root, fname) + if not Stage.is_valid_filename(path) or path.endswith( + ".pipeline" + Stage.STAGE_FILE_SUFFIX + ): + continue + dvcfile = Dvcfile(self, path) + stgs = dvcfile.stages + stages.extend(stgs) + + for stage in stgs: + for out in stage.outs: + if out.scheme == "local": + outs.add(out.fspath) + + dirs[:] = [d for d in dirs if os.path.join(root, d) not in outs] + + return stages + + @cached_property + def pipeline_graph(self): + return self._collect_graph(self.pipeline_stages) @cached_property def stages(self): diff --git a/dvc/repo/reproduce.py b/dvc/repo/reproduce.py index b14245822b..24049a54aa 100644 --- a/dvc/repo/reproduce.py +++ b/dvc/repo/reproduce.py @@ -69,14 +69,16 @@ def reproduce( if not interactive: kwargs["interactive"] = self.config["core"].get("interactive", False) - active_graph = _get_active_graph(self.graph) + active_graph = _get_active_graph(self.pipeline_graph) active_pipelines = get_pipelines(active_graph) if pipeline or all_pipelines: if all_pipelines: pipelines = active_pipelines else: - stage = Stage.load(self, target) + stage = self._get_stage_from( + self.pipeline_stages, name=target, path=target + ) pipelines = [get_pipeline(active_pipelines, stage)] targets = [] @@ -85,7 +87,9 @@ def reproduce( if pipeline.in_degree(stage) == 0: targets.append(stage) else: - targets = self.collect(target, recursive=recursive, graph=active_graph) + targets = self._collect_for_pipelines( + target, recursive=recursive, graph=active_graph + ) ret = [] for target in targets: diff --git a/dvc/repo/run.py b/dvc/repo/run.py index e5d62ec872..cfd82ae570 100644 --- a/dvc/repo/run.py +++ b/dvc/repo/run.py @@ -5,18 +5,26 @@ @locked @scm_context def run(self, no_exec=False, **kwargs): - from dvc.stage import Stage + from dvc.stage import PipelineStage, Stage - stage = Stage.create(self, **kwargs) + stage_cls = PipelineStage if kwargs.get("name") else Stage + stage = stage_cls.create(self, **kwargs) + if stage_cls == PipelineStage: + stage.name = kwargs["name"] if stage is None: return None - self.check_modified_graph([stage]) + # TODO: check if the stage with given name already exists, don't allow that + self.check_modified_graph([stage], self.pipeline_stages) + self.pipeline_stages.append(stage) if not no_exec: stage.run(no_commit=kwargs.get("no_commit", False)) + if stage_cls == PipelineStage: + stage.dvcfile.dump_multistages(stage, stage.path) + stage.dump() return stage diff --git a/dvc/stage.py b/dvc/stage.py index 14de51501a..5ef9a9ae73 100644 --- a/dvc/stage.py +++ b/dvc/stage.py @@ -193,6 +193,8 @@ class Stage(object): PARAM_LOCKED: bool, PARAM_META: object, PARAM_ALWAYS_CHANGED: bool, + # TODO: Use separate schema? + "stages": object, } COMPILED_SCHEMA = Schema(SCHEMA) @@ -1066,3 +1068,138 @@ def get_used_cache(self, *args, **kwargs): cache.update(out.get_used_cache(*args, **kwargs)) return cache + + +class PipelineStage(Stage): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._dvcfile = None + self._name = None + + def __hash__(self): + return hash(self.name) + + def __repr__(self): + return super().__repr__() + "({name})".format(name=self.name) + + @property + def dvcfile(self): + if self._dvcfile and self._dvcfile.path == self.path: + return self._dvcfile + + from dvc.dvcfile import Dvcfile + + self._dvcfile = Dvcfile(self.repo, self.path) + return self._dvcfile + + @dvcfile.setter + def dvcfile(self, dvcfile): + self._dvcfile = dvcfile + + @property + def name(self): + return self._name + + @name.setter + def name(self, name): + self._name = name + + def dump(self): + self.dvcfile.dump(self) + + @staticmethod + def create(repo, accompany_outs=False, **kwargs): + + wdir = kwargs.get("wdir", None) + cwd = kwargs.get("cwd", None) + fname = kwargs.get("fname", None) + + # Backward compatibility for `cwd` option + if wdir is None and cwd is not None: + if fname is not None and os.path.basename(fname) != fname: + raise StageFileBadNameError( + "DVC-file name '{fname}' may not contain subdirectories" + " if `-c|--cwd` (deprecated) is specified. Use `-w|--wdir`" + " along with `-f` to specify DVC-file path with working" + " directory.".format(fname=fname) + ) + wdir = cwd + elif wdir is None: + wdir = os.curdir + + stage = PipelineStage( + repo=repo, + wdir=wdir, + cmd=kwargs.get("cmd", None), + locked=kwargs.get("locked", False), + always_changed=kwargs.get("always_changed", False), + ) + + Stage._fill_stage_outputs(stage, **kwargs) + deps = dependency.loads_from( + stage, kwargs.get("deps", []), erepo=kwargs.get("erepo", None) + ) + params = dependency.loads_params(stage, kwargs.get("params", [])) + stage.deps = deps + params + + stage._check_circular_dependency() + stage._check_duplicated_arguments() + + from dvc.dvcfile import Dvcfile + + fname = fname or Stage.STAGE_FILE + if os.path.exists(fname): + dvcfile = Dvcfile(repo, fname) + stages = dvcfile.stages + if not dvcfile.is_multi_stages: + raise DvcException( + "%s already exists and is a multistage file", fname + ) + + stage._check_dvc_filename(fname) + + # Autodetecting wdir for add, we need to create outs first to do that, + # so we start with wdir = . and remap out paths later. + if accompany_outs and kwargs.get("wdir") is None and cwd is None: + wdir = os.path.dirname(fname) + + for out in chain(stage.outs, stage.deps): + if out.is_in_repo: + out.def_path = relpath(out.path_info, wdir) + + wdir = os.path.abspath(wdir) + + if cwd is not None: + path = os.path.join(wdir, fname) + else: + path = os.path.abspath(fname) + + Stage._check_stage_path(repo, wdir, is_wdir=kwargs.get("wdir")) + Stage._check_stage_path(repo, os.path.dirname(path)) + + stage.wdir = wdir + stage.path = path + + ignore_build_cache = kwargs.get("ignore_build_cache", False) + + # NOTE: remove outs before we check build cache + if kwargs.get("remove_outs", False): + logger.warning( + "--remove-outs is deprecated." + " It is now the default behavior," + " so there's no need to use this option anymore." + ) + stage.remove_outs(ignore_remove=False) + logger.warning("Build cache is ignored when using --remove-outs.") + ignore_build_cache = True + + if os.path.exists(path) and any(out.persist for out in stage.outs): + logger.warning("Build cache is ignored when persisting outputs.") + ignore_build_cache = True + + # TODO: check if the stage is already cached and unchanged. + return stage + + def changed_md5(self): + # TODO: Use build cache to determine if things changed via `changed()` + return False From 8b50c39d0524591f5640678b1c399e683361216b Mon Sep 17 00:00:00 2001 From: "Restyled.io" Date: Thu, 9 Apr 2020 10:09:24 +0000 Subject: [PATCH 2/2] Restyled by black --- dvc/command/pipeline.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/dvc/command/pipeline.py b/dvc/command/pipeline.py index baac44561e..0b0fcd79d0 100644 --- a/dvc/command/pipeline.py +++ b/dvc/command/pipeline.py @@ -12,7 +12,9 @@ class CmdPipelineShow(CmdBase): def _show(self, target, commands, outs, locked): import networkx - stage = self.repo._get_stage_from(self.repo.pipeline_stages, name=target, path=target) + stage = self.repo._get_stage_from( + self.repo.pipeline_stages, name=target, path=target + ) G = self.repo.pipeline_graph stages = networkx.dfs_postorder_nodes(G, stage) @@ -35,7 +37,9 @@ def _build_graph(self, target, commands, outs): from dvc.stage import Stage from dvc.repo.graph import get_pipeline - target_stage = self.repo._get_stage_from(self.repo.pipeline_stages, name=target, path=target) + target_stage = self.repo._get_stage_from( + self.repo.pipeline_stages, name=target, path=target + ) G = get_pipeline(self.repo.pipelines, target_stage) nodes = set()