diff --git a/cirq-core/cirq/__init__.py b/cirq-core/cirq/__init__.py index 6bf9f7d7886..eb776e5fd4f 100644 --- a/cirq-core/cirq/__init__.py +++ b/cirq-core/cirq/__init__.py @@ -452,12 +452,13 @@ ActOnStabilizerCHFormArgs, ActOnStabilizerArgs, ActOnStateVectorArgs, - StabilizerStateChForm, CIRCUIT_LIKE, CliffordSimulator, CliffordState, CliffordSimulatorStepResult, + CliffordTableauSimulationState, CliffordTrialResult, + DensityMatrixSimulationState, DensityMatrixSimulator, DensityMatrixSimulatorState, DensityMatrixStepResult, @@ -477,13 +478,20 @@ SimulatesIntermediateState, SimulatesIntermediateStateVector, SimulatesSamples, + SimulationProductState, + SimulationState, + SimulationStateBase, SimulationTrialResult, SimulationTrialResultBase, Simulator, SimulatorBase, SparseSimulatorStep, + StabilizerChFormSimulationState, StabilizerSampler, + StabilizerSimulationState, + StabilizerStateChForm, StateVectorMixin, + StateVectorSimulationState, StateVectorSimulatorState, StateVectorStepResult, StateVectorTrialResult, diff --git a/cirq-core/cirq/circuits/circuit_operation.py b/cirq-core/cirq/circuits/circuit_operation.py index 48e46976762..76b7a6492e0 100644 --- a/cirq-core/cirq/circuits/circuit_operation.py +++ b/cirq-core/cirq/circuits/circuit_operation.py @@ -350,17 +350,17 @@ def mapped_op(self, deep: bool = False) -> 'cirq.CircuitOperation': def _decompose_(self) -> Iterator['cirq.Operation']: return self.mapped_circuit(deep=False).all_operations() - def _act_on_(self, args: 'cirq.OperationTarget') -> bool: + def _act_on_(self, sim_state: 'cirq.SimulationStateBase') -> bool: if self.repeat_until: circuit = self._mapped_single_loop() while True: for op in circuit.all_operations(): - protocols.act_on(op, args) - if self.repeat_until.resolve(args.classical_data): + protocols.act_on(op, sim_state) + if self.repeat_until.resolve(sim_state.classical_data): break else: for op in self._decompose_(): - protocols.act_on(op, args) + protocols.act_on(op, sim_state) return True # Methods for string representation of the operation. diff --git a/cirq-core/cirq/contrib/quimb/mps_simulator.py b/cirq-core/cirq/contrib/quimb/mps_simulator.py index b8ea3f1f751..3335bcb7794 100644 --- a/cirq-core/cirq/contrib/quimb/mps_simulator.py +++ b/cirq-core/cirq/contrib/quimb/mps_simulator.py @@ -27,7 +27,7 @@ from cirq import devices, protocols, qis, value from cirq._compat import deprecated_parameter from cirq.sim import simulator, simulator_base -from cirq.sim.act_on_args import ActOnArgs +from cirq.sim.simulation_state import SimulationState if TYPE_CHECKING: import cirq @@ -115,14 +115,14 @@ def _create_partial_act_on_args( classical_data=classical_data, ) - def _create_step_result(self, sim_state: 'cirq.OperationTarget[MPSState]'): + def _create_step_result(self, sim_state: 'cirq.SimulationStateBase[MPSState]'): return MPSSimulatorStepResult(sim_state) def _create_simulator_trial_result( self, params: 'cirq.ParamResolver', measurements: Dict[str, np.ndarray], - final_simulator_state: 'cirq.OperationTarget[MPSState]', + final_simulator_state: 'cirq.SimulationStateBase[MPSState]', ) -> 'MPSTrialResult': """Creates a single trial results with the measurements. @@ -148,7 +148,7 @@ def __init__( self, params: 'cirq.ParamResolver', measurements: Dict[str, np.ndarray], - final_simulator_state: 'cirq.OperationTarget[MPSState]', + final_simulator_state: 'cirq.SimulationStateBase[MPSState]', ) -> None: super().__init__( params=params, measurements=measurements, final_simulator_state=final_simulator_state @@ -175,10 +175,10 @@ def _repr_pretty_(self, p: Any, cycle: bool): class MPSSimulatorStepResult(simulator_base.StepResultBase['MPSState']): """A `StepResult` that can perform measurements.""" - def __init__(self, sim_state: 'cirq.OperationTarget[MPSState]'): + def __init__(self, sim_state: 'cirq.SimulationStateBase[MPSState]'): """Results of a step of the simulator. Attributes: - sim_state: The qubit:ActOnArgs lookup for this step. + sim_state: The qubit:SimulationState lookup for this step. """ super().__init__(sim_state) @@ -560,7 +560,7 @@ def sample( @value.value_equality -class MPSState(ActOnArgs[_MPSHandler]): +class MPSState(SimulationState[_MPSHandler]): """A state of the MPS simulation.""" @deprecated_parameter( diff --git a/cirq-core/cirq/contrib/quimb/mps_simulator_test.py b/cirq-core/cirq/contrib/quimb/mps_simulator_test.py index dfd241f6d27..0d20cf9047e 100644 --- a/cirq-core/cirq/contrib/quimb/mps_simulator_test.py +++ b/cirq-core/cirq/contrib/quimb/mps_simulator_test.py @@ -183,7 +183,7 @@ def test_cnot_flipped(): ) -def test_act_on_args(): +def test_simulation_state(): q0, q1 = qubit_order = cirq.LineQubit.range(2) circuit = cirq.Circuit(cirq.CNOT(q1, q0)) mps_simulator = ccq.mps_simulator.MPSSimulator() @@ -537,7 +537,7 @@ def test_state_copy(): assert not np.shares_memory(x[i], y[i]) -def test_state_act_on_args_initializer(): +def test_simulation_state_initializer(): s = ccq.mps_simulator.MPSState( qubits=(cirq.LineQubit(0),), prng=np.random.RandomState(0), diff --git a/cirq-core/cirq/ops/classically_controlled_operation.py b/cirq-core/cirq/ops/classically_controlled_operation.py index 28d9c22a003..ae067f6e42f 100644 --- a/cirq-core/cirq/ops/classically_controlled_operation.py +++ b/cirq-core/cirq/ops/classically_controlled_operation.py @@ -171,9 +171,9 @@ def _circuit_diagram_info_( def _json_dict_(self) -> Dict[str, Any]: return {'conditions': self._conditions, 'sub_operation': self._sub_operation} - def _act_on_(self, args: 'cirq.OperationTarget') -> bool: - if all(c.resolve(args.classical_data) for c in self._conditions): - protocols.act_on(self._sub_operation, args) + def _act_on_(self, sim_state: 'cirq.SimulationStateBase') -> bool: + if all(c.resolve(sim_state.classical_data) for c in self._conditions): + protocols.act_on(self._sub_operation, sim_state) return True def _with_measurement_key_mapping_( diff --git a/cirq-core/cirq/ops/clifford_gate.py b/cirq-core/cirq/ops/clifford_gate.py index 06cfcb9afc2..78b0dade8f2 100644 --- a/cirq-core/cirq/ops/clifford_gate.py +++ b/cirq-core/cirq/ops/clifford_gate.py @@ -277,7 +277,9 @@ def _generate_clifford_from_known_gate( ) -> Union['SingleQubitCliffordGate', 'CliffordGate']: qubits = devices.LineQubit.range(num_qubits) t = qis.CliffordTableau(num_qubits=num_qubits) - args = sim.ActOnCliffordTableauArgs(tableau=t, qubits=qubits, prng=np.random.RandomState()) + args = sim.CliffordTableauSimulationState( + tableau=t, qubits=qubits, prng=np.random.RandomState() + ) protocols.act_on(gate, args, qubits, allow_decompose=False) if num_qubits == 1: @@ -339,7 +341,7 @@ def from_op_list( ) base_tableau = qis.CliffordTableau(len(qubit_order)) - args = sim.clifford.ActOnCliffordTableauArgs( + args = sim.clifford.CliffordTableauSimulationState( tableau=base_tableau, qubits=qubit_order, prng=np.random.RandomState(0) # unused ) for op in operations: @@ -444,7 +446,7 @@ def _decompose_(self, qubits: Sequence['cirq.Qid']) -> 'cirq.OP_TREE': ) def _act_on_( - self, args: 'cirq.OperationTarget', qubits: Sequence['cirq.Qid'] + self, sim_state: 'cirq.SimulationStateBase', qubits: Sequence['cirq.Qid'] ) -> Union[NotImplementedType, bool]: # Note the computation complexity difference between _decompose_ and _act_on_. @@ -453,15 +455,15 @@ def _act_on_( # 1. Direct act_on is O(n^3) -- two matrices multiplication # 2. Decomposition is O(m^3)+O(k*n^2) -- Decomposition complexity + k * One/two-qubits Ops # So when m << n, the decomposition is more efficient. - if isinstance(args, sim.clifford.ActOnCliffordTableauArgs): - axes = args.get_axes(qubits) + if isinstance(sim_state, sim.clifford.CliffordTableauSimulationState): + axes = sim_state.get_axes(qubits) # This padding is important and cannot be omitted. - padded_tableau = _pad_tableau(self._clifford_tableau, len(args.qubits), axes) - args._state = args.tableau.then(padded_tableau) + padded_tableau = _pad_tableau(self._clifford_tableau, len(sim_state.qubits), axes) + sim_state._state = sim_state.tableau.then(padded_tableau) return True - if isinstance(args, sim.clifford.ActOnStabilizerCHFormArgs): - # Do we know how to apply CliffordTableau on ActOnStabilizerCHFormArgs? + if isinstance(sim_state, sim.clifford.StabilizerChFormSimulationState): # coverage: ignore + # Do we know how to apply CliffordTableau on StabilizerChFormSimulationState? # It should be unlike because CliffordTableau ignores the global phase but CHForm # is aimed to fix that. return NotImplemented @@ -706,10 +708,10 @@ def __pow__(self, exponent) -> 'SingleQubitCliffordGate': def _act_on_( self, - args: 'cirq.OperationTarget', # pylint: disable=unused-argument + sim_state: 'cirq.SimulationStateBase', # pylint: disable=unused-argument qubits: Sequence['cirq.Qid'], # pylint: disable=unused-argument ): - # TODO(#5256) Add the implementation of _act_on_ with ActOnCliffordTableauArgs. + # TODO(#5256) Add the implementation of _act_on_ with CliffordTableauSimulationState. return NotImplemented # Single Clifford Gate decomposition is more efficient than the general Tableau decomposition. diff --git a/cirq-core/cirq/ops/clifford_gate_test.py b/cirq-core/cirq/ops/clifford_gate_test.py index faf3e7c5a27..adc5c6f43fa 100644 --- a/cirq-core/cirq/ops/clifford_gate_test.py +++ b/cirq-core/cirq/ops/clifford_gate_test.py @@ -19,7 +19,7 @@ import pytest import cirq -from cirq.protocols.act_on_protocol_test import DummyActOnArgs +from cirq.protocols.act_on_protocol_test import DummySimulationState from cirq.testing import EqualsTester, assert_allclose_up_to_global_phase _bools = (False, True) @@ -783,10 +783,10 @@ def test_clifford_gate_act_on_small_case(): # Note this is also covered by the `from_op_list` one, etc. qubits = cirq.LineQubit.range(5) - args = cirq.ActOnCliffordTableauArgs( + args = cirq.CliffordTableauSimulationState( tableau=cirq.CliffordTableau(num_qubits=5), qubits=qubits, prng=np.random.RandomState() ) - expected_args = cirq.ActOnCliffordTableauArgs( + expected_args = cirq.CliffordTableauSimulationState( tableau=cirq.CliffordTableau(num_qubits=5), qubits=qubits, prng=np.random.RandomState() ) cirq.act_on(cirq.H, expected_args, qubits=[qubits[0]], allow_decompose=False) @@ -818,8 +818,8 @@ def test_clifford_gate_act_on_large_case(): t1 = cirq.CliffordTableau(num_qubits=n) t2 = cirq.CliffordTableau(num_qubits=n) qubits = cirq.LineQubit.range(n) - args1 = cirq.ActOnCliffordTableauArgs(tableau=t1, qubits=qubits, prng=prng) - args2 = cirq.ActOnCliffordTableauArgs(tableau=t2, qubits=qubits, prng=prng) + args1 = cirq.CliffordTableauSimulationState(tableau=t1, qubits=qubits, prng=prng) + args2 = cirq.CliffordTableauSimulationState(tableau=t2, qubits=qubits, prng=prng) ops = [] for _ in range(num_ops): g = prng.randint(len(gate_candidate)) @@ -838,7 +838,7 @@ def test_clifford_gate_act_on_ch_form(): # Although we don't support CH_form from the _act_on_, it will fall back # to the decomposititon method and apply it through decomposed ops. # Here we run it for the coverage only. - args = cirq.ActOnStabilizerCHFormArgs( + args = cirq.StabilizerChFormSimulationState( initial_state=cirq.StabilizerStateChForm(num_qubits=2, initial_state=1), qubits=cirq.LineQubit.range(2), prng=np.random.RandomState(), @@ -849,4 +849,4 @@ def test_clifford_gate_act_on_ch_form(): def test_clifford_gate_act_on_fail(): with pytest.raises(TypeError, match="Failed to act"): - cirq.act_on(cirq.CliffordGate.X, DummyActOnArgs(), qubits=()) + cirq.act_on(cirq.CliffordGate.X, DummySimulationState(), qubits=()) diff --git a/cirq-core/cirq/ops/common_channels.py b/cirq-core/cirq/ops/common_channels.py index dda34bfb96f..b8f15f53860 100644 --- a/cirq-core/cirq/ops/common_channels.py +++ b/cirq-core/cirq/ops/common_channels.py @@ -713,7 +713,7 @@ def _qasm_(self, args: 'cirq.QasmArgs', qubits: Tuple['cirq.Qid', ...]) -> Optio def _qid_shape_(self): return (self._dimension,) - def _act_on_(self, args: 'cirq.OperationTarget', qubits: Sequence['cirq.Qid']): + def _act_on_(self, sim_state: 'cirq.SimulationStateBase', qubits: Sequence['cirq.Qid']): if len(qubits) != 1: return NotImplemented @@ -734,12 +734,15 @@ def _unitary_(self): u[:inc] = np.eye(self.dimension)[-inc:] return u - from cirq.sim import act_on_args + from cirq.sim import simulation_state - if isinstance(args, act_on_args.ActOnArgs) and not args.can_represent_mixed_states: - result = args._perform_measurement(qubits)[0] + if ( + isinstance(sim_state, simulation_state.SimulationState) + and not sim_state.can_represent_mixed_states + ): + result = sim_state._perform_measurement(qubits)[0] gate = PlusGate(self.dimension, self.dimension - result) - protocols.act_on(gate, args, qubits) + protocols.act_on(gate, sim_state, qubits) return True return NotImplemented diff --git a/cirq-core/cirq/ops/common_channels_test.py b/cirq-core/cirq/ops/common_channels_test.py index 0f87a522843..806d37a426b 100644 --- a/cirq-core/cirq/ops/common_channels_test.py +++ b/cirq-core/cirq/ops/common_channels_test.py @@ -18,7 +18,7 @@ import pytest import cirq -from cirq.protocols.act_on_protocol_test import DummyActOnArgs +from cirq.protocols.act_on_protocol_test import DummySimulationState X = np.array([[0, 1], [1, 0]]) Y = np.array([[0, -1j], [1j, 0]]) @@ -489,9 +489,9 @@ def test_reset_channel_text_diagram(): def test_reset_act_on(): with pytest.raises(TypeError, match="Failed to act"): - cirq.act_on(cirq.ResetChannel(), DummyActOnArgs(), qubits=()) + cirq.act_on(cirq.ResetChannel(), DummySimulationState(), qubits=()) - args = cirq.ActOnStateVectorArgs( + args = cirq.StateVectorSimulationState( available_buffer=np.empty(shape=(2, 2, 2, 2, 2), dtype=np.complex64), qubits=cirq.LineQubit.range(5), prng=np.random.RandomState(), @@ -714,7 +714,7 @@ def test_bit_flip_channel_text_diagram(): def test_stabilizer_supports_depolarize(): with pytest.raises(TypeError, match="act_on"): for _ in range(100): - cirq.act_on(cirq.depolarize(3 / 4), DummyActOnArgs(), qubits=()) + cirq.act_on(cirq.depolarize(3 / 4), DummySimulationState(), qubits=()) q = cirq.LineQubit(0) c = cirq.Circuit(cirq.depolarize(3 / 4).on(q), cirq.measure(q, key='m')) diff --git a/cirq-core/cirq/ops/common_gates_test.py b/cirq-core/cirq/ops/common_gates_test.py index d2fcb6f5de9..4b6ca57179c 100644 --- a/cirq-core/cirq/ops/common_gates_test.py +++ b/cirq-core/cirq/ops/common_gates_test.py @@ -17,7 +17,7 @@ import sympy import cirq -from cirq.protocols.act_on_protocol_test import DummyActOnArgs +from cirq.protocols.act_on_protocol_test import DummySimulationState H = np.array([[1, 1], [1, -1]]) * np.sqrt(0.5) HH = cirq.kron(H, H) @@ -289,11 +289,11 @@ def test_h_str(): def test_x_act_on_tableau(): with pytest.raises(TypeError, match="Failed to act"): - cirq.act_on(cirq.X, DummyActOnArgs(), qubits=()) + cirq.act_on(cirq.X, DummySimulationState(), qubits=()) original_tableau = cirq.CliffordTableau(num_qubits=5, initial_state=31) flipped_tableau = cirq.CliffordTableau(num_qubits=5, initial_state=23) - args = cirq.ActOnCliffordTableauArgs( + args = cirq.CliffordTableauSimulationState( tableau=original_tableau.copy(), qubits=cirq.LineQubit.range(5), prng=np.random.RandomState(), @@ -338,11 +338,11 @@ def _unitary_(self): def test_y_act_on_tableau(): with pytest.raises(TypeError, match="Failed to act"): - cirq.act_on(cirq.Y, DummyActOnArgs(), qubits=()) + cirq.act_on(cirq.Y, DummySimulationState(), qubits=()) original_tableau = cirq.CliffordTableau(num_qubits=5, initial_state=31) flipped_tableau = cirq.CliffordTableau(num_qubits=5, initial_state=23) - args = cirq.ActOnCliffordTableauArgs( + args = cirq.CliffordTableauSimulationState( tableau=original_tableau.copy(), qubits=cirq.LineQubit.range(5), prng=np.random.RandomState(), @@ -376,13 +376,13 @@ def test_y_act_on_tableau(): def test_z_h_act_on_tableau(): with pytest.raises(TypeError, match="Failed to act"): - cirq.act_on(cirq.Z, DummyActOnArgs(), qubits=()) + cirq.act_on(cirq.Z, DummySimulationState(), qubits=()) with pytest.raises(TypeError, match="Failed to act"): - cirq.act_on(cirq.H, DummyActOnArgs(), qubits=()) + cirq.act_on(cirq.H, DummySimulationState(), qubits=()) original_tableau = cirq.CliffordTableau(num_qubits=5, initial_state=31) flipped_tableau = cirq.CliffordTableau(num_qubits=5, initial_state=23) - args = cirq.ActOnCliffordTableauArgs( + args = cirq.CliffordTableauSimulationState( tableau=original_tableau.copy(), qubits=cirq.LineQubit.range(5), prng=np.random.RandomState(), @@ -429,10 +429,10 @@ def test_z_h_act_on_tableau(): def test_cx_act_on_tableau(): with pytest.raises(TypeError, match="Failed to act"): - cirq.act_on(cirq.CX, DummyActOnArgs(), qubits=()) + cirq.act_on(cirq.CX, DummySimulationState(), qubits=()) original_tableau = cirq.CliffordTableau(num_qubits=5, initial_state=31) - args = cirq.ActOnCliffordTableauArgs( + args = cirq.CliffordTableauSimulationState( tableau=original_tableau.copy(), qubits=cirq.LineQubit.range(5), prng=np.random.RandomState(), @@ -473,10 +473,10 @@ def test_cx_act_on_tableau(): def test_cz_act_on_tableau(): with pytest.raises(TypeError, match="Failed to act"): - cirq.act_on(cirq.CZ, DummyActOnArgs(), qubits=()) + cirq.act_on(cirq.CZ, DummySimulationState(), qubits=()) original_tableau = cirq.CliffordTableau(num_qubits=5, initial_state=31) - args = cirq.ActOnCliffordTableauArgs( + args = cirq.CliffordTableauSimulationState( tableau=original_tableau.copy(), qubits=cirq.LineQubit.range(5), prng=np.random.RandomState(), @@ -516,12 +516,12 @@ def test_cz_act_on_tableau(): def test_cz_act_on_equivalent_to_h_cx_h_tableau(): - args1 = cirq.ActOnCliffordTableauArgs( + args1 = cirq.CliffordTableauSimulationState( tableau=cirq.CliffordTableau(num_qubits=2), qubits=cirq.LineQubit.range(2), prng=np.random.RandomState(), ) - args2 = cirq.ActOnCliffordTableauArgs( + args2 = cirq.CliffordTableauSimulationState( tableau=cirq.CliffordTableau(num_qubits=2), qubits=cirq.LineQubit.range(2), prng=np.random.RandomState(), @@ -583,7 +583,7 @@ def test_act_on_ch_form(input_gate_sequence, outcome): else: assert num_qubits == 2 qubits = cirq.LineQubit.range(2) - args = cirq.ActOnStabilizerCHFormArgs( + args = cirq.StabilizerChFormSimulationState( qubits=cirq.LineQubit.range(2), prng=np.random.RandomState(), initial_state=original_state.copy(), diff --git a/cirq-core/cirq/ops/gate_operation.py b/cirq-core/cirq/ops/gate_operation.py index f928a29a419..f23a1a0b30e 100644 --- a/cirq-core/cirq/ops/gate_operation.py +++ b/cirq-core/cirq/ops/gate_operation.py @@ -254,10 +254,10 @@ def _measurement_key_objs_(self) -> Optional[AbstractSet['cirq.MeasurementKey']] return getter() return NotImplemented - def _act_on_(self, args: 'cirq.OperationTarget'): + def _act_on_(self, sim_state: 'cirq.SimulationStateBase'): getter = getattr(self.gate, '_act_on_', None) if getter is not None: - return getter(args, self.qubits) + return getter(sim_state, self.qubits) return NotImplemented def _is_parameterized_(self) -> bool: diff --git a/cirq-core/cirq/ops/global_phase_op_test.py b/cirq-core/cirq/ops/global_phase_op_test.py index 5c55b5354e1..b8c9c5a0202 100644 --- a/cirq-core/cirq/ops/global_phase_op_test.py +++ b/cirq-core/cirq/ops/global_phase_op_test.py @@ -44,7 +44,7 @@ def test_protocols(): @pytest.mark.parametrize('phase', [1, 1j, -1]) def test_act_on_tableau(phase): original_tableau = cirq.CliffordTableau(0) - args = cirq.ActOnCliffordTableauArgs(original_tableau.copy(), np.random.RandomState()) + args = cirq.CliffordTableauSimulationState(original_tableau.copy(), np.random.RandomState()) cirq.act_on(cirq.global_phase_operation(phase), args, allow_decompose=False) assert args.tableau == original_tableau @@ -52,7 +52,7 @@ def test_act_on_tableau(phase): @pytest.mark.parametrize('phase', [1, 1j, -1]) def test_act_on_ch_form(phase): state = cirq.StabilizerStateChForm(0) - args = cirq.ActOnStabilizerCHFormArgs( + args = cirq.StabilizerChFormSimulationState( qubits=[], prng=np.random.RandomState(), initial_state=state ) cirq.act_on(cirq.global_phase_operation(phase), args, allow_decompose=False) @@ -241,7 +241,7 @@ def test_gate_protocols(): @pytest.mark.parametrize('phase', [1, 1j, -1]) def test_gate_act_on_tableau(phase): original_tableau = cirq.CliffordTableau(0) - args = cirq.ActOnCliffordTableauArgs(original_tableau.copy(), np.random.RandomState()) + args = cirq.CliffordTableauSimulationState(original_tableau.copy(), np.random.RandomState()) cirq.act_on(cirq.GlobalPhaseGate(phase), args, qubits=(), allow_decompose=False) assert args.tableau == original_tableau @@ -249,7 +249,7 @@ def test_gate_act_on_tableau(phase): @pytest.mark.parametrize('phase', [1, 1j, -1]) def test_gate_act_on_ch_form(phase): state = cirq.StabilizerStateChForm(0) - args = cirq.ActOnStabilizerCHFormArgs( + args = cirq.StabilizerChFormSimulationState( qubits=[], prng=np.random.RandomState(), initial_state=state ) cirq.act_on(cirq.GlobalPhaseGate(phase), args, qubits=(), allow_decompose=False) diff --git a/cirq-core/cirq/ops/identity.py b/cirq-core/cirq/ops/identity.py index 435bdabc008..dfd3fded96e 100644 --- a/cirq-core/cirq/ops/identity.py +++ b/cirq-core/cirq/ops/identity.py @@ -61,7 +61,7 @@ def __init__( if len(self._qid_shape) != num_qubits: raise ValueError('len(qid_shape) != num_qubits') - def _act_on_(self, args: 'cirq.OperationTarget', qubits: Sequence['cirq.Qid']): + def _act_on_(self, sim_state: 'cirq.SimulationStateBase', qubits: Sequence['cirq.Qid']): return True def _qid_shape_(self) -> Tuple[int, ...]: diff --git a/cirq-core/cirq/ops/identity_test.py b/cirq-core/cirq/ops/identity_test.py index c782cd39228..a3ce014b293 100644 --- a/cirq-core/cirq/ops/identity_test.py +++ b/cirq-core/cirq/ops/identity_test.py @@ -205,6 +205,6 @@ def with_qubits(self, *new_qubits): def test_identity_short_circuits_act_on(): - args = mock.Mock(cirq.ActOnArgs) + args = mock.Mock(cirq.SimulationState) args._act_on_fallback_.side_effect = mock.Mock(side_effect=Exception('No!')) cirq.act_on(cirq.IdentityGate(1)(cirq.LineQubit(0)), args) diff --git a/cirq-core/cirq/ops/measurement_gate.py b/cirq-core/cirq/ops/measurement_gate.py index 66f16080ea5..8f065886f7a 100644 --- a/cirq-core/cirq/ops/measurement_gate.py +++ b/cirq-core/cirq/ops/measurement_gate.py @@ -260,12 +260,12 @@ def _from_json_dict_(cls, num_qubits, key, invert_mask, qid_shape=None, **kwargs def _has_stabilizer_effect_(self) -> Optional[bool]: return True - def _act_on_(self, args: 'cirq.OperationTarget', qubits: Sequence['cirq.Qid']) -> bool: - from cirq.sim import ActOnArgs + def _act_on_(self, sim_state: 'cirq.SimulationStateBase', qubits: Sequence['cirq.Qid']) -> bool: + from cirq.sim import SimulationState - if not isinstance(args, ActOnArgs): + if not isinstance(sim_state, SimulationState): return NotImplemented - args.measure(qubits, self.key, self.full_invert_mask()) + sim_state.measure(qubits, self.key, self.full_invert_mask()) return True diff --git a/cirq-core/cirq/ops/measurement_gate_test.py b/cirq-core/cirq/ops/measurement_gate_test.py index ae29c4f792d..91a31be530f 100644 --- a/cirq-core/cirq/ops/measurement_gate_test.py +++ b/cirq-core/cirq/ops/measurement_gate_test.py @@ -284,7 +284,7 @@ def test_act_on_state_vector(): a, b = [cirq.LineQubit(3), cirq.LineQubit(1)] m = cirq.measure(a, b, key='out', invert_mask=(True,)) - args = cirq.ActOnStateVectorArgs( + args = cirq.StateVectorSimulationState( available_buffer=np.empty(shape=(2, 2, 2, 2, 2)), qubits=cirq.LineQubit.range(5), prng=np.random.RandomState(), @@ -294,7 +294,7 @@ def test_act_on_state_vector(): cirq.act_on(m, args) assert args.log_of_measurement_results == {'out': [1, 0]} - args = cirq.ActOnStateVectorArgs( + args = cirq.StateVectorSimulationState( available_buffer=np.empty(shape=(2, 2, 2, 2, 2)), qubits=cirq.LineQubit.range(5), prng=np.random.RandomState(), @@ -306,7 +306,7 @@ def test_act_on_state_vector(): cirq.act_on(m, args) assert args.log_of_measurement_results == {'out': [1, 1]} - args = cirq.ActOnStateVectorArgs( + args = cirq.StateVectorSimulationState( available_buffer=np.empty(shape=(2, 2, 2, 2, 2)), qubits=cirq.LineQubit.range(5), prng=np.random.RandomState(), @@ -331,7 +331,7 @@ def test_act_on_clifford_tableau(): # The below assertion does not fail since it ignores non-unitary operations cirq.testing.assert_all_implemented_act_on_effects_match_unitary(m) - args = cirq.ActOnCliffordTableauArgs( + args = cirq.CliffordTableauSimulationState( tableau=cirq.CliffordTableau(num_qubits=5, initial_state=0), qubits=cirq.LineQubit.range(5), prng=np.random.RandomState(), @@ -339,7 +339,7 @@ def test_act_on_clifford_tableau(): cirq.act_on(m, args) assert args.log_of_measurement_results == {'out': [1, 0]} - args = cirq.ActOnCliffordTableauArgs( + args = cirq.CliffordTableauSimulationState( tableau=cirq.CliffordTableau(num_qubits=5, initial_state=8), qubits=cirq.LineQubit.range(5), prng=np.random.RandomState(), @@ -348,7 +348,7 @@ def test_act_on_clifford_tableau(): cirq.act_on(m, args) assert args.log_of_measurement_results == {'out': [1, 1]} - args = cirq.ActOnCliffordTableauArgs( + args = cirq.CliffordTableauSimulationState( tableau=cirq.CliffordTableau(num_qubits=5, initial_state=10), qubits=cirq.LineQubit.range(5), prng=np.random.RandomState(), @@ -369,20 +369,20 @@ def test_act_on_stabilizer_ch_form(): # The below assertion does not fail since it ignores non-unitary operations cirq.testing.assert_all_implemented_act_on_effects_match_unitary(m) - args = cirq.ActOnStabilizerCHFormArgs( + args = cirq.StabilizerChFormSimulationState( qubits=cirq.LineQubit.range(5), prng=np.random.RandomState(), initial_state=0 ) cirq.act_on(m, args) assert args.log_of_measurement_results == {'out': [1, 0]} - args = cirq.ActOnStabilizerCHFormArgs( + args = cirq.StabilizerChFormSimulationState( qubits=cirq.LineQubit.range(5), prng=np.random.RandomState(), initial_state=8 ) cirq.act_on(m, args) assert args.log_of_measurement_results == {'out': [1, 1]} - args = cirq.ActOnStabilizerCHFormArgs( + args = cirq.StabilizerChFormSimulationState( qubits=cirq.LineQubit.range(5), prng=np.random.RandomState(), initial_state=10 ) cirq.act_on(m, args) @@ -399,7 +399,7 @@ def test_act_on_qutrit(): a, b = [cirq.LineQid(3, dimension=3), cirq.LineQid(1, dimension=3)] m = cirq.measure(a, b, key='out', invert_mask=(True,)) - args = cirq.ActOnStateVectorArgs( + args = cirq.StateVectorSimulationState( available_buffer=np.empty(shape=(3, 3, 3, 3, 3)), qubits=cirq.LineQid.range(5, dimension=3), prng=np.random.RandomState(), @@ -411,7 +411,7 @@ def test_act_on_qutrit(): cirq.act_on(m, args) assert args.log_of_measurement_results == {'out': [2, 2]} - args = cirq.ActOnStateVectorArgs( + args = cirq.StateVectorSimulationState( available_buffer=np.empty(shape=(3, 3, 3, 3, 3)), qubits=cirq.LineQid.range(5, dimension=3), prng=np.random.RandomState(), @@ -423,7 +423,7 @@ def test_act_on_qutrit(): cirq.act_on(m, args) assert args.log_of_measurement_results == {'out': [2, 1]} - args = cirq.ActOnStateVectorArgs( + args = cirq.StateVectorSimulationState( available_buffer=np.empty(shape=(3, 3, 3, 3, 3)), qubits=cirq.LineQid.range(5, dimension=3), prng=np.random.RandomState(), diff --git a/cirq-core/cirq/ops/random_gate_channel_test.py b/cirq-core/cirq/ops/random_gate_channel_test.py index 8e46fdcff99..57b776c262a 100644 --- a/cirq-core/cirq/ops/random_gate_channel_test.py +++ b/cirq-core/cirq/ops/random_gate_channel_test.py @@ -215,13 +215,13 @@ def test_stabilizer_supports_probability(): def test_unsupported_stabilizer_safety(): - from cirq.protocols.act_on_protocol_test import DummyActOnArgs + from cirq.protocols.act_on_protocol_test import DummySimulationState with pytest.raises(TypeError, match="act_on"): for _ in range(100): - cirq.act_on(cirq.X.with_probability(0.5), DummyActOnArgs(), qubits=()) + cirq.act_on(cirq.X.with_probability(0.5), DummySimulationState(), qubits=()) with pytest.raises(TypeError, match="act_on"): - cirq.act_on(cirq.X.with_probability(sympy.Symbol('x')), DummyActOnArgs(), qubits=()) + cirq.act_on(cirq.X.with_probability(sympy.Symbol('x')), DummySimulationState(), qubits=()) q = cirq.LineQubit(0) c = cirq.Circuit((cirq.X(q) ** 0.25).with_probability(0.5), cirq.measure(q, key='m')) diff --git a/cirq-core/cirq/ops/raw_types.py b/cirq-core/cirq/ops/raw_types.py index 9b07c9661a3..bb394fc9c5f 100644 --- a/cirq-core/cirq/ops/raw_types.py +++ b/cirq-core/cirq/ops/raw_types.py @@ -797,10 +797,10 @@ def _is_parameterized_(self) -> bool: protocols.is_parameterized(tag) for tag in self.tags ) - def _act_on_(self, args: 'cirq.OperationTarget') -> bool: + def _act_on_(self, sim_state: 'cirq.SimulationStateBase') -> bool: sub = getattr(self.sub_operation, "_act_on_", None) if sub is not None: - return sub(args) + return sub(sim_state) return NotImplemented def _parameter_names_(self) -> AbstractSet[str]: diff --git a/cirq-core/cirq/ops/raw_types_test.py b/cirq-core/cirq/ops/raw_types_test.py index 92e1f6ee65d..a040351d3cc 100644 --- a/cirq-core/cirq/ops/raw_types_test.py +++ b/cirq-core/cirq/ops/raw_types_test.py @@ -735,14 +735,14 @@ class YesActOn(cirq.Gate): def _num_qubits_(self) -> int: return 1 - def _act_on_(self, args, qubits): + def _act_on_(self, sim_state, qubits): return True class NoActOn(cirq.Gate): def _num_qubits_(self) -> int: return 1 - def _act_on_(self, args, qubits): + def _act_on_(self, sim_state, qubits): return NotImplemented class MissingActOn(cirq.Operation): @@ -754,9 +754,9 @@ def qubits(self): pass q = cirq.LineQubit(1) - from cirq.protocols.act_on_protocol_test import DummyActOnArgs + from cirq.protocols.act_on_protocol_test import DummySimulationState - args = DummyActOnArgs() + args = DummySimulationState() cirq.act_on(YesActOn()(q).with_tags("test"), args) with pytest.raises(TypeError, match="Failed to act"): cirq.act_on(NoActOn()(q).with_tags("test"), args) diff --git a/cirq-core/cirq/protocols/act_on_protocol.py b/cirq-core/cirq/protocols/act_on_protocol.py index a38515b30ae..5080509ad66 100644 --- a/cirq-core/cirq/protocols/act_on_protocol.py +++ b/cirq-core/cirq/protocols/act_on_protocol.py @@ -28,11 +28,11 @@ class SupportsActOn(Protocol): """An object that explicitly specifies how to act on simulator states.""" @doc_private - def _act_on_(self, args: 'cirq.OperationTarget') -> Union[NotImplementedType, bool]: + def _act_on_(self, sim_state: 'cirq.SimulationStateBase') -> Union[NotImplementedType, bool]: """Applies an action to the given argument, if it is a supported type. For example, unitary operations can implement an `_act_on_` method that - checks if `isinstance(args, cirq.ActOnStateVectorArgs)` and, if so, + checks if `isinstance(args, cirq.StateVectorSimulationState)` and, if so, apply their unitary effect to the state vector. The global `cirq.act_on` method looks for whether or not the given @@ -43,8 +43,8 @@ def _act_on_(self, args: 'cirq.OperationTarget') -> Union[NotImplementedType, bo as gates should use `SupportsActOnQubits`. Args: - args: An object of unspecified type. The method must check if this - object is of a recognized type and act on it if so. + sim_state: An object of unspecified type. The method must check if + this object is of a recognized type and act on it if so. Returns: True: The receiving object (`self`) acted on the argument. @@ -59,12 +59,12 @@ class SupportsActOnQubits(Protocol): @doc_private def _act_on_( - self, args: 'cirq.OperationTarget', qubits: Sequence['cirq.Qid'] + self, sim_state: 'cirq.SimulationStateBase', qubits: Sequence['cirq.Qid'] ) -> Union[NotImplementedType, bool]: """Applies an action to the given argument, if it is a supported type. For example, unitary operations can implement an `_act_on_` method that - checks if `isinstance(args, cirq.ActOnStateVectorArgs)` and, if so, + checks if `isinstance(args, cirq.StateVectorSimulationState)` and, if so, apply their unitary effect to the state vector. The global `cirq.act_on` method looks for whether or not the given @@ -74,8 +74,8 @@ def _act_on_( If implementing this on an `Operation`, use `SupportsActOn` instead. Args: - args: An object of unspecified type. The method must check if this - object is of a recognized type and act on it if so. + sim_state: An object of unspecified type. The method must check if + this object is of a recognized type and act on it if so. qubits: The sequence of qubits to use when applying the action. Returns: @@ -88,7 +88,7 @@ def _act_on_( def act_on( action: Any, - args: 'cirq.OperationTarget', + args: 'cirq.SimulationStateBase', qubits: Sequence['cirq.Qid'] = None, *, allow_decompose: bool = True, @@ -97,7 +97,7 @@ def act_on( For example, the action may be a `cirq.Operation` and the state argument may represent the internal state of a state vector simulator (a - `cirq.ActOnStateVectorArgs`). + `cirq.StateVectorSimulationState`). For non-operations, the `qubits` argument must be explicitly supplied. diff --git a/cirq-core/cirq/protocols/act_on_protocol_test.py b/cirq-core/cirq/protocols/act_on_protocol_test.py index d3ba8113932..bc0cd4c6a4a 100644 --- a/cirq-core/cirq/protocols/act_on_protocol_test.py +++ b/cirq-core/cirq/protocols/act_on_protocol_test.py @@ -28,7 +28,7 @@ def measure(self, axes, seed=None): pass -class DummyActOnArgs(cirq.ActOnArgs): +class DummySimulationState(cirq.SimulationState): def __init__(self, fallback_result: Any = NotImplemented): super().__init__(prng=np.random.RandomState(), state=DummyQuantumState()) self.fallback_result = fallback_result @@ -43,18 +43,18 @@ def _act_on_fallback_( def test_act_on_fallback_succeeds(): - args = DummyActOnArgs(fallback_result=True) + args = DummySimulationState(fallback_result=True) cirq.act_on(op, args) def test_act_on_fallback_fails(): - args = DummyActOnArgs(fallback_result=NotImplemented) + args = DummySimulationState(fallback_result=NotImplemented) with pytest.raises(TypeError, match='Failed to act'): cirq.act_on(op, args) def test_act_on_fallback_errors(): - args = DummyActOnArgs(fallback_result=False) + args = DummySimulationState(fallback_result=False) with pytest.raises(ValueError, match='_act_on_fallback_ must return True or NotImplemented'): cirq.act_on(op, args) @@ -68,10 +68,10 @@ def qubits(self) -> Tuple['cirq.Qid', ...]: def with_qubits(self: TSelf, *new_qubits: 'cirq.Qid') -> TSelf: pass - def _act_on_(self, args): + def _act_on_(self, sim_state): return False - args = DummyActOnArgs(fallback_result=True) + args = DummySimulationState(fallback_result=True) with pytest.raises(ValueError, match='_act_on_ must return True or NotImplemented'): cirq.act_on(Op(), args) @@ -85,7 +85,7 @@ def qubits(self) -> Tuple['cirq.Qid', ...]: def with_qubits(self: TSelf, *new_qubits: 'cirq.Qid') -> TSelf: pass - args = DummyActOnArgs() + args = DummySimulationState() with pytest.raises( ValueError, match='Calls to act_on should not supply qubits if the action is an Operation' ): @@ -93,6 +93,6 @@ def with_qubits(self: TSelf, *new_qubits: 'cirq.Qid') -> TSelf: def test_qubits_should_be_defined_for_operations(): - args = DummyActOnArgs() + args = DummySimulationState() with pytest.raises(ValueError, match='Calls to act_on should'): cirq.act_on(cirq.KrausChannel([np.array([[1, 0], [0, 0]])]), args, qubits=None) diff --git a/cirq-core/cirq/protocols/json_test_data/spec.py b/cirq-core/cirq/protocols/json_test_data/spec.py index a133ca3e8f1..434a12c5e6c 100644 --- a/cirq-core/cirq/protocols/json_test_data/spec.py +++ b/cirq-core/cirq/protocols/json_test_data/spec.py @@ -91,7 +91,15 @@ 'ApplyChannelArgs', 'ApplyMixtureArgs', 'ApplyUnitaryArgs', + 'CliffordTableauSimulationState', + 'DensityMatrixSimulationState', 'OperationTarget', + 'SimulationProductState', + 'SimulationState', + 'SimulationStateBase', + 'StabilizerChFormSimulationState', + 'StabilizerSimulationState', + 'StateVectorSimulationState', # Abstract base class for creating compilation targets. 'CompilationTargetGateset', 'TwoQubitCompilationTargetGateset', diff --git a/cirq-core/cirq/sim/__init__.py b/cirq-core/cirq/sim/__init__.py index 307047c8d30..116ccb7c237 100644 --- a/cirq-core/cirq/sim/__init__.py +++ b/cirq-core/cirq/sim/__init__.py @@ -13,7 +13,6 @@ # limitations under the License. """Classes for circuit simulators and base implementations of these classes.""" -from typing import Tuple, Dict from cirq.sim.act_on_args import ActOnArgs @@ -23,7 +22,22 @@ from cirq.sim.act_on_state_vector_args import ActOnStateVectorArgs -from cirq.sim.density_matrix_utils import measure_density_matrix, sample_density_matrix +from cirq.sim.clifford import ( + ActOnCliffordTableauArgs, + ActOnStabilizerArgs, + ActOnStabilizerCHFormArgs, + CliffordSimulator, + CliffordSimulatorStepResult, + CliffordState, + CliffordTrialResult, + CliffordTableauSimulationState, + StabilizerChFormSimulationState, + StabilizerSampler, + StabilizerSimulationState, + StabilizerStateChForm, +) + +from cirq.sim.density_matrix_simulation_state import DensityMatrixSimulationState from cirq.sim.density_matrix_simulator import ( DensityMatrixSimulator, @@ -32,7 +46,7 @@ DensityMatrixTrialResult, ) -from cirq.sim.operation_target import OperationTarget +from cirq.sim.density_matrix_utils import measure_density_matrix, sample_density_matrix from cirq.sim.mux import ( CIRCUIT_LIKE, @@ -42,6 +56,14 @@ sample_sweep, ) +from cirq.sim.operation_target import OperationTarget + +from cirq.sim.simulation_product_state import SimulationProductState + +from cirq.sim.simulation_state import SimulationState + +from cirq.sim.simulation_state_base import SimulationStateBase + from cirq.sim.simulator import ( SimulatesAmplitudes, SimulatesExpectationValues, @@ -56,23 +78,13 @@ from cirq.sim.sparse_simulator import Simulator, SparseSimulatorStep +from cirq.sim.state_vector import measure_state_vector, sample_state_vector, StateVectorMixin + +from cirq.sim.state_vector_simulation_state import StateVectorSimulationState + from cirq.sim.state_vector_simulator import ( SimulatesIntermediateStateVector, StateVectorSimulatorState, StateVectorStepResult, StateVectorTrialResult, ) - -from cirq.sim.state_vector import measure_state_vector, sample_state_vector, StateVectorMixin - -from cirq.sim.clifford import ( - ActOnCliffordTableauArgs, - ActOnStabilizerCHFormArgs, - ActOnStabilizerArgs, - StabilizerSampler, - StabilizerStateChForm, - CliffordSimulator, - CliffordState, - CliffordTrialResult, - CliffordSimulatorStepResult, -) diff --git a/cirq-core/cirq/sim/act_on_args.py b/cirq-core/cirq/sim/act_on_args.py index 6fc0b679833..b1f8668bfba 100644 --- a/cirq-core/cirq/sim/act_on_args.py +++ b/cirq-core/cirq/sim/act_on_args.py @@ -1,4 +1,4 @@ -# Copyright 2021 The Cirq Developers +# Copyright 2022 The Cirq Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,381 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Objects and methods for acting efficiently on a state tensor.""" -import abc -import copy -from typing import ( - Any, - cast, - Dict, - Generic, - Iterator, - List, - Optional, - Sequence, - TypeVar, - TYPE_CHECKING, - Tuple, -) -import numpy as np +from cirq import _compat +from cirq.sim.simulation_state import SimulationState -from cirq import protocols, value -from cirq._compat import _warn_or_error, deprecated, deprecated_parameter -from cirq.protocols.decompose_protocol import _try_decompose_into_operations_and_qubits -from cirq.sim.operation_target import OperationTarget -TSelf = TypeVar('TSelf', bound='ActOnArgs') -TState = TypeVar('TState', bound='cirq.QuantumStateRepresentation') - -if TYPE_CHECKING: - import cirq - - -class ActOnArgs(OperationTarget, Generic[TState], metaclass=abc.ABCMeta): - """State and context for an operation acting on a state tensor.""" - - @deprecated_parameter( - deadline='v0.16', - fix='Use kwargs instead of positional args', - parameter_desc='args', - match=lambda args, kwargs: len(args) > 1, - ) - @deprecated_parameter( - deadline='v0.16', - fix='Replace log_of_measurement_results with' - ' classical_data=cirq.ClassicalDataDictionaryStore(_records=logs).', - parameter_desc='log_of_measurement_results', - match=lambda args, kwargs: 'log_of_measurement_results' in kwargs, - ) - def __init__( - self, - prng: Optional[np.random.RandomState] = None, - qubits: Optional[Sequence['cirq.Qid']] = None, - log_of_measurement_results: Optional[Dict[str, List[int]]] = None, - classical_data: Optional['cirq.ClassicalDataStore'] = None, - state: Optional[TState] = None, - ): - """Inits ActOnArgs. - - Args: - prng: The pseudo random number generator to use for probabilistic - effects. - qubits: Determines the canonical ordering of the qubits. This - is often used in specifying the initial state, i.e. the - ordering of the computational basis states. - log_of_measurement_results: A mutable object that measurements are - being recorded into. - classical_data: The shared classical data container for this - simulation. - state: The underlying quantum state of the simulation. - """ - if qubits is None: - qubits = () - classical_data = classical_data or value.ClassicalDataDictionaryStore( - _records={ - value.MeasurementKey.parse_serialized(k): [tuple(v)] - for k, v in (log_of_measurement_results or {}).items() - } - ) - super().__init__(qubits=qubits, classical_data=classical_data) - if prng is None: - prng = cast(np.random.RandomState, np.random) - self._prng = prng - self._state = cast(TState, state) - if state is None: - _warn_or_error('This function will require a valid `state` input in cirq v0.16.') - - @property - def prng(self) -> np.random.RandomState: - return self._prng - - def measure(self, qubits: Sequence['cirq.Qid'], key: str, invert_mask: Sequence[bool]): - """Measures the qubits and records to `log_of_measurement_results`. - - Any bitmasks will be applied to the measurement record. - - Args: - qubits: The qubits to measure. - key: The key the measurement result should be logged under. Note - that operations should only store results under keys they have - declared in a `_measurement_key_names_` method. - invert_mask: The invert mask for the measurement. - - Raises: - ValueError: If a measurement key has already been logged to a key. - """ - bits = self._perform_measurement(qubits) - corrected = [bit ^ (bit < 2 and mask) for bit, mask in zip(bits, invert_mask)] - self._classical_data.record_measurement( - value.MeasurementKey.parse_serialized(key), corrected, qubits - ) - - def get_axes(self, qubits: Sequence['cirq.Qid']) -> List[int]: - return [self.qubit_map[q] for q in qubits] - - def _perform_measurement(self, qubits: Sequence['cirq.Qid']) -> List[int]: - """Delegates the call to measure the density matrix.""" - if self._state is not None: - return self._state.measure(self.get_axes(qubits), self.prng) - raise NotImplementedError() - - def sample( - self, - qubits: Sequence['cirq.Qid'], - repetitions: int = 1, - seed: 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None, - ) -> np.ndarray: - if self._state is not None: - return self._state.sample(self.get_axes(qubits), repetitions, seed) - raise NotImplementedError() - - def copy(self: TSelf, deep_copy_buffers: bool = True) -> TSelf: - """Creates a copy of the object. - - Args: - deep_copy_buffers: If True, buffers will also be deep-copied. - Otherwise the copy will share a reference to the original object's - buffers. - - Returns: - A copied instance. - """ - args = copy.copy(self) - args._classical_data = self._classical_data.copy() - if self._state is not None: - args._state = self._state.copy(deep_copy_buffers=deep_copy_buffers) - else: - _warn_or_error( - 'Pass a `QuantumStateRepresentation` into the `ActOnArgs` constructor. The `_on_`' - ' overrides will be removed in cirq v0.16.' - ) - self._on_copy(args, deep_copy_buffers) - return args - - @deprecated( - deadline='v0.16', - fix='Pass a `QuantumStateRepresentation` into the `ActOnArgs` constructor.', - ) - def _on_copy(self: TSelf, args: TSelf, deep_copy_buffers: bool = True): - """Subclasses should implement this with any additional state copy - functionality.""" - - def create_merged_state(self: TSelf) -> TSelf: - """Creates a final merged state.""" - return self - - def kronecker_product(self: TSelf, other: TSelf, *, inplace=False) -> TSelf: - """Joins two state spaces together.""" - args = self if inplace else copy.copy(self) - if self._state is not None and other._state is not None: - args._state = self._state.kron(other._state) - else: - _warn_or_error( - 'Pass a `QuantumStateRepresentation` into the `ActOnArgs` constructor. The `_on_`' - ' overrides will be removed in cirq v0.16.' - ) - self._on_kronecker_product(other, args) - args._set_qubits(self.qubits + other.qubits) - return args - - @deprecated( - deadline='v0.16', - fix='Pass a `QuantumStateRepresentation` into the `ActOnArgs` constructor.', - ) - def _on_kronecker_product(self: TSelf, other: TSelf, target: TSelf): - """Subclasses should implement this with any additional state product - functionality, if supported.""" - - def with_qubits(self: TSelf, qubits) -> TSelf: - """Extend current state space with added qubits. - - The state of the added qubits is the default value set in the - subclasses. A new state space is created as the Kronecker product of - the original one and the added one. - - Args: - qubits: The qubits to be added to the state space. - - Regurns: - A new subclass object containing the extended state space. - """ - new_space = type(self)(qubits=qubits) - return self.kronecker_product(new_space) - - def factor( - self: TSelf, qubits: Sequence['cirq.Qid'], *, validate=True, atol=1e-07, inplace=False - ) -> Tuple[TSelf, TSelf]: - """Splits two state spaces after a measurement or reset.""" - extracted = copy.copy(self) - remainder = self if inplace else copy.copy(self) - if self._state is not None: - e, r = self._state.factor(self.get_axes(qubits), validate=validate, atol=atol) - extracted._state = e - remainder._state = r - else: - _warn_or_error( - 'Pass a `QuantumStateRepresentation` into the `ActOnArgs` constructor. The `_on_`' - ' overrides will be removed in cirq v0.16.' - ) - self._on_factor(qubits, extracted, remainder, validate, atol) - extracted._set_qubits(qubits) - remainder._set_qubits([q for q in self.qubits if q not in qubits]) - return extracted, remainder - - @property - def allows_factoring(self): - """Subclasses that allow factorization should override this.""" - return self._state.supports_factor if self._state is not None else False - - @deprecated( - deadline='v0.16', - fix='Pass a `QuantumStateRepresentation` into the `ActOnArgs` constructor.', - ) - def _on_factor( - self: TSelf, - qubits: Sequence['cirq.Qid'], - extracted: TSelf, - remainder: TSelf, - validate=True, - atol=1e-07, - ): - """Subclasses should implement this with any additional state factor - functionality, if supported.""" - - def transpose_to_qubit_order( - self: TSelf, qubits: Sequence['cirq.Qid'], *, inplace=False - ) -> TSelf: - """Physically reindexes the state by the new basis. - - Args: - qubits: The desired qubit order. - inplace: True to perform this operation inplace. - - Returns: - The state with qubit order transposed and underlying representation - updated. - - Raises: - ValueError: If the provided qubits do not match the existing ones. - """ - if len(self.qubits) != len(qubits) or set(qubits) != set(self.qubits): - raise ValueError(f'Qubits do not match. Existing: {self.qubits}, provided: {qubits}') - args = self if inplace else copy.copy(self) - if self._state is not None: - args._state = self._state.reindex(self.get_axes(qubits)) - else: - _warn_or_error( - 'Pass a `QuantumStateRepresentation` into the `ActOnArgs` constructor. The `_on_`' - ' overrides will be removed in cirq v0.16.' - ) - self._on_transpose_to_qubit_order(qubits, args) - args._set_qubits(qubits) - return args - - @deprecated( - deadline='v0.16', - fix='Pass a `QuantumStateRepresentation` into the `ActOnArgs` constructor.', - ) - def _on_transpose_to_qubit_order(self: TSelf, qubits: Sequence['cirq.Qid'], target: TSelf): - """Subclasses should implement this with any additional state transpose - functionality, if supported.""" - - @property # type: ignore - @deprecated(deadline='v0.16', fix='Remove this call, it always returns False.') - def ignore_measurement_results(self) -> bool: - return False - - @property - def qubits(self) -> Tuple['cirq.Qid', ...]: - return self._qubits - - def swap(self, q1: 'cirq.Qid', q2: 'cirq.Qid', *, inplace=False): - """Swaps two qubits. - - This only affects the index, and does not modify the underlying - state. - - Args: - q1: The first qubit to swap. - q2: The second qubit to swap. - inplace: True to swap the qubits in the current object, False to - create a copy with the qubits swapped. - - Returns: - The original object with the qubits swapped if inplace is - requested, or a copy of the original object with the qubits swapped - otherwise. - - Raises: - ValueError: If the qubits are of different dimensionality. - """ - if q1.dimension != q2.dimension: - raise ValueError(f'Cannot swap different dimensions: q1={q1}, q2={q2}') - - args = self if inplace else copy.copy(self) - i1 = self.qubits.index(q1) - i2 = self.qubits.index(q2) - qubits = list(args.qubits) - qubits[i1], qubits[i2] = qubits[i2], qubits[i1] - args._set_qubits(qubits) - return args - - def rename(self, q1: 'cirq.Qid', q2: 'cirq.Qid', *, inplace=False): - """Renames `q1` to `q2`. - - Args: - q1: The qubit to rename. - q2: The new name. - inplace: True to rename the qubit in the current object, False to - create a copy with the qubit renamed. - - Returns: - The original object with the qubits renamed if inplace is - requested, or a copy of the original object with the qubits renamed - otherwise. - - Raises: - ValueError: If the qubits are of different dimensionality. - """ - if q1.dimension != q2.dimension: - raise ValueError(f'Cannot rename to different dimensions: q1={q1}, q2={q2}') - - args = self if inplace else copy.copy(self) - i1 = self.qubits.index(q1) - qubits = list(args.qubits) - qubits[i1] = q2 - args._set_qubits(qubits) - return args - - def __getitem__(self: TSelf, item: Optional['cirq.Qid']) -> TSelf: - if item not in self.qubit_map: - raise IndexError(f'{item} not in {self.qubits}') - return self - - def __len__(self) -> int: - return len(self.qubits) - - def __iter__(self) -> Iterator[Optional['cirq.Qid']]: - return iter(self.qubits) - - @property - def can_represent_mixed_states(self) -> bool: - return self._state.can_represent_mixed_states if self._state is not None else False - - -def strat_act_on_from_apply_decompose( - val: Any, args: 'cirq.ActOnArgs', qubits: Sequence['cirq.Qid'] -) -> bool: - operations, qubits1, _ = _try_decompose_into_operations_and_qubits(val) - assert len(qubits1) == len(qubits) - qubit_map = {q: qubits[i] for i, q in enumerate(qubits1)} - if operations is None: - return NotImplemented - for operation in operations: - operation = operation.with_qubits(*[qubit_map[q] for q in operation.qubits]) - protocols.act_on(operation, args) - return True - - -TActOnArgs = TypeVar('TActOnArgs', bound=ActOnArgs) +@_compat.deprecated_class(deadline='v0.16', fix='Use cirq.SimulationState instead.') +class ActOnArgs(SimulationState): + pass diff --git a/cirq-core/cirq/sim/act_on_args_container.py b/cirq-core/cirq/sim/act_on_args_container.py index 0b4e0856b64..ca5453faca1 100644 --- a/cirq-core/cirq/sim/act_on_args_container.py +++ b/cirq-core/cirq/sim/act_on_args_container.py @@ -1,4 +1,4 @@ -# Copyright 2021 The Cirq Developers +# Copyright 2022 The Cirq Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,160 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from collections import abc -from typing import Any, Dict, Generic, Iterator, List, Mapping, Optional, Sequence, TYPE_CHECKING +from cirq import _compat +from cirq.sim.simulation_product_state import SimulationProductState -import numpy as np -from cirq import ops, protocols, value -from cirq.sim.act_on_args import TActOnArgs -from cirq.sim.operation_target import OperationTarget - -if TYPE_CHECKING: - import cirq - - -class ActOnArgsContainer(Generic[TActOnArgs], OperationTarget[TActOnArgs], abc.Mapping): - """A container for a `Qid`-to-`ActOnArgs` dictionary.""" - - def __init__( - self, - args: Dict[Optional['cirq.Qid'], TActOnArgs], - qubits: Sequence['cirq.Qid'], - split_untangled_states: bool, - classical_data: Optional['cirq.ClassicalDataStore'] = None, - ): - """Initializes the class. - - Args: - args: The `ActOnArgs` dictionary. This will not be copied; the - original reference will be kept here. - qubits: The canonical ordering of qubits. - split_untangled_states: If True, optimizes operations by running - unentangled qubit sets independently and merging those states - at the end. - classical_data: The shared classical data container for this - simulation. - """ - classical_data = classical_data or value.ClassicalDataDictionaryStore() - super().__init__(qubits=qubits, classical_data=classical_data) - self._args = args - self._split_untangled_states = split_untangled_states - - @property - def args(self) -> Mapping[Optional['cirq.Qid'], TActOnArgs]: - return self._args - - @property - def split_untangled_states(self) -> bool: - return self._split_untangled_states - - def create_merged_state(self) -> TActOnArgs: - if not self.split_untangled_states: - return self.args[None] - final_args = self.args[None] - for args in set([self.args[k] for k in self.args.keys() if k is not None]): - final_args = final_args.kronecker_product(args) - return final_args.transpose_to_qubit_order(self.qubits) - - def _act_on_fallback_( - self, action: Any, qubits: Sequence['cirq.Qid'], allow_decompose: bool = True - ) -> bool: - gate_opt = ( - action - if isinstance(action, ops.Gate) - else action.gate - if isinstance(action, ops.Operation) - else None - ) - - if isinstance(gate_opt, ops.IdentityGate): - return True - - if ( - isinstance(gate_opt, ops.SwapPowGate) - and gate_opt.exponent % 2 == 1 - and gate_opt.global_shift == 0 - ): - q0, q1 = qubits - args0 = self.args[q0] - args1 = self.args[q1] - if args0 is args1: - args0.swap(q0, q1, inplace=True) - else: - self._args[q0] = args1.rename(q1, q0, inplace=True) - self._args[q1] = args0.rename(q0, q1, inplace=True) - return True - - # Go through the op's qubits and join any disparate ActOnArgs states - # into a new combined state. - op_args_opt: Optional[TActOnArgs] = None - for q in qubits: - if op_args_opt is None: - op_args_opt = self.args[q] - elif q not in op_args_opt.qubits: - op_args_opt = op_args_opt.kronecker_product(self.args[q]) - op_args = op_args_opt or self.args[None] - - # (Backfill the args map with the new value) - for q in op_args.qubits: - self._args[q] = op_args - - # Act on the args with the operation - act_on_qubits = qubits if isinstance(action, ops.Gate) else None - protocols.act_on(action, op_args, act_on_qubits, allow_decompose=allow_decompose) - - # Decouple any measurements or resets - if self.split_untangled_states and isinstance( - gate_opt, (ops.ResetChannel, ops.MeasurementGate) - ): - for q in qubits: - if op_args.allows_factoring: - q_args, op_args = op_args.factor((q,), validate=False) - self._args[q] = q_args - - # (Backfill the args map with the new value) - for q in op_args.qubits: - self._args[q] = op_args - return True - - def copy(self, deep_copy_buffers: bool = True) -> 'cirq.ActOnArgsContainer[TActOnArgs]': - classical_data = self._classical_data.copy() - copies = {} - for act_on_args in set(self.args.values()): - copies[act_on_args] = act_on_args.copy(deep_copy_buffers) - for copy in copies.values(): - copy._classical_data = classical_data - args = {q: copies[a] for q, a in self.args.items()} - return ActOnArgsContainer( - args, self.qubits, self.split_untangled_states, classical_data=classical_data - ) - - def sample( - self, - qubits: List['cirq.Qid'], - repetitions: int = 1, - seed: 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None, - ) -> np.ndarray: - columns = [] - selected_order: List[ops.Qid] = [] - q_set = set(qubits) - for v in dict.fromkeys(self.args.values()): - qs = [q for q in v.qubits if q in q_set] - if any(qs): - column = v.sample(qs, repetitions, seed) - columns.append(column) - selected_order += qs - stacked = np.column_stack(columns) - qubit_map = {q: i for i, q in enumerate(selected_order)} - index_order = [qubit_map[q] for q in qubits] - return stacked[:, index_order] - - def __getitem__(self, item: Optional['cirq.Qid']) -> TActOnArgs: - return self.args[item] - - def __len__(self) -> int: - return len(self.args) - - def __iter__(self) -> Iterator[Optional['cirq.Qid']]: - return iter(self.args) +@_compat.deprecated_class(deadline='v0.16', fix='Use cirq.SimulationProductState instead.') +class ActOnArgsContainer(SimulationProductState): + pass diff --git a/cirq-core/cirq/sim/act_on_density_matrix_args.py b/cirq-core/cirq/sim/act_on_density_matrix_args.py index c596cbe5631..b7c840b996d 100644 --- a/cirq-core/cirq/sim/act_on_density_matrix_args.py +++ b/cirq-core/cirq/sim/act_on_density_matrix_args.py @@ -1,4 +1,4 @@ -# Copyright 2021 The Cirq Developers +# Copyright 2022 The Cirq Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,329 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Objects and methods for acting efficiently on a density matrix.""" -from typing import Any, Callable, List, Optional, Sequence, Tuple, TYPE_CHECKING, Type, Union +from cirq import _compat +from cirq.sim.density_matrix_simulation_state import DensityMatrixSimulationState -import numpy as np -from cirq import protocols, qis, sim -from cirq._compat import proper_repr -from cirq.linalg import transformations -from cirq.sim.act_on_args import ActOnArgs, strat_act_on_from_apply_decompose - -if TYPE_CHECKING: - import cirq - from numpy.typing import DTypeLike - - -class _BufferedDensityMatrix(qis.QuantumStateRepresentation): - """Contains the density matrix and buffers for efficient state evolution.""" - - def __init__(self, density_matrix: np.ndarray, buffer: Optional[List[np.ndarray]] = None): - """Initializes the object with the inputs. - - This initializer creates the buffer if necessary. - - Args: - density_matrix: The density matrix, must be correctly formatted. The data is not - checked for validity here due to performance concerns. - buffer: Optional, must be length 3 and same shape as the density matrix. If not - provided, a buffer will be created automatically. - Raises: - ValueError: If the array is not the shape of a density matrix. - """ - self._density_matrix = density_matrix - if buffer is None: - buffer = [np.empty_like(density_matrix) for _ in range(3)] - self._buffer = buffer - if len(density_matrix.shape) % 2 != 0: - raise ValueError('The dimension of target_tensor is not divisible by 2.') - self._qid_shape = density_matrix.shape[: len(density_matrix.shape) // 2] - - @classmethod - def create( - cls, - *, - initial_state: Union[np.ndarray, 'cirq.STATE_VECTOR_LIKE'] = 0, - qid_shape: Optional[Tuple[int, ...]] = None, - dtype: Optional['DTypeLike'] = None, - buffer: Optional[List[np.ndarray]] = None, - ): - """Creates a buffered density matrix with the requested state. - - Args: - initial_state: The initial state for the simulation in the computational basis. - qid_shape: The shape of the density matrix, if the initial state is provided as an int. - dtype: The desired dtype of the density matrix. - buffer: Optional, must be length 3 and same shape as the density matrix. If not - provided, a buffer will be created automatically. - Raises: - ValueError: If initial state is provided as integer, but qid_shape is not provided. - """ - if not isinstance(initial_state, np.ndarray): - if qid_shape is None: - raise ValueError('qid_shape must be provided if initial_state is not ndarray') - density_matrix = qis.to_valid_density_matrix( - initial_state, len(qid_shape), qid_shape=qid_shape, dtype=dtype - ).reshape(qid_shape * 2) - else: - if qid_shape is not None: - if dtype and initial_state.dtype != dtype: - initial_state = initial_state.astype(dtype) - density_matrix = qis.to_valid_density_matrix( - initial_state, len(qid_shape), qid_shape=qid_shape, dtype=dtype - ).reshape(qid_shape * 2) - else: - density_matrix = initial_state - if np.may_share_memory(density_matrix, initial_state): - density_matrix = density_matrix.copy() - density_matrix = density_matrix.astype(dtype, copy=False) - return cls(density_matrix, buffer) - - def copy(self, deep_copy_buffers: bool = True) -> '_BufferedDensityMatrix': - """Copies the object. - - Args: - deep_copy_buffers: True by default, False to reuse the existing buffers. - Returns: - A copy of the object. - """ - return _BufferedDensityMatrix( - density_matrix=self._density_matrix.copy(), - buffer=[b.copy() for b in self._buffer] if deep_copy_buffers else self._buffer, - ) - - def kron(self, other: '_BufferedDensityMatrix') -> '_BufferedDensityMatrix': - """Creates the Kronecker product with the other density matrix. - - Args: - other: The density matrix with which to kron. - Returns: - The Kronecker product of the two density matrices. - """ - density_matrix = transformations.density_matrix_kronecker_product( - self._density_matrix, other._density_matrix - ) - return _BufferedDensityMatrix(density_matrix=density_matrix) - - def factor( - self, axes: Sequence[int], *, validate=True, atol=1e-07 - ) -> Tuple['_BufferedDensityMatrix', '_BufferedDensityMatrix']: - """Factors out the desired axes. - - Args: - axes: The axes to factor out. Only the left axes should be provided. For example, to - extract [C,A] from density matrix of shape [A,B,C,D,A,B,C,D], `axes` should be - [2,0], and the return value will be two density matrices ([C,A,C,A], [B,D,B,D]). - validate: Perform a validation that the density matrix factors cleanly. - atol: The absolute tolerance for the validation. - Returns: - A tuple with the `(extracted, remainder)` density matrices, where `extracted` means - the sub-matrix which corresponds to the axes requested, and with the axes in the - requested order, and where `remainder` means the sub-matrix on the remaining axes, - in the same order as the original density matrix. - """ - extracted_tensor, remainder_tensor = transformations.factor_density_matrix( - self._density_matrix, axes, validate=validate, atol=atol - ) - extracted = _BufferedDensityMatrix(density_matrix=extracted_tensor) - remainder = _BufferedDensityMatrix(density_matrix=remainder_tensor) - return extracted, remainder - - def reindex(self, axes: Sequence[int]) -> '_BufferedDensityMatrix': - """Transposes the axes of a density matrix to a specified order. - - Args: - axes: The desired axis order. Only the left axes should be provided. For example, to - transpose [A,B,C,A,B,C] to [C,B,A,C,B,A], `axes` should be [2,1,0]. - Returns: - The transposed density matrix. - """ - new_tensor = transformations.transpose_density_matrix_to_axis_order( - self._density_matrix, axes - ) - return _BufferedDensityMatrix(density_matrix=new_tensor) - - def apply_channel(self, action: Any, axes: Sequence[int]) -> bool: - """Apply channel to state. - - Args: - action: The value with a channel to apply. - axes: The axes on which to apply the channel. - Returns: - True if the action succeeded. - """ - result = protocols.apply_channel( - action, - args=protocols.ApplyChannelArgs( - target_tensor=self._density_matrix, - out_buffer=self._buffer[0], - auxiliary_buffer0=self._buffer[1], - auxiliary_buffer1=self._buffer[2], - left_axes=axes, - right_axes=[e + len(self._qid_shape) for e in axes], - ), - default=None, - ) - if result is None: - return False - for i in range(len(self._buffer)): - if result is self._buffer[i]: - self._buffer[i] = self._density_matrix - self._density_matrix = result - return True - - def measure( - self, axes: Sequence[int], seed: 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None - ) -> List[int]: - """Measures the density matrix. - - Args: - axes: The axes to measure. - seed: The random number seed to use. - Returns: - The measurements in order. - """ - bits, _ = sim.measure_density_matrix( - self._density_matrix, - axes, - out=self._density_matrix, - qid_shape=self._qid_shape, - seed=seed, - ) - return bits - - def sample( - self, - axes: Sequence[int], - repetitions: int = 1, - seed: 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None, - ) -> np.ndarray: - """Samples the density matrix. - - Args: - axes: The axes to sample. - repetitions: The number of samples to make. - seed: The random number seed to use. - Returns: - The samples in order. - """ - return sim.sample_density_matrix( - self._density_matrix, - axes, - qid_shape=self._qid_shape, - repetitions=repetitions, - seed=seed, - ) - - @property - def supports_factor(self) -> bool: - return True - - @property - def can_represent_mixed_states(self) -> bool: - return True - - -class ActOnDensityMatrixArgs(ActOnArgs[_BufferedDensityMatrix]): - """State and context for an operation acting on a density matrix. - - To act on this object, directly edit the `target_tensor` property, which is - storing the density matrix of the quantum system with one axis per qubit. - """ - - def __init__( - self, - *, - available_buffer: Optional[List[np.ndarray]] = None, - qid_shape: Optional[Tuple[int, ...]] = None, - prng: Optional[np.random.RandomState] = None, - qubits: Optional[Sequence['cirq.Qid']] = None, - initial_state: Union[np.ndarray, 'cirq.STATE_VECTOR_LIKE'] = 0, - dtype: Type[np.number] = np.complex64, - classical_data: Optional['cirq.ClassicalDataStore'] = None, - ): - """Inits ActOnDensityMatrixArgs. - - Args: - available_buffer: A workspace with the same shape and dtype as - `target_tensor`. Used by operations that cannot be applied to - `target_tensor` inline, in order to avoid unnecessary - allocations. - qubits: Determines the canonical ordering of the qubits. This - is often used in specifying the initial state, i.e. the - ordering of the computational basis states. - qid_shape: The shape of the target tensor. - prng: The pseudo random number generator to use for probabilistic - effects. - initial_state: The initial state for the simulation in the - computational basis. - dtype: The `numpy.dtype` of the inferred state vector. One of - `numpy.complex64` or `numpy.complex128`. Only used when - `target_tenson` is None. - classical_data: The shared classical data container for this - simulation. - - Raises: - ValueError: The dimension of `target_tensor` is not divisible by 2 - and `qid_shape` is not provided. - """ - state = _BufferedDensityMatrix.create( - initial_state=initial_state, - qid_shape=tuple(q.dimension for q in qubits) if qubits is not None else None, - dtype=dtype, - buffer=available_buffer, - ) - super().__init__(state=state, prng=prng, qubits=qubits, classical_data=classical_data) - - def _act_on_fallback_( - self, action: Any, qubits: Sequence['cirq.Qid'], allow_decompose: bool = True - ) -> bool: - strats: List[Callable[[Any, Any, Sequence['cirq.Qid']], bool]] = [ - _strat_apply_channel_to_state - ] - if allow_decompose: - strats.append(strat_act_on_from_apply_decompose) - - # Try each strategy, stopping if one works. - for strat in strats: - result = strat(action, self, qubits) - if result is False: - break # coverage: ignore - if result is True: - return True - assert result is NotImplemented, str(result) - raise TypeError( - "Can't simulate operations that don't implement " - "SupportsUnitary, SupportsConsistentApplyUnitary, " - "SupportsMixture or SupportsKraus or is a measurement: {!r}".format(action) - ) - - def __repr__(self) -> str: - return ( - 'cirq.ActOnDensityMatrixArgs(' - f'initial_state={proper_repr(self.target_tensor)},' - f' qid_shape={self.qid_shape!r},' - f' qubits={self.qubits!r},' - f' classical_data={self.classical_data!r})' - ) - - @property - def target_tensor(self): - return self._state._density_matrix - - @property - def available_buffer(self): - return self._state._buffer - - @property - def qid_shape(self): - return self._state._qid_shape - - -def _strat_apply_channel_to_state( - action: Any, args: 'cirq.ActOnDensityMatrixArgs', qubits: Sequence['cirq.Qid'] -) -> bool: - """Apply channel to state.""" - return True if args._state.apply_channel(action, args.get_axes(qubits)) else NotImplemented +@_compat.deprecated_class(deadline='v0.16', fix='Use cirq.DensityMatrixSimulationState instead.') +class ActOnDensityMatrixArgs(DensityMatrixSimulationState): + pass diff --git a/cirq-core/cirq/sim/act_on_state_vector_args.py b/cirq-core/cirq/sim/act_on_state_vector_args.py index 93c258de2ba..1aaf1c39c81 100644 --- a/cirq-core/cirq/sim/act_on_state_vector_args.py +++ b/cirq-core/cirq/sim/act_on_state_vector_args.py @@ -1,4 +1,4 @@ -# Copyright 2018 The Cirq Developers +# Copyright 2022 The Cirq Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,491 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Objects and methods for acting efficiently on a state vector.""" -from typing import Any, Callable, List, Optional, Sequence, Tuple, TYPE_CHECKING, Type, Union +from cirq import _compat +from cirq.sim.state_vector_simulation_state import StateVectorSimulationState -import numpy as np -from cirq import _compat, linalg, protocols, qis, sim -from cirq._compat import proper_repr -from cirq.linalg import transformations -from cirq.sim.act_on_args import ActOnArgs, strat_act_on_from_apply_decompose - -if TYPE_CHECKING: - import cirq - from numpy.typing import DTypeLike - - -class _BufferedStateVector(qis.QuantumStateRepresentation): - """Contains the state vector and buffer for efficient state evolution.""" - - def __init__(self, state_vector: np.ndarray, buffer: Optional[np.ndarray] = None): - """Initializes the object with the inputs. - - This initializer creates the buffer if necessary. - - Args: - state_vector: The state vector, must be correctly formatted. The data is not checked - for validity here due to performance concerns. - buffer: Optional, must be same shape as the state vector. If not provided, a buffer - will be created automatically. - """ - self._state_vector = state_vector - if buffer is None: - buffer = np.empty_like(state_vector) - self._buffer = buffer - self._qid_shape = state_vector.shape - - @classmethod - def create( - cls, - *, - initial_state: Union[np.ndarray, 'cirq.STATE_VECTOR_LIKE'] = 0, - qid_shape: Optional[Tuple[int, ...]] = None, - dtype: Optional['DTypeLike'] = None, - buffer: Optional[List[np.ndarray]] = None, - ): - """Initializes the object with the inputs. - - This initializer creates the buffer if necessary. - - Args: - initial_state: The density matrix, must be correctly formatted. The data is not - checked for validity here due to performance concerns. - qid_shape: The shape of the density matrix, if the initial state is provided as an int. - dtype: The dtype of the density matrix, if the initial state is provided as an int. - buffer: Optional, must be length 3 and same shape as the density matrix. If not - provided, a buffer will be created automatically. - Raises: - ValueError: If initial state is provided as integer, but qid_shape is not provided. - """ - if not isinstance(initial_state, np.ndarray): - if qid_shape is None: - raise ValueError('qid_shape must be provided if initial_state is not ndarray') - state_vector = qis.to_valid_state_vector( - initial_state, len(qid_shape), qid_shape=qid_shape, dtype=dtype - ).reshape(qid_shape) - else: - if qid_shape is not None: - state_vector = initial_state.reshape(qid_shape) - else: - state_vector = initial_state - if np.may_share_memory(state_vector, initial_state): - state_vector = state_vector.copy() - state_vector = state_vector.astype(dtype, copy=False) - return cls(state_vector, buffer) - - def copy(self, deep_copy_buffers: bool = True) -> '_BufferedStateVector': - """Copies the object. - - Args: - deep_copy_buffers: True by default, False to reuse the existing buffers. - Returns: - A copy of the object. - """ - return _BufferedStateVector( - state_vector=self._state_vector.copy(), - buffer=self._buffer.copy() if deep_copy_buffers else self._buffer, - ) - - def kron(self, other: '_BufferedStateVector') -> '_BufferedStateVector': - """Creates the Kronecker product with the other state vector. - - Args: - other: The state vector with which to kron. - Returns: - The Kronecker product of the two state vectors. - """ - target_tensor = transformations.state_vector_kronecker_product( - self._state_vector, other._state_vector - ) - return _BufferedStateVector(state_vector=target_tensor, buffer=np.empty_like(target_tensor)) - - def factor( - self, axes: Sequence[int], *, validate=True, atol=1e-07 - ) -> Tuple['_BufferedStateVector', '_BufferedStateVector']: - """Factors a state vector into two independent state vectors. - - This function should only be called on state vectors that are known to be separable, such - as immediately after a measurement or reset operation. It does not verify that the provided - state vector is indeed separable, and will return nonsense results for vectors - representing entangled states. - - Args: - axes: The axes to factor out. - validate: Perform a validation that the state vector factors cleanly. - atol: The absolute tolerance for the validation. - - Returns: - A tuple with the `(extracted, remainder)` state vectors, where `extracted` means the - sub-state vector which corresponds to the axes requested, and with the axes in the - requested order, and where `remainder` means the sub-state vector on the remaining - axes, in the same order as the original state vector. - """ - extracted_tensor, remainder_tensor = transformations.factor_state_vector( - self._state_vector, axes, validate=validate, atol=atol - ) - extracted = _BufferedStateVector( - state_vector=extracted_tensor, buffer=np.empty_like(extracted_tensor) - ) - remainder = _BufferedStateVector( - state_vector=remainder_tensor, buffer=np.empty_like(remainder_tensor) - ) - return extracted, remainder - - def reindex(self, axes: Sequence[int]) -> '_BufferedStateVector': - """Transposes the axes of a state vector to a specified order. - - Args: - axes: The desired axis order. - Returns: - The transposed state vector. - """ - new_tensor = transformations.transpose_state_vector_to_axis_order(self._state_vector, axes) - return _BufferedStateVector(state_vector=new_tensor, buffer=np.empty_like(new_tensor)) - - def apply_unitary(self, action: Any, axes: Sequence[int]) -> bool: - """Apply unitary to state. - - Args: - action: The value with a unitary to apply. - axes: The axes on which to apply the unitary. - Returns: - True if the operation succeeded. - """ - new_target_tensor = protocols.apply_unitary( - action, - protocols.ApplyUnitaryArgs( - target_tensor=self._state_vector, available_buffer=self._buffer, axes=axes - ), - allow_decompose=False, - default=NotImplemented, - ) - if new_target_tensor is NotImplemented: - return False - self._swap_target_tensor_for(new_target_tensor) - return True - - def apply_mixture(self, action: Any, axes: Sequence[int], prng) -> Optional[int]: - """Apply mixture to state. - - Args: - action: The value with a mixture to apply. - axes: The axes on which to apply the mixture. - prng: The pseudo random number generator to use. - Returns: - The mixture index if the operation succeeded, otherwise None. - """ - mixture = protocols.mixture(action, default=None) - if mixture is None: - return None - probabilities, unitaries = zip(*mixture) - - index = prng.choice(range(len(unitaries)), p=probabilities) - shape = protocols.qid_shape(action) * 2 - unitary = unitaries[index].astype(self._state_vector.dtype).reshape(shape) - linalg.targeted_left_multiply(unitary, self._state_vector, axes, out=self._buffer) - self._swap_target_tensor_for(self._buffer) - return index - - def apply_channel(self, action: Any, axes: Sequence[int], prng) -> Optional[int]: - """Apply channel to state. - - Args: - action: The value with a channel to apply. - axes: The axes on which to apply the channel. - prng: The pseudo random number generator to use. - Returns: - The kraus index if the operation succeeded, otherwise None. - """ - kraus_operators = protocols.kraus(action, default=None) - if kraus_operators is None: - return None - - def prepare_into_buffer(k: int): - linalg.targeted_left_multiply( - left_matrix=kraus_tensors[k], - right_target=self._state_vector, - target_axes=axes, - out=self._buffer, - ) - - shape = protocols.qid_shape(action) - kraus_tensors = [ - e.reshape(shape * 2).astype(self._state_vector.dtype) for e in kraus_operators - ] - p = prng.random() - weight = None - fallback_weight = 0 - fallback_weight_index = 0 - index = None - for index in range(len(kraus_tensors)): - prepare_into_buffer(index) - weight = np.linalg.norm(self._buffer) ** 2 - - if weight > fallback_weight: - fallback_weight_index = index - fallback_weight = weight - - p -= weight - if p < 0: - break - - assert weight is not None, "No Kraus operators" - if p >= 0 or weight == 0: - # Floating point error resulted in a malformed sample. - # Fall back to the most likely case. - prepare_into_buffer(fallback_weight_index) - weight = fallback_weight - index = fallback_weight_index - - self._buffer /= np.sqrt(weight) - self._swap_target_tensor_for(self._buffer) - return index - - def measure( - self, axes: Sequence[int], seed: 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None - ) -> List[int]: - """Measures the state vector. - - Args: - axes: The axes to measure. - seed: The random number seed to use. - Returns: - The measurements in order. - """ - bits, _ = sim.measure_state_vector( - self._state_vector, axes, out=self._state_vector, qid_shape=self._qid_shape, seed=seed - ) - return bits - - def sample( - self, - axes: Sequence[int], - repetitions: int = 1, - seed: 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None, - ) -> np.ndarray: - """Samples the state vector. - - Args: - axes: The axes to sample. - repetitions: The number of samples to make. - seed: The random number seed to use. - Returns: - The samples in order. - """ - return sim.sample_state_vector( - self._state_vector, axes, qid_shape=self._qid_shape, repetitions=repetitions, seed=seed - ) - - def _swap_target_tensor_for(self, new_target_tensor: np.ndarray): - """Gives a new state vector for the system. - - Typically, the new state vector should be `args.available_buffer` where - `args` is this `cirq.ActOnStateVectorArgs` instance. - - Args: - new_target_tensor: The new system state. Must have the same shape - and dtype as the old system state. - """ - if new_target_tensor is self._buffer: - self._buffer = self._state_vector - self._state_vector = new_target_tensor - - @property - def supports_factor(self) -> bool: - return True - - -class ActOnStateVectorArgs(ActOnArgs[_BufferedStateVector]): - """State and context for an operation acting on a state vector. - - There are two common ways to act on this object: - - 1. Directly edit the `target_tensor` property, which is storing the state - vector of the quantum system as a numpy array with one axis per qudit. - 2. Overwrite the `available_buffer` property with the new state vector, and - then pass `available_buffer` into `swap_target_tensor_for`. - """ - - def __init__( - self, - *, - available_buffer: Optional[np.ndarray] = None, - prng: Optional[np.random.RandomState] = None, - qubits: Optional[Sequence['cirq.Qid']] = None, - initial_state: Union[np.ndarray, 'cirq.STATE_VECTOR_LIKE'] = 0, - dtype: Type[np.number] = np.complex64, - classical_data: Optional['cirq.ClassicalDataStore'] = None, - ): - """Inits ActOnStateVectorArgs. - - Args: - available_buffer: A workspace with the same shape and dtype as - `target_tensor`. Used by operations that cannot be applied to - `target_tensor` inline, in order to avoid unnecessary - allocations. Passing `available_buffer` into - `swap_target_tensor_for` will swap it for `target_tensor`. - qubits: Determines the canonical ordering of the qubits. This - is often used in specifying the initial state, i.e. the - ordering of the computational basis states. - prng: The pseudo random number generator to use for probabilistic - effects. - initial_state: The initial state for the simulation in the - computational basis. - dtype: The `numpy.dtype` of the inferred state vector. One of - `numpy.complex64` or `numpy.complex128`. Only used when - `target_tenson` is None. - classical_data: The shared classical data container for this - simulation. - """ - state = _BufferedStateVector.create( - initial_state=initial_state, - qid_shape=tuple(q.dimension for q in qubits) if qubits is not None else None, - dtype=dtype, - buffer=available_buffer, - ) - super().__init__(state=state, prng=prng, qubits=qubits, classical_data=classical_data) - - @_compat.deprecated( - deadline='v0.16', fix='None, this function was unintentionally made public.' - ) - def swap_target_tensor_for(self, new_target_tensor: np.ndarray): - """Gives a new state vector for the system. - - Typically, the new state vector should be `args.available_buffer` where - `args` is this `cirq.ActOnStateVectorArgs` instance. - - Args: - new_target_tensor: The new system state. Must have the same shape - and dtype as the old system state. - """ - self._state._swap_target_tensor_for(new_target_tensor) - - @_compat.deprecated( - deadline='v0.16', fix='None, this function was unintentionally made public.' - ) - def subspace_index( - self, axes: Sequence[int], little_endian_bits_int: int = 0, *, big_endian_bits_int: int = 0 - ) -> Tuple[Union[slice, int, 'ellipsis'], ...]: - """An index for the subspace where the target axes equal a value. - - Args: - axes: The qubits that are specified by the index bits. - little_endian_bits_int: The desired value of the qubits at the - targeted `axes`, packed into an integer. The least significant - bit of the integer is the desired bit for the first axis, and - so forth in increasing order. Can't be specified at the same - time as `big_endian_bits_int`. - - When operating on qudits instead of qubits, the same basic logic - applies but in a different basis. For example, if the target - axes have dimension [a:2, b:3, c:2] then the integer 10 - decomposes into [a=0, b=2, c=1] via 7 = 1*(3*2) + 2*(2) + 0. - big_endian_bits_int: The desired value of the qubits at the - targeted `axes`, packed into an integer. The most significant - bit of the integer is the desired bit for the first axis, and - so forth in decreasing order. Can't be specified at the same - time as `little_endian_bits_int`. - - When operating on qudits instead of qubits, the same basic logic - applies but in a different basis. For example, if the target - axes have dimension [a:2, b:3, c:2] then the integer 10 - decomposes into [a=1, b=2, c=0] via 7 = 1*(3*2) + 2*(2) + 0. - - Returns: - A value that can be used to index into `target_tensor` and - `available_buffer`, and manipulate only the part of Hilbert space - corresponding to a given bit assignment. - - Example: - If `target_tensor` is a 4 qubit tensor and `axes` is `[1, 3]` and - then this method will return the following when given - `little_endian_bits=0b01`: - - `(slice(None), 0, slice(None), 1, Ellipsis)` - - Therefore the following two lines would be equivalent: - - args.target_tensor[args.subspace_index(0b01)] += 1 - - args.target_tensor[:, 0, :, 1] += 1 - """ - return linalg.slice_for_qubits_equal_to( - axes, - little_endian_qureg_value=little_endian_bits_int, - big_endian_qureg_value=big_endian_bits_int, - qid_shape=self.target_tensor.shape, - ) - - def _act_on_fallback_( - self, action: Any, qubits: Sequence['cirq.Qid'], allow_decompose: bool = True - ) -> bool: - strats: List[Callable[[Any, Any, Sequence['cirq.Qid']], bool]] = [ - _strat_act_on_state_vector_from_apply_unitary, - _strat_act_on_state_vector_from_mixture, - _strat_act_on_state_vector_from_channel, - ] - if allow_decompose: - strats.append(strat_act_on_from_apply_decompose) # type: ignore - - # Try each strategy, stopping if one works. - for strat in strats: - result = strat(action, self, qubits) - if result is False: - break # coverage: ignore - if result is True: - return True - assert result is NotImplemented, str(result) - raise TypeError( - "Can't simulate operations that don't implement " - "SupportsUnitary, SupportsConsistentApplyUnitary, " - "SupportsMixture or is a measurement: {!r}".format(action) - ) - - def __repr__(self) -> str: - return ( - 'cirq.ActOnStateVectorArgs(' - f'initial_state={proper_repr(self.target_tensor)},' - f' qubits={self.qubits!r},' - f' classical_data={self.classical_data!r})' - ) - - @property - def target_tensor(self): - return self._state._state_vector - - @property - def available_buffer(self): - return self._state._buffer - - -def _strat_act_on_state_vector_from_apply_unitary( - action: Any, args: 'cirq.ActOnStateVectorArgs', qubits: Sequence['cirq.Qid'] -) -> bool: - return True if args._state.apply_unitary(action, args.get_axes(qubits)) else NotImplemented - - -def _strat_act_on_state_vector_from_mixture( - action: Any, args: 'cirq.ActOnStateVectorArgs', qubits: Sequence['cirq.Qid'] -) -> bool: - index = args._state.apply_mixture(action, args.get_axes(qubits), args.prng) - if index is None: - return NotImplemented - if protocols.is_measurement(action): - key = protocols.measurement_key_name(action) - args._classical_data.record_channel_measurement(key, index) - return True - - -def _strat_act_on_state_vector_from_channel( - action: Any, args: 'cirq.ActOnStateVectorArgs', qubits: Sequence['cirq.Qid'] -) -> bool: - index = args._state.apply_channel(action, args.get_axes(qubits), args.prng) - if index is None: - return NotImplemented - if protocols.is_measurement(action): - key = protocols.measurement_key_name(action) - args._classical_data.record_channel_measurement(key, index) - return True +@_compat.deprecated_class(deadline='v0.16', fix='Use cirq.StateVectorSimulationState instead.') +class ActOnStateVectorArgs(StateVectorSimulationState): + pass diff --git a/cirq-core/cirq/sim/clifford/__init__.py b/cirq-core/cirq/sim/clifford/__init__.py index d466ca209b5..1de27e171e6 100644 --- a/cirq-core/cirq/sim/clifford/__init__.py +++ b/cirq-core/cirq/sim/clifford/__init__.py @@ -1,17 +1,24 @@ # pylint: disable=wrong-or-nonexistent-copyright-notice -from cirq.sim.clifford.act_on_clifford_tableau_args import ActOnCliffordTableauArgs -from cirq.sim.clifford.act_on_stabilizer_ch_form_args import ActOnStabilizerCHFormArgs +from cirq.sim.clifford.act_on_clifford_tableau_args import ActOnCliffordTableauArgs from cirq.sim.clifford.act_on_stabilizer_args import ActOnStabilizerArgs -from cirq.sim.clifford.stabilizer_state_ch_form import StabilizerStateChForm +from cirq.sim.clifford.act_on_stabilizer_ch_form_args import ActOnStabilizerCHFormArgs from cirq.sim.clifford.clifford_simulator import ( CliffordSimulator, + CliffordSimulatorStepResult, CliffordState, CliffordTrialResult, - CliffordSimulatorStepResult, ) +from cirq.sim.clifford.clifford_tableau_simulation_state import CliffordTableauSimulationState + +from cirq.sim.clifford.stabilizer_ch_form_simulation_state import StabilizerChFormSimulationState + from cirq.sim.clifford.stabilizer_sampler import StabilizerSampler + +from cirq.sim.clifford.stabilizer_simulation_state import StabilizerSimulationState + +from cirq.sim.clifford.stabilizer_state_ch_form import StabilizerStateChForm diff --git a/cirq-core/cirq/sim/clifford/act_on_clifford_tableau_args.py b/cirq-core/cirq/sim/clifford/act_on_clifford_tableau_args.py index 0008f859839..4c0c9ea60ae 100644 --- a/cirq-core/cirq/sim/clifford/act_on_clifford_tableau_args.py +++ b/cirq-core/cirq/sim/clifford/act_on_clifford_tableau_args.py @@ -1,4 +1,4 @@ -# Copyright 2018 The Cirq Developers +# Copyright 2022 The Cirq Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,45 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""A protocol for implementing high performance clifford tableau evolutions - for Clifford Simulator.""" -from typing import Optional, Sequence, TYPE_CHECKING +from cirq import _compat +from cirq.sim.clifford.clifford_tableau_simulation_state import CliffordTableauSimulationState -import numpy as np -from cirq.qis import clifford_tableau -from cirq.sim.clifford.act_on_stabilizer_args import ActOnStabilizerArgs - -if TYPE_CHECKING: - import cirq - - -class ActOnCliffordTableauArgs(ActOnStabilizerArgs[clifford_tableau.CliffordTableau]): - """State and context for an operation acting on a clifford tableau.""" - - def __init__( - self, - tableau: 'cirq.CliffordTableau', - prng: Optional[np.random.RandomState] = None, - qubits: Optional[Sequence['cirq.Qid']] = None, - classical_data: Optional['cirq.ClassicalDataStore'] = None, - ): - """Inits ActOnCliffordTableauArgs. - - Args: - tableau: The CliffordTableau to act on. Operations are expected to - perform inplace edits of this object. - qubits: Determines the canonical ordering of the qubits. This - is often used in specifying the initial state, i.e. the - ordering of the computational basis states. - prng: The pseudo random number generator to use for probabilistic - effects. - classical_data: The shared classical data container for this - simulation. - """ - super().__init__(state=tableau, prng=prng, qubits=qubits, classical_data=classical_data) - - @property - def tableau(self) -> 'cirq.CliffordTableau': - return self.state +@_compat.deprecated_class(deadline='v0.16', fix='Use cirq.CliffordTableauSimulationState instead.') +class ActOnCliffordTableauArgs(CliffordTableauSimulationState): + pass diff --git a/cirq-core/cirq/sim/clifford/act_on_stabilizer_args.py b/cirq-core/cirq/sim/clifford/act_on_stabilizer_args.py index d95b072eac1..63758b5a982 100644 --- a/cirq-core/cirq/sim/clifford/act_on_stabilizer_args.py +++ b/cirq-core/cirq/sim/clifford/act_on_stabilizer_args.py @@ -1,4 +1,4 @@ -# Copyright 2021 The Cirq Developers +# Copyright 2022 The Cirq Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,175 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -import abc -from typing import Any, Dict, Generic, List, Optional, Sequence, TYPE_CHECKING, TypeVar, Union +from cirq import _compat +from cirq.sim.clifford.stabilizer_simulation_state import StabilizerSimulationState -import numpy as np -from cirq import linalg, ops, protocols -from cirq._compat import deprecated_parameter -from cirq.ops import common_gates, global_phase_op, matrix_gates, swap_gates -from cirq.ops.clifford_gate import SingleQubitCliffordGate -from cirq.protocols import has_unitary, num_qubits, unitary -from cirq.sim.act_on_args import ActOnArgs -from cirq.type_workarounds import NotImplementedType - -if TYPE_CHECKING: - import cirq - - -TStabilizerState = TypeVar('TStabilizerState', bound='cirq.StabilizerState') - - -class ActOnStabilizerArgs( - ActOnArgs[TStabilizerState], Generic[TStabilizerState], metaclass=abc.ABCMeta -): - """Abstract wrapper around a stabilizer state for the act_on protocol.""" - - @deprecated_parameter( - deadline='v0.16', - fix='Use kwargs instead of positional args', - parameter_desc='args', - match=lambda args, kwargs: len(args) > 1, - ) - @deprecated_parameter( - deadline='v0.16', - fix='Replace log_of_measurement_results with' - ' classical_data=cirq.ClassicalDataDictionaryStore(_records=logs).', - parameter_desc='log_of_measurement_results', - match=lambda args, kwargs: 'log_of_measurement_results' in kwargs, - ) - def __init__( - self, - state: TStabilizerState, - prng: Optional[np.random.RandomState] = None, - log_of_measurement_results: Optional[Dict[str, List[int]]] = None, - qubits: Optional[Sequence['cirq.Qid']] = None, - classical_data: Optional['cirq.ClassicalDataStore'] = None, - ): - """Initializes the ActOnStabilizerArgs. - - Args: - state: The quantum stabilizer state to use in the simulation or - act_on invocation. - prng: The pseudo random number generator to use for probabilistic - effects. - qubits: Determines the canonical ordering of the qubits. This - is often used in specifying the initial state, i.e. the - ordering of the computational basis states. - log_of_measurement_results: A mutable object that measurements are - being recorded into. - classical_data: The shared classical data container for this - simulation. - """ - if log_of_measurement_results is not None: - super().__init__( - state=state, - prng=prng, - qubits=qubits, - log_of_measurement_results=log_of_measurement_results, - classical_data=classical_data, - ) - else: - super().__init__(state=state, prng=prng, qubits=qubits, classical_data=classical_data) - - @property - def state(self) -> TStabilizerState: - return self._state - - def _act_on_fallback_( - self, action: Any, qubits: Sequence['cirq.Qid'], allow_decompose: bool = True - ) -> Union[bool, NotImplementedType]: - strats = [self._strat_apply_gate, self._strat_apply_mixture] - if allow_decompose: - strats.append(self._strat_decompose) - strats.append(self._strat_act_from_single_qubit_decompose) - for strat in strats: - result = strat(action, qubits) # type: ignore - if result is True: - return True - assert result is NotImplemented, str(result) - - return NotImplemented - - def _swap( - self, control_axis: int, target_axis: int, exponent: float = 1, global_shift: float = 0 - ): - """Apply a SWAP gate""" - if exponent % 1 != 0: - raise ValueError('Swap exponent must be integer') # coverage: ignore - self._state.apply_cx(control_axis, target_axis) - self._state.apply_cx(target_axis, control_axis, exponent, global_shift) - self._state.apply_cx(control_axis, target_axis) - - def _strat_apply_gate(self, val: Any, qubits: Sequence['cirq.Qid']) -> bool: - if not protocols.has_stabilizer_effect(val): - return NotImplemented - gate = val.gate if isinstance(val, ops.Operation) else val - axes = self.get_axes(qubits) - if isinstance(gate, common_gates.XPowGate): - self._state.apply_x(axes[0], gate.exponent, gate.global_shift) - elif isinstance(gate, common_gates.YPowGate): - self._state.apply_y(axes[0], gate.exponent, gate.global_shift) - elif isinstance(gate, common_gates.ZPowGate): - self._state.apply_z(axes[0], gate.exponent, gate.global_shift) - elif isinstance(gate, common_gates.HPowGate): - self._state.apply_h(axes[0], gate.exponent, gate.global_shift) - elif isinstance(gate, common_gates.CXPowGate): - self._state.apply_cx(axes[0], axes[1], gate.exponent, gate.global_shift) - elif isinstance(gate, common_gates.CZPowGate): - self._state.apply_cz(axes[0], axes[1], gate.exponent, gate.global_shift) - elif isinstance(gate, global_phase_op.GlobalPhaseGate): - self._state.apply_global_phase(gate.coefficient) - elif isinstance(gate, swap_gates.SwapPowGate): - self._swap(axes[0], axes[1], gate.exponent, gate.global_shift) - else: - return NotImplemented - return True - - def _strat_apply_mixture(self, val: Any, qubits: Sequence['cirq.Qid']) -> bool: - mixture = protocols.mixture(val, None) - if mixture is None: - return NotImplemented - if not all(linalg.is_unitary(m) for _, m in mixture): - return NotImplemented - probabilities, unitaries = zip(*mixture) - index = self.prng.choice(len(unitaries), p=probabilities) - return self._strat_act_from_single_qubit_decompose( - matrix_gates.MatrixGate(unitaries[index]), qubits - ) - - def _strat_act_from_single_qubit_decompose( - self, val: Any, qubits: Sequence['cirq.Qid'] - ) -> bool: - if num_qubits(val) == 1: - if not has_unitary(val): - return NotImplemented - u = unitary(val) - clifford_gate = SingleQubitCliffordGate.from_unitary(u) - if clifford_gate is not None: - # Gather the effective unitary applied so as to correct for the - # global phase later. - final_unitary = np.eye(2) - for axis, quarter_turns in clifford_gate.decompose_rotation(): - gate = axis ** (quarter_turns / 2) - self._strat_apply_gate(gate, qubits) - final_unitary = np.matmul(unitary(gate), final_unitary) - - # Find the entry with the largest magnitude in the input unitary. - k = max(np.ndindex(*u.shape), key=lambda t: abs(u[t])) - # Correct the global phase that wasn't conserved in the above - # decomposition. - self._state.apply_global_phase(u[k] / final_unitary[k]) - return True - - return NotImplemented - - def _strat_decompose(self, val: Any, qubits: Sequence['cirq.Qid']) -> bool: - gate = val.gate if isinstance(val, ops.Operation) else val - operations = protocols.decompose_once_with_qubits(gate, qubits, None) - if operations is None or not all(protocols.has_stabilizer_effect(op) for op in operations): - return NotImplemented - for op in operations: - protocols.act_on(op, self) - return True +@_compat.deprecated_class(deadline='v0.16', fix='Use cirq.StabilizerSimulationState instead.') +class ActOnStabilizerArgs(StabilizerSimulationState): + pass diff --git a/cirq-core/cirq/sim/clifford/act_on_stabilizer_ch_form_args.py b/cirq-core/cirq/sim/clifford/act_on_stabilizer_ch_form_args.py index 66069f7ea4f..b04b18083ec 100644 --- a/cirq-core/cirq/sim/clifford/act_on_stabilizer_ch_form_args.py +++ b/cirq-core/cirq/sim/clifford/act_on_stabilizer_ch_form_args.py @@ -1,4 +1,4 @@ -# Copyright 2020 The Cirq Developers +# Copyright 2022 The Cirq Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,65 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Sequence, TYPE_CHECKING, Union +from cirq import _compat +from cirq.sim.clifford.stabilizer_ch_form_simulation_state import StabilizerChFormSimulationState -import numpy as np -from cirq._compat import proper_repr -from cirq.sim.clifford import stabilizer_state_ch_form -from cirq.sim.clifford.act_on_stabilizer_args import ActOnStabilizerArgs - -if TYPE_CHECKING: - import cirq - - -class ActOnStabilizerCHFormArgs( - ActOnStabilizerArgs[stabilizer_state_ch_form.StabilizerStateChForm] -): - """Wrapper around a stabilizer state in CH form for the act_on protocol.""" - - def __init__( - self, - *, - prng: Optional[np.random.RandomState] = None, - qubits: Optional[Sequence['cirq.Qid']] = None, - initial_state: Union[int, 'cirq.StabilizerStateChForm'] = 0, - classical_data: Optional['cirq.ClassicalDataStore'] = None, - ): - """Initializes with the given state and the axes for the operation. - - Args: - qubits: Determines the canonical ordering of the qubits. This - is often used in specifying the initial state, i.e. the - ordering of the computational basis states. - prng: The pseudo random number generator to use for probabilistic - effects. - initial_state: The initial state for the simulation. This can be a - full CH form passed by reference which will be modified inplace, - or a big-endian int in the computational basis. If the state is - an integer, qubits must be provided in order to determine - array sizes. - classical_data: The shared classical data container for this - simulation. - - Raises: - ValueError: If initial state is an integer but qubits are not - provided. - """ - if isinstance(initial_state, int): - if qubits is None: - raise ValueError('Must specify qubits if initial state is integer') - initial_state = stabilizer_state_ch_form.StabilizerStateChForm( - len(qubits), initial_state - ) - super().__init__( - state=initial_state, prng=prng, qubits=qubits, classical_data=classical_data - ) - - def __repr__(self) -> str: - return ( - 'cirq.ActOnStabilizerCHFormArgs(' - f'initial_state={proper_repr(self.state)},' - f' qubits={self.qubits!r},' - f' classical_data={self.classical_data!r})' - ) +@_compat.deprecated_class(deadline='v0.16', fix='Use cirq.StabilizerChFormSimulationState instead.') +class ActOnStabilizerCHFormArgs(StabilizerChFormSimulationState): + pass diff --git a/cirq-core/cirq/sim/clifford/clifford_simulator.py b/cirq-core/cirq/sim/clifford/clifford_simulator.py index 1919222d618..a2ab97d0f4e 100644 --- a/cirq-core/cirq/sim/clifford/clifford_simulator.py +++ b/cirq-core/cirq/sim/clifford/clifford_simulator.py @@ -43,7 +43,7 @@ class CliffordSimulator( simulator_base.SimulatorBase[ 'cirq.CliffordSimulatorStepResult', 'cirq.CliffordTrialResult', - 'cirq.ActOnStabilizerCHFormArgs', + 'cirq.StabilizerChFormSimulationState', ] ): """An efficient simulator for Clifford circuits.""" @@ -69,10 +69,10 @@ def is_supported_operation(op: 'cirq.Operation') -> bool: def _create_partial_act_on_args( self, - initial_state: Union[int, 'cirq.ActOnStabilizerCHFormArgs'], + initial_state: Union[int, 'cirq.StabilizerChFormSimulationState'], qubits: Sequence['cirq.Qid'], classical_data: 'cirq.ClassicalDataStore', - ) -> 'cirq.ActOnStabilizerCHFormArgs': + ) -> 'cirq.StabilizerChFormSimulationState': """Creates the ActOnStabilizerChFormArgs for a circuit. Args: @@ -88,10 +88,10 @@ def _create_partial_act_on_args( Returns: ActOnStabilizerChFormArgs for the circuit. """ - if isinstance(initial_state, clifford.ActOnStabilizerCHFormArgs): + if isinstance(initial_state, clifford.StabilizerChFormSimulationState): return initial_state - return clifford.ActOnStabilizerCHFormArgs( + return clifford.StabilizerChFormSimulationState( prng=self._prng, classical_data=classical_data, qubits=qubits, @@ -99,7 +99,7 @@ def _create_partial_act_on_args( ) def _create_step_result( - self, sim_state: 'cirq.OperationTarget[clifford.ActOnStabilizerCHFormArgs]' + self, sim_state: 'cirq.SimulationStateBase[clifford.StabilizerChFormSimulationState]' ): return CliffordSimulatorStepResult(sim_state=sim_state) @@ -107,7 +107,7 @@ def _create_simulator_trial_result( self, params: 'cirq.ParamResolver', measurements: Dict[str, np.ndarray], - final_simulator_state: 'cirq.OperationTarget[cirq.ActOnStabilizerCHFormArgs]', + final_simulator_state: 'cirq.SimulationStateBase[cirq.StabilizerChFormSimulationState]', ): return CliffordTrialResult( @@ -116,14 +116,14 @@ def _create_simulator_trial_result( class CliffordTrialResult( - simulator_base.SimulationTrialResultBase['clifford.ActOnStabilizerCHFormArgs'] + simulator_base.SimulationTrialResultBase['clifford.StabilizerChFormSimulationState'] ): @simulator._deprecated_step_result_parameter(old_position=3) def __init__( self, params: 'cirq.ParamResolver', measurements: Dict[str, np.ndarray], - final_simulator_state: 'cirq.OperationTarget[cirq.ActOnStabilizerCHFormArgs]', + final_simulator_state: 'cirq.SimulationStateBase[cirq.StabilizerChFormSimulationState]', ) -> None: super().__init__( params=params, measurements=measurements, final_simulator_state=final_simulator_state @@ -146,13 +146,17 @@ def _repr_pretty_(self, p: Any, cycle: bool): p.text("cirq.CliffordTrialResult(...)" if cycle else self.__str__()) -class CliffordSimulatorStepResult(simulator_base.StepResultBase['cirq.ActOnStabilizerCHFormArgs']): +class CliffordSimulatorStepResult( + simulator_base.StepResultBase['cirq.StabilizerChFormSimulationState'] +): """A `StepResult` that includes `StateVectorMixin` methods.""" - def __init__(self, sim_state: 'cirq.OperationTarget[clifford.ActOnStabilizerCHFormArgs]'): + def __init__( + self, sim_state: 'cirq.SimulationStateBase[clifford.StabilizerChFormSimulationState]' + ): """Results of a step of the simulator. Attributes: - sim_state: The qubit:ActOnArgs lookup for this step. + sim_state: The qubit:SimulationState lookup for this step. """ super().__init__(sim_state) self._clifford_state = None @@ -238,7 +242,7 @@ def state_vector(self): return self.ch_form.state_vector() def apply_unitary(self, op: 'cirq.Operation'): - ch_form_args = clifford.ActOnStabilizerCHFormArgs( + ch_form_args = clifford.StabilizerChFormSimulationState( prng=np.random.RandomState(), qubits=self.qubit_map.keys(), initial_state=self.ch_form ) try: @@ -268,7 +272,7 @@ def apply_measurement( state = self.copy() classical_data = value.ClassicalDataDictionaryStore() - ch_form_args = clifford.ActOnStabilizerCHFormArgs( + ch_form_args = clifford.StabilizerChFormSimulationState( prng=prng, classical_data=classical_data, qubits=self.qubit_map.keys(), diff --git a/cirq-core/cirq/sim/clifford/clifford_simulator_test.py b/cirq-core/cirq/sim/clifford/clifford_simulator_test.py index c3bfca63d90..3abb33bf279 100644 --- a/cirq-core/cirq/sim/clifford/clifford_simulator_test.py +++ b/cirq-core/cirq/sim/clifford/clifford_simulator_test.py @@ -99,7 +99,7 @@ def test_simulate_initial_state(): ) -def test_simulate_act_on_args(): +def test_simulation_state(): q0, q1 = cirq.LineQubit.range(2) simulator = cirq.CliffordSimulator() for b0 in [0, 1]: @@ -211,7 +211,7 @@ def test_clifford_state_initial_state(): def test_clifford_trial_result_repr(): q0 = cirq.LineQubit(0) - final_simulator_state = cirq.ActOnStabilizerCHFormArgs(qubits=[q0]) + final_simulator_state = cirq.StabilizerChFormSimulationState(qubits=[q0]) assert ( repr( cirq.CliffordTrialResult( @@ -222,7 +222,7 @@ def test_clifford_trial_result_repr(): ) == "cirq.SimulationTrialResult(params=cirq.ParamResolver({}), " "measurements={'m': array([[1]])}, " - "final_simulator_state=cirq.ActOnStabilizerCHFormArgs(" + "final_simulator_state=cirq.StabilizerChFormSimulationState(" "initial_state=StabilizerStateChForm(num_qubits=1), " "qubits=(cirq.LineQubit(0),), " "classical_data=cirq.ClassicalDataDictionaryStore()))" @@ -231,7 +231,7 @@ def test_clifford_trial_result_repr(): def test_clifford_trial_result_str(): q0 = cirq.LineQubit(0) - final_simulator_state = cirq.ActOnStabilizerCHFormArgs(qubits=[q0]) + final_simulator_state = cirq.StabilizerChFormSimulationState(qubits=[q0]) assert ( str( cirq.CliffordTrialResult( @@ -247,7 +247,7 @@ def test_clifford_trial_result_str(): def test_clifford_trial_result_repr_pretty(): q0 = cirq.LineQubit(0) - final_simulator_state = cirq.ActOnStabilizerCHFormArgs(qubits=[q0]) + final_simulator_state = cirq.StabilizerChFormSimulationState(qubits=[q0]) result = cirq.CliffordTrialResult( params=cirq.ParamResolver({}), measurements={'m': np.array([[1]])}, diff --git a/cirq-core/cirq/sim/clifford/clifford_tableau_simulation_state.py b/cirq-core/cirq/sim/clifford/clifford_tableau_simulation_state.py new file mode 100644 index 00000000000..4937aa7eeff --- /dev/null +++ b/cirq-core/cirq/sim/clifford/clifford_tableau_simulation_state.py @@ -0,0 +1,55 @@ +# Copyright 2018 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""A protocol for implementing high performance clifford tableau evolutions + for Clifford Simulator.""" + +from typing import Optional, Sequence, TYPE_CHECKING + +import numpy as np + +from cirq.qis import clifford_tableau +from cirq.sim.clifford.stabilizer_simulation_state import StabilizerSimulationState + +if TYPE_CHECKING: + import cirq + + +class CliffordTableauSimulationState(StabilizerSimulationState[clifford_tableau.CliffordTableau]): + """State and context for an operation acting on a clifford tableau.""" + + def __init__( + self, + tableau: 'cirq.CliffordTableau', + prng: Optional[np.random.RandomState] = None, + qubits: Optional[Sequence['cirq.Qid']] = None, + classical_data: Optional['cirq.ClassicalDataStore'] = None, + ): + """Inits CliffordTableauSimulationState. + + Args: + tableau: The CliffordTableau to act on. Operations are expected to + perform inplace edits of this object. + qubits: Determines the canonical ordering of the qubits. This + is often used in specifying the initial state, i.e. the + ordering of the computational basis states. + prng: The pseudo random number generator to use for probabilistic + effects. + classical_data: The shared classical data container for this + simulation. + """ + super().__init__(state=tableau, prng=prng, qubits=qubits, classical_data=classical_data) + + @property + def tableau(self) -> 'cirq.CliffordTableau': + return self.state diff --git a/cirq-core/cirq/sim/clifford/act_on_clifford_tableau_args_test.py b/cirq-core/cirq/sim/clifford/clifford_tableau_simulation_state_test.py similarity index 90% rename from cirq-core/cirq/sim/clifford/act_on_clifford_tableau_args_test.py rename to cirq-core/cirq/sim/clifford/clifford_tableau_simulation_state_test.py index 8135841841e..dfa7f4fa90f 100644 --- a/cirq-core/cirq/sim/clifford/act_on_clifford_tableau_args_test.py +++ b/cirq-core/cirq/sim/clifford/clifford_tableau_simulation_state_test.py @@ -33,7 +33,7 @@ def _unitary_(self): return np.array([[0, -1j], [1j, 0]]) original_tableau = cirq.CliffordTableau(num_qubits=3) - args = cirq.ActOnCliffordTableauArgs( + args = cirq.CliffordTableauSimulationState( tableau=original_tableau.copy(), qubits=cirq.LineQubit.range(3), prng=np.random.RandomState(), @@ -42,13 +42,13 @@ def _unitary_(self): cirq.act_on(UnitaryXGate(), args, [cirq.LineQubit(1)]) assert args.tableau == cirq.CliffordTableau(num_qubits=3, initial_state=2) - args = cirq.ActOnCliffordTableauArgs( + args = cirq.CliffordTableauSimulationState( tableau=original_tableau.copy(), qubits=cirq.LineQubit.range(3), prng=np.random.RandomState(), ) cirq.act_on(UnitaryYGate(), args, [cirq.LineQubit(1)]) - expected_args = cirq.ActOnCliffordTableauArgs( + expected_args = cirq.CliffordTableauSimulationState( tableau=original_tableau.copy(), qubits=cirq.LineQubit.range(3), prng=np.random.RandomState(), @@ -64,7 +64,7 @@ class NoDetails: class NoDetailsSingleQubitGate(cirq.testing.SingleQubitGate): pass - args = cirq.ActOnCliffordTableauArgs( + args = cirq.CliffordTableauSimulationState( tableau=cirq.CliffordTableau(num_qubits=3), qubits=cirq.LineQubit.range(3), prng=np.random.RandomState(), @@ -78,13 +78,13 @@ class NoDetailsSingleQubitGate(cirq.testing.SingleQubitGate): def test_copy(): - args = cirq.ActOnCliffordTableauArgs( + args = cirq.CliffordTableauSimulationState( tableau=cirq.CliffordTableau(num_qubits=3), qubits=cirq.LineQubit.range(3), prng=np.random.RandomState(), ) args1 = args.copy() - assert isinstance(args1, cirq.ActOnCliffordTableauArgs) + assert isinstance(args1, cirq.CliffordTableauSimulationState) assert args is not args1 assert args.tableau is not args1.tableau assert args.tableau == args1.tableau diff --git a/cirq-core/cirq/sim/clifford/stabilizer_ch_form_simulation_state.py b/cirq-core/cirq/sim/clifford/stabilizer_ch_form_simulation_state.py new file mode 100644 index 00000000000..ba261ced412 --- /dev/null +++ b/cirq-core/cirq/sim/clifford/stabilizer_ch_form_simulation_state.py @@ -0,0 +1,76 @@ +# Copyright 2020 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Sequence, TYPE_CHECKING, Union + +import numpy as np + +from cirq._compat import proper_repr +from cirq.sim.clifford import stabilizer_state_ch_form +from cirq.sim.clifford.stabilizer_simulation_state import StabilizerSimulationState + +if TYPE_CHECKING: + import cirq + + +class StabilizerChFormSimulationState( + StabilizerSimulationState[stabilizer_state_ch_form.StabilizerStateChForm] +): + """Wrapper around a stabilizer state in CH form for the act_on protocol.""" + + def __init__( + self, + *, + prng: Optional[np.random.RandomState] = None, + qubits: Optional[Sequence['cirq.Qid']] = None, + initial_state: Union[int, 'cirq.StabilizerStateChForm'] = 0, + classical_data: Optional['cirq.ClassicalDataStore'] = None, + ): + """Initializes with the given state and the axes for the operation. + + Args: + qubits: Determines the canonical ordering of the qubits. This + is often used in specifying the initial state, i.e. the + ordering of the computational basis states. + prng: The pseudo random number generator to use for probabilistic + effects. + initial_state: The initial state for the simulation. This can be a + full CH form passed by reference which will be modified inplace, + or a big-endian int in the computational basis. If the state is + an integer, qubits must be provided in order to determine + array sizes. + classical_data: The shared classical data container for this + simulation. + + Raises: + ValueError: If initial state is an integer but qubits are not + provided. + """ + if isinstance(initial_state, int): + if qubits is None: + raise ValueError('Must specify qubits if initial state is integer') + initial_state = stabilizer_state_ch_form.StabilizerStateChForm( + len(qubits), initial_state + ) + super().__init__( + state=initial_state, prng=prng, qubits=qubits, classical_data=classical_data + ) + + def __repr__(self) -> str: + return ( + 'cirq.StabilizerChFormSimulationState(' + f'initial_state={proper_repr(self.state)},' + f' qubits={self.qubits!r},' + f' classical_data={self.classical_data!r})' + ) diff --git a/cirq-core/cirq/sim/clifford/act_on_stabilizer_ch_form_args_test.py b/cirq-core/cirq/sim/clifford/stabilizer_ch_form_simulation_state_test.py similarity index 77% rename from cirq-core/cirq/sim/clifford/act_on_stabilizer_ch_form_args_test.py rename to cirq-core/cirq/sim/clifford/stabilizer_ch_form_simulation_state_test.py index 32d6d7aa922..89ea2019f17 100644 --- a/cirq-core/cirq/sim/clifford/act_on_stabilizer_ch_form_args_test.py +++ b/cirq-core/cirq/sim/clifford/stabilizer_ch_form_simulation_state_test.py @@ -19,17 +19,17 @@ def test_init_state(): - args = cirq.ActOnStabilizerCHFormArgs(qubits=cirq.LineQubit.range(1), initial_state=1) + args = cirq.StabilizerChFormSimulationState(qubits=cirq.LineQubit.range(1), initial_state=1) np.testing.assert_allclose(args.state.state_vector(), [0, 1]) with pytest.raises(ValueError, match='Must specify qubits'): - _ = cirq.ActOnStabilizerCHFormArgs(initial_state=1) + _ = cirq.StabilizerChFormSimulationState(initial_state=1) def test_cannot_act(): class NoDetails(cirq.testing.SingleQubitGate): pass - args = cirq.ActOnStabilizerCHFormArgs(qubits=[], prng=np.random.RandomState()) + args = cirq.StabilizerChFormSimulationState(qubits=[], prng=np.random.RandomState()) with pytest.raises(TypeError, match="Failed to act"): cirq.act_on(NoDetails(), args, qubits=()) @@ -37,14 +37,14 @@ class NoDetails(cirq.testing.SingleQubitGate): def test_gate_with_act_on(): class CustomGate(cirq.testing.SingleQubitGate): - def _act_on_(self, args, qubits): - if isinstance(args, cirq.ActOnStabilizerCHFormArgs): - qubit = args.qubit_map[qubits[0]] - args.state.gamma[qubit] += 1 + def _act_on_(self, sim_state, qubits): + if isinstance(sim_state, cirq.StabilizerChFormSimulationState): + qubit = sim_state.qubit_map[qubits[0]] + sim_state.state.gamma[qubit] += 1 return True state = cirq.StabilizerStateChForm(num_qubits=3) - args = cirq.ActOnStabilizerCHFormArgs( + args = cirq.StabilizerChFormSimulationState( qubits=cirq.LineQubit.range(3), prng=np.random.RandomState(), initial_state=state ) @@ -61,11 +61,11 @@ def num_qubits(self) -> int: def _unitary_(self): return np.array([[0, -1j], [1j, 0]]) - args = cirq.ActOnStabilizerCHFormArgs( + args = cirq.StabilizerChFormSimulationState( qubits=cirq.LineQubit.range(3), prng=np.random.RandomState() ) cirq.act_on(UnitaryYGate(), args, [cirq.LineQubit(1)]) - expected_args = cirq.ActOnStabilizerCHFormArgs( + expected_args = cirq.StabilizerChFormSimulationState( qubits=cirq.LineQubit.range(3), prng=np.random.RandomState() ) cirq.act_on(cirq.Y, expected_args, [cirq.LineQubit(1)]) @@ -80,11 +80,11 @@ def num_qubits(self) -> int: def _unitary_(self): return np.array([[1, 1], [1, -1]]) / (2**0.5) - args = cirq.ActOnStabilizerCHFormArgs( + args = cirq.StabilizerChFormSimulationState( qubits=cirq.LineQubit.range(3), prng=np.random.RandomState() ) cirq.act_on(UnitaryHGate(), args, [cirq.LineQubit(1)]) - expected_args = cirq.ActOnStabilizerCHFormArgs( + expected_args = cirq.StabilizerChFormSimulationState( qubits=cirq.LineQubit.range(3), prng=np.random.RandomState() ) cirq.act_on(cirq.H, expected_args, [cirq.LineQubit(1)]) @@ -92,11 +92,11 @@ def _unitary_(self): def test_copy(): - args = cirq.ActOnStabilizerCHFormArgs( + args = cirq.StabilizerChFormSimulationState( qubits=cirq.LineQubit.range(3), prng=np.random.RandomState() ) args1 = args.copy() - assert isinstance(args1, cirq.ActOnStabilizerCHFormArgs) + assert isinstance(args1, cirq.StabilizerChFormSimulationState) assert args is not args1 assert args.state is not args1.state np.testing.assert_equal(args.state.state_vector(), args1.state.state_vector()) diff --git a/cirq-core/cirq/sim/clifford/stabilizer_sampler.py b/cirq-core/cirq/sim/clifford/stabilizer_sampler.py index 73bab9091f1..ec589389f16 100644 --- a/cirq-core/cirq/sim/clifford/stabilizer_sampler.py +++ b/cirq-core/cirq/sim/clifford/stabilizer_sampler.py @@ -19,7 +19,7 @@ import cirq from cirq import protocols, value from cirq.qis.clifford_tableau import CliffordTableau -from cirq.sim.clifford.act_on_clifford_tableau_args import ActOnCliffordTableauArgs +from cirq.sim.clifford.clifford_tableau_simulation_state import CliffordTableauSimulationState from cirq.work import sampler @@ -53,7 +53,7 @@ def _run(self, circuit: 'cirq.AbstractCircuit', repetitions: int) -> Dict[str, n qubits = circuit.all_qubits() for _ in range(repetitions): - state = ActOnCliffordTableauArgs( + state = CliffordTableauSimulationState( CliffordTableau(num_qubits=len(qubits)), qubits=list(qubits), prng=self._prng ) for op in circuit.all_operations(): diff --git a/cirq-core/cirq/sim/clifford/stabilizer_simulation_state.py b/cirq-core/cirq/sim/clifford/stabilizer_simulation_state.py new file mode 100644 index 00000000000..e24c19e8fae --- /dev/null +++ b/cirq-core/cirq/sim/clifford/stabilizer_simulation_state.py @@ -0,0 +1,186 @@ +# Copyright 2021 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc +from typing import Any, Dict, Generic, List, Optional, Sequence, TYPE_CHECKING, TypeVar, Union + +import numpy as np + +from cirq import linalg, ops, protocols +from cirq._compat import deprecated_parameter +from cirq.ops import common_gates, global_phase_op, matrix_gates, swap_gates +from cirq.ops.clifford_gate import SingleQubitCliffordGate +from cirq.protocols import has_unitary, num_qubits, unitary +from cirq.sim.simulation_state import SimulationState +from cirq.type_workarounds import NotImplementedType + +if TYPE_CHECKING: + import cirq + + +TStabilizerState = TypeVar('TStabilizerState', bound='cirq.StabilizerState') + + +class StabilizerSimulationState( + SimulationState[TStabilizerState], Generic[TStabilizerState], metaclass=abc.ABCMeta +): + """Abstract wrapper around a stabilizer state for the act_on protocol.""" + + @deprecated_parameter( + deadline='v0.16', + fix='Use kwargs instead of positional args', + parameter_desc='args', + match=lambda args, kwargs: len(args) > 1, + ) + @deprecated_parameter( + deadline='v0.16', + fix='Replace log_of_measurement_results with' + ' classical_data=cirq.ClassicalDataDictionaryStore(_records=logs).', + parameter_desc='log_of_measurement_results', + match=lambda args, kwargs: 'log_of_measurement_results' in kwargs, + ) + def __init__( + self, + state: TStabilizerState, + prng: Optional[np.random.RandomState] = None, + log_of_measurement_results: Optional[Dict[str, List[int]]] = None, + qubits: Optional[Sequence['cirq.Qid']] = None, + classical_data: Optional['cirq.ClassicalDataStore'] = None, + ): + """Initializes the StabilizerSimulationState. + + Args: + state: The quantum stabilizer state to use in the simulation or + act_on invocation. + prng: The pseudo random number generator to use for probabilistic + effects. + qubits: Determines the canonical ordering of the qubits. This + is often used in specifying the initial state, i.e. the + ordering of the computational basis states. + log_of_measurement_results: A mutable object that measurements are + being recorded into. + classical_data: The shared classical data container for this + simulation. + """ + if log_of_measurement_results is not None: + super().__init__( + state=state, + prng=prng, + qubits=qubits, + log_of_measurement_results=log_of_measurement_results, + classical_data=classical_data, + ) + else: + super().__init__(state=state, prng=prng, qubits=qubits, classical_data=classical_data) + + @property + def state(self) -> TStabilizerState: + return self._state + + def _act_on_fallback_( + self, action: Any, qubits: Sequence['cirq.Qid'], allow_decompose: bool = True + ) -> Union[bool, NotImplementedType]: + strats = [self._strat_apply_gate, self._strat_apply_mixture] + if allow_decompose: + strats.append(self._strat_decompose) + strats.append(self._strat_act_from_single_qubit_decompose) + for strat in strats: + result = strat(action, qubits) # type: ignore + if result is True: + return True + assert result is NotImplemented, str(result) + + return NotImplemented + + def _swap( + self, control_axis: int, target_axis: int, exponent: float = 1, global_shift: float = 0 + ): + """Apply a SWAP gate""" + if exponent % 1 != 0: + raise ValueError('Swap exponent must be integer') # coverage: ignore + self._state.apply_cx(control_axis, target_axis) + self._state.apply_cx(target_axis, control_axis, exponent, global_shift) + self._state.apply_cx(control_axis, target_axis) + + def _strat_apply_gate(self, val: Any, qubits: Sequence['cirq.Qid']) -> bool: + if not protocols.has_stabilizer_effect(val): + return NotImplemented + gate = val.gate if isinstance(val, ops.Operation) else val + axes = self.get_axes(qubits) + if isinstance(gate, common_gates.XPowGate): + self._state.apply_x(axes[0], gate.exponent, gate.global_shift) + elif isinstance(gate, common_gates.YPowGate): + self._state.apply_y(axes[0], gate.exponent, gate.global_shift) + elif isinstance(gate, common_gates.ZPowGate): + self._state.apply_z(axes[0], gate.exponent, gate.global_shift) + elif isinstance(gate, common_gates.HPowGate): + self._state.apply_h(axes[0], gate.exponent, gate.global_shift) + elif isinstance(gate, common_gates.CXPowGate): + self._state.apply_cx(axes[0], axes[1], gate.exponent, gate.global_shift) + elif isinstance(gate, common_gates.CZPowGate): + self._state.apply_cz(axes[0], axes[1], gate.exponent, gate.global_shift) + elif isinstance(gate, global_phase_op.GlobalPhaseGate): + self._state.apply_global_phase(gate.coefficient) + elif isinstance(gate, swap_gates.SwapPowGate): + self._swap(axes[0], axes[1], gate.exponent, gate.global_shift) + else: + return NotImplemented + return True + + def _strat_apply_mixture(self, val: Any, qubits: Sequence['cirq.Qid']) -> bool: + mixture = protocols.mixture(val, None) + if mixture is None: + return NotImplemented + if not all(linalg.is_unitary(m) for _, m in mixture): + return NotImplemented # coverage: ignore + probabilities, unitaries = zip(*mixture) + index = self.prng.choice(len(unitaries), p=probabilities) + return self._strat_act_from_single_qubit_decompose( + matrix_gates.MatrixGate(unitaries[index]), qubits + ) + + def _strat_act_from_single_qubit_decompose( + self, val: Any, qubits: Sequence['cirq.Qid'] + ) -> bool: + if num_qubits(val) == 1: + if not has_unitary(val): + return NotImplemented + u = unitary(val) + clifford_gate = SingleQubitCliffordGate.from_unitary(u) + if clifford_gate is not None: + # Gather the effective unitary applied so as to correct for the + # global phase later. + final_unitary = np.eye(2) + for axis, quarter_turns in clifford_gate.decompose_rotation(): + gate = axis ** (quarter_turns / 2) + self._strat_apply_gate(gate, qubits) + final_unitary = np.matmul(unitary(gate), final_unitary) + + # Find the entry with the largest magnitude in the input unitary. + k = max(np.ndindex(*u.shape), key=lambda t: abs(u[t])) + # Correct the global phase that wasn't conserved in the above + # decomposition. + self._state.apply_global_phase(u[k] / final_unitary[k]) + return True + + return NotImplemented + + def _strat_decompose(self, val: Any, qubits: Sequence['cirq.Qid']) -> bool: + gate = val.gate if isinstance(val, ops.Operation) else val + operations = protocols.decompose_once_with_qubits(gate, qubits, None) + if operations is None or not all(protocols.has_stabilizer_effect(op) for op in operations): + return NotImplemented + for op in operations: + protocols.act_on(op, self) + return True diff --git a/cirq-core/cirq/sim/clifford/act_on_stabilizer_args_test.py b/cirq-core/cirq/sim/clifford/stabilizer_simulation_state_test.py similarity index 90% rename from cirq-core/cirq/sim/clifford/act_on_stabilizer_args_test.py rename to cirq-core/cirq/sim/clifford/stabilizer_simulation_state_test.py index e72cfc5dbe3..50559d9cc67 100644 --- a/cirq-core/cirq/sim/clifford/act_on_stabilizer_args_test.py +++ b/cirq-core/cirq/sim/clifford/stabilizer_simulation_state_test.py @@ -22,7 +22,7 @@ def test_apply_gate(): q0, q1 = cirq.LineQubit.range(2) state = Mock() - args = cirq.ActOnStabilizerArgs(state=state, qubits=[q0, q1]) + args = cirq.StabilizerSimulationState(state=state, qubits=[q0, q1]) assert args._strat_apply_gate(cirq.X, [q0]) is True state.apply_x.assert_called_with(0, 1.0, 0.0) @@ -83,7 +83,7 @@ def test_apply_gate(): def test_apply_mixture(): q0 = cirq.LineQubit(0) state = Mock() - args = cirq.ActOnStabilizerArgs(state=state, qubits=[q0]) + args = cirq.StabilizerSimulationState(state=state, qubits=[q0]) for _ in range(100): assert args._strat_apply_mixture(cirq.BitFlipChannel(0.5), [q0]) is True @@ -94,7 +94,7 @@ def test_apply_mixture(): def test_act_from_single_qubit_decompose(): q0 = cirq.LineQubit(0) state = Mock() - args = cirq.ActOnStabilizerArgs(state=state, qubits=[q0]) + args = cirq.StabilizerSimulationState(state=state, qubits=[q0]) assert ( args._strat_act_from_single_qubit_decompose( cirq.MatrixGate(np.array([[0, 1], [1, 0]])), [q0] @@ -114,13 +114,13 @@ def _qid_shape_(self): q0 = cirq.LineQubit(0) state = Mock() - args = cirq.ActOnStabilizerArgs(state=state, qubits=[q0]) + args = cirq.StabilizerSimulationState(state=state, qubits=[q0]) assert args._strat_decompose(XContainer(), [q0]) is True state.apply_x.assert_called_with(0, 1.0, 0.0) def test_deprecated(): with cirq.testing.assert_deprecated('log_of_measurement_results', deadline='v0.16', count=2): - _ = cirq.ActOnStabilizerArgs(state=0, log_of_measurement_results={}) + _ = cirq.StabilizerSimulationState(state=0, log_of_measurement_results={}) with cirq.testing.assert_deprecated('positional', deadline='v0.16'): - _ = cirq.ActOnStabilizerArgs(0) + _ = cirq.StabilizerSimulationState(0) diff --git a/cirq-core/cirq/sim/clifford/stabilizer_state_ch_form_test.py b/cirq-core/cirq/sim/clifford/stabilizer_state_ch_form_test.py index 30ef642a745..142cd978cd0 100644 --- a/cirq-core/cirq/sim/clifford/stabilizer_state_ch_form_test.py +++ b/cirq-core/cirq/sim/clifford/stabilizer_state_ch_form_test.py @@ -66,7 +66,7 @@ def test_run(): state = cirq.StabilizerStateChForm(num_qubits=3) classical_data = cirq.ClassicalDataDictionaryStore() for op in circuit.all_operations(): - args = cirq.ActOnStabilizerCHFormArgs( + args = cirq.StabilizerChFormSimulationState( qubits=list(circuit.all_qubits()), prng=np.random.RandomState(), classical_data=classical_data, diff --git a/cirq-core/cirq/sim/density_matrix_simulation_state.py b/cirq-core/cirq/sim/density_matrix_simulation_state.py new file mode 100644 index 00000000000..1486be355b4 --- /dev/null +++ b/cirq-core/cirq/sim/density_matrix_simulation_state.py @@ -0,0 +1,340 @@ +# Copyright 2021 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Objects and methods for acting efficiently on a density matrix.""" + +from typing import Any, Callable, List, Optional, Sequence, Tuple, TYPE_CHECKING, Type, Union + +import numpy as np + +from cirq import protocols, qis, sim +from cirq._compat import proper_repr +from cirq.linalg import transformations +from cirq.sim.simulation_state import SimulationState, strat_act_on_from_apply_decompose + +if TYPE_CHECKING: + import cirq + from numpy.typing import DTypeLike + + +class _BufferedDensityMatrix(qis.QuantumStateRepresentation): + """Contains the density matrix and buffers for efficient state evolution.""" + + def __init__(self, density_matrix: np.ndarray, buffer: Optional[List[np.ndarray]] = None): + """Initializes the object with the inputs. + + This initializer creates the buffer if necessary. + + Args: + density_matrix: The density matrix, must be correctly formatted. The data is not + checked for validity here due to performance concerns. + buffer: Optional, must be length 3 and same shape as the density matrix. If not + provided, a buffer will be created automatically. + Raises: + ValueError: If the array is not the shape of a density matrix. + """ + self._density_matrix = density_matrix + if buffer is None: + buffer = [np.empty_like(density_matrix) for _ in range(3)] + self._buffer = buffer + if len(density_matrix.shape) % 2 != 0: + # coverage: ignore + raise ValueError('The dimension of target_tensor is not divisible by 2.') + self._qid_shape = density_matrix.shape[: len(density_matrix.shape) // 2] + + @classmethod + def create( + cls, + *, + initial_state: Union[np.ndarray, 'cirq.STATE_VECTOR_LIKE'] = 0, + qid_shape: Optional[Tuple[int, ...]] = None, + dtype: Optional['DTypeLike'] = None, + buffer: Optional[List[np.ndarray]] = None, + ): + """Creates a buffered density matrix with the requested state. + + Args: + initial_state: The initial state for the simulation in the computational basis. + qid_shape: The shape of the density matrix, if the initial state is provided as an int. + dtype: The desired dtype of the density matrix. + buffer: Optional, must be length 3 and same shape as the density matrix. If not + provided, a buffer will be created automatically. + Raises: + ValueError: If initial state is provided as integer, but qid_shape is not provided. + """ + if not isinstance(initial_state, np.ndarray): + if qid_shape is None: + raise ValueError('qid_shape must be provided if initial_state is not ndarray') + density_matrix = qis.to_valid_density_matrix( + initial_state, len(qid_shape), qid_shape=qid_shape, dtype=dtype + ).reshape(qid_shape * 2) + else: + if qid_shape is not None: + if dtype and initial_state.dtype != dtype: + initial_state = initial_state.astype(dtype) + density_matrix = qis.to_valid_density_matrix( + initial_state, len(qid_shape), qid_shape=qid_shape, dtype=dtype + ).reshape(qid_shape * 2) + else: + density_matrix = initial_state # coverage: ignore + if np.may_share_memory(density_matrix, initial_state): + density_matrix = density_matrix.copy() + density_matrix = density_matrix.astype(dtype, copy=False) + return cls(density_matrix, buffer) + + def copy(self, deep_copy_buffers: bool = True) -> '_BufferedDensityMatrix': + """Copies the object. + + Args: + deep_copy_buffers: True by default, False to reuse the existing buffers. + Returns: + A copy of the object. + """ + return _BufferedDensityMatrix( + density_matrix=self._density_matrix.copy(), + buffer=[b.copy() for b in self._buffer] if deep_copy_buffers else self._buffer, + ) + + def kron(self, other: '_BufferedDensityMatrix') -> '_BufferedDensityMatrix': + """Creates the Kronecker product with the other density matrix. + + Args: + other: The density matrix with which to kron. + Returns: + The Kronecker product of the two density matrices. + """ + density_matrix = transformations.density_matrix_kronecker_product( + self._density_matrix, other._density_matrix + ) + return _BufferedDensityMatrix(density_matrix=density_matrix) + + def factor( + self, axes: Sequence[int], *, validate=True, atol=1e-07 + ) -> Tuple['_BufferedDensityMatrix', '_BufferedDensityMatrix']: + """Factors out the desired axes. + + Args: + axes: The axes to factor out. Only the left axes should be provided. For example, to + extract [C,A] from density matrix of shape [A,B,C,D,A,B,C,D], `axes` should be + [2,0], and the return value will be two density matrices ([C,A,C,A], [B,D,B,D]). + validate: Perform a validation that the density matrix factors cleanly. + atol: The absolute tolerance for the validation. + Returns: + A tuple with the `(extracted, remainder)` density matrices, where `extracted` means + the sub-matrix which corresponds to the axes requested, and with the axes in the + requested order, and where `remainder` means the sub-matrix on the remaining axes, + in the same order as the original density matrix. + """ + extracted_tensor, remainder_tensor = transformations.factor_density_matrix( + self._density_matrix, axes, validate=validate, atol=atol + ) + extracted = _BufferedDensityMatrix(density_matrix=extracted_tensor) + remainder = _BufferedDensityMatrix(density_matrix=remainder_tensor) + return extracted, remainder + + def reindex(self, axes: Sequence[int]) -> '_BufferedDensityMatrix': + """Transposes the axes of a density matrix to a specified order. + + Args: + axes: The desired axis order. Only the left axes should be provided. For example, to + transpose [A,B,C,A,B,C] to [C,B,A,C,B,A], `axes` should be [2,1,0]. + Returns: + The transposed density matrix. + """ + new_tensor = transformations.transpose_density_matrix_to_axis_order( + self._density_matrix, axes + ) + return _BufferedDensityMatrix(density_matrix=new_tensor) + + def apply_channel(self, action: Any, axes: Sequence[int]) -> bool: + """Apply channel to state. + + Args: + action: The value with a channel to apply. + axes: The axes on which to apply the channel. + Returns: + True if the action succeeded. + """ + result = protocols.apply_channel( + action, + args=protocols.ApplyChannelArgs( + target_tensor=self._density_matrix, + out_buffer=self._buffer[0], + auxiliary_buffer0=self._buffer[1], + auxiliary_buffer1=self._buffer[2], + left_axes=axes, + right_axes=[e + len(self._qid_shape) for e in axes], + ), + default=None, + ) + if result is None: + return False + for i in range(len(self._buffer)): + if result is self._buffer[i]: + self._buffer[i] = self._density_matrix + self._density_matrix = result + return True + + def measure( + self, axes: Sequence[int], seed: 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None + ) -> List[int]: + """Measures the density matrix. + + Args: + axes: The axes to measure. + seed: The random number seed to use. + Returns: + The measurements in order. + """ + bits, _ = sim.measure_density_matrix( + self._density_matrix, + axes, + out=self._density_matrix, + qid_shape=self._qid_shape, + seed=seed, + ) + return bits + + def sample( + self, + axes: Sequence[int], + repetitions: int = 1, + seed: 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None, + ) -> np.ndarray: + """Samples the density matrix. + + Args: + axes: The axes to sample. + repetitions: The number of samples to make. + seed: The random number seed to use. + Returns: + The samples in order. + """ + return sim.sample_density_matrix( + self._density_matrix, + axes, + qid_shape=self._qid_shape, + repetitions=repetitions, + seed=seed, + ) + + @property + def supports_factor(self) -> bool: + return True + + @property + def can_represent_mixed_states(self) -> bool: + return True + + +class DensityMatrixSimulationState(SimulationState[_BufferedDensityMatrix]): + """State and context for an operation acting on a density matrix. + + To act on this object, directly edit the `target_tensor` property, which is + storing the density matrix of the quantum system with one axis per qubit. + """ + + def __init__( + self, + *, + available_buffer: Optional[List[np.ndarray]] = None, + qid_shape: Optional[Tuple[int, ...]] = None, + prng: Optional[np.random.RandomState] = None, + qubits: Optional[Sequence['cirq.Qid']] = None, + initial_state: Union[np.ndarray, 'cirq.STATE_VECTOR_LIKE'] = 0, + dtype: Type[np.number] = np.complex64, + classical_data: Optional['cirq.ClassicalDataStore'] = None, + ): + """Inits DensityMatrixSimulationState. + + Args: + available_buffer: A workspace with the same shape and dtype as + `target_tensor`. Used by operations that cannot be applied to + `target_tensor` inline, in order to avoid unnecessary + allocations. + qubits: Determines the canonical ordering of the qubits. This + is often used in specifying the initial state, i.e. the + ordering of the computational basis states. + qid_shape: The shape of the target tensor. + prng: The pseudo random number generator to use for probabilistic + effects. + initial_state: The initial state for the simulation in the + computational basis. + dtype: The `numpy.dtype` of the inferred state vector. One of + `numpy.complex64` or `numpy.complex128`. Only used when + `target_tenson` is None. + classical_data: The shared classical data container for this + simulation. + + Raises: + ValueError: The dimension of `target_tensor` is not divisible by 2 + and `qid_shape` is not provided. + """ + state = _BufferedDensityMatrix.create( + initial_state=initial_state, + qid_shape=tuple(q.dimension for q in qubits) if qubits is not None else None, + dtype=dtype, + buffer=available_buffer, + ) + super().__init__(state=state, prng=prng, qubits=qubits, classical_data=classical_data) + + def _act_on_fallback_( + self, action: Any, qubits: Sequence['cirq.Qid'], allow_decompose: bool = True + ) -> bool: + strats: List[Callable[[Any, Any, Sequence['cirq.Qid']], bool]] = [ + _strat_apply_channel_to_state + ] + if allow_decompose: + strats.append(strat_act_on_from_apply_decompose) + + # Try each strategy, stopping if one works. + for strat in strats: + result = strat(action, self, qubits) + if result is False: + break # coverage: ignore + if result is True: + return True + assert result is NotImplemented, str(result) + raise TypeError( + "Can't simulate operations that don't implement " + "SupportsUnitary, SupportsConsistentApplyUnitary, " + "SupportsMixture or SupportsKraus or is a measurement: {!r}".format(action) + ) + + def __repr__(self) -> str: + return ( + 'cirq.DensityMatrixSimulationState(' + f'initial_state={proper_repr(self.target_tensor)},' + f' qid_shape={self.qid_shape!r},' + f' qubits={self.qubits!r},' + f' classical_data={self.classical_data!r})' + ) + + @property + def target_tensor(self): + return self._state._density_matrix + + @property + def available_buffer(self): + return self._state._buffer + + @property + def qid_shape(self): + return self._state._qid_shape + + +def _strat_apply_channel_to_state( + action: Any, args: 'cirq.DensityMatrixSimulationState', qubits: Sequence['cirq.Qid'] +) -> bool: + """Apply channel to state.""" + return True if args._state.apply_channel(action, args.get_axes(qubits)) else NotImplemented diff --git a/cirq-core/cirq/sim/act_on_density_matrix_args_test.py b/cirq-core/cirq/sim/density_matrix_simulation_state_test.py similarity index 83% rename from cirq-core/cirq/sim/act_on_density_matrix_args_test.py rename to cirq-core/cirq/sim/density_matrix_simulation_state_test.py index 559759f5f10..03b5e5d497f 100644 --- a/cirq-core/cirq/sim/act_on_density_matrix_args_test.py +++ b/cirq-core/cirq/sim/density_matrix_simulation_state_test.py @@ -23,7 +23,7 @@ def test_default_parameter(): tensor = cirq.to_valid_density_matrix( 0, len(qid_shape), qid_shape=qid_shape, dtype=np.complex64 ) - args = cirq.ActOnDensityMatrixArgs(qubits=cirq.LineQubit.range(1), initial_state=0) + args = cirq.DensityMatrixSimulationState(qubits=cirq.LineQubit.range(1), initial_state=0) np.testing.assert_almost_equal(args.target_tensor, tensor) assert len(args.available_buffer) == 3 for buffer in args.available_buffer: @@ -33,7 +33,7 @@ def test_default_parameter(): def test_shallow_copy_buffers(): - args = cirq.ActOnDensityMatrixArgs(qubits=cirq.LineQubit.range(1), initial_state=0) + args = cirq.DensityMatrixSimulationState(qubits=cirq.LineQubit.range(1), initial_state=0) copy = args.copy(deep_copy_buffers=False) assert copy.available_buffer is args.available_buffer @@ -46,7 +46,7 @@ def num_qubits(self) -> int: def _decompose_(self, qubits): yield cirq.X(*qubits) - args = cirq.ActOnDensityMatrixArgs( + args = cirq.DensityMatrixSimulationState( qubits=cirq.LineQubit.range(1), prng=np.random.RandomState(), initial_state=0, @@ -63,7 +63,7 @@ def test_cannot_act(): class NoDetails: pass - args = cirq.ActOnDensityMatrixArgs( + args = cirq.DensityMatrixSimulationState( qubits=cirq.LineQubit.range(1), prng=np.random.RandomState(), initial_state=0, @@ -74,7 +74,7 @@ class NoDetails: def test_with_qubits(): - original = cirq.ActOnDensityMatrixArgs( + original = cirq.DensityMatrixSimulationState( qubits=cirq.LineQubit.range(1), initial_state=1, dtype=np.complex64 ) extened = original.with_qubits(cirq.LineQubit.range(1, 2)) @@ -89,17 +89,17 @@ def test_with_qubits(): def test_qid_shape_error(): with pytest.raises(ValueError, match="qid_shape must be provided"): - cirq.sim.act_on_density_matrix_args._BufferedDensityMatrix.create(initial_state=0) + cirq.sim.density_matrix_simulation_state._BufferedDensityMatrix.create(initial_state=0) def test_initial_state_vector(): qubits = cirq.LineQubit.range(3) - args = cirq.ActOnDensityMatrixArgs( + args = cirq.DensityMatrixSimulationState( qubits=qubits, initial_state=np.full((8,), 1 / np.sqrt(8)), dtype=np.complex64 ) assert args.target_tensor.shape == (2, 2, 2, 2, 2, 2) - args2 = cirq.ActOnDensityMatrixArgs( + args2 = cirq.DensityMatrixSimulationState( qubits=qubits, initial_state=np.full((2, 2, 2), 1 / np.sqrt(8)), dtype=np.complex64 ) assert args2.target_tensor.shape == (2, 2, 2, 2, 2, 2) @@ -107,12 +107,12 @@ def test_initial_state_vector(): def test_initial_state_matrix(): qubits = cirq.LineQubit.range(3) - args = cirq.ActOnDensityMatrixArgs( + args = cirq.DensityMatrixSimulationState( qubits=qubits, initial_state=np.full((8, 8), 1 / 8), dtype=np.complex64 ) assert args.target_tensor.shape == (2, 2, 2, 2, 2, 2) - args2 = cirq.ActOnDensityMatrixArgs( + args2 = cirq.DensityMatrixSimulationState( qubits=qubits, initial_state=np.full((2, 2, 2, 2, 2, 2), 1 / 8), dtype=np.complex64 ) assert args2.target_tensor.shape == (2, 2, 2, 2, 2, 2) @@ -121,19 +121,19 @@ def test_initial_state_matrix(): def test_initial_state_bad_shape(): qubits = cirq.LineQubit.range(3) with pytest.raises(ValueError, match="Invalid quantum state"): - cirq.ActOnDensityMatrixArgs( + cirq.DensityMatrixSimulationState( qubits=qubits, initial_state=np.full((4,), 1 / 2), dtype=np.complex64 ) with pytest.raises(ValueError, match="Invalid quantum state"): - cirq.ActOnDensityMatrixArgs( + cirq.DensityMatrixSimulationState( qubits=qubits, initial_state=np.full((2, 2), 1 / 2), dtype=np.complex64 ) with pytest.raises(ValueError, match="Invalid quantum state"): - cirq.ActOnDensityMatrixArgs( + cirq.DensityMatrixSimulationState( qubits=qubits, initial_state=np.full((4, 4), 1 / 4), dtype=np.complex64 ) with pytest.raises(ValueError, match="Invalid quantum state"): - cirq.ActOnDensityMatrixArgs( + cirq.DensityMatrixSimulationState( qubits=qubits, initial_state=np.full((2, 2, 2, 2), 1 / 4), dtype=np.complex64 ) diff --git a/cirq-core/cirq/sim/density_matrix_simulator.py b/cirq-core/cirq/sim/density_matrix_simulator.py index 1a3188e654d..09ff6f98781 100644 --- a/cirq-core/cirq/sim/density_matrix_simulator.py +++ b/cirq-core/cirq/sim/density_matrix_simulator.py @@ -18,7 +18,7 @@ from cirq import ops, protocols, study, value from cirq._compat import deprecated_class, deprecated_parameter, proper_repr -from cirq.sim import simulator, act_on_density_matrix_args, simulator_base +from cirq.sim import simulator, density_matrix_simulation_state, simulator_base if TYPE_CHECKING: import cirq @@ -29,7 +29,7 @@ class DensityMatrixSimulator( simulator_base.SimulatorBase[ 'cirq.DensityMatrixStepResult', 'cirq.DensityMatrixTrialResult', - 'cirq.ActOnDensityMatrixArgs', + 'cirq.DensityMatrixSimulationState', ], simulator.SimulatesExpectationValues, ): @@ -147,11 +147,13 @@ def __init__( def _create_partial_act_on_args( self, - initial_state: Union[np.ndarray, 'cirq.STATE_VECTOR_LIKE', 'cirq.ActOnDensityMatrixArgs'], + initial_state: Union[ + np.ndarray, 'cirq.STATE_VECTOR_LIKE', 'cirq.DensityMatrixSimulationState' + ], qubits: Sequence['cirq.Qid'], classical_data: 'cirq.ClassicalDataStore', - ) -> 'cirq.ActOnDensityMatrixArgs': - """Creates the ActOnDensityMatrixArgs for a circuit. + ) -> 'cirq.DensityMatrixSimulationState': + """Creates the DensityMatrixSimulationState for a circuit. Args: initial_state: The initial state for the simulation in the @@ -163,12 +165,12 @@ def _create_partial_act_on_args( simulation. Returns: - ActOnDensityMatrixArgs for the circuit. + DensityMatrixSimulationState for the circuit. """ - if isinstance(initial_state, act_on_density_matrix_args.ActOnDensityMatrixArgs): + if isinstance(initial_state, density_matrix_simulation_state.DensityMatrixSimulationState): return initial_state - return act_on_density_matrix_args.ActOnDensityMatrixArgs( + return density_matrix_simulation_state.DensityMatrixSimulationState( qubits=qubits, prng=self._prng, classical_data=classical_data, @@ -179,14 +181,16 @@ def _create_partial_act_on_args( def _can_be_in_run_prefix(self, val: Any): return not protocols.measurement_keys_touched(val) - def _create_step_result(self, sim_state: 'cirq.OperationTarget[cirq.ActOnDensityMatrixArgs]'): + def _create_step_result( + self, sim_state: 'cirq.SimulationStateBase[cirq.DensityMatrixSimulationState]' + ): return DensityMatrixStepResult(sim_state=sim_state, dtype=self._dtype) def _create_simulator_trial_result( self, params: 'cirq.ParamResolver', measurements: Dict[str, np.ndarray], - final_simulator_state: 'cirq.OperationTarget[cirq.ActOnDensityMatrixArgs]', + final_simulator_state: 'cirq.SimulationStateBase[cirq.DensityMatrixSimulationState]', ) -> 'cirq.DensityMatrixTrialResult': return DensityMatrixTrialResult( params=params, measurements=measurements, final_simulator_state=final_simulator_state @@ -227,7 +231,7 @@ def simulate_expectation_values_sweep( return swept_evs -class DensityMatrixStepResult(simulator_base.StepResultBase['cirq.ActOnDensityMatrixArgs']): +class DensityMatrixStepResult(simulator_base.StepResultBase['cirq.DensityMatrixSimulationState']): """A single step in the simulation of the DensityMatrixSimulator. Attributes: @@ -243,14 +247,14 @@ class DensityMatrixStepResult(simulator_base.StepResultBase['cirq.ActOnDensityMa ) def __init__( self, - sim_state: 'cirq.OperationTarget[cirq.ActOnDensityMatrixArgs]', + sim_state: 'cirq.SimulationStateBase[cirq.DensityMatrixSimulationState]', simulator: 'cirq.DensityMatrixSimulator' = None, dtype: 'DTypeLike' = np.complex64, ): """DensityMatrixStepResult. Args: - sim_state: The qubit:ActOnArgs lookup for this step. + sim_state: The qubit:SimulationState lookup for this step. simulator: The simulator used to create this. dtype: The `numpy.dtype` used by the simulation. One of `numpy.complex64` or `numpy.complex128`. @@ -342,7 +346,9 @@ def __repr__(self) -> str: @value.value_equality(unhashable=True) class DensityMatrixTrialResult( - simulator_base.SimulationTrialResultBase[act_on_density_matrix_args.ActOnDensityMatrixArgs] + simulator_base.SimulationTrialResultBase[ + density_matrix_simulation_state.DensityMatrixSimulationState + ] ): """A `SimulationTrialResult` for `DensityMatrixSimulator` runs. @@ -386,7 +392,7 @@ def __init__( self, params: 'cirq.ParamResolver', measurements: Dict[str, np.ndarray], - final_simulator_state: 'cirq.OperationTarget[cirq.ActOnDensityMatrixArgs]', + final_simulator_state: 'cirq.SimulationStateBase[cirq.DensityMatrixSimulationState]', ) -> None: super().__init__( params=params, measurements=measurements, final_simulator_state=final_simulator_state diff --git a/cirq-core/cirq/sim/density_matrix_simulator_test.py b/cirq-core/cirq/sim/density_matrix_simulator_test.py index fc805a70554..bff52775e6c 100644 --- a/cirq-core/cirq/sim/density_matrix_simulator_test.py +++ b/cirq-core/cirq/sim/density_matrix_simulator_test.py @@ -591,7 +591,7 @@ def test_simulate_initial_state(dtype: Type[np.number], split: bool): @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_simulate_act_on_args(dtype: Type[np.number], split: bool): +def test_simulation_state(dtype: Type[np.number], split: bool): q0, q1 = cirq.LineQubit.range(2) simulator = cirq.DensityMatrixSimulator(dtype=dtype, split_untangled_states=split) for b0 in [0, 1]: @@ -976,7 +976,7 @@ def test_density_matrix_simulator_state_repr(): def test_density_matrix_trial_result_eq(): q0 = cirq.LineQubit(0) - final_simulator_state = cirq.ActOnDensityMatrixArgs( + final_simulator_state = cirq.DensityMatrixSimulationState( initial_state=np.ones((2, 2)) * 0.5, qubits=[q0] ) eq = cirq.testing.EqualsTester() @@ -1010,7 +1010,7 @@ def test_density_matrix_trial_result_eq(): def test_density_matrix_trial_result_qid_shape(): q0, q1 = cirq.LineQubit.range(2) - final_simulator_state = cirq.ActOnDensityMatrixArgs( + final_simulator_state = cirq.DensityMatrixSimulationState( initial_state=np.ones((4, 4)) / 4, qubits=[q0, q1] ) assert cirq.qid_shape( @@ -1021,7 +1021,7 @@ def test_density_matrix_trial_result_qid_shape(): ) ) == (2, 2) q0, q1 = cirq.LineQid.for_qid_shape((3, 4)) - final_simulator_state = cirq.ActOnDensityMatrixArgs( + final_simulator_state = cirq.DensityMatrixSimulationState( initial_state=np.ones((12, 12)) / 12, qubits=[q0, q1] ) assert cirq.qid_shape( @@ -1036,7 +1036,7 @@ def test_density_matrix_trial_result_qid_shape(): def test_density_matrix_trial_result_repr(): q0 = cirq.LineQubit(0) dtype = np.complex64 - final_simulator_state = cirq.ActOnDensityMatrixArgs( + final_simulator_state = cirq.DensityMatrixSimulationState( available_buffer=[], qid_shape=(2,), prng=np.random.RandomState(0), @@ -1053,7 +1053,7 @@ def test_density_matrix_trial_result_repr(): "cirq.DensityMatrixTrialResult(" "params=cirq.ParamResolver({'s': 1}), " "measurements={'m': np.array([[1]], dtype=np.int32)}, " - "final_simulator_state=cirq.ActOnDensityMatrixArgs(" + "final_simulator_state=cirq.DensityMatrixSimulationState(" "initial_state=np.array([[(0.5+0j), (0.5+0j)], [(0.5+0j), (0.5+0j)]], dtype=np.complex64), " "qid_shape=(2,), " "qubits=(cirq.LineQubit(0),), " @@ -1124,7 +1124,7 @@ def test_works_on_pauli_string(): def test_density_matrix_trial_result_str(): q0 = cirq.LineQubit(0) dtype = np.complex64 - final_simulator_state = cirq.ActOnDensityMatrixArgs( + final_simulator_state = cirq.DensityMatrixSimulationState( available_buffer=[], qid_shape=(2,), prng=np.random.RandomState(0), @@ -1149,7 +1149,7 @@ def test_density_matrix_trial_result_str(): def test_density_matrix_trial_result_repr_pretty(): q0 = cirq.LineQubit(0) dtype = np.complex64 - final_simulator_state = cirq.ActOnDensityMatrixArgs( + final_simulator_state = cirq.DensityMatrixSimulationState( available_buffer=[], qid_shape=(2,), prng=np.random.RandomState(0), @@ -1593,7 +1593,7 @@ def test_sweep_unparameterized_prefix_not_repeated_even_non_unitaries(): class NonUnitaryOp(cirq.Operation): count = 0 - def _act_on_(self, args): + def _act_on_(self, sim_state): self.count += 1 return True diff --git a/cirq-core/cirq/sim/operation_target.py b/cirq-core/cirq/sim/operation_target.py index 835062ff128..80448fffea1 100644 --- a/cirq-core/cirq/sim/operation_target.py +++ b/cirq-core/cirq/sim/operation_target.py @@ -1,4 +1,4 @@ -# Copyright 2021 The Cirq Developers +# Copyright 2022 The Cirq Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,124 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""An interface for quantum states as targets for operations.""" -import abc -from typing import ( - Any, - Dict, - Generic, - Iterator, - List, - Mapping, - Optional, - Sequence, - Tuple, - TYPE_CHECKING, - TypeVar, - Union, -) -import numpy as np +from cirq import _compat +from cirq.sim.simulation_state_base import SimulationStateBase -from cirq import protocols, value -from cirq.type_workarounds import NotImplementedType -if TYPE_CHECKING: - import cirq - - -TSelfTarget = TypeVar('TSelfTarget', bound='OperationTarget') -TActOnArgs = TypeVar('TActOnArgs', bound='cirq.ActOnArgs') - - -class OperationTarget(Generic[TActOnArgs], metaclass=abc.ABCMeta): - """An interface for quantum states as targets for operations.""" - - def __init__( - self, - *, - qubits: Sequence['cirq.Qid'], - classical_data: Optional['cirq.ClassicalDataStore'] = None, - ): - """Initializes the class. - - Args: - qubits: The canonical ordering of qubits. - classical_data: The shared classical data container for this - simulation. - """ - self._set_qubits(tuple(qubits)) - self._classical_data = classical_data or value.ClassicalDataDictionaryStore() - - @property - def qubits(self) -> Tuple['cirq.Qid', ...]: - return self._qubits - - @property - def qubit_map(self) -> Mapping['cirq.Qid', int]: - return self._qubit_map - - def _set_qubits(self, qubits: Sequence['cirq.Qid']): - self._qubits = tuple(qubits) - self._qubit_map = {q: i for i, q in enumerate(self.qubits)} - - @property - def classical_data(self) -> 'cirq.ClassicalDataStoreReader': - return self._classical_data - - @abc.abstractmethod - def create_merged_state(self) -> TActOnArgs: - """Creates a final merged state.""" - - @abc.abstractmethod - def _act_on_fallback_( - self, action: Any, qubits: Sequence['cirq.Qid'], allow_decompose: bool = True - ) -> Union[bool, NotImplementedType]: - """Handles the act_on protocol fallback implementation. - - Args: - action: A gate, operation, or other to act on. - qubits: The applicable qubits if a gate is passed as the action. - allow_decompose: Flag to allow decomposition. - - Returns: - True if the fallback applies, else NotImplemented.""" - - def apply_operation(self, op: 'cirq.Operation'): - protocols.act_on(op, self) - - @abc.abstractmethod - def copy(self: TSelfTarget, deep_copy_buffers: bool = True) -> TSelfTarget: - """Creates a copy of the object. - - Args: - deep_copy_buffers: If True, buffers will also be deep-copied. - Otherwise the copy will share a reference to the original object's - buffers. - - Returns: - A copied instance. - """ - - @property - def log_of_measurement_results(self) -> Dict[str, List[int]]: - """Gets the log of measurement results.""" - return {str(k): list(self.classical_data.get_digits(k)) for k in self.classical_data.keys()} - - @abc.abstractmethod - def sample( - self, - qubits: List['cirq.Qid'], - repetitions: int = 1, - seed: 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None, - ) -> np.ndarray: - """Samples the state value.""" - - def __getitem__(self, item: Optional['cirq.Qid']) -> TActOnArgs: - """Gets the item associated with the qubit.""" - - def __len__(self) -> int: - """Gets the number of items in the mapping.""" - - def __iter__(self) -> Iterator[Optional['cirq.Qid']]: - """Iterates the keys of the mapping.""" +@_compat.deprecated_class(deadline='v0.16', fix='Use cirq.SimulationStateBase instead.') +class OperationTarget(SimulationStateBase): + pass diff --git a/cirq-core/cirq/sim/simulation_product_state.py b/cirq-core/cirq/sim/simulation_product_state.py new file mode 100644 index 00000000000..c42a07103d9 --- /dev/null +++ b/cirq-core/cirq/sim/simulation_product_state.py @@ -0,0 +1,175 @@ +# Copyright 2021 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import abc +from typing import Any, Dict, Generic, Iterator, List, Mapping, Optional, Sequence, TYPE_CHECKING + +import numpy as np + +from cirq import ops, protocols, value +from cirq.sim.simulation_state import TSimulationState +from cirq.sim.simulation_state_base import SimulationStateBase + +if TYPE_CHECKING: + import cirq + + +class SimulationProductState( + Generic[TSimulationState], SimulationStateBase[TSimulationState], abc.Mapping +): + """A container for a `Qid`-to-`SimulationState` dictionary.""" + + def __init__( + self, + args: Dict[Optional['cirq.Qid'], TSimulationState], + qubits: Sequence['cirq.Qid'], + split_untangled_states: bool, + classical_data: Optional['cirq.ClassicalDataStore'] = None, + ): + """Initializes the class. + + Args: + args: The `SimulationState` dictionary. This will not be copied; the + original reference will be kept here. + qubits: The canonical ordering of qubits. + split_untangled_states: If True, optimizes operations by running + unentangled qubit sets independently and merging those states + at the end. + classical_data: The shared classical data container for this + simulation. + """ + classical_data = classical_data or value.ClassicalDataDictionaryStore() + super().__init__(qubits=qubits, classical_data=classical_data) + self._args = args + self._split_untangled_states = split_untangled_states + + @property + def args(self) -> Mapping[Optional['cirq.Qid'], TSimulationState]: + return self._args + + @property + def split_untangled_states(self) -> bool: + return self._split_untangled_states + + def create_merged_state(self) -> TSimulationState: + if not self.split_untangled_states: + return self.args[None] + final_args = self.args[None] + for args in set([self.args[k] for k in self.args.keys() if k is not None]): + final_args = final_args.kronecker_product(args) + return final_args.transpose_to_qubit_order(self.qubits) + + def _act_on_fallback_( + self, action: Any, qubits: Sequence['cirq.Qid'], allow_decompose: bool = True + ) -> bool: + gate_opt = ( + action + if isinstance(action, ops.Gate) + else action.gate + if isinstance(action, ops.Operation) + else None + ) + + if isinstance(gate_opt, ops.IdentityGate): + return True + + if ( + isinstance(gate_opt, ops.SwapPowGate) + and gate_opt.exponent % 2 == 1 + and gate_opt.global_shift == 0 + ): + q0, q1 = qubits + args0 = self.args[q0] + args1 = self.args[q1] + if args0 is args1: + args0.swap(q0, q1, inplace=True) + else: + self._args[q0] = args1.rename(q1, q0, inplace=True) + self._args[q1] = args0.rename(q0, q1, inplace=True) + return True + + # Go through the op's qubits and join any disparate SimulationState states + # into a new combined state. + op_args_opt: Optional[TSimulationState] = None + for q in qubits: + if op_args_opt is None: + op_args_opt = self.args[q] + elif q not in op_args_opt.qubits: + op_args_opt = op_args_opt.kronecker_product(self.args[q]) + op_args = op_args_opt or self.args[None] + + # (Backfill the args map with the new value) + for q in op_args.qubits: + self._args[q] = op_args + + # Act on the args with the operation + act_on_qubits = qubits if isinstance(action, ops.Gate) else None + protocols.act_on(action, op_args, act_on_qubits, allow_decompose=allow_decompose) + + # Decouple any measurements or resets + if self.split_untangled_states and isinstance( + gate_opt, (ops.ResetChannel, ops.MeasurementGate) + ): + for q in qubits: + if op_args.allows_factoring: + q_args, op_args = op_args.factor((q,), validate=False) + self._args[q] = q_args + + # (Backfill the args map with the new value) + for q in op_args.qubits: + self._args[q] = op_args + return True + + def copy( + self, deep_copy_buffers: bool = True + ) -> 'cirq.SimulationProductState[TSimulationState]': + classical_data = self._classical_data.copy() + copies = {} + for sim_state in set(self.args.values()): + copies[sim_state] = sim_state.copy(deep_copy_buffers) + for copy in copies.values(): + copy._classical_data = classical_data + args = {q: copies[a] for q, a in self.args.items()} + return SimulationProductState( + args, self.qubits, self.split_untangled_states, classical_data=classical_data + ) + + def sample( + self, + qubits: List['cirq.Qid'], + repetitions: int = 1, + seed: 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None, + ) -> np.ndarray: + columns = [] + selected_order: List[ops.Qid] = [] + q_set = set(qubits) + for v in dict.fromkeys(self.args.values()): + qs = [q for q in v.qubits if q in q_set] + if any(qs): + column = v.sample(qs, repetitions, seed) + columns.append(column) + selected_order += qs + stacked = np.column_stack(columns) + qubit_map = {q: i for i, q in enumerate(selected_order)} + index_order = [qubit_map[q] for q in qubits] + return stacked[:, index_order] + + def __getitem__(self, item: Optional['cirq.Qid']) -> TSimulationState: + return self.args[item] + + def __len__(self) -> int: + return len(self.args) + + def __iter__(self) -> Iterator[Optional['cirq.Qid']]: + return iter(self.args) diff --git a/cirq-core/cirq/sim/act_on_args_container_test.py b/cirq-core/cirq/sim/simulation_product_state_test.py similarity index 94% rename from cirq-core/cirq/sim/act_on_args_container_test.py rename to cirq-core/cirq/sim/simulation_product_state_test.py index 92dde560fdb..ab8c5082950 100644 --- a/cirq-core/cirq/sim/act_on_args_container_test.py +++ b/cirq-core/cirq/sim/simulation_product_state_test.py @@ -37,7 +37,7 @@ def reindex(self, axes): return self -class EmptyActOnArgs(cirq.ActOnArgs): +class EmptySimulationState(cirq.SimulationState): def __init__(self, qubits, classical_data): super().__init__(state=EmptyQuantumState(), qubits=qubits, classical_data=classical_data) @@ -53,19 +53,19 @@ def _act_on_fallback_( def create_container( qubits: Sequence['cirq.Qid'], split_untangled_states=True -) -> cirq.ActOnArgsContainer[EmptyActOnArgs]: - args_map: Dict[Optional['cirq.Qid'], EmptyActOnArgs] = {} +) -> cirq.SimulationProductState[EmptySimulationState]: + args_map: Dict[Optional['cirq.Qid'], EmptySimulationState] = {} log = cirq.ClassicalDataDictionaryStore() if split_untangled_states: for q in reversed(qubits): - args_map[q] = EmptyActOnArgs([q], log) - args_map[None] = EmptyActOnArgs((), log) + args_map[q] = EmptySimulationState([q], log) + args_map[None] = EmptySimulationState((), log) else: - args = EmptyActOnArgs(qubits, log) + args = EmptySimulationState(qubits, log) for q in qubits: args_map[q] = args - args_map[None] = args if not split_untangled_states else EmptyActOnArgs((), log) - return cirq.ActOnArgsContainer(args_map, qubits, split_untangled_states, classical_data=log) + args_map[None] = args if not split_untangled_states else EmptySimulationState((), log) + return cirq.SimulationProductState(args_map, qubits, split_untangled_states, classical_data=log) def test_entanglement_causes_join(): diff --git a/cirq-core/cirq/sim/simulation_state.py b/cirq-core/cirq/sim/simulation_state.py new file mode 100644 index 00000000000..401d6188777 --- /dev/null +++ b/cirq-core/cirq/sim/simulation_state.py @@ -0,0 +1,391 @@ +# Copyright 2021 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Objects and methods for acting efficiently on a state tensor.""" +import abc +import copy +from typing import ( + Any, + cast, + Dict, + Generic, + Iterator, + List, + Optional, + Sequence, + TypeVar, + TYPE_CHECKING, + Tuple, +) + +import numpy as np + +from cirq import protocols, value +from cirq._compat import _warn_or_error, deprecated, deprecated_parameter +from cirq.protocols.decompose_protocol import _try_decompose_into_operations_and_qubits +from cirq.sim.simulation_state_base import SimulationStateBase + +TSelf = TypeVar('TSelf', bound='SimulationState') +TState = TypeVar('TState', bound='cirq.QuantumStateRepresentation') + +if TYPE_CHECKING: + import cirq + + +class SimulationState(SimulationStateBase, Generic[TState], metaclass=abc.ABCMeta): + """State and context for an operation acting on a state tensor.""" + + @deprecated_parameter( + deadline='v0.16', + fix='Use kwargs instead of positional args', + parameter_desc='args', + match=lambda args, kwargs: len(args) > 1, + ) + @deprecated_parameter( + deadline='v0.16', + fix='Replace log_of_measurement_results with' + ' classical_data=cirq.ClassicalDataDictionaryStore(_records=logs).', + parameter_desc='log_of_measurement_results', + match=lambda args, kwargs: 'log_of_measurement_results' in kwargs, + ) + def __init__( + self, + prng: Optional[np.random.RandomState] = None, + qubits: Optional[Sequence['cirq.Qid']] = None, + log_of_measurement_results: Optional[Dict[str, List[int]]] = None, + classical_data: Optional['cirq.ClassicalDataStore'] = None, + state: Optional[TState] = None, + ): + """Inits SimulationState. + + Args: + prng: The pseudo random number generator to use for probabilistic + effects. + qubits: Determines the canonical ordering of the qubits. This + is often used in specifying the initial state, i.e. the + ordering of the computational basis states. + log_of_measurement_results: A mutable object that measurements are + being recorded into. + classical_data: The shared classical data container for this + simulation. + state: The underlying quantum state of the simulation. + """ + if qubits is None: + qubits = () + classical_data = classical_data or value.ClassicalDataDictionaryStore( + _records={ + value.MeasurementKey.parse_serialized(k): [tuple(v)] + for k, v in (log_of_measurement_results or {}).items() + } + ) + super().__init__(qubits=qubits, classical_data=classical_data) + if prng is None: + prng = cast(np.random.RandomState, np.random) + self._prng = prng + self._state = cast(TState, state) + if state is None: + _warn_or_error('This function will require a valid `state` input in cirq v0.16.') + + @property + def prng(self) -> np.random.RandomState: + return self._prng + + def measure(self, qubits: Sequence['cirq.Qid'], key: str, invert_mask: Sequence[bool]): + """Measures the qubits and records to `log_of_measurement_results`. + + Any bitmasks will be applied to the measurement record. + + Args: + qubits: The qubits to measure. + key: The key the measurement result should be logged under. Note + that operations should only store results under keys they have + declared in a `_measurement_key_names_` method. + invert_mask: The invert mask for the measurement. + + Raises: + ValueError: If a measurement key has already been logged to a key. + """ + bits = self._perform_measurement(qubits) + corrected = [bit ^ (bit < 2 and mask) for bit, mask in zip(bits, invert_mask)] + self._classical_data.record_measurement( + value.MeasurementKey.parse_serialized(key), corrected, qubits + ) + + def get_axes(self, qubits: Sequence['cirq.Qid']) -> List[int]: + return [self.qubit_map[q] for q in qubits] + + def _perform_measurement(self, qubits: Sequence['cirq.Qid']) -> List[int]: + """Delegates the call to measure the density matrix.""" + if self._state is not None: + return self._state.measure(self.get_axes(qubits), self.prng) + raise NotImplementedError() + + def sample( + self, + qubits: Sequence['cirq.Qid'], + repetitions: int = 1, + seed: 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None, + ) -> np.ndarray: + if self._state is not None: + return self._state.sample(self.get_axes(qubits), repetitions, seed) + raise NotImplementedError() + + def copy(self: TSelf, deep_copy_buffers: bool = True) -> TSelf: + """Creates a copy of the object. + + Args: + deep_copy_buffers: If True, buffers will also be deep-copied. + Otherwise the copy will share a reference to the original object's + buffers. + + Returns: + A copied instance. + """ + args = copy.copy(self) + args._classical_data = self._classical_data.copy() + if self._state is not None: + args._state = self._state.copy(deep_copy_buffers=deep_copy_buffers) + else: + _warn_or_error( + 'Pass a `QuantumStateRepresentation` into the `SimulationState` constructor.' + ' The `_on_` overrides will be removed in cirq v0.16.' + ) + self._on_copy(args, deep_copy_buffers) + return args + + @deprecated( + deadline='v0.16', + fix='Pass a `QuantumStateRepresentation` into the `SimulationState` constructor.', + ) + def _on_copy(self: TSelf, args: TSelf, deep_copy_buffers: bool = True): + """Subclasses should implement this with any additional state copy + functionality.""" + + def create_merged_state(self: TSelf) -> TSelf: + """Creates a final merged state.""" + return self + + def kronecker_product(self: TSelf, other: TSelf, *, inplace=False) -> TSelf: + """Joins two state spaces together.""" + args = self if inplace else copy.copy(self) + if self._state is not None and other._state is not None: + args._state = self._state.kron(other._state) + else: + _warn_or_error( + 'Pass a `QuantumStateRepresentation` into the `SimulationState` constructor.' + ' The `_on_` overrides will be removed in cirq v0.16.' + ) + self._on_kronecker_product(other, args) + args._set_qubits(self.qubits + other.qubits) + return args + + @deprecated( + deadline='v0.16', + fix='Pass a `QuantumStateRepresentation` into the `SimulationState` constructor.', + ) + def _on_kronecker_product(self: TSelf, other: TSelf, target: TSelf): + """Subclasses should implement this with any additional state product + functionality, if supported.""" + + def with_qubits(self: TSelf, qubits) -> TSelf: + """Extend current state space with added qubits. + + The state of the added qubits is the default value set in the + subclasses. A new state space is created as the Kronecker product of + the original one and the added one. + + Args: + qubits: The qubits to be added to the state space. + + Regurns: + A new subclass object containing the extended state space. + """ + new_space = type(self)(qubits=qubits) + return self.kronecker_product(new_space) + + def factor( + self: TSelf, qubits: Sequence['cirq.Qid'], *, validate=True, atol=1e-07, inplace=False + ) -> Tuple[TSelf, TSelf]: + """Splits two state spaces after a measurement or reset.""" + extracted = copy.copy(self) + remainder = self if inplace else copy.copy(self) + if self._state is not None: + e, r = self._state.factor(self.get_axes(qubits), validate=validate, atol=atol) + extracted._state = e + remainder._state = r + else: + _warn_or_error( + 'Pass a `QuantumStateRepresentation` into the `SimulationState` constructor.' + ' The `_on_` overrides will be removed in cirq v0.16.' + ) + self._on_factor(qubits, extracted, remainder, validate, atol) + extracted._set_qubits(qubits) + remainder._set_qubits([q for q in self.qubits if q not in qubits]) + return extracted, remainder + + @property + def allows_factoring(self): + """Subclasses that allow factorization should override this.""" + return self._state.supports_factor if self._state is not None else False + + @deprecated( + deadline='v0.16', + fix='Pass a `QuantumStateRepresentation` into the `SimulationState` constructor.', + ) + def _on_factor( + self: TSelf, + qubits: Sequence['cirq.Qid'], + extracted: TSelf, + remainder: TSelf, + validate=True, + atol=1e-07, + ): + """Subclasses should implement this with any additional state factor + functionality, if supported.""" + + def transpose_to_qubit_order( + self: TSelf, qubits: Sequence['cirq.Qid'], *, inplace=False + ) -> TSelf: + """Physically reindexes the state by the new basis. + + Args: + qubits: The desired qubit order. + inplace: True to perform this operation inplace. + + Returns: + The state with qubit order transposed and underlying representation + updated. + + Raises: + ValueError: If the provided qubits do not match the existing ones. + """ + if len(self.qubits) != len(qubits) or set(qubits) != set(self.qubits): + raise ValueError(f'Qubits do not match. Existing: {self.qubits}, provided: {qubits}') + args = self if inplace else copy.copy(self) + if self._state is not None: + args._state = self._state.reindex(self.get_axes(qubits)) + else: + _warn_or_error( + 'Pass a `QuantumStateRepresentation` into the `SimulationState` constructor.' + ' The `_on_` overrides will be removed in cirq v0.16.' + ) + self._on_transpose_to_qubit_order(qubits, args) + args._set_qubits(qubits) + return args + + @deprecated( + deadline='v0.16', + fix='Pass a `QuantumStateRepresentation` into the `SimulationState` constructor.', + ) + def _on_transpose_to_qubit_order(self: TSelf, qubits: Sequence['cirq.Qid'], target: TSelf): + """Subclasses should implement this with any additional state transpose + functionality, if supported.""" + + @property # type: ignore + @deprecated(deadline='v0.16', fix='Remove this call, it always returns False.') + def ignore_measurement_results(self) -> bool: + return False + + @property + def qubits(self) -> Tuple['cirq.Qid', ...]: + return self._qubits + + def swap(self, q1: 'cirq.Qid', q2: 'cirq.Qid', *, inplace=False): + """Swaps two qubits. + + This only affects the index, and does not modify the underlying + state. + + Args: + q1: The first qubit to swap. + q2: The second qubit to swap. + inplace: True to swap the qubits in the current object, False to + create a copy with the qubits swapped. + + Returns: + The original object with the qubits swapped if inplace is + requested, or a copy of the original object with the qubits swapped + otherwise. + + Raises: + ValueError: If the qubits are of different dimensionality. + """ + if q1.dimension != q2.dimension: + raise ValueError(f'Cannot swap different dimensions: q1={q1}, q2={q2}') + + args = self if inplace else copy.copy(self) + i1 = self.qubits.index(q1) + i2 = self.qubits.index(q2) + qubits = list(args.qubits) + qubits[i1], qubits[i2] = qubits[i2], qubits[i1] + args._set_qubits(qubits) + return args + + def rename(self, q1: 'cirq.Qid', q2: 'cirq.Qid', *, inplace=False): + """Renames `q1` to `q2`. + + Args: + q1: The qubit to rename. + q2: The new name. + inplace: True to rename the qubit in the current object, False to + create a copy with the qubit renamed. + + Returns: + The original object with the qubits renamed if inplace is + requested, or a copy of the original object with the qubits renamed + otherwise. + + Raises: + ValueError: If the qubits are of different dimensionality. + """ + if q1.dimension != q2.dimension: + raise ValueError(f'Cannot rename to different dimensions: q1={q1}, q2={q2}') + + args = self if inplace else copy.copy(self) + i1 = self.qubits.index(q1) + qubits = list(args.qubits) + qubits[i1] = q2 + args._set_qubits(qubits) + return args + + def __getitem__(self: TSelf, item: Optional['cirq.Qid']) -> TSelf: + if item not in self.qubit_map: + raise IndexError(f'{item} not in {self.qubits}') + return self + + def __len__(self) -> int: + return len(self.qubits) + + def __iter__(self) -> Iterator[Optional['cirq.Qid']]: + return iter(self.qubits) + + @property + def can_represent_mixed_states(self) -> bool: + return self._state.can_represent_mixed_states if self._state is not None else False + + +def strat_act_on_from_apply_decompose( + val: Any, args: 'cirq.SimulationState', qubits: Sequence['cirq.Qid'] +) -> bool: + operations, qubits1, _ = _try_decompose_into_operations_and_qubits(val) + assert len(qubits1) == len(qubits) + qubit_map = {q: qubits[i] for i, q in enumerate(qubits1)} + if operations is None: + return NotImplemented + for operation in operations: + operation = operation.with_qubits(*[qubit_map[q] for q in operation.qubits]) + protocols.act_on(operation, args) + return True + + +TSimulationState = TypeVar('TSimulationState', bound=SimulationState) diff --git a/cirq-core/cirq/sim/simulation_state_base.py b/cirq-core/cirq/sim/simulation_state_base.py new file mode 100644 index 00000000000..89033c7ac44 --- /dev/null +++ b/cirq-core/cirq/sim/simulation_state_base.py @@ -0,0 +1,134 @@ +# Copyright 2021 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""An interface for quantum states as targets for operations.""" +import abc +from typing import ( + Any, + Dict, + Generic, + Iterator, + List, + Mapping, + Optional, + Sequence, + Tuple, + TYPE_CHECKING, + TypeVar, + Union, +) + +import numpy as np + +from cirq import protocols, value +from cirq.type_workarounds import NotImplementedType + +if TYPE_CHECKING: + import cirq + + +TSelfTarget = TypeVar('TSelfTarget', bound='SimulationStateBase') +TSimulationState = TypeVar('TSimulationState', bound='cirq.SimulationState') + + +class SimulationStateBase(Generic[TSimulationState], metaclass=abc.ABCMeta): + """An interface for quantum states as targets for operations.""" + + def __init__( + self, + *, + qubits: Sequence['cirq.Qid'], + classical_data: Optional['cirq.ClassicalDataStore'] = None, + ): + """Initializes the class. + + Args: + qubits: The canonical ordering of qubits. + classical_data: The shared classical data container for this + simulation. + """ + self._set_qubits(tuple(qubits)) + self._classical_data = classical_data or value.ClassicalDataDictionaryStore() + + @property + def qubits(self) -> Tuple['cirq.Qid', ...]: + return self._qubits + + @property + def qubit_map(self) -> Mapping['cirq.Qid', int]: + return self._qubit_map + + def _set_qubits(self, qubits: Sequence['cirq.Qid']): + self._qubits = tuple(qubits) + self._qubit_map = {q: i for i, q in enumerate(self.qubits)} + + @property + def classical_data(self) -> 'cirq.ClassicalDataStoreReader': + return self._classical_data + + @abc.abstractmethod + def create_merged_state(self) -> TSimulationState: + """Creates a final merged state.""" + + @abc.abstractmethod + def _act_on_fallback_( + self, action: Any, qubits: Sequence['cirq.Qid'], allow_decompose: bool = True + ) -> Union[bool, NotImplementedType]: + """Handles the act_on protocol fallback implementation. + + Args: + action: A gate, operation, or other to act on. + qubits: The applicable qubits if a gate is passed as the action. + allow_decompose: Flag to allow decomposition. + + Returns: + True if the fallback applies, else NotImplemented.""" + + def apply_operation(self, op: 'cirq.Operation'): + protocols.act_on(op, self) + + @abc.abstractmethod + def copy(self: TSelfTarget, deep_copy_buffers: bool = True) -> TSelfTarget: + """Creates a copy of the object. + + Args: + deep_copy_buffers: If True, buffers will also be deep-copied. + Otherwise the copy will share a reference to the original object's + buffers. + + Returns: + A copied instance. + """ + + @property + def log_of_measurement_results(self) -> Dict[str, List[int]]: + """Gets the log of measurement results.""" + return {str(k): list(self.classical_data.get_digits(k)) for k in self.classical_data.keys()} + + @abc.abstractmethod + def sample( + self, + qubits: List['cirq.Qid'], + repetitions: int = 1, + seed: 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None, + ) -> np.ndarray: + """Samples the state value.""" + + def __getitem__(self, item: Optional['cirq.Qid']) -> TSimulationState: + """Gets the item associated with the qubit.""" + + def __len__(self) -> int: + """Gets the number of items in the mapping.""" + + def __iter__(self) -> Iterator[Optional['cirq.Qid']]: + """Iterates the keys of the mapping.""" diff --git a/cirq-core/cirq/sim/act_on_args_test.py b/cirq-core/cirq/sim/simulation_state_test.py similarity index 89% rename from cirq-core/cirq/sim/act_on_args_test.py rename to cirq-core/cirq/sim/simulation_state_test.py index a8b9a4c51ad..6c64ceee268 100644 --- a/cirq-core/cirq/sim/act_on_args_test.py +++ b/cirq-core/cirq/sim/simulation_state_test.py @@ -18,7 +18,7 @@ import pytest import cirq -from cirq.sim import act_on_args +from cirq.sim import simulation_state class DummyQuantumState(cirq.QuantumStateRepresentation): @@ -32,7 +32,7 @@ def reindex(self, axes): return self -class DummyArgs(cirq.ActOnArgs): +class DummySimulationState(cirq.SimulationState): def __init__(self): super().__init__(state=DummyQuantumState(), qubits=cirq.LineQubit.range(2)) @@ -43,7 +43,7 @@ def _act_on_fallback_( def test_measurements(): - args = DummyArgs() + args = DummySimulationState() args.measure([cirq.LineQubit(0)], "test", [False]) assert args.log_of_measurement_results["test"] == [5] @@ -56,12 +56,14 @@ def num_qubits(self) -> int: def _decompose_(self, qubits): yield cirq.X(*qubits) - args = DummyArgs() - assert act_on_args.strat_act_on_from_apply_decompose(Composite(), args, [cirq.LineQubit(0)]) + args = DummySimulationState() + assert simulation_state.strat_act_on_from_apply_decompose( + Composite(), args, [cirq.LineQubit(0)] + ) def test_mapping(): - args = DummyArgs() + args = DummySimulationState() assert list(iter(args)) == cirq.LineQubit.range(2) r1 = args[cirq.LineQubit(0)] assert args is r1 @@ -72,7 +74,7 @@ def test_mapping(): def test_swap_bad_dimensions(): q0 = cirq.LineQubit(0) q1 = cirq.LineQid(1, 3) - args = DummyArgs() + args = DummySimulationState() with pytest.raises(ValueError, match='Cannot swap different dimensions'): args.swap(q0, q1) @@ -80,14 +82,14 @@ def test_swap_bad_dimensions(): def test_rename_bad_dimensions(): q0 = cirq.LineQubit(0) q1 = cirq.LineQid(1, 3) - args = DummyArgs() + args = DummySimulationState() with pytest.raises(ValueError, match='Cannot rename to different dimensions'): args.rename(q0, q1) def test_transpose_qubits(): q0, q1, q2 = cirq.LineQubit.range(3) - args = DummyArgs() + args = DummySimulationState() assert args.transpose_to_qubit_order((q1, q0)).qubits == (q1, q0) with pytest.raises(ValueError, match='Qubits do not match'): args.transpose_to_qubit_order((q0, q2)) @@ -96,7 +98,7 @@ def test_transpose_qubits(): def test_field_getters(): - args = DummyArgs() + args = DummySimulationState() assert args.prng is np.random assert args.qubit_map == {q: i for i, q in enumerate(cirq.LineQubit.range(2))} with cirq.testing.assert_deprecated('always returns False', deadline='v0.16'): @@ -104,7 +106,7 @@ def test_field_getters(): def test_on_methods_deprecated(): - class OldStyleArgs(cirq.ActOnArgs): + class OldStyleArgs(cirq.SimulationState): def _act_on_fallback_(self, action, qubits, allow_decompose=True): pass @@ -121,7 +123,7 @@ def _act_on_fallback_(self, action, qubits, allow_decompose=True): def test_on_methods_deprecated_if_implemented(): - class OldStyleArgs(cirq.ActOnArgs): + class OldStyleArgs(cirq.SimulationState): def _act_on_fallback_(self, action, qubits, allow_decompose=True): pass @@ -150,7 +152,7 @@ def _on_transpose_to_qubit_order(self, qubits, target): def test_deprecated(): - class DeprecatedArgs(cirq.ActOnArgs): + class DeprecatedArgs(cirq.SimulationState): def _act_on_fallback_(self, action, qubits, allow_decompose=True): pass diff --git a/cirq-core/cirq/sim/simulator.py b/cirq-core/cirq/sim/simulator.py index 2d9a160474d..1c9ed276151 100644 --- a/cirq-core/cirq/sim/simulator.py +++ b/cirq-core/cirq/sim/simulator.py @@ -50,7 +50,7 @@ import numpy as np from cirq import _compat, circuits, ops, protocols, study, value, work -from cirq.sim.operation_target import OperationTarget +from cirq.sim.simulation_state_base import SimulationStateBase if TYPE_CHECKING: import cirq @@ -586,7 +586,7 @@ def simulate_sweep_iter( is often used in specifying the initial state, i.e. the ordering of the computational basis states. initial_state: The initial state for the simulation. This can be - either a raw state or an `OperationTarget`. The form of the + either a raw state or an `SimulationStateBase`. The form of the raw state depends on the simulation implementation. See documentation of the implementing class for details. @@ -598,7 +598,7 @@ def simulate_sweep_iter( for param_resolver in study.to_resolvers(params): state = ( initial_state.copy() - if isinstance(initial_state, OperationTarget) + if isinstance(initial_state, SimulationStateBase) else initial_state ) all_step_results = self.simulate_moment_steps( @@ -643,7 +643,7 @@ def simulate_moment_steps( is often used in specifying the initial state, i.e. the ordering of the computational basis states. initial_state: The initial state for the simulation. This can be - either a raw state or a `TActOnArgs`. The form of the + either a raw state or a `TSimulationState`. The form of the raw state depends on the simulation implementation. See documentation of the implementing class for details. @@ -684,8 +684,8 @@ def _base_iterator( StepResults from simulating a Moment of the Circuit. """ qubits = ops.QubitOrder.as_qubit_order(qubit_order).order_for(circuit.all_qubits()) - act_on_args = self._create_act_on_args(initial_state, qubits) - return self._core_iterator(circuit, act_on_args) + sim_state = self._create_act_on_args(initial_state, qubits) + return self._core_iterator(circuit, sim_state) @abc.abstractmethod def _create_act_on_args( diff --git a/cirq-core/cirq/sim/simulator_base.py b/cirq-core/cirq/sim/simulator_base.py index e147a708c1a..74a1fba4f84 100644 --- a/cirq-core/cirq/sim/simulator_base.py +++ b/cirq-core/cirq/sim/simulator_base.py @@ -34,10 +34,10 @@ import numpy as np from cirq import ops, protocols, study, value, devices -from cirq.sim import ActOnArgsContainer from cirq.sim import simulator -from cirq.sim.act_on_args import TActOnArgs -from cirq.sim.operation_target import OperationTarget +from cirq.sim.simulation_product_state import SimulationProductState +from cirq.sim.simulation_state import TSimulationState +from cirq.sim.simulation_state_base import SimulationStateBase from cirq.sim.simulator import ( TSimulationTrialResult, SimulatesIntermediateState, @@ -56,9 +56,9 @@ class SimulatorBase( - Generic[TStepResultBase, TSimulationTrialResult, TActOnArgs], + Generic[TStepResultBase, TSimulationTrialResult, TSimulationState], SimulatesIntermediateState[ - TStepResultBase, TSimulationTrialResult, OperationTarget[TActOnArgs] + TStepResultBase, TSimulationTrialResult, SimulationStateBase[TSimulationState] ], SimulatesSamples, metaclass=abc.ABCMeta, @@ -119,8 +119,8 @@ def _create_partial_act_on_args( initial_state: Any, qubits: Sequence['cirq.Qid'], classical_data: 'cirq.ClassicalDataStore', - ) -> TActOnArgs: - """Creates an instance of the TActOnArgs class for the simulator. + ) -> TSimulationState: + """Creates an instance of the TSimulationState class for the simulator. It represents the supplied qubits initialized to the provided state. @@ -134,11 +134,13 @@ def _create_partial_act_on_args( """ @abc.abstractmethod - def _create_step_result(self, sim_state: OperationTarget[TActOnArgs]) -> TStepResultBase: + def _create_step_result( + self, sim_state: SimulationStateBase[TSimulationState] + ) -> TStepResultBase: """This method should be implemented to create a step result. Args: - sim_state: The OperationTarget for this trial. + sim_state: The SimulationStateBase for this trial. Returns: The StepResult. @@ -169,7 +171,7 @@ def _can_be_in_run_prefix(self, val: Any): def _core_iterator( self, circuit: 'cirq.AbstractCircuit', - sim_state: OperationTarget[TActOnArgs], + sim_state: SimulationStateBase[TSimulationState], all_measurements_are_terminal: bool = False, ) -> Iterator[TStepResultBase]: """Standard iterator over StepResult from Moments of a Circuit. @@ -224,7 +226,7 @@ def _run( resolved_circuit = protocols.resolve_parameters(circuit, param_resolver) check_all_resolved(resolved_circuit) qubits = tuple(sorted(resolved_circuit.all_qubits())) - act_on_args = self._create_act_on_args(0, qubits) + sim_state = self._create_act_on_args(0, qubits) prefix, general_suffix = ( split_into_matching_protocol_then_general(resolved_circuit, self._can_be_in_run_prefix) @@ -232,13 +234,13 @@ def _run( else (resolved_circuit[0:0], resolved_circuit) ) step_result = None - for step_result in self._core_iterator(circuit=prefix, sim_state=act_on_args): + for step_result in self._core_iterator(circuit=prefix, sim_state=sim_state): pass general_ops = list(general_suffix.all_operations()) if all(isinstance(op.gate, ops.MeasurementGate) for op in general_ops): for step_result in self._core_iterator( - circuit=general_suffix, sim_state=act_on_args, all_measurements_are_terminal=True + circuit=general_suffix, sim_state=sim_state, all_measurements_are_terminal=True ): pass assert step_result is not None @@ -251,9 +253,9 @@ def _run( for i in range(repetitions): for step_result in self._core_iterator( general_suffix, - sim_state=act_on_args.copy(deep_copy_buffers=False) + sim_state=sim_state.copy(deep_copy_buffers=False) if i < repetitions - 1 - else act_on_args, + else sim_state, ): pass for k, r in step_result._classical_data.records.items(): @@ -286,7 +288,7 @@ def simulate_sweep_iter( is often used in specifying the initial state, i.e. the ordering of the computational basis states. initial_state: The initial state for the simulation. This can be - either a raw state or an `OperationTarget`. The form of the + either a raw state or an `SimulationStateBase`. The form of the raw state depends on the simulation implementation. See documentation of the implementing class for details. @@ -314,13 +316,13 @@ def sweep_prefixable(op: 'cirq.Operation'): def _create_act_on_args( self, initial_state: Any, qubits: Sequence['cirq.Qid'] - ) -> OperationTarget[TActOnArgs]: - if isinstance(initial_state, OperationTarget): + ) -> SimulationStateBase[TSimulationState]: + if isinstance(initial_state, SimulationStateBase): return initial_state classical_data = value.ClassicalDataDictionaryStore() if self._split_untangled_states: - args_map: Dict[Optional['cirq.Qid'], TActOnArgs] = {} + args_map: Dict[Optional['cirq.Qid'], TSimulationState] = {} if isinstance(initial_state, int): for q in reversed(qubits): args_map[q] = self._create_partial_act_on_args( @@ -336,7 +338,7 @@ def _create_act_on_args( for q in qubits: args_map[q] = args args_map[None] = self._create_partial_act_on_args(0, (), classical_data) - return ActOnArgsContainer( + return SimulationProductState( args_map, qubits, self._split_untangled_states, classical_data=classical_data ) else: @@ -345,17 +347,19 @@ def _create_act_on_args( ) -class StepResultBase(Generic[TActOnArgs], StepResult[OperationTarget[TActOnArgs]], abc.ABC): +class StepResultBase( + Generic[TSimulationState], StepResult[SimulationStateBase[TSimulationState]], abc.ABC +): """A base class for step results.""" - def __init__(self, sim_state: OperationTarget[TActOnArgs]): + def __init__(self, sim_state: SimulationStateBase[TSimulationState]): """Initializes the step result. Args: - sim_state: The `OperationTarget` for this step. + sim_state: The `SimulationStateBase` for this step. """ super().__init__(sim_state) - self._merged_sim_state_cache: Optional[TActOnArgs] = None + self._merged_sim_state_cache: Optional[TSimulationState] = None qubits = sim_state.qubits self._qubits = qubits self._qubit_mapping = {q: i for i, q in enumerate(qubits)} @@ -366,7 +370,7 @@ def _qid_shape_(self): return self._qubit_shape @property - def _merged_sim_state(self) -> TActOnArgs: + def _merged_sim_state(self) -> TSimulationState: if self._merged_sim_state_cache is None: self._merged_sim_state_cache = self._sim_state.create_merged_state() return self._merged_sim_state_cache @@ -381,7 +385,7 @@ def sample( class SimulationTrialResultBase( - SimulationTrialResult[OperationTarget[TActOnArgs]], Generic[TActOnArgs], abc.ABC + SimulationTrialResult[SimulationStateBase[TSimulationState]], Generic[TSimulationState], abc.ABC ): """A base class for trial results.""" @@ -390,7 +394,7 @@ def __init__( self, params: study.ParamResolver, measurements: Dict[str, np.ndarray], - final_simulator_state: 'cirq.OperationTarget[TActOnArgs]', + final_simulator_state: 'cirq.SimulationStateBase[TSimulationState]', ) -> None: """Initializes the `SimulationTrialResultBase` class. @@ -404,9 +408,9 @@ def __init__( trial finishes. """ super().__init__(params, measurements, final_simulator_state=final_simulator_state) - self._merged_sim_state_cache: Optional[TActOnArgs] = None + self._merged_sim_state_cache: Optional[TSimulationState] = None - def get_state_containing_qubit(self, qubit: 'cirq.Qid') -> TActOnArgs: + def get_state_containing_qubit(self, qubit: 'cirq.Qid') -> TSimulationState: """Returns the independent state space containing the qubit. Args: @@ -416,17 +420,17 @@ def get_state_containing_qubit(self, qubit: 'cirq.Qid') -> TActOnArgs: The state space containing the qubit.""" return self._final_simulator_state[qubit] - def _get_substates(self) -> Sequence[TActOnArgs]: + def _get_substates(self) -> Sequence[TSimulationState]: state = self._final_simulator_state - if isinstance(state, ActOnArgsContainer): - substates: Dict[TActOnArgs, int] = {} + if isinstance(state, SimulationProductState): + substates: Dict[TSimulationState, int] = {} for q in state.qubits: substates[self.get_state_containing_qubit(q)] = 0 substates[state[None]] = 0 return tuple(substates.keys()) return [state.create_merged_state()] - def _get_merged_sim_state(self) -> TActOnArgs: + def _get_merged_sim_state(self) -> TSimulationState: if self._merged_sim_state_cache is None: self._merged_sim_state_cache = self._final_simulator_state.create_merged_state() return self._merged_sim_state_cache diff --git a/cirq-core/cirq/sim/simulator_base_test.py b/cirq-core/cirq/sim/simulator_base_test.py index 47e3170fc24..ec2042b673b 100644 --- a/cirq-core/cirq/sim/simulator_base_test.py +++ b/cirq-core/cirq/sim/simulator_base_test.py @@ -56,7 +56,7 @@ def copy(self, deep_copy_buffers: bool = True) -> 'CountingState': ) -class CountingActOnArgs(cirq.ActOnArgs[CountingState]): +class CountingSimulationState(cirq.SimulationState[CountingState]): def __init__(self, state, qubits, classical_data): state_obj = CountingState(state) super().__init__(state=state_obj, qubits=qubits, classical_data=classical_data) @@ -80,13 +80,13 @@ def measurement_count(self): return self._state.measurement_count -class SplittableCountingActOnArgs(CountingActOnArgs): +class SplittableCountingSimulationState(CountingSimulationState): @property def allows_factoring(self): return True -class CountingStepResult(cirq.StepResultBase[CountingActOnArgs]): +class CountingStepResult(cirq.StepResultBase[CountingSimulationState]): def sample( self, qubits: List[cirq.Qid], @@ -98,16 +98,16 @@ def sample( measurements.append(self._merged_sim_state._perform_measurement(qubits)) return np.array(measurements, dtype=int) - def _simulator_state(self) -> CountingActOnArgs: + def _simulator_state(self) -> CountingSimulationState: return self._merged_sim_state -class CountingTrialResult(cirq.SimulationTrialResultBase[CountingActOnArgs]): +class CountingTrialResult(cirq.SimulationTrialResultBase[CountingSimulationState]): pass class CountingSimulator( - cirq.SimulatorBase[CountingStepResult, CountingTrialResult, CountingActOnArgs] + cirq.SimulatorBase[CountingStepResult, CountingTrialResult, CountingSimulationState] ): def __init__(self, noise=None, split_untangled_states=False): super().__init__(noise=noise, split_untangled_states=split_untangled_states) @@ -117,21 +117,23 @@ def _create_partial_act_on_args( initial_state: Any, qubits: Sequence['cirq.Qid'], classical_data: cirq.ClassicalDataStore, - ) -> CountingActOnArgs: - return CountingActOnArgs(qubits=qubits, state=initial_state, classical_data=classical_data) + ) -> CountingSimulationState: + return CountingSimulationState( + qubits=qubits, state=initial_state, classical_data=classical_data + ) def _create_simulator_trial_result( self, params: cirq.ParamResolver, measurements: Dict[str, np.ndarray], - final_simulator_state: 'cirq.OperationTarget[CountingActOnArgs]', + final_simulator_state: 'cirq.SimulationStateBase[CountingSimulationState]', ) -> CountingTrialResult: return CountingTrialResult( params, measurements, final_simulator_state=final_simulator_state ) def _create_step_result( - self, sim_state: cirq.OperationTarget[CountingActOnArgs] + self, sim_state: cirq.SimulationStateBase[CountingSimulationState] ) -> CountingStepResult: return CountingStepResult(sim_state) @@ -145,8 +147,8 @@ def _create_partial_act_on_args( initial_state: Any, qubits: Sequence['cirq.Qid'], classical_data: cirq.ClassicalDataStore, - ) -> CountingActOnArgs: - return SplittableCountingActOnArgs( + ) -> CountingSimulationState: + return SplittableCountingSimulationState( qubits=qubits, state=initial_state, classical_data=classical_data ) @@ -205,7 +207,7 @@ def test_noise_applied_measurement_gate(): def test_cannot_act(): class BadOp(TestOp): - def _act_on_(self, args): + def _act_on_(self, sim_state): raise TypeError() sim = CountingSimulator() @@ -260,7 +262,7 @@ def test_integer_initial_state_is_split(): def test_integer_initial_state_is_not_split_if_disabled(): sim = SplittableCountingSimulator(split_untangled_states=False) args = sim._create_act_on_args(2, (q0, q1)) - assert isinstance(args, SplittableCountingActOnArgs) + assert isinstance(args, SplittableCountingSimulationState) assert args[q0] is args[q1] assert args.state == 2 @@ -268,8 +270,8 @@ def test_integer_initial_state_is_not_split_if_disabled(): def test_integer_initial_state_is_not_split_if_impossible(): sim = CountingSimulator() args = sim._create_act_on_args(2, (q0, q1)) - assert isinstance(args, CountingActOnArgs) - assert not isinstance(args, SplittableCountingActOnArgs) + assert isinstance(args, CountingSimulationState) + assert not isinstance(args, SplittableCountingSimulationState) assert args[q0] is args[q1] assert args.state == 2 @@ -306,20 +308,20 @@ def test_measurement_causes_split(): def test_measurement_does_not_split_if_disabled(): sim = SplittableCountingSimulator(split_untangled_states=False) args = sim._create_act_on_args(2, (q0, q1)) - assert isinstance(args, SplittableCountingActOnArgs) + assert isinstance(args, SplittableCountingSimulationState) args.apply_operation(cirq.measure(q0)) - assert isinstance(args, SplittableCountingActOnArgs) + assert isinstance(args, SplittableCountingSimulationState) assert args[q0] is args[q1] def test_measurement_does_not_split_if_impossible(): sim = CountingSimulator() args = sim._create_act_on_args(2, (q0, q1)) - assert isinstance(args, CountingActOnArgs) - assert not isinstance(args, SplittableCountingActOnArgs) + assert isinstance(args, CountingSimulationState) + assert not isinstance(args, SplittableCountingSimulationState) args.apply_operation(cirq.measure(q0)) - assert isinstance(args, CountingActOnArgs) - assert not isinstance(args, SplittableCountingActOnArgs) + assert isinstance(args, CountingSimulationState) + assert not isinstance(args, SplittableCountingSimulationState) assert args[q0] is args[q1] @@ -360,7 +362,7 @@ def __init__(self, *, has_unitary: bool): self.count = 0 self.has_unitary = has_unitary - def _act_on_(self, args): + def _act_on_(self, sim_state): self.count += 1 return True diff --git a/cirq-core/cirq/sim/simulator_test.py b/cirq-core/cirq/sim/simulator_test.py index 52b886196b8..9352d358a83 100644 --- a/cirq-core/cirq/sim/simulator_test.py +++ b/cirq-core/cirq/sim/simulator_test.py @@ -22,7 +22,7 @@ import cirq from cirq import study -from cirq.sim.act_on_args import TActOnArgs +from cirq.sim.simulation_state import TSimulationState from cirq.sim.simulator import ( TStepResult, SimulatesAmplitudes, @@ -63,8 +63,8 @@ def sample(self, qubits, repetitions=1, seed=None): class SimulatesIntermediateStateImpl( - Generic[TStepResult, TActOnArgs], - SimulatesIntermediateState[TStepResult, 'SimulationTrialResult', TActOnArgs], + Generic[TStepResult, TSimulationState], + SimulatesIntermediateState[TStepResult, 'SimulationTrialResult', TSimulationState], metaclass=abc.ABCMeta, ): """A SimulatesIntermediateState that uses the default SimulationTrialResult type.""" @@ -73,7 +73,7 @@ def _create_simulator_trial_result( self, params: study.ParamResolver, measurements: Dict[str, np.ndarray], - final_simulator_state: 'cirq.OperationTarget[TActOnArgs]', + final_simulator_state: 'cirq.SimulationStateBase[TSimulationState]', ) -> 'SimulationTrialResult': """This method creates a default trial result. diff --git a/cirq-core/cirq/sim/sparse_simulator.py b/cirq-core/cirq/sim/sparse_simulator.py index d95182f4794..fbc8eeae29f 100644 --- a/cirq-core/cirq/sim/sparse_simulator.py +++ b/cirq-core/cirq/sim/sparse_simulator.py @@ -20,7 +20,7 @@ from cirq import ops from cirq._compat import deprecated_parameter -from cirq.sim import simulator, state_vector, state_vector_simulator, act_on_state_vector_args +from cirq.sim import simulator, state_vector, state_vector_simulator, state_vector_simulation_state if TYPE_CHECKING: import cirq @@ -155,11 +155,11 @@ def __init__( def _create_partial_act_on_args( self, - initial_state: Union['cirq.STATE_VECTOR_LIKE', 'cirq.ActOnStateVectorArgs'], + initial_state: Union['cirq.STATE_VECTOR_LIKE', 'cirq.StateVectorSimulationState'], qubits: Sequence['cirq.Qid'], classical_data: 'cirq.ClassicalDataStore', ): - """Creates the ActOnStateVectorArgs for a circuit. + """Creates the StateVectorSimulationState for a circuit. Args: initial_state: The initial state for the simulation in the @@ -171,12 +171,12 @@ def _create_partial_act_on_args( simulation. Returns: - ActOnStateVectorArgs for the circuit. + StateVectorSimulationState for the circuit. """ - if isinstance(initial_state, act_on_state_vector_args.ActOnStateVectorArgs): + if isinstance(initial_state, state_vector_simulation_state.StateVectorSimulationState): return initial_state - return act_on_state_vector_args.ActOnStateVectorArgs( + return state_vector_simulation_state.StateVectorSimulationState( qubits=qubits, prng=self._prng, classical_data=classical_data, @@ -184,7 +184,9 @@ def _create_partial_act_on_args( dtype=self._dtype, ) - def _create_step_result(self, sim_state: 'cirq.OperationTarget[cirq.ActOnStateVectorArgs]'): + def _create_step_result( + self, sim_state: 'cirq.SimulationStateBase[cirq.StateVectorSimulationState]' + ): return SparseSimulatorStep(sim_state=sim_state, dtype=self._dtype) def simulate_expectation_values_sweep_iter( @@ -228,14 +230,14 @@ class SparseSimulatorStep( ) def __init__( self, - sim_state: 'cirq.OperationTarget[cirq.ActOnStateVectorArgs]', + sim_state: 'cirq.SimulationStateBase[cirq.StateVectorSimulationState]', simulator: 'cirq.Simulator' = None, dtype: 'DTypeLike' = np.complex64, ): """Results of a step of the simulator. Args: - sim_state: The qubit:ActOnArgs lookup for this step. + sim_state: The qubit:SimulationState lookup for this step. simulator: The simulator used to create this. dtype: The `numpy.dtype` used by the simulation. One of `numpy.complex64` or `numpy.complex128`. diff --git a/cirq-core/cirq/sim/sparse_simulator_test.py b/cirq-core/cirq/sim/sparse_simulator_test.py index cd962051148..3329dc8135d 100644 --- a/cirq-core/cirq/sim/sparse_simulator_test.py +++ b/cirq-core/cirq/sim/sparse_simulator_test.py @@ -465,7 +465,7 @@ def test_simulate_initial_state(dtype: Type[np.number], split: bool): @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_simulate_act_on_args(dtype: Type[np.number], split: bool): +def test_simulation_state(dtype: Type[np.number], split: bool): q0, q1 = cirq.LineQubit.range(2) simulator = cirq.Simulator(dtype=dtype, split_untangled_states=split) for b0 in [0, 1]: @@ -749,7 +749,7 @@ def _apply_unitary_(self, args: cirq.ApplyUnitaryArgs): def test_simulator_step_state_mixin(): qubits = cirq.LineQubit.range(2) - args = cirq.ActOnStateVectorArgs( + args = cirq.StateVectorSimulationState( available_buffer=np.array([0, 1, 0, 0]).reshape((2, 2)), prng=cirq.value.parse_random_state(0), qubits=qubits, @@ -1343,7 +1343,7 @@ def test_nondeterministic_mixture_noise(): assert result1 != result2 -def test_act_on_args_pure_state_creation(): +def test_pure_state_creation(): sim = cirq.Simulator() qids = cirq.LineQubit.range(3) shape = cirq.qid_shape(qids) diff --git a/cirq-core/cirq/sim/state_vector_simulation_state.py b/cirq-core/cirq/sim/state_vector_simulation_state.py new file mode 100644 index 00000000000..1111538f5ee --- /dev/null +++ b/cirq-core/cirq/sim/state_vector_simulation_state.py @@ -0,0 +1,501 @@ +# Copyright 2018 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Objects and methods for acting efficiently on a state vector.""" + +from typing import Any, Callable, List, Optional, Sequence, Tuple, TYPE_CHECKING, Type, Union + +import numpy as np + +from cirq import _compat, linalg, protocols, qis, sim +from cirq._compat import proper_repr +from cirq.linalg import transformations +from cirq.sim.simulation_state import SimulationState, strat_act_on_from_apply_decompose + +if TYPE_CHECKING: + import cirq + from numpy.typing import DTypeLike + + +class _BufferedStateVector(qis.QuantumStateRepresentation): + """Contains the state vector and buffer for efficient state evolution.""" + + def __init__(self, state_vector: np.ndarray, buffer: Optional[np.ndarray] = None): + """Initializes the object with the inputs. + + This initializer creates the buffer if necessary. + + Args: + state_vector: The state vector, must be correctly formatted. The data is not checked + for validity here due to performance concerns. + buffer: Optional, must be same shape as the state vector. If not provided, a buffer + will be created automatically. + """ + self._state_vector = state_vector + if buffer is None: + buffer = np.empty_like(state_vector) + self._buffer = buffer + self._qid_shape = state_vector.shape + + @classmethod + def create( + cls, + *, + initial_state: Union[np.ndarray, 'cirq.STATE_VECTOR_LIKE'] = 0, + qid_shape: Optional[Tuple[int, ...]] = None, + dtype: Optional['DTypeLike'] = None, + buffer: Optional[List[np.ndarray]] = None, + ): + """Initializes the object with the inputs. + + This initializer creates the buffer if necessary. + + Args: + initial_state: The density matrix, must be correctly formatted. The data is not + checked for validity here due to performance concerns. + qid_shape: The shape of the density matrix, if the initial state is provided as an int. + dtype: The dtype of the density matrix, if the initial state is provided as an int. + buffer: Optional, must be length 3 and same shape as the density matrix. If not + provided, a buffer will be created automatically. + Raises: + ValueError: If initial state is provided as integer, but qid_shape is not provided. + """ + if not isinstance(initial_state, np.ndarray): + if qid_shape is None: + raise ValueError('qid_shape must be provided if initial_state is not ndarray') + state_vector = qis.to_valid_state_vector( + initial_state, len(qid_shape), qid_shape=qid_shape, dtype=dtype + ).reshape(qid_shape) + else: + if qid_shape is not None: + state_vector = initial_state.reshape(qid_shape) + else: + state_vector = initial_state + if np.may_share_memory(state_vector, initial_state): + state_vector = state_vector.copy() + state_vector = state_vector.astype(dtype, copy=False) + return cls(state_vector, buffer) + + def copy(self, deep_copy_buffers: bool = True) -> '_BufferedStateVector': + """Copies the object. + + Args: + deep_copy_buffers: True by default, False to reuse the existing buffers. + Returns: + A copy of the object. + """ + return _BufferedStateVector( + state_vector=self._state_vector.copy(), + buffer=self._buffer.copy() if deep_copy_buffers else self._buffer, + ) + + def kron(self, other: '_BufferedStateVector') -> '_BufferedStateVector': + """Creates the Kronecker product with the other state vector. + + Args: + other: The state vector with which to kron. + Returns: + The Kronecker product of the two state vectors. + """ + target_tensor = transformations.state_vector_kronecker_product( + self._state_vector, other._state_vector + ) + return _BufferedStateVector(state_vector=target_tensor, buffer=np.empty_like(target_tensor)) + + def factor( + self, axes: Sequence[int], *, validate=True, atol=1e-07 + ) -> Tuple['_BufferedStateVector', '_BufferedStateVector']: + """Factors a state vector into two independent state vectors. + + This function should only be called on state vectors that are known to be separable, such + as immediately after a measurement or reset operation. It does not verify that the provided + state vector is indeed separable, and will return nonsense results for vectors + representing entangled states. + + Args: + axes: The axes to factor out. + validate: Perform a validation that the state vector factors cleanly. + atol: The absolute tolerance for the validation. + + Returns: + A tuple with the `(extracted, remainder)` state vectors, where `extracted` means the + sub-state vector which corresponds to the axes requested, and with the axes in the + requested order, and where `remainder` means the sub-state vector on the remaining + axes, in the same order as the original state vector. + """ + extracted_tensor, remainder_tensor = transformations.factor_state_vector( + self._state_vector, axes, validate=validate, atol=atol + ) + extracted = _BufferedStateVector( + state_vector=extracted_tensor, buffer=np.empty_like(extracted_tensor) + ) + remainder = _BufferedStateVector( + state_vector=remainder_tensor, buffer=np.empty_like(remainder_tensor) + ) + return extracted, remainder + + def reindex(self, axes: Sequence[int]) -> '_BufferedStateVector': + """Transposes the axes of a state vector to a specified order. + + Args: + axes: The desired axis order. + Returns: + The transposed state vector. + """ + new_tensor = transformations.transpose_state_vector_to_axis_order(self._state_vector, axes) + return _BufferedStateVector(state_vector=new_tensor, buffer=np.empty_like(new_tensor)) + + def apply_unitary(self, action: Any, axes: Sequence[int]) -> bool: + """Apply unitary to state. + + Args: + action: The value with a unitary to apply. + axes: The axes on which to apply the unitary. + Returns: + True if the operation succeeded. + """ + new_target_tensor = protocols.apply_unitary( + action, + protocols.ApplyUnitaryArgs( + target_tensor=self._state_vector, available_buffer=self._buffer, axes=axes + ), + allow_decompose=False, + default=NotImplemented, + ) + if new_target_tensor is NotImplemented: + return False + self._swap_target_tensor_for(new_target_tensor) + return True + + def apply_mixture(self, action: Any, axes: Sequence[int], prng) -> Optional[int]: + """Apply mixture to state. + + Args: + action: The value with a mixture to apply. + axes: The axes on which to apply the mixture. + prng: The pseudo random number generator to use. + Returns: + The mixture index if the operation succeeded, otherwise None. + """ + mixture = protocols.mixture(action, default=None) + if mixture is None: + return None + probabilities, unitaries = zip(*mixture) + + index = prng.choice(range(len(unitaries)), p=probabilities) + shape = protocols.qid_shape(action) * 2 + unitary = unitaries[index].astype(self._state_vector.dtype).reshape(shape) + linalg.targeted_left_multiply(unitary, self._state_vector, axes, out=self._buffer) + self._swap_target_tensor_for(self._buffer) + return index + + def apply_channel(self, action: Any, axes: Sequence[int], prng) -> Optional[int]: + """Apply channel to state. + + Args: + action: The value with a channel to apply. + axes: The axes on which to apply the channel. + prng: The pseudo random number generator to use. + Returns: + The kraus index if the operation succeeded, otherwise None. + """ + kraus_operators = protocols.kraus(action, default=None) + if kraus_operators is None: + return None + + def prepare_into_buffer(k: int): + linalg.targeted_left_multiply( + left_matrix=kraus_tensors[k], + right_target=self._state_vector, + target_axes=axes, + out=self._buffer, + ) + + shape = protocols.qid_shape(action) + kraus_tensors = [ + e.reshape(shape * 2).astype(self._state_vector.dtype) for e in kraus_operators + ] + p = prng.random() + weight = None + fallback_weight = 0 + fallback_weight_index = 0 + index = None + for index in range(len(kraus_tensors)): + prepare_into_buffer(index) + weight = np.linalg.norm(self._buffer) ** 2 + + if weight > fallback_weight: + fallback_weight_index = index + fallback_weight = weight + + p -= weight + if p < 0: + break + + assert weight is not None, "No Kraus operators" + if p >= 0 or weight == 0: + # Floating point error resulted in a malformed sample. + # Fall back to the most likely case. + prepare_into_buffer(fallback_weight_index) + weight = fallback_weight + index = fallback_weight_index + + self._buffer /= np.sqrt(weight) + self._swap_target_tensor_for(self._buffer) + return index + + def measure( + self, axes: Sequence[int], seed: 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None + ) -> List[int]: + """Measures the state vector. + + Args: + axes: The axes to measure. + seed: The random number seed to use. + Returns: + The measurements in order. + """ + bits, _ = sim.measure_state_vector( + self._state_vector, axes, out=self._state_vector, qid_shape=self._qid_shape, seed=seed + ) + return bits + + def sample( + self, + axes: Sequence[int], + repetitions: int = 1, + seed: 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None, + ) -> np.ndarray: + """Samples the state vector. + + Args: + axes: The axes to sample. + repetitions: The number of samples to make. + seed: The random number seed to use. + Returns: + The samples in order. + """ + return sim.sample_state_vector( + self._state_vector, axes, qid_shape=self._qid_shape, repetitions=repetitions, seed=seed + ) + + def _swap_target_tensor_for(self, new_target_tensor: np.ndarray): + """Gives a new state vector for the system. + + Typically, the new state vector should be `args.available_buffer` where + `args` is this `cirq.StateVectorSimulationState` instance. + + Args: + new_target_tensor: The new system state. Must have the same shape + and dtype as the old system state. + """ + if new_target_tensor is self._buffer: + self._buffer = self._state_vector + self._state_vector = new_target_tensor + + @property + def supports_factor(self) -> bool: + return True + + +class StateVectorSimulationState(SimulationState[_BufferedStateVector]): + """State and context for an operation acting on a state vector. + + There are two common ways to act on this object: + + 1. Directly edit the `target_tensor` property, which is storing the state + vector of the quantum system as a numpy array with one axis per qudit. + 2. Overwrite the `available_buffer` property with the new state vector, and + then pass `available_buffer` into `swap_target_tensor_for`. + """ + + def __init__( + self, + *, + available_buffer: Optional[np.ndarray] = None, + prng: Optional[np.random.RandomState] = None, + qubits: Optional[Sequence['cirq.Qid']] = None, + initial_state: Union[np.ndarray, 'cirq.STATE_VECTOR_LIKE'] = 0, + dtype: Type[np.number] = np.complex64, + classical_data: Optional['cirq.ClassicalDataStore'] = None, + ): + """Inits StateVectorSimulationState. + + Args: + available_buffer: A workspace with the same shape and dtype as + `target_tensor`. Used by operations that cannot be applied to + `target_tensor` inline, in order to avoid unnecessary + allocations. Passing `available_buffer` into + `swap_target_tensor_for` will swap it for `target_tensor`. + qubits: Determines the canonical ordering of the qubits. This + is often used in specifying the initial state, i.e. the + ordering of the computational basis states. + prng: The pseudo random number generator to use for probabilistic + effects. + initial_state: The initial state for the simulation in the + computational basis. + dtype: The `numpy.dtype` of the inferred state vector. One of + `numpy.complex64` or `numpy.complex128`. Only used when + `target_tenson` is None. + classical_data: The shared classical data container for this + simulation. + """ + state = _BufferedStateVector.create( + initial_state=initial_state, + qid_shape=tuple(q.dimension for q in qubits) if qubits is not None else None, + dtype=dtype, + buffer=available_buffer, + ) + super().__init__(state=state, prng=prng, qubits=qubits, classical_data=classical_data) + + @_compat.deprecated( + deadline='v0.16', fix='None, this function was unintentionally made public.' + ) + def swap_target_tensor_for(self, new_target_tensor: np.ndarray): + """Gives a new state vector for the system. + + Typically, the new state vector should be `args.available_buffer` where + `args` is this `cirq.StateVectorSimulationState` instance. + + Args: + new_target_tensor: The new system state. Must have the same shape + and dtype as the old system state. + """ + self._state._swap_target_tensor_for(new_target_tensor) + + @_compat.deprecated( + deadline='v0.16', fix='None, this function was unintentionally made public.' + ) + def subspace_index( + self, axes: Sequence[int], little_endian_bits_int: int = 0, *, big_endian_bits_int: int = 0 + ) -> Tuple[Union[slice, int, 'ellipsis'], ...]: + """An index for the subspace where the target axes equal a value. + + Args: + axes: The qubits that are specified by the index bits. + little_endian_bits_int: The desired value of the qubits at the + targeted `axes`, packed into an integer. The least significant + bit of the integer is the desired bit for the first axis, and + so forth in increasing order. Can't be specified at the same + time as `big_endian_bits_int`. + + When operating on qudits instead of qubits, the same basic logic + applies but in a different basis. For example, if the target + axes have dimension [a:2, b:3, c:2] then the integer 10 + decomposes into [a=0, b=2, c=1] via 7 = 1*(3*2) + 2*(2) + 0. + big_endian_bits_int: The desired value of the qubits at the + targeted `axes`, packed into an integer. The most significant + bit of the integer is the desired bit for the first axis, and + so forth in decreasing order. Can't be specified at the same + time as `little_endian_bits_int`. + + When operating on qudits instead of qubits, the same basic logic + applies but in a different basis. For example, if the target + axes have dimension [a:2, b:3, c:2] then the integer 10 + decomposes into [a=1, b=2, c=0] via 7 = 1*(3*2) + 2*(2) + 0. + + Returns: + A value that can be used to index into `target_tensor` and + `available_buffer`, and manipulate only the part of Hilbert space + corresponding to a given bit assignment. + + Example: + If `target_tensor` is a 4 qubit tensor and `axes` is `[1, 3]` and + then this method will return the following when given + `little_endian_bits=0b01`: + + `(slice(None), 0, slice(None), 1, Ellipsis)` + + Therefore the following two lines would be equivalent: + + args.target_tensor[args.subspace_index(0b01)] += 1 + + args.target_tensor[:, 0, :, 1] += 1 + """ + return linalg.slice_for_qubits_equal_to( + axes, + little_endian_qureg_value=little_endian_bits_int, + big_endian_qureg_value=big_endian_bits_int, + qid_shape=self.target_tensor.shape, + ) + + def _act_on_fallback_( + self, action: Any, qubits: Sequence['cirq.Qid'], allow_decompose: bool = True + ) -> bool: + strats: List[Callable[[Any, Any, Sequence['cirq.Qid']], bool]] = [ + _strat_act_on_state_vector_from_apply_unitary, + _strat_act_on_state_vector_from_mixture, + _strat_act_on_state_vector_from_channel, + ] + if allow_decompose: + strats.append(strat_act_on_from_apply_decompose) # type: ignore + + # Try each strategy, stopping if one works. + for strat in strats: + result = strat(action, self, qubits) + if result is False: + break # coverage: ignore + if result is True: + return True + assert result is NotImplemented, str(result) + raise TypeError( + "Can't simulate operations that don't implement " + "SupportsUnitary, SupportsConsistentApplyUnitary, " + "SupportsMixture or is a measurement: {!r}".format(action) + ) + + def __repr__(self) -> str: + return ( + 'cirq.StateVectorSimulationState(' + f'initial_state={proper_repr(self.target_tensor)},' + f' qubits={self.qubits!r},' + f' classical_data={self.classical_data!r})' + ) + + @property + def target_tensor(self): + return self._state._state_vector + + @property + def available_buffer(self): + return self._state._buffer + + +def _strat_act_on_state_vector_from_apply_unitary( + action: Any, args: 'cirq.StateVectorSimulationState', qubits: Sequence['cirq.Qid'] +) -> bool: + return True if args._state.apply_unitary(action, args.get_axes(qubits)) else NotImplemented + + +def _strat_act_on_state_vector_from_mixture( + action: Any, args: 'cirq.StateVectorSimulationState', qubits: Sequence['cirq.Qid'] +) -> bool: + index = args._state.apply_mixture(action, args.get_axes(qubits), args.prng) + if index is None: + return NotImplemented + if protocols.is_measurement(action): + key = protocols.measurement_key_name(action) + args._classical_data.record_channel_measurement(key, index) + return True + + +def _strat_act_on_state_vector_from_channel( + action: Any, args: 'cirq.StateVectorSimulationState', qubits: Sequence['cirq.Qid'] +) -> bool: + index = args._state.apply_channel(action, args.get_axes(qubits), args.prng) + if index is None: + return NotImplemented + if protocols.is_measurement(action): + key = protocols.measurement_key_name(action) + args._classical_data.record_channel_measurement(key, index) + return True diff --git a/cirq-core/cirq/sim/act_on_state_vector_args_test.py b/cirq-core/cirq/sim/state_vector_simulation_state_test.py similarity index 92% rename from cirq-core/cirq/sim/act_on_state_vector_args_test.py rename to cirq-core/cirq/sim/state_vector_simulation_state_test.py index dd08eeca095..4dc54331e78 100644 --- a/cirq-core/cirq/sim/act_on_state_vector_args_test.py +++ b/cirq-core/cirq/sim/state_vector_simulation_state_test.py @@ -24,7 +24,7 @@ def test_default_parameter(): dtype = np.complex64 tensor = cirq.one_hot(shape=(2, 2, 2), dtype=np.complex64) qubits = cirq.LineQubit.range(3) - args = cirq.ActOnStateVectorArgs(qubits=qubits, initial_state=tensor, dtype=dtype) + args = cirq.StateVectorSimulationState(qubits=qubits, initial_state=tensor, dtype=dtype) qid_shape = cirq.protocols.qid_shape(qubits) tensor = np.reshape(tensor, qid_shape) np.testing.assert_almost_equal(args.target_tensor, tensor) @@ -34,7 +34,7 @@ def test_default_parameter(): def test_infer_target_tensor(): dtype = np.complex64 - args = cirq.ActOnStateVectorArgs( + args = cirq.StateVectorSimulationState( qubits=cirq.LineQubit.range(2), initial_state=np.array([1.0, 0.0, 0.0, 0.0], dtype=dtype), dtype=dtype, @@ -44,7 +44,9 @@ def test_infer_target_tensor(): np.array([[1.0 + 0.0j, 0.0 + 0.0j], [0.0 + 0.0j, 0.0 + 0.0j]], dtype=dtype), ) - args = cirq.ActOnStateVectorArgs(qubits=cirq.LineQubit.range(2), initial_state=0, dtype=dtype) + args = cirq.StateVectorSimulationState( + qubits=cirq.LineQubit.range(2), initial_state=0, dtype=dtype + ) np.testing.assert_almost_equal( args.target_tensor, np.array([[1.0 + 0.0j, 0.0 + 0.0j], [0.0 + 0.0j, 0.0 + 0.0j]], dtype=dtype), @@ -52,7 +54,7 @@ def test_infer_target_tensor(): def test_shallow_copy_buffers(): - args = cirq.ActOnStateVectorArgs(qubits=cirq.LineQubit.range(1), initial_state=0) + args = cirq.StateVectorSimulationState(qubits=cirq.LineQubit.range(1), initial_state=0) copy = args.copy(deep_copy_buffers=False) assert copy.available_buffer is args.available_buffer @@ -65,7 +67,7 @@ def num_qubits(self) -> int: def _decompose_(self, qubits): yield cirq.X(*qubits) - args = cirq.ActOnStateVectorArgs( + args = cirq.StateVectorSimulationState( available_buffer=np.empty((2, 2, 2), dtype=np.complex64), qubits=cirq.LineQubit.range(3), prng=np.random.RandomState(), @@ -83,7 +85,7 @@ def test_cannot_act(): class NoDetails: pass - args = cirq.ActOnStateVectorArgs( + args = cirq.StateVectorSimulationState( available_buffer=np.empty((2, 2, 2), dtype=np.complex64), qubits=cirq.LineQubit.range(3), prng=np.random.RandomState(), @@ -107,7 +109,7 @@ def _kraus_(self): mock_prng = mock.Mock() mock_prng.random.return_value = 1 / 3 + 1e-6 - args = cirq.ActOnStateVectorArgs( + args = cirq.StateVectorSimulationState( available_buffer=np.empty_like(initial_state), qubits=cirq.LineQubit.range(4), prng=mock_prng, @@ -126,7 +128,7 @@ def _kraus_(self): ) mock_prng.random.return_value = 1 / 3 - 1e-6 - args = cirq.ActOnStateVectorArgs( + args = cirq.StateVectorSimulationState( available_buffer=np.empty_like(initial_state), qubits=cirq.LineQubit.range(4), prng=mock_prng, @@ -163,7 +165,7 @@ def _kraus_(self): def get_result(state: np.ndarray, sample: float): mock_prng.random.return_value = sample - args = cirq.ActOnStateVectorArgs( + args = cirq.StateVectorSimulationState( available_buffer=np.empty_like(state), qubits=cirq.LineQubit.range(4), prng=mock_prng, @@ -218,7 +220,7 @@ def _kraus_(self): mock_prng = mock.Mock() mock_prng.random.return_value = 0.9999 - args = cirq.ActOnStateVectorArgs( + args = cirq.StateVectorSimulationState( available_buffer=np.empty(2, dtype=np.complex64), qubits=cirq.LineQubit.range(1), prng=mock_prng, @@ -267,7 +269,7 @@ def test_measured_mixture(): def test_with_qubits(): - original = cirq.ActOnStateVectorArgs( + original = cirq.StateVectorSimulationState( qubits=cirq.LineQubit.range(2), initial_state=1, dtype=np.complex64 ) extened = original.with_qubits(cirq.LineQubit.range(2, 4)) @@ -282,11 +284,11 @@ def test_with_qubits(): def test_qid_shape_error(): with pytest.raises(ValueError, match="qid_shape must be provided"): - cirq.sim.act_on_state_vector_args._BufferedStateVector.create(initial_state=0) + cirq.sim.state_vector_simulation_state._BufferedStateVector.create(initial_state=0) def test_deprecated_methods(): - args = cirq.ActOnStateVectorArgs(qubits=[cirq.LineQubit(0)]) + args = cirq.StateVectorSimulationState(qubits=[cirq.LineQubit(0)]) with cirq.testing.assert_deprecated('unintentionally made public', deadline='v0.16'): args.subspace_index([0], 0) with cirq.testing.assert_deprecated('unintentionally made public', deadline='v0.16'): diff --git a/cirq-core/cirq/sim/state_vector_simulator.py b/cirq-core/cirq/sim/state_vector_simulator.py index a473565ba3e..cbf965dc99d 100644 --- a/cirq-core/cirq/sim/state_vector_simulator.py +++ b/cirq-core/cirq/sim/state_vector_simulator.py @@ -43,7 +43,7 @@ class SimulatesIntermediateStateVector( Generic[TStateVectorStepResult], simulator_base.SimulatorBase[ - TStateVectorStepResult, 'cirq.StateVectorTrialResult', 'cirq.ActOnStateVectorArgs', + TStateVectorStepResult, 'cirq.StateVectorTrialResult', 'cirq.StateVectorSimulationState', ], simulator.SimulatesAmplitudes, metaclass=abc.ABCMeta, @@ -69,7 +69,7 @@ def _create_simulator_trial_result( self, params: 'cirq.ParamResolver', measurements: Dict[str, np.ndarray], - final_simulator_state: 'cirq.OperationTarget[cirq.ActOnStateVectorArgs]', + final_simulator_state: 'cirq.SimulationStateBase[cirq.StateVectorSimulationState]', ) -> 'cirq.StateVectorTrialResult': return StateVectorTrialResult( params=params, measurements=measurements, final_simulator_state=final_simulator_state @@ -102,7 +102,7 @@ def compute_amplitudes_sweep_iter( class StateVectorStepResult( - simulator_base.StepResultBase['cirq.ActOnStateVectorArgs'], metaclass=abc.ABCMeta + simulator_base.StepResultBase['cirq.StateVectorSimulationState'], metaclass=abc.ABCMeta ): pass @@ -132,7 +132,7 @@ def _value_equality_values_(self) -> Any: @value.value_equality(unhashable=True) class StateVectorTrialResult( state_vector.StateVectorMixin, - simulator_base.SimulationTrialResultBase['cirq.ActOnStateVectorArgs'], + simulator_base.SimulationTrialResultBase['cirq.StateVectorSimulationState'], ): """A `SimulationTrialResult` that includes the `StateVectorMixin` methods. @@ -144,7 +144,7 @@ def __init__( self, params: 'cirq.ParamResolver', measurements: Dict[str, np.ndarray], - final_simulator_state: 'cirq.OperationTarget[cirq.ActOnStateVectorArgs]', + final_simulator_state: 'cirq.SimulationStateBase[cirq.StateVectorSimulationState]', ) -> None: super().__init__( params=params, diff --git a/cirq-core/cirq/sim/state_vector_simulator_test.py b/cirq-core/cirq/sim/state_vector_simulator_test.py index fdae8ece86e..6bff10df026 100644 --- a/cirq-core/cirq/sim/state_vector_simulator_test.py +++ b/cirq-core/cirq/sim/state_vector_simulator_test.py @@ -20,7 +20,7 @@ def test_state_vector_trial_result_repr(): q0 = cirq.NamedQubit('a') - final_simulator_state = cirq.ActOnStateVectorArgs( + final_simulator_state = cirq.StateVectorSimulationState( available_buffer=np.array([0, 1], dtype=np.complex64), prng=np.random.RandomState(0), qubits=[q0], @@ -36,7 +36,7 @@ def test_state_vector_trial_result_repr(): "cirq.StateVectorTrialResult(" "params=cirq.ParamResolver({'s': 1}), " "measurements={'m': np.array([[1]], dtype=np.int32)}, " - "final_simulator_state=cirq.ActOnStateVectorArgs(" + "final_simulator_state=cirq.StateVectorSimulationState(" "initial_state=np.array([0j, (1+0j)], dtype=np.complex64), " "qubits=(cirq.NamedQubit('a'),), " "classical_data=cirq.ClassicalDataDictionaryStore()))" @@ -55,7 +55,7 @@ def test_state_vector_simulator_state_repr(): def test_state_vector_trial_result_equality(): eq = cirq.testing.EqualsTester() - final_simulator_state = cirq.ActOnStateVectorArgs(initial_state=np.array([])) + final_simulator_state = cirq.StateVectorSimulationState(initial_state=np.array([])) eq.add_equality_group( cirq.StateVectorTrialResult( params=cirq.ParamResolver({}), @@ -82,7 +82,7 @@ def test_state_vector_trial_result_equality(): final_simulator_state=final_simulator_state, ) ) - final_simulator_state = cirq.ActOnStateVectorArgs(initial_state=np.array([1])) + final_simulator_state = cirq.StateVectorSimulationState(initial_state=np.array([1])) eq.add_equality_group( cirq.StateVectorTrialResult( params=cirq.ParamResolver({'s': 1}), @@ -94,7 +94,7 @@ def test_state_vector_trial_result_equality(): def test_state_vector_trial_result_state_mixin(): qubits = cirq.LineQubit.range(2) - final_simulator_state = cirq.ActOnStateVectorArgs( + final_simulator_state = cirq.StateVectorSimulationState( qubits=qubits, initial_state=np.array([0, 1, 0, 0]) ) result = cirq.StateVectorTrialResult( @@ -110,7 +110,7 @@ def test_state_vector_trial_result_state_mixin(): def test_state_vector_trial_result_qid_shape(): - final_simulator_state = cirq.ActOnStateVectorArgs( + final_simulator_state = cirq.StateVectorSimulationState( qubits=[cirq.NamedQubit('a')], initial_state=np.array([0, 1]) ) trial_result = cirq.StateVectorTrialResult( @@ -120,7 +120,7 @@ def test_state_vector_trial_result_qid_shape(): ) assert cirq.qid_shape(trial_result) == (2,) - final_simulator_state = cirq.ActOnStateVectorArgs( + final_simulator_state = cirq.StateVectorSimulationState( qubits=cirq.LineQid.for_qid_shape((3, 2)), initial_state=np.array([0, 0, 0, 0, 1, 0]) ) trial_result = cirq.StateVectorTrialResult( @@ -134,7 +134,7 @@ def test_state_vector_trial_result_qid_shape(): def test_state_vector_trial_state_vector_is_copy(): final_state_vector = np.array([0, 1], dtype=np.complex64) qubit_map = {cirq.NamedQubit('a'): 0} - final_simulator_state = cirq.ActOnStateVectorArgs( + final_simulator_state = cirq.StateVectorSimulationState( qubits=list(qubit_map), initial_state=final_state_vector ) trial_result = cirq.StateVectorTrialResult( @@ -145,7 +145,7 @@ def test_state_vector_trial_state_vector_is_copy(): def test_str_big(): qs = cirq.LineQubit.range(10) - final_simulator_state = cirq.ActOnStateVectorArgs( + final_simulator_state = cirq.StateVectorSimulationState( prng=np.random.RandomState(0), qubits=qs, initial_state=np.array([1] * 2**10, dtype=np.complex64) * 0.03125, @@ -156,7 +156,7 @@ def test_str_big(): def test_pretty_print(): - final_simulator_state = cirq.ActOnStateVectorArgs( + final_simulator_state = cirq.StateVectorSimulationState( available_buffer=np.array([1]), prng=np.random.RandomState(0), qubits=[], diff --git a/cirq-core/cirq/testing/consistent_act_on.py b/cirq-core/cirq/testing/consistent_act_on.py index 0bd30f378ec..619f5bfa59c 100644 --- a/cirq-core/cirq/testing/consistent_act_on.py +++ b/cirq-core/cirq/testing/consistent_act_on.py @@ -22,11 +22,11 @@ from cirq.ops.dense_pauli_string import DensePauliString from cirq import protocols from cirq.qis import clifford_tableau -from cirq.sim import act_on_state_vector_args, final_state_vector +from cirq.sim import state_vector_simulation_state, final_state_vector from cirq.sim.clifford import ( - act_on_clifford_tableau_args, + clifford_tableau_simulation_state, stabilizer_state_ch_form, - act_on_stabilizer_ch_form_args, + stabilizer_ch_form_simulation_state, ) @@ -46,7 +46,7 @@ def state_vector_has_stabilizer(state_vector: np.ndarray, stabilizer: DensePauli """ qubits = LineQubit.range(protocols.num_qubits(stabilizer)) - args = act_on_state_vector_args.ActOnStateVectorArgs( + args = state_vector_simulation_state.StateVectorSimulationState( available_buffer=np.empty_like(state_vector), qubits=qubits, prng=np.random.RandomState(), @@ -71,7 +71,7 @@ def assert_all_implemented_act_on_effects_match_unitary( Args: val: A gate or operation that may be an input to protocols.act_on. assert_tableau_implemented: asserts that protocols.act_on() works with - val and ActOnCliffordTableauArgs inputs. + val and CliffordTableauSimulationState inputs. assert_ch_form_implemented: asserts that protocols.act_on() works with val and ActOnStabilizerStateChFormArgs inputs. """ @@ -159,7 +159,7 @@ def _final_clifford_tableau( the tableau otherwise.""" tableau = clifford_tableau.CliffordTableau(len(qubit_map)) - args = act_on_clifford_tableau_args.ActOnCliffordTableauArgs( + args = clifford_tableau_simulation_state.CliffordTableauSimulationState( tableau=tableau, qubits=list(qubit_map.keys()), prng=np.random.RandomState() ) for op in circuit.all_operations(): @@ -187,7 +187,7 @@ def _final_stabilizer_state_ch_form( returns the StabilizerStateChForm otherwise.""" stabilizer_ch_form = stabilizer_state_ch_form.StabilizerStateChForm(len(qubit_map)) - args = act_on_stabilizer_ch_form_args.ActOnStabilizerCHFormArgs( + args = stabilizer_ch_form_simulation_state.StabilizerChFormSimulationState( qubits=list(qubit_map.keys()), prng=np.random.RandomState(), initial_state=stabilizer_ch_form, diff --git a/cirq-core/cirq/testing/consistent_act_on_test.py b/cirq-core/cirq/testing/consistent_act_on_test.py index 34be8031b95..3de9612c90f 100644 --- a/cirq-core/cirq/testing/consistent_act_on_test.py +++ b/cirq-core/cirq/testing/consistent_act_on_test.py @@ -24,10 +24,10 @@ class GoodGate(cirq.testing.SingleQubitGate): def _unitary_(self): return np.array([[0, 1], [1, 0]]) - def _act_on_(self, args: 'cirq.OperationTarget', qubits: Sequence['cirq.Qid']): - if isinstance(args, cirq.ActOnCliffordTableauArgs): - tableau = args.tableau - q = args.qubit_map[qubits[0]] + def _act_on_(self, sim_state: 'cirq.SimulationStateBase', qubits: Sequence['cirq.Qid']): + if isinstance(sim_state, cirq.CliffordTableauSimulationState): + tableau = sim_state.tableau + q = sim_state.qubit_map[qubits[0]] tableau.rs[:] ^= tableau.zs[:, q] return True return NotImplemented @@ -37,10 +37,10 @@ class BadGate(cirq.testing.SingleQubitGate): def _unitary_(self): return np.array([[0, 1j], [1, 0]]) - def _act_on_(self, args: 'cirq.OperationTarget', qubits: Sequence['cirq.Qid']): - if isinstance(args, cirq.ActOnCliffordTableauArgs): - tableau = args.tableau - q = args.qubit_map[qubits[0]] + def _act_on_(self, sim_state: 'cirq.SimulationStateBase', qubits: Sequence['cirq.Qid']): + if isinstance(sim_state, cirq.CliffordTableauSimulationState): + tableau = sim_state.tableau + q = sim_state.qubit_map[qubits[0]] tableau.rs[:] ^= tableau.zs[:, q] return True return NotImplemented diff --git a/cirq-core/cirq/transformers/analytical_decompositions/clifford_decomposition.py b/cirq-core/cirq/transformers/analytical_decompositions/clifford_decomposition.py index 1e51d8b866e..5018e510125 100644 --- a/cirq-core/cirq/transformers/analytical_decompositions/clifford_decomposition.py +++ b/cirq-core/cirq/transformers/analytical_decompositions/clifford_decomposition.py @@ -26,7 +26,7 @@ def _X( q: int, - args: sim.ActOnCliffordTableauArgs, + args: sim.CliffordTableauSimulationState, operations: List[ops.Operation], qubits: List['cirq.Qid'], ): @@ -36,7 +36,7 @@ def _X( def _Z( q: int, - args: sim.ActOnCliffordTableauArgs, + args: sim.CliffordTableauSimulationState, operations: List[ops.Operation], qubits: List['cirq.Qid'], ): @@ -46,7 +46,7 @@ def _Z( def _Sdg( q: int, - args: sim.ActOnCliffordTableauArgs, + args: sim.CliffordTableauSimulationState, operations: List[ops.Operation], qubits: List['cirq.Qid'], ): @@ -57,7 +57,7 @@ def _Sdg( def _H( q: int, - args: sim.ActOnCliffordTableauArgs, + args: sim.CliffordTableauSimulationState, operations: List[ops.Operation], qubits: List['cirq.Qid'], ): @@ -68,7 +68,7 @@ def _H( def _CNOT( q1: int, q2: int, - args: sim.ActOnCliffordTableauArgs, + args: sim.CliffordTableauSimulationState, operations: List[ops.Operation], qubits: List['cirq.Qid'], ): @@ -79,7 +79,7 @@ def _CNOT( def _SWAP( q1: int, q2: int, - args: sim.ActOnCliffordTableauArgs, + args: sim.CliffordTableauSimulationState, operations: List[ops.Operation], qubits: List['cirq.Qid'], ): @@ -114,7 +114,9 @@ def decompose_clifford_tableau_to_operations( t: qis.CliffordTableau = clifford_tableau.copy() operations: List[ops.Operation] = [] - args = sim.ActOnCliffordTableauArgs(tableau=t, qubits=qubits, prng=np.random.RandomState()) + args = sim.CliffordTableauSimulationState( + tableau=t, qubits=qubits, prng=np.random.RandomState() + ) _X_with_ops = functools.partial(_X, args=args, operations=operations, qubits=qubits) _Z_with_ops = functools.partial(_Z, args=args, operations=operations, qubits=qubits) diff --git a/cirq-core/cirq/transformers/analytical_decompositions/clifford_decomposition_test.py b/cirq-core/cirq/transformers/analytical_decompositions/clifford_decomposition_test.py index e635ff7ae2f..29b3f4b43e1 100644 --- a/cirq-core/cirq/transformers/analytical_decompositions/clifford_decomposition_test.py +++ b/cirq-core/cirq/transformers/analytical_decompositions/clifford_decomposition_test.py @@ -39,7 +39,7 @@ def test_misaligned_qubits(): def test_clifford_decompose_one_qubit(): """Two random instance for one qubit decomposition.""" qubits = cirq.LineQubit.range(1) - args = cirq.ActOnCliffordTableauArgs( + args = cirq.CliffordTableauSimulationState( tableau=cirq.CliffordTableau(num_qubits=1), qubits=qubits, prng=np.random.RandomState() ) cirq.act_on(cirq.X, args, qubits=[qubits[0]], allow_decompose=False) @@ -51,7 +51,7 @@ def test_clifford_decompose_one_qubit(): assert_allclose_up_to_global_phase(cirq.unitary(expect_circ), cirq.unitary(circ), atol=1e-7) qubits = cirq.LineQubit.range(1) - args = cirq.ActOnCliffordTableauArgs( + args = cirq.CliffordTableauSimulationState( tableau=cirq.CliffordTableau(num_qubits=1), qubits=qubits, prng=np.random.RandomState() ) cirq.act_on(cirq.Z, args, qubits=[qubits[0]], allow_decompose=False) @@ -74,7 +74,7 @@ def test_clifford_decompose_one_qubit(): def test_clifford_decompose_two_qubits(): """Two random instance for two qubits decomposition.""" qubits = cirq.LineQubit.range(2) - args = cirq.ActOnCliffordTableauArgs( + args = cirq.CliffordTableauSimulationState( tableau=cirq.CliffordTableau(num_qubits=2), qubits=qubits, prng=np.random.RandomState() ) cirq.act_on(cirq.H, args, qubits=[qubits[0]], allow_decompose=False) @@ -85,7 +85,7 @@ def test_clifford_decompose_two_qubits(): assert_allclose_up_to_global_phase(cirq.unitary(expect_circ), cirq.unitary(circ), atol=1e-7) qubits = cirq.LineQubit.range(2) - args = cirq.ActOnCliffordTableauArgs( + args = cirq.CliffordTableauSimulationState( tableau=cirq.CliffordTableau(num_qubits=2), qubits=qubits, prng=np.random.RandomState() ) cirq.act_on(cirq.H, args, qubits=[qubits[0]], allow_decompose=False) @@ -118,7 +118,7 @@ def test_clifford_decompose_by_unitary(): t = cirq.CliffordTableau(num_qubits=n) qubits = cirq.LineQubit.range(n) expect_circ = cirq.Circuit() - args = cirq.ActOnCliffordTableauArgs(tableau=t, qubits=qubits, prng=prng) + args = cirq.CliffordTableauSimulationState(tableau=t, qubits=qubits, prng=prng) for _ in range(num_ops): g = prng.randint(len(gate_candidate)) indices = (prng.randint(n),) if g < 5 else prng.choice(n, 2, replace=False) @@ -145,7 +145,7 @@ def test_clifford_decompose_by_reconstruction(): t = cirq.CliffordTableau(num_qubits=n) qubits = cirq.LineQubit.range(n) expect_circ = cirq.Circuit() - args = cirq.ActOnCliffordTableauArgs(tableau=t, qubits=qubits, prng=prng) + args = cirq.CliffordTableauSimulationState(tableau=t, qubits=qubits, prng=prng) for _ in range(num_ops): g = prng.randint(len(gate_candidate)) indices = (prng.randint(n),) if g < 5 else prng.choice(n, 2, replace=False) @@ -156,7 +156,7 @@ def test_clifford_decompose_by_reconstruction(): ops = cirq.decompose_clifford_tableau_to_operations(qubits, args.tableau) reconstruct_t = cirq.CliffordTableau(num_qubits=n) - reconstruct_args = cirq.ActOnCliffordTableauArgs( + reconstruct_args = cirq.CliffordTableauSimulationState( tableau=reconstruct_t, qubits=qubits, prng=prng ) for op in ops: diff --git a/cirq-google/cirq_google/calibration/engine_simulator.py b/cirq-google/cirq_google/calibration/engine_simulator.py index 69eba099950..761c649d1ab 100644 --- a/cirq-google/cirq_google/calibration/engine_simulator.py +++ b/cirq-google/cirq_google/calibration/engine_simulator.py @@ -456,14 +456,14 @@ def simulate( def _create_partial_act_on_args( self, - initial_state: Union[int, cirq.ActOnStateVectorArgs], + initial_state: Union[int, cirq.StateVectorSimulationState], qubits: Sequence[cirq.Qid], classical_data: cirq.ClassicalDataStore, - ) -> cirq.ActOnStateVectorArgs: + ) -> cirq.StateVectorSimulationState: # Needs an implementation since it's abstract but will never actually be called. raise NotImplementedError() - def _create_step_result(self, sim_state: cirq.OperationTarget) -> cirq.SparseSimulatorStep: + def _create_step_result(self, sim_state: cirq.SimulationStateBase) -> cirq.SparseSimulatorStep: # Needs an implementation since it's abstract but will never actually be called. raise NotImplementedError() diff --git a/examples/direct_fidelity_estimation.py b/examples/direct_fidelity_estimation.py index 20085320be0..322b36c648c 100644 --- a/examples/direct_fidelity_estimation.py +++ b/examples/direct_fidelity_estimation.py @@ -369,7 +369,7 @@ def direct_fidelity_estimation( clifford_tableau = cirq.CliffordTableau(n_qubits) try: for gate in circuit.all_operations(): - tableau_args = clifford.ActOnCliffordTableauArgs( + tableau_args = clifford.CliffordTableauSimulationState( tableau=clifford_tableau, qubits=qubits ) cirq.act_on(gate, tableau_args)