Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #13194 from bjodah/codegen-rewrite
Add .codegen.rewriting module
- Loading branch information
Showing
3 changed files
with
362 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,201 @@ | ||
# -*- coding: utf-8 -*- | ||
""" | ||
Classes and functions useful for rewriting expressions for optimized code | ||
generation. Some languages (or standards thereof), e.g. C99, offer specialized | ||
math functions for better performance and/or precision. | ||
Using the ``optimize`` function in this module, together with a collection of | ||
rules (represented as instances of ``Optimization``), one can rewrite the | ||
expressions for this purpose:: | ||
>>> from sympy import Symbol, exp, log | ||
>>> from sympy.codegen.rewriting import optimize, optims_c99 | ||
>>> x = Symbol('x') | ||
>>> optimize(3*exp(2*x) - 3, optims_c99) | ||
3*expm1(2*x) | ||
>>> optimize(exp(2*x) - 3, optims_c99) | ||
exp(2*x) - 3 | ||
>>> optimize(log(3*x + 3), optims_c99) | ||
log1p(x) + log(3) | ||
>>> optimize(log(2*x + 3), optims_c99) | ||
log(2*x + 3) | ||
The ``optims_c99`` imported above is tuple containing the following instances | ||
(which may be imported from ``sympy.codegen.rewriting``): | ||
- ``expm1_opt`` | ||
- ``log1p_opt`` | ||
- ``exp2_opt`` | ||
- ``log2_opt`` | ||
- ``log2const_opt`` | ||
""" | ||
from __future__ import (absolute_import, division, print_function) | ||
from itertools import tee, chain | ||
from sympy import log, Add, exp, Max, Min, Wild, Pow, expand_log, Dummy | ||
from sympy.core.compatibility import filterfalse | ||
from sympy.codegen.cfunctions import log1p, log2, exp2, expm1 | ||
|
||
|
||
class Optimization(object): | ||
""" Abstract base class for rewriting optimization. | ||
Subclasses should implement ``__call__`` taking an expression | ||
as argument. | ||
Parameters | ||
========== | ||
cost_function : callable returning number | ||
priority : number | ||
""" | ||
def __init__(self, cost_function=None, priority=1): | ||
self.cost_function = cost_function | ||
self.priority=priority | ||
|
||
|
||
class ReplaceOptim(Optimization): | ||
""" Rewriting optimization calling replace on expressions. | ||
The instance can be used as a function on expressions for which | ||
it will apply the ``replace`` method (see | ||
:meth:`sympy.core.basic.Basic.replace`). | ||
Parameters | ||
========== | ||
query : first argument passed to replace | ||
value : second argument passed to replace | ||
Examples | ||
======== | ||
>>> from sympy import Symbol, Pow | ||
>>> from sympy.codegen.rewriting import ReplaceOptim | ||
>>> from sympy.codegen.cfunctions import exp2 | ||
>>> x = Symbol('x') | ||
>>> exp2_opt = ReplaceOptim(lambda p: (isinstance(p, Pow) and p.base == 2), | ||
... lambda p: exp2(p.exp)) | ||
>>> exp2_opt(2**x) | ||
exp2(x) | ||
""" | ||
|
||
def __init__(self, query, value, **kwargs): | ||
super(ReplaceOptim, self).__init__(**kwargs) | ||
self.query = query | ||
self.value = value | ||
|
||
def __call__(self, expr): | ||
return expr.replace(self.query, self.value) | ||
|
||
|
||
def optimize(expr, optimizations): | ||
""" Apply optimizations to an expression. | ||
Parameters | ||
========== | ||
expr : expression | ||
optimizations : iterable of ``Optimization`` instances | ||
The optimizations will be sorted with respect to ``priority`` (highest first). | ||
Examples | ||
======== | ||
>>> from sympy import log, Symbol | ||
>>> from sympy.codegen.rewriting import optims_c99, optimize | ||
>>> x = Symbol('x') | ||
>>> optimize(log(x+3)/log(2) + log(x**2 + 1), optims_c99) | ||
log1p(x**2) + log2(x + 3) | ||
""" | ||
|
||
for optim in sorted(optimizations, key=lambda opt: opt.priority, reverse=True): | ||
new_expr = optim(expr) | ||
if optim.cost_function is None: | ||
expr = new_expr | ||
else: | ||
before, after = map(lambda x: optim.cost_function(x), (expr, new_expr)) | ||
if before > after: | ||
expr = new_expr | ||
return expr | ||
|
||
|
||
exp2_opt = ReplaceOptim( | ||
lambda p: (isinstance(p, Pow) | ||
and p.base == 2), | ||
lambda p: exp2(p.exp) | ||
) | ||
|
||
_d = Wild('d', properties=[lambda x: x.is_Dummy]) | ||
_u = Wild('u', properties=[lambda x: not x.is_number and not x.is_Add]) | ||
_v = Wild('v') | ||
_w = Wild('w') | ||
|
||
|
||
log2_opt = ReplaceOptim(_v*log(_w)/log(2), _v*log2(_w), cost_function=lambda expr: expr.count( | ||
lambda e: ( # division & eval of transcendentals are expensive floating point operations... | ||
(isinstance(e, Pow) and e.exp.is_negative) # division | ||
or (isinstance(e, (log, log2)) and not e.args[0].is_number)) # transcendental | ||
) | ||
) | ||
|
||
log2const_opt = ReplaceOptim(log(2)*log2(_w), log(_w)) | ||
|
||
logsumexp_2terms_opt = ReplaceOptim( | ||
lambda l: (isinstance(l, log) | ||
and isinstance(l.args[0], Add) | ||
and len(l.args[0].args) == 2 | ||
and all(isinstance(t, exp) for t in l.args[0].args)), | ||
lambda l: ( | ||
Max(*[e.args[0] for e in l.args[0].args]) + | ||
log1p(exp(Min(*[e.args[0] for e in l.args[0].args]))) | ||
) | ||
) | ||
|
||
|
||
def _partition(predicate, iterable): | ||
iter_a, iter_b = tee(iterable) | ||
return tuple(filter(predicate, iter_a)), tuple(filterfalse(predicate, iter_b)) | ||
|
||
|
||
def _try_expm1(expr): | ||
protected, old_new = expr.replace(exp, lambda arg: Dummy(), map=True) | ||
factored = protected.factor() | ||
new_old = {v: k for k, v in old_new.items()} | ||
return factored.replace(_d - 1, lambda d: expm1(new_old[d].args[0])).xreplace(new_old) | ||
|
||
|
||
def _expm1_value(e): | ||
numbers, non_num = _partition(lambda arg: arg.is_number, e.args) | ||
non_num_exp, non_num_other = _partition(lambda arg: arg.has(exp), non_num) | ||
numsum = sum(numbers) | ||
new_exp_terms, done = [], False | ||
for exp_term in non_num_exp: | ||
if done: | ||
new_exp_terms.append(exp_term) | ||
else: | ||
looking_at = exp_term + numsum | ||
attempt = _try_expm1(looking_at) | ||
if looking_at == attempt: | ||
new_exp_terms.append(exp_term) | ||
else: | ||
done = True | ||
new_exp_terms.append(attempt) | ||
if not done: | ||
new_exp_terms.append(numsum) | ||
return e.func(*chain(new_exp_terms, non_num_other)) | ||
|
||
|
||
expm1_opt = ReplaceOptim(lambda e: e.is_Add, _expm1_value) | ||
|
||
|
||
log1p_opt = ReplaceOptim( | ||
lambda e: isinstance(e, log), | ||
lambda l: expand_log(l.replace( | ||
log, lambda arg: log(arg.factor()) | ||
)).replace(log(_u+1), log1p(_u)) | ||
) | ||
|
||
# Collections of optimizations: | ||
optims_c99 = (expm1_opt, log1p_opt, exp2_opt, log2_opt, log2const_opt) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,155 @@ | ||
# -*- coding: utf-8 -*- | ||
from __future__ import (absolute_import, division, print_function) | ||
|
||
import pytest | ||
from sympy import log, exp, Symbol, Pow | ||
from sympy.codegen.cfunctions import log2, exp2, expm1, log1p | ||
from sympy.codegen.rewriting import optimize, log2_opt, exp2_opt, expm1_opt, log1p_opt, optims_c99 | ||
from sympy.utilities.pytest import XFAIL | ||
|
||
|
||
def test_log2_opt(): | ||
x = Symbol('x') | ||
expr1 = 7*log(3*x + 5)/(log(2)) | ||
opt1 = optimize(expr1, [log2_opt]) | ||
assert opt1 == 7*log2(3*x + 5) | ||
assert opt1.rewrite(log) == expr1 | ||
|
||
expr2 = 3*log(5*x + 7)/(13*log(2)) | ||
opt2 = optimize(expr2, [log2_opt]) | ||
assert opt2 == 3*log2(5*x + 7)/13 | ||
assert opt2.rewrite(log) == expr2 | ||
|
||
expr3 = log(x)/log(2) | ||
opt3 = optimize(expr3, [log2_opt]) | ||
assert opt3 == log2(x) | ||
assert opt3.rewrite(log) == expr3 | ||
|
||
expr4 = log(x)/log(2) + log(x+1) | ||
opt4 = optimize(expr4, [log2_opt]) | ||
assert opt4 == log2(x) + log(2)*log2(x+1) | ||
assert opt4.rewrite(log) == expr4 | ||
|
||
expr5 = log(17) | ||
opt5 = optimize(expr5, [log2_opt]) | ||
assert opt5 == expr5 | ||
|
||
expr6 = log(x + 3)/log(2) | ||
opt6 = optimize(expr6, [log2_opt]) | ||
assert str(opt6) == 'log2(x + 3)' | ||
assert opt6.rewrite(log) == expr6 | ||
|
||
|
||
def test_exp2_opt(): | ||
x = Symbol('x') | ||
expr1 = 1 + 2**x | ||
opt1 = optimize(expr1, [exp2_opt]) | ||
assert opt1 == 1 + exp2(x) | ||
assert opt1.rewrite(Pow) == expr1 | ||
|
||
expr2 = 1 + 3**x | ||
assert expr2 == optimize(expr2, [exp2_opt]) | ||
|
||
|
||
def test_expm1_opt(): | ||
x = Symbol('x') | ||
|
||
expr1 = exp(x) - 1 | ||
opt1 = optimize(expr1, [expm1_opt]) | ||
assert expm1(x) - opt1 == 0 | ||
assert opt1.rewrite(exp) == expr1 | ||
|
||
expr2 = 3*exp(x) - 3 | ||
opt2 = optimize(expr2, [expm1_opt]) | ||
assert 3*expm1(x) == opt2 | ||
assert opt2.rewrite(exp) == expr2 | ||
|
||
expr3 = 3*exp(x) - 5 | ||
assert expr3 == optimize(expr3, [expm1_opt]) | ||
|
||
expr4 = 3*exp(x) + log(x) - 3 | ||
opt4 = optimize(expr4, [expm1_opt]) | ||
assert 3*expm1(x) + log(x) == opt4 | ||
assert opt4.rewrite(exp) == expr4 | ||
|
||
expr5 = 3*exp(2*x) - 3 | ||
opt5 = optimize(expr5, [expm1_opt]) | ||
assert 3*expm1(2*x) == opt5 | ||
assert opt5.rewrite(exp) == expr5 | ||
|
||
|
||
@XFAIL | ||
def test_expm1_two_exp_terms(): | ||
x, y = map(Symbol, 'x y'.split()) | ||
expr1 = exp(x) + exp(y) - 2 | ||
opt1 = optimize(expr1, [expm1_opt]) | ||
assert opt1 == expm1(x) + expm1(y) | ||
|
||
|
||
def test_log1p_opt(): | ||
x = Symbol('x') | ||
expr1 = log(x + 1) | ||
opt1 = optimize(expr1, [log1p_opt]) | ||
assert log1p(x) - opt1 == 0 | ||
assert opt1.rewrite(log) == expr1 | ||
|
||
expr2 = log(3*x + 3) | ||
opt2 = optimize(expr2, [log1p_opt]) | ||
assert log1p(x) + log(3) == opt2 | ||
assert (opt2.rewrite(log) - expr2).simplify() == 0 | ||
|
||
expr3 = log(2*x + 1) | ||
opt3 = optimize(expr3, [log1p_opt]) | ||
assert log1p(2*x) - opt3 == 0 | ||
assert opt3.rewrite(log) == expr3 | ||
|
||
expr4 = log(x+3) | ||
opt4 = optimize(expr4, [log1p_opt]) | ||
assert str(opt4) == 'log(x + 3)' | ||
|
||
|
||
def test_optims_c99(): | ||
x = Symbol('x') | ||
|
||
expr1 = 2**x + log(x)/log(2) + log(x + 1) + exp(x) - 1 | ||
opt1 = optimize(expr1, optims_c99).simplify() | ||
assert opt1 == exp2(x) + log2(x) + log1p(x) + expm1(x) | ||
assert opt1.rewrite(exp).rewrite(log).rewrite(Pow) == expr1 | ||
|
||
expr2 = log(x)/log(2) + log(x + 1) | ||
opt2 = optimize(expr2, optims_c99) | ||
assert opt2 == log2(x) + log1p(x) | ||
assert opt2.rewrite(log) == expr2 | ||
|
||
expr3 = log(x)/log(2) + log(17*x + 17) | ||
opt3 = optimize(expr3, optims_c99) | ||
delta3 = opt3 - (log2(x) + log(17) + log1p(x)) | ||
assert delta3 == 0 | ||
assert (opt3.rewrite(log) - expr3).simplify() == 0 | ||
|
||
expr4 = 2**x + 3*log(5*x + 7)/(13*log(2)) + 11*exp(x) - 11 + log(17*x + 17) | ||
opt4 = optimize(expr4, optims_c99).simplify() | ||
delta4 = opt4 - (exp2(x) + 3*log2(5*x + 7)/13 + 11*expm1(x) + log(17) + log1p(x)) | ||
assert delta4 == 0 | ||
assert (opt4.rewrite(exp).rewrite(log).rewrite(Pow) - expr4).simplify() == 0 | ||
|
||
expr5 = 3*exp(2*x) - 3 | ||
opt5 = optimize(expr5, optims_c99) | ||
delta5 = opt5 - 3*expm1(2*x) | ||
assert delta5 == 0 | ||
assert opt5.rewrite(exp) == expr5 | ||
|
||
expr6 = exp(2*x) - 3 | ||
opt6 = optimize(expr6, optims_c99) | ||
delta6 = opt6 - (exp(2*x) - 3) | ||
assert delta6 == 0 | ||
|
||
expr7 = log(3*x + 3) | ||
opt7 = optimize(expr7, optims_c99) | ||
delta7 = opt7 - (log(3) + log1p(x)) | ||
assert delta7 == 0 | ||
assert (opt7.rewrite(log) - expr7).simplify() == 0 | ||
|
||
expr8 = log(2*x + 3) | ||
opt8 = optimize(expr8, optims_c99) | ||
assert opt8 == expr8 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters