Skip to content

Commit

Permalink
Change measurement_keys protocol to return AbstractSet[str] (#3454)
Browse files Browse the repository at this point in the history
Fixes #3452
  • Loading branch information
maffoo committed Oct 28, 2020
1 parent c04dd6e commit 786822b
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 33 deletions.
2 changes: 1 addition & 1 deletion cirq/circuits/circuit.py
Expand Up @@ -1499,7 +1499,7 @@ def qid_shape(self,
self.all_qubits())
return protocols.qid_shape(qids)

def all_measurement_keys(self) -> Tuple[str, ...]:
def all_measurement_keys(self) -> AbstractSet[str]:
return protocols.measurement_keys(self)

def _qid_shape_(self) -> Tuple[int, ...]:
Expand Down
10 changes: 5 additions & 5 deletions cirq/circuits/circuit_test.py
Expand Up @@ -3816,22 +3816,22 @@ def _measurement_key_(self):
)

# Big case.
assert c.all_measurement_keys() == ('x', 'y', 'xy', 'test')
assert c.all_measurement_keys() == {'x', 'y', 'xy', 'test'}

# Empty case.
assert cirq.Circuit().all_measurement_keys() == ()
assert cirq.Circuit().all_measurement_keys() == set()

# Output order matches insertion order, not qubit order.
# Order does not matter.
assert cirq.Circuit(
cirq.Moment([
cirq.measure(a, key='x'),
cirq.measure(b, key='y'),
])).all_measurement_keys() == ('x', 'y')
])).all_measurement_keys() == {'x', 'y'}
assert cirq.Circuit(
cirq.Moment([
cirq.measure(b, key='y'),
cirq.measure(a, key='x'),
])).all_measurement_keys() == ('y', 'x')
])).all_measurement_keys() == {'x', 'y'}


def test_deprecated():
Expand Down
15 changes: 7 additions & 8 deletions cirq/protocols/measurement_key_protocol.py
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
"""Protocol for object that have measurement keys."""

from typing import Any, Iterable, Tuple
from typing import AbstractSet, Any, Iterable

from typing_extensions import Protocol

Expand Down Expand Up @@ -89,7 +89,7 @@ def measurement_key(val: Any, default: Any = RaiseTypeErrorIfNotProvided):
result = measurement_keys(val)

if len(result) == 1:
return result[0]
return next(iter(result))

if len(result) > 1:
raise ValueError(f'Got multiple measurement keys ({result!r}) '
Expand All @@ -102,7 +102,7 @@ def measurement_key(val: Any, default: Any = RaiseTypeErrorIfNotProvided):


def measurement_keys(val: Any, *,
allow_decompose: bool = True) -> Tuple[str, ...]:
allow_decompose: bool = True) -> AbstractSet[str]:
"""Gets the measurement keys of measurements within the given value.
Args:
Expand All @@ -121,20 +121,19 @@ def measurement_keys(val: Any, *,
getter = getattr(val, '_measurement_keys_', None)
result = NotImplemented if getter is None else getter()
if result is not NotImplemented and result is not None:
return tuple(result)
return set(result)

getter = getattr(val, '_measurement_key_', None)
result = NotImplemented if getter is None else getter()
if result is not NotImplemented and result is not None:
return result,
return {result}

if allow_decompose:
operations, _, _ = _try_decompose_into_operations_and_qubits(val)
if operations is not None:
return tuple(
key for op in operations for key in measurement_keys(op))
return {key for op in operations for key in measurement_keys(op)}

return ()
return set()


def is_measurement(val: Any) -> bool:
Expand Down
38 changes: 19 additions & 19 deletions cirq/protocols/measurement_key_protocol_test.py
Expand Up @@ -41,8 +41,8 @@ class NoMethod():

with pytest.raises(ValueError, match='multiple measurement keys'):
cirq.measurement_key(
cirq.Circuit(cirq.measure(cirq.LineQubit(0)),
cirq.measure(cirq.LineQubit(0))))
cirq.Circuit(cirq.measure(cirq.LineQubit(0), key='a'),
cirq.measure(cirq.LineQubit(0), key='b')))

assert cirq.measurement_key(NoMethod(), None) is None
assert cirq.measurement_key(NoMethod(), NotImplemented) is NotImplemented
Expand Down Expand Up @@ -110,25 +110,25 @@ def num_qubits(self) -> int:
return 1

a, b = cirq.LineQubit.range(2)
assert cirq.measurement_keys(Composite()) == ('inner1', 'inner2')
assert cirq.measurement_keys(Composite().on(a, b)) == ('inner1', 'inner2')
assert cirq.measurement_keys(Composite(), allow_decompose=False) == ()
assert cirq.measurement_keys(Composite()) == {'inner1', 'inner2'}
assert cirq.measurement_keys(Composite().on(a, b)) == {'inner1', 'inner2'}
assert cirq.measurement_keys(Composite(), allow_decompose=False) == set()
assert cirq.measurement_keys(Composite().on(a, b),
allow_decompose=False) == ()

assert cirq.measurement_keys(None) == ()
assert cirq.measurement_keys([]) == ()
assert cirq.measurement_keys(cirq.X) == ()
assert cirq.measurement_keys(cirq.X(a)) == ()
assert cirq.measurement_keys(None, allow_decompose=False) == ()
assert cirq.measurement_keys([], allow_decompose=False) == ()
assert cirq.measurement_keys(cirq.X, allow_decompose=False) == ()
assert cirq.measurement_keys(cirq.measure(a, key='out')) == ('out',)
allow_decompose=False) == set()

assert cirq.measurement_keys(None) == set()
assert cirq.measurement_keys([]) == set()
assert cirq.measurement_keys(cirq.X) == set()
assert cirq.measurement_keys(cirq.X(a)) == set()
assert cirq.measurement_keys(None, allow_decompose=False) == set()
assert cirq.measurement_keys([], allow_decompose=False) == set()
assert cirq.measurement_keys(cirq.X, allow_decompose=False) == set()
assert cirq.measurement_keys(cirq.measure(a, key='out')) == {'out'}
assert cirq.measurement_keys(cirq.measure(a, key='out'),
allow_decompose=False) == ('out',)
allow_decompose=False) == {'out'}

assert cirq.measurement_keys(
cirq.Circuit(cirq.measure(a, key='a'),
cirq.measure(b, key='2'))) == ('a', '2')
assert cirq.measurement_keys(MeasurementKeysGate()) == ('a', 'b')
assert cirq.measurement_keys(MeasurementKeysGate().on(a)) == ('a', 'b')
cirq.measure(b, key='2'))) == {'a', '2'}
assert cirq.measurement_keys(MeasurementKeysGate()) == {'a', 'b'}
assert cirq.measurement_keys(MeasurementKeysGate().on(a)) == {'a', 'b'}

0 comments on commit 786822b

Please sign in to comment.