Skip to content

Commit

Permalink
Allow custom definition for _sympy_pass_through (#3921)
Browse files Browse the repository at this point in the history
  • Loading branch information
zchen088 committed Mar 17, 2021
1 parent 229b204 commit 38cb2f2
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 10 deletions.
10 changes: 10 additions & 0 deletions cirq/protocols/resolve_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,16 @@ def _resolve_parameters_(self: Any, param_resolver: 'cirq.ParamResolver', recurs
"""Resolve the parameters in the effect."""


class ResolvableValue(Protocol):
@doc_private
def _resolved_value_(self) -> Any:
"""Returns a resolved value during parameter resolution.
Use this to mark a custom type as "resolved", instead of requiring
further parsing like we do with Sympy symbols.
"""


def is_parameterized(val: Any) -> bool:
"""Returns whether the object is parameterized with any Symbols.
Expand Down
27 changes: 17 additions & 10 deletions cirq/study/resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""Resolves ParameterValues to assigned values."""
import numbers
from typing import Any, Dict, Iterator, Optional, TYPE_CHECKING, Union, cast

import numpy as np
import sympy
from sympy.core import numbers as sympy_numbers
Expand All @@ -34,6 +35,9 @@
"""Something that can be used to turn parameters into values.""",
)

# Used to mark values that are being resolved recursively to detect loops.
_RecursionFlag = object()


class ParamResolver:
"""Resolves parameters to actual values.
Expand Down Expand Up @@ -96,8 +100,8 @@ def value_of(
"""

# Input is a pass through type, no resolution needed: return early
v = _sympy_pass_through(value)
if v is not None:
v = _resolve_value(value)
if v is not NotImplemented:
return v

# Handles 2 cases:
Expand All @@ -106,8 +110,8 @@ def value_of(
# In both cases, return it directly.
if value in self.param_dict:
param_value = self.param_dict[value]
v = _sympy_pass_through(param_value)
if v is not None:
v = _resolve_value(param_value)
if v is not NotImplemented:
return v

# Input is a string and is not in the dictionary.
Expand All @@ -121,8 +125,8 @@ def value_of(
# in the dictionary ({'a': 1.0}). Return it.
if isinstance(value, sympy.Symbol) and value.name in self.param_dict:
param_value = self.param_dict[value.name]
v = _sympy_pass_through(param_value)
if v is not None:
v = _resolve_value(param_value)
if v is not NotImplemented:
return v

# The following resolves common sympy expressions
Expand Down Expand Up @@ -166,13 +170,13 @@ def value_of(
# single symbol, since combinations are handled earlier in the method.
if value in self._deep_eval_map:
v = self._deep_eval_map[value]
if v is not None:
if v is not _RecursionFlag:
return v
raise RecursionError('Evaluation of {value} indirectly contains itself.')

# There isn't a full evaluation for 'value' yet. Until it's ready,
# map value to None to identify loops in component evaluation.
self._deep_eval_map[value] = None
self._deep_eval_map[value] = _RecursionFlag

v = self.value_of(value, recursive=False)
if v == value:
Expand Down Expand Up @@ -235,7 +239,7 @@ def _from_json_dict_(cls, param_dict, **kwargs):
return cls(dict(param_dict))


def _sympy_pass_through(val: Any) -> Optional[Any]:
def _resolve_value(val: Any) -> Any:
if isinstance(val, numbers.Number) and not isinstance(val, sympy.Basic):
return val
if isinstance(val, sympy_numbers.IntegerConstant):
Expand All @@ -244,4 +248,7 @@ def _sympy_pass_through(val: Any) -> Optional[Any]:
return val.p / val.q
if val == sympy.pi:
return np.pi
return None

getter = getattr(val, '_resolved_value_', None)
result = NotImplemented if getter is None else getter()
return result
26 changes: 26 additions & 0 deletions cirq/study/resolver_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,32 @@ def test_resolve_unknown_type():
assert r.value_of(cirq.X) == cirq.X


def test_custom_resolved_value():
class Foo:
def _resolved_value_(self):
return self

class Bar:
def _resolved_value_(self):
return NotImplemented

class Baz:
def _resolved_value_(self):
return 'Baz'

foo = Foo()
bar = Bar()
baz = Baz()

a = sympy.Symbol('a')
b = sympy.Symbol('b')
c = sympy.Symbol('c')
r = cirq.ParamResolver({a: foo, b: bar, c: baz})
assert r.value_of(a) is foo
assert r.value_of(b) is b
assert r.value_of(c) == 'Baz'


def test_compose():
"""
Calling cirq.resolve_paramters on a ParamResolver composes that resolver
Expand Down

0 comments on commit 38cb2f2

Please sign in to comment.