From a091c65d82b13bceb9fcaae828a612a44be3acaa Mon Sep 17 00:00:00 2001 From: Dave Bacon Date: Mon, 11 Apr 2022 20:55:34 -0700 Subject: [PATCH 1/4] Add testing helper for consistent channel --- cirq-core/cirq/ops/common_channels_test.py | 22 ++++----- cirq-core/cirq/ops/kraus_channel_test.py | 1 + cirq-core/cirq/ops/measurement_gate_test.py | 1 + .../cirq/ops/mixed_unitary_channel_test.py | 1 + cirq-core/cirq/ops/projector.py | 2 +- .../cirq/ops/random_gate_channel_test.py | 6 +-- .../ops/state_preparation_channel_test.py | 1 + cirq-core/cirq/testing/__init__.py | 4 ++ cirq-core/cirq/testing/consistent_channel.py | 26 ++++++++++ .../cirq/testing/consistent_channel_test.py | 47 +++++++++++++++++++ 10 files changed, 96 insertions(+), 15 deletions(-) create mode 100644 cirq-core/cirq/testing/consistent_channel.py create mode 100644 cirq-core/cirq/testing/consistent_channel_test.py diff --git a/cirq-core/cirq/ops/common_channels_test.py b/cirq-core/cirq/ops/common_channels_test.py index 3808e194163..e79c3f5a768 100644 --- a/cirq-core/cirq/ops/common_channels_test.py +++ b/cirq-core/cirq/ops/common_channels_test.py @@ -62,7 +62,7 @@ def test_asymmetric_depolarizing_channel(): cirq.kraus(d), (np.sqrt(0.4) * np.eye(2), np.sqrt(0.1) * X, np.sqrt(0.2) * Y, np.sqrt(0.3) * Z), ) - assert cirq.has_kraus(d) + cirq.testing.assert_consistent_channel(d) assert cirq.AsymmetricDepolarizingChannel(p_x=0, p_y=0.1, p_z=0).num_qubits() == 1 @@ -145,7 +145,7 @@ def test_depolarizing_channel(): np.sqrt(0.1) * Z, ), ) - assert cirq.has_kraus(d) + cirq.testing.assert_consistent_channel(d) def test_depolarizing_channel_two_qubits(): @@ -171,7 +171,7 @@ def test_depolarizing_channel_two_qubits(): np.sqrt(0.01) * np.kron(Z, Z), ), ) - assert cirq.has_kraus(d) + cirq.testing.assert_consistent_channel(d) assert d.num_qubits() == 2 cirq.testing.assert_has_diagram( @@ -332,7 +332,7 @@ def test_generalized_amplitude_damping_channel(): np.sqrt(0.9) * np.array([[0.0, 0.0], [np.sqrt(0.3), 0.0]]), ), ) - assert cirq.has_kraus(d) + cirq.testing.assert_consistent_channel(d) assert not cirq.has_mixture(d) @@ -396,7 +396,7 @@ def test_amplitude_damping_channel(): np.array([[0.0, np.sqrt(0.3)], [0.0, 0.0]]), ), ) - assert cirq.has_kraus(d) + cirq.testing.assert_consistent_channel(d) assert not cirq.has_mixture(d) @@ -448,7 +448,7 @@ def test_reset_channel(): np.testing.assert_almost_equal( cirq.kraus(r), (np.array([[1.0, 0.0], [0.0, 0]]), np.array([[0.0, 1.0], [0.0, 0.0]])) ) - assert cirq.has_kraus(r) + cirq.testing.assert_consistent_channel(r) assert not cirq.has_mixture(r) assert cirq.qid_shape(r) == (2,) @@ -461,7 +461,7 @@ def test_reset_channel(): np.array([[0, 0, 1], [0, 0, 0], [0, 0, 0]]), ), ) # yapf: disable - assert cirq.has_kraus(r) + cirq.testing.assert_consistent_channel(r) assert not cirq.has_mixture(r) assert cirq.qid_shape(r) == (3,) @@ -538,7 +538,7 @@ def test_phase_damping_channel(): np.array([[0.0, 0.0], [0.0, np.sqrt(0.3)]]), ), ) - assert cirq.has_kraus(d) + cirq.testing.assert_consistent_channel(d) assert not cirq.has_mixture(d) @@ -590,7 +590,7 @@ def test_phase_flip_channel(): np.testing.assert_almost_equal( cirq.kraus(d), (np.sqrt(1.0 - 0.3) * np.eye(2), np.sqrt(0.3) * Z) ) - assert cirq.has_kraus(d) + cirq.testing.assert_consistent_channel(d) def test_phase_flip_mixture(): @@ -654,7 +654,7 @@ def test_bit_flip_channel(): np.testing.assert_almost_equal( cirq.kraus(d), (np.sqrt(1.0 - 0.3) * np.eye(2), np.sqrt(0.3) * X) ) - assert cirq.has_kraus(d) + cirq.testing.assert_consistent_channel(d) def test_bit_flip_mixture(): @@ -760,7 +760,7 @@ def test_multi_asymmetric_depolarizing_channel(): np.testing.assert_almost_equal( cirq.kraus(d), (np.sqrt(0.8) * np.eye(4), np.sqrt(0.2) * np.kron(X, X)) ) - assert cirq.has_kraus(d) + cirq.testing.assert_consistent_channel(d) np.testing.assert_equal(d._num_qubits_(), 2) with pytest.raises(ValueError, match="num_qubits should be 1"): diff --git a/cirq-core/cirq/ops/kraus_channel_test.py b/cirq-core/cirq/ops/kraus_channel_test.py index 0e6164f19b7..a7cdbdf90b9 100644 --- a/cirq-core/cirq/ops/kraus_channel_test.py +++ b/cirq-core/cirq/ops/kraus_channel_test.py @@ -9,6 +9,7 @@ def test_kraus_channel_from_channel(): dp = cirq.depolarize(0.1) kc = cirq.KrausChannel.from_channel(dp, key='dp') assert cirq.measurement_key_name(kc) == 'dp' + cirq.testing.assert_consistent_channel(kc) circuit = cirq.Circuit(kc.on(q0)) sim = cirq.Simulator(seed=0) diff --git a/cirq-core/cirq/ops/measurement_gate_test.py b/cirq-core/cirq/ops/measurement_gate_test.py index 5e8e666f879..c6925579f94 100644 --- a/cirq-core/cirq/ops/measurement_gate_test.py +++ b/cirq-core/cirq/ops/measurement_gate_test.py @@ -223,6 +223,7 @@ def test_measurement_channel(): cirq.kraus(cirq.MeasurementGate(1, 'a')), (np.array([[1, 0], [0, 0]]), np.array([[0, 0], [0, 1]])), ) + cirq.testing.assert_consistent_channel(cirq.MeasurementGate(1, 'a')) # yapf: disable np.testing.assert_allclose( cirq.kraus(cirq.MeasurementGate(2, 'a')), diff --git a/cirq-core/cirq/ops/mixed_unitary_channel_test.py b/cirq-core/cirq/ops/mixed_unitary_channel_test.py index 03b671a3b49..2d1358828c8 100644 --- a/cirq-core/cirq/ops/mixed_unitary_channel_test.py +++ b/cirq-core/cirq/ops/mixed_unitary_channel_test.py @@ -9,6 +9,7 @@ def test_matrix_mixture_from_mixture(): dp = cirq.depolarize(0.1) mm = cirq.MixedUnitaryChannel.from_mixture(dp, key='dp') assert cirq.measurement_key_name(mm) == 'dp' + cirq.testing.assert_consistent_channel(mm) circuit = cirq.Circuit(mm.on(q0)) sim = cirq.Simulator(seed=0) diff --git a/cirq-core/cirq/ops/projector.py b/cirq-core/cirq/ops/projector.py index 290ba2351e5..c8aa20cf727 100644 --- a/cirq-core/cirq/ops/projector.py +++ b/cirq-core/cirq/ops/projector.py @@ -31,7 +31,7 @@ def __init__( projector_dict: Dict[raw_types.Qid, int], coefficient: Union[int, float, complex] = 1, ): - """Contructor for ProjectorString + """Constructor for ProjectorString Args: projector_dict: A python dictionary mapping from cirq.Qid to integers. A key value pair diff --git a/cirq-core/cirq/ops/random_gate_channel_test.py b/cirq-core/cirq/ops/random_gate_channel_test.py index c54e5f91e17..36fbe749cbc 100644 --- a/cirq-core/cirq/ops/random_gate_channel_test.py +++ b/cirq-core/cirq/ops/random_gate_channel_test.py @@ -175,9 +175,9 @@ def num_qubits(self) -> int: assert not cirq.has_kraus(NoDetailsGate().with_probability(0.5)) assert cirq.kraus(NoDetailsGate().with_probability(0.5), None) is None assert cirq.kraus(cirq.X.with_probability(sympy.Symbol('x')), None) is None - assert_channel_sums_to_identity(cirq.X.with_probability(0.25)) - assert_channel_sums_to_identity(cirq.bit_flip(0.75).with_probability(0.25)) - assert_channel_sums_to_identity(cirq.amplitude_damp(0.75).with_probability(0.25)) + cirq.testing.assert_consistent_channel(cirq.X.with_probability(0.25)) + cirq.testing.assert_consistent_channel(cirq.bit_flip(0.75).with_probability(0.25)) + cirq.testing.assert_consistent_channel(cirq.amplitude_damp(0.75).with_probability(0.25)) m = cirq.kraus(cirq.X.with_probability(0.25)) assert len(m) == 2 diff --git a/cirq-core/cirq/ops/state_preparation_channel_test.py b/cirq-core/cirq/ops/state_preparation_channel_test.py index ffc454ba554..5939d770ccb 100644 --- a/cirq-core/cirq/ops/state_preparation_channel_test.py +++ b/cirq-core/cirq/ops/state_preparation_channel_test.py @@ -38,6 +38,7 @@ def test_state_prep_channel_kraus(state): qubits = cirq.LineQubit.range(2) gate = cirq.StatePreparationChannel(state)(qubits[0], qubits[1]) + cirq.testing.assert_consistent_channel(gate) state = state / np.linalg.norm(state) np.testing.assert_almost_equal( cirq.kraus(gate), diff --git a/cirq-core/cirq/testing/__init__.py b/cirq-core/cirq/testing/__init__.py index 7124f91173e..a57445253da 100644 --- a/cirq-core/cirq/testing/__init__.py +++ b/cirq-core/cirq/testing/__init__.py @@ -28,6 +28,10 @@ assert_all_implemented_act_on_effects_match_unitary, ) +from cirq.testing.consistent_channel import ( + assert_consistent_channel, +) + from cirq.testing.consistent_controlled_gate_op import ( assert_controlled_and_controlled_by_identical, ) diff --git a/cirq-core/cirq/testing/consistent_channel.py b/cirq-core/cirq/testing/consistent_channel.py new file mode 100644 index 00000000000..3ad467f0f3d --- /dev/null +++ b/cirq-core/cirq/testing/consistent_channel.py @@ -0,0 +1,26 @@ +# Copyright 2022 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any + +import cirq + + +def assert_consistent_channel(gate: Any, rtol: float = 1e-5, atol: float = 1e-8): + assert cirq.has_kraus(gate), f"Given gate {gate!r} does not return True cirq.has_kraus." + kraus_ops = cirq.kraus(gate) + assert cirq.is_cptp(kraus_ops=kraus_ops, rtol=rtol, atol=atol), ( + f"Kraus operators for {gate!r} did not sum to identity up to expected tolerances. " + f"Summed to {sum(m.T.conj() @ m for m in kraus_ops)}" + ) diff --git a/cirq-core/cirq/testing/consistent_channel_test.py b/cirq-core/cirq/testing/consistent_channel_test.py new file mode 100644 index 00000000000..e009d61a9dc --- /dev/null +++ b/cirq-core/cirq/testing/consistent_channel_test.py @@ -0,0 +1,47 @@ +# Copyright 2022 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import numpy as np +import cirq + + +def test_assert_consistent_channel_valid(): + channel = cirq.KrausChannel(kraus_ops=(np.array([[0, 1], [0, 0]]), np.array([[1, 0], [0, 0]]))) + cirq.testing.assert_consistent_channel(channel) + + +def test_assert_consistent_channel_tolerances(): + # This channel is off by 1e-5 from the identity matrix in the consistency condition. + channel = cirq.KrausChannel( + kraus_ops=(np.array([[0, np.sqrt(1 - 1e-5)], [0, 0]]), np.array([[1, 0], [0, 0]])) + ) + # We are comparing to identity, so rtol is same as atol for non-zero entries. + cirq.testing.assert_consistent_channel(channel, rtol=1e-5, atol=0) + with pytest.raises(AssertionError): + cirq.testing.assert_consistent_channel(channel, rtol=1e-6, atol=0) + cirq.testing.assert_consistent_channel(channel, rtol=0, atol=1e-5) + with pytest.raises(AssertionError): + cirq.testing.assert_consistent_channel(channel, rtol=0, atol=1e-6) + + +def test_assert_consistent_channel_invalid(): + channel = cirq.KrausChannel(kraus_ops=(np.array([[1, 1], [0, 0]]), np.array([[1, 0], [0, 0]]))) + with pytest.raises(AssertionError, match=r"cirq.KrausChannel.*Summed to [[2 1]"): + cirq.testing.assert_consistent_channel(channel) + + +def test_assert_consistent_channel_not_kraus(): + with pytest.raises(AssertionError, match="12.*has_kraus"): + cirq.testing.assert_consistent_channel(12) From 64b2c69b17c35bc1a8427adeca7329f4267b9d40 Mon Sep 17 00:00:00 2001 From: Dave Bacon Date: Mon, 11 Apr 2022 21:03:29 -0700 Subject: [PATCH 2/4] remove warning --- cirq-core/cirq/testing/consistent_channel_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cirq-core/cirq/testing/consistent_channel_test.py b/cirq-core/cirq/testing/consistent_channel_test.py index e009d61a9dc..57662d79f72 100644 --- a/cirq-core/cirq/testing/consistent_channel_test.py +++ b/cirq-core/cirq/testing/consistent_channel_test.py @@ -38,7 +38,7 @@ def test_assert_consistent_channel_tolerances(): def test_assert_consistent_channel_invalid(): channel = cirq.KrausChannel(kraus_ops=(np.array([[1, 1], [0, 0]]), np.array([[1, 0], [0, 0]]))) - with pytest.raises(AssertionError, match=r"cirq.KrausChannel.*Summed to [[2 1]"): + with pytest.raises(AssertionError, match=r"cirq.KrausChannel.*2 1"): cirq.testing.assert_consistent_channel(channel) From 9434dd2ed6cb9ba6662ab287aa19c92d40cf3ae7 Mon Sep 17 00:00:00 2001 From: Dave Bacon Date: Fri, 13 May 2022 23:48:03 +0000 Subject: [PATCH 3/4] missed merge --- cirq-core/cirq/ops/common_channels_test.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/cirq-core/cirq/ops/common_channels_test.py b/cirq-core/cirq/ops/common_channels_test.py index d73edd84e21..277e8adfdab 100644 --- a/cirq-core/cirq/ops/common_channels_test.py +++ b/cirq-core/cirq/ops/common_channels_test.py @@ -443,13 +443,10 @@ def test_reset_channel(): np.testing.assert_almost_equal( cirq.kraus(r), (np.array([[1.0, 0.0], [0.0, 0]]), np.array([[0.0, 1.0], [0.0, 0.0]])) ) -<<<<<<< HEAD cirq.testing.assert_consistent_channel(r) -======= assert cirq.num_qubits(r) == 1 assert cirq.has_kraus(r) ->>>>>>> master assert not cirq.has_mixture(r) assert cirq.qid_shape(r) == (2,) From 261da79b729d3f98bb9b9b803e86d1e4789e4f41 Mon Sep 17 00:00:00 2001 From: Dave Bacon Date: Sat, 14 May 2022 00:23:58 +0000 Subject: [PATCH 4/4] add consistent mixture --- cirq-core/cirq/ops/common_channels_test.py | 9 +++++++-- cirq-core/cirq/ops/measurement_gate_test.py | 1 + cirq-core/cirq/ops/mixed_unitary_channel_test.py | 1 + cirq-core/cirq/ops/random_gate_channel_test.py | 4 ++++ .../cirq/ops/state_preparation_channel_test.py | 1 + cirq-core/cirq/testing/__init__.py | 2 +- ...sistent_channel.py => consistent_channels.py} | 16 +++++++++++++++- ...annel_test.py => consistent_channels_test.py} | 0 8 files changed, 30 insertions(+), 4 deletions(-) rename cirq-core/cirq/testing/{consistent_channel.py => consistent_channels.py} (60%) rename cirq-core/cirq/testing/{consistent_channel_test.py => consistent_channels_test.py} (100%) diff --git a/cirq-core/cirq/ops/common_channels_test.py b/cirq-core/cirq/ops/common_channels_test.py index 277e8adfdab..19c9c23a343 100644 --- a/cirq-core/cirq/ops/common_channels_test.py +++ b/cirq-core/cirq/ops/common_channels_test.py @@ -63,6 +63,7 @@ def test_asymmetric_depolarizing_channel(): (np.sqrt(0.4) * np.eye(2), np.sqrt(0.1) * X, np.sqrt(0.2) * Y, np.sqrt(0.3) * Z), ) cirq.testing.assert_consistent_channel(d) + cirq.testing.assert_consistent_mixture(d) assert cirq.AsymmetricDepolarizingChannel(p_x=0, p_y=0.1, p_z=0).num_qubits() == 1 @@ -141,6 +142,7 @@ def test_depolarizing_channel(): (np.sqrt(0.7) * np.eye(2), np.sqrt(0.1) * X, np.sqrt(0.1) * Y, np.sqrt(0.1) * Z), ) cirq.testing.assert_consistent_channel(d) + cirq.testing.assert_consistent_mixture(d) def test_depolarizing_channel_two_qubits(): @@ -167,6 +169,7 @@ def test_depolarizing_channel_two_qubits(): ), ) cirq.testing.assert_consistent_channel(d) + cirq.testing.assert_consistent_mixture(d) assert d.num_qubits() == 2 cirq.testing.assert_has_diagram( @@ -444,10 +447,9 @@ def test_reset_channel(): cirq.kraus(r), (np.array([[1.0, 0.0], [0.0, 0]]), np.array([[0.0, 1.0], [0.0, 0.0]])) ) cirq.testing.assert_consistent_channel(r) + assert not cirq.has_mixture(r) assert cirq.num_qubits(r) == 1 - assert cirq.has_kraus(r) - assert not cirq.has_mixture(r) assert cirq.qid_shape(r) == (2,) r = cirq.reset(cirq.LineQid(0, dimension=3)) @@ -589,6 +591,7 @@ def test_phase_flip_channel(): cirq.kraus(d), (np.sqrt(1.0 - 0.3) * np.eye(2), np.sqrt(0.3) * Z) ) cirq.testing.assert_consistent_channel(d) + cirq.testing.assert_consistent_mixture(d) def test_phase_flip_mixture(): @@ -653,6 +656,7 @@ def test_bit_flip_channel(): cirq.kraus(d), (np.sqrt(1.0 - 0.3) * np.eye(2), np.sqrt(0.3) * X) ) cirq.testing.assert_consistent_channel(d) + cirq.testing.assert_consistent_mixture(d) def test_bit_flip_mixture(): @@ -759,6 +763,7 @@ def test_multi_asymmetric_depolarizing_channel(): cirq.kraus(d), (np.sqrt(0.8) * np.eye(4), np.sqrt(0.2) * np.kron(X, X)) ) cirq.testing.assert_consistent_channel(d) + cirq.testing.assert_consistent_mixture(d) np.testing.assert_equal(d._num_qubits_(), 2) with pytest.raises(ValueError, match="num_qubits should be 1"): diff --git a/cirq-core/cirq/ops/measurement_gate_test.py b/cirq-core/cirq/ops/measurement_gate_test.py index b31974baadf..339a582267c 100644 --- a/cirq-core/cirq/ops/measurement_gate_test.py +++ b/cirq-core/cirq/ops/measurement_gate_test.py @@ -217,6 +217,7 @@ def test_measurement_channel(): (np.array([[1, 0], [0, 0]]), np.array([[0, 0], [0, 1]])), ) cirq.testing.assert_consistent_channel(cirq.MeasurementGate(1, 'a')) + assert not cirq.has_mixture(cirq.MeasurementGate(1, 'a')) # yapf: disable np.testing.assert_allclose( cirq.kraus(cirq.MeasurementGate(2, 'a')), diff --git a/cirq-core/cirq/ops/mixed_unitary_channel_test.py b/cirq-core/cirq/ops/mixed_unitary_channel_test.py index dfdff242128..6781e48ca67 100644 --- a/cirq-core/cirq/ops/mixed_unitary_channel_test.py +++ b/cirq-core/cirq/ops/mixed_unitary_channel_test.py @@ -10,6 +10,7 @@ def test_matrix_mixture_from_mixture(): mm = cirq.MixedUnitaryChannel.from_mixture(dp, key='dp') assert cirq.measurement_key_name(mm) == 'dp' cirq.testing.assert_consistent_channel(mm) + cirq.testing.assert_consistent_mixture(mm) circuit = cirq.Circuit(mm.on(q0)) sim = cirq.Simulator(seed=0) diff --git a/cirq-core/cirq/ops/random_gate_channel_test.py b/cirq-core/cirq/ops/random_gate_channel_test.py index 4484a1eec7c..5c1cb98ac68 100644 --- a/cirq-core/cirq/ops/random_gate_channel_test.py +++ b/cirq-core/cirq/ops/random_gate_channel_test.py @@ -168,6 +168,10 @@ def num_qubits(self) -> int: cirq.testing.assert_consistent_channel(cirq.bit_flip(0.75).with_probability(0.25)) cirq.testing.assert_consistent_channel(cirq.amplitude_damp(0.75).with_probability(0.25)) + cirq.testing.assert_consistent_mixture(cirq.X.with_probability(0.25)) + cirq.testing.assert_consistent_mixture(cirq.bit_flip(0.75).with_probability(0.25)) + assert not cirq.has_mixture(cirq.amplitude_damp(0.75).with_probability(0.25)) + m = cirq.kraus(cirq.X.with_probability(0.25)) assert len(m) == 2 np.testing.assert_allclose(m[0], cirq.unitary(cirq.X) * np.sqrt(0.25), atol=1e-8) diff --git a/cirq-core/cirq/ops/state_preparation_channel_test.py b/cirq-core/cirq/ops/state_preparation_channel_test.py index 5939d770ccb..ead3a80481f 100644 --- a/cirq-core/cirq/ops/state_preparation_channel_test.py +++ b/cirq-core/cirq/ops/state_preparation_channel_test.py @@ -39,6 +39,7 @@ def test_state_prep_channel_kraus(state): qubits = cirq.LineQubit.range(2) gate = cirq.StatePreparationChannel(state)(qubits[0], qubits[1]) cirq.testing.assert_consistent_channel(gate) + assert not cirq.has_mixture(gate) state = state / np.linalg.norm(state) np.testing.assert_almost_equal( cirq.kraus(gate), diff --git a/cirq-core/cirq/testing/__init__.py b/cirq-core/cirq/testing/__init__.py index 026c7cc5ffb..9e0165f0c7f 100644 --- a/cirq-core/cirq/testing/__init__.py +++ b/cirq-core/cirq/testing/__init__.py @@ -26,7 +26,7 @@ from cirq.testing.consistent_act_on import assert_all_implemented_act_on_effects_match_unitary -from cirq.testing.consistent_channel import assert_consistent_channel +from cirq.testing.consistent_channels import assert_consistent_channel, assert_consistent_mixture from cirq.testing.consistent_controlled_gate_op import assert_controlled_and_controlled_by_identical diff --git a/cirq-core/cirq/testing/consistent_channel.py b/cirq-core/cirq/testing/consistent_channels.py similarity index 60% rename from cirq-core/cirq/testing/consistent_channel.py rename to cirq-core/cirq/testing/consistent_channels.py index 3ad467f0f3d..a9e65384dc8 100644 --- a/cirq-core/cirq/testing/consistent_channel.py +++ b/cirq-core/cirq/testing/consistent_channels.py @@ -14,13 +14,27 @@ from typing import Any +import numpy as np + import cirq def assert_consistent_channel(gate: Any, rtol: float = 1e-5, atol: float = 1e-8): - assert cirq.has_kraus(gate), f"Given gate {gate!r} does not return True cirq.has_kraus." + """Asserts that a given gate has Kraus operators and that they are properly normalized.""" + assert cirq.has_kraus(gate), f"Given gate {gate!r} does not return True for cirq.has_kraus." kraus_ops = cirq.kraus(gate) assert cirq.is_cptp(kraus_ops=kraus_ops, rtol=rtol, atol=atol), ( f"Kraus operators for {gate!r} did not sum to identity up to expected tolerances. " f"Summed to {sum(m.T.conj() @ m for m in kraus_ops)}" ) + + +def assert_consistent_mixture(gate: Any, rtol: float = 1e-5, atol: float = 1e-8): + """Asserts that a given gate is a mixture and the mixture probabilities sum to one.""" + assert cirq.has_mixture(gate), f"Give gate {gate!r} does not return for cirq.has_mixture." + mixture = cirq.mixture(gate) + total = np.sum(k for k, v in mixture) + assert total - 1 <= atol + rtol * np.abs(total), ( + f"The mixture for gate {gate!r} did not return coefficients that sum to 1. Summed to " + f"{total}." + ) diff --git a/cirq-core/cirq/testing/consistent_channel_test.py b/cirq-core/cirq/testing/consistent_channels_test.py similarity index 100% rename from cirq-core/cirq/testing/consistent_channel_test.py rename to cirq-core/cirq/testing/consistent_channels_test.py