Skip to content

Commit

Permalink
make shardingcodec pickleable (#2011)
Browse files Browse the repository at this point in the history
* use tmpdir for test

* type annotations

* refactor morton decode and remove destructuring in call to max

* parametrize sharding codec test by data shape

* refactor codec tests

* add test for pickling sharding codec, and make it pass

* Revert "use tmpdir for test"

This reverts commit 6ad2ca6.

* move fixtures into conftest.py

* Update tests/v3/test_codecs/test_endian.py
  • Loading branch information
d-v-b committed Jul 5, 2024
1 parent 22e3fc5 commit e84057a
Show file tree
Hide file tree
Showing 12 changed files with 1,189 additions and 1,076 deletions.
16 changes: 16 additions & 0 deletions src/zarr/codecs/sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,22 @@ def __init__(
object.__setattr__(self, "_get_index_chunk_spec", lru_cache()(self._get_index_chunk_spec))
object.__setattr__(self, "_get_chunks_per_shard", lru_cache()(self._get_chunks_per_shard))

# todo: typedict return type
def __getstate__(self) -> dict[str, Any]:
return self.to_dict()

def __setstate__(self, state: dict[str, Any]) -> None:
config = state["configuration"]
object.__setattr__(self, "chunk_shape", parse_shapelike(config["chunk_shape"]))
object.__setattr__(self, "codecs", parse_codecs(config["codecs"]))
object.__setattr__(self, "index_codecs", parse_codecs(config["index_codecs"]))
object.__setattr__(self, "index_location", parse_index_location(config["index_location"]))

# Use instance-local lru_cache to avoid memory leaks
object.__setattr__(self, "_get_chunk_spec", lru_cache()(self._get_chunk_spec))
object.__setattr__(self, "_get_index_chunk_spec", lru_cache()(self._get_index_chunk_spec))
object.__setattr__(self, "_get_chunks_per_shard", lru_cache()(self._get_chunks_per_shard))

@classmethod
def from_dict(cls, data: dict[str, JSON]) -> Self:
_, configuration_parsed = parse_named_configuration(data, "sharding_indexed")
Expand Down
35 changes: 18 additions & 17 deletions src/zarr/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1220,24 +1220,25 @@ def make_slice_selection(selection: Any) -> list[slice]:
return ls


def morton_order_iter(chunk_shape: ChunkCoords) -> Iterator[ChunkCoords]:
def decode_morton(z: int, chunk_shape: ChunkCoords) -> ChunkCoords:
# Inspired by compressed morton code as implemented in Neuroglancer
# https://github.com/google/neuroglancer/blob/master/src/neuroglancer/datasource/precomputed/volume.md#compressed-morton-code
bits = tuple(math.ceil(math.log2(c)) for c in chunk_shape)
max_coords_bits = max(*bits)
input_bit = 0
input_value = z
out = [0 for _ in range(len(chunk_shape))]

for coord_bit in range(max_coords_bits):
for dim in range(len(chunk_shape)):
if coord_bit < bits[dim]:
bit = (input_value >> input_bit) & 1
out[dim] |= bit << coord_bit
input_bit += 1
return tuple(out)
def decode_morton(z: int, chunk_shape: ChunkCoords) -> ChunkCoords:
# Inspired by compressed morton code as implemented in Neuroglancer
# https://github.com/google/neuroglancer/blob/master/src/neuroglancer/datasource/precomputed/volume.md#compressed-morton-code
bits = tuple(math.ceil(math.log2(c)) for c in chunk_shape)
max_coords_bits = max(bits)
input_bit = 0
input_value = z
out = [0] * len(chunk_shape)

for coord_bit in range(max_coords_bits):
for dim in range(len(chunk_shape)):
if coord_bit < bits[dim]:
bit = (input_value >> input_bit) & 1
out[dim] |= bit << coord_bit
input_bit += 1
return tuple(out)


def morton_order_iter(chunk_shape: ChunkCoords) -> Iterator[ChunkCoords]:
for i in range(product(chunk_shape)):
yield decode_morton(i, chunk_shape)

Expand Down
41 changes: 31 additions & 10 deletions tests/v3/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,18 @@
from types import ModuleType
from typing import TYPE_CHECKING

from zarr.common import ZarrFormat
from _pytest.compat import LEGACY_PATH

from zarr.abc.store import Store
from zarr.common import ChunkCoords, MemoryOrder, ZarrFormat
from zarr.group import AsyncGroup

if TYPE_CHECKING:
from typing import Any, Literal
import pathlib
from dataclasses import dataclass, field

import numpy as np
import pytest

from zarr.store import LocalStore, MemoryStore, StorePath
Expand All @@ -26,40 +30,40 @@ def parse_store(
if store == "memory":
return MemoryStore(mode="w")
if store == "remote":
return RemoteStore(mode="w")
return RemoteStore(url=path, mode="w")
raise AssertionError


@pytest.fixture(params=[str, pathlib.Path])
def path_type(request):
def path_type(request: pytest.FixtureRequest) -> Any:
return request.param


# todo: harmonize this with local_store fixture
@pytest.fixture
def store_path(tmpdir):
def store_path(tmpdir: LEGACY_PATH) -> StorePath:
store = LocalStore(str(tmpdir), mode="w")
p = StorePath(store)
return p


@pytest.fixture(scope="function")
def local_store(tmpdir):
def local_store(tmpdir: LEGACY_PATH) -> LocalStore:
return LocalStore(str(tmpdir), mode="w")


@pytest.fixture(scope="function")
def remote_store():
return RemoteStore(mode="w")
def remote_store(url: str) -> RemoteStore:
return RemoteStore(url, mode="w")


@pytest.fixture(scope="function")
def memory_store():
def memory_store() -> MemoryStore:
return MemoryStore(mode="w")


@pytest.fixture(scope="function")
def store(request: str, tmpdir):
def store(request: pytest.FixtureRequest, tmpdir: LEGACY_PATH) -> Store:
param = request.param
return parse_store(param, str(tmpdir))

Expand All @@ -72,7 +76,7 @@ class AsyncGroupRequest:


@pytest.fixture(scope="function")
async def async_group(request: pytest.FixtureRequest, tmpdir) -> AsyncGroup:
async def async_group(request: pytest.FixtureRequest, tmpdir: LEGACY_PATH) -> AsyncGroup:
param: AsyncGroupRequest = request.param

store = parse_store(param.store, str(tmpdir))
Expand All @@ -90,3 +94,20 @@ def xp(request: pytest.FixtureRequest) -> Iterator[ModuleType]:
"""Fixture to parametrize over numpy-like libraries"""

yield pytest.importorskip(request.param)


@dataclass
class ArrayRequest:
shape: ChunkCoords
dtype: str
order: MemoryOrder


@pytest.fixture
def array_fixture(request: pytest.FixtureRequest) -> np.ndarray:
array_request: ArrayRequest = request.param
return (
np.arange(np.prod(array_request.shape))
.reshape(array_request.shape, order=array_request.order)
.astype(array_request.dtype)
)
Loading

0 comments on commit e84057a

Please sign in to comment.