Skip to content

Commit a1b571c

Browse files
committed
Fix elementwise arithmetic with zero sized dimensions
Fixes #893.
1 parent 61fe256 commit a1b571c

File tree

2 files changed

+17
-1
lines changed

2 files changed

+17
-1
lines changed

test/cpp/test_aten_xla_tensor.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,18 @@ TEST_F(AtenXlaTensorTest, TestAddScalarInPlace) {
189189
});
190190
}
191191

192+
TEST_F(AtenXlaTensorTest, TestAddZeroSizeDim) {
193+
torch::Tensor a = torch::rand({0, 2}, torch::TensorOptions(torch::kFloat));
194+
torch::Tensor b = torch::rand({1, 2}, torch::TensorOptions(torch::kFloat));
195+
torch::Tensor c = torch::add(a, b);
196+
ForEachDevice([&](const torch::Device& device) {
197+
torch::Tensor xla_a = CopyToDevice(a, device);
198+
torch::Tensor xla_b = CopyToDevice(b, device);
199+
torch::Tensor xla_c = torch::add(xla_a, xla_b);
200+
AllClose(c, xla_c);
201+
});
202+
}
203+
192204
TEST_F(AtenXlaTensorTest, TestSub) {
193205
torch::Tensor a = torch::rand({2, 2}, torch::TensorOptions(torch::kFloat));
194206
torch::Tensor b = torch::rand({2, 2}, torch::TensorOptions(torch::kFloat));

torch_xla/csrc/helpers.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,11 @@ xla::Shape XlaHelpers::GetPromotedShape(const xla::Shape& shape1,
307307
xla::int64 dim2 = shape2_dims[shape2_dims.size() - min_size + i];
308308
XLA_CHECK(dim1 == dim2 || dim1 == 1 || dim2 == 1)
309309
<< shape1 << " and " << shape2;
310-
dimensions.push_back(std::max<xla::int64>(dim1, dim2));
310+
if (dim1 == 0 || dim2 == 0) {
311+
dimensions.push_back(0);
312+
} else {
313+
dimensions.push_back(std::max<xla::int64>(dim1, dim2));
314+
}
311315
}
312316
return xla::ShapeUtil::MakeShape(shape1.element_type(), dimensions);
313317
}

0 commit comments

Comments
 (0)