diff --git a/pytensor/tensor/rewriting/elemwise.py b/pytensor/tensor/rewriting/elemwise.py index dfcdfdd471..65d9cfafb5 100644 --- a/pytensor/tensor/rewriting/elemwise.py +++ b/pytensor/tensor/rewriting/elemwise.py @@ -652,12 +652,12 @@ def elemwise_scalar_op_has_c_code( # `ancestors_bitset[C] & (node_bitset[A] | node_bitset[B]) != 0` nodes_bitflags = {node: 1 << i for i, node in enumerate(fgraph.toposort())} # Root variables have `None` as owner, which we can handle with a bitset of 0 - ancestors_bitset = {None: 0} + ancestors_bitsets: dict[Apply | None, int] = {None: 0} for node, node_bitflag in nodes_bitflags.items(): # The bitset of each node is the union of the bitsets of its inputs, plus its own bit flag - ancestors_bitset[node] = reduce( + ancestors_bitsets[node] = reduce( or_, - (ancestors_bitset[inp.owner] for inp in node.inputs), + (ancestors_bitsets[inp.owner] for inp in node.inputs), node_bitflag, ) # Handle root and leaf nodes gracefully @@ -666,10 +666,12 @@ def elemwise_scalar_op_has_c_code( nodes_bitflags[None] = 0 # Nothing ever depends on the special Output nodes, so just use a new bit for all of them out_bitflag = 1 << len(nodes_bitflags) - for out in fg.outputs: - for client, _ in fg_clients[out]: - if isinstance(client.op, Output): - nodes_bitflags[client] = out_bitflag + nodes_bitflags |= ( + (client, out_bitflag) + for out in fg.outputs + for client, _ in fg_clients[out] + if isinstance(client.op, Output) + ) # Start main loop to find collection of fuseable subgraphs # We store the collection in `sorted_subgraphs`, in reverse topological order @@ -692,9 +694,13 @@ def elemwise_scalar_op_has_c_code( # For simplicity, we always want to visit ancestors before clients # For ancestors, we want to visit the later nodes first (those that have more dependencies) # whereas for clients we want to visit earlier nodes first (those that have fewer dependencies) - # To achieve this we use the bitflag as the sorting key (which encodes the topological order) - # and negate it for ancestors. - fuseables_nodes_queue = [(-starting_bitflag, starting_node)] + # To achieve this we use the ancestors_bitset as the sorting key (which encodes the topological order) + # and negate it for ancestors. We use the ancestors_bitset instead of the node bitflag because we + # update the former when we find a fuseable subgraph, emulating the effect of recomputing the + # topological order on the remaining nodes. + fuseables_nodes_queue = [ + (-ancestors_bitsets[starting_node], starting_bitflag, starting_node) + ] heapify(fuseables_nodes_queue) # We keep 3 bitsets during the exploration of a new subgraph: @@ -713,10 +719,12 @@ def elemwise_scalar_op_has_c_code( unfuseable_clients_bitset = 0 while fuseables_nodes_queue: - node_bitflag, node = heappop(fuseables_nodes_queue) - is_ancestor = node_bitflag < 0 + node_ancestors_bitset, node_bitflag, node = heappop( + fuseables_nodes_queue + ) + is_ancestor = node_ancestors_bitset < 0 if is_ancestor: - node_bitflag = -node_bitflag + node_ancestors_bitset = -node_ancestors_bitset if node_bitflag & subgraph_bitset: # Already part of the subgraph @@ -726,7 +734,7 @@ def elemwise_scalar_op_has_c_code( if node_bitflag & unfuseable_ancestors_bitset: # An unfuseable ancestor of the subgraph depends on this node, can't fuse continue - elif ancestors_bitset[node] & unfuseable_clients_bitset: + elif node_ancestors_bitset & unfuseable_clients_bitset: # This node depends on an unfuseable client of the subgraph, can't fuse continue @@ -742,36 +750,41 @@ def elemwise_scalar_op_has_c_code( for inp in node.inputs: ancestor_node = inp.owner ancestor_bitflag = nodes_bitflags[ancestor_node] - if ancestor_bitflag & subgraph_bitset: + if (not is_ancestor) and (ancestor_bitflag & subgraph_bitset): continue if node in fuseable_clients.get(ancestor_node, ()): heappush( fuseables_nodes_queue, - (-ancestor_bitflag, ancestor_node), + ( + -ancestors_bitsets[ancestor_node], + ancestor_bitflag, + ancestor_node, + ), ) else: # If the node is not in the ancestor's fuseable clients set, it's not fuseable with it, # nor with any of the ancestor's ancestors - unfuseable_ancestors_bitset |= ancestors_bitset[ + unfuseable_ancestors_bitset |= ancestors_bitsets[ ancestor_node ] next_fuseable_clients = fuseable_clients.get(node, ()) for client, _ in fg_clients[node.outputs[0]]: client_bitflag = nodes_bitflags[client] - if client_bitflag & subgraph_bitset: + if is_ancestor and (client_bitflag & subgraph_bitset): continue if client in next_fuseable_clients: - heappush(fuseables_nodes_queue, (client_bitflag, client)) + heappush( + fuseables_nodes_queue, + (ancestors_bitsets[client], client_bitflag, client), + ) else: # If a client is not in the node's fuseable clients set, it's nto fuseable with it, # nor any of its clients. But we don't need to keep track of those as any downstream # client we may consider later will also depend on this unfuseable client and be rejected unfuseable_clients_bitset |= client_bitflag - # Finished exploring this subgraph - all_subgraphs_bitset |= subgraph_bitset - + # Finished expansion of subgraph if subgraph_bitset == starting_bitflag: # We ended were we started, no fusion possible continue @@ -814,6 +827,18 @@ def elemwise_scalar_op_has_c_code( for out in subgraph_outputs: fuseable_clients.pop(out.owner, None) + # When we fuse multi-output subgraphs, we also need to fuse the dependencies of successor nodes. + # Nodes that previously depended on a subset of the fused outputs, now depend on all of them. + if len(subgraph_outputs) > 1: + subgraph_and_ancestors = ( + subgraph_bitset | unfuseable_ancestors_bitset + ) + ancestors_bitsets |= ( + (node, node_ancestors_bitset | subgraph_and_ancestors) + for node, node_ancestors_bitset in ancestors_bitsets.items() + if node_ancestors_bitset & subgraph_bitset + ) + # Add new subgraph to sorted_subgraphs # Because we start from sink nodes in reverse topological order, most times new subgraphs # don't depend on previous subgraphs, so we can just append them at the end. @@ -826,8 +851,7 @@ def elemwise_scalar_op_has_c_code( else: # But not here, so we need to find the right position for insertion. # We iterate through the previous subgraphs in topological order (reverse of the stored order). - # We exclude cumulatively exclude each subgraph_bitset and perform the same dependency check again. - # The (index + 1) of the firs iteration where the check passes is the correct insertion position. + # We cumulatively exclude each subgraph_bitset and perform the same dependency check again, until it passes. remaining_subgraphs_bitset = all_subgraphs_bitset for index, (other_subgraph_bitset, _) in enumerate( reversed(sorted_subgraphs) @@ -838,12 +862,20 @@ def elemwise_scalar_op_has_c_code( unfuseable_ancestors_bitset & remaining_subgraphs_bitset ): break # bingo + else: # no-break + raise RuntimeError( + "Failed to find insertion point for fused subgraph" + ) sorted_subgraphs.insert( -(index + 1), (subgraph_bitset, (subgraph_inputs, subgraph_outputs)), ) - # yield from sorted_subgraphs, discarding the subgraph_bitset + # Add subgraph to all_subgraphs_bitset + all_subgraphs_bitset |= subgraph_bitset + + # Finished exploring the whole graph + # Yield from sorted_subgraphs, discarding the subgraph_bitset yield from (io for _, io in sorted_subgraphs) max_operands = elemwise_max_operands_fct(None) diff --git a/tests/tensor/rewriting/test_elemwise.py b/tests/tensor/rewriting/test_elemwise.py index 523effb1d1..5982ee291d 100644 --- a/tests/tensor/rewriting/test_elemwise.py +++ b/tests/tensor/rewriting/test_elemwise.py @@ -1426,6 +1426,43 @@ def test_no_warning_from_old_client(self): np.log(1 - np.exp(-2)), ) + def test_joint_circular_dependency(self): + # Test a case where fused subgraphs could induce a circular dependency + x = matrix("x") + neg = pt.neg(x) + eq = pt.eq(x.sum(axis=0), 0) + sub = pt.sub(eq, neg) + exp = pt.exp(neg.sum(axis=0)) + # We test arbitrary add and output orders, to make sure our algorithm + # is robust to valid toposort variations. + for add_order in [(exp, eq), (eq, exp)]: + add = pt.add(*add_order) + + # The naive fused graphs to consider are {sub, neg} and {add, exp, eq}, + # which is not valid because sub depends on eq, while add/exp depends on neg. + # Instead, we can either fuse both {sub, neg} and {add, exp} or just {add, exp, eq} + + for out_order in [(sub, add), (add, sub)]: + fgraph = FunctionGraph([x], out_order, clone=True) + _, nb_fused, nb_replaced, *_ = FusionOptimizer().apply(fgraph) + # (nb_fused, nb_replaced) would be (2, 5) if we did the invalid fusion + assert (nb_fused, nb_replaced) in ((2, 4), (1, 3)) + fused_nodes = { + frozenset( + scalar_n.op for scalar_n in n.op.scalar_op.fgraph.apply_nodes + ) + for n in fgraph.apply_nodes + if isinstance(n.op, Elemwise) + and isinstance(n.op.scalar_op, Composite) + } + if nb_fused == 1: + assert fused_nodes == {frozenset((ps.add, ps.exp, ps.eq))} + else: + assert fused_nodes == { + frozenset((ps.sub, ps.neg)), + frozenset((ps.add, ps.exp)), + } + class TimesN(ps.basic.UnaryScalarOp): """