Skip to content

Commit

Permalink
Merge pull request #103 from skirpichev/remove-strategies
Browse files Browse the repository at this point in the history
Unbundle strategies module
  • Loading branch information
skirpichev committed Aug 23, 2015
2 parents f04c343 + 314eafd commit 26438f3
Show file tree
Hide file tree
Showing 33 changed files with 255 additions and 1,423 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,5 +168,5 @@ def run_tests(self):
'Programming Language :: Python :: 3.4',
],
tests_require=['pytest'],
install_requires=['mpmath>=0.19', 'decorator']
install_requires=['mpmath>=0.19', 'decorator', 'strategies>=0.2.3']
)
156 changes: 156 additions & 0 deletions sympy/core/strategies.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
""" Generic Rules for SymPy
This file assumes knowledge of Basic and little else.
"""
from __future__ import print_function, division, absolute_import

from strategies.dispatch import dispatch

from sympy.core.basic import Basic, Atom
from sympy.utilities.iterables import sift


@dispatch(Basic)
def arguments(o):
return o.args


@dispatch((int, Atom))
def arguments(o):
return ()


@dispatch(Basic)
def operator(o):
return o.func


@dispatch((int, Atom))
def operator(o):
return o


@dispatch(type, (tuple, list))
def term(op, args):
return op(*args)


@dispatch((int, Atom), (tuple, list))
def term(op, args):
return op


# Functions that create rules


def rm_id(isid):
""" Create a rule to remove identities
isid - fn :: x -> Bool --- whether or not this element is an identity
>>> from sympy.core.strategies import rm_id
>>> from sympy import Basic
>>> remove_zeros = rm_id(lambda x: x==0)
>>> remove_zeros(Basic(1, 0, 2))
Basic(1, 2)
>>> remove_zeros(Basic(0, 0)) # If only identites then we keep one
Basic(0)
See Also
========
unpack
"""
def ident_remove(expr):
""" Remove identities """
ids = list(map(isid, arguments(expr)))
if sum(ids) == 0: # No identities. Common case
return expr
elif sum(ids) != len(ids): # there is at least one non-identity
return term(operator(expr),
[arg for arg, x in zip(arguments(expr), ids) if not x])
else:
return term(operator(expr), [arguments(expr)[0]])

return ident_remove


def glom(key, count, combine):
""" Create a rule to conglomerate identical args
>>> from sympy.core.strategies import glom
>>> from sympy import Add
>>> from sympy.abc import x
>>> key = lambda x: x.as_coeff_Mul()[1]
>>> count = lambda x: x.as_coeff_Mul()[0]
>>> combine = lambda cnt, arg: cnt * arg
>>> rl = glom(key, count, combine)
>>> rl(Add(x, -x, 3*x, 2, 3, evaluate=False))
3*x + 5
Wait, how are key, count and combine supposed to work?
>>> key(2*x)
x
>>> count(2*x)
2
>>> combine(2, x)
2*x
"""
def conglomerate(expr):
""" Conglomerate together identical args x + x -> 2x """
groups = sift(arguments(expr), key)
counts = {k: sum(map(count, args)) for k, args in groups.items()}
newargs = [combine(cnt, mat) for mat, cnt in counts.items()]
if set(newargs) != set(arguments(expr)):
return term(operator(expr), newargs)
else:
return expr

return conglomerate


def sort(key):
""" Create a rule to sort by a key function
>>> from sympy.core.strategies import sort
>>> from sympy import Basic
>>> sort_rl = sort(str)
>>> sort_rl(Basic(3, 1, 2))
Basic(1, 2, 3)
"""

def sort_rl(expr):
return term(operator(expr), sorted(arguments(expr), key=key))
return sort_rl


# Functions that are rules


def unpack(expr):
""" Rule to unpack singleton args
>>> from sympy.core.strategies import unpack
>>> from sympy import Basic
>>> unpack(Basic(2))
2
"""
if len(arguments(expr)) == 1:
return arguments(expr)[0]
else:
return expr


def flatten(expr):
""" Flatten T(a, b, T(c, d), T2(e)) to T(a, b, c, d, T2(e)) """
cls = operator(expr)
args = []
for arg in arguments(expr):
if operator(arg) == cls:
args.extend(arguments(arg))
else:
args.append(arg)
return term(cls, args)
54 changes: 54 additions & 0 deletions sympy/core/tests/test_strategies.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from sympy.core.strategies import (rm_id, glom, flatten, unpack, sort,
arguments, operator, term)
from sympy.core import Basic, Integer, Add, Mul, symbols
from sympy.abc import x, y


def test_rm_id():
rmzeros = rm_id(lambda x: x == 0)
assert rmzeros(Basic(0, 1)) == Basic(1)
assert rmzeros(Basic(0, 0)) == Basic(0)
assert rmzeros(Basic(2, 1)) == Basic(2, 1)


def test_glom():
def key(x):
return x.as_coeff_Mul()[1]

def count(x):
return x.as_coeff_Mul()[0]

def newargs(cnt, arg):
return cnt * arg

rl = glom(key, count, newargs)

result = rl(Add(x, -x, 3*x, 2, 3, evaluate=False))
expected = Add(3*x, 5)
assert set(result.args) == set(expected.args)

result = rl(Add(*expected.args, evaluate=False))
assert set(result.args) == set(expected.args)


def test_flatten():
assert flatten(Basic(1, 2, Basic(3, 4))) == Basic(1, 2, 3, 4)


def test_unpack():
assert unpack(Basic(2)) == 2
assert unpack(Basic(2, 3)) == Basic(2, 3)


def test_sort():
assert sort(str)(Basic(3,1,2)) == Basic(1,2,3)


def test_term():
assert arguments(2) == ()
assert arguments(Integer(2)) == ()
assert arguments(2 + x) == (2, x)
assert operator(2 + x) == Add
assert operator(Integer(2)) == Integer(2)
assert term(Add, (2, x)) == 2 + x
assert term(Integer(2), ()) == Integer(2)
22 changes: 11 additions & 11 deletions sympy/integrals/manualintegrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@
from __future__ import print_function, division

from collections import namedtuple
from strategies.core import switch, do_one, null_safe, condition

import sympy

from sympy.core.compatibility import reduce
from sympy.functions.elementary.trigonometric import TrigonometricFunction
from sympy.strategies.core import switch, do_one, null_safe, condition


def Rule(name, props=""):
Expand Down Expand Up @@ -703,9 +703,9 @@ def trig_cotcsc_rule(integral):


def trig_powers_products_rule(integral):
return do_one(null_safe(trig_sincos_rule),
null_safe(trig_tansec_rule),
null_safe(trig_cotcsc_rule))(integral)
return do_one([null_safe(trig_sincos_rule),
null_safe(trig_tansec_rule),
null_safe(trig_cotcsc_rule)])(integral)


def trig_substitution_rule(integral):
Expand Down Expand Up @@ -950,20 +950,20 @@ def _integral_is_subclass(integral):
return k and issubclass(k, klasses)
return _integral_is_subclass

result = do_one(
result = do_one([
null_safe(switch(key, {
sympy.Pow: do_one(null_safe(power_rule), null_safe(inverse_trig_rule)),
sympy.Pow: do_one([null_safe(power_rule), null_safe(inverse_trig_rule)]),
sympy.Symbol: power_rule,
sympy.exp: exp_rule,
sympy.Add: add_rule,
sympy.Mul: do_one(null_safe(mul_rule), null_safe(trig_product_rule),
null_safe(heaviside_rule)),
sympy.Mul: do_one([null_safe(mul_rule), null_safe(trig_product_rule),
null_safe(heaviside_rule)]),
sympy.Derivative: derivative_rule,
TrigonometricFunction: trig_rule,
sympy.Heaviside: heaviside_rule,
sympy.Number: constant_rule
})),
do_one(
do_one([
null_safe(trig_rule),
null_safe(alternatives(
rewrites_rule,
Expand All @@ -979,9 +979,9 @@ def _integral_is_subclass(integral):
distribute_expand_rule),
trig_powers_products_rule
)),
null_safe(trig_substitution_rule)
null_safe(trig_substitution_rule)]
),
fallback_rule)(integral)
fallback_rule])(integral)
del _integral_cache[cachekey]
return result

Expand Down
13 changes: 8 additions & 5 deletions sympy/matrices/expressions/blockmatrix.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from __future__ import print_function, division

from strategies import exhaust, condition, do_one
from strategies.core import typed
from strategies.traverse import bottom_up

from sympy import ask, Q
from sympy.core import Basic, Add, sympify
from sympy.core.compatibility import range
from sympy.strategies import typed, exhaust, condition, do_one, unpack
from sympy.strategies.traverse import bottom_up
from sympy.core.strategies import unpack
from sympy.utilities import sift

from sympy.matrices.expressions.matexpr import MatrixExpr, ZeroMatrix, Identity
Expand Down Expand Up @@ -280,11 +283,11 @@ def hasbm(expr):
return isinstance(expr, MatrixExpr) and expr.has(BlockMatrix)
rule = exhaust(
bottom_up(exhaust(condition(hasbm, typed(
{MatAdd: do_one(bc_matadd, bc_block_plus_ident),
MatMul: do_one(bc_matmul, bc_dist),
{MatAdd: do_one([bc_matadd, bc_block_plus_ident]),
MatMul: do_one([bc_matmul, bc_dist]),
Transpose: bc_transpose,
Inverse: bc_inverse,
BlockMatrix: do_one(bc_unpack, deblock)})))))
BlockMatrix: do_one([bc_unpack, deblock])})))))
result = rule(expr)
try:
return result.doit()
Expand Down
6 changes: 4 additions & 2 deletions sympy/matrices/expressions/hadamard.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from __future__ import print_function, division

from strategies import condition, exhaust, do_one

from sympy.core import Mul, sympify
from sympy.strategies import unpack, flatten, condition, exhaust, do_one
from sympy.core.strategies import unpack, flatten

from sympy.matrices.expressions.matexpr import MatrixExpr, ShapeError

Expand Down Expand Up @@ -82,4 +84,4 @@ def validate(*args):
flatten)

canonicalize = exhaust(condition(lambda x: isinstance(x, HadamardProduct),
do_one(*rules)))
do_one(rules)))
8 changes: 4 additions & 4 deletions sympy/matrices/expressions/matadd.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from __future__ import print_function, division

from sympy.core.compatibility import reduce
from operator import add
from strategies import condition, exhaust, do_one

from sympy.core.compatibility import reduce
from sympy.core import Add, Basic, sympify
from sympy.functions import adjoint
from sympy.matrices.matrices import MatrixBase
from sympy.matrices.expressions.transpose import transpose
from sympy.strategies import (rm_id, unpack, flatten, sort, condition,
exhaust, do_one, glom)
from sympy.core.strategies import rm_id, unpack, flatten, sort, glom
from sympy.matrices.expressions.matexpr import MatrixExpr, ShapeError, ZeroMatrix
from sympy.utilities import default_sort_key, sift

Expand Down Expand Up @@ -120,4 +120,4 @@ def merge_explicit(matadd):
sort(default_sort_key))

canonicalize = exhaust(condition(lambda x: isinstance(x, MatAdd),
do_one(*rules)))
do_one(rules)))
Loading

0 comments on commit 26438f3

Please sign in to comment.