Skip to content

Commit

Permalink
Backwards-compatibility behavior for resolve_parameters. (#3719)
Browse files Browse the repository at this point in the history
Fixes #3714.
  • Loading branch information
95-martin-orion committed Feb 5, 2021
1 parent abfa2af commit a163a39
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 1 deletion.
16 changes: 15 additions & 1 deletion cirq/protocols/resolve_parameters.py
Expand Up @@ -142,6 +142,8 @@ def resolve_parameters(
Raises:
RecursionError if the ParamResolver detects a loop in resolution.
ValueError if `recursive=False` is passed to an external
_resolve_parameters_ method with no `recursive` parameter.
"""
if not param_resolver:
return val
Expand All @@ -154,7 +156,19 @@ def resolve_parameters(
return type(val)(resolve_parameters(e, param_resolver, recursive) for e in val)

getter = getattr(val, '_resolve_parameters_', None)
result = NotImplemented if getter is None else getter(param_resolver, recursive)
if getter is None:
result = NotImplemented
# Backwards-compatibility for external _resolve_parameters_ methods.
# TODO: remove in Cirq v0.11.0
elif 'recursive' in getter.__code__.co_varnames:
result = getter(param_resolver, recursive)
else:
if not recursive:
raise ValueError(
f'Object type {type(val)} does not support non-recursive parameter resolution.'
' This must be updated before Cirq v0.11.'
)
result = getter(param_resolver)

if result is not NotImplemented:
return result
Expand Down
19 changes: 19 additions & 0 deletions cirq/protocols/resolve_parameters_test.py
Expand Up @@ -143,3 +143,22 @@ def test_recursive_resolve():
assert cirq.resolve_parameters_once(a, resolver) == b
with pytest.raises(RecursionError):
_ = cirq.resolve_parameters(a, resolver)


# TODO: remove in Cirq v0.11
def test_backwards_compatible():
a, b, c = [sympy.Symbol(l) for l in 'abc']
resolver = cirq.ParamResolver({a: b + 3, b: c + 2, c: 1})

class SymbolSum:
def __init__(self, *syms):
self.syms = [*syms]

def _resolve_parameters_(self, pr):
return sum([cirq.resolve_parameters(s, pr) for s in self.syms])

ssum = SymbolSum(a, b, c)
assert cirq.resolve_parameters(ssum, resolver) == 10
assert cirq.resolve_parameters(ssum, resolver, recursive=True) == 10
with pytest.raises(ValueError):
_ = cirq.resolve_parameters(ssum, resolver, recursive=False)

0 comments on commit a163a39

Please sign in to comment.