diff --git a/cirq-core/cirq/ion/ion_device_test.py b/cirq-core/cirq/ion/ion_device_test.py index 94736d346f0..8a79beb13e9 100644 --- a/cirq-core/cirq/ion/ion_device_test.py +++ b/cirq-core/cirq/ion/ion_device_test.py @@ -106,7 +106,7 @@ def test_validate_measurement_non_adjacent_qubits_ok(): d = ion_device(3) d.validate_operation( - cirq.GateOperation(cirq.MeasurementGate(2), (cirq.LineQubit(0), cirq.LineQubit(1))) + cirq.GateOperation(cirq.MeasurementGate(2, 'key'), (cirq.LineQubit(0), cirq.LineQubit(1))) ) diff --git a/cirq-core/cirq/neutral_atoms/neutral_atom_devices_test.py b/cirq-core/cirq/neutral_atoms/neutral_atom_devices_test.py index 360c51040f1..e82b04d71d1 100644 --- a/cirq-core/cirq/neutral_atoms/neutral_atom_devices_test.py +++ b/cirq-core/cirq/neutral_atoms/neutral_atom_devices_test.py @@ -179,7 +179,7 @@ def test_validate_moment_errors(): m = cirq.Moment(cirq.X.on_each(*(d.qubit_list()[1:]))) with pytest.raises(ValueError, match="Bad number of simultaneous XY gates"): d.validate_moment(m) - m = cirq.Moment([cirq.MeasurementGate(1).on(q00), cirq.Z.on(q01)]) + m = cirq.Moment([cirq.MeasurementGate(1, 'a').on(q00), cirq.Z.on(q01)]) with pytest.raises( ValueError, match="Measurements can't be simultaneous with other operations" ): diff --git a/cirq-core/cirq/ops/measurement_gate_test.py b/cirq-core/cirq/ops/measurement_gate_test.py index 7d08f7cb605..b6621d29718 100644 --- a/cirq-core/cirq/ops/measurement_gate_test.py +++ b/cirq-core/cirq/ops/measurement_gate_test.py @@ -29,43 +29,43 @@ def test_eval_repr(): @pytest.mark.parametrize('num_qubits', [1, 2, 4]) def test_measure_init(num_qubits): - assert cirq.MeasurementGate(num_qubits).num_qubits() == num_qubits + assert cirq.MeasurementGate(num_qubits, 'a').num_qubits() == num_qubits assert cirq.MeasurementGate(num_qubits, key='a').key == 'a' assert cirq.MeasurementGate(num_qubits, key='a').mkey == cirq.MeasurementKey('a') assert cirq.MeasurementGate(num_qubits, key=cirq.MeasurementKey('a')).key == 'a' assert cirq.MeasurementGate(num_qubits, key=cirq.MeasurementKey('a')) == cirq.MeasurementGate( num_qubits, key='a' ) - assert cirq.MeasurementGate(num_qubits, invert_mask=(True,)).invert_mask == (True,) - assert cirq.qid_shape(cirq.MeasurementGate(num_qubits)) == (2,) * num_qubits - assert cirq.qid_shape(cirq.MeasurementGate(3, qid_shape=(1, 2, 3))) == (1, 2, 3) - assert cirq.qid_shape(cirq.MeasurementGate(qid_shape=(1, 2, 3))) == (1, 2, 3) + assert cirq.MeasurementGate(num_qubits, 'a', invert_mask=(True,)).invert_mask == (True,) + assert cirq.qid_shape(cirq.MeasurementGate(num_qubits, 'a')) == (2,) * num_qubits + assert cirq.qid_shape(cirq.MeasurementGate(3, 'a', qid_shape=(1, 2, 3))) == (1, 2, 3) + assert cirq.qid_shape(cirq.MeasurementGate(key='a', qid_shape=(1, 2, 3))) == (1, 2, 3) with pytest.raises(ValueError, match='len.* >'): - cirq.MeasurementGate(5, invert_mask=(True,) * 6) + cirq.MeasurementGate(5, 'a', invert_mask=(True,) * 6) with pytest.raises(ValueError, match='len.* !='): - cirq.MeasurementGate(5, qid_shape=(1, 2)) + cirq.MeasurementGate(5, 'a', qid_shape=(1, 2)) + with pytest.raises(ValueError, match='cannot be empty'): + cirq.MeasurementGate(2, qid_shape=(1, 2)) with pytest.raises(ValueError, match='Specify either'): cirq.MeasurementGate() @pytest.mark.parametrize('num_qubits', [1, 2, 4]) def test_has_stabilizer_effect(num_qubits): - assert cirq.has_stabilizer_effect(cirq.MeasurementGate(num_qubits)) + assert cirq.has_stabilizer_effect(cirq.MeasurementGate(num_qubits, 'a')) def test_measurement_eq(): eq = cirq.testing.EqualsTester() eq.make_equality_group( - lambda: cirq.MeasurementGate(1, ''), - lambda: cirq.MeasurementGate(1, '', invert_mask=()), - lambda: cirq.MeasurementGate(1, '', qid_shape=(2,)), + lambda: cirq.MeasurementGate(1, 'a'), + lambda: cirq.MeasurementGate(1, 'a', invert_mask=()), + lambda: cirq.MeasurementGate(1, 'a', qid_shape=(2,)), ) - eq.add_equality_group(cirq.MeasurementGate(1, 'a')) eq.add_equality_group(cirq.MeasurementGate(1, 'a', invert_mask=(True,))) eq.add_equality_group(cirq.MeasurementGate(1, 'a', invert_mask=(False,))) eq.add_equality_group(cirq.MeasurementGate(1, 'b')) eq.add_equality_group(cirq.MeasurementGate(2, 'a')) - eq.add_equality_group(cirq.MeasurementGate(2, '')) eq.add_equality_group( cirq.MeasurementGate(3, 'a'), cirq.MeasurementGate(3, 'a', qid_shape=(2, 2, 2)) ) @@ -154,7 +154,6 @@ def test_qudit_measure_quil(): def test_measurement_gate_diagram(): # Shows key. - assert cirq.circuit_diagram_info(cirq.MeasurementGate(1)) == cirq.CircuitDiagramInfo(("M('')",)) assert cirq.circuit_diagram_info( cirq.MeasurementGate(1, key='test') ) == cirq.CircuitDiagramInfo(("M('test')",)) @@ -162,7 +161,7 @@ def test_measurement_gate_diagram(): # Uses known qubit count. assert ( cirq.circuit_diagram_info( - cirq.MeasurementGate(3), + cirq.MeasurementGate(3, 'a'), cirq.CircuitDiagramInfoArgs( known_qubits=None, known_qubit_count=3, @@ -171,13 +170,13 @@ def test_measurement_gate_diagram(): qubit_map=None, ), ) - == cirq.CircuitDiagramInfo(("M('')", 'M', 'M')) + == cirq.CircuitDiagramInfo(("M('a')", 'M', 'M')) ) # Shows invert mask. assert cirq.circuit_diagram_info( - cirq.MeasurementGate(2, invert_mask=(False, True)) - ) == cirq.CircuitDiagramInfo(("M('')", "!M")) + cirq.MeasurementGate(2, 'a', invert_mask=(False, True)) + ) == cirq.CircuitDiagramInfo(("M('a')", "!M")) # Omits key when it is the default. a = cirq.NamedQubit('a') @@ -210,12 +209,12 @@ def test_measurement_gate_diagram(): def test_measurement_channel(): np.testing.assert_allclose( - cirq.kraus(cirq.MeasurementGate(1)), + cirq.kraus(cirq.MeasurementGate(1, 'a')), (np.array([[1, 0], [0, 0]]), np.array([[0, 0], [0, 1]])), ) # yapf: disable np.testing.assert_allclose( - cirq.kraus(cirq.MeasurementGate(2)), + cirq.kraus(cirq.MeasurementGate(2, 'a')), (np.array([[1, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], @@ -233,7 +232,7 @@ def test_measurement_channel(): [0, 0, 0, 0], [0, 0, 0, 1]]))) np.testing.assert_allclose( - cirq.kraus(cirq.MeasurementGate(2, qid_shape=(2, 3))), + cirq.kraus(cirq.MeasurementGate(2, 'a', qid_shape=(2, 3))), (np.diag([1, 0, 0, 0, 0, 0]), np.diag([0, 1, 0, 0, 0, 0]), np.diag([0, 0, 1, 0, 0, 0]), @@ -248,21 +247,21 @@ def test_measurement_qubit_count_vs_mask_length(): b = cirq.NamedQubit('b') c = cirq.NamedQubit('c') - _ = cirq.MeasurementGate(num_qubits=1, invert_mask=(True,)).on(a) - _ = cirq.MeasurementGate(num_qubits=2, invert_mask=(True, False)).on(a, b) - _ = cirq.MeasurementGate(num_qubits=3, invert_mask=(True, False, True)).on(a, b, c) + _ = cirq.MeasurementGate(num_qubits=1, key='a', invert_mask=(True,)).on(a) + _ = cirq.MeasurementGate(num_qubits=2, key='a', invert_mask=(True, False)).on(a, b) + _ = cirq.MeasurementGate(num_qubits=3, key='a', invert_mask=(True, False, True)).on(a, b, c) with pytest.raises(ValueError): - _ = cirq.MeasurementGate(num_qubits=1, invert_mask=(True, False)).on(a) + _ = cirq.MeasurementGate(num_qubits=1, key='a', invert_mask=(True, False)).on(a) with pytest.raises(ValueError): - _ = cirq.MeasurementGate(num_qubits=3, invert_mask=(True, False, True)).on(a, b) + _ = cirq.MeasurementGate(num_qubits=3, key='a', invert_mask=(True, False, True)).on(a, b) def test_consistent_protocols(): for n in range(1, 5): - gate = cirq.MeasurementGate(num_qubits=n) + gate = cirq.MeasurementGate(num_qubits=n, key='a') cirq.testing.assert_implements_consistent_protocols(gate) - gate = cirq.MeasurementGate(num_qubits=n, qid_shape=(3,) * n) + gate = cirq.MeasurementGate(num_qubits=n, key='a', qid_shape=(3,) * n) cirq.testing.assert_implements_consistent_protocols(gate) diff --git a/cirq-core/cirq/sim/sparse_simulator_test.py b/cirq-core/cirq/sim/sparse_simulator_test.py index d99c14d7c1d..dfe21d8e6e4 100644 --- a/cirq-core/cirq/sim/sparse_simulator_test.py +++ b/cirq-core/cirq/sim/sparse_simulator_test.py @@ -1226,7 +1226,7 @@ def test_separated_measurements(): cirq.H(a), cirq.H(b), cirq.CZ(a, b), - cirq.measure(a, key=''), + cirq.measure(a, key='a'), cirq.CZ(a, b), cirq.H(b), cirq.measure(b, key='zero'), diff --git a/cirq-core/cirq/value/measurement_key.py b/cirq-core/cirq/value/measurement_key.py index ecd03e00aaa..19cf2f4d146 100644 --- a/cirq-core/cirq/value/measurement_key.py +++ b/cirq-core/cirq/value/measurement_key.py @@ -41,6 +41,8 @@ class MeasurementKey: path: Tuple[str, ...] = dataclasses.field(default_factory=tuple) def __post_init__(self): + if not self.name: + raise ValueError("Measurement key name cannot be empty") if MEASUREMENT_KEY_SEPARATOR in self.name: raise ValueError( f'Invalid key name: {self.name}\n{MEASUREMENT_KEY_SEPARATOR} is not allowed in ' diff --git a/cirq-core/cirq/value/measurement_key_test.py b/cirq-core/cirq/value/measurement_key_test.py index 08dfad85a8d..34402c9fcff 100644 --- a/cirq-core/cirq/value/measurement_key_test.py +++ b/cirq-core/cirq/value/measurement_key_test.py @@ -20,8 +20,8 @@ def test_empty_init(): with pytest.raises(TypeError, match='required positional argument'): _ = cirq.MeasurementKey() - mkey = cirq.MeasurementKey('') - assert mkey.name == '' + with pytest.raises(ValueError, match='cannot be empty'): + _ = cirq.MeasurementKey('') def test_nested_key(): diff --git a/cirq-google/cirq_google/devices/xmon_device_test.py b/cirq-google/cirq_google/devices/xmon_device_test.py index 747af4cf0ae..e1a5ba15339 100644 --- a/cirq-google/cirq_google/devices/xmon_device_test.py +++ b/cirq-google/cirq_google/devices/xmon_device_test.py @@ -14,8 +14,8 @@ import pytest -import cirq import cirq_google as cg +import cirq def square_device(width: int, height: int, holes=()) -> cg.XmonDevice: @@ -133,7 +133,9 @@ def test_validate_measurement_non_adjacent_qubits_ok(): d = square_device(3, 3) d.validate_operation( - cirq.GateOperation(cirq.MeasurementGate(2), (cirq.GridQubit(0, 0), cirq.GridQubit(2, 0))) + cirq.GateOperation( + cirq.MeasurementGate(2, 'a'), (cirq.GridQubit(0, 0), cirq.GridQubit(2, 0)) + ) ) diff --git a/cirq-ionq/cirq_ionq/ionq_devices_test.py b/cirq-ionq/cirq_ionq/ionq_devices_test.py index 60b86c260a4..355c3122148 100644 --- a/cirq-ionq/cirq_ionq/ionq_devices_test.py +++ b/cirq-ionq/cirq_ionq/ionq_devices_test.py @@ -38,9 +38,9 @@ cirq.YY ** 0.5, cirq.ZZ ** 0.5, cirq.SWAP, - cirq.MeasurementGate(num_qubits=1), - cirq.MeasurementGate(num_qubits=2), - cirq.MeasurementGate(num_qubits=10), + cirq.MeasurementGate(num_qubits=1, key='a'), + cirq.MeasurementGate(num_qubits=2, key='b'), + cirq.MeasurementGate(num_qubits=10, key='c'), ) diff --git a/cirq-pasqal/cirq_pasqal/pasqal_device_test.py b/cirq-pasqal/cirq_pasqal/pasqal_device_test.py index 9bcc13f031e..4c0773f1653 100644 --- a/cirq-pasqal/cirq_pasqal/pasqal_device_test.py +++ b/cirq-pasqal/cirq_pasqal/pasqal_device_test.py @@ -88,11 +88,11 @@ def test_decompose_error(): # MeasurementGate is not a GateOperation with pytest.raises(TypeError): - d.decompose_operation(cirq.ops.MeasurementGate(num_qubits=2)) + d.decompose_operation(cirq.ops.MeasurementGate(num_qubits=2, key='a')) # It has to be made into one assert d.is_pasqal_device_op( cirq.ops.GateOperation( - cirq.ops.MeasurementGate(2), [cirq.NamedQubit('q0'), cirq.NamedQubit('q1')] + cirq.ops.MeasurementGate(2, 'b'), [cirq.NamedQubit('q0'), cirq.NamedQubit('q1')] ) )