publicsympy/sympy

Subversion checkout URL

You can clone with HTTPS or Subversion.

Merge pull request #1411 from smichr/rand

`Randomization fixes`
commit 9ff4edf72e293ab00d7204fb0309b7ab90acdf44 2 parents 262bdc4 + bf84abd
smichr authored
40 sympy/core/basic.py
 @@ -1501,44 +1501,54 @@ class preorder_traversal(object): 1501 1501 .args, which in many cases can be arbitrary. 1502 1502 1503 1503 Parameters 1504 - ---------- 1504 + ========== 1505 1505 node : sympy expression 1506 1506 The expression to traverse. 1507 + key : (default None) sort key 1508 + The key used to sort args of Basic objects. When None, args of Basic 1509 + objects are processed in arbitrary order. 1507 1510 1508 1511 Yields 1509 - ------ 1512 + ====== 1510 1513 subtree : sympy expression 1511 1514 All of the subtrees in the tree. 1512 1515 1513 1516 Examples 1514 - -------- 1517 + ======== 1515 1518 >>> from sympy import symbols 1519 + >>> from sympy import symbols, default_sort_key 1516 1520 >>> from sympy.core.basic import preorder_traversal 1517 1521 >>> x, y, z = symbols('x y z') 1518 - >>> list(preorder_traversal(z*(x+y))) in ( # any of these are possible 1519 - ... [z*(x + y), z, x + y, x, y], [z*(x + y), z, x + y, y, x], 1520 - ... [z*(x + y), x + y, x, y, z], [z*(x + y), x + y, y, x, z]) 1521 - True 1522 - >>> list(preorder_traversal((x, (y, z)))) 1523 - [(x, (y, z)), x, (y, z), y, z] 1522 + 1523 + The nodes are returned in the order that they are encountered unless key 1524 + is given. 1525 + 1526 + >>> list(preorder_traversal((x + y)*z, key=None)) # doctest: +SKIP 1527 + [z*(x + y), z, x + y, y, x] 1528 + >>> list(preorder_traversal((x + y)*z, key=default_sort_key)) 1529 + [z*(x + y), z, x + y, x, y] 1524 1530 1525 1531 """ 1526 - def __init__(self, node): 1532 + def __init__(self, node, key=None): 1527 1533 self._skip_flag = False 1528 - self._pt = self._preorder_traversal(node) 1534 + self._pt = self._preorder_traversal(node, key) 1529 1535 1530 - def _preorder_traversal(self, node): 1536 + def _preorder_traversal(self, node, key): 1531 1537 yield node 1532 1538 if self._skip_flag: 1533 1539 self._skip_flag = False 1534 1540 return 1535 1541 if isinstance(node, Basic): 1536 - for arg in node.args: 1537 - for subtree in self._preorder_traversal(arg): 1542 + args = node.args 1543 + if key: 1544 + args = list(args) 1545 + args.sort(key=key) 1546 + for arg in args: 1547 + for subtree in self._preorder_traversal(arg, key): 1538 1548 yield subtree 1539 1549 elif iterable(node): 1540 1550 for item in node: 1541 - for subtree in self._preorder_traversal(item): 1551 + for subtree in self._preorder_traversal(item, key): 1542 1552 yield subtree 1543 1553 1544 1554 def skip(self):
40 sympy/core/function.py
 @@ -1347,7 +1347,7 @@ class Subs(Expr): 1347 1347 1348 1348 An example with several variables: 1349 1349 1350 - >>> Subs(f(x)*sin(y)+z, (x, y), (0, 1)) 1350 + >>> Subs(f(x)*sin(y) + z, (x, y), (0, 1)) 1351 1351 Subs(z + f(_x)*sin(_y), (_x, _y), (0, 1)) 1352 1352 >>> _.doit() 1353 1353 z + f(0)*sin(1) @@ -1359,14 +1359,12 @@ def __new__(cls, expr, variables, point, **assumptions): 1359 1359 variables = Tuple(*sympify(variables)) 1360 1360 1361 1361 if uniq(variables) != variables: 1362 - repeated = repeated = [ v for v in set(variables) 1362 + repeated = [ v for v in set(variables) 1363 1363 if list(variables).count(v) > 1 ] 1364 1364 raise ValueError('cannot substitute expressions %s more than ' 1365 1365 'once.' % repeated) 1366 1366 1367 - if not is_sequence(point, Tuple): 1368 - point = [point] 1369 - point = Tuple(*sympify(point)) 1367 + point = Tuple(*sympify(point if is_sequence(point, Tuple) else [point])) 1370 1368 1371 1369 if len(point) != len(variables): 1372 1370 raise ValueError('Number of point values must be the same as ' @@ -1417,32 +1415,16 @@ def free_symbols(self): 1417 1415 def __eq__(self, other): 1418 1416 if not isinstance(other, Subs): 1419 1417 return False 1420 - if (len(self.point) != len(other.point) or 1421 - self.free_symbols != other.free_symbols or 1422 - sorted(self.point) != sorted(other.point)): 1418 + 1419 + if len(self.expr.free_symbols) != len(other.expr.free_symbols): 1423 1420 return False 1424 1421 1425 - # non-repeated point args 1426 - selfargs = [ v[0] for v in sorted(zip(self.variables, self.point), 1427 - key = lambda v: v[1]) if list(self.point.args).count(v[1]) == 1 ] 1428 - otherargs = [ v[0] for v in sorted(zip(other.variables, other.point), 1429 - key = lambda v: v[1]) if list(other.point.args).count(v[1]) == 1 ] 1430 - # find repeated point values and subs each associated variable 1431 - # for a single symbol 1432 - selfrepargs = [] 1433 - otherrepargs = [] 1434 - if uniq(self.point) != self.point: 1435 - repeated = uniq([ v for v in self.point if 1436 - list(self.point.args).count(v) > 1 ]) 1437 - repswap = dict(zip(repeated, [ C.Dummy() for _ in 1438 - xrange(len(repeated)) ])) 1439 - selfrepargs = [ (self.variables[i], repswap[v]) for i, v in 1440 - enumerate(self.point) if v in repeated ] 1441 - otherrepargs = [ (other.variables[i], repswap[v]) for i, v in 1442 - enumerate(other.point) if v in repeated ] 1443 - 1444 - return self.expr.subs(selfrepargs) == other.expr.subs( 1445 - tuple(zip(otherargs, selfargs))).subs(otherrepargs) 1422 + # replace points with dummies, each unique pt getting its own dummy 1423 + pts = set(self.point.args + other.point.args) 1424 + d = dict([(p, Dummy()) for p in pts]) 1425 + eq = lambda e: \ 1426 + e.expr.xreplace(dict(zip(e.variables, [d[p] for p in e.point]))) 1427 + return eq(self) == eq(other) 1446 1428 1447 1429 def __ne__(self, other): 1448 1430 return not(self == other)
7 sympy/core/tests/test_basic.py
 @@ -3,6 +3,8 @@ 3 3 4 4 from sympy.core.basic import Basic, Atom, preorder_traversal 5 5 from sympy.core.singleton import S, Singleton 6 +from sympy.core.symbol import symbols 7 +from sympy.utilities.misc import default_sort_key 6 8 7 9 from sympy.utilities.pytest import raises 8 10 @@ -117,3 +119,8 @@ def test_preorder_traversal(): 117 119 if i == b2: 118 120 pt.skip() 119 121 assert result == [expr, b21, b2, b1, b3, b2] 122 + 123 + w, x, y, z = symbols('w:z') 124 + expr = z + w*(x+y) 125 + assert list(preorder_traversal([expr], key=default_sort_key)) == \ 126 + [[w*(x + y) + z], w*(x + y) + z, z, w*(x + y), w, x + y, x, y]
6 sympy/core/tests/test_functions.py
 @@ -136,6 +136,12 @@ def test_Lambda_equality(): 136 136 def test_Subs(): 137 137 assert Subs(f(x), x, 0).doit() == f(0) 138 138 assert Subs(f(x**2), x**2, 0).doit() == f(0) 139 + assert Subs(f(x, y, z), (x, y, z), (0, 1, 1)) != \ 140 + Subs(f(x, y, z), (x, y, z), (0, 0, 1)) 141 + assert Subs(f(x, y), (x, y, z), (0, 1, 1)) == \ 142 + Subs(f(x, y), (x, y, z), (0, 1, 2)) 143 + assert Subs(f(x, y), (x, y, z), (0, 1, 1)) != \ 144 + Subs(f(x, y) + z, (x, y, z), (0, 1, 0)) 139 145 assert Subs(f(x, y), (x, y), (0, 1)).doit() == f(0, 1) 140 146 assert Subs(Subs(f(x, y), x, 0), y, 1).doit() == f(0, 1) 141 147 raises(ValueError, lambda: Subs(f(x, y), (x, y), (0, 0, 1)))
73 sympy/solvers/solvers.py
 @@ -434,10 +434,10 @@ def solve(f, *symbols, **flags): 434 434 [3] 435 435 >>> solve(Poly(x - 3), x) 436 436 [3] 437 - >>> solve(x**2 - y**2, x) 438 - [-y, y] 439 - >>> solve(x**4 - 1, x) 440 - [-1, 1, -I, I] 437 + >>> set(solve(x**2 - y**2, x)) 438 + set([-y, y]) 439 + >>> set(solve(x**4 - 1, x)) 440 + set([-1, 1, -I, I]) 441 441 442 442 * single expression with no symbol that is in the expression 443 443 @@ -455,9 +455,9 @@ def solve(f, *symbols, **flags): 455 455 456 456 >>> solve(x - 3) 457 457 [3] 458 - >>> solve(x**2 - y**2) 458 + >>> solve(x**2 - y**2) # doctest: +SKIP 459 459 [{x: -y}, {x: y}] 460 - >>> solve(z**2*x**2 - z**2*y**2) 460 + >>> solve(z**2*x**2 - z**2*y**2) # doctest: +SKIP 461 461 [{x: -y}, {x: y}] 462 462 >>> solve(z**2*x - z**2*y**2) 463 463 [{x: y**2}] @@ -473,8 +473,8 @@ def solve(f, *symbols, **flags): 473 473 [x + f(x)] 474 474 >>> solve(f(x).diff(x) - f(x) - x, f(x)) 475 475 [-x + Derivative(f(x), x)] 476 - >>> solve(x + exp(x)**2, exp(x)) 477 - [-sqrt(-x), sqrt(-x)] 476 + >>> set(solve(x + exp(x)**2, exp(x))) 477 + set([-sqrt(-x), sqrt(-x)]) 478 478 479 479 * To solve for a *symbol* implicitly, use 'implicit=True': 480 480 @@ -501,17 +501,17 @@ def solve(f, *symbols, **flags): 501 501 502 502 * that are nonlinear 503 503 504 - >>> solve((a + b)*x - b**2 + 2, a, b) 505 - [(-sqrt(2), sqrt(2)), (sqrt(2), -sqrt(2))] 504 + >>> set(solve((a + b)*x - b**2 + 2, a, b)) 505 + set([(-sqrt(2), sqrt(2)), (sqrt(2), -sqrt(2))]) 506 506 507 507 * if there is no linear solution then the first successful 508 508 attempt for a nonlinear solution will be returned 509 509 510 - >>> solve(x**2 - y**2, x, y) 510 + >>> solve(x**2 - y**2, x, y) # doctest: +SKIP 511 511 [{x: -y}, {x: y}] 512 512 >>> solve(x**2 - y**2/exp(x), x, y) 513 513 [{x: 2*LambertW(y/2)}] 514 - >>> solve(x**2 - y**2/exp(x), y, x) 514 + >>> solve(x**2 - y**2/exp(x), y, x) # doctest: +SKIP 515 515 [{y: -x*exp(x/2)}, {y: x*exp(x/2)}] 516 516 517 517 * iterable of one or more of the above @@ -543,8 +543,8 @@ def solve(f, *symbols, **flags): 543 543 544 544 * when the system is not linear 545 545 546 - >>> solve([x**2 + y -2, y**2 - 4], x, y) 547 - [(-2, -2), (0, 2), (0, 2), (2, -2)] 546 + >>> set(solve([x**2 + y -2, y**2 - 4], x, y)) 547 + set([(-2, -2), (0, 2), (2, -2)]) 548 548 549 549 * if no symbols are given, all free symbols will be selected and a list 550 550 of mappings returned @@ -659,7 +659,7 @@ def _sympified_list(w): 659 659 # we do this to make the results returned canonical in case f 660 660 # contains a system of nonlinear equations; all other cases should 661 661 # be unambiguous 662 - symbols = sorted(symbols, key=lambda i: i.sort_key()) 662 + symbols = sorted(symbols, key=default_sort_key) 663 663 664 664 # we can solve for Function and Derivative instances by replacing them 665 665 # with Dummy symbols or functions @@ -1287,9 +1287,10 @@ def _solve_system(exprs, symbols, **flags): 1287 1287 else: 1288 1288 if len(symbols) != len(polys): 1289 1289 from sympy.utilities.iterables import subsets 1290 - free = list(reduce(set.union, 1291 - [p.free_symbols for p in polys], set() 1292 - ).intersection(symbols)) 1290 + from sympy.core.compatibility import set_union 1291 + 1292 + free = set_union(*[p.free_symbols for p in polys]) 1293 + free = list(free.intersection(symbols)) 1293 1294 free.sort(key=default_sort_key) 1294 1295 for syms in subsets(free, len(polys)): 1295 1296 try: @@ -1341,16 +1342,26 @@ def _solve_system(exprs, symbols, **flags): 1341 1342 result = [result] 1342 1343 else: 1343 1344 result = [{}] 1345 + 1346 + def _ok_syms(e, sort=False): 1347 + rv = (e.free_symbols - solved_syms) & legal 1348 + if sort: 1349 + rv = list(rv) 1350 + rv.sort(key=default_sort_key) 1351 + return rv 1352 + 1344 1353 solved_syms = set(solved_syms) # set of symbols we have solved for 1345 1354 legal = set(symbols) # what we are interested in 1346 1355 simplify_flag = flags.get('simplify', None) 1347 1356 do_simplify = flags.get('simplify', True) 1348 - # sort so equation with the fewest potential symbols is first; break ties with 1349 - # count_ops and sort_key 1350 - short = sift(failed, lambda x: len((x.free_symbols - solved_syms) & legal)) 1357 + # sort so equation with the fewest potential symbols is first; 1358 + # break ties with count_ops and default_sort_key 1359 + short = sift(failed, lambda x: len(_ok_syms(x))) 1351 1360 failed = [] 1352 - for k in sorted(short): 1353 - failed.extend(sorted(short[k], key=lambda x: (x.count_ops(), x.sort_key()))) 1361 + for k in sorted(short, key=default_sort_key): 1362 + failed.extend(sorted(sorted(short[k], 1363 + key=lambda x: x.count_ops()), 1364 + key=default_sort_key)) 1354 1365 for eq in failed: 1355 1366 newresult = [] 1356 1367 got_s = None @@ -1365,7 +1376,7 @@ def _solve_system(exprs, symbols, **flags): 1365 1376 continue 1366 1377 # search for a symbol amongst those available that 1367 1378 # can be solved for 1368 - ok_syms = (eq2.free_symbols - solved_syms) & legal 1379 + ok_syms = _ok_syms(eq2, sort=True) 1369 1380 if not ok_syms: 1370 1381 break # skip as it's independent of desired symbols 1371 1382 for s in ok_syms: @@ -1591,7 +1602,7 @@ def solve_linear_system(system, *symbols, **flags): 1591 1602 matrix = system[:, :] 1592 1603 syms = list(symbols) 1593 1604 1594 - i, m = 0, matrix.cols-1 # don't count augmentation 1605 + i, m = 0, matrix.cols - 1 # don't count augmentation 1595 1606 1596 1607 while i < matrix.rows: 1597 1608 if i == m: @@ -1611,12 +1622,12 @@ def solve_linear_system(system, *symbols, **flags): 1611 1622 break 1612 1623 else: 1613 1624 if matrix[i, m]: 1614 - # we need to know this this is always zero or not. We 1625 + # we need to know if this is always zero or not. We 1615 1626 # assume that if there are free symbols that it is not 1616 1627 # identically zero (or that there is more than one way 1617 1628 # to make this zero. Otherwise, if there are none, this 1618 1629 # is a constant and we assume that it does not simplify 1619 - # to zero XXX are there better ways to test this/ 1630 + # to zero XXX are there better ways to test this? 1620 1631 if not matrix[i, m].free_symbols: 1621 1632 return None # no solution 1622 1633 @@ -1664,7 +1675,7 @@ def solve_linear_system(system, *symbols, **flags): 1664 1675 # divide all elements in the current row by the pivot 1665 1676 matrix.row(i, lambda x, _: x * pivot_inv) 1666 1677 1667 - for k in xrange(i+1, matrix.rows): 1678 + for k in xrange(i + 1, matrix.rows): 1668 1679 if matrix[k, i]: 1669 1680 coeff = matrix[k, i] 1670 1681 @@ -1689,7 +1700,7 @@ def solve_linear_system(system, *symbols, **flags): 1689 1700 content = matrix[k, m] 1690 1701 1691 1702 # run back-substitution for variables 1692 - for j in xrange(k+1, m): 1703 + for j in xrange(k + 1, m): 1693 1704 content -= matrix[k, j]*solutions[syms[j]] 1694 1705 1695 1706 if do_simplify: @@ -1703,13 +1714,13 @@ def solve_linear_system(system, *symbols, **flags): 1703 1714 elif len(syms) > matrix.rows: 1704 1715 # this system will have infinite number of solutions 1705 1716 # dependent on exactly len(syms) - i parameters 1706 - k, solutions = i-1, {} 1717 + k, solutions = i - 1, {} 1707 1718 1708 1719 while k >= 0: 1709 1720 content = matrix[k, m] 1710 1721 1711 1722 # run back-substitution for variables 1712 - for j in xrange(k+1, i): 1723 + for j in xrange(k + 1, i): 1713 1724 content -= matrix[k, j]*solutions[syms[j]] 1714 1725 1715 1726 # run back-substitution for parameters
209 sympy/solvers/tests/test_solvers.py
 @@ -30,9 +30,6 @@ def guess_solve_strategy(eq, symbol): 30 30 return False 31 31 32 32 def test_guess_poly(): 33 - """ 34 - See solvers.guess_solve_strategy 35 - """ 36 33 # polynomial equations 37 34 assert guess_solve_strategy( S(4), x ) #== GS_POLY 38 35 assert guess_solve_strategy( x, x ) #== GS_POLY @@ -76,8 +73,8 @@ def test_guess_transcendental(): 76 73 77 74 def test_solve_args(): 78 75 #implicit symbol to solve for 79 - assert set(int(tmp) for tmp in solve(x**2-4)) == set([2,-2]) 80 - assert solve([x+y-3,x-y-5]) == {x: 4, y: -1} 76 + assert set(solve(x**2 - 4)) == set([S(2), -S(2)]) 77 + assert solve([x + y - 3, x - y - 5]) == {x: 4, y: -1} 81 78 #no symbol to solve for 82 79 assert solve(42) == [] 83 80 assert solve([1, 2]) == [] @@ -118,14 +115,14 @@ def test_solve_args(): 118 115 assert solve((x + y - 2, 2*x + 2*y - 4)) == {x: -y + 2} 119 116 120 117 def test_solve_polynomial1(): 121 - assert solve(3*x-2, x) == [Rational(2,3)] 122 - assert solve(Eq(3*x, 2), x) == [Rational(2,3)] 118 + assert solve(3*x-2, x) == [Rational(2, 3)] 119 + assert solve(Eq(3*x, 2), x) == [Rational(2, 3)] 123 120 124 - assert solve(x**2-1, x) in [[-1, 1], [1, -1]] 125 - assert solve(Eq(x**2, 1), x) in [[-1, 1], [1, -1]] 121 + assert set(solve(x**2 - 1, x)) == set([-S(1), S(1)]) 122 + assert set(solve(Eq(x**2, 1), x)) == set([-S(1), S(1)]) 126 123 127 - assert solve( x - y**3, x) == [y**3] 128 - assert sorted(solve( x - y**3, y)) == sorted([ 124 + assert solve(x - y**3, x) == [y**3] 125 + assert set(solve(x - y**3, y)) == set([ 129 126 (-x**Rational(1,3))/2 + I*sqrt(3)*x**Rational(1,3)/2, 130 127 x**Rational(1,3), 131 128 (-x**Rational(1,3))/2 - I*sqrt(3)*x**Rational(1,3)/2, @@ -149,8 +146,8 @@ def test_solve_polynomial1(): 149 146 S(4), 150 147 -2 - 3**Rational(1,2) ]) 151 148 152 - assert sorted(solve((x**2 - 1)**2 - a, x)) == \ 153 - sorted([sqrt(1 + sqrt(a)), -sqrt(1 + sqrt(a)), 149 + assert set(solve((x**2 - 1)**2 - a, x)) == \ 150 + set([sqrt(1 + sqrt(a)), -sqrt(1 + sqrt(a)), 154 151 sqrt(1 - sqrt(a)), -sqrt(1 - sqrt(a))]) 155 152 156 153 def test_solve_polynomial2(): @@ -225,7 +222,7 @@ def test_tsolve(): 225 222 assert set(solve((a*x+b)*(exp(x)-3), x)) == set([-b/a, log(3)]) 226 223 assert solve(cos(x)-y, x) == [acos(y)] 227 224 assert solve(2*cos(x)-y,x)== [acos(y/2)] 228 - assert solve(Eq(cos(x), sin(x)), x) == [-3*pi/4, pi/4] 225 + assert set(solve(Eq(cos(x), sin(x)), x)) == set([-3*pi/4, pi/4]) 229 226 230 227 assert set(solve(exp(x) + exp(-x) - y, x)) == set([ 231 228 log(y/2 - sqrt(y**2 - 4)/2), @@ -240,8 +237,8 @@ def test_tsolve(): 240 237 assert solve(x+2**x, x) == [-LambertW(log(2))/log(2)] 241 238 assert solve(3*x+5+2**(-5*x+3), x) in [ 242 239 [-((25*log(2) - 3*LambertW(-10240*2**(Rational(1, 3))*log(2)/3))/(15*log(2)))], 243 - [Rational(-5, 3) + LambertW(log(2**(-10240*2**(Rational(1, 3))/3)))/(5*log(2))], 244 - [-Rational(5,3) + LambertW(-10240*2**Rational(1,3)*log(2)/3)/(5*log(2))], 240 + [-Rational(5, 3) + LambertW(log(2**(-10240*2**(Rational(1, 3))/3)))/(5*log(2))], 241 + [-Rational(5, 3) + LambertW(-10240*2**Rational(1,3)*log(2)/3)/(5*log(2))], 245 242 [(-25*log(2) + 3*LambertW(-10240*2**(Rational(1, 3))*log(2)/3))/(15*log(2))], 246 243 [-((25*log(2) - 3*LambertW(-10240*2**(Rational(1, 3))*log(2)/3)))/(15*log(2))], 247 244 [-(25*log(2) - 3*LambertW(log(2**(-10240*2**Rational(1, 3)/3))))/(15*log(2))], @@ -385,8 +382,10 @@ def test_issue_1694(): 385 382 eq = 4*3**(5*x + 2) - 7 386 383 ans = solve(eq, x) 387 384 assert len(ans) == 5 and all(eq.subs(x, a).n(chop=True) == 0 for a in ans) 388 - assert solve(log(x**2) - y**2/exp(x), x, y) == [{y: -sqrt(exp(x)*log(x**2))}, 389 - {y: sqrt(exp(x)*log(x**2))}] 385 + assert solve(log(x**2) - y**2/exp(x), x, y, set=True) == \ 386 + ([y], set([ 387 + (-sqrt(exp(x)*log(x**2)),), 388 + (sqrt(exp(x)*log(x**2)),)])) 390 389 assert solve(x**2*z**2 - z**2*y**2) in ([{x: y}, {x: -y}], [{x: -y}, {x: y}]) 391 390 assert solve((x - 1)/(1 + 1/(x - 1))) == [] 392 391 assert solve(x**(y*z) - x, x) == [1] @@ -401,7 +400,7 @@ def test_issue_1694(): 401 400 # 1387 402 401 assert solve(2*x/(x + 2) - 1,x) == [2] 403 402 # 1397 404 - assert solve((x**2/(7 - x)).diff(x)) == [0, 14] 403 + assert set(solve((x**2/(7 - x)).diff(x))) == set([S(0), S(14)]) 405 404 # 1596 406 405 f = Function('f') 407 406 assert solve((3 - 5*x/f(x))*f(x), f(x)) == [5*x/3] @@ -409,22 +408,23 @@ def test_issue_1694(): 409 408 assert solve(1/(5 + x)**(S(1)/5) - 9, x) == [-295244/S(59049)] 410 409 411 410 assert solve(sqrt(x) + sqrt(sqrt(x)) - 4) == [-9*sqrt(17)/2 + 49*S.Half] 412 - assert solve(Poly(sqrt(exp(x)) + sqrt(exp(-x)) - 4)) in \ 411 + assert set(solve(Poly(sqrt(exp(x)) + sqrt(exp(-x)) - 4))) in \ 413 412 [ 414 - [2*log(-sqrt(3) + 2), 2*log(sqrt(3) + 2)], 415 - [log(-4*sqrt(3) + 7), log(4*sqrt(3) + 7)], 413 + set([2*log(-sqrt(3) + 2), 2*log(sqrt(3) + 2)]), 414 + set([log(-4*sqrt(3) + 7), log(4*sqrt(3) + 7)]), 416 415 ] 417 - assert solve(Poly(exp(x) + exp(-x) - 4)) == [log(-sqrt(3) + 2), log(sqrt(3) + 2)] 418 - assert solve(x**y + x**(2*y) - 1, x) == \ 419 - [(-S.Half + sqrt(5)/2)**(1/y), (-S.Half - sqrt(5)/2)**(1/y)] 416 + assert set(solve(Poly(exp(x) + exp(-x) - 4))) == \ 417 + set([log(-sqrt(3) + 2), log(sqrt(3) + 2)]) 418 + assert set(solve(x**y + x**(2*y) - 1, x)) == \ 419 + set([(-S.Half + sqrt(5)/2)**(1/y), (-S.Half - sqrt(5)/2)**(1/y)]) 420 420 421 421 assert solve(exp(x/y)*exp(-z/y) - 2, y) == [(x - z)/log(2)] 422 422 assert solve(x**z*y**z - 2, z) in [[log(2)/(log(x) + log(y))], [log(2)/(log(x*y))]] 423 423 # if you do inversion too soon then multiple roots as for the following will 424 424 # be missed, e.g. if exp(3*x) = exp(3) -> 3*x = 3 425 425 E = S.Exp1 426 - assert solve(exp(3*x) - exp(3), x) == \ 427 - [1, log(-E/2 - sqrt(3)*E*I/2), log(-E/2 + sqrt(3)*E*I/2)] 426 + assert set(solve(exp(3*x) - exp(3), x)) == \ 427 + set([S(1), log(-E/2 - sqrt(3)*E*I/2), log(-E/2 + sqrt(3)*E*I/2)]) 428 428 429 429 def test_issue_2098(): 430 430 x = Symbol('x', real=True) @@ -461,9 +461,9 @@ def test_checking(): 461 461 def test_issue_1572_1364_1368(): 462 462 assert solve((sqrt(x**2 - 1) - 2)) in ([sqrt(5), -sqrt(5)], 463 463 [-sqrt(5), sqrt(5)]) 464 - assert solve((2**exp(y**2/x) + 2)/(x**2 + 15), y) == [ 464 + assert set(solve((2**exp(y**2/x) + 2)/(x**2 + 15), y)) == set([ 465 465 -sqrt(x)*sqrt(-log(log(2)) + log(log(2) + I*pi)), 466 - sqrt(x)*sqrt(-log(log(2)) + log(log(2) + I*pi))] 466 + sqrt(x)*sqrt(-log(log(2)) + log(log(2) + I*pi))]) 467 467 468 468 C1, C2 = symbols('C1 C2') 469 469 f = Function('f') @@ -480,65 +480,65 @@ def test_issue_1572_1364_1368(): 480 480 assert solve(1 - log(a + 4*x**2), x) in ( 481 481 [-sqrt(-a + E)/2, sqrt(-a + E)/2], 482 482 [sqrt(-a + E)/2, -sqrt(-a + E)/2],) 483 - assert solve((a**2 + 1) * (sin(a*x) + cos(a*x)), x) == [-pi/(4*a), 3*pi/(4*a)] 483 + assert set(solve((a**2 + 1) * (sin(a*x) + cos(a*x)), x)) == set([-pi/(4*a), 3*pi/(4*a)]) 484 484 assert solve(3 - (sinh(a*x) + cosh(a*x)), x) == [2*atanh(S.Half)/a] 485 - assert solve(3-(sinh(a*x) + cosh(a*x)**2), x) == \ 486 - [ 485 + assert set(solve(3-(sinh(a*x) + cosh(a*x)**2), x)) == \ 486 + set([ 487 487 2*atanh(-1 + sqrt(2))/a, 488 488 2*atanh(S(1)/2 + sqrt(5)/2)/a, 489 489 2*atanh(-sqrt(2) - 1)/a, 490 490 2*atanh(-sqrt(5)/2 + S(1)/2)/a 491 - ] 491 + ]) 492 492 assert solve(atan(x) - 1) == [tan(1)] 493 493 494 494 def test_issue_2033(): 495 495 r, t = symbols('r,t') 496 - assert solve([r - x**2 - y**2, tan(t) - y/x], [x, y]) == \ 497 - [ 496 + assert set(solve([r - x**2 - y**2, tan(t) - y/x], [x, y])) == \ 497 + set([ 498 498 (-sqrt(r*sin(t)**2)/tan(t), -sqrt(r*sin(t)**2)), 499 - (sqrt(r*sin(t)**2)/tan(t), sqrt(r*sin(t)**2))] 499 + (sqrt(r*sin(t)**2)/tan(t), sqrt(r*sin(t)**2))]) 500 500 assert solve([exp(x) - sin(y), 1/y - 3], [x, y]) == \ 501 501 [(log(sin(S(1)/3)), S(1)/3)] 502 502 assert solve([exp(x) - sin(y), 1/exp(y) - 3], [x, y]) == \ 503 503 [(log(-sin(log(3))), -log(3))] 504 - assert solve([exp(x) - sin(y), y**2 - 4], [x, y]) == \ 505 - [(log(-sin(2)), -2), (log(sin(2)), 2)] 504 + assert set(solve([exp(x) - sin(y), y**2 - 4], [x, y])) == \ 505 + set([(log(-sin(2)), -S(2)), (log(sin(2)), S(2))]) 506 506 eqs = [exp(x)**2 - sin(y) + z**2, 1/exp(y) - 3] 507 - assert solve(eqs) == \ 508 - [ 509 - {x: log(-sqrt(-z**2 - sin(log(3)))), y: -log(3)}, 510 - {x: log(sqrt(-z**2 - sin(log(3)))), y: -log(3)}] 511 - assert solve(eqs, x, z) == \ 512 - [ 513 - {x: log(-sqrt(-z**2 + sin(y)))}, 514 - {x: log(sqrt(-z**2 + sin(y)))}] 515 - assert solve(eqs, x, y) == \ 516 - [ 507 + assert solve(eqs, set=True) == \ 508 + ([x, y], set([ 517 509 (log(-sqrt(-z**2 - sin(log(3)))), -log(3)), 518 - (log(sqrt(-z**2 - sin(log(3)))), -log(3))] 519 - assert solve(eqs, y, z) == \ 520 - [ 510 + (log(sqrt(-z**2 - sin(log(3)))), -log(3))])) 511 + assert solve(eqs, x, z, set=True) == \ 512 + ([x], set([ 513 + (log(-sqrt(-z**2 + sin(y))),), 514 + (log(sqrt(-z**2 + sin(y))),)])) 515 + assert set(solve(eqs, x, y)) == \ 516 + set([ 517 + (log(-sqrt(-z**2 - sin(log(3)))), -log(3)), 518 + (log(sqrt(-z**2 - sin(log(3)))), -log(3))]) 519 + assert set(solve(eqs, y, z)) == \ 520 + set([ 521 521 (-log(3), -sqrt(-exp(2*x) - sin(log(3)))), 522 - (-log(3), sqrt(-exp(2*x) - sin(log(3))))] 522 + (-log(3), sqrt(-exp(2*x) - sin(log(3))))]) 523 523 eqs = [exp(x)**2 - sin(y) + z, 1/exp(y) - 3] 524 - assert solve(eqs) == \ 524 + assert solve(eqs, set=True) == ([x, y], set( 525 525 [ 526 - {x: log(-sqrt(-z - sin(log(3)))), y: -log(3)}, 527 - {x: log(sqrt(-z - sin(log(3)))), y: -log(3)}] 528 - assert solve(eqs, x, z) == \ 526 + (log(-sqrt(-z - sin(log(3)))), -log(3)), 527 + (log(sqrt(-z - sin(log(3)))), -log(3))])) 528 + assert solve(eqs, x, z, set=True) == ([x], set( 529 529 [ 530 - {x: log(-sqrt(-z + sin(y)))}, 531 - {x: log(sqrt(-z + sin(y)))}] 532 - assert solve(eqs, x, y) == \ 530 + (log(-sqrt(-z + sin(y))),), 531 + (log(sqrt(-z + sin(y))),)])) 532 + assert set(solve(eqs, x, y)) == set( 533 533 [ 534 534 (log(-sqrt(-z - sin(log(3)))), -log(3)), 535 - (log(sqrt(-z - sin(log(3)))), -log(3))] 535 + (log(sqrt(-z - sin(log(3)))), -log(3))]) 536 536 assert solve(eqs, z, y) == \ 537 537 [(-exp(2*x) - sin(log(3)), -log(3))] 538 - assert solve((sqrt(x**2 + y**2) - sqrt(10), x + y - 4)) == \ 539 - [{x: 1, y: 3}, {x: 3, y: 1}] 540 - assert solve((sqrt(x**2 + y**2) - sqrt(10), x + y - 4), x, y) == \ 541 - [(1, 3), (3, 1)] 538 + assert solve((sqrt(x**2 + y**2) - sqrt(10), x + y - 4), set=True) == ( 539 + [x, y], set([(S(1), S(3)), (S(3), S(1))])) 540 + assert set(solve((sqrt(x**2 + y**2) - sqrt(10), x + y - 4), x, y)) == \ 541 + set([(S(1), S(3)), (S(3), S(1))]) 542 542 543 543 def test_issue_2236(): 544 544 lam, a0, conc = symbols('lam a0 conc') @@ -559,11 +559,13 @@ def test_issue_2236_float(): 559 559 assert len(solve(eqs, sym, rational=False, check=False, simplify=False)) == 2 560 560 561 561 def test_issue_2668(): 562 - assert solve([x**2 + y + 4], [x]) == [(-sqrt(-y - 4),), (sqrt(-y - 4),)] 562 + assert set(solve([x**2 + y + 4], [x])) == \ 563 + set([(-sqrt(-y - 4),), (sqrt(-y - 4),)]) 563 564 564 565 def test_polysys(): 565 - assert solve([x**2 + 2/y - 2 , x + y - 3], [x, y]) == \ 566 - [(1, 2), (1 + sqrt(5), 2 - sqrt(5)), (1 - sqrt(5), 2 + sqrt(5))] 566 + assert set(solve([x**2 + 2/y - 2 , x + y - 3], [x, y])) == \ 567 + set([(S(1), S(2)), (1 + sqrt(5), 2 - sqrt(5)), 568 + (1 - sqrt(5), 2 + sqrt(5))]) 567 569 assert solve([x**2 + y - 2, x**2 + y]) == [] 568 570 # the ordering should be whatever the user requested 569 571 assert solve([x**2 + y - 3, x - y - 4], (x, y)) != solve([x**2 + y - 3, x - y - 4], (y, x)) @@ -626,10 +628,11 @@ def s_check(rv, ans): 626 628 eq = sqrt(x) + sqrt(x + 1) + sqrt(1 - sqrt(x)) 627 629 assert check(unrad(eq), 628 630 (16*x**3 - 9*x**2, [], [])) 629 - assert solve(eq, check=False) == [0, S(9)/16] 631 + assert set(solve(eq, check=False)) == set([S(0), S(9)/16]) 630 632 assert solve(eq) == [] 631 633 # but this one really does have those solutions 632 - assert solve(sqrt(x) - sqrt(x + 1) + sqrt(1 - sqrt(x))) == [0, S(9)/16] 634 + assert set(solve(sqrt(x) - sqrt(x + 1) + sqrt(1 - sqrt(x)))) == \ 635 + set([S.Zero, S(9)/16]) 633 636 634 637 '''real_root changes the value of the result if the solution is 635 638 simplified; `a` in the text below is the root that is not 4/5: @@ -661,8 +664,8 @@ def s_check(rv, ans): 661 664 ans = solve(sqrt(x) + sqrt(x + 1) + sqrt(1 - x) - 6*sqrt(5)/5) 662 665 assert all(abs(eq.subs(x, i).n()) < 1e-10 for i in (ra, rb)) and \ 663 666 len(ans) == 2 and \ 664 - sorted([i.n(chop=True) for i in ans]) == \ 665 - sorted([i.n(chop=True) for i in (ra, rb)]) 667 + set([i.n(chop=True) for i in ans]) == \ 668 + set([i.n(chop=True) for i in (ra, rb)]) 666 669 667 670 ans = solve(sqrt(x) + sqrt(x + 1) - \ 668 671 sqrt(1 - x) - sqrt(2 + x)) @@ -694,15 +697,15 @@ def s_check(rv, ans): 694 697 assert solve(Eq(x, sqrt(x + 6))) == [3] 695 698 assert solve(Eq(x + sqrt(x - 4), 4)) == [4] 696 699 assert solve(Eq(1, x + sqrt(2*x - 3))) == [] 697 - assert solve(Eq(sqrt(5*x + 6) - 2, x)) == [-1, 2] 698 - assert solve(Eq(sqrt(2*x - 1) - sqrt(x - 4), 2)) == [5, 13] 700 + assert set(solve(Eq(sqrt(5*x + 6) - 2, x))) == set([-S(1), S(2)]) 701 + assert set(solve(Eq(sqrt(2*x - 1) - sqrt(x - 4), 2))) == set([S(5), S(13)]) 699 702 assert solve(Eq(sqrt(x + 7) + 2, sqrt(3 - x))) == [-6] 700 703 # http://www.purplemath.com/modules/solverad.htm 701 704 assert solve((2*x - 5)**Rational(1, 3) - 3) == [16] 702 705 assert solve((x**3 - 3*x**2)**Rational(1, 3) + 1 - x) == [] 703 - assert solve(x + 1 - (x**4 + 4*x**3 - x)**Rational(1, 4)) == \ 704 - [-S(1)/2, -S(1)/3] 705 - assert solve(sqrt(2*x**2 - 7) - (3 - x)) == [-8, 2] 706 + assert set(solve(x + 1 - (x**4 + 4*x**3 - x)**Rational(1, 4))) == \ 707 + set([-S(1)/2, -S(1)/3]) 708 + assert set(solve(sqrt(2*x**2 - 7) - (3 - x))) == set([-S(8), S(2)]) 706 709 assert solve(sqrt(2*x + 9) - sqrt(x + 1) - sqrt(x + 4)) == [0] 707 710 assert solve(sqrt(x + 4) + sqrt(2*x - 1) - 3*sqrt(x - 1)) == [5] 708 711 assert solve(sqrt(x)*sqrt(x - 7) - 12) == [16] @@ -784,14 +787,13 @@ def test_issue_2750(): 784 787 ) 785 788 786 789 ans = [{ 787 - dI4: -I3 + 3*I5 - 2*Q4, 790 + dQ4: I3 - I5, 788 791 dI1: -4*I2 - 8*I3 - 4*I5 - 6*I6 + 24, 792 + I4: I3 - I5, 789 793 dQ2: I2, 790 - I1: I2 + I3, 791 794 Q2: 2*I3 + 2*I5 + 3*I6, 792 - dQ4: I3 - I5, 793 - I4: I3 - I5, 794 - }] 795 + I1: I2 + I3, 796 + Q4: -I3/2 + 3*I5/2 - dI4/2}] 795 797 assert solve(e, I1, I4, Q2, Q4, dI1, dI4, dQ2, dQ4, manual=True) == ans 796 798 # the matrix solver (tested below) doesn't like this because it produces 797 799 # a zero row in the matrix. Is this related to issue 1452? @@ -837,18 +839,21 @@ def test_issue_2802(): 837 839 {f(x): 3*D} 838 840 assert solve([f(x) - 3*f(x).diff(x), f(x)**2 - y + 4], f(x), y) == \ 839 841 [{f(x): 3*D, y: 9*D**2 + 4}] 840 - assert solve(-f(a)**2*g(a)**2 + f(a)**2*h(a)**2 + g(a).diff(a), h(a), g(a)) == \ 841 - [{g(a): -sqrt(h(a)**2 + G/f(a)**2)}, 842 - {g(a): sqrt(h(a)**2 + G/f(a)**2)}] 842 + assert solve(-f(a)**2*g(a)**2 + f(a)**2*h(a)**2 + g(a).diff(a), 843 + h(a), g(a), set=True) == \ 844 + ([g(a)], set([ 845 + (-sqrt(h(a)**2 + G/f(a)**2),), 846 + (sqrt(h(a)**2 + G/f(a)**2),)])) 843 847 args = [f(x).diff(x, 2)*(f(x) + g(x)) - g(x)**2 + 2, f(x), g(x)] 844 - assert solve(*args) == \ 845 - [(-sqrt(2), sqrt(2)), (sqrt(2), -sqrt(2))] 848 + assert set(solve(*args)) == \ 849 + set([(-sqrt(2), sqrt(2)), (sqrt(2), -sqrt(2))]) 846 850 eqs = [f(x)**2 + g(x) - 2*f(x).diff(x), g(x)**2 - 4] 847 - assert solve(eqs, f(x), g(x)) == \ 848 - [{g(x): 2, f(x): -sqrt(2*D - 2)}, 849 - {g(x): 2, f(x): sqrt(2*D - 2)}, 850 - {g(x): -2, f(x): -sqrt(2*D + 2)}, 851 - {g(x): -2, f(x): sqrt(2*D + 2)}] 851 + assert solve(eqs, f(x), g(x), set=True) == \ 852 + ([f(x), g(x)], set([ 853 + (-sqrt(2*D - 2), S(2)), 854 + (sqrt(2*D - 2), S(2)), 855 + (-sqrt(2*D + 2), -S(2)), 856 + (sqrt(2*D + 2), -S(2))])) 852 857 853 858 # the underlying problem was in solve_linear that was not masking off 854 859 # anything but a Mul or Add; it now raises an error if it gets anything @@ -872,8 +877,8 @@ def test_issue_2802(): 872 877 assert solve_linear(x + Integral(x, y) - 2, symbols=[x]) == \ 873 878 (x, 2/(y + 1)) 874 879 875 - assert solve(x + exp(x)**2, exp(x)) == \ 876 - [-sqrt(-x), sqrt(-x)] 880 + assert set(solve(x + exp(x)**2, exp(x))) == \ 881 + set([-sqrt(-x), sqrt(-x)]) 877 882 assert solve(x + exp(x), x, implicit=True) == \ 878 883 [-exp(x)] 879 884 assert solve(cos(x) - sin(x), x, implicit=True) == [] @@ -885,8 +890,8 @@ def test_issue_2802(): 885 890 [-x + 3] 886 891 887 892 def test_issue_2813(): 888 - assert solve(x**2 - x - 0.1, rational=True) == \ 889 - [S(1)/2 + sqrt(35)/10, -sqrt(35)/10 + S(1)/2] 893 + assert set(solve(x**2 - x - 0.1, rational=True)) == \ 894 + set([S(1)/2 + sqrt(35)/10, -sqrt(35)/10 + S(1)/2]) 890 895 # [-0.0916079783099616, 1.09160797830996] 891 896 ans = solve(x**2 - x - 0.1, rational=False) 892 897 assert len(ans) == 2 and all(a.is_Number for a in ans) @@ -923,18 +928,18 @@ def test_check_assumptions(): 923 928 assert solve(x**2 - 1) == [1] 924 929 925 930 def test_solve_abs(): 926 - assert solve(abs(x - 7) - 8) == [-1, 15] 931 + assert set(solve(abs(x - 7) - 8)) == set([-S(1), S(15)]) 927 932 928 933 def test_issue_2957(): 929 934 assert solve(tanh(x + 3)*tanh(x - 3) - 1) == [] 930 - assert solve(tanh(x - 1)*tanh(x + 1) + 1) == [ 935 + assert set(solve(tanh(x - 1)*tanh(x + 1) + 1)) == set([ 931 936 -log(2)/2 + log(-1 - I), 932 937 -log(2)/2 + log(-1 + I), 933 938 -log(2)/2 + log(1 - I), 934 - -log(2)/2 + log(1 + I)] 935 - assert solve((tanh(x + 3)*tanh(x - 3) + 1)**2) == \ 936 - [-log(2)/2 + log(-1 - I), -log(2)/2 + log(-1 + I), 937 - -log(2)/2 + log(1 - I), -log(2)/2 + log(1 + I)] 939 + -log(2)/2 + log(1 + I)]) 940 + assert set(solve((tanh(x + 3)*tanh(x - 3) + 1)**2)) == \ 941 + set([-log(2)/2 + log(-1 - I), -log(2)/2 + log(-1 + I), 942 + -log(2)/2 + log(1 - I), -log(2)/2 + log(1 + I)]) 938 943 939 944 def test_issue_2574(): 940 945 eq = -x + exp(exp(LambertW(log(x)))*LambertW(log(x))) @@ -965,4 +970,4 @@ def test_exclude(): 965 970 966 971 def test_high_order_roots(): 967 972 s = x**5 + 4*x**3 + 3*x**2 + S(7)/4 968 - assert solve(s) == Poly(s*4, domain='ZZ').all_roots() 973 + assert set(solve(s)) == set(Poly(s*4, domain='ZZ').all_roots())
44 sympy/utilities/iterables.py
 @@ -4,6 +4,7 @@ 4 4 from sympy.core import Basic, C 5 5 from sympy.core.compatibility import is_sequence, iterable #logically, these belong here 6 6 from sympy.core.compatibility import product as cartes, combinations, combinations_with_replacement 7 +from sympy.utilities.misc import default_sort_key 7 8 from sympy.utilities.exceptions import SymPyDeprecationWarning 8 9 9 10 def flatten(iterable, levels=None, cls=None): @@ -110,7 +111,7 @@ def group(container, multiple=True): 110 111 111 112 return groups 112 113 113 -def postorder_traversal(node): 114 +def postorder_traversal(node, key=None): 114 115 """ 115 116 Do a postorder traversal of a tree. 116 117 @@ -118,39 +119,46 @@ def postorder_traversal(node): 118 119 fashion. That is, it descends through the tree depth-first to yield all of 119 120 a node's children's postorder traversal before yielding the node itself. 120 121 121 - For an expression, the order of the traversal depends on the order of 122 - .args, which in many cases can be arbitrary. 123 - 124 122 Parameters 125 - ---------- 123 + ========== 126 124 node : sympy expression 127 125 The expression to traverse. 126 + key : (default None) sort key 127 + The key used to sort args of Basic objects. When None, args of Basic 128 + objects are processed in arbitrary order. 128 129 129 130 Returns 130 - ------- 131 + ======= 131 132 subtree : sympy expression 132 133 All of the subtrees in the tree. 133 134 134 135 Examples 135 - -------- 136 - >>> from sympy import symbols 136 + ======== 137 + >>> from sympy import symbols, default_sort_key 137 138 >>> from sympy.utilities.iterables import postorder_traversal 138 - >>> from sympy.abc import x, y, z 139 - >>> list(postorder_traversal(z*(x+y))) in ( # any of these are possible 140 - ... [z, y, x, x + y, z*(x + y)], [z, x, y, x + y, z*(x + y)], 141 - ... [x, y, x + y, z, z*(x + y)], [y, x, x + y, z, z*(x + y)]) 142 - True 143 - >>> list(postorder_traversal((x, (y, z)))) 144 - [x, y, z, (y, z), (x, (y, z))] 139 + >>> from sympy.abc import w, x, y, z 140 + 141 + The nodes are returned in the order that they are encountered unless key 142 + is given. 143 + 144 + >>> list(postorder_traversal(w + (x + y)*z)) # doctest: +SKIP 145 + [z, y, x, x + y, z*(x + y), w, w + z*(x + y)] 146 + >>> list(postorder_traversal(w + (x + y)*z, key=default_sort_key)) 147 + [w, z, x, y, x + y, z*(x + y), w + z*(x + y)] 148 + 145 149 146 150 """ 147 151 if isinstance(node, Basic): 148 - for arg in node.args: 149 - for subtree in postorder_traversal(arg): 152 + args = node.args 153 + if key: 154 + args = list(args) 155 + args.sort(key=key) 156 + for arg in args: 157 + for subtree in postorder_traversal(arg, key): 150 158 yield subtree 151 159 elif iterable(node): 152 160 for item in node: 153 - for subtree in postorder_traversal(item): 161 + for subtree in postorder_traversal(item, key): 154 162 yield subtree 155 163 yield node 156 164
34 sympy/utilities/tests/test_iterables.py
 ... ... @@ -1,4 +1,4 @@ 1 -from sympy import symbols, Integral, Tuple, Dummy, Basic 1 +from sympy import symbols, Integral, Tuple, Dummy, Basic, default_sort_key 2 2 from sympy.utilities.iterables import (postorder_traversal, flatten, group, 3 3 take, subsets, variations, cartes, numbered_symbols, dict_merge, 4 4 prefixes, postfixes, sift, topological_sort, rotate_left, rotate_right, @@ -12,26 +12,22 @@ 12 12 w,x,y,z= symbols('w,x,y,z') 13 13 14 14 def test_postorder_traversal(): 15 - expr = z+w*(x+y) 16 - expected1 = [z, w, y, x, x + y, w*(x + y), z + w*(x + y)] 17 - expected2 = [z, w, x, y, x + y, w*(x + y), z + w*(x + y)] 18 - expected3 = [w, y, x, x + y, w*(x + y), z, z + w*(x + y)] 19 - expected4 = [w, x, y, x + y, w*(x + y), z, z + w*(x + y)] 20 - expected5 = [x, y, x + y, w, w*(x + y), x, x + w*(x + y)] 21 - expected6 = [y, x, x + y, w, w*(x + y), x, x + w*(x + y)] 22 - assert list(postorder_traversal(expr)) in [expected1, expected2, 23 - expected3, expected4, 24 - expected5, expected6] 25 - 26 - expr = Piecewise((x,x<1),(x**2,True)) 27 - assert list(postorder_traversal(expr)) == [ 28 - x, x, 1, x < 1, ExprCondPair(x, x < 1), x, 2, x**2, 29 - ExprCondPair.true_sentinel, 15 + expr = z + w*(x+y) 16 + expected = [z, w, x, y, x + y, w*(x + y), w*(x + y) + z] 17 + assert list(postorder_traversal(expr, key=default_sort_key)) == expected 18 + 19 + expr = Piecewise((x, x < 1), (x**2, True)) 20 + expected = [ 21 + x, 1, x, x < 1, ExprCondPair(x, x < 1), 22 + ExprCondPair.true_sentinel, 2, x, x**2, 30 23 ExprCondPair(x**2, True), Piecewise((x, x < 1), (x**2, True)) 31 - ] 24 + ] 25 + assert list(postorder_traversal(expr, key=default_sort_key)) == expected 26 + assert list(postorder_traversal([expr], key=default_sort_key)) == expected + [[expr]] 32 27 33 - assert list(postorder_traversal(Integral(x**2, (x, 0, 1)))) == [ 34 - x, 2, x**2, x, 0, 1, Tuple(x, 0, 1), 28 + assert list(postorder_traversal(Integral(x**2, (x, 0, 1)), 29 + key=default_sort_key)) == [ 30 + 2, x, x**2, 0, 1, x, Tuple(x, 0, 1), 35 31 Integral(x**2, Tuple(x, 0, 1)) 36 32 ] 37 33 assert list(postorder_traversal(('abc', ('d', 'ef')))) == [