Skip to content

Commit 9c1f79f

Browse files
committed
Route bitwise binary with shape checks to our implementations
1 parent 7d817be commit 9c1f79f

File tree

3 files changed

+109
-35
lines changed

3 files changed

+109
-35
lines changed

scripts/gen.py

Lines changed: 39 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,9 @@ class ArgTemplate(string.Template):
3030
)
3131

3232
FuncOpts = namedtuple_with_defaults(
33-
'FuncOpts', 'ref_param, device_param, wparams, outfn_template, outfn_name')
33+
'FuncOpts',
34+
'ref_param, device_param, wparams, outfn_template, outfn_name, shape_check_indices'
35+
)
3436

3537
_GRAMMAR = r"""
3638
start: type fnname "(" params ")"
@@ -118,14 +120,28 @@ class ArgTemplate(string.Template):
118120
_FN_OUT_REGEX = []
119121

120122
_FN_REMAP = {
121-
'_th_eq(Tensor, Scalar) -> Tensor': FuncOpts(outfn_name='eq'),
122-
'_th_eq(Tensor, Tensor) -> Tensor': FuncOpts(outfn_name='eq'),
123-
'_th_ge(Tensor, Scalar) -> Tensor': FuncOpts(outfn_name='ge'),
124-
'_th_ge(Tensor, Tensor) -> Tensor': FuncOpts(outfn_name='ge'),
125-
'_th_gt(Tensor, Scalar) -> Tensor': FuncOpts(outfn_name='gt'),
126-
'_th_gt(Tensor, Tensor) -> Tensor': FuncOpts(outfn_name='gt'),
127-
'_th_lt(Tensor, Scalar) -> Tensor': FuncOpts(outfn_name='lt'),
128-
'_th_lt(Tensor, Tensor) -> Tensor': FuncOpts(outfn_name='lt'),
123+
'_th_eq(Tensor, Scalar) -> Tensor':
124+
FuncOpts(outfn_name='eq'),
125+
'_th_eq(Tensor, Tensor) -> Tensor':
126+
FuncOpts(outfn_name='eq'),
127+
'_th_ge(Tensor, Scalar) -> Tensor':
128+
FuncOpts(outfn_name='ge'),
129+
'_th_ge(Tensor, Tensor) -> Tensor':
130+
FuncOpts(outfn_name='ge'),
131+
'_th_gt(Tensor, Scalar) -> Tensor':
132+
FuncOpts(outfn_name='gt'),
133+
'_th_gt(Tensor, Tensor) -> Tensor':
134+
FuncOpts(outfn_name='gt'),
135+
'_th_lt(Tensor, Scalar) -> Tensor':
136+
FuncOpts(outfn_name='lt'),
137+
'_th_lt(Tensor, Tensor) -> Tensor':
138+
FuncOpts(outfn_name='lt'),
139+
's__th_and(Tensor, Tensor) -> Tensor':
140+
FuncOpts(outfn_name='__and__', shape_check_indices=((0, 1),)),
141+
's__th_or(Tensor, Tensor) -> Tensor':
142+
FuncOpts(outfn_name='__or__', shape_check_indices=((0, 1),)),
143+
's__th_xor(Tensor, Tensor) -> Tensor':
144+
FuncOpts(outfn_name='__xor__', shape_check_indices=((0, 1),)),
129145
}
130146

131147
_TYPE_NSMAP = {
@@ -809,6 +825,15 @@ def create_call(fname, param_vars):
809825
return '{}({})'.format(fname, ', '.join(param_vars))
810826

811827

828+
def generate_shape_checks(param_vars, shape_check_indices, fname):
829+
code = ''
830+
for i, j in shape_check_indices:
831+
code += (' XLA_CHECK({}.sizes() == {}.sizes()) << "Operand shapes must be '
832+
'identical for {}, mismatch for arguments {} and {}";\n').format(
833+
param_vars[i], param_vars[j], fname, i + 1, j + 1)
834+
return code
835+
836+
812837
def generate_aten_remap(ctx, fname, sig, params, fnopts):
813838
code = '{} {}{{\n'.format(sig, 'const ' if ctx.gen_class_mode else '')
814839

@@ -819,6 +844,8 @@ def generate_aten_remap(ctx, fname, sig, params, fnopts):
819844
assert fnopts.outfn_name
820845
fcall = create_call(fnopts.outfn_name, param_vars)
821846

847+
if fnopts.shape_check_indices is not None:
848+
code += generate_shape_checks(param_vars, fnopts.shape_check_indices, fname)
822849
code += ' return {};\n'.format(fcall)
823850
code += '}'
824851
return code
@@ -907,8 +934,9 @@ def generate_aten_to_xla(ctx, tree, rwxtree, fname, sig, rwsig, params, fnopts):
907934
if result_assign:
908935
code += (' static_cast<void>({}); // Avoid warnings in case not '
909936
'used\n'.format(_RESULT_NAME))
910-
code += generate_exit_debug_code(
911-
tree, fname, _RESULT_NAME if result_assign else None, params, param_vars)
937+
code += generate_exit_debug_code(tree, fname,
938+
_RESULT_NAME if result_assign else None,
939+
params, param_vars)
912940
code += generate_return_stmt(tree, get_return_type_str(rwxtree, rwsig), fname,
913941
_RESULT_NAME if result_assign else None, params,
914942
param_vars, ref_param, fnopts)

test/cpp/test_aten_xla_tensor.cpp

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include <iostream>
44

55
#include <ATen/ATen.h>
6+
#include <ATen/LegacyTHFunctions.h>
67
#include <ATen/NativeFunctions.h>
78

89
#include <torch/csrc/autograd/function.h>
@@ -5085,6 +5086,54 @@ TEST_F(AtenXlaTensorTest, TestBitwiseXorScalarInPlace) {
50855086
});
50865087
}
50875088

5089+
TEST_F(AtenXlaTensorTest, TestBitwiseAndAutograd) {
5090+
at::Tensor lhs = at::randint(0, std::numeric_limits<int32_t>::max(), {4, 2},
5091+
at::TensorOptions(at::kInt));
5092+
at::Tensor rhs = at::randint(0, std::numeric_limits<int32_t>::max(), {4, 2},
5093+
at::TensorOptions(at::kInt));
5094+
at::Tensor result = at::legacy::th::__and__(lhs, rhs);
5095+
ForEachDevice([&](const Device& device) {
5096+
at::Tensor xla_lhs = torch::autograd::make_variable(
5097+
bridge::CreateXlaTensor(lhs, device), false);
5098+
at::Tensor xla_rhs = torch::autograd::make_variable(
5099+
bridge::CreateXlaTensor(rhs, device), false);
5100+
at::Tensor xla_result = at::legacy::th::__and__(xla_lhs, xla_rhs);
5101+
AllClose(result, xla_result);
5102+
});
5103+
}
5104+
5105+
TEST_F(AtenXlaTensorTest, TestBitwiseOrAutograd) {
5106+
at::Tensor lhs = at::randint(0, std::numeric_limits<int32_t>::max(), {4, 2},
5107+
at::TensorOptions(at::kInt));
5108+
at::Tensor rhs = at::randint(0, std::numeric_limits<int32_t>::max(), {4, 2},
5109+
at::TensorOptions(at::kInt));
5110+
at::Tensor result = at::legacy::th::__or__(lhs, rhs);
5111+
ForEachDevice([&](const Device& device) {
5112+
at::Tensor xla_lhs = torch::autograd::make_variable(
5113+
bridge::CreateXlaTensor(lhs, device), false);
5114+
at::Tensor xla_rhs = torch::autograd::make_variable(
5115+
bridge::CreateXlaTensor(rhs, device), false);
5116+
at::Tensor xla_result = at::legacy::th::__or__(xla_lhs, xla_rhs);
5117+
AllClose(result, xla_result);
5118+
});
5119+
}
5120+
5121+
TEST_F(AtenXlaTensorTest, TestBitwiseXorAutograd) {
5122+
at::Tensor lhs = at::randint(0, std::numeric_limits<int32_t>::max(), {4, 2},
5123+
at::TensorOptions(at::kInt));
5124+
at::Tensor rhs = at::randint(0, std::numeric_limits<int32_t>::max(), {4, 2},
5125+
at::TensorOptions(at::kInt));
5126+
at::Tensor result = at::legacy::th::__xor__(lhs, rhs);
5127+
ForEachDevice([&](const Device& device) {
5128+
at::Tensor xla_lhs = torch::autograd::make_variable(
5129+
bridge::CreateXlaTensor(lhs, device), false);
5130+
at::Tensor xla_rhs = torch::autograd::make_variable(
5131+
bridge::CreateXlaTensor(rhs, device), false);
5132+
at::Tensor xla_result = at::legacy::th::__xor__(xla_lhs, xla_rhs);
5133+
AllClose(result, xla_result);
5134+
});
5135+
}
5136+
50885137
TEST_F(AtenXlaTensorTest, TestLshift) {
50895138
at::Tensor input = at::randn({4, 2}, at::TensorOptions(at::kFloat));
50905139
at::Tensor shift_amount = at::randint(16, input.sizes());

torch_xla/csrc/tensor_methods.cpp

Lines changed: 21 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -152,34 +152,39 @@ ir::Value GetIrValueOrDefault(const XLATensor& input, at::Scalar default_value,
152152
: input.GetIrValue();
153153
}
154154

155+
void CheckIsIntegralOrPred(const xla::Shape& shape,
156+
const std::string& op_name) {
157+
XLA_CHECK(xla::ShapeUtil::ElementIsIntegral(shape) ||
158+
shape.element_type() == xla::PrimitiveType::PRED)
159+
<< "Operator " << op_name
160+
<< " is only supported for integer or boolean type tensors, got: "
161+
<< shape;
162+
}
163+
155164
} // namespace
156165

157166
XLATensor XLATensor::__and__(const XLATensor& input, at::Scalar other) {
158-
XLA_CHECK(xla::ShapeUtil::ElementIsIntegral(input.shape()))
159-
<< "Bitwise and is only supported for integer type tensors";
167+
CheckIsIntegralOrPred(input.shape(), "__and__");
160168
ir::NodePtr other_broadcasted_ir = ir::ops::ScalarOp(other, input.shape());
161169
return input.CreateFrom(
162170
ir::ops::BitwiseAnd(input.GetIrValue(), other_broadcasted_ir));
163171
}
164172

165173
XLATensor XLATensor::__and__(const XLATensor& input, const XLATensor& other) {
166-
XLA_CHECK(xla::ShapeUtil::ElementIsIntegral(input.shape()))
167-
<< "Bitwise and is only supported for integer type tensors";
174+
CheckIsIntegralOrPred(input.shape(), "__and__");
168175
return input.CreateFrom(
169176
ir::ops::BitwiseAnd(input.GetIrValue(), other.GetIrValue()));
170177
}
171178

172179
void XLATensor::__iand__(XLATensor& input, at::Scalar other) {
173-
XLA_CHECK(xla::ShapeUtil::ElementIsIntegral(input.shape()))
174-
<< "Bitwise and is only supported for integer type tensors";
180+
CheckIsIntegralOrPred(input.shape(), "__iand__");
175181
ir::NodePtr other_broadcasted_ir = ir::ops::ScalarOp(other, input.shape());
176182
input.SetIrValue(
177183
ir::ops::BitwiseAnd(input.GetIrValue(), other_broadcasted_ir));
178184
}
179185

180186
void XLATensor::__iand__(XLATensor& input, const XLATensor& other) {
181-
XLA_CHECK(xla::ShapeUtil::ElementIsIntegral(input.shape()))
182-
<< "Bitwise and is only supported for integer type tensors";
187+
CheckIsIntegralOrPred(input.shape(), "__iand__");
183188
input.SetIrValue(ir::ops::BitwiseAnd(input.GetIrValue(), other.GetIrValue()));
184189
}
185190

@@ -192,16 +197,14 @@ void XLATensor::__ilshift__(XLATensor& input, const XLATensor& other) {
192197
}
193198

194199
void XLATensor::__ior__(XLATensor& input, at::Scalar other) {
195-
XLA_CHECK(xla::ShapeUtil::ElementIsIntegral(input.shape()))
196-
<< "Bitwise or is only supported for integer type tensors";
200+
CheckIsIntegralOrPred(input.shape(), "__ior__");
197201
ir::NodePtr other_broadcasted_ir = ir::ops::ScalarOp(other, input.shape());
198202
input.SetIrValue(
199203
ir::ops::BitwiseOr(input.GetIrValue(), other_broadcasted_ir));
200204
}
201205

202206
void XLATensor::__ior__(XLATensor& input, const XLATensor& other) {
203-
XLA_CHECK(xla::ShapeUtil::ElementIsIntegral(input.shape()))
204-
<< "Bitwise or is only supported for integer type tensors";
207+
CheckIsIntegralOrPred(input.shape(), "__ior__");
205208
return input.SetIrValue(
206209
ir::ops::BitwiseOr(input.GetIrValue(), other.GetIrValue()));
207210
}
@@ -215,16 +218,14 @@ void XLATensor::__irshift__(XLATensor& input, const XLATensor& other) {
215218
}
216219

217220
void XLATensor::__ixor__(XLATensor& input, at::Scalar other) {
218-
XLA_CHECK(xla::ShapeUtil::ElementIsIntegral(input.shape()))
219-
<< "Bitwise xor is only supported for integer type tensors";
221+
CheckIsIntegralOrPred(input.shape(), "__ixor__");
220222
ir::NodePtr other_broadcasted_ir = ir::ops::ScalarOp(other, input.shape());
221223
input.SetIrValue(
222224
ir::ops::BitwiseXor(input.GetIrValue(), other_broadcasted_ir));
223225
}
224226

225227
void XLATensor::__ixor__(XLATensor& input, const XLATensor& other) {
226-
XLA_CHECK(xla::ShapeUtil::ElementIsIntegral(input.shape()))
227-
<< "Bitwise xor is only supported for integer type tensors";
228+
CheckIsIntegralOrPred(input.shape(), "__ixor__");
228229
input.SetIrValue(ir::ops::BitwiseXor(input.GetIrValue(), other.GetIrValue()));
229230
}
230231

@@ -239,15 +240,13 @@ XLATensor XLATensor::__lshift__(const XLATensor& input,
239240
}
240241

241242
XLATensor XLATensor::__or__(const XLATensor& input, const XLATensor& other) {
242-
XLA_CHECK(xla::ShapeUtil::ElementIsIntegral(input.shape()))
243-
<< "Bitwise or is only supported for integer type tensors";
243+
CheckIsIntegralOrPred(input.shape(), "__or__");
244244
return input.CreateFrom(
245245
ir::ops::BitwiseOr(input.GetIrValue(), other.GetIrValue()));
246246
}
247247

248248
XLATensor XLATensor::__or__(const XLATensor& input, at::Scalar other) {
249-
XLA_CHECK(xla::ShapeUtil::ElementIsIntegral(input.shape()))
250-
<< "Bitwise or is only supported for integer type tensors";
249+
CheckIsIntegralOrPred(input.shape(), "__or__");
251250
ir::NodePtr other_broadcasted_ir = ir::ops::ScalarOp(other, input.shape());
252251
return input.CreateFrom(
253252
ir::ops::BitwiseOr(input.GetIrValue(), other_broadcasted_ir));
@@ -264,16 +263,14 @@ XLATensor XLATensor::__rshift__(const XLATensor& input,
264263
}
265264

266265
XLATensor XLATensor::__xor__(const XLATensor& input, at::Scalar other) {
267-
XLA_CHECK(xla::ShapeUtil::ElementIsIntegral(input.shape()))
268-
<< "Bitwise xor is only supported for integer type tensors";
266+
CheckIsIntegralOrPred(input.shape(), "__xor__");
269267
ir::NodePtr other_broadcasted_ir = ir::ops::ScalarOp(other, input.shape());
270268
return input.CreateFrom(
271269
ir::ops::BitwiseXor(input.GetIrValue(), other_broadcasted_ir));
272270
}
273271

274272
XLATensor XLATensor::__xor__(const XLATensor& input, const XLATensor& other) {
275-
XLA_CHECK(xla::ShapeUtil::ElementIsIntegral(input.shape()))
276-
<< "Bitwise xor is only supported for integer type tensors";
273+
CheckIsIntegralOrPred(input.shape(), "__xor__");
277274
return input.CreateFrom(
278275
ir::ops::BitwiseXor(input.GetIrValue(), other.GetIrValue()));
279276
}

0 commit comments

Comments
 (0)