From 3434785df4b62be58a28464ebeff14ba1bcdcbcf Mon Sep 17 00:00:00 2001 From: Davide Libenzi Date: Fri, 15 Feb 2019 18:11:49 -0800 Subject: [PATCH] Added aten::sigmoid operations. --- test/cpp/test_aten_xla_tensor.cpp | 10 ++++++++++ torch_xla/csrc/aten_xla_type.cpp | 11 +++++++++++ torch_xla/csrc/aten_xla_type.h | 3 +++ torch_xla/csrc/elementwise.cpp | 7 +++++++ torch_xla/csrc/elementwise.h | 2 ++ torch_xla/csrc/ops/ops.cpp | 10 ++++++++++ torch_xla/csrc/ops/ops.h | 2 ++ torch_xla/csrc/tensor.cpp | 8 ++++++++ torch_xla/csrc/tensor.h | 3 +++ torch_xla/csrc/translator.cpp | 6 +----- 10 files changed, 57 insertions(+), 5 deletions(-) diff --git a/test/cpp/test_aten_xla_tensor.cpp b/test/cpp/test_aten_xla_tensor.cpp index ad2bda167e46..54e66bbe4a91 100644 --- a/test/cpp/test_aten_xla_tensor.cpp +++ b/test/cpp/test_aten_xla_tensor.cpp @@ -524,6 +524,16 @@ TEST_F(AtenXlaTensorTest, TestAbs) { }); } +TEST_F(AtenXlaTensorTest, TestSigmoid) { + at::Tensor a = at::rand({2, 2}, at::TensorOptions(at::kFloat)); + at::Tensor b = at::sigmoid(a); + ForEachDevice([&](const Device& device) { + at::Tensor xla_a = bridge::CreateXlaTensor(a, device); + at::Tensor xla_b = at::sigmoid(xla_a); + AllClose(b, xla_b, /*rtol=*/1e-3, /*atol=*/1e-5); + }); +} + TEST_F(AtenXlaTensorTest, TestAddCMul) { at::Tensor a = at::rand({2, 2}, at::TensorOptions(at::kFloat)); at::Tensor b = at::rand({2, 2}, at::TensorOptions(at::kFloat)); diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index 259b036fe64e..0956c4e7951f 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -495,6 +495,17 @@ at::Tensor AtenXlaType::softmax(const at::Tensor& self, int64_t dim) const { XLATensor::softmax(bridge::GetXlaTensor(self), dim)); } +at::Tensor AtenXlaType::sigmoid(const at::Tensor& self) const { + return bridge::AtenFromXlaTensor( + XLATensor::sigmoid(bridge::GetXlaTensor(self))); +} + +at::Tensor& AtenXlaType::sigmoid_(at::Tensor& self) const { + XLATensor self_tensor = bridge::GetXlaTensor(self); + XLATensor::sigmoid_(self_tensor); + return self; +} + at::Tensor AtenXlaType::max_pool2d(const at::Tensor& self, at::IntList kernel_size, at::IntList stride, at::IntList padding, at::IntList dilation, diff --git a/torch_xla/csrc/aten_xla_type.h b/torch_xla/csrc/aten_xla_type.h index a60b0be0f818..d3450f658a37 100644 --- a/torch_xla/csrc/aten_xla_type.h +++ b/torch_xla/csrc/aten_xla_type.h @@ -173,6 +173,9 @@ class AtenXlaType : public AtenXlaTypeBase { at::Tensor softmax(const at::Tensor& self, int64_t dim) const override; + at::Tensor sigmoid(const at::Tensor& self) const override; + at::Tensor& sigmoid_(at::Tensor& self) const override; + at::Tensor max_pool2d(const at::Tensor& self, at::IntList kernel_size, at::IntList stride, at::IntList padding, at::IntList dilation, bool ceil_mode) const override; diff --git a/torch_xla/csrc/elementwise.cpp b/torch_xla/csrc/elementwise.cpp index 54fd92309057..2d3e51184766 100644 --- a/torch_xla/csrc/elementwise.cpp +++ b/torch_xla/csrc/elementwise.cpp @@ -91,4 +91,11 @@ xla::XlaOp BuildTypeAs(const torch::jit::Node* node, return xla::ConvertElementType(operand, target_type); } +xla::XlaOp BuildSigmoid(const xla::XlaOp& input) { + xla::Shape shape = XlaHelpers::ShapeOfXlaOp(input); + xla::XlaOp half = + XlaHelpers::ScalarValue(0.5, shape.element_type(), input.builder()); + return half + half * xla::Tanh(half * input); +} + } // namespace torch_xla diff --git a/torch_xla/csrc/elementwise.h b/torch_xla/csrc/elementwise.h index 49d08ac31418..b280758ecb4f 100644 --- a/torch_xla/csrc/elementwise.h +++ b/torch_xla/csrc/elementwise.h @@ -28,4 +28,6 @@ xla::XlaOp BuildThreshold(const xla::XlaOp& input, const xla::XlaOp& output, // Computes the rectified linear unit (replace negative elements with 0). xla::XlaOp BuildRelu(const xla::XlaOp& input); +xla::XlaOp BuildSigmoid(const xla::XlaOp& input); + } // namespace torch_xla diff --git a/torch_xla/csrc/ops/ops.cpp b/torch_xla/csrc/ops/ops.cpp index 961d9a20f629..1b5e52034e34 100644 --- a/torch_xla/csrc/ops/ops.cpp +++ b/torch_xla/csrc/ops/ops.cpp @@ -90,6 +90,16 @@ NodePtr TransposeOp(const Value& input) { output_shape, std::move(lower_fn)); } +NodePtr Sigmoid(const Value& input) { + auto lower_fn = [](const ir::Node& node, + ir::LoweringContext* loctx) -> ir::XlaOpVector { + xla::XlaOp xla_input = loctx->GetOutputOp(node.operand(0)); + return node.ReturnOp(BuildSigmoid(xla_input), loctx); + }; + return ir::ops::GenericOp(ir::OpKind(at::aten::sigmoid), ir::OpList{input}, + input.shape(), std::move(lower_fn)); +} + NodePtr Clamp(const Value& input, c10::optional min, c10::optional max) { const xla::Shape& input_shape = input.shape(); diff --git a/torch_xla/csrc/ops/ops.h b/torch_xla/csrc/ops/ops.h index 7d9b8ff1f707..29b74b6f23f9 100644 --- a/torch_xla/csrc/ops/ops.h +++ b/torch_xla/csrc/ops/ops.h @@ -81,6 +81,8 @@ NodePtr Sqrt(const Value& input); NodePtr Pow(const Value& input, const Value& exponent); +NodePtr Sigmoid(const Value& input); + NodePtr Clamp(const Value& input, c10::optional min, c10::optional max); diff --git a/torch_xla/csrc/tensor.cpp b/torch_xla/csrc/tensor.cpp index 4c616f3b684b..b503b23cf610 100644 --- a/torch_xla/csrc/tensor.cpp +++ b/torch_xla/csrc/tensor.cpp @@ -1089,6 +1089,14 @@ XLATensor XLATensor::softmax(const XLATensor& input, xla::int64 dim) { input.GetDevice()); } +XLATensor XLATensor::sigmoid(const XLATensor& input) { + return Create(ir::ops::Sigmoid(input.GetIrValue()), input.GetDevice()); +} + +void XLATensor::sigmoid_(XLATensor& input) { + input.SetIrValue(ir::ops::Sigmoid(input.GetIrValue())); +} + XLATensor XLATensor::nll_loss(const XLATensor& input, const XLATensor& target) { return Create(ir::ops::NllLossOp(input.GetIrValue(), target.GetIrValue()), input.GetDevice()); diff --git a/torch_xla/csrc/tensor.h b/torch_xla/csrc/tensor.h index 9967981f06fd..387cee6964c3 100644 --- a/torch_xla/csrc/tensor.h +++ b/torch_xla/csrc/tensor.h @@ -196,6 +196,9 @@ class XLATensor { static XLATensor softmax(const XLATensor& input, xla::int64 dim); + static XLATensor sigmoid(const XLATensor& input); + static void sigmoid_(XLATensor& input); + static XLATensor ones(tensorflow::gtl::ArraySlice size, const Device& device, at::ScalarType scalar_type); static XLATensor ones_like(const XLATensor& input, const Device& device, diff --git a/torch_xla/csrc/translator.cpp b/torch_xla/csrc/translator.cpp index f51907186e75..83cf775feaed 100644 --- a/torch_xla/csrc/translator.cpp +++ b/torch_xla/csrc/translator.cpp @@ -366,11 +366,7 @@ void TranslateSigmoid(const torch::jit::Node* node, ComputationContext* cctx, xla::XlaBuilder* b) { XLA_CHECK_EQ(node->inputs().size(), 1); xla::XlaOp xla_input = cctx->OpForInput(node, 0); - xla::Shape xla_input_shape = XlaHelpers::ShapeOfXlaOp(xla_input); - xla::XlaOp half = - XlaHelpers::ScalarValue(0.5, xla_input_shape.element_type(), b); - xla::XlaOp xla_output = half + half * xla::Tanh(half * xla_input); - cctx->AddNodeOp(node, xla_output); + cctx->AddNodeOp(node, BuildSigmoid(xla_input)); } void TranslateRelu(const torch::jit::Node* node, ComputationContext* cctx,