diff --git a/cirq-core/cirq/work/observable_measurement.py b/cirq-core/cirq/work/observable_measurement.py index f31f3ab59be..5a8ccc7a900 100644 --- a/cirq-core/cirq/work/observable_measurement.py +++ b/cirq-core/cirq/work/observable_measurement.py @@ -531,7 +531,7 @@ def measure_grouped_settings( for max_setting, param_resolver in itertools.product( grouped_settings.keys(), study.to_resolvers(circuit_sweep) ): - circuit_params = dict(param_resolver.param_dict) + circuit_params = param_resolver.param_dict meas_spec = _MeasurementSpec(max_setting=max_setting, circuit_params=circuit_params) accumulator = BitstringAccumulator( meas_spec=meas_spec, diff --git a/cirq-core/cirq/work/observable_measurement_data.py b/cirq-core/cirq/work/observable_measurement_data.py index 95459ddef7b..9936eed2f0d 100644 --- a/cirq-core/cirq/work/observable_measurement_data.py +++ b/cirq-core/cirq/work/observable_measurement_data.py @@ -14,7 +14,7 @@ import dataclasses import datetime -from typing import Any, Dict, Iterable, List, Tuple, TYPE_CHECKING, Union +from typing import Any, Dict, Iterable, List, Mapping, Tuple, TYPE_CHECKING, Union import numpy as np import sympy @@ -107,7 +107,7 @@ class ObservableMeasuredResult: mean: float variance: float repetitions: int - circuit_params: Dict[Union[str, sympy.Expr], Union[value.Scalar, sympy.Expr]] + circuit_params: Mapping[Union[str, sympy.Expr], Union[value.Scalar, sympy.Expr]] def __repr__(self): # I wish we could use the default dataclass __repr__ but diff --git a/cirq-core/cirq/work/observable_settings.py b/cirq-core/cirq/work/observable_settings.py index 1a265cb5b6f..677297fbc91 100644 --- a/cirq-core/cirq/work/observable_settings.py +++ b/cirq-core/cirq/work/observable_settings.py @@ -14,7 +14,17 @@ import dataclasses import numbers -from typing import Union, Iterable, Dict, Optional, TYPE_CHECKING, ItemsView, Tuple, FrozenSet +from typing import ( + AbstractSet, + Mapping, + Union, + Iterable, + Dict, + Optional, + TYPE_CHECKING, + Tuple, + FrozenSet, +) import sympy @@ -143,7 +153,8 @@ def _fix_precision(val: Union[value.Scalar, sympy.Expr], precision) -> Union[int def _hashable_param( - param_tuples: ItemsView[Union[str, sympy.Expr], Union[value.Scalar, sympy.Expr]], precision=1e7 + param_tuples: AbstractSet[Tuple[Union[str, sympy.Expr], Union[value.Scalar, sympy.Expr]]], + precision=1e7, ) -> FrozenSet[Tuple[str, Union[int, Tuple[int, int]]]]: """Hash circuit parameters using fixed precision. @@ -166,7 +177,7 @@ class _MeasurementSpec: """ max_setting: InitObsSetting - circuit_params: Dict[Union[str, sympy.Expr], Union[value.Scalar, sympy.Expr]] + circuit_params: Mapping[Union[str, sympy.Expr], Union[value.Scalar, sympy.Expr]] def __hash__(self): return hash((self.max_setting, _hashable_param(self.circuit_params.items()))) diff --git a/cirq-core/cirq/work/sampler.py b/cirq-core/cirq/work/sampler.py index 2184cad4687..2ccf9477659 100644 --- a/cirq-core/cirq/work/sampler.py +++ b/cirq-core/cirq/work/sampler.py @@ -352,8 +352,8 @@ def sample_expectation_values( # Flatten Circuit Sweep into one big list of Params. # Keep track of their indices so we can map back. - flat_params: List['cirq.ParamDictType'] = [ - dict(pr.param_dict) for pr in study.to_resolvers(params) + flat_params: List['cirq.ParamMappingType'] = [ + pr.param_dict for pr in study.to_resolvers(params) ] circuit_param_to_sweep_i: Dict[FrozenSet[Tuple[str, Union[int, Tuple[int, int]]]], int] = { _hashable_param(param.items()): i for i, param in enumerate(flat_params)