Skip to content

Commit

Permalink
Add a fallback when GetDefaultLayout is unimplemented for that backend.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 622278710
  • Loading branch information
jyingl3 authored and tensorflower-gardener committed Apr 5, 2024
1 parent 7b54481 commit c91b377
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 2 deletions.
1 change: 1 addition & 0 deletions third_party/xla/xla/python/BUILD
Expand Up @@ -502,6 +502,7 @@ cc_library(
"//xla/python/ifrt",
"//xla/python/pjrt_ifrt",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
Expand Down
17 changes: 15 additions & 2 deletions third_party/xla/xla/python/dlpack.cc
Expand Up @@ -27,13 +27,15 @@ limitations under the License.
#include <vector>

#include "absl/algorithm/container.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_join.h"
#include "absl/types/span.h"
#include "include/dlpack/dlpack.h" // from @dlpack
#include "llvm/Support/Casting.h"
#include "third_party/nanobind/include/nanobind/nanobind.h"
#include "xla/layout.h"
#include "xla/layout_util.h"
#include "xla/pjrt/exceptions.h"
#include "xla/pjrt/pjrt_client.h"
#include "xla/pjrt/pjrt_compiler.h"
Expand Down Expand Up @@ -424,8 +426,19 @@ absl::StatusOr<nb::object> DLPackManagedTensorToBuffer(
// for non-default layouts, and will return wrong results if a non-default
// layout is passed to a computation expecting default layouts. Remove this
// special case when non-default layouts are better supported by JAX.
TF_ASSIGN_OR_RETURN(Layout default_layout, device->client()->GetDefaultLayout(
element_type, dimensions));
absl::StatusOr<Layout> default_layout_from_client =
device->client()->GetDefaultLayout(element_type, dimensions);
Layout default_layout;
if (default_layout_from_client.ok()) {
default_layout = *default_layout_from_client;
} else if (absl::IsUnimplemented(default_layout_from_client.status())) {
// TODO(skyewm): consider remove the fallback path when GetDefaultLayout is
// unimplemented.
Shape host_shape = ShapeUtil::MakeShape(element_type, dimensions);
default_layout = LayoutUtil::GetWithDefaultLayout(host_shape).layout();
} else {
return default_layout_from_client.status();
}
if (shape.layout() != default_layout) {
return Unimplemented(
"from_dlpack got array with non-default layout with minor-to-major "
Expand Down

0 comments on commit c91b377

Please sign in to comment.