Skip to content

Commit

Permalink
Update ZPow.controlled specialization (#2572)
Browse files Browse the repository at this point in the history
Respect Z.controlled(a).controlled(b) === Z.controlled(b+a)
  • Loading branch information
smitsanghavi authored and CirqBot committed Nov 21, 2019
1 parent 52b14db commit e2446d4
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 9 deletions.
17 changes: 11 additions & 6 deletions cirq/ops/common_gates.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from cirq import protocols, value
from cirq._compat import proper_repr
from cirq._doc import document
from cirq.ops import gate_features, eigen_gate, raw_types
from cirq.ops import controlled_gate, gate_features, eigen_gate, raw_types

from cirq.type_workarounds import NotImplementedType

Expand Down Expand Up @@ -387,15 +387,20 @@ def controlled(self,
control_qid_shape: Optional[Tuple[int, ...]] = None
) -> raw_types.Gate:
"""
Specialize controlled for ZPow to return corresponding CZPow when
controlled by a single qubit.
Specialize controlled for ZPow to return corresponding controlled CZPow
when the last control (which acts first semantically) is a default-type
control qubit.
"""
result = super().controlled(num_controls, control_values,
control_qid_shape)
if (result.control_values == ((1,),) and # type: ignore
result.control_qid_shape == (2,)): # type: ignore
if (isinstance(result, controlled_gate.ControlledGate) and
result.control_values[-1] == (1,) and
result.control_qid_shape[-1] == 2):
return cirq.CZPowGate(exponent=self._exponent,
global_shift=self._global_shift)
global_shift=self._global_shift).controlled(
result.num_controls() - 1,
result.control_values[:-1],
result.control_qid_shape[:-1])
return result

def _eigen_components(self):
Expand Down
50 changes: 47 additions & 3 deletions cirq/ops/common_gates_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,15 +121,59 @@ def test_z_init():
assert cirq.Z**0.5 != cirq.Z**-0.5
assert (cirq.Z**-1)**0.5 == cirq.Z**-0.5
assert cirq.Z**-1 == cirq.Z
assert cirq.Z.controlled(num_controls=2) == cirq.ControlledGate(
cirq.Z, num_controls=2)


def test_z_control():
# Single qubit control on Z gives a CZ
assert cirq.Z.controlled() == cirq.CZ
assert cirq.Z.controlled(num_controls=1) == cirq.CZ
assert cirq.Z.controlled(control_values=((1,),)) == cirq.CZ
assert cirq.Z.controlled(control_qid_shape=(2,)) == cirq.CZ

# Also works for any ZPow.
assert cirq.ZPowGate(exponent=5).controlled() == cirq.CZPowGate(exponent=5)

# For multi-qudit controls, if the last control is a qubit with control
# value 1, construct a CZ leaving the rest of the controls as is.
assert cirq.Z.controlled().controlled() == cirq.ControlledGate(
cirq.CZ, num_controls=1)
assert cirq.Z.controlled(num_controls=2) == cirq.ControlledGate(
cirq.CZ, num_controls=1)
assert cirq.Z.controlled(control_values=((0,), (0,),
(1,))) == cirq.ControlledGate(
cirq.CZ,
num_controls=2,
control_values=((0,), (0,)))
assert cirq.Z.controlled(control_qid_shape=(3, 3,
2)) == cirq.ControlledGate(
cirq.CZ,
num_controls=2,
control_qid_shape=(3, 3))
assert cirq.Z.controlled(control_qid_shape=(2,)).controlled(
control_qid_shape=(3,)).controlled(
control_qid_shape=(4,)) == cirq.ControlledGate(
cirq.CZ, num_controls=2, control_qid_shape=(3, 4))

# When a control_value 1 qubit is not acting first, results in a regular
# ControlledGate on Z instance.
assert cirq.Z.controlled(num_controls=1,
control_qid_shape=(3,)) == cirq.ControlledGate(
cirq.Z, num_controls=1, control_qid_shape=(3,))
assert z.controlled() == cirq.CZPowGate(exponent=5)
assert cirq.Z.controlled(control_values=((0,), (1,),
(0,))) == cirq.ControlledGate(
cirq.Z,
num_controls=3,
control_values=((0,), (1,),
(0,)))
assert cirq.Z.controlled(control_qid_shape=(3, 2,
3)) == cirq.ControlledGate(
cirq.Z,
num_controls=3,
control_qid_shape=(3, 2, 3))
assert cirq.Z.controlled(control_qid_shape=(3,)).controlled(
control_qid_shape=(2,)).controlled(
control_qid_shape=(4,)) == cirq.ControlledGate(
cirq.Z, num_controls=3, control_qid_shape=(3, 2, 4))


def test_rot_gates_eq():
Expand Down

0 comments on commit e2446d4

Please sign in to comment.