Skip to content

Commit

Permalink
Merge pull request #1411 from smichr/rand
Browse files Browse the repository at this point in the history
Randomization fixes
  • Loading branch information
smichr committed Jul 12, 2012
2 parents 262bdc4 + bf84abd commit 9ff4edf
Show file tree
Hide file tree
Showing 8 changed files with 239 additions and 214 deletions.
40 changes: 25 additions & 15 deletions sympy/core/basic.py
Expand Up @@ -1501,44 +1501,54 @@ class preorder_traversal(object):
.args, which in many cases can be arbitrary.
Parameters
----------
==========
node : sympy expression
The expression to traverse.
key : (default None) sort key
The key used to sort args of Basic objects. When None, args of Basic
objects are processed in arbitrary order.
Yields
------
======
subtree : sympy expression
All of the subtrees in the tree.
Examples
--------
========
>>> from sympy import symbols
>>> from sympy import symbols, default_sort_key
>>> from sympy.core.basic import preorder_traversal
>>> x, y, z = symbols('x y z')
>>> list(preorder_traversal(z*(x+y))) in ( # any of these are possible
... [z*(x + y), z, x + y, x, y], [z*(x + y), z, x + y, y, x],
... [z*(x + y), x + y, x, y, z], [z*(x + y), x + y, y, x, z])
True
>>> list(preorder_traversal((x, (y, z))))
[(x, (y, z)), x, (y, z), y, z]
The nodes are returned in the order that they are encountered unless key
is given.
>>> list(preorder_traversal((x + y)*z, key=None)) # doctest: +SKIP
[z*(x + y), z, x + y, y, x]
>>> list(preorder_traversal((x + y)*z, key=default_sort_key))
[z*(x + y), z, x + y, x, y]
"""
def __init__(self, node):
def __init__(self, node, key=None):
self._skip_flag = False
self._pt = self._preorder_traversal(node)
self._pt = self._preorder_traversal(node, key)

def _preorder_traversal(self, node):
def _preorder_traversal(self, node, key):
yield node
if self._skip_flag:
self._skip_flag = False
return
if isinstance(node, Basic):
for arg in node.args:
for subtree in self._preorder_traversal(arg):
args = node.args
if key:
args = list(args)
args.sort(key=key)
for arg in args:
for subtree in self._preorder_traversal(arg, key):
yield subtree
elif iterable(node):
for item in node:
for subtree in self._preorder_traversal(item):
for subtree in self._preorder_traversal(item, key):
yield subtree

def skip(self):
Expand Down
40 changes: 11 additions & 29 deletions sympy/core/function.py
Expand Up @@ -1347,7 +1347,7 @@ class Subs(Expr):
An example with several variables:
>>> Subs(f(x)*sin(y)+z, (x, y), (0, 1))
>>> Subs(f(x)*sin(y) + z, (x, y), (0, 1))
Subs(z + f(_x)*sin(_y), (_x, _y), (0, 1))
>>> _.doit()
z + f(0)*sin(1)
Expand All @@ -1359,14 +1359,12 @@ def __new__(cls, expr, variables, point, **assumptions):
variables = Tuple(*sympify(variables))

if uniq(variables) != variables:
repeated = repeated = [ v for v in set(variables)
repeated = [ v for v in set(variables)
if list(variables).count(v) > 1 ]
raise ValueError('cannot substitute expressions %s more than '
'once.' % repeated)

if not is_sequence(point, Tuple):
point = [point]
point = Tuple(*sympify(point))
point = Tuple(*sympify(point if is_sequence(point, Tuple) else [point]))

if len(point) != len(variables):
raise ValueError('Number of point values must be the same as '
Expand Down Expand Up @@ -1417,32 +1415,16 @@ def free_symbols(self):
def __eq__(self, other):
if not isinstance(other, Subs):
return False
if (len(self.point) != len(other.point) or
self.free_symbols != other.free_symbols or
sorted(self.point) != sorted(other.point)):

if len(self.expr.free_symbols) != len(other.expr.free_symbols):
return False

# non-repeated point args
selfargs = [ v[0] for v in sorted(zip(self.variables, self.point),
key = lambda v: v[1]) if list(self.point.args).count(v[1]) == 1 ]
otherargs = [ v[0] for v in sorted(zip(other.variables, other.point),
key = lambda v: v[1]) if list(other.point.args).count(v[1]) == 1 ]
# find repeated point values and subs each associated variable
# for a single symbol
selfrepargs = []
otherrepargs = []
if uniq(self.point) != self.point:
repeated = uniq([ v for v in self.point if
list(self.point.args).count(v) > 1 ])
repswap = dict(zip(repeated, [ C.Dummy() for _ in
xrange(len(repeated)) ]))
selfrepargs = [ (self.variables[i], repswap[v]) for i, v in
enumerate(self.point) if v in repeated ]
otherrepargs = [ (other.variables[i], repswap[v]) for i, v in
enumerate(other.point) if v in repeated ]

return self.expr.subs(selfrepargs) == other.expr.subs(
tuple(zip(otherargs, selfargs))).subs(otherrepargs)
# replace points with dummies, each unique pt getting its own dummy
pts = set(self.point.args + other.point.args)
d = dict([(p, Dummy()) for p in pts])
eq = lambda e: \
e.expr.xreplace(dict(zip(e.variables, [d[p] for p in e.point])))
return eq(self) == eq(other)

def __ne__(self, other):
return not(self == other)
Expand Down
7 changes: 7 additions & 0 deletions sympy/core/tests/test_basic.py
Expand Up @@ -3,6 +3,8 @@

from sympy.core.basic import Basic, Atom, preorder_traversal
from sympy.core.singleton import S, Singleton
from sympy.core.symbol import symbols
from sympy.utilities.misc import default_sort_key

from sympy.utilities.pytest import raises

Expand Down Expand Up @@ -117,3 +119,8 @@ def test_preorder_traversal():
if i == b2:
pt.skip()
assert result == [expr, b21, b2, b1, b3, b2]

w, x, y, z = symbols('w:z')
expr = z + w*(x+y)
assert list(preorder_traversal([expr], key=default_sort_key)) == \
[[w*(x + y) + z], w*(x + y) + z, z, w*(x + y), w, x + y, x, y]
6 changes: 6 additions & 0 deletions sympy/core/tests/test_functions.py
Expand Up @@ -136,6 +136,12 @@ def test_Lambda_equality():
def test_Subs():
assert Subs(f(x), x, 0).doit() == f(0)
assert Subs(f(x**2), x**2, 0).doit() == f(0)
assert Subs(f(x, y, z), (x, y, z), (0, 1, 1)) != \
Subs(f(x, y, z), (x, y, z), (0, 0, 1))
assert Subs(f(x, y), (x, y, z), (0, 1, 1)) == \
Subs(f(x, y), (x, y, z), (0, 1, 2))
assert Subs(f(x, y), (x, y, z), (0, 1, 1)) != \
Subs(f(x, y) + z, (x, y, z), (0, 1, 0))
assert Subs(f(x, y), (x, y), (0, 1)).doit() == f(0, 1)
assert Subs(Subs(f(x, y), x, 0), y, 1).doit() == f(0, 1)
raises(ValueError, lambda: Subs(f(x, y), (x, y), (0, 0, 1)))
Expand Down
73 changes: 42 additions & 31 deletions sympy/solvers/solvers.py
Expand Up @@ -434,10 +434,10 @@ def solve(f, *symbols, **flags):
[3]
>>> solve(Poly(x - 3), x)
[3]
>>> solve(x**2 - y**2, x)
[-y, y]
>>> solve(x**4 - 1, x)
[-1, 1, -I, I]
>>> set(solve(x**2 - y**2, x))
set([-y, y])
>>> set(solve(x**4 - 1, x))
set([-1, 1, -I, I])
* single expression with no symbol that is in the expression
Expand All @@ -455,9 +455,9 @@ def solve(f, *symbols, **flags):
>>> solve(x - 3)
[3]
>>> solve(x**2 - y**2)
>>> solve(x**2 - y**2) # doctest: +SKIP
[{x: -y}, {x: y}]
>>> solve(z**2*x**2 - z**2*y**2)
>>> solve(z**2*x**2 - z**2*y**2) # doctest: +SKIP
[{x: -y}, {x: y}]
>>> solve(z**2*x - z**2*y**2)
[{x: y**2}]
Expand All @@ -473,8 +473,8 @@ def solve(f, *symbols, **flags):
[x + f(x)]
>>> solve(f(x).diff(x) - f(x) - x, f(x))
[-x + Derivative(f(x), x)]
>>> solve(x + exp(x)**2, exp(x))
[-sqrt(-x), sqrt(-x)]
>>> set(solve(x + exp(x)**2, exp(x)))
set([-sqrt(-x), sqrt(-x)])
* To solve for a *symbol* implicitly, use 'implicit=True':
Expand All @@ -501,17 +501,17 @@ def solve(f, *symbols, **flags):
* that are nonlinear
>>> solve((a + b)*x - b**2 + 2, a, b)
[(-sqrt(2), sqrt(2)), (sqrt(2), -sqrt(2))]
>>> set(solve((a + b)*x - b**2 + 2, a, b))
set([(-sqrt(2), sqrt(2)), (sqrt(2), -sqrt(2))])
* if there is no linear solution then the first successful
attempt for a nonlinear solution will be returned
>>> solve(x**2 - y**2, x, y)
>>> solve(x**2 - y**2, x, y) # doctest: +SKIP
[{x: -y}, {x: y}]
>>> solve(x**2 - y**2/exp(x), x, y)
[{x: 2*LambertW(y/2)}]
>>> solve(x**2 - y**2/exp(x), y, x)
>>> solve(x**2 - y**2/exp(x), y, x) # doctest: +SKIP
[{y: -x*exp(x/2)}, {y: x*exp(x/2)}]
* iterable of one or more of the above
Expand Down Expand Up @@ -543,8 +543,8 @@ def solve(f, *symbols, **flags):
* when the system is not linear
>>> solve([x**2 + y -2, y**2 - 4], x, y)
[(-2, -2), (0, 2), (0, 2), (2, -2)]
>>> set(solve([x**2 + y -2, y**2 - 4], x, y))
set([(-2, -2), (0, 2), (2, -2)])
* if no symbols are given, all free symbols will be selected and a list
of mappings returned
Expand Down Expand Up @@ -659,7 +659,7 @@ def _sympified_list(w):
# we do this to make the results returned canonical in case f
# contains a system of nonlinear equations; all other cases should
# be unambiguous
symbols = sorted(symbols, key=lambda i: i.sort_key())
symbols = sorted(symbols, key=default_sort_key)

# we can solve for Function and Derivative instances by replacing them
# with Dummy symbols or functions
Expand Down Expand Up @@ -1287,9 +1287,10 @@ def _solve_system(exprs, symbols, **flags):
else:
if len(symbols) != len(polys):
from sympy.utilities.iterables import subsets
free = list(reduce(set.union,
[p.free_symbols for p in polys], set()
).intersection(symbols))
from sympy.core.compatibility import set_union

free = set_union(*[p.free_symbols for p in polys])
free = list(free.intersection(symbols))
free.sort(key=default_sort_key)
for syms in subsets(free, len(polys)):
try:
Expand Down Expand Up @@ -1341,16 +1342,26 @@ def _solve_system(exprs, symbols, **flags):
result = [result]
else:
result = [{}]

def _ok_syms(e, sort=False):
rv = (e.free_symbols - solved_syms) & legal
if sort:
rv = list(rv)
rv.sort(key=default_sort_key)
return rv

solved_syms = set(solved_syms) # set of symbols we have solved for
legal = set(symbols) # what we are interested in
simplify_flag = flags.get('simplify', None)
do_simplify = flags.get('simplify', True)
# sort so equation with the fewest potential symbols is first; break ties with
# count_ops and sort_key
short = sift(failed, lambda x: len((x.free_symbols - solved_syms) & legal))
# sort so equation with the fewest potential symbols is first;
# break ties with count_ops and default_sort_key
short = sift(failed, lambda x: len(_ok_syms(x)))
failed = []
for k in sorted(short):
failed.extend(sorted(short[k], key=lambda x: (x.count_ops(), x.sort_key())))
for k in sorted(short, key=default_sort_key):
failed.extend(sorted(sorted(short[k],
key=lambda x: x.count_ops()),
key=default_sort_key))
for eq in failed:
newresult = []
got_s = None
Expand All @@ -1365,7 +1376,7 @@ def _solve_system(exprs, symbols, **flags):
continue
# search for a symbol amongst those available that
# can be solved for
ok_syms = (eq2.free_symbols - solved_syms) & legal
ok_syms = _ok_syms(eq2, sort=True)
if not ok_syms:
break # skip as it's independent of desired symbols
for s in ok_syms:
Expand Down Expand Up @@ -1591,7 +1602,7 @@ def solve_linear_system(system, *symbols, **flags):
matrix = system[:, :]
syms = list(symbols)

i, m = 0, matrix.cols-1 # don't count augmentation
i, m = 0, matrix.cols - 1 # don't count augmentation

while i < matrix.rows:
if i == m:
Expand All @@ -1611,12 +1622,12 @@ def solve_linear_system(system, *symbols, **flags):
break
else:
if matrix[i, m]:
# we need to know this this is always zero or not. We
# we need to know if this is always zero or not. We
# assume that if there are free symbols that it is not
# identically zero (or that there is more than one way
# to make this zero. Otherwise, if there are none, this
# is a constant and we assume that it does not simplify
# to zero XXX are there better ways to test this/
# to zero XXX are there better ways to test this?
if not matrix[i, m].free_symbols:
return None # no solution

Expand Down Expand Up @@ -1664,7 +1675,7 @@ def solve_linear_system(system, *symbols, **flags):
# divide all elements in the current row by the pivot
matrix.row(i, lambda x, _: x * pivot_inv)

for k in xrange(i+1, matrix.rows):
for k in xrange(i + 1, matrix.rows):
if matrix[k, i]:
coeff = matrix[k, i]

Expand All @@ -1689,7 +1700,7 @@ def solve_linear_system(system, *symbols, **flags):
content = matrix[k, m]

# run back-substitution for variables
for j in xrange(k+1, m):
for j in xrange(k + 1, m):
content -= matrix[k, j]*solutions[syms[j]]

if do_simplify:
Expand All @@ -1703,13 +1714,13 @@ def solve_linear_system(system, *symbols, **flags):
elif len(syms) > matrix.rows:
# this system will have infinite number of solutions
# dependent on exactly len(syms) - i parameters
k, solutions = i-1, {}
k, solutions = i - 1, {}

while k >= 0:
content = matrix[k, m]

# run back-substitution for variables
for j in xrange(k+1, i):
for j in xrange(k + 1, i):
content -= matrix[k, j]*solutions[syms[j]]

# run back-substitution for parameters
Expand Down

0 comments on commit 9ff4edf

Please sign in to comment.