diff --git a/changes/3560.bugfix.md b/changes/3560.bugfix.md new file mode 100644 index 0000000000..c3306cb6ac --- /dev/null +++ b/changes/3560.bugfix.md @@ -0,0 +1 @@ +Improve write performance to large shards by up to 10x. diff --git a/src/zarr/codecs/sharding.py b/src/zarr/codecs/sharding.py index b0fd75cef7..8124ea44ea 100644 --- a/src/zarr/codecs/sharding.py +++ b/src/zarr/codecs/sharding.py @@ -1,7 +1,7 @@ from __future__ import annotations from collections.abc import Iterable, Mapping, MutableMapping -from dataclasses import dataclass, field, replace +from dataclasses import dataclass, replace from enum import Enum from functools import lru_cache from operator import itemgetter @@ -54,15 +54,15 @@ from zarr.registry import get_ndbuffer_class, get_pipeline_class if TYPE_CHECKING: - from collections.abc import Awaitable, Callable, Iterator + from collections.abc import Iterator from typing import Self from zarr.core.common import JSON from zarr.core.dtype.wrapper import TBaseDType, TBaseScalar, ZDType MAX_UINT_64 = 2**64 - 1 -ShardMapping = Mapping[tuple[int, ...], Buffer] -ShardMutableMapping = MutableMapping[tuple[int, ...], Buffer] +ShardMapping = Mapping[tuple[int, ...], Buffer | None] +ShardMutableMapping = MutableMapping[tuple[int, ...], Buffer | None] class ShardingCodecIndexLocation(Enum): @@ -219,114 +219,6 @@ def __len__(self) -> int: def __iter__(self) -> Iterator[tuple[int, ...]]: return c_order_iter(self.index.offsets_and_lengths.shape[:-1]) - def is_empty(self) -> bool: - return self.index.is_all_empty() - - -class _ShardBuilder(_ShardReader, ShardMutableMapping): - buf: Buffer - index: _ShardIndex - - @classmethod - def merge_with_morton_order( - cls, - chunks_per_shard: tuple[int, ...], - tombstones: set[tuple[int, ...]], - *shard_dicts: ShardMapping, - ) -> _ShardBuilder: - obj = cls.create_empty(chunks_per_shard) - for chunk_coords in morton_order_iter(chunks_per_shard): - if chunk_coords in tombstones: - continue - for shard_dict in shard_dicts: - maybe_value = shard_dict.get(chunk_coords, None) - if maybe_value is not None: - obj[chunk_coords] = maybe_value - break - return obj - - @classmethod - def create_empty( - cls, chunks_per_shard: tuple[int, ...], buffer_prototype: BufferPrototype | None = None - ) -> _ShardBuilder: - if buffer_prototype is None: - buffer_prototype = default_buffer_prototype() - obj = cls() - obj.buf = buffer_prototype.buffer.create_zero_length() - obj.index = _ShardIndex.create_empty(chunks_per_shard) - return obj - - def __setitem__(self, chunk_coords: tuple[int, ...], value: Buffer) -> None: - chunk_start = len(self.buf) - chunk_length = len(value) - self.buf += value - self.index.set_chunk_slice(chunk_coords, slice(chunk_start, chunk_start + chunk_length)) - - def __delitem__(self, chunk_coords: tuple[int, ...]) -> None: - raise NotImplementedError - - async def finalize( - self, - index_location: ShardingCodecIndexLocation, - index_encoder: Callable[[_ShardIndex], Awaitable[Buffer]], - ) -> Buffer: - index_bytes = await index_encoder(self.index) - if index_location == ShardingCodecIndexLocation.start: - empty_chunks_mask = self.index.offsets_and_lengths[..., 0] == MAX_UINT_64 - self.index.offsets_and_lengths[~empty_chunks_mask, 0] += len(index_bytes) - index_bytes = await index_encoder(self.index) # encode again with corrected offsets - out_buf = index_bytes + self.buf - else: - out_buf = self.buf + index_bytes - return out_buf - - -@dataclass(frozen=True) -class _MergingShardBuilder(ShardMutableMapping): - old_dict: _ShardReader - new_dict: _ShardBuilder - tombstones: set[tuple[int, ...]] = field(default_factory=set) - - def __getitem__(self, chunk_coords: tuple[int, ...]) -> Buffer: - chunk_bytes_maybe = self.new_dict.get(chunk_coords) - if chunk_bytes_maybe is not None: - return chunk_bytes_maybe - return self.old_dict[chunk_coords] - - def __setitem__(self, chunk_coords: tuple[int, ...], value: Buffer) -> None: - self.new_dict[chunk_coords] = value - - def __delitem__(self, chunk_coords: tuple[int, ...]) -> None: - self.tombstones.add(chunk_coords) - - def __len__(self) -> int: - return self.old_dict.__len__() - - def __iter__(self) -> Iterator[tuple[int, ...]]: - return self.old_dict.__iter__() - - def is_empty(self) -> bool: - full_chunk_coords_map = self.old_dict.index.get_full_chunk_map() - full_chunk_coords_map = np.logical_or( - full_chunk_coords_map, self.new_dict.index.get_full_chunk_map() - ) - for tombstone in self.tombstones: - full_chunk_coords_map[tombstone] = False - return bool(np.array_equiv(full_chunk_coords_map, False)) - - async def finalize( - self, - index_location: ShardingCodecIndexLocation, - index_encoder: Callable[[_ShardIndex], Awaitable[Buffer]], - ) -> Buffer: - shard_builder = _ShardBuilder.merge_with_morton_order( - self.new_dict.index.chunks_per_shard, - self.tombstones, - self.new_dict, - self.old_dict, - ) - return await shard_builder.finalize(index_location, index_encoder) - @dataclass(frozen=True) class ShardingCodec( @@ -573,7 +465,7 @@ async def _encode_single( ) ) - shard_builder = _ShardBuilder.create_empty(chunks_per_shard) + shard_builder = dict.fromkeys(morton_order_iter(chunks_per_shard)) await self.codec_pipeline.write( [ @@ -589,7 +481,11 @@ async def _encode_single( shard_array, ) - return await shard_builder.finalize(self.index_location, self._encode_shard_index) + return await self._encode_shard_dict( + shard_builder, + chunks_per_shard=chunks_per_shard, + buffer_prototype=default_buffer_prototype(), + ) async def _encode_partial_single( self, @@ -603,15 +499,13 @@ async def _encode_partial_single( chunks_per_shard = self._get_chunks_per_shard(shard_spec) chunk_spec = self._get_chunk_spec(shard_spec) - shard_dict = _MergingShardBuilder( - await self._load_full_shard_maybe( - byte_getter=byte_setter, - prototype=chunk_spec.prototype, - chunks_per_shard=chunks_per_shard, - ) - or _ShardReader.create_empty(chunks_per_shard), - _ShardBuilder.create_empty(chunks_per_shard), + shard_reader = await self._load_full_shard_maybe( + byte_getter=byte_setter, + prototype=chunk_spec.prototype, + chunks_per_shard=chunks_per_shard, ) + shard_reader = shard_reader or _ShardReader.create_empty(chunks_per_shard) + shard_dict = {k: shard_reader.get(k) for k in morton_order_iter(chunks_per_shard)} indexer = list( get_indexer( @@ -632,16 +526,57 @@ async def _encode_partial_single( ], shard_array, ) + buf = await self._encode_shard_dict( + shard_dict, + chunks_per_shard=chunks_per_shard, + buffer_prototype=default_buffer_prototype(), + ) - if shard_dict.is_empty(): + if buf is None: await byte_setter.delete() else: - await byte_setter.set( - await shard_dict.finalize( - self.index_location, - self._encode_shard_index, - ) - ) + await byte_setter.set(buf) + + async def _encode_shard_dict( + self, + map: ShardMapping, + chunks_per_shard: tuple[int, ...], + buffer_prototype: BufferPrototype, + ) -> Buffer | None: + index = _ShardIndex.create_empty(chunks_per_shard) + + buffers = [] + + template = buffer_prototype.buffer.create_zero_length() + chunk_start = 0 + for chunk_coords in morton_order_iter(chunks_per_shard): + value = map.get(chunk_coords) + if value is None: + continue + + if len(value) == 0: + continue + + chunk_length = len(value) + buffers.append(value) + index.set_chunk_slice(chunk_coords, slice(chunk_start, chunk_start + chunk_length)) + chunk_start += chunk_length + + if len(buffers) == 0: + return None + + index_bytes = await self._encode_shard_index(index) + if self.index_location == ShardingCodecIndexLocation.start: + empty_chunks_mask = index.offsets_and_lengths[..., 0] == MAX_UINT_64 + index.offsets_and_lengths[~empty_chunks_mask, 0] += len(index_bytes) + index_bytes = await self._encode_shard_index( + index + ) # encode again with corrected offsets + buffers.insert(0, index_bytes) + else: + buffers.append(index_bytes) + + return template.combine(buffers) def _is_total_shard( self, all_chunk_coords: set[tuple[int, ...]], chunks_per_shard: tuple[int, ...] diff --git a/src/zarr/core/buffer/core.py b/src/zarr/core/buffer/core.py index 189916dc91..ddd3073af2 100644 --- a/src/zarr/core/buffer/core.py +++ b/src/zarr/core/buffer/core.py @@ -2,6 +2,7 @@ import sys from abc import ABC, abstractmethod +from collections.abc import Iterable from typing import ( TYPE_CHECKING, Any, @@ -294,9 +295,13 @@ def __len__(self) -> int: return self._data.size @abstractmethod + def combine(self, others: Iterable[Buffer]) -> Self: + """Concatenate many buffers""" + ... + def __add__(self, other: Buffer) -> Self: """Concatenate two buffers""" - ... + return self.combine([other]) def __eq__(self, other: object) -> bool: # Another Buffer class can override this to choose a more efficient path diff --git a/src/zarr/core/buffer/cpu.py b/src/zarr/core/buffer/cpu.py index 415b9d928c..58275d2843 100644 --- a/src/zarr/core/buffer/cpu.py +++ b/src/zarr/core/buffer/cpu.py @@ -107,14 +107,13 @@ def as_numpy_array(self) -> npt.NDArray[Any]: """ return np.asanyarray(self._data) - def __add__(self, other: core.Buffer) -> Self: - """Concatenate two buffers""" - - other_array = other.as_array_like() - assert other_array.dtype == np.dtype("B") - return self.__class__( - np.concatenate((np.asanyarray(self._data), np.asanyarray(other_array))) - ) + def combine(self, others: Iterable[core.Buffer]) -> Self: + data = [np.asanyarray(self._data)] + for buf in others: + other_array = buf.as_array_like() + assert other_array.dtype == np.dtype("B") + data.append(np.asanyarray(other_array)) + return self.__class__(np.concatenate(data)) class NDBuffer(core.NDBuffer): diff --git a/src/zarr/core/buffer/gpu.py b/src/zarr/core/buffer/gpu.py index bfe977c50f..2a591884ae 100644 --- a/src/zarr/core/buffer/gpu.py +++ b/src/zarr/core/buffer/gpu.py @@ -107,14 +107,15 @@ def from_bytes(cls, bytes_like: BytesLike) -> Self: def as_numpy_array(self) -> npt.NDArray[Any]: return cast("npt.NDArray[Any]", cp.asnumpy(self._data)) - def __add__(self, other: core.Buffer) -> Self: - other_array = other.as_array_like() - assert other_array.dtype == np.dtype("B") - gpu_other = Buffer(other_array) - gpu_other_array = gpu_other.as_array_like() - return self.__class__( - cp.concatenate((cp.asanyarray(self._data), cp.asanyarray(gpu_other_array))) - ) + def combine(self, others: Iterable[core.Buffer]) -> Self: + data = [cp.asanyarray(self._data)] + for other in others: + other_array = other.as_array_like() + assert other_array.dtype == np.dtype("B") + gpu_other = Buffer(other_array) + gpu_other_array = gpu_other.as_array_like() + data.append(cp.asanyarray(gpu_other_array)) + return self.__class__(cp.concatenate(data)) class NDBuffer(core.NDBuffer):