From 837b3d78b7cda05dbb9aa727c062805272735589 Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Mon, 6 Oct 2025 14:50:45 +0200 Subject: [PATCH 1/2] Small tweaks to FusionOptimizer --- pytensor/tensor/rewriting/elemwise.py | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/pytensor/tensor/rewriting/elemwise.py b/pytensor/tensor/rewriting/elemwise.py index dfcdfdd471..c37a5caa47 100644 --- a/pytensor/tensor/rewriting/elemwise.py +++ b/pytensor/tensor/rewriting/elemwise.py @@ -652,12 +652,12 @@ def elemwise_scalar_op_has_c_code( # `ancestors_bitset[C] & (node_bitset[A] | node_bitset[B]) != 0` nodes_bitflags = {node: 1 << i for i, node in enumerate(fgraph.toposort())} # Root variables have `None` as owner, which we can handle with a bitset of 0 - ancestors_bitset = {None: 0} + ancestors_bitsets = {None: 0} for node, node_bitflag in nodes_bitflags.items(): # The bitset of each node is the union of the bitsets of its inputs, plus its own bit flag - ancestors_bitset[node] = reduce( + ancestors_bitsets[node] = reduce( or_, - (ancestors_bitset[inp.owner] for inp in node.inputs), + (ancestors_bitsets[inp.owner] for inp in node.inputs), node_bitflag, ) # Handle root and leaf nodes gracefully @@ -666,10 +666,12 @@ def elemwise_scalar_op_has_c_code( nodes_bitflags[None] = 0 # Nothing ever depends on the special Output nodes, so just use a new bit for all of them out_bitflag = 1 << len(nodes_bitflags) - for out in fg.outputs: - for client, _ in fg_clients[out]: - if isinstance(client.op, Output): - nodes_bitflags[client] = out_bitflag + nodes_bitflags |= ( + (client, out_bitflag) + for out in fg.outputs + for client, _ in fg_clients[out] + if isinstance(client.op, Output) + ) # Start main loop to find collection of fuseable subgraphs # We store the collection in `sorted_subgraphs`, in reverse topological order @@ -726,7 +728,7 @@ def elemwise_scalar_op_has_c_code( if node_bitflag & unfuseable_ancestors_bitset: # An unfuseable ancestor of the subgraph depends on this node, can't fuse continue - elif ancestors_bitset[node] & unfuseable_clients_bitset: + elif ancestors_bitsets[node] & unfuseable_clients_bitset: # This node depends on an unfuseable client of the subgraph, can't fuse continue @@ -742,7 +744,7 @@ def elemwise_scalar_op_has_c_code( for inp in node.inputs: ancestor_node = inp.owner ancestor_bitflag = nodes_bitflags[ancestor_node] - if ancestor_bitflag & subgraph_bitset: + if (not is_ancestor) and (ancestor_bitflag & subgraph_bitset): continue if node in fuseable_clients.get(ancestor_node, ()): heappush( @@ -752,14 +754,14 @@ def elemwise_scalar_op_has_c_code( else: # If the node is not in the ancestor's fuseable clients set, it's not fuseable with it, # nor with any of the ancestor's ancestors - unfuseable_ancestors_bitset |= ancestors_bitset[ + unfuseable_ancestors_bitset |= ancestors_bitsets[ ancestor_node ] next_fuseable_clients = fuseable_clients.get(node, ()) for client, _ in fg_clients[node.outputs[0]]: client_bitflag = nodes_bitflags[client] - if client_bitflag & subgraph_bitset: + if is_ancestor and (client_bitflag & subgraph_bitset): continue if client in next_fuseable_clients: heappush(fuseables_nodes_queue, (client_bitflag, client)) From ca797da8f86bc9d43f35d75c25b961ea57622e5d Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Sun, 5 Oct 2025 03:15:36 +0200 Subject: [PATCH 2/2] Fix FusionOptimizer bug When a subgraph with multiple outputs is "implicitly" claimed, it can change the dependencies of remaining nodes. A node that depended only on a subset of the subgraph outputs now depends on all of them. Not taking this into account could lead to circular dependent Composites --- pytensor/tensor/rewriting/elemwise.py | 62 ++++++++++++++++++------- tests/tensor/rewriting/test_elemwise.py | 37 +++++++++++++++ 2 files changed, 83 insertions(+), 16 deletions(-) diff --git a/pytensor/tensor/rewriting/elemwise.py b/pytensor/tensor/rewriting/elemwise.py index c37a5caa47..65d9cfafb5 100644 --- a/pytensor/tensor/rewriting/elemwise.py +++ b/pytensor/tensor/rewriting/elemwise.py @@ -652,7 +652,7 @@ def elemwise_scalar_op_has_c_code( # `ancestors_bitset[C] & (node_bitset[A] | node_bitset[B]) != 0` nodes_bitflags = {node: 1 << i for i, node in enumerate(fgraph.toposort())} # Root variables have `None` as owner, which we can handle with a bitset of 0 - ancestors_bitsets = {None: 0} + ancestors_bitsets: dict[Apply | None, int] = {None: 0} for node, node_bitflag in nodes_bitflags.items(): # The bitset of each node is the union of the bitsets of its inputs, plus its own bit flag ancestors_bitsets[node] = reduce( @@ -694,9 +694,13 @@ def elemwise_scalar_op_has_c_code( # For simplicity, we always want to visit ancestors before clients # For ancestors, we want to visit the later nodes first (those that have more dependencies) # whereas for clients we want to visit earlier nodes first (those that have fewer dependencies) - # To achieve this we use the bitflag as the sorting key (which encodes the topological order) - # and negate it for ancestors. - fuseables_nodes_queue = [(-starting_bitflag, starting_node)] + # To achieve this we use the ancestors_bitset as the sorting key (which encodes the topological order) + # and negate it for ancestors. We use the ancestors_bitset instead of the node bitflag because we + # update the former when we find a fuseable subgraph, emulating the effect of recomputing the + # topological order on the remaining nodes. + fuseables_nodes_queue = [ + (-ancestors_bitsets[starting_node], starting_bitflag, starting_node) + ] heapify(fuseables_nodes_queue) # We keep 3 bitsets during the exploration of a new subgraph: @@ -715,10 +719,12 @@ def elemwise_scalar_op_has_c_code( unfuseable_clients_bitset = 0 while fuseables_nodes_queue: - node_bitflag, node = heappop(fuseables_nodes_queue) - is_ancestor = node_bitflag < 0 + node_ancestors_bitset, node_bitflag, node = heappop( + fuseables_nodes_queue + ) + is_ancestor = node_ancestors_bitset < 0 if is_ancestor: - node_bitflag = -node_bitflag + node_ancestors_bitset = -node_ancestors_bitset if node_bitflag & subgraph_bitset: # Already part of the subgraph @@ -728,7 +734,7 @@ def elemwise_scalar_op_has_c_code( if node_bitflag & unfuseable_ancestors_bitset: # An unfuseable ancestor of the subgraph depends on this node, can't fuse continue - elif ancestors_bitsets[node] & unfuseable_clients_bitset: + elif node_ancestors_bitset & unfuseable_clients_bitset: # This node depends on an unfuseable client of the subgraph, can't fuse continue @@ -749,7 +755,11 @@ def elemwise_scalar_op_has_c_code( if node in fuseable_clients.get(ancestor_node, ()): heappush( fuseables_nodes_queue, - (-ancestor_bitflag, ancestor_node), + ( + -ancestors_bitsets[ancestor_node], + ancestor_bitflag, + ancestor_node, + ), ) else: # If the node is not in the ancestor's fuseable clients set, it's not fuseable with it, @@ -764,16 +774,17 @@ def elemwise_scalar_op_has_c_code( if is_ancestor and (client_bitflag & subgraph_bitset): continue if client in next_fuseable_clients: - heappush(fuseables_nodes_queue, (client_bitflag, client)) + heappush( + fuseables_nodes_queue, + (ancestors_bitsets[client], client_bitflag, client), + ) else: # If a client is not in the node's fuseable clients set, it's nto fuseable with it, # nor any of its clients. But we don't need to keep track of those as any downstream # client we may consider later will also depend on this unfuseable client and be rejected unfuseable_clients_bitset |= client_bitflag - # Finished exploring this subgraph - all_subgraphs_bitset |= subgraph_bitset - + # Finished expansion of subgraph if subgraph_bitset == starting_bitflag: # We ended were we started, no fusion possible continue @@ -816,6 +827,18 @@ def elemwise_scalar_op_has_c_code( for out in subgraph_outputs: fuseable_clients.pop(out.owner, None) + # When we fuse multi-output subgraphs, we also need to fuse the dependencies of successor nodes. + # Nodes that previously depended on a subset of the fused outputs, now depend on all of them. + if len(subgraph_outputs) > 1: + subgraph_and_ancestors = ( + subgraph_bitset | unfuseable_ancestors_bitset + ) + ancestors_bitsets |= ( + (node, node_ancestors_bitset | subgraph_and_ancestors) + for node, node_ancestors_bitset in ancestors_bitsets.items() + if node_ancestors_bitset & subgraph_bitset + ) + # Add new subgraph to sorted_subgraphs # Because we start from sink nodes in reverse topological order, most times new subgraphs # don't depend on previous subgraphs, so we can just append them at the end. @@ -828,8 +851,7 @@ def elemwise_scalar_op_has_c_code( else: # But not here, so we need to find the right position for insertion. # We iterate through the previous subgraphs in topological order (reverse of the stored order). - # We exclude cumulatively exclude each subgraph_bitset and perform the same dependency check again. - # The (index + 1) of the firs iteration where the check passes is the correct insertion position. + # We cumulatively exclude each subgraph_bitset and perform the same dependency check again, until it passes. remaining_subgraphs_bitset = all_subgraphs_bitset for index, (other_subgraph_bitset, _) in enumerate( reversed(sorted_subgraphs) @@ -840,12 +862,20 @@ def elemwise_scalar_op_has_c_code( unfuseable_ancestors_bitset & remaining_subgraphs_bitset ): break # bingo + else: # no-break + raise RuntimeError( + "Failed to find insertion point for fused subgraph" + ) sorted_subgraphs.insert( -(index + 1), (subgraph_bitset, (subgraph_inputs, subgraph_outputs)), ) - # yield from sorted_subgraphs, discarding the subgraph_bitset + # Add subgraph to all_subgraphs_bitset + all_subgraphs_bitset |= subgraph_bitset + + # Finished exploring the whole graph + # Yield from sorted_subgraphs, discarding the subgraph_bitset yield from (io for _, io in sorted_subgraphs) max_operands = elemwise_max_operands_fct(None) diff --git a/tests/tensor/rewriting/test_elemwise.py b/tests/tensor/rewriting/test_elemwise.py index 523effb1d1..5982ee291d 100644 --- a/tests/tensor/rewriting/test_elemwise.py +++ b/tests/tensor/rewriting/test_elemwise.py @@ -1426,6 +1426,43 @@ def test_no_warning_from_old_client(self): np.log(1 - np.exp(-2)), ) + def test_joint_circular_dependency(self): + # Test a case where fused subgraphs could induce a circular dependency + x = matrix("x") + neg = pt.neg(x) + eq = pt.eq(x.sum(axis=0), 0) + sub = pt.sub(eq, neg) + exp = pt.exp(neg.sum(axis=0)) + # We test arbitrary add and output orders, to make sure our algorithm + # is robust to valid toposort variations. + for add_order in [(exp, eq), (eq, exp)]: + add = pt.add(*add_order) + + # The naive fused graphs to consider are {sub, neg} and {add, exp, eq}, + # which is not valid because sub depends on eq, while add/exp depends on neg. + # Instead, we can either fuse both {sub, neg} and {add, exp} or just {add, exp, eq} + + for out_order in [(sub, add), (add, sub)]: + fgraph = FunctionGraph([x], out_order, clone=True) + _, nb_fused, nb_replaced, *_ = FusionOptimizer().apply(fgraph) + # (nb_fused, nb_replaced) would be (2, 5) if we did the invalid fusion + assert (nb_fused, nb_replaced) in ((2, 4), (1, 3)) + fused_nodes = { + frozenset( + scalar_n.op for scalar_n in n.op.scalar_op.fgraph.apply_nodes + ) + for n in fgraph.apply_nodes + if isinstance(n.op, Elemwise) + and isinstance(n.op.scalar_op, Composite) + } + if nb_fused == 1: + assert fused_nodes == {frozenset((ps.add, ps.exp, ps.eq))} + else: + assert fused_nodes == { + frozenset((ps.sub, ps.neg)), + frozenset((ps.add, ps.exp)), + } + class TimesN(ps.basic.UnaryScalarOp): """