Skip to content

Commit

Permalink
Replace decomposition in measurement_key_protocol with explicit imple…
Browse files Browse the repository at this point in the history
…mentations (#4471)

* Remove decomposition in measurement_key_protocol

* deprecation

* Fix deprecation description

* Simplify, removing unused method.

* mypy

* reverse function dependency order to take advantage of FrozenCircuit caching

* cache moment

* Improve deprecation message
  • Loading branch information
daxfohl committed Sep 8, 2021
1 parent 676431d commit eb3d84c
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 72 deletions.
5 changes: 4 additions & 1 deletion cirq-core/cirq/circuits/circuit.py
Expand Up @@ -909,7 +909,10 @@ def qid_shape(
return protocols.qid_shape(qids)

def all_measurement_key_names(self) -> AbstractSet[str]:
return protocols.measurement_key_names(self)
return {key for op in self.all_operations() for key in protocols.measurement_key_names(op)}

def _measurement_key_names_(self) -> AbstractSet[str]:
return self.all_measurement_key_names()

def _with_measurement_key_mapping_(self, key_map: Dict[str, str]):
return self._with_sliced_moments(
Expand Down
9 changes: 9 additions & 0 deletions cirq-core/cirq/ops/moment.py
Expand Up @@ -15,6 +15,7 @@
"""A simplified time-slice of operations within a sequenced circuit."""

from typing import (
AbstractSet,
Any,
Callable,
Dict,
Expand Down Expand Up @@ -91,6 +92,7 @@ def __init__(self, *contents: 'cirq.OP_TREE') -> None:
self._qubit_to_op[q] = op

self._qubits = frozenset(self._qubit_to_op.keys())
self._measurement_key_names: Optional[AbstractSet[str]] = None

@property
def operations(self) -> Tuple['cirq.Operation', ...]:
Expand Down Expand Up @@ -217,6 +219,13 @@ def _with_measurement_key_mapping_(self, key_map: Dict[str, str]):
for op in self.operations
)

def _measurement_key_names_(self) -> AbstractSet[str]:
if self._measurement_key_names is None:
self._measurement_key_names = {
key for op in self.operations for key in protocols.measurement_key_names(op)
}
return self._measurement_key_names

def _with_key_path_(self, path: Tuple[str, ...]):
return Moment(
protocols.with_key_path(op, path) if protocols.is_measurement(op) else op
Expand Down
71 changes: 21 additions & 50 deletions cirq-core/cirq/protocols/measurement_key_protocol.py
Expand Up @@ -17,9 +17,8 @@

from typing_extensions import Protocol

from cirq._compat import deprecated, _warn_or_error
from cirq._compat import deprecated, deprecated_parameter, _warn_or_error
from cirq._doc import doc_private
from cirq.protocols.decompose_protocol import _try_decompose_into_operations_and_qubits

# This is a special indicator value used by the inverse method to determine
# whether or not the caller provided a 'default' argument.
Expand Down Expand Up @@ -155,6 +154,13 @@ def measurement_keys(val: Any, *, allow_decompose: bool = True):
return measurement_key_names(val, allow_decompose=allow_decompose)


@deprecated_parameter(
deadline='v0.14',
fix='This protocol no longer uses decomposition, so allow_decompose should be removed',
func_name='measurement_key_names',
parameter_desc='allow_decompose',
match=lambda args, kwargs: 'allow_decompose' in kwargs,
)
def measurement_key_names(val: Any, *, allow_decompose: bool = True) -> AbstractSet[str]:
"""Gets the measurement keys of measurements within the given value.
Expand All @@ -174,12 +180,6 @@ def measurement_key_names(val: Any, *, allow_decompose: bool = True) -> Abstract
result = _measurement_key_names_from_magic_methods(val)
if result is not NotImplemented and result is not None:
return result

if allow_decompose:
operations, _, _ = _try_decompose_into_operations_and_qubits(val)
if operations is not None:
return {key for op in operations for key in measurement_key_names(op)}

return set()


Expand All @@ -189,47 +189,13 @@ def _is_measurement_from_magic_method(val: Any) -> Optional[bool]:
return NotImplemented if getter is None else getter()


def _is_any_measurement(vals: List[Any], allow_decompose: bool) -> bool:
"""Given a list of objects, returns True if any of them is a measurement.
If `allow_decompose` is True, decomposes the objects and runs the measurement checks on the
constituent decomposed operations. But a decompose operation is only called if all cheaper
checks are done. A BFS for searching measurements, where "depth" is each level of decompose.
"""
vals_to_decompose = [] # type: List[Any]
while vals:
val = vals.pop(0)
result = _is_measurement_from_magic_method(val)
if result is not NotImplemented:
if result is True:
return True
if result is False:
# Do not try any other strategies if `val` was explicitly marked as
# "not measurement".
continue

keys = _measurement_key_names_from_magic_methods(val)
if keys is not NotImplemented and bool(keys) is True:
return True

if allow_decompose:
vals_to_decompose.append(val)

# If vals has finished iterating over, keep decomposing from vals_to_decompose until vals
# is populated with something.
while not vals:
if not vals_to_decompose:
# Nothing left to process, this is not a measurement.
return False
operations, _, _ = _try_decompose_into_operations_and_qubits(vals_to_decompose.pop(0))
if operations:
# Reverse the decomposed operations because measurements are typically at later
# moments.
vals = operations[::-1]

return False


@deprecated_parameter(
deadline='v0.14',
fix='This protocol no longer uses decomposition, so allow_decompose should be removed',
func_name='is_measurement',
parameter_desc='allow_decompose',
match=lambda args, kwargs: 'allow_decompose' in kwargs,
)
def is_measurement(val: Any, allow_decompose: bool = True) -> bool:
"""Determines whether or not the given value is a measurement (or contains one).
Expand All @@ -242,7 +208,12 @@ def is_measurement(val: Any, allow_decompose: bool = True) -> bool:
don't directly specify their `_is_measurement_` property will be decomposed in
order to find any measurements keys within the decomposed operations.
"""
return _is_any_measurement([val], allow_decompose)
result = _is_measurement_from_magic_method(val)
if isinstance(result, bool):
return result

keys = _measurement_key_names_from_magic_methods(val)
return keys is not NotImplemented and bool(keys)


def with_measurement_key_mapping(val: Any, key_map: Dict[str, str]):
Expand Down
31 changes: 10 additions & 21 deletions cirq-core/cirq/protocols/measurement_key_protocol_test.py
Expand Up @@ -130,15 +130,6 @@ def num_qubits(self) -> int:


def test_measurement_keys():
class Composite(cirq.Gate):
def _decompose_(self, qubits):
yield cirq.measure(qubits[0], key='inner1')
yield cirq.measure(qubits[1], key='inner2')
yield cirq.reset(qubits[0])

def num_qubits(self) -> int:
return 2

class MeasurementKeysGate(cirq.Gate):
def _measurement_key_names_(self):
return ['a', 'b']
Expand All @@ -154,28 +145,26 @@ def num_qubits(self) -> int:
return 1

a, b = cirq.LineQubit.range(2)
assert cirq.is_measurement(Composite())
with cirq.testing.assert_deprecated(deadline="v0.13"):
assert cirq.measurement_keys(Composite()) == {'inner1', 'inner2'}
with cirq.testing.assert_deprecated(deadline="v0.13"):
assert cirq.measurement_key_names(DeprecatedMagicMethod()) == {'a', 'b'}
with cirq.testing.assert_deprecated(deadline="v0.13"):
assert cirq.measurement_key_names(DeprecatedMagicMethod().on(a)) == {'a', 'b'}
assert cirq.measurement_key_names(Composite()) == {'inner1', 'inner2'}
assert cirq.measurement_key_names(Composite().on(a, b)) == {'inner1', 'inner2'}
assert not cirq.is_measurement(Composite(), allow_decompose=False)
assert cirq.measurement_key_names(Composite(), allow_decompose=False) == set()
assert cirq.measurement_key_names(Composite().on(a, b), allow_decompose=False) == set()

assert cirq.measurement_key_names(None) == set()
assert cirq.measurement_key_names([]) == set()
assert cirq.measurement_key_names(cirq.X) == set()
assert cirq.measurement_key_names(cirq.X(a)) == set()
assert cirq.measurement_key_names(None, allow_decompose=False) == set()
assert cirq.measurement_key_names([], allow_decompose=False) == set()
assert cirq.measurement_key_names(cirq.X, allow_decompose=False) == set()
with cirq.testing.assert_deprecated(deadline="v0.14"):
assert cirq.measurement_key_names(None, allow_decompose=False) == set()
with cirq.testing.assert_deprecated(deadline="v0.14"):
assert cirq.measurement_key_names([], allow_decompose=False) == set()
with cirq.testing.assert_deprecated(deadline="v0.14"):
assert cirq.measurement_key_names(cirq.X, allow_decompose=False) == set()
assert cirq.measurement_key_names(cirq.measure(a, key='out')) == {'out'}
assert cirq.measurement_key_names(cirq.measure(a, key='out'), allow_decompose=False) == {'out'}
with cirq.testing.assert_deprecated(deadline="v0.14"):
assert cirq.measurement_key_names(cirq.measure(a, key='out'), allow_decompose=False) == {
'out'
}

assert cirq.measurement_key_names(
cirq.Circuit(cirq.measure(a, key='a'), cirq.measure(b, key='2'))
Expand Down

0 comments on commit eb3d84c

Please sign in to comment.