Skip to content

Commit

Permalink
Added __str__ and __repr__ for MappingManager and pushed name 'Mappin…
Browse files Browse the repository at this point in the history
…gMananger' to 'cirq' namespace (#5828)

* added __str__ and __repr__ for MappingManager

* made MappingManager not serializable

* removed unused import

* addressed comments

* fixed bug with edges not being sorted for graph equality testing

* fixed bug with digraphs repr method in MappingManager and added test for it

* made MappingManager serializable

* removed print statements

* ready for merging

* nit

* fix lint

* removed serialization

* removed unused imports

* fixed nit

* removed debug print
  • Loading branch information
ammareltigani committed Aug 19, 2022
1 parent 4a64b1e commit 685bddd
Show file tree
Hide file tree
Showing 5 changed files with 124 additions and 9 deletions.
1 change: 1 addition & 0 deletions cirq-core/cirq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,7 @@
eject_z,
expand_composite,
is_negligible_turn,
MappingManager,
map_moments,
map_operations,
map_operations_and_unroll,
Expand Down
2 changes: 2 additions & 0 deletions cirq-core/cirq/protocols/json_test_data/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@
# Transformers
'TransformerLogger',
'TransformerContext',
# Routing utilities
'MappingManager',
# global objects
'CONTROL_TAG',
'PAULI_BASIS',
Expand Down
2 changes: 2 additions & 0 deletions cirq-core/cirq/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@
two_qubit_gate_product_tabulation,
)

from cirq.transformers.routing import MappingManager

from cirq.transformers.target_gatesets import (
create_transformer_with_kwargs,
CompilationTargetGateset,
Expand Down
38 changes: 34 additions & 4 deletions cirq-core/cirq/transformers/routing/mapping_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@
"""Manages the mapping from logical to physical qubits during a routing procedure."""

from typing import Dict, Sequence, TYPE_CHECKING
from cirq._compat import cached_method
import networkx as nx

from cirq import protocols
from cirq import protocols, value, _compat

if TYPE_CHECKING:
import cirq


@value.value_equality
class MappingManager:
"""Class that manages the mapping from logical to physical qubits.
Expand All @@ -36,11 +36,24 @@ def __init__(
) -> None:
"""Initializes MappingManager.
Sorts the nodes and edges in the device graph to guarantee graph equality. If undirected,
also sorts the nodes within each edge.
Args:
device_graph: connectivity graph of qubits in the hardware device.
initial_mapping: the initial mapping of logical (keys) to physical qubits (values).
"""
self.device_graph = device_graph
if nx.is_directed(device_graph):
self.device_graph = nx.DiGraph()
self.device_graph.add_nodes_from(sorted(list(device_graph.nodes(data=True))))
self.device_graph.add_edges_from(sorted(list(device_graph.edges)))
else:
self.device_graph = nx.Graph()
self.device_graph.add_nodes_from(sorted(list(device_graph.nodes(data=True))))
self.device_graph.add_edges_from(
sorted(list(sorted(edge) for edge in device_graph.edges))
)

self._map = initial_mapping.copy()
self._inverse_map = {v: k for k, v in self._map.items()}
self._induced_subgraph = nx.induced_subgraph(self.device_graph, self._map.values())
Expand Down Expand Up @@ -130,6 +143,23 @@ def shortest_path(self, lq1: 'cirq.Qid', lq2: 'cirq.Qid') -> Sequence['cirq.Qid'
physical_shortest_path = self._physical_shortest_path(self._map[lq1], self._map[lq2])
return [self._inverse_map[pq] for pq in physical_shortest_path]

@cached_method
@_compat.cached_method
def _physical_shortest_path(self, pq1: 'cirq.Qid', pq2: 'cirq.Qid') -> Sequence['cirq.Qid']:
return nx.shortest_path(self._induced_subgraph, pq1, pq2)

def _value_equality_values_(self):
graph_equality = (
tuple(self.device_graph.nodes),
tuple(self.device_graph.edges),
nx.is_directed(self.device_graph),
)
map_equality = tuple(sorted(self._map.items()))
return (graph_equality, map_equality)

def __repr__(self) -> str:
graph_type = type(self.device_graph).__name__
return (
f'cirq.MappingManager('
f'nx.{graph_type}({dict(self.device_graph.adjacency())}),'
f' {self._map})'
)
90 changes: 85 additions & 5 deletions cirq-core/cirq/transformers/routing/mapping_manager_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def construct_device_graph_and_mapping():

def test_induced_subgraph():
device_graph, initial_mapping, _ = construct_device_graph_and_mapping()
mm = cirq.transformers.routing.MappingManager(device_graph, initial_mapping)
mm = cirq.MappingManager(device_graph, initial_mapping)

expected_induced_subgraph = nx.Graph(
[
Expand All @@ -55,7 +55,7 @@ def test_induced_subgraph():

def test_mapped_op():
device_graph, initial_mapping, q = construct_device_graph_and_mapping()
mm = cirq.transformers.routing.MappingManager(device_graph, initial_mapping)
mm = cirq.MappingManager(device_graph, initial_mapping)

assert mm.mapped_op(cirq.CNOT(q[1], q[3])).qubits == (
cirq.NamedQubit("a"),
Expand All @@ -82,7 +82,7 @@ def test_mapped_op():

def test_distance_on_device_and_can_execute():
device_graph, initial_mapping, q = construct_device_graph_and_mapping()
mm = cirq.transformers.routing.MappingManager(device_graph, initial_mapping)
mm = cirq.MappingManager(device_graph, initial_mapping)

# adjacent qubits have distance 1 and are thus executable
assert mm.dist_on_device(q[1], q[3]) == 1
Expand All @@ -108,7 +108,7 @@ def test_distance_on_device_and_can_execute():

def test_apply_swap():
device_graph, initial_mapping, q = construct_device_graph_and_mapping()
mm = cirq.transformers.routing.MappingManager(device_graph, initial_mapping)
mm = cirq.MappingManager(device_graph, initial_mapping)

# swapping non-adjacent qubits raises error
with pytest.raises(ValueError):
Expand All @@ -130,7 +130,7 @@ def test_apply_swap():

def test_shortest_path():
device_graph, initial_mapping, q = construct_device_graph_and_mapping()
mm = cirq.transformers.routing.MappingManager(device_graph, initial_mapping)
mm = cirq.MappingManager(device_graph, initial_mapping)

one_to_four = [q[1], q[3], q[2], q[4]]
assert mm.shortest_path(q[1], q[2]) == one_to_four[:3]
Expand All @@ -143,3 +143,83 @@ def test_shortest_path():
one_to_four[1], one_to_four[2] = one_to_four[2], one_to_four[1]
assert mm.shortest_path(q[1], q[4]) == one_to_four
assert mm.shortest_path(q[1], q[2]) == [q[1], q[2]]


def test_value_equality():
equals_tester = cirq.testing.EqualsTester()
device_graph, initial_mapping, q = construct_device_graph_and_mapping()

mm = cirq.MappingManager(device_graph, initial_mapping)

# same as 'device_graph' but with different insertion order of edges
diff_edge_order = nx.Graph(
[
(cirq.NamedQubit("a"), cirq.NamedQubit("b")),
(cirq.NamedQubit("e"), cirq.NamedQubit("d")),
(cirq.NamedQubit("c"), cirq.NamedQubit("d")),
(cirq.NamedQubit("a"), cirq.NamedQubit("e")),
(cirq.NamedQubit("b"), cirq.NamedQubit("c")),
]
)
mm_edge_order = cirq.MappingManager(diff_edge_order, initial_mapping)
equals_tester.add_equality_group(mm, mm_edge_order)

# same as 'device_graph' but with directed edges (DiGraph)
device_digraph = nx.DiGraph(
[
(cirq.NamedQubit("a"), cirq.NamedQubit("b")),
(cirq.NamedQubit("b"), cirq.NamedQubit("c")),
(cirq.NamedQubit("c"), cirq.NamedQubit("d")),
(cirq.NamedQubit("a"), cirq.NamedQubit("e")),
(cirq.NamedQubit("e"), cirq.NamedQubit("d")),
]
)
mm_digraph = cirq.MappingManager(device_digraph, initial_mapping)
equals_tester.add_equality_group(mm_digraph)

# same as 'device_graph' but with an added isolated node
isolated_vertex_graph = nx.Graph(
[
(cirq.NamedQubit("a"), cirq.NamedQubit("b")),
(cirq.NamedQubit("b"), cirq.NamedQubit("c")),
(cirq.NamedQubit("c"), cirq.NamedQubit("d")),
(cirq.NamedQubit("a"), cirq.NamedQubit("e")),
(cirq.NamedQubit("e"), cirq.NamedQubit("d")),
]
)
isolated_vertex_graph.add_node(cirq.NamedQubit("z"))
mm = cirq.MappingManager(isolated_vertex_graph, initial_mapping)
equals_tester.add_equality_group(isolated_vertex_graph)

# mapping manager with same initial graph and initial mapping as 'mm' but with different
# current state
mm_with_swap = cirq.MappingManager(device_graph, initial_mapping)
mm_with_swap.apply_swap(q[1], q[3])
equals_tester.add_equality_group(mm_with_swap)


def test_repr():
device_graph, initial_mapping, _ = construct_device_graph_and_mapping()
mm = cirq.MappingManager(device_graph, initial_mapping)
cirq.testing.assert_equivalent_repr(mm, setup_code='import cirq\nimport networkx as nx')

device_digraph = nx.DiGraph(
[
(cirq.NamedQubit("a"), cirq.NamedQubit("b")),
(cirq.NamedQubit("b"), cirq.NamedQubit("c")),
(cirq.NamedQubit("c"), cirq.NamedQubit("d")),
(cirq.NamedQubit("a"), cirq.NamedQubit("e")),
(cirq.NamedQubit("e"), cirq.NamedQubit("d")),
]
)
mm_digraph = cirq.MappingManager(device_digraph, initial_mapping)
cirq.testing.assert_equivalent_repr(mm_digraph, setup_code='import cirq\nimport networkx as nx')


def test_str():
device_graph, initial_mapping, _ = construct_device_graph_and_mapping()
mm = cirq.MappingManager(device_graph, initial_mapping)
assert (
str(mm)
== f'cirq.MappingManager(nx.Graph({dict(device_graph.adjacency())}), {initial_mapping})'
)

0 comments on commit 685bddd

Please sign in to comment.