Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions torch_xla/csrc/aten_xla_bridge.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@ at::Tensor CreateEmptyTensor(at::IntList size,
return at::empty(size, options.device(at::kCPU));
}

at::Tensor CreateRandTensor(at::IntArrayRef size,
at::Generator* generator,
at::Tensor CreateRandTensor(at::IntArrayRef size, at::Generator* generator,
const at::TensorOptions& options) {
return at::randn(size, generator, options.device(at::DeviceType::CPU));
}
Expand Down
3 changes: 1 addition & 2 deletions torch_xla/csrc/aten_xla_bridge.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@ at::Tensor CreateEmptyTensor(at::IntList size,
const at::TensorOptions& options);

// Helper function which creates a random CPU ATEN tensor.
at::Tensor CreateRandTensor(at::IntArrayRef size,
at::Generator* generator,
at::Tensor CreateRandTensor(at::IntArrayRef size, at::Generator* generator,
const at::TensorOptions& options);
at::Tensor CreateRandTensor(at::IntArrayRef size,
const at::TensorOptions& options);
Expand Down
41 changes: 41 additions & 0 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,31 @@ at::Tensor AtenXlaType::max_pool2d(const at::Tensor& self,
XlaHelpers::I64List(padding)));
}

std::tuple<at::Tensor, at::Tensor> AtenXlaType::max_pool2d_with_indices(
const at::Tensor& self, at::IntList kernel_size, at::IntList stride,
at::IntList padding, at::IntList dilation, bool ceil_mode) const {
// Lowering when ceil_mode or dilation is set not supported yet.
if (ceil_mode || IsNonTrivialDilation(dilation)) {
return AtenXlaTypeBase::max_pool2d_with_indices(
self, kernel_size, stride, padding, dilation, ceil_mode);
}
// TODO(asuhan): Here we return a placeholder tensor for the indices we hope
// to never evaluate, which works for the backward of max_pool2d. However, the
// user could request the indices to be returned, in which case we'd throw. We
// need to either provide a lowering or improve our infrastructure to be able
// to route to ATen the evaluation of outputs we hope to be unused.
XLATensor result = bridge::GetXlaTensor(self).max_pool2d(
XlaHelpers::I64List(kernel_size), XlaHelpers::I64List(stride),
XlaHelpers::I64List(padding));
xla::Shape indices_shape = result.shape();
indices_shape.set_element_type(xla::PrimitiveType::S64);
XLATensor indices_not_supported =
XLATensor::not_supported(at::aten::max_pool2d_with_indices, indices_shape,
bridge::GetXlaTensor(self).GetDevice());
return std::make_tuple(bridge::AtenFromXlaTensor(result),
bridge::AtenFromXlaTensor(indices_not_supported));
}

at::Tensor AtenXlaType::avg_pool2d(const at::Tensor& self,
at::IntList kernel_size, at::IntList stride,
at::IntList padding, bool ceil_mode,
Expand Down Expand Up @@ -208,6 +233,22 @@ at::Tensor AtenXlaType::avg_pool2d_backward(const at::Tensor& grad_output,
XlaHelpers::I64List(padding), count_include_pad));
}

at::Tensor AtenXlaType::max_pool2d_with_indices_backward(
const at::Tensor& grad_output, const at::Tensor& self,
at::IntList kernel_size, at::IntList stride, at::IntList padding,
at::IntList dilation, bool ceil_mode, const at::Tensor& indices) const {
// Lowering when ceil_mode or dilation is set not supported yet.
if (ceil_mode || IsNonTrivialDilation(dilation)) {
return AtenXlaTypeBase::max_pool2d_with_indices_backward(
grad_output, self, kernel_size, stride, padding, dilation, ceil_mode,
indices);
}
return bridge::AtenFromXlaTensor(XLATensor::max_pool2d_backward(
bridge::GetXlaTensor(grad_output), bridge::GetXlaTensor(self),
XlaHelpers::I64List(kernel_size), XlaHelpers::I64List(stride),
XlaHelpers::I64List(padding)));
}

at::Tensor AtenXlaType::_log_softmax_backward_data(
const at::Tensor& grad_output, const at::Tensor& output, int64_t dim,
const at::Tensor& /* self*/) const {
Expand Down
10 changes: 10 additions & 0 deletions torch_xla/csrc/aten_xla_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ class AtenXlaType : public AtenXlaTypeBase {
at::IntList stride, at::IntList padding,
at::IntList dilation, bool ceil_mode) const override;

std::tuple<at::Tensor, at::Tensor> max_pool2d_with_indices(
const at::Tensor& self, at::IntList kernel_size, at::IntList stride,
at::IntList padding, at::IntList dilation, bool ceil_mode) const override;

at::Tensor avg_pool2d(const at::Tensor& self, at::IntList kernel_size,
at::IntList stride, at::IntList padding, bool ceil_mode,
bool count_include_pad) const override;
Expand All @@ -74,6 +78,12 @@ class AtenXlaType : public AtenXlaTypeBase {
at::IntList padding, bool ceil_mode,
bool count_include_pad) const override;

at::Tensor max_pool2d_with_indices_backward(
const at::Tensor& grad_output, const at::Tensor& self,
at::IntList kernel_size, at::IntList stride, at::IntList padding,
at::IntList dilation, bool ceil_mode,
const at::Tensor& indices) const override;

at::Tensor _log_softmax_backward_data(const at::Tensor& grad_output,
const at::Tensor& output, int64_t dim,
const at::Tensor& self) const override;
Expand Down
66 changes: 66 additions & 0 deletions torch_xla/csrc/ops/max_pool2d_backward.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
#include "ops/max_pool2d_backward.h"
#include "lowering_context.h"
#include "ops/infer_output_shape.h"
#include "pooling.h"
#include "tensorflow/compiler/xla/xla_client/debug_macros.h"
#include "tensorflow/compiler/xla/xla_client/util.h"

namespace torch_xla {
namespace ir {
namespace ops {
namespace {

xla::Shape NodeOutputShape(
const Value& grad_output, const Value& input,
tensorflow::gtl::ArraySlice<const xla::int64> kernel_size,
tensorflow::gtl::ArraySlice<const xla::int64> stride,
tensorflow::gtl::ArraySlice<const xla::int64> padding) {
auto lower_for_shape_fn =
[stride, padding,
kernel_size](tensorflow::gtl::ArraySlice<const xla::XlaOp> operands)
-> xla::XlaOp {
XLA_CHECK_EQ(operands.size(), 2);
return BuildMaxPool2dBackward(/*out_backprop=*/operands[0],
/*input=*/operands[1], kernel_size, stride,
padding);
};
return InferOutputShape({grad_output.node->shape(), input.node->shape()},
lower_for_shape_fn);
}

} // namespace

MaxPool2dBackward::MaxPool2dBackward(
const Value& grad_output, const Value& input,
tensorflow::gtl::ArraySlice<const xla::int64> kernel_size,
tensorflow::gtl::ArraySlice<const xla::int64> stride,
tensorflow::gtl::ArraySlice<const xla::int64> padding)
: Node(ir::OpKind(at::aten::max_pool2d_with_indices_backward),
{grad_output, input},
NodeOutputShape(grad_output, input, kernel_size, stride, padding),
/*num_outputs=*/1, xla::util::MHash(kernel_size, stride, padding)),
kernel_size_(kernel_size.begin(), kernel_size.end()),
stride_(stride.begin(), stride.end()),
padding_(padding.begin(), padding.end()) {}

XlaOpVector MaxPool2dBackward::Lower(LoweringContext* loctx) const {
xla::XlaOp grad_output = loctx->GetOutputOp(operand(0));
xla::XlaOp input = loctx->GetOutputOp(operand(1));
xla::XlaOp output = BuildMaxPool2dBackward(
/*out_backprop=*/grad_output, /*input=*/input, kernel_size_, stride_,
padding_);
return ReturnOp(output, loctx);
}

std::string MaxPool2dBackward::ToString() const {
std::stringstream ss;
ss << Node::ToString() << ", kernel_size=["
<< absl::StrJoin(kernel_size_, ", ") << "], stride=["
<< absl::StrJoin(stride_, ", ") << "], padding=["
<< absl::StrJoin(padding_, ", ") << "]";
return ss.str();
}

} // namespace ops
} // namespace ir
} // namespace torch_xla
35 changes: 35 additions & 0 deletions torch_xla/csrc/ops/max_pool2d_backward.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#pragma once

#include "ir.h"

namespace torch_xla {
namespace ir {
namespace ops {

class MaxPool2dBackward : public Node {
public:
MaxPool2dBackward(const Value& grad_output, const Value& input,
tensorflow::gtl::ArraySlice<const xla::int64> kernel_size,
tensorflow::gtl::ArraySlice<const xla::int64> stride,
tensorflow::gtl::ArraySlice<const xla::int64> padding);

XlaOpVector Lower(LoweringContext* loctx) const override;

std::string ToString() const override;

const std::vector<xla::int64>& kernel_size() const { return kernel_size_; }

const std::vector<xla::int64>& stride() const { return stride_; }

const std::vector<xla::int64>& padding() const { return padding_; }

private:
// The parameters of the pooling. Ceil mode not supported yet.
std::vector<xla::int64> kernel_size_;
std::vector<xla::int64> stride_;
std::vector<xla::int64> padding_;
};

} // namespace ops
} // namespace ir
} // namespace torch_xla
9 changes: 9 additions & 0 deletions torch_xla/csrc/ops/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,15 @@ NodePtr NllLossBackwardOp(const Value& logits, const Value& labels) {
std::move(lower_fn));
}

NodePtr NotSupportedOp(c10::Symbol node_symbol, xla::Shape shape) {
auto lower_fn = [](const ir::Node& node,
ir::LoweringContext* loctx) -> ir::XlaOpVector {
XLA_ERROR() << "Node not supported: " << node.ToString();
};
return ir::ops::GenericOp(ir::OpKind(node_symbol), {}, std::move(shape),
std::move(lower_fn));
}

} // namespace ops
} // namespace ir
} // namespace torch_xla
4 changes: 4 additions & 0 deletions torch_xla/csrc/ops/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ NodePtr NllLossOp(const Value& logits, const Value& labels);

NodePtr NllLossBackwardOp(const Value& logits, const Value& labels);

// Placeholder node which is never to be used. Using it would throw an error
// during lowering.
NodePtr NotSupportedOp(c10::Symbol node_symbol, xla::Shape shape);

} // namespace ops
} // namespace ir
} // namespace torch_xla
32 changes: 18 additions & 14 deletions torch_xla/csrc/pooling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,19 +57,6 @@ PoolingOpAttributes Pooling2DOpAttributes(
return {kernel_size, stride, padding};
}

// Extract the pooling attributes for the given 2D pooling operator "node".
PoolingOpAttributes Pooling2DOpAttributes(const torch::jit::Node* pooling_2d) {
const auto kernel_size_attr = XlaHelpers::I64List(
pooling_2d->get<std::vector<int64_t>>(at::attr::kernel_size).value());
const auto stride_attr = XlaHelpers::I64List(
pooling_2d->get<std::vector<int64_t>>(at::attr::stride).value());
const auto padding_attr = XlaHelpers::I64List(
pooling_2d->get<std::vector<int64_t>>(at::attr::padding).value());
return Pooling2DOpAttributes(/*kernel_size_attr=*/kernel_size_attr,
/*stride_attr=*/stride_attr,
/*padding_attr=*/padding_attr);
}

void CheckAvgPool2DIsSupported(const torch::jit::Node* node) {
const auto node_inputs = node->inputs();
XLA_CHECK_GE(node_inputs.size(), size_t(6));
Expand Down Expand Up @@ -136,14 +123,31 @@ xla::XlaOp BuildMaxPool2d(
xla::XlaOp BuildMaxPool2dBackward(const torch::jit::Node* node,
const xla::XlaOp& out_backprop,
const xla::XlaOp& input) {
const auto kernel_size =
node->get<std::vector<int64_t>>(at::attr::kernel_size).value();
const auto stride = node->get<std::vector<int64_t>>(at::attr::stride).value();
const auto padding =
node->get<std::vector<int64_t>>(at::attr::padding).value();
return BuildMaxPool2dBackward(
out_backprop, input, XlaHelpers::I64List(kernel_size),
XlaHelpers::I64List(stride), XlaHelpers::I64List(padding));
}

xla::XlaOp BuildMaxPool2dBackward(
const xla::XlaOp& out_backprop, const xla::XlaOp& input,
tensorflow::gtl::ArraySlice<const xla::int64> kernel_size,
tensorflow::gtl::ArraySlice<const xla::int64> stride,
tensorflow::gtl::ArraySlice<const xla::int64> padding) {
xla::XlaBuilder* builder = out_backprop.builder();
xla::Shape input_shape = XlaHelpers::ShapeOfXlaOp(input);
xla::XlaOp init_value =
XlaHelpers::ScalarValue<float>(0, input_shape.element_type(), builder);
xla::XlaComputation select = CreateGeComputation(input_shape.element_type());
xla::XlaComputation scatter =
XlaHelpers::CreateAddComputation(input_shape.element_type());
PoolingOpAttributes pooling_op_attributes = Pooling2DOpAttributes(node);
PoolingOpAttributes pooling_op_attributes =
Pooling2DOpAttributes(/*kernel_size_attr=*/kernel_size,
/*stride_attr=*/stride, /*padding_attr=*/padding);
std::vector<std::pair<xla::int64, xla::int64>> window_padding;
window_padding.resize(2);
window_padding.insert(window_padding.end(),
Expand Down
7 changes: 7 additions & 0 deletions torch_xla/csrc/pooling.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,13 @@ xla::XlaOp BuildMaxPool2dBackward(const torch::jit::Node* node,
const xla::XlaOp& out_backprop,
const xla::XlaOp& input);

// Same as above, with kernel size, stride and padding provided as parameters.
xla::XlaOp BuildMaxPool2dBackward(
const xla::XlaOp& out_backprop, const xla::XlaOp& input,
tensorflow::gtl::ArraySlice<const xla::int64> kernel_size,
tensorflow::gtl::ArraySlice<const xla::int64> stride,
tensorflow::gtl::ArraySlice<const xla::int64> padding);

// Computes average pooling for the given input with the attributes specified in
// the given node.
xla::XlaOp BuildAvgPool2d(const torch::jit::Node* node,
Expand Down
25 changes: 21 additions & 4 deletions torch_xla/csrc/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "ops/generic.h"
#include "ops/infer_output_shape.h"
#include "ops/max_pool2d.h"
#include "ops/max_pool2d_backward.h"
#include "ops/ops.h"
#include "ops/scalar.h"
#include "ops/softmax.h"
Expand Down Expand Up @@ -561,7 +562,7 @@ XLATensor XLATensor::conv2d(
tensorflow::gtl::ArraySlice<const xla::int64> stride,
tensorflow::gtl::ArraySlice<const xla::int64> padding,
bool use_full_conv_precision) const {
ir::NodePtr ir_node = std::make_shared<ir::ops::Conv2d>(
ir::NodePtr ir_node = ir::MakeNode<ir::ops::Conv2d>(
GetIrNode(), weight.GetIrNode(), bias.GetIrNode(), stride, padding,
use_full_conv_precision);
return Create(ir_node, GetDevice());
Expand All @@ -573,8 +574,8 @@ XLATensor XLATensor::conv2d(
tensorflow::gtl::ArraySlice<const xla::int64> padding,
bool use_full_conv_precision) const {
ir::NodePtr ir_node =
std::make_shared<ir::ops::Conv2d>(GetIrNode(), weight.GetIrNode(), stride,
padding, use_full_conv_precision);
ir::MakeNode<ir::ops::Conv2d>(GetIrNode(), weight.GetIrNode(), stride,
padding, use_full_conv_precision);
return Create(ir_node, GetDevice());
}

Expand Down Expand Up @@ -624,13 +625,24 @@ XLATensor XLATensor::avg_pool2d_backward(
out_backprop.GetDevice());
}

XLATensor XLATensor::max_pool2d_backward(
const XLATensor& out_backprop, const XLATensor& input,
tensorflow::gtl::ArraySlice<const xla::int64> kernel_size,
tensorflow::gtl::ArraySlice<const xla::int64> stride,
tensorflow::gtl::ArraySlice<const xla::int64> padding) {
return Create(ir::MakeNode<ir::ops::MaxPool2dBackward>(
out_backprop.GetIrNode(), input.GetIrNode(), kernel_size,
stride, padding),
out_backprop.GetDevice());
}

std::tuple<XLATensor, XLATensor, XLATensor> XLATensor::conv2d_backward(
const XLATensor& out_backprop, const XLATensor& input,
const XLATensor& weight,
tensorflow::gtl::ArraySlice<const xla::int64> stride,
tensorflow::gtl::ArraySlice<const xla::int64> padding,
bool use_full_conv_precision) {
const auto node = std::make_shared<ir::ops::Conv2dBackward>(
ir::NodePtr node = ir::MakeNode<ir::ops::Conv2dBackward>(
out_backprop.GetIrNode(), input.GetIrNode(), weight.GetIrNode(), stride,
padding, use_full_conv_precision);
XLATensor grad_input = Create(ir::Value(node, 0), out_backprop.GetDevice());
Expand Down Expand Up @@ -704,6 +716,11 @@ XLATensor XLATensor::nll_loss_backward(const XLATensor& input,
input.GetDevice());
}

XLATensor XLATensor::not_supported(c10::Symbol node_symbol, xla::Shape shape,
const Device& device) {
return Create(ir::ops::NotSupportedOp(node_symbol, shape), device);
}

XLATensor XLATensor::cross_replica_sum(
const std::vector<std::vector<xla::int64>>& groups) const {
ir::NodePtr crs = ir::ops::CrossReplicaSumOp(GetIrNode(), groups);
Expand Down
9 changes: 9 additions & 0 deletions torch_xla/csrc/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,12 @@ class XLATensor {
tensorflow::gtl::ArraySlice<const xla::int64> padding,
bool count_include_pad);

static XLATensor max_pool2d_backward(
const XLATensor& out_backprop, const XLATensor& input,
tensorflow::gtl::ArraySlice<const xla::int64> kernel_size,
tensorflow::gtl::ArraySlice<const xla::int64> stride,
tensorflow::gtl::ArraySlice<const xla::int64> padding);

static std::tuple<XLATensor, XLATensor, XLATensor> conv2d_backward(
const XLATensor& out_backprop, const XLATensor& input,
const XLATensor& weight,
Expand All @@ -178,6 +184,9 @@ class XLATensor {
static XLATensor nll_loss_backward(const XLATensor& input,
const XLATensor& target);

static XLATensor not_supported(c10::Symbol node_symbol, xla::Shape shape,
const Device& device);

XLATensor cross_replica_sum(
const std::vector<std::vector<xla::int64>>& groups) const;

Expand Down