Skip to content

Commit

Permalink
Push redundant simulator iterator code into base classes (#3650)
Browse files Browse the repository at this point in the history
First step for #2178 

This PR pushes redundant simulator code into base classes, making it easier to create simulators. The scope of the PR is currently only the iterator methods. The run methods have subsequently diverged among subclasses, such that the first two lines of each are all that is repetitive, and the overhead of consolidating those is not worth it.

This is still something that could be visited and refined later on. There may be some other higher-level changes we could make to consolidate additional logic. That would be substantial redesign though and is beyond the scope of this PR.
  • Loading branch information
daxfohl committed Jan 12, 2021
1 parent 989abad commit 5fa8a3b
Show file tree
Hide file tree
Showing 7 changed files with 63 additions and 135 deletions.
34 changes: 2 additions & 32 deletions cirq/sim/clifford/clifford_simulator.py
Expand Up @@ -41,6 +41,7 @@
from cirq.protocols import act_on, unitary
from cirq.sim import clifford, simulator
from cirq._compat import deprecated, deprecated_parameter
from cirq.sim.simulator import check_all_resolved


class CliffordSimulator(simulator.SimulatesSamples, simulator.SimulatesIntermediateState):
Expand Down Expand Up @@ -122,26 +123,6 @@ def _base_iterator(
measurements=ch_form_args.log_of_measurement_results, state=state
)

def _simulator_iterator(
self,
circuit: circuits.Circuit,
param_resolver: study.ParamResolver,
qubit_order: ops.QubitOrderOrList,
initial_state: int,
) -> Iterator:
"""See definition in `cirq.SimulatesIntermediateState`.
Args:
inital_state: An integer specifying the inital
state in the computational basis.
"""
param_resolver = param_resolver or study.ParamResolver({})
resolved_circuit = protocols.resolve_parameters(circuit, param_resolver)
self._check_all_resolved(resolved_circuit)
actual_initial_state = 0 if initial_state is None else initial_state

return self._base_iterator(resolved_circuit, qubit_order, actual_initial_state)

def _create_simulator_trial_result(
self,
params: study.ParamResolver,
Expand All @@ -159,7 +140,7 @@ def _run(

param_resolver = param_resolver or study.ParamResolver({})
resolved_circuit = protocols.resolve_parameters(circuit, param_resolver)
self._check_all_resolved(resolved_circuit)
check_all_resolved(resolved_circuit)

measurements = {} # type: Dict[str, List[np.ndarray]]
if repetitions == 0:
Expand All @@ -179,17 +160,6 @@ def _run(

return {k: np.array(v) for k, v in measurements.items()}

def _check_all_resolved(self, circuit):
"""Raises if the circuit contains unresolved symbols."""
if protocols.is_parameterized(circuit):
unresolved = [
op for moment in circuit for op in moment if protocols.is_parameterized(op)
]
raise ValueError(
'Circuit contains ops whose symbols were not specified in '
'parameter sweep. Ops: {}'.format(unresolved)
)


class CliffordTrialResult(simulator.SimulationTrialResult):
def __init__(
Expand Down
37 changes: 2 additions & 35 deletions cirq/sim/density_matrix_simulator.py
Expand Up @@ -21,6 +21,7 @@

from cirq import circuits, ops, protocols, qis, study, value, devices
from cirq.sim import density_matrix_utils, simulator
from cirq.sim.simulator import check_all_resolved

if TYPE_CHECKING:
from typing import Tuple
Expand Down Expand Up @@ -165,7 +166,7 @@ def _run(
"""See definition in `cirq.SimulatesSamples`."""
param_resolver = param_resolver or study.ParamResolver({})
resolved_circuit = protocols.resolve_parameters(circuit, param_resolver)
self._check_all_resolved(resolved_circuit)
check_all_resolved(resolved_circuit)

if circuit.are_all_measurements_terminal():
return self._run_sweep_sample(resolved_circuit, repetitions)
Expand Down Expand Up @@ -205,29 +206,6 @@ def _run_sweep_repeat(
measurements[k].append(np.array(v, dtype=np.uint8))
return {k: np.array(v) for k, v in measurements.items()}

def _simulator_iterator(
self,
circuit: circuits.Circuit,
param_resolver: study.ParamResolver,
qubit_order: ops.QubitOrderOrList,
initial_state: Union[int, np.ndarray],
) -> Iterator:
"""See definition in `cirq.SimulatesIntermediateState`.
If the initial state is an int, the state is set to the computational
basis state corresponding to this state. Otherwise if the initial
state is a np.ndarray it is the full initial state, either a pure state
or the full density matrix. If it is the pure state it must be the
correct size, be normalized (an L2 norm of 1), and be safely castable
to an appropriate dtype for the simulator. If it is a mixed state
it must be correctly sized and positive semidefinite with trace one.
"""
param_resolver = param_resolver or study.ParamResolver({})
resolved_circuit = protocols.resolve_parameters(circuit, param_resolver)
self._check_all_resolved(resolved_circuit)
actual_initial_state = 0 if initial_state is None else initial_state
return self._base_iterator(resolved_circuit, qubit_order, actual_initial_state)

def _apply_op_channel(
self, op: ops.Operation, state: _StateAndBuffers, indices: List[int]
) -> None:
Expand Down Expand Up @@ -339,17 +317,6 @@ def _create_simulator_trial_result(
params=params, measurements=measurements, final_simulator_state=final_simulator_state
)

def _check_all_resolved(self, circuit):
"""Raises if the circuit contains unresolved symbols."""
if protocols.is_parameterized(circuit):
unresolved = [
op for moment in circuit for op in moment if protocols.is_parameterized(op)
]
raise ValueError(
'Circuit contains ops whose symbols were not specified in '
'parameter sweep. Ops: {}'.format(unresolved)
)


class DensityMatrixStepResult(simulator.StepResult):
"""A single step in the simulation of the DensityMatrixSimulator.
Expand Down
4 changes: 3 additions & 1 deletion cirq/sim/mux_test.py
Expand Up @@ -394,7 +394,9 @@ def test_final_density_matrix_noise():

def test_deprecated():
a = cirq.LineQubit(0)
with cirq.testing.assert_logs('final_wavefunction', 'final_state_vector', 'deprecated'):
with cirq.testing.assert_logs(
'final_wavefunction', 'final_state_vector', 'deprecated', count=2
):
_ = cirq.final_wavefunction([cirq.H(a)])


Expand Down
52 changes: 50 additions & 2 deletions cirq/sim/simulator.py
Expand Up @@ -35,6 +35,7 @@
import numpy as np

from cirq import circuits, ops, protocols, study, value, work
from cirq._compat import deprecated

if TYPE_CHECKING:
import cirq
Expand Down Expand Up @@ -250,7 +251,7 @@ class SimulatesIntermediateState(SimulatesFinalState, metaclass=abc.ABCMeta):
state at the end of a circuit, a SimulatesIntermediateState can
simulate stepping through the moments of a circuit.
Implementors of this interface should implement the _simulator_iterator
Implementors of this interface should implement the _base_iterator
method.
Note that state here refers to simulator state, which is not necessarily
Expand Down Expand Up @@ -333,7 +334,7 @@ def simulate_moment_steps(
circuit, study.ParamResolver(param_resolver), qubit_order, initial_state
)

@abc.abstractmethod
@deprecated(deadline='v0.11.0', fix='Override _base_iterator instead')
def _simulator_iterator(
self,
circuit: circuits.Circuit,
Expand All @@ -343,6 +344,43 @@ def _simulator_iterator(
) -> Iterator:
"""Iterator over StepResult from Moments of a Circuit.
If the initial state is an int, the state is set to the computational
basis state corresponding to this state. Otherwise if the initial
state is a np.ndarray it is the full initial state, either a pure state
or the full density matrix. If it is the pure state it must be the
correct size, be normalized (an L2 norm of 1), and be safely castable
to an appropriate dtype for the simulator. If it is a mixed state
it must be correctly sized and positive semidefinite with trace one.
Args:
circuit: The circuit to simulate.
param_resolver: A ParamResolver for determining values of
Symbols.
qubit_order: Determines the canonical ordering of the qubits. This
is often used in specifying the initial state, i.e. the
ordering of the computational basis states.
initial_state: The initial state for the simulation. The form of
this state depends on the simulation implementation. See
documentation of the implementing class for details.
Yields:
StepResults from simulating a Moment of the Circuit.
"""
param_resolver = param_resolver or study.ParamResolver({})
resolved_circuit = protocols.resolve_parameters(circuit, param_resolver)
check_all_resolved(resolved_circuit)
actual_initial_state = 0 if initial_state is None else initial_state
return self._base_iterator(resolved_circuit, qubit_order, actual_initial_state)

@abc.abstractmethod
def _base_iterator(
self,
circuit: circuits.Circuit,
qubit_order: ops.QubitOrderOrList,
initial_state: Any,
) -> Iterator['StepResult']:
"""Iterator over StepResult from Moments of a Circuit.
Args:
circuit: The circuit to simulate.
param_resolver: A ParamResolver for determining values of
Expand Down Expand Up @@ -593,3 +631,13 @@ def _verify_unique_measurement_keys(circuit: circuits.Circuit):
duplicates = [k for k, v in result.most_common() if v > 1]
if duplicates:
raise ValueError('Measurement key {} repeated'.format(",".join(duplicates)))


def check_all_resolved(circuit):
"""Raises if the circuit contains unresolved symbols."""
if protocols.is_parameterized(circuit):
unresolved = [op for moment in circuit for op in moment if protocols.is_parameterized(op)]
raise ValueError(
'Circuit contains ops whose symbols were not specified in '
'parameter sweep. Ops: {}'.format(unresolved)
)
37 changes: 2 additions & 35 deletions cirq/sim/sparse_simulator.py
Expand Up @@ -26,6 +26,7 @@
state_vector_simulator,
act_on_state_vector_args,
)
from cirq.sim.simulator import check_all_resolved

if TYPE_CHECKING:
import cirq
Expand Down Expand Up @@ -135,7 +136,7 @@ def _run(
"""See definition in `cirq.SimulatesSamples`."""
param_resolver = param_resolver or study.ParamResolver({})
resolved_circuit = protocols.resolve_parameters(circuit, param_resolver)
self._check_all_resolved(resolved_circuit)
check_all_resolved(resolved_circuit)
qubit_order = sorted(resolved_circuit.all_qubits())

# Simulate as many unitary operations as possible before having to
Expand Down Expand Up @@ -192,29 +193,6 @@ def _brute_force_samples(
measurements[k].append(np.array(v, dtype=np.uint8))
return {k: np.array(v) for k, v in measurements.items()}

def _simulator_iterator(
self,
circuit: circuits.Circuit,
param_resolver: study.ParamResolver,
qubit_order: ops.QubitOrderOrList,
initial_state: 'cirq.STATE_VECTOR_LIKE',
) -> Iterator:
"""See definition in `cirq.SimulatesIntermediateState`.
If the initial state is an int, the state is set to the computational
basis state corresponding to this state. Otherwise if the initial
state is a np.ndarray it is the full initial state. In this case it
must be the correct size, be normalized (an L2 norm of 1), and
be safely castable to an appropriate dtype for the simulator.
"""
param_resolver = param_resolver or study.ParamResolver({})
resolved_circuit = protocols.resolve_parameters(circuit, param_resolver)
self._check_all_resolved(resolved_circuit)
actual_initial_state = 0 if initial_state is None else initial_state
return self._base_iterator(
resolved_circuit, qubit_order, actual_initial_state, perform_measurements=True
)

def _base_iterator(
self,
circuit: circuits.Circuit,
Expand Down Expand Up @@ -254,17 +232,6 @@ def _base_iterator(
)
sim_state.log_of_measurement_results.clear()

def _check_all_resolved(self, circuit):
"""Raises if the circuit contains unresolved symbols."""
if protocols.is_parameterized(circuit):
unresolved = [
op for moment in circuit for op in moment if protocols.is_parameterized(op)
]
raise ValueError(
'Circuit contains ops whose symbols were not specified in '
'parameter sweep. Ops: {}'.format(unresolved)
)


class SparseSimulatorStep(
state_vector.StateVectorMixin, state_vector_simulator.StateVectorStepResult
Expand Down
32 changes: 3 additions & 29 deletions cirq/sim/state_vector_simulator.py
Expand Up @@ -15,11 +15,11 @@

import abc

from typing import Any, cast, Dict, Iterator, Sequence, TYPE_CHECKING, Tuple
from typing import Any, cast, Dict, Sequence, TYPE_CHECKING, Tuple

import numpy as np

from cirq import circuits, ops, study, value
from cirq import ops, study, value
from cirq.sim import simulator, state_vector
from cirq._compat import deprecated

Expand All @@ -32,35 +32,9 @@ class SimulatesIntermediateStateVector(
):
"""A simulator that accesses its state vector as it does its simulation.
Implementors of this interface should implement the _simulator_iterator
Implementors of this interface should implement the _base_iterator
method."""

@abc.abstractmethod
def _simulator_iterator(
self,
circuit: circuits.Circuit,
param_resolver: study.ParamResolver,
qubit_order: ops.QubitOrderOrList,
initial_state: np.ndarray,
) -> Iterator:
"""Iterator over StateVectorStepResult from Moments of a Circuit.
Args:
circuit: The circuit to simulate.
param_resolver: A ParamResolver for determining values of
Symbols.
qubit_order: Determines the canonical ordering of the qubits. This
is often used in specifying the initial state, i.e. the
ordering of the computational basis states.
initial_state: The initial state for the simulation. The form of
this state depends on the simulation implementation. See
documentation of the implementing class for details.
Yields:
StateVectorStepResult from simulating a Moment of the Circuit.
"""
raise NotImplementedError()

def _create_simulator_trial_result(
self,
params: study.ParamResolver,
Expand Down
2 changes: 1 addition & 1 deletion cirq/sim/state_vector_simulator_test.py
Expand Up @@ -200,7 +200,7 @@ def sample(self, qubits, repetitions, seed):
_ = TestStepResult()

class TestSimulatesClass(cirq.sim.SimulatesIntermediateWaveFunction):
def _simulator_iterator(self, circuit, param_resolver, qubit_order, initial_state):
def _base_iterator(self, circuit, qubit_order, initial_state):
pass

with cirq.testing.assert_logs(
Expand Down

0 comments on commit 5fa8a3b

Please sign in to comment.