-
Notifications
You must be signed in to change notification settings - Fork 981
/
sweeps.py
157 lines (140 loc) · 6.5 KB
/
sweeps.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
# Copyright 2019 The Cirq Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import cast, Dict, List, Optional
import sympy
import cirq
from cirq_google.api.v2 import run_context_pb2
from cirq_google.study.device_parameter import DeviceParameter
def sweep_to_proto(
sweep: cirq.Sweep, *, out: Optional[run_context_pb2.Sweep] = None
) -> run_context_pb2.Sweep:
"""Converts a Sweep to v2 protobuf message.
Args:
sweep: The sweep to convert.
out: Optional message to be populated. If not given, a new message will
be created.
Returns:
Populated sweep protobuf message.
Raises:
ValueError: If the conversion cannot be completed successfully.
"""
if out is None:
out = run_context_pb2.Sweep()
if sweep is cirq.UnitSweep:
pass
elif isinstance(sweep, cirq.Product):
out.sweep_function.function_type = run_context_pb2.SweepFunction.PRODUCT
for factor in sweep.factors:
sweep_to_proto(factor, out=out.sweep_function.sweeps.add())
elif isinstance(sweep, cirq.Zip):
out.sweep_function.function_type = run_context_pb2.SweepFunction.ZIP
for s in sweep.sweeps:
sweep_to_proto(s, out=out.sweep_function.sweeps.add())
elif isinstance(sweep, cirq.Linspace) and not isinstance(sweep.key, sympy.Expr):
out.single_sweep.parameter_key = sweep.key
out.single_sweep.linspace.first_point = sweep.start
out.single_sweep.linspace.last_point = sweep.stop
out.single_sweep.linspace.num_points = sweep.length
# Use duck-typing to support google-internal Parameter objects
if sweep.metadata and getattr(sweep.metadata, 'path', None):
out.single_sweep.parameter.path.extend(sweep.metadata.path)
if sweep.metadata and getattr(sweep.metadata, 'idx', None):
out.single_sweep.parameter.idx = sweep.metadata.idx
if sweep.metadata and getattr(sweep.metadata, 'units', None):
out.single_sweep.parameter.units = sweep.metadata.units
elif isinstance(sweep, cirq.Points) and not isinstance(sweep.key, sympy.Expr):
out.single_sweep.parameter_key = sweep.key
out.single_sweep.points.points.extend(sweep.points)
# Use duck-typing to support google-internal Parameter objects
if sweep.metadata and getattr(sweep.metadata, 'path', None):
out.single_sweep.parameter.path.extend(sweep.metadata.path)
if sweep.metadata and getattr(sweep.metadata, 'idx', None):
out.single_sweep.parameter.idx = sweep.metadata.idx
if sweep.metadata and getattr(sweep.metadata, 'units', None):
out.single_sweep.parameter.units = sweep.metadata.units
elif isinstance(sweep, cirq.ListSweep):
sweep_dict: Dict[str, List[float]] = {}
for param_resolver in sweep:
for key in param_resolver:
if key not in sweep_dict:
sweep_dict[cast(str, key)] = []
sweep_dict[cast(str, key)].append(cast(float, param_resolver.value_of(key)))
out.sweep_function.function_type = run_context_pb2.SweepFunction.ZIP
for key in sweep_dict:
sweep_to_proto(cirq.Points(key, sweep_dict[key]), out=out.sweep_function.sweeps.add())
else:
raise ValueError(f'cannot convert to v2 Sweep proto: {sweep}')
return out
def sweep_from_proto(msg: run_context_pb2.Sweep) -> cirq.Sweep:
"""Creates a Sweep from a v2 protobuf message."""
which = msg.WhichOneof('sweep')
if which is None:
return cirq.UnitSweep
if which == 'sweep_function':
factors = [sweep_from_proto(m) for m in msg.sweep_function.sweeps]
func_type = msg.sweep_function.function_type
if func_type == run_context_pb2.SweepFunction.PRODUCT:
return cirq.Product(*factors)
if func_type == run_context_pb2.SweepFunction.ZIP:
return cirq.Zip(*factors)
raise ValueError(f'invalid sweep function type: {func_type}')
if which == 'single_sweep':
key = msg.single_sweep.parameter_key
if msg.single_sweep.HasField("parameter"):
metadata = DeviceParameter(
path=msg.single_sweep.parameter.path,
idx=(
msg.single_sweep.parameter.idx
if msg.single_sweep.parameter.HasField("idx")
else None
),
units=(
msg.single_sweep.parameter.units
if msg.single_sweep.parameter.HasField("units")
else None
),
)
else:
metadata = None
if msg.single_sweep.WhichOneof('sweep') == 'linspace':
return cirq.Linspace(
key=key,
start=msg.single_sweep.linspace.first_point,
stop=msg.single_sweep.linspace.last_point,
length=msg.single_sweep.linspace.num_points,
metadata=metadata,
)
if msg.single_sweep.WhichOneof('sweep') == 'points':
return cirq.Points(key=key, points=msg.single_sweep.points.points, metadata=metadata)
raise ValueError(f'single sweep type not set: {msg}')
raise ValueError(f'sweep type not set: {msg}') # pragma: no cover
def run_context_to_proto(
sweepable: cirq.Sweepable, repetitions: int, *, out: Optional[run_context_pb2.RunContext] = None
) -> run_context_pb2.RunContext:
"""Populates a RunContext protobuf message.
Args:
sweepable: The sweepable to include in the run context.
repetitions: The number of repetitions for the run context.
out: Optional message to be populated. If not given, a new message will
be created.
Returns:
Populated RunContext protobuf message.
"""
if out is None:
out = run_context_pb2.RunContext()
for sweep in cirq.to_sweeps(sweepable):
sweep_proto = out.parameter_sweeps.add()
sweep_proto.repetitions = repetitions
sweep_to_proto(sweep, out=sweep_proto.sweep)
return out