From 56713720ad7977398889123cb8b7db4a12969ccd Mon Sep 17 00:00:00 2001 From: Chris Smith Date: Tue, 23 Apr 2013 10:33:05 +0545 Subject: [PATCH] add bivariate solving to solvers --- sympy/solvers/bivariate.py | 403 ++++++++++++++++++++++++++++ sympy/solvers/solvers.py | 70 ++++- sympy/solvers/tests/test_ode.py | 10 +- sympy/solvers/tests/test_solvers.py | 41 ++- 4 files changed, 490 insertions(+), 34 deletions(-) create mode 100644 sympy/solvers/bivariate.py diff --git a/sympy/solvers/bivariate.py b/sympy/solvers/bivariate.py new file mode 100644 index 000000000000..6d1b85b891dd --- /dev/null +++ b/sympy/solvers/bivariate.py @@ -0,0 +1,403 @@ +from sympy.core.add import Add +from sympy.core.compatibility import ordered +from sympy.core.function import Function, expand_log +from sympy.core.mul import Mul +from sympy.core.power import Pow +from sympy.core.singleton import S +from sympy.core.symbol import (Dummy, Wild) +from sympy.functions.elementary.exponential import (LambertW, exp, log) +from sympy.functions.elementary.miscellaneous import root +from sympy.polys.polytools import (factor, Poly, primitive) +from sympy.simplify.simplify import (_mexpand, collect, separatevars) +from sympy.solvers.solvers import solve + + +def _filtered_gens(poly, symbol): + """process the generators of ``poly``, returning the set of generators that + have ``symbol``. If there are two generators that are inverses of each other, + prefer the one that has no denominator. + + Examples + ======== + + >>> from sympy.solvers.bivariate import _filtered_gens + >>> from sympy import Poly, exp + >>> from sympy.abc import x + >>> _filtered_gens(Poly(x + 1/x + exp(x)), x) + set([x, exp(x)]) + + """ + gens = set([g for g in poly.gens if symbol in g.free_symbols]) + for g in list(gens): + ag = 1/g + if g in gens and ag in gens: + if ag.as_numer_denom()[1] is not S.One: + g = ag + gens.remove(g) + return gens + + +def _mostfunc(lhs, func, X=None): + """Returns the term in lhs which contains the most of the + func-type things e.g. log(log(x)) wins over log(x) if both terms appear. + + ``func`` can be a function (exp, log, etc...) or any other SymPy object, + like Pow. + + Examples + ======== + + >>> from sympy.solvers.bivariate import _mostfunc + >>> from sympy.functions.elementary.exponential import exp + >>> from sympy.utilities.pytest import raises + >>> from sympy.abc import x, y + >>> _mostfunc(exp(x) + exp(exp(x) + 2), exp) + exp(exp(x) + 2) + >>> _mostfunc(exp(x) + exp(exp(y) + 2), exp, x) + exp(x) + >>> _mostfunc(exp(x) + exp(exp(y) + 2), exp, x) + exp(x) + >>> _mostfunc(x, exp, x) is None + True + >>> _mostfunc(exp(x) + exp(x*y), exp, x) + exp(x) + """ + fterms = [tmp for tmp in lhs.atoms(func) if (not X or + X.is_Symbol and X in tmp.free_symbols or + not X.is_Symbol and tmp.has(X))] + if len(fterms) == 1: + return fterms[0] + elif fterms: + return max(list(ordered(fterms)), key=lambda x: x.count(func)) + return None + + +def _linab(arg, symbol): + """Return ``a, b, X`` assuming ``arg`` can be written as ``a*X + b`` + where ``X`` is a symbol-dependent factor and ``a`` and ``b`` are + independent of ``symbol``. + + Examples + ======== + + >>> from sympy.functions.elementary.exponential import exp + >>> from sympy.solvers.bivariate import _linab + >>> from sympy.abc import x, y + >>> from sympy import S + >>> _linab(S(2), x) + (2, 0, 1) + >>> _linab(2*x, x) + (2, 0, x) + >>> _linab(y + y*x + 2*x, x) + (y + 2, y, x) + >>> _linab(3 + 2*exp(x), x) + (2, 3, exp(x)) + """ + + arg = arg.expand() + ind, dep = arg.as_independent(symbol) + if not arg.is_Add: + b = 0 + a, x = ind, dep + else: + b = ind + a, x = separatevars(dep).as_independent(symbol, as_Add=False) + if x.could_extract_minus_sign(): + a = -a + x = -x + return a, b, x + + +def _lambert(eq, x): + """ + Given an expression assumed to be in the form + ``F(X, a..f) = a*log(b*X + c) + d*X + f = 0`` + return the Lambert solution if possible: + ``X = -c/b + (a/d)*W(d/(a*b)*exp(c*d/a/b)*exp(-f/a))``. + """ + eq = _mexpand(expand_log(eq)) + mainlog = _mostfunc(eq, log, x) + if not mainlog: + return [] # violated assumptions + other = eq.subs(mainlog, 0) + if not x in other.free_symbols: + return [] # violated assumptions + d, f, X2 = _linab(other, x) + logterm = eq - other + a = logterm.subs(mainlog, 1) + if x in a.free_symbols: + return [] # violated assumptions + logarg = (logterm/a).args[0] + b, c, X1 = _linab(logarg, x) + if X1*X2 == 1: + X1 = 1/X1 + if X1 != X2: + return [] + + u = Dummy('rhs') + rhs = -c/b + (a/d)*LambertW(d/(a*b)*exp(c*d/a/b)*exp(-f/a)) + + # if W's arg is between -1/e and 0 there is a -1 branch solution, too. + + # Check here to see if exp(W(s)) appears and return s/W(s) instead? + + solns = solve(X1 - u, x) + for i, tmp in enumerate(solns): + solns[i] = tmp.subs(u,rhs) + if solns == [2]: + # hack: how to get this as a soln for x**2 - 2**x = 0? + solns = [-2/log(2)*LambertW(log(2)/2)] + return solns + + +def _solve_lambert(f, symbol, gens): + """Return solution to ``f`` if it is a Lambert-type expression + else raise NotImplementedError. + + The equality, ``f(x, a..f) = a*log(b*X + c) + d*X - f = 0`` has the + solution, `X = -c/b + (a/d)*W(d/(a*b)*exp(c*d/a/b)*exp(f/a))`. There + are a variety of forms for `f(X, a..f)` as enumerated below: + + 1a1) + if B**B = R for R not [0, 1] then + log(B) + log(log(B)) = log(log(R)) + X = log(B), a = 1, b = 1, c = 0, d = 1, f = log(log(R)) + 1a2) + if B*(b*log(B) + c)**a = R then + log(B) + a*log(b*log(B) + c) = log(R) + X = log(B); d=1, f=log(R) + 1b) + if a*log(b*B + c) + d*B = R then + X = B, f = R + 2a) + if (b*B + c)*exp(d*B + g) = R then + log(b*B + c) + d*B + g = log(R) + a = 1, f = log(R) - g, X = B + 2b) + if -b*B + g*exp(d*B + h) = c then + log(g) + d*B + h - log(b*B + c) = 0 + a = -1, f = -h - log(g), X = B + 3) + if d*p**(a*B + g) - b*B = c then + log(d) + (a*B + g)*log(p) - log(c + b*B) = 0 + a = -1, d = a*log(p), f = -log(d) - g*log(p) + """ + + nrhs, lhs = f.as_independent(symbol) + rhs = -nrhs + lamcheck = [tmp for tmp in gens + if (tmp.func in [exp, log] or + (tmp.is_Pow and symbol in tmp.exp.free_symbols))] + if not lamcheck: + raise NotImplementedError() + + if lhs.is_Mul: + lhs = expand_log(log(lhs)) + rhs = log(rhs) + + lhs = factor(lhs) + + # For the 1st two, collect on main log + # 1a1) B**B = R != 0 (when 0, there is only a solution if the base is 0, + # but if it is, the exp is 0 and 0**0=1 + # comes back as B*log(B) = log(R) + # 1a2) B*(a + b*log(B))**p = R or with monomial expanded or with whole + # thing expanded comes back unchanged + # log(B) + p*log(a + b*log(B)) = log(R) + # lhs is Mul: + # factor each mul term + # expand log of both sides to give: + # log(B) + log(log(B)) = log(log(R)) + # 1b) d*log(a*B + b) + c*B = R + # lhs is Add: + # expand log of both sides to give: + # log(log(a*B + b)) - log(R - c*B) = -log(d) + + rv = [] + if not rv: + mainlog = _mostfunc(lhs, log, symbol) + if mainlog: + was = lhs + lhs = collect(lhs, mainlog) + if lhs.is_Mul and rhs != 0: + if was != lhs: + factored_args = [] + for arg in lhs.args: + factored_args.append(factor(arg)) + lhs = Mul(*factored_args) + soln = _lambert(log(lhs) - log(rhs), symbol) + elif lhs.is_Add: + other = lhs.subs(mainlog, 0) + if (other.is_Pow or other.is_Mul and + [tmp for tmp in other.atoms(Pow) + if symbol in tmp.free_symbols]): + diff = log(other) - log(rhs - (lhs - other)) + soln = _lambert(expand_log(diff), symbol) + else: + #it's ready to go + soln = _lambert(lhs - rhs, symbol) + else: + soln = [] + for si in soln: + rv.append(si) + + # For the next two, + # collect on main exp + # 2a) (b*B + c)*exp(d*B + g) = R + # lhs is mul: + # log to give + # log(b*B + c) + d*B = log(R) - g + # 2b) -b*B + g*exp(d*B + h) = R + # lhs is add: + # add b*B + # log and rearrange + # log(R + b*B) - d*B = log(g) + h + + if not rv: + mainexp = _mostfunc(lhs, exp, symbol) + if mainexp: + lhs = collect(lhs, mainexp) + if lhs.is_Mul and rhs != 0: + soln = _lambert(expand_log(log(lhs) - log(rhs)), symbol) + elif lhs.is_Add: + # move all but mainexp-containing term to rhs + other = lhs.subs(mainexp, 0) + mainterm = lhs - other + rhs=rhs - other + if (mainterm.could_extract_minus_sign() and + rhs.could_extract_minus_sign()): + mainterm *= -1 + rhs *= -1 + diff = log(mainterm) - log(rhs) + soln = _lambert(expand_log(diff), symbol) + else: + soln = [] + for si in soln: + rv.append(si) + + # 3) d*p**(a*B + b) + c*B = R + # collect on main pow + # log(R - c*B) - a*B*log(p) = log(d) + b*log(p) + + if not rv: + mainpow = _mostfunc(lhs, Pow, symbol) + if mainpow and symbol in mainpow.exp.free_symbols: + lhs = collect(lhs, mainpow) + if lhs.is_Mul and rhs != 0: + soln = _lambert(expand_log(log(lhs) - log(rhs)), symbol) + elif lhs.is_Add: + # move all but mainpow-containing term to rhs + other = lhs.subs(mainpow, 0) + mainterm = lhs - other + rhs = rhs - other + if (mainterm.could_extract_minus_sign() and + rhs.could_extract_minus_sign()): + mainterm *= -1 + rhs *= -1 + diff = log(mainterm) - log(rhs) + soln = _lambert(expand_log(diff), symbol) + else: + soln=[] + for si in soln: + rv.append(si) + + if not rv: + raise NotImplementedError('%s does not appear to have a solution in ' + 'terms of LambertW' % f) + + return rv + + +def bivariate_type(f, x, y, **kwargs): + """Given an expression, f, 3 tests will be done to see what type + of composite bivariate it might be, options for u(x, y) are:: + + x*y + x+y + x*y+x + x*y+y + + If it matches one of these types, ``u(x, y)``, ``P(u)`` and dummy + variable ``u`` will be returned. Solving ``P(u)`` for ``u`` and + equating the solutions to ``u(x, y)`` and then solving for ``x`` or + ``y`` is equivalent to solving the original expression for ``x`` or + ``y``. If ``x`` and ``y`` represent two functions in the same + variable, e.g. ``x = g(t)`` and ``y = h(t)``, then if ``u(x, y) - p`` + can be solved for ``t`` then these represent the solutions to + ``P(u) = 0`` when ``p`` are the solutions of ``P(u) = 0``. + + Only positive values of ``u`` are considered. + + Examples + ======== + + >>> from sympy.solvers.solvers import solve + >>> from sympy.solvers.bivariate import bivariate_type + >>> from sympy.abc import x, y + >>> eq = (x**2 - 3).subs(x, x + y) + >>> bivariate_type(eq, x, y) + (x + y, _u**2 - 3, _u) + >>> uxy, pu, u = _ + >>> usol = solve(pu, u); usol + [sqrt(3)] + >>> [solve(uxy - s) for s in solve(pu, u)] + [[{x: -y + sqrt(3)}]] + >>> all(eq.subs(s).equals(0) for sol in _ for s in sol) + True + + """ + + u = Dummy('u', positive=True) + + if kwargs.pop('first', True): + p = Poly(f, x, y) + f = p.as_expr() + _x = Dummy() + _y = Dummy() + rv = bivariate_type(Poly(f.subs({x: _x, y: _y}), _x, _y), _x, _y, first=False) + if rv: + reps = {_x: x, _y: y} + return rv[0].xreplace(reps), rv[1].xreplace(reps), rv[2] + return + + p = f + f = p.as_expr() + + # f(x*y) + args = Add.make_args(p.as_expr()) + new = [] + for a in args: + a = _mexpand(a.subs(x, u/y)) + free = a.free_symbols + if x in free or y in free: + break + new.append(a) + else: + return x*y, Add(*new), u + + def ok(f, v, c): + new = _mexpand(f.subs(v, c)) + free = new.free_symbols + return None if (x in free or y in free) else new + + # f(a*x + b*y) + new = [] + d = p.degree(x) + if p.degree(y) == d: + a = root(p.coeff_monomial(x**d), d) + b = root(p.coeff_monomial(y**d), d) + new = ok(f, x, (u - b*y)/a) + if new is not None: + return a*x + b*y, new, u + + # f(a*x*y + b*y) + new = [] + d = p.degree(x) + if p.degree(y) == d: + for itry in range(2): + a = root(p.coeff_monomial(x**d*y**d), d) + b = root(p.coeff_monomial(y**d), d) + new = ok(f, x, (u - b*y)/a/y) + if new is not None: + return a*x*y + b*y, new, u + x, y = y, x diff --git a/sympy/solvers/solvers.py b/sympy/solvers/solvers.py index 6eda9875a92f..e7ca69576908 100644 --- a/sympy/solvers/solvers.py +++ b/sympy/solvers/solvers.py @@ -1,4 +1,4 @@ -""" +""" This module contain solvers for all kinds of equations: - algebraic or transcendental, use solve() @@ -21,7 +21,7 @@ from sympy.core.exprtools import factor_terms from sympy.core.function import (expand_mul, expand_multinomial, expand_log, Derivative, AppliedUndef, UndefinedFunction, nfloat, - count_ops, Function) + count_ops, Function, expand_power_exp) from sympy.core.numbers import ilcm, Float from sympy.core.relational import Relational from sympy.logic.boolalg import And, Or @@ -2132,16 +2132,41 @@ def _tsolve(eq, sym, **flags): [LambertW(2)/2] """ + + eq2 = eq.subs(sym, _x) + if _patterns is None: _generate_patterns() - eq2 = eq.subs(sym, _x) - for p, sol in _patterns: - m = eq2.match(p) - if m: - soln = sol.subs(m).subs(_x, sym) - if sym not in soln.free_symbols: - return [soln] + # try pattern matching - two passes, once without simplification + # and once with + for itry in range(2): + for p, sol in _patterns: + m = eq2.match(p) + if m: + soln = sol.subs(m).subs(_x, sym) + if sym not in soln.free_symbols: + return [soln] + if itry == 0: + # lambert forms may need some help being recognized, e.g. changing + # 2**(3*x) + x**3*log(2)**3 + 3*x**2*log(2)**2 + 3*x*log(2) + 1 + # to 2**(3*x) + (x*log(2) + 1)**3 + g = _filtered_gens(eq.as_poly(), sym) + up_or_log = set() + for gi in g: + if gi.func is exp or gi.func is log: + up_or_log.add(gi) + elif gi.is_Pow: + gisimp = powdenest(expand_power_exp(gi)) + if gisimp.is_Pow and sym in gisimp.exp.free_symbols: + up_or_log.add(gi) + down = g.difference(up_or_log) + eq_down = expand_power_exp(eq).subs( + dict(zip(up_or_log, [0]*len(up_or_log)))) + eq2 = expand_power_exp(factor(eq_down) + (eq - eq_down)) + eq2 = eq2.subs(sym, _x) + + # continue with heuristics rhs, lhs = _invert(eq, sym) if lhs.is_Add: @@ -2180,6 +2205,29 @@ def _tsolve(eq, sym, **flags): if rewrite != lhs: return _solve(rewrite - rhs, sym) + + # maybe it was a harder lambert pattern + if flags.pop('bivariate', True): + try: + poly = lhs.as_poly() + g = _filtered_gens(poly, sym) + return _solve_lambert(lhs - rhs, sym, g) + except NotImplementedError: + # maybe it's a convoluted function + if len(g) == 2: + try: + gpu = bivariate_type(lhs - rhs, *g) + if gpu is None: + raise NotImplementedError + g, p, u = gpu + flags['bivariate'] = False + inversion = _tsolve(g - u, sym, **flags) + if inversion: + sol = _solve(p, u, **flags) + return [i.subs(u, s) for i in inversion for s in sol] + except NotImplementedError: + pass + if flags.pop('force', True): flags['force'] = False pos, reps = posify(lhs - rhs) @@ -2698,3 +2746,7 @@ def _norm2(a, b): eq = neq[0] return (_canonical(eq), cov, list(dens)) + + +from sympy.solvers.bivariate import ( + bivariate_type, _solve_lambert, _filtered_gens) diff --git a/sympy/solvers/tests/test_ode.py b/sympy/solvers/tests/test_ode.py index aca4b258ca5b..8793ef71575f 100644 --- a/sympy/solvers/tests/test_ode.py +++ b/sympy/solvers/tests/test_ode.py @@ -316,7 +316,8 @@ def test_1st_exact1(): eq4 = cos(f(x)) - (x*sin(f(x)) - f(x)**2)*f(x).diff(x) eq5 = 2*x*f(x) + (x**2 + f(x)**2)*f(x).diff(x) sol1 = Eq(f(x), acos((C1)/cos(x))) - sol2 = Eq(log(f(x)) + x/f(x) + x**2, C1) + sol2 = Eq(f(x), C1*exp(-x**2 + LambertW(C2*x*exp(x**2)))) + sol2b = Eq(log(f(x)) + x/f(x) + x**2, C1) sol3 = Eq(f(x)*sin(x) + cos(f(x)) + x**2 + f(x)**2, C1) sol4 = Eq(x*cos(f(x)) + f(x)**3/3, C1) sol5 = Eq(x**2*f(x) + f(x)**3/3, C1) @@ -326,7 +327,8 @@ def test_1st_exact1(): assert dsolve(eq4, hint='1st_exact') == sol4 assert dsolve(eq5, hint='1st_exact', simplify=False) == sol5 assert checkodesol(eq1, sol1, order=1, solve_for_func=False)[0] - assert checkodesol(eq2, sol2, order=1, solve_for_func=False)[0] + # simplification doesn't handle LambertW well enough to verify + assert checkodesol(eq2, sol2b, order=1, solve_for_func=False)[0] assert checkodesol(eq3, sol3, order=1, solve_for_func=False)[0] assert checkodesol(eq4, sol4, order=1, solve_for_func=False)[0] assert checkodesol(eq5, sol5, order=1, solve_for_func=False)[0] @@ -528,9 +530,9 @@ def test_1st_homogeneous_coeff_ode(): eq8 = x + f(x) - (x - f(x))*f(x).diff(x) sol1 = Eq(log(x), C1 - log(f(x)*sin(f(x)/x)/x)) sol2 = Eq(log(x), C1 + log(sqrt(cos(f(x)/x) - 1)/sqrt(cos(f(x)/x) + 1))) - sol3 = Eq(log(f(x)), C1 + log(log(f(x)/x) - 1)) + sol3 = Eq(f(x), x*exp(-LambertW(C1*x) + 1)) sol4 = Eq(log(f(x)), C1 - 2*exp(x/f(x))) - sol5 = Eq(log(x), C1 - x**2/(2*f(x)**2) - log(sqrt(f(x)/x))) + sol5 = Eq(f(x), C1*exp(LambertW(C2*x**4)/2)/x) sol6 = Eq(log(x), C1 + exp(-f(x)/x)*sin(f(x)/x)/2 + exp(-f(x)/x)*cos(f(x)/x)/2) sol7 = Eq(log(f(x)), C1 - 2*sqrt(-x/f(x) + 1)) diff --git a/sympy/solvers/tests/test_solvers.py b/sympy/solvers/tests/test_solvers.py index 62b6ac0bdc28..83054b3b4fa0 100644 --- a/sympy/solvers/tests/test_solvers.py +++ b/sympy/solvers/tests/test_solvers.py @@ -552,14 +552,6 @@ def test_issue_2098(): assert solve(z**2*x**2 - z**2*y**2/exp(x), y, x, z) == [{y: x*exp(x/2)}] -@XFAIL -def test_failing(): - # better Lambert detection is needed if the expression is expanded - # this case has a double generator: (7**x, x); this will pass if the - # x-terms are factored - assert solve((2*(3*x + 4)**5 - 6*7**(3*x + 9)).expand(), x) - - def test_checking(): assert set( solve(x*(x - y/x), x, check=False)) == set([sqrt(y), S(0), -sqrt(y)]) @@ -877,19 +869,6 @@ def test_unrad_slow(): raises(NotImplementedError, lambda: solve(eq)) # not other code errors -@XFAIL -def test_multivariate(): - assert solve( - (x**2 - 2*x + 1).subs(x, log(x) + 3*x)) == [LambertW(3*S.Exp1)/3] - assert solve((x**2 - 2*x + 1).subs(x, (log(x) + 3*x)**2 - 1)) == \ - [LambertW(3*exp(-sqrt(2)))/3, LambertW(3*exp(sqrt(2)))/3] - assert solve((x**2 - 2*x - 2).subs(x, log(x) + 3*x)) == \ - [LambertW(3*exp(1 - sqrt(3)))/3, LambertW(3*exp(1 + sqrt(3)))/3] - assert solve(x*log(x) + 3*x + 1, x) == [exp(-3 + LambertW(-exp(3)))] - # symmetry - assert solve(3*sin(x) - x*sin(3), x) == [3] - - def test__invert(): assert _invert(x - 2) == (2, x) assert _invert(2) == (2, 0) @@ -1251,3 +1230,23 @@ def test_issues_3720_3721_3722(): assert solve(log(x + 1) - log(2*x - 1)) == [2] x = symbols('x') assert solve(2**x + 4**x) == [I*pi/log(2)] + + +def test_lambert_multivariate(): + from sympy.abc import x + assert solve((x**2 - 2*x + 1).subs(x, log(x) + 3*x)) == [LambertW(3*S.Exp1)/3] + assert solve((x**2 - 2*x + 1).subs(x, (log(x) + 3*x)**2 - 1)) == \ + [LambertW(3*exp(-sqrt(2)))/3, LambertW(3*exp(sqrt(2)))/3] + assert solve((x**2 - 2*x - 2).subs(x, log(x) + 3*x)) == \ + [LambertW(3*exp(1 + sqrt(3)))/3, LambertW(3*exp(-sqrt(3) + 1))/3] + assert solve(x*log(x) + 3*x + 1, x) == [exp(-3 + LambertW(-exp(3)))] + eq = (x*exp(x) - 3).subs(x, x*exp(x)) + assert solve(eq) == [LambertW(3*exp(-LambertW(3)))] + assert solve((2*(3*x + 4)**5 - 6*7**(3*x + 9)).expand(), x) == \ + [S(-5)*LambertW(-7*3**(S(1)/5)*log(7)/5)/(3*log(7)) + S(-4)/3] + + +@XFAIL +def test_by_symmetry(): + from sympy.abc import x + assert solve(3*sin(x) - x*sin(3), x) == [3]