Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion dvc/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ class RelPath(str):
}
LOCAL_COMMON = {
"type": supported_cache_type,
Optional("protected", default=False): Bool,
Optional("protected", default=False): Bool, # obsoleted
"shared": All(Lower, Choices("group")),
Optional("slow_link_warning", default=True): Bool,
}
Expand Down
36 changes: 24 additions & 12 deletions dvc/remote/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,19 +83,24 @@ class RemoteBASE(object):
DEFAULT_NO_TRAVERSE = True
DEFAULT_VERIFY = False

CACHE_MODE = None
SHARED_MODE_MAP = {None: (None, None), "group": (None, None)}

state = StateNoop()

def __init__(self, repo, config):
self.repo = repo

self._check_requires(config)

shared = config.get("shared")
self._file_mode, self._dir_mode = self.SHARED_MODE_MAP[shared]

self.checksum_jobs = (
config.get("checksum_jobs")
or (self.repo and self.repo.config["core"].get("checksum_jobs"))
or self.CHECKSUM_JOBS
)
self.protected = False
self.no_traverse = config.get("no_traverse", self.DEFAULT_NO_TRAVERSE)
self.verify = config.get("verify", self.DEFAULT_VERIFY)
self._dir_info = {}
Expand Down Expand Up @@ -221,7 +226,7 @@ def get_dir_checksum(self, path_info):
new_info = self.cache.checksum_to_path_info(checksum)
if self.cache.changed_cache_file(checksum):
self.cache.makedirs(new_info.parent)
self.cache.move(tmp_info, new_info)
self.cache.move(tmp_info, new_info, mode=self.CACHE_MODE)

self.state.save(path_info, checksum)
self.state.save(new_info, checksum)
Expand Down Expand Up @@ -409,12 +414,8 @@ def _do_link(self, from_info, to_info, link_method):

link_method(from_info, to_info)

if self.protected:
self.protect(to_info)

logger.debug(
"Created %s'%s': %s -> %s",
"protected " if self.protected else "",
self.cache_types[0],
from_info,
to_info,
Expand All @@ -425,14 +426,11 @@ def _save_file(self, path_info, checksum, save_link=True):

cache_info = self.checksum_to_path_info(checksum)
if self.changed_cache(checksum):
self.move(path_info, cache_info)
self.move(path_info, cache_info, mode=self.CACHE_MODE)
self.link(cache_info, path_info)
elif self.iscopy(path_info) and self._cache_is_copy(path_info):
# Default relink procedure involves unneeded copy
if self.protected:
self.protect(path_info)
else:
self.unprotect(path_info)
self.unprotect(path_info)
else:
self.remove(path_info)
self.link(cache_info, path_info)
Expand Down Expand Up @@ -656,7 +654,8 @@ def open(self, path_info, mode="r", encoding=None):
def remove(self, path_info):
raise RemoteActionNotImplemented("remove", self.scheme)

def move(self, from_info, to_info):
def move(self, from_info, to_info, mode=None):
assert mode is None
self.copy(from_info, to_info)
self.remove(from_info)

Expand Down Expand Up @@ -718,6 +717,9 @@ def gc(self, named_cache):
removed = True
return removed

def is_protected(self, path_info):
return False

def changed_cache_file(self, checksum):
"""Compare the given checksum with the (corresponding) actual one.

Expand All @@ -730,7 +732,14 @@ def changed_cache_file(self, checksum):

- Remove the file from cache if it doesn't match the actual checksum
"""

cache_info = self.checksum_to_path_info(checksum)
if self.is_protected(cache_info):
logger.debug(
"Assuming '{}' is unchanged since it is read-only", cache_info
)
return False

actual = self.get_checksum(cache_info)

logger.debug(
Expand All @@ -744,6 +753,9 @@ def changed_cache_file(self, checksum):
return True

if actual.split(".")[0] == checksum.split(".")[0]:
# making cache file read-only so we don't need to check it
# next time
self.protect(cache_info)
return False

if self.exists(cache_info):
Expand Down
54 changes: 29 additions & 25 deletions dvc/remote/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,19 +32,11 @@ class RemoteLOCAL(RemoteBASE):

DEFAULT_CACHE_TYPES = ["reflink", "copy"]

CACHE_MODE = 0o444
SHARED_MODE_MAP = {None: (0o644, 0o755), "group": (0o664, 0o775)}

def __init__(self, repo, config):
super().__init__(repo, config)
self.protected = config.get("protected", False)

shared = config.get("shared")
self._file_mode, self._dir_mode = self.SHARED_MODE_MAP[shared]

if self.protected:
# cache files are set to be read-only for everyone
self._file_mode = stat.S_IREAD | stat.S_IRGRP | stat.S_IROTH

self.cache_dir = config.get("url")
self._dir_info = {}

Expand Down Expand Up @@ -142,23 +134,25 @@ def remove(self, path_info):
if self.exists(path_info):
remove(path_info.fspath)

def move(self, from_info, to_info):
def move(self, from_info, to_info, mode=None):
if from_info.scheme != "local" or to_info.scheme != "local":
raise NotImplementedError

self.makedirs(to_info.parent)

if self.isfile(from_info):
mode = self._file_mode
else:
mode = self._dir_mode
if mode is None:
if self.isfile(from_info):
mode = self._file_mode
else:
mode = self._dir_mode

move(from_info, to_info, mode=mode)

def copy(self, from_info, to_info):
tmp_info = to_info.parent / tmp_fname(to_info.name)
try:
System.copy(from_info, tmp_info)
os.chmod(fspath_py35(tmp_info), self._file_mode)
os.rename(fspath_py35(tmp_info), fspath_py35(to_info))
except Exception:
self.remove(tmp_info)
Expand Down Expand Up @@ -202,9 +196,13 @@ def hardlink(self, from_info, to_info):
def is_hardlink(path_info):
return System.is_hardlink(path_info)

@staticmethod
def reflink(from_info, to_info):
System.reflink(from_info, to_info)
def reflink(self, from_info, to_info):
tmp_info = to_info.parent / tmp_fname(to_info.name)
System.reflink(from_info, tmp_info)
# NOTE: reflink has its own separate inode, so you can set permissions
# that are different from the source.
os.chmod(fspath_py35(tmp_info), self._file_mode)
os.rename(fspath_py35(tmp_info), fspath_py35(to_info))

def cache_exists(self, checksums, jobs=None, name=None):
return [
Expand Down Expand Up @@ -402,8 +400,7 @@ def _log_missing_caches(checksum_info_dict):
)
logger.warning(msg)

@staticmethod
def _unprotect_file(path):
def _unprotect_file(self, path):
if System.is_symlink(path) or System.is_hardlink(path):
logger.debug("Unprotecting '{}'".format(path))
tmp = os.path.join(os.path.dirname(path), "." + uuid())
Expand All @@ -423,13 +420,13 @@ def _unprotect_file(path):
"a symlink or a hardlink.".format(path)
)

os.chmod(path, os.stat(path).st_mode | stat.S_IWRITE)
os.chmod(path, self._file_mode)

def _unprotect_dir(self, path):
assert is_working_tree(self.repo.tree)

for fname in self.repo.tree.walk_files(path):
RemoteLOCAL._unprotect_file(fname)
self._unprotect_file(fname)

def unprotect(self, path_info):
path = path_info.fspath
Expand All @@ -441,12 +438,11 @@ def unprotect(self, path_info):
if os.path.isdir(path):
self._unprotect_dir(path)
else:
RemoteLOCAL._unprotect_file(path)
self._unprotect_file(path)

@staticmethod
def protect(path_info):
def protect(self, path_info):
path = fspath_py35(path_info)
mode = stat.S_IREAD | stat.S_IRGRP | stat.S_IROTH
mode = self.CACHE_MODE

try:
os.chmod(path, mode)
Expand Down Expand Up @@ -519,3 +515,11 @@ def _get_unpacked_dir_names(self, checksums):
if self.is_dir_checksum(c):
unpacked.add(c + self.UNPACKED_DIR_SUFFIX)
return unpacked

def is_protected(self, path_info):
if not self.exists(path_info):
return False

mode = os.stat(fspath_py35(path_info)).st_mode

return stat.S_IMODE(mode) == self.CACHE_MODE
3 changes: 2 additions & 1 deletion dvc/remote/ssh/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,8 @@ def remove(self, path_info):
with self.ssh(path_info) as ssh:
ssh.remove(path_info.path)

def move(self, from_info, to_info):
def move(self, from_info, to_info, mode=None):
assert mode is None
if from_info.scheme != self.scheme or to_info.scheme != self.scheme:
raise NotImplementedError

Expand Down
1 change: 1 addition & 0 deletions tests/func/remote/test_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def test_dir_cache_changed_on_single_cache_file_modification(tmp_dir, dvc):
assert os.path.exists(unpacked_dir)

cache_file_path = dvc.cache.local.get(file_md5)
os.chmod(cache_file_path, 0o644)
with open(cache_file_path, "a") as fobj:
fobj.write("modification")

Expand Down
19 changes: 13 additions & 6 deletions tests/func/test_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,12 +489,12 @@ def test(self):
ret = main(["config", "cache.type", "hardlink"])
self.assertEqual(ret, 0)

ret = main(["config", "cache.protected", "true"])
self.assertEqual(ret, 0)

ret = main(["add", self.FOO])
self.assertEqual(ret, 0)

self.assertFalse(os.access(self.FOO, os.W_OK))
self.assertTrue(System.is_hardlink(self.FOO))

ret = main(["unprotect", self.FOO])
self.assertEqual(ret, 0)

Expand Down Expand Up @@ -561,7 +561,6 @@ def test_readding_dir_should_not_unprotect_all(tmp_dir, dvc, mocker):
tmp_dir.gen("dir/data", "data")

dvc.cache.local.cache_types = ["symlink"]
dvc.cache.local.protected = True

dvc.add("dir")
tmp_dir.gen("dir/new_file", "new_file_content")
Expand Down Expand Up @@ -618,15 +617,23 @@ def test_should_relink_on_repeated_add(
@pytest.mark.parametrize("link", ["hardlink", "symlink", "copy"])
def test_should_protect_on_repeated_add(link, tmp_dir, dvc):
dvc.cache.local.cache_types = [link]
dvc.cache.local.protected = True

tmp_dir.dvc_gen({"foo": "foo"})

dvc.unprotect("foo")

dvc.add("foo")

assert not os.access("foo", os.W_OK)
assert not os.access(
os.path.join(".dvc", "cache", "ac", "bd18db4cc2f85cedef654fccc4a4d8"),
os.W_OK,
)

# NOTE: Windows symlink perms don't propagate to the target
if link == "copy" or (link == "symlink" and os.name == "nt"):
assert os.access("foo", os.W_OK)
else:
assert not os.access("foo", os.W_OK)


def test_escape_gitignore_entries(tmp_dir, scm, dvc):
Expand Down
9 changes: 5 additions & 4 deletions tests/func/test_api.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import os
import shutil

import pytest

Expand All @@ -9,6 +8,8 @@
from dvc.exceptions import FileMissingError
from dvc.main import main
from dvc.path_info import URLInfo
from dvc.utils.fs import remove

from tests.remotes import Azure, GCP, HDFS, Local, OSS, S3, SSH


Expand Down Expand Up @@ -65,7 +66,7 @@ def test_open(remote_url, tmp_dir, dvc):
run_dvc("push")

# Remove cache to force download
shutil.rmtree(dvc.cache.local.cache_dir)
remove(dvc.cache.local.cache_dir)

with api.open("foo") as fd:
assert fd.read() == "foo-text"
Expand All @@ -85,7 +86,7 @@ def test_open_external(remote_url, erepo_dir):
erepo_dir.dvc.push(all_branches=True)

# Remove cache to force download
shutil.rmtree(erepo_dir.dvc.cache.local.cache_dir)
remove(erepo_dir.dvc.cache.local.cache_dir)

# Using file url to force clone to tmp repo
repo_url = "file://{}".format(erepo_dir)
Expand All @@ -101,7 +102,7 @@ def test_missing(remote_url, tmp_dir, dvc):
run_dvc("remote", "add", "-d", "upstream", remote_url)

# Remove cache to make foo missing
shutil.rmtree(dvc.cache.local.cache_dir)
remove(dvc.cache.local.cache_dir)

with pytest.raises(FileMissingError):
api.read("foo")
Expand Down
12 changes: 6 additions & 6 deletions tests/func/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,12 +184,12 @@ def test_default_cache_type(dvc):

@pytest.mark.skipif(os.name == "nt", reason="Not supported for Windows.")
@pytest.mark.parametrize(
"protected,dir_mode,file_mode",
[(False, 0o775, 0o664), (True, 0o775, 0o444)],
"group, dir_mode", [(False, 0o755), (True, 0o775)],
)
def test_shared_cache(tmp_dir, dvc, protected, dir_mode, file_mode):
with dvc.config.edit() as conf:
conf["cache"].update({"shared": "group", "protected": str(protected)})
def test_shared_cache(tmp_dir, dvc, group, dir_mode):
if group:
with dvc.config.edit() as conf:
conf["cache"].update({"shared": "group"})
dvc.cache = Cache(dvc)

tmp_dir.dvc_gen(
Expand All @@ -203,4 +203,4 @@ def test_shared_cache(tmp_dir, dvc, protected, dir_mode, file_mode):

for fname in fnames:
path = os.path.join(root, fname)
assert stat.S_IMODE(os.stat(path).st_mode) == file_mode
assert stat.S_IMODE(os.stat(path).st_mode) == 0o444
Loading