@@ -12,7 +12,10 @@ Value BitwiseAnd(const Value& node1, const Value& node2) {
1212 auto lower_fn = [](const Node& node, LoweringContext* loctx) -> XlaOpVector {
1313 xla::XlaOp op0 = loctx->GetOutputOp (node.operand (0 ));
1414 xla::XlaOp op1 = loctx->GetOutputOp (node.operand (1 ));
15- return node.ReturnOp (op0 & op1, loctx);
15+ auto kernel = [](const xla::XlaOp& op0, const xla::XlaOp& op1) {
16+ return op0 & op1;
17+ };
18+ return node.ReturnOp (XlaHelpers::PromotedBinaryOp (op0, op1, kernel), loctx);
1619 };
1720 return GenericOp (OpKind (at::aten::__and__), OpList{node1, node2},
1821 node1.shape (), std::move (lower_fn));
@@ -22,7 +25,10 @@ Value BitwiseOr(const Value& node1, const Value& node2) {
2225 auto lower_fn = [](const Node& node, LoweringContext* loctx) -> XlaOpVector {
2326 xla::XlaOp op0 = loctx->GetOutputOp (node.operand (0 ));
2427 xla::XlaOp op1 = loctx->GetOutputOp (node.operand (1 ));
25- return node.ReturnOp (op0 | op1, loctx);
28+ auto kernel = [](const xla::XlaOp& op0, const xla::XlaOp& op1) {
29+ return op0 | op1;
30+ };
31+ return node.ReturnOp (XlaHelpers::PromotedBinaryOp (op0, op1, kernel), loctx);
2632 };
2733 return GenericOp (OpKind (at::aten::__or__), OpList{node1, node2},
2834 node1.shape (), std::move (lower_fn));
@@ -32,7 +38,10 @@ Value BitwiseXor(const Value& node1, const Value& node2) {
3238 auto lower_fn = [](const Node& node, LoweringContext* loctx) -> XlaOpVector {
3339 xla::XlaOp op0 = loctx->GetOutputOp (node.operand (0 ));
3440 xla::XlaOp op1 = loctx->GetOutputOp (node.operand (1 ));
35- return node.ReturnOp (op0 ^ op1, loctx);
41+ auto kernel = [](const xla::XlaOp& op0, const xla::XlaOp& op1) {
42+ return op0 ^ op1;
43+ };
44+ return node.ReturnOp (XlaHelpers::PromotedBinaryOp (op0, op1, kernel), loctx);
3645 };
3746 return GenericOp (OpKind (at::aten::__xor__), OpList{node1, node2},
3847 node1.shape (), std::move (lower_fn));
0 commit comments