Skip to content

Commit

Permalink
Make RouteCQC errorout on intermediate measurements on 3+ qubits (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
shef4 committed Nov 10, 2023
1 parent aa312bc commit 11ae0bd
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 3 deletions.
19 changes: 16 additions & 3 deletions cirq-core/cirq/transformers/routing/route_circuit_cqc.py
Expand Up @@ -15,7 +15,7 @@
"""Heuristic qubit routing algorithm based on arxiv:1902.08091."""

from typing import Any, Dict, List, Optional, Set, Sequence, Tuple, TYPE_CHECKING
from itertools import combinations
import itertools
import networkx as nx

from cirq import circuits, ops, protocols
Expand Down Expand Up @@ -48,7 +48,9 @@ def _disjoint_nc2_combinations(
Returns:
All 2-combinations between qubit pairs that are disjoint.
"""
return [pair for pair in combinations(qubit_pairs, 2) if set(pair[0]).isdisjoint(pair[1])]
return [
pair for pair in itertools.combinations(qubit_pairs, 2) if set(pair[0]).isdisjoint(pair[1])
]


@transformer_api.transformer
Expand Down Expand Up @@ -245,17 +247,28 @@ def _get_one_and_two_qubit_ops_as_timesteps(
The i'th entry in the nested two-qubit and single-qubit ops correspond to the two-qubit
gates and single-qubit gates of the i'th timesteps respectively. When constructing the
output routed circuit, single-qubit operations are inserted before two-qubit operations.
Raises:
ValueError: if circuit has intermediate measurement op's that act on 3 or more qubits.
"""
two_qubit_circuit = circuits.Circuit()
single_qubit_ops: List[List[cirq.Operation]] = []

if any(
protocols.num_qubits(op) > 2 and protocols.is_measurement(op)
for op in itertools.chain(*circuit.moments[:-1])
):
# There is at least one non-terminal measurement on 3+ qubits
raise ValueError('Non-terminal measurements on three or more qubits are not supported')

for moment in circuit:
for op in moment:
timestep = two_qubit_circuit.earliest_available_moment(op)
single_qubit_ops.extend([] for _ in range(timestep + 1 - len(single_qubit_ops)))
two_qubit_circuit.append(
circuits.Moment() for _ in range(timestep + 1 - len(two_qubit_circuit))
)
if protocols.num_qubits(op) == 2 and not protocols.is_measurement(op):
if protocols.num_qubits(op) == 2:
two_qubit_circuit[timestep] = two_qubit_circuit[timestep].with_operation(op)
else:
single_qubit_ops[timestep].append(op)
Expand Down
32 changes: 32 additions & 0 deletions cirq-core/cirq/transformers/routing/route_circuit_cqc_test.py
Expand Up @@ -107,6 +107,38 @@ def test_circuit_with_measurement_gates():
cirq.testing.assert_same_circuits(routed_circuit, circuit)


def test_circuit_with_valid_intermediate_multi_qubit_measurement_gates():
device = cirq.testing.construct_ring_device(3)
device_graph = device.metadata.nx_graph
router = cirq.RouteCQC(device_graph)
q = cirq.LineQubit.range(2)
hard_coded_mapper = cirq.HardCodedInitialMapper({q[i]: q[i] for i in range(2)})

valid_circuit = cirq.Circuit(cirq.measure_each(*q), cirq.H.on_each(q))

c_routed = router(
valid_circuit, initial_mapper=hard_coded_mapper, context=cirq.TransformerContext(deep=True)
)
device.validate_circuit(c_routed)


def test_circuit_with_invalid_intermediate_multi_qubit_measurement_gates():
device = cirq.testing.construct_ring_device(3)
device_graph = device.metadata.nx_graph
router = cirq.RouteCQC(device_graph)
q = cirq.LineQubit.range(3)
hard_coded_mapper = cirq.HardCodedInitialMapper({q[i]: q[i] for i in range(3)})

invalid_circuit = cirq.Circuit(cirq.MeasurementGate(3).on(*q), cirq.H.on_each(*q))

with pytest.raises(ValueError):
_ = router(
invalid_circuit,
initial_mapper=hard_coded_mapper,
context=cirq.TransformerContext(deep=True),
)


def test_circuit_with_non_unitary_and_global_phase():
device = cirq.testing.construct_ring_device(4)
device_graph = device.metadata.nx_graph
Expand Down

0 comments on commit 11ae0bd

Please sign in to comment.