Skip to content

Commit

Permalink
Disallow empty measurement keys and fix tests using empty keys (#4060)
Browse files Browse the repository at this point in the history
* Empty keys are illegal now

* merge

* fix duplicate post-inits
  • Loading branch information
smitsanghavi committed Aug 2, 2021
1 parent 9c23053 commit 409a412
Show file tree
Hide file tree
Showing 9 changed files with 43 additions and 40 deletions.
2 changes: 1 addition & 1 deletion cirq-core/cirq/ion/ion_device_test.py
Expand Up @@ -106,7 +106,7 @@ def test_validate_measurement_non_adjacent_qubits_ok():
d = ion_device(3)

d.validate_operation(
cirq.GateOperation(cirq.MeasurementGate(2), (cirq.LineQubit(0), cirq.LineQubit(1)))
cirq.GateOperation(cirq.MeasurementGate(2, 'key'), (cirq.LineQubit(0), cirq.LineQubit(1)))
)


Expand Down
2 changes: 1 addition & 1 deletion cirq-core/cirq/neutral_atoms/neutral_atom_devices_test.py
Expand Up @@ -179,7 +179,7 @@ def test_validate_moment_errors():
m = cirq.Moment(cirq.X.on_each(*(d.qubit_list()[1:])))
with pytest.raises(ValueError, match="Bad number of simultaneous XY gates"):
d.validate_moment(m)
m = cirq.Moment([cirq.MeasurementGate(1).on(q00), cirq.Z.on(q01)])
m = cirq.Moment([cirq.MeasurementGate(1, 'a').on(q00), cirq.Z.on(q01)])
with pytest.raises(
ValueError, match="Measurements can't be simultaneous with other operations"
):
Expand Down
55 changes: 27 additions & 28 deletions cirq-core/cirq/ops/measurement_gate_test.py
Expand Up @@ -29,43 +29,43 @@ def test_eval_repr():

@pytest.mark.parametrize('num_qubits', [1, 2, 4])
def test_measure_init(num_qubits):
assert cirq.MeasurementGate(num_qubits).num_qubits() == num_qubits
assert cirq.MeasurementGate(num_qubits, 'a').num_qubits() == num_qubits
assert cirq.MeasurementGate(num_qubits, key='a').key == 'a'
assert cirq.MeasurementGate(num_qubits, key='a').mkey == cirq.MeasurementKey('a')
assert cirq.MeasurementGate(num_qubits, key=cirq.MeasurementKey('a')).key == 'a'
assert cirq.MeasurementGate(num_qubits, key=cirq.MeasurementKey('a')) == cirq.MeasurementGate(
num_qubits, key='a'
)
assert cirq.MeasurementGate(num_qubits, invert_mask=(True,)).invert_mask == (True,)
assert cirq.qid_shape(cirq.MeasurementGate(num_qubits)) == (2,) * num_qubits
assert cirq.qid_shape(cirq.MeasurementGate(3, qid_shape=(1, 2, 3))) == (1, 2, 3)
assert cirq.qid_shape(cirq.MeasurementGate(qid_shape=(1, 2, 3))) == (1, 2, 3)
assert cirq.MeasurementGate(num_qubits, 'a', invert_mask=(True,)).invert_mask == (True,)
assert cirq.qid_shape(cirq.MeasurementGate(num_qubits, 'a')) == (2,) * num_qubits
assert cirq.qid_shape(cirq.MeasurementGate(3, 'a', qid_shape=(1, 2, 3))) == (1, 2, 3)
assert cirq.qid_shape(cirq.MeasurementGate(key='a', qid_shape=(1, 2, 3))) == (1, 2, 3)
with pytest.raises(ValueError, match='len.* >'):
cirq.MeasurementGate(5, invert_mask=(True,) * 6)
cirq.MeasurementGate(5, 'a', invert_mask=(True,) * 6)
with pytest.raises(ValueError, match='len.* !='):
cirq.MeasurementGate(5, qid_shape=(1, 2))
cirq.MeasurementGate(5, 'a', qid_shape=(1, 2))
with pytest.raises(ValueError, match='cannot be empty'):
cirq.MeasurementGate(2, qid_shape=(1, 2))
with pytest.raises(ValueError, match='Specify either'):
cirq.MeasurementGate()


@pytest.mark.parametrize('num_qubits', [1, 2, 4])
def test_has_stabilizer_effect(num_qubits):
assert cirq.has_stabilizer_effect(cirq.MeasurementGate(num_qubits))
assert cirq.has_stabilizer_effect(cirq.MeasurementGate(num_qubits, 'a'))


def test_measurement_eq():
eq = cirq.testing.EqualsTester()
eq.make_equality_group(
lambda: cirq.MeasurementGate(1, ''),
lambda: cirq.MeasurementGate(1, '', invert_mask=()),
lambda: cirq.MeasurementGate(1, '', qid_shape=(2,)),
lambda: cirq.MeasurementGate(1, 'a'),
lambda: cirq.MeasurementGate(1, 'a', invert_mask=()),
lambda: cirq.MeasurementGate(1, 'a', qid_shape=(2,)),
)
eq.add_equality_group(cirq.MeasurementGate(1, 'a'))
eq.add_equality_group(cirq.MeasurementGate(1, 'a', invert_mask=(True,)))
eq.add_equality_group(cirq.MeasurementGate(1, 'a', invert_mask=(False,)))
eq.add_equality_group(cirq.MeasurementGate(1, 'b'))
eq.add_equality_group(cirq.MeasurementGate(2, 'a'))
eq.add_equality_group(cirq.MeasurementGate(2, ''))
eq.add_equality_group(
cirq.MeasurementGate(3, 'a'), cirq.MeasurementGate(3, 'a', qid_shape=(2, 2, 2))
)
Expand Down Expand Up @@ -154,15 +154,14 @@ def test_qudit_measure_quil():

def test_measurement_gate_diagram():
# Shows key.
assert cirq.circuit_diagram_info(cirq.MeasurementGate(1)) == cirq.CircuitDiagramInfo(("M('')",))
assert cirq.circuit_diagram_info(
cirq.MeasurementGate(1, key='test')
) == cirq.CircuitDiagramInfo(("M('test')",))

# Uses known qubit count.
assert (
cirq.circuit_diagram_info(
cirq.MeasurementGate(3),
cirq.MeasurementGate(3, 'a'),
cirq.CircuitDiagramInfoArgs(
known_qubits=None,
known_qubit_count=3,
Expand All @@ -171,13 +170,13 @@ def test_measurement_gate_diagram():
qubit_map=None,
),
)
== cirq.CircuitDiagramInfo(("M('')", 'M', 'M'))
== cirq.CircuitDiagramInfo(("M('a')", 'M', 'M'))
)

# Shows invert mask.
assert cirq.circuit_diagram_info(
cirq.MeasurementGate(2, invert_mask=(False, True))
) == cirq.CircuitDiagramInfo(("M('')", "!M"))
cirq.MeasurementGate(2, 'a', invert_mask=(False, True))
) == cirq.CircuitDiagramInfo(("M('a')", "!M"))

# Omits key when it is the default.
a = cirq.NamedQubit('a')
Expand Down Expand Up @@ -210,12 +209,12 @@ def test_measurement_gate_diagram():

def test_measurement_channel():
np.testing.assert_allclose(
cirq.kraus(cirq.MeasurementGate(1)),
cirq.kraus(cirq.MeasurementGate(1, 'a')),
(np.array([[1, 0], [0, 0]]), np.array([[0, 0], [0, 1]])),
)
# yapf: disable
np.testing.assert_allclose(
cirq.kraus(cirq.MeasurementGate(2)),
cirq.kraus(cirq.MeasurementGate(2, 'a')),
(np.array([[1, 0, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 0],
Expand All @@ -233,7 +232,7 @@ def test_measurement_channel():
[0, 0, 0, 0],
[0, 0, 0, 1]])))
np.testing.assert_allclose(
cirq.kraus(cirq.MeasurementGate(2, qid_shape=(2, 3))),
cirq.kraus(cirq.MeasurementGate(2, 'a', qid_shape=(2, 3))),
(np.diag([1, 0, 0, 0, 0, 0]),
np.diag([0, 1, 0, 0, 0, 0]),
np.diag([0, 0, 1, 0, 0, 0]),
Expand All @@ -248,21 +247,21 @@ def test_measurement_qubit_count_vs_mask_length():
b = cirq.NamedQubit('b')
c = cirq.NamedQubit('c')

_ = cirq.MeasurementGate(num_qubits=1, invert_mask=(True,)).on(a)
_ = cirq.MeasurementGate(num_qubits=2, invert_mask=(True, False)).on(a, b)
_ = cirq.MeasurementGate(num_qubits=3, invert_mask=(True, False, True)).on(a, b, c)
_ = cirq.MeasurementGate(num_qubits=1, key='a', invert_mask=(True,)).on(a)
_ = cirq.MeasurementGate(num_qubits=2, key='a', invert_mask=(True, False)).on(a, b)
_ = cirq.MeasurementGate(num_qubits=3, key='a', invert_mask=(True, False, True)).on(a, b, c)
with pytest.raises(ValueError):
_ = cirq.MeasurementGate(num_qubits=1, invert_mask=(True, False)).on(a)
_ = cirq.MeasurementGate(num_qubits=1, key='a', invert_mask=(True, False)).on(a)
with pytest.raises(ValueError):
_ = cirq.MeasurementGate(num_qubits=3, invert_mask=(True, False, True)).on(a, b)
_ = cirq.MeasurementGate(num_qubits=3, key='a', invert_mask=(True, False, True)).on(a, b)


def test_consistent_protocols():
for n in range(1, 5):
gate = cirq.MeasurementGate(num_qubits=n)
gate = cirq.MeasurementGate(num_qubits=n, key='a')
cirq.testing.assert_implements_consistent_protocols(gate)

gate = cirq.MeasurementGate(num_qubits=n, qid_shape=(3,) * n)
gate = cirq.MeasurementGate(num_qubits=n, key='a', qid_shape=(3,) * n)
cirq.testing.assert_implements_consistent_protocols(gate)


Expand Down
2 changes: 1 addition & 1 deletion cirq-core/cirq/sim/sparse_simulator_test.py
Expand Up @@ -1226,7 +1226,7 @@ def test_separated_measurements():
cirq.H(a),
cirq.H(b),
cirq.CZ(a, b),
cirq.measure(a, key=''),
cirq.measure(a, key='a'),
cirq.CZ(a, b),
cirq.H(b),
cirq.measure(b, key='zero'),
Expand Down
2 changes: 2 additions & 0 deletions cirq-core/cirq/value/measurement_key.py
Expand Up @@ -41,6 +41,8 @@ class MeasurementKey:
path: Tuple[str, ...] = dataclasses.field(default_factory=tuple)

def __post_init__(self):
if not self.name:
raise ValueError("Measurement key name cannot be empty")
if MEASUREMENT_KEY_SEPARATOR in self.name:
raise ValueError(
f'Invalid key name: {self.name}\n{MEASUREMENT_KEY_SEPARATOR} is not allowed in '
Expand Down
4 changes: 2 additions & 2 deletions cirq-core/cirq/value/measurement_key_test.py
Expand Up @@ -20,8 +20,8 @@
def test_empty_init():
with pytest.raises(TypeError, match='required positional argument'):
_ = cirq.MeasurementKey()
mkey = cirq.MeasurementKey('')
assert mkey.name == ''
with pytest.raises(ValueError, match='cannot be empty'):
_ = cirq.MeasurementKey('')


def test_nested_key():
Expand Down
6 changes: 4 additions & 2 deletions cirq-google/cirq_google/devices/xmon_device_test.py
Expand Up @@ -14,8 +14,8 @@

import pytest

import cirq
import cirq_google as cg
import cirq


def square_device(width: int, height: int, holes=()) -> cg.XmonDevice:
Expand Down Expand Up @@ -133,7 +133,9 @@ def test_validate_measurement_non_adjacent_qubits_ok():
d = square_device(3, 3)

d.validate_operation(
cirq.GateOperation(cirq.MeasurementGate(2), (cirq.GridQubit(0, 0), cirq.GridQubit(2, 0)))
cirq.GateOperation(
cirq.MeasurementGate(2, 'a'), (cirq.GridQubit(0, 0), cirq.GridQubit(2, 0))
)
)


Expand Down
6 changes: 3 additions & 3 deletions cirq-ionq/cirq_ionq/ionq_devices_test.py
Expand Up @@ -38,9 +38,9 @@
cirq.YY ** 0.5,
cirq.ZZ ** 0.5,
cirq.SWAP,
cirq.MeasurementGate(num_qubits=1),
cirq.MeasurementGate(num_qubits=2),
cirq.MeasurementGate(num_qubits=10),
cirq.MeasurementGate(num_qubits=1, key='a'),
cirq.MeasurementGate(num_qubits=2, key='b'),
cirq.MeasurementGate(num_qubits=10, key='c'),
)


Expand Down
4 changes: 2 additions & 2 deletions cirq-pasqal/cirq_pasqal/pasqal_device_test.py
Expand Up @@ -88,11 +88,11 @@ def test_decompose_error():

# MeasurementGate is not a GateOperation
with pytest.raises(TypeError):
d.decompose_operation(cirq.ops.MeasurementGate(num_qubits=2))
d.decompose_operation(cirq.ops.MeasurementGate(num_qubits=2, key='a'))
# It has to be made into one
assert d.is_pasqal_device_op(
cirq.ops.GateOperation(
cirq.ops.MeasurementGate(2), [cirq.NamedQubit('q0'), cirq.NamedQubit('q1')]
cirq.ops.MeasurementGate(2, 'b'), [cirq.NamedQubit('q0'), cirq.NamedQubit('q1')]
)
)

Expand Down

0 comments on commit 409a412

Please sign in to comment.