Skip to content

Commit

Permalink
Added option to return weight of matching
Browse files Browse the repository at this point in the history
  • Loading branch information
oscarhiggott committed Apr 19, 2021
1 parent e069dc5 commit 3d1904a
Show file tree
Hide file tree
Showing 6 changed files with 122 additions and 14 deletions.
11 changes: 9 additions & 2 deletions src/pymatching/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,13 @@ PYBIND11_MODULE(_cpp_mwpm, m) {
m.def("set_seed", &set_seed, "s"_a);
m.def("rand_float", &rand_float, "from"_a, "to"_a);

m.def("decode_match_neighbourhood", &LemonDecodeMatchNeighbourhood);
m.def("decode", &LemonDecode);
py::class_<MatchingResult>(m, "MatchingResult")
.def_readwrite("correction", &MatchingResult::correction)
.def_readwrite("weight", &MatchingResult::weight);

m.def("decode_match_neighbourhood", &LemonDecodeMatchNeighbourhood,
"sg"_a, "defects"_a, "num_neighbours"_a=20,
"return_weight"_a=false);
m.def("decode", &LemonDecode, "sg"_a, "defects"_a,
"return_weight"_a=false);
}
29 changes: 25 additions & 4 deletions src/pymatching/lemon_mwpm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ void DefectGraph::AddEdge(int i, int j, double weight){
}


py::array_t<std::uint8_t> LemonDecode(IStabiliserGraph& sg, const py::array_t<int>& defects){
MatchingResult LemonDecode(IStabiliserGraph& sg, const py::array_t<int>& defects, bool return_weight){
MatchingResult matching_result;
if (!sg.HasComputedAllPairsShortestPaths()){
sg.ComputeAllPairsShortestPaths();
}
Expand All @@ -76,6 +77,13 @@ py::array_t<std::uint8_t> LemonDecode(IStabiliserGraph& sg, const py::array_t<in
typedef lemon::MaxWeightedPerfectMatching<UGraph,LengthMap> MWPM;
MWPM pm(defect_graph.g, defect_graph.length);
pm.run();

if (return_weight) {
matching_result.weight = -1*pm.matchingWeight();
} else {
matching_result.weight = -1.0;
}

int N = sg.GetNumQubits();
auto correction = new std::vector<int>(N, 0);
std::set<int> qids;
Expand All @@ -95,11 +103,15 @@ py::array_t<std::uint8_t> LemonDecode(IStabiliserGraph& sg, const py::array_t<in
}

auto capsule = py::capsule(correction, [](void *correction) { delete reinterpret_cast<std::vector<int>*>(correction); });
return py::array_t<int>(correction->size(), correction->data(), capsule);
auto corr = py::array_t<int>(correction->size(), correction->data(), capsule);
matching_result.correction = corr;
return matching_result;
}


py::array_t<std::uint8_t> LemonDecodeMatchNeighbourhood(WeightedStabiliserGraph& sg, const py::array_t<int>& defects, int num_neighbours){
MatchingResult LemonDecodeMatchNeighbourhood(WeightedStabiliserGraph& sg, const py::array_t<int>& defects,
int num_neighbours, bool return_weight){
MatchingResult matching_result;
auto d = defects.unchecked<1>();
int num_defects = d.shape(0);

Expand Down Expand Up @@ -143,6 +155,12 @@ py::array_t<std::uint8_t> LemonDecodeMatchNeighbourhood(WeightedStabiliserGraph&
MWPM pm(defect_graph->g, defect_graph->length);
pm.run();

if (return_weight) {
matching_result.weight = -1*pm.matchingWeight();
} else {
matching_result.weight = -1.0;
}

int N = sg.GetNumQubits();
auto correction = new std::vector<int>(N, 0);

Expand Down Expand Up @@ -171,5 +189,8 @@ py::array_t<std::uint8_t> LemonDecodeMatchNeighbourhood(WeightedStabiliserGraph&
}
}
auto capsule = py::capsule(correction, [](void *correction) { delete reinterpret_cast<std::vector<int>*>(correction); });
return py::array_t<int>(correction->size(), correction->data(), capsule);
auto corr = py::array_t<int>(correction->size(), correction->data(), capsule);

matching_result.correction = corr;
return matching_result;
}
33 changes: 29 additions & 4 deletions src/pymatching/lemon_mwpm.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,28 @@
#include <pybind11/pybind11.h>
#include <pybind11/numpy.h>

/**
* @brief A struct containing the output of the minimum weight perfect matching decoder.
* Contains the correction corresponding to the solution, as well as the total weight
* of the solution (with the latter set to -1 if not requested).
*
*/
struct MatchingResult {
/**
* @brief The correction operator corresponding to the minimum-weight perfect matching.
* correction[i] is 1 if the ith qubit is flipped and correction[i] is 0 otherwise.
*
*/
py::array_t<std::uint8_t> correction;
/**
* @brief The total weight of the edges in the minimum-weight perfect matching.
* If the weight is not requested by the decoder (return_weight=false), then
* weight=-1.
*
*/
double weight;
};

/**
* @brief Given a stabiliser graph sg and a vector `defects` of indices of nodes that have a -1 syndrome,
* find the find the minimum weight perfect matching in the complete graph with nodes in the defects
Expand All @@ -29,9 +51,10 @@
*
* @param sg A stabiliser graph
* @param defects The indices of nodes that are associated with a -1 syndrome
* @return py::array_t<std::uint8_t> The noise vector for the minimum-weight perfect matching
* @return MatchingResult A struct containing the correction vector for the minimum-weight perfect matching and the matching weight.
* The matching weight is set to -1 if it is not requested.
*/
py::array_t<std::uint8_t> LemonDecode(IStabiliserGraph& sg, const py::array_t<int>& defects);
MatchingResult LemonDecode(IStabiliserGraph& sg, const py::array_t<int>& defects, bool return_weight=false);
/**
* @brief Given a stabiliser graph `sg`, a vector `defects` of indices of nodes that have a -1 syndrome and
* a chosen `num_neighbours`, find the minimum weight perfect matching in the graph V where each defect node
Expand All @@ -43,6 +66,8 @@ py::array_t<std::uint8_t> LemonDecode(IStabiliserGraph& sg, const py::array_t<in
* @param sg A stabiliser graph
* @param defects The indices of nodes that are associated with a -1 syndrome
* @param num_neighbours The number of closest defects to connect each defect to in the matching graph
* @return py::array_t<std::uint8_t> The noise vector for the minimum-weight perfect matching
* @return MatchingResult A struct containing the correction vector for the minimum-weight perfect matching and the matching weight.
* The matching weight is set to -1 if it is not requested.
*/
py::array_t<std::uint8_t> LemonDecodeMatchNeighbourhood(WeightedStabiliserGraph& sg, const py::array_t<int>& defects, int num_neighbours);
MatchingResult LemonDecodeMatchNeighbourhood(WeightedStabiliserGraph& sg, const py::array_t<int>& defects,
int num_neighbours, bool return_weight=false);
27 changes: 24 additions & 3 deletions src/pymatching/matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def boundary(self):
"""
return self.stabiliser_graph.get_boundary()

def decode(self, z, num_neighbours=20):
def decode(self, z, num_neighbours=20, return_weight=False):
"""Decode the syndrome `z` using minimum-weight perfect matching
If the parity of `z` is odd, then the first boundary node in
Expand Down Expand Up @@ -256,15 +256,32 @@ def decode(self, z, num_neighbours=20):
with more than around 10,000 nodes, and is only faster than
local matching for matching graphs with less than around 1,000
nodes. By default 20
return_weight : bool, optional
If return_weight=True, the sum of the weights of the edges in the
minimum weight perfect matching is also returned. By default False
Returns
-------
**(If return_weight == False)**
numpy.ndarray
A 1D numpy array of ints giving the minimum-weight correction
operator. The number of elements equals the number of qubits,
and an element is 1 if the corresponding qubit should be flipped,
and otherwise 0.
**(If return_weight == True)**
numpy.ndarray
A 1D numpy array of ints giving the minimum-weight correction
operator. The number of elements equals the number of qubits,
and an element is 1 if the corresponding qubit should be flipped,
and otherwise 0.
float
The sum of the weights of the edges in the minimum-weight perfect
matching.
"""
try:
z = np.array(z, dtype=np.uint8)
Expand All @@ -290,9 +307,13 @@ def decode(self, z, num_neighbours=20):
"if no boundary vertex is given.")
defects = np.setxor1d(defects, np.array(self.boundary[0:1]))
if num_neighbours is None:
return decode(self.stabiliser_graph, defects)
res = decode(self.stabiliser_graph, defects, return_weight)
else:
res = decode_match_neighbourhood(self.stabiliser_graph, defects, num_neighbours, return_weight)
if return_weight:
return res.correction, res.weight
else:
return decode_match_neighbourhood(self.stabiliser_graph, defects, num_neighbours)
return res.correction

def add_noise(self):
"""Add noise by flipping edges in the stabiliser graph with
Expand Down
34 changes: 34 additions & 0 deletions tests/test_matching_weight.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import numpy as np
from scipy.sparse import csr_matrix
import pytest

from pymatching import Matching


def repetition_code(n):
row_ind, col_ind = zip(*((i, j) for i in range(n) for j in (i, (i+1)%n)))
data = np.ones(2*n, dtype=np.uint8)
return csr_matrix((data, (row_ind, col_ind)))


weight_fixtures = [
(10, 20),
(10, None),
(15, 10),
(15, None),
(20, 1),
(20, None)
]


@pytest.mark.parametrize("n,num_neighbours", weight_fixtures)
def test_matching_weight(n, num_neighbours):
p = 0.4
H = repetition_code(n)
noise = np.random.rand(n) < p
weights = np.random.rand(n)
s = H@noise % 2
m = Matching(H, spacelike_weights=weights)
corr, weight = m.decode(s, num_neighbours=num_neighbours, return_weight=True)
expected_weight = np.sum(weights[corr==1])
assert expected_weight == pytest.approx(weight)
2 changes: 1 addition & 1 deletion tests/test_nearest_neighbours.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,5 +48,5 @@ def test_dijkstra_path():
@pytest.mark.parametrize("defects,num_neighbours,correction", neighbour_match_fixtures)
def test_neighbourhood_matching(defects,num_neighbours,correction):
assert (np.array_equal(decode_match_neighbourhood(m.stabiliser_graph,
np.array(defects), num_neighbours), np.array(correction)))
np.array(defects), num_neighbours, False).correction, np.array(correction)))

0 comments on commit 3d1904a

Please sign in to comment.