11#include " torch_xla/csrc/convolution.h"
22
3- #include " tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h"
43#include " tensorflow/compiler/xla/client/lib/constants.h"
54#include " torch_xla/csrc/helpers.h"
65#include " torch_xla/csrc/runtime/debug_macros.h"
@@ -33,10 +32,8 @@ namespace {
3332 * - grad_input: conv(grad_output, weight^T) (with padding etc)
3433 * - grad_weight: conv(input^T, grad_output)
3534 *
36- * XLA provides the following wrappers instead of calling into raw
37- * ConvGeneralDilated.
35+ * Below helpers are inspired by TF2XLA implementation of the Convolution
3836 * https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc
39- * - MakeXlaForwardConvOp (not used in our lowering, see below)
4037 * - MakeXlaBackpropInputConvOp
4138 * - MakeXlaBackpropFilterConvOp
4239 *
@@ -74,10 +71,9 @@ namespace {
7471 * depthwise convolution, there's no need to do additional reshapes to match to
7572 * XLA expected format. This is also why we use raw ConvGeneralDilated instead
7673 * of MakeXlaForwardConvOp in forward graph. For code simplicity we still want
77- * to use the MakeXlaBackpropInputConvOp and MakeXlaBackpropFilterConvOp given
78- * they have many useful steps that we don't want to duplicate here, we simply
79- * enforce depthwise = false inside those functions, so that we skip the reshape
80- * steps XLA has with a [Hker, Wker, Cin, M] input.
74+ * to use the MakeXlaBackpropInputConvOp and MakeXlaBackpropFilterConvOp,
75+ * we simply enforce depthwise = false inside those functions, so that we skip the
76+ * reshape steps XLA has with a [Hker, Wker, Cin, M] input.
8177 *
8278 * forward: (conv with groups = G)
8379 * - input: [N, Hin, Win, Cin]
@@ -121,16 +117,15 @@ xla::XlaOp PadInputFromOutputSize(xla::XlaOp input,
121117 return PadToSize (input, expected_input_sizes);
122118}
123119
124- // Create a TF convolution metadata structure out of PyTorch convolution
125- // attributes.
126- tensorflow::ConvOpAttrs MakeConvOpAttrs (
127- absl::Span<const int64_t > spatial_stride,
128- absl::Span<const int64_t > spatial_padding,
129- absl::Span<const int64_t > spatial_dilation, bool depthwise) {
120+ // Create ConvOpAttrs
121+ ConvOpAttrs MakeConvOpAttrs (absl::Span<const int64_t > spatial_stride,
122+ absl::Span<const int64_t > spatial_padding,
123+ absl::Span<const int64_t > spatial_dilation,
124+ bool depthwise) {
130125 int num_spatial_dims = spatial_stride.size ();
131126 XLA_CHECK_EQ (spatial_padding.size (), num_spatial_dims);
132127 XLA_CHECK_EQ (spatial_dilation.size (), num_spatial_dims);
133- tensorflow:: ConvOpAttrs conv_op_attrs;
128+ ConvOpAttrs conv_op_attrs;
134129 conv_op_attrs.depthwise = depthwise;
135130 conv_op_attrs.num_spatial_dims = num_spatial_dims;
136131 // Stride, dilation and padding must be set for the batch and feature in the
@@ -142,15 +137,15 @@ tensorflow::ConvOpAttrs MakeConvOpAttrs(
142137 conv_op_attrs.strides = {1 , 1 };
143138 std::copy (spatial_stride.begin (), spatial_stride.end (),
144139 std::back_inserter (conv_op_attrs.strides ));
145- conv_op_attrs.padding = tensorflow:: Padding::EXPLICIT;
140+ conv_op_attrs.padding = Padding::EXPLICIT;
146141 // https://github.com/tensorflow/tensorflow/blob/ec81825aaf7e848d9f8ddffdf1e0d20aebe9172c/tensorflow/core/util/padding.cc#L40
147142 // explicit_padding requires to have (spatial_dims + 2) * 2 elements
148143 conv_op_attrs.explicit_paddings .resize (4 );
149144 for (int spatial_dim = 0 ; spatial_dim < num_spatial_dims; ++spatial_dim) {
150145 conv_op_attrs.explicit_paddings .push_back (spatial_padding[spatial_dim]);
151146 conv_op_attrs.explicit_paddings .push_back (spatial_padding[spatial_dim]);
152147 }
153- conv_op_attrs.data_format = tensorflow:: TensorFormat::FORMAT_NCHW;
148+ conv_op_attrs.data_format = TensorFormat::FORMAT_NCHW;
154149 return conv_op_attrs;
155150}
156151
@@ -218,13 +213,13 @@ xla::XlaOp BuildConvBackwardInput(xla::XlaOp grad_output, xla::XlaOp kernel,
218213 absl::Span<const int64_t > spatial_padding,
219214 absl::Span<const int64_t > spatial_dilation,
220215 int64_t groups) {
221- tensorflow:: ConvOpAttrs conv_op_attrs =
216+ ConvOpAttrs conv_op_attrs =
222217 MakeConvOpAttrs (spatial_stride, spatial_padding, spatial_dilation, false );
223218 xla::XlaOp kernel_transposed =
224219 xla::Transpose (kernel, FilterTransposePermutation (input_shape.rank ()));
225- return ConsumeValue (tensorflow:: MakeXlaBackpropInputConvOp (
226- " conv_backward_input " , input_shape, kernel_transposed, grad_output ,
227- conv_op_attrs));
220+ return ConsumeValue (MakeXlaBackpropInputConvOp (" conv_backward_input " ,
221+ input_shape, kernel_transposed,
222+ grad_output, conv_op_attrs));
228223}
229224
230225// Computes the kernel gradient for a convolution.
@@ -234,14 +229,14 @@ xla::XlaOp BuildConvBackwardWeight(xla::XlaOp grad_output, xla::XlaOp input,
234229 absl::Span<const int64_t > spatial_padding,
235230 absl::Span<const int64_t > spatial_dilation,
236231 int64_t groups) {
237- tensorflow:: ConvOpAttrs conv_op_attrs =
232+ ConvOpAttrs conv_op_attrs =
238233 MakeConvOpAttrs (spatial_stride, spatial_padding, spatial_dilation, false );
239234 auto transpose_permutation = FilterTransposePermutation (kernel_shape.rank ());
240235 auto inv_transpose_permutation =
241236 xla::InversePermutation (transpose_permutation);
242237 xla::Shape transposed_weight_shape =
243238 xla::ShapeUtil::PermuteDimensions (transpose_permutation, kernel_shape);
244- xla::XlaOp conv = ConsumeValue (tensorflow:: MakeXlaBackpropFilterConvOp (
239+ xla::XlaOp conv = ConsumeValue (MakeXlaBackpropFilterConvOp (
245240 " conv_backward_weight" , input, transposed_weight_shape, grad_output,
246241 conv_op_attrs));
247242
0 commit comments