From 5e97b9dfeba2da0a1e5c73ef316da03205a51ba1 Mon Sep 17 00:00:00 2001 From: mleila1312 Date: Thu, 4 Apr 2024 11:57:03 +0200 Subject: [PATCH] Fix cse treatment in lambdify with Derivatives Before, when there were Derivatives in expr ans args given to lambdify with cse enabled, there was an error because the cse treatment changed the arguments of the Derivative object. With this implementation, the expression is pre-treated by a function to mask the instances of Derivative objects, then the cse process is applied and finally we do a post-treatment to put back the Derivative expressions in expr. --- .mailmap | 1 + sympy/utilities/lambdify.py | 161 ++++++++++++++++++++++++- sympy/utilities/tests/test_lambdify.py | 79 ++++++++++++ 3 files changed, 236 insertions(+), 5 deletions(-) diff --git a/.mailmap b/.mailmap index ab3bc2dd057e..b6ae32c2e0f9 100644 --- a/.mailmap +++ b/.mailmap @@ -1586,6 +1586,7 @@ latot luzpaz luz paz luzpaz luz.paz mao8 +mleila1312 mohajain mohajain <45903778+mohajain@users.noreply.github.com> mohammedouahman mohit <39158356+mohitacecode@users.noreply.github.com> mohit <42018918+mohitshah3111999@users.noreply.github.com> diff --git a/sympy/utilities/lambdify.py b/sympy/utilities/lambdify.py index 23fbf1299f6d..9c51ecb3fc8c 100644 --- a/sympy/utilities/lambdify.py +++ b/sympy/utilities/lambdify.py @@ -17,9 +17,10 @@ from sympy.utilities.exceptions import sympy_deprecation_warning from sympy.utilities.decorator import doctest_depends_on from sympy.utilities.iterables import (is_sequence, iterable, - NotIterable, flatten) + NotIterable, flatten, numbered_symbols,) from sympy.utilities.misc import filldedent + __doctest_requires__ = {('lambdify',): ['numpy', 'tensorflow']} # Default namespaces, letting us define translations that can't be defined @@ -178,6 +179,151 @@ def _import(module, reload=False): # linecache. _lambdify_generated_counter = 1 +def _replace_recursively(e, dict) : + if isinstance(e, list): + return [_replace_recursively(sub_e, dict)for sub_e in e] + elif isinstance(e, tuple): + return tuple([_replace_recursively(sub_e, dict)for sub_e in e]) + else : + return e.xreplace(dict) + +def _pre_treatment_cse(args_f, expr): + r""" + This function masks Derivative that are also arguments + of the expression to prevent erros in the cse treament + in lambdify. + + The first step is to go through the expression to make + sure that we don't replace the Derivatives by symbols + already in the expression. Then we replace the + Derivatives in the expression with symbols and remember + the changes in a dictionary. + + Parameters : + args_f : the arguments of expr given in lambdify + + expr : expression given to lambdify + Return : + dictionary : dictionary of the associations + Derivative-new name + + new_expr : expression where the Derivatives + have been replaced + + """ + #Necessary librairies and dependencies + from sympy.core.function import Derivative + from sympy.core.symbol import Symbol + from sympy.core import Basic + from sympy.matrices.expressions import MatrixSymbol + from sympy.matrices.expressions.matexpr import MatrixElement + from sympy.polys.rootoftools import RootOf + + #creation of the dictionary + dictionary={} + # creation of the symbols that can't be used to replace the Derivatives in the expression + excluded_symbols = set() + symbols = numbered_symbols(cls=Symbol) + def _eliminates_symbols(expr): + # function that finds the symbols that can't be used + if not isinstance(expr, Basic): + return + + if isinstance(expr, RootOf): + return + + if isinstance(expr, Basic) and ( + expr.is_Atom or + expr.is_Order or + isinstance(expr, (MatrixSymbol, MatrixElement))): + if expr.is_Symbol: + excluded_symbols.add(expr.name) + return + #recursively goes through the expression + if iterable(expr): + args = expr + else: + args = expr.args + list(map(_eliminates_symbols, args)) + return + + if iterable(expr): + for e in expr: + if isinstance(e, Basic): + _eliminates_symbols(e) + else: + if isinstance(expr, Basic): + _eliminates_symbols(expr) + + #gets the possible symbols to replace Derivatives with + symbols = (_ for _ in symbols if _.name not in excluded_symbols) + new_expr = expr + # replaces the instances of Derivatives in the expression + + for arg in args_f: + if isinstance(arg, (Derivative)): + try: + dictionary[arg] = next(symbols) + except StopIteration: + raise ValueError("Symbols iterator ran out of symbols.") + + new_expr=_replace_recursively(new_expr, dictionary) + return dictionary, new_expr + +def _post_treatment_cse(dictionary, args, expr, cses): + r""" + This function changes back the replaced Derivatives to + their original values after passing through + _pre_treatment_cse and cse in lambdify. + + This function returns the Derivatives to their + original value in the expression and cses. + + Parameters : + dictionary : dictonary containing associations + of Derivative-new name given by _pre_treatment_cse + + args : arguments given to lambdify of expr + + expr : expression returned by cse + + cses : changes made by the cse process containing + the associations partial expression- new name + + Return : + post_cses : cses modified to return Derivatives + to their original value + + post_expr : expression where the Derivatives have + been returned back to their original values + + """ + from sympy.core.function import Derivative + post_expr = expr + post_cses= cses + for arg in args: + if isinstance(arg, Derivative): + association = [] + #checks if if the new name of the Deivative was changed by the cse process + #or if combinations of the Derivatives expressions were replaces + for i in range(len(cses)): + new_a, a = cses[i] + if a.has(dictionary[arg]): + if a == dictionary[arg]: + association = new_a + post_cses.remove((new_a, a)) + else : + a = a.xreplace({dictionary[arg] : arg}) + cses[i] = new_a, a + #Checks if the new name of the Deivative was changed by the cse process + if association == []: + # if the derivative hasn't been replaced by the cse process + post_expr = _replace_recursively(post_expr, {dictionary[arg] : arg}) + else: + # if the derivative has been replaced by the cse process + post_expr = _replace_recursively(post_expr,{association : arg}) + return post_cses, post_expr + @doctest_depends_on(modules=('numpy', 'scipy', 'tensorflow',), python_version=(3,)) def lambdify(args, expr, modules=None, printer=None, use_imps=True, @@ -277,7 +423,8 @@ def lambdify(args, expr, modules=None, printer=None, use_imps=True, 6 expr : Expr - An expression, list of expressions, or matrix to be evaluated. + An expression, list of expressions, tuple of expressions or + matrix to be evaluated. Lists may be nested. If the expression is a list, the output will also be a list. @@ -756,7 +903,6 @@ def _lambdifygenerated(x): """ from sympy.core.symbol import Symbol from sympy.core.expr import Expr - # If the user hasn't specified any modules, use what is available. if modules is None: try: @@ -866,8 +1012,13 @@ def _lambdifygenerated(x): funcprinter = _EvaluatorPrinter(printer, dummify) if cse == True: - from sympy.simplify.cse_main import cse as _cse - cses, _expr = _cse(expr, list=False) + #get the dictionary containing the Derivative in the + #arguments and their new name + dictionary, new_expr= _pre_treatment_cse(args, expr) + from sympy.simplify.cse_main import cse as cse_function + cses, _expr = cse_function(new_expr, list=False) + #puts back the instances of Derivatives inthe expression + cses, _expr= _post_treatment_cse(dictionary, args, _expr, cses) elif callable(cse): cses, _expr = cse(expr) else: diff --git a/sympy/utilities/tests/test_lambdify.py b/sympy/utilities/tests/test_lambdify.py index 7ee8e84d6dca..96483da3d407 100644 --- a/sympy/utilities/tests/test_lambdify.py +++ b/sympy/utilities/tests/test_lambdify.py @@ -1891,3 +1891,82 @@ def test_assoc_legendre_numerical_evaluation(): assert all_close(sympy_result_integer, mpmath_result_integer, tol) assert all_close(sympy_result_complex, mpmath_result_complex, tol) + +def test_derivative_issue_26404(): + r""" + test issue fixed when using cse in lambdify when some arguments + are Derivatives and they appear in the expression + """ + from sympy import (cos, sin, Matrix, symbols) + from sympy.physics.mechanics import (dynamicsymbols) + t = symbols("t") + x = Function("x")(t) + xd = x.diff(t) + xdd= xd.diff(t) + assert lambdify((xd, x), xd, cse=True)(1, 1) == 1 + assert lambdify((xd, x), xd + x, cse=True)(1, 1) == 2 + assert lambdify((xdd, xd, x), xdd*xd + x, cse=True)(3,1, 1) == 4 + assert lambdify((xd, xdd, x), xdd*xd + x, cse=True)(3,1, 1) == 4 + assert lambdify((xdd, xd, x), cos(xdd*xd) + x, cse=True)(0,1, 1) == 2.0 + #test for matrix and cases were Derivative(a,b) becomes x_n + #and cse makes a replacement x_m: x_n**2 or other + #and case where xn(n : int) is already the name of an + #element of the function + x0, m0 = symbols("x0 m0") + l1, m1 = symbols("l1 m1") + m2 = symbols("m2") + g = symbols("g") + q0, q1, q2 = Function("q0")(x0),Function("q1")(l1),Function("q2")(m0) + u1, u2 =q1.diff(l1), q2.diff(m0) + F, T1 = dynamicsymbols("F T1") + massmatrix1 = Matrix([[m0 + m1 + m2, -x0*m1*cos(q1) - x0*m2*cos(q1), + -l1*m2*cos(q2)], + [-x0*m1*cos(q1) - x0*m2*cos(q1), x0**2*m1 + x0**2*m2, + x0*l1*m2*(sin(q1)*sin(q2) + cos(q1)*cos(q2))], + [-l1*m2*cos(q2), + x0*l1*m2*(sin(q1)*sin(q2) + cos(q1)*cos(q2)), + l1**2*m2]]) + + forcing1 = Matrix([[-x0*m1*u1**2*sin(q1) - x0*m2*u1**2*sin(q1) - + l1*m2*u2**2*sin(q2) + F, + g*x0*m1*sin(q1) + g*x0*m2*sin(q1) - + x0*l1*m2*(sin(q1)*cos(q2) - sin(q2)*cos(q1))*u2**2, + g*l1*m2*sin(q2) - x0*l1*m2*(-sin(q1)*cos(q2) + + sin(q2)*cos(q1))*u1**2], + [-x0*m1*u1**2*sin(q1) - x0*m2*u1**2*sin(q1) - + l1*m2*u2**2*sin(q2) + F, + g*x0*m1*sin(q1) + g*x0*m2*sin(q1) - + x0*l1*m2*(sin(q1)*cos(q2) - sin(q2)*cos(q1))*u2**2, + g*l1*m2*sin(q2) - x0*l1*m2*(-sin(q1)*cos(q2) + + sin(q2)*cos(q1))*u1**2], + [-x0*m1*u1**2*sin(q1) - x0*m2*u1**2*sin(q1) - + l1*m2*u2**2*sin(q2) + F, + g*x0*m1*sin(q1) + g*x0*m2*sin(q1) - + x0*l1*m2*(sin(q1)*cos(q2) - sin(q2)*cos(q1))*u2**2, + g*l1*m2*sin(q2) - x0*l1*m2*(-sin(q1)*cos(q2) + + sin(q2)*cos(q1))*u1**2]]) + res_expected=Matrix([[ 1., 0, 0],[-1., 0, 0],[-1., 0, 0]]) + res_lamdbify=Matrix((lambdify((x0, m0 ,l1, m1, m2, g, q0, q1, q2, u1, u2, F, T1), massmatrix1 -forcing1, \ + cse=True)( 0, 0 ,0, 1, 1, 1, 0, 1, 1 , 1, 1, 1, 1))) + equal=True + for i in range(res_lamdbify.rows*res_lamdbify.cols): + equal=equal and (res_expected[i]==res_lamdbify[i]) + assert equal + # test in the case chen a list of expressions is given + expected=[[[0, 2, 18], 5], [18, 1], 0] + t1, t2, t3, t4, t5, t6, t7 = symbols("t1 t2 t3 t4 t5 t6 t7") + x1, x2, x3, x4, x5, x6, x7 = Function('x1')(t1), Function('x2')(t2), Function('x3')(t3),\ + Function('x4')(t4), Function('x5')(t5), Function('x6')(t6),\ + Function('x7')(t7) + d1, d2, d3, d4, d5, d6, d7 = x1.diff(t1), x2.diff(t2), x3.diff(t3), x4.diff(t4),\ + x5.diff(t5), x6.diff(t6), x7.diff(t7) + list_of_list = [[[d5*d1, d3, d4*d7], d6], [d4*d7, d2],d5*d1] + res_list_of_list = lambdify((d1, d2, d3, d4, d5, d6, d7),list_of_list,\ + cse=True)(0,1,2,3,4,5,6) + assert (expected[0][0][0] == res_list_of_list[0][0][0] and\ + expected[0][0][1] == res_list_of_list[0][0][1] and\ + expected[0][0][2] == res_list_of_list[0][0][2] and \ + expected[0][1] == res_list_of_list[0][1] and\ + expected[1][0] == res_list_of_list[1][0] and \ + expected[1][1] == res_list_of_list[1][1] and \ + expected[2] == res_list_of_list[2])