Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with
or
.
Download ZIP
Browse files

1400: cse is deterministic and ignores singleton expressions

deterministic

Since cse relies on making substitutions at the very beginning
in the preprocessing, and substitution is an order-sensitive proceedure,
it was necessary to sort the Adds that are being identified in that
process.

The use of default_sort_key from the traversal was removed. This is not
needed in the sub-expression identification stage. The insert method
that sought to keep the expressions in order was also removed. The
default_sort_key applied in one branch of the expression tree in one
of the expressions given to cse doesn't guarantee that another expression
with the same key (as was defined in the insert method) will be in a
deterministic order -- the tie would not be broken.

So now all the expressions are collected however they come from the
tree and are assembled in set. These (and the Add and Mul terms that
are handled individually) are sorted with the lazyDSU sort which seeks
to put the items in some canonical order. This is the source, then,
of determinism in the method.

In addition, expressions that are only used once are no longer reported
in the returned expressions, e.g. whereas one might have gotten
[[(x0, x+1)], [cos(x0)]] perviously, now [[], [cos(x + 1)]] will be
returned.

The insert method used to keep the expressions arranged in a particular
order was removed since more than this is
  • Loading branch information...
commit 33fa7f8d933678f27b82f6fcb1ea816846910cdc 1 parent 88a1348
@smichr smichr authored
View
7 doc/src/modules/rewriting.rst
@@ -64,7 +64,8 @@ Common Subexpression Detection and Collection
.. module:: sympy.simplify.cse_main
Before evaluating a large expression, it is often useful to identify common
-subexpressions, collect them and evaluate them at once. This is implemented in the ``cse`` function. Examples::
+subexpressions, collect them and evaluate them at once. This is implemented
+in the ``cse`` function. Examples::
>>> from sympy import cse, sqrt, sin, pprint
>>> from sympy.abc import x
@@ -79,8 +80,8 @@ subexpressions, collect them and evaluate them at once. This is implemented in t
>>> pprint(cse(sqrt(sin(x+1) + 5 + cos(y))*sqrt(sin(x+1) + 4 + cos(y))),
... use_unicode=True)
- ⎡ ________ ________⎤⎞
- ⎝[(x₀, cos(y)), (x₁, sin(x + 1)), (x₂, x₀ + x₁)], ⎣╲╱ x + 4 ⋅╲╱ x + 5 ⎦⎠
+ ⎛ ⎡ ________ ________⎤⎞
+ ⎝[(x₀, sin(x + 1) + cos(y))], ⎣╲╱ x + 4 ⋅╲╱ x + 5 ⎦⎠
>>> pprint(cse((x-y)*(z-y) + sqrt((x-y)*(z-y))), use_unicode=True)
⎛ ⎡ ____ ⎤⎞
View
25 sympy/core/tests/test_expand.py
@@ -100,25 +100,18 @@ def test_expand_frac():
def test_issue_3022():
+ # TODO when 3460 is merged, move this import to the top
from sympy import cse
- ans = S('''([
- (x0, im(x)),
- (x1, re(x)),
- (x2, atan2(x0, x1)/2),
- (x3, sin(x2)), (x4, cos(x2)),
- (x5, x0**2 + x1**2),
- (x6, atan2(0, x5)/4),
- (x7, cos(x6)),
- (x8, sin(x6)),
- (x9, x4*x7),
- (x10, x4*x8),
- (x11, x3*x8),
- (x12, x3*x7)],
- [sqrt(2)*(x10 + I*x10 + x11 - I*x11 + x12 + I*x12 - x9 + I*x9)/
- (8*pi**(3/2)*x5**(1/4))])''')
eq = -I*exp(-3*I*pi/4)/(4*pi**(S(3)/2)*sqrt(x))
r, e = cse((eq).expand(complex=True))
- assert abs((eq - e[0].subs(reversed(r))).subs(x, 1 + 3*I)) < 1e-9
+ assert r == S('''[
+ (x0, re(x)), (x1, im(x)), (x2, atan2(x1, x0)/2), (x3, x0**2 + x1**2),
+ (x4, sin(x2)), (x5, cos(x2)), (x6, atan2(0, x3)/4), (x7, sin(x6)),
+ (x8, cos(x6)), (x9, x4*x7), (x10, x5*x7), (x11, x4*x8), (x12, x5*x8)
+ ]''')
+ assert e == S('''[
+ sqrt(2)*(x10 + I*x10 + x11 + I*x11 - x12 + I*x12 + x9 -
+ I*x9)/(8*pi**(S(3)/2)*x3**(S(1)/4))]''')
def test_expand_power_base():
View
147 sympy/simplify/cse_main.py
@@ -1,16 +1,13 @@
""" Tools for doing common subexpression elimination.
"""
-import bisect
import difflib
-from sympy.core import Basic, Mul, Add, Tuple, sympify
+from sympy.core import Basic, Mul, Add, sympify
from sympy.core.basic import preorder_traversal
-from sympy.functions.elementary.complexes import sign
from sympy.core.function import _coeff_isneg
from sympy.core.compatibility import iterable
from sympy.utilities.iterables import numbered_symbols, \
- sift, topological_sort
-from sympy.utilities.misc import default_sort_key
+ sift, topological_sort, lazyDSU_sort, small_first_keys
import cse_opts
@@ -69,16 +66,20 @@ def cse_separate(r, e):
========
>>> from sympy.simplify.cse_main import cse_separate
>>> from sympy.abc import x, y, z
- >>> from sympy import cos, exp, cse, Eq
+ >>> from sympy import cos, exp, cse, Eq, symbols
+ >>> x0, x1 = symbols('x:2')
>>> eq = (x + 1 + exp((x + 1)/(y + 1)) + cos(y + 1))
- >>> cse([eq, Eq(x, z + 1), z - 2], postprocess=cse_separate)
- [[(x0, y + 1), (x, z + 1), (x1, x + 1)],
- [x1 + exp(x1/x0) + cos(x0), z - 2]]
+ >>> cse([eq, Eq(x, z + 1), z - 2], postprocess=cse_separate) in [
+ ... [[(x0, y + 1), (x, z + 1), (x1, x + 1)],
+ ... [x1 + exp(x1/x0) + cos(x0), z - 2]],
+ ... [[(x1, y + 1), (x, z + 1), (x0, x + 1)],
+ ... [x0 + exp(x0/x1) + cos(x1), z - 2]]]
+ ...
+ True
"""
- syms = set([k for k, v in r])
- d = sift(
- e, lambda w: w.is_Equality and not bool(w.free_symbols & set(syms)))
- r, e = [r + [w.args for w in d[True]], d[False]]
+ d = sift(e, lambda w: w.is_Equality and w.lhs.is_Symbol)
+ r = r + [w.args for w in d[True]]
+ e = d[False]
return [reps_toposort(r), e]
# ====end of cse postprocess idioms===========================
@@ -132,6 +133,81 @@ def postprocess_for_cse(expr, optimizations):
return expr
+def _remove_singletons(reps, exprs):
+ """
+ Helper function for cse that will remove expressions that weren't
+ used more than once.
+ """
+ u_reps = [] # the useful reps that are used more than once
+ for i, ui in enumerate(reps):
+ used = [] # where it was used
+ ri, ei = ui
+
+ # keep track of whether the substitution was used more
+ # than once. If used is None, it was never used (yet);
+ # if used is an int, that is the last place where it was
+ # used (>=0 in the reps, <0 in the expressions) and if
+ # it is True, it was used more than once.
+
+ used = None
+
+ tot = 0 # total times used so far
+
+ # search through the reps
+ for j in range(i + 1, len(reps)):
+ c = reps[j][1].count(ri)
+ if c:
+ tot += c
+ if tot > 1:
+ u_reps.append(ui)
+ used = True
+ break
+ else:
+ used = j
+
+ if used is not True:
+
+ # then search through the expressions
+
+ for j, rj in enumerate(exprs):
+ c = rj.count(ri)
+ if c:
+ # append a negative so we know that it was in the
+ # expression that used it
+ tot += c
+ if tot > 1:
+ u_reps.append(ui)
+ used = True
+ break
+ else:
+ used = j - len(exprs)
+
+ if type(used) is int:
+
+ # undo the change
+
+ rep = {ri: ei}
+ j = used
+ if j < 0:
+ exprs[j] = exprs[j].subs(rep)
+ else:
+ reps[j] = reps[j][0], reps[j][1].subs(rep)
+
+ # reuse unused symbols so a contiguous range of symbols is returned
+
+ if len(u_reps) != len(reps):
+ for i, ri in enumerate(u_reps):
+ if u_reps[i][0] != reps[i][0]:
+ rep = (u_reps[i][0], reps[i][0])
+ u_reps[i] = rep[1], u_reps[i][1].subs(*rep)
+ for j in range(i + 1, len(u_reps)):
+ u_reps[j] = u_reps[j][0], u_reps[j][1].subs(*rep)
+ for j, rj in enumerate(exprs):
+ exprs[j] = exprs[j].subs(*rep)
+
+ reps[:] = u_reps # change happens in-place
+
+
def cse(exprs, symbols=None, optimizations=None, postprocess=None):
""" Perform common subexpression elimination on an expression.
@@ -163,7 +239,6 @@ def cse(exprs, symbols=None, optimizations=None, postprocess=None):
The reduced expressions with all of the replacements above.
"""
from sympy.matrices import Matrix
- from sympy.simplify.simplify import fraction
if symbols is None:
symbols = numbered_symbols()
@@ -174,8 +249,7 @@ def cse(exprs, symbols=None, optimizations=None, postprocess=None):
seen_subexp = set()
muls = set()
adds = set()
- to_eliminate = []
- to_eliminate_ops_count = []
+ to_eliminate = set()
if optimizations is None:
# Pull out the default here just in case there are some weird
@@ -188,29 +262,12 @@ def cse(exprs, symbols=None, optimizations=None, postprocess=None):
# Preprocess the expressions to give us better optimization opportunities.
reduced_exprs = [preprocess_for_cse(e, optimizations) for e in exprs]
- # Find all of the repeated subexpressions.
-
- def insert(subtree):
- '''This helper will insert the subtree into to_eliminate while
- maintaining the ordering by op count and will skip the insertion
- if subtree is already present.'''
- ops_count = (
- subtree.count_ops(), subtree.is_Mul) # prefer non-Mul to Mul
- index_to_insert = bisect.bisect(to_eliminate_ops_count, ops_count)
- # all i up to this index have op count <= the current op count
- # so check that subtree is not yet present from this index down
- # (if necessary) to zero.
- for i in xrange(index_to_insert - 1, -1, -1):
- if to_eliminate_ops_count[i] == ops_count and \
- subtree == to_eliminate[i]:
- return # already have it
- to_eliminate_ops_count.insert(index_to_insert, ops_count)
- to_eliminate.insert(index_to_insert, subtree)
+ # Find all of the repeated subexpressions.
for expr in reduced_exprs:
if not isinstance(expr, Basic):
continue
- pt = preorder_traversal(expr, key=default_sort_key)
+ pt = preorder_traversal(expr)
for subtree in pt:
inv = 1/subtree if subtree.is_Pow else None
@@ -223,7 +280,7 @@ def insert(subtree):
if inv and _coeff_isneg(subtree.exp):
# save the form with positive exponent
subtree = inv
- insert(subtree)
+ to_eliminate.add(subtree)
pt.skip()
continue
@@ -231,7 +288,7 @@ def insert(subtree):
if _coeff_isneg(subtree.exp):
# save the form with positive exponent
subtree = inv
- insert(subtree)
+ to_eliminate.add(subtree)
pt.skip()
continue
elif subtree.is_Mul:
@@ -243,12 +300,13 @@ def insert(subtree):
# process adds - any adds that weren't repeated might contain
# subpatterns that are repeated, e.g. x+y+z and x+y have x+y in common
+ adds = lazyDSU_sort(adds, small_first_keys)
adds = [set(a.args) for a in adds]
for i in xrange(len(adds)):
for j in xrange(i + 1, len(adds)):
com = adds[i].intersection(adds[j])
if len(com) > 1:
- insert(Add(*com))
+ to_eliminate.add(Add(*com))
# remove this set of symbols so it doesn't appear again
adds[i] = adds[i].difference(com)
@@ -264,6 +322,7 @@ def insert(subtree):
# in common between the two nc parts
sm = difflib.SequenceMatcher()
+ muls = lazyDSU_sort(muls, small_first_keys)
muls = [a.args_cnc(cset=True) for a in muls]
for i in xrange(len(muls)):
if muls[i][1]:
@@ -291,7 +350,7 @@ def insert(subtree):
if len(com) < 2:
continue
- insert(Mul(*com))
+ to_eliminate.add(Mul(*com))
# remove ccom from all if there was no ncom; to update the nc part
# would require finding the subexpr and then replacing it with a
@@ -304,6 +363,13 @@ def insert(subtree):
if not ccom.difference(muls[k][0]):
muls[k][0] = muls[k][0].difference(ccom)
+ # make to_eliminate canonical; we will prefer non-Muls to Muls
+ # so make that the 2nd sort key for lazyDSU_sort (if it's a Mul
+ # the value of the key will be True which will sort after False
+ ops_mul_def__key = list(small_first_keys)
+ ops_mul_def__key.insert(1, lambda _: _.is_Mul)
+ to_eliminate = lazyDSU_sort(to_eliminate, ops_mul_def__key)
+
# Substitute symbols for all of the repeated subexpressions.
replacements = []
reduced_exprs = list(reduced_exprs)
@@ -336,6 +402,9 @@ def insert(subtree):
reduced_exprs = [postprocess_for_cse(e, optimizations)
for e in reduced_exprs]
+ # remove replacements that weren't used more than once
+ _remove_singletons(replacements, reduced_exprs)
+
if isinstance(exprs, Matrix):
reduced_exprs = [Matrix(exprs.rows, exprs.cols, reduced_exprs)]
if postprocess is None:
View
23 sympy/simplify/cse_opts.py
@@ -1,10 +1,11 @@
""" Optimizations of the expression tree representation for better CSE
opportunities.
"""
-from sympy.core import Add, Basic, Expr, Mul, S
-from sympy.core.exprtools import factor_terms
+from sympy.core import Add, Basic, Expr, Mul
from sympy.core.basic import preorder_traversal
-
+from sympy.core.compatibility import oset
+from sympy.core.exprtools import factor_terms
+from sympy.utilities.iterables import lazyDSU_sort, small_first_keys
class Neg(Expr):
""" Stub to hold negated expression.
@@ -16,16 +17,16 @@ def sub_pre(e):
""" Replace y - x with Neg(x - y) if -1 can be extracted from y - x.
"""
# make canonical, first
- adds = {}
- for a in e.atoms(Add):
- adds[a] = a.could_extract_minus_sign()
- e = e.subs([(a, Mul(-1, -a, evaluate=False)
- if adds[a] else a) for a in adds])
+ adds = lazyDSU_sort(e.atoms(Add), small_first_keys)
+ reps = oset([a for a in adds if a.could_extract_minus_sign()])
+ e = e.subs([(a, Mul(-1, -a, evaluate=False)) for a in reps])
# now replace any persisting Adds, a, that can have -1 extracted with Neg(-a)
if isinstance(e, Basic):
- reps = dict([(a, Neg(-a)) for a in e.atoms(Add)
- if adds.get(a, a.could_extract_minus_sign())])
- e = e.xreplace(reps)
+ negs = {}
+ for a in lazyDSU_sort(e.atoms(Add), small_first_keys):
+ if a in reps or a.could_extract_minus_sign():
+ negs[a] = Neg(-a)
+ e = e.xreplace(negs)
return e
View
54 sympy/simplify/tests/test_cse.py
@@ -1,14 +1,13 @@
import itertools
-from sympy import (Add, Mul, Pow, Symbol, exp, sqrt, symbols, sympify, cse,
- Matrix, S, cos, sin, Eq)
+from sympy import (Add, Pow, Symbol, exp, sqrt, symbols, sympify, cse,
+ Matrix, S, cos, sin, Eq, Function, Tuple)
from sympy.functions.special.hyper import meijerg
from sympy.simplify import cse_main, cse_opts
from sympy.utilities.pytest import XFAIL
w, x, y, z = symbols('w,x,y,z')
-x0, x1, x2 = list(itertools.islice(cse_main.numbered_symbols(), 0, 3))
-negone = sympify(-1)
+x0, x1, x2, x3, x4, x5, x6, x7, x8, x9, x10, x11 = symbols('x:12')
def test_numbered_symbols():
@@ -94,8 +93,8 @@ def test_subtraction_opt():
e = (x - y)*(z - y) + exp((x - y)*(z - y))
substs, reduced = cse(
[e], optimizations=[(cse_opts.sub_pre, cse_opts.sub_post)])
- assert substs == [(x0, x - y), (x1, y - z), (x2, x0*x1)]
- assert reduced == [-x2 + exp(-x2)]
+ assert substs == [(x0, (x - y)*(y - z))]
+ assert reduced == [-x0 + exp(-x0)]
assert cse(-(x - y)*(z - y) + exp(-(x - y)*(z - y))) == \
([(x0, (x - y)*(y - z))], [x0 + exp(x0)])
# issue 978
@@ -123,12 +122,9 @@ def test_multiple_expressions():
l = [(x - z)*(y - z), x - z, y - z]
substs, reduced = cse(l)
rsubsts, _ = cse(reversed(l))
- substitutions = [
- [(x0, x - z), (x1, y - z)],
- [(x0, y - z), (x1, x - z)],
- ]
- assert substs in substitutions
- assert rsubsts in substitutions
+ substitutions = [(x0, x - z), (x1, y - z)]
+ assert substs == substitutions
+ assert rsubsts == substitutions
assert reduced == [x0*x1, x0, x1]
l = [w*y + w + x + y + z, w*x*y]
assert cse(l) == ([(x0, w*y)], [w + x + x0 + y + z, x*x0])
@@ -167,7 +163,7 @@ def test_issue_3164():
def test_dont_cse_tuples():
- from sympy import Subs, Function
+ from sympy import Subs
f = Function("f")
g = Function("g")
@@ -211,7 +207,31 @@ def test_pow_invpow():
def test_postprocess():
eq = (x + 1 + exp((x + 1)/(y + 1)) + cos(y + 1))
- assert cse([eq, Eq(x, z + 1), z - 2],
- postprocess=cse_main.cse_separate) == \
- [[(x0, y + 1), (x, z + 1), (x1, x + 1)], [x1 + exp(x1/x0) +
- cos(x0), z - 2]]
+ assert cse([eq, Eq(x, z + 1), z - 2, (z+1)*(x+1)],
+ postprocess=cse_main.cse_separate) == \
+ [[(x1, y + 1), (x2, z + 1), (x, x2), (x0, x + 1)],
+ [x0 + exp(x0/x1) + cos(x1), x2 - 3, x0*x2]]
+
+def test_issue1400():
+ # previously, this gave 16 constants
+ from sympy.abc import a, b
+ B = Function('B')
+ G = Function('G')
+ t = Tuple(*
+ (a, a + S(1)/2, 2*a, b, 2*a - b + 1, (sqrt(z)/2)**(-2*a + 1)*B(2*a -
+ b, sqrt(z))*B(b - 1, sqrt(z))*G(b)*G(2*a - b + 1),
+ sqrt(z)*(sqrt(z)/2)**(-2*a + 1)*B(b, sqrt(z))*B(2*a - b,
+ sqrt(z))*G(b)*G(2*a - b + 1), sqrt(z)*(sqrt(z)/2)**(-2*a + 1)*B(b - 1,
+ sqrt(z))*B(2*a - b + 1, sqrt(z))*G(b)*G(2*a - b + 1),
+ (sqrt(z)/2)**(-2*a + 1)*B(b, sqrt(z))*B(2*a - b + 1,
+ sqrt(z))*G(b)*G(2*a - b + 1), 1, 0, S(1)/2, z/2, -b + 1, -2*a + b,
+ -2*a))
+
+ c = cse(t)
+ ans = (
+ [(x0, b - 1), (x1, 2*a), (x2, -b), (x3, sqrt(z)), (x4, -x0 + x1 - 1),
+ (x5, -x1), (x6, x4 + 1), (x7, B(x0, x3)), (x8, B(x4, x3)), (x9, B(x6,
+ x3)), (x10, (x3/2)**(x5 + 1)*G(b)*G(x6)), (x11, x10*x3)], [(a, a +
+ S(1)/2, x1, b, x6, x10*x7*x8, x11*x8*B(-x2, x3), x11*x7*x9,
+ x10*x9*B(-x2, x3), 1, 0, S(1)/2, z/2, -x0, -x4, x5)])
+ assert ans == c
Please sign in to comment.
Something went wrong with that request. Please try again.