Skip to content

Commit

Permalink
dict only returned for linear system
Browse files Browse the repository at this point in the history
  • Loading branch information
smichr committed Aug 7, 2022
1 parent d14ddee commit 124c430
Show file tree
Hide file tree
Showing 6 changed files with 47 additions and 44 deletions.
1 change: 0 additions & 1 deletion sympy/core/tests/test_subs.py
Expand Up @@ -882,7 +882,6 @@ def test_issue_19558():
assert (sin(x) + cos(x)).subs(x, oo) == AccumBounds(-2, 2)



def test_issue_22033():
xr = Symbol('xr', real=True)
e = (1/xr)
Expand Down
Expand Up @@ -535,7 +535,6 @@ def test_binomial_Mod():
assert Mod(binomial(253, 113, evaluate=False), r) == Mod(binomial(253, 113), r)



@slow
def test_binomial_Mod_slow():
p, q = 10**5 + 3, 10**9 + 33 # prime modulo
Expand Down
1 change: 0 additions & 1 deletion sympy/functions/elementary/tests/test_trigonometric.py
Expand Up @@ -1502,7 +1502,6 @@ def test_inverses():
assert acot(x).inverse() == cot



def test_real_imag():
a, b = symbols('a b', real=True)
z = a + b*I
Expand Down
1 change: 0 additions & 1 deletion sympy/geometry/tests/test_line.py
Expand Up @@ -72,7 +72,6 @@ def test_angle_between():
Line3D(Point3D(5, 0, 0), z)) == acos(-sqrt(3) / 3)



def test_closing_angle():
a = Ray((0, 0), angle=0)
b = Ray((1, 2), angle=pi/2)
Expand Down
67 changes: 37 additions & 30 deletions sympy/solvers/solvers.py
Expand Up @@ -1146,39 +1146,46 @@ def _has_piecewise(e):
#
# try to get a solution
###########################################################################
as_dict = flags.get('dict', False)
as_tuple = lambda s: [tuple([i.get(x, x) for x in symbols]) for i in s]
if bare_f:
solution = _solve(f[0], *symbols, **flags)
assert type(solution) is list
assert not solution or type(solution[0]) is dict, solution
unpack = lambda s: s # how the solution will be unpacked
if flags.get('dict', False):
pass
elif flags.get('set', False):
unpack = as_tuple
elif len(symbols) == 1:
else:
linear, solution = _solve_system(f, symbols, **flags)
assert type(solution) is list
assert not solution or type(solution[0]) is dict, solution

#
# postprocessing
###########################################################################
# capture as_dict flag now (as_set already captured)
as_dict = flags.get('dict', False)

# define how solution will get unpacked
tuple_format = lambda s: [tuple([i.get(x, x) for x in symbols]) for i in s]
if as_dict:
unpack = lambda s: s
elif as_set:
unpack = tuple_format
elif bare_f:
if len(symbols) == 1:
unpack = lambda s: [i[symbols[0]] for i in s]
elif len(solution) == 1 and len(solution[0]) == len(symbols):
# undetermined linear coeffs solution
unpack = lambda s: s[0]
elif ordered_symbols:
unpack = as_tuple
unpack = tuple_format
else:
unpack = lambda s: s
else:
linear, solution = _solve_system(f, symbols, **flags)
assert type(solution) is list
assert not solution or type(solution[0]) is dict, solution
if solution:
if linear:
unpack = lambda s: s[0]
elif ordered_symbols:
unpack = as_tuple
unpack = tuple_format
else:
unpack = lambda s: s
else:
unpack = None
#
# postprocessing
###########################################################################

# Restore masked-off objects
if non_inverts and type(solution) is list:
solution = [{k: v.subs(non_inverts) for k, v in s.items()}
Expand Down Expand Up @@ -1287,8 +1294,8 @@ def _solve(f, *symbols, **flags):
if len(ex) != 1:
ind, dep = f.as_independent(*symbols) # (2*x - c, a*x + b)
ex = ind.free_symbols & dep.free_symbols # {x, c} & {a, x, b} -> {x}
if len(ex) != 1: # e.g. (a+b)*x + b - c -> {c}, {a, b, x}
ex = dep.free_symbols - set(symbols) # {x} = {a,b,x}-{a.b}
if len(ex) != 1: # e.g. (a+b)*x + b - c -> {c}, {a, b, x}
ex = dep.free_symbols - set(symbols) # {x} = {a,b,x}-{a.b}
if len(ex) == 1:
ex = ex.pop()
try:
Expand Down Expand Up @@ -1767,21 +1774,23 @@ def _solve_system(exprs, symbols, **flags):
for soldicts in product(*subsols):
sols.append(dict(item for sd in soldicts
for item in sd.items()))
# legacy "linear" value
linear = len(sols) == 1
return linear, sols

polys = []
dens = set()
failed = []
result = False
linear = False
result = []
solved_syms = []
linear = True
manual = flags.get('manual', False)
checkdens = check = flags.get('check', True)

for j, g in enumerate(exprs):
dens.update(_simple_dens(g, symbols))
i, d = _invert(g, *symbols)
if d in symbols:
if linear:
linear = solve_linear(g, 0, [d])[0] == d
g = d - i
g = g.as_numer_denom()[0]
if manual:
Expand All @@ -1795,9 +1804,7 @@ def _solve_system(exprs, symbols, **flags):
else:
failed.append(g)

if not polys:
solved_syms = []
else:
if polys:
if all(p.is_linear for p in polys):
n, m = len(polys), len(symbols)
matrix = zeros(n, m + 1)
Expand All @@ -1821,10 +1828,9 @@ def _solve_system(exprs, symbols, **flags):
solved_syms = list(result[0].keys()) # there is only one result dict
else:
solved_syms = []
else:
linear = True

# linear doesn't change
else:
linear = False
if len(symbols) > len(polys):

free = set().union(*[p.free_symbols for p in polys])
Expand Down Expand Up @@ -1867,6 +1873,7 @@ def _solve_system(exprs, symbols, **flags):
result = result or [{}]

if failed:
linear = False
# For each failed equation, see if we can solve for one of the
# remaining symbols from that equation. If so, we update the
# solution set and continue with the next failed equation,
Expand Down
20 changes: 10 additions & 10 deletions sympy/solvers/tests/test_solvers.py
Expand Up @@ -343,7 +343,7 @@ def test_solve_io():
assert [solve([*e], {y, x}, **f) for f in flags] == [lod, lod, sxy] # {x, y}
assert [solve([*e], x, z, y, **f) for f in flags] == [txzy, lod, sxzy] # x, z (missing), y
assert [solve([*e], {x, z, y}, **f) for f in flags] == [lod, lod, sxy] # {x, y, z (missing)}
assert solve((exp(x) - 1, y - 2)) == {x: 0, y: 2}
assert solve((exp(x) - 1, y - 2)) == [{x: 0, y: 2}]

d = {x: 0, y: 1}
lxy = [(0, 1)]
Expand Down Expand Up @@ -424,12 +424,12 @@ def test_solve_args():
NotImplementedError, lambda: solve(exp(x) + sin(x) + exp(y) + sin(y)))
# failed system
# -- when no symbols given, 1 fails
assert solve([y, exp(x) + x]) == {x: -LambertW(1), y: 0}
assert solve([y, exp(x) + x]) == [{x: -LambertW(1), y: 0}]
# both fail
assert solve(
(exp(x) - x, exp(y) - y)) == {x: -LambertW(-1), y: -LambertW(-1)}
(exp(x) - x, exp(y) - y)) == [{x: -LambertW(-1), y: -LambertW(-1)}]
# -- when symbols given
assert solve([y, exp(x) + x], x, y) == {y: 0, x: -LambertW(1)}
assert solve([y, exp(x) + x], x, y) == [(-LambertW(1), 0)]
# symbol is a number
assert solve(x**2 - pi, pi) == [x**2]
# no equations
Expand Down Expand Up @@ -1659,7 +1659,8 @@ def test_issue_5849():

v = I1, I4, Q2, Q4, dI1, dI4, dQ2, dQ4
assert solve(e, *v, manual=True, check=False, dict=True) == ans
assert solve(e, *v, manual=True, check=False) == ans[0]
assert solve(e, *v, manual=True, check=False) == [
tuple([a.get(i, i) for i in v]) for a in ans]
assert solve(e, *v, manual=True) == []
assert solve(e, *v) == []

Expand Down Expand Up @@ -2717,13 +2718,12 @@ def test_issue_20902():
def test_issue_21034():
a = symbols('a', real=True)
system = [x - cosh(cos(4)), y - sinh(cos(a)), z - tanh(x)]
assert solve(system, x, y, z) == {x: cosh(cos(4)), z: tanh(cosh(cos(4))),
y: sinh(cos(a))}
#Constants inside hyperbolic functions should not be rewritten in terms of exp
assert solve(system, x, y, z) == [(cosh(cos(4)), sinh(cos(a)), tanh(cosh(cos(4))))]
# constants inside hyperbolic functions should not be rewritten
# in terms of exp; rewriting should only happen with hyperbolics
# containing a symbol if interest
newsystem = [(exp(x) - exp(-x)) - tanh(x)*(exp(x) + exp(-x)) + x - 5]
assert solve(newsystem, x) == {x: 5}
#If the variable of interest is present in hyperbolic function, only then
# it shouuld be rewritten in terms of exp and solved further


def test_issue_4886():
Expand Down

0 comments on commit 124c430

Please sign in to comment.