Skip to content

Commit

Permalink
Load from stim files (#44)
Browse files Browse the repository at this point in the history
* Add method to load stim circuits, as well as stim circuits and dems from file paths. Fixes #42

* linting

Co-authored-by: Oscar Higgott <oscarhiggott@users.noreply.github.com>
  • Loading branch information
oscarhiggott and oscarhiggott committed Nov 1, 2022
1 parent 7b3aef5 commit ab2e510
Show file tree
Hide file tree
Showing 5 changed files with 165 additions and 12 deletions.
106 changes: 104 additions & 2 deletions src/pymatching/matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -1154,18 +1154,120 @@ def from_detector_error_model(model: 'stim.DetectorErrorModel') -> 'pymatching.M
Parameters
----------
model
model : stim.DetectorErrorModel
A stim DetectorErrorModel, with all error mechanisms either graphlike, or decomposed into graphlike
error mechanisms
Returns
-------
pymatching.Matching
A `pymatching.Matching` object representing the edge-like fault mechanisms in `model`
A `pymatching.Matching` object representing the graphlike error mechanisms in `model`
Examples
--------
>>> import stim
>>> import pymatching
>>> circuit = stim.Circuit.generated("surface_code:rotated_memory_x",
... distance=5,
... rounds=5,
... after_clifford_depolarization=0.005)
>>> model = circuit.detector_error_model(decompose_errors=True)
>>> matching = pymatching.Matching.from_detector_error_model(model)
>>> matching
<pymatching.Matching object with 120 detectors, 0 boundary nodes, and 502 edges>
"""
m = Matching()
m._load_from_detector_error_model(model)
return m

@staticmethod
def from_detector_error_model_file(dem_path: str) -> 'pymatching.Matching':
"""
Construct a `pymatching.Matching` by loading from a stim DetectorErrorModel file path.
Parameters
----------
dem_path : str
The path of the detector error model file
Returns
-------
pymatching.Matching
A `pymatching.Matching` object representing the graphlike error mechanisms in the stim DetectorErrorModel
in the file `dem_path`
"""
m = Matching()
m._matching_graph = _cpp_pm.detector_error_model_file_to_matching_graph(dem_path)
return m

@staticmethod
def from_stim_circuit(circuit: 'stim.Circuit') -> 'pymatching.Matching':
"""
Constructs a `pymatching.Matching` object by loading from a `stim.Circuit`
Parameters
----------
circuit : stim.Circuit
A stim circuit containing error mechanisms that are all either graphlike, or decomposable into
graphlike error mechanisms
Returns
-------
pymatching.Matching
A `pymatching.Matching` object representing the graphlike error mechanisms in `circuit`, with any hyperedge
error mechanisms decomposed into graphlike error mechanisms. Parallel edges are merged using
`merge_strategy="independent"`.
Examples
--------
>>> import stim
>>> import pymatching
>>> circuit = stim.Circuit.generated("surface_code:rotated_memory_x",
... distance=5,
... rounds=5,
... after_clifford_depolarization=0.005)
>>> matching = pymatching.Matching.from_stim_circuit(circuit)
>>> matching
<pymatching.Matching object with 120 detectors, 0 boundary nodes, and 502 edges>
"""
try:
import stim
except ImportError: # pragma no cover
raise TypeError(
f"`circuit` must be a `stim.Circuit. Instead, got: {type(circuit)}.`"
"The 'stim' package also isn't installed and is required for this method. \n"
"To install stim using pip, run `pip install stim`."
)
if not isinstance(circuit, stim.Circuit):
raise TypeError(f"`circuit` must be a `stim.Circuit`. Instead, got {type(circuit)}")
m = Matching()
m._matching_graph = _cpp_pm.detector_error_model_to_matching_graph(
str(circuit.detector_error_model(decompose_errors=True))
)
return m

@staticmethod
def from_stim_circuit_file(stim_circuit_path: str) -> 'pymatching.Matching':
"""
Construct a `pymatching.Matching` by loading from a stim circuit file path.
Parameters
----------
stim_circuit_path : str
The path of the stim circuit file
Returns
-------
pymatching.Matching
A `pymatching.Matching` object representing the graphlike error mechanisms in the stim circuit
in the file `stim_circuit_path`, with any hyperedge error mechanisms decomposed into graphlike error
mechanisms. Parallel edges are merged using `merge_strategy="independent"`.
"""
m = Matching()
m._matching_graph = _cpp_pm.stim_circuit_file_to_matching_graph(stim_circuit_path)
return m

def _load_from_detector_error_model(self, model: 'stim.DetectorErrorModel') -> None:
try:
import stim
Expand Down
28 changes: 23 additions & 5 deletions src/pymatching/sparse_blossom/driver/user_graph.pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ void pm_pybind::pybind_user_graph_methods(py::module &m, py::class_<pm::UserGrap

// Reserve all-zeros predictions array
size_t num_observable_bytes =
bit_packed_predictions ?(self.get_num_observables() + 7) >> 3 : self.get_num_observables();
bit_packed_predictions ? (self.get_num_observables() + 7) >> 3 : self.get_num_observables();
py::array_t<uint8_t> predictions = py::array_t<uint8_t>(shots.shape(0) * num_observable_bytes);
predictions[py::make_tuple(py::ellipsis())] = 0; // Initialise to 0
py::buffer_info buff = predictions.request();
Expand Down Expand Up @@ -287,9 +287,7 @@ void pm_pybind::pybind_user_graph_methods(py::module &m, py::class_<pm::UserGrap
pm::total_weight_int solution_weight = 0;
if (bit_packed_predictions) {
std::fill(temp_predictions.begin(), temp_predictions.end(), 0);
pm::decode_detection_events(
mwpm, detection_events, temp_predictions.data(), solution_weight
);
pm::decode_detection_events(mwpm, detection_events, temp_predictions.data(), solution_weight);
// bitpack the predictions
for (size_t k = 0; k < temp_predictions.size(); k++) {
size_t arr_idx = k >> 3;
Expand Down Expand Up @@ -397,7 +395,27 @@ void pm_pybind::pybind_user_graph_methods(py::module &m, py::class_<pm::UserGrap
auto dem = stim::DetectorErrorModel(dem_string);
return pm::detector_error_model_to_user_graph(dem);
});

m.def("detector_error_model_file_to_matching_graph", [](const char *dem_path) {
FILE *file = fopen(dem_path, "r");
if (file == nullptr) {
std::stringstream msg;
msg << "Failed to open '" << dem_path << "'";
throw std::invalid_argument(msg.str());
}
auto dem = stim::DetectorErrorModel::from_file(file);
return pm::detector_error_model_to_user_graph(dem);
});
m.def("stim_circuit_file_to_matching_graph", [](const char *stim_circuit_path) {
FILE *file = fopen(stim_circuit_path, "r");
if (file == nullptr) {
std::stringstream msg;
msg << "Failed to open '" << stim_circuit_path << "'";
throw std::invalid_argument(msg.str());
}
auto circuit = stim::Circuit::from_file(file);
auto dem = stim::ErrorAnalyzer::circuit_to_detector_error_model(circuit, true, true, false, 0, false, false);
return pm::detector_error_model_to_user_graph(dem);
});
m.def(
"sparse_column_check_matrix_to_matching_graph",
[](const py::object &check_matrix,
Expand Down
5 changes: 5 additions & 0 deletions tests/matching/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
import pathlib
import os

THIS_DIR = pathlib.Path(__file__).parent.resolve()
DATA_DIR = os.path.join(pathlib.Path(THIS_DIR).parent.parent.absolute(), "data")
4 changes: 1 addition & 3 deletions tests/matching/decode_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import pathlib
import numpy as np
from scipy.sparse import csc_matrix
import pytest
Expand All @@ -22,8 +21,7 @@
import pymatching
from pymatching import Matching

THIS_DIR = pathlib.Path(__file__).parent.resolve()
DATA_DIR = os.path.join(pathlib.Path(THIS_DIR).parent.parent.absolute(), "data")
from .config import DATA_DIR


def repetition_code(n):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os

import pytest

from pymatching.matching import Matching
from .config import DATA_DIR


def test_load_from_dem():
def test_load_from_stim_objects():
stim = pytest.importorskip("stim")
c = stim.Circuit.generated("surface_code:rotated_memory_x", distance=5, rounds=5,
after_clifford_depolarization=0.01,
Expand All @@ -33,14 +36,39 @@ def test_load_from_dem():
assert m2.num_detectors == dem.num_detectors
assert m2.num_fault_ids == dem.num_observables
assert m2.num_edges == 502
m3 = Matching.from_stim_circuit(c)
assert m3.num_detectors == dem.num_detectors
assert m3.num_fault_ids == dem.num_observables
assert m3.num_edges == 502


def test_load_from_stim_files():
circuit_path = os.path.join(DATA_DIR, "negative_weight_circuit.stim")
m = Matching.from_stim_circuit_file(circuit_path)
assert m.num_detectors == 2
assert m.num_edges == 2
assert m.num_fault_ids == 1
dem_path = os.path.join(DATA_DIR, "negative_weight_circuit.dem")
m2 = Matching.from_detector_error_model_file(dem_path)
assert m2.edges() == m.edges()
with pytest.raises(ValueError):
Matching.from_stim_circuit_file("fake_filename.stim")
with pytest.raises(ValueError):
Matching.from_detector_error_model_file("fake_filename.dem")
with pytest.raises(ValueError):
Matching.from_stim_circuit_file(dem_path)
with pytest.raises(IndexError):
Matching.from_detector_error_model_file(circuit_path)


def test_load_from_dem_wrong_type_raises_type_error():
def test_load_from_stim_wrong_type_raises_type_error():
stim = pytest.importorskip("stim")
c = stim.Circuit.generated("surface_code:rotated_memory_x", distance=3, rounds=1,
after_clifford_depolarization=0.01)
with pytest.raises(TypeError):
Matching.from_detector_error_model(c)
with pytest.raises(TypeError):
Matching.from_stim_circuit(c.detector_error_model(decompose_errors=True))


def test_load_from_dem_without_stim_raises_type_error():
Expand All @@ -49,3 +77,5 @@ def test_load_from_dem_without_stim_raises_type_error():
except ImportError:
with pytest.raises(TypeError):
Matching.from_detector_error_model("test")
with pytest.raises(TypeError):
Matching.from_stim_circuit("test")

0 comments on commit ab2e510

Please sign in to comment.