diff --git a/test/cpp/test_aten_xla_tensor.cpp b/test/cpp/test_aten_xla_tensor.cpp index 7ae27192eca..53f0a0b33b1 100644 --- a/test/cpp/test_aten_xla_tensor.cpp +++ b/test/cpp/test_aten_xla_tensor.cpp @@ -6662,6 +6662,87 @@ TEST_F(AtenXlaTensorTest, TestAdaptiveAvgPool2DNoBatch) { } } +TEST_F(AtenXlaTensorTest, TestMaxUnpool2D) { + int kernel_size = 2; + torch::Tensor input = + torch::rand({2, 2, 8, 8}, torch::TensorOptions(torch::kFloat)); + for (int stride = 1; stride <= 2; ++stride) { + for (int padding = 0; padding <= 1; ++padding) { + // Test ceil_mode=true through the CPU interop. + for (bool ceil_mode : {false, true}) { + // Test dilation through the CPU interop. + for (int dilation = 1; dilation <= 2; ++dilation) { + torch::Tensor output; + torch::Tensor indices; + std::tie(output, indices) = torch::max_pool2d_with_indices( + input, /*kernel_size=*/{kernel_size, kernel_size}, + /*stride=*/{stride, stride}, + /*padding=*/{padding, padding}, /*dilation=*/{dilation, dilation}, + /*ceil_mode=*/ceil_mode); + + std::vector output_size({input.size(2), input.size(3)}); + at::Tensor utensor = + torch::max_unpool2d(output, indices, output_size); + + ForEachDevice([&](const torch::Device& device) { + torch::Tensor xla_output = CopyToDevice(output, device); + torch::Tensor xla_indices = CopyToDevice(indices, device); + at::Tensor xla_utensor = + torch::max_unpool2d(xla_output, xla_indices, output_size); + AllClose(utensor, xla_utensor); + }); + } + } + } + } + + ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::max_unpool2d", cpp_test::GetIgnoredCounters()); +} + +TEST_F(AtenXlaTensorTest, TestMaxUnpool3D) { + int kernel_size = 2; + torch::Tensor input = + torch::rand({2, 2, 8, 8, 8}, torch::TensorOptions(torch::kFloat)); + for (int stride = 1; stride <= 2; ++stride) { + for (int padding = 0; padding <= 1; ++padding) { + // Test ceil_mode=true through the CPU interop. + for (bool ceil_mode : {false, true}) { + // Test dilation through the CPU interop. + for (int dilation = 1; dilation <= 2; ++dilation) { + torch::Tensor output; + torch::Tensor indices; + std::tie(output, indices) = torch::max_pool3d_with_indices( + input, /*kernel_size=*/{kernel_size, kernel_size, kernel_size}, + /*stride=*/{stride, stride, stride}, + /*padding=*/{padding, padding, padding}, + /*dilation=*/{dilation, dilation, dilation}, + /*ceil_mode=*/ceil_mode); + + std::vector output_size( + {input.size(2), input.size(3), input.size(4)}); + at::Tensor utensor = torch::max_unpool3d( + output, indices, output_size, /*stride=*/{stride, stride, stride}, + /*padding=*/{padding, padding, padding}); + + ForEachDevice([&](const torch::Device& device) { + torch::Tensor xla_output = CopyToDevice(output, device); + torch::Tensor xla_indices = CopyToDevice(indices, device); + at::Tensor xla_utensor = + torch::max_unpool3d(xla_output, xla_indices, output_size, + /*stride=*/{stride, stride, stride}, + /*padding=*/{padding, padding, padding}); + AllClose(utensor, xla_utensor); + }); + } + } + } + } + + ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::max_unpool3d", cpp_test::GetIgnoredCounters()); +} + TEST_F(AtenXlaTensorTest, TestNllLoss) { int batch = 6; int classes = 2; @@ -8727,6 +8808,84 @@ TEST_F(AtenXlaTensorTest, TestMaxPool3DNoBatchBackward) { } } +TEST_F(AtenXlaTensorTest, TestMaxUnpool2DBackward) { + int kernel_size = 2; + torch::Tensor input = + torch::rand({2, 2, 8, 8}, torch::TensorOptions(torch::kFloat)); + for (int stride = 1; stride <= 2; ++stride) { + for (int padding = 0; padding <= 1; ++padding) { + // Test ceil_mode=true through the CPU interop. + for (bool ceil_mode : {false, true}) { + for (int dilation = 1; dilation <= 2; ++dilation) { + torch::Tensor output; + torch::Tensor indices; + std::tie(output, indices) = torch::max_pool2d_with_indices( + input, /*kernel_size=*/{kernel_size, kernel_size}, + /*stride=*/{stride, stride}, + /*padding=*/{padding, padding}, /*dilation=*/{dilation, dilation}, + /*ceil_mode=*/ceil_mode); + + std::vector output_size({input.size(2), input.size(3)}); + auto testfn = + [&](const std::vector& inputs) -> torch::Tensor { + return torch::max_unpool2d(inputs[0], inputs[1], output_size); + }; + + ForEachDevice([&](const torch::Device& device) { + TestBackward({output.requires_grad_(true), indices}, device, + testfn); + }); + } + } + } + } + + ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::max_unpool2d_backward", + cpp_test::GetIgnoredCounters()); +} + +TEST_F(AtenXlaTensorTest, TestMaxUnpool3DBackward) { + int kernel_size = 2; + torch::Tensor input = + torch::rand({2, 2, 8, 8, 8}, torch::TensorOptions(torch::kFloat)); + for (int stride = 1; stride <= 2; ++stride) { + for (int padding = 0; padding <= 1; ++padding) { + // Test ceil_mode=true through the CPU interop. + for (bool ceil_mode : {false, true}) { + for (int dilation = 1; dilation <= 2; ++dilation) { + torch::Tensor output; + torch::Tensor indices; + std::tie(output, indices) = torch::max_pool3d_with_indices( + input, /*kernel_size=*/{kernel_size, kernel_size, kernel_size}, + /*stride=*/{stride, stride, stride}, + /*padding=*/{padding, padding, padding}, + /*dilation=*/{dilation, dilation, dilation}, + /*ceil_mode=*/ceil_mode); + + std::vector output_size( + {input.size(2), input.size(3), input.size(4)}); + auto testfn = + [&](const std::vector& inputs) -> torch::Tensor { + return torch::max_unpool3d(inputs[0], inputs[1], output_size, + /*stride=*/{stride, stride, stride}, + /*padding=*/{padding, padding, padding}); + }; + + ForEachDevice([&](const torch::Device& device) { + TestBackward({output.requires_grad_(true), indices}, device, + testfn); + }); + } + } + } + } + + ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::max_unpool3d_backward", + cpp_test::GetIgnoredCounters()); +} + TEST_F(AtenXlaTensorTest, TestTanhBackward) { auto testfn = [&](const std::vector& inputs) -> torch::Tensor { return torch::tanh(inputs[0]); diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index a2dec8c6641..986a49ae408 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -1906,6 +1906,50 @@ std::tuple AtenXlaType::max_pool3d_with_indices( bridge::AtenFromXlaTensor(std::get<1>(outputs))); } +at::Tensor AtenXlaType::max_unpool2d(const at::Tensor& self, + const at::Tensor& indices, + at::IntArrayRef output_size) { + XLA_FN_COUNTER("xla::"); + return bridge::AtenFromXlaTensor(XLATensor::max_unpool( + bridge::GetXlaTensor(self), bridge::GetXlaTensor(indices), + xla::util::ToVector(output_size))); +} + +at::Tensor AtenXlaType::max_unpool2d_backward(const at::Tensor& grad_output, + const at::Tensor& self, + const at::Tensor& indices, + at::IntArrayRef output_size) { + XLA_FN_COUNTER("xla::"); + return bridge::AtenFromXlaTensor(XLATensor::max_unpool_backward( + bridge::GetXlaTensor(grad_output), bridge::GetXlaTensor(self), + bridge::GetXlaTensor(indices), + xla::util::ToVector(output_size))); +} + +at::Tensor AtenXlaType::max_unpool3d(const at::Tensor& self, + const at::Tensor& indices, + at::IntArrayRef output_size, + at::IntArrayRef stride, + at::IntArrayRef padding) { + XLA_FN_COUNTER("xla::"); + return bridge::AtenFromXlaTensor(XLATensor::max_unpool( + bridge::GetXlaTensor(self), bridge::GetXlaTensor(indices), + xla::util::ToVector(output_size))); +} + +at::Tensor AtenXlaType::max_unpool3d_backward(const at::Tensor& grad_output, + const at::Tensor& self, + const at::Tensor& indices, + at::IntArrayRef output_size, + at::IntArrayRef stride, + at::IntArrayRef padding) { + XLA_FN_COUNTER("xla::"); + return bridge::AtenFromXlaTensor(XLATensor::max_unpool_backward( + bridge::GetXlaTensor(grad_output), bridge::GetXlaTensor(self), + bridge::GetXlaTensor(indices), + xla::util::ToVector(output_size))); +} + at::Tensor AtenXlaType::mean(const at::Tensor& self, c10::optional dtype) { XLA_FN_COUNTER("xla::"); diff --git a/torch_xla/csrc/aten_xla_type.h b/torch_xla/csrc/aten_xla_type.h index 3ce54284828..a852f905210 100644 --- a/torch_xla/csrc/aten_xla_type.h +++ b/torch_xla/csrc/aten_xla_type.h @@ -578,6 +578,28 @@ class AtenXlaType { at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode, const at::Tensor& indices); + static at::Tensor max_unpool2d(const at::Tensor& self, + const at::Tensor& indices, + at::IntArrayRef output_size); + + static at::Tensor max_unpool2d_backward(const at::Tensor& grad_output, + const at::Tensor& self, + const at::Tensor& indices, + at::IntArrayRef output_size); + + static at::Tensor max_unpool3d(const at::Tensor& self, + const at::Tensor& indices, + at::IntArrayRef output_size, + at::IntArrayRef stride, + at::IntArrayRef padding); + + static at::Tensor max_unpool3d_backward(const at::Tensor& grad_output, + const at::Tensor& self, + const at::Tensor& indices, + at::IntArrayRef output_size, + at::IntArrayRef stride, + at::IntArrayRef padding); + static at::Tensor mean(const at::Tensor& self, c10::optional dtype); diff --git a/torch_xla/csrc/helpers.cpp b/torch_xla/csrc/helpers.cpp index 73fccc2906a..aa1df91372f 100644 --- a/torch_xla/csrc/helpers.cpp +++ b/torch_xla/csrc/helpers.cpp @@ -386,6 +386,32 @@ xla::XlaOp XlaHelpers::Flatten(xla::XlaOp input, xla::Shape* input_shape) { return DynamicReshape(input, {input_elements}); } +xla::XlaOp XlaHelpers::FlattenDimRange(xla::XlaOp input, xla::int64 start, + xla::int64 range, + xla::Shape* input_shape) { + xla::util::MaybePtr input_shape_tmp(input_shape); + *input_shape_tmp = ShapeOfXlaOp(input); + + std::vector sizes; + xla::int64 flat_size = -1; + for (xla::int64 dim = 0; dim < input_shape_tmp->rank(); ++dim) { + if (dim < start || dim >= start + range) { + if (flat_size >= 0) { + sizes.push_back(flat_size); + flat_size = -1; + } + sizes.push_back(input_shape_tmp->dimensions(dim)); + } else { + flat_size = + (flat_size < 0 ? 1 : flat_size) * input_shape_tmp->dimensions(dim); + } + } + if (flat_size >= 0) { + sizes.push_back(flat_size); + } + return DynamicReshape(input, sizes); +} + std::vector XlaHelpers::MakeTransposePermutation(xla::int64 dim0, xla::int64 dim1, xla::int64 rank) { diff --git a/torch_xla/csrc/helpers.h b/torch_xla/csrc/helpers.h index 902d8e68260..850b7be3e4a 100644 --- a/torch_xla/csrc/helpers.h +++ b/torch_xla/csrc/helpers.h @@ -229,6 +229,10 @@ class XlaHelpers { static xla::XlaOp Flatten(xla::XlaOp input, xla::Shape* input_shape = nullptr); + static xla::XlaOp FlattenDimRange(xla::XlaOp input, xla::int64 start, + xla::int64 range, + xla::Shape* input_shape = nullptr); + // Gathers the input using the order specified by the permutation. For each i, // output[i] = input[permutation[i]]. The given permutation must be the same // size as the input. diff --git a/torch_xla/csrc/ops/max_unpool_nd.cpp b/torch_xla/csrc/ops/max_unpool_nd.cpp new file mode 100644 index 00000000000..5ec814884cc --- /dev/null +++ b/torch_xla/csrc/ops/max_unpool_nd.cpp @@ -0,0 +1,65 @@ +#include "torch_xla/csrc/ops/max_unpool_nd.h" + +#include "tensorflow/compiler/xla/xla_client/debug_macros.h" +#include "tensorflow/compiler/xla/xla_client/util.h" +#include "torch_xla/csrc/lowering_context.h" +#include "torch_xla/csrc/ops/infer_output_shape.h" +#include "torch_xla/csrc/pooling.h" + +namespace torch_xla { +namespace ir { +namespace ops { +namespace { + +xla::Shape NodeOutputShape(const Value& input, const Value& indices, + absl::Span output_size) { + auto shape_fn = [&](absl::Span operands) -> xla::XlaOp { + return BuildMaxUnpoolNd(GetCurrentDevice(), operands[0], operands[1], + output_size); + }; + return InferOutputShape({input.shape(), indices.shape()}, shape_fn); +} + +c10::Symbol MaxUnpoolNdSymbol(xla::int64 spatial_dim_count) { + switch (spatial_dim_count) { + case 2: + return at::aten::max_unpool2d; + case 3: + return at::aten::max_unpool3d; + default: + XLA_ERROR() << "Invalid number of spatial dimensions: " + << spatial_dim_count; + } +} + +} // namespace + +MaxUnpoolNd::MaxUnpoolNd(const Value& input, const Value& indices, + std::vector output_size) + : Node(ir::OpKind(MaxUnpoolNdSymbol(output_size.size())), {input, indices}, + [&]() { return NodeOutputShape(input, indices, output_size); }, + /*num_outputs=*/1, xla::util::MHash(output_size)), + output_size_(std::move(output_size)) {} + +NodePtr MaxUnpoolNd::Clone(OpList operands) const { + return MakeNode(operands.at(0), operands.at(1), output_size_); +} + +XlaOpVector MaxUnpoolNd::Lower(LoweringContext* loctx) const { + xla::XlaOp input = loctx->GetOutputOp(operand(0)); + xla::XlaOp indices = loctx->GetOutputOp(operand(1)); + xla::XlaOp output = + BuildMaxUnpoolNd(loctx->device(), input, indices, output_size_); + return ReturnOp(output, loctx); +} + +std::string MaxUnpoolNd::ToString() const { + std::stringstream ss; + ss << Node::ToString() << ", output_size=(" + << absl::StrJoin(output_size_, ", ") << ")"; + return ss.str(); +} + +} // namespace ops +} // namespace ir +} // namespace torch_xla diff --git a/torch_xla/csrc/ops/max_unpool_nd.h b/torch_xla/csrc/ops/max_unpool_nd.h new file mode 100644 index 00000000000..50bea7f5903 --- /dev/null +++ b/torch_xla/csrc/ops/max_unpool_nd.h @@ -0,0 +1,28 @@ +#pragma once + +#include "torch_xla/csrc/ir.h" + +namespace torch_xla { +namespace ir { +namespace ops { + +class MaxUnpoolNd : public Node { + public: + MaxUnpoolNd(const Value& input, const Value& indices, + std::vector output_size); + + NodePtr Clone(OpList operands) const override; + + XlaOpVector Lower(LoweringContext* loctx) const override; + + std::string ToString() const override; + + const std::vector& output_size() const { return output_size_; } + + private: + std::vector output_size_; +}; + +} // namespace ops +} // namespace ir +} // namespace torch_xla diff --git a/torch_xla/csrc/ops/max_unpool_nd_backward.cpp b/torch_xla/csrc/ops/max_unpool_nd_backward.cpp new file mode 100644 index 00000000000..c19d411e64a --- /dev/null +++ b/torch_xla/csrc/ops/max_unpool_nd_backward.cpp @@ -0,0 +1,74 @@ +#include "torch_xla/csrc/ops/max_unpool_nd_backward.h" + +#include "tensorflow/compiler/xla/xla_client/debug_macros.h" +#include "tensorflow/compiler/xla/xla_client/util.h" +#include "torch_xla/csrc/lowering_context.h" +#include "torch_xla/csrc/ops/infer_output_shape.h" +#include "torch_xla/csrc/pooling.h" + +namespace torch_xla { +namespace ir { +namespace ops { +namespace { + +xla::Shape NodeOutputShape(const Value& grad_output, const Value& input, + const Value& indices, + absl::Span output_size) { + auto shape_fn = [&](absl::Span operands) -> xla::XlaOp { + return BuildMaxUnpoolNdBackward(operands[0], operands[1], operands[2], + output_size); + }; + return InferOutputShape({grad_output.shape(), input.shape(), indices.shape()}, + shape_fn); +} + +c10::Symbol MaxUnpoolNdBackwardSymbol(xla::int64 spatial_dim_count) { + switch (spatial_dim_count) { + case 2: + return at::aten::max_unpool2d_backward; + case 3: + return at::aten::max_unpool3d_backward; + default: + XLA_ERROR() << "Invalid number of spatial dimensions: " + << spatial_dim_count; + } +} + +} // namespace + +MaxUnpoolNdBackward::MaxUnpoolNdBackward(const Value& grad_output, + const Value& input, + const Value& indices, + std::vector output_size) + : Node(ir::OpKind(MaxUnpoolNdBackwardSymbol(output_size.size())), + {grad_output, input, indices}, + [&]() { + return NodeOutputShape(grad_output, input, indices, output_size); + }, + /*num_outputs=*/1, xla::util::MHash(output_size)), + output_size_(std::move(output_size)) {} + +NodePtr MaxUnpoolNdBackward::Clone(OpList operands) const { + return MakeNode(operands.at(0), operands.at(1), + operands.at(2), output_size_); +} + +XlaOpVector MaxUnpoolNdBackward::Lower(LoweringContext* loctx) const { + xla::XlaOp grad_output = loctx->GetOutputOp(operand(0)); + xla::XlaOp input = loctx->GetOutputOp(operand(1)); + xla::XlaOp indices = loctx->GetOutputOp(operand(2)); + xla::XlaOp output = + BuildMaxUnpoolNdBackward(grad_output, input, indices, output_size_); + return ReturnOp(output, loctx); +} + +std::string MaxUnpoolNdBackward::ToString() const { + std::stringstream ss; + ss << Node::ToString() << ", output_size=(" + << absl::StrJoin(output_size_, ", ") << ")"; + return ss.str(); +} + +} // namespace ops +} // namespace ir +} // namespace torch_xla diff --git a/torch_xla/csrc/ops/max_unpool_nd_backward.h b/torch_xla/csrc/ops/max_unpool_nd_backward.h new file mode 100644 index 00000000000..df9874fcd85 --- /dev/null +++ b/torch_xla/csrc/ops/max_unpool_nd_backward.h @@ -0,0 +1,29 @@ +#pragma once + +#include "torch_xla/csrc/ir.h" + +namespace torch_xla { +namespace ir { +namespace ops { + +class MaxUnpoolNdBackward : public Node { + public: + MaxUnpoolNdBackward(const Value& grad_output, const Value& input, + const Value& indices, + std::vector output_size); + + NodePtr Clone(OpList operands) const override; + + XlaOpVector Lower(LoweringContext* loctx) const override; + + std::string ToString() const override; + + const std::vector& output_size() const { return output_size_; } + + private: + std::vector output_size_; +}; + +} // namespace ops +} // namespace ir +} // namespace torch_xla diff --git a/torch_xla/csrc/pooling.cpp b/torch_xla/csrc/pooling.cpp index b5b21c256b6..3df17779848 100644 --- a/torch_xla/csrc/pooling.cpp +++ b/torch_xla/csrc/pooling.cpp @@ -4,11 +4,13 @@ #include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/lib/loops.h" #include "tensorflow/compiler/xla/client/lib/pooling.h" +#include "tensorflow/compiler/xla/client/lib/slicing.h" #include "tensorflow/compiler/xla/xla_client/debug_macros.h" #include "tensorflow/compiler/xla/xla_client/util.h" #include "torch_xla/csrc/data_ops.h" #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/tensor_util.h" +#include "torch_xla/csrc/xla_lower_util.h" namespace torch_xla { namespace { @@ -256,7 +258,6 @@ xla::XlaOp ComputeMaxPoolIndices( // We loop through every window and compute the index. The slow code will only // be executed if the caller actually uses the indices, and only if the reduce // windows overlap. - const xla::Shape& padded_input_shape = XlaHelpers::ShapeOfXlaOp(padded_input); xla::XlaOp iota = CreatePoolIndicesIota(input_shape, padded_input.builder()); xla::XlaOp padded_iota = xla::Pad(iota, xla::MaxValue(padded_input.builder(), kIndicesType), @@ -420,6 +421,49 @@ xla::XlaOp BuildMaxPoolNdBackward(xla::XlaOp out_backprop, xla::XlaOp input, /*spatial_dim_count=*/spatial_dim_count); } +xla::XlaOp BuildMaxUnpoolNd(const Device& device, xla::XlaOp input, + xla::XlaOp indices, + absl::Span output_size) { + const xla::Shape& input_shape = XlaHelpers::ShapeOfXlaOp(input); + XLA_CHECK_EQ(input_shape.rank(), 2 + output_size.size()); + + xla::Shape zeros_shape = xla::ShapeUtil::MakeShape( + input_shape.element_type(), + {input_shape.dimensions(0), input_shape.dimensions(1), + xla::util::Multiply(output_size)}); + xla::XlaOp zeros = xla::Zeros(input.builder(), zeros_shape); + xla::XlaOp flat_input = + XlaHelpers::FlattenDimRange(input, 2, output_size.size()); + xla::XlaOp flat_indices = + XlaHelpers::FlattenDimRange(indices, 2, output_size.size()); + + auto combiner_fn = [](xla::XlaOp x, xla::XlaOp y) -> xla::XlaOp { return y; }; + xla::XlaOp scatter_result = + CreateScatter(device, zeros, flat_indices, flat_input, + /*dim=*/2, combiner_fn); + + std::vector result_sizes( + {input_shape.dimensions(0), input_shape.dimensions(1)}); + result_sizes.insert(result_sizes.end(), output_size.begin(), + output_size.end()); + return XlaHelpers::DynamicReshape(scatter_result, result_sizes); +} + +xla::XlaOp BuildMaxUnpoolNdBackward(xla::XlaOp grad_output, xla::XlaOp input, + xla::XlaOp indices, + absl::Span output_size) { + const xla::Shape& input_shape = XlaHelpers::ShapeOfXlaOp(input); + xla::XlaOp flat_grad_output = + XlaHelpers::FlattenDimRange(grad_output, 2, output_size.size()); + xla::XlaOp flat_indices = + XlaHelpers::FlattenDimRange(indices, 2, output_size.size()); + xla::XlaOp gather_result = xla::TorchGather( + flat_grad_output, flat_indices, /*dim=*/2, + IsSparseGather(flat_grad_output, flat_indices, /*dim=*/2)); + + return XlaHelpers::DynamicReshapeAs(gather_result, input_shape); +} + xla::XlaOp BuildAvgPoolNd(xla::XlaOp input, xla::int64 spatial_dim_count, absl::Span kernel_size, absl::Span stride, diff --git a/torch_xla/csrc/pooling.h b/torch_xla/csrc/pooling.h index f309197c1a4..885fc90965e 100644 --- a/torch_xla/csrc/pooling.h +++ b/torch_xla/csrc/pooling.h @@ -2,6 +2,7 @@ #include "absl/types/span.h" #include "tensorflow/compiler/xla/client/xla_builder.h" +#include "torch_xla/csrc/device.h" namespace torch_xla { @@ -40,6 +41,14 @@ xla::XlaOp BuildAvgPoolNdBackward(xla::XlaOp out_backprop, xla::XlaOp input, absl::Span padding, bool ceil_mode, bool count_include_pad); +xla::XlaOp BuildMaxUnpoolNd(const Device& device, xla::XlaOp input, + xla::XlaOp indices, + absl::Span output_size); + +xla::XlaOp BuildMaxUnpoolNdBackward(xla::XlaOp grad_output, xla::XlaOp input, + xla::XlaOp indices, + absl::Span output_size); + // Computes adaptive average pooling for the given input and output size. xla::XlaOp BuildAdaptiveAvgPool2d(xla::XlaOp input, absl::Span output_size); diff --git a/torch_xla/csrc/tensor.h b/torch_xla/csrc/tensor.h index 60b1767e6da..5e95983d66a 100644 --- a/torch_xla/csrc/tensor.h +++ b/torch_xla/csrc/tensor.h @@ -713,6 +713,14 @@ class XLATensor { std::vector padding, bool ceil_mode); + static XLATensor max_unpool(const XLATensor& input, const XLATensor& indices, + std::vector output_size); + + static XLATensor max_unpool_backward(const XLATensor& grad_output, + const XLATensor& input, + const XLATensor& indices, + std::vector output_size); + static XLATensor mean(const XLATensor& input, std::vector dimensions, bool keep_reduced_dimensions, diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index a6f97839c33..e8aac10b819 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -69,6 +69,8 @@ #include "torch_xla/csrc/ops/max_in_dim.h" #include "torch_xla/csrc/ops/max_pool_nd.h" #include "torch_xla/csrc/ops/max_pool_nd_backward.h" +#include "torch_xla/csrc/ops/max_unpool_nd.h" +#include "torch_xla/csrc/ops/max_unpool_nd_backward.h" #include "torch_xla/csrc/ops/mean.h" #include "torch_xla/csrc/ops/min_in_dim.h" #include "torch_xla/csrc/ops/mse_loss.h" @@ -1642,6 +1644,22 @@ XLATensor XLATensor::max_pool_nd_backward(const XLATensor& out_backprop, ceil_mode)); } +XLATensor XLATensor::max_unpool(const XLATensor& input, + const XLATensor& indices, + std::vector output_size) { + return input.CreateFrom(ir::MakeNode( + input.GetIrValue(), indices.GetIrValue(), std::move(output_size))); +} + +XLATensor XLATensor::max_unpool_backward(const XLATensor& grad_output, + const XLATensor& input, + const XLATensor& indices, + std::vector output_size) { + return grad_output.CreateFrom(ir::MakeNode( + grad_output.GetIrValue(), input.GetIrValue(), indices.GetIrValue(), + std::move(output_size))); +} + XLATensor XLATensor::mean(const XLATensor& input, std::vector dimensions, bool keep_reduced_dimensions,