From 3c8b036b7adb6259960b71919c25f08fda501da1 Mon Sep 17 00:00:00 2001 From: Dave Bacon Date: Thu, 23 Jun 2022 16:16:36 -0700 Subject: [PATCH] Use np.complexfloating for dtypes that should be complex (#5488) This changes the type signature for dtypes that should only be `np.complex64` or `np.complex128`. There are also dtypes in `qis.states.py` but I've not changed these over, as they are often called the `dtype` of a `np.array`, and so they will often need to be cast. This fixes some `check/mypy --next` errors. Technically a breaking change in type signature (more restrictive), but we yell when types are not complex. --- cirq-core/cirq/circuits/circuit.py | 7 +- .../sim/density_matrix_simulation_state.py | 7 +- .../cirq/sim/density_matrix_simulator.py | 7 +- .../cirq/sim/density_matrix_simulator_test.py | 95 ++++++++++--------- cirq-core/cirq/sim/mux.py | 10 +- cirq-core/cirq/sim/simulator_base.py | 3 +- cirq-core/cirq/sim/sparse_simulator.py | 7 +- cirq-core/cirq/sim/sparse_simulator_test.py | 79 +++++++-------- .../cirq/sim/state_vector_simulation_state.py | 7 +- .../sim/state_vector_simulation_state_test.py | 3 +- cirq-core/cirq/sim/state_vector_simulator.py | 16 +++- 11 files changed, 127 insertions(+), 114 deletions(-) diff --git a/cirq-core/cirq/circuits/circuit.py b/cirq-core/cirq/circuits/circuit.py index 82d5fa44c87..917b422ad7f 100644 --- a/cirq-core/cirq/circuits/circuit.py +++ b/cirq-core/cirq/circuits/circuit.py @@ -64,7 +64,6 @@ if TYPE_CHECKING: import cirq - from numpy.typing import DTypeLike _TGate = TypeVar('_TGate', bound='cirq.Gate') @@ -999,7 +998,7 @@ def unitary( qubit_order: 'cirq.QubitOrderOrList' = ops.QubitOrder.DEFAULT, qubits_that_should_be_present: Iterable['cirq.Qid'] = (), ignore_terminal_measurements: bool = True, - dtype: 'DTypeLike' = np.complex64, + dtype: Type[np.complexfloating] = np.complex64, ) -> np.ndarray: """Converts the circuit into a unitary matrix, if possible. @@ -1089,7 +1088,7 @@ def final_state_vector( qubit_order: 'cirq.QubitOrderOrList' = ops.QubitOrder.DEFAULT, qubits_that_should_be_present: Iterable['cirq.Qid'] = (), ignore_terminal_measurements: Optional[bool] = None, - dtype: Optional['DTypeLike'] = None, + dtype: Optional[Type[np.complexfloating]] = None, param_resolver: 'cirq.ParamResolverOrSimilarType' = None, seed: 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None, ) -> np.ndarray: @@ -2656,7 +2655,7 @@ def _apply_unitary_circuit( circuit: 'cirq.AbstractCircuit', state: np.ndarray, qubits: Tuple['cirq.Qid', ...], - dtype: 'DTypeLike', + dtype: Type[np.complexfloating], ) -> np.ndarray: """Applies a circuit's unitary effect to the given vector or matrix. diff --git a/cirq-core/cirq/sim/density_matrix_simulation_state.py b/cirq-core/cirq/sim/density_matrix_simulation_state.py index ad1c969364e..1953e576ee4 100644 --- a/cirq-core/cirq/sim/density_matrix_simulation_state.py +++ b/cirq-core/cirq/sim/density_matrix_simulation_state.py @@ -13,7 +13,7 @@ # limitations under the License. """Objects and methods for acting efficiently on a density matrix.""" -from typing import Any, Callable, List, Optional, Sequence, Tuple, TYPE_CHECKING, Union +from typing import Any, Callable, List, Optional, Sequence, Tuple, Type, TYPE_CHECKING, Union import numpy as np @@ -24,7 +24,6 @@ if TYPE_CHECKING: import cirq - from numpy.typing import DTypeLike class _BufferedDensityMatrix(qis.QuantumStateRepresentation): @@ -58,7 +57,7 @@ def create( *, initial_state: Union[np.ndarray, 'cirq.STATE_VECTOR_LIKE'] = 0, qid_shape: Optional[Tuple[int, ...]] = None, - dtype: Optional['DTypeLike'] = None, + dtype: Optional[Type[np.complexfloating]] = None, buffer: Optional[List[np.ndarray]] = None, ): """Creates a buffered density matrix with the requested state. @@ -252,7 +251,7 @@ def __init__( prng: Optional[np.random.RandomState] = None, qubits: Optional[Sequence['cirq.Qid']] = None, initial_state: Union[np.ndarray, 'cirq.STATE_VECTOR_LIKE'] = 0, - dtype: 'DTypeLike' = np.complex64, + dtype: Type[np.complexfloating] = np.complex64, classical_data: Optional['cirq.ClassicalDataStore'] = None, ): """Inits DensityMatrixSimulationState. diff --git a/cirq-core/cirq/sim/density_matrix_simulator.py b/cirq-core/cirq/sim/density_matrix_simulator.py index f65aa06ec29..a50fdd85ff7 100644 --- a/cirq-core/cirq/sim/density_matrix_simulator.py +++ b/cirq-core/cirq/sim/density_matrix_simulator.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Simulator for density matrices that simulates noisy quantum circuits.""" -from typing import Any, Dict, TYPE_CHECKING, Tuple, Union, Sequence, Optional, List +from typing import Any, Dict, List, Optional, Sequence, Type, TYPE_CHECKING, Tuple, Union import numpy as np @@ -22,7 +22,6 @@ if TYPE_CHECKING: import cirq - from numpy.typing import DTypeLike class DensityMatrixSimulator( @@ -116,7 +115,7 @@ class DensityMatrixSimulator( def __init__( self, *, - dtype: 'DTypeLike' = np.complex64, + dtype: Type[np.complexfloating] = np.complex64, noise: 'cirq.NOISE_MODEL_LIKE' = None, seed: 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None, split_untangled_states: bool = True, @@ -250,7 +249,7 @@ def __init__( self, sim_state: 'cirq.SimulationStateBase[cirq.DensityMatrixSimulationState]', simulator: 'cirq.DensityMatrixSimulator' = None, - dtype: 'DTypeLike' = np.complex64, + dtype: Type[np.complexfloating] = np.complex64, ): """DensityMatrixStepResult. diff --git a/cirq-core/cirq/sim/density_matrix_simulator_test.py b/cirq-core/cirq/sim/density_matrix_simulator_test.py index 42dffbf4d40..38c13c48cee 100644 --- a/cirq-core/cirq/sim/density_matrix_simulator_test.py +++ b/cirq-core/cirq/sim/density_matrix_simulator_test.py @@ -13,14 +13,13 @@ # limitations under the License. import itertools import random +from typing import Type from unittest import mock import numpy as np import pytest import sympy -from numpy.typing import DTypeLike - import cirq import cirq.testing @@ -54,7 +53,7 @@ def test_invalid_dtype(): @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_run_no_measurements(dtype: DTypeLike, split: bool): +def test_run_no_measurements(dtype: Type[np.complexfloating], split: bool): q0, q1 = cirq.LineQubit.range(2) simulator = cirq.DensityMatrixSimulator(dtype=dtype, split_untangled_states=split) @@ -65,7 +64,7 @@ def test_run_no_measurements(dtype: DTypeLike, split: bool): @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_run_no_results(dtype: DTypeLike, split: bool): +def test_run_no_results(dtype: Type[np.complexfloating], split: bool): q0, q1 = cirq.LineQubit.range(2) simulator = cirq.DensityMatrixSimulator(dtype=dtype, split_untangled_states=split) @@ -76,7 +75,7 @@ def test_run_no_results(dtype: DTypeLike, split: bool): @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_run_empty_circuit(dtype: DTypeLike, split: bool): +def test_run_empty_circuit(dtype: Type[np.complexfloating], split: bool): simulator = cirq.DensityMatrixSimulator(dtype=dtype, split_untangled_states=split) with pytest.raises(ValueError, match="no measurements"): simulator.run(cirq.Circuit()) @@ -84,7 +83,7 @@ def test_run_empty_circuit(dtype: DTypeLike, split: bool): @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_run_bit_flips(dtype: DTypeLike, split: bool): +def test_run_bit_flips(dtype: Type[np.complexfloating], split: bool): q0, q1 = cirq.LineQubit.range(2) simulator = cirq.DensityMatrixSimulator(dtype=dtype, split_untangled_states=split) for b0 in [0, 1]: @@ -98,7 +97,7 @@ def test_run_bit_flips(dtype: DTypeLike, split: bool): @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_run_bit_flips_with_dephasing(dtype: DTypeLike, split: bool): +def test_run_bit_flips_with_dephasing(dtype: Type[np.complexfloating], split: bool): q0, q1 = cirq.LineQubit.range(2) simulator = cirq.DensityMatrixSimulator(dtype=dtype, split_untangled_states=split) for b0 in [0, 1]: @@ -112,7 +111,7 @@ def test_run_bit_flips_with_dephasing(dtype: DTypeLike, split: bool): @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_run_qudit_increments(dtype: DTypeLike, split: bool): +def test_run_qudit_increments(dtype: Type[np.complexfloating], split: bool): q0, q1 = cirq.LineQid.for_qid_shape((3, 4)) simulator = cirq.DensityMatrixSimulator(dtype=dtype, split_untangled_states=split) for b0 in [0, 1, 2]: @@ -131,7 +130,7 @@ def test_run_qudit_increments(dtype: DTypeLike, split: bool): @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_run_not_channel_op(dtype: DTypeLike, split: bool): +def test_run_not_channel_op(dtype: Type[np.complexfloating], split: bool): class BadOp(cirq.Operation): def __init__(self, qubits): self._qubits = qubits @@ -153,7 +152,7 @@ def with_qubits(self, *new_qubits): @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_run_mixture(dtype: DTypeLike, split: bool): +def test_run_mixture(dtype: Type[np.complexfloating], split: bool): q0, q1 = cirq.LineQubit.range(2) circuit = cirq.Circuit(cirq.bit_flip(0.5)(q0), cirq.measure(q0), cirq.measure(q1)) simulator = cirq.DensityMatrixSimulator(dtype=dtype, split_untangled_states=split) @@ -167,7 +166,7 @@ def test_run_mixture(dtype: DTypeLike, split: bool): @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_run_qudit_mixture(dtype: DTypeLike, split: bool): +def test_run_qudit_mixture(dtype: Type[np.complexfloating], split: bool): q0, q1 = cirq.LineQid.for_qid_shape((3, 2)) mixture = _TestMixture( [ @@ -188,7 +187,7 @@ def test_run_qudit_mixture(dtype: DTypeLike, split: bool): @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_run_channel(dtype: DTypeLike, split: bool): +def test_run_channel(dtype: Type[np.complexfloating], split: bool): q0, q1 = cirq.LineQubit.range(2) circuit = cirq.Circuit( cirq.X(q0), cirq.amplitude_damp(0.5)(q0), cirq.measure(q0), cirq.measure(q1) @@ -205,7 +204,7 @@ def test_run_channel(dtype: DTypeLike, split: bool): @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_run_decomposable_channel(dtype: DTypeLike, split: bool): +def test_run_decomposable_channel(dtype: Type[np.complexfloating], split: bool): q0, q1 = cirq.LineQubit.range(2) circuit = cirq.Circuit( @@ -226,7 +225,7 @@ def test_run_decomposable_channel(dtype: DTypeLike, split: bool): @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_run_qudit_channel(dtype: DTypeLike, split: bool): +def test_run_qudit_channel(dtype: Type[np.complexfloating], split: bool): class TestChannel(cirq.Gate): def _qid_shape_(self): return (3,) @@ -258,7 +257,7 @@ def _kraus_(self): @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_run_measure_at_end_no_repetitions(dtype: DTypeLike, split: bool): +def test_run_measure_at_end_no_repetitions(dtype: Type[np.complexfloating], split: bool): q0, q1 = cirq.LineQubit.range(2) simulator = cirq.DensityMatrixSimulator(dtype=dtype, split_untangled_states=split) with mock.patch.object(simulator, '_core_iterator', wraps=simulator._core_iterator) as mock_sim: @@ -277,7 +276,7 @@ def test_run_measure_at_end_no_repetitions(dtype: DTypeLike, split: bool): @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_run_repetitions_measure_at_end(dtype: DTypeLike, split: bool): +def test_run_repetitions_measure_at_end(dtype: Type[np.complexfloating], split: bool): q0, q1 = cirq.LineQubit.range(2) simulator = cirq.DensityMatrixSimulator(dtype=dtype, split_untangled_states=split) with mock.patch.object(simulator, '_core_iterator', wraps=simulator._core_iterator) as mock_sim: @@ -296,7 +295,7 @@ def test_run_repetitions_measure_at_end(dtype: DTypeLike, split: bool): @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_run_qudits_repetitions_measure_at_end(dtype: DTypeLike, split: bool): +def test_run_qudits_repetitions_measure_at_end(dtype: Type[np.complexfloating], split: bool): q0, q1 = cirq.LineQid.for_qid_shape((2, 3)) simulator = cirq.DensityMatrixSimulator(dtype=dtype, split_untangled_states=split) with mock.patch.object(simulator, '_core_iterator', wraps=simulator._core_iterator) as mock_sim: @@ -318,7 +317,7 @@ def test_run_qudits_repetitions_measure_at_end(dtype: DTypeLike, split: bool): @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_run_measurement_not_terminal_no_repetitions(dtype: DTypeLike, split: bool): +def test_run_measurement_not_terminal_no_repetitions(dtype: Type[np.complexfloating], split: bool): q0, q1 = cirq.LineQubit.range(2) simulator = cirq.DensityMatrixSimulator(dtype=dtype, split_untangled_states=split) with mock.patch.object(simulator, '_core_iterator', wraps=simulator._core_iterator) as mock_sim: @@ -342,7 +341,7 @@ def test_run_measurement_not_terminal_no_repetitions(dtype: DTypeLike, split: bo @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_run_repetitions_measurement_not_terminal(dtype: DTypeLike, split: bool): +def test_run_repetitions_measurement_not_terminal(dtype: Type[np.complexfloating], split: bool): q0, q1 = cirq.LineQubit.range(2) simulator = cirq.DensityMatrixSimulator(dtype=dtype, split_untangled_states=split) with mock.patch.object(simulator, '_core_iterator', wraps=simulator._core_iterator) as mock_sim: @@ -366,7 +365,9 @@ def test_run_repetitions_measurement_not_terminal(dtype: DTypeLike, split: bool) @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_run_qudits_repetitions_measurement_not_terminal(dtype: DTypeLike, split: bool): +def test_run_qudits_repetitions_measurement_not_terminal( + dtype: Type[np.complexfloating], split: bool +): q0, q1 = cirq.LineQid.for_qid_shape((2, 3)) simulator = cirq.DensityMatrixSimulator(dtype=dtype, split_untangled_states=split) with mock.patch.object(simulator, '_core_iterator', wraps=simulator._core_iterator) as mock_sim: @@ -390,7 +391,7 @@ def test_run_qudits_repetitions_measurement_not_terminal(dtype: DTypeLike, split @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_run_param_resolver(dtype: DTypeLike, split: bool): +def test_run_param_resolver(dtype: Type[np.complexfloating], split: bool): q0, q1 = cirq.LineQubit.range(2) simulator = cirq.DensityMatrixSimulator(dtype=dtype, split_untangled_states=split) for b0 in [0, 1]: @@ -410,7 +411,7 @@ def test_run_param_resolver(dtype: DTypeLike, split: bool): @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_run_correlations(dtype: DTypeLike, split: bool): +def test_run_correlations(dtype: Type[np.complexfloating], split: bool): q0, q1 = cirq.LineQubit.range(2) simulator = cirq.DensityMatrixSimulator(dtype=dtype, split_untangled_states=split) circuit = cirq.Circuit(cirq.H(q0), cirq.CNOT(q0, q1), cirq.measure(q0, q1)) @@ -422,7 +423,7 @@ def test_run_correlations(dtype: DTypeLike, split: bool): @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_run_measure_multiple_qubits(dtype: DTypeLike, split: bool): +def test_run_measure_multiple_qubits(dtype: Type[np.complexfloating], split: bool): q0, q1 = cirq.LineQubit.range(2) simulator = cirq.DensityMatrixSimulator(dtype=dtype, split_untangled_states=split) for b0 in [0, 1]: @@ -434,7 +435,7 @@ def test_run_measure_multiple_qubits(dtype: DTypeLike, split: bool): @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_run_measure_multiple_qudits(dtype: DTypeLike, split: bool): +def test_run_measure_multiple_qudits(dtype: Type[np.complexfloating], split: bool): q0, q1 = cirq.LineQid.for_qid_shape((2, 3)) simulator = cirq.DensityMatrixSimulator(dtype=dtype, split_untangled_states=split) for b0 in [0, 1]: @@ -448,7 +449,7 @@ def test_run_measure_multiple_qudits(dtype: DTypeLike, split: bool): @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_run_sweeps_param_resolvers(dtype: DTypeLike, split: bool): +def test_run_sweeps_param_resolvers(dtype: Type[np.complexfloating], split: bool): q0, q1 = cirq.LineQubit.range(2) simulator = cirq.DensityMatrixSimulator(dtype=dtype, split_untangled_states=split) for b0 in [0, 1]: @@ -474,7 +475,7 @@ def test_run_sweeps_param_resolvers(dtype: DTypeLike, split: bool): @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_simulate_no_circuit(dtype: DTypeLike, split: bool): +def test_simulate_no_circuit(dtype: Type[np.complexfloating], split: bool): q0, q1 = cirq.LineQubit.range(2) simulator = cirq.DensityMatrixSimulator(dtype=dtype, split_untangled_states=split) circuit = cirq.Circuit() @@ -487,7 +488,7 @@ def test_simulate_no_circuit(dtype: DTypeLike, split: bool): @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_simulate(dtype: DTypeLike, split: bool): +def test_simulate(dtype: Type[np.complexfloating], split: bool): q0, q1 = cirq.LineQubit.range(2) simulator = cirq.DensityMatrixSimulator(dtype=dtype, split_untangled_states=split) circuit = cirq.Circuit(cirq.H(q0), cirq.H(q1)) @@ -498,7 +499,7 @@ def test_simulate(dtype: DTypeLike, split: bool): @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_simulate_qudits(dtype: DTypeLike, split: bool): +def test_simulate_qudits(dtype: Type[np.complexfloating], split: bool): q0, q1 = cirq.LineQid.for_qid_shape((2, 3)) simulator = cirq.DensityMatrixSimulator(dtype=dtype, split_untangled_states=split) circuit = cirq.Circuit(cirq.H(q0), cirq.XPowGate(dimension=3)(q1) ** 2) @@ -512,7 +513,7 @@ def test_simulate_qudits(dtype: DTypeLike, split: bool): @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) def test_reset_one_qubit_does_not_affect_partial_trace_of_other_qubits( - dtype: DTypeLike, split: bool + dtype: Type[np.complexfloating], split: bool ): q0, q1 = cirq.LineQubit.range(2) simulator = cirq.DensityMatrixSimulator(dtype=dtype, split_untangled_states=split) @@ -531,7 +532,7 @@ def test_reset_one_qubit_does_not_affect_partial_trace_of_other_qubits( [cirq.testing.random_circuit(cirq.LineQubit.range(4), 5, 0.9) for _ in range(20)], ), ) -def test_simulate_compare_to_state_vector_simulator(dtype: DTypeLike, circuit): +def test_simulate_compare_to_state_vector_simulator(dtype: Type[np.complexfloating], circuit): qubits = cirq.LineQubit.range(4) pure_result = ( cirq.Simulator(dtype=dtype).simulate(circuit, qubit_order=qubits).density_matrix_of() @@ -547,7 +548,7 @@ def test_simulate_compare_to_state_vector_simulator(dtype: DTypeLike, circuit): @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_simulate_bit_flips(dtype: DTypeLike, split: bool): +def test_simulate_bit_flips(dtype: Type[np.complexfloating], split: bool): q0, q1 = cirq.LineQubit.range(2) simulator = cirq.DensityMatrixSimulator(dtype=dtype, split_untangled_states=split) for b0 in [0, 1]: @@ -564,7 +565,7 @@ def test_simulate_bit_flips(dtype: DTypeLike, split: bool): @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_simulate_qudit_increments(dtype: DTypeLike, split: bool): +def test_simulate_qudit_increments(dtype: Type[np.complexfloating], split: bool): q0, q1 = cirq.LineQid.for_qid_shape((2, 3)) simulator = cirq.DensityMatrixSimulator(dtype=dtype, split_untangled_states=split) for b0 in [0, 1]: @@ -584,7 +585,7 @@ def test_simulate_qudit_increments(dtype: DTypeLike, split: bool): @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_simulate_initial_state(dtype: DTypeLike, split: bool): +def test_simulate_initial_state(dtype: Type[np.complexfloating], split: bool): q0, q1 = cirq.LineQubit.range(2) simulator = cirq.DensityMatrixSimulator(dtype=dtype, split_untangled_states=split) for b0 in [0, 1]: @@ -598,7 +599,7 @@ def test_simulate_initial_state(dtype: DTypeLike, split: bool): @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_simulation_state(dtype: DTypeLike, split: bool): +def test_simulation_state(dtype: Type[np.complexfloating], split: bool): q0, q1 = cirq.LineQubit.range(2) simulator = cirq.DensityMatrixSimulator(dtype=dtype, split_untangled_states=split) for b0 in [0, 1]: @@ -625,7 +626,7 @@ def test_simulate_tps_initial_state(): @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_simulate_initial_qudit_state(dtype: DTypeLike, split: bool): +def test_simulate_initial_qudit_state(dtype: Type[np.complexfloating], split: bool): q0, q1 = cirq.LineQid.for_qid_shape((3, 4)) simulator = cirq.DensityMatrixSimulator(dtype=dtype, split_untangled_states=split) for b0 in [0, 1, 2]: @@ -645,7 +646,7 @@ def test_simulate_initial_qudit_state(dtype: DTypeLike, split: bool): @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_simulate_qubit_order(dtype: DTypeLike, split: bool): +def test_simulate_qubit_order(dtype: Type[np.complexfloating], split: bool): q0, q1 = cirq.LineQubit.range(2) simulator = cirq.DensityMatrixSimulator(dtype=dtype, split_untangled_states=split) for b0 in [0, 1]: @@ -659,7 +660,7 @@ def test_simulate_qubit_order(dtype: DTypeLike, split: bool): @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_simulate_param_resolver(dtype: DTypeLike, split: bool): +def test_simulate_param_resolver(dtype: Type[np.complexfloating], split: bool): q0, q1 = cirq.LineQubit.range(2) simulator = cirq.DensityMatrixSimulator(dtype=dtype, split_untangled_states=split) for b0 in [0, 1]: @@ -678,7 +679,7 @@ def test_simulate_param_resolver(dtype: DTypeLike, split: bool): @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_simulate_measure_multiple_qubits(dtype: DTypeLike, split: bool): +def test_simulate_measure_multiple_qubits(dtype: Type[np.complexfloating], split: bool): q0, q1 = cirq.LineQubit.range(2) simulator = cirq.DensityMatrixSimulator(dtype=dtype, split_untangled_states=split) for b0 in [0, 1]: @@ -690,7 +691,7 @@ def test_simulate_measure_multiple_qubits(dtype: DTypeLike, split: bool): @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_simulate_measure_multiple_qudits(dtype: DTypeLike, split: bool): +def test_simulate_measure_multiple_qudits(dtype: Type[np.complexfloating], split: bool): q0, q1 = cirq.LineQid.for_qid_shape((2, 3)) simulator = cirq.DensityMatrixSimulator(dtype=dtype, split_untangled_states=split) for b0 in [0, 1]: @@ -704,7 +705,7 @@ def test_simulate_measure_multiple_qudits(dtype: DTypeLike, split: bool): @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_simulate_sweeps_param_resolver(dtype: DTypeLike, split: bool): +def test_simulate_sweeps_param_resolver(dtype: Type[np.complexfloating], split: bool): q0, q1 = cirq.LineQubit.range(2) simulator = cirq.DensityMatrixSimulator(dtype=dtype, split_untangled_states=split) for b0 in [0, 1]: @@ -731,7 +732,7 @@ def test_simulate_sweeps_param_resolver(dtype: DTypeLike, split: bool): @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_simulate_moment_steps(dtype: DTypeLike, split: bool): +def test_simulate_moment_steps(dtype: Type[np.complexfloating], split: bool): q0, q1 = cirq.LineQubit.range(2) circuit = cirq.Circuit(cirq.H(q0), cirq.H(q1), cirq.H(q0), cirq.H(q1)) simulator = cirq.DensityMatrixSimulator(dtype=dtype, split_untangled_states=split) @@ -745,7 +746,7 @@ def test_simulate_moment_steps(dtype: DTypeLike, split: bool): @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_simulate_moment_steps_qudits(dtype: DTypeLike, split: bool): +def test_simulate_moment_steps_qudits(dtype: Type[np.complexfloating], split: bool): q0, q1 = cirq.LineQid.for_qid_shape((2, 3)) circuit = cirq.Circuit( cirq.XPowGate(dimension=2)(q0), @@ -766,7 +767,7 @@ def test_simulate_moment_steps_qudits(dtype: DTypeLike, split: bool): @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_simulate_moment_steps_empty_circuit(dtype: DTypeLike, split: bool): +def test_simulate_moment_steps_empty_circuit(dtype: Type[np.complexfloating], split: bool): circuit = cirq.Circuit() simulator = cirq.DensityMatrixSimulator(dtype=dtype, split_untangled_states=split) step = None @@ -778,7 +779,7 @@ def test_simulate_moment_steps_empty_circuit(dtype: DTypeLike, split: bool): @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_simulate_moment_steps_sample(dtype: DTypeLike, split: bool): +def test_simulate_moment_steps_sample(dtype: Type[np.complexfloating], split: bool): q0, q1 = cirq.LineQubit.range(2) circuit = cirq.Circuit(cirq.H(q0), cirq.CNOT(q0, q1)) simulator = cirq.DensityMatrixSimulator(dtype=dtype, split_untangled_states=split) @@ -799,7 +800,7 @@ def test_simulate_moment_steps_sample(dtype: DTypeLike, split: bool): @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_simulate_moment_steps_sample_qudits(dtype: DTypeLike, split: bool): +def test_simulate_moment_steps_sample_qudits(dtype: Type[np.complexfloating], split: bool): class TestGate(cirq.Gate): """Swaps the 2nd qid |0> and |2> states when the 1st is |1>.""" @@ -828,7 +829,9 @@ def _apply_unitary_(self, args: cirq.ApplyUnitaryArgs): @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_simulate_moment_steps_intermediate_measurement(dtype: DTypeLike, split: bool): +def test_simulate_moment_steps_intermediate_measurement( + dtype: Type[np.complexfloating], split: bool +): q0 = cirq.LineQubit(0) circuit = cirq.Circuit(cirq.H(q0), cirq.measure(q0), cirq.H(q0)) simulator = cirq.DensityMatrixSimulator(dtype=dtype, split_untangled_states=split) diff --git a/cirq-core/cirq/sim/mux.py b/cirq-core/cirq/sim/mux.py index 288e7454af5..b11d4484e0a 100644 --- a/cirq-core/cirq/sim/mux.py +++ b/cirq-core/cirq/sim/mux.py @@ -17,7 +17,7 @@ Filename is a reference to multiplexing. """ -from typing import cast, List, Optional, Sequence, TYPE_CHECKING, Union +from typing import cast, List, Optional, Sequence, Type, TYPE_CHECKING, Union import numpy as np @@ -53,7 +53,7 @@ def sample( noise: 'cirq.NOISE_MODEL_LIKE' = None, param_resolver: Optional['cirq.ParamResolver'] = None, repetitions: int = 1, - dtype: 'DTypeLike' = np.complex64, + dtype: Type[np.complexfloating] = np.complex64, seed: 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None, ) -> 'cirq.Result': """Simulates sampling from the given circuit. @@ -108,7 +108,7 @@ def final_state_vector( param_resolver: 'cirq.ParamResolverOrSimilarType' = None, qubit_order: 'cirq.QubitOrderOrList' = ops.QubitOrder.DEFAULT, ignore_terminal_measurements: bool = False, - dtype: 'DTypeLike' = np.complex64, + dtype: Type[np.complexfloating] = np.complex64, seed: 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None, ) -> 'np.ndarray': """Returns the state vector resulting from acting operations on a state. @@ -178,7 +178,7 @@ def sample_sweep( *, noise: 'cirq.NOISE_MODEL_LIKE' = None, repetitions: int = 1, - dtype: 'DTypeLike' = np.complex64, + dtype: Type[np.complexfloating] = np.complex64, seed: 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None, ) -> Sequence['cirq.Result']: """Runs the supplied Circuit, mimicking quantum hardware. @@ -224,7 +224,7 @@ def final_density_matrix( initial_state: 'cirq.STATE_VECTOR_LIKE' = 0, param_resolver: 'cirq.ParamResolverOrSimilarType' = None, qubit_order: 'cirq.QubitOrderOrList' = ops.QubitOrder.DEFAULT, - dtype: 'DTypeLike' = np.complex64, + dtype: Type[np.complexfloating] = np.complex64, seed: Optional[Union[int, np.random.RandomState]] = None, ignore_measurement_results: bool = True, ) -> 'np.ndarray': diff --git a/cirq-core/cirq/sim/simulator_base.py b/cirq-core/cirq/sim/simulator_base.py index 18a6b2d19d9..9a2e85ea4a0 100644 --- a/cirq-core/cirq/sim/simulator_base.py +++ b/cirq-core/cirq/sim/simulator_base.py @@ -26,6 +26,7 @@ Optional, Sequence, Tuple, + Type, TypeVar, TYPE_CHECKING, ) @@ -93,7 +94,7 @@ class SimulatorBase( def __init__( self, *, - dtype: 'DTypeLike' = np.complex64, + dtype: Type[np.complexfloating] = np.complex64, noise: 'cirq.NOISE_MODEL_LIKE' = None, seed: 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None, split_untangled_states: bool = False, diff --git a/cirq-core/cirq/sim/sparse_simulator.py b/cirq-core/cirq/sim/sparse_simulator.py index 932c231526e..dd2c160f1cc 100644 --- a/cirq-core/cirq/sim/sparse_simulator.py +++ b/cirq-core/cirq/sim/sparse_simulator.py @@ -14,7 +14,7 @@ """A simulator that uses numpy's einsum for sparse matrix operations.""" -from typing import Any, Iterator, List, TYPE_CHECKING, Union, Sequence, Optional +from typing import Any, Iterator, List, TYPE_CHECKING, Union, Sequence, Type, Optional import numpy as np @@ -24,7 +24,6 @@ if TYPE_CHECKING: import cirq - from numpy.typing import DTypeLike class Simulator( @@ -127,7 +126,7 @@ class Simulator( def __init__( self, *, - dtype: 'DTypeLike' = np.complex64, + dtype: Type[np.complexfloating] = np.complex64, noise: 'cirq.NOISE_MODEL_LIKE' = None, seed: 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None, split_untangled_states: bool = True, @@ -231,7 +230,7 @@ def __init__( self, sim_state: 'cirq.SimulationStateBase[cirq.StateVectorSimulationState]', simulator: 'cirq.Simulator' = None, - dtype: 'DTypeLike' = np.complex64, + dtype: Type[np.complexfloating] = np.complex64, ): """Results of a step of the simulator. diff --git a/cirq-core/cirq/sim/sparse_simulator_test.py b/cirq-core/cirq/sim/sparse_simulator_test.py index 6bfc7aa3005..764f3c0abaa 100644 --- a/cirq-core/cirq/sim/sparse_simulator_test.py +++ b/cirq-core/cirq/sim/sparse_simulator_test.py @@ -13,9 +13,10 @@ # limitations under the License. import itertools import random +from typing import Type + from unittest import mock import numpy as np -from numpy.typing import DTypeLike import pytest import sympy @@ -29,7 +30,7 @@ def test_invalid_dtype(): @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_run_no_measurements(dtype: DTypeLike, split: bool): +def test_run_no_measurements(dtype: Type[np.complexfloating], split: bool): q0, q1 = cirq.LineQubit.range(2) simulator = cirq.Simulator(dtype=dtype, split_untangled_states=split) @@ -40,7 +41,7 @@ def test_run_no_measurements(dtype: DTypeLike, split: bool): @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_run_no_results(dtype: DTypeLike, split: bool): +def test_run_no_results(dtype: Type[np.complexfloating], split: bool): q0, q1 = cirq.LineQubit.range(2) simulator = cirq.Simulator(dtype=dtype, split_untangled_states=split) @@ -51,7 +52,7 @@ def test_run_no_results(dtype: DTypeLike, split: bool): @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_run_empty_circuit(dtype: DTypeLike, split: bool): +def test_run_empty_circuit(dtype: Type[np.complexfloating], split: bool): simulator = cirq.Simulator(dtype=dtype, split_untangled_states=split) with pytest.raises(ValueError, match="no measurements"): simulator.run(cirq.Circuit()) @@ -59,7 +60,7 @@ def test_run_empty_circuit(dtype: DTypeLike, split: bool): @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_run_reset(dtype: DTypeLike, split: bool): +def test_run_reset(dtype: Type[np.complexfloating], split: bool): q0, q1 = cirq.LineQid.for_qid_shape((2, 3)) simulator = cirq.Simulator(dtype=dtype, split_untangled_states=split) circuit = cirq.Circuit( @@ -79,7 +80,7 @@ def test_run_reset(dtype: DTypeLike, split: bool): @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_run_bit_flips(dtype: DTypeLike, split: bool): +def test_run_bit_flips(dtype: Type[np.complexfloating], split: bool): q0, q1 = cirq.LineQubit.range(2) simulator = cirq.Simulator(dtype=dtype, split_untangled_states=split) for b0 in [0, 1]: @@ -93,7 +94,7 @@ def test_run_bit_flips(dtype: DTypeLike, split: bool): @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_run_measure_at_end_no_repetitions(dtype: DTypeLike, split: bool): +def test_run_measure_at_end_no_repetitions(dtype: Type[np.complexfloating], split: bool): q0, q1 = cirq.LineQubit.range(2) simulator = cirq.Simulator(dtype=dtype, split_untangled_states=split) with mock.patch.object(simulator, '_core_iterator', wraps=simulator._core_iterator) as mock_sim: @@ -119,7 +120,7 @@ def test_run_repetitions_terminal_measurement_stochastic(): @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_run_repetitions_measure_at_end(dtype: DTypeLike, split: bool): +def test_run_repetitions_measure_at_end(dtype: Type[np.complexfloating], split: bool): q0, q1 = cirq.LineQubit.range(2) simulator = cirq.Simulator(dtype=dtype, split_untangled_states=split) with mock.patch.object(simulator, '_core_iterator', wraps=simulator._core_iterator) as mock_sim: @@ -139,7 +140,7 @@ def test_run_repetitions_measure_at_end(dtype: DTypeLike, split: bool): @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_run_invert_mask_measure_not_terminal(dtype: DTypeLike, split: bool): +def test_run_invert_mask_measure_not_terminal(dtype: Type[np.complexfloating], split: bool): q0, q1 = cirq.LineQubit.range(2) simulator = cirq.Simulator(dtype=dtype, split_untangled_states=split) with mock.patch.object(simulator, '_core_iterator', wraps=simulator._core_iterator) as mock_sim: @@ -160,7 +161,7 @@ def test_run_invert_mask_measure_not_terminal(dtype: DTypeLike, split: bool): @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_run_partial_invert_mask_measure_not_terminal(dtype: DTypeLike, split: bool): +def test_run_partial_invert_mask_measure_not_terminal(dtype: Type[np.complexfloating], split: bool): q0, q1 = cirq.LineQubit.range(2) simulator = cirq.Simulator(dtype=dtype, split_untangled_states=split) with mock.patch.object(simulator, '_core_iterator', wraps=simulator._core_iterator) as mock_sim: @@ -181,7 +182,7 @@ def test_run_partial_invert_mask_measure_not_terminal(dtype: DTypeLike, split: b @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_run_measurement_not_terminal_no_repetitions(dtype: DTypeLike, split: bool): +def test_run_measurement_not_terminal_no_repetitions(dtype: Type[np.complexfloating], split: bool): q0, q1 = cirq.LineQubit.range(2) simulator = cirq.Simulator(dtype=dtype, split_untangled_states=split) with mock.patch.object(simulator, '_core_iterator', wraps=simulator._core_iterator) as mock_sim: @@ -205,7 +206,7 @@ def test_run_measurement_not_terminal_no_repetitions(dtype: DTypeLike, split: bo @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_run_repetitions_measurement_not_terminal(dtype: DTypeLike, split: bool): +def test_run_repetitions_measurement_not_terminal(dtype: Type[np.complexfloating], split: bool): q0, q1 = cirq.LineQubit.range(2) simulator = cirq.Simulator(dtype=dtype, split_untangled_states=split) with mock.patch.object(simulator, '_core_iterator', wraps=simulator._core_iterator) as mock_sim: @@ -230,7 +231,7 @@ def test_run_repetitions_measurement_not_terminal(dtype: DTypeLike, split: bool) @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_run_param_resolver(dtype: DTypeLike, split: bool): +def test_run_param_resolver(dtype: Type[np.complexfloating], split: bool): q0, q1 = cirq.LineQubit.range(2) simulator = cirq.Simulator(dtype=dtype, split_untangled_states=split) for b0 in [0, 1]: @@ -249,7 +250,7 @@ def test_run_param_resolver(dtype: DTypeLike, split: bool): @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_run_mixture(dtype: DTypeLike, split: bool): +def test_run_mixture(dtype: Type[np.complexfloating], split: bool): q0 = cirq.LineQubit(0) simulator = cirq.Simulator(dtype=dtype, split_untangled_states=split) circuit = cirq.Circuit(cirq.bit_flip(0.5)(q0), cirq.measure(q0)) @@ -259,7 +260,7 @@ def test_run_mixture(dtype: DTypeLike, split: bool): @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_run_mixture_with_gates(dtype: DTypeLike, split: bool): +def test_run_mixture_with_gates(dtype: Type[np.complexfloating], split: bool): q0 = cirq.LineQubit(0) simulator = cirq.Simulator(dtype=dtype, split_untangled_states=split, seed=23) circuit = cirq.Circuit(cirq.H(q0), cirq.phase_flip(0.5)(q0), cirq.H(q0), cirq.measure(q0)) @@ -270,7 +271,7 @@ def test_run_mixture_with_gates(dtype: DTypeLike, split: bool): @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_run_correlations(dtype: DTypeLike, split: bool): +def test_run_correlations(dtype: Type[np.complexfloating], split: bool): q0, q1 = cirq.LineQubit.range(2) simulator = cirq.Simulator(dtype=dtype, split_untangled_states=split) circuit = cirq.Circuit(cirq.H(q0), cirq.CNOT(q0, q1), cirq.measure(q0, q1)) @@ -282,7 +283,7 @@ def test_run_correlations(dtype: DTypeLike, split: bool): @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_run_measure_multiple_qubits(dtype: DTypeLike, split: bool): +def test_run_measure_multiple_qubits(dtype: Type[np.complexfloating], split: bool): q0, q1 = cirq.LineQubit.range(2) simulator = cirq.Simulator(dtype=dtype, split_untangled_states=split) for b0 in [0, 1]: @@ -294,7 +295,7 @@ def test_run_measure_multiple_qubits(dtype: DTypeLike, split: bool): @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_run_sweeps_param_resolvers(dtype: DTypeLike, split: bool): +def test_run_sweeps_param_resolvers(dtype: Type[np.complexfloating], split: bool): q0, q1 = cirq.LineQubit.range(2) simulator = cirq.Simulator(dtype=dtype, split_untangled_states=split) for b0 in [0, 1]: @@ -320,7 +321,7 @@ def test_run_sweeps_param_resolvers(dtype: DTypeLike, split: bool): @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_simulate_random_unitary(dtype: DTypeLike, split: bool): +def test_simulate_random_unitary(dtype: Type[np.complexfloating], split: bool): q0, q1 = cirq.LineQubit.range(2) simulator = cirq.Simulator(dtype=dtype, split_untangled_states=split) for _ in range(10): @@ -336,7 +337,7 @@ def test_simulate_random_unitary(dtype: DTypeLike, split: bool): @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_simulate_no_circuit(dtype: DTypeLike, split: bool): +def test_simulate_no_circuit(dtype: Type[np.complexfloating], split: bool): q0, q1 = cirq.LineQubit.range(2) simulator = cirq.Simulator(dtype=dtype, split_untangled_states=split) circuit = cirq.Circuit() @@ -347,7 +348,7 @@ def test_simulate_no_circuit(dtype: DTypeLike, split: bool): @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_simulate(dtype: DTypeLike, split: bool): +def test_simulate(dtype: Type[np.complexfloating], split: bool): q0, q1 = cirq.LineQubit.range(2) simulator = cirq.Simulator(dtype=dtype, split_untangled_states=split) circuit = cirq.Circuit(cirq.H(q0), cirq.H(q1)) @@ -369,7 +370,7 @@ def _mixture_(self): @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_simulate_qudits(dtype: DTypeLike, split: bool): +def test_simulate_qudits(dtype: Type[np.complexfloating], split: bool): q0, q1 = cirq.LineQid.for_qid_shape((3, 4)) simulator = cirq.Simulator(dtype=dtype, split_untangled_states=split) circuit = cirq.Circuit(cirq.XPowGate(dimension=3)(q0), cirq.XPowGate(dimension=4)(q1) ** 3) @@ -382,7 +383,7 @@ def test_simulate_qudits(dtype: DTypeLike, split: bool): @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_simulate_mixtures(dtype: DTypeLike, split: bool): +def test_simulate_mixtures(dtype: Type[np.complexfloating], split: bool): q0 = cirq.LineQubit(0) simulator = cirq.Simulator(dtype=dtype, split_untangled_states=split) circuit = cirq.Circuit(cirq.bit_flip(0.5)(q0), cirq.measure(q0)) @@ -400,7 +401,7 @@ def test_simulate_mixtures(dtype: DTypeLike, split: bool): @pytest.mark.parametrize( 'dtype, split', itertools.product([np.complex64, np.complex128], [True, False]) ) -def test_simulate_qudit_mixtures(dtype: DTypeLike, split: bool): +def test_simulate_qudit_mixtures(dtype: Type[np.complexfloating], split: bool): q0 = cirq.LineQid(0, 3) simulator = cirq.Simulator(dtype=dtype, split_untangled_states=split) mixture = _TestMixture( @@ -426,7 +427,7 @@ def test_simulate_qudit_mixtures(dtype: DTypeLike, split: bool): @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_simulate_bit_flips(dtype: DTypeLike, split: bool): +def test_simulate_bit_flips(dtype: Type[np.complexfloating], split: bool): q0, q1 = cirq.LineQubit.range(2) simulator = cirq.Simulator(dtype=dtype, split_untangled_states=split) for b0 in [0, 1]: @@ -443,7 +444,7 @@ def test_simulate_bit_flips(dtype: DTypeLike, split: bool): @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_simulate_initial_state(dtype: DTypeLike, split: bool): +def test_simulate_initial_state(dtype: Type[np.complexfloating], split: bool): q0, q1 = cirq.LineQubit.range(2) simulator = cirq.Simulator(dtype=dtype, split_untangled_states=split) for b0 in [0, 1]: @@ -457,7 +458,7 @@ def test_simulate_initial_state(dtype: DTypeLike, split: bool): @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_simulation_state(dtype: DTypeLike, split: bool): +def test_simulation_state(dtype: Type[np.complexfloating], split: bool): q0, q1 = cirq.LineQubit.range(2) simulator = cirq.Simulator(dtype=dtype, split_untangled_states=split) for b0 in [0, 1]: @@ -472,7 +473,7 @@ def test_simulation_state(dtype: DTypeLike, split: bool): @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_simulate_qubit_order(dtype: DTypeLike, split: bool): +def test_simulate_qubit_order(dtype: Type[np.complexfloating], split: bool): q0, q1 = cirq.LineQubit.range(2) simulator = cirq.Simulator(dtype=dtype, split_untangled_states=split) for b0 in [0, 1]: @@ -486,7 +487,7 @@ def test_simulate_qubit_order(dtype: DTypeLike, split: bool): @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_simulate_param_resolver(dtype: DTypeLike, split: bool): +def test_simulate_param_resolver(dtype: Type[np.complexfloating], split: bool): q0, q1 = cirq.LineQubit.range(2) simulator = cirq.Simulator(dtype=dtype, split_untangled_states=split) for b0 in [0, 1]: @@ -505,7 +506,7 @@ def test_simulate_param_resolver(dtype: DTypeLike, split: bool): @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_simulate_measure_multiple_qubits(dtype: DTypeLike, split: bool): +def test_simulate_measure_multiple_qubits(dtype: Type[np.complexfloating], split: bool): q0, q1 = cirq.LineQubit.range(2) simulator = cirq.Simulator(dtype=dtype, split_untangled_states=split) for b0 in [0, 1]: @@ -517,7 +518,7 @@ def test_simulate_measure_multiple_qubits(dtype: DTypeLike, split: bool): @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_simulate_sweeps_param_resolver(dtype: DTypeLike, split: bool): +def test_simulate_sweeps_param_resolver(dtype: Type[np.complexfloating], split: bool): q0, q1 = cirq.LineQubit.range(2) simulator = cirq.Simulator(dtype=dtype, split_untangled_states=split) for b0 in [0, 1]: @@ -544,7 +545,7 @@ def test_simulate_sweeps_param_resolver(dtype: DTypeLike, split: bool): @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_simulate_moment_steps(dtype: DTypeLike, split: bool): +def test_simulate_moment_steps(dtype: Type[np.complexfloating], split: bool): q0, q1 = cirq.LineQubit.range(2) circuit = cirq.Circuit(cirq.H(q0), cirq.H(q1), cirq.H(q0), cirq.H(q1)) simulator = cirq.Simulator(dtype=dtype, split_untangled_states=split) @@ -568,7 +569,7 @@ def test_simulate_moment_steps_implicit_copy_deprecated(): @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_simulate_moment_steps_empty_circuit(dtype: DTypeLike, split: bool): +def test_simulate_moment_steps_empty_circuit(dtype: Type[np.complexfloating], split: bool): circuit = cirq.Circuit() simulator = cirq.Simulator(dtype=dtype, split_untangled_states=split) step = None @@ -580,7 +581,7 @@ def test_simulate_moment_steps_empty_circuit(dtype: DTypeLike, split: bool): @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_simulate_moment_steps_sample(dtype: DTypeLike, split: bool): +def test_simulate_moment_steps_sample(dtype: Type[np.complexfloating], split: bool): q0, q1 = cirq.LineQubit.range(2) circuit = cirq.Circuit(cirq.H(q0), cirq.CNOT(q0, q1)) simulator = cirq.Simulator(dtype=dtype, split_untangled_states=split) @@ -601,7 +602,9 @@ def test_simulate_moment_steps_sample(dtype: DTypeLike, split: bool): @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_simulate_moment_steps_intermediate_measurement(dtype: DTypeLike, split: bool): +def test_simulate_moment_steps_intermediate_measurement( + dtype: Type[np.complexfloating], split: bool +): q0 = cirq.LineQubit(0) circuit = cirq.Circuit(cirq.H(q0), cirq.measure(q0), cirq.H(q0)) simulator = cirq.Simulator(dtype=dtype, split_untangled_states=split) @@ -618,7 +621,7 @@ def test_simulate_moment_steps_intermediate_measurement(dtype: DTypeLike, split: @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_simulate_expectation_values(dtype: DTypeLike, split: bool): +def test_simulate_expectation_values(dtype: Type[np.complexfloating], split: bool): # Compare with test_expectation_from_state_vector_two_qubit_states # in file: cirq/ops/linear_combinations_test.py q0, q1 = cirq.LineQubit.range(2) @@ -643,7 +646,7 @@ def test_simulate_expectation_values(dtype: DTypeLike, split: bool): @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_simulate_expectation_values_terminal_measure(dtype: DTypeLike, split: bool): +def test_simulate_expectation_values_terminal_measure(dtype: Type[np.complexfloating], split: bool): q0 = cirq.LineQubit(0) circuit = cirq.Circuit(cirq.H(q0), cirq.measure(q0)) obs = cirq.Z(q0) @@ -681,7 +684,7 @@ def test_simulate_expectation_values_terminal_measure(dtype: DTypeLike, split: b @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_simulate_expectation_values_qubit_order(dtype: DTypeLike, split: bool): +def test_simulate_expectation_values_qubit_order(dtype: Type[np.complexfloating], split: bool): q0, q1, q2 = cirq.LineQubit.range(3) circuit = cirq.Circuit(cirq.H(q0), cirq.H(q1), cirq.X(q2)) obs = cirq.X(q0) + cirq.X(q1) - cirq.Z(q2) diff --git a/cirq-core/cirq/sim/state_vector_simulation_state.py b/cirq-core/cirq/sim/state_vector_simulation_state.py index 1bfd7d62226..e1bc1d08d2f 100644 --- a/cirq-core/cirq/sim/state_vector_simulation_state.py +++ b/cirq-core/cirq/sim/state_vector_simulation_state.py @@ -13,7 +13,7 @@ # limitations under the License. """Objects and methods for acting efficiently on a state vector.""" -from typing import Any, Callable, List, Optional, Sequence, Tuple, TYPE_CHECKING, Union +from typing import Any, Callable, List, Optional, Sequence, Tuple, Type, TYPE_CHECKING, Union import numpy as np @@ -24,7 +24,6 @@ if TYPE_CHECKING: import cirq - from numpy.typing import DTypeLike class _BufferedStateVector(qis.QuantumStateRepresentation): @@ -53,7 +52,7 @@ def create( *, initial_state: Union[np.ndarray, 'cirq.STATE_VECTOR_LIKE'] = 0, qid_shape: Optional[Tuple[int, ...]] = None, - dtype: Optional['DTypeLike'] = None, + dtype: Optional[Type[np.complexfloating]] = None, buffer: Optional[List[np.ndarray]] = None, ): """Initializes the object with the inputs. @@ -326,7 +325,7 @@ def __init__( prng: Optional[np.random.RandomState] = None, qubits: Optional[Sequence['cirq.Qid']] = None, initial_state: Union[np.ndarray, 'cirq.STATE_VECTOR_LIKE'] = 0, - dtype: 'DTypeLike' = np.complex64, + dtype: Type[np.complexfloating] = np.complex64, classical_data: Optional['cirq.ClassicalDataStore'] = None, ): """Inits StateVectorSimulationState. diff --git a/cirq-core/cirq/sim/state_vector_simulation_state_test.py b/cirq-core/cirq/sim/state_vector_simulation_state_test.py index 4dc54331e78..515ea2b2457 100644 --- a/cirq-core/cirq/sim/state_vector_simulation_state_test.py +++ b/cirq-core/cirq/sim/state_vector_simulation_state_test.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import cast, Type from unittest import mock import numpy as np @@ -170,7 +171,7 @@ def get_result(state: np.ndarray, sample: float): qubits=cirq.LineQubit.range(4), prng=mock_prng, initial_state=np.copy(state), - dtype=state.dtype, + dtype=cast(Type[np.complexfloating], state.dtype), ) cirq.act_on(Decay11(), args, [cirq.LineQubit(1), cirq.LineQubit(3)]) return args.target_tensor diff --git a/cirq-core/cirq/sim/state_vector_simulator.py b/cirq-core/cirq/sim/state_vector_simulator.py index d2e9c86f0c8..ecc0f417b89 100644 --- a/cirq-core/cirq/sim/state_vector_simulator.py +++ b/cirq-core/cirq/sim/state_vector_simulator.py @@ -14,7 +14,18 @@ """Abstract classes for simulations which keep track of state vector.""" import abc -from typing import Any, Dict, Iterator, Sequence, TYPE_CHECKING, Tuple, Generic, TypeVar, Optional +from typing import ( + Any, + Dict, + Iterator, + Sequence, + Type, + TYPE_CHECKING, + Tuple, + Generic, + TypeVar, + Optional, +) import numpy as np @@ -24,7 +35,6 @@ if TYPE_CHECKING: import cirq - from numpy.typing import DTypeLike TStateVectorStepResult = TypeVar('TStateVectorStepResult', bound='StateVectorStepResult') @@ -46,7 +56,7 @@ class SimulatesIntermediateStateVector( def __init__( self, *, - dtype: 'DTypeLike' = np.complex64, + dtype: Type[np.complexfloating] = np.complex64, noise: 'cirq.NOISE_MODEL_LIKE' = None, seed: 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None, split_untangled_states: bool = False,