From 651595d137c2c4469d97be2458f469dfbdb2c9b9 Mon Sep 17 00:00:00 2001 From: Chris Smith Date: Thu, 8 Apr 2021 04:19:16 -0500 Subject: [PATCH] unrad-related changes --- sympy/calculus/tests/test_util.py | 6 ++ sympy/solvers/solvers.py | 83 ++++++++++--------------- sympy/solvers/solveset.py | 49 +++------------ sympy/solvers/tests/test_solvers.py | 33 +++++++--- sympy/solvers/tests/test_solveset.py | 25 ++++---- sympy/stats/tests/test_continuous_rv.py | 13 +++- 6 files changed, 92 insertions(+), 117 deletions(-) diff --git a/sympy/calculus/tests/test_util.py b/sympy/calculus/tests/test_util.py index 5e377a98ff4b..befced601c89 100644 --- a/sympy/calculus/tests/test_util.py +++ b/sympy/calculus/tests/test_util.py @@ -300,6 +300,12 @@ def test_minimum(): raises(ValueError, lambda : minimum(sin(x), S.One)) +def test_issue_19869(): + t = symbols('t') + assert (maximum(sqrt(3)*(t - 1)/(3*sqrt(t**2 + 1)), t) + ) == sqrt(3)/3 + + def test_AccumBounds(): assert AccumBounds(1, 2).args == (1, 2) assert AccumBounds(1, 2).delta is S.One diff --git a/sympy/solvers/solvers.py b/sympy/solvers/solvers.py index 06d30cbdea09..16714416e7da 100644 --- a/sympy/solvers/solvers.py +++ b/sympy/solvers/solvers.py @@ -42,7 +42,7 @@ from sympy.simplify.fu import TR1, TR2i from sympy.matrices.common import NonInvertibleMatrixError from sympy.matrices import Matrix, zeros -from sympy.polys import roots, cancel, factor, Poly, degree +from sympy.polys import roots, cancel, factor, Poly from sympy.polys.polyerrors import GeneratorsNeeded, PolynomialError from sympy.polys.solvers import sympy_eqs_to_ring, solve_lin_sys @@ -3207,12 +3207,13 @@ def unrad(eq, *syms, **flags): >>> unrad(sqrt(x)*x**Rational(1, 3) + 2) (x**5 - 64, []) >>> unrad(sqrt(x) + root(x + 1, 3)) - (x**3 - x**2 - 2*x - 1, []) + (-x**3 + x**2 + 2*x + 1, []) >>> eq = sqrt(x) + root(x, 3) - 2 >>> unrad(eq) (_p**3 + _p**2 - 2, [_p, _p**6 - x]) """ + from sympy import Equality as Eq uflags = dict(check=False, simplify=False) @@ -3243,19 +3244,21 @@ def _canonical(eq, cov): for f in eq.args: if f.is_number: continue - if f.is_Pow and _take(f, True): + if f.is_Pow: args.append(f.base) else: args.append(f) eq = Mul(*args) # leave as Mul for more efficient solving # make the sign canonical - free = eq.free_symbols - if len(free) == 1: - if eq.coeff(free.pop()**degree(eq)).could_extract_minus_sign(): - eq = -eq - elif eq.could_extract_minus_sign(): - eq = -eq + margs = list(Mul.make_args(eq)) + changed = False + for i, m in enumerate(margs): + if m.could_extract_minus_sign(): + margs[i] = -m + changed = True + if changed: + eq = Mul(*margs, evaluate=False) return eq, cov @@ -3267,52 +3270,46 @@ def _Q(pow): return c.q # define the _take method that will determine whether a term is of interest - def _take(d, take_int_pow): + def _take(d): # return True if coefficient of any factor's exponent's den is not 1 for pow in Mul.make_args(d): - if not (pow.is_Symbol or pow.is_Pow): + if not pow.is_Pow: continue - b, e = pow.as_base_exp() - if not b.has(*syms): + if _Q(pow) == 1: continue - if not take_int_pow and _Q(pow) == 1: - continue - free = pow.free_symbols - if free.intersection(syms): + if pow.free_symbols & syms: return True return False _take = flags.setdefault('_take', _take) + if isinstance(eq, Eq): + eq = eq.rewrite(Add) + assert isinstance(eq, Expr) + elif not isinstance(eq, Expr): + return + cov, nwas, rpt = [flags.setdefault(k, v) for k, v in sorted(dict(cov=[], n=None, rpt=0).items())] # preconditioning eq = powdenest(factor_terms(eq, radical=True, clear=True)) - - if isinstance(eq, Relational): - eq, d = eq, 1 - else: - eq, d = eq.as_numer_denom() - + eq = eq.as_numer_denom()[0] eq = _mexpand(eq, recursive=True) if eq.is_number: - return eq, [] + return - syms = set(syms) or eq.free_symbols + # see if there are radicals in symbols of interest + syms = set(syms) or eq.free_symbols # _take uses this poly = eq.as_poly() - gens = [g for g in poly.gens if _take(g, True)] + gens = [g for g in poly.gens if _take(g)] if not gens: return - # check for trivial case - # - already a polynomial in integer powers - if all(_Q(g) == 1 for g in gens): - if (len(gens) == len(poly.gens) and d!=1): - return eq, [] - else: - return + # recast poly in terms of eigen-gens + poly = eq.as_poly(*gens) + # - an exponent has a symbol of interest (don't handle) - if any(g.as_base_exp()[1].has(*syms) for g in gens): + if any(g.exp.has(*syms) for g in gens): return def _rads_bases_lcm(poly): @@ -3323,8 +3320,6 @@ def _rads_bases_lcm(poly): rads = set() bases = set() for g in poly.gens: - if not _take(g, False): - continue q = _Q(g) if q != 1: rads.add(g) @@ -3333,9 +3328,6 @@ def _rads_bases_lcm(poly): return rads, bases, lcm rads, bases, lcm = _rads_bases_lcm(poly) - if not rads: - return - covsym = Dummy('p', nonnegative=True) # only keep in syms symbols that actually appear in radicals; @@ -3352,7 +3344,7 @@ def _rads_bases_lcm(poly): rterms = {(): []} args = Add.make_args(poly.as_expr()) for t in args: - if _take(t, False): + if _take(t): common = set(t.as_poly().gens).intersection(rads) key = tuple(sorted([drad[i] for i in common])) else: @@ -3377,14 +3369,10 @@ def _rads_bases_lcm(poly): if len(bases) == 1: b = bases.pop() if len(syms) > 1: - free = b.free_symbols - x = {g for g in gens if g.is_Symbol} & free - if not x: - x = free - x = ordered(x) + x = b.free_symbols else: x = syms - x = list(x)[0] + x = list(ordered(x))[0] try: inv = _solve(covsym**lcm - b, x, **uflags) if not inv: @@ -3394,9 +3382,6 @@ def _rads_bases_lcm(poly): return _canonical(eq, cov) except NotImplementedError: pass - else: - # no longer consider integer powers as generators - gens = [g for g in gens if _Q(g) != 1] if len(rterms) == 2: if not others: diff --git a/sympy/solvers/solveset.py b/sympy/solvers/solveset.py index eb77ef03b411..87154e61d9e7 100644 --- a/sympy/solvers/solveset.py +++ b/sympy/solvers/solveset.py @@ -551,7 +551,8 @@ def _is_function_class_equation(func_class, f, symbol): def _solve_as_rational(f, symbol, domain): """ solve rational functions""" - f = together(f, deep=True) + from sympy.core.function import _mexpand + f = together(_mexpand(f, recursive=True), deep=True) g, h = fraction(f) if not h.has(symbol): try: @@ -826,44 +827,9 @@ def _solve_as_poly(f, symbol, domain=S.Complexes): return ConditionSet(symbol, Eq(f, 0), domain) -def _has_rational_power(expr, symbol): - """ - Returns (bool, den) where bool is True if the term has a - non-integer rational power and den is the denominator of the - expression's exponent. - - Examples - ======== - - >>> from sympy.solvers.solveset import _has_rational_power - >>> from sympy import sqrt - >>> from sympy.abc import x - >>> _has_rational_power(sqrt(x), x) - (True, 2) - >>> _has_rational_power(x**2, x) - (False, 1) - """ - a, p, q = Wild('a'), Wild('p'), Wild('q') - pattern_match = expr.match(a*p**q) or {} - if pattern_match.get(a, S.Zero).is_zero: - return (False, S.One) - elif p not in pattern_match.keys(): - return (False, S.One) - elif isinstance(pattern_match[q], Rational) \ - and pattern_match[p].has(symbol): - if not pattern_match[q].q == S.One: - return (True, pattern_match[q].q) - - if not isinstance(pattern_match[a], Pow) \ - or isinstance(pattern_match[a], Mul): - return (False, S.One) - else: - return _has_rational_power(pattern_match[a], symbol) - - -def _solve_radical(f, symbol, solveset_solver): +def _solve_radical(f, unradf, symbol, solveset_solver): """ Helper function to solve equations with radicals """ - res = unrad(f) + res = unradf eq, cov = res if res else (f, []) if not cov: result = solveset_solver(eq, symbol) - \ @@ -1083,10 +1049,9 @@ def _solveset(f, symbol, domain, _check=False): elif isinstance(rhs_s, FiniteSet): for equation in [lhs - rhs for rhs in rhs_s]: if equation == f: - if any(_has_rational_power(g, symbol)[0] - for g in equation.args) or _has_rational_power( - equation, symbol)[0]: - result += _solve_radical(equation, + u = unrad(f) + if u: + result += _solve_radical(equation, u, symbol, solver) elif equation.has(Abs): diff --git a/sympy/solvers/tests/test_solvers.py b/sympy/solvers/tests/test_solvers.py index 5fc1da4edb6a..74a6d7f09bc0 100644 --- a/sympy/solvers/tests/test_solvers.py +++ b/sympy/solvers/tests/test_solvers.py @@ -973,6 +973,7 @@ def s_check(rv, ans): return str(rv[0]) in [str(ans[0]), str(-ans[0])] and \ str(rv[1]) == str(ans[1]) + assert unrad(1) is None assert check(unrad(sqrt(x)), (x, [])) assert check(unrad(sqrt(x) + 1), @@ -1056,7 +1057,7 @@ def s_check(rv, ans): assert solve(p + 6*I) == [] # issue 8622 assert unrad(root(x + 1, 5) - root(x, 3)) == ( - x**5 - x**3 - 3*x**2 - 3*x - 1, []) + -(x**5 - x**3 - 3*x**2 - 3*x - 1), []) # issue #8679 assert check(unrad(x + root(x, 3) + root(x, 3)**2 + sqrt(y), x), (s**3 + s**2 + s + sqrt(y), [s, s**3 - x])) @@ -1129,14 +1130,21 @@ def s_check(rv, ans): (s**12 - 2*s**8 - 8*s**7 - 8*s**6 + s**4 + 8*s**3 + 23*s**2 + 32*s + 17, [s, s**6 - x])) - # is this needed? - #assert unrad(root(cosh(x), 3)/x*root(x + 1, 5) - 1) == ( - # x**15 - x**3*cosh(x)**5 - 3*x**2*cosh(x)**5 - 3*x*cosh(x)**5 - cosh(x)**5, []) - raises(NotImplementedError, lambda: - unrad(sqrt(cosh(x)/x) + root(x + 1,3)*sqrt(x) - 1)) + # why does this pass + assert unrad(root(cosh(x), 3)/x*root(x + 1, 5) - 1) == ( + -(x**15 - x**3*cosh(x)**5 - 3*x**2*cosh(x)**5 - 3*x*cosh(x)**5 + - cosh(x)**5), []) + # and this fail? + #assert unrad(sqrt(cosh(x)/x) + root(x + 1, 3)*sqrt(x) - 1) == ( + # -s**6 + 6*s**5 - 15*s**4 + 20*s**3 - 15*s**2 + 6*s + x**5 + + # 2*x**4 + x**3 - 1, [s, s**2 - cosh(x)/x]) + + # watch for symbols in exponents assert unrad(S('(x+y)**(2*y/3) + (x+y)**(1/3) + 1')) is None assert check(unrad(S('(x+y)**(2*y/3) + (x+y)**(1/3) + 1'), x), (s**(2*y) + s + 1, [s, s**3 - x - y])) + # should _Q be so lenient? + assert unrad(x**(S.Half/y) + y, x) == (x**(1/y) - y**2, []) # This tests two things: that if full unrad is attempted and fails # the solution should still be found; also it tests that the use of @@ -1163,7 +1171,7 @@ def s_check(rv, ans): (3*s**13 + 3*s**11 + 6*s**10 + s**9 + 12*s**8 + 6*s**6 + 12*s**5 + 12*s**3 + 7, [s, s**15 - x])) assert check(unrad(root(x, 3) - root(x + 1, 4)/2 + root(x + 2, 3)), - (4096*s**13 + 960*s**12 + 48*s**11 - s**10 - 1728*s**4, + (s*(4096*s**9 + 960*s**8 + 48*s**7 - s**6 - 1728), [s, s**4 - x - 1])) # orig expr has two real roots: -1, -.389 assert check(unrad(root(x, 3) + root(x + 1, 4) - root(x + 2, 3)/2), (343*s**13 + 2904*s**12 + 1344*s**11 + 512*s**10 - 1323*s**9 - @@ -1212,7 +1220,7 @@ def s_check(rv, ans): sqrt(3)*I/2)*(3*x**3/2 - x*(3*x**2 - 34)/2 + sqrt((-3*x**3 + x*(3*x**2 - 34) + 90)**2/4 - 39304/27) - 45)**(1/3))''') assert check(unrad(eq), - (-s*(-s**6 + sqrt(3)*s**6*I - 153*2**Rational(2, 3)*3**Rational(1, 3)*s**4 + + (s*-(-s**6 + sqrt(3)*s**6*I - 153*2**Rational(2, 3)*3**Rational(1, 3)*s**4 + 51*12**Rational(1, 3)*s**4 - 102*2**Rational(2, 3)*3**Rational(5, 6)*s**4*I - 1620*s**3 + 1620*sqrt(3)*s**3*I + 13872*18**Rational(1, 3)*s**2 - 471648 + 471648*sqrt(3)*I), [s, s**3 - 306*x - sqrt(3)*sqrt(31212*x**2 - @@ -1235,6 +1243,13 @@ def s_check(rv, ans): assert check(unrad(eq), (-s**5 + s**3 - 3**(S(1)/3) - (-1)**(S(3)/5)*3**(S(1)/5), [s, s**15 - x])) + # make sure buried radicals are exposed + s = sqrt(x) - 1 + assert unrad(s**2 - s**3) == (x**3 - 6*x**2 + 9*x - 4, []) + # make sure numerators which are already polynomial are rejected + assert unrad((x/(x + 1) + 3)**(-2), x) is None + + @slow def test_unrad_slow(): # this has roots with multiplicity > 1; there should be no @@ -2250,7 +2265,7 @@ def test_issue_17650(): def test_issue_17882(): eq = -8*x**2/(9*(x**2 - 1)**(S(4)/3)) + 4/(3*(x**2 - 1)**(S(1)/3)) - assert unrad(eq) == (4*x**2 - 12, []) + assert unrad(eq) is None def test_issue_17949(): diff --git a/sympy/solvers/tests/test_solveset.py b/sympy/solvers/tests/test_solveset.py index 4cf0c2b312dd..0970854263ee 100644 --- a/sympy/solvers/tests/test_solveset.py +++ b/sympy/solvers/tests/test_solveset.py @@ -213,6 +213,14 @@ def test_issue_18449(): # TODO: Is the above solution set definitely complete? +def test_issue_21047(): + f = (2 - x)**2 + (sqrt(x - 1) - 1)**6 + assert(solveset(f, x, S.Reals)) == FiniteSet(2) + f = (sqrt(x)-1)**2 + (sqrt(x)+1)**2 -2*x**2 + sqrt(2) + assert solveset(f, x, S.Reals) == FiniteSet( + S.Half - sqrt(2*sqrt(2) + 5)/2, S.Half + sqrt(2*sqrt(2) + 5)/2) + + def test_is_function_class_equation(): from sympy.abc import x, a assert _is_function_class_equation(TrigonometricFunction, @@ -387,16 +395,6 @@ def test_return_root_of(): CRootOf(x**6 - x + 1, 5)) -def test__has_rational_power(): - from sympy.solvers.solveset import _has_rational_power - assert _has_rational_power(sqrt(2), x)[0] is False - assert _has_rational_power(x*sqrt(2), x)[0] is False - - assert _has_rational_power(x**2*sqrt(x), x) == (True, 2) - assert _has_rational_power(sqrt(2)*x**Rational(1, 3), x) == (True, 3) - assert _has_rational_power(sqrt(x)*x**Rational(1, 3), x) == (True, 6) - - def test_solveset_sqrt_1(): assert solveset_real(sqrt(5*x + 6) - 2 - x, x) == \ FiniteSet(-S.One, S(2)) @@ -453,6 +451,9 @@ def test_solveset_sqrt_2(): eq = sqrt(x) - sqrt(x - 1) + sqrt(sqrt(x)) assert solveset_real(eq, x) == FiniteSet() + eq = (x - 4)**2 + (sqrt(x) - 2)**4 + assert solveset_real(eq, x) == FiniteSet(-4, 4) + eq = (sqrt(x) + sqrt(x + 1) + sqrt(1 - x) - 6*sqrt(5)/5) ans = solveset_real(eq, x) ra = S('''-1484/375 - 4*(-1/2 + sqrt(3)*I/2)*(-12459439/52734375 + @@ -1294,14 +1295,10 @@ def test_solveset_domain(): def test_improve_coverage(): - from sympy.solvers.solveset import _has_rational_power solution = solveset(exp(x) + sin(x), x, S.Reals) unsolved_object = ConditionSet(x, Eq(exp(x) + sin(x), 0), S.Reals) assert solution.dummy_eq(unsolved_object) - assert _has_rational_power(sin(x)*exp(x) + 1, x) == (False, S.One) - assert _has_rational_power((sin(x)**2)*(exp(x) + 1)**3, x) == (False, S.One) - def test_issue_9522(): expr1 = Eq(1/(x**2 - 4) + x, 1/(x**2 - 4) + 2) diff --git a/sympy/stats/tests/test_continuous_rv.py b/sympy/stats/tests/test_continuous_rv.py index 55bb6eb16abc..393f3509c08b 100644 --- a/sympy/stats/tests/test_continuous_rv.py +++ b/sympy/stats/tests/test_continuous_rv.py @@ -4,10 +4,11 @@ gamma, beta, Piecewise, Integral, sin, cos, tan, atan, sinh, cosh, besseli, floor, expand_func, Rational, I, re, Lambda, asin, im, lambdify, hyper, diff, Or, Mul, sign, Dummy, Sum, - factorial, binomial, erfi, besselj, besselk) + factorial, binomial, erfi, besselj, besselk, factor_terms) from sympy.functions.special.error_functions import erfinv from sympy.functions.special.hyper import meijerg -from sympy.sets.sets import Intersection, FiniteSet +from sympy.sets.sets import FiniteSet, Complement +from sympy.sets.conditionset import ConditionSet from sympy.stats import (P, E, where, density, variance, covariance, skewness, kurtosis, median, given, pspace, cdf, characteristic_function, moment_generating_function, ContinuousRV, Arcsin, Benini, Beta, BetaNoncentral, BetaPrime, @@ -46,7 +47,13 @@ def test_single_normal(): 2**S.Half*exp(-(x - mu)**2/(2*sigma**2))/(2*pi**S.Half*sigma)) assert P(X**2 < 1) == erf(2**S.Half/2) - assert quantile(Y)(x) == Intersection(S.Reals, FiniteSet(sqrt(2)*sigma*(sqrt(2)*mu/(2*sigma) + erfinv(2*x - 1)))) + ans = quantile(Y)(x) + eq = ans.atoms(Eq).pop() + ans = factor_terms(ans.xreplace({eq: Eq(eq.lhs.simplify().factor(), 0)} + ).xreplace({eq.atoms(Dummy).pop(): y})) + assert ans == Complement(ConditionSet(y, Eq((-mu + + y)*(2*x + erf(sqrt(2)*(mu - y)/(2*sigma)) - 1), + 0), S.Reals), FiniteSet(mu)) assert E(X, Eq(X, mu)) == mu assert median(X) == FiniteSet(0)