Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cirq-core/cirq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@
LineTopology,
TiltedSquareLattice,
get_placements,
is_valid_placement,
draw_placements,
)

Expand Down
1 change: 1 addition & 0 deletions cirq-core/cirq/devices/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
LineTopology,
TiltedSquareLattice,
get_placements,
is_valid_placement,
draw_placements,
)

Expand Down
59 changes: 57 additions & 2 deletions cirq-core/cirq/devices/named_topologies.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,18 @@
import abc
import warnings
from dataclasses import dataclass
from typing import Dict, List, Tuple, Any, Sequence, Union, Iterable, TYPE_CHECKING
from typing import (
Dict,
List,
Tuple,
Any,
Sequence,
Union,
Iterable,
TYPE_CHECKING,
Callable,
Optional,
)

import networkx as nx
from matplotlib import pyplot as plt
Expand Down Expand Up @@ -290,13 +301,45 @@ def get_placements(
return small_to_bigs


def _is_valid_placement_helper(
big_graph: nx.Graph, small_mapped: nx.Graph, small_to_big_mapping: Dict
):
"""Helper function for `is_valid_placement` that assumes the mapping of `small_graph` has
already occurred.

This is so we don't duplicate work when checking placements during `draw_placements`.
"""
subgraph = big_graph.subgraph(small_to_big_mapping.values())
return (subgraph.nodes == small_mapped.nodes) and (subgraph.edges == small_mapped.edges)


def is_valid_placement(big_graph: nx.Graph, small_graph: nx.Graph, small_to_big_mapping: Dict):
"""Return whether the given placement is a valid placement of small_graph onto big_graph.

This is done by making sure all the nodes and edges on the mapped version of `small_graph`
are present in `big_graph`.

Args:
big_graph: A larger graph we're placing `small_graph` onto.
small_graph: A smaller, (potential) sub-graph to validate the given mapping.
small_to_big_mapping: A mapping from `small_graph` nodes to `big_graph`
nodes. After the mapping occurs, we check whether all of the mapped nodes and
edges exist on `big_graph`.
"""
small_mapped = nx.relabel_nodes(small_graph, small_to_big_mapping)
return _is_valid_placement_helper(
big_graph=big_graph, small_mapped=small_mapped, small_to_big_mapping=small_to_big_mapping
)


def draw_placements(
big_graph: nx.Graph,
small_graph: nx.Graph,
small_to_big_mappings: Sequence[Dict],
max_plots: int = 20,
axes: Sequence[plt.Axes] = None,
tilted=True,
tilted: bool = True,
bad_placement_callback: Optional[Callable[[plt.Axes, int], None]] = None,
):
"""Draw a visualization of placements from small_graph onto big_graph using Matplotlib.

Expand All @@ -312,6 +355,9 @@ def draw_placements(
`max_plots` plots.
axes: Optional list of matplotlib Axes to contain the drawings.
tilted: Whether to draw gridlike graphs in the ordinary cartesian or tilted plane.
bad_placement_callback: If provided, we check that the given mappings are valid. If not,
this callback is called. The callback should accept `ax` and `i` keyword arguments
for the current axis and mapping index, respectively.
"""
if len(small_to_big_mappings) > max_plots:
# coverage: ignore
Expand All @@ -331,6 +377,15 @@ def draw_placements(
ax = plt.gca()

small_mapped = nx.relabel_nodes(small_graph, small_to_big_map)
if bad_placement_callback is not None:
# coverage: ignore
if not _is_valid_placement_helper(
big_graph=big_graph,
small_mapped=small_mapped,
small_to_big_mapping=small_to_big_map,
):
bad_placement_callback(ax, i)

draw_gridlike(big_graph, ax=ax, tilted=tilted)
draw_gridlike(
small_mapped,
Expand Down
20 changes: 19 additions & 1 deletion cirq-core/cirq/devices/named_topologies_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,14 @@
import cirq
import networkx as nx
import pytest
from cirq import draw_gridlike, LineTopology, TiltedSquareLattice, get_placements, draw_placements
from cirq import (
draw_gridlike,
LineTopology,
TiltedSquareLattice,
get_placements,
draw_placements,
is_valid_placement,
)


@pytest.mark.parametrize('width, height', list(itertools.product([1, 2, 3, 24], repeat=2)))
Expand Down Expand Up @@ -119,3 +126,14 @@ def test_get_placements():
draw_placements(syc23, topo.graph, placements[::3], axes=axes)
for ax in axes:
ax.scatter.assert_called()


def test_is_valid_placement():
topo = TiltedSquareLattice(4, 2)
syc23 = TiltedSquareLattice(8, 4).graph
placements = get_placements(syc23, topo.graph)
for placement in placements:
assert is_valid_placement(syc23, topo.graph, placement)

bad_placement = topo.nodes_to_gridqubits(offset=(100, 100))
assert not is_valid_placement(syc23, topo.graph, bad_placement)
1 change: 1 addition & 0 deletions cirq-google/cirq_google/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@
CouldNotPlaceError,
NaiveQubitPlacer,
RandomDevicePlacer,
HardcodedQubitPlacer,
ProcessorRecord,
EngineProcessorRecord,
SimulatedProcessorRecord,
Expand Down
1 change: 1 addition & 0 deletions cirq-google/cirq_google/json_resolver_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,5 +61,6 @@ def _class_resolver_dictionary() -> Dict[str, ObjectFactory]:
'cirq.google.SimulatedProcessorRecord': cirq_google.SimulatedProcessorRecord,
# pylint: disable=line-too-long
'cirq.google.SimulatedProcessorWithLocalDeviceRecord': cirq_google.SimulatedProcessorWithLocalDeviceRecord,
'cirq.google.HardcodedQubitPlacer': cirq_google.HardcodedQubitPlacer,
# pylint: enable=line-too-long
}
Loading