Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP
Browse files

Merge pull request #1411 from smichr/rand

Randomization fixes
  • Loading branch information...
commit 9ff4edf72e293ab00d7204fb0309b7ab90acdf44 2 parents 262bdc4 + bf84abd
Christopher Smith smichr authored
40 sympy/core/basic.py
View
@@ -1501,44 +1501,54 @@ class preorder_traversal(object):
.args, which in many cases can be arbitrary.
Parameters
- ----------
+ ==========
node : sympy expression
The expression to traverse.
+ key : (default None) sort key
+ The key used to sort args of Basic objects. When None, args of Basic
+ objects are processed in arbitrary order.
Yields
- ------
+ ======
subtree : sympy expression
All of the subtrees in the tree.
Examples
- --------
+ ========
>>> from sympy import symbols
+ >>> from sympy import symbols, default_sort_key
>>> from sympy.core.basic import preorder_traversal
>>> x, y, z = symbols('x y z')
- >>> list(preorder_traversal(z*(x+y))) in ( # any of these are possible
- ... [z*(x + y), z, x + y, x, y], [z*(x + y), z, x + y, y, x],
- ... [z*(x + y), x + y, x, y, z], [z*(x + y), x + y, y, x, z])
- True
- >>> list(preorder_traversal((x, (y, z))))
- [(x, (y, z)), x, (y, z), y, z]
+
+ The nodes are returned in the order that they are encountered unless key
+ is given.
+
+ >>> list(preorder_traversal((x + y)*z, key=None)) # doctest: +SKIP
+ [z*(x + y), z, x + y, y, x]
+ >>> list(preorder_traversal((x + y)*z, key=default_sort_key))
+ [z*(x + y), z, x + y, x, y]
"""
- def __init__(self, node):
+ def __init__(self, node, key=None):
self._skip_flag = False
- self._pt = self._preorder_traversal(node)
+ self._pt = self._preorder_traversal(node, key)
- def _preorder_traversal(self, node):
+ def _preorder_traversal(self, node, key):
yield node
if self._skip_flag:
self._skip_flag = False
return
if isinstance(node, Basic):
- for arg in node.args:
- for subtree in self._preorder_traversal(arg):
+ args = node.args
+ if key:
+ args = list(args)
+ args.sort(key=key)
+ for arg in args:
+ for subtree in self._preorder_traversal(arg, key):
yield subtree
elif iterable(node):
for item in node:
- for subtree in self._preorder_traversal(item):
+ for subtree in self._preorder_traversal(item, key):
yield subtree
def skip(self):
40 sympy/core/function.py
View
@@ -1347,7 +1347,7 @@ class Subs(Expr):
An example with several variables:
- >>> Subs(f(x)*sin(y)+z, (x, y), (0, 1))
+ >>> Subs(f(x)*sin(y) + z, (x, y), (0, 1))
Subs(z + f(_x)*sin(_y), (_x, _y), (0, 1))
>>> _.doit()
z + f(0)*sin(1)
@@ -1359,14 +1359,12 @@ def __new__(cls, expr, variables, point, **assumptions):
variables = Tuple(*sympify(variables))
if uniq(variables) != variables:
- repeated = repeated = [ v for v in set(variables)
+ repeated = [ v for v in set(variables)
if list(variables).count(v) > 1 ]
raise ValueError('cannot substitute expressions %s more than '
'once.' % repeated)
- if not is_sequence(point, Tuple):
- point = [point]
- point = Tuple(*sympify(point))
+ point = Tuple(*sympify(point if is_sequence(point, Tuple) else [point]))
if len(point) != len(variables):
raise ValueError('Number of point values must be the same as '
@@ -1417,32 +1415,16 @@ def free_symbols(self):
def __eq__(self, other):
if not isinstance(other, Subs):
return False
- if (len(self.point) != len(other.point) or
- self.free_symbols != other.free_symbols or
- sorted(self.point) != sorted(other.point)):
+
+ if len(self.expr.free_symbols) != len(other.expr.free_symbols):
return False
- # non-repeated point args
- selfargs = [ v[0] for v in sorted(zip(self.variables, self.point),
- key = lambda v: v[1]) if list(self.point.args).count(v[1]) == 1 ]
- otherargs = [ v[0] for v in sorted(zip(other.variables, other.point),
- key = lambda v: v[1]) if list(other.point.args).count(v[1]) == 1 ]
- # find repeated point values and subs each associated variable
- # for a single symbol
- selfrepargs = []
- otherrepargs = []
- if uniq(self.point) != self.point:
- repeated = uniq([ v for v in self.point if
- list(self.point.args).count(v) > 1 ])
- repswap = dict(zip(repeated, [ C.Dummy() for _ in
- xrange(len(repeated)) ]))
- selfrepargs = [ (self.variables[i], repswap[v]) for i, v in
- enumerate(self.point) if v in repeated ]
- otherrepargs = [ (other.variables[i], repswap[v]) for i, v in
- enumerate(other.point) if v in repeated ]
-
- return self.expr.subs(selfrepargs) == other.expr.subs(
- tuple(zip(otherargs, selfargs))).subs(otherrepargs)
+ # replace points with dummies, each unique pt getting its own dummy
+ pts = set(self.point.args + other.point.args)
+ d = dict([(p, Dummy()) for p in pts])
+ eq = lambda e: \
+ e.expr.xreplace(dict(zip(e.variables, [d[p] for p in e.point])))
+ return eq(self) == eq(other)
def __ne__(self, other):
return not(self == other)
7 sympy/core/tests/test_basic.py
View
@@ -3,6 +3,8 @@
from sympy.core.basic import Basic, Atom, preorder_traversal
from sympy.core.singleton import S, Singleton
+from sympy.core.symbol import symbols
+from sympy.utilities.misc import default_sort_key
from sympy.utilities.pytest import raises
@@ -117,3 +119,8 @@ def test_preorder_traversal():
if i == b2:
pt.skip()
assert result == [expr, b21, b2, b1, b3, b2]
+
+ w, x, y, z = symbols('w:z')
+ expr = z + w*(x+y)
+ assert list(preorder_traversal([expr], key=default_sort_key)) == \
+ [[w*(x + y) + z], w*(x + y) + z, z, w*(x + y), w, x + y, x, y]
6 sympy/core/tests/test_functions.py
View
@@ -136,6 +136,12 @@ def test_Lambda_equality():
def test_Subs():
assert Subs(f(x), x, 0).doit() == f(0)
assert Subs(f(x**2), x**2, 0).doit() == f(0)
+ assert Subs(f(x, y, z), (x, y, z), (0, 1, 1)) != \
+ Subs(f(x, y, z), (x, y, z), (0, 0, 1))
+ assert Subs(f(x, y), (x, y, z), (0, 1, 1)) == \
+ Subs(f(x, y), (x, y, z), (0, 1, 2))
+ assert Subs(f(x, y), (x, y, z), (0, 1, 1)) != \
+ Subs(f(x, y) + z, (x, y, z), (0, 1, 0))
assert Subs(f(x, y), (x, y), (0, 1)).doit() == f(0, 1)
assert Subs(Subs(f(x, y), x, 0), y, 1).doit() == f(0, 1)
raises(ValueError, lambda: Subs(f(x, y), (x, y), (0, 0, 1)))
73 sympy/solvers/solvers.py
View
@@ -434,10 +434,10 @@ def solve(f, *symbols, **flags):
[3]
>>> solve(Poly(x - 3), x)
[3]
- >>> solve(x**2 - y**2, x)
- [-y, y]
- >>> solve(x**4 - 1, x)
- [-1, 1, -I, I]
+ >>> set(solve(x**2 - y**2, x))
+ set([-y, y])
+ >>> set(solve(x**4 - 1, x))
+ set([-1, 1, -I, I])
* single expression with no symbol that is in the expression
@@ -455,9 +455,9 @@ def solve(f, *symbols, **flags):
>>> solve(x - 3)
[3]
- >>> solve(x**2 - y**2)
+ >>> solve(x**2 - y**2) # doctest: +SKIP
[{x: -y}, {x: y}]
- >>> solve(z**2*x**2 - z**2*y**2)
+ >>> solve(z**2*x**2 - z**2*y**2) # doctest: +SKIP
[{x: -y}, {x: y}]
>>> solve(z**2*x - z**2*y**2)
[{x: y**2}]
@@ -473,8 +473,8 @@ def solve(f, *symbols, **flags):
[x + f(x)]
>>> solve(f(x).diff(x) - f(x) - x, f(x))
[-x + Derivative(f(x), x)]
- >>> solve(x + exp(x)**2, exp(x))
- [-sqrt(-x), sqrt(-x)]
+ >>> set(solve(x + exp(x)**2, exp(x)))
+ set([-sqrt(-x), sqrt(-x)])
* To solve for a *symbol* implicitly, use 'implicit=True':
@@ -501,17 +501,17 @@ def solve(f, *symbols, **flags):
* that are nonlinear
- >>> solve((a + b)*x - b**2 + 2, a, b)
- [(-sqrt(2), sqrt(2)), (sqrt(2), -sqrt(2))]
+ >>> set(solve((a + b)*x - b**2 + 2, a, b))
+ set([(-sqrt(2), sqrt(2)), (sqrt(2), -sqrt(2))])
* if there is no linear solution then the first successful
attempt for a nonlinear solution will be returned
- >>> solve(x**2 - y**2, x, y)
+ >>> solve(x**2 - y**2, x, y) # doctest: +SKIP
[{x: -y}, {x: y}]
>>> solve(x**2 - y**2/exp(x), x, y)
[{x: 2*LambertW(y/2)}]
- >>> solve(x**2 - y**2/exp(x), y, x)
+ >>> solve(x**2 - y**2/exp(x), y, x) # doctest: +SKIP
[{y: -x*exp(x/2)}, {y: x*exp(x/2)}]
* iterable of one or more of the above
@@ -543,8 +543,8 @@ def solve(f, *symbols, **flags):
* when the system is not linear
- >>> solve([x**2 + y -2, y**2 - 4], x, y)
- [(-2, -2), (0, 2), (0, 2), (2, -2)]
+ >>> set(solve([x**2 + y -2, y**2 - 4], x, y))
+ set([(-2, -2), (0, 2), (2, -2)])
* if no symbols are given, all free symbols will be selected and a list
of mappings returned
@@ -659,7 +659,7 @@ def _sympified_list(w):
# we do this to make the results returned canonical in case f
# contains a system of nonlinear equations; all other cases should
# be unambiguous
- symbols = sorted(symbols, key=lambda i: i.sort_key())
+ symbols = sorted(symbols, key=default_sort_key)
# we can solve for Function and Derivative instances by replacing them
# with Dummy symbols or functions
@@ -1287,9 +1287,10 @@ def _solve_system(exprs, symbols, **flags):
else:
if len(symbols) != len(polys):
from sympy.utilities.iterables import subsets
- free = list(reduce(set.union,
- [p.free_symbols for p in polys], set()
- ).intersection(symbols))
+ from sympy.core.compatibility import set_union
+
+ free = set_union(*[p.free_symbols for p in polys])
+ free = list(free.intersection(symbols))
free.sort(key=default_sort_key)
for syms in subsets(free, len(polys)):
try:
@@ -1341,16 +1342,26 @@ def _solve_system(exprs, symbols, **flags):
result = [result]
else:
result = [{}]
+
+ def _ok_syms(e, sort=False):
+ rv = (e.free_symbols - solved_syms) & legal
+ if sort:
+ rv = list(rv)
+ rv.sort(key=default_sort_key)
+ return rv
+
solved_syms = set(solved_syms) # set of symbols we have solved for
legal = set(symbols) # what we are interested in
simplify_flag = flags.get('simplify', None)
do_simplify = flags.get('simplify', True)
- # sort so equation with the fewest potential symbols is first; break ties with
- # count_ops and sort_key
- short = sift(failed, lambda x: len((x.free_symbols - solved_syms) & legal))
+ # sort so equation with the fewest potential symbols is first;
+ # break ties with count_ops and default_sort_key
+ short = sift(failed, lambda x: len(_ok_syms(x)))
failed = []
- for k in sorted(short):
- failed.extend(sorted(short[k], key=lambda x: (x.count_ops(), x.sort_key())))
+ for k in sorted(short, key=default_sort_key):
+ failed.extend(sorted(sorted(short[k],
+ key=lambda x: x.count_ops()),
+ key=default_sort_key))
for eq in failed:
newresult = []
got_s = None
@@ -1365,7 +1376,7 @@ def _solve_system(exprs, symbols, **flags):
continue
# search for a symbol amongst those available that
# can be solved for
- ok_syms = (eq2.free_symbols - solved_syms) & legal
+ ok_syms = _ok_syms(eq2, sort=True)
if not ok_syms:
break # skip as it's independent of desired symbols
for s in ok_syms:
@@ -1591,7 +1602,7 @@ def solve_linear_system(system, *symbols, **flags):
matrix = system[:, :]
syms = list(symbols)
- i, m = 0, matrix.cols-1 # don't count augmentation
+ i, m = 0, matrix.cols - 1 # don't count augmentation
while i < matrix.rows:
if i == m:
@@ -1611,12 +1622,12 @@ def solve_linear_system(system, *symbols, **flags):
break
else:
if matrix[i, m]:
- # we need to know this this is always zero or not. We
+ # we need to know if this is always zero or not. We
# assume that if there are free symbols that it is not
# identically zero (or that there is more than one way
# to make this zero. Otherwise, if there are none, this
# is a constant and we assume that it does not simplify
- # to zero XXX are there better ways to test this/
+ # to zero XXX are there better ways to test this?
if not matrix[i, m].free_symbols:
return None # no solution
@@ -1664,7 +1675,7 @@ def solve_linear_system(system, *symbols, **flags):
# divide all elements in the current row by the pivot
matrix.row(i, lambda x, _: x * pivot_inv)
- for k in xrange(i+1, matrix.rows):
+ for k in xrange(i + 1, matrix.rows):
if matrix[k, i]:
coeff = matrix[k, i]
@@ -1689,7 +1700,7 @@ def solve_linear_system(system, *symbols, **flags):
content = matrix[k, m]
# run back-substitution for variables
- for j in xrange(k+1, m):
+ for j in xrange(k + 1, m):
content -= matrix[k, j]*solutions[syms[j]]
if do_simplify:
@@ -1703,13 +1714,13 @@ def solve_linear_system(system, *symbols, **flags):
elif len(syms) > matrix.rows:
# this system will have infinite number of solutions
# dependent on exactly len(syms) - i parameters
- k, solutions = i-1, {}
+ k, solutions = i - 1, {}
while k >= 0:
content = matrix[k, m]
# run back-substitution for variables
- for j in xrange(k+1, i):
+ for j in xrange(k + 1, i):
content -= matrix[k, j]*solutions[syms[j]]
# run back-substitution for parameters
209 sympy/solvers/tests/test_solvers.py
View
@@ -30,9 +30,6 @@ def guess_solve_strategy(eq, symbol):
return False
def test_guess_poly():
- """
- See solvers.guess_solve_strategy
- """
# polynomial equations
assert guess_solve_strategy( S(4), x ) #== GS_POLY
assert guess_solve_strategy( x, x ) #== GS_POLY
@@ -76,8 +73,8 @@ def test_guess_transcendental():
def test_solve_args():
#implicit symbol to solve for
- assert set(int(tmp) for tmp in solve(x**2-4)) == set([2,-2])
- assert solve([x+y-3,x-y-5]) == {x: 4, y: -1}
+ assert set(solve(x**2 - 4)) == set([S(2), -S(2)])
+ assert solve([x + y - 3, x - y - 5]) == {x: 4, y: -1}
#no symbol to solve for
assert solve(42) == []
assert solve([1, 2]) == []
@@ -118,14 +115,14 @@ def test_solve_args():
assert solve((x + y - 2, 2*x + 2*y - 4)) == {x: -y + 2}
def test_solve_polynomial1():
- assert solve(3*x-2, x) == [Rational(2,3)]
- assert solve(Eq(3*x, 2), x) == [Rational(2,3)]
+ assert solve(3*x-2, x) == [Rational(2, 3)]
+ assert solve(Eq(3*x, 2), x) == [Rational(2, 3)]
- assert solve(x**2-1, x) in [[-1, 1], [1, -1]]
- assert solve(Eq(x**2, 1), x) in [[-1, 1], [1, -1]]
+ assert set(solve(x**2 - 1, x)) == set([-S(1), S(1)])
+ assert set(solve(Eq(x**2, 1), x)) == set([-S(1), S(1)])
- assert solve( x - y**3, x) == [y**3]
- assert sorted(solve( x - y**3, y)) == sorted([
+ assert solve(x - y**3, x) == [y**3]
+ assert set(solve(x - y**3, y)) == set([
(-x**Rational(1,3))/2 + I*sqrt(3)*x**Rational(1,3)/2,
x**Rational(1,3),
(-x**Rational(1,3))/2 - I*sqrt(3)*x**Rational(1,3)/2,
@@ -149,8 +146,8 @@ def test_solve_polynomial1():
S(4),
-2 - 3**Rational(1,2) ])
- assert sorted(solve((x**2 - 1)**2 - a, x)) == \
- sorted([sqrt(1 + sqrt(a)), -sqrt(1 + sqrt(a)),
+ assert set(solve((x**2 - 1)**2 - a, x)) == \
+ set([sqrt(1 + sqrt(a)), -sqrt(1 + sqrt(a)),
sqrt(1 - sqrt(a)), -sqrt(1 - sqrt(a))])
def test_solve_polynomial2():
@@ -225,7 +222,7 @@ def test_tsolve():
assert set(solve((a*x+b)*(exp(x)-3), x)) == set([-b/a, log(3)])
assert solve(cos(x)-y, x) == [acos(y)]
assert solve(2*cos(x)-y,x)== [acos(y/2)]
- assert solve(Eq(cos(x), sin(x)), x) == [-3*pi/4, pi/4]
+ assert set(solve(Eq(cos(x), sin(x)), x)) == set([-3*pi/4, pi/4])
assert set(solve(exp(x) + exp(-x) - y, x)) == set([
log(y/2 - sqrt(y**2 - 4)/2),
@@ -240,8 +237,8 @@ def test_tsolve():
assert solve(x+2**x, x) == [-LambertW(log(2))/log(2)]
assert solve(3*x+5+2**(-5*x+3), x) in [
[-((25*log(2) - 3*LambertW(-10240*2**(Rational(1, 3))*log(2)/3))/(15*log(2)))],
- [Rational(-5, 3) + LambertW(log(2**(-10240*2**(Rational(1, 3))/3)))/(5*log(2))],
- [-Rational(5,3) + LambertW(-10240*2**Rational(1,3)*log(2)/3)/(5*log(2))],
+ [-Rational(5, 3) + LambertW(log(2**(-10240*2**(Rational(1, 3))/3)))/(5*log(2))],
+ [-Rational(5, 3) + LambertW(-10240*2**Rational(1,3)*log(2)/3)/(5*log(2))],
[(-25*log(2) + 3*LambertW(-10240*2**(Rational(1, 3))*log(2)/3))/(15*log(2))],
[-((25*log(2) - 3*LambertW(-10240*2**(Rational(1, 3))*log(2)/3)))/(15*log(2))],
[-(25*log(2) - 3*LambertW(log(2**(-10240*2**Rational(1, 3)/3))))/(15*log(2))],
@@ -385,8 +382,10 @@ def test_issue_1694():
eq = 4*3**(5*x + 2) - 7
ans = solve(eq, x)
assert len(ans) == 5 and all(eq.subs(x, a).n(chop=True) == 0 for a in ans)
- assert solve(log(x**2) - y**2/exp(x), x, y) == [{y: -sqrt(exp(x)*log(x**2))},
- {y: sqrt(exp(x)*log(x**2))}]
+ assert solve(log(x**2) - y**2/exp(x), x, y, set=True) == \
+ ([y], set([
+ (-sqrt(exp(x)*log(x**2)),),
+ (sqrt(exp(x)*log(x**2)),)]))
assert solve(x**2*z**2 - z**2*y**2) in ([{x: y}, {x: -y}], [{x: -y}, {x: y}])
assert solve((x - 1)/(1 + 1/(x - 1))) == []
assert solve(x**(y*z) - x, x) == [1]
@@ -401,7 +400,7 @@ def test_issue_1694():
# 1387
assert solve(2*x/(x + 2) - 1,x) == [2]
# 1397
- assert solve((x**2/(7 - x)).diff(x)) == [0, 14]
+ assert set(solve((x**2/(7 - x)).diff(x))) == set([S(0), S(14)])
# 1596
f = Function('f')
assert solve((3 - 5*x/f(x))*f(x), f(x)) == [5*x/3]
@@ -409,22 +408,23 @@ def test_issue_1694():
assert solve(1/(5 + x)**(S(1)/5) - 9, x) == [-295244/S(59049)]
assert solve(sqrt(x) + sqrt(sqrt(x)) - 4) == [-9*sqrt(17)/2 + 49*S.Half]
- assert solve(Poly(sqrt(exp(x)) + sqrt(exp(-x)) - 4)) in \
+ assert set(solve(Poly(sqrt(exp(x)) + sqrt(exp(-x)) - 4))) in \
[
- [2*log(-sqrt(3) + 2), 2*log(sqrt(3) + 2)],
- [log(-4*sqrt(3) + 7), log(4*sqrt(3) + 7)],
+ set([2*log(-sqrt(3) + 2), 2*log(sqrt(3) + 2)]),
+ set([log(-4*sqrt(3) + 7), log(4*sqrt(3) + 7)]),
]
- assert solve(Poly(exp(x) + exp(-x) - 4)) == [log(-sqrt(3) + 2), log(sqrt(3) + 2)]
- assert solve(x**y + x**(2*y) - 1, x) == \
- [(-S.Half + sqrt(5)/2)**(1/y), (-S.Half - sqrt(5)/2)**(1/y)]
+ assert set(solve(Poly(exp(x) + exp(-x) - 4))) == \
+ set([log(-sqrt(3) + 2), log(sqrt(3) + 2)])
+ assert set(solve(x**y + x**(2*y) - 1, x)) == \
+ set([(-S.Half + sqrt(5)/2)**(1/y), (-S.Half - sqrt(5)/2)**(1/y)])
assert solve(exp(x/y)*exp(-z/y) - 2, y) == [(x - z)/log(2)]
assert solve(x**z*y**z - 2, z) in [[log(2)/(log(x) + log(y))], [log(2)/(log(x*y))]]
# if you do inversion too soon then multiple roots as for the following will
# be missed, e.g. if exp(3*x) = exp(3) -> 3*x = 3
E = S.Exp1
- assert solve(exp(3*x) - exp(3), x) == \
- [1, log(-E/2 - sqrt(3)*E*I/2), log(-E/2 + sqrt(3)*E*I/2)]
+ assert set(solve(exp(3*x) - exp(3), x)) == \
+ set([S(1), log(-E/2 - sqrt(3)*E*I/2), log(-E/2 + sqrt(3)*E*I/2)])
def test_issue_2098():
x = Symbol('x', real=True)
@@ -461,9 +461,9 @@ def test_checking():
def test_issue_1572_1364_1368():
assert solve((sqrt(x**2 - 1) - 2)) in ([sqrt(5), -sqrt(5)],
[-sqrt(5), sqrt(5)])
- assert solve((2**exp(y**2/x) + 2)/(x**2 + 15), y) == [
+ assert set(solve((2**exp(y**2/x) + 2)/(x**2 + 15), y)) == set([
-sqrt(x)*sqrt(-log(log(2)) + log(log(2) + I*pi)),
- sqrt(x)*sqrt(-log(log(2)) + log(log(2) + I*pi))]
+ sqrt(x)*sqrt(-log(log(2)) + log(log(2) + I*pi))])
C1, C2 = symbols('C1 C2')
f = Function('f')
@@ -480,65 +480,65 @@ def test_issue_1572_1364_1368():
assert solve(1 - log(a + 4*x**2), x) in (
[-sqrt(-a + E)/2, sqrt(-a + E)/2],
[sqrt(-a + E)/2, -sqrt(-a + E)/2],)
- assert solve((a**2 + 1) * (sin(a*x) + cos(a*x)), x) == [-pi/(4*a), 3*pi/(4*a)]
+ assert set(solve((a**2 + 1) * (sin(a*x) + cos(a*x)), x)) == set([-pi/(4*a), 3*pi/(4*a)])
assert solve(3 - (sinh(a*x) + cosh(a*x)), x) == [2*atanh(S.Half)/a]
- assert solve(3-(sinh(a*x) + cosh(a*x)**2), x) == \
- [
+ assert set(solve(3-(sinh(a*x) + cosh(a*x)**2), x)) == \
+ set([
2*atanh(-1 + sqrt(2))/a,
2*atanh(S(1)/2 + sqrt(5)/2)/a,
2*atanh(-sqrt(2) - 1)/a,
2*atanh(-sqrt(5)/2 + S(1)/2)/a
- ]
+ ])
assert solve(atan(x) - 1) == [tan(1)]
def test_issue_2033():
r, t = symbols('r,t')
- assert solve([r - x**2 - y**2, tan(t) - y/x], [x, y]) == \
- [
+ assert set(solve([r - x**2 - y**2, tan(t) - y/x], [x, y])) == \
+ set([
(-sqrt(r*sin(t)**2)/tan(t), -sqrt(r*sin(t)**2)),
- (sqrt(r*sin(t)**2)/tan(t), sqrt(r*sin(t)**2))]
+ (sqrt(r*sin(t)**2)/tan(t), sqrt(r*sin(t)**2))])
assert solve([exp(x) - sin(y), 1/y - 3], [x, y]) == \
[(log(sin(S(1)/3)), S(1)/3)]
assert solve([exp(x) - sin(y), 1/exp(y) - 3], [x, y]) == \
[(log(-sin(log(3))), -log(3))]
- assert solve([exp(x) - sin(y), y**2 - 4], [x, y]) == \
- [(log(-sin(2)), -2), (log(sin(2)), 2)]
+ assert set(solve([exp(x) - sin(y), y**2 - 4], [x, y])) == \
+ set([(log(-sin(2)), -S(2)), (log(sin(2)), S(2))])
eqs = [exp(x)**2 - sin(y) + z**2, 1/exp(y) - 3]
- assert solve(eqs) == \
- [
- {x: log(-sqrt(-z**2 - sin(log(3)))), y: -log(3)},
- {x: log(sqrt(-z**2 - sin(log(3)))), y: -log(3)}]
- assert solve(eqs, x, z) == \
- [
- {x: log(-sqrt(-z**2 + sin(y)))},
- {x: log(sqrt(-z**2 + sin(y)))}]
- assert solve(eqs, x, y) == \
- [
+ assert solve(eqs, set=True) == \
+ ([x, y], set([
(log(-sqrt(-z**2 - sin(log(3)))), -log(3)),
- (log(sqrt(-z**2 - sin(log(3)))), -log(3))]
- assert solve(eqs, y, z) == \
- [
+ (log(sqrt(-z**2 - sin(log(3)))), -log(3))]))
+ assert solve(eqs, x, z, set=True) == \
+ ([x], set([
+ (log(-sqrt(-z**2 + sin(y))),),
+ (log(sqrt(-z**2 + sin(y))),)]))
+ assert set(solve(eqs, x, y)) == \
+ set([
+ (log(-sqrt(-z**2 - sin(log(3)))), -log(3)),
+ (log(sqrt(-z**2 - sin(log(3)))), -log(3))])
+ assert set(solve(eqs, y, z)) == \
+ set([
(-log(3), -sqrt(-exp(2*x) - sin(log(3)))),
- (-log(3), sqrt(-exp(2*x) - sin(log(3))))]
+ (-log(3), sqrt(-exp(2*x) - sin(log(3))))])
eqs = [exp(x)**2 - sin(y) + z, 1/exp(y) - 3]
- assert solve(eqs) == \
+ assert solve(eqs, set=True) == ([x, y], set(
[
- {x: log(-sqrt(-z - sin(log(3)))), y: -log(3)},
- {x: log(sqrt(-z - sin(log(3)))), y: -log(3)}]
- assert solve(eqs, x, z) == \
+ (log(-sqrt(-z - sin(log(3)))), -log(3)),
+ (log(sqrt(-z - sin(log(3)))), -log(3))]))
+ assert solve(eqs, x, z, set=True) == ([x], set(
[
- {x: log(-sqrt(-z + sin(y)))},
- {x: log(sqrt(-z + sin(y)))}]
- assert solve(eqs, x, y) == \
+ (log(-sqrt(-z + sin(y))),),
+ (log(sqrt(-z + sin(y))),)]))
+ assert set(solve(eqs, x, y)) == set(
[
(log(-sqrt(-z - sin(log(3)))), -log(3)),
- (log(sqrt(-z - sin(log(3)))), -log(3))]
+ (log(sqrt(-z - sin(log(3)))), -log(3))])
assert solve(eqs, z, y) == \
[(-exp(2*x) - sin(log(3)), -log(3))]
- assert solve((sqrt(x**2 + y**2) - sqrt(10), x + y - 4)) == \
- [{x: 1, y: 3}, {x: 3, y: 1}]
- assert solve((sqrt(x**2 + y**2) - sqrt(10), x + y - 4), x, y) == \
- [(1, 3), (3, 1)]
+ assert solve((sqrt(x**2 + y**2) - sqrt(10), x + y - 4), set=True) == (
+ [x, y], set([(S(1), S(3)), (S(3), S(1))]))
+ assert set(solve((sqrt(x**2 + y**2) - sqrt(10), x + y - 4), x, y)) == \
+ set([(S(1), S(3)), (S(3), S(1))])
def test_issue_2236():
lam, a0, conc = symbols('lam a0 conc')
@@ -559,11 +559,13 @@ def test_issue_2236_float():
assert len(solve(eqs, sym, rational=False, check=False, simplify=False)) == 2
def test_issue_2668():
- assert solve([x**2 + y + 4], [x]) == [(-sqrt(-y - 4),), (sqrt(-y - 4),)]
+ assert set(solve([x**2 + y + 4], [x])) == \
+ set([(-sqrt(-y - 4),), (sqrt(-y - 4),)])
def test_polysys():
- assert solve([x**2 + 2/y - 2 , x + y - 3], [x, y]) == \
- [(1, 2), (1 + sqrt(5), 2 - sqrt(5)), (1 - sqrt(5), 2 + sqrt(5))]
+ assert set(solve([x**2 + 2/y - 2 , x + y - 3], [x, y])) == \
+ set([(S(1), S(2)), (1 + sqrt(5), 2 - sqrt(5)),
+ (1 - sqrt(5), 2 + sqrt(5))])
assert solve([x**2 + y - 2, x**2 + y]) == []
# the ordering should be whatever the user requested
assert solve([x**2 + y - 3, x - y - 4], (x, y)) != solve([x**2 + y - 3, x - y - 4], (y, x))
@@ -626,10 +628,11 @@ def s_check(rv, ans):
eq = sqrt(x) + sqrt(x + 1) + sqrt(1 - sqrt(x))
assert check(unrad(eq),
(16*x**3 - 9*x**2, [], []))
- assert solve(eq, check=False) == [0, S(9)/16]
+ assert set(solve(eq, check=False)) == set([S(0), S(9)/16])
assert solve(eq) == []
# but this one really does have those solutions
- assert solve(sqrt(x) - sqrt(x + 1) + sqrt(1 - sqrt(x))) == [0, S(9)/16]
+ assert set(solve(sqrt(x) - sqrt(x + 1) + sqrt(1 - sqrt(x)))) == \
+ set([S.Zero, S(9)/16])
'''real_root changes the value of the result if the solution is
simplified; `a` in the text below is the root that is not 4/5:
@@ -661,8 +664,8 @@ def s_check(rv, ans):
ans = solve(sqrt(x) + sqrt(x + 1) + sqrt(1 - x) - 6*sqrt(5)/5)
assert all(abs(eq.subs(x, i).n()) < 1e-10 for i in (ra, rb)) and \
len(ans) == 2 and \
- sorted([i.n(chop=True) for i in ans]) == \
- sorted([i.n(chop=True) for i in (ra, rb)])
+ set([i.n(chop=True) for i in ans]) == \
+ set([i.n(chop=True) for i in (ra, rb)])
ans = solve(sqrt(x) + sqrt(x + 1) - \
sqrt(1 - x) - sqrt(2 + x))
@@ -694,15 +697,15 @@ def s_check(rv, ans):
assert solve(Eq(x, sqrt(x + 6))) == [3]
assert solve(Eq(x + sqrt(x - 4), 4)) == [4]
assert solve(Eq(1, x + sqrt(2*x - 3))) == []
- assert solve(Eq(sqrt(5*x + 6) - 2, x)) == [-1, 2]
- assert solve(Eq(sqrt(2*x - 1) - sqrt(x - 4), 2)) == [5, 13]
+ assert set(solve(Eq(sqrt(5*x + 6) - 2, x))) == set([-S(1), S(2)])
+ assert set(solve(Eq(sqrt(2*x - 1) - sqrt(x - 4), 2))) == set([S(5), S(13)])
assert solve(Eq(sqrt(x + 7) + 2, sqrt(3 - x))) == [-6]
# http://www.purplemath.com/modules/solverad.htm
assert solve((2*x - 5)**Rational(1, 3) - 3) == [16]
assert solve((x**3 - 3*x**2)**Rational(1, 3) + 1 - x) == []
- assert solve(x + 1 - (x**4 + 4*x**3 - x)**Rational(1, 4)) == \
- [-S(1)/2, -S(1)/3]
- assert solve(sqrt(2*x**2 - 7) - (3 - x)) == [-8, 2]
+ assert set(solve(x + 1 - (x**4 + 4*x**3 - x)**Rational(1, 4))) == \
+ set([-S(1)/2, -S(1)/3])
+ assert set(solve(sqrt(2*x**2 - 7) - (3 - x))) == set([-S(8), S(2)])
assert solve(sqrt(2*x + 9) - sqrt(x + 1) - sqrt(x + 4)) == [0]
assert solve(sqrt(x + 4) + sqrt(2*x - 1) - 3*sqrt(x - 1)) == [5]
assert solve(sqrt(x)*sqrt(x - 7) - 12) == [16]
@@ -784,14 +787,13 @@ def test_issue_2750():
)
ans = [{
- dI4: -I3 + 3*I5 - 2*Q4,
+ dQ4: I3 - I5,
dI1: -4*I2 - 8*I3 - 4*I5 - 6*I6 + 24,
+ I4: I3 - I5,
dQ2: I2,
- I1: I2 + I3,
Q2: 2*I3 + 2*I5 + 3*I6,
- dQ4: I3 - I5,
- I4: I3 - I5,
- }]
+ I1: I2 + I3,
+ Q4: -I3/2 + 3*I5/2 - dI4/2}]
assert solve(e, I1, I4, Q2, Q4, dI1, dI4, dQ2, dQ4, manual=True) == ans
# the matrix solver (tested below) doesn't like this because it produces
# a zero row in the matrix. Is this related to issue 1452?
@@ -837,18 +839,21 @@ def test_issue_2802():
{f(x): 3*D}
assert solve([f(x) - 3*f(x).diff(x), f(x)**2 - y + 4], f(x), y) == \
[{f(x): 3*D, y: 9*D**2 + 4}]
- assert solve(-f(a)**2*g(a)**2 + f(a)**2*h(a)**2 + g(a).diff(a), h(a), g(a)) == \
- [{g(a): -sqrt(h(a)**2 + G/f(a)**2)},
- {g(a): sqrt(h(a)**2 + G/f(a)**2)}]
+ assert solve(-f(a)**2*g(a)**2 + f(a)**2*h(a)**2 + g(a).diff(a),
+ h(a), g(a), set=True) == \
+ ([g(a)], set([
+ (-sqrt(h(a)**2 + G/f(a)**2),),
+ (sqrt(h(a)**2 + G/f(a)**2),)]))
args = [f(x).diff(x, 2)*(f(x) + g(x)) - g(x)**2 + 2, f(x), g(x)]
- assert solve(*args) == \
- [(-sqrt(2), sqrt(2)), (sqrt(2), -sqrt(2))]
+ assert set(solve(*args)) == \
+ set([(-sqrt(2), sqrt(2)), (sqrt(2), -sqrt(2))])
eqs = [f(x)**2 + g(x) - 2*f(x).diff(x), g(x)**2 - 4]
- assert solve(eqs, f(x), g(x)) == \
- [{g(x): 2, f(x): -sqrt(2*D - 2)},
- {g(x): 2, f(x): sqrt(2*D - 2)},
- {g(x): -2, f(x): -sqrt(2*D + 2)},
- {g(x): -2, f(x): sqrt(2*D + 2)}]
+ assert solve(eqs, f(x), g(x), set=True) == \
+ ([f(x), g(x)], set([
+ (-sqrt(2*D - 2), S(2)),
+ (sqrt(2*D - 2), S(2)),
+ (-sqrt(2*D + 2), -S(2)),
+ (sqrt(2*D + 2), -S(2))]))
# the underlying problem was in solve_linear that was not masking off
# anything but a Mul or Add; it now raises an error if it gets anything
@@ -872,8 +877,8 @@ def test_issue_2802():
assert solve_linear(x + Integral(x, y) - 2, symbols=[x]) == \
(x, 2/(y + 1))
- assert solve(x + exp(x)**2, exp(x)) == \
- [-sqrt(-x), sqrt(-x)]
+ assert set(solve(x + exp(x)**2, exp(x))) == \
+ set([-sqrt(-x), sqrt(-x)])
assert solve(x + exp(x), x, implicit=True) == \
[-exp(x)]
assert solve(cos(x) - sin(x), x, implicit=True) == []
@@ -885,8 +890,8 @@ def test_issue_2802():
[-x + 3]
def test_issue_2813():
- assert solve(x**2 - x - 0.1, rational=True) == \
- [S(1)/2 + sqrt(35)/10, -sqrt(35)/10 + S(1)/2]
+ assert set(solve(x**2 - x - 0.1, rational=True)) == \
+ set([S(1)/2 + sqrt(35)/10, -sqrt(35)/10 + S(1)/2])
# [-0.0916079783099616, 1.09160797830996]
ans = solve(x**2 - x - 0.1, rational=False)
assert len(ans) == 2 and all(a.is_Number for a in ans)
@@ -923,18 +928,18 @@ def test_check_assumptions():
assert solve(x**2 - 1) == [1]
def test_solve_abs():
- assert solve(abs(x - 7) - 8) == [-1, 15]
+ assert set(solve(abs(x - 7) - 8)) == set([-S(1), S(15)])
def test_issue_2957():
assert solve(tanh(x + 3)*tanh(x - 3) - 1) == []
- assert solve(tanh(x - 1)*tanh(x + 1) + 1) == [
+ assert set(solve(tanh(x - 1)*tanh(x + 1) + 1)) == set([
-log(2)/2 + log(-1 - I),
-log(2)/2 + log(-1 + I),
-log(2)/2 + log(1 - I),
- -log(2)/2 + log(1 + I)]
- assert solve((tanh(x + 3)*tanh(x - 3) + 1)**2) == \
- [-log(2)/2 + log(-1 - I), -log(2)/2 + log(-1 + I),
- -log(2)/2 + log(1 - I), -log(2)/2 + log(1 + I)]
+ -log(2)/2 + log(1 + I)])
+ assert set(solve((tanh(x + 3)*tanh(x - 3) + 1)**2)) == \
+ set([-log(2)/2 + log(-1 - I), -log(2)/2 + log(-1 + I),
+ -log(2)/2 + log(1 - I), -log(2)/2 + log(1 + I)])
def test_issue_2574():
eq = -x + exp(exp(LambertW(log(x)))*LambertW(log(x)))
@@ -965,4 +970,4 @@ def test_exclude():
def test_high_order_roots():
s = x**5 + 4*x**3 + 3*x**2 + S(7)/4
- assert solve(s) == Poly(s*4, domain='ZZ').all_roots()
+ assert set(solve(s)) == set(Poly(s*4, domain='ZZ').all_roots())
44 sympy/utilities/iterables.py
View
@@ -4,6 +4,7 @@
from sympy.core import Basic, C
from sympy.core.compatibility import is_sequence, iterable #logically, these belong here
from sympy.core.compatibility import product as cartes, combinations, combinations_with_replacement
+from sympy.utilities.misc import default_sort_key
from sympy.utilities.exceptions import SymPyDeprecationWarning
def flatten(iterable, levels=None, cls=None):
@@ -110,7 +111,7 @@ def group(container, multiple=True):
return groups
-def postorder_traversal(node):
+def postorder_traversal(node, key=None):
"""
Do a postorder traversal of a tree.
@@ -118,39 +119,46 @@ def postorder_traversal(node):
fashion. That is, it descends through the tree depth-first to yield all of
a node's children's postorder traversal before yielding the node itself.
- For an expression, the order of the traversal depends on the order of
- .args, which in many cases can be arbitrary.
-
Parameters
- ----------
+ ==========
node : sympy expression
The expression to traverse.
+ key : (default None) sort key
+ The key used to sort args of Basic objects. When None, args of Basic
+ objects are processed in arbitrary order.
Returns
- -------
+ =======
subtree : sympy expression
All of the subtrees in the tree.
Examples
- --------
- >>> from sympy import symbols
+ ========
+ >>> from sympy import symbols, default_sort_key
>>> from sympy.utilities.iterables import postorder_traversal
- >>> from sympy.abc import x, y, z
- >>> list(postorder_traversal(z*(x+y))) in ( # any of these are possible
- ... [z, y, x, x + y, z*(x + y)], [z, x, y, x + y, z*(x + y)],
- ... [x, y, x + y, z, z*(x + y)], [y, x, x + y, z, z*(x + y)])
- True
- >>> list(postorder_traversal((x, (y, z))))
- [x, y, z, (y, z), (x, (y, z))]
+ >>> from sympy.abc import w, x, y, z
+
+ The nodes are returned in the order that they are encountered unless key
+ is given.
+
+ >>> list(postorder_traversal(w + (x + y)*z)) # doctest: +SKIP
+ [z, y, x, x + y, z*(x + y), w, w + z*(x + y)]
+ >>> list(postorder_traversal(w + (x + y)*z, key=default_sort_key))
+ [w, z, x, y, x + y, z*(x + y), w + z*(x + y)]
+
"""
if isinstance(node, Basic):
- for arg in node.args:
- for subtree in postorder_traversal(arg):
+ args = node.args
+ if key:
+ args = list(args)
+ args.sort(key=key)
+ for arg in args:
+ for subtree in postorder_traversal(arg, key):
yield subtree
elif iterable(node):
for item in node:
- for subtree in postorder_traversal(item):
+ for subtree in postorder_traversal(item, key):
yield subtree
yield node
34 sympy/utilities/tests/test_iterables.py
View
@@ -1,4 +1,4 @@
-from sympy import symbols, Integral, Tuple, Dummy, Basic
+from sympy import symbols, Integral, Tuple, Dummy, Basic, default_sort_key
from sympy.utilities.iterables import (postorder_traversal, flatten, group,
take, subsets, variations, cartes, numbered_symbols, dict_merge,
prefixes, postfixes, sift, topological_sort, rotate_left, rotate_right,
@@ -12,26 +12,22 @@
w,x,y,z= symbols('w,x,y,z')
def test_postorder_traversal():
- expr = z+w*(x+y)
- expected1 = [z, w, y, x, x + y, w*(x + y), z + w*(x + y)]
- expected2 = [z, w, x, y, x + y, w*(x + y), z + w*(x + y)]
- expected3 = [w, y, x, x + y, w*(x + y), z, z + w*(x + y)]
- expected4 = [w, x, y, x + y, w*(x + y), z, z + w*(x + y)]
- expected5 = [x, y, x + y, w, w*(x + y), x, x + w*(x + y)]
- expected6 = [y, x, x + y, w, w*(x + y), x, x + w*(x + y)]
- assert list(postorder_traversal(expr)) in [expected1, expected2,
- expected3, expected4,
- expected5, expected6]
-
- expr = Piecewise((x,x<1),(x**2,True))
- assert list(postorder_traversal(expr)) == [
- x, x, 1, x < 1, ExprCondPair(x, x < 1), x, 2, x**2,
- ExprCondPair.true_sentinel,
+ expr = z + w*(x+y)
+ expected = [z, w, x, y, x + y, w*(x + y), w*(x + y) + z]
+ assert list(postorder_traversal(expr, key=default_sort_key)) == expected
+
+ expr = Piecewise((x, x < 1), (x**2, True))
+ expected = [
+ x, 1, x, x < 1, ExprCondPair(x, x < 1),
+ ExprCondPair.true_sentinel, 2, x, x**2,
ExprCondPair(x**2, True), Piecewise((x, x < 1), (x**2, True))
- ]
+ ]
+ assert list(postorder_traversal(expr, key=default_sort_key)) == expected
+ assert list(postorder_traversal([expr], key=default_sort_key)) == expected + [[expr]]
- assert list(postorder_traversal(Integral(x**2, (x, 0, 1)))) == [
- x, 2, x**2, x, 0, 1, Tuple(x, 0, 1),
+ assert list(postorder_traversal(Integral(x**2, (x, 0, 1)),
+ key=default_sort_key)) == [
+ 2, x, x**2, 0, 1, x, Tuple(x, 0, 1),
Integral(x**2, Tuple(x, 0, 1))
]
assert list(postorder_traversal(('abc', ('d', 'ef')))) == [
Please sign in to comment.
Something went wrong with that request. Please try again.