From 348a18ed14eac52b087198f75c67b63ef156b931 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Saugat=20Pachhai=20=28=E0=A4=B8=E0=A5=8C=E0=A4=97=E0=A4=BE?= =?UTF-8?q?=E0=A4=A4=29?= Date: Sun, 4 Feb 2024 22:53:47 +0545 Subject: [PATCH 1/3] cleanup callback Structured callbacks are now available in fsspec>=2024.2.0, so we can remove most of the code here. Almost all of the `dvc_objects.fs.callbacks` is deprecated, and will be removed within 4.x minor release, including TqdmCallback. TqdmCallback can be replaced with fsspec's TqdmCallback, or can be overridden very easily. This will likely be moved to dvc-data for now. Similarly, `dvc_objects._tqdm` module is deprecated and is slated for removal within 4.x minor release. The code will likely move over to dvc-data. --- pyproject.toml | 2 +- src/dvc_objects/db.py | 6 +- src/dvc_objects/executors.py | 2 +- src/dvc_objects/fs/base.py | 7 +- src/dvc_objects/fs/callbacks.py | 114 ++++---------------------------- src/dvc_objects/fs/generic.py | 6 +- src/dvc_objects/fs/utils.py | 7 +- tests/fs/test_callbacks.py | 41 +++--------- 8 files changed, 38 insertions(+), 147 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 89f8846..058e716 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,7 +23,7 @@ requires-python = ">=3.8" dynamic = ["version"] dependencies = [ "funcy>=1.14; python_version < '3.12'", - "fsspec>=2022.10.0", + "fsspec>=2023.2.0", ] [project.urls] diff --git a/src/dvc_objects/db.py b/src/dvc_objects/db.py index b67f77f..bb730f2 100644 --- a/src/dvc_objects/db.py +++ b/src/dvc_objects/db.py @@ -16,13 +16,15 @@ cast, ) +from fsspec.callbacks import DEFAULT_CALLBACK + from .errors import ObjectDBPermissionError -from .fs.callbacks import DEFAULT_CALLBACK from .obj import Object if TYPE_CHECKING: + from fsspec import Callback + from .fs.base import AnyFSPath, FileSystem - from .fs.callbacks import Callback logger = logging.getLogger(__name__) diff --git a/src/dvc_objects/executors.py b/src/dvc_objects/executors.py index 28c707c..68e7212 100644 --- a/src/dvc_objects/executors.py +++ b/src/dvc_objects/executors.py @@ -18,7 +18,7 @@ TypeVar, ) -from .fs.callbacks import Callback +from fsspec import Callback _T = TypeVar("_T") diff --git a/src/dvc_objects/fs/base.py b/src/dvc_objects/fs/base.py index 1eb0aec..01a09a1 100644 --- a/src/dvc_objects/fs/base.py +++ b/src/dvc_objects/fs/base.py @@ -27,15 +27,12 @@ import fsspec from fsspec.asyn import get_loop +from fsspec.callbacks import DEFAULT_CALLBACK from dvc_objects.compat import cached_property from dvc_objects.executors import ThreadPoolExecutor, batch_coros -from .callbacks import ( - DEFAULT_CALLBACK, - wrap_and_branch_callback, - wrap_file, -) +from .callbacks import wrap_and_branch_callback, wrap_file from .errors import RemoteMissingDepsError if TYPE_CHECKING: diff --git a/src/dvc_objects/fs/callbacks.py b/src/dvc_objects/fs/callbacks.py index 41b3fd3..91b1afd 100644 --- a/src/dvc_objects/fs/callbacks.py +++ b/src/dvc_objects/fs/callbacks.py @@ -3,6 +3,7 @@ from typing import TYPE_CHECKING, Any, BinaryIO, Callable, Dict, Optional, TypeVar, cast import fsspec +from fsspec.callbacks import DEFAULT_CALLBACK, Callback, NoOpCallback if TYPE_CHECKING: from typing import Union @@ -11,9 +12,11 @@ F = TypeVar("F", bound=Callable) +__all__ = ["Callback", "NoOpCallback", "TqdmCallback", "DEFAULT_CALLBACK"] + class CallbackStream: - def __init__(self, stream, callback: fsspec.Callback): + def __init__(self, stream, callback: Callback): self.stream = stream @wraps(stream.read) @@ -28,48 +31,7 @@ def __getattr__(self, attr): return getattr(self.stream, attr) -class ScopedCallback(fsspec.Callback): - def __enter__(self): - return self - - def __exit__(self, *exc_args): - self.close() - - def close(self): - """Handle here on exit.""" - - def branch( - self, - path_1: "Union[str, BinaryIO]", - path_2: str, - kwargs: Dict[str, Any], - child: Optional["Callback"] = None, - ) -> "Callback": - child = kwargs["callback"] = child or DEFAULT_CALLBACK - return child - - -class Callback(ScopedCallback): - def absolute_update(self, value: int) -> None: - value = value if value is not None else self.value - return super().absolute_update(value) - - @classmethod - def as_callback( - cls, maybe_callback: Optional[fsspec.Callback] = None - ) -> "Callback": - if maybe_callback is None: - return DEFAULT_CALLBACK - if isinstance(maybe_callback, Callback): - return maybe_callback - return FsspecCallbackWrapper(maybe_callback) - - -class NoOpCallback(Callback, fsspec.callbacks.NoOpCallback): - pass - - -class TqdmCallback(Callback): +class TqdmCallback(fsspec.callbacks.TqdmCallback): def __init__( self, size: Optional[int] = None, @@ -80,58 +42,18 @@ def __init__( from dvc_objects._tqdm import Tqdm tqdm_kwargs.pop("total", None) - self._tqdm_kwargs = tqdm_kwargs - self._tqdm_cls = Tqdm - self.tqdm = progress_bar - super().__init__(size=size, value=value) - - def close(self): - if self.tqdm is not None: - self.tqdm.close() - self.tqdm = None - - def call(self, hook_name=None, **kwargs): - if self.tqdm is None: - self.tqdm = self._tqdm_cls(**self._tqdm_kwargs, total=self.size or -1) - self.tqdm.update_to(self.value, total=self.size) - - def branch( + super().__init__(tqdm_kwargs=tqdm_kwargs, tqdm_cls=Tqdm, size=size, value=value) + if progress_bar: + self.tqdm = progress_bar + + def branched( self, path_1: "Union[str, BinaryIO]", path_2: str, kwargs: Dict[str, Any], - child: Optional[Callback] = None, ): desc = path_1 if isinstance(path_1, str) else path_2 - child = child or TqdmCallback(bytes=True, desc=desc) - return super().branch(path_1, path_2, kwargs, child=child) - - -class FsspecCallbackWrapper(Callback): - def __init__(self, callback: fsspec.Callback): - object.__setattr__(self, "_callback", callback) - - def __getattr__(self, name: str): - return getattr(self._callback, name) - - def __setattr__(self, name: str, value: Any): - setattr(self._callback, name, value) - - def absolute_update(self, value: int) -> None: - value = value if value is not None else self.value - return self._callback.absolute_update(value) - - def branch( - self, - path_1: "Union[str, BinaryIO]", - path_2: str, - kwargs: Dict[str, Any], - child: Optional["Callback"] = None, - ) -> "Callback": - if not child: - self._callback.branch(path_1, path_2, kwargs) - child = self.as_callback(kwargs.get("callback")) - return super().branch(path_1, path_2, kwargs, child=child) + return TqdmCallback(bytes=True, desc=desc) def wrap_fn(callback: fsspec.Callback, fn: F) -> F: @@ -151,19 +73,12 @@ def sync_wrapper(*args, **kwargs): def branch_callback(callback: fsspec.Callback, fn: F) -> F: - callback = Callback.as_callback(callback) - - @wraps(fn) - async def async_wrapper(path1: "Union[str, BinaryIO]", path2: str, **kwargs): - with callback.branch(path1, path2, kwargs): - return await fn(path1, path2, **kwargs) - @wraps(fn) def sync_wrapper(path1: "Union[str, BinaryIO]", path2: str, **kwargs): - with callback.branch(path1, path2, kwargs): + with callback.branched(path1, path2): return fn(path1, path2, **kwargs) - return async_wrapper if asyncio.iscoroutinefunction(fn) else sync_wrapper # type: ignore[return-value] + return callback.branch_coro(fn) if asyncio.iscoroutinefunction(fn) else sync_wrapper # type: ignore[return-value] def wrap_and_branch_callback(callback: fsspec.Callback, fn: F) -> F: @@ -173,6 +88,3 @@ def wrap_and_branch_callback(callback: fsspec.Callback, fn: F) -> F: def wrap_file(file, callback: fsspec.Callback) -> BinaryIO: return cast(BinaryIO, CallbackStream(file, callback)) - - -DEFAULT_CALLBACK = NoOpCallback() diff --git a/src/dvc_objects/fs/generic.py b/src/dvc_objects/fs/generic.py index 36bbaac..06b51c7 100644 --- a/src/dvc_objects/fs/generic.py +++ b/src/dvc_objects/fs/generic.py @@ -7,16 +7,18 @@ from typing import TYPE_CHECKING, Any, Callable, List, Optional, Union from fsspec.asyn import get_loop +from fsspec.callbacks import DEFAULT_CALLBACK from dvc_objects.executors import ThreadPoolExecutor, batch_coros -from .callbacks import DEFAULT_CALLBACK, wrap_and_branch_callback +from .callbacks import wrap_and_branch_callback from .local import LocalFileSystem, localfs from .utils import as_atomic, umask if TYPE_CHECKING: + from fsspec import Callback + from .base import AnyFSPath, FileSystem - from .callbacks import Callback logger = logging.getLogger(__name__) diff --git a/src/dvc_objects/fs/utils.py b/src/dvc_objects/fs/utils.py index ab2c284..921972c 100644 --- a/src/dvc_objects/fs/utils.py +++ b/src/dvc_objects/fs/utils.py @@ -10,14 +10,17 @@ from secrets import token_urlsafe from typing import TYPE_CHECKING, Any, Collection, Dict, Iterator, Optional, Set, Union +from fsspec.callbacks import DEFAULT_CALLBACK + from dvc_objects.executors import ThreadPoolExecutor from . import system -from .callbacks import DEFAULT_CALLBACK, wrap_file +from .callbacks import wrap_file if TYPE_CHECKING: + from fsspec import Callback + from .base import AnyFSPath, FileSystem - from .callbacks import Callback logger = logging.getLogger(__name__) diff --git a/tests/fs/test_callbacks.py b/tests/fs/test_callbacks.py index dfa84e7..8587371 100644 --- a/tests/fs/test_callbacks.py +++ b/tests/fs/test_callbacks.py @@ -1,12 +1,8 @@ -from typing import Optional - import fsspec import pytest +from fsspec.callbacks import DEFAULT_CALLBACK, Callback from dvc_objects.fs.callbacks import ( - DEFAULT_CALLBACK, - Callback, - TqdmCallback, branch_callback, wrap_and_branch_callback, wrap_file, @@ -35,27 +31,6 @@ def test_callback_with_none(request, api, mocker): assert callback.value == 0 -def test_wrap_fsspec(): - def _branch_fn(*args, callback: Optional["Callback"] = None, **kwargs): - pass - - callback = fsspec.callbacks.Callback() - assert callback.value == 0 - with Callback.as_callback(callback) as cb: - assert not isinstance(cb, TqdmCallback) - assert cb.value == 0 - cb.relative_update() - assert cb.value == 1 - assert callback.value == 1 - - with cb.branch("foo", "bar", {}) as child: - _branch_fn("foo", "bar", callback=child) - cb.relative_update() - - assert cb.value == 2 - assert callback.value == 2 - - def ids_func(cb_type): return f"{cb_type.__module__}.{cb_type.__qualname__}" @@ -96,20 +71,20 @@ async def test_wrap_fn_async(mocker, cb_class): def test_branch_fn_sync(mocker, cb_class): m = mocker.MagicMock(return_value=1) callback = cb_class() - spy = mocker.spy(callback, "branch") + spy = mocker.spy(callback, "branched") wrapped = branch_callback(callback, m) assert wrapped("arg1", "arg2") == 1 assert callback.value == 0 assert spy.call_count == 1 - m.assert_called_once_with("arg1", "arg2", callback=IsDVCCallback()) + m.assert_called_once_with("arg1", "arg2") @pytest.mark.asyncio async def test_branch_fn_async(mocker, cb_class): m = mocker.AsyncMock(return_value=1) callback = cb_class() - spy = mocker.spy(callback, "branch") + spy = mocker.spy(callback, "branched") wrapped = branch_callback(callback, m) assert await wrapped("arg1", "arg2") == 1 @@ -121,14 +96,14 @@ async def test_branch_fn_async(mocker, cb_class): def test_wrap_and_branch_callback_sync(mocker, cb_class): m = mocker.MagicMock(return_value=1) callback = cb_class() - spy = mocker.spy(callback, "branch") + spy = mocker.spy(callback, "branched") wrapped = wrap_and_branch_callback(callback, m) assert wrapped("arg1", "arg2", arg3="arg3") == 1 assert wrapped("argA", "argB", arg3="argC") == 1 - m.assert_any_call("arg1", "arg2", arg3="arg3", callback=IsDVCCallback()) - m.assert_any_call("argA", "argB", arg3="argC", callback=IsDVCCallback()) + m.assert_any_call("arg1", "arg2", arg3="arg3") + m.assert_any_call("argA", "argB", arg3="argC") assert callback.value == 2 assert spy.call_count == 2 @@ -137,7 +112,7 @@ def test_wrap_and_branch_callback_sync(mocker, cb_class): async def test_wrap_and_branch_callback_async(mocker, cb_class): m = mocker.AsyncMock(return_value=1) callback = cb_class() - spy = mocker.spy(callback, "branch") + spy = mocker.spy(callback, "branched") wrapped = wrap_and_branch_callback(callback, m) assert await wrapped("arg1", "arg2", arg3="arg3") == 1 From 67ef3ebe0b0350a053465fc1fc3d44ac27b4e097 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Saugat=20Pachhai=20=28=E0=A4=B8=E0=A5=8C=E0=A4=97=E0=A4=BE?= =?UTF-8?q?=E0=A4=A4=29?= Date: Mon, 5 Feb 2024 13:19:13 +0545 Subject: [PATCH 2/3] get rid of callback wrappers, add tests for generic transfers --- pyproject.toml | 2 +- src/dvc_objects/fs/base.py | 23 ++++-- src/dvc_objects/fs/callbacks.py | 55 +++----------- src/dvc_objects/fs/generic.py | 65 +++++++++++------ tests/fs/test_callbacks.py | 125 +------------------------------- tests/fs/test_generic.py | 120 ++++++++++++++++++++++++++++++ 6 files changed, 190 insertions(+), 200 deletions(-) create mode 100644 tests/fs/test_generic.py diff --git a/pyproject.toml b/pyproject.toml index 058e716..0f5942c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -104,7 +104,7 @@ module = [ ] [tool.codespell] -ignore-words-list = " " +ignore-words-list = "cachable," [tool.ruff] ignore = [ diff --git a/src/dvc_objects/fs/base.py b/src/dvc_objects/fs/base.py index 01a09a1..e0a741c 100644 --- a/src/dvc_objects/fs/base.py +++ b/src/dvc_objects/fs/base.py @@ -32,7 +32,7 @@ from dvc_objects.compat import cached_property from dvc_objects.executors import ThreadPoolExecutor, batch_coros -from .callbacks import wrap_and_branch_callback, wrap_file +from .callbacks import wrap_file from .errors import RemoteMissingDepsError if TYPE_CHECKING: @@ -703,9 +703,14 @@ def put( callback.set_size(len(from_infos)) executor = ThreadPoolExecutor(max_workers=jobs, cancel_on_error=True) + + def put_file(from_path, to_path): + with callback.branched(from_path, to_path) as child: + return self.put_file(from_path, to_path, callback=child) + with executor: - put_file = wrap_and_branch_callback(callback, self.put_file) - list(executor.imap_unordered(put_file, from_infos, to_infos)) + it = executor.imap_unordered(put_file, from_infos, to_infos) + list(callback.wrap(it)) def get( self, @@ -721,9 +726,8 @@ def get( def get_file(rpath, lpath, **kwargs): localfs.makedirs(localfs.parent(lpath), exist_ok=True) - self.fs.get_file(rpath, lpath, **kwargs) - - get_file = wrap_and_branch_callback(callback, get_file) + with callback.branched(rpath, lpath) as child: + self.fs.get_file(rpath, lpath, callback=child, **kwargs) if isinstance(from_info, list) and isinstance(to_info, list): from_infos: List[AnyFSPath] = from_info @@ -734,7 +738,9 @@ def get_file(rpath, lpath, **kwargs): if not self.isdir(from_info): callback.set_size(1) - return get_file(from_info, to_info) + get_file(from_info, to_info) + callback.relative_update() + return from_infos = list(self.find(from_info)) if not from_infos: @@ -757,7 +763,8 @@ def get_file(rpath, lpath, **kwargs): callback.set_size(len(from_infos)) executor = ThreadPoolExecutor(max_workers=jobs, cancel_on_error=True) with executor: - list(executor.imap_unordered(get_file, from_infos, to_infos)) + it = executor.imap_unordered(get_file, from_infos, to_infos) + list(callback.wrap(it)) def ukey(self, path: AnyFSPath) -> str: return self.fs.ukey(path) diff --git a/src/dvc_objects/fs/callbacks.py b/src/dvc_objects/fs/callbacks.py index 91b1afd..467e26a 100644 --- a/src/dvc_objects/fs/callbacks.py +++ b/src/dvc_objects/fs/callbacks.py @@ -1,6 +1,5 @@ -import asyncio from functools import wraps -from typing import TYPE_CHECKING, Any, BinaryIO, Callable, Dict, Optional, TypeVar, cast +from typing import TYPE_CHECKING, BinaryIO, Optional, Type, cast import fsspec from fsspec.callbacks import DEFAULT_CALLBACK, Callback, NoOpCallback @@ -8,9 +7,8 @@ if TYPE_CHECKING: from typing import Union - from dvc_objects._tqdm import Tqdm + from tqdm import tqdm -F = TypeVar("F", bound=Callable) __all__ = ["Callback", "NoOpCallback", "TqdmCallback", "DEFAULT_CALLBACK"] @@ -36,55 +34,24 @@ def __init__( self, size: Optional[int] = None, value: int = 0, - progress_bar: Optional["Tqdm"] = None, + progress_bar: Optional["tqdm"] = None, + tqdm_cls: Optional[Type["tqdm"]] = None, **tqdm_kwargs, ): from dvc_objects._tqdm import Tqdm tqdm_kwargs.pop("total", None) - super().__init__(tqdm_kwargs=tqdm_kwargs, tqdm_cls=Tqdm, size=size, value=value) - if progress_bar: + tqdm_cls = tqdm_cls or Tqdm + super().__init__( + tqdm_kwargs=tqdm_kwargs, tqdm_cls=tqdm_cls, size=size, value=value + ) + if progress_bar is None: self.tqdm = progress_bar - def branched( - self, - path_1: "Union[str, BinaryIO]", - path_2: str, - kwargs: Dict[str, Any], - ): + def branched(self, path_1: "Union[str, BinaryIO]", path_2: str, **kwargs): desc = path_1 if isinstance(path_1, str) else path_2 return TqdmCallback(bytes=True, desc=desc) -def wrap_fn(callback: fsspec.Callback, fn: F) -> F: - @wraps(fn) - async def async_wrapper(*args, **kwargs): - res = await fn(*args, **kwargs) - callback.relative_update() - return res - - @wraps(fn) - def sync_wrapper(*args, **kwargs): - res = fn(*args, **kwargs) - callback.relative_update() - return res - - return async_wrapper if asyncio.iscoroutinefunction(fn) else sync_wrapper # type: ignore[return-value] - - -def branch_callback(callback: fsspec.Callback, fn: F) -> F: - @wraps(fn) - def sync_wrapper(path1: "Union[str, BinaryIO]", path2: str, **kwargs): - with callback.branched(path1, path2): - return fn(path1, path2, **kwargs) - - return callback.branch_coro(fn) if asyncio.iscoroutinefunction(fn) else sync_wrapper # type: ignore[return-value] - - -def wrap_and_branch_callback(callback: fsspec.Callback, fn: F) -> F: - branch_wrapper = branch_callback(callback, fn) - return wrap_fn(callback, branch_wrapper) - - -def wrap_file(file, callback: fsspec.Callback) -> BinaryIO: +def wrap_file(file, callback: Callback) -> BinaryIO: return cast(BinaryIO, CallbackStream(file, callback)) diff --git a/src/dvc_objects/fs/generic.py b/src/dvc_objects/fs/generic.py index 06b51c7..a25df9a 100644 --- a/src/dvc_objects/fs/generic.py +++ b/src/dvc_objects/fs/generic.py @@ -11,7 +11,6 @@ from dvc_objects.executors import ThreadPoolExecutor, batch_coros -from .callbacks import wrap_and_branch_callback from .local import LocalFileSystem, localfs from .utils import as_atomic, umask @@ -105,7 +104,6 @@ def copy( ) jobs = batch_size or to_fs.jobs - put_file = wrap_and_branch_callback(callback, to_fs.put_file) put_file_kwargs = {} if hasattr(to_fs.fs, "max_concurrency"): put_file_kwargs["max_concurrency"] = jobs if len(from_path) == 1 else 1 @@ -114,9 +112,10 @@ def _copy_one(from_p: "AnyFSPath", to_p: "AnyFSPath"): try: with from_fs.open(from_p, mode="rb") as fobj: size = from_fs.size(from_p) - return put_file( - fobj, to_p, size=size, callback=callback, **put_file_kwargs - ) + with callback.branched(from_p, to_p) as child: + return to_fs.put_file( + fobj, to_p, size=size, callback=child, **put_file_kwargs + ) except Exception as exc: # noqa: BLE001 if on_error is not None: on_error(from_p, to_p, exc) @@ -124,14 +123,16 @@ def _copy_one(from_p: "AnyFSPath", to_p: "AnyFSPath"): raise if len(from_path) == 1: - return _copy_one(from_path[0], to_path[0]) + _copy_one(from_path[0], to_path[0]) + return callback.relative_update() executor = ThreadPoolExecutor(max_workers=jobs, cancel_on_error=True) with executor: - list(executor.imap_unordered(_copy_one, from_path, to_path)) + it = executor.imap_unordered(_copy_one, from_path, to_path) + list(callback.wrap(it)) -def _put( +def _put( # noqa: C901 from_paths: List["AnyFSPath"], to_fs: "FileSystem", to_paths: List["AnyFSPath"], @@ -140,14 +141,16 @@ def _put( on_error: Optional[TransferErrorHandler] = None, ) -> None: jobs = batch_size or to_fs.jobs - put_file = wrap_and_branch_callback(callback, to_fs.put_file) put_file_kwargs = {} if hasattr(to_fs.fs, "max_concurrency"): put_file_kwargs["max_concurrency"] = jobs if len(from_paths) == 1 else 1 def _put_one(from_path: "AnyFSPath", to_path: "AnyFSPath"): try: - return put_file(from_path, to_path, callback=callback, **put_file_kwargs) + with callback.branched(from_path, to_path) as child: + return to_fs.put_file( + from_path, to_path, callback=child, **put_file_kwargs + ) except Exception as exc: # noqa: BLE001 if on_error is not None: on_error(from_path, to_path, exc) @@ -155,18 +158,27 @@ def _put_one(from_path: "AnyFSPath", to_path: "AnyFSPath"): raise if len(from_paths) == 1: - return _put_one(from_paths[0], to_paths[0]) + _put_one(from_paths[0], to_paths[0]) + return callback.relative_update() if to_fs.fs.async_impl: - put_coro = wrap_and_branch_callback(callback, to_fs.fs._put_file) + to_fs_async = to_fs.fs + + async def put_coro(from_path, to_path, **kwargs): + with callback.branched(from_path, to_path) as child: + return await to_fs_async._put_file( + from_path, to_path, callback=child, **kwargs + ) + loop = get_loop() fut = asyncio.run_coroutine_threadsafe( batch_coros( [ - put_coro(from_path, to_path, callback=callback, **put_file_kwargs) + put_coro(from_path, to_path, **put_file_kwargs) for from_path, to_path in zip(from_paths, to_paths) ], batch_size=jobs, + callback=callback, return_exceptions=True, ), loop, @@ -181,7 +193,8 @@ def _put_one(from_path: "AnyFSPath", to_path: "AnyFSPath"): executor = ThreadPoolExecutor(max_workers=jobs, cancel_on_error=True) with executor: - list(executor.imap_unordered(_put_one, from_paths, to_paths)) + it = executor.imap_unordered(_put_one, from_paths, to_paths) + list(callback.wrap(it)) def _get( # noqa: C901 @@ -193,7 +206,6 @@ def _get( # noqa: C901 on_error: Optional[TransferErrorHandler] = None, ) -> None: jobs = batch_size or from_fs.jobs - get_file = wrap_and_branch_callback(callback, from_fs.get_file) get_file_kwargs = {} if hasattr(from_fs.fs, "max_concurrency"): get_file_kwargs["max_concurrency"] = jobs if len(from_paths) == 1 else 1 @@ -201,9 +213,10 @@ def _get( # noqa: C901 def _get_one(from_path: "AnyFSPath", to_path: "AnyFSPath"): with as_atomic(localfs, to_path, create_parents=True) as tmp_file: try: - return get_file( - from_path, tmp_file, callback=callback, **get_file_kwargs - ) + with callback.branched(from_path, to_path) as child: + return from_fs.get_file( + from_path, tmp_file, callback=child, **get_file_kwargs + ) except Exception as exc: # noqa: BLE001 if on_error is not None: on_error(from_path, to_path, exc) @@ -211,16 +224,18 @@ def _get_one(from_path: "AnyFSPath", to_path: "AnyFSPath"): raise if len(from_paths) == 1: - return _get_one(from_paths[0], to_paths[0]) + _get_one(from_paths[0], to_paths[0]) + return callback.relative_update() if from_fs.fs.async_impl: + from_async_fs = from_fs.fs async def _get_one_coro(from_path: "AnyFSPath", to_path: "AnyFSPath"): - get_coro = wrap_and_branch_callback(callback, from_fs.fs._get_file) with as_atomic(localfs, to_path, create_parents=True) as tmp_file: - return await get_coro( - from_path, tmp_file, callback=callback, **get_file_kwargs - ) + with callback.branched(from_path, to_path) as child: + return await from_async_fs._get_file( + from_path, tmp_file, callback=child, **get_file_kwargs + ) loop = get_loop() fut = asyncio.run_coroutine_threadsafe( @@ -231,6 +246,7 @@ async def _get_one_coro(from_path: "AnyFSPath", to_path: "AnyFSPath"): ], batch_size=jobs, return_exceptions=True, + callback=callback, ), loop, ) @@ -244,7 +260,8 @@ async def _get_one_coro(from_path: "AnyFSPath", to_path: "AnyFSPath"): executor = ThreadPoolExecutor(max_workers=jobs, cancel_on_error=True) with executor: - list(executor.imap_unordered(_get_one, from_paths, to_paths)) + it = executor.imap_unordered(_get_one, from_paths, to_paths) + list(callback.wrap(it)) def _try_links( diff --git a/tests/fs/test_callbacks.py b/tests/fs/test_callbacks.py index 8587371..4620c98 100644 --- a/tests/fs/test_callbacks.py +++ b/tests/fs/test_callbacks.py @@ -1,127 +1,6 @@ -import fsspec -import pytest -from fsspec.callbacks import DEFAULT_CALLBACK, Callback +from fsspec.callbacks import Callback -from dvc_objects.fs.callbacks import ( - branch_callback, - wrap_and_branch_callback, - wrap_file, - wrap_fn, -) - - -@pytest.mark.parametrize("api", ["set_size", "absolute_update"]) -def test_callback_with_none(request, api, mocker): - """ - Test that callback don't fail if they receive None. - The callbacks should not receive None, but there may be some - filesystems that are not compliant, we may want to maintain - maximum compatibility, and not break UI in these edge-cases. - See https://github.com/iterative/dvc/issues/7704. - """ - callback = Callback.as_callback() - request.addfinalizer(callback.close) - - call_mock = mocker.spy(callback, "call") - method = getattr(callback, api) - method(None) - call_mock.assert_called_once_with() - if callback is not DEFAULT_CALLBACK: - assert callback.size is None - assert callback.value == 0 - - -def ids_func(cb_type): - return f"{cb_type.__module__}.{cb_type.__qualname__}" - - -class IsDVCCallback: - def __eq__(self, value: object) -> bool: - return isinstance(value, Callback) - - -@pytest.fixture(params=[Callback, fsspec.Callback], ids=ids_func) -def cb_class(request): - return request.param - - -def test_wrap_fn_sync(mocker, cb_class): - m = mocker.MagicMock(return_value=1) - callback = cb_class() - - wrapped = wrap_fn(callback, m) - - assert wrapped("arg") == 1 - assert callback.value == 1 - m.assert_called_once_with("arg") - - -@pytest.mark.asyncio -async def test_wrap_fn_async(mocker, cb_class): - m = mocker.AsyncMock(return_value=1) - callback = cb_class() - - wrapped = wrap_fn(callback, m) - - assert await wrapped("arg") == 1 - assert callback.value == 1 - m.assert_called_once_with("arg") - - -def test_branch_fn_sync(mocker, cb_class): - m = mocker.MagicMock(return_value=1) - callback = cb_class() - spy = mocker.spy(callback, "branched") - wrapped = branch_callback(callback, m) - - assert wrapped("arg1", "arg2") == 1 - assert callback.value == 0 - assert spy.call_count == 1 - m.assert_called_once_with("arg1", "arg2") - - -@pytest.mark.asyncio -async def test_branch_fn_async(mocker, cb_class): - m = mocker.AsyncMock(return_value=1) - callback = cb_class() - spy = mocker.spy(callback, "branched") - wrapped = branch_callback(callback, m) - - assert await wrapped("arg1", "arg2") == 1 - assert callback.value == 0 - assert spy.call_count == 1 - m.assert_called_once_with("arg1", "arg2", callback=IsDVCCallback()) - - -def test_wrap_and_branch_callback_sync(mocker, cb_class): - m = mocker.MagicMock(return_value=1) - callback = cb_class() - spy = mocker.spy(callback, "branched") - wrapped = wrap_and_branch_callback(callback, m) - - assert wrapped("arg1", "arg2", arg3="arg3") == 1 - assert wrapped("argA", "argB", arg3="argC") == 1 - - m.assert_any_call("arg1", "arg2", arg3="arg3") - m.assert_any_call("argA", "argB", arg3="argC") - assert callback.value == 2 - assert spy.call_count == 2 - - -@pytest.mark.asyncio -async def test_wrap_and_branch_callback_async(mocker, cb_class): - m = mocker.AsyncMock(return_value=1) - callback = cb_class() - spy = mocker.spy(callback, "branched") - wrapped = wrap_and_branch_callback(callback, m) - - assert await wrapped("arg1", "arg2", arg3="arg3") == 1 - assert await wrapped("argA", "argB", arg3="argC") == 1 - - m.assert_any_call("arg1", "arg2", arg3="arg3", callback=IsDVCCallback()) - m.assert_any_call("argA", "argB", arg3="argC", callback=IsDVCCallback()) - assert callback.value == 2 - assert spy.call_count == 2 +from dvc_objects.fs.callbacks import wrap_file def test_wrap_file(memfs): diff --git a/tests/fs/test_generic.py b/tests/fs/test_generic.py new file mode 100644 index 0000000..385fecd --- /dev/null +++ b/tests/fs/test_generic.py @@ -0,0 +1,120 @@ +import pytest +from fsspec import Callback +from fsspec.asyn import AsyncFileSystem +from fsspec.implementations.memory import MemoryFileSystem as MemoryFS + +from dvc_objects.fs.generic import copy, transfer +from dvc_objects.fs.local import LocalFileSystem +from dvc_objects.fs.memory import MemoryFileSystem + + +def awrap(fn): + async def inner(self, *args, **kwargs): + return fn(self.fs, *args, **kwargs) + + return inner + + +class AsyncMemoryFS(AsyncFileSystem): + cachable = False + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.fs = MemoryFS() + self.fs.store = {} + self.fs.pseudo_dirs = [""] + + def _open(self, *args, **kwargs): + return self.fs.open(*args, **kwargs) + + _info = awrap(MemoryFS.info) + _ls = awrap(MemoryFS.ls) + _mkdir = awrap(MemoryFS.mkdir) + _makedirs = awrap(MemoryFS.makedirs) + _get_file = awrap(MemoryFS.get_file) + _put_file = awrap(MemoryFS.put_file) + _cat_file = awrap(MemoryFS.cat_file) + _pipe_file = awrap(MemoryFS.pipe_file) + _cp_file = awrap(MemoryFS.cp_file) + _rm_file = awrap(MemoryFS.rm_file) + + +fs_clses = [ + LocalFileSystem, + lambda: MemoryFileSystem(global_store=False), + lambda: MemoryFileSystem(fs=AsyncMemoryFS()), +] +fs_clses[1].__name__ = MemoryFileSystem.__name__ # type: ignore[attr-defined] +fs_clses[2].__name__ = "Async" + MemoryFileSystem.__name__ # type: ignore[attr-defined] + + +@pytest.mark.parametrize("files", [{"foo": b"foo"}, {"foo": b"foo", "bar": b"barr"}]) +@pytest.mark.parametrize("fs_cls1", fs_clses) +@pytest.mark.parametrize("fs_cls2", fs_clses) +def test_copy(tmp_path, files, fs_cls1, fs_cls2, mocker): + fs1, fs2 = fs_cls1(), fs_cls2() + fs1_root = tmp_path if isinstance(fs1, LocalFileSystem) else fs1.root_marker + fs2_root = tmp_path if isinstance(fs2, LocalFileSystem) else fs2.root_marker + src_root = fs1.join(fs1_root, "src") + dest_root = fs2.join(fs2_root, "dest") + + src_files = {fs1.join(src_root, f): c for f, c in files.items()} + dest_files = {fs2.join(dest_root, f): c for f, c in files.items()} + fs1.mkdir(src_root) + fs1.pipe(src_files) + + callback = Callback() + spy_close = mocker.spy(Callback, "close") + child_callbacks = [Callback() for _ in files] + + branched = mocker.patch.object(callback, "branched", side_effect=child_callbacks) + copy(fs1, list(src_files), fs2, list(dest_files), callback=callback) + + assert fs2.cat(list(dest_files)) == dest_files + + n = len(files) + # assert main callback works + assert callback.value == n + assert callback.size is None # does not set sizes + # assert child callbacks are handled correctly + assert branched.call_count == n, f"expected branched to be called {n} times" + assert spy_close.call_count == n, f"expected close to be called {n} times" + + if isinstance(fs1, LocalFileSystem) and isinstance(fs2, LocalFileSystem): + # localfs copy avoids calling set_size or update if fs supports reflink + # or, file size is less than 1GB + return + assert {c.size for c in child_callbacks} == {len(c) for c in files.values()} + assert {c.value for c in child_callbacks} == {len(c) for c in files.values()} + + +@pytest.mark.parametrize("files", [{"foo": b"foo"}, {"foo": b"foo", "bar": b"barr"}]) +@pytest.mark.parametrize( + "link_type", + [ + pytest.param("reflink", marks=pytest.mark.xfail(reason="unsupported")), + "symlink", + "hardlink", + "copy", + ], +) +def test_transfer_between_local_fses(mocker, tmp_path, files, link_type): + fs = LocalFileSystem() + fs.mkdir(fs.join(tmp_path, "src")) + fs.mkdir(fs.join(tmp_path, "dest")) + + src_files = {fs.join(tmp_path, "src", f): c for f, c in files.items()} + dest_files = {fs.join(tmp_path, "dest", f): c for f, c in files.items()} + + fs.pipe(src_files) + + callback = Callback() + + branched = mocker.patch.object(callback, "branched") + transfer( + fs, list(src_files), fs, list(dest_files), callback=callback, links=[link_type] + ) + assert fs.cat([str(tmp_path / "dest" / file) for file in files]) == dest_files + assert callback.value == len(files) + assert callback.size is None # does not set sizes + assert branched.call_count == (len(files) if link_type == "copy" else 0) From c8359c1e9e432abac82c35ec58f5f2481d48f0ed Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Saugat=20Pachhai=20=28=E0=A4=B8=E0=A5=8C=E0=A4=97=E0=A4=BE?= =?UTF-8?q?=E0=A4=A4=29?= Date: Mon, 5 Feb 2024 14:51:38 +0545 Subject: [PATCH 3/3] bump fsspec req --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 0f5942c..f142984 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,7 +23,7 @@ requires-python = ">=3.8" dynamic = ["version"] dependencies = [ "funcy>=1.14; python_version < '3.12'", - "fsspec>=2023.2.0", + "fsspec>=2024.2.0", ] [project.urls]