Skip to content

Commit

Permalink
[obs] 4.2 - Basic sampling loop (#3855)
Browse files Browse the repository at this point in the history
Loop through grouped settings, execute, accumulate bitstrings. 

The following PRs will add
 - More advanced control of how many samples to take per chunk and when to stop
 - Checkpointing
 - Higher-level APIs that are more convenient to use

Part of #3647
Follows #3792
  • Loading branch information
mpharrigan committed Mar 25, 2021
1 parent f12e221 commit 24e98fa
Show file tree
Hide file tree
Showing 3 changed files with 272 additions and 2 deletions.
3 changes: 3 additions & 0 deletions cirq/work/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@
BitstringAccumulator,
flatten_grouped_results,
)
from cirq.work.observable_measurement import (
measure_grouped_settings,
)
from cirq.work.sampler import (
Sampler,
)
Expand Down
183 changes: 181 additions & 2 deletions cirq/work/observable_measurement.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,14 @@

import dataclasses
import itertools
from typing import Iterable, List, Tuple, TYPE_CHECKING, Sequence, Dict
import warnings
from typing import Iterable, Dict, List, Tuple, TYPE_CHECKING, Set, Sequence

import numpy as np
import sympy

from cirq import circuits, ops, value
from cirq import circuits, study, ops, value
from cirq.work.observable_measurement_data import BitstringAccumulator
from cirq.work.observable_settings import (
InitObsSetting,
_MeasurementSpec,
Expand Down Expand Up @@ -150,6 +152,62 @@ def _pad_setting(
return InitObsSetting(init_state=init_state, observable=obs)


def _aggregate_n_repetitions(next_chunk_repetitions: Set[int]) -> int:
"""In the future, we will allow each accumulator to request a different number
of repetitions for the next chunk. For batching efficiency, we take the
max and issue a warning in this case."""
if len(next_chunk_repetitions) == 1:
return list(next_chunk_repetitions)[0]

reps = max(next_chunk_repetitions)
warnings.warn(
f"The stopping criteria specified a various numbers of "
f"repetitions to perform next. To be able to submit as a single "
f"sweep, the largest value will be used: {reps}."
)
return reps


def _repetitions_to_do(accumulator: BitstringAccumulator, desired_reps: int) -> int:
"""Stub function to chunk desired repetitions into groups of 10,000."""
done = accumulator.n_repetitions
todo = desired_reps - done
if todo <= 0:
return 0

to_do_next = min(10_000, todo)
return to_do_next


def _check_meas_specs_still_todo(
meas_specs: List[_MeasurementSpec],
accumulators: Dict[_MeasurementSpec, BitstringAccumulator],
desired_repetitions: int,
) -> Tuple[List[_MeasurementSpec], int]:
"""Filter `meas_specs` in case some are done.
In the sampling loop in `measure_grouped_settings`, we submit
each `meas_spec` in chunks. This function contains the logic for
removing `meas_spec`s from the loop if they are done.
"""
still_todo = []
repetitions: Set[int] = set()
for meas_spec in meas_specs:
accumulator = accumulators[meas_spec]
more_repetitions = _repetitions_to_do(accumulator, desired_reps=desired_repetitions)
if more_repetitions == 0:
continue

repetitions.add(more_repetitions)
still_todo.append(meas_spec)

if len(still_todo) == 0:
return still_todo, 0

reps = _aggregate_n_repetitions(repetitions)
return still_todo, reps


@dataclasses.dataclass(frozen=True)
class _FlippyMeasSpec:
"""Internally, each MeasurementSpec class is split into two
Expand Down Expand Up @@ -211,3 +269,124 @@ def _subdivide_meas_specs(
repetitions //= 2

return flippy_mspecs, repetitions


def _to_sweep(param_tuples):
"""Turn param tuples into a sweep."""
to_sweep = [dict(pt) for pt in param_tuples]
to_sweep = study.to_sweep(to_sweep)
return to_sweep


def _needs_init_layer(grouped_settings: Dict[InitObsSetting, List[InitObsSetting]]) -> bool:
"""Helper function to go through init_states and determine if any of them need an
initialization layer of single-qubit gates."""
for max_setting in grouped_settings.keys():
if any(st is not value.KET_ZERO for _, st in max_setting.init_state):
return True
return False


def measure_grouped_settings(
circuit: 'cirq.Circuit',
grouped_settings: Dict[InitObsSetting, List[InitObsSetting]],
sampler: 'cirq.Sampler',
desired_repetitions: int,
*,
readout_symmetrization: bool = False,
circuit_sweep: 'cirq.study.sweepable.SweepLike' = None,
) -> List[BitstringAccumulator]:
"""Measure a suite of grouped InitObsSetting settings.
This is a low-level API for accessing the observable measurement
framework. See also `measure_observables` and `measure_observables_df`.
Args:
circuit: The circuit. This can contain parameters, in which case
you should also specify `circuit_sweep`.
grouped_settings: A series of setting groups expressed as a dictionary.
The key is the max-weight setting used for preparing single-qubit
basis-change rotations. The value is a list of settings
compatible with the maximal setting you desire to measure.
Automated routing algorithms like `group_settings_greedy` can
be used to construct this input.
sampler: A sampler.
desired_repetitions: How many repetitions per observable.
readout_symmetrization: If set to True, each `meas_spec` will be
split into two runs: one normal and one where a bit flip is
incorporated prior to measurement. In the latter case, the
measured bit will be flipped back classically and accumulated
together. This causes readout error to appear symmetric,
p(0|0) = p(1|1).
circuit_sweep: Additional parameter sweeps for parameters contained
in `circuit`. The total sweep is the product of the circuit sweep
with parameter settings for the single-qubit basis-change rotations.
"""
qubits = sorted({q for ms in grouped_settings.keys() for q in ms.init_state.qubits})
qubit_to_index = {q: i for i, q in enumerate(qubits)}

needs_init_layer = _needs_init_layer(grouped_settings)
measurement_param_circuit = _with_parameterized_layers(circuit, qubits, needs_init_layer)
grouped_settings = {
_pad_setting(max_setting, qubits): settings
for max_setting, settings in grouped_settings.items()
}
circuit_sweep = study.UnitSweep if circuit_sweep is None else study.to_sweep(circuit_sweep)

# meas_spec provides a key for accumulators.
# meas_specs_todo is a mutable list. We will pop things from it as various
# specs are measured to the satisfaction of the stopping criteria
accumulators = {}
meas_specs_todo = []
for max_setting, circuit_params in itertools.product(
grouped_settings.keys(), circuit_sweep.param_tuples()
):
# The type annotation for Param is just `Iterable`.
# We make sure that it's truly a tuple.
circuit_params = dict(circuit_params)

meas_spec = _MeasurementSpec(max_setting=max_setting, circuit_params=circuit_params)
accumulator = BitstringAccumulator(
meas_spec=meas_spec,
simul_settings=grouped_settings[max_setting],
qubit_to_index=qubit_to_index,
)
accumulators[meas_spec] = accumulator
meas_specs_todo += [meas_spec]

while True:
meas_specs_todo, repetitions = _check_meas_specs_still_todo(
meas_specs=meas_specs_todo,
accumulators=accumulators,
desired_repetitions=desired_repetitions,
)
if len(meas_specs_todo) == 0:
break

flippy_meas_specs, repetitions = _subdivide_meas_specs(
meas_specs=meas_specs_todo,
repetitions=repetitions,
qubits=qubits,
readout_symmetrization=readout_symmetrization,
)

resolved_params = [
flippy_ms.param_tuples(needs_init_layer=needs_init_layer)
for flippy_ms in flippy_meas_specs
]
resolved_params = _to_sweep(resolved_params)

results = sampler.run_sweep(
program=measurement_param_circuit, params=resolved_params, repetitions=repetitions
)

assert len(results) == len(
flippy_meas_specs
), 'Not as many results received as sweeps requested!'

for flippy_ms, result in zip(flippy_meas_specs, results):
accumulator = accumulators[flippy_ms.meas_spec]
bitstrings = np.logical_xor(flippy_ms.flips, result.measurements['z'])
accumulator.consume_results(bitstrings.astype(np.uint8, casting='safe'))

return list(accumulators.values())
88 changes: 88 additions & 0 deletions cirq/work/observable_measurement_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,14 @@

import cirq
import cirq.work as cw
from cirq.work import _MeasurementSpec
from cirq.work.observable_measurement import (
_with_parameterized_layers,
_get_params_for_setting,
_pad_setting,
_subdivide_meas_specs,
_aggregate_n_repetitions,
_check_meas_specs_still_todo,
)


Expand Down Expand Up @@ -187,3 +190,88 @@ def test_subdivide_meas_specs():
('beta', 0.123),
('gamma', 0.456),
]


def test_aggregate_n_repetitions():
with pytest.warns(UserWarning):
reps = _aggregate_n_repetitions({5, 6})
assert reps == 6


def test_meas_specs_still_todo():
q0, q1 = cirq.LineQubit.range(2)
setting = cw.InitObsSetting(
init_state=cirq.KET_ZERO(q0) * cirq.KET_ZERO(q1), observable=cirq.X(q0) * cirq.Y(q1)
)
meas_spec = _MeasurementSpec(
max_setting=setting,
circuit_params={
'beta': 0.123,
'gamma': 0.456,
},
)
bsa = cw.BitstringAccumulator(
meas_spec, [], {q: i for i, q in enumerate(cirq.LineQubit.range(3))}
)

# 1. before taking any data
still_todo, reps = _check_meas_specs_still_todo(
meas_specs=[meas_spec], accumulators={meas_spec: bsa}, desired_repetitions=1_000
)
assert still_todo == [meas_spec]
assert reps == 1_000

# 2. After taking a mocked-out 997 shots.
bsa.consume_results(np.zeros((997, 3), dtype=np.uint8))
still_todo, reps = _check_meas_specs_still_todo(
meas_specs=[meas_spec], accumulators={meas_spec: bsa}, desired_repetitions=1_000
)
assert still_todo == [meas_spec]
assert reps == 3

# 3. After taking the final 3 shots
bsa.consume_results(np.zeros((reps, 3), dtype=np.uint8))
still_todo, reps = _check_meas_specs_still_todo(
meas_specs=[meas_spec], accumulators={meas_spec: bsa}, desired_repetitions=1_000
)
assert still_todo == []
assert reps == 0


@pytest.mark.parametrize('with_circuit_sweep', (True, False))
def test_measure_grouped_settings(with_circuit_sweep):
qubits = cirq.LineQubit.range(1)
(q,) = qubits
tests = [
(cirq.KET_ZERO, cirq.Z, 1),
(cirq.KET_ONE, cirq.Z, -1),
(cirq.KET_PLUS, cirq.X, 1),
(cirq.KET_MINUS, cirq.X, -1),
(cirq.KET_IMAG, cirq.Y, 1),
(cirq.KET_MINUS_IMAG, cirq.Y, -1),
]
if with_circuit_sweep:
ss = cirq.Linspace('a', 0, 1, 12)
else:
ss = None

for init, obs, coef in tests:
setting = cw.InitObsSetting(
init_state=init(q),
observable=obs(q),
)
grouped_settings = {setting: [setting]}
circuit = cirq.Circuit(cirq.I.on_each(*qubits))
results = cw.measure_grouped_settings(
circuit=circuit,
grouped_settings=grouped_settings,
sampler=cirq.Simulator(),
desired_repetitions=1_000,
circuit_sweep=ss,
)
if with_circuit_sweep:
for result in results:
assert result.means() == [coef]
else:
(result,) = results # one group
assert result.means() == [coef]

0 comments on commit 24e98fa

Please sign in to comment.