Skip to content

Commit

Permalink
Optimize ParamResolver.value_of (#6341)
Browse files Browse the repository at this point in the history
Review: @dstrain115
  • Loading branch information
maffoo committed Nov 15, 2023
1 parent 392083b commit 8d07cab
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 59 deletions.
7 changes: 2 additions & 5 deletions cirq-core/cirq/sim/simulator_test.py
Expand Up @@ -134,8 +134,7 @@ def steps(*args, **kwargs):

simulator.simulate_moment_steps.side_effect = steps
circuit = mock.Mock(cirq.Circuit)
param_resolver = mock.Mock(cirq.ParamResolver)
param_resolver.param_dict = {}
param_resolver = cirq.ParamResolver({})
qubit_order = mock.Mock(cirq.QubitOrder)
result = simulator.simulate(
program=circuit, param_resolver=param_resolver, qubit_order=qubit_order, initial_state=2
Expand Down Expand Up @@ -163,9 +162,7 @@ def steps(*args, **kwargs):

simulator.simulate_moment_steps.side_effect = steps
circuit = mock.Mock(cirq.Circuit)
param_resolvers = [mock.Mock(cirq.ParamResolver), mock.Mock(cirq.ParamResolver)]
for resolver in param_resolvers:
resolver.param_dict = {}
param_resolvers = [cirq.ParamResolver({}), cirq.ParamResolver({})]
qubit_order = mock.Mock(cirq.QubitOrder)
results = simulator.simulate_sweep(
program=circuit, params=param_resolvers, qubit_order=qubit_order, initial_state=2
Expand Down
82 changes: 41 additions & 41 deletions cirq-core/cirq/study/resolver.py
Expand Up @@ -36,8 +36,11 @@
ParamResolverOrSimilarType, """Something that can be used to turn parameters into values."""
)

# Used to mark values that are not found in a dict.
_NOT_FOUND = object()

# Used to mark values that are being resolved recursively to detect loops.
_RecursionFlag = object()
_RECURSION_FLAG = object()


def _is_param_resolver_or_similar_type(obj: Any):
Expand Down Expand Up @@ -72,7 +75,7 @@ def __init__(self, param_dict: 'cirq.ParamResolverOrSimilarType' = None) -> None

self._param_hash: Optional[int] = None
self._param_dict = cast(ParamDictType, {} if param_dict is None else param_dict)
for key in self.param_dict:
for key in self._param_dict:
if isinstance(key, sympy.Expr) and not isinstance(key, sympy.Symbol):
raise TypeError(f'ParamResolver keys cannot be (non-symbol) formulas ({key})')
self._deep_eval_map: ParamDictType = {}
Expand Down Expand Up @@ -120,32 +123,30 @@ def value_of(
if v is not NotImplemented:
return v

# Handles 2 cases:
# Input is a string and maps to a number in the dictionary
# Input is a symbol and maps to a number in the dictionary
# In both cases, return it directly.
if value in self.param_dict:
# Note: if the value is in the dictionary, it will be a key type
# Add a cast to make mypy happy.
param_value = self.param_dict[cast('cirq.TParamKey', value)]
# Handle string or symbol
if isinstance(value, (str, sympy.Symbol)):
string = value if isinstance(value, str) else value.name
symbol = value if isinstance(value, sympy.Symbol) else sympy.Symbol(value)
param_value = self._param_dict.get(string, _NOT_FOUND)
if param_value is _NOT_FOUND:
param_value = self._param_dict.get(symbol, _NOT_FOUND)
if param_value is _NOT_FOUND:
# Symbol or string cannot be resolved if not in param dict; return as symbol.
return symbol
v = _resolve_value(param_value)
if v is not NotImplemented:
return v
if isinstance(param_value, str):
param_value = sympy.Symbol(param_value)
elif not isinstance(param_value, sympy.Basic):
return value # type: ignore[return-value]
if recursive:
param_value = self._value_of_recursive(value)
return param_value # type: ignore[return-value]

# Input is a string and is not in the dictionary.
# Treat it as a symbol instead.
if isinstance(value, str):
# If the string is in the param_dict as a value, return it.
# Otherwise, try using the symbol instead.
return self.value_of(sympy.Symbol(value), recursive)

# Input is a symbol (sympy.Symbol('a')) and its string maps to a number
# in the dictionary ({'a': 1.0}). Return it.
if isinstance(value, sympy.Symbol) and value.name in self.param_dict:
param_value = self.param_dict[value.name]
v = _resolve_value(param_value)
if v is not NotImplemented:
return v
if not isinstance(value, sympy.Basic):
# No known way to resolve this variable, return unchanged.
return value

# The following resolves common sympy expressions
# If sympy did its job and wasn't slower than molasses,
Expand All @@ -171,10 +172,6 @@ def value_of(
return np.float_power(cast(complex, base), cast(complex, exponent))
return np.power(cast(complex, base), cast(complex, exponent))

if not isinstance(value, sympy.Basic):
# No known way to resolve this variable, return unchanged.
return value

# Input is either a sympy formula or the dictionary maps to a
# formula. Use sympy to resolve the value.
# Note that sympy.subs() is slow, so we want to avoid this and
Expand All @@ -186,7 +183,7 @@ def value_of(
# Note that a sympy.SympifyError here likely means
# that one of the expressions was not parsable by sympy
# (such as a function returning NotImplemented)
v = value.subs(self.param_dict, simultaneous=True)
v = value.subs(self._param_dict, simultaneous=True)

if v.free_symbols:
return v
Expand All @@ -197,23 +194,26 @@ def value_of(
else:
return float(v)

return self._value_of_recursive(value)

def _value_of_recursive(self, value: 'cirq.TParamKey') -> 'cirq.TParamValComplex':
# Recursive parameter resolution. We can safely assume that value is a
# single symbol, since combinations are handled earlier in the method.
if value in self._deep_eval_map:
v = self._deep_eval_map[value]
if v is not _RecursionFlag:
return v
raise RecursionError('Evaluation of {value} indirectly contains itself.')
if v is _RECURSION_FLAG:
raise RecursionError('Evaluation of {value} indirectly contains itself.')
return v

# There isn't a full evaluation for 'value' yet. Until it's ready,
# map value to None to identify loops in component evaluation.
self._deep_eval_map[value] = _RecursionFlag # type: ignore
self._deep_eval_map[value] = _RECURSION_FLAG # type: ignore

v = self.value_of(value, recursive=False)
if v == value:
self._deep_eval_map[value] = v
else:
self._deep_eval_map[value] = self.value_of(v, recursive)
self._deep_eval_map[value] = self.value_of(v, recursive=True)
return self._deep_eval_map[value]

def _resolve_parameters_(self, resolver: 'ParamResolver', recursive: bool) -> 'ParamResolver':
Expand All @@ -224,17 +224,17 @@ def _resolve_parameters_(self, resolver: 'ParamResolver', recursive: bool) -> 'P
new_dict.update(
{k: resolver.value_of(v, recursive) for k, v in new_dict.items()} # type: ignore[misc]
)
if recursive and self.param_dict:
if recursive and self._param_dict:
new_resolver = ParamResolver(cast(ParamDictType, new_dict))
# Resolve down to single-step mappings.
return ParamResolver()._resolve_parameters_(new_resolver, recursive=True)
return ParamResolver(cast(ParamDictType, new_dict))

def __iter__(self) -> Iterator[Union[str, sympy.Expr]]:
return iter(self.param_dict)
return iter(self._param_dict)

def __bool__(self) -> bool:
return bool(self.param_dict)
return bool(self._param_dict)

def __getitem__(
self, key: Union['cirq.TParamKey', 'cirq.TParamValComplex']
Expand All @@ -243,29 +243,29 @@ def __getitem__(

def __hash__(self) -> int:
if self._param_hash is None:
self._param_hash = hash(frozenset(self.param_dict.items()))
self._param_hash = hash(frozenset(self._param_dict.items()))
return self._param_hash

def __eq__(self, other):
if not isinstance(other, ParamResolver):
return NotImplemented
return self.param_dict == other.param_dict
return self._param_dict == other._param_dict

def __ne__(self, other):
return not self == other

def __repr__(self) -> str:
param_dict_repr = (
'{'
+ ', '.join([f'{proper_repr(k)}: {proper_repr(v)}' for k, v in self.param_dict.items()])
+ ', '.join(f'{proper_repr(k)}: {proper_repr(v)}' for k, v in self._param_dict.items())
+ '}'
)
return f'cirq.ParamResolver({param_dict_repr})'

def _json_dict_(self) -> Dict[str, Any]:
return {
# JSON requires mappings to have keys of basic types.
'param_dict': list(self.param_dict.items())
'param_dict': list(self._param_dict.items())
}

@classmethod
Expand Down
25 changes: 13 additions & 12 deletions cirq-core/cirq/study/resolver_test.py
Expand Up @@ -53,10 +53,10 @@ def test_value_of_transformed_types(val, resolved):

@pytest.mark.parametrize('val,resolved', [(sympy.I, 1j)])
def test_value_of_substituted_types(val, resolved):
_assert_consistent_resolution(val, resolved, True)
_assert_consistent_resolution(val, resolved)


def _assert_consistent_resolution(v, resolved, subs_called=False):
def _assert_consistent_resolution(v, resolved):
"""Asserts that parameter resolution works consistently.
The ParamResolver.value_of method can resolve any Sympy expression -
Expand All @@ -70,7 +70,7 @@ def _assert_consistent_resolution(v, resolved, subs_called=False):
Args:
v: the value to resolve
resolved: the expected resolution result
subs_called: if True, it is expected that the slow subs method is called
Raises:
AssertionError in case resolution assertion fail.
"""
Expand All @@ -93,9 +93,7 @@ def subs(self, *args, **kwargs):
# symbol based resolution
s = SubsAwareSymbol('a')
assert r.value_of(s) == resolved, f"expected {resolved}, got {r.value_of(s)}"
assert (
subs_called == s.called
), f"For pass-through type {type(v)} sympy.subs shouldn't have been called."
assert not s.called, f"For pass-through type {type(v)} sympy.subs shouldn't have been called."
assert isinstance(
r.value_of(s), type(resolved)
), f"expected {type(resolved)} got {type(r.value_of(s))}"
Expand Down Expand Up @@ -243,15 +241,18 @@ def _resolved_value_(self):


def test_custom_value_not_implemented():
class Bar:
class BarImplicit:
pass

class BarExplicit:
def _resolved_value_(self):
return NotImplemented

b = sympy.Symbol('b')
bar = Bar()
r = cirq.ParamResolver({b: bar})
with pytest.raises(sympy.SympifyError):
_ = r.value_of(b)
for cls in [BarImplicit, BarExplicit]:
b = sympy.Symbol('b')
bar = cls()
r = cirq.ParamResolver({b: bar})
assert r.value_of(b) == b


def test_compose():
Expand Down
2 changes: 1 addition & 1 deletion cirq-ionq/cirq_ionq/sampler_test.py
Expand Up @@ -100,7 +100,7 @@ def test_sampler_multiple_jobs():
results = sampler.sample(
program=circuit,
repetitions=4,
params=[cirq.ParamResolver({x: '0.5'}), cirq.ParamResolver({x: '0.6'})],
params=[cirq.ParamResolver({x: 0.5}), cirq.ParamResolver({x: 0.6})],
)
pd.testing.assert_frame_equal(
results,
Expand Down

0 comments on commit 8d07cab

Please sign in to comment.