-
Notifications
You must be signed in to change notification settings - Fork 989
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow custom definition for _sympy_pass_through #3921
Allow custom definition for _sympy_pass_through #3921
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems reasonable to me. Once we settle on the design, we should update the value_of
docs (and probably the resolve_parameters
protocol docs) to describe this mechanism and how it can be used on custom types that one might want to use when resolving parameters.
cirq/study/resolver.py
Outdated
@@ -244,4 +244,9 @@ def _sympy_pass_through(val: Any) -> Optional[Any]: | |||
return val.p / val.q | |||
if val == sympy.pi: | |||
return np.pi | |||
|
|||
getter = getattr(val, '_sympy_pass_through_', None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it'd be helpful to change the name, because the point of defining a special method is to broaden this beyond the hard-coded sympy handling. In addition, the code above allows transforming values, not just passing them through (e.g. sympy integers become plain old integers). So maybe we could call the special method something like _resolver_value_
and if it exists we call it to get a value to pass through, which could be different from val
itself. The method could return NotImplemented
to indicate that the value should not be passed through. We could also change this function itself to use NotImplemented
as the sentinel value instead of None
if we want to allow things to resolve to None
(the function is only called in value_of
above, so that change wouldn't affect any other code).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could also change this function itself to use NotImplemented as the sentinel value instead of None if we want to allow things to resolve to None (the function is only called in value_of above, so that change wouldn't affect any other code).
I tried this but it turns out that we're relying on None
as a sentinel in some subtle way that I couldn't quite unravel, probably because of this line: https://github.com/quantumlib/Cirq/blob/master/cirq/study/resolver.py#L175
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That None
was me :p A valid workaround would be to replace None
with a file-local class used only for flagging potential loops in recursive parameter evaluation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm...I guess replacing wouldn't resolve this entirely, since None
is checked for earlier in value_of
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@95-martin-orion thanks for the hint, got all the tests to pass now, check if this is what you had in mind?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd put a leading underscore on RecursionFlag
just to make it abundantly clear that it's a non-public object, but otherwise this looks right.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 to leading underscore. I'd also suggest adding a comment that this is used to mark values that are being resolved recursively to detect loops.
cirq/study/resolver_test.py
Outdated
|
||
class Bar: | ||
def _sympy_pass_through_(self): | ||
return False | ||
pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe add a case here that has _resolver_value_
method that returns NotImplemented
?
Our requirement for "magic methods" (i.e. Cirq/cirq/protocols/resolve_parameters.py Line 28 in 8cf7825
To follow the same rule for |
@95-martin-orion, the problem with adding this in |
Protocols do not require top-level methods - see Cirq/cirq/protocols/json_serialization.py Line 330 in 06e4892
|
cirq/study/resolver.py
Outdated
@@ -237,6 +237,16 @@ def _from_json_dict_(cls, param_dict, **kwargs): | |||
return cls(dict(param_dict)) | |||
|
|||
|
|||
class ResolvableValue(Protocol): | |||
@doc_private | |||
def _resolver_value_(self) -> Any: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we should call it _resolved_value_
to match what's in the docstring below (I think that explanation is pretty clear). Also, I'd suggest moving this protocol into resolve_parameters.py
alongside the other protocols that are related to parameter resolution, so we can document the whole process there.
cirq/study/resolver.py
Outdated
further parsing like we do with Sympy symbols. | ||
""" | ||
|
||
|
||
def _sympy_pass_through(val: Any) -> Any: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rename this _resolve_value
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One remaining nit, then this looks good to go.
cirq/study/resolver_test.py
Outdated
@@ -230,6 +230,32 @@ def test_resolve_unknown_type(): | |||
assert r.value_of(cirq.X) == cirq.X | |||
|
|||
|
|||
def test_custom_sympy_pass_through(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
test_custom_sympy_pass_through
-> test_custom_resolved_value
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Fixes #3916