Skip to content

Commit

Permalink
Fix a bug in Depol channel with multi qubits (#3715)
Browse files Browse the repository at this point in the history
Fixes #3685
  • Loading branch information
tonybruguier committed Feb 2, 2021
1 parent c626db0 commit 2a121ec
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 41 deletions.
9 changes: 6 additions & 3 deletions cirq/ops/common_channels.py
Expand Up @@ -15,7 +15,7 @@
"""Quantum channels that are commonly used in the literature."""

import itertools
from typing import Any, Dict, Iterable, Optional, Sequence, Tuple, Union, TYPE_CHECKING
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union, TYPE_CHECKING

import numpy as np

Expand Down Expand Up @@ -230,8 +230,8 @@ def asymmetric_depolarize(


@value.value_equality
class DepolarizingChannel(gate_features.SingleQubitGate):
"""A channel that depolarizes a qubit."""
class DepolarizingChannel(gate_features.SupportsOnEachGate, raw_types.Gate):
"""A channel that depolarizes one or several qubits."""

def __init__(self, p: float, n_qubits: int = 1) -> None:
r"""The symmetric depolarizing channel.
Expand Down Expand Up @@ -278,6 +278,9 @@ def __init__(self, p: float, n_qubits: int = 1) -> None:

self._delegate = AsymmetricDepolarizingChannel(error_probabilities=error_probabilities)

def _qid_shape_(self):
return (2,) * self._n_qubits

def _mixture_(self) -> Sequence[Tuple[float, np.ndarray]]:
return self._delegate._mixture_()

Expand Down
30 changes: 30 additions & 0 deletions cirq/ops/common_channels_test.py
Expand Up @@ -213,6 +213,36 @@ def test_depolarizing_channel_str_two_qubits():
assert str(cirq.depolarize(0.3, n_qubits=2)) == 'depolarize(p=0.3,n_qubits=2)'


def test_deprecated_on_each_for_depolarizing_channel_one_qubit():
q0 = cirq.LineQubit.range(1)
op = cirq.DepolarizingChannel(p=0.1, n_qubits=1)

op.on_each(q0)
op.on_each([q0])
with pytest.raises(ValueError, match="Gate was called with type different than Qid"):
op.on_each('bogus object')


def test_deprecated_on_each_for_depolarizing_channel_two_qubits():
q0, q1 = cirq.LineQubit.range(2)
op = cirq.DepolarizingChannel(p=0.1, n_qubits=2)

with pytest.raises(ValueError, match="one qubit"):
op.on_each(q0, q1)


def test_depolarizing_channel_apply_two_qubits():
q0, q1 = cirq.LineQubit.range(2)
op = cirq.DepolarizingChannel(p=0.1, n_qubits=2)
op(q0, q1)


def test_asymmetric_depolarizing_channel_apply_two_qubits():
q0, q1 = cirq.LineQubit.range(2)
op = cirq.AsymmetricDepolarizingChannel(error_probabilities={'XX': 0.1})
op(q0, q1)


def test_depolarizing_channel_eq():
et = cirq.testing.EqualsTester()
c = cirq.depolarize(0.0)
Expand Down
17 changes: 12 additions & 5 deletions cirq/ops/gate_features.py
Expand Up @@ -31,11 +31,8 @@ def qubit_index_to_equivalence_group_key(self, index: int) -> int:
return 0


class SingleQubitGate(raw_types.Gate, metaclass=abc.ABCMeta):
"""A gate that must be applied to exactly one qubit."""

def _num_qubits_(self) -> int:
return 1
class SupportsOnEachGate(raw_types.Gate, metaclass=abc.ABCMeta):
"""A gate that can be applied to exactly one qubit."""

def on_each(self, *targets: Union[raw_types.Qid, Iterable[Any]]) -> List[raw_types.Operation]:
"""Returns a list of operations applying the gate to all targets.
Expand All @@ -48,7 +45,10 @@ def on_each(self, *targets: Union[raw_types.Qid, Iterable[Any]]) -> List[raw_typ
Raises:
ValueError if targets are not instances of Qid or List[Qid].
ValueError if the gate operates on two or more Qids.
"""
if self._num_qubits_() > 1:
raise ValueError('This gate only supports on_each when it is a one qubit gate.')
operations = [] # type: List[raw_types.Operation]
for target in targets:
if isinstance(target, raw_types.Qid):
Expand All @@ -62,6 +62,13 @@ def on_each(self, *targets: Union[raw_types.Qid, Iterable[Any]]) -> List[raw_typ
return operations


class SingleQubitGate(SupportsOnEachGate, metaclass=abc.ABCMeta):
"""A gate that must be applied to exactly one qubit."""

def _num_qubits_(self) -> int:
return 1


class TwoQubitGate(raw_types.Gate, metaclass=abc.ABCMeta):
"""A gate that must be applied to exactly two qubits."""

Expand Down
35 changes: 3 additions & 32 deletions cirq/ops/identity.py
Expand Up @@ -13,21 +13,21 @@
# limitations under the License.
"""IdentityGate."""

from typing import Any, Dict, Iterable, List, Optional, Tuple, Union, TYPE_CHECKING
from typing import Any, Dict, Iterable, List, Optional, Tuple, TYPE_CHECKING

import numpy as np
import sympy

from cirq import protocols, value
from cirq._doc import document
from cirq.ops import raw_types
from cirq.ops import gate_features, raw_types

if TYPE_CHECKING:
import cirq


@value.value_equality
class IdentityGate(raw_types.Gate):
class IdentityGate(gate_features.SupportsOnEachGate, raw_types.Gate):
"""A Gate that perform no operation on qubits.
The unitary matrix of this gate is a diagonal matrix with all 1s on the
Expand Down Expand Up @@ -64,35 +64,6 @@ def _qid_shape_(self) -> Tuple[int, ...]:
def num_qubits(self) -> int:
return len(self._qid_shape)

def on_each(self, *targets: Union['cirq.Qid', Iterable[Any]]) -> List['cirq.Operation']:
"""Returns a list of operations that applies the single qubit identity
to each of the targets.
Args:
*targets: The qubits to apply this gate to.
Returns:
Operations applying this gate to the target qubits.
Raises:
ValueError if targets are not instances of Qid or List[Qid] or
the gate from which this is applied is not a single qubit identity
gate.
"""
if len(self._qid_shape) != 1:
raise ValueError('IdentityGate only supports on_each when it is a one qubit gate.')
operations: List['cirq.Operation'] = []
for target in targets:
if isinstance(target, raw_types.Qid):
operations.append(self.on(target))
elif isinstance(target, Iterable) and not isinstance(target, str):
operations.extend(self.on_each(*target))
else:
raise ValueError(
'Gate was called with type different than Qid. Type: {}'.format(type(target))
)
return operations

def __pow__(self, power: Any) -> Any:
if isinstance(power, (int, float, complex, sympy.Basic)):
return self
Expand Down
5 changes: 4 additions & 1 deletion cirq/ops/pauli_string_test.py
Expand Up @@ -674,7 +674,10 @@ def test_pass_operations_over_single(shift: int, sign: int):
def test_pass_operations_over_double(shift: int, t_or_f1: bool, t_or_f2: bool, neg: bool):
sign = -1 if neg else +1
q0, q1, q2 = _make_qubits(3)
X, Y, Z = (cirq.Pauli.by_relative_index(pauli, shift) for pauli in (cirq.X, cirq.Y, cirq.Z))
X, Y, Z = (
cirq.Pauli.by_relative_index(cast(cirq.Pauli, pauli), shift)
for pauli in (cirq.X, cirq.Y, cirq.Z)
)

op0 = cirq.PauliInteractionGate(Z, t_or_f1, X, t_or_f2)(q0, q1)
ps_before = cirq.PauliString(qubit_pauli_map={q0: Z, q2: Y}, coefficient=sign)
Expand Down

0 comments on commit 2a121ec

Please sign in to comment.