Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve deterministic intent parser #126

Merged
merged 1 commit into from Mar 6, 2019
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
244 changes: 189 additions & 55 deletions src/intent_parser/deterministic_intent_parser.rs
Expand Up @@ -5,8 +5,7 @@ use std::path::Path;
use std::str::FromStr;
use std::sync::Arc;

use failure::ResultExt;
use itertools::Itertools;
use failure::{format_err, ResultExt};
use regex::{Regex, RegexBuilder};
use snips_nlu_ontology::{BuiltinEntityKind, IntentClassifierResult, Language};
use snips_nlu_utils::language::Language as NluUtilsLanguage;
Expand All @@ -25,13 +24,14 @@ use crate::utils::{
};

use super::{internal_parsing_result, IntentParser, InternalParsingResult};
use itertools::Itertools;

pub struct DeterministicIntentParser {
language: Language,
regexes_per_intent: HashMap<IntentName, Vec<Regex>>,
group_names_to_slot_names: HashMap<String, SlotName>,
slot_names_to_entities: HashMap<IntentName, HashMap<SlotName, EntityName>>,
builtin_scope: Vec<BuiltinEntityKind>,
entity_scopes: HashMap<IntentName, (Vec<BuiltinEntityKind>, Vec<EntityName>)>,
ignore_stop_words: bool,
shared_resources: Arc<SharedResources>,
}
Expand Down Expand Up @@ -60,22 +60,35 @@ impl DeterministicIntentParser {
shared_resources: Arc<SharedResources>,
) -> Result<Self> {
let language = Language::from_str(&model.language_code)?;
let builtin_scope = model
let entity_scopes = model
.slot_names_to_entities
.iter()
.flat_map(|(_, mapping)| {
mapping
.map(|(intent, mapping)| {
let builtin_entities = mapping
.iter()
.flat_map(|(_, entity)| BuiltinEntityKind::from_identifier(entity).ok())
.unique()
.collect();
let custom_entities = mapping
.iter()
.flat_map(|(_, entity)| {
if BuiltinEntityKind::from_identifier(entity).is_ok() {
None
} else {
Some(entity.to_string())
}
})
.unique()
.collect();
(intent.to_string(), (builtin_entities, custom_entities))
})
.unique()
.collect();
Ok(DeterministicIntentParser {
language,
regexes_per_intent: compile_regexes_per_intent(model.patterns)?,
group_names_to_slot_names: model.group_names_to_slot_names,
slot_names_to_entities: model.slot_names_to_entities,
builtin_scope,
entity_scopes,
ignore_stop_words: model.config.ignore_stop_words,
shared_resources,
})
Expand Down Expand Up @@ -145,59 +158,65 @@ impl DeterministicIntentParser {
top_n: usize,
intents: Option<&[&str]>,
) -> Result<Vec<InternalParsingResult>> {
let opt_intents_set: Option<HashSet<&str>> =
intents.map(|intent_list| intent_list.iter().cloned().collect());
let builtin_entities = self
.shared_resources
.builtin_entity_parser
.extract_entities(input, Some(self.builtin_scope.as_ref()), true)?
.into_iter()
.map(|entity| entity.into());

let custom_entities = self
.shared_resources
.custom_entity_parser
.extract_entities(input, None)?
.into_iter()
.map(|entity| entity.into());

let mut matched_entities: Vec<MatchedEntity> = vec![];
matched_entities.extend(builtin_entities);
matched_entities.extend(custom_entities);

let (ranges_mapping, formatted_input) =
replace_entities(input, matched_entities, get_entity_placeholder);
let cleaned_input = self.preprocess_text(input);
let cleaned_formatted_input = self.preprocess_text(&*formatted_input);

let mut results = vec![];

for (intent, regexes) in self.regexes_per_intent.iter() {
let intents_set: HashSet<&str> = intents
.map(|intent_list| intent_list.iter().map(|intent| *intent).collect())
.unwrap_or_else(|| {
self.slot_names_to_entities
.keys()
.map(|intent| &**intent)
.collect()
});
let filtered_entity_scopes = self
.entity_scopes
.iter()
.filter(|(intent, _)| intents_set.contains(&***intent));

for (intent, (builtin_scope, custom_scope)) in filtered_entity_scopes {
if results.len() == top_n {
break;
}
if !opt_intents_set
.as_ref()
.map(|intents_set| intents_set.contains(&**intent))
.unwrap_or(true)
{
continue;
}
for regex in regexes {
if let Some(matching_result_formatted) = self.get_matching_result(
input,
&*cleaned_formatted_input,
regex,
intent,
Some(&ranges_mapping),
) {
results.push(matching_result_formatted);
} else if let Some(matching_result) =
self.get_matching_result(input, &*cleaned_input, regex, intent, None)
{
results.push(matching_result);
}
}

let builtin_entities = self
.shared_resources
.builtin_entity_parser
.extract_entities(input, Some(builtin_scope.as_ref()), true)?
.into_iter()
.map(|entity| entity.into());

let custom_entities = self
.shared_resources
.custom_entity_parser
.extract_entities(input, Some(custom_scope.as_ref()))?
.into_iter()
.map(|entity| entity.into());

let mut matched_entities: Vec<MatchedEntity> = vec![];
matched_entities.extend(builtin_entities);
matched_entities.extend(custom_entities);

let (ranges_mapping, formatted_input) =
replace_entities(input, matched_entities, get_entity_placeholder);
let cleaned_formatted_input = self.preprocess_text(&*formatted_input);
self.regexes_per_intent
.get(intent)
.ok_or_else(|| format_err!("No associated regexes for intent '{}'", intent))?
.iter()
.find_map(|regex| {
self.get_matching_result(
input,
&*cleaned_formatted_input,
regex,
intent,
Some(&ranges_mapping),
)
.or_else(|| {
self.get_matching_result(input, &*cleaned_input, regex, intent, None)
})
})
.map(|matching_result_formatted| results.push(matching_result_formatted));
}
Ok(results)
}
Expand Down Expand Up @@ -363,6 +382,8 @@ mod tests {
use crate::testutils::*;

use super::*;
use crate::entity_parser::builtin_entity_parser::BuiltinEntityParser;
use crate::entity_parser::custom_entity_parser::CustomEntityParser;

fn test_configuration() -> DeterministicParserModel {
DeterministicParserModel {
Expand Down Expand Up @@ -611,6 +632,119 @@ mod tests {
assert_eq!(intent, expected_intent);
}

#[test]
fn test_parse_intent_with_entities_from_different_intents() {
// Given
let text = "Send 10 dollars to John at the wall";

#[derive(Default)]
pub struct MyMockedBuiltinEntityParser;

impl BuiltinEntityParser for MyMockedBuiltinEntityParser {
fn extract_entities(
&self,
sentence: &str,
filter_entity_kinds: Option<&[BuiltinEntityKind]>,
_use_cache: bool,
) -> Result<Vec<BuiltinEntity>> {
let mocked_builtin_entity_number = BuiltinEntity {
value: "10".to_string(),
range: 5..7,
entity: SlotValue::Number(NumberValue { value: 10. }),
entity_kind: BuiltinEntityKind::Number,
};
let mocked_builtin_entity_money = BuiltinEntity {
value: "10 dollars".to_string(),
range: 5..15,
entity: SlotValue::AmountOfMoney(AmountOfMoneyValue {
value: 10.,
precision: Precision::Exact,
unit: Some("dollars".to_string()),
}),
entity_kind: BuiltinEntityKind::AmountOfMoney,
};
if sentence != "Send 10 dollars to John at the wall" {
return Ok(vec![]);
}
Ok(filter_entity_kinds
.map(|entity_kinds| {
let mut entities = vec![];
if entity_kinds.contains(&mocked_builtin_entity_number.entity_kind) {
entities.push(mocked_builtin_entity_number.clone())
};
if entity_kinds.contains(&mocked_builtin_entity_money.entity_kind) {
entities.push(mocked_builtin_entity_money.clone())
};
entities
})
.unwrap_or_else(|| {
vec![mocked_builtin_entity_number, mocked_builtin_entity_money]
}))
}
}

#[derive(Default)]
pub struct MyMockedCustomEntityParser;

impl CustomEntityParser for MyMockedCustomEntityParser {
fn extract_entities(
&self,
sentence: &str,
filter_entity_kinds: Option<&[String]>,
) -> Result<Vec<CustomEntity>> {
let mocked_custom_entity_1 = CustomEntity {
value: "John".to_string(),
resolved_value: "John".to_string(),
range: 19..23,
entity_identifier: "dummy_entity_1".to_string(),
};
let mocked_custom_entity_2 = CustomEntity {
value: "the wall".to_string(),
resolved_value: "the wall".to_string(),
range: 27..35,
entity_identifier: "dummy_entity_2".to_string(),
};
if sentence != "Send 10 dollars to John at the wall" {
return Ok(vec![]);
}
Ok(filter_entity_kinds
.map(|entity_kinds| {
let mut entities = vec![];
if entity_kinds.contains(&mocked_custom_entity_1.entity_identifier) {
entities.push(mocked_custom_entity_1.clone())
};
if entity_kinds.contains(&mocked_custom_entity_2.entity_identifier) {
entities.push(mocked_custom_entity_2.clone())
};
entities
})
.unwrap_or_else(|| vec![mocked_custom_entity_1, mocked_custom_entity_2]))
}
}

let my_mocked_builtin_entity_parser = MyMockedBuiltinEntityParser {};
let my_mocked_custom_entity_parser = MyMockedCustomEntityParser {};

let shared_resources = SharedResourcesBuilder::default()
.builtin_entity_parser(my_mocked_builtin_entity_parser)
.custom_entity_parser(my_mocked_custom_entity_parser)
.build();
let parser =
DeterministicIntentParser::new(test_configuration(), Arc::new(shared_resources))
.unwrap();

// When
let intent = parser.parse(text, None).unwrap().intent;

// Then
let expected_intent = IntentClassifierResult {
intent_name: Some("dummy_intent_3".to_string()),
confidence_score: 1.0,
};

assert_eq!(intent, expected_intent);
}

#[test]
fn test_parse_utterance_with_duplicated_slot_name() {
// Given
Expand Down