Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions test/cpp/test_aten_xla_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,16 @@ TEST_F(AtenXlaTensorTest, TestAbs) {
});
}

TEST_F(AtenXlaTensorTest, TestSigmoid) {
at::Tensor a = at::rand({2, 2}, at::TensorOptions(at::kFloat));
at::Tensor b = at::sigmoid(a);
ForEachDevice([&](const Device& device) {
at::Tensor xla_a = bridge::CreateXlaTensor(a, device);
at::Tensor xla_b = at::sigmoid(xla_a);
AllClose(b, xla_b, /*rtol=*/1e-3, /*atol=*/1e-5);
});
}

TEST_F(AtenXlaTensorTest, TestAddCMul) {
at::Tensor a = at::rand({2, 2}, at::TensorOptions(at::kFloat));
at::Tensor b = at::rand({2, 2}, at::TensorOptions(at::kFloat));
Expand Down
11 changes: 11 additions & 0 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,17 @@ at::Tensor AtenXlaType::softmax(const at::Tensor& self, int64_t dim) const {
XLATensor::softmax(bridge::GetXlaTensor(self), dim));
}

at::Tensor AtenXlaType::sigmoid(const at::Tensor& self) const {
return bridge::AtenFromXlaTensor(
XLATensor::sigmoid(bridge::GetXlaTensor(self)));
}

at::Tensor& AtenXlaType::sigmoid_(at::Tensor& self) const {
XLATensor self_tensor = bridge::GetXlaTensor(self);
XLATensor::sigmoid_(self_tensor);
return self;
}

at::Tensor AtenXlaType::max_pool2d(const at::Tensor& self,
at::IntList kernel_size, at::IntList stride,
at::IntList padding, at::IntList dilation,
Expand Down
3 changes: 3 additions & 0 deletions torch_xla/csrc/aten_xla_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,9 @@ class AtenXlaType : public AtenXlaTypeBase {

at::Tensor softmax(const at::Tensor& self, int64_t dim) const override;

at::Tensor sigmoid(const at::Tensor& self) const override;
at::Tensor& sigmoid_(at::Tensor& self) const override;

at::Tensor max_pool2d(const at::Tensor& self, at::IntList kernel_size,
at::IntList stride, at::IntList padding,
at::IntList dilation, bool ceil_mode) const override;
Expand Down
7 changes: 7 additions & 0 deletions torch_xla/csrc/elementwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,4 +91,11 @@ xla::XlaOp BuildTypeAs(const torch::jit::Node* node,
return xla::ConvertElementType(operand, target_type);
}

xla::XlaOp BuildSigmoid(const xla::XlaOp& input) {
xla::Shape shape = XlaHelpers::ShapeOfXlaOp(input);
xla::XlaOp half =
XlaHelpers::ScalarValue<float>(0.5, shape.element_type(), input.builder());
return half + half * xla::Tanh(half * input);
}

} // namespace torch_xla
2 changes: 2 additions & 0 deletions torch_xla/csrc/elementwise.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,6 @@ xla::XlaOp BuildThreshold(const xla::XlaOp& input, const xla::XlaOp& output,
// Computes the rectified linear unit (replace negative elements with 0).
xla::XlaOp BuildRelu(const xla::XlaOp& input);

xla::XlaOp BuildSigmoid(const xla::XlaOp& input);

} // namespace torch_xla
10 changes: 10 additions & 0 deletions torch_xla/csrc/ops/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,16 @@ NodePtr TransposeOp(const Value& input) {
output_shape, std::move(lower_fn));
}

NodePtr Sigmoid(const Value& input) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not reuse TranslateSigmoid from translator.cpp?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

auto lower_fn = [](const ir::Node& node,
ir::LoweringContext* loctx) -> ir::XlaOpVector {
xla::XlaOp xla_input = loctx->GetOutputOp(node.operand(0));
return node.ReturnOp(BuildSigmoid(xla_input), loctx);
};
return ir::ops::GenericOp(ir::OpKind(at::aten::sigmoid), ir::OpList{input},
input.shape(), std::move(lower_fn));
}

NodePtr Clamp(const Value& input, c10::optional<at::Scalar> min,
c10::optional<at::Scalar> max) {
const xla::Shape& input_shape = input.shape();
Expand Down
2 changes: 2 additions & 0 deletions torch_xla/csrc/ops/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ NodePtr Sqrt(const Value& input);

NodePtr Pow(const Value& input, const Value& exponent);

NodePtr Sigmoid(const Value& input);

NodePtr Clamp(const Value& input, c10::optional<at::Scalar> min,
c10::optional<at::Scalar> max);

Expand Down
8 changes: 8 additions & 0 deletions torch_xla/csrc/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1089,6 +1089,14 @@ XLATensor XLATensor::softmax(const XLATensor& input, xla::int64 dim) {
input.GetDevice());
}

XLATensor XLATensor::sigmoid(const XLATensor& input) {
return Create(ir::ops::Sigmoid(input.GetIrValue()), input.GetDevice());
}

void XLATensor::sigmoid_(XLATensor& input) {
input.SetIrValue(ir::ops::Sigmoid(input.GetIrValue()));
}

XLATensor XLATensor::nll_loss(const XLATensor& input, const XLATensor& target) {
return Create(ir::ops::NllLossOp(input.GetIrValue(), target.GetIrValue()),
input.GetDevice());
Expand Down
3 changes: 3 additions & 0 deletions torch_xla/csrc/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,9 @@ class XLATensor {

static XLATensor softmax(const XLATensor& input, xla::int64 dim);

static XLATensor sigmoid(const XLATensor& input);
static void sigmoid_(XLATensor& input);

static XLATensor ones(tensorflow::gtl::ArraySlice<const xla::int64> size,
const Device& device, at::ScalarType scalar_type);
static XLATensor ones_like(const XLATensor& input, const Device& device,
Expand Down
6 changes: 1 addition & 5 deletions torch_xla/csrc/translator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -366,11 +366,7 @@ void TranslateSigmoid(const torch::jit::Node* node, ComputationContext* cctx,
xla::XlaBuilder* b) {
XLA_CHECK_EQ(node->inputs().size(), 1);
xla::XlaOp xla_input = cctx->OpForInput(node, 0);
xla::Shape xla_input_shape = XlaHelpers::ShapeOfXlaOp(xla_input);
xla::XlaOp half =
XlaHelpers::ScalarValue<float>(0.5, xla_input_shape.element_type(), b);
xla::XlaOp xla_output = half + half * xla::Tanh(half * xla_input);
cctx->AddNodeOp(node, xla_output);
cctx->AddNodeOp(node, BuildSigmoid(xla_input));
}

void TranslateRelu(const torch::jit::Node* node, ComputationContext* cctx,
Expand Down