From 536cf343aa620515ecd49bb788607bf6a0f547ae Mon Sep 17 00:00:00 2001 From: Chris Smith Date: Wed, 10 Jun 2020 18:32:01 -0500 Subject: [PATCH] solve wrt additive generators --- sympy/solvers/solvers.py | 24 ++++++++++++++++++++++-- sympy/solvers/tests/test_solvers.py | 12 ++++++++++++ 2 files changed, 34 insertions(+), 2 deletions(-) diff --git a/sympy/solvers/solvers.py b/sympy/solvers/solvers.py index 91a68e48a658..b93dd9f0aed8 100644 --- a/sympy/solvers/solvers.py +++ b/sympy/solvers/solvers.py @@ -20,7 +20,7 @@ default_sort_key) from sympy.core.sympify import sympify from sympy.core import (S, Add, Symbol, Equality, Dummy, Expr, Mul, - Pow, Unequality) + Pow, Unequality, Wild) from sympy.core.exprtools import factor_terms from sympy.core.function import (expand_mul, expand_log, Derivative, AppliedUndef, UndefinedFunction, nfloat, @@ -1444,6 +1444,25 @@ def _solve(f, *symbols, **flags): sol = simplify(sol) return [sol] + poly = None + # check for a single non-symbol generator + dums = f_num.atoms(Dummy) + D = f_num.replace( + lambda i: isinstance(i, Add) and symbol in i.free_symbols, + lambda i: Dummy()) + if not D.is_Dummy: + dgen = D.atoms(Dummy) - dums + if len(dgen) == 1: + d = dgen.pop() + w = Wild('g') + gen = f_num.match(D.xreplace({d: w}))[w] + spart = gen.as_independent(symbol)[1].as_base_exp()[0] + if spart == symbol: + try: + poly = Poly(f_num, spart) + except PolynomialError: + pass + result = False # no solution was obtained msg = '' # there is no failure message @@ -1457,7 +1476,8 @@ def _solve(f, *symbols, **flags): # generator is not a symbol try: - poly = Poly(f_num) + if poly is None: + poly = Poly(f_num) if poly is None: raise ValueError('could not convert %s to Poly' % f_num) except GeneratorsNeeded: diff --git a/sympy/solvers/tests/test_solvers.py b/sympy/solvers/tests/test_solvers.py index 172fbaa3ce11..bd36a2a7fabe 100644 --- a/sympy/solvers/tests/test_solvers.py +++ b/sympy/solvers/tests/test_solvers.py @@ -2187,3 +2187,15 @@ def test_issue_19113_19102(): -2*atan(-sqrt(5)/2 + h + sqrt(2)*sqrt(1 - sqrt(5))/2), -2*atan(-sqrt(2)*sqrt(1 + sqrt(5))/2 + h + sqrt(5)/2)] assert solve(3*cos(x) - sin(x)) == [atan(3)] + + +def test_issue_19509(): + a = S(3)/4 + b = S(5)/8 + c = sqrt(5)/8 + d = sqrt(5)/4 + assert solve(1/(x -1)**5 - 1) == [2, + -d + a - sqrt(-b + c), + -d + a + sqrt(-b + c), + d + a - sqrt(-b - c), + d + a + sqrt(-b - c)]