Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions pytensor/tensor/rewriting/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,36 @@
solve_triangular,
)

from pytensor.tensor.slinalg import BlockDiagonal
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make sure you have pre-commit and you've done pre-commit install in your dev environment. You have doubled imports and other issues this tool with help you check.


logger = logging.getLogger(__name__)
MATRIX_INVERSE_OPS = (MatrixInverse, MatrixPinv)


from pytensor.tensor.slinalg import BlockDiagonal
from pytensor.graph import Apply

def fuse_blockdiagonal(node):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You need to register the rewrite using one or more of the rewrite registration decorators. I suggest @register_canonicalizeto start. You also need to pass in which Op you are registering. Check the other rewrites to see how it works.

# Only process if this node is a BlockDiagonal
if not isinstance(node.owner.op, BlockDiagonal):
return node
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Return None from the rewrite if it didn't do anything


new_inputs = []
changed = False
for inp in node.owner.inputs:
# If input is itself a BlockDiagonal, flatten its inputs
if inp.owner and isinstance(inp.owner.op, BlockDiagonal):
new_inputs.extend(inp.owner.inputs)
changed = True
else:
new_inputs.append(inp)

if changed:
# Return a new fused BlockDiagonal with all inputs
return BlockDiagonal(len(new_inputs))(*new_inputs)
return node
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Return None from a rewrite if it didn't do anything



def is_matrix_transpose(x: TensorVariable) -> bool:
"""Check if a variable corresponds to a transpose of the last two axes"""
node = x.owner
Expand Down
43 changes: 43 additions & 0 deletions tests/tensor/rewriting/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,50 @@
from tests import unittest_tools as utt
from tests.test_rop import break_op

from pytensor.tensor.rewriting.linalg import fuse_blockdiagonal


def test_nested_blockdiag_fusion():
# Create matrix variables
x = pt.matrix("x")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of pt.matrix use pt.tensor('x', shape=(3, 3)) for example, and give all the variables static shapes. The reason for this is that I want to test that the fused blockwise correctly comes out with the correct static shape

y = pt.matrix("y")
z = pt.matrix("z")

# Nested BlockDiagonal
inner = BlockDiagonal(2)(x, y)
outer = BlockDiagonal(2)(inner, z)

# Count number of BlockDiagonal ops before fusion
nodes_before = ancestors([outer])
initial_count = sum(
1 for node in nodes_before
if getattr(node, "owner", None) and isinstance(node.owner.op, BlockDiagonal)
)
assert initial_count > 1, "Setup failed: should have nested BlockDiagonal"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You control the setup, so directly assert initial_count == 2.

But on that note, make sure to test a deeper nesting as well.


# Apply the rewrite
fused = fuse_blockdiagonal(outer)
Comment on lines +67 to +68
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You don't want to actually call the rewrite. Instead, compile the function using pytensor.function, then check that the rewrite was correctly applied by looking at the compiled graph. Check here for a template to follow

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

calling rewrite_graph followed by assert_equal_computations is also a fine test, unless you are too uncertain and want to evaluate against something provably correct


# Count number of BlockDiagonal ops after fusion
nodes_after = ancestors([fused])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to look at only ancestors. Once you have a compiled function, you can look at all the nodes with fn.maker.fgraph.apply_nodes (see the SVD test I linked above)

fused_count = sum(
1 for node in nodes_after
if getattr(node, "owner", None) and isinstance(node.owner.op, BlockDiagonal)
)
assert fused_count == 1, "Nested BlockDiagonal ops were not fused"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also test that the n_inputs property of the new BlockDiagonal is correctly set.


# Check that all original inputs are preserved
fused_inputs = [
inp
for node in ancestors([fused])
if getattr(node, "owner", None) and isinstance(node.owner.op, BlockDiagonal)
for inp in node.owner.inputs
]
assert set(fused_inputs) == {x, y, z}, "Inputs were not correctly fused"




def test_matrix_inverse_rop_lop():
rtol = 1e-7 if config.floatX == "float64" else 1e-5
mx = matrix("mx")
Expand Down