Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speed up parameter resolution by checking if val is parameterized and caching the boolean #6023

Merged
merged 4 commits into from
Mar 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
14 changes: 13 additions & 1 deletion cirq-core/cirq/circuits/frozen_circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""An immutable version of the Circuit data structure."""
from typing import FrozenSet, Iterable, Iterator, Sequence, Tuple, TYPE_CHECKING, Union
from typing import AbstractSet, FrozenSet, Iterable, Iterator, Sequence, Tuple, TYPE_CHECKING, Union

import numpy as np

Expand Down Expand Up @@ -83,6 +83,10 @@ def _num_qubits_(self) -> int:
def _qid_shape_(self) -> Tuple[int, ...]:
return super()._qid_shape_()

@_compat.cached_method
def _has_unitary_(self) -> bool:
return super()._has_unitary_()

@_compat.cached_method
def _unitary_(self) -> Union[np.ndarray, NotImplementedType]:
return super()._unitary_()
Expand Down Expand Up @@ -124,6 +128,14 @@ def are_all_measurements_terminal(self) -> bool:
def all_measurement_key_names(self) -> FrozenSet[str]:
return frozenset(str(key) for key in self.all_measurement_key_objs())

@_compat.cached_method
def _is_parameterized_(self) -> bool:
return super()._is_parameterized_()

@_compat.cached_method
def _parameter_names_(self) -> AbstractSet[str]:
return super()._parameter_names_()

def _measurement_key_names_(self) -> FrozenSet[str]:
return self.all_measurement_key_names()

Expand Down
7 changes: 6 additions & 1 deletion cirq-core/cirq/circuits/moment.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@

import numpy as np

from cirq import protocols, ops, qis
from cirq import protocols, ops, qis, _compat
from cirq._import import LazyLoader
from cirq.ops import raw_types, op_tree
from cirq.protocols import circuit_diagram_info_protocol
Expand Down Expand Up @@ -237,9 +237,11 @@ def without_operations_touching(self, qubits: Iterable['cirq.Qid']) -> 'cirq.Mom
if qubits.isdisjoint(frozenset(operation.qubits))
)

@_compat.cached_method()
def _is_parameterized_(self) -> bool:
return any(protocols.is_parameterized(op) for op in self)

@_compat.cached_method()
def _parameter_names_(self) -> AbstractSet[str]:
return {name for op in self for name in protocols.parameter_names(op)}

Expand All @@ -265,6 +267,7 @@ def _with_measurement_key_mapping_(self, key_map: Mapping[str, str]):
for op in self.operations
)

@_compat.cached_method()
def _measurement_key_names_(self) -> FrozenSet[str]:
return frozenset(str(key) for key in self._measurement_key_objs_())

Expand Down Expand Up @@ -332,6 +335,7 @@ def _approx_eq_(self, other: Any, atol: Union[int, float]) -> bool:
def __ne__(self, other) -> bool:
return not self == other

@_compat.cached_method()
def __hash__(self):
return hash((Moment, self._sorted_operations_()))

Expand Down Expand Up @@ -405,6 +409,7 @@ def expand_to(self, qubits: Iterable['cirq.Qid']) -> 'cirq.Moment':
operations.append(ops.I(q))
return Moment(*operations)

@_compat.cached_method()
def _has_kraus_(self) -> bool:
"""Returns True if self has a Kraus representation and self uses <= 10 qubits."""
return all(protocols.has_kraus(op) for op in self.operations) and len(self.qubits) <= 10
Expand Down
7 changes: 6 additions & 1 deletion cirq-core/cirq/ops/pauli_sum_exponential.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import numpy as np
import sympy

from cirq import linalg, protocols, value
from cirq import linalg, protocols, value, _compat
from cirq.ops import linear_combinations, pauli_string_phasor

if TYPE_CHECKING:
Expand Down Expand Up @@ -78,6 +78,10 @@ def _value_equality_values_(self) -> Any:
def with_qubits(self, *new_qubits: 'cirq.Qid') -> 'PauliSumExponential':
return PauliSumExponential(self._pauli_sum.with_qubits(*new_qubits), self._exponent)

@_compat.cached_method
def _is_parameterized_(self) -> bool:
tanujkhattar marked this conversation as resolved.
Show resolved Hide resolved
return protocols.is_parameterized(self._exponent)

def _resolve_parameters_(
self, resolver: 'cirq.ParamResolver', recursive: bool
) -> 'PauliSumExponential':
Expand Down Expand Up @@ -109,6 +113,7 @@ def matrix(self) -> np.ndarray:
ret = np.kron(ret, protocols.unitary(pauli_string_exp))
return ret

@_compat.cached_method
def _has_unitary_(self) -> bool:
return linalg.is_unitary(self.matrix())

Expand Down
17 changes: 15 additions & 2 deletions cirq-core/cirq/ops/raw_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@

from cirq import protocols, value
from cirq._import import LazyLoader
from cirq._compat import __cirq_debug__
from cirq._compat import __cirq_debug__, cached_method
from cirq.type_workarounds import NotImplementedType
from cirq.ops import control_values as cv

Expand Down Expand Up @@ -111,7 +111,8 @@ def with_dimension(self, dimension: int) -> 'Qid':
def _cmp_tuple(self):
return (type(self).__name__, repr(type(self)), self._comparison_key(), self.dimension)

def __hash__(self):
@cached_method
def __hash__(self) -> int:
return hash((Qid, self._comparison_key()))

def __eq__(self, other):
Expand Down Expand Up @@ -505,6 +506,7 @@ def _num_qubits_(self) -> int:
"""
return len(self.qubits)

@cached_method
def _qid_shape_(self) -> Tuple[int, ...]:
return protocols.qid_shape(self.qubits)

Expand Down Expand Up @@ -839,6 +841,7 @@ def _apply_unitary_(
) -> Union[np.ndarray, None, NotImplementedType]:
return protocols.apply_unitary(self.sub_operation, args, default=None)

@cached_method
def _has_unitary_(self) -> bool:
return protocols.has_unitary(self.sub_operation)

Expand All @@ -850,30 +853,36 @@ def _commutes_(
) -> Union[bool, NotImplementedType, None]:
return protocols.commutes(self.sub_operation, other, atol=atol)

@cached_method
def _has_mixture_(self) -> bool:
return protocols.has_mixture(self.sub_operation)

def _mixture_(self) -> Sequence[Tuple[float, Any]]:
return protocols.mixture(self.sub_operation, NotImplemented)

@cached_method
def _has_kraus_(self) -> bool:
return protocols.has_kraus(self.sub_operation)

def _kraus_(self) -> Union[Tuple[np.ndarray], NotImplementedType]:
return protocols.kraus(self.sub_operation, NotImplemented)

@cached_method
def _measurement_key_names_(self) -> FrozenSet[str]:
return protocols.measurement_key_names(self.sub_operation)

@cached_method
def _measurement_key_objs_(self) -> FrozenSet['cirq.MeasurementKey']:
return protocols.measurement_key_objs(self.sub_operation)

@cached_method
def _is_measurement_(self) -> bool:
sub = getattr(self.sub_operation, "_is_measurement_", None)
if sub is not None:
return sub()
return NotImplemented

@cached_method
def _is_parameterized_(self) -> bool:
return protocols.is_parameterized(self.sub_operation) or any(
protocols.is_parameterized(tag) for tag in self.tags
Expand All @@ -885,6 +894,7 @@ def _act_on_(self, sim_state: 'cirq.SimulationStateBase') -> bool:
return sub(sim_state)
return NotImplemented

@cached_method
def _parameter_names_(self) -> AbstractSet[str]:
tag_params = {name for tag in self.tags for name in protocols.parameter_names(tag)}
return protocols.parameter_names(self.sub_operation) | tag_params
Expand All @@ -909,6 +919,7 @@ def _circuit_diagram_info_(
) + sub_op_info.wire_symbols[1:]
return sub_op_info

@cached_method
def _trace_distance_bound_(self) -> float:
return protocols.trace_distance_bound(self.sub_operation)

Expand Down Expand Up @@ -980,9 +991,11 @@ def _has_unitary_(self):
for op in protocols.decompose_once_with_qubits(self._original, qubits)
)

@cached_method
def _is_parameterized_(self) -> bool:
return protocols.is_parameterized(self._original)

@cached_method
def _parameter_names_(self) -> AbstractSet[str]:
return protocols.parameter_names(self._original)

Expand Down
4 changes: 4 additions & 0 deletions cirq-core/cirq/protocols/resolve_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,10 @@ def resolve_parameters(
if isinstance(val, (list, tuple)):
return cast(T, type(val)(resolve_parameters(e, param_resolver, recursive) for e in val))

is_parameterized = getattr(val, '_is_parameterized_', None)
if is_parameterized is not None and not is_parameterized():
return val

getter = getattr(val, '_resolve_parameters_', None)
if getter is None:
result = NotImplemented
Expand Down
6 changes: 3 additions & 3 deletions cirq-core/cirq/protocols/resolve_parameters_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,16 @@ def __init__(self, var):
self.parameter = var

def _is_parameterized_(self) -> bool:
return self.parameter == 0
return self.parameter != 0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious, why this change?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because in the existing test, the _is_paramaterized_ returns True when the class is actually not paramaterized (i.e. self.parameter is 0 and there's nothing to resolve).

As a result cirq.is_parameterized(val) returns False when val is actually parameterized and therefore the output of cirq.resolve_parameters(val) changes.

So this change fixes the test to make the behavior of _is_paramaterized_ and resolve_parameters consistent.


def _resolve_parameters_(self, resolver: ParamResolver, recursive: bool):
self.parameter = resolver.value_of(self.parameter, recursive)
return self

assert not cirq.is_parameterized(NoMethod())
assert not cirq.is_parameterized(ReturnsNotImplemented())
assert not cirq.is_parameterized(SimpleParameterSwitch('a'))
assert cirq.is_parameterized(SimpleParameterSwitch(0))
assert cirq.is_parameterized(SimpleParameterSwitch('a'))
assert not cirq.is_parameterized(SimpleParameterSwitch(0))

ni = ReturnsNotImplemented()
d = {'a': 0}
Expand Down