Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions test/cpp/test_aten_xla_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5597,6 +5597,43 @@ TEST_F(AtenXlaTensorTest, TestL1LossBackward) {
}
}

TEST_F(AtenXlaTensorTest, TestMseLoss) {
torch::Tensor input =
torch::randn({2, 4}, torch::TensorOptions(torch::kFloat));
torch::Tensor target =
torch::randn({2, 4}, torch::TensorOptions(torch::kFloat));
for (torch::Reduction::Reduction reduction :
{torch::Reduction::None, torch::Reduction::Mean,
torch::Reduction::Sum}) {
torch::Tensor output = torch::mse_loss(input, target, reduction);
ForEachDevice([&](const torch::Device& device) {
torch::Tensor xla_input = CopyToDevice(input, device);
torch::Tensor xla_target = CopyToDevice(target, device);
torch::Tensor xla_output =
torch::mse_loss(xla_input, xla_target, reduction);
AllClose(output, xla_output);
});
}
}

TEST_F(AtenXlaTensorTest, TestMseLossBackward) {
for (torch::Reduction::Reduction reduction :
{torch::Reduction::None, torch::Reduction::Mean,
torch::Reduction::Sum}) {
auto testfn =
[&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
return torch::mse_loss(inputs[0], inputs[1], reduction);
};
ForEachDevice([&](const torch::Device& device) {
TestBackward(
{torch::rand({2, 4},
torch::TensorOptions(torch::kFloat).requires_grad(true)),
torch::rand({2, 4}, torch::TensorOptions(torch::kFloat))},
device, testfn);
});
}
}

TEST_F(AtenXlaTensorTest, TestBatchNorm1D) {
int num_features = 3;
torch::Tensor input =
Expand Down
15 changes: 15 additions & 0 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2013,6 +2013,21 @@ at::Tensor AtenXlaType::mm(const at::Tensor& self, const at::Tensor& mat2) {
/*weight=*/bridge::GetXlaTensor(mat2)));
}

at::Tensor AtenXlaType::mse_loss(const at::Tensor& self,
const at::Tensor& target, int64_t reduction) {
return bridge::AtenFromXlaTensor(XLATensor::mse_loss(
bridge::GetXlaTensor(self), bridge::GetXlaTensor(target), reduction));
}

at::Tensor AtenXlaType::mse_loss_backward(const at::Tensor& grad_output,
const at::Tensor& self,
const at::Tensor& target,
int64_t reduction) {
return bridge::AtenFromXlaTensor(XLATensor::mse_loss_backward(
bridge::GetXlaTensor(grad_output), bridge::GetXlaTensor(self),
bridge::GetXlaTensor(target), reduction));
}

at::Tensor AtenXlaType::mul(const at::Tensor& self, const at::Tensor& other) {
auto xlatensors = GetPromotedXlaTensorsForBinaryOp(self, other);
return bridge::AtenFromXlaTensor(
Expand Down
8 changes: 8 additions & 0 deletions torch_xla/csrc/aten_xla_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -739,6 +739,14 @@ class AtenXlaType {

static at::Tensor mm(const at::Tensor& self, const at::Tensor& mat2);

static at::Tensor mse_loss(const at::Tensor& self, const at::Tensor& target,
int64_t reduction);

static at::Tensor mse_loss_backward(const at::Tensor& grad_output,
const at::Tensor& self,
const at::Tensor& target,
int64_t reduction);

static at::Tensor mul(const at::Tensor& self, const at::Tensor& other);

static at::Tensor mul(const at::Tensor& self, at::Scalar other);
Expand Down
54 changes: 54 additions & 0 deletions torch_xla/csrc/ops/mse_loss.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
#include "torch_xla/csrc/ops/mse_loss.h"

#include <ATen/core/Reduction.h>

#include "tensorflow/compiler/xla/xla_client/debug_macros.h"
#include "tensorflow/compiler/xla/xla_client/util.h"
#include "torch_xla/csrc/lowering_context.h"
#include "torch_xla/csrc/ops/infer_output_shape.h"

namespace torch_xla {
namespace ir {
namespace ops {
namespace {

xla::Shape NodeOutputShape(const Value& input, const Value& target,
ReductionMode reduction) {
auto lower_for_shape_fn =
[&](tensorflow::gtl::ArraySlice<const xla::XlaOp> operands)
-> xla::XlaOp {
return BuildMseLoss(operands[0], operands[1], reduction);
};
return InferOutputShape({input.shape(), target.shape()}, lower_for_shape_fn);
}

} // namespace

MseLoss::MseLoss(const Value& input, const Value& target,
ReductionMode reduction)
: Node(ir::OpKind(at::aten::mse_loss), {input, target},
[&]() { return NodeOutputShape(input, target, reduction); },
/*num_outputs=*/1,
xla::util::MHash(xla::util::GetEnumValue<ReductionMode>(reduction))),
reduction_(reduction) {}

NodePtr MseLoss::Clone(OpList operands) const {
return MakeNode<MseLoss>(operands.at(0), operands.at(1), reduction_);
}

XlaOpVector MseLoss::Lower(LoweringContext* loctx) const {
xla::XlaOp input = loctx->GetOutputOp(operand(0));
xla::XlaOp target = loctx->GetOutputOp(operand(1));
return ReturnOp(BuildMseLoss(input, target, reduction_), loctx);
}

std::string MseLoss::ToString() const {
std::stringstream ss;
ss << Node::ToString()
<< ", reduction=" << xla::util::GetEnumValue<ReductionMode>(reduction_);
return ss.str();
}

} // namespace ops
} // namespace ir
} // namespace torch_xla
29 changes: 29 additions & 0 deletions torch_xla/csrc/ops/mse_loss.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#pragma once

#include "tensorflow/compiler/xla/types.h"
#include "torch_xla/csrc/ir.h"
#include "torch_xla/csrc/reduction.h"

namespace torch_xla {
namespace ir {
namespace ops {

class MseLoss : public Node {
public:
MseLoss(const Value& input, const Value& target, ReductionMode reduction);

std::string ToString() const override;

NodePtr Clone(OpList operands) const override;

XlaOpVector Lower(LoweringContext* loctx) const override;

ReductionMode reduction() const { return reduction_; }

private:
ReductionMode reduction_;
};

} // namespace ops
} // namespace ir
} // namespace torch_xla
61 changes: 61 additions & 0 deletions torch_xla/csrc/ops/mse_loss_backward.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
#include "torch_xla/csrc/ops/mse_loss_backward.h"

#include "tensorflow/compiler/xla/xla_client/util.h"
#include "torch_xla/csrc/lowering_context.h"
#include "torch_xla/csrc/ops/infer_output_shape.h"
#include "torch_xla/csrc/ops/mse_loss.h"
#include "torch_xla/csrc/reduction.h"

namespace torch_xla {
namespace ir {
namespace ops {
namespace {

xla::Shape NodeOutputShape(const Value& grad_output, const Value& input,
const Value& target, ReductionMode reduction) {
auto lower_for_shape_fn =
[&](tensorflow::gtl::ArraySlice<const xla::XlaOp> operands)
-> xla::XlaOp {
return BuildMseLossBackward(operands[0], operands[1], operands[2],
reduction);
};
return InferOutputShape({grad_output.shape(), input.shape(), target.shape()},
lower_for_shape_fn);
}

} // namespace

MseLossBackward::MseLossBackward(const Value& grad_output, const Value& input,
const Value& target, ReductionMode reduction)
: Node(ir::OpKind(at::aten::mse_loss_backward),
{grad_output, input, target},
[&]() {
return NodeOutputShape(grad_output, input, target, reduction);
},
/*num_outputs=*/1,
xla::util::MHash(xla::util::GetEnumValue<ReductionMode>(reduction))),
reduction_(reduction) {}

NodePtr MseLossBackward::Clone(OpList operands) const {
return MakeNode<MseLossBackward>(operands.at(0), operands.at(1),
operands.at(2), reduction_);
}

XlaOpVector MseLossBackward::Lower(LoweringContext* loctx) const {
xla::XlaOp grad_output = loctx->GetOutputOp(operand(0));
xla::XlaOp input = loctx->GetOutputOp(operand(1));
xla::XlaOp target = loctx->GetOutputOp(operand(2));
return ReturnOp(BuildMseLossBackward(grad_output, input, target, reduction_),
loctx);
}

std::string MseLossBackward::ToString() const {
std::stringstream ss;
ss << Node::ToString()
<< ", reduction=" << xla::util::GetEnumValue<ReductionMode>(reduction_);
return ss.str();
}

} // namespace ops
} // namespace ir
} // namespace torch_xla
30 changes: 30 additions & 0 deletions torch_xla/csrc/ops/mse_loss_backward.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#pragma once

#include "tensorflow/compiler/xla/types.h"
#include "torch_xla/csrc/ir.h"
#include "torch_xla/csrc/reduction.h"

namespace torch_xla {
namespace ir {
namespace ops {

class MseLossBackward : public Node {
public:
MseLossBackward(const Value& grad_output, const Value& input,
const Value& target, ReductionMode reduction);

std::string ToString() const override;

NodePtr Clone(OpList operands) const override;

XlaOpVector Lower(LoweringContext* loctx) const override;

ReductionMode reduction() const { return reduction_; }

private:
ReductionMode reduction_;
};

} // namespace ops
} // namespace ir
} // namespace torch_xla
47 changes: 47 additions & 0 deletions torch_xla/csrc/reduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,53 @@ xla::XlaOp BuildL1LossBackward(const xla::XlaOp& grad_output,
return xla::Select(xla::Ge(input, target), grad_value, -grad_value);
}

xla::XlaOp BuildMseLoss(const xla::XlaOp& input, const xla::XlaOp& target,
ReductionMode reduction) {
xla::XlaOp diff = input - target;
xla::XlaOp result = diff * diff;
if (reduction == ReductionMode::kNone) {
return result;
}
xla::Shape input_shape = XlaHelpers::ShapeOfXlaOp(input);
result = xla::ReduceAll(
result, xla::Zero(input.builder(), input_shape.element_type()),
XlaHelpers::CreateAddComputation(input_shape.element_type()));
if (reduction == ReductionMode::kMean) {
xla::int64 num_elements = xla::ShapeUtil::ElementsIn(input_shape);
if (num_elements == 0) {
return xla::NanValue(input.builder(), input_shape.element_type());
} else {
xla::XlaOp scale_value = XlaHelpers::ScalarValue<double>(
1.0 / static_cast<double>(num_elements), input_shape.element_type(),
input.builder());
result = result * scale_value;
}
}
return result;
}

xla::XlaOp BuildMseLossBackward(const xla::XlaOp& grad_output,
const xla::XlaOp& input,
const xla::XlaOp& target,
ReductionMode reduction) {
xla::Shape input_shape = XlaHelpers::ShapeOfXlaOp(input);
xla::XlaOp two = XlaHelpers::ScalarValue<double>(
2, input_shape.element_type(), input.builder());
xla::XlaOp d_input = two * (input - target);
if (reduction == ReductionMode::kNone) {
return d_input * grad_output;
}
xla::XlaOp grad_value = grad_output;
if (reduction == ReductionMode::kMean) {
xla::int64 num_elements = xla::ShapeUtil::ElementsIn(input_shape);
xla::XlaOp scale_value = XlaHelpers::ScalarValue<double>(
1.0 / static_cast<double>(num_elements), input_shape.element_type(),
input.builder());
grad_value = grad_output * scale_value;
}
return d_input * grad_value;
}

xla::XlaOp BuildCumulativeComputation(const xla::XlaOp& input, xla::int64 dim,
const xla::XlaComputation& reducer,
const xla::XlaOp& init) {
Expand Down
8 changes: 8 additions & 0 deletions torch_xla/csrc/reduction.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,14 @@ xla::XlaOp BuildL1LossBackward(const xla::XlaOp& grad_output,
const xla::XlaOp& target,
ReductionMode reduction);

xla::XlaOp BuildMseLoss(const xla::XlaOp& input, const xla::XlaOp& target,
ReductionMode reduction);

xla::XlaOp BuildMseLossBackward(const xla::XlaOp& grad_output,
const xla::XlaOp& input,
const xla::XlaOp& target,
ReductionMode reduction);

// Builds a mean by reducing all the dimensions listed in dimensions. If
// keep_reduced_dimensions is true, the reduced dimensions will be retained,
// with value 1.
Expand Down
8 changes: 8 additions & 0 deletions torch_xla/csrc/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -630,6 +630,14 @@ class XLATensor {

static XLATensor mm(const XLATensor& input, const XLATensor& weight);

static XLATensor mse_loss(const XLATensor& input, const XLATensor& target,
xla::int64 reduction);

static XLATensor mse_loss_backward(const XLATensor& grad_output,
const XLATensor& input,
const XLATensor& target,
xla::int64 reduction);

static XLATensor mul(const XLATensor& input, const XLATensor& other);
static XLATensor mul(const XLATensor& input, at::Scalar other);
static void mul_(XLATensor& input, const XLATensor& other);
Expand Down
17 changes: 17 additions & 0 deletions torch_xla/csrc/tensor_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@
#include "torch_xla/csrc/ops/max_pool_nd_backward.h"
#include "torch_xla/csrc/ops/mean.h"
#include "torch_xla/csrc/ops/min_in_dim.h"
#include "torch_xla/csrc/ops/mse_loss.h"
#include "torch_xla/csrc/ops/mse_loss_backward.h"
#include "torch_xla/csrc/ops/native_batch_norm_backward.h"
#include "torch_xla/csrc/ops/native_batch_norm_forward.h"
#include "torch_xla/csrc/ops/nll_loss.h"
Expand Down Expand Up @@ -1487,6 +1489,21 @@ XLATensor XLATensor::mm(const XLATensor& input, const XLATensor& weight) {
ir::ops::Dot(input.GetIrValue(), weight.GetIrValue()));
}

XLATensor XLATensor::mse_loss(const XLATensor& input, const XLATensor& target,
xla::int64 reduction) {
return input.CreateFrom(ir::MakeNode<ir::ops::MseLoss>(
input.GetIrValue(), target.GetIrValue(), GetXlaReductionMode(reduction)));
}

XLATensor XLATensor::mse_loss_backward(const XLATensor& grad_output,
const XLATensor& input,
const XLATensor& target,
xla::int64 reduction) {
return input.CreateFrom(ir::MakeNode<ir::ops::MseLossBackward>(
grad_output.GetIrValue(), input.GetIrValue(), target.GetIrValue(),
GetXlaReductionMode(reduction)));
}

XLATensor XLATensor::mul(const XLATensor& input, const XLATensor& other) {
return input.CreateFrom(input.GetIrValue() * other.GetIrValue());
}
Expand Down