Skip to content

Commit

Permalink
Add tags to MeasureInfo (#3694)
Browse files Browse the repository at this point in the history
* Add tags to MeasureInfo

- When we call cirq.google.api.v2.find_measurements, we need a way
to retrieve tags that were originally on the circuit.
- This adds tags to the NamedTuple that is returned by this function
that includes tags.
- This also changes MeasureInfo into a dataclass.

This does change the output of this function, but, since it is an
additive change, it should be minimally breaking.
  • Loading branch information
dstrain115 committed Jan 25, 2021
1 parent 680ae21 commit fbafbd5
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 19 deletions.
35 changes: 21 additions & 14 deletions cirq/google/api/v2/results.py
Expand Up @@ -11,10 +11,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import cast, Dict, Iterable, Iterator, List, NamedTuple, Optional, Set, TYPE_CHECKING

from typing import (
cast,
Dict,
Hashable,
Iterable,
Iterator,
List,
Optional,
Set,
TYPE_CHECKING,
)
from collections import OrderedDict
import dataclasses
import numpy as np

from cirq.google.api import v2
Expand All @@ -28,17 +37,8 @@
import cirq


class MeasureInfo(
NamedTuple(
'MeasureInfo',
[
('key', str),
('qubits', List['cirq.GridQubit']),
('slot', int),
('invert_mask', List[bool]),
],
)
):
@dataclasses.dataclass
class MeasureInfo:
"""Extra info about a single measurement within a circuit.
Attributes:
Expand All @@ -53,6 +53,12 @@ class MeasureInfo(
be flipped for each of the qubits in the qubits field.
"""

key: str
qubits: List['cirq.GridQubit']
slot: int
invert_mask: List[bool]
tags: List[Hashable]


def find_measurements(program: 'cirq.Circuit') -> List[MeasureInfo]:
"""Find measurements in the given program (circuit).
Expand Down Expand Up @@ -86,6 +92,7 @@ def _circuit_measurements(circuit: 'cirq.Circuit') -> Iterator[MeasureInfo]:
qubits=_grid_qubits(op),
slot=i,
invert_mask=list(op.gate.full_invert_mask()),
tags=list(op.tags),
)


Expand Down
35 changes: 30 additions & 5 deletions cirq/google/api/v2/results_test.py
Expand Up @@ -18,7 +18,7 @@ def test_pack_bits(reps):
q = cirq.GridQubit # For brevity.


def _check_measurement(m, key, qubits, slot, invert_mask=None):
def _check_measurement(m, key, qubits, slot, invert_mask=None, tags=None):
assert m.key == key
assert m.qubits == qubits
assert m.slot == slot
Expand All @@ -27,6 +27,10 @@ def _check_measurement(m, key, qubits, slot, invert_mask=None):
else:
assert len(m.invert_mask) == len(m.qubits)
assert m.invert_mask == [False] * len(m.qubits)
if tags is not None:
assert m.tags == tags
else:
assert len(m.tags) == 0


def test_find_measurements_simple_circuit():
Expand All @@ -51,6 +55,27 @@ def test_find_measurements_invert_mask():
_check_measurement(m, 'k', [q(0, 0), q(0, 1), q(0, 2)], 0, [False, True, True])


def test_find_measurements_with_tags():
circuit = cirq.Circuit()
circuit.append(
cirq.measure(q(0, 0), q(0, 1), q(0, 2), key='k', invert_mask=[False, True, True]).with_tags(
cirq.google.CalibrationTag('special')
)
)
measurements = v2.find_measurements(circuit)

assert len(measurements) == 1
m = measurements[0]
_check_measurement(
m,
'k',
[q(0, 0), q(0, 1), q(0, 2)],
0,
[False, True, True],
[cirq.google.CalibrationTag('special')],
)


def test_find_measurements_fill_mask():
circuit = cirq.Circuit()
circuit.append(cirq.measure(q(0, 0), q(0, 1), q(0, 2), key='k', invert_mask=[False, True]))
Expand Down Expand Up @@ -110,7 +135,7 @@ def test_multiple_measurements_shared_slots():


def test_results_to_proto():
measurements = [v2.MeasureInfo('foo', [q(0, 0)], slot=0, invert_mask=[False])]
measurements = [v2.MeasureInfo('foo', [q(0, 0)], slot=0, invert_mask=[False], tags=[])]
trial_results = [
[
cirq.Result.from_single_parameter_set(
Expand Down Expand Up @@ -157,7 +182,7 @@ def test_results_to_proto():


def test_results_to_proto_sweep_repetitions():
measurements = [v2.MeasureInfo('foo', [q(0, 0)], slot=0, invert_mask=[False])]
measurements = [v2.MeasureInfo('foo', [q(0, 0)], slot=0, invert_mask=[False], tags=[])]
trial_results = [
[
cirq.Result.from_single_parameter_set(
Expand All @@ -181,7 +206,7 @@ def test_results_to_proto_sweep_repetitions():
def test_results_from_proto_qubit_ordering():
measurements = [
v2.MeasureInfo(
'foo', [q(0, 0), q(0, 1), q(1, 1)], slot=0, invert_mask=[False, False, False]
'foo', [q(0, 0), q(0, 1), q(1, 1)], slot=0, invert_mask=[False, False, False], tags=[]
)
]
proto = v2.result_pb2.Result()
Expand Down Expand Up @@ -225,7 +250,7 @@ def test_results_from_proto_qubit_ordering():
def test_results_from_proto_duplicate_qubit():
measurements = [
v2.MeasureInfo(
'foo', [q(0, 0), q(0, 1), q(1, 1)], slot=0, invert_mask=[False, False, False]
'foo', [q(0, 0), q(0, 1), q(1, 1)], slot=0, invert_mask=[False, False, False], tags=[]
)
]
proto = v2.result_pb2.Result()
Expand Down

0 comments on commit fbafbd5

Please sign in to comment.