Skip to content

Commit 9e4db96

Browse files
authored
Reapply "Fix some more core aten ops (#6342)" (#6377) (#6387)
1 parent c70e4cc commit 9e4db96

File tree

6 files changed

+44
-14
lines changed

6 files changed

+44
-14
lines changed

codegen/xla_native_functions.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ full_codegen:
7777
- rsqrt
7878
- selu
7979
- sgn
80+
- sigmoid
8081
- sign
8182
- silu
8283
- silu_backward
@@ -304,7 +305,6 @@ supported:
304305
- select_scatter
305306
- selu_
306307
- set_.source_Tensor
307-
- sigmoid
308308
- sigmoid_backward
309309
- slice_copy.Tensor
310310
- slice_scatter

test/test_core_aten_ops.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1904,11 +1904,17 @@ def test_aten_gelu_0(self):
19041904
kwargs = dict()
19051905
run_export_and_compare(self, torch.ops.aten.gelu, args, kwargs)
19061906

1907-
@unittest.skip
19081907
def test_aten_gelu_1(self):
19091908
args = (torch.randn((10, 10)).to(torch.float16),)
19101909
kwargs = dict()
1911-
run_export_and_compare(self, torch.ops.aten.gelu, args, kwargs)
1910+
run_export_and_compare(
1911+
self,
1912+
torch.ops.aten.gelu,
1913+
args,
1914+
kwargs,
1915+
rtol=0.001,
1916+
atol=0.01,
1917+
)
19121918

19131919
def test_aten_glu_0(self):
19141920
args = (
@@ -3082,7 +3088,6 @@ def test_aten_native_group_norm_0(self):
30823088
kwargs = dict()
30833089
run_export_and_compare(self, torch.ops.aten.native_group_norm, args, kwargs)
30843090

3085-
@unittest.skip
30863091
def test_aten_native_group_norm_1(self):
30873092
args = (
30883093
torch.randn((1, 3, 2, 10)).to(torch.float16),
@@ -3095,7 +3100,14 @@ def test_aten_native_group_norm_1(self):
30953100
0.0,
30963101
)
30973102
kwargs = dict()
3098-
run_export_and_compare(self, torch.ops.aten.native_group_norm, args, kwargs)
3103+
run_export_and_compare(
3104+
self,
3105+
torch.ops.aten.native_group_norm,
3106+
args,
3107+
kwargs,
3108+
rtol=0.001,
3109+
atol=0.01,
3110+
)
30993111

31003112
def test_aten_native_layer_norm_0(self):
31013113
args = (
@@ -3411,7 +3423,6 @@ def test_aten_reciprocal_1(self):
34113423
kwargs = dict()
34123424
run_export_and_compare(self, torch.ops.aten.reciprocal, args, kwargs)
34133425

3414-
@unittest.skip
34153426
def test_aten_reciprocal_2(self):
34163427
args = (torch.randint(0, 10, (10, 10)).to(torch.int32),)
34173428
kwargs = dict()
@@ -4009,7 +4020,6 @@ def test_aten_sigmoid_1(self):
40094020
kwargs = dict()
40104021
run_export_and_compare(self, torch.ops.aten.sigmoid, args, kwargs)
40114022

4012-
@unittest.skip
40134023
def test_aten_sigmoid_2(self):
40144024
args = (torch.randint(0, 10, (10, 10)).to(torch.int32),)
40154025
kwargs = dict()

torch_xla/csrc/aten_xla_type.cpp

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2791,12 +2791,6 @@ at::Tensor& XLANativeFunctions::set_(at::Tensor& self,
27912791
return self;
27922792
}
27932793

2794-
at::Tensor XLANativeFunctions::sigmoid(const at::Tensor& self) {
2795-
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
2796-
return bridge::AtenFromXlaTensor(
2797-
tensor_methods::sigmoid(bridge::GetXlaTensor(self)));
2798-
}
2799-
28002794
at::Tensor XLANativeFunctions::sigmoid_backward(const at::Tensor& grad_output,
28012795
const at::Tensor& output) {
28022796
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");

torch_xla/csrc/ops/ops_lower_fn.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -684,6 +684,10 @@ torch_xla::XlaOpVector NeTensor::Lower(LoweringContext* loctx) const {
684684

685685
torch_xla::XlaOpVector Reciprocal::Lower(LoweringContext* loctx) const {
686686
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
687+
if (xla::primitive_util::IsIntegralType(XlaHelpers::TypeOfXlaOp(xla_input))) {
688+
xla::PrimitiveType input_type = XlaHelpers::TypeOfXlaOp(xla_input);
689+
xla_input = ConvertTo(xla_input, input_type, xla::PrimitiveType::F32);
690+
}
687691
return ReturnOp(BuildReciprocal(xla_input), loctx);
688692
}
689693

@@ -726,6 +730,14 @@ torch_xla::XlaOpVector Sgn::Lower(LoweringContext* loctx) const {
726730
return ReturnOp(BuildSgn(xla_input), loctx);
727731
}
728732

733+
torch_xla::XlaOpVector Sigmoid::Lower(LoweringContext* loctx) const {
734+
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
735+
if (xla::primitive_util::IsIntegralType(XlaHelpers::TypeOfXlaOp(xla_input))) {
736+
xla_input = xla::ConvertElementType(xla_input, xla::PrimitiveType::F32);
737+
}
738+
return ReturnOp(xla::Logistic(xla_input), loctx);
739+
}
740+
729741
torch_xla::XlaOpVector Sign::Lower(LoweringContext* loctx) const {
730742
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
731743
return ReturnOp(BuildSign(xla_input), loctx);

torch_xla/csrc/ops/ops_xla_shape_fn.cpp

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -762,7 +762,11 @@ xla::Shape NeTensorOutputShape(const torch::lazy::Value& self,
762762
}
763763

764764
xla::Shape ReciprocalOutputShape(const torch::lazy::Value& input) {
765-
return GetXlaShape(input);
765+
xla::Shape result_shape = GetXlaShape(input);
766+
if (xla::primitive_util::IsIntegralType(result_shape.element_type())) {
767+
result_shape.set_element_type(xla::PrimitiveType::F32);
768+
}
769+
return result_shape;
766770
}
767771

768772
xla::Shape ReluOutputShape(const torch::lazy::Value& input) {
@@ -804,6 +808,14 @@ xla::Shape SgnOutputShape(const torch::lazy::Value& input) {
804808
return GetXlaShape(input);
805809
}
806810

811+
xla::Shape SigmoidOutputShape(const torch::lazy::Value& input) {
812+
xla::Shape result_shape = GetXlaShape(input);
813+
if (xla::primitive_util::IsIntegralType(result_shape.element_type())) {
814+
result_shape.set_element_type(xla::PrimitiveType::F32);
815+
}
816+
return result_shape;
817+
}
818+
807819
xla::Shape SignOutputShape(const torch::lazy::Value& input) {
808820
return GetXlaShape(input);
809821
}

torch_xla/csrc/ops/ops_xla_shape_fn.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,8 @@ xla::Shape SeluOutputShape(const torch::lazy::Value& input);
248248

249249
xla::Shape SgnOutputShape(const torch::lazy::Value& input);
250250

251+
xla::Shape SigmoidOutputShape(const torch::lazy::Value& input);
252+
251253
xla::Shape SignOutputShape(const torch::lazy::Value& input);
252254

253255
xla::Shape SiluOutputShape(const torch::lazy::Value& input);

0 commit comments

Comments
 (0)