Skip to content

Commit

Permalink
Reduce locks of RwLocks for language models
Browse files Browse the repository at this point in the history
  • Loading branch information
pemistahl committed May 24, 2023
1 parent a93ef8c commit c0f7e71
Show file tree
Hide file tree
Showing 2 changed files with 145 additions and 58 deletions.
201 changes: 144 additions & 57 deletions src/detector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,11 @@ use crate::constant::{
use crate::json::load_json;
use crate::language::Language;
use crate::model::{TestDataLanguageModel, TrainingDataLanguageModel};
use crate::ngram::NgramRef;
use crate::result::DetectionResult;

type LazyLanguageModelMap = Lazy<RwLock<HashMap<Language, HashMap<String, f64>>>>;
type StaticLanguageModelMap = &'static RwLock<HashMap<Language, HashMap<String, f64>>>;
type LanguageModelArray<'a> = [Option<&'a HashMap<Language, HashMap<String, f64>>>; 5];

static UNIGRAM_MODELS: LazyLanguageModelMap = Lazy::new(|| RwLock::new(HashMap::new()));
static BIGRAM_MODELS: LazyLanguageModelMap = Lazy::new(|| RwLock::new(HashMap::new()));
Expand Down Expand Up @@ -658,42 +658,108 @@ impl LanguageDetector {
}
}

fn get_language_models<R>(
&self,
ngram_length: usize,
filtered_languages: &HashSet<Language>,
callback_handler: impl FnOnce(LanguageModelArray) -> R,
) -> R {
let mut model_read_locks = [None, None, None, None, None];

if ngram_length >= 1 {
for language in filtered_languages {
self.load_language_models(self.unigram_language_models, language, 1);
}
model_read_locks[0] = Some(self.unigram_language_models.read().unwrap());
}

if ngram_length >= 2 {
for language in filtered_languages {
self.load_language_models(self.bigram_language_models, language, 2);
}
model_read_locks[1] = Some(self.bigram_language_models.read().unwrap());
}

if ngram_length >= 3 {
for language in filtered_languages {
self.load_language_models(self.trigram_language_models, language, 3);
}
model_read_locks[2] = Some(self.trigram_language_models.read().unwrap());
}

if ngram_length >= 4 {
for language in filtered_languages {
self.load_language_models(self.quadrigram_language_models, language, 4);
}
model_read_locks[3] = Some(self.quadrigram_language_models.read().unwrap());
}

if ngram_length >= 5 {
for language in filtered_languages {
self.load_language_models(self.fivegram_language_models, language, 5);
}
model_read_locks[4] = Some(self.fivegram_language_models.read().unwrap());
}

let models = [
model_read_locks[0].as_deref(),
model_read_locks[1].as_deref(),
model_read_locks[2].as_deref(),
model_read_locks[3].as_deref(),
model_read_locks[4].as_deref(),
];

callback_handler(models)
}

fn look_up_language_models(
&self,
words: &[String],
ngram_length: usize,
filtered_languages: &HashSet<Language>,
) -> (HashMap<Language, f64>, Option<HashMap<Language, u32>>) {
let test_data_model = TestDataLanguageModel::from(words, ngram_length);
let probabilities =
self.compute_language_probabilities(&test_data_model, filtered_languages);
let unigram_counts = if ngram_length == 1 {
let languages = probabilities.keys().collect_vec();
let intersected_languages = if !languages.is_empty() {
filtered_languages
.iter()
.cloned()
.filter(|it| languages.contains(&it))
.collect()

self.get_language_models(ngram_length, filtered_languages, |language_models| {
let probabilities = self.compute_language_probabilities(
&test_data_model,
filtered_languages,
&language_models,
);

let unigram_counts = if ngram_length == 1 {
let languages = probabilities.keys().collect_vec();
let intersected_languages = if !languages.is_empty() {
filtered_languages
.iter()
.cloned()
.filter(|it| languages.contains(&it))
.collect()
} else {
filtered_languages.clone()
};
Some(self.count_unigrams(
&test_data_model,
&intersected_languages,
language_models[0].unwrap(),
))
} else {
filtered_languages.clone()
None
};
Some(self.count_unigrams(&test_data_model, &intersected_languages))
} else {
None
};

(probabilities, unigram_counts)
(probabilities, unigram_counts)
})
}

fn compute_language_probabilities(
&self,
model: &TestDataLanguageModel,
filtered_languages: &HashSet<Language>,
language_models: &LanguageModelArray,
) -> HashMap<Language, f64> {
let mut probabilities = hashmap!();
for language in filtered_languages.iter() {
let sum = self.compute_sum_of_ngram_probabilities(language, model);
let sum = self.compute_sum_of_ngram_probabilities(language, model, language_models);
if sum < 0.0 {
probabilities.insert(language.clone(), sum);
}
Expand Down Expand Up @@ -746,11 +812,22 @@ impl LanguageDetector {
&self,
language: &Language,
ngram_model: &TestDataLanguageModel,
language_models: &LanguageModelArray,
) -> f64 {
let models = [
language_models[0].as_ref().and_then(|m| m.get(language)),
language_models[1].as_ref().and_then(|m| m.get(language)),
language_models[2].as_ref().and_then(|m| m.get(language)),
language_models[3].as_ref().and_then(|m| m.get(language)),
language_models[4].as_ref().and_then(|m| m.get(language)),
];
let mut sum = 0.0;
for ngrams in ngram_model.ngrams.iter() {
for ngram in ngrams {
let probability = self.look_up_ngram_probability(language, ngram);
let probability = models[ngram.char_count - 1]
.and_then(|m| m.get(ngram.value))
.copied()
.unwrap_or(0.0);

if probability > 0.0 {
sum += probability.ln();
Expand All @@ -761,38 +838,26 @@ impl LanguageDetector {
sum
}

fn look_up_ngram_probability(&self, language: &Language, ngram: &NgramRef) -> f64 {
let ngram_length = ngram.value.chars().count();
let language_models = match ngram_length {
5 => self.fivegram_language_models,
4 => self.quadrigram_language_models,
3 => self.trigram_language_models,
2 => self.bigram_language_models,
1 => self.unigram_language_models,
0 => panic!("zerogram detected"),
_ => panic!(
"unsupported ngram length detected: {}",
ngram.value.chars().count()
),
};

self.load_language_models(language_models, language, ngram_length);

match language_models.read().unwrap().get(language) {
Some(model) => *model.get(ngram.value).unwrap_or(&0.0),
None => 0.0,
}
}

fn count_unigrams(
&self,
unigram_model: &TestDataLanguageModel,
filtered_languages: &HashSet<Language>,
language_models: &HashMap<Language, HashMap<String, f64>>,
) -> HashMap<Language, u32> {
let mut unigram_counts = HashMap::new();
for language in filtered_languages.iter() {
let model = match language_models.get(language) {
Some(model) => model,
None => continue,
};

for unigrams in unigram_model.ngrams.iter() {
if self.look_up_ngram_probability(language, unigrams.get(0).unwrap()) > 0.0 {
let probability = model
.get(unigrams.get(0).unwrap().value)
.copied()
.unwrap_or(0.0);

if probability > 0.0 {
self.increment_counter(&mut unigram_counts, language.clone());
}
}
Expand Down Expand Up @@ -938,6 +1003,7 @@ mod tests {
use rstest::*;

use crate::language::Language::*;
use crate::ngram::NgramRef;

use super::*;

Expand Down Expand Up @@ -1257,23 +1323,28 @@ mod tests {
ngram: &str,
expected_probability: f64,
) {
let probability = detector_for_english_and_german
.look_up_ngram_probability(&language, &NgramRef::new(ngram));
let ngram_length = ngram.chars().count();
let probability = detector_for_english_and_german.get_language_models(
ngram_length,
&hashset!(language.clone()),
|language_models| {
language_models[ngram_length - 1]
.unwrap()
.get(&language)
.unwrap()
.get(ngram)
.copied()
.unwrap_or(0.0)
},
);

assert_eq!(
probability, expected_probability,
"expected probability {} for language '{:?}' and ngram '{}', got {}",
expected_probability, language, ngram, probability
);
}

#[rstest]
#[should_panic(expected = "zerogram detected")]
fn assert_ngram_probability_lookup_does_not_work_for_zerogram(
detector_for_english_and_german: LanguageDetector,
) {
detector_for_english_and_german.look_up_ngram_probability(&English, &NgramRef::new(""));
}

#[rstest(
test_data_model,
expected_sum_of_probabilities,
Expand All @@ -1297,8 +1368,17 @@ mod tests {
test_data_model: TestDataLanguageModel,
expected_sum_of_probabilities: f64,
) {
let sum_of_probabilities = detector_for_english_and_german
.compute_sum_of_ngram_probabilities(&English, &test_data_model);
let sum_of_probabilities = detector_for_english_and_german.get_language_models(
5,
&hashset!(English),
|language_models| {
detector_for_english_and_german.compute_sum_of_ngram_probabilities(
&English,
&test_data_model,
&language_models,
)
},
);

assert!(
approx_eq!(
Expand Down Expand Up @@ -1345,8 +1425,15 @@ mod tests {
test_data_model: TestDataLanguageModel,
expected_probabilities: HashMap<Language, f64>,
) {
let probabilities = detector_for_english_and_german
.compute_language_probabilities(&test_data_model, &hashset!(English, German));
let languages = hashset!(English, German);
let probabilities =
detector_for_english_and_german.get_language_models(5, &languages, |language_models| {
detector_for_english_and_german.compute_language_probabilities(
&test_data_model,
&languages,
&language_models,
)
});

for (language, probability) in probabilities {
let expected_probability = expected_probabilities[&language];
Expand Down
2 changes: 1 addition & 1 deletion src/ngram.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ impl<'de> Deserialize<'de> for Ngram {
#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
pub(crate) struct NgramRef<'a> {
pub(crate) value: &'a str,
char_count: usize,
pub(crate) char_count: usize,
}

impl<'a> NgramRef<'a> {
Expand Down

0 comments on commit c0f7e71

Please sign in to comment.