From 0885c83ea967d6f7c9104c4cbb16217bc5507490 Mon Sep 17 00:00:00 2001 From: Ailing Zhang Date: Wed, 16 Oct 2019 04:02:26 +0000 Subject: [PATCH 1/2] Added mse_loss and mse_loss_backward lowering. --- test/cpp/test_aten_xla_tensor.cpp | 37 ++++++++++++++ torch_xla/csrc/aten_xla_type.cpp | 15 ++++++ torch_xla/csrc/aten_xla_type.h | 8 +++ torch_xla/csrc/ops/mse_loss.cpp | 65 ++++++++++++++++++++++++ torch_xla/csrc/ops/mse_loss.h | 31 +++++++++++ torch_xla/csrc/ops/mse_loss_backward.cpp | 61 ++++++++++++++++++++++ torch_xla/csrc/ops/mse_loss_backward.h | 29 +++++++++++ torch_xla/csrc/reduction.cpp | 47 +++++++++++++++++ torch_xla/csrc/reduction.h | 8 +++ torch_xla/csrc/tensor.h | 8 +++ torch_xla/csrc/tensor_methods.cpp | 17 +++++++ 11 files changed, 326 insertions(+) create mode 100644 torch_xla/csrc/ops/mse_loss.cpp create mode 100644 torch_xla/csrc/ops/mse_loss.h create mode 100644 torch_xla/csrc/ops/mse_loss_backward.cpp create mode 100644 torch_xla/csrc/ops/mse_loss_backward.h diff --git a/test/cpp/test_aten_xla_tensor.cpp b/test/cpp/test_aten_xla_tensor.cpp index 51fd60292ec..b652b5ac7fa 100644 --- a/test/cpp/test_aten_xla_tensor.cpp +++ b/test/cpp/test_aten_xla_tensor.cpp @@ -5597,6 +5597,43 @@ TEST_F(AtenXlaTensorTest, TestL1LossBackward) { } } +TEST_F(AtenXlaTensorTest, TestMseLoss) { + torch::Tensor input = + torch::randn({2, 4}, torch::TensorOptions(torch::kFloat)); + torch::Tensor target = + torch::randn({2, 4}, torch::TensorOptions(torch::kFloat)); + for (torch::Reduction::Reduction reduction : + {torch::Reduction::None, torch::Reduction::Mean, + torch::Reduction::Sum}) { + torch::Tensor output = torch::mse_loss(input, target, reduction); + ForEachDevice([&](const torch::Device& device) { + torch::Tensor xla_input = CopyToDevice(input, device); + torch::Tensor xla_target = CopyToDevice(target, device); + torch::Tensor xla_output = + torch::mse_loss(xla_input, xla_target, reduction); + AllClose(output, xla_output); + }); + } +} + +TEST_F(AtenXlaTensorTest, TestMseLossBackward) { + for (torch::Reduction::Reduction reduction : + {torch::Reduction::None, torch::Reduction::Mean, + torch::Reduction::Sum}) { + auto testfn = + [&](const std::vector& inputs) -> torch::Tensor { + return torch::mse_loss(inputs[0], inputs[1], reduction); + }; + ForEachDevice([&](const torch::Device& device) { + TestBackward( + {torch::rand({2, 4}, + torch::TensorOptions(torch::kFloat).requires_grad(true)), + torch::rand({2, 4}, torch::TensorOptions(torch::kFloat))}, + device, testfn); + }); + } +} + TEST_F(AtenXlaTensorTest, TestBatchNorm1D) { int num_features = 3; torch::Tensor input = diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index e6f3a290a6c..48bfa63f8f6 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -2013,6 +2013,21 @@ at::Tensor AtenXlaType::mm(const at::Tensor& self, const at::Tensor& mat2) { /*weight=*/bridge::GetXlaTensor(mat2))); } +at::Tensor AtenXlaType::mse_loss(const at::Tensor& self, + const at::Tensor& target, int64_t reduction) { + return bridge::AtenFromXlaTensor(XLATensor::mse_loss( + bridge::GetXlaTensor(self), bridge::GetXlaTensor(target), reduction)); +} + +at::Tensor AtenXlaType::mse_loss_backward(const at::Tensor& grad_output, + const at::Tensor& self, + const at::Tensor& target, + int64_t reduction) { + return bridge::AtenFromXlaTensor(XLATensor::mse_loss_backward( + bridge::GetXlaTensor(grad_output), bridge::GetXlaTensor(self), + bridge::GetXlaTensor(target), reduction)); +} + at::Tensor AtenXlaType::mul(const at::Tensor& self, const at::Tensor& other) { auto xlatensors = GetPromotedXlaTensorsForBinaryOp(self, other); return bridge::AtenFromXlaTensor( diff --git a/torch_xla/csrc/aten_xla_type.h b/torch_xla/csrc/aten_xla_type.h index 1e6fe2557c6..bfab6cc1bf7 100644 --- a/torch_xla/csrc/aten_xla_type.h +++ b/torch_xla/csrc/aten_xla_type.h @@ -739,6 +739,14 @@ class AtenXlaType { static at::Tensor mm(const at::Tensor& self, const at::Tensor& mat2); + static at::Tensor mse_loss(const at::Tensor& self, const at::Tensor& target, + int64_t reduction); + + static at::Tensor mse_loss_backward(const at::Tensor& grad_output, + const at::Tensor& self, + const at::Tensor& target, + int64_t reduction); + static at::Tensor mul(const at::Tensor& self, const at::Tensor& other); static at::Tensor mul(const at::Tensor& self, at::Scalar other); diff --git a/torch_xla/csrc/ops/mse_loss.cpp b/torch_xla/csrc/ops/mse_loss.cpp new file mode 100644 index 00000000000..82594fd4828 --- /dev/null +++ b/torch_xla/csrc/ops/mse_loss.cpp @@ -0,0 +1,65 @@ +#include "torch_xla/csrc/ops/mse_loss.h" + +#include + +#include "tensorflow/compiler/xla/xla_client/debug_macros.h" +#include "tensorflow/compiler/xla/xla_client/util.h" +#include "torch_xla/csrc/lowering_context.h" +#include "torch_xla/csrc/ops/infer_output_shape.h" + +namespace torch_xla { +namespace ir { +namespace ops { +namespace { + +xla::Shape NodeOutputShape(const Value& input, const Value& target, + xla::int64 reduction) { + auto lower_for_shape_fn = + [&](tensorflow::gtl::ArraySlice operands) + -> xla::XlaOp { + return BuildMseLoss(operands[0], operands[1], + MseLoss::GetXlaReductionMode(reduction)); + }; + return InferOutputShape({input.shape(), target.shape()}, lower_for_shape_fn); +} + +} // namespace + +MseLoss::MseLoss(const Value& input, const Value& target, xla::int64 reduction) + : Node(ir::OpKind(at::aten::mse_loss), {input, target}, + [&]() { return NodeOutputShape(input, target, reduction); }, + /*num_outputs=*/1, xla::util::MHash(reduction)), + reduction_(reduction) {} + +NodePtr MseLoss::Clone(OpList operands) const { + return MakeNode(operands.at(0), operands.at(1), reduction_); +} + +XlaOpVector MseLoss::Lower(LoweringContext* loctx) const { + xla::XlaOp input = loctx->GetOutputOp(operand(0)); + xla::XlaOp target = loctx->GetOutputOp(operand(1)); + return ReturnOp(BuildMseLoss(input, target, GetXlaReductionMode(reduction_)), + loctx); +} + +std::string MseLoss::ToString() const { + std::stringstream ss; + ss << Node::ToString() << ", reduction=" << reduction_; + return ss.str(); +} + +ReductionMode MseLoss::GetXlaReductionMode(xla::int64 reduction) { + switch (reduction) { + case at::Reduction::Mean: + return ReductionMode::kMean; + case at::Reduction::None: + return ReductionMode::kNone; + case at::Reduction::Sum: + return ReductionMode::kSum; + } + XLA_ERROR() << "Unknown reduction mode: " << reduction; +} + +} // namespace ops +} // namespace ir +} // namespace torch_xla diff --git a/torch_xla/csrc/ops/mse_loss.h b/torch_xla/csrc/ops/mse_loss.h new file mode 100644 index 00000000000..216a2ea0982 --- /dev/null +++ b/torch_xla/csrc/ops/mse_loss.h @@ -0,0 +1,31 @@ +#pragma once + +#include "tensorflow/compiler/xla/types.h" +#include "torch_xla/csrc/ir.h" +#include "torch_xla/csrc/reduction.h" + +namespace torch_xla { +namespace ir { +namespace ops { + +class MseLoss : public Node { + public: + MseLoss(const Value& input, const Value& target, xla::int64 reduction); + + std::string ToString() const override; + + NodePtr Clone(OpList operands) const override; + + XlaOpVector Lower(LoweringContext* loctx) const override; + + xla::int64 reduction() const { return reduction_; } + + static ReductionMode GetXlaReductionMode(xla::int64 reduction); + + private: + xla::int64 reduction_; +}; + +} // namespace ops +} // namespace ir +} // namespace torch_xla diff --git a/torch_xla/csrc/ops/mse_loss_backward.cpp b/torch_xla/csrc/ops/mse_loss_backward.cpp new file mode 100644 index 00000000000..697dce33021 --- /dev/null +++ b/torch_xla/csrc/ops/mse_loss_backward.cpp @@ -0,0 +1,61 @@ +#include "torch_xla/csrc/ops/mse_loss_backward.h" + +#include "tensorflow/compiler/xla/xla_client/util.h" +#include "torch_xla/csrc/lowering_context.h" +#include "torch_xla/csrc/ops/infer_output_shape.h" +#include "torch_xla/csrc/ops/mse_loss.h" +#include "torch_xla/csrc/reduction.h" + +namespace torch_xla { +namespace ir { +namespace ops { +namespace { + +xla::Shape NodeOutputShape(const Value& grad_output, const Value& input, + const Value& target, xla::int64 reduction) { + auto lower_for_shape_fn = + [&](tensorflow::gtl::ArraySlice operands) + -> xla::XlaOp { + return BuildMseLossBackward(operands[0], operands[1], operands[2], + MseLoss::GetXlaReductionMode(reduction)); + }; + return InferOutputShape({grad_output.shape(), input.shape(), target.shape()}, + lower_for_shape_fn); +} + +} // namespace + +MseLossBackward::MseLossBackward(const Value& grad_output, const Value& input, + const Value& target, xla::int64 reduction) + : Node(ir::OpKind(at::aten::mse_loss_backward), + {grad_output, input, target}, + [&]() { + return NodeOutputShape(grad_output, input, target, reduction); + }, + /*num_outputs=*/1, xla::util::MHash(reduction)), + reduction_(reduction) {} + +NodePtr MseLossBackward::Clone(OpList operands) const { + return MakeNode(operands.at(0), operands.at(1), + operands.at(2), reduction_); +} + +XlaOpVector MseLossBackward::Lower(LoweringContext* loctx) const { + xla::XlaOp grad_output = loctx->GetOutputOp(operand(0)); + xla::XlaOp input = loctx->GetOutputOp(operand(1)); + xla::XlaOp target = loctx->GetOutputOp(operand(2)); + return ReturnOp( + BuildMseLossBackward(grad_output, input, target, + MseLoss::GetXlaReductionMode(reduction_)), + loctx); +} + +std::string MseLossBackward::ToString() const { + std::stringstream ss; + ss << Node::ToString() << ", reduction=" << reduction_; + return ss.str(); +} + +} // namespace ops +} // namespace ir +} // namespace torch_xla diff --git a/torch_xla/csrc/ops/mse_loss_backward.h b/torch_xla/csrc/ops/mse_loss_backward.h new file mode 100644 index 00000000000..a73406209c9 --- /dev/null +++ b/torch_xla/csrc/ops/mse_loss_backward.h @@ -0,0 +1,29 @@ +#pragma once + +#include "tensorflow/compiler/xla/types.h" +#include "torch_xla/csrc/ir.h" + +namespace torch_xla { +namespace ir { +namespace ops { + +class MseLossBackward : public Node { + public: + MseLossBackward(const Value& grad_output, const Value& input, + const Value& target, xla::int64 reduction); + + std::string ToString() const override; + + NodePtr Clone(OpList operands) const override; + + XlaOpVector Lower(LoweringContext* loctx) const override; + + xla::int64 reduction() const { return reduction_; } + + private: + xla::int64 reduction_; +}; + +} // namespace ops +} // namespace ir +} // namespace torch_xla diff --git a/torch_xla/csrc/reduction.cpp b/torch_xla/csrc/reduction.cpp index 0bba5c5d8e5..7be4dfaed28 100644 --- a/torch_xla/csrc/reduction.cpp +++ b/torch_xla/csrc/reduction.cpp @@ -159,6 +159,53 @@ xla::XlaOp BuildL1LossBackward(const xla::XlaOp& grad_output, return xla::Select(xla::Ge(input, target), grad_value, -grad_value); } +xla::XlaOp BuildMseLoss(const xla::XlaOp& input, const xla::XlaOp& target, + ReductionMode reduction) { + xla::XlaOp diff = input - target; + xla::XlaOp result = diff * diff; + if (reduction == ReductionMode::kNone) { + return result; + } + xla::Shape input_shape = XlaHelpers::ShapeOfXlaOp(input); + result = xla::ReduceAll( + result, xla::Zero(input.builder(), input_shape.element_type()), + XlaHelpers::CreateAddComputation(input_shape.element_type())); + if (reduction == ReductionMode::kMean) { + xla::int64 num_elements = xla::ShapeUtil::ElementsIn(input_shape); + if (num_elements == 0) { + return xla::NanValue(input.builder(), input_shape.element_type()); + } else { + xla::XlaOp scale_value = XlaHelpers::ScalarValue( + 1.0 / static_cast(num_elements), input_shape.element_type(), + input.builder()); + result = result * scale_value; + } + } + return result; +} + +xla::XlaOp BuildMseLossBackward(const xla::XlaOp& grad_output, + const xla::XlaOp& input, + const xla::XlaOp& target, + ReductionMode reduction) { + xla::Shape input_shape = XlaHelpers::ShapeOfXlaOp(input); + xla::XlaOp two = XlaHelpers::ScalarValue( + 2, input_shape.element_type(), input.builder()); + xla::XlaOp d_input = two * (input - target); + if (reduction == ReductionMode::kNone) { + return d_input * grad_output; + } + xla::XlaOp grad_value = grad_output; + if (reduction == ReductionMode::kMean) { + xla::int64 num_elements = xla::ShapeUtil::ElementsIn(input_shape); + xla::XlaOp scale_value = XlaHelpers::ScalarValue( + 1.0 / static_cast(num_elements), input_shape.element_type(), + input.builder()); + grad_value = grad_output * scale_value; + } + return d_input * grad_value; +} + xla::XlaOp BuildCumulativeComputation(const xla::XlaOp& input, xla::int64 dim, const xla::XlaComputation& reducer, const xla::XlaOp& init) { diff --git a/torch_xla/csrc/reduction.h b/torch_xla/csrc/reduction.h index e472527520f..518e8fa3ede 100644 --- a/torch_xla/csrc/reduction.h +++ b/torch_xla/csrc/reduction.h @@ -19,6 +19,14 @@ xla::XlaOp BuildL1LossBackward(const xla::XlaOp& grad_output, const xla::XlaOp& target, ReductionMode reduction); +xla::XlaOp BuildMseLoss(const xla::XlaOp& input, const xla::XlaOp& target, + ReductionMode reduction); + +xla::XlaOp BuildMseLossBackward(const xla::XlaOp& grad_output, + const xla::XlaOp& input, + const xla::XlaOp& target, + ReductionMode reduction); + // Builds a mean by reducing all the dimensions listed in dimensions. If // keep_reduced_dimensions is true, the reduced dimensions will be retained, // with value 1. diff --git a/torch_xla/csrc/tensor.h b/torch_xla/csrc/tensor.h index d28c0127200..2e4ebae80cf 100644 --- a/torch_xla/csrc/tensor.h +++ b/torch_xla/csrc/tensor.h @@ -630,6 +630,14 @@ class XLATensor { static XLATensor mm(const XLATensor& input, const XLATensor& weight); + static XLATensor mse_loss(const XLATensor& input, const XLATensor& target, + xla::int64 reduction); + + static XLATensor mse_loss_backward(const XLATensor& grad_output, + const XLATensor& input, + const XLATensor& target, + xla::int64 reduction); + static XLATensor mul(const XLATensor& input, const XLATensor& other); static XLATensor mul(const XLATensor& input, at::Scalar other); static void mul_(XLATensor& input, const XLATensor& other); diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index a68804baa57..f2dd9cb27b2 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -60,6 +60,8 @@ #include "torch_xla/csrc/ops/max_pool_nd_backward.h" #include "torch_xla/csrc/ops/mean.h" #include "torch_xla/csrc/ops/min_in_dim.h" +#include "torch_xla/csrc/ops/mse_loss.h" +#include "torch_xla/csrc/ops/mse_loss_backward.h" #include "torch_xla/csrc/ops/native_batch_norm_backward.h" #include "torch_xla/csrc/ops/native_batch_norm_forward.h" #include "torch_xla/csrc/ops/nll_loss.h" @@ -1472,6 +1474,21 @@ XLATensor XLATensor::mm(const XLATensor& input, const XLATensor& weight) { ir::ops::Dot(input.GetIrValue(), weight.GetIrValue())); } +XLATensor XLATensor::mse_loss(const XLATensor& input, const XLATensor& target, + xla::int64 reduction) { + return input.CreateFrom(ir::MakeNode( + input.GetIrValue(), target.GetIrValue(), reduction)); +} + +XLATensor XLATensor::mse_loss_backward(const XLATensor& grad_output, + const XLATensor& input, + const XLATensor& target, + xla::int64 reduction) { + return input.CreateFrom(ir::MakeNode( + grad_output.GetIrValue(), input.GetIrValue(), target.GetIrValue(), + reduction)); +} + XLATensor XLATensor::mul(const XLATensor& input, const XLATensor& other) { return input.CreateFrom(input.GetIrValue() * other.GetIrValue()); } From 9b2d5a36d482e4d65fd08c5efeb56a1528f2e7aa Mon Sep 17 00:00:00 2001 From: Ailing Zhang Date: Wed, 16 Oct 2019 21:48:12 +0000 Subject: [PATCH 2/2] Switch to map ReductionMode in tensor_methods.cpp. --- torch_xla/csrc/ops/mse_loss.cpp | 29 ++++++++---------------- torch_xla/csrc/ops/mse_loss.h | 8 +++---- torch_xla/csrc/ops/mse_loss_backward.cpp | 18 +++++++-------- torch_xla/csrc/ops/mse_loss_backward.h | 7 +++--- torch_xla/csrc/tensor_methods.cpp | 4 ++-- 5 files changed, 27 insertions(+), 39 deletions(-) diff --git a/torch_xla/csrc/ops/mse_loss.cpp b/torch_xla/csrc/ops/mse_loss.cpp index 82594fd4828..0c5c3742772 100644 --- a/torch_xla/csrc/ops/mse_loss.cpp +++ b/torch_xla/csrc/ops/mse_loss.cpp @@ -13,22 +13,23 @@ namespace ops { namespace { xla::Shape NodeOutputShape(const Value& input, const Value& target, - xla::int64 reduction) { + ReductionMode reduction) { auto lower_for_shape_fn = [&](tensorflow::gtl::ArraySlice operands) -> xla::XlaOp { - return BuildMseLoss(operands[0], operands[1], - MseLoss::GetXlaReductionMode(reduction)); + return BuildMseLoss(operands[0], operands[1], reduction); }; return InferOutputShape({input.shape(), target.shape()}, lower_for_shape_fn); } } // namespace -MseLoss::MseLoss(const Value& input, const Value& target, xla::int64 reduction) +MseLoss::MseLoss(const Value& input, const Value& target, + ReductionMode reduction) : Node(ir::OpKind(at::aten::mse_loss), {input, target}, [&]() { return NodeOutputShape(input, target, reduction); }, - /*num_outputs=*/1, xla::util::MHash(reduction)), + /*num_outputs=*/1, + xla::util::MHash(xla::util::GetEnumValue(reduction))), reduction_(reduction) {} NodePtr MseLoss::Clone(OpList operands) const { @@ -38,28 +39,16 @@ NodePtr MseLoss::Clone(OpList operands) const { XlaOpVector MseLoss::Lower(LoweringContext* loctx) const { xla::XlaOp input = loctx->GetOutputOp(operand(0)); xla::XlaOp target = loctx->GetOutputOp(operand(1)); - return ReturnOp(BuildMseLoss(input, target, GetXlaReductionMode(reduction_)), - loctx); + return ReturnOp(BuildMseLoss(input, target, reduction_), loctx); } std::string MseLoss::ToString() const { std::stringstream ss; - ss << Node::ToString() << ", reduction=" << reduction_; + ss << Node::ToString() + << ", reduction=" << xla::util::GetEnumValue(reduction_); return ss.str(); } -ReductionMode MseLoss::GetXlaReductionMode(xla::int64 reduction) { - switch (reduction) { - case at::Reduction::Mean: - return ReductionMode::kMean; - case at::Reduction::None: - return ReductionMode::kNone; - case at::Reduction::Sum: - return ReductionMode::kSum; - } - XLA_ERROR() << "Unknown reduction mode: " << reduction; -} - } // namespace ops } // namespace ir } // namespace torch_xla diff --git a/torch_xla/csrc/ops/mse_loss.h b/torch_xla/csrc/ops/mse_loss.h index 216a2ea0982..4fb4daac226 100644 --- a/torch_xla/csrc/ops/mse_loss.h +++ b/torch_xla/csrc/ops/mse_loss.h @@ -10,7 +10,7 @@ namespace ops { class MseLoss : public Node { public: - MseLoss(const Value& input, const Value& target, xla::int64 reduction); + MseLoss(const Value& input, const Value& target, ReductionMode reduction); std::string ToString() const override; @@ -18,12 +18,10 @@ class MseLoss : public Node { XlaOpVector Lower(LoweringContext* loctx) const override; - xla::int64 reduction() const { return reduction_; } - - static ReductionMode GetXlaReductionMode(xla::int64 reduction); + ReductionMode reduction() const { return reduction_; } private: - xla::int64 reduction_; + ReductionMode reduction_; }; } // namespace ops diff --git a/torch_xla/csrc/ops/mse_loss_backward.cpp b/torch_xla/csrc/ops/mse_loss_backward.cpp index 697dce33021..007e0c603ec 100644 --- a/torch_xla/csrc/ops/mse_loss_backward.cpp +++ b/torch_xla/csrc/ops/mse_loss_backward.cpp @@ -12,12 +12,12 @@ namespace ops { namespace { xla::Shape NodeOutputShape(const Value& grad_output, const Value& input, - const Value& target, xla::int64 reduction) { + const Value& target, ReductionMode reduction) { auto lower_for_shape_fn = [&](tensorflow::gtl::ArraySlice operands) -> xla::XlaOp { return BuildMseLossBackward(operands[0], operands[1], operands[2], - MseLoss::GetXlaReductionMode(reduction)); + reduction); }; return InferOutputShape({grad_output.shape(), input.shape(), target.shape()}, lower_for_shape_fn); @@ -26,13 +26,14 @@ xla::Shape NodeOutputShape(const Value& grad_output, const Value& input, } // namespace MseLossBackward::MseLossBackward(const Value& grad_output, const Value& input, - const Value& target, xla::int64 reduction) + const Value& target, ReductionMode reduction) : Node(ir::OpKind(at::aten::mse_loss_backward), {grad_output, input, target}, [&]() { return NodeOutputShape(grad_output, input, target, reduction); }, - /*num_outputs=*/1, xla::util::MHash(reduction)), + /*num_outputs=*/1, + xla::util::MHash(xla::util::GetEnumValue(reduction))), reduction_(reduction) {} NodePtr MseLossBackward::Clone(OpList operands) const { @@ -44,15 +45,14 @@ XlaOpVector MseLossBackward::Lower(LoweringContext* loctx) const { xla::XlaOp grad_output = loctx->GetOutputOp(operand(0)); xla::XlaOp input = loctx->GetOutputOp(operand(1)); xla::XlaOp target = loctx->GetOutputOp(operand(2)); - return ReturnOp( - BuildMseLossBackward(grad_output, input, target, - MseLoss::GetXlaReductionMode(reduction_)), - loctx); + return ReturnOp(BuildMseLossBackward(grad_output, input, target, reduction_), + loctx); } std::string MseLossBackward::ToString() const { std::stringstream ss; - ss << Node::ToString() << ", reduction=" << reduction_; + ss << Node::ToString() + << ", reduction=" << xla::util::GetEnumValue(reduction_); return ss.str(); } diff --git a/torch_xla/csrc/ops/mse_loss_backward.h b/torch_xla/csrc/ops/mse_loss_backward.h index a73406209c9..7f83b85fff8 100644 --- a/torch_xla/csrc/ops/mse_loss_backward.h +++ b/torch_xla/csrc/ops/mse_loss_backward.h @@ -2,6 +2,7 @@ #include "tensorflow/compiler/xla/types.h" #include "torch_xla/csrc/ir.h" +#include "torch_xla/csrc/reduction.h" namespace torch_xla { namespace ir { @@ -10,7 +11,7 @@ namespace ops { class MseLossBackward : public Node { public: MseLossBackward(const Value& grad_output, const Value& input, - const Value& target, xla::int64 reduction); + const Value& target, ReductionMode reduction); std::string ToString() const override; @@ -18,10 +19,10 @@ class MseLossBackward : public Node { XlaOpVector Lower(LoweringContext* loctx) const override; - xla::int64 reduction() const { return reduction_; } + ReductionMode reduction() const { return reduction_; } private: - xla::int64 reduction_; + ReductionMode reduction_; }; } // namespace ops diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index 5b080c63511..240f5b7bfe7 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -1492,7 +1492,7 @@ XLATensor XLATensor::mm(const XLATensor& input, const XLATensor& weight) { XLATensor XLATensor::mse_loss(const XLATensor& input, const XLATensor& target, xla::int64 reduction) { return input.CreateFrom(ir::MakeNode( - input.GetIrValue(), target.GetIrValue(), reduction)); + input.GetIrValue(), target.GetIrValue(), GetXlaReductionMode(reduction))); } XLATensor XLATensor::mse_loss_backward(const XLATensor& grad_output, @@ -1501,7 +1501,7 @@ XLATensor XLATensor::mse_loss_backward(const XLATensor& grad_output, xla::int64 reduction) { return input.CreateFrom(ir::MakeNode( grad_output.GetIrValue(), input.GetIrValue(), target.GetIrValue(), - reduction)); + GetXlaReductionMode(reduction))); } XLATensor XLATensor::mul(const XLATensor& input, const XLATensor& other) {