Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 8 additions & 11 deletions dvc/remote/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from operator import itemgetter
from multiprocessing import cpu_count

import json
import logging
import tempfile
Expand Down Expand Up @@ -190,14 +191,13 @@ def _calculate_checksums(self, file_infos):
return checksums

def _collect_dir(self, path_info):

file_infos = set()
for root, _dirs, files in self.walk(path_info):

if DvcIgnore.DVCIGNORE_FILE in files:
raise DvcIgnoreInCollectedDirError(root)
for fname in self.walk_files(path_info):
if DvcIgnore.DVCIGNORE_FILE == fname.name:
raise DvcIgnoreInCollectedDirError(fname.parent)

file_infos.update(path_info / root / fname for fname in files)
file_infos.add(fname)

checksums = {fi: self.state.get(fi) for fi in file_infos}
not_in_state = {
Expand Down Expand Up @@ -466,7 +466,8 @@ def isdir(self, path_info):
"""
return False

def walk(self, path_info):
def walk_files(self, path_info):
"""Return a generator with `PathInfo`s to all the files"""
raise NotImplementedError

@staticmethod
Expand Down Expand Up @@ -831,11 +832,7 @@ def _checkout_dir(
self.state.save(path_info, checksum)

def _remove_redundant_files(self, path_info, dir_info, force):
existing_files = set(
path_info / root / fname
for root, _, files in self.walk(path_info)
for fname in files
)
existing_files = set(self.walk_files(path_info))

needed_files = {
path_info / entry[self.PARAM_RELPATH] for entry in dir_info
Expand Down
6 changes: 3 additions & 3 deletions dvc/remote/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
file_md5,
walk_files,
relpath,
dvc_walk,
makedirs,
)
from dvc.config import Config
Expand Down Expand Up @@ -144,8 +143,9 @@ def isdir(path_info):
def getsize(path_info):
return os.path.getsize(fspath_py35(path_info))

def walk(self, path_info):
return dvc_walk(path_info, self.repo.dvcignore)
def walk_files(self, path_info):
for fname in walk_files(path_info, self.repo.dvcignore):
yield PathInfo(fname)

def get_file_checksum(self, path_info):
return file_md5(fspath_py35(path_info))[0]
Expand Down
64 changes: 52 additions & 12 deletions dvc/remote/s3.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
# -*- coding: utf-8 -*-

from __future__ import unicode_literals

import os
import logging
import posixpath
from funcy import cached_property

from dvc.progress import Tqdm
Expand Down Expand Up @@ -36,7 +39,10 @@ def __init__(self, repo, config):

self.endpoint_url = config.get(Config.SECTION_AWS_ENDPOINT_URL)

self.list_objects = config.get(Config.SECTION_AWS_LIST_OBJECTS)
if config.get(Config.SECTION_AWS_LIST_OBJECTS):
self.list_objects_api = "list_objects"
else:
self.list_objects_api = "list_objects_v2"

self.use_ssl = config.get(Config.SECTION_AWS_USE_SSL, True)

Expand Down Expand Up @@ -180,27 +186,57 @@ def remove(self, path_info):
logger.debug("Removing {}".format(path_info))
self.s3.delete_object(Bucket=path_info.bucket, Key=path_info.path)

def _list_paths(self, bucket, prefix):
def _list_objects(self, path_info, max_items=None):
""" Read config for list object api, paginate through list objects."""
kwargs = {"Bucket": bucket, "Prefix": prefix}
if self.list_objects:
list_objects_api = "list_objects"
else:
list_objects_api = "list_objects_v2"
paginator = self.s3.get_paginator(list_objects_api)
kwargs = {
"Bucket": path_info.bucket,
"Prefix": path_info.path,
"PaginationConfig": {"MaxItems": max_items},
}
paginator = self.s3.get_paginator(self.list_objects_api)
for page in paginator.paginate(**kwargs):
contents = page.get("Contents", None)
if not contents:
continue
for item in contents:
yield item["Key"]
yield item

def _list_paths(self, path_info, max_items=None):
return (
item["Key"] for item in self._list_objects(path_info, max_items)
)

def list_cache_paths(self):
return self._list_paths(self.path_info.bucket, self.path_info.path)
return self.walk_files(self.path_info)

def exists(self, path_info):
paths = self._list_paths(path_info.bucket, path_info.path)
return any(path_info.path == path for path in paths)
dir_path = path_info / ""
fname = next(self._list_paths(path_info, max_items=1), "")
return path_info.path == fname or fname.startswith(dir_path.path)

def isdir(self, path_info):
# S3 doesn't have a concept for directories.
#
# Using `head_object` with a path pointing to a directory
# will throw a 404 error.
#
# A reliable way to know if a given path is a directory is by
# checking if there are more files sharing the same prefix
# with a `list_objects` call.
#
# We need to make sure that the path ends with a forward slash,
# since we can end with false-positives like the following example:
#
# bucket
# └── data
# ├── alice
# └── alpha
#
# Using `data/al` as prefix will return `[data/alice, data/alpha]`,
# While `data/al/` will return nothing.
#
dir_path = path_info / ""
return bool(list(self._list_paths(dir_path, max_items=1)))

def _upload(self, from_file, to_info, name=None, no_progress_bar=False):
total = os.path.getsize(from_file)
Expand Down Expand Up @@ -234,3 +270,7 @@ def _generate_download_url(self, path_info, expires=3600):
return self.s3.generate_presigned_url(
ClientMethod="get_object", Params=params, ExpiresIn=int(expires)
)

def walk_files(self, path_info, max_items=None):
for fname in self._list_paths(path_info, max_items):
yield path_info / posixpath.relpath(fname, path_info.path)
7 changes: 4 additions & 3 deletions dvc/remote/ssh/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import itertools
import io
import os
import posixpath
import getpass
import logging
import threading
Expand Down Expand Up @@ -260,10 +261,10 @@ def list_cache_paths(self):
with self.ssh(self.path_info) as ssh:
return list(ssh.walk_files(self.path_info.path))

def walk(self, path_info):
def walk_files(self, path_info):
with self.ssh(path_info) as ssh:
for entry in ssh.walk(path_info.path):
yield entry
for fname in ssh.walk_files(path_info.path):
yield path_info / posixpath.relpath(fname, path_info.path)

def makedirs(self, path_info):
with self.ssh(path_info) as ssh:
Expand Down
85 changes: 85 additions & 0 deletions tests/unit/remote/test_s3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# -*- coding: utf-8 -*-

import pytest
from moto import mock_s3

from dvc.remote.s3 import RemoteS3


@pytest.fixture
def remote():
"""Returns a RemoteS3 connected to a bucket with the following structure:

bucket
├── data
│ ├── alice
│ ├── alpha
│ └── subdir
│ ├── 1
│ ├── 2
│ └── 3
├── empty_dir
├── empty_file
└── foo
"""
with mock_s3():
remote = RemoteS3(None, {"url": "s3://bucket", "region": "us-east-1"})
s3 = remote.s3

s3.create_bucket(Bucket="bucket")
s3.put_object(Bucket="bucket", Key="empty_dir/")
s3.put_object(Bucket="bucket", Key="empty_file", Body=b"")
s3.put_object(Bucket="bucket", Key="foo", Body=b"foo")
s3.put_object(Bucket="bucket", Key="data/alice", Body=b"alice")
s3.put_object(Bucket="bucket", Key="data/alpha", Body=b"alpha")
s3.put_object(Bucket="bucket", Key="data/subdir/1", Body=b"1")
s3.put_object(Bucket="bucket", Key="data/subdir/2", Body=b"2")
s3.put_object(Bucket="bucket", Key="data/subdir/3", Body=b"3")

yield remote


def test_isdir(remote):
test_cases = [
(True, "data"),
(True, "data/"),
(True, "data/subdir"),
(True, "empty_dir"),
(False, "foo"),
(False, "data/alice"),
(False, "data/al"),
(False, "data/subdir/1"),
]

for expected, path in test_cases:
assert remote.isdir(remote.path_info / path) == expected


def test_exists(remote):
test_cases = [
(True, "data"),
(True, "data/"),
(True, "data/subdir"),
(True, "empty_dir"),
(True, "empty_file"),
(True, "foo"),
(True, "data/alice"),
(True, "data/subdir/1"),
(False, "data/al"),
(False, "foo/"),
]

for expected, path in test_cases:
assert remote.exists(remote.path_info / path) == expected


def test_walk_files(remote):
files = [
remote.path_info / "data/alice",
remote.path_info / "data/alpha",
remote.path_info / "data/subdir/1",
remote.path_info / "data/subdir/2",
remote.path_info / "data/subdir/3",
]

assert list(remote.walk_files(remote.path_info / "data")) == files