Skip to content

Commit ce69bd5

Browse files
committed
Bring new TF and fix convolution breakage due to TF2XLA changes.
1 parent 4404ef4 commit ce69bd5

File tree

2 files changed

+10
-27
lines changed

2 files changed

+10
-27
lines changed

third_party/tensorflow

Submodule tensorflow updated 1023 files

torch_xla/csrc/convolution.cpp

Lines changed: 9 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#include "torch_xla/csrc/convolution.h"
22

3-
#include "tensorflow/compiler/xla/client/lib/conv_op_helpers.h"
3+
#include "tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h"
44
#include "tensorflow/compiler/xla/xla_client/debug_macros.h"
55
#include "tensorflow/core/framework/tensor_shape.h"
66
#include "tensorflow/core/kernels/conv_grad_ops.h"
@@ -11,33 +11,16 @@
1111
namespace torch_xla {
1212
namespace {
1313

14-
// Converts the tensor data format to the one required by the XLA convolution
15-
// library.
16-
xla::ConvolutionDimensionNumbers MakeConvolutionDimensionNumbers(
17-
tensorflow::TensorFormat data_format, int num_spatial_dims) {
18-
int num_dims = num_spatial_dims + 2;
19-
int batch_dimension = GetTensorBatchDimIndex(num_dims, data_format);
20-
int feature_dimension = GetTensorFeatureDimIndex(num_dims, data_format);
21-
xla::ConvolutionDimensionNumbers conv_dim_numbers;
22-
for (int spatial_dim = 0; spatial_dim < num_spatial_dims; ++spatial_dim) {
23-
conv_dim_numbers.add_input_spatial_dimensions(
24-
GetTensorSpatialDimIndex(num_dims, data_format, spatial_dim));
25-
}
26-
conv_dim_numbers.set_input_batch_dimension(batch_dimension);
27-
conv_dim_numbers.set_input_feature_dimension(feature_dimension);
28-
return conv_dim_numbers;
29-
}
30-
3114
// Create a TF convolution metadata structure out of PyTorch convolution
3215
// attributes.
33-
xla::ConvOpAttrs MakeConvOpAttrs(
16+
tensorflow::ConvOpAttrs MakeConvOpAttrs(
3417
tensorflow::gtl::ArraySlice<const xla::int64> spatial_stride,
3518
tensorflow::gtl::ArraySlice<const xla::int64> spatial_padding,
3619
tensorflow::gtl::ArraySlice<const xla::int64> spatial_dilation) {
3720
int num_spatial_dims = spatial_stride.size();
3821
XLA_CHECK_EQ(spatial_padding.size(), num_spatial_dims);
3922
XLA_CHECK_EQ(spatial_dilation.size(), num_spatial_dims);
40-
xla::ConvOpAttrs conv_op_attrs;
23+
tensorflow::ConvOpAttrs conv_op_attrs;
4124
conv_op_attrs.depthwise = false;
4225
conv_op_attrs.num_spatial_dims = num_spatial_dims;
4326
// Stride, dilation and padding must be set for the batch and feature in the
@@ -49,13 +32,13 @@ xla::ConvOpAttrs MakeConvOpAttrs(
4932
conv_op_attrs.strides = {1, 1};
5033
std::copy(spatial_stride.begin(), spatial_stride.end(),
5134
std::back_inserter(conv_op_attrs.strides));
35+
conv_op_attrs.padding = tensorflow::Padding::EXPLICIT;
5236
conv_op_attrs.explicit_paddings.resize(4);
5337
for (int spatial_dim = 0; spatial_dim < num_spatial_dims; ++spatial_dim) {
5438
conv_op_attrs.explicit_paddings.push_back(spatial_padding[spatial_dim]);
5539
conv_op_attrs.explicit_paddings.push_back(spatial_padding[spatial_dim]);
5640
}
57-
conv_op_attrs.data_format = MakeConvolutionDimensionNumbers(
58-
tensorflow::TensorFormat::FORMAT_NCHW, num_spatial_dims);
41+
conv_op_attrs.data_format = tensorflow::TensorFormat::FORMAT_NCHW;
5942
return conv_op_attrs;
6043
}
6144

@@ -67,13 +50,13 @@ xla::XlaOp BuildThnnConv2dBackwardInput(
6750
const xla::Shape& input_shape,
6851
tensorflow::gtl::ArraySlice<const xla::int64> spatial_stride,
6952
tensorflow::gtl::ArraySlice<const xla::int64> spatial_padding) {
70-
xla::ConvOpAttrs conv_op_attrs =
53+
tensorflow::ConvOpAttrs conv_op_attrs =
7154
MakeConvOpAttrs(spatial_stride, spatial_padding, {1, 1});
7255
xla::XlaOp kernel_transposed =
7356
xla::Transpose(kernel, FilterTransposePermutation());
7457
xla::PrecisionConfig precision_config =
7558
XlaHelpers::BuildPrecisionConfig(XlaHelpers::mat_mul_precision());
76-
return ConsumeValue(xla::MakeXlaBackpropInputConvOp(
59+
return ConsumeValue(tensorflow::MakeXlaBackpropInputConvOp(
7760
"thnn_conv2d_backward", input_shape, kernel_transposed, grad_output,
7861
conv_op_attrs, &precision_config));
7962
}
@@ -84,15 +67,15 @@ xla::XlaOp BuildThnnConv2dBackwardWeight(
8467
const xla::Shape& kernel_shape,
8568
tensorflow::gtl::ArraySlice<const xla::int64> spatial_stride,
8669
tensorflow::gtl::ArraySlice<const xla::int64> spatial_padding) {
87-
xla::ConvOpAttrs conv_op_attrs =
70+
tensorflow::ConvOpAttrs conv_op_attrs =
8871
MakeConvOpAttrs(spatial_stride, spatial_padding, {1, 1});
8972
auto inv_transpose_permutation =
9073
xla::InversePermutation(FilterTransposePermutation());
9174
xla::Shape transposed_weight_shape = xla::ShapeUtil::PermuteDimensions(
9275
inv_transpose_permutation, kernel_shape);
9376
xla::PrecisionConfig precision_config =
9477
XlaHelpers::BuildPrecisionConfig(XlaHelpers::mat_mul_precision());
95-
xla::XlaOp conv = ConsumeValue(xla::MakeXlaBackpropFilterConvOp(
78+
xla::XlaOp conv = ConsumeValue(tensorflow::MakeXlaBackpropFilterConvOp(
9679
"thnn_conv2d_backward", input, transposed_weight_shape, grad_output,
9780
conv_op_attrs, &precision_config));
9881

0 commit comments

Comments
 (0)