From 120eb873cd1f3b0864e1783928efe416e113b7dc Mon Sep 17 00:00:00 2001 From: Matthew Harrigan Date: Wed, 4 Aug 2021 14:27:05 -0700 Subject: [PATCH] [Obs] 4.4 - Checkpointing (#4352) Save "checkpoint" files during observable estimation. With enough observables or enough samples or low enough variables you can construct long running calls to this functionality. These options will (optionally) make sure data is not lost in those scenarios. - It's off by default - If you just toggle it to True, it will save data in a temporary directory. The use case envisaged here is to guard against data loss in an unforseen interruption - You can provide your own filenames. The use case here can be part of the nominal operation where you use that file as the saved results for a given run - We need two filenames so we can do an atomic `mv` so errors during serialization won't result in data loss. The two filenames should be on the same disk or `mv` isn't atomic. We don't enforce that. --- cirq-core/cirq/work/observable_measurement.py | 84 +++++++++++++++- .../cirq/work/observable_measurement_test.py | 99 ++++++++++++++++++- 2 files changed, 178 insertions(+), 5 deletions(-) diff --git a/cirq-core/cirq/work/observable_measurement.py b/cirq-core/cirq/work/observable_measurement.py index dc97eb433b0..457cadd553f 100644 --- a/cirq-core/cirq/work/observable_measurement.py +++ b/cirq-core/cirq/work/observable_measurement.py @@ -15,6 +15,8 @@ import abc import dataclasses import itertools +import os +import tempfile import warnings from typing import Optional, Iterable, Dict, List, Tuple, TYPE_CHECKING, Set, Sequence @@ -22,7 +24,7 @@ import sympy from cirq import circuits, study, ops, value from cirq._doc import document -from cirq.protocols import json_serializable_dataclass +from cirq.protocols import json_serializable_dataclass, to_json from cirq.work.observable_measurement_data import BitstringAccumulator from cirq.work.observable_settings import ( InitObsSetting, @@ -353,6 +355,60 @@ def _to_sweep(param_tuples): return to_sweep +def _parse_checkpoint_options( + checkpoint: bool, checkpoint_fn: Optional[str], checkpoint_other_fn: Optional[str] +) -> Tuple[Optional[str], Optional[str]]: + """Parse the checkpoint-oriented options in `measure_grouped_settings`. + + This function contains the validation and defaults logic. Please see + `measure_grouped_settings` for documentation on these args. + + Returns: + checkpoint_fn, checkpoint_other_fn: Parsed or default filenames for primary and previous + checkpoint files. + """ + if not checkpoint: + if checkpoint_fn is not None or checkpoint_other_fn is not None: + raise ValueError( + "Checkpoint filenames were provided but `checkpoint` was set to False." + ) + return None, None + + if checkpoint_fn is None: + checkpoint_dir = tempfile.mkdtemp() + chk_basename = 'observables' + checkpoint_fn = f'{checkpoint_dir}/{chk_basename}.json' + + if checkpoint_other_fn is None: + checkpoint_dir = os.path.dirname(checkpoint_fn) + chk_basename = os.path.basename(checkpoint_fn) + chk_basename, dot, ext = chk_basename.rpartition('.') + if chk_basename == '' or dot != '.' or ext == '': + raise ValueError( + f"You specified `checkpoint_fn={checkpoint_fn!r}` which does not follow the " + f"pattern of 'filename.extension'. Please follow this pattern or fully specify " + f"`checkpoint_other_fn`." + ) + + if ext != 'json': + raise ValueError( + "Please use a `.json` filename or fully " + "specify checkpoint_fn and checkpoint_other_fn" + ) + if checkpoint_dir == '': + checkpoint_other_fn = f'{chk_basename}.prev.json' + else: + checkpoint_other_fn = f'{checkpoint_dir}/{chk_basename}.prev.json' + + if checkpoint_fn == checkpoint_other_fn: + raise ValueError( + f"`checkpoint_fn` and `checkpoint_other_fn` were set to the same " + f"filename: {checkpoint_fn}. Please use two different filenames." + ) + + return checkpoint_fn, checkpoint_other_fn + + def _needs_init_layer(grouped_settings: Dict[InitObsSetting, List[InitObsSetting]]) -> bool: """Helper function to go through init_states and determine if any of them need an initialization layer of single-qubit gates.""" @@ -371,6 +427,9 @@ def measure_grouped_settings( readout_symmetrization: bool = False, circuit_sweep: 'cirq.study.sweepable.SweepLike' = None, readout_calibrations: Optional[BitstringAccumulator] = None, + checkpoint: bool = False, + checkpoint_fn: Optional[str] = None, + checkpoint_other_fn: Optional[str] = None, ) -> List[BitstringAccumulator]: """Measure a suite of grouped InitObsSetting settings. @@ -399,10 +458,26 @@ def measure_grouped_settings( in `circuit`. The total sweep is the product of the circuit sweep with parameter settings for the single-qubit basis-change rotations. readout_calibrations: The result of `calibrate_readout_error`. + checkpoint: If set to True, save cumulative raw results at the end + of each iteration of the sampling loop. Load in these results + with `cirq.read_json`. + checkpoint_fn: The filename for the checkpoint file. If `checkpoint` + is set to True and this is not specified, a file in a temporary + directory will be used. + checkpoint_other_fn: The filename for another checkpoint file, which + contains the previous checkpoint. This lets us avoid losing data if + a failure occurs during checkpoint writing. If `checkpoint` + is set to True and this is not specified, a file in a temporary + directory will be used. If `checkpoint` is set to True and + `checkpoint_fn` is specified but this argument is *not* specified, + "{checkpoint_fn}.prev.json" will be used. """ if readout_calibrations is not None and not readout_symmetrization: raise ValueError("Readout calibration only works if `readout_symmetrization` is enabled.") + checkpoint_fn, checkpoint_other_fn = _parse_checkpoint_options( + checkpoint=checkpoint, checkpoint_fn=checkpoint_fn, checkpoint_other_fn=checkpoint_other_fn + ) qubits = sorted({q for ms in grouped_settings.keys() for q in ms.init_state.qubits}) qubit_to_index = {q: i for i, q in enumerate(qubits)} @@ -471,4 +546,11 @@ def measure_grouped_settings( bitstrings = np.logical_xor(flippy_ms.flips, result.measurements['z']) accumulator.consume_results(bitstrings.astype(np.uint8, casting='safe')) + if checkpoint: + assert checkpoint_fn is not None, 'mypy' + assert checkpoint_other_fn is not None, 'mypy' + if os.path.exists(checkpoint_fn): + os.replace(checkpoint_fn, checkpoint_other_fn) + to_json(list(accumulators.values()), checkpoint_fn) + return list(accumulators.values()) diff --git a/cirq-core/cirq/work/observable_measurement_test.py b/cirq-core/cirq/work/observable_measurement_test.py index 052c924253d..51c19093522 100644 --- a/cirq-core/cirq/work/observable_measurement_test.py +++ b/cirq-core/cirq/work/observable_measurement_test.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import tempfile import numpy as np import pytest @@ -26,6 +27,7 @@ _aggregate_n_repetitions, _check_meas_specs_still_todo, StoppingCriteria, + _parse_checkpoint_options, ) @@ -155,7 +157,6 @@ def test_params_and_settings(): def test_subdivide_meas_specs(): - qubits = cirq.LineQubit.range(2) q0, q1 = qubits setting = cw.InitObsSetting( @@ -364,8 +365,56 @@ def test_meas_spec_still_todo_lots_of_params(monkeypatch): ) -@pytest.mark.parametrize('with_circuit_sweep', (True, False)) -def test_measure_grouped_settings(with_circuit_sweep): +def test_checkpoint_options(): + # There are three ~binary options (the latter two can be either specified or `None`. We + # test those 2^3 cases. + + assert _parse_checkpoint_options(False, None, None) == (None, None) + with pytest.raises(ValueError): + _parse_checkpoint_options(False, 'test', None) + with pytest.raises(ValueError): + _parse_checkpoint_options(False, None, 'test') + with pytest.raises(ValueError): + _parse_checkpoint_options(False, 'test1', 'test2') + + chk, chkprev = _parse_checkpoint_options(True, None, None) + assert chk.startswith(tempfile.gettempdir()) + assert chk.endswith('observables.json') + assert chkprev.startswith(tempfile.gettempdir()) + assert chkprev.endswith('observables.prev.json') + + chk, chkprev = _parse_checkpoint_options(True, None, 'prev.json') + assert chk.startswith(tempfile.gettempdir()) + assert chk.endswith('observables.json') + assert chkprev == 'prev.json' + + chk, chkprev = _parse_checkpoint_options(True, 'my_fancy_observables.json', None) + assert chk == 'my_fancy_observables.json' + assert chkprev == 'my_fancy_observables.prev.json' + + chk, chkprev = _parse_checkpoint_options(True, 'my_fancy/observables.json', None) + assert chk == 'my_fancy/observables.json' + assert chkprev == 'my_fancy/observables.prev.json' + + with pytest.raises(ValueError, match=r'Please use a `.json` filename.*'): + _parse_checkpoint_options(True, 'my_fancy_observables.obs', None) + + with pytest.raises(ValueError, match=r"pattern of 'filename.extension'.*"): + _parse_checkpoint_options(True, 'my_fancy_observables', None) + with pytest.raises(ValueError, match=r"pattern of 'filename.extension'.*"): + _parse_checkpoint_options(True, '.obs', None) + with pytest.raises(ValueError, match=r"pattern of 'filename.extension'.*"): + _parse_checkpoint_options(True, 'obs.', None) + with pytest.raises(ValueError, match=r"pattern of 'filename.extension'.*"): + _parse_checkpoint_options(True, '', None) + + chk, chkprev = _parse_checkpoint_options(True, 'test1', 'test2') + assert chk == 'test1' + assert chkprev == 'test2' + + +@pytest.mark.parametrize(('with_circuit_sweep', 'checkpoint'), [(True, True), (False, False)]) +def test_measure_grouped_settings(with_circuit_sweep, checkpoint, tmpdir): qubits = cirq.LineQubit.range(1) (q,) = qubits tests = [ @@ -381,6 +430,11 @@ def test_measure_grouped_settings(with_circuit_sweep): else: ss = None + if checkpoint: + checkpoint_fn = f'{tmpdir}/obs.json' + else: + checkpoint_fn = None + for init, obs, coef in tests: setting = cw.InitObsSetting( init_state=init(q), @@ -392,8 +446,10 @@ def test_measure_grouped_settings(with_circuit_sweep): circuit=circuit, grouped_settings=grouped_settings, sampler=cirq.Simulator(), - stopping_criteria=cw.RepetitionsStoppingCriteria(1_000), + stopping_criteria=cw.RepetitionsStoppingCriteria(1_000, repetitions_per_chunk=500), circuit_sweep=ss, + checkpoint=checkpoint, + checkpoint_fn=checkpoint_fn, ) if with_circuit_sweep: for result in results: @@ -430,3 +486,38 @@ def test_measure_grouped_settings_calibration_validation(): readout_calibrations=dummy_ro_calib, readout_symmetrization=False, # no-no! ) + + +def test_measure_grouped_settings_read_checkpoint(tmpdir): + qubits = cirq.LineQubit.range(1) + (q,) = qubits + + setting = cw.InitObsSetting( + init_state=cirq.KET_ZERO(q), + observable=cirq.Z(q), + ) + grouped_settings = {setting: [setting]} + circuit = cirq.Circuit(cirq.I.on_each(*qubits)) + with pytest.raises(ValueError, match=r'same filename.*'): + _ = cw.measure_grouped_settings( + circuit=circuit, + grouped_settings=grouped_settings, + sampler=cirq.Simulator(), + stopping_criteria=cw.RepetitionsStoppingCriteria(1_000, repetitions_per_chunk=500), + checkpoint=True, + checkpoint_fn=f'{tmpdir}/obs.json', + checkpoint_other_fn=f'{tmpdir}/obs.json', # Same filename + ) + _ = cw.measure_grouped_settings( + circuit=circuit, + grouped_settings=grouped_settings, + sampler=cirq.Simulator(), + stopping_criteria=cw.RepetitionsStoppingCriteria(1_000, repetitions_per_chunk=500), + checkpoint=True, + checkpoint_fn=f'{tmpdir}/obs.json', + checkpoint_other_fn=f'{tmpdir}/obs.prev.json', + ) + results = cirq.read_json(f'{tmpdir}/obs.json') + (result,) = results # one group + assert result.n_repetitions == 1_000 + assert result.means() == [1.0]