From be19e1e083cc9441e05d0f3afcccfee339a76ddc Mon Sep 17 00:00:00 2001 From: Davide Libenzi Date: Tue, 22 Oct 2019 18:02:10 -0700 Subject: [PATCH] Added aten::upsample_nearest2d and aten::upsample_nearest2d_backward for TPU. Going to PyTorch CPU for XLA CPU/GPU, as those have not implemented the XLA custom-call to lower them. --- test/cpp/test_aten_xla_tensor.cpp | 35 +++++++ torch_xla/csrc/aten_xla_type.cpp | 23 +++++ torch_xla/csrc/aten_xla_type.h | 7 ++ torch_xla/csrc/ops/upsample_nearest2d.cpp | 86 ++++++++++++++++ torch_xla/csrc/ops/upsample_nearest2d.h | 29 ++++++ .../csrc/ops/upsample_nearest2d_backward.cpp | 99 +++++++++++++++++++ .../csrc/ops/upsample_nearest2d_backward.h | 34 +++++++ torch_xla/csrc/tensor.h | 7 ++ torch_xla/csrc/tensor_methods.cpp | 15 +++ 9 files changed, 335 insertions(+) create mode 100644 torch_xla/csrc/ops/upsample_nearest2d.cpp create mode 100644 torch_xla/csrc/ops/upsample_nearest2d.h create mode 100644 torch_xla/csrc/ops/upsample_nearest2d_backward.cpp create mode 100644 torch_xla/csrc/ops/upsample_nearest2d_backward.h diff --git a/test/cpp/test_aten_xla_tensor.cpp b/test/cpp/test_aten_xla_tensor.cpp index 757042ebe3a..7d45df3b97c 100644 --- a/test/cpp/test_aten_xla_tensor.cpp +++ b/test/cpp/test_aten_xla_tensor.cpp @@ -2590,6 +2590,41 @@ TEST_F(AtenXlaTensorTest, TestBilinear) { }); } +TEST_F(AtenXlaTensorTest, TestUpsampleNearest2D) { + int batch_size = 2; + int h = 5; + int w = 5; + int uh = 8; + int uw = 8; + int chans = 2; + torch::Tensor input = torch::rand({batch_size, chans, h, w}, + torch::TensorOptions(torch::kFloat)); + ForEachDevice([&](const torch::Device& device) { + torch::Tensor xla_input = CopyToDevice(input, device); + torch::Tensor result = torch::upsample_nearest2d(input, {uh, uw}); + torch::Tensor xla_result = torch::upsample_nearest2d(xla_input, {uh, uw}); + AllClose(result, xla_result); + }); +} + +TEST_F(AtenXlaTensorTest, TestUpsampleNearest2DBackward) { + int batch_size = 2; + int h = 5; + int w = 5; + int uh = 8; + int uw = 8; + int chans = 2; + auto testfn = [&](const std::vector& inputs) -> torch::Tensor { + return torch::upsample_nearest2d(inputs[0], {uh, uw}); + }; + ForEachDevice([&](const torch::Device& device) { + TestBackward( + {torch::rand({batch_size, chans, h, w}, + torch::TensorOptions(torch::kFloat).requires_grad(true))}, + device, testfn); + }); +} + TEST_F(AtenXlaTensorTest, TestAddCMul) { torch::Tensor a = torch::rand({2, 2}, torch::TensorOptions(torch::kFloat)); torch::Tensor b = torch::rand({2, 2}, torch::TensorOptions(torch::kFloat)); diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index e3cbf67d09c..e339053a90f 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -3037,6 +3037,29 @@ at::Tensor& AtenXlaType::unsqueeze_(at::Tensor& self, int64_t dim) { return self; } +at::Tensor AtenXlaType::upsample_nearest2d(const at::Tensor& self, + at::IntArrayRef output_size) { + XLATensor self_tensor = bridge::GetXlaTensor(self); + if (self_tensor.GetDevice().hw_type != DeviceType::TPU) { + return AtenXlaTypeDefault::upsample_nearest2d(self, output_size); + } + return bridge::AtenFromXlaTensor(XLATensor::upsample_nearest2d( + self_tensor, xla::util::ToVector(output_size))); +} + +at::Tensor AtenXlaType::upsample_nearest2d_backward( + const at::Tensor& grad_output, at::IntArrayRef output_size, + at::IntArrayRef input_size) { + XLATensor grad_output_tensor = bridge::GetXlaTensor(grad_output); + if (grad_output_tensor.GetDevice().hw_type != DeviceType::TPU) { + return AtenXlaTypeDefault::upsample_nearest2d_backward( + grad_output, output_size, input_size); + } + return bridge::AtenFromXlaTensor(XLATensor::upsample_nearest2d_backward( + grad_output_tensor, xla::util::ToVector(output_size), + xla::util::ToVector(input_size))); +} + at::Tensor AtenXlaType::view(const at::Tensor& self, at::IntArrayRef size) { return bridge::AtenFromXlaTensor( XLATensor::view(bridge::GetXlaTensor(self), XlaHelpers::I64List(size))); diff --git a/torch_xla/csrc/aten_xla_type.h b/torch_xla/csrc/aten_xla_type.h index c0c854ffe25..54b18aacddd 100644 --- a/torch_xla/csrc/aten_xla_type.h +++ b/torch_xla/csrc/aten_xla_type.h @@ -1124,6 +1124,13 @@ class AtenXlaType { static at::Tensor& unsqueeze_(at::Tensor& self, int64_t dim); + static at::Tensor upsample_nearest2d(const at::Tensor& self, + at::IntArrayRef output_size); + + static at::Tensor upsample_nearest2d_backward(const at::Tensor& grad_output, + at::IntArrayRef output_size, + at::IntArrayRef input_size); + static at::Tensor view(const at::Tensor& self, at::IntArrayRef size); static at::Tensor view_as(const at::Tensor& self, const at::Tensor& other); diff --git a/torch_xla/csrc/ops/upsample_nearest2d.cpp b/torch_xla/csrc/ops/upsample_nearest2d.cpp new file mode 100644 index 00000000000..abf9fd70c1c --- /dev/null +++ b/torch_xla/csrc/ops/upsample_nearest2d.cpp @@ -0,0 +1,86 @@ +#include "torch_xla/csrc/ops/upsample_nearest2d.h" + +#include + +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_client/debug_macros.h" +#include "tensorflow/compiler/xla/xla_client/util.h" +#include "torch_xla/csrc/helpers.h" +#include "torch_xla/csrc/lowering_context.h" + +namespace torch_xla { +namespace ir { +namespace ops { +namespace { + +xla::Shape NodeOutputShape( + const Value& input, + tensorflow::gtl::ArraySlice output_size) { + XLA_CHECK_EQ(output_size.size(), 2); + const xla::Shape& input_shape = input.shape(); + return xla::ShapeUtil::MakeShape( + input_shape.element_type(), + {input_shape.dimensions(0), input_shape.dimensions(1), output_size[0], + output_size[1]}); +} + +std::string GetBackendConfig(bool align_corners, bool half_pixel_centers) { + return absl::StrCat("\"", align_corners, half_pixel_centers, "\""); +} + +xla::XlaOp LowerUpsampleNearest(const xla::XlaOp& input, + const xla::Shape& output_shape) { + xla::Shape input_shape = XlaHelpers::ShapeOfXlaOp(input); + if (input_shape.dimensions(2) == output_shape.dimensions(2) && + input_shape.dimensions(3) == output_shape.dimensions(3)) { + return input; + } + if (input_shape.dimensions(2) == 1 && input_shape.dimensions(3) == 1) { + return input + xla::Zeros(input.builder(), output_shape); + } + // XLA wants NHWC while PyTorch comes in as NCHW, so we need to transpose, + // call the kernel, and transpose back. + std::vector transpose_permute({0, 3, 2, 1}); + auto inv_transpose_permute = xla::InversePermutation(transpose_permute); + xla::Shape resized_shape = + xla::ShapeUtil::PermuteDimensions(inv_transpose_permute, output_shape); + xla::XlaOp tinput = xla::Transpose(input, transpose_permute); + xla::XlaOp resised = xla::CustomCall( + input.builder(), "ResizeNearest", {tinput}, resized_shape, + GetBackendConfig(/*align_corners=*/false, /*half_pixel_centers=*/false)); + return xla::Transpose(resised, inv_transpose_permute); +} + +} // namespace + +UpsampleNearest::UpsampleNearest(const Value& input, + std::vector output_size) + : Node(ir::OpKind(at::aten::upsample_nearest2d), {input}, + [&]() { return NodeOutputShape(input, output_size); }, + /*num_outputs=*/1, xla::util::MHash(output_size)), + output_size_(std::move(output_size)) {} + +NodePtr UpsampleNearest::Clone(OpList operands) const { + return MakeNode(operands.at(0), output_size_); +} + +XlaOpVector UpsampleNearest::Lower(LoweringContext* loctx) const { + xla::XlaOp input = loctx->GetOutputOp(operand(0)); + xla::XlaOp output = LowerUpsampleNearest(input, shape()); + return ReturnOp(output, loctx); +} + +std::string UpsampleNearest::ToString() const { + std::stringstream ss; + ss << Node::ToString() << ", output_size=(" + << absl::StrJoin(output_size_, ", ") << ")"; + return ss.str(); +} + +} // namespace ops +} // namespace ir +} // namespace torch_xla diff --git a/torch_xla/csrc/ops/upsample_nearest2d.h b/torch_xla/csrc/ops/upsample_nearest2d.h new file mode 100644 index 00000000000..62cc383d7a7 --- /dev/null +++ b/torch_xla/csrc/ops/upsample_nearest2d.h @@ -0,0 +1,29 @@ +#pragma once + +#include + +#include "torch_xla/csrc/ir.h" + +namespace torch_xla { +namespace ir { +namespace ops { + +class UpsampleNearest : public Node { + public: + UpsampleNearest(const Value& input, std::vector output_size); + + NodePtr Clone(OpList operands) const override; + + XlaOpVector Lower(LoweringContext* loctx) const override; + + std::string ToString() const override; + + const std::vector& output_size() const { return output_size_; } + + private: + std::vector output_size_; +}; + +} // namespace ops +} // namespace ir +} // namespace torch_xla diff --git a/torch_xla/csrc/ops/upsample_nearest2d_backward.cpp b/torch_xla/csrc/ops/upsample_nearest2d_backward.cpp new file mode 100644 index 00000000000..e9e03869117 --- /dev/null +++ b/torch_xla/csrc/ops/upsample_nearest2d_backward.cpp @@ -0,0 +1,99 @@ +#include "torch_xla/csrc/ops/upsample_nearest2d_backward.h" + +#include + +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/xla_client/debug_macros.h" +#include "tensorflow/compiler/xla/xla_client/sys_util.h" +#include "tensorflow/compiler/xla/xla_client/util.h" +#include "torch_xla/csrc/helpers.h" +#include "torch_xla/csrc/lowering_context.h" + +namespace torch_xla { +namespace ir { +namespace ops { +namespace { + +xla::Shape NodeOutputShape( + const Value& input, + tensorflow::gtl::ArraySlice input_size) { + return xla::ShapeUtil::MakeShape(input.shape().element_type(), input_size); +} + +std::string GetBackendConfig(bool align_corners, bool half_pixel_centers) { + return absl::StrCat("\"", align_corners, half_pixel_centers, "\""); +} + +double ResizeFactor(const xla::Shape& input_shape, + const xla::Shape& output_shape, int dim) { + return static_cast(input_shape.dimensions(dim)) / + static_cast(output_shape.dimensions(dim)); +} + +xla::XlaOp LowerUpsampleNearestBackward(const xla::XlaOp& input, + const xla::Shape& output_shape) { + static double resiple_split_factor = + xla::sys_util::GetEnvDouble("XLA_RESIZE_SPLIT_FACTOR", 3.0); + xla::Shape input_shape = XlaHelpers::ShapeOfXlaOp(input); + if (input_shape.dimensions(2) == output_shape.dimensions(2) && + input_shape.dimensions(3) == output_shape.dimensions(3)) { + return input; + } + // XLA wants NHWC while PyTorch comes in as NCHW, so we need to transpose, + // call the kernel, and transpose back. + std::vector transpose_permute({0, 3, 2, 1}); + auto inv_transpose_permute = xla::InversePermutation(transpose_permute); + xla::Shape resized_shape = + xla::ShapeUtil::PermuteDimensions(inv_transpose_permute, output_shape); + xla::XlaOp tinput = xla::Transpose(input, transpose_permute); + std::string backend_config = + GetBackendConfig(/*align_corners=*/false, /*half_pixel_centers=*/false); + if (ResizeFactor(input_shape, output_shape, 2) > resiple_split_factor && + ResizeFactor(input_shape, output_shape, 3) > resiple_split_factor) { + // If the resize is too large, do one dimension at a time. + xla::Shape partial_shape = resized_shape; + // Partial shape is in NHWC, while input shape is in NCHW. + partial_shape.mutable_dimensions()[1] = input_shape.dimensions(2); + tinput = xla::CustomCall(input.builder(), "ResizeNearestGrad", {tinput}, + partial_shape, backend_config); + } + xla::XlaOp resised = xla::CustomCall(input.builder(), "ResizeNearestGrad", + {tinput}, resized_shape, backend_config); + return xla::Transpose(resised, inv_transpose_permute); +} + +} // namespace + +UpsampleNearestBackward::UpsampleNearestBackward( + const Value& input, std::vector output_size, + std::vector input_size) + : Node(ir::OpKind(at::aten::upsample_nearest2d_backward), {input}, + [&]() { return NodeOutputShape(input, input_size); }, + /*num_outputs=*/1, xla::util::MHash(output_size, input_size)), + output_size_(std::move(output_size)), + input_size_(std::move(input_size)) {} + +NodePtr UpsampleNearestBackward::Clone(OpList operands) const { + return MakeNode(operands.at(0), output_size_, + input_size_); +} + +XlaOpVector UpsampleNearestBackward::Lower(LoweringContext* loctx) const { + xla::XlaOp input = loctx->GetOutputOp(operand(0)); + xla::XlaOp output = LowerUpsampleNearestBackward(input, shape()); + return ReturnOp(output, loctx); +} + +std::string UpsampleNearestBackward::ToString() const { + std::stringstream ss; + ss << Node::ToString() << ", output_size=(" + << absl::StrJoin(output_size_, ", ") << "), input_size=(" + << absl::StrJoin(input_size_, ", ") << ")"; + return ss.str(); +} + +} // namespace ops +} // namespace ir +} // namespace torch_xla diff --git a/torch_xla/csrc/ops/upsample_nearest2d_backward.h b/torch_xla/csrc/ops/upsample_nearest2d_backward.h new file mode 100644 index 00000000000..f9b06ba0d0e --- /dev/null +++ b/torch_xla/csrc/ops/upsample_nearest2d_backward.h @@ -0,0 +1,34 @@ +#pragma once + +#include + +#include "torch_xla/csrc/ir.h" + +namespace torch_xla { +namespace ir { +namespace ops { + +class UpsampleNearestBackward : public Node { + public: + UpsampleNearestBackward(const Value& input, + std::vector output_size, + std::vector input_size); + + NodePtr Clone(OpList operands) const override; + + XlaOpVector Lower(LoweringContext* loctx) const override; + + std::string ToString() const override; + + const std::vector& output_size() const { return output_size_; } + + const std::vector& input_size() const { return input_size_; } + + private: + std::vector output_size_; + std::vector input_size_; +}; + +} // namespace ops +} // namespace ir +} // namespace torch_xla diff --git a/torch_xla/csrc/tensor.h b/torch_xla/csrc/tensor.h index 5a6e8b0fbdf..b07402f392e 100644 --- a/torch_xla/csrc/tensor.h +++ b/torch_xla/csrc/tensor.h @@ -926,6 +926,13 @@ class XLATensor { // In-place version of the method above. static void unsqueeze_(XLATensor& input, xla::int64 dim); + static XLATensor upsample_nearest2d(const XLATensor& input, + std::vector output_size); + + static XLATensor upsample_nearest2d_backward( + const XLATensor& grad_output, std::vector output_size, + std::vector input_size); + // Like reshape, but it returns a view into the original tensor. static XLATensor view( const XLATensor& input, diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index 240f5b7bfe7..88bc402ee23 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -98,6 +98,8 @@ #include "torch_xla/csrc/ops/tril.h" #include "torch_xla/csrc/ops/triu.h" #include "torch_xla/csrc/ops/unsqueeze.h" +#include "torch_xla/csrc/ops/upsample_nearest2d.h" +#include "torch_xla/csrc/ops/upsample_nearest2d_backward.h" #include "torch_xla/csrc/ops/view.h" #include "torch_xla/csrc/tensor.h" #include "torch_xla/csrc/tensor_ops.h" @@ -2325,6 +2327,19 @@ void XLATensor::unsqueeze_(XLATensor& input, xla::int64 dim) { ir::MakeNode(input.GetIrValue(), squeeze_dim)); } +XLATensor XLATensor::upsample_nearest2d(const XLATensor& input, + std::vector output_size) { + return input.CreateFrom(ir::MakeNode( + input.GetIrValue(), std::move(output_size))); +} + +XLATensor XLATensor::upsample_nearest2d_backward( + const XLATensor& grad_output, std::vector output_size, + std::vector input_size) { + return grad_output.CreateFrom(ir::MakeNode( + grad_output.GetIrValue(), std::move(output_size), std::move(input_size))); +} + XLATensor XLATensor::view( const XLATensor& input, tensorflow::gtl::ArraySlice output_size) {