Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve dsolve #21743

Merged
merged 11 commits into from Jul 22, 2021
Merged

Improve dsolve #21743

merged 11 commits into from Jul 22, 2021

Conversation

Mohitbalwani26
Copy link
Member

@Mohitbalwani26 Mohitbalwani26 commented Jul 14, 2021

References to other Issues or PRs

#18348

Brief description of what is fixed or changed

Earlier the flow of code was dsolve -> _desolve -> classify_ode (which checked for every hint whether it matches and return all the matching hints) -> particular solver's general solution was called.

This PR now changes the flow in classify_ode as whenever the user calls with explicit hint, It won't be matching all the possible solvers. It will just create the instance of that particular class and if it matches it will return.

Other comments

This also significantly improves the speed of test suite.

Release Notes

NO ENTRY

@sympy-bot
Copy link

sympy-bot commented Jul 14, 2021

Hi, I am the SymPy bot (v161). I'm here to help you write a release notes entry. Please read the guide on how to write release notes.

  • No release notes entry will be added for this pull request.
Click here to see the pull request description that was parsed.
<!-- Your title above should be a short description of what
was changed. Do not include the issue number in the title. -->

#### References to other Issues or PRs
<!-- If this pull request fixes an issue, write "Fixes #NNNN" in that exact
format, e.g. "Fixes #1234" (see
https://tinyurl.com/auto-closing for more information). Also, please
write a comment on that issue linking back to this pull request once it is
open. -->
#18348 

#### Brief description of what is fixed or changed
Earlier the flow of code was dsolve -> _desolve -> classify_ode (which checked for every hint whether it matches and return all the matching hints) -> particular solver's general solution was called.

This PR now changes the flow in classify_ode as whenever the user calls with explicit hint, It won't be matching all the possible solvers. It will just create the instance of that particular class and if it matches it will return.

#### Other comments
This also significantly improves the speed of test suite.

#### Release Notes

<!-- Write the release notes for this release below between the BEGIN and END
statements. The basic format is a bulleted list with the name of the subpackage
and the release note for this PR. For example:

* solvers
  * Added a new solver for logarithmic equations.

* functions
  * Fixed a bug with log of integers.

or if no release note(s) should be included use:

NO ENTRY

See https://github.com/sympy/sympy/wiki/Writing-Release-Notes for more
information on how to write release notes. The bot will check your release
notes automatically to see if they are formatted correctly. -->

<!-- BEGIN RELEASE NOTES -->
NO ENTRY
<!-- END RELEASE NOTES -->

@@ -1017,42 +1017,54 @@ class in it. Note that a hint may do this anyway if
# Any ODE that can be solved with a combination of algebra and
# integrals e.g.:
# d^3/dx^3(x y) = F(x)
solver_map = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should probably be a module level constant. Perhaps it should be defined at the bottom of single.py.

Comment on lines 1054 to 1066
if user_hint not in ['default', 'all', 'all_Integral', 'best'] and user_hint in solver_map:
solver = solver_map[user_hint](ode)
if solver.matches():
for hints in solvers[solvercls]:
matching_hints[hints] = solver
if solvercls.has_integral:
matching_hints[hints + "_Integral"] = solver
matching_hints[user_hint] = solver
if solver_map[user_hint].has_integral:
matching_hints[user_hint + "_Integral"] = solver
else:
for hint in solver_map:
solver = solver_map[hint](ode)
if solver.matches():
matching_hints[hint] = solver
if solver_map[hint].has_integral:
matching_hints[hint + "_Integral"] = solver
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would probably restructure this slightly:

# Used when dsolve is called without an explicit hint.
# We exit early to return the first valid match
early_exit = (user_hint == 'default')

# An explicit hint has been given to dsolve
# Skip matching code for other hints
if user_hint not in ('default', 'all', 'all_Integral', 'best'):
    solver_map = ... # only the relevant items

for hint, cls in solver_map.items()
   ...
   if solver.matches():
       ...
       if early_exit:
           # maybe there needs to be a check here because of the unrefactored
           # solvers that are matched below. I guess this should only return if the
           # hint has a higher priority than all of those...
           return matching_hints

I suppose ultimately if all solvers were refactored then we wouldn't need to call classify_ode at all when a hint is given but for now the explicit hint case can be handled here.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am still trying to figure out what to do with the default hint as Factorable and NthorderReducible solvers call dsolve recursively.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You have changed this code now but you seemed to have missed part of the point which is that by changing solver_map we can avoid duplicating the code that calls match.

@oscarbenjamin
Copy link
Contributor

Can you write an explanation of what this does in the OP?

@Mohitbalwani26
Copy link
Member Author

Can you write an explanation of what this does in the OP?

sure, I will update it

@oscarbenjamin
Copy link
Contributor

The OP doesn't explain the early exit

@oscarbenjamin
Copy link
Contributor

This also significantly improves the speed of test suite.

Can you give numbers for this?

@Mohitbalwani26
Copy link
Member Author

Can you give numbers for this?

  1. bin/doctest sympy/solvers/ode took 233.06 seconds on master and 195.56 seconds on this branch.

  2.  In [2]: import sympy
     In [3]: %time sympy.test('sympy/solvers/ode/tests/test_ode.py')   #on master
     CPU times: user 644 µs, sys: 7.56 ms, total: 8.2 ms
     Wall time: 1min 51s
     ***************************************************************************************************
     In [2]: %time sympy.test('sympy/solvers/ode/tests/test_ode.py')   #on this Branch
     CPU times: user 1.43 ms, sys: 8.07 ms, total: 9.5 ms
     Wall time: 1min 25s
  3.  In [2]: import sympy
     In [3]: %time sympy.test('sympy/solvers/ode/tests/test_single.py')   #on master
     CPU times: user 0 ns, sys: 9.96 ms, total: 9.96 ms
     Wall time: 25min 39s
     ***************************************************************************************************
     In [2]: %time sympy.test('sympy/solvers/ode/tests/test_single.py')   #on this Branch
     CPU times: user 4.75 ms, sys: 4.29 ms, total: 9.04 ms
     Wall time: 24min 43s

These doesn't include examples marked as XFAIL or Slow.
Also can you suggest some way to see the difference in speed if this is not the correct way.

@oscarbenjamin
Copy link
Contributor

Timing the test suite is good. I think that the best thing to do is to focus on a particular example and you should start with the simplest examples first. So let's take an ODE that should be trivial and time solving it on master:

In [1]: eq = f(x).diff(x, 4)

In [2]: eq
Out[2]: 
  4      
 d       
───(f(x))
  4      
dx       

In [3]: %time sol = dsolve(eq, f(x))
CPU times: user 1.37 s, sys: 15.2 ms, total: 1.38 s
Wall time: 1.38 s

In [4]: sol
Out[4]: 
                       2       3
f(x) = C+ C₂⋅x + C₃⋅x  + C₄⋅x 

Clearly it should be possible to compute this in much less than 1 second. There are several hints that match this and we can try others:

In [5]: classify_ode(eq)
Out[5]: 
('nth_algebraic',
 'nth_linear_constant_coeff_homogeneous',
 'nth_linear_euler_eq_homogeneous',
 'nth_algebraic_Integral')

In [6]: %time sol = dsolve(eq, f(x), hint='nth_linear_constant_coeff_homogeneous')
CPU times: user 1.79 s, sys: 22.3 ms, total: 1.81 s
Wall time: 1.85 s

With the PR we get better time if we specify the right hint:

In [1]: eq = f(x).diff(x, 4)

In [2]: %time sol = dsolve(eq, f(x))
CPU times: user 1.28 s, sys: 13.4 ms, total: 1.3 s
Wall time: 1.3 s

In [3]: %time sol = dsolve(eq, f(x), hint='nth_linear_constant_coeff_homogeneous')
CPU times: user 96.3 ms, sys: 2.45 ms, total: 98.8 ms
Wall time: 97.7 ms

We should be able to make this a lot faster still but you can see that it can be 10x faster if we specify the right hint. However by default it is still much slower than that which is because the nth_algebraic solver is called first and that is slow.

We can use a profiler to see what is slow:

In [4]: %prun -s cumulative dsolve(eq)

         2324044 function calls (2175047 primitive calls) in 3.251 seconds

   Ordered by: cumulative time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
    197/1    0.036    0.000    3.252    3.252 {built-in method builtins.exec}
        1    0.000    0.000    3.252    3.252 ode.py:354(dsolve)
      2/1    0.000    0.000    2.945    2.945 deutils.py:134(_desolve)
        1    0.000    0.000    2.938    2.938 ode.py:811(classify_ode)
        3    0.000    0.000    2.924    0.975 single.py:283(matches)
        1    0.000    0.000    2.905    2.905 single.py:380(_matches)
        2    0.000    0.000    2.677    1.339 solvers.py:379(solve)
      3/2    0.000    0.000    2.672    1.336 solvers.py:1281(_solve)
     79/3    0.008    0.000    2.557    0.852 simplify.py:411(simplify)
        2    0.000    0.000    2.503    1.252 solvers.py:1731(<listcomp>)
        2    0.000    0.000    2.503    1.251 solvers.py:187(checksol)
  790/388    0.010    0.000    2.390    0.006 basic.py:1241(replace)
 8284/388    0.017    0.000    2.361    0.006 basic.py:1466(walk)
8284/3376    0.010    0.000    2.341    0.001 basic.py:1488(rec_replace)

Line 380 of single.py is the NthAlgebraic match method and it is slow because it calls solve which then also calls simplify.

I don't know whether the nth_algebraic solver is still needed any more but it is very slow so we should either make it faster or get rid of it. In any case a method like constant coefficients can get the solution without integrals and so should always be preferred over a method that computes the result using integrals.

We can also check what makes it slow when using the hint:

In [4]: %prun -s cumulative sol = dsolve(eq, f(x), hint='nth_linear_constant_coeff_homogeneous')

         141421 function calls (133382 primitive calls) in 0.211 seconds

   Ordered by: cumulative time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.000    0.000    0.211    0.211 {built-in method builtins.exec}
        1    0.000    0.000    0.211    0.211 <string>:1(<module>)
        1    0.000    0.000    0.210    0.210 ode.py:354(dsolve)
        1    0.000    0.000    0.178    0.178 ode.py:640(_helper_simplify)
        1    0.000    0.000    0.103    0.103 single.py:292(get_general_solution)
        1    0.000    0.000    0.103    0.103 single.py:2133(_get_general_solution)
        1    0.000    0.000    0.093    0.093 nonhomogeneous.py:251(_get_simplified_sol)
   162/12    0.007    0.000    0.085    0.007 radsimp.py:21(collect)
        1    0.000    0.000    0.075    0.075 ode.py:671(<listcomp>)
      3/1    0.000    0.000    0.075    0.075 multidimensional.py:105(wrapper)
        1    0.000    0.000    0.075    0.075 ode.py:1561(odesimp)
        2    0.000    0.000    0.069    0.034 ode.py:1916(constantsimp)
2638/2253    0.006    0.000    0.052    0.000 cache.py:69(wrapper)
      582    0.007    0.000    0.039    0.000 basic.py:1560(match)
  250/144    0.000    0.000    0.038    0.000 {built-in method builtins.sum}
       32    0.000    0.000    0.037    0.001 basic.py:1519(count)
      448    0.001    0.000    0.037    0.000 basic.py:1522(<genexpr>)
    91/32    0.005    0.000    0.035    0.001 function.py:1268(__new__)
       30    0.001    0.000    0.035    0.001 basic.py:765(subs)
        2    0.000    0.000    0.033    0.017 ode.py:1863(__remove_linear_redundancies)

Here it is slow because of simplification. If we disable that then it is faster:

In [5]: %time sol = dsolve(eq, f(x), hint='nth_linear_constant_coeff_homogeneous', simplify=False)
CPU times: user 22.9 ms, sys: 1.76 ms, total: 24.6 ms
Wall time: 24.9 ms

Now 25ms is a lot better than 1.3 seconds but I'm sure there are many things that can be optimised here because this really is a simple ODE.

The simplification code here can be improved I think because that it currently quite slow.

@Mohitbalwani26
Copy link
Member Author

@oscarbenjamin I think first we should take a look on the ordering of hints. should I take some example which can be solved by every hint and then according to the time I should place them?

@oscarbenjamin
Copy link
Contributor

I do think that we should look at changing the order of the hints and optimising things but it should be separate from this PR

@Mohitbalwani26
Copy link
Member Author

I do think that we should look at changing the order of the hints and optimising things but it should be separate from this PR

so any changes you want me to make in this PR before getting it merged?

@@ -608,7 +608,7 @@ def recur_len(l):
if all_:
retdict = {}
failed_hints = {}
gethints = classify_ode(eq, dict=True)
gethints = classify_ode(eq, dict=True,hint='all')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Space after comma

if user_map[hint].has_integral:
matching_hints[hint + "_Integral"] = solver
if dict and early_exit:
break
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this break rather than return?

Presumably it's because some of the other hints are not yet refactored and might come earlier than the match that we just found. Which hints are not refactored?

Looking at allhints it seems that everything is refactored apart from the series solvers. If that's the case then I think it might be better to just return here. I don't see why we would want to return a series solution when it is still possible to return a non-series solution. Note that the matching code for the series solvers can be slow.

}
for solvercls in solvers:
solver = solvercls(ode)
user_hint = kwargs.get('hint','default')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There should be a space after a comma.

@oscarbenjamin
Copy link
Contributor

There should be tests for the new behaviour of classify_ode.

@Mohitbalwani26
Copy link
Member Author

There should be tests for the new behaviour of classify_ode.

@oscarbenjamin can you please explain a bit what the test would be as behavior remains same.

@oscarbenjamin
Copy link
Contributor

The behaviour is not the same if there is an early exit

@github-actions
Copy link

github-actions bot commented Jul 20, 2021

Benchmark results from GitHub Actions

Lower numbers are good, higher numbers are bad. A ratio less than 1
means a speed up and greater than 1 means a slowdown. Green lines
beginning with + are slowdowns (the PR is slower then master or
master is slower than the previous release). Red lines beginning
with - are speedups.

Significantly changed benchmark results (PR vs master)

       before           after         ratio
     [854c87d1]       [acdedd20]
-      2.38±0.02s       1.57±0.02s     0.66  dsolve.TimeDsolve01.time_dsolve

Significantly changed benchmark results (master vs previous release)

       before           after         ratio
     [ed9a550f]       [854c87d1]
     <sympy-1.8^0>                 
+      1.23±0.01s       2.38±0.02s     1.93  dsolve.TimeDsolve01.time_dsolve
+      89.2±0.5μs       3.25±0.1ms    36.46  matrices.TimeDiagonalEigenvals.time_eigenvals
-      6.46±0.1ms      3.51±0.02ms     0.54  solve.TimeRationalSystem.time_linsolve(10)
-     1.30±0.03ms         819±10μs     0.63  solve.TimeRationalSystem.time_linsolve(5)
-     1.54±0.04ms          995±9μs     0.65  solve.TimeSparseSystem.time_linsolve_eqs(10)
-     2.88±0.05ms      1.86±0.04ms     0.65  solve.TimeSparseSystem.time_linsolve_eqs(20)
-     4.23±0.04ms      2.68±0.02ms     0.63  solve.TimeSparseSystem.time_linsolve_eqs(30)

Full benchmark results can be found as artifacts in GitHub Actions
(click on checks at the top of the PR).

#This is for new behavior of classify_ode when called internally with default, It should
# return the first hint which matches therefore, 'ordered_hints' key will not be there.
assert classify_ode(Eq(f(x).diff(x), 0), f(x),dict=True).get('ordered_hints') == None
assert classify_ode(Eq(f(x).diff(x), 0), f(x),dict=True).get('default') == 'nth_algebraic'
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There should be a space after a comma.

Why doesn't this assert what the actual output of classify_ode is?

Also is the behaviour changed in cases where a hint is specified?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why doesn't this assert what the actual output of classify_ode is?

ok, will change it to that.

Also is the behaviour changed in cases where a hint is specified?

Earlier classify_ode didn't take hint as argument but now it has so if we pass specific hint it will return solver for that and also checks for the series solution

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also the output of this test is something like:

{'order': 1,
 'nth_algebraic': <sympy.solvers.ode.single.NthAlgebraic at 0x7f94c63cc160>,
 'nth_algebraic_Integral': <sympy.solvers.ode.single.NthAlgebraic at 0x7f94c63cc160>,
 'default': 'nth_algebraic'}

so how should we check with assert?

@oscarbenjamin
Copy link
Contributor

Looks good. Thanks!

@oscarbenjamin oscarbenjamin merged commit 133beb4 into sympy:master Jul 22, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants