Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions test/cpp/test_aten_xla_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4319,6 +4319,19 @@ TEST_F(AtenXlaTensorTest, TestLogSoftmaxBackward) {
}
}

TEST_F(AtenXlaTensorTest, TestSoftmaxBackward) {
for (int dim = -4; dim < 4; ++dim) {
auto testfn = [&](const std::vector<at::Tensor>& inputs) -> at::Tensor {
return at::softmax(inputs[0], dim);
};

ForEachDevice([&](const Device& device) {
TestBackward({at::rand({5, 3, 4, 2}, at::TensorOptions(at::kFloat))},
device, testfn, /*rtol=*/1e-3, /*atol=*/1e-4);
});
}
}

TEST_F(AtenXlaTensorTest, TestReluBackward) {
auto testfn = [&](const std::vector<at::Tensor>& inputs) -> at::Tensor {
return at::relu(inputs[0]);
Expand Down
8 changes: 8 additions & 0 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1769,6 +1769,14 @@ at::Tensor AtenXlaType::softmax(const at::Tensor& self, int64_t dim) const {
XLATensor::softmax(bridge::GetXlaTensor(self), dim));
}

at::Tensor AtenXlaType::_softmax_backward_data(const at::Tensor& grad_output,
const at::Tensor& output,
int64_t dim,
const at::Tensor& self) const {
return bridge::AtenFromXlaTensor(XLATensor::softmax_backward(
bridge::GetXlaTensor(grad_output), bridge::GetXlaTensor(output), dim));
}

at::Tensor AtenXlaType::sigmoid(const at::Tensor& self) const {
return bridge::AtenFromXlaTensor(
XLATensor::sigmoid(bridge::GetXlaTensor(self)));
Expand Down
3 changes: 3 additions & 0 deletions torch_xla/csrc/aten_xla_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -555,6 +555,9 @@ class AtenXlaType : public AtenXlaTypeBase {
bool half_to_float) const override;

at::Tensor softmax(const at::Tensor& self, int64_t dim) const override;
at::Tensor _softmax_backward_data(const at::Tensor& grad_output,
const at::Tensor& output, int64_t dim,
const at::Tensor& self) const override;

at::Tensor sigmoid(const at::Tensor& self) const override;
at::Tensor& sigmoid_(at::Tensor& self) const override;
Expand Down
19 changes: 1 addition & 18 deletions torch_xla/csrc/ops/log_softmax_backward.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,28 +8,11 @@
namespace torch_xla {
namespace ir {
namespace ops {
namespace {

xla::Shape NodeOutputShape(const Value& grad_output, const Value& output,
xla::int64 dim) {
auto lower_for_shape_fn =
[dim](tensorflow::gtl::ArraySlice<const xla::XlaOp> operands)
-> xla::XlaOp {
XLA_CHECK_EQ(operands.size(), 2)
<< "Unexpected number of operands: " << operands.size();
return BuildLogSoftmaxGrad(/*grad_output=*/operands[0],
/*output=*/operands[1], dim);
};
return InferOutputShape({grad_output.shape(), output.shape()},
lower_for_shape_fn);
}

} // namespace

LogSoftmaxBackward::LogSoftmaxBackward(const Value& grad_output,
const Value& output, xla::int64 dim)
: Node(ir::OpKind(at::aten::_log_softmax_backward_data),
{grad_output, output}, NodeOutputShape(grad_output, output, dim),
{grad_output, output}, grad_output.shape(),
/*num_outputs=*/1, xla::util::MHash(dim)),
dim_(dim) {}

Expand Down
23 changes: 19 additions & 4 deletions torch_xla/csrc/ops/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
#include "torch_xla/csrc/ops/arithmetic_ir_ops.h"
#include "torch_xla/csrc/ops/constant.h"
#include "torch_xla/csrc/ops/infer_output_shape.h"
#include "torch_xla/csrc/ops/log_softmax_backward.h"
#include "torch_xla/csrc/ops/permute.h"
#include "torch_xla/csrc/ops/softmax_backward.h"
#include "torch_xla/csrc/ops/sum.h"
#include "torch_xla/csrc/pooling.h"
#include "torch_xla/csrc/tensor_util.h"
Expand Down Expand Up @@ -138,10 +140,9 @@ NodePtr ReluOp(const Value& input) {
}

NodePtr TransposeOp(const Value& input, xla::int64 dim0, xla::int64 dim1) {
return ir::MakeNode<ir::ops::Permute>(
input,
XlaHelpers::MakeTransposePermutation(/*dim0=*/dim0, /*dim1=*/dim1,
/*rank=*/input.shape().rank()));
return ir::MakeNode<Permute>(input, XlaHelpers::MakeTransposePermutation(
/*dim0=*/dim0, /*dim1=*/dim1,
/*rank=*/input.shape().rank()));
}

NodePtr Sigmoid(const Value& input) {
Expand All @@ -153,6 +154,20 @@ NodePtr Sigmoid(const Value& input) {
std::move(lower_fn));
}

NodePtr LogSoftmaxBackwardOp(const Value& grad_output, const Value& output,
xla::int64 dim) {
return ir::MakeNode<LogSoftmaxBackward>(
grad_output, output,
XlaHelpers::GetCanonicalDimensionIndex(dim, grad_output.shape().rank()));
}

NodePtr SoftmaxBackwardOp(const Value& grad_output, const Value& output,
xla::int64 dim) {
return ir::MakeNode<SoftmaxBackward>(
grad_output, output,
XlaHelpers::GetCanonicalDimensionIndex(dim, grad_output.shape().rank()));
}

NodePtr Clamp(const Value& input, c10::optional<at::Scalar> min,
c10::optional<at::Scalar> max) {
const xla::Shape& input_shape = input.shape();
Expand Down
6 changes: 6 additions & 0 deletions torch_xla/csrc/ops/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,12 @@ NodePtr TransposeOp(const Value& input, xla::int64 dim0, xla::int64 dim1);

NodePtr Sigmoid(const Value& input);

NodePtr LogSoftmaxBackwardOp(const Value& grad_output, const Value& output,
xla::int64 dim);

NodePtr SoftmaxBackwardOp(const Value& grad_output, const Value& output,
xla::int64 dim);

NodePtr Clamp(const Value& input, c10::optional<at::Scalar> min,
c10::optional<at::Scalar> max);

Expand Down
35 changes: 35 additions & 0 deletions torch_xla/csrc/ops/softmax_backward.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#include "torch_xla/csrc/ops/softmax_backward.h"
#include "tensorflow/compiler/xla/xla_client/debug_macros.h"
#include "tensorflow/compiler/xla/xla_client/util.h"
#include "torch_xla/csrc/lowering_context.h"
#include "torch_xla/csrc/ops/infer_output_shape.h"
#include "torch_xla/csrc/softmax_builder.h"

namespace torch_xla {
namespace ir {
namespace ops {

SoftmaxBackward::SoftmaxBackward(const Value& grad_output, const Value& output,
xla::int64 dim)
: Node(ir::OpKind(at::aten::_softmax_backward_data), {grad_output, output},
grad_output.shape(),
/*num_outputs=*/1, xla::util::MHash(dim)),
dim_(dim) {}

XlaOpVector SoftmaxBackward::Lower(LoweringContext* loctx) const {
xla::XlaOp grad_output = loctx->GetOutputOp(operand(0));
xla::XlaOp output = loctx->GetOutputOp(operand(1));
xla::XlaOp grad_input =
BuildSoftmaxGrad(/*grad_output=*/grad_output, /*output=*/output, dim_);
return ReturnOp(grad_input, loctx);
}

std::string SoftmaxBackward::ToString() const {
std::stringstream ss;
ss << Node::ToString() << ", dim=" << dim_;
return ss.str();
}

} // namespace ops
} // namespace ir
} // namespace torch_xla
27 changes: 27 additions & 0 deletions torch_xla/csrc/ops/softmax_backward.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
#pragma once

#include "torch_xla/csrc/ir.h"

namespace torch_xla {
namespace ir {
namespace ops {

class SoftmaxBackward : public Node {
public:
SoftmaxBackward(const Value& grad_output, const Value& output,
xla::int64 dim);

XlaOpVector Lower(LoweringContext* loctx) const override;

std::string ToString() const override;

xla::int64 dim() const { return dim_; }

private:
// The dimension along which the result is computed.
xla::int64 dim_;
};

} // namespace ops
} // namespace ir
} // namespace torch_xla
43 changes: 25 additions & 18 deletions torch_xla/csrc/softmax_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,18 @@ SoftMaxPartials LogSoftmaxPartials(const xla::XlaOp& logits, xla::int64 dim) {
return {std::move(broadcast_dimensions), shifted_logits, exp_shifted, reduce};
}

xla::XlaOp SoftmaxSumOfGrad(const xla::XlaOp& grad_output, xla::int64 dim) {
xla::Shape grad_output_shape = XlaHelpers::ShapeOfXlaOp(grad_output);
auto broadcast_dimensions =
BroadcastDimensions(grad_output_shape.rank(), dim);
const auto init_value = XlaHelpers::ScalarValue<float>(
0, grad_output_shape.element_type(), grad_output.builder());
return xla::Reduce(
grad_output, init_value,
XlaHelpers::CreateAddComputation(grad_output_shape.element_type()),
{dim});
}

} // namespace

xla::XlaOp BuildLogSoftmax(const torch::jit::Node* node,
Expand Down Expand Up @@ -73,24 +85,10 @@ xla::XlaOp BuildLogSoftmaxGrad(const torch::jit::Node* node,
xla::XlaOp BuildLogSoftmaxGrad(const xla::XlaOp& grad_output,
const xla::XlaOp& output, xla::int64 dim) {
// Inspired from tf2xla.
auto input_size = XlaHelpers::SizesOfXlaOp(grad_output);
std::vector<xla::int64> broadcast_dimensions;
for (size_t broadcast_dim = 0; broadcast_dim < input_size.size();
++broadcast_dim) {
if (broadcast_dim == dim) {
continue;
}
broadcast_dimensions.push_back(broadcast_dim);
}

xla::XlaBuilder* builder = grad_output.builder();
xla::Shape output_shape = XlaHelpers::ShapeOfXlaOp(output);
const auto init_value =
XlaHelpers::ScalarValue<float>(0, output_shape.element_type(), builder);
const auto sum = xla::Reduce(
grad_output, init_value,
XlaHelpers::CreateAddComputation(output_shape.element_type()), {dim});

xla::XlaOp sum = SoftmaxSumOfGrad(grad_output, dim);
xla::Shape grad_output_shape = XlaHelpers::ShapeOfXlaOp(grad_output);
auto broadcast_dimensions =
BroadcastDimensions(grad_output_shape.rank(), dim);
return xla::Sub(grad_output,
xla::Mul(xla::Exp(output), sum, broadcast_dimensions));
}
Expand All @@ -100,4 +98,13 @@ xla::XlaOp BuildSoftmax(const xla::XlaOp& logits, xla::int64 dim) {
return xla::Div(parts.exp_shifted, parts.reduce, parts.broadcast_dimensions);
}

xla::XlaOp BuildSoftmaxGrad(const xla::XlaOp& grad_output,
const xla::XlaOp& output, xla::int64 dim) {
xla::XlaOp sum = SoftmaxSumOfGrad(xla::Mul(grad_output, output), dim);
xla::Shape grad_output_shape = XlaHelpers::ShapeOfXlaOp(grad_output);
auto broadcast_dimensions =
BroadcastDimensions(grad_output_shape.rank(), dim);
return xla::Mul(output, xla::Sub(grad_output, sum, broadcast_dimensions));
}

} // namespace torch_xla
3 changes: 3 additions & 0 deletions torch_xla/csrc/softmax_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,7 @@ xla::XlaOp BuildLogSoftmaxGrad(const xla::XlaOp& grad_output,

xla::XlaOp BuildSoftmax(const xla::XlaOp& logits, xla::int64 dim);

xla::XlaOp BuildSoftmaxGrad(const xla::XlaOp& grad_output,
const xla::XlaOp& output, xla::int64 dim);

} // namespace torch_xla
2 changes: 2 additions & 0 deletions torch_xla/csrc/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -651,6 +651,8 @@ class XLATensor {
xla::int64 reduction);

static XLATensor softmax(const XLATensor& input, xla::int64 dim);
static XLATensor softmax_backward(const XLATensor& grad_output,
const XLATensor& output, xla::int64 dim);

static std::vector<XLATensor> split(const XLATensor& input,
xla::int64 split_size, xla::int64 dim);
Expand Down
13 changes: 8 additions & 5 deletions torch_xla/csrc/tensor_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@
#include "torch_xla/csrc/ops/leaky_relu.h"
#include "torch_xla/csrc/ops/leaky_relu_backward.h"
#include "torch_xla/csrc/ops/log_softmax.h"
#include "torch_xla/csrc/ops/log_softmax_backward.h"
#include "torch_xla/csrc/ops/masked_fill.h"
#include "torch_xla/csrc/ops/max_pool2d.h"
#include "torch_xla/csrc/ops/max_pool2d_backward.h"
Expand Down Expand Up @@ -976,10 +975,8 @@ XLATensor XLATensor::log_softmax(const XLATensor& input, xla::int64 dim) {
XLATensor XLATensor::log_softmax_backward(const XLATensor& grad_output,
const XLATensor& output,
xla::int64 dim) {
return grad_output.CreateFrom(ir::MakeNode<ir::ops::LogSoftmaxBackward>(
grad_output.GetIrValue(), output.GetIrValue(),
XlaHelpers::GetCanonicalDimensionIndex(
dim, grad_output.shape().get().rank())));
return grad_output.CreateFrom(ir::ops::LogSoftmaxBackwardOp(
grad_output.GetIrValue(), output.GetIrValue(), dim));
}

XLATensor XLATensor::log1p(const XLATensor& input) {
Expand Down Expand Up @@ -1432,6 +1429,12 @@ XLATensor XLATensor::softmax(const XLATensor& input, xla::int64 dim) {
XlaHelpers::GetCanonicalDimensionIndex(dim, input.shape().get().rank())));
}

XLATensor XLATensor::softmax_backward(const XLATensor& grad_output,
const XLATensor& output, xla::int64 dim) {
return grad_output.CreateFrom(ir::ops::SoftmaxBackwardOp(
grad_output.GetIrValue(), output.GetIrValue(), dim));
}

std::vector<XLATensor> XLATensor::split(const XLATensor& input,
xla::int64 split_size, xla::int64 dim) {
auto input_shape = input.shape();
Expand Down