From c91b3777f5fac9a236e1f6d8d54091ecb3e63cf8 Mon Sep 17 00:00:00 2001 From: Jieying Luo Date: Fri, 5 Apr 2024 14:12:21 -0700 Subject: [PATCH] Add a fallback when GetDefaultLayout is unimplemented for that backend. PiperOrigin-RevId: 622278710 --- third_party/xla/xla/python/BUILD | 1 + third_party/xla/xla/python/dlpack.cc | 17 +++++++++++++++-- 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/third_party/xla/xla/python/BUILD b/third_party/xla/xla/python/BUILD index abcada18a29e9d..bf89c821ae1bae 100644 --- a/third_party/xla/xla/python/BUILD +++ b/third_party/xla/xla/python/BUILD @@ -502,6 +502,7 @@ cc_library( "//xla/python/ifrt", "//xla/python/pjrt_ifrt", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", diff --git a/third_party/xla/xla/python/dlpack.cc b/third_party/xla/xla/python/dlpack.cc index 0cb187ba27a76a..795e50e9d9bed8 100644 --- a/third_party/xla/xla/python/dlpack.cc +++ b/third_party/xla/xla/python/dlpack.cc @@ -27,6 +27,7 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_join.h" #include "absl/types/span.h" @@ -34,6 +35,7 @@ limitations under the License. #include "llvm/Support/Casting.h" #include "third_party/nanobind/include/nanobind/nanobind.h" #include "xla/layout.h" +#include "xla/layout_util.h" #include "xla/pjrt/exceptions.h" #include "xla/pjrt/pjrt_client.h" #include "xla/pjrt/pjrt_compiler.h" @@ -424,8 +426,19 @@ absl::StatusOr DLPackManagedTensorToBuffer( // for non-default layouts, and will return wrong results if a non-default // layout is passed to a computation expecting default layouts. Remove this // special case when non-default layouts are better supported by JAX. - TF_ASSIGN_OR_RETURN(Layout default_layout, device->client()->GetDefaultLayout( - element_type, dimensions)); + absl::StatusOr default_layout_from_client = + device->client()->GetDefaultLayout(element_type, dimensions); + Layout default_layout; + if (default_layout_from_client.ok()) { + default_layout = *default_layout_from_client; + } else if (absl::IsUnimplemented(default_layout_from_client.status())) { + // TODO(skyewm): consider remove the fallback path when GetDefaultLayout is + // unimplemented. + Shape host_shape = ShapeUtil::MakeShape(element_type, dimensions); + default_layout = LayoutUtil::GetWithDefaultLayout(host_shape).layout(); + } else { + return default_layout_from_client.status(); + } if (shape.layout() != default_layout) { return Unimplemented( "from_dlpack got array with non-default layout with minor-to-major "