Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Lineax operators for the vector field of ControlTerm #370

Open
tttc3 opened this issue Feb 10, 2024 · 4 comments
Open

Support Lineax operators for the vector field of ControlTerm #370

tttc3 opened this issue Feb 10, 2024 · 4 comments
Labels
feature New feature

Comments

@tttc3
Copy link
Contributor

tttc3 commented Feb 10, 2024

By providing support for lineax.AbstractLinearOperators in the vector field of a ControlTerm it may be possible to reduce the need for WeaklyDiagonalControlTerm and/or any other such specialised control terms.

The following gives a MWE:

class ControlTerm(_ControlTerm):
    def prod(self, vf, control):
        if isinstance(vf, lx.AbstractLinearOperator):
            return jtu.tree_map(lambda _vf, _control: _vf.mv(_control), vf, control)
        return jtu.tree_map(_prod, vf, control)

# These two are now equivalent.
ControlTerm(lx.DiagonalLinearOperator(jnp.array([1,2,3])), ...)
WeaklyDiagonalControlTerm(jnp.array([1,2,3]), ...)

Not sure if this is something you want to support, but it occurred to me that the operator tags might also be useful for some of the diffrax solvers?

@patrick-kidger
Copy link
Owner

Yup, with the release of Lineax this is something I've been considering! This also dovetails nicely with #364, so that we may wish to also introduce a type parameter for the return type of the vector field.

(One very nitty concern I do have is that mathematically speaking, we tend to interpret f(y) dx as actually being a linear function dx -> f( . ) dx (returning a nonlinear function), rather than a nonlinear function y -> f(y) (returning a linear operator). But that's probably not a super important distinction, to be honest.)

I'd want to be sure that this works correctly with:

  • solvers that evaluate and store the result of .vf(...) directly.
  • BacksolveAdjoint
  • returning "complicated" linear operators like lx.JacobianLinearOperator.

In principle all of those things should be solvable -- Diffrax allows the result of the vector field to be arbitrary -- I just think we'd want to explicitly test them.

I'd be happy to take a pull request on this!

@lockwo
Copy link
Contributor

lockwo commented Jun 6, 2024

I have a branch started for this, but I want to know the scope of the change you are looking for here. Having spent (a little) time thinking about it I see a couple options, in increasing impact on the package

  1. I can add support to ControlTerms to allow linear operators, which would really reduce the complexity of Add KL divergence terms for Latent SDEs #402, without changing much or adding many LoC (and introduce no breaking changes). This could be a weird edge case, like oh you can do weaklydiagonal or also just lx control (maybe not a big deal)
  2. Control terms must return linear operators (or maybe allow operators and arrays), this means the weaklydiagonal term is gone/deprecated and can be totally removed from the package (breaking change) and everything is just control terms with specifications (might require some work in term checking, but probably doable)
  3. Everything just becomes a linear operator (every f(y) dx), which means all vector fields are now this way. This would include removing weaklydiagonal and changing existing terms (and would introduce breaking changes across the board). Seems like a very substantial refactor of the core of the Term design

I'm sure there are more nuanced or totally different options, but I'm just looking to get a feel on the scope you want with this change so I don't do a lot of unnecessary work.

@lockwo
Copy link
Contributor

lockwo commented Jun 6, 2024

For context, the simplest sort of approach to 1. is shown in #434 (needs more tests/docs/whatever but the core idea of just making minimal change to allow lineax in control terms is there)

@lockwo lockwo mentioned this issue Jun 6, 2024
@patrick-kidger
Copy link
Owner

I really like the look of #434! I think this is pretty much exactly what I had in mind.

I think if we were designing from scratch then we'd probably go with option 2*, but since that's not the reality we live in, then I think maintaining backward compatibility is worthwhile.

That said I think it might be worth marking WeaklyDiagonalControlTerm with a PendingDeprecationWarning, just to gently encourage people to use this new thing instead. (And perhaps also remove it from the generated docs?) Subclassing terms to create new matrix-vector interactions was always an advanced thing to do, so I like that this new approach simplifies and standardises that.

* Why not option 3? I think we'd still need the AbstractTerm abstraction as a wrapper around lx.AbstractLinearOperator, because we need a way to wrap up multiple terms into one, and a place to put controls. Not a strong feeling though.

@lockwo lockwo mentioned this issue Jun 26, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature New feature
Projects
None yet
Development

No branches or pull requests

3 participants