Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix for LSTM Diplopia issue #3476

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 62 additions & 5 deletions src/lstm/recodebeam.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,11 @@ const int RecodeBeamSearch::kBeamWidths[RecodedCharID::kMaxCodeLen + 1] = {

static const char *kNodeContNames[] = {"Anything", "OnlyDup", "NoDup"};

// the minimum diplopia key is the minimum score (key) from
// the network output to qualify as a likely 'real' character
// for the purposes of identifying possible diplopia
stweil marked this conversation as resolved.
Show resolved Hide resolved
static const float kMinDiplopiaKey = 0.25f;

// Prints debug details of the node.
void RecodeNode::Print(int null_char, const UNICHARSET &unicharset, int depth) const {
if (code == null_char) {
Expand All @@ -65,6 +70,9 @@ RecodeBeamSearch::RecodeBeamSearch(const UnicharCompress &recoder, int null_char
, beam_size_(0)
, top_code_(-1)
, second_code_(-1)
, in_possible_diplopia_(false)
, first_diplopia_code_(-1)
, second_diplopia_code_(-1)
stweil marked this conversation as resolved.
Show resolved Hide resolved
, dict_(dict)
, space_delimited_(true)
, is_simple_text_(simple_text)
Expand Down Expand Up @@ -182,7 +190,7 @@ RecodeBeamSearch::combineSegmentedTimesteps(

void RecodeBeamSearch::calculateCharBoundaries(std::vector<int> *starts, std::vector<int> *ends,
std::vector<int> *char_bounds_, int maxWidth) {
char_bounds_->push_back(0);
char_bounds_->push_back((*starts)[0]);
for (int i = 0; i < ends->size(); ++i) {
int middle = ((*starts)[i + 1] - (*ends)[i]) / 2;
char_bounds_->push_back((*ends)[i] + middle);
Expand Down Expand Up @@ -567,8 +575,8 @@ void RecodeBeamSearch::ExtractPathAsUnicharIds(const std::vector<const RecodeNod
}
rating -= cert;
}
starts.push_back(t);
if (t < width) {
starts.push_back(t);
int unichar_id = best_nodes[t]->unichar_id;
if (unichar_id == UNICHAR_SPACE && !certs->empty() && best_nodes[t]->permuter != NO_PERM) {
// All the rating and certainty go on the previous character except
Expand All @@ -582,16 +590,18 @@ void RecodeBeamSearch::ExtractPathAsUnicharIds(const std::vector<const RecodeNod
}
unichar_ids->push_back(unichar_id);
xcoords->push_back(t);
do {
double cert = best_nodes[t++]->certainty;
t++;
while (t < width && best_nodes[t]->duplicate) {
double cert = best_nodes[t]->certainty;
// Special-case NO-PERM space to forget the certainty of the previous
// nulls. See long comment in ContinueContext.
if (cert < certainty ||
(unichar_id == UNICHAR_SPACE && best_nodes[t - 1]->permuter == NO_PERM)) {
certainty = cert;
}
rating -= cert;
} while (t < width && best_nodes[t]->duplicate);
t++;
}
ends.push_back(t);
certs->push_back(certainty);
ratings->push_back(rating);
Expand Down Expand Up @@ -657,20 +667,48 @@ void RecodeBeamSearch::ComputeTopN(const float *outputs, int num_outputs, int to
}
}
}
float top_key = 0.0f;
float second_key = 0.0f;
bool found_first_code = false;
bool found_second_code = false;
while (!top_heap_.empty()) {
TopPair entry;
top_heap_.Pop(&entry);
if (in_possible_diplopia_ && entry.data() == first_diplopia_code_) {
found_first_code = true;
}
if (in_possible_diplopia_ && entry.data() == second_diplopia_code_) {
found_second_code = true;
}
if (top_heap_.size() > 1) {
top_n_flags_[entry.data()] = TN_TOPN;
} else {
top_n_flags_[entry.data()] = TN_TOP2;
if (top_heap_.empty()) {
top_code_ = entry.data();
top_key = entry.key();
} else {
second_code_ = entry.data();
second_key = entry.key();
}
}
}
// need to identify if we are in a potential diplopia situation
// or if we already are, then determine if it is ended
stweil marked this conversation as resolved.
Show resolved Hide resolved
if (in_possible_diplopia_) {
if (!found_first_code && !found_second_code) {
in_possible_diplopia_ = false;
first_diplopia_code_ = -1;
second_diplopia_code_ = -1;
}
}
if (!in_possible_diplopia_) {
if (top_code_ != null_char_ && second_code_ != null_char_ && top_key > kMinDiplopiaKey && second_key > kMinDiplopiaKey) {
in_possible_diplopia_ = true;
first_diplopia_code_ = top_code_;
second_diplopia_code_ = second_code_;
}
}
top_n_flags_[null_char_] = TN_TOP2;
}

Expand Down Expand Up @@ -1138,6 +1176,10 @@ void RecodeBeamSearch::PushHeapIfBetter(int max_size, int code, int unichar_id,
if (UpdateHeapIfMatched(&node, heap)) {
return;
}
// check to see if node is possible diplopia
stweil marked this conversation as resolved.
Show resolved Hide resolved
if (!AddToHeapIsAllowed(&node)) {
return;
}
RecodePair entry(score, node);
heap->Push(&entry);
ASSERT_HOST(entry.data().dawgs == nullptr);
Expand Down Expand Up @@ -1189,6 +1231,21 @@ bool RecodeBeamSearch::UpdateHeapIfMatched(RecodeNode *new_node, RecodeHeap *hea
return false;
}

// Determines if node can be added to heap based on possible diplopia status
stweil marked this conversation as resolved.
Show resolved Hide resolved
bool RecodeBeamSearch::AddToHeapIsAllowed(RecodeNode *new_node) {
if (!in_possible_diplopia_) {
return true;
}
const RecodeNode *prev_node = new_node->prev;
if (prev_node != nullptr && prev_node->code == first_diplopia_code_ && new_node->code == second_diplopia_code_) {
return false;
}
if (prev_node != nullptr && prev_node->code == second_diplopia_code_ && new_node->code == first_diplopia_code_) {
return false;
}
return true;
}

// Computes and returns the code-hash for the given code and prev.
uint64_t RecodeBeamSearch::ComputeCodeHash(int code, bool dup, const RecodeNode *prev) const {
uint64_t hash = prev == nullptr ? 0 : prev->code_hash;
Expand Down
8 changes: 8 additions & 0 deletions src/lstm/recodebeam.h
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,9 @@ class TESS_API RecodeBeamSearch {
// Searches the heap for an entry matching new_node, and updates the entry
// with reshuffle if needed. Returns true if there was a match.
bool UpdateHeapIfMatched(RecodeNode *new_node, RecodeHeap *heap);
// Determines if new node can be added to the heap for the current beam.
// Returns false if we are in possible diplopia situation
stweil marked this conversation as resolved.
Show resolved Hide resolved
bool AddToHeapIsAllowed(RecodeNode *new_node);
// Computes and returns the code-hash for the given code and prev.
uint64_t ComputeCodeHash(int code, bool dup, const RecodeNode *prev) const;
// Backtracks to extract the best path through the lattice that was built
Expand Down Expand Up @@ -423,6 +426,11 @@ class TESS_API RecodeBeamSearch {
// True if the input is simple text, ie adjacent equal chars are not to be
// eliminated.
bool is_simple_text_;
// Variables used in tracking possible diplopia case
stweil marked this conversation as resolved.
Show resolved Hide resolved
stweil marked this conversation as resolved.
Show resolved Hide resolved
// Refer to ComputeTopN routine for use of these variables
stweil marked this conversation as resolved.
Show resolved Hide resolved
bool in_possible_diplopia_;
int first_diplopia_code_;
int second_diplopia_code_;
stweil marked this conversation as resolved.
Show resolved Hide resolved
// The encoded (class label) of the null/reject character.
int null_char_;
};
Expand Down