Skip to content

Commit

Permalink
Merge pull request #31 from mtreinish/add-retworkx-support
Browse files Browse the repository at this point in the history
Add support for using retworkx as input and output from Matching
  • Loading branch information
oscarhiggott committed May 10, 2022
2 parents b64cca6 + 73bfb6f commit 2b67ffe
Show file tree
Hide file tree
Showing 3 changed files with 270 additions and 5 deletions.
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,12 +127,12 @@ def build_extension(self, ext):
packages=find_packages("src"),
package_dir={'':'src'},
cmdclass=dict(build_ext=CMakeBuild),
install_requires=['scipy', 'numpy', 'networkx','matplotlib'],
install_requires=['scipy', 'numpy', 'networkx','retworkx>=0.11.0','matplotlib'],
classifiers=[
"License :: OSI Approved :: Apache Software License"
],
long_description=long_description,
long_description_content_type='text/markdown',
python_requires='>=3',
zip_safe=False,
)
)
99 changes: 96 additions & 3 deletions src/pymatching/matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import matplotlib.cbook
import numpy as np
import networkx as nx
import retworkx as rx
import scipy
from scipy.sparse import csc_matrix

Expand Down Expand Up @@ -60,7 +61,7 @@ class Matching:
fault ids, boundaries and error probabilities.
"""
def __init__(self,
H: Union[scipy.sparse.spmatrix, np.ndarray, nx.Graph, List[List[int]]] = None,
H: Union[scipy.sparse.spmatrix, np.ndarray, rx.PyGraph, nx.Graph, List[List[int]]] = None,
spacelike_weights: Union[float, np.ndarray, List[float]] = None,
error_probabilities: Union[float, np.ndarray, List[float]] = None,
repetitions: int = None,
Expand Down Expand Up @@ -156,7 +157,7 @@ def __init__(self,
self.matching_graph = MatchingGraph()
if H is None:
return
if not isinstance(H, nx.Graph):
if not isinstance(H, (nx.Graph, rx.PyGraph)):
try:
H = csc_matrix(H)
except TypeError:
Expand All @@ -165,8 +166,10 @@ def __init__(self,
self.load_from_check_matrix(H, spacelike_weights, error_probabilities,
repetitions, timelike_weights, measurement_error_probabilities,
**kwargs)
else:
elif isinstance(H, nx.Graph):
self.load_from_networkx(H)
else:
self.load_from_retworkx(H)
if precompute_shortest_paths:
self.matching_graph.compute_all_pairs_shortest_paths()

Expand Down Expand Up @@ -310,6 +313,74 @@ def load_from_networkx(self, graph: nx.Graph) -> None:
g.add_edge(u, v, fault_ids, weight, e_prob, 0 <= e_prob <= 1)
self.matching_graph = g

def load_from_retworkx(self, graph: rx.PyGraph) -> None:
r"""
Load a matching graph from a retworkX graph
Parameters
----------
graph : retworkx.PyGraph
Each edge in the retworkx graph can have dictionary payload with keys
``fault_ids``, ``weight`` and ``error_probability``. ``fault_ids`` should be
an int or a set of ints. Each fault id corresponds to a self-inverse fault
that is flipped when the corresponding edge is flipped. These self-inverse
faults could correspond to physical Pauli errors (physical frame changes)
or to the logical observables that are flipped by the fault
(a logical frame change, equivalent to an obersvable ID in an error instruction in a Stim
detector error model). The `fault_ids` attribute was previously named `qubit_id` in an
earlier version of PyMatching, and `qubit_id` is still accepted instead of `fault_ids` in order
to maintain backward compatibility.
Each ``weight`` attribute should be a non-negative float. If
every edge is assigned an error_probability between zero and one,
then the ``add_noise`` method can be used to simulate noise and
flip edges independently in the graph.
Examples
--------
>>> import pymatching
>>> import retworkx as rx
>>> import math
>>> g = rx.PyGraph()
>>> matching = g.add_nodes_from([{} for _ in range(3)])
>>> edge_a =g.add_edge(0, 1, dict(fault_ids=0, weight=math.log((1-0.1)/0.1), error_probability=0.1))
>>> edge_b = g.add_edge(1, 2, dict(fault_ids=1, weight=math.log((1-0.15)/0.15), error_probability=0.15))
>>> g[0]['is_boundary'] = True
>>> g[2]['is_boundary'] = True
>>> m = pymatching.Matching(g)
>>> m
<pymatching.Matching object with 1 detector, 2 boundary nodes, and 2 edges>
"""
if not isinstance(graph, rx.PyGraph):
raise TypeError("G must be a retworkx graph")
boundary = {i for i in graph.node_indices() if graph[i].get("is_boundary", False)}
num_nodes = len(graph)
g = MatchingGraph(self.num_detectors, boundary)
for (u, v, attr) in graph.weighted_edge_list():
u, v = int(u), int(v)
if "fault_ids" in attr and "qubit_id" in attr:
raise ValueError("Both `fault_ids` and `qubit_id` were provided as edge attributes, however use "
"of `qubit_id` has been deprecated in favour of `fault_ids`. Please only supply "
"`fault_ids` as an edge attribute.")
if "fault_ids" not in attr and "qubit_id" in attr:
fault_ids = attr["qubit_id"] # Still accept qubit_id as well for now
else:
fault_ids = attr.get("fault_ids", set())
if isinstance(fault_ids, (int, np.integer)):
fault_ids = {int(fault_ids)} if fault_ids != -1 else set()
else:
try:
fault_ids = set(fault_ids)
if not all(isinstance(q, (int, np.integer)) for q in fault_ids):
raise ValueError("fault_ids must be a set of ints, not {}".format(fault_ids))
except:
raise ValueError(
"fault_ids property must be an int or a set of int"\
" (or convertible to a set), not {}".format(fault_ids))
weight = attr.get("weight", 1) # Default weight is 1 if not provided
e_prob = attr.get("error_probability", -1)
g.add_edge(u, v, fault_ids, weight, e_prob, 0 <= e_prob <= 1)
self.matching_graph = g

def load_from_check_matrix(self,
H: Union[scipy.sparse.spmatrix, np.ndarray, List[List[int]]],
spacelike_weights: Union[float, np.ndarray, List[float]] = None,
Expand Down Expand Up @@ -748,6 +819,28 @@ def to_networkx(self) -> nx.Graph:
is_boundary = i in boundary
G.nodes[i]['is_boundary'] = is_boundary
return G

def to_retworkx(self) -> rx.PyGraph:
"""Convert to retworkx graph
Returns a retworkx graph object corresponding to the matching graph. Each edge
payload is a ``dict`` with keys `fault_ids`, `weight` and `error_probability` and
each node has a ``dict`` payload with the key ``is_boundary`` and the value is
a boolean.
Returns
-------
retworkx.PyGraph
retworkx graph corresponding to the matching graph
"""
G = rx.PyGraph(multigraph=False)
G.add_nodes_from([{} for _ in range(self.num_nodes)])
G.extend_from_weighted_edge_list(self.edges())
boundary = self.boundary
for i in G.node_indices():
is_boundary = i in boundary
G[i]['is_boundary'] = is_boundary
return G

def draw(self) -> None:
"""Draw the matching graph using matplotlib
Expand Down
172 changes: 172 additions & 0 deletions tests/test_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from scipy.sparse import csc_matrix, csr_matrix
import pytest
import networkx as nx
import retworkx as rx
import matplotlib.pyplot as plt

from pymatching._cpp_mwpm import MatchingGraph
Expand Down Expand Up @@ -60,6 +61,21 @@ def test_boundary_from_networkx():
assert np.array_equal(m.decode(np.array([0,1,1,0])), np.array([0,0,1,0,0]))
assert np.array_equal(m.decode(np.array([0,0,1,0])), np.array([0,0,0,1,1]))

def test_boundary_from_retworkx():
g = rx.PyGraph()
g.add_nodes_from([{} for _ in range(5)])
g.add_edge(4,0, dict(fault_ids=0))
g.add_edge(0,1, dict(fault_ids=1))
g.add_edge(1,2, dict(fault_ids=2))
g.add_edge(2,3, dict(fault_ids=3))
g.add_edge(3,4, dict(fault_ids=4))
g[4]['is_boundary'] = True
m = Matching(g)
assert m.boundary == {4}
assert np.array_equal(m.decode(np.array([1,0,0,0])), np.array([1,0,0,0,0]))
assert np.array_equal(m.decode(np.array([0,1,0,0])), np.array([1,1,0,0,0]))
assert np.array_equal(m.decode(np.array([0,1,1,0])), np.array([0,0,1,0,0]))
assert np.array_equal(m.decode(np.array([0,0,1,0])), np.array([0,0,0,1,1]))

def test_boundaries_from_networkx():
g = nx.Graph()
Expand All @@ -78,6 +94,23 @@ def test_boundaries_from_networkx():
assert np.array_equal(m.decode(np.array([0,0,1,1,0])), np.array([0,0,1,0,0]))
assert np.array_equal(m.decode(np.array([0,0,0,1,0])), np.array([0,0,0,1,1]))

def test_boundaries_from_retworkx():
g = rx.PyGraph()
g.add_nodes_from([{} for _ in range(6)])
g.add_edge(0,1, dict(fault_ids=0))
g.add_edge(1,2, dict(fault_ids=1))
g.add_edge(2,3, dict(fault_ids=2))
g.add_edge(3,4, dict(fault_ids=3))
g.add_edge(4,5, dict(fault_ids=4))
g.add_edge(0,5, dict(fault_ids=-1, weight=0.0))
g.nodes()[0]['is_boundary'] = True
g.nodes()[5]['is_boundary'] = True
m = Matching(g)
assert m.boundary == {0, 5}
assert np.array_equal(m.decode(np.array([0,1,0,0,0,0])), np.array([1,0,0,0,0]))
assert np.array_equal(m.decode(np.array([0,0,1,0,0])), np.array([1,1,0,0,0]))
assert np.array_equal(m.decode(np.array([0,0,1,1,0])), np.array([0,0,1,0,0]))
assert np.array_equal(m.decode(np.array([0,0,0,1,0])), np.array([0,0,0,1,1]))

def test_nonzero_matrix_elements_not_one_raises_value_error():
H = csr_matrix(np.array([[0,1.01,1.01],[1.01,1.01,0]]))
Expand Down Expand Up @@ -200,6 +233,79 @@ def test_mwpm_from_networkx():
assert(m.matching_graph.shortest_path(0,2) == [0,2])


def test_unweighted_stabiliser_graph_from_retworkx():
w = rx.PyGraph()
w.add_nodes_from([{} for _ in range(6)])
w.add_edge(0, 1, dict(fault_ids=0, weight=7.0))
w.add_edge(0, 5, dict(fault_ids=1, weight=14.0))
w.add_edge(0, 2, dict(fault_ids=2, weight=9.0))
w.add_edge(1, 2, dict(fault_ids=-1, weight=10.0))
w.add_edge(1, 3, dict(fault_ids=3, weight=15.0))
w.add_edge(2, 5, dict(fault_ids=4, weight=2.0))
w.add_edge(2, 3, dict(fault_ids=-1, weight=11.0))
w.add_edge(3, 4, dict(fault_ids=5, weight=6.0))
w.add_edge(4, 5, dict(fault_ids=6, weight=9.0))
m = Matching(w)
assert(m.num_fault_ids == 7)
assert(m.num_detectors == 6)
assert(m.matching_graph.shortest_path(3, 5) == [3, 2, 5])
assert(m.matching_graph.distance(5, 0) == pytest.approx(11.0))
assert(np.array_equal(
m.decode(np.array([1,0,1,0,0,0])),
np.array([0,0,1,0,0,0,0]))
)
with pytest.raises(ValueError):
m.decode(np.array([1,1,0]))
with pytest.raises(ValueError):
m.decode(np.array([1,1,1,0,0,0]))
assert(np.array_equal(
m.decode(np.array([1,0,0,0,0,1])),
np.array([0,0,1,0,1,0,0]))
)
assert(np.array_equal(
m.decode(np.array([0,1,0,0,0,1])),
np.array([0,0,0,0,1,0,0]))
)


def test_mwpm_from_retworkx():
g = rx.PyGraph()
g.add_nodes_from([{} for _ in range(3)])
g.add_edge(0, 1, dict(fault_ids=0))
g.add_edge(0, 2, dict(fault_ids=1))
g.add_edge(1, 2, dict(fault_ids=2))
m = Matching(g)
assert(isinstance(m.matching_graph, MatchingGraph))
assert(m.num_detectors == 3)
assert(m.num_fault_ids == 3)
assert(m.matching_graph.distance(0,2) == 1)
assert(m.matching_graph.shortest_path(0,2) == [0,2])

g = rx.PyGraph()
g.add_nodes_from([{} for _ in range(3)])
g.add_edge(0, 1, {})
g.add_edge(0, 2, {})
g.add_edge(1, 2, {})
m = Matching(g)
assert(isinstance(m.matching_graph, MatchingGraph))
assert(m.num_detectors == 3)
assert(m.num_fault_ids == 0)
assert(m.matching_graph.distance(0,2) == 1)
assert(m.matching_graph.shortest_path(0,2) == [0,2])

g = rx.PyGraph()
g.add_nodes_from([{} for _ in range(3)])
g.add_edge(0, 1, dict(weight=1.5))
g.add_edge(0, 2, dict(weight=1.7))
g.add_edge(1, 2, dict(weight=1.2))
m = Matching(g)
assert(isinstance(m.matching_graph, MatchingGraph))
assert(m.num_detectors == 3)
assert(m.num_fault_ids == 0)
assert(m.matching_graph.distance(0,2) == pytest.approx(1.7))
assert(m.matching_graph.shortest_path(0,2) == [0,2])


def test_repr():
g = nx.Graph()
g.add_edge(0, 1, fault_ids=0)
Expand Down Expand Up @@ -253,6 +359,47 @@ def test_qubit_id_accepted_via_networkx():
assert es == expected_edges


def test_matching_edges_from_retworkx():
g = rx.PyGraph()
g.add_nodes_from([{} for _ in range(4)])
g.add_edge(0, 1, dict(fault_ids=0, weight=1.1, error_probability=0.1))
g.add_edge(1, 2, dict(fault_ids=1, weight=2.1, error_probability=0.2))
g.add_edge(2, 3, dict(fault_ids={2,3}, weight=0.9, error_probability=0.3))
g[0]['is_boundary'] = True
g[3]['is_boundary'] = True
g.add_edge(0, 3, dict(weight=0.0))
m = Matching(g)
es = list(m.edges())
expected_edges = [
(0,1,{'fault_ids': {0}, 'weight': 1.1, 'error_probability': 0.1}),
(1,2,{'fault_ids': {1}, 'weight': 2.1, 'error_probability': 0.2}),
(2,3,{'fault_ids': {2,3}, 'weight': 0.9, 'error_probability': 0.3}),
(0,3,{'fault_ids': set(), 'weight': 0.0, 'error_probability': -1.0}),
]
print(es)
assert es == expected_edges


def test_qubit_id_accepted_via_retworkx():
g = rx.PyGraph()
g.add_nodes_from([{} for _ in range(4)])
g.add_edge(0, 1, dict(qubit_id=0, weight=1.1, error_probability=0.1))
g.add_edge(1, 2, dict(qubit_id=1, weight=2.1, error_probability=0.2))
g.add_edge(2, 3, dict(qubit_id={2, 3}, weight=0.9, error_probability=0.3))
g[0]['is_boundary'] = True
g[3]['is_boundary'] = True
g.add_edge(0, 3, dict(weight=0.0))
m = Matching(g)
es = list(m.edges())
expected_edges = [
(0, 1, {'fault_ids': {0}, 'weight': 1.1, 'error_probability': 0.1}),
(1, 2, {'fault_ids': {1}, 'weight': 2.1, 'error_probability': 0.2}),
(2, 3, {'fault_ids': {2, 3}, 'weight': 0.9, 'error_probability': 0.3}),
(0, 3, {'fault_ids': set(), 'weight': 0.0, 'error_probability': -1.0}),
]
assert es == expected_edges


def test_qubit_id_accepted_using_add_edge():
m = Matching()
m.add_edge(0, 1, qubit_id=0)
Expand Down Expand Up @@ -304,6 +451,31 @@ def test_matching_to_networkx():
assert sorted(gedges) == sorted(g2edges)


def test_matching_to_retworkx():
g = rx.PyGraph()
g.add_nodes_from([{} for _ in range(4)])
g.add_edge(0, 1, dict(fault_ids={0}, weight=1.1, error_probability=0.1))
g.add_edge(1, 2, dict(fault_ids={1}, weight=2.1, error_probability=0.2))
g.add_edge(2, 3, dict(fault_ids={2,3}, weight=0.9, error_probability=0.3))
g[0]['is_boundary'] = True
g[3]['is_boundary'] = True
g.add_edge(0, 3, dict(weight=0.0))
m = Matching(g)

edge_0_3 = g.get_edge_data(0, 3)
edge_0_3['fault_ids'] = set()
edge_0_3['error_probability'] = -1.0
g[1]['is_boundary'] = False
g[2]['is_boundary'] = False

g2 = m.to_retworkx()

assert g.node_indices() == g2.node_indices()
gedges = [({s,t},d) for (s, t, d) in g.weighted_edge_list()]
g2edges = [({s,t},d) for (s, t, d) in g.weighted_edge_list()]
assert sorted(gedges) == sorted(g2edges)


def test_draw_matching():
g = nx.Graph()
g.add_edge(0, 1, fault_ids={0}, weight=1.1, error_probability=0.1)
Expand Down

0 comments on commit 2b67ffe

Please sign in to comment.