# sympy/sympy

Merge pull request #1411 from smichr/rand

`Randomization fixes`
• Loading branch information...
2 parents 262bdc4 + bf84abd commit 9ff4edf72e293ab00d7204fb0309b7ab90acdf44 smichr committed Jul 12, 2012
 @@ -1501,44 +1501,54 @@ class preorder_traversal(object): .args, which in many cases can be arbitrary. Parameters - ---------- + ========== node : sympy expression The expression to traverse. + key : (default None) sort key + The key used to sort args of Basic objects. When None, args of Basic + objects are processed in arbitrary order. Yields - ------ + ====== subtree : sympy expression All of the subtrees in the tree. Examples - -------- + ======== >>> from sympy import symbols + >>> from sympy import symbols, default_sort_key >>> from sympy.core.basic import preorder_traversal >>> x, y, z = symbols('x y z') - >>> list(preorder_traversal(z*(x+y))) in ( # any of these are possible - ... [z*(x + y), z, x + y, x, y], [z*(x + y), z, x + y, y, x], - ... [z*(x + y), x + y, x, y, z], [z*(x + y), x + y, y, x, z]) - True - >>> list(preorder_traversal((x, (y, z)))) - [(x, (y, z)), x, (y, z), y, z] + + The nodes are returned in the order that they are encountered unless key + is given. + + >>> list(preorder_traversal((x + y)*z, key=None)) # doctest: +SKIP + [z*(x + y), z, x + y, y, x] + >>> list(preorder_traversal((x + y)*z, key=default_sort_key)) + [z*(x + y), z, x + y, x, y] """ - def __init__(self, node): + def __init__(self, node, key=None): self._skip_flag = False - self._pt = self._preorder_traversal(node) + self._pt = self._preorder_traversal(node, key) - def _preorder_traversal(self, node): + def _preorder_traversal(self, node, key): yield node if self._skip_flag: self._skip_flag = False return if isinstance(node, Basic): - for arg in node.args: - for subtree in self._preorder_traversal(arg): + args = node.args + if key: + args = list(args) + args.sort(key=key) + for arg in args: + for subtree in self._preorder_traversal(arg, key): yield subtree elif iterable(node): for item in node: - for subtree in self._preorder_traversal(item): + for subtree in self._preorder_traversal(item, key): yield subtree def skip(self):
 @@ -1347,7 +1347,7 @@ class Subs(Expr): An example with several variables: - >>> Subs(f(x)*sin(y)+z, (x, y), (0, 1)) + >>> Subs(f(x)*sin(y) + z, (x, y), (0, 1)) Subs(z + f(_x)*sin(_y), (_x, _y), (0, 1)) >>> _.doit() z + f(0)*sin(1) @@ -1359,14 +1359,12 @@ def __new__(cls, expr, variables, point, **assumptions): variables = Tuple(*sympify(variables)) if uniq(variables) != variables: - repeated = repeated = [ v for v in set(variables) + repeated = [ v for v in set(variables) if list(variables).count(v) > 1 ] raise ValueError('cannot substitute expressions %s more than ' 'once.' % repeated) - if not is_sequence(point, Tuple): - point = [point] - point = Tuple(*sympify(point)) + point = Tuple(*sympify(point if is_sequence(point, Tuple) else [point])) if len(point) != len(variables): raise ValueError('Number of point values must be the same as ' @@ -1417,32 +1415,16 @@ def free_symbols(self): def __eq__(self, other): if not isinstance(other, Subs): return False - if (len(self.point) != len(other.point) or - self.free_symbols != other.free_symbols or - sorted(self.point) != sorted(other.point)): + + if len(self.expr.free_symbols) != len(other.expr.free_symbols): return False - # non-repeated point args - selfargs = [ v[0] for v in sorted(zip(self.variables, self.point), - key = lambda v: v[1]) if list(self.point.args).count(v[1]) == 1 ] - otherargs = [ v[0] for v in sorted(zip(other.variables, other.point), - key = lambda v: v[1]) if list(other.point.args).count(v[1]) == 1 ] - # find repeated point values and subs each associated variable - # for a single symbol - selfrepargs = [] - otherrepargs = [] - if uniq(self.point) != self.point: - repeated = uniq([ v for v in self.point if - list(self.point.args).count(v) > 1 ]) - repswap = dict(zip(repeated, [ C.Dummy() for _ in - xrange(len(repeated)) ])) - selfrepargs = [ (self.variables[i], repswap[v]) for i, v in - enumerate(self.point) if v in repeated ] - otherrepargs = [ (other.variables[i], repswap[v]) for i, v in - enumerate(other.point) if v in repeated ] - - return self.expr.subs(selfrepargs) == other.expr.subs( - tuple(zip(otherargs, selfargs))).subs(otherrepargs) + # replace points with dummies, each unique pt getting its own dummy + pts = set(self.point.args + other.point.args) + d = dict([(p, Dummy()) for p in pts]) + eq = lambda e: \ + e.expr.xreplace(dict(zip(e.variables, [d[p] for p in e.point]))) + return eq(self) == eq(other) def __ne__(self, other): return not(self == other)
 @@ -3,6 +3,8 @@ from sympy.core.basic import Basic, Atom, preorder_traversal from sympy.core.singleton import S, Singleton +from sympy.core.symbol import symbols +from sympy.utilities.misc import default_sort_key from sympy.utilities.pytest import raises @@ -117,3 +119,8 @@ def test_preorder_traversal(): if i == b2: pt.skip() assert result == [expr, b21, b2, b1, b3, b2] + + w, x, y, z = symbols('w:z') + expr = z + w*(x+y) + assert list(preorder_traversal([expr], key=default_sort_key)) == \ + [[w*(x + y) + z], w*(x + y) + z, z, w*(x + y), w, x + y, x, y]
 @@ -136,6 +136,12 @@ def test_Lambda_equality(): def test_Subs(): assert Subs(f(x), x, 0).doit() == f(0) assert Subs(f(x**2), x**2, 0).doit() == f(0) + assert Subs(f(x, y, z), (x, y, z), (0, 1, 1)) != \ + Subs(f(x, y, z), (x, y, z), (0, 0, 1)) + assert Subs(f(x, y), (x, y, z), (0, 1, 1)) == \ + Subs(f(x, y), (x, y, z), (0, 1, 2)) + assert Subs(f(x, y), (x, y, z), (0, 1, 1)) != \ + Subs(f(x, y) + z, (x, y, z), (0, 1, 0)) assert Subs(f(x, y), (x, y), (0, 1)).doit() == f(0, 1) assert Subs(Subs(f(x, y), x, 0), y, 1).doit() == f(0, 1) raises(ValueError, lambda: Subs(f(x, y), (x, y), (0, 0, 1)))
 @@ -434,10 +434,10 @@ def solve(f, *symbols, **flags): [3] >>> solve(Poly(x - 3), x) [3] - >>> solve(x**2 - y**2, x) - [-y, y] - >>> solve(x**4 - 1, x) - [-1, 1, -I, I] + >>> set(solve(x**2 - y**2, x)) + set([-y, y]) + >>> set(solve(x**4 - 1, x)) + set([-1, 1, -I, I]) * single expression with no symbol that is in the expression @@ -455,9 +455,9 @@ def solve(f, *symbols, **flags): >>> solve(x - 3) [3] - >>> solve(x**2 - y**2) + >>> solve(x**2 - y**2) # doctest: +SKIP [{x: -y}, {x: y}] - >>> solve(z**2*x**2 - z**2*y**2) + >>> solve(z**2*x**2 - z**2*y**2) # doctest: +SKIP [{x: -y}, {x: y}] >>> solve(z**2*x - z**2*y**2) [{x: y**2}] @@ -473,8 +473,8 @@ def solve(f, *symbols, **flags): [x + f(x)] >>> solve(f(x).diff(x) - f(x) - x, f(x)) [-x + Derivative(f(x), x)] - >>> solve(x + exp(x)**2, exp(x)) - [-sqrt(-x), sqrt(-x)] + >>> set(solve(x + exp(x)**2, exp(x))) + set([-sqrt(-x), sqrt(-x)]) * To solve for a *symbol* implicitly, use 'implicit=True': @@ -501,17 +501,17 @@ def solve(f, *symbols, **flags): * that are nonlinear - >>> solve((a + b)*x - b**2 + 2, a, b) - [(-sqrt(2), sqrt(2)), (sqrt(2), -sqrt(2))] + >>> set(solve((a + b)*x - b**2 + 2, a, b)) + set([(-sqrt(2), sqrt(2)), (sqrt(2), -sqrt(2))]) * if there is no linear solution then the first successful attempt for a nonlinear solution will be returned - >>> solve(x**2 - y**2, x, y) + >>> solve(x**2 - y**2, x, y) # doctest: +SKIP [{x: -y}, {x: y}] >>> solve(x**2 - y**2/exp(x), x, y) [{x: 2*LambertW(y/2)}] - >>> solve(x**2 - y**2/exp(x), y, x) + >>> solve(x**2 - y**2/exp(x), y, x) # doctest: +SKIP [{y: -x*exp(x/2)}, {y: x*exp(x/2)}] * iterable of one or more of the above @@ -543,8 +543,8 @@ def solve(f, *symbols, **flags): * when the system is not linear - >>> solve([x**2 + y -2, y**2 - 4], x, y) - [(-2, -2), (0, 2), (0, 2), (2, -2)] + >>> set(solve([x**2 + y -2, y**2 - 4], x, y)) + set([(-2, -2), (0, 2), (2, -2)]) * if no symbols are given, all free symbols will be selected and a list of mappings returned @@ -659,7 +659,7 @@ def _sympified_list(w): # we do this to make the results returned canonical in case f # contains a system of nonlinear equations; all other cases should # be unambiguous - symbols = sorted(symbols, key=lambda i: i.sort_key()) + symbols = sorted(symbols, key=default_sort_key) # we can solve for Function and Derivative instances by replacing them # with Dummy symbols or functions @@ -1287,9 +1287,10 @@ def _solve_system(exprs, symbols, **flags): else: if len(symbols) != len(polys): from sympy.utilities.iterables import subsets - free = list(reduce(set.union, - [p.free_symbols for p in polys], set() - ).intersection(symbols)) + from sympy.core.compatibility import set_union + + free = set_union(*[p.free_symbols for p in polys]) + free = list(free.intersection(symbols)) free.sort(key=default_sort_key) for syms in subsets(free, len(polys)): try: @@ -1341,16 +1342,26 @@ def _solve_system(exprs, symbols, **flags): result = [result] else: result = [{}] + + def _ok_syms(e, sort=False): + rv = (e.free_symbols - solved_syms) & legal + if sort: + rv = list(rv) + rv.sort(key=default_sort_key) + return rv + solved_syms = set(solved_syms) # set of symbols we have solved for legal = set(symbols) # what we are interested in simplify_flag = flags.get('simplify', None) do_simplify = flags.get('simplify', True) - # sort so equation with the fewest potential symbols is first; break ties with - # count_ops and sort_key - short = sift(failed, lambda x: len((x.free_symbols - solved_syms) & legal)) + # sort so equation with the fewest potential symbols is first; + # break ties with count_ops and default_sort_key + short = sift(failed, lambda x: len(_ok_syms(x))) failed = [] - for k in sorted(short): - failed.extend(sorted(short[k], key=lambda x: (x.count_ops(), x.sort_key()))) + for k in sorted(short, key=default_sort_key): + failed.extend(sorted(sorted(short[k], + key=lambda x: x.count_ops()), + key=default_sort_key)) for eq in failed: newresult = [] got_s = None @@ -1365,7 +1376,7 @@ def _solve_system(exprs, symbols, **flags): continue # search for a symbol amongst those available that # can be solved for - ok_syms = (eq2.free_symbols - solved_syms) & legal + ok_syms = _ok_syms(eq2, sort=True) if not ok_syms: break # skip as it's independent of desired symbols for s in ok_syms: @@ -1591,7 +1602,7 @@ def solve_linear_system(system, *symbols, **flags): matrix = system[:, :] syms = list(symbols) - i, m = 0, matrix.cols-1 # don't count augmentation + i, m = 0, matrix.cols - 1 # don't count augmentation while i < matrix.rows: if i == m: @@ -1611,12 +1622,12 @@ def solve_linear_system(system, *symbols, **flags): break else: if matrix[i, m]: - # we need to know this this is always zero or not. We + # we need to know if this is always zero or not. We # assume that if there are free symbols that it is not # identically zero (or that there is more than one way # to make this zero. Otherwise, if there are none, this # is a constant and we assume that it does not simplify - # to zero XXX are there better ways to test this/ + # to zero XXX are there better ways to test this? if not matrix[i, m].free_symbols: return None # no solution @@ -1664,7 +1675,7 @@ def solve_linear_system(system, *symbols, **flags): # divide all elements in the current row by the pivot matrix.row(i, lambda x, _: x * pivot_inv) - for k in xrange(i+1, matrix.rows): + for k in xrange(i + 1, matrix.rows): if matrix[k, i]: coeff = matrix[k, i] @@ -1689,7 +1700,7 @@ def solve_linear_system(system, *symbols, **flags): content = matrix[k, m] # run back-substitution for variables - for j in xrange(k+1, m): + for j in xrange(k + 1, m): content -= matrix[k, j]*solutions[syms[j]] if do_simplify: @@ -1703,13 +1714,13 @@ def solve_linear_system(system, *symbols, **flags): elif len(syms) > matrix.rows: # this system will have infinite number of solutions # dependent on exactly len(syms) - i parameters - k, solutions = i-1, {} + k, solutions = i - 1, {} while k >= 0: content = matrix[k, m] # run back-substitution for variables - for j in xrange(k+1, i): + for j in xrange(k + 1, i): content -= matrix[k, j]*solutions[syms[j]] # run back-substitution for parameters

#### 0 comments on commit `9ff4edf`

Please sign in to comment.