Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for bfloat16 data type #711

Open
jbms opened this issue Mar 16, 2021 · 0 comments
Open

Support for bfloat16 data type #711

jbms opened this issue Mar 16, 2021 · 0 comments

Comments

@jbms
Copy link

jbms commented Mar 16, 2021

Numpy does not natively support bfloat16, but Jax and Tensorflow define a bfloat16 numpy dtype. It would be great if the zarr format provided a way to store it.

Currently the zarr library already works for writing bfloat16 data, but as the dtype is stored as "<V2" reading is not supported without explicitly calling .view after opening:

import zarr
import jax.numpy as jnp
import numpy as np

bfloat16 = jnp.bfloat16
np.typeDict['bfloat16'] = bfloat16

my_store = dict()

z1 = zarr.open(mode='w', shape=(1,), compressor=None, dtype=np.dtype(bfloat16), store=my_store)
z1[0] = np.array(42, dtype=bfloat16)
print('Original array: %r' % (z1[0],))
print('Original array with view: %r' % (z1.view(dtype=bfloat16)[0],))

z2 = zarr.open(mode='r', store=my_store)
print('Reopening with original dtype: %r' % (z2[0],))
print('Reopening with original dtype with view: %r' % (z2.view(dtype=bfloat16)[0],))

my_store['.zarray'] = my_store['.zarray'].replace(b'<V2', b'bfloat16')
z3 = zarr.open(mode='r', store=my_store)
print('With adjusted dtype: %r' % (z3[0],))

Output is:

Original array: void(b'\x28\x42')
Original array with view: 42
Reopening with original dtype: void(b'\x28\x42')
Reopening with original dtype with view: 42
With adjusted dtype: 42

Replacing the stored dtype in the .zarray file with "bfloat16" seems to be the only way to get zarr to use bfloat16 as the data type when opening. (That requires registering the data type in np.typeDict, which jax does not do but probably should do.) That does not allow specifying the byte order, but supporting a big endian machine may not be too important.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant