Skip to content

Commit

Permalink
Merge pull request #22840 from zouhairm/peabody/cse_matrixsymbol
Browse files Browse the repository at this point in the history
Peabody/cse matrixsymbol
  • Loading branch information
smichr committed Jan 12, 2022
2 parents 8c446f0 + 86d13ad commit c907524
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 21 deletions.
6 changes: 5 additions & 1 deletion sympy/simplify/cse_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,6 +567,7 @@ def tree_cse(exprs, symbols, opt_subs=None, order='canonical', ignore=()):
Substitutions containing any Symbol from ``ignore`` will be ignored.
"""
from sympy.matrices.expressions import MatrixExpr, MatrixSymbol, MatMul, MatAdd
from sympy.matrices.expressions.matexpr import MatrixElement
from sympy.polys.rootoftools import RootOf

if opt_subs is None:
Expand All @@ -586,7 +587,10 @@ def _find_repeated(expr):
if isinstance(expr, RootOf):
return

if isinstance(expr, Basic) and (expr.is_Atom or expr.is_Order):
if isinstance(expr, Basic) and (
expr.is_Atom or
expr.is_Order or
isinstance(expr, (MatrixSymbol, MatrixElement))):
if expr.is_Symbol:
excluded_symbols.add(expr)
return
Expand Down
4 changes: 4 additions & 0 deletions sympy/simplify/tests/test_cse.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,10 @@ def test_cse_MatrixSymbol():
B = MatrixSymbol("B", n, n)
assert cse(B) == ([], [B])

assert cse(A[0] * A[0]) == ([], [A[0]*A[0]])

assert cse(A[0,0]*A[0,1] + A[0,0]*A[0,1]*A[0,2]) == ([(x0, A[0, 0]*A[0, 1])], [x0*A[0, 2] + x0])

def test_cse_MatrixExpr():
A = MatrixSymbol('A', 3, 3)
y = MatrixSymbol('y', 3, 1)
Expand Down
23 changes: 3 additions & 20 deletions sympy/utilities/tests/test_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,26 +531,9 @@ def test_multidim_c_argument_cse():
'#include "test.h"\n'
"#include <math.h>\n"
"void c(double *A, double *b, double *out) {\n"
" double x0[9];\n"
" x0[0] = A[0];\n"
" x0[1] = A[1];\n"
" x0[2] = A[2];\n"
" x0[3] = A[3];\n"
" x0[4] = A[4];\n"
" x0[5] = A[5];\n"
" x0[6] = A[6];\n"
" x0[7] = A[7];\n"
" x0[8] = A[8];\n"
" double x1[3];\n"
" x1[0] = b[0];\n"
" x1[1] = b[1];\n"
" x1[2] = b[2];\n"
" const double x2 = x1[0];\n"
" const double x3 = x1[1];\n"
" const double x4 = x1[2];\n"
" out[0] = x2*x0[0] + x3*x0[1] + x4*x0[2];\n"
" out[1] = x2*x0[3] + x3*x0[4] + x4*x0[5];\n"
" out[2] = x2*x0[6] + x3*x0[7] + x4*x0[8];\n"
" out[0] = A[0]*b[0] + A[1]*b[1] + A[2]*b[2];\n"
" out[1] = A[3]*b[0] + A[4]*b[1] + A[5]*b[2];\n"
" out[2] = A[6]*b[0] + A[7]*b[1] + A[8]*b[2];\n"
"}\n"
)
assert code == expected
Expand Down

0 comments on commit c907524

Please sign in to comment.