From eb3d84c4cbf8782d9c2c06d807182dd767770efe Mon Sep 17 00:00:00 2001 From: Dax Fohl Date: Wed, 8 Sep 2021 10:14:56 -0700 Subject: [PATCH] Replace decomposition in measurement_key_protocol with explicit implementations (#4471) * Remove decomposition in measurement_key_protocol * deprecation * Fix deprecation description * Simplify, removing unused method. * mypy * reverse function dependency order to take advantage of FrozenCircuit caching * cache moment * Improve deprecation message --- cirq-core/cirq/circuits/circuit.py | 5 +- cirq-core/cirq/ops/moment.py | 9 +++ .../protocols/measurement_key_protocol.py | 71 ++++++------------- .../measurement_key_protocol_test.py | 31 +++----- 4 files changed, 44 insertions(+), 72 deletions(-) diff --git a/cirq-core/cirq/circuits/circuit.py b/cirq-core/cirq/circuits/circuit.py index e70d67040e9..a4c21bf6831 100644 --- a/cirq-core/cirq/circuits/circuit.py +++ b/cirq-core/cirq/circuits/circuit.py @@ -909,7 +909,10 @@ def qid_shape( return protocols.qid_shape(qids) def all_measurement_key_names(self) -> AbstractSet[str]: - return protocols.measurement_key_names(self) + return {key for op in self.all_operations() for key in protocols.measurement_key_names(op)} + + def _measurement_key_names_(self) -> AbstractSet[str]: + return self.all_measurement_key_names() def _with_measurement_key_mapping_(self, key_map: Dict[str, str]): return self._with_sliced_moments( diff --git a/cirq-core/cirq/ops/moment.py b/cirq-core/cirq/ops/moment.py index 8a3ad36d299..cfe871e4bec 100644 --- a/cirq-core/cirq/ops/moment.py +++ b/cirq-core/cirq/ops/moment.py @@ -15,6 +15,7 @@ """A simplified time-slice of operations within a sequenced circuit.""" from typing import ( + AbstractSet, Any, Callable, Dict, @@ -91,6 +92,7 @@ def __init__(self, *contents: 'cirq.OP_TREE') -> None: self._qubit_to_op[q] = op self._qubits = frozenset(self._qubit_to_op.keys()) + self._measurement_key_names: Optional[AbstractSet[str]] = None @property def operations(self) -> Tuple['cirq.Operation', ...]: @@ -217,6 +219,13 @@ def _with_measurement_key_mapping_(self, key_map: Dict[str, str]): for op in self.operations ) + def _measurement_key_names_(self) -> AbstractSet[str]: + if self._measurement_key_names is None: + self._measurement_key_names = { + key for op in self.operations for key in protocols.measurement_key_names(op) + } + return self._measurement_key_names + def _with_key_path_(self, path: Tuple[str, ...]): return Moment( protocols.with_key_path(op, path) if protocols.is_measurement(op) else op diff --git a/cirq-core/cirq/protocols/measurement_key_protocol.py b/cirq-core/cirq/protocols/measurement_key_protocol.py index a9da9aa1a3f..af237d21892 100644 --- a/cirq-core/cirq/protocols/measurement_key_protocol.py +++ b/cirq-core/cirq/protocols/measurement_key_protocol.py @@ -17,9 +17,8 @@ from typing_extensions import Protocol -from cirq._compat import deprecated, _warn_or_error +from cirq._compat import deprecated, deprecated_parameter, _warn_or_error from cirq._doc import doc_private -from cirq.protocols.decompose_protocol import _try_decompose_into_operations_and_qubits # This is a special indicator value used by the inverse method to determine # whether or not the caller provided a 'default' argument. @@ -155,6 +154,13 @@ def measurement_keys(val: Any, *, allow_decompose: bool = True): return measurement_key_names(val, allow_decompose=allow_decompose) +@deprecated_parameter( + deadline='v0.14', + fix='This protocol no longer uses decomposition, so allow_decompose should be removed', + func_name='measurement_key_names', + parameter_desc='allow_decompose', + match=lambda args, kwargs: 'allow_decompose' in kwargs, +) def measurement_key_names(val: Any, *, allow_decompose: bool = True) -> AbstractSet[str]: """Gets the measurement keys of measurements within the given value. @@ -174,12 +180,6 @@ def measurement_key_names(val: Any, *, allow_decompose: bool = True) -> Abstract result = _measurement_key_names_from_magic_methods(val) if result is not NotImplemented and result is not None: return result - - if allow_decompose: - operations, _, _ = _try_decompose_into_operations_and_qubits(val) - if operations is not None: - return {key for op in operations for key in measurement_key_names(op)} - return set() @@ -189,47 +189,13 @@ def _is_measurement_from_magic_method(val: Any) -> Optional[bool]: return NotImplemented if getter is None else getter() -def _is_any_measurement(vals: List[Any], allow_decompose: bool) -> bool: - """Given a list of objects, returns True if any of them is a measurement. - - If `allow_decompose` is True, decomposes the objects and runs the measurement checks on the - constituent decomposed operations. But a decompose operation is only called if all cheaper - checks are done. A BFS for searching measurements, where "depth" is each level of decompose. - """ - vals_to_decompose = [] # type: List[Any] - while vals: - val = vals.pop(0) - result = _is_measurement_from_magic_method(val) - if result is not NotImplemented: - if result is True: - return True - if result is False: - # Do not try any other strategies if `val` was explicitly marked as - # "not measurement". - continue - - keys = _measurement_key_names_from_magic_methods(val) - if keys is not NotImplemented and bool(keys) is True: - return True - - if allow_decompose: - vals_to_decompose.append(val) - - # If vals has finished iterating over, keep decomposing from vals_to_decompose until vals - # is populated with something. - while not vals: - if not vals_to_decompose: - # Nothing left to process, this is not a measurement. - return False - operations, _, _ = _try_decompose_into_operations_and_qubits(vals_to_decompose.pop(0)) - if operations: - # Reverse the decomposed operations because measurements are typically at later - # moments. - vals = operations[::-1] - - return False - - +@deprecated_parameter( + deadline='v0.14', + fix='This protocol no longer uses decomposition, so allow_decompose should be removed', + func_name='is_measurement', + parameter_desc='allow_decompose', + match=lambda args, kwargs: 'allow_decompose' in kwargs, +) def is_measurement(val: Any, allow_decompose: bool = True) -> bool: """Determines whether or not the given value is a measurement (or contains one). @@ -242,7 +208,12 @@ def is_measurement(val: Any, allow_decompose: bool = True) -> bool: don't directly specify their `_is_measurement_` property will be decomposed in order to find any measurements keys within the decomposed operations. """ - return _is_any_measurement([val], allow_decompose) + result = _is_measurement_from_magic_method(val) + if isinstance(result, bool): + return result + + keys = _measurement_key_names_from_magic_methods(val) + return keys is not NotImplemented and bool(keys) def with_measurement_key_mapping(val: Any, key_map: Dict[str, str]): diff --git a/cirq-core/cirq/protocols/measurement_key_protocol_test.py b/cirq-core/cirq/protocols/measurement_key_protocol_test.py index 12bee12accd..fab7518df7e 100644 --- a/cirq-core/cirq/protocols/measurement_key_protocol_test.py +++ b/cirq-core/cirq/protocols/measurement_key_protocol_test.py @@ -130,15 +130,6 @@ def num_qubits(self) -> int: def test_measurement_keys(): - class Composite(cirq.Gate): - def _decompose_(self, qubits): - yield cirq.measure(qubits[0], key='inner1') - yield cirq.measure(qubits[1], key='inner2') - yield cirq.reset(qubits[0]) - - def num_qubits(self) -> int: - return 2 - class MeasurementKeysGate(cirq.Gate): def _measurement_key_names_(self): return ['a', 'b'] @@ -154,28 +145,26 @@ def num_qubits(self) -> int: return 1 a, b = cirq.LineQubit.range(2) - assert cirq.is_measurement(Composite()) - with cirq.testing.assert_deprecated(deadline="v0.13"): - assert cirq.measurement_keys(Composite()) == {'inner1', 'inner2'} with cirq.testing.assert_deprecated(deadline="v0.13"): assert cirq.measurement_key_names(DeprecatedMagicMethod()) == {'a', 'b'} with cirq.testing.assert_deprecated(deadline="v0.13"): assert cirq.measurement_key_names(DeprecatedMagicMethod().on(a)) == {'a', 'b'} - assert cirq.measurement_key_names(Composite()) == {'inner1', 'inner2'} - assert cirq.measurement_key_names(Composite().on(a, b)) == {'inner1', 'inner2'} - assert not cirq.is_measurement(Composite(), allow_decompose=False) - assert cirq.measurement_key_names(Composite(), allow_decompose=False) == set() - assert cirq.measurement_key_names(Composite().on(a, b), allow_decompose=False) == set() assert cirq.measurement_key_names(None) == set() assert cirq.measurement_key_names([]) == set() assert cirq.measurement_key_names(cirq.X) == set() assert cirq.measurement_key_names(cirq.X(a)) == set() - assert cirq.measurement_key_names(None, allow_decompose=False) == set() - assert cirq.measurement_key_names([], allow_decompose=False) == set() - assert cirq.measurement_key_names(cirq.X, allow_decompose=False) == set() + with cirq.testing.assert_deprecated(deadline="v0.14"): + assert cirq.measurement_key_names(None, allow_decompose=False) == set() + with cirq.testing.assert_deprecated(deadline="v0.14"): + assert cirq.measurement_key_names([], allow_decompose=False) == set() + with cirq.testing.assert_deprecated(deadline="v0.14"): + assert cirq.measurement_key_names(cirq.X, allow_decompose=False) == set() assert cirq.measurement_key_names(cirq.measure(a, key='out')) == {'out'} - assert cirq.measurement_key_names(cirq.measure(a, key='out'), allow_decompose=False) == {'out'} + with cirq.testing.assert_deprecated(deadline="v0.14"): + assert cirq.measurement_key_names(cirq.measure(a, key='out'), allow_decompose=False) == { + 'out' + } assert cirq.measurement_key_names( cirq.Circuit(cirq.measure(a, key='a'), cirq.measure(b, key='2'))