Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions torchax/torchax/ops/mappings.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,14 @@
import torch.utils.dlpack as torchdl
import torch.utils._mode_utils as mode_utils

NUMPY_UNSUPPORTED_DTYPES = {
torch.bfloat16: jnp.bfloat16,
torch.float8_e4m3fn: jnp.float8_e4m3fn,
torch.float8_e4m3fnuz: jnp.float8_e4m3fnuz,
torch.float8_e5m2: jnp.float8_e5m2,
torch.float8_e5m2fnuz: jnp.float8_e5m2fnuz,
}


def t2j(t, use_dlpack=True):
is_bool = False
Expand All @@ -28,14 +36,14 @@ def t2j(t, use_dlpack=True):
if res is None:
# https://github.com/google/jax/issues/7657
# https://github.com/google/jax/issues/17784
if t.dtype == torch.bfloat16:
if t.dtype in NUMPY_UNSUPPORTED_DTYPES:
nparray = (t.cpu().detach().to(torch.float32).numpy()
) # numpy don't support bfloat16
) # handle dtypes not supported by numpy
else:
nparray = t.cpu().detach().numpy()
res = jnp.asarray(nparray)
if t.dtype == torch.bfloat16:
res = res.astype(jnp.bfloat16)
if t.dtype in NUMPY_UNSUPPORTED_DTYPES:
res = res.astype(NUMPY_UNSUPPORTED_DTYPES[t.dtype])

if is_bool:
res = res.astype(jnp.bool_)
Expand Down