@@ -152,34 +152,39 @@ ir::Value GetIrValueOrDefault(const XLATensor& input, at::Scalar default_value,
152152 : input.GetIrValue ();
153153}
154154
155+ void CheckIsIntegralOrPred (const xla::Shape& shape,
156+ const std::string& op_name) {
157+ XLA_CHECK (xla::ShapeUtil::ElementIsIntegral (shape) ||
158+ shape.element_type () == xla::PrimitiveType::PRED)
159+ << " Operator " << op_name
160+ << " is only supported for integer or boolean type tensors, got: "
161+ << shape;
162+ }
163+
155164} // namespace
156165
157166XLATensor XLATensor::__and__ (const XLATensor& input, at::Scalar other) {
158- XLA_CHECK (xla::ShapeUtil::ElementIsIntegral (input.shape ()))
159- << " Bitwise and is only supported for integer type tensors" ;
167+ CheckIsIntegralOrPred (input.shape (), " __and__" );
160168 ir::NodePtr other_broadcasted_ir = ir::ops::ScalarOp (other, input.shape ());
161169 return input.CreateFrom (
162170 ir::ops::BitwiseAnd (input.GetIrValue (), other_broadcasted_ir));
163171}
164172
165173XLATensor XLATensor::__and__ (const XLATensor& input, const XLATensor& other) {
166- XLA_CHECK (xla::ShapeUtil::ElementIsIntegral (input.shape ()))
167- << " Bitwise and is only supported for integer type tensors" ;
174+ CheckIsIntegralOrPred (input.shape (), " __and__" );
168175 return input.CreateFrom (
169176 ir::ops::BitwiseAnd (input.GetIrValue (), other.GetIrValue ()));
170177}
171178
172179void XLATensor::__iand__ (XLATensor& input, at::Scalar other) {
173- XLA_CHECK (xla::ShapeUtil::ElementIsIntegral (input.shape ()))
174- << " Bitwise and is only supported for integer type tensors" ;
180+ CheckIsIntegralOrPred (input.shape (), " __iand__" );
175181 ir::NodePtr other_broadcasted_ir = ir::ops::ScalarOp (other, input.shape ());
176182 input.SetIrValue (
177183 ir::ops::BitwiseAnd (input.GetIrValue (), other_broadcasted_ir));
178184}
179185
180186void XLATensor::__iand__ (XLATensor& input, const XLATensor& other) {
181- XLA_CHECK (xla::ShapeUtil::ElementIsIntegral (input.shape ()))
182- << " Bitwise and is only supported for integer type tensors" ;
187+ CheckIsIntegralOrPred (input.shape (), " __iand__" );
183188 input.SetIrValue (ir::ops::BitwiseAnd (input.GetIrValue (), other.GetIrValue ()));
184189}
185190
@@ -192,16 +197,14 @@ void XLATensor::__ilshift__(XLATensor& input, const XLATensor& other) {
192197}
193198
194199void XLATensor::__ior__ (XLATensor& input, at::Scalar other) {
195- XLA_CHECK (xla::ShapeUtil::ElementIsIntegral (input.shape ()))
196- << " Bitwise or is only supported for integer type tensors" ;
200+ CheckIsIntegralOrPred (input.shape (), " __ior__" );
197201 ir::NodePtr other_broadcasted_ir = ir::ops::ScalarOp (other, input.shape ());
198202 input.SetIrValue (
199203 ir::ops::BitwiseOr (input.GetIrValue (), other_broadcasted_ir));
200204}
201205
202206void XLATensor::__ior__ (XLATensor& input, const XLATensor& other) {
203- XLA_CHECK (xla::ShapeUtil::ElementIsIntegral (input.shape ()))
204- << " Bitwise or is only supported for integer type tensors" ;
207+ CheckIsIntegralOrPred (input.shape (), " __ior__" );
205208 return input.SetIrValue (
206209 ir::ops::BitwiseOr (input.GetIrValue (), other.GetIrValue ()));
207210}
@@ -215,16 +218,14 @@ void XLATensor::__irshift__(XLATensor& input, const XLATensor& other) {
215218}
216219
217220void XLATensor::__ixor__ (XLATensor& input, at::Scalar other) {
218- XLA_CHECK (xla::ShapeUtil::ElementIsIntegral (input.shape ()))
219- << " Bitwise xor is only supported for integer type tensors" ;
221+ CheckIsIntegralOrPred (input.shape (), " __ixor__" );
220222 ir::NodePtr other_broadcasted_ir = ir::ops::ScalarOp (other, input.shape ());
221223 input.SetIrValue (
222224 ir::ops::BitwiseXor (input.GetIrValue (), other_broadcasted_ir));
223225}
224226
225227void XLATensor::__ixor__ (XLATensor& input, const XLATensor& other) {
226- XLA_CHECK (xla::ShapeUtil::ElementIsIntegral (input.shape ()))
227- << " Bitwise xor is only supported for integer type tensors" ;
228+ CheckIsIntegralOrPred (input.shape (), " __ixor__" );
228229 input.SetIrValue (ir::ops::BitwiseXor (input.GetIrValue (), other.GetIrValue ()));
229230}
230231
@@ -239,15 +240,13 @@ XLATensor XLATensor::__lshift__(const XLATensor& input,
239240}
240241
241242XLATensor XLATensor::__or__ (const XLATensor& input, const XLATensor& other) {
242- XLA_CHECK (xla::ShapeUtil::ElementIsIntegral (input.shape ()))
243- << " Bitwise or is only supported for integer type tensors" ;
243+ CheckIsIntegralOrPred (input.shape (), " __or__" );
244244 return input.CreateFrom (
245245 ir::ops::BitwiseOr (input.GetIrValue (), other.GetIrValue ()));
246246}
247247
248248XLATensor XLATensor::__or__ (const XLATensor& input, at::Scalar other) {
249- XLA_CHECK (xla::ShapeUtil::ElementIsIntegral (input.shape ()))
250- << " Bitwise or is only supported for integer type tensors" ;
249+ CheckIsIntegralOrPred (input.shape (), " __or__" );
251250 ir::NodePtr other_broadcasted_ir = ir::ops::ScalarOp (other, input.shape ());
252251 return input.CreateFrom (
253252 ir::ops::BitwiseOr (input.GetIrValue (), other_broadcasted_ir));
@@ -264,16 +263,14 @@ XLATensor XLATensor::__rshift__(const XLATensor& input,
264263}
265264
266265XLATensor XLATensor::__xor__ (const XLATensor& input, at::Scalar other) {
267- XLA_CHECK (xla::ShapeUtil::ElementIsIntegral (input.shape ()))
268- << " Bitwise xor is only supported for integer type tensors" ;
266+ CheckIsIntegralOrPred (input.shape (), " __xor__" );
269267 ir::NodePtr other_broadcasted_ir = ir::ops::ScalarOp (other, input.shape ());
270268 return input.CreateFrom (
271269 ir::ops::BitwiseXor (input.GetIrValue (), other_broadcasted_ir));
272270}
273271
274272XLATensor XLATensor::__xor__ (const XLATensor& input, const XLATensor& other) {
275- XLA_CHECK (xla::ShapeUtil::ElementIsIntegral (input.shape ()))
276- << " Bitwise xor is only supported for integer type tensors" ;
273+ CheckIsIntegralOrPred (input.shape (), " __xor__" );
277274 return input.CreateFrom (
278275 ir::ops::BitwiseXor (input.GetIrValue (), other.GetIrValue ()));
279276}
0 commit comments