Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ requires-python = ">=3.8"
dynamic = ["version"]
dependencies = [
"funcy>=1.14; python_version < '3.12'",
"fsspec>=2022.10.0",
"fsspec>=2024.2.0",
]

[project.urls]
Expand Down Expand Up @@ -104,7 +104,7 @@ module = [
]

[tool.codespell]
ignore-words-list = " "
ignore-words-list = "cachable,"

[tool.ruff]
ignore = [
Expand Down
6 changes: 4 additions & 2 deletions src/dvc_objects/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,15 @@
cast,
)

from fsspec.callbacks import DEFAULT_CALLBACK

from .errors import ObjectDBPermissionError
from .fs.callbacks import DEFAULT_CALLBACK
from .obj import Object

if TYPE_CHECKING:
from fsspec import Callback

from .fs.base import AnyFSPath, FileSystem
from .fs.callbacks import Callback


logger = logging.getLogger(__name__)
Expand Down
2 changes: 1 addition & 1 deletion src/dvc_objects/executors.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
TypeVar,
)

from .fs.callbacks import Callback
from fsspec import Callback

_T = TypeVar("_T")

Expand Down
28 changes: 16 additions & 12 deletions src/dvc_objects/fs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,12 @@

import fsspec
from fsspec.asyn import get_loop
from fsspec.callbacks import DEFAULT_CALLBACK

from dvc_objects.compat import cached_property
from dvc_objects.executors import ThreadPoolExecutor, batch_coros

from .callbacks import (
DEFAULT_CALLBACK,
wrap_and_branch_callback,
wrap_file,
)
from .callbacks import wrap_file
from .errors import RemoteMissingDepsError

if TYPE_CHECKING:
Expand Down Expand Up @@ -706,9 +703,14 @@ def put(

callback.set_size(len(from_infos))
executor = ThreadPoolExecutor(max_workers=jobs, cancel_on_error=True)

def put_file(from_path, to_path):
with callback.branched(from_path, to_path) as child:
return self.put_file(from_path, to_path, callback=child)

with executor:
put_file = wrap_and_branch_callback(callback, self.put_file)
list(executor.imap_unordered(put_file, from_infos, to_infos))
it = executor.imap_unordered(put_file, from_infos, to_infos)
list(callback.wrap(it))

def get(
self,
Expand All @@ -724,9 +726,8 @@ def get(

def get_file(rpath, lpath, **kwargs):
localfs.makedirs(localfs.parent(lpath), exist_ok=True)
self.fs.get_file(rpath, lpath, **kwargs)

get_file = wrap_and_branch_callback(callback, get_file)
with callback.branched(rpath, lpath) as child:
self.fs.get_file(rpath, lpath, callback=child, **kwargs)

if isinstance(from_info, list) and isinstance(to_info, list):
from_infos: List[AnyFSPath] = from_info
Expand All @@ -737,7 +738,9 @@ def get_file(rpath, lpath, **kwargs):

if not self.isdir(from_info):
callback.set_size(1)
return get_file(from_info, to_info)
get_file(from_info, to_info)
callback.relative_update()
return

from_infos = list(self.find(from_info))
if not from_infos:
Expand All @@ -760,7 +763,8 @@ def get_file(rpath, lpath, **kwargs):
callback.set_size(len(from_infos))
executor = ThreadPoolExecutor(max_workers=jobs, cancel_on_error=True)
with executor:
list(executor.imap_unordered(get_file, from_infos, to_infos))
it = executor.imap_unordered(get_file, from_infos, to_infos)
list(callback.wrap(it))

def ukey(self, path: AnyFSPath) -> str:
return self.fs.ukey(path)
Expand Down
159 changes: 19 additions & 140 deletions src/dvc_objects/fs/callbacks.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
import asyncio
from functools import wraps
from typing import TYPE_CHECKING, Any, BinaryIO, Callable, Dict, Optional, TypeVar, cast
from typing import TYPE_CHECKING, BinaryIO, Optional, Type, cast

import fsspec
from fsspec.callbacks import DEFAULT_CALLBACK, Callback, NoOpCallback

if TYPE_CHECKING:
from typing import Union

from dvc_objects._tqdm import Tqdm
from tqdm import tqdm

F = TypeVar("F", bound=Callable)

__all__ = ["Callback", "NoOpCallback", "TqdmCallback", "DEFAULT_CALLBACK"]


class CallbackStream:
def __init__(self, stream, callback: fsspec.Callback):
def __init__(self, stream, callback: Callback):
self.stream = stream

@wraps(stream.read)
Expand All @@ -28,151 +29,29 @@ def __getattr__(self, attr):
return getattr(self.stream, attr)


class ScopedCallback(fsspec.Callback):
def __enter__(self):
return self

def __exit__(self, *exc_args):
self.close()

def close(self):
"""Handle here on exit."""

def branch(
self,
path_1: "Union[str, BinaryIO]",
path_2: str,
kwargs: Dict[str, Any],
child: Optional["Callback"] = None,
) -> "Callback":
child = kwargs["callback"] = child or DEFAULT_CALLBACK
return child


class Callback(ScopedCallback):
def absolute_update(self, value: int) -> None:
value = value if value is not None else self.value
return super().absolute_update(value)

@classmethod
def as_callback(
cls, maybe_callback: Optional[fsspec.Callback] = None
) -> "Callback":
if maybe_callback is None:
return DEFAULT_CALLBACK
if isinstance(maybe_callback, Callback):
return maybe_callback
return FsspecCallbackWrapper(maybe_callback)


class NoOpCallback(Callback, fsspec.callbacks.NoOpCallback):
pass


class TqdmCallback(Callback):
class TqdmCallback(fsspec.callbacks.TqdmCallback):
def __init__(
self,
size: Optional[int] = None,
value: int = 0,
progress_bar: Optional["Tqdm"] = None,
progress_bar: Optional["tqdm"] = None,
tqdm_cls: Optional[Type["tqdm"]] = None,
**tqdm_kwargs,
):
from dvc_objects._tqdm import Tqdm

tqdm_kwargs.pop("total", None)
self._tqdm_kwargs = tqdm_kwargs
self._tqdm_cls = Tqdm
self.tqdm = progress_bar
super().__init__(size=size, value=value)

def close(self):
if self.tqdm is not None:
self.tqdm.close()
self.tqdm = None

def call(self, hook_name=None, **kwargs):
if self.tqdm is None:
self.tqdm = self._tqdm_cls(**self._tqdm_kwargs, total=self.size or -1)
self.tqdm.update_to(self.value, total=self.size)

def branch(
self,
path_1: "Union[str, BinaryIO]",
path_2: str,
kwargs: Dict[str, Any],
child: Optional[Callback] = None,
):
tqdm_cls = tqdm_cls or Tqdm
super().__init__(
tqdm_kwargs=tqdm_kwargs, tqdm_cls=tqdm_cls, size=size, value=value
)
if progress_bar is None:
self.tqdm = progress_bar

def branched(self, path_1: "Union[str, BinaryIO]", path_2: str, **kwargs):
desc = path_1 if isinstance(path_1, str) else path_2
child = child or TqdmCallback(bytes=True, desc=desc)
return super().branch(path_1, path_2, kwargs, child=child)


class FsspecCallbackWrapper(Callback):
def __init__(self, callback: fsspec.Callback):
object.__setattr__(self, "_callback", callback)

def __getattr__(self, name: str):
return getattr(self._callback, name)
return TqdmCallback(bytes=True, desc=desc)

def __setattr__(self, name: str, value: Any):
setattr(self._callback, name, value)

def absolute_update(self, value: int) -> None:
value = value if value is not None else self.value
return self._callback.absolute_update(value)

def branch(
self,
path_1: "Union[str, BinaryIO]",
path_2: str,
kwargs: Dict[str, Any],
child: Optional["Callback"] = None,
) -> "Callback":
if not child:
self._callback.branch(path_1, path_2, kwargs)
child = self.as_callback(kwargs.get("callback"))
return super().branch(path_1, path_2, kwargs, child=child)


def wrap_fn(callback: fsspec.Callback, fn: F) -> F:
@wraps(fn)
async def async_wrapper(*args, **kwargs):
res = await fn(*args, **kwargs)
callback.relative_update()
return res

@wraps(fn)
def sync_wrapper(*args, **kwargs):
res = fn(*args, **kwargs)
callback.relative_update()
return res

return async_wrapper if asyncio.iscoroutinefunction(fn) else sync_wrapper # type: ignore[return-value]


def branch_callback(callback: fsspec.Callback, fn: F) -> F:
callback = Callback.as_callback(callback)

@wraps(fn)
async def async_wrapper(path1: "Union[str, BinaryIO]", path2: str, **kwargs):
with callback.branch(path1, path2, kwargs):
return await fn(path1, path2, **kwargs)

@wraps(fn)
def sync_wrapper(path1: "Union[str, BinaryIO]", path2: str, **kwargs):
with callback.branch(path1, path2, kwargs):
return fn(path1, path2, **kwargs)

return async_wrapper if asyncio.iscoroutinefunction(fn) else sync_wrapper # type: ignore[return-value]


def wrap_and_branch_callback(callback: fsspec.Callback, fn: F) -> F:
branch_wrapper = branch_callback(callback, fn)
return wrap_fn(callback, branch_wrapper)


def wrap_file(file, callback: fsspec.Callback) -> BinaryIO:
def wrap_file(file, callback: Callback) -> BinaryIO:
return cast(BinaryIO, CallbackStream(file, callback))


DEFAULT_CALLBACK = NoOpCallback()
Loading