Skip to content

Commit efd2064

Browse files
committed
Promote sides of bitwise logical operations
1 parent 002eb2d commit efd2064

File tree

2 files changed

+25
-3
lines changed

2 files changed

+25
-3
lines changed

test/cpp/test_aten_xla_tensor.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5514,6 +5514,19 @@ TEST_F(AtenXlaTensorTest, TestBitwiseAndScalarInPlace) {
55145514
});
55155515
}
55165516

5517+
TEST_F(AtenXlaTensorTest, TestBitwiseAndPromotion) {
5518+
at::Tensor input = at::rand({4, 2}, at::TensorOptions(at::kFloat));
5519+
at::Tensor view = input.reshape(-1);
5520+
at::Tensor result = at::__and__(view.gt(0), view.ne(0));
5521+
ForEachDevice([&](const Device& device) {
5522+
at::Tensor xla_input = torch::autograd::make_variable(
5523+
bridge::CreateXlaTensor(input, device), false);
5524+
at::Tensor xla_view = xla_input.reshape(-1);
5525+
at::Tensor xla_result = at::__and__(xla_view.gt(0), xla_view.ne(0));
5526+
EXPECT_TRUE(EqualValues(result, xla_result));
5527+
});
5528+
}
5529+
55175530
TEST_F(AtenXlaTensorTest, TestBitwiseOr) {
55185531
at::Tensor lhs = at::randint(0, std::numeric_limits<int32_t>::max(), {4, 2},
55195532
at::TensorOptions(at::kInt));

torch_xla/csrc/ops/bitwise_ir_ops.cpp

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,10 @@ Value BitwiseAnd(const Value& node1, const Value& node2) {
1212
auto lower_fn = [](const Node& node, LoweringContext* loctx) -> XlaOpVector {
1313
xla::XlaOp op0 = loctx->GetOutputOp(node.operand(0));
1414
xla::XlaOp op1 = loctx->GetOutputOp(node.operand(1));
15-
return node.ReturnOp(op0 & op1, loctx);
15+
auto kernel = [](const xla::XlaOp& op0, const xla::XlaOp& op1) {
16+
return op0 & op1;
17+
};
18+
return node.ReturnOp(XlaHelpers::PromotedBinaryOp(op0, op1, kernel), loctx);
1619
};
1720
return GenericOp(OpKind(at::aten::__and__), OpList{node1, node2},
1821
node1.shape(), std::move(lower_fn));
@@ -22,7 +25,10 @@ Value BitwiseOr(const Value& node1, const Value& node2) {
2225
auto lower_fn = [](const Node& node, LoweringContext* loctx) -> XlaOpVector {
2326
xla::XlaOp op0 = loctx->GetOutputOp(node.operand(0));
2427
xla::XlaOp op1 = loctx->GetOutputOp(node.operand(1));
25-
return node.ReturnOp(op0 | op1, loctx);
28+
auto kernel = [](const xla::XlaOp& op0, const xla::XlaOp& op1) {
29+
return op0 | op1;
30+
};
31+
return node.ReturnOp(XlaHelpers::PromotedBinaryOp(op0, op1, kernel), loctx);
2632
};
2733
return GenericOp(OpKind(at::aten::__or__), OpList{node1, node2},
2834
node1.shape(), std::move(lower_fn));
@@ -32,7 +38,10 @@ Value BitwiseXor(const Value& node1, const Value& node2) {
3238
auto lower_fn = [](const Node& node, LoweringContext* loctx) -> XlaOpVector {
3339
xla::XlaOp op0 = loctx->GetOutputOp(node.operand(0));
3440
xla::XlaOp op1 = loctx->GetOutputOp(node.operand(1));
35-
return node.ReturnOp(op0 ^ op1, loctx);
41+
auto kernel = [](const xla::XlaOp& op0, const xla::XlaOp& op1) {
42+
return op0 ^ op1;
43+
};
44+
return node.ReturnOp(XlaHelpers::PromotedBinaryOp(op0, op1, kernel), loctx);
3645
};
3746
return GenericOp(OpKind(at::aten::__xor__), OpList{node1, node2},
3847
node1.shape(), std::move(lower_fn));

0 commit comments

Comments
 (0)