diff --git a/dvc/remote/base.py b/dvc/remote/base.py index 5dd2ca1b6a..f576205d8b 100644 --- a/dvc/remote/base.py +++ b/dvc/remote/base.py @@ -2,6 +2,7 @@ from operator import itemgetter from multiprocessing import cpu_count + import json import logging import tempfile @@ -190,14 +191,13 @@ def _calculate_checksums(self, file_infos): return checksums def _collect_dir(self, path_info): - file_infos = set() - for root, _dirs, files in self.walk(path_info): - if DvcIgnore.DVCIGNORE_FILE in files: - raise DvcIgnoreInCollectedDirError(root) + for fname in self.walk_files(path_info): + if DvcIgnore.DVCIGNORE_FILE == fname.name: + raise DvcIgnoreInCollectedDirError(fname.parent) - file_infos.update(path_info / root / fname for fname in files) + file_infos.add(fname) checksums = {fi: self.state.get(fi) for fi in file_infos} not_in_state = { @@ -466,7 +466,8 @@ def isdir(self, path_info): """ return False - def walk(self, path_info): + def walk_files(self, path_info): + """Return a generator with `PathInfo`s to all the files""" raise NotImplementedError @staticmethod @@ -831,11 +832,7 @@ def _checkout_dir( self.state.save(path_info, checksum) def _remove_redundant_files(self, path_info, dir_info, force): - existing_files = set( - path_info / root / fname - for root, _, files in self.walk(path_info) - for fname in files - ) + existing_files = set(self.walk_files(path_info)) needed_files = { path_info / entry[self.PARAM_RELPATH] for entry in dir_info diff --git a/dvc/remote/local.py b/dvc/remote/local.py index e187f33b32..dff4337566 100644 --- a/dvc/remote/local.py +++ b/dvc/remote/local.py @@ -25,7 +25,6 @@ file_md5, walk_files, relpath, - dvc_walk, makedirs, ) from dvc.config import Config @@ -144,8 +143,9 @@ def isdir(path_info): def getsize(path_info): return os.path.getsize(fspath_py35(path_info)) - def walk(self, path_info): - return dvc_walk(path_info, self.repo.dvcignore) + def walk_files(self, path_info): + for fname in walk_files(path_info, self.repo.dvcignore): + yield PathInfo(fname) def get_file_checksum(self, path_info): return file_md5(fspath_py35(path_info))[0] diff --git a/dvc/remote/s3.py b/dvc/remote/s3.py index 51f07f94df..cef7f447c8 100644 --- a/dvc/remote/s3.py +++ b/dvc/remote/s3.py @@ -1,7 +1,10 @@ +# -*- coding: utf-8 -*- + from __future__ import unicode_literals import os import logging +import posixpath from funcy import cached_property from dvc.progress import Tqdm @@ -36,7 +39,10 @@ def __init__(self, repo, config): self.endpoint_url = config.get(Config.SECTION_AWS_ENDPOINT_URL) - self.list_objects = config.get(Config.SECTION_AWS_LIST_OBJECTS) + if config.get(Config.SECTION_AWS_LIST_OBJECTS): + self.list_objects_api = "list_objects" + else: + self.list_objects_api = "list_objects_v2" self.use_ssl = config.get(Config.SECTION_AWS_USE_SSL, True) @@ -180,27 +186,57 @@ def remove(self, path_info): logger.debug("Removing {}".format(path_info)) self.s3.delete_object(Bucket=path_info.bucket, Key=path_info.path) - def _list_paths(self, bucket, prefix): + def _list_objects(self, path_info, max_items=None): """ Read config for list object api, paginate through list objects.""" - kwargs = {"Bucket": bucket, "Prefix": prefix} - if self.list_objects: - list_objects_api = "list_objects" - else: - list_objects_api = "list_objects_v2" - paginator = self.s3.get_paginator(list_objects_api) + kwargs = { + "Bucket": path_info.bucket, + "Prefix": path_info.path, + "PaginationConfig": {"MaxItems": max_items}, + } + paginator = self.s3.get_paginator(self.list_objects_api) for page in paginator.paginate(**kwargs): contents = page.get("Contents", None) if not contents: continue for item in contents: - yield item["Key"] + yield item + + def _list_paths(self, path_info, max_items=None): + return ( + item["Key"] for item in self._list_objects(path_info, max_items) + ) def list_cache_paths(self): - return self._list_paths(self.path_info.bucket, self.path_info.path) + return self.walk_files(self.path_info) def exists(self, path_info): - paths = self._list_paths(path_info.bucket, path_info.path) - return any(path_info.path == path for path in paths) + dir_path = path_info / "" + fname = next(self._list_paths(path_info, max_items=1), "") + return path_info.path == fname or fname.startswith(dir_path.path) + + def isdir(self, path_info): + # S3 doesn't have a concept for directories. + # + # Using `head_object` with a path pointing to a directory + # will throw a 404 error. + # + # A reliable way to know if a given path is a directory is by + # checking if there are more files sharing the same prefix + # with a `list_objects` call. + # + # We need to make sure that the path ends with a forward slash, + # since we can end with false-positives like the following example: + # + # bucket + # └── data + # ├── alice + # └── alpha + # + # Using `data/al` as prefix will return `[data/alice, data/alpha]`, + # While `data/al/` will return nothing. + # + dir_path = path_info / "" + return bool(list(self._list_paths(dir_path, max_items=1))) def _upload(self, from_file, to_info, name=None, no_progress_bar=False): total = os.path.getsize(from_file) @@ -234,3 +270,7 @@ def _generate_download_url(self, path_info, expires=3600): return self.s3.generate_presigned_url( ClientMethod="get_object", Params=params, ExpiresIn=int(expires) ) + + def walk_files(self, path_info, max_items=None): + for fname in self._list_paths(path_info, max_items): + yield path_info / posixpath.relpath(fname, path_info.path) diff --git a/dvc/remote/ssh/__init__.py b/dvc/remote/ssh/__init__.py index 75a2a6f980..41e35135fb 100644 --- a/dvc/remote/ssh/__init__.py +++ b/dvc/remote/ssh/__init__.py @@ -4,6 +4,7 @@ import itertools import io import os +import posixpath import getpass import logging import threading @@ -260,10 +261,10 @@ def list_cache_paths(self): with self.ssh(self.path_info) as ssh: return list(ssh.walk_files(self.path_info.path)) - def walk(self, path_info): + def walk_files(self, path_info): with self.ssh(path_info) as ssh: - for entry in ssh.walk(path_info.path): - yield entry + for fname in ssh.walk_files(path_info.path): + yield path_info / posixpath.relpath(fname, path_info.path) def makedirs(self, path_info): with self.ssh(path_info) as ssh: diff --git a/tests/unit/remote/test_s3.py b/tests/unit/remote/test_s3.py new file mode 100644 index 0000000000..7861fb5a9a --- /dev/null +++ b/tests/unit/remote/test_s3.py @@ -0,0 +1,85 @@ +# -*- coding: utf-8 -*- + +import pytest +from moto import mock_s3 + +from dvc.remote.s3 import RemoteS3 + + +@pytest.fixture +def remote(): + """Returns a RemoteS3 connected to a bucket with the following structure: + + bucket + ├── data + │ ├── alice + │ ├── alpha + │ └── subdir + │ ├── 1 + │ ├── 2 + │ └── 3 + ├── empty_dir + ├── empty_file + └── foo + """ + with mock_s3(): + remote = RemoteS3(None, {"url": "s3://bucket", "region": "us-east-1"}) + s3 = remote.s3 + + s3.create_bucket(Bucket="bucket") + s3.put_object(Bucket="bucket", Key="empty_dir/") + s3.put_object(Bucket="bucket", Key="empty_file", Body=b"") + s3.put_object(Bucket="bucket", Key="foo", Body=b"foo") + s3.put_object(Bucket="bucket", Key="data/alice", Body=b"alice") + s3.put_object(Bucket="bucket", Key="data/alpha", Body=b"alpha") + s3.put_object(Bucket="bucket", Key="data/subdir/1", Body=b"1") + s3.put_object(Bucket="bucket", Key="data/subdir/2", Body=b"2") + s3.put_object(Bucket="bucket", Key="data/subdir/3", Body=b"3") + + yield remote + + +def test_isdir(remote): + test_cases = [ + (True, "data"), + (True, "data/"), + (True, "data/subdir"), + (True, "empty_dir"), + (False, "foo"), + (False, "data/alice"), + (False, "data/al"), + (False, "data/subdir/1"), + ] + + for expected, path in test_cases: + assert remote.isdir(remote.path_info / path) == expected + + +def test_exists(remote): + test_cases = [ + (True, "data"), + (True, "data/"), + (True, "data/subdir"), + (True, "empty_dir"), + (True, "empty_file"), + (True, "foo"), + (True, "data/alice"), + (True, "data/subdir/1"), + (False, "data/al"), + (False, "foo/"), + ] + + for expected, path in test_cases: + assert remote.exists(remote.path_info / path) == expected + + +def test_walk_files(remote): + files = [ + remote.path_info / "data/alice", + remote.path_info / "data/alpha", + remote.path_info / "data/subdir/1", + remote.path_info / "data/subdir/2", + remote.path_info / "data/subdir/3", + ] + + assert list(remote.walk_files(remote.path_info / "data")) == files