From 04d58c4e489e9d22b8ae0ce30d1ea0fda2643c40 Mon Sep 17 00:00:00 2001 From: Tanuj Khattar Date: Tue, 22 Mar 2022 05:45:50 +0530 Subject: [PATCH] Add support for deep=True to `cirq.align_left` and `cirq.align_right` transformers (#5112) - Adds support to recursively run `cirq.align_left` and `cirq.align_right` transformers on circuits wrapped inside a circuit operation by setting `deep=True` in transformer context. - Part of https://github.com/quantumlib/Cirq/issues/5039 --- cirq-core/cirq/transformers/align.py | 7 ++- cirq-core/cirq/transformers/align_test.py | 68 +++++++++++++++++++++++ 2 files changed, 73 insertions(+), 2 deletions(-) diff --git a/cirq-core/cirq/transformers/align.py b/cirq-core/cirq/transformers/align.py index e9f2561d79d..6f3f12f49d9 100644 --- a/cirq-core/cirq/transformers/align.py +++ b/cirq-core/cirq/transformers/align.py @@ -14,6 +14,7 @@ """Transformer passes which align operations to the left or right of the circuit.""" +import dataclasses from typing import Optional, TYPE_CHECKING from cirq import circuits, ops from cirq.transformers import transformer_api @@ -22,7 +23,7 @@ import cirq -@transformer_api.transformer +@transformer_api.transformer(add_deep_support=True) def align_left( circuit: 'cirq.AbstractCircuit', *, context: Optional['cirq.TransformerContext'] = None ) -> 'cirq.Circuit': @@ -54,7 +55,7 @@ def align_left( return ret -@transformer_api.transformer +@transformer_api.transformer(add_deep_support=True) def align_right( circuit: 'cirq.AbstractCircuit', *, context: Optional['cirq.TransformerContext'] = None ) -> 'cirq.Circuit': @@ -70,4 +71,6 @@ def align_right( Returns: Copy of the transformed input circuit. """ + if context is not None and context.deep is True: + context = dataclasses.replace(context, deep=False) return align_left(circuit[::-1], context=context)[::-1] diff --git a/cirq-core/cirq/transformers/align_test.py b/cirq-core/cirq/transformers/align_test.py index 5203f6c2b30..76525abeb85 100644 --- a/cirq-core/cirq/transformers/align_test.py +++ b/cirq-core/cirq/transformers/align_test.py @@ -71,6 +71,41 @@ def test_align_left_no_compile_context(): ) +def test_align_left_deep(): + q1, q2 = cirq.LineQubit.range(2) + c_nested = cirq.FrozenCircuit( + [ + cirq.Moment([cirq.X(q1)]), + cirq.Moment([cirq.Y(q2)]), + cirq.Moment([cirq.Z(q1), cirq.Y(q2).with_tags("nocompile")]), + cirq.Moment([cirq.Y(q1)]), + cirq.measure(q2, key='a'), + cirq.Z(q1).with_classical_controls('a'), + ] + ) + c_nested_aligned = cirq.FrozenCircuit( + cirq.Moment(cirq.X(q1), cirq.Y(q2)), + cirq.Moment(cirq.Z(q1)), + cirq.Moment([cirq.Y(q1), cirq.Y(q2).with_tags("nocompile")]), + cirq.measure(q2, key='a'), + cirq.Z(q1).with_classical_controls('a'), + ) + c_orig = cirq.Circuit( + c_nested, + cirq.CircuitOperation(c_nested).repeat(6).with_tags("nocompile"), + c_nested, + cirq.CircuitOperation(c_nested).repeat(5).with_tags("preserve_tag"), + ) + c_expected = cirq.Circuit( + c_nested_aligned, + cirq.CircuitOperation(c_nested).repeat(6).with_tags("nocompile"), + c_nested_aligned, + cirq.CircuitOperation(c_nested_aligned).repeat(5).with_tags("preserve_tag"), + ) + context = cirq.TransformerContext(tags_to_ignore=["nocompile"], deep=True) + cirq.testing.assert_same_circuits(cirq.align_left(c_orig, context=context), c_expected) + + def test_align_left_subset_of_operations(): q1 = cirq.NamedQubit('q1') q2 = cirq.NamedQubit('q2') @@ -133,6 +168,39 @@ def test_align_right_no_compile_context(): ) +def test_align_right_deep(): + q1, q2 = cirq.LineQubit.range(2) + c_nested = cirq.FrozenCircuit( + cirq.Moment([cirq.X(q1)]), + cirq.Moment([cirq.Y(q1), cirq.X(q2).with_tags("nocompile")]), + cirq.Moment([cirq.X(q2)]), + cirq.Moment([cirq.Y(q1)]), + cirq.measure(q1, key='a'), + cirq.Z(q2).with_classical_controls('a'), + ) + c_nested_aligned = cirq.FrozenCircuit( + cirq.Moment([cirq.X(q1), cirq.X(q2).with_tags("nocompile")]), + [cirq.Y(q1), cirq.Y(q1)], + cirq.Moment(cirq.measure(q1, key='a'), cirq.X(q2)), + cirq.Z(q2).with_classical_controls('a'), + ) + c_orig = cirq.Circuit( + c_nested, + cirq.CircuitOperation(c_nested).repeat(6).with_tags("nocompile"), + c_nested, + cirq.CircuitOperation(c_nested).repeat(5).with_tags("preserve_tag"), + ) + c_expected = cirq.Circuit( + c_nested_aligned, + cirq.CircuitOperation(c_nested).repeat(6).with_tags("nocompile"), + cirq.Moment(), + c_nested_aligned, + cirq.CircuitOperation(c_nested_aligned).repeat(5).with_tags("preserve_tag"), + ) + context = cirq.TransformerContext(tags_to_ignore=["nocompile"], deep=True) + cirq.testing.assert_same_circuits(cirq.align_right(c_orig, context=context), c_expected) + + def test_classical_control(): q0, q1 = cirq.LineQubit.range(2) circuit = cirq.Circuit(