Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .circleci/build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ sudo apt-get -qq install clang-7 clang++-7
# Bazel dependencies
sudo apt-get -qq install pkg-config zip zlib1g-dev unzip
# XLA build requires Bazel
wget https://github.com/bazelbuild/bazel/releases/download/0.21.0/bazel-0.21.0-installer-linux-x86_64.sh
wget https://github.com/bazelbuild/bazel/releases/download/0.24.1/bazel-0.24.1-installer-linux-x86_64.sh
chmod +x bazel-*.sh
sudo ./bazel-*.sh
BAZEL="$(which bazel)"
Expand Down
2 changes: 1 addition & 1 deletion kokoro/ubuntu/common.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ set -e
set -x

function install_bazel() {
local BAZEL_VERSION="0.22.0"
local BAZEL_VERSION="0.24.1"
local BAZEL_FILE="bazel-${BAZEL_VERSION}-installer-linux-x86_64.sh"
sudo apt-get install pkg-config zip zlib1g-dev unzip
curl -L -O "https://github.com/bazelbuild/bazel/releases/download/${BAZEL_VERSION}/${BAZEL_FILE}"
Expand Down
2 changes: 1 addition & 1 deletion third_party/tensorflow
Submodule tensorflow updated 3358 files
291 changes: 71 additions & 220 deletions torch_xla/csrc/convolution.cpp
Original file line number Diff line number Diff line change
@@ -1,240 +1,89 @@
#include "torch_xla/csrc/convolution.h"

#include "tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h"
#include "tensorflow/compiler/xla/xla_client/debug_macros.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/kernels/conv_grad_ops.h"
#include "third_party/xla_client/debug_macros.h"
#include "torch_xla/csrc/helpers.h"
#include "torch_xla/csrc/tensor.h"
#include "torch_xla/csrc/translator.h"

namespace torch_xla {
namespace {

// Computes the input gradient for a convolution.
xla::XlaOp BuildThnnConv2dBackwardInput(
const xla::XlaOp& grad_output, const xla::XlaOp& weight,
tensorflow::gtl::ArraySlice<const xla::int64> input_size,
tensorflow::gtl::ArraySlice<const xla::int64> stride_attr,
tensorflow::gtl::ArraySlice<const xla::int64> padding_attr) {
XLA_CHECK_EQ(stride_attr.size(), 2);
// Adjust input size to account for specified padding.
auto padded_input_size = xla::util::ToVector<xla::int64>(input_size);
for (int i = 0; i < 2; ++i) {
padded_input_size[2 + i] += 2 * padding_attr[i];
}
tensorflow::TensorShape input_shape(padded_input_size);
xla::XlaOp filter = xla::Transpose(weight, {2, 3, 1, 0});
xla::XlaBuilder* builder = grad_output.builder();
const auto filter_size = XlaHelpers::SizesOfXlaOp(filter);
tensorflow::TensorShape filter_shape(filter_size);
tensorflow::TensorShape out_backprop_shape(
XlaHelpers::SizesOfXlaOp(grad_output));
std::vector<int> strides{1, 1};
std::copy(stride_attr.begin(), stride_attr.end(),
std::back_inserter(strides));
tensorflow::ConvBackpropDimensions dims;
constexpr int num_spatial_dims = 2;
std::vector<int> dilations{1, 1, 1, 1};
xla::Status status = ConvBackpropComputeDimensionsV2(
"thnn_conv2d_backward", num_spatial_dims, input_shape, filter_shape,
out_backprop_shape, dilations, strides, tensorflow::Padding::VALID,
/*explicit_paddings=*/{}, tensorflow::TensorFormat::FORMAT_NCHW, &dims);
XLA_CHECK_OK(status);

constexpr int batch_dim = 0;
constexpr int feature_dim = 1;

// The input gradients are computed by a convolution of the output
// gradients and the filter, with some appropriate padding. See the
// comment at the top of conv_grad_ops.h for details.

xla::ConvolutionDimensionNumbers dnums;
dnums.set_input_batch_dimension(batch_dim);
dnums.set_output_batch_dimension(batch_dim);
dnums.set_input_feature_dimension(feature_dim);
dnums.set_output_feature_dimension(feature_dim);

// TF filter shape is [ H, W, ..., inC, outC ]
// Transpose the input and output features for computing the gradient.
dnums.set_kernel_input_feature_dimension(num_spatial_dims + 1);
dnums.set_kernel_output_feature_dimension(num_spatial_dims);

std::vector<xla::int64> kernel_spatial_dims(num_spatial_dims);
std::vector<std::pair<xla::int64, xla::int64>> padding(num_spatial_dims);
std::vector<xla::int64> lhs_dilation(num_spatial_dims);
std::vector<xla::int64> rhs_dilation(num_spatial_dims);
std::vector<xla::int64> ones(num_spatial_dims, 1);
for (int i = 0; i < num_spatial_dims; ++i) {
xla::int64 dim = 2 + i;
dnums.add_input_spatial_dimensions(dim);
dnums.add_kernel_spatial_dimensions(i);
dnums.add_output_spatial_dimensions(dim);

kernel_spatial_dims[i] = i;
padding[i] = {dims.spatial_dims[i].pad_before,
dims.spatial_dims[i].pad_after};
lhs_dilation[i] = dims.spatial_dims[i].stride;
rhs_dilation[i] = dilations[dim];
// Create a TF convolution metadata structure out of PyTorch convolution
// attributes.
tensorflow::ConvOpAttrs MakeConvOpAttrs(
tensorflow::gtl::ArraySlice<const xla::int64> spatial_stride,
tensorflow::gtl::ArraySlice<const xla::int64> spatial_padding,
tensorflow::gtl::ArraySlice<const xla::int64> spatial_dilation) {
int num_spatial_dims = spatial_stride.size();
XLA_CHECK_EQ(spatial_padding.size(), num_spatial_dims);
XLA_CHECK_EQ(spatial_dilation.size(), num_spatial_dims);
tensorflow::ConvOpAttrs conv_op_attrs;
conv_op_attrs.depthwise = false;
conv_op_attrs.num_spatial_dims = num_spatial_dims;
// Stride, dilation and padding must be set for the batch and feature in the
// TF convolution metadata. Set them to 1 (stride and dilation) or 0 (padding)
// for the batch and feature dimensions.
conv_op_attrs.dilations = {1, 1};
std::copy(spatial_dilation.begin(), spatial_dilation.end(),
std::back_inserter(conv_op_attrs.dilations));
conv_op_attrs.strides = {1, 1};
std::copy(spatial_stride.begin(), spatial_stride.end(),
std::back_inserter(conv_op_attrs.strides));
conv_op_attrs.padding = tensorflow::Padding::EXPLICIT;
conv_op_attrs.explicit_paddings.resize(4);
for (int spatial_dim = 0; spatial_dim < num_spatial_dims; ++spatial_dim) {
conv_op_attrs.explicit_paddings.push_back(spatial_padding[spatial_dim]);
conv_op_attrs.explicit_paddings.push_back(spatial_padding[spatial_dim]);
}
conv_op_attrs.data_format = tensorflow::TensorFormat::FORMAT_NCHW;
return conv_op_attrs;
}

// Mirror the filter in the spatial dimensions.
xla::XlaOp mirrored_weights = xla::Rev(filter, kernel_spatial_dims);
std::vector<xla::int64> FilterTransposePermutation() { return {2, 3, 1, 0}; }

// We'll need to undo the initial input padding once on the input backprop
// result since edges are constant and have to be discarded for the gradient.
xla::PaddingConfig padding_config;
for (int i = 0; i < 2; ++i) {
padding_config.add_dimensions();
}
for (int i = 0; i < 2; ++i) {
xla::PaddingConfig::PaddingConfigDimension* dims =
padding_config.add_dimensions();
dims->set_edge_padding_low(-padding_attr[i]);
dims->set_edge_padding_high(-padding_attr[i]);
}

// activation gradients
// = gradients (with padding and dilation) <conv> mirrored_weights
// Computes the input gradient for a convolution.
xla::XlaOp BuildThnnConv2dBackwardInput(
const xla::XlaOp& grad_output, const xla::XlaOp& kernel,
const xla::Shape& input_shape,
tensorflow::gtl::ArraySlice<const xla::int64> spatial_stride,
tensorflow::gtl::ArraySlice<const xla::int64> spatial_padding) {
tensorflow::ConvOpAttrs conv_op_attrs =
MakeConvOpAttrs(spatial_stride, spatial_padding, {1, 1});
xla::XlaOp kernel_transposed =
xla::Transpose(kernel, FilterTransposePermutation());
xla::PrecisionConfig precision_config =
XlaHelpers::BuildPrecisionConfig(XlaHelpers::mat_mul_precision());
xla::Shape weight_shape = XlaHelpers::ShapeOfXlaOp(weight);
return xla::Pad(
xla::ConvGeneralDilated(grad_output, mirrored_weights,
/*window_strides=*/ones, padding, lhs_dilation,
rhs_dilation, dnums,
/*feature_group_count=*/1,
/*batch_group_count=*/1, &precision_config),
XlaHelpers::ScalarValue<float>(0, weight_shape.element_type(), builder),
padding_config);
return ConsumeValue(tensorflow::MakeXlaBackpropInputConvOp(
"thnn_conv2d_backward", input_shape, kernel_transposed, grad_output,
conv_op_attrs, &precision_config));
}

// Computes the weight gradient for a convolution.
// Computes the kernel gradient for a convolution.
xla::XlaOp BuildThnnConv2dBackwardWeight(
const xla::XlaOp& grad_output, const xla::XlaOp& input,
const xla::XlaOp& weight,
tensorflow::gtl::ArraySlice<const xla::int64> stride_attr,
tensorflow::gtl::ArraySlice<const xla::int64> padding_attr) {
constexpr int n_dim = 0;
constexpr int c_dim = 1;
XLA_CHECK_EQ(padding_attr.size(), 2);
// Adjust input size to account for specified padding.
auto input_size = XlaHelpers::SizesOfXlaOp(input);
for (int i = 0; i < 2; ++i) {
input_size[2 + i] += 2 * padding_attr[i];
}
tensorflow::TensorShape activations_shape(input_size);
const auto filter_size = XlaHelpers::SizesOfXlaOp(weight);
std::vector<xla::int64> filter_size_backward{filter_size[2], filter_size[3],
filter_size[1], filter_size[0]};
tensorflow::TensorShape filter_shape(filter_size_backward);
tensorflow::TensorShape out_backprop_shape(
XlaHelpers::SizesOfXlaOp(grad_output));
std::vector<int> strides{1, 1};
std::copy(stride_attr.begin(), stride_attr.end(),
std::back_inserter(strides));
tensorflow::ConvBackpropDimensions dims;
constexpr int num_spatial_dims = 2;
std::vector<int> dilations{1, 1, 1, 1};
xla::Status status = ConvBackpropComputeDimensionsV2(
"thnn_conv2d_backward", num_spatial_dims, activations_shape, filter_shape,
out_backprop_shape, dilations, strides, tensorflow::Padding::VALID,
/*explicit_paddings=*/{}, tensorflow::TensorFormat::FORMAT_NCHW, &dims);
XLA_CHECK(status.ok()) << status.error_message();

// The filter gradients are computed by a convolution of the input
// activations and the output gradients, with some appropriate padding.
// See the comment at the top of conv_grad_ops.h for details.

xla::ConvolutionDimensionNumbers dnums;

// The activations (inputs) form the LHS of the convolution.
// Activations have shape: [batch, in_rows, in_cols, ..., in_depth]
// For the gradient computation, we flip the roles of the batch and
// feature dimensions.
// Each spatial entry has size in_depth * batch

// Swap n_dim and c_dim in the activations.
dnums.set_input_batch_dimension(c_dim);
dnums.set_input_feature_dimension(n_dim);

// The gradients become the RHS of the convolution.
// The gradients have shape [batch, out_rows, out_cols, ..., out_depth]
// where the batch becomes the input feature for the convolution.
dnums.set_kernel_input_feature_dimension(n_dim);
dnums.set_kernel_output_feature_dimension(c_dim);

std::vector<std::pair<xla::int64, xla::int64>> padding(num_spatial_dims);
std::vector<xla::int64> rhs_dilation(num_spatial_dims);
std::vector<xla::int64> window_strides(num_spatial_dims);
std::vector<xla::int64> ones(num_spatial_dims, 1);

// Tensorflow filter shape is [ H, W, ..., inC, outC ].
for (int i = 0; i < num_spatial_dims; ++i) {
dnums.add_output_spatial_dimensions(i);
}
dnums.set_output_batch_dimension(num_spatial_dims);
dnums.set_output_feature_dimension(num_spatial_dims + 1);

for (int i = 0; i < num_spatial_dims; ++i) {
xla::int64 dim = 2 + i;
dnums.add_input_spatial_dimensions(dim);
dnums.add_kernel_spatial_dimensions(dim);

// We will also need to pad the input with zeros such that after the
// convolution, we get the right size for the filter.
// The padded_in_rows should be such that when we convolve this with the
// expanded_out_rows as a filter, we should get filter_rows back.
//
const xla::int64 padded_in_size =
dims.spatial_dims[i].expanded_output_size +
(dims.spatial_dims[i].filter_size - 1) * dilations[dim];

// However it can be smaller than input_rows: in this
// case it means some of the inputs are not used.
//
// An example is to have input_cols = 3, filter_cols = 2 and stride = 2:
//
// INPUT = [ A B C ]
//
// FILTER = [ x y ]
//
// and the output will only have one column: a = A * x + B * y
//
// and input "C" is not used at all.
//
// We apply negative padding in this case.
const xla::int64 pad_total =
padded_in_size - dims.spatial_dims[i].input_size;

// Pad the bottom/right side with the remaining space.
const xla::int64 pad_before = 0;

padding[i] = {pad_before, pad_total - pad_before};
rhs_dilation[i] = dims.spatial_dims[i].stride;
window_strides[i] = dilations[dim];
}

// Redo the initial input padding.
xla::PaddingConfig padding_config =
XlaHelpers::MakeXlaPaddingConfig(XlaHelpers::I64List(padding_attr));

xla::XlaBuilder* builder = grad_output.builder();
xla::Shape input_shape = XlaHelpers::ShapeOfXlaOp(input);
xla::XlaOp padded_input = xla::Pad(
input,
XlaHelpers::ScalarValue<float>(0, input_shape.element_type(), builder),
padding_config);

const xla::Shape& kernel_shape,
tensorflow::gtl::ArraySlice<const xla::int64> spatial_stride,
tensorflow::gtl::ArraySlice<const xla::int64> spatial_padding) {
tensorflow::ConvOpAttrs conv_op_attrs =
MakeConvOpAttrs(spatial_stride, spatial_padding, {1, 1});
auto inv_transpose_permutation =
xla::InversePermutation(FilterTransposePermutation());
xla::Shape transposed_weight_shape = xla::ShapeUtil::PermuteDimensions(
inv_transpose_permutation, kernel_shape);
xla::PrecisionConfig precision_config =
XlaHelpers::BuildPrecisionConfig(XlaHelpers::mat_mul_precision());
xla::XlaOp conv = ConsumeValue(MakeXlaBackpropFilterConvOp(
"thnn_conv2d_backward", input, transposed_weight_shape, grad_output,
conv_op_attrs, &precision_config));

// Reorder the dimensions of the filter gradient to match the NCHW convention
// of PyTorch. The original result of the convolution has the spatial and
// feature dimensions swapped and the spatial dimensions reversed.
return xla::Transpose(xla::ConvGeneralDilated(
padded_input, grad_output, window_strides, padding,
/*lhs_dilation=*/ones, rhs_dilation, dnums,
/*feature_group_count=*/1,
/*batch_group_count=*/1, &precision_config),
{3, 2, 0, 1});
return xla::Transpose(conv, inv_transpose_permutation);
}

xla::XlaOp BuildGradBias(xla::XlaOp grad_output) {
Expand Down Expand Up @@ -313,24 +162,24 @@ xla::XlaOp BuildConvolutionBias(
Conv2DGrads BuildConv2dBackward(const torch::jit::Node* node,
const xla::XlaOp& grad_output,
const xla::XlaOp& input,
const xla::XlaOp& weight) {
const xla::XlaOp& kernel) {
const auto stride = node->get<std::vector<int64_t>>(at::attr::stride).value();
const auto padding =
node->get<std::vector<int64_t>>(at::attr::padding).value();
return BuildConv2dBackward(grad_output, input, weight,
return BuildConv2dBackward(grad_output, input, kernel,
XlaHelpers::I64List(stride),
XlaHelpers::I64List(padding));
}

Conv2DGrads BuildConv2dBackward(
const xla::XlaOp& grad_output, const xla::XlaOp& input,
const xla::XlaOp& weight,
const xla::XlaOp& kernel,
tensorflow::gtl::ArraySlice<const xla::int64> stride,
tensorflow::gtl::ArraySlice<const xla::int64> padding) {
xla::XlaOp grad_input = BuildThnnConv2dBackwardInput(
grad_output, weight, XlaHelpers::SizesOfXlaOp(input), stride, padding);
grad_output, kernel, XlaHelpers::ShapeOfXlaOp(input), stride, padding);
xla::XlaOp grad_weight = BuildThnnConv2dBackwardWeight(
grad_output, input, weight, stride, padding);
grad_output, input, XlaHelpers::ShapeOfXlaOp(kernel), stride, padding);
xla::XlaOp grad_bias = BuildGradBias(grad_output);
return {grad_input, grad_weight, grad_bias};
}
Expand All @@ -348,8 +197,10 @@ xla::XlaOp BuildTransposedConvolution(
(input_shape.dimensions(2 + spatial_dim) - 1) * stride[spatial_dim] -
2 * padding[spatial_dim] + kernel_shape.dimensions(2 + spatial_dim));
}
return BuildThnnConv2dBackwardInput(input, kernel, input_size, stride,
padding);
return BuildThnnConv2dBackwardInput(
input, kernel,
xla::ShapeUtil::MakeShape(input_shape.element_type(), input_size), stride,
padding);
}

xla::XlaOp BuildTransposedConvolutionBias(
Expand All @@ -375,7 +226,7 @@ Conv2DGrads BuildTransposedConvolutionBackward(
xla::XlaOp grad_input =
BuildConvolution(grad_output, kernel, stride, padding);
xla::XlaOp grad_weight = BuildThnnConv2dBackwardWeight(
input, grad_output, kernel, stride, padding);
input, grad_output, XlaHelpers::ShapeOfXlaOp(kernel), stride, padding);
xla::XlaOp grad_bias = BuildGradBias(grad_output);
return {grad_input, grad_weight, grad_bias};
}
Expand Down
4 changes: 2 additions & 2 deletions torch_xla/csrc/convolution.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,12 @@ struct Conv2DGrads {
Conv2DGrads BuildConv2dBackward(const torch::jit::Node* node,
const xla::XlaOp& grad_output,
const xla::XlaOp& input,
const xla::XlaOp& weight);
const xla::XlaOp& kernel);

// Same as above, with stride and padding provided as parameters.
Conv2DGrads BuildConv2dBackward(
const xla::XlaOp& grad_output, const xla::XlaOp& input,
const xla::XlaOp& weight,
const xla::XlaOp& kernel,
tensorflow::gtl::ArraySlice<const xla::int64> stride,
tensorflow::gtl::ArraySlice<const xla::int64> padding);

Expand Down