diff --git a/aten/src/ATen/cpu/vec256/vec256_bfloat16.h b/aten/src/ATen/cpu/vec256/vec256_bfloat16.h index d608fb9fac0c..10bbe139b63f 100644 --- a/aten/src/ATen/cpu/vec256/vec256_bfloat16.h +++ b/aten/src/ATen/cpu/vec256/vec256_bfloat16.h @@ -465,7 +465,7 @@ Vec256 inline Vec256::operator==(const Vec256& oth } Vec256 inline Vec256::operator!=(const Vec256& other) const { return bfloat16_binary_op_as_fp32(*this, other, [](__m256 x, __m256 y) { - return _mm256_cmp_ps(x, y, _CMP_NEQ_UQ); + return _mm256_cmp_ps(x, y, _CMP_NEQ_OQ); }); } diff --git a/aten/src/ATen/cpu/vec256/vec256_complex_double.h b/aten/src/ATen/cpu/vec256/vec256_complex_double.h index 1df674383fb3..d7f5afd8b67d 100644 --- a/aten/src/ATen/cpu/vec256/vec256_complex_double.h +++ b/aten/src/ATen/cpu/vec256/vec256_complex_double.h @@ -309,7 +309,7 @@ template <> class Vec256> { return _mm256_cmp_pd(values, other.values, _CMP_EQ_OQ); } Vec256> operator!=(const Vec256>& other) const { - return _mm256_cmp_pd(values, other.values, _CMP_NEQ_UQ); + return _mm256_cmp_pd(values, other.values, _CMP_NEQ_OQ); } Vec256> operator<(const Vec256>& other) const { TORCH_CHECK(false, "not supported for complex numbers"); diff --git a/aten/src/ATen/cpu/vec256/vec256_complex_float.h b/aten/src/ATen/cpu/vec256/vec256_complex_float.h index dc8ef7cc76d6..4df95dbea926 100644 --- a/aten/src/ATen/cpu/vec256/vec256_complex_float.h +++ b/aten/src/ATen/cpu/vec256/vec256_complex_float.h @@ -347,7 +347,7 @@ template <> class Vec256> { return _mm256_cmp_ps(values, other.values, _CMP_EQ_OQ); } Vec256> operator!=(const Vec256>& other) const { - return _mm256_cmp_ps(values, other.values, _CMP_NEQ_UQ); + return _mm256_cmp_ps(values, other.values, _CMP_NEQ_OQ); } Vec256> operator<(const Vec256>& other) const { TORCH_CHECK(false, "not supported for complex numbers"); diff --git a/aten/src/ATen/cpu/vec256/vec256_double.h b/aten/src/ATen/cpu/vec256/vec256_double.h index c9b359e214fe..6b611e8d2e7a 100644 --- a/aten/src/ATen/cpu/vec256/vec256_double.h +++ b/aten/src/ATen/cpu/vec256/vec256_double.h @@ -237,7 +237,7 @@ template <> class Vec256 { } Vec256 operator!=(const Vec256& other) const { - return _mm256_cmp_pd(values, other.values, _CMP_NEQ_UQ); + return _mm256_cmp_pd(values, other.values, _CMP_NEQ_OQ); } Vec256 operator<(const Vec256& other) const { diff --git a/aten/src/ATen/cpu/vec256/vec256_float.h b/aten/src/ATen/cpu/vec256/vec256_float.h index a3c4d7b89845..d83895fdf854 100644 --- a/aten/src/ATen/cpu/vec256/vec256_float.h +++ b/aten/src/ATen/cpu/vec256/vec256_float.h @@ -244,7 +244,7 @@ template <> class Vec256 { } Vec256 operator!=(const Vec256& other) const { - return _mm256_cmp_ps(values, other.values, _CMP_NEQ_UQ); + return _mm256_cmp_ps(values, other.values, _CMP_NEQ_OQ); } Vec256 operator<(const Vec256& other) const { diff --git a/aten/src/ATen/native/ReduceOps.cpp b/aten/src/ATen/native/ReduceOps.cpp index 7cef3a1e83fe..ff6d702293b9 100644 --- a/aten/src/ATen/native/ReduceOps.cpp +++ b/aten/src/ATen/native/ReduceOps.cpp @@ -1008,8 +1008,8 @@ Tensor var(const Tensor& self, bool unbiased) { return trivial_return.value(); } - // NOTE: CPU performance significantly regressed when attempting to port to ATen, - // so this dispatches differently based on device type. + // NOTE: CPU performance significantly regressed when attempting to port to ATen, + // so this dispatches differently based on device type. // See https://github.com/pytorch/pytorch/pull/43858. if (self.device().type() == kCPU) { return at::_var(self, unbiased); @@ -1040,8 +1040,8 @@ Tensor std(const Tensor& self, bool unbiased) { return trivial_return.value(); } - // NOTE: CPU performance significantly regressed when attempting to port to ATen, - // so this dispatches differently based on device type. + // NOTE: CPU performance significantly regressed when attempting to port to ATen, + // so this dispatches differently based on device type. // See https://github.com/pytorch/pytorch/pull/43858. if (self.device().type() == kCPU) { return at::_std(self, unbiased); diff --git a/aten/src/ATen/native/TensorIterator.cpp b/aten/src/ATen/native/TensorIterator.cpp index cb37b2055eda..1738e838d436 100644 --- a/aten/src/ATen/native/TensorIterator.cpp +++ b/aten/src/ATen/native/TensorIterator.cpp @@ -406,7 +406,6 @@ void TensorIterator::compute_types(const TensorIteratorConfig& config) { op.tensor.options().dtype(common_dtype_), LEGACY_CONTIGUOUS_MEMORY_FORMAT); op.current_dtype = common_dtype_; - op.target_dtype = common_dtype_; } // Promotes inputs by creating temporaries of the correct dtype @@ -414,7 +413,6 @@ void TensorIterator::compute_types(const TensorIteratorConfig& config) { op.original_tensor = op.tensor; op.tensor = op.tensor.to(common_dtype_); op.current_dtype = common_dtype_; - op.target_dtype = common_dtype_; } } } @@ -849,33 +847,14 @@ TensorIterator TensorIterator::binary_float_op(Tensor& out, const Tensor& a, TensorIterator TensorIterator::comparison_op(Tensor& out, const Tensor& a, const Tensor& b) { - // Note [special-case bool outputs] - // We explicitly don't call `cast_common_dtype_to_outputs` when the output tensor - // has `bool` dtype. This is a performance optimization: the functional - // version of all comparison/logical ops uses a bool output tensor, and we'd like to - // avoid creating a temporary copy of the output. - // However, note that all kernels using this TensorIterator will need to special-case when - // the output tensor has bool dtype, and provide a lambda of type (scalar_t, scalar_t -> bool). - if (out.scalar_type() == kBool) { - return TensorIteratorConfig() - .set_check_mem_overlap(true) - .add_output(out) - .add_input(a) - .add_input(b) - .allow_cpu_scalars(true) - .promote_inputs_to_common_dtype(true) - .build(); - } else { - return TensorIteratorConfig() + return TensorIteratorConfig() .set_check_mem_overlap(true) .add_output(out) .add_input(a) .add_input(b) .allow_cpu_scalars(true) .promote_inputs_to_common_dtype(true) - .cast_common_dtype_to_outputs(true) .build(); - } } TensorIterator TensorIterator::unary_op(Tensor& out, const Tensor& a) { diff --git a/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp b/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp index bc1fdeb1ccb5..652f3ee063e1 100644 --- a/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp +++ b/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp @@ -234,16 +234,17 @@ void lshift_kernel(TensorIterator& iter) { } void logical_and_kernel(TensorIterator& iter) { - // See Note [special-case bool outputs] + // We use if-else here specifically for bool instead of using iter.common_dtype() like the CUDA implementation because + // common_dtype() is unavailable for bfloat16. if (iter.dtype() == ScalarType::Bool) { - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kBool, kBFloat16, kHalf, iter.common_dtype(), "logical_and_cpu", [&]() { + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kBool, kBFloat16, kHalf, iter.input_dtype(), "logical_and_cpu", [&]() { cpu_kernel(iter, [](scalar_t a, scalar_t b) -> bool { return a && b; }); }); } else { - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kBFloat16, kHalf, iter.common_dtype(), "logical_and_cpu", [&]() { + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kBFloat16, kHalf, iter.dtype(), "logical_and_cpu", [&]() { cpu_kernel(iter, [](scalar_t a, scalar_t b) -> scalar_t { return static_cast(a && b); @@ -253,35 +254,37 @@ void logical_and_kernel(TensorIterator& iter) { } void logical_or_kernel(TensorIterator& iter) { - // See Note [special-case bool outputs] + // We use if-else here specifically for bool instead of using iter.common_dtype() like the CUDA implementation because + // common_dtype() is unavailable for bfloat16. if (iter.dtype() == ScalarType::Bool) { - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kBool, kBFloat16, kHalf, iter.common_dtype(), "logical_or_cpu", [&]() { + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kBool, kBFloat16, kHalf, iter.input_dtype(), "logical_or_cpu", [&]() { cpu_kernel(iter, [](scalar_t a, scalar_t b) -> bool { return a || b; }); }); } else { - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kBool, kBFloat16, kHalf, iter.common_dtype(), "logical_or_cpu", [&]() { + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kBool, kBFloat16, kHalf, iter.dtype(), "logical_or_cpu", [&]() { cpu_kernel(iter, [](scalar_t a, scalar_t b) -> scalar_t { return static_cast(a || b); }); - }); + }); } } void logical_xor_kernel(TensorIterator& iter) { - // See Note [special-case bool outputs] + // We use if-else here specifically for bool instead of using iter.common_dtype() like the CUDA implementation because + // common_dtype() is unavailable for bfloat16. if (iter.dtype() == ScalarType::Bool) { - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kBool, kBFloat16, kHalf, iter.common_dtype(), "logical_xor_cpu", [&]() { + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kBool, kBFloat16, kHalf, iter.input_dtype(), "logical_xor_cpu", [&]() { cpu_kernel(iter, [](scalar_t a, scalar_t b) -> bool { return bool(a) != bool(b); }); }); } else { - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kBFloat16, kHalf, iter.common_dtype(), "logical_xor_cpu", [&]() { + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kBFloat16, kHalf, iter.dtype(), "logical_xor_cpu", [&]() { cpu_kernel(iter, [](scalar_t a, scalar_t b) -> scalar_t { return static_cast(bool(a) != bool(b)); @@ -308,22 +311,21 @@ void rshift_kernel(TensorIterator& iter) { cpu_kernel(iter, [](scalar_t a, scalar_t b) -> scalar_t { return a >> b; - }); + }); }); } } void lt_kernel(TensorIterator& iter) { - // See Note [special-case bool outputs] if (iter.dtype() == ScalarType::Bool) { - AT_DISPATCH_ALL_TYPES_AND3(kBool, kBFloat16, kHalf, iter.common_dtype(), "lt_cpu", [&]() { + AT_DISPATCH_ALL_TYPES_AND3(kBool, kBFloat16, kHalf, iter.input_dtype(), "lt_cpu", [&]() { cpu_kernel(iter, - [](scalar_t a, scalar_t b) -> bool { - return a < b; - }); + [](scalar_t a, scalar_t b) -> bool { + return a < b; + }); }); } else { - AT_DISPATCH_ALL_TYPES_AND2(kBFloat16, kHalf, iter.common_dtype(), "lt_cpu", [&]() { + AT_DISPATCH_ALL_TYPES_AND2(kBFloat16, kHalf, iter.dtype(), "lt_cpu", [&]() { cpu_kernel_vec( iter, [](scalar_t a, scalar_t b) -> scalar_t { @@ -332,21 +334,20 @@ void lt_kernel(TensorIterator& iter) { [](Vec256 a, Vec256 b) -> Vec256 { return a.lt(b); }); - }); + }); } } void le_kernel(TensorIterator& iter) { - // See Note [special-case bool outputs] if (iter.dtype() == ScalarType::Bool) { - AT_DISPATCH_ALL_TYPES_AND3(kBool, kBFloat16, kHalf, iter.common_dtype(), "le_cpu", [&]() { + AT_DISPATCH_ALL_TYPES_AND3(kBool, kBFloat16, kHalf, iter.input_dtype(), "le_cpu", [&]() { cpu_kernel(iter, - [](scalar_t a, scalar_t b) -> bool { - return a <= b; - }); + [](scalar_t a, scalar_t b) -> bool { + return a <= b; + }); }); } else { - AT_DISPATCH_ALL_TYPES_AND2(kBFloat16, kHalf, iter.common_dtype(), "le_cpu", [&]() { + AT_DISPATCH_ALL_TYPES_AND2(kBFloat16, kHalf, iter.dtype(), "le_cpu", [&]() { cpu_kernel_vec( iter, [](scalar_t a, scalar_t b) -> scalar_t { @@ -355,21 +356,20 @@ void le_kernel(TensorIterator& iter) { [](Vec256 a, Vec256 b) -> Vec256 { return a.le(b); }); - }); + }); } } void gt_kernel(TensorIterator& iter) { - // See Note [special-case bool outputs] if (iter.dtype() == ScalarType::Bool) { - AT_DISPATCH_ALL_TYPES_AND3(kBool, kBFloat16, kHalf, iter.common_dtype(), "gt_cpu", [&]() { + AT_DISPATCH_ALL_TYPES_AND3(kBool, kBFloat16, kHalf, iter.input_dtype(), "gt_cpu", [&]() { cpu_kernel(iter, - [](scalar_t a, scalar_t b) -> bool { - return a > b; - }); + [=](scalar_t a, scalar_t b) -> bool { + return a > b; + }); }); } else { - AT_DISPATCH_ALL_TYPES_AND2(kBFloat16, kHalf, iter.common_dtype(), "gt_cpu", [&]() { + AT_DISPATCH_ALL_TYPES_AND2(kBFloat16, kHalf, iter.dtype(), "gt_cpu", [&]() { cpu_kernel_vec( iter, [](scalar_t a, scalar_t b) -> scalar_t { @@ -378,21 +378,20 @@ void gt_kernel(TensorIterator& iter) { [](Vec256 a, Vec256 b) -> Vec256 { return a.gt(b); }); - }); + }); } } void ge_kernel(TensorIterator& iter) { - // See Note [special-case bool outputs] if (iter.dtype() == ScalarType::Bool) { - AT_DISPATCH_ALL_TYPES_AND3(kBool, kBFloat16, kHalf, iter.common_dtype(), "ge_cpu", [&]() { + AT_DISPATCH_ALL_TYPES_AND3(kBool, kBFloat16, kHalf, iter.input_dtype(), "ge_cpu", [&]() { cpu_kernel(iter, - [](scalar_t a, scalar_t b) -> bool { - return a >= b; - }); + [](scalar_t a, scalar_t b) -> bool { + return a >= b; + }); }); } else { - AT_DISPATCH_ALL_TYPES_AND2(kBFloat16, kHalf, iter.common_dtype(), "ge_cpu", [&]() { + AT_DISPATCH_ALL_TYPES_AND2(kBFloat16, kHalf, iter.dtype(), "ge_cpu", [&]() { cpu_kernel_vec( iter, [](scalar_t a, scalar_t b) -> scalar_t { @@ -401,21 +400,20 @@ void ge_kernel(TensorIterator& iter) { [](Vec256 a, Vec256 b) -> Vec256 { return a.ge(b); }); - }); + }); } } void eq_kernel(TensorIterator& iter) { - // See Note [special-case bool outputs] if (iter.dtype() == ScalarType::Bool) { - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kBool, kBFloat16, kHalf, iter.common_dtype(), "eq_cpu", [&]() { + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kBool, kBFloat16, kHalf, iter.input_dtype(), "eq_cpu", [&]() { cpu_kernel(iter, - [](scalar_t a, scalar_t b) -> bool { - return a == b; - }); + [](scalar_t a, scalar_t b) -> bool { + return a == b; + }); }); } else { - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kBFloat16, kHalf, iter.common_dtype(), "eq_cpu", [&]() { + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kBFloat16, kHalf, iter.dtype(), "eq_cpu", [&]() { cpu_kernel_vec( iter, [](scalar_t a, scalar_t b) -> scalar_t { @@ -424,21 +422,20 @@ void eq_kernel(TensorIterator& iter) { [](Vec256 a, Vec256 b) -> Vec256 { return a.eq(b); }); - }); + }); } } void ne_kernel(TensorIterator& iter) { - // See Note [special-case bool outputs] if (iter.dtype() == ScalarType::Bool) { - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kBool, kBFloat16, kHalf, iter.common_dtype(), "ne_cpu", [&]() { + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kBool, kBFloat16, kHalf, iter.input_dtype(), "ne_cpu", [&]() { cpu_kernel(iter, - [](scalar_t a, scalar_t b) -> bool { - return a != b; - }); + [](scalar_t a, scalar_t b) -> bool { + return a != b; + }); }); } else { - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kBFloat16, kHalf, iter.common_dtype(), "ne_cpu", [&]() { + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kBFloat16, kHalf, iter.dtype(), "ne_cpu", [&]() { cpu_kernel_vec( iter, [](scalar_t a, scalar_t b) -> scalar_t { @@ -447,7 +444,7 @@ void ne_kernel(TensorIterator& iter) { [](Vec256 a, Vec256 b) -> Vec256 { return a.ne(b); }); - }); + }); } } diff --git a/test/test_torch.py b/test/test_torch.py index 312406fdab50..cc6d57b01d80 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -1,7 +1,6 @@ import sys import io import inspect -import itertools import math import random import re @@ -5923,57 +5922,6 @@ def test_isinf_type(self, device): with self.assertRaises(TypeError): torch.isinf(1) # Parameter must be a tensor - @dtypes(*tuple(itertools.combinations_with_replacement(torch.testing.get_all_dtypes(), 2))) - def test_comparison_ops_type_promotion_and_broadcasting(self, device, dtypes): - # issue #42660 - # testing all combinations of broadcasting and type promotion - # with a range of dtypes and input shapes, and with extremal values - def compare_with_numpy_bin_op(torch_fn, np_fn, x, y, out=None): - # working around the fact that numpy doesn't support bfloat16 - # by letting numpy treat them as float32's - x_np = x if x.dtype != torch.bfloat16 else x.to(torch.float32) - y_np = y.cpu().numpy() if y.dtype != torch.bfloat16 else y.to(torch.float32).cpu().numpy() - self.compare_with_numpy(lambda inp: torch_fn(inp, y, out=out) if out else torch_fn(inp, y), - lambda inp: np_fn(inp, y_np, out=out) if out else np_fn(inp, y_np), - x_np) - - complex_op_denylist = [torch.lt, torch.le, torch.gt, torch.ge] # complex not supported - input_sizes = [ - (1,), - (10,), - (10, 1), - (1, 10), - (4, 10), - (64, 10), - (12, 3)] - op_pairs = [(torch.lt, np.less), - (torch.le, np.less_equal), - (torch.gt, np.greater), - (torch.ge, np.greater_equal), - (torch.eq, np.equal), - (torch.ne, np.not_equal), - (torch.logical_and, np.logical_and), - (torch.logical_or, np.logical_or), - (torch.logical_xor, np.logical_xor)] - - for size1 in input_sizes: - size2 = (2,) + size1 # perform broadcasting - for with_extremal in [False, True]: - a = self._generate_input(size1, dtypes[0], device, with_extremal) - b = self._generate_input(size2, dtypes[1], device, with_extremal) - for torch_op, numpy_op in op_pairs: - if (dtypes[0].is_complex or dtypes[1].is_complex) and torch_op in complex_op_denylist: - continue - # functional version of op - compare_with_numpy_bin_op(torch_op, numpy_op, a, b) - - # functional comparison ops always return bool tensors - self.assertEqual(torch_op(a,b).dtype, torch.bool) - - # out version of op - out = torch.zeros(1, dtype=torch.complex128) # all casts to complex128 are safe - compare_with_numpy_bin_op(torch_op, numpy_op, a, b, out=out) - @onlyCPU @dtypes(torch.float) def test_diag(self, device, dtype): @@ -18900,12 +18848,7 @@ def _generate_input(self, shape, dtype, device, with_extremal): x = torch.tensor((), dtype=dtype, device=device) else: if dtype.is_floating_point or dtype.is_complex: - # work around torch.randn not being implemented for bfloat16 - if dtype == torch.bfloat16: - x = torch.randn(*shape, device=device) * random.randint(30, 100) - x = x.to(torch.bfloat16) - else: - x = torch.randn(*shape, dtype=dtype, device=device) * random.randint(30, 100) + x = torch.randn(*shape, dtype=dtype, device=device) * random.randint(30, 100) x[torch.randn(*shape) > 0.5] = 0 if with_extremal and dtype.is_floating_point: # Use extremal values @@ -18916,9 +18859,6 @@ def _generate_input(self, shape, dtype, device, with_extremal): x[torch.randn(*shape) > 0.5] = complex('nan') x[torch.randn(*shape) > 0.5] = complex('inf') x[torch.randn(*shape) > 0.5] = complex('-inf') - elif dtype == torch.bool: - x = torch.zeros(shape, dtype=dtype, device=device) - x[torch.randn(*shape) > 0.5] = True else: x = torch.randint(15, 100, shape, dtype=dtype, device=device)