diff --git a/cirq-core/cirq/transformers/optimize_for_target_gateset.py b/cirq-core/cirq/transformers/optimize_for_target_gateset.py index d028366db14..5c71ec9c8ee 100644 --- a/cirq-core/cirq/transformers/optimize_for_target_gateset.py +++ b/cirq-core/cirq/transformers/optimize_for_target_gateset.py @@ -14,8 +14,9 @@ """Transformers to rewrite a circuit using gates from a given target gateset.""" -from typing import Optional, Callable, TYPE_CHECKING +from typing import Optional, Callable, Hashable, Sequence, TYPE_CHECKING +from cirq import circuits from cirq.protocols import decompose_protocol as dp from cirq.transformers import transformer_api, transformer_primitives @@ -38,6 +39,7 @@ def _decompose_operations_to_target_gateset( gateset: Optional['cirq.Gateset'] = None, decomposer: Callable[['cirq.Operation', int], dp.DecomposeResult] = lambda *_: NotImplemented, ignore_failures: bool = True, + tags_to_decompose: Sequence[Hashable] = (), ) -> 'cirq.Circuit': """Decomposes every operation to `gateset` using `cirq.decompose` and `decomposer`. @@ -56,6 +58,8 @@ def _decompose_operations_to_target_gateset( - `None` or `NotImplemented` if does not know how to decompose a given `op`. ignore_failures: If set, operations that fail to convert are left unchanged. If not set, conversion failures raise a ValueError. + tags_to_decompose: `cirq.CircuitOperation`s tagged with any of `tags_to_decompose` will + be decomposed even if context.deep is True. Returns: An equivalent circuit containing gates accepted by `gateset`. @@ -65,6 +69,13 @@ def _decompose_operations_to_target_gateset( """ def map_func(op: 'cirq.Operation', moment_index: int): + if ( + context + and context.deep + and isinstance(op.untagged, circuits.CircuitOperation) + and set(op.tags).isdisjoint(tags_to_decompose) + ): + return op return dp.decompose( op, intercepting_decomposer=lambda o: decomposer(o, moment_index), @@ -77,7 +88,10 @@ def map_func(op: 'cirq.Operation', moment_index: int): ) return transformer_primitives.map_operations_and_unroll( - circuit, map_func, tags_to_ignore=context.tags_to_ignore if context else () + circuit, + map_func, + tags_to_ignore=context.tags_to_ignore if context else (), + deep=context.deep if context else False, ).unfreeze(copy=False) @@ -122,6 +136,7 @@ def optimize_for_target_gateset( gateset=gateset, decomposer=gateset.decompose_to_target_gateset, ignore_failures=ignore_failures, + tags_to_decompose=(gateset._intermediate_result_tag,), ) for transformer in gateset.postprocess_transformers: diff --git a/cirq-core/cirq/transformers/optimize_for_target_gateset_test.py b/cirq-core/cirq/transformers/optimize_for_target_gateset_test.py index e923ceb60db..0efb950fa7a 100644 --- a/cirq-core/cirq/transformers/optimize_for_target_gateset_test.py +++ b/cirq-core/cirq/transformers/optimize_for_target_gateset_test.py @@ -196,3 +196,53 @@ def test_optimize_for_target_gateset(): _ = cirq.optimize_for_target_gateset( c_orig, gateset=gateset, context=context, ignore_failures=False ) + + +def test_optimize_for_target_gateset_deep(): + q0, q1 = cirq.LineQubit.range(2) + c_nested = cirq.FrozenCircuit(cirq.CX(q0, q1)) + c_orig = cirq.Circuit( + cirq.CircuitOperation( + cirq.FrozenCircuit(cirq.H(q0), cirq.CircuitOperation(c_nested).repeat(3)) + ).repeat(5) + ) + c_expected = cirq.Circuit( + cirq.CircuitOperation( + cirq.FrozenCircuit( + cirq.single_qubit_matrix_to_phxz(cirq.unitary(cirq.H(q0))).on(q0), + cirq.CircuitOperation( + cirq.FrozenCircuit( + cirq.MatrixGate(c_nested.unitary(qubit_order=[q0, q1]), name="M").on(q0, q1) + ) + ).repeat(3), + ) + ).repeat(5) + ) + gateset = MatrixGateTargetGateset() + context = cirq.TransformerContext(deep=True) + c_new = cirq.optimize_for_target_gateset(c_orig, gateset=gateset, context=context) + cirq.testing.assert_circuits_with_terminal_measurements_are_equivalent(c_new, c_expected) + cirq.testing.assert_has_diagram( + c_orig, + ''' + [ [ 0: ───@─── ] ] + [ 0: ───H───[ │ ]──────────── ] +0: ───[ [ 1: ───X─── ](loops=3) ]──────────── + [ │ ] + [ 1: ───────#2──────────────────────── ](loops=5) + │ +1: ───#2────────────────────────────────────────────────── +''', + ) + cirq.testing.assert_has_diagram( + c_new, + ''' + [ [ 0: ───M[1]─── ] ] + [ 0: ───PhXZ(a=-0.5,x=0.5,z=-1)───[ │ ]──────────── ] +0: ───[ [ 1: ───M[2]─── ](loops=3) ]──────────── + [ │ ] + [ 1: ─────────────────────────────#2─────────────────────────── ](loops=5) + │ +1: ───#2─────────────────────────────────────────────────────────────────────────── +''', + )