From 9bc29d0bf6c84b296bab5b3762cc908514256ad6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bj=C3=B6rn=20Dahlgren?= Date: Fri, 25 Aug 2017 16:53:49 +0200 Subject: [PATCH 1/3] Add .codgen.rewriting module --- sympy/codegen/rewriting.py | 197 ++++++++++++++++++++++++++ sympy/codegen/tests/test_rewriting.py | 136 ++++++++++++++++++ 2 files changed, 333 insertions(+) create mode 100644 sympy/codegen/rewriting.py create mode 100644 sympy/codegen/tests/test_rewriting.py diff --git a/sympy/codegen/rewriting.py b/sympy/codegen/rewriting.py new file mode 100644 index 000000000000..6912cbec03dc --- /dev/null +++ b/sympy/codegen/rewriting.py @@ -0,0 +1,197 @@ +# -*- coding: utf-8 -*- +from __future__ import (absolute_import, division, print_function) +from itertools import tee, chain +try: + from itertools import filterfalse +except ImportError: + def filterfalse(pred, iter): + return filter(lambda x: not pred(x), iter) + +from sympy import log, Add, exp, Max, Min, Wild, Pow, expand_log, Dummy +from sympy.codegen.cfunctions import log1p, log2, exp2, expm1 + +""" +Classes and functions useful for rewriting expressions for optimized code +generation. Some languages (or standard thereof), e.g. C99, offer specialized +math functions which may offer better performance and/or precision. + +Using the ``optimize`` function in this module together with a collection of +rules (represented as instances of ``Optimization``) one can rewrite the +expressions for this purpose. + +Examples +-------- +>>> from sympy import Symbol, exp, log +>>> from sympy.codegen.rewriting import optimize, c99_optims +>>> x = Symbol('x') +>>> optimize(3*exp(2*x) - 3, c99_optims) +3*expm1(2*x) +>>> optimize(exp(2*x) - 3, c99_optims) +exp(2*x) - 3 +>>> optimize(log(3*x + 3), c99_optims) +log(3) + log1p(x) +>>> optimize(log(2*x + 3), c99_optims) +log(2*x + 3) + +""" + + +class Optimization(object): + """ Abstract base class for rewriting optimization. + + Subclasses should implement ``__call__`` taking an expression + as argument. + + Parameters + ---------- + cost_function : callable returning number + priority : number + + """ + def __init__(self, cost_function=None, priority=1): + self.cost_function = cost_function + self.priority=priority + + +class ReplaceOptim(Optimization): + """ Rewriting optimization calling replace on expressions. + + The instance can be used as a function on expressions for which + it will apply the ``replace`` method (see + :meth:`sympy.core.basic.Basic.replace`). + + Parameters + ---------- + query : first argument passed to replace + value : second argument passed to replace + + Examples + -------- + >>> from sympy import Symbol, Pow + >>> from sympy.codegen.rewriting import ReplaceOptim + >>> from sympy.codegen.cfunctions import exp2 + >>> x = Symbol('x') + >>> exp2_opt = ReplaceOptim(lambda p: (isinstance(p, Pow) and p.base == 2), + ... lambda p: exp2(p.exp)) + >>> exp2_opt(2**x) + exp2(x) + + """ + + def __init__(self, query, value, **kwargs): + super(ReplaceOptim, self).__init__(**kwargs) + self.query = query + self.value = value + + def __call__(self, expr): + return expr.replace(self.query, self.value) + + +def optimize(expr, optimizations): + """ Apply optimizations to an expression. + + Parameters + ---------- + expr : expression + optimizations : iterable of ``Optimization`` instances + The optimizations will be sorted with respect to ``priority`` (highest first). + + Examples + -------- + >>> from sympy import log, Symbol + >>> from sympy.codegen.rewriting import optims_c99, optimize + >>> x = Symbol('x') + >>> optimize(log(x+3)/log(2) + log(x**2 + 1), optims_c99) + log1p(x**2) + log2(x + 3) + + """ + + for optim in sorted(optimizations, key=lambda opt: opt.priority, reverse=True): + new_expr = optim(expr) + if optim.cost_function is None: + expr = new_expr + else: + before, after = map(lambda x: optim.cost_function(x), (expr, new_expr)) + if before > after: + expr = new_expr + return expr + + +exp2_opt = ReplaceOptim( + lambda p: (isinstance(p, Pow) + and p.base == 2), + lambda p: exp2(p.exp) +) + +_d = Wild('d', properties=[lambda x: x.is_Dummy]) +_u = Wild('u', properties=[lambda x: not x.is_number and not x.is_Add]) +_v = Wild('v') +_w = Wild('w') + + +log2_opt = ReplaceOptim(_v*log(_w)/log(2), _v*log2(_w), cost_function=lambda expr: expr.count( + lambda e: ( # division & eval of transcendentals are expensive floating point operations... + (isinstance(e, Pow) and e.exp.is_negative) # division + or (isinstance(e, (log, log2)) and not e.args[0].is_number)) # transcendental + ) +) + +log2const_opt = ReplaceOptim(log(2)*log2(_w), log(_w)) + +logsumexp_2terms_opt = ReplaceOptim( + lambda l: (isinstance(l, log) + and isinstance(l.args[0], Add) + and len(l.args[0].args) == 2 + and all(isinstance(t, exp) for t in l.args[0].args)), + lambda l: ( + Max(*[e.args[0] for e in l.args[0].args]) + + log1p(exp(Min(*[e.args[0] for e in l.args[0].args]))) + ) +) + + +def _partition(predicate, iterable): + iter_a, iter_b = tee(iterable) + return tuple(filter(predicate, iter_a)), tuple(filterfalse(predicate, iter_b)) + + +def _try_expm1(expr): + protected, old_new = expr.replace(exp, lambda arg: Dummy(), map=True) + factored = protected.factor() + new_old = {v: k for k, v in old_new.items()} + return factored.replace(_d - 1, lambda d: expm1(new_old[d].args[0])).xreplace(new_old) + + +def _expm1_value(e): + numbers, non_num = _partition(lambda arg: arg.is_number, e.args) + non_num_exp, non_num_other = _partition(lambda arg: arg.has(exp), non_num) + numsum = sum(numbers) + new_exp_terms, done = [], False + for exp_term in non_num_exp: + if done: + new_exp_terms.append(exp_term) + else: + looking_at = exp_term + numsum + attempt = _try_expm1(looking_at) + if looking_at == attempt: + new_exp_terms.append(exp_term) + else: + done = True + new_exp_terms.append(attempt) + if not done: + new_exp_terms.append(numsum) + return e.func(*chain(new_exp_terms, non_num_other)) + + +expm1_opt = ReplaceOptim(lambda e: e.is_Add, _expm1_value) + + +log1p_opt = ReplaceOptim( + lambda e: isinstance(e, log), + lambda l: expand_log(l.replace( + log, lambda arg: log(arg.factor()) + )).replace(log(_u+1), log1p(_u)) +) + +# Collections of optimizations: +optims_c99 = (expm1_opt, log1p_opt, exp2_opt, log2_opt, log2const_opt) diff --git a/sympy/codegen/tests/test_rewriting.py b/sympy/codegen/tests/test_rewriting.py new file mode 100644 index 000000000000..89a3455ac814 --- /dev/null +++ b/sympy/codegen/tests/test_rewriting.py @@ -0,0 +1,136 @@ +# -*- coding: utf-8 -*- +from __future__ import (absolute_import, division, print_function) + +import pytest +from sympy import log, exp, Symbol +from sympy.codegen.cfunctions import log2, exp2, expm1, log1p +from sympy.codegen.rewriting import optimize, log2_opt, exp2_opt, expm1_opt, log1p_opt, optims_c99 + + +def test_log2_opt(): + x = Symbol('x') + expr1 = 7*log(3*x + 5)/(log(2)) + opt1 = optimize(expr1, [log2_opt]) + assert opt1 == 7*log2(3*x + 5) + + expr2 = 3*log(5*x + 7)/(13*log(2)) + opt2 = optimize(expr2, [log2_opt]) + assert opt2 == 3*log2(5*x + 7)/13 + + expr3 = log(x)/log(2) + opt3 = optimize(expr3, [log2_opt]) + assert opt3 == log2(x) + + expr4 = log(x)/log(2) + log(x+1) + opt4 = optimize(expr4, [log2_opt]) + assert opt4 == log2(x) + log(2)*log2(x+1) + + expr5 = log(17) + opt5 = optimize(expr5, [log2_opt]) + assert opt5 == expr5 + + expr6 = log(x + 3)/log(2) + opt6 = optimize(expr6, [log2_opt]) + assert str(opt6) == 'log2(x + 3)' + + +def test_exp2_opt(): + x = Symbol('x') + expr1 = 1 + 2**x + opt1 = optimize(expr1, [exp2_opt]) + assert opt1 == 1 + exp2(x) + + expr2 = 1 + 3**x + assert expr2 == optimize(expr2, [exp2_opt]) + + +def test_expm1_opt(): + x = Symbol('x') + + expr1 = exp(x) - 1 + opt1 = optimize(expr1, [expm1_opt]) + assert expm1(x) - opt1 == 0 + + expr2 = 3*exp(x) - 3 + opt2 = optimize(expr2, [expm1_opt]) + assert 3*expm1(x) == opt2 + + expr3 = 3*exp(x) - 5 + assert expr3 == optimize(expr3, [expm1_opt]) + + expr4 = 3*exp(x) + log(x) - 3 + opt4 = optimize(expr4, [expm1_opt]) + assert 3*expm1(x) + log(x) == opt4 + + expr5 = 3*exp(2*x) - 3 + opt5 = optimize(expr5, [expm1_opt]) + assert 3*expm1(2*x) == opt5 + + +@pytest.mark.xfail +def test_expm1_two_exp_terms(): + x, y = map(Symbol, 'x y'.split()) + expr1 = exp(x) + exp(y) - 2 + opt1 = optimize(expr1, [expm1_opt]) + assert opt1 == expm1(x) + expm1(y) + + +def test_log1p_opt(): + x = Symbol('x') + expr1 = log(x + 1) + opt1 = optimize(expr1, [log1p_opt]) + assert log1p(x) - opt1 == 0 + + expr2 = log(3*x + 3) + opt2 = optimize(expr2, [log1p_opt]) + assert log1p(x) + log(3) == opt2 + + expr3 = log(2*x + 1) + opt3 = optimize(expr3, [log1p_opt]) + assert log1p(2*x) - opt3 == 0 + + expr4 = log(x+3) + opt4 = optimize(expr4, [log1p_opt]) + assert str(opt4) == 'log(x + 3)' + + +def test_optims_c99(): + x = Symbol('x') + + expr1 = 2**x + log(x)/log(2) + log(x + 1) + exp(x) - 1 + opt1 = optimize(expr1, optims_c99).simplify() + assert opt1 == exp2(x) + log2(x) + log1p(x) + expm1(x) + + expr2 = log(x)/log(2) + log(x + 1) + opt2 = optimize(expr2, optims_c99) + assert opt2 == log2(x) + log1p(x) + + expr3 = log(x)/log(2) + log(17*x + 17) + opt3 = optimize(expr3, optims_c99) + delta3 = opt3 - (log2(x) + log(17) + log1p(x)) + assert delta3 == 0 + + expr4 = 2**x + 3*log(5*x + 7)/(13*log(2)) + 11*exp(x) - 11 + log(17*x + 17) + opt4 = optimize(expr4, optims_c99).simplify() + delta4 = opt4 - (exp2(x) + 3*log2(5*x + 7)/13 + 11*expm1(x) + log(17) + log1p(x)) + assert delta4 == 0 + + expr5 = 3*exp(2*x) - 3 + opt5 = optimize(expr5, optims_c99) + delta5 = opt5 - 3*expm1(2*x) + assert delta5 == 0 + + expr6 = exp(2*x) - 3 + opt6 = optimize(expr6, optims_c99) + delta6 = opt6 - (exp(2*x) - 3) + assert delta6 == 0 + + expr7 = log(3*x + 3) + opt7 = optimize(expr7, optims_c99) + delta7 = opt7 - (log(3) + log1p(x)) + assert delta7 == 0 + + expr8 = log(2*x + 3) + opt8 = optimize(expr8, optims_c99) + delta8 = opt8 - (log(2*x + 3)) + assert delta8 == 0 From d781b080d4b09d4f8cc1b63498ff761a257e07e0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bj=C3=B6rn=20Dahlgren?= Date: Sat, 26 Aug 2017 07:21:14 +0200 Subject: [PATCH 2/3] Use SymPy's XFAIL, minor enhancements to tests --- sympy/codegen/rewriting.py | 8 ++++---- sympy/codegen/tests/test_rewriting.py | 27 +++++++++++++++++++++++---- 2 files changed, 27 insertions(+), 8 deletions(-) diff --git a/sympy/codegen/rewriting.py b/sympy/codegen/rewriting.py index 6912cbec03dc..096eb05b1061 100644 --- a/sympy/codegen/rewriting.py +++ b/sympy/codegen/rewriting.py @@ -12,11 +12,11 @@ def filterfalse(pred, iter): """ Classes and functions useful for rewriting expressions for optimized code -generation. Some languages (or standard thereof), e.g. C99, offer specialized -math functions which may offer better performance and/or precision. +generation. Some languages (or standards thereof), e.g. C99, offer specialized +math functions for better performance and/or precision. -Using the ``optimize`` function in this module together with a collection of -rules (represented as instances of ``Optimization``) one can rewrite the +Using the ``optimize`` function in this module, together with a collection of +rules (represented as instances of ``Optimization``), one can rewrite the expressions for this purpose. Examples diff --git a/sympy/codegen/tests/test_rewriting.py b/sympy/codegen/tests/test_rewriting.py index 89a3455ac814..ff6206b058d3 100644 --- a/sympy/codegen/tests/test_rewriting.py +++ b/sympy/codegen/tests/test_rewriting.py @@ -2,9 +2,10 @@ from __future__ import (absolute_import, division, print_function) import pytest -from sympy import log, exp, Symbol +from sympy import log, exp, Symbol, Pow from sympy.codegen.cfunctions import log2, exp2, expm1, log1p from sympy.codegen.rewriting import optimize, log2_opt, exp2_opt, expm1_opt, log1p_opt, optims_c99 +from sympy.utilities.pytest import XFAIL def test_log2_opt(): @@ -12,18 +13,22 @@ def test_log2_opt(): expr1 = 7*log(3*x + 5)/(log(2)) opt1 = optimize(expr1, [log2_opt]) assert opt1 == 7*log2(3*x + 5) + assert opt1.rewrite(log) == expr1 expr2 = 3*log(5*x + 7)/(13*log(2)) opt2 = optimize(expr2, [log2_opt]) assert opt2 == 3*log2(5*x + 7)/13 + assert opt2.rewrite(log) == expr2 expr3 = log(x)/log(2) opt3 = optimize(expr3, [log2_opt]) assert opt3 == log2(x) + assert opt3.rewrite(log) == expr3 expr4 = log(x)/log(2) + log(x+1) opt4 = optimize(expr4, [log2_opt]) assert opt4 == log2(x) + log(2)*log2(x+1) + assert opt4.rewrite(log) == expr4 expr5 = log(17) opt5 = optimize(expr5, [log2_opt]) @@ -32,6 +37,7 @@ def test_log2_opt(): expr6 = log(x + 3)/log(2) opt6 = optimize(expr6, [log2_opt]) assert str(opt6) == 'log2(x + 3)' + assert opt6.rewrite(log) == expr6 def test_exp2_opt(): @@ -39,6 +45,7 @@ def test_exp2_opt(): expr1 = 1 + 2**x opt1 = optimize(expr1, [exp2_opt]) assert opt1 == 1 + exp2(x) + assert opt1.rewrite(Pow) == expr1 expr2 = 1 + 3**x assert expr2 == optimize(expr2, [exp2_opt]) @@ -50,10 +57,12 @@ def test_expm1_opt(): expr1 = exp(x) - 1 opt1 = optimize(expr1, [expm1_opt]) assert expm1(x) - opt1 == 0 + assert opt1.rewrite(exp) == expr1 expr2 = 3*exp(x) - 3 opt2 = optimize(expr2, [expm1_opt]) assert 3*expm1(x) == opt2 + assert opt2.rewrite(exp) == expr2 expr3 = 3*exp(x) - 5 assert expr3 == optimize(expr3, [expm1_opt]) @@ -61,13 +70,15 @@ def test_expm1_opt(): expr4 = 3*exp(x) + log(x) - 3 opt4 = optimize(expr4, [expm1_opt]) assert 3*expm1(x) + log(x) == opt4 + assert opt4.rewrite(exp) == expr4 expr5 = 3*exp(2*x) - 3 opt5 = optimize(expr5, [expm1_opt]) assert 3*expm1(2*x) == opt5 + assert opt5.rewrite(exp) == expr5 -@pytest.mark.xfail +@XFAIL def test_expm1_two_exp_terms(): x, y = map(Symbol, 'x y'.split()) expr1 = exp(x) + exp(y) - 2 @@ -80,14 +91,17 @@ def test_log1p_opt(): expr1 = log(x + 1) opt1 = optimize(expr1, [log1p_opt]) assert log1p(x) - opt1 == 0 + assert opt1.rewrite(log) == expr1 expr2 = log(3*x + 3) opt2 = optimize(expr2, [log1p_opt]) assert log1p(x) + log(3) == opt2 + assert (opt2.rewrite(log) - expr2).simplify() == 0 expr3 = log(2*x + 1) opt3 = optimize(expr3, [log1p_opt]) assert log1p(2*x) - opt3 == 0 + assert opt3.rewrite(log) == expr3 expr4 = log(x+3) opt4 = optimize(expr4, [log1p_opt]) @@ -100,25 +114,30 @@ def test_optims_c99(): expr1 = 2**x + log(x)/log(2) + log(x + 1) + exp(x) - 1 opt1 = optimize(expr1, optims_c99).simplify() assert opt1 == exp2(x) + log2(x) + log1p(x) + expm1(x) + assert opt1.rewrite(exp).rewrite(log).rewrite(Pow) == expr1 expr2 = log(x)/log(2) + log(x + 1) opt2 = optimize(expr2, optims_c99) assert opt2 == log2(x) + log1p(x) + assert opt2.rewrite(log) == expr2 expr3 = log(x)/log(2) + log(17*x + 17) opt3 = optimize(expr3, optims_c99) delta3 = opt3 - (log2(x) + log(17) + log1p(x)) assert delta3 == 0 + assert (opt3.rewrite(log) - expr3).simplify() == 0 expr4 = 2**x + 3*log(5*x + 7)/(13*log(2)) + 11*exp(x) - 11 + log(17*x + 17) opt4 = optimize(expr4, optims_c99).simplify() delta4 = opt4 - (exp2(x) + 3*log2(5*x + 7)/13 + 11*expm1(x) + log(17) + log1p(x)) assert delta4 == 0 + assert (opt4.rewrite(exp).rewrite(log).rewrite(Pow) - expr4).simplify() == 0 expr5 = 3*exp(2*x) - 3 opt5 = optimize(expr5, optims_c99) delta5 = opt5 - 3*expm1(2*x) assert delta5 == 0 + assert opt5.rewrite(exp) == expr5 expr6 = exp(2*x) - 3 opt6 = optimize(expr6, optims_c99) @@ -129,8 +148,8 @@ def test_optims_c99(): opt7 = optimize(expr7, optims_c99) delta7 = opt7 - (log(3) + log1p(x)) assert delta7 == 0 + assert (opt7.rewrite(log) - expr7).simplify() == 0 expr8 = log(2*x + 3) opt8 = optimize(expr8, optims_c99) - delta8 = opt8 - (log(2*x + 3)) - assert delta8 == 0 + assert opt8 == expr8 From 62a625385257cc17b7f140aa35a59279319e0668 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bj=C3=B6rn=20Dahlgren?= Date: Sun, 27 Aug 2017 12:36:42 +0200 Subject: [PATCH 3/3] Address review (doc & compatibility) - Added filterfalse to sympy.core.compatiblity - Fixed formatting of docstrings in .codegen.rewriting - Ensured doctests are run on module level docstring of .codegen.rewriting --- sympy/codegen/rewriting.py | 66 ++++++++++++++++++++----------------- sympy/core/compatibility.py | 6 ++++ 2 files changed, 41 insertions(+), 31 deletions(-) diff --git a/sympy/codegen/rewriting.py b/sympy/codegen/rewriting.py index 096eb05b1061..be29a421c55e 100644 --- a/sympy/codegen/rewriting.py +++ b/sympy/codegen/rewriting.py @@ -1,15 +1,4 @@ # -*- coding: utf-8 -*- -from __future__ import (absolute_import, division, print_function) -from itertools import tee, chain -try: - from itertools import filterfalse -except ImportError: - def filterfalse(pred, iter): - return filter(lambda x: not pred(x), iter) - -from sympy import log, Add, exp, Max, Min, Wild, Pow, expand_log, Dummy -from sympy.codegen.cfunctions import log1p, log2, exp2, expm1 - """ Classes and functions useful for rewriting expressions for optimized code generation. Some languages (or standards thereof), e.g. C99, offer specialized @@ -17,23 +6,36 @@ def filterfalse(pred, iter): Using the ``optimize`` function in this module, together with a collection of rules (represented as instances of ``Optimization``), one can rewrite the -expressions for this purpose. - -Examples --------- ->>> from sympy import Symbol, exp, log ->>> from sympy.codegen.rewriting import optimize, c99_optims ->>> x = Symbol('x') ->>> optimize(3*exp(2*x) - 3, c99_optims) -3*expm1(2*x) ->>> optimize(exp(2*x) - 3, c99_optims) -exp(2*x) - 3 ->>> optimize(log(3*x + 3), c99_optims) -log(3) + log1p(x) ->>> optimize(log(2*x + 3), c99_optims) -log(2*x + 3) +expressions for this purpose:: + + >>> from sympy import Symbol, exp, log + >>> from sympy.codegen.rewriting import optimize, optims_c99 + >>> x = Symbol('x') + >>> optimize(3*exp(2*x) - 3, optims_c99) + 3*expm1(2*x) + >>> optimize(exp(2*x) - 3, optims_c99) + exp(2*x) - 3 + >>> optimize(log(3*x + 3), optims_c99) + log1p(x) + log(3) + >>> optimize(log(2*x + 3), optims_c99) + log(2*x + 3) + +The ``optims_c99`` imported above is tuple containing the following instances +(which may be imported from ``sympy.codegen.rewriting``): + +- ``expm1_opt`` +- ``log1p_opt`` +- ``exp2_opt`` +- ``log2_opt`` +- ``log2const_opt`` + """ +from __future__ import (absolute_import, division, print_function) +from itertools import tee, chain +from sympy import log, Add, exp, Max, Min, Wild, Pow, expand_log, Dummy +from sympy.core.compatibility import filterfalse +from sympy.codegen.cfunctions import log1p, log2, exp2, expm1 class Optimization(object): @@ -43,7 +45,7 @@ class Optimization(object): as argument. Parameters - ---------- + ========== cost_function : callable returning number priority : number @@ -61,12 +63,13 @@ class ReplaceOptim(Optimization): :meth:`sympy.core.basic.Basic.replace`). Parameters - ---------- + ========== query : first argument passed to replace value : second argument passed to replace Examples - -------- + ======== + >>> from sympy import Symbol, Pow >>> from sympy.codegen.rewriting import ReplaceOptim >>> from sympy.codegen.cfunctions import exp2 @@ -91,13 +94,14 @@ def optimize(expr, optimizations): """ Apply optimizations to an expression. Parameters - ---------- + ========== expr : expression optimizations : iterable of ``Optimization`` instances The optimizations will be sorted with respect to ``priority`` (highest first). Examples - -------- + ======== + >>> from sympy import log, Symbol >>> from sympy.codegen.rewriting import optims_c99, optimize >>> x = Symbol('x') diff --git a/sympy/core/compatibility.py b/sympy/core/compatibility.py index 524d2b3bdda6..0d10408036fa 100644 --- a/sympy/core/compatibility.py +++ b/sympy/core/compatibility.py @@ -860,3 +860,9 @@ def cache_clear(): if sys.version_info[:2] >= (3, 3): # 3.2 has an lru_cache with an incompatible API from functools import lru_cache + +try: + from itertools import filterfalse +except ImportError: + def filterfalse(pred, itr): + return filter(lambda x: not pred(x), itr)