Skip to content

Commit

Permalink
Add caching to value_equality_values decorator for auto generated m…
Browse files Browse the repository at this point in the history
…ethods. (#6275)

* Add caching to value_equality_values decorator for auto generated methods.

* Fix pylint and formatting errors

* Address nits, fix bugs and make PauliSum unhashable
  • Loading branch information
tanujkhattar committed Sep 6, 2023
1 parent 0e80fa5 commit eddb2d9
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 4 deletions.
3 changes: 3 additions & 0 deletions cirq-core/cirq/ops/dense_pauli_string.py
Expand Up @@ -570,6 +570,9 @@ def copy(
def __str__(self) -> str:
return super().__str__() + ' (mutable)'

def _value_equality_values_(self):
return self.coefficient, tuple(PAULI_CHARS[p] for p in self.pauli_mask)

@classmethod
def inline_gaussian_elimination(cls, rows: 'List[MutableDensePauliString]') -> None:
if not rows:
Expand Down
2 changes: 1 addition & 1 deletion cirq-core/cirq/ops/linear_combinations.py
Expand Up @@ -357,7 +357,7 @@ def _pauli_string_from_unit(unit: UnitPauliStringT, coefficient: Union[int, floa
return PauliString(qubit_pauli_map=dict(unit), coefficient=coefficient)


@value.value_equality(approximate=True)
@value.value_equality(approximate=True, unhashable=True)
class PauliSum:
"""Represents operator defined by linear combination of PauliStrings.
Expand Down
14 changes: 11 additions & 3 deletions cirq-core/cirq/value/value_equality_attr.py
Expand Up @@ -17,7 +17,7 @@

from typing_extensions import Protocol

from cirq import protocols
from cirq import protocols, _compat


class _SupportsValueEquality(Protocol):
Expand Down Expand Up @@ -221,13 +221,21 @@ class return the existing class' type.
)
else:
setattr(cls, '_value_equality_values_cls_', lambda self: cls)
setattr(cls, '__hash__', None if unhashable else _value_equality_hash)
cached_values_getter = values_getter if unhashable else _compat.cached_method(values_getter)
setattr(cls, '_value_equality_values_', cached_values_getter)
setattr(cls, '__hash__', None if unhashable else _compat.cached_method(_value_equality_hash))
setattr(cls, '__eq__', _value_equality_eq)
setattr(cls, '__ne__', _value_equality_ne)

if approximate:
if not hasattr(cls, '_value_equality_approximate_values_'):
setattr(cls, '_value_equality_approximate_values_', values_getter)
setattr(cls, '_value_equality_approximate_values_', cached_values_getter)
else:
approx_values_getter = getattr(cls, '_value_equality_approximate_values_')
cached_approx_values_getter = (
approx_values_getter if unhashable else _compat.cached_method(approx_values_getter)
)
setattr(cls, '_value_equality_approximate_values_', cached_approx_values_getter)
setattr(cls, '_approx_eq_', _value_equality_approx_eq)

return cls
Expand Down

0 comments on commit eddb2d9

Please sign in to comment.