From 483a82c6fe6ea805229059af516bd99daccb6581 Mon Sep 17 00:00:00 2001 From: Davide Libenzi Date: Sat, 23 Nov 2019 08:06:32 -0800 Subject: [PATCH] Added aten::reflection_pad2d and aten::reflection_pad2d_backward XLA Lowerings. --- test/cpp/test_aten_xla_tensor.cpp | 45 +++++++++ torch_xla/csrc/aten_xla_type.cpp | 16 +++ torch_xla/csrc/aten_xla_type.h | 7 ++ torch_xla/csrc/data_ops.cpp | 99 +++++++++++++++++++ torch_xla/csrc/data_ops.h | 11 +++ torch_xla/csrc/ir.cpp | 2 +- torch_xla/csrc/ops/reflection_pad2d.cpp | 50 ++++++++++ torch_xla/csrc/ops/reflection_pad2d.h | 29 ++++++ .../csrc/ops/reflection_pad2d_backward.cpp | 58 +++++++++++ .../csrc/ops/reflection_pad2d_backward.h | 30 ++++++ torch_xla/csrc/tensor.h | 7 ++ torch_xla/csrc/tensor_methods.cpp | 15 +++ 12 files changed, 368 insertions(+), 1 deletion(-) create mode 100644 torch_xla/csrc/ops/reflection_pad2d.cpp create mode 100644 torch_xla/csrc/ops/reflection_pad2d.h create mode 100644 torch_xla/csrc/ops/reflection_pad2d_backward.cpp create mode 100644 torch_xla/csrc/ops/reflection_pad2d_backward.h diff --git a/test/cpp/test_aten_xla_tensor.cpp b/test/cpp/test_aten_xla_tensor.cpp index 1d42d24678d..7537a1167c0 100644 --- a/test/cpp/test_aten_xla_tensor.cpp +++ b/test/cpp/test_aten_xla_tensor.cpp @@ -7735,6 +7735,51 @@ TEST_F(AtenXlaTensorTest, TestConstantPadIncomplete) { }); } +TEST_F(AtenXlaTensorTest, TestReflectionPad2dRank3) { + torch::Tensor input = + torch::rand({2, 3, 4}, torch::TensorOptions(torch::kFloat)); + std::vector pad{2, 2, 2, 2}; + torch::Tensor output = torch::reflection_pad2d(input, pad); + ForEachDevice([&](const torch::Device& device) { + torch::Tensor xla_input = CopyToDevice(input, device); + torch::Tensor xla_output = torch::reflection_pad2d(xla_input, pad); + AllClose(output, xla_output); + }); + + ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::reflection_pad2d", cpp_test::GetIgnoredCounters()); +} + +TEST_F(AtenXlaTensorTest, TestReflectionPad2dRank4) { + torch::Tensor input = + torch::rand({2, 2, 3, 4}, torch::TensorOptions(torch::kFloat)); + std::vector pad{2, 2, 2, 2}; + torch::Tensor output = torch::reflection_pad2d(input, pad); + ForEachDevice([&](const torch::Device& device) { + torch::Tensor xla_input = CopyToDevice(input, device); + torch::Tensor xla_output = torch::reflection_pad2d(xla_input, pad); + AllClose(output, xla_output); + }); + + ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::reflection_pad2d", cpp_test::GetIgnoredCounters()); +} + +TEST_F(AtenXlaTensorTest, TestReflectionPad2dBackward) { + std::vector pad{2, 3, 1, 2}; + auto testfn = [&](const std::vector& inputs) -> torch::Tensor { + return torch::reflection_pad2d(inputs[0], pad); + }; + ForEachDevice([&](const torch::Device& device) { + TestBackward( + {torch::rand({1, 2, 4, 4}, + torch::TensorOptions(torch::kFloat).requires_grad(true))}, + device, testfn); + }); + + ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); +} + TEST_F(AtenXlaTensorTest, TestAsStrided) { torch::Tensor input = torch::rand({128, 320}, torch::TensorOptions(torch::kFloat)); diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index 628a14b8692..b33c6b46c13 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -2234,6 +2234,22 @@ at::Tensor& AtenXlaType::reciprocal_(at::Tensor& self) { return self; } +at::Tensor AtenXlaType::reflection_pad2d(const at::Tensor& self, + at::IntArrayRef padding) { + XLA_FN_COUNTER("xla::"); + return bridge::AtenFromXlaTensor(XLATensor::reflection_pad2d( + bridge::GetXlaTensor(self), xla::util::ToVector(padding))); +} + +at::Tensor AtenXlaType::reflection_pad2d_backward(const at::Tensor& grad_output, + const at::Tensor& self, + at::IntArrayRef padding) { + XLA_FN_COUNTER("xla::"); + return bridge::AtenFromXlaTensor(XLATensor::reflection_pad2d_backward( + bridge::GetXlaTensor(grad_output), bridge::GetXlaTensor(self), + xla::util::ToVector(padding))); +} + at::Tensor AtenXlaType::relu(const at::Tensor& self) { XLA_FN_COUNTER("xla::"); return bridge::AtenFromXlaTensor(XLATensor::relu(bridge::GetXlaTensor(self))); diff --git a/torch_xla/csrc/aten_xla_type.h b/torch_xla/csrc/aten_xla_type.h index b0700913ff1..7e61c550ca7 100644 --- a/torch_xla/csrc/aten_xla_type.h +++ b/torch_xla/csrc/aten_xla_type.h @@ -687,6 +687,13 @@ class AtenXlaType { static at::Tensor& reciprocal_(at::Tensor& self); + static at::Tensor reflection_pad2d(const at::Tensor& self, + at::IntArrayRef padding); + + static at::Tensor reflection_pad2d_backward(const at::Tensor& grad_output, + const at::Tensor& self, + at::IntArrayRef padding); + static at::Tensor relu(const at::Tensor& self); static at::Tensor& relu_(at::Tensor& self); diff --git a/torch_xla/csrc/data_ops.cpp b/torch_xla/csrc/data_ops.cpp index 56f7488241f..5ba8687c26a 100644 --- a/torch_xla/csrc/data_ops.cpp +++ b/torch_xla/csrc/data_ops.cpp @@ -30,6 +30,16 @@ bool IsSparseGather(const xla::Shape& input_shape, return index_elements < input_elements / dense_gather_factor; } +std::vector GetReflectionPad2dSpatialDims(xla::int64 rank) { + std::vector spatial_dims; + if (rank == 3) { + return {2, 1}; + } else if (rank == 4) { + return {3, 2}; + } + XLA_ERROR() << "Invalid input shape for reflection_pad2d: rank=" << rank; +} + } // namespace bool IsSparseGather(const xla::XlaOp& input, const xla::XlaOp& index, @@ -348,4 +358,93 @@ xla::XlaOp BuildUnselect(const xla::XlaOp& target, const xla::XlaOp& source, return xla::Select(mask, padded_source, target); } +xla::XlaOp BuildReflectionPad2d( + const xla::XlaOp& input, + tensorflow::gtl::ArraySlice padding) { + const xla::Shape& input_shape = XlaHelpers::ShapeOfXlaOp(input); + std::vector spatial_dims = + GetReflectionPad2dSpatialDims(input_shape.rank()); + + xla::XlaOp result = input; + for (xla::int64 i = 0; i < spatial_dims.size(); ++i) { + xla::int64 dim = spatial_dims[i]; + xla::int64 dim_size = input_shape.dimensions(dim); + xla::int64 lhs_padding = padding[2 * i]; + xla::int64 rhs_padding = padding[2 * i + 1]; + + XLA_CHECK(lhs_padding >= 0 && lhs_padding <= dim_size - 1); + XLA_CHECK(rhs_padding >= 0 && rhs_padding <= dim_size - 1); + + xla::XlaOp reverse = xla::Rev(result, {dim}); + xla::XlaOp lhs_pad = xla::SliceInDim(reverse, dim_size - 1 - lhs_padding, + dim_size - 1, 1, dim); + xla::XlaOp rhs_pad = xla::SliceInDim(reverse, 1, 1 + rhs_padding, 1, dim); + result = xla::ConcatInDim(input.builder(), {lhs_pad, result, rhs_pad}, dim); + } + return result; +} + +xla::XlaOp BuildReflectionPad2dBackward( + const xla::XlaOp& grad_output, const xla::XlaOp& input, + tensorflow::gtl::ArraySlice padding) { + const xla::Shape& input_shape = XlaHelpers::ShapeOfXlaOp(input); + const xla::Shape& grad_output_shape = XlaHelpers::ShapeOfXlaOp(grad_output); + std::vector spatial_dims = + GetReflectionPad2dSpatialDims(grad_output_shape.rank()); + + xla::XlaOp grad = grad_output; + for (xla::int64 i = 0; i < spatial_dims.size(); ++i) { + xla::int64 dim = spatial_dims[i]; + xla::int64 dim_size = grad_output_shape.dimensions(dim); + xla::int64 lhs_padding = padding[2 * i]; + xla::int64 rhs_padding = padding[2 * i + 1]; + + XLA_CHECK(lhs_padding >= 0 && lhs_padding <= dim_size - 1); + XLA_CHECK(rhs_padding >= 0 && rhs_padding <= dim_size - 1); + + xla::XlaOp lhs_pad = xla::SliceInDim(grad, 0, lhs_padding, 1, dim); + xla::XlaOp reverse_lhs_pad = xla::Rev(lhs_pad, {dim}); + xla::XlaOp padded_lhs_pad = + PadInDim(reverse_lhs_pad, dim, + /*pad_lo=*/1, + /*pad_hi=*/input_shape.dimensions(dim) - lhs_padding - 1); + + xla::XlaOp rhs_pad = + xla::SliceInDim(grad, dim_size - rhs_padding, dim_size, 1, dim); + xla::XlaOp reverse_rhs_pad = xla::Rev(rhs_pad, {dim}); + xla::XlaOp padded_rhs_pad = + PadInDim(reverse_rhs_pad, dim, + /*pad_lo=*/input_shape.dimensions(dim) - rhs_padding - 1, + /*pad_hi=*/1); + + xla::XlaOp grad_core = + xla::SliceInDim(grad, lhs_padding, dim_size - rhs_padding, 1, dim); + grad = padded_lhs_pad + grad_core + padded_rhs_pad; + } + return grad; +} + +xla::XlaOp PadInDim(const xla::XlaOp& input, xla::int64 dim, xla::int64 pad_lo, + xla::int64 pad_hi, const xla::XlaOp* pad_value) { + const xla::Shape& input_shape = XlaHelpers::ShapeOfXlaOp(input); + xla::XlaOp zero; + if (pad_value == nullptr) { + zero = xla::Zero(input.builder(), input_shape.element_type()); + pad_value = &zero; + } + xla::PaddingConfig padding_config; + for (xla::int64 i = 0; i < input_shape.rank(); ++i) { + auto* dims = padding_config.add_dimensions(); + dims->set_interior_padding(0); + if (i == dim) { + dims->set_edge_padding_low(pad_lo); + dims->set_edge_padding_high(pad_hi); + } else { + dims->set_edge_padding_low(0); + dims->set_edge_padding_high(0); + } + } + return xla::Pad(input, *pad_value, padding_config); +} + } // namespace torch_xla diff --git a/torch_xla/csrc/data_ops.h b/torch_xla/csrc/data_ops.h index dca8848cfbd..76b08228ad8 100644 --- a/torch_xla/csrc/data_ops.h +++ b/torch_xla/csrc/data_ops.h @@ -101,4 +101,15 @@ xla::XlaOp BuildUnselect(const xla::XlaOp& target, const xla::XlaOp& source, xla::int64 dim, xla::int64 start, xla::int64 end, xla::int64 stride); +xla::XlaOp BuildReflectionPad2d( + const xla::XlaOp& input, + tensorflow::gtl::ArraySlice padding); + +xla::XlaOp BuildReflectionPad2dBackward( + const xla::XlaOp& grad_output, const xla::XlaOp& input, + tensorflow::gtl::ArraySlice padding); + +xla::XlaOp PadInDim(const xla::XlaOp& input, xla::int64 dim, xla::int64 pad_lo, + xla::int64 pad_hi, const xla::XlaOp* pad_value = nullptr); + } // namespace torch_xla diff --git a/torch_xla/csrc/ir.cpp b/torch_xla/csrc/ir.cpp index 3fe27eff2d0..1d19f8bb475 100644 --- a/torch_xla/csrc/ir.cpp +++ b/torch_xla/csrc/ir.cpp @@ -60,7 +60,7 @@ std::string GetCurrentScope() { ShapeCache* GetShapeCache() { static xla::int64 shape_cache_size = - xla::sys_util::GetEnvInt("XLA_IR_SHAPE_CACHE_SIZE", 1024); + xla::sys_util::GetEnvInt("XLA_IR_SHAPE_CACHE_SIZE", 4096); static ShapeCache* cache = new ShapeCache(shape_cache_size); return cache; } diff --git a/torch_xla/csrc/ops/reflection_pad2d.cpp b/torch_xla/csrc/ops/reflection_pad2d.cpp new file mode 100644 index 00000000000..8c4c704f897 --- /dev/null +++ b/torch_xla/csrc/ops/reflection_pad2d.cpp @@ -0,0 +1,50 @@ +#include "torch_xla/csrc/ops/reflection_pad2d.h" + +#include "absl/strings/str_join.h" +#include "tensorflow/compiler/xla/xla_client/util.h" +#include "torch_xla/csrc/data_ops.h" +#include "torch_xla/csrc/lowering_context.h" +#include "torch_xla/csrc/ops/infer_output_shape.h" + +namespace torch_xla { +namespace ir { +namespace ops { +namespace { + +xla::Shape NodeOutputShape( + const Value& input, tensorflow::gtl::ArraySlice padding) { + auto lower_for_shape_fn = + [&](tensorflow::gtl::ArraySlice operands) + -> xla::XlaOp { return BuildReflectionPad2d(operands[0], padding); }; + return InferOutputShape({input.shape()}, lower_for_shape_fn); +} + +} // namespace + +ReflectionPad2d::ReflectionPad2d(const Value& input, + std::vector padding) + : Node(OpKind(at::aten::reflection_pad2d), {input}, + [&]() { return NodeOutputShape(input, padding); }, + /*num_outputs=*/1, xla::util::MHash(padding)), + padding_(std::move(padding)) {} + +NodePtr ReflectionPad2d::Clone(OpList operands) const { + return MakeNode(operands.at(0), padding_); +} + +XlaOpVector ReflectionPad2d::Lower(LoweringContext* loctx) const { + xla::XlaOp input = loctx->GetOutputOp(operand(0)); + xla::XlaOp output = BuildReflectionPad2d(input, padding_); + return ReturnOp(output, loctx); +} + +std::string ReflectionPad2d::ToString() const { + std::stringstream ss; + ss << Node::ToString() << ", padding=(" << absl::StrJoin(padding_, ", ") + << ")"; + return ss.str(); +} + +} // namespace ops +} // namespace ir +} // namespace torch_xla diff --git a/torch_xla/csrc/ops/reflection_pad2d.h b/torch_xla/csrc/ops/reflection_pad2d.h new file mode 100644 index 00000000000..4328db42b40 --- /dev/null +++ b/torch_xla/csrc/ops/reflection_pad2d.h @@ -0,0 +1,29 @@ +#pragma once + +#include + +#include "torch_xla/csrc/ir.h" + +namespace torch_xla { +namespace ir { +namespace ops { + +class ReflectionPad2d : public Node { + public: + ReflectionPad2d(const Value& input, std::vector padding); + + NodePtr Clone(OpList operands) const override; + + XlaOpVector Lower(LoweringContext* loctx) const override; + + std::string ToString() const override; + + const std::vector& padding() const { return padding_; } + + private: + std::vector padding_; +}; + +} // namespace ops +} // namespace ir +} // namespace torch_xla diff --git a/torch_xla/csrc/ops/reflection_pad2d_backward.cpp b/torch_xla/csrc/ops/reflection_pad2d_backward.cpp new file mode 100644 index 00000000000..1a23cc50a02 --- /dev/null +++ b/torch_xla/csrc/ops/reflection_pad2d_backward.cpp @@ -0,0 +1,58 @@ +#include "torch_xla/csrc/ops/reflection_pad2d_backward.h" + +#include "absl/strings/str_join.h" +#include "tensorflow/compiler/xla/xla_client/util.h" +#include "torch_xla/csrc/data_ops.h" +#include "torch_xla/csrc/lowering_context.h" +#include "torch_xla/csrc/ops/infer_output_shape.h" + +namespace torch_xla { +namespace ir { +namespace ops { +namespace { + +xla::Shape NodeOutputShape( + const Value& grad_output, const Value& input, + tensorflow::gtl::ArraySlice padding) { + auto lower_for_shape_fn = + [&](tensorflow::gtl::ArraySlice operands) + -> xla::XlaOp { + return BuildReflectionPad2dBackward(operands[0], operands[1], padding); + }; + return InferOutputShape({grad_output.shape(), input.shape()}, + lower_for_shape_fn); +} + +} // namespace + +ReflectionPad2dBackward::ReflectionPad2dBackward( + const Value& grad_output, const Value& input, + std::vector padding) + : Node(OpKind(at::aten::reflection_pad2d_backward), {grad_output, input}, + [&]() { return NodeOutputShape(grad_output, input, padding); }, + /*num_outputs=*/1, xla::util::MHash(padding)), + padding_(std::move(padding)) {} + +NodePtr ReflectionPad2dBackward::Clone(OpList operands) const { + return MakeNode(operands.at(0), operands.at(1), + padding_); +} + +XlaOpVector ReflectionPad2dBackward::Lower(LoweringContext* loctx) const { + xla::XlaOp grad_output = loctx->GetOutputOp(operand(0)); + xla::XlaOp input = loctx->GetOutputOp(operand(1)); + xla::XlaOp output = + BuildReflectionPad2dBackward(grad_output, input, padding_); + return ReturnOp(output, loctx); +} + +std::string ReflectionPad2dBackward::ToString() const { + std::stringstream ss; + ss << Node::ToString() << ", padding=(" << absl::StrJoin(padding_, ", ") + << ")"; + return ss.str(); +} + +} // namespace ops +} // namespace ir +} // namespace torch_xla diff --git a/torch_xla/csrc/ops/reflection_pad2d_backward.h b/torch_xla/csrc/ops/reflection_pad2d_backward.h new file mode 100644 index 00000000000..358a99af947 --- /dev/null +++ b/torch_xla/csrc/ops/reflection_pad2d_backward.h @@ -0,0 +1,30 @@ +#pragma once + +#include + +#include "torch_xla/csrc/ir.h" + +namespace torch_xla { +namespace ir { +namespace ops { + +class ReflectionPad2dBackward : public Node { + public: + ReflectionPad2dBackward(const Value& gard_output, const Value& input, + std::vector padding); + + NodePtr Clone(OpList operands) const override; + + XlaOpVector Lower(LoweringContext* loctx) const override; + + std::string ToString() const override; + + const std::vector& padding() const { return padding_; } + + private: + std::vector padding_; +}; + +} // namespace ops +} // namespace ir +} // namespace torch_xla diff --git a/torch_xla/csrc/tensor.h b/torch_xla/csrc/tensor.h index aeaa5b5dc63..8fdd8cc6f2a 100644 --- a/torch_xla/csrc/tensor.h +++ b/torch_xla/csrc/tensor.h @@ -739,6 +739,13 @@ class XLATensor { static XLATensor reciprocal(const XLATensor& input); static void reciprocal_(XLATensor& input); + static XLATensor reflection_pad2d(const XLATensor& input, + std::vector padding); + + static XLATensor reflection_pad2d_backward(const XLATensor& grad_output, + const XLATensor& input, + std::vector padding); + static XLATensor relu(const XLATensor& input); static void relu_(XLATensor& input); diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index 63dd2648800..506e5910140 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -76,6 +76,8 @@ #include "torch_xla/csrc/ops/put.h" #include "torch_xla/csrc/ops/qr.h" #include "torch_xla/csrc/ops/randperm.h" +#include "torch_xla/csrc/ops/reflection_pad2d.h" +#include "torch_xla/csrc/ops/reflection_pad2d_backward.h" #include "torch_xla/csrc/ops/repeat.h" #include "torch_xla/csrc/ops/resize.h" #include "torch_xla/csrc/ops/rrelu_with_noise.h" @@ -1782,6 +1784,19 @@ void XLATensor::reciprocal_(XLATensor& input) { input.SetIrValue(ir::ops::ReciprocalOp(input.GetIrValue())); } +XLATensor XLATensor::reflection_pad2d(const XLATensor& input, + std::vector padding) { + return input.CreateFrom(ir::MakeNode( + input.GetIrValue(), std::move(padding))); +} + +XLATensor XLATensor::reflection_pad2d_backward( + const XLATensor& grad_output, const XLATensor& input, + std::vector padding) { + return input.CreateFrom(ir::MakeNode( + grad_output.GetIrValue(), input.GetIrValue(), std::move(padding))); +} + XLATensor XLATensor::relu(const XLATensor& input) { return input.CreateFrom(ir::ops::ReluOp(input.GetIrValue())); }