Permalink
Browse files

Merge pull request #1411 from smichr/rand

Randomization fixes
  • Loading branch information...
2 parents 262bdc4 + bf84abd commit 9ff4edf72e293ab00d7204fb0309b7ab90acdf44 @smichr smichr committed Jul 12, 2012
View
@@ -1501,44 +1501,54 @@ class preorder_traversal(object):
.args, which in many cases can be arbitrary.
Parameters
- ----------
+ ==========
node : sympy expression
The expression to traverse.
+ key : (default None) sort key
+ The key used to sort args of Basic objects. When None, args of Basic
+ objects are processed in arbitrary order.
Yields
- ------
+ ======
subtree : sympy expression
All of the subtrees in the tree.
Examples
- --------
+ ========
>>> from sympy import symbols
+ >>> from sympy import symbols, default_sort_key
>>> from sympy.core.basic import preorder_traversal
>>> x, y, z = symbols('x y z')
- >>> list(preorder_traversal(z*(x+y))) in ( # any of these are possible
- ... [z*(x + y), z, x + y, x, y], [z*(x + y), z, x + y, y, x],
- ... [z*(x + y), x + y, x, y, z], [z*(x + y), x + y, y, x, z])
- True
- >>> list(preorder_traversal((x, (y, z))))
- [(x, (y, z)), x, (y, z), y, z]
+
+ The nodes are returned in the order that they are encountered unless key
+ is given.
+
+ >>> list(preorder_traversal((x + y)*z, key=None)) # doctest: +SKIP
+ [z*(x + y), z, x + y, y, x]
+ >>> list(preorder_traversal((x + y)*z, key=default_sort_key))
+ [z*(x + y), z, x + y, x, y]
"""
- def __init__(self, node):
+ def __init__(self, node, key=None):
self._skip_flag = False
- self._pt = self._preorder_traversal(node)
+ self._pt = self._preorder_traversal(node, key)
- def _preorder_traversal(self, node):
+ def _preorder_traversal(self, node, key):
yield node
if self._skip_flag:
self._skip_flag = False
return
if isinstance(node, Basic):
- for arg in node.args:
- for subtree in self._preorder_traversal(arg):
+ args = node.args
+ if key:
+ args = list(args)
+ args.sort(key=key)
+ for arg in args:
+ for subtree in self._preorder_traversal(arg, key):
yield subtree
elif iterable(node):
for item in node:
- for subtree in self._preorder_traversal(item):
+ for subtree in self._preorder_traversal(item, key):
yield subtree
def skip(self):
View
@@ -1347,7 +1347,7 @@ class Subs(Expr):
An example with several variables:
- >>> Subs(f(x)*sin(y)+z, (x, y), (0, 1))
+ >>> Subs(f(x)*sin(y) + z, (x, y), (0, 1))
Subs(z + f(_x)*sin(_y), (_x, _y), (0, 1))
>>> _.doit()
z + f(0)*sin(1)
@@ -1359,14 +1359,12 @@ def __new__(cls, expr, variables, point, **assumptions):
variables = Tuple(*sympify(variables))
if uniq(variables) != variables:
- repeated = repeated = [ v for v in set(variables)
+ repeated = [ v for v in set(variables)
if list(variables).count(v) > 1 ]
raise ValueError('cannot substitute expressions %s more than '
'once.' % repeated)
- if not is_sequence(point, Tuple):
- point = [point]
- point = Tuple(*sympify(point))
+ point = Tuple(*sympify(point if is_sequence(point, Tuple) else [point]))
if len(point) != len(variables):
raise ValueError('Number of point values must be the same as '
@@ -1417,32 +1415,16 @@ def free_symbols(self):
def __eq__(self, other):
if not isinstance(other, Subs):
return False
- if (len(self.point) != len(other.point) or
- self.free_symbols != other.free_symbols or
- sorted(self.point) != sorted(other.point)):
+
+ if len(self.expr.free_symbols) != len(other.expr.free_symbols):
return False
- # non-repeated point args
- selfargs = [ v[0] for v in sorted(zip(self.variables, self.point),
- key = lambda v: v[1]) if list(self.point.args).count(v[1]) == 1 ]
- otherargs = [ v[0] for v in sorted(zip(other.variables, other.point),
- key = lambda v: v[1]) if list(other.point.args).count(v[1]) == 1 ]
- # find repeated point values and subs each associated variable
- # for a single symbol
- selfrepargs = []
- otherrepargs = []
- if uniq(self.point) != self.point:
- repeated = uniq([ v for v in self.point if
- list(self.point.args).count(v) > 1 ])
- repswap = dict(zip(repeated, [ C.Dummy() for _ in
- xrange(len(repeated)) ]))
- selfrepargs = [ (self.variables[i], repswap[v]) for i, v in
- enumerate(self.point) if v in repeated ]
- otherrepargs = [ (other.variables[i], repswap[v]) for i, v in
- enumerate(other.point) if v in repeated ]
-
- return self.expr.subs(selfrepargs) == other.expr.subs(
- tuple(zip(otherargs, selfargs))).subs(otherrepargs)
+ # replace points with dummies, each unique pt getting its own dummy
+ pts = set(self.point.args + other.point.args)
+ d = dict([(p, Dummy()) for p in pts])
+ eq = lambda e: \
+ e.expr.xreplace(dict(zip(e.variables, [d[p] for p in e.point])))
+ return eq(self) == eq(other)
def __ne__(self, other):
return not(self == other)
@@ -3,6 +3,8 @@
from sympy.core.basic import Basic, Atom, preorder_traversal
from sympy.core.singleton import S, Singleton
+from sympy.core.symbol import symbols
+from sympy.utilities.misc import default_sort_key
from sympy.utilities.pytest import raises
@@ -117,3 +119,8 @@ def test_preorder_traversal():
if i == b2:
pt.skip()
assert result == [expr, b21, b2, b1, b3, b2]
+
+ w, x, y, z = symbols('w:z')
+ expr = z + w*(x+y)
+ assert list(preorder_traversal([expr], key=default_sort_key)) == \
+ [[w*(x + y) + z], w*(x + y) + z, z, w*(x + y), w, x + y, x, y]
@@ -136,6 +136,12 @@ def test_Lambda_equality():
def test_Subs():
assert Subs(f(x), x, 0).doit() == f(0)
assert Subs(f(x**2), x**2, 0).doit() == f(0)
+ assert Subs(f(x, y, z), (x, y, z), (0, 1, 1)) != \
+ Subs(f(x, y, z), (x, y, z), (0, 0, 1))
+ assert Subs(f(x, y), (x, y, z), (0, 1, 1)) == \
+ Subs(f(x, y), (x, y, z), (0, 1, 2))
+ assert Subs(f(x, y), (x, y, z), (0, 1, 1)) != \
+ Subs(f(x, y) + z, (x, y, z), (0, 1, 0))
assert Subs(f(x, y), (x, y), (0, 1)).doit() == f(0, 1)
assert Subs(Subs(f(x, y), x, 0), y, 1).doit() == f(0, 1)
raises(ValueError, lambda: Subs(f(x, y), (x, y), (0, 0, 1)))
View
@@ -434,10 +434,10 @@ def solve(f, *symbols, **flags):
[3]
>>> solve(Poly(x - 3), x)
[3]
- >>> solve(x**2 - y**2, x)
- [-y, y]
- >>> solve(x**4 - 1, x)
- [-1, 1, -I, I]
+ >>> set(solve(x**2 - y**2, x))
+ set([-y, y])
+ >>> set(solve(x**4 - 1, x))
+ set([-1, 1, -I, I])
* single expression with no symbol that is in the expression
@@ -455,9 +455,9 @@ def solve(f, *symbols, **flags):
>>> solve(x - 3)
[3]
- >>> solve(x**2 - y**2)
+ >>> solve(x**2 - y**2) # doctest: +SKIP
[{x: -y}, {x: y}]
- >>> solve(z**2*x**2 - z**2*y**2)
+ >>> solve(z**2*x**2 - z**2*y**2) # doctest: +SKIP
[{x: -y}, {x: y}]
>>> solve(z**2*x - z**2*y**2)
[{x: y**2}]
@@ -473,8 +473,8 @@ def solve(f, *symbols, **flags):
[x + f(x)]
>>> solve(f(x).diff(x) - f(x) - x, f(x))
[-x + Derivative(f(x), x)]
- >>> solve(x + exp(x)**2, exp(x))
- [-sqrt(-x), sqrt(-x)]
+ >>> set(solve(x + exp(x)**2, exp(x)))
+ set([-sqrt(-x), sqrt(-x)])
* To solve for a *symbol* implicitly, use 'implicit=True':
@@ -501,17 +501,17 @@ def solve(f, *symbols, **flags):
* that are nonlinear
- >>> solve((a + b)*x - b**2 + 2, a, b)
- [(-sqrt(2), sqrt(2)), (sqrt(2), -sqrt(2))]
+ >>> set(solve((a + b)*x - b**2 + 2, a, b))
+ set([(-sqrt(2), sqrt(2)), (sqrt(2), -sqrt(2))])
* if there is no linear solution then the first successful
attempt for a nonlinear solution will be returned
- >>> solve(x**2 - y**2, x, y)
+ >>> solve(x**2 - y**2, x, y) # doctest: +SKIP
[{x: -y}, {x: y}]
>>> solve(x**2 - y**2/exp(x), x, y)
[{x: 2*LambertW(y/2)}]
- >>> solve(x**2 - y**2/exp(x), y, x)
+ >>> solve(x**2 - y**2/exp(x), y, x) # doctest: +SKIP
[{y: -x*exp(x/2)}, {y: x*exp(x/2)}]
* iterable of one or more of the above
@@ -543,8 +543,8 @@ def solve(f, *symbols, **flags):
* when the system is not linear
- >>> solve([x**2 + y -2, y**2 - 4], x, y)
- [(-2, -2), (0, 2), (0, 2), (2, -2)]
+ >>> set(solve([x**2 + y -2, y**2 - 4], x, y))
+ set([(-2, -2), (0, 2), (2, -2)])
* if no symbols are given, all free symbols will be selected and a list
of mappings returned
@@ -659,7 +659,7 @@ def _sympified_list(w):
# we do this to make the results returned canonical in case f
# contains a system of nonlinear equations; all other cases should
# be unambiguous
- symbols = sorted(symbols, key=lambda i: i.sort_key())
+ symbols = sorted(symbols, key=default_sort_key)
# we can solve for Function and Derivative instances by replacing them
# with Dummy symbols or functions
@@ -1287,9 +1287,10 @@ def _solve_system(exprs, symbols, **flags):
else:
if len(symbols) != len(polys):
from sympy.utilities.iterables import subsets
- free = list(reduce(set.union,
- [p.free_symbols for p in polys], set()
- ).intersection(symbols))
+ from sympy.core.compatibility import set_union
+
+ free = set_union(*[p.free_symbols for p in polys])
+ free = list(free.intersection(symbols))
free.sort(key=default_sort_key)
for syms in subsets(free, len(polys)):
try:
@@ -1341,16 +1342,26 @@ def _solve_system(exprs, symbols, **flags):
result = [result]
else:
result = [{}]
+
+ def _ok_syms(e, sort=False):
+ rv = (e.free_symbols - solved_syms) & legal
+ if sort:
+ rv = list(rv)
+ rv.sort(key=default_sort_key)
+ return rv
+
solved_syms = set(solved_syms) # set of symbols we have solved for
legal = set(symbols) # what we are interested in
simplify_flag = flags.get('simplify', None)
do_simplify = flags.get('simplify', True)
- # sort so equation with the fewest potential symbols is first; break ties with
- # count_ops and sort_key
- short = sift(failed, lambda x: len((x.free_symbols - solved_syms) & legal))
+ # sort so equation with the fewest potential symbols is first;
+ # break ties with count_ops and default_sort_key
+ short = sift(failed, lambda x: len(_ok_syms(x)))
failed = []
- for k in sorted(short):
- failed.extend(sorted(short[k], key=lambda x: (x.count_ops(), x.sort_key())))
+ for k in sorted(short, key=default_sort_key):
+ failed.extend(sorted(sorted(short[k],
+ key=lambda x: x.count_ops()),
+ key=default_sort_key))
for eq in failed:
newresult = []
got_s = None
@@ -1365,7 +1376,7 @@ def _solve_system(exprs, symbols, **flags):
continue
# search for a symbol amongst those available that
# can be solved for
- ok_syms = (eq2.free_symbols - solved_syms) & legal
+ ok_syms = _ok_syms(eq2, sort=True)
if not ok_syms:
break # skip as it's independent of desired symbols
for s in ok_syms:
@@ -1591,7 +1602,7 @@ def solve_linear_system(system, *symbols, **flags):
matrix = system[:, :]
syms = list(symbols)
- i, m = 0, matrix.cols-1 # don't count augmentation
+ i, m = 0, matrix.cols - 1 # don't count augmentation
while i < matrix.rows:
if i == m:
@@ -1611,12 +1622,12 @@ def solve_linear_system(system, *symbols, **flags):
break
else:
if matrix[i, m]:
- # we need to know this this is always zero or not. We
+ # we need to know if this is always zero or not. We
# assume that if there are free symbols that it is not
# identically zero (or that there is more than one way
# to make this zero. Otherwise, if there are none, this
# is a constant and we assume that it does not simplify
- # to zero XXX are there better ways to test this/
+ # to zero XXX are there better ways to test this?
if not matrix[i, m].free_symbols:
return None # no solution
@@ -1664,7 +1675,7 @@ def solve_linear_system(system, *symbols, **flags):
# divide all elements in the current row by the pivot
matrix.row(i, lambda x, _: x * pivot_inv)
- for k in xrange(i+1, matrix.rows):
+ for k in xrange(i + 1, matrix.rows):
if matrix[k, i]:
coeff = matrix[k, i]
@@ -1689,7 +1700,7 @@ def solve_linear_system(system, *symbols, **flags):
content = matrix[k, m]
# run back-substitution for variables
- for j in xrange(k+1, m):
+ for j in xrange(k + 1, m):
content -= matrix[k, j]*solutions[syms[j]]
if do_simplify:
@@ -1703,13 +1714,13 @@ def solve_linear_system(system, *symbols, **flags):
elif len(syms) > matrix.rows:
# this system will have infinite number of solutions
# dependent on exactly len(syms) - i parameters
- k, solutions = i-1, {}
+ k, solutions = i - 1, {}
while k >= 0:
content = matrix[k, m]
# run back-substitution for variables
- for j in xrange(k+1, i):
+ for j in xrange(k + 1, i):
content -= matrix[k, j]*solutions[syms[j]]
# run back-substitution for parameters
Oops, something went wrong.

0 comments on commit 9ff4edf

Please sign in to comment.