Skip to content

Commit

Permalink
Refactor local and exact matching
Browse files Browse the repository at this point in the history
  • Loading branch information
oscarhiggott committed Jul 25, 2021
1 parent 96a1787 commit 3b11b48
Show file tree
Hide file tree
Showing 6 changed files with 112 additions and 113 deletions.
8 changes: 4 additions & 4 deletions src/pymatching/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,9 @@ PYBIND11_MODULE(_cpp_mwpm, m) {

py::register_exception<BlossomFailureException>(m, "BlossomFailureException", PyExc_RuntimeError);

m.def("decode_match_neighbourhood", &LemonDecodeMatchNeighbourhood,
"sg"_a, "defects"_a, "num_neighbours"_a=20,
"return_weight"_a=false);
m.def("decode", &LemonDecode, "sg"_a, "defects"_a,
m.def("local_matching", &LocalMatching,
"sg"_a, "defects"_a, "num_neighbours"_a=30,
"return_weight"_a=false, "max_attempts"_a=10);
m.def("exact_matching", &LemonDecode, "sg"_a, "defects"_a,
"return_weight"_a=false);
}
51 changes: 44 additions & 7 deletions src/pymatching/lemon_mwpm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -141,21 +141,58 @@ MatchingResult LemonDecode(
}


MatchingResult LemonDecodeMatchNeighbourhood(WeightedStabiliserGraph& sg, const py::array_t<int>& defects,
int num_neighbours, bool return_weight){
MatchingResult matching_result;
MatchingResult LocalMatching(
WeightedStabiliserGraph& sg,
const py::array_t<int>& defects,
int num_neighbours,
bool return_weight,
int max_attempts
){
if (num_neighbours <= 0){
throw std::invalid_argument("num_neighbours must be greater than zero");
}
auto d = defects.unchecked<1>();
std::set<int> defects_set;
for (int i=0; i<d.shape(0); i++) {
defects_set.insert(d(i));
}
int num_attempts = 0;
while (true) {
try{
return LemonDecodeMatchNeighbourhood(
sg,
defects_set,
num_neighbours,
return_weight
);
} catch (BlossomFailureException& e) {
num_attempts++;
if (num_neighbours >= defects_set.size() || num_attempts >= max_attempts){
throw;
} else {
num_neighbours *= 2;
}
}
}
}


MatchingResult LemonDecodeMatchNeighbourhood(
WeightedStabiliserGraph& sg,
std::set<int>& defects_set,
int num_neighbours,
bool return_weight
){
MatchingResult matching_result;

int num_nodes = sg.GetNumNodes();

std::set<int> defects_set;
for (int i=0; i<d.shape(0); i++) {
if (d(i) >= num_nodes){
for (auto d : defects_set){
if (d >= num_nodes){
throw std::invalid_argument(
"Defect id must be less than the number of nodes in the matching graph"
);
}
defects_set.insert(d(i));
}

sg.FlipBoundaryNodesIfNeeded(defects_set);
Expand Down
12 changes: 10 additions & 2 deletions src/pymatching/lemon_mwpm.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,5 +76,13 @@ MatchingResult LemonDecode(WeightedStabiliserGraph& sg, const py::array_t<int>&
* @return MatchingResult A struct containing the correction vector for the minimum-weight perfect matching and the matching weight.
* The matching weight is set to -1 if it is not requested.
*/
MatchingResult LemonDecodeMatchNeighbourhood(WeightedStabiliserGraph& sg, const py::array_t<int>& defects,
int num_neighbours, bool return_weight=false);
MatchingResult LemonDecodeMatchNeighbourhood(WeightedStabiliserGraph& sg, std::set<int>& defects,
int num_neighbours=30, bool return_weight=false);

MatchingResult LocalMatching(
WeightedStabiliserGraph& sg,
const py::array_t<int>& defects,
int num_neighbours=30,
bool return_weight=false,
int max_attempts=10
);
57 changes: 4 additions & 53 deletions src/pymatching/matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,8 @@
import networkx as nx
from scipy.sparse import csc_matrix

from pymatching._cpp_mwpm import (decode,
decode_match_neighbourhood,
WeightedStabiliserGraph,
BlossomFailureException)
# alias to let unittest mock decode_match_neighbourhood
_py_decode_match_neighbourhood = decode_match_neighbourhood
from pymatching._cpp_mwpm import (exact_matching, local_matching,
WeightedStabiliserGraph)


def _find_boundary_nodes(G):
Expand All @@ -48,51 +44,6 @@ def _find_boundary_nodes(G):
if attr.get("is_boundary", False)}


def _local_matching(stabiliser_graph, defects, num_neighbours, return_weight=False):
"""Local matching decoder
Find the local matching in `stabiliser_graph` with a syndrome defined by -1 measurements
at nodes in `defects`. Each defect node can be matched to one of the `num_neighbours`
nearest defects. If the graph does not have a perfect matching, this function
doubles `num_neighbours` and retries until `num_neighbours==len(defects)`.
Parameters
----------
stabiliser_graph : pymatching._cpp_mwpm.WeightedStabiliserGraph
The stabliser graph defining the matching problem
defects : numpy.ndarray of dtype int
The indices of the nodes corresponding to defects (-1 measurements in the syndrome)
num_neighbours : int
The number of neighbours to use in local matching
return_weight : bool, optional
If `return_weight==True`, the sum of the weights of the edges in the
minimum weight perfect matching is also returned. By default False
Returns
-------
numpy.ndarray
A 1D numpy array of ints giving the minimum-weight correction
operator. The number of elements equals the number of qubits,
and an element is 1 if the corresponding qubit should be flipped,
and otherwise 0.
float
Present only if `return_weight==True`.
The sum of the weights of the edges in the minimum-weight perfect
matching.
"""
if num_neighbours < 0:
raise ValueError("num_neighbours must be a positive integer")
while True:
try:
return _py_decode_match_neighbourhood(stabiliser_graph, defects, num_neighbours, return_weight)
except BlossomFailureException:
if num_neighbours >= len(defects):
raise
else:
num_neighbours *= 2


class Matching:
"""A class for constructing matching graphs and decoding using the minimum-weight perfect matching decoder
Expand Down Expand Up @@ -402,9 +353,9 @@ def decode(self, z, num_neighbours=30, return_weight=False):
else:
raise ValueError("The shape ({}) of the syndrome vector z is not valid.".format(z.shape))
if num_neighbours is None:
res = decode(self.stabiliser_graph, defects, return_weight)
res = exact_matching(self.stabiliser_graph, defects, return_weight)
else:
res = _local_matching(self.stabiliser_graph, defects, num_neighbours, return_weight)
res = local_matching(self.stabiliser_graph, defects, num_neighbours, return_weight)
if return_weight:
return res.correction, res.weight
else:
Expand Down
84 changes: 43 additions & 41 deletions tests/test_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,14 @@

import os

import pytest
from unittest.mock import patch
import numpy as np
from scipy.sparse import load_npz, csr_matrix
import pytest
import networkx as nx

from pymatching._cpp_mwpm import (BlossomFailureException, decode_match_neighbourhood,
decode)
from pymatching._cpp_mwpm import (BlossomFailureException, local_matching,
exact_matching)
from pymatching import Matching
from pymatching.matching import _local_matching

TEST_DIR = dir_path = os.path.dirname(os.path.realpath(__file__))

Expand Down Expand Up @@ -204,20 +201,41 @@ def test_matching_correct():
assert np.array_equal(m.decode(z, num_neighbours=None).nonzero()[0], np.array([0,4,12,16,23]))


@pytest.mark.parametrize("cluster_size", range(3, 10, 2))
@pytest.mark.parametrize("cluster_size", [2, 6, 10])
def test_local_matching_clusters(cluster_size):
g = nx.Graph()
qid = 0
for i in range(cluster_size):
g.add_edge(i, i+1, weight=1.0, qubit_id=qid)
qid += 1
g.add_edge(cluster_size, cluster_size+1, weight=2*cluster_size, qubit_id=qid)
g.add_edge(cluster_size, cluster_size+1, weight=3*cluster_size, qubit_id=qid)
qid += 1
for i in range(cluster_size+1, 2*cluster_size + 1):
g.add_edge(i, i+1, weight=1.0, qubit_id=qid)
qid += 1
m = Matching(g)
m.decode([1]*(cluster_size+1)*2, num_neighbours=cluster_size)
m.decode(np.ones((cluster_size+1)*2), num_neighbours=cluster_size)
for i in range(1, cluster_size+1):
with pytest.raises(BlossomFailureException):
local_matching(
m.stabiliser_graph,
np.arange(m.num_detectors),
num_neighbours=i,
max_attempts=1
)
for i in range(cluster_size+1, 2*cluster_size):
local_matching(
m.stabiliser_graph,
np.arange(m.num_detectors),
num_neighbours=i,
max_attempts=1
)
local_matching(
m.stabiliser_graph,
np.arange(m.num_detectors),
num_neighbours=1,
max_attempts=int(np.log2(cluster_size+1))+2
)


G = nx.Graph()
Expand All @@ -231,23 +249,7 @@ def test_local_matching_clusters(cluster_size):
def test_local_matching_raises_value_error():
with pytest.raises(ValueError):
for x in (-10, -5, 0):
_local_matching(M.stabiliser_graph, defects, x, False)


@pytest.mark.parametrize("num_neighbours", [2, 5, 20, 50])
def test_local_matching_raises_blossom_error(num_neighbours):
with patch('pymatching.matching._py_decode_match_neighbourhood') as mock_decode:
mock_decode.side_effect = BlossomFailureException
with pytest.raises(BlossomFailureException):
_local_matching(M.stabiliser_graph, defects, num_neighbours, False)
assert mock_decode.call_count == 1+int(np.ceil(np.log2(n)-np.log2(num_neighbours)))


def test_local_matching_catches_blossom_errors():
with patch('pymatching.matching._py_decode_match_neighbourhood') as mock_decode:
mock_decode.side_effect = [BlossomFailureException]*3 + [None]
_local_matching(M.stabiliser_graph, defects, 2, False)
assert mock_decode.call_count == 4
local_matching(M.stabiliser_graph, defects, x, False)


def test_decoding_large_defect_id_raises_value_error():
Expand All @@ -256,8 +258,8 @@ def test_decoding_large_defect_id_raises_value_error():
g.add_edge(1, 2)
m = Matching(g)
with pytest.raises(ValueError):
decode_match_neighbourhood(m.stabiliser_graph, np.array([1, 4]))
decode(m.stabiliser_graph, np.array([1, 4]))
local_matching(m.stabiliser_graph, np.array([1, 4]))
exact_matching(m.stabiliser_graph, np.array([1, 4]))


def test_decode_with_odd_number_of_defects():
Expand All @@ -267,18 +269,18 @@ def test_decode_with_odd_number_of_defects():
g.add_edge(2, 0)
m = Matching(g)
with pytest.raises(ValueError):
decode_match_neighbourhood(m.stabiliser_graph, np.array([1]))
local_matching(m.stabiliser_graph, np.array([1]))
with pytest.raises(ValueError):
decode(m.stabiliser_graph, np.array([1]))
exact_matching(m.stabiliser_graph, np.array([1]))
g.nodes[2]['is_boundary'] = True
m2 = Matching(g)
decode_match_neighbourhood(m2.stabiliser_graph, np.array([1]))
decode(m2.stabiliser_graph, np.array([1]))
local_matching(m2.stabiliser_graph, np.array([1]))
exact_matching(m2.stabiliser_graph, np.array([1]))
g.nodes[2]['is_boundary'] = False
g.nodes[1]['is_boundary'] = True
m3 = Matching(g)
decode_match_neighbourhood(m3.stabiliser_graph, np.array([1]))
decode(m3.stabiliser_graph, np.array([1]))
local_matching(m3.stabiliser_graph, np.array([1]))
exact_matching(m3.stabiliser_graph, np.array([1]))


def test_decode_with_multiple_components():
Expand All @@ -293,22 +295,22 @@ def test_decode_with_multiple_components():
m = Matching(g)
for z in (np.array([0]), np.arange(6)):
with pytest.raises(ValueError):
decode_match_neighbourhood(m.stabiliser_graph, z)
local_matching(m.stabiliser_graph, z)
with pytest.raises(ValueError):
decode(m.stabiliser_graph, z)
exact_matching(m.stabiliser_graph, z)

g.nodes[0]['is_boundary'] = True
m2 = Matching(g)
decode_match_neighbourhood(m2.stabiliser_graph, np.array([1]))
decode(m2.stabiliser_graph, np.array([1]))
local_matching(m2.stabiliser_graph, np.array([1]))
exact_matching(m2.stabiliser_graph, np.array([1]))
for z in (np.arange(6), np.array([3])):
with pytest.raises(ValueError):
decode_match_neighbourhood(m2.stabiliser_graph, z)
local_matching(m2.stabiliser_graph, z)
with pytest.raises(ValueError):
decode(m2.stabiliser_graph, z)
exact_matching(m2.stabiliser_graph, z)

g.nodes[4]['is_boundary'] = True
m3 = Matching(g)
for z in (np.array([0]), np.arange(6), np.array([3]), np.array([1,3])):
decode_match_neighbourhood(m3.stabiliser_graph, z)
decode(m3.stabiliser_graph, z)
local_matching(m3.stabiliser_graph, z)
exact_matching(m3.stabiliser_graph, z)
13 changes: 7 additions & 6 deletions tests/test_nearest_neighbours.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
import pytest

from pymatching import Matching
from pymatching._cpp_mwpm import decode_match_neighbourhood
from pymatching._cpp_mwpm import local_matching


g = nx.Graph()
g.add_edge(0,1, qubit_id=0, weight=0.3)
Expand All @@ -32,7 +33,7 @@

def test_dijkstra_nearest_neighbour_nodes():
assert (set(m.stabiliser_graph.get_nearest_neighbours(2, 3, [0,0,1,-1,2,-1]))
== {(2, 0.0), (1, 0.1), (4, 0.2)})
== {(2, 0.0), (1, 0.1), (4, 0.2)})


def test_dijkstra_path():
Expand All @@ -42,13 +43,13 @@ def test_dijkstra_path():


neighbour_match_fixtures = [
([1,3], 3, [0,1,1,0,0]),
([0,1,2,3,4,5], 2, [1,0,1,0,1]),
([0,1,2,3,4,5], 10, [1,0,1,0,1])
([1, 3], 3, [0, 1, 1, 0, 0]),
([0, 1, 2, 3, 4, 5], 3, [1, 0, 1, 0, 1]),
([0, 1, 2, 3, 4, 5], 11, [1, 0, 1, 0, 1])
]


@pytest.mark.parametrize("defects,num_neighbours,correction", neighbour_match_fixtures)
def test_neighbourhood_matching(defects,num_neighbours,correction):
assert (np.array_equal(decode_match_neighbourhood(m.stabiliser_graph,
assert (np.array_equal(local_matching(m.stabiliser_graph,
np.array(defects), num_neighbours, False).correction, np.array(correction)))

0 comments on commit 3b11b48

Please sign in to comment.