Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions test/cpp/test_aten_xla_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7735,6 +7735,51 @@ TEST_F(AtenXlaTensorTest, TestConstantPadIncomplete) {
});
}

TEST_F(AtenXlaTensorTest, TestReflectionPad2dRank3) {
torch::Tensor input =
torch::rand({2, 3, 4}, torch::TensorOptions(torch::kFloat));
std::vector<int64_t> pad{2, 2, 2, 2};
torch::Tensor output = torch::reflection_pad2d(input, pad);
ForEachDevice([&](const torch::Device& device) {
torch::Tensor xla_input = CopyToDevice(input, device);
torch::Tensor xla_output = torch::reflection_pad2d(xla_input, pad);
AllClose(output, xla_output);
});

ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::reflection_pad2d", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestReflectionPad2dRank4) {
torch::Tensor input =
torch::rand({2, 2, 3, 4}, torch::TensorOptions(torch::kFloat));
std::vector<int64_t> pad{2, 2, 2, 2};
torch::Tensor output = torch::reflection_pad2d(input, pad);
ForEachDevice([&](const torch::Device& device) {
torch::Tensor xla_input = CopyToDevice(input, device);
torch::Tensor xla_output = torch::reflection_pad2d(xla_input, pad);
AllClose(output, xla_output);
});

ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::reflection_pad2d", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestReflectionPad2dBackward) {
std::vector<int64_t> pad{2, 3, 1, 2};
auto testfn = [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
return torch::reflection_pad2d(inputs[0], pad);
};
ForEachDevice([&](const torch::Device& device) {
TestBackward(
{torch::rand({1, 2, 4, 4},
torch::TensorOptions(torch::kFloat).requires_grad(true))},
device, testfn);
});

ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestAsStrided) {
torch::Tensor input =
torch::rand({128, 320}, torch::TensorOptions(torch::kFloat));
Expand Down
16 changes: 16 additions & 0 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2234,6 +2234,22 @@ at::Tensor& AtenXlaType::reciprocal_(at::Tensor& self) {
return self;
}

at::Tensor AtenXlaType::reflection_pad2d(const at::Tensor& self,
at::IntArrayRef padding) {
XLA_FN_COUNTER("xla::");
return bridge::AtenFromXlaTensor(XLATensor::reflection_pad2d(
bridge::GetXlaTensor(self), xla::util::ToVector<xla::int64>(padding)));
}

at::Tensor AtenXlaType::reflection_pad2d_backward(const at::Tensor& grad_output,
const at::Tensor& self,
at::IntArrayRef padding) {
XLA_FN_COUNTER("xla::");
return bridge::AtenFromXlaTensor(XLATensor::reflection_pad2d_backward(
bridge::GetXlaTensor(grad_output), bridge::GetXlaTensor(self),
xla::util::ToVector<xla::int64>(padding)));
}

at::Tensor AtenXlaType::relu(const at::Tensor& self) {
XLA_FN_COUNTER("xla::");
return bridge::AtenFromXlaTensor(XLATensor::relu(bridge::GetXlaTensor(self)));
Expand Down
7 changes: 7 additions & 0 deletions torch_xla/csrc/aten_xla_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -687,6 +687,13 @@ class AtenXlaType {

static at::Tensor& reciprocal_(at::Tensor& self);

static at::Tensor reflection_pad2d(const at::Tensor& self,
at::IntArrayRef padding);

static at::Tensor reflection_pad2d_backward(const at::Tensor& grad_output,
const at::Tensor& self,
at::IntArrayRef padding);

static at::Tensor relu(const at::Tensor& self);

static at::Tensor& relu_(at::Tensor& self);
Expand Down
99 changes: 99 additions & 0 deletions torch_xla/csrc/data_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,16 @@ bool IsSparseGather(const xla::Shape& input_shape,
return index_elements < input_elements / dense_gather_factor;
}

std::vector<xla::int64> GetReflectionPad2dSpatialDims(xla::int64 rank) {
std::vector<xla::int64> spatial_dims;
if (rank == 3) {
return {2, 1};
} else if (rank == 4) {
return {3, 2};
}
XLA_ERROR() << "Invalid input shape for reflection_pad2d: rank=" << rank;
}

} // namespace

bool IsSparseGather(const xla::XlaOp& input, const xla::XlaOp& index,
Expand Down Expand Up @@ -348,4 +358,93 @@ xla::XlaOp BuildUnselect(const xla::XlaOp& target, const xla::XlaOp& source,
return xla::Select(mask, padded_source, target);
}

xla::XlaOp BuildReflectionPad2d(
const xla::XlaOp& input,
tensorflow::gtl::ArraySlice<const xla::int64> padding) {
const xla::Shape& input_shape = XlaHelpers::ShapeOfXlaOp(input);
std::vector<xla::int64> spatial_dims =
GetReflectionPad2dSpatialDims(input_shape.rank());

xla::XlaOp result = input;
for (xla::int64 i = 0; i < spatial_dims.size(); ++i) {
xla::int64 dim = spatial_dims[i];
xla::int64 dim_size = input_shape.dimensions(dim);
xla::int64 lhs_padding = padding[2 * i];
xla::int64 rhs_padding = padding[2 * i + 1];

XLA_CHECK(lhs_padding >= 0 && lhs_padding <= dim_size - 1);
XLA_CHECK(rhs_padding >= 0 && rhs_padding <= dim_size - 1);

xla::XlaOp reverse = xla::Rev(result, {dim});
xla::XlaOp lhs_pad = xla::SliceInDim(reverse, dim_size - 1 - lhs_padding,
dim_size - 1, 1, dim);
xla::XlaOp rhs_pad = xla::SliceInDim(reverse, 1, 1 + rhs_padding, 1, dim);
result = xla::ConcatInDim(input.builder(), {lhs_pad, result, rhs_pad}, dim);
}
return result;
}

xla::XlaOp BuildReflectionPad2dBackward(
const xla::XlaOp& grad_output, const xla::XlaOp& input,
tensorflow::gtl::ArraySlice<const xla::int64> padding) {
const xla::Shape& input_shape = XlaHelpers::ShapeOfXlaOp(input);
const xla::Shape& grad_output_shape = XlaHelpers::ShapeOfXlaOp(grad_output);
std::vector<xla::int64> spatial_dims =
GetReflectionPad2dSpatialDims(grad_output_shape.rank());

xla::XlaOp grad = grad_output;
for (xla::int64 i = 0; i < spatial_dims.size(); ++i) {
xla::int64 dim = spatial_dims[i];
xla::int64 dim_size = grad_output_shape.dimensions(dim);
xla::int64 lhs_padding = padding[2 * i];
xla::int64 rhs_padding = padding[2 * i + 1];

XLA_CHECK(lhs_padding >= 0 && lhs_padding <= dim_size - 1);
XLA_CHECK(rhs_padding >= 0 && rhs_padding <= dim_size - 1);

xla::XlaOp lhs_pad = xla::SliceInDim(grad, 0, lhs_padding, 1, dim);
xla::XlaOp reverse_lhs_pad = xla::Rev(lhs_pad, {dim});
xla::XlaOp padded_lhs_pad =
PadInDim(reverse_lhs_pad, dim,
/*pad_lo=*/1,
/*pad_hi=*/input_shape.dimensions(dim) - lhs_padding - 1);

xla::XlaOp rhs_pad =
xla::SliceInDim(grad, dim_size - rhs_padding, dim_size, 1, dim);
xla::XlaOp reverse_rhs_pad = xla::Rev(rhs_pad, {dim});
xla::XlaOp padded_rhs_pad =
PadInDim(reverse_rhs_pad, dim,
/*pad_lo=*/input_shape.dimensions(dim) - rhs_padding - 1,
/*pad_hi=*/1);

xla::XlaOp grad_core =
xla::SliceInDim(grad, lhs_padding, dim_size - rhs_padding, 1, dim);
grad = padded_lhs_pad + grad_core + padded_rhs_pad;
}
return grad;
}

xla::XlaOp PadInDim(const xla::XlaOp& input, xla::int64 dim, xla::int64 pad_lo,
xla::int64 pad_hi, const xla::XlaOp* pad_value) {
const xla::Shape& input_shape = XlaHelpers::ShapeOfXlaOp(input);
xla::XlaOp zero;
if (pad_value == nullptr) {
zero = xla::Zero(input.builder(), input_shape.element_type());
pad_value = &zero;
}
xla::PaddingConfig padding_config;
for (xla::int64 i = 0; i < input_shape.rank(); ++i) {
auto* dims = padding_config.add_dimensions();
dims->set_interior_padding(0);
if (i == dim) {
dims->set_edge_padding_low(pad_lo);
dims->set_edge_padding_high(pad_hi);
} else {
dims->set_edge_padding_low(0);
dims->set_edge_padding_high(0);
}
}
return xla::Pad(input, *pad_value, padding_config);
}

} // namespace torch_xla
11 changes: 11 additions & 0 deletions torch_xla/csrc/data_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,4 +101,15 @@ xla::XlaOp BuildUnselect(const xla::XlaOp& target, const xla::XlaOp& source,
xla::int64 dim, xla::int64 start, xla::int64 end,
xla::int64 stride);

xla::XlaOp BuildReflectionPad2d(
const xla::XlaOp& input,
tensorflow::gtl::ArraySlice<const xla::int64> padding);

xla::XlaOp BuildReflectionPad2dBackward(
const xla::XlaOp& grad_output, const xla::XlaOp& input,
tensorflow::gtl::ArraySlice<const xla::int64> padding);

xla::XlaOp PadInDim(const xla::XlaOp& input, xla::int64 dim, xla::int64 pad_lo,
xla::int64 pad_hi, const xla::XlaOp* pad_value = nullptr);

} // namespace torch_xla
2 changes: 1 addition & 1 deletion torch_xla/csrc/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ std::string GetCurrentScope() {

ShapeCache* GetShapeCache() {
static xla::int64 shape_cache_size =
xla::sys_util::GetEnvInt("XLA_IR_SHAPE_CACHE_SIZE", 1024);
xla::sys_util::GetEnvInt("XLA_IR_SHAPE_CACHE_SIZE", 4096);
static ShapeCache* cache = new ShapeCache(shape_cache_size);
return cache;
}
Expand Down
50 changes: 50 additions & 0 deletions torch_xla/csrc/ops/reflection_pad2d.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
#include "torch_xla/csrc/ops/reflection_pad2d.h"

#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/xla_client/util.h"
#include "torch_xla/csrc/data_ops.h"
#include "torch_xla/csrc/lowering_context.h"
#include "torch_xla/csrc/ops/infer_output_shape.h"

namespace torch_xla {
namespace ir {
namespace ops {
namespace {

xla::Shape NodeOutputShape(
const Value& input, tensorflow::gtl::ArraySlice<const xla::int64> padding) {
auto lower_for_shape_fn =
[&](tensorflow::gtl::ArraySlice<const xla::XlaOp> operands)
-> xla::XlaOp { return BuildReflectionPad2d(operands[0], padding); };
return InferOutputShape({input.shape()}, lower_for_shape_fn);
}

} // namespace

ReflectionPad2d::ReflectionPad2d(const Value& input,
std::vector<xla::int64> padding)
: Node(OpKind(at::aten::reflection_pad2d), {input},
[&]() { return NodeOutputShape(input, padding); },
/*num_outputs=*/1, xla::util::MHash(padding)),
padding_(std::move(padding)) {}

NodePtr ReflectionPad2d::Clone(OpList operands) const {
return MakeNode<ReflectionPad2d>(operands.at(0), padding_);
}

XlaOpVector ReflectionPad2d::Lower(LoweringContext* loctx) const {
xla::XlaOp input = loctx->GetOutputOp(operand(0));
xla::XlaOp output = BuildReflectionPad2d(input, padding_);
return ReturnOp(output, loctx);
}

std::string ReflectionPad2d::ToString() const {
std::stringstream ss;
ss << Node::ToString() << ", padding=(" << absl::StrJoin(padding_, ", ")
<< ")";
return ss.str();
}

} // namespace ops
} // namespace ir
} // namespace torch_xla
29 changes: 29 additions & 0 deletions torch_xla/csrc/ops/reflection_pad2d.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#pragma once

#include <vector>

#include "torch_xla/csrc/ir.h"

namespace torch_xla {
namespace ir {
namespace ops {

class ReflectionPad2d : public Node {
public:
ReflectionPad2d(const Value& input, std::vector<xla::int64> padding);

NodePtr Clone(OpList operands) const override;

XlaOpVector Lower(LoweringContext* loctx) const override;

std::string ToString() const override;

const std::vector<xla::int64>& padding() const { return padding_; }

private:
std::vector<xla::int64> padding_;
};

} // namespace ops
} // namespace ir
} // namespace torch_xla
58 changes: 58 additions & 0 deletions torch_xla/csrc/ops/reflection_pad2d_backward.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
#include "torch_xla/csrc/ops/reflection_pad2d_backward.h"

#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/xla_client/util.h"
#include "torch_xla/csrc/data_ops.h"
#include "torch_xla/csrc/lowering_context.h"
#include "torch_xla/csrc/ops/infer_output_shape.h"

namespace torch_xla {
namespace ir {
namespace ops {
namespace {

xla::Shape NodeOutputShape(
const Value& grad_output, const Value& input,
tensorflow::gtl::ArraySlice<const xla::int64> padding) {
auto lower_for_shape_fn =
[&](tensorflow::gtl::ArraySlice<const xla::XlaOp> operands)
-> xla::XlaOp {
return BuildReflectionPad2dBackward(operands[0], operands[1], padding);
};
return InferOutputShape({grad_output.shape(), input.shape()},
lower_for_shape_fn);
}

} // namespace

ReflectionPad2dBackward::ReflectionPad2dBackward(
const Value& grad_output, const Value& input,
std::vector<xla::int64> padding)
: Node(OpKind(at::aten::reflection_pad2d_backward), {grad_output, input},
[&]() { return NodeOutputShape(grad_output, input, padding); },
/*num_outputs=*/1, xla::util::MHash(padding)),
padding_(std::move(padding)) {}

NodePtr ReflectionPad2dBackward::Clone(OpList operands) const {
return MakeNode<ReflectionPad2dBackward>(operands.at(0), operands.at(1),
padding_);
}

XlaOpVector ReflectionPad2dBackward::Lower(LoweringContext* loctx) const {
xla::XlaOp grad_output = loctx->GetOutputOp(operand(0));
xla::XlaOp input = loctx->GetOutputOp(operand(1));
xla::XlaOp output =
BuildReflectionPad2dBackward(grad_output, input, padding_);
return ReturnOp(output, loctx);
}

std::string ReflectionPad2dBackward::ToString() const {
std::stringstream ss;
ss << Node::ToString() << ", padding=(" << absl::StrJoin(padding_, ", ")
<< ")";
return ss.str();
}

} // namespace ops
} // namespace ir
} // namespace torch_xla
30 changes: 30 additions & 0 deletions torch_xla/csrc/ops/reflection_pad2d_backward.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#pragma once

#include <vector>

#include "torch_xla/csrc/ir.h"

namespace torch_xla {
namespace ir {
namespace ops {

class ReflectionPad2dBackward : public Node {
public:
ReflectionPad2dBackward(const Value& gard_output, const Value& input,
std::vector<xla::int64> padding);

NodePtr Clone(OpList operands) const override;

XlaOpVector Lower(LoweringContext* loctx) const override;

std::string ToString() const override;

const std::vector<xla::int64>& padding() const { return padding_; }

private:
std::vector<xla::int64> padding_;
};

} // namespace ops
} // namespace ir
} // namespace torch_xla
Loading