-
Notifications
You must be signed in to change notification settings - Fork 983
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Obs] 4.4 - Checkpointing #4352
Changes from 1 commit
9f4412e
8b50b6f
fbb3f95
d42060a
9c78f8d
3c0829d
292e16f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -15,15 +15,16 @@ | |||||||||||||
import abc | ||||||||||||||
import dataclasses | ||||||||||||||
import itertools | ||||||||||||||
import os | ||||||||||||||
import tempfile | ||||||||||||||
import warnings | ||||||||||||||
from typing import Iterable, Dict, List, Tuple, TYPE_CHECKING, Set, Sequence | ||||||||||||||
from typing import Optional, Iterable, Dict, List, Tuple, TYPE_CHECKING, Set, Sequence | ||||||||||||||
|
||||||||||||||
import numpy as np | ||||||||||||||
import sympy | ||||||||||||||
|
||||||||||||||
from cirq import circuits, study, ops, value | ||||||||||||||
from cirq._doc import document | ||||||||||||||
from cirq.protocols import json_serializable_dataclass | ||||||||||||||
from cirq.protocols import json_serializable_dataclass, to_json | ||||||||||||||
from cirq.work.observable_measurement_data import BitstringAccumulator | ||||||||||||||
from cirq.work.observable_settings import ( | ||||||||||||||
InitObsSetting, | ||||||||||||||
|
@@ -354,6 +355,47 @@ def _to_sweep(param_tuples): | |||||||||||||
return to_sweep | ||||||||||||||
|
||||||||||||||
|
||||||||||||||
def _parse_checkpoint_options( | ||||||||||||||
checkpoint: bool, checkpoint_fn: Optional[str], checkpoint_other_fn: Optional[str] | ||||||||||||||
) -> Tuple[Optional[str], Optional[str]]: | ||||||||||||||
"""Parse the checkpoint-oriented options in `measure_grouped_settings`. | ||||||||||||||
|
||||||||||||||
This function contains the validation and defaults logic. Please see | ||||||||||||||
`measure_grouped_settings` for documentation on these args. | ||||||||||||||
|
||||||||||||||
Returns: | ||||||||||||||
checkpoint_fn, checkpoint_other_fn: Parsed or default filenames for primary and previous | ||||||||||||||
checkpoint files. | ||||||||||||||
""" | ||||||||||||||
if not checkpoint: | ||||||||||||||
if checkpoint_fn is not None or checkpoint_other_fn is not None: | ||||||||||||||
raise ValueError( | ||||||||||||||
"Checkpoint filenames were provided but `checkpoint` was set to False." | ||||||||||||||
) | ||||||||||||||
return None, None | ||||||||||||||
|
||||||||||||||
if checkpoint_fn is None: | ||||||||||||||
checkpoint_dir = tempfile.mkdtemp() | ||||||||||||||
chk_basename = 'observables' | ||||||||||||||
checkpoint_fn = f'{checkpoint_dir}/{chk_basename}.json' | ||||||||||||||
|
||||||||||||||
if checkpoint_other_fn is None: | ||||||||||||||
checkpoint_dir = os.path.dirname(checkpoint_fn) | ||||||||||||||
chk_basename = os.path.basename(checkpoint_fn) | ||||||||||||||
chk_basename, _, ext = chk_basename.rpartition('.') | ||||||||||||||
if ext != 'json': | ||||||||||||||
raise ValueError( | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Any reason we prefer an error for this instead of e.g. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The original intent was to limit the amount of "magic" it would do for you, and if you're not following normal filename semantics then there's no telling what's going on. Specifically: if you don't have a file extension If you had something like Do you think I should relax it? I could meet halfway and accept other file extensions but reject the case where there's no file extension, i.e. there's no There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I added some validation that it follows There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No strong opinions on my end either, but your comment helped clarify the reasoning for me. I think the current behavior should be fine. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok keeping the current behavior. We can loosen the json extension check further if the need arises |
||||||||||||||
"Please use a `.json` filename or fully " | ||||||||||||||
"specify checkpoint_fn and checkpoint_other_fn" | ||||||||||||||
) | ||||||||||||||
if checkpoint_dir == '': | ||||||||||||||
checkpoint_other_fn = f'{chk_basename}.prev.json' | ||||||||||||||
else: | ||||||||||||||
checkpoint_other_fn = f'{checkpoint_dir}/{chk_basename}.prev.json' | ||||||||||||||
|
||||||||||||||
return checkpoint_fn, checkpoint_other_fn | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If two separate files are required, we should return an error if these are equal. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good point. As written, it will still "work" as expected if you use the same filename for both. You just may get corruption if the process dies during the checkpointing process. Let me think about this.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I added a test that uses the same name for both. Let me know if you think I should disallow this behavior. I don't have a strong opinion either way There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is case (2) expected to be common? e.g. are these files particularly large, or are users expected to create many of them at once? Otherwise, I think the risk of confusion from (1) outweighs the utility of this behavior. If case (2) is common, it may still be useful to guard it with an "are you sure?"-type flag, similar to the Cirq/cirq-core/cirq/sim/sparse_simulator.py Lines 223 to 228 in fb43b84
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I banned it |
||||||||||||||
|
||||||||||||||
|
||||||||||||||
def _needs_init_layer(grouped_settings: Dict[InitObsSetting, List[InitObsSetting]]) -> bool: | ||||||||||||||
"""Helper function to go through init_states and determine if any of them need an | ||||||||||||||
initialization layer of single-qubit gates.""" | ||||||||||||||
|
@@ -371,6 +413,9 @@ def measure_grouped_settings( | |||||||||||||
*, | ||||||||||||||
readout_symmetrization: bool = False, | ||||||||||||||
circuit_sweep: 'cirq.study.sweepable.SweepLike' = None, | ||||||||||||||
checkpoint: bool = False, | ||||||||||||||
checkpoint_fn: Optional[str] = None, | ||||||||||||||
checkpoint_other_fn: Optional[str] = None, | ||||||||||||||
) -> List[BitstringAccumulator]: | ||||||||||||||
"""Measure a suite of grouped InitObsSetting settings. | ||||||||||||||
|
||||||||||||||
|
@@ -398,7 +443,23 @@ def measure_grouped_settings( | |||||||||||||
circuit_sweep: Additional parameter sweeps for parameters contained | ||||||||||||||
in `circuit`. The total sweep is the product of the circuit sweep | ||||||||||||||
with parameter settings for the single-qubit basis-change rotations. | ||||||||||||||
checkpoint: If set to True, save cumulative raw results at the end | ||||||||||||||
of each iteration of the sampling loop. | ||||||||||||||
checkpoint_fn: The filename for the checkpoint file. If `checkpoint` | ||||||||||||||
is set to True and this is not specified, a file in a temporary | ||||||||||||||
directory will be used. | ||||||||||||||
checkpoint_other_fn: The filename for another checkpoint file, which | ||||||||||||||
contains the previous checkpoint. This lets us avoid losing data if | ||||||||||||||
a failure occurs during checkpoint writing. If `checkpoint` | ||||||||||||||
is set to True and this is not specified, a file in a temporary | ||||||||||||||
directory will be used. If `checkpoint` is set to True and | ||||||||||||||
`checkpoint_fn` is specified but this argument is *not* specified, | ||||||||||||||
"{checkpoint_fn}.prev.json" will be used. | ||||||||||||||
""" | ||||||||||||||
|
||||||||||||||
checkpoint_fn, checkpoint_other_fn = _parse_checkpoint_options( | ||||||||||||||
checkpoint=checkpoint, checkpoint_fn=checkpoint_fn, checkpoint_other_fn=checkpoint_other_fn | ||||||||||||||
) | ||||||||||||||
qubits = sorted({q for ms in grouped_settings.keys() for q in ms.init_state.qubits}) | ||||||||||||||
qubit_to_index = {q: i for i, q in enumerate(qubits)} | ||||||||||||||
|
||||||||||||||
|
@@ -466,4 +527,11 @@ def measure_grouped_settings( | |||||||||||||
bitstrings = np.logical_xor(flippy_ms.flips, result.measurements['z']) | ||||||||||||||
accumulator.consume_results(bitstrings.astype(np.uint8, casting='safe')) | ||||||||||||||
|
||||||||||||||
if checkpoint: | ||||||||||||||
assert checkpoint_fn is not None, 'mypy' | ||||||||||||||
assert checkpoint_other_fn is not None, 'mypy' | ||||||||||||||
if os.path.exists(checkpoint_fn): | ||||||||||||||
os.rename(checkpoint_fn, checkpoint_other_fn) | ||||||||||||||
to_json(list(accumulators.values()), checkpoint_fn) | ||||||||||||||
mpharrigan marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||
|
||||||||||||||
return list(accumulators.values()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we call
mkdtemp
in tests, we should clean up the directories it creates to ensure tests are hermetic. TemporaryDirectory may help with this.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you're right. I added the pytest
tmpdir
fixture to the test. For the actual code, we want the checkpoint files to stick aroundThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would
os.mkdir
be preferable, given that the files are meant to outlive the Python process (i.e. they are non-temporary)?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mkdir
makes a directory but you'd still need to choose its location. This will give you a file in/tmp
which gets cleaned up after like 30 days or whatever your computer's policy is.