-
Notifications
You must be signed in to change notification settings - Fork 1k
/
serializable_device.py
320 lines (269 loc) · 12.2 KB
/
serializable_device.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
# Copyright 2019 The Cirq Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Device object for converting from device specification protos"""
from typing import (
Any,
Callable,
cast,
Dict,
Iterable,
Optional,
List,
Set,
Tuple,
Type,
TYPE_CHECKING,
FrozenSet,
)
from cirq import circuits, devices
from cirq_google import serializable_gate_set
from cirq_google.api import v2
from cirq.value import Duration
if TYPE_CHECKING:
import cirq
class _GateDefinition:
"""Class for keeping track of gate definitions within SerializableDevice"""
def __init__(
self,
duration: 'cirq.DURATION_LIKE',
target_set: Set[Tuple['cirq.Qid', ...]],
number_of_qubits: int,
is_permutation: bool,
can_serialize_predicate: Callable[['cirq.Operation'], bool] = lambda x: True,
):
self.duration = Duration(duration)
self.target_set = target_set
self.is_permutation = is_permutation
self.number_of_qubits = number_of_qubits
self.can_serialize_predicate = can_serialize_predicate
# Compute the set of all qubits in all target sets.
self.flattened_qubits = {q for qubit_tuple in target_set for q in qubit_tuple}
def with_can_serialize_predicate(
self, can_serialize_predicate: Callable[['cirq.Operation'], bool]
) -> '_GateDefinition':
"""Creates a new _GateDefinition as a copy of the existing definition
but with a new with_can_serialize_predicate. This is useful if multiple
definitions exist for the same gate, but with different conditions.
An example is if gates at certain angles of a gate take longer or are
not allowed.
"""
return _GateDefinition(
self.duration,
self.target_set,
self.number_of_qubits,
self.is_permutation,
can_serialize_predicate,
)
def __eq__(self, other):
if not isinstance(other, self.__class__):
return NotImplemented
return self.__dict__ == other.__dict__
class SerializableDevice(devices.Device):
"""Device object generated from a device specification proto.
Given a device specification proto and a gate_set to translate the
serialized gate_ids to cirq Gates, this will generate a Device that can
verify operations and circuits for the hardware specified by the device.
Expected usage is through constructing this class through a proto using
the static function call from_proto().
This class only supports GridQubits and NamedQubits. NamedQubits with names
that conflict (such as "4_3") may be converted to GridQubits on
deserialization.
"""
def __init__(
self,
qubits: List['cirq.Qid'],
gate_definitions: Dict[Type['cirq.Gate'], List[_GateDefinition]],
):
"""Constructor for SerializableDevice using python objects.
Note that the preferred method of constructing this object is through
the static from_proto() call.
Args:
qubits: A list of valid Qid for the device.
gate_definitions: Maps cirq gates to device properties for that
gate.
"""
self.qubits = qubits
self.gate_definitions = gate_definitions
def qubit_set(self) -> FrozenSet['cirq.Qid']:
return frozenset(self.qubits)
@classmethod
def from_proto(
cls,
proto: v2.device_pb2.DeviceSpecification,
gate_sets: Iterable[serializable_gate_set.SerializableGateSet],
) -> 'SerializableDevice':
"""
Args:
proto: A proto describing the qubits on the device, as well as the
supported gates and timing information.
gate_set: A SerializableGateSet that can translate the gate_ids
into cirq Gates.
"""
# Store target sets, since they are referred to by name later
allowed_targets: Dict[str, Set[Tuple['cirq.Qid', ...]]] = {}
permutation_ids: Set[str] = set()
for ts in proto.valid_targets:
allowed_targets[ts.name] = cls._create_target_set(ts)
if ts.target_ordering == v2.device_pb2.TargetSet.SUBSET_PERMUTATION:
permutation_ids.add(ts.name)
# Store gate definitions from proto
gate_definitions: Dict[str, _GateDefinition] = {}
for gs in proto.valid_gate_sets:
for gate_def in gs.valid_gates:
# Combine all valid targets in the gate's listed target sets
gate_target_set = {
target
for ts_name in gate_def.valid_targets
for target in allowed_targets[ts_name]
}
which_are_permutations = [t in permutation_ids for t in gate_def.valid_targets]
is_permutation = any(which_are_permutations)
if is_permutation:
if not all(which_are_permutations):
raise NotImplementedError(
f'Id {gate_def.id} in {gs.name} mixes '
'SUBSET_PERMUTATION with other types which is not '
'currently allowed.'
)
gate_definitions[gate_def.id] = _GateDefinition(
duration=Duration(picos=gate_def.gate_duration_picos),
target_set=gate_target_set,
is_permutation=is_permutation,
number_of_qubits=gate_def.number_of_qubits,
)
# Loop through serializers and map gate_definitions to type
gates_by_type: Dict[Type['cirq.Gate'], List[_GateDefinition]] = {}
for gate_set in gate_sets:
for gate_type in gate_set.supported_gate_types():
for serializer in gate_set.serializers[gate_type]:
gate_id = serializer.serialized_gate_id
if gate_id not in gate_definitions:
raise ValueError(
f'Serializer has {gate_id} which is not supported '
'by the device specification'
)
if gate_type not in gates_by_type:
gates_by_type[gate_type] = []
gate_def = gate_definitions[gate_id].with_can_serialize_predicate(
serializer.can_serialize_predicate
)
gates_by_type[gate_type].append(gate_def)
return SerializableDevice(
qubits=[cls._qid_from_str(q) for q in proto.valid_qubits],
gate_definitions=gates_by_type,
)
@staticmethod
def _qid_from_str(id_str: str) -> 'cirq.Qid':
"""Translates a qubit id string info cirq.Qid objects.
Tries to translate to GridQubit if possible (e.g. '4_3'), otherwise
falls back to using NamedQubit.
"""
try:
return v2.grid_qubit_from_proto_id(id_str)
except ValueError:
return v2.named_qubit_from_proto_id(id_str)
@classmethod
def _create_target_set(cls, ts: v2.device_pb2.TargetSet) -> Set[Tuple['cirq.Qid', ...]]:
"""Transform a TargetSet proto into a set of qubit tuples"""
target_set = set()
for target in ts.targets:
qid_tuple = tuple(cls._qid_from_str(q) for q in target.ids)
target_set.add(qid_tuple)
if ts.target_ordering == v2.device_pb2.TargetSet.SYMMETRIC:
target_set.add(qid_tuple[::-1])
return target_set
def __str__(self) -> str:
# If all qubits are grid qubits, render an appropriate text diagram.
if all(isinstance(q, devices.GridQubit) for q in self.qubits):
diagram = circuits.TextDiagramDrawer()
qubits = cast(List['cirq.GridQubit'], self.qubits)
# Don't print out extras newlines if the row/col doesn't start at 0
min_col = min(q.col for q in qubits)
min_row = min(q.row for q in qubits)
for q in qubits:
diagram.write(q.col - min_col, q.row - min_row, str(q))
# Find pairs that are connected by two-qubit gates.
Pair = Tuple['cirq.GridQubit', 'cirq.GridQubit']
pairs = {
cast(Pair, pair)
for gate_defs in self.gate_definitions.values()
for gate_def in gate_defs
if gate_def.number_of_qubits == 2
for pair in gate_def.target_set
if len(pair) == 2
}
# Draw lines between connected pairs. Limit to horizontal/vertical
# lines since that is all the diagram drawer can handle.
for q1, q2 in sorted(pairs):
if q1.row == q2.row or q1.col == q2.col:
diagram.grid_line(
q1.col - min_col, q1.row - min_row, q2.col - min_col, q2.row - min_row
)
return diagram.render(
horizontal_spacing=3, vertical_spacing=2, use_unicode_characters=True
)
return super().__str__()
def _repr_pretty_(self, p: Any, cycle: bool) -> None:
"""Creates ASCII diagram for Jupyter, IPython, etc."""
# There should never be a cycle, but just in case use the default repr.
p.text(repr(self) if cycle else str(self))
def _find_operation_type(self, op: 'cirq.Operation') -> Optional[_GateDefinition]:
"""Finds the type (or a compatible type) of an operation from within
a dictionary with keys of Gate type.
Returns:
the value corresponding to that key or None if no type matches
"""
for type_key, gate_defs in self.gate_definitions.items():
if isinstance(op.gate, type_key):
for gate_def in gate_defs:
if gate_def.can_serialize_predicate(op):
return gate_def
return None
def duration_of(self, operation: 'cirq.Operation') -> Duration:
gate_def = self._find_operation_type(operation)
if gate_def is None:
raise ValueError(f'Operation {operation} does not have a known duration')
return gate_def.duration
def validate_operation(self, operation: 'cirq.Operation') -> None:
for q in operation.qubits:
if q not in self.qubits:
raise ValueError(f'Qubit not on device: {q!r}')
gate_def = self._find_operation_type(operation)
if gate_def is None:
raise ValueError(f'{operation} is not a supported gate')
req_num_qubits = gate_def.number_of_qubits
if req_num_qubits > 0:
if len(operation.qubits) != req_num_qubits:
raise ValueError(
f'{operation} has {len(operation.qubits)} '
f'qubits but expected {req_num_qubits}'
)
if gate_def.is_permutation:
# A permutation gate can have any combination of qubits
if not gate_def.target_set:
# All qubits are valid
return
if not all(q in gate_def.flattened_qubits for q in operation.qubits):
raise ValueError('Operation does not use valid qubits: {operation}.')
return
if len(operation.qubits) > 1:
# TODO: verify args.
# Github issue: https://github.com/quantumlib/Cirq/issues/2964
if not gate_def.target_set:
# All qubit combinations are valid
return
qubit_tuple = tuple(operation.qubits)
if qubit_tuple not in gate_def.target_set:
# Target is not within the target sets specified by the gate.
raise ValueError(f'Operation does not use valid qubit target: {operation}.')