Skip to content

Commit

Permalink
Merge pull request #23 from oscarhiggott/negative-weights
Browse files Browse the repository at this point in the history
Support negative edge weights
  • Loading branch information
oscarhiggott committed Dec 24, 2021
2 parents ba2d861 + 73c1e41 commit 7d799d1
Show file tree
Hide file tree
Showing 8 changed files with 146 additions and 35 deletions.
16 changes: 8 additions & 8 deletions docs/toric-code-example.ipynb

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions src/pymatching/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ PYBIND11_MODULE(_cpp_mwpm, m) {
.def(py::init<>(), u8R"(
Initialises a WeightedEdgeData object
)")
.def(py::init<std::set<int>, double, double, bool>(),
"fault_ids"_a, "weight"_a, "error_probability"_a, "has_error_probability"_a, u8R"(
.def(py::init<std::set<int>, double, double, bool, bool>(),
"fault_ids"_a, "weight"_a, "error_probability"_a, "has_error_probability"_a, "weight_is_negative"_a, u8R"(
Initialises a WeightedEdgeData object
Parameters
Expand Down
36 changes: 34 additions & 2 deletions src/pymatching/lemon_mwpm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,16 @@ MatchingResult ExactMatching(
}
defects_set.insert(d(i));
}

for (auto s : graph.negative_edge_syndrome){
auto s_it = defects_set.find(s);
if (s_it != defects_set.end()){
defects_set.erase(s_it);
} else {
defects_set.insert(s);
}
}

graph.FlipBoundaryNodesIfNeeded(defects_set);

std::vector<int> defects_vec(defects_set.begin(), defects_set.end());
Expand Down Expand Up @@ -156,11 +166,17 @@ MatchingResult ExactMatching(
}
}

for (auto fid : graph.negative_edge_fault_ids){
if ((fid != -1) && (fid >= 0) && (fid < N)){
(*correction)[fid] = ((*correction)[fid] + 1) % 2;
}
}

auto capsule = py::capsule(correction, [](void *correction) { delete reinterpret_cast<std::vector<int>*>(correction); });
auto corr = py::array_t<int>(correction->size(), correction->data(), capsule);

if (return_weight) {
matching_result.weight = -1*pm.matchingWeight();
matching_result.weight = -1*pm.matchingWeight() + graph.negative_weight_sum;
} else {
matching_result.weight = -1.0;
}
Expand Down Expand Up @@ -224,6 +240,15 @@ MatchingResult LemonDecodeMatchNeighbourhood(
}
}

for (auto s : graph.negative_edge_syndrome){
auto s_it = defects_set.find(s);
if (s_it != defects_set.end()){
defects_set.erase(s_it);
} else {
defects_set.insert(s);
}
}

graph.FlipBoundaryNodesIfNeeded(defects_set);

std::vector<int> defects_vec(defects_set.begin(), defects_set.end());
Expand Down Expand Up @@ -286,11 +311,18 @@ MatchingResult LemonDecodeMatchNeighbourhood(
}
}
}

for (auto fid : graph.negative_edge_fault_ids){
if ((fid != -1) && (fid >= 0) && (fid < N)){
(*correction)[fid] = ((*correction)[fid] + 1) % 2;
}
}

auto capsule = py::capsule(correction, [](void *correction) { delete reinterpret_cast<std::vector<int>*>(correction); });
auto corr = py::array_t<int>(correction->size(), correction->data(), capsule);

if (return_weight) {
matching_result.weight = -1*pm.matchingWeight();
matching_result.weight = -1*pm.matchingWeight() + graph.negative_weight_sum;
} else {
matching_result.weight = -1.0;
}
Expand Down
4 changes: 0 additions & 4 deletions src/pymatching/matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,8 +305,6 @@ def load_from_networkx(self, graph: nx.Graph) -> None:
" (or convertible to a set), not {}".format(fault_ids))
all_fault_ids = all_fault_ids | fault_ids
weight = attr.get("weight", 1) # Default weight is 1 if not provided
if weight < 0:
raise ValueError("Weights cannot be negative.")
e_prob = attr.get("error_probability", -1)
g.add_edge(u, v, fault_ids, weight, e_prob, 0 <= e_prob <= 1)
self.matching_graph = g
Expand Down Expand Up @@ -409,8 +407,6 @@ def load_from_check_matrix(self,

if weights.shape[0] != num_fault_ids:
raise ValueError("Weights array must have num_fault_ids elements")
if np.any(weights < 0.):
raise ValueError("All weights must be non-negative.")

timelike_weights = 1.0 if timelike_weights is None else timelike_weights
if isinstance(timelike_weights, (int, float, np.integer, np.floating)):
Expand Down
47 changes: 40 additions & 7 deletions src/pymatching/matching_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include <cstdint>
#include <iostream>
#include <limits>
#include <cmath>
#include "rand_gen.h"


Expand All @@ -34,9 +35,11 @@ WeightedEdgeData::WeightedEdgeData(
std::set<int> fault_ids,
double weight,
double error_probability,
bool has_error_probability
bool has_error_probability,
bool weight_is_negative
): fault_ids(fault_ids), weight(weight),
error_probability(error_probability), has_error_probability(has_error_probability) {}
error_probability(error_probability), has_error_probability(has_error_probability),
weight_is_negative(weight_is_negative) {}


std::string set_repr(std::set<int> x) {
Expand Down Expand Up @@ -67,7 +70,8 @@ std::string WeightedEdgeData::repr() const {

MatchingGraph::MatchingGraph()
: all_edges_have_error_probabilities(true),
connected_components_need_updating(true) {
connected_components_need_updating(true),
negative_weight_sum(0.0) {
wgraph_t sgraph = wgraph_t();
this->matching_graph = sgraph;
}
Expand All @@ -78,7 +82,8 @@ MatchingGraph::MatchingGraph(
std::set<int>& boundary)
: all_edges_have_error_probabilities(true),
boundary(boundary),
connected_components_need_updating(true) {
connected_components_need_updating(true),
negative_weight_sum(0.0) {
wgraph_t sgraph = wgraph_t(num_detectors+boundary.size());
this->matching_graph = sgraph;
}
Expand All @@ -104,8 +109,8 @@ void MatchingGraph::AddEdge(
throw std::invalid_argument("This edge already exists in the graph. "
"Parallel edges are not supported.");
}
if (weight < 0){
throw std::invalid_argument("Edge weights must be non-negative");
if (std::signbit(weight)){
HandleNewNegativeWeightEdge(node1, node2, weight, fault_ids);
}
if (!has_error_probability){
all_edges_have_error_probabilities = false;
Expand All @@ -114,16 +119,41 @@ void MatchingGraph::AddEdge(
connected_components_need_updating = true;
WeightedEdgeData data;
data.fault_ids = fault_ids;
data.weight = weight;
data.weight = std::abs(weight);
data.error_probability = error_probability;
data.has_error_probability = has_error_probability;
data.weight_is_negative = std::signbit(weight);
boost::add_edge(
n1,
n2,
data,
matching_graph);
}


void MatchingGraph::HandleNewNegativeWeightEdge(int u, int v, double weight, std::set<int> &fault_ids){
assert(std::signbit(weight));
negative_weight_sum += weight;

for (auto fid : fault_ids){
if (negative_edge_fault_ids.find(fid) != negative_edge_fault_ids.end()){
negative_edge_fault_ids.erase(fid);
} else {
negative_edge_fault_ids.insert(fid);
}
}

for (auto node : {u, v}){
if (negative_edge_syndrome.find(node) != negative_edge_syndrome.end()){
negative_edge_syndrome.erase(node);
} else {
negative_edge_syndrome.insert(node);
}
}

}


void MatchingGraph::ComputeAllPairsShortestPaths(){
int n = boost::num_vertices(matching_graph);
all_distances.clear();
Expand Down Expand Up @@ -445,6 +475,9 @@ std::vector<std::tuple<int,int,WeightedEdgeData>> MatchingGraph::GetEdges() cons
WeightedEdgeData edata = matching_graph[*eit];
int s = boost::source(*eit, matching_graph);
int t = boost::target(*eit, matching_graph);
if (edata.weight_is_negative) {
edata.weight = -1 * edata.weight;
}
std::tuple<int,int,WeightedEdgeData> edge = std::make_tuple(s, t, edata);
edges.push_back(edge);
}
Expand Down
8 changes: 7 additions & 1 deletion src/pymatching/matching_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,14 @@ struct WeightedEdgeData {
double weight;
double error_probability;
bool has_error_probability;
bool weight_is_negative;
WeightedEdgeData();
WeightedEdgeData(
std::set<int> fault_ids,
double weight,
double error_probability,
bool has_error_probability
bool has_error_probability,
bool weight_is_negative
);
std::string repr() const;
};
Expand Down Expand Up @@ -230,6 +232,10 @@ class MatchingGraph{
bool AllEdgesHaveErrorProbabilities() const;
void FlipBoundaryNodesIfNeeded(std::set<int> &defects);
std::string repr() const;
void HandleNewNegativeWeightEdge(int u, int v, double weight, std::set<int> &fault_ids);
std::set<int> negative_edge_syndrome;
std::set<int> negative_edge_fault_ids;
double negative_weight_sum;
private:
/**
* @brief The indices of the boundary nodes
Expand Down
11 changes: 0 additions & 11 deletions tests/test_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,15 +91,6 @@ def test_too_many_checks_per_qubit_raises_value_error():
Matching(H)


def test_negative_weight_raises_value_error():
g = nx.Graph()
g.add_edge(0,1,weight=-1)
with pytest.raises(ValueError):
Matching(g)
with pytest.raises(ValueError):
Matching(csr_matrix([[1,1,0],[0,1,1]]), spacelike_weights=np.array([1,1,-1]))


def test_wrong_check_matrix_type_raises_type_error():
with pytest.raises(TypeError):
Matching("test")
Expand Down Expand Up @@ -138,8 +129,6 @@ def test_weighted_mwpm_from_array():
assert m.matching_graph.distance(1, 2) == 2.
with pytest.raises(ValueError):
m = Matching(H, spacelike_weights=np.array([1.]))
with pytest.raises(ValueError):
m = Matching(H, spacelike_weights=np.array([1., -2.]))


def test_unweighted_stabiliser_graph_from_networkx():
Expand Down
55 changes: 55 additions & 0 deletions tests/test_negative_weghts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import pytest
import numpy as np
import networkx as nx

from pymatching import Matching


@pytest.mark.parametrize("nn", (None, 30))
def test_negative_weight_repetition_code(nn):
m = Matching()
m.add_edge(0, 1, 0, -1)
m.add_edge(1, 2, 1, -1)
m.add_edge(2, 3, 2, -1)
m.add_edge(3, 4, 3, -1)
m.add_edge(4, 5, 4, -1)
m.add_edge(5, 0, 5, -1)
c, w = m.decode([0, 1, 1, 0, 0, 0], return_weight=True, num_neighbours=nn)
assert np.array_equal(c, np.array([1, 0, 1, 1, 1, 1]))
assert w == -5


@pytest.mark.parametrize("nn", (None, 30))
def test_isolated_negative_weight(nn):
m = Matching()
m.add_edge(0, 1, 0, 1)
m.add_edge(1, 2, 1, -10)
m.add_edge(2, 3, 2, 1)
m.add_edge(3, 0, 3, 1)
c, w = m.decode([0, 1, 1, 0], return_weight=True, num_neighbours=nn)
assert np.array_equal(c, np.array([0, 1, 0, 0]))
assert w == -10


@pytest.mark.parametrize("nn", (None, 30))
def test_negative_and_positive_in_matching(nn):
g = nx.Graph()
g.add_edge(0, 1, fault_ids=0, weight=1)
g.add_edge(1, 2, fault_ids=1, weight=-10)
g.add_edge(2, 3, fault_ids=2, weight=1)
g.add_edge(3, 0, fault_ids=3, weight=1)
m = Matching(g)
c, w = m.decode([0, 1, 0, 1], return_weight=True, num_neighbours=nn)
assert np.array_equal(c, np.array([0, 1, 1, 0]))
assert w == -9


def test_negative_weight_edge_returned():
m = Matching()
m.add_edge(0, 1, weight=0.5, error_probability=0.3)
m.add_edge(1, 2, weight=0.5, error_probability=0.3, fault_ids=0)
m.add_edge(2, 3, weight=-0.5, error_probability=0.7, fault_ids={1, 2})
expected = [(0, 1, {'fault_ids': set(), 'weight': 0.5, 'error_probability': 0.3}),
(1, 2, {'fault_ids': {0}, 'weight': 0.5, 'error_probability': 0.3}),
(2, 3, {'fault_ids': {1, 2}, 'weight': -0.5, 'error_probability': 0.7})]
assert m.edges() == expected

0 comments on commit 7d799d1

Please sign in to comment.