Skip to content

Optimize solve of block_diag#2109

Merged
ricardoV94 merged 3 commits into
pymc-devs:mainfrom
jessegrabowski:block-diag-pushdown
May 4, 2026
Merged

Optimize solve of block_diag#2109
ricardoV94 merged 3 commits into
pymc-devs:mainfrom
jessegrabowski:block-diag-pushdown

Conversation

@jessegrabowski
Copy link
Copy Markdown
Member

Similar rewrite to #1493. Given solve(block_diag(A, B), C), we can split up C and do two smaller solves, returning concat(solve(A, C[:n]), solve(B, C[n:]).

Plays nicely with further rewrites, especially after #2032. If A is dense and B is orthogonal for example, this can end up as concat(solve(A, C[:n]), B.T @ C[n:]), which is much cheaper. Same goes for diagonality, etc.

The properties of the larger solve can be preserved, because if a block diagonal matrix is triangular, diagonal, psd, or symmetric, it must be the case that all component matrices have those properties.

@jessegrabowski jessegrabowski requested a review from ricardoV94 May 2, 2026 23:05
@jessegrabowski jessegrabowski added enhancement New feature or request graph rewriting linalg Linear algebra labels May 2, 2026
Comment thread pytensor/tensor/rewriting/linalg/solvers.py Outdated
Comment thread pytensor/tensor/rewriting/linalg/solvers.py Outdated
Comment thread pytensor/tensor/rewriting/linalg/solvers.py Outdated
- ``b`` is also ``block_diag(B_1, ..., B_n)`` with matching block sizes:
decompose into ``block_diag(solve(A_i, B_i))`` (each per-block solve is
``(m_i, m_i)`` instead of ``(m_i, m_total)``).
- Otherwise: split ``b`` into per-block row chunks and solve each block
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can't we still slice away zero columns? and then call blockdiag again to reconstruct?


Also wondering if you do this, whether your specialization case is like blockdiad(*mats)[a:b, c:d] -> mat[i] when bounds align perfectly with one matrix. Then it could be subsumed by the composition of two rewrites. Not that we need to go there now.

Would still strip away useless zero columns though

Copy link
Copy Markdown
Member

@ricardoV94 ricardoV94 May 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah sorry this is meant for non-aligned block diag, the point being you can do b=block_diag -> b[a :b, a: b] with the cumulative A diag sizes (reduce col dimensions, not just the row dimension)

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is structure but it's non-trivial. Here's a worked example i came up with together with the bot:

import numpy as np
from scipy import linalg 
rng = np.random.default_rng()
A1 = rng.normal(size=(4, 4)) 
A2 = rng.normal(size=(2, 2))

B1 = rng.normal(size=(3, 3))
B2 = rng.normal(size=(3, 3))

A = linalg.block_diag(A1, A2)
B = linalg.block_diag(B1, B2)

# Reference (the slow / dense way): solve A X = B as a 6x6 system.            
X_ref = linalg.solve(A, B)

# Partition geometry.           
A_sizes = [A1.shape[0], A2.shape[0]]   # [4, 2]   row-partition by A
B_sizes = [B1.shape[1], B2.shape[1]]   # [3, 3]   col-partition by B            
N = sum(A_sizes)                       # 6            

A_blocks = [A1, A2]
B_blocks = [B1, B2]

A_row_starts = np.cumsum([0, *A_sizes])  # [0, 4, 6]
B_col_starts = np.cumsum([0, *B_sizes])  # [0, 3, 6]

# Per A-block: figure out which B-blocks touch its rows, slice b to that
# contiguous column band, solve the small system, then zero-pad back to N cols.            
row_pieces = []
for i, A_i in enumerate(A_blocks):                     
    r_lo, r_hi = A_row_starts[i], A_row_starts[i + 1]

    # A B-block j (cols c_j..c_{j+1}) is nonzero in rows r_lo..r_hi iff            
    # its row range [c_j, c_{j+1}) overlaps [r_lo, r_hi).
    overlapping = [
      j for j in range(len(B_blocks))                
      if B_col_starts[j] < r_hi and B_col_starts[j + 1] > r_lo
    ]                                                  
    j_lo, j_hi = overlapping[0], overlapping[-1]
    band_lo = B_col_starts[j_lo]            
    band_hi = B_col_starts[j_hi + 1]            

    # Slice the row-chunk of B to its nonzero column band.            
    chunk = B[r_lo:r_hi, band_lo:band_hi]   # (m_i, w_i)

    # Small solve: (m_i, m_i) against (m_i, w_i) instead of (m_i, N).            
    sol_red = linalg.solve(A_i, chunk)      # (m_i, w_i)            

    # Zero-pad left/right to (m_i, N).            
    m_i = r_hi - r_lo            
    sol_full = np.zeros((m_i, N))            
    sol_full[:, band_lo:band_hi] = sol_red             
    row_pieces.append(sol_full)            

X = np.concatenate(row_pieces, axis=0)            

assert np.allclose(X, X_ref)            

# Concretely for this example:            
#   A_1 (rows 0..3): overlaps both B_1 (cols 0..2) and B_2 (cols 3..5)
#       → band = cols 0..5, w_1 = 6   (no savings, full width)            
#   A_2 (rows 4..5): overlaps only B_2 (cols 3..5)            
#       → band = cols 3..5, w_2 = 3   (drops cols 0..2 — half the work)            
#            
# Total RHS-column work: 4*6 + 2*3 = 30, vs the naive 4*6 + 2*6 = 36.
# Bigger savings appear when more A-blocks see only one (or a few) B-blocks.      

I think it's left to future work, the cases I care about don't have this structure and I can't think of a motivating example off the top of my head.

Copy link
Copy Markdown
Member

@ricardoV94 ricardoV94 May 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is one where you don't need to know
anything about B shape statically (untested):

import numpy as np
from scipy import linalg


def solve_blockdiag_blockdiag(A_blocks, B_blocks):
    """
    Solve  block_diag(*A_blocks) @ X = block_diag(*B_blocks)  for X.

    A_i is square (m_i, m_i). B_j is (p_j, q_j). Requires sum(m_i) == sum(p_j).

    Single output buffer: zero-fill once, write each per-block solve directly
    into its (row range, column band) slot. No per-block pad, no concat.
    """
    A_sizes     = np.array([A.shape[0] for A in A_blocks])
    B_row_sizes = np.array([B.shape[0] for B in B_blocks])
    B_col_sizes = np.array([B.shape[1] for B in B_blocks])

    N_rows = A_sizes.sum()
    N_cols = B_col_sizes.sum()

    A_row_starts = np.concatenate([[0], np.cumsum(A_sizes)])
    B_row_starts = np.concatenate([[0], np.cumsum(B_row_sizes)])
    B_col_starts = np.concatenate([[0], np.cumsum(B_col_sizes)])

    r_lo = A_row_starts[:-1]
    r_hi = A_row_starts[1:]

    # Vectorized band lookup.
    j_lo        = np.searchsorted(B_row_starts[1:], r_lo, side='right')
    j_hi_plus_1 = np.searchsorted(B_row_starts,     r_hi, side='left')

    band_lo = B_col_starts[j_lo]
    band_hi = B_col_starts[j_hi_plus_1]

    b = linalg.block_diag(*B_blocks)

    # Single output buffer. Zero-fill once; write each per-block solve in place.
    X = np.zeros((N_rows, N_cols))
    for i, A_i in enumerate(A_blocks):
        chunk = b[r_lo[i]:r_hi[i], band_lo[i]:band_hi[i]]
        X[r_lo[i]:r_hi[i], band_lo[i]:band_hi[i]] = linalg.solve(A_i, chunk)

    return X


if __name__ == "__main__":
    rng = np.random.default_rng(0)

    cases = [
        ("original example",   [(4,4), (2,2)],          [(3,3), (3,3)]),
        ("misaligned rect",    [(3,3), (5,5), (4,4)],   [(2,3), (4,2), (3,5), (3,4)]),
        ("aligned square",     [(3,3), (4,4), (2,2)],   [(3,5), (4,2), (2,3)]),
        ("1 A, 4 B",           [(10,10)],               [(2,1), (3,7), (1,3), (4,2)]),
        ("4 A, 1 B",           [(2,2), (3,3), (1,1), (4,4)], [(10,6)]),
    ]

    for label, A_shapes, B_shapes in cases:
        A_blocks = [rng.normal(size=s) for s in A_shapes]
        B_blocks = [rng.normal(size=s) for s in B_shapes]
        X     = solve_blockdiag_blockdiag(A_blocks, B_blocks)
        X_ref = linalg.solve(linalg.block_diag(*A_blocks),
                             linalg.block_diag(*B_blocks))
        print(f"{label:20s}  matches: {np.allclose(X, X_ref)}")

Copy link
Copy Markdown
Member

@ricardoV94 ricardoV94 May 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason I might go for it is if you lose static info, maybe you are in the nice aligned case, you just don't statically know about it for whatever reason.

this doesn't need to assert A blocks are square. I don't know why you asserted it though, isn't the full solve invalid anyway in that case?

block_sizes = [block.type.shape[-1] for block in blocks]

# Rewrite is conservative: we require all component matrices to be provably square.
# It is possible to have a square matrix block_diagonal matrix comprised of non-square
Copy link
Copy Markdown
Member

@ricardoV94 ricardoV94 May 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, but it would still not be invertible?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like we talked about this already, i'm getting deja vu. I don't remember what we decided then.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that we get Nans instead of hard failure, I'm okay with doing the rewrite eagerly, unless you prove me wrong that there is ever a case where the whole matrix is invertible but not the non-square subcomponents, I don't think so, you will have rank deficiency one way or another

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we get Nan for trying to solve a (2, 3) (3, x) system?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We get LinAlgError: Dimensions of A and B do not conform, not great... can we make it return nan instead?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I hate throwing away perf because of static shape this basically becomes a specify_shape that each entry is square, but okay I don't love it either

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, i just don't have a good alternative. I could add an Asset Op?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or as you say we could just tag it shape_unsafe and force users to write good graphs

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, but let's leave it for now, it's certainly not a no-brainer

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given that is this PR mergable or was there other outstanding requests

@ricardoV94 ricardoV94 merged commit d6c6da1 into pymc-devs:main May 4, 2026
66 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request graph rewriting linalg Linear algebra

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants