Skip to content

Commit

Permalink
store takes prototype
Browse files Browse the repository at this point in the history
  • Loading branch information
madsbk committed May 24, 2024
1 parent cb88b98 commit 26069f4
Show file tree
Hide file tree
Showing 11 changed files with 116 additions and 46 deletions.
17 changes: 12 additions & 5 deletions src/zarr/abc/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,17 @@
from collections.abc import AsyncGenerator
from typing import Protocol, runtime_checkable

from zarr.buffer import Buffer
from zarr.buffer import Buffer, Prototype
from zarr.common import BytesLike


class Store(ABC):
@abstractmethod
async def get(
self, key: str, byte_range: tuple[int, int | None] | None = None
self,
key: str,
prototype: Prototype,
byte_range: tuple[int, int | None] | None = None,
) -> Buffer | None:
"""Retrieve the value associated with a given key.
Expand All @@ -26,7 +29,7 @@ async def get(

@abstractmethod
async def get_partial_values(
self, key_ranges: list[tuple[str, tuple[int, int]]]
self, prototype: Prototype, key_ranges: list[tuple[str, tuple[int, int]]]
) -> list[Buffer | None]:
"""Retrieve possibly partial values from given key_ranges.
Expand Down Expand Up @@ -150,12 +153,16 @@ def list_dir(self, prefix: str) -> AsyncGenerator[str, None]:

@runtime_checkable
class ByteGetter(Protocol):
async def get(self, byte_range: tuple[int, int | None] | None = None) -> Buffer | None: ...
async def get(
self, prototype: Prototype, byte_range: tuple[int, int | None] | None = None
) -> Buffer | None: ...


@runtime_checkable
class ByteSetter(Protocol):
async def get(self, byte_range: tuple[int, int | None] | None = None) -> Buffer | None: ...
async def get(
self, prototype: Prototype, byte_range: tuple[int, int | None] | None = None
) -> Buffer | None: ...

async def set(self, value: Buffer, byte_range: tuple[int, int] | None = None) -> None: ...

Expand Down
11 changes: 11 additions & 0 deletions src/zarr/buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,8 +424,19 @@ def as_numpy_array_wrapper(func: Callable[[npt.NDArray[Any]], bytes], buf: Buffe


class Prototype(NamedTuple):
"""Prototype of the Buffer and NDBuffer class
Attributes
----------
buffer
The Buffer class to use when Zarr needs to create new Buffer.
nd_buffer
The NDBuffer class to use when Zarr needs to create new NDBuffer.
"""

buffer: type[Buffer]
nd_buffer: type[NDBuffer]


# The default prototype used throughout the Zarr codebase.
default_prototype = Prototype(buffer=Buffer, nd_buffer=NDBuffer)
20 changes: 14 additions & 6 deletions src/zarr/codecs/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
Codec,
CodecPipeline,
)
from zarr.buffer import Buffer, NDBuffer
from zarr.buffer import Buffer, NDBuffer, Prototype
from zarr.codecs.registry import get_codec_class
from zarr.common import JSON, concurrent_map, parse_named_configuration
from zarr.config import config
Expand Down Expand Up @@ -311,8 +311,11 @@ async def read_batch(
out[out_selection] = chunk_spec.fill_value
else:
chunk_bytes_batch = await concurrent_map(
[(byte_getter,) for byte_getter, _, _, _ in batch_info],
lambda byte_getter: byte_getter.get(),
[
(byte_getter, array_spec.prototype)
for byte_getter, array_spec, _, _ in batch_info
],
lambda byte_getter, prototype: byte_getter.get(prototype),
config.get("async.concurrency"),
)
chunk_array_batch = await self.decode_batch(
Expand Down Expand Up @@ -347,15 +350,20 @@ async def write_batch(

else:
# Read existing bytes if not total slice
async def _read_key(byte_setter: ByteSetter | None) -> Buffer | None:
async def _read_key(
byte_setter: ByteSetter | None, prototype: Prototype
) -> Buffer | None:
if byte_setter is None:
return None
return await byte_setter.get()
return await byte_setter.get(prototype=prototype)

chunk_bytes_batch: Iterable[Buffer | None]
chunk_bytes_batch = await concurrent_map(
[
(None if is_total_slice(chunk_selection, chunk_spec.shape) else byte_setter,)
(
None if is_total_slice(chunk_selection, chunk_spec.shape) else byte_setter,
chunk_spec.prototype,
)
for byte_setter, chunk_spec, chunk_selection, _ in batch_info
],
_read_key,
Expand Down
35 changes: 26 additions & 9 deletions src/zarr/codecs/sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
CodecPipeline,
)
from zarr.array_spec import ArraySpec
from zarr.buffer import Buffer, NDBuffer, default_prototype
from zarr.buffer import Buffer, NDBuffer, Prototype, default_prototype
from zarr.chunk_grids import RegularChunkGrid
from zarr.codecs.bytes import BytesCodec
from zarr.codecs.crc32c_ import Crc32cCodec
Expand Down Expand Up @@ -67,8 +67,11 @@ class _ShardingByteGetter(ByteGetter):
shard_dict: ShardMapping
chunk_coords: ChunkCoords

async def get(self, byte_range: tuple[int, int | None] | None = None) -> Buffer | None:
async def get(
self, prototype: Prototype, byte_range: tuple[int, int | None] | None = None
) -> Buffer | None:
assert byte_range is None, "byte_range is not supported within shards"
assert prototype is default_prototype, "prototype is not supported within shards currently"
return self.shard_dict.get(self.chunk_coords)


Expand Down Expand Up @@ -450,7 +453,11 @@ async def _decode_partial_single(
shard_dict: ShardMapping = {}
if self._is_total_shard(all_chunk_coords, chunks_per_shard):
# read entire shard
shard_dict_maybe = await self._load_full_shard_maybe(byte_getter, chunks_per_shard)
shard_dict_maybe = await self._load_full_shard_maybe(
byte_getter=byte_getter,
prototype=chunk_spec.prototype,
chunks_per_shard=chunks_per_shard,
)
if shard_dict_maybe is None:
return None
shard_dict = shard_dict_maybe
Expand All @@ -463,7 +470,9 @@ async def _decode_partial_single(
for chunk_coords in all_chunk_coords:
chunk_byte_slice = shard_index.get_chunk_slice(chunk_coords)
if chunk_byte_slice:
chunk_bytes = await byte_getter.get(chunk_byte_slice)
chunk_bytes = await byte_getter.get(
prototype=chunk_spec.prototype, byte_range=chunk_byte_slice
)
if chunk_bytes:
shard_dict[chunk_coords] = chunk_bytes

Expand Down Expand Up @@ -530,7 +539,11 @@ async def _encode_partial_single(
chunk_spec = self._get_chunk_spec(shard_spec)

shard_dict = _MergingShardBuilder(
await self._load_full_shard_maybe(byte_setter, chunks_per_shard)
await self._load_full_shard_maybe(
byte_getter=byte_setter,
prototype=chunk_spec.prototype,
chunks_per_shard=chunks_per_shard,
)
or _ShardReader.create_empty(chunks_per_shard),
_ShardBuilder.create_empty(chunks_per_shard),
)
Expand Down Expand Up @@ -641,9 +654,13 @@ async def _load_shard_index_maybe(
) -> _ShardIndex | None:
shard_index_size = self._shard_index_size(chunks_per_shard)
if self.index_location == ShardingCodecIndexLocation.start:
index_bytes = await byte_getter.get((0, shard_index_size))
index_bytes = await byte_getter.get(
prototype=default_prototype, byte_range=(0, shard_index_size)
)
else:
index_bytes = await byte_getter.get((-shard_index_size, None))
index_bytes = await byte_getter.get(
prototype=default_prototype, byte_range=(-shard_index_size, None)
)
if index_bytes is not None:
return await self._decode_shard_index(index_bytes, chunks_per_shard)
return None
Expand All @@ -656,9 +673,9 @@ async def _load_shard_index(
) or _ShardIndex.create_empty(chunks_per_shard)

async def _load_full_shard_maybe(
self, byte_getter: ByteGetter, chunks_per_shard: ChunkCoords
self, byte_getter: ByteGetter, prototype: Prototype, chunks_per_shard: ChunkCoords
) -> _ShardReader | None:
shard_bytes = await byte_getter.get()
shard_bytes = await byte_getter.get(prototype=prototype)

return (
await _ShardReader.from_bytes(shard_bytes, self, chunks_per_shard)
Expand Down
10 changes: 7 additions & 3 deletions src/zarr/store/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Any

from zarr.abc.store import Store
from zarr.buffer import Buffer
from zarr.buffer import Buffer, Prototype, default_prototype
from zarr.store.local import LocalStore


Expand All @@ -25,8 +25,12 @@ def __init__(self, store: Store, path: str | None = None):
self.store = store
self.path = path or ""

async def get(self, byte_range: tuple[int, int | None] | None = None) -> Buffer | None:
return await self.store.get(self.path, byte_range)
async def get(
self,
prototype: Prototype = default_prototype,
byte_range: tuple[int, int | None] | None = None,
) -> Buffer | None:
return await self.store.get(self.path, prototype=prototype, byte_range=byte_range)

async def set(self, value: Buffer, byte_range: tuple[int, int] | None = None) -> None:
if byte_range is not None:
Expand Down
18 changes: 9 additions & 9 deletions src/zarr/store/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@
from pathlib import Path

from zarr.abc.store import Store
from zarr.buffer import Buffer
from zarr.buffer import Buffer, Prototype
from zarr.common import concurrent_map, to_thread


def _get(path: Path, byte_range: tuple[int, int | None] | None) -> Buffer:
def _get(path: Path, prototype: Prototype, byte_range: tuple[int, int | None] | None) -> Buffer:
"""
Fetch a contiguous region of bytes from a file.
Expand All @@ -32,7 +32,7 @@ def _get(path: Path, byte_range: tuple[int, int | None] | None) -> Buffer:

end = (start + byte_range[1]) if byte_range[1] is not None else None
else:
return Buffer.from_bytes(path.read_bytes())
return prototype.buffer.from_bytes(path.read_bytes())
with path.open("rb") as f:
size = f.seek(0, io.SEEK_END)
if start is not None:
Expand All @@ -43,8 +43,8 @@ def _get(path: Path, byte_range: tuple[int, int | None] | None) -> Buffer:
if end is not None:
if end < 0:
end = size + end
return Buffer.from_bytes(f.read(end - f.tell()))
return Buffer.from_bytes(f.read())
return prototype.buffer.from_bytes(f.read(end - f.tell()))
return prototype.buffer.from_bytes(f.read())


def _put(
Expand Down Expand Up @@ -90,18 +90,18 @@ def __eq__(self, other: object) -> bool:
return isinstance(other, type(self)) and self.root == other.root

async def get(
self, key: str, byte_range: tuple[int, int | None] | None = None
self, key: str, prototype: Prototype, byte_range: tuple[int, int | None] | None = None
) -> Buffer | None:
assert isinstance(key, str)
path = self.root / key

try:
return await to_thread(_get, path, byte_range)
return await to_thread(_get, path, prototype, byte_range)
except (FileNotFoundError, IsADirectoryError, NotADirectoryError):
return None

async def get_partial_values(
self, key_ranges: list[tuple[str, tuple[int, int]]]
self, prototype: Prototype, key_ranges: list[tuple[str, tuple[int, int]]]
) -> list[Buffer | None]:
"""
Read byte ranges from multiple keys.
Expand All @@ -117,7 +117,7 @@ async def get_partial_values(
for key, byte_range in key_ranges:
assert isinstance(key, str)
path = self.root / key
args.append((_get, path, byte_range))
args.append((_get, path, prototype, byte_range))
return await concurrent_map(args, to_thread, limit=None) # TODO: fix limit

async def set(self, key: str, value: Buffer) -> None:
Expand Down
12 changes: 8 additions & 4 deletions src/zarr/store/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from collections.abc import AsyncGenerator, MutableMapping

from zarr.abc.store import Store
from zarr.buffer import Buffer
from zarr.buffer import Buffer, Prototype
from zarr.common import concurrent_map


Expand All @@ -26,7 +26,7 @@ def __repr__(self) -> str:
return f"MemoryStore({str(self)!r})"

async def get(
self, key: str, byte_range: tuple[int, int | None] | None = None
self, key: str, prototype: Prototype, byte_range: tuple[int, int | None] | None = None
) -> Buffer | None:
assert isinstance(key, str)
try:
Expand All @@ -38,9 +38,13 @@ async def get(
return None

async def get_partial_values(
self, key_ranges: list[tuple[str, tuple[int, int]]]
self, prototype: Prototype, key_ranges: list[tuple[str, tuple[int, int]]]
) -> list[Buffer | None]:
vals = await concurrent_map(key_ranges, self.get, limit=None)
# All the key-ranges arguments goes with the same prototype
async def _get(key: str, byte_range: tuple[int, int | None]) -> Buffer | None:
return await self.get(key, prototype=prototype, byte_range=byte_range)

vals = await concurrent_map(key_ranges, _get, limit=None)
return vals

async def exists(self, key: str) -> bool:
Expand Down
4 changes: 2 additions & 2 deletions src/zarr/store/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import TYPE_CHECKING, Any

from zarr.abc.store import Store
from zarr.buffer import Buffer
from zarr.buffer import Buffer, Prototype
from zarr.store.core import _dereference_path

if TYPE_CHECKING:
Expand Down Expand Up @@ -49,7 +49,7 @@ def _make_fs(self) -> tuple[AsyncFileSystem, str]:
return fs, root

async def get(
self, key: str, byte_range: tuple[int, int | None] | None = None
self, key: str, prototype: Prototype, byte_range: tuple[int, int | None] | None = None
) -> Buffer | None:
assert isinstance(key, str)
fs, root = self._make_fs()
Expand Down
8 changes: 4 additions & 4 deletions src/zarr/testing/store.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest

from zarr.abc.store import Store
from zarr.buffer import Buffer
from zarr.buffer import Buffer, default_prototype
from zarr.testing.utils import assert_bytes_equal


Expand All @@ -28,19 +28,19 @@ def test_store_capabilities(self, store: Store) -> None:
@pytest.mark.parametrize("data", [b"\x01\x02\x03\x04", b""])
async def test_set_get_bytes_roundtrip(self, store: Store, key: str, data: bytes) -> None:
await store.set(key, Buffer.from_bytes(data))
assert_bytes_equal(await store.get(key), data)
assert_bytes_equal(await store.get(key, prototype=default_prototype), data)

@pytest.mark.parametrize("key", ["foo/c/0"])
@pytest.mark.parametrize("data", [b"\x01\x02\x03\x04", b""])
async def test_get_partial_values(self, store: Store, key: str, data: bytes) -> None:
# put all of the data
await store.set(key, Buffer.from_bytes(data))
# read back just part of it
vals = await store.get_partial_values([(key, (0, 2))])
vals = await store.get_partial_values(default_prototype, [(key, (0, 2))])
assert_bytes_equal(vals[0], data[0:2])

# read back multiple parts of it at once
vals = await store.get_partial_values([(key, (0, 2)), (key, (2, 4))])
vals = await store.get_partial_values(default_prototype, [(key, (0, 2)), (key, (2, 4))])
assert_bytes_equal(vals[0], data[0:2])
assert_bytes_equal(vals[1], data[2:4])

Expand Down
11 changes: 11 additions & 0 deletions tests/v3/test_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,17 @@ async def set(self, key: str, value: Buffer, byte_range: tuple[int, int] | None
assert isinstance(value, MyBuffer)
await super().set(key, value, byte_range)

async def get(
self,
key: str,
prototype: Prototype,
byte_range: tuple[int, int | None] | None = None,
) -> Buffer | None:
# Check that non-metadata is using MyBuffer
if "json" not in key:
assert prototype.buffer is MyBuffer
return await super().get(key, byte_range)


def test_nd_array_like(xp):
ary = xp.arange(10)
Expand Down
Loading

0 comments on commit 26069f4

Please sign in to comment.