Skip to content

Commit

Permalink
Resolve inconsistencies in using controlled gates & controlled operat…
Browse files Browse the repository at this point in the history
…ions (#4167)

This PR tries to resolve the inconsistencies mentioned in #4172. Specifically, 

- Equality b/w controlled `TOFFOLI`s with different qubit ordering on controls, i.e. `TOFFOLI(a,b,c).controlled_by(d)  == TOFFOLI(d,b,c).controlled_by(a)`  (and other 3Q controlled gates like CCZ, CSWAP). To achieve this, we override the `controlled` method on `CCX`, `CCZ` etc. to return a `ControlledGate` with `sub_gate = X` in case of `CCX` s.t. `TOFFOLI(a,b,c).controlled_by(d) == CCCX(a, b, d, c) == TOFFOLI(d,b,c).controlled_by(a)` instead of `CTOFFOLI`
- `gate_operation.controlled_by` now forwards the request to `gate.controlled` to first create a controlled gate and then apply it on qubits to create a `ControlledOperation`. This solves the original use case of adding specialized controls, requested in #2142, i.e. `cirq.Z(q0).controlled_by(q1) == cirq.CZ(q1, q0)`. 
- Fixes #4515


Note, this is a breaking change because
- `gate_operation.controlled_by()` can now return a `cirq.GateOperation` instead of `cirq.ControlledOperation` in cases where the underlying gates have specialized `gate.controlled()` implementations. 
  - This also leads to a change in diagram of the controlled gates with specialized controlled implementations. For eg: Controlled S gate is now plotted as `CZPowGate` (`@ --- @ ** 0.5`) instead of `ControlledOperation` with Z ** 0.5 as subgate(`@ ---- S`)
- `op.controlled_by` for 3Q gates like `CCX`, `CCZ`,  `CSWAP` will now return `ControlledOperation` with `sub_operation = <underlying non-controlled gate>`. Eg: `CCCX` (i.e. `sub_gate = X`) instead of `CTOFFOLI` (i.e. `sub_gate = TOFFOLI`) etc.
- Diagrams for `ControlledOperations` will now always have the exponent drawn on the target qubit (in case of multi qubit `sub_operation`, the exponent will always be on the first qubit if not the underlying gate does not explicitly specify a target index).
  • Loading branch information
tanujkhattar committed Oct 1, 2021
1 parent 9d0ac9c commit f48efe0
Show file tree
Hide file tree
Showing 15 changed files with 316 additions and 35 deletions.
7 changes: 1 addition & 6 deletions cirq-core/cirq/interop/quirk/cells/input_rotation_cells.py
Expand Up @@ -142,15 +142,10 @@ def _circuit_diagram_info_(self, args: 'cirq.CircuitDiagramInfoArgs'):
sign_char = '-' if self.exponent_sign == -1 else ''
symbols = list(sub_result.wire_symbols)
symbols.extend(f'A{i}' for i in range(len(self.register)))
qubit_index = (
len(self.base_operation.controls)
if isinstance(self.base_operation, ops.ControlledOperation)
else 0
)
return cirq.CircuitDiagramInfo(
tuple(symbols),
exponent=f'({sign_char}A/2^{len(self.register)})',
exponent_qubit_index=qubit_index,
exponent_qubit_index=sub_result.exponent_qubit_index or 0,
auto_exponent_parens=False,
)

Expand Down
Expand Up @@ -64,9 +64,9 @@ def test_input_rotation_cells():
assert_url_to_circuit_returns(
'{"cols":[["•","Z^(A/2^n)","inputA2"]]}',
diagram="""
0: ───@───────────
0: ───@^(A/2^2)───
1: ───Z^(A/2^2)───
1: ───@───────────
2: ───A0──────────
Expand Down
30 changes: 15 additions & 15 deletions cirq-core/cirq/interop/quirk/url_to_circuit_test.py

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions cirq-core/cirq/ops/controlled_gate_test.py
Expand Up @@ -485,7 +485,7 @@ def test_circuit_diagram():
class MockGate(cirq.testing.TwoQubitGate):
def _circuit_diagram_info_(self, args: cirq.CircuitDiagramInfoArgs) -> cirq.CircuitDiagramInfo:
self.captured_diagram_args = args
return cirq.CircuitDiagramInfo(wire_symbols=tuple(['MOCK']), exponent=1, connected=True)
return cirq.CircuitDiagramInfo(wire_symbols=tuple(['M1', 'M2']), exponent=1, connected=True)


def test_uninformed_circuit_diagram_info():
Expand All @@ -496,7 +496,7 @@ def test_uninformed_circuit_diagram_info():
args = cirq.CircuitDiagramInfoArgs.UNINFORMED_DEFAULT

assert cirq.circuit_diagram_info(cgate, args) == cirq.CircuitDiagramInfo(
wire_symbols=('@', 'MOCK'), exponent=1, connected=True
wire_symbols=('@', 'M1', 'M2'), exponent=1, connected=True, exponent_qubit_index=1
)
assert mock_gate.captured_diagram_args == args

Expand Down
13 changes: 10 additions & 3 deletions cirq-core/cirq/ops/controlled_operation.py
Expand Up @@ -237,12 +237,19 @@ def get_symbol(vals):
return f"({','.join(map(str, vals))})"

wire_symbols = (*(get_symbol(vals) for vals in self.control_values), *sub_info.wire_symbols)
exponent_qubit_index = None
if sub_info.exponent_qubit_index is not None:
exponent_qubit_index = sub_info.exponent_qubit_index + len(self.control_values)
elif sub_info.exponent is not None:
# For a multi-qubit `sub_operation`, if the `exponent_qubit_index` is None, the qubit
# on which the exponent gets drawn in the controlled case (smallest ordered qubit of
# sub_operation) can be different from the uncontrolled case (lexicographically largest
# qubit of sub_operation). See tests for example.
exponent_qubit_index = len(self.control_values)
return protocols.CircuitDiagramInfo(
wire_symbols=wire_symbols,
exponent=sub_info.exponent,
exponent_qubit_index=None
if sub_info.exponent_qubit_index is None
else sub_info.exponent_qubit_index + 1,
exponent_qubit_index=exponent_qubit_index,
)

def _json_dict_(self) -> Dict[str, Any]:
Expand Down
21 changes: 19 additions & 2 deletions cirq-core/cirq/ops/controlled_operation_test.py
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Union, Tuple, cast
import itertools

import numpy as np
import pytest
Expand Down Expand Up @@ -247,11 +248,27 @@ def test_circuit_diagram():


class MockGate(cirq.testing.TwoQubitGate):
def __init__(self, exponent_qubit_index=None):
self._exponent_qubit_index = exponent_qubit_index

def _circuit_diagram_info_(
self, args: protocols.CircuitDiagramInfoArgs
) -> protocols.CircuitDiagramInfo:
self.captured_diagram_args = args
return cirq.CircuitDiagramInfo(wire_symbols=tuple(['MOCK']), exponent=1, connected=True)
return cirq.CircuitDiagramInfo(
wire_symbols=tuple(['M1', 'M2']),
exponent=1,
exponent_qubit_index=self._exponent_qubit_index,
connected=True,
)


def test_controlled_diagram_exponent():
for q in itertools.permutations(cirq.LineQubit.range(5)):
for idx in [None, 0, 1]:
op = MockGate(idx)(*q[:2]).controlled_by(*q[2:])
add = 0 if idx is None else idx
assert cirq.circuit_diagram_info(op).exponent_qubit_index == len(q[2:]) + add


def test_uninformed_circuit_diagram_info():
Expand All @@ -262,7 +279,7 @@ def test_uninformed_circuit_diagram_info():
args = protocols.CircuitDiagramInfoArgs.UNINFORMED_DEFAULT

assert cirq.circuit_diagram_info(c_op, args) == cirq.CircuitDiagramInfo(
wire_symbols=('@', 'MOCK'), exponent=1, connected=True
wire_symbols=('@', 'M1', 'M2'), exponent=1, connected=True, exponent_qubit_index=1
)
assert mock_gate.captured_diagram_args == args

Expand Down
15 changes: 15 additions & 0 deletions cirq-core/cirq/ops/gate_operation.py
Expand Up @@ -19,6 +19,7 @@
AbstractSet,
Any,
cast,
Collection,
Dict,
FrozenSet,
Iterable,
Expand Down Expand Up @@ -351,5 +352,19 @@ def _equal_up_to_global_phase_(
return False
return protocols.equal_up_to_global_phase(self.gate, other.gate, atol=atol)

def controlled_by(
self,
*control_qubits: 'cirq.Qid',
control_values: Optional[Sequence[Union[int, Collection[int]]]] = None,
) -> 'cirq.Operation':
if len(control_qubits) == 0:
return self
qubits = tuple(control_qubits)
return self._gate.controlled(
num_controls=len(qubits),
control_values=control_values,
control_qid_shape=tuple(q.dimension for q in qubits),
).on(*(qubits + self._qubits))


TV = TypeVar('TV', bound=raw_types.Gate)
8 changes: 4 additions & 4 deletions cirq-core/cirq/ops/raw_types_test.py
Expand Up @@ -127,13 +127,13 @@ def test_gate():


def test_op():
a, b, c = cirq.LineQubit.range(3)
a, b, c, d = cirq.LineQubit.range(4)
g = ValiGate()
op = g(a)
op = g(a, b)
assert op.controlled_by() is op
controlled_op = op.controlled_by(b, c)
controlled_op = op.controlled_by(c, d)
assert controlled_op.sub_operation == op
assert controlled_op.controls == (b, c)
assert controlled_op.controls == (c, d)


def test_op_validate():
Expand Down
85 changes: 84 additions & 1 deletion cirq-core/cirq/ops/three_qubit_gates.py
Expand Up @@ -14,7 +14,17 @@

"""Common quantum gates that target three qubits."""

from typing import AbstractSet, Any, List, Optional, Tuple, TYPE_CHECKING
from typing import (
AbstractSet,
Any,
Collection,
List,
Optional,
Sequence,
Tuple,
TYPE_CHECKING,
Union,
)

import numpy as np
import sympy
Expand All @@ -30,6 +40,7 @@
pauli_gates,
raw_types,
swap_gates,
raw_types,
)

if TYPE_CHECKING:
Expand Down Expand Up @@ -169,6 +180,31 @@ def __str__(self) -> str:
def _num_qubits_(self) -> int:
return 3

def controlled(
self,
num_controls: int = None,
control_values: Optional[Sequence[Union[int, Collection[int]]]] = None,
control_qid_shape: Optional[Tuple[int, ...]] = None,
) -> raw_types.Gate:
"""Returns a controlled `ZPowGate` with two additional controls.
The `controlled` method of the `Gate` class, of which this class is a
child, returns a `ControlledGate` with `sub_gate = self`. This method
overrides this behavior to return a `ControlledGate` with
`sub_gate = ZPowGate`.
"""
if num_controls == 0:
return self
return controlled_gate.ControlledGate(
controlled_gate.ControlledGate(
common_gates.ZPowGate(exponent=self._exponent, global_shift=self._global_shift),
num_controls=2,
),
num_controls=num_controls,
control_values=control_values,
control_qid_shape=control_qid_shape,
)


@value.value_equality()
class ThreeQubitDiagonalGate(raw_types.Gate):
Expand Down Expand Up @@ -432,6 +468,31 @@ def __str__(self) -> str:
def _num_qubits_(self) -> int:
return 3

def controlled(
self,
num_controls: int = None,
control_values: Optional[Sequence[Union[int, Collection[int]]]] = None,
control_qid_shape: Optional[Tuple[int, ...]] = None,
) -> raw_types.Gate:
"""Returns a controlled `XPowGate` with two additional controls.
The `controlled` method of the `Gate` class, of which this class is a
child, returns a `ControlledGate` with `sub_gate = self`. This method
overrides this behavior to return a `ControlledGate` with
`sub_gate = XPowGate`.
"""
if num_controls == 0:
return self
return controlled_gate.ControlledGate(
controlled_gate.ControlledGate(
common_gates.XPowGate(exponent=self._exponent, global_shift=self._global_shift),
num_controls=2,
),
num_controls=num_controls,
control_values=control_values,
control_qid_shape=control_qid_shape,
)


@value.value_equality()
class CSwapGate(gate_features.InterchangeableQubitsGate, raw_types.Gate):
Expand Down Expand Up @@ -580,6 +641,28 @@ def __repr__(self) -> str:
def _num_qubits_(self) -> int:
return 3

def controlled(
self,
num_controls: int = None,
control_values: Optional[Sequence[Union[int, Collection[int]]]] = None,
control_qid_shape: Optional[Tuple[int, ...]] = None,
) -> raw_types.Gate:
"""Returns a controlled `SWAP` with one additional control.
The `controlled` method of the `Gate` class, of which this class is a
child, returns a `ControlledGate` with `sub_gate = self`. This method
overrides this behavior to return a `ControlledGate` with
`sub_gate = SWAP`.
"""
if num_controls == 0:
return self
return controlled_gate.ControlledGate(
controlled_gate.ControlledGate(swap_gates.SWAP, num_controls=1),
num_controls=num_controls,
control_values=control_values,
control_qid_shape=control_qid_shape,
)


CCZ = CCZPowGate()
document(
Expand Down
9 changes: 9 additions & 0 deletions cirq-core/cirq/ops/three_qubit_gates_test.py
Expand Up @@ -39,6 +39,8 @@ def test_eigen_gates_consistent_protocols(eigen_gate_type):
(cirq.CSWAP, False),
(cirq.ThreeQubitDiagonalGate([2, 3, 5, 7, 11, 13, 17, 19]), True),
(cirq.ThreeQubitDiagonalGate([0, 0, 0, 0, 0, 0, 0, 0]), True),
(cirq.CCX, False),
(cirq.CCZ, False),
),
)
def test_consistent_protocols(gate, ignoring_global_phase):
Expand Down Expand Up @@ -320,3 +322,10 @@ def test_resolve(resolve_fn):
diagonal_gate = resolve_fn(diagonal_gate, {'b': 19})
assert diagonal_gate == cirq.ThreeQubitDiagonalGate(diagonal_angles)
assert not cirq.is_parameterized(diagonal_gate)


@pytest.mark.parametrize('gate', [cirq.CCX, cirq.CCZ, cirq.CSWAP])
def test_controlled_ops_consistency(gate):
a, b, c, d = cirq.LineQubit.range(4)
assert gate.controlled(0) is gate
assert gate(a, b, c).controlled_by(d) == gate(d, b, c).controlled_by(a)
4 changes: 4 additions & 0 deletions cirq-core/cirq/testing/__init__.py
Expand Up @@ -36,6 +36,10 @@
assert_phase_by_is_consistent_with_unitary,
)

from cirq.testing.consistent_controlled_gate_op import (
assert_controlled_and_controlled_by_identical,
)

from cirq.testing.consistent_decomposition import (
assert_decompose_is_consistent_with_unitary,
)
Expand Down
56 changes: 56 additions & 0 deletions cirq-core/cirq/testing/consistent_controlled_gate_op.py
@@ -0,0 +1,56 @@
# Copyright 2021 The Cirq Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Sequence, Optional, Union, Collection

from cirq import protocols, devices, ops


def assert_controlled_and_controlled_by_identical(
gate: ops.Gate,
*,
num_controls: Sequence[int] = (2, 1, 3, 10),
control_values: Optional[Sequence[Optional[Sequence[Union[int, Collection[int]]]]]] = None,
) -> None:
"""Checks that gate.on().controlled_by() == gate.controlled().on()"""
if control_values is not None:
if len(num_controls) != len(control_values):
raise ValueError(f"len(num_controls) != len(control_values)")
for i, num_control in enumerate(num_controls):
control_value = control_values[i] if control_values else None
if control_value is not None and len(control_value) != num_control:
raise ValueError(f"len(control_values[{i}]) != num_controls[{i}]")
_assert_gate_consistent(gate, num_control, control_value)


def _assert_gate_consistent(
gate: ops.Gate,
num_controls: int,
control_values: Optional[Sequence[Union[int, Collection[int]]]],
) -> None:
if isinstance(gate, ops.DensePauliString) and protocols.is_parameterized(gate):
# Parameterized `DensePauliString`s cannot be applied to qubits to produce valid operations.
# TODO: This behavior should be fixed (https://github.com/quantumlib/Cirq/issues/4508)
return None
gate_controlled = gate.controlled(num_controls, control_values)
qubits = devices.LineQid.for_gate(gate_controlled)
control_qubits = qubits[:num_controls]
gate_qubits = qubits[num_controls:]
gate_controlled_on = gate_controlled.on(*control_qubits, *gate_qubits)
gate_on_controlled_by = gate.on(*gate_qubits).controlled_by(
*control_qubits, control_values=control_values
)
assert (
gate_controlled_on == gate_on_controlled_by
), "gate.controlled().on() and gate.on().controlled() should return the same operations."

0 comments on commit f48efe0

Please sign in to comment.