Skip to content

Commit

Permalink
Add support for deep=True to cirq.align_left and cirq.align_right
Browse files Browse the repository at this point in the history
… transformers (#5112)

- Adds support to recursively run `cirq.align_left` and `cirq.align_right` transformers on circuits wrapped inside a circuit operation by setting `deep=True` in transformer context.
- Part of #5039
  • Loading branch information
tanujkhattar committed Mar 22, 2022
1 parent ca4bb72 commit 04d58c4
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 2 deletions.
7 changes: 5 additions & 2 deletions cirq-core/cirq/transformers/align.py
Expand Up @@ -14,6 +14,7 @@

"""Transformer passes which align operations to the left or right of the circuit."""

import dataclasses
from typing import Optional, TYPE_CHECKING
from cirq import circuits, ops
from cirq.transformers import transformer_api
Expand All @@ -22,7 +23,7 @@
import cirq


@transformer_api.transformer
@transformer_api.transformer(add_deep_support=True)
def align_left(
circuit: 'cirq.AbstractCircuit', *, context: Optional['cirq.TransformerContext'] = None
) -> 'cirq.Circuit':
Expand Down Expand Up @@ -54,7 +55,7 @@ def align_left(
return ret


@transformer_api.transformer
@transformer_api.transformer(add_deep_support=True)
def align_right(
circuit: 'cirq.AbstractCircuit', *, context: Optional['cirq.TransformerContext'] = None
) -> 'cirq.Circuit':
Expand All @@ -70,4 +71,6 @@ def align_right(
Returns:
Copy of the transformed input circuit.
"""
if context is not None and context.deep is True:
context = dataclasses.replace(context, deep=False)
return align_left(circuit[::-1], context=context)[::-1]
68 changes: 68 additions & 0 deletions cirq-core/cirq/transformers/align_test.py
Expand Up @@ -71,6 +71,41 @@ def test_align_left_no_compile_context():
)


def test_align_left_deep():
q1, q2 = cirq.LineQubit.range(2)
c_nested = cirq.FrozenCircuit(
[
cirq.Moment([cirq.X(q1)]),
cirq.Moment([cirq.Y(q2)]),
cirq.Moment([cirq.Z(q1), cirq.Y(q2).with_tags("nocompile")]),
cirq.Moment([cirq.Y(q1)]),
cirq.measure(q2, key='a'),
cirq.Z(q1).with_classical_controls('a'),
]
)
c_nested_aligned = cirq.FrozenCircuit(
cirq.Moment(cirq.X(q1), cirq.Y(q2)),
cirq.Moment(cirq.Z(q1)),
cirq.Moment([cirq.Y(q1), cirq.Y(q2).with_tags("nocompile")]),
cirq.measure(q2, key='a'),
cirq.Z(q1).with_classical_controls('a'),
)
c_orig = cirq.Circuit(
c_nested,
cirq.CircuitOperation(c_nested).repeat(6).with_tags("nocompile"),
c_nested,
cirq.CircuitOperation(c_nested).repeat(5).with_tags("preserve_tag"),
)
c_expected = cirq.Circuit(
c_nested_aligned,
cirq.CircuitOperation(c_nested).repeat(6).with_tags("nocompile"),
c_nested_aligned,
cirq.CircuitOperation(c_nested_aligned).repeat(5).with_tags("preserve_tag"),
)
context = cirq.TransformerContext(tags_to_ignore=["nocompile"], deep=True)
cirq.testing.assert_same_circuits(cirq.align_left(c_orig, context=context), c_expected)


def test_align_left_subset_of_operations():
q1 = cirq.NamedQubit('q1')
q2 = cirq.NamedQubit('q2')
Expand Down Expand Up @@ -133,6 +168,39 @@ def test_align_right_no_compile_context():
)


def test_align_right_deep():
q1, q2 = cirq.LineQubit.range(2)
c_nested = cirq.FrozenCircuit(
cirq.Moment([cirq.X(q1)]),
cirq.Moment([cirq.Y(q1), cirq.X(q2).with_tags("nocompile")]),
cirq.Moment([cirq.X(q2)]),
cirq.Moment([cirq.Y(q1)]),
cirq.measure(q1, key='a'),
cirq.Z(q2).with_classical_controls('a'),
)
c_nested_aligned = cirq.FrozenCircuit(
cirq.Moment([cirq.X(q1), cirq.X(q2).with_tags("nocompile")]),
[cirq.Y(q1), cirq.Y(q1)],
cirq.Moment(cirq.measure(q1, key='a'), cirq.X(q2)),
cirq.Z(q2).with_classical_controls('a'),
)
c_orig = cirq.Circuit(
c_nested,
cirq.CircuitOperation(c_nested).repeat(6).with_tags("nocompile"),
c_nested,
cirq.CircuitOperation(c_nested).repeat(5).with_tags("preserve_tag"),
)
c_expected = cirq.Circuit(
c_nested_aligned,
cirq.CircuitOperation(c_nested).repeat(6).with_tags("nocompile"),
cirq.Moment(),
c_nested_aligned,
cirq.CircuitOperation(c_nested_aligned).repeat(5).with_tags("preserve_tag"),
)
context = cirq.TransformerContext(tags_to_ignore=["nocompile"], deep=True)
cirq.testing.assert_same_circuits(cirq.align_right(c_orig, context=context), c_expected)


def test_classical_control():
q0, q1 = cirq.LineQubit.range(2)
circuit = cirq.Circuit(
Expand Down

0 comments on commit 04d58c4

Please sign in to comment.