Skip to content

Commit

Permalink
[Obs] 4.4 - Checkpointing (#4352)
Browse files Browse the repository at this point in the history
Save "checkpoint" files during observable estimation.

With enough observables or enough samples or low enough variables you can construct long running calls to this functionality. These options will (optionally) make sure data is not lost in those scenarios.

 - It's off by default
 - If you just toggle it to True, it will save data in a temporary directory. The use case envisaged here is to guard against data loss in an unforseen interruption
 - You can provide your own filenames. The use case here can be part of the nominal operation where you use that file as the saved results for a given run
 - We need two filenames so we can do an atomic `mv` so errors during serialization won't result in data loss. The two filenames should be on the same disk or `mv` isn't atomic. We don't enforce that.
  • Loading branch information
mpharrigan committed Aug 4, 2021
1 parent b1dc973 commit 120eb87
Show file tree
Hide file tree
Showing 2 changed files with 178 additions and 5 deletions.
84 changes: 83 additions & 1 deletion cirq-core/cirq/work/observable_measurement.py
Expand Up @@ -15,14 +15,16 @@
import abc
import dataclasses
import itertools
import os
import tempfile
import warnings
from typing import Optional, Iterable, Dict, List, Tuple, TYPE_CHECKING, Set, Sequence

import numpy as np
import sympy
from cirq import circuits, study, ops, value
from cirq._doc import document
from cirq.protocols import json_serializable_dataclass
from cirq.protocols import json_serializable_dataclass, to_json
from cirq.work.observable_measurement_data import BitstringAccumulator
from cirq.work.observable_settings import (
InitObsSetting,
Expand Down Expand Up @@ -353,6 +355,60 @@ def _to_sweep(param_tuples):
return to_sweep


def _parse_checkpoint_options(
checkpoint: bool, checkpoint_fn: Optional[str], checkpoint_other_fn: Optional[str]
) -> Tuple[Optional[str], Optional[str]]:
"""Parse the checkpoint-oriented options in `measure_grouped_settings`.
This function contains the validation and defaults logic. Please see
`measure_grouped_settings` for documentation on these args.
Returns:
checkpoint_fn, checkpoint_other_fn: Parsed or default filenames for primary and previous
checkpoint files.
"""
if not checkpoint:
if checkpoint_fn is not None or checkpoint_other_fn is not None:
raise ValueError(
"Checkpoint filenames were provided but `checkpoint` was set to False."
)
return None, None

if checkpoint_fn is None:
checkpoint_dir = tempfile.mkdtemp()
chk_basename = 'observables'
checkpoint_fn = f'{checkpoint_dir}/{chk_basename}.json'

if checkpoint_other_fn is None:
checkpoint_dir = os.path.dirname(checkpoint_fn)
chk_basename = os.path.basename(checkpoint_fn)
chk_basename, dot, ext = chk_basename.rpartition('.')
if chk_basename == '' or dot != '.' or ext == '':
raise ValueError(
f"You specified `checkpoint_fn={checkpoint_fn!r}` which does not follow the "
f"pattern of 'filename.extension'. Please follow this pattern or fully specify "
f"`checkpoint_other_fn`."
)

if ext != 'json':
raise ValueError(
"Please use a `.json` filename or fully "
"specify checkpoint_fn and checkpoint_other_fn"
)
if checkpoint_dir == '':
checkpoint_other_fn = f'{chk_basename}.prev.json'
else:
checkpoint_other_fn = f'{checkpoint_dir}/{chk_basename}.prev.json'

if checkpoint_fn == checkpoint_other_fn:
raise ValueError(
f"`checkpoint_fn` and `checkpoint_other_fn` were set to the same "
f"filename: {checkpoint_fn}. Please use two different filenames."
)

return checkpoint_fn, checkpoint_other_fn


def _needs_init_layer(grouped_settings: Dict[InitObsSetting, List[InitObsSetting]]) -> bool:
"""Helper function to go through init_states and determine if any of them need an
initialization layer of single-qubit gates."""
Expand All @@ -371,6 +427,9 @@ def measure_grouped_settings(
readout_symmetrization: bool = False,
circuit_sweep: 'cirq.study.sweepable.SweepLike' = None,
readout_calibrations: Optional[BitstringAccumulator] = None,
checkpoint: bool = False,
checkpoint_fn: Optional[str] = None,
checkpoint_other_fn: Optional[str] = None,
) -> List[BitstringAccumulator]:
"""Measure a suite of grouped InitObsSetting settings.
Expand Down Expand Up @@ -399,10 +458,26 @@ def measure_grouped_settings(
in `circuit`. The total sweep is the product of the circuit sweep
with parameter settings for the single-qubit basis-change rotations.
readout_calibrations: The result of `calibrate_readout_error`.
checkpoint: If set to True, save cumulative raw results at the end
of each iteration of the sampling loop. Load in these results
with `cirq.read_json`.
checkpoint_fn: The filename for the checkpoint file. If `checkpoint`
is set to True and this is not specified, a file in a temporary
directory will be used.
checkpoint_other_fn: The filename for another checkpoint file, which
contains the previous checkpoint. This lets us avoid losing data if
a failure occurs during checkpoint writing. If `checkpoint`
is set to True and this is not specified, a file in a temporary
directory will be used. If `checkpoint` is set to True and
`checkpoint_fn` is specified but this argument is *not* specified,
"{checkpoint_fn}.prev.json" will be used.
"""
if readout_calibrations is not None and not readout_symmetrization:
raise ValueError("Readout calibration only works if `readout_symmetrization` is enabled.")

checkpoint_fn, checkpoint_other_fn = _parse_checkpoint_options(
checkpoint=checkpoint, checkpoint_fn=checkpoint_fn, checkpoint_other_fn=checkpoint_other_fn
)
qubits = sorted({q for ms in grouped_settings.keys() for q in ms.init_state.qubits})
qubit_to_index = {q: i for i, q in enumerate(qubits)}

Expand Down Expand Up @@ -471,4 +546,11 @@ def measure_grouped_settings(
bitstrings = np.logical_xor(flippy_ms.flips, result.measurements['z'])
accumulator.consume_results(bitstrings.astype(np.uint8, casting='safe'))

if checkpoint:
assert checkpoint_fn is not None, 'mypy'
assert checkpoint_other_fn is not None, 'mypy'
if os.path.exists(checkpoint_fn):
os.replace(checkpoint_fn, checkpoint_other_fn)
to_json(list(accumulators.values()), checkpoint_fn)

return list(accumulators.values())
99 changes: 95 additions & 4 deletions cirq-core/cirq/work/observable_measurement_test.py
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tempfile

import numpy as np
import pytest
Expand All @@ -26,6 +27,7 @@
_aggregate_n_repetitions,
_check_meas_specs_still_todo,
StoppingCriteria,
_parse_checkpoint_options,
)


Expand Down Expand Up @@ -155,7 +157,6 @@ def test_params_and_settings():


def test_subdivide_meas_specs():

qubits = cirq.LineQubit.range(2)
q0, q1 = qubits
setting = cw.InitObsSetting(
Expand Down Expand Up @@ -364,8 +365,56 @@ def test_meas_spec_still_todo_lots_of_params(monkeypatch):
)


@pytest.mark.parametrize('with_circuit_sweep', (True, False))
def test_measure_grouped_settings(with_circuit_sweep):
def test_checkpoint_options():
# There are three ~binary options (the latter two can be either specified or `None`. We
# test those 2^3 cases.

assert _parse_checkpoint_options(False, None, None) == (None, None)
with pytest.raises(ValueError):
_parse_checkpoint_options(False, 'test', None)
with pytest.raises(ValueError):
_parse_checkpoint_options(False, None, 'test')
with pytest.raises(ValueError):
_parse_checkpoint_options(False, 'test1', 'test2')

chk, chkprev = _parse_checkpoint_options(True, None, None)
assert chk.startswith(tempfile.gettempdir())
assert chk.endswith('observables.json')
assert chkprev.startswith(tempfile.gettempdir())
assert chkprev.endswith('observables.prev.json')

chk, chkprev = _parse_checkpoint_options(True, None, 'prev.json')
assert chk.startswith(tempfile.gettempdir())
assert chk.endswith('observables.json')
assert chkprev == 'prev.json'

chk, chkprev = _parse_checkpoint_options(True, 'my_fancy_observables.json', None)
assert chk == 'my_fancy_observables.json'
assert chkprev == 'my_fancy_observables.prev.json'

chk, chkprev = _parse_checkpoint_options(True, 'my_fancy/observables.json', None)
assert chk == 'my_fancy/observables.json'
assert chkprev == 'my_fancy/observables.prev.json'

with pytest.raises(ValueError, match=r'Please use a `.json` filename.*'):
_parse_checkpoint_options(True, 'my_fancy_observables.obs', None)

with pytest.raises(ValueError, match=r"pattern of 'filename.extension'.*"):
_parse_checkpoint_options(True, 'my_fancy_observables', None)
with pytest.raises(ValueError, match=r"pattern of 'filename.extension'.*"):
_parse_checkpoint_options(True, '.obs', None)
with pytest.raises(ValueError, match=r"pattern of 'filename.extension'.*"):
_parse_checkpoint_options(True, 'obs.', None)
with pytest.raises(ValueError, match=r"pattern of 'filename.extension'.*"):
_parse_checkpoint_options(True, '', None)

chk, chkprev = _parse_checkpoint_options(True, 'test1', 'test2')
assert chk == 'test1'
assert chkprev == 'test2'


@pytest.mark.parametrize(('with_circuit_sweep', 'checkpoint'), [(True, True), (False, False)])
def test_measure_grouped_settings(with_circuit_sweep, checkpoint, tmpdir):
qubits = cirq.LineQubit.range(1)
(q,) = qubits
tests = [
Expand All @@ -381,6 +430,11 @@ def test_measure_grouped_settings(with_circuit_sweep):
else:
ss = None

if checkpoint:
checkpoint_fn = f'{tmpdir}/obs.json'
else:
checkpoint_fn = None

for init, obs, coef in tests:
setting = cw.InitObsSetting(
init_state=init(q),
Expand All @@ -392,8 +446,10 @@ def test_measure_grouped_settings(with_circuit_sweep):
circuit=circuit,
grouped_settings=grouped_settings,
sampler=cirq.Simulator(),
stopping_criteria=cw.RepetitionsStoppingCriteria(1_000),
stopping_criteria=cw.RepetitionsStoppingCriteria(1_000, repetitions_per_chunk=500),
circuit_sweep=ss,
checkpoint=checkpoint,
checkpoint_fn=checkpoint_fn,
)
if with_circuit_sweep:
for result in results:
Expand Down Expand Up @@ -430,3 +486,38 @@ def test_measure_grouped_settings_calibration_validation():
readout_calibrations=dummy_ro_calib,
readout_symmetrization=False, # no-no!
)


def test_measure_grouped_settings_read_checkpoint(tmpdir):
qubits = cirq.LineQubit.range(1)
(q,) = qubits

setting = cw.InitObsSetting(
init_state=cirq.KET_ZERO(q),
observable=cirq.Z(q),
)
grouped_settings = {setting: [setting]}
circuit = cirq.Circuit(cirq.I.on_each(*qubits))
with pytest.raises(ValueError, match=r'same filename.*'):
_ = cw.measure_grouped_settings(
circuit=circuit,
grouped_settings=grouped_settings,
sampler=cirq.Simulator(),
stopping_criteria=cw.RepetitionsStoppingCriteria(1_000, repetitions_per_chunk=500),
checkpoint=True,
checkpoint_fn=f'{tmpdir}/obs.json',
checkpoint_other_fn=f'{tmpdir}/obs.json', # Same filename
)
_ = cw.measure_grouped_settings(
circuit=circuit,
grouped_settings=grouped_settings,
sampler=cirq.Simulator(),
stopping_criteria=cw.RepetitionsStoppingCriteria(1_000, repetitions_per_chunk=500),
checkpoint=True,
checkpoint_fn=f'{tmpdir}/obs.json',
checkpoint_other_fn=f'{tmpdir}/obs.prev.json',
)
results = cirq.read_json(f'{tmpdir}/obs.json')
(result,) = results # one group
assert result.n_repetitions == 1_000
assert result.means() == [1.0]

0 comments on commit 120eb87

Please sign in to comment.