diff --git a/cirq-core/cirq/transformers/eject_z.py b/cirq-core/cirq/transformers/eject_z.py index 7ebd12e1047..fc22718d06f 100644 --- a/cirq-core/cirq/transformers/eject_z.py +++ b/cirq-core/cirq/transformers/eject_z.py @@ -43,7 +43,7 @@ def _is_swaplike(gate: 'cirq.Gate'): return False -@transformer_api.transformer +@transformer_api.transformer(add_deep_support=True) def eject_z( circuit: 'cirq.AbstractCircuit', *, @@ -96,7 +96,7 @@ def map_func(op: 'cirq.Operation', moment_index: int) -> 'cirq.OP_TREE': gate = op.gate # Return if circuit operation. if gate is None: - return op + return [dump_tracked_phase(op.qubits), op] # Swap phases if `op` is a swap operation. if _is_swaplike(gate): diff --git a/cirq-core/cirq/transformers/eject_z_test.py b/cirq-core/cirq/transformers/eject_z_test.py index ca8743a35f1..d5bf72931fb 100644 --- a/cirq-core/cirq/transformers/eject_z_test.py +++ b/cirq-core/cirq/transformers/eject_z_test.py @@ -11,6 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +import dataclasses + import pytest import numpy as np import sympy @@ -39,6 +42,33 @@ def assert_optimizes( circuit = cirq.eject_z(before, eject_parameterized=eject_parameterized, context=context) cirq.testing.assert_same_circuits(circuit, expected) + # Nested sub-circuits should also get optimized. + q = before.all_qubits() + c_nested = cirq.Circuit( + [(cirq.Z ** 0.5).on_each(*q), (cirq.Y ** 0.25).on_each(*q)], + cirq.Moment(cirq.CircuitOperation(before.freeze()).repeat(2).with_tags("ignore")), + [(cirq.Z ** 0.5).on_each(*q), (cirq.Y ** 0.25).on_each(*q)], + cirq.Moment(cirq.CircuitOperation(before.freeze()).repeat(3).with_tags("preserve_tag")), + ) + c_expected = cirq.Circuit( + cirq.PhasedXPowGate(phase_exponent=0, exponent=0.25).on_each(*q), + (cirq.Z ** 0.5).on_each(*q), + cirq.Moment(cirq.CircuitOperation(before.freeze()).repeat(2).with_tags("ignore")), + cirq.PhasedXPowGate(phase_exponent=0, exponent=0.25).on_each(*q), + (cirq.Z ** 0.5).on_each(*q), + cirq.Moment(cirq.CircuitOperation(expected.freeze()).repeat(3).with_tags("preserve_tag")), + ) + if context is None: + context = cirq.TransformerContext(tags_to_ignore=("ignore",), deep=True) + else: + context = dataclasses.replace( + context, tags_to_ignore=context.tags_to_ignore + ("ignore",), deep=True + ) + c_nested = cirq.eject_z(c_nested, context=context, eject_parameterized=eject_parameterized) + cirq.testing.assert_same_circuits(c_nested, c_expected) + c_nested = cirq.eject_z(c_nested, context=context, eject_parameterized=eject_parameterized) + cirq.testing.assert_same_circuits(c_nested, c_expected) + def assert_removes_all_z_gates(circuit: cirq.Circuit, eject_parameterized: bool = True): optimized = cirq.eject_z(circuit, eject_parameterized=eject_parameterized)