Skip to content

Commit

Permalink
Add all_measurement_keys method to circuit (#2868)
Browse files Browse the repository at this point in the history
Fixes #2863
  • Loading branch information
Strilanc committed Apr 28, 2020
1 parent 675e4f0 commit d147dda
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 0 deletions.
8 changes: 8 additions & 0 deletions cirq/circuits/circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -1445,6 +1445,14 @@ def qid_shape(self,
self.all_qubits())
return protocols.qid_shape(qids)

def all_measurement_keys(self) -> List[str]:
result = []
for op in self.all_operations():
key = protocols.measurement_key(op, default=None)
if key is not None:
result.append(key)
return result

def _qid_shape_(self) -> Tuple[int, ...]:
return self.qid_shape()

Expand Down
37 changes: 37 additions & 0 deletions cirq/circuits/circuit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3724,3 +3724,40 @@ def test_indexing_by_numpy_integer():

assert c[np.int32(1)] == cirq.Moment([cirq.Y(q)])
assert c[np.int64(1)] == cirq.Moment([cirq.Y(q)])


def test_all_measurement_keys():

class Unknown(cirq.SingleQubitGate):

def _measurement_key_(self):
return 'test'

a, b = cirq.LineQubit.range(2)
c = cirq.Circuit(
cirq.X(a),
cirq.CNOT(a, b),
cirq.measure(a, key='x'),
cirq.measure(b, key='y'),
cirq.reset(a),
cirq.measure(a, b, key='xy'),
Unknown().on(a),
)

# Big case.
assert c.all_measurement_keys() == ['x', 'y', 'xy', 'test']

# Empty case.
assert cirq.Circuit().all_measurement_keys() == []

# Output order matches insertion order, not qubit order.
assert cirq.Circuit(
cirq.Moment([
cirq.measure(a, key='x'),
cirq.measure(b, key='y'),
])).all_measurement_keys() == ['x', 'y']
assert cirq.Circuit(
cirq.Moment([
cirq.measure(b, key='y'),
cirq.measure(a, key='x'),
])).all_measurement_keys() == ['y', 'x']

0 comments on commit d147dda

Please sign in to comment.