Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Disallow empty measurement keys and fix tests using empty keys #4060

Merged
merged 4 commits into from
Aug 2, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion cirq-core/cirq/ion/ion_device_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def test_validate_measurement_non_adjacent_qubits_ok():
d = ion_device(3)

d.validate_operation(
cirq.GateOperation(cirq.MeasurementGate(2), (cirq.LineQubit(0), cirq.LineQubit(1)))
cirq.GateOperation(cirq.MeasurementGate(2, 'key'), (cirq.LineQubit(0), cirq.LineQubit(1)))
)


Expand Down
2 changes: 1 addition & 1 deletion cirq-core/cirq/neutral_atoms/neutral_atom_devices_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def test_validate_moment_errors():
m = cirq.Moment(cirq.X.on_each(*(d.qubit_list()[1:])))
with pytest.raises(ValueError, match="Bad number of simultaneous XY gates"):
d.validate_moment(m)
m = cirq.Moment([cirq.MeasurementGate(1).on(q00), cirq.Z.on(q01)])
m = cirq.Moment([cirq.MeasurementGate(1, 'a').on(q00), cirq.Z.on(q01)])
with pytest.raises(
ValueError, match="Measurements can't be simultaneous with other operations"
):
Expand Down
55 changes: 27 additions & 28 deletions cirq-core/cirq/ops/measurement_gate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,43 +29,43 @@ def test_eval_repr():

@pytest.mark.parametrize('num_qubits', [1, 2, 4])
def test_measure_init(num_qubits):
assert cirq.MeasurementGate(num_qubits).num_qubits() == num_qubits
assert cirq.MeasurementGate(num_qubits, 'a').num_qubits() == num_qubits
assert cirq.MeasurementGate(num_qubits, key='a').key == 'a'
assert cirq.MeasurementGate(num_qubits, key='a').mkey == cirq.MeasurementKey('a')
assert cirq.MeasurementGate(num_qubits, key=cirq.MeasurementKey('a')).key == 'a'
assert cirq.MeasurementGate(num_qubits, key=cirq.MeasurementKey('a')) == cirq.MeasurementGate(
num_qubits, key='a'
)
assert cirq.MeasurementGate(num_qubits, invert_mask=(True,)).invert_mask == (True,)
assert cirq.qid_shape(cirq.MeasurementGate(num_qubits)) == (2,) * num_qubits
assert cirq.qid_shape(cirq.MeasurementGate(3, qid_shape=(1, 2, 3))) == (1, 2, 3)
assert cirq.qid_shape(cirq.MeasurementGate(qid_shape=(1, 2, 3))) == (1, 2, 3)
assert cirq.MeasurementGate(num_qubits, 'a', invert_mask=(True,)).invert_mask == (True,)
assert cirq.qid_shape(cirq.MeasurementGate(num_qubits, 'a')) == (2,) * num_qubits
assert cirq.qid_shape(cirq.MeasurementGate(3, 'a', qid_shape=(1, 2, 3))) == (1, 2, 3)
assert cirq.qid_shape(cirq.MeasurementGate(key='a', qid_shape=(1, 2, 3))) == (1, 2, 3)
with pytest.raises(ValueError, match='len.* >'):
cirq.MeasurementGate(5, invert_mask=(True,) * 6)
cirq.MeasurementGate(5, 'a', invert_mask=(True,) * 6)
with pytest.raises(ValueError, match='len.* !='):
cirq.MeasurementGate(5, qid_shape=(1, 2))
cirq.MeasurementGate(5, 'a', qid_shape=(1, 2))
with pytest.raises(ValueError, match='cannot be empty'):
cirq.MeasurementGate(2, qid_shape=(1, 2))
with pytest.raises(ValueError, match='Specify either'):
cirq.MeasurementGate()


@pytest.mark.parametrize('num_qubits', [1, 2, 4])
def test_has_stabilizer_effect(num_qubits):
assert cirq.has_stabilizer_effect(cirq.MeasurementGate(num_qubits))
assert cirq.has_stabilizer_effect(cirq.MeasurementGate(num_qubits, 'a'))


def test_measurement_eq():
eq = cirq.testing.EqualsTester()
eq.make_equality_group(
lambda: cirq.MeasurementGate(1, ''),
lambda: cirq.MeasurementGate(1, '', invert_mask=()),
lambda: cirq.MeasurementGate(1, '', qid_shape=(2,)),
lambda: cirq.MeasurementGate(1, 'a'),
lambda: cirq.MeasurementGate(1, 'a', invert_mask=()),
lambda: cirq.MeasurementGate(1, 'a', qid_shape=(2,)),
)
eq.add_equality_group(cirq.MeasurementGate(1, 'a'))
eq.add_equality_group(cirq.MeasurementGate(1, 'a', invert_mask=(True,)))
eq.add_equality_group(cirq.MeasurementGate(1, 'a', invert_mask=(False,)))
eq.add_equality_group(cirq.MeasurementGate(1, 'b'))
eq.add_equality_group(cirq.MeasurementGate(2, 'a'))
eq.add_equality_group(cirq.MeasurementGate(2, ''))
eq.add_equality_group(
cirq.MeasurementGate(3, 'a'), cirq.MeasurementGate(3, 'a', qid_shape=(2, 2, 2))
)
Expand Down Expand Up @@ -154,15 +154,14 @@ def test_qudit_measure_quil():

def test_measurement_gate_diagram():
# Shows key.
assert cirq.circuit_diagram_info(cirq.MeasurementGate(1)) == cirq.CircuitDiagramInfo(("M('')",))
assert cirq.circuit_diagram_info(
cirq.MeasurementGate(1, key='test')
) == cirq.CircuitDiagramInfo(("M('test')",))

# Uses known qubit count.
assert (
cirq.circuit_diagram_info(
cirq.MeasurementGate(3),
cirq.MeasurementGate(3, 'a'),
cirq.CircuitDiagramInfoArgs(
known_qubits=None,
known_qubit_count=3,
Expand All @@ -171,13 +170,13 @@ def test_measurement_gate_diagram():
qubit_map=None,
),
)
== cirq.CircuitDiagramInfo(("M('')", 'M', 'M'))
== cirq.CircuitDiagramInfo(("M('a')", 'M', 'M'))
)

# Shows invert mask.
assert cirq.circuit_diagram_info(
cirq.MeasurementGate(2, invert_mask=(False, True))
) == cirq.CircuitDiagramInfo(("M('')", "!M"))
cirq.MeasurementGate(2, 'a', invert_mask=(False, True))
) == cirq.CircuitDiagramInfo(("M('a')", "!M"))

# Omits key when it is the default.
a = cirq.NamedQubit('a')
Expand Down Expand Up @@ -210,12 +209,12 @@ def test_measurement_gate_diagram():

def test_measurement_channel():
np.testing.assert_allclose(
cirq.kraus(cirq.MeasurementGate(1)),
cirq.kraus(cirq.MeasurementGate(1, 'a')),
(np.array([[1, 0], [0, 0]]), np.array([[0, 0], [0, 1]])),
)
# yapf: disable
np.testing.assert_allclose(
cirq.kraus(cirq.MeasurementGate(2)),
cirq.kraus(cirq.MeasurementGate(2, 'a')),
(np.array([[1, 0, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 0],
Expand All @@ -233,7 +232,7 @@ def test_measurement_channel():
[0, 0, 0, 0],
[0, 0, 0, 1]])))
np.testing.assert_allclose(
cirq.kraus(cirq.MeasurementGate(2, qid_shape=(2, 3))),
cirq.kraus(cirq.MeasurementGate(2, 'a', qid_shape=(2, 3))),
(np.diag([1, 0, 0, 0, 0, 0]),
np.diag([0, 1, 0, 0, 0, 0]),
np.diag([0, 0, 1, 0, 0, 0]),
Expand All @@ -248,21 +247,21 @@ def test_measurement_qubit_count_vs_mask_length():
b = cirq.NamedQubit('b')
c = cirq.NamedQubit('c')

_ = cirq.MeasurementGate(num_qubits=1, invert_mask=(True,)).on(a)
_ = cirq.MeasurementGate(num_qubits=2, invert_mask=(True, False)).on(a, b)
_ = cirq.MeasurementGate(num_qubits=3, invert_mask=(True, False, True)).on(a, b, c)
_ = cirq.MeasurementGate(num_qubits=1, key='a', invert_mask=(True,)).on(a)
_ = cirq.MeasurementGate(num_qubits=2, key='a', invert_mask=(True, False)).on(a, b)
_ = cirq.MeasurementGate(num_qubits=3, key='a', invert_mask=(True, False, True)).on(a, b, c)
with pytest.raises(ValueError):
_ = cirq.MeasurementGate(num_qubits=1, invert_mask=(True, False)).on(a)
_ = cirq.MeasurementGate(num_qubits=1, key='a', invert_mask=(True, False)).on(a)
with pytest.raises(ValueError):
_ = cirq.MeasurementGate(num_qubits=3, invert_mask=(True, False, True)).on(a, b)
_ = cirq.MeasurementGate(num_qubits=3, key='a', invert_mask=(True, False, True)).on(a, b)


def test_consistent_protocols():
for n in range(1, 5):
gate = cirq.MeasurementGate(num_qubits=n)
gate = cirq.MeasurementGate(num_qubits=n, key='a')
cirq.testing.assert_implements_consistent_protocols(gate)

gate = cirq.MeasurementGate(num_qubits=n, qid_shape=(3,) * n)
gate = cirq.MeasurementGate(num_qubits=n, key='a', qid_shape=(3,) * n)
cirq.testing.assert_implements_consistent_protocols(gate)


Expand Down
2 changes: 1 addition & 1 deletion cirq-core/cirq/sim/sparse_simulator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1226,7 +1226,7 @@ def test_separated_measurements():
cirq.H(a),
cirq.H(b),
cirq.CZ(a, b),
cirq.measure(a, key=''),
cirq.measure(a, key='a'),
cirq.CZ(a, b),
cirq.H(b),
cirq.measure(b, key='zero'),
Expand Down
2 changes: 2 additions & 0 deletions cirq-core/cirq/value/measurement_key.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ class MeasurementKey:
path: Tuple[str, ...] = dataclasses.field(default_factory=tuple)

def __post_init__(self):
if not self.name:
raise ValueError("Measurement key name cannot be empty")
if MEASUREMENT_KEY_SEPARATOR in self.name:
raise ValueError(
f'Invalid key name: {self.name}\n{MEASUREMENT_KEY_SEPARATOR} is not allowed in '
Expand Down
4 changes: 2 additions & 2 deletions cirq-core/cirq/value/measurement_key_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
def test_empty_init():
with pytest.raises(TypeError, match='required positional argument'):
_ = cirq.MeasurementKey()
mkey = cirq.MeasurementKey('')
assert mkey.name == ''
with pytest.raises(ValueError, match='cannot be empty'):
_ = cirq.MeasurementKey('')


def test_nested_key():
Expand Down
6 changes: 4 additions & 2 deletions cirq-google/cirq_google/devices/xmon_device_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@

import pytest

import cirq
import cirq_google as cg
import cirq


def square_device(width: int, height: int, holes=()) -> cg.XmonDevice:
Expand Down Expand Up @@ -133,7 +133,9 @@ def test_validate_measurement_non_adjacent_qubits_ok():
d = square_device(3, 3)

d.validate_operation(
cirq.GateOperation(cirq.MeasurementGate(2), (cirq.GridQubit(0, 0), cirq.GridQubit(2, 0)))
cirq.GateOperation(
cirq.MeasurementGate(2, 'a'), (cirq.GridQubit(0, 0), cirq.GridQubit(2, 0))
)
)


Expand Down
6 changes: 3 additions & 3 deletions cirq-ionq/cirq_ionq/ionq_devices_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@
cirq.YY ** 0.5,
cirq.ZZ ** 0.5,
cirq.SWAP,
cirq.MeasurementGate(num_qubits=1),
cirq.MeasurementGate(num_qubits=2),
cirq.MeasurementGate(num_qubits=10),
cirq.MeasurementGate(num_qubits=1, key='a'),
cirq.MeasurementGate(num_qubits=2, key='b'),
cirq.MeasurementGate(num_qubits=10, key='c'),
)


Expand Down
4 changes: 2 additions & 2 deletions cirq-pasqal/cirq_pasqal/pasqal_device_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,11 @@ def test_decompose_error():

# MeasurementGate is not a GateOperation
with pytest.raises(TypeError):
d.decompose_operation(cirq.ops.MeasurementGate(num_qubits=2))
d.decompose_operation(cirq.ops.MeasurementGate(num_qubits=2, key='a'))
# It has to be made into one
assert d.is_pasqal_device_op(
cirq.ops.GateOperation(
cirq.ops.MeasurementGate(2), [cirq.NamedQubit('q0'), cirq.NamedQubit('q1')]
cirq.ops.MeasurementGate(2, 'b'), [cirq.NamedQubit('q0'), cirq.NamedQubit('q1')]
)
)

Expand Down