From 44c269e2196214352aa08ce8e30a8752150558f5 Mon Sep 17 00:00:00 2001 From: Kyuyeun Kim Date: Wed, 20 Aug 2025 19:26:13 +0000 Subject: [PATCH] Create mapping for FP8 torch dtypes --- torchax/torchax/ops/mappings.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/torchax/torchax/ops/mappings.py b/torchax/torchax/ops/mappings.py index 409a6d8350be..4eb7c6996159 100644 --- a/torchax/torchax/ops/mappings.py +++ b/torchax/torchax/ops/mappings.py @@ -6,6 +6,14 @@ import torch.utils.dlpack as torchdl import torch.utils._mode_utils as mode_utils +NUMPY_UNSUPPORTED_DTYPES = { + torch.bfloat16: jnp.bfloat16, + torch.float8_e4m3fn: jnp.float8_e4m3fn, + torch.float8_e4m3fnuz: jnp.float8_e4m3fnuz, + torch.float8_e5m2: jnp.float8_e5m2, + torch.float8_e5m2fnuz: jnp.float8_e5m2fnuz, +} + def t2j(t, use_dlpack=True): is_bool = False @@ -28,14 +36,14 @@ def t2j(t, use_dlpack=True): if res is None: # https://github.com/google/jax/issues/7657 # https://github.com/google/jax/issues/17784 - if t.dtype == torch.bfloat16: + if t.dtype in NUMPY_UNSUPPORTED_DTYPES: nparray = (t.cpu().detach().to(torch.float32).numpy() - ) # numpy don't support bfloat16 + ) # handle dtypes not supported by numpy else: nparray = t.cpu().detach().numpy() res = jnp.asarray(nparray) - if t.dtype == torch.bfloat16: - res = res.astype(jnp.bfloat16) + if t.dtype in NUMPY_UNSUPPORTED_DTYPES: + res = res.astype(NUMPY_UNSUPPORTED_DTYPES[t.dtype]) if is_bool: res = res.astype(jnp.bool_)