Skip to content

Add custom jax jvp for solve_sylvester#2116

Merged
ricardoV94 merged 3 commits into
pymc-devs:mainfrom
jessegrabowski:sylvester-jvp
May 6, 2026
Merged

Add custom jax jvp for solve_sylvester#2116
ricardoV94 merged 3 commits into
pymc-devs:mainfrom
jessegrabowski:sylvester-jvp

Conversation

@jessegrabowski
Copy link
Copy Markdown
Member

@jessegrabowski jessegrabowski commented May 6, 2026

Jax added support for solve_sylvester but they didn't add a VJP. There is a bunch of discussion about corner cases in this thread. I don't think we care -- these are about schur/eigenproblems. This is implicated in the autograd graph of solve_sylvester, but we can directly use adjoints. This is already what we do in python/c/numba. Kind of niche, but it comes up in statespace models using nutpie with gradient_backend='jax'.

Copy link
Copy Markdown
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Worth a test?

Comment thread tests/link/jax/linalg/test_solvers.py Outdated
Comment on lines +188 to +190
# We're manually overriding the jax jvp for this Op, so we test the gradients too
A_bar, B_bar, C_bar = pt.grad(out.sum(), [A, B, C])
compare_jax_and_py([A, B, C], [A_bar, B_bar, C_bar], [A_val, B_val, C_val])
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this doesn't test the jax jvp though?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

details details

@ricardoV94 ricardoV94 merged commit 109d8b6 into pymc-devs:main May 6, 2026
66 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request gradients jax

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants