Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[v3] Feature: Store open mode #1911

Merged
merged 3 commits into from
May 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 26 additions & 1 deletion src/zarr/abc/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,31 @@
from typing import Protocol, runtime_checkable

from zarr.buffer import Buffer
from zarr.common import BytesLike
from zarr.common import BytesLike, OpenMode


class Store(ABC):
_mode: OpenMode

def __init__(self, mode: OpenMode = "r"):
if mode not in ("r", "r+", "w", "w-", "a"):
raise ValueError("mode must be one of 'r', 'r+', 'w', 'w-', 'a'")
self._mode = mode

@property
def mode(self) -> OpenMode:
"""Access mode of the store."""
return self._mode

@property
def writeable(self) -> bool:
"""Is the store writeable?"""
return self.mode in ("a", "w", "w-")

def _check_writable(self) -> None:
if not self.writeable:
raise ValueError("store mode does not support writing")

@abstractmethod
async def get(
self, key: str, byte_range: tuple[int | None, int | None] | None = None
Expand Down Expand Up @@ -147,6 +168,10 @@ def list_dir(self, prefix: str) -> AsyncGenerator[str, None]:
"""
...

def close(self) -> None: # noqa: B027
"""Close the store."""
pass


@runtime_checkable
class ByteGetter(Protocol):
Expand Down
1 change: 1 addition & 0 deletions src/zarr/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
Selection = slice | SliceSelection
ZarrFormat = Literal[2, 3]
JSON = None | str | int | float | Enum | dict[str, "JSON"] | list["JSON"] | tuple["JSON", ...]
OpenMode = Literal["r", "r+", "a", "w", "w-"]


def product(tup: ChunkCoords) -> int:
Expand Down
10 changes: 8 additions & 2 deletions src/zarr/store/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from zarr.abc.store import Store
from zarr.buffer import Buffer
from zarr.common import OpenMode
from zarr.store.local import LocalStore


Expand Down Expand Up @@ -60,13 +61,18 @@ def __eq__(self, other: Any) -> bool:
StoreLike = Store | StorePath | Path | str


def make_store_path(store_like: StoreLike) -> StorePath:
def make_store_path(store_like: StoreLike, *, mode: OpenMode | None = None) -> StorePath:
if isinstance(store_like, StorePath):
if mode is not None:
assert mode == store_like.store.mode
return store_like
elif isinstance(store_like, Store):
if mode is not None:
assert mode == store_like.mode
return StorePath(store_like)
elif isinstance(store_like, str):
return StorePath(LocalStore(Path(store_like)))
assert mode is not None
return StorePath(LocalStore(Path(store_like), mode=mode))
raise TypeError


Expand Down
8 changes: 6 additions & 2 deletions src/zarr/store/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from zarr.abc.store import Store
from zarr.buffer import Buffer
from zarr.common import concurrent_map, to_thread
from zarr.common import OpenMode, concurrent_map, to_thread


def _get(path: Path, byte_range: tuple[int | None, int | None] | None) -> Buffer:
Expand Down Expand Up @@ -69,7 +69,8 @@ class LocalStore(Store):

root: Path

def __init__(self, root: Path | str):
def __init__(self, root: Path | str, *, mode: OpenMode = "r"):
super().__init__(mode=mode)
if isinstance(root, str):
root = Path(root)
assert isinstance(root, Path)
Expand Down Expand Up @@ -117,6 +118,7 @@ async def get_partial_values(
return await concurrent_map(args, to_thread, limit=None) # TODO: fix limit

async def set(self, key: str, value: Buffer) -> None:
self._check_writable()
assert isinstance(key, str)
if isinstance(value, bytes | bytearray):
# TODO: to support the v2 tests, we convert bytes to Buffer here
Expand All @@ -127,6 +129,7 @@ async def set(self, key: str, value: Buffer) -> None:
await to_thread(_put, path, value)

async def set_partial_values(self, key_start_values: list[tuple[str, int, bytes]]) -> None:
self._check_writable()
args = []
for key, start, value in key_start_values:
assert isinstance(key, str)
Expand All @@ -138,6 +141,7 @@ async def set_partial_values(self, key_start_values: list[tuple[str, int, bytes]
await concurrent_map(args, to_thread, limit=None) # TODO: fix limit

async def delete(self, key: str) -> None:
self._check_writable()
path = self.root / key
if path.is_dir(): # TODO: support deleting directories? shutil.rmtree?
shutil.rmtree(path)
Expand Down
9 changes: 7 additions & 2 deletions src/zarr/store/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from zarr.abc.store import Store
from zarr.buffer import Buffer
from zarr.common import concurrent_map
from zarr.common import OpenMode, concurrent_map
from zarr.store.core import _normalize_interval_index


Expand All @@ -17,7 +17,10 @@ class MemoryStore(Store):

_store_dict: MutableMapping[str, Buffer]

def __init__(self, store_dict: MutableMapping[str, Buffer] | None = None):
def __init__(
self, store_dict: MutableMapping[str, Buffer] | None = None, *, mode: OpenMode = "r"
):
super().__init__(mode=mode)
self._store_dict = store_dict or {}

def __str__(self) -> str:
Expand Down Expand Up @@ -47,6 +50,7 @@ async def exists(self, key: str) -> bool:
return key in self._store_dict

async def set(self, key: str, value: Buffer, byte_range: tuple[int, int] | None = None) -> None:
self._check_writable()
d-v-b marked this conversation as resolved.
Show resolved Hide resolved
assert isinstance(key, str)
if isinstance(value, bytes | bytearray):
# TODO: to support the v2 tests, we convert bytes to Buffer here
Expand All @@ -62,6 +66,7 @@ async def set(self, key: str, value: Buffer, byte_range: tuple[int, int] | None
self._store_dict[key] = value

async def delete(self, key: str) -> None:
self._check_writable()
try:
del self._store_dict[key]
except KeyError:
Expand Down
10 changes: 9 additions & 1 deletion src/zarr/store/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from zarr.abc.store import Store
from zarr.buffer import Buffer
from zarr.common import OpenMode
from zarr.store.core import _dereference_path

if TYPE_CHECKING:
Expand All @@ -18,17 +19,22 @@ class RemoteStore(Store):

root: UPath

def __init__(self, url: UPath | str, **storage_options: dict[str, Any]):
def __init__(
self, url: UPath | str, *, mode: OpenMode = "r", **storage_options: dict[str, Any]
):
import fsspec
from upath import UPath

super().__init__(mode=mode)

if isinstance(url, str):
self.root = UPath(url, **storage_options)
else:
assert (
len(storage_options) == 0
), "If constructed with a UPath object, no additional storage_options are allowed."
self.root = url.rstrip("/")

# test instantiate file system
fs, _ = fsspec.core.url_to_fs(str(self.root), asynchronous=True, **self.root._kwargs)
assert fs.__class__.async_impl, "FileSystem needs to support async operations."
Expand Down Expand Up @@ -67,6 +73,7 @@ async def get(
return value

async def set(self, key: str, value: Buffer, byte_range: tuple[int, int] | None = None) -> None:
self._check_writable()
assert isinstance(key, str)
fs, root = self._make_fs()
path = _dereference_path(root, key)
Expand All @@ -80,6 +87,7 @@ async def set(self, key: str, value: Buffer, byte_range: tuple[int, int] | None
await fs._pipe_file(path, value)

async def delete(self, key: str) -> None:
self._check_writable()
fs, root = self._make_fs()
path = _dereference_path(root, key)
if await fs._exists(path):
Expand Down
37 changes: 34 additions & 3 deletions src/zarr/testing/store.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Generic, TypeVar
from typing import Any, Generic, TypeVar

import pytest

Expand Down Expand Up @@ -31,13 +31,43 @@ def get(self, store: S, key: str) -> Buffer:
raise NotImplementedError

@pytest.fixture(scope="function")
def store(self) -> Store:
return self.store_cls()
def store_kwargs(self) -> dict[str, Any]:
return {"mode": "w"}

@pytest.fixture(scope="function")
def store(self, store_kwargs: dict[str, Any]) -> Store:
return self.store_cls(**store_kwargs)

def test_store_type(self, store: S) -> None:
assert isinstance(store, Store)
assert isinstance(store, self.store_cls)

def test_store_mode(self, store: S, store_kwargs: dict[str, Any]) -> None:
assert store.mode == "w", store.mode
assert store.writeable

with pytest.raises(AttributeError):
store.mode = "w" # type: ignore

# read-only
kwargs = {**store_kwargs, "mode": "r"}
read_store = self.store_cls(**kwargs)
assert read_store.mode == "r", read_store.mode
assert not read_store.writeable

async def test_not_writable_store_raises(self, store_kwargs: dict[str, Any]) -> None:
kwargs = {**store_kwargs, "mode": "r"}
store = self.store_cls(**kwargs)
assert not store.writeable

# set
with pytest.raises(ValueError):
await store.set("foo", Buffer.from_bytes(b"bar"))

# delete
with pytest.raises(ValueError):
await store.delete("foo")

def test_store_repr(self, store: S) -> None:
raise NotImplementedError

Expand Down Expand Up @@ -72,6 +102,7 @@ async def test_set(self, store: S, key: str, data: bytes) -> None:
"""
Ensure that data can be written to the store using the store.set method.
"""
assert store.writeable
data_buf = Buffer.from_bytes(data)
await store.set(key, data_buf)
observed = self.get(store, key)
Expand Down
14 changes: 7 additions & 7 deletions tests/v3/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@ def parse_store(
store: Literal["local", "memory", "remote"], path: str
) -> LocalStore | MemoryStore | RemoteStore:
if store == "local":
return LocalStore(path)
return LocalStore(path, mode="w")
if store == "memory":
return MemoryStore()
return MemoryStore(mode="w")
if store == "remote":
return RemoteStore()
return RemoteStore(mode="w")
raise AssertionError


Expand All @@ -38,24 +38,24 @@ def path_type(request):
# todo: harmonize this with local_store fixture
@pytest.fixture
def store_path(tmpdir):
store = LocalStore(str(tmpdir))
store = LocalStore(str(tmpdir), mode="w")
p = StorePath(store)
return p


@pytest.fixture(scope="function")
def local_store(tmpdir):
return LocalStore(str(tmpdir))
return LocalStore(str(tmpdir), mode="w")


@pytest.fixture(scope="function")
def remote_store():
return RemoteStore()
return RemoteStore(mode="w")


@pytest.fixture(scope="function")
def memory_store():
return MemoryStore()
return MemoryStore(mode="w")


@pytest.fixture(scope="function")
Expand Down
2 changes: 1 addition & 1 deletion tests/v3/test_codecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ async def set(self, value: np.ndarray):

@pytest.fixture
def store() -> Iterator[Store]:
yield StorePath(MemoryStore())
yield StorePath(MemoryStore(mode="w"))


@pytest.fixture
Expand Down
29 changes: 16 additions & 13 deletions tests/v3/test_store.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from collections.abc import MutableMapping
from typing import Any

import pytest

Expand All @@ -10,7 +10,6 @@
from zarr.testing.store import StoreTests


@pytest.mark.parametrize("store_dict", (None, {}))
class TestMemoryStore(StoreTests[MemoryStore]):
store_cls = MemoryStore

Expand All @@ -20,21 +19,25 @@ def set(self, store: MemoryStore, key: str, value: Buffer) -> None:
def get(self, store: MemoryStore, key: str) -> Buffer:
return store._store_dict[key]

@pytest.fixture(scope="function", params=[None, {}])
def store_kwargs(self, request) -> dict[str, Any]:
return {"store_dict": request.param, "mode": "w"}

@pytest.fixture(scope="function")
def store(self, store_dict: MutableMapping[str, Buffer] | None):
return MemoryStore(store_dict=store_dict)
def store(self, store_kwargs: dict[str, Any]) -> MemoryStore:
return self.store_cls(**store_kwargs)

def test_store_repr(self, store: MemoryStore) -> None:
assert str(store) == f"memory://{id(store._store_dict)}"

def test_store_supports_writes(self, store: MemoryStore) -> None:
assert True
assert store.supports_writes

def test_store_supports_listing(self, store: MemoryStore) -> None:
assert True
assert store.supports_listing

def test_store_supports_partial_writes(self, store: MemoryStore) -> None:
assert True
assert store.supports_partial_writes

def test_list_prefix(self, store: MemoryStore) -> None:
assert True
Expand All @@ -52,21 +55,21 @@ def set(self, store: LocalStore, key: str, value: Buffer) -> None:
parent.mkdir(parents=True)
(store.root / key).write_bytes(value.to_bytes())

@pytest.fixture(scope="function")
def store(self, tmpdir) -> LocalStore:
return self.store_cls(str(tmpdir))
@pytest.fixture
def store_kwargs(self, tmpdir) -> dict[str, str]:
return {"root": str(tmpdir), "mode": "w"}

def test_store_repr(self, store: LocalStore) -> None:
assert str(store) == f"file://{store.root!s}"

def test_store_supports_writes(self, store: LocalStore) -> None:
assert True
assert store.supports_writes

def test_store_supports_partial_writes(self, store: LocalStore) -> None:
assert True
assert store.supports_partial_writes

def test_store_supports_listing(self, store: LocalStore) -> None:
assert True
assert store.supports_listing

def test_list_prefix(self, store: LocalStore) -> None:
assert True
Loading
Loading