Skip to content

Commit

Permalink
Add MeasurementGate.with_key to change measurement key (#3123)
Browse files Browse the repository at this point in the history
The "rekey" operation is useful for some circuit transformations we want to implement.
Defining it on the MeasurementGate ensures that all the gate properties get copied properly when rekeying.

Review: @balopat
  • Loading branch information
maffoo committed Jul 7, 2020
1 parent b37f226 commit 4e44446
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 14 deletions.
16 changes: 3 additions & 13 deletions cirq/google/api/v2/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class MeasureInfo(
"""


def find_measurements(program: 'cirq.Circuit',) -> List[MeasureInfo]:
def find_measurements(program: 'cirq.Circuit') -> List[MeasureInfo]:
"""Find measurements in the given program (circuit).
Returns:
Expand All @@ -77,12 +77,11 @@ def find_measurements(program: 'cirq.Circuit',) -> List[MeasureInfo]:
def _circuit_measurements(circuit: 'cirq.Circuit') -> Iterator[MeasureInfo]:
for i, moment in enumerate(circuit):
for op in moment:
if (isinstance(op, ops.GateOperation) and
isinstance(op.gate, ops.MeasurementGate)):
if isinstance(op.gate, ops.MeasurementGate):
yield MeasureInfo(key=op.gate.key,
qubits=_grid_qubits(op),
slot=i,
invert_mask=_full_mask(op))
invert_mask=list(op.gate.full_invert_mask()))


def _grid_qubits(op: 'cirq.Operation') -> List['cirq.GridQubit']:
Expand All @@ -91,15 +90,6 @@ def _grid_qubits(op: 'cirq.Operation') -> List['cirq.GridQubit']:
return cast(List['cirq.GridQubit'], list(op.qubits))


def _full_mask(op: 'cirq.GateOperation') -> List[bool]:
invert_mask = list(cast(ops.MeasurementGate, op.gate).invert_mask)
len_missing_mask = len(op.qubits) - len(invert_mask)
if len_missing_mask > 0:
return invert_mask + [False] * len_missing_mask
else:
return invert_mask


def pack_bits(bits: np.ndarray) -> bytes:
"""Pack bits given as a numpy array of bools into bytes."""
# Pad length to multiple of 8 if needed.
Expand Down
10 changes: 9 additions & 1 deletion cirq/ops/measurement_gate.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,13 @@ def __init__(self,
def _qid_shape_(self) -> Tuple[int, ...]:
return self._qid_shape

def with_key(self, key: str) -> 'MeasurementGate':
"""Creates a measurement gate with a new key but otherwise identical."""
return MeasurementGate(self.num_qubits(),
key=key,
invert_mask=self.invert_mask,
qid_shape=self._qid_shape)

def with_bits_flipped(self, *bit_positions: int) -> 'MeasurementGate':
"""Toggles whether or not the measurement inverts various outputs."""
old_mask = self.invert_mask or ()
Expand All @@ -81,7 +88,8 @@ def with_bits_flipped(self, *bit_positions: int) -> 'MeasurementGate':
new_mask[b] = not new_mask[b]
return MeasurementGate(self.num_qubits(),
key=self.key,
invert_mask=tuple(new_mask))
invert_mask=tuple(new_mask),
qid_shape=self._qid_shape)

def full_invert_mask(self):
"""Returns the invert mask for all qubits.
Expand Down
38 changes: 38 additions & 0 deletions cirq/ops/measurement_gate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,44 @@ def test_measurement_full_invert_mask():
2, 'a', invert_mask=(True,)).full_invert_mask() == (True, False))


@pytest.mark.parametrize('gate', [
cirq.MeasurementGate(1, 'a'),
cirq.MeasurementGate(1, 'a', invert_mask=(True,)),
cirq.MeasurementGate(1, 'a', qid_shape=(3,)),
cirq.MeasurementGate(2, 'a', invert_mask=(True, False), qid_shape=(2, 3)),
])
def test_measurement_with_key(gate):
gate1 = gate.with_key('b')
assert gate1.key == 'b'
assert gate1.num_qubits() == gate.num_qubits()
assert gate1.invert_mask == gate.invert_mask
assert cirq.qid_shape(gate1) == cirq.qid_shape(gate)
gate2 = gate1.with_key('a')
assert gate2 == gate


@pytest.mark.parametrize('num_qubits, mask, bits, flipped', [
(1, (), [0], (True,)),
(3, (False,), [1], (False, True)),
(3, (False, False), [0, 2], (True, False, True)),
])
def test_measurement_with_bits_flipped(num_qubits, mask, bits, flipped):
gate = cirq.MeasurementGate(num_qubits,
key='a',
invert_mask=mask,
qid_shape=(3,) * num_qubits)

gate1 = gate.with_bits_flipped(*bits)
assert gate1.key == gate.key
assert gate1.num_qubits() == gate.num_qubits()
assert gate1.invert_mask == flipped
assert cirq.qid_shape(gate1) == cirq.qid_shape(gate)

# Flipping bits again restores the mask (but may have extended it).
gate2 = gate1.with_bits_flipped(*bits)
assert gate2.full_invert_mask() == gate.full_invert_mask()


def test_qudit_measure_qasm():
assert cirq.qasm(cirq.measure(cirq.LineQid(0, 3), key='a'),
args=cirq.QasmArgs(),
Expand Down

0 comments on commit 4e44446

Please sign in to comment.