Skip to content

Commit

Permalink
Revert D24335982: explicitly error out in comparison ops when the typ…
Browse files Browse the repository at this point in the history
…es don't match

Test Plan: revert-hammer

Differential Revision:
D24335982 (60fea51)

Original commit changeset: 3dfb02bcb403

fbshipit-source-id: 00072f1b00e228bbbe295053091cf4a7a46f4668
  • Loading branch information
bdhirsh authored and facebook-github-bot committed Nov 2, 2020
1 parent 7f125bc commit b3eb0c8
Show file tree
Hide file tree
Showing 9 changed files with 61 additions and 145 deletions.
2 changes: 1 addition & 1 deletion aten/src/ATen/cpu/vec256/vec256_bfloat16.h
Expand Up @@ -465,7 +465,7 @@ Vec256<BFloat16> inline Vec256<BFloat16>::operator==(const Vec256<BFloat16>& oth
}
Vec256<BFloat16> inline Vec256<BFloat16>::operator!=(const Vec256<BFloat16>& other) const {
return bfloat16_binary_op_as_fp32(*this, other, [](__m256 x, __m256 y) {
return _mm256_cmp_ps(x, y, _CMP_NEQ_UQ);
return _mm256_cmp_ps(x, y, _CMP_NEQ_OQ);
});
}

Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/cpu/vec256/vec256_complex_double.h
Expand Up @@ -309,7 +309,7 @@ template <> class Vec256<c10::complex<double>> {
return _mm256_cmp_pd(values, other.values, _CMP_EQ_OQ);
}
Vec256<c10::complex<double>> operator!=(const Vec256<c10::complex<double>>& other) const {
return _mm256_cmp_pd(values, other.values, _CMP_NEQ_UQ);
return _mm256_cmp_pd(values, other.values, _CMP_NEQ_OQ);
}
Vec256<c10::complex<double>> operator<(const Vec256<c10::complex<double>>& other) const {
TORCH_CHECK(false, "not supported for complex numbers");
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/cpu/vec256/vec256_complex_float.h
Expand Up @@ -347,7 +347,7 @@ template <> class Vec256<c10::complex<float>> {
return _mm256_cmp_ps(values, other.values, _CMP_EQ_OQ);
}
Vec256<c10::complex<float>> operator!=(const Vec256<c10::complex<float>>& other) const {
return _mm256_cmp_ps(values, other.values, _CMP_NEQ_UQ);
return _mm256_cmp_ps(values, other.values, _CMP_NEQ_OQ);
}
Vec256<c10::complex<float>> operator<(const Vec256<c10::complex<float>>& other) const {
TORCH_CHECK(false, "not supported for complex numbers");
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/cpu/vec256/vec256_double.h
Expand Up @@ -237,7 +237,7 @@ template <> class Vec256<double> {
}

Vec256<double> operator!=(const Vec256<double>& other) const {
return _mm256_cmp_pd(values, other.values, _CMP_NEQ_UQ);
return _mm256_cmp_pd(values, other.values, _CMP_NEQ_OQ);
}

Vec256<double> operator<(const Vec256<double>& other) const {
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/cpu/vec256/vec256_float.h
Expand Up @@ -244,7 +244,7 @@ template <> class Vec256<float> {
}

Vec256<float> operator!=(const Vec256<float>& other) const {
return _mm256_cmp_ps(values, other.values, _CMP_NEQ_UQ);
return _mm256_cmp_ps(values, other.values, _CMP_NEQ_OQ);
}

Vec256<float> operator<(const Vec256<float>& other) const {
Expand Down
8 changes: 4 additions & 4 deletions aten/src/ATen/native/ReduceOps.cpp
Expand Up @@ -1008,8 +1008,8 @@ Tensor var(const Tensor& self, bool unbiased) {
return trivial_return.value();
}

// NOTE: CPU performance significantly regressed when attempting to port to ATen,
// so this dispatches differently based on device type.
// NOTE: CPU performance significantly regressed when attempting to port to ATen,
// so this dispatches differently based on device type.
// See https://github.com/pytorch/pytorch/pull/43858.
if (self.device().type() == kCPU) {
return at::_var(self, unbiased);
Expand Down Expand Up @@ -1040,8 +1040,8 @@ Tensor std(const Tensor& self, bool unbiased) {
return trivial_return.value();
}

// NOTE: CPU performance significantly regressed when attempting to port to ATen,
// so this dispatches differently based on device type.
// NOTE: CPU performance significantly regressed when attempting to port to ATen,
// so this dispatches differently based on device type.
// See https://github.com/pytorch/pytorch/pull/43858.
if (self.device().type() == kCPU) {
return at::_std(self, unbiased);
Expand Down
23 changes: 1 addition & 22 deletions aten/src/ATen/native/TensorIterator.cpp
Expand Up @@ -406,15 +406,13 @@ void TensorIterator::compute_types(const TensorIteratorConfig& config) {
op.tensor.options().dtype(common_dtype_),
LEGACY_CONTIGUOUS_MEMORY_FORMAT);
op.current_dtype = common_dtype_;
op.target_dtype = common_dtype_;
}

// Promotes inputs by creating temporaries of the correct dtype
if (config.promote_inputs_to_common_dtype_ && !op.is_output && op.current_dtype != common_dtype_) {
op.original_tensor = op.tensor;
op.tensor = op.tensor.to(common_dtype_);
op.current_dtype = common_dtype_;
op.target_dtype = common_dtype_;
}
}
}
Expand Down Expand Up @@ -849,33 +847,14 @@ TensorIterator TensorIterator::binary_float_op(Tensor& out, const Tensor& a,

TensorIterator TensorIterator::comparison_op(Tensor& out, const Tensor& a,
const Tensor& b) {
// Note [special-case bool outputs]
// We explicitly don't call `cast_common_dtype_to_outputs` when the output tensor
// has `bool` dtype. This is a performance optimization: the functional
// version of all comparison/logical ops uses a bool output tensor, and we'd like to
// avoid creating a temporary copy of the output.
// However, note that all kernels using this TensorIterator will need to special-case when
// the output tensor has bool dtype, and provide a lambda of type (scalar_t, scalar_t -> bool).
if (out.scalar_type() == kBool) {
return TensorIteratorConfig()
.set_check_mem_overlap(true)
.add_output(out)
.add_input(a)
.add_input(b)
.allow_cpu_scalars(true)
.promote_inputs_to_common_dtype(true)
.build();
} else {
return TensorIteratorConfig()
return TensorIteratorConfig()
.set_check_mem_overlap(true)
.add_output(out)
.add_input(a)
.add_input(b)
.allow_cpu_scalars(true)
.promote_inputs_to_common_dtype(true)
.cast_common_dtype_to_outputs(true)
.build();
}
}

TensorIterator TensorIterator::unary_op(Tensor& out, const Tensor& a) {
Expand Down
103 changes: 50 additions & 53 deletions aten/src/ATen/native/cpu/BinaryOpsKernel.cpp
Expand Up @@ -234,16 +234,17 @@ void lshift_kernel(TensorIterator& iter) {
}

void logical_and_kernel(TensorIterator& iter) {
// See Note [special-case bool outputs]
// We use if-else here specifically for bool instead of using iter.common_dtype() like the CUDA implementation because
// common_dtype() is unavailable for bfloat16.
if (iter.dtype() == ScalarType::Bool) {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kBool, kBFloat16, kHalf, iter.common_dtype(), "logical_and_cpu", [&]() {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kBool, kBFloat16, kHalf, iter.input_dtype(), "logical_and_cpu", [&]() {
cpu_kernel(iter,
[](scalar_t a, scalar_t b) -> bool {
return a && b;
});
});
} else {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kBFloat16, kHalf, iter.common_dtype(), "logical_and_cpu", [&]() {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kBFloat16, kHalf, iter.dtype(), "logical_and_cpu", [&]() {
cpu_kernel(iter,
[](scalar_t a, scalar_t b) -> scalar_t {
return static_cast<scalar_t>(a && b);
Expand All @@ -253,35 +254,37 @@ void logical_and_kernel(TensorIterator& iter) {
}

void logical_or_kernel(TensorIterator& iter) {
// See Note [special-case bool outputs]
// We use if-else here specifically for bool instead of using iter.common_dtype() like the CUDA implementation because
// common_dtype() is unavailable for bfloat16.
if (iter.dtype() == ScalarType::Bool) {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kBool, kBFloat16, kHalf, iter.common_dtype(), "logical_or_cpu", [&]() {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kBool, kBFloat16, kHalf, iter.input_dtype(), "logical_or_cpu", [&]() {
cpu_kernel(iter,
[](scalar_t a, scalar_t b) -> bool {
return a || b;
});
});
} else {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kBool, kBFloat16, kHalf, iter.common_dtype(), "logical_or_cpu", [&]() {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kBool, kBFloat16, kHalf, iter.dtype(), "logical_or_cpu", [&]() {
cpu_kernel(iter,
[](scalar_t a, scalar_t b) -> scalar_t {
return static_cast<scalar_t>(a || b);
});
});
});
}
}

void logical_xor_kernel(TensorIterator& iter) {
// See Note [special-case bool outputs]
// We use if-else here specifically for bool instead of using iter.common_dtype() like the CUDA implementation because
// common_dtype() is unavailable for bfloat16.
if (iter.dtype() == ScalarType::Bool) {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kBool, kBFloat16, kHalf, iter.common_dtype(), "logical_xor_cpu", [&]() {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kBool, kBFloat16, kHalf, iter.input_dtype(), "logical_xor_cpu", [&]() {
cpu_kernel(iter,
[](scalar_t a, scalar_t b) -> bool {
return bool(a) != bool(b);
});
});
} else {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kBFloat16, kHalf, iter.common_dtype(), "logical_xor_cpu", [&]() {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kBFloat16, kHalf, iter.dtype(), "logical_xor_cpu", [&]() {
cpu_kernel(iter,
[](scalar_t a, scalar_t b) -> scalar_t {
return static_cast<scalar_t>(bool(a) != bool(b));
Expand All @@ -308,22 +311,21 @@ void rshift_kernel(TensorIterator& iter) {
cpu_kernel(iter,
[](scalar_t a, scalar_t b) -> scalar_t {
return a >> b;
});
});
});
}
}

void lt_kernel(TensorIterator& iter) {
// See Note [special-case bool outputs]
if (iter.dtype() == ScalarType::Bool) {
AT_DISPATCH_ALL_TYPES_AND3(kBool, kBFloat16, kHalf, iter.common_dtype(), "lt_cpu", [&]() {
AT_DISPATCH_ALL_TYPES_AND3(kBool, kBFloat16, kHalf, iter.input_dtype(), "lt_cpu", [&]() {
cpu_kernel(iter,
[](scalar_t a, scalar_t b) -> bool {
return a < b;
});
[](scalar_t a, scalar_t b) -> bool {
return a < b;
});
});
} else {
AT_DISPATCH_ALL_TYPES_AND2(kBFloat16, kHalf, iter.common_dtype(), "lt_cpu", [&]() {
AT_DISPATCH_ALL_TYPES_AND2(kBFloat16, kHalf, iter.dtype(), "lt_cpu", [&]() {
cpu_kernel_vec(
iter,
[](scalar_t a, scalar_t b) -> scalar_t {
Expand All @@ -332,21 +334,20 @@ void lt_kernel(TensorIterator& iter) {
[](Vec256<scalar_t> a, Vec256<scalar_t> b) -> Vec256<scalar_t> {
return a.lt(b);
});
});
});
}
}

void le_kernel(TensorIterator& iter) {
// See Note [special-case bool outputs]
if (iter.dtype() == ScalarType::Bool) {
AT_DISPATCH_ALL_TYPES_AND3(kBool, kBFloat16, kHalf, iter.common_dtype(), "le_cpu", [&]() {
AT_DISPATCH_ALL_TYPES_AND3(kBool, kBFloat16, kHalf, iter.input_dtype(), "le_cpu", [&]() {
cpu_kernel(iter,
[](scalar_t a, scalar_t b) -> bool {
return a <= b;
});
[](scalar_t a, scalar_t b) -> bool {
return a <= b;
});
});
} else {
AT_DISPATCH_ALL_TYPES_AND2(kBFloat16, kHalf, iter.common_dtype(), "le_cpu", [&]() {
AT_DISPATCH_ALL_TYPES_AND2(kBFloat16, kHalf, iter.dtype(), "le_cpu", [&]() {
cpu_kernel_vec(
iter,
[](scalar_t a, scalar_t b) -> scalar_t {
Expand All @@ -355,21 +356,20 @@ void le_kernel(TensorIterator& iter) {
[](Vec256<scalar_t> a, Vec256<scalar_t> b) -> Vec256<scalar_t> {
return a.le(b);
});
});
});
}
}

void gt_kernel(TensorIterator& iter) {
// See Note [special-case bool outputs]
if (iter.dtype() == ScalarType::Bool) {
AT_DISPATCH_ALL_TYPES_AND3(kBool, kBFloat16, kHalf, iter.common_dtype(), "gt_cpu", [&]() {
AT_DISPATCH_ALL_TYPES_AND3(kBool, kBFloat16, kHalf, iter.input_dtype(), "gt_cpu", [&]() {
cpu_kernel(iter,
[](scalar_t a, scalar_t b) -> bool {
return a > b;
});
[=](scalar_t a, scalar_t b) -> bool {
return a > b;
});
});
} else {
AT_DISPATCH_ALL_TYPES_AND2(kBFloat16, kHalf, iter.common_dtype(), "gt_cpu", [&]() {
AT_DISPATCH_ALL_TYPES_AND2(kBFloat16, kHalf, iter.dtype(), "gt_cpu", [&]() {
cpu_kernel_vec(
iter,
[](scalar_t a, scalar_t b) -> scalar_t {
Expand All @@ -378,21 +378,20 @@ void gt_kernel(TensorIterator& iter) {
[](Vec256<scalar_t> a, Vec256<scalar_t> b) -> Vec256<scalar_t> {
return a.gt(b);
});
});
});
}
}

void ge_kernel(TensorIterator& iter) {
// See Note [special-case bool outputs]
if (iter.dtype() == ScalarType::Bool) {
AT_DISPATCH_ALL_TYPES_AND3(kBool, kBFloat16, kHalf, iter.common_dtype(), "ge_cpu", [&]() {
AT_DISPATCH_ALL_TYPES_AND3(kBool, kBFloat16, kHalf, iter.input_dtype(), "ge_cpu", [&]() {
cpu_kernel(iter,
[](scalar_t a, scalar_t b) -> bool {
return a >= b;
});
[](scalar_t a, scalar_t b) -> bool {
return a >= b;
});
});
} else {
AT_DISPATCH_ALL_TYPES_AND2(kBFloat16, kHalf, iter.common_dtype(), "ge_cpu", [&]() {
AT_DISPATCH_ALL_TYPES_AND2(kBFloat16, kHalf, iter.dtype(), "ge_cpu", [&]() {
cpu_kernel_vec(
iter,
[](scalar_t a, scalar_t b) -> scalar_t {
Expand All @@ -401,21 +400,20 @@ void ge_kernel(TensorIterator& iter) {
[](Vec256<scalar_t> a, Vec256<scalar_t> b) -> Vec256<scalar_t> {
return a.ge(b);
});
});
});
}
}

void eq_kernel(TensorIterator& iter) {
// See Note [special-case bool outputs]
if (iter.dtype() == ScalarType::Bool) {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kBool, kBFloat16, kHalf, iter.common_dtype(), "eq_cpu", [&]() {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kBool, kBFloat16, kHalf, iter.input_dtype(), "eq_cpu", [&]() {
cpu_kernel(iter,
[](scalar_t a, scalar_t b) -> bool {
return a == b;
});
[](scalar_t a, scalar_t b) -> bool {
return a == b;
});
});
} else {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kBFloat16, kHalf, iter.common_dtype(), "eq_cpu", [&]() {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kBFloat16, kHalf, iter.dtype(), "eq_cpu", [&]() {
cpu_kernel_vec(
iter,
[](scalar_t a, scalar_t b) -> scalar_t {
Expand All @@ -424,21 +422,20 @@ void eq_kernel(TensorIterator& iter) {
[](Vec256<scalar_t> a, Vec256<scalar_t> b) -> Vec256<scalar_t> {
return a.eq(b);
});
});
});
}
}

void ne_kernel(TensorIterator& iter) {
// See Note [special-case bool outputs]
if (iter.dtype() == ScalarType::Bool) {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kBool, kBFloat16, kHalf, iter.common_dtype(), "ne_cpu", [&]() {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kBool, kBFloat16, kHalf, iter.input_dtype(), "ne_cpu", [&]() {
cpu_kernel(iter,
[](scalar_t a, scalar_t b) -> bool {
return a != b;
});
[](scalar_t a, scalar_t b) -> bool {
return a != b;
});
});
} else {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kBFloat16, kHalf, iter.common_dtype(), "ne_cpu", [&]() {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kBFloat16, kHalf, iter.dtype(), "ne_cpu", [&]() {
cpu_kernel_vec(
iter,
[](scalar_t a, scalar_t b) -> scalar_t {
Expand All @@ -447,7 +444,7 @@ void ne_kernel(TensorIterator& iter) {
[](Vec256<scalar_t> a, Vec256<scalar_t> b) -> Vec256<scalar_t> {
return a.ne(b);
});
});
});
}
}

Expand Down

0 comments on commit b3eb0c8

Please sign in to comment.