Skip to content

Commit

Permalink
Measureable channels and mixtures (#4194)
Browse files Browse the repository at this point in the history
_Partially_ addresses #3241; qudit and non-square operator support is not part of this PR. Design is presented in [this RFC](https://tinyurl.com/cirq-custom-channel).

`KrausChannel` and `MatrixMixture` both serve two purposes:
- Provide a base type for users to create their own noisy channels (and mixtures) without creating a new type
- Allow channels (and mixtures) to capture the selected operator index in a measurement result

The changes to `act_on_state_vector_args.py` enable the second item; everything else in this PR was already possible in Cirq, but previously required users to define their own classes to make use of it.
  • Loading branch information
95-martin-orion committed Aug 4, 2021
1 parent 409a412 commit f7b882c
Show file tree
Hide file tree
Showing 14 changed files with 895 additions and 11 deletions.
2 changes: 2 additions & 0 deletions cirq-core/cirq/__init__.py
Expand Up @@ -217,9 +217,11 @@
InterchangeableQubitsGate,
ISWAP,
ISwapPowGate,
KrausChannel,
LinearCombinationOfGates,
LinearCombinationOfOperations,
MatrixGate,
MixedUnitaryChannel,
measure,
measure_each,
MeasurementGate,
Expand Down
2 changes: 2 additions & 0 deletions cirq-core/cirq/json_resolver_cache.py
Expand Up @@ -89,10 +89,12 @@ def two_qubit_matrix_gate(matrix):
'IdentityGate': cirq.IdentityGate,
'IdentityOperation': _identity_operation_from_dict,
'InitObsSetting': cirq.work.InitObsSetting,
'KrausChannel': cirq.KrausChannel,
'LinearDict': cirq.LinearDict,
'LineQubit': cirq.LineQubit,
'LineQid': cirq.LineQid,
'MatrixGate': cirq.MatrixGate,
'MixedUnitaryChannel': cirq.MixedUnitaryChannel,
'MeasurementKey': cirq.MeasurementKey,
'MeasurementGate': cirq.MeasurementGate,
'_MeasurementSpec': cirq.work._MeasurementSpec,
Expand Down
8 changes: 8 additions & 0 deletions cirq-core/cirq/ops/__init__.py
Expand Up @@ -116,13 +116,21 @@
GlobalPhaseOperation,
)

from cirq.ops.kraus_channel import (
KrausChannel,
)

from cirq.ops.linear_combinations import (
LinearCombinationOfGates,
LinearCombinationOfOperations,
PauliSum,
PauliSumLike,
)

from cirq.ops.mixed_unitary_channel import (
MixedUnitaryChannel,
)

from cirq.ops.pauli_sum_exponential import (
PauliSumExponential,
)
Expand Down
107 changes: 107 additions & 0 deletions cirq-core/cirq/ops/kraus_channel.py
@@ -0,0 +1,107 @@
from typing import Any, Dict, Iterable, Tuple, Union
import numpy as np

from cirq import linalg, protocols, value
from cirq._compat import proper_repr
from cirq.ops import raw_types


# TODO(#3241): support qudits and non-square operators.
class KrausChannel(raw_types.Gate):
"""A generic channel that can record the index of its selected operator.
Args:
kraus_ops: a list of Kraus operators, formatted as numpy array.
Currently, only square-matrix operators on qubits (not qudits) are
supported by this type.
key: an optional measurement key string for this channel. Simulations
which select a single Kraus operator to apply will store the index
of that operator in the measurement result list with this key.
validate: if True, validate that `kraus_ops` describe a valid channel.
This validation can be slow; prefer pre-validating if possible.
"""

def __init__(
self,
kraus_ops: Iterable[np.ndarray],
key: Union[str, value.MeasurementKey, None] = None,
validate: bool = False,
):
kraus_ops = list(kraus_ops)
if not kraus_ops:
raise ValueError('KrausChannel must have at least one operation.')
num_qubits = np.log2(kraus_ops[0].shape[0])
if not num_qubits.is_integer() or kraus_ops[0].shape[1] != kraus_ops[0].shape[0]:
raise ValueError(
f'Input Kraus ops of shape {kraus_ops[0].shape} does not '
'represent a square operator over qubits.'
)
self._num_qubits = int(num_qubits)
for i, op in enumerate(kraus_ops):
if not op.shape == kraus_ops[0].shape:
raise ValueError(
'Inconsistent Kraus operator shapes: '
f'op[0]: {kraus_ops[0].shape}, op[{i}]: {op.shape}'
)
if validate and not linalg.is_cptp(kraus_ops=kraus_ops):
raise ValueError('Kraus operators do not describe a CPTP map.')
self._kraus_ops = kraus_ops
if not isinstance(key, value.MeasurementKey) and key is not None:
key = value.MeasurementKey(key)
self._key = key

@staticmethod
def from_channel(
channel: 'protocols.SupportsChannel', key: Union[str, value.MeasurementKey, None] = None
):
"""Creates a copy of a channel with the given measurement key."""
return KrausChannel(kraus_ops=list(protocols.kraus(channel)), key=key)

def __eq__(self, other) -> bool:
# TODO(#3241): provide a protocol to test equivalence between channels,
# ignoring measurement keys and channel/mixture distinction
if not isinstance(other, KrausChannel):
return NotImplemented
if self._key != other._key:
return False
return np.allclose(self._kraus_ops, other._kraus_ops)

def num_qubits(self) -> int:
return self._num_qubits

def _kraus_(self):
return self._kraus_ops

def _measurement_key_(self):
if self._key is None:
return NotImplemented
return self._key

def _with_measurement_key_mapping_(self, key_map: Dict[str, str]):
if self._key is None:
return NotImplemented
if self._key not in key_map:
return self
return KrausChannel(kraus_ops=self._kraus_ops, key=key_map[str(self._key)])

def _with_key_path_(self, path: Tuple[str, ...]):
return KrausChannel(kraus_ops=self._kraus_ops, key=protocols.with_key_path(self._key, path))

def __str__(self):
if self._key is not None:
return f'KrausChannel({self._kraus_ops}, key={self._key})'
return f'KrausChannel({self._kraus_ops})'

def __repr__(self):
args = ['kraus_ops=[' + ', '.join(proper_repr(op) for op in self._kraus_ops) + ']']
if self._key is not None:
args.append(f'key=\'{self._key}\'')
return f'cirq.KrausChannel({", ".join(args)})'

def _json_dict_(self) -> Dict[str, Any]:
return protocols.obj_to_dict_helper(self, ['_kraus_ops', '_key'])

@classmethod
def _from_json_dict_(cls, _kraus_ops, _key, **kwargs):
ops = [np.asarray(op) for op in _kraus_ops]
return cls(kraus_ops=ops, key=_key)
154 changes: 154 additions & 0 deletions cirq-core/cirq/ops/kraus_channel_test.py
@@ -0,0 +1,154 @@
import cirq
import numpy as np
import pytest


def test_kraus_channel_from_channel():
q0 = cirq.LineQubit(0)
dp = cirq.depolarize(0.1)
kc = cirq.KrausChannel.from_channel(dp, key='dp')
assert cirq.measurement_key(kc) == 'dp'

circuit = cirq.Circuit(kc.on(q0))
sim = cirq.Simulator(seed=0)

results = sim.simulate(circuit)
assert 'dp' in results.measurements
# The depolarizing channel has four Kraus operators.
assert results.measurements['dp'] in range(4)


def test_kraus_channel_equality():
dp_pt1 = cirq.depolarize(0.1)
dp_pt2 = cirq.depolarize(0.2)
kc_a1 = cirq.KrausChannel.from_channel(dp_pt1, key='a')
kc_a2 = cirq.KrausChannel.from_channel(dp_pt2, key='a')
kc_b1 = cirq.KrausChannel.from_channel(dp_pt1, key='b')

# Even if their effect is the same, KrausChannels are not treated as equal
# to other channels defined in Cirq.
assert kc_a1 != dp_pt1
assert kc_a1 != kc_a2
assert kc_a1 != kc_b1
assert kc_a2 != kc_b1

ops = [
np.array([[1, 0], [0, 0]]),
np.array([[0, 0], [0, 1]]),
]
x_meas = cirq.KrausChannel(ops)
ops_inv = list(reversed(ops))
x_meas_inv = cirq.KrausChannel(ops_inv)
# Even though these have the same effect on the circuit, their measurement
# behavior differs, so they are considered non-equal.
assert x_meas != x_meas_inv


def test_kraus_channel_remap_keys():
dp = cirq.depolarize(0.1)
kc = cirq.KrausChannel.from_channel(dp)
with pytest.raises(TypeError):
_ = cirq.measurement_key(kc)
assert cirq.with_measurement_key_mapping(kc, {'a': 'b'}) is NotImplemented

kc_x = cirq.KrausChannel.from_channel(dp, key='x')
assert cirq.with_measurement_key_mapping(kc_x, {'a': 'b'}) is kc_x
assert cirq.measurement_key(cirq.with_key_path(kc_x, ('path',))) == 'path:x'

kc_a = cirq.KrausChannel.from_channel(dp, key='a')
kc_b = cirq.KrausChannel.from_channel(dp, key='b')
assert kc_a != kc_b
assert cirq.with_measurement_key_mapping(kc_a, {'a': 'b'}) == kc_b


def test_kraus_channel_from_kraus():
q0 = cirq.LineQubit(0)
# This is equivalent to an X-basis measurement.
ops = [
np.array([[1, 1], [1, 1]]) * 0.5,
np.array([[1, -1], [-1, 1]]) * 0.5,
]
x_meas = cirq.KrausChannel(ops, key='x_meas')
assert cirq.measurement_key(x_meas) == 'x_meas'

circuit = cirq.Circuit(cirq.H(q0), x_meas.on(q0))
sim = cirq.Simulator(seed=0)

results = sim.simulate(circuit)
assert 'x_meas' in results.measurements
assert results.measurements['x_meas'] == 0


def test_kraus_channel_str():
# This is equivalent to an X-basis measurement.
ops = [
np.array([[1, 1], [1, 1]]) * 0.5,
np.array([[1, -1], [-1, 1]]) * 0.5,
]
x_meas = cirq.KrausChannel(ops)
assert (
str(x_meas)
== """KrausChannel([array([[0.5, 0.5],
[0.5, 0.5]]), array([[ 0.5, -0.5],
[-0.5, 0.5]])])"""
)
x_meas_keyed = cirq.KrausChannel(ops, key='x_meas')
assert (
str(x_meas_keyed)
== """KrausChannel([array([[0.5, 0.5],
[0.5, 0.5]]), array([[ 0.5, -0.5],
[-0.5, 0.5]])], key=x_meas)"""
)


def test_kraus_channel_repr():
# This is equivalent to an X-basis measurement.
ops = [
np.array([[1, 1], [1, 1]], dtype=np.complex64) * 0.5,
np.array([[1, -1], [-1, 1]], dtype=np.complex64) * 0.5,
]
x_meas = cirq.KrausChannel(ops, key='x_meas')
assert (
repr(x_meas)
== """\
cirq.KrausChannel(kraus_ops=[\
np.array([[(0.5+0j), (0.5+0j)], [(0.5+0j), (0.5+0j)]], dtype=np.complex64), \
np.array([[(0.5+0j), (-0.5+0j)], [(-0.5+0j), (0.5+0j)]], dtype=np.complex64)], \
key='x_meas')"""
)


def test_empty_ops_fails():
ops = []

with pytest.raises(ValueError, match='must have at least one operation'):
_ = cirq.KrausChannel(kraus_ops=ops, key='m')


def test_ops_mismatch_fails():
op2 = np.zeros((4, 4))
op2[1][1] = 1
ops = [np.array([[1, 0], [0, 0]]), op2]

with pytest.raises(ValueError, match='Inconsistent Kraus operator shapes'):
_ = cirq.KrausChannel(kraus_ops=ops, key='m')


def test_nonqubit_kraus_ops_fails():
ops = [
np.array([[1, 0, 0], [0, 0, 0]]),
np.array([[0, 0, 0], [0, 1, 0]]),
]

with pytest.raises(ValueError, match='Input Kraus ops'):
_ = cirq.KrausChannel(kraus_ops=ops, key='m')


def test_validate():
# Not quite CPTP.
ops = [
np.array([[1, 0], [0, 0]]),
np.array([[0, 0], [0, 0.9]]),
]
with pytest.raises(ValueError, match='CPTP map'):
_ = cirq.KrausChannel(kraus_ops=ops, key='m', validate=True)

0 comments on commit f7b882c

Please sign in to comment.