From 7d3c6c16100cdcb01f60287369599ea3dd944fa8 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 22 May 2024 17:41:06 -0700 Subject: [PATCH] Add a check that if any single dimension of an xla shape is 0 the data size is 0. PiperOrigin-RevId: 636352383 --- third_party/xla/xla/shape_util_test.cc | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/third_party/xla/xla/shape_util_test.cc b/third_party/xla/xla/shape_util_test.cc index 6ebda8e53ed79a..f72cf4e4742cb9 100644 --- a/third_party/xla/xla/shape_util_test.cc +++ b/third_party/xla/xla/shape_util_test.cc @@ -1202,6 +1202,20 @@ TEST(ShapeUtilTest, Int4ShapeSize) { EXPECT_EQ(ShapeUtil::ArraySize(int4_shape2), 9216 * 6144 / 2); } +TEST(XlaShapeUtilTest, ZeroSize) { + // Verify that if any one dimension is 0 we have a zero byte buffer. + std::vector> test_cases = { + {0, 64, 128}, {128, 0, 64}, {64, 128, 0}, + {0, 63, 127}, {127, 0, 63}, {63, 127, 0}, + }; + for (const auto& dimensions : test_cases) { + xla::Shape int4_shape = xla::ShapeUtil::MakeShape(xla::S4, dimensions); + int4_shape.mutable_layout()->set_element_size_in_bits(4); + EXPECT_EQ(xla::ShapeUtil::ArrayDataSize(int4_shape), 0); + EXPECT_EQ(xla::ShapeUtil::ArraySize(int4_shape), 0); + } +} + TEST(ShapeUtilTest, DecomposeBitcastToReshape) { const Shape kInputShape = ShapeUtil::MakeShapeWithDenseLayout(F32, {1, 16, 17, 3}, {3, 2, 1, 0});