Skip to content

Commit

Permalink
Guard confusion_map usage (#5534)
Browse files Browse the repository at this point in the history
Fixes the breakage created by #5480.
  • Loading branch information
95-martin-orion committed Jun 16, 2022
1 parent f4f9ac6 commit 25388af
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 4 deletions.
14 changes: 13 additions & 1 deletion cirq-core/cirq/ops/measurement_gate.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,19 @@ def _act_on_(self, sim_state: 'cirq.SimulationStateBase', qubits: Sequence['cirq

if not isinstance(sim_state, SimulationState):
return NotImplemented
sim_state.measure(qubits, self.key, self.full_invert_mask(), self.confusion_map)
try:
sim_state.measure(
qubits, self.key, self.full_invert_mask(), confusion_map=self.confusion_map
)
except TypeError as e:
# Ensure that the error was due to confusion_map.
if not any("unexpected keyword argument 'confusion_map'" in arg for arg in e.args):
raise
_compat._warn_or_error(
"Starting in v0.16, SimulationState subclasses will be required to accept "
"a 'confusion_map' argument. See SimulationState.measure for details."
)
sim_state.measure(qubits, self.key, self.full_invert_mask())
return True


Expand Down
57 changes: 56 additions & 1 deletion cirq-core/cirq/ops/measurement_gate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import cast
from typing import Any, Dict, Optional, Sequence, Tuple, Union, cast
import numpy as np
import pytest

import cirq
from cirq.type_workarounds import NotImplementedType


@pytest.mark.parametrize(
Expand Down Expand Up @@ -529,3 +530,57 @@ def test_act_on_qutrit():
)
cirq.act_on(m, args)
assert args.log_of_measurement_results == {'out': [0, 0]}


def test_act_on_no_confusion_map_deprecated():
class OldSimState(cirq.StateVectorSimulationState):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.measured = False

def _act_on_fallback_(
self, action: Any, qubits: Sequence['cirq.Qid'], allow_decompose: bool = True
) -> Union[bool, NotImplementedType]:
return NotImplemented # coverage: ignore

def measure( # type: ignore
self, qubits: Sequence['cirq.Qid'], key: str, invert_mask: Sequence[bool]
):
self.measured = True

qubits = cirq.LineQubit.range(2)
old_state = OldSimState(qubits=qubits)
m = cirq.measure(*qubits, key='test')
with cirq.testing.assert_deprecated('confusion_map', deadline='v0.16'):
cirq.act_on(m, old_state)
assert old_state.measured


def test_act_on_no_confusion_map_scope_limited():
error_msg = "error from deeper in measure"

class ErrorProneSimState(cirq.StateVectorSimulationState):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.measured = False

def _act_on_fallback_(
self, action: Any, qubits: Sequence['cirq.Qid'], allow_decompose: bool = True
) -> Union[bool, NotImplementedType]:
return NotImplemented # coverage: ignore

def measure(
self,
qubits: Sequence['cirq.Qid'],
key: str,
invert_mask: Sequence[bool],
confusion_map: Optional[Dict[Tuple[int, ...], np.ndarray]] = None,
):
raise TypeError(error_msg)

# Verify that the check doesn't prevent other errors from being raised
qubits = cirq.LineQubit.range(2)
sv_state = ErrorProneSimState(qubits=qubits)
m = cirq.measure(*qubits, key='test')
with pytest.raises(TypeError, match=error_msg):
cirq.act_on(m, sv_state)
5 changes: 3 additions & 2 deletions cirq-core/cirq/sim/simulation_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def measure(
qubits: Sequence['cirq.Qid'],
key: str,
invert_mask: Sequence[bool],
confusion_map: Dict[Tuple[int, ...], np.ndarray],
confusion_map: Optional[Dict[Tuple[int, ...], np.ndarray]] = None,
):
"""Measures the qubits and records to `log_of_measurement_results`.
Expand All @@ -123,7 +123,8 @@ def measure(
ValueError: If a measurement key has already been logged to a key.
"""
bits = self._perform_measurement(qubits)
confused = self._confuse_result(bits, qubits, confusion_map)
if confusion_map is not None:
confused = self._confuse_result(bits, qubits, confusion_map)
corrected = [bit ^ (bit < 2 and mask) for bit, mask in zip(confused, invert_mask)]

This comment has been minimized.

Copy link
@pavoljuhas

pavoljuhas Jul 8, 2022

Collaborator

@95-martin-orion - looks like confused would be undefined on line 128 when confusion_map is None.

self._classical_data.record_measurement(
value.MeasurementKey.parse_serialized(key), corrected, qubits
Expand Down

0 comments on commit 25388af

Please sign in to comment.