Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions test/torch_test_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
'test_std_mean_some_dims',
'test_zeros_like',
'test_histc',
'test_bool_sub',
'test_bool_tensor_comparison_ops',
'test_bool_tensor_value_change',
'test_addcmul',
Expand Down
16 changes: 16 additions & 0 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,16 @@ bool IsOperationOnType(const c10::optional<at::ScalarType>& opt_dtype,
return tensor_type == type;
}

void CheckSubOperandTypes(at::ScalarType type1, at::ScalarType type2) {
XLA_CHECK(type1 != at::kBool || type2 != at::kBool)
<< "Subtraction, the `-` operator, with two bool tensors is not "
"supported. Use the `^` or `logical_xor()` operator instead.";
XLA_CHECK(type1 != at::kBool && type2 != at::kBool)
<< "Subtraction, the `-` operator, with a bool tensor is not "
"supported. If you are trying to invert a mask, use the `~` or "
"`logical_not()` operator instead.";
}

void AtenInitialize() {
RegisterAtenTypeFunctions();
XLATensorImpl::AtenInitialize();
Expand Down Expand Up @@ -2388,12 +2398,14 @@ at::Tensor& AtenXlaType::rsqrt_(at::Tensor& self) {

at::Tensor AtenXlaType::rsub(const at::Tensor& self, const at::Tensor& other,
at::Scalar alpha) {
CheckSubOperandTypes(self.scalar_type(), other.scalar_type());
return bridge::AtenFromXlaTensor(XLATensor::rsub(
bridge::GetXlaTensor(self), bridge::GetXlaTensor(other), alpha));
}

at::Tensor AtenXlaType::rsub(const at::Tensor& self, at::Scalar other,
at::Scalar alpha) {
CheckSubOperandTypes(self.scalar_type(), GetScalarType(other));
return bridge::AtenFromXlaTensor(
XLATensor::rsub(bridge::GetXlaTensor(self), other, alpha));
}
Expand Down Expand Up @@ -2627,6 +2639,7 @@ at::Tensor AtenXlaType::stack(at::TensorList tensors, int64_t dim) {

at::Tensor AtenXlaType::sub(const at::Tensor& self, const at::Tensor& other,
at::Scalar alpha) {
CheckSubOperandTypes(self.scalar_type(), other.scalar_type());
XLATensor self_tensor = bridge::GetXlaTensor(self);
return bridge::AtenFromXlaTensor(XLATensor::sub(
self_tensor, bridge::GetOrCreateXlaTensor(other, self_tensor.GetDevice()),
Expand All @@ -2635,12 +2648,14 @@ at::Tensor AtenXlaType::sub(const at::Tensor& self, const at::Tensor& other,

at::Tensor AtenXlaType::sub(const at::Tensor& self, at::Scalar other,
at::Scalar alpha) {
CheckSubOperandTypes(self.scalar_type(), GetScalarType(other));
return bridge::AtenFromXlaTensor(
XLATensor::sub(bridge::GetXlaTensor(self), other, alpha));
}

at::Tensor& AtenXlaType::sub_(at::Tensor& self, const at::Tensor& other,
at::Scalar alpha) {
CheckSubOperandTypes(self.scalar_type(), other.scalar_type());
XLATensor self_tensor = bridge::GetXlaTensor(self);
XLATensor::sub_(self_tensor,
bridge::GetOrCreateXlaTensor(other, self_tensor.GetDevice()),
Expand All @@ -2650,6 +2665,7 @@ at::Tensor& AtenXlaType::sub_(at::Tensor& self, const at::Tensor& other,

at::Tensor& AtenXlaType::sub_(at::Tensor& self, at::Scalar other,
at::Scalar alpha) {
CheckSubOperandTypes(self.scalar_type(), GetScalarType(other));
XLATensor self_tensor = bridge::GetXlaTensor(self);
XLATensor::sub_(self_tensor, other, alpha);
return self;
Expand Down
13 changes: 13 additions & 0 deletions torch_xla/csrc/torch_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,17 @@ at::Tensor CopyTensor(const at::Tensor& ref, at::ScalarType dest_type) {
return ref.to(ref.options().dtype(dest_type), /*non_blocking=*/false,
/*copy=*/true);
}

at::ScalarType GetScalarType(at::Scalar scalar) {
if (scalar.isFloatingPoint()) {
return at::kDouble;
} else if (scalar.isBoolean()) {
return at::kBool;
} else if (scalar.isComplex()) {
return at::kComplexDouble;
} else {
XLA_CHECK(scalar.isIntegral(/*includeBool=*/false));
return at::kLong;
}
}
} // namespace torch_xla
3 changes: 3 additions & 0 deletions torch_xla/csrc/torch_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ at::Tensor CopyTensor(const at::Tensor& ref);
// Same as above, with an additional cast.
at::Tensor CopyTensor(const at::Tensor& ref, at::ScalarType dest_type);

// Return at::ScalarType from at::Scalar
at::ScalarType GetScalarType(at::Scalar scalar);

template <typename T, typename S>
T OptionalOr(const c10::optional<S>& value, T defval) {
return value ? static_cast<T>(*value) : defval;
Expand Down