Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions test/cpp/test_aten_xla_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2590,6 +2590,41 @@ TEST_F(AtenXlaTensorTest, TestBilinear) {
});
}

TEST_F(AtenXlaTensorTest, TestUpsampleNearest2D) {
int batch_size = 2;
int h = 5;
int w = 5;
int uh = 8;
int uw = 8;
int chans = 2;
torch::Tensor input = torch::rand({batch_size, chans, h, w},
torch::TensorOptions(torch::kFloat));
ForEachDevice([&](const torch::Device& device) {
torch::Tensor xla_input = CopyToDevice(input, device);
torch::Tensor result = torch::upsample_nearest2d(input, {uh, uw});
torch::Tensor xla_result = torch::upsample_nearest2d(xla_input, {uh, uw});
AllClose(result, xla_result);
});
}

TEST_F(AtenXlaTensorTest, TestUpsampleNearest2DBackward) {
int batch_size = 2;
int h = 5;
int w = 5;
int uh = 8;
int uw = 8;
int chans = 2;
auto testfn = [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
return torch::upsample_nearest2d(inputs[0], {uh, uw});
};
ForEachDevice([&](const torch::Device& device) {
TestBackward(
{torch::rand({batch_size, chans, h, w},
torch::TensorOptions(torch::kFloat).requires_grad(true))},
device, testfn);
});
}

TEST_F(AtenXlaTensorTest, TestAddCMul) {
torch::Tensor a = torch::rand({2, 2}, torch::TensorOptions(torch::kFloat));
torch::Tensor b = torch::rand({2, 2}, torch::TensorOptions(torch::kFloat));
Expand Down
23 changes: 23 additions & 0 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3037,6 +3037,29 @@ at::Tensor& AtenXlaType::unsqueeze_(at::Tensor& self, int64_t dim) {
return self;
}

at::Tensor AtenXlaType::upsample_nearest2d(const at::Tensor& self,
at::IntArrayRef output_size) {
XLATensor self_tensor = bridge::GetXlaTensor(self);
if (self_tensor.GetDevice().hw_type != DeviceType::TPU) {
return AtenXlaTypeDefault::upsample_nearest2d(self, output_size);
}
return bridge::AtenFromXlaTensor(XLATensor::upsample_nearest2d(
self_tensor, xla::util::ToVector<xla::int64>(output_size)));
}

at::Tensor AtenXlaType::upsample_nearest2d_backward(
const at::Tensor& grad_output, at::IntArrayRef output_size,
at::IntArrayRef input_size) {
XLATensor grad_output_tensor = bridge::GetXlaTensor(grad_output);
if (grad_output_tensor.GetDevice().hw_type != DeviceType::TPU) {
return AtenXlaTypeDefault::upsample_nearest2d_backward(
grad_output, output_size, input_size);
}
return bridge::AtenFromXlaTensor(XLATensor::upsample_nearest2d_backward(
grad_output_tensor, xla::util::ToVector<xla::int64>(output_size),
xla::util::ToVector<xla::int64>(input_size)));
}

at::Tensor AtenXlaType::view(const at::Tensor& self, at::IntArrayRef size) {
return bridge::AtenFromXlaTensor(
XLATensor::view(bridge::GetXlaTensor(self), XlaHelpers::I64List(size)));
Expand Down
7 changes: 7 additions & 0 deletions torch_xla/csrc/aten_xla_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -1124,6 +1124,13 @@ class AtenXlaType {

static at::Tensor& unsqueeze_(at::Tensor& self, int64_t dim);

static at::Tensor upsample_nearest2d(const at::Tensor& self,
at::IntArrayRef output_size);

static at::Tensor upsample_nearest2d_backward(const at::Tensor& grad_output,
at::IntArrayRef output_size,
at::IntArrayRef input_size);

static at::Tensor view(const at::Tensor& self, at::IntArrayRef size);

static at::Tensor view_as(const at::Tensor& self, const at::Tensor& other);
Expand Down
86 changes: 86 additions & 0 deletions torch_xla/csrc/ops/upsample_nearest2d.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
#include "torch_xla/csrc/ops/upsample_nearest2d.h"

#include <string>

#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_client/debug_macros.h"
#include "tensorflow/compiler/xla/xla_client/util.h"
#include "torch_xla/csrc/helpers.h"
#include "torch_xla/csrc/lowering_context.h"

namespace torch_xla {
namespace ir {
namespace ops {
namespace {

xla::Shape NodeOutputShape(
const Value& input,
tensorflow::gtl::ArraySlice<const xla::int64> output_size) {
XLA_CHECK_EQ(output_size.size(), 2);
const xla::Shape& input_shape = input.shape();
return xla::ShapeUtil::MakeShape(
input_shape.element_type(),
{input_shape.dimensions(0), input_shape.dimensions(1), output_size[0],
output_size[1]});
}

std::string GetBackendConfig(bool align_corners, bool half_pixel_centers) {
return absl::StrCat("\"", align_corners, half_pixel_centers, "\"");
}

xla::XlaOp LowerUpsampleNearest(const xla::XlaOp& input,
const xla::Shape& output_shape) {
xla::Shape input_shape = XlaHelpers::ShapeOfXlaOp(input);
if (input_shape.dimensions(2) == output_shape.dimensions(2) &&
input_shape.dimensions(3) == output_shape.dimensions(3)) {
return input;
}
if (input_shape.dimensions(2) == 1 && input_shape.dimensions(3) == 1) {
return input + xla::Zeros(input.builder(), output_shape);
}
// XLA wants NHWC while PyTorch comes in as NCHW, so we need to transpose,
// call the kernel, and transpose back.
std::vector<xla::int64> transpose_permute({0, 3, 2, 1});
auto inv_transpose_permute = xla::InversePermutation(transpose_permute);
xla::Shape resized_shape =
xla::ShapeUtil::PermuteDimensions(inv_transpose_permute, output_shape);
xla::XlaOp tinput = xla::Transpose(input, transpose_permute);
xla::XlaOp resised = xla::CustomCall(
input.builder(), "ResizeNearest", {tinput}, resized_shape,
GetBackendConfig(/*align_corners=*/false, /*half_pixel_centers=*/false));
return xla::Transpose(resised, inv_transpose_permute);
}

} // namespace

UpsampleNearest::UpsampleNearest(const Value& input,
std::vector<xla::int64> output_size)
: Node(ir::OpKind(at::aten::upsample_nearest2d), {input},
[&]() { return NodeOutputShape(input, output_size); },
/*num_outputs=*/1, xla::util::MHash(output_size)),
output_size_(std::move(output_size)) {}

NodePtr UpsampleNearest::Clone(OpList operands) const {
return MakeNode<UpsampleNearest>(operands.at(0), output_size_);
}

XlaOpVector UpsampleNearest::Lower(LoweringContext* loctx) const {
xla::XlaOp input = loctx->GetOutputOp(operand(0));
xla::XlaOp output = LowerUpsampleNearest(input, shape());
return ReturnOp(output, loctx);
}

std::string UpsampleNearest::ToString() const {
std::stringstream ss;
ss << Node::ToString() << ", output_size=("
<< absl::StrJoin(output_size_, ", ") << ")";
return ss.str();
}

} // namespace ops
} // namespace ir
} // namespace torch_xla
29 changes: 29 additions & 0 deletions torch_xla/csrc/ops/upsample_nearest2d.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#pragma once

#include <vector>

#include "torch_xla/csrc/ir.h"

namespace torch_xla {
namespace ir {
namespace ops {

class UpsampleNearest : public Node {
public:
UpsampleNearest(const Value& input, std::vector<xla::int64> output_size);

NodePtr Clone(OpList operands) const override;

XlaOpVector Lower(LoweringContext* loctx) const override;

std::string ToString() const override;

const std::vector<xla::int64>& output_size() const { return output_size_; }

private:
std::vector<xla::int64> output_size_;
};

} // namespace ops
} // namespace ir
} // namespace torch_xla
99 changes: 99 additions & 0 deletions torch_xla/csrc/ops/upsample_nearest2d_backward.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
#include "torch_xla/csrc/ops/upsample_nearest2d_backward.h"

#include <string>

#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/xla_client/debug_macros.h"
#include "tensorflow/compiler/xla/xla_client/sys_util.h"
#include "tensorflow/compiler/xla/xla_client/util.h"
#include "torch_xla/csrc/helpers.h"
#include "torch_xla/csrc/lowering_context.h"

namespace torch_xla {
namespace ir {
namespace ops {
namespace {

xla::Shape NodeOutputShape(
const Value& input,
tensorflow::gtl::ArraySlice<const xla::int64> input_size) {
return xla::ShapeUtil::MakeShape(input.shape().element_type(), input_size);
}

std::string GetBackendConfig(bool align_corners, bool half_pixel_centers) {
return absl::StrCat("\"", align_corners, half_pixel_centers, "\"");
}

double ResizeFactor(const xla::Shape& input_shape,
const xla::Shape& output_shape, int dim) {
return static_cast<double>(input_shape.dimensions(dim)) /
static_cast<double>(output_shape.dimensions(dim));
}

xla::XlaOp LowerUpsampleNearestBackward(const xla::XlaOp& input,
const xla::Shape& output_shape) {
static double resiple_split_factor =
xla::sys_util::GetEnvDouble("XLA_RESIZE_SPLIT_FACTOR", 3.0);
xla::Shape input_shape = XlaHelpers::ShapeOfXlaOp(input);
if (input_shape.dimensions(2) == output_shape.dimensions(2) &&
input_shape.dimensions(3) == output_shape.dimensions(3)) {
return input;
}
// XLA wants NHWC while PyTorch comes in as NCHW, so we need to transpose,
// call the kernel, and transpose back.
std::vector<xla::int64> transpose_permute({0, 3, 2, 1});
auto inv_transpose_permute = xla::InversePermutation(transpose_permute);
xla::Shape resized_shape =
xla::ShapeUtil::PermuteDimensions(inv_transpose_permute, output_shape);
xla::XlaOp tinput = xla::Transpose(input, transpose_permute);
std::string backend_config =
GetBackendConfig(/*align_corners=*/false, /*half_pixel_centers=*/false);
if (ResizeFactor(input_shape, output_shape, 2) > resiple_split_factor &&
ResizeFactor(input_shape, output_shape, 3) > resiple_split_factor) {
// If the resize is too large, do one dimension at a time.
xla::Shape partial_shape = resized_shape;
// Partial shape is in NHWC, while input shape is in NCHW.
partial_shape.mutable_dimensions()[1] = input_shape.dimensions(2);
tinput = xla::CustomCall(input.builder(), "ResizeNearestGrad", {tinput},
partial_shape, backend_config);
}
xla::XlaOp resised = xla::CustomCall(input.builder(), "ResizeNearestGrad",
{tinput}, resized_shape, backend_config);
return xla::Transpose(resised, inv_transpose_permute);
}

} // namespace

UpsampleNearestBackward::UpsampleNearestBackward(
const Value& input, std::vector<xla::int64> output_size,
std::vector<xla::int64> input_size)
: Node(ir::OpKind(at::aten::upsample_nearest2d_backward), {input},
[&]() { return NodeOutputShape(input, input_size); },
/*num_outputs=*/1, xla::util::MHash(output_size, input_size)),
output_size_(std::move(output_size)),
input_size_(std::move(input_size)) {}

NodePtr UpsampleNearestBackward::Clone(OpList operands) const {
return MakeNode<UpsampleNearestBackward>(operands.at(0), output_size_,
input_size_);
}

XlaOpVector UpsampleNearestBackward::Lower(LoweringContext* loctx) const {
xla::XlaOp input = loctx->GetOutputOp(operand(0));
xla::XlaOp output = LowerUpsampleNearestBackward(input, shape());
return ReturnOp(output, loctx);
}

std::string UpsampleNearestBackward::ToString() const {
std::stringstream ss;
ss << Node::ToString() << ", output_size=("
<< absl::StrJoin(output_size_, ", ") << "), input_size=("
<< absl::StrJoin(input_size_, ", ") << ")";
return ss.str();
}

} // namespace ops
} // namespace ir
} // namespace torch_xla
34 changes: 34 additions & 0 deletions torch_xla/csrc/ops/upsample_nearest2d_backward.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#pragma once

#include <vector>

#include "torch_xla/csrc/ir.h"

namespace torch_xla {
namespace ir {
namespace ops {

class UpsampleNearestBackward : public Node {
public:
UpsampleNearestBackward(const Value& input,
std::vector<xla::int64> output_size,
std::vector<xla::int64> input_size);

NodePtr Clone(OpList operands) const override;

XlaOpVector Lower(LoweringContext* loctx) const override;

std::string ToString() const override;

const std::vector<xla::int64>& output_size() const { return output_size_; }

const std::vector<xla::int64>& input_size() const { return input_size_; }

private:
std::vector<xla::int64> output_size_;
std::vector<xla::int64> input_size_;
};

} // namespace ops
} // namespace ir
} // namespace torch_xla
7 changes: 7 additions & 0 deletions torch_xla/csrc/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -926,6 +926,13 @@ class XLATensor {
// In-place version of the method above.
static void unsqueeze_(XLATensor& input, xla::int64 dim);

static XLATensor upsample_nearest2d(const XLATensor& input,
std::vector<xla::int64> output_size);

static XLATensor upsample_nearest2d_backward(
const XLATensor& grad_output, std::vector<xla::int64> output_size,
std::vector<xla::int64> input_size);

// Like reshape, but it returns a view into the original tensor.
static XLATensor view(
const XLATensor& input,
Expand Down
15 changes: 15 additions & 0 deletions torch_xla/csrc/tensor_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@
#include "torch_xla/csrc/ops/tril.h"
#include "torch_xla/csrc/ops/triu.h"
#include "torch_xla/csrc/ops/unsqueeze.h"
#include "torch_xla/csrc/ops/upsample_nearest2d.h"
#include "torch_xla/csrc/ops/upsample_nearest2d_backward.h"
#include "torch_xla/csrc/ops/view.h"
#include "torch_xla/csrc/tensor.h"
#include "torch_xla/csrc/tensor_ops.h"
Expand Down Expand Up @@ -2325,6 +2327,19 @@ void XLATensor::unsqueeze_(XLATensor& input, xla::int64 dim) {
ir::MakeNode<ir::ops::Unsqueeze>(input.GetIrValue(), squeeze_dim));
}

XLATensor XLATensor::upsample_nearest2d(const XLATensor& input,
std::vector<xla::int64> output_size) {
return input.CreateFrom(ir::MakeNode<ir::ops::UpsampleNearest>(
input.GetIrValue(), std::move(output_size)));
}

XLATensor XLATensor::upsample_nearest2d_backward(
const XLATensor& grad_output, std::vector<xla::int64> output_size,
std::vector<xla::int64> input_size) {
return grad_output.CreateFrom(ir::MakeNode<ir::ops::UpsampleNearestBackward>(
grad_output.GetIrValue(), std::move(output_size), std::move(input_size)));
}

XLATensor XLATensor::view(
const XLATensor& input,
tensorflow::gtl::ArraySlice<const xla::int64> output_size) {
Expand Down