From 38b6a7e2e571b04308a198c0a22b5dc53f8e7254 Mon Sep 17 00:00:00 2001 From: Dave Bacon Date: Mon, 6 Jun 2022 19:25:58 +0000 Subject: [PATCH 1/4] fix --- cirq-core/cirq/ops/raw_types.py | 8 +- .../two_qubit_to_cz_test.py | 2 +- .../two_qubit_to_fsim_test.py | 2 +- .../serialization/arg_func_langs.py | 11 +- .../serialization/circuit_serializer_test.py | 10 + .../serialization/op_serializer.py | 2 +- .../serialization/op_serializer_test.py | 2 + docs/classical_control.ipynb | 311 ++++++++++++++++++ 8 files changed, 341 insertions(+), 7 deletions(-) create mode 100644 docs/classical_control.ipynb diff --git a/cirq-core/cirq/ops/raw_types.py b/cirq-core/cirq/ops/raw_types.py index d64c1fe923e..dd51abfdfa5 100644 --- a/cirq-core/cirq/ops/raw_types.py +++ b/cirq-core/cirq/ops/raw_types.py @@ -626,9 +626,11 @@ def with_classical_controls( """Returns a classically controlled version of this operation. An operation that is classically controlled is executed iff all - conditions evaluate to True. Currently the only condition type is a - measurement key. A measurement key evaluates to True iff any qubit in - the corresponding measurement operation evaluated to a non-zero value. + conditions evaluate to True. Conditions can be either a measurement key + or a user-specified `cirq.Condition`. A measurement key evaluates to + True iff any qubit in the corresponding measurement operation evaluated + to a non-zero value; `cirq.Condition` supports more complex, + user-defined conditions. If no conditions are specified, returns self. diff --git a/cirq-core/cirq/transformers/analytical_decompositions/two_qubit_to_cz_test.py b/cirq-core/cirq/transformers/analytical_decompositions/two_qubit_to_cz_test.py index e968cb807d5..e26e42b4a52 100644 --- a/cirq-core/cirq/transformers/analytical_decompositions/two_qubit_to_cz_test.py +++ b/cirq-core/cirq/transformers/analytical_decompositions/two_qubit_to_cz_test.py @@ -265,4 +265,4 @@ def test_decompose_to_diagonal_and_circuit(v): assert cirq.is_diagonal(diagonal) combined_circuit = cirq.Circuit(cirq.MatrixGate(diagonal)(b, c), ops) circuit_unitary = combined_circuit.unitary(qubits_that_should_be_present=[b, c]) - cirq.testing.assert_allclose_up_to_global_phase(circuit_unitary, v, atol=1e-14) + cirq.testing.assert_allclose_up_to_global_phase(circuit_unitary, v, atol=2e-6) diff --git a/cirq-core/cirq/transformers/analytical_decompositions/two_qubit_to_fsim_test.py b/cirq-core/cirq/transformers/analytical_decompositions/two_qubit_to_fsim_test.py index f821a9153b5..3a95ae0e2d3 100644 --- a/cirq-core/cirq/transformers/analytical_decompositions/two_qubit_to_fsim_test.py +++ b/cirq-core/cirq/transformers/analytical_decompositions/two_qubit_to_fsim_test.py @@ -115,7 +115,7 @@ def test_decompose_two_qubit_interaction_into_four_fsim_gates_equivalence( for operation in circuit.all_operations(): assert len(operation.qubits) < 2 or operation.gate == fsim_gate assert len(circuit) <= 4 * 3 + 5 - assert cirq.approx_eq(circuit.unitary(qubit_order=qubits), desired_unitary, atol=1e-6) + assert cirq.approx_eq(circuit.unitary(qubit_order=qubits), desired_unitary, atol=1e-4) def test_decompose_two_qubit_interaction_into_four_fsim_gates_validate(): diff --git a/cirq-google/cirq_google/serialization/arg_func_langs.py b/cirq-google/cirq_google/serialization/arg_func_langs.py index 1e4496cdc31..5d52387aea9 100644 --- a/cirq-google/cirq_google/serialization/arg_func_langs.py +++ b/cirq-google/cirq_google/serialization/arg_func_langs.py @@ -37,7 +37,16 @@ # Types for comparing floats # Includes sympy types. Needed for arg parsing. -FLOAT_TYPES = (float, int, sympy.Integer, sympy.Float, sympy.Rational, sympy.NumberSymbol) +FLOAT_TYPES = ( + float, + int, + np.integer, + np.floating, + sympy.Integer, + sympy.Float, + sympy.Rational, + sympy.NumberSymbol, +) # Supported function languages in order from least to most flexible. # Clients should use the least flexible language they can, to make it easier diff --git a/cirq-google/cirq_google/serialization/circuit_serializer_test.py b/cirq-google/cirq_google/serialization/circuit_serializer_test.py index fa75b043756..8789107baf0 100644 --- a/cirq-google/cirq_google/serialization/circuit_serializer_test.py +++ b/cirq-google/cirq_google/serialization/circuit_serializer_test.py @@ -14,6 +14,8 @@ from typing import Dict, List import pytest + +import numpy as np import sympy from google.protobuf import json_format @@ -66,6 +68,14 @@ def circuit_proto(json: Dict, qubits: List[str]): cirq.XPowGate(exponent=0.125)(Q1), op_proto({'xpowgate': {'exponent': {'float_value': 0.125}}, 'qubit_constant_index': [0]}), ), + ( + cirq.XPowGate(exponent=np.double(0.125))(Q1), + op_proto({'xpowgate': {'exponent': {'float_value': 0.125}}, 'qubit_constant_index': [0]}), + ), + ( + cirq.XPowGate(exponent=np.short(1))(Q1), + op_proto({'xpowgate': {'exponent': {'float_value': 1.0}}, 'qubit_constant_index': [0]}), + ), ( cirq.XPowGate(exponent=sympy.Symbol('a'))(Q1), op_proto({'xpowgate': {'exponent': {'symbol': 'a'}}, 'qubit_constant_index': [0]}), diff --git a/cirq-google/cirq_google/serialization/op_serializer.py b/cirq-google/cirq_google/serialization/op_serializer.py index 4399e3be429..76e80fecc5c 100644 --- a/cirq-google/cirq_google/serialization/op_serializer.py +++ b/cirq-google/cirq_google/serialization/op_serializer.py @@ -274,7 +274,7 @@ def _value_from_gate(self, op: cirq.Operation, arg: SerializingArg) -> Optional[ def _check_type(self, value: ARG_LIKE, arg: SerializingArg) -> None: if arg.serialized_type == float: - if not isinstance(value, (float, int)): + if not isinstance(value, (float, int, np.integer, np.floating)): raise ValueError(f'Expected type convertible to float but was {type(value)}') elif arg.serialized_type == List[bool]: if not isinstance(value, (list, tuple, np.ndarray)) or not all( diff --git a/cirq-google/cirq_google/serialization/op_serializer_test.py b/cirq-google/cirq_google/serialization/op_serializer_test.py index b650e40c8ff..845fb1f701c 100644 --- a/cirq-google/cirq_google/serialization/op_serializer_test.py +++ b/cirq-google/cirq_google/serialization/op_serializer_test.py @@ -69,6 +69,8 @@ def get_val(op): TEST_CASES = ( (float, 1.0, {'arg_value': {'float_value': 1.0}}), + (float, np.short(1), {'arg_value': {'float_value': 1.0}}), + (float, np.double(1.0), {'arg_value': {'float_value': 1.0}}), (str, 'abc', {'arg_value': {'string_value': 'abc'}}), (float, 1, {'arg_value': {'float_value': 1.0}}), (List[bool], [True, False], {'arg_value': {'bool_values': {'values': [True, False]}}}), diff --git a/docs/classical_control.ipynb b/docs/classical_control.ipynb new file mode 100644 index 00000000000..bdfefbaabc9 --- /dev/null +++ b/docs/classical_control.ipynb @@ -0,0 +1,311 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "b952a1c0faad" + }, + "outputs": [], + "source": [ + "#@title Copyright 2022 The Cirq Developers\n", + "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "# you may not use this file except in compliance with the License.\n", + "# You may obtain a copy of the License at\n", + "#\n", + "# https://www.apache.org/licenses/LICENSE-2.0\n", + "#\n", + "# Unless required by applicable law or agreed to in writing, software\n", + "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "# See the License for the specific language governing permissions and\n", + "# limitations under the License." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "3556e78efd03" + }, + "source": [ + "# Classical control" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "925dbb45c75e" + }, + "source": [ + "\n", + " \n", + " \n", + " \n", + " \n", + "
\n", + " View on QuantumAI\n", + " \n", + " Run in Google Colab\n", + " \n", + " View source on GitHub\n", + " \n", + " Download notebook\n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "d4c447ddd24e" + }, + "outputs": [], + "source": [ + "try:\n", + " import cirq\n", + "except ImportError:\n", + " print(\"installing cirq...\")\n", + " !pip install --quiet cirq\n", + " import cirq\n", + " print(\"installed cirq.\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "8ccb64c25e3a" + }, + "source": [ + "While some quantum algorithms can be defined entirely at the quantum level, there are many others (notably including [teleportation](/cirq/tutorials/educators/textbook_algorithms#quantum_teleportation) and [error correction](https://www.nature.com/articles/s41586-021-03588-y)) which rely on classical measurement results from one part of the algorithm to control operations in a later section.\n", + "\n", + "To represent this, Cirq provides the `ClassicallyControlledOperation`. Following the pattern of controlled operations, a classically-controlled version of any `Operation` can be constructed by calling its `with_classical_controls` method with the control condition(s)." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "b3ed39be4c06" + }, + "source": [ + "## Basic conditions\n", + "\n", + "In the example below, `X` will only be applied to `q1` if the previous measurement \"a\" returns a 1. More generally, providing some string `\"cond\"` to `with_classical_controls` creates a `ClassicallyControlledOperation` with a `KeyCondition` whose key is `\"cond\"`. A `KeyCondition` will only trigger and apply the operation it controls if a preceding measurement with the same key measured one or more qubits in the $|1\\rangle$ state." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "df3dd6e3b308" + }, + "outputs": [], + "source": [ + "q0, q1, q2 = cirq.LineQubit.range(3)\n", + "circuit = cirq.Circuit(\n", + " cirq.H(q0),\n", + " cirq.measure(q0, key='a'),\n", + " cirq.X(q1).with_classical_controls('a'),\n", + ")\n", + "print(circuit)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "4e416431b695" + }, + "source": [ + "Using just these conditions, we can construct the [quantum teleportation](/cirq/tutorials/educators/textbook_algorithms#quantum_teleportation) circuit:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "01ccc99c6a3c" + }, + "outputs": [], + "source": [ + "# Teleports `_message` from Alice to Bob.\n", + "alice = cirq.NamedQubit('alice')\n", + "bob = cirq.NamedQubit('bob')\n", + "message = cirq.NamedQubit('_message')\n", + "circuit = cirq.Circuit(\n", + " # Create Bell state to be shared between Alice and Bob.\n", + " cirq.H(alice),\n", + " cirq.CNOT(alice, bob),\n", + " # Create the message.\n", + " cirq.X(message) ** 0.371,\n", + " cirq.Y(message) ** 0.882,\n", + " # Bell measurement of the message and Alice's entangled qubit.\n", + " cirq.CNOT(message, alice),\n", + " cirq.H(message),\n", + " cirq.measure(message, key='M'),\n", + " cirq.measure(alice, key='A'),\n", + " # Uses the two classical bits from the Bell measurement to recover the\n", + " # original quantum message on Bob's entangled qubit.\n", + " cirq.X(bob).with_classical_controls('A'),\n", + " cirq.Z(bob).with_classical_controls('M'),\n", + ")\n", + "print(circuit)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "3c64d94110ba" + }, + "source": [ + "## Sympy conditions\n", + "\n", + "Cirq also supports more complex control conditions: providing some `sympy` expression `\"expr\"` to `with_classical_controls` creates a `ClassicallyControlledOperation` with a `SympyCondition`. That condition will only trigger if `\"expr\"` evaluates to a \"truthy\" value (`bool(expr) == True`), and uses measurement results to resolve any variables in the expression.\n", + "\n", + "In this example, `X` will only be applied to `q2` if `a == b`; in other words, $|q_0q_1\\rangle$ must be either $|00\\rangle$ or $|11\\rangle$." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "9a7ff41b51c0" + }, + "outputs": [], + "source": [ + "import sympy\n", + "\n", + "a, b, c = sympy.symbols('a b c')\n", + "sympy_cond = sympy.Eq(a, b)\n", + "circuit = cirq.Circuit(\n", + " cirq.H.on_each(q0, q1),\n", + " cirq.measure(q0, key='a'),\n", + " cirq.measure(q1, key='b'),\n", + " cirq.X(q2).with_classical_controls(sympy_cond)\n", + ")\n", + "print(circuit)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "dfb58a6f479c" + }, + "source": [ + "## Combining conditions\n", + "\n", + "Multiple conditions of either type can be specified to `with_classical_controls`, in which case the resulting `ClassicallyControlledOperation` will only trigger if _all_ conditions trigger. Similarly, calling `with_classical_controls` on an existing `ClassicallyControlledOperation` will require all new and pre-existing conditions to trigger for the operation to trigger." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "8be2002669fc" + }, + "outputs": [], + "source": [ + "sympy_cond = sympy.Eq(a, 0)\n", + "circuit = cirq.Circuit(\n", + " cirq.H.on_each(q0, q1, q2),\n", + " cirq.measure(q0, q1, key='a'),\n", + " cirq.measure(q2, key='b'),\n", + " cirq.X(q0).with_classical_controls('b', sympy_cond),\n", + " cirq.CZ(q1, q2).with_classical_controls('b').with_classical_controls(sympy_cond),\n", + ")\n", + "print(circuit)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "34d0fe9226e1" + }, + "source": [ + "## Variable scope\n", + "\n", + "When used with `CircuitOperation`, classically controlled operations will be resolved using local repetition IDs, if any. This is the only way to create a non-global variable scope within a circuit. A simple example of this is shown below, where the controls inside and outside a subcircuit rely on measurements in their respective scopes:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "6a7441827bd6" + }, + "outputs": [], + "source": [ + "subcircuit = cirq.FrozenCircuit(\n", + " cirq.measure(q0, key='a'), cirq.X(q0).with_classical_controls('a')\n", + ")\n", + "circuit = cirq.Circuit(\n", + " cirq.measure(q0, key='a'),\n", + " cirq.CircuitOperation(subcircuit, repetitions=2),\n", + " cirq.X(q0).with_classical_controls('a')\n", + ")\n", + "print(circuit)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "b0807e8edb7f" + }, + "source": [ + "More complex scoping behavior is described in the [classically controlled operation tests](https://github.com/quantumlib/Cirq/blob/master/cirq-core/cirq/ops/classically_controlled_operation_test.py)." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "520a5bbbea93" + }, + "source": [ + "## Using with transformers\n", + "\n", + "Cirq [transformers](transformers.ipynb) are aware of classical control and will avoid changes which move a control before its corresponding measurement. Additionally, for some simple cases the [`defer_measurements` transformer](https://github.com/daxfohl/Cirq/blob/e68ff85e9bb0c7373572cdc212c10f226cd40b0f/cirq-core/cirq/transformers/measurement_transformers.py#L58) can convert a classically-controlled circuit into a purely-quantum circuit:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "e7bda8edb27a" + }, + "outputs": [], + "source": [ + "circuit = cirq.Circuit(\n", + " cirq.measure(q0, key='a'),\n", + " cirq.X(q1).with_classical_controls('a'),\n", + " cirq.measure(q1, key='b'),\n", + ")\n", + "deferred = cirq.defer_measurements(circuit)\n", + "print(\"Original circuit:\")\n", + "print(circuit)\n", + "print(\"Measurement deferred:\")\n", + "print(deferred)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "48666318febe" + }, + "source": [ + "## Compatibility\n", + "\n", + "The Cirq built-in simulators provide support for classical control, but caution should be exercised when exporting these circuits to other environments. `ClassicallyControlledOperation` is fundamentally different from other operations in that it requires access to the measurement results, and simulators or hardware that does not explicitly support this will not be able to run `ClassicallyControlledOperation`s." + ] + } + ], + "metadata": { + "colab": { + "name": "classical_control.ipynb", + "toc_visible": true + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} From 86a45775bd37b4f38caf9286688157c7af81186e Mon Sep 17 00:00:00 2001 From: Dave Bacon Date: Mon, 6 Jun 2022 19:27:14 +0000 Subject: [PATCH 2/4] Add tests --- cirq-core/cirq/testing/consistent_channels.py | 2 +- cirq-core/cirq/testing/consistent_channels_test.py | 5 +++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/cirq-core/cirq/testing/consistent_channels.py b/cirq-core/cirq/testing/consistent_channels.py index a9e65384dc8..325d7a160de 100644 --- a/cirq-core/cirq/testing/consistent_channels.py +++ b/cirq-core/cirq/testing/consistent_channels.py @@ -33,7 +33,7 @@ def assert_consistent_mixture(gate: Any, rtol: float = 1e-5, atol: float = 1e-8) """Asserts that a given gate is a mixture and the mixture probabilities sum to one.""" assert cirq.has_mixture(gate), f"Give gate {gate!r} does not return for cirq.has_mixture." mixture = cirq.mixture(gate) - total = np.sum(k for k, v in mixture) + total = np.sum([k for k, v in mixture]) assert total - 1 <= atol + rtol * np.abs(total), ( f"The mixture for gate {gate!r} did not return coefficients that sum to 1. Summed to " f"{total}." diff --git a/cirq-core/cirq/testing/consistent_channels_test.py b/cirq-core/cirq/testing/consistent_channels_test.py index 57662d79f72..fa73f9ed2f3 100644 --- a/cirq-core/cirq/testing/consistent_channels_test.py +++ b/cirq-core/cirq/testing/consistent_channels_test.py @@ -45,3 +45,8 @@ def test_assert_consistent_channel_invalid(): def test_assert_consistent_channel_not_kraus(): with pytest.raises(AssertionError, match="12.*has_kraus"): cirq.testing.assert_consistent_channel(12) + + +def test_assert_consistent_mixture_valid(): + mixture = cirq.X.with_probability(0.1) + cirq.testing.assert_consistent_channel(mixture) From bb6f55571102f0c6a193bde713b37d10577a728f Mon Sep 17 00:00:00 2001 From: Dave Bacon Date: Mon, 6 Jun 2022 19:28:08 +0000 Subject: [PATCH 3/4] repair master --- .../analytical_decompositions/two_qubit_to_cz_test.py | 2 +- .../analytical_decompositions/two_qubit_to_fsim_test.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/cirq-core/cirq/transformers/analytical_decompositions/two_qubit_to_cz_test.py b/cirq-core/cirq/transformers/analytical_decompositions/two_qubit_to_cz_test.py index e26e42b4a52..e968cb807d5 100644 --- a/cirq-core/cirq/transformers/analytical_decompositions/two_qubit_to_cz_test.py +++ b/cirq-core/cirq/transformers/analytical_decompositions/two_qubit_to_cz_test.py @@ -265,4 +265,4 @@ def test_decompose_to_diagonal_and_circuit(v): assert cirq.is_diagonal(diagonal) combined_circuit = cirq.Circuit(cirq.MatrixGate(diagonal)(b, c), ops) circuit_unitary = combined_circuit.unitary(qubits_that_should_be_present=[b, c]) - cirq.testing.assert_allclose_up_to_global_phase(circuit_unitary, v, atol=2e-6) + cirq.testing.assert_allclose_up_to_global_phase(circuit_unitary, v, atol=1e-14) diff --git a/cirq-core/cirq/transformers/analytical_decompositions/two_qubit_to_fsim_test.py b/cirq-core/cirq/transformers/analytical_decompositions/two_qubit_to_fsim_test.py index 3a95ae0e2d3..f821a9153b5 100644 --- a/cirq-core/cirq/transformers/analytical_decompositions/two_qubit_to_fsim_test.py +++ b/cirq-core/cirq/transformers/analytical_decompositions/two_qubit_to_fsim_test.py @@ -115,7 +115,7 @@ def test_decompose_two_qubit_interaction_into_four_fsim_gates_equivalence( for operation in circuit.all_operations(): assert len(operation.qubits) < 2 or operation.gate == fsim_gate assert len(circuit) <= 4 * 3 + 5 - assert cirq.approx_eq(circuit.unitary(qubit_order=qubits), desired_unitary, atol=1e-4) + assert cirq.approx_eq(circuit.unitary(qubit_order=qubits), desired_unitary, atol=1e-6) def test_decompose_two_qubit_interaction_into_four_fsim_gates_validate(): From 9874fac2a58a22d8e0c743008cc0ecded4b3e0db Mon Sep 17 00:00:00 2001 From: Dave Bacon Date: Mon, 6 Jun 2022 19:59:20 +0000 Subject: [PATCH 4/4] Add missing consistent mixture tests. Fix bug. --- cirq-core/cirq/testing/consistent_channels.py | 4 +- .../cirq/testing/consistent_channels_test.py | 42 ++++++++++++++++++- 2 files changed, 43 insertions(+), 3 deletions(-) diff --git a/cirq-core/cirq/testing/consistent_channels.py b/cirq-core/cirq/testing/consistent_channels.py index 325d7a160de..1b7029955fd 100644 --- a/cirq-core/cirq/testing/consistent_channels.py +++ b/cirq-core/cirq/testing/consistent_channels.py @@ -33,8 +33,8 @@ def assert_consistent_mixture(gate: Any, rtol: float = 1e-5, atol: float = 1e-8) """Asserts that a given gate is a mixture and the mixture probabilities sum to one.""" assert cirq.has_mixture(gate), f"Give gate {gate!r} does not return for cirq.has_mixture." mixture = cirq.mixture(gate) - total = np.sum([k for k, v in mixture]) - assert total - 1 <= atol + rtol * np.abs(total), ( + total = np.sum(np.fromiter((k for k, v in mixture), dtype=float)) + assert np.abs(1 - total) <= atol + rtol * np.abs(total), ( f"The mixture for gate {gate!r} did not return coefficients that sum to 1. Summed to " f"{total}." ) diff --git a/cirq-core/cirq/testing/consistent_channels_test.py b/cirq-core/cirq/testing/consistent_channels_test.py index fa73f9ed2f3..75ef7770653 100644 --- a/cirq-core/cirq/testing/consistent_channels_test.py +++ b/cirq-core/cirq/testing/consistent_channels_test.py @@ -49,4 +49,44 @@ def test_assert_consistent_channel_not_kraus(): def test_assert_consistent_mixture_valid(): mixture = cirq.X.with_probability(0.1) - cirq.testing.assert_consistent_channel(mixture) + cirq.testing.assert_consistent_mixture(mixture) + + +def test_assert_consistent_mixture_not_mixture(): + not_mixture = cirq.amplitude_damp(0.1) + with pytest.raises(AssertionError, match="has_mixture"): + cirq.testing.assert_consistent_mixture(not_mixture) + + +class _MixtureGate(cirq.testing.SingleQubitGate): + def __init__(self, p, q): + self._p = p + self._q = q + super().__init__() + + def _mixture_(self): + return (self._p, cirq.unitary(cirq.I)), (self._q, cirq.unitary(cirq.X)) + + +def test_assert_consistent_mixture_not_normalized(): + mixture = _MixtureGate(0.1, 0.85) + with pytest.raises(AssertionError, match="sum to 1"): + cirq.testing.assert_consistent_mixture(mixture) + + mixture = _MixtureGate(0.2, 0.85) + with pytest.raises(AssertionError, match="sum to 1"): + cirq.testing.assert_consistent_mixture(mixture) + + +def test_assert_consistent_mixture_tolerances(): + + # This gate is 1e-5 off being properly normalized. + mixture = _MixtureGate(0.1, 0.9 - 1e-5) + # Defaults of rtol=1e-5, atol=1e-8 are fine. + cirq.testing.assert_consistent_mixture(mixture) + + with pytest.raises(AssertionError, match="sum to 1"): + cirq.testing.assert_consistent_mixture(mixture, rtol=0, atol=1e-6) + + with pytest.raises(AssertionError, match="sum to 1"): + cirq.testing.assert_consistent_mixture(mixture, rtol=1e-6, atol=0)