Skip to content

Commit

Permalink
Adaptively ensure matching graph is connected even for small num_neig…
Browse files Browse the repository at this point in the history
…hbours
  • Loading branch information
oscarhiggott committed Nov 27, 2020
1 parent 078b819 commit f52ad20
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 64 deletions.
103 changes: 59 additions & 44 deletions src/pymatching/lemon_mwpm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,38 +30,57 @@
typedef lemon::ListGraph UGraph;
typedef UGraph::EdgeMap<double> LengthMap;

py::array_t<std::uint8_t> LemonDecode(IStabiliserGraph& sg, const py::array_t<int>& defects){
if (!sg.HasComputedAllPairsShortestPaths()){
sg.ComputeAllPairsShortestPaths();
}
auto d = defects.unchecked<1>();
int num_nodes = d.shape(0);

UGraph g;
LengthMap length(g);
UGraph::NodeMap<int> node_map(g);
std::vector<UGraph::Node> node_list;
class DefectGraph {
public:
DefectGraph(int num_nodes);
void AddEdge(int i, int j, double weight);
UGraph g;
LengthMap length;
UGraph::NodeMap<int> node_map;
std::vector<UGraph::Node> node_list;
int num_nodes;
};

DefectGraph::DefectGraph(int num_nodes) : num_nodes(num_nodes),
length(g), node_map(g)
{
for (int i=0; i<num_nodes; i++){
UGraph::Node x;
x = g.addNode();
node_map[x] = i;
node_list.push_back(x);
}
}

void DefectGraph::AddEdge(int i, int j, double weight){
UGraph::Edge e = g.addEdge(node_list[i], node_list[j]);
length[e] = weight;
}


py::array_t<std::uint8_t> LemonDecode(IStabiliserGraph& sg, const py::array_t<int>& defects){
if (!sg.HasComputedAllPairsShortestPaths()){
sg.ComputeAllPairsShortestPaths();
}
auto d = defects.unchecked<1>();
int num_nodes = d.shape(0);

DefectGraph defect_graph(num_nodes);

for (py::size_t i = 0; i<num_nodes; i++){
for (py::size_t j=i+1; j<num_nodes; j++){
UGraph::Edge e = g.addEdge(node_list[i], node_list[j]);
length[e] = -1.0*sg.SpaceTimeDistance(d(i), d(j));
defect_graph.AddEdge(i, j, -1.0*sg.SpaceTimeDistance(d(i), d(j)));
}
};
typedef lemon::MaxWeightedPerfectMatching<UGraph,LengthMap> MWPM;
MWPM pm(g, length);
MWPM pm(defect_graph.g, defect_graph.length);
pm.run();
int N = sg.GetNumQubits();
auto correction = new std::vector<int>(N, 0);
std::set<int> qids;
for (py::size_t i = 0; i<num_nodes; i++){
int j = node_map[pm.mate(node_list[i])];
int j = defect_graph.node_map[pm.mate(defect_graph.node_list[i])];
if (i<j){
std::vector<int> path = sg.SpaceTimeShortestPath(d(i), d(j));
for (std::vector<int>::size_type k=0; k<path.size()-1; k++){
Expand Down Expand Up @@ -90,43 +109,38 @@ py::array_t<std::uint8_t> LemonDecodeMatchNeighbourhood(WeightedStabiliserGraph&
for (int i=0; i<num_defects; i++){
defect_id[d(i)] = i;
}

num_neighbours = std::min(num_neighbours, num_defects-1) + 1;

UGraph g;
LengthMap length(g);
UGraph::NodeMap<int> node_map(g);
std::vector<UGraph::Node> node_list;
for (int i=0; i<num_defects; i++){
UGraph::Node x;
x = g.addNode();
node_map[x] = i;
node_list.push_back(x);
}

std::vector<std::pair<int, double>> neighbours;
std::vector<std::set<int>> adj_list(num_defects);
int j;
bool is_in;
for (int i=0; i<num_defects; i++){
neighbours = sg.GetNearestNeighbours(d(i), num_neighbours, defect_id);
for (const auto &neighbour : neighbours){
j = defect_id[neighbour.first];
is_in = adj_list[i].find(j) != adj_list[i].end();
if (!is_in && i!=j){
UGraph::Edge e = g.addEdge(node_list[i], node_list[j]);
length[e] = -1.0*neighbour.second;
adj_list[i].insert(j);
adj_list[j].insert(i);
--num_neighbours;
bool is_connected = false;
std::unique_ptr<DefectGraph> defect_graph;
while (!is_connected && num_neighbours < num_defects){
++num_neighbours;
defect_graph = std::make_unique<DefectGraph>(num_defects);
std::vector<std::pair<int, double>> neighbours;
std::vector<std::set<int>> adj_list(num_defects);
int j;
bool is_in;
for (int i=0; i<num_defects; i++){
neighbours = sg.GetNearestNeighbours(d(i), num_neighbours, defect_id);
for (const auto &neighbour : neighbours){
j = defect_id[neighbour.first];
is_in = adj_list[i].find(j) != adj_list[i].end();
if (!is_in && i!=j){
defect_graph->AddEdge(i, j, -1.0*neighbour.second);
adj_list[i].insert(j);
adj_list[j].insert(i);
}
}
}
is_connected = lemon::connected(defect_graph->g);
}

if (!lemon::connected(g)){
if (!lemon::connected(defect_graph->g)){
throw std::runtime_error("Graph must have only one connected component");
}

typedef lemon::MaxWeightedPerfectMatching<UGraph,LengthMap> MWPM;
MWPM pm(g, length);
MWPM pm(defect_graph->g, defect_graph->length);
pm.run();

int N = sg.GetNumQubits();
Expand All @@ -139,11 +153,12 @@ py::array_t<std::uint8_t> LemonDecodeMatchNeighbourhood(WeightedStabiliserGraph&

std::vector<int> path;
int i;
int j;
std::set<int> qids;
while (remaining_defects.size() > 0){
i = *remaining_defects.begin();
remaining_defects.erase(remaining_defects.begin());
j = node_map[pm.mate(node_list[i])];
j = defect_graph->node_map[pm.mate(defect_graph->node_list[i])];
remaining_defects.erase(j);
path = sg.GetPath(d(i), d(j));
for (std::vector<int>::size_type k=0; k<path.size()-1; k++){
Expand Down
7 changes: 1 addition & 6 deletions src/pymatching/matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,8 +243,7 @@ def decode(self, z, num_neighbours=20):
then the local matching decoder in the Appendix of
https://arxiv.org/abs/2010.09626 is used, and `num_neighbours`
corresponds to the parameter `m` in the paper. It is recommended
to leave `num_neighbours` set to 20. Setting num_neighbours to
less than than 10 is numerically unstable, and not permitted.
to leave `num_neighbours` set to at least 20.
If `num_neighbours=None`, then instead full matching is
performed, with the all-pairs shortest paths precomputed and
cached the first time it is used. Since full matching is more
Expand All @@ -262,10 +261,6 @@ def decode(self, z, num_neighbours=20):
and otherwise 0.
"""
if num_neighbours is not None and num_neighbours < 10:
raise ValueError("num_neighbours can be either None, or an"
" integer greater than or equal to 10, "
f"not {num_neighbours}.")
try:
z = np.array(z, dtype=np.uint8)
except:
Expand Down
30 changes: 16 additions & 14 deletions tests/test_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,20 +372,6 @@ def test_wrong_connected_components_raises_value_error():
assert m.stabiliser_graph.get_num_connected_components() == 1


def test_small_num_neighbours_raises_value_error():
m = Matching(np.array([
[1,1,0,0],
[0,1,1,0],
[0,0,1,1]
]))
min_num_neighbours = 10
for i in range(min_num_neighbours):
with pytest.raises(ValueError):
m.decode([0,1,1], num_neighbours=i)
for i in range(min_num_neighbours, 2*min_num_neighbours):
m.decode([0,1,1], num_neighbours=i)


def test_high_qubit_id_raises_value_error():
g = nx.Graph()
g.add_edge(0,1,qubit_id=1)
Expand Down Expand Up @@ -439,3 +425,19 @@ def test_matching_correct():
z[15] = 1
assert np.array_equal(m.decode(z, num_neighbours=20).nonzero()[0], np.array([0,4,12,16,23]))
assert np.array_equal(m.decode(z, num_neighbours=None).nonzero()[0], np.array([0,4,12,16,23]))


@pytest.mark.parametrize("cluster_size", range(3,10,2))
def test_local_matching_connected(cluster_size):
g = nx.Graph()
qid = 0
for i in range(cluster_size):
g.add_edge(i, i+1, weight=1.0, qubit_id=qid)
qid += 1
g.add_edge(cluster_size, cluster_size+1, weight=2*cluster_size, qubit_id=qid)
qid += 1
for i in range(cluster_size+1, 2*cluster_size + 1):
g.add_edge(i, i+1, weight=1.0, qubit_id=qid)
qid += 1
m = Matching(g)
m.decode([1]*(cluster_size+1)*2, num_neighbours=cluster_size)

0 comments on commit f52ad20

Please sign in to comment.