Skip to content

Commit

Permalink
Merge pull request #14 from oscarhiggott/prevent-seg-fault
Browse files Browse the repository at this point in the history
Prevent Python interpreter from crashing when the blossom algorithm fails
  • Loading branch information
oscarhiggott committed Jul 3, 2021
2 parents ce5e913 + 394f585 commit 7f2b32f
Show file tree
Hide file tree
Showing 5 changed files with 137 additions and 32 deletions.
2 changes: 2 additions & 0 deletions src/pymatching/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ PYBIND11_MODULE(_cpp_mwpm, m) {
.def_readwrite("correction", &MatchingResult::correction)
.def_readwrite("weight", &MatchingResult::weight);

py::register_exception<BlossomFailureException>(m, "BlossomFailureException", PyExc_RuntimeError);

m.def("decode_match_neighbourhood", &LemonDecodeMatchNeighbourhood,
"sg"_a, "defects"_a, "num_neighbours"_a=20,
"return_weight"_a=false);
Expand Down
62 changes: 36 additions & 26 deletions src/pymatching/lemon_mwpm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,14 @@

typedef lemon::ListGraph UGraph;
typedef UGraph::EdgeMap<double> LengthMap;
typedef lemon::MaxWeightedPerfectMatching<UGraph,LengthMap> MWPM;


const char * BlossomFailureException::what() const throw() {
return "The Lemon implementation of the blossom algorithm "
"(lemon::MaxWeightedPerfectMatching) "
"was unable to find a solution due to an error.";
}


class DefectGraph {
Expand All @@ -37,24 +45,20 @@ class DefectGraph {
void AddEdge(int i, int j, double weight);
UGraph g;
LengthMap length;
UGraph::NodeMap<int> node_map;
std::vector<UGraph::Node> node_list;
int num_nodes;
};

DefectGraph::DefectGraph(int num_nodes) : num_nodes(num_nodes),
length(g), node_map(g)
length(g)
{
for (int i=0; i<num_nodes; i++){
UGraph::Node x;
x = g.addNode();
node_map[x] = i;
node_list.push_back(x);
}
}

void DefectGraph::AddEdge(int i, int j, double weight){
UGraph::Edge e = g.addEdge(node_list[i], node_list[j]);
UGraph::Edge e = g.addEdge(g.nodeFromId(i), g.nodeFromId(j));
length[e] = weight;
}

Expand All @@ -74,21 +78,17 @@ MatchingResult LemonDecode(IStabiliserGraph& sg, const py::array_t<int>& defects
defect_graph.AddEdge(i, j, -1.0*sg.SpaceTimeDistance(d(i), d(j)));
}
};
typedef lemon::MaxWeightedPerfectMatching<UGraph,LengthMap> MWPM;
MWPM pm(defect_graph.g, defect_graph.length);
pm.run();

if (return_weight) {
matching_result.weight = -1*pm.matchingWeight();
} else {
matching_result.weight = -1.0;
bool success = pm.run();
if (!success){
throw BlossomFailureException();
}

int N = sg.GetNumQubits();
auto correction = new std::vector<int>(N, 0);
std::set<int> qids;
for (py::size_t i = 0; i<num_nodes; i++){
int j = defect_graph.node_map[pm.mate(defect_graph.node_list[i])];
int j = defect_graph.g.id(pm.mate(defect_graph.g.nodeFromId(i)));
if (i<j){
std::vector<int> path = sg.SpaceTimeShortestPath(d(i), d(j));
for (std::vector<int>::size_type k=0; k<path.size()-1; k++){
Expand All @@ -104,6 +104,13 @@ MatchingResult LemonDecode(IStabiliserGraph& sg, const py::array_t<int>& defects

auto capsule = py::capsule(correction, [](void *correction) { delete reinterpret_cast<std::vector<int>*>(correction); });
auto corr = py::array_t<int>(correction->size(), correction->data(), capsule);

if (return_weight) {
matching_result.weight = -1*pm.matchingWeight();
} else {
matching_result.weight = -1.0;
}

matching_result.correction = corr;
return matching_result;
}
Expand All @@ -129,18 +136,19 @@ MatchingResult LemonDecodeMatchNeighbourhood(WeightedStabiliserGraph& sg, const
++num_neighbours;
defect_graph = std::make_unique<DefectGraph>(num_defects);
std::vector<std::pair<int, double>> neighbours;
std::vector<std::set<int>> adj_list(num_defects);
int j;
bool is_in;
for (int i=0; i<num_defects; i++){
neighbours = sg.GetNearestNeighbours(d(i), num_neighbours, defect_id);
for (const auto &neighbour : neighbours){
j = defect_id[neighbour.first];
is_in = adj_list[i].find(j) != adj_list[i].end();
UGraph::Edge FoundEdge = lemon::findEdge(
defect_graph->g,
defect_graph->g.nodeFromId(i),
defect_graph->g.nodeFromId(j));
is_in = FoundEdge != lemon::INVALID;
if (!is_in && i!=j){
defect_graph->AddEdge(i, j, -1.0*neighbour.second);
adj_list[i].insert(j);
adj_list[j].insert(i);
}
}
}
Expand All @@ -151,14 +159,10 @@ MatchingResult LemonDecodeMatchNeighbourhood(WeightedStabiliserGraph& sg, const
throw std::runtime_error("Graph must have only one connected component");
}

typedef lemon::MaxWeightedPerfectMatching<UGraph,LengthMap> MWPM;
MWPM pm(defect_graph->g, defect_graph->length);
pm.run();

if (return_weight) {
matching_result.weight = -1*pm.matchingWeight();
} else {
matching_result.weight = -1.0;
bool success = pm.run();
if (!success){
throw BlossomFailureException();
}

int N = sg.GetNumQubits();
Expand All @@ -176,7 +180,7 @@ MatchingResult LemonDecodeMatchNeighbourhood(WeightedStabiliserGraph& sg, const
while (remaining_defects.size() > 0){
i = *remaining_defects.begin();
remaining_defects.erase(remaining_defects.begin());
j = defect_graph->node_map[pm.mate(defect_graph->node_list[i])];
j = defect_graph->g.id(pm.mate(defect_graph->g.nodeFromId(i)));
remaining_defects.erase(j);
path = sg.GetPath(d(i), d(j));
for (std::vector<int>::size_type k=0; k<path.size()-1; k++){
Expand All @@ -190,6 +194,12 @@ MatchingResult LemonDecodeMatchNeighbourhood(WeightedStabiliserGraph& sg, const
}
auto capsule = py::capsule(correction, [](void *correction) { delete reinterpret_cast<std::vector<int>*>(correction); });
auto corr = py::array_t<int>(correction->size(), correction->data(), capsule);

if (return_weight) {
matching_result.weight = -1*pm.matchingWeight();
} else {
matching_result.weight = -1.0;
}

matching_result.correction = corr;
return matching_result;
Expand Down
7 changes: 7 additions & 0 deletions src/pymatching/lemon_mwpm.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,13 @@
#include "stabiliser_graph.h"
#include <pybind11/pybind11.h>
#include <pybind11/numpy.h>
#include <exception>


struct BlossomFailureException : public std::exception {
const char * what() const throw();
};


/**
* @brief A struct containing the output of the minimum weight perfect matching decoder.
Expand Down
64 changes: 59 additions & 5 deletions src/pymatching/matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@
import matplotlib.cbook
import numpy as np
import networkx as nx
from scipy.sparse import csc_matrix, spmatrix, vstack
from scipy.sparse import csc_matrix

from pymatching._cpp_mwpm import (decode,
decode_match_neighbourhood,
WeightedStabiliserGraph)

WeightedStabiliserGraph,
BlossomFailureException)
# alias to let unittest mock decode_match_neighbourhood
_py_decode_match_neighbourhood = decode_match_neighbourhood


def _find_boundary_nodes(G):
Expand All @@ -44,7 +46,59 @@ def _find_boundary_nodes(G):
"""
return [i for i, attr in G.nodes(data=True)
if attr.get("is_boundary", False)]



def _local_matching(stabiliser_graph, defects, num_neighbours, return_weight=False,
num_attempts=10):
"""Local matching decoder
Find the local matching in `stabiliser_graph` with a syndrome defined by -1 measurements
at nodes in `defects`. Each defect node can be matched to one of the `num_neighbours`
nearest defects. This function uses the Lemon library's MaxWeightedPerfectMatching
implementation of the blossom algorithm. If Lemon fails to find a solution, this function
retries at most `num_attempts` times, increasing `num_neighbours` by 5 between each attempt.
Parameters
----------
stabiliser_graph : pymatching._cpp_mwpm.WeightedStabiliserGraph
The stabliser graph defining the matching problem
defects : numpy.ndarray of dtype int
The indices of the nodes corresponding to defects (-1 measurements in the syndrome)
num_neighbours : int
The number of neighbours to use in local matching
return_weight : bool, optional
If `return_weight==True`, the sum of the weights of the edges in the
minimum weight perfect matching is also returned. By default False
num_attempts : int, optional
Number of attempts to solve the matching problem if the blossom algorithm
fails to converge, by default 5
Returns
-------
numpy.ndarray
A 1D numpy array of ints giving the minimum-weight correction
operator. The number of elements equals the number of qubits,
and an element is 1 if the corresponding qubit should be flipped,
and otherwise 0.
float
Present only if `return_weight==True`.
The sum of the weights of the edges in the minimum-weight perfect
matching.
"""
if num_attempts < 1:
raise ValueError("`num_attempts` must be an integer >= 1")
attempts_remaining = num_attempts
while True:
try:
return _py_decode_match_neighbourhood(stabiliser_graph, defects, num_neighbours, return_weight)
except BlossomFailureException:
if attempts_remaining <= 1 or num_neighbours >= len(defects):
raise
else:
num_neighbours += 5
attempts_remaining -= 1


class Matching:
"""A class for constructing matching graphs and decoding using the minimum-weight perfect matching decoder
Expand Down Expand Up @@ -300,7 +354,7 @@ def decode(self, z, num_neighbours=20, return_weight=False):
if num_neighbours is None:
res = decode(self.stabiliser_graph, defects, return_weight)
else:
res = decode_match_neighbourhood(self.stabiliser_graph, defects, num_neighbours, return_weight)
res = _local_matching(self.stabiliser_graph, defects, num_neighbours, return_weight)
if return_weight:
return res.correction, res.weight
else:
Expand Down
34 changes: 33 additions & 1 deletion tests/test_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,16 @@
import os

import pytest
from unittest.mock import patch
import numpy as np
from scipy.sparse import csc_matrix, load_npz, csr_matrix
import pytest
import networkx as nx
import matplotlib.pyplot as plt

from pymatching._cpp_mwpm import WeightedStabiliserGraph
from pymatching._cpp_mwpm import WeightedStabiliserGraph, BlossomFailureException
from pymatching import Matching
from pymatching.matching import _local_matching

TEST_DIR = dir_path = os.path.dirname(os.path.realpath(__file__))

Expand Down Expand Up @@ -498,3 +500,33 @@ def test_draw_matching():
m = Matching(g)
plt.figure()
m.draw()


G = nx.Graph()
n = 100
for i in range(n):
G.add_edge(i, (i+1) % n, qubit_id=i)
M = Matching(G)
defects = np.array(list(range(n)))


def test_local_matching_raises_value_error():
with pytest.raises(ValueError):
for x in (-10, -5, 0):
_local_matching(M.stabiliser_graph, defects, 20, False, x)


@pytest.mark.parametrize("num_attempts", [1,3,5])
def test_local_matching_raises_blossom_error(num_attempts):
with patch('pymatching.matching._py_decode_match_neighbourhood') as mock_decode:
mock_decode.side_effect = BlossomFailureException
with pytest.raises(BlossomFailureException):
_local_matching(M.stabiliser_graph, defects, 20, False, num_attempts)
assert mock_decode.call_count == num_attempts


def test_local_matching_catches_blossom_errors():
with patch('pymatching.matching._py_decode_match_neighbourhood') as mock_decode:
mock_decode.side_effect = [BlossomFailureException]*3 + [0]
_local_matching(M.stabiliser_graph, defects, 20, False, 5)
assert mock_decode.call_count == 4

0 comments on commit 7f2b32f

Please sign in to comment.