Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

solve() returns overly complex results for simple linear equations #23140

Closed
moorepants opened this issue Feb 21, 2022 · 11 comments
Closed

solve() returns overly complex results for simple linear equations #23140

moorepants opened this issue Feb 21, 2022 · 11 comments

Comments

@moorepants
Copy link
Member

moorepants commented Feb 21, 2022

I can develop a work around for this, but I think the behavior should be what we see prior to SymPy 1.7. We should try to return the simpler, clearer result for a linear solve.

master:

In [1]: import sympy as sm

In [2]: import sympy.physics.mechanics as me

In [3]: q1, q2, q3, u1, u2, u3 = me.dynamicsymbols('q1, q2, q3, u1, u2, u3')

In [4]: eqs = me.kinematic_equations([u1, u2, u3], [q1, q2, q3], 'body', 'ZXZ')

In [5]: eqs
Out[5]: 
[-(u1(t)*sin(q3(t)) + u2(t)*cos(q3(t)))/sin(q2(t)) + Derivative(q1(t), t),
 -u1(t)*cos(q3(t)) + u2(t)*sin(q3(t)) + Derivative(q2(t), t),
 (u1(t)*sin(q3(t)) + u2(t)*cos(q3(t)))*cos(q2(t))/sin(q2(t)) - u3(t) + Derivative(q3(t), t)]

In [6]: sm.solve(eqs, [u1, u2, u3])
Out[6]: 
{u1(t): sin(q2(t))*sin(q3(t))*Derivative(q1(t), t)/(sin(q3(t))**2 + cos(q3(t))**2) + cos(q3(t))*Derivative(q2(t), t)/(sin(q3(t))**2 + cos(q3(t))**2),
 u2(t): sin(q2(t))*cos(q3(t))*Derivative(q1(t), t)/(sin(q3(t))**2 + cos(q3(t))**2) - sin(q3(t))*Derivative(q2(t), t)/(sin(q3(t))**2 + cos(q3(t))**2),
 u3(t): cos(q2(t))*Derivative(q1(t), t) + Derivative(q3(t), t)}

sympy 1.6:

In [1]: import sympy as sm

In [2]: import sympy.physics.mechanics as me

In [3]: q1, q2, q3, u1, u2, u3 = me.dynamicsymbols('q1, q2, q3, u1, u2, u3')

In [4]: eqs = me.kinematic_equations([u1, u2, u3], [q1, q2, q3], 'body', 'ZXZ')

In [5]: sm.solve(eqs, [u1, u2, u3])
Out[5]: 
{u1(t): sin(q2(t))*sin(q3(t))*Derivative(q1(t), t) + cos(q3(t))*Derivative(q2(t), t),
 u2(t): sin(q2(t))*cos(q3(t))*Derivative(q1(t), t) - sin(q3(t))*Derivative(q2(t), t),
 u3(t): cos(q2(t))*Derivative(q1(t), t) + Derivative(q3(t), t)}

I noticed this in #23130 and bisected it to f0b7996 from PR #18844.

@oscarbenjamin
Copy link
Contributor

It's not obvious how that commit would lead to the different output shown.

@moorepants
Copy link
Member Author

I can try bisecting again.

@moorepants
Copy link
Member Author

The example shown here: #23075 (comment) gives 863824 operations on SymPy 1.71 and 621264 operations on SymPy 1.6.2, which is about a 140% increase in operations.

@moorepants
Copy link
Member Author

I got the same commit on a second bisect:

moorepants@agni:sympy((28abe430fe...)|BISECTING)$ git bisect good
Bisecting: 0 revisions left to test after this (roughly 0 steps)
[f0b7996ca517e3b78749ec5c40e86d9d6efc0f23] feat(polys): Add DomainMatrix based on poly elements

f0b7996

@oscarbenjamin
Copy link
Contributor

In SymPy 1.6 the solve_linear_system function was used for this case. Basically after solving the system it would pass the solutions through simplify before returning the output at this line:

sympy/sympy/solvers/solvers.py

Lines 2239 to 2241 in c3087c8

if do_simplify:
for k, v in rv.items():
rv[k] = simplify(v)

Without that the output in 1.6 is more complicated:

In [3]: sm.solve(eqs, [u1, u2, u3], simplify=False)
Out[3]: 
{u1(t): -sin(q2(t))**2*sin(q3(t))*Derivative(q1(t), t)/(-sin(q2(t))*sin(q3(t))**2 - sin(q2(t))*cos(q3(t))**2) - sin(q2(t))*cos(q3(t))*Derivative(q2(t), t)/(-sin(q2(t))*sin(q3(t))**2 - sin(q2(t))*cos(q3(t))**2),
 u2(t): -sin(q2(t))**2*cos(q3(t))*Derivative(q1(t), t)/(-sin(q2(t))*sin(q3(t))**2 - sin(q2(t))*cos(q3(t))**2) + sin(q2(t))*sin(q3(t))*Derivative(q2(t), t)/(-sin(q2(t))*sin(q3(t))**2 - sin(q2(t))*cos(q3(t))**2),
 u3(t): (-sin(q3(t))**2*cos(q2(t)) - cos(q2(t))*cos(q3(t))**2)*sin(q2(t))*Derivative(q1(t), t)/(-sin(q2(t))*sin(q3(t))**2 - sin(q2(t))*cos(q3(t))**2) - (sin(q3(t))**2 + cos(q3(t))**2)*sin(q2(t))*Derivative(q3(t), t)/(-sin(q2(t))*sin(q3(t))**2 - sin(q2(t))*cos(q3(t))**2)}

Personally I don't much see the point in solve calling simplify on its outputs when the caller can just do that themselves if they want. The idea for DomainMatrix is that it applies well-defined simplification rather than arbitrary heuristic simplification like simplify. It doesn't yet have a domain that can automatically handle sin and cos but when it does something like sin(t)**2 + cos(t)**2 will disappear automatically (but in a way that will be handled more efficiently than trigsimp or simplify).

@moorepants
Copy link
Member Author

I agree that simplify() should not be called in our library code, but for a 3x3 linear solve, it seems trivial to return an answer that doesn't have these sin**2 + cos**2 terms in the denominator and the expected output has changed, with rather large consequences.

@oscarbenjamin
Copy link
Contributor

This is the matrix for this system:

In [2]: M, b = linear_eq_to_matrix(
   ...:     [eq.subs(zip([u1, u2, u3], [x, y, z])) for eq in eqs], [x, y, z]
   ...: )

In [3]: M
Out[3]: 
⎡    -sin(q₃(t))            -cos(q₃(t))          ⎤
⎢    ────────────           ────────────       0 ⎥
⎢     sin(q₂(t))             sin(q₂(t))          ⎥
⎢                                                ⎥
⎢     -cos(q₃(t))            sin(q₃(t))        0 ⎥
⎢                                                ⎥
⎢sin(q₃(t))⋅cos(q₂(t))  cos(q₂(t))⋅cos(q₃(t))    ⎥
⎢─────────────────────  ─────────────────────  -1⎥
⎣      sin(q₂(t))             sin(q₂(t))         ⎦

Working on pen and paper how would you solve that without at some point applying trig identities?

The determinant is given in this form and so it also comes out naturally in the inverse:

In [8]: M.det().together()
Out[8]: 
   2             2       
sin (q₃(t)) + cos (q₃(t))
─────────────────────────
        sin(q₂(t)) 

In [6]: M.inv()
Out[6]: 
⎡ -sin(q₂(t))⋅sin(q₃(t))           -cos(q₃(t))           ⎤
⎢─────────────────────────  ─────────────────────────  0 ⎥
⎢   2             2            2             2           ⎥
⎢sin (q₃(t)) + cos (q₃(t))  sin (q₃(t)) + cos (q₃(t))    ⎥
⎢                                                        ⎥
⎢ -sin(q₂(t))⋅cos(q₃(t))            sin(q₃(t))           ⎥
⎢─────────────────────────  ─────────────────────────  0 ⎥
⎢   2             2            2             2           ⎥
⎢sin (q₃(t)) + cos (q₃(t))  sin (q₃(t)) + cos (q₃(t))    ⎥
⎢                                                        ⎥
⎣       -cos(q₂(t))                     0              -1

You can also see it implicitly in the LU decomposition (see the element with cos squared):

In [9]: M.LUdecomposition()
Out[9]: 
⎛                               ⎡-sin(q₃(t))         -cos(q₃(t))           ⎤    ⎞
⎜                               ⎢────────────        ────────────        0 ⎥    ⎟
⎜⎡          1            0  0⎤  ⎢ sin(q₂(t))          sin(q₂(t))           ⎥    ⎟
⎜⎢                           ⎥  ⎢                                          ⎥    ⎟
⎜⎢sin(q₂(t))⋅cos(q₃(t))      ⎥  ⎢                              2           ⎥    ⎟
⎜⎢─────────────────────  1  0⎥, ⎢                           cos (q₃(t))    ⎥, []⎟
⎜⎢      sin(q₃(t))           ⎥  ⎢     0        sin(q₃(t)) + ───────────  0 ⎥    ⎟
⎜⎢                           ⎥  ⎢                            sin(q₃(t))    ⎥    ⎟
⎜⎣     -cos(q₂(t))       0  1⎦  ⎢                                          ⎥    ⎟
⎝                               ⎣     0                   0              -1⎦    ⎠

If you apply Gaussian elimination or anything else then at some point you will get trig combinations that can be simplified.

Given that the previous output was obtained simply by calling simplify at the end its clear how to recover the old behaviour if you want that:

In [10]: sol = solve(eqs, [u1, u2, u3])

In [11]: {s: simplify(v) for s, v in sol.items()}
Out[11]: 
⎧                             d                      d                                       du₁(t): sin(q₂(t))⋅sin(q₃(t))⋅──(q₁(t)) + cos(q₃(t))⋅──(q₂(t)), u₂(t): sin(q₂(t))⋅cos(q₃(t))⋅──(q₁(
⎩                             dt                     dt                                      dt    

                 d                            d           dt)) - sin(q₃(t))⋅──(q₂(t)), u₃(t): cos(q₂(t))⋅──(q₁(t)) + ──(q₃(t))⎬
                 dt                           dt          dt

Longer term I have a better idea which is to make a polys domain that can represent sin and cos. The problem with sin and cos is that they are algebraically dependent which means a domain containing them is a bit more complicated to implement (#20608).

@moorepants
Copy link
Member Author

I understand all of that, but downstream code has relied on the prior simplified output. Removing the simplification that shouldn't have been there in the first place does have consequences, even though the outputs are mathematically equivalent.

I'll build a work around for the change in the physics.vector module.

@oscarbenjamin
Copy link
Contributor

Removing the simplification that shouldn't have been there in the first place does have consequences

I'm not sure I realised at the time that I was effectively removing the call to simplify(). The solve() code is such a mess that it's hard to trace anything. You would expect that post-simplification would just happen at the end of the main solve() function just before returning but instead everything is peppered everywhere.

At this point the call to simplify is gone which is good and it's been like this for a few releases so I think it's best to keep it that way now. We do clearly need more tests for nontrivial cases though (as well as benchmarks). You can see in #18844 that a handful of tests were changed because the form of the output had changed. Clearly there was no test that benefited from the use of simplify though.

@moorepants
Copy link
Member Author

Thanks. I'll close this and we'll fix physics.vector to provide the desired output. It also showed a poor design there which relied on solve at all (we don't need to in this case).

@smichr
Copy link
Member

smichr commented Feb 27, 2022

The solve() code is such a mess that it's hard to trace anything

The simplification happens in the solve-one or solve-many routines. And part of the reason for the peppering is that different solvers that are called do/do not need to have the solution checked. A linear system doesn't need checking, for example.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants