Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 57 additions & 25 deletions pytensor/tensor/rewriting/elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,12 +652,12 @@ def elemwise_scalar_op_has_c_code(
# `ancestors_bitset[C] & (node_bitset[A] | node_bitset[B]) != 0`
nodes_bitflags = {node: 1 << i for i, node in enumerate(fgraph.toposort())}
# Root variables have `None` as owner, which we can handle with a bitset of 0
ancestors_bitset = {None: 0}
ancestors_bitsets: dict[Apply | None, int] = {None: 0}
for node, node_bitflag in nodes_bitflags.items():
# The bitset of each node is the union of the bitsets of its inputs, plus its own bit flag
ancestors_bitset[node] = reduce(
ancestors_bitsets[node] = reduce(
or_,
(ancestors_bitset[inp.owner] for inp in node.inputs),
(ancestors_bitsets[inp.owner] for inp in node.inputs),
node_bitflag,
)
# Handle root and leaf nodes gracefully
Expand All @@ -666,10 +666,12 @@ def elemwise_scalar_op_has_c_code(
nodes_bitflags[None] = 0
# Nothing ever depends on the special Output nodes, so just use a new bit for all of them
out_bitflag = 1 << len(nodes_bitflags)
for out in fg.outputs:
for client, _ in fg_clients[out]:
if isinstance(client.op, Output):
nodes_bitflags[client] = out_bitflag
nodes_bitflags |= (
(client, out_bitflag)
for out in fg.outputs
for client, _ in fg_clients[out]
if isinstance(client.op, Output)
)

# Start main loop to find collection of fuseable subgraphs
# We store the collection in `sorted_subgraphs`, in reverse topological order
Expand All @@ -692,9 +694,13 @@ def elemwise_scalar_op_has_c_code(
# For simplicity, we always want to visit ancestors before clients
# For ancestors, we want to visit the later nodes first (those that have more dependencies)
# whereas for clients we want to visit earlier nodes first (those that have fewer dependencies)
# To achieve this we use the bitflag as the sorting key (which encodes the topological order)
# and negate it for ancestors.
fuseables_nodes_queue = [(-starting_bitflag, starting_node)]
# To achieve this we use the ancestors_bitset as the sorting key (which encodes the topological order)
# and negate it for ancestors. We use the ancestors_bitset instead of the node bitflag because we
# update the former when we find a fuseable subgraph, emulating the effect of recomputing the
# topological order on the remaining nodes.
fuseables_nodes_queue = [
(-ancestors_bitsets[starting_node], starting_bitflag, starting_node)
]
heapify(fuseables_nodes_queue)

# We keep 3 bitsets during the exploration of a new subgraph:
Expand All @@ -713,10 +719,12 @@ def elemwise_scalar_op_has_c_code(
unfuseable_clients_bitset = 0

while fuseables_nodes_queue:
node_bitflag, node = heappop(fuseables_nodes_queue)
is_ancestor = node_bitflag < 0
node_ancestors_bitset, node_bitflag, node = heappop(
fuseables_nodes_queue
)
is_ancestor = node_ancestors_bitset < 0
if is_ancestor:
node_bitflag = -node_bitflag
node_ancestors_bitset = -node_ancestors_bitset

if node_bitflag & subgraph_bitset:
# Already part of the subgraph
Expand All @@ -726,7 +734,7 @@ def elemwise_scalar_op_has_c_code(
if node_bitflag & unfuseable_ancestors_bitset:
# An unfuseable ancestor of the subgraph depends on this node, can't fuse
continue
elif ancestors_bitset[node] & unfuseable_clients_bitset:
elif node_ancestors_bitset & unfuseable_clients_bitset:
# This node depends on an unfuseable client of the subgraph, can't fuse
continue

Expand All @@ -742,36 +750,41 @@ def elemwise_scalar_op_has_c_code(
for inp in node.inputs:
ancestor_node = inp.owner
ancestor_bitflag = nodes_bitflags[ancestor_node]
if ancestor_bitflag & subgraph_bitset:
if (not is_ancestor) and (ancestor_bitflag & subgraph_bitset):
continue
if node in fuseable_clients.get(ancestor_node, ()):
heappush(
fuseables_nodes_queue,
(-ancestor_bitflag, ancestor_node),
(
-ancestors_bitsets[ancestor_node],
ancestor_bitflag,
ancestor_node,
),
)
else:
# If the node is not in the ancestor's fuseable clients set, it's not fuseable with it,
# nor with any of the ancestor's ancestors
unfuseable_ancestors_bitset |= ancestors_bitset[
unfuseable_ancestors_bitset |= ancestors_bitsets[
ancestor_node
]

next_fuseable_clients = fuseable_clients.get(node, ())
for client, _ in fg_clients[node.outputs[0]]:
client_bitflag = nodes_bitflags[client]
if client_bitflag & subgraph_bitset:
if is_ancestor and (client_bitflag & subgraph_bitset):
continue
if client in next_fuseable_clients:
heappush(fuseables_nodes_queue, (client_bitflag, client))
heappush(
fuseables_nodes_queue,
(ancestors_bitsets[client], client_bitflag, client),
)
else:
# If a client is not in the node's fuseable clients set, it's nto fuseable with it,
# nor any of its clients. But we don't need to keep track of those as any downstream
# client we may consider later will also depend on this unfuseable client and be rejected
unfuseable_clients_bitset |= client_bitflag

# Finished exploring this subgraph
all_subgraphs_bitset |= subgraph_bitset

# Finished expansion of subgraph
if subgraph_bitset == starting_bitflag:
# We ended were we started, no fusion possible
continue
Expand Down Expand Up @@ -814,6 +827,18 @@ def elemwise_scalar_op_has_c_code(
for out in subgraph_outputs:
fuseable_clients.pop(out.owner, None)

# When we fuse multi-output subgraphs, we also need to fuse the dependencies of successor nodes.
# Nodes that previously depended on a subset of the fused outputs, now depend on all of them.
if len(subgraph_outputs) > 1:
subgraph_and_ancestors = (
subgraph_bitset | unfuseable_ancestors_bitset
)
ancestors_bitsets |= (
(node, node_ancestors_bitset | subgraph_and_ancestors)
for node, node_ancestors_bitset in ancestors_bitsets.items()
if node_ancestors_bitset & subgraph_bitset
)

# Add new subgraph to sorted_subgraphs
# Because we start from sink nodes in reverse topological order, most times new subgraphs
# don't depend on previous subgraphs, so we can just append them at the end.
Expand All @@ -826,8 +851,7 @@ def elemwise_scalar_op_has_c_code(
else:
# But not here, so we need to find the right position for insertion.
# We iterate through the previous subgraphs in topological order (reverse of the stored order).
# We exclude cumulatively exclude each subgraph_bitset and perform the same dependency check again.
# The (index + 1) of the firs iteration where the check passes is the correct insertion position.
# We cumulatively exclude each subgraph_bitset and perform the same dependency check again, until it passes.
remaining_subgraphs_bitset = all_subgraphs_bitset
for index, (other_subgraph_bitset, _) in enumerate(
reversed(sorted_subgraphs)
Expand All @@ -838,12 +862,20 @@ def elemwise_scalar_op_has_c_code(
unfuseable_ancestors_bitset & remaining_subgraphs_bitset
):
break # bingo
else: # no-break
raise RuntimeError(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would have raised in the failing scan test, showing our sorting failed

"Failed to find insertion point for fused subgraph"
)
sorted_subgraphs.insert(
-(index + 1),
(subgraph_bitset, (subgraph_inputs, subgraph_outputs)),
)

# yield from sorted_subgraphs, discarding the subgraph_bitset
# Add subgraph to all_subgraphs_bitset
all_subgraphs_bitset |= subgraph_bitset

# Finished exploring the whole graph
# Yield from sorted_subgraphs, discarding the subgraph_bitset
yield from (io for _, io in sorted_subgraphs)

max_operands = elemwise_max_operands_fct(None)
Expand Down
37 changes: 37 additions & 0 deletions tests/tensor/rewriting/test_elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -1426,6 +1426,43 @@ def test_no_warning_from_old_client(self):
np.log(1 - np.exp(-2)),
)

def test_joint_circular_dependency(self):
# Test a case where fused subgraphs could induce a circular dependency
x = matrix("x")
neg = pt.neg(x)
eq = pt.eq(x.sum(axis=0), 0)
sub = pt.sub(eq, neg)
exp = pt.exp(neg.sum(axis=0))
# We test arbitrary add and output orders, to make sure our algorithm
# is robust to valid toposort variations.
for add_order in [(exp, eq), (eq, exp)]:
add = pt.add(*add_order)

# The naive fused graphs to consider are {sub, neg} and {add, exp, eq},
# which is not valid because sub depends on eq, while add/exp depends on neg.
# Instead, we can either fuse both {sub, neg} and {add, exp} or just {add, exp, eq}

for out_order in [(sub, add), (add, sub)]:
fgraph = FunctionGraph([x], out_order, clone=True)
_, nb_fused, nb_replaced, *_ = FusionOptimizer().apply(fgraph)
# (nb_fused, nb_replaced) would be (2, 5) if we did the invalid fusion
assert (nb_fused, nb_replaced) in ((2, 4), (1, 3))
fused_nodes = {
frozenset(
scalar_n.op for scalar_n in n.op.scalar_op.fgraph.apply_nodes
)
for n in fgraph.apply_nodes
if isinstance(n.op, Elemwise)
and isinstance(n.op.scalar_op, Composite)
}
if nb_fused == 1:
assert fused_nodes == {frozenset((ps.add, ps.exp, ps.eq))}
else:
assert fused_nodes == {
frozenset((ps.sub, ps.neg)),
frozenset((ps.add, ps.exp)),
}


class TimesN(ps.basic.UnaryScalarOp):
"""
Expand Down