From 76894bc70408bf5c71dfc1e74758f941d198813e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Saugat=20Pachhai=20=28=E0=A4=B8=E0=A5=8C=E0=A4=97=E0=A4=BE?= =?UTF-8?q?=E0=A4=A4=29?= Date: Mon, 20 Mar 2023 12:47:15 +0545 Subject: [PATCH] targets from index --- dvc/repo/__init__.py | 7 +++- dvc/repo/index.py | 84 +++++++++++++++++++++++++++++--------------- 2 files changed, 62 insertions(+), 29 deletions(-) diff --git a/dvc/repo/__init__.py b/dvc/repo/__init__.py index 27945559a7..5f5597a997 100644 --- a/dvc/repo/__init__.py +++ b/dvc/repo/__init__.py @@ -476,6 +476,8 @@ def used_objs( # noqa: PLR0913 belong to each ODB. If the ODB instance is None, the objects are naive and do not belong to a specific remote ODB. """ + from .index import index_from_targets + used = defaultdict(set) for _ in self.brancher( @@ -487,7 +489,10 @@ def used_objs( # noqa: PLR0913 commit_date=commit_date, num=num, ): - for odb, objs in self.index.used_objs( + index, targets = index_from_targets( + self, targets, with_deps=with_deps, recursive=recursive + ) + for odb, objs in index.used_objs( targets, remote=remote, force=force, diff --git a/dvc/repo/index.py b/dvc/repo/index.py index 156c0f2ebc..d4df2a665f 100644 --- a/dvc/repo/index.py +++ b/dvc/repo/index.py @@ -166,6 +166,36 @@ def _load_storage_from_out(storage_map, key, out): storage_map.add_data(FileStorage(key, dep.fs, dep.fs_path)) +def index_from_targets( + repo: "Repo", + targets: Optional[List[str]] = None, + with_deps: bool = False, + recursive: bool = False, +) -> Tuple["Index", "Optional[List[str]]"]: + from dvc.stage.exceptions import StageFileDoesNotExistError, StageNotFound + from dvc.utils import parse_target + + if targets and all(targets) and not with_deps and not recursive: + try: + indexes = [] + for target in targets: + file, name = parse_target(target) + if file and not name: + index = Index.from_file(repo, file) + else: + stages = repo.stage.collect(target) + index = Index(repo, stages=list(stages)) + indexes.append(index) + + return ( + Index._from_indexes(repo, indexes), # pylint: disable=protected-access + None, + ) + except (StageFileDoesNotExistError, StageNotFound): + pass + return repo.index, targets + + class Index: def __init__( self, @@ -182,41 +212,15 @@ def __init__( self._params = params or {} self._collected_targets: Dict[int, List["StageInfo"]] = {} - @cached_property - def rev(self) -> Optional[str]: - if not isinstance(self.repo.fs, LocalFileSystem): - return self.repo.get_rev()[:7] - return None - - def __repr__(self) -> str: - rev = self.rev or "workspace" - return f"Index({self.repo}, fs@{rev})" - @classmethod def from_repo( cls, repo: "Repo", onerror: Optional[Callable[[str, Exception], None]] = None, ) -> "Index": - stages = [] - metrics = {} - plots = {} - params = {} - onerror = onerror or repo.stage_collection_error_handler - for _, idx in collect_files(repo, onerror=onerror): - # pylint: disable=protected-access - stages.extend(idx.stages) - metrics.update(idx._metrics) - plots.update(idx._plots) - params.update(idx._params) - return cls( - repo, - stages=stages, - metrics=metrics, - plots=plots, - params=params, - ) + indexes = [index for _, index in collect_files(repo, onerror=onerror)] + return cls._from_indexes(repo, indexes) @classmethod def from_file(cls, repo: "Repo", path: str) -> "Index": @@ -231,6 +235,30 @@ def from_file(cls, repo: "Repo", path: str) -> "Index": params={path: dvcfile.params} if dvcfile.params else {}, ) + @classmethod + def _from_indexes(cls, repo: "Repo", indexes: Iterable["Index"]) -> "Index": + stages = [] + metrics = {} + plots = {} + params = {} + + for index in indexes: + stages.extend(index.stages) + metrics.update(index._metrics) # pylint: disable=protected-access + plots.update(index._plots) # pylint: disable=protected-access + params.update(index._params) # pylint: disable=protected-access + return cls(repo, stages, metrics, plots, params) + + @cached_property + def rev(self) -> Optional[str]: + if not isinstance(self.repo.fs, LocalFileSystem): + return self.repo.get_rev()[:7] + return None + + def __repr__(self) -> str: + rev = self.rev or "workspace" + return f"Index({self.repo}, fs@{rev})" + def update(self, stages: Iterable["Stage"]) -> "Index": stages = set(stages) # we remove existing stages with same hashes at first