From 9546b36d65c1fc0ea0a09392b7117743dc6c7834 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Saugat=20Pachhai=20=28=E0=A4=B8=E0=A5=8C=E0=A4=97=E0=A4=BE?= =?UTF-8?q?=E0=A4=A4=29?= Date: Sun, 17 Oct 2021 14:34:03 +0545 Subject: [PATCH] api: traversable `files` API --- dvc/api.py | 66 +++++----- dvc/exceptions.py | 1 + dvc/repo_path.py | 306 ++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 338 insertions(+), 35 deletions(-) create mode 100644 dvc/repo_path.py diff --git a/dvc/api.py b/dvc/api.py index a56ed75d7f..68c4498741 100644 --- a/dvc/api.py +++ b/dvc/api.py @@ -1,10 +1,22 @@ import os from contextlib import _GeneratorContextManager as GCM +from contextlib import contextmanager +from typing import ContextManager, Iterator -from funcy import reraise - -from dvc.exceptions import OutputNotFoundError, PathMissingError +from dvc.exceptions import NoOutputInExternalRepoError, OutputNotFoundError from dvc.repo import Repo +from dvc.repo_path import RepoPath + + +def files(path=os.curdir, repo=None, rev=None) -> ContextManager[RepoPath]: + @contextmanager + def inner() -> Iterator["RepoPath"]: + with Repo.open( + repo, rev=rev, subrepos=True, uninitialized=True + ) as root_repo: + yield RepoPath(path, fs=root_repo.repo_fs) + + return inner() def get_url(path, repo=None, rev=None, remote=None): @@ -18,17 +30,11 @@ def get_url(path, repo=None, rev=None, remote=None): NOTE: This function does not check for the actual existence of the file or directory in the remote storage. """ - with Repo.open(repo, rev=rev, subrepos=True, uninitialized=True) as _repo: - fs_path = _repo.fs.path.join(_repo.root_dir, path) - with reraise(FileNotFoundError, PathMissingError(path, repo)): - metadata = _repo.repo_fs.metadata(fs_path) - - if not metadata.is_dvc: - raise OutputNotFoundError(path, repo) - - cloud = metadata.repo.cloud - md5 = metadata.repo.dvcfs.info(fs_path)["md5"] - return cloud.get_url_for(remote, checksum=md5) + try: + with files(path, repo=repo, rev=rev) as path_obj: + return path_obj.url(remote=remote) + except NoOutputInExternalRepoError as exc: + raise OutputNotFoundError(exc.path, repo=repo) def open( # noqa, pylint: disable=redefined-builtin @@ -46,15 +52,15 @@ def open( # noqa, pylint: disable=redefined-builtin ) as fd: # ... Handle file object fd """ - args = (path,) - kwargs = { - "repo": repo, - "remote": remote, - "rev": rev, - "mode": mode, - "encoding": encoding, - } - return _OpenContextManager(_open, args, kwargs) + + def _open(): + with files(path, repo=repo, rev=rev) as path_obj: + with path_obj.open( # pylint: disable=not-context-manager + remote=remote, mode=mode, encoding=encoding + ) as fd: + yield fd + + return _OpenContextManager(_open, (), {}) class _OpenContextManager(GCM): @@ -70,24 +76,14 @@ def __getattr__(self, name): ) -def _open(path, repo=None, rev=None, remote=None, mode="r", encoding=None): - with Repo.open(repo, rev=rev, subrepos=True, uninitialized=True) as _repo: - with _repo.open_by_relpath( - path, remote=remote, mode=mode, encoding=encoding - ) as fd: - yield fd - - def read(path, repo=None, rev=None, remote=None, mode="r", encoding=None): """ Returns the contents of a tracked file (by DVC or Git). For Git repos, HEAD is used unless a rev argument is supplied. The default remote is tried unless a remote argument is supplied. """ - with open( - path, repo=repo, rev=rev, remote=remote, mode=mode, encoding=encoding - ) as fd: - return fd.read() + with files(path, repo=repo, rev=rev) as path_obj: + return path_obj.read(remote=remote, mode=mode, encoding=encoding) def make_checkpoint(): diff --git a/dvc/exceptions.py b/dvc/exceptions.py index f65f307136..5dc05e8b2f 100644 --- a/dvc/exceptions.py +++ b/dvc/exceptions.py @@ -264,6 +264,7 @@ class NoOutputInExternalRepoError(DvcException): def __init__(self, path, external_repo_path, external_repo_url): from dvc.utils import relpath + self.path = path super().__init__( "Output '{}' not found in target repository '{}'".format( relpath(path, external_repo_path), external_repo_url diff --git a/dvc/repo_path.py b/dvc/repo_path.py new file mode 100644 index 0000000000..a45800ffd0 --- /dev/null +++ b/dvc/repo_path.py @@ -0,0 +1,306 @@ +import os +import pathlib +import sys +from itertools import chain +from typing import TYPE_CHECKING, overload + +from .exceptions import OutputNotFoundError, PathMissingError +from .types import OptStr + +if TYPE_CHECKING: + from io import ( + BufferedRandom, + BufferedReader, + BufferedWriter, + FileIO, + TextIOWrapper, + ) + from typing import IO, Any, BinaryIO, Generator, Union + + from _typeshed import ( + OpenBinaryMode, + OpenBinaryModeReading, + OpenBinaryModeUpdating, + OpenBinaryModeWriting, + OpenTextMode, + OpenTextModeReading, + ) + from typing_extensions import Literal + + from .fs.repo import RepoFileSystem + + +def _unsupported(method: str): + def wrapped(*args, **kwargs): + raise NotImplementedError(f"{method}() is unsupported.") + + return wrapped + + +class _PathNotSupportedMixin: + samefile = _unsupported("samefile") + absolute = _unsupported("absolute") + resolve = _unsupported("resolve") + stat = _unsupported("stat") + owner = _unsupported("owner") + group = _unsupported("group") + write_bytes = _unsupported("write_bytes") + write_text = _unsupported("write_text") + touch = _unsupported("touch") + mkdir = _unsupported("mkdir") + chmod = _unsupported("chmod") + unlink = _unsupported("unlink") + rmdir = _unsupported("rmdir") + lstat = _unsupported("lstat") + rename = _unsupported("rename") + replace = _unsupported("replace") + symlink_to = _unsupported("symlink_to") + is_mount = _unsupported("is_mount") + + +class PureRepoPath(pathlib.PurePath): + # pylint: disable=protected-access + _flavour = ( + pathlib._WindowsFlavour() # type: ignore[attr-defined] + if os.name == "nt" + else pathlib._PosixFlavour() # type: ignore[attr-defined] + ) + __slots__ = () + + +class RepoPath( # lgtm[py/conflicting-attributes] + # pylint:disable=abstract-method + _PathNotSupportedMixin, + PureRepoPath, + pathlib.Path, +): + _fs: "RepoFileSystem" + + scheme = "local" + __slots__ = ("_fs",) + + def __new__(cls, *args, **kwargs): + args_list = list(args) + repo_path = args_list.pop(0) + kw = {"init": False} if sys.version_info < (3, 10) else {} + self = super()._from_parts( # pylint: disable=unexpected-keyword-arg + args, **kw + ) + if isinstance(repo_path, RepoPath): + # pylint: disable=protected-access + kwargs["fs"] = kwargs.get("fs") or repo_path._fs + self._init(*args, **kwargs) # pylint: disable=no-member + return self + + def _from_parsed_parts(self, *args, **kwargs): + new = super()._from_parsed_parts(*args, **kwargs) + # pylint: disable=protected-access, assigning-non-slot + new._fs = self._fs + return new + + def _init( # pylint: disable=arguments-differ + self, *args, template=None, fs=None + ): + self._fs = fs # pylint: disable=disable=assigning-non-slot + if sys.version_info > (3, 10): + return + super()._init(template) # pylint: disable=no-member + + def url(self, remote: str = None) -> str: + fs = self._fs + fs_path = fs.path.join(fs.root_dir, str(self)) + try: + metadata = fs.metadata(fs_path) + except FileNotFoundError: + # pylint: disable=protected-access + raise PathMissingError(str(self), fs._main_repo) + + if not metadata.is_dvc: + raise OutputNotFoundError(str(self), metadata.repo) + + cloud = metadata.repo.cloud + md5 = metadata.repo.dvcfs.info(fs_path)["md5"] + return cloud.get_url_for(remote, checksum=md5) + + def exists(self) -> bool: + return self._fs.exists(self) + + def is_dir(self) -> bool: + return self._fs.isdir(self) + + def is_file(self) -> bool: + return self._fs.isfile(self) + + @overload + def read( + self, + mode: "OpenTextModeReading", + remote: str = None, + encoding: str = None, + errors: str = None, + ) -> str: + ... + + @overload + def read( + self, + mode: "OpenBinaryModeReading", + remote: str = None, + encoding: str = None, + errors: str = None, + ) -> bytes: + ... + + @overload + def read( + self, + mode: str = ..., + remote: str = None, + encoding: str = None, + errors: str = None, + ) -> "Union[str, bytes]": + ... + + def read( + self, + mode: str = "r", + remote: str = None, + encoding: str = None, + errors: str = None, + ): + with self.open( # pylint: disable=not-context-manager + remote=remote, mode=mode, encoding=encoding, errors=errors + ) as f: + return f.read() + + def read_bytes( # pylint: disable=arguments-differ + self, remote: str = None + ) -> bytes: + return self.read(remote=remote, mode="rb") + + def read_text( # pylint: disable=arguments-differ + self, encoding: str = None, errors: str = None, remote: str = None + ) -> str: + return self.read( + mode="r", encoding=encoding, errors=errors, remote=remote + ) + + def iterdir(self) -> "Generator[RepoPath, None, None]": + def onerror(exc): + raise exc + + repo_walk = self._fs.walk(self, onerror=onerror, dvcfiles=True) + for _, dirs, files in repo_walk: + yield from (self / entry for entry in chain(files, dirs)) + break + + # NOTE: keep in sync with Pathlib.open typehints + # pylint: disable=arguments-differ + @overload + def open( + self, + mode: "OpenTextMode" = ..., + buffering: int = ..., + encoding: OptStr = ..., + errors: OptStr = ..., + newline: OptStr = ..., + remote: OptStr = ..., + ) -> "TextIOWrapper": + ... + + # Unbuffered binary mode: returns a FileIO + @overload + def open( + self, + mode: "OpenBinaryMode", + buffering: "Literal[0]", + encoding: None = ..., + errors: None = ..., + newline: None = ..., + remote: OptStr = ..., + ) -> "FileIO": + ... + + # Buffering is on: return BufferedRandom, BufferedReader, or BufferedWriter + @overload + def open( + self, + mode: "OpenBinaryModeUpdating", + buffering: "Literal[-1, 1]" = ..., + encoding: None = ..., + errors: None = ..., + newline: None = ..., + remote: OptStr = ..., + ) -> "BufferedRandom": + ... + + @overload + def open( + self, + mode: "OpenBinaryModeWriting", + buffering: "Literal[-1, 1]" = ..., + encoding: None = ..., + errors: None = ..., + newline: None = ..., + remote: OptStr = ..., + ) -> "BufferedWriter": + ... + + @overload + def open( + self, + mode: "OpenBinaryModeReading", + buffering: "Literal[-1, 1]" = ..., + encoding: None = ..., + errors: None = ..., + newline: None = ..., + remote: OptStr = ..., + ) -> "BufferedReader": + ... + + # Buffering cannot be determined: fall back to BinaryIO + @overload + def open( + self, + mode: "OpenBinaryMode", + buffering: int, + encoding: None = ..., + errors: None = ..., + newline: None = ..., + remote: OptStr = ..., + ) -> "BinaryIO": + ... + + # Fallback if mode is not specified + @overload + def open( + self, + mode: str, + buffering: int = ..., + encoding: OptStr = ..., + errors: OptStr = ..., + newline: OptStr = ..., + remote: OptStr = ..., + ) -> "IO[Any]": + ... + + def open( + self, + mode="r", + buffering=-1, + encoding=None, + errors=None, + newline=None, + remote=None, + ): + assert buffering == -1 + assert errors is None + assert newline is None + assert mode in {"rt", "tr", "r", "rb", "br"} + + main_repo = self._fs._main_repo # pylint: disable=protected-access + return main_repo.open_by_relpath( + self, mode=mode, encoding=encoding, remote=remote + ) + + # pylint: enable=arguments-differ