Skip to content

Commit

Permalink
Eliminate simulator boilerplate by pushing iteration into base class (#…
Browse files Browse the repository at this point in the history
…4035)

Fixes #2178

Pushes iteration and run repetitions into simulator base class. Simulators only have to implement _create_act_on_args, and _create_step_result, and the base class handles the rest.

After the migration to ActOnArgs, all simulators had nearly identical code for iteration and repetitions. The few places where they differed were either bugs (state vector had a bug where noise was not accounted for in terminal measurement gates), or missed optimization opportunities (MPS and Clifford did not split out unitary prefixes when running iterations). This PR consolidates and dedupes that logic. It also lays the foundation for future improvements, such as the ActOnArgs.join/split/move we've discussed separately, to be implemented once and inherited by all simulators.
  • Loading branch information
daxfohl committed May 6, 2021
1 parent 226ca20 commit d485de8
Show file tree
Hide file tree
Showing 12 changed files with 550 additions and 408 deletions.
1 change: 1 addition & 0 deletions cirq-core/cirq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,7 @@
SimulatesSamples,
SimulationTrialResult,
Simulator,
SimulatorBase,
SparseSimulatorStep,
StabilizerSampler,
StateVectorMixin,
Expand Down
98 changes: 13 additions & 85 deletions cirq-core/cirq/contrib/quimb/mps_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,8 @@
import numpy as np
import quimb.tensor as qtn

from cirq import circuits, devices, study, ops, protocols, value
from cirq.ops import flatten_to_ops
from cirq.sim import simulator
from cirq import devices, study, ops, protocols, value
from cirq.sim import simulator, simulator_base
from cirq.sim.act_on_args import ActOnArgs

if TYPE_CHECKING:
Expand All @@ -53,8 +52,7 @@ class MPSOptions:


class MPSSimulator(
simulator.SimulatesSamples,
simulator.SimulatesIntermediateState[
simulator_base.SimulatorBase[
'MPSSimulatorStepResult', 'MPSTrialResult', 'MPSState', 'MPSState'
],
):
Expand All @@ -79,10 +77,12 @@ def __init__(
noise_model = devices.NoiseModel.from_noise_model_like(noise)
if not protocols.has_mixture(noise_model):
raise ValueError(f'noise must be unitary or mixture but was {noise_model}')
self.noise = noise_model
self.prng = value.parse_random_state(seed)
self.simulation_options = simulation_options
self.grouping = grouping
super().__init__(
noise=noise,
seed=seed,
)

def _create_act_on_args(
self,
Expand All @@ -106,46 +106,20 @@ def _create_act_on_args(

return MPSState(
qubits=qubits,
prng=self.prng,
prng=self._prng,
simulation_options=self.simulation_options,
grouping=self.grouping,
initial_state=initial_state,
)

def _core_iterator(
def _create_step_result(
self,
circuit: circuits.Circuit,
sim_state: 'MPSState',
qubit_map: Dict['cirq.Qid', int],
):
"""Iterator over MPSSimulatorStepResult from Moments of a Circuit
Args:
circuit: The circuit to simulate.
sim_state: The initial state args for the simulation in the
computational basis.
Yields:
MPSStepResult from simulating a Moment of the Circuit.
"""
if len(circuit) == 0:
yield MPSSimulatorStepResult(
measurements=sim_state.log_of_measurement_results, state=sim_state
)
return

noisy_moments = self.noise.noisy_moments(circuit, sorted(circuit.all_qubits()))
for op_tree in noisy_moments:
for op in flatten_to_ops(op_tree):
if protocols.is_measurement(op) or protocols.has_mixture(op):
sim_state.axes = tuple(sim_state.qubit_map[qubit] for qubit in op.qubits)
protocols.act_on(op, sim_state)
else:
raise NotImplementedError(f"Unrecognized operation: {op!r}")

yield MPSSimulatorStepResult(
measurements=sim_state.log_of_measurement_results, state=sim_state
)
sim_state.log_of_measurement_results.clear()
return MPSSimulatorStepResult(
measurements=sim_state.log_of_measurement_results, state=sim_state
)

def _create_simulator_trial_result(
self,
Expand All @@ -170,52 +144,6 @@ def _create_simulator_trial_result(
params=params, measurements=measurements, final_simulator_state=final_simulator_state
)

def _run(
self, circuit: circuits.Circuit, param_resolver: study.ParamResolver, repetitions: int
) -> Dict[str, List[np.ndarray]]:
"""Repeats measurements multiple times.
Args:
circuit: The circuit to simulate.
param_resolver: A ParamResolver for determining values of
Symbols.
repetitions: How many measurements to perform
final_simulator_state: The final state of the simulator.
Returns:
A dictionay of measurement key (e.g. qubit) to a list of arrays that
are the measurements.
"""
param_resolver = param_resolver or study.ParamResolver({})
resolved_circuit = protocols.resolve_parameters(circuit, param_resolver)
self._check_all_resolved(resolved_circuit)

measurements: Dict[str, List[np.ndarray]] = {}

for _ in range(repetitions):
all_step_results = self._base_iterator(
resolved_circuit, qubit_order=ops.QubitOrder.DEFAULT, initial_state=0
)

for step_result in all_step_results:
for k, v in step_result.measurements.items():
if not k in measurements:
measurements[k] = []
measurements[k].append(np.array(v, dtype=int))

return {k: np.array(v) for k, v in measurements.items()}

def _check_all_resolved(self, circuit):
"""Raises if the circuit contains unresolved symbols."""
if protocols.is_parameterized(circuit):
unresolved = [
op for moment in circuit for op in moment if protocols.is_parameterized(op)
]
raise ValueError(
'Circuit contains ops whose symbols were not specified in '
'parameter sweep. Ops: {}'.format(unresolved)
)


class MPSTrialResult(simulator.SimulationTrialResult):
"""A single trial reult"""
Expand Down
4 changes: 4 additions & 0 deletions cirq-core/cirq/sim/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@
StepResult,
)

from cirq.sim.simulator_base import (
SimulatorBase,
)

from cirq.sim.sparse_simulator import (
Simulator,
SparseSimulatorStep,
Expand Down
79 changes: 11 additions & 68 deletions cirq-core/cirq/sim/clifford/clifford_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,21 +30,17 @@
"""

from typing import Any, Dict, List, Sequence, Union

import numpy as np

import cirq
from cirq import circuits, study, ops, protocols, value
from cirq import study, ops, protocols, value
from cirq._compat import deprecated
from cirq.ops.dense_pauli_string import DensePauliString
from cirq.protocols import act_on
from cirq.sim import clifford, simulator
from cirq.sim.simulator import check_all_resolved
from cirq.sim import clifford, simulator, simulator_base


class CliffordSimulator(
simulator.SimulatesSamples,
simulator.SimulatesIntermediateState[
simulator_base.SimulatorBase[
'CliffordSimulatorStepResult',
'CliffordTrialResult',
'CliffordState',
Expand All @@ -60,7 +56,7 @@ def __init__(self, seed: 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None):
seed: The random seed to use for this simulator.
"""
self.init = True
self._prng = value.parse_random_state(seed)
super().__init__(seed=seed)

@staticmethod
def is_supported_operation(op: 'cirq.Operation') -> bool:
Expand Down Expand Up @@ -99,46 +95,16 @@ def _create_act_on_args(
qubits=qubits,
)

def _core_iterator(
def _create_step_result(
self,
circuit: circuits.Circuit,
sim_state: clifford.ActOnStabilizerCHFormArgs,
qubit_map: Dict['cirq.Qid', int],
):
"""Iterator over CliffordSimulatorStepResult from Moments of a Circuit
Args:
circuit: The circuit to simulate.
sim_state: The initial state args for the simulation in the
computational basis.
Yields:
CliffordStepResult from simulating a Moment of the Circuit.
"""

def create_state():
return CliffordState(sim_state.qubit_map, sim_state.state.copy())

if len(circuit) == 0:
yield CliffordSimulatorStepResult(
measurements=sim_state.log_of_measurement_results, state=create_state()
)
return

for moment in circuit:
sim_state.log_of_measurement_results = {}

for op in moment:
try:
sim_state.axes = tuple(sim_state.qubit_map[i] for i in op.qubits)
act_on(op, sim_state)
except TypeError:
raise NotImplementedError(
f"CliffordSimulator doesn't support {op!r}"
) # type: ignore

yield CliffordSimulatorStepResult(
measurements=sim_state.log_of_measurement_results, state=create_state()
)
state = CliffordState(qubit_map)
state.ch_form = sim_state.state.copy()
return CliffordSimulatorStepResult(
measurements=sim_state.log_of_measurement_results, state=state
)

def _create_simulator_trial_result(
self,
Expand All @@ -151,29 +117,6 @@ def _create_simulator_trial_result(
params=params, measurements=measurements, final_simulator_state=final_simulator_state
)

def _run(
self, circuit: circuits.Circuit, param_resolver: study.ParamResolver, repetitions: int
) -> Dict[str, List[np.ndarray]]:

param_resolver = param_resolver or study.ParamResolver({})
resolved_circuit = protocols.resolve_parameters(circuit, param_resolver)
check_all_resolved(resolved_circuit)

measurements = {} # type: Dict[str, List[np.ndarray]]

for _ in range(repetitions):
all_step_results = self._base_iterator(
resolved_circuit, qubit_order=ops.QubitOrder.DEFAULT, initial_state=0
)

for step_result in all_step_results:
for k, v in step_result.measurements.items():
if not k in measurements:
measurements[k] = []
measurements[k].append(np.array(v, dtype=bool))

return {k: np.array(v) for k, v in measurements.items()}


class CliffordTrialResult(simulator.SimulationTrialResult):
def __init__(
Expand Down
4 changes: 2 additions & 2 deletions cirq-core/cirq/sim/clifford/clifford_simulator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ def test_non_clifford_circuit():
q0 = cirq.LineQubit(0)
circuit = cirq.Circuit()
circuit.append(cirq.T(q0))
with pytest.raises(NotImplementedError, match="support cirq.T"):
with pytest.raises(TypeError, match="support cirq.T"):
cirq.CliffordSimulator().simulate(circuit)


Expand All @@ -426,7 +426,7 @@ def test_swap():
assert not r["a"][0]
assert r["b"][0]

with pytest.raises(NotImplementedError, match="CliffordSimulator doesn't support"):
with pytest.raises(TypeError, match="CliffordSimulator doesn't support"):
cirq.CliffordSimulator().simulate((cirq.Circuit(cirq.SWAP(a, b) ** 3.5)))


Expand Down

0 comments on commit d485de8

Please sign in to comment.