Skip to content

Commit

Permalink
Fix memoryleak in filters.
Browse files Browse the repository at this point in the history
Breaking change:
- `Filter()` subclasses have to call `super()`.
  • Loading branch information
jonathanslenders committed Feb 1, 2023
1 parent 568de64 commit 7776bf9
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 61 deletions.
6 changes: 5 additions & 1 deletion src/prompt_toolkit/filters/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,11 @@
]


@memoized()
# NOTE: `has_focus` below should *not* be `memoized`. It can reference any user
# control. For instance, if we would contiously create new
# `PromptSession` instances, then previous instances won't be released,
# because this memoize (which caches results in the global scope) will
# still refer to each instance.
def has_focus(value: "FocusableElement") -> Condition:
"""
Enable when this buffer has the focus.
Expand Down
111 changes: 51 additions & 60 deletions src/prompt_toolkit/filters/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import weakref
from abc import ABCMeta, abstractmethod
from typing import Callable, Dict, Iterable, List, Tuple, Union
from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union

__all__ = ["Filter", "Never", "Always", "Condition", "FilterOrBool"]

Expand All @@ -12,6 +13,15 @@ class Filter(metaclass=ABCMeta):
The return value of ``__call__`` will tell if the feature should be active.
"""

def __init__(self) -> None:
self._and_cache: "weakref.WeakValueDictionary[Filter, _AndList]" = (
weakref.WeakValueDictionary()
)
self._or_cache: "weakref.WeakValueDictionary[Filter, _OrList]" = (
weakref.WeakValueDictionary()
)
self._invert_result: Optional[Filter] = None

@abstractmethod
def __call__(self) -> bool:
"""
Expand All @@ -23,19 +33,46 @@ def __and__(self, other: "Filter") -> "Filter":
"""
Chaining of filters using the & operator.
"""
return _and_cache[self, other]
assert isinstance(other, Filter), "Expecting filter, got %r" % other

if isinstance(other, Always):
return self
if isinstance(other, Never):
return other

if other in self._and_cache:
return self._and_cache[other]

result = _AndList([self, other])
self._and_cache[other] = result
return result

def __or__(self, other: "Filter") -> "Filter":
"""
Chaining of filters using the | operator.
"""
return _or_cache[self, other]
assert isinstance(other, Filter), "Expecting filter, got %r" % other

if isinstance(other, Always):
return other
if isinstance(other, Never):
return self

if other in self._or_cache:
return self._or_cache[other]

result = _OrList([self, other])
self._or_cache[other] = result
return result

def __invert__(self) -> "Filter":
"""
Inverting of filters using the ~ operator.
"""
return _invert_cache[self]
if self._invert_result is None:
self._invert_result = _Invert(self)

return self._invert_result

def __bool__(self) -> None:
"""
Expand All @@ -52,68 +89,13 @@ def __bool__(self) -> None:
)


class _AndCache(Dict[Tuple[Filter, Filter], "_AndList"]):
"""
Cache for And operation between filters.
(Filter classes are stateless, so we can reuse them.)
Note: This could be a memory leak if we keep creating filters at runtime.
If that is True, the filters should be weakreffed (not the tuple of
filters), and tuples should be removed when one of these filters is
removed. In practise however, there is a finite amount of filters.
"""

def __missing__(self, filters: Tuple[Filter, Filter]) -> Filter:
a, b = filters
assert isinstance(b, Filter), "Expecting filter, got %r" % b

if isinstance(b, Always) or isinstance(a, Never):
return a
elif isinstance(b, Never) or isinstance(a, Always):
return b

result = _AndList(filters)
self[filters] = result
return result


class _OrCache(Dict[Tuple[Filter, Filter], "_OrList"]):
"""Cache for Or operation between filters."""

def __missing__(self, filters: Tuple[Filter, Filter]) -> Filter:
a, b = filters
assert isinstance(b, Filter), "Expecting filter, got %r" % b

if isinstance(b, Always) or isinstance(a, Never):
return b
elif isinstance(b, Never) or isinstance(a, Always):
return a

result = _OrList(filters)
self[filters] = result
return result


class _InvertCache(Dict[Filter, "_Invert"]):
"""Cache for inversion operator."""

def __missing__(self, filter: Filter) -> Filter:
result = _Invert(filter)
self[filter] = result
return result


_and_cache = _AndCache()
_or_cache = _OrCache()
_invert_cache = _InvertCache()


class _AndList(Filter):
"""
Result of &-operation between several filters.
"""

def __init__(self, filters: Iterable[Filter]) -> None:
super().__init__()
self.filters: List[Filter] = []

for f in filters:
Expand All @@ -135,6 +117,7 @@ class _OrList(Filter):
"""

def __init__(self, filters: Iterable[Filter]) -> None:
super().__init__()
self.filters: List[Filter] = []

for f in filters:
Expand All @@ -156,6 +139,7 @@ class _Invert(Filter):
"""

def __init__(self, filter: Filter) -> None:
super().__init__()
self.filter = filter

def __call__(self) -> bool:
Expand All @@ -173,6 +157,9 @@ class Always(Filter):
def __call__(self) -> bool:
return True

def __or__(self, other: "Filter") -> "Filter":
return self

def __invert__(self) -> "Never":
return Never()

Expand All @@ -185,6 +172,9 @@ class Never(Filter):
def __call__(self) -> bool:
return False

def __and__(self, other: "Filter") -> "Filter":
return self

def __invert__(self) -> Always:
return Always()

Expand All @@ -204,6 +194,7 @@ def feature_is_active(): # `feature_is_active` becomes a Filter.
"""

def __init__(self, func: Callable[[], bool]) -> None:
super().__init__()
self.func = func

def __call__(self) -> bool:
Expand Down
33 changes: 33 additions & 0 deletions tests/test_memory_leaks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import gc

import pytest

from prompt_toolkit.shortcuts.prompt import PromptSession


def _count_prompt_session_instances() -> int:
# Run full GC collection first.
gc.collect()

# Count number of remaining referenced `PromptSession` instances.
objects = gc.get_objects()
return len([obj for obj in objects if isinstance(obj, PromptSession)])


# Fails in GitHub CI, probably due to GC differences.
@pytest.mark.xfail(reason="Memory leak testing fails in GitHub CI.")
def test_prompt_session_memory_leak() -> None:
before_count = _count_prompt_session_instances()

# Somehow in CI/CD, the before_count is > 0
assert before_count == 0

p = PromptSession()

after_count = _count_prompt_session_instances()
assert after_count == before_count + 1

del p

after_delete_count = _count_prompt_session_instances()
assert after_delete_count == before_count

0 comments on commit 7776bf9

Please sign in to comment.