Skip to content
This repository
Browse code

Merge pull request #1411 from smichr/rand

Randomization fixes
  • Loading branch information...
commit 9ff4edf72e293ab00d7204fb0309b7ab90acdf44 2 parents 262bdc4 + bf84abd
Christopher Smith smichr authored
40 sympy/core/basic.py
@@ -1501,44 +1501,54 @@ class preorder_traversal(object):
1501 1501 .args, which in many cases can be arbitrary.
1502 1502
1503 1503 Parameters
1504   - ----------
  1504 + ==========
1505 1505 node : sympy expression
1506 1506 The expression to traverse.
  1507 + key : (default None) sort key
  1508 + The key used to sort args of Basic objects. When None, args of Basic
  1509 + objects are processed in arbitrary order.
1507 1510
1508 1511 Yields
1509   - ------
  1512 + ======
1510 1513 subtree : sympy expression
1511 1514 All of the subtrees in the tree.
1512 1515
1513 1516 Examples
1514   - --------
  1517 + ========
1515 1518 >>> from sympy import symbols
  1519 + >>> from sympy import symbols, default_sort_key
1516 1520 >>> from sympy.core.basic import preorder_traversal
1517 1521 >>> x, y, z = symbols('x y z')
1518   - >>> list(preorder_traversal(z*(x+y))) in ( # any of these are possible
1519   - ... [z*(x + y), z, x + y, x, y], [z*(x + y), z, x + y, y, x],
1520   - ... [z*(x + y), x + y, x, y, z], [z*(x + y), x + y, y, x, z])
1521   - True
1522   - >>> list(preorder_traversal((x, (y, z))))
1523   - [(x, (y, z)), x, (y, z), y, z]
  1522 +
  1523 + The nodes are returned in the order that they are encountered unless key
  1524 + is given.
  1525 +
  1526 + >>> list(preorder_traversal((x + y)*z, key=None)) # doctest: +SKIP
  1527 + [z*(x + y), z, x + y, y, x]
  1528 + >>> list(preorder_traversal((x + y)*z, key=default_sort_key))
  1529 + [z*(x + y), z, x + y, x, y]
1524 1530
1525 1531 """
1526   - def __init__(self, node):
  1532 + def __init__(self, node, key=None):
1527 1533 self._skip_flag = False
1528   - self._pt = self._preorder_traversal(node)
  1534 + self._pt = self._preorder_traversal(node, key)
1529 1535
1530   - def _preorder_traversal(self, node):
  1536 + def _preorder_traversal(self, node, key):
1531 1537 yield node
1532 1538 if self._skip_flag:
1533 1539 self._skip_flag = False
1534 1540 return
1535 1541 if isinstance(node, Basic):
1536   - for arg in node.args:
1537   - for subtree in self._preorder_traversal(arg):
  1542 + args = node.args
  1543 + if key:
  1544 + args = list(args)
  1545 + args.sort(key=key)
  1546 + for arg in args:
  1547 + for subtree in self._preorder_traversal(arg, key):
1538 1548 yield subtree
1539 1549 elif iterable(node):
1540 1550 for item in node:
1541   - for subtree in self._preorder_traversal(item):
  1551 + for subtree in self._preorder_traversal(item, key):
1542 1552 yield subtree
1543 1553
1544 1554 def skip(self):
40 sympy/core/function.py
@@ -1347,7 +1347,7 @@ class Subs(Expr):
1347 1347
1348 1348 An example with several variables:
1349 1349
1350   - >>> Subs(f(x)*sin(y)+z, (x, y), (0, 1))
  1350 + >>> Subs(f(x)*sin(y) + z, (x, y), (0, 1))
1351 1351 Subs(z + f(_x)*sin(_y), (_x, _y), (0, 1))
1352 1352 >>> _.doit()
1353 1353 z + f(0)*sin(1)
@@ -1359,14 +1359,12 @@ def __new__(cls, expr, variables, point, **assumptions):
1359 1359 variables = Tuple(*sympify(variables))
1360 1360
1361 1361 if uniq(variables) != variables:
1362   - repeated = repeated = [ v for v in set(variables)
  1362 + repeated = [ v for v in set(variables)
1363 1363 if list(variables).count(v) > 1 ]
1364 1364 raise ValueError('cannot substitute expressions %s more than '
1365 1365 'once.' % repeated)
1366 1366
1367   - if not is_sequence(point, Tuple):
1368   - point = [point]
1369   - point = Tuple(*sympify(point))
  1367 + point = Tuple(*sympify(point if is_sequence(point, Tuple) else [point]))
1370 1368
1371 1369 if len(point) != len(variables):
1372 1370 raise ValueError('Number of point values must be the same as '
@@ -1417,32 +1415,16 @@ def free_symbols(self):
1417 1415 def __eq__(self, other):
1418 1416 if not isinstance(other, Subs):
1419 1417 return False
1420   - if (len(self.point) != len(other.point) or
1421   - self.free_symbols != other.free_symbols or
1422   - sorted(self.point) != sorted(other.point)):
  1418 +
  1419 + if len(self.expr.free_symbols) != len(other.expr.free_symbols):
1423 1420 return False
1424 1421
1425   - # non-repeated point args
1426   - selfargs = [ v[0] for v in sorted(zip(self.variables, self.point),
1427   - key = lambda v: v[1]) if list(self.point.args).count(v[1]) == 1 ]
1428   - otherargs = [ v[0] for v in sorted(zip(other.variables, other.point),
1429   - key = lambda v: v[1]) if list(other.point.args).count(v[1]) == 1 ]
1430   - # find repeated point values and subs each associated variable
1431   - # for a single symbol
1432   - selfrepargs = []
1433   - otherrepargs = []
1434   - if uniq(self.point) != self.point:
1435   - repeated = uniq([ v for v in self.point if
1436   - list(self.point.args).count(v) > 1 ])
1437   - repswap = dict(zip(repeated, [ C.Dummy() for _ in
1438   - xrange(len(repeated)) ]))
1439   - selfrepargs = [ (self.variables[i], repswap[v]) for i, v in
1440   - enumerate(self.point) if v in repeated ]
1441   - otherrepargs = [ (other.variables[i], repswap[v]) for i, v in
1442   - enumerate(other.point) if v in repeated ]
1443   -
1444   - return self.expr.subs(selfrepargs) == other.expr.subs(
1445   - tuple(zip(otherargs, selfargs))).subs(otherrepargs)
  1422 + # replace points with dummies, each unique pt getting its own dummy
  1423 + pts = set(self.point.args + other.point.args)
  1424 + d = dict([(p, Dummy()) for p in pts])
  1425 + eq = lambda e: \
  1426 + e.expr.xreplace(dict(zip(e.variables, [d[p] for p in e.point])))
  1427 + return eq(self) == eq(other)
1446 1428
1447 1429 def __ne__(self, other):
1448 1430 return not(self == other)
7 sympy/core/tests/test_basic.py
@@ -3,6 +3,8 @@
3 3
4 4 from sympy.core.basic import Basic, Atom, preorder_traversal
5 5 from sympy.core.singleton import S, Singleton
  6 +from sympy.core.symbol import symbols
  7 +from sympy.utilities.misc import default_sort_key
6 8
7 9 from sympy.utilities.pytest import raises
8 10
@@ -117,3 +119,8 @@ def test_preorder_traversal():
117 119 if i == b2:
118 120 pt.skip()
119 121 assert result == [expr, b21, b2, b1, b3, b2]
  122 +
  123 + w, x, y, z = symbols('w:z')
  124 + expr = z + w*(x+y)
  125 + assert list(preorder_traversal([expr], key=default_sort_key)) == \
  126 + [[w*(x + y) + z], w*(x + y) + z, z, w*(x + y), w, x + y, x, y]
6 sympy/core/tests/test_functions.py
@@ -136,6 +136,12 @@ def test_Lambda_equality():
136 136 def test_Subs():
137 137 assert Subs(f(x), x, 0).doit() == f(0)
138 138 assert Subs(f(x**2), x**2, 0).doit() == f(0)
  139 + assert Subs(f(x, y, z), (x, y, z), (0, 1, 1)) != \
  140 + Subs(f(x, y, z), (x, y, z), (0, 0, 1))
  141 + assert Subs(f(x, y), (x, y, z), (0, 1, 1)) == \
  142 + Subs(f(x, y), (x, y, z), (0, 1, 2))
  143 + assert Subs(f(x, y), (x, y, z), (0, 1, 1)) != \
  144 + Subs(f(x, y) + z, (x, y, z), (0, 1, 0))
139 145 assert Subs(f(x, y), (x, y), (0, 1)).doit() == f(0, 1)
140 146 assert Subs(Subs(f(x, y), x, 0), y, 1).doit() == f(0, 1)
141 147 raises(ValueError, lambda: Subs(f(x, y), (x, y), (0, 0, 1)))
73 sympy/solvers/solvers.py
@@ -434,10 +434,10 @@ def solve(f, *symbols, **flags):
434 434 [3]
435 435 >>> solve(Poly(x - 3), x)
436 436 [3]
437   - >>> solve(x**2 - y**2, x)
438   - [-y, y]
439   - >>> solve(x**4 - 1, x)
440   - [-1, 1, -I, I]
  437 + >>> set(solve(x**2 - y**2, x))
  438 + set([-y, y])
  439 + >>> set(solve(x**4 - 1, x))
  440 + set([-1, 1, -I, I])
441 441
442 442 * single expression with no symbol that is in the expression
443 443
@@ -455,9 +455,9 @@ def solve(f, *symbols, **flags):
455 455
456 456 >>> solve(x - 3)
457 457 [3]
458   - >>> solve(x**2 - y**2)
  458 + >>> solve(x**2 - y**2) # doctest: +SKIP
459 459 [{x: -y}, {x: y}]
460   - >>> solve(z**2*x**2 - z**2*y**2)
  460 + >>> solve(z**2*x**2 - z**2*y**2) # doctest: +SKIP
461 461 [{x: -y}, {x: y}]
462 462 >>> solve(z**2*x - z**2*y**2)
463 463 [{x: y**2}]
@@ -473,8 +473,8 @@ def solve(f, *symbols, **flags):
473 473 [x + f(x)]
474 474 >>> solve(f(x).diff(x) - f(x) - x, f(x))
475 475 [-x + Derivative(f(x), x)]
476   - >>> solve(x + exp(x)**2, exp(x))
477   - [-sqrt(-x), sqrt(-x)]
  476 + >>> set(solve(x + exp(x)**2, exp(x)))
  477 + set([-sqrt(-x), sqrt(-x)])
478 478
479 479 * To solve for a *symbol* implicitly, use 'implicit=True':
480 480
@@ -501,17 +501,17 @@ def solve(f, *symbols, **flags):
501 501
502 502 * that are nonlinear
503 503
504   - >>> solve((a + b)*x - b**2 + 2, a, b)
505   - [(-sqrt(2), sqrt(2)), (sqrt(2), -sqrt(2))]
  504 + >>> set(solve((a + b)*x - b**2 + 2, a, b))
  505 + set([(-sqrt(2), sqrt(2)), (sqrt(2), -sqrt(2))])
506 506
507 507 * if there is no linear solution then the first successful
508 508 attempt for a nonlinear solution will be returned
509 509
510   - >>> solve(x**2 - y**2, x, y)
  510 + >>> solve(x**2 - y**2, x, y) # doctest: +SKIP
511 511 [{x: -y}, {x: y}]
512 512 >>> solve(x**2 - y**2/exp(x), x, y)
513 513 [{x: 2*LambertW(y/2)}]
514   - >>> solve(x**2 - y**2/exp(x), y, x)
  514 + >>> solve(x**2 - y**2/exp(x), y, x) # doctest: +SKIP
515 515 [{y: -x*exp(x/2)}, {y: x*exp(x/2)}]
516 516
517 517 * iterable of one or more of the above
@@ -543,8 +543,8 @@ def solve(f, *symbols, **flags):
543 543
544 544 * when the system is not linear
545 545
546   - >>> solve([x**2 + y -2, y**2 - 4], x, y)
547   - [(-2, -2), (0, 2), (0, 2), (2, -2)]
  546 + >>> set(solve([x**2 + y -2, y**2 - 4], x, y))
  547 + set([(-2, -2), (0, 2), (2, -2)])
548 548
549 549 * if no symbols are given, all free symbols will be selected and a list
550 550 of mappings returned
@@ -659,7 +659,7 @@ def _sympified_list(w):
659 659 # we do this to make the results returned canonical in case f
660 660 # contains a system of nonlinear equations; all other cases should
661 661 # be unambiguous
662   - symbols = sorted(symbols, key=lambda i: i.sort_key())
  662 + symbols = sorted(symbols, key=default_sort_key)
663 663
664 664 # we can solve for Function and Derivative instances by replacing them
665 665 # with Dummy symbols or functions
@@ -1287,9 +1287,10 @@ def _solve_system(exprs, symbols, **flags):
1287 1287 else:
1288 1288 if len(symbols) != len(polys):
1289 1289 from sympy.utilities.iterables import subsets
1290   - free = list(reduce(set.union,
1291   - [p.free_symbols for p in polys], set()
1292   - ).intersection(symbols))
  1290 + from sympy.core.compatibility import set_union
  1291 +
  1292 + free = set_union(*[p.free_symbols for p in polys])
  1293 + free = list(free.intersection(symbols))
1293 1294 free.sort(key=default_sort_key)
1294 1295 for syms in subsets(free, len(polys)):
1295 1296 try:
@@ -1341,16 +1342,26 @@ def _solve_system(exprs, symbols, **flags):
1341 1342 result = [result]
1342 1343 else:
1343 1344 result = [{}]
  1345 +
  1346 + def _ok_syms(e, sort=False):
  1347 + rv = (e.free_symbols - solved_syms) & legal
  1348 + if sort:
  1349 + rv = list(rv)
  1350 + rv.sort(key=default_sort_key)
  1351 + return rv
  1352 +
1344 1353 solved_syms = set(solved_syms) # set of symbols we have solved for
1345 1354 legal = set(symbols) # what we are interested in
1346 1355 simplify_flag = flags.get('simplify', None)
1347 1356 do_simplify = flags.get('simplify', True)
1348   - # sort so equation with the fewest potential symbols is first; break ties with
1349   - # count_ops and sort_key
1350   - short = sift(failed, lambda x: len((x.free_symbols - solved_syms) & legal))
  1357 + # sort so equation with the fewest potential symbols is first;
  1358 + # break ties with count_ops and default_sort_key
  1359 + short = sift(failed, lambda x: len(_ok_syms(x)))
1351 1360 failed = []
1352   - for k in sorted(short):
1353   - failed.extend(sorted(short[k], key=lambda x: (x.count_ops(), x.sort_key())))
  1361 + for k in sorted(short, key=default_sort_key):
  1362 + failed.extend(sorted(sorted(short[k],
  1363 + key=lambda x: x.count_ops()),
  1364 + key=default_sort_key))
1354 1365 for eq in failed:
1355 1366 newresult = []
1356 1367 got_s = None
@@ -1365,7 +1376,7 @@ def _solve_system(exprs, symbols, **flags):
1365 1376 continue
1366 1377 # search for a symbol amongst those available that
1367 1378 # can be solved for
1368   - ok_syms = (eq2.free_symbols - solved_syms) & legal
  1379 + ok_syms = _ok_syms(eq2, sort=True)
1369 1380 if not ok_syms:
1370 1381 break # skip as it's independent of desired symbols
1371 1382 for s in ok_syms:
@@ -1591,7 +1602,7 @@ def solve_linear_system(system, *symbols, **flags):
1591 1602 matrix = system[:, :]
1592 1603 syms = list(symbols)
1593 1604
1594   - i, m = 0, matrix.cols-1 # don't count augmentation
  1605 + i, m = 0, matrix.cols - 1 # don't count augmentation
1595 1606
1596 1607 while i < matrix.rows:
1597 1608 if i == m:
@@ -1611,12 +1622,12 @@ def solve_linear_system(system, *symbols, **flags):
1611 1622 break
1612 1623 else:
1613 1624 if matrix[i, m]:
1614   - # we need to know this this is always zero or not. We
  1625 + # we need to know if this is always zero or not. We
1615 1626 # assume that if there are free symbols that it is not
1616 1627 # identically zero (or that there is more than one way
1617 1628 # to make this zero. Otherwise, if there are none, this
1618 1629 # is a constant and we assume that it does not simplify
1619   - # to zero XXX are there better ways to test this/
  1630 + # to zero XXX are there better ways to test this?
1620 1631 if not matrix[i, m].free_symbols:
1621 1632 return None # no solution
1622 1633
@@ -1664,7 +1675,7 @@ def solve_linear_system(system, *symbols, **flags):
1664 1675 # divide all elements in the current row by the pivot
1665 1676 matrix.row(i, lambda x, _: x * pivot_inv)
1666 1677
1667   - for k in xrange(i+1, matrix.rows):
  1678 + for k in xrange(i + 1, matrix.rows):
1668 1679 if matrix[k, i]:
1669 1680 coeff = matrix[k, i]
1670 1681
@@ -1689,7 +1700,7 @@ def solve_linear_system(system, *symbols, **flags):
1689 1700 content = matrix[k, m]
1690 1701
1691 1702 # run back-substitution for variables
1692   - for j in xrange(k+1, m):
  1703 + for j in xrange(k + 1, m):
1693 1704 content -= matrix[k, j]*solutions[syms[j]]
1694 1705
1695 1706 if do_simplify:
@@ -1703,13 +1714,13 @@ def solve_linear_system(system, *symbols, **flags):
1703 1714 elif len(syms) > matrix.rows:
1704 1715 # this system will have infinite number of solutions
1705 1716 # dependent on exactly len(syms) - i parameters
1706   - k, solutions = i-1, {}
  1717 + k, solutions = i - 1, {}
1707 1718
1708 1719 while k >= 0:
1709 1720 content = matrix[k, m]
1710 1721
1711 1722 # run back-substitution for variables
1712   - for j in xrange(k+1, i):
  1723 + for j in xrange(k + 1, i):
1713 1724 content -= matrix[k, j]*solutions[syms[j]]
1714 1725
1715 1726 # run back-substitution for parameters
209 sympy/solvers/tests/test_solvers.py
@@ -30,9 +30,6 @@ def guess_solve_strategy(eq, symbol):
30 30 return False
31 31
32 32 def test_guess_poly():
33   - """
34   - See solvers.guess_solve_strategy
35   - """
36 33 # polynomial equations
37 34 assert guess_solve_strategy( S(4), x ) #== GS_POLY
38 35 assert guess_solve_strategy( x, x ) #== GS_POLY
@@ -76,8 +73,8 @@ def test_guess_transcendental():
76 73
77 74 def test_solve_args():
78 75 #implicit symbol to solve for
79   - assert set(int(tmp) for tmp in solve(x**2-4)) == set([2,-2])
80   - assert solve([x+y-3,x-y-5]) == {x: 4, y: -1}
  76 + assert set(solve(x**2 - 4)) == set([S(2), -S(2)])
  77 + assert solve([x + y - 3, x - y - 5]) == {x: 4, y: -1}
81 78 #no symbol to solve for
82 79 assert solve(42) == []
83 80 assert solve([1, 2]) == []
@@ -118,14 +115,14 @@ def test_solve_args():
118 115 assert solve((x + y - 2, 2*x + 2*y - 4)) == {x: -y + 2}
119 116
120 117 def test_solve_polynomial1():
121   - assert solve(3*x-2, x) == [Rational(2,3)]
122   - assert solve(Eq(3*x, 2), x) == [Rational(2,3)]
  118 + assert solve(3*x-2, x) == [Rational(2, 3)]
  119 + assert solve(Eq(3*x, 2), x) == [Rational(2, 3)]
123 120
124   - assert solve(x**2-1, x) in [[-1, 1], [1, -1]]
125   - assert solve(Eq(x**2, 1), x) in [[-1, 1], [1, -1]]
  121 + assert set(solve(x**2 - 1, x)) == set([-S(1), S(1)])
  122 + assert set(solve(Eq(x**2, 1), x)) == set([-S(1), S(1)])
126 123
127   - assert solve( x - y**3, x) == [y**3]
128   - assert sorted(solve( x - y**3, y)) == sorted([
  124 + assert solve(x - y**3, x) == [y**3]
  125 + assert set(solve(x - y**3, y)) == set([
129 126 (-x**Rational(1,3))/2 + I*sqrt(3)*x**Rational(1,3)/2,
130 127 x**Rational(1,3),
131 128 (-x**Rational(1,3))/2 - I*sqrt(3)*x**Rational(1,3)/2,
@@ -149,8 +146,8 @@ def test_solve_polynomial1():
149 146 S(4),
150 147 -2 - 3**Rational(1,2) ])
151 148
152   - assert sorted(solve((x**2 - 1)**2 - a, x)) == \
153   - sorted([sqrt(1 + sqrt(a)), -sqrt(1 + sqrt(a)),
  149 + assert set(solve((x**2 - 1)**2 - a, x)) == \
  150 + set([sqrt(1 + sqrt(a)), -sqrt(1 + sqrt(a)),
154 151 sqrt(1 - sqrt(a)), -sqrt(1 - sqrt(a))])
155 152
156 153 def test_solve_polynomial2():
@@ -225,7 +222,7 @@ def test_tsolve():
225 222 assert set(solve((a*x+b)*(exp(x)-3), x)) == set([-b/a, log(3)])
226 223 assert solve(cos(x)-y, x) == [acos(y)]
227 224 assert solve(2*cos(x)-y,x)== [acos(y/2)]
228   - assert solve(Eq(cos(x), sin(x)), x) == [-3*pi/4, pi/4]
  225 + assert set(solve(Eq(cos(x), sin(x)), x)) == set([-3*pi/4, pi/4])
229 226
230 227 assert set(solve(exp(x) + exp(-x) - y, x)) == set([
231 228 log(y/2 - sqrt(y**2 - 4)/2),
@@ -240,8 +237,8 @@ def test_tsolve():
240 237 assert solve(x+2**x, x) == [-LambertW(log(2))/log(2)]
241 238 assert solve(3*x+5+2**(-5*x+3), x) in [
242 239 [-((25*log(2) - 3*LambertW(-10240*2**(Rational(1, 3))*log(2)/3))/(15*log(2)))],
243   - [Rational(-5, 3) + LambertW(log(2**(-10240*2**(Rational(1, 3))/3)))/(5*log(2))],
244   - [-Rational(5,3) + LambertW(-10240*2**Rational(1,3)*log(2)/3)/(5*log(2))],
  240 + [-Rational(5, 3) + LambertW(log(2**(-10240*2**(Rational(1, 3))/3)))/(5*log(2))],
  241 + [-Rational(5, 3) + LambertW(-10240*2**Rational(1,3)*log(2)/3)/(5*log(2))],
245 242 [(-25*log(2) + 3*LambertW(-10240*2**(Rational(1, 3))*log(2)/3))/(15*log(2))],
246 243 [-((25*log(2) - 3*LambertW(-10240*2**(Rational(1, 3))*log(2)/3)))/(15*log(2))],
247 244 [-(25*log(2) - 3*LambertW(log(2**(-10240*2**Rational(1, 3)/3))))/(15*log(2))],
@@ -385,8 +382,10 @@ def test_issue_1694():
385 382 eq = 4*3**(5*x + 2) - 7
386 383 ans = solve(eq, x)
387 384 assert len(ans) == 5 and all(eq.subs(x, a).n(chop=True) == 0 for a in ans)
388   - assert solve(log(x**2) - y**2/exp(x), x, y) == [{y: -sqrt(exp(x)*log(x**2))},
389   - {y: sqrt(exp(x)*log(x**2))}]
  385 + assert solve(log(x**2) - y**2/exp(x), x, y, set=True) == \
  386 + ([y], set([
  387 + (-sqrt(exp(x)*log(x**2)),),
  388 + (sqrt(exp(x)*log(x**2)),)]))
390 389 assert solve(x**2*z**2 - z**2*y**2) in ([{x: y}, {x: -y}], [{x: -y}, {x: y}])
391 390 assert solve((x - 1)/(1 + 1/(x - 1))) == []
392 391 assert solve(x**(y*z) - x, x) == [1]
@@ -401,7 +400,7 @@ def test_issue_1694():
401 400 # 1387
402 401 assert solve(2*x/(x + 2) - 1,x) == [2]
403 402 # 1397
404   - assert solve((x**2/(7 - x)).diff(x)) == [0, 14]
  403 + assert set(solve((x**2/(7 - x)).diff(x))) == set([S(0), S(14)])
405 404 # 1596
406 405 f = Function('f')
407 406 assert solve((3 - 5*x/f(x))*f(x), f(x)) == [5*x/3]
@@ -409,22 +408,23 @@ def test_issue_1694():
409 408 assert solve(1/(5 + x)**(S(1)/5) - 9, x) == [-295244/S(59049)]
410 409
411 410 assert solve(sqrt(x) + sqrt(sqrt(x)) - 4) == [-9*sqrt(17)/2 + 49*S.Half]
412   - assert solve(Poly(sqrt(exp(x)) + sqrt(exp(-x)) - 4)) in \
  411 + assert set(solve(Poly(sqrt(exp(x)) + sqrt(exp(-x)) - 4))) in \
413 412 [
414   - [2*log(-sqrt(3) + 2), 2*log(sqrt(3) + 2)],
415   - [log(-4*sqrt(3) + 7), log(4*sqrt(3) + 7)],
  413 + set([2*log(-sqrt(3) + 2), 2*log(sqrt(3) + 2)]),
  414 + set([log(-4*sqrt(3) + 7), log(4*sqrt(3) + 7)]),
416 415 ]
417   - assert solve(Poly(exp(x) + exp(-x) - 4)) == [log(-sqrt(3) + 2), log(sqrt(3) + 2)]
418   - assert solve(x**y + x**(2*y) - 1, x) == \
419   - [(-S.Half + sqrt(5)/2)**(1/y), (-S.Half - sqrt(5)/2)**(1/y)]
  416 + assert set(solve(Poly(exp(x) + exp(-x) - 4))) == \
  417 + set([log(-sqrt(3) + 2), log(sqrt(3) + 2)])
  418 + assert set(solve(x**y + x**(2*y) - 1, x)) == \
  419 + set([(-S.Half + sqrt(5)/2)**(1/y), (-S.Half - sqrt(5)/2)**(1/y)])
420 420
421 421 assert solve(exp(x/y)*exp(-z/y) - 2, y) == [(x - z)/log(2)]
422 422 assert solve(x**z*y**z - 2, z) in [[log(2)/(log(x) + log(y))], [log(2)/(log(x*y))]]
423 423 # if you do inversion too soon then multiple roots as for the following will
424 424 # be missed, e.g. if exp(3*x) = exp(3) -> 3*x = 3
425 425 E = S.Exp1
426   - assert solve(exp(3*x) - exp(3), x) == \
427   - [1, log(-E/2 - sqrt(3)*E*I/2), log(-E/2 + sqrt(3)*E*I/2)]
  426 + assert set(solve(exp(3*x) - exp(3), x)) == \
  427 + set([S(1), log(-E/2 - sqrt(3)*E*I/2), log(-E/2 + sqrt(3)*E*I/2)])
428 428
429 429 def test_issue_2098():
430 430 x = Symbol('x', real=True)
@@ -461,9 +461,9 @@ def test_checking():
461 461 def test_issue_1572_1364_1368():
462 462 assert solve((sqrt(x**2 - 1) - 2)) in ([sqrt(5), -sqrt(5)],
463 463 [-sqrt(5), sqrt(5)])
464   - assert solve((2**exp(y**2/x) + 2)/(x**2 + 15), y) == [
  464 + assert set(solve((2**exp(y**2/x) + 2)/(x**2 + 15), y)) == set([
465 465 -sqrt(x)*sqrt(-log(log(2)) + log(log(2) + I*pi)),
466   - sqrt(x)*sqrt(-log(log(2)) + log(log(2) + I*pi))]
  466 + sqrt(x)*sqrt(-log(log(2)) + log(log(2) + I*pi))])
467 467
468 468 C1, C2 = symbols('C1 C2')
469 469 f = Function('f')
@@ -480,65 +480,65 @@ def test_issue_1572_1364_1368():
480 480 assert solve(1 - log(a + 4*x**2), x) in (
481 481 [-sqrt(-a + E)/2, sqrt(-a + E)/2],
482 482 [sqrt(-a + E)/2, -sqrt(-a + E)/2],)
483   - assert solve((a**2 + 1) * (sin(a*x) + cos(a*x)), x) == [-pi/(4*a), 3*pi/(4*a)]
  483 + assert set(solve((a**2 + 1) * (sin(a*x) + cos(a*x)), x)) == set([-pi/(4*a), 3*pi/(4*a)])
484 484 assert solve(3 - (sinh(a*x) + cosh(a*x)), x) == [2*atanh(S.Half)/a]
485   - assert solve(3-(sinh(a*x) + cosh(a*x)**2), x) == \
486   - [
  485 + assert set(solve(3-(sinh(a*x) + cosh(a*x)**2), x)) == \
  486 + set([
487 487 2*atanh(-1 + sqrt(2))/a,
488 488 2*atanh(S(1)/2 + sqrt(5)/2)/a,
489 489 2*atanh(-sqrt(2) - 1)/a,
490 490 2*atanh(-sqrt(5)/2 + S(1)/2)/a
491   - ]
  491 + ])
492 492 assert solve(atan(x) - 1) == [tan(1)]
493 493
494 494 def test_issue_2033():
495 495 r, t = symbols('r,t')
496   - assert solve([r - x**2 - y**2, tan(t) - y/x], [x, y]) == \
497   - [
  496 + assert set(solve([r - x**2 - y**2, tan(t) - y/x], [x, y])) == \
  497 + set([
498 498 (-sqrt(r*sin(t)**2)/tan(t), -sqrt(r*sin(t)**2)),
499   - (sqrt(r*sin(t)**2)/tan(t), sqrt(r*sin(t)**2))]
  499 + (sqrt(r*sin(t)**2)/tan(t), sqrt(r*sin(t)**2))])
500 500 assert solve([exp(x) - sin(y), 1/y - 3], [x, y]) == \
501 501 [(log(sin(S(1)/3)), S(1)/3)]
502 502 assert solve([exp(x) - sin(y), 1/exp(y) - 3], [x, y]) == \
503 503 [(log(-sin(log(3))), -log(3))]
504   - assert solve([exp(x) - sin(y), y**2 - 4], [x, y]) == \
505   - [(log(-sin(2)), -2), (log(sin(2)), 2)]
  504 + assert set(solve([exp(x) - sin(y), y**2 - 4], [x, y])) == \
  505 + set([(log(-sin(2)), -S(2)), (log(sin(2)), S(2))])
506 506 eqs = [exp(x)**2 - sin(y) + z**2, 1/exp(y) - 3]
507   - assert solve(eqs) == \
508   - [
509   - {x: log(-sqrt(-z**2 - sin(log(3)))), y: -log(3)},
510   - {x: log(sqrt(-z**2 - sin(log(3)))), y: -log(3)}]
511   - assert solve(eqs, x, z) == \
512   - [
513   - {x: log(-sqrt(-z**2 + sin(y)))},
514   - {x: log(sqrt(-z**2 + sin(y)))}]
515   - assert solve(eqs, x, y) == \
516   - [
  507 + assert solve(eqs, set=True) == \
  508 + ([x, y], set([
517 509 (log(-sqrt(-z**2 - sin(log(3)))), -log(3)),
518   - (log(sqrt(-z**2 - sin(log(3)))), -log(3))]
519   - assert solve(eqs, y, z) == \
520   - [
  510 + (log(sqrt(-z**2 - sin(log(3)))), -log(3))]))
  511 + assert solve(eqs, x, z, set=True) == \
  512 + ([x], set([
  513 + (log(-sqrt(-z**2 + sin(y))),),
  514 + (log(sqrt(-z**2 + sin(y))),)]))
  515 + assert set(solve(eqs, x, y)) == \
  516 + set([
  517 + (log(-sqrt(-z**2 - sin(log(3)))), -log(3)),
  518 + (log(sqrt(-z**2 - sin(log(3)))), -log(3))])
  519 + assert set(solve(eqs, y, z)) == \
  520 + set([
521 521 (-log(3), -sqrt(-exp(2*x) - sin(log(3)))),
522   - (-log(3), sqrt(-exp(2*x) - sin(log(3))))]
  522 + (-log(3), sqrt(-exp(2*x) - sin(log(3))))])
523 523 eqs = [exp(x)**2 - sin(y) + z, 1/exp(y) - 3]
524   - assert solve(eqs) == \
  524 + assert solve(eqs, set=True) == ([x, y], set(
525 525 [
526   - {x: log(-sqrt(-z - sin(log(3)))), y: -log(3)},
527   - {x: log(sqrt(-z - sin(log(3)))), y: -log(3)}]
528   - assert solve(eqs, x, z) == \
  526 + (log(-sqrt(-z - sin(log(3)))), -log(3)),
  527 + (log(sqrt(-z - sin(log(3)))), -log(3))]))
  528 + assert solve(eqs, x, z, set=True) == ([x], set(
529 529 [
530   - {x: log(-sqrt(-z + sin(y)))},
531   - {x: log(sqrt(-z + sin(y)))}]
532   - assert solve(eqs, x, y) == \
  530 + (log(-sqrt(-z + sin(y))),),
  531 + (log(sqrt(-z + sin(y))),)]))
  532 + assert set(solve(eqs, x, y)) == set(
533 533 [
534 534 (log(-sqrt(-z - sin(log(3)))), -log(3)),
535   - (log(sqrt(-z - sin(log(3)))), -log(3))]
  535 + (log(sqrt(-z - sin(log(3)))), -log(3))])
536 536 assert solve(eqs, z, y) == \
537 537 [(-exp(2*x) - sin(log(3)), -log(3))]
538   - assert solve((sqrt(x**2 + y**2) - sqrt(10), x + y - 4)) == \
539   - [{x: 1, y: 3}, {x: 3, y: 1}]
540   - assert solve((sqrt(x**2 + y**2) - sqrt(10), x + y - 4), x, y) == \
541   - [(1, 3), (3, 1)]
  538 + assert solve((sqrt(x**2 + y**2) - sqrt(10), x + y - 4), set=True) == (
  539 + [x, y], set([(S(1), S(3)), (S(3), S(1))]))
  540 + assert set(solve((sqrt(x**2 + y**2) - sqrt(10), x + y - 4), x, y)) == \
  541 + set([(S(1), S(3)), (S(3), S(1))])
542 542
543 543 def test_issue_2236():
544 544 lam, a0, conc = symbols('lam a0 conc')
@@ -559,11 +559,13 @@ def test_issue_2236_float():
559 559 assert len(solve(eqs, sym, rational=False, check=False, simplify=False)) == 2
560 560
561 561 def test_issue_2668():
562   - assert solve([x**2 + y + 4], [x]) == [(-sqrt(-y - 4),), (sqrt(-y - 4),)]
  562 + assert set(solve([x**2 + y + 4], [x])) == \
  563 + set([(-sqrt(-y - 4),), (sqrt(-y - 4),)])
563 564
564 565 def test_polysys():
565   - assert solve([x**2 + 2/y - 2 , x + y - 3], [x, y]) == \
566   - [(1, 2), (1 + sqrt(5), 2 - sqrt(5)), (1 - sqrt(5), 2 + sqrt(5))]
  566 + assert set(solve([x**2 + 2/y - 2 , x + y - 3], [x, y])) == \
  567 + set([(S(1), S(2)), (1 + sqrt(5), 2 - sqrt(5)),
  568 + (1 - sqrt(5), 2 + sqrt(5))])
567 569 assert solve([x**2 + y - 2, x**2 + y]) == []
568 570 # the ordering should be whatever the user requested
569 571 assert solve([x**2 + y - 3, x - y - 4], (x, y)) != solve([x**2 + y - 3, x - y - 4], (y, x))
@@ -626,10 +628,11 @@ def s_check(rv, ans):
626 628 eq = sqrt(x) + sqrt(x + 1) + sqrt(1 - sqrt(x))
627 629 assert check(unrad(eq),
628 630 (16*x**3 - 9*x**2, [], []))
629   - assert solve(eq, check=False) == [0, S(9)/16]
  631 + assert set(solve(eq, check=False)) == set([S(0), S(9)/16])
630 632 assert solve(eq) == []
631 633 # but this one really does have those solutions
632   - assert solve(sqrt(x) - sqrt(x + 1) + sqrt(1 - sqrt(x))) == [0, S(9)/16]
  634 + assert set(solve(sqrt(x) - sqrt(x + 1) + sqrt(1 - sqrt(x)))) == \
  635 + set([S.Zero, S(9)/16])
633 636
634 637 '''real_root changes the value of the result if the solution is
635 638 simplified; `a` in the text below is the root that is not 4/5:
@@ -661,8 +664,8 @@ def s_check(rv, ans):
661 664 ans = solve(sqrt(x) + sqrt(x + 1) + sqrt(1 - x) - 6*sqrt(5)/5)
662 665 assert all(abs(eq.subs(x, i).n()) < 1e-10 for i in (ra, rb)) and \
663 666 len(ans) == 2 and \
664   - sorted([i.n(chop=True) for i in ans]) == \
665   - sorted([i.n(chop=True) for i in (ra, rb)])
  667 + set([i.n(chop=True) for i in ans]) == \
  668 + set([i.n(chop=True) for i in (ra, rb)])
666 669
667 670 ans = solve(sqrt(x) + sqrt(x + 1) - \
668 671 sqrt(1 - x) - sqrt(2 + x))
@@ -694,15 +697,15 @@ def s_check(rv, ans):
694 697 assert solve(Eq(x, sqrt(x + 6))) == [3]
695 698 assert solve(Eq(x + sqrt(x - 4), 4)) == [4]
696 699 assert solve(Eq(1, x + sqrt(2*x - 3))) == []
697   - assert solve(Eq(sqrt(5*x + 6) - 2, x)) == [-1, 2]
698   - assert solve(Eq(sqrt(2*x - 1) - sqrt(x - 4), 2)) == [5, 13]
  700 + assert set(solve(Eq(sqrt(5*x + 6) - 2, x))) == set([-S(1), S(2)])
  701 + assert set(solve(Eq(sqrt(2*x - 1) - sqrt(x - 4), 2))) == set([S(5), S(13)])
699 702 assert solve(Eq(sqrt(x + 7) + 2, sqrt(3 - x))) == [-6]
700 703 # http://www.purplemath.com/modules/solverad.htm
701 704 assert solve((2*x - 5)**Rational(1, 3) - 3) == [16]
702 705 assert solve((x**3 - 3*x**2)**Rational(1, 3) + 1 - x) == []
703   - assert solve(x + 1 - (x**4 + 4*x**3 - x)**Rational(1, 4)) == \
704   - [-S(1)/2, -S(1)/3]
705   - assert solve(sqrt(2*x**2 - 7) - (3 - x)) == [-8, 2]
  706 + assert set(solve(x + 1 - (x**4 + 4*x**3 - x)**Rational(1, 4))) == \
  707 + set([-S(1)/2, -S(1)/3])
  708 + assert set(solve(sqrt(2*x**2 - 7) - (3 - x))) == set([-S(8), S(2)])
706 709 assert solve(sqrt(2*x + 9) - sqrt(x + 1) - sqrt(x + 4)) == [0]
707 710 assert solve(sqrt(x + 4) + sqrt(2*x - 1) - 3*sqrt(x - 1)) == [5]
708 711 assert solve(sqrt(x)*sqrt(x - 7) - 12) == [16]
@@ -784,14 +787,13 @@ def test_issue_2750():
784 787 )
785 788
786 789 ans = [{
787   - dI4: -I3 + 3*I5 - 2*Q4,
  790 + dQ4: I3 - I5,
788 791 dI1: -4*I2 - 8*I3 - 4*I5 - 6*I6 + 24,
  792 + I4: I3 - I5,
789 793 dQ2: I2,
790   - I1: I2 + I3,
791 794 Q2: 2*I3 + 2*I5 + 3*I6,
792   - dQ4: I3 - I5,
793   - I4: I3 - I5,
794   - }]
  795 + I1: I2 + I3,
  796 + Q4: -I3/2 + 3*I5/2 - dI4/2}]
795 797 assert solve(e, I1, I4, Q2, Q4, dI1, dI4, dQ2, dQ4, manual=True) == ans
796 798 # the matrix solver (tested below) doesn't like this because it produces
797 799 # a zero row in the matrix. Is this related to issue 1452?
@@ -837,18 +839,21 @@ def test_issue_2802():
837 839 {f(x): 3*D}
838 840 assert solve([f(x) - 3*f(x).diff(x), f(x)**2 - y + 4], f(x), y) == \
839 841 [{f(x): 3*D, y: 9*D**2 + 4}]
840   - assert solve(-f(a)**2*g(a)**2 + f(a)**2*h(a)**2 + g(a).diff(a), h(a), g(a)) == \
841   - [{g(a): -sqrt(h(a)**2 + G/f(a)**2)},
842   - {g(a): sqrt(h(a)**2 + G/f(a)**2)}]
  842 + assert solve(-f(a)**2*g(a)**2 + f(a)**2*h(a)**2 + g(a).diff(a),
  843 + h(a), g(a), set=True) == \
  844 + ([g(a)], set([
  845 + (-sqrt(h(a)**2 + G/f(a)**2),),
  846 + (sqrt(h(a)**2 + G/f(a)**2),)]))
843 847 args = [f(x).diff(x, 2)*(f(x) + g(x)) - g(x)**2 + 2, f(x), g(x)]
844   - assert solve(*args) == \
845   - [(-sqrt(2), sqrt(2)), (sqrt(2), -sqrt(2))]
  848 + assert set(solve(*args)) == \
  849 + set([(-sqrt(2), sqrt(2)), (sqrt(2), -sqrt(2))])
846 850 eqs = [f(x)**2 + g(x) - 2*f(x).diff(x), g(x)**2 - 4]
847   - assert solve(eqs, f(x), g(x)) == \
848   - [{g(x): 2, f(x): -sqrt(2*D - 2)},
849   - {g(x): 2, f(x): sqrt(2*D - 2)},
850   - {g(x): -2, f(x): -sqrt(2*D + 2)},
851   - {g(x): -2, f(x): sqrt(2*D + 2)}]
  851 + assert solve(eqs, f(x), g(x), set=True) == \
  852 + ([f(x), g(x)], set([
  853 + (-sqrt(2*D - 2), S(2)),
  854 + (sqrt(2*D - 2), S(2)),
  855 + (-sqrt(2*D + 2), -S(2)),
  856 + (sqrt(2*D + 2), -S(2))]))
852 857
853 858 # the underlying problem was in solve_linear that was not masking off
854 859 # anything but a Mul or Add; it now raises an error if it gets anything
@@ -872,8 +877,8 @@ def test_issue_2802():
872 877 assert solve_linear(x + Integral(x, y) - 2, symbols=[x]) == \
873 878 (x, 2/(y + 1))
874 879
875   - assert solve(x + exp(x)**2, exp(x)) == \
876   - [-sqrt(-x), sqrt(-x)]
  880 + assert set(solve(x + exp(x)**2, exp(x))) == \
  881 + set([-sqrt(-x), sqrt(-x)])
877 882 assert solve(x + exp(x), x, implicit=True) == \
878 883 [-exp(x)]
879 884 assert solve(cos(x) - sin(x), x, implicit=True) == []
@@ -885,8 +890,8 @@ def test_issue_2802():
885 890 [-x + 3]
886 891
887 892 def test_issue_2813():
888   - assert solve(x**2 - x - 0.1, rational=True) == \
889   - [S(1)/2 + sqrt(35)/10, -sqrt(35)/10 + S(1)/2]
  893 + assert set(solve(x**2 - x - 0.1, rational=True)) == \
  894 + set([S(1)/2 + sqrt(35)/10, -sqrt(35)/10 + S(1)/2])
890 895 # [-0.0916079783099616, 1.09160797830996]
891 896 ans = solve(x**2 - x - 0.1, rational=False)
892 897 assert len(ans) == 2 and all(a.is_Number for a in ans)
@@ -923,18 +928,18 @@ def test_check_assumptions():
923 928 assert solve(x**2 - 1) == [1]
924 929
925 930 def test_solve_abs():
926   - assert solve(abs(x - 7) - 8) == [-1, 15]
  931 + assert set(solve(abs(x - 7) - 8)) == set([-S(1), S(15)])
927 932
928 933 def test_issue_2957():
929 934 assert solve(tanh(x + 3)*tanh(x - 3) - 1) == []
930   - assert solve(tanh(x - 1)*tanh(x + 1) + 1) == [
  935 + assert set(solve(tanh(x - 1)*tanh(x + 1) + 1)) == set([
931 936 -log(2)/2 + log(-1 - I),
932 937 -log(2)/2 + log(-1 + I),
933 938 -log(2)/2 + log(1 - I),
934   - -log(2)/2 + log(1 + I)]
935   - assert solve((tanh(x + 3)*tanh(x - 3) + 1)**2) == \
936   - [-log(2)/2 + log(-1 - I), -log(2)/2 + log(-1 + I),
937   - -log(2)/2 + log(1 - I), -log(2)/2 + log(1 + I)]
  939 + -log(2)/2 + log(1 + I)])
  940 + assert set(solve((tanh(x + 3)*tanh(x - 3) + 1)**2)) == \
  941 + set([-log(2)/2 + log(-1 - I), -log(2)/2 + log(-1 + I),
  942 + -log(2)/2 + log(1 - I), -log(2)/2 + log(1 + I)])
938 943
939 944 def test_issue_2574():
940 945 eq = -x + exp(exp(LambertW(log(x)))*LambertW(log(x)))
@@ -965,4 +970,4 @@ def test_exclude():
965 970
966 971 def test_high_order_roots():
967 972 s = x**5 + 4*x**3 + 3*x**2 + S(7)/4
968   - assert solve(s) == Poly(s*4, domain='ZZ').all_roots()
  973 + assert set(solve(s)) == set(Poly(s*4, domain='ZZ').all_roots())
44 sympy/utilities/iterables.py
@@ -4,6 +4,7 @@
4 4 from sympy.core import Basic, C
5 5 from sympy.core.compatibility import is_sequence, iterable #logically, these belong here
6 6 from sympy.core.compatibility import product as cartes, combinations, combinations_with_replacement
  7 +from sympy.utilities.misc import default_sort_key
7 8 from sympy.utilities.exceptions import SymPyDeprecationWarning
8 9
9 10 def flatten(iterable, levels=None, cls=None):
@@ -110,7 +111,7 @@ def group(container, multiple=True):
110 111
111 112 return groups
112 113
113   -def postorder_traversal(node):
  114 +def postorder_traversal(node, key=None):
114 115 """
115 116 Do a postorder traversal of a tree.
116 117
@@ -118,39 +119,46 @@ def postorder_traversal(node):
118 119 fashion. That is, it descends through the tree depth-first to yield all of
119 120 a node's children's postorder traversal before yielding the node itself.
120 121
121   - For an expression, the order of the traversal depends on the order of
122   - .args, which in many cases can be arbitrary.
123   -
124 122 Parameters
125   - ----------
  123 + ==========
126 124 node : sympy expression
127 125 The expression to traverse.
  126 + key : (default None) sort key
  127 + The key used to sort args of Basic objects. When None, args of Basic
  128 + objects are processed in arbitrary order.
128 129
129 130 Returns
130   - -------
  131 + =======
131 132 subtree : sympy expression
132 133 All of the subtrees in the tree.
133 134
134 135 Examples
135   - --------
136   - >>> from sympy import symbols
  136 + ========
  137 + >>> from sympy import symbols, default_sort_key
137 138 >>> from sympy.utilities.iterables import postorder_traversal
138   - >>> from sympy.abc import x, y, z
139   - >>> list(postorder_traversal(z*(x+y))) in ( # any of these are possible
140   - ... [z, y, x, x + y, z*(x + y)], [z, x, y, x + y, z*(x + y)],
141   - ... [x, y, x + y, z, z*(x + y)], [y, x, x + y, z, z*(x + y)])
142   - True
143   - >>> list(postorder_traversal((x, (y, z))))
144   - [x, y, z, (y, z), (x, (y, z))]
  139 + >>> from sympy.abc import w, x, y, z
  140 +
  141 + The nodes are returned in the order that they are encountered unless key
  142 + is given.
  143 +
  144 + >>> list(postorder_traversal(w + (x + y)*z)) # doctest: +SKIP
  145 + [z, y, x, x + y, z*(x + y), w, w + z*(x + y)]
  146 + >>> list(postorder_traversal(w + (x + y)*z, key=default_sort_key))
  147 + [w, z, x, y, x + y, z*(x + y), w + z*(x + y)]
  148 +
145 149
146 150 """
147 151 if isinstance(node, Basic):
148   - for arg in node.args:
149   - for subtree in postorder_traversal(arg):
  152 + args = node.args
  153 + if key:
  154 + args = list(args)
  155 + args.sort(key=key)
  156 + for arg in args:
  157 + for subtree in postorder_traversal(arg, key):
150 158 yield subtree
151 159 elif iterable(node):
152 160 for item in node:
153   - for subtree in postorder_traversal(item):
  161 + for subtree in postorder_traversal(item, key):
154 162 yield subtree
155 163 yield node
156 164
34 sympy/utilities/tests/test_iterables.py
... ... @@ -1,4 +1,4 @@
1   -from sympy import symbols, Integral, Tuple, Dummy, Basic
  1 +from sympy import symbols, Integral, Tuple, Dummy, Basic, default_sort_key
2 2 from sympy.utilities.iterables import (postorder_traversal, flatten, group,
3 3 take, subsets, variations, cartes, numbered_symbols, dict_merge,
4 4 prefixes, postfixes, sift, topological_sort, rotate_left, rotate_right,
@@ -12,26 +12,22 @@
12 12 w,x,y,z= symbols('w,x,y,z')
13 13
14 14 def test_postorder_traversal():
15   - expr = z+w*(x+y)
16   - expected1 = [z, w, y, x, x + y, w*(x + y), z + w*(x + y)]
17   - expected2 = [z, w, x, y, x + y, w*(x + y), z + w*(x + y)]
18   - expected3 = [w, y, x, x + y, w*(x + y), z, z + w*(x + y)]
19   - expected4 = [w, x, y, x + y, w*(x + y), z, z + w*(x + y)]
20   - expected5 = [x, y, x + y, w, w*(x + y), x, x + w*(x + y)]
21   - expected6 = [y, x, x + y, w, w*(x + y), x, x + w*(x + y)]
22   - assert list(postorder_traversal(expr)) in [expected1, expected2,
23   - expected3, expected4,
24   - expected5, expected6]
25   -
26   - expr = Piecewise((x,x<1),(x**2,True))
27   - assert list(postorder_traversal(expr)) == [
28   - x, x, 1, x < 1, ExprCondPair(x, x < 1), x, 2, x**2,
29   - ExprCondPair.true_sentinel,
  15 + expr = z + w*(x+y)
  16 + expected = [z, w, x, y, x + y, w*(x + y), w*(x + y) + z]
  17 + assert list(postorder_traversal(expr, key=default_sort_key)) == expected
  18 +
  19 + expr = Piecewise((x, x < 1), (x**2, True))
  20 + expected = [
  21 + x, 1, x, x < 1, ExprCondPair(x, x < 1),
  22 + ExprCondPair.true_sentinel, 2, x, x**2,
30 23 ExprCondPair(x**2, True), Piecewise((x, x < 1), (x**2, True))
31   - ]
  24 + ]
  25 + assert list(postorder_traversal(expr, key=default_sort_key)) == expected
  26 + assert list(postorder_traversal([expr], key=default_sort_key)) == expected + [[expr]]
32 27
33   - assert list(postorder_traversal(Integral(x**2, (x, 0, 1)))) == [
34   - x, 2, x**2, x, 0, 1, Tuple(x, 0, 1),
  28 + assert list(postorder_traversal(Integral(x**2, (x, 0, 1)),
  29 + key=default_sort_key)) == [
  30 + 2, x, x**2, 0, 1, x, Tuple(x, 0, 1),
35 31 Integral(x**2, Tuple(x, 0, 1))
36 32 ]
37 33 assert list(postorder_traversal(('abc', ('d', 'ef')))) == [

0 comments on commit 9ff4edf

Please sign in to comment.
Something went wrong with that request. Please try again.