Skip to content

Commit

Permalink
Sample independent qubit sets without merging state space (quantumlib…
Browse files Browse the repository at this point in the history
…#4110)

* split

* Allow config param split_entangled_states

* default split to off

* ensure consistent_act_on circuits have a qubit.

* lint

* lint

* mps

* lint

* lint

* run sparse by default

* fix tests

* fix tests

* fix tests

* most of sparse and dm

* clifford

* sim_base

* sim_base

* mps

* turn off on experiments with rounding error

* fix tests

* fix tests

* fix testsCreate base step result

* clifford

* mps

* mps

* mps

* tableau

* test simulator

* test simulator

* Update simulator_base.py

* Drop mps/join

* Fix clifford extract

* lint

* simplify index

* Add qubits to base class

* Fix clifford sampling

* Fix _sim_state_values

* fix tostring tests, format

* remove split/join from ch-form

* remove split/join from ch-form

* push merged state to base layer

* lint

* mypy

* mypy

* mypy

* Add default arg for zero qubit circuits

* Have last repetition reuse original state repr

* Remove cast

* Split all pure initial states by default

* Detangle on reset channels

* docstrings

* docstrings

* docstrings

* docstrings

* fix merge

* lint

* Add unit test for integer states

* format

* Add tests for splitting and joining

* remove unnecessary qubits param

* Clean up default args

* Fix failing test

* Add ActOnArgsContainer

* Add ActOnArgsContainer

* Clean up tests

* Clean up tests

* Clean up tests

* format

* Fix tests and coverage

* Add OperationTarget interface

* Fix unit tests

* mypy, lint, mocks, coverage

* coverage

* lint, tests

* lint, tests

* mypy

* mypy, tests

* remove test code

* test

* dead code

* mocks

* add log to container

* fix logs

* dead code

* unit test

* unit test

* dead code

* operationtarget samples

* StepResultBase

* Mock, format

* EmptyActOnArgs

* EmptyActOnArgs

* simplify dummyargs

* lint

* Add [] to actonargs

* rename _create_act_on_arg

* coverage

* coverage

* Default sparse sim to split=false

* format

* Default sparse sim to split=false

* Default density matrix sim to split=false

* lint

* lint

* lint

* lint

* address review comments

* lint

* Defaults back to split=false

* add error if setting state when split is enabled

* Unit tests

* coverage

* coverage

* coverage

* docs

* conflicts

* conflicts

* cover

* Add qubits to bb84

* mergedsimstate private

* q_set

* default to split=True

* Allow set_state

* Allow set_state

* format

* fix merge

* fix merge

* maintain order in sampling for determinicity.

* Pydoc fixes

* revert bb48 num_qubits change

* fix docstrings for set_state error

* Remove duplicate sample declaration from ActOnArgs

* Remove unnecessary split_untangled_states=True

* Reduce atol of dm/sv test

* Add test for sim_state propagation from step_result

* Add test for sim_state propagation from step_result

Co-authored-by: Cirq Bot <craiggidney+github+cirqbot@google.com>
  • Loading branch information
2 people authored and rht committed May 1, 2023
1 parent 880c678 commit 91239f0
Show file tree
Hide file tree
Showing 27 changed files with 347 additions and 278 deletions.
1 change: 1 addition & 0 deletions cirq-core/cirq/__init__.py
Expand Up @@ -403,6 +403,7 @@
StateVectorStepResult,
StateVectorTrialResult,
StepResult,
StepResultBase,
)

from cirq.study import (
Expand Down
65 changes: 31 additions & 34 deletions cirq-core/cirq/contrib/quimb/mps_simulator.py
Expand Up @@ -117,12 +117,9 @@ def _create_partial_act_on_args(

def _create_step_result(
self,
sim_state: 'MPSState',
qubit_map: Dict['cirq.Qid', int],
sim_state: 'cirq.OperationTarget[MPSState]',
):
return MPSSimulatorStepResult(
measurements=sim_state.log_of_measurement_results, state=sim_state
)
return MPSSimulatorStepResult(sim_state)

def _create_simulator_trial_result(
self,
Expand Down Expand Up @@ -169,22 +166,22 @@ def __str__(self) -> str:
return f'measurements: {samples}\noutput state: {final}'


class MPSSimulatorStepResult(simulator.StepResult['MPSState']):
class MPSSimulatorStepResult(simulator_base.StepResultBase['MPSState', 'MPSState']):
"""A `StepResult` that can perform measurements."""

def __init__(self, state, measurements):
def __init__(
self,
sim_state: 'cirq.OperationTarget[MPSState]',
):
"""Results of a step of the simulator.
Attributes:
state: A MPSState
measurements: A dictionary from measurement gate key to measurement
results, ordered by the qubits that the measurement operates on.
qubit_map: A map from the Qubits in the Circuit to the the index
of this qubit for a canonical ordering. This canonical ordering
is used to define the state vector (see the state_vector()
method).
sim_state: The qubit:ActOnArgs lookup for this step.
"""
self.measurements = measurements
self.state = state.copy()
super().__init__(sim_state)

@property
def state(self):
return self._merged_sim_state

def __str__(self) -> str:
def bitstring(vals):
Expand All @@ -204,24 +201,6 @@ def bitstring(vals):
def _simulator_state(self):
return self.state

def sample(
self,
qubits: List[ops.Qid],
repetitions: int = 1,
seed: 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None,
) -> np.ndarray:

measurements: List[int] = []

for _ in range(repetitions):
measurements.append(
self.state.perform_measurement(
qubits, value.parse_random_state(seed), collapse_state_vector=False
)
)

return np.array(measurements, dtype=int)


@value.value_equality
class MPSState(ActOnArgs):
Expand Down Expand Up @@ -537,3 +516,21 @@ def perform_measurement(
def _perform_measurement(self, qubits: Sequence['cirq.Qid']) -> List[int]:
"""Measures the axes specified by the simulator."""
return self.perform_measurement(qubits, self.prng)

def sample(
self,
qubits: Sequence[ops.Qid],
repetitions: int = 1,
seed: 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None,
) -> np.ndarray:

measurements: List[List[int]] = []

for _ in range(repetitions):
measurements.append(
self.perform_measurement(
qubits, value.parse_random_state(seed), collapse_state_vector=False
)
)

return np.array(measurements, dtype=int)
6 changes: 3 additions & 3 deletions cirq-core/cirq/contrib/quimb/mps_simulator_test.py
Expand Up @@ -274,11 +274,11 @@ def test_trial_result_str():

def test_empty_step_result():
q0 = cirq.LineQubit(0)
state = ccq.mps_simulator.MPSState(qubits=(q0,), prng=value.parse_random_state(0))
step_result = ccq.mps_simulator.MPSSimulatorStepResult(state, measurements={'0': [1]})
sim = ccq.mps_simulator.MPSSimulator()
step_result = next(sim.simulate_moment_steps(cirq.Circuit(cirq.measure(q0))))
assert (
str(step_result)
== """0=1
== """0=0
TensorNetwork([
Tensor(shape=(2,), inds=('i_0',), tags=set()),
])"""
Expand Down
Expand Up @@ -20,7 +20,7 @@ def test_estimate_parallel_two_qubit_xeb_fidelity_on_grid_no_noise(tmpdir):
two_qubit_gate = cirq.ISWAP ** 0.5
cycles = [5, 10, 15]
data_collection_id = collect_grid_parallel_two_qubit_xeb_data(
sampler=cirq.Simulator(seed=34310),
sampler=cirq.Simulator(seed=34310, split_untangled_states=False),
qubits=qubits,
two_qubit_gate=two_qubit_gate,
num_circuits=2,
Expand Down Expand Up @@ -53,7 +53,9 @@ def test_estimate_parallel_two_qubit_xeb_fidelity_on_grid_depolarizing(tmpdir):
cycles = [5, 10, 15]
e = 0.01
data_collection_id = collect_grid_parallel_two_qubit_xeb_data(
sampler=cirq.DensityMatrixSimulator(noise=cirq.depolarize(e), seed=65008),
sampler=cirq.DensityMatrixSimulator(
noise=cirq.depolarize(e), seed=65008, split_untangled_states=False
),
qubits=qubits,
two_qubit_gate=two_qubit_gate,
num_circuits=2,
Expand Down
Expand Up @@ -31,7 +31,7 @@ def __init__(self, p0: float, p1: float, seed: 'cirq.RANDOM_STATE_OR_SEED_LIKE'
self.p0 = p0
self.p1 = p1
self.prng = cirq.value.parse_random_state(seed)
self.simulator = cirq.Simulator(seed=self.prng)
self.simulator = cirq.Simulator(seed=self.prng, split_untangled_states=False)

def run_sweep(
self,
Expand Down
3 changes: 3 additions & 0 deletions cirq-core/cirq/protocols/act_on_protocol_test.py
Expand Up @@ -37,6 +37,9 @@ def copy(self):
def _act_on_fallback_(self, action, qubits, allow_decompose):
return self.fallback_result

def sample(self, qubits, repetitions=1, seed=None):
pass


op = cirq.X(cirq.LineQubit(0))

Expand Down
1 change: 1 addition & 0 deletions cirq-core/cirq/sim/__init__.py
Expand Up @@ -64,6 +64,7 @@
)

from cirq.sim.simulator_base import (
StepResultBase,
SimulatorBase,
)

Expand Down
9 changes: 4 additions & 5 deletions cirq-core/cirq/sim/act_on_args.py
Expand Up @@ -60,8 +60,7 @@ def __init__(
axes: The indices of axes corresponding to the qubits that the
operation is supposed to act upon.
log_of_measurement_results: A mutable object that measurements are
being recorded into. Edit it easily by calling
`ActOnStateVectorArgs.record_measurement_result`.
being recorded into.
"""
if prng is None:
prng = cast(np.random.RandomState, np.random)
Expand All @@ -72,7 +71,7 @@ def __init__(
if log_of_measurement_results is None:
log_of_measurement_results = {}
self._qubits = tuple(qubits)
self.qubit_map = {q: i for i, q in enumerate(self.qubits)}
self.qubit_map = {q: i for i, q in enumerate(qubits)}
self._axes = tuple(axes)
self.prng = prng
self._log_of_measurement_results = log_of_measurement_results
Expand All @@ -89,9 +88,9 @@ def measure(self, qubits: Sequence['cirq.Qid'], key: str, invert_mask: Sequence[
"""
bits = self._perform_measurement(qubits)
corrected = [bit ^ (bit < 2 and mask) for bit, mask in zip(bits, invert_mask)]
if key in self.log_of_measurement_results:
if key in self._log_of_measurement_results:
raise ValueError(f"Measurement already logged to key {key!r}")
self.log_of_measurement_results[key] = corrected
self._log_of_measurement_results[key] = corrected

def get_axes(self, qubits: Sequence['cirq.Qid']) -> List[int]:
return [self.qubit_map[q] for q in qubits]
Expand Down
26 changes: 25 additions & 1 deletion cirq-core/cirq/sim/act_on_args_container.py
Expand Up @@ -20,10 +20,14 @@
Sequence,
Optional,
Iterator,
Tuple,
Any,
Tuple,
Set,
List,
)

import numpy as np

from cirq import ops
from cirq.sim.operation_target import OperationTarget
from cirq.sim.simulator import (
Expand Down Expand Up @@ -122,6 +126,26 @@ def qubits(self) -> Tuple['cirq.Qid', ...]:
def log_of_measurement_results(self) -> Dict[str, Any]:
return self._log_of_measurement_results

def sample(
self,
qubits: List[ops.Qid],
repetitions: int = 1,
seed: 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None,
) -> np.ndarray:
columns = []
selected_order: List[ops.Qid] = []
q_set = set(qubits)
for v in dict.fromkeys(self.args.values()):
qs = [q for q in v.qubits if q in q_set]
if any(qs):
column = v.sample(qs, repetitions, seed)
columns.append(column)
selected_order += qs
stacked = np.column_stack(columns)
qubit_map = {q: i for i, q in enumerate(selected_order)}
index_order = [qubit_map[q] for q in qubits]
return stacked[:, index_order]

def __getitem__(self, item: Optional['cirq.Qid']) -> TActOnArgs:
return self.args[item]

Expand Down
3 changes: 3 additions & 0 deletions cirq-core/cirq/sim/act_on_args_container_test.py
Expand Up @@ -64,6 +64,9 @@ def transpose_to_qubit_order(self, qubits: Sequence['cirq.Qid']) -> 'EmptyActOnA
logs=self.log_of_measurement_results,
)

def sample(self, qubits, repetitions=1, seed=None):
pass


q0, q1 = qs2 = cirq.LineQubit.range(2)

Expand Down
3 changes: 3 additions & 0 deletions cirq-core/cirq/sim/act_on_args_test.py
Expand Up @@ -25,6 +25,9 @@ def __init__(self):
def copy(self):
pass

def sample(self, qubits, repetitions=1, seed=None):
pass

def _perform_measurement(self, qubits):
return [5, 3]

Expand Down
18 changes: 16 additions & 2 deletions cirq-core/cirq/sim/act_on_density_matrix_args.py
Expand Up @@ -82,8 +82,7 @@ def __init__(
prng: The pseudo random number generator to use for probabilistic
effects.
log_of_measurement_results: A mutable object that measurements are
being recorded into. Edit it easily by calling
`ActOnStateVectorArgs.record_measurement_result`.
being recorded into.
axes: The indices of axes corresponding to the qubits that the
operation is supposed to act upon.
"""
Expand Down Expand Up @@ -197,6 +196,21 @@ def transpose_to_qubit_order(
log_of_measurement_results=self.log_of_measurement_results,
)

def sample(
self,
qubits: Sequence['cirq.Qid'],
repetitions: int = 1,
seed: 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None,
) -> np.ndarray:
indices = [self.qubit_map[q] for q in qubits]
return sim.sample_density_matrix(
self.target_tensor,
indices,
qid_shape=tuple(q.dimension for q in self.qubits),
repetitions=repetitions,
seed=seed,
)


def _strat_apply_channel_to_state(
action: Any, args: ActOnDensityMatrixArgs, qubits: Sequence['cirq.Qid']
Expand Down
21 changes: 17 additions & 4 deletions cirq-core/cirq/sim/act_on_state_vector_args.py
Expand Up @@ -41,13 +41,12 @@ def _rewrite_deprecated_args(args, kwargs):
class ActOnStateVectorArgs(ActOnArgs):
"""State and context for an operation acting on a state vector.
There are three common ways to act on this object:
There are two common ways to act on this object:
1. Directly edit the `target_tensor` property, which is storing the state
vector of the quantum system as a numpy array with one axis per qudit.
2. Overwrite the `available_buffer` property with the new state vector, and
then pass `available_buffer` into `swap_target_tensor_for`.
3. Call `record_measurement_result(key, val)` to log a measurement result.
"""

@deprecated_parameter(
Expand Down Expand Up @@ -84,8 +83,7 @@ def __init__(
prng: The pseudo random number generator to use for probabilistic
effects.
log_of_measurement_results: A mutable object that measurements are
being recorded into. Edit it easily by calling
`ActOnStateVectorArgs.record_measurement_result`.
being recorded into.
axes: The indices of axes corresponding to the qubits that the
operation is supposed to act upon.
"""
Expand Down Expand Up @@ -255,6 +253,21 @@ def transpose_to_qubit_order(self, qubits: Sequence['cirq.Qid']) -> 'cirq.ActOnS
)
return new_args

def sample(
self,
qubits: Sequence['cirq.Qid'],
repetitions: int = 1,
seed: 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None,
) -> np.ndarray:
indices = [self.qubit_map[q] for q in qubits]
return sim.sample_state_vector(
self.target_tensor,
indices,
qid_shape=tuple(q.dimension for q in self.qubits),
repetitions=repetitions,
seed=seed,
)


def _strat_act_on_state_vector_from_apply_unitary(
unitary_value: Any,
Expand Down
19 changes: 13 additions & 6 deletions cirq-core/cirq/sim/clifford/act_on_clifford_tableau_args.py
Expand Up @@ -44,10 +44,9 @@ def _rewrite_deprecated_args(args, kwargs):

class ActOnCliffordTableauArgs(ActOnArgs):
"""State and context for an operation acting on a clifford tableau.
There are two common ways to act on this object:
1. Directly edit the `tableau` property, which is storing the clifford
tableau of the quantum system with one axis per qubit.
2. Call `record_measurement_result(key, val)` to log a measurement result.
To act on this object, directly edit the `tableau` property, which is
storing the density matrix of the quantum system with one axis per qubit.
"""

@deprecated_parameter(
Expand Down Expand Up @@ -77,8 +76,7 @@ def __init__(
prng: The pseudo random number generator to use for probabilistic
effects.
log_of_measurement_results: A mutable object that measurements are
being recorded into. Edit it easily by calling
`ActOnCliffordTableauArgs.record_measurement_result`.
being recorded into.
axes: The indices of axes corresponding to the qubits that the
operation is supposed to act upon.
"""
Expand Down Expand Up @@ -111,6 +109,15 @@ def copy(self) -> 'cirq.ActOnCliffordTableauArgs':
log_of_measurement_results=self.log_of_measurement_results.copy(),
)

def sample(
self,
qubits: Sequence['cirq.Qid'],
repetitions: int = 1,
seed: 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None,
) -> np.ndarray:
# Unnecessary for now but can be added later if there is a use case.
raise NotImplementedError()


def _strat_act_on_clifford_tableau_from_single_qubit_decompose(
val: Any, args: 'cirq.ActOnCliffordTableauArgs', qubits: Sequence['cirq.Qid']
Expand Down

0 comments on commit 91239f0

Please sign in to comment.