Skip to content

apply linearise in init for Normal and iterative solvers#198

Merged
patrick-kidger merged 4 commits intopatrick-kidger:mainfrom
jpbrodrick89:jpb/linearise-normal-iterative
Mar 1, 2026
Merged

apply linearise in init for Normal and iterative solvers#198
patrick-kidger merged 4 commits intopatrick-kidger:mainfrom
jpbrodrick89:jpb/linearise-normal-iterative

Conversation

@jpbrodrick89
Copy link
Copy Markdown
Contributor

This is not without precedent as is done already in CG. The motivation is that both Normal and iterative solvers employ multiple mv's in sequence that cannot be parallelised, in which case calling linearise to cache the primal computation for JacobianLinearOperators should be more efficient. This is essentially hiding some complexity and decision-making stress from users at the cost of reduce control if there really is a case where not caching the primal computation saves significant memory. Typically, when memory is the bottleneck iterative solvers are the go to and I can't really envisage a case where you could run e.g. LSMR with JacobianLinearOperator but not with FunctionLinearOperator but I may be way off the mark.

If you'd rather not do this, should we remove linearise from CG?

@patrick-kidger
Copy link
Copy Markdown
Owner

This looks reasonable to me!

...although we appear to have a lot of test failures (not just the nitty ongoing LSMR thing)?

@jpbrodrick89
Copy link
Copy Markdown
Contributor Author

jpbrodrick89 commented Feb 5, 2026

I think these are fixed by #200 just not merged here (this PR I think was the trigger to go and fix it! 😅), if you're happy with #200 I can rebase the other PR's accordingly. Happy to take a relay approach here where we iterate and merge a PR or two at a time rather than a big sweep. Just trying to break up the cognitive load in to manageable chunks for you. 😊

#200 is definitely the candidate for first merge rn. I can let you know which are next most ready as we go, sorry about the avalanche, when you get deep in the weeds these things all knot together😅

@jpbrodrick89
Copy link
Copy Markdown
Contributor Author

I tried updating the optimistix Levenberg-Marquadt benchmarks to use the Normal wrapper and noticed some very serious regressions, I tried playing around with some different options to get to the bottom of this (the first row uses the reported benchmarks in the file which were run on a different machine and arguably not directly comparable, the second row is what I get when running with lineax 0.0.8 and from optimistix commit 03e47dc with the original custom NormalCholesky and NormalCG operator ALL the rest use Normal(Cholesky()) and Normal(CG()) but with different versions of lineax base on 0.1.0):

  ┌────────────────────────────┬─────────────────────────┬───────────────────┐
  │          Version           │ Normal Cholesky Runtime │ Normal CG Runtime │
  ├────────────────────────────┼─────────────────────────┼───────────────────┤
  │ Reported                   │ 0.0016s                 │ 0.015s            │
  ├────────────────────────────┼─────────────────────────┼───────────────────┤
  │ Custom                     │ 0.0025s                 │ 0.058s            │
  ├────────────────────────────┼─────────────────────────┼───────────────────┤
  │ lineax 0.1.0               │ 0.0386s                 │ 0.058s            │
  ├────────────────────────────┼─────────────────────────┼───────────────────┤
  │ composed-materialise       │ 0.0138s                 │ 0.058s            │
  ├────────────────────────────┼─────────────────────────┼───────────────────┤
  │ linearise-normal-iterative │ 0.0358s                 │ 0.057s            │
  ├────────────────────────────┼─────────────────────────┼───────────────────┤
  │ Both combined              │ 0.0144s                 │ 0.057s            │
  ├────────────────────────────┼─────────────────────────┼───────────────────┤
  │ materialise in Normal.init │ 0.0025s                 │ 0.0026s           │
  └────────────────────────────┴─────────────────────────┴───────────────────┘

In summary, the Normal wrapper leads to a 15x performance regression and linearise only makes that 10% better. #196 (my composed-materialise PR) makes a factor of three improvement but this still leaves a 5x discrepancy. My interpretation, is that for long computational graphs (i.e. the vmapped diffrax solve in the benchmark file) materialisation is a huge win when computing the normal matrices for a tall, thin operator and when the solver requires materialisation already this is essentially free and transient (if the height is more than twice the width the normal matrix will have at least as many elements as the materialised operator and its transpose).

How do you suggest we handle this (options in order of decreasing user control/visibility)?

  1. Just document and leave it up to the user to materialise their operator in advance.
  2. Always materialise under-the-hood for "direct" solvers for "tall, thin" matrices (of some definition)
  3. Always materialise under-the-hood for "direct" solvers when the matrix is not "too" wide (e.g. 2048 columns to take arbitrary inspiration from jax decides to use magma for pivoted QR).
  4. Always materialise under-the-hood for "direct" solvers for all in/out structure
  5. Always materialise under-the-hood for "direct" and "iterative" solvers for tall, thin matrices.

My preference would probably be 2 or 4. Even if often more performant I think materialising for iterative solvers would be a surprise to users looking to minimise their memory requirements. Furthermore, I think the memory cost for the materialised rectangular matrix would be transient giving more weight to option 4. For these options, we would need to come up with a nice way of detecting an iterative solver. This could feasibly be "whether is has max_steps (or tolerances) as an attribute" or we could add a property to the solver/add an AbstractIterativeLInearSolver/AbstractDirectLinearSolver if you want something more transparent/future-proof.

I've intentionally left out the option of giving the Normal wrapper a materialise argument as that just seems an unnecessarily convenient way of offering 1.

@patrick-kidger
Copy link
Copy Markdown
Owner

patrick-kidger commented Mar 1, 2026

IIUC, then the main concern here is that Normal+Cholesky – on presumably a matrix-free operator representation, though I don't think you specify – is unnecessarily slow, just in light of the fact that Cholesky will anyway materialise the operator.

What I don't understand is why Normal(Cholesky()) isn't doing that anyway. Normal.init will call Cholesky.init on op @ op.T (or op.T @ op), which then calls .as_matrix(), and this already forces a materialisation of the composed operator.

I have one guess, which is that (op @ op.T).as_matrix() does not exploit the fact that op appears twice, i.e. it computes its way through the operator twice, rather than materialising once and then doing a matmul.

Do you have an better idea of what's going on / perhaps a MWE of the above sort if my guess is correct?


Different topic – I think we're happy with this PR / shall I merge?

@jpbrodrick89
Copy link
Copy Markdown
Contributor Author

jpbrodrick89 commented Mar 1, 2026

Yes exactly, lineax can't tell that op.T and op are related so the linearised function is vmapped over twice. The MWE I'm using is just the optimistix Levenberg Marquadt benchmarks in the repo but swapping out your custom NormalCholesky for Normal(Cholesky).

This PR is indeed a net improvement from lineax main right now, so feel free to merge if you like, but using materialise instead of linearise on op in some circumstances (for Normal only) is necessary to match the custom NormalCholesky performance. If you let me know what circumstances you'd like (I've given my suggestions above) I'd be happy to code that in either on this PR or a follow up.

@patrick-kidger patrick-kidger merged commit 995efb3 into patrick-kidger:main Mar 1, 2026
1 check passed
@patrick-kidger
Copy link
Copy Markdown
Owner

Okay! In that case let's merge this one 🎉

For a follow-up, then maybe it would suffice for materialise(transpose(...)) to commute into transpose(materialise(...))? And then materialise(op) @ materialise(op.T) would get optimized by XLA's CSE pass? Admittedly still perhaps at the cost of some compile time.

Alternatively, my guess is that this isn't an optimization that belongs in Normal, but perhaps should instead be an optimization that belongs to materialise(::ComposedLinearOperator)?

These are just some guesses though, and ultimately I don't really have strong feelings. I'd even be happy with just special-casing Cholesky inside Normal (rather than introducing an AbstractDirectLinearSolver just for this purpose).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants