Skip to content

Commit

Permalink
Lock down CircuitOperation and ParamResolver (#5548)
Browse files Browse the repository at this point in the history
* Lock down CircuitOperation attributes.

* Reduce attribute lockdown

* Resolve type conflicts

* review comments

* docs and defensive copies

* document error modes

Co-authored-by: Cirq Bot <craiggidney+github+cirqbot@google.com>
  • Loading branch information
95-martin-orion and CirqBot committed Jun 22, 2022
1 parent dd4c0a2 commit 61fefe6
Show file tree
Hide file tree
Showing 11 changed files with 254 additions and 174 deletions.
1 change: 1 addition & 0 deletions cirq-core/cirq/__init__.py
Expand Up @@ -514,6 +514,7 @@
Linspace,
ListSweep,
ParamDictType,
ParamMappingType,
ParamResolver,
ParamResolverOrSimilarType,
Points,
Expand Down
387 changes: 226 additions & 161 deletions cirq-core/cirq/circuits/circuit_operation.py

Large diffs are not rendered by default.

6 changes: 4 additions & 2 deletions cirq-core/cirq/circuits/circuit_operation_test.py
Expand Up @@ -994,8 +994,10 @@ def test_keys_under_parent_path():
assert cirq.measurement_key_names(op1) == {'A'}
op2 = op1.with_key_path(('B',))
assert cirq.measurement_key_names(op2) == {'B:A'}
op3 = op2.repeat(2)
assert cirq.measurement_key_names(op3) == {'B:0:A', 'B:1:A'}
op3 = cirq.with_key_path_prefix(op2, ('C',))
assert cirq.measurement_key_names(op3) == {'C:B:A'}
op4 = op3.repeat(2)
assert cirq.measurement_key_names(op4) == {'C:B:0:A', 'C:B:1:A'}


def test_mapped_circuit_preserves_moments():
Expand Down
1 change: 1 addition & 0 deletions cirq-core/cirq/protocols/json_test_data/spec.py
Expand Up @@ -185,6 +185,7 @@
'TParamValComplex',
'TRANSFORMER',
'ParamDictType',
'ParamMappingType',
# utility:
'CliffordSimulator',
'NoiseModelFromNoiseProperties',
Expand Down
7 changes: 6 additions & 1 deletion cirq-core/cirq/study/__init__.py
Expand Up @@ -21,7 +21,12 @@
flatten_with_sweep,
)

from cirq.study.resolver import ParamDictType, ParamResolver, ParamResolverOrSimilarType
from cirq.study.resolver import (
ParamDictType,
ParamMappingType,
ParamResolver,
ParamResolverOrSimilarType,
)

from cirq.study.sweepable import Sweepable, to_resolvers, to_sweep, to_sweeps

Expand Down
6 changes: 3 additions & 3 deletions cirq-core/cirq/study/flatten_expressions.py
Expand Up @@ -278,7 +278,7 @@ def value_of(
return out
# Create a new symbol
symbol = self._next_symbol(value)
self.param_dict[value] = symbol
self._param_dict[value] = symbol
self._taken_symbols.add(symbol)
return symbol

Expand All @@ -292,9 +292,9 @@ def __bool__(self) -> bool:

def __repr__(self) -> str:
if self.get_param_name == self.default_get_param_name:
return f'_ParamFlattener({self.param_dict!r})'
return f'_ParamFlattener({self._param_dict!r})'
else:
return f'_ParamFlattener({self.param_dict!r}, get_param_name={self.get_param_name!r})'
return f'_ParamFlattener({self._param_dict!r}, get_param_name={self.get_param_name!r})'

def flatten(self, val: Any) -> Any:
"""Returns a copy of `val` with any symbols or expressions replaced with
Expand Down
12 changes: 9 additions & 3 deletions cirq-core/cirq/study/resolver.py
Expand Up @@ -14,7 +14,7 @@

"""Resolves ParameterValues to assigned values."""
import numbers
from typing import Any, Dict, Iterator, Optional, TYPE_CHECKING, Union, cast
from typing import Any, Dict, Iterator, Mapping, Optional, TYPE_CHECKING, Union, cast

import numpy as np
import sympy
Expand All @@ -27,9 +27,11 @@


ParamDictType = Dict['cirq.TParamKey', 'cirq.TParamValComplex']
ParamMappingType = Mapping['cirq.TParamKey', 'cirq.TParamValComplex']
document(ParamDictType, """Dictionary from symbols to values.""") # type: ignore
document(ParamMappingType, """Immutable map from symbols to values.""") # type: ignore

ParamResolverOrSimilarType = Union['cirq.ParamResolver', ParamDictType, None]
ParamResolverOrSimilarType = Union['cirq.ParamResolver', ParamMappingType, None]
document(
ParamResolverOrSimilarType, # type: ignore
"""Something that can be used to turn parameters into values.""",
Expand Down Expand Up @@ -70,12 +72,16 @@ def __init__(self, param_dict: 'cirq.ParamResolverOrSimilarType' = None) -> None
return # Already initialized. Got wrapped as part of the __new__.

self._param_hash: Optional[int] = None
self.param_dict = cast(ParamDictType, {} if param_dict is None else param_dict)
self._param_dict = cast(ParamDictType, {} if param_dict is None else param_dict)
for key in self.param_dict:
if isinstance(key, sympy.Expr) and not isinstance(key, sympy.Symbol):
raise TypeError(f'ParamResolver keys cannot be (non-symbol) formulas ({key})')
self._deep_eval_map: ParamDictType = {}

@property
def param_dict(self) -> ParamMappingType:
return self._param_dict

def value_of(
self, value: Union['cirq.TParamKey', 'cirq.TParamValComplex'], recursive: bool = True
) -> 'cirq.TParamValComplex':
Expand Down
2 changes: 1 addition & 1 deletion cirq-core/cirq/work/observable_measurement.py
Expand Up @@ -531,7 +531,7 @@ def measure_grouped_settings(
for max_setting, param_resolver in itertools.product(
grouped_settings.keys(), study.to_resolvers(circuit_sweep)
):
circuit_params = param_resolver.param_dict
circuit_params = dict(param_resolver.param_dict)
meas_spec = _MeasurementSpec(max_setting=max_setting, circuit_params=circuit_params)
accumulator = BitstringAccumulator(
meas_spec=meas_spec,
Expand Down
2 changes: 1 addition & 1 deletion cirq-core/cirq/work/sampler.py
Expand Up @@ -353,7 +353,7 @@ def sample_expectation_values(
# Flatten Circuit Sweep into one big list of Params.
# Keep track of their indices so we can map back.
flat_params: List['cirq.ParamDictType'] = [
pr.param_dict for pr in study.to_resolvers(params)
dict(pr.param_dict) for pr in study.to_resolvers(params)
]
circuit_param_to_sweep_i: Dict[FrozenSet[Tuple[str, Union[int, Tuple[int, int]]]], int] = {
_hashable_param(param.items()): i for i, param in enumerate(flat_params)
Expand Down
2 changes: 1 addition & 1 deletion cirq-rigetti/cirq_rigetti/circuit_sweep_executors.py
Expand Up @@ -97,7 +97,7 @@ def _get_param_dict(resolver: cirq.ParamResolverOrSimilarType) -> Dict[Union[str
"""
param_dict: Dict[Union[str, sympy.Expr], Any] = {}
if isinstance(resolver, cirq.ParamResolver):
param_dict = resolver.param_dict
param_dict = dict(resolver.param_dict)
elif isinstance(resolver, dict):
param_dict = resolver
return param_dict
Expand Down
2 changes: 1 addition & 1 deletion cirq-rigetti/cirq_rigetti/circuit_sweep_executors_test.py
Expand Up @@ -60,7 +60,7 @@ def test_with_quilc_parametric_compilation(

param_resolvers: List[Union[cirq.ParamResolver, cirq.ParamDictType]]
if pass_dict:
param_resolvers = [params.param_dict for params in sweepable]
param_resolvers = [dict(params.param_dict) for params in sweepable]
else:
param_resolvers = [r for r in cirq.to_resolvers(sweepable)]
expected_results = [
Expand Down

0 comments on commit 61fefe6

Please sign in to comment.