From 530200a5d3468b36bd8b91efb069d349f7f834ef Mon Sep 17 00:00:00 2001 From: daxfohl Date: Thu, 21 Apr 2022 13:58:26 -0700 Subject: [PATCH 1/8] Deprecate final_step_result param of TrialResult --- cirq-core/cirq/contrib/quimb/mps_simulator.py | 11 ++-- .../cirq/sim/clifford/clifford_simulator.py | 11 ++-- .../cirq/sim/density_matrix_simulator.py | 11 ++-- .../cirq/sim/density_matrix_simulator_test.py | 56 +++++++++-------- cirq-core/cirq/sim/simulator.py | 62 ++++++++++++------- cirq-core/cirq/sim/simulator_base.py | 10 +-- cirq-core/cirq/sim/simulator_base_test.py | 6 +- cirq-core/cirq/sim/simulator_test.py | 52 ++++++++-------- cirq-core/cirq/sim/state_vector.py | 6 +- cirq-core/cirq/sim/state_vector_simulator.py | 10 +-- 10 files changed, 130 insertions(+), 105 deletions(-) diff --git a/cirq-core/cirq/contrib/quimb/mps_simulator.py b/cirq-core/cirq/contrib/quimb/mps_simulator.py index 2a39efe6f62..4cc85b0c3b9 100644 --- a/cirq-core/cirq/contrib/quimb/mps_simulator.py +++ b/cirq-core/cirq/contrib/quimb/mps_simulator.py @@ -26,7 +26,7 @@ from cirq import devices, protocols, qis, value from cirq._compat import deprecated_parameter -from cirq.sim import simulator_base +from cirq.sim import simulator, simulator_base from cirq.sim.act_on_args import ActOnArgs if TYPE_CHECKING: @@ -122,7 +122,7 @@ def _create_simulator_trial_result( self, params: 'cirq.ParamResolver', measurements: Dict[str, np.ndarray], - final_step_result: 'MPSSimulatorStepResult', + final_simulator_state: 'cirq.OperationTarget[MPSState]', ) -> 'MPSTrialResult': """Creates a single trial results with the measurements. @@ -136,21 +136,22 @@ def _create_simulator_trial_result( A single result. """ return MPSTrialResult( - params=params, measurements=measurements, final_step_result=final_step_result + params=params, measurements=measurements, final_simulator_state=final_simulator_state ) class MPSTrialResult(simulator_base.SimulationTrialResultBase['MPSState']): """A single trial reult""" + @simulator.deprecated_step_result_parameter(old_position=3) def __init__( self, params: 'cirq.ParamResolver', measurements: Dict[str, np.ndarray], - final_step_result: 'MPSSimulatorStepResult', + final_simulator_state: 'cirq.OperationTarget[MPSState]', ) -> None: super().__init__( - params=params, measurements=measurements, final_step_result=final_step_result + params=params, measurements=measurements, final_simulator_state=final_simulator_state ) @property diff --git a/cirq-core/cirq/sim/clifford/clifford_simulator.py b/cirq-core/cirq/sim/clifford/clifford_simulator.py index b12137a1f54..c5bbb87fd39 100644 --- a/cirq-core/cirq/sim/clifford/clifford_simulator.py +++ b/cirq-core/cirq/sim/clifford/clifford_simulator.py @@ -36,7 +36,7 @@ import cirq from cirq import protocols, value from cirq.protocols import act_on -from cirq.sim import clifford, simulator_base +from cirq.sim import clifford, simulator, simulator_base class CliffordSimulator( @@ -107,25 +107,26 @@ def _create_simulator_trial_result( self, params: 'cirq.ParamResolver', measurements: Dict[str, np.ndarray], - final_step_result: 'CliffordSimulatorStepResult', + final_simulator_state: 'cirq.OperationTarget[cirq.ActOnStabilizerCHFormArgs]', ): return CliffordTrialResult( - params=params, measurements=measurements, final_step_result=final_step_result + params=params, measurements=measurements, final_simulator_state=final_simulator_state ) class CliffordTrialResult( simulator_base.SimulationTrialResultBase['clifford.ActOnStabilizerCHFormArgs'] ): + @simulator.deprecated_step_result_parameter(old_position=3) def __init__( self, params: 'cirq.ParamResolver', measurements: Dict[str, np.ndarray], - final_step_result: 'cirq.CliffordSimulatorStepResult', + final_simulator_state: 'cirq.OperationTarget[cirq.ActOnStabilizerCHFormArgs]', ) -> None: super().__init__( - params=params, measurements=measurements, final_step_result=final_step_result + params=params, measurements=measurements, final_simulator_state=final_simulator_state ) @property diff --git a/cirq-core/cirq/sim/density_matrix_simulator.py b/cirq-core/cirq/sim/density_matrix_simulator.py index caf140f7c01..602da2947f6 100644 --- a/cirq-core/cirq/sim/density_matrix_simulator.py +++ b/cirq-core/cirq/sim/density_matrix_simulator.py @@ -186,10 +186,10 @@ def _create_simulator_trial_result( self, params: 'cirq.ParamResolver', measurements: Dict[str, np.ndarray], - final_step_result: 'cirq.DensityMatrixStepResult', + final_simulator_state: 'cirq.OperationTarget[cirq.ActOnDensityMatrixArgs]', ) -> 'cirq.DensityMatrixTrialResult': return DensityMatrixTrialResult( - params=params, measurements=measurements, final_step_result=final_step_result + params=params, measurements=measurements, final_simulator_state=final_simulator_state ) # TODO(#4209): Deduplicate with identical code in sparse_simulator. @@ -380,14 +380,15 @@ class DensityMatrixTrialResult( trial finishes. """ + @simulator.deprecated_step_result_parameter(old_position=3) def __init__( self, params: 'cirq.ParamResolver', measurements: Dict[str, np.ndarray], - final_step_result: 'cirq.DensityMatrixStepResult', + final_simulator_state: 'cirq.OperationTarget[cirq.ActOnDensityMatrixArgs]', ) -> None: super().__init__( - params=params, measurements=measurements, final_step_result=final_step_result + params=params, measurements=measurements, final_simulator_state=final_simulator_state ) self._final_density_matrix: Optional[np.ndarray] = None @@ -418,7 +419,7 @@ def __repr__(self) -> str: return ( 'cirq.DensityMatrixTrialResult(' f'params={self.params!r}, measurements={proper_repr(self.measurements)}, ' - f'final_step_result={self._final_step_result!r})' + f'final_simulator_state={self._final_simulator_state!r})' ) def _repr_pretty_(self, p: Any, cycle: bool): diff --git a/cirq-core/cirq/sim/density_matrix_simulator_test.py b/cirq-core/cirq/sim/density_matrix_simulator_test.py index 8e8ca9db3e4..4711bc27379 100644 --- a/cirq-core/cirq/sim/density_matrix_simulator_test.py +++ b/cirq-core/cirq/sim/density_matrix_simulator_test.py @@ -967,52 +967,59 @@ def test_density_matrix_simulator_state_repr(): def test_density_matrix_trial_result_eq(): q0 = cirq.LineQubit(0) - final_step_result = cirq.DensityMatrixStepResult( - cirq.ActOnDensityMatrixArgs(initial_state=np.ones((2, 2)) * 0.5, qubits=[q0]) + final_simulator_state = cirq.ActOnDensityMatrixArgs( + initial_state=np.ones((2, 2)) * 0.5, qubits=[q0] ) eq = cirq.testing.EqualsTester() eq.add_equality_group( cirq.DensityMatrixTrialResult( - params=cirq.ParamResolver({}), measurements={}, final_step_result=final_step_result + params=cirq.ParamResolver({}), + measurements={}, + final_simulator_state=final_simulator_state, ), cirq.DensityMatrixTrialResult( - params=cirq.ParamResolver({}), measurements={}, final_step_result=final_step_result + params=cirq.ParamResolver({}), + measurements={}, + final_simulator_state=final_simulator_state, ), ) eq.add_equality_group( cirq.DensityMatrixTrialResult( params=cirq.ParamResolver({'s': 1}), measurements={}, - final_step_result=final_step_result, + final_simulator_state=final_simulator_state, ) ) eq.add_equality_group( cirq.DensityMatrixTrialResult( params=cirq.ParamResolver({'s': 1}), measurements={'m': np.array([[1]])}, - final_step_result=final_step_result, + final_simulator_state=final_simulator_state, ) ) def test_density_matrix_trial_result_qid_shape(): q0, q1 = cirq.LineQubit.range(2) - final_step_result = mock.Mock(cirq.StepResult) - final_step_result._simulator_state.return_value = cirq.DensityMatrixSimulatorState( - density_matrix=np.ones((4, 4)) / 4, qubit_map={q0: 0, q1: 1} + final_simulator_state = cirq.ActOnDensityMatrixArgs( + initial_state=np.ones((4, 4)) / 4, qubits=[q0, q1] ) assert cirq.qid_shape( cirq.DensityMatrixTrialResult( - params=cirq.ParamResolver({}), measurements={}, final_step_result=final_step_result + params=cirq.ParamResolver({}), + measurements={}, + final_simulator_state=final_simulator_state, ) ) == (2, 2) q0, q1 = cirq.LineQid.for_qid_shape((3, 4)) - final_step_result._simulator_state.return_value = cirq.DensityMatrixSimulatorState( - density_matrix=np.ones((12, 12)) / 12, qubit_map={q0: 0, q1: 1} + final_simulator_state = cirq.ActOnDensityMatrixArgs( + initial_state=np.ones((12, 12)) / 12, qubits=[q0, q1] ) assert cirq.qid_shape( cirq.DensityMatrixTrialResult( - params=cirq.ParamResolver({}), measurements={}, final_step_result=final_step_result + params=cirq.ParamResolver({}), + measurements={}, + final_simulator_state=final_simulator_state, ) ) == (3, 4) @@ -1020,7 +1027,7 @@ def test_density_matrix_trial_result_qid_shape(): def test_density_matrix_trial_result_repr(): q0 = cirq.LineQubit(0) dtype = np.complex64 - args = cirq.ActOnDensityMatrixArgs( + final_simulator_state = cirq.ActOnDensityMatrixArgs( available_buffer=[], qid_shape=(2,), prng=np.random.RandomState(0), @@ -1028,23 +1035,20 @@ def test_density_matrix_trial_result_repr(): initial_state=np.ones((2, 2), dtype=dtype) * 0.5, dtype=dtype, ) - final_step_result = cirq.DensityMatrixStepResult(args) trial_result = cirq.DensityMatrixTrialResult( params=cirq.ParamResolver({'s': 1}), measurements={'m': np.array([[1]], dtype=np.int32)}, - final_step_result=final_step_result, + final_simulator_state=final_simulator_state, ) expected_repr = ( "cirq.DensityMatrixTrialResult(" "params=cirq.ParamResolver({'s': 1}), " "measurements={'m': np.array([[1]], dtype=np.int32)}, " - "final_step_result=cirq.DensityMatrixStepResult(" - "sim_state=cirq.ActOnDensityMatrixArgs(" + "final_simulator_state=cirq.ActOnDensityMatrixArgs(" "initial_state=np.array([[(0.5+0j), (0.5+0j)], [(0.5+0j), (0.5+0j)]], dtype=np.complex64), " "qid_shape=(2,), " "qubits=(cirq.LineQubit(0),), " - "classical_data=cirq.ClassicalDataDictionaryStore()), " - "dtype=np.complex64))" + "classical_data=cirq.ClassicalDataDictionaryStore()))" ) assert repr(trial_result) == expected_repr assert eval(expected_repr) == trial_result @@ -1111,7 +1115,7 @@ def test_works_on_pauli_string(): def test_density_matrix_trial_result_str(): q0 = cirq.LineQubit(0) dtype = np.complex64 - args = cirq.ActOnDensityMatrixArgs( + final_simulator_state = cirq.ActOnDensityMatrixArgs( available_buffer=[], qid_shape=(2,), prng=np.random.RandomState(0), @@ -1119,9 +1123,8 @@ def test_density_matrix_trial_result_str(): initial_state=np.ones((2, 2), dtype=dtype) * 0.5, dtype=dtype, ) - final_step_result = cirq.DensityMatrixStepResult(args) result = cirq.DensityMatrixTrialResult( - params=cirq.ParamResolver({}), measurements={}, final_step_result=final_step_result + params=cirq.ParamResolver({}), measurements={}, final_simulator_state=final_simulator_state ) # numpy varies whitespace in its representation for different versions @@ -1137,7 +1140,7 @@ def test_density_matrix_trial_result_str(): def test_density_matrix_trial_result_repr_pretty(): q0 = cirq.LineQubit(0) dtype = np.complex64 - args = cirq.ActOnDensityMatrixArgs( + final_simulator_state = cirq.ActOnDensityMatrixArgs( available_buffer=[], qid_shape=(2,), prng=np.random.RandomState(0), @@ -1145,9 +1148,8 @@ def test_density_matrix_trial_result_repr_pretty(): initial_state=np.ones((2, 2), dtype=dtype) * 0.5, dtype=dtype, ) - final_step_result = cirq.DensityMatrixStepResult(args) result = cirq.DensityMatrixTrialResult( - params=cirq.ParamResolver({}), measurements={}, final_step_result=final_step_result + params=cirq.ParamResolver({}), measurements={}, final_simulator_state=final_simulator_state ) fake_printer = cirq.testing.FakePrinter() @@ -1523,7 +1525,7 @@ def test_large_untangled_okay(): # Validate a simulation run result = cirq.DensityMatrixSimulator().simulate(circuit) - assert set(result._final_step_result._qubits) == set(cirq.LineQubit.range(59)) + assert set(result._final_simulator_state.qubits) == set(cirq.LineQubit.range(59)) # _ = result.final_density_matrix hangs (as expected) # Validate a trial run and sampling diff --git a/cirq-core/cirq/sim/simulator.py b/cirq-core/cirq/sim/simulator.py index b1f3bed9483..b7005969195 100644 --- a/cirq-core/cirq/sim/simulator.py +++ b/cirq-core/cirq/sim/simulator.py @@ -37,7 +37,7 @@ Generic, Iterator, List, - Optional, + Mapping, Sequence, Set, Tuple, @@ -48,7 +48,7 @@ import numpy as np -from cirq import circuits, ops, protocols, study, value, work +from cirq import _compat, circuits, ops, protocols, study, value, work from cirq.sim.act_on_args import ActOnArgs from cirq.sim.operation_target import OperationTarget @@ -610,7 +610,9 @@ def simulate_sweep_iter( for k, v in step_result.measurements.items(): measurements[k] = np.array(v, dtype=np.uint8) yield self._create_simulator_trial_result( - params=param_resolver, measurements=measurements, final_step_result=step_result + params=param_resolver, + measurements=measurements, + final_simulator_state=step_result._simulator_state(), ) def simulate_moment_steps( @@ -724,7 +726,7 @@ def _create_simulator_trial_result( self, params: 'cirq.ParamResolver', measurements: Dict[str, np.ndarray], - final_step_result: TStepResult, + final_simulator_state: 'cirq.OperationTarget[TActOnArgs]', ) -> TSimulationTrialResult: """This method can be implemented to create a trial result. @@ -876,6 +878,34 @@ def sample_measurement_ops( ) +def deprecated_step_result_parameter( + old_position: int = 4, new_position: int = 3 +) -> Callable[[Callable], Callable]: + def rewrite_deprecated_step_result_param(args, kwargs): + has_state = len(args) > new_position or 'final_simulator_state' in kwargs + has_step_result = len(args) > old_position or 'final_step_result' in kwargs + if not has_step_result ^ has_state: + raise ValueError( + 'Exactly one of final_simulator_state and final_step_result should be provided' + ) + if len(args) > old_position: + args[new_position] = args[old_position]._simulator_state() + if old_position > new_position: + del args[old_position] + elif 'final_step_result' in kwargs: + kwargs['final_simulator_state'] = kwargs['final_step_result']._simulator_state() + del kwargs['final_step_result'] + return args, kwargs + + return _compat.deprecated_parameter( + deadline='v0.16', + fix='', + parameter_desc='final_step_result', + match=lambda args, kwargs: 'final_step_result' in kwargs or len(args) > old_position, + rewrite=rewrite_deprecated_step_result_param, + ) + + @value.value_equality(unhashable=True) class SimulationTrialResult(Generic[TSimulatorState]): """Results of a simulation by a SimulatesFinalState. @@ -893,12 +923,12 @@ class SimulationTrialResult(Generic[TSimulatorState]): measurement gate.) """ + @deprecated_step_result_parameter() def __init__( self, params: 'cirq.ParamResolver', measurements: Dict[str, np.ndarray], - final_simulator_state: Optional[TSimulatorState] = None, - final_step_result: Optional['cirq.StepResult[TSimulatorState]'] = None, + final_simulator_state: TSimulatorState, ) -> None: """Initializes the `SimulationTrialResult` class. @@ -909,28 +939,14 @@ def __init__( boolean measurement results (ordered by the qubits acted on by the measurement gate.) final_simulator_state: The final simulator state. - final_step_result: The step result coming from the simulation, that - can be used to get the final simulator state. This is primarily - for cases when calculating simulator state may be expensive and - unneeded. If this is provided, then final_simulator_state - should not be, and vice versa. Raises: ValueError: If `final_step_result` and `final_simulator_state` are both None or both not None. """ - if [final_step_result, final_simulator_state].count(None) != 1: - raise ValueError( - 'Exactly one of final_simulator_state and final_step_result should be provided' - ) self.params = params self.measurements = measurements - self._final_step_result = final_step_result - self._final_simulator_state: TSimulatorState = ( - final_simulator_state - if final_simulator_state is not None - else cast('cirq.StepResult[TSimulatorState]', final_step_result)._simulator_state() - ) + self._final_simulator_state = final_simulator_state def __repr__(self) -> str: return ( @@ -962,7 +978,7 @@ def _value_equality_values_(self) -> Any: return self.params, measurements, self._final_simulator_state @property - def qubit_map(self) -> Dict['cirq.Qid', int]: + def qubit_map(self) -> Mapping['cirq.Qid', int]: """A map from Qid to index used to define the ordering of the basis in the result. """ @@ -972,7 +988,7 @@ def _qid_shape_(self) -> Tuple[int, ...]: return _qubit_map_to_shape(self.qubit_map) -def _qubit_map_to_shape(qubit_map: Dict['cirq.Qid', int]) -> Tuple[int, ...]: +def _qubit_map_to_shape(qubit_map: Mapping['cirq.Qid', int]) -> Tuple[int, ...]: qid_shape: List[int] = [-1] * len(qubit_map) try: for q, i in qubit_map.items(): diff --git a/cirq-core/cirq/sim/simulator_base.py b/cirq-core/cirq/sim/simulator_base.py index f5ab34fd2f4..749856742de 100644 --- a/cirq-core/cirq/sim/simulator_base.py +++ b/cirq-core/cirq/sim/simulator_base.py @@ -36,6 +36,7 @@ from cirq import ops, protocols, study, value, devices from cirq.sim import ActOnArgsContainer from cirq.sim.operation_target import OperationTarget +from cirq.sim import simulator from cirq.sim.simulator import ( TSimulationTrialResult, TActOnArgs, @@ -385,11 +386,12 @@ class SimulationTrialResultBase( ): """A base class for trial results.""" + @simulator.deprecated_step_result_parameter(old_position=3) def __init__( self, params: study.ParamResolver, measurements: Dict[str, np.ndarray], - final_step_result: StepResultBase[TActOnArgs], + final_simulator_state: 'cirq.OperationTarget[TActOnArgs]', ) -> None: """Initializes the `SimulationTrialResultBase` class. @@ -399,10 +401,10 @@ def __init__( results. Measurement results are a numpy ndarray of actual boolean measurement results (ordered by the qubits acted on by the measurement gate.) - final_step_result: The step result coming from the simulation, that - can be used to get the final simulator state. + final_simulator_state: The final simulator state of the system after the + trial finishes. """ - super().__init__(params, measurements, final_step_result=final_step_result) + super().__init__(params, measurements, final_simulator_state=final_simulator_state) self._merged_sim_state_cache: Optional[TActOnArgs] = None def get_state_containing_qubit(self, qubit: 'cirq.Qid') -> TActOnArgs: diff --git a/cirq-core/cirq/sim/simulator_base_test.py b/cirq-core/cirq/sim/simulator_base_test.py index 0fed8a667ad..4d2ad991421 100644 --- a/cirq-core/cirq/sim/simulator_base_test.py +++ b/cirq-core/cirq/sim/simulator_base_test.py @@ -124,9 +124,11 @@ def _create_simulator_trial_result( self, params: cirq.ParamResolver, measurements: Dict[str, np.ndarray], - final_step_result: CountingStepResult, + final_simulator_state: 'cirq.OperationTarget[CountingActOnArgs]', ) -> CountingTrialResult: - return CountingTrialResult(params, measurements, final_step_result=final_step_result) + return CountingTrialResult( + params, measurements, final_simulator_state=final_simulator_state + ) def _create_step_result( self, sim_state: cirq.OperationTarget[CountingActOnArgs] diff --git a/cirq-core/cirq/sim/simulator_test.py b/cirq-core/cirq/sim/simulator_test.py index f7063167471..eb9a0815d9b 100644 --- a/cirq-core/cirq/sim/simulator_test.py +++ b/cirq-core/cirq/sim/simulator_test.py @@ -73,7 +73,7 @@ def _create_simulator_trial_result( self, params: study.ParamResolver, measurements: Dict[str, np.ndarray], - final_step_result: TStepResult, + final_simulator_state: 'cirq.OperationTarget[TActOnArgs]', ) -> 'SimulationTrialResult': """This method creates a default trial result. @@ -86,7 +86,7 @@ def _create_simulator_trial_result( The SimulationTrialResult. """ return SimulationTrialResult( - params=params, measurements=measurements, final_step_result=final_step_result + params=params, measurements=measurements, final_simulator_state=final_simulator_state ) @@ -171,17 +171,16 @@ def steps(*args, **kwargs): program=circuit, params=param_resolvers, qubit_order=qubit_order, initial_state=2 ) - final_step_result = FakeStepResult(final_state=final_state) expected_results = [ cirq.SimulationTrialResult( measurements={'a': np.array([True, True])}, params=param_resolvers[0], - final_step_result=final_step_result, + final_simulator_state=final_state, ), cirq.SimulationTrialResult( measurements={'a': np.array([True, True])}, params=param_resolvers[1], - final_step_result=final_step_result, + final_simulator_state=final_state, ), ] assert results == expected_results @@ -242,46 +241,41 @@ def test_step_sample_measurement_ops_repeated_qubit(): def test_simulation_trial_result_equality(): eq = cirq.testing.EqualsTester() - final_step_result = FakeStepResult(final_state=()) eq.add_equality_group( cirq.SimulationTrialResult( - params=cirq.ParamResolver({}), measurements={}, final_step_result=final_step_result + params=cirq.ParamResolver({}), measurements={}, final_simulator_state=() ), cirq.SimulationTrialResult( - params=cirq.ParamResolver({}), measurements={}, final_step_result=final_step_result + params=cirq.ParamResolver({}), measurements={}, final_simulator_state=() ), ) eq.add_equality_group( cirq.SimulationTrialResult( - params=cirq.ParamResolver({'s': 1}), - measurements={}, - final_step_result=final_step_result, + params=cirq.ParamResolver({'s': 1}), measurements={}, final_simulator_state=() ) ) eq.add_equality_group( cirq.SimulationTrialResult( params=cirq.ParamResolver({'s': 1}), measurements={'m': np.array([1])}, - final_step_result=final_step_result, + final_simulator_state=(), ) ) - final_step_result._final_state = (0, 1) eq.add_equality_group( cirq.SimulationTrialResult( params=cirq.ParamResolver({'s': 1}), measurements={'m': np.array([1])}, - final_step_result=final_step_result, + final_simulator_state=(0, 1), ) ) def test_simulation_trial_result_repr(): - final_step_result = FakeStepResult(final_state=(0, 1)) assert repr( cirq.SimulationTrialResult( params=cirq.ParamResolver({'s': 1}), measurements={'m': np.array([1])}, - final_step_result=final_step_result, + final_simulator_state=(0, 1), ) ) == ( "cirq.SimulationTrialResult(" @@ -292,13 +286,10 @@ def test_simulation_trial_result_repr(): def test_simulation_trial_result_str(): - final_step_result = FakeStepResult(final_state=(0, 1)) assert ( str( cirq.SimulationTrialResult( - params=cirq.ParamResolver({'s': 1}), - measurements={}, - final_step_result=final_step_result, + params=cirq.ParamResolver({'s': 1}), measurements={}, final_simulator_state=(0, 1) ) ) == '(no measurements)' @@ -309,7 +300,7 @@ def test_simulation_trial_result_str(): cirq.SimulationTrialResult( params=cirq.ParamResolver({'s': 1}), measurements={'m': np.array([1])}, - final_step_result=final_step_result, + final_simulator_state=(0, 1), ) ) == 'm=1' @@ -320,7 +311,7 @@ def test_simulation_trial_result_str(): cirq.SimulationTrialResult( params=cirq.ParamResolver({'s': 1}), measurements={'m': np.array([1, 2, 3])}, - final_step_result=final_step_result, + final_simulator_state=(0, 1), ) ) == 'm=123' @@ -331,7 +322,7 @@ def test_simulation_trial_result_str(): cirq.SimulationTrialResult( params=cirq.ParamResolver({'s': 1}), measurements={'m': np.array([9, 10, 11])}, - final_step_result=final_step_result, + final_simulator_state=(0, 1), ) ) == 'm=9 10 11' @@ -447,9 +438,7 @@ def _kraus_(self): def test_iter_definitions(): - dummy_trial_result = SimulationTrialResult( - params={}, measurements={}, final_step_result=FakeStepResult(final_state=[]) - ) + dummy_trial_result = SimulationTrialResult(params={}, measurements={}, final_simulator_state=[]) class FakeNonIterSimulatorImpl( SimulatesAmplitudes, SimulatesExpectationValues, SimulatesFinalState @@ -543,3 +532,14 @@ def test_trial_result_initializer(): _ = SimulationTrialResult(cirq.ParamResolver(), {}, None, None) with pytest.raises(ValueError, match='Exactly one of'): _ = SimulationTrialResult(cirq.ParamResolver(), {}, object(), mock.Mock(TStepResult)) + with pytest.raises(ValueError, match='Exactly one of'): + _ = SimulationTrialResult( + cirq.ParamResolver(), {}, final_simulator_state=None, final_step_result=None + ) + with pytest.raises(ValueError, match='Exactly one of'): + _ = SimulationTrialResult( + cirq.ParamResolver(), + {}, + final_simulator_state=object(), + final_step_result=mock.Mock(TStepResult), + ) diff --git a/cirq-core/cirq/sim/state_vector.py b/cirq-core/cirq/sim/state_vector.py index 7364cbcd352..ebccd000ca0 100644 --- a/cirq-core/cirq/sim/state_vector.py +++ b/cirq-core/cirq/sim/state_vector.py @@ -14,7 +14,7 @@ """Helpers for handling quantum state vectors.""" import abc -from typing import Dict, List, Optional, Tuple, TYPE_CHECKING, Sequence +from typing import List, Mapping, Optional, Tuple, TYPE_CHECKING, Sequence import numpy as np @@ -31,7 +31,7 @@ class StateVectorMixin: """A mixin that provide methods for objects that have a state vector.""" - def __init__(self, qubit_map: Optional[Dict['cirq.Qid', int]] = None, *args, **kwargs): + def __init__(self, qubit_map: Optional[Mapping['cirq.Qid', int]] = None, *args, **kwargs): """Inits StateVectorMixin. Args: @@ -48,7 +48,7 @@ def __init__(self, qubit_map: Optional[Dict['cirq.Qid', int]] = None, *args, **k self._qid_shape = None if qubit_map is None else qid_shape @property - def qubit_map(self) -> Dict['cirq.Qid', int]: + def qubit_map(self) -> Mapping['cirq.Qid', int]: return self._qubit_map def _qid_shape_(self) -> Tuple[int, ...]: diff --git a/cirq-core/cirq/sim/state_vector_simulator.py b/cirq-core/cirq/sim/state_vector_simulator.py index a2a0e442998..7896383aa20 100644 --- a/cirq-core/cirq/sim/state_vector_simulator.py +++ b/cirq-core/cirq/sim/state_vector_simulator.py @@ -69,10 +69,10 @@ def _create_simulator_trial_result( self, params: 'cirq.ParamResolver', measurements: Dict[str, np.ndarray], - final_step_result: 'cirq.StateVectorStepResult', + final_simulator_state: 'cirq.OperationTarget[cirq.ActOnStateVectorArgs]', ) -> 'cirq.StateVectorTrialResult': return StateVectorTrialResult( - params=params, measurements=measurements, final_step_result=final_step_result + params=params, measurements=measurements, final_simulator_state=final_simulator_state ) def compute_amplitudes_sweep_iter( @@ -143,13 +143,13 @@ def __init__( self, params: 'cirq.ParamResolver', measurements: Dict[str, np.ndarray], - final_step_result: 'cirq.StateVectorStepResult', + final_simulator_state: 'cirq.OperationTarget[cirq.ActOnStateVectorArgs]', ) -> None: super().__init__( params=params, measurements=measurements, - final_step_result=final_step_result, - qubit_map=final_step_result._qubit_mapping, + final_simulator_state=final_simulator_state, + qubit_map=final_simulator_state.qubit_map, ) self._final_state_vector: Optional[np.ndarray] = None From 711ac8c4c1f2d0b24b2b46fb3419c5551e8ea991 Mon Sep 17 00:00:00 2001 From: daxfohl Date: Thu, 21 Apr 2022 15:43:44 -0700 Subject: [PATCH 2/8] Finish deprecation changes --- .../act_on_stabilizer_ch_form_args.py | 9 +++ .../cirq/sim/clifford/clifford_simulator.py | 2 +- .../sim/clifford/clifford_simulator_test.py | 20 ++--- cirq-core/cirq/sim/state_vector_simulator.py | 2 +- .../cirq/sim/state_vector_simulator_test.py | 73 +++++++------------ 5 files changed, 47 insertions(+), 59 deletions(-) diff --git a/cirq-core/cirq/sim/clifford/act_on_stabilizer_ch_form_args.py b/cirq-core/cirq/sim/clifford/act_on_stabilizer_ch_form_args.py index 55a8a17c6fd..66069f7ea4f 100644 --- a/cirq-core/cirq/sim/clifford/act_on_stabilizer_ch_form_args.py +++ b/cirq-core/cirq/sim/clifford/act_on_stabilizer_ch_form_args.py @@ -16,6 +16,7 @@ import numpy as np +from cirq._compat import proper_repr from cirq.sim.clifford import stabilizer_state_ch_form from cirq.sim.clifford.act_on_stabilizer_args import ActOnStabilizerArgs @@ -65,3 +66,11 @@ def __init__( super().__init__( state=initial_state, prng=prng, qubits=qubits, classical_data=classical_data ) + + def __repr__(self) -> str: + return ( + 'cirq.ActOnStabilizerCHFormArgs(' + f'initial_state={proper_repr(self.state)},' + f' qubits={self.qubits!r},' + f' classical_data={self.classical_data!r})' + ) diff --git a/cirq-core/cirq/sim/clifford/clifford_simulator.py b/cirq-core/cirq/sim/clifford/clifford_simulator.py index c5bbb87fd39..c86124fe002 100644 --- a/cirq-core/cirq/sim/clifford/clifford_simulator.py +++ b/cirq-core/cirq/sim/clifford/clifford_simulator.py @@ -138,7 +138,7 @@ def final_state(self) -> 'cirq.CliffordState': def __str__(self) -> str: samples = super().__str__() - final = self._final_simulator_state + final = self._get_merged_sim_state().state return f'measurements: {samples}\noutput state: {final}' def _repr_pretty_(self, p: Any, cycle: bool): diff --git a/cirq-core/cirq/sim/clifford/clifford_simulator_test.py b/cirq-core/cirq/sim/clifford/clifford_simulator_test.py index 1c7c6a5a17d..4caa4867b57 100644 --- a/cirq-core/cirq/sim/clifford/clifford_simulator_test.py +++ b/cirq-core/cirq/sim/clifford/clifford_simulator_test.py @@ -212,32 +212,33 @@ def test_clifford_state_initial_state(): def test_clifford_trial_result_repr(): q0 = cirq.LineQubit(0) - final_step_result = mock.Mock(cirq.CliffordSimulatorStepResult) - final_step_result._simulator_state.return_value = cirq.CliffordState(qubit_map={q0: 0}) + final_simulator_state = cirq.ActOnStabilizerCHFormArgs(qubits=[q0]) assert ( repr( cirq.CliffordTrialResult( params=cirq.ParamResolver({}), measurements={'m': np.array([[1]])}, - final_step_result=final_step_result, + final_simulator_state=final_simulator_state, ) ) == "cirq.SimulationTrialResult(params=cirq.ParamResolver({}), " "measurements={'m': array([[1]])}, " - "final_simulator_state=StabilizerStateChForm(num_qubits=1))" + "final_simulator_state=cirq.ActOnStabilizerCHFormArgs(" + "initial_state=StabilizerStateChForm(num_qubits=1), " + "qubits=(cirq.LineQubit(0),), " + "classical_data=cirq.ClassicalDataDictionaryStore()))" ) def test_clifford_trial_result_str(): q0 = cirq.LineQubit(0) - final_step_result = mock.Mock(cirq.CliffordSimulatorStepResult) - final_step_result._simulator_state.return_value = cirq.CliffordState(qubit_map={q0: 0}) + final_simulator_state = cirq.ActOnStabilizerCHFormArgs(qubits=[q0]) assert ( str( cirq.CliffordTrialResult( params=cirq.ParamResolver({}), measurements={'m': np.array([[1]])}, - final_step_result=final_step_result, + final_simulator_state=final_simulator_state, ) ) == "measurements: m=1\n" @@ -247,12 +248,11 @@ def test_clifford_trial_result_str(): def test_clifford_trial_result_repr_pretty(): q0 = cirq.LineQubit(0) - final_step_result = mock.Mock(cirq.CliffordSimulatorStepResult) - final_step_result._simulator_state.return_value = cirq.CliffordState(qubit_map={q0: 0}) + final_simulator_state = cirq.ActOnStabilizerCHFormArgs(qubits=[q0]) result = cirq.CliffordTrialResult( params=cirq.ParamResolver({}), measurements={'m': np.array([[1]])}, - final_step_result=final_step_result, + final_simulator_state=final_simulator_state, ) cirq.testing.assert_repr_pretty(result, "measurements: m=1\n" "output state: |0⟩") diff --git a/cirq-core/cirq/sim/state_vector_simulator.py b/cirq-core/cirq/sim/state_vector_simulator.py index 7896383aa20..f32673baf1e 100644 --- a/cirq-core/cirq/sim/state_vector_simulator.py +++ b/cirq-core/cirq/sim/state_vector_simulator.py @@ -222,5 +222,5 @@ def __repr__(self) -> str: return ( 'cirq.StateVectorTrialResult(' f'params={self.params!r}, measurements={proper_repr(self.measurements)}, ' - f'final_step_result={self._final_step_result!r})' + f'final_simulator_state={self._final_simulator_state!r})' ) diff --git a/cirq-core/cirq/sim/state_vector_simulator_test.py b/cirq-core/cirq/sim/state_vector_simulator_test.py index 50d3063f398..3c757394753 100644 --- a/cirq-core/cirq/sim/state_vector_simulator_test.py +++ b/cirq-core/cirq/sim/state_vector_simulator_test.py @@ -22,29 +22,26 @@ def test_state_vector_trial_result_repr(): q0 = cirq.NamedQubit('a') - args = cirq.ActOnStateVectorArgs( + final_simulator_state = cirq.ActOnStateVectorArgs( available_buffer=np.array([0, 1], dtype=np.complex64), prng=np.random.RandomState(0), qubits=[q0], initial_state=np.array([0, 1], dtype=np.complex64), dtype=np.complex64, ) - final_step_result = cirq.SparseSimulatorStep(args) trial_result = cirq.StateVectorTrialResult( params=cirq.ParamResolver({'s': 1}), measurements={'m': np.array([[1]], dtype=np.int32)}, - final_step_result=final_step_result, + final_simulator_state=final_simulator_state, ) expected_repr = ( "cirq.StateVectorTrialResult(" "params=cirq.ParamResolver({'s': 1}), " "measurements={'m': np.array([[1]], dtype=np.int32)}, " - "final_step_result=cirq.SparseSimulatorStep(" - "sim_state=cirq.ActOnStateVectorArgs(" + "final_simulator_state=cirq.ActOnStateVectorArgs(" "initial_state=np.array([0j, (1+0j)], dtype=np.complex64), " "qubits=(cirq.NamedQubit('a'),), " - "classical_data=cirq.ClassicalDataDictionaryStore()), " - "dtype=np.complex64))" + "classical_data=cirq.ClassicalDataDictionaryStore()))" ) assert repr(trial_result) == expected_repr assert eval(expected_repr) == trial_result @@ -59,52 +56,46 @@ def test_state_vector_simulator_state_repr(): def test_state_vector_trial_result_equality(): eq = cirq.testing.EqualsTester() - final_step_result = cirq.StateVectorStepResult( - cirq.ActOnStateVectorArgs(initial_state=np.array([])) - ) + final_simulator_state = cirq.ActOnStateVectorArgs(initial_state=np.array([])) eq.add_equality_group( cirq.StateVectorTrialResult( - params=cirq.ParamResolver({}), measurements={}, final_step_result=final_step_result + params=cirq.ParamResolver({}), measurements={}, final_simulator_state=final_simulator_state ), cirq.StateVectorTrialResult( - params=cirq.ParamResolver({}), measurements={}, final_step_result=final_step_result + params=cirq.ParamResolver({}), measurements={}, final_simulator_state=final_simulator_state ), ) eq.add_equality_group( cirq.StateVectorTrialResult( params=cirq.ParamResolver({'s': 1}), measurements={}, - final_step_result=final_step_result, + final_simulator_state=final_simulator_state, ) ) eq.add_equality_group( cirq.StateVectorTrialResult( params=cirq.ParamResolver({'s': 1}), measurements={'m': np.array([[1]])}, - final_step_result=final_step_result, + final_simulator_state=final_simulator_state, ) ) - final_step_result = cirq.StateVectorStepResult( - cirq.ActOnStateVectorArgs(initial_state=np.array([1])) - ) + final_simulator_state = cirq.ActOnStateVectorArgs(initial_state=np.array([1])) eq.add_equality_group( cirq.StateVectorTrialResult( params=cirq.ParamResolver({'s': 1}), measurements={'m': np.array([[1]])}, - final_step_result=final_step_result, + final_simulator_state=final_simulator_state, ) ) def test_state_vector_trial_result_state_mixin(): qubits = cirq.LineQubit.range(2) - final_step_result = cirq.StateVectorStepResult( - cirq.ActOnStateVectorArgs(qubits=qubits, initial_state=np.array([0, 1, 0, 0])) - ) + final_simulator_state = cirq.ActOnStateVectorArgs(qubits=qubits, initial_state=np.array([0, 1, 0, 0])) result = cirq.StateVectorTrialResult( params=cirq.ParamResolver({'a': 2}), measurements={'m': np.array([1, 2])}, - final_step_result=final_step_result, + final_simulator_state=final_simulator_state, ) rho = np.array([[0, 0, 0, 0], [0, 1, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]) np.testing.assert_array_almost_equal(rho, result.density_matrix_of(qubits)) @@ -114,70 +105,58 @@ def test_state_vector_trial_result_state_mixin(): def test_state_vector_trial_result_qid_shape(): - qubit_map = {cirq.NamedQubit('a'): 0} - final_step_result = mock.Mock(cirq.StateVectorStepResult) - final_step_result._qubit_mapping = qubit_map - final_step_result._simulator_state.return_value = cirq.StateVectorSimulatorState( - qubit_map=qubit_map, state_vector=np.array([0, 1]) + final_simulator_state = cirq.ActOnStateVectorArgs( + qubits=[cirq.NamedQubit('a')], initial_state=np.array([0, 1]) ) trial_result = cirq.StateVectorTrialResult( params=cirq.ParamResolver({'s': 1}), measurements={'m': np.array([[1]])}, - final_step_result=final_step_result, + final_simulator_state=final_simulator_state, ) - assert cirq.qid_shape(final_step_result._simulator_state()) == (2,) assert cirq.qid_shape(trial_result) == (2,) - q0, q1 = cirq.LineQid.for_qid_shape((2, 3)) - qubit_map = {q0: 1, q1: 0} - final_step_result._qubit_mapping = qubit_map - final_step_result._simulator_state.return_value = cirq.StateVectorSimulatorState( - qubit_map=qubit_map, state_vector=np.array([0, 0, 0, 0, 1, 0]) + final_simulator_state = cirq.ActOnStateVectorArgs( + qubits=cirq.LineQid.for_qid_shape((3, 2)), initial_state=np.array([0, 0, 0, 0, 1, 0]) ) trial_result = cirq.StateVectorTrialResult( params=cirq.ParamResolver({'s': 1}), measurements={'m': np.array([[2, 0]])}, - final_step_result=final_step_result, + final_simulator_state=final_simulator_state, ) - assert cirq.qid_shape(final_step_result._simulator_state()) == (3, 2) assert cirq.qid_shape(trial_result) == (3, 2) def test_state_vector_trial_state_vector_is_copy(): final_state_vector = np.array([0, 1], dtype=np.complex64) qubit_map = {cirq.NamedQubit('a'): 0} - final_step_result = cirq.StateVectorStepResult( - cirq.ActOnStateVectorArgs(qubits=list(qubit_map), initial_state=final_state_vector) - ) + final_simulator_state = cirq.ActOnStateVectorArgs(qubits=list(qubit_map), initial_state=final_state_vector) trial_result = cirq.StateVectorTrialResult( - params=cirq.ParamResolver({}), measurements={}, final_step_result=final_step_result + params=cirq.ParamResolver({}), measurements={}, final_simulator_state=final_simulator_state ) - assert trial_result.state_vector() is not final_step_result._simulator_state().target_tensor + assert trial_result.state_vector() is not final_simulator_state.target_tensor def test_str_big(): qs = cirq.LineQubit.range(10) - args = cirq.ActOnStateVectorArgs( + final_simulator_state = cirq.ActOnStateVectorArgs( prng=np.random.RandomState(0), qubits=qs, initial_state=np.array([1] * 2**10, dtype=np.complex64) * 0.03125, dtype=np.complex64, ) - final_step_result = cirq.SparseSimulatorStep(args) - result = cirq.StateVectorTrialResult(cirq.ParamResolver(), {}, final_step_result) + result = cirq.StateVectorTrialResult(cirq.ParamResolver(), {}, final_simulator_state) assert 'output vector: [0.03125+0.j 0.03125+0.j 0.03125+0.j ..' in str(result) def test_pretty_print(): - args = cirq.ActOnStateVectorArgs( + final_simulator_state = cirq.ActOnStateVectorArgs( available_buffer=np.array([1]), prng=np.random.RandomState(0), qubits=[], initial_state=np.array([1], dtype=np.complex64), dtype=np.complex64, ) - final_step_result = cirq.SparseSimulatorStep(args) - result = cirq.StateVectorTrialResult(cirq.ParamResolver(), {}, final_step_result) + result = cirq.StateVectorTrialResult(cirq.ParamResolver(), {}, final_simulator_state) # Test Jupyter console output from class FakePrinter: From c8a40cb8338d74d6421060dfa2b756d0c85c24e4 Mon Sep 17 00:00:00 2001 From: daxfohl Date: Thu, 21 Apr 2022 15:45:48 -0700 Subject: [PATCH 3/8] format --- .../sim/clifford/clifford_simulator_test.py | 1 - .../cirq/sim/state_vector_simulator_test.py | 18 ++++++++++++------ 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/cirq-core/cirq/sim/clifford/clifford_simulator_test.py b/cirq-core/cirq/sim/clifford/clifford_simulator_test.py index 4caa4867b57..c3bfca63d90 100644 --- a/cirq-core/cirq/sim/clifford/clifford_simulator_test.py +++ b/cirq-core/cirq/sim/clifford/clifford_simulator_test.py @@ -1,6 +1,5 @@ # pylint: disable=wrong-or-nonexistent-copyright-notice import itertools -from unittest import mock import numpy as np import pytest diff --git a/cirq-core/cirq/sim/state_vector_simulator_test.py b/cirq-core/cirq/sim/state_vector_simulator_test.py index 3c757394753..a8f1f15cb68 100644 --- a/cirq-core/cirq/sim/state_vector_simulator_test.py +++ b/cirq-core/cirq/sim/state_vector_simulator_test.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from unittest import mock - import numpy as np import cirq @@ -59,10 +57,14 @@ def test_state_vector_trial_result_equality(): final_simulator_state = cirq.ActOnStateVectorArgs(initial_state=np.array([])) eq.add_equality_group( cirq.StateVectorTrialResult( - params=cirq.ParamResolver({}), measurements={}, final_simulator_state=final_simulator_state + params=cirq.ParamResolver({}), + measurements={}, + final_simulator_state=final_simulator_state, ), cirq.StateVectorTrialResult( - params=cirq.ParamResolver({}), measurements={}, final_simulator_state=final_simulator_state + params=cirq.ParamResolver({}), + measurements={}, + final_simulator_state=final_simulator_state, ), ) eq.add_equality_group( @@ -91,7 +93,9 @@ def test_state_vector_trial_result_equality(): def test_state_vector_trial_result_state_mixin(): qubits = cirq.LineQubit.range(2) - final_simulator_state = cirq.ActOnStateVectorArgs(qubits=qubits, initial_state=np.array([0, 1, 0, 0])) + final_simulator_state = cirq.ActOnStateVectorArgs( + qubits=qubits, initial_state=np.array([0, 1, 0, 0]) + ) result = cirq.StateVectorTrialResult( params=cirq.ParamResolver({'a': 2}), measurements={'m': np.array([1, 2])}, @@ -129,7 +133,9 @@ def test_state_vector_trial_result_qid_shape(): def test_state_vector_trial_state_vector_is_copy(): final_state_vector = np.array([0, 1], dtype=np.complex64) qubit_map = {cirq.NamedQubit('a'): 0} - final_simulator_state = cirq.ActOnStateVectorArgs(qubits=list(qubit_map), initial_state=final_state_vector) + final_simulator_state = cirq.ActOnStateVectorArgs( + qubits=list(qubit_map), initial_state=final_state_vector + ) trial_result = cirq.StateVectorTrialResult( params=cirq.ParamResolver({}), measurements={}, final_simulator_state=final_simulator_state ) From 6ec8b2f08e8076948bc537e39ba7d2e34f73ed0f Mon Sep 17 00:00:00 2001 From: daxfohl Date: Thu, 21 Apr 2022 15:56:23 -0700 Subject: [PATCH 4/8] Deprecate the state classes too --- .../cirq/sim/density_matrix_simulator.py | 3 +- .../cirq/sim/density_matrix_simulator_test.py | 77 +++++++++++-------- cirq-core/cirq/sim/state_vector_simulator.py | 3 +- .../cirq/sim/state_vector_simulator_test.py | 9 ++- 4 files changed, 52 insertions(+), 40 deletions(-) diff --git a/cirq-core/cirq/sim/density_matrix_simulator.py b/cirq-core/cirq/sim/density_matrix_simulator.py index 602da2947f6..13c344a62e0 100644 --- a/cirq-core/cirq/sim/density_matrix_simulator.py +++ b/cirq-core/cirq/sim/density_matrix_simulator.py @@ -17,7 +17,7 @@ import numpy as np from cirq import ops, protocols, study, value -from cirq._compat import deprecated_parameter, proper_repr +from cirq._compat import deprecated_class, deprecated_parameter, proper_repr from cirq.sim import simulator, act_on_density_matrix_args, simulator_base if TYPE_CHECKING: @@ -310,6 +310,7 @@ def __repr__(self) -> str: ) +@deprecated_class(deadline='v0.16', fix='This class is no longer used.') @value.value_equality(unhashable=True) class DensityMatrixSimulatorState: """The simulator state for DensityMatrixSimulator diff --git a/cirq-core/cirq/sim/density_matrix_simulator_test.py b/cirq-core/cirq/sim/density_matrix_simulator_test.py index 4711bc27379..fc805a70554 100644 --- a/cirq-core/cirq/sim/density_matrix_simulator_test.py +++ b/cirq-core/cirq/sim/density_matrix_simulator_test.py @@ -921,48 +921,57 @@ def test_simulate_expectation_values_qubit_order(dtype): assert cirq.approx_eq(result_flipped[0], 3, atol=1e-6) -def test_density_matrix_simulator_state_eq(): - q0, q1 = cirq.LineQubit.range(2) - eq = cirq.testing.EqualsTester() - eq.add_equality_group( - cirq.DensityMatrixSimulatorState(density_matrix=np.ones((2, 2)) * 0.5, qubit_map={q0: 0}), - cirq.DensityMatrixSimulatorState(density_matrix=np.ones((2, 2)) * 0.5, qubit_map={q0: 0}), - ) - eq.add_equality_group( - cirq.DensityMatrixSimulatorState(density_matrix=np.eye(2) * 0.5, qubit_map={q0: 0}) - ) - eq.add_equality_group( - cirq.DensityMatrixSimulatorState(density_matrix=np.eye(2) * 0.5, qubit_map={q0: 0, q1: 1}) - ) +def test_density_matrix_simulator_state_eq_deprecated(): + with cirq.testing.assert_deprecated('no longer used', deadline='v0.16', count=4): + q0, q1 = cirq.LineQubit.range(2) + eq = cirq.testing.EqualsTester() + eq.add_equality_group( + cirq.DensityMatrixSimulatorState( + density_matrix=np.ones((2, 2)) * 0.5, qubit_map={q0: 0} + ), + cirq.DensityMatrixSimulatorState( + density_matrix=np.ones((2, 2)) * 0.5, qubit_map={q0: 0} + ), + ) + eq.add_equality_group( + cirq.DensityMatrixSimulatorState(density_matrix=np.eye(2) * 0.5, qubit_map={q0: 0}) + ) + eq.add_equality_group( + cirq.DensityMatrixSimulatorState( + density_matrix=np.eye(2) * 0.5, qubit_map={q0: 0, q1: 1} + ) + ) def test_density_matrix_simulator_state_qid_shape(): - q0, q1 = cirq.LineQubit.range(2) - assert cirq.qid_shape( - cirq.DensityMatrixSimulatorState( - density_matrix=np.ones((4, 4)) / 4, qubit_map={q0: 0, q1: 1} - ) - ) == (2, 2) - q0, q1 = cirq.LineQid.for_qid_shape((3, 4)) - assert cirq.qid_shape( - cirq.DensityMatrixSimulatorState( - density_matrix=np.ones((12, 12)) / 12, qubit_map={q0: 0, q1: 1} - ) - ) == (3, 4) + with cirq.testing.assert_deprecated('no longer used', deadline='v0.16', count=2): + q0, q1 = cirq.LineQubit.range(2) + assert cirq.qid_shape( + cirq.DensityMatrixSimulatorState( + density_matrix=np.ones((4, 4)) / 4, qubit_map={q0: 0, q1: 1} + ) + ) == (2, 2) + q0, q1 = cirq.LineQid.for_qid_shape((3, 4)) + assert cirq.qid_shape( + cirq.DensityMatrixSimulatorState( + density_matrix=np.ones((12, 12)) / 12, qubit_map={q0: 0, q1: 1} + ) + ) == (3, 4) def test_density_matrix_simulator_state_repr(): - q0 = cirq.LineQubit(0) - assert ( - repr( - cirq.DensityMatrixSimulatorState( - density_matrix=np.ones((2, 2)) * 0.5, qubit_map={q0: 0} + with cirq.testing.assert_deprecated('no longer used', deadline='v0.16'): + q0 = cirq.LineQubit(0) + assert ( + repr( + cirq.DensityMatrixSimulatorState( + density_matrix=np.ones((2, 2)) * 0.5, qubit_map={q0: 0} + ) ) + == "cirq.DensityMatrixSimulatorState(density_matrix=" + "np.array([[0.5, 0.5], [0.5, 0.5]]), " + "qubit_map={cirq.LineQubit(0): 0})" ) - == "cirq.DensityMatrixSimulatorState(density_matrix=" - "np.array([[0.5, 0.5], [0.5, 0.5]]), " - "qubit_map={cirq.LineQubit(0): 0})" - ) def test_density_matrix_trial_result_eq(): diff --git a/cirq-core/cirq/sim/state_vector_simulator.py b/cirq-core/cirq/sim/state_vector_simulator.py index f32673baf1e..a473565ba3e 100644 --- a/cirq-core/cirq/sim/state_vector_simulator.py +++ b/cirq-core/cirq/sim/state_vector_simulator.py @@ -30,7 +30,7 @@ import numpy as np from cirq import ops, value, qis -from cirq._compat import proper_repr +from cirq._compat import deprecated_class, proper_repr from cirq.sim import simulator, state_vector, simulator_base if TYPE_CHECKING: @@ -107,6 +107,7 @@ class StateVectorStepResult( pass +@deprecated_class(deadline='v0.16', fix='This class is no longer used.') @value.value_equality(unhashable=True) class StateVectorSimulatorState: def __init__(self, state_vector: np.ndarray, qubit_map: Dict[ops.Qid, int]) -> None: diff --git a/cirq-core/cirq/sim/state_vector_simulator_test.py b/cirq-core/cirq/sim/state_vector_simulator_test.py index a8f1f15cb68..fdae8ece86e 100644 --- a/cirq-core/cirq/sim/state_vector_simulator_test.py +++ b/cirq-core/cirq/sim/state_vector_simulator_test.py @@ -46,10 +46,11 @@ def test_state_vector_trial_result_repr(): def test_state_vector_simulator_state_repr(): - final_simulator_state = cirq.StateVectorSimulatorState( - qubit_map={cirq.NamedQubit('a'): 0}, state_vector=np.array([0, 1]) - ) - cirq.testing.assert_equivalent_repr(final_simulator_state) + with cirq.testing.assert_deprecated('no longer used', deadline='v0.16', count=4): + final_simulator_state = cirq.StateVectorSimulatorState( + qubit_map={cirq.NamedQubit('a'): 0}, state_vector=np.array([0, 1]) + ) + cirq.testing.assert_equivalent_repr(final_simulator_state) def test_state_vector_trial_result_equality(): From ed9fb0ebf7a1c7b8b112fbf3998891dd3251739a Mon Sep 17 00:00:00 2001 From: daxfohl Date: Thu, 21 Apr 2022 17:37:31 -0700 Subject: [PATCH 5/8] Add backwards-compatible function call --- cirq-core/cirq/sim/simulator.py | 21 ++++++++++++++++----- cirq-core/cirq/sim/simulator_base_test.py | 17 +++++++++++++++++ 2 files changed, 33 insertions(+), 5 deletions(-) diff --git a/cirq-core/cirq/sim/simulator.py b/cirq-core/cirq/sim/simulator.py index b7005969195..949f2a65ece 100644 --- a/cirq-core/cirq/sim/simulator.py +++ b/cirq-core/cirq/sim/simulator.py @@ -29,6 +29,7 @@ import abc import collections +import inspect from typing import ( Any, Callable, @@ -609,11 +610,21 @@ def simulate_sweep_iter( for step_result in all_step_results: for k, v in step_result.measurements.items(): measurements[k] = np.array(v, dtype=np.uint8) - yield self._create_simulator_trial_result( - params=param_resolver, - measurements=measurements, - final_simulator_state=step_result._simulator_state(), - ) + if ( + 'final_simulator_state' + in inspect.signature(self._create_simulator_trial_result).parameters + ): + yield self._create_simulator_trial_result( + params=param_resolver, + measurements=measurements, + final_simulator_state=step_result._simulator_state(), + ) + else: + yield self._create_simulator_trial_result( + params=param_resolver, + measurements=measurements, + final_step_result=step_result, # type: ignore + ) def simulate_moment_steps( self, diff --git a/cirq-core/cirq/sim/simulator_base_test.py b/cirq-core/cirq/sim/simulator_base_test.py index 4d2ad991421..c62d95e4fcb 100644 --- a/cirq-core/cirq/sim/simulator_base_test.py +++ b/cirq-core/cirq/sim/simulator_base_test.py @@ -390,3 +390,20 @@ def _has_unitary_(self): simulator.simulate_sweep(program=circuit, params=params) assert op1.count == 2 assert op2.count == 2 + + +def test_deprecated_final_step_result(): + class OldCountingSimulator(CountingSimulator): + def _create_simulator_trial_result( + self, + params: cirq.ParamResolver, + measurements: Dict[str, np.ndarray], + final_step_result: CountingStepResult, + ) -> CountingTrialResult: + return CountingTrialResult(params, measurements, final_step_result=final_step_result) + + sim = OldCountingSimulator() + with cirq.testing.assert_deprecated('final_step_result', deadline='0.16'): + r = sim.simulate(cirq.Circuit()) + assert r._final_simulator_state.gate_count == 0 + assert r._final_simulator_state.measurement_count == 0 From e00448de69711dc5be0aea58c50c3f6b36ae6650 Mon Sep 17 00:00:00 2001 From: daxfohl Date: Thu, 21 Apr 2022 17:39:26 -0700 Subject: [PATCH 6/8] Fix mps --- cirq-core/cirq/contrib/quimb/mps_simulator_test.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/cirq-core/cirq/contrib/quimb/mps_simulator_test.py b/cirq-core/cirq/contrib/quimb/mps_simulator_test.py index 5b4ce85d802..dbabe57aaaf 100644 --- a/cirq-core/cirq/contrib/quimb/mps_simulator_test.py +++ b/cirq-core/cirq/contrib/quimb/mps_simulator_test.py @@ -263,8 +263,7 @@ def test_measurement_str(): def test_trial_result_str(): q0 = cirq.LineQubit(0) - final_step_result = mock.Mock(cirq.StepResult) - final_step_result._simulator_state.return_value = ccq.mps_simulator.MPSState( + final_simulator_state = ccq.mps_simulator.MPSState( qubits=(q0,), prng=value.parse_random_state(0), simulation_options=ccq.mps_simulator.MPSOptions(), @@ -274,7 +273,7 @@ def test_trial_result_str(): ccq.mps_simulator.MPSTrialResult( params=cirq.ParamResolver({}), measurements={'m': np.array([[1]])}, - final_step_result=final_step_result, + final_simulator_state=final_simulator_state, ) ) == """measurements: m=1 @@ -286,8 +285,7 @@ def test_trial_result_str(): def test_trial_result_repr_pretty(): q0 = cirq.LineQubit(0) - final_step_result = mock.Mock(cirq.StepResult) - final_step_result._simulator_state.return_value = ccq.mps_simulator.MPSState( + final_simulator_state = ccq.mps_simulator.MPSState( qubits=(q0,), prng=value.parse_random_state(0), simulation_options=ccq.mps_simulator.MPSOptions(), @@ -295,7 +293,7 @@ def test_trial_result_repr_pretty(): result = ccq.mps_simulator.MPSTrialResult( params=cirq.ParamResolver({}), measurements={'m': np.array([[1]])}, - final_step_result=final_step_result, + final_simulator_state=final_simulator_state, ) cirq.testing.assert_repr_pretty( result, From ea717a26d0bf33e8e751335e6a5441c0feef90ec Mon Sep 17 00:00:00 2001 From: daxfohl Date: Thu, 21 Apr 2022 17:45:53 -0700 Subject: [PATCH 7/8] docstrings --- cirq-core/cirq/contrib/quimb/mps_simulator.py | 2 +- cirq-core/cirq/sim/simulator.py | 8 +++----- cirq-core/cirq/sim/simulator_test.py | 2 +- 3 files changed, 5 insertions(+), 7 deletions(-) diff --git a/cirq-core/cirq/contrib/quimb/mps_simulator.py b/cirq-core/cirq/contrib/quimb/mps_simulator.py index 4cc85b0c3b9..874252244a6 100644 --- a/cirq-core/cirq/contrib/quimb/mps_simulator.py +++ b/cirq-core/cirq/contrib/quimb/mps_simulator.py @@ -130,7 +130,7 @@ def _create_simulator_trial_result( params: A ParamResolver for determining values of Symbols. measurements: A dictionary from measurement key (e.g. qubit) to the actual measurement array. - final_step_result: The final step result of the simulation. + final_simulator_state: The final state of the simulation. Returns: A single result. diff --git a/cirq-core/cirq/sim/simulator.py b/cirq-core/cirq/sim/simulator.py index 949f2a65ece..e7879e1d7f2 100644 --- a/cirq-core/cirq/sim/simulator.py +++ b/cirq-core/cirq/sim/simulator.py @@ -744,7 +744,7 @@ def _create_simulator_trial_result( Args: params: The ParamResolver for this trial. measurements: The measurement results for this trial. - final_step_result: The final step result of the simulation. + final_simulator_state: The final state of the simulation. Returns: The SimulationTrialResult. @@ -889,6 +889,8 @@ def sample_measurement_ops( ) +# When removing this, also remove the check in simulate_sweep_iter. +# Basically there should be no "final_step_result" anywhere in the project afterwards. def deprecated_step_result_parameter( old_position: int = 4, new_position: int = 3 ) -> Callable[[Callable], Callable]: @@ -950,10 +952,6 @@ def __init__( boolean measurement results (ordered by the qubits acted on by the measurement gate.) final_simulator_state: The final simulator state. - - Raises: - ValueError: If `final_step_result` and `final_simulator_state` are both - None or both not None. """ self.params = params self.measurements = measurements diff --git a/cirq-core/cirq/sim/simulator_test.py b/cirq-core/cirq/sim/simulator_test.py index eb9a0815d9b..5071aed20b3 100644 --- a/cirq-core/cirq/sim/simulator_test.py +++ b/cirq-core/cirq/sim/simulator_test.py @@ -80,7 +80,7 @@ def _create_simulator_trial_result( Args: params: The ParamResolver for this trial. measurements: The measurement results for this trial. - final_step_result: The final step result of the simulation. + final_simulator_state: The final state of the simulation. Returns: The SimulationTrialResult. From 7f94982af0929b3b68e4613d1cec34dcf9fe1552 Mon Sep 17 00:00:00 2001 From: daxfohl Date: Thu, 21 Apr 2022 19:29:57 -0700 Subject: [PATCH 8/8] Fix tests, coverage --- cirq-core/cirq/contrib/quimb/mps_simulator.py | 2 +- .../cirq/contrib/quimb/mps_simulator_test.py | 1 - .../cirq/sim/clifford/clifford_simulator.py | 2 +- .../cirq/sim/density_matrix_simulator.py | 2 +- cirq-core/cirq/sim/simulator.py | 37 ++++++++++++++----- cirq-core/cirq/sim/simulator_base.py | 2 +- cirq-core/cirq/sim/simulator_base_test.py | 2 +- cirq-core/cirq/sim/simulator_test.py | 35 ++++++++++++------ 8 files changed, 57 insertions(+), 26 deletions(-) diff --git a/cirq-core/cirq/contrib/quimb/mps_simulator.py b/cirq-core/cirq/contrib/quimb/mps_simulator.py index 874252244a6..b8ea3f1f751 100644 --- a/cirq-core/cirq/contrib/quimb/mps_simulator.py +++ b/cirq-core/cirq/contrib/quimb/mps_simulator.py @@ -143,7 +143,7 @@ def _create_simulator_trial_result( class MPSTrialResult(simulator_base.SimulationTrialResultBase['MPSState']): """A single trial reult""" - @simulator.deprecated_step_result_parameter(old_position=3) + @simulator._deprecated_step_result_parameter(old_position=3) def __init__( self, params: 'cirq.ParamResolver', diff --git a/cirq-core/cirq/contrib/quimb/mps_simulator_test.py b/cirq-core/cirq/contrib/quimb/mps_simulator_test.py index dbabe57aaaf..dfd241f6d27 100644 --- a/cirq-core/cirq/contrib/quimb/mps_simulator_test.py +++ b/cirq-core/cirq/contrib/quimb/mps_simulator_test.py @@ -1,7 +1,6 @@ # pylint: disable=wrong-or-nonexistent-copyright-notice import itertools import math -from unittest import mock import numpy as np import pytest diff --git a/cirq-core/cirq/sim/clifford/clifford_simulator.py b/cirq-core/cirq/sim/clifford/clifford_simulator.py index c86124fe002..1919222d618 100644 --- a/cirq-core/cirq/sim/clifford/clifford_simulator.py +++ b/cirq-core/cirq/sim/clifford/clifford_simulator.py @@ -118,7 +118,7 @@ def _create_simulator_trial_result( class CliffordTrialResult( simulator_base.SimulationTrialResultBase['clifford.ActOnStabilizerCHFormArgs'] ): - @simulator.deprecated_step_result_parameter(old_position=3) + @simulator._deprecated_step_result_parameter(old_position=3) def __init__( self, params: 'cirq.ParamResolver', diff --git a/cirq-core/cirq/sim/density_matrix_simulator.py b/cirq-core/cirq/sim/density_matrix_simulator.py index 13c344a62e0..1a3188e654d 100644 --- a/cirq-core/cirq/sim/density_matrix_simulator.py +++ b/cirq-core/cirq/sim/density_matrix_simulator.py @@ -381,7 +381,7 @@ class DensityMatrixTrialResult( trial finishes. """ - @simulator.deprecated_step_result_parameter(old_position=3) + @simulator._deprecated_step_result_parameter(old_position=3) def __init__( self, params: 'cirq.ParamResolver', diff --git a/cirq-core/cirq/sim/simulator.py b/cirq-core/cirq/sim/simulator.py index e7879e1d7f2..371b3569003 100644 --- a/cirq-core/cirq/sim/simulator.py +++ b/cirq-core/cirq/sim/simulator.py @@ -620,7 +620,7 @@ def simulate_sweep_iter( final_simulator_state=step_result._simulator_state(), ) else: - yield self._create_simulator_trial_result( + yield self._create_simulator_trial_result( # pylint: disable=no-value-for-parameter, unexpected-keyword-arg, line-too-long params=param_resolver, measurements=measurements, final_step_result=step_result, # type: ignore @@ -891,24 +891,43 @@ def sample_measurement_ops( # When removing this, also remove the check in simulate_sweep_iter. # Basically there should be no "final_step_result" anywhere in the project afterwards. -def deprecated_step_result_parameter( +def _deprecated_step_result_parameter( old_position: int = 4, new_position: int = 3 ) -> Callable[[Callable], Callable]: + assert old_position >= new_position + def rewrite_deprecated_step_result_param(args, kwargs): - has_state = len(args) > new_position or 'final_simulator_state' in kwargs - has_step_result = len(args) > old_position or 'final_step_result' in kwargs - if not has_step_result ^ has_state: + args = list(args) + state = ( + kwargs['final_simulator_state'] + if 'final_simulator_state' in kwargs + else args[new_position] + if len(args) > new_position and not isinstance(args[new_position], StepResult) + else None + ) + step_result = ( + kwargs['final_step_result'] + if 'final_step_result' in kwargs + else args[old_position] + if len(args) > old_position and isinstance(args[old_position], StepResult) + else None + ) + if (step_result is None) == (state is None): raise ValueError( 'Exactly one of final_simulator_state and final_step_result should be provided' ) - if len(args) > old_position: + if len(args) > old_position and isinstance(args[old_position], StepResult): args[new_position] = args[old_position]._simulator_state() if old_position > new_position: del args[old_position] elif 'final_step_result' in kwargs: - kwargs['final_simulator_state'] = kwargs['final_step_result']._simulator_state() + sim_state = kwargs['final_step_result']._simulator_state() + if len(args) > new_position: + args[new_position] = sim_state + else: + kwargs['final_simulator_state'] = sim_state del kwargs['final_step_result'] - return args, kwargs + return tuple(args), kwargs return _compat.deprecated_parameter( deadline='v0.16', @@ -936,7 +955,7 @@ class SimulationTrialResult(Generic[TSimulatorState]): measurement gate.) """ - @deprecated_step_result_parameter() + @_deprecated_step_result_parameter() def __init__( self, params: 'cirq.ParamResolver', diff --git a/cirq-core/cirq/sim/simulator_base.py b/cirq-core/cirq/sim/simulator_base.py index 749856742de..c2d07b0fe46 100644 --- a/cirq-core/cirq/sim/simulator_base.py +++ b/cirq-core/cirq/sim/simulator_base.py @@ -386,7 +386,7 @@ class SimulationTrialResultBase( ): """A base class for trial results.""" - @simulator.deprecated_step_result_parameter(old_position=3) + @simulator._deprecated_step_result_parameter(old_position=3) def __init__( self, params: study.ParamResolver, diff --git a/cirq-core/cirq/sim/simulator_base_test.py b/cirq-core/cirq/sim/simulator_base_test.py index c62d95e4fcb..47e3170fc24 100644 --- a/cirq-core/cirq/sim/simulator_base_test.py +++ b/cirq-core/cirq/sim/simulator_base_test.py @@ -394,7 +394,7 @@ def _has_unitary_(self): def test_deprecated_final_step_result(): class OldCountingSimulator(CountingSimulator): - def _create_simulator_trial_result( + def _create_simulator_trial_result( # type: ignore self, params: cirq.ParamResolver, measurements: Dict[str, np.ndarray], diff --git a/cirq-core/cirq/sim/simulator_test.py b/cirq-core/cirq/sim/simulator_test.py index 5071aed20b3..26afb2a8e28 100644 --- a/cirq-core/cirq/sim/simulator_test.py +++ b/cirq-core/cirq/sim/simulator_test.py @@ -528,18 +528,31 @@ class FakeMissingIterSimulatorImpl( def test_trial_result_initializer(): + resolver = cirq.ParamResolver() + step = mock.Mock(cirq.StepResultBase) + step._simulator_state.return_value = 1 + state = 3 with pytest.raises(ValueError, match='Exactly one of'): - _ = SimulationTrialResult(cirq.ParamResolver(), {}, None, None) + _ = SimulationTrialResult(resolver, {}, None, None) with pytest.raises(ValueError, match='Exactly one of'): - _ = SimulationTrialResult(cirq.ParamResolver(), {}, object(), mock.Mock(TStepResult)) + _ = SimulationTrialResult(resolver, {}, state, step) with pytest.raises(ValueError, match='Exactly one of'): - _ = SimulationTrialResult( - cirq.ParamResolver(), {}, final_simulator_state=None, final_step_result=None - ) + _ = SimulationTrialResult(resolver, {}, final_simulator_state=None, final_step_result=None) with pytest.raises(ValueError, match='Exactly one of'): - _ = SimulationTrialResult( - cirq.ParamResolver(), - {}, - final_simulator_state=object(), - final_step_result=mock.Mock(TStepResult), - ) + _ = SimulationTrialResult(resolver, {}, final_simulator_state=state, final_step_result=step) + with cirq.testing.assert_deprecated(deadline='v0.16'): + x = SimulationTrialResult(resolver, {}, final_step_result=step) + assert x._final_simulator_state == 1 + with cirq.testing.assert_deprecated(deadline='v0.16'): + x = SimulationTrialResult(resolver, {}, None, final_step_result=step) + assert x._final_simulator_state == 1 + with cirq.testing.assert_deprecated(deadline='v0.16'): + x = SimulationTrialResult(resolver, {}, None, step) + assert x._final_simulator_state == 1 + with cirq.testing.assert_deprecated(deadline='v0.16'): + x = SimulationTrialResult(resolver, {}, final_simulator_state=None, final_step_result=step) + assert x._final_simulator_state == 1 + x = SimulationTrialResult(resolver, {}, state) + assert x._final_simulator_state == 3 + x = SimulationTrialResult(resolver, {}, final_simulator_state=state) + assert x._final_simulator_state == 3