Skip to content

Commit

Permalink
Merge pull request sympy#26358 from smichr/26338
Browse files Browse the repository at this point in the history
maintain ordering in heurisch mapping
  • Loading branch information
smichr committed Mar 19, 2024
2 parents 94305f1 + ca87484 commit 9d13061
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 4 deletions.
14 changes: 12 additions & 2 deletions sympy/integrals/heurisch.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from __future__ import annotations

from itertools import permutations
from collections import defaultdict
from functools import reduce
from itertools import permutations

from sympy.core.add import Add
from sympy.core.basic import Basic
Expand Down Expand Up @@ -503,7 +504,16 @@ def heurisch(f, x, rewrite=False, hints=None, mappings=None, retries=3,
# optimizing the number of permutations of mapping #
assert mapping[-1][0] == x # if not, find it and correct this comment
unnecessary_permutations = [mapping.pop(-1)]
mappings = permutations(mapping)
# only permute types of objects and let the ordering
# of types take care of the order of replacement
types = defaultdict(list)
for i in mapping:
types[type(i)].append(i)
mapping = [types[i] for i in types]
def _iter_mappings():
for i in permutations(mapping):
yield [j for i in i for j in i]
mappings = _iter_mappings()
else:
unnecessary_permutations = unnecessary_permutations or []

Expand Down
34 changes: 33 additions & 1 deletion sympy/integrals/tests/test_heurisch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from sympy.core.add import Add
from sympy.core.function import (Derivative, Function, diff)
from sympy.core.numbers import (I, Rational, pi)
from sympy.core.relational import Ne
from sympy.core.relational import Eq, Ne
from sympy.core.symbol import (Symbol, symbols)
from sympy.functions.elementary.exponential import (LambertW, exp, log)
from sympy.functions.elementary.hyperbolic import (asinh, cosh, sinh, tanh)
Expand All @@ -12,6 +12,8 @@
from sympy.functions.special.bessel import (besselj, besselk, bessely, jn)
from sympy.functions.special.error_functions import erf
from sympy.integrals.integrals import Integral
from sympy.logic.boolalg import And
from sympy.matrices import Matrix
from sympy.simplify.ratsimp import ratsimp
from sympy.simplify.simplify import simplify
from sympy.integrals.heurisch import components, heurisch, heurisch_wrapper
Expand Down Expand Up @@ -365,3 +367,33 @@ def f(x):
Uz = integrate(f(z), z)
Ut = integrate(f(t), t)
assert Ut == Uz.subs(z, t)


def test_heurisch_complex_erf_issue_26338():
r = symbols('r', real=True)
a = exp(-r**2/(2*(2 - I)**2))
assert heurisch(a, r, hints=[]) is None # None, not a wrong soln
a = sqrt(pi)*erf((1 + I)/2)/2
assert integrate(exp(-I*r**2/2), (r, 0, 1)) == a - I*a

a = exp(-x**2/(2*(2 - I)**2))
assert heurisch(a, x, hints=[]) is None # None, not a wrong soln
a = sqrt(pi)*erf((1 + I)/2)/2
assert integrate(exp(-I*x**2/2), (x, 0, 1)) == a - I*a


def test_issue_15498():
Z0 = Function('Z0')
k01, k10, t, s= symbols('k01 k10 t s', real=True, positive=True)
m = Matrix([[exp(-k10*t)]])
_83 = Rational(83, 100) # 0.83 works, too
[a, b, c, d, e, f, g] = [100, 0.5, _83, 50, 0.6, 2, 120]
AIF_btf = a*(d*e*(1 - exp(-(t - b)/e)) + f*g*(1 - exp(-(t - b)/g)))
AIF_atf = a*(d*e*exp(-(t - b)/e)*(exp((c - b)/e) - 1
) + f*g*exp(-(t - b)/g)*(exp((c - b)/g) - 1))
AIF_sym = Piecewise((0, t < b), (AIF_btf, And(b <= t, t < c)), (AIF_atf, c <= t))
aif_eq = Eq(Z0(t), AIF_sym)
f_vec = Matrix([[k01*Z0(t)]])
integrand = m*m.subs(t, s)**-1*f_vec.subs(aif_eq.lhs, aif_eq.rhs).subs(t, s)
solution = integrate(integrand[0], (s, 0, t))
assert solution is not None # does not hang and takes less than 10 s
2 changes: 1 addition & 1 deletion sympy/integrals/tests/test_integrals.py
Original file line number Diff line number Diff line change
Expand Up @@ -1148,7 +1148,7 @@ def test_issue_3940():
assert integrate(exp(-x**2 + I*c*x), x) == \
-sqrt(pi)*exp(-c**2/4)*erf(I*c/2 - x)/2
assert integrate(exp(a*x**2 + b*x + c), x) == \
sqrt(pi)*exp(c)*exp(-b**2/(4*a))*erfi(sqrt(a)*x + b/(2*sqrt(a)))/(2*sqrt(a))
sqrt(pi)*exp(c - b**2/(4*a))*erfi((2*a*x + b)/(2*sqrt(a)))/(2*sqrt(a))

from sympy.core.function import expand_mul
from sympy.abc import k
Expand Down

0 comments on commit 9d13061

Please sign in to comment.