From 124c430ed296df685094643c3952b26213bb9093 Mon Sep 17 00:00:00 2001 From: Chris Smith Date: Fri, 5 Aug 2022 20:14:10 -0500 Subject: [PATCH] dict only returned for linear system --- sympy/core/tests/test_subs.py | 1 - .../tests/test_comb_factorials.py | 1 - .../elementary/tests/test_trigonometric.py | 1 - sympy/geometry/tests/test_line.py | 1 - sympy/solvers/solvers.py | 67 ++++++++++--------- sympy/solvers/tests/test_solvers.py | 20 +++--- 6 files changed, 47 insertions(+), 44 deletions(-) diff --git a/sympy/core/tests/test_subs.py b/sympy/core/tests/test_subs.py index 9c4a5e90c8c2..4f09a034c18e 100644 --- a/sympy/core/tests/test_subs.py +++ b/sympy/core/tests/test_subs.py @@ -882,7 +882,6 @@ def test_issue_19558(): assert (sin(x) + cos(x)).subs(x, oo) == AccumBounds(-2, 2) - def test_issue_22033(): xr = Symbol('xr', real=True) e = (1/xr) diff --git a/sympy/functions/combinatorial/tests/test_comb_factorials.py b/sympy/functions/combinatorial/tests/test_comb_factorials.py index 0baac6a56ea6..d9dce2acca6a 100644 --- a/sympy/functions/combinatorial/tests/test_comb_factorials.py +++ b/sympy/functions/combinatorial/tests/test_comb_factorials.py @@ -535,7 +535,6 @@ def test_binomial_Mod(): assert Mod(binomial(253, 113, evaluate=False), r) == Mod(binomial(253, 113), r) - @slow def test_binomial_Mod_slow(): p, q = 10**5 + 3, 10**9 + 33 # prime modulo diff --git a/sympy/functions/elementary/tests/test_trigonometric.py b/sympy/functions/elementary/tests/test_trigonometric.py index 4ac1df2dd14c..5b19bd352649 100644 --- a/sympy/functions/elementary/tests/test_trigonometric.py +++ b/sympy/functions/elementary/tests/test_trigonometric.py @@ -1502,7 +1502,6 @@ def test_inverses(): assert acot(x).inverse() == cot - def test_real_imag(): a, b = symbols('a b', real=True) z = a + b*I diff --git a/sympy/geometry/tests/test_line.py b/sympy/geometry/tests/test_line.py index d892997c18f3..63f11e45f4a5 100644 --- a/sympy/geometry/tests/test_line.py +++ b/sympy/geometry/tests/test_line.py @@ -72,7 +72,6 @@ def test_angle_between(): Line3D(Point3D(5, 0, 0), z)) == acos(-sqrt(3) / 3) - def test_closing_angle(): a = Ray((0, 0), angle=0) b = Ray((1, 2), angle=pi/2) diff --git a/sympy/solvers/solvers.py b/sympy/solvers/solvers.py index afde1bd1f45c..57acf8971365 100644 --- a/sympy/solvers/solvers.py +++ b/sympy/solvers/solvers.py @@ -1146,39 +1146,46 @@ def _has_piecewise(e): # # try to get a solution ########################################################################### - as_dict = flags.get('dict', False) - as_tuple = lambda s: [tuple([i.get(x, x) for x in symbols]) for i in s] if bare_f: solution = _solve(f[0], *symbols, **flags) - assert type(solution) is list - assert not solution or type(solution[0]) is dict, solution - unpack = lambda s: s # how the solution will be unpacked - if flags.get('dict', False): - pass - elif flags.get('set', False): - unpack = as_tuple - elif len(symbols) == 1: + else: + linear, solution = _solve_system(f, symbols, **flags) + assert type(solution) is list + assert not solution or type(solution[0]) is dict, solution + + # + # postprocessing + ########################################################################### + # capture as_dict flag now (as_set already captured) + as_dict = flags.get('dict', False) + + # define how solution will get unpacked + tuple_format = lambda s: [tuple([i.get(x, x) for x in symbols]) for i in s] + if as_dict: + unpack = lambda s: s + elif as_set: + unpack = tuple_format + elif bare_f: + if len(symbols) == 1: unpack = lambda s: [i[symbols[0]] for i in s] elif len(solution) == 1 and len(solution[0]) == len(symbols): + # undetermined linear coeffs solution unpack = lambda s: s[0] elif ordered_symbols: - unpack = as_tuple + unpack = tuple_format + else: + unpack = lambda s: s else: - linear, solution = _solve_system(f, symbols, **flags) - assert type(solution) is list - assert not solution or type(solution[0]) is dict, solution if solution: if linear: unpack = lambda s: s[0] elif ordered_symbols: - unpack = as_tuple + unpack = tuple_format else: unpack = lambda s: s else: unpack = None - # - # postprocessing - ########################################################################### + # Restore masked-off objects if non_inverts and type(solution) is list: solution = [{k: v.subs(non_inverts) for k, v in s.items()} @@ -1287,8 +1294,8 @@ def _solve(f, *symbols, **flags): if len(ex) != 1: ind, dep = f.as_independent(*symbols) # (2*x - c, a*x + b) ex = ind.free_symbols & dep.free_symbols # {x, c} & {a, x, b} -> {x} - if len(ex) != 1: # e.g. (a+b)*x + b - c -> {c}, {a, b, x} - ex = dep.free_symbols - set(symbols) # {x} = {a,b,x}-{a.b} + if len(ex) != 1: # e.g. (a+b)*x + b - c -> {c}, {a, b, x} + ex = dep.free_symbols - set(symbols) # {x} = {a,b,x}-{a.b} if len(ex) == 1: ex = ex.pop() try: @@ -1767,21 +1774,23 @@ def _solve_system(exprs, symbols, **flags): for soldicts in product(*subsols): sols.append(dict(item for sd in soldicts for item in sd.items())) - # legacy "linear" value - linear = len(sols) == 1 return linear, sols polys = [] dens = set() failed = [] - result = False - linear = False + result = [] + solved_syms = [] + linear = True manual = flags.get('manual', False) checkdens = check = flags.get('check', True) for j, g in enumerate(exprs): dens.update(_simple_dens(g, symbols)) i, d = _invert(g, *symbols) + if d in symbols: + if linear: + linear = solve_linear(g, 0, [d])[0] == d g = d - i g = g.as_numer_denom()[0] if manual: @@ -1795,9 +1804,7 @@ def _solve_system(exprs, symbols, **flags): else: failed.append(g) - if not polys: - solved_syms = [] - else: + if polys: if all(p.is_linear for p in polys): n, m = len(polys), len(symbols) matrix = zeros(n, m + 1) @@ -1821,10 +1828,9 @@ def _solve_system(exprs, symbols, **flags): solved_syms = list(result[0].keys()) # there is only one result dict else: solved_syms = [] - else: - linear = True - + # linear doesn't change else: + linear = False if len(symbols) > len(polys): free = set().union(*[p.free_symbols for p in polys]) @@ -1867,6 +1873,7 @@ def _solve_system(exprs, symbols, **flags): result = result or [{}] if failed: + linear = False # For each failed equation, see if we can solve for one of the # remaining symbols from that equation. If so, we update the # solution set and continue with the next failed equation, diff --git a/sympy/solvers/tests/test_solvers.py b/sympy/solvers/tests/test_solvers.py index 95078cb84427..25470039c865 100644 --- a/sympy/solvers/tests/test_solvers.py +++ b/sympy/solvers/tests/test_solvers.py @@ -343,7 +343,7 @@ def test_solve_io(): assert [solve([*e], {y, x}, **f) for f in flags] == [lod, lod, sxy] # {x, y} assert [solve([*e], x, z, y, **f) for f in flags] == [txzy, lod, sxzy] # x, z (missing), y assert [solve([*e], {x, z, y}, **f) for f in flags] == [lod, lod, sxy] # {x, y, z (missing)} - assert solve((exp(x) - 1, y - 2)) == {x: 0, y: 2} + assert solve((exp(x) - 1, y - 2)) == [{x: 0, y: 2}] d = {x: 0, y: 1} lxy = [(0, 1)] @@ -424,12 +424,12 @@ def test_solve_args(): NotImplementedError, lambda: solve(exp(x) + sin(x) + exp(y) + sin(y))) # failed system # -- when no symbols given, 1 fails - assert solve([y, exp(x) + x]) == {x: -LambertW(1), y: 0} + assert solve([y, exp(x) + x]) == [{x: -LambertW(1), y: 0}] # both fail assert solve( - (exp(x) - x, exp(y) - y)) == {x: -LambertW(-1), y: -LambertW(-1)} + (exp(x) - x, exp(y) - y)) == [{x: -LambertW(-1), y: -LambertW(-1)}] # -- when symbols given - assert solve([y, exp(x) + x], x, y) == {y: 0, x: -LambertW(1)} + assert solve([y, exp(x) + x], x, y) == [(-LambertW(1), 0)] # symbol is a number assert solve(x**2 - pi, pi) == [x**2] # no equations @@ -1659,7 +1659,8 @@ def test_issue_5849(): v = I1, I4, Q2, Q4, dI1, dI4, dQ2, dQ4 assert solve(e, *v, manual=True, check=False, dict=True) == ans - assert solve(e, *v, manual=True, check=False) == ans[0] + assert solve(e, *v, manual=True, check=False) == [ + tuple([a.get(i, i) for i in v]) for a in ans] assert solve(e, *v, manual=True) == [] assert solve(e, *v) == [] @@ -2717,13 +2718,12 @@ def test_issue_20902(): def test_issue_21034(): a = symbols('a', real=True) system = [x - cosh(cos(4)), y - sinh(cos(a)), z - tanh(x)] - assert solve(system, x, y, z) == {x: cosh(cos(4)), z: tanh(cosh(cos(4))), - y: sinh(cos(a))} - #Constants inside hyperbolic functions should not be rewritten in terms of exp + assert solve(system, x, y, z) == [(cosh(cos(4)), sinh(cos(a)), tanh(cosh(cos(4))))] + # constants inside hyperbolic functions should not be rewritten + # in terms of exp; rewriting should only happen with hyperbolics + # containing a symbol if interest newsystem = [(exp(x) - exp(-x)) - tanh(x)*(exp(x) + exp(-x)) + x - 5] assert solve(newsystem, x) == {x: 5} - #If the variable of interest is present in hyperbolic function, only then - # it shouuld be rewritten in terms of exp and solved further def test_issue_4886():