diff --git a/numexpr/complex_functions.hpp b/numexpr/complex_functions.hpp index 42775e2..e0ff8a2 100644 --- a/numexpr/complex_functions.hpp +++ b/numexpr/complex_functions.hpp @@ -390,42 +390,45 @@ nc_sinh(std::complex *x, std::complex *r) static void nc_tan(std::complex *x, std::complex *r) { - double sr,cr,shi,chi; - double rs,is,rc,ic; - double d; - double xr=x->real(), xi=x->imag(); - sr = sin(xr); - cr = cos(xr); - shi = sinh(xi); - chi = cosh(xi); - rs = sr*chi; - is = cr*shi; - rc = cr*chi; - ic = -sr*shi; - d = rc*rc + ic*ic; - r->real((rs*rc+is*ic)/d); - r->imag((is*rc-rs*ic)/d); + double xr = x->real(); + double xi = x->imag(); + double imag_part; + + double denom = cos(2*xr) + cosh(2*xi); + // handle overflows + if (xi > 20) { + imag_part = 1.0 / (1.0 + exp(-4*xi)); + } else if (xi < -20) { + imag_part = -1.0 / (1.0 + exp(4*xi)); + } else { + imag_part = sinh(2*xi) / denom; + } + double real_part = sin(2*xr) / denom; + + r->real(real_part); + r->imag(imag_part); return; } static void nc_tanh(std::complex *x, std::complex *r) { - double si,ci,shr,chr; - double rs,is,rc,ic; - double d; - double xr=x->real(), xi=x->imag(); - si = sin(xi); - ci = cos(xi); - shr = sinh(xr); - chr = cosh(xr); - rs = ci*shr; - is = si*chr; - rc = ci*chr; - ic = si*shr; - d = rc*rc + ic*ic; - r->real((rs*rc+is*ic)/d); - r->imag((is*rc-rs*ic)/d); + double xr = x->real(); + double xi = x->imag(); + double real_part; + double denom = cosh(2*xr) + cos(2*xi); + // handle overflows + if (xr > 20) { + real_part = 1.0 / (1.0 + exp(-4*xr)); + } else if (xr < -20) { + real_part = -1.0 / (1.0 + exp(4*xr)); + } else { + real_part = sinh(2*xr) / denom; + } + double imag_part = sin(2*xi) / denom; + + r->real(real_part); + r->imag(imag_part); return; } diff --git a/numexpr/tests/test_numexpr.py b/numexpr/tests/test_numexpr.py index be0b055..e8e711a 100644 --- a/numexpr/tests/test_numexpr.py +++ b/numexpr/tests/test_numexpr.py @@ -480,6 +480,13 @@ def test_bitwise_operators(self): assert_array_equal(evaluate("x | y"), x | y) # or assert_array_equal(evaluate("~x"), ~x) # invert + def test_complex_tan(self): + # old version of NumExpr had overflow problems + x = np.arange(1, 400., step=16., dtype=np.complex128) + y = 1j*np.arange(1, 400., step=16., dtype=np.complex128) + assert_array_almost_equal(evaluate("tan(x + y)"), tan(x + y)) + assert_array_almost_equal(evaluate("tanh(x + y)"), tanh(x + y)) + def test_maximum_minimum(self): for dtype in [float, double, int, np.int64]: x = arange(10, dtype=dtype)