diff --git a/cirq-core/cirq/transformers/merge_k_qubit_gates.py b/cirq-core/cirq/transformers/merge_k_qubit_gates.py index e62885bb121..a8707d18757 100644 --- a/cirq-core/cirq/transformers/merge_k_qubit_gates.py +++ b/cirq-core/cirq/transformers/merge_k_qubit_gates.py @@ -34,8 +34,6 @@ def _rewrite_merged_k_qubit_unitaries( deep = context.deep if context else False def map_func(op: 'cirq.Operation', _) -> 'cirq.OP_TREE': - if not (protocols.num_qubits(op) <= k and protocols.has_unitary(op)): - return op op_untagged = op.untagged if ( deep @@ -51,6 +49,8 @@ def map_func(op: 'cirq.Operation', _) -> 'cirq.OP_TREE': merged_circuit_op_tag=merged_circuit_op_tag, ).freeze() ).with_tags(*op.tags) + if not (protocols.num_qubits(op) <= k and protocols.has_unitary(op)): + return op if rewriter: return rewriter( cast(circuits.CircuitOperation, op_untagged) diff --git a/cirq-core/cirq/transformers/merge_k_qubit_gates_test.py b/cirq-core/cirq/transformers/merge_k_qubit_gates_test.py index 47b0e9efdd4..6dc1434dc52 100644 --- a/cirq-core/cirq/transformers/merge_k_qubit_gates_test.py +++ b/cirq-core/cirq/transformers/merge_k_qubit_gates_test.py @@ -253,3 +253,27 @@ def _wrap_in_matrix_gate(ops: cirq.OP_TREE): ) c_new_matrix = cirq.merge_k_qubit_unitaries(c_orig, k=2, context=context) cirq.testing.assert_same_circuits(c_new_matrix, c_expected_matrix) + + +def test_merge_k_qubit_unitaries_deep_recurses_on_large_circuit_op(): + q = cirq.LineQubit.range(2) + c_orig = cirq.Circuit( + cirq.CircuitOperation(cirq.FrozenCircuit(cirq.X(q[0]), cirq.H(q[0]), cirq.CNOT(*q))) + ) + c_expected = cirq.Circuit( + cirq.CircuitOperation( + cirq.FrozenCircuit( + cirq.CircuitOperation(cirq.FrozenCircuit(cirq.X(q[0]), cirq.H(q[0]))).with_tags( + "merged" + ), + cirq.CNOT(*q), + ) + ) + ) + c_new = cirq.merge_k_qubit_unitaries( + c_orig, + context=cirq.TransformerContext(deep=True), + k=1, + rewriter=lambda op: op.with_tags("merged"), + ) + cirq.testing.assert_same_circuits(c_new, c_expected)