Skip to content

Commit

Permalink
Override gate.controlled() for GlobalPhaseGate to return a ZPowGate (#…
Browse files Browse the repository at this point in the history
…6073)

* Override gate.controlled() for GlobalPhaseGate to return a ZPowGate

* Test unitary equivalence

* Override controlled only if gate is not parameterized

* Fix typo

* Fix type check

* another attempt at fixing types

* Add a comment and additional tests
  • Loading branch information
tanujkhattar committed May 10, 2023
1 parent f2cd706 commit a95f009
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 2 deletions.
30 changes: 28 additions & 2 deletions cirq-core/cirq/ops/global_phase_op.py
Expand Up @@ -13,14 +13,14 @@
# limitations under the License.
"""A no-qubit global phase operation."""

from typing import AbstractSet, Any, cast, Dict, Sequence, Tuple, Union
from typing import AbstractSet, Any, cast, Dict, Sequence, Tuple, Union, Optional, Collection

import numpy as np
import sympy

import cirq
from cirq import value, protocols
from cirq.ops import raw_types
from cirq.ops import raw_types, controlled_gate, control_values as cv
from cirq.type_workarounds import NotImplementedType


Expand Down Expand Up @@ -91,6 +91,32 @@ def _resolve_parameters_(
coefficient = protocols.resolve_parameters(self.coefficient, resolver, recursive)
return GlobalPhaseGate(coefficient=coefficient)

def controlled(
self,
num_controls: Optional[int] = None,
control_values: Optional[
Union[cv.AbstractControlValues, Sequence[Union[int, Collection[int]]]]
] = None,
control_qid_shape: Optional[Tuple[int, ...]] = None,
) -> raw_types.Gate:
result = super().controlled(num_controls, control_values, control_qid_shape)
if (
not self._is_parameterized_()
and isinstance(result, controlled_gate.ControlledGate)
and isinstance(result.control_values, cv.ProductOfSums)
and result.control_values[-1] == (1,)
and result.control_qid_shape[-1] == 2
):
# A `GlobalPhaseGate` controlled on a qubit in state `|1>` is equivalent
# to applying a `ZPowGate`. This override ensures that `global_phase_gate.controlled()`
# returns a `ZPowGate` instead of a `ControlledGate(sub_gate=global_phase_gate)`.
coefficient = complex(self.coefficient)
exponent = float(np.angle(coefficient) / np.pi)
return cirq.ZPowGate(exponent=exponent).controlled(
result.num_controls() - 1, result.control_values[:-1], result.control_qid_shape[:-1]
)
return result


def global_phase_operation(
coefficient: 'cirq.TParamValComplex', atol: float = 1e-8
Expand Down
21 changes: 21 additions & 0 deletions cirq-core/cirq/ops/global_phase_op_test.py
Expand Up @@ -279,3 +279,24 @@ def test_resolve_error(resolve_fn):
gpt = cirq.GlobalPhaseGate(coefficient=t)
with pytest.raises(ValueError, match='Coefficient is not unitary'):
resolve_fn(gpt, {'t': -2})


@pytest.mark.parametrize(
'coeff, exp', [(-1, 1), (1j, 0.5), (-1j, -0.5), (1 / np.sqrt(2) * (1 + 1j), 0.25)]
)
def test_global_phase_gate_controlled(coeff, exp):
g = cirq.GlobalPhaseGate(coeff)
op = cirq.global_phase_operation(coeff)
q = cirq.LineQubit.range(3)
for num_controls, target_gate in zip(range(1, 4), [cirq.Z, cirq.CZ, cirq.CCZ]):
assert g.controlled(num_controls) == target_gate**exp
np.testing.assert_allclose(
cirq.unitary(cirq.ControlledGate(g, num_controls)),
cirq.unitary(g.controlled(num_controls)),
)
assert op.controlled_by(*q[:num_controls]) == target_gate(*q[:num_controls]) ** exp
assert g.controlled(control_values=[0]) == cirq.ControlledGate(g, control_values=[0])
xor_control_values = cirq.SumOfProducts(((0, 0), (1, 1)))
assert g.controlled(control_values=xor_control_values) == cirq.ControlledGate(
g, control_values=xor_control_values
)

0 comments on commit a95f009

Please sign in to comment.