Skip to content

Commit

Permalink
Use unordered NEQ comparison for vec512 operator!= implementations (#…
Browse files Browse the repository at this point in the history
…97466)

This is consistent with the vec256 operator!= implementations. _CMP_NEQ_UQ is the logical opposite of _CMP_EQ_OQ comparison used in the operator== implementations.

Using the ordered NEQ operation results in nan != nan being false which is incorrect.

Pull Request resolved: #97466
Approved by: https://github.com/jgong5, https://github.com/sanchitintel
  • Loading branch information
bjhargrave authored and pytorchmergebot committed Mar 25, 2023
1 parent c757647 commit ee934fd
Show file tree
Hide file tree
Showing 5 changed files with 5 additions and 5 deletions.
2 changes: 1 addition & 1 deletion aten/src/ATen/cpu/vec/vec512/vec512_bfloat16.h
Original file line number Diff line number Diff line change
Expand Up @@ -653,7 +653,7 @@ Vectorized<BFloat16> inline Vectorized<BFloat16>::operator==(const Vectorized<BF
Vectorized<BFloat16> inline Vectorized<BFloat16>::operator!=(const Vectorized<BFloat16>& other) const {
return bfloat16_compare_as_fp32(*this, other, [](__m512 x, __m512 y) {
auto zero_vec = _mm512_set1_epi32(0);
auto cmp = _mm512_cmp_ps_mask(x, y, _CMP_NEQ_OQ);
auto cmp = _mm512_cmp_ps_mask(x, y, _CMP_NEQ_UQ);
return _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, cmp, 0xFFFFFFFF));
});
}
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/cpu/vec/vec512/vec512_complex_double.h
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ template <> class Vectorized<c10::complex<double>> {
0xFFFFFFFFFFFFFFFF));
}
Vectorized<c10::complex<double>> operator!=(const Vectorized<c10::complex<double>>& other) const {
auto mask = _mm512_cmp_pd_mask(values, other.values, _CMP_NEQ_OQ);
auto mask = _mm512_cmp_pd_mask(values, other.values, _CMP_NEQ_UQ);
return _mm512_castsi512_pd(_mm512_mask_set1_epi64(zero_vector, mask,
0xFFFFFFFFFFFFFFFF));
}
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/cpu/vec/vec512/vec512_complex_float.h
Original file line number Diff line number Diff line change
Expand Up @@ -879,7 +879,7 @@ template <> class Vectorized<c10::complex<float>> {
return _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vector, mask, 0xFFFFFFFF));
}
Vectorized<c10::complex<float>> operator!=(const Vectorized<c10::complex<float>>& other) const {
auto mask = _mm512_cmp_ps_mask(values, other.values, _CMP_NEQ_OQ);
auto mask = _mm512_cmp_ps_mask(values, other.values, _CMP_NEQ_UQ);
return _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vector, mask, 0xFFFFFFFF));
}
Vectorized<c10::complex<float>> operator<(const Vectorized<c10::complex<float>>& other) const {
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/cpu/vec/vec512/vec512_double.h
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ template <> class Vectorized<double> {
}

Vectorized<double> operator!=(const Vectorized<double>& other) const {
auto cmp_mask = _mm512_cmp_pd_mask(values, other.values, _CMP_NEQ_OQ);
auto cmp_mask = _mm512_cmp_pd_mask(values, other.values, _CMP_NEQ_UQ);
return _mm512_castsi512_pd(_mm512_mask_set1_epi64(zero_vector, cmp_mask,
0xFFFFFFFFFFFFFFFF));
}
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/cpu/vec/vec512/vec512_float.h
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ template <> class Vectorized<float> {
}

Vectorized<float> operator!=(const Vectorized<float>& other) const {
auto mask = _mm512_cmp_ps_mask(values, other.values, _CMP_NEQ_OQ);
auto mask = _mm512_cmp_ps_mask(values, other.values, _CMP_NEQ_UQ);
return _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, mask,
0xFFFFFFFF));
}
Expand Down

0 comments on commit ee934fd

Please sign in to comment.