Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Measureable channels and mixtures (#4194)
_Partially_ addresses #3241; qudit and non-square operator support is not part of this PR. Design is presented in [this RFC](https://tinyurl.com/cirq-custom-channel). `KrausChannel` and `MatrixMixture` both serve two purposes: - Provide a base type for users to create their own noisy channels (and mixtures) without creating a new type - Allow channels (and mixtures) to capture the selected operator index in a measurement result The changes to `act_on_state_vector_args.py` enable the second item; everything else in this PR was already possible in Cirq, but previously required users to define their own classes to make use of it.
- Loading branch information
1 parent
409a412
commit f7b882c
Showing
14 changed files
with
895 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
from typing import Any, Dict, Iterable, Tuple, Union | ||
import numpy as np | ||
|
||
from cirq import linalg, protocols, value | ||
from cirq._compat import proper_repr | ||
from cirq.ops import raw_types | ||
|
||
|
||
# TODO(#3241): support qudits and non-square operators. | ||
class KrausChannel(raw_types.Gate): | ||
"""A generic channel that can record the index of its selected operator. | ||
Args: | ||
kraus_ops: a list of Kraus operators, formatted as numpy array. | ||
Currently, only square-matrix operators on qubits (not qudits) are | ||
supported by this type. | ||
key: an optional measurement key string for this channel. Simulations | ||
which select a single Kraus operator to apply will store the index | ||
of that operator in the measurement result list with this key. | ||
validate: if True, validate that `kraus_ops` describe a valid channel. | ||
This validation can be slow; prefer pre-validating if possible. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
kraus_ops: Iterable[np.ndarray], | ||
key: Union[str, value.MeasurementKey, None] = None, | ||
validate: bool = False, | ||
): | ||
kraus_ops = list(kraus_ops) | ||
if not kraus_ops: | ||
raise ValueError('KrausChannel must have at least one operation.') | ||
num_qubits = np.log2(kraus_ops[0].shape[0]) | ||
if not num_qubits.is_integer() or kraus_ops[0].shape[1] != kraus_ops[0].shape[0]: | ||
raise ValueError( | ||
f'Input Kraus ops of shape {kraus_ops[0].shape} does not ' | ||
'represent a square operator over qubits.' | ||
) | ||
self._num_qubits = int(num_qubits) | ||
for i, op in enumerate(kraus_ops): | ||
if not op.shape == kraus_ops[0].shape: | ||
raise ValueError( | ||
'Inconsistent Kraus operator shapes: ' | ||
f'op[0]: {kraus_ops[0].shape}, op[{i}]: {op.shape}' | ||
) | ||
if validate and not linalg.is_cptp(kraus_ops=kraus_ops): | ||
raise ValueError('Kraus operators do not describe a CPTP map.') | ||
self._kraus_ops = kraus_ops | ||
if not isinstance(key, value.MeasurementKey) and key is not None: | ||
key = value.MeasurementKey(key) | ||
self._key = key | ||
|
||
@staticmethod | ||
def from_channel( | ||
channel: 'protocols.SupportsChannel', key: Union[str, value.MeasurementKey, None] = None | ||
): | ||
"""Creates a copy of a channel with the given measurement key.""" | ||
return KrausChannel(kraus_ops=list(protocols.kraus(channel)), key=key) | ||
|
||
def __eq__(self, other) -> bool: | ||
# TODO(#3241): provide a protocol to test equivalence between channels, | ||
# ignoring measurement keys and channel/mixture distinction | ||
if not isinstance(other, KrausChannel): | ||
return NotImplemented | ||
if self._key != other._key: | ||
return False | ||
return np.allclose(self._kraus_ops, other._kraus_ops) | ||
|
||
def num_qubits(self) -> int: | ||
return self._num_qubits | ||
|
||
def _kraus_(self): | ||
return self._kraus_ops | ||
|
||
def _measurement_key_(self): | ||
if self._key is None: | ||
return NotImplemented | ||
return self._key | ||
|
||
def _with_measurement_key_mapping_(self, key_map: Dict[str, str]): | ||
if self._key is None: | ||
return NotImplemented | ||
if self._key not in key_map: | ||
return self | ||
return KrausChannel(kraus_ops=self._kraus_ops, key=key_map[str(self._key)]) | ||
|
||
def _with_key_path_(self, path: Tuple[str, ...]): | ||
return KrausChannel(kraus_ops=self._kraus_ops, key=protocols.with_key_path(self._key, path)) | ||
|
||
def __str__(self): | ||
if self._key is not None: | ||
return f'KrausChannel({self._kraus_ops}, key={self._key})' | ||
return f'KrausChannel({self._kraus_ops})' | ||
|
||
def __repr__(self): | ||
args = ['kraus_ops=[' + ', '.join(proper_repr(op) for op in self._kraus_ops) + ']'] | ||
if self._key is not None: | ||
args.append(f'key=\'{self._key}\'') | ||
return f'cirq.KrausChannel({", ".join(args)})' | ||
|
||
def _json_dict_(self) -> Dict[str, Any]: | ||
return protocols.obj_to_dict_helper(self, ['_kraus_ops', '_key']) | ||
|
||
@classmethod | ||
def _from_json_dict_(cls, _kraus_ops, _key, **kwargs): | ||
ops = [np.asarray(op) for op in _kraus_ops] | ||
return cls(kraus_ops=ops, key=_key) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,154 @@ | ||
import cirq | ||
import numpy as np | ||
import pytest | ||
|
||
|
||
def test_kraus_channel_from_channel(): | ||
q0 = cirq.LineQubit(0) | ||
dp = cirq.depolarize(0.1) | ||
kc = cirq.KrausChannel.from_channel(dp, key='dp') | ||
assert cirq.measurement_key(kc) == 'dp' | ||
|
||
circuit = cirq.Circuit(kc.on(q0)) | ||
sim = cirq.Simulator(seed=0) | ||
|
||
results = sim.simulate(circuit) | ||
assert 'dp' in results.measurements | ||
# The depolarizing channel has four Kraus operators. | ||
assert results.measurements['dp'] in range(4) | ||
|
||
|
||
def test_kraus_channel_equality(): | ||
dp_pt1 = cirq.depolarize(0.1) | ||
dp_pt2 = cirq.depolarize(0.2) | ||
kc_a1 = cirq.KrausChannel.from_channel(dp_pt1, key='a') | ||
kc_a2 = cirq.KrausChannel.from_channel(dp_pt2, key='a') | ||
kc_b1 = cirq.KrausChannel.from_channel(dp_pt1, key='b') | ||
|
||
# Even if their effect is the same, KrausChannels are not treated as equal | ||
# to other channels defined in Cirq. | ||
assert kc_a1 != dp_pt1 | ||
assert kc_a1 != kc_a2 | ||
assert kc_a1 != kc_b1 | ||
assert kc_a2 != kc_b1 | ||
|
||
ops = [ | ||
np.array([[1, 0], [0, 0]]), | ||
np.array([[0, 0], [0, 1]]), | ||
] | ||
x_meas = cirq.KrausChannel(ops) | ||
ops_inv = list(reversed(ops)) | ||
x_meas_inv = cirq.KrausChannel(ops_inv) | ||
# Even though these have the same effect on the circuit, their measurement | ||
# behavior differs, so they are considered non-equal. | ||
assert x_meas != x_meas_inv | ||
|
||
|
||
def test_kraus_channel_remap_keys(): | ||
dp = cirq.depolarize(0.1) | ||
kc = cirq.KrausChannel.from_channel(dp) | ||
with pytest.raises(TypeError): | ||
_ = cirq.measurement_key(kc) | ||
assert cirq.with_measurement_key_mapping(kc, {'a': 'b'}) is NotImplemented | ||
|
||
kc_x = cirq.KrausChannel.from_channel(dp, key='x') | ||
assert cirq.with_measurement_key_mapping(kc_x, {'a': 'b'}) is kc_x | ||
assert cirq.measurement_key(cirq.with_key_path(kc_x, ('path',))) == 'path:x' | ||
|
||
kc_a = cirq.KrausChannel.from_channel(dp, key='a') | ||
kc_b = cirq.KrausChannel.from_channel(dp, key='b') | ||
assert kc_a != kc_b | ||
assert cirq.with_measurement_key_mapping(kc_a, {'a': 'b'}) == kc_b | ||
|
||
|
||
def test_kraus_channel_from_kraus(): | ||
q0 = cirq.LineQubit(0) | ||
# This is equivalent to an X-basis measurement. | ||
ops = [ | ||
np.array([[1, 1], [1, 1]]) * 0.5, | ||
np.array([[1, -1], [-1, 1]]) * 0.5, | ||
] | ||
x_meas = cirq.KrausChannel(ops, key='x_meas') | ||
assert cirq.measurement_key(x_meas) == 'x_meas' | ||
|
||
circuit = cirq.Circuit(cirq.H(q0), x_meas.on(q0)) | ||
sim = cirq.Simulator(seed=0) | ||
|
||
results = sim.simulate(circuit) | ||
assert 'x_meas' in results.measurements | ||
assert results.measurements['x_meas'] == 0 | ||
|
||
|
||
def test_kraus_channel_str(): | ||
# This is equivalent to an X-basis measurement. | ||
ops = [ | ||
np.array([[1, 1], [1, 1]]) * 0.5, | ||
np.array([[1, -1], [-1, 1]]) * 0.5, | ||
] | ||
x_meas = cirq.KrausChannel(ops) | ||
assert ( | ||
str(x_meas) | ||
== """KrausChannel([array([[0.5, 0.5], | ||
[0.5, 0.5]]), array([[ 0.5, -0.5], | ||
[-0.5, 0.5]])])""" | ||
) | ||
x_meas_keyed = cirq.KrausChannel(ops, key='x_meas') | ||
assert ( | ||
str(x_meas_keyed) | ||
== """KrausChannel([array([[0.5, 0.5], | ||
[0.5, 0.5]]), array([[ 0.5, -0.5], | ||
[-0.5, 0.5]])], key=x_meas)""" | ||
) | ||
|
||
|
||
def test_kraus_channel_repr(): | ||
# This is equivalent to an X-basis measurement. | ||
ops = [ | ||
np.array([[1, 1], [1, 1]], dtype=np.complex64) * 0.5, | ||
np.array([[1, -1], [-1, 1]], dtype=np.complex64) * 0.5, | ||
] | ||
x_meas = cirq.KrausChannel(ops, key='x_meas') | ||
assert ( | ||
repr(x_meas) | ||
== """\ | ||
cirq.KrausChannel(kraus_ops=[\ | ||
np.array([[(0.5+0j), (0.5+0j)], [(0.5+0j), (0.5+0j)]], dtype=np.complex64), \ | ||
np.array([[(0.5+0j), (-0.5+0j)], [(-0.5+0j), (0.5+0j)]], dtype=np.complex64)], \ | ||
key='x_meas')""" | ||
) | ||
|
||
|
||
def test_empty_ops_fails(): | ||
ops = [] | ||
|
||
with pytest.raises(ValueError, match='must have at least one operation'): | ||
_ = cirq.KrausChannel(kraus_ops=ops, key='m') | ||
|
||
|
||
def test_ops_mismatch_fails(): | ||
op2 = np.zeros((4, 4)) | ||
op2[1][1] = 1 | ||
ops = [np.array([[1, 0], [0, 0]]), op2] | ||
|
||
with pytest.raises(ValueError, match='Inconsistent Kraus operator shapes'): | ||
_ = cirq.KrausChannel(kraus_ops=ops, key='m') | ||
|
||
|
||
def test_nonqubit_kraus_ops_fails(): | ||
ops = [ | ||
np.array([[1, 0, 0], [0, 0, 0]]), | ||
np.array([[0, 0, 0], [0, 1, 0]]), | ||
] | ||
|
||
with pytest.raises(ValueError, match='Input Kraus ops'): | ||
_ = cirq.KrausChannel(kraus_ops=ops, key='m') | ||
|
||
|
||
def test_validate(): | ||
# Not quite CPTP. | ||
ops = [ | ||
np.array([[1, 0], [0, 0]]), | ||
np.array([[0, 0], [0, 0.9]]), | ||
] | ||
with pytest.raises(ValueError, match='CPTP map'): | ||
_ = cirq.KrausChannel(kraus_ops=ops, key='m', validate=True) |
Oops, something went wrong.