diff --git a/cirq-core/cirq/ops/common_channels_test.py b/cirq-core/cirq/ops/common_channels_test.py index 806d37a426b..19c9c23a343 100644 --- a/cirq-core/cirq/ops/common_channels_test.py +++ b/cirq-core/cirq/ops/common_channels_test.py @@ -62,7 +62,8 @@ def test_asymmetric_depolarizing_channel(): cirq.kraus(d), (np.sqrt(0.4) * np.eye(2), np.sqrt(0.1) * X, np.sqrt(0.2) * Y, np.sqrt(0.3) * Z), ) - assert cirq.has_kraus(d) + cirq.testing.assert_consistent_channel(d) + cirq.testing.assert_consistent_mixture(d) assert cirq.AsymmetricDepolarizingChannel(p_x=0, p_y=0.1, p_z=0).num_qubits() == 1 @@ -140,7 +141,8 @@ def test_depolarizing_channel(): cirq.kraus(d), (np.sqrt(0.7) * np.eye(2), np.sqrt(0.1) * X, np.sqrt(0.1) * Y, np.sqrt(0.1) * Z), ) - assert cirq.has_kraus(d) + cirq.testing.assert_consistent_channel(d) + cirq.testing.assert_consistent_mixture(d) def test_depolarizing_channel_two_qubits(): @@ -166,7 +168,8 @@ def test_depolarizing_channel_two_qubits(): np.sqrt(0.01) * np.kron(Z, Z), ), ) - assert cirq.has_kraus(d) + cirq.testing.assert_consistent_channel(d) + cirq.testing.assert_consistent_mixture(d) assert d.num_qubits() == 2 cirq.testing.assert_has_diagram( @@ -327,7 +330,7 @@ def test_generalized_amplitude_damping_channel(): np.sqrt(0.9) * np.array([[0.0, 0.0], [np.sqrt(0.3), 0.0]]), ), ) - assert cirq.has_kraus(d) + cirq.testing.assert_consistent_channel(d) assert not cirq.has_mixture(d) @@ -391,7 +394,7 @@ def test_amplitude_damping_channel(): np.array([[0.0, np.sqrt(0.3)], [0.0, 0.0]]), ), ) - assert cirq.has_kraus(d) + cirq.testing.assert_consistent_channel(d) assert not cirq.has_mixture(d) @@ -443,10 +446,10 @@ def test_reset_channel(): np.testing.assert_almost_equal( cirq.kraus(r), (np.array([[1.0, 0.0], [0.0, 0]]), np.array([[0.0, 1.0], [0.0, 0.0]])) ) + cirq.testing.assert_consistent_channel(r) + assert not cirq.has_mixture(r) assert cirq.num_qubits(r) == 1 - assert cirq.has_kraus(r) - assert not cirq.has_mixture(r) assert cirq.qid_shape(r) == (2,) r = cirq.reset(cirq.LineQid(0, dimension=3)) @@ -458,7 +461,7 @@ def test_reset_channel(): np.array([[0, 0, 1], [0, 0, 0], [0, 0, 0]]), ), ) # yapf: disable - assert cirq.has_kraus(r) + cirq.testing.assert_consistent_channel(r) assert not cirq.has_mixture(r) assert cirq.qid_shape(r) == (3,) @@ -535,7 +538,7 @@ def test_phase_damping_channel(): np.array([[0.0, 0.0], [0.0, np.sqrt(0.3)]]), ), ) - assert cirq.has_kraus(d) + cirq.testing.assert_consistent_channel(d) assert not cirq.has_mixture(d) @@ -587,7 +590,8 @@ def test_phase_flip_channel(): np.testing.assert_almost_equal( cirq.kraus(d), (np.sqrt(1.0 - 0.3) * np.eye(2), np.sqrt(0.3) * Z) ) - assert cirq.has_kraus(d) + cirq.testing.assert_consistent_channel(d) + cirq.testing.assert_consistent_mixture(d) def test_phase_flip_mixture(): @@ -651,7 +655,8 @@ def test_bit_flip_channel(): np.testing.assert_almost_equal( cirq.kraus(d), (np.sqrt(1.0 - 0.3) * np.eye(2), np.sqrt(0.3) * X) ) - assert cirq.has_kraus(d) + cirq.testing.assert_consistent_channel(d) + cirq.testing.assert_consistent_mixture(d) def test_bit_flip_mixture(): @@ -757,7 +762,8 @@ def test_multi_asymmetric_depolarizing_channel(): np.testing.assert_almost_equal( cirq.kraus(d), (np.sqrt(0.8) * np.eye(4), np.sqrt(0.2) * np.kron(X, X)) ) - assert cirq.has_kraus(d) + cirq.testing.assert_consistent_channel(d) + cirq.testing.assert_consistent_mixture(d) np.testing.assert_equal(d._num_qubits_(), 2) with pytest.raises(ValueError, match="num_qubits should be 1"): diff --git a/cirq-core/cirq/ops/kraus_channel_test.py b/cirq-core/cirq/ops/kraus_channel_test.py index ecb94f336f8..9c45834ad1c 100644 --- a/cirq-core/cirq/ops/kraus_channel_test.py +++ b/cirq-core/cirq/ops/kraus_channel_test.py @@ -9,6 +9,7 @@ def test_kraus_channel_from_channel(): dp = cirq.depolarize(0.1) kc = cirq.KrausChannel.from_channel(dp, key='dp') assert cirq.measurement_key_name(kc) == 'dp' + cirq.testing.assert_consistent_channel(kc) circuit = cirq.Circuit(kc.on(q0)) sim = cirq.Simulator(seed=0) diff --git a/cirq-core/cirq/ops/measurement_gate_test.py b/cirq-core/cirq/ops/measurement_gate_test.py index 91a31be530f..339a582267c 100644 --- a/cirq-core/cirq/ops/measurement_gate_test.py +++ b/cirq-core/cirq/ops/measurement_gate_test.py @@ -216,6 +216,8 @@ def test_measurement_channel(): cirq.kraus(cirq.MeasurementGate(1, 'a')), (np.array([[1, 0], [0, 0]]), np.array([[0, 0], [0, 1]])), ) + cirq.testing.assert_consistent_channel(cirq.MeasurementGate(1, 'a')) + assert not cirq.has_mixture(cirq.MeasurementGate(1, 'a')) # yapf: disable np.testing.assert_allclose( cirq.kraus(cirq.MeasurementGate(2, 'a')), diff --git a/cirq-core/cirq/ops/mixed_unitary_channel_test.py b/cirq-core/cirq/ops/mixed_unitary_channel_test.py index 53d862c1e97..6781e48ca67 100644 --- a/cirq-core/cirq/ops/mixed_unitary_channel_test.py +++ b/cirq-core/cirq/ops/mixed_unitary_channel_test.py @@ -9,6 +9,8 @@ def test_matrix_mixture_from_mixture(): dp = cirq.depolarize(0.1) mm = cirq.MixedUnitaryChannel.from_mixture(dp, key='dp') assert cirq.measurement_key_name(mm) == 'dp' + cirq.testing.assert_consistent_channel(mm) + cirq.testing.assert_consistent_mixture(mm) circuit = cirq.Circuit(mm.on(q0)) sim = cirq.Simulator(seed=0) diff --git a/cirq-core/cirq/ops/projector.py b/cirq-core/cirq/ops/projector.py index b1c29237e2c..ba65196e140 100644 --- a/cirq-core/cirq/ops/projector.py +++ b/cirq-core/cirq/ops/projector.py @@ -21,7 +21,7 @@ class ProjectorString: def __init__( self, projector_dict: Dict[raw_types.Qid, int], coefficient: Union[int, float, complex] = 1 ): - """Contructor for ProjectorString + """Constructor for ProjectorString Args: projector_dict: A python dictionary mapping from cirq.Qid to integers. A key value pair diff --git a/cirq-core/cirq/ops/random_gate_channel_test.py b/cirq-core/cirq/ops/random_gate_channel_test.py index 57b776c262a..5c1cb98ac68 100644 --- a/cirq-core/cirq/ops/random_gate_channel_test.py +++ b/cirq-core/cirq/ops/random_gate_channel_test.py @@ -164,9 +164,13 @@ def num_qubits(self) -> int: assert not cirq.has_kraus(NoDetailsGate().with_probability(0.5)) assert cirq.kraus(NoDetailsGate().with_probability(0.5), None) is None assert cirq.kraus(cirq.X.with_probability(sympy.Symbol('x')), None) is None - assert_channel_sums_to_identity(cirq.X.with_probability(0.25)) - assert_channel_sums_to_identity(cirq.bit_flip(0.75).with_probability(0.25)) - assert_channel_sums_to_identity(cirq.amplitude_damp(0.75).with_probability(0.25)) + cirq.testing.assert_consistent_channel(cirq.X.with_probability(0.25)) + cirq.testing.assert_consistent_channel(cirq.bit_flip(0.75).with_probability(0.25)) + cirq.testing.assert_consistent_channel(cirq.amplitude_damp(0.75).with_probability(0.25)) + + cirq.testing.assert_consistent_mixture(cirq.X.with_probability(0.25)) + cirq.testing.assert_consistent_mixture(cirq.bit_flip(0.75).with_probability(0.25)) + assert not cirq.has_mixture(cirq.amplitude_damp(0.75).with_probability(0.25)) m = cirq.kraus(cirq.X.with_probability(0.25)) assert len(m) == 2 diff --git a/cirq-core/cirq/ops/state_preparation_channel_test.py b/cirq-core/cirq/ops/state_preparation_channel_test.py index ffc454ba554..ead3a80481f 100644 --- a/cirq-core/cirq/ops/state_preparation_channel_test.py +++ b/cirq-core/cirq/ops/state_preparation_channel_test.py @@ -38,6 +38,8 @@ def test_state_prep_channel_kraus(state): qubits = cirq.LineQubit.range(2) gate = cirq.StatePreparationChannel(state)(qubits[0], qubits[1]) + cirq.testing.assert_consistent_channel(gate) + assert not cirq.has_mixture(gate) state = state / np.linalg.norm(state) np.testing.assert_almost_equal( cirq.kraus(gate), diff --git a/cirq-core/cirq/testing/__init__.py b/cirq-core/cirq/testing/__init__.py index 9788ad1b98b..9e0165f0c7f 100644 --- a/cirq-core/cirq/testing/__init__.py +++ b/cirq-core/cirq/testing/__init__.py @@ -26,6 +26,8 @@ from cirq.testing.consistent_act_on import assert_all_implemented_act_on_effects_match_unitary +from cirq.testing.consistent_channels import assert_consistent_channel, assert_consistent_mixture + from cirq.testing.consistent_controlled_gate_op import assert_controlled_and_controlled_by_identical from cirq.testing.consistent_decomposition import ( diff --git a/cirq-core/cirq/testing/consistent_channels.py b/cirq-core/cirq/testing/consistent_channels.py new file mode 100644 index 00000000000..a9e65384dc8 --- /dev/null +++ b/cirq-core/cirq/testing/consistent_channels.py @@ -0,0 +1,40 @@ +# Copyright 2022 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any + +import numpy as np + +import cirq + + +def assert_consistent_channel(gate: Any, rtol: float = 1e-5, atol: float = 1e-8): + """Asserts that a given gate has Kraus operators and that they are properly normalized.""" + assert cirq.has_kraus(gate), f"Given gate {gate!r} does not return True for cirq.has_kraus." + kraus_ops = cirq.kraus(gate) + assert cirq.is_cptp(kraus_ops=kraus_ops, rtol=rtol, atol=atol), ( + f"Kraus operators for {gate!r} did not sum to identity up to expected tolerances. " + f"Summed to {sum(m.T.conj() @ m for m in kraus_ops)}" + ) + + +def assert_consistent_mixture(gate: Any, rtol: float = 1e-5, atol: float = 1e-8): + """Asserts that a given gate is a mixture and the mixture probabilities sum to one.""" + assert cirq.has_mixture(gate), f"Give gate {gate!r} does not return for cirq.has_mixture." + mixture = cirq.mixture(gate) + total = np.sum(k for k, v in mixture) + assert total - 1 <= atol + rtol * np.abs(total), ( + f"The mixture for gate {gate!r} did not return coefficients that sum to 1. Summed to " + f"{total}." + ) diff --git a/cirq-core/cirq/testing/consistent_channels_test.py b/cirq-core/cirq/testing/consistent_channels_test.py new file mode 100644 index 00000000000..57662d79f72 --- /dev/null +++ b/cirq-core/cirq/testing/consistent_channels_test.py @@ -0,0 +1,47 @@ +# Copyright 2022 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import numpy as np +import cirq + + +def test_assert_consistent_channel_valid(): + channel = cirq.KrausChannel(kraus_ops=(np.array([[0, 1], [0, 0]]), np.array([[1, 0], [0, 0]]))) + cirq.testing.assert_consistent_channel(channel) + + +def test_assert_consistent_channel_tolerances(): + # This channel is off by 1e-5 from the identity matrix in the consistency condition. + channel = cirq.KrausChannel( + kraus_ops=(np.array([[0, np.sqrt(1 - 1e-5)], [0, 0]]), np.array([[1, 0], [0, 0]])) + ) + # We are comparing to identity, so rtol is same as atol for non-zero entries. + cirq.testing.assert_consistent_channel(channel, rtol=1e-5, atol=0) + with pytest.raises(AssertionError): + cirq.testing.assert_consistent_channel(channel, rtol=1e-6, atol=0) + cirq.testing.assert_consistent_channel(channel, rtol=0, atol=1e-5) + with pytest.raises(AssertionError): + cirq.testing.assert_consistent_channel(channel, rtol=0, atol=1e-6) + + +def test_assert_consistent_channel_invalid(): + channel = cirq.KrausChannel(kraus_ops=(np.array([[1, 1], [0, 0]]), np.array([[1, 0], [0, 0]]))) + with pytest.raises(AssertionError, match=r"cirq.KrausChannel.*2 1"): + cirq.testing.assert_consistent_channel(channel) + + +def test_assert_consistent_channel_not_kraus(): + with pytest.raises(AssertionError, match="12.*has_kraus"): + cirq.testing.assert_consistent_channel(12)