Optimize solve of block_diag#2109
Conversation
| - ``b`` is also ``block_diag(B_1, ..., B_n)`` with matching block sizes: | ||
| decompose into ``block_diag(solve(A_i, B_i))`` (each per-block solve is | ||
| ``(m_i, m_i)`` instead of ``(m_i, m_total)``). | ||
| - Otherwise: split ``b`` into per-block row chunks and solve each block |
There was a problem hiding this comment.
can't we still slice away zero columns? and then call blockdiag again to reconstruct?
Also wondering if you do this, whether your specialization case is like blockdiad(*mats)[a:b, c:d] -> mat[i] when bounds align perfectly with one matrix. Then it could be subsumed by the composition of two rewrites. Not that we need to go there now.
Would still strip away useless zero columns though
There was a problem hiding this comment.
Ah sorry this is meant for non-aligned block diag, the point being you can do b=block_diag -> b[a :b, a: b] with the cumulative A diag sizes (reduce col dimensions, not just the row dimension)
There was a problem hiding this comment.
There is structure but it's non-trivial. Here's a worked example i came up with together with the bot:
import numpy as np
from scipy import linalg
rng = np.random.default_rng()
A1 = rng.normal(size=(4, 4))
A2 = rng.normal(size=(2, 2))
B1 = rng.normal(size=(3, 3))
B2 = rng.normal(size=(3, 3))
A = linalg.block_diag(A1, A2)
B = linalg.block_diag(B1, B2)
# Reference (the slow / dense way): solve A X = B as a 6x6 system.
X_ref = linalg.solve(A, B)
# Partition geometry.
A_sizes = [A1.shape[0], A2.shape[0]] # [4, 2] row-partition by A
B_sizes = [B1.shape[1], B2.shape[1]] # [3, 3] col-partition by B
N = sum(A_sizes) # 6
A_blocks = [A1, A2]
B_blocks = [B1, B2]
A_row_starts = np.cumsum([0, *A_sizes]) # [0, 4, 6]
B_col_starts = np.cumsum([0, *B_sizes]) # [0, 3, 6]
# Per A-block: figure out which B-blocks touch its rows, slice b to that
# contiguous column band, solve the small system, then zero-pad back to N cols.
row_pieces = []
for i, A_i in enumerate(A_blocks):
r_lo, r_hi = A_row_starts[i], A_row_starts[i + 1]
# A B-block j (cols c_j..c_{j+1}) is nonzero in rows r_lo..r_hi iff
# its row range [c_j, c_{j+1}) overlaps [r_lo, r_hi).
overlapping = [
j for j in range(len(B_blocks))
if B_col_starts[j] < r_hi and B_col_starts[j + 1] > r_lo
]
j_lo, j_hi = overlapping[0], overlapping[-1]
band_lo = B_col_starts[j_lo]
band_hi = B_col_starts[j_hi + 1]
# Slice the row-chunk of B to its nonzero column band.
chunk = B[r_lo:r_hi, band_lo:band_hi] # (m_i, w_i)
# Small solve: (m_i, m_i) against (m_i, w_i) instead of (m_i, N).
sol_red = linalg.solve(A_i, chunk) # (m_i, w_i)
# Zero-pad left/right to (m_i, N).
m_i = r_hi - r_lo
sol_full = np.zeros((m_i, N))
sol_full[:, band_lo:band_hi] = sol_red
row_pieces.append(sol_full)
X = np.concatenate(row_pieces, axis=0)
assert np.allclose(X, X_ref)
# Concretely for this example:
# A_1 (rows 0..3): overlaps both B_1 (cols 0..2) and B_2 (cols 3..5)
# → band = cols 0..5, w_1 = 6 (no savings, full width)
# A_2 (rows 4..5): overlaps only B_2 (cols 3..5)
# → band = cols 3..5, w_2 = 3 (drops cols 0..2 — half the work)
#
# Total RHS-column work: 4*6 + 2*3 = 30, vs the naive 4*6 + 2*6 = 36.
# Bigger savings appear when more A-blocks see only one (or a few) B-blocks. I think it's left to future work, the cases I care about don't have this structure and I can't think of a motivating example off the top of my head.
There was a problem hiding this comment.
Here is one where you don't need to know
anything about B shape statically (untested):
import numpy as np
from scipy import linalg
def solve_blockdiag_blockdiag(A_blocks, B_blocks):
"""
Solve block_diag(*A_blocks) @ X = block_diag(*B_blocks) for X.
A_i is square (m_i, m_i). B_j is (p_j, q_j). Requires sum(m_i) == sum(p_j).
Single output buffer: zero-fill once, write each per-block solve directly
into its (row range, column band) slot. No per-block pad, no concat.
"""
A_sizes = np.array([A.shape[0] for A in A_blocks])
B_row_sizes = np.array([B.shape[0] for B in B_blocks])
B_col_sizes = np.array([B.shape[1] for B in B_blocks])
N_rows = A_sizes.sum()
N_cols = B_col_sizes.sum()
A_row_starts = np.concatenate([[0], np.cumsum(A_sizes)])
B_row_starts = np.concatenate([[0], np.cumsum(B_row_sizes)])
B_col_starts = np.concatenate([[0], np.cumsum(B_col_sizes)])
r_lo = A_row_starts[:-1]
r_hi = A_row_starts[1:]
# Vectorized band lookup.
j_lo = np.searchsorted(B_row_starts[1:], r_lo, side='right')
j_hi_plus_1 = np.searchsorted(B_row_starts, r_hi, side='left')
band_lo = B_col_starts[j_lo]
band_hi = B_col_starts[j_hi_plus_1]
b = linalg.block_diag(*B_blocks)
# Single output buffer. Zero-fill once; write each per-block solve in place.
X = np.zeros((N_rows, N_cols))
for i, A_i in enumerate(A_blocks):
chunk = b[r_lo[i]:r_hi[i], band_lo[i]:band_hi[i]]
X[r_lo[i]:r_hi[i], band_lo[i]:band_hi[i]] = linalg.solve(A_i, chunk)
return X
if __name__ == "__main__":
rng = np.random.default_rng(0)
cases = [
("original example", [(4,4), (2,2)], [(3,3), (3,3)]),
("misaligned rect", [(3,3), (5,5), (4,4)], [(2,3), (4,2), (3,5), (3,4)]),
("aligned square", [(3,3), (4,4), (2,2)], [(3,5), (4,2), (2,3)]),
("1 A, 4 B", [(10,10)], [(2,1), (3,7), (1,3), (4,2)]),
("4 A, 1 B", [(2,2), (3,3), (1,1), (4,4)], [(10,6)]),
]
for label, A_shapes, B_shapes in cases:
A_blocks = [rng.normal(size=s) for s in A_shapes]
B_blocks = [rng.normal(size=s) for s in B_shapes]
X = solve_blockdiag_blockdiag(A_blocks, B_blocks)
X_ref = linalg.solve(linalg.block_diag(*A_blocks),
linalg.block_diag(*B_blocks))
print(f"{label:20s} matches: {np.allclose(X, X_ref)}")There was a problem hiding this comment.
The reason I might go for it is if you lose static info, maybe you are in the nice aligned case, you just don't statically know about it for whatever reason.
this doesn't need to assert A blocks are square. I don't know why you asserted it though, isn't the full solve invalid anyway in that case?
| block_sizes = [block.type.shape[-1] for block in blocks] | ||
|
|
||
| # Rewrite is conservative: we require all component matrices to be provably square. | ||
| # It is possible to have a square matrix block_diagonal matrix comprised of non-square |
There was a problem hiding this comment.
sure, but it would still not be invertible?
There was a problem hiding this comment.
I feel like we talked about this already, i'm getting deja vu. I don't remember what we decided then.
There was a problem hiding this comment.
Now that we get Nans instead of hard failure, I'm okay with doing the rewrite eagerly, unless you prove me wrong that there is ever a case where the whole matrix is invertible but not the non-square subcomponents, I don't think so, you will have rank deficiency one way or another
There was a problem hiding this comment.
Do we get Nan for trying to solve a (2, 3) (3, x) system?
There was a problem hiding this comment.
We get LinAlgError: Dimensions of A and B do not conform, not great... can we make it return nan instead?
There was a problem hiding this comment.
I hate throwing away perf because of static shape this basically becomes a specify_shape that each entry is square, but okay I don't love it either
There was a problem hiding this comment.
I agree, i just don't have a good alternative. I could add an Asset Op?
There was a problem hiding this comment.
Or as you say we could just tag it shape_unsafe and force users to write good graphs
There was a problem hiding this comment.
yes, but let's leave it for now, it's certainly not a no-brainer
There was a problem hiding this comment.
Given that is this PR mergable or was there other outstanding requests
Similar rewrite to #1493. Given
solve(block_diag(A, B), C), we can split up C and do two smaller solves, returningconcat(solve(A, C[:n]), solve(B, C[n:]).Plays nicely with further rewrites, especially after #2032. If A is dense and B is orthogonal for example, this can end up as
concat(solve(A, C[:n]), B.T @ C[n:]), which is much cheaper. Same goes for diagonality, etc.The properties of the larger solve can be preserved, because if a block diagonal matrix is triangular, diagonal, psd, or symmetric, it must be the case that all component matrices have those properties.