Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 159 additions & 0 deletions test/cpp/test_aten_xla_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6662,6 +6662,87 @@ TEST_F(AtenXlaTensorTest, TestAdaptiveAvgPool2DNoBatch) {
}
}

TEST_F(AtenXlaTensorTest, TestMaxUnpool2D) {
int kernel_size = 2;
torch::Tensor input =
torch::rand({2, 2, 8, 8}, torch::TensorOptions(torch::kFloat));
for (int stride = 1; stride <= 2; ++stride) {
for (int padding = 0; padding <= 1; ++padding) {
// Test ceil_mode=true through the CPU interop.
for (bool ceil_mode : {false, true}) {
// Test dilation through the CPU interop.
for (int dilation = 1; dilation <= 2; ++dilation) {
torch::Tensor output;
torch::Tensor indices;
std::tie(output, indices) = torch::max_pool2d_with_indices(
input, /*kernel_size=*/{kernel_size, kernel_size},
/*stride=*/{stride, stride},
/*padding=*/{padding, padding}, /*dilation=*/{dilation, dilation},
/*ceil_mode=*/ceil_mode);

std::vector<int64_t> output_size({input.size(2), input.size(3)});
at::Tensor utensor =
torch::max_unpool2d(output, indices, output_size);

ForEachDevice([&](const torch::Device& device) {
torch::Tensor xla_output = CopyToDevice(output, device);
torch::Tensor xla_indices = CopyToDevice(indices, device);
at::Tensor xla_utensor =
torch::max_unpool2d(xla_output, xla_indices, output_size);
AllClose(utensor, xla_utensor);
});
}
}
}
}

ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::max_unpool2d", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestMaxUnpool3D) {
int kernel_size = 2;
torch::Tensor input =
torch::rand({2, 2, 8, 8, 8}, torch::TensorOptions(torch::kFloat));
for (int stride = 1; stride <= 2; ++stride) {
for (int padding = 0; padding <= 1; ++padding) {
// Test ceil_mode=true through the CPU interop.
for (bool ceil_mode : {false, true}) {
// Test dilation through the CPU interop.
for (int dilation = 1; dilation <= 2; ++dilation) {
torch::Tensor output;
torch::Tensor indices;
std::tie(output, indices) = torch::max_pool3d_with_indices(
input, /*kernel_size=*/{kernel_size, kernel_size, kernel_size},
/*stride=*/{stride, stride, stride},
/*padding=*/{padding, padding, padding},
/*dilation=*/{dilation, dilation, dilation},
/*ceil_mode=*/ceil_mode);

std::vector<int64_t> output_size(
{input.size(2), input.size(3), input.size(4)});
at::Tensor utensor = torch::max_unpool3d(
output, indices, output_size, /*stride=*/{stride, stride, stride},
/*padding=*/{padding, padding, padding});

ForEachDevice([&](const torch::Device& device) {
torch::Tensor xla_output = CopyToDevice(output, device);
torch::Tensor xla_indices = CopyToDevice(indices, device);
at::Tensor xla_utensor =
torch::max_unpool3d(xla_output, xla_indices, output_size,
/*stride=*/{stride, stride, stride},
/*padding=*/{padding, padding, padding});
AllClose(utensor, xla_utensor);
});
}
}
}
}

ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::max_unpool3d", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestNllLoss) {
int batch = 6;
int classes = 2;
Expand Down Expand Up @@ -8727,6 +8808,84 @@ TEST_F(AtenXlaTensorTest, TestMaxPool3DNoBatchBackward) {
}
}

TEST_F(AtenXlaTensorTest, TestMaxUnpool2DBackward) {
int kernel_size = 2;
torch::Tensor input =
torch::rand({2, 2, 8, 8}, torch::TensorOptions(torch::kFloat));
for (int stride = 1; stride <= 2; ++stride) {
for (int padding = 0; padding <= 1; ++padding) {
// Test ceil_mode=true through the CPU interop.
for (bool ceil_mode : {false, true}) {
for (int dilation = 1; dilation <= 2; ++dilation) {
torch::Tensor output;
torch::Tensor indices;
std::tie(output, indices) = torch::max_pool2d_with_indices(
input, /*kernel_size=*/{kernel_size, kernel_size},
/*stride=*/{stride, stride},
/*padding=*/{padding, padding}, /*dilation=*/{dilation, dilation},
/*ceil_mode=*/ceil_mode);

std::vector<int64_t> output_size({input.size(2), input.size(3)});
auto testfn =
[&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
return torch::max_unpool2d(inputs[0], inputs[1], output_size);
};

ForEachDevice([&](const torch::Device& device) {
TestBackward({output.requires_grad_(true), indices}, device,
testfn);
});
}
}
}
}

ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::max_unpool2d_backward",
cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestMaxUnpool3DBackward) {
int kernel_size = 2;
torch::Tensor input =
torch::rand({2, 2, 8, 8, 8}, torch::TensorOptions(torch::kFloat));
for (int stride = 1; stride <= 2; ++stride) {
for (int padding = 0; padding <= 1; ++padding) {
// Test ceil_mode=true through the CPU interop.
for (bool ceil_mode : {false, true}) {
for (int dilation = 1; dilation <= 2; ++dilation) {
torch::Tensor output;
torch::Tensor indices;
std::tie(output, indices) = torch::max_pool3d_with_indices(
input, /*kernel_size=*/{kernel_size, kernel_size, kernel_size},
/*stride=*/{stride, stride, stride},
/*padding=*/{padding, padding, padding},
/*dilation=*/{dilation, dilation, dilation},
/*ceil_mode=*/ceil_mode);

std::vector<int64_t> output_size(
{input.size(2), input.size(3), input.size(4)});
auto testfn =
[&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
return torch::max_unpool3d(inputs[0], inputs[1], output_size,
/*stride=*/{stride, stride, stride},
/*padding=*/{padding, padding, padding});
};

ForEachDevice([&](const torch::Device& device) {
TestBackward({output.requires_grad_(true), indices}, device,
testfn);
});
}
}
}
}

ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::max_unpool3d_backward",
cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestTanhBackward) {
auto testfn = [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
return torch::tanh(inputs[0]);
Expand Down
44 changes: 44 additions & 0 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1906,6 +1906,50 @@ std::tuple<at::Tensor, at::Tensor> AtenXlaType::max_pool3d_with_indices(
bridge::AtenFromXlaTensor(std::get<1>(outputs)));
}

at::Tensor AtenXlaType::max_unpool2d(const at::Tensor& self,
const at::Tensor& indices,
at::IntArrayRef output_size) {
XLA_FN_COUNTER("xla::");
return bridge::AtenFromXlaTensor(XLATensor::max_unpool(
bridge::GetXlaTensor(self), bridge::GetXlaTensor(indices),
xla::util::ToVector<xla::int64>(output_size)));
}

at::Tensor AtenXlaType::max_unpool2d_backward(const at::Tensor& grad_output,
const at::Tensor& self,
const at::Tensor& indices,
at::IntArrayRef output_size) {
XLA_FN_COUNTER("xla::");
return bridge::AtenFromXlaTensor(XLATensor::max_unpool_backward(
bridge::GetXlaTensor(grad_output), bridge::GetXlaTensor(self),
bridge::GetXlaTensor(indices),
xla::util::ToVector<xla::int64>(output_size)));
}

at::Tensor AtenXlaType::max_unpool3d(const at::Tensor& self,
const at::Tensor& indices,
at::IntArrayRef output_size,
at::IntArrayRef stride,
at::IntArrayRef padding) {
XLA_FN_COUNTER("xla::");
return bridge::AtenFromXlaTensor(XLATensor::max_unpool(
bridge::GetXlaTensor(self), bridge::GetXlaTensor(indices),
xla::util::ToVector<xla::int64>(output_size)));
}

at::Tensor AtenXlaType::max_unpool3d_backward(const at::Tensor& grad_output,
const at::Tensor& self,
const at::Tensor& indices,
at::IntArrayRef output_size,
at::IntArrayRef stride,
at::IntArrayRef padding) {
XLA_FN_COUNTER("xla::");
return bridge::AtenFromXlaTensor(XLATensor::max_unpool_backward(
bridge::GetXlaTensor(grad_output), bridge::GetXlaTensor(self),
bridge::GetXlaTensor(indices),
xla::util::ToVector<xla::int64>(output_size)));
}

at::Tensor AtenXlaType::mean(const at::Tensor& self,
c10::optional<at::ScalarType> dtype) {
XLA_FN_COUNTER("xla::");
Expand Down
22 changes: 22 additions & 0 deletions torch_xla/csrc/aten_xla_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -578,6 +578,28 @@ class AtenXlaType {
at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode,
const at::Tensor& indices);

static at::Tensor max_unpool2d(const at::Tensor& self,
const at::Tensor& indices,
at::IntArrayRef output_size);

static at::Tensor max_unpool2d_backward(const at::Tensor& grad_output,
const at::Tensor& self,
const at::Tensor& indices,
at::IntArrayRef output_size);

static at::Tensor max_unpool3d(const at::Tensor& self,
const at::Tensor& indices,
at::IntArrayRef output_size,
at::IntArrayRef stride,
at::IntArrayRef padding);

static at::Tensor max_unpool3d_backward(const at::Tensor& grad_output,
const at::Tensor& self,
const at::Tensor& indices,
at::IntArrayRef output_size,
at::IntArrayRef stride,
at::IntArrayRef padding);

static at::Tensor mean(const at::Tensor& self,
c10::optional<at::ScalarType> dtype);

Expand Down
26 changes: 26 additions & 0 deletions torch_xla/csrc/helpers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,32 @@ xla::XlaOp XlaHelpers::Flatten(xla::XlaOp input, xla::Shape* input_shape) {
return DynamicReshape(input, {input_elements});
}

xla::XlaOp XlaHelpers::FlattenDimRange(xla::XlaOp input, xla::int64 start,
xla::int64 range,
xla::Shape* input_shape) {
xla::util::MaybePtr<xla::Shape> input_shape_tmp(input_shape);
*input_shape_tmp = ShapeOfXlaOp(input);

std::vector<xla::int64> sizes;
xla::int64 flat_size = -1;
for (xla::int64 dim = 0; dim < input_shape_tmp->rank(); ++dim) {
if (dim < start || dim >= start + range) {
if (flat_size >= 0) {
sizes.push_back(flat_size);
flat_size = -1;
}
sizes.push_back(input_shape_tmp->dimensions(dim));
} else {
flat_size =
(flat_size < 0 ? 1 : flat_size) * input_shape_tmp->dimensions(dim);
}
}
if (flat_size >= 0) {
sizes.push_back(flat_size);
}
return DynamicReshape(input, sizes);
}

std::vector<xla::int64> XlaHelpers::MakeTransposePermutation(xla::int64 dim0,
xla::int64 dim1,
xla::int64 rank) {
Expand Down
4 changes: 4 additions & 0 deletions torch_xla/csrc/helpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,10 @@ class XlaHelpers {
static xla::XlaOp Flatten(xla::XlaOp input,
xla::Shape* input_shape = nullptr);

static xla::XlaOp FlattenDimRange(xla::XlaOp input, xla::int64 start,
xla::int64 range,
xla::Shape* input_shape = nullptr);

// Gathers the input using the order specified by the permutation. For each i,
// output[i] = input[permutation[i]]. The given permutation must be the same
// size as the input.
Expand Down
65 changes: 65 additions & 0 deletions torch_xla/csrc/ops/max_unpool_nd.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
#include "torch_xla/csrc/ops/max_unpool_nd.h"

#include "tensorflow/compiler/xla/xla_client/debug_macros.h"
#include "tensorflow/compiler/xla/xla_client/util.h"
#include "torch_xla/csrc/lowering_context.h"
#include "torch_xla/csrc/ops/infer_output_shape.h"
#include "torch_xla/csrc/pooling.h"

namespace torch_xla {
namespace ir {
namespace ops {
namespace {

xla::Shape NodeOutputShape(const Value& input, const Value& indices,
absl::Span<const xla::int64> output_size) {
auto shape_fn = [&](absl::Span<const xla::XlaOp> operands) -> xla::XlaOp {
return BuildMaxUnpoolNd(GetCurrentDevice(), operands[0], operands[1],
output_size);
};
return InferOutputShape({input.shape(), indices.shape()}, shape_fn);
}

c10::Symbol MaxUnpoolNdSymbol(xla::int64 spatial_dim_count) {
switch (spatial_dim_count) {
case 2:
return at::aten::max_unpool2d;
case 3:
return at::aten::max_unpool3d;
default:
XLA_ERROR() << "Invalid number of spatial dimensions: "
<< spatial_dim_count;
}
}

} // namespace

MaxUnpoolNd::MaxUnpoolNd(const Value& input, const Value& indices,
std::vector<xla::int64> output_size)
: Node(ir::OpKind(MaxUnpoolNdSymbol(output_size.size())), {input, indices},
[&]() { return NodeOutputShape(input, indices, output_size); },
/*num_outputs=*/1, xla::util::MHash(output_size)),
output_size_(std::move(output_size)) {}

NodePtr MaxUnpoolNd::Clone(OpList operands) const {
return MakeNode<MaxUnpoolNd>(operands.at(0), operands.at(1), output_size_);
}

XlaOpVector MaxUnpoolNd::Lower(LoweringContext* loctx) const {
xla::XlaOp input = loctx->GetOutputOp(operand(0));
xla::XlaOp indices = loctx->GetOutputOp(operand(1));
xla::XlaOp output =
BuildMaxUnpoolNd(loctx->device(), input, indices, output_size_);
return ReturnOp(output, loctx);
}

std::string MaxUnpoolNd::ToString() const {
std::stringstream ss;
ss << Node::ToString() << ", output_size=("
<< absl::StrJoin(output_size_, ", ") << ")";
return ss.str();
}

} // namespace ops
} // namespace ir
} // namespace torch_xla
28 changes: 28 additions & 0 deletions torch_xla/csrc/ops/max_unpool_nd.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#pragma once

#include "torch_xla/csrc/ir.h"

namespace torch_xla {
namespace ir {
namespace ops {

class MaxUnpoolNd : public Node {
public:
MaxUnpoolNd(const Value& input, const Value& indices,
std::vector<xla::int64> output_size);

NodePtr Clone(OpList operands) const override;

XlaOpVector Lower(LoweringContext* loctx) const override;

std::string ToString() const override;

const std::vector<xla::int64>& output_size() const { return output_size_; }

private:
std::vector<xla::int64> output_size_;
};

} // namespace ops
} // namespace ir
} // namespace torch_xla
Loading