Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions pytensor/graph/rewriting/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1073,13 +1073,15 @@ def __init__(self):
defaultdict(lambda: defaultdict(list))
)
self.untracked_rewrites: list[NodeRewriter] = []
self.get_trackers = functools.cache(self._get_trackers)
self._cached_composed_mro = None

def add_tracker(self, rw: NodeRewriter):
"""Add a `NodeRewriter` to be keyed by its `NodeRewriter.tracks` or applied generally."""
if self._cached_composed_mro is not None:
# We shouldn't actually add_trackers after the first call to get_trackers
# But just to be safe we kill the cache here
self.get_trackers = functools.cache(self._get_trackers)
self._cached_composed_mro = None

tracks = rw.tracks()
Expand Down Expand Up @@ -1107,8 +1109,7 @@ def add_tracker(self, rw: NodeRewriter):
else:
self.tracked_instances[c].append(rw)

@functools.cache
def get_trackers(self, op: Op) -> list[NodeRewriter]:
def _get_trackers(self, op: Op) -> list[NodeRewriter]:
"""Get all the rewrites applicable to an `Op`."""

if self._cached_composed_mro is None:
Expand Down
45 changes: 45 additions & 0 deletions tests/graph/rewriting/test_basic.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import gc
import operator

import pytest

from pytensor.configdefaults import config
from pytensor.graph import rewrite_graph
from pytensor.graph.basic import Apply, Constant, equal_computations
from pytensor.graph.features import Feature
from pytensor.graph.fg import FunctionGraph
Expand Down Expand Up @@ -930,3 +934,44 @@ def perform(self, *args):
local_rewriter_2,
local_rewriter_1,
]


def test_rewrite_weakref_leak():
"""Check we don't have weakref leak on our rewrites"""

def _growth(limit=10, peak_stats={}):
"""Vendoring of objgraph.growth

Source: https://github.com/mgedmin/objgraph/blob/94b1ca61a11109547442701800292dcfc7f59fc8/objgraph.py#L253
"""
gc.collect()
objects = gc.get_objects()

stats = {}
for o in objects:
n = type(o).__name__
stats[n] = stats.get(n, 0) + 1

deltas = {}
for name, count in stats.items():
old_count = peak_stats.get(name, 0)
if count > old_count:
deltas[name] = count - old_count
peak_stats[name] = count

deltas = sorted(deltas.items(), key=operator.itemgetter(1), reverse=True)

if limit:
deltas = deltas[:limit]

return [(name, stats[name], delta) for name, delta in deltas]

x = vector("x")
y = exp(x)

for i in range(20):
rewrite_graph(y, clone=False)
res = _growth()
# Only start checking after warmup
if i > 15:
assert not res, "Object counts are still growing"