Skip to content

Commit

Permalink
unrad-related changes
Browse files Browse the repository at this point in the history
  • Loading branch information
smichr committed Apr 10, 2021
1 parent 86cf5c3 commit 651595d
Show file tree
Hide file tree
Showing 6 changed files with 92 additions and 117 deletions.
6 changes: 6 additions & 0 deletions sympy/calculus/tests/test_util.py
Expand Up @@ -300,6 +300,12 @@ def test_minimum():
raises(ValueError, lambda : minimum(sin(x), S.One))


def test_issue_19869():
t = symbols('t')
assert (maximum(sqrt(3)*(t - 1)/(3*sqrt(t**2 + 1)), t)
) == sqrt(3)/3


def test_AccumBounds():
assert AccumBounds(1, 2).args == (1, 2)
assert AccumBounds(1, 2).delta is S.One
Expand Down
83 changes: 34 additions & 49 deletions sympy/solvers/solvers.py
Expand Up @@ -42,7 +42,7 @@
from sympy.simplify.fu import TR1, TR2i
from sympy.matrices.common import NonInvertibleMatrixError
from sympy.matrices import Matrix, zeros
from sympy.polys import roots, cancel, factor, Poly, degree
from sympy.polys import roots, cancel, factor, Poly
from sympy.polys.polyerrors import GeneratorsNeeded, PolynomialError

from sympy.polys.solvers import sympy_eqs_to_ring, solve_lin_sys
Expand Down Expand Up @@ -3207,12 +3207,13 @@ def unrad(eq, *syms, **flags):
>>> unrad(sqrt(x)*x**Rational(1, 3) + 2)
(x**5 - 64, [])
>>> unrad(sqrt(x) + root(x + 1, 3))
(x**3 - x**2 - 2*x - 1, [])
(-x**3 + x**2 + 2*x + 1, [])
>>> eq = sqrt(x) + root(x, 3) - 2
>>> unrad(eq)
(_p**3 + _p**2 - 2, [_p, _p**6 - x])
"""
from sympy import Equality as Eq

uflags = dict(check=False, simplify=False)

Expand Down Expand Up @@ -3243,19 +3244,21 @@ def _canonical(eq, cov):
for f in eq.args:
if f.is_number:
continue
if f.is_Pow and _take(f, True):
if f.is_Pow:
args.append(f.base)
else:
args.append(f)
eq = Mul(*args) # leave as Mul for more efficient solving

# make the sign canonical
free = eq.free_symbols
if len(free) == 1:
if eq.coeff(free.pop()**degree(eq)).could_extract_minus_sign():
eq = -eq
elif eq.could_extract_minus_sign():
eq = -eq
margs = list(Mul.make_args(eq))
changed = False
for i, m in enumerate(margs):
if m.could_extract_minus_sign():
margs[i] = -m
changed = True
if changed:
eq = Mul(*margs, evaluate=False)

return eq, cov

Expand All @@ -3267,52 +3270,46 @@ def _Q(pow):
return c.q

# define the _take method that will determine whether a term is of interest
def _take(d, take_int_pow):
def _take(d):
# return True if coefficient of any factor's exponent's den is not 1
for pow in Mul.make_args(d):
if not (pow.is_Symbol or pow.is_Pow):
if not pow.is_Pow:
continue
b, e = pow.as_base_exp()
if not b.has(*syms):
if _Q(pow) == 1:
continue
if not take_int_pow and _Q(pow) == 1:
continue
free = pow.free_symbols
if free.intersection(syms):
if pow.free_symbols & syms:
return True
return False
_take = flags.setdefault('_take', _take)

if isinstance(eq, Eq):
eq = eq.rewrite(Add)
assert isinstance(eq, Expr)
elif not isinstance(eq, Expr):
return

cov, nwas, rpt = [flags.setdefault(k, v) for k, v in
sorted(dict(cov=[], n=None, rpt=0).items())]

# preconditioning
eq = powdenest(factor_terms(eq, radical=True, clear=True))

if isinstance(eq, Relational):
eq, d = eq, 1
else:
eq, d = eq.as_numer_denom()

eq = eq.as_numer_denom()[0]
eq = _mexpand(eq, recursive=True)
if eq.is_number:
return eq, []
return

syms = set(syms) or eq.free_symbols
# see if there are radicals in symbols of interest
syms = set(syms) or eq.free_symbols # _take uses this
poly = eq.as_poly()
gens = [g for g in poly.gens if _take(g, True)]
gens = [g for g in poly.gens if _take(g)]
if not gens:
return

# check for trivial case
# - already a polynomial in integer powers
if all(_Q(g) == 1 for g in gens):
if (len(gens) == len(poly.gens) and d!=1):
return eq, []
else:
return
# recast poly in terms of eigen-gens
poly = eq.as_poly(*gens)

# - an exponent has a symbol of interest (don't handle)
if any(g.as_base_exp()[1].has(*syms) for g in gens):
if any(g.exp.has(*syms) for g in gens):
return

def _rads_bases_lcm(poly):
Expand All @@ -3323,8 +3320,6 @@ def _rads_bases_lcm(poly):
rads = set()
bases = set()
for g in poly.gens:
if not _take(g, False):
continue
q = _Q(g)
if q != 1:
rads.add(g)
Expand All @@ -3333,9 +3328,6 @@ def _rads_bases_lcm(poly):
return rads, bases, lcm
rads, bases, lcm = _rads_bases_lcm(poly)

if not rads:
return

covsym = Dummy('p', nonnegative=True)

# only keep in syms symbols that actually appear in radicals;
Expand All @@ -3352,7 +3344,7 @@ def _rads_bases_lcm(poly):
rterms = {(): []}
args = Add.make_args(poly.as_expr())
for t in args:
if _take(t, False):
if _take(t):
common = set(t.as_poly().gens).intersection(rads)
key = tuple(sorted([drad[i] for i in common]))
else:
Expand All @@ -3377,14 +3369,10 @@ def _rads_bases_lcm(poly):
if len(bases) == 1:
b = bases.pop()
if len(syms) > 1:
free = b.free_symbols
x = {g for g in gens if g.is_Symbol} & free
if not x:
x = free
x = ordered(x)
x = b.free_symbols
else:
x = syms
x = list(x)[0]
x = list(ordered(x))[0]
try:
inv = _solve(covsym**lcm - b, x, **uflags)
if not inv:
Expand All @@ -3394,9 +3382,6 @@ def _rads_bases_lcm(poly):
return _canonical(eq, cov)
except NotImplementedError:
pass
else:
# no longer consider integer powers as generators
gens = [g for g in gens if _Q(g) != 1]

if len(rterms) == 2:
if not others:
Expand Down
49 changes: 7 additions & 42 deletions sympy/solvers/solveset.py
Expand Up @@ -551,7 +551,8 @@ def _is_function_class_equation(func_class, f, symbol):

def _solve_as_rational(f, symbol, domain):
""" solve rational functions"""
f = together(f, deep=True)
from sympy.core.function import _mexpand
f = together(_mexpand(f, recursive=True), deep=True)
g, h = fraction(f)
if not h.has(symbol):
try:
Expand Down Expand Up @@ -826,44 +827,9 @@ def _solve_as_poly(f, symbol, domain=S.Complexes):
return ConditionSet(symbol, Eq(f, 0), domain)


def _has_rational_power(expr, symbol):
"""
Returns (bool, den) where bool is True if the term has a
non-integer rational power and den is the denominator of the
expression's exponent.
Examples
========
>>> from sympy.solvers.solveset import _has_rational_power
>>> from sympy import sqrt
>>> from sympy.abc import x
>>> _has_rational_power(sqrt(x), x)
(True, 2)
>>> _has_rational_power(x**2, x)
(False, 1)
"""
a, p, q = Wild('a'), Wild('p'), Wild('q')
pattern_match = expr.match(a*p**q) or {}
if pattern_match.get(a, S.Zero).is_zero:
return (False, S.One)
elif p not in pattern_match.keys():
return (False, S.One)
elif isinstance(pattern_match[q], Rational) \
and pattern_match[p].has(symbol):
if not pattern_match[q].q == S.One:
return (True, pattern_match[q].q)

if not isinstance(pattern_match[a], Pow) \
or isinstance(pattern_match[a], Mul):
return (False, S.One)
else:
return _has_rational_power(pattern_match[a], symbol)


def _solve_radical(f, symbol, solveset_solver):
def _solve_radical(f, unradf, symbol, solveset_solver):
""" Helper function to solve equations with radicals """
res = unrad(f)
res = unradf
eq, cov = res if res else (f, [])
if not cov:
result = solveset_solver(eq, symbol) - \
Expand Down Expand Up @@ -1083,10 +1049,9 @@ def _solveset(f, symbol, domain, _check=False):
elif isinstance(rhs_s, FiniteSet):
for equation in [lhs - rhs for rhs in rhs_s]:
if equation == f:
if any(_has_rational_power(g, symbol)[0]
for g in equation.args) or _has_rational_power(
equation, symbol)[0]:
result += _solve_radical(equation,
u = unrad(f)
if u:
result += _solve_radical(equation, u,
symbol,
solver)
elif equation.has(Abs):
Expand Down
33 changes: 24 additions & 9 deletions sympy/solvers/tests/test_solvers.py
Expand Up @@ -973,6 +973,7 @@ def s_check(rv, ans):
return str(rv[0]) in [str(ans[0]), str(-ans[0])] and \
str(rv[1]) == str(ans[1])

assert unrad(1) is None
assert check(unrad(sqrt(x)),
(x, []))
assert check(unrad(sqrt(x) + 1),
Expand Down Expand Up @@ -1056,7 +1057,7 @@ def s_check(rv, ans):
assert solve(p + 6*I) == []
# issue 8622
assert unrad(root(x + 1, 5) - root(x, 3)) == (
x**5 - x**3 - 3*x**2 - 3*x - 1, [])
-(x**5 - x**3 - 3*x**2 - 3*x - 1), [])
# issue #8679
assert check(unrad(x + root(x, 3) + root(x, 3)**2 + sqrt(y), x),
(s**3 + s**2 + s + sqrt(y), [s, s**3 - x]))
Expand Down Expand Up @@ -1129,14 +1130,21 @@ def s_check(rv, ans):
(s**12 - 2*s**8 - 8*s**7 - 8*s**6 + s**4 + 8*s**3 + 23*s**2 +
32*s + 17, [s, s**6 - x]))

# is this needed?
#assert unrad(root(cosh(x), 3)/x*root(x + 1, 5) - 1) == (
# x**15 - x**3*cosh(x)**5 - 3*x**2*cosh(x)**5 - 3*x*cosh(x)**5 - cosh(x)**5, [])
raises(NotImplementedError, lambda:
unrad(sqrt(cosh(x)/x) + root(x + 1,3)*sqrt(x) - 1))
# why does this pass
assert unrad(root(cosh(x), 3)/x*root(x + 1, 5) - 1) == (
-(x**15 - x**3*cosh(x)**5 - 3*x**2*cosh(x)**5 - 3*x*cosh(x)**5
- cosh(x)**5), [])
# and this fail?
#assert unrad(sqrt(cosh(x)/x) + root(x + 1, 3)*sqrt(x) - 1) == (
# -s**6 + 6*s**5 - 15*s**4 + 20*s**3 - 15*s**2 + 6*s + x**5 +
# 2*x**4 + x**3 - 1, [s, s**2 - cosh(x)/x])

# watch for symbols in exponents
assert unrad(S('(x+y)**(2*y/3) + (x+y)**(1/3) + 1')) is None
assert check(unrad(S('(x+y)**(2*y/3) + (x+y)**(1/3) + 1'), x),
(s**(2*y) + s + 1, [s, s**3 - x - y]))
# should _Q be so lenient?
assert unrad(x**(S.Half/y) + y, x) == (x**(1/y) - y**2, [])

# This tests two things: that if full unrad is attempted and fails
# the solution should still be found; also it tests that the use of
Expand All @@ -1163,7 +1171,7 @@ def s_check(rv, ans):
(3*s**13 + 3*s**11 + 6*s**10 + s**9 + 12*s**8 + 6*s**6 + 12*s**5 +
12*s**3 + 7, [s, s**15 - x]))
assert check(unrad(root(x, 3) - root(x + 1, 4)/2 + root(x + 2, 3)),
(4096*s**13 + 960*s**12 + 48*s**11 - s**10 - 1728*s**4,
(s*(4096*s**9 + 960*s**8 + 48*s**7 - s**6 - 1728),
[s, s**4 - x - 1])) # orig expr has two real roots: -1, -.389
assert check(unrad(root(x, 3) + root(x + 1, 4) - root(x + 2, 3)/2),
(343*s**13 + 2904*s**12 + 1344*s**11 + 512*s**10 - 1323*s**9 -
Expand Down Expand Up @@ -1212,7 +1220,7 @@ def s_check(rv, ans):
sqrt(3)*I/2)*(3*x**3/2 - x*(3*x**2 - 34)/2 + sqrt((-3*x**3 + x*(3*x**2
- 34) + 90)**2/4 - 39304/27) - 45)**(1/3))''')
assert check(unrad(eq),
(-s*(-s**6 + sqrt(3)*s**6*I - 153*2**Rational(2, 3)*3**Rational(1, 3)*s**4 +
(s*-(-s**6 + sqrt(3)*s**6*I - 153*2**Rational(2, 3)*3**Rational(1, 3)*s**4 +
51*12**Rational(1, 3)*s**4 - 102*2**Rational(2, 3)*3**Rational(5, 6)*s**4*I - 1620*s**3 +
1620*sqrt(3)*s**3*I + 13872*18**Rational(1, 3)*s**2 - 471648 +
471648*sqrt(3)*I), [s, s**3 - 306*x - sqrt(3)*sqrt(31212*x**2 -
Expand All @@ -1235,6 +1243,13 @@ def s_check(rv, ans):
assert check(unrad(eq),
(-s**5 + s**3 - 3**(S(1)/3) - (-1)**(S(3)/5)*3**(S(1)/5), [s, s**15 - x]))

# make sure buried radicals are exposed
s = sqrt(x) - 1
assert unrad(s**2 - s**3) == (x**3 - 6*x**2 + 9*x - 4, [])
# make sure numerators which are already polynomial are rejected
assert unrad((x/(x + 1) + 3)**(-2), x) is None


@slow
def test_unrad_slow():
# this has roots with multiplicity > 1; there should be no
Expand Down Expand Up @@ -2250,7 +2265,7 @@ def test_issue_17650():

def test_issue_17882():
eq = -8*x**2/(9*(x**2 - 1)**(S(4)/3)) + 4/(3*(x**2 - 1)**(S(1)/3))
assert unrad(eq) == (4*x**2 - 12, [])
assert unrad(eq) is None


def test_issue_17949():
Expand Down

0 comments on commit 651595d

Please sign in to comment.