diff --git a/pytensor/graph/rewriting/db.py b/pytensor/graph/rewriting/db.py index fb81622458..fc18636a1a 100644 --- a/pytensor/graph/rewriting/db.py +++ b/pytensor/graph/rewriting/db.py @@ -310,18 +310,21 @@ class EquilibriumDB(RewriteDatabase): """ def __init__( - self, ignore_newtrees: bool = True, tracks_on_change_inputs: bool = False + self, + ignore_newtrees: bool = True, + tracks_on_change_inputs: bool = False, + eq_rewriter_class=pytensor_rewriting.EquilibriumGraphRewriter, ): """ Parameters ---------- ignore_newtrees - If ``False``, apply rewrites to new nodes introduced during - rewriting. - + If ``False``, apply rewrites to new nodes introduced during rewritings. tracks_on_change_inputs If ``True``, re-apply rewrites on nodes with changed inputs. + eq_rewriter_class: EquilibriumGraphRewriter class, optional + The class used to create the equilibrium rewriter. Defaults to EquilibriumGraphRewriter. """ super().__init__() @@ -329,6 +332,7 @@ def __init__( self.tracks_on_change_inputs = tracks_on_change_inputs self.__final__: dict[str, bool] = {} self.__cleanup__: dict[str, bool] = {} + self.eq_rewriter_class = eq_rewriter_class def register( self, @@ -360,7 +364,7 @@ def query(self, *tags, **kwtags): final_rewriters = None if len(cleanup_rewriters) == 0: cleanup_rewriters = None - return pytensor_rewriting.EquilibriumGraphRewriter( + return self.eq_rewriter_class( rewriters, max_use_ratio=config.optdb__max_use_ratio, ignore_newtrees=self.ignore_newtrees, diff --git a/pytensor/scan/rewriting.py b/pytensor/scan/rewriting.py index c5ac0a28a3..adab47d37b 100644 --- a/pytensor/scan/rewriting.py +++ b/pytensor/scan/rewriting.py @@ -30,6 +30,7 @@ from pytensor.graph.op import compute_test_value from pytensor.graph.replace import clone_replace from pytensor.graph.rewriting.basic import ( + EquilibriumGraphRewriter, GraphRewriter, copy_stack_trace, in2out, @@ -2517,12 +2518,21 @@ def scan_push_out_dot1(fgraph, node): return False +class ScanEquilibriumGraphRewriter(EquilibriumGraphRewriter): + """Subclass of EquilibriumGraphRewriter that aborts early if there are no Scan Ops in the graph""" + + def apply(self, fgraph, start_from=None): + if not any(isinstance(node.op, Scan) for node in fgraph.apply_nodes): + return + super().apply(fgraph=fgraph, start_from=start_from) + + # I've added an equilibrium because later scan optimization in the sequence # can make it such that earlier optimizations should apply. However, in # general I do not expect the sequence to run more then once -scan_eqopt1 = EquilibriumDB() +scan_eqopt1 = EquilibriumDB(eq_rewriter_class=ScanEquilibriumGraphRewriter) scan_seqopt1 = SequenceDB() -scan_eqopt2 = EquilibriumDB() +scan_eqopt2 = EquilibriumDB(eq_rewriter_class=ScanEquilibriumGraphRewriter) # scan_eqopt1 before ShapeOpt at 0.1 # This is needed to don't have ShapeFeature trac old Scan that we