Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion dvc/repo/ls.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def onerror(exc):
for dname in dirs:
info = PathInfo(root) / dname
# pylint:disable=protected-access
_, dvctree = tree._get_tree_pairs(info) # noqa
_, dvctree = tree._get_tree_pair(info) # noqa
if not dvc_only or (dvctree and dvctree.exists(info)):
dvc = tree.isdvc(info)
path = str(info.relative_to(path_info))
Expand Down
106 changes: 69 additions & 37 deletions dvc/repo/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
import threading
from itertools import takewhile
from typing import Optional, Tuple
from typing import TYPE_CHECKING, Callable, Optional, Tuple, Union

from funcy import wrap_with
from pygtrie import StringTrie
Expand All @@ -14,6 +14,12 @@
from dvc.utils import file_md5
from dvc.utils.fs import copy_fobj_to_file, makedirs

if TYPE_CHECKING:
from dvc.repo import Repo
from dvc.tree.local import LocalTree
from dvc.tree.git import GitTree


logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -124,10 +130,7 @@ def isdir(self, path): # pylint: disable=arguments-differ

out = outs[0]
if not out.is_dir_checksum:
if out.path_info != path_info:
return True
return False

return out.path_info != path_info
if out.path_info == path_info:
return True

Expand All @@ -153,12 +156,11 @@ def _add_dir(self, top, trie, out, download_callback=None, **kwargs):
dir_cache = out.get_dir_cache(**kwargs)

# pull dir contents if needed
if self.fetch:
if out.changed_cache(filter_info=top):
used_cache = out.get_used_cache(filter_info=top)
downloaded = self.repo.cloud.pull(used_cache, **kwargs)
if download_callback:
download_callback(downloaded)
if self.fetch and out.changed_cache(filter_info=top):
used_cache = out.get_used_cache(filter_info=top)
downloaded = self.repo.cloud.pull(used_cache, **kwargs)
if download_callback:
download_callback(downloaded)

for entry in dir_cache:
entry_relpath = entry[out.tree.PARAM_RELPATH]
Expand Down Expand Up @@ -250,14 +252,21 @@ class RepoTree(BaseTree): # pylint:disable=abstract-method

Args:
repo: DVC or git repo.

Any kwargs will be passed to `DvcTree()`.
subrepos: traverse to subrepos (by default, it ignores subrepos)
repo_factory: A function to initialize subrepo with, default is Repo.
kwargs: Additional keyword arguments passed to the `DvcTree()`.
"""

scheme = "local"
PARAM_CHECKSUM = "md5"

def __init__(self, repo, subrepos=False, repo_factory=None, **kwargs):
def __init__(
self,
repo,
subrepos=False,
repo_factory: Callable[[str], "Repo"] = None,
**kwargs
):
super().__init__(repo, {"url": repo.root_dir})

if not repo_factory:
Expand All @@ -271,57 +280,75 @@ def __init__(self, repo, subrepos=False, repo_factory=None, **kwargs):
self.root_dir = repo.root_dir
self._traverse_subrepos = subrepos

self._discovered_subrepos = StringTrie(separator=os.sep)
self._discovered_subrepos[self.root_dir] = repo
self._subrepos_trie = StringTrie(separator=os.sep)
"""Keeps track of each and every path with the corresponding repo."""

self._subrepos_trie[self.root_dir] = repo

self._dvctrees = {}
"""Keep a dvctree instance of each repo."""

self._dvctree_configs = kwargs

if hasattr(repo, "dvc_dir"):
self._dvctrees[repo.root_dir] = DvcTree(repo, **kwargs)

def _get_repo(self, path):
repo = self._discovered_subrepos.get(path)
def _get_repo(self, path) -> Optional["Repo"]:
"""Returns repo that the path falls in, using prefix.

If the path is already tracked/collected, it just returns the repo.

Otherwise, it collects the repos that might be in the path's parents
and then returns the appropriate one.
"""
repo = self._subrepos_trie.get(path)
if repo:
return repo

prefix, repo = self._discovered_subrepos.longest_prefix(path)
prefix, repo = self._subrepos_trie.longest_prefix(path)
if not prefix:
return None

parents = (parent.fspath for parent in PathInfo(path).parents)
dirs = [path] + list(takewhile(lambda p: p != prefix, parents))
dirs.reverse()
self._update(dirs, starting_repo=repo)
return self._discovered_subrepos.get(path)
return self._subrepos_trie.get(path)

@wrap_with(threading.Lock())
def _update(self, dirs, starting_repo):
"""Checks for subrepo in directories and updates them."""
repo = starting_repo
for d in dirs:
if self._is_dvc_repo(d):
repo = self.repo_factory(d)
self._dvctrees[repo.root_dir] = DvcTree(
repo, **self._dvctree_configs
)
self._discovered_subrepos[d] = repo
self._subrepos_trie[d] = repo

def _is_dvc_repo(self, dir_path):
"""Check if the directory is a dvc repo."""
if not self._traverse_subrepos:
return False

from dvc.repo import Repo

repo_path = os.path.join(dir_path, Repo.DVC_DIR)
# dvcignore will ignore subrepos, therefore using `use_dvcignore=False`
return self._main_repo.tree.isdir(repo_path, use_dvcignore=False)

def _get_tree_pairs(self, path) -> Tuple["BaseTree", Optional["DvcTree"]]:
def _get_tree_pair(
self, path
) -> Tuple[Union["GitTree", "LocalTree"], DvcTree]:
"""
Returns a pair of trees based on repo the path falls in, using prefix.
"""
path = os.path.abspath(path)
repo = self._get_repo(path)
if not repo:
# path could be outside of the repo, so we just send them the main
# tree instead
return self._main_repo.tree, self._dvctrees.get(self.root_dir)

# fallback to the top-level repo if repo was not found
# this can happen if the path is outside of the repo
repo = self._get_repo(path) or self._main_repo

dvc_tree = self._dvctrees.get(repo.root_dir)
return repo.tree, dvc_tree
Expand All @@ -340,37 +367,37 @@ def open(
if "b" in mode:
encoding = None

tree, dvc_tree = self._get_tree_pairs(path)
tree, dvc_tree = self._get_tree_pair(path)
if dvc_tree and dvc_tree.exists(path):
return dvc_tree.open(path, mode=mode, encoding=encoding, **kwargs)
return tree.open(path, mode=mode, encoding=encoding)

def exists(
self, path, use_dvcignore=True
): # pylint: disable=arguments-differ
tree, dvc_tree = self._get_tree_pairs(path)
tree, dvc_tree = self._get_tree_pair(path)
return tree.exists(path) or (dvc_tree and dvc_tree.exists(path))

def isdir(self, path): # pylint: disable=arguments-differ
tree, dvc_tree = self._get_tree_pairs(path)
tree, dvc_tree = self._get_tree_pair(path)
return tree.isdir(path) or (dvc_tree and dvc_tree.isdir(path))

def isdvc(self, path, **kwargs):
_, dvc_tree = self._get_tree_pairs(path)
_, dvc_tree = self._get_tree_pair(path)
return dvc_tree is not None and dvc_tree.isdvc(path, **kwargs)

def isfile(self, path): # pylint: disable=arguments-differ
tree, dvc_tree = self._get_tree_pairs(path)
tree, dvc_tree = self._get_tree_pair(path)
return tree.isfile(path) or (dvc_tree and dvc_tree.isfile(path))

def isexec(self, path):
tree, dvc_tree = self._get_tree_pairs(path)
tree, dvc_tree = self._get_tree_pair(path)
if dvc_tree and dvc_tree.exists(path):
return dvc_tree.isexec(path)
return tree.isexec(path)

def stat(self, path):
tree, _ = self._get_tree_pairs(path)
tree, _ = self._get_tree_pair(path)
return tree.stat(path)

def _dvc_walk(self, walk):
Expand All @@ -383,7 +410,12 @@ def _dvc_walk(self, walk):
yield from self._dvc_walk(walk)

def _subrepo_walk(self, dir_path, **kwargs):
tree, dvc_tree = self._get_tree_pairs(dir_path)
"""Walk into a new repo.

NOTE: subrepo will only be discovered when walking if
ignore_subrepos is set to False.
"""
tree, dvc_tree = self._get_tree_pair(dir_path)
tree_walk = tree.walk(
dir_path, topdown=True, ignore_subrepos=not self._traverse_subrepos
)
Expand Down Expand Up @@ -464,7 +496,7 @@ def walk(
onerror(NotADirectoryError(top))
return

tree, dvc_tree = self._get_tree_pairs(top)
tree, dvc_tree = self._get_tree_pair(top)
dvc_exists = dvc_tree and dvc_tree.exists(top)
repo_exists = tree.exists(top)
if dvc_exists:
Expand Down Expand Up @@ -501,7 +533,7 @@ def get_file_hash(self, path_info):
"""
if not self.exists(path_info):
raise FileNotFoundError
_, dvc_tree = self._get_tree_pairs(path_info)
_, dvc_tree = self._get_tree_pair(path_info)
if dvc_tree and dvc_tree.exists(path_info):
try:
return dvc_tree.get_file_hash(path_info)
Expand Down