From f722b5f1989b1d287a3aa3d5db15f41d48c6b2d9 Mon Sep 17 00:00:00 2001 From: Davide Libenzi Date: Wed, 20 Feb 2019 09:04:19 -0800 Subject: [PATCH] Added aten::fmod operation. --- test/cpp/test_aten_xla_tensor.cpp | 10 ++++++++++ torch_xla/csrc/aten_xla_type.cpp | 18 ++++++++++++++++++ torch_xla/csrc/aten_xla_type.h | 5 +++++ torch_xla/csrc/ops/ops.cpp | 3 ++- torch_xla/csrc/ops/ops.h | 2 ++ torch_xla/csrc/tensor.cpp | 14 ++++++++++++++ torch_xla/csrc/tensor.h | 6 +++++- 7 files changed, 56 insertions(+), 2 deletions(-) diff --git a/test/cpp/test_aten_xla_tensor.cpp b/test/cpp/test_aten_xla_tensor.cpp index 722244b3ad04..8d56abf80ca9 100644 --- a/test/cpp/test_aten_xla_tensor.cpp +++ b/test/cpp/test_aten_xla_tensor.cpp @@ -1033,6 +1033,16 @@ TEST_F(AtenXlaTensorTest, TestPow) { }); } +TEST_F(AtenXlaTensorTest, TestFmod) { + at::Tensor a = at::rand({2, 2}, at::TensorOptions(at::kFloat)) * 100.0; + at::Tensor b = at::fmod(a, 2.0); + ForEachDevice([&](const Device& device) { + at::Tensor xla_a = bridge::CreateXlaTensor(a, device); + at::Tensor xla_b = at::fmod(xla_a, 2.0); + AllClose(b, xla_b); + }); +} + TEST_F(AtenXlaTensorTest, TestWhere) { at::Tensor a = at::rand({3, 3}, at::TensorOptions(at::kFloat)); at::Tensor b = at::rand({3, 3}, at::TensorOptions(at::kFloat)); diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index c4826d9a5560..915374fd9d49 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -304,6 +304,24 @@ at::Tensor& AtenXlaType::div_(at::Tensor& self, const at::Tensor& other) const { return self; } +at::Tensor AtenXlaType::fmod(const at::Tensor& self, + const at::Tensor& other) const { + return bridge::AtenFromXlaTensor( + XLATensor::fmod(bridge::GetXlaTensor(self), bridge::GetXlaTensor(other))); +} + +at::Tensor& AtenXlaType::fmod_(at::Tensor& self, at::Scalar other) const { + XLATensor self_tensor = bridge::GetXlaTensor(self); + XLATensor::fmod_(self_tensor, other); + return self; +} + +at::Tensor& AtenXlaType::fmod_(at::Tensor& self, const at::Tensor& other) const { + XLATensor self_tensor = bridge::GetXlaTensor(self); + XLATensor::fmod_(self_tensor, bridge::GetXlaTensor(other)); + return self; +} + at::Tensor AtenXlaType::ne(const at::Tensor& self, at::Scalar other) const { return bridge::AtenFromXlaTensor( XLATensor::ne(bridge::GetXlaTensor(self), other)); diff --git a/torch_xla/csrc/aten_xla_type.h b/torch_xla/csrc/aten_xla_type.h index 94def9c772f9..d459a4eec5a5 100644 --- a/torch_xla/csrc/aten_xla_type.h +++ b/torch_xla/csrc/aten_xla_type.h @@ -99,6 +99,11 @@ class AtenXlaType : public AtenXlaTypeBase { const at::Tensor& other) const override; at::Tensor& div_(at::Tensor& self, const at::Tensor& other) const override; + at::Tensor fmod(const at::Tensor& self, + const at::Tensor& other) const override; + at::Tensor& fmod_(at::Tensor& self, at::Scalar other) const override; + at::Tensor& fmod_(at::Tensor& self, const at::Tensor& other) const override; + at::Tensor ne(const at::Tensor& self, at::Scalar other) const override; at::Tensor ne(const at::Tensor& self, const at::Tensor& other) const override; diff --git a/torch_xla/csrc/ops/ops.cpp b/torch_xla/csrc/ops/ops.cpp index 6aca7360e209..ae350111624f 100644 --- a/torch_xla/csrc/ops/ops.cpp +++ b/torch_xla/csrc/ops/ops.cpp @@ -64,6 +64,7 @@ PTXLA_UNARY_OP(Floor, at::aten::floor, xla::Floor); PTXLA_BINARY_OP(Min, at::aten::min, xla::Min); PTXLA_BINARY_OP(Max, at::aten::max, xla::Max); PTXLA_BINARY_OP(Pow, at::aten::pow, xla::Pow); +PTXLA_BINARY_OP(Fmod, at::aten::fmod, xla::Rem); NodePtr ReciprocalOp(const Value& input) { auto lower_fn = [](const ir::Node& node, @@ -73,7 +74,7 @@ NodePtr ReciprocalOp(const Value& input) { }; return ir::ops::GenericOp(ir::OpKind(at::aten::reciprocal), ir::OpList{input}, input.shape(), std::move(lower_fn)); -} +} NodePtr ReluOp(const Value& input) { auto lower_fn = [](const ir::Node& node, diff --git a/torch_xla/csrc/ops/ops.h b/torch_xla/csrc/ops/ops.h index 4c7d82b31daa..c47de46e35df 100644 --- a/torch_xla/csrc/ops/ops.h +++ b/torch_xla/csrc/ops/ops.h @@ -95,6 +95,8 @@ NodePtr ReciprocalOp(const Value& input); NodePtr Pow(const Value& input, const Value& exponent); +NodePtr Fmod(const Value& dividend, const Value& divisor); + NodePtr Sigmoid(const Value& input); NodePtr Clamp(const Value& input, c10::optional min, diff --git a/torch_xla/csrc/tensor.cpp b/torch_xla/csrc/tensor.cpp index cf46e527f8f7..e6f58f6045d2 100644 --- a/torch_xla/csrc/tensor.cpp +++ b/torch_xla/csrc/tensor.cpp @@ -591,6 +591,20 @@ void XLATensor::div_(XLATensor& input, const at::Scalar& other) { input.SetIrValue(input.GetIrValue() / constant); } +XLATensor XLATensor::fmod(const XLATensor& input, const XLATensor& other) { + return Create(ir::ops::Fmod(input.GetIrValue(), other.GetIrValue()), + input.GetDevice()); +} + +void XLATensor::fmod_(XLATensor& input, at::Scalar other) { + ir::NodePtr constant = ir::ops::ScalarOp(other, input.shape()); + input.SetIrValue(ir::ops::Fmod(input.GetIrValue(), constant)); +} + +void XLATensor::fmod_(XLATensor& input, const XLATensor& other) { + input.SetIrValue(ir::ops::Fmod(input.GetIrValue(), other.GetIrValue())); +} + void XLATensor::zero_(XLATensor& input) { input.SetIrValue(ir::ops::ScalarOp(0.0, input.shape())); } diff --git a/torch_xla/csrc/tensor.h b/torch_xla/csrc/tensor.h index bf419a39a21c..de0cb0bff009 100644 --- a/torch_xla/csrc/tensor.h +++ b/torch_xla/csrc/tensor.h @@ -119,6 +119,10 @@ class XLATensor { static void div_(XLATensor& input, const XLATensor& other); static void div_(XLATensor& input, const at::Scalar& other); + static XLATensor fmod(const XLATensor& input, const XLATensor& other); + static void fmod_(XLATensor& input, at::Scalar other); + static void fmod_(XLATensor& input, const XLATensor& other); + static void zero_(XLATensor& input); // Additional operations which are part of the PyTorch Tensor functionality. @@ -403,7 +407,7 @@ class XLATensor { const XLATensor& input, tensorflow::gtl::ArraySlice repeats); - static std::vector split(const XLATensor& self, + static std::vector split(const XLATensor& input, xla::int64 split_size, xla::int64 dim); // Squeeze out all trivial (size 1) dimensions.