/
circuit_dag.py
204 lines (159 loc) · 6.81 KB
/
circuit_dag.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
# Copyright 2018 The Cirq Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Dict, Generic, Iterator, TypeVar, cast, TYPE_CHECKING
import functools
import networkx
from cirq import ops
from cirq.circuits import circuit
if TYPE_CHECKING:
import cirq
T = TypeVar('T')
@functools.total_ordering
class Unique(Generic[T]):
"""A wrapper for a value that doesn't compare equal to other instances.
For example: 5 == 5 but Unique(5) != Unique(5).
Unique is used by CircuitDag to wrap operations because nodes in a graph
are considered the same node if they compare equal to each other. For
example, `X(q0)` in one moment of a circuit, and `X(q0)` in another moment
of the circuit are wrapped by `cirq.Unique(X(q0))` so they are distinct
nodes in the graph.
"""
def __init__(self, val: T) -> None:
self.val = val
def __repr__(self) -> str:
return f'cirq.contrib.Unique({id(self)}, {self.val!r})'
def __lt__(self, other):
if not isinstance(other, type(self)):
return NotImplemented
return id(self) < id(other)
def _disjoint_qubits(op1: 'cirq.Operation', op2: 'cirq.Operation') -> bool:
"""Returns true only if the operations have qubits in common."""
return not set(op1.qubits) & set(op2.qubits)
class CircuitDag(networkx.DiGraph):
"""A representation of a Circuit as a directed acyclic graph.
Nodes of the graph are instances of Unique containing each operation of a
circuit.
Edges of the graph are tuples of nodes. Each edge specifies a required
application order between two operations. The first must be applied before
the second.
The graph is maximalist (transitive completion).
"""
disjoint_qubits = staticmethod(_disjoint_qubits)
def __init__(
self,
can_reorder: Callable[['cirq.Operation', 'cirq.Operation'], bool] = _disjoint_qubits,
incoming_graph_data: Any = None,
) -> None:
"""Initializes a CircuitDag.
Args:
can_reorder: A predicate that determines if two operations may be
reordered. Graph edges are created for pairs of operations
where this returns False.
The default predicate allows reordering only when the operations
don't share common qubits.
incoming_graph_data: Data in initialize the graph. This can be any
value supported by networkx.DiGraph() e.g. an edge list or
another graph.
device: Hardware that the circuit should be able to run on.
"""
super().__init__(incoming_graph_data)
self.can_reorder = can_reorder
@staticmethod
def make_node(op: 'cirq.Operation') -> Unique:
return Unique(op)
@staticmethod
def from_circuit(
circuit: circuit.Circuit,
can_reorder: Callable[['cirq.Operation', 'cirq.Operation'], bool] = _disjoint_qubits,
) -> 'CircuitDag':
return CircuitDag.from_ops(circuit.all_operations(), can_reorder=can_reorder)
@staticmethod
def from_ops(
*operations: 'cirq.OP_TREE',
can_reorder: Callable[['cirq.Operation', 'cirq.Operation'], bool] = _disjoint_qubits,
) -> 'CircuitDag':
dag = CircuitDag(can_reorder=can_reorder)
for op in ops.flatten_op_tree(operations):
dag.append(cast(ops.Operation, op))
return dag
def append(self, op: 'cirq.Operation') -> None:
new_node = self.make_node(op)
for node in list(self.nodes()):
if not self.can_reorder(node.val, op):
self.add_edge(node, new_node)
for pred in self.pred[node]:
self.add_edge(pred, new_node)
self.add_node(new_node)
def __eq__(self, other):
if not isinstance(other, type(self)):
return NotImplemented
g1 = self.copy()
g2 = other.copy()
for node, attr in g1.nodes(data=True):
attr['val'] = node.val
for node, attr in g2.nodes(data=True):
attr['val'] = node.val
def node_match(attr1: Dict[Any, Any], attr2: Dict[Any, Any]) -> bool:
return attr1['val'] == attr2['val']
return networkx.is_isomorphic(g1, g2, node_match=node_match)
def __ne__(self, other):
return not self == other
__hash__ = None # type: ignore
def ordered_nodes(self) -> Iterator[Unique['cirq.Operation']]:
if not self.nodes():
return
g = self.copy()
def get_root_node(some_node: Unique['cirq.Operation']) -> Unique['cirq.Operation']:
pred = g.pred
while pred[some_node]:
some_node = next(iter(pred[some_node]))
return some_node
def get_first_node() -> Unique['cirq.Operation']:
return get_root_node(next(iter(g.nodes())))
def get_next_node(succ: networkx.classes.coreviews.AtlasView) -> Unique['cirq.Operation']:
if succ:
return get_root_node(next(iter(succ)))
return get_first_node()
node = get_first_node()
while True:
yield node
succ = g.succ[node]
g.remove_node(node)
if not g.nodes():
return
node = get_next_node(succ)
def all_operations(self) -> Iterator['cirq.Operation']:
return (node.val for node in self.ordered_nodes())
def all_qubits(self):
return frozenset(q for node in self.nodes for q in node.val.qubits)
def to_circuit(self) -> circuit.Circuit:
return circuit.Circuit(self.all_operations(), strategy=circuit.InsertStrategy.EARLIEST)
def findall_nodes_until_blocked(
self, is_blocker: Callable[['cirq.Operation'], bool]
) -> Iterator[Unique['cirq.Operation']]:
"""Finds all nodes before blocking ones.
Args:
is_blocker: The predicate that indicates whether or not an
operation is blocking.
"""
remaining_dag = self.copy()
for node in self.ordered_nodes():
if node not in remaining_dag:
continue
if is_blocker(node.val):
successors = list(remaining_dag.succ[node])
remaining_dag.remove_nodes_from(successors)
remaining_dag.remove_node(node)
continue
yield node