From d57b38329061800882f02ec0b240933caf299931 Mon Sep 17 00:00:00 2001 From: Justus Magin Date: Wed, 15 Oct 2025 12:25:28 +0200 Subject: [PATCH 1/9] array registry infrastructure --- src/zarr/core/buffer/cpu.py | 6 +++++ src/zarr/core/buffer/gpu.py | 8 ++++++ src/zarr/registry.py | 51 ++++++++++++++++++++++++++++++++++++- 3 files changed, 64 insertions(+), 1 deletion(-) diff --git a/src/zarr/core/buffer/cpu.py b/src/zarr/core/buffer/cpu.py index 415b9d928c..077856c052 100644 --- a/src/zarr/core/buffer/cpu.py +++ b/src/zarr/core/buffer/cpu.py @@ -11,6 +11,7 @@ from zarr.core.buffer import core from zarr.registry import ( + register_array_type, register_buffer, register_ndbuffer, ) @@ -230,6 +231,11 @@ def numpy_buffer_prototype() -> core.BufferPrototype: return core.BufferPrototype(buffer=Buffer, nd_buffer=NDBuffer) +@register_array_type(np.ndarray) +def create_numpy_buffer(value: npt.ArrayLike) -> core.BufferPrototype: + return buffer_prototype + + register_buffer(Buffer, qualname="zarr.buffer.cpu.Buffer") register_ndbuffer(NDBuffer, qualname="zarr.buffer.cpu.NDBuffer") diff --git a/src/zarr/core/buffer/gpu.py b/src/zarr/core/buffer/gpu.py index bfe977c50f..c51e29d63f 100644 --- a/src/zarr/core/buffer/gpu.py +++ b/src/zarr/core/buffer/gpu.py @@ -15,6 +15,7 @@ from zarr.core.buffer.core import ArrayLike, BufferPrototype, NDArrayLike from zarr.errors import ZarrUserWarning from zarr.registry import ( + register_array_type, register_buffer, register_ndbuffer, ) @@ -228,6 +229,13 @@ def __setitem__(self, key: Any, value: Any) -> None: buffer_prototype = BufferPrototype(buffer=Buffer, nd_buffer=NDBuffer) +if cp is not None: + + @register_array_type(cp.ndarray) + def _(value: cp.ndarray) -> BufferPrototype: + return buffer_prototype + + register_buffer(Buffer, qualname="zarr.buffer.gpu.Buffer") register_ndbuffer(NDBuffer, qualname="zarr.buffer.gpu.NDBuffer") diff --git a/src/zarr/registry.py b/src/zarr/registry.py index a8dd2a1c6c..bb898ef54d 100644 --- a/src/zarr/registry.py +++ b/src/zarr/registry.py @@ -2,6 +2,7 @@ import warnings from collections import defaultdict +from dataclasses import dataclass, field from importlib.metadata import entry_points as get_entry_points from typing import TYPE_CHECKING, Any, Generic, TypeVar @@ -10,8 +11,11 @@ from zarr.errors import ZarrUserWarning if TYPE_CHECKING: + from collections.abc import Callable from importlib.metadata import EntryPoint + import numpy.typing as npt + from zarr.abc.codec import ( ArrayArrayCodec, ArrayBytesCodec, @@ -21,7 +25,7 @@ CodecPipeline, ) from zarr.abc.numcodec import Numcodec - from zarr.core.buffer import Buffer, NDBuffer + from zarr.core.buffer import Buffer, BufferPrototype, NDBuffer from zarr.core.chunk_key_encodings import ChunkKeyEncoding from zarr.core.common import JSON @@ -138,6 +142,51 @@ def fully_qualified_name(cls: type) -> str: return module + "." + cls.__qualname__ +@dataclass +class ArrayTypeRegistry: + registry: dict[type[npt.ArrayLike], Callable[[npt.ArrayLike], BufferPrototype]] = field( + default_factory=dict + ) + + def register( + self, + array_type: type[npt.ArrayLike], + converter: Callable[[npt.ArrayLike], BufferPrototype], + ) -> None: + self.registry[array_type] = converter + + def lookup(self, value: npt.ArrayLike) -> BufferPrototype: + from zarr.core.buffer.core import default_buffer_prototype + + converter = self.registry.get(type(value), None) + if converter is None: + return default_buffer_prototype() + + return converter(value) + + +__array_type_registry = ArrayTypeRegistry() + + +def register_array_type( + array_type: type[npt.ArrayLike], +) -> Callable[ + [Callable[[npt.ArrayLike], BufferPrototype]], Callable[[npt.ArrayLike], BufferPrototype] +]: + def wrapper( + func: Callable[[npt.ArrayLike], BufferPrototype], + ) -> Callable[[npt.ArrayLike], BufferPrototype]: + __array_type_registry.register(array_type, func) + + return func + + return wrapper + + +def infer_prototype(value: npt.ArrayLike) -> BufferPrototype: + return __array_type_registry.lookup(value) + + def register_codec(key: str, codec_cls: type[Codec], *, qualname: str | None = None) -> None: if key not in __codec_registries: __codec_registries[key] = Registry() From b35030eba97ad83a0e12e38540837a476088bbf6 Mon Sep 17 00:00:00 2001 From: Justus Magin Date: Wed, 15 Oct 2025 12:29:02 +0200 Subject: [PATCH 2/9] infer the prototype from the array type --- src/zarr/core/array.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/src/zarr/core/array.py b/src/zarr/core/array.py index 42d6201ba9..9e919ffc8b 100644 --- a/src/zarr/core/array.py +++ b/src/zarr/core/array.py @@ -129,6 +129,7 @@ _parse_array_bytes_codec, _parse_bytes_bytes_codec, get_pipeline_class, + infer_prototype, ) from zarr.storage._common import StorePath, ensure_no_existing_node, make_store_path from zarr.storage._utils import _relativize_path @@ -1787,7 +1788,7 @@ async def setitem( - Supports basic indexing, where the selection is contiguous and does not involve advanced indexing. """ if prototype is None: - prototype = default_buffer_prototype() + prototype = infer_prototype(value) indexer = BasicIndexer( selection, shape=self.metadata.shape, @@ -3195,7 +3196,7 @@ def set_basic_selection( """ if prototype is None: - prototype = default_buffer_prototype() + prototype = infer_prototype(value) indexer = BasicIndexer(selection, self.shape, self.metadata.chunk_grid) sync(self._async_array._set_selection(indexer, value, fields=fields, prototype=prototype)) @@ -3442,7 +3443,7 @@ def set_orthogonal_selection( [__setitem__][zarr.Array.__setitem__] """ if prototype is None: - prototype = default_buffer_prototype() + prototype = infer_prototype(value) indexer = OrthogonalIndexer(selection, self.shape, self.metadata.chunk_grid) return sync( self._async_array._set_selection(indexer, value, fields=fields, prototype=prototype) @@ -3620,7 +3621,7 @@ def set_mask_selection( """ if prototype is None: - prototype = default_buffer_prototype() + prototype = infer_prototype(value) indexer = MaskIndexer(mask, self.shape, self.metadata.chunk_grid) sync(self._async_array._set_selection(indexer, value, fields=fields, prototype=prototype)) @@ -3800,7 +3801,7 @@ def set_coordinate_selection( """ if prototype is None: - prototype = default_buffer_prototype() + prototype = infer_prototype(value) # setup indexer indexer = CoordinateIndexer(selection, self.shape, self.metadata.chunk_grid) @@ -4024,7 +4025,7 @@ def set_block_selection( """ if prototype is None: - prototype = default_buffer_prototype() + prototype = infer_prototype(value) indexer = BlockIndexer(selection, self.shape, self.metadata.chunk_grid) sync(self._async_array._set_selection(indexer, value, fields=fields, prototype=prototype)) From b36c1a1a989d71f3fc170bbcf1790c452baec208 Mon Sep 17 00:00:00 2001 From: Justus Magin Date: Thu, 16 Oct 2025 15:15:20 +0200 Subject: [PATCH 3/9] add buffer declarations to codecs --- src/zarr/abc/codec.py | 9 +++++++++ src/zarr/codecs/blosc.py | 5 ++++- src/zarr/codecs/bytes.py | 3 +++ src/zarr/codecs/crc32c_.py | 5 ++++- src/zarr/codecs/transpose.py | 5 ++++- src/zarr/codecs/vlen_utf8.py | 6 ++++++ src/zarr/codecs/zstd.py | 5 ++++- 7 files changed, 34 insertions(+), 4 deletions(-) diff --git a/src/zarr/abc/codec.py b/src/zarr/abc/codec.py index d41c457b4e..6bab38a29f 100644 --- a/src/zarr/abc/codec.py +++ b/src/zarr/abc/codec.py @@ -185,14 +185,23 @@ async def encode( class ArrayArrayCodec(BaseCodec[NDBuffer, NDBuffer]): """Base class for array-to-array codecs.""" + codec_input: type[NDBuffer] + codec_output: type[NDBuffer] + class ArrayBytesCodec(BaseCodec[NDBuffer, Buffer]): """Base class for array-to-bytes codecs.""" + codec_input: type[NDBuffer] + codec_output: type[Buffer] + class BytesBytesCodec(BaseCodec[Buffer, Buffer]): """Base class for bytes-to-bytes codecs.""" + codec_input: type[Buffer] + codec_output: type[Buffer] + Codec = ArrayArrayCodec | ArrayBytesCodec | BytesBytesCodec diff --git a/src/zarr/codecs/blosc.py b/src/zarr/codecs/blosc.py index 6a482ed6e5..8aa36a885b 100644 --- a/src/zarr/codecs/blosc.py +++ b/src/zarr/codecs/blosc.py @@ -11,6 +11,7 @@ from packaging.version import Version from zarr.abc.codec import BytesBytesCodec +from zarr.core.buffer import Buffer from zarr.core.buffer.cpu import as_numpy_array_wrapper from zarr.core.common import JSON, parse_enum, parse_named_configuration from zarr.core.dtype.common import HasItemSize @@ -19,7 +20,6 @@ from typing import Self from zarr.core.array_spec import ArraySpec - from zarr.core.buffer import Buffer class BloscShuffle(Enum): @@ -88,6 +88,9 @@ def parse_blocksize(data: JSON) -> int: class BloscCodec(BytesBytesCodec): """blosc codec""" + codec_input = Buffer + codec_output = Buffer + is_fixed_size = False typesize: int | None diff --git a/src/zarr/codecs/bytes.py b/src/zarr/codecs/bytes.py index 39c26bd4a8..19dcc5e6f6 100644 --- a/src/zarr/codecs/bytes.py +++ b/src/zarr/codecs/bytes.py @@ -34,6 +34,9 @@ class Endian(Enum): class BytesCodec(ArrayBytesCodec): """bytes codec""" + codec_input = NDBuffer + codec_output = Buffer + is_fixed_size = True endian: Endian | None diff --git a/src/zarr/codecs/crc32c_.py b/src/zarr/codecs/crc32c_.py index 9536d0d558..67c79d2cbd 100644 --- a/src/zarr/codecs/crc32c_.py +++ b/src/zarr/codecs/crc32c_.py @@ -8,19 +8,22 @@ import typing_extensions from zarr.abc.codec import BytesBytesCodec +from zarr.core.buffer import Buffer from zarr.core.common import JSON, parse_named_configuration if TYPE_CHECKING: from typing import Self from zarr.core.array_spec import ArraySpec - from zarr.core.buffer import Buffer @dataclass(frozen=True) class Crc32cCodec(BytesBytesCodec): """crc32c codec""" + codec_input = Buffer + codec_output = Buffer + is_fixed_size = True @classmethod diff --git a/src/zarr/codecs/transpose.py b/src/zarr/codecs/transpose.py index a8570b6e8f..dfa95a3e75 100644 --- a/src/zarr/codecs/transpose.py +++ b/src/zarr/codecs/transpose.py @@ -8,12 +8,12 @@ from zarr.abc.codec import ArrayArrayCodec from zarr.core.array_spec import ArraySpec +from zarr.core.buffer import NDBuffer from zarr.core.common import JSON, parse_named_configuration if TYPE_CHECKING: from typing import Self - from zarr.core.buffer import NDBuffer from zarr.core.chunk_grids import ChunkGrid from zarr.core.dtype.wrapper import TBaseDType, TBaseScalar, ZDType @@ -30,6 +30,9 @@ def parse_transpose_order(data: JSON | Iterable[int]) -> tuple[int, ...]: class TransposeCodec(ArrayArrayCodec): """Transpose codec""" + codec_input = NDBuffer + codec_output = NDBuffer + is_fixed_size = True order: tuple[int, ...] diff --git a/src/zarr/codecs/vlen_utf8.py b/src/zarr/codecs/vlen_utf8.py index fa1a229855..9a0edce695 100644 --- a/src/zarr/codecs/vlen_utf8.py +++ b/src/zarr/codecs/vlen_utf8.py @@ -25,6 +25,9 @@ class VLenUTF8Codec(ArrayBytesCodec): """Variable-length UTF8 codec""" + codec_input = NDBuffer + codec_output = Buffer + @classmethod def from_dict(cls, data: dict[str, JSON]) -> Self: _, configuration_parsed = parse_named_configuration( @@ -71,6 +74,9 @@ def compute_encoded_size(self, input_byte_length: int, _chunk_spec: ArraySpec) - @dataclass(frozen=True) class VLenBytesCodec(ArrayBytesCodec): + codec_input = NDBuffer + codec_output = Buffer + @classmethod def from_dict(cls, data: dict[str, JSON]) -> Self: _, configuration_parsed = parse_named_configuration( diff --git a/src/zarr/codecs/zstd.py b/src/zarr/codecs/zstd.py index 27cc9a7777..2df90291f9 100644 --- a/src/zarr/codecs/zstd.py +++ b/src/zarr/codecs/zstd.py @@ -10,6 +10,7 @@ from packaging.version import Version from zarr.abc.codec import BytesBytesCodec +from zarr.core.buffer import Buffer from zarr.core.buffer.cpu import as_numpy_array_wrapper from zarr.core.common import JSON, parse_named_configuration @@ -17,7 +18,6 @@ from typing import Self from zarr.core.array_spec import ArraySpec - from zarr.core.buffer import Buffer def parse_zstd_level(data: JSON) -> int: @@ -38,6 +38,9 @@ def parse_checksum(data: JSON) -> bool: class ZstdCodec(BytesBytesCodec): """zstd codec""" + codec_input = Buffer + codec_output = Buffer + is_fixed_size = True level: int = 0 From 1f718a8edd7f4589b7511b405991aa080890ebed Mon Sep 17 00:00:00 2001 From: Justus Magin Date: Thu, 16 Oct 2025 17:30:05 +0200 Subject: [PATCH 4/9] use the codec pipeline's prototype instead --- src/zarr/core/array.py | 33 ++++++++++++++++----------------- src/zarr/core/codec_pipeline.py | 27 +++++++++++++++++++++++++-- 2 files changed, 41 insertions(+), 19 deletions(-) diff --git a/src/zarr/core/array.py b/src/zarr/core/array.py index 9e919ffc8b..e5eb0ff708 100644 --- a/src/zarr/core/array.py +++ b/src/zarr/core/array.py @@ -37,7 +37,6 @@ NDArrayLike, NDArrayLikeOrScalar, NDBuffer, - default_buffer_prototype, ) from zarr.core.buffer.cpu import buffer_prototype as cpu_buffer_prototype from zarr.core.chunk_grids import RegularChunkGrid, _auto_partition, normalize_chunks @@ -129,7 +128,6 @@ _parse_array_bytes_codec, _parse_bytes_bytes_codec, get_pipeline_class, - infer_prototype, ) from zarr.storage._common import StorePath, ensure_no_existing_node, make_store_path from zarr.storage._utils import _relativize_path @@ -1624,7 +1622,7 @@ async def example(): ``` """ if prototype is None: - prototype = default_buffer_prototype() + prototype = self.codec_pipeline.prototype indexer = BasicIndexer( selection, shape=self.metadata.shape, @@ -1641,7 +1639,7 @@ async def get_orthogonal_selection( prototype: BufferPrototype | None = None, ) -> NDArrayLikeOrScalar: if prototype is None: - prototype = default_buffer_prototype() + prototype = self.codec_pipeline.prototype indexer = OrthogonalIndexer(selection, self.shape, self.metadata.chunk_grid) return await self._get_selection( indexer=indexer, out=out, fields=fields, prototype=prototype @@ -1656,7 +1654,7 @@ async def get_mask_selection( prototype: BufferPrototype | None = None, ) -> NDArrayLikeOrScalar: if prototype is None: - prototype = default_buffer_prototype() + prototype = self.codec_pipeline.prototype indexer = MaskIndexer(mask, self.shape, self.metadata.chunk_grid) return await self._get_selection( indexer=indexer, out=out, fields=fields, prototype=prototype @@ -1671,7 +1669,7 @@ async def get_coordinate_selection( prototype: BufferPrototype | None = None, ) -> NDArrayLikeOrScalar: if prototype is None: - prototype = default_buffer_prototype() + prototype = self.codec_pipeline.prototype indexer = CoordinateIndexer(selection, self.shape, self.metadata.chunk_grid) out_array = await self._get_selection( indexer=indexer, out=out, fields=fields, prototype=prototype @@ -1788,7 +1786,8 @@ async def setitem( - Supports basic indexing, where the selection is contiguous and does not involve advanced indexing. """ if prototype is None: - prototype = infer_prototype(value) + prototype = self.codec_pipeline.prototype + indexer = BasicIndexer( selection, shape=self.metadata.shape, @@ -3087,7 +3086,7 @@ def get_basic_selection( """ if prototype is None: - prototype = default_buffer_prototype() + prototype = self._async_array.codec_pipeline.prototype return sync( self._async_array._get_selection( BasicIndexer(selection, self.shape, self.metadata.chunk_grid), @@ -3196,7 +3195,7 @@ def set_basic_selection( """ if prototype is None: - prototype = infer_prototype(value) + prototype = self._async_array.codec_pipeline.prototype indexer = BasicIndexer(selection, self.shape, self.metadata.chunk_grid) sync(self._async_array._set_selection(indexer, value, fields=fields, prototype=prototype)) @@ -3324,7 +3323,7 @@ def get_orthogonal_selection( """ if prototype is None: - prototype = default_buffer_prototype() + prototype = self._async_array.codec_pipeline.prototype indexer = OrthogonalIndexer(selection, self.shape, self.metadata.chunk_grid) return sync( self._async_array._get_selection( @@ -3443,7 +3442,7 @@ def set_orthogonal_selection( [__setitem__][zarr.Array.__setitem__] """ if prototype is None: - prototype = infer_prototype(value) + prototype = self._async_array.codec_pipeline.prototype indexer = OrthogonalIndexer(selection, self.shape, self.metadata.chunk_grid) return sync( self._async_array._set_selection(indexer, value, fields=fields, prototype=prototype) @@ -3531,7 +3530,7 @@ def get_mask_selection( """ if prototype is None: - prototype = default_buffer_prototype() + prototype = self._async_array.codec_pipeline.prototype indexer = MaskIndexer(mask, self.shape, self.metadata.chunk_grid) return sync( self._async_array._get_selection( @@ -3621,7 +3620,7 @@ def set_mask_selection( """ if prototype is None: - prototype = infer_prototype(value) + prototype = self._async_array.codec_pipeline.prototype indexer = MaskIndexer(mask, self.shape, self.metadata.chunk_grid) sync(self._async_array._set_selection(indexer, value, fields=fields, prototype=prototype)) @@ -3709,7 +3708,7 @@ def get_coordinate_selection( """ if prototype is None: - prototype = default_buffer_prototype() + prototype = self._async_array.codec_pipeline.prototype indexer = CoordinateIndexer(selection, self.shape, self.metadata.chunk_grid) out_array = sync( self._async_array._get_selection( @@ -3801,7 +3800,7 @@ def set_coordinate_selection( """ if prototype is None: - prototype = infer_prototype(value) + prototype = self._async_array.codec_pipeline.prototype # setup indexer indexer = CoordinateIndexer(selection, self.shape, self.metadata.chunk_grid) @@ -3924,7 +3923,7 @@ def get_block_selection( [__setitem__][zarr.Array.__setitem__] """ if prototype is None: - prototype = default_buffer_prototype() + prototype = self._async_array.codec_pipeline.prototype indexer = BlockIndexer(selection, self.shape, self.metadata.chunk_grid) return sync( self._async_array._get_selection( @@ -4025,7 +4024,7 @@ def set_block_selection( """ if prototype is None: - prototype = infer_prototype(value) + prototype = self._async_array.codec_pipeline.prototype indexer = BlockIndexer(selection, self.shape, self.metadata.chunk_grid) sync(self._async_array._set_selection(indexer, value, fields=fields, prototype=prototype)) diff --git a/src/zarr/core/codec_pipeline.py b/src/zarr/core/codec_pipeline.py index 63fcda7065..9fa601c3c2 100644 --- a/src/zarr/core/codec_pipeline.py +++ b/src/zarr/core/codec_pipeline.py @@ -2,7 +2,7 @@ from dataclasses import dataclass from itertools import islice, pairwise -from typing import TYPE_CHECKING, Any, TypeVar +from typing import TYPE_CHECKING, Any, TypeVar, cast from warnings import warn from zarr.abc.codec import ( @@ -14,6 +14,7 @@ Codec, CodecPipeline, ) +from zarr.core.buffer import Buffer, BufferPrototype, NDBuffer, default_buffer_prototype from zarr.core.common import concurrent_map from zarr.core.config import config from zarr.core.indexing import SelectorTuple, is_scalar @@ -26,7 +27,6 @@ from zarr.abc.store import ByteGetter, ByteSetter from zarr.core.array_spec import ArraySpec - from zarr.core.buffer import Buffer, BufferPrototype, NDBuffer from zarr.core.chunk_grids import ChunkGrid from zarr.core.dtype.wrapper import TBaseDType, TBaseScalar, ZDType @@ -82,6 +82,29 @@ class BatchedCodecPipeline(CodecPipeline): bytes_bytes_codecs: tuple[BytesBytesCodec, ...] batch_size: int + @property + def prototype(self) -> BufferPrototype: + all_codecs = self.array_array_codecs + (self.array_bytes_codec,) + self.bytes_bytes_codecs + + current_buffer = all_codecs[0].codec_output + for codec in all_codecs[1:]: + if codec.codec_input is not current_buffer: + print(codec.codec_input, current_buffer) + raise ValueError("input buffer do not match the codec's predecessor") + + current_buffer = codec.codec_output + + nd_buffer = cast(type[NDBuffer], all_codecs[0].codec_input) + buffer = cast(type[Buffer], all_codecs[0].codec_output) + + if nd_buffer is NDBuffer and buffer is Buffer: + return default_buffer_prototype() + + return BufferPrototype( + nd_buffer=all_codecs[0].codec_input, + buffer=all_codecs[-1].codec_output, + ) + def evolve_from_array_spec(self, array_spec: ArraySpec) -> Self: return type(self).from_codecs(c.evolve_from_array_spec(array_spec=array_spec) for c in self) From ca1228287a0c4f9a3ce9c65b86d18cbd4f4d93f0 Mon Sep 17 00:00:00 2001 From: Justus Magin Date: Thu, 16 Oct 2025 17:32:33 +0200 Subject: [PATCH 5/9] Revert "array registry infrastructure" This reverts commit d57b38329061800882f02ec0b240933caf299931. --- src/zarr/core/buffer/cpu.py | 6 ----- src/zarr/core/buffer/gpu.py | 8 ------ src/zarr/registry.py | 51 +------------------------------------ 3 files changed, 1 insertion(+), 64 deletions(-) diff --git a/src/zarr/core/buffer/cpu.py b/src/zarr/core/buffer/cpu.py index 077856c052..415b9d928c 100644 --- a/src/zarr/core/buffer/cpu.py +++ b/src/zarr/core/buffer/cpu.py @@ -11,7 +11,6 @@ from zarr.core.buffer import core from zarr.registry import ( - register_array_type, register_buffer, register_ndbuffer, ) @@ -231,11 +230,6 @@ def numpy_buffer_prototype() -> core.BufferPrototype: return core.BufferPrototype(buffer=Buffer, nd_buffer=NDBuffer) -@register_array_type(np.ndarray) -def create_numpy_buffer(value: npt.ArrayLike) -> core.BufferPrototype: - return buffer_prototype - - register_buffer(Buffer, qualname="zarr.buffer.cpu.Buffer") register_ndbuffer(NDBuffer, qualname="zarr.buffer.cpu.NDBuffer") diff --git a/src/zarr/core/buffer/gpu.py b/src/zarr/core/buffer/gpu.py index c51e29d63f..bfe977c50f 100644 --- a/src/zarr/core/buffer/gpu.py +++ b/src/zarr/core/buffer/gpu.py @@ -15,7 +15,6 @@ from zarr.core.buffer.core import ArrayLike, BufferPrototype, NDArrayLike from zarr.errors import ZarrUserWarning from zarr.registry import ( - register_array_type, register_buffer, register_ndbuffer, ) @@ -229,13 +228,6 @@ def __setitem__(self, key: Any, value: Any) -> None: buffer_prototype = BufferPrototype(buffer=Buffer, nd_buffer=NDBuffer) -if cp is not None: - - @register_array_type(cp.ndarray) - def _(value: cp.ndarray) -> BufferPrototype: - return buffer_prototype - - register_buffer(Buffer, qualname="zarr.buffer.gpu.Buffer") register_ndbuffer(NDBuffer, qualname="zarr.buffer.gpu.NDBuffer") diff --git a/src/zarr/registry.py b/src/zarr/registry.py index bb898ef54d..a8dd2a1c6c 100644 --- a/src/zarr/registry.py +++ b/src/zarr/registry.py @@ -2,7 +2,6 @@ import warnings from collections import defaultdict -from dataclasses import dataclass, field from importlib.metadata import entry_points as get_entry_points from typing import TYPE_CHECKING, Any, Generic, TypeVar @@ -11,11 +10,8 @@ from zarr.errors import ZarrUserWarning if TYPE_CHECKING: - from collections.abc import Callable from importlib.metadata import EntryPoint - import numpy.typing as npt - from zarr.abc.codec import ( ArrayArrayCodec, ArrayBytesCodec, @@ -25,7 +21,7 @@ CodecPipeline, ) from zarr.abc.numcodec import Numcodec - from zarr.core.buffer import Buffer, BufferPrototype, NDBuffer + from zarr.core.buffer import Buffer, NDBuffer from zarr.core.chunk_key_encodings import ChunkKeyEncoding from zarr.core.common import JSON @@ -142,51 +138,6 @@ def fully_qualified_name(cls: type) -> str: return module + "." + cls.__qualname__ -@dataclass -class ArrayTypeRegistry: - registry: dict[type[npt.ArrayLike], Callable[[npt.ArrayLike], BufferPrototype]] = field( - default_factory=dict - ) - - def register( - self, - array_type: type[npt.ArrayLike], - converter: Callable[[npt.ArrayLike], BufferPrototype], - ) -> None: - self.registry[array_type] = converter - - def lookup(self, value: npt.ArrayLike) -> BufferPrototype: - from zarr.core.buffer.core import default_buffer_prototype - - converter = self.registry.get(type(value), None) - if converter is None: - return default_buffer_prototype() - - return converter(value) - - -__array_type_registry = ArrayTypeRegistry() - - -def register_array_type( - array_type: type[npt.ArrayLike], -) -> Callable[ - [Callable[[npt.ArrayLike], BufferPrototype]], Callable[[npt.ArrayLike], BufferPrototype] -]: - def wrapper( - func: Callable[[npt.ArrayLike], BufferPrototype], - ) -> Callable[[npt.ArrayLike], BufferPrototype]: - __array_type_registry.register(array_type, func) - - return func - - return wrapper - - -def infer_prototype(value: npt.ArrayLike) -> BufferPrototype: - return __array_type_registry.lookup(value) - - def register_codec(key: str, codec_cls: type[Codec], *, qualname: str | None = None) -> None: if key not in __codec_registries: __codec_registries[key] = Registry() From 7ba5ced03b5e77f69c636ef419c58afa3d662278 Mon Sep 17 00:00:00 2001 From: Justus Magin Date: Thu, 16 Oct 2025 17:36:26 +0200 Subject: [PATCH 6/9] get typing to pass --- src/zarr/abc/codec.py | 8 +++++++- src/zarr/core/codec_pipeline.py | 5 +---- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/src/zarr/abc/codec.py b/src/zarr/abc/codec.py index 6bab38a29f..0059bb39c7 100644 --- a/src/zarr/abc/codec.py +++ b/src/zarr/abc/codec.py @@ -1,6 +1,6 @@ from __future__ import annotations -from abc import abstractmethod +from abc import abstractmethod, abstractproperty from collections.abc import Mapping from typing import TYPE_CHECKING, Generic, TypeGuard, TypeVar @@ -17,6 +17,7 @@ from zarr.abc.store import ByteGetter, ByteSetter, Store from zarr.core.array_spec import ArraySpec + from zarr.core.buffer import BufferPrototype from zarr.core.chunk_grids import ChunkGrid from zarr.core.dtype.wrapper import TBaseDType, TBaseScalar, ZDType from zarr.core.indexing import SelectorTuple @@ -285,6 +286,11 @@ class CodecPipeline: decoding them and assembling an output array. On the write path, it encodes the chunks and writes them to a store (via ByteSetter).""" + @abstractproperty + def prototype(self) -> BufferPrototype: + """The buffer prototype of the codec pipeline""" + ... + @abstractmethod def evolve_from_array_spec(self, array_spec: ArraySpec) -> Self: """Fills in codec configuration parameters that can be automatically diff --git a/src/zarr/core/codec_pipeline.py b/src/zarr/core/codec_pipeline.py index 9fa601c3c2..9b22f83df6 100644 --- a/src/zarr/core/codec_pipeline.py +++ b/src/zarr/core/codec_pipeline.py @@ -100,10 +100,7 @@ def prototype(self) -> BufferPrototype: if nd_buffer is NDBuffer and buffer is Buffer: return default_buffer_prototype() - return BufferPrototype( - nd_buffer=all_codecs[0].codec_input, - buffer=all_codecs[-1].codec_output, - ) + return BufferPrototype(nd_buffer=nd_buffer, buffer=buffer) def evolve_from_array_spec(self, array_spec: ArraySpec) -> Self: return type(self).from_codecs(c.evolve_from_array_spec(array_spec=array_spec) for c in self) From d53ca8d18d3bb506d6064c26f24212f29deade0b Mon Sep 17 00:00:00 2001 From: Justus Magin Date: Fri, 17 Oct 2025 11:10:08 +0200 Subject: [PATCH 7/9] also make the v2 codec declare its buffers --- src/zarr/codecs/_v2.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/zarr/codecs/_v2.py b/src/zarr/codecs/_v2.py index 3c6c99c21c..4d7e62e8a1 100644 --- a/src/zarr/codecs/_v2.py +++ b/src/zarr/codecs/_v2.py @@ -8,16 +8,19 @@ from numcodecs.compat import ensure_bytes, ensure_ndarray_like from zarr.abc.codec import ArrayBytesCodec +from zarr.core.buffer import Buffer, NDBuffer from zarr.registry import get_ndbuffer_class if TYPE_CHECKING: from zarr.abc.numcodec import Numcodec from zarr.core.array_spec import ArraySpec - from zarr.core.buffer import Buffer, NDBuffer @dataclass(frozen=True) class V2Codec(ArrayBytesCodec): + codec_input = NDBuffer + codec_output = Buffer + filters: tuple[Numcodec, ...] | None compressor: Numcodec | None From 78287f9fcee55870723811094e3006bb75407a97 Mon Sep 17 00:00:00 2001 From: Justus Magin Date: Fri, 17 Oct 2025 11:13:50 +0200 Subject: [PATCH 8/9] more input / output declarations on codecs --- src/zarr/codecs/gzip.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/zarr/codecs/gzip.py b/src/zarr/codecs/gzip.py index 610ca9dadd..d91fd1971b 100644 --- a/src/zarr/codecs/gzip.py +++ b/src/zarr/codecs/gzip.py @@ -7,6 +7,7 @@ from numcodecs.gzip import GZip from zarr.abc.codec import BytesBytesCodec +from zarr.core.buffer import Buffer from zarr.core.buffer.cpu import as_numpy_array_wrapper from zarr.core.common import JSON, parse_named_configuration @@ -14,7 +15,6 @@ from typing import Self from zarr.core.array_spec import ArraySpec - from zarr.core.buffer import Buffer def parse_gzip_level(data: JSON) -> int: @@ -31,6 +31,9 @@ def parse_gzip_level(data: JSON) -> int: class GzipCodec(BytesBytesCodec): """gzip codec""" + codec_input = Buffer + codec_output = Buffer + is_fixed_size = False level: int = 5 From b9b7058d931ea9d1a5538a03807fb79dd8b56ede Mon Sep 17 00:00:00 2001 From: Justus Magin Date: Fri, 17 Oct 2025 15:52:52 +0200 Subject: [PATCH 9/9] more buffer declarations --- src/zarr/codecs/sharding.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/zarr/codecs/sharding.py b/src/zarr/codecs/sharding.py index b0fd75cef7..cd12806c80 100644 --- a/src/zarr/codecs/sharding.py +++ b/src/zarr/codecs/sharding.py @@ -334,6 +334,9 @@ class ShardingCodec( ): """Sharding codec""" + codec_input = NDBuffer + codec_output = Buffer + chunk_shape: tuple[int, ...] codecs: tuple[Codec, ...] index_codecs: tuple[Codec, ...]