diff --git a/cirq-core/cirq/transformers/merge_k_qubit_gates.py b/cirq-core/cirq/transformers/merge_k_qubit_gates.py index eb0cf247a5c..e62885bb121 100644 --- a/cirq-core/cirq/transformers/merge_k_qubit_gates.py +++ b/cirq-core/cirq/transformers/merge_k_qubit_gates.py @@ -23,6 +23,47 @@ import cirq +def _rewrite_merged_k_qubit_unitaries( + circuit: 'cirq.AbstractCircuit', + *, + context: Optional['cirq.TransformerContext'] = None, + k: int = 0, + rewriter: Optional[Callable[['cirq.CircuitOperation'], 'cirq.OP_TREE']] = None, + merged_circuit_op_tag: str = "_merged_k_qubit_unitaries_component", +) -> 'cirq.Circuit': + deep = context.deep if context else False + + def map_func(op: 'cirq.Operation', _) -> 'cirq.OP_TREE': + if not (protocols.num_qubits(op) <= k and protocols.has_unitary(op)): + return op + op_untagged = op.untagged + if ( + deep + and isinstance(op_untagged, circuits.CircuitOperation) + and merged_circuit_op_tag not in op.tags + ): + return op_untagged.replace( + circuit=_rewrite_merged_k_qubit_unitaries( + op_untagged.circuit, + context=context, + k=k, + rewriter=rewriter, + merged_circuit_op_tag=merged_circuit_op_tag, + ).freeze() + ).with_tags(*op.tags) + if rewriter: + return rewriter( + cast(circuits.CircuitOperation, op_untagged) + if merged_circuit_op_tag in op.tags + else circuits.CircuitOperation(circuits.FrozenCircuit(op)) + ) + return ops.MatrixGate(protocols.unitary(op)).on(*op.qubits) + + return transformer_primitives.map_operations_and_unroll( + circuit, map_func, tags_to_ignore=context.tags_to_ignore if context else () + ).unfreeze(copy=False) + + @transformer_api.transformer def merge_k_qubit_unitaries( circuit: 'cirq.AbstractCircuit', @@ -54,24 +95,17 @@ def merge_k_qubit_unitaries( if k <= 0: raise ValueError(f"k should be greater than or equal to 1. Found {k}.") merged_circuit_op_tag = "_merged_k_qubit_unitaries_component" - - def map_func(op: 'cirq.Operation', _) -> 'cirq.OP_TREE': - if not (protocols.num_qubits(op) <= k and protocols.has_unitary(op)): - return op - if rewriter: - return rewriter( - cast(circuits.CircuitOperation, op.untagged) - if merged_circuit_op_tag in op.tags - else circuits.CircuitOperation(circuits.FrozenCircuit(op)) - ) - return ops.MatrixGate(protocols.unitary(op)).on(*op.qubits) - circuit = transformer_primitives.merge_k_qubit_unitaries_to_circuit_op( circuit, k=k, tags_to_ignore=context.tags_to_ignore if context else (), merged_circuit_op_tag=merged_circuit_op_tag, + deep=context.deep if context else False, + ) + return _rewrite_merged_k_qubit_unitaries( + circuit, + context=context, + k=k, + rewriter=rewriter, + merged_circuit_op_tag=merged_circuit_op_tag, ) - return transformer_primitives.map_operations_and_unroll( - circuit, map_func, tags_to_ignore=context.tags_to_ignore if context else () - ).unfreeze(copy=False) diff --git a/cirq-core/cirq/transformers/merge_k_qubit_gates_test.py b/cirq-core/cirq/transformers/merge_k_qubit_gates_test.py index 4a39c31bab7..47b0e9efdd4 100644 --- a/cirq-core/cirq/transformers/merge_k_qubit_gates_test.py +++ b/cirq-core/cirq/transformers/merge_k_qubit_gates_test.py @@ -188,3 +188,68 @@ def rewriter_replace_with_decomp(op: 'cirq.CircuitOperation') -> 'cirq.OP_TREE': ║ ║ a: ═════════════════════════════════════════════════════════════════════════════════════════════@══════════════════════════════^═══''', ) + + +def test_merge_k_qubit_unitaries_deep(): + q = cirq.LineQubit.range(2) + h_cz_y = [cirq.H(q[0]), cirq.CZ(*q), cirq.Y(q[1])] + c_orig = cirq.Circuit( + h_cz_y, + cirq.Moment(cirq.X(q[0]).with_tags("ignore"), cirq.Y(q[1])), + cirq.CircuitOperation(cirq.FrozenCircuit(h_cz_y)).repeat(6).with_tags("ignore"), + [cirq.CNOT(*q), cirq.CNOT(*q)], + cirq.CircuitOperation(cirq.FrozenCircuit(h_cz_y)).repeat(4), + [cirq.CNOT(*q), cirq.CZ(*q), cirq.CNOT(*q)], + cirq.CircuitOperation(cirq.FrozenCircuit(h_cz_y)).repeat(5).with_tags("preserve_tag"), + ) + + def _wrap_in_cop(ops: cirq.OP_TREE, tag: str): + return cirq.CircuitOperation(cirq.FrozenCircuit(ops)).with_tags(tag) + + c_expected = cirq.Circuit( + _wrap_in_cop([h_cz_y, cirq.Y(q[1])], '1'), + cirq.Moment(cirq.X(q[0]).with_tags("ignore")), + cirq.CircuitOperation(cirq.FrozenCircuit(h_cz_y)).repeat(6).with_tags("ignore"), + _wrap_in_cop([cirq.CNOT(*q), cirq.CNOT(*q)], '2'), + cirq.CircuitOperation(cirq.FrozenCircuit(_wrap_in_cop(h_cz_y, '3'))).repeat(4), + _wrap_in_cop([cirq.CNOT(*q), cirq.CZ(*q), cirq.CNOT(*q)], '4'), + cirq.CircuitOperation(cirq.FrozenCircuit(_wrap_in_cop(h_cz_y, '5'))) + .repeat(5) + .with_tags("preserve_tag"), + strategy=cirq.InsertStrategy.NEW, + ) + + component_id = 0 + + def rewriter_merge_to_circuit_op(op: 'cirq.CircuitOperation') -> 'cirq.OP_TREE': + nonlocal component_id + component_id = component_id + 1 + return op.with_tags(f'{component_id}') + + context = cirq.TransformerContext(tags_to_ignore=("ignore",), deep=True) + c_new = cirq.merge_k_qubit_unitaries( + c_orig, + k=2, + context=context, + rewriter=rewriter_merge_to_circuit_op, + ) + cirq.testing.assert_same_circuits(c_new, c_expected) + + def _wrap_in_matrix_gate(ops: cirq.OP_TREE): + op = _wrap_in_cop(ops, 'temp') + return cirq.MatrixGate(cirq.unitary(op)).on(*op.qubits) + + c_expected_matrix = cirq.Circuit( + _wrap_in_matrix_gate([h_cz_y, cirq.Y(q[1])]), + cirq.Moment(cirq.X(q[0]).with_tags("ignore")), + cirq.CircuitOperation(cirq.FrozenCircuit(h_cz_y)).repeat(6).with_tags("ignore"), + _wrap_in_matrix_gate([cirq.CNOT(*q), cirq.CNOT(*q)]), + cirq.CircuitOperation(cirq.FrozenCircuit(_wrap_in_matrix_gate(h_cz_y))).repeat(4), + _wrap_in_matrix_gate([cirq.CNOT(*q), cirq.CZ(*q), cirq.CNOT(*q)]), + cirq.CircuitOperation(cirq.FrozenCircuit(_wrap_in_matrix_gate(h_cz_y))) + .repeat(5) + .with_tags("preserve_tag"), + strategy=cirq.InsertStrategy.NEW, + ) + c_new_matrix = cirq.merge_k_qubit_unitaries(c_orig, k=2, context=context) + cirq.testing.assert_same_circuits(c_new_matrix, c_expected_matrix)