@@ -241,10 +241,8 @@ xla::XlaOp BuildThnnConv2dBackwardWeight(
241241}
242242
243243std::vector<std::pair<xla::int64, xla::int64>> MakePadding (
244- const torch::jit::Node* node ) {
244+ tensorflow::gtl::ArraySlice<xla::int64> padding ) {
245245 std::vector<std::pair<xla::int64, xla::int64>> dims_padding;
246- const auto padding =
247- node->get <std::vector<int64_t >>(at::attr::padding).value ();
248246 for (const auto dim_padding : padding) {
249247 dims_padding.emplace_back (dim_padding, dim_padding);
250248 }
@@ -257,13 +255,25 @@ xla::XlaOp BuildConvolution(
257255 const torch::jit::Node* node, const xla::XlaOp& input,
258256 const xla::XlaOp& kernel,
259257 const xla::PrecisionConfig::Precision conv_precision) {
260- const auto window_strides = XlaHelpers::I64List (
261- node->get <std::vector<int64_t >>(at::attr::stride).value ());
262- const auto dims_padding = MakePadding (node);
258+ const auto stride = node->get <std::vector<int64_t >>(at::attr::stride).value ();
259+ const auto padding =
260+ node->get <std::vector<int64_t >>(at::attr::padding).value ();
261+ xla::PrecisionConfig precision_config =
262+ XlaHelpers::BuildPrecisionConfig (conv_precision);
263+ return BuildConvolution (input, kernel, XlaHelpers::I64List (stride),
264+ XlaHelpers::I64List (padding), conv_precision);
265+ }
266+
267+ xla::XlaOp BuildConvolution (
268+ const xla::XlaOp& input, const xla::XlaOp& kernel,
269+ tensorflow::gtl::ArraySlice<xla::int64> stride,
270+ tensorflow::gtl::ArraySlice<xla::int64> padding,
271+ const xla::PrecisionConfig::Precision conv_precision) {
272+ const auto dims_padding = MakePadding (padding);
263273 xla::PrecisionConfig precision_config =
264274 XlaHelpers::BuildPrecisionConfig (conv_precision);
265275 return xla::ConvWithGeneralPadding (
266- input, kernel, window_strides , dims_padding,
276+ input, kernel, stride , dims_padding,
267277 /* feature_group_count*/ 1 , /* batch_group_count=*/ 1 , &precision_config);
268278}
269279
@@ -273,7 +283,20 @@ xla::XlaOp BuildConvolutionBias(
273283 const xla::PrecisionConfig::Precision conv_precision) {
274284 const auto node_inputs = node->inputs ();
275285 XLA_CHECK_GE (node_inputs.size (), size_t (4 ));
276- const auto conv = BuildConvolution (node, input, kernel, conv_precision);
286+ const auto stride = node->get <std::vector<int64_t >>(at::attr::stride).value ();
287+ const auto padding =
288+ node->get <std::vector<int64_t >>(at::attr::padding).value ();
289+ return BuildConvolutionBias (input, kernel, bias, XlaHelpers::I64List (stride),
290+ XlaHelpers::I64List (padding), conv_precision);
291+ }
292+
293+ xla::XlaOp BuildConvolutionBias (
294+ const xla::XlaOp& input, const xla::XlaOp& kernel, const xla::XlaOp& bias,
295+ tensorflow::gtl::ArraySlice<xla::int64> stride,
296+ tensorflow::gtl::ArraySlice<xla::int64> padding,
297+ const xla::PrecisionConfig::Precision conv_precision) {
298+ const auto conv =
299+ BuildConvolution (input, kernel, stride, padding, conv_precision);
277300 auto broadcast_sizes = XlaHelpers::SizesOfXlaOp (conv);
278301 XLA_CHECK_EQ (broadcast_sizes.size (), 4 );
279302 // Remove the channels dimension.
0 commit comments