Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/3560.bugfix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Improve write performance to large shards by up to 10x.
193 changes: 64 additions & 129 deletions src/zarr/codecs/sharding.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from collections.abc import Iterable, Mapping, MutableMapping
from dataclasses import dataclass, field, replace
from dataclasses import dataclass, replace
from enum import Enum
from functools import lru_cache
from operator import itemgetter
Expand Down Expand Up @@ -54,15 +54,15 @@
from zarr.registry import get_ndbuffer_class, get_pipeline_class

if TYPE_CHECKING:
from collections.abc import Awaitable, Callable, Iterator
from collections.abc import Iterator
from typing import Self

from zarr.core.common import JSON
from zarr.core.dtype.wrapper import TBaseDType, TBaseScalar, ZDType

MAX_UINT_64 = 2**64 - 1
ShardMapping = Mapping[tuple[int, ...], Buffer]
ShardMutableMapping = MutableMapping[tuple[int, ...], Buffer]
ShardMapping = Mapping[tuple[int, ...], Buffer | None]
ShardMutableMapping = MutableMapping[tuple[int, ...], Buffer | None]


class ShardingCodecIndexLocation(Enum):
Expand Down Expand Up @@ -219,114 +219,6 @@ def __len__(self) -> int:
def __iter__(self) -> Iterator[tuple[int, ...]]:
return c_order_iter(self.index.offsets_and_lengths.shape[:-1])

def is_empty(self) -> bool:
return self.index.is_all_empty()


class _ShardBuilder(_ShardReader, ShardMutableMapping):
buf: Buffer
index: _ShardIndex

@classmethod
def merge_with_morton_order(
cls,
chunks_per_shard: tuple[int, ...],
tombstones: set[tuple[int, ...]],
*shard_dicts: ShardMapping,
) -> _ShardBuilder:
obj = cls.create_empty(chunks_per_shard)
for chunk_coords in morton_order_iter(chunks_per_shard):
if chunk_coords in tombstones:
continue
for shard_dict in shard_dicts:
maybe_value = shard_dict.get(chunk_coords, None)
if maybe_value is not None:
obj[chunk_coords] = maybe_value
break
return obj

@classmethod
def create_empty(
cls, chunks_per_shard: tuple[int, ...], buffer_prototype: BufferPrototype | None = None
) -> _ShardBuilder:
if buffer_prototype is None:
buffer_prototype = default_buffer_prototype()
obj = cls()
obj.buf = buffer_prototype.buffer.create_zero_length()
obj.index = _ShardIndex.create_empty(chunks_per_shard)
return obj

def __setitem__(self, chunk_coords: tuple[int, ...], value: Buffer) -> None:
chunk_start = len(self.buf)
chunk_length = len(value)
self.buf += value
self.index.set_chunk_slice(chunk_coords, slice(chunk_start, chunk_start + chunk_length))

def __delitem__(self, chunk_coords: tuple[int, ...]) -> None:
raise NotImplementedError

async def finalize(
self,
index_location: ShardingCodecIndexLocation,
index_encoder: Callable[[_ShardIndex], Awaitable[Buffer]],
) -> Buffer:
index_bytes = await index_encoder(self.index)
if index_location == ShardingCodecIndexLocation.start:
empty_chunks_mask = self.index.offsets_and_lengths[..., 0] == MAX_UINT_64
self.index.offsets_and_lengths[~empty_chunks_mask, 0] += len(index_bytes)
index_bytes = await index_encoder(self.index) # encode again with corrected offsets
out_buf = index_bytes + self.buf
else:
out_buf = self.buf + index_bytes
return out_buf


@dataclass(frozen=True)
class _MergingShardBuilder(ShardMutableMapping):
old_dict: _ShardReader
new_dict: _ShardBuilder
tombstones: set[tuple[int, ...]] = field(default_factory=set)

def __getitem__(self, chunk_coords: tuple[int, ...]) -> Buffer:
chunk_bytes_maybe = self.new_dict.get(chunk_coords)
if chunk_bytes_maybe is not None:
return chunk_bytes_maybe
return self.old_dict[chunk_coords]

def __setitem__(self, chunk_coords: tuple[int, ...], value: Buffer) -> None:
self.new_dict[chunk_coords] = value

def __delitem__(self, chunk_coords: tuple[int, ...]) -> None:
self.tombstones.add(chunk_coords)

def __len__(self) -> int:
return self.old_dict.__len__()

def __iter__(self) -> Iterator[tuple[int, ...]]:
return self.old_dict.__iter__()

def is_empty(self) -> bool:
full_chunk_coords_map = self.old_dict.index.get_full_chunk_map()
full_chunk_coords_map = np.logical_or(
full_chunk_coords_map, self.new_dict.index.get_full_chunk_map()
)
for tombstone in self.tombstones:
full_chunk_coords_map[tombstone] = False
return bool(np.array_equiv(full_chunk_coords_map, False))

async def finalize(
self,
index_location: ShardingCodecIndexLocation,
index_encoder: Callable[[_ShardIndex], Awaitable[Buffer]],
) -> Buffer:
shard_builder = _ShardBuilder.merge_with_morton_order(
self.new_dict.index.chunks_per_shard,
self.tombstones,
self.new_dict,
self.old_dict,
)
return await shard_builder.finalize(index_location, index_encoder)


@dataclass(frozen=True)
class ShardingCodec(
Expand Down Expand Up @@ -573,7 +465,7 @@ async def _encode_single(
)
)

shard_builder = _ShardBuilder.create_empty(chunks_per_shard)
shard_builder = dict.fromkeys(morton_order_iter(chunks_per_shard))

await self.codec_pipeline.write(
[
Expand All @@ -589,7 +481,11 @@ async def _encode_single(
shard_array,
)

return await shard_builder.finalize(self.index_location, self._encode_shard_index)
return await self._encode_shard_dict(
shard_builder,
chunks_per_shard=chunks_per_shard,
buffer_prototype=default_buffer_prototype(),
)

async def _encode_partial_single(
self,
Expand All @@ -603,15 +499,13 @@ async def _encode_partial_single(
chunks_per_shard = self._get_chunks_per_shard(shard_spec)
chunk_spec = self._get_chunk_spec(shard_spec)

shard_dict = _MergingShardBuilder(
await self._load_full_shard_maybe(
byte_getter=byte_setter,
prototype=chunk_spec.prototype,
chunks_per_shard=chunks_per_shard,
)
or _ShardReader.create_empty(chunks_per_shard),
_ShardBuilder.create_empty(chunks_per_shard),
shard_reader = await self._load_full_shard_maybe(
byte_getter=byte_setter,
prototype=chunk_spec.prototype,
chunks_per_shard=chunks_per_shard,
)
shard_reader = shard_reader or _ShardReader.create_empty(chunks_per_shard)
shard_dict = {k: shard_reader.get(k) for k in morton_order_iter(chunks_per_shard)}

indexer = list(
get_indexer(
Expand All @@ -632,16 +526,57 @@ async def _encode_partial_single(
],
shard_array,
)
buf = await self._encode_shard_dict(
shard_dict,
chunks_per_shard=chunks_per_shard,
buffer_prototype=default_buffer_prototype(),
)

if shard_dict.is_empty():
if buf is None:
await byte_setter.delete()
else:
await byte_setter.set(
await shard_dict.finalize(
self.index_location,
self._encode_shard_index,
)
)
await byte_setter.set(buf)

async def _encode_shard_dict(
self,
map: ShardMapping,
chunks_per_shard: tuple[int, ...],
buffer_prototype: BufferPrototype,
) -> Buffer | None:
index = _ShardIndex.create_empty(chunks_per_shard)

buffers = []

template = buffer_prototype.buffer.create_zero_length()
chunk_start = 0
for chunk_coords in morton_order_iter(chunks_per_shard):
value = map.get(chunk_coords)
if value is None:
continue

if len(value) == 0:
continue

chunk_length = len(value)
buffers.append(value)
index.set_chunk_slice(chunk_coords, slice(chunk_start, chunk_start + chunk_length))
chunk_start += chunk_length

if len(buffers) == 0:
return None

index_bytes = await self._encode_shard_index(index)
if self.index_location == ShardingCodecIndexLocation.start:
empty_chunks_mask = index.offsets_and_lengths[..., 0] == MAX_UINT_64
index.offsets_and_lengths[~empty_chunks_mask, 0] += len(index_bytes)
index_bytes = await self._encode_shard_index(
index
) # encode again with corrected offsets
buffers.insert(0, index_bytes)
else:
buffers.append(index_bytes)

return template.combine(buffers)

def _is_total_shard(
self, all_chunk_coords: set[tuple[int, ...]], chunks_per_shard: tuple[int, ...]
Expand Down
7 changes: 6 additions & 1 deletion src/zarr/core/buffer/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import sys
from abc import ABC, abstractmethod
from collections.abc import Iterable
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -294,9 +295,13 @@ def __len__(self) -> int:
return self._data.size

@abstractmethod
def combine(self, others: Iterable[Buffer]) -> Self:
"""Concatenate many buffers"""
...

def __add__(self, other: Buffer) -> Self:
"""Concatenate two buffers"""
...
return self.combine([other])

def __eq__(self, other: object) -> bool:
# Another Buffer class can override this to choose a more efficient path
Expand Down
15 changes: 7 additions & 8 deletions src/zarr/core/buffer/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,14 +107,13 @@ def as_numpy_array(self) -> npt.NDArray[Any]:
"""
return np.asanyarray(self._data)

def __add__(self, other: core.Buffer) -> Self:
"""Concatenate two buffers"""

other_array = other.as_array_like()
assert other_array.dtype == np.dtype("B")
return self.__class__(
np.concatenate((np.asanyarray(self._data), np.asanyarray(other_array)))
)
def combine(self, others: Iterable[core.Buffer]) -> Self:
data = [np.asanyarray(self._data)]
for buf in others:
other_array = buf.as_array_like()
assert other_array.dtype == np.dtype("B")
data.append(np.asanyarray(other_array))
return self.__class__(np.concatenate(data))


class NDBuffer(core.NDBuffer):
Expand Down
17 changes: 9 additions & 8 deletions src/zarr/core/buffer/gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,14 +107,15 @@ def from_bytes(cls, bytes_like: BytesLike) -> Self:
def as_numpy_array(self) -> npt.NDArray[Any]:
return cast("npt.NDArray[Any]", cp.asnumpy(self._data))

def __add__(self, other: core.Buffer) -> Self:
other_array = other.as_array_like()
assert other_array.dtype == np.dtype("B")
gpu_other = Buffer(other_array)
gpu_other_array = gpu_other.as_array_like()
return self.__class__(
cp.concatenate((cp.asanyarray(self._data), cp.asanyarray(gpu_other_array)))
)
def combine(self, others: Iterable[core.Buffer]) -> Self:
data = [cp.asanyarray(self._data)]
for other in others:
other_array = other.as_array_like()
assert other_array.dtype == np.dtype("B")
gpu_other = Buffer(other_array)
gpu_other_array = gpu_other.as_array_like()
data.append(cp.asanyarray(gpu_other_array))
return self.__class__(cp.concatenate(data))


class NDBuffer(core.NDBuffer):
Expand Down