From fb8f2e19d58886244289ea0ba2054ca0ca605957 Mon Sep 17 00:00:00 2001 From: Saugat Pachhai Date: Fri, 17 Apr 2020 22:41:24 +0545 Subject: [PATCH 1/3] refactor: extract stage loading outside of dvcfile --- dvc/command/pipeline.py | 4 +- dvc/dvcfile.py | 297 +++++++--------------------- dvc/loader.py | 132 +++++++++++++ dvc/lockfile.py | 35 +--- dvc/output/base.py | 4 + dvc/repo/__init__.py | 25 +-- dvc/repo/add.py | 2 +- dvc/repo/imp_url.py | 2 +- dvc/repo/lock.py | 2 +- dvc/repo/remove.py | 2 +- dvc/repo/reproduce.py | 5 +- dvc/repo/run.py | 2 +- dvc/repo/update.py | 3 +- dvc/serialize.py | 89 +++++++++ dvc/stage/__init__.py | 7 +- tests/func/test_checkout.py | 2 +- tests/func/test_dvcfile.py | 33 ++-- tests/func/test_import.py | 4 +- tests/func/test_repro.py | 4 +- tests/func/test_repro_multistage.py | 32 ++- tests/func/test_run_multistage.py | 2 +- tests/func/test_stage.py | 32 +-- tests/func/test_update.py | 4 +- tests/unit/test_lockfile.py | 13 +- 24 files changed, 387 insertions(+), 350 deletions(-) create mode 100644 dvc/loader.py create mode 100644 dvc/serialize.py diff --git a/dvc/command/pipeline.py b/dvc/command/pipeline.py index 9f8058a113..a8d3d0008e 100644 --- a/dvc/command/pipeline.py +++ b/dvc/command/pipeline.py @@ -25,7 +25,7 @@ def _show(self, target, commands, outs, locked): from dvc.utils import parse_target path, name = parse_target(target) - stage = dvcfile.Dvcfile(self.repo, path).load_one(name) + stage = dvcfile.Dvcfile(self.repo, path).stages[name] G = self.repo.pipeline_graph stages = networkx.dfs_postorder_nodes(G, stage) if locked: @@ -49,7 +49,7 @@ def _build_graph(self, target, commands=False, outs=False): from dvc.utils import parse_target path, name = parse_target(target) - target_stage = dvcfile.Dvcfile(self.repo, path).load_one(name) + target_stage = dvcfile.Dvcfile(self.repo, path).stages[name] G = get_pipeline(self.repo.pipelines, target_stage) nodes = set() diff --git a/dvc/dvcfile.py b/dvc/dvcfile.py index dd7cb94430..cd8d2c3cb2 100644 --- a/dvc/dvcfile.py +++ b/dvc/dvcfile.py @@ -2,14 +2,16 @@ import re import logging -from funcy import project - import dvc.prompt as prompt from voluptuous import MultipleInvalid - -from dvc import dependency, output +from dvc import serialize from dvc.exceptions import DvcException +from dvc.loader import SingleStageLoader, StageLoader +from dvc.schema import ( + COMPILED_SINGLE_STAGE_SCHEMA, + COMPILED_MULTI_STAGE_SCHEMA, +) from dvc.stage.exceptions import ( StageFileBadNameError, StageFileDoesNotExistError, @@ -18,11 +20,10 @@ StageFileAlreadyExistsError, ) from dvc.utils import relpath -from dvc.utils.collections import apply_diff from dvc.utils.stage import ( - parse_stage_for_update, dump_stage_file, parse_stage, + parse_stage_for_update, ) logger = logging.getLogger(__name__) @@ -33,36 +34,8 @@ class MultiStageFileLoadError(DvcException): - def __init__(self): - super().__init__("Cannot load multi-stage file.") - - -def _serialize_stage(stage): - outs_bucket = {} - for o in stage.outs: - bucket_key = ["metrics"] if o.metric else ["outs"] - - if not o.metric and o.persist: - bucket_key += ["persist"] - if not o.use_cache: - bucket_key += ["no_cache"] - key = "_".join(bucket_key) - outs_bucket[key] = outs_bucket.get(key, []) + [o.def_path] - - return { - stage.name: { - key: value - for key, value in { - stage.PARAM_CMD: stage.cmd, - stage.PARAM_WDIR: stage.resolve_wdir(), - stage.PARAM_DEPS: [d.def_path for d in stage.deps], - **outs_bucket, - stage.PARAM_LOCKED: stage.locked, - stage.PARAM_ALWAYS_CHANGED: stage.always_changed, - }.items() - if value - } - } + def __init__(self, file): + super().__init__("Cannot load multi-stage file: '{}'".format(file)) class Dvcfile: @@ -71,7 +44,37 @@ def __init__(self, repo, path): self.path, self.tag = self._get_path_tag(path) def __repr__(self): - return "{}: {}".format(DVC_FILE, self.path) + return "{}: {}".format( + DVC_FILE, relpath(self.path, self.repo.root_dir) + ) + + def __str__(self): + return "{}: {}".format(DVC_FILE, self.relpath) + + @property + def relpath(self): + return relpath(self.path) + + @property + def stage(self): + data, raw = self._load() + if not self.is_multi_stage(data): + return SingleStageLoader.load_stage(self, data, raw) + raise MultiStageFileLoadError(self.path) + + @property + def lockfile(self): + return os.path.splitext(self.path)[0] + ".lock" + + @property + def stages(self): + from . import lockfile + + data, raw = self._load() + if self.is_multi_stage(data): + lockfile_data = lockfile.load(self.repo, self.lockfile) + return StageLoader(self, data.get("stages", {}), lockfile_data) + return SingleStageLoader(self, data, raw) @classmethod def is_valid_filename(cls, path): @@ -113,13 +116,11 @@ def _get_path_tag(s): return s, None return match.group("path"), match.group("tag") - @property - def lockfile(self): - return os.path.splitext(self.path)[0] + ".lock" - def dump(self, stage, update_dvcfile=False): """Dumps given stage appropriately in the dvcfile.""" - if not hasattr(stage, "name"): + from dvc.stage import create_stage, PipelineStage, Stage + + if not isinstance(stage, PipelineStage): self.dump_single_stage(stage) return @@ -127,11 +128,7 @@ def dump(self, stage, update_dvcfile=False): if update_dvcfile and not stage.is_data_source: self.dump_multistage_dvcfile(stage) - from .stage import Stage, create_stage - - for out in stage.outs: - if not out.use_cache: - continue + for out in filter(lambda o: o.use_cache, stage.outs): s = create_stage( Stage, stage.repo, @@ -145,29 +142,24 @@ def dump(self, stage, update_dvcfile=False): def dump_lockfile(self, stage): from . import lockfile - lockfile.dump(self.repo, self.lockfile, stage) + lockfile.dump(self.repo, self.lockfile, serialize.to_lockfile(stage)) self.repo.scm.track_file(relpath(self.lockfile)) def dump_multistage_dvcfile(self, stage): - from dvc.utils.stage import parse_stage_for_update, dump_stage_file - from dvc.schema import COMPILED_MULTI_STAGE_SCHEMA - - path = self.path - if not os.path.exists(path): - open(path, "w+").close() + data = {} + if self.exists(): + with open(self.path, "r") as fd: + data = parse_stage_for_update(fd.read(), self.path) + if not self.is_multi_stage(data): + raise MultiStageFileLoadError(self.path) + else: + open(self.path, "w+").close() - with open(path, "r") as fd: - data = parse_stage_for_update(fd.read(), path) - - if not self.is_multi_stage(data): - raise MultiStageFileLoadError - - # handle this in Stage::dumpd() data["stages"] = data.get("stages", {}) - data["stages"].update(_serialize_stage(stage)) + data["stages"].update(serialize.to_dvcfile(stage)) - dump_stage_file(path, COMPILED_MULTI_STAGE_SCHEMA(data)) - self.repo.scm.track_file(relpath(path)) + dump_stage_file(self.path, COMPILED_MULTI_STAGE_SCHEMA(data)) + self.repo.scm.track_file(relpath(self.path)) def dump_single_stage(self, stage): self.check_dvc_filename(self.path) @@ -175,25 +167,8 @@ def dump_single_stage(self, stage): logger.debug( "Saving information to '{file}'.".format(file=relpath(self.path)) ) - state = stage.dumpd() - - # When we load a stage we parse yaml with a fast parser, which strips - # off all the comments and formatting. To retain those on update we do - # a trick here: - # - reparse the same yaml text with a slow but smart ruamel yaml parser - # - apply changes to a returned structure - # - serialize it - if stage._stage_text is not None: - saved_state = parse_stage_for_update(stage._stage_text, self.path) - # Stage doesn't work with meta in any way, so .dumpd() doesn't - # have it. We simply copy it over. - if "meta" in saved_state: - state["meta"] = saved_state["meta"] - apply_diff(state, saved_state) - state = saved_state - - dump_stage_file(self.path, state) + dump_stage_file(self.path, serialize.to_single_stage_file(stage)) self.repo.scm.track_file(relpath(self.path)) def _load(self): @@ -210,169 +185,33 @@ def _load(self): d = parse_stage(stage_text, self.path) return d, stage_text - def load_one(self, target=None): - data, raw = self._load() - if not self.is_multi_stage(data): - if target: - logger.warning( - "Ignoring target name '%s' as it's a single stage file.", - target, - ) - return self._load_single_stage(data, raw) - - if not target: - raise DvcException( - "No target provided for multi-stage file '{}'.".format( - self.path - ) - ) - - if not self.has_stage(name=target, data=data): - raise DvcException( - "Target '{}' does not exist " - "inside '{}' multi-stage file.".format(target, self.path) - ) - - stages = self._load_multi_stage( - {"stages": {target: self._get_stage_data(target, data)}} - ) - assert stages - return stages[0] - - @staticmethod - def _get_stage_data(name, data): - return data.get("stages", {}).get(name) - - def has_stage(self, name, data=None): - if not data: - data, _ = self._load() - return bool(self._get_stage_data(name, data)) - - def load(self): - """Loads single stage.""" - data, raw = self._load() - if not self.is_multi_stage(data): - return self._load_single_stage(data, raw) - - raise MultiStageFileLoadError - - def load_all(self): - data, raw = self._load() - return ( - [self._load_single_stage(data, raw)] - if not self.is_multi_stage(data) - else self._load_multi_stage(data) - ) - - def _load_single_stage(self, d: dict, stage_text: str): - from dvc.stage import Stage, loads_from - - path = os.path.abspath(self.path) - wdir = os.path.abspath( - os.path.join(os.path.dirname(path), d.get(Stage.PARAM_WDIR, ".")) - ) - stage = loads_from(Stage, self.repo, path, wdir, d) - stage._stage_text, stage.tag = stage_text, self.tag - stage.deps = dependency.loadd_from( - stage, d.get(Stage.PARAM_DEPS) or [] - ) - stage.outs = output.loadd_from(stage, d.get(Stage.PARAM_OUTS) or []) - - return stage - - def load_multi(self): - data, _ = self._load() - if self.is_multi_stage(data): - return self._load_multi_stage(data) - raise DvcException( - "Cannot load multiple stages from single stage file." - ) - - def _load_multi_stage(self, data): - from . import lockfile - from .stage import PipelineStage, Stage, loads_from - - stages = [] - path = os.path.abspath(self.path) - lock_data = lockfile.load(self.repo, self.lockfile) - for stage_name, d in data.get("stages", {}).items(): - lock_stage_data = lock_data.get(stage_name, {}) - wdir = os.path.abspath( - os.path.join( - os.path.dirname(path), d.get(Stage.PARAM_WDIR, ".") - ) - ) - stage = loads_from(PipelineStage, self.repo, path, wdir, d) - stage.name = stage_name - stage.cmd_changed = lock_stage_data.get(Stage.PARAM_CMD) != d.get( - Stage.PARAM_CMD - ) - - stage._fill_stage_dependencies(**project(d, ["deps"])) - stage._fill_stage_outputs(**d) - stages.append(stage) - - for dep in stage.deps: - dep.info[dep.remote.PARAM_CHECKSUM] = lock_stage_data.get( - Stage.PARAM_DEPS, {} - ).get(dep.def_path) - - if stage.cmd_changed: - continue - - for out in stage.outs: - out.info[out.remote.PARAM_CHECKSUM] = lock_stage_data.get( - Stage.PARAM_OUTS, {} - ).get(out.def_path) - - return stages - @staticmethod def validate_single_stage(d, fname=None): - from dvc.schema import COMPILED_SINGLE_STAGE_SCHEMA - - try: - COMPILED_SINGLE_STAGE_SCHEMA(d) - except MultipleInvalid as exc: - raise StageFileFormatError(fname, exc) + Dvcfile._validate(COMPILED_SINGLE_STAGE_SCHEMA, d, fname) @staticmethod def validate_multi_stage(d, fname=None): - from dvc.schema import COMPILED_MULTI_STAGE_SCHEMA + Dvcfile._validate(COMPILED_MULTI_STAGE_SCHEMA, d, fname) + @staticmethod + def _validate(schema, d, fname=None): try: - COMPILED_MULTI_STAGE_SCHEMA(d) + schema(d) except MultipleInvalid as exc: raise StageFileFormatError(fname, exc) - @staticmethod - def validate(d, fname=None): - Dvcfile.validate_single_stage(d, fname) - def is_multi_stage(self, d=None): - # TODO: maybe the following heuristics is enough? if d is None: d = self._load()[0] - check_multi_stage = d.get("stages") or not d - exc = None + check_multi_stage = d.get("stages") if check_multi_stage: - try: - self.validate_multi_stage(d, self.path) - return True - except StageFileFormatError as _exc: - exc = _exc - - try: - self.validate_single_stage(d, self.path) - return False - except StageFileFormatError: - if check_multi_stage: - raise exc + self.validate_multi_stage(d, self.path) + return True - self.validate_multi_stage(d, self.path) - return True + self.validate_single_stage(d, self.path) + return False - def overwrite_with_prompt(self, force=False): + def remove_with_prompt(self, force=False): if not self.exists(): return diff --git a/dvc/loader.py b/dvc/loader.py new file mode 100644 index 0000000000..6e850e5b9c --- /dev/null +++ b/dvc/loader.py @@ -0,0 +1,132 @@ +import collections +import logging +import os + +from copy import deepcopy +from itertools import chain + +from dvc import dependency, output +from dvc.exceptions import DvcException +from dvc.utils import relpath + +logger = logging.getLogger(__name__) + + +class StageNotFound(KeyError, DvcException): + def __init__(self, file, name): + super().__init__( + "Stage with '{}' name not found inside '{}' file".format( + name, relpath(file) + ) + ) + + +def resolve_paths(path, wdir=None): + path = os.path.abspath(path) + wdir = wdir or os.curdir + wdir = os.path.abspath(os.path.join(os.path.dirname(path), wdir)) + return path, wdir + + +class StageLoader(collections.abc.Mapping): + def __init__(self, dvcfile, stages_data, lockfile_data=None): + self.dvcfile = dvcfile + self.stages_data = stages_data or {} + self.lockfile_data = lockfile_data or {} + + @staticmethod + def _fill_lock_checksums(stage, lock_data): + from .stage import Stage + + outs = stage.outs if not stage.cmd_changed else [] + items = chain( + ((Stage.PARAM_DEPS, dep) for dep in stage.deps), + ((Stage.PARAM_OUTS, out) for out in outs), + ) + for key, item in items: + item.checksum = lock_data.get(key, {}).get(item.def_path) + + @classmethod + def load_stage(cls, dvcfile, name, stage_data, lock_data): + from .stage import PipelineStage, Stage, loads_from + + path, wdir = resolve_paths( + dvcfile.path, stage_data.get(Stage.PARAM_WDIR) + ) + stage = loads_from(PipelineStage, dvcfile.repo, path, wdir, stage_data) + stage.name = name + stage._fill_stage_dependencies(**stage_data) + stage._fill_stage_outputs(**stage_data) + if lock_data: + stage.cmd_changed = lock_data.get( + Stage.PARAM_CMD + ) != stage_data.get(Stage.PARAM_CMD) + cls._fill_lock_checksums(stage, lock_data) + + return stage + + def __getitem__(self, name): + if name not in self: + raise StageNotFound(self.dvcfile.path, name) + + if not self.lockfile_data.get(name): + logger.warning( + "No lock entry found for '%s:%s'", self.dvcfile.relpath, name + ) + return self.load_stage( + self.dvcfile, + name, + self.stages_data[name], + self.lockfile_data.get(name, {}), + ) + + def __iter__(self): + return iter(self.stages_data) + + def __len__(self): + return len(self.stages_data) + + def __contains__(self, name): + return name in self.stages_data + + +class SingleStageLoader(collections.abc.Mapping): + def __init__(self, dvcfile, stage_data, stage_text=None): + self.dvcfile = dvcfile + self.stage_data = stage_data or {} + self.stage_text = stage_text + + def __getitem__(self, item): + if item: + logger.warning( + "Ignoring name '%s' for single stage in '%s'.", + item, + self.dvcfile, + ) + # during `load`, we remove attributes from stage data, so as to + # not duplicate, therefore, for MappingView, we need to deepcopy. + return self.load_stage( + self.dvcfile, deepcopy(self.stage_data), self.stage_text + ) + + @classmethod + def load_stage(cls, dvcfile, d, stage_text): + from dvc.stage import Stage, loads_from + + path, wdir = resolve_paths(dvcfile.path, d.get(Stage.PARAM_WDIR)) + stage = loads_from(Stage, dvcfile.repo, path, wdir, d) + stage._stage_text, stage.tag = stage_text, dvcfile.tag + stage.deps = dependency.loadd_from( + stage, d.get(Stage.PARAM_DEPS) or [] + ) + stage.outs = output.loadd_from(stage, d.get(Stage.PARAM_OUTS) or []) + return stage + + def __iter__(self): + return iter([None]) + + def __contains__(self, item): + return False + + def __len__(self): + return 1 diff --git a/dvc/lockfile.py b/dvc/lockfile.py index 5f39d1e5c7..ee47bd3b91 100644 --- a/dvc/lockfile.py +++ b/dvc/lockfile.py @@ -9,7 +9,6 @@ if TYPE_CHECKING: from dvc.repo import Repo - from dvc.stage import PipelineStage class LockfileCorruptedError(DvcException): @@ -17,36 +16,6 @@ def __init__(self, path): super().__init__("Lockfile '{}' is corrupted.".format(path)) -def serialize_stage(stage: "PipelineStage") -> OrderedDict: - assert stage.cmd - assert stage.name - - deps = OrderedDict( - [ - (dep.def_path, dep.remote.get_checksum(dep.path_info),) - for dep in stage.deps - if dep.remote.get_checksum(dep.path_info) - ] - ) - outs = OrderedDict( - [ - (out.def_path, out.remote.get_checksum(out.path_info),) - for out in stage.outs - if out.remote.get_checksum(out.path_info) - ] - ) - return OrderedDict( - [ - ( - stage.name, - OrderedDict( - [("cmd", stage.cmd), ("deps", deps,), ("outs", outs)] - ), - ) - ] - ) - - def exists(repo: "Repo", path: str) -> bool: return repo.tree.exists(path) @@ -70,9 +39,7 @@ def load(repo: "Repo", path: str) -> dict: raise LockfileCorruptedError(path) -def dump(repo: "Repo", path: str, stage: "PipelineStage"): - stage_data = serialize_stage(stage) - +def dump(repo: "Repo", path: str, stage_data: dict): if not exists(repo, path): data = stage_data else: diff --git a/dvc/output/base.py b/dvc/output/base.py index 2948c117f5..2e555119cb 100644 --- a/dvc/output/base.py +++ b/dvc/output/base.py @@ -151,6 +151,10 @@ def cache_path(self): def checksum(self): return self.info.get(self.remote.PARAM_CHECKSUM) + @checksum.setter + def checksum(self, checksum): + self.info[self.remote.PARAM_CHECKSUM] = checksum + @property def is_dir_checksum(self): return self.remote.is_dir_checksum(self.checksum) diff --git a/dvc/repo/__init__.py b/dvc/repo/__init__.py index 39920cf7ca..c2931f11d1 100644 --- a/dvc/repo/__init__.py +++ b/dvc/repo/__init__.py @@ -211,7 +211,7 @@ def collect_for_pipelines( dvcfile = Dvcfile(self, path) dvcfile.check_file_exists() - return [dvcfile.load_one(name)] + return [dvcfile.stages[name]] def collect(self, target, with_deps=False, recursive=False, graph=None): import networkx as nx @@ -225,7 +225,7 @@ def collect(self, target, with_deps=False, recursive=False, graph=None): if recursive and os.path.isdir(target): return self._collect_inside(target, graph or self.graph) - stage = Dvcfile(self, target).load() + stage = Dvcfile(self, target).stage # Optimization: do not collect the graph for a specific target if not with_deps: @@ -242,7 +242,7 @@ def collect_granular(self, target, *args, **kwargs): # Optimization: do not collect the graph for a specific .dvc target if Dvcfile.is_valid_filename(target) and not kwargs.get("with_deps"): - return [(Dvcfile(self, target).load(), None)] + return [(Dvcfile(self, target).stage, None)] try: (out,) = self.find_outs_by_path(target, strict=False) @@ -308,7 +308,7 @@ def used_cache( return cache - def _collect_graph(self, stages=None): + def _collect_graph(self, stages): """Generate a graph by using the given stages on the given directory The nodes of the graph are the stage's path relative to the root. @@ -357,10 +357,9 @@ def _collect_graph(self, stages=None): G = nx.DiGraph() stages = stages or self.stages - stages = [stage for stage in stages if stage] outs = Trie() # Use trie to efficiently find overlapping outs and deps - for stage in stages: + for stage in filter(bool, stages): for out in stage.outs: out_key = out.path_info.parts @@ -457,15 +456,8 @@ def _collect_stages(self): path = os.path.join(root, fname) if not Dvcfile.is_valid_filename(path): continue - stgs = Dvcfile(self, path).load_all() - ignored_outs.extend( - out - for stage in stgs - if isinstance(stage, PipelineStage) - for out in stage.outs - ) - - for stage in stgs: + dvcfile = Dvcfile(self, path) + for stage in dvcfile.stages.values(): stages = ( output_stages if stage.is_data_source @@ -477,9 +469,12 @@ def _collect_stages(self): or stage.is_data_source ): single_stages.append(stage) + for out in stage.outs: if out.scheme == "local": outs.add(out.fspath) + if isinstance(stage, PipelineStage): + ignored_outs.append(out) dirs[:] = [d for d in dirs if os.path.join(root, d) not in outs] diff --git a/dvc/repo/add.py b/dvc/repo/add.py index c597418083..05fa020c1f 100644 --- a/dvc/repo/add.py +++ b/dvc/repo/add.py @@ -128,7 +128,7 @@ def _create_stages(repo, targets, fname, pbar=None): path, wdir, out = resolve_paths(repo, out) stage = create_stage(Stage, repo, fname or path, wdir=wdir, outs=[out]) if stage: - Dvcfile(repo, stage.path).overwrite_with_prompt(force=True) + Dvcfile(repo, stage.path).remove_with_prompt(force=True) repo._reset() diff --git a/dvc/repo/imp_url.py b/dvc/repo/imp_url.py index 32cdf3de7d..dbf71c0aeb 100644 --- a/dvc/repo/imp_url.py +++ b/dvc/repo/imp_url.py @@ -33,7 +33,7 @@ def imp_url(self, url, out=None, fname=None, erepo=None, locked=True): return None dvcfile = Dvcfile(self, stage.path) - dvcfile.overwrite_with_prompt(force=True) + dvcfile.remove_with_prompt(force=True) self.check_modified_graph([stage]) diff --git a/dvc/repo/lock.py b/dvc/repo/lock.py index d9f8d5f83a..9e26941148 100644 --- a/dvc/repo/lock.py +++ b/dvc/repo/lock.py @@ -8,7 +8,7 @@ def lock(self, target, unlock=False): path, target = parse_target(target) dvcfile = dvcfile.Dvcfile(self, path) - stage = dvcfile.load_one(target) + stage = dvcfile.stages[target] stage.locked = False if unlock else True dvcfile.dump(stage, update_dvcfile=True) diff --git a/dvc/repo/remove.py b/dvc/repo/remove.py index 1001d2c967..c24ede47ef 100644 --- a/dvc/repo/remove.py +++ b/dvc/repo/remove.py @@ -5,7 +5,7 @@ def remove(self, target, outs_only=False): from ..dvcfile import Dvcfile - stage = Dvcfile(self, target).load() + stage = Dvcfile(self, target).stage if outs_only: stage.remove_outs(force=True) else: diff --git a/dvc/repo/reproduce.py b/dvc/repo/reproduce.py index 2e9697c886..a0ae1c35af 100644 --- a/dvc/repo/reproduce.py +++ b/dvc/repo/reproduce.py @@ -61,7 +61,7 @@ def reproduce( all_pipelines=False, **kwargs ): - from .. import dvcfile + from ..dvcfile import Dvcfile from dvc.utils import parse_target if not target and not all_pipelines: @@ -81,7 +81,8 @@ def reproduce( if all_pipelines: pipelines = active_pipelines else: - stage = dvcfile.Dvcfile(self, path).load_one(name) + dvcfile = Dvcfile(self, path) + stage = dvcfile.stages[name] pipelines = [get_pipeline(active_pipelines, stage)] targets = [] diff --git a/dvc/repo/run.py b/dvc/repo/run.py index 274b956993..806cc8da65 100644 --- a/dvc/repo/run.py +++ b/dvc/repo/run.py @@ -53,7 +53,7 @@ def run(self, fname=None, no_exec=False, **kwargs): relpath(dvcfile.path) ) ) - dvcfile.overwrite_with_prompt(force=kwargs.get("overwrite", True)) + dvcfile.remove_with_prompt(force=kwargs.get("overwrite", True)) self.check_modified_graph([stage], self.pipeline_stages) if not no_exec: diff --git a/dvc/repo/update.py b/dvc/repo/update.py index f9bee48f3f..5ade233669 100644 --- a/dvc/repo/update.py +++ b/dvc/repo/update.py @@ -6,8 +6,7 @@ def update(self, target, rev=None): from ..dvcfile import Dvcfile dvcfile = Dvcfile(self, target) - stage = dvcfile.load() - + stage = dvcfile.stage stage.update(rev) dvcfile.dump(stage) diff --git a/dvc/serialize.py b/dvc/serialize.py new file mode 100644 index 0000000000..c6594aefbf --- /dev/null +++ b/dvc/serialize.py @@ -0,0 +1,89 @@ +from collections import OrderedDict +from typing import TYPE_CHECKING + +from dvc.utils.collections import apply_diff +from dvc.utils.stage import parse_stage_for_update + +if TYPE_CHECKING: + from dvc.stage import PipelineStage, Stage + + +def _get_outs(stage: "PipelineStage"): + outs_bucket = {} + for o in stage.outs: + bucket_key = ["metrics"] if o.metric else ["outs"] + + if not o.metric and o.persist: + bucket_key += ["persist"] + if not o.use_cache: + bucket_key += ["no_cache"] + key = "_".join(bucket_key) + outs_bucket[key] = outs_bucket.get(key, []) + [o.def_path] + return outs_bucket + + +def to_dvcfile(stage: "PipelineStage"): + return { + stage.name: { + key: value + for key, value in { + stage.PARAM_CMD: stage.cmd, + stage.PARAM_WDIR: stage.resolve_wdir(), + stage.PARAM_DEPS: [d.def_path for d in stage.deps], + **_get_outs(stage), + stage.PARAM_LOCKED: stage.locked, + stage.PARAM_ALWAYS_CHANGED: stage.always_changed, + }.items() + if value + } + } + + +def to_lockfile(stage: "PipelineStage") -> OrderedDict: + assert stage.cmd + assert stage.name + + deps = OrderedDict( + [ + (dep.def_path, dep.remote.get_checksum(dep.path_info),) + for dep in stage.deps + if dep.remote.get_checksum(dep.path_info) + ] + ) + outs = OrderedDict( + [ + (out.def_path, out.remote.get_checksum(out.path_info),) + for out in stage.outs + if out.remote.get_checksum(out.path_info) + ] + ) + return OrderedDict( + [ + ( + stage.name, + OrderedDict( + [("cmd", stage.cmd), ("deps", deps,), ("outs", outs)] + ), + ) + ] + ) + + +def to_single_stage_file(stage: "Stage"): + state = stage.dumpd() + + # When we load a stage we parse yaml with a fast parser, which strips + # off all the comments and formatting. To retain those on update we do + # a trick here: + # - reparse the same yaml text with a slow but smart ruamel yaml parser + # - apply changes to a returned structure + # - serialize it + if stage._stage_text is not None: + saved_state = parse_stage_for_update(stage._stage_text, stage.path) + # Stage doesn't work with meta in any way, so .dumpd() doesn't + # have it. We simply copy it over. + if "meta" in saved_state: + state["meta"] = saved_state["meta"] + apply_diff(state, saved_state) + state = saved_state + return state diff --git a/dvc/stage/__init__.py b/dvc/stage/__init__.py index bd313107df..1120772ed6 100644 --- a/dvc/stage/__init__.py +++ b/dvc/stage/__init__.py @@ -348,7 +348,7 @@ def can_be_skipped(self): ) def reload(self): - return self.dvcfile.load() + return self.dvcfile.stage @property def is_cached(self): @@ -776,12 +776,11 @@ def __repr__(self): def _changed(self): if self.cmd_changed: logger.warning("'cmd' of {} has changed.".format(self)) - return self.cmd_changed or self._changed_deps() or self._changed_outs() def reload(self): - return self.dvcfile.load_one(self.name) + return self.dvcfile.stages[self.name] @property def is_cached(self): - return self.dvcfile.has_stage(name=self.name) and super().is_cached + return self.name in self.dvcfile.stages and super().is_cached diff --git a/tests/func/test_checkout.py b/tests/func/test_checkout.py index 20d59a5ee5..b6380c4f66 100644 --- a/tests/func/test_checkout.py +++ b/tests/func/test_checkout.py @@ -586,7 +586,7 @@ def test_checkout_stats_on_failure(tmp_dir, dvc, scm): {"foo": "foo", "dir": {"subdir": {"file": "file"}}, "other": "other"}, commit="initial", ) - stage = Dvcfile(dvc, "foo.dvc").load() + stage = Dvcfile(dvc, "foo.dvc").stage tmp_dir.dvc_gen({"foo": "foobar", "other": "other other"}, commit="second") # corrupt cache diff --git a/tests/func/test_dvcfile.py b/tests/func/test_dvcfile.py index e1160f946d..3b871cce32 100644 --- a/tests/func/test_dvcfile.py +++ b/tests/func/test_dvcfile.py @@ -1,8 +1,8 @@ import pytest from dvc.dvcfile import Dvcfile -from dvc.exceptions import DvcException from dvc.stage import Stage +from dvc.loader import StageNotFound from dvc.stage.exceptions import StageFileDoesNotExistError @@ -15,7 +15,7 @@ def test_run_load_one_for_multistage(tmp_dir, dvc): outs_persist_no_cache=["foo2"], always_changed=True, ) - stage2 = Dvcfile(dvc, "Dvcfile").load_one("copy-foo-foo2") + stage2 = Dvcfile(dvc, "Dvcfile").stages["copy-foo-foo2"] assert stage1 == stage2 foo_out = stage2.outs[0] assert stage2.cmd == "cp foo foo2" @@ -29,7 +29,7 @@ def test_run_load_one_for_multistage(tmp_dir, dvc): def test_run_load_one_for_multistage_non_existing(tmp_dir, dvc): with pytest.raises(StageFileDoesNotExistError): - Dvcfile(dvc, "Dvcfile").load_one("copy-foo-foo2") + assert Dvcfile(dvc, "Dvcfile").stages.get("copy-foo-foo2") def test_run_load_one_for_multistage_non_existing_stage_name(tmp_dir, dvc): @@ -41,9 +41,8 @@ def test_run_load_one_for_multistage_non_existing_stage_name(tmp_dir, dvc): metrics=["foo2"], always_changed=True, ) - with pytest.raises(DvcException): - # TODO: Better exception - Dvcfile(dvc, stage.path).load_one("random-name") + with pytest.raises(StageNotFound): + assert Dvcfile(dvc, stage.path).stages["random-name"] def test_run_load_one_on_single_stage(tmp_dir, dvc): @@ -51,8 +50,8 @@ def test_run_load_one_on_single_stage(tmp_dir, dvc): stage = dvc.run( cmd="cp foo foo2", deps=["foo"], metrics=["foo2"], always_changed=True, ) - Dvcfile(dvc, stage.path).load_one("random-name") - Dvcfile(dvc, stage.path).load_one() + assert Dvcfile(dvc, stage.path).stages.get("random-name") + assert Dvcfile(dvc, stage.path).stage def test_has_stage_with_name(tmp_dir, dvc): @@ -65,8 +64,8 @@ def test_has_stage_with_name(tmp_dir, dvc): always_changed=True, ) dvcfile = Dvcfile(dvc, "Dvcfile") - assert dvcfile.has_stage("copy-foo-foo2") - assert not dvcfile.has_stage("copy") + assert "copy-foo-foo2" in dvcfile.stages + assert "copy" not in dvcfile.stages def test_load_all_multistage(tmp_dir, dvc): @@ -78,9 +77,9 @@ def test_load_all_multistage(tmp_dir, dvc): metrics=["foo2"], always_changed=True, ) - stages = Dvcfile(dvc, "Dvcfile").load_all() + stages = Dvcfile(dvc, "Dvcfile").stages.values() assert len(stages) == 1 - assert stages[0] == stage1 + assert list(stages) == [stage1] tmp_dir.gen("bar", "bar") stage2 = dvc.run( @@ -90,7 +89,7 @@ def test_load_all_multistage(tmp_dir, dvc): metrics=["bar2"], always_changed=True, ) - assert set(Dvcfile(dvc, "Dvcfile").load_all()) == {stage2, stage1} + assert set(Dvcfile(dvc, "Dvcfile").stages.values()) == {stage2, stage1} def test_load_all_singlestage(tmp_dir, dvc): @@ -98,9 +97,9 @@ def test_load_all_singlestage(tmp_dir, dvc): stage1 = dvc.run( cmd="cp foo foo2", deps=["foo"], metrics=["foo2"], always_changed=True, ) - stages = Dvcfile(dvc, "foo2.dvc").load_all() + stages = Dvcfile(dvc, "foo2.dvc").stages.values() assert len(stages) == 1 - assert stages == [stage1] + assert list(stages) == [stage1] def test_load_singlestage(tmp_dir, dvc): @@ -108,7 +107,7 @@ def test_load_singlestage(tmp_dir, dvc): stage1 = dvc.run( cmd="cp foo foo2", deps=["foo"], metrics=["foo2"], always_changed=True, ) - assert Dvcfile(dvc, "foo2.dvc").load() == stage1 + assert Dvcfile(dvc, "foo2.dvc").stage == stage1 def test_load_multistage(tmp_dir, dvc): @@ -123,7 +122,7 @@ def test_load_multistage(tmp_dir, dvc): always_changed=True, ) with pytest.raises(MultiStageFileLoadError): - Dvcfile(dvc, "Dvcfile").load() + Dvcfile(dvc, "Dvcfile").stage def test_is_multistage(tmp_dir, dvc): diff --git a/tests/func/test_import.py b/tests/func/test_import.py index 58a2406719..c57a025829 100644 --- a/tests/func/test_import.py +++ b/tests/func/test_import.py @@ -182,7 +182,7 @@ def test_pull_imported_stage(tmp_dir, dvc, erepo_dir): erepo_dir.dvc_gen("foo", "foo content", commit="create foo") dvc.imp(fspath(erepo_dir), "foo", "foo_imported") - dst_stage = Dvcfile(dvc, "foo_imported.dvc").load() + dst_stage = Dvcfile(dvc, "foo_imported.dvc").stage dst_cache = dst_stage.outs[0].cache_path remove("foo_imported") @@ -232,7 +232,7 @@ def test_download_error_pulling_imported_stage(tmp_dir, dvc, erepo_dir): erepo_dir.dvc_gen("foo", "foo content", commit="create foo") dvc.imp(fspath(erepo_dir), "foo", "foo_imported") - dst_stage = Dvcfile(dvc, "foo_imported.dvc").load() + dst_stage = Dvcfile(dvc, "foo_imported.dvc").stage dst_cache = dst_stage.outs[0].cache_path remove("foo_imported") diff --git a/tests/func/test_repro.py b/tests/func/test_repro.py index f00d071e14..eecc3e4001 100644 --- a/tests/func/test_repro.py +++ b/tests/func/test_repro.py @@ -193,7 +193,7 @@ def test_nested(self): # be processed before dir1 to load error.dvc first. self.dvc.pipeline_stages = [ nested_stage, - Dvcfile(self.dvc, error_stage_path).load(), + Dvcfile(self.dvc, error_stage_path).stage, ] with patch.object(self.dvc, "_reset"): # to prevent `stages` resetting @@ -1320,7 +1320,7 @@ def test_force_with_dependencies(self): ret = main(["repro", "--force", "datetime.dvc"]) self.assertEqual(ret, 0) - repro_out = Dvcfile(self.dvc, "datetime.dvc").load().outs[0] + repro_out = Dvcfile(self.dvc, "datetime.dvc").stage.outs[0] self.assertNotEqual(run_out.checksum, repro_out.checksum) diff --git a/tests/func/test_repro_multistage.py b/tests/func/test_repro_multistage.py index 86852a9e69..78c6756732 100644 --- a/tests/func/test_repro_multistage.py +++ b/tests/func/test_repro_multistage.py @@ -3,7 +3,9 @@ import pytest +from dvc.exceptions import CyclicGraphError from dvc.stage import PipelineStage +from dvc.utils.stage import dump_stage_file from tests.func import test_repro from dvc.main import main @@ -38,7 +40,6 @@ class TestReproFailMultiStage(MultiStageRun, test_repro.TestReproFail): class TestReproCyclicGraphMultiStage( MultiStageRun, test_repro.TestReproCyclicGraph ): - # TODO: Also test with new-style forced dump pass @@ -238,7 +239,6 @@ def test_repro_when_cmd_changes(tmp_dir, dvc, run_copy): def test_repro_when_new_deps_is_added_in_dvcfile(tmp_dir, dvc, run_copy): from dvc.dvcfile import Dvcfile - from dvc.utils.stage import dump_stage_file tmp_dir.gen("copy.py", COPY_SCRIPT) tmp_dir.gen({"foo": "foo", "bar": "bar"}) @@ -262,7 +262,6 @@ def test_repro_when_new_deps_is_added_in_dvcfile(tmp_dir, dvc, run_copy): def test_repro_when_new_outs_is_added_in_dvcfile(tmp_dir, dvc): from dvc.dvcfile import Dvcfile - from dvc.utils.stage import dump_stage_file tmp_dir.gen("copy.py", COPY_SCRIPT) tmp_dir.gen({"foo": "foo", "bar": "bar"}) @@ -286,7 +285,6 @@ def test_repro_when_new_outs_is_added_in_dvcfile(tmp_dir, dvc): def test_repro_when_new_deps_is_moved(tmp_dir, dvc): from dvc.dvcfile import Dvcfile - from dvc.utils.stage import dump_stage_file tmp_dir.gen("copy.py", COPY_SCRIPT) tmp_dir.gen({"foo": "foo", "bar": "foo"}) @@ -314,7 +312,6 @@ def test_repro_when_new_deps_is_moved(tmp_dir, dvc): def test_repro_when_new_out_overlaps_others_stage_outs(tmp_dir, dvc): - from dvc.utils.stage import dump_stage_file from dvc.exceptions import OverlappingOutputPathsError tmp_dir.gen({"dir": {"file1": "file1"}, "foo": "foo"}) @@ -336,7 +333,6 @@ def test_repro_when_new_out_overlaps_others_stage_outs(tmp_dir, dvc): def test_repro_when_new_deps_added_does_not_exist(tmp_dir, dvc): - from dvc.utils.stage import dump_stage_file from dvc.exceptions import ReproductionError tmp_dir.gen("copy.py", COPY_SCRIPT) @@ -358,7 +354,6 @@ def test_repro_when_new_deps_added_does_not_exist(tmp_dir, dvc): def test_repro_when_new_outs_added_does_not_exist(tmp_dir, dvc): - from dvc.utils.stage import dump_stage_file from dvc.exceptions import ReproductionError tmp_dir.gen("copy.py", COPY_SCRIPT) @@ -380,8 +375,6 @@ def test_repro_when_new_outs_added_does_not_exist(tmp_dir, dvc): def test_repro_when_lockfile_gets_deleted(tmp_dir, dvc): - from dvc.utils.stage import dump_stage_file - tmp_dir.gen("copy.py", COPY_SCRIPT) tmp_dir.gen("foo", "foo") dump_stage_file( @@ -410,3 +403,24 @@ def test_repro_when_lockfile_gets_deleted(tmp_dir, dvc): ) assert os.path.exists("foobar.dvc") + + +def test_cyclic_graph_error(tmp_dir, dvc, run_copy): + tmp_dir.gen("foo", "foo") + run_copy("foo", "bar", name="copy-foo-bar") + run_copy("bar", "baz", name="copy-bar-baz") + run_copy("baz", "foobar", name="copy-baz-foobar") + + stage_dump = { + "stages": { + "copy-baz-foo": { + "cmd": "echo baz > foo", + "deps": ["baz"], + "outs": ["foo"], + } + } + } + dump_stage_file("cycle.dvc", stage_dump) + + with pytest.raises(CyclicGraphError): + dvc.reproduce("cycle.dvc:copy-baz-foo") diff --git a/tests/func/test_run_multistage.py b/tests/func/test_run_multistage.py index 574efc7b67..35594c1268 100644 --- a/tests/func/test_run_multistage.py +++ b/tests/func/test_run_multistage.py @@ -39,7 +39,7 @@ def test_run_multi_stage_repeat(tmp_dir, dvc, run_copy): run_copy("foo1", "foo2", name="copy-foo1-foo2") run_copy("foo2", "foo3") - stages = Dvcfile(dvc, DVC_FILE).load_multi() + stages = list(Dvcfile(dvc, DVC_FILE).stages.values()) assert len(stages) == 2 assert all(isinstance(stage, PipelineStage) for stage in stages) assert set(stage.name for stage in stages) == { diff --git a/tests/func/test_stage.py b/tests/func/test_stage.py index 08e9b488a2..b0a751e5c7 100644 --- a/tests/func/test_stage.py +++ b/tests/func/test_stage.py @@ -16,40 +16,40 @@ def test_cmd_obj(): with pytest.raises(StageFileFormatError): - Dvcfile.validate({Stage.PARAM_CMD: {}}) + Dvcfile.validate_single_stage({Stage.PARAM_CMD: {}}) def test_cmd_none(): - Dvcfile.validate({Stage.PARAM_CMD: None}) + Dvcfile.validate_single_stage({Stage.PARAM_CMD: None}) def test_no_cmd(): - Dvcfile.validate({}) + Dvcfile.validate_single_stage({}) def test_cmd_str(): - Dvcfile.validate({Stage.PARAM_CMD: "cmd"}) + Dvcfile.validate_single_stage({Stage.PARAM_CMD: "cmd"}) def test_object(): with pytest.raises(StageFileFormatError): - Dvcfile.validate({Stage.PARAM_DEPS: {}}) + Dvcfile.validate_single_stage({Stage.PARAM_DEPS: {}}) with pytest.raises(StageFileFormatError): - Dvcfile.validate({Stage.PARAM_OUTS: {}}) + Dvcfile.validate_single_stage({Stage.PARAM_OUTS: {}}) def test_none(): - Dvcfile.validate({Stage.PARAM_DEPS: None}) - Dvcfile.validate({Stage.PARAM_OUTS: None}) + Dvcfile.validate_single_stage({Stage.PARAM_DEPS: None}) + Dvcfile.validate_single_stage({Stage.PARAM_OUTS: None}) def test_empty_list(): d = {Stage.PARAM_DEPS: []} - Dvcfile.validate(d) + Dvcfile.validate_single_stage(d) d = {Stage.PARAM_OUTS: []} - Dvcfile.validate(d) + Dvcfile.validate_single_stage(d) def test_list(): @@ -59,12 +59,12 @@ def test_list(): {OutputLOCAL.PARAM_PATH: "baz"}, ] d = {Stage.PARAM_DEPS: lst} - Dvcfile.validate(d) + Dvcfile.validate_single_stage(d) lst[0][OutputLOCAL.PARAM_CACHE] = True lst[1][OutputLOCAL.PARAM_CACHE] = False d = {Stage.PARAM_OUTS: lst} - Dvcfile.validate(d) + Dvcfile.validate_single_stage(d) class TestReload(TestDvc): @@ -82,7 +82,7 @@ def test(self): dump_stage_file(stage.relpath, d) dvcfile = Dvcfile(self.dvc, stage.relpath) - stage = dvcfile.load() + stage = dvcfile.stage self.assertTrue(stage is not None) dvcfile.dump(stage) @@ -106,7 +106,7 @@ def test_ignored_in_checksum(self): self.assertNotIn(Stage.PARAM_WDIR, d.keys()) with self.dvc.lock, self.dvc.state: - stage = Dvcfile(self.dvc, stage.relpath).load() + stage = Dvcfile(self.dvc, stage.relpath).stage self.assertFalse(stage.changed()) @@ -157,7 +157,7 @@ def test_md5_ignores_comments(tmp_dir, dvc): with open(stage.path, "a") as f: f.write("# End comment\n") - new_stage = Dvcfile(dvc, stage.path).load() + new_stage = Dvcfile(dvc, stage.path).stage assert not new_stage.changed_md5() @@ -171,7 +171,7 @@ def test_meta_is_preserved(tmp_dir, dvc): # Loading and dumping to test that it works and meta is retained dvcfile = Dvcfile(dvc, stage.path) - new_stage = dvcfile.load() + new_stage = dvcfile.stage dvcfile.dump(new_stage) new_data = load_stage_file(stage.path) diff --git a/tests/func/test_update.py b/tests/func/test_update.py index a77fe9ea61..dc9bae7c15 100644 --- a/tests/func/test_update.py +++ b/tests/func/test_update.py @@ -28,7 +28,7 @@ def test_update_import(tmp_dir, dvc, erepo_dir, cached): dvc.update(stage.path) assert (tmp_dir / "version").read_text() == "updated" - stage = Dvcfile(dvc, stage.path).load() + stage = Dvcfile(dvc, stage.path).stage assert stage.deps[0].def_repo["rev_lock"] == new_rev @@ -71,7 +71,7 @@ def test_update_import_after_remote_updates_to_dvc(tmp_dir, dvc, erepo_dir): assert imported.is_file() assert imported.read_text() == "updated" - stage = Dvcfile(dvc, stage.path).load() + stage = Dvcfile(dvc, stage.path).stage assert stage.deps[0].def_repo == { "url": fspath(erepo_dir), "rev": "branch", diff --git a/tests/unit/test_lockfile.py b/tests/unit/test_lockfile.py index 58f2f5c47d..294e2fd333 100644 --- a/tests/unit/test_lockfile.py +++ b/tests/unit/test_lockfile.py @@ -1,13 +1,12 @@ from dvc.stage import PipelineStage -from dvc import lockfile +from dvc import lockfile, serialize import json import pytest def test_stage_dump_no_outs_deps(tmp_dir, dvc): stage = PipelineStage(name="s1", repo=dvc, path="path", cmd="command") - - lockfile.dump(dvc, "path.lock", stage) + lockfile.dump(dvc, "path.lock", serialize.to_lockfile(stage)) assert lockfile.load(dvc, "path.lock") == { "s1": {"cmd": "command", "deps": {}, "outs": {}} } @@ -19,7 +18,7 @@ def test_stage_dump_when_already_exists(tmp_dir, dvc): json.dump(data, f) stage = PipelineStage(name="s2", repo=dvc, path="path", cmd="command2") - lockfile.dump(dvc, "path.lock", stage) + lockfile.dump(dvc, "path.lock", serialize.to_lockfile(stage)) assert lockfile.load(dvc, "path.lock") == { **data, "s2": {"cmd": "command2", "deps": {}, "outs": {}}, @@ -38,7 +37,7 @@ def test_stage_dump_with_deps_and_outs(tmp_dir, dvc): json.dump(data, f) stage = PipelineStage(name="s2", repo=dvc, path="path", cmd="command2") - lockfile.dump(dvc, "path.lock", stage) + lockfile.dump(dvc, "path.lock", serialize.to_lockfile(stage)) assert lockfile.load(dvc, "path.lock") == { **data, "s2": {"cmd": "command2", "deps": {}, "outs": {}}, @@ -47,9 +46,9 @@ def test_stage_dump_with_deps_and_outs(tmp_dir, dvc): def test_stage_overwrites_if_already_exists(tmp_dir, dvc): stage = PipelineStage(name="s2", repo=dvc, path="path", cmd="command2") - lockfile.dump(dvc, "path.lock", stage) + lockfile.dump(dvc, "path.lock", serialize.to_lockfile(stage)) stage = PipelineStage(name="s2", repo=dvc, path="path", cmd="command3") - lockfile.dump(dvc, "path.lock", stage) + lockfile.dump(dvc, "path.lock", serialize.to_lockfile(stage)) assert lockfile.load(dvc, "path.lock") == { "s2": {"cmd": "command3", "deps": {}, "outs": {}}, } From 4f1746a762f68d3bccb120bcdd7b48feb79ae657 Mon Sep 17 00:00:00 2001 From: Saugat Pachhai Date: Mon, 20 Apr 2020 20:03:39 +0545 Subject: [PATCH 2/3] repo: simplify stage collection --- dvc/repo/__init__.py | 48 ++++++++++++++++++-------------------------- 1 file changed, 19 insertions(+), 29 deletions(-) diff --git a/dvc/repo/__init__.py b/dvc/repo/__init__.py index c2931f11d1..66e33d2e7d 100644 --- a/dvc/repo/__init__.py +++ b/dvc/repo/__init__.py @@ -446,9 +446,8 @@ def _collect_stages(self): from dvc.stage import PipelineStage pipeline_stages = [] - single_stages = [] output_stages = [] - ignored_outs = [] + ignored_outs = set() outs = set() for root, dirs, files in self.tree.walk(self.root_dir): @@ -458,43 +457,34 @@ def _collect_stages(self): continue dvcfile = Dvcfile(self, path) for stage in dvcfile.stages.values(): - stages = ( - output_stages - if stage.is_data_source - else pipeline_stages + if isinstance(stage, PipelineStage): + ignored_outs.update( + out.path_info for out in stage.outs + ) + pipeline_stages.append(stage) + else: + # Old single-stages are used for both + # outputs and pipelines. + output_stages.append(stage) + + outs.update( + out.fspath + for out in stage.outs + if out.scheme == "local" ) - stages.append(stage) - if not ( - isinstance(stage, PipelineStage) - or stage.is_data_source - ): - single_stages.append(stage) - - for out in stage.outs: - if out.scheme == "local": - outs.add(out.fspath) - if isinstance(stage, PipelineStage): - ignored_outs.append(out) dirs[:] = [d for d in dirs if os.path.join(root, d) not in outs] # DVC files are generated by multi-stage for data management. # We need to ignore those stages for pipelines_stages, but still # should be collected for output stages. - _output_stages = [ + pipeline_stages.extend( stage for stage in output_stages - if all( - stage.outs and out.fspath != stage.outs[0].fspath - for out in ignored_outs - ) - ] - # Old single-stages are used for both outputs and pipelines - # so they go into both buckets: pipeline_stages and output stages. - return ( - output_stages + single_stages, - pipeline_stages + _output_stages, + if not stage.outs + or all(out.path_info not in ignored_outs for out in stage.outs) ) + return output_stages, pipeline_stages def find_outs_by_path(self, path, outs=None, recursive=False, strict=True): if not outs: From 4e09da10dae31c5df9b9046a348597fb8643960c Mon Sep 17 00:00:00 2001 From: Saugat Pachhai Date: Mon, 20 Apr 2020 20:29:59 +0545 Subject: [PATCH 3/3] fix typing --- dvc/lockfile.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dvc/lockfile.py b/dvc/lockfile.py index ee47bd3b91..19b37317c3 100644 --- a/dvc/lockfile.py +++ b/dvc/lockfile.py @@ -25,7 +25,7 @@ def read(repo: "Repo", path: str) -> dict: return json.load(f, object_pairs_hook=OrderedDict) -def write(repo: "Repo", path: str, data: dict) -> dict: +def write(repo: "Repo", path: str, data: dict) -> None: with repo.tree.open(path, "w+") as f: json.dump(data, f)