Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Obs] 4.4 - Checkpointing #4352

Merged
merged 7 commits into from
Aug 4, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
74 changes: 71 additions & 3 deletions cirq-core/cirq/work/observable_measurement.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,16 @@
import abc
import dataclasses
import itertools
import os
import tempfile
import warnings
from typing import Iterable, Dict, List, Tuple, TYPE_CHECKING, Set, Sequence
from typing import Optional, Iterable, Dict, List, Tuple, TYPE_CHECKING, Set, Sequence

import numpy as np
import sympy

from cirq import circuits, study, ops, value
from cirq._doc import document
from cirq.protocols import json_serializable_dataclass
from cirq.protocols import json_serializable_dataclass, to_json
from cirq.work.observable_measurement_data import BitstringAccumulator
from cirq.work.observable_settings import (
InitObsSetting,
Expand Down Expand Up @@ -354,6 +355,47 @@ def _to_sweep(param_tuples):
return to_sweep


def _parse_checkpoint_options(
checkpoint: bool, checkpoint_fn: Optional[str], checkpoint_other_fn: Optional[str]
) -> Tuple[Optional[str], Optional[str]]:
"""Parse the checkpoint-oriented options in `measure_grouped_settings`.

This function contains the validation and defaults logic. Please see
`measure_grouped_settings` for documentation on these args.

Returns:
checkpoint_fn, checkpoint_other_fn: Parsed or default filenames for primary and previous
checkpoint files.
"""
if not checkpoint:
if checkpoint_fn is not None or checkpoint_other_fn is not None:
raise ValueError(
"Checkpoint filenames were provided but `checkpoint` was set to False."
)
return None, None

if checkpoint_fn is None:
checkpoint_dir = tempfile.mkdtemp()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we call mkdtemp in tests, we should clean up the directories it creates to ensure tests are hermetic. TemporaryDirectory may help with this.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you're right. I added the pytest tmpdir fixture to the test. For the actual code, we want the checkpoint files to stick around

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would os.mkdir be preferable, given that the files are meant to outlive the Python process (i.e. they are non-temporary)?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mkdir makes a directory but you'd still need to choose its location. This will give you a file in /tmp which gets cleaned up after like 30 days or whatever your computer's policy is.

chk_basename = 'observables'
checkpoint_fn = f'{checkpoint_dir}/{chk_basename}.json'

if checkpoint_other_fn is None:
checkpoint_dir = os.path.dirname(checkpoint_fn)
chk_basename = os.path.basename(checkpoint_fn)
chk_basename, _, ext = chk_basename.rpartition('.')
if ext != 'json':
raise ValueError(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason we prefer an error for this instead of e.g. f'{chk_basename}.prev.{ext}' below?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The original intent was to limit the amount of "magic" it would do for you, and if you're not following normal filename semantics then there's no telling what's going on. Specifically: if you don't have a file extension rpartition will give you ('', '', 'filename' which would give a weird automatic checkpoint_other_fn.

If you had something like .tar.gz the automatic filename would be basename.tar.prev.gz which is also weird.

Do you think I should relax it? I could meet halfway and accept other file extensions but reject the case where there's no file extension, i.e. there's no '.' in the name.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added some validation that it follows filname.ext pattern. I left in the json extension check. Let me know what you think and I can remove the json extension check. I don't have a strong opinion either way.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No strong opinions on my end either, but your comment helped clarify the reasoning for me. I think the current behavior should be fine.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok keeping the current behavior. We can loosen the json extension check further if the need arises

"Please use a `.json` filename or fully "
"specify checkpoint_fn and checkpoint_other_fn"
)
if checkpoint_dir == '':
checkpoint_other_fn = f'{chk_basename}.prev.json'
else:
checkpoint_other_fn = f'{checkpoint_dir}/{chk_basename}.prev.json'

return checkpoint_fn, checkpoint_other_fn
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If two separate files are required, we should return an error if these are equal.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. As written, it will still "work" as expected if you use the same filename for both. You just may get corruption if the process dies during the checkpointing process. Let me think about this.

  1. it's confusing and the user may be making a mistake
  2. a power user may not want two different files around and specifically wants to roll the dice with non-atomic checkpointing

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a test that uses the same name for both. Let me know if you think I should disallow this behavior. I don't have a strong opinion either way

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is case (2) expected to be common? e.g. are these files particularly large, or are users expected to create many of them at once? Otherwise, I think the risk of confusion from (1) outweighs the utility of this behavior.

If case (2) is common, it may still be useful to guard it with an "are you sure?"-type flag, similar to the permit_terminal_measurements guard on simulate_expectation_values:

if not permit_terminal_measurements and program.are_any_measurements_terminal():
raise ValueError(
'Provided circuit has terminal measurements, which may '
'skew expectation values. If this is intentional, set '
'permit_terminal_measurements=True.'
)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I banned it



def _needs_init_layer(grouped_settings: Dict[InitObsSetting, List[InitObsSetting]]) -> bool:
"""Helper function to go through init_states and determine if any of them need an
initialization layer of single-qubit gates."""
Expand All @@ -371,6 +413,9 @@ def measure_grouped_settings(
*,
readout_symmetrization: bool = False,
circuit_sweep: 'cirq.study.sweepable.SweepLike' = None,
checkpoint: bool = False,
checkpoint_fn: Optional[str] = None,
checkpoint_other_fn: Optional[str] = None,
) -> List[BitstringAccumulator]:
"""Measure a suite of grouped InitObsSetting settings.

Expand Down Expand Up @@ -398,7 +443,23 @@ def measure_grouped_settings(
circuit_sweep: Additional parameter sweeps for parameters contained
in `circuit`. The total sweep is the product of the circuit sweep
with parameter settings for the single-qubit basis-change rotations.
checkpoint: If set to True, save cumulative raw results at the end
of each iteration of the sampling loop.
checkpoint_fn: The filename for the checkpoint file. If `checkpoint`
is set to True and this is not specified, a file in a temporary
directory will be used.
checkpoint_other_fn: The filename for another checkpoint file, which
contains the previous checkpoint. This lets us avoid losing data if
a failure occurs during checkpoint writing. If `checkpoint`
is set to True and this is not specified, a file in a temporary
directory will be used. If `checkpoint` is set to True and
`checkpoint_fn` is specified but this argument is *not* specified,
"{checkpoint_fn}.prev.json" will be used.
"""

checkpoint_fn, checkpoint_other_fn = _parse_checkpoint_options(
checkpoint=checkpoint, checkpoint_fn=checkpoint_fn, checkpoint_other_fn=checkpoint_other_fn
)
qubits = sorted({q for ms in grouped_settings.keys() for q in ms.init_state.qubits})
qubit_to_index = {q: i for i, q in enumerate(qubits)}

Expand Down Expand Up @@ -466,4 +527,11 @@ def measure_grouped_settings(
bitstrings = np.logical_xor(flippy_ms.flips, result.measurements['z'])
accumulator.consume_results(bitstrings.astype(np.uint8, casting='safe'))

if checkpoint:
assert checkpoint_fn is not None, 'mypy'
assert checkpoint_other_fn is not None, 'mypy'
if os.path.exists(checkpoint_fn):
os.rename(checkpoint_fn, checkpoint_other_fn)
to_json(list(accumulators.values()), checkpoint_fn)
mpharrigan marked this conversation as resolved.
Show resolved Hide resolved

return list(accumulators.values())
48 changes: 44 additions & 4 deletions cirq-core/cirq/work/observable_measurement_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
_aggregate_n_repetitions,
_check_meas_specs_still_todo,
StoppingCriteria,
_parse_checkpoint_options,
)


Expand Down Expand Up @@ -155,7 +156,6 @@ def test_params_and_settings():


def test_subdivide_meas_specs():

qubits = cirq.LineQubit.range(2)
q0, q1 = qubits
setting = cw.InitObsSetting(
Expand Down Expand Up @@ -364,8 +364,47 @@ def test_meas_spec_still_todo_lots_of_params(monkeypatch):
)


@pytest.mark.parametrize('with_circuit_sweep', (True, False))
def test_measure_grouped_settings(with_circuit_sweep):
def test_checkpoint_options():
# There are three ~binary options (the latter two can be either specified or `None`. We
# test those 2^3 cases.

assert _parse_checkpoint_options(False, None, None) == (None, None)
with pytest.raises(ValueError):
_parse_checkpoint_options(False, 'test', None)
with pytest.raises(ValueError):
_parse_checkpoint_options(False, None, 'test')
with pytest.raises(ValueError):
_parse_checkpoint_options(False, 'test1', 'test2')

chk, chkprev = _parse_checkpoint_options(True, None, None)
assert chk.startswith('/') # absolute temp path
mpharrigan marked this conversation as resolved.
Show resolved Hide resolved
assert chk.endswith('observables.json')
assert chkprev.startswith('/') # absolute temp path
assert chkprev.endswith('observables.prev.json')

chk, chkprev = _parse_checkpoint_options(True, None, 'prev.json')
assert chk.startswith('/') # absolute temp path
assert chk.endswith('observables.json')
assert chkprev == 'prev.json'

chk, chkprev = _parse_checkpoint_options(True, 'my_fancy_observables.json', None)
assert chk == 'my_fancy_observables.json'
assert chkprev == 'my_fancy_observables.prev.json'

chk, chkprev = _parse_checkpoint_options(True, 'my_fancy/observables.json', None)
assert chk == 'my_fancy/observables.json'
assert chkprev == 'my_fancy/observables.prev.json'

with pytest.raises(ValueError, match=r'Please use a `.json` filename.*'):
_parse_checkpoint_options(True, 'my_fancy_observables', None)

chk, chkprev = _parse_checkpoint_options(True, 'test1', 'test2')
assert chk == 'test1'
assert chkprev == 'test2'


@pytest.mark.parametrize(('with_circuit_sweep', 'checkpoint'), [(True, True), (False, False)])
def test_measure_grouped_settings(with_circuit_sweep, checkpoint):
qubits = cirq.LineQubit.range(1)
(q,) = qubits
tests = [
Expand All @@ -392,8 +431,9 @@ def test_measure_grouped_settings(with_circuit_sweep):
circuit=circuit,
grouped_settings=grouped_settings,
sampler=cirq.Simulator(),
stopping_criteria=cw.RepetitionsStoppingCriteria(1_000),
stopping_criteria=cw.RepetitionsStoppingCriteria(1_000, repetitions_per_chunk=500),
circuit_sweep=ss,
checkpoint=checkpoint,
)
if with_circuit_sweep:
for result in results:
Expand Down