Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions pytensor/graph/rewriting/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,25 +310,29 @@ class EquilibriumDB(RewriteDatabase):
"""

def __init__(
self, ignore_newtrees: bool = True, tracks_on_change_inputs: bool = False
self,
ignore_newtrees: bool = True,
tracks_on_change_inputs: bool = False,
eq_rewriter_class=pytensor_rewriting.EquilibriumGraphRewriter,
):
"""

Parameters
----------
ignore_newtrees
If ``False``, apply rewrites to new nodes introduced during
rewriting.

If ``False``, apply rewrites to new nodes introduced during rewritings.
tracks_on_change_inputs
If ``True``, re-apply rewrites on nodes with changed inputs.
eq_rewriter_class: EquilibriumGraphRewriter class, optional
The class used to create the equilibrium rewriter. Defaults to EquilibriumGraphRewriter.

"""
super().__init__()
self.ignore_newtrees = ignore_newtrees
self.tracks_on_change_inputs = tracks_on_change_inputs
self.__final__: dict[str, bool] = {}
self.__cleanup__: dict[str, bool] = {}
self.eq_rewriter_class = eq_rewriter_class

def register(
self,
Expand Down Expand Up @@ -360,7 +364,7 @@ def query(self, *tags, **kwtags):
final_rewriters = None
if len(cleanup_rewriters) == 0:
cleanup_rewriters = None
return pytensor_rewriting.EquilibriumGraphRewriter(
return self.eq_rewriter_class(
rewriters,
max_use_ratio=config.optdb__max_use_ratio,
ignore_newtrees=self.ignore_newtrees,
Expand Down
14 changes: 12 additions & 2 deletions pytensor/scan/rewriting.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from pytensor.graph.op import compute_test_value
from pytensor.graph.replace import clone_replace
from pytensor.graph.rewriting.basic import (
EquilibriumGraphRewriter,
GraphRewriter,
copy_stack_trace,
in2out,
Expand Down Expand Up @@ -2517,12 +2518,21 @@ def scan_push_out_dot1(fgraph, node):
return False


class ScanEquilibriumGraphRewriter(EquilibriumGraphRewriter):
"""Subclass of EquilibriumGraphRewriter that aborts early if there are no Scan Ops in the graph"""

def apply(self, fgraph, start_from=None):
if not any(isinstance(node.op, Scan) for node in fgraph.apply_nodes):
return
super().apply(fgraph=fgraph, start_from=start_from)


# I've added an equilibrium because later scan optimization in the sequence
# can make it such that earlier optimizations should apply. However, in
# general I do not expect the sequence to run more then once
scan_eqopt1 = EquilibriumDB()
scan_eqopt1 = EquilibriumDB(eq_rewriter_class=ScanEquilibriumGraphRewriter)
scan_seqopt1 = SequenceDB()
scan_eqopt2 = EquilibriumDB()
scan_eqopt2 = EquilibriumDB(eq_rewriter_class=ScanEquilibriumGraphRewriter)

# scan_eqopt1 before ShapeOpt at 0.1
# This is needed to don't have ShapeFeature trac old Scan that we
Expand Down