Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cse does unnecessary replacements of Matrix symbols #14559

Open
stewpend0us opened this issue Mar 26, 2018 · 1 comment
Open

cse does unnecessary replacements of Matrix symbols #14559

stewpend0us opened this issue Mar 26, 2018 · 1 comment
Labels

Comments

@stewpend0us
Copy link
Contributor

>>> from sympy import *
>>> from sympy.simplify import cse
>>> A = MatrixSymbol('A',2,2)
>>> B = MatrixSymbol('B',2,1)
>>> replacements, new = cse(Matrix(A*B))
>>> replacements
[(x0, A), (x1, B), (x2, x1[0, 0]), (x3, x1[1, 0])]
>>> new
[Matrix([
[x2*x0[0, 0] + x3*x0[0, 1]],
[x2*x0[1, 0] + x3*x0[1, 1]]])]

(x0, A), (x1, B) isn't wrong I guess but doesn't seem necessary?
I would expect:

>>> replacements
[(x0, B[0, 0]), (x1, B[1, 0])]
>>> new
[Matrix([
[x0*A[0, 0] + x1*A[0, 1]],
[x0*A[1, 0] + x1*A[1, 1]]])]

Maybe related to #11991

@rpep
Copy link

rpep commented Feb 19, 2019

I also just came across this issue, it doesn't seem to be resolved. It's quite tricky when you try to convert to code with a printer, because you have to handle the case in a non-straightforward way:

>>> import sympy as sp
>>> from sympy.printing.ccode import C99CodePrinter

>>> p = C99CodePrinter()

>>> x, y, z = sp.symbols('x y z')

>>> M = sp.MatrixSymbol('M', 4, 1)
>>> A = sp.Matrix([M[0, 0],
               x*M[0, 0] + M[1, 0],
               y*M[0, 0] + M[2, 0],
               z*M[0, 0] + M[3, 0],
               x**2*M[0, 0]
              ])

>>> sub_expressions, Apr = sp.cse(A)

>>> code = ""
>>> for i, (var, sub_expr) in enumerate(sub_expressions):
>>>     code += p.doprint(sub_expr, assign_to=var) + "\n"
    
>>> print(code)
x0[0] = M[0];
x0[1] = M[1];
x0[2] = M[2];
x0[3] = M[3];
x1 = x0[0];

>>> print(p.doprint(sp.Matrix(Apr), assign_to='B'))
B[0] = x1;
B[1] = x*x1 + x0[1];
B[2] = x1*y + x0[2];
B[3] = x1*z + x0[3];
B[4] = pow(x, 2)*x1;

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

3 participants