From e7532cd93f164c5df6ffb7fcc3ae9dea76c64269 Mon Sep 17 00:00:00 2001 From: Yue Sheng Date: Fri, 10 May 2024 21:07:18 -0700 Subject: [PATCH] Raise a runtime error when trying to convert the `jax.Array` wrapped by `jax.core.Token` to a numpy array, as it is an internal implementation detail and the buffer has XLA token shape. PiperOrigin-RevId: 632682906 --- third_party/xla/xla/python/py_array.cc | 7 +++++++ third_party/xla/xla/python/xla_client.py | 2 +- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/third_party/xla/xla/python/py_array.cc b/third_party/xla/xla/python/py_array.cc index 741105dea0cd65..320f2bfdd20849 100644 --- a/third_party/xla/xla/python/py_array.cc +++ b/third_party/xla/xla/python/py_array.cc @@ -1412,6 +1412,13 @@ StatusOr PyHostValue::AsNumPyArray( if (ifrt_array->IsDeleted()) { return InvalidArgument("DeviceArray has been deleted."); } + // The only `jax.Array` with token-shape buffer is the one wrapped by + // `jax.core.Token`. Since it is an internal implementation detail, we + // don't support converting it to a numpy array. + if (ifrt_array->dtype().kind() == ifrt::DType::kToken) { + return InvalidArgument( + "Cannot convert a token-shape buffer to a numpy array."); + } auto* arr = llvm::dyn_cast_or_null(ifrt_array); if (arr != nullptr) { auto* pjrt_buffer = arr->pjrt_buffers().front().get(); diff --git a/third_party/xla/xla/python/xla_client.py b/third_party/xla/xla/python/xla_client.py index bebb13b376efd3..37c462efc3e0f6 100644 --- a/third_party/xla/xla/python/xla_client.py +++ b/third_party/xla/xla/python/xla_client.py @@ -49,7 +49,7 @@ # Just an internal arbitrary increasing number to help with backward-compatible # changes. In JAX, reference this via jax._src.lib.xla_extension_version. -_version = 263 +_version = 264 # Version number for MLIR:Python components. mlir_api_version = 56