/
resolve_parameters.py
159 lines (124 loc) · 5.39 KB
/
resolve_parameters.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
# Copyright 2018 The Cirq Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numbers
from typing import AbstractSet, Any, TypeVar, TYPE_CHECKING
from typing_extensions import Protocol
import sympy
from cirq import study
from cirq._doc import document
if TYPE_CHECKING:
import cirq
TDefault = TypeVar('TDefault')
class SupportsParameterization(Protocol):
"""An object that can be parameterized by Symbols and resolved
via a ParamResolver"""
@document
def _is_parameterized_(self: Any) -> bool:
"""Whether the object is parameterized by any Symbols that require
resolution. Returns True if the object has any unresolved Symbols
and False otherwise."""
@document
def _parameter_names_(self: Any) -> AbstractSet[str]:
"""Returns a collection of string names of parameters that require
resolution. If _is_parameterized_ is False, the collection is empty.
The converse is not necessarily true, because some objects may report
that they are parameterized when they contain symbolic constants which
need to be evaluated, but no free symbols.
"""
@document
def _resolve_parameters_(self: Any, param_resolver: 'cirq.ParamResolver'):
"""Resolve the parameters in the effect."""
def is_parameterized(val: Any) -> bool:
"""Returns whether the object is parameterized with any Symbols.
A value is parameterized when it has an `_is_parameterized_` method and
that method returns a truthy value, or if the value is an instance of
sympy.Basic.
Returns:
True if the gate has any unresolved Symbols
and False otherwise. If no implementation of the magic
method above exists or if that method returns NotImplemented,
this will default to False.
"""
if isinstance(val, sympy.Basic):
return True
if isinstance(val, numbers.Number):
return False
if isinstance(val, (list, tuple)):
return any(is_parameterized(e) for e in val)
getter = getattr(val, '_is_parameterized_', None)
result = NotImplemented if getter is None else getter()
if result is not NotImplemented:
return result
return bool(parameter_names(val))
def parameter_names(val: Any) -> AbstractSet[str]:
"""Returns parameter names for this object.
Args:
val: Object for which to find the parameter names.
check_symbols: If true, fall back to calling parameter_symbols.
Returns:
A set of parameter names if the object is parameterized. It the object
does not implement the _parameter_names_ magic method or that method
returns NotImplemented, returns an empty set.
"""
if isinstance(val, sympy.Basic):
return {symbol.name for symbol in val.free_symbols}
if isinstance(val, numbers.Number):
return set()
if isinstance(val, (list, tuple)):
return {name for e in val for name in parameter_names(e)}
getter = getattr(val, '_parameter_names_', None)
result = NotImplemented if getter is None else getter()
if result is not NotImplemented:
return result
return set()
def parameter_symbols(val: Any) -> AbstractSet[sympy.Symbol]:
"""Returns parameter symbols for this object.
Args:
val: Object for which to find the parameter symbols.
Returns:
A set of parameter symbols if the object is parameterized. It the object
does not implement the _parameter_symbols_ magic method or that method
returns NotImplemented, returns an empty set.
"""
return {sympy.Symbol(name) for name in parameter_names(val)}
def resolve_parameters(
val: Any,
param_resolver: 'cirq.ParamResolverOrSimilarType') -> Any:
"""Resolves symbol parameters in the effect using the param resolver.
This function will use the `_resolve_parameters_` magic method
of `val` to resolve any Symbols with concrete values from the given
parameter resolver.
Args:
val: The object to resolve (e.g. the gate, operation, etc)
param_resolver: the object to use for resolving all symbols
Returns:
a gate or operation of the same type, but with all Symbols
replaced with floats according to the given ParamResolver.
If `val` has no `_resolve_parameters_` method or if it returns
NotImplemented, `val` itself is returned.
"""
if not param_resolver:
return val
# Ensure its a dictionary wrapped in a ParamResolver.
param_resolver = study.ParamResolver(param_resolver)
if isinstance(val, sympy.Basic):
return param_resolver.value_of(val)
if isinstance(val, (list, tuple)):
return type(val)(resolve_parameters(e, param_resolver) for e in val)
getter = getattr(val, '_resolve_parameters_', None)
result = NotImplemented if getter is None else getter(param_resolver)
if result is not NotImplemented:
return result
else:
return val