diff --git a/cirq-core/cirq/testing/consistent_channels.py b/cirq-core/cirq/testing/consistent_channels.py index a9e65384dc8..1b7029955fd 100644 --- a/cirq-core/cirq/testing/consistent_channels.py +++ b/cirq-core/cirq/testing/consistent_channels.py @@ -33,8 +33,8 @@ def assert_consistent_mixture(gate: Any, rtol: float = 1e-5, atol: float = 1e-8) """Asserts that a given gate is a mixture and the mixture probabilities sum to one.""" assert cirq.has_mixture(gate), f"Give gate {gate!r} does not return for cirq.has_mixture." mixture = cirq.mixture(gate) - total = np.sum(k for k, v in mixture) - assert total - 1 <= atol + rtol * np.abs(total), ( + total = np.sum(np.fromiter((k for k, v in mixture), dtype=float)) + assert np.abs(1 - total) <= atol + rtol * np.abs(total), ( f"The mixture for gate {gate!r} did not return coefficients that sum to 1. Summed to " f"{total}." ) diff --git a/cirq-core/cirq/testing/consistent_channels_test.py b/cirq-core/cirq/testing/consistent_channels_test.py index 57662d79f72..75ef7770653 100644 --- a/cirq-core/cirq/testing/consistent_channels_test.py +++ b/cirq-core/cirq/testing/consistent_channels_test.py @@ -45,3 +45,48 @@ def test_assert_consistent_channel_invalid(): def test_assert_consistent_channel_not_kraus(): with pytest.raises(AssertionError, match="12.*has_kraus"): cirq.testing.assert_consistent_channel(12) + + +def test_assert_consistent_mixture_valid(): + mixture = cirq.X.with_probability(0.1) + cirq.testing.assert_consistent_mixture(mixture) + + +def test_assert_consistent_mixture_not_mixture(): + not_mixture = cirq.amplitude_damp(0.1) + with pytest.raises(AssertionError, match="has_mixture"): + cirq.testing.assert_consistent_mixture(not_mixture) + + +class _MixtureGate(cirq.testing.SingleQubitGate): + def __init__(self, p, q): + self._p = p + self._q = q + super().__init__() + + def _mixture_(self): + return (self._p, cirq.unitary(cirq.I)), (self._q, cirq.unitary(cirq.X)) + + +def test_assert_consistent_mixture_not_normalized(): + mixture = _MixtureGate(0.1, 0.85) + with pytest.raises(AssertionError, match="sum to 1"): + cirq.testing.assert_consistent_mixture(mixture) + + mixture = _MixtureGate(0.2, 0.85) + with pytest.raises(AssertionError, match="sum to 1"): + cirq.testing.assert_consistent_mixture(mixture) + + +def test_assert_consistent_mixture_tolerances(): + + # This gate is 1e-5 off being properly normalized. + mixture = _MixtureGate(0.1, 0.9 - 1e-5) + # Defaults of rtol=1e-5, atol=1e-8 are fine. + cirq.testing.assert_consistent_mixture(mixture) + + with pytest.raises(AssertionError, match="sum to 1"): + cirq.testing.assert_consistent_mixture(mixture, rtol=0, atol=1e-6) + + with pytest.raises(AssertionError, match="sum to 1"): + cirq.testing.assert_consistent_mixture(mixture, rtol=1e-6, atol=0)