diff --git a/test/cpp/test_aten_xla_tensor.cpp b/test/cpp/test_aten_xla_tensor.cpp index 51fd60292ec..b652b5ac7fa 100644 --- a/test/cpp/test_aten_xla_tensor.cpp +++ b/test/cpp/test_aten_xla_tensor.cpp @@ -5597,6 +5597,43 @@ TEST_F(AtenXlaTensorTest, TestL1LossBackward) { } } +TEST_F(AtenXlaTensorTest, TestMseLoss) { + torch::Tensor input = + torch::randn({2, 4}, torch::TensorOptions(torch::kFloat)); + torch::Tensor target = + torch::randn({2, 4}, torch::TensorOptions(torch::kFloat)); + for (torch::Reduction::Reduction reduction : + {torch::Reduction::None, torch::Reduction::Mean, + torch::Reduction::Sum}) { + torch::Tensor output = torch::mse_loss(input, target, reduction); + ForEachDevice([&](const torch::Device& device) { + torch::Tensor xla_input = CopyToDevice(input, device); + torch::Tensor xla_target = CopyToDevice(target, device); + torch::Tensor xla_output = + torch::mse_loss(xla_input, xla_target, reduction); + AllClose(output, xla_output); + }); + } +} + +TEST_F(AtenXlaTensorTest, TestMseLossBackward) { + for (torch::Reduction::Reduction reduction : + {torch::Reduction::None, torch::Reduction::Mean, + torch::Reduction::Sum}) { + auto testfn = + [&](const std::vector& inputs) -> torch::Tensor { + return torch::mse_loss(inputs[0], inputs[1], reduction); + }; + ForEachDevice([&](const torch::Device& device) { + TestBackward( + {torch::rand({2, 4}, + torch::TensorOptions(torch::kFloat).requires_grad(true)), + torch::rand({2, 4}, torch::TensorOptions(torch::kFloat))}, + device, testfn); + }); + } +} + TEST_F(AtenXlaTensorTest, TestBatchNorm1D) { int num_features = 3; torch::Tensor input = diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index e6f3a290a6c..48bfa63f8f6 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -2013,6 +2013,21 @@ at::Tensor AtenXlaType::mm(const at::Tensor& self, const at::Tensor& mat2) { /*weight=*/bridge::GetXlaTensor(mat2))); } +at::Tensor AtenXlaType::mse_loss(const at::Tensor& self, + const at::Tensor& target, int64_t reduction) { + return bridge::AtenFromXlaTensor(XLATensor::mse_loss( + bridge::GetXlaTensor(self), bridge::GetXlaTensor(target), reduction)); +} + +at::Tensor AtenXlaType::mse_loss_backward(const at::Tensor& grad_output, + const at::Tensor& self, + const at::Tensor& target, + int64_t reduction) { + return bridge::AtenFromXlaTensor(XLATensor::mse_loss_backward( + bridge::GetXlaTensor(grad_output), bridge::GetXlaTensor(self), + bridge::GetXlaTensor(target), reduction)); +} + at::Tensor AtenXlaType::mul(const at::Tensor& self, const at::Tensor& other) { auto xlatensors = GetPromotedXlaTensorsForBinaryOp(self, other); return bridge::AtenFromXlaTensor( diff --git a/torch_xla/csrc/aten_xla_type.h b/torch_xla/csrc/aten_xla_type.h index 1e6fe2557c6..bfab6cc1bf7 100644 --- a/torch_xla/csrc/aten_xla_type.h +++ b/torch_xla/csrc/aten_xla_type.h @@ -739,6 +739,14 @@ class AtenXlaType { static at::Tensor mm(const at::Tensor& self, const at::Tensor& mat2); + static at::Tensor mse_loss(const at::Tensor& self, const at::Tensor& target, + int64_t reduction); + + static at::Tensor mse_loss_backward(const at::Tensor& grad_output, + const at::Tensor& self, + const at::Tensor& target, + int64_t reduction); + static at::Tensor mul(const at::Tensor& self, const at::Tensor& other); static at::Tensor mul(const at::Tensor& self, at::Scalar other); diff --git a/torch_xla/csrc/ops/mse_loss.cpp b/torch_xla/csrc/ops/mse_loss.cpp new file mode 100644 index 00000000000..0c5c3742772 --- /dev/null +++ b/torch_xla/csrc/ops/mse_loss.cpp @@ -0,0 +1,54 @@ +#include "torch_xla/csrc/ops/mse_loss.h" + +#include + +#include "tensorflow/compiler/xla/xla_client/debug_macros.h" +#include "tensorflow/compiler/xla/xla_client/util.h" +#include "torch_xla/csrc/lowering_context.h" +#include "torch_xla/csrc/ops/infer_output_shape.h" + +namespace torch_xla { +namespace ir { +namespace ops { +namespace { + +xla::Shape NodeOutputShape(const Value& input, const Value& target, + ReductionMode reduction) { + auto lower_for_shape_fn = + [&](tensorflow::gtl::ArraySlice operands) + -> xla::XlaOp { + return BuildMseLoss(operands[0], operands[1], reduction); + }; + return InferOutputShape({input.shape(), target.shape()}, lower_for_shape_fn); +} + +} // namespace + +MseLoss::MseLoss(const Value& input, const Value& target, + ReductionMode reduction) + : Node(ir::OpKind(at::aten::mse_loss), {input, target}, + [&]() { return NodeOutputShape(input, target, reduction); }, + /*num_outputs=*/1, + xla::util::MHash(xla::util::GetEnumValue(reduction))), + reduction_(reduction) {} + +NodePtr MseLoss::Clone(OpList operands) const { + return MakeNode(operands.at(0), operands.at(1), reduction_); +} + +XlaOpVector MseLoss::Lower(LoweringContext* loctx) const { + xla::XlaOp input = loctx->GetOutputOp(operand(0)); + xla::XlaOp target = loctx->GetOutputOp(operand(1)); + return ReturnOp(BuildMseLoss(input, target, reduction_), loctx); +} + +std::string MseLoss::ToString() const { + std::stringstream ss; + ss << Node::ToString() + << ", reduction=" << xla::util::GetEnumValue(reduction_); + return ss.str(); +} + +} // namespace ops +} // namespace ir +} // namespace torch_xla diff --git a/torch_xla/csrc/ops/mse_loss.h b/torch_xla/csrc/ops/mse_loss.h new file mode 100644 index 00000000000..4fb4daac226 --- /dev/null +++ b/torch_xla/csrc/ops/mse_loss.h @@ -0,0 +1,29 @@ +#pragma once + +#include "tensorflow/compiler/xla/types.h" +#include "torch_xla/csrc/ir.h" +#include "torch_xla/csrc/reduction.h" + +namespace torch_xla { +namespace ir { +namespace ops { + +class MseLoss : public Node { + public: + MseLoss(const Value& input, const Value& target, ReductionMode reduction); + + std::string ToString() const override; + + NodePtr Clone(OpList operands) const override; + + XlaOpVector Lower(LoweringContext* loctx) const override; + + ReductionMode reduction() const { return reduction_; } + + private: + ReductionMode reduction_; +}; + +} // namespace ops +} // namespace ir +} // namespace torch_xla diff --git a/torch_xla/csrc/ops/mse_loss_backward.cpp b/torch_xla/csrc/ops/mse_loss_backward.cpp new file mode 100644 index 00000000000..007e0c603ec --- /dev/null +++ b/torch_xla/csrc/ops/mse_loss_backward.cpp @@ -0,0 +1,61 @@ +#include "torch_xla/csrc/ops/mse_loss_backward.h" + +#include "tensorflow/compiler/xla/xla_client/util.h" +#include "torch_xla/csrc/lowering_context.h" +#include "torch_xla/csrc/ops/infer_output_shape.h" +#include "torch_xla/csrc/ops/mse_loss.h" +#include "torch_xla/csrc/reduction.h" + +namespace torch_xla { +namespace ir { +namespace ops { +namespace { + +xla::Shape NodeOutputShape(const Value& grad_output, const Value& input, + const Value& target, ReductionMode reduction) { + auto lower_for_shape_fn = + [&](tensorflow::gtl::ArraySlice operands) + -> xla::XlaOp { + return BuildMseLossBackward(operands[0], operands[1], operands[2], + reduction); + }; + return InferOutputShape({grad_output.shape(), input.shape(), target.shape()}, + lower_for_shape_fn); +} + +} // namespace + +MseLossBackward::MseLossBackward(const Value& grad_output, const Value& input, + const Value& target, ReductionMode reduction) + : Node(ir::OpKind(at::aten::mse_loss_backward), + {grad_output, input, target}, + [&]() { + return NodeOutputShape(grad_output, input, target, reduction); + }, + /*num_outputs=*/1, + xla::util::MHash(xla::util::GetEnumValue(reduction))), + reduction_(reduction) {} + +NodePtr MseLossBackward::Clone(OpList operands) const { + return MakeNode(operands.at(0), operands.at(1), + operands.at(2), reduction_); +} + +XlaOpVector MseLossBackward::Lower(LoweringContext* loctx) const { + xla::XlaOp grad_output = loctx->GetOutputOp(operand(0)); + xla::XlaOp input = loctx->GetOutputOp(operand(1)); + xla::XlaOp target = loctx->GetOutputOp(operand(2)); + return ReturnOp(BuildMseLossBackward(grad_output, input, target, reduction_), + loctx); +} + +std::string MseLossBackward::ToString() const { + std::stringstream ss; + ss << Node::ToString() + << ", reduction=" << xla::util::GetEnumValue(reduction_); + return ss.str(); +} + +} // namespace ops +} // namespace ir +} // namespace torch_xla diff --git a/torch_xla/csrc/ops/mse_loss_backward.h b/torch_xla/csrc/ops/mse_loss_backward.h new file mode 100644 index 00000000000..7f83b85fff8 --- /dev/null +++ b/torch_xla/csrc/ops/mse_loss_backward.h @@ -0,0 +1,30 @@ +#pragma once + +#include "tensorflow/compiler/xla/types.h" +#include "torch_xla/csrc/ir.h" +#include "torch_xla/csrc/reduction.h" + +namespace torch_xla { +namespace ir { +namespace ops { + +class MseLossBackward : public Node { + public: + MseLossBackward(const Value& grad_output, const Value& input, + const Value& target, ReductionMode reduction); + + std::string ToString() const override; + + NodePtr Clone(OpList operands) const override; + + XlaOpVector Lower(LoweringContext* loctx) const override; + + ReductionMode reduction() const { return reduction_; } + + private: + ReductionMode reduction_; +}; + +} // namespace ops +} // namespace ir +} // namespace torch_xla diff --git a/torch_xla/csrc/reduction.cpp b/torch_xla/csrc/reduction.cpp index 0bba5c5d8e5..7be4dfaed28 100644 --- a/torch_xla/csrc/reduction.cpp +++ b/torch_xla/csrc/reduction.cpp @@ -159,6 +159,53 @@ xla::XlaOp BuildL1LossBackward(const xla::XlaOp& grad_output, return xla::Select(xla::Ge(input, target), grad_value, -grad_value); } +xla::XlaOp BuildMseLoss(const xla::XlaOp& input, const xla::XlaOp& target, + ReductionMode reduction) { + xla::XlaOp diff = input - target; + xla::XlaOp result = diff * diff; + if (reduction == ReductionMode::kNone) { + return result; + } + xla::Shape input_shape = XlaHelpers::ShapeOfXlaOp(input); + result = xla::ReduceAll( + result, xla::Zero(input.builder(), input_shape.element_type()), + XlaHelpers::CreateAddComputation(input_shape.element_type())); + if (reduction == ReductionMode::kMean) { + xla::int64 num_elements = xla::ShapeUtil::ElementsIn(input_shape); + if (num_elements == 0) { + return xla::NanValue(input.builder(), input_shape.element_type()); + } else { + xla::XlaOp scale_value = XlaHelpers::ScalarValue( + 1.0 / static_cast(num_elements), input_shape.element_type(), + input.builder()); + result = result * scale_value; + } + } + return result; +} + +xla::XlaOp BuildMseLossBackward(const xla::XlaOp& grad_output, + const xla::XlaOp& input, + const xla::XlaOp& target, + ReductionMode reduction) { + xla::Shape input_shape = XlaHelpers::ShapeOfXlaOp(input); + xla::XlaOp two = XlaHelpers::ScalarValue( + 2, input_shape.element_type(), input.builder()); + xla::XlaOp d_input = two * (input - target); + if (reduction == ReductionMode::kNone) { + return d_input * grad_output; + } + xla::XlaOp grad_value = grad_output; + if (reduction == ReductionMode::kMean) { + xla::int64 num_elements = xla::ShapeUtil::ElementsIn(input_shape); + xla::XlaOp scale_value = XlaHelpers::ScalarValue( + 1.0 / static_cast(num_elements), input_shape.element_type(), + input.builder()); + grad_value = grad_output * scale_value; + } + return d_input * grad_value; +} + xla::XlaOp BuildCumulativeComputation(const xla::XlaOp& input, xla::int64 dim, const xla::XlaComputation& reducer, const xla::XlaOp& init) { diff --git a/torch_xla/csrc/reduction.h b/torch_xla/csrc/reduction.h index e472527520f..518e8fa3ede 100644 --- a/torch_xla/csrc/reduction.h +++ b/torch_xla/csrc/reduction.h @@ -19,6 +19,14 @@ xla::XlaOp BuildL1LossBackward(const xla::XlaOp& grad_output, const xla::XlaOp& target, ReductionMode reduction); +xla::XlaOp BuildMseLoss(const xla::XlaOp& input, const xla::XlaOp& target, + ReductionMode reduction); + +xla::XlaOp BuildMseLossBackward(const xla::XlaOp& grad_output, + const xla::XlaOp& input, + const xla::XlaOp& target, + ReductionMode reduction); + // Builds a mean by reducing all the dimensions listed in dimensions. If // keep_reduced_dimensions is true, the reduced dimensions will be retained, // with value 1. diff --git a/torch_xla/csrc/tensor.h b/torch_xla/csrc/tensor.h index d28c0127200..2e4ebae80cf 100644 --- a/torch_xla/csrc/tensor.h +++ b/torch_xla/csrc/tensor.h @@ -630,6 +630,14 @@ class XLATensor { static XLATensor mm(const XLATensor& input, const XLATensor& weight); + static XLATensor mse_loss(const XLATensor& input, const XLATensor& target, + xla::int64 reduction); + + static XLATensor mse_loss_backward(const XLATensor& grad_output, + const XLATensor& input, + const XLATensor& target, + xla::int64 reduction); + static XLATensor mul(const XLATensor& input, const XLATensor& other); static XLATensor mul(const XLATensor& input, at::Scalar other); static void mul_(XLATensor& input, const XLATensor& other); diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index f8cf418e0fc..240f5b7bfe7 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -62,6 +62,8 @@ #include "torch_xla/csrc/ops/max_pool_nd_backward.h" #include "torch_xla/csrc/ops/mean.h" #include "torch_xla/csrc/ops/min_in_dim.h" +#include "torch_xla/csrc/ops/mse_loss.h" +#include "torch_xla/csrc/ops/mse_loss_backward.h" #include "torch_xla/csrc/ops/native_batch_norm_backward.h" #include "torch_xla/csrc/ops/native_batch_norm_forward.h" #include "torch_xla/csrc/ops/nll_loss.h" @@ -1487,6 +1489,21 @@ XLATensor XLATensor::mm(const XLATensor& input, const XLATensor& weight) { ir::ops::Dot(input.GetIrValue(), weight.GetIrValue())); } +XLATensor XLATensor::mse_loss(const XLATensor& input, const XLATensor& target, + xla::int64 reduction) { + return input.CreateFrom(ir::MakeNode( + input.GetIrValue(), target.GetIrValue(), GetXlaReductionMode(reduction))); +} + +XLATensor XLATensor::mse_loss_backward(const XLATensor& grad_output, + const XLATensor& input, + const XLATensor& target, + xla::int64 reduction) { + return input.CreateFrom(ir::MakeNode( + grad_output.GetIrValue(), input.GetIrValue(), target.GetIrValue(), + GetXlaReductionMode(reduction))); +} + XLATensor XLATensor::mul(const XLATensor& input, const XLATensor& other) { return input.CreateFrom(input.GetIrValue() * other.GetIrValue()); }