Skip to content

Commit cd7cbca

Browse files
committed
[XLATensor] Add conv2d
1 parent be4409d commit cd7cbca

File tree

8 files changed

+223
-8
lines changed

8 files changed

+223
-8
lines changed

test/test_operations.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1172,7 +1172,28 @@ def test_relu(self):
11721172
out = torch_xla._XLAC.relu(xt_x).to_tensor()
11731173
self.assertEqualDbg(out.data, expected.data)
11741174

1175+
def test_conv2d(self):
1176+
in_channels = 3
1177+
out_channels = 7
1178+
kernel_size = 5
1179+
input = _gen_tensor(4, in_channels, 28, 28)
1180+
weight = torch.Tensor(out_channels, in_channels, kernel_size, kernel_size)
1181+
bias = torch.Tensor(out_channels)
1182+
xt_input = torch_xla._XLAC.XLATensor(input)
1183+
xt_weight = torch_xla._XLAC.XLATensor(weight)
1184+
xt_bias = torch_xla._XLAC.XLATensor(bias)
1185+
for stride in range(1, 4):
1186+
for padding in range(0, 3):
1187+
for with_bias in [True, False]:
1188+
conv_bias = bias if with_bias else None
1189+
conv_xt_bias = xt_bias if with_bias else None
1190+
expected = F.conv2d(input, weight, conv_bias, stride=stride, padding=padding)
1191+
out = torch_xla._XLAC.conv2d(xt_input, xt_weight, conv_xt_bias, stride=stride,
1192+
padding=padding, use_full_conv_precision=True).to_tensor()
1193+
self.assertEqualRel(out.data, expected.data)
1194+
11751195

11761196
if __name__ == '__main__':
11771197
torch.set_default_tensor_type('torch.FloatTensor')
1198+
torch.manual_seed(42)
11781199
run_tests()

torch_xla/csrc/convolution.cpp

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -241,10 +241,8 @@ xla::XlaOp BuildThnnConv2dBackwardWeight(
241241
}
242242

243243
std::vector<std::pair<xla::int64, xla::int64>> MakePadding(
244-
const torch::jit::Node* node) {
244+
tensorflow::gtl::ArraySlice<xla::int64> padding) {
245245
std::vector<std::pair<xla::int64, xla::int64>> dims_padding;
246-
const auto padding =
247-
node->get<std::vector<int64_t>>(at::attr::padding).value();
248246
for (const auto dim_padding : padding) {
249247
dims_padding.emplace_back(dim_padding, dim_padding);
250248
}
@@ -257,13 +255,25 @@ xla::XlaOp BuildConvolution(
257255
const torch::jit::Node* node, const xla::XlaOp& input,
258256
const xla::XlaOp& kernel,
259257
const xla::PrecisionConfig::Precision conv_precision) {
260-
const auto window_strides = XlaHelpers::I64List(
261-
node->get<std::vector<int64_t>>(at::attr::stride).value());
262-
const auto dims_padding = MakePadding(node);
258+
const auto stride = node->get<std::vector<int64_t>>(at::attr::stride).value();
259+
const auto padding =
260+
node->get<std::vector<int64_t>>(at::attr::padding).value();
261+
xla::PrecisionConfig precision_config =
262+
XlaHelpers::BuildPrecisionConfig(conv_precision);
263+
return BuildConvolution(input, kernel, XlaHelpers::I64List(stride),
264+
XlaHelpers::I64List(padding), conv_precision);
265+
}
266+
267+
xla::XlaOp BuildConvolution(
268+
const xla::XlaOp& input, const xla::XlaOp& kernel,
269+
tensorflow::gtl::ArraySlice<xla::int64> stride,
270+
tensorflow::gtl::ArraySlice<xla::int64> padding,
271+
const xla::PrecisionConfig::Precision conv_precision) {
272+
const auto dims_padding = MakePadding(padding);
263273
xla::PrecisionConfig precision_config =
264274
XlaHelpers::BuildPrecisionConfig(conv_precision);
265275
return xla::ConvWithGeneralPadding(
266-
input, kernel, window_strides, dims_padding,
276+
input, kernel, stride, dims_padding,
267277
/*feature_group_count*/ 1, /*batch_group_count=*/1, &precision_config);
268278
}
269279

@@ -273,7 +283,20 @@ xla::XlaOp BuildConvolutionBias(
273283
const xla::PrecisionConfig::Precision conv_precision) {
274284
const auto node_inputs = node->inputs();
275285
XLA_CHECK_GE(node_inputs.size(), size_t(4));
276-
const auto conv = BuildConvolution(node, input, kernel, conv_precision);
286+
const auto stride = node->get<std::vector<int64_t>>(at::attr::stride).value();
287+
const auto padding =
288+
node->get<std::vector<int64_t>>(at::attr::padding).value();
289+
return BuildConvolutionBias(input, kernel, bias, XlaHelpers::I64List(stride),
290+
XlaHelpers::I64List(padding), conv_precision);
291+
}
292+
293+
xla::XlaOp BuildConvolutionBias(
294+
const xla::XlaOp& input, const xla::XlaOp& kernel, const xla::XlaOp& bias,
295+
tensorflow::gtl::ArraySlice<xla::int64> stride,
296+
tensorflow::gtl::ArraySlice<xla::int64> padding,
297+
const xla::PrecisionConfig::Precision conv_precision) {
298+
const auto conv =
299+
BuildConvolution(input, kernel, stride, padding, conv_precision);
277300
auto broadcast_sizes = XlaHelpers::SizesOfXlaOp(conv);
278301
XLA_CHECK_EQ(broadcast_sizes.size(), 4);
279302
// Remove the channels dimension.

torch_xla/csrc/convolution.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#pragma once
22

33
#include "tensorflow/compiler/xla/client/xla_builder.h"
4+
#include "tensorflow/core/lib/gtl/array_slice.h"
45
#include "torch/csrc/jit/ir.h"
56

67
namespace torch_xla {
@@ -12,12 +13,26 @@ xla::XlaOp BuildConvolution(
1213
const xla::XlaOp& kernel,
1314
const xla::PrecisionConfig::Precision conv_precision);
1415

16+
// Same as above, with stride and padding provided as parameters.
17+
xla::XlaOp BuildConvolution(
18+
const xla::XlaOp& input, const xla::XlaOp& kernel,
19+
tensorflow::gtl::ArraySlice<xla::int64> stride,
20+
tensorflow::gtl::ArraySlice<xla::int64> padding,
21+
const xla::PrecisionConfig::Precision conv_precision);
22+
1523
// Same as above, then broadcasts the bias and adds it to the result.
1624
xla::XlaOp BuildConvolutionBias(
1725
const torch::jit::Node* node, const xla::XlaOp& input,
1826
const xla::XlaOp& kernel, const xla::XlaOp& bias,
1927
const xla::PrecisionConfig::Precision conv_precision);
2028

29+
// Same as above, with stride and padding provided as parameters.
30+
xla::XlaOp BuildConvolutionBias(
31+
const xla::XlaOp& input, const xla::XlaOp& kernel, const xla::XlaOp& bias,
32+
tensorflow::gtl::ArraySlice<xla::int64> stride,
33+
tensorflow::gtl::ArraySlice<xla::int64> padding,
34+
const xla::PrecisionConfig::Precision conv_precision);
35+
2136
struct Conv2DGrads {
2237
xla::XlaOp grad_input;
2338
xla::XlaOp grad_weight;

torch_xla/csrc/init_python_bindings.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,17 @@ void InitXlaTensorBindings(py::module m) {
226226
return s.str();
227227
});
228228
m.def("relu", [](std::shared_ptr<XLATensor> self) { return self->relu(); });
229+
m.def(
230+
"conv2d",
231+
[](std::shared_ptr<XLATensor> self, std::shared_ptr<XLATensor> weight,
232+
std::shared_ptr<XLATensor> bias, int stride, int padding,
233+
bool use_full_conv_precision) {
234+
return self->conv2d(weight, bias, stride, padding,
235+
use_full_conv_precision);
236+
},
237+
py::arg("input"), py::arg("weight"), py::arg("bias") = nullptr,
238+
py::arg("stride") = 1, py::arg("padding") = 0,
239+
py::arg("use_full_conv_precision") = false);
229240
}
230241

231242
} // namespace

torch_xla/csrc/ops/conv2d.cpp

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
#include "ops/conv2d.h"
2+
#include "convolution.h"
3+
#include "lowering_context.h"
4+
#include "ops/infer_output_shape.h"
5+
#include "tensorflow/compiler/xla/xla_client/debug_macros.h"
6+
7+
namespace torch_xla {
8+
namespace ir {
9+
namespace ops {
10+
11+
namespace {
12+
13+
// The bias doesn't matter for shape inference.
14+
xla::Shape NodeOutputShape(const NodeOperand& input, const NodeOperand& weight,
15+
int stride, int padding) {
16+
std::vector<xla::int64> stride_2d(2, stride);
17+
std::vector<xla::int64> padding_2d(2, padding);
18+
auto lower_for_shape_fn =
19+
[stride_2d,
20+
padding_2d](tensorflow::gtl::ArraySlice<const xla::XlaOp> operands)
21+
-> xla::XlaOp {
22+
XLA_CHECK(operands.size() == 2 || operands.size() == 3)
23+
<< "Unexpected number of operands: " << operands.size();
24+
// The precision doesn't matter for shape inference.
25+
return BuildConvolution(operands[0], operands[1], absl::MakeSpan(stride_2d),
26+
absl::MakeSpan(padding_2d),
27+
xla::PrecisionConfig::DEFAULT);
28+
};
29+
return InferOutputShape({input.node->shape(), weight.node->shape()},
30+
lower_for_shape_fn);
31+
}
32+
33+
xla::PrecisionConfig::Precision MakePrecisionConfig(
34+
bool use_full_conv_precision) {
35+
return use_full_conv_precision ? xla::PrecisionConfig::HIGHEST
36+
: xla::PrecisionConfig::DEFAULT;
37+
}
38+
39+
} // namespace
40+
41+
Conv2d::Conv2d(const NodeOperand& input, const NodeOperand& weight,
42+
const NodeOperand& bias, int stride, int padding,
43+
bool use_full_conv_precision)
44+
: Node(ir::OpKind(at::aten::convolution), {input, weight, bias},
45+
NodeOutputShape(input, weight, stride, padding)),
46+
stride_(stride),
47+
padding_(padding),
48+
precision_(MakePrecisionConfig(use_full_conv_precision)) {}
49+
50+
Conv2d::Conv2d(const NodeOperand& input, const NodeOperand& weight, int stride,
51+
int padding, bool use_full_conv_precision)
52+
: Node(ir::OpKind(at::aten::convolution), {input, weight},
53+
NodeOutputShape(input, weight, stride, padding)),
54+
stride_(stride),
55+
padding_(padding),
56+
precision_(MakePrecisionConfig(use_full_conv_precision)) {}
57+
58+
XlaOpVector Conv2d::Lower(LoweringContext* loctx) const {
59+
std::vector<xla::int64> stride_2d(2, stride_);
60+
std::vector<xla::int64> padding_2d(2, padding_);
61+
xla::XlaOp input = loctx->GetOutputOp(operand(0));
62+
xla::XlaOp kernel = loctx->GetOutputOp(operand(1));
63+
xla::XlaOp output;
64+
if (operands().size() == 3) {
65+
xla::XlaOp bias = loctx->GetOutputOp(operand(2));
66+
output =
67+
BuildConvolutionBias(input, kernel, bias, absl::MakeSpan(stride_2d),
68+
absl::MakeSpan(padding_2d), precision_);
69+
} else {
70+
XLA_CHECK_EQ(operands().size(), 2);
71+
output = BuildConvolution(input, kernel, absl::MakeSpan(stride_2d),
72+
absl::MakeSpan(padding_2d), precision_);
73+
}
74+
return ReturnOp(output, loctx);
75+
}
76+
77+
std::string Conv2d::ToString() const {
78+
std::stringstream ss;
79+
ss << Node::ToString() << ", stride=" << stride_ << ", padding=" << padding_
80+
<< ", precision=" << xla::PrecisionConfig::Precision_Name(precision_);
81+
return ss.str();
82+
}
83+
84+
} // namespace ops
85+
} // namespace ir
86+
} // namespace torch_xla

torch_xla/csrc/ops/conv2d.h

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
#pragma once
2+
3+
#include "ir.h"
4+
#include "tensorflow/compiler/xla/xla_data.pb.h"
5+
6+
namespace torch_xla {
7+
namespace ir {
8+
namespace ops {
9+
10+
// IR node for 2D convolutions with or without bias.
11+
class Conv2d : public Node {
12+
public:
13+
Conv2d(const NodeOperand& input, const NodeOperand& weight,
14+
const NodeOperand& bias, int stride, int padding,
15+
bool use_full_conv_precision);
16+
17+
Conv2d(const NodeOperand& input, const NodeOperand& weight, int stride,
18+
int padding, bool use_full_conv_precision);
19+
20+
XlaOpVector Lower(LoweringContext* loctx) const override;
21+
22+
std::string ToString() const override;
23+
24+
private:
25+
// The parameters of the convolution. Only support the same stride and padding
26+
// in both dimension for now.
27+
int stride_;
28+
int padding_;
29+
// The numeric precision to use on TPU.
30+
xla::PrecisionConfig::Precision precision_;
31+
};
32+
33+
} // namespace ops
34+
} // namespace ir
35+
} // namespace torch_xla

torch_xla/csrc/tensor.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "helpers.h"
1414
#include "lowering_context.h"
1515
#include "ops/arithmetic_ir_ops.h"
16+
#include "ops/conv2d.h"
1617
#include "ops/cross_replica_sum.h"
1718
#include "ops/device_data.h"
1819
#include "ops/generic.h"
@@ -621,6 +622,24 @@ std::shared_ptr<XLATensor> XLATensor::relu() {
621622
GetDevice());
622623
}
623624

625+
std::shared_ptr<XLATensor> XLATensor::conv2d(
626+
const std::shared_ptr<XLATensor>& weight,
627+
const std::shared_ptr<XLATensor>& bias, int stride, int padding,
628+
bool use_full_conv_precision) {
629+
std::shared_ptr<ir::ops::Conv2d> ir_node;
630+
if (bias) {
631+
ir_node = std::make_shared<ir::ops::Conv2d>(
632+
ir::NodeOperand(GetIrNode()), ir::NodeOperand(weight->GetIrNode()),
633+
ir::NodeOperand(bias->GetIrNode()), stride, padding,
634+
use_full_conv_precision);
635+
} else {
636+
ir_node = std::make_shared<ir::ops::Conv2d>(
637+
ir::NodeOperand(GetIrNode()), ir::NodeOperand(weight->GetIrNode()),
638+
stride, padding, use_full_conv_precision);
639+
}
640+
return Create(ir_node, GetDevice());
641+
}
642+
624643
std::shared_ptr<XLATensor> XLATensor::cross_replica_sum(
625644
const std::vector<std::vector<xla::int64>>& groups) {
626645
ir::NodePtr crs =

torch_xla/csrc/tensor.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,11 @@ class XLATensor {
150150
// Additional operations which are part of the PyTorch Tensor functionality.
151151
std::shared_ptr<XLATensor> relu();
152152

153+
std::shared_ptr<XLATensor> conv2d(const std::shared_ptr<XLATensor>& weight,
154+
const std::shared_ptr<XLATensor>& bias,
155+
int stride, int padding,
156+
bool use_full_conv_precision);
157+
153158
std::shared_ptr<XLATensor> cross_replica_sum(
154159
const std::vector<std::vector<xla::int64>>& groups);
155160

0 commit comments

Comments
 (0)