Enable setting of Jacobian tags in ODETerm#755
Conversation
|
@patrick-kidger just a friendly ping to see if you have had/might have a chance to think about this before the workshop next week. I realise this is an impactful change that might take some digestion, so no worries if not. Note we can offload even more complexity and maintenance to lineax now with the new tags_from_checks helper. I could document that and do another lineax release and refactor this accordingly? |
patrick-kidger
left a comment
There was a problem hiding this comment.
I like the basic idea of deferring the tag-propagation problem to Lineax. That sounds like the right place for this problem to live, I think.
I am concerned that this is started to get a bit of a thorny stack to reason over: whilst many users can indeed just provide some tags, this really starts to raise the bar for what is required to be a power user. More and more subsystems that need addressing. (I see this a lot first-hand: I almost always take a look at the source code of the libraries I use, and many of them start to accrue a lot of special-cases that look a lot like this, that are largely inscrutable to the non-expert or non-author.)
I wonder if we should instead make it possible for a user to just say something like 'just use a Cholesky solver, trust me it'll work'.
(Honestly the whole tags system is never something I've been a fan of anyway...)
| # (C.f. `AbstractRungeKutta.step`.) | ||
| # If we wanted FSAL then really the correct thing to do would just be to | ||
| # write out a `ButcherTableau` and use `AbstractSDIRK`. | ||
| _inner = terms.term if isinstance(terms, WrapTerm) else terms |
There was a problem hiding this comment.
I don't think this is okay. Kind of the point of having an ABC is that we don't generally care exactly what we get, and can abstract over them.
(As a practical matter we do pierce this veil in a couple of places, but almost always just as special-case performance optimizations, not things that affect functional correctness.)
| lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype), y0 | ||
| ) | ||
| residual_tags = self._residual_tags( | ||
| getattr(_inner, "tags", frozenset()), y_struct, _implicit_op_relation |
There was a problem hiding this comment.
I consider getattr to be pretty much banned in all code I write! 😄 The better choice is usually to set up an abstractmethod on the appropriate ABC, or to use an isinstance (depending on whether you want type or consumer to own the implementation).
| from equinox.internal import ω | ||
|
|
||
| from ._heuristics import is_sde, is_unsafe_sde | ||
| from ._misc import _frozenset |
There was a problem hiding this comment.
nit, underscore indicates privacy to the file it's created in. (Is the convention we use in the Equinox ecosystem.)
|
I'm tempted to for a both/and option by going with this PLUS a lineax environment LINEAX_CHECK_TAGS that defaults to True (for error checking and backwards compatibility) and allowing users to disable this if they can't be bothered to mess around with tags. The trickiest thing though is Cholesky requires awareness of whether the Jacobian is nsd or psd and this would require the user diving into the internals (it depends whether they're using DIRK or implicit Euler) to check our sign conventions which is also a footgun that tag propagation would solve for them. Can you think of any clever solution to this? Possibly the easiest is to change the sign convention of Euler and then default Cholesky to assuming psd rather than nsd when LINEAX_CHECK_TAGS is false but there might be another non-breaking change. |
In light of the new colouring methods in lineax I really wanted to see if there was an unfootgunnable way for users to provide structural tags for their syste, and I think I might have found it.
The challenge
Providing tags to implicit solvers requires user's knowledge of the exact impicit relation being solved and perhaps careful consideration of whether the solve is fully-implicit (in which case a tridiagonal ODE term may be become pentadiagonal for a FIRK2) or direct implicit.Fortunately though, diffrax currently only has DIRK solvers (Implicit Euler is just a flavour of DIRK).
The observation
Each iteration of the nonlinear solver solves a linearised form of the implicit relation—this can be represented in terms of lineax operators and our existing tag propagation logic.
The solution
Users provide tags for the ODE system Jacobian directly on to ODETerm, this should represent the structure of
jacobian(vf)(y)and is completely agnostic of solver internals. Implicit solvers use lineax tag propagation to determine which tags remain relevant. For example, if a user declares a system to be negative semi-definite (e.g. a diffusion equation) then the current DIRK and implicit Euler solvers will recognised this as positive/negative semi-definite and pick Cholesky, however if the system is positive semi-definite (anti-diffusion) the implicit solvers will realise that residual jacobian is not guaranteed to have either form of semi-definiteness for arbitrary timestep and LU will be used instead.Remaining footguns
The inherent footgun is not necessary completely removed, but the burden of complexity now lays solely at the feet of power users developing their own implicit solvers rather than the typical diffeqsolve users who can just plug in well designed implicit solvers and have it "just work". The main responsibility of implicit solver writers is to provide
implicit_op_relationmirroringimplicit_relation. Furthermore, this handling is opt-in from implicit solver designers, no mandatory (or optional) arguments are introduced into implicit solver design.It would be really nice if we can get this in before the AD workshop and provide an easy heat equation example but completely understand if this is too complicated to find the time for in that window.