Skip to content

Commit

Permalink
BUG: Catch -INF in exp function and do not set underflow flag
Browse files Browse the repository at this point in the history
  • Loading branch information
r-devulap committed Jan 28, 2022
1 parent eee12b6 commit 76f8584
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions numpy/core/src/umath/loops_exponent_log.dispatch.c.src
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,7 @@ simd_exp_FLOAT(npy_float * op,
@vtype@ poly, num_poly, denom_poly, quadrant;
@vtype@i vindex = _mm@vsize@_loadu_si@vsize@((@vtype@i*)&indexarr[0]);

@mask@ xmax_mask, xmin_mask, nan_mask, inf_mask;
@mask@ xmax_mask, xmin_mask, nan_mask, inf_mask, ninf_mask;
@mask@ overflow_mask = @isa@_get_partial_load_mask_ps(0, num_lanes);
@mask@ underflow_mask = @isa@_get_partial_load_mask_ps(0, num_lanes);
@mask@ load_mask = @isa@_get_full_load_mask_ps();
Expand All @@ -490,9 +490,11 @@ simd_exp_FLOAT(npy_float * op,
xmax_mask = _mm@vsize@_cmp_ps@vsub@(x, _mm@vsize@_set1_ps(xmax), _CMP_GE_OQ);
xmin_mask = _mm@vsize@_cmp_ps@vsub@(x, _mm@vsize@_set1_ps(xmin), _CMP_LE_OQ);
inf_mask = _mm@vsize@_cmp_ps@vsub@(x, inf, _CMP_EQ_OQ);
ninf_mask = _mm@vsize@_cmp_ps@vsub@(x, -inf, _CMP_EQ_OQ);
overflow_mask = @or_masks@(overflow_mask,
@xor_masks@(xmax_mask, inf_mask));
underflow_mask = @or_masks@(underflow_mask, xmin_mask);
underflow_mask = @or_masks@(underflow_mask,
@xor_masks@(xmin_mask, ninf_mask));

x = @isa@_set_masked_lanes_ps(x, zeros_f, @or_masks@(
@or_masks@(nan_mask, xmin_mask), xmax_mask));
Expand Down Expand Up @@ -748,7 +750,7 @@ AVX512F_exp_DOUBLE(npy_double * op,
__mmask8 overflow_mask = avx512_get_partial_load_mask_pd(0, num_lanes);
__mmask8 underflow_mask = avx512_get_partial_load_mask_pd(0, num_lanes);
__mmask8 load_mask = avx512_get_full_load_mask_pd();
__mmask8 xmin_mask, xmax_mask, inf_mask, nan_mask, nearzero_mask;
__mmask8 xmin_mask, xmax_mask, inf_mask, ninf_mask, nan_mask, nearzero_mask;

while (num_remaining_elements > 0) {
if (num_remaining_elements < num_lanes) {
Expand All @@ -769,14 +771,16 @@ AVX512F_exp_DOUBLE(npy_double * op,
xmax_mask = _mm512_cmp_pd_mask(x, mTH_max, _CMP_GT_OQ);
xmin_mask = _mm512_cmp_pd_mask(x, mTH_min, _CMP_LT_OQ);
inf_mask = _mm512_cmp_pd_mask(x, mTH_inf, _CMP_EQ_OQ);
ninf_mask = _mm512_cmp_pd_mask(x, -mTH_inf, _CMP_EQ_OQ);
__m512i x_abs = _mm512_and_epi64(_mm512_castpd_si512(x),
_mm512_set1_epi64(0x7FFFFFFFFFFFFFFF));
nearzero_mask = _mm512_cmp_pd_mask(_mm512_castsi512_pd(x_abs),
mTH_nearzero, _CMP_LT_OQ);
nearzero_mask = _mm512_kxor(nearzero_mask, nan_mask);
overflow_mask = _mm512_kor(overflow_mask,
_mm512_kxor(xmax_mask, inf_mask));
underflow_mask = _mm512_kor(underflow_mask, xmin_mask);
underflow_mask = _mm512_kor(underflow_mask,
_mm512_kxor(xmin_mask, ninf_mask));
x = avx512_set_masked_lanes_pd(x, zeros_d,
_mm512_kor(_mm512_kor(nan_mask, xmin_mask),
_mm512_kor(xmax_mask, nearzero_mask)));
Expand Down

0 comments on commit 76f8584

Please sign in to comment.