diff --git a/src/tesseract.cc b/src/tesseract.cc index f3d2460..a2c1691 100644 --- a/src/tesseract.cc +++ b/src/tesseract.cc @@ -74,12 +74,12 @@ std::string Node::str() { ss << "Node("; ss << "errors=" << self.errors << ", "; ss << "cost=" << self.cost << ", "; - ss << "num_detectors=" << self.num_detectors << ", "; + ss << "num_dets=" << self.num_dets << ", "; return ss.str(); } bool Node::operator>(const Node& other) const { - return cost > other.cost || (cost == other.cost && num_detectors < other.num_detectors); + return cost > other.cost || (cost == other.cost && num_dets < other.num_dets); } double TesseractDecoder::get_detcost( @@ -293,27 +293,27 @@ void TesseractDecoder::decode_to_errors(const std::vector& detections, return; } - size_t min_num_detectors = detections.size(); - size_t max_num_detectors = min_num_detectors + detector_beam; + size_t min_num_dets = detections.size(); + size_t max_num_dets = min_num_dets + detector_beam; std::vector next_errors; boost::dynamic_bitset<> next_detectors; std::vector next_detector_cost_tuples; - pq.push({initial_cost, min_num_detectors, std::vector()}); + pq.push({initial_cost, min_num_dets, std::vector()}); size_t num_pq_pushed = 1; while (!pq.empty()) { const Node node = pq.top(); pq.pop(); - if (node.num_detectors > max_num_detectors) continue; + if (node.num_dets > max_num_dets) continue; boost::dynamic_bitset<> detectors = initial_detectors; std::vector detector_cost_tuples(num_errors); flip_detectors_and_block_errors(detector_order, node.errors, detectors, detector_cost_tuples); - if (node.num_detectors == 0) { + if (node.num_dets == 0) { if (config.create_visualization) { visualizer.add_activated_errors(node.errors); visualizer.add_activated_detectors(detectors, num_detectors); @@ -339,7 +339,7 @@ void TesseractDecoder::decode_to_errors(const std::vector& detections, return; } - if (config.no_revisit_dets && !visited_detectors[node.num_detectors].insert(detectors).second) + if (config.no_revisit_dets && !visited_detectors[node.num_dets].insert(detectors).second) continue; if (config.create_visualization) { @@ -349,9 +349,8 @@ void TesseractDecoder::decode_to_errors(const std::vector& detections, if (config.verbose) { std::cout.precision(13); std::cout << "len(pq) = " << pq.size() << " num_pq_pushed = " << num_pq_pushed << std::endl; - std::cout << "num_detectors = " << node.num_detectors - << " max_num_detectors = " << max_num_detectors << " cost = " << node.cost - << std::endl; + std::cout << "num_dets = " << node.num_dets << " max_num_dets = " << max_num_dets + << " cost = " << node.cost << std::endl; std::cout << "activated_errors = "; for (size_t oei : node.errors) { std::cout << oei << ", "; @@ -366,14 +365,14 @@ void TesseractDecoder::decode_to_errors(const std::vector& detections, std::cout << std::endl; } - if (node.num_detectors < min_num_detectors) { - min_num_detectors = node.num_detectors; + if (node.num_dets < min_num_dets) { + min_num_dets = node.num_dets; if (config.no_revisit_dets) { - for (size_t i = min_num_detectors + detector_beam + 1; i <= max_num_detectors; ++i) { + for (size_t i = min_num_dets + detector_beam + 1; i <= max_num_dets; ++i) { visited_detectors[i].clear(); } } - max_num_detectors = std::min(max_num_detectors, min_num_detectors + detector_beam); + max_num_dets = std::min(max_num_dets, min_num_dets + detector_beam); } for (size_t d = 0; d < num_detectors; ++d) { @@ -415,21 +414,21 @@ void TesseractDecoder::decode_to_errors(const std::vector& detections, next_detector_cost_tuples[ei].error_blocked = 1; double next_cost = node.cost + errors[ei].likelihood_cost; - size_t next_num_detectors = node.num_detectors; + size_t next_num_dets = node.num_dets; for (int d : edets[ei]) { next_detectors[d] = !next_detectors[d]; int fired = next_detectors[d] ? 1 : -1; - next_num_detectors += fired; + next_num_dets += fired; for (int oei : d2e[d]) { next_detector_cost_tuples[oei].detectors_count += fired; } } - if (next_num_detectors > max_num_detectors) continue; + if (next_num_dets > max_num_dets) continue; - if (config.no_revisit_dets && visited_detectors[next_num_detectors].find(next_detectors) != - visited_detectors[next_num_detectors].end()) + if (config.no_revisit_dets && visited_detectors[next_num_dets].find(next_detectors) != + visited_detectors[next_num_dets].end()) continue; for (int d : edets[ei]) { @@ -454,7 +453,7 @@ void TesseractDecoder::decode_to_errors(const std::vector& detections, if (next_cost == INF) continue; - pq.push({next_cost, next_num_detectors, next_errors}); + pq.push({next_cost, next_num_dets, next_errors}); ++num_pq_pushed; if (num_pq_pushed > config.pqlimit) { diff --git a/src/tesseract.h b/src/tesseract.h index 62d6ed6..0eeb3ea 100644 --- a/src/tesseract.h +++ b/src/tesseract.h @@ -50,7 +50,8 @@ struct TesseractConfig { class Node { public: double cost; - size_t num_detectors; + // The number of activated detectors (dets for short) at this node + size_t num_dets; std::vector errors; bool operator>(const Node& other) const; @@ -118,4 +119,4 @@ struct TesseractDecoder { std::vector& detector_cost_tuples) const; }; -#endif // TESSERACT_DECODER_H \ No newline at end of file +#endif // TESSERACT_DECODER_H diff --git a/src/tesseract.pybind.h b/src/tesseract.pybind.h index 85603be..72f92e5 100644 --- a/src/tesseract.pybind.h +++ b/src/tesseract.pybind.h @@ -205,21 +205,20 @@ void add_tesseract_module(py::module& root) { This is used internally by the decoder to track decoding progress. )pbdoc") .def(py::init>(), py::arg("cost") = 0.0, - py::arg("num_detectors") = 0, py::arg("errors") = std::vector(), R"pbdoc( + py::arg("num_dets") = 0, py::arg("errors") = std::vector(), R"pbdoc( The constructor for the `Node` class. Parameters ---------- cost : float, default=0.0 The cost of the path to this node. - num_detectors : int, default=0 + num_dets : int, default=0 The number of detectors this search node has. errors : list[int], default=empty The list of error indices this search node has. )pbdoc") .def_readwrite("cost", &Node::cost, "The cost of the node.") - .def_readwrite("num_detectors", &Node::num_detectors, - "The number of detectors this search node has.") + .def_readwrite("num_dets", &Node::num_dets, "The number of detectors this search node has.") .def_readwrite("errors", &Node::errors, "The list of error indices this search node has.") .def(py::self > py::self, "Comparison operator for nodes based on cost. This is necessary to prioritize "