Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add _apply_channel_ optimizations for reset and confusion #5917

Merged
merged 17 commits into from
Oct 13, 2022
79 changes: 78 additions & 1 deletion cirq-core/cirq/linalg/transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@

"""Utility methods for transforming matrices or vectors."""

from typing import Tuple, Optional, Sequence, List, Union
import dataclasses
from typing import Any, List, Optional, Sequence, Tuple, Union

import numpy as np

Expand Down Expand Up @@ -168,6 +169,82 @@ def targeted_left_multiply(
) # type: ignore


@dataclasses.dataclass
class _SliceConfig:
axis: int
source_index: int
target_index: int


@dataclasses.dataclass
class _BuildFromSlicesArgs:
slices: Tuple[_SliceConfig, ...]
scale: complex


def _build_from_slices(
args: Sequence[_BuildFromSlicesArgs], source: np.ndarray, out: np.ndarray
) -> np.ndarray:
"""Populates `out` from the desired slices of `source`.

This function is best described by example.

For instance in 3*3*3 3D space, one could take a cube array, take all the horizontal slices,
and add them up into the top slice leaving everything else zero. If the vertical axis was 1,
and the top was index=2, then this would be written as follows:

_build_from_slices(
[
_BuildFromSlicesArgs((_SliceConfig(axis=1, source_index=0, target_index=2),), 1),
_BuildFromSlicesArgs((_SliceConfig(axis=1, source_index=1, target_index=2),), 1),
_BuildFromSlicesArgs((_SliceConfig(axis=1, source_index=2, target_index=2),), 1),
],
source,
out,
)

When multiple slices are included in the _BuildFromSlicesArgs, this means to take the
intersection of the source space and move it to the intersection of the target space. For
example, the following takes the bottom-left edge and moves it to the top-right, leaving all
other cells zero. Assume the lateral axis is 2 and right-most index thereof is 2:

_build_from_slices(
[
_BuildFromSlicesArgs(
(
_SliceConfig(axis=1, source_index=0, target_index=2), # top
_SliceConfig(axis=2, source_index=0, target_index=2), # right
),
scale=1,
),
],
source,
out,
)

This function is useful for optimizing multiplying a state by one or more one-hot matrices,
as is common when working with Kraus components. It is more efficient than using an einsum.

Args:
args: The list of slice configurations to sum up into the output.
source: The source tensor for the slice data.
out: An output tensor that is the same shape as the source.

Returns:
The output tensor.
"""
d = len(source.shape)
out[...] = 0
for arg in args:
source_slice: List[Any] = [slice(None)] * d
target_slice: List[Any] = [slice(None)] * d
for sleis in arg.slices:
source_slice[sleis.axis] = sleis.source_index
target_slice[sleis.axis] = sleis.target_index
out[tuple(target_slice)] += arg.scale * source[tuple(source_slice)]
return out


def targeted_conjugate_about(
tensor: np.ndarray,
target: np.ndarray,
Expand Down
32 changes: 31 additions & 1 deletion cirq-core/cirq/ops/common_channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@
import numpy as np

from cirq import protocols, value
from cirq.linalg import transformations
from cirq.ops import raw_types, common_gates, pauli_gates, identity


if TYPE_CHECKING:
import cirq

Expand Down Expand Up @@ -734,6 +734,19 @@ def _kraus_(self) -> Iterable[np.ndarray]:
channel[:, 0, :] = np.eye(self._dimension)
return channel

def _apply_channel_(self, args: 'cirq.ApplyChannelArgs'):
configs = []
for i in range(self._dimension):
s1 = transformations._SliceConfig(
axis=args.left_axes[0], source_index=i, target_index=0
)
s2 = transformations._SliceConfig(
axis=args.right_axes[0], source_index=i, target_index=0
)
configs.append(transformations._BuildFromSlicesArgs(slices=(s1, s2), scale=1))
transformations._build_from_slices(configs, args.target_tensor, out=args.out_buffer)
return args.out_buffer

def _has_kraus_(self) -> bool:
return True

Expand Down Expand Up @@ -816,6 +829,23 @@ def __init__(self, gamma: float) -> None:
def _num_qubits_(self) -> int:
return 1

def _apply_channel_(self, args: 'cirq.ApplyChannelArgs'):
if self._gamma == 0:
return args.target_tensor
if self._gamma != 1:
return NotImplemented
configs = []
for i in range(2):
s1 = transformations._SliceConfig(
axis=args.left_axes[0], source_index=i, target_index=i
)
s2 = transformations._SliceConfig(
axis=args.right_axes[0], source_index=i, target_index=i
)
configs.append(transformations._BuildFromSlicesArgs(slices=(s1, s2), scale=1))
transformations._build_from_slices(configs, args.target_tensor, out=args.out_buffer)
return args.out_buffer

def _kraus_(self) -> Iterable[np.ndarray]:
return (
np.array([[1.0, 0.0], [0.0, np.sqrt(1.0 - self._gamma)]]),
Expand Down
16 changes: 16 additions & 0 deletions cirq-core/cirq/ops/common_channels_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,13 @@ def test_reset_each():
assert op.qubits == (qubits[i],)


def test_reset_consistency():
two_d_chan = cirq.ResetChannel()
cirq.testing.assert_has_consistent_apply_channel(two_d_chan)
three_d_chan = cirq.ResetChannel(dimension=3)
cirq.testing.assert_has_consistent_apply_channel(three_d_chan)


def test_phase_damping_channel():
d = cirq.phase_damp(0.3)
np.testing.assert_almost_equal(
Expand Down Expand Up @@ -585,6 +592,15 @@ def test_phase_damping_channel_text_diagram():
)


def test_phase_damp_consistency():
full_damp = cirq.PhaseDampingChannel(gamma=1)
cirq.testing.assert_has_consistent_apply_channel(full_damp)
partial_damp = cirq.PhaseDampingChannel(gamma=0.5)
cirq.testing.assert_has_consistent_apply_channel(partial_damp)
no_damp = cirq.PhaseDampingChannel(gamma=0)
cirq.testing.assert_has_consistent_apply_channel(no_damp)


def test_phase_flip_channel():
d = cirq.phase_flip(0.3)
np.testing.assert_almost_equal(
Expand Down
8 changes: 8 additions & 0 deletions cirq-core/cirq/ops/gate_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,14 @@ def _mixture_(self) -> Sequence[Tuple[float, Any]]:
return getter()
return NotImplemented

def _apply_channel_(
self, args: 'protocols.ApplyChannelArgs'
) -> Union[np.ndarray, None, NotImplementedType]:
getter = getattr(self.gate, '_apply_channel_', None)
if getter is not None:
return getter(args)
return NotImplemented

def _has_kraus_(self) -> bool:
getter = getattr(self.gate, '_has_kraus_', None)
if getter is not None:
Expand Down
1 change: 1 addition & 0 deletions cirq-core/cirq/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from cirq.testing.circuit_compare import (
assert_circuits_with_terminal_measurements_are_equivalent,
assert_circuits_have_same_unitary_given_final_permutation,
assert_has_consistent_apply_channel,
assert_has_consistent_apply_unitary,
assert_has_consistent_apply_unitary_for_various_exponents,
assert_has_diagram,
Expand Down
44 changes: 44 additions & 0 deletions cirq-core/cirq/testing/circuit_compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,50 @@ def assert_has_consistent_apply_unitary(val: Any, *, atol: float = 1e-8) -> None
np.testing.assert_allclose(actual.reshape(n, n), expected, atol=atol)


def assert_has_consistent_apply_channel(val: Any, *, atol: float = 1e-8) -> None:
"""Tests whether a value's _apply_channel_ is correct.

Contrasts the effects of the value's `_apply_channel_` with the superoperator calculated from
the Kraus components returned by the value's `_kraus_` method.

Args:
val: The value under test. Should have a `__pow__` method.
atol: Absolute error tolerance.
"""
# pylint: disable=unused-variable
__tracebackhide__ = True
# pylint: enable=unused-variable

kraus = protocols.kraus(val, default=None)
expected = qis.kraus_to_superoperator(kraus) if kraus is not None else None

qid_shape = protocols.qid_shape(val)

eye = qis.eye_tensor(qid_shape * 2, dtype=np.complex128)
actual = protocols.apply_channel(
val=val,
args=protocols.ApplyChannelArgs(
target_tensor=eye,
out_buffer=np.ones_like(eye) * float('nan'),
auxiliary_buffer0=np.ones_like(eye) * float('nan'),
auxiliary_buffer1=np.ones_like(eye) * float('nan'),
left_axes=list(range(len(qid_shape))),
right_axes=list(range(len(qid_shape), len(qid_shape) * 2)),
),
default=None,
)

# If you don't have a Kraus, you shouldn't be able to apply a channel.
if expected is None:
assert actual is None

# If you applied a channel, it should match the superoperator you say you have.
if actual is not None:
assert expected is not None
n = np.product(qid_shape) ** 2
np.testing.assert_allclose(actual.reshape((n, n)), expected, atol=atol)


def _assert_apply_unitary_works_when_axes_transposed(val: Any, *, atol: float = 1e-8) -> None:
"""Tests whether a value's _apply_unitary_ handles out-of-order axes.

Expand Down
66 changes: 66 additions & 0 deletions cirq-core/cirq/testing/circuit_compare_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,72 @@ def test_assert_has_diagram():
assert expected_error in ex_info.value.args[0]


def test_assert_has_consistent_apply_channel():
class Correct:
def _apply_channel_(self, args: cirq.ApplyChannelArgs):
args.target_tensor[...] = 0
return args.target_tensor

def _kraus_(self):
return [np.array([[0, 0], [0, 0]])]

def _num_qubits_(self):
return 1

cirq.testing.assert_has_consistent_apply_channel(Correct())

class Wrong:
def _apply_channel_(self, args: cirq.ApplyChannelArgs):
args.target_tensor[...] = 0
return args.target_tensor

def _kraus_(self):
return [np.array([[1, 0], [0, 0]])]

def _num_qubits_(self):
return 1

with pytest.raises(AssertionError):
cirq.testing.assert_has_consistent_apply_channel(Wrong())

class NoNothing:
def _apply_channel_(self, args: cirq.ApplyChannelArgs):
return NotImplemented

def _kraus_(self):
return NotImplemented

def _num_qubits_(self):
return 1

cirq.testing.assert_has_consistent_apply_channel(NoNothing())

class NoKraus:
def _apply_channel_(self, args: cirq.ApplyChannelArgs):
return args.target_tensor

def _kraus_(self):
return NotImplemented

def _num_qubits_(self):
return 1

with pytest.raises(AssertionError):
cirq.testing.assert_has_consistent_apply_channel(NoKraus())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Optional: I think it'd be nice to make this case work by having protocol.kraus compute the Kraus representation when _apply_channel_ is provided but _kraus_ isn't. Probably a good idea for a follow-up rather than this PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added #5921


class NoApply:
def _apply_channel_(self, args: cirq.ApplyChannelArgs):
return NotImplemented

def _kraus_(self):
return [np.array([[0, 0], [0, 0]])]

def _num_qubits_(self):
return 1

cirq.testing.assert_has_consistent_apply_channel(NoApply())


def test_assert_has_consistent_apply_unitary():
class IdentityReturningUnalteredWorkspace:
def _apply_unitary_(self, args: cirq.ApplyUnitaryArgs) -> np.ndarray:
Expand Down
27 changes: 27 additions & 0 deletions cirq-core/cirq/transformers/measurement_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import numpy as np

from cirq import linalg, ops, protocols, value
from cirq.linalg import transformations
from cirq.transformers import transformer_api, transformer_primitives
from cirq.transformers.synchronize_terminal_measurements import find_terminal_measurements

Expand Down Expand Up @@ -341,6 +342,7 @@ def __init__(self, confusion_map: np.ndarray, shape: Sequence[int]):
if not linalg.is_cptp(kraus_ops=kraus):
raise ValueError('Confusion map has invalid probabilities.')
self._shape = tuple(shape)
self._confusion_map = confusion_map.copy()
self._kraus = tuple(kraus)

def _qid_shape_(self) -> Tuple[int, ...]:
Expand All @@ -349,6 +351,31 @@ def _qid_shape_(self) -> Tuple[int, ...]:
def _kraus_(self) -> Tuple[np.ndarray, ...]:
return self._kraus

def _apply_channel_(self, args: 'cirq.ApplyChannelArgs'):
configs = []
for i in range(np.prod(self._shape) ** 2):
scale = self._confusion_map.flat[i]
if scale == 0:
continue
index: Any = np.unravel_index(i, self._shape * 2)
slices = []
axis_count = len(args.left_axes)
for j in range(axis_count):
s1 = transformations._SliceConfig(
axis=args.left_axes[j],
source_index=index[j],
target_index=index[j + axis_count],
)
s2 = transformations._SliceConfig(
axis=args.right_axes[j],
source_index=index[j],
target_index=index[j + axis_count],
)
slices.extend([s1, s2])
configs.append(transformations._BuildFromSlicesArgs(slices=tuple(slices), scale=scale))
transformations._build_from_slices(configs, args.target_tensor, out=args.out_buffer)
return args.out_buffer


@value.value_equality
class _ModAdd(ops.ArithmeticGate):
Expand Down
16 changes: 15 additions & 1 deletion cirq-core/cirq/transformers/measurement_transformers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import sympy

import cirq
from cirq.transformers.measurement_transformers import _mod_add, _MeasurementQid
from cirq.transformers.measurement_transformers import _ConfusionChannel, _MeasurementQid, _mod_add


def assert_equivalent_to_deferred(circuit: cirq.Circuit):
Expand Down Expand Up @@ -575,3 +575,17 @@ def test_drop_terminal_nonterminal_error():

with pytest.raises(ValueError, match='Context has `deep=False`'):
_ = cirq.drop_terminal_measurements(circuit, context=None)


def test_confusion_channel_consistency():
two_d_chan = _ConfusionChannel(np.array([[0.5, 0.5], [0.4, 0.6]]), shape=(2,))
cirq.testing.assert_has_consistent_apply_channel(two_d_chan)
three_d_chan = _ConfusionChannel(
np.array([[0.5, 0.3, 0.2], [0.4, 0.5, 0.1], [0, 0, 1]]), shape=(3,)
)
cirq.testing.assert_has_consistent_apply_channel(three_d_chan)
two_q_chan = _ConfusionChannel(
np.array([[0.5, 0.3, 0.1, 0.1], [0.4, 0.5, 0.1, 0], [0, 0, 1, 0], [0, 0, 0.5, 0.5]]),
shape=(2, 2),
)
cirq.testing.assert_has_consistent_apply_channel(two_q_chan)