diff --git a/pyproject.toml b/pyproject.toml index 89f8846..f142984 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,7 +23,7 @@ requires-python = ">=3.8" dynamic = ["version"] dependencies = [ "funcy>=1.14; python_version < '3.12'", - "fsspec>=2022.10.0", + "fsspec>=2024.2.0", ] [project.urls] @@ -104,7 +104,7 @@ module = [ ] [tool.codespell] -ignore-words-list = " " +ignore-words-list = "cachable," [tool.ruff] ignore = [ diff --git a/src/dvc_objects/db.py b/src/dvc_objects/db.py index b67f77f..bb730f2 100644 --- a/src/dvc_objects/db.py +++ b/src/dvc_objects/db.py @@ -16,13 +16,15 @@ cast, ) +from fsspec.callbacks import DEFAULT_CALLBACK + from .errors import ObjectDBPermissionError -from .fs.callbacks import DEFAULT_CALLBACK from .obj import Object if TYPE_CHECKING: + from fsspec import Callback + from .fs.base import AnyFSPath, FileSystem - from .fs.callbacks import Callback logger = logging.getLogger(__name__) diff --git a/src/dvc_objects/executors.py b/src/dvc_objects/executors.py index 28c707c..68e7212 100644 --- a/src/dvc_objects/executors.py +++ b/src/dvc_objects/executors.py @@ -18,7 +18,7 @@ TypeVar, ) -from .fs.callbacks import Callback +from fsspec import Callback _T = TypeVar("_T") diff --git a/src/dvc_objects/fs/base.py b/src/dvc_objects/fs/base.py index 1eb0aec..e0a741c 100644 --- a/src/dvc_objects/fs/base.py +++ b/src/dvc_objects/fs/base.py @@ -27,15 +27,12 @@ import fsspec from fsspec.asyn import get_loop +from fsspec.callbacks import DEFAULT_CALLBACK from dvc_objects.compat import cached_property from dvc_objects.executors import ThreadPoolExecutor, batch_coros -from .callbacks import ( - DEFAULT_CALLBACK, - wrap_and_branch_callback, - wrap_file, -) +from .callbacks import wrap_file from .errors import RemoteMissingDepsError if TYPE_CHECKING: @@ -706,9 +703,14 @@ def put( callback.set_size(len(from_infos)) executor = ThreadPoolExecutor(max_workers=jobs, cancel_on_error=True) + + def put_file(from_path, to_path): + with callback.branched(from_path, to_path) as child: + return self.put_file(from_path, to_path, callback=child) + with executor: - put_file = wrap_and_branch_callback(callback, self.put_file) - list(executor.imap_unordered(put_file, from_infos, to_infos)) + it = executor.imap_unordered(put_file, from_infos, to_infos) + list(callback.wrap(it)) def get( self, @@ -724,9 +726,8 @@ def get( def get_file(rpath, lpath, **kwargs): localfs.makedirs(localfs.parent(lpath), exist_ok=True) - self.fs.get_file(rpath, lpath, **kwargs) - - get_file = wrap_and_branch_callback(callback, get_file) + with callback.branched(rpath, lpath) as child: + self.fs.get_file(rpath, lpath, callback=child, **kwargs) if isinstance(from_info, list) and isinstance(to_info, list): from_infos: List[AnyFSPath] = from_info @@ -737,7 +738,9 @@ def get_file(rpath, lpath, **kwargs): if not self.isdir(from_info): callback.set_size(1) - return get_file(from_info, to_info) + get_file(from_info, to_info) + callback.relative_update() + return from_infos = list(self.find(from_info)) if not from_infos: @@ -760,7 +763,8 @@ def get_file(rpath, lpath, **kwargs): callback.set_size(len(from_infos)) executor = ThreadPoolExecutor(max_workers=jobs, cancel_on_error=True) with executor: - list(executor.imap_unordered(get_file, from_infos, to_infos)) + it = executor.imap_unordered(get_file, from_infos, to_infos) + list(callback.wrap(it)) def ukey(self, path: AnyFSPath) -> str: return self.fs.ukey(path) diff --git a/src/dvc_objects/fs/callbacks.py b/src/dvc_objects/fs/callbacks.py index 41b3fd3..467e26a 100644 --- a/src/dvc_objects/fs/callbacks.py +++ b/src/dvc_objects/fs/callbacks.py @@ -1,19 +1,20 @@ -import asyncio from functools import wraps -from typing import TYPE_CHECKING, Any, BinaryIO, Callable, Dict, Optional, TypeVar, cast +from typing import TYPE_CHECKING, BinaryIO, Optional, Type, cast import fsspec +from fsspec.callbacks import DEFAULT_CALLBACK, Callback, NoOpCallback if TYPE_CHECKING: from typing import Union - from dvc_objects._tqdm import Tqdm + from tqdm import tqdm -F = TypeVar("F", bound=Callable) + +__all__ = ["Callback", "NoOpCallback", "TqdmCallback", "DEFAULT_CALLBACK"] class CallbackStream: - def __init__(self, stream, callback: fsspec.Callback): + def __init__(self, stream, callback: Callback): self.stream = stream @wraps(stream.read) @@ -28,151 +29,29 @@ def __getattr__(self, attr): return getattr(self.stream, attr) -class ScopedCallback(fsspec.Callback): - def __enter__(self): - return self - - def __exit__(self, *exc_args): - self.close() - - def close(self): - """Handle here on exit.""" - - def branch( - self, - path_1: "Union[str, BinaryIO]", - path_2: str, - kwargs: Dict[str, Any], - child: Optional["Callback"] = None, - ) -> "Callback": - child = kwargs["callback"] = child or DEFAULT_CALLBACK - return child - - -class Callback(ScopedCallback): - def absolute_update(self, value: int) -> None: - value = value if value is not None else self.value - return super().absolute_update(value) - - @classmethod - def as_callback( - cls, maybe_callback: Optional[fsspec.Callback] = None - ) -> "Callback": - if maybe_callback is None: - return DEFAULT_CALLBACK - if isinstance(maybe_callback, Callback): - return maybe_callback - return FsspecCallbackWrapper(maybe_callback) - - -class NoOpCallback(Callback, fsspec.callbacks.NoOpCallback): - pass - - -class TqdmCallback(Callback): +class TqdmCallback(fsspec.callbacks.TqdmCallback): def __init__( self, size: Optional[int] = None, value: int = 0, - progress_bar: Optional["Tqdm"] = None, + progress_bar: Optional["tqdm"] = None, + tqdm_cls: Optional[Type["tqdm"]] = None, **tqdm_kwargs, ): from dvc_objects._tqdm import Tqdm tqdm_kwargs.pop("total", None) - self._tqdm_kwargs = tqdm_kwargs - self._tqdm_cls = Tqdm - self.tqdm = progress_bar - super().__init__(size=size, value=value) - - def close(self): - if self.tqdm is not None: - self.tqdm.close() - self.tqdm = None - - def call(self, hook_name=None, **kwargs): - if self.tqdm is None: - self.tqdm = self._tqdm_cls(**self._tqdm_kwargs, total=self.size or -1) - self.tqdm.update_to(self.value, total=self.size) - - def branch( - self, - path_1: "Union[str, BinaryIO]", - path_2: str, - kwargs: Dict[str, Any], - child: Optional[Callback] = None, - ): + tqdm_cls = tqdm_cls or Tqdm + super().__init__( + tqdm_kwargs=tqdm_kwargs, tqdm_cls=tqdm_cls, size=size, value=value + ) + if progress_bar is None: + self.tqdm = progress_bar + + def branched(self, path_1: "Union[str, BinaryIO]", path_2: str, **kwargs): desc = path_1 if isinstance(path_1, str) else path_2 - child = child or TqdmCallback(bytes=True, desc=desc) - return super().branch(path_1, path_2, kwargs, child=child) - - -class FsspecCallbackWrapper(Callback): - def __init__(self, callback: fsspec.Callback): - object.__setattr__(self, "_callback", callback) - - def __getattr__(self, name: str): - return getattr(self._callback, name) + return TqdmCallback(bytes=True, desc=desc) - def __setattr__(self, name: str, value: Any): - setattr(self._callback, name, value) - def absolute_update(self, value: int) -> None: - value = value if value is not None else self.value - return self._callback.absolute_update(value) - - def branch( - self, - path_1: "Union[str, BinaryIO]", - path_2: str, - kwargs: Dict[str, Any], - child: Optional["Callback"] = None, - ) -> "Callback": - if not child: - self._callback.branch(path_1, path_2, kwargs) - child = self.as_callback(kwargs.get("callback")) - return super().branch(path_1, path_2, kwargs, child=child) - - -def wrap_fn(callback: fsspec.Callback, fn: F) -> F: - @wraps(fn) - async def async_wrapper(*args, **kwargs): - res = await fn(*args, **kwargs) - callback.relative_update() - return res - - @wraps(fn) - def sync_wrapper(*args, **kwargs): - res = fn(*args, **kwargs) - callback.relative_update() - return res - - return async_wrapper if asyncio.iscoroutinefunction(fn) else sync_wrapper # type: ignore[return-value] - - -def branch_callback(callback: fsspec.Callback, fn: F) -> F: - callback = Callback.as_callback(callback) - - @wraps(fn) - async def async_wrapper(path1: "Union[str, BinaryIO]", path2: str, **kwargs): - with callback.branch(path1, path2, kwargs): - return await fn(path1, path2, **kwargs) - - @wraps(fn) - def sync_wrapper(path1: "Union[str, BinaryIO]", path2: str, **kwargs): - with callback.branch(path1, path2, kwargs): - return fn(path1, path2, **kwargs) - - return async_wrapper if asyncio.iscoroutinefunction(fn) else sync_wrapper # type: ignore[return-value] - - -def wrap_and_branch_callback(callback: fsspec.Callback, fn: F) -> F: - branch_wrapper = branch_callback(callback, fn) - return wrap_fn(callback, branch_wrapper) - - -def wrap_file(file, callback: fsspec.Callback) -> BinaryIO: +def wrap_file(file, callback: Callback) -> BinaryIO: return cast(BinaryIO, CallbackStream(file, callback)) - - -DEFAULT_CALLBACK = NoOpCallback() diff --git a/src/dvc_objects/fs/generic.py b/src/dvc_objects/fs/generic.py index 36bbaac..a25df9a 100644 --- a/src/dvc_objects/fs/generic.py +++ b/src/dvc_objects/fs/generic.py @@ -7,16 +7,17 @@ from typing import TYPE_CHECKING, Any, Callable, List, Optional, Union from fsspec.asyn import get_loop +from fsspec.callbacks import DEFAULT_CALLBACK from dvc_objects.executors import ThreadPoolExecutor, batch_coros -from .callbacks import DEFAULT_CALLBACK, wrap_and_branch_callback from .local import LocalFileSystem, localfs from .utils import as_atomic, umask if TYPE_CHECKING: + from fsspec import Callback + from .base import AnyFSPath, FileSystem - from .callbacks import Callback logger = logging.getLogger(__name__) @@ -103,7 +104,6 @@ def copy( ) jobs = batch_size or to_fs.jobs - put_file = wrap_and_branch_callback(callback, to_fs.put_file) put_file_kwargs = {} if hasattr(to_fs.fs, "max_concurrency"): put_file_kwargs["max_concurrency"] = jobs if len(from_path) == 1 else 1 @@ -112,9 +112,10 @@ def _copy_one(from_p: "AnyFSPath", to_p: "AnyFSPath"): try: with from_fs.open(from_p, mode="rb") as fobj: size = from_fs.size(from_p) - return put_file( - fobj, to_p, size=size, callback=callback, **put_file_kwargs - ) + with callback.branched(from_p, to_p) as child: + return to_fs.put_file( + fobj, to_p, size=size, callback=child, **put_file_kwargs + ) except Exception as exc: # noqa: BLE001 if on_error is not None: on_error(from_p, to_p, exc) @@ -122,14 +123,16 @@ def _copy_one(from_p: "AnyFSPath", to_p: "AnyFSPath"): raise if len(from_path) == 1: - return _copy_one(from_path[0], to_path[0]) + _copy_one(from_path[0], to_path[0]) + return callback.relative_update() executor = ThreadPoolExecutor(max_workers=jobs, cancel_on_error=True) with executor: - list(executor.imap_unordered(_copy_one, from_path, to_path)) + it = executor.imap_unordered(_copy_one, from_path, to_path) + list(callback.wrap(it)) -def _put( +def _put( # noqa: C901 from_paths: List["AnyFSPath"], to_fs: "FileSystem", to_paths: List["AnyFSPath"], @@ -138,14 +141,16 @@ def _put( on_error: Optional[TransferErrorHandler] = None, ) -> None: jobs = batch_size or to_fs.jobs - put_file = wrap_and_branch_callback(callback, to_fs.put_file) put_file_kwargs = {} if hasattr(to_fs.fs, "max_concurrency"): put_file_kwargs["max_concurrency"] = jobs if len(from_paths) == 1 else 1 def _put_one(from_path: "AnyFSPath", to_path: "AnyFSPath"): try: - return put_file(from_path, to_path, callback=callback, **put_file_kwargs) + with callback.branched(from_path, to_path) as child: + return to_fs.put_file( + from_path, to_path, callback=child, **put_file_kwargs + ) except Exception as exc: # noqa: BLE001 if on_error is not None: on_error(from_path, to_path, exc) @@ -153,18 +158,27 @@ def _put_one(from_path: "AnyFSPath", to_path: "AnyFSPath"): raise if len(from_paths) == 1: - return _put_one(from_paths[0], to_paths[0]) + _put_one(from_paths[0], to_paths[0]) + return callback.relative_update() if to_fs.fs.async_impl: - put_coro = wrap_and_branch_callback(callback, to_fs.fs._put_file) + to_fs_async = to_fs.fs + + async def put_coro(from_path, to_path, **kwargs): + with callback.branched(from_path, to_path) as child: + return await to_fs_async._put_file( + from_path, to_path, callback=child, **kwargs + ) + loop = get_loop() fut = asyncio.run_coroutine_threadsafe( batch_coros( [ - put_coro(from_path, to_path, callback=callback, **put_file_kwargs) + put_coro(from_path, to_path, **put_file_kwargs) for from_path, to_path in zip(from_paths, to_paths) ], batch_size=jobs, + callback=callback, return_exceptions=True, ), loop, @@ -179,7 +193,8 @@ def _put_one(from_path: "AnyFSPath", to_path: "AnyFSPath"): executor = ThreadPoolExecutor(max_workers=jobs, cancel_on_error=True) with executor: - list(executor.imap_unordered(_put_one, from_paths, to_paths)) + it = executor.imap_unordered(_put_one, from_paths, to_paths) + list(callback.wrap(it)) def _get( # noqa: C901 @@ -191,7 +206,6 @@ def _get( # noqa: C901 on_error: Optional[TransferErrorHandler] = None, ) -> None: jobs = batch_size or from_fs.jobs - get_file = wrap_and_branch_callback(callback, from_fs.get_file) get_file_kwargs = {} if hasattr(from_fs.fs, "max_concurrency"): get_file_kwargs["max_concurrency"] = jobs if len(from_paths) == 1 else 1 @@ -199,9 +213,10 @@ def _get( # noqa: C901 def _get_one(from_path: "AnyFSPath", to_path: "AnyFSPath"): with as_atomic(localfs, to_path, create_parents=True) as tmp_file: try: - return get_file( - from_path, tmp_file, callback=callback, **get_file_kwargs - ) + with callback.branched(from_path, to_path) as child: + return from_fs.get_file( + from_path, tmp_file, callback=child, **get_file_kwargs + ) except Exception as exc: # noqa: BLE001 if on_error is not None: on_error(from_path, to_path, exc) @@ -209,16 +224,18 @@ def _get_one(from_path: "AnyFSPath", to_path: "AnyFSPath"): raise if len(from_paths) == 1: - return _get_one(from_paths[0], to_paths[0]) + _get_one(from_paths[0], to_paths[0]) + return callback.relative_update() if from_fs.fs.async_impl: + from_async_fs = from_fs.fs async def _get_one_coro(from_path: "AnyFSPath", to_path: "AnyFSPath"): - get_coro = wrap_and_branch_callback(callback, from_fs.fs._get_file) with as_atomic(localfs, to_path, create_parents=True) as tmp_file: - return await get_coro( - from_path, tmp_file, callback=callback, **get_file_kwargs - ) + with callback.branched(from_path, to_path) as child: + return await from_async_fs._get_file( + from_path, tmp_file, callback=child, **get_file_kwargs + ) loop = get_loop() fut = asyncio.run_coroutine_threadsafe( @@ -229,6 +246,7 @@ async def _get_one_coro(from_path: "AnyFSPath", to_path: "AnyFSPath"): ], batch_size=jobs, return_exceptions=True, + callback=callback, ), loop, ) @@ -242,7 +260,8 @@ async def _get_one_coro(from_path: "AnyFSPath", to_path: "AnyFSPath"): executor = ThreadPoolExecutor(max_workers=jobs, cancel_on_error=True) with executor: - list(executor.imap_unordered(_get_one, from_paths, to_paths)) + it = executor.imap_unordered(_get_one, from_paths, to_paths) + list(callback.wrap(it)) def _try_links( diff --git a/src/dvc_objects/fs/utils.py b/src/dvc_objects/fs/utils.py index ab2c284..921972c 100644 --- a/src/dvc_objects/fs/utils.py +++ b/src/dvc_objects/fs/utils.py @@ -10,14 +10,17 @@ from secrets import token_urlsafe from typing import TYPE_CHECKING, Any, Collection, Dict, Iterator, Optional, Set, Union +from fsspec.callbacks import DEFAULT_CALLBACK + from dvc_objects.executors import ThreadPoolExecutor from . import system -from .callbacks import DEFAULT_CALLBACK, wrap_file +from .callbacks import wrap_file if TYPE_CHECKING: + from fsspec import Callback + from .base import AnyFSPath, FileSystem - from .callbacks import Callback logger = logging.getLogger(__name__) diff --git a/tests/fs/test_callbacks.py b/tests/fs/test_callbacks.py index dfa84e7..4620c98 100644 --- a/tests/fs/test_callbacks.py +++ b/tests/fs/test_callbacks.py @@ -1,152 +1,6 @@ -from typing import Optional +from fsspec.callbacks import Callback -import fsspec -import pytest - -from dvc_objects.fs.callbacks import ( - DEFAULT_CALLBACK, - Callback, - TqdmCallback, - branch_callback, - wrap_and_branch_callback, - wrap_file, - wrap_fn, -) - - -@pytest.mark.parametrize("api", ["set_size", "absolute_update"]) -def test_callback_with_none(request, api, mocker): - """ - Test that callback don't fail if they receive None. - The callbacks should not receive None, but there may be some - filesystems that are not compliant, we may want to maintain - maximum compatibility, and not break UI in these edge-cases. - See https://github.com/iterative/dvc/issues/7704. - """ - callback = Callback.as_callback() - request.addfinalizer(callback.close) - - call_mock = mocker.spy(callback, "call") - method = getattr(callback, api) - method(None) - call_mock.assert_called_once_with() - if callback is not DEFAULT_CALLBACK: - assert callback.size is None - assert callback.value == 0 - - -def test_wrap_fsspec(): - def _branch_fn(*args, callback: Optional["Callback"] = None, **kwargs): - pass - - callback = fsspec.callbacks.Callback() - assert callback.value == 0 - with Callback.as_callback(callback) as cb: - assert not isinstance(cb, TqdmCallback) - assert cb.value == 0 - cb.relative_update() - assert cb.value == 1 - assert callback.value == 1 - - with cb.branch("foo", "bar", {}) as child: - _branch_fn("foo", "bar", callback=child) - cb.relative_update() - - assert cb.value == 2 - assert callback.value == 2 - - -def ids_func(cb_type): - return f"{cb_type.__module__}.{cb_type.__qualname__}" - - -class IsDVCCallback: - def __eq__(self, value: object) -> bool: - return isinstance(value, Callback) - - -@pytest.fixture(params=[Callback, fsspec.Callback], ids=ids_func) -def cb_class(request): - return request.param - - -def test_wrap_fn_sync(mocker, cb_class): - m = mocker.MagicMock(return_value=1) - callback = cb_class() - - wrapped = wrap_fn(callback, m) - - assert wrapped("arg") == 1 - assert callback.value == 1 - m.assert_called_once_with("arg") - - -@pytest.mark.asyncio -async def test_wrap_fn_async(mocker, cb_class): - m = mocker.AsyncMock(return_value=1) - callback = cb_class() - - wrapped = wrap_fn(callback, m) - - assert await wrapped("arg") == 1 - assert callback.value == 1 - m.assert_called_once_with("arg") - - -def test_branch_fn_sync(mocker, cb_class): - m = mocker.MagicMock(return_value=1) - callback = cb_class() - spy = mocker.spy(callback, "branch") - wrapped = branch_callback(callback, m) - - assert wrapped("arg1", "arg2") == 1 - assert callback.value == 0 - assert spy.call_count == 1 - m.assert_called_once_with("arg1", "arg2", callback=IsDVCCallback()) - - -@pytest.mark.asyncio -async def test_branch_fn_async(mocker, cb_class): - m = mocker.AsyncMock(return_value=1) - callback = cb_class() - spy = mocker.spy(callback, "branch") - wrapped = branch_callback(callback, m) - - assert await wrapped("arg1", "arg2") == 1 - assert callback.value == 0 - assert spy.call_count == 1 - m.assert_called_once_with("arg1", "arg2", callback=IsDVCCallback()) - - -def test_wrap_and_branch_callback_sync(mocker, cb_class): - m = mocker.MagicMock(return_value=1) - callback = cb_class() - spy = mocker.spy(callback, "branch") - wrapped = wrap_and_branch_callback(callback, m) - - assert wrapped("arg1", "arg2", arg3="arg3") == 1 - assert wrapped("argA", "argB", arg3="argC") == 1 - - m.assert_any_call("arg1", "arg2", arg3="arg3", callback=IsDVCCallback()) - m.assert_any_call("argA", "argB", arg3="argC", callback=IsDVCCallback()) - assert callback.value == 2 - assert spy.call_count == 2 - - -@pytest.mark.asyncio -async def test_wrap_and_branch_callback_async(mocker, cb_class): - m = mocker.AsyncMock(return_value=1) - callback = cb_class() - spy = mocker.spy(callback, "branch") - wrapped = wrap_and_branch_callback(callback, m) - - assert await wrapped("arg1", "arg2", arg3="arg3") == 1 - assert await wrapped("argA", "argB", arg3="argC") == 1 - - m.assert_any_call("arg1", "arg2", arg3="arg3", callback=IsDVCCallback()) - m.assert_any_call("argA", "argB", arg3="argC", callback=IsDVCCallback()) - assert callback.value == 2 - assert spy.call_count == 2 +from dvc_objects.fs.callbacks import wrap_file def test_wrap_file(memfs): diff --git a/tests/fs/test_generic.py b/tests/fs/test_generic.py new file mode 100644 index 0000000..385fecd --- /dev/null +++ b/tests/fs/test_generic.py @@ -0,0 +1,120 @@ +import pytest +from fsspec import Callback +from fsspec.asyn import AsyncFileSystem +from fsspec.implementations.memory import MemoryFileSystem as MemoryFS + +from dvc_objects.fs.generic import copy, transfer +from dvc_objects.fs.local import LocalFileSystem +from dvc_objects.fs.memory import MemoryFileSystem + + +def awrap(fn): + async def inner(self, *args, **kwargs): + return fn(self.fs, *args, **kwargs) + + return inner + + +class AsyncMemoryFS(AsyncFileSystem): + cachable = False + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.fs = MemoryFS() + self.fs.store = {} + self.fs.pseudo_dirs = [""] + + def _open(self, *args, **kwargs): + return self.fs.open(*args, **kwargs) + + _info = awrap(MemoryFS.info) + _ls = awrap(MemoryFS.ls) + _mkdir = awrap(MemoryFS.mkdir) + _makedirs = awrap(MemoryFS.makedirs) + _get_file = awrap(MemoryFS.get_file) + _put_file = awrap(MemoryFS.put_file) + _cat_file = awrap(MemoryFS.cat_file) + _pipe_file = awrap(MemoryFS.pipe_file) + _cp_file = awrap(MemoryFS.cp_file) + _rm_file = awrap(MemoryFS.rm_file) + + +fs_clses = [ + LocalFileSystem, + lambda: MemoryFileSystem(global_store=False), + lambda: MemoryFileSystem(fs=AsyncMemoryFS()), +] +fs_clses[1].__name__ = MemoryFileSystem.__name__ # type: ignore[attr-defined] +fs_clses[2].__name__ = "Async" + MemoryFileSystem.__name__ # type: ignore[attr-defined] + + +@pytest.mark.parametrize("files", [{"foo": b"foo"}, {"foo": b"foo", "bar": b"barr"}]) +@pytest.mark.parametrize("fs_cls1", fs_clses) +@pytest.mark.parametrize("fs_cls2", fs_clses) +def test_copy(tmp_path, files, fs_cls1, fs_cls2, mocker): + fs1, fs2 = fs_cls1(), fs_cls2() + fs1_root = tmp_path if isinstance(fs1, LocalFileSystem) else fs1.root_marker + fs2_root = tmp_path if isinstance(fs2, LocalFileSystem) else fs2.root_marker + src_root = fs1.join(fs1_root, "src") + dest_root = fs2.join(fs2_root, "dest") + + src_files = {fs1.join(src_root, f): c for f, c in files.items()} + dest_files = {fs2.join(dest_root, f): c for f, c in files.items()} + fs1.mkdir(src_root) + fs1.pipe(src_files) + + callback = Callback() + spy_close = mocker.spy(Callback, "close") + child_callbacks = [Callback() for _ in files] + + branched = mocker.patch.object(callback, "branched", side_effect=child_callbacks) + copy(fs1, list(src_files), fs2, list(dest_files), callback=callback) + + assert fs2.cat(list(dest_files)) == dest_files + + n = len(files) + # assert main callback works + assert callback.value == n + assert callback.size is None # does not set sizes + # assert child callbacks are handled correctly + assert branched.call_count == n, f"expected branched to be called {n} times" + assert spy_close.call_count == n, f"expected close to be called {n} times" + + if isinstance(fs1, LocalFileSystem) and isinstance(fs2, LocalFileSystem): + # localfs copy avoids calling set_size or update if fs supports reflink + # or, file size is less than 1GB + return + assert {c.size for c in child_callbacks} == {len(c) for c in files.values()} + assert {c.value for c in child_callbacks} == {len(c) for c in files.values()} + + +@pytest.mark.parametrize("files", [{"foo": b"foo"}, {"foo": b"foo", "bar": b"barr"}]) +@pytest.mark.parametrize( + "link_type", + [ + pytest.param("reflink", marks=pytest.mark.xfail(reason="unsupported")), + "symlink", + "hardlink", + "copy", + ], +) +def test_transfer_between_local_fses(mocker, tmp_path, files, link_type): + fs = LocalFileSystem() + fs.mkdir(fs.join(tmp_path, "src")) + fs.mkdir(fs.join(tmp_path, "dest")) + + src_files = {fs.join(tmp_path, "src", f): c for f, c in files.items()} + dest_files = {fs.join(tmp_path, "dest", f): c for f, c in files.items()} + + fs.pipe(src_files) + + callback = Callback() + + branched = mocker.patch.object(callback, "branched") + transfer( + fs, list(src_files), fs, list(dest_files), callback=callback, links=[link_type] + ) + assert fs.cat([str(tmp_path / "dest" / file) for file in files]) == dest_files + assert callback.value == len(files) + assert callback.size is None # does not set sizes + assert branched.call_count == (len(files) if link_type == "copy" else 0)