From 70e664e447441ef2e9533d7f1b330e6e66374025 Mon Sep 17 00:00:00 2001 From: Adrien Ball Date: Wed, 6 Mar 2019 16:14:34 +0100 Subject: [PATCH] Leverage entity scopes of each intent in deterministic intent parser --- .../deterministic_intent_parser.rs | 244 ++++++++++++++---- 1 file changed, 189 insertions(+), 55 deletions(-) diff --git a/src/intent_parser/deterministic_intent_parser.rs b/src/intent_parser/deterministic_intent_parser.rs index b4a49cdc..b2798962 100644 --- a/src/intent_parser/deterministic_intent_parser.rs +++ b/src/intent_parser/deterministic_intent_parser.rs @@ -5,8 +5,7 @@ use std::path::Path; use std::str::FromStr; use std::sync::Arc; -use failure::ResultExt; -use itertools::Itertools; +use failure::{format_err, ResultExt}; use regex::{Regex, RegexBuilder}; use snips_nlu_ontology::{BuiltinEntityKind, IntentClassifierResult, Language}; use snips_nlu_utils::language::Language as NluUtilsLanguage; @@ -25,13 +24,14 @@ use crate::utils::{ }; use super::{internal_parsing_result, IntentParser, InternalParsingResult}; +use itertools::Itertools; pub struct DeterministicIntentParser { language: Language, regexes_per_intent: HashMap>, group_names_to_slot_names: HashMap, slot_names_to_entities: HashMap>, - builtin_scope: Vec, + entity_scopes: HashMap, Vec)>, ignore_stop_words: bool, shared_resources: Arc, } @@ -60,22 +60,35 @@ impl DeterministicIntentParser { shared_resources: Arc, ) -> Result { let language = Language::from_str(&model.language_code)?; - let builtin_scope = model + let entity_scopes = model .slot_names_to_entities .iter() - .flat_map(|(_, mapping)| { - mapping + .map(|(intent, mapping)| { + let builtin_entities = mapping .iter() .flat_map(|(_, entity)| BuiltinEntityKind::from_identifier(entity).ok()) + .unique() + .collect(); + let custom_entities = mapping + .iter() + .flat_map(|(_, entity)| { + if BuiltinEntityKind::from_identifier(entity).is_ok() { + None + } else { + Some(entity.to_string()) + } + }) + .unique() + .collect(); + (intent.to_string(), (builtin_entities, custom_entities)) }) - .unique() .collect(); Ok(DeterministicIntentParser { language, regexes_per_intent: compile_regexes_per_intent(model.patterns)?, group_names_to_slot_names: model.group_names_to_slot_names, slot_names_to_entities: model.slot_names_to_entities, - builtin_scope, + entity_scopes, ignore_stop_words: model.config.ignore_stop_words, shared_resources, }) @@ -145,59 +158,65 @@ impl DeterministicIntentParser { top_n: usize, intents: Option<&[&str]>, ) -> Result> { - let opt_intents_set: Option> = - intents.map(|intent_list| intent_list.iter().cloned().collect()); - let builtin_entities = self - .shared_resources - .builtin_entity_parser - .extract_entities(input, Some(self.builtin_scope.as_ref()), true)? - .into_iter() - .map(|entity| entity.into()); - - let custom_entities = self - .shared_resources - .custom_entity_parser - .extract_entities(input, None)? - .into_iter() - .map(|entity| entity.into()); - - let mut matched_entities: Vec = vec![]; - matched_entities.extend(builtin_entities); - matched_entities.extend(custom_entities); - - let (ranges_mapping, formatted_input) = - replace_entities(input, matched_entities, get_entity_placeholder); let cleaned_input = self.preprocess_text(input); - let cleaned_formatted_input = self.preprocess_text(&*formatted_input); - let mut results = vec![]; - for (intent, regexes) in self.regexes_per_intent.iter() { + let intents_set: HashSet<&str> = intents + .map(|intent_list| intent_list.iter().map(|intent| *intent).collect()) + .unwrap_or_else(|| { + self.slot_names_to_entities + .keys() + .map(|intent| &**intent) + .collect() + }); + let filtered_entity_scopes = self + .entity_scopes + .iter() + .filter(|(intent, _)| intents_set.contains(&***intent)); + + for (intent, (builtin_scope, custom_scope)) in filtered_entity_scopes { if results.len() == top_n { break; } - if !opt_intents_set - .as_ref() - .map(|intents_set| intents_set.contains(&**intent)) - .unwrap_or(true) - { - continue; - } - for regex in regexes { - if let Some(matching_result_formatted) = self.get_matching_result( - input, - &*cleaned_formatted_input, - regex, - intent, - Some(&ranges_mapping), - ) { - results.push(matching_result_formatted); - } else if let Some(matching_result) = - self.get_matching_result(input, &*cleaned_input, regex, intent, None) - { - results.push(matching_result); - } - } + + let builtin_entities = self + .shared_resources + .builtin_entity_parser + .extract_entities(input, Some(builtin_scope.as_ref()), true)? + .into_iter() + .map(|entity| entity.into()); + + let custom_entities = self + .shared_resources + .custom_entity_parser + .extract_entities(input, Some(custom_scope.as_ref()))? + .into_iter() + .map(|entity| entity.into()); + + let mut matched_entities: Vec = vec![]; + matched_entities.extend(builtin_entities); + matched_entities.extend(custom_entities); + + let (ranges_mapping, formatted_input) = + replace_entities(input, matched_entities, get_entity_placeholder); + let cleaned_formatted_input = self.preprocess_text(&*formatted_input); + self.regexes_per_intent + .get(intent) + .ok_or_else(|| format_err!("No associated regexes for intent '{}'", intent))? + .iter() + .find_map(|regex| { + self.get_matching_result( + input, + &*cleaned_formatted_input, + regex, + intent, + Some(&ranges_mapping), + ) + .or_else(|| { + self.get_matching_result(input, &*cleaned_input, regex, intent, None) + }) + }) + .map(|matching_result_formatted| results.push(matching_result_formatted)); } Ok(results) } @@ -363,6 +382,8 @@ mod tests { use crate::testutils::*; use super::*; + use crate::entity_parser::builtin_entity_parser::BuiltinEntityParser; + use crate::entity_parser::custom_entity_parser::CustomEntityParser; fn test_configuration() -> DeterministicParserModel { DeterministicParserModel { @@ -611,6 +632,119 @@ mod tests { assert_eq!(intent, expected_intent); } + #[test] + fn test_parse_intent_with_entities_from_different_intents() { + // Given + let text = "Send 10 dollars to John at the wall"; + + #[derive(Default)] + pub struct MyMockedBuiltinEntityParser; + + impl BuiltinEntityParser for MyMockedBuiltinEntityParser { + fn extract_entities( + &self, + sentence: &str, + filter_entity_kinds: Option<&[BuiltinEntityKind]>, + _use_cache: bool, + ) -> Result> { + let mocked_builtin_entity_number = BuiltinEntity { + value: "10".to_string(), + range: 5..7, + entity: SlotValue::Number(NumberValue { value: 10. }), + entity_kind: BuiltinEntityKind::Number, + }; + let mocked_builtin_entity_money = BuiltinEntity { + value: "10 dollars".to_string(), + range: 5..15, + entity: SlotValue::AmountOfMoney(AmountOfMoneyValue { + value: 10., + precision: Precision::Exact, + unit: Some("dollars".to_string()), + }), + entity_kind: BuiltinEntityKind::AmountOfMoney, + }; + if sentence != "Send 10 dollars to John at the wall" { + return Ok(vec![]); + } + Ok(filter_entity_kinds + .map(|entity_kinds| { + let mut entities = vec![]; + if entity_kinds.contains(&mocked_builtin_entity_number.entity_kind) { + entities.push(mocked_builtin_entity_number.clone()) + }; + if entity_kinds.contains(&mocked_builtin_entity_money.entity_kind) { + entities.push(mocked_builtin_entity_money.clone()) + }; + entities + }) + .unwrap_or_else(|| { + vec![mocked_builtin_entity_number, mocked_builtin_entity_money] + })) + } + } + + #[derive(Default)] + pub struct MyMockedCustomEntityParser; + + impl CustomEntityParser for MyMockedCustomEntityParser { + fn extract_entities( + &self, + sentence: &str, + filter_entity_kinds: Option<&[String]>, + ) -> Result> { + let mocked_custom_entity_1 = CustomEntity { + value: "John".to_string(), + resolved_value: "John".to_string(), + range: 19..23, + entity_identifier: "dummy_entity_1".to_string(), + }; + let mocked_custom_entity_2 = CustomEntity { + value: "the wall".to_string(), + resolved_value: "the wall".to_string(), + range: 27..35, + entity_identifier: "dummy_entity_2".to_string(), + }; + if sentence != "Send 10 dollars to John at the wall" { + return Ok(vec![]); + } + Ok(filter_entity_kinds + .map(|entity_kinds| { + let mut entities = vec![]; + if entity_kinds.contains(&mocked_custom_entity_1.entity_identifier) { + entities.push(mocked_custom_entity_1.clone()) + }; + if entity_kinds.contains(&mocked_custom_entity_2.entity_identifier) { + entities.push(mocked_custom_entity_2.clone()) + }; + entities + }) + .unwrap_or_else(|| vec![mocked_custom_entity_1, mocked_custom_entity_2])) + } + } + + let my_mocked_builtin_entity_parser = MyMockedBuiltinEntityParser {}; + let my_mocked_custom_entity_parser = MyMockedCustomEntityParser {}; + + let shared_resources = SharedResourcesBuilder::default() + .builtin_entity_parser(my_mocked_builtin_entity_parser) + .custom_entity_parser(my_mocked_custom_entity_parser) + .build(); + let parser = + DeterministicIntentParser::new(test_configuration(), Arc::new(shared_resources)) + .unwrap(); + + // When + let intent = parser.parse(text, None).unwrap().intent; + + // Then + let expected_intent = IntentClassifierResult { + intent_name: Some("dummy_intent_3".to_string()), + confidence_score: 1.0, + }; + + assert_eq!(intent, expected_intent); + } + #[test] fn test_parse_utterance_with_duplicated_slot_name() { // Given