11#include " torch_xla/csrc/convolution.h"
22
3- #include " tensorflow/compiler/xla/client/lib /conv_op_helpers.h"
3+ #include " tensorflow/compiler/tf2xla/kernels /conv_op_helpers.h"
44#include " tensorflow/compiler/xla/xla_client/debug_macros.h"
55#include " tensorflow/core/framework/tensor_shape.h"
66#include " tensorflow/core/kernels/conv_grad_ops.h"
1111namespace torch_xla {
1212namespace {
1313
14- // Converts the tensor data format to the one required by the XLA convolution
15- // library.
16- xla::ConvolutionDimensionNumbers MakeConvolutionDimensionNumbers (
17- tensorflow::TensorFormat data_format, int num_spatial_dims) {
18- int num_dims = num_spatial_dims + 2 ;
19- int batch_dimension = GetTensorBatchDimIndex (num_dims, data_format);
20- int feature_dimension = GetTensorFeatureDimIndex (num_dims, data_format);
21- xla::ConvolutionDimensionNumbers conv_dim_numbers;
22- for (int spatial_dim = 0 ; spatial_dim < num_spatial_dims; ++spatial_dim) {
23- conv_dim_numbers.add_input_spatial_dimensions (
24- GetTensorSpatialDimIndex (num_dims, data_format, spatial_dim));
25- }
26- conv_dim_numbers.set_input_batch_dimension (batch_dimension);
27- conv_dim_numbers.set_input_feature_dimension (feature_dimension);
28- return conv_dim_numbers;
29- }
30-
3114// Create a TF convolution metadata structure out of PyTorch convolution
3215// attributes.
33- xla ::ConvOpAttrs MakeConvOpAttrs (
16+ tensorflow ::ConvOpAttrs MakeConvOpAttrs (
3417 tensorflow::gtl::ArraySlice<const xla::int64> spatial_stride,
3518 tensorflow::gtl::ArraySlice<const xla::int64> spatial_padding,
3619 tensorflow::gtl::ArraySlice<const xla::int64> spatial_dilation) {
3720 int num_spatial_dims = spatial_stride.size ();
3821 XLA_CHECK_EQ (spatial_padding.size (), num_spatial_dims);
3922 XLA_CHECK_EQ (spatial_dilation.size (), num_spatial_dims);
40- xla ::ConvOpAttrs conv_op_attrs;
23+ tensorflow ::ConvOpAttrs conv_op_attrs;
4124 conv_op_attrs.depthwise = false ;
4225 conv_op_attrs.num_spatial_dims = num_spatial_dims;
4326 // Stride, dilation and padding must be set for the batch and feature in the
@@ -49,13 +32,13 @@ xla::ConvOpAttrs MakeConvOpAttrs(
4932 conv_op_attrs.strides = {1 , 1 };
5033 std::copy (spatial_stride.begin (), spatial_stride.end (),
5134 std::back_inserter (conv_op_attrs.strides ));
35+ conv_op_attrs.padding = tensorflow::Padding::EXPLICIT;
5236 conv_op_attrs.explicit_paddings .resize (4 );
5337 for (int spatial_dim = 0 ; spatial_dim < num_spatial_dims; ++spatial_dim) {
5438 conv_op_attrs.explicit_paddings .push_back (spatial_padding[spatial_dim]);
5539 conv_op_attrs.explicit_paddings .push_back (spatial_padding[spatial_dim]);
5640 }
57- conv_op_attrs.data_format = MakeConvolutionDimensionNumbers (
58- tensorflow::TensorFormat::FORMAT_NCHW, num_spatial_dims);
41+ conv_op_attrs.data_format = tensorflow::TensorFormat::FORMAT_NCHW;
5942 return conv_op_attrs;
6043}
6144
@@ -67,13 +50,13 @@ xla::XlaOp BuildThnnConv2dBackwardInput(
6750 const xla::Shape& input_shape,
6851 tensorflow::gtl::ArraySlice<const xla::int64> spatial_stride,
6952 tensorflow::gtl::ArraySlice<const xla::int64> spatial_padding) {
70- xla ::ConvOpAttrs conv_op_attrs =
53+ tensorflow ::ConvOpAttrs conv_op_attrs =
7154 MakeConvOpAttrs (spatial_stride, spatial_padding, {1 , 1 });
7255 xla::XlaOp kernel_transposed =
7356 xla::Transpose (kernel, FilterTransposePermutation ());
7457 xla::PrecisionConfig precision_config =
7558 XlaHelpers::BuildPrecisionConfig (XlaHelpers::mat_mul_precision ());
76- return ConsumeValue (xla ::MakeXlaBackpropInputConvOp (
59+ return ConsumeValue (tensorflow ::MakeXlaBackpropInputConvOp (
7760 " thnn_conv2d_backward" , input_shape, kernel_transposed, grad_output,
7861 conv_op_attrs, &precision_config));
7962}
@@ -84,15 +67,15 @@ xla::XlaOp BuildThnnConv2dBackwardWeight(
8467 const xla::Shape& kernel_shape,
8568 tensorflow::gtl::ArraySlice<const xla::int64> spatial_stride,
8669 tensorflow::gtl::ArraySlice<const xla::int64> spatial_padding) {
87- xla ::ConvOpAttrs conv_op_attrs =
70+ tensorflow ::ConvOpAttrs conv_op_attrs =
8871 MakeConvOpAttrs (spatial_stride, spatial_padding, {1 , 1 });
8972 auto inv_transpose_permutation =
9073 xla::InversePermutation (FilterTransposePermutation ());
9174 xla::Shape transposed_weight_shape = xla::ShapeUtil::PermuteDimensions (
9275 inv_transpose_permutation, kernel_shape);
9376 xla::PrecisionConfig precision_config =
9477 XlaHelpers::BuildPrecisionConfig (XlaHelpers::mat_mul_precision ());
95- xla::XlaOp conv = ConsumeValue (xla ::MakeXlaBackpropFilterConvOp (
78+ xla::XlaOp conv = ConsumeValue (tensorflow ::MakeXlaBackpropFilterConvOp (
9679 " thnn_conv2d_backward" , input, transposed_weight_shape, grad_output,
9780 conv_op_attrs, &precision_config));
9881
0 commit comments