From eddb2d9cbdf55576c6e7532e0a25b40995d889dd Mon Sep 17 00:00:00 2001 From: Tanuj Khattar Date: Tue, 5 Sep 2023 19:08:32 -0700 Subject: [PATCH] Add caching to `value_equality_values` decorator for auto generated methods. (#6275) * Add caching to value_equality_values decorator for auto generated methods. * Fix pylint and formatting errors * Address nits, fix bugs and make PauliSum unhashable --- cirq-core/cirq/ops/dense_pauli_string.py | 3 +++ cirq-core/cirq/ops/linear_combinations.py | 2 +- cirq-core/cirq/value/value_equality_attr.py | 14 +++++++++++--- 3 files changed, 15 insertions(+), 4 deletions(-) diff --git a/cirq-core/cirq/ops/dense_pauli_string.py b/cirq-core/cirq/ops/dense_pauli_string.py index 6cf97c4eb31..9893b64f706 100644 --- a/cirq-core/cirq/ops/dense_pauli_string.py +++ b/cirq-core/cirq/ops/dense_pauli_string.py @@ -570,6 +570,9 @@ def copy( def __str__(self) -> str: return super().__str__() + ' (mutable)' + def _value_equality_values_(self): + return self.coefficient, tuple(PAULI_CHARS[p] for p in self.pauli_mask) + @classmethod def inline_gaussian_elimination(cls, rows: 'List[MutableDensePauliString]') -> None: if not rows: diff --git a/cirq-core/cirq/ops/linear_combinations.py b/cirq-core/cirq/ops/linear_combinations.py index 9f50216dab9..ed223dfe0de 100644 --- a/cirq-core/cirq/ops/linear_combinations.py +++ b/cirq-core/cirq/ops/linear_combinations.py @@ -357,7 +357,7 @@ def _pauli_string_from_unit(unit: UnitPauliStringT, coefficient: Union[int, floa return PauliString(qubit_pauli_map=dict(unit), coefficient=coefficient) -@value.value_equality(approximate=True) +@value.value_equality(approximate=True, unhashable=True) class PauliSum: """Represents operator defined by linear combination of PauliStrings. diff --git a/cirq-core/cirq/value/value_equality_attr.py b/cirq-core/cirq/value/value_equality_attr.py index 31d570430a6..f66c6549e57 100644 --- a/cirq-core/cirq/value/value_equality_attr.py +++ b/cirq-core/cirq/value/value_equality_attr.py @@ -17,7 +17,7 @@ from typing_extensions import Protocol -from cirq import protocols +from cirq import protocols, _compat class _SupportsValueEquality(Protocol): @@ -221,13 +221,21 @@ class return the existing class' type. ) else: setattr(cls, '_value_equality_values_cls_', lambda self: cls) - setattr(cls, '__hash__', None if unhashable else _value_equality_hash) + cached_values_getter = values_getter if unhashable else _compat.cached_method(values_getter) + setattr(cls, '_value_equality_values_', cached_values_getter) + setattr(cls, '__hash__', None if unhashable else _compat.cached_method(_value_equality_hash)) setattr(cls, '__eq__', _value_equality_eq) setattr(cls, '__ne__', _value_equality_ne) if approximate: if not hasattr(cls, '_value_equality_approximate_values_'): - setattr(cls, '_value_equality_approximate_values_', values_getter) + setattr(cls, '_value_equality_approximate_values_', cached_values_getter) + else: + approx_values_getter = getattr(cls, '_value_equality_approximate_values_') + cached_approx_values_getter = ( + approx_values_getter if unhashable else _compat.cached_method(approx_values_getter) + ) + setattr(cls, '_value_equality_approximate_values_', cached_approx_values_getter) setattr(cls, '_approx_eq_', _value_equality_approx_eq) return cls