From a95f009617bcdc613cfbca169fbedbfea1de85b7 Mon Sep 17 00:00:00 2001 From: Tanuj Khattar Date: Tue, 9 May 2023 17:57:06 -0700 Subject: [PATCH] Override gate.controlled() for GlobalPhaseGate to return a ZPowGate (#6073) * Override gate.controlled() for GlobalPhaseGate to return a ZPowGate * Test unitary equivalence * Override controlled only if gate is not parameterized * Fix typo * Fix type check * another attempt at fixing types * Add a comment and additional tests --- cirq-core/cirq/ops/global_phase_op.py | 30 ++++++++++++++++++++-- cirq-core/cirq/ops/global_phase_op_test.py | 21 +++++++++++++++ 2 files changed, 49 insertions(+), 2 deletions(-) diff --git a/cirq-core/cirq/ops/global_phase_op.py b/cirq-core/cirq/ops/global_phase_op.py index aae1c8db86d..e1a66272244 100644 --- a/cirq-core/cirq/ops/global_phase_op.py +++ b/cirq-core/cirq/ops/global_phase_op.py @@ -13,14 +13,14 @@ # limitations under the License. """A no-qubit global phase operation.""" -from typing import AbstractSet, Any, cast, Dict, Sequence, Tuple, Union +from typing import AbstractSet, Any, cast, Dict, Sequence, Tuple, Union, Optional, Collection import numpy as np import sympy import cirq from cirq import value, protocols -from cirq.ops import raw_types +from cirq.ops import raw_types, controlled_gate, control_values as cv from cirq.type_workarounds import NotImplementedType @@ -91,6 +91,32 @@ def _resolve_parameters_( coefficient = protocols.resolve_parameters(self.coefficient, resolver, recursive) return GlobalPhaseGate(coefficient=coefficient) + def controlled( + self, + num_controls: Optional[int] = None, + control_values: Optional[ + Union[cv.AbstractControlValues, Sequence[Union[int, Collection[int]]]] + ] = None, + control_qid_shape: Optional[Tuple[int, ...]] = None, + ) -> raw_types.Gate: + result = super().controlled(num_controls, control_values, control_qid_shape) + if ( + not self._is_parameterized_() + and isinstance(result, controlled_gate.ControlledGate) + and isinstance(result.control_values, cv.ProductOfSums) + and result.control_values[-1] == (1,) + and result.control_qid_shape[-1] == 2 + ): + # A `GlobalPhaseGate` controlled on a qubit in state `|1>` is equivalent + # to applying a `ZPowGate`. This override ensures that `global_phase_gate.controlled()` + # returns a `ZPowGate` instead of a `ControlledGate(sub_gate=global_phase_gate)`. + coefficient = complex(self.coefficient) + exponent = float(np.angle(coefficient) / np.pi) + return cirq.ZPowGate(exponent=exponent).controlled( + result.num_controls() - 1, result.control_values[:-1], result.control_qid_shape[:-1] + ) + return result + def global_phase_operation( coefficient: 'cirq.TParamValComplex', atol: float = 1e-8 diff --git a/cirq-core/cirq/ops/global_phase_op_test.py b/cirq-core/cirq/ops/global_phase_op_test.py index 9b4ff15b332..37aa560a0ac 100644 --- a/cirq-core/cirq/ops/global_phase_op_test.py +++ b/cirq-core/cirq/ops/global_phase_op_test.py @@ -279,3 +279,24 @@ def test_resolve_error(resolve_fn): gpt = cirq.GlobalPhaseGate(coefficient=t) with pytest.raises(ValueError, match='Coefficient is not unitary'): resolve_fn(gpt, {'t': -2}) + + +@pytest.mark.parametrize( + 'coeff, exp', [(-1, 1), (1j, 0.5), (-1j, -0.5), (1 / np.sqrt(2) * (1 + 1j), 0.25)] +) +def test_global_phase_gate_controlled(coeff, exp): + g = cirq.GlobalPhaseGate(coeff) + op = cirq.global_phase_operation(coeff) + q = cirq.LineQubit.range(3) + for num_controls, target_gate in zip(range(1, 4), [cirq.Z, cirq.CZ, cirq.CCZ]): + assert g.controlled(num_controls) == target_gate**exp + np.testing.assert_allclose( + cirq.unitary(cirq.ControlledGate(g, num_controls)), + cirq.unitary(g.controlled(num_controls)), + ) + assert op.controlled_by(*q[:num_controls]) == target_gate(*q[:num_controls]) ** exp + assert g.controlled(control_values=[0]) == cirq.ControlledGate(g, control_values=[0]) + xor_control_values = cirq.SumOfProducts(((0, 0), (1, 1))) + assert g.controlled(control_values=xor_control_values) == cirq.ControlledGate( + g, control_values=xor_control_values + )