From 0cfb0a3aa7759907a6e9496ec1d245893295e71e Mon Sep 17 00:00:00 2001 From: Ruslan Kuprieiev Date: Tue, 10 Nov 2020 04:35:42 +0200 Subject: [PATCH] dvc: refactor cache/remote relations --- dvc/cache/__init__.py | 47 +-- dvc/cache/base.py | 312 +++++++++++++++++- dvc/cache/local.py | 428 ++----------------------- dvc/{remote => cache}/ssh.py | 23 +- dvc/data_cloud.py | 18 +- dvc/path_info.py | 1 - dvc/remote/__init__.py | 3 - dvc/remote/base.py | 523 ++++++++++++++++++++++++------- dvc/repo/experiments/__init__.py | 2 +- dvc/repo/gc.py | 27 +- dvc/repo/status.py | 2 +- dvc/stage/cache.py | 2 +- dvc/tree/base.py | 199 ------------ dvc/tree/local.py | 32 +- dvc/tree/ssh/__init__.py | 17 - tests/func/test_cache.py | 2 +- tests/func/test_data_cloud.py | 6 +- tests/func/test_gc.py | 5 +- tests/unit/cache/__init__.py | 0 tests/unit/remote/test_base.py | 72 +++-- tests/unit/remote/test_local.py | 2 +- 21 files changed, 864 insertions(+), 859 deletions(-) rename dvc/{remote => cache}/ssh.py (76%) create mode 100644 tests/unit/cache/__init__.py diff --git a/dvc/cache/__init__.py b/dvc/cache/__init__.py index 0cfc72700a..fc78f47cde 100644 --- a/dvc/cache/__init__.py +++ b/dvc/cache/__init__.py @@ -4,6 +4,30 @@ from ..scheme import Schemes +def get_cloud_cache(tree): + from .base import CloudCache + from .local import LocalCache + from .ssh import SSHCache + + if tree.scheme == Schemes.LOCAL: + return LocalCache(tree) + + if tree.scheme == Schemes.SSH: + return SSHCache(tree) + + return CloudCache(tree) + + +def _get_cache(repo, settings): + from ..tree import get_cloud_tree + + if not settings: + return None + + tree = get_cloud_tree(repo, **settings) + return get_cloud_cache(tree) + + class Cache: """Class that manages cache locations of a DVC repo. @@ -15,20 +39,16 @@ class Cache: CLOUD_SCHEMES = [Schemes.S3, Schemes.GS, Schemes.SSH, Schemes.HDFS] def __init__(self, repo): - from ..tree import get_cloud_tree - from .base import CloudCache - from .local import LocalCache - self.repo = repo self.config = config = repo.config["cache"] + self._cache = {} local = config.get("local") if local: settings = {"name": local} elif "dir" not in config: - self.local = None - return + settings = None else: from ..config import LOCAL_COMMON @@ -37,21 +57,12 @@ def __init__(self, repo): if opt in config: settings[str(opt)] = config.get(opt) - tree = get_cloud_tree(repo, **settings) - - self._cache = {} - self._cache[Schemes.LOCAL] = LocalCache(tree) + self._cache[Schemes.LOCAL] = _get_cache(repo, settings) for scheme in self.CLOUD_SCHEMES: - remote = self.config.get(scheme) - if remote: - tree = get_cloud_tree(self.repo, name=remote) - cache = CloudCache(tree) - else: - cache = None - - self._cache[scheme] = cache + settings = {"name": remote} if remote else None + self._cache[scheme] = _get_cache(repo, settings) def __getattr__(self, name): try: diff --git a/dvc/cache/base.py b/dvc/cache/base.py index 35f4a1af7d..4454dff51e 100644 --- a/dvc/cache/base.py +++ b/dvc/cache/base.py @@ -1,5 +1,7 @@ +import itertools import json import logging +from concurrent.futures import ThreadPoolExecutor from copy import copy from funcy import decorator @@ -20,19 +22,6 @@ logger = logging.getLogger(__name__) -STATUS_OK = 1 -STATUS_MISSING = 2 -STATUS_NEW = 3 -STATUS_DELETED = 4 - -STATUS_MAP = { - # (local_exists, remote_exists) - (True, True): STATUS_OK, - (False, False): STATUS_MISSING, - (True, False): STATUS_NEW, - (False, True): STATUS_DELETED, -} - class DirCacheError(DvcException): def __init__(self, hash_): @@ -646,3 +635,300 @@ def set_dir_info(self, hash_info): hash_info.dir_info = self.get_dir_cache(hash_info) hash_info.nfiles = hash_info.dir_info.nfiles + + def _list_paths(self, prefix=None, progress_callback=None): + if prefix: + if len(prefix) > 2: + path_info = self.tree.path_info / prefix[:2] / prefix[2:] + else: + path_info = self.tree.path_info / prefix[:2] + prefix = True + else: + path_info = self.tree.path_info + prefix = False + if progress_callback: + for file_info in self.tree.walk_files(path_info, prefix=prefix): + progress_callback() + yield file_info.path + else: + yield from self.tree.walk_files(path_info, prefix=prefix) + + def _path_to_hash(self, path): + parts = self.tree.PATH_CLS(path).parts[-2:] + + if not (len(parts) == 2 and parts[0] and len(parts[0]) == 2): + raise ValueError(f"Bad cache file path '{path}'") + + return "".join(parts) + + def list_hashes(self, prefix=None, progress_callback=None): + """Iterate over hashes in this tree. + + If `prefix` is specified, only hashes which begin with `prefix` + will be returned. + """ + for path in self._list_paths(prefix, progress_callback): + try: + yield self._path_to_hash(path) + except ValueError: + logger.debug( + "'%s' doesn't look like a cache file, skipping", path + ) + + def _hashes_with_limit(self, limit, prefix=None, progress_callback=None): + count = 0 + for hash_ in self.list_hashes(prefix, progress_callback): + yield hash_ + count += 1 + if count > limit: + logger.debug( + "`list_hashes()` returned max '{}' hashes, " + "skipping remaining results".format(limit) + ) + return + + def _max_estimation_size(self, hashes): + # Max remote size allowed for us to use traverse method + return max( + self.tree.TRAVERSE_THRESHOLD_SIZE, + len(hashes) + / self.tree.TRAVERSE_WEIGHT_MULTIPLIER + * self.tree.LIST_OBJECT_PAGE_SIZE, + ) + + def _estimate_remote_size(self, hashes=None, name=None): + """Estimate tree size based on number of entries beginning with + "00..." prefix. + """ + prefix = "0" * self.tree.TRAVERSE_PREFIX_LEN + total_prefixes = pow(16, self.tree.TRAVERSE_PREFIX_LEN) + if hashes: + max_hashes = self._max_estimation_size(hashes) + else: + max_hashes = None + + with Tqdm( + desc="Estimating size of " + + (f"cache in '{name}'" if name else "remote cache"), + unit="file", + ) as pbar: + + def update(n=1): + pbar.update(n * total_prefixes) + + if max_hashes: + hashes = self._hashes_with_limit( + max_hashes / total_prefixes, prefix, update + ) + else: + hashes = self.list_hashes(prefix, update) + + remote_hashes = set(hashes) + if remote_hashes: + remote_size = total_prefixes * len(remote_hashes) + else: + remote_size = total_prefixes + logger.debug(f"Estimated remote size: {remote_size} files") + return remote_size, remote_hashes + + def list_hashes_traverse( + self, remote_size, remote_hashes, jobs=None, name=None + ): + """Iterate over all hashes found in this tree. + Hashes are fetched in parallel according to prefix, except in + cases where the remote size is very small. + + All hashes from the remote (including any from the size + estimation step passed via the `remote_hashes` argument) will be + returned. + + NOTE: For large remotes the list of hashes will be very + big(e.g. 100M entries, md5 for each is 32 bytes, so ~3200Mb list) + and we don't really need all of it at the same time, so it makes + sense to use a generator to gradually iterate over it, without + keeping all of it in memory. + """ + num_pages = remote_size / self.tree.LIST_OBJECT_PAGE_SIZE + if num_pages < 256 / self.tree.JOBS: + # Fetching prefixes in parallel requires at least 255 more + # requests, for small enough remotes it will be faster to fetch + # entire cache without splitting it into prefixes. + # + # NOTE: this ends up re-fetching hashes that were already + # fetched during remote size estimation + traverse_prefixes = [None] + initial = 0 + else: + yield from remote_hashes + initial = len(remote_hashes) + traverse_prefixes = [f"{i:02x}" for i in range(1, 256)] + if self.tree.TRAVERSE_PREFIX_LEN > 2: + traverse_prefixes += [ + "{0:0{1}x}".format(i, self.tree.TRAVERSE_PREFIX_LEN) + for i in range( + 1, pow(16, self.tree.TRAVERSE_PREFIX_LEN - 2) + ) + ] + with Tqdm( + desc="Querying " + + (f"cache in '{name}'" if name else "remote cache"), + total=remote_size, + initial=initial, + unit="file", + ) as pbar: + + def list_with_update(prefix): + return list( + self.list_hashes( + prefix=prefix, progress_callback=pbar.update + ) + ) + + with ThreadPoolExecutor( + max_workers=jobs or self.tree.JOBS + ) as executor: + in_remote = executor.map(list_with_update, traverse_prefixes,) + yield from itertools.chain.from_iterable(in_remote) + + def all(self, jobs=None, name=None): + """Iterate over all hashes in this tree. + + Hashes will be fetched in parallel threads according to prefix + (except for small remotes) and a progress bar will be displayed. + """ + logger.debug( + "Fetching all hashes from '{}'".format( + name if name else "remote cache" + ) + ) + + if not self.tree.CAN_TRAVERSE: + return self.list_hashes() + + remote_size, remote_hashes = self._estimate_remote_size(name=name) + return self.list_hashes_traverse( + remote_size, remote_hashes, jobs, name + ) + + def _remove_unpacked_dir(self, hash_): + pass + + def gc(self, used, jobs=None): + removed = False + # hashes must be sorted to ensure we always remove .dir files first + for hash_ in sorted( + self.all(jobs, str(self.tree.path_info)), + key=self.tree.is_dir_hash, + reverse=True, + ): + if hash_ in used: + continue + path_info = self.tree.hash_to_path_info(hash_) + if self.tree.is_dir_hash(hash_): + # backward compatibility + # pylint: disable=protected-access + self._remove_unpacked_dir(hash_) + self.tree.remove(path_info) + removed = True + + return removed + + def list_hashes_exists(self, hashes, jobs=None, name=None): + """Return list of the specified hashes which exist in this tree. + Hashes will be queried individually. + """ + logger.debug( + "Querying {} hashes via object_exists".format(len(hashes)) + ) + with Tqdm( + desc="Querying " + + ("cache in " + name if name else "remote cache"), + total=len(hashes), + unit="file", + ) as pbar: + + def exists_with_progress(path_info): + ret = self.tree.exists(path_info) + pbar.update_msg(str(path_info)) + return ret + + with ThreadPoolExecutor( + max_workers=jobs or self.tree.JOBS + ) as executor: + path_infos = map(self.tree.hash_to_path_info, hashes) + in_remote = executor.map(exists_with_progress, path_infos) + ret = list(itertools.compress(hashes, in_remote)) + return ret + + def hashes_exist(self, hashes, jobs=None, name=None): + """Check if the given hashes are stored in the remote. + + There are two ways of performing this check: + + - Traverse method: Get a list of all the files in the remote + (traversing the cache directory) and compare it with + the given hashes. Cache entries will be retrieved in parallel + threads according to prefix (i.e. entries starting with, "00...", + "01...", and so on) and a progress bar will be displayed. + + - Exists method: For each given hash, run the `exists` + method and filter the hashes that aren't on the remote. + This is done in parallel threads. + It also shows a progress bar when performing the check. + + The reason for such an odd logic is that most of the remotes + take much shorter time to just retrieve everything they have under + a certain prefix (e.g. s3, gs, ssh, hdfs). Other remotes that can + check if particular file exists much quicker, use their own + implementation of hashes_exist (see ssh, local). + + Which method to use will be automatically determined after estimating + the size of the remote cache, and comparing the estimated size with + len(hashes). To estimate the size of the remote cache, we fetch + a small subset of cache entries (i.e. entries starting with "00..."). + Based on the number of entries in that subset, the size of the full + cache can be estimated, since the cache is evenly distributed according + to hash. + + Returns: + A list with hashes that were found in the remote + """ + # Remotes which do not use traverse prefix should override + # hashes_exist() (see ssh, local) + assert self.tree.TRAVERSE_PREFIX_LEN >= 2 + + hashes = set(hashes) + if len(hashes) == 1 or not self.tree.CAN_TRAVERSE: + remote_hashes = self.list_hashes_exists(hashes, jobs, name) + return remote_hashes + + # Max remote size allowed for us to use traverse method + remote_size, remote_hashes = self._estimate_remote_size(hashes, name) + + traverse_pages = remote_size / self.tree.LIST_OBJECT_PAGE_SIZE + # For sufficiently large remotes, traverse must be weighted to account + # for performance overhead from large lists/sets. + # From testing with S3, for remotes with 1M+ files, object_exists is + # faster until len(hashes) is at least 10k~100k + if remote_size > self.tree.TRAVERSE_THRESHOLD_SIZE: + traverse_weight = ( + traverse_pages * self.tree.TRAVERSE_WEIGHT_MULTIPLIER + ) + else: + traverse_weight = traverse_pages + if len(hashes) < traverse_weight: + logger.debug( + "Large remote ('{}' hashes < '{}' traverse weight), " + "using object_exists for remaining hashes".format( + len(hashes), traverse_weight + ) + ) + return list(hashes & remote_hashes) + self.list_hashes_exists( + hashes - remote_hashes, jobs, name + ) + + logger.debug("Querying '{}' hashes via traverse".format(len(hashes))) + remote_hashes = set( + self.list_hashes_traverse(remote_size, remote_hashes, jobs, name) + ) + return list(hashes & set(remote_hashes)) diff --git a/dvc/cache/local.py b/dvc/cache/local.py index 547c890895..e2c0199e23 100644 --- a/dvc/cache/local.py +++ b/dvc/cache/local.py @@ -1,54 +1,23 @@ -import errno import logging import os -from concurrent.futures import ThreadPoolExecutor, as_completed -from functools import partial, wraps -from funcy import cached_property, concat +from funcy import cached_property -from dvc.exceptions import DownloadError, UploadError from dvc.hash_info import HashInfo from dvc.path_info import PathInfo from dvc.progress import Tqdm -from ..remote.base import index_locked from ..tree.local import LocalTree -from .base import ( - STATUS_DELETED, - STATUS_MAP, - STATUS_MISSING, - STATUS_NEW, - CloudCache, - use_state, -) +from ..utils.fs import walk_files +from .base import CloudCache logger = logging.getLogger(__name__) -def _log_exceptions(func, operation): - @wraps(func) - def wrapper(from_info, to_info, *args, **kwargs): - try: - func(from_info, to_info, *args, **kwargs) - return 0 - except Exception as exc: # pylint: disable=broad-except - # NOTE: this means we ran out of file descriptors and there is no - # reason to try to proceed, as we will hit this error anyways. - # pylint: disable=no-member - if isinstance(exc, OSError) and exc.errno == errno.EMFILE: - raise - - logger.exception( - "failed to %s '%s' to '%s'", operation, from_info, to_info - ) - return 1 - - return wrapper - - class LocalCache(CloudCache): DEFAULT_CACHE_TYPES = ["reflink", "copy"] CACHE_MODE = LocalTree.CACHE_MODE + UNPACKED_DIR_SUFFIX = ".unpacked" def __init__(self, tree): super().__init__(tree) @@ -62,10 +31,6 @@ def cache_dir(self): def cache_dir(self, value): self.tree.path_info = PathInfo(value) if value else None - @classmethod - def supported(cls, config): # pylint: disable=unused-argument - return True - @cached_property def cache_path(self): return os.path.abspath(self.cache_dir) @@ -103,373 +68,24 @@ def _verify_link(self, path_info, link_type): super()._verify_link(path_info, link_type) - @index_locked - def status( - self, - named_cache, - remote, - jobs=None, - show_checksums=False, - download=False, - log_missing=True, - ): - # Return flattened dict containing all status info - dir_status, file_status, _ = self._status( - named_cache, - remote, - jobs=jobs, - show_checksums=show_checksums, - download=download, - log_missing=log_missing, - ) - return dict(dir_status, **file_status) - - def _status( - self, - named_cache, - remote, - jobs=None, - show_checksums=False, - download=False, - log_missing=True, - ): - """Return a tuple of (dir_status_info, file_status_info, dir_contents). - - dir_status_info contains status for .dir files, file_status_info - contains status for all other files, and dir_contents is a dict of - {dir_hash: set(file_hash, ...)} which can be used to map - a .dir file to its file contents. - """ - logger.debug( - f"Preparing to collect status from {remote.tree.path_info}" - ) - md5s = set(named_cache.scheme_keys(self.tree.scheme)) - - logger.debug("Collecting information from local cache...") - local_exists = frozenset( - self.hashes_exist(md5s, jobs=jobs, name=self.cache_dir) - ) - - # This is a performance optimization. We can safely assume that, - # if the resources that we want to fetch are already cached, - # there's no need to check the remote storage for the existence of - # those files. - if download and local_exists == md5s: - remote_exists = local_exists + def _list_paths(self, prefix=None, progress_callback=None): + assert self.tree.path_info is not None + if prefix: + path_info = self.tree.path_info / prefix[:2] + if not self.tree.exists(path_info): + return else: - logger.debug("Collecting information from remote cache...") - remote_exists = set() - dir_md5s = set(named_cache.dir_keys(self.tree.scheme)) - if dir_md5s: - remote_exists.update( - self._indexed_dir_hashes(named_cache, remote, dir_md5s) - ) - md5s.difference_update(remote_exists) - if md5s: - remote_exists.update( - remote.hashes_exist( - md5s, jobs=jobs, name=str(remote.tree.path_info) - ) - ) - return self._make_status( - named_cache, - show_checksums, - local_exists, - remote_exists, - log_missing, - ) - - def _make_status( - self, - named_cache, - show_checksums, - local_exists, - remote_exists, - log_missing, - ): - def make_names(hash_, names): - return {"name": hash_ if show_checksums else " ".join(names)} - - dir_status = {} - file_status = {} - dir_contents = {} - for hash_, item in named_cache[self.tree.scheme].items(): - if item.children: - dir_status[hash_] = make_names(hash_, item.names) - dir_contents[hash_] = set() - for child_hash, child in item.children.items(): - file_status[child_hash] = make_names( - child_hash, child.names - ) - dir_contents[hash_].add(child_hash) - else: - file_status[hash_] = make_names(hash_, item.names) - - self._fill_statuses(dir_status, local_exists, remote_exists) - self._fill_statuses(file_status, local_exists, remote_exists) - - if log_missing: - self._log_missing_caches(dict(dir_status, **file_status)) - - return dir_status, file_status, dir_contents - - def _indexed_dir_hashes(self, named_cache, remote, dir_md5s): - # Validate our index by verifying all indexed .dir hashes - # still exist on the remote - indexed_dirs = set(remote.index.dir_hashes()) - indexed_dir_exists = set() - if indexed_dirs: - indexed_dir_exists.update( - remote.tree.list_hashes_exists(indexed_dirs) - ) - missing_dirs = indexed_dirs.difference(indexed_dir_exists) - if missing_dirs: - logger.debug( - "Remote cache missing indexed .dir hashes '{}', " - "clearing remote index".format(", ".join(missing_dirs)) - ) - remote.index.clear() - - # Check if non-indexed (new) dir hashes exist on remote - dir_exists = dir_md5s.intersection(indexed_dir_exists) - dir_exists.update( - remote.tree.list_hashes_exists(dir_md5s - dir_exists) - ) - - # If .dir hash exists on the remote, assume directory contents - # still exists on the remote - for dir_hash in dir_exists: - file_hashes = list( - named_cache.child_keys(self.tree.scheme, dir_hash) - ) - if dir_hash not in remote.index: - logger.debug( - "Indexing new .dir '{}' with '{}' nested files".format( - dir_hash, len(file_hashes) - ) - ) - remote.index.update([dir_hash], file_hashes) - yield dir_hash - yield from file_hashes - - @staticmethod - def _fill_statuses(hash_info_dir, local_exists, remote_exists): - # Using sets because they are way faster for lookups - local = set(local_exists) - remote = set(remote_exists) - - for md5, info in hash_info_dir.items(): - status = STATUS_MAP[(md5 in local, md5 in remote)] - info["status"] = status - - def _get_plans(self, download, remote, status_info, status): - cache = [] - path_infos = [] - names = [] - hashes = [] - missing = [] - for md5, info in Tqdm( - status_info.items(), desc="Analysing status", unit="file" - ): - if info["status"] == status: - cache.append(self.tree.hash_to_path_info(md5)) - path_infos.append(remote.tree.hash_to_path_info(md5)) - names.append(info["name"]) - hashes.append(md5) - elif info["status"] == STATUS_MISSING: - missing.append(md5) - - if download: - to_infos = cache - from_infos = path_infos + path_info = self.tree.path_info + # NOTE: use utils.fs walk_files since tree.walk_files will not follow + # symlinks + if progress_callback: + for path in walk_files(path_info): + progress_callback() + yield path else: - to_infos = path_infos - from_infos = cache + yield from walk_files(path_info) - return (from_infos, to_infos, names, hashes), missing - - def _process( - self, - named_cache, - remote, - jobs=None, - show_checksums=False, - download=False, - ): - logger.debug( - "Preparing to {} '{}'".format( - "download data from" if download else "upload data to", - remote.tree.path_info, - ) - ) - - if download: - func = partial( - _log_exceptions(remote.tree.download, "download"), - dir_mode=self.tree.dir_mode, - file_mode=self.tree.file_mode, - ) - status = STATUS_DELETED - desc = "Downloading" - else: - func = _log_exceptions(remote.tree.upload, "upload") - status = STATUS_NEW - desc = "Uploading" - - if jobs is None: - jobs = remote.tree.JOBS - - dir_status, file_status, dir_contents = self._status( - named_cache, - remote, - jobs=jobs, - show_checksums=show_checksums, - download=download, - ) - - dir_plans, _ = self._get_plans(download, remote, dir_status, status) - file_plans, missing_files = self._get_plans( - download, remote, file_status, status - ) - - total = len(dir_plans[0]) + len(file_plans[0]) - if total == 0: - return 0 - - with Tqdm(total=total, unit="file", desc=desc) as pbar: - func = pbar.wrap_fn(func) - with ThreadPoolExecutor(max_workers=jobs) as executor: - if download: - from_infos, to_infos, names, _ = ( - d + f for d, f in zip(dir_plans, file_plans) - ) - fails = sum( - executor.map(func, from_infos, to_infos, names) - ) - else: - # for uploads, push files first, and any .dir files last - - file_futures = {} - for from_info, to_info, name, hash_ in zip(*file_plans): - file_futures[hash_] = executor.submit( - func, from_info, to_info, name - ) - dir_futures = {} - for from_info, to_info, name, dir_hash in zip(*dir_plans): - # if for some reason a file contained in this dir is - # missing both locally and in the remote, we want to - # push whatever file content we have, but should not - # push .dir file - for file_hash in missing_files: - if file_hash in dir_contents[dir_hash]: - logger.debug( - "directory '%s' contains missing files," - "skipping .dir file upload", - name, - ) - break - else: - wait_futures = { - future - for file_hash, future in file_futures.items() - if file_hash in dir_contents[dir_hash] - } - dir_futures[dir_hash] = executor.submit( - self._dir_upload, - func, - wait_futures, - from_info, - to_info, - name, - ) - fails = sum( - future.result() - for future in concat( - file_futures.values(), dir_futures.values() - ) - ) - - if fails: - if download: - remote.index.clear() - raise DownloadError(fails) - raise UploadError(fails) - - if not download: - # index successfully pushed dirs - for dir_hash, future in dir_futures.items(): - if future.result() == 0: - file_hashes = dir_contents[dir_hash] - logger.debug( - "Indexing pushed dir '{}' with " - "'{}' nested files".format(dir_hash, len(file_hashes)) - ) - remote.index.update([dir_hash], file_hashes) - - return len(dir_plans[0]) + len(file_plans[0]) - - @staticmethod - def _dir_upload(func, futures, from_info, to_info, name): - for future in as_completed(futures): - if future.result(): - # do not upload this .dir file if any file in this - # directory failed to upload - logger.debug( - "failed to upload full contents of '{}', " - "aborting .dir file upload".format(name) - ) - logger.error(f"failed to upload '{from_info}' to '{to_info}'") - return 1 - return func(from_info, to_info, name) - - @index_locked - def push(self, named_cache, remote, jobs=None, show_checksums=False): - return self._process( - named_cache, - remote, - jobs=jobs, - show_checksums=show_checksums, - download=False, - ) - - @use_state - @index_locked - def pull(self, named_cache, remote, jobs=None, show_checksums=False): - ret = self._process( - named_cache, - remote, - jobs=jobs, - show_checksums=show_checksums, - download=True, - ) - - if not remote.tree.verify: - for checksum in named_cache.scheme_keys("local"): - cache_file = self.tree.hash_to_path_info(checksum) - if self.tree.exists(cache_file): - # We can safely save here, as existing corrupted files will - # be removed upon status, while files corrupted during - # download will not be moved from tmp_file - # (see `BaseTree.download()`) - hash_info = HashInfo(self.tree.PARAM_CHECKSUM, checksum) - self.tree.state.save(cache_file, hash_info) - - return ret - - @staticmethod - def _log_missing_caches(hash_info_dict): - missing_caches = [ - (md5, info) - for md5, info in hash_info_dict.items() - if info["status"] == STATUS_MISSING - ] - if missing_caches: - missing_desc = "\n".join( - "name: {}, md5: {}".format(info["name"], md5) - for md5, info in missing_caches - ) - msg = ( - "Some of the cache files do not exist neither locally " - "nor on remote. Missing cache files:\n{}".format(missing_desc) - ) - logger.warning(msg) + def _remove_unpacked_dir(self, hash_): + info = self.tree.hash_to_path_info(hash_) + path_info = info.with_name(info.name + self.UNPACKED_DIR_SUFFIX) + self.tree.remove(path_info) diff --git a/dvc/remote/ssh.py b/dvc/cache/ssh.py similarity index 76% rename from dvc/remote/ssh.py rename to dvc/cache/ssh.py index 84a261acd6..b734dd5765 100644 --- a/dvc/remote/ssh.py +++ b/dvc/cache/ssh.py @@ -1,17 +1,18 @@ import errno import itertools import logging +import posixpath from concurrent.futures import ThreadPoolExecutor from dvc.progress import Tqdm from dvc.utils import to_chunks -from .base import Remote +from .base import CloudCache logger = logging.getLogger(__name__) -class SSHRemote(Remote): +class SSHCache(CloudCache): def batch_exists(self, path_infos, callback): def _exists(chunk_and_channel): chunk, channel = chunk_and_channel @@ -47,7 +48,7 @@ def hashes_exist(self, hashes, jobs=None, name=None): remote/base. """ if not self.tree.CAN_TRAVERSE: - return list(set(hashes) & set(self.tree.all())) + return list(set(hashes) & set(self.all())) # possibly prompt for credentials before "Querying" progress output self.tree.ensure_credentials() @@ -71,3 +72,19 @@ def exists_with_progress(chunks): in_remote = itertools.chain.from_iterable(results) ret = list(itertools.compress(hashes, in_remote)) return ret + + def _list_paths(self, prefix=None, progress_callback=None): + if prefix: + root = posixpath.join(self.tree.path_info.path, prefix[:2]) + else: + root = self.tree.path_info.path + with self.tree.ssh(self.tree.path_info) as ssh: + if prefix and not ssh.exists(root): + return + # If we simply return an iterator then with above closes instantly + if progress_callback: + for path in ssh.walk_files(root): + progress_callback() + yield path + else: + yield from ssh.walk_files(root) diff --git a/dvc/data_cloud.py b/dvc/data_cloud.py index 30f4a618d7..71f3653d77 100644 --- a/dvc/data_cloud.py +++ b/dvc/data_cloud.py @@ -62,8 +62,11 @@ def push( """ remote = self.get_remote(remote, "push") - return self.repo.cache.local.push( - cache, jobs=jobs, remote=remote, show_checksums=show_checksums, + return remote.push( + self.repo.cache.local, + cache, + jobs=jobs, + show_checksums=show_checksums, ) def pull( @@ -81,8 +84,11 @@ def pull( """ remote = self.get_remote(remote, "pull") - return self.repo.cache.local.pull( - cache, jobs=jobs, remote=remote, show_checksums=show_checksums + return remote.pull( + self.repo.cache.local, + cache, + jobs=jobs, + show_checksums=show_checksums, ) def status( @@ -107,10 +113,10 @@ def status( neither in cache, neither in cloud. """ remote = self.get_remote(remote, "status") - return self.repo.cache.local.status( + return remote.status( + self.repo.cache.local, cache, jobs=jobs, - remote=remote, show_checksums=show_checksums, log_missing=log_missing, ) diff --git a/dvc/path_info.py b/dvc/path_info.py index c170086bdb..3f35eba90f 100644 --- a/dvc/path_info.py +++ b/dvc/path_info.py @@ -124,7 +124,6 @@ def __init__(self, url): p = urlparse(url) assert not p.query and not p.params and not p.fragment assert p.password is None - self._fill_parts(p.scheme, p.hostname, p.username, p.port, p.path) @classmethod diff --git a/dvc/remote/__init__.py b/dvc/remote/__init__.py index 0c622d4b40..7979ed7cf7 100644 --- a/dvc/remote/__init__.py +++ b/dvc/remote/__init__.py @@ -1,13 +1,10 @@ from ..tree import get_cloud_tree from .base import Remote from .local import LocalRemote -from .ssh import SSHRemote def get_remote(repo, **kwargs): tree = get_cloud_tree(repo, **kwargs) if tree.scheme == "local": return LocalRemote(tree) - if tree.scheme == "ssh": - return SSHRemote(tree) return Remote(tree) diff --git a/dvc/remote/base.py b/dvc/remote/base.py index 8b2bd90943..f44dc23026 100644 --- a/dvc/remote/base.py +++ b/dvc/remote/base.py @@ -1,19 +1,59 @@ +import errno import hashlib import logging -from functools import wraps +from concurrent.futures import ThreadPoolExecutor, as_completed +from functools import partial, wraps +from funcy import concat + +from dvc.exceptions import DownloadError, UploadError +from dvc.hash_info import HashInfo + +from ..progress import Tqdm from .index import RemoteIndex, RemoteIndexNoop logger = logging.getLogger(__name__) +STATUS_OK = 1 +STATUS_MISSING = 2 +STATUS_NEW = 3 +STATUS_DELETED = 4 + +STATUS_MAP = { + # (local_exists, remote_exists) + (True, True): STATUS_OK, + (False, False): STATUS_MISSING, + (True, False): STATUS_NEW, + (False, True): STATUS_DELETED, +} + + +def _log_exceptions(func, operation): + @wraps(func) + def wrapper(from_info, to_info, *args, **kwargs): + try: + func(from_info, to_info, *args, **kwargs) + return 0 + except Exception as exc: # pylint: disable=broad-except + # NOTE: this means we ran out of file descriptors and there is no + # reason to try to proceed, as we will hit this error anyways. + # pylint: disable=no-member + if isinstance(exc, OSError) and exc.errno == errno.EMFILE: + raise + + logger.exception( + "failed to %s '%s' to '%s'", operation, from_info, to_info + ) + return 1 + + return wrapper + def index_locked(f): @wraps(f) - def wrapper(obj, named_cache, remote, *args, **kwargs): - if hasattr(remote, "index"): - with remote.index: - return f(obj, named_cache, remote, *args, **kwargs) - return f(obj, named_cache, remote, *args, **kwargs) + def wrapper(obj, *args, **kwargs): + with obj.index: + return f(obj, *args, **kwargs) return wrapper @@ -28,8 +68,11 @@ class Remote: INDEX_CLS = RemoteIndex def __init__(self, tree): + from dvc.cache import get_cloud_cache + self.tree = tree self.repo = tree.repo + self.cache = get_cloud_cache(self.tree) config = tree.config url = config.get("url") @@ -47,123 +90,393 @@ def __repr__(self): path_info=self.tree.path_info or "No path", ) - @property - def cache(self): - return getattr(self.repo.cache, self.tree.scheme) - - def hashes_exist(self, hashes, jobs=None, name=None): - """Check if the given hashes are stored in the remote. - - There are two ways of performing this check: - - - Traverse method: Get a list of all the files in the remote - (traversing the cache directory) and compare it with - the given hashes. Cache entries will be retrieved in parallel - threads according to prefix (i.e. entries starting with, "00...", - "01...", and so on) and a progress bar will be displayed. - - - Exists method: For each given hash, run the `exists` - method and filter the hashes that aren't on the remote. - This is done in parallel threads. - It also shows a progress bar when performing the check. - - The reason for such an odd logic is that most of the remotes - take much shorter time to just retrieve everything they have under - a certain prefix (e.g. s3, gs, ssh, hdfs). Other remotes that can - check if particular file exists much quicker, use their own - implementation of hashes_exist (see ssh, local). - - Which method to use will be automatically determined after estimating - the size of the remote cache, and comparing the estimated size with - len(hashes). To estimate the size of the remote cache, we fetch - a small subset of cache entries (i.e. entries starting with "00..."). - Based on the number of entries in that subset, the size of the full - cache can be estimated, since the cache is evenly distributed according - to hash. - - Returns: - A list with hashes that were found in the remote - """ - # Remotes which do not use traverse prefix should override - # hashes_exist() (see ssh, local) - assert self.tree.TRAVERSE_PREFIX_LEN >= 2 + @index_locked + def gc(self, *args, **kwargs): + removed = self.cache.gc(*args, **kwargs) + + if removed: + self.index.clear() + + return removed + + @index_locked + def status( + self, + cache, + named_cache, + jobs=None, + show_checksums=False, + download=False, + log_missing=True, + ): + # Return flattened dict containing all status info + dir_status, file_status, _ = self._status( + cache, + named_cache, + jobs=jobs, + show_checksums=show_checksums, + download=download, + log_missing=log_missing, + ) + return dict(dir_status, **file_status) + def hashes_exist(self, hashes, **kwargs): hashes = set(hashes) indexed_hashes = set(self.index.intersection(hashes)) hashes -= indexed_hashes + indexed_hashes = list(indexed_hashes) logger.debug("Matched '{}' indexed hashes".format(len(indexed_hashes))) if not hashes: return indexed_hashes - if len(hashes) == 1 or not self.tree.CAN_TRAVERSE: - remote_hashes = self.tree.list_hashes_exists(hashes, jobs, name) - return list(indexed_hashes) + remote_hashes + return indexed_hashes + self.cache.hashes_exist(list(hashes), **kwargs) + + def _status( + self, + cache, + named_cache, + jobs=None, + show_checksums=False, + download=False, + log_missing=True, + ): + """Return a tuple of (dir_status_info, file_status_info, dir_contents). + + dir_status_info contains status for .dir files, file_status_info + contains status for all other files, and dir_contents is a dict of + {dir_hash: set(file_hash, ...)} which can be used to map + a .dir file to its file contents. + """ + logger.debug(f"Preparing to collect status from {self.tree.path_info}") + md5s = set(named_cache.scheme_keys(cache.tree.scheme)) - # Max remote size allowed for us to use traverse method - remote_size, remote_hashes = self.tree.estimate_remote_size( - hashes, name + logger.debug("Collecting information from local cache...") + local_exists = frozenset( + cache.hashes_exist(md5s, jobs=jobs, name=cache.cache_dir) ) - traverse_pages = remote_size / self.tree.LIST_OBJECT_PAGE_SIZE - # For sufficiently large remotes, traverse must be weighted to account - # for performance overhead from large lists/sets. - # From testing with S3, for remotes with 1M+ files, object_exists is - # faster until len(hashes) is at least 10k~100k - if remote_size > self.tree.TRAVERSE_THRESHOLD_SIZE: - traverse_weight = ( - traverse_pages * self.tree.TRAVERSE_WEIGHT_MULTIPLIER - ) + # This is a performance optimization. We can safely assume that, + # if the resources that we want to fetch are already cached, + # there's no need to check the remote storage for the existence of + # those files. + if download and local_exists == md5s: + remote_exists = local_exists else: - traverse_weight = traverse_pages - if len(hashes) < traverse_weight: - logger.debug( - "Large remote ('{}' hashes < '{}' traverse weight), " - "using object_exists for remaining hashes".format( - len(hashes), traverse_weight + logger.debug("Collecting information from remote cache...") + remote_exists = set() + dir_md5s = set(named_cache.dir_keys(cache.tree.scheme)) + if dir_md5s: + remote_exists.update( + self._indexed_dir_hashes(cache, named_cache, dir_md5s) + ) + md5s.difference_update(remote_exists) + if md5s: + remote_exists.update( + self.hashes_exist( + md5s, jobs=jobs, name=str(self.tree.path_info) + ) ) + return self._make_status( + cache, + named_cache, + show_checksums, + local_exists, + remote_exists, + log_missing, + ) + + def _make_status( + self, + cache, + named_cache, + show_checksums, + local_exists, + remote_exists, + log_missing, + ): + def make_names(hash_, names): + return {"name": hash_ if show_checksums else " ".join(names)} + + dir_status = {} + file_status = {} + dir_contents = {} + for hash_, item in named_cache[cache.tree.scheme].items(): + if item.children: + dir_status[hash_] = make_names(hash_, item.names) + dir_contents[hash_] = set() + for child_hash, child in item.children.items(): + file_status[child_hash] = make_names( + child_hash, child.names + ) + dir_contents[hash_].add(child_hash) + else: + file_status[hash_] = make_names(hash_, item.names) + + self._fill_statuses(dir_status, local_exists, remote_exists) + self._fill_statuses(file_status, local_exists, remote_exists) + + if log_missing: + self._log_missing_caches(dict(dir_status, **file_status)) + + return dir_status, file_status, dir_contents + + def _indexed_dir_hashes(self, cache, named_cache, dir_md5s): + # Validate our index by verifying all indexed .dir hashes + # still exist on the remote + indexed_dirs = set(self.index.dir_hashes()) + indexed_dir_exists = set() + if indexed_dirs: + indexed_dir_exists.update( + self.cache.list_hashes_exists(indexed_dirs) + ) + missing_dirs = indexed_dirs.difference(indexed_dir_exists) + if missing_dirs: + logger.debug( + "Remote cache missing indexed .dir hashes '{}', " + "clearing remote index".format(", ".join(missing_dirs)) + ) + self.index.clear() + + # Check if non-indexed (new) dir hashes exist on remote + dir_exists = dir_md5s.intersection(indexed_dir_exists) + dir_exists.update(self.cache.list_hashes_exists(dir_md5s - dir_exists)) + + # If .dir hash exists on the remote, assume directory contents + # still exists on the remote + for dir_hash in dir_exists: + file_hashes = list( + named_cache.child_keys(cache.tree.scheme, dir_hash) ) - return ( - list(indexed_hashes) - + list(hashes & remote_hashes) - + self.tree.list_hashes_exists( - hashes - remote_hashes, jobs, name + if dir_hash not in self.index: + logger.debug( + "Indexing new .dir '{}' with '{}' nested files".format( + dir_hash, len(file_hashes) + ) ) + self.index.update([dir_hash], file_hashes) + yield dir_hash + yield from file_hashes + + @staticmethod + def _fill_statuses(hash_info_dir, local_exists, remote_exists): + # Using sets because they are way faster for lookups + local = set(local_exists) + remote = set(remote_exists) + + for md5, info in hash_info_dir.items(): + status = STATUS_MAP[(md5 in local, md5 in remote)] + info["status"] = status + + def _get_plans(self, cache_obj, download, status_info, status): + cache = [] + path_infos = [] + names = [] + hashes = [] + missing = [] + for md5, info in Tqdm( + status_info.items(), desc="Analysing status", unit="file" + ): + if info["status"] == status: + cache.append(cache_obj.tree.hash_to_path_info(md5)) + path_infos.append(self.tree.hash_to_path_info(md5)) + names.append(info["name"]) + hashes.append(md5) + elif info["status"] == STATUS_MISSING: + missing.append(md5) + + if download: + to_infos = cache + from_infos = path_infos + else: + to_infos = path_infos + from_infos = cache + + return (from_infos, to_infos, names, hashes), missing + + def _process( + self, + cache, + named_cache, + jobs=None, + show_checksums=False, + download=False, + ): + logger.debug( + "Preparing to {} '{}'".format( + "download data from" if download else "upload data to", + self.tree.path_info, ) + ) - logger.debug("Querying '{}' hashes via traverse".format(len(hashes))) - remote_hashes = set( - self.tree.list_hashes_traverse( - remote_size, remote_hashes, jobs, name + if download: + func = partial( + _log_exceptions(self.tree.download, "download"), + dir_mode=cache.tree.dir_mode, + file_mode=cache.tree.file_mode, ) + status = STATUS_DELETED + desc = "Downloading" + else: + func = _log_exceptions(self.tree.upload, "upload") + status = STATUS_NEW + desc = "Uploading" + + if jobs is None: + jobs = self.tree.JOBS + + dir_status, file_status, dir_contents = self._status( + cache, + named_cache, + jobs=jobs, + show_checksums=show_checksums, + download=download, + ) + + dir_plans, _ = self._get_plans(cache, download, dir_status, status) + file_plans, missing_files = self._get_plans( + cache, download, file_status, status + ) + + total = len(dir_plans[0]) + len(file_plans[0]) + if total == 0: + return 0 + + with Tqdm(total=total, unit="file", desc=desc) as pbar: + func = pbar.wrap_fn(func) + with ThreadPoolExecutor(max_workers=jobs) as executor: + if download: + from_infos, to_infos, names, _ = ( + d + f for d, f in zip(dir_plans, file_plans) + ) + fails = sum( + executor.map(func, from_infos, to_infos, names) + ) + else: + # for uploads, push files first, and any .dir files last + + file_futures = {} + for from_info, to_info, name, hash_ in zip(*file_plans): + file_futures[hash_] = executor.submit( + func, from_info, to_info, name + ) + dir_futures = {} + for from_info, to_info, name, dir_hash in zip(*dir_plans): + # if for some reason a file contained in this dir is + # missing both locally and in the remote, we want to + # push whatever file content we have, but should not + # push .dir file + for file_hash in missing_files: + if file_hash in dir_contents[dir_hash]: + logger.debug( + "directory '%s' contains missing files," + "skipping .dir file upload", + name, + ) + break + else: + wait_futures = { + future + for file_hash, future in file_futures.items() + if file_hash in dir_contents[dir_hash] + } + dir_futures[dir_hash] = executor.submit( + self._dir_upload, + func, + wait_futures, + from_info, + to_info, + name, + ) + fails = sum( + future.result() + for future in concat( + file_futures.values(), dir_futures.values() + ) + ) + + if fails: + if download: + self.index.clear() + raise DownloadError(fails) + raise UploadError(fails) + + if not download: + # index successfully pushed dirs + for dir_hash, future in dir_futures.items(): + if future.result() == 0: + file_hashes = dir_contents[dir_hash] + logger.debug( + "Indexing pushed dir '{}' with " + "'{}' nested files".format(dir_hash, len(file_hashes)) + ) + self.index.update([dir_hash], file_hashes) + + return len(dir_plans[0]) + len(file_plans[0]) + + @staticmethod + def _dir_upload(func, futures, from_info, to_info, name): + for future in as_completed(futures): + if future.result(): + # do not upload this .dir file if any file in this + # directory failed to upload + logger.debug( + "failed to upload full contents of '{}', " + "aborting .dir file upload".format(name) + ) + logger.error(f"failed to upload '{from_info}' to '{to_info}'") + return 1 + return func(from_info, to_info, name) + + @index_locked + def push(self, cache, named_cache, jobs=None, show_checksums=False): + return self._process( + cache, + named_cache, + jobs=jobs, + show_checksums=show_checksums, + download=False, ) - return list(indexed_hashes) + list(hashes & set(remote_hashes)) - @classmethod @index_locked - def gc(cls, named_cache, remote, jobs=None): - tree = remote.tree - used = set(named_cache.scheme_keys("local")) - - if tree.scheme != "": - used.update(named_cache.scheme_keys(tree.scheme)) - - removed = False - # hashes must be sorted to ensure we always remove .dir files first - for hash_ in sorted( - tree.all(jobs, str(tree.path_info)), - key=tree.is_dir_hash, - reverse=True, - ): - if hash_ in used: - continue - path_info = tree.hash_to_path_info(hash_) - if tree.is_dir_hash(hash_): - # backward compatibility - # pylint: disable=protected-access - tree._remove_unpacked_dir(hash_) - tree.remove(path_info) - removed = True - - if removed and hasattr(remote, "index"): - remote.index.clear() - return removed + def pull(self, cache, named_cache, jobs=None, show_checksums=False): + ret = self._process( + cache, + named_cache, + jobs=jobs, + show_checksums=show_checksums, + download=True, + ) + + if not self.tree.verify: + with cache.tree.state: + for checksum in named_cache.scheme_keys("local"): + cache_file = cache.tree.hash_to_path_info(checksum) + if cache.tree.exists(cache_file): + # We can safely save here, as existing corrupted files + # will be removed upon status, while files corrupted + # during download will not be moved from tmp_file + # (see `BaseTree.download()`) + hash_info = HashInfo( + cache.tree.PARAM_CHECKSUM, checksum + ) + cache.tree.state.save(cache_file, hash_info) + + return ret + + @staticmethod + def _log_missing_caches(hash_info_dict): + missing_caches = [ + (md5, info) + for md5, info in hash_info_dict.items() + if info["status"] == STATUS_MISSING + ] + if missing_caches: + missing_desc = "\n".join( + "name: {}, md5: {}".format(info["name"], md5) + for md5, info in missing_caches + ) + msg = ( + "Some of the cache files do not exist neither locally " + "nor on remote. Missing cache files:\n{}".format(missing_desc) + ) + logger.warning(msg) diff --git a/dvc/repo/experiments/__init__.py b/dvc/repo/experiments/__init__.py index bd3002aa3a..6340a1e7be 100644 --- a/dvc/repo/experiments/__init__.py +++ b/dvc/repo/experiments/__init__.py @@ -717,7 +717,7 @@ def _collect_output(self, executor: ExperimentExecutor): @staticmethod def _process(dest_tree, src_tree, collected_files, download=False): - from dvc.cache.local import _log_exceptions + from dvc.remote.base import _log_exceptions from_infos = [] to_infos = [] diff --git a/dvc/repo/gc.py b/dvc/repo/gc.py index 79e6aa4b63..9b415b7a87 100644 --- a/dvc/repo/gc.py +++ b/dvc/repo/gc.py @@ -3,19 +3,12 @@ from dvc.cache import NamedCache from dvc.exceptions import InvalidArgumentError +from ..scheme import Schemes from . import locked logger = logging.getLogger(__name__) -def _do_gc(typ, remote, clist, jobs=None): - from dvc.remote.base import Remote - - removed = Remote.gc(clist, remote, jobs=jobs) - if not removed: - logger.info(f"No unused '{typ}' cache to remove.") - - def _raise_error_if_all_disabled(**kwargs): if not any(kwargs.values()): raise InvalidArgumentError( @@ -76,10 +69,18 @@ def gc( ) ) - # treat caches as remotes for garbage collection for scheme, cache in self.cache.by_scheme(): - if cache: - _do_gc(scheme, cache, used, jobs) + if not cache: + continue - if cloud: - _do_gc("remote", self.cloud.get_remote(remote, "gc -c"), used, jobs) + removed = cache.gc(set(used.scheme_keys(scheme)), jobs=jobs) + if not removed: + logger.info(f"No unused '{scheme}' cache to remove.") + + if not cloud: + return + + remote = self.cloud.get_remote(remote, "gc -c") + removed = remote.gc(set(used.scheme_keys(Schemes.LOCAL)), jobs=jobs) + if not removed: + logger.info("No unused cache to remove from remote.") diff --git a/dvc/repo/status.py b/dvc/repo/status.py index 2cee28e967..accf186093 100644 --- a/dvc/repo/status.py +++ b/dvc/repo/status.py @@ -74,7 +74,7 @@ def _cloud_status( { "bar": "deleted" } """ - import dvc.cache.base as cloud + import dvc.remote.base as cloud used = self.used_cache( targets, diff --git a/dvc/stage/cache.py b/dvc/stage/cache.py index e7c5b721ac..2aabf3fddb 100644 --- a/dvc/stage/cache.py +++ b/dvc/stage/cache.py @@ -6,9 +6,9 @@ from funcy import cached_property, first from voluptuous import Invalid -from dvc.cache.local import _log_exceptions from dvc.exceptions import DvcException from dvc.path_info import PathInfo +from dvc.remote.base import _log_exceptions from dvc.schema import COMPILED_LOCK_FILE_STAGE_SCHEMA from dvc.utils import dict_sha256, relpath from dvc.utils.serialize import YAMLFileCorruptedError, dump_yaml, load_yaml diff --git a/dvc/tree/base.py b/dvc/tree/base.py index fb1fbc3e56..4f594ec32e 100644 --- a/dvc/tree/base.py +++ b/dvc/tree/base.py @@ -1,4 +1,3 @@ -import itertools import logging from concurrent.futures import ThreadPoolExecutor, as_completed from functools import partial @@ -282,14 +281,6 @@ def get_file_hash(self, path_info): def hash_to_path_info(self, hash_): return self.path_info / hash_[0:2] / hash_[2:] - def path_to_hash(self, path): - parts = self.PATH_CLS(path).parts[-2:] - - if not (len(parts) == 2 and parts[0] and len(parts[0]) == 2): - raise ValueError(f"Bad cache file path '{path}'") - - return "".join(parts) - def _calculate_hashes(self, file_infos): file_infos = list(file_infos) with Tqdm( @@ -449,193 +440,3 @@ def _download_file( ) move(tmp_file, to_info, mode=file_mode) - - def list_paths(self, prefix=None, progress_callback=None): - if prefix: - if len(prefix) > 2: - path_info = self.path_info / prefix[:2] / prefix[2:] - else: - path_info = self.path_info / prefix[:2] - prefix = True - else: - path_info = self.path_info - prefix = False - if progress_callback: - for file_info in self.walk_files(path_info, prefix=prefix): - progress_callback() - yield file_info.path - else: - yield from self.walk_files(path_info, prefix=prefix) - - def list_hashes(self, prefix=None, progress_callback=None): - """Iterate over hashes in this tree. - - If `prefix` is specified, only hashes which begin with `prefix` - will be returned. - """ - for path in self.list_paths(prefix, progress_callback): - try: - yield self.path_to_hash(path) - except ValueError: - logger.debug( - "'%s' doesn't look like a cache file, skipping", path - ) - - def all(self, jobs=None, name=None): - """Iterate over all hashes in this tree. - - Hashes will be fetched in parallel threads according to prefix - (except for small remotes) and a progress bar will be displayed. - """ - logger.debug( - "Fetching all hashes from '{}'".format( - name if name else "remote cache" - ) - ) - - if not self.CAN_TRAVERSE: - return self.list_hashes() - - remote_size, remote_hashes = self.estimate_remote_size(name=name) - return self.list_hashes_traverse( - remote_size, remote_hashes, jobs, name - ) - - def _hashes_with_limit(self, limit, prefix=None, progress_callback=None): - count = 0 - for hash_ in self.list_hashes(prefix, progress_callback): - yield hash_ - count += 1 - if count > limit: - logger.debug( - "`list_hashes()` returned max '{}' hashes, " - "skipping remaining results".format(limit) - ) - return - - def _max_estimation_size(self, hashes): - # Max remote size allowed for us to use traverse method - return max( - self.TRAVERSE_THRESHOLD_SIZE, - len(hashes) - / self.TRAVERSE_WEIGHT_MULTIPLIER - * self.LIST_OBJECT_PAGE_SIZE, - ) - - def estimate_remote_size(self, hashes=None, name=None): - """Estimate tree size based on number of entries beginning with - "00..." prefix. - """ - prefix = "0" * self.TRAVERSE_PREFIX_LEN - total_prefixes = pow(16, self.TRAVERSE_PREFIX_LEN) - if hashes: - max_hashes = self._max_estimation_size(hashes) - else: - max_hashes = None - - with Tqdm( - desc="Estimating size of " - + (f"cache in '{name}'" if name else "remote cache"), - unit="file", - ) as pbar: - - def update(n=1): - pbar.update(n * total_prefixes) - - if max_hashes: - hashes = self._hashes_with_limit( - max_hashes / total_prefixes, prefix, update - ) - else: - hashes = self.list_hashes(prefix, update) - - remote_hashes = set(hashes) - if remote_hashes: - remote_size = total_prefixes * len(remote_hashes) - else: - remote_size = total_prefixes - logger.debug(f"Estimated remote size: {remote_size} files") - return remote_size, remote_hashes - - def list_hashes_traverse( - self, remote_size, remote_hashes, jobs=None, name=None - ): - """Iterate over all hashes found in this tree. - Hashes are fetched in parallel according to prefix, except in - cases where the remote size is very small. - - All hashes from the remote (including any from the size - estimation step passed via the `remote_hashes` argument) will be - returned. - - NOTE: For large remotes the list of hashes will be very - big(e.g. 100M entries, md5 for each is 32 bytes, so ~3200Mb list) - and we don't really need all of it at the same time, so it makes - sense to use a generator to gradually iterate over it, without - keeping all of it in memory. - """ - num_pages = remote_size / self.LIST_OBJECT_PAGE_SIZE - if num_pages < 256 / self.JOBS: - # Fetching prefixes in parallel requires at least 255 more - # requests, for small enough remotes it will be faster to fetch - # entire cache without splitting it into prefixes. - # - # NOTE: this ends up re-fetching hashes that were already - # fetched during remote size estimation - traverse_prefixes = [None] - initial = 0 - else: - yield from remote_hashes - initial = len(remote_hashes) - traverse_prefixes = [f"{i:02x}" for i in range(1, 256)] - if self.TRAVERSE_PREFIX_LEN > 2: - traverse_prefixes += [ - "{0:0{1}x}".format(i, self.TRAVERSE_PREFIX_LEN) - for i in range(1, pow(16, self.TRAVERSE_PREFIX_LEN - 2)) - ] - with Tqdm( - desc="Querying " - + (f"cache in '{name}'" if name else "remote cache"), - total=remote_size, - initial=initial, - unit="file", - ) as pbar: - - def list_with_update(prefix): - return list( - self.list_hashes( - prefix=prefix, progress_callback=pbar.update - ) - ) - - with ThreadPoolExecutor(max_workers=jobs or self.JOBS) as executor: - in_remote = executor.map(list_with_update, traverse_prefixes,) - yield from itertools.chain.from_iterable(in_remote) - - def list_hashes_exists(self, hashes, jobs=None, name=None): - """Return list of the specified hashes which exist in this tree. - Hashes will be queried individually. - """ - logger.debug( - "Querying {} hashes via object_exists".format(len(hashes)) - ) - with Tqdm( - desc="Querying " - + ("cache in " + name if name else "remote cache"), - total=len(hashes), - unit="file", - ) as pbar: - - def exists_with_progress(path_info): - ret = self.exists(path_info) - pbar.update_msg(str(path_info)) - return ret - - with ThreadPoolExecutor(max_workers=jobs or self.JOBS) as executor: - path_infos = map(self.hash_to_path_info, hashes) - in_remote = executor.map(exists_with_progress, path_infos) - ret = list(itertools.compress(hashes, in_remote)) - return ret - - def _remove_unpacked_dir(self, hash_): - pass diff --git a/dvc/tree/local.py b/dvc/tree/local.py index ce56cc78c4..aa25a12fdd 100644 --- a/dvc/tree/local.py +++ b/dvc/tree/local.py @@ -12,14 +12,7 @@ from dvc.scheme import Schemes from dvc.system import System from dvc.utils import file_md5, is_exec, relpath, tmp_fname -from dvc.utils.fs import ( - copy_fobj_to_file, - copyfile, - makedirs, - move, - remove, - walk_files, -) +from dvc.utils.fs import copy_fobj_to_file, copyfile, makedirs, move, remove from .base import BaseTree @@ -32,7 +25,6 @@ class LocalTree(BaseTree): PARAM_CHECKSUM = "md5" PARAM_PATH = "path" TRAVERSE_PREFIX_LEN = 2 - UNPACKED_DIR_SUFFIX = ".unpacked" CACHE_MODE = 0o444 SHARED_MODE_MAP = {None: (0o644, 0o755), "group": (0o664, 0o775)} @@ -344,27 +336,5 @@ def _download( from_info, to_file, no_progress_bar=no_progress_bar, name=name ) - def list_paths(self, prefix=None, progress_callback=None): - assert self.path_info is not None - if prefix: - path_info = self.path_info / prefix[:2] - if not self.exists(path_info): - return - else: - path_info = self.path_info - # NOTE: use utils.fs walk_files since tree.walk_files will not follow - # symlinks - if progress_callback: - for path in walk_files(path_info): - progress_callback() - yield path - else: - yield from walk_files(path_info) - - def _remove_unpacked_dir(self, hash_): - info = self.hash_to_path_info(hash_) - path_info = info.with_name(info.name + self.UNPACKED_DIR_SUFFIX) - self.remove(path_info) - def _reset(self): return self.__dict__.pop("dvcignore", None) diff --git a/dvc/tree/ssh/__init__.py b/dvc/tree/ssh/__init__.py index 1af9a88cab..59193a8fbd 100644 --- a/dvc/tree/ssh/__init__.py +++ b/dvc/tree/ssh/__init__.py @@ -2,7 +2,6 @@ import io import logging import os -import posixpath import threading from contextlib import closing, contextmanager from urllib.parse import urlparse @@ -269,19 +268,3 @@ def _upload(self, from_file, to_info, name=None, no_progress_bar=False): progress_title=name, no_progress_bar=no_progress_bar, ) - - def list_paths(self, prefix=None, progress_callback=None): - if prefix: - root = posixpath.join(self.path_info.path, prefix[:2]) - else: - root = self.path_info.path - with self.ssh(self.path_info) as ssh: - if prefix and not ssh.exists(root): - return - # If we simply return an iterator then with above closes instantly - if progress_callback: - for path in ssh.walk_files(root): - progress_callback() - yield path - else: - yield from ssh.walk_files(root) diff --git a/tests/func/test_cache.py b/tests/func/test_cache.py index b32fd2606e..53c8f65643 100644 --- a/tests/func/test_cache.py +++ b/tests/func/test_cache.py @@ -32,7 +32,7 @@ def setUp(self): self.create(self.cache2, "2") def test_all(self): - md5_list = list(Cache(self.dvc).local.tree.all()) + md5_list = list(Cache(self.dvc).local.all()) self.assertEqual(len(md5_list), 2) self.assertIn(self.cache1_md5, md5_list) self.assertIn(self.cache2_md5, md5_list) diff --git a/tests/func/test_data_cloud.py b/tests/func/test_data_cloud.py index f370bae0af..eb687a692c 100644 --- a/tests/func/test_data_cloud.py +++ b/tests/func/test_data_cloud.py @@ -6,14 +6,14 @@ from flaky.flaky_decorator import flaky from dvc.cache import NamedCache -from dvc.cache.base import ( +from dvc.external_repo import clean_repos +from dvc.main import main +from dvc.remote.base import ( STATUS_DELETED, STATUS_MISSING, STATUS_NEW, STATUS_OK, ) -from dvc.external_repo import clean_repos -from dvc.main import main from dvc.stage.exceptions import StageNotFound from dvc.tree.local import LocalTree from dvc.utils.fs import move, remove diff --git a/tests/func/test_gc.py b/tests/func/test_gc.py index 317683cf53..90b5664a49 100644 --- a/tests/func/test_gc.py +++ b/tests/func/test_gc.py @@ -6,6 +6,7 @@ import pytest from git import Repo +from dvc.cache.local import LocalCache from dvc.exceptions import CollectCacheError from dvc.main import main from dvc.repo import Repo as DvcRepo @@ -22,7 +23,7 @@ def setUp(self): self.dvc.add(self.DATA_DIR) self.good_cache = [ self.dvc.cache.local.tree.hash_to_path_info(md5) - for md5 in self.dvc.cache.local.tree.all() + for md5 in self.dvc.cache.local.all() ] self.bad_cache = [] @@ -217,7 +218,7 @@ def test_gc_no_unpacked_dir(tmp_dir, dvc): os.remove("dir.dvc") unpackeddir = ( - dir_stages[0].outs[0].cache_path + LocalTree.UNPACKED_DIR_SUFFIX + dir_stages[0].outs[0].cache_path + LocalCache.UNPACKED_DIR_SUFFIX ) # older (pre 1.0) versions of dvc used to generate this dir diff --git a/tests/unit/cache/__init__.py b/tests/unit/cache/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit/remote/test_base.py b/tests/unit/remote/test_base.py index f40a5e8927..1f793ed83d 100644 --- a/tests/unit/remote/test_base.py +++ b/tests/unit/remote/test_base.py @@ -3,8 +3,8 @@ import mock import pytest +from dvc.cache.base import CloudCache from dvc.path_info import PathInfo -from dvc.remote.base import Remote from dvc.tree.base import BaseTree, RemoteCmdError, RemoteMissingDepsError @@ -39,36 +39,36 @@ def test_cmd_error(dvc): BaseTree(dvc, config).remove("file") -@mock.patch.object(BaseTree, "list_hashes_traverse") -@mock.patch.object(BaseTree, "list_hashes_exists") +@mock.patch.object(CloudCache, "list_hashes_traverse") +@mock.patch.object(CloudCache, "list_hashes_exists") def test_hashes_exist(object_exists, traverse, dvc): - remote = Remote(BaseTree(dvc, {})) + cache = CloudCache(BaseTree(dvc, {})) # remote does not support traverse - remote.tree.CAN_TRAVERSE = False + cache.tree.CAN_TRAVERSE = False with mock.patch.object( - remote.tree, "list_hashes", return_value=list(range(256)) + cache, "list_hashes", return_value=list(range(256)) ): hashes = set(range(1000)) - remote.hashes_exist(hashes) + cache.hashes_exist(hashes) object_exists.assert_called_with(hashes, None, None) traverse.assert_not_called() - remote.tree.CAN_TRAVERSE = True + cache.tree.CAN_TRAVERSE = True # large remote, small local object_exists.reset_mock() traverse.reset_mock() with mock.patch.object( - remote.tree, "list_hashes", return_value=list(range(256)) + cache, "list_hashes", return_value=list(range(256)) ): hashes = list(range(1000)) - remote.hashes_exist(hashes) + cache.hashes_exist(hashes) # verify that _cache_paths_with_max() short circuits # before returning all 256 remote hashes max_hashes = math.ceil( - remote.tree._max_estimation_size(hashes) - / pow(16, remote.tree.TRAVERSE_PREFIX_LEN) + cache._max_estimation_size(hashes) + / pow(16, cache.tree.TRAVERSE_PREFIX_LEN) ) assert max_hashes < 256 object_exists.assert_called_with( @@ -79,15 +79,15 @@ def test_hashes_exist(object_exists, traverse, dvc): # large remote, large local object_exists.reset_mock() traverse.reset_mock() - remote.tree.JOBS = 16 + cache.tree.JOBS = 16 with mock.patch.object( - remote.tree, "list_hashes", return_value=list(range(256)) + cache, "list_hashes", return_value=list(range(256)) ): hashes = list(range(1000000)) - remote.hashes_exist(hashes) + cache.hashes_exist(hashes) object_exists.assert_not_called() traverse.assert_called_with( - 256 * pow(16, remote.tree.TRAVERSE_PREFIX_LEN), + 256 * pow(16, cache.tree.TRAVERSE_PREFIX_LEN), set(range(256)), None, None, @@ -95,18 +95,18 @@ def test_hashes_exist(object_exists, traverse, dvc): @mock.patch.object( - BaseTree, "list_hashes", return_value=[], + CloudCache, "list_hashes", return_value=[], ) @mock.patch.object( - BaseTree, "path_to_hash", side_effect=lambda x: x, + CloudCache, "_path_to_hash", side_effect=lambda x: x, ) def test_list_hashes_traverse(_path_to_hash, list_hashes, dvc): - tree = BaseTree(dvc, {}) - tree.path_info = PathInfo("foo") + cache = CloudCache(BaseTree(dvc, {})) + cache.tree.path_info = PathInfo("foo") # parallel traverse - size = 256 / tree.JOBS * tree.LIST_OBJECT_PAGE_SIZE - list(tree.list_hashes_traverse(size, {0})) + size = 256 / cache.tree.JOBS * cache.tree.LIST_OBJECT_PAGE_SIZE + list(cache.list_hashes_traverse(size, {0})) for i in range(1, 16): list_hashes.assert_any_call( prefix=f"{i:03x}", progress_callback=CallableOrNone @@ -119,35 +119,39 @@ def test_list_hashes_traverse(_path_to_hash, list_hashes, dvc): # default traverse (small remote) size -= 1 list_hashes.reset_mock() - list(tree.list_hashes_traverse(size - 1, {0})) + list(cache.list_hashes_traverse(size - 1, {0})) list_hashes.assert_called_with( prefix=None, progress_callback=CallableOrNone ) def test_list_hashes(dvc): - tree = BaseTree(dvc, {}) - tree.path_info = PathInfo("foo") + cache = CloudCache(BaseTree(dvc, {})) + cache.tree.path_info = PathInfo("foo") with mock.patch.object( - tree, "list_paths", return_value=["12/3456", "bar"] + cache, "_list_paths", return_value=["12/3456", "bar"] ): - hashes = list(tree.list_hashes()) + hashes = list(cache.list_hashes()) assert hashes == ["123456"] def test_list_paths(dvc): - tree = BaseTree(dvc, {}) - tree.path_info = PathInfo("foo") + cache = CloudCache(BaseTree(dvc, {})) + cache.tree.path_info = PathInfo("foo") - with mock.patch.object(tree, "walk_files", return_value=[]) as walk_mock: - for _ in tree.list_paths(): + with mock.patch.object( + cache.tree, "walk_files", return_value=[] + ) as walk_mock: + for _ in cache._list_paths(): pass - walk_mock.assert_called_with(tree.path_info, prefix=False) + walk_mock.assert_called_with(cache.tree.path_info, prefix=False) - for _ in tree.list_paths(prefix="000"): + for _ in cache._list_paths(prefix="000"): pass - walk_mock.assert_called_with(tree.path_info / "00" / "0", prefix=True) + walk_mock.assert_called_with( + cache.tree.path_info / "00" / "0", prefix=True + ) @pytest.mark.parametrize( diff --git a/tests/unit/remote/test_local.py b/tests/unit/remote/test_local.py index f24454b8b6..5846c1ddf4 100644 --- a/tests/unit/remote/test_local.py +++ b/tests/unit/remote/test_local.py @@ -29,7 +29,7 @@ def test_status_download_optimization(mocker, dvc): other_remote.hashes_exist.return_value = [] other_remote.index = RemoteIndexNoop() - cache.status(infos, other_remote, download=True) + other_remote.status(cache, infos, download=True) assert other_remote.hashes_exist.call_count == 0