From ebab8832756d6ae68fcb0b3c217bc9aa557fef29 Mon Sep 17 00:00:00 2001 From: Laleh Beni Date: Wed, 18 Jun 2025 13:50:46 -0700 Subject: [PATCH 1/3] update the format --- .github/workflows/ci.yml | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 0e06251..c22f167 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -35,5 +35,21 @@ jobs: disk-cache: ${{ github.workflow }} repository-cache: true + - name: Install clang-format (Linux) + if: runner.os == 'Linux' + run: | + sudo apt-get update + sudo apt-get install -y clang-format + + - name: Install clang-format (macOS) + if: runner.os == 'macOS' + run: | + brew install clang-format + + - name: Check formatting + run: | + clang-format --version + files=$(git ls-files '*.cc' '*.h' '*.cpp' '*.hpp') + [ -z "$files" ] || clang-format --dry-run --Werror $files - name: Bazel tests run: bazel test src:all From e7f89e4300a20aca923775f59e3ac077566136c2 Mon Sep 17 00:00:00 2001 From: Laleh Beni Date: Wed, 18 Jun 2025 14:27:42 -0700 Subject: [PATCH 2/3] update the clang format --- .clang-format | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/.clang-format b/.clang-format index 81413c1..6a6633f 100644 --- a/.clang-format +++ b/.clang-format @@ -1,7 +1,6 @@ BasedOnStyle: Google IndentWidth: 2 -ColumnLimit: 80 +ColumnLimit: 100 AllowShortFunctionsOnASingleLine: Empty PointerAlignment: Left -Standard: C++20 SortIncludes: true From e1b3ff480934c82a9d8b66f825a2c97b4c1d1918 Mon Sep 17 00:00:00 2001 From: Laleh Beni Date: Wed, 18 Jun 2025 14:35:55 -0700 Subject: [PATCH 3/3] fix the format --- src/common.cc | 42 +++----- src/common.h | 19 ++-- src/common.pybind.h | 24 ++--- src/common.test.cc | 28 ++---- src/simplex.cc | 75 ++++++-------- src/simplex.h | 7 +- src/simplex.pybind.h | 26 ++--- src/simplex_main.cc | 149 +++++++++++----------------- src/tesseract.cc | 122 ++++++++++------------- src/tesseract.h | 12 +-- src/tesseract.perf.cc | 47 ++++----- src/tesseract.pybind.h | 46 ++++----- src/tesseract.test.cc | 50 ++++------ src/tesseract_main.cc | 216 +++++++++++++++++------------------------ src/test_data.h | 8 +- src/utils.cc | 30 ++---- src/utils.h | 15 +-- 17 files changed, 364 insertions(+), 552 deletions(-) diff --git a/src/common.cc b/src/common.cc index fb5439d..cc97a8a 100644 --- a/src/common.cc +++ b/src/common.cc @@ -52,12 +52,10 @@ common::Error::Error(const stim::DemInstruction& error) { } std::string common::Error::str() { - return "Error{cost=" + std::to_string(likelihood_cost) + - ", symptom=" + symptom.str() + "}"; + return "Error{cost=" + std::to_string(likelihood_cost) + ", symptom=" + symptom.str() + "}"; } -std::vector common::Symptom::as_dem_instruction_targets() - const { +std::vector common::Symptom::as_dem_instruction_targets() const { std::vector targets; for (int d : detectors) { targets.push_back(stim::DemTarget::relative_detector_id(d)); @@ -72,8 +70,7 @@ std::vector common::Symptom::as_dem_instruction_targets() return targets; } -stim::DetectorErrorModel common::merge_identical_errors( - const stim::DetectorErrorModel& dem) { +stim::DetectorErrorModel common::merge_identical_errors(const stim::DetectorErrorModel& dem) { stim::DetectorErrorModel out_dem; // Map to track the distinct symptoms @@ -89,11 +86,9 @@ stim::DetectorErrorModel common::merge_identical_errors( // Merge with existing error with the same symptom (if applicable) if (errors_by_symptom.find(error.symptom) != errors_by_symptom.end()) { double p0 = errors_by_symptom[error.symptom].probability; - error.probability = - p0 * (1 - error.probability) + (1 - p0) * error.probability; + error.probability = p0 * (1 - error.probability) + (1 - p0) * error.probability; } - error.likelihood_cost = - -1 * std::log(error.probability / (1 - error.probability)); + error.likelihood_cost = -1 * std::log(error.probability / (1 - error.probability)); errors_by_symptom[error.symptom] = error; break; } @@ -106,9 +101,9 @@ stim::DetectorErrorModel common::merge_identical_errors( } } for (const auto& it : errors_by_symptom) { - out_dem.append_error_instruction( - it.second.probability, it.second.symptom.as_dem_instruction_targets(), - /*tag=*/""); + out_dem.append_error_instruction(it.second.probability, + it.second.symptom.as_dem_instruction_targets(), + /*tag=*/""); } return out_dem; } @@ -136,19 +131,17 @@ stim::DetectorErrorModel common::remove_zero_probability_errors( return out_dem; } -stim::DetectorErrorModel common::dem_from_counts( - stim::DetectorErrorModel& orig_dem, const std::vector& error_counts, - size_t num_shots) { +stim::DetectorErrorModel common::dem_from_counts(stim::DetectorErrorModel& orig_dem, + const std::vector& error_counts, + size_t num_shots) { if (orig_dem.count_errors() != error_counts.size()) { throw std::invalid_argument( "Error hits array must be the same size as the number of errors in the " "original DEM."); } - for (const stim::DemInstruction& instruction : - orig_dem.flattened().instructions) { - if (instruction.type == stim::DemInstructionType::DEM_ERROR && - instruction.arg_data[0] == 0) { + for (const stim::DemInstruction& instruction : orig_dem.flattened().instructions) { + if (instruction.type == stim::DemInstructionType::DEM_ERROR && instruction.arg_data[0] == 0) { throw std::invalid_argument( "dem_from_counts requires DEMs without zero-probability errors. Use" " remove_zero_probability_errors first."); @@ -157,17 +150,14 @@ stim::DetectorErrorModel common::dem_from_counts( stim::DetectorErrorModel out_dem; size_t ei = 0; - for (const stim::DemInstruction& instruction : - orig_dem.flattened().instructions) { + for (const stim::DemInstruction& instruction : orig_dem.flattened().instructions) { switch (instruction.type) { case stim::DemInstructionType::DEM_SHIFT_DETECTORS: assert(false && "unreachable"); break; case stim::DemInstructionType::DEM_ERROR: { - double est_probability = - double(error_counts.at(ei)) / double(num_shots); - out_dem.append_error_instruction(est_probability, - instruction.target_data, /*tag=*/""); + double est_probability = double(error_counts.at(ei)) / double(num_shots); + out_dem.append_error_instruction(est_probability, instruction.target_data, /*tag=*/""); ++ei; break; } diff --git a/src/common.h b/src/common.h index 754ec28..2395d96 100644 --- a/src/common.h +++ b/src/common.h @@ -51,11 +51,9 @@ struct Error { Symptom symptom; std::vector dets_array; Error() = default; - Error(double likelihood_cost, std::vector& detectors, - ObservablesMask observables, std::vector& dets_array) - : likelihood_cost(likelihood_cost), - symptom{detectors, observables}, - dets_array(dets_array) {} + Error(double likelihood_cost, std::vector& detectors, ObservablesMask observables, + std::vector& dets_array) + : likelihood_cost(likelihood_cost), symptom{detectors, observables}, dets_array(dets_array) {} Error(double likelihood_cost, double probability, std::vector& detectors, ObservablesMask observables, std::vector& dets_array) : likelihood_cost(likelihood_cost), @@ -68,21 +66,18 @@ struct Error { // Makes a new (flattened) dem where identical error mechanisms have been // merged. -stim::DetectorErrorModel merge_identical_errors( - const stim::DetectorErrorModel& dem); +stim::DetectorErrorModel merge_identical_errors(const stim::DetectorErrorModel& dem); // Returns a copy of the given error model with any zero-probability DEM_ERROR // instructions removed. -stim::DetectorErrorModel remove_zero_probability_errors( - const stim::DetectorErrorModel& dem); +stim::DetectorErrorModel remove_zero_probability_errors(const stim::DetectorErrorModel& dem); // Makes a new dem where the probabilities of errors are estimated from the // fraction of shots they were used in. // Throws std::invalid_argument if `orig_dem` contains zero-probability errors; // call remove_zero_probability_errors first. -stim::DetectorErrorModel dem_from_counts( - stim::DetectorErrorModel& orig_dem, const std::vector& error_counts, - size_t num_shots); +stim::DetectorErrorModel dem_from_counts(stim::DetectorErrorModel& orig_dem, + const std::vector& error_counts, size_t num_shots); } // namespace common diff --git a/src/common.pybind.h b/src/common.pybind.h index 04c4e40..5facc9a 100644 --- a/src/common.pybind.h +++ b/src/common.pybind.h @@ -32,8 +32,7 @@ void add_common_module(py::module &root) { py::class_(m, "Symptom") .def(py::init, common::ObservablesMask>(), - py::arg("detectors") = std::vector(), - py::arg("observables") = 0) + py::arg("detectors") = std::vector(), py::arg("observables") = 0) .def_readwrite("detectors", &common::Symptom::detectors) .def_readwrite("observables", &common::Symptom::observables) .def("__str__", &common::Symptom::str) @@ -51,25 +50,22 @@ void add_common_module(py::module &root) { .def_readwrite("symptom", &common::Error::symptom) .def("__str__", &common::Error::str) .def(py::init<>()) - .def(py::init &, common::ObservablesMask, - std::vector &>(), - py::arg("likelihood_cost"), py::arg("detectors"), - py::arg("observables"), py::arg("dets_array")) + .def(py::init &, common::ObservablesMask, std::vector &>(), + py::arg("likelihood_cost"), py::arg("detectors"), py::arg("observables"), + py::arg("dets_array")) .def(py::init &, common::ObservablesMask, std::vector &>(), - py::arg("likelihood_cost"), py::arg("probability"), - py::arg("detectors"), py::arg("observables"), py::arg("dets_array")) + py::arg("likelihood_cost"), py::arg("probability"), py::arg("detectors"), + py::arg("observables"), py::arg("dets_array")) .def(py::init([](stim_pybind::ExposedDemInstruction edi) { return new common::Error(edi.as_dem_instruction()); }), py::arg("error")); - m.def("merge_identical_errors", &common::merge_identical_errors, - py::arg("dem")); - m.def("remove_zero_probability_errors", - &common::remove_zero_probability_errors, py::arg("dem")); - m.def("dem_from_counts", &common::dem_from_counts, py::arg("orig_dem"), - py::arg("error_counts"), py::arg("num_shots")); + m.def("merge_identical_errors", &common::merge_identical_errors, py::arg("dem")); + m.def("remove_zero_probability_errors", &common::remove_zero_probability_errors, py::arg("dem")); + m.def("dem_from_counts", &common::dem_from_counts, py::arg("orig_dem"), py::arg("error_counts"), + py::arg("num_shots")); } #endif diff --git a/src/common.test.cc b/src/common.test.cc index a3a92df..349c013 100644 --- a/src/common.test.cc +++ b/src/common.test.cc @@ -38,9 +38,7 @@ TEST(common, DemFromCountsRejectsZeroProbabilityErrors) { std::vector counts{1, 7, 4}; size_t num_shots = 10; - EXPECT_THROW({ - common::dem_from_counts(dem, counts, num_shots); - }, std::invalid_argument); + EXPECT_THROW({ common::dem_from_counts(dem, counts, num_shots); }, std::invalid_argument); stim::DetectorErrorModel cleaned = common::remove_zero_probability_errors(dem); stim::DetectorErrorModel out_dem = @@ -50,11 +48,9 @@ TEST(common, DemFromCountsRejectsZeroProbabilityErrors) { ASSERT_EQ(out_dem.count_errors(), 2); ASSERT_GE(flat.instructions.size(), 2); - EXPECT_EQ(flat.instructions[0].type, - stim::DemInstructionType::DEM_ERROR); + EXPECT_EQ(flat.instructions[0].type, stim::DemInstructionType::DEM_ERROR); EXPECT_NEAR(flat.instructions[0].arg_data[0], 0.1, 1e-9); - ASSERT_EQ(flat.instructions[1].type, - stim::DemInstructionType::DEM_ERROR); + ASSERT_EQ(flat.instructions[1].type, stim::DemInstructionType::DEM_ERROR); EXPECT_NEAR(flat.instructions[1].arg_data[0], 0.4, 1e-9); } @@ -68,18 +64,15 @@ TEST(common, DemFromCountsSimpleTwoErrors) { std::vector counts{5, 7}; size_t num_shots = 20; - stim::DetectorErrorModel out_dem = - common::dem_from_counts(dem, counts, num_shots); + stim::DetectorErrorModel out_dem = common::dem_from_counts(dem, counts, num_shots); auto flat = out_dem.flattened(); ASSERT_EQ(out_dem.count_errors(), 2); ASSERT_GE(flat.instructions.size(), 2); - EXPECT_EQ(flat.instructions[0].type, - stim::DemInstructionType::DEM_ERROR); + EXPECT_EQ(flat.instructions[0].type, stim::DemInstructionType::DEM_ERROR); EXPECT_NEAR(flat.instructions[0].arg_data[0], 0.25, 1e-9); - EXPECT_EQ(flat.instructions[1].type, - stim::DemInstructionType::DEM_ERROR); + EXPECT_EQ(flat.instructions[1].type, stim::DemInstructionType::DEM_ERROR); EXPECT_NEAR(flat.instructions[1].arg_data[0], 0.35, 1e-9); } @@ -93,15 +86,12 @@ TEST(common, RemoveZeroProbabilityErrors) { detector(0, 0, 0) D2 )DEM"); - stim::DetectorErrorModel cleaned = - common::remove_zero_probability_errors(dem); + stim::DetectorErrorModel cleaned = common::remove_zero_probability_errors(dem); EXPECT_EQ(cleaned.count_errors(), 2); auto flat = cleaned.flattened(); - ASSERT_EQ(flat.instructions[0].type, - stim::DemInstructionType::DEM_ERROR); + ASSERT_EQ(flat.instructions[0].type, stim::DemInstructionType::DEM_ERROR); EXPECT_NEAR(flat.instructions[0].arg_data[0], 0.1, 1e-9); - ASSERT_EQ(flat.instructions[1].type, - stim::DemInstructionType::DEM_ERROR); + ASSERT_EQ(flat.instructions[1].type, stim::DemInstructionType::DEM_ERROR); EXPECT_NEAR(flat.instructions[1].arg_data[0], 0.2, 1e-9); } diff --git a/src/simplex.cc b/src/simplex.cc index cc361dd..c8c5f3b 100644 --- a/src/simplex.cc +++ b/src/simplex.cc @@ -22,7 +22,7 @@ constexpr size_t T_COORD = 2; std::string SimplexConfig::str() { - auto & self = *this; + auto& self = *this; std::stringstream ss; ss << "SimplexConfig("; ss << "dem=" << "DetectorErrorModel_Object" << ", "; @@ -35,8 +35,7 @@ std::string SimplexConfig::str() { SimplexDecoder::SimplexDecoder(SimplexConfig _config) : config(_config) { config.dem = common::remove_zero_probability_errors(config.dem); std::vector detector_t_coords(config.dem.count_detectors()); - for (const stim::DemInstruction& instruction : - config.dem.flattened().instructions) { + for (const stim::DemInstruction& instruction : config.dem.flattened().instructions) { switch (instruction.type) { case stim::DemInstructionType::DEM_SHIFT_DETECTORS: assert(false && "unreachable"); @@ -47,15 +46,13 @@ SimplexDecoder::SimplexDecoder(SimplexConfig _config) : config(_config) { break; } case stim::DemInstructionType::DEM_DETECTOR: - detector_t_coords[instruction.target_data[0].val()] = - instruction.arg_data[T_COORD]; + detector_t_coords[instruction.target_data[0].val()] = instruction.arg_data[T_COORD]; break; default: assert(false && "unreachable"); } } - std::map> start_time_to_errors_map, - end_time_to_errors_map; + std::map> start_time_to_errors_map, end_time_to_errors_map; std::set times; for (size_t ei = 0; ei < errors.size(); ++ei) { double min_error_time = std::numeric_limits::max(); @@ -195,8 +192,7 @@ void SimplexDecoder::decode_to_errors(const std::vector& detections) { bool solution_empty = true; while (t1 < start_time_to_errors.size() or solution_empty) { - for (size_t step = 0; step < config.window_slide_length && - t1 < start_time_to_errors.size(); + for (size_t step = 0; step < config.window_slide_length && t1 < start_time_to_errors.size(); ++step) { add_costs_for_time(t1); ++t1; @@ -208,8 +204,8 @@ void SimplexDecoder::decode_to_errors(const std::vector& detections) { // Pass the model to HiGHS *return_status = highs->passModel(*model); if (*return_status != HighsStatus::kOk) { - std::cerr << "Error: passModel failed with status: " - << highsStatusToString(*return_status) << std::endl; + std::cerr << "Error: passModel failed with status: " << highsStatusToString(*return_status) + << std::endl; } assert(*return_status == HighsStatus::kOk); @@ -226,14 +222,14 @@ void SimplexDecoder::decode_to_errors(const std::vector& detections) { // Solve the model *return_status = highs->run(); if (*return_status != HighsStatus::kOk) { - std::cerr << "Error: run failed with status: " - << highsStatusToString(*return_status) << std::endl; + std::cerr << "Error: run failed with status: " << highsStatusToString(*return_status) + << std::endl; // Write out the model in mps format for debugging HighsStatus write_return_status = writeModelAsMps(highs->getOptions(), "bad_shot.mps", *model, /*free_format=*/true); - std::cerr << "Write return had status: " - << highsStatusToString(write_return_status) << std::endl; + std::cerr << "Write return had status: " << highsStatusToString(write_return_status) + << std::endl; assert(write_return_status == HighsStatus::kOk or write_return_status == HighsStatus::kWarning); } @@ -242,19 +238,13 @@ void SimplexDecoder::decode_to_errors(const std::vector& detections) { if (config.verbose) { // Get the solution information const HighsInfo& info = highs->getInfo(); - std::cout << "Simplex iteration count: " << info.simplex_iteration_count - << std::endl; - std::cout << "Objective function value: " - << info.objective_function_value << std::endl; + std::cout << "Simplex iteration count: " << info.simplex_iteration_count << std::endl; + std::cout << "Objective function value: " << info.objective_function_value << std::endl; std::cout << "Primal solution status: " - << highs->solutionStatusToString(info.primal_solution_status) - << std::endl; + << highs->solutionStatusToString(info.primal_solution_status) << std::endl; std::cout << "Dual solution status: " - << highs->solutionStatusToString(info.dual_solution_status) - << std::endl; - std::cout << "Basis: " - << highs->basisValidityToString(info.basis_validity) - << std::endl; + << highs->solutionStatusToString(info.dual_solution_status) << std::endl; + std::cout << "Basis: " << highs->basisValidityToString(info.basis_validity) << std::endl; } // Get the model status @@ -270,8 +260,7 @@ void SimplexDecoder::decode_to_errors(const std::vector& detections) { assert(!solution.hasUndefined()); solution_empty = false; - for (size_t step = 0; - step < config.window_slide_length && t0 < end_time_to_errors.size(); + for (size_t step = 0; step < config.window_slide_length && t0 < end_time_to_errors.size(); ++step) { // Freeze all errors at time slice t0 to their current values, and // increment t0 @@ -300,24 +289,17 @@ void SimplexDecoder::decode_to_errors(const std::vector& detections) { if (config.verbose) { // Get the solution information const HighsInfo& info = highs->getInfo(); - std::cout << "Simplex iteration count: " << info.simplex_iteration_count - << std::endl; - std::cout << "Objective function value: " << info.objective_function_value - << std::endl; + std::cout << "Simplex iteration count: " << info.simplex_iteration_count << std::endl; + std::cout << "Objective function value: " << info.objective_function_value << std::endl; std::cout << "Primal solution status: " - << highs->solutionStatusToString(info.primal_solution_status) - << std::endl; + << highs->solutionStatusToString(info.primal_solution_status) << std::endl; std::cout << "Dual solution status: " - << highs->solutionStatusToString(info.dual_solution_status) - << std::endl; - std::cout << "Basis: " - << highs->basisValidityToString(info.basis_validity) - << std::endl; + << highs->solutionStatusToString(info.dual_solution_status) << std::endl; + std::cout << "Basis: " << highs->basisValidityToString(info.basis_validity) << std::endl; } // Get the model status - [[maybe_unused]] const HighsModelStatus& model_status = - highs->getModelStatus(); + [[maybe_unused]] const HighsModelStatus& model_status = highs->getModelStatus(); assert(model_status == HighsModelStatus::kOptimal); } @@ -335,8 +317,7 @@ void SimplexDecoder::decode_to_errors(const std::vector& detections) { } } -double SimplexDecoder::cost_from_errors( - const std::vector& predicted_errors) { +double SimplexDecoder::cost_from_errors(const std::vector& predicted_errors) { double total_cost = 0; // Iterate over all errors and add to the mask for (size_t ei : predicted_errors_buffer) { @@ -355,15 +336,13 @@ common::ObservablesMask SimplexDecoder::mask_from_errors( return mask; } -common::ObservablesMask SimplexDecoder::decode( - const std::vector& detections) { +common::ObservablesMask SimplexDecoder::decode(const std::vector& detections) { decode_to_errors(detections); return mask_from_errors(predicted_errors_buffer); } -void SimplexDecoder::decode_shots( - std::vector& shots, - std::vector& obs_predicted) { +void SimplexDecoder::decode_shots(std::vector& shots, + std::vector& obs_predicted) { obs_predicted.resize(shots.size()); for (size_t i = 0; i < shots.size(); ++i) { obs_predicted[i] = decode(shots[i].hits); diff --git a/src/simplex.h b/src/simplex.h index 91402dc..bdbd13b 100644 --- a/src/simplex.h +++ b/src/simplex.h @@ -29,7 +29,9 @@ struct SimplexConfig { size_t window_length = 0; size_t window_slide_length = 0; bool verbose = false; - bool windowing_enabled() { return (window_length != 0); } + bool windowing_enabled() { + return (window_length != 0); + } std::string str(); }; @@ -60,8 +62,7 @@ struct SimplexDecoder { void decode_to_errors(const std::vector& detections); // Returns the bitwise XOR of all the observables bitmasks of all errors in // the predicted errors buffer. - common::ObservablesMask mask_from_errors( - const std::vector& predicted_errors); + common::ObservablesMask mask_from_errors(const std::vector& predicted_errors); // Returns the sum of the likelihood costs (minus-log-likelihood-ratios) of // all errors in the predicted errors buffer. double cost_from_errors(const std::vector& predicted_errors); diff --git a/src/simplex.pybind.h b/src/simplex.pybind.h index a69e968..32422c3 100644 --- a/src/simplex.pybind.h +++ b/src/simplex.pybind.h @@ -25,14 +25,13 @@ namespace py = pybind11; void add_simplex_module(py::module &root) { - auto m = root.def_submodule( - "simplex", "Module containing the SimplexDecoder and related methods"); + auto m = + root.def_submodule("simplex", "Module containing the SimplexDecoder and related methods"); py::class_(m, "SimplexConfig") - .def(py::init(), - py::arg("dem"), py::arg("parallelize") = false, - py::arg("window_length") = 0, py::arg("window_slide_length") = 0, - py::arg("verbose") = false) + .def(py::init(), py::arg("dem"), + py::arg("parallelize") = false, py::arg("window_length") = 0, + py::arg("window_slide_length") = 0, py::arg("verbose") = false) .def_readwrite("dem", &SimplexConfig::dem) .def_readwrite("parallelize", &SimplexConfig::parallelize) .def_readwrite("window_length", &SimplexConfig::window_length) @@ -47,20 +46,15 @@ void add_simplex_module(py::module &root) { .def_readwrite("errors", &SimplexDecoder::errors) .def_readwrite("num_detectors", &SimplexDecoder::num_detectors) .def_readwrite("num_observables", &SimplexDecoder::num_observables) - .def_readwrite("predicted_errors_buffer", - &SimplexDecoder::predicted_errors_buffer) + .def_readwrite("predicted_errors_buffer", &SimplexDecoder::predicted_errors_buffer) .def_readwrite("error_masks", &SimplexDecoder::error_masks) - .def_readwrite("start_time_to_errors", - &SimplexDecoder::start_time_to_errors) + .def_readwrite("start_time_to_errors", &SimplexDecoder::start_time_to_errors) .def_readwrite("end_time_to_errors", &SimplexDecoder::end_time_to_errors) .def_readonly("low_confidence_flag", &SimplexDecoder::low_confidence_flag) .def("init_ilp", &SimplexDecoder::init_ilp) - .def("decode_to_errors", &SimplexDecoder::decode_to_errors, - py::arg("detections")) - .def("mask_from_errors", &SimplexDecoder::mask_from_errors, - py::arg("predicted_errors")) - .def("cost_from_errors", &SimplexDecoder::cost_from_errors, - py::arg("predicted_errors")) + .def("decode_to_errors", &SimplexDecoder::decode_to_errors, py::arg("detections")) + .def("mask_from_errors", &SimplexDecoder::mask_from_errors, py::arg("predicted_errors")) + .def("cost_from_errors", &SimplexDecoder::cost_from_errors, py::arg("predicted_errors")) .def("decode", &SimplexDecoder::decode, py::arg("detections")); } #endif diff --git a/src/simplex_main.cc b/src/simplex_main.cc index e4cb9eb..9b4e95d 100644 --- a/src/simplex_main.cc +++ b/src/simplex_main.cc @@ -74,14 +74,12 @@ struct Args { bool print_stats = false; bool has_observables() { - return append_observables || !obs_in_fname.empty() || - (sample_num_shots > 0); + return append_observables || !obs_in_fname.empty() || (sample_num_shots > 0); } void validate() { if (circuit_path.empty() and dem_path.empty()) { - throw std::invalid_argument( - "Must provide at least one of --circuit or --dem"); + throw std::invalid_argument("Must provide at least one of --circuit or --dem"); } int num_data_sources = int(sample_num_shots > 0) + int(!in_fname.empty()); @@ -89,23 +87,18 @@ struct Args { throw std::invalid_argument("Requires exactly 1 source of shots."); } if (!in_fname.empty() and in_format.empty()) { - throw std::invalid_argument( - "If --in is provided, must also specify --in-format."); + throw std::invalid_argument("If --in is provided, must also specify --in-format."); } if (!out_fname.empty() and out_format.empty()) { - throw std::invalid_argument( - "If --out is provided, must also specify --out-format."); + throw std::invalid_argument("If --out is provided, must also specify --out-format."); } - if (!in_format.empty() && - !stim::format_name_to_enum_map().contains(in_format)) { + if (!in_format.empty() && !stim::format_name_to_enum_map().contains(in_format)) { throw std::invalid_argument("Invalid format: " + in_format); } - if (!obs_in_format.empty() && - !stim::format_name_to_enum_map().contains(obs_in_format)) { + if (!obs_in_format.empty() && !stim::format_name_to_enum_map().contains(obs_in_format)) { throw std::invalid_argument("Invalid format: " + obs_in_format); } - if (!out_format.empty() && - !stim::format_name_to_enum_map().contains(out_format)) { + if (!out_format.empty() && !stim::format_name_to_enum_map().contains(out_format)) { throw std::invalid_argument("Invalid format: " + out_format); } if (!obs_in_fname.empty() and in_fname.empty()) { @@ -122,8 +115,7 @@ struct Args { } if (shot_range_begin or shot_range_end) { if (shot_range_end < shot_range_begin) { - throw std::invalid_argument( - "Provided shot range must have end >= begin."); + throw std::invalid_argument("Provided shot range must have end >= begin."); } } if ((window_length != 0) != (window_slide_length != 0)) { @@ -132,8 +124,7 @@ struct Args { "length > 0 is provided."); } if (window_slide_length > window_length) { - throw std::invalid_argument( - "Must have window_slide_length <= window_length"); + throw std::invalid_argument("Must have window_slide_length <= window_length"); } if (sample_num_shots > 0 and circuit_path.empty()) { throw std::invalid_argument("Cannot sample shots without a circuit."); @@ -181,8 +172,8 @@ struct Args { assert(!circuit_path.empty()); std::mt19937_64 rng(sample_seed); size_t num_detectors = circuit.count_detectors(); - const auto [dets, obs] = stim::sample_batch_detection_events<64>( - circuit, sample_num_shots, rng); + const auto [dets, obs] = + stim::sample_batch_detection_events<64>(circuit, sample_num_shots, rng); stim::simd_bit_table<64> obs_T = obs.transposed(); shots.resize(sample_num_shots); for (size_t k = 0; k < sample_num_shots; k++) { @@ -201,8 +192,7 @@ struct Args { if (!shots_file) { throw std::invalid_argument("Could not open the file: " + in_fname); } - stim::FileFormatData shots_in_format = - stim::format_name_to_enum_map().at(in_format); + stim::FileFormatData shots_in_format = stim::format_name_to_enum_map().at(in_format); auto reader = stim::MeasureRecordReader::make( shots_file, shots_in_format.id, 0, config.dem.count_detectors(), append_observables * config.dem.count_observables()); @@ -223,12 +213,9 @@ struct Args { if (!obs_file) { throw std::invalid_argument("Could not open the file: " + obs_in_fname); } - stim::FileFormatData shots_obs_in_format = - stim::format_name_to_enum_map().at(obs_in_format); - auto obs_reader = - stim::MeasureRecordReader::make( - obs_file, shots_obs_in_format.id, 0, 0, - config.dem.count_observables()); + stim::FileFormatData shots_obs_in_format = stim::format_name_to_enum_map().at(obs_in_format); + auto obs_reader = stim::MeasureRecordReader::make( + obs_file, shots_obs_in_format.id, 0, 0, config.dem.count_observables()); stim::SparseShot sparse_shot; sparse_shot.clear(); size_t num_obs_shots = 0; @@ -250,24 +237,21 @@ struct Args { if (shot_range_begin or shot_range_end) { assert(shot_range_end >= shot_range_begin); if (shot_range_end > shots.size()) { - throw std::invalid_argument( - "Shot range end is past end of shots array."); + throw std::invalid_argument("Shot range end is past end of shots array."); } - std::vector shots_in_range( - shots.begin() + shot_range_begin, shots.begin() + shot_range_end); + std::vector shots_in_range(shots.begin() + shot_range_begin, + shots.begin() + shot_range_end); std::swap(shots_in_range, shots); } if (!out_fname.empty()) { // Create a writer instance to write the predicted obs to a file - stim::FileFormatData predictions_out_format = - stim::format_name_to_enum_map().at(out_format); + stim::FileFormatData predictions_out_format = stim::format_name_to_enum_map().at(out_format); FILE* predictions_file = stdout; if (out_fname != "-") { predictions_file = fopen(out_fname.c_str(), "w"); } - writer = stim::MeasureRecordWriter::make(predictions_file, - predictions_out_format.id); + writer = stim::MeasureRecordWriter::make(predictions_file, predictions_out_format.id); writer->begin_result_type('L'); // TODO: ensure the fclose happens after all predictions are written to // the writer. @@ -284,12 +268,8 @@ int main(int argc, char* argv[]) { std::cout.precision(16); argparse::ArgumentParser program("simplex"); Args args; - program.add_argument("--circuit") - .help("Stim circuit file path") - .store_into(args.circuit_path); - program.add_argument("--dem") - .help("Stim dem file path") - .store_into(args.dem_path); + program.add_argument("--circuit").help("Stim circuit file path").store_into(args.circuit_path); + program.add_argument("--dem").help("Stim dem file path").store_into(args.dem_path); program.add_argument("--no-merge-errors") .help("If provided, will not merge identical error mechanisms.") .store_into(args.no_merge_errors); @@ -332,8 +312,7 @@ int main(int argc, char* argv[]) { .default_value(size_t(0)) .store_into(args.shot_range_end); program.add_argument("--in") - .help( - "File to read detection events (and possibly observable flips) from") + .help("File to read detection events (and possibly observable flips) from") .metavar("filename") .default_value(std::string("")) .store_into(args.in_fname); @@ -345,14 +324,11 @@ int main(int argc, char* argv[]) { in_formats += key; } program.add_argument("--in-format", "--in_format") - .help("Format of the file to read detection events from (" + in_formats + - ")") + .help("Format of the file to read detection events from (" + in_formats + ")") .metavar(in_formats) .default_value(std::string("")) .store_into(args.in_format); - program - .add_argument("--in-includes-appended-observables", - "--in_includes_appended_observables") + program.add_argument("--in-includes-appended-observables", "--in_includes_appended_observables") .help( "If present, assumes that the observable flips are appended to the " "end of each shot.") @@ -375,8 +351,7 @@ int main(int argc, char* argv[]) { .default_value(std::string("")) .store_into(args.out_fname); program.add_argument("--out-format") - .help("Format of the file to write observable flip predictions to (" + - in_formats + ")") + .help("Format of the file to write observable flip predictions to (" + in_formats + ")") .metavar(in_formats) .default_value(std::string("")) .store_into(args.out_format); @@ -456,42 +431,38 @@ int main(int argc, char* argv[]) { // After this value returns to 0, we know that no further shots will // transition to finished. ++num_worker_threads_active; - decoder_threads.push_back(std::thread( - [&config, &next_unclaimed_shot, &shots, &obs_predicted, &cost_predicted, - &decoding_time_seconds, &finished, &error_use_totals, &has_obs, - &worker_threads_please_terminate, &num_worker_threads_active]() { - SimplexDecoder decoder(config); - std::vector error_use(config.dem.count_errors()); - for (size_t shot; !worker_threads_please_terminate and - ((shot = next_unclaimed_shot++) < shots.size());) { - auto start_time = std::chrono::high_resolution_clock::now(); - decoder.decode_to_errors(shots[shot].hits); - auto stop_time = std::chrono::high_resolution_clock::now(); - decoding_time_seconds[shot] = - std::chrono::duration_cast( - stop_time - start_time) - .count() / - 1e6; - obs_predicted[shot] = - decoder.mask_from_errors(decoder.predicted_errors_buffer); - cost_predicted[shot] = - decoder.cost_from_errors(decoder.predicted_errors_buffer); - if (!has_obs or - shots[shot].obs_mask_as_u64() == obs_predicted[shot]) { - // Only count the error uses for shots that did not have a logical - // error, if we know the obs flips. - for (size_t ei : decoder.predicted_errors_buffer) { - ++error_use[ei]; - } - } - finished[shot] = true; + decoder_threads.push_back(std::thread([&config, &next_unclaimed_shot, &shots, &obs_predicted, + &cost_predicted, &decoding_time_seconds, &finished, + &error_use_totals, &has_obs, + &worker_threads_please_terminate, + &num_worker_threads_active]() { + SimplexDecoder decoder(config); + std::vector error_use(config.dem.count_errors()); + for (size_t shot; + !worker_threads_please_terminate and ((shot = next_unclaimed_shot++) < shots.size());) { + auto start_time = std::chrono::high_resolution_clock::now(); + decoder.decode_to_errors(shots[shot].hits); + auto stop_time = std::chrono::high_resolution_clock::now(); + decoding_time_seconds[shot] = + std::chrono::duration_cast(stop_time - start_time).count() / + 1e6; + obs_predicted[shot] = decoder.mask_from_errors(decoder.predicted_errors_buffer); + cost_predicted[shot] = decoder.cost_from_errors(decoder.predicted_errors_buffer); + if (!has_obs or shots[shot].obs_mask_as_u64() == obs_predicted[shot]) { + // Only count the error uses for shots that did not have a logical + // error, if we know the obs flips. + for (size_t ei : decoder.predicted_errors_buffer) { + ++error_use[ei]; } - // Add the error counts to the total - for (size_t ei = 0; ei < config.dem.count_errors(); ++ei) { - error_use_totals[ei] += error_use[ei]; - } - --num_worker_threads_active; - })); + } + finished[shot] = true; + } + // Add the error counts to the total + for (size_t ei = 0; ei < config.dem.count_errors(); ++ei) { + error_use_totals[ei] += error_use[ei]; + } + --num_worker_threads_active; + })); } size_t num_errors = 0; double total_time_seconds = 0; @@ -522,8 +493,7 @@ int main(int argc, char* argv[]) { total_time_seconds += decoding_time_seconds[shot]; if (args.print_stats) { - std::cout << "num_shots = " << (shot + 1) - << " num_errors = " << num_errors + std::cout << "num_shots = " << (shot + 1) << " num_errors = " << num_errors << " total_time_seconds = " << total_time_seconds << std::endl; std::cout << "cost = " << cost_predicted[shot] << std::endl; std::cout.flush(); @@ -538,8 +508,7 @@ int main(int argc, char* argv[]) { } if (!args.dem_out_fname.empty()) { - std::vector counts(error_use_totals.begin(), - error_use_totals.end()); + std::vector counts(error_use_totals.begin(), error_use_totals.end()); size_t num_usage_dem_shots = shot; if (has_obs) { // When we know the obs, we only count non-error shots. diff --git a/src/tesseract.cc b/src/tesseract.cc index 0c1099c..e634407 100644 --- a/src/tesseract.cc +++ b/src/tesseract.cc @@ -18,32 +18,27 @@ #include #include -namespace -{ - - template - std::ostream &operator<<(std::ostream &os, const std::vector &vec) - { - os << "["; - bool is_first = true; - for (auto &x : vec) - { - if (!is_first) - { - os << ", "; - } - is_first = false; - os << x; +namespace { + +template +std::ostream& operator<<(std::ostream& os, const std::vector& vec) { + os << "["; + bool is_first = true; + for (auto& x : vec) { + if (!is_first) { + os << ", "; } - os << "]"; - return os; + is_first = false; + os << x; } + os << "]"; + return os; +} -}; +}; // namespace -std::string TesseractConfig::str() -{ - auto &config = *this; +std::string TesseractConfig::str() { + auto& config = *this; std::stringstream ss; ss << "TesseractConfig("; ss << "dem=DetectorErrorModel_Object" << ", "; @@ -61,10 +56,9 @@ bool Node::operator>(const Node& other) const { return cost > other.cost || (cost == other.cost && num_dets < other.num_dets); } -std::string Node::str() -{ +std::string Node::str() { std::stringstream ss; - auto &self = *this; + auto& self = *this; ss << "Node("; ss << "errs=" << self.errs << ", "; ss << "dets=" << self.dets << ", "; @@ -75,7 +69,7 @@ std::string Node::str() } std::string QNode::str() { - auto & self = *this; + auto& self = *this; std::stringstream ss; ss << "QNode("; ss << "cost=" << self.cost << ", "; @@ -84,8 +78,7 @@ std::string QNode::str() { return ss.str(); } -double TesseractDecoder::get_detcost(size_t d, - const std::vector& blocked_errs, +double TesseractDecoder::get_detcost(size_t d, const std::vector& blocked_errs, const std::vector& det_counts) const { double min_cost = INF; for (size_t ei : d2e[d]) { @@ -134,8 +127,7 @@ void TesseractDecoder::initialize_structures(size_t num_detectors) { eneighbors.resize(num_errors); std::vector> edets_sets(edets.size()); for (size_t ei = 0; ei < edets.size(); ++ei) { - edets_sets[ei] = - std::unordered_set(edets[ei].begin(), edets[ei].end()); + edets_sets[ei] = std::unordered_set(edets[ei].begin(), edets[ei].end()); } for (size_t ei = 0; ei < num_errors; ++ei) { std::set neighbor_set; @@ -170,8 +162,7 @@ struct VectorCharHash { } }; -void TesseractDecoder::decode_to_errors( - const std::vector& detections) { +void TesseractDecoder::decode_to_errors(const std::vector& detections) { std::vector best_errors; double best_cost = std::numeric_limits::max(); assert(config.det_orders.size()); @@ -187,16 +178,14 @@ void TesseractDecoder::decode_to_errors( best_cost = this_cost; } if (config.verbose) { - std::cout << "for det_order " << det_order << " beam " << beam - << " got low confidence " << low_confidence_flag - << " and cost " << this_cost << " and obs_mask " + std::cout << "for det_order " << det_order << " beam " << beam << " got low confidence " + << low_confidence_flag << " and cost " << this_cost << " and obs_mask " << mask_from_errors(predicted_errors_buffer) << ". Best cost so far: " << best_cost << std::endl; } } } else { - for (size_t det_order = 0; det_order < config.det_orders.size(); - ++det_order) { + for (size_t det_order = 0; det_order < config.det_orders.size(); ++det_order) { decode_to_errors(detections, det_order); double this_cost = cost_from_errors(predicted_errors_buffer); if (!low_confidence_flag && this_cost < best_cost) { @@ -204,11 +193,9 @@ void TesseractDecoder::decode_to_errors( best_cost = this_cost; } if (config.verbose) { - std::cout << "for det_order " << det_order << " beam " - << config.det_beam << " got low confidence " - << low_confidence_flag << " and cost " << this_cost - << " and obs_mask " - << mask_from_errors(predicted_errors_buffer) + std::cout << "for det_order " << det_order << " beam " << config.det_beam + << " got low confidence " << low_confidence_flag << " and cost " << this_cost + << " and obs_mask " << mask_from_errors(predicted_errors_buffer) << ". Best cost so far: " << best_cost << std::endl; } } @@ -222,8 +209,7 @@ bool QNode::operator>(const QNode& other) const { return cost > other.cost || (cost == other.cost && num_dets < other.num_dets); } -void TesseractDecoder::to_node(const QNode& qnode, - const std::vector& shot_dets, +void TesseractDecoder::to_node(const QNode& qnode, const std::vector& shot_dets, size_t det_order, Node& node) const { node.cost = qnode.cost; node.errs = qnode.errs; @@ -260,8 +246,7 @@ void TesseractDecoder::to_node(const QNode& qnode, } } -void TesseractDecoder::decode_to_errors(const std::vector& detections, - size_t det_order) { +void TesseractDecoder::decode_to_errors(const std::vector& detections, size_t det_order) { size_t det_beam = config.det_beam; predicted_errors_buffer.clear(); low_confidence_flag = false; @@ -271,9 +256,7 @@ void TesseractDecoder::decode_to_errors(const std::vector& detections, } std::priority_queue, std::greater> pq; - std::unordered_map, VectorCharHash>> - discovered_dets; + std::unordered_map, VectorCharHash>> discovered_dets; size_t min_num_dets = detections.size(); std::vector errs; @@ -340,18 +323,15 @@ void TesseractDecoder::decode_to_errors(const std::vector& detections, if (node.num_dets > max_num_dets) continue; - if (config.no_revisit_dets && - !discovered_dets[node.num_dets].insert(node.dets).second) { + if (config.no_revisit_dets && !discovered_dets[node.num_dets].insert(node.dets).second) { continue; } if (config.verbose) { std::cout.precision(13); - std::cout << "len(pq) = " << pq.size() - << " num_pq_pushed = " << num_pq_pushed << std::endl; - std::cout << "num_dets = " << node.num_dets - << " max_num_dets = " << max_num_dets << " cost = " << node.cost - << std::endl; + std::cout << "len(pq) = " << pq.size() << " num_pq_pushed = " << num_pq_pushed << std::endl; + std::cout << "num_dets = " << node.num_dets << " max_num_dets = " << max_num_dets + << " cost = " << node.cost << std::endl; std::cout << "activated_errors = "; for (size_t oei : node.errs) { std::cout << oei << ", "; @@ -458,31 +438,32 @@ void TesseractDecoder::decode_to_errors(const std::vector& detections, } if (config.no_revisit_dets && - discovered_dets[next_num_dets].find(next_dets) != - discovered_dets[next_num_dets].end()) { + discovered_dets[next_num_dets].find(next_dets) != discovered_dets[next_num_dets].end()) { continue; } for (int d : edets[ei]) { if (node.dets[d]) { if (det_costs[d] == -1) { - det_costs[d] = - get_detcost(d, node.blocked_errs, det_counts); + det_costs[d] = get_detcost(d, node.blocked_errs, det_counts); } next_cost -= det_costs[d]; } else { - next_cost += get_detcost(d, config.at_most_two_errors_per_detector ? next_next_blocked_errs : next_blocked_errs, next_det_counts); + next_cost += get_detcost( + d, + config.at_most_two_errors_per_detector ? next_next_blocked_errs : next_blocked_errs, + next_det_counts); } } for (size_t od : eneighbors[ei]) { if (!node.dets[od] || !next_dets[od]) continue; if (det_costs[od] == -1) { - det_costs[od] = - get_detcost(od, node.blocked_errs, det_counts); + det_costs[od] = get_detcost(od, node.blocked_errs, det_counts); } next_cost -= det_costs[od]; - next_cost += - get_detcost(od, config.at_most_two_errors_per_detector ? next_next_blocked_errs : next_blocked_errs, next_det_counts); + next_cost += get_detcost( + od, config.at_most_two_errors_per_detector ? next_next_blocked_errs : next_blocked_errs, + next_det_counts); } if (next_cost == INF) { @@ -509,8 +490,7 @@ void TesseractDecoder::decode_to_errors(const std::vector& detections, return; } -double TesseractDecoder::cost_from_errors( - const std::vector& predicted_errors) { +double TesseractDecoder::cost_from_errors(const std::vector& predicted_errors) { double total_cost = 0; // Iterate over all errors and add to the mask for (size_t ei : predicted_errors_buffer) { @@ -529,15 +509,13 @@ common::ObservablesMask TesseractDecoder::mask_from_errors( return mask; } -common::ObservablesMask TesseractDecoder::decode( - const std::vector& detections) { +common::ObservablesMask TesseractDecoder::decode(const std::vector& detections) { decode_to_errors(detections); return mask_from_errors(predicted_errors_buffer); } -void TesseractDecoder::decode_shots( - std::vector& shots, - std::vector& obs_predicted) { +void TesseractDecoder::decode_shots(std::vector& shots, + std::vector& obs_predicted) { obs_predicted.resize(shots.size()); for (size_t i = 0; i < shots.size(); ++i) { obs_predicted[i] = decode(shots[i].hits); diff --git a/src/tesseract.h b/src/tesseract.h index cd290eb..c5d2b31 100644 --- a/src/tesseract.h +++ b/src/tesseract.h @@ -18,8 +18,8 @@ #include #include #include -#include #include +#include #include "common.h" #include "stim.h" @@ -73,13 +73,11 @@ struct TesseractDecoder { // Clears the predicted_errors_buffer and fills it with the decoded errors for // these detection events, using a specified detector ordering index. - void decode_to_errors(const std::vector& detections, - size_t det_order); + void decode_to_errors(const std::vector& detections, size_t det_order); // Returns the bitwise XOR of all the observables bitmasks of all errors in // the predicted errors buffer. - common::ObservablesMask mask_from_errors( - const std::vector& predicted_errors); + common::ObservablesMask mask_from_errors(const std::vector& predicted_errors); // Returns the sum of the likelihood costs (minus-log-likelihood-ratios) of // all errors in the predicted errors buffer. @@ -105,8 +103,8 @@ struct TesseractDecoder { void initialize_structures(size_t num_detectors); double get_detcost(size_t d, const std::vector& blocked_errs, const std::vector& det_counts) const; - void to_node(const QNode& qnode, const std::vector& shot_dets, - size_t det_order, Node& node) const; + void to_node(const QNode& qnode, const std::vector& shot_dets, size_t det_order, + Node& node) const; }; #endif // TESSERACT_DECODER_H \ No newline at end of file diff --git a/src/tesseract.perf.cc b/src/tesseract.perf.cc index 38e49e1..4ec53a5 100644 --- a/src/tesseract.perf.cc +++ b/src/tesseract.perf.cc @@ -23,8 +23,7 @@ constexpr uint64_t test_data_seed = 752024; template -void benchmark_decoder(Decoder& decoder, stim::Circuit& circuit, - size_t num_shots) { +void benchmark_decoder(Decoder& decoder, stim::Circuit& circuit, size_t num_shots) { // Sample data std::vector shots; sample_shots(test_data_seed, circuit, num_shots, shots); @@ -37,10 +36,8 @@ void benchmark_decoder(Decoder& decoder, stim::Circuit& circuit, auto benchmark_func = [&]() { for (size_t shot = 0; shot < num_shots; ++shot) { decoder.decode_to_errors(shots[shot].hits); - common::ObservablesMask obs = - decoder.mask_from_errors(decoder.predicted_errors_buffer); - num_errors += (!decoder.low_confidence_flag and - (obs != shots[shot].obs_mask_as_u64())); + common::ObservablesMask obs = decoder.mask_from_errors(decoder.predicted_errors_buffer); + num_errors += (!decoder.low_confidence_flag and (obs != shots[shot].obs_mask_as_u64())); num_low_confidence += decoder.low_confidence_flag; total_num_errors_used += decoder.predicted_errors_buffer.size(); ++num_decoded; @@ -52,15 +49,11 @@ void benchmark_decoder(Decoder& decoder, stim::Circuit& circuit, do { benchmark_func(); auto end_time = std::chrono::steady_clock::now(); - num_milliseconds = - std::chrono::duration(end_time - start_time) - .count(); + num_milliseconds = std::chrono::duration(end_time - start_time).count(); } while (num_milliseconds < 1000.0); - std::cout << (num_milliseconds / num_decoded) << " milliseconds per shot " - << num_decoded << " shots " << num_low_confidence - << " low confidence " << num_errors << " errors " - << " total_num_errors_used = " << total_num_errors_used - << std::endl; + std::cout << (num_milliseconds / num_decoded) << " milliseconds per shot " << num_decoded + << " shots " << num_low_confidence << " low confidence " << num_errors << " errors " + << " total_num_errors_used = " << total_num_errors_used << std::endl; } void benchmark_tesseract(std::string circuit_path, size_t num_shots) { @@ -70,13 +63,12 @@ void benchmark_tesseract(std::string circuit_path, size_t num_shots) { } stim::Circuit circuit = stim::Circuit::from_file(file); fclose(file); - stim::DetectorErrorModel dem = - stim::ErrorAnalyzer::circuit_to_detector_error_model( - circuit, /*decompose_errors=*/false, /*fold_loops=*/true, - /*allow_gauge_detectors=*/true, - /*approximate_disjoint_errors_threshold=*/1, - /*ignore_decomposition_failures=*/false, - /*block_decomposition_from_introducing_remnant_edges=*/false); + stim::DetectorErrorModel dem = stim::ErrorAnalyzer::circuit_to_detector_error_model( + circuit, /*decompose_errors=*/false, /*fold_loops=*/true, + /*allow_gauge_detectors=*/true, + /*approximate_disjoint_errors_threshold=*/1, + /*ignore_decomposition_failures=*/false, + /*block_decomposition_from_introducing_remnant_edges=*/false); dem = common::remove_zero_probability_errors(dem); TesseractConfig config{dem}; config.det_beam = 20; @@ -93,13 +85,12 @@ void benchmark_simplex(std::string circuit_path, size_t num_shots) { } stim::Circuit circuit = stim::Circuit::from_file(file); fclose(file); - stim::DetectorErrorModel dem = - stim::ErrorAnalyzer::circuit_to_detector_error_model( - circuit, /*decompose_errors=*/false, /*fold_loops=*/true, - /*allow_gauge_detectors=*/true, - /*approximate_disjoint_errors_threshold=*/1, - /*ignore_decomposition_failures=*/false, - /*block_decomposition_from_introducing_remnant_edges=*/false); + stim::DetectorErrorModel dem = stim::ErrorAnalyzer::circuit_to_detector_error_model( + circuit, /*decompose_errors=*/false, /*fold_loops=*/true, + /*allow_gauge_detectors=*/true, + /*approximate_disjoint_errors_threshold=*/1, + /*ignore_decomposition_failures=*/false, + /*block_decomposition_from_introducing_remnant_edges=*/false); dem = common::remove_zero_probability_errors(dem); SimplexConfig config{dem}; config.parallelize = true; diff --git a/src/tesseract.pybind.h b/src/tesseract.pybind.h index 5c57a2e..f428709 100644 --- a/src/tesseract.pybind.h +++ b/src/tesseract.pybind.h @@ -24,20 +24,16 @@ namespace py = pybind11; void add_tesseract_module(py::module &root) { - auto m = root.def_submodule("tesseract", - "Module containing the tesseract algorithm"); + auto m = root.def_submodule("tesseract", "Module containing the tesseract algorithm"); m.attr("INF_DET_BEAM") = INF_DET_BEAM; py::class_(m, "TesseractConfig") - .def(py::init>, double>(), - py::arg("dem"), py::arg("det_beam") = INF_DET_BEAM, - py::arg("beam_climbing") = false, py::arg("no_revisit_dets") = false, - py::arg("at_most_two_errors_per_detector") = false, - py::arg("verbose") = false, - py::arg("pqlimit") = std::numeric_limits::max(), - py::arg("det_orders") = std::vector>(), - py::arg("det_penalty") = 0.0) + .def(py::init>, double>(), + py::arg("dem"), py::arg("det_beam") = INF_DET_BEAM, py::arg("beam_climbing") = false, + py::arg("no_revisit_dets") = false, py::arg("at_most_two_errors_per_detector") = false, + py::arg("verbose") = false, py::arg("pqlimit") = std::numeric_limits::max(), + py::arg("det_orders") = std::vector>(), py::arg("det_penalty") = 0.0) .def_readwrite("dem", &TesseractConfig::dem) .def_readwrite("det_beam", &TesseractConfig::det_beam) .def_readwrite("no_revisit_dets", &TesseractConfig::no_revisit_dets) @@ -50,11 +46,9 @@ void add_tesseract_module(py::module &root) { .def("__str__", &TesseractConfig::str); py::class_(m, "Node") - .def(py::init, std::vector, double, size_t, - std::vector>(), - py::arg("errs") = std::vector(), - py::arg("dets") = std::vector(), py::arg("cost") = 0.0, - py::arg("num_dets") = 0, + .def(py::init, std::vector, double, size_t, std::vector>(), + py::arg("errs") = std::vector(), py::arg("dets") = std::vector(), + py::arg("cost") = 0.0, py::arg("num_dets") = 0, py::arg("blocked_errs") = std::vector()) .def_readwrite("errs", &Node::errs) .def_readwrite("dets", &Node::dets) @@ -65,9 +59,8 @@ void add_tesseract_module(py::module &root) { .def("__str__", &Node::str); py::class_(m, "QNode") - .def(py::init>(), - py::arg("cost") = 0.0, py::arg("num_dets") = 0, - py::arg("errs") = std::vector()) + .def(py::init>(), py::arg("cost") = 0.0, + py::arg("num_dets") = 0, py::arg("errs") = std::vector()) .def_readwrite("cost", &QNode::cost) .def_readwrite("num_dets", &QNode::num_dets) .def_readwrite("errs", &QNode::errs) @@ -77,22 +70,17 @@ void add_tesseract_module(py::module &root) { py::class_(m, "TesseractDecoder") .def(py::init(), py::arg("config")) .def("decode_to_errors", - py::overload_cast &>( - &TesseractDecoder::decode_to_errors), + py::overload_cast &>(&TesseractDecoder::decode_to_errors), py::arg("detections")) .def("decode_to_errors", py::overload_cast &, size_t>( &TesseractDecoder::decode_to_errors), py::arg("detections"), py::arg("det_order")) - .def("mask_from_errors", &TesseractDecoder::mask_from_errors, - py::arg("predicted_errors")) - .def("cost_from_errors", &TesseractDecoder::cost_from_errors, - py::arg("predicted_errors")) + .def("mask_from_errors", &TesseractDecoder::mask_from_errors, py::arg("predicted_errors")) + .def("cost_from_errors", &TesseractDecoder::cost_from_errors, py::arg("predicted_errors")) .def("decode", &TesseractDecoder::decode, py::arg("detections")) - .def_readwrite("low_confidence_flag", - &TesseractDecoder::low_confidence_flag) - .def_readwrite("predicted_errors_buffer", - &TesseractDecoder::predicted_errors_buffer) + .def_readwrite("low_confidence_flag", &TesseractDecoder::low_confidence_flag) + .def_readwrite("predicted_errors_buffer", &TesseractDecoder::predicted_errors_buffer) .def_readwrite("det_beam", &TesseractDecoder::det_beam) .def_readwrite("errors", &TesseractDecoder::errors); } diff --git a/src/tesseract.test.cc b/src/tesseract.test.cc index a1775c7..308c64f 100644 --- a/src/tesseract.test.cc +++ b/src/tesseract.test.cc @@ -14,8 +14,8 @@ #include "tesseract.h" -#include #include +#include #include "gtest/gtest.h" #include "simplex.h" @@ -24,8 +24,7 @@ constexpr uint64_t test_data_seed = 752024; -bool simplex_test_compare(stim::DetectorErrorModel& dem, - std::vector& shots) { +bool simplex_test_compare(stim::DetectorErrorModel& dem, std::vector& shots) { TesseractConfig tesseract_config{dem}; TesseractDecoder tesseract_decoder(tesseract_config); @@ -34,21 +33,19 @@ bool simplex_test_compare(stim::DetectorErrorModel& dem, for (size_t shot = 0; shot < shots.size(); shot++) { tesseract_decoder.decode_to_errors(shots[shot].hits); - double tesseract_cost = tesseract_decoder.cost_from_errors( - tesseract_decoder.predicted_errors_buffer); + double tesseract_cost = + tesseract_decoder.cost_from_errors(tesseract_decoder.predicted_errors_buffer); if (tesseract_decoder.low_confidence_flag) { // Simplex c++ does not yet support undecodable shots -- i.e. detection // event configurations with no error solution. std::cout << "not decoding shot " << shot - << " with simplex because Tesseract found no solution" - << std::endl; + << " with simplex because Tesseract found no solution" << std::endl; continue; } simplex_decoder.decode_to_errors(shots[shot].hits); - double simplex_cost = simplex_decoder.cost_from_errors( - simplex_decoder.predicted_errors_buffer); + double simplex_cost = simplex_decoder.cost_from_errors(simplex_decoder.predicted_errors_buffer); // If there is a mismatch in weights, print diagnostic information if (std::abs(tesseract_cost - simplex_cost) > EPSILON) { @@ -59,8 +56,7 @@ bool simplex_test_compare(stim::DetectorErrorModel& dem, std::cout << std::endl; std::cout << "Error: For shot " << shot << " tesseract got solution with cost:" << tesseract_cost - << " simplex got solution with cost: " << simplex_cost - << std::endl; + << " simplex got solution with cost: " << simplex_cost << std::endl; std::cout << "tesseract used errors "; for (size_t ei : tesseract_decoder.predicted_errors_buffer) { std::cout << ei << ", "; @@ -81,12 +77,10 @@ bool simplex_test_compare(stim::DetectorErrorModel& dem, TEST(tesseract, Tesseract_simplex_test) { bool long_tests = std::getenv("TESSERACT_LONG_TESTS") != nullptr; - auto p_errs = long_tests ? std::vector{0.001f, 0.003f, 0.005f} - : std::vector{0.003f}; - auto distances = long_tests ? std::vector{3, 5, 7} - : std::vector{3}; - auto rounds = long_tests ? std::vector{2, 5, 10} - : std::vector{2}; + auto p_errs = + long_tests ? std::vector{0.001f, 0.003f, 0.005f} : std::vector{0.003f}; + auto distances = long_tests ? std::vector{3, 5, 7} : std::vector{3}; + auto rounds = long_tests ? std::vector{2, 5, 10} : std::vector{2}; size_t base_shots = long_tests ? 1000 : 100; for (float p_err : p_errs) { @@ -94,23 +88,20 @@ TEST(tesseract, Tesseract_simplex_test) { for (const size_t num_rounds : rounds) { const size_t num_shots = base_shots / num_rounds / distance; std::cout << "p_err = " << p_err << " distance = " << distance - << " num_rounds = " << num_rounds - << " num_shots = " << num_shots << std::endl; + << " num_rounds = " << num_rounds << " num_shots = " << num_shots << std::endl; stim::CircuitGenParameters params(num_rounds, /*distance=*/distance, /*task=*/"rotated_memory_x"); params.after_clifford_depolarization = p_err; params.before_round_data_depolarization = p_err; params.before_measure_flip_probability = p_err; params.after_reset_flip_probability = p_err; - stim::Circuit circuit = - stim::generate_surface_code_circuit(params).circuit; - stim::DetectorErrorModel dem = - stim::ErrorAnalyzer::circuit_to_detector_error_model( - circuit, /*decompose_errors=*/false, /*fold_loops=*/true, - /*allow_gauge_detectors=*/true, - /*approximate_disjoint_errors_threshold=*/1, - /*ignore_decomposition_failures=*/false, - /*block_decomposition_from_introducing_remnant_edges=*/false); + stim::Circuit circuit = stim::generate_surface_code_circuit(params).circuit; + stim::DetectorErrorModel dem = stim::ErrorAnalyzer::circuit_to_detector_error_model( + circuit, /*decompose_errors=*/false, /*fold_loops=*/true, + /*allow_gauge_detectors=*/true, + /*approximate_disjoint_errors_threshold=*/1, + /*ignore_decomposition_failures=*/false, + /*block_decomposition_from_introducing_remnant_edges=*/false); for (bool merge_errors : {true, false}) { stim::DetectorErrorModel new_dem = dem; if (merge_errors) { @@ -190,8 +181,7 @@ TEST(tesseract, Tesseract_simplex_DEM_exhaustive_test) { ASSERT_LE(num_detectors, 64); // Try all possible dets sets on num_detectors detectors std::vector shots; - for (uint64_t bitstring = 0; bitstring < (1ULL << num_detectors); - ++bitstring) { + for (uint64_t bitstring = 0; bitstring < (1ULL << num_detectors); ++bitstring) { stim::SparseShot shot; for (size_t d = 0; d < num_detectors; ++d) { if (bitstring & (1 << (num_detectors - d - 1))) { diff --git a/src/tesseract_main.cc b/src/tesseract_main.cc index 496a7c0..4f6ffdf 100644 --- a/src/tesseract_main.cc +++ b/src/tesseract_main.cc @@ -12,13 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include #include #include #include -#include -#include -#include #include +#include +#include #include #include "common.h" @@ -79,14 +79,12 @@ struct Args { bool print_stats = false; bool has_observables() { - return append_observables || !obs_in_fname.empty() || - (sample_num_shots > 0); + return append_observables || !obs_in_fname.empty() || (sample_num_shots > 0); } void validate() { if (circuit_path.empty() and dem_path.empty()) { - throw std::invalid_argument( - "Must provide at least one of --circuit or --dem"); + throw std::invalid_argument("Must provide at least one of --circuit or --dem"); } int num_data_sources = int(sample_num_shots > 0) + int(!in_fname.empty()); @@ -94,23 +92,18 @@ struct Args { throw std::invalid_argument("Requires exactly 1 source of shots."); } if (!in_fname.empty() and in_format.empty()) { - throw std::invalid_argument( - "If --in is provided, must also specify --in-format."); + throw std::invalid_argument("If --in is provided, must also specify --in-format."); } if (!out_fname.empty() and out_format.empty()) { - throw std::invalid_argument( - "If --out is provided, must also specify --out-format."); + throw std::invalid_argument("If --out is provided, must also specify --out-format."); } - if (!in_format.empty() && - !stim::format_name_to_enum_map().contains(in_format)) { + if (!in_format.empty() && !stim::format_name_to_enum_map().contains(in_format)) { throw std::invalid_argument("Invalid format: " + in_format); } - if (!obs_in_format.empty() && - !stim::format_name_to_enum_map().contains(obs_in_format)) { + if (!obs_in_format.empty() && !stim::format_name_to_enum_map().contains(obs_in_format)) { throw std::invalid_argument("Invalid format: " + obs_in_format); } - if (!out_format.empty() && - !stim::format_name_to_enum_map().contains(out_format)) { + if (!out_format.empty() && !stim::format_name_to_enum_map().contains(out_format)) { throw std::invalid_argument("Invalid format: " + out_format); } if (!obs_in_fname.empty() and in_fname.empty()) { @@ -127,8 +120,7 @@ struct Args { } if (shot_range_begin or shot_range_end) { if (shot_range_end < shot_range_begin) { - throw std::invalid_argument( - "Provided shot range must have end >= begin."); + throw std::invalid_argument("Provided shot range must have end >= begin."); } } if (sample_num_shots > 0 and circuit_path.empty()) { @@ -182,8 +174,7 @@ struct Args { std::mt19937_64 rng(det_order_seed); std::normal_distribution dist(/*mean=*/0, /*stddev=*/1); - std::vector> detector_coords = - get_detector_coords(config.dem); + std::vector> detector_coords = get_detector_coords(config.dem); if (verbose) { for (size_t d = 0; d < detector_coords.size(); ++d) { std::cout << "Detector D" << d << " coordinate ("; @@ -244,8 +235,7 @@ struct Args { // of the indices. for (size_t det_order = 0; det_order < num_det_orders; ++det_order) { config.det_orders.emplace_back(); - std::iota(config.det_orders.back().begin(), - config.det_orders.front().end(), 0); + std::iota(config.det_orders.back().begin(), config.det_orders.front().end(), 0); } } else { // Use the coordinates to order the detectors based on a random @@ -260,16 +250,14 @@ struct Args { for (size_t i = 0; i < detector_coords.size(); ++i) { inner_products[i] = 0; for (size_t j = 0; j < orientation_vector.size(); ++j) { - inner_products[i] += - detector_coords[i][j] * orientation_vector[j]; + inner_products[i] += detector_coords[i][j] * orientation_vector[j]; } } std::vector perm(config.dem.count_detectors()); std::iota(perm.begin(), perm.end(), 0); - std::sort(perm.begin(), perm.end(), - [&](const size_t& i, const size_t& j) { - return inner_products[i] > inner_products[j]; - }); + std::sort(perm.begin(), perm.end(), [&](const size_t& i, const size_t& j) { + return inner_products[i] > inner_products[j]; + }); // Invert the permutation std::vector inv_perm(config.dem.count_detectors()); for (size_t i = 0; i < perm.size(); ++i) { @@ -285,8 +273,8 @@ struct Args { assert(!circuit_path.empty()); std::mt19937_64 rng(sample_seed); size_t num_detectors = circuit.count_detectors(); - const auto [dets, obs] = stim::sample_batch_detection_events<64>( - circuit, sample_num_shots, rng); + const auto [dets, obs] = + stim::sample_batch_detection_events<64>(circuit, sample_num_shots, rng); stim::simd_bit_table<64> obs_T = obs.transposed(); shots.resize(sample_num_shots); for (size_t k = 0; k < sample_num_shots; k++) { @@ -305,8 +293,7 @@ struct Args { if (!shots_file) { throw std::invalid_argument("Could not open the file: " + in_fname); } - stim::FileFormatData shots_in_format = - stim::format_name_to_enum_map().at(in_format); + stim::FileFormatData shots_in_format = stim::format_name_to_enum_map().at(in_format); auto reader = stim::MeasureRecordReader::make( shots_file, shots_in_format.id, 0, config.dem.count_detectors(), append_observables * config.dem.count_observables()); @@ -327,12 +314,9 @@ struct Args { if (!obs_file) { throw std::invalid_argument("Could not open the file: " + obs_in_fname); } - stim::FileFormatData shots_obs_in_format = - stim::format_name_to_enum_map().at(obs_in_format); - auto obs_reader = - stim::MeasureRecordReader::make( - obs_file, shots_obs_in_format.id, 0, 0, - config.dem.count_observables()); + stim::FileFormatData shots_obs_in_format = stim::format_name_to_enum_map().at(obs_in_format); + auto obs_reader = stim::MeasureRecordReader::make( + obs_file, shots_obs_in_format.id, 0, 0, config.dem.count_observables()); stim::SparseShot sparse_shot; sparse_shot.clear(); size_t num_obs_shots = 0; @@ -354,24 +338,21 @@ struct Args { if (shot_range_begin or shot_range_end) { assert(shot_range_end >= shot_range_begin); if (shot_range_end > shots.size()) { - throw std::invalid_argument( - "Shot range end is past end of shots array."); + throw std::invalid_argument("Shot range end is past end of shots array."); } - std::vector shots_in_range( - shots.begin() + shot_range_begin, shots.begin() + shot_range_end); + std::vector shots_in_range(shots.begin() + shot_range_begin, + shots.begin() + shot_range_end); std::swap(shots_in_range, shots); } if (!out_fname.empty()) { // Create a writer instance to write the predicted obs to a file - stim::FileFormatData predictions_out_format = - stim::format_name_to_enum_map().at(out_format); + stim::FileFormatData predictions_out_format = stim::format_name_to_enum_map().at(out_format); FILE* predictions_file = stdout; if (out_fname != "-") { predictions_file = fopen(out_fname.c_str(), "w"); } - writer = stim::MeasureRecordWriter::make(predictions_file, - predictions_out_format.id); + writer = stim::MeasureRecordWriter::make(predictions_file, predictions_out_format.id); writer->begin_result_type('L'); // TODO: ensure the fclose happens after all predictions are written to // the writer. @@ -390,18 +371,13 @@ int main(int argc, char* argv[]) { std::cout.precision(16); argparse::ArgumentParser program("tesseract"); Args args; - program.add_argument("--circuit") - .help("Stim circuit file path") - .store_into(args.circuit_path); - program.add_argument("--dem") - .help("Stim dem file path") - .store_into(args.dem_path); + program.add_argument("--circuit").help("Stim circuit file path").store_into(args.circuit_path); + program.add_argument("--dem").help("Stim dem file path").store_into(args.dem_path); program.add_argument("--no-merge-errors") .help("If provided, will not merge identical error mechanisms.") .store_into(args.no_merge_errors); program.add_argument("--num-det-orders") - .help( - "Number of ways to orient the manifold when reordering the detectors") + .help("Number of ways to orient the manifold when reordering the detectors") .metavar("N") .default_value(size_t(1)) .store_into(args.num_det_orders); @@ -455,8 +431,7 @@ int main(int argc, char* argv[]) { .default_value(size_t(0)) .store_into(args.shot_range_end); program.add_argument("--in") - .help( - "File to read detection events (and possibly observable flips) from") + .help("File to read detection events (and possibly observable flips) from") .metavar("filename") .default_value(std::string("")) .store_into(args.in_fname); @@ -468,14 +443,11 @@ int main(int argc, char* argv[]) { in_formats += key; } program.add_argument("--in-format", "--in_format") - .help("Format of the file to read detection events from (" + in_formats + - ")") + .help("Format of the file to read detection events from (" + in_formats + ")") .metavar(in_formats) .default_value(std::string("")) .store_into(args.in_format); - program - .add_argument("--in-includes-appended-observables", - "--in_includes_appended_observables") + program.add_argument("--in-includes-appended-observables", "--in_includes_appended_observables") .help( "If present, assumes that the observable flips are appended to the " "end of each shot.") @@ -498,8 +470,7 @@ int main(int argc, char* argv[]) { .default_value(std::string("")) .store_into(args.out_fname); program.add_argument("--out-format") - .help("Format of the file to write observable flip predictions to (" + - in_formats + ")") + .help("Format of the file to write observable flip predictions to (" + in_formats + ")") .metavar(in_formats) .default_value(std::string("")) .store_into(args.out_format); @@ -585,44 +556,39 @@ int main(int argc, char* argv[]) { // After this value returns to 0, we know that no further shots will // transition to finished. ++num_worker_threads_active; - decoder_threads.push_back(std::thread( - [&config, &next_unclaimed_shot, &shots, &obs_predicted, &cost_predicted, - &decoding_time_seconds, &low_confidence, &finished, &error_use_totals, - &has_obs, &worker_threads_please_terminate, - &num_worker_threads_active]() { - TesseractDecoder decoder(config); - std::vector error_use(config.dem.count_errors()); - for (size_t shot; !worker_threads_please_terminate and - ((shot = next_unclaimed_shot++) < shots.size());) { - auto start_time = std::chrono::high_resolution_clock::now(); - decoder.decode_to_errors(shots[shot].hits); - auto stop_time = std::chrono::high_resolution_clock::now(); - decoding_time_seconds[shot] = - std::chrono::duration_cast( - stop_time - start_time) - .count() / - 1e6; - obs_predicted[shot] = - decoder.mask_from_errors(decoder.predicted_errors_buffer); - low_confidence[shot] = decoder.low_confidence_flag; - cost_predicted[shot] = - decoder.cost_from_errors(decoder.predicted_errors_buffer); - if (!has_obs or - shots[shot].obs_mask_as_u64() == obs_predicted[shot]) { - // Only count the error uses for shots that did not have a logical - // error, if we know the obs flips. - for (size_t ei : decoder.predicted_errors_buffer) { - ++error_use[ei]; - } - } - finished[shot] = true; + decoder_threads.push_back(std::thread([&config, &next_unclaimed_shot, &shots, &obs_predicted, + &cost_predicted, &decoding_time_seconds, &low_confidence, + &finished, &error_use_totals, &has_obs, + &worker_threads_please_terminate, + &num_worker_threads_active]() { + TesseractDecoder decoder(config); + std::vector error_use(config.dem.count_errors()); + for (size_t shot; + !worker_threads_please_terminate and ((shot = next_unclaimed_shot++) < shots.size());) { + auto start_time = std::chrono::high_resolution_clock::now(); + decoder.decode_to_errors(shots[shot].hits); + auto stop_time = std::chrono::high_resolution_clock::now(); + decoding_time_seconds[shot] = + std::chrono::duration_cast(stop_time - start_time).count() / + 1e6; + obs_predicted[shot] = decoder.mask_from_errors(decoder.predicted_errors_buffer); + low_confidence[shot] = decoder.low_confidence_flag; + cost_predicted[shot] = decoder.cost_from_errors(decoder.predicted_errors_buffer); + if (!has_obs or shots[shot].obs_mask_as_u64() == obs_predicted[shot]) { + // Only count the error uses for shots that did not have a logical + // error, if we know the obs flips. + for (size_t ei : decoder.predicted_errors_buffer) { + ++error_use[ei]; } - // Add the error counts to the total - for (size_t ei = 0; ei < error_use_totals.size(); ++ei) { - error_use_totals[ei] += error_use[ei]; - } - --num_worker_threads_active; - })); + } + finished[shot] = true; + } + // Add the error counts to the total + for (size_t ei = 0; ei < error_use_totals.size(); ++ei) { + error_use_totals[ei] += error_use[ei]; + } + --num_worker_threads_active; + })); } size_t num_errors = 0; size_t num_low_confidence = 0; @@ -658,10 +624,9 @@ int main(int argc, char* argv[]) { total_time_seconds += decoding_time_seconds[shot]; if (args.print_stats) { - std::cout << "num_shots = " << (shot + 1) - << " num_low_confidence = " << num_low_confidence - << " num_errors = " << num_errors - << " total_time_seconds = " << total_time_seconds << std::endl; + std::cout << "num_shots = " << (shot + 1) << " num_low_confidence = " << num_low_confidence + << " num_errors = " << num_errors << " total_time_seconds = " << total_time_seconds + << std::endl; std::cout << "cost = " << cost_predicted[shot] << std::endl; std::cout.flush(); } @@ -675,8 +640,7 @@ int main(int argc, char* argv[]) { } if (!args.dem_out_fname.empty()) { - std::vector counts(error_use_totals.begin(), - error_use_totals.end()); + std::vector counts(error_use_totals.begin(), error_use_totals.end()); size_t num_usage_dem_shots = shot; if (has_obs) { // When we know the obs, we only count non-error shots. @@ -693,25 +657,25 @@ int main(int argc, char* argv[]) { bool print_final_stats = true; if (!args.stats_out_fname.empty()) { - nlohmann::json stats_json = {{"circuit_path", args.circuit_path}, - {"dem_path", args.dem_path}, - {"max_errors", args.max_errors}, - {"sample_seed", args.sample_seed}, - {"at_most_two_errors_per_detector", - args.at_most_two_errors_per_detector}, - {"det_beam", args.det_beam}, - {"det_penalty", args.det_penalty}, - {"beam_climbing", args.beam_climbing}, - {"no_revisit_dets", args.no_revisit_dets}, - {"pqlimit", args.pqlimit}, - {"num_det_orders", args.num_det_orders}, - {"det_order_seed", args.det_order_seed}, - {"total_time_seconds", total_time_seconds}, - {"num_errors", num_errors}, - {"num_low_confidence", num_low_confidence}, - {"num_shots", shot}, - {"num_threads", args.num_threads}, - {"sample_num_shots", args.sample_num_shots}}; + nlohmann::json stats_json = { + {"circuit_path", args.circuit_path}, + {"dem_path", args.dem_path}, + {"max_errors", args.max_errors}, + {"sample_seed", args.sample_seed}, + {"at_most_two_errors_per_detector", args.at_most_two_errors_per_detector}, + {"det_beam", args.det_beam}, + {"det_penalty", args.det_penalty}, + {"beam_climbing", args.beam_climbing}, + {"no_revisit_dets", args.no_revisit_dets}, + {"pqlimit", args.pqlimit}, + {"num_det_orders", args.num_det_orders}, + {"det_order_seed", args.det_order_seed}, + {"total_time_seconds", total_time_seconds}, + {"num_errors", num_errors}, + {"num_low_confidence", num_low_confidence}, + {"num_shots", shot}, + {"num_threads", args.num_threads}, + {"sample_num_shots", args.sample_num_shots}}; if (args.stats_out_fname == "-") { std::cout << stats_json << std::endl; diff --git a/src/test_data.h b/src/test_data.h index 6910607..137f57e 100644 --- a/src/test_data.h +++ b/src/test_data.h @@ -19,8 +19,12 @@ #include "stim.h" -std::vector get_small_test_circuits() { return {}; } +std::vector get_small_test_circuits() { + return {}; +} -std::vector get_large_test_circuits() { return {}; } +std::vector get_large_test_circuits() { + return {}; +} #endif // TESSERACT_TEST_DATA_H diff --git a/src/utils.cc b/src/utils.cc index f3081d3..c4c6edd 100644 --- a/src/utils.cc +++ b/src/utils.cc @@ -25,8 +25,7 @@ #include "common.h" #include "stim.h" -std::vector> get_detector_coords( - stim::DetectorErrorModel& dem) { +std::vector> get_detector_coords(stim::DetectorErrorModel& dem) { std::vector> detector_coords; for (const stim::DemInstruction& instruction : dem.flattened().instructions) { switch (instruction.type) { @@ -51,8 +50,7 @@ std::vector> get_detector_coords( return detector_coords; } -std::vector> build_detector_graph( - const stim::DetectorErrorModel& dem) { +std::vector> build_detector_graph(const stim::DetectorErrorModel& dem) { size_t num_detectors = dem.count_detectors(); std::vector> neighbors(num_detectors); for (const stim::DemInstruction& instruction : dem.flattened().instructions) { @@ -81,11 +79,9 @@ std::vector> build_detector_graph( return neighbors; } -bool sampling_from_dem(uint64_t seed, size_t num_shots, - stim::DetectorErrorModel dem, +bool sampling_from_dem(uint64_t seed, size_t num_shots, stim::DetectorErrorModel dem, std::vector& shots) { - stim::DemSampler sampler(dem, std::mt19937_64{seed}, - num_shots); + stim::DemSampler sampler(dem, std::mt19937_64{seed}, num_shots); sampler.resample(false); shots.resize(0); shots.resize(num_shots); @@ -112,13 +108,11 @@ bool sampling_from_dem(uint64_t seed, size_t num_shots, return true; } -void sample_shots(uint64_t sample_seed, stim::Circuit& circuit, - size_t sample_num_shots, +void sample_shots(uint64_t sample_seed, stim::Circuit& circuit, size_t sample_num_shots, std::vector& shots) { std::mt19937_64 rng(sample_seed); size_t num_detectors = circuit.count_detectors(); - const auto [dets, obs] = - stim::sample_batch_detection_events<64>(circuit, sample_num_shots, rng); + const auto [dets, obs] = stim::sample_batch_detection_events<64>(circuit, sample_num_shots, rng); stim::simd_bit_table<64> obs_T = obs.transposed(); shots.resize(sample_num_shots); for (size_t k = 0; k < sample_num_shots; k++) { @@ -131,24 +125,20 @@ void sample_shots(uint64_t sample_seed, stim::Circuit& circuit, } } -std::vector get_errors_from_dem( - const stim::DetectorErrorModel& dem) { +std::vector get_errors_from_dem(const stim::DetectorErrorModel& dem) { std::vector errors; for (const stim::DemInstruction& instruction : dem.instructions) { // Ignore zero-probability errors - if (instruction.type == stim::DemInstructionType::DEM_ERROR and - instruction.arg_data[0] > 0) + if (instruction.type == stim::DemInstructionType::DEM_ERROR and instruction.arg_data[0] > 0) errors.emplace_back(instruction); } return errors; } -std::vector get_files_recursive( - const std::string& directory_path) { +std::vector get_files_recursive(const std::string& directory_path) { std::vector file_paths; try { - for (const auto& entry : - std::filesystem::recursive_directory_iterator(directory_path)) { + for (const auto& entry : std::filesystem::recursive_directory_iterator(directory_path)) { if (std::filesystem::is_regular_file(entry)) { file_paths.push_back(entry.path().string()); } diff --git a/src/utils.h b/src/utils.h index 52cd787..5fe82be 100644 --- a/src/utils.h +++ b/src/utils.h @@ -28,26 +28,21 @@ constexpr const double EPSILON = 1e-7; -std::vector> get_detector_coords( - stim::DetectorErrorModel& dem); +std::vector> get_detector_coords(stim::DetectorErrorModel& dem); // Builds an adjacency list graph where two detectors share an edge iff an error // in the model activates them both. -std::vector> build_detector_graph( - const stim::DetectorErrorModel& dem); +std::vector> build_detector_graph(const stim::DetectorErrorModel& dem); const double INF = std::numeric_limits::infinity(); -bool sampling_from_dem(uint64_t seed, size_t num_shots, - stim::DetectorErrorModel dem, +bool sampling_from_dem(uint64_t seed, size_t num_shots, stim::DetectorErrorModel dem, std::vector& shots); -void sample_shots(uint64_t sample_seed, stim::Circuit& circuit, - size_t sample_num_shots, +void sample_shots(uint64_t sample_seed, stim::Circuit& circuit, size_t sample_num_shots, std::vector& shots); -std::vector get_errors_from_dem( - const stim::DetectorErrorModel& dem); +std::vector get_errors_from_dem(const stim::DetectorErrorModel& dem); std::vector get_files_recursive(const std::string& directory_path);