Skip to content

Commit

Permalink
Merge pull request #13194 from bjodah/codegen-rewrite
Browse files Browse the repository at this point in the history
Add .codegen.rewriting module
  • Loading branch information
bjodah committed Aug 28, 2017
2 parents 4009719 + 62a6253 commit 83bcbc0
Show file tree
Hide file tree
Showing 3 changed files with 362 additions and 0 deletions.
201 changes: 201 additions & 0 deletions sympy/codegen/rewriting.py
@@ -0,0 +1,201 @@
# -*- coding: utf-8 -*-
"""
Classes and functions useful for rewriting expressions for optimized code
generation. Some languages (or standards thereof), e.g. C99, offer specialized
math functions for better performance and/or precision.
Using the ``optimize`` function in this module, together with a collection of
rules (represented as instances of ``Optimization``), one can rewrite the
expressions for this purpose::
>>> from sympy import Symbol, exp, log
>>> from sympy.codegen.rewriting import optimize, optims_c99
>>> x = Symbol('x')
>>> optimize(3*exp(2*x) - 3, optims_c99)
3*expm1(2*x)
>>> optimize(exp(2*x) - 3, optims_c99)
exp(2*x) - 3
>>> optimize(log(3*x + 3), optims_c99)
log1p(x) + log(3)
>>> optimize(log(2*x + 3), optims_c99)
log(2*x + 3)
The ``optims_c99`` imported above is tuple containing the following instances
(which may be imported from ``sympy.codegen.rewriting``):
- ``expm1_opt``
- ``log1p_opt``
- ``exp2_opt``
- ``log2_opt``
- ``log2const_opt``
"""
from __future__ import (absolute_import, division, print_function)
from itertools import tee, chain
from sympy import log, Add, exp, Max, Min, Wild, Pow, expand_log, Dummy
from sympy.core.compatibility import filterfalse
from sympy.codegen.cfunctions import log1p, log2, exp2, expm1


class Optimization(object):
""" Abstract base class for rewriting optimization.
Subclasses should implement ``__call__`` taking an expression
as argument.
Parameters
==========
cost_function : callable returning number
priority : number
"""
def __init__(self, cost_function=None, priority=1):
self.cost_function = cost_function
self.priority=priority


class ReplaceOptim(Optimization):
""" Rewriting optimization calling replace on expressions.
The instance can be used as a function on expressions for which
it will apply the ``replace`` method (see
:meth:`sympy.core.basic.Basic.replace`).
Parameters
==========
query : first argument passed to replace
value : second argument passed to replace
Examples
========
>>> from sympy import Symbol, Pow
>>> from sympy.codegen.rewriting import ReplaceOptim
>>> from sympy.codegen.cfunctions import exp2
>>> x = Symbol('x')
>>> exp2_opt = ReplaceOptim(lambda p: (isinstance(p, Pow) and p.base == 2),
... lambda p: exp2(p.exp))
>>> exp2_opt(2**x)
exp2(x)
"""

def __init__(self, query, value, **kwargs):
super(ReplaceOptim, self).__init__(**kwargs)
self.query = query
self.value = value

def __call__(self, expr):
return expr.replace(self.query, self.value)


def optimize(expr, optimizations):
""" Apply optimizations to an expression.
Parameters
==========
expr : expression
optimizations : iterable of ``Optimization`` instances
The optimizations will be sorted with respect to ``priority`` (highest first).
Examples
========
>>> from sympy import log, Symbol
>>> from sympy.codegen.rewriting import optims_c99, optimize
>>> x = Symbol('x')
>>> optimize(log(x+3)/log(2) + log(x**2 + 1), optims_c99)
log1p(x**2) + log2(x + 3)
"""

for optim in sorted(optimizations, key=lambda opt: opt.priority, reverse=True):
new_expr = optim(expr)
if optim.cost_function is None:
expr = new_expr
else:
before, after = map(lambda x: optim.cost_function(x), (expr, new_expr))
if before > after:
expr = new_expr
return expr


exp2_opt = ReplaceOptim(
lambda p: (isinstance(p, Pow)
and p.base == 2),
lambda p: exp2(p.exp)
)

_d = Wild('d', properties=[lambda x: x.is_Dummy])
_u = Wild('u', properties=[lambda x: not x.is_number and not x.is_Add])
_v = Wild('v')
_w = Wild('w')


log2_opt = ReplaceOptim(_v*log(_w)/log(2), _v*log2(_w), cost_function=lambda expr: expr.count(
lambda e: ( # division & eval of transcendentals are expensive floating point operations...
(isinstance(e, Pow) and e.exp.is_negative) # division
or (isinstance(e, (log, log2)) and not e.args[0].is_number)) # transcendental
)
)

log2const_opt = ReplaceOptim(log(2)*log2(_w), log(_w))

logsumexp_2terms_opt = ReplaceOptim(
lambda l: (isinstance(l, log)
and isinstance(l.args[0], Add)
and len(l.args[0].args) == 2
and all(isinstance(t, exp) for t in l.args[0].args)),
lambda l: (
Max(*[e.args[0] for e in l.args[0].args]) +
log1p(exp(Min(*[e.args[0] for e in l.args[0].args])))
)
)


def _partition(predicate, iterable):
iter_a, iter_b = tee(iterable)
return tuple(filter(predicate, iter_a)), tuple(filterfalse(predicate, iter_b))


def _try_expm1(expr):
protected, old_new = expr.replace(exp, lambda arg: Dummy(), map=True)
factored = protected.factor()
new_old = {v: k for k, v in old_new.items()}
return factored.replace(_d - 1, lambda d: expm1(new_old[d].args[0])).xreplace(new_old)


def _expm1_value(e):
numbers, non_num = _partition(lambda arg: arg.is_number, e.args)
non_num_exp, non_num_other = _partition(lambda arg: arg.has(exp), non_num)
numsum = sum(numbers)
new_exp_terms, done = [], False
for exp_term in non_num_exp:
if done:
new_exp_terms.append(exp_term)
else:
looking_at = exp_term + numsum
attempt = _try_expm1(looking_at)
if looking_at == attempt:
new_exp_terms.append(exp_term)
else:
done = True
new_exp_terms.append(attempt)
if not done:
new_exp_terms.append(numsum)
return e.func(*chain(new_exp_terms, non_num_other))


expm1_opt = ReplaceOptim(lambda e: e.is_Add, _expm1_value)


log1p_opt = ReplaceOptim(
lambda e: isinstance(e, log),
lambda l: expand_log(l.replace(
log, lambda arg: log(arg.factor())
)).replace(log(_u+1), log1p(_u))
)

# Collections of optimizations:
optims_c99 = (expm1_opt, log1p_opt, exp2_opt, log2_opt, log2const_opt)
155 changes: 155 additions & 0 deletions sympy/codegen/tests/test_rewriting.py
@@ -0,0 +1,155 @@
# -*- coding: utf-8 -*-
from __future__ import (absolute_import, division, print_function)

import pytest
from sympy import log, exp, Symbol, Pow
from sympy.codegen.cfunctions import log2, exp2, expm1, log1p
from sympy.codegen.rewriting import optimize, log2_opt, exp2_opt, expm1_opt, log1p_opt, optims_c99
from sympy.utilities.pytest import XFAIL


def test_log2_opt():
x = Symbol('x')
expr1 = 7*log(3*x + 5)/(log(2))
opt1 = optimize(expr1, [log2_opt])
assert opt1 == 7*log2(3*x + 5)
assert opt1.rewrite(log) == expr1

expr2 = 3*log(5*x + 7)/(13*log(2))
opt2 = optimize(expr2, [log2_opt])
assert opt2 == 3*log2(5*x + 7)/13
assert opt2.rewrite(log) == expr2

expr3 = log(x)/log(2)
opt3 = optimize(expr3, [log2_opt])
assert opt3 == log2(x)
assert opt3.rewrite(log) == expr3

expr4 = log(x)/log(2) + log(x+1)
opt4 = optimize(expr4, [log2_opt])
assert opt4 == log2(x) + log(2)*log2(x+1)
assert opt4.rewrite(log) == expr4

expr5 = log(17)
opt5 = optimize(expr5, [log2_opt])
assert opt5 == expr5

expr6 = log(x + 3)/log(2)
opt6 = optimize(expr6, [log2_opt])
assert str(opt6) == 'log2(x + 3)'
assert opt6.rewrite(log) == expr6


def test_exp2_opt():
x = Symbol('x')
expr1 = 1 + 2**x
opt1 = optimize(expr1, [exp2_opt])
assert opt1 == 1 + exp2(x)
assert opt1.rewrite(Pow) == expr1

expr2 = 1 + 3**x
assert expr2 == optimize(expr2, [exp2_opt])


def test_expm1_opt():
x = Symbol('x')

expr1 = exp(x) - 1
opt1 = optimize(expr1, [expm1_opt])
assert expm1(x) - opt1 == 0
assert opt1.rewrite(exp) == expr1

expr2 = 3*exp(x) - 3
opt2 = optimize(expr2, [expm1_opt])
assert 3*expm1(x) == opt2
assert opt2.rewrite(exp) == expr2

expr3 = 3*exp(x) - 5
assert expr3 == optimize(expr3, [expm1_opt])

expr4 = 3*exp(x) + log(x) - 3
opt4 = optimize(expr4, [expm1_opt])
assert 3*expm1(x) + log(x) == opt4
assert opt4.rewrite(exp) == expr4

expr5 = 3*exp(2*x) - 3
opt5 = optimize(expr5, [expm1_opt])
assert 3*expm1(2*x) == opt5
assert opt5.rewrite(exp) == expr5


@XFAIL
def test_expm1_two_exp_terms():
x, y = map(Symbol, 'x y'.split())
expr1 = exp(x) + exp(y) - 2
opt1 = optimize(expr1, [expm1_opt])
assert opt1 == expm1(x) + expm1(y)


def test_log1p_opt():
x = Symbol('x')
expr1 = log(x + 1)
opt1 = optimize(expr1, [log1p_opt])
assert log1p(x) - opt1 == 0
assert opt1.rewrite(log) == expr1

expr2 = log(3*x + 3)
opt2 = optimize(expr2, [log1p_opt])
assert log1p(x) + log(3) == opt2
assert (opt2.rewrite(log) - expr2).simplify() == 0

expr3 = log(2*x + 1)
opt3 = optimize(expr3, [log1p_opt])
assert log1p(2*x) - opt3 == 0
assert opt3.rewrite(log) == expr3

expr4 = log(x+3)
opt4 = optimize(expr4, [log1p_opt])
assert str(opt4) == 'log(x + 3)'


def test_optims_c99():
x = Symbol('x')

expr1 = 2**x + log(x)/log(2) + log(x + 1) + exp(x) - 1
opt1 = optimize(expr1, optims_c99).simplify()
assert opt1 == exp2(x) + log2(x) + log1p(x) + expm1(x)
assert opt1.rewrite(exp).rewrite(log).rewrite(Pow) == expr1

expr2 = log(x)/log(2) + log(x + 1)
opt2 = optimize(expr2, optims_c99)
assert opt2 == log2(x) + log1p(x)
assert opt2.rewrite(log) == expr2

expr3 = log(x)/log(2) + log(17*x + 17)
opt3 = optimize(expr3, optims_c99)
delta3 = opt3 - (log2(x) + log(17) + log1p(x))
assert delta3 == 0
assert (opt3.rewrite(log) - expr3).simplify() == 0

expr4 = 2**x + 3*log(5*x + 7)/(13*log(2)) + 11*exp(x) - 11 + log(17*x + 17)
opt4 = optimize(expr4, optims_c99).simplify()
delta4 = opt4 - (exp2(x) + 3*log2(5*x + 7)/13 + 11*expm1(x) + log(17) + log1p(x))
assert delta4 == 0
assert (opt4.rewrite(exp).rewrite(log).rewrite(Pow) - expr4).simplify() == 0

expr5 = 3*exp(2*x) - 3
opt5 = optimize(expr5, optims_c99)
delta5 = opt5 - 3*expm1(2*x)
assert delta5 == 0
assert opt5.rewrite(exp) == expr5

expr6 = exp(2*x) - 3
opt6 = optimize(expr6, optims_c99)
delta6 = opt6 - (exp(2*x) - 3)
assert delta6 == 0

expr7 = log(3*x + 3)
opt7 = optimize(expr7, optims_c99)
delta7 = opt7 - (log(3) + log1p(x))
assert delta7 == 0
assert (opt7.rewrite(log) - expr7).simplify() == 0

expr8 = log(2*x + 3)
opt8 = optimize(expr8, optims_c99)
assert opt8 == expr8
6 changes: 6 additions & 0 deletions sympy/core/compatibility.py
Expand Up @@ -860,3 +860,9 @@ def cache_clear():
if sys.version_info[:2] >= (3, 3):
# 3.2 has an lru_cache with an incompatible API
from functools import lru_cache

try:
from itertools import filterfalse
except ImportError:
def filterfalse(pred, itr):
return filter(lambda x: not pred(x), itr)

0 comments on commit 83bcbc0

Please sign in to comment.