diff --git a/cirq-core/cirq/protocols/apply_unitary_protocol.py b/cirq-core/cirq/protocols/apply_unitary_protocol.py index 5c9cc396cdc..3f8a36bf004 100644 --- a/cirq-core/cirq/protocols/apply_unitary_protocol.py +++ b/cirq-core/cirq/protocols/apply_unitary_protocol.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """A protocol for implementing high performance unitary left-multiplies.""" - +import warnings from typing import Any, cast, Iterable, Optional, Sequence, Tuple, TYPE_CHECKING, TypeVar, Union import numpy as np @@ -342,7 +342,6 @@ def apply_unitary( TypeError: `unitary_value` doesn't have a unitary effect and `default` wasn't specified. """ - # Decide on order to attempt application strategies. if len(args.axes) <= 4: strats = [ @@ -360,12 +359,15 @@ def apply_unitary( strats.remove(_strat_apply_unitary_from_decompose) # Try each strategy, stopping if one works. - for strat in strats: - result = strat(unitary_value, args) - if result is None: - break - if result is not NotImplemented: - return result + # Also catch downcasting warnings and throw an error: #2041 + with warnings.catch_warnings(): + warnings.filterwarnings(action="error", category=np.ComplexWarning) + for strat in strats: + result = strat(unitary_value, args) + if result is None: + break + if result is not NotImplemented: + return result # Don't know how to apply. Fallback to specified default behavior. if default is not RaiseTypeErrorIfNotProvided: diff --git a/cirq-core/cirq/protocols/apply_unitary_protocol_test.py b/cirq-core/cirq/protocols/apply_unitary_protocol_test.py index 0f41ec77d20..1b455c6bd69 100644 --- a/cirq-core/cirq/protocols/apply_unitary_protocol_test.py +++ b/cirq-core/cirq/protocols/apply_unitary_protocol_test.py @@ -704,3 +704,16 @@ def test_apply_unitary_args_with_axes_transposed_to_start(): assert args.target_tensor[1, 2, 3, 4] == 1 new_args.available_buffer[2, 4, 1, 3] = 2 assert args.available_buffer[1, 2, 3, 4] == 2 + + +def test_cast_to_complex(): + y0 = cirq.PauliString({cirq.LineQubit(0): cirq.Y}) + state = 0.5 * np.eye(2) + args = cirq.ApplyUnitaryArgs( + target_tensor=state, available_buffer=np.zeros_like(state), axes=(0,) + ) + + with pytest.raises( + np.ComplexWarning, match='Casting complex values to real discards the imaginary part' + ): + cirq.apply_unitary(y0, args)