Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added __str__ and __repr__ for MappingManager and pushed name 'MappingMananger' to 'cirq' namespace #5828

Merged
Merged
1 change: 1 addition & 0 deletions cirq-core/cirq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,7 @@
eject_z,
expand_composite,
is_negligible_turn,
MappingManager,
map_moments,
map_operations,
map_operations_and_unroll,
Expand Down
2 changes: 2 additions & 0 deletions cirq-core/cirq/protocols/json_test_data/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@
# Transformers
'TransformerLogger',
'TransformerContext',
# Routing classes
'MappingManager',
# global objects
'CONTROL_TAG',
'PAULI_BASIS',
Expand Down
2 changes: 2 additions & 0 deletions cirq-core/cirq/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@
two_qubit_gate_product_tabulation,
)

from cirq.transformers.routing import MappingManager

from cirq.transformers.target_gatesets import (
create_transformer_with_kwargs,
CompilationTargetGateset,
Expand Down
31 changes: 28 additions & 3 deletions cirq-core/cirq/transformers/routing/mapping_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@
"""Manages the mapping from logical to physical qubits during a routing procedure."""

from typing import Dict, Sequence, TYPE_CHECKING
from cirq._compat import cached_method
import networkx as nx

from cirq import protocols
from cirq import protocols, value, _compat

if TYPE_CHECKING:
import cirq


@value.value_equality
class MappingManager:
"""Class that manages the mapping from logical to physical qubits.

Expand Down Expand Up @@ -130,6 +130,31 @@ def shortest_path(self, lq1: 'cirq.Qid', lq2: 'cirq.Qid') -> Sequence['cirq.Qid'
physical_shortest_path = self._physical_shortest_path(self._map[lq1], self._map[lq2])
return [self._inverse_map[pq] for pq in physical_shortest_path]

@cached_method
@_compat.cached_method
def _physical_shortest_path(self, pq1: 'cirq.Qid', pq2: 'cirq.Qid') -> Sequence['cirq.Qid']:
return nx.shortest_path(self._induced_subgraph, pq1, pq2)

def _value_equality_values_(self):
if nx.is_directed(self.device_graph):
graph_equality = (
tuple(sorted(self.device_graph.nodes)),
tuple(sorted(self.device_graph.edges)),
True,
)
else:
graph_equality = (
tuple(sorted(self.device_graph.nodes)),
tuple(sorted(tuple(sorted(edge)) for edge in self.device_graph.edges)),
False,
)
map_equality = tuple(sorted(self._map.items()))
ammareltigani marked this conversation as resolved.
Show resolved Hide resolved
return (graph_equality, map_equality)

def __str__(self) -> str:
return self.__repr__()
ammareltigani marked this conversation as resolved.
Show resolved Hide resolved

def __repr__(self) -> str:
graph_type = 'nx.DiGraph' if nx.is_directed(self.device_graph) else 'nx.Graph'
ammareltigani marked this conversation as resolved.
Show resolved Hide resolved
return (
f'cirq.MappingManager({graph_type}({dict(self.device_graph.adjacency())}), {self._map})'
)
92 changes: 87 additions & 5 deletions cirq-core/cirq/transformers/routing/mapping_manager_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def construct_device_graph_and_mapping():

def test_induced_subgraph():
device_graph, initial_mapping, _ = construct_device_graph_and_mapping()
mm = cirq.transformers.routing.MappingManager(device_graph, initial_mapping)
mm = cirq.MappingManager(device_graph, initial_mapping)

expected_induced_subgraph = nx.Graph(
[
Expand All @@ -55,7 +55,7 @@ def test_induced_subgraph():

def test_mapped_op():
device_graph, initial_mapping, q = construct_device_graph_and_mapping()
mm = cirq.transformers.routing.MappingManager(device_graph, initial_mapping)
mm = cirq.MappingManager(device_graph, initial_mapping)

assert mm.mapped_op(cirq.CNOT(q[1], q[3])).qubits == (
cirq.NamedQubit("a"),
Expand All @@ -82,7 +82,7 @@ def test_mapped_op():

def test_distance_on_device_and_can_execute():
device_graph, initial_mapping, q = construct_device_graph_and_mapping()
mm = cirq.transformers.routing.MappingManager(device_graph, initial_mapping)
mm = cirq.MappingManager(device_graph, initial_mapping)

# adjacent qubits have distance 1 and are thus executable
assert mm.dist_on_device(q[1], q[3]) == 1
Expand All @@ -108,7 +108,7 @@ def test_distance_on_device_and_can_execute():

def test_apply_swap():
device_graph, initial_mapping, q = construct_device_graph_and_mapping()
mm = cirq.transformers.routing.MappingManager(device_graph, initial_mapping)
mm = cirq.MappingManager(device_graph, initial_mapping)

# swapping non-adjacent qubits raises error
with pytest.raises(ValueError):
Expand All @@ -130,7 +130,7 @@ def test_apply_swap():

def test_shortest_path():
device_graph, initial_mapping, q = construct_device_graph_and_mapping()
mm = cirq.transformers.routing.MappingManager(device_graph, initial_mapping)
mm = cirq.MappingManager(device_graph, initial_mapping)

one_to_four = [q[1], q[3], q[2], q[4]]
assert mm.shortest_path(q[1], q[2]) == one_to_four[:3]
Expand All @@ -143,3 +143,85 @@ def test_shortest_path():
one_to_four[1], one_to_four[2] = one_to_four[2], one_to_four[1]
assert mm.shortest_path(q[1], q[4]) == one_to_four
assert mm.shortest_path(q[1], q[2]) == [q[1], q[2]]


ammareltigani marked this conversation as resolved.
Show resolved Hide resolved
def test_value_equality():
equals_tester = cirq.testing.EqualsTester()
device_graph, initial_mapping, q = construct_device_graph_and_mapping()

mm = cirq.MappingManager(device_graph, initial_mapping)

# same as 'device_graph' but with different insertion order of edges
diff_edge_order = nx.Graph(
[
(cirq.NamedQubit("a"), cirq.NamedQubit("b")),
(cirq.NamedQubit("e"), cirq.NamedQubit("d")),
(cirq.NamedQubit("c"), cirq.NamedQubit("d")),
(cirq.NamedQubit("a"), cirq.NamedQubit("e")),
(cirq.NamedQubit("b"), cirq.NamedQubit("c")),
]
)
mm_edge_order = cirq.MappingManager(diff_edge_order, initial_mapping)
# print(mm._value_equality_values_())
# print(mm_edge_order._value_equality_values_())
ammareltigani marked this conversation as resolved.
Show resolved Hide resolved
equals_tester.add_equality_group(mm, mm_edge_order)

# same as 'device_graph' but with directed edges (DiGraph)
device_digraph = nx.DiGraph(
[
(cirq.NamedQubit("a"), cirq.NamedQubit("b")),
(cirq.NamedQubit("b"), cirq.NamedQubit("c")),
(cirq.NamedQubit("c"), cirq.NamedQubit("d")),
(cirq.NamedQubit("a"), cirq.NamedQubit("e")),
(cirq.NamedQubit("e"), cirq.NamedQubit("d")),
]
)
mm_digraph = cirq.MappingManager(device_digraph, initial_mapping)
equals_tester.add_equality_group(mm_digraph)

# same as 'device_graph' but with an added isolated node
isolated_vertex_graph = nx.Graph(
[
(cirq.NamedQubit("a"), cirq.NamedQubit("b")),
(cirq.NamedQubit("b"), cirq.NamedQubit("c")),
(cirq.NamedQubit("c"), cirq.NamedQubit("d")),
(cirq.NamedQubit("a"), cirq.NamedQubit("e")),
(cirq.NamedQubit("e"), cirq.NamedQubit("d")),
]
)
isolated_vertex_graph.add_node(cirq.NamedQubit("z"))
mm = cirq.MappingManager(isolated_vertex_graph, initial_mapping)
equals_tester.add_equality_group(isolated_vertex_graph)

# mapping manager with same initial graph and initial mapping as 'mm' but with different
# current state
mm_with_swap = cirq.MappingManager(device_graph, initial_mapping)
mm_with_swap.apply_swap(q[1], q[3])
equals_tester.add_equality_group(mm_with_swap)


def test_repr():
device_graph, initial_mapping, _ = construct_device_graph_and_mapping()
mm = cirq.MappingManager(device_graph, initial_mapping)
cirq.testing.assert_equivalent_repr(mm, setup_code='import cirq\nimport networkx as nx')
ammareltigani marked this conversation as resolved.
Show resolved Hide resolved

device_digraph = nx.DiGraph(
[
(cirq.NamedQubit("a"), cirq.NamedQubit("b")),
(cirq.NamedQubit("b"), cirq.NamedQubit("c")),
(cirq.NamedQubit("c"), cirq.NamedQubit("d")),
(cirq.NamedQubit("a"), cirq.NamedQubit("e")),
(cirq.NamedQubit("e"), cirq.NamedQubit("d")),
]
)
mm_digraph = cirq.MappingManager(device_digraph, initial_mapping)
cirq.testing.assert_equivalent_repr(mm_digraph, setup_code='import cirq\nimport networkx as nx')


def test_str():
device_graph, initial_mapping, _ = construct_device_graph_and_mapping()
mm = cirq.MappingManager(device_graph, initial_mapping)
assert (
str(mm)
== f'cirq.MappingManager(nx.Graph({dict(device_graph.adjacency())}), {initial_mapping})'
)