Skip to content

Commit 48e12e3

Browse files
authored
Fix non_zero cast issue (#4243)
* Make nonzero result to reflect the real type * Make sizeNode's type hardware dependent * fix cpu and gpu error * fix gpu and tpu cpp test
1 parent d668f2c commit 48e12e3

File tree

4 files changed

+17
-4
lines changed

4 files changed

+17
-4
lines changed

test/cpp/test_aten_xla_tensor.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5140,7 +5140,7 @@ TEST_F(AtenXlaTensorTest, TestNonzero) {
51405140
ForEachDevice([&](const torch::Device& device) {
51415141
torch::Tensor xla_a = CopyToDevice(a, device);
51425142
torch::Tensor xla_b = torch::nonzero(xla_a);
5143-
AllClose(b, xla_b);
5143+
AllClose(b, torch::_cast_Long(xla_b));
51445144

51455145
if (DebugUtil::ExperimentEnabled("nonzero")) {
51465146
// If the nonzero support is enabled, we must not see any aten:: calls.

test/test_operations.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -749,6 +749,14 @@ def test_masked_select_shape(self):
749749
torch.masked_select(x, mask), 0)
750750
self.assertEqual(x_dim0_shape.item(), 3)
751751

752+
def test_nonzero_cast(self):
753+
t1 = torch.ones(5, 2, device=xm.xla_device())
754+
# Result of the nonzero should be the index type. Currently
755+
# index type is s64 on cpu and gpu, but s32 on TPU. We should be
756+
# able to cast it to any other type without error.
757+
t2 = torch.nonzero(t1.int()).float()
758+
xm.mark_step()
759+
752760

753761
class TestOptimizationBarrier(XlaTestCase):
754762

torch_xla/csrc/ops/dynamic_ir.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include "torch_xla/csrc/lowering_context.h"
66
#include "torch_xla/csrc/ops/infer_output_shape.h"
77
#include "torch_xla/csrc/tensor.h"
8+
#include "torch_xla/csrc/tensor_util.h"
89

910
namespace torch_xla {
1011

@@ -23,8 +24,10 @@ const std::shared_ptr<torch::lazy::DimensionNode> DimCast(
2324

2425
SizeNode::SizeNode(torch::lazy::Value input, size_t dim)
2526
: XlaNode(torch::lazy::OpKind{c10::Symbol::fromQualString("aten::size")},
26-
{input}, xla::ShapeUtil::MakeShape(xla::S64, {}), 1,
27-
torch::lazy::MHash(dim)),
27+
{input},
28+
xla::ShapeUtil::MakeShape(
29+
GetShapeDimensionType(/*device=*/nullptr), {}),
30+
1, torch::lazy::MHash(dim)),
2831
dim_(dim) {
2932
// Not all IR has torch::lazy::shape now, use xla::shape to unblock
3033
// the development.

torch_xla/csrc/tensor_methods.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1876,7 +1876,9 @@ std::pair<XLATensorPtr, XLATensorPtr> XLATensor::nms(
18761876
XLATensorPtr XLATensor::nonzero(const XLATensorPtr& input) {
18771877
torch::lazy::NodePtr node =
18781878
torch::lazy::MakeNode<NonZero>(input->GetIrValue());
1879-
return input->CreateFrom(torch::lazy::Value(node, 0), at::ScalarType::Long);
1879+
// Nonzero result type should not depend on input type, hence we shouldn't
1880+
// use input->CreateFrom which will inherit the logical_element_type.
1881+
return Create(torch::lazy::Value(node, 0), input->GetDevice());
18801882
}
18811883

18821884
XLATensorPtr XLATensor::norm(const XLATensorPtr& input,

0 commit comments

Comments
 (0)