From 11ae0bd3f5aa3586f440560148ff7deb562cae9b Mon Sep 17 00:00:00 2001 From: Shef Date: Fri, 10 Nov 2023 13:45:31 -0500 Subject: [PATCH] Make `RouteCQC` errorout on intermediate measurements on 3+ qubits (#6307) --- .../transformers/routing/route_circuit_cqc.py | 19 +++++++++-- .../routing/route_circuit_cqc_test.py | 32 +++++++++++++++++++ 2 files changed, 48 insertions(+), 3 deletions(-) diff --git a/cirq-core/cirq/transformers/routing/route_circuit_cqc.py b/cirq-core/cirq/transformers/routing/route_circuit_cqc.py index 66d86245c9e..fdf6eda1fea 100644 --- a/cirq-core/cirq/transformers/routing/route_circuit_cqc.py +++ b/cirq-core/cirq/transformers/routing/route_circuit_cqc.py @@ -15,7 +15,7 @@ """Heuristic qubit routing algorithm based on arxiv:1902.08091.""" from typing import Any, Dict, List, Optional, Set, Sequence, Tuple, TYPE_CHECKING -from itertools import combinations +import itertools import networkx as nx from cirq import circuits, ops, protocols @@ -48,7 +48,9 @@ def _disjoint_nc2_combinations( Returns: All 2-combinations between qubit pairs that are disjoint. """ - return [pair for pair in combinations(qubit_pairs, 2) if set(pair[0]).isdisjoint(pair[1])] + return [ + pair for pair in itertools.combinations(qubit_pairs, 2) if set(pair[0]).isdisjoint(pair[1]) + ] @transformer_api.transformer @@ -245,9 +247,20 @@ def _get_one_and_two_qubit_ops_as_timesteps( The i'th entry in the nested two-qubit and single-qubit ops correspond to the two-qubit gates and single-qubit gates of the i'th timesteps respectively. When constructing the output routed circuit, single-qubit operations are inserted before two-qubit operations. + + Raises: + ValueError: if circuit has intermediate measurement op's that act on 3 or more qubits. """ two_qubit_circuit = circuits.Circuit() single_qubit_ops: List[List[cirq.Operation]] = [] + + if any( + protocols.num_qubits(op) > 2 and protocols.is_measurement(op) + for op in itertools.chain(*circuit.moments[:-1]) + ): + # There is at least one non-terminal measurement on 3+ qubits + raise ValueError('Non-terminal measurements on three or more qubits are not supported') + for moment in circuit: for op in moment: timestep = two_qubit_circuit.earliest_available_moment(op) @@ -255,7 +268,7 @@ def _get_one_and_two_qubit_ops_as_timesteps( two_qubit_circuit.append( circuits.Moment() for _ in range(timestep + 1 - len(two_qubit_circuit)) ) - if protocols.num_qubits(op) == 2 and not protocols.is_measurement(op): + if protocols.num_qubits(op) == 2: two_qubit_circuit[timestep] = two_qubit_circuit[timestep].with_operation(op) else: single_qubit_ops[timestep].append(op) diff --git a/cirq-core/cirq/transformers/routing/route_circuit_cqc_test.py b/cirq-core/cirq/transformers/routing/route_circuit_cqc_test.py index dcd55b3226c..a07161da936 100644 --- a/cirq-core/cirq/transformers/routing/route_circuit_cqc_test.py +++ b/cirq-core/cirq/transformers/routing/route_circuit_cqc_test.py @@ -107,6 +107,38 @@ def test_circuit_with_measurement_gates(): cirq.testing.assert_same_circuits(routed_circuit, circuit) +def test_circuit_with_valid_intermediate_multi_qubit_measurement_gates(): + device = cirq.testing.construct_ring_device(3) + device_graph = device.metadata.nx_graph + router = cirq.RouteCQC(device_graph) + q = cirq.LineQubit.range(2) + hard_coded_mapper = cirq.HardCodedInitialMapper({q[i]: q[i] for i in range(2)}) + + valid_circuit = cirq.Circuit(cirq.measure_each(*q), cirq.H.on_each(q)) + + c_routed = router( + valid_circuit, initial_mapper=hard_coded_mapper, context=cirq.TransformerContext(deep=True) + ) + device.validate_circuit(c_routed) + + +def test_circuit_with_invalid_intermediate_multi_qubit_measurement_gates(): + device = cirq.testing.construct_ring_device(3) + device_graph = device.metadata.nx_graph + router = cirq.RouteCQC(device_graph) + q = cirq.LineQubit.range(3) + hard_coded_mapper = cirq.HardCodedInitialMapper({q[i]: q[i] for i in range(3)}) + + invalid_circuit = cirq.Circuit(cirq.MeasurementGate(3).on(*q), cirq.H.on_each(*q)) + + with pytest.raises(ValueError): + _ = router( + invalid_circuit, + initial_mapper=hard_coded_mapper, + context=cirq.TransformerContext(deep=True), + ) + + def test_circuit_with_non_unitary_and_global_phase(): device = cirq.testing.construct_ring_device(4) device_graph = device.metadata.nx_graph