diff --git a/test/cpp/test_aten_xla_tensor.cpp b/test/cpp/test_aten_xla_tensor.cpp index 820754f9a06b..f8c42209b026 100644 --- a/test/cpp/test_aten_xla_tensor.cpp +++ b/test/cpp/test_aten_xla_tensor.cpp @@ -4319,6 +4319,19 @@ TEST_F(AtenXlaTensorTest, TestLogSoftmaxBackward) { } } +TEST_F(AtenXlaTensorTest, TestSoftmaxBackward) { + for (int dim = -4; dim < 4; ++dim) { + auto testfn = [&](const std::vector& inputs) -> at::Tensor { + return at::softmax(inputs[0], dim); + }; + + ForEachDevice([&](const Device& device) { + TestBackward({at::rand({5, 3, 4, 2}, at::TensorOptions(at::kFloat))}, + device, testfn, /*rtol=*/1e-3, /*atol=*/1e-4); + }); + } +} + TEST_F(AtenXlaTensorTest, TestReluBackward) { auto testfn = [&](const std::vector& inputs) -> at::Tensor { return at::relu(inputs[0]); diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index d91d4210f3c1..9094282face4 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -1769,6 +1769,14 @@ at::Tensor AtenXlaType::softmax(const at::Tensor& self, int64_t dim) const { XLATensor::softmax(bridge::GetXlaTensor(self), dim)); } +at::Tensor AtenXlaType::_softmax_backward_data(const at::Tensor& grad_output, + const at::Tensor& output, + int64_t dim, + const at::Tensor& self) const { + return bridge::AtenFromXlaTensor(XLATensor::softmax_backward( + bridge::GetXlaTensor(grad_output), bridge::GetXlaTensor(output), dim)); +} + at::Tensor AtenXlaType::sigmoid(const at::Tensor& self) const { return bridge::AtenFromXlaTensor( XLATensor::sigmoid(bridge::GetXlaTensor(self))); diff --git a/torch_xla/csrc/aten_xla_type.h b/torch_xla/csrc/aten_xla_type.h index d154a1b1ea19..a7174a41b408 100644 --- a/torch_xla/csrc/aten_xla_type.h +++ b/torch_xla/csrc/aten_xla_type.h @@ -555,6 +555,9 @@ class AtenXlaType : public AtenXlaTypeBase { bool half_to_float) const override; at::Tensor softmax(const at::Tensor& self, int64_t dim) const override; + at::Tensor _softmax_backward_data(const at::Tensor& grad_output, + const at::Tensor& output, int64_t dim, + const at::Tensor& self) const override; at::Tensor sigmoid(const at::Tensor& self) const override; at::Tensor& sigmoid_(at::Tensor& self) const override; diff --git a/torch_xla/csrc/ops/log_softmax_backward.cpp b/torch_xla/csrc/ops/log_softmax_backward.cpp index 20fb3993962c..3d3e1edc19cb 100644 --- a/torch_xla/csrc/ops/log_softmax_backward.cpp +++ b/torch_xla/csrc/ops/log_softmax_backward.cpp @@ -8,28 +8,11 @@ namespace torch_xla { namespace ir { namespace ops { -namespace { - -xla::Shape NodeOutputShape(const Value& grad_output, const Value& output, - xla::int64 dim) { - auto lower_for_shape_fn = - [dim](tensorflow::gtl::ArraySlice operands) - -> xla::XlaOp { - XLA_CHECK_EQ(operands.size(), 2) - << "Unexpected number of operands: " << operands.size(); - return BuildLogSoftmaxGrad(/*grad_output=*/operands[0], - /*output=*/operands[1], dim); - }; - return InferOutputShape({grad_output.shape(), output.shape()}, - lower_for_shape_fn); -} - -} // namespace LogSoftmaxBackward::LogSoftmaxBackward(const Value& grad_output, const Value& output, xla::int64 dim) : Node(ir::OpKind(at::aten::_log_softmax_backward_data), - {grad_output, output}, NodeOutputShape(grad_output, output, dim), + {grad_output, output}, grad_output.shape(), /*num_outputs=*/1, xla::util::MHash(dim)), dim_(dim) {} diff --git a/torch_xla/csrc/ops/ops.cpp b/torch_xla/csrc/ops/ops.cpp index f698f871102c..2dd335693cae 100644 --- a/torch_xla/csrc/ops/ops.cpp +++ b/torch_xla/csrc/ops/ops.cpp @@ -16,7 +16,9 @@ #include "torch_xla/csrc/ops/arithmetic_ir_ops.h" #include "torch_xla/csrc/ops/constant.h" #include "torch_xla/csrc/ops/infer_output_shape.h" +#include "torch_xla/csrc/ops/log_softmax_backward.h" #include "torch_xla/csrc/ops/permute.h" +#include "torch_xla/csrc/ops/softmax_backward.h" #include "torch_xla/csrc/ops/sum.h" #include "torch_xla/csrc/pooling.h" #include "torch_xla/csrc/tensor_util.h" @@ -138,10 +140,9 @@ NodePtr ReluOp(const Value& input) { } NodePtr TransposeOp(const Value& input, xla::int64 dim0, xla::int64 dim1) { - return ir::MakeNode( - input, - XlaHelpers::MakeTransposePermutation(/*dim0=*/dim0, /*dim1=*/dim1, - /*rank=*/input.shape().rank())); + return ir::MakeNode(input, XlaHelpers::MakeTransposePermutation( + /*dim0=*/dim0, /*dim1=*/dim1, + /*rank=*/input.shape().rank())); } NodePtr Sigmoid(const Value& input) { @@ -153,6 +154,20 @@ NodePtr Sigmoid(const Value& input) { std::move(lower_fn)); } +NodePtr LogSoftmaxBackwardOp(const Value& grad_output, const Value& output, + xla::int64 dim) { + return ir::MakeNode( + grad_output, output, + XlaHelpers::GetCanonicalDimensionIndex(dim, grad_output.shape().rank())); +} + +NodePtr SoftmaxBackwardOp(const Value& grad_output, const Value& output, + xla::int64 dim) { + return ir::MakeNode( + grad_output, output, + XlaHelpers::GetCanonicalDimensionIndex(dim, grad_output.shape().rank())); +} + NodePtr Clamp(const Value& input, c10::optional min, c10::optional max) { const xla::Shape& input_shape = input.shape(); diff --git a/torch_xla/csrc/ops/ops.h b/torch_xla/csrc/ops/ops.h index f1a5820c994d..ec8c4afe80a7 100644 --- a/torch_xla/csrc/ops/ops.h +++ b/torch_xla/csrc/ops/ops.h @@ -107,6 +107,12 @@ NodePtr TransposeOp(const Value& input, xla::int64 dim0, xla::int64 dim1); NodePtr Sigmoid(const Value& input); +NodePtr LogSoftmaxBackwardOp(const Value& grad_output, const Value& output, + xla::int64 dim); + +NodePtr SoftmaxBackwardOp(const Value& grad_output, const Value& output, + xla::int64 dim); + NodePtr Clamp(const Value& input, c10::optional min, c10::optional max); diff --git a/torch_xla/csrc/ops/softmax_backward.cpp b/torch_xla/csrc/ops/softmax_backward.cpp new file mode 100644 index 000000000000..708f11b3e1ee --- /dev/null +++ b/torch_xla/csrc/ops/softmax_backward.cpp @@ -0,0 +1,35 @@ +#include "torch_xla/csrc/ops/softmax_backward.h" +#include "tensorflow/compiler/xla/xla_client/debug_macros.h" +#include "tensorflow/compiler/xla/xla_client/util.h" +#include "torch_xla/csrc/lowering_context.h" +#include "torch_xla/csrc/ops/infer_output_shape.h" +#include "torch_xla/csrc/softmax_builder.h" + +namespace torch_xla { +namespace ir { +namespace ops { + +SoftmaxBackward::SoftmaxBackward(const Value& grad_output, const Value& output, + xla::int64 dim) + : Node(ir::OpKind(at::aten::_softmax_backward_data), {grad_output, output}, + grad_output.shape(), + /*num_outputs=*/1, xla::util::MHash(dim)), + dim_(dim) {} + +XlaOpVector SoftmaxBackward::Lower(LoweringContext* loctx) const { + xla::XlaOp grad_output = loctx->GetOutputOp(operand(0)); + xla::XlaOp output = loctx->GetOutputOp(operand(1)); + xla::XlaOp grad_input = + BuildSoftmaxGrad(/*grad_output=*/grad_output, /*output=*/output, dim_); + return ReturnOp(grad_input, loctx); +} + +std::string SoftmaxBackward::ToString() const { + std::stringstream ss; + ss << Node::ToString() << ", dim=" << dim_; + return ss.str(); +} + +} // namespace ops +} // namespace ir +} // namespace torch_xla diff --git a/torch_xla/csrc/ops/softmax_backward.h b/torch_xla/csrc/ops/softmax_backward.h new file mode 100644 index 000000000000..8786c9fc39f2 --- /dev/null +++ b/torch_xla/csrc/ops/softmax_backward.h @@ -0,0 +1,27 @@ +#pragma once + +#include "torch_xla/csrc/ir.h" + +namespace torch_xla { +namespace ir { +namespace ops { + +class SoftmaxBackward : public Node { + public: + SoftmaxBackward(const Value& grad_output, const Value& output, + xla::int64 dim); + + XlaOpVector Lower(LoweringContext* loctx) const override; + + std::string ToString() const override; + + xla::int64 dim() const { return dim_; } + + private: + // The dimension along which the result is computed. + xla::int64 dim_; +}; + +} // namespace ops +} // namespace ir +} // namespace torch_xla diff --git a/torch_xla/csrc/softmax_builder.cpp b/torch_xla/csrc/softmax_builder.cpp index 725762caf5f2..a820e8696256 100644 --- a/torch_xla/csrc/softmax_builder.cpp +++ b/torch_xla/csrc/softmax_builder.cpp @@ -46,6 +46,18 @@ SoftMaxPartials LogSoftmaxPartials(const xla::XlaOp& logits, xla::int64 dim) { return {std::move(broadcast_dimensions), shifted_logits, exp_shifted, reduce}; } +xla::XlaOp SoftmaxSumOfGrad(const xla::XlaOp& grad_output, xla::int64 dim) { + xla::Shape grad_output_shape = XlaHelpers::ShapeOfXlaOp(grad_output); + auto broadcast_dimensions = + BroadcastDimensions(grad_output_shape.rank(), dim); + const auto init_value = XlaHelpers::ScalarValue( + 0, grad_output_shape.element_type(), grad_output.builder()); + return xla::Reduce( + grad_output, init_value, + XlaHelpers::CreateAddComputation(grad_output_shape.element_type()), + {dim}); +} + } // namespace xla::XlaOp BuildLogSoftmax(const torch::jit::Node* node, @@ -73,24 +85,10 @@ xla::XlaOp BuildLogSoftmaxGrad(const torch::jit::Node* node, xla::XlaOp BuildLogSoftmaxGrad(const xla::XlaOp& grad_output, const xla::XlaOp& output, xla::int64 dim) { // Inspired from tf2xla. - auto input_size = XlaHelpers::SizesOfXlaOp(grad_output); - std::vector broadcast_dimensions; - for (size_t broadcast_dim = 0; broadcast_dim < input_size.size(); - ++broadcast_dim) { - if (broadcast_dim == dim) { - continue; - } - broadcast_dimensions.push_back(broadcast_dim); - } - - xla::XlaBuilder* builder = grad_output.builder(); - xla::Shape output_shape = XlaHelpers::ShapeOfXlaOp(output); - const auto init_value = - XlaHelpers::ScalarValue(0, output_shape.element_type(), builder); - const auto sum = xla::Reduce( - grad_output, init_value, - XlaHelpers::CreateAddComputation(output_shape.element_type()), {dim}); - + xla::XlaOp sum = SoftmaxSumOfGrad(grad_output, dim); + xla::Shape grad_output_shape = XlaHelpers::ShapeOfXlaOp(grad_output); + auto broadcast_dimensions = + BroadcastDimensions(grad_output_shape.rank(), dim); return xla::Sub(grad_output, xla::Mul(xla::Exp(output), sum, broadcast_dimensions)); } @@ -100,4 +98,13 @@ xla::XlaOp BuildSoftmax(const xla::XlaOp& logits, xla::int64 dim) { return xla::Div(parts.exp_shifted, parts.reduce, parts.broadcast_dimensions); } +xla::XlaOp BuildSoftmaxGrad(const xla::XlaOp& grad_output, + const xla::XlaOp& output, xla::int64 dim) { + xla::XlaOp sum = SoftmaxSumOfGrad(xla::Mul(grad_output, output), dim); + xla::Shape grad_output_shape = XlaHelpers::ShapeOfXlaOp(grad_output); + auto broadcast_dimensions = + BroadcastDimensions(grad_output_shape.rank(), dim); + return xla::Mul(output, xla::Sub(grad_output, sum, broadcast_dimensions)); +} + } // namespace torch_xla diff --git a/torch_xla/csrc/softmax_builder.h b/torch_xla/csrc/softmax_builder.h index 607ab5d26871..66abe5e73c64 100644 --- a/torch_xla/csrc/softmax_builder.h +++ b/torch_xla/csrc/softmax_builder.h @@ -24,4 +24,7 @@ xla::XlaOp BuildLogSoftmaxGrad(const xla::XlaOp& grad_output, xla::XlaOp BuildSoftmax(const xla::XlaOp& logits, xla::int64 dim); +xla::XlaOp BuildSoftmaxGrad(const xla::XlaOp& grad_output, + const xla::XlaOp& output, xla::int64 dim); + } // namespace torch_xla diff --git a/torch_xla/csrc/tensor.h b/torch_xla/csrc/tensor.h index d04befd31aea..5989ad24f10c 100644 --- a/torch_xla/csrc/tensor.h +++ b/torch_xla/csrc/tensor.h @@ -651,6 +651,8 @@ class XLATensor { xla::int64 reduction); static XLATensor softmax(const XLATensor& input, xla::int64 dim); + static XLATensor softmax_backward(const XLATensor& grad_output, + const XLATensor& output, xla::int64 dim); static std::vector split(const XLATensor& input, xla::int64 split_size, xla::int64 dim); diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index 8e531969b790..1401e44fa820 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -52,7 +52,6 @@ #include "torch_xla/csrc/ops/leaky_relu.h" #include "torch_xla/csrc/ops/leaky_relu_backward.h" #include "torch_xla/csrc/ops/log_softmax.h" -#include "torch_xla/csrc/ops/log_softmax_backward.h" #include "torch_xla/csrc/ops/masked_fill.h" #include "torch_xla/csrc/ops/max_pool2d.h" #include "torch_xla/csrc/ops/max_pool2d_backward.h" @@ -976,10 +975,8 @@ XLATensor XLATensor::log_softmax(const XLATensor& input, xla::int64 dim) { XLATensor XLATensor::log_softmax_backward(const XLATensor& grad_output, const XLATensor& output, xla::int64 dim) { - return grad_output.CreateFrom(ir::MakeNode( - grad_output.GetIrValue(), output.GetIrValue(), - XlaHelpers::GetCanonicalDimensionIndex( - dim, grad_output.shape().get().rank()))); + return grad_output.CreateFrom(ir::ops::LogSoftmaxBackwardOp( + grad_output.GetIrValue(), output.GetIrValue(), dim)); } XLATensor XLATensor::log1p(const XLATensor& input) { @@ -1432,6 +1429,12 @@ XLATensor XLATensor::softmax(const XLATensor& input, xla::int64 dim) { XlaHelpers::GetCanonicalDimensionIndex(dim, input.shape().get().rank()))); } +XLATensor XLATensor::softmax_backward(const XLATensor& grad_output, + const XLATensor& output, xla::int64 dim) { + return grad_output.CreateFrom(ir::ops::SoftmaxBackwardOp( + grad_output.GetIrValue(), output.GetIrValue(), dim)); +} + std::vector XLATensor::split(const XLATensor& input, xla::int64 split_size, xla::int64 dim) { auto input_shape = input.shape();