Skip to content

Commit ae43c14

Browse files
committed
Use non recursive algorithm in rebuild_collect_shared
1 parent 40861a7 commit ae43c14

File tree

1 file changed

+47
-40
lines changed

1 file changed

+47
-40
lines changed

pytensor/compile/function/pfunc.py

Lines changed: 47 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -179,47 +179,54 @@ def clone_v_get_shared_updates(v, copy_inputs_over):
179179
180180
"""
181181
# this co-recurses with clone_a
182-
assert v is not None
183-
if v in clone_d:
184-
return clone_d[v]
185-
if v.owner:
186-
owner = v.owner
187-
if owner not in clone_d:
188-
for i in owner.inputs:
189-
clone_v_get_shared_updates(i, copy_inputs_over)
190-
clone_node_and_cache(
191-
owner,
192-
clone_d,
193-
strict=rebuild_strict,
194-
clone_inner_graphs=clone_inner_graphs,
195-
)
196-
return clone_d.setdefault(v, v)
197-
elif isinstance(v, SharedVariable):
198-
if v not in shared_inputs:
199-
shared_inputs.append(v)
200-
if v.default_update is not None:
201-
# Check that v should not be excluded from the default
202-
# updates list
203-
if no_default_updates is False or (
204-
isinstance(no_default_updates, list) and v not in no_default_updates
205-
):
206-
# Do not use default_update if a "real" update was
207-
# provided
208-
if v not in update_d:
209-
v_update = v.type.filter_variable(
210-
v.default_update, allow_convert=False
182+
stack = [v]
183+
try:
184+
while True:
185+
v = stack.pop()
186+
if v in clone_d:
187+
continue
188+
if (apply := v.owner) is not None:
189+
if all(i in clone_d for i in apply.inputs):
190+
# all inputs have been cloned, we can clone this node
191+
clone_node_and_cache(
192+
apply,
193+
clone_d,
194+
strict=rebuild_strict,
195+
clone_inner_graphs=clone_inner_graphs,
211196
)
212-
if not v.type.is_super(v_update.type):
213-
raise TypeError(
214-
"An update must have a type compatible with "
215-
"the original shared variable"
216-
)
217-
update_d[v] = v_update
218-
update_expr.append((v, v_update))
219-
if not copy_inputs_over:
220-
return clone_d.setdefault(v, v.clone())
221-
else:
222-
return clone_d.setdefault(v, v)
197+
else:
198+
# expand on the inputs
199+
stack.extend(apply.inputs)
200+
else:
201+
clone_d[v] = v if copy_inputs_over else v.clone()
202+
203+
# Special handling of SharedVariables
204+
if isinstance(v, SharedVariable):
205+
if v not in shared_inputs:
206+
shared_inputs.append(v)
207+
if v.default_update is not None:
208+
# Check that v should not be excluded from the default
209+
# updates list
210+
if no_default_updates is False or (
211+
isinstance(no_default_updates, list)
212+
and v not in no_default_updates
213+
):
214+
# Do not use default_update if a "real" update was
215+
# provided
216+
if v not in update_d:
217+
v_update = v.type.filter_variable(
218+
v.default_update, allow_convert=False
219+
)
220+
if not v.type.is_super(v_update.type):
221+
raise TypeError(
222+
"An update must have a type compatible with "
223+
"the original shared variable"
224+
)
225+
update_d[v] = v_update
226+
update_expr.append((v, v_update))
227+
except IndexError:
228+
pass # stack is empty
229+
return clone_d[v]
223230

224231
# initialize the clone_d mapping with the replace dictionary
225232
if replace is None:

0 commit comments

Comments
 (0)