Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions test/cpp/test_aten_xla_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3012,6 +3012,19 @@ TEST_F(AtenXlaTensorTest, TestBitwiseNotInPlace) {
ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestSgn) {
torch::Tensor a =
torch::randn({2, 2}, torch::TensorOptions(torch::kComplexFloat)) * 100.0;
torch::Tensor b = torch::sgn(a);
ForEachDevice([&](const torch::Device& device) {
torch::Tensor xla_a = CopyToDevice(a, device);
torch::Tensor xla_b = torch::sgn(xla_a);
AllClose(b, xla_b);
});
ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::sgn", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestSign) {
torch::Tensor a =
torch::randn({2, 2}, torch::TensorOptions(torch::kFloat)) * 100.0;
Expand Down
58 changes: 45 additions & 13 deletions test/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,19 +281,18 @@ def assertTensorsEqual(a, b):
b = b.to(torch.int)

diff = a - b
if a.is_floating_point():
# check that NaNs are in the same locations
nan_mask = torch.isnan(a)
self.assertTrue(torch.equal(nan_mask, torch.isnan(b)), message)
diff[nan_mask] = 0
# inf check if allow_inf=True
if allow_inf:
inf_mask = torch.isinf(a)
inf_sign = inf_mask.sign()
self.assertTrue(
torch.equal(inf_sign,
torch.isinf(b).sign()), message)
diff[inf_mask] = 0
# check that NaNs are in the same locations
nan_mask = torch.isnan(a)
self.assertTrue(torch.equal(nan_mask, torch.isnan(b)), message)
diff[nan_mask] = 0
# inf check if allow_inf=True
if allow_inf:
inf_mask = torch.isinf(a)
inf_sign = inf_mask.sign()
self.assertTrue(
torch.equal(inf_sign,
torch.isinf(b).sign()), message)
diff[inf_mask] = 0
# TODO: implement abs on CharTensor (int8)
if diff.is_signed() and diff.dtype != torch.int8:
diff = diff.abs()
Expand Down Expand Up @@ -991,6 +990,39 @@ def test_max_broadcast(self):
xla_c = torch.max(xla_a, xla_b)
self.assertEqual(c.data, xla_c.data.cpu())

def test_sgn(self):
xla_device = xm.xla_device()
t = torch.randn(2, 3, dtype=torch.cfloat)
# Generate inf+infj
t[0][0].real.div_(0)
t[0][0].imag.div_(0)
# Generate nan+nanj
t[0][1] = 0
t[0][1].real.div_(0)
t[0][1].imag.div_(0)
# Generate 0+0j
t[1][0] = 0
# Generate inf+0j
t[1][1].real.div_(0)
t[1][1] = t[1][1].real.abs()
# Generate -inf+0j
t[1][2].real.div_(0)
t[1][2] = t[1][1].real.abs() * -1
a = t.sgn()
xla_a = t.to(xla_device).sgn()
self.assertEqual(a.data, xla_a.data.cpu())

t = torch.randn(2, 3, dtype=torch.float32)
t[0][0].div_(0)
t[0][1] = 0
t[0][1].div_(0)
t[1][0] = 0
t[1][2].div_(0)
t[1][2] = t[1][1].abs() * -1
a = t.sgn()
xla_a = t.to(xla_device).sgn()
self.assertEqual(a.data, xla_a.data.cpu())

def test_index_put(self):
xla_device = xm.xla_device()
a = torch.tensor([1, 1, 1, 1]).to(xla_device).to(dtype=torch.float32)
Expand Down
2 changes: 1 addition & 1 deletion third_party/tensorflow
Submodule tensorflow updated 4671 files
5 changes: 5 additions & 0 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2847,6 +2847,11 @@ at::Tensor XLANativeFunctions::sigmoid_backward(const at::Tensor& grad_output,
bridge::GetXlaTensor(grad_output), bridge::GetXlaTensor(output)));
}

at::Tensor XLANativeFunctions::sgn(const at::Tensor& self) {
XLA_FN_COUNTER("xla::");
return bridge::AtenFromXlaTensor(XLATensor::sgn(bridge::GetXlaTensor(self)));
}

at::Tensor XLANativeFunctions::sign(const at::Tensor& self) {
XLA_FN_COUNTER("xla::");
return bridge::AtenFromXlaTensor(XLATensor::sign(bridge::GetXlaTensor(self)));
Expand Down
20 changes: 20 additions & 0 deletions torch_xla/csrc/elementwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,26 @@ xla::XlaOp BuildReciprocal(xla::XlaOp input) {
return xla::Div(one, input);
}

xla::XlaOp BuildSgn(xla::XlaOp input) {
xla::XlaOp num_input = ConvertToNumeric(input);
const xla::Shape& shape = XlaHelpers::ShapeOfXlaOp(num_input);
if (!(shape.element_type() == xla::PrimitiveType::C64 ||
shape.element_type() == xla::PrimitiveType::C128)) {
return BuildSign(input);
}
const xla::Shape& shape_real = XlaHelpers::ShapeOfXlaOp(xla::Real(num_input));
xla::XlaOp nan_real =
xla::NanValue(num_input.builder(), shape_real.element_type());
xla::XlaOp nan_complex = xla::Complex(nan_real, nan_real);
xla::XlaOp sign = xla::Sign(num_input);
xla::XlaOp is_finite =
xla::And(xla::IsFinite(xla::Real(sign)), xla::IsFinite(xla::Imag(sign)));
// Replace non-finite tensor values (e.g. Inf, NaN) with NaN
return xla::Select(
is_finite, sign,
MaybeConvertTo(nan_complex, XlaHelpers::TypeOfXlaOp(sign)));
}

xla::XlaOp BuildSign(xla::XlaOp input) {
xla::XlaOp num_input = ConvertToNumeric(input);
const xla::Shape& shape = XlaHelpers::ShapeOfXlaOp(num_input);
Expand Down
4 changes: 4 additions & 0 deletions torch_xla/csrc/elementwise.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ xla::XlaOp BuildSigmoid(xla::XlaOp input);
// Reciprocal(x) = 1 / x
xla::XlaOp BuildReciprocal(xla::XlaOp input);

// Computes the sgn of the complex input.
// If input magnitude is 0 then 0, else input / input magnitude
xla::XlaOp BuildSgn(xla::XlaOp input);

// Computes the sign of the input.
// If x is NaN then 0, otherwise the actual sign
xla::XlaOp BuildSign(xla::XlaOp input);
Expand Down
9 changes: 9 additions & 0 deletions torch_xla/csrc/ops/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,15 @@ NodePtr ReciprocalOp(const Value& input) {
std::move(lower_fn));
}

NodePtr SgnOp(const Value& input) {
auto lower_fn = [](const Node& node, LoweringContext* loctx) -> XlaOpVector {
xla::XlaOp xla_input = loctx->GetOutputOp(node.operand(0));
return node.ReturnOp(BuildSgn(xla_input), loctx);
};
return GenericOp(OpKind(at::aten::sgn), {input}, input.shape(),
std::move(lower_fn));
}

NodePtr SignOp(const Value& input) {
auto lower_fn = [](const Node& node, LoweringContext* loctx) -> XlaOpVector {
xla::XlaOp xla_input = loctx->GetOutputOp(node.operand(0));
Expand Down
2 changes: 2 additions & 0 deletions torch_xla/csrc/ops/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ NodePtr Tanh(const Value& input);

NodePtr Neg(const Value& input);

NodePtr SgnOp(const Value& input);

NodePtr SignOp(const Value& input);

NodePtr Abs(const Value& input);
Expand Down
2 changes: 2 additions & 0 deletions torch_xla/csrc/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -953,6 +953,8 @@ class XLATensor {
static XLATensor sigmoid_backward(const XLATensor& grad_output,
const XLATensor& output);

static XLATensor sgn(const XLATensor& input);

static XLATensor sign(const XLATensor& input);

static XLATensor sin(const XLATensor& input);
Expand Down
4 changes: 4 additions & 0 deletions torch_xla/csrc/tensor_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2245,6 +2245,10 @@ XLATensor XLATensor::sigmoid_backward(const XLATensor& grad_output,
ir::ops::SigmoidBackward(grad_output.GetIrValue(), output.GetIrValue()));
}

XLATensor XLATensor::sgn(const XLATensor& input) {
return input.CreateFrom(ir::ops::SgnOp(input.GetIrValue()));
}

XLATensor XLATensor::sign(const XLATensor& input) {
return input.CreateFrom(ir::ops::SignOp(input.GetIrValue()));
}
Expand Down
1 change: 1 addition & 0 deletions xla_native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,7 @@ supported:
- cholesky
- qr
- erfinv
- sgn
- sign
- atan2
- fmod.Scalar
Expand Down