Skip to content

Commit 3d79831

Browse files
authored
[OpenXLA] Reimplement Convolution with XLA::Shape (#5190)
* only one func * add * add too * int32 * add * add * add * add * add * add * no conv_op_helper.h * no padding.h * no ConvOpAttrs * no ConvOpAttrs * no ConvOpAttrs * no ConvOpAttrs * no ConvOpAttrs * no ConvOpAttrs * no ConvOpAttrs * no ConvOpAttrs * no ConvOpAttrs * no ConvOpAttrs * no ConvOpAttrs * no ConvOpAttrs * no ConvOpAttrs * no ConvOpAttrs * no ConvOpAttrs * no ConvOpAttrs * no ConvOpAttrs * no ConvOpAttrs * no ConvOpAttrs * no ConvOpAttrs * more add * more add * more add * more add * more add * more add * skip * fix one * use xla::shape * del * no tensorshape * no tensorshape * no tensorshape * no tensorshape * no tensorshape * no tensorshape * no tensorshape * no tensorshape * no tensorshape * no tensorshape * no tensorshape * no tensorshape * no tensorshape * no tensorshape * no tensorshape * format * format * format * format * format * format * format * format * format * no conv op * update * del no used * split * split * format * format * format * format * format * format * format
1 parent a04b86d commit 3d79831

File tree

5 files changed

+839
-24
lines changed

5 files changed

+839
-24
lines changed

torch_xla/csrc/BUILD

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ ptxla_cc_library(
3737
"batch_norm.cpp",
3838
"convert_ops.cpp",
3939
"convolution.cpp",
40+
"convolution_helper.cpp",
4041
"cross_replica_reduces.cpp",
4142
"data_ops.cpp",
4243
"debug_util.cpp",
@@ -75,6 +76,7 @@ ptxla_cc_library(
7576
"batch_norm.h",
7677
"convert_ops.h",
7778
"convolution.h",
79+
"convolution_helper.h",
7880
"cross_replica_reduces.h",
7981
"data_ops.h",
8082
"debug_util.h",
@@ -128,7 +130,6 @@ ptxla_cc_library(
128130
"@com_google_absl//absl/strings",
129131
"@com_google_absl//absl/types:optional",
130132
"@com_google_absl//absl/types:span",
131-
"@org_tensorflow//tensorflow/compiler/tf2xla/kernels:conv_op_helpers",
132133
"@org_tensorflow//tensorflow/compiler/xla:comparison_util",
133134
"@org_tensorflow//tensorflow/compiler/xla:literal_util",
134135
"@org_tensorflow//tensorflow/compiler/xla:permutation_util",

torch_xla/csrc/convolution.cpp

Lines changed: 18 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
#include "torch_xla/csrc/convolution.h"
22

3-
#include "tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h"
43
#include "tensorflow/compiler/xla/client/lib/constants.h"
54
#include "torch_xla/csrc/helpers.h"
65
#include "torch_xla/csrc/runtime/debug_macros.h"
@@ -33,10 +32,8 @@ namespace {
3332
* - grad_input: conv(grad_output, weight^T) (with padding etc)
3433
* - grad_weight: conv(input^T, grad_output)
3534
*
36-
* XLA provides the following wrappers instead of calling into raw
37-
* ConvGeneralDilated.
35+
* Below helpers are inspired by TF2XLA implementation of the Convolution
3836
* https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc
39-
* - MakeXlaForwardConvOp (not used in our lowering, see below)
4037
* - MakeXlaBackpropInputConvOp
4138
* - MakeXlaBackpropFilterConvOp
4239
*
@@ -74,10 +71,9 @@ namespace {
7471
* depthwise convolution, there's no need to do additional reshapes to match to
7572
* XLA expected format. This is also why we use raw ConvGeneralDilated instead
7673
* of MakeXlaForwardConvOp in forward graph. For code simplicity we still want
77-
* to use the MakeXlaBackpropInputConvOp and MakeXlaBackpropFilterConvOp given
78-
* they have many useful steps that we don't want to duplicate here, we simply
79-
* enforce depthwise = false inside those functions, so that we skip the reshape
80-
* steps XLA has with a [Hker, Wker, Cin, M] input.
74+
* to use the MakeXlaBackpropInputConvOp and MakeXlaBackpropFilterConvOp,
75+
* we simply enforce depthwise = false inside those functions, so that we skip the
76+
* reshape steps XLA has with a [Hker, Wker, Cin, M] input.
8177
*
8278
* forward: (conv with groups = G)
8379
* - input: [N, Hin, Win, Cin]
@@ -121,16 +117,15 @@ xla::XlaOp PadInputFromOutputSize(xla::XlaOp input,
121117
return PadToSize(input, expected_input_sizes);
122118
}
123119

124-
// Create a TF convolution metadata structure out of PyTorch convolution
125-
// attributes.
126-
tensorflow::ConvOpAttrs MakeConvOpAttrs(
127-
absl::Span<const int64_t> spatial_stride,
128-
absl::Span<const int64_t> spatial_padding,
129-
absl::Span<const int64_t> spatial_dilation, bool depthwise) {
120+
// Create ConvOpAttrs
121+
ConvOpAttrs MakeConvOpAttrs(absl::Span<const int64_t> spatial_stride,
122+
absl::Span<const int64_t> spatial_padding,
123+
absl::Span<const int64_t> spatial_dilation,
124+
bool depthwise) {
130125
int num_spatial_dims = spatial_stride.size();
131126
XLA_CHECK_EQ(spatial_padding.size(), num_spatial_dims);
132127
XLA_CHECK_EQ(spatial_dilation.size(), num_spatial_dims);
133-
tensorflow::ConvOpAttrs conv_op_attrs;
128+
ConvOpAttrs conv_op_attrs;
134129
conv_op_attrs.depthwise = depthwise;
135130
conv_op_attrs.num_spatial_dims = num_spatial_dims;
136131
// Stride, dilation and padding must be set for the batch and feature in the
@@ -142,15 +137,15 @@ tensorflow::ConvOpAttrs MakeConvOpAttrs(
142137
conv_op_attrs.strides = {1, 1};
143138
std::copy(spatial_stride.begin(), spatial_stride.end(),
144139
std::back_inserter(conv_op_attrs.strides));
145-
conv_op_attrs.padding = tensorflow::Padding::EXPLICIT;
140+
conv_op_attrs.padding = Padding::EXPLICIT;
146141
// https://github.com/tensorflow/tensorflow/blob/ec81825aaf7e848d9f8ddffdf1e0d20aebe9172c/tensorflow/core/util/padding.cc#L40
147142
// explicit_padding requires to have (spatial_dims + 2) * 2 elements
148143
conv_op_attrs.explicit_paddings.resize(4);
149144
for (int spatial_dim = 0; spatial_dim < num_spatial_dims; ++spatial_dim) {
150145
conv_op_attrs.explicit_paddings.push_back(spatial_padding[spatial_dim]);
151146
conv_op_attrs.explicit_paddings.push_back(spatial_padding[spatial_dim]);
152147
}
153-
conv_op_attrs.data_format = tensorflow::TensorFormat::FORMAT_NCHW;
148+
conv_op_attrs.data_format = TensorFormat::FORMAT_NCHW;
154149
return conv_op_attrs;
155150
}
156151

@@ -218,13 +213,13 @@ xla::XlaOp BuildConvBackwardInput(xla::XlaOp grad_output, xla::XlaOp kernel,
218213
absl::Span<const int64_t> spatial_padding,
219214
absl::Span<const int64_t> spatial_dilation,
220215
int64_t groups) {
221-
tensorflow::ConvOpAttrs conv_op_attrs =
216+
ConvOpAttrs conv_op_attrs =
222217
MakeConvOpAttrs(spatial_stride, spatial_padding, spatial_dilation, false);
223218
xla::XlaOp kernel_transposed =
224219
xla::Transpose(kernel, FilterTransposePermutation(input_shape.rank()));
225-
return ConsumeValue(tensorflow::MakeXlaBackpropInputConvOp(
226-
"conv_backward_input", input_shape, kernel_transposed, grad_output,
227-
conv_op_attrs));
220+
return ConsumeValue(MakeXlaBackpropInputConvOp("conv_backward_input",
221+
input_shape, kernel_transposed,
222+
grad_output, conv_op_attrs));
228223
}
229224

230225
// Computes the kernel gradient for a convolution.
@@ -234,14 +229,14 @@ xla::XlaOp BuildConvBackwardWeight(xla::XlaOp grad_output, xla::XlaOp input,
234229
absl::Span<const int64_t> spatial_padding,
235230
absl::Span<const int64_t> spatial_dilation,
236231
int64_t groups) {
237-
tensorflow::ConvOpAttrs conv_op_attrs =
232+
ConvOpAttrs conv_op_attrs =
238233
MakeConvOpAttrs(spatial_stride, spatial_padding, spatial_dilation, false);
239234
auto transpose_permutation = FilterTransposePermutation(kernel_shape.rank());
240235
auto inv_transpose_permutation =
241236
xla::InversePermutation(transpose_permutation);
242237
xla::Shape transposed_weight_shape =
243238
xla::ShapeUtil::PermuteDimensions(transpose_permutation, kernel_shape);
244-
xla::XlaOp conv = ConsumeValue(tensorflow::MakeXlaBackpropFilterConvOp(
239+
xla::XlaOp conv = ConsumeValue(MakeXlaBackpropFilterConvOp(
245240
"conv_backward_weight", input, transposed_weight_shape, grad_output,
246241
conv_op_attrs));
247242

torch_xla/csrc/convolution.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
#include "absl/types/span.h"
55
#include "tensorflow/compiler/xla/client/xla_builder.h"
6+
#include "torch_xla/csrc/convolution_helper.h"
67

78
namespace torch_xla {
89

0 commit comments

Comments
 (0)