Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for deep=True to cirq.optimize_for_target_gateset transformer #5124

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
19 changes: 17 additions & 2 deletions cirq-core/cirq/transformers/optimize_for_target_gateset.py
Expand Up @@ -14,8 +14,9 @@

"""Transformers to rewrite a circuit using gates from a given target gateset."""

from typing import Optional, Callable, TYPE_CHECKING
from typing import Optional, Callable, Hashable, Sequence, TYPE_CHECKING

from cirq import circuits
from cirq.protocols import decompose_protocol as dp
from cirq.transformers import transformer_api, transformer_primitives

Expand All @@ -38,6 +39,7 @@ def _decompose_operations_to_target_gateset(
gateset: Optional['cirq.Gateset'] = None,
decomposer: Callable[['cirq.Operation', int], dp.DecomposeResult] = lambda *_: NotImplemented,
ignore_failures: bool = True,
tags_to_decompose: Sequence[Hashable] = (),
) -> 'cirq.Circuit':
"""Decomposes every operation to `gateset` using `cirq.decompose` and `decomposer`.

Expand All @@ -56,6 +58,8 @@ def _decompose_operations_to_target_gateset(
- `None` or `NotImplemented` if does not know how to decompose a given `op`.
ignore_failures: If set, operations that fail to convert are left unchanged. If not set,
conversion failures raise a ValueError.
tags_to_decompose: `cirq.CircuitOperation`s tagged with any of `tags_to_decompose` will
be decomposed even if context.deep is True.

Returns:
An equivalent circuit containing gates accepted by `gateset`.
Expand All @@ -65,6 +69,13 @@ def _decompose_operations_to_target_gateset(
"""

def map_func(op: 'cirq.Operation', moment_index: int):
if (
context
and context.deep
and isinstance(op.untagged, circuits.CircuitOperation)
and set(op.tags).isdisjoint(tags_to_decompose)
):
return op
return dp.decompose(
op,
intercepting_decomposer=lambda o: decomposer(o, moment_index),
Expand All @@ -77,7 +88,10 @@ def map_func(op: 'cirq.Operation', moment_index: int):
)

return transformer_primitives.map_operations_and_unroll(
circuit, map_func, tags_to_ignore=context.tags_to_ignore if context else ()
circuit,
map_func,
tags_to_ignore=context.tags_to_ignore if context else (),
deep=context.deep if context else False,
).unfreeze(copy=False)


Expand Down Expand Up @@ -122,6 +136,7 @@ def optimize_for_target_gateset(
gateset=gateset,
decomposer=gateset.decompose_to_target_gateset,
ignore_failures=ignore_failures,
tags_to_decompose=(gateset._intermediate_result_tag,),
)

for transformer in gateset.postprocess_transformers:
Expand Down
50 changes: 50 additions & 0 deletions cirq-core/cirq/transformers/optimize_for_target_gateset_test.py
Expand Up @@ -196,3 +196,53 @@ def test_optimize_for_target_gateset():
_ = cirq.optimize_for_target_gateset(
c_orig, gateset=gateset, context=context, ignore_failures=False
)


def test_optimize_for_target_gateset_deep():
q0, q1 = cirq.LineQubit.range(2)
c_nested = cirq.FrozenCircuit(cirq.CX(q0, q1))
c_orig = cirq.Circuit(
cirq.CircuitOperation(
cirq.FrozenCircuit(cirq.H(q0), cirq.CircuitOperation(c_nested).repeat(3))
).repeat(5)
)
c_expected = cirq.Circuit(
cirq.CircuitOperation(
cirq.FrozenCircuit(
cirq.single_qubit_matrix_to_phxz(cirq.unitary(cirq.H(q0))).on(q0),
cirq.CircuitOperation(
cirq.FrozenCircuit(
cirq.MatrixGate(c_nested.unitary(qubit_order=[q0, q1]), name="M").on(q0, q1)
)
).repeat(3),
)
).repeat(5)
)
gateset = MatrixGateTargetGateset()
context = cirq.TransformerContext(deep=True)
c_new = cirq.optimize_for_target_gateset(c_orig, gateset=gateset, context=context)
cirq.testing.assert_circuits_with_terminal_measurements_are_equivalent(c_new, c_expected)
cirq.testing.assert_has_diagram(
c_orig,
'''
[ [ 0: ───@─── ] ]
[ 0: ───H───[ │ ]──────────── ]
0: ───[ [ 1: ───X─── ](loops=3) ]────────────
[ │ ]
[ 1: ───────#2──────────────────────── ](loops=5)
1: ───#2──────────────────────────────────────────────────
''',
)
cirq.testing.assert_has_diagram(
c_new,
'''
[ [ 0: ───M[1]─── ] ]
[ 0: ───PhXZ(a=-0.5,x=0.5,z=-1)───[ │ ]──────────── ]
0: ───[ [ 1: ───M[2]─── ](loops=3) ]────────────
[ │ ]
[ 1: ─────────────────────────────#2─────────────────────────── ](loops=5)
1: ───#2───────────────────────────────────────────────────────────────────────────
''',
)