forked from scipy/scipy
-
Notifications
You must be signed in to change notification settings - Fork 0
/
lambertw.pxd
150 lines (127 loc) · 4.48 KB
/
lambertw.pxd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
# -*-cython-*-
#
# Implementation of the Lambert W function [1]. Based on the MPMath
# implementation [2], and documentation [3].
#
# Copyright: Yosef Meller, 2009
# Author email: mellerf@netvision.net.il
#
# Distributed under the same license as SciPy
#
# References:
# [1] On the Lambert W function, Adv. Comp. Math. 5 (1996) 329-359,
# available online: http://www.apmaths.uwo.ca/~djeffrey/Offprints/W-adv-cm.pdf
# [2] mpmath source code, Subversion revision 990
# http://code.google.com/p/mpmath/source/browse/trunk/mpmath/functions.py?spec=svn994&r=992
# [3] mpmath source code, Subversion revision 994
# http://code.google.com/p/mpmath/source/browse/trunk/mpmath/function_docs.py?spec=svn994&r=994
# TODO: use a series expansion when extremely close to the branch point
# at `-1/e` and make sure that the proper branch is chosen there
import cython
from . cimport sf_error
from ._evalpoly cimport cevalpoly
cdef extern from "math.h":
double exp(double x) nogil
double log(double x) nogil
cdef extern from "numpy/npy_math.h":
double NPY_E
from ._complexstuff cimport *
DEF twopi = 6.2831853071795864769252842 # 2*pi
DEF EXPN1 = 0.36787944117144232159553 # exp(-1)
DEF OMEGA = 0.56714329040978387299997 # W(1, 0)
@cython.cdivision(True)
cdef inline double complex lambertw_scalar(double complex z, long k, double tol) nogil:
cdef int i
cdef double absz, p
cdef double complex w
cdef double complex ew, wew, wewz, wn
if zisnan(z):
return z
elif z.real == inf:
return z + twopi*k*1j
elif z.real == -inf:
return -z + (twopi*k+pi)*1j
elif z == 0:
if k == 0:
return z
sf_error.error("lambertw", sf_error.SINGULAR, NULL)
return -inf
elif z == 1 and k == 0:
# Split out this case because the asymptotic series blows up
return OMEGA
absz = zabs(z)
# Get an initial guess for Halley's method
if k == 0:
if zabs(z + EXPN1) < 0.3:
w = lambertw_branchpt(z)
elif (-1.0 < z.real < 1.5 and zabs(z.imag) < 1.0
and -2.5*zabs(z.imag) - 0.2 < z.real):
# Empirically determined decision boundary where the Pade
# approximation is more accurate.
w = lambertw_pade0(z)
else:
w = lambertw_asy(z, k)
elif k == -1:
if absz <= EXPN1 and z.imag == 0 and z.real < 0:
w = log(-z.real)
else:
w = lambertw_asy(z, k)
else:
w = lambertw_asy(z, k)
# Halley's method; see 5.9 in [1]
if w.real >= 0:
# Rearrange the formula to avoid overflow in exp
for i in range(100):
ew = zexp(-w)
wewz = w - z*ew
wn = w - wewz/(w + 1 - (w + 2)*wewz/(2*w + 2))
if zabs(wn - w) < tol*zabs(wn):
return wn
else:
w = wn
else:
for i in range(100):
ew = zexp(w)
wew = w*ew
wewz = wew - z
wn = w - wewz/(wew + ew - (w + 2)*wewz/(2*w + 2))
if zabs(wn - w) < tol*zabs(wn):
return wn
else:
w = wn
sf_error.error("lambertw", sf_error.SLOW,
"iteration failed to converge: %g + %gj",
<double>z.real, <double>z.imag)
return zpack(nan, nan)
@cython.cdivision(True)
cdef inline double complex lambertw_branchpt(double complex z) nogil:
"""Series for W(z, 0) around the branch point; see 4.22 in [1]."""
cdef double *coeffs = [-1.0/3.0, 1.0, -1.0]
cdef double complex p = zsqrt(2*(NPY_E*z + 1))
return cevalpoly(coeffs, 2, p)
@cython.cdivision(True)
cdef inline double complex lambertw_pade0(double complex z) nogil:
"""(3, 2) Pade approximation for W(z, 0) around 0."""
cdef:
double *num = [
12.85106382978723404255,
12.34042553191489361902,
1.0
]
double *denom = [
32.53191489361702127660,
14.34042553191489361702,
1.0
]
# This only gets evaluated close to 0, so we don't need a more
# careful algorithm that avoids overflow in the numerator for
# large z.
return z*cevalpoly(num, 2, z)/cevalpoly(denom, 2, z)
@cython.cdivision(True)
cdef inline double complex lambertw_asy(double complex z, long k) nogil:
"""Compute the W function using the first two terms of the
asymptotic series. See 4.20 in [1].
"""
cdef double complex w
w = zlog(z) + twopi*k*1j
return w - zlog(w)