Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions test/cpp/test_aten_xla_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4880,6 +4880,10 @@ TEST_F(AtenXlaTensorTest, TestScatterAddInPlace) {
}

TEST_F(AtenXlaTensorTest, TestScatterReduceSum) {
GTEST_SKIP() << "Unrecognized `reduce` at "
"https://github.com/pytorch/xla/blob/"
"933dcc21c51676f72a41f2989f5bbba760a498c0/torch_xla/csrc/ops/"
"scatter_reduce.cpp#L42 after functionalization";
torch::Tensor a = torch::rand({3, 5}, torch::TensorOptions(torch::kFloat));
torch::Tensor b = torch::rand({3, 5}, torch::TensorOptions(torch::kFloat));
torch::Tensor c = torch::empty({3, 5}, torch::TensorOptions(torch::kLong));
Expand Down Expand Up @@ -4930,6 +4934,10 @@ TEST_F(AtenXlaTensorTest, TestScatterReduceSumInPlace) {
}

TEST_F(AtenXlaTensorTest, TestScatterReduceProd) {
GTEST_SKIP() << "Unrecognized `reduce` at "
"https://github.com/pytorch/xla/blob/"
"933dcc21c51676f72a41f2989f5bbba760a498c0/torch_xla/csrc/ops/"
"scatter_reduce.cpp#L42 after functionalization";
torch::Tensor a = torch::rand({3, 5}, torch::TensorOptions(torch::kFloat));
torch::Tensor b = torch::rand({3, 5}, torch::TensorOptions(torch::kFloat));
torch::Tensor c = torch::empty({3, 5}, torch::TensorOptions(torch::kLong));
Expand All @@ -4955,6 +4963,10 @@ TEST_F(AtenXlaTensorTest, TestScatterReduceProd) {
}

TEST_F(AtenXlaTensorTest, TestScatterReduceProdInPlace) {
GTEST_SKIP() << "Unrecognized `reduce` at "
"https://github.com/pytorch/xla/blob/"
"933dcc21c51676f72a41f2989f5bbba760a498c0/torch_xla/csrc/ops/"
"scatter_reduce.cpp#L42 after functionalization";
torch::Tensor a = torch::rand({3, 5}, torch::TensorOptions(torch::kFloat));
torch::Tensor b = torch::rand({3, 5}, torch::TensorOptions(torch::kFloat));
torch::Tensor c = torch::empty({3, 5}, torch::TensorOptions(torch::kLong));
Expand All @@ -4979,6 +4991,10 @@ TEST_F(AtenXlaTensorTest, TestScatterReduceProdInPlace) {
}

TEST_F(AtenXlaTensorTest, TestScatterReduceMin) {
GTEST_SKIP() << "Unrecognized `reduce` at "
"https://github.com/pytorch/xla/blob/"
"933dcc21c51676f72a41f2989f5bbba760a498c0/torch_xla/csrc/ops/"
"scatter_reduce.cpp#L42 after functionalization";
torch::Tensor a = torch::rand({3, 5}, torch::TensorOptions(torch::kFloat));
torch::Tensor b = torch::rand({3, 5}, torch::TensorOptions(torch::kFloat));
torch::Tensor c = torch::empty({3, 5}, torch::TensorOptions(torch::kLong));
Expand All @@ -5004,6 +5020,10 @@ TEST_F(AtenXlaTensorTest, TestScatterReduceMin) {
}

TEST_F(AtenXlaTensorTest, TestScatterReduceMinInPlace) {
GTEST_SKIP() << "Unrecognized `reduce` at "
"https://github.com/pytorch/xla/blob/"
"933dcc21c51676f72a41f2989f5bbba760a498c0/torch_xla/csrc/ops/"
"scatter_reduce.cpp#L42 after functionalization";
torch::Tensor a = torch::rand({3, 5}, torch::TensorOptions(torch::kFloat));
torch::Tensor b = torch::rand({3, 5}, torch::TensorOptions(torch::kFloat));
torch::Tensor c = torch::empty({3, 5}, torch::TensorOptions(torch::kLong));
Expand All @@ -5028,6 +5048,10 @@ TEST_F(AtenXlaTensorTest, TestScatterReduceMinInPlace) {
}

TEST_F(AtenXlaTensorTest, TestScatterReduceMax) {
GTEST_SKIP() << "Unrecognized `reduce` at "
"https://github.com/pytorch/xla/blob/"
"933dcc21c51676f72a41f2989f5bbba760a498c0/torch_xla/csrc/ops/"
"scatter_reduce.cpp#L42 after functionalization";
torch::Tensor a = torch::rand({3, 5}, torch::TensorOptions(torch::kFloat));
torch::Tensor b = torch::rand({3, 5}, torch::TensorOptions(torch::kFloat));
torch::Tensor c = torch::empty({3, 5}, torch::TensorOptions(torch::kLong));
Expand All @@ -5052,6 +5076,10 @@ TEST_F(AtenXlaTensorTest, TestScatterReduceMax) {
}

TEST_F(AtenXlaTensorTest, TestScatterReduceMaxInPlace) {
GTEST_SKIP() << "Unrecognized `reduce` at "
"https://github.com/pytorch/xla/blob/"
"933dcc21c51676f72a41f2989f5bbba760a498c0/torch_xla/csrc/ops/"
"scatter_reduce.cpp#L42 after functionalization";
torch::Tensor a = torch::rand({3, 5}, torch::TensorOptions(torch::kFloat));
torch::Tensor b = torch::rand({3, 5}, torch::TensorOptions(torch::kFloat));
torch::Tensor c = torch::empty({3, 5}, torch::TensorOptions(torch::kLong));
Expand Down
1 change: 0 additions & 1 deletion test/test_dynamic_shape_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ def forward(self, x):
)
class TestDynamicShapeModels(unittest.TestCase):

@unittest.skip("Broke by functionalization")
def test_forward_pass_dynamic_input_correctness(self):
losses = []
for _ in range(2):
Expand Down
5 changes: 0 additions & 5 deletions test/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -749,10 +749,6 @@ def test_masked_select_shape(self):
torch.masked_select(x, mask), 0)
self.assertEqual(x_dim0_shape.item(), 3)

@unittest.skip(
"Temporarily disable test. See https://github.com/pytorch/xla/issues/4501"
)
# @unittest.skip("Crash with dynamic shape")
def test_nonzero_cast(self):
t1 = torch.ones(5, 2, device=xm.xla_device())
# Result of the nonzero should be the index type. Currently
Expand Down Expand Up @@ -1481,7 +1477,6 @@ def test_scatter_add_bool(self):
xla_b.scatter_add_(0, xla_index, xla_a)
self.assertEqual(b, xla_b)

@unittest.skip("Fail with run_dynamic")
def test_squeeze_nonzero(self):

def test_fn(a):
Expand Down
12 changes: 6 additions & 6 deletions torch_xla/csrc/ops/dynamic_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ int64_t SizeEq::getDynamicValue() const {
std::string SizeEq::ToString() const { return "aten::size_eq"; }

SizeNe::SizeNe(torch::lazy::Value a, torch::lazy::Value b)
: XlaNode(torch::lazy::OpKind{c10::Symbol::fromQualString("aten::ne")},
: XlaNode(torch::lazy::OpKind{c10::Symbol::fromQualString("aten::size_ne")},
{a, b},
xla::ShapeUtil::MakeShape(
GetShapeDimensionType(/*device=*/nullptr), {}),
Expand All @@ -169,10 +169,10 @@ int64_t SizeNe::getDynamicValue() const {
return dim_node_0->getDynamicValue() != dim_node_1->getDynamicValue() ? 1 : 0;
}

std::string SizeNe::ToString() const { return "aten::ne_size"; }
std::string SizeNe::ToString() const { return "aten::size_ne"; }

SizeGe::SizeGe(torch::lazy::Value a, torch::lazy::Value b)
: XlaNode(torch::lazy::OpKind{c10::Symbol::fromQualString("aten::ge")},
: XlaNode(torch::lazy::OpKind{c10::Symbol::fromQualString("aten::size_ge")},
{a, b},
xla::ShapeUtil::MakeShape(
GetShapeDimensionType(/*device=*/nullptr), {}),
Expand All @@ -191,10 +191,10 @@ int64_t SizeGe::getDynamicValue() const {
return dim_node_0->getDynamicValue() >= dim_node_1->getDynamicValue() ? 1 : 0;
}

std::string SizeGe::ToString() const { return "aten::ge_size"; }
std::string SizeGe::ToString() const { return "aten::size_ge"; }

SizeLt::SizeLt(torch::lazy::Value a, torch::lazy::Value b)
: XlaNode(torch::lazy::OpKind{c10::Symbol::fromQualString("aten::lt")},
: XlaNode(torch::lazy::OpKind{c10::Symbol::fromQualString("aten::size_lt")},
{a, b},
xla::ShapeUtil::MakeShape(
GetShapeDimensionType(/*device=*/nullptr), {}),
Expand All @@ -213,7 +213,7 @@ int64_t SizeLt::getDynamicValue() const {
return dim_node_0->getDynamicValue() < dim_node_1->getDynamicValue() ? 1 : 0;
}

std::string SizeLt::ToString() const { return "aten::lt_size"; }
std::string SizeLt::ToString() const { return "aten::size_lt"; }

SizeConstant::SizeConstant(int64_t val)
: Scalar(c10::Scalar{val},
Expand Down
8 changes: 5 additions & 3 deletions torch_xla/csrc/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -623,11 +623,13 @@ bool XLATensor::ShouldSyncIrNode() {
}

bool XLASymNodeImpl::is_bool() {
auto op = node()->op().op;
c10::Symbol op = node()->op().op;
// Reference:
// https://github.com/pytorch/pytorch/blob/master/torch/fx/experimental/symbolic_shapes.py#L403
if (op == at::aten::eq || op == at::aten::ne || op == at::aten::ge ||
op == at::aten::lt) {
if (op == c10::Symbol::fromQualString("aten::size_eq") ||
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@alanwaketan I tried to use switch but c10::Symbol::fromQualString("aten::size_eq") are not constant so the build would fail.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, it wasn't like the enum before.

op == c10::Symbol::fromQualString("aten::size_ne") ||
op == c10::Symbol::fromQualString("aten::size_ge") ||
op == c10::Symbol::fromQualString("aten::size_lt")) {
return true;
}
return false;
Expand Down