diff --git a/cirq-core/cirq/transformers/target_gatesets/compilation_target_gateset.py b/cirq-core/cirq/transformers/target_gatesets/compilation_target_gateset.py index d63e7d09ac4..45801676282 100644 --- a/cirq-core/cirq/transformers/target_gatesets/compilation_target_gateset.py +++ b/cirq-core/cirq/transformers/target_gatesets/compilation_target_gateset.py @@ -170,10 +170,24 @@ def decompose_to_target_gateset(self, op: 'cirq.Operation', moment_idx: int) -> old_2q_gate_count = sum(1 for o in ops.flatten_to_ops(old_optree) if len(o.qubits) == 2) new_2q_gate_count = sum(1 for o in ops.flatten_to_ops(new_optree) if len(o.qubits) == 2) switch_to_new = ( - any(op not in self for op in ops.flatten_to_ops(old_optree)) + any( + protocols.num_qubits(op) == 2 and op not in self + for op in ops.flatten_to_ops(old_optree) + ) or new_2q_gate_count < old_2q_gate_count ) - return new_optree if switch_to_new else old_optree + if switch_to_new: + return new_optree + mapped_old_optree: List['cirq.OP_TREE'] = [] + for old_op in ops.flatten_to_ops(old_optree): + if old_op in self: + mapped_old_optree.append(old_op) + else: + decomposed_op = self._decompose_single_qubit_operation(old_op, moment_idx) + if decomposed_op is None or decomposed_op is NotImplemented: + return NotImplemented + mapped_old_optree.append(decomposed_op) + return mapped_old_optree def _decompose_single_qubit_operation( self, op: 'cirq.Operation', moment_idx: int diff --git a/cirq-core/cirq/transformers/target_gatesets/compilation_target_gateset_test.py b/cirq-core/cirq/transformers/target_gatesets/compilation_target_gateset_test.py index c0fd64cd90c..c3af371e90f 100644 --- a/cirq-core/cirq/transformers/target_gatesets/compilation_target_gateset_test.py +++ b/cirq-core/cirq/transformers/target_gatesets/compilation_target_gateset_test.py @@ -199,3 +199,23 @@ def test_two_qubit_compilation_merge_and_replace_inefficient_component(): m: ═════════════════════════════════════════════════════@═══^═══ ''', ) + + +def test_two_qubit_compilation_replaces_only_if_2q_gate_count_is_less(): + class DummyTargetGateset(cirq.TwoQubitCompilationTargetGateset): + def __init__(self): + super().__init__(cirq.X, cirq.CNOT) + + def _decompose_two_qubit_operation(self, op: 'cirq.Operation', _) -> DecomposeResult: + q0, q1 = op.qubits + return [cirq.X.on_each(q0, q1), cirq.CNOT(q0, q1)] * 10 + + def _decompose_single_qubit_operation(self, op: 'cirq.Operation', _) -> DecomposeResult: + return cirq.X(*op.qubits) if op.gate == cirq.Y else NotImplemented + + q = cirq.LineQubit.range(2) + ops = [cirq.Y.on_each(*q), cirq.CNOT(*q), cirq.Z.on_each(*q)] + c_orig = cirq.Circuit(ops) + c_expected = cirq.Circuit(cirq.X.on_each(*q), ops[-2:]) + c_new = cirq.optimize_for_target_gateset(c_orig, gateset=DummyTargetGateset()) + cirq.testing.assert_same_circuits(c_new, c_expected)