Skip to content

Commit

Permalink
[Obs] 4.5 - High-level API (#4392)
Browse files Browse the repository at this point in the history
Provide two very similar high-level entries into the observable measurement framework. This provides more sensible defaults for some of the execution details and can return nicer values

```python
    df = measure_observables_df(
        circuit,
        [observable],
        cirq.Simulator(),
        stopping_criteria='variance',
        stopping_criteria_val=1e-3 ** 2,
    )
```

------------------

`measure_grouped_settings` lets you control how settings are grouped and returns "raw" data in one container (`BitstringAccumulator`) per group. This level of control is probably desirable for people seriously executing and expensive experiment on the device. The high-level API makes some simplifying assumptions

 - The qubits are the union of circuit and observable qubits
 - We will group the settings using a `GROUPER_T` function, by default "greedy"
 - We'll always initialize settings into the `|00..00>` state, so you can just provide `PauliString`s
 - You'll probably be using one of the two built-in stopping criteria and can just provide a name and a value instead of grokking these objects.
 - (for the second entrypoint) you really don't care about how things were grouped and executed, you'd just like the observable values in a dataframe table
  • Loading branch information
mpharrigan committed Aug 26, 2021
1 parent 2988803 commit 0dce2c1
Show file tree
Hide file tree
Showing 8 changed files with 593 additions and 85 deletions.
4 changes: 3 additions & 1 deletion cirq-core/cirq/work/observable_grouping.py
Expand Up @@ -12,14 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Iterable, Dict, List, TYPE_CHECKING, cast
from typing import Iterable, Dict, List, TYPE_CHECKING, cast, Callable

from cirq import ops, value
from cirq.work.observable_settings import InitObsSetting, _max_weight_state, _max_weight_observable

if TYPE_CHECKING:
pass

GROUPER_T = Callable[[Iterable[InitObsSetting]], Dict[InitObsSetting, List[InitObsSetting]]]


def group_settings_greedy(
settings: Iterable[InitObsSetting],
Expand Down
261 changes: 213 additions & 48 deletions cirq-core/cirq/work/observable_measurement.py

Large diffs are not rendered by default.

39 changes: 27 additions & 12 deletions cirq-core/cirq/work/observable_measurement_data.py
Expand Up @@ -14,11 +14,10 @@

import dataclasses
import datetime
from typing import Dict, List, Tuple, TYPE_CHECKING
from typing import Dict, List, Tuple, TYPE_CHECKING, Iterable, Any

import numpy as np

from cirq import protocols, ops
from cirq import ops, protocols
from cirq._compat import proper_repr
from cirq.work.observable_settings import (
InitObsSetting,
Expand Down Expand Up @@ -81,12 +80,12 @@ def _stats_from_measurements(
return obs_mean.item(), obs_err.item()


@protocols.json_serializable_dataclass(frozen=True)
@dataclasses.dataclass(frozen=True)
class ObservableMeasuredResult:
"""The result of an observable measurement.
Please see `flatten_grouped_results` or `BitstringAccumulator.results` for information on how
to get these from `measure_observables` return values.
A list of these is returned by `measure_observables`, or see `flatten_grouped_results` for
transformation of `measure_grouped_settings` BitstringAccumulators into these objects.
This is a flattened form of the contents of a `BitstringAccumulator` which may group many
simultaneously-observable settings into one object. As such, `BitstringAccumulator` has more
Expand All @@ -110,7 +109,7 @@ class ObservableMeasuredResult:

def __repr__(self):
# I wish we could use the default dataclass __repr__ but
# we need to prefix our class name with `cirq.work.`A
# we need to prefix our class name with `cirq.work.`
return (
f'cirq.work.ObservableMeasuredResult('
f'setting={self.setting!r}, '
Expand All @@ -132,6 +131,25 @@ def observable(self):
def stddev(self):
return np.sqrt(self.variance)

def as_dict(self) -> Dict[str, Any]:
"""Return the contents of this class as a dictionary.
This makes records suitable for construction of a Pandas dataframe. The circuit parameters
are flattened into the top-level of this dictionary.
"""
record = dataclasses.asdict(self)
del record['circuit_params']
del record['setting']
record['init_state'] = self.init_state
record['observable'] = self.observable

circuit_param_dict = {f'param.{k}': v for k, v in self.circuit_params.items()}
record.update(**circuit_param_dict)
return record

def _json_dict_(self):
return protocols.dataclass_json_dict(self)


def _setting_to_z_observable(setting: InitObsSetting):
qubits = setting.observable.qubits
Expand Down Expand Up @@ -271,7 +289,7 @@ def n_repetitions(self):
return len(self.bitstrings)

@property
def results(self):
def results(self) -> Iterable[ObservableMeasuredResult]:
"""Yield individual setting results as `ObservableMeasuredResult`
objects."""
for setting in self._simul_settings:
Expand All @@ -291,10 +309,7 @@ def records(self):
after chaining these results with those from other BitstringAccumulators.
"""
for result in self.results:
record = dataclasses.asdict(result)
del record['circuit_params']
record.update(**self._meas_spec.circuit_params)
yield record
yield result.as_dict()

def _json_dict_(self):
from cirq.study.result import _pack_digits
Expand Down
30 changes: 29 additions & 1 deletion cirq-core/cirq/work/observable_measurement_data_test.py
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import dataclasses
import datetime
import time

Expand Down Expand Up @@ -90,14 +91,41 @@ def test_observable_measured_result():
mean=0,
variance=5 ** 2,
repetitions=4,
circuit_params={},
circuit_params={'phi': 52},
)
assert omr.stddev == 5
assert omr.observable == cirq.Y(a) * cirq.Y(b)
assert omr.init_state == cirq.Z(a) * cirq.Z(b)

cirq.testing.assert_equivalent_repr(omr)

assert omr.as_dict() == {
'init_state': cirq.Z(a) * cirq.Z(b),
'observable': cirq.Y(a) * cirq.Y(b),
'mean': 0,
'variance': 25,
'repetitions': 4,
'param.phi': 52,
}
omr2 = dataclasses.replace(
omr,
circuit_params={
'phi': 52,
'observable': 3.14, # this would be a bad but legal parameter name
'param.phi': -1,
},
)
assert omr2.as_dict() == {
'init_state': cirq.Z(a) * cirq.Z(b),
'observable': cirq.Y(a) * cirq.Y(b),
'mean': 0,
'variance': 25,
'repetitions': 4,
'param.phi': 52,
'param.observable': 3.14,
'param.param.phi': -1,
}


@pytest.fixture()
def example_bsa() -> 'cw.BitstringAccumulator':
Expand Down
94 changes: 85 additions & 9 deletions cirq-core/cirq/work/observable_measurement_test.py
Expand Up @@ -12,13 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import tempfile
from typing import Iterable, Dict, List

import numpy as np
import pytest

import cirq
import cirq.work as cw
from cirq.work import _MeasurementSpec, BitstringAccumulator
from cirq.work import _MeasurementSpec, BitstringAccumulator, group_settings_greedy, InitObsSetting
from cirq.work.observable_measurement import (
_with_parameterized_layers,
_get_params_for_setting,
Expand All @@ -28,6 +29,11 @@
_check_meas_specs_still_todo,
StoppingCriteria,
_parse_checkpoint_options,
measure_observables_df,
CheckpointFileOptions,
VarianceStoppingCriteria,
measure_observables,
RepetitionsStoppingCriteria,
)


Expand Down Expand Up @@ -448,8 +454,7 @@ def test_measure_grouped_settings(with_circuit_sweep, checkpoint, tmpdir):
sampler=cirq.Simulator(),
stopping_criteria=cw.RepetitionsStoppingCriteria(1_000, repetitions_per_chunk=500),
circuit_sweep=ss,
checkpoint=checkpoint,
checkpoint_fn=checkpoint_fn,
checkpoint=CheckpointFileOptions(checkpoint=checkpoint, checkpoint_fn=checkpoint_fn),
)
if with_circuit_sweep:
for result in results:
Expand Down Expand Up @@ -504,20 +509,91 @@ def test_measure_grouped_settings_read_checkpoint(tmpdir):
grouped_settings=grouped_settings,
sampler=cirq.Simulator(),
stopping_criteria=cw.RepetitionsStoppingCriteria(1_000, repetitions_per_chunk=500),
checkpoint=True,
checkpoint_fn=f'{tmpdir}/obs.json',
checkpoint_other_fn=f'{tmpdir}/obs.json', # Same filename
checkpoint=CheckpointFileOptions(
checkpoint=True,
checkpoint_fn=f'{tmpdir}/obs.json',
checkpoint_other_fn=f'{tmpdir}/obs.json', # Same filename
),
)
_ = cw.measure_grouped_settings(
circuit=circuit,
grouped_settings=grouped_settings,
sampler=cirq.Simulator(),
stopping_criteria=cw.RepetitionsStoppingCriteria(1_000, repetitions_per_chunk=500),
checkpoint=True,
checkpoint_fn=f'{tmpdir}/obs.json',
checkpoint_other_fn=f'{tmpdir}/obs.prev.json',
checkpoint=CheckpointFileOptions(
checkpoint=True,
checkpoint_fn=f'{tmpdir}/obs.json',
checkpoint_other_fn=f'{tmpdir}/obs.prev.json',
),
)
results = cirq.read_json(f'{tmpdir}/obs.json')
(result,) = results # one group
assert result.n_repetitions == 1_000
assert result.means() == [1.0]


Q = cirq.NamedQubit('q')


@pytest.mark.parametrize(
['circuit', 'observable'],
[
(cirq.Circuit(cirq.X(Q) ** 0.2), cirq.Z(Q)),
(cirq.Circuit(cirq.X(Q) ** -0.5, cirq.Z(Q) ** 0.2), cirq.Y(Q)),
(cirq.Circuit(cirq.Y(Q) ** 0.5, cirq.Z(Q) ** 0.2), cirq.X(Q)),
],
)
def test_XYZ_point8(circuit, observable):
# each circuit, observable combination should result in the observable value of 0.8
df = measure_observables_df(
circuit,
[observable],
cirq.Simulator(seed=52),
stopping_criteria=VarianceStoppingCriteria(1e-3 ** 2),
)
assert len(df) == 1, 'one observable'
mean = df.loc[0]['mean']
np.testing.assert_allclose(0.8, mean, atol=1e-2)


def _each_in_its_own_group_grouper(
settings: Iterable[InitObsSetting],
) -> Dict[InitObsSetting, List[InitObsSetting]]:
return {setting: [setting] for setting in settings}


@pytest.mark.parametrize(
'grouper', ['greedy', group_settings_greedy, _each_in_its_own_group_grouper]
)
def test_measure_observable_grouper(grouper):
circuit = cirq.Circuit(cirq.X(Q) ** 0.2)
observables = [
cirq.Z(Q),
cirq.Z(cirq.NamedQubit('q2')),
]
results = measure_observables(
circuit,
observables,
cirq.Simulator(seed=52),
stopping_criteria=RepetitionsStoppingCriteria(50_000),
grouper=grouper,
)
assert len(results) == 2, 'two observables'
np.testing.assert_allclose(0.8, results[0].mean, atol=0.05)
np.testing.assert_allclose(1, results[1].mean, atol=1e-9)


def test_measure_observable_bad_grouper():
circuit = cirq.Circuit(cirq.X(Q) ** 0.2)
observables = [
cirq.Z(Q),
cirq.Z(cirq.NamedQubit('q2')),
]
with pytest.raises(ValueError, match=r'Unknown grouping function'):
_ = measure_observables(
circuit,
observables,
cirq.Simulator(seed=52),
stopping_criteria=RepetitionsStoppingCriteria(50_000),
grouper='super fancy grouper',
)
24 changes: 14 additions & 10 deletions cirq-core/cirq/work/observable_settings.py
Expand Up @@ -12,21 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Union, Iterable, Dict, TYPE_CHECKING, Tuple
import dataclasses
from typing import Union, Iterable, Dict, TYPE_CHECKING, ItemsView, Tuple, FrozenSet

from cirq import ops, value
from cirq import ops, value, protocols

if TYPE_CHECKING:
import cirq
from cirq.value.product_state import _NamedOneQubitState

# Workaround for mypy custom dataclasses
from dataclasses import dataclass as json_serializable_dataclass
else:
from cirq.protocols import json_serializable_dataclass


@json_serializable_dataclass(frozen=True)
@dataclasses.dataclass(frozen=True)
class InitObsSetting:
"""A pair of initial state and observable.
Expand Down Expand Up @@ -59,6 +55,9 @@ def __repr__(self):
f'observable={self.observable!r})'
)

def _json_dict_(self):
return protocols.dataclass_json_dict(self)


def _max_weight_observable(observables: Iterable[ops.PauliString]) -> Union[None, ops.PauliString]:
"""Create a new observable that is compatible with all input observables
Expand Down Expand Up @@ -135,7 +134,9 @@ def _fix_precision(val: float, precision) -> int:
return int(val * precision)


def _hashable_param(param_tuples: Iterable[Tuple[str, float]], precision=1e7):
def _hashable_param(
param_tuples: ItemsView[str, float], precision=1e7
) -> FrozenSet[Tuple[str, float]]:
"""Hash circuit parameters using fixed precision.
Circuit parameters can be floats but we also need to use them as
Expand All @@ -144,7 +145,7 @@ def _hashable_param(param_tuples: Iterable[Tuple[str, float]], precision=1e7):
return frozenset((k, _fix_precision(v, precision)) for k, v in param_tuples)


@json_serializable_dataclass(frozen=True)
@dataclasses.dataclass(frozen=True)
class _MeasurementSpec:
"""An encapsulation of all the specifications for one run of a
quantum processor.
Expand All @@ -165,3 +166,6 @@ def __repr__(self):
f'cirq.work._MeasurementSpec(max_setting={self.max_setting!r}, '
f'circuit_params={self.circuit_params!r})'
)

def _json_dict_(self):
return protocols.dataclass_json_dict(self)

0 comments on commit 0dce2c1

Please sign in to comment.