-
-
Notifications
You must be signed in to change notification settings - Fork 363
Zstd Codec on the GPU #2863
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Zstd Codec on the GPU #2863
Changes from all commits
49d5ee8
d548adc
a8c0db3
69aa274
10e1bc9
771c0c1
ec07100
d1c37a3
69ea74e
1b85fdc
f5c7814
7671274
d558ef8
f16d730
f0db57d
048ad48
2282cb9
c6460b5
3b5e294
dd825dc
f89b232
7a4b037
398b4d1
dd69543
76f7560
8b5b3f1
7af3a16
996fbc0
d24d027
090349c
83c53b0
eb50521
de3b577
ac14838
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
Added GPU-accelerated Zstd Codec | ||
|
||
This adds support for decoding with the Zstd Codec on NVIDIA GPUs using the | ||
nvidia-nvcomp library. | ||
|
||
With `zarr.config.enable_gpu()`, buffers will be decoded using the GPU | ||
and the output will reside in device memory. |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
.. _user-guide-gpu: | ||
|
||
Using GPUs with Zarr | ||
==================== | ||
|
||
Zarr can use GPUs to accelerate your workload by running | ||
:meth:`zarr.config.enable_gpu`. | ||
|
||
Reading data into device memory | ||
------------------------------- | ||
|
||
:meth:`zarr.config.enable_gpu` configures Zarr to use GPU memory for the data | ||
buffers used internally by Zarr. | ||
|
||
.. code-block:: python | ||
|
||
>>> import zarr | ||
>>> import cupy as cp # doctest: +SKIP | ||
>>> zarr.config.enable_gpu() # doctest: +SKIP | ||
>>> store = zarr.storage.MemoryStore() # doctest: +SKIP | ||
>>> z = zarr.create_array( # doctest: +SKIP | ||
... store=store, shape=(100, 100), chunks=(10, 10), dtype="float32", | ||
... ) | ||
>>> type(z[:10, :10]) # doctest: +SKIP | ||
cupy.ndarray | ||
|
||
Note that the output type is a ``cupy.ndarray`` rather than a NumPy array. | ||
|
||
For supported codecs, data will be decoded using the GPU via the `nvcomp`_ | ||
library. See :ref:`user-guide-config` for more. Isseus and feature requests | ||
for NVIDIA nvCOMP can be reported in the `nvcomp issue tracker`_. | ||
|
||
.. _nvcomp: https://docs.nvidia.com/cuda/nvcomp/samples/python_samples.html | ||
.. _nvcomp issue tracker: https://github.com/NVIDIA/CUDALibrarySamples/issues |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -67,6 +67,7 @@ remote = [ | |
] | ||
gpu = [ | ||
"cupy-cuda12x", | ||
"nvidia-nvcomp-cu12", | ||
] | ||
cli = ["typer"] | ||
# Development extras | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,176 @@ | ||
from __future__ import annotations | ||
|
||
import asyncio | ||
from dataclasses import dataclass | ||
from functools import cached_property | ||
from typing import TYPE_CHECKING | ||
|
||
import numpy as np | ||
|
||
from zarr.abc.codec import BytesBytesCodec | ||
from zarr.core.common import JSON, parse_named_configuration | ||
from zarr.registry import register_codec | ||
|
||
if TYPE_CHECKING: | ||
from collections.abc import Iterable | ||
from typing import Self | ||
|
||
from zarr.core.array_spec import ArraySpec | ||
from zarr.core.buffer import Buffer | ||
|
||
try: | ||
import cupy as cp | ||
except ImportError: # pragma: no cover | ||
cp = None | ||
|
||
try: | ||
from nvidia import nvcomp | ||
except ImportError: # pragma: no cover | ||
nvcomp = None | ||
|
||
|
||
def _parse_zstd_level(data: JSON) -> int: | ||
if isinstance(data, int): | ||
if data >= 23: | ||
raise ValueError(f"Value must be less than or equal to 22. Got {data} instead.") | ||
return data | ||
raise TypeError(f"Got value with type {type(data)}, but expected an int.") | ||
|
||
|
||
def _parse_checksum(data: JSON) -> bool: | ||
if isinstance(data, bool): | ||
return data | ||
raise TypeError(f"Expected bool. Got {type(data)}.") | ||
|
||
|
||
@dataclass(frozen=True) | ||
class NvcompZstdCodec(BytesBytesCodec): | ||
is_fixed_size = True | ||
|
||
level: int = 0 | ||
checksum: bool = False | ||
|
||
def __init__(self, *, level: int = 0, checksum: bool = False) -> None: | ||
# TODO: Set CUDA device appropriately here and also set CUDA stream | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agreed with leaving devices / streams as a TODO for now. I want to enable users to overlap host-to-device memcpys with compute operations (like decode, but their own compute operations as well), but I'm not sure yet what that API will look like. If you have any thoughts on how best to do this I'd love to hear them, and write them up as an issue. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. #3271 for planning on devices and streams. |
||
|
||
level_parsed = _parse_zstd_level(level) | ||
checksum_parsed = _parse_checksum(checksum) | ||
|
||
object.__setattr__(self, "level", level_parsed) | ||
object.__setattr__(self, "checksum", checksum_parsed) | ||
|
||
@classmethod | ||
def from_dict(cls, data: dict[str, JSON]) -> Self: | ||
_, configuration_parsed = parse_named_configuration(data, "zstd") | ||
return cls(**configuration_parsed) # type: ignore[arg-type] | ||
|
||
def to_dict(self) -> dict[str, JSON]: | ||
return { | ||
"name": "zstd", | ||
"configuration": {"level": self.level, "checksum": self.checksum}, | ||
} | ||
|
||
@cached_property | ||
def _zstd_codec(self) -> nvcomp.Codec: | ||
device = cp.cuda.Device() # Select the current default device | ||
stream = cp.cuda.get_current_stream() # Use the current default stream | ||
return nvcomp.Codec( | ||
algorithm="Zstd", | ||
bitstream_kind=nvcomp.BitstreamKind.RAW, | ||
device_id=device.id, | ||
cuda_stream=stream.ptr, | ||
) | ||
|
||
def _convert_to_nvcomp_arrays( | ||
self, | ||
chunks_and_specs: Iterable[tuple[Buffer | None, ArraySpec]], | ||
) -> tuple[list[nvcomp.Array], list[int]]: | ||
none_indices = [i for i, (b, _) in enumerate(chunks_and_specs) if b is None] | ||
filtered_inputs = [b.as_array_like() for b, _ in chunks_and_specs if b is not None] | ||
# TODO: add CUDA stream here | ||
return nvcomp.as_arrays(filtered_inputs), none_indices | ||
|
||
def _convert_from_nvcomp_arrays( | ||
self, | ||
arrays: Iterable[nvcomp.Array], | ||
chunks_and_specs: Iterable[tuple[Buffer | None, ArraySpec]], | ||
) -> Iterable[Buffer | None]: | ||
return [ | ||
spec.prototype.buffer.from_array_like(cp.array(a, dtype=np.dtype("B"), copy=False)) | ||
if a | ||
else None | ||
for a, (_, spec) in zip(arrays, chunks_and_specs, strict=True) | ||
] | ||
|
||
async def decode( | ||
self, | ||
chunks_and_specs: Iterable[tuple[Buffer | None, ArraySpec]], | ||
) -> Iterable[Buffer | None]: | ||
"""Decodes a batch of chunks. | ||
Chunks can be None in which case they are ignored by the codec. | ||
|
||
Parameters | ||
---------- | ||
chunks_and_specs : Iterable[tuple[Buffer | None, ArraySpec]] | ||
Ordered set of encoded chunks with their accompanying chunk spec. | ||
|
||
Returns | ||
------- | ||
Iterable[Buffer | None] | ||
""" | ||
chunks_and_specs = list(chunks_and_specs) | ||
|
||
# Convert to nvcomp arrays | ||
filtered_inputs, none_indices = self._convert_to_nvcomp_arrays(chunks_and_specs) | ||
|
||
outputs = self._zstd_codec.decode(filtered_inputs) if len(filtered_inputs) > 0 else [] | ||
TomAugspurger marked this conversation as resolved.
Show resolved
Hide resolved
TomAugspurger marked this conversation as resolved.
Show resolved
Hide resolved
TomAugspurger marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
# Record event for synchronization | ||
event = cp.cuda.Event() | ||
# Wait for decode to complete in a separate async thread | ||
await asyncio.to_thread(event.synchronize) | ||
|
||
for index in none_indices: | ||
outputs.insert(index, None) | ||
|
||
return self._convert_from_nvcomp_arrays(outputs, chunks_and_specs) | ||
|
||
async def encode( | ||
self, | ||
chunks_and_specs: Iterable[tuple[Buffer | None, ArraySpec]], | ||
) -> Iterable[Buffer | None]: | ||
"""Encodes a batch of chunks. | ||
Chunks can be None in which case they are ignored by the codec. | ||
|
||
Parameters | ||
---------- | ||
chunks_and_specs : Iterable[tuple[Buffer | None, ArraySpec]] | ||
Ordered set of to-be-encoded chunks with their accompanying chunk spec. | ||
|
||
Returns | ||
------- | ||
Iterable[Buffer | None] | ||
""" | ||
# TODO: Make this actually async | ||
chunks_and_specs = list(chunks_and_specs) | ||
|
||
# Convert to nvcomp arrays | ||
filtered_inputs, none_indices = self._convert_to_nvcomp_arrays(chunks_and_specs) | ||
|
||
outputs = self._zstd_codec.encode(filtered_inputs) if len(filtered_inputs) > 0 else [] | ||
|
||
# Record event for synchronization | ||
event = cp.cuda.Event() | ||
# Wait for decode to complete in a separate async thread | ||
await asyncio.to_thread(event.synchronize) | ||
|
||
for index in none_indices: | ||
outputs.insert(index, None) | ||
TomAugspurger marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
return self._convert_from_nvcomp_arrays(outputs, chunks_and_specs) | ||
|
||
def compute_encoded_size(self, _input_byte_length: int, _chunk_spec: ArraySpec) -> int: | ||
raise NotImplementedError | ||
|
||
|
||
register_codec("zstd", NvcompZstdCodec) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -28,7 +28,6 @@ | |
from zarr.codecs._v2 import V2Codec | ||
from zarr.codecs.bytes import BytesCodec | ||
from zarr.codecs.vlen_utf8 import VLenBytesCodec, VLenUTF8Codec | ||
from zarr.codecs.zstd import ZstdCodec | ||
from zarr.core._info import ArrayInfo | ||
from zarr.core.array_spec import ArrayConfig, ArrayConfigLike, parse_array_config | ||
from zarr.core.attributes import Attributes | ||
|
@@ -128,6 +127,7 @@ | |
_parse_array_array_codec, | ||
_parse_array_bytes_codec, | ||
_parse_bytes_bytes_codec, | ||
get_codec_class, | ||
get_pipeline_class, | ||
) | ||
from zarr.storage._common import StorePath, ensure_no_existing_node, make_store_path | ||
|
@@ -5036,9 +5036,9 @@ def default_compressors_v3(dtype: ZDType[Any, Any]) -> tuple[BytesBytesCodec, .. | |
""" | ||
Given a data type, return the default compressors for that data type. | ||
|
||
This is just a tuple containing ``ZstdCodec`` | ||
This is just a tuple containing an instance of the default "zstd" codec class. | ||
""" | ||
return (ZstdCodec(),) | ||
return (cast(BytesBytesCodec, get_codec_class("zstd")()),) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is the extra There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
||
|
||
def default_serializer_v3(dtype: ZDType[Any, Any]) -> ArrayBytesCodec: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,9 +8,6 @@ | |
cast, | ||
) | ||
|
||
import numpy as np | ||
import numpy.typing as npt | ||
|
||
from zarr.core.buffer import core | ||
from zarr.core.buffer.core import ArrayLike, BufferPrototype, NDArrayLike | ||
from zarr.errors import ZarrUserWarning | ||
|
@@ -23,8 +20,9 @@ | |
from collections.abc import Iterable | ||
from typing import Self | ||
|
||
from zarr.core.common import BytesLike | ||
import numpy.typing as npt | ||
|
||
from zarr.core.common import BytesLike | ||
try: | ||
import cupy as cp | ||
except ImportError: | ||
|
@@ -54,14 +52,14 @@ class Buffer(core.Buffer): | |
|
||
def __init__(self, array_like: ArrayLike) -> None: | ||
if cp is None: | ||
raise ImportError( | ||
raise ImportError( # pragma: no cover | ||
"Cannot use zarr.buffer.gpu.Buffer without cupy. Please install cupy." | ||
) | ||
|
||
if array_like.ndim != 1: | ||
raise ValueError("array_like: only 1-dim allowed") | ||
if array_like.dtype != np.dtype("B"): | ||
raise ValueError("array_like: only byte dtype allowed") | ||
if array_like.dtype.itemsize != 1: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The new tests in I'd like to get us to a point where we don't care as much about the details of the buffer passed in here. This is an OK start I think. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what exactly does this check for? It's not clear to me why any numpy array that can be viewed as bytes wouldn't be allowed in here There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. and same for the dimensionality check, since any N-dimensional numpy array can be viewed as a 1D array. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, I'm not really sure... I agree that the actual data we store internally here needs to be a byte dtype. Just doing There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i'm not even convinced that we need There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we could even express this formally by:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (super out of scope for this PR ofc) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I think this is mainly because But I agree that we can probably merge those two and make |
||
raise ValueError("array_like: only dtypes with itemsize=1 allowed") | ||
|
||
if not hasattr(array_like, "__cuda_array_interface__"): | ||
# Slow copy based path for arrays that don't support the __cuda_array_interface__ | ||
|
@@ -108,13 +106,13 @@ def as_numpy_array(self) -> npt.NDArray[Any]: | |
return cast("npt.NDArray[Any]", cp.asnumpy(self._data)) | ||
|
||
def __add__(self, other: core.Buffer) -> Self: | ||
other_array = other.as_array_like() | ||
assert other_array.dtype == np.dtype("B") | ||
gpu_other = Buffer(other_array) | ||
gpu_other_array = gpu_other.as_array_like() | ||
return self.__class__( | ||
cp.concatenate((cp.asanyarray(self._data), cp.asanyarray(gpu_other_array))) | ||
) | ||
other_array = cp.asanyarray(other.as_array_like()) | ||
left = self._data | ||
if left.dtype != other_array.dtype: | ||
other_array = other_array.view(left.dtype) | ||
|
||
buffer = cp.concatenate([left, other_array]) | ||
return type(self)(buffer) | ||
|
||
|
||
class NDBuffer(core.NDBuffer): | ||
|
@@ -144,7 +142,7 @@ class NDBuffer(core.NDBuffer): | |
|
||
def __init__(self, array: NDArrayLike) -> None: | ||
if cp is None: | ||
raise ImportError( | ||
raise ImportError( # pragma: no cover | ||
"Cannot use zarr.buffer.gpu.NDBuffer without cupy. Please install cupy." | ||
) | ||
|
||
|
Uh oh!
There was an error while loading. Please reload this page.