Skip to content

Commit e90a516

Browse files
committed
Promote int to float for tanh operation (consistent with Pytorch) (#6166)
1 parent dbfe5c3 commit e90a516

File tree

4 files changed

+24
-2
lines changed

4 files changed

+24
-2
lines changed

test/cpp/test_aten_xla_tensor_2.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2411,6 +2411,20 @@ TEST_F(AtenXlaTensorTest, TestTanh) {
24112411
});
24122412
}
24132413

2414+
// In torch, tanh works with integer inputs. The same should be true for
2415+
// torch_xla
2416+
TEST_F(AtenXlaTensorTest, TestTanhWithInt) {
2417+
torch::Tensor a = torch::rand({2, 2});
2418+
torch::Tensor b = torch::tanh(a);
2419+
ForEachDevice([&](const torch::Device& device) {
2420+
torch::Tensor xla_a = CopyToDevice(a, device);
2421+
torch::Tensor xla_b = torch::tanh(xla_a);
2422+
AllClose(b, xla_b, /*rtol=*/1e-3, /*atol=*/1e-5);
2423+
});
2424+
ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
2425+
ExpectCounterChanged("xla::tanh", cpp_test::GetIgnoredCounters());
2426+
}
2427+
24142428
TEST_F(AtenXlaTensorTest, TestClampMinMax) {
24152429
torch::Tensor a = torch::rand({2, 2}, torch::TensorOptions(torch::kFloat));
24162430
torch::Scalar min_val(0.311);

test/test_core_aten_ops.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4412,7 +4412,6 @@ def test_aten_tanh_1(self):
44124412
kwargs = dict()
44134413
run_export_and_compare(self, torch.ops.aten.tanh, args, kwargs)
44144414

4415-
@unittest.skip
44164415
def test_aten_tanh_2(self):
44174416
args = (torch.randint(0, 10, (10, 10)).to(torch.int32),)
44184417
kwargs = dict()

torch_xla/csrc/ops/ops_lower_fn.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -713,6 +713,11 @@ torch_xla::XlaOpVector Tan::Lower(LoweringContext* loctx) const {
713713

714714
torch_xla::XlaOpVector Tanh::Lower(LoweringContext* loctx) const {
715715
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
716+
if (xla::primitive_util::IsIntegralType(XlaHelpers::TypeOfXlaOp(xla_input))) {
717+
xla::PrimitiveType input_type = XlaHelpers::TypeOfXlaOp(xla_input);
718+
xla_input = ConvertTo(xla_input, input_type, xla::PrimitiveType::F32,
719+
/*device=*/nullptr);
720+
}
716721
return ReturnOp(xla::Tanh(xla_input), loctx);
717722
}
718723

torch_xla/csrc/ops/ops_xla_shape_fn.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -799,7 +799,11 @@ xla::Shape TakeOutputShape(const torch::lazy::Value& input,
799799
}
800800

801801
xla::Shape TanhOutputShape(const torch::lazy::Value& input) {
802-
return GetXlaShape(input);
802+
xla::Shape result_shape = GetXlaShape(input);
803+
if (xla::primitive_util::IsIntegralType(result_shape.element_type())) {
804+
result_shape.set_element_type(xla::PrimitiveType::F32);
805+
}
806+
return result_shape;
803807
}
804808

805809
xla::Shape TrilOutputShape(const torch::lazy::Value& input) {

0 commit comments

Comments
 (0)