From 77a0b7e207d601870a72b08dde5391e4e0a32724 Mon Sep 17 00:00:00 2001 From: Matthew Neeley Date: Tue, 21 Jun 2022 13:29:06 -0700 Subject: [PATCH 1/3] Add cached_method decorator for per-instance method caches --- cirq-core/cirq/_compat.py | 51 +++++++++++++++++++++++++++++++++- cirq-core/cirq/_compat_test.py | 34 ++++++++++++++++++++++- 2 files changed, 83 insertions(+), 2 deletions(-) diff --git a/cirq-core/cirq/_compat.py b/cirq-core/cirq/_compat.py index e17597060c7..7752f06be09 100644 --- a/cirq-core/cirq/_compat.py +++ b/cirq-core/cirq/_compat.py @@ -23,7 +23,7 @@ import traceback import warnings from types import ModuleType -from typing import Any, Callable, Optional, Dict, Tuple, Type, Set +from typing import Any, Callable, Dict, Optional, overload, Set, Tuple, Type, TypeVar import numpy as np import pandas as pd @@ -39,6 +39,55 @@ from backports.cached_property import cached_property # type: ignore[no-redef] +TFunc = TypeVar('TFunc', bound=Callable) + + +@overload +def cached_method(__func: TFunc) -> TFunc: + ... + + +@overload +def cached_method(*, maxsize: int = 128) -> Callable[[TFunc], TFunc]: + ... + + +def cached_method(*args: Any, maxsize: int = 128) -> Any: + """Decorator that adds a per-instance LRU cache for a method. + + Can be applied with or without parameters to customize the underlying cache: + + @cached_method + def foo(self, name: str) -> int: + ... + + @cached_method(maxsize=1000) + def bar(self, name: str) -> int: + ... + """ + + def decorator(func): + cache_name = f'_{func.__name__}_cache' + + @functools.wraps(func) + def wrapped(self, *args, **kwargs): + cached = getattr(self, cache_name, None) + if cached is None: + + @functools.lru_cache(maxsize=maxsize) + def cached_func(*args, **kwargs): + return func(self, *args, **kwargs) + + object.__setattr__(self, cache_name, cached_func) + cached = cached_func + return cached(*args, **kwargs) + + return wrapped + + return decorator(args[0]) if args else decorator + + + def proper_repr(value: Any) -> str: """Overrides sympy and numpy returning repr strings that don't parse.""" diff --git a/cirq-core/cirq/_compat_test.py b/cirq-core/cirq/_compat_test.py index c1ce6cab838..62cb4fdd649 100644 --- a/cirq-core/cirq/_compat_test.py +++ b/cirq-core/cirq/_compat_test.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import collections import dataclasses import importlib import logging @@ -21,7 +22,7 @@ import types import warnings from types import ModuleType -from typing import Any, Callable, Optional +from typing import Any, Callable, Dict, Optional, Tuple from importlib.machinery import ModuleSpec from unittest import mock @@ -35,6 +36,7 @@ import cirq.testing from cirq._compat import ( block_overlapping_deprecation, + cached_method, cached_property, proper_repr, dataclass_repr, @@ -985,3 +987,33 @@ def bar(self): bar2 = foo.bar assert bar2 is bar assert foo.bar_calls == 1 + + +class Bar: + def __init__(self): + self.foo_calls: Dict[int, int] = collections.Counter() + self.bar_calls: Dict[int, int] = collections.Counter() + + @cached_method + def foo(self, n: int) -> Tuple[int, int]: + self.foo_calls[n] += 1 + return (id(self), n) + + @cached_method(maxsize=1) + def bar(self, n: int) -> Tuple[int, int]: + self.bar_calls[n] += 1 + return (id(self), 2 * n) + + +def test_cached_method(): + b = Bar() + assert b.foo(123) == b.foo(123) == b.foo(123) == (id(b), 123) + assert b.foo(234) == b.foo(234) == b.foo(234) == (id(b), 234) + assert b.foo_calls == {123: 1, 234: 1} + + assert b.bar(123) == b.bar(123) == (id(b), 123 * 2) + assert b.bar_calls == {123: 1} + assert b.bar(234) == b.bar(234) == (id(b), 234 * 2) + assert b.bar_calls == {123: 1, 234: 1} + assert b.bar(123) == b.bar(123) == (id(b), 123 * 2) + assert b.bar_calls == {123: 2, 234: 1} From 296cefed22ceba5f90e315ac9c44add43891d4c8 Mon Sep 17 00:00:00 2001 From: Matthew Neeley Date: Wed, 22 Jun 2022 01:29:09 -0700 Subject: [PATCH 2/3] Format --- cirq-core/cirq/_compat.py | 1 - 1 file changed, 1 deletion(-) diff --git a/cirq-core/cirq/_compat.py b/cirq-core/cirq/_compat.py index 7752f06be09..246e0dd4b27 100644 --- a/cirq-core/cirq/_compat.py +++ b/cirq-core/cirq/_compat.py @@ -87,7 +87,6 @@ def cached_func(*args, **kwargs): return decorator(args[0]) if args else decorator - def proper_repr(value: Any) -> str: """Overrides sympy and numpy returning repr strings that don't parse.""" From 5f1d631a27017d6b153cb9544769e85d3c889c1a Mon Sep 17 00:00:00 2001 From: Pavol Juhas Date: Thu, 23 Jun 2022 14:29:55 -0700 Subject: [PATCH 3/3] Let cached_method have exactly one optional positional argument Avoid repeated variable name `args`. --- cirq-core/cirq/_compat.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cirq-core/cirq/_compat.py b/cirq-core/cirq/_compat.py index 246e0dd4b27..7cc28cdf559 100644 --- a/cirq-core/cirq/_compat.py +++ b/cirq-core/cirq/_compat.py @@ -52,7 +52,7 @@ def cached_method(*, maxsize: int = 128) -> Callable[[TFunc], TFunc]: ... -def cached_method(*args: Any, maxsize: int = 128) -> Any: +def cached_method(method: Optional[TFunc] = None, *, maxsize: int = 128) -> Any: """Decorator that adds a per-instance LRU cache for a method. Can be applied with or without parameters to customize the underlying cache: @@ -84,7 +84,7 @@ def cached_func(*args, **kwargs): return wrapped - return decorator(args[0]) if args else decorator + return decorator if method is None else decorator(method) def proper_repr(value: Any) -> str: