Skip to content

Commit

Permalink
Merge pull request #26471 from haru-44/LambertW_eval
Browse files Browse the repository at this point in the history
Improvement of `LambertW` evaluation
  • Loading branch information
smichr committed Apr 8, 2024
2 parents 08b6c58 + 4dab56b commit e89ee93
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 6 deletions.
23 changes: 17 additions & 6 deletions sympy/functions/elementary/exponential.py
Expand Up @@ -11,6 +11,7 @@
from sympy.core.numbers import Integer, Rational, pi, I
from sympy.core.parameters import global_parameters
from sympy.core.power import Pow
from sympy.core.relational import Ge
from sympy.core.singleton import S
from sympy.core.symbol import Wild, Dummy
from sympy.core.sympify import sympify
Expand Down Expand Up @@ -1161,20 +1162,30 @@ def eval(cls, x, k=None):
return S.Zero
if x is S.Exp1:
return S.One
if x == -1/S.Exp1:
return S.NegativeOne
w = Wild('w')
# W(x*log(x)) = log(x) for x >= 1/e
# e.g., W(-1/e) = -1, W(2*log(2)) = log(2)
result = x.match(w*log(w))
if result is not None and Ge(result[w]*S.Exp1, S.One) is S.true:
return log(result[w])
if x == -log(2)/2:
return -log(2)
if x == 2*log(2):
return log(2)
# W(x**(x+1)*log(x)) = x*log(x) for x > 0
# e.g., W(81*log(3)) = 3*log(3)
result = x.match(w**(w+1)*log(w))
if result is not None and result[w].is_positive is True:
return result[w]*log(result[w])
# W(e**(1/n)/n) = 1/n
# e.g., W(sqrt(e)/2) = 1/2
result = x.match(S.Exp1**(1/w)/w)
if result is not None:
return 1 / result[w]
if x == -pi/2:
return I*pi/2
if x == exp(1 + S.Exp1):
return S.Exp1
if x is S.Infinity:
return S.Infinity
if x.is_zero:
return S.Zero

if fuzzy_not(k.is_zero):
if x.is_zero:
Expand Down
4 changes: 4 additions & 0 deletions sympy/functions/elementary/tests/test_exponential.py
Expand Up @@ -605,7 +605,10 @@ def test_lambertw():
assert LambertW(0) == 0
assert LambertW(E) == 1
assert LambertW(-1/E) == -1
assert LambertW(100*log(100)) == log(100)
assert LambertW(-log(2)/2) == -log(2)
assert LambertW(81*log(3)) == 3*log(3)
assert LambertW(sqrt(E)/2) == S.Half
assert LambertW(oo) is oo
assert LambertW(0, 1) is -oo
assert LambertW(0, 42) is -oo
Expand All @@ -627,6 +630,7 @@ def test_lambertw():
assert LambertW(2, evaluate=False).is_real
p = Symbol('p', positive=True)
assert LambertW(p, evaluate=False).is_real
assert LambertW(p**(p+1)*log(p)) == p*log(p)
assert LambertW(p - 1, evaluate=False).is_real is None
assert LambertW(-p - 2/S.Exp1, evaluate=False).is_real is False
assert LambertW(S.Half, -1, evaluate=False).is_real is False
Expand Down

0 comments on commit e89ee93

Please sign in to comment.