Skip to content

Commit

Permalink
Merge pull request #19507 from gschintgen/fix-solve-trig1
Browse files Browse the repository at this point in the history
Improve _solve_trig1 (handle rational & symbolic coefficients)
  • Loading branch information
oscarbenjamin committed Jun 10, 2020
2 parents 6adabb3 + 3c8b95b commit d9bc559
Show file tree
Hide file tree
Showing 2 changed files with 233 additions and 38 deletions.
110 changes: 83 additions & 27 deletions sympy/solvers/solveset.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
from sympy.ntheory.factor_ import divisors
from sympy.ntheory.residue_ntheory import discrete_log, nthroot_mod
from sympy.polys import (roots, Poly, degree, together, PolynomialError,
RootOf, factor)
RootOf, factor, lcm, gcd)
from sympy.polys.polyerrors import CoercionFailed
from sympy.polys.polytools import invert
from sympy.solvers.solvers import (checksol, denoms, unrad,
Expand Down Expand Up @@ -532,66 +532,122 @@ def _solve_as_rational(f, symbol, domain):
return valid_solns - invalid_solns


class _SolveTrig1Error(Exception):
"""Raised when _solve_trig1 heuristics do not apply"""

def _solve_trig(f, symbol, domain):
"""Function to call other helpers to solve trigonometric equations """
sol1 = sol = None
sol = None
try:
sol1 = _solve_trig1(f, symbol, domain)
except NotImplementedError:
pass
if sol1 is None or isinstance(sol1, ConditionSet):
sol = _solve_trig1(f, symbol, domain)
except _SolveTrig1Error:
try:
sol = _solve_trig2(f, symbol, domain)
except ValueError:
sol = sol1
if isinstance(sol1, ConditionSet) and isinstance(sol, ConditionSet):
if sol1.count_ops() < sol.count_ops():
sol = sol1
else:
sol = sol1
if sol is None:
raise NotImplementedError(filldedent('''
Solution to this kind of trigonometric equations
is yet to be implemented'''))
raise NotImplementedError(filldedent('''
Solution to this kind of trigonometric equations
is yet to be implemented'''))
return sol


def _solve_trig1(f, symbol, domain):
"""Primary helper to solve trigonometric and hyperbolic equations"""
"""Primary solver for trigonometric and hyperbolic equations
Returns either the solution set as a ConditionSet (auto-evaluated to a
union of ImageSets if no variables besides 'symbol' are involved) or
raises _SolveTrig1Error if f == 0 can't be solved.
Notes
=====
Algorithm:
1. Do a change of variable x -> mu*x in arguments to trigonometric and
hyperbolic functions, in order to reduce them to small integers. (This
step is crucial to keep the degrees of the polynomials of step 4 low.)
2. Rewrite trigonometric/hyperbolic functions as exponentials.
3. Proceed to a 2nd change of variable, replacing exp(I*x) or exp(x) by y.
4. Solve the resulting rational equation.
5. Use invert_complex or invert_real to return to the original variable.
6. If the coefficients of 'symbol' were symbolic in nature, add the
necessary consistency conditions in a ConditionSet.
"""
# Prepare change of variable
x = Dummy('x')
if _is_function_class_equation(HyperbolicFunction, f, symbol):
cov = exp(symbol)
cov = exp(x)
inverter = invert_real if domain.is_subset(S.Reals) else invert_complex
else:
cov = exp(I*symbol)
cov = exp(I*x)
inverter = invert_complex

f = trigsimp(f)
f_original = f
trig_functions = f.atoms(TrigonometricFunction, HyperbolicFunction)
trig_arguments = [e.args[0] for e in trig_functions]
# trigsimp may have reduced the equation to an expression
# that is independent of 'symbol' (e.g. cos**2+sin**2)
if not any(a.has(symbol) for a in trig_arguments):
return solveset(f_original, symbol, domain)

denominators = []
numerators = []
for ar in trig_arguments:
try:
poly_ar = Poly(ar, symbol)
except PolynomialError:
raise _SolveTrig1Error("trig argument is not a polynomial")
if poly_ar.degree() > 1: # degree >1 still bad
raise _SolveTrig1Error("degree of variable must not exceed one")
if poly_ar.degree() == 0: # degree 0, don't care
continue
c = poly_ar.all_coeffs()[0] # got the coefficient of 'symbol'
numerators.append(fraction(c)[0])
denominators.append(fraction(c)[1])

mu = lcm(denominators)/gcd(numerators)
f = f.subs(symbol, mu*x)
f = f.rewrite(exp)
f = together(f)
g, h = fraction(f)
y = Dummy('y')
g, h = g.expand(), h.expand()
g, h = g.subs(cov, y), h.subs(cov, y)
if g.has(symbol) or h.has(symbol):
return ConditionSet(symbol, Eq(f, 0), domain)
if g.has(x) or h.has(x):
raise _SolveTrig1Error("change of variable not possible")

solns = solveset_complex(g, y) - solveset_complex(h, y)
if isinstance(solns, ConditionSet):
raise NotImplementedError
raise _SolveTrig1Error("polynomial has ConditionSet solution")

if isinstance(solns, FiniteSet):
if any(isinstance(s, RootOf) for s in solns):
raise NotImplementedError
raise _SolveTrig1Error("polynomial results in RootOf object")
# revert the change of variable
cov = cov.subs(x, symbol/mu)
result = Union(*[inverter(cov, s, symbol)[1] for s in solns])
# avoid spurious intersections with C in solution set
# In case of symbolic coefficients, the solution set is only valid
# if numerator and denominator of mu are non-zero.
if mu.has(Symbol):
syms = (mu).atoms(Symbol)
munum, muden = fraction(mu)
condnum = munum.as_independent(*syms, as_Add=False)[1]
condden = muden.as_independent(*syms, as_Add=False)[1]
cond = And(Ne(condnum, 0), Ne(condden, 0))
else:
cond = True
# Actual conditions are returned as part of the ConditionSet. Adding an
# intersection with C would only complicate some solution sets due to
# current limitations of intersection code. (e.g. #19154)
if domain is S.Complexes:
return result
# This is a slight abuse of ConditionSet. Ideally this should
# be some kind of "PiecewiseSet". (See #19507 discussion)
return ConditionSet(symbol, cond, result)
else:
return Intersection(result, domain)
return ConditionSet(symbol, cond, Intersection(result, domain))
elif solns is S.EmptySet:
return S.EmptySet
else:
return ConditionSet(symbol, Eq(f_original, 0), domain)
raise _SolveTrig1Error("polynomial solutions must form FiniteSet")


def _solve_trig2(f, symbol, domain):
Expand Down
161 changes: 150 additions & 11 deletions sympy/solvers/tests/test_solveset.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from sympy.sets.fancysets import ImageSet
from sympy.sets.sets import (Complement, EmptySet, FiniteSet,
Intersection, Interval, Union, imageset, ProductSet)
from sympy.simplify import simplify
from sympy.tensor.indexed import Indexed
from sympy.utilities.iterables import numbered_symbols

Expand Down Expand Up @@ -815,7 +816,53 @@ def test_solve_trig():
assert dumeq(solveset_real(sin(2*x)*cos(x) + cos(2*x)*sin(x)-1, x),
ImageSet(Lambda(n, n*pi*Rational(2, 3) + pi/6), S.Integers))

# Tests for _solve_trig2() function
assert dumeq(solveset_real(2*tan(x)*sin(x) + 1, x), Union(
ImageSet(Lambda(n, 2*n*pi + atan(sqrt(2)*sqrt(-1 + sqrt(17))/
(1 - sqrt(17))) + pi), S.Integers),
ImageSet(Lambda(n, 2*n*pi - atan(sqrt(2)*sqrt(-1 + sqrt(17))/
(1 - sqrt(17))) + pi), S.Integers)))

assert dumeq(solveset_real(cos(2*x)*cos(4*x) - 1, x),
ImageSet(Lambda(n, n*pi), S.Integers))

assert dumeq(solveset(sin(x/10) + Rational(3, 4)), Union(
ImageSet(Lambda(n, 20*n*pi + 10*atan(3*sqrt(7)/7) + 10*pi), S.Integers),
ImageSet(Lambda(n, 20*n*pi - 10*atan(3*sqrt(7)/7) + 20*pi), S.Integers)))

assert dumeq(solveset(cos(x/15) + cos(x/5)), Union(
ImageSet(Lambda(n, 30*n*pi + 15*pi/2), S.Integers),
ImageSet(Lambda(n, 30*n*pi + 45*pi/2), S.Integers),
ImageSet(Lambda(n, 30*n*pi + 75*pi/4), S.Integers),
ImageSet(Lambda(n, 30*n*pi + 45*pi/4), S.Integers),
ImageSet(Lambda(n, 30*n*pi + 105*pi/4), S.Integers),
ImageSet(Lambda(n, 30*n*pi + 15*pi/4), S.Integers)))

assert dumeq(solveset(sec(sqrt(2)*x/3) + 5), Union(
ImageSet(Lambda(n, 3*sqrt(2)*(2*n*pi - pi + atan(2*sqrt(6)))/2), S.Integers),
ImageSet(Lambda(n, 3*sqrt(2)*(2*n*pi - atan(2*sqrt(6)) + pi)/2), S.Integers)))

assert dumeq(simplify(solveset(tan(pi*x) - cot(pi/2*x))), Union(
ImageSet(Lambda(n, 4*n + 1), S.Integers),
ImageSet(Lambda(n, 4*n + 3), S.Integers),
ImageSet(Lambda(n, 4*n + Rational(7, 3)), S.Integers),
ImageSet(Lambda(n, 4*n + Rational(5, 3)), S.Integers),
ImageSet(Lambda(n, 4*n + Rational(11, 3)), S.Integers),
ImageSet(Lambda(n, 4*n + Rational(1, 3)), S.Integers)))

assert dumeq(solveset(cos(9*x)), Union(
ImageSet(Lambda(n, 2*n*pi/9 + pi/18), S.Integers),
ImageSet(Lambda(n, 2*n*pi/9 + pi/6), S.Integers)))

assert dumeq(solveset(sin(8*x) + cot(12*x), x, S.Reals), Union(
ImageSet(Lambda(n, n*pi/2 + pi/8), S.Integers),
ImageSet(Lambda(n, n*pi/2 + 3*pi/8), S.Integers),
ImageSet(Lambda(n, n*pi/2 + 5*pi/16), S.Integers),
ImageSet(Lambda(n, n*pi/2 + 3*pi/16), S.Integers),
ImageSet(Lambda(n, n*pi/2 + 7*pi/16), S.Integers),
ImageSet(Lambda(n, n*pi/2 + pi/16), S.Integers)))

# This is the only remaining solveset test that actually ends up being solved
# by _solve_trig2(). All others are handled by the improved _solve_trig1.
assert dumeq(solveset_real(2*cos(x)*cos(2*x) - 1, x),
Union(ImageSet(Lambda(n, 2*n*pi + 2*atan(sqrt(-2*2**Rational(1, 3)*(67 +
9*sqrt(57))**Rational(2, 3) + 8*2**Rational(2, 3) + 11*(67 +
Expand All @@ -825,14 +872,10 @@ def test_solve_trig():
9*sqrt(57))**Rational(1, 3))/(3*(67 + 9*sqrt(57))**Rational(1, 6))) +
2*pi), S.Integers)))

assert dumeq(solveset_real(2*tan(x)*sin(x) + 1, x), Union(
ImageSet(Lambda(n, 2*n*pi + atan(sqrt(2)*sqrt(-1 +sqrt(17))/
(1 - sqrt(17))) + pi), S.Integers),
ImageSet(Lambda(n, 2*n*pi - atan(sqrt(2)*sqrt(-1 + sqrt(17))/
(1 - sqrt(17))) + pi), S.Integers)))

assert dumeq(solveset_real(cos(2*x)*cos(4*x) - 1, x),
ImageSet(Lambda(n, n*pi), S.Integers))
# issue #16870
assert dumeq(simplify(solveset(sin(x/180*pi) - S.Half, x, S.Reals)), Union(
ImageSet(Lambda(n, 360*n + 150), S.Integers),
ImageSet(Lambda(n, 360*n + 30), S.Integers)))


def test_solve_hyperbolic():
Expand All @@ -844,31 +887,127 @@ def test_solve_hyperbolic():
assert solveset_real(sinh(x) + sech(x), x) == FiniteSet(
log(sqrt(sqrt(5) - 2)))
assert solveset_real(3*cosh(2*x) - 5, x) == FiniteSet(
log(sqrt(3)/3), log(sqrt(3)))
-log(3)/2, log(3)/2)
assert solveset_real(sinh(x - 3) - 2, x) == FiniteSet(
log((2 + sqrt(5))*exp(3)))
assert solveset_real(cosh(2*x) + 2*sinh(x) - 5, x) == FiniteSet(
log(-2 + sqrt(5)), log(1 + sqrt(2)))
assert solveset_real((coth(x) + sinh(2*x))/cosh(x) - 3, x) == FiniteSet(
log(S.Half + sqrt(5)/2), log(1 + sqrt(2)))
assert solveset_real(cosh(x)*sinh(x) - 2, x) == FiniteSet(
log(sqrt(4 + sqrt(17))))
log(4 + sqrt(17))/2)
assert solveset_real(sinh(x) + tanh(x) - 1, x) == FiniteSet(
log(sqrt(2)/2 + sqrt(-S(1)/2 + sqrt(2))))

assert dumeq(solveset_complex(sinh(x) - I/2, x), Union(
ImageSet(Lambda(n, I*(2*n*pi + 5*pi/6)), S.Integers),
ImageSet(Lambda(n, I*(2*n*pi + pi/6)), S.Integers)))

assert dumeq(solveset_complex(sinh(x) + sech(x), x), Union(
ImageSet(Lambda(n, 2*n*I*pi + log(sqrt(-2 + sqrt(5)))), S.Integers),
ImageSet(Lambda(n, I*(2*n*pi + pi/2) + log(sqrt(2 + sqrt(5)))), S.Integers),
ImageSet(Lambda(n, I*(2*n*pi + pi) + log(sqrt(-2 + sqrt(5)))), S.Integers),
ImageSet(Lambda(n, I*(2*n*pi - pi/2) + log(sqrt(2 + sqrt(5)))), S.Integers)))

assert dumeq(solveset(sinh(x/10) + Rational(3, 4)), Union(
ImageSet(Lambda(n, 10*I*(2*n*pi + pi) + 10*log(2)), S.Integers),
ImageSet(Lambda(n, 20*n*I*pi - 10*log(2)), S.Integers)))

assert dumeq(solveset(cosh(x/15) + cosh(x/5)), Union(
ImageSet(Lambda(n, 15*I*(2*n*pi + pi/2)), S.Integers),
ImageSet(Lambda(n, 15*I*(2*n*pi - pi/2)), S.Integers),
ImageSet(Lambda(n, 15*I*(2*n*pi - 3*pi/4)), S.Integers),
ImageSet(Lambda(n, 15*I*(2*n*pi + 3*pi/4)), S.Integers),
ImageSet(Lambda(n, 15*I*(2*n*pi - pi/4)), S.Integers),
ImageSet(Lambda(n, 15*I*(2*n*pi + pi/4)), S.Integers)))

assert dumeq(solveset(sech(sqrt(2)*x/3) + 5), Union(
ImageSet(Lambda(n, 3*sqrt(2)*I*(2*n*pi - pi + atan(2*sqrt(6)))/2), S.Integers),
ImageSet(Lambda(n, 3*sqrt(2)*I*(2*n*pi - atan(2*sqrt(6)) + pi)/2), S.Integers)))

assert dumeq(solveset(tanh(pi*x) - coth(pi/2*x)), Union(
ImageSet(Lambda(n, 2*I*(2*n*pi + pi/2)/pi), S.Integers),
ImageSet(Lambda(n, 2*I*(2*n*pi - pi/2)/pi), S.Integers)))

assert dumeq(solveset(cosh(9*x)), Union(
ImageSet(Lambda(n, I*(2*n*pi + pi/2)/9), S.Integers),
ImageSet(Lambda(n, I*(2*n*pi - pi/2)/9), S.Integers)))

# issues #9606 / #9531:
assert solveset(sinh(x), x, S.Reals) == FiniteSet(0)
assert dumeq(solveset(sinh(x), x, S.Complexes), Union(
ImageSet(Lambda(n, I*(2*n*pi + pi)), S.Integers),
ImageSet(Lambda(n, 2*n*I*pi), S.Integers)))

# issues #11218 / #18427
assert dumeq(solveset(sin(pi*x), x, S.Reals), Union(
ImageSet(Lambda(n, (2*n*pi + pi)/pi), S.Integers),
ImageSet(Lambda(n, 2*n), S.Integers)))
assert dumeq(solveset(sin(pi*x), x), Union(
ImageSet(Lambda(n, (2*n*pi + pi)/pi), S.Integers),
ImageSet(Lambda(n, 2*n), S.Integers)))

# issue #17543
assert dumeq(simplify(solveset(I*cot(8*x - 8*E), x)), Union(
ImageSet(Lambda(n, n*pi/4 - 13*pi/16 + E), S.Integers),
ImageSet(Lambda(n, n*pi/4 - 11*pi/16 + E), S.Integers)))


def test_solve_trig_hyp_symbolic():
# actual solver: _solve_trig1
assert dumeq(solveset(sin(a*x), x), ConditionSet(x, Ne(a, 0), Union(
ImageSet(Lambda(n, (2*n*pi + pi)/a), S.Integers),
ImageSet(Lambda(n, 2*n*pi/a), S.Integers))))

assert dumeq(solveset(cosh(x/a), x), ConditionSet(x, Ne(a, 0), Union(
ImageSet(Lambda(n, I*a*(2*n*pi + pi/2)), S.Integers),
ImageSet(Lambda(n, I*a*(2*n*pi - pi/2)), S.Integers))))

assert dumeq(solveset(sin(2*sqrt(3)/3*a**2/(b*pi)*x)
+ cos(4*sqrt(3)/3*a**2/(b*pi)*x), x),
ConditionSet(x, Ne(b, 0) & Ne(a**2, 0), Union(
ImageSet(Lambda(n, sqrt(3)*pi*b*(2*n*pi + pi/2)/(2*a**2)), S.Integers),
ImageSet(Lambda(n, sqrt(3)*pi*b*(2*n*pi - 5*pi/6)/(2*a**2)), S.Integers),
ImageSet(Lambda(n, sqrt(3)*pi*b*(2*n*pi - pi/6)/(2*a**2)), S.Integers))))

assert dumeq(simplify(solveset(cot((1 + I)*x) - cot((3 + 3*I)*x), x)), Union(
ImageSet(Lambda(n, pi*(1 - I)*(4*n + 1)/4), S.Integers),
ImageSet(Lambda(n, pi*(1 - I)*(4*n - 1)/4), S.Integers)))

assert dumeq(solveset(cosh((a**2 + 1)*x) - 3, x),
ConditionSet(x, Ne(a**2 + 1, 0), Union(
ImageSet(Lambda(n, (2*n*I*pi + log(3 - 2*sqrt(2)))/(a**2 + 1)), S.Integers),
ImageSet(Lambda(n, (2*n*I*pi + log(2*sqrt(2) + 3))/(a**2 + 1)), S.Integers))))

ar = Symbol('ar', real=True)
assert solveset(cosh((ar**2 + 1)*x) - 2, x, S.Reals) == FiniteSet(
log(sqrt(3) + 2)/(ar**2 + 1), log(2 - sqrt(3))/(ar**2 + 1))


def test_issue_9616():
assert dumeq(solveset(sinh(x) + tanh(x) - 1, x), Union(
ImageSet(Lambda(n, 2*n*I*pi + log(sqrt(2)/2 + sqrt(-S.Half + sqrt(2)))), S.Integers),
ImageSet(Lambda(n, I*(2*n*pi - atan(sqrt(2)*sqrt(S.Half + sqrt(2))) + pi)
+ log(sqrt(1 + sqrt(2)))), S.Integers),
ImageSet(Lambda(n, I*(2*n*pi + pi) + log(-sqrt(2)/2 + sqrt(-S.Half + sqrt(2)))), S.Integers),
ImageSet(Lambda(n, I*(2*n*pi - pi + atan(sqrt(2)*sqrt(S.Half + sqrt(2))))
+ log(sqrt(1 + sqrt(2)))), S.Integers)))
f1 = (sinh(x)).rewrite(exp)
f2 = (tanh(x)).rewrite(exp)
assert dumeq(solveset(f1 + f2 - 1, x), Union(
Complement(ImageSet(
Lambda(n, I*(2*n*pi + pi) + log(-sqrt(2)/2 + sqrt(-S.Half + sqrt(2)))), S.Integers),
ImageSet(Lambda(n, I*(2*n*pi + pi)/2), S.Integers)),
Complement(ImageSet(Lambda(n, I*(2*n*pi - pi + atan(sqrt(2)*sqrt(S.Half + sqrt(2))))
+ log(sqrt(1 + sqrt(2)))), S.Integers),
ImageSet(Lambda(n, I*(2*n*pi + pi)/2), S.Integers)),
Complement(ImageSet(Lambda(n, I*(2*n*pi - atan(sqrt(2)*sqrt(S.Half + sqrt(2))) + pi)
+ log(sqrt(1 + sqrt(2)))), S.Integers),
ImageSet(Lambda(n, I*(2*n*pi + pi)/2), S.Integers)),
Complement(
ImageSet(Lambda(n, 2*n*I*pi + log(sqrt(2)/2 + sqrt(-S.Half + sqrt(2)))), S.Integers),
ImageSet(Lambda(n, I*(2*n*pi + pi)/2), S.Integers))))


def test_solve_invalid_sol():
assert 0 not in solveset_real(sin(x)/x, x)
Expand Down

0 comments on commit d9bc559

Please sign in to comment.