-
Couldn't load subscription status.
- Fork 146
WIP: Add rewrite to fuse nested BlockDiag Ops #1671
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -60,11 +60,36 @@ | |
| solve_triangular, | ||
| ) | ||
|
|
||
| from pytensor.tensor.slinalg import BlockDiagonal | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
| MATRIX_INVERSE_OPS = (MatrixInverse, MatrixPinv) | ||
|
|
||
|
|
||
| from pytensor.tensor.slinalg import BlockDiagonal | ||
| from pytensor.graph import Apply | ||
|
|
||
| def fuse_blockdiagonal(node): | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You need to register the rewrite using one or more of the rewrite registration decorators. I suggest |
||
| # Only process if this node is a BlockDiagonal | ||
| if not isinstance(node.owner.op, BlockDiagonal): | ||
| return node | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Return |
||
|
|
||
| new_inputs = [] | ||
| changed = False | ||
| for inp in node.owner.inputs: | ||
| # If input is itself a BlockDiagonal, flatten its inputs | ||
| if inp.owner and isinstance(inp.owner.op, BlockDiagonal): | ||
| new_inputs.extend(inp.owner.inputs) | ||
| changed = True | ||
| else: | ||
| new_inputs.append(inp) | ||
|
|
||
| if changed: | ||
| # Return a new fused BlockDiagonal with all inputs | ||
| return BlockDiagonal(len(new_inputs))(*new_inputs) | ||
| return node | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Return |
||
|
|
||
|
|
||
| def is_matrix_transpose(x: TensorVariable) -> bool: | ||
| """Check if a variable corresponds to a transpose of the last two axes""" | ||
| node = x.owner | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -43,7 +43,50 @@ | |
| from tests import unittest_tools as utt | ||
| from tests.test_rop import break_op | ||
|
|
||
| from pytensor.tensor.rewriting.linalg import fuse_blockdiagonal | ||
|
|
||
|
|
||
| def test_nested_blockdiag_fusion(): | ||
| # Create matrix variables | ||
| x = pt.matrix("x") | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Instead of |
||
| y = pt.matrix("y") | ||
| z = pt.matrix("z") | ||
|
|
||
| # Nested BlockDiagonal | ||
| inner = BlockDiagonal(2)(x, y) | ||
| outer = BlockDiagonal(2)(inner, z) | ||
|
|
||
| # Count number of BlockDiagonal ops before fusion | ||
| nodes_before = ancestors([outer]) | ||
| initial_count = sum( | ||
| 1 for node in nodes_before | ||
| if getattr(node, "owner", None) and isinstance(node.owner.op, BlockDiagonal) | ||
| ) | ||
| assert initial_count > 1, "Setup failed: should have nested BlockDiagonal" | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You control the setup, so directly assert But on that note, make sure to test a deeper nesting as well. |
||
|
|
||
| # Apply the rewrite | ||
| fused = fuse_blockdiagonal(outer) | ||
|
Comment on lines
+67
to
+68
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You don't want to actually call the rewrite. Instead, compile the function using There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. calling |
||
|
|
||
| # Count number of BlockDiagonal ops after fusion | ||
| nodes_after = ancestors([fused]) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No need to look at only ancestors. Once you have a compiled function, you can look at all the nodes with |
||
| fused_count = sum( | ||
| 1 for node in nodes_after | ||
| if getattr(node, "owner", None) and isinstance(node.owner.op, BlockDiagonal) | ||
| ) | ||
| assert fused_count == 1, "Nested BlockDiagonal ops were not fused" | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also test that the |
||
|
|
||
| # Check that all original inputs are preserved | ||
| fused_inputs = [ | ||
| inp | ||
| for node in ancestors([fused]) | ||
| if getattr(node, "owner", None) and isinstance(node.owner.op, BlockDiagonal) | ||
| for inp in node.owner.inputs | ||
| ] | ||
| assert set(fused_inputs) == {x, y, z}, "Inputs were not correctly fused" | ||
|
|
||
|
|
||
|
|
||
|
|
||
| def test_matrix_inverse_rop_lop(): | ||
| rtol = 1e-7 if config.floatX == "float64" else 1e-5 | ||
| mx = matrix("mx") | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make sure you have
pre-commitand you've donepre-commit installin your dev environment. You have doubled imports and other issues this tool with help you check.