diff --git a/src/zarr/abc/codec.py b/src/zarr/abc/codec.py index d41c457b4e..0059bb39c7 100644 --- a/src/zarr/abc/codec.py +++ b/src/zarr/abc/codec.py @@ -1,6 +1,6 @@ from __future__ import annotations -from abc import abstractmethod +from abc import abstractmethod, abstractproperty from collections.abc import Mapping from typing import TYPE_CHECKING, Generic, TypeGuard, TypeVar @@ -17,6 +17,7 @@ from zarr.abc.store import ByteGetter, ByteSetter, Store from zarr.core.array_spec import ArraySpec + from zarr.core.buffer import BufferPrototype from zarr.core.chunk_grids import ChunkGrid from zarr.core.dtype.wrapper import TBaseDType, TBaseScalar, ZDType from zarr.core.indexing import SelectorTuple @@ -185,14 +186,23 @@ async def encode( class ArrayArrayCodec(BaseCodec[NDBuffer, NDBuffer]): """Base class for array-to-array codecs.""" + codec_input: type[NDBuffer] + codec_output: type[NDBuffer] + class ArrayBytesCodec(BaseCodec[NDBuffer, Buffer]): """Base class for array-to-bytes codecs.""" + codec_input: type[NDBuffer] + codec_output: type[Buffer] + class BytesBytesCodec(BaseCodec[Buffer, Buffer]): """Base class for bytes-to-bytes codecs.""" + codec_input: type[Buffer] + codec_output: type[Buffer] + Codec = ArrayArrayCodec | ArrayBytesCodec | BytesBytesCodec @@ -276,6 +286,11 @@ class CodecPipeline: decoding them and assembling an output array. On the write path, it encodes the chunks and writes them to a store (via ByteSetter).""" + @abstractproperty + def prototype(self) -> BufferPrototype: + """The buffer prototype of the codec pipeline""" + ... + @abstractmethod def evolve_from_array_spec(self, array_spec: ArraySpec) -> Self: """Fills in codec configuration parameters that can be automatically diff --git a/src/zarr/codecs/_v2.py b/src/zarr/codecs/_v2.py index 3c6c99c21c..4d7e62e8a1 100644 --- a/src/zarr/codecs/_v2.py +++ b/src/zarr/codecs/_v2.py @@ -8,16 +8,19 @@ from numcodecs.compat import ensure_bytes, ensure_ndarray_like from zarr.abc.codec import ArrayBytesCodec +from zarr.core.buffer import Buffer, NDBuffer from zarr.registry import get_ndbuffer_class if TYPE_CHECKING: from zarr.abc.numcodec import Numcodec from zarr.core.array_spec import ArraySpec - from zarr.core.buffer import Buffer, NDBuffer @dataclass(frozen=True) class V2Codec(ArrayBytesCodec): + codec_input = NDBuffer + codec_output = Buffer + filters: tuple[Numcodec, ...] | None compressor: Numcodec | None diff --git a/src/zarr/codecs/blosc.py b/src/zarr/codecs/blosc.py index 6a482ed6e5..8aa36a885b 100644 --- a/src/zarr/codecs/blosc.py +++ b/src/zarr/codecs/blosc.py @@ -11,6 +11,7 @@ from packaging.version import Version from zarr.abc.codec import BytesBytesCodec +from zarr.core.buffer import Buffer from zarr.core.buffer.cpu import as_numpy_array_wrapper from zarr.core.common import JSON, parse_enum, parse_named_configuration from zarr.core.dtype.common import HasItemSize @@ -19,7 +20,6 @@ from typing import Self from zarr.core.array_spec import ArraySpec - from zarr.core.buffer import Buffer class BloscShuffle(Enum): @@ -88,6 +88,9 @@ def parse_blocksize(data: JSON) -> int: class BloscCodec(BytesBytesCodec): """blosc codec""" + codec_input = Buffer + codec_output = Buffer + is_fixed_size = False typesize: int | None diff --git a/src/zarr/codecs/bytes.py b/src/zarr/codecs/bytes.py index 39c26bd4a8..19dcc5e6f6 100644 --- a/src/zarr/codecs/bytes.py +++ b/src/zarr/codecs/bytes.py @@ -34,6 +34,9 @@ class Endian(Enum): class BytesCodec(ArrayBytesCodec): """bytes codec""" + codec_input = NDBuffer + codec_output = Buffer + is_fixed_size = True endian: Endian | None diff --git a/src/zarr/codecs/crc32c_.py b/src/zarr/codecs/crc32c_.py index 9536d0d558..67c79d2cbd 100644 --- a/src/zarr/codecs/crc32c_.py +++ b/src/zarr/codecs/crc32c_.py @@ -8,19 +8,22 @@ import typing_extensions from zarr.abc.codec import BytesBytesCodec +from zarr.core.buffer import Buffer from zarr.core.common import JSON, parse_named_configuration if TYPE_CHECKING: from typing import Self from zarr.core.array_spec import ArraySpec - from zarr.core.buffer import Buffer @dataclass(frozen=True) class Crc32cCodec(BytesBytesCodec): """crc32c codec""" + codec_input = Buffer + codec_output = Buffer + is_fixed_size = True @classmethod diff --git a/src/zarr/codecs/gzip.py b/src/zarr/codecs/gzip.py index 610ca9dadd..d91fd1971b 100644 --- a/src/zarr/codecs/gzip.py +++ b/src/zarr/codecs/gzip.py @@ -7,6 +7,7 @@ from numcodecs.gzip import GZip from zarr.abc.codec import BytesBytesCodec +from zarr.core.buffer import Buffer from zarr.core.buffer.cpu import as_numpy_array_wrapper from zarr.core.common import JSON, parse_named_configuration @@ -14,7 +15,6 @@ from typing import Self from zarr.core.array_spec import ArraySpec - from zarr.core.buffer import Buffer def parse_gzip_level(data: JSON) -> int: @@ -31,6 +31,9 @@ def parse_gzip_level(data: JSON) -> int: class GzipCodec(BytesBytesCodec): """gzip codec""" + codec_input = Buffer + codec_output = Buffer + is_fixed_size = False level: int = 5 diff --git a/src/zarr/codecs/sharding.py b/src/zarr/codecs/sharding.py index b0fd75cef7..cd12806c80 100644 --- a/src/zarr/codecs/sharding.py +++ b/src/zarr/codecs/sharding.py @@ -334,6 +334,9 @@ class ShardingCodec( ): """Sharding codec""" + codec_input = NDBuffer + codec_output = Buffer + chunk_shape: tuple[int, ...] codecs: tuple[Codec, ...] index_codecs: tuple[Codec, ...] diff --git a/src/zarr/codecs/transpose.py b/src/zarr/codecs/transpose.py index a8570b6e8f..dfa95a3e75 100644 --- a/src/zarr/codecs/transpose.py +++ b/src/zarr/codecs/transpose.py @@ -8,12 +8,12 @@ from zarr.abc.codec import ArrayArrayCodec from zarr.core.array_spec import ArraySpec +from zarr.core.buffer import NDBuffer from zarr.core.common import JSON, parse_named_configuration if TYPE_CHECKING: from typing import Self - from zarr.core.buffer import NDBuffer from zarr.core.chunk_grids import ChunkGrid from zarr.core.dtype.wrapper import TBaseDType, TBaseScalar, ZDType @@ -30,6 +30,9 @@ def parse_transpose_order(data: JSON | Iterable[int]) -> tuple[int, ...]: class TransposeCodec(ArrayArrayCodec): """Transpose codec""" + codec_input = NDBuffer + codec_output = NDBuffer + is_fixed_size = True order: tuple[int, ...] diff --git a/src/zarr/codecs/vlen_utf8.py b/src/zarr/codecs/vlen_utf8.py index fa1a229855..9a0edce695 100644 --- a/src/zarr/codecs/vlen_utf8.py +++ b/src/zarr/codecs/vlen_utf8.py @@ -25,6 +25,9 @@ class VLenUTF8Codec(ArrayBytesCodec): """Variable-length UTF8 codec""" + codec_input = NDBuffer + codec_output = Buffer + @classmethod def from_dict(cls, data: dict[str, JSON]) -> Self: _, configuration_parsed = parse_named_configuration( @@ -71,6 +74,9 @@ def compute_encoded_size(self, input_byte_length: int, _chunk_spec: ArraySpec) - @dataclass(frozen=True) class VLenBytesCodec(ArrayBytesCodec): + codec_input = NDBuffer + codec_output = Buffer + @classmethod def from_dict(cls, data: dict[str, JSON]) -> Self: _, configuration_parsed = parse_named_configuration( diff --git a/src/zarr/codecs/zstd.py b/src/zarr/codecs/zstd.py index 27cc9a7777..2df90291f9 100644 --- a/src/zarr/codecs/zstd.py +++ b/src/zarr/codecs/zstd.py @@ -10,6 +10,7 @@ from packaging.version import Version from zarr.abc.codec import BytesBytesCodec +from zarr.core.buffer import Buffer from zarr.core.buffer.cpu import as_numpy_array_wrapper from zarr.core.common import JSON, parse_named_configuration @@ -17,7 +18,6 @@ from typing import Self from zarr.core.array_spec import ArraySpec - from zarr.core.buffer import Buffer def parse_zstd_level(data: JSON) -> int: @@ -38,6 +38,9 @@ def parse_checksum(data: JSON) -> bool: class ZstdCodec(BytesBytesCodec): """zstd codec""" + codec_input = Buffer + codec_output = Buffer + is_fixed_size = True level: int = 0 diff --git a/src/zarr/core/array.py b/src/zarr/core/array.py index 42d6201ba9..e5eb0ff708 100644 --- a/src/zarr/core/array.py +++ b/src/zarr/core/array.py @@ -37,7 +37,6 @@ NDArrayLike, NDArrayLikeOrScalar, NDBuffer, - default_buffer_prototype, ) from zarr.core.buffer.cpu import buffer_prototype as cpu_buffer_prototype from zarr.core.chunk_grids import RegularChunkGrid, _auto_partition, normalize_chunks @@ -1623,7 +1622,7 @@ async def example(): ``` """ if prototype is None: - prototype = default_buffer_prototype() + prototype = self.codec_pipeline.prototype indexer = BasicIndexer( selection, shape=self.metadata.shape, @@ -1640,7 +1639,7 @@ async def get_orthogonal_selection( prototype: BufferPrototype | None = None, ) -> NDArrayLikeOrScalar: if prototype is None: - prototype = default_buffer_prototype() + prototype = self.codec_pipeline.prototype indexer = OrthogonalIndexer(selection, self.shape, self.metadata.chunk_grid) return await self._get_selection( indexer=indexer, out=out, fields=fields, prototype=prototype @@ -1655,7 +1654,7 @@ async def get_mask_selection( prototype: BufferPrototype | None = None, ) -> NDArrayLikeOrScalar: if prototype is None: - prototype = default_buffer_prototype() + prototype = self.codec_pipeline.prototype indexer = MaskIndexer(mask, self.shape, self.metadata.chunk_grid) return await self._get_selection( indexer=indexer, out=out, fields=fields, prototype=prototype @@ -1670,7 +1669,7 @@ async def get_coordinate_selection( prototype: BufferPrototype | None = None, ) -> NDArrayLikeOrScalar: if prototype is None: - prototype = default_buffer_prototype() + prototype = self.codec_pipeline.prototype indexer = CoordinateIndexer(selection, self.shape, self.metadata.chunk_grid) out_array = await self._get_selection( indexer=indexer, out=out, fields=fields, prototype=prototype @@ -1787,7 +1786,8 @@ async def setitem( - Supports basic indexing, where the selection is contiguous and does not involve advanced indexing. """ if prototype is None: - prototype = default_buffer_prototype() + prototype = self.codec_pipeline.prototype + indexer = BasicIndexer( selection, shape=self.metadata.shape, @@ -3086,7 +3086,7 @@ def get_basic_selection( """ if prototype is None: - prototype = default_buffer_prototype() + prototype = self._async_array.codec_pipeline.prototype return sync( self._async_array._get_selection( BasicIndexer(selection, self.shape, self.metadata.chunk_grid), @@ -3195,7 +3195,7 @@ def set_basic_selection( """ if prototype is None: - prototype = default_buffer_prototype() + prototype = self._async_array.codec_pipeline.prototype indexer = BasicIndexer(selection, self.shape, self.metadata.chunk_grid) sync(self._async_array._set_selection(indexer, value, fields=fields, prototype=prototype)) @@ -3323,7 +3323,7 @@ def get_orthogonal_selection( """ if prototype is None: - prototype = default_buffer_prototype() + prototype = self._async_array.codec_pipeline.prototype indexer = OrthogonalIndexer(selection, self.shape, self.metadata.chunk_grid) return sync( self._async_array._get_selection( @@ -3442,7 +3442,7 @@ def set_orthogonal_selection( [__setitem__][zarr.Array.__setitem__] """ if prototype is None: - prototype = default_buffer_prototype() + prototype = self._async_array.codec_pipeline.prototype indexer = OrthogonalIndexer(selection, self.shape, self.metadata.chunk_grid) return sync( self._async_array._set_selection(indexer, value, fields=fields, prototype=prototype) @@ -3530,7 +3530,7 @@ def get_mask_selection( """ if prototype is None: - prototype = default_buffer_prototype() + prototype = self._async_array.codec_pipeline.prototype indexer = MaskIndexer(mask, self.shape, self.metadata.chunk_grid) return sync( self._async_array._get_selection( @@ -3620,7 +3620,7 @@ def set_mask_selection( """ if prototype is None: - prototype = default_buffer_prototype() + prototype = self._async_array.codec_pipeline.prototype indexer = MaskIndexer(mask, self.shape, self.metadata.chunk_grid) sync(self._async_array._set_selection(indexer, value, fields=fields, prototype=prototype)) @@ -3708,7 +3708,7 @@ def get_coordinate_selection( """ if prototype is None: - prototype = default_buffer_prototype() + prototype = self._async_array.codec_pipeline.prototype indexer = CoordinateIndexer(selection, self.shape, self.metadata.chunk_grid) out_array = sync( self._async_array._get_selection( @@ -3800,7 +3800,7 @@ def set_coordinate_selection( """ if prototype is None: - prototype = default_buffer_prototype() + prototype = self._async_array.codec_pipeline.prototype # setup indexer indexer = CoordinateIndexer(selection, self.shape, self.metadata.chunk_grid) @@ -3923,7 +3923,7 @@ def get_block_selection( [__setitem__][zarr.Array.__setitem__] """ if prototype is None: - prototype = default_buffer_prototype() + prototype = self._async_array.codec_pipeline.prototype indexer = BlockIndexer(selection, self.shape, self.metadata.chunk_grid) return sync( self._async_array._get_selection( @@ -4024,7 +4024,7 @@ def set_block_selection( """ if prototype is None: - prototype = default_buffer_prototype() + prototype = self._async_array.codec_pipeline.prototype indexer = BlockIndexer(selection, self.shape, self.metadata.chunk_grid) sync(self._async_array._set_selection(indexer, value, fields=fields, prototype=prototype)) diff --git a/src/zarr/core/codec_pipeline.py b/src/zarr/core/codec_pipeline.py index 63fcda7065..9b22f83df6 100644 --- a/src/zarr/core/codec_pipeline.py +++ b/src/zarr/core/codec_pipeline.py @@ -2,7 +2,7 @@ from dataclasses import dataclass from itertools import islice, pairwise -from typing import TYPE_CHECKING, Any, TypeVar +from typing import TYPE_CHECKING, Any, TypeVar, cast from warnings import warn from zarr.abc.codec import ( @@ -14,6 +14,7 @@ Codec, CodecPipeline, ) +from zarr.core.buffer import Buffer, BufferPrototype, NDBuffer, default_buffer_prototype from zarr.core.common import concurrent_map from zarr.core.config import config from zarr.core.indexing import SelectorTuple, is_scalar @@ -26,7 +27,6 @@ from zarr.abc.store import ByteGetter, ByteSetter from zarr.core.array_spec import ArraySpec - from zarr.core.buffer import Buffer, BufferPrototype, NDBuffer from zarr.core.chunk_grids import ChunkGrid from zarr.core.dtype.wrapper import TBaseDType, TBaseScalar, ZDType @@ -82,6 +82,26 @@ class BatchedCodecPipeline(CodecPipeline): bytes_bytes_codecs: tuple[BytesBytesCodec, ...] batch_size: int + @property + def prototype(self) -> BufferPrototype: + all_codecs = self.array_array_codecs + (self.array_bytes_codec,) + self.bytes_bytes_codecs + + current_buffer = all_codecs[0].codec_output + for codec in all_codecs[1:]: + if codec.codec_input is not current_buffer: + print(codec.codec_input, current_buffer) + raise ValueError("input buffer do not match the codec's predecessor") + + current_buffer = codec.codec_output + + nd_buffer = cast(type[NDBuffer], all_codecs[0].codec_input) + buffer = cast(type[Buffer], all_codecs[0].codec_output) + + if nd_buffer is NDBuffer and buffer is Buffer: + return default_buffer_prototype() + + return BufferPrototype(nd_buffer=nd_buffer, buffer=buffer) + def evolve_from_array_spec(self, array_spec: ArraySpec) -> Self: return type(self).from_codecs(c.evolve_from_array_spec(array_spec=array_spec) for c in self)