Skip to content

Implement linalg.BandedDot #1415

@jessegrabowski

Description

@jessegrabowski
Member

Description

Benchmarking in #1323 showed that the banded (tridiagonal) case can get huge speedups by using a specialized dot product. This issue asks for a BandedDot Op that uses xgbmv to realize these speedups.

In the future it would be nice to be able to rewrite into this Op in cases were we see that we can, but I don't think it's necessary on first pass. Just having the functionality laying around will be nice.

Note that JAX doesn't use xgbmv to do this in the tridiagonal case. They have _tridiagonal_product that just directly does it using jax primitive Ops. This might be preferable, because it would require no extra dispatch work, but it would not let us handle the general banded case -- only the tridiagonal case. Maybe we want both?

At minimum, we should benchmark xgbmv vs direct method in the tridiagonal case.

Activity

jessegrabowski

jessegrabowski commented on May 23, 2025

@jessegrabowski
MemberAuthor

@benmaier tagging you because this is related to the PDE stuff

jessegrabowski

jessegrabowski commented on May 23, 2025

@jessegrabowski
MemberAuthor

I updated my benchmark gist to include the "direct tridiagonal" case that jax uses. It is much worse than xgbmv, but maybe that's not surprising? It's faster than A @ b on the very largest cases, but not otherwise.

linked a pull request that will close this issue on May 23, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Development

    Participants

    @jessegrabowski

    Issue actions

      Implement `linalg.BandedDot` · Issue #1415 · pymc-devs/pytensor