From 61fefe6aee8e09efbc1f8b5f55b631221303518c Mon Sep 17 00:00:00 2001 From: Orion Martin <40585662+95-martin-orion@users.noreply.github.com> Date: Wed, 22 Jun 2022 14:08:11 -0700 Subject: [PATCH] Lock down CircuitOperation and ParamResolver (#5548) * Lock down CircuitOperation attributes. * Reduce attribute lockdown * Resolve type conflicts * review comments * docs and defensive copies * document error modes Co-authored-by: Cirq Bot --- cirq-core/cirq/__init__.py | 1 + cirq-core/cirq/circuits/circuit_operation.py | 387 ++++++++++-------- .../cirq/circuits/circuit_operation_test.py | 6 +- .../cirq/protocols/json_test_data/spec.py | 1 + cirq-core/cirq/study/__init__.py | 7 +- cirq-core/cirq/study/flatten_expressions.py | 6 +- cirq-core/cirq/study/resolver.py | 12 +- cirq-core/cirq/work/observable_measurement.py | 2 +- cirq-core/cirq/work/sampler.py | 2 +- .../cirq_rigetti/circuit_sweep_executors.py | 2 +- .../circuit_sweep_executors_test.py | 2 +- 11 files changed, 254 insertions(+), 174 deletions(-) diff --git a/cirq-core/cirq/__init__.py b/cirq-core/cirq/__init__.py index 03d5f22433e..41b160a4c44 100644 --- a/cirq-core/cirq/__init__.py +++ b/cirq-core/cirq/__init__.py @@ -514,6 +514,7 @@ Linspace, ListSweep, ParamDictType, + ParamMappingType, ParamResolver, ParamResolverOrSimilarType, Points, diff --git a/cirq-core/cirq/circuits/circuit_operation.py b/cirq-core/cirq/circuits/circuit_operation.py index f4049685204..adc006aea25 100644 --- a/cirq-core/cirq/circuits/circuit_operation.py +++ b/cirq-core/cirq/circuits/circuit_operation.py @@ -17,12 +17,12 @@ applied as part of a larger circuit, a CircuitOperation will execute all component operations in order, including any nested CircuitOperations. """ -import dataclasses import math from typing import ( AbstractSet, Callable, - cast, + Mapping, + Sequence, Dict, FrozenSet, Iterator, @@ -37,7 +37,7 @@ import sympy from cirq import circuits, ops, protocols, value, study -from cirq._compat import proper_repr +from cirq._compat import cached_property, proper_repr if TYPE_CHECKING: import cirq @@ -56,8 +56,8 @@ def default_repetition_ids(repetitions: IntParam) -> Optional[List[str]]: def _full_join_string_lists( - list1: Optional[List[str]], list2: Optional[List[str]] -) -> Optional[List[str]]: + list1: Optional[Sequence[str]], list2: Optional[Sequence[str]] +) -> Optional[Sequence[str]]: if list1 is None and list2 is None: return None # coverage: ignore if list1 is None: @@ -67,127 +67,179 @@ def _full_join_string_lists( return [f'{first}{REPETITION_ID_SEPARATOR}{second}' for first in list1 for second in list2] -@dataclasses.dataclass(frozen=True) class CircuitOperation(ops.Operation): """An operation that encapsulates a circuit. This class captures modifications to the contained circuit, such as tags and loops, to support more condensed serialization. Similar to GateOperation, this type is immutable. - - Args: - circuit: The FrozenCircuit wrapped by this operation. - repetitions: How many times the circuit should be repeated. This can be - integer, or a sympy expression. If sympy, the expression must - resolve to an integer, or float within 0.001 of integer, at - runtime. - qubit_map: Remappings for qubits in the circuit. - measurement_key_map: Remappings for measurement keys in the circuit. - The keys and values should be unindexed (i.e. without repetition_ids). - The values cannot contain the `MEASUREMENT_KEY_SEPARATOR`. - param_resolver: Resolved values for parameters in the circuit. - repetition_ids: List of identifiers for each repetition of the - CircuitOperation. If populated, the length should be equal to the - repetitions. If not populated and abs(`repetitions`) > 1, it is - initialized to strings for numbers in `range(repetitions)`. - parent_path: A tuple of identifiers for any parent CircuitOperations - containing this one. - extern_keys: The set of measurement keys defined at extern scope. The - values here are used by decomposition and simulation routines to - cache which external measurement keys exist as possible binding - targets for unbound `ClassicallyControlledOperation` keys. This - field is not intended to be set or changed manually, and should be - empty in circuits that aren't in the middle of decomposition. - use_repetition_ids: When True, any measurement key in the subcircuit - will have its path prepended with the repetition id for each - repetition. When False, this will not happen and the measurement - key will be repeated. - repeat_until: A condition that will be tested after each iteration of - the subcircuit. The subcircuit will repeat until condition returns - True, but will always run at least once, and the measurement key - need not be defined prior to the subcircuit (but must be defined in - a measurement within the subcircuit). This field is incompatible - with repetitions or repetition_ids. """ - _hash: Optional[int] = dataclasses.field(default=None, init=False) - _cached_measurement_key_objs: Optional[AbstractSet['cirq.MeasurementKey']] = dataclasses.field( - default=None, init=False - ) - _cached_control_keys: Optional[AbstractSet['cirq.MeasurementKey']] = dataclasses.field( - default=None, init=False - ) - _cached_mapped_single_loop: Optional['cirq.Circuit'] = dataclasses.field( - default=None, init=False - ) - - circuit: 'cirq.FrozenCircuit' - repetitions: IntParam = 1 - qubit_map: Dict['cirq.Qid', 'cirq.Qid'] = dataclasses.field(default_factory=dict) - measurement_key_map: Dict[str, str] = dataclasses.field(default_factory=dict) - param_resolver: study.ParamResolver = study.ParamResolver() - repetition_ids: Optional[List[str]] = dataclasses.field(default=None) - parent_path: Tuple[str, ...] = dataclasses.field(default_factory=tuple) - extern_keys: FrozenSet['cirq.MeasurementKey'] = dataclasses.field(default_factory=frozenset) - use_repetition_ids: bool = True - repeat_until: Optional['cirq.Condition'] = dataclasses.field(default=None) - - def __post_init__(self): - if not isinstance(self.circuit, circuits.FrozenCircuit): - raise TypeError(f'Expected circuit of type FrozenCircuit, got: {type(self.circuit)!r}') + def __init__( + self, + circuit: 'cirq.FrozenCircuit', + repetitions: int = 1, + qubit_map: Optional[Dict['cirq.Qid', 'cirq.Qid']] = None, + measurement_key_map: Optional[Dict[str, str]] = None, + param_resolver: Optional[study.ParamResolverOrSimilarType] = None, + repetition_ids: Optional[Sequence[str]] = None, + parent_path: Tuple[str, ...] = (), + extern_keys: FrozenSet['cirq.MeasurementKey'] = frozenset(), + use_repetition_ids: bool = True, + repeat_until: Optional['cirq.Condition'] = None, + ): + """Initializes a CircuitOperation. + + Args: + circuit: The FrozenCircuit wrapped by this operation. + repetitions: How many times the circuit should be repeated. This can be + integer, or a sympy expression. If sympy, the expression must + resolve to an integer, or float within 0.001 of integer, at + runtime. + qubit_map: Remappings for qubits in the circuit. + measurement_key_map: Remappings for measurement keys in the circuit. + The keys and values should be unindexed (i.e. without repetition_ids). + The values cannot contain the `MEASUREMENT_KEY_SEPARATOR`. + param_resolver: Resolved values for parameters in the circuit. + repetition_ids: List of identifiers for each repetition of the + CircuitOperation. If populated, the length should be equal to the + repetitions. If not populated and abs(`repetitions`) > 1, it is + initialized to strings for numbers in `range(repetitions)`. + parent_path: A tuple of identifiers for any parent CircuitOperations + containing this one. + extern_keys: The set of measurement keys defined at extern scope. The + values here are used by decomposition and simulation routines to + cache which external measurement keys exist as possible binding + targets for unbound `ClassicallyControlledOperation` keys. This + field is not intended to be set or changed manually, and should be + empty in circuits that aren't in the middle of decomposition. + use_repetition_ids: When True, any measurement key in the subcircuit + will have its path prepended with the repetition id for each + repetition. When False, this will not happen and the measurement + key will be repeated. + repeat_until: A condition that will be tested after each iteration of + the subcircuit. The subcircuit will repeat until condition returns + True, but will always run at least once, and the measurement key + need not be defined prior to the subcircuit (but must be defined in + a measurement within the subcircuit). This field is incompatible + with repetitions or repetition_ids. + + Raises: + TypeError: if repetitions is not an integer or sympy expression, or if + the provided circuit is not a FrozenCircuit. + ValueError: if any of the following conditions is met. + - Negative repetitions on non-invertible circuit + - Number of repetition IDs does not match repetitions + - Repetition IDs used with parameterized repetitions + - Conflicting qubit dimensions in qubit_map + - Measurement key map has invalid key names + - repeat_until used with other repetition controls + - Key(s) in repeat_until are not modified by circuit + """ + # This fields is exclusively for use in decomposition. It should not be + # referenced outside this class. + self._extern_keys = extern_keys + + # All other fields are pseudo-private: read access is allowed via the + # @property methods, but mutation is prohibited. + self._param_resolver = study.ParamResolver(param_resolver) + self._parent_path = parent_path + + self._circuit = circuit + if not isinstance(self._circuit, circuits.FrozenCircuit): + raise TypeError(f'Expected circuit of type FrozenCircuit, got: {type(self._circuit)!r}') # Ensure that the circuit is invertible if the repetitions are negative. - if isinstance(self.repetitions, float): - if math.isclose(self.repetitions, round(self.repetitions)): - object.__setattr__(self, 'repetitions', round(self.repetitions)) - if isinstance(self.repetitions, INT_CLASSES): - if self.repetitions < 0: + self._repetitions = repetitions + self._repetition_ids = None if repetition_ids is None else list(repetition_ids) + self._use_repetition_ids = use_repetition_ids + if isinstance(self._repetitions, float): + if math.isclose(self._repetitions, round(self._repetitions)): + self._repetitions = round(self._repetitions) + if isinstance(self._repetitions, INT_CLASSES): + if self._repetitions < 0: try: - protocols.inverse(self.circuit.unfreeze()) + protocols.inverse(self._circuit.unfreeze()) except TypeError: raise ValueError('repetitions are negative but the circuit is not invertible') # Initialize repetition_ids to default, if unspecified. Else, validate their length. - loop_size = abs(self.repetitions) - if not self.repetition_ids: - object.__setattr__(self, 'repetition_ids', self._default_repetition_ids()) - elif len(self.repetition_ids) != loop_size: + loop_size = abs(self._repetitions) + if not self._repetition_ids: + self._repetition_ids = self._default_repetition_ids() + elif len(self._repetition_ids) != loop_size: raise ValueError( f'Expected repetition_ids to be a list of length {loop_size}, ' - f'got: {self.repetition_ids}' + f'got: {self._repetition_ids}' ) - elif isinstance(self.repetitions, sympy.Expr): - if self.repetition_ids is not None: + elif isinstance(self._repetitions, sympy.Expr): + if self._repetition_ids is not None: raise ValueError('Cannot use repetition ids with parameterized repetitions') else: raise TypeError( f'Only integer or sympy repetitions are allowed.\n' - f'User provided: {self.repetitions}' + f'User provided: {self._repetitions}' ) + # Disallow qid mapping dimension conflicts. + self._qubit_map = dict(qubit_map or {}) + for q, q_new in self._qubit_map.items(): + if q_new.dimension != q.dimension: + raise ValueError(f'Qid dimension conflict.\nFrom qid: {q}\nTo qid: {q_new}') + + self._measurement_key_map = dict(measurement_key_map or {}) # Disallow mapping to keys containing the `MEASUREMENT_KEY_SEPARATOR` - for mapped_key in self.measurement_key_map.values(): + for mapped_key in self._measurement_key_map.values(): if value.MEASUREMENT_KEY_SEPARATOR in mapped_key: raise ValueError( f'Mapping to invalid key: {mapped_key}. "{value.MEASUREMENT_KEY_SEPARATOR}" ' 'is not allowed for measurement keys in a CircuitOperation' ) - # Disallow qid mapping dimension conflicts. - for q, q_new in self.qubit_map.items(): - if q_new.dimension != q.dimension: - raise ValueError(f'Qid dimension conflict.\nFrom qid: {q}\nTo qid: {q_new}') - - if self.repeat_until: - if self.use_repetition_ids or self.repetitions != 1: + self._repeat_until = repeat_until + if self._repeat_until: + if self._use_repetition_ids or self._repetitions != 1: raise ValueError('Cannot use repetitions with repeat_until') if protocols.measurement_key_objs(self._mapped_single_loop()).isdisjoint( - self.repeat_until.keys + self._repeat_until.keys ): raise ValueError('Infinite loop: condition is not modified in subcircuit.') - # Ensure that param_resolver is converted to an actual ParamResolver. - object.__setattr__(self, 'param_resolver', study.ParamResolver(self.param_resolver)) + @property + def circuit(self) -> 'cirq.FrozenCircuit': + return self._circuit + + @property + def repetitions(self) -> IntParam: + return self._repetitions + + @property + def repetition_ids(self) -> Optional[Sequence[str]]: + return self._repetition_ids + + @property + def use_repetition_ids(self) -> bool: + return self._use_repetition_ids + + @property + def repeat_until(self) -> Optional['cirq.Condition']: + return self._repeat_until + + @property + def qubit_map(self) -> Mapping['cirq.Qid', 'cirq.Qid']: + return self._qubit_map + + @property + def measurement_key_map(self) -> Mapping[str, str]: + return self._measurement_key_map + + @property + def param_resolver(self) -> study.ParamResolver: + return self._param_resolver + + @property + def parent_path(self) -> Tuple[str, ...]: + return self._parent_path def base_operation(self) -> 'cirq.CircuitOperation': """Returns a copy of this operation with only the wrapped circuit. @@ -198,7 +250,20 @@ def base_operation(self) -> 'cirq.CircuitOperation': def replace(self, **changes) -> 'cirq.CircuitOperation': """Returns a copy of this operation with the specified changes.""" - return dataclasses.replace(self, **changes) + kwargs = { + 'circuit': self.circuit, + 'repetitions': self.repetitions, + 'qubit_map': self.qubit_map, + 'measurement_key_map': self.measurement_key_map, + 'param_resolver': self.param_resolver, + 'repetition_ids': self.repetition_ids, + 'parent_path': self.parent_path, + 'extern_keys': self._extern_keys, + 'use_repetition_ids': self.use_repetition_ids, + 'repeat_until': self.repeat_until, + **changes, + } + return CircuitOperation(**kwargs) # type: ignore def __eq__(self, other) -> bool: if not isinstance(other, type(self)): @@ -243,42 +308,42 @@ def _ensure_deterministic_loop_count(self): if self.repeat_until or isinstance(self.repetitions, sympy.Expr): raise ValueError('Cannot unroll circuit due to nondeterministic repetitions') - def _measurement_key_objs_(self) -> AbstractSet['cirq.MeasurementKey']: - if self._cached_measurement_key_objs is None: - circuit_keys = protocols.measurement_key_objs(self.circuit) - if circuit_keys and self.use_repetition_ids: - self._ensure_deterministic_loop_count() - if self.repetition_ids is not None: - circuit_keys = { - key.with_key_path_prefix(repetition_id) - for repetition_id in self.repetition_ids - for key in circuit_keys - } - circuit_keys = {key.with_key_path_prefix(*self.parent_path) for key in circuit_keys} - object.__setattr__( - self, - '_cached_measurement_key_objs', - { - protocols.with_measurement_key_mapping(key, self.measurement_key_map) + @cached_property + def _measurement_key_objs(self) -> AbstractSet['cirq.MeasurementKey']: + circuit_keys = protocols.measurement_key_objs(self.circuit) + if circuit_keys and self.use_repetition_ids: + self._ensure_deterministic_loop_count() + if self.repetition_ids is not None: + circuit_keys = { + key.with_key_path_prefix(repetition_id) + for repetition_id in self.repetition_ids for key in circuit_keys - }, - ) - return self._cached_measurement_key_objs # type: ignore + } + circuit_keys = {key.with_key_path_prefix(*self.parent_path) for key in circuit_keys} + return { + protocols.with_measurement_key_mapping(key, dict(self.measurement_key_map)) + for key in circuit_keys + } + + def _measurement_key_objs_(self) -> AbstractSet['cirq.MeasurementKey']: + return self._measurement_key_objs def _measurement_key_names_(self) -> AbstractSet[str]: return {str(key) for key in self._measurement_key_objs_()} + @cached_property + def _control_keys(self) -> AbstractSet['cirq.MeasurementKey']: + keys = ( + frozenset() + if not protocols.control_keys(self.circuit) + else protocols.control_keys(self._mapped_single_loop()) + ) + if self.repeat_until is not None: + keys |= frozenset(self.repeat_until.keys) - self._measurement_key_objs_() + return keys + def _control_keys_(self) -> AbstractSet['cirq.MeasurementKey']: - if self._cached_control_keys is None: - keys = ( - frozenset() - if not protocols.control_keys(self.circuit) - else protocols.control_keys(self._mapped_single_loop()) - ) - if self.repeat_until is not None: - keys |= frozenset(self.repeat_until.keys) - self._measurement_key_objs_() - object.__setattr__(self, '_cached_control_keys', keys) - return self._cached_control_keys # type: ignore + return self._control_keys def _is_parameterized_(self) -> bool: return any(self._parameter_names_generator()) @@ -294,25 +359,27 @@ def _parameter_names_generator(self) -> Iterator[str]: ): yield name + @cached_property + def _mapped_any_loop(self) -> 'cirq.Circuit': + circuit = self.circuit.unfreeze() + if self.qubit_map: + circuit = circuit.transform_qubits(lambda q: self.qubit_map.get(q, q)) + if isinstance(self.repetitions, INT_CLASSES) and self.repetitions < 0: + circuit = circuit**-1 + if self.measurement_key_map: + circuit = protocols.with_measurement_key_mapping( + circuit, dict(self.measurement_key_map) + ) + if self.param_resolver: + circuit = protocols.resolve_parameters(circuit, self.param_resolver, recursive=False) + return circuit.unfreeze(copy=False) + def _mapped_single_loop(self, repetition_id: Optional[str] = None) -> 'cirq.Circuit': - if self._cached_mapped_single_loop is None: - circuit = self.circuit.unfreeze() - if self.qubit_map: - circuit = circuit.transform_qubits(lambda q: self.qubit_map.get(q, q)) - if isinstance(self.repetitions, INT_CLASSES) and self.repetitions < 0: - circuit = circuit**-1 - if self.measurement_key_map: - circuit = protocols.with_measurement_key_mapping(circuit, self.measurement_key_map) - if self.param_resolver: - circuit = protocols.resolve_parameters( - circuit, self.param_resolver, recursive=False - ) - object.__setattr__(self, '_cached_mapped_single_loop', circuit) - circuit = cast(circuits.Circuit, self._cached_mapped_single_loop) + circuit = self._mapped_any_loop if repetition_id: circuit = protocols.with_rescoped_keys(circuit, (repetition_id,)) return protocols.with_rescoped_keys( - circuit, self.parent_path, bindable_keys=self.extern_keys + circuit, self.parent_path, bindable_keys=self._extern_keys ) def mapped_circuit(self, deep: bool = False) -> 'cirq.Circuit': @@ -422,24 +489,22 @@ def dict_str(d: Dict) -> str: return circuit_msg return f'{circuit_msg}({", ".join(args)})' - def __hash__(self): - if self._hash is None: - object.__setattr__( - self, - '_hash', - hash( - ( - self.circuit, - self.repetitions, - frozenset(self.qubit_map.items()), - frozenset(self.measurement_key_map.items()), - self.param_resolver, - self.parent_path, - tuple([] if self.repetition_ids is None else self.repetition_ids), - self.use_repetition_ids, - ) - ), + @cached_property + def _hash(self) -> int: + return hash( + ( + self.circuit, + self.repetitions, + frozenset(self.qubit_map.items()), + frozenset(self.measurement_key_map.items()), + self.param_resolver, + self.parent_path, + () if self.repetition_ids is None else tuple(self.repetition_ids), + self.use_repetition_ids, ) + ) + + def __hash__(self) -> int: return self._hash def _json_dict_(self): @@ -489,7 +554,7 @@ def _from_json_dict_( # Methods for constructing a similar object with one field modified. def repeat( - self, repetitions: Optional[IntParam] = None, repetition_ids: Optional[List[str]] = None + self, repetitions: Optional[IntParam] = None, repetition_ids: Optional[Sequence[str]] = None ) -> 'CircuitOperation': """Returns a copy of this operation repeated 'repetitions' times. Each repetition instance will be identified by a single repetition_id. @@ -546,10 +611,10 @@ def __pow__(self, power: IntParam) -> 'cirq.CircuitOperation': return self.repeat(power) def _with_key_path_(self, path: Tuple[str, ...]): - return dataclasses.replace(self, parent_path=path) + return self.replace(parent_path=path) def _with_key_path_prefix_(self, prefix: Tuple[str, ...]): - return dataclasses.replace(self, parent_path=prefix + self.parent_path) + return self.replace(parent_path=prefix + self.parent_path) def _with_rescoped_keys_( self, path: Tuple[str, ...], bindable_keys: FrozenSet['cirq.MeasurementKey'] @@ -560,9 +625,9 @@ def _with_rescoped_keys_( # the subcircuit having some 'allow_cross_circuit_binding' field set), this is the line to # change or remove. bindable_keys = frozenset(k for k in bindable_keys if len(k.path) <= len(path)) - bindable_keys |= {k.with_key_path_prefix(*path) for k in self.extern_keys} + bindable_keys |= {k.with_key_path_prefix(*path) for k in self._extern_keys} path += self.parent_path - return dataclasses.replace(self, parent_path=path, extern_keys=bindable_keys) + return self.replace(parent_path=path, extern_keys=bindable_keys) def with_key_path(self, path: Tuple[str, ...]): return self._with_key_path_(path) @@ -571,7 +636,7 @@ def with_repetition_ids(self, repetition_ids: List[str]) -> 'cirq.CircuitOperati return self.replace(repetition_ids=repetition_ids) def with_qubit_mapping( - self, qubit_map: Union[Dict['cirq.Qid', 'cirq.Qid'], Callable[['cirq.Qid'], 'cirq.Qid']] + self, qubit_map: Union[Mapping['cirq.Qid', 'cirq.Qid'], Callable[['cirq.Qid'], 'cirq.Qid']] ) -> 'cirq.CircuitOperation': """Returns a copy of this operation with an updated qubit mapping. @@ -631,7 +696,7 @@ def with_qubits(self, *new_qubits: 'cirq.Qid') -> 'cirq.CircuitOperation': raise ValueError(f'Expected {expected} qubits, got {len(new_qubits)}.') return self.with_qubit_mapping(dict(zip(self.qubits, new_qubits))) - def with_measurement_key_mapping(self, key_map: Dict[str, str]) -> 'cirq.CircuitOperation': + def with_measurement_key_mapping(self, key_map: Mapping[str, str]) -> 'cirq.CircuitOperation': """Returns a copy of this operation with an updated key mapping. Args: @@ -665,7 +730,7 @@ def with_measurement_key_mapping(self, key_map: Dict[str, str]) -> 'cirq.Circuit ) return new_op - def _with_measurement_key_mapping_(self, key_map: Dict[str, str]) -> 'cirq.CircuitOperation': + def _with_measurement_key_mapping_(self, key_map: Mapping[str, str]) -> 'cirq.CircuitOperation': return self.with_measurement_key_mapping(key_map) def with_params( diff --git a/cirq-core/cirq/circuits/circuit_operation_test.py b/cirq-core/cirq/circuits/circuit_operation_test.py index e708202c46c..c8cc0be2e70 100644 --- a/cirq-core/cirq/circuits/circuit_operation_test.py +++ b/cirq-core/cirq/circuits/circuit_operation_test.py @@ -994,8 +994,10 @@ def test_keys_under_parent_path(): assert cirq.measurement_key_names(op1) == {'A'} op2 = op1.with_key_path(('B',)) assert cirq.measurement_key_names(op2) == {'B:A'} - op3 = op2.repeat(2) - assert cirq.measurement_key_names(op3) == {'B:0:A', 'B:1:A'} + op3 = cirq.with_key_path_prefix(op2, ('C',)) + assert cirq.measurement_key_names(op3) == {'C:B:A'} + op4 = op3.repeat(2) + assert cirq.measurement_key_names(op4) == {'C:B:0:A', 'C:B:1:A'} def test_mapped_circuit_preserves_moments(): diff --git a/cirq-core/cirq/protocols/json_test_data/spec.py b/cirq-core/cirq/protocols/json_test_data/spec.py index 6e1828074aa..24ed13ef904 100644 --- a/cirq-core/cirq/protocols/json_test_data/spec.py +++ b/cirq-core/cirq/protocols/json_test_data/spec.py @@ -185,6 +185,7 @@ 'TParamValComplex', 'TRANSFORMER', 'ParamDictType', + 'ParamMappingType', # utility: 'CliffordSimulator', 'NoiseModelFromNoiseProperties', diff --git a/cirq-core/cirq/study/__init__.py b/cirq-core/cirq/study/__init__.py index b1cd838a1b5..61b88fc83cc 100644 --- a/cirq-core/cirq/study/__init__.py +++ b/cirq-core/cirq/study/__init__.py @@ -21,7 +21,12 @@ flatten_with_sweep, ) -from cirq.study.resolver import ParamDictType, ParamResolver, ParamResolverOrSimilarType +from cirq.study.resolver import ( + ParamDictType, + ParamMappingType, + ParamResolver, + ParamResolverOrSimilarType, +) from cirq.study.sweepable import Sweepable, to_resolvers, to_sweep, to_sweeps diff --git a/cirq-core/cirq/study/flatten_expressions.py b/cirq-core/cirq/study/flatten_expressions.py index 2f82105190f..98014f6a2d9 100644 --- a/cirq-core/cirq/study/flatten_expressions.py +++ b/cirq-core/cirq/study/flatten_expressions.py @@ -278,7 +278,7 @@ def value_of( return out # Create a new symbol symbol = self._next_symbol(value) - self.param_dict[value] = symbol + self._param_dict[value] = symbol self._taken_symbols.add(symbol) return symbol @@ -292,9 +292,9 @@ def __bool__(self) -> bool: def __repr__(self) -> str: if self.get_param_name == self.default_get_param_name: - return f'_ParamFlattener({self.param_dict!r})' + return f'_ParamFlattener({self._param_dict!r})' else: - return f'_ParamFlattener({self.param_dict!r}, get_param_name={self.get_param_name!r})' + return f'_ParamFlattener({self._param_dict!r}, get_param_name={self.get_param_name!r})' def flatten(self, val: Any) -> Any: """Returns a copy of `val` with any symbols or expressions replaced with diff --git a/cirq-core/cirq/study/resolver.py b/cirq-core/cirq/study/resolver.py index 4ea4173ff82..fe8464cdbc3 100644 --- a/cirq-core/cirq/study/resolver.py +++ b/cirq-core/cirq/study/resolver.py @@ -14,7 +14,7 @@ """Resolves ParameterValues to assigned values.""" import numbers -from typing import Any, Dict, Iterator, Optional, TYPE_CHECKING, Union, cast +from typing import Any, Dict, Iterator, Mapping, Optional, TYPE_CHECKING, Union, cast import numpy as np import sympy @@ -27,9 +27,11 @@ ParamDictType = Dict['cirq.TParamKey', 'cirq.TParamValComplex'] +ParamMappingType = Mapping['cirq.TParamKey', 'cirq.TParamValComplex'] document(ParamDictType, """Dictionary from symbols to values.""") # type: ignore +document(ParamMappingType, """Immutable map from symbols to values.""") # type: ignore -ParamResolverOrSimilarType = Union['cirq.ParamResolver', ParamDictType, None] +ParamResolverOrSimilarType = Union['cirq.ParamResolver', ParamMappingType, None] document( ParamResolverOrSimilarType, # type: ignore """Something that can be used to turn parameters into values.""", @@ -70,12 +72,16 @@ def __init__(self, param_dict: 'cirq.ParamResolverOrSimilarType' = None) -> None return # Already initialized. Got wrapped as part of the __new__. self._param_hash: Optional[int] = None - self.param_dict = cast(ParamDictType, {} if param_dict is None else param_dict) + self._param_dict = cast(ParamDictType, {} if param_dict is None else param_dict) for key in self.param_dict: if isinstance(key, sympy.Expr) and not isinstance(key, sympy.Symbol): raise TypeError(f'ParamResolver keys cannot be (non-symbol) formulas ({key})') self._deep_eval_map: ParamDictType = {} + @property + def param_dict(self) -> ParamMappingType: + return self._param_dict + def value_of( self, value: Union['cirq.TParamKey', 'cirq.TParamValComplex'], recursive: bool = True ) -> 'cirq.TParamValComplex': diff --git a/cirq-core/cirq/work/observable_measurement.py b/cirq-core/cirq/work/observable_measurement.py index 5a8ccc7a900..f31f3ab59be 100644 --- a/cirq-core/cirq/work/observable_measurement.py +++ b/cirq-core/cirq/work/observable_measurement.py @@ -531,7 +531,7 @@ def measure_grouped_settings( for max_setting, param_resolver in itertools.product( grouped_settings.keys(), study.to_resolvers(circuit_sweep) ): - circuit_params = param_resolver.param_dict + circuit_params = dict(param_resolver.param_dict) meas_spec = _MeasurementSpec(max_setting=max_setting, circuit_params=circuit_params) accumulator = BitstringAccumulator( meas_spec=meas_spec, diff --git a/cirq-core/cirq/work/sampler.py b/cirq-core/cirq/work/sampler.py index 20423cc2bef..2184cad4687 100644 --- a/cirq-core/cirq/work/sampler.py +++ b/cirq-core/cirq/work/sampler.py @@ -353,7 +353,7 @@ def sample_expectation_values( # Flatten Circuit Sweep into one big list of Params. # Keep track of their indices so we can map back. flat_params: List['cirq.ParamDictType'] = [ - pr.param_dict for pr in study.to_resolvers(params) + dict(pr.param_dict) for pr in study.to_resolvers(params) ] circuit_param_to_sweep_i: Dict[FrozenSet[Tuple[str, Union[int, Tuple[int, int]]]], int] = { _hashable_param(param.items()): i for i, param in enumerate(flat_params) diff --git a/cirq-rigetti/cirq_rigetti/circuit_sweep_executors.py b/cirq-rigetti/cirq_rigetti/circuit_sweep_executors.py index b32052d70d7..464662ae757 100644 --- a/cirq-rigetti/cirq_rigetti/circuit_sweep_executors.py +++ b/cirq-rigetti/cirq_rigetti/circuit_sweep_executors.py @@ -97,7 +97,7 @@ def _get_param_dict(resolver: cirq.ParamResolverOrSimilarType) -> Dict[Union[str """ param_dict: Dict[Union[str, sympy.Expr], Any] = {} if isinstance(resolver, cirq.ParamResolver): - param_dict = resolver.param_dict + param_dict = dict(resolver.param_dict) elif isinstance(resolver, dict): param_dict = resolver return param_dict diff --git a/cirq-rigetti/cirq_rigetti/circuit_sweep_executors_test.py b/cirq-rigetti/cirq_rigetti/circuit_sweep_executors_test.py index 40075fdae44..2a0a3492d8f 100644 --- a/cirq-rigetti/cirq_rigetti/circuit_sweep_executors_test.py +++ b/cirq-rigetti/cirq_rigetti/circuit_sweep_executors_test.py @@ -60,7 +60,7 @@ def test_with_quilc_parametric_compilation( param_resolvers: List[Union[cirq.ParamResolver, cirq.ParamDictType]] if pass_dict: - param_resolvers = [params.param_dict for params in sweepable] + param_resolvers = [dict(params.param_dict) for params in sweepable] else: param_resolvers = [r for r in cirq.to_resolvers(sweepable)] expected_results = [