-
-
Notifications
You must be signed in to change notification settings - Fork 177
Documentation request for scan_kind #694
Description
Hi,
I am solving an ODE where the vector field needs many complex pytree objects as arguments and also creates some of them for internal use. I have noticed an excessive amount of memory usage on CPU for a single step of Tsit5() (similar behaviour is observed for other adaptive time-stepping solvers too), and the compute time is unreasonably long. However, the same integration on the GPU works fine (at least it runs in a normal amount of time).
In this tutorial, I saw the use of scan_kind="bounded" and gave it a try. It made a huge positive difference in compute time and max memory usage without sacrificing the compile time. Unfortunately, I couldn't find much documentation on different options and their expected behavior. Even realising that this was an option was a bit of luck. For a single step of solver, the overal speedup is like ~30x, from 6 minutes (or more) to just ~10 seconds. And the memory usage goes down to ~100 MBs from 14 GBs.
It would be perfect if this option were documented.
Note: As I mentioned, the vector field definition is pretty convoluted, so I cannot share a minimal example but if you are interested the issue happens for this vector field. The integration is called in this function and an example usage can be found in this test.
Package versions
jax v0.6.2diffrax v0.7.0equinox v0.12.2