Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cirq-core/cirq/work/observable_measurement.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,7 +531,7 @@ def measure_grouped_settings(
for max_setting, param_resolver in itertools.product(
grouped_settings.keys(), study.to_resolvers(circuit_sweep)
):
circuit_params = dict(param_resolver.param_dict)
circuit_params = param_resolver.param_dict
meas_spec = _MeasurementSpec(max_setting=max_setting, circuit_params=circuit_params)
accumulator = BitstringAccumulator(
meas_spec=meas_spec,
Expand Down
4 changes: 2 additions & 2 deletions cirq-core/cirq/work/observable_measurement_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import dataclasses
import datetime
from typing import Any, Dict, Iterable, List, Tuple, TYPE_CHECKING, Union
from typing import Any, Dict, Iterable, List, Mapping, Tuple, TYPE_CHECKING, Union

import numpy as np
import sympy
Expand Down Expand Up @@ -107,7 +107,7 @@ class ObservableMeasuredResult:
mean: float
variance: float
repetitions: int
circuit_params: Dict[Union[str, sympy.Expr], Union[value.Scalar, sympy.Expr]]
circuit_params: Mapping[Union[str, sympy.Expr], Union[value.Scalar, sympy.Expr]]

def __repr__(self):
# I wish we could use the default dataclass __repr__ but
Expand Down
17 changes: 14 additions & 3 deletions cirq-core/cirq/work/observable_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,17 @@

import dataclasses
import numbers
from typing import Union, Iterable, Dict, Optional, TYPE_CHECKING, ItemsView, Tuple, FrozenSet
from typing import (
AbstractSet,
Mapping,
Union,
Iterable,
Dict,
Optional,
TYPE_CHECKING,
Tuple,
FrozenSet,
)

import sympy

Expand Down Expand Up @@ -143,7 +153,8 @@ def _fix_precision(val: Union[value.Scalar, sympy.Expr], precision) -> Union[int


def _hashable_param(
param_tuples: ItemsView[Union[str, sympy.Expr], Union[value.Scalar, sympy.Expr]], precision=1e7
param_tuples: AbstractSet[Tuple[Union[str, sympy.Expr], Union[value.Scalar, sympy.Expr]]],
precision=1e7,
) -> FrozenSet[Tuple[str, Union[int, Tuple[int, int]]]]:
"""Hash circuit parameters using fixed precision.

Expand All @@ -166,7 +177,7 @@ class _MeasurementSpec:
"""

max_setting: InitObsSetting
circuit_params: Dict[Union[str, sympy.Expr], Union[value.Scalar, sympy.Expr]]
circuit_params: Mapping[Union[str, sympy.Expr], Union[value.Scalar, sympy.Expr]]

def __hash__(self):
return hash((self.max_setting, _hashable_param(self.circuit_params.items())))
Expand Down
4 changes: 2 additions & 2 deletions cirq-core/cirq/work/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,8 +352,8 @@ def sample_expectation_values(

# Flatten Circuit Sweep into one big list of Params.
# Keep track of their indices so we can map back.
flat_params: List['cirq.ParamDictType'] = [
dict(pr.param_dict) for pr in study.to_resolvers(params)
flat_params: List['cirq.ParamMappingType'] = [
pr.param_dict for pr in study.to_resolvers(params)
]
circuit_param_to_sweep_i: Dict[FrozenSet[Tuple[str, Union[int, Tuple[int, int]]]], int] = {
_hashable_param(param.items()): i for i, param in enumerate(flat_params)
Expand Down