diff --git a/.gitignore b/.gitignore index aa43b2f..d73a712 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,3 @@ - /target **/*.rs.bk Cargo.lock diff --git a/CHANGELOG.md b/CHANGELOG.md index ebad102..2136d07 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,10 @@ # Changelog All notable changes to this project will be documented in this file. +### [Unreleased] +### Added +- Add `resolved_value_id` property to `EntityValue` and `ParsedValue` structs [#34](https://github.com/snipsco/gazetteer-entity-parser/pull/34) + ## [0.7.0] - 2019-04-16 ### Added - Add API to prepend entity values [#31](https://github.com/snipsco/gazetteer-entity-parser/pull/31) diff --git a/Cargo.toml b/Cargo.toml index 87981e7..5f3a45c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,6 +2,7 @@ name = "gazetteer-entity-parser" version = "0.7.0" authors = ["Alaa Saade "] +edition = "2018" [profile.bench] debug = true diff --git a/README.rst b/README.rst index 8f5a48d..feb518e 100644 --- a/README.rst +++ b/README.rst @@ -21,22 +21,27 @@ Example .add_value(EntityValue { raw_value: "king of pop".to_string(), resolved_value: "Michael Jackson".to_string(), + resolved_value_id: None, }) .add_value(EntityValue { raw_value: "the rolling stones".to_string(), resolved_value: "The Rolling Stones".to_string(), + resolved_value_id: Some("artist_id_42".to_string()), }) .add_value(EntityValue { raw_value: "the fab four".to_string(), resolved_value: "The Beatles".to_string(), + resolved_value_id: None, }) .add_value(EntityValue { raw_value: "queen of soul".to_string(), resolved_value: "Aretha Franklin".to_string(), + resolved_value_id: None, }) .add_value(EntityValue { raw_value: "the red hot chili peppers".to_string(), resolved_value: "The Red Hot Chili Peppers".to_string(), + resolved_value_id: None, }) .minimum_tokens_ratio(2. / 3.) .build() @@ -50,12 +55,14 @@ Example raw_value: "the stones".to_string(), matched_value: "the rolling stones".to_string(), resolved_value: "The Rolling Stones".to_string(), + resolved_value_id: Some("artist_id_42".to_string()), range: 25..35, }, ParsedValue { raw_value: "fab four".to_string(), matched_value: "the fab four".to_string(), resolved_value: "The Beatles".to_string(), + resolved_value_id: None, range: 52..60, }]); } diff --git a/benches/bench_parser.rs b/benches/bench_parser.rs index 7ab5ef4..96e41d9 100644 --- a/benches/bench_parser.rs +++ b/benches/bench_parser.rs @@ -81,6 +81,7 @@ fn generate_random_gazetteer( .take(nb_entity_values) .map(|string| EntityValue { resolved_value: string.to_lowercase(), + resolved_value_id: None, raw_value: string, }) .collect(); @@ -106,11 +107,11 @@ fn generate_random_parser( } fn get_low_redundancy_parser() -> (Parser, RandomStringGenerator) { - generate_random_parser(10000, 100000, 10, 0.5, 50) + generate_random_parser(10_000, 100_000, 10, 0.5, 50) } fn get_high_redundancy_parser() -> (Parser, RandomStringGenerator) { - generate_random_parser(100, 100000, 5, 0.5, 50) + generate_random_parser(100, 100_000, 5, 0.5, 50) } fn parsing_low_redundancy(c: &mut Criterion) { diff --git a/examples/entity_parsing_from_scratch.rs b/examples/entity_parsing_from_scratch.rs index 103592a..617027d 100644 --- a/examples/entity_parsing_from_scratch.rs +++ b/examples/entity_parsing_from_scratch.rs @@ -1,5 +1,3 @@ -extern crate gazetteer_entity_parser; - use gazetteer_entity_parser::*; fn main() { @@ -7,22 +5,27 @@ fn main() { .add_value(EntityValue { raw_value: "king of pop".to_string(), resolved_value: "Michael Jackson".to_string(), + resolved_value_id: None, }) .add_value(EntityValue { raw_value: "the rolling stones".to_string(), resolved_value: "The Rolling Stones".to_string(), + resolved_value_id: Some("id42".to_string()), }) .add_value(EntityValue { raw_value: "the fab four".to_string(), resolved_value: "The Beatles".to_string(), + resolved_value_id: None, }) .add_value(EntityValue { raw_value: "queen of soul".to_string(), resolved_value: "Aretha Franklin".to_string(), + resolved_value_id: None, }) .add_value(EntityValue { raw_value: "the red hot chili peppers".to_string(), resolved_value: "The Red Hot Chili Peppers".to_string(), + resolved_value_id: None, }) .minimum_tokens_ratio(2. / 3.) .build() @@ -30,18 +33,23 @@ fn main() { let sentence = "My favourite artists are the stones and fab four"; let extracted_entities = parser.run(sentence).unwrap(); - assert_eq!(extracted_entities, - vec![ - ParsedValue { - raw_value: "the stones".to_string(), - resolved_value: "The Rolling Stones".to_string(), - range: 25..35, - matched_value: "the rolling stones".to_string() - }, - ParsedValue { - raw_value: "fab four".to_string(), - resolved_value: "The Beatles".to_string(), - range: 40..48, - matched_value: "the fab four".to_string(), - }]); + assert_eq!( + extracted_entities, + vec![ + ParsedValue { + raw_value: "the stones".to_string(), + resolved_value: "The Rolling Stones".to_string(), + range: 25..35, + matched_value: "the rolling stones".to_string(), + resolved_value_id: Some("id42".to_string()), + }, + ParsedValue { + raw_value: "fab four".to_string(), + resolved_value: "The Beatles".to_string(), + range: 40..48, + matched_value: "the fab four".to_string(), + resolved_value_id: None, + } + ] + ); } diff --git a/examples/interactive_parsing_cli.rs b/examples/interactive_parsing_cli.rs index ef0690d..dcd6140 100644 --- a/examples/interactive_parsing_cli.rs +++ b/examples/interactive_parsing_cli.rs @@ -1,20 +1,20 @@ -extern crate clap; -extern crate serde_json; -extern crate gazetteer_entity_parser; - -use clap::{Arg, App}; use std::io; use std::io::Write; + +use clap::{App, Arg}; + use gazetteer_entity_parser::Parser; fn main() { let matches = App::new("gazetteer-entity-parser-demo") .about("Interactive CLI for parsing gazetteer entities") - .arg(Arg::with_name("PARSER_DIR") - .required(true) - .takes_value(true) - .index(1) - .help("path to the parser directory")) + .arg( + Arg::with_name("PARSER_DIR") + .required(true) + .takes_value(true) + .index(1) + .help("path to the parser directory"), + ) .get_matches(); let parser_dir = matches.value_of("PARSER_DIR").unwrap(); diff --git a/src/data.rs b/src/data.rs index 3c6b5fe..e3d0c17 100644 --- a/src/data.rs +++ b/src/data.rs @@ -1,11 +1,14 @@ use std::result::Result; use serde::{Deserialize, Deserializer, Serialize, Serializer}; +use serde_derive::*; /// Struct representing the value of an entity to be added to the parser #[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq, Hash)] pub struct EntityValue { pub resolved_value: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub resolved_value_id: Option, pub raw_value: String, } diff --git a/src/lib.rs b/src/lib.rs index 2c29a31..c346bf6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -17,22 +17,27 @@ //! // We fill the gazetteer with artists, sorted by popularity //! gazetteer.add(EntityValue { //! resolved_value: "The Rolling Stones".to_string(), +//! resolved_value_id: Some("id42".to_string()), //! raw_value: "the rolling stones".to_string(), //! }); //! gazetteer.add(EntityValue { //! resolved_value: "The Strokes".to_string(), +//! resolved_value_id: None, //! raw_value: "the strokes".to_string(), //! }); //! gazetteer.add(EntityValue { //! resolved_value: "The Hives".to_string(), +//! resolved_value_id: None, //! raw_value: "the hives".to_string(), //! }); //! gazetteer.add(EntityValue { //! resolved_value: "Jacques Brel".to_string(), +//! resolved_value_id: None, //! raw_value: "jacques brel".to_string(), //! }); //! gazetteer.add(EntityValue { //! resolved_value: "Daniel Brel".to_string(), +//! resolved_value_id: None, //! raw_value: "daniel brel".to_string(), //! }); //! @@ -53,6 +58,7 @@ //! vec![ParsedValue { //! raw_value: "the stones".to_string(), //! resolved_value: "The Rolling Stones".to_string(), +//! resolved_value_id: Some("id42".to_string()), //! matched_value: "the rolling stones".to_string(), //! range: 20..30, //! }] @@ -65,21 +71,14 @@ //! vec![ParsedValue { //! raw_value: "brel".to_string(), //! resolved_value: "Jacques Brel".to_string(), +//! resolved_value_id: None, //! matched_value: "jacques brel".to_string(), //! range: 20..24, //! }] //! ); //!``` -#[macro_use] -extern crate failure; -extern crate fnv; -extern crate rmp_serde as rmps; -extern crate serde; -extern crate serde_json; - -#[macro_use] -extern crate serde_derive; +#![allow(clippy::range_plus_one, clippy::float_cmp)] mod constants; mod data; diff --git a/src/parser.rs b/src/parser.rs index 5f21f29..ad23a1c 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -1,19 +1,22 @@ -use constants::*; -use data::EntityValue; -use errors::*; -use failure::ResultExt; -use fnv::{FnvHashMap as HashMap, FnvHashSet as HashSet}; -use rmps::{from_read, Serializer}; -use serde::Serialize; -use serde_json; use std::cmp::Ordering; use std::collections::hash_map::Entry; use std::collections::{BTreeSet, BinaryHeap}; use std::fs; use std::ops::Range; use std::path::Path; -use symbol_table::{ResolvedSymbolTable, TokenSymbolTable}; -use utils::{check_threshold, whitespace_tokenizer}; + +use failure::{format_err, ResultExt}; +use fnv::{FnvHashMap as HashMap, FnvHashSet as HashSet}; +use rmp_serde::{from_read, Serializer}; +use serde::Serialize; +use serde_derive::*; +use serde_json; + +use crate::constants::*; +use crate::data::EntityValue; +use crate::errors::*; +use crate::symbol_table::{ResolvedSymbolTable, TokenSymbolTable}; +use crate::utils::{check_threshold, whitespace_tokenizer}; /// Struct representing the parser. The Parser will match the longest possible contiguous /// substrings of a query that match partial entity values. The order in which the values are @@ -79,19 +82,18 @@ impl PossibleMatch { } impl Ord for PossibleMatch { + #[allow(clippy::if_same_then_else)] fn cmp(&self, other: &PossibleMatch) -> Ordering { if self.n_consumed_tokens < other.n_consumed_tokens { Ordering::Less } else if self.n_consumed_tokens > other.n_consumed_tokens { Ordering::Greater + } else if self.raw_value_length < other.raw_value_length { + Ordering::Greater + } else if self.raw_value_length > other.raw_value_length { + Ordering::Less } else { - if self.raw_value_length < other.raw_value_length { - Ordering::Greater - } else if self.raw_value_length > other.raw_value_length { - Ordering::Less - } else { - other.rank.cmp(&self.rank) - } + other.rank.cmp(&self.rank) } } } @@ -108,6 +110,7 @@ impl PartialOrd for PossibleMatch { #[derive(Debug, PartialEq, Eq, Serialize)] pub struct ParsedValue { pub resolved_value: String, + pub resolved_value_id: Option, // character-level pub range: Range, pub raw_value: String, @@ -146,7 +149,7 @@ impl Parser { // we duplicate it to allow several raw values to map to it let res_value_idx = self .resolved_symbol_table - .add_symbol(entity_value.resolved_value); + .add_symbol(entity_value.resolved_value, entity_value.resolved_value_id); for (_, token) in whitespace_tokenizer(&entity_value.raw_value) { let token_idx = self.tokens_symbol_table.add_symbol(token); @@ -170,13 +173,10 @@ impl Parser { /// Prepend a list of entity values to the parser and update the ranks accordingly pub fn prepend_values(&mut self, entity_values: Vec) { // update rank of previous values - for res_val in self.resolved_symbol_table.get_all_indices() { - self.resolved_value_to_tokens - .entry(*res_val) - .and_modify(|(rank, _)| *rank += entity_values.len() as u32); - } + self.shift_ranks(entity_values.len() as u32); + for (rank, entity_value) in entity_values.into_iter().enumerate() { - self.add_value(entity_value.clone(), rank as u32); + self.add_value(entity_value, rank as u32); } // Update the stop words and edge cases @@ -206,7 +206,7 @@ impl Parser { let mut tokens_with_counts = self .token_to_resolved_values .iter() - .map(|(idx, res_values)| (idx.clone(), res_values.len())) + .map(|(idx, res_values)| (*idx, res_values.len())) .collect::>(); tokens_with_counts.sort_by_key(|&(_, count)| -(count as i32)); @@ -226,7 +226,7 @@ impl Parser { self.stop_words.insert(tok_idx); self.token_to_resolved_values .entry(tok_idx) - .or_insert_with(|| BTreeSet::new()); + .or_insert_with(BTreeSet::new); } } @@ -243,7 +243,11 @@ impl Parser { pub fn get_stop_words(&self) -> HashSet { self.stop_words .iter() - .flat_map(|idx| self.tokens_symbol_table.find_index(idx).cloned()) + .flat_map(|idx| { + self.tokens_symbol_table + .find_index(*idx) + .map(|sym| sym.to_string()) + }) .collect() } @@ -251,7 +255,11 @@ impl Parser { pub fn get_edge_cases(&self) -> HashSet { self.edge_cases .iter() - .flat_map(|idx| self.resolved_symbol_table.find_index(idx).cloned()) + .flat_map(|idx| { + self.resolved_symbol_table + .find_index(*idx) + .map(|resolved_symbol| resolved_symbol.value.to_string()) + }) .collect() } @@ -286,7 +294,7 @@ impl Parser { for val in &self.injected_values { for res_val in self.resolved_symbol_table.remove_symbol(&val) { let (_, tokens) = self - .get_tokens_from_resolved_value(&res_val) + .get_tokens_from_resolved_value(res_val) .with_context(|_| { format_err!("Error when retrieving tokens of resolved value '{}'", val) })? @@ -315,26 +323,22 @@ impl Parser { } } for tok_idx in tokens_marked_for_removal { - self.tokens_symbol_table.remove_index(&tok_idx); + self.tokens_symbol_table.remove_index(tok_idx); self.token_to_resolved_values.remove(&tok_idx); } } if prepend { // update rank of previous values - let n_new_values = new_values.len() as u32; - for res_val in self.resolved_symbol_table.get_all_indices() { - self.resolved_value_to_tokens - .entry(*res_val) - .and_modify(|(rank, _)| *rank += n_new_values); - } + self.shift_ranks(new_values.len() as u32); } - let new_start_rank = match prepend { + let new_start_rank = if prepend { // we inject new values from rank 0 to n_new_values - 1 - true => 0, + 0 + } else { // we inject new values from the current last rank onwards - false => self.resolved_value_to_tokens.len(), + self.resolved_value_to_tokens.len() } as u32; for (rank, entity_value) in new_values.into_iter().enumerate() { @@ -393,9 +397,9 @@ impl Parser { impl Parser { /// get resolved value - fn get_tokens_from_resolved_value(&self, resolved_value: &u32) -> Result<&(u32, Vec)> { + fn get_tokens_from_resolved_value(&self, resolved_value: u32) -> Result<&(u32, Vec)> { self.resolved_value_to_tokens - .get(resolved_value) + .get(&resolved_value) .ok_or_else(|| { format_err!( "Cannot find resolved value index {} in `resolved_value_to_tokens`", @@ -405,8 +409,8 @@ impl Parser { } /// get resolved values from token - fn get_resolved_values_from_token(&self, token: &u32) -> Result<&BTreeSet> { - self.token_to_resolved_values.get(token).ok_or_else(|| { + fn get_resolved_values_from_token(&self, token: u32) -> Result<&BTreeSet> { + self.token_to_resolved_values.get(&token).ok_or_else(|| { format_err!( "Cannot find token index {} in `token_to_resolved_values`", token @@ -414,6 +418,13 @@ impl Parser { }) } + /// Shift the ranks of all the resolved values + fn shift_ranks(&mut self, shift: u32) { + for (rank, _) in self.resolved_value_to_tokens.values_mut() { + *rank += shift + } + } + /// get the underlying matched value associated to a `PossibleMatch` fn get_matched_value(&self, possible_match: &PossibleMatch) -> Result { Ok(self @@ -427,8 +438,8 @@ impl Parser { })? .1 .iter() - .flat_map(|token_idx| self.tokens_symbol_table.find_index(token_idx)) - .map(|token_string| token_string.as_str()) + .flat_map(|token_idx| self.tokens_symbol_table.find_index(*token_idx)) + .map(|token_string| token_string) .collect::>() .join(" ")) } @@ -448,7 +459,7 @@ impl Parser { let mut skipped_tokens: HashMap, u32)> = HashMap::default(); for (token_idx, (range, token)) in whitespace_tokenizer(input).enumerate() { if let Some(value) = self.tokens_symbol_table.find_symbol(&token) { - let res_vals_from_token = self.get_resolved_values_from_token(&value)?; + let res_vals_from_token = self.get_resolved_values_from_token(*value)?; if res_vals_from_token.is_empty() { continue; } @@ -522,6 +533,7 @@ impl Parser { })) } + #[allow(clippy::too_many_arguments)] fn update_or_insert_possible_match( &self, value: u32, @@ -545,17 +557,16 @@ impl Parser { )?; } Entry::Vacant(entry) => { - self.insert_new_possible_match( + if let Some(new_possible_match) = self.insert_new_possible_match( res_val, value, range, token_idx, threshold, &skipped_tokens, - )? - .map(|new_possible_match| { + )? { entry.insert(new_possible_match); - }); + } } } Ok(()) @@ -568,18 +579,20 @@ impl Parser { value: u32, range: Range, threshold: f32, - ref mut matches_heap: &mut BinaryHeap, + matches_heap: &mut BinaryHeap, ) -> Result<()> { - let (rank, otokens) = - self.get_tokens_from_resolved_value(&possible_match.resolved_value)?; + let (rank, otokens) = self.get_tokens_from_resolved_value(possible_match.resolved_value)?; if token_idx == possible_match.last_token_in_input + 1 { // Grow the last Possible Match // Find the next token in the resolved value that matches the // input token - for otoken_idx in (possible_match.last_token_in_resolution + 1)..otokens.len() { - let otok = otokens[otoken_idx]; - if value == otok { + for (otoken_idx, otok) in otokens + .iter() + .enumerate() + .skip(possible_match.last_token_in_resolution + 1) + { + if value == *otok { possible_match.range.end = range.end; possible_match.n_consumed_tokens += 1; possible_match.last_token_in_input = token_idx; @@ -628,7 +641,7 @@ impl Parser { threshold: f32, skipped_tokens: &HashMap, u32)>, ) -> Result> { - let (rank, otokens) = self.get_tokens_from_resolved_value(&res_val)?; + let (rank, otokens) = self.get_tokens_from_resolved_value(res_val)?; let last_token_in_resolution = otokens.iter().position(|e| *e == value).ok_or_else(|| { format_err!("Missing token {} from list {:?}", value, otokens.clone()) @@ -736,7 +749,7 @@ impl Parser { possible_match.tokens_range.start <= **idx && possible_match.tokens_range.end > **idx }) - .map(|idx| *idx) + .cloned() .collect(); if !overlapping_tokens.is_empty() { @@ -757,6 +770,15 @@ impl Parser { } continue; } + let resolved_symbol = self + .resolved_symbol_table + .find_index(possible_match.resolved_value) + .ok_or_else(|| { + format_err!( + "Missing key for resolved value {}", + possible_match.resolved_value + ) + })?; parsing.push(ParsedValue { range: possible_match.range.clone(), @@ -765,16 +787,8 @@ impl Parser { .skip(possible_match.range.start) .take(possible_match.range.len()) .collect(), - resolved_value: self - .resolved_symbol_table - .find_index(&possible_match.resolved_value) - .cloned() - .ok_or_else(|| { - format_err!( - "Missing key for resolved value {}", - possible_match.resolved_value - ) - })?, + resolved_value: resolved_symbol.value.to_string(), + resolved_value_id: resolved_symbol.identifier.map(|id| id.to_string()), matched_value: self.get_matched_value(&possible_match)?, }); for idx in tokens_range_start..tokens_range_end { @@ -802,16 +816,19 @@ mod tests { extern crate mio_httpc; extern crate tempfile; - use self::mio_httpc::CallBuilder; - use self::tempfile::tempdir; + use std::time::Instant; + + use failure::ResultExt; + + use crate::data::EntityValue; + use crate::data::Gazetteer; + use crate::parser_builder::ParserBuilder; + #[allow(unused_imports)] use super::*; - #[allow(unused_imports)] - use data::EntityValue; - use data::Gazetteer; - use failure::ResultExt; - use parser_builder::ParserBuilder; - use std::time::Instant; + + use self::mio_httpc::CallBuilder; + use self::tempfile::tempdir; #[test] fn test_serialization_deserialization() { @@ -820,14 +837,17 @@ mod tests { gazetteer.add(EntityValue { resolved_value: "The Flying Stones".to_string(), raw_value: "the flying stones".to_string(), + resolved_value_id: None, }); gazetteer.add(EntityValue { resolved_value: "The Rolling Stones".to_string(), raw_value: "the rolling stones".to_string(), + resolved_value_id: None, }); gazetteer.add(EntityValue { resolved_value: "The Rolling Stones".to_string(), raw_value: "the stones".to_string(), + resolved_value_id: None, }); let parser = ParserBuilder::default() .minimum_tokens_ratio(0.5) @@ -867,19 +887,23 @@ mod tests { gazetteer.add(EntityValue { resolved_value: "The Flying Stones".to_string(), raw_value: "the flying stones".to_string(), + resolved_value_id: None, }); gazetteer.add(EntityValue { resolved_value: "The Rolling Stones".to_string(), raw_value: "the rolling stones".to_string(), + resolved_value_id: None, }); gazetteer.add(EntityValue { resolved_value: "The Stones Rolling".to_string(), raw_value: "the stones rolling".to_string(), + resolved_value_id: None, }); gazetteer.add(EntityValue { resolved_value: "The Stones".to_string(), raw_value: "the stones".to_string(), + resolved_value_id: None, }); let mut parser = ParserBuilder::default() @@ -907,6 +931,7 @@ mod tests { vec![ParsedValue { raw_value: "the rolling".to_string(), resolved_value: "The Rolling Stones".to_string(), + resolved_value_id: None, matched_value: "the rolling stones".to_string(), range: 20..31, }] @@ -922,6 +947,7 @@ mod tests { vec![ParsedValue { raw_value: "the rolling stones".to_string(), resolved_value: "The Rolling Stones".to_string(), + resolved_value_id: None, matched_value: "the rolling stones".to_string(), range: 20..38, }] @@ -937,6 +963,7 @@ mod tests { vec![ParsedValue { raw_value: "the stones rolling".to_string(), resolved_value: "The Stones Rolling".to_string(), + resolved_value_id: None, matched_value: "the stones rolling".to_string(), range: 20..38, }] @@ -950,6 +977,7 @@ mod tests { vec![ParsedValue { raw_value: "the stones".to_string(), resolved_value: "The Stones".to_string(), + resolved_value_id: None, matched_value: "the stones".to_string(), range: 20..30, }] @@ -969,6 +997,7 @@ mod tests { vec![ParsedValue { raw_value: "the rolling stones".to_string(), resolved_value: "The Rolling Stones".to_string(), + resolved_value_id: None, matched_value: "the rolling stones".to_string(), range: 26..44, }] @@ -983,6 +1012,7 @@ mod tests { vec![ParsedValue { raw_value: "the rolling stones".to_string(), resolved_value: "The Rolling Stones".to_string(), + resolved_value_id: None, matched_value: "the rolling stones".to_string(), range: 30..48, }] @@ -994,18 +1024,22 @@ mod tests { let mut gazetteer = Gazetteer::default(); gazetteer.add(EntityValue { resolved_value: "The Flying Stones".to_string(), + resolved_value_id: None, raw_value: "the flying stones".to_string(), }); gazetteer.add(EntityValue { resolved_value: "The Rolling Stones".to_string(), + resolved_value_id: None, raw_value: "the rolling stones".to_string(), }); gazetteer.add(EntityValue { resolved_value: "Blink-182".to_string(), + resolved_value_id: None, raw_value: "blink one eight two".to_string(), }); gazetteer.add(EntityValue { resolved_value: "Je Suis Animal".to_string(), + resolved_value_id: None, raw_value: "je suis animal".to_string(), }); @@ -1022,12 +1056,14 @@ mod tests { ParsedValue { raw_value: "je".to_string(), resolved_value: "Je Suis Animal".to_string(), + resolved_value_id: None, matched_value: "je suis animal".to_string(), range: 0..2, }, ParsedValue { raw_value: "rolling stones".to_string(), resolved_value: "The Rolling Stones".to_string(), + resolved_value_id: None, matched_value: "the rolling stones".to_string(), range: 20..34, }, @@ -1041,12 +1077,14 @@ mod tests { ParsedValue { raw_value: "je".to_string(), resolved_value: "Je Suis Animal".to_string(), + resolved_value_id: None, matched_value: "je suis animal".to_string(), range: 0..2, }, ParsedValue { raw_value: "rolling stones".to_string(), resolved_value: "The Rolling Stones".to_string(), + resolved_value_id: None, matched_value: "the rolling stones".to_string(), range: 22..36, }, @@ -1062,12 +1100,14 @@ mod tests { ParsedValue { raw_value: "rolling stones".to_string(), resolved_value: "The Rolling Stones".to_string(), + resolved_value_id: None, matched_value: "the rolling stones".to_string(), range: 20..34, }, ParsedValue { raw_value: "blink eight".to_string(), resolved_value: "Blink-182".to_string(), + resolved_value_id: None, matched_value: "blink one eight two".to_string(), range: 39..50, }, @@ -1083,14 +1123,17 @@ mod tests { let mut gazetteer = Gazetteer::default(); gazetteer.add(EntityValue { resolved_value: "Blink-182".to_string(), + resolved_value_id: None, raw_value: "blink one eight two".to_string(), }); gazetteer.add(EntityValue { resolved_value: "Blink-182".to_string(), + resolved_value_id: None, raw_value: "blink 182".to_string(), }); gazetteer.add(EntityValue { resolved_value: "Blink-182".to_string(), + resolved_value_id: None, raw_value: "blink".to_string(), }); @@ -1106,6 +1149,7 @@ mod tests { vec![ParsedValue { raw_value: "blink 182".to_string(), resolved_value: "Blink-182".to_string(), + resolved_value_id: None, matched_value: "blink 182".to_string(), range: 16..25, }] @@ -1118,6 +1162,7 @@ mod tests { vec![ParsedValue { raw_value: "blink".to_string(), resolved_value: "Blink-182".to_string(), + resolved_value_id: None, matched_value: "blink".to_string(), range: 16..21, }] @@ -1130,6 +1175,71 @@ mod tests { vec![ParsedValue { raw_value: "one eight two".to_string(), resolved_value: "Blink-182".to_string(), + resolved_value_id: None, + matched_value: "blink one eight two".to_string(), + range: 16..29, + }] + ); + } + + #[test] + fn test_parser_with_resolved_value_ids() { + let mut gazetteer = Gazetteer::default(); + gazetteer.add(EntityValue { + resolved_value: "Blink-182".to_string(), + resolved_value_id: Some("42".to_string()), + raw_value: "blink one eight two".to_string(), + }); + gazetteer.add(EntityValue { + resolved_value: "Blink-182".to_string(), + resolved_value_id: Some("43".to_string()), + raw_value: "blink 182".to_string(), + }); + gazetteer.add(EntityValue { + resolved_value: "Blink-182".to_string(), + resolved_value_id: Some("44".to_string()), + raw_value: "blink".to_string(), + }); + + let mut parser = ParserBuilder::default() + .minimum_tokens_ratio(0.0) + .gazetteer(gazetteer) + .build() + .unwrap(); + + let mut parsed = parser.run("let's listen to blink 182").unwrap(); + assert_eq!( + parsed, + vec![ParsedValue { + raw_value: "blink 182".to_string(), + resolved_value: "Blink-182".to_string(), + resolved_value_id: Some("43".to_string()), + matched_value: "blink 182".to_string(), + range: 16..25, + }] + ); + + parser.set_threshold(1.0); + parsed = parser.run("let's listen to blink").unwrap(); + assert_eq!( + parsed, + vec![ParsedValue { + raw_value: "blink".to_string(), + resolved_value: "Blink-182".to_string(), + resolved_value_id: Some("44".to_string()), + matched_value: "blink".to_string(), + range: 16..21, + }] + ); + + parser.set_threshold(3.0 / 4.0); + parsed = parser.run("let's listen to one eight two").unwrap(); + assert_eq!( + parsed, + vec![ParsedValue { + raw_value: "one eight two".to_string(), + resolved_value: "Blink-182".to_string(), + resolved_value_id: Some("42".to_string()), matched_value: "blink one eight two".to_string(), range: 16..29, }] @@ -1141,22 +1251,27 @@ mod tests { let mut gazetteer = Gazetteer::default(); gazetteer.add(EntityValue { resolved_value: "Jacques Brel".to_string(), + resolved_value_id: None, raw_value: "jacques brel".to_string(), }); gazetteer.add(EntityValue { resolved_value: "The Rolling Stones".to_string(), + resolved_value_id: None, raw_value: "the rolling stones".to_string(), }); gazetteer.add(EntityValue { resolved_value: "The Flying Stones".to_string(), + resolved_value_id: None, raw_value: "the flying stones".to_string(), }); gazetteer.add(EntityValue { resolved_value: "Daniel Brel".to_string(), + resolved_value_id: None, raw_value: "daniel brel".to_string(), }); gazetteer.add(EntityValue { resolved_value: "Jacques".to_string(), + resolved_value_id: None, raw_value: "jacques".to_string(), }); let parser = ParserBuilder::default() @@ -1172,6 +1287,7 @@ mod tests { vec![ParsedValue { raw_value: "the stones".to_string(), resolved_value: "The Rolling Stones".to_string(), + resolved_value_id: None, matched_value: "the rolling stones".to_string(), range: 16..26, }] @@ -1183,6 +1299,7 @@ mod tests { vec![ParsedValue { raw_value: "brel".to_string(), resolved_value: "Jacques Brel".to_string(), + resolved_value_id: None, matched_value: "jacques brel".to_string(), range: 16..20, }] @@ -1195,6 +1312,7 @@ mod tests { vec![ParsedValue { raw_value: "the flying stones".to_string(), resolved_value: "The Flying Stones".to_string(), + resolved_value_id: None, matched_value: "the flying stones".to_string(), range: 16..33, }] @@ -1206,6 +1324,7 @@ mod tests { vec![ParsedValue { raw_value: "daniel brel".to_string(), resolved_value: "Daniel Brel".to_string(), + resolved_value_id: None, matched_value: "daniel brel".to_string(), range: 16..27, }] @@ -1217,6 +1336,7 @@ mod tests { vec![ParsedValue { raw_value: "jacques".to_string(), resolved_value: "Jacques".to_string(), + resolved_value_id: None, matched_value: "jacques".to_string(), range: 16..23, }] @@ -1228,10 +1348,12 @@ mod tests { let mut gazetteer = Gazetteer::default(); gazetteer.add(EntityValue { resolved_value: "Jacques Brel".to_string(), + resolved_value_id: None, raw_value: "jacques brel".to_string(), }); gazetteer.add(EntityValue { resolved_value: "The Rolling Stones".to_string(), + resolved_value_id: None, raw_value: "the rolling stones".to_string(), }); let mut parser = ParserBuilder::default() @@ -1248,6 +1370,7 @@ mod tests { vec![ParsedValue { raw_value: "brel".to_string(), resolved_value: "Jacques Brel".to_string(), + resolved_value_id: None, matched_value: "jacques brel".to_string(), range: 16..20, }] @@ -1256,10 +1379,12 @@ mod tests { let values_to_prepend = vec![ EntityValue { resolved_value: "Daniel Brel".to_string(), + resolved_value_id: None, raw_value: "daniel brel".to_string(), }, EntityValue { resolved_value: "Eric Brel".to_string(), + resolved_value_id: None, raw_value: "eric brel".to_string(), }, ]; @@ -1272,6 +1397,7 @@ mod tests { vec![ParsedValue { raw_value: "brel".to_string(), resolved_value: "Daniel Brel".to_string(), + resolved_value_id: None, matched_value: "daniel brel".to_string(), range: 16..20, }] @@ -1283,6 +1409,7 @@ mod tests { let mut gazetteer = Gazetteer::default(); gazetteer.add(EntityValue { resolved_value: "The Rolling Stones".to_string(), + resolved_value_id: None, raw_value: "the rolling stones".to_string(), }); @@ -1303,6 +1430,7 @@ mod tests { let mut gazetteer = Gazetteer::default(); gazetteer.add(EntityValue { resolved_value: "Quand est-ce ?".to_string(), + resolved_value_id: None, raw_value: "quand est -ce".to_string(), }); let parser = ParserBuilder::default() @@ -1316,6 +1444,7 @@ mod tests { parsed, vec![ParsedValue { resolved_value: "Quand est-ce ?".to_string(), + resolved_value_id: None, range: 4..13, raw_value: "quand est".to_string(), matched_value: "quand est -ce".to_string(), @@ -1328,6 +1457,7 @@ mod tests { let mut gazetteer = Gazetteer::default(); gazetteer.add(EntityValue { resolved_value: "The Rolling Stones".to_string(), + resolved_value_id: None, raw_value: "the rolling stones".to_string(), }); @@ -1342,6 +1472,7 @@ mod tests { parsed, vec![ParsedValue { resolved_value: "The Rolling Stones".to_string(), + resolved_value_id: None, range: 8..18, raw_value: "the stones".to_string(), matched_value: "the rolling stones".to_string(), @@ -1354,22 +1485,27 @@ mod tests { let mut gazetteer = Gazetteer::default(); gazetteer.add(EntityValue { resolved_value: "The Flying Stones".to_string(), + resolved_value_id: None, raw_value: "the flying stones".to_string(), }); gazetteer.add(EntityValue { resolved_value: "The Rolling Stones".to_string(), + resolved_value_id: None, raw_value: "the rolling stones".to_string(), }); gazetteer.add(EntityValue { resolved_value: "Blink-182".to_string(), + resolved_value_id: None, raw_value: "blink one eight two".to_string(), }); gazetteer.add(EntityValue { resolved_value: "Je Suis Animal".to_string(), + resolved_value_id: None, raw_value: "je suis animal".to_string(), }); gazetteer.add(EntityValue { resolved_value: "Les Enfoirés".to_string(), + resolved_value_id: None, raw_value: "les enfoirés".to_string(), }); @@ -1385,6 +1521,7 @@ mod tests { vec![ ParsedValue { resolved_value: "Les Enfoirés".to_string(), + resolved_value_id: None, range: 16..19, raw_value: "les".to_string(), matched_value: "les enfoirés".to_string(), @@ -1392,6 +1529,7 @@ mod tests { ParsedValue { raw_value: "rolling stones".to_string(), resolved_value: "The Rolling Stones".to_string(), + resolved_value_id: None, matched_value: "the rolling stones".to_string(), range: 20..34, }, @@ -1406,11 +1544,13 @@ mod tests { ParsedValue { raw_value: "je".to_string(), resolved_value: "Je Suis Animal".to_string(), + resolved_value_id: None, matched_value: "je suis animal".to_string(), range: 0..2, }, ParsedValue { resolved_value: "Les Enfoirés".to_string(), + resolved_value_id: None, matched_value: "les enfoirés".to_string(), range: 16..19, raw_value: "les".to_string(), @@ -1418,6 +1558,7 @@ mod tests { ParsedValue { raw_value: "rolling stones".to_string(), resolved_value: "The Rolling Stones".to_string(), + resolved_value_id: None, matched_value: "the rolling stones".to_string(), range: 20..34, }, @@ -1431,6 +1572,7 @@ mod tests { vec![ParsedValue { raw_value: "rolling stones".to_string(), resolved_value: "The Rolling Stones".to_string(), + resolved_value_id: None, matched_value: "the rolling stones".to_string(), range: 20..34, }] @@ -1442,6 +1584,7 @@ mod tests { let mut gazetteer = Gazetteer::default(); gazetteer.add(EntityValue { resolved_value: "The Rolling Stones".to_string(), + resolved_value_id: None, raw_value: "the rolling stones".to_string(), }); @@ -1463,6 +1606,7 @@ mod tests { vec![ParsedValue { raw_value: "the rolling stones".to_string(), resolved_value: "The Rolling Stones".to_string(), + resolved_value_id: None, matched_value: "the rolling stones".to_string(), range: 8..26, }] @@ -1474,6 +1618,7 @@ mod tests { let mut gazetteer = Gazetteer::default(); gazetteer.add(EntityValue { resolved_value: "The Rolling Stones".to_string(), + resolved_value_id: None, raw_value: "the rolling stones".to_string(), }); @@ -1485,6 +1630,7 @@ mod tests { let new_values = vec![EntityValue { resolved_value: "The Flying Stones".to_string(), + resolved_value_id: None, raw_value: "the flying stones".to_string(), }]; @@ -1499,6 +1645,7 @@ mod tests { vec![ParsedValue { raw_value: "flying stones".to_string(), resolved_value: "The Flying Stones".to_string(), + resolved_value_id: None, matched_value: "the flying stones".to_string(), range: 20..33, }] @@ -1511,6 +1658,7 @@ mod tests { vec![ParsedValue { raw_value: "the stones".to_string(), resolved_value: "The Rolling Stones".to_string(), + resolved_value_id: None, matched_value: "the rolling stones".to_string(), range: 16..26, }] @@ -1526,6 +1674,7 @@ mod tests { vec![ParsedValue { raw_value: "flying stones".to_string(), resolved_value: "The Flying Stones".to_string(), + resolved_value_id: None, matched_value: "the flying stones".to_string(), range: 20..33, }] @@ -1538,6 +1687,7 @@ mod tests { vec![ParsedValue { raw_value: "the stones".to_string(), resolved_value: "The Flying Stones".to_string(), + resolved_value_id: None, matched_value: "the flying stones".to_string(), range: 16..26, }] @@ -1549,6 +1699,7 @@ mod tests { let mut gazetteer = Gazetteer::default(); gazetteer.add(EntityValue { resolved_value: "The Rolling Stones".to_string(), + resolved_value_id: None, raw_value: "the rolling stones".to_string(), }); @@ -1560,6 +1711,7 @@ mod tests { let new_values_1 = vec![EntityValue { resolved_value: "The Flying Stones".to_string(), + resolved_value_id: None, raw_value: "the flying stones".to_string(), }]; @@ -1568,6 +1720,7 @@ mod tests { // Test injection from vanilla let new_values_2 = vec![EntityValue { resolved_value: "Queens Of The Stone Age".to_string(), + resolved_value_id: None, raw_value: "queens of the stone age".to_string(), }]; @@ -1589,6 +1742,7 @@ mod tests { vec![ParsedValue { raw_value: "queens the stone age".to_string(), resolved_value: "Queens Of The Stone Age".to_string(), + resolved_value_id: None, matched_value: "queens of the stone age".to_string(), range: 16..36, }] @@ -1600,11 +1754,7 @@ mod tests { .is_empty()); assert!(parser.tokens_symbol_table.find_symbol("flying").is_none()); assert!(!parser.token_to_resolved_values.contains_key(&flying_idx)); - assert!(!parser - .token_to_resolved_values - .get(&stones_idx) - .unwrap() - .contains(&flying_stones_idx)); + assert!(!parser.token_to_resolved_values[&stones_idx].contains(&flying_stones_idx)); assert!(!parser .resolved_value_to_tokens .contains_key(&flying_stones_idx)); @@ -1615,10 +1765,12 @@ mod tests { let mut gazetteer = Gazetteer::default(); gazetteer.add(EntityValue { resolved_value: "The Rolling Stones".to_string(), + resolved_value_id: None, raw_value: "the rolling stones".to_string(), }); gazetteer.add(EntityValue { resolved_value: "The Stones".to_string(), + resolved_value_id: None, raw_value: "the stones".to_string(), }); @@ -1650,10 +1802,12 @@ mod tests { let new_values = vec![ EntityValue { resolved_value: "Rolling".to_string(), + resolved_value_id: None, raw_value: "rolling".to_string(), }, EntityValue { resolved_value: "Rolling Two".to_string(), + resolved_value_id: None, raw_value: "rolling two".to_string(), }, ]; @@ -1679,26 +1833,32 @@ mod tests { let mut gazetteer = Gazetteer::default(); gazetteer.add(EntityValue { resolved_value: "Black And White".to_string(), + resolved_value_id: None, raw_value: "black and white".to_string(), }); gazetteer.add(EntityValue { resolved_value: "Album".to_string(), + resolved_value_id: None, raw_value: "album".to_string(), }); gazetteer.add(EntityValue { resolved_value: "The Black and White Album".to_string(), + resolved_value_id: None, raw_value: "the black and white album".to_string(), }); gazetteer.add(EntityValue { resolved_value: "1 2 3 4".to_string(), + resolved_value_id: None, raw_value: "one two three four".to_string(), }); gazetteer.add(EntityValue { resolved_value: "3 4 5 6".to_string(), + resolved_value_id: None, raw_value: "three four five six".to_string(), }); gazetteer.add(EntityValue { resolved_value: "6 7".to_string(), + resolved_value_id: None, raw_value: "six seven".to_string(), }); @@ -1716,6 +1876,7 @@ mod tests { vec![ParsedValue { raw_value: "black and white album".to_string(), resolved_value: "The Black and White Album".to_string(), + resolved_value_id: None, matched_value: "the black and white album".to_string(), range: 19..40, }] @@ -1727,6 +1888,7 @@ mod tests { vec![ParsedValue { raw_value: "one two three four".to_string(), resolved_value: "1 2 3 4".to_string(), + resolved_value_id: None, matched_value: "one two three four".to_string(), range: 0..18, }] @@ -1740,12 +1902,14 @@ mod tests { ParsedValue { raw_value: "one two three four".to_string(), resolved_value: "1 2 3 4".to_string(), + resolved_value_id: None, matched_value: "one two three four".to_string(), range: 5..23, }, ParsedValue { raw_value: "five six".to_string(), resolved_value: "3 4 5 6".to_string(), + resolved_value_id: None, matched_value: "three four five six".to_string(), range: 24..32, }, @@ -1761,12 +1925,14 @@ mod tests { ParsedValue { raw_value: "one two three four".to_string(), resolved_value: "1 2 3 4".to_string(), + resolved_value_id: None, matched_value: "one two three four".to_string(), range: 5..23, }, ParsedValue { raw_value: "six seven".to_string(), resolved_value: "6 7".to_string(), + resolved_value_id: None, matched_value: "six seven".to_string(), range: 29..38, }, @@ -1777,7 +1943,7 @@ mod tests { #[test] #[ignore] fn real_world_gazetteer_parser() { - let (_, body) = CallBuilder::get().max_response(20000000).timeout_ms(60000).url("https://s3.amazonaws.com/snips/nlu-lm/test/gazetteer-entity-parser/artist_gazetteer_formatted.json").unwrap().exec().unwrap(); + let (_, body) = CallBuilder::get().max_response(20_000_000).timeout_ms(60_000).url("https://s3.amazonaws.com/snips/nlu-lm/test/gazetteer-entity-parser/artist_gazetteer_formatted.json").unwrap().exec().unwrap(); let data: Vec = serde_json::from_reader(&*body).unwrap(); let gaz = Gazetteer { data }; @@ -1795,6 +1961,7 @@ mod tests { vec![ParsedValue { raw_value: "rolling stones".to_string(), resolved_value: "The Rolling Stones".to_string(), + resolved_value_id: None, matched_value: "the rolling stones".to_string(), range: 20..34, }] @@ -1807,12 +1974,13 @@ mod tests { vec![ParsedValue { raw_value: "bowie".to_string(), resolved_value: "David Bowie".to_string(), + resolved_value_id: None, matched_value: "david bowie".to_string(), range: 16..21, }] ); - let (_, body) = CallBuilder::get().max_response(20000000).timeout_ms(60000).url("https://s3.amazonaws.com/snips/nlu-lm/test/gazetteer-entity-parser/album_gazetteer_formatted.json").unwrap().exec().unwrap(); + let (_, body) = CallBuilder::get().max_response(20_000_000).timeout_ms(60_000).url("https://s3.amazonaws.com/snips/nlu-lm/test/gazetteer-entity-parser/album_gazetteer_formatted.json").unwrap().exec().unwrap(); let data: Vec = serde_json::from_reader(&*body).unwrap(); let gaz = Gazetteer { data }; @@ -1832,6 +2000,7 @@ mod tests { vec![ParsedValue { raw_value: "black and white album".to_string(), resolved_value: "The Black and White Album".to_string(), + resolved_value_id: None, matched_value: "black and white album".to_string(), range: 19..40, }] @@ -1845,6 +2014,7 @@ mod tests { vec![ParsedValue { raw_value: "dark side of the moon".to_string(), resolved_value: "Dark Side of the Moon".to_string(), + resolved_value_id: None, matched_value: "dark side of the moon".to_string(), range: 16..37, }] @@ -1860,12 +2030,14 @@ mod tests { ParsedValue { raw_value: "je veux".to_string(), resolved_value: "Je veux du bonheur".to_string(), + resolved_value_id: None, matched_value: "je veux du bonheur".to_string(), range: 0..7, }, ParsedValue { raw_value: "dark side of the moon".to_string(), resolved_value: "Dark Side of the Moon".to_string(), + resolved_value_id: None, matched_value: "dark side of the moon".to_string(), range: 16..37, }, @@ -1877,7 +2049,7 @@ mod tests { #[ignore] fn test_real_word_injection() { // Real-world artist gazetteer - let (_, body) = CallBuilder::get().max_response(20000000).timeout_ms(100000).url("https://s3.amazonaws.com/snips/nlu-lm/test/gazetteer-entity-parser/artist_gazetteer_formatted.json").unwrap().exec().unwrap(); + let (_, body) = CallBuilder::get().max_response(20_000_000).timeout_ms(100_000).url("https://s3.amazonaws.com/snips/nlu-lm/test/gazetteer-entity-parser/artist_gazetteer_formatted.json").unwrap().exec().unwrap(); let data: Vec = serde_json::from_reader(&*body).unwrap(); let album_gaz = Gazetteer { data }; @@ -1896,7 +2068,7 @@ mod tests { .unwrap(); // Get 10k values from the album gazetter to inject in the album parser - let (_, body) = CallBuilder::get().max_response(20000000).timeout_ms(100000).url("https://s3.amazonaws.com/snips/nlu-lm/test/gazetteer-entity-parser/album_gazetteer_formatted.json").unwrap().exec().unwrap(); + let (_, body) = CallBuilder::get().max_response(20_000_000).timeout_ms(100_000).url("https://s3.amazonaws.com/snips/nlu-lm/test/gazetteer-entity-parser/album_gazetteer_formatted.json").unwrap().exec().unwrap(); let mut new_values: Vec = serde_json::from_reader(&*body).unwrap(); new_values.truncate(10000); @@ -1910,6 +2082,7 @@ mod tests { vec![ParsedValue { raw_value: "hans knappertsbusch".to_string(), resolved_value: "Hans Knappertsbusch".to_string(), + resolved_value_id: None, matched_value: "hans knappertsbusch".to_string(), range: 16..35, }] @@ -1926,6 +2099,7 @@ mod tests { vec![ParsedValue { raw_value: "hans knappertsbusch conducts".to_string(), resolved_value: "Hans Knappertsbusch conducts".to_string(), + resolved_value_id: None, matched_value: "hans knappertsbusch conducts".to_string(), range: 16..44, }] diff --git a/src/parser_builder.rs b/src/parser_builder.rs index 11131b8..aeb4394 100644 --- a/src/parser_builder.rs +++ b/src/parser_builder.rs @@ -1,7 +1,10 @@ -use data::Gazetteer; -use errors::*; -use parser::Parser; -use EntityValue; +use failure::format_err; +use serde_derive::*; + +use crate::data::Gazetteer; +use crate::errors::*; +use crate::parser::Parser; +use crate::EntityValue; /// Struct exposing a builder allowing to configure and build a Parser #[derive(Clone, Serialize, Deserialize, PartialEq, Debug)] @@ -70,48 +73,56 @@ impl ParserBuilder { /// Instantiate a Parser from the ParserBuilder pub fn build(self) -> Result { if self.threshold < 0.0 || self.threshold > 1.0 { - return Err( - format_err!("Invalid value for threshold ({}), it must be between 0.0 and 1.0", - self.threshold)) + return Err(format_err!( + "Invalid value for threshold ({}), it must be between 0.0 and 1.0", + self.threshold + )); } - let mut parser = self.gazetteer.data - .into_iter() - .enumerate() - .fold(Parser::default(), |mut parser, (rank, entity_value)| { + let mut parser = self.gazetteer.data.into_iter().enumerate().fold( + Parser::default(), + |mut parser, (rank, entity_value)| { parser.add_value(entity_value, rank as u32); parser - }); + }, + ); parser.set_threshold(self.threshold); - parser.set_stop_words(self.n_gazetteer_stop_words.unwrap_or(0), - self.additional_stop_words); + parser.set_stop_words( + self.n_gazetteer_stop_words.unwrap_or(0), + self.additional_stop_words, + ); Ok(parser) } } #[cfg(test)] mod tests { - use super::*; - use data::EntityValue; use serde_json; + use super::*; + #[test] fn test_parser_builder_using_gazetteer() { // Given let entity_values = vec![ EntityValue { resolved_value: "The Flying Stones".to_string(), + resolved_value_id: None, raw_value: "the flying stones".to_string(), }, EntityValue { resolved_value: "The Rolling Stones".to_string(), + resolved_value_id: None, raw_value: "the rolling stones".to_string(), }, EntityValue { resolved_value: "The Rolling Stones".to_string(), + resolved_value_id: None, raw_value: "the stones".to_string(), - } + }, ]; - let gazetteer = Gazetteer { data: entity_values }; + let gazetteer = Gazetteer { + data: entity_values, + }; let builder = ParserBuilder::default() .minimum_tokens_ratio(0.5) .gazetteer(gazetteer.clone()) @@ -119,9 +130,7 @@ mod tests { .additional_stop_words(vec!["hello".to_string()]); // When - let parser_from_builder = builder - .build() - .unwrap(); + let parser_from_builder = builder.build().unwrap(); // Then let mut expected_parser = Parser::default(); @@ -137,25 +146,30 @@ mod tests { #[test] fn test_parser_builder_using_extended_gazetteer() { // Given - let entity_values_1 = vec![ - EntityValue { - resolved_value: "The Flying Stones".to_string(), - raw_value: "the flying stones".to_string(), - } - ]; + let entity_values_1 = vec![EntityValue { + resolved_value: "The Flying Stones".to_string(), + resolved_value_id: None, + raw_value: "the flying stones".to_string(), + }]; let entity_values_2 = vec![ EntityValue { resolved_value: "The Rolling Stones".to_string(), + resolved_value_id: None, raw_value: "the rolling stones".to_string(), }, EntityValue { resolved_value: "The Rolling Stones".to_string(), + resolved_value_id: None, raw_value: "the stones".to_string(), - } + }, ]; - let gazetteer_1 = Gazetteer { data: entity_values_1.clone() }; - let gazetteer_2 = Gazetteer { data: entity_values_2.clone() }; + let gazetteer_1 = Gazetteer { + data: entity_values_1.clone(), + }; + let gazetteer_2 = Gazetteer { + data: entity_values_2.clone(), + }; let builder = ParserBuilder::default() .minimum_tokens_ratio(0.5) .gazetteer(gazetteer_1) @@ -164,9 +178,7 @@ mod tests { .additional_stop_words(vec!["hello".to_string()]); // When - let parser_from_builder = builder - .build() - .unwrap(); + let parser_from_builder = builder.build().unwrap(); // Then let mut expected_parser = Parser::default(); @@ -188,16 +200,19 @@ mod tests { let entity_values = vec![ EntityValue { resolved_value: "The Flying Stones".to_string(), + resolved_value_id: None, raw_value: "the flying stones".to_string(), }, EntityValue { resolved_value: "The Rolling Stones".to_string(), + resolved_value_id: None, raw_value: "the rolling stones".to_string(), }, EntityValue { resolved_value: "The Rolling Stones".to_string(), + resolved_value_id: None, raw_value: "the stones".to_string(), - } + }, ]; let builder = ParserBuilder::default() .minimum_tokens_ratio(0.5) @@ -208,9 +223,7 @@ mod tests { .additional_stop_words(vec!["hello".to_string()]); // When - let parser_from_builder = builder - .build() - .unwrap(); + let parser_from_builder = builder.build().unwrap(); // Then let mut expected_parser = Parser::default(); @@ -230,6 +243,11 @@ mod tests { { "resolved_value": "yala", "raw_value": "yolo" + }, + { + "resolved_value": "Value With Id", + "resolved_value_id": "42", + "raw_value": "value with id" } ], "threshold": 0.6, @@ -242,8 +260,14 @@ mod tests { let mut gazetteer = Gazetteer::default(); gazetteer.add(EntityValue { resolved_value: "yala".to_string(), + resolved_value_id: None, raw_value: "yolo".to_string(), }); + gazetteer.add(EntityValue { + resolved_value: "Value With Id".to_string(), + resolved_value_id: Some("42".to_string()), + raw_value: "value with id".to_string(), + }); let builder = ParserBuilder::default() .minimum_tokens_ratio(0.6) .gazetteer(gazetteer) diff --git a/src/symbol_table.rs b/src/symbol_table.rs index 4efdef1..da90bbd 100644 --- a/src/symbol_table.rs +++ b/src/symbol_table.rs @@ -1,9 +1,10 @@ /// Implementation of a symbol table that /// - always maps a given index to a single string /// - allows mapping a string to several indices - use std::collections::BTreeMap; +use serde_derive::*; + #[derive(PartialEq, Eq, Debug, Default, Serialize, Deserialize)] pub struct TokenSymbolTable { string_to_index: BTreeMap, @@ -16,7 +17,7 @@ impl TokenSymbolTable { pub fn add_symbol(&mut self, symbol: String) -> u32 { self.string_to_index .get(&symbol) - .map(|idx| *idx) + .cloned() .unwrap_or_else(|| { let symbol_index = self.available_index; self.available_index += 1; @@ -31,43 +32,64 @@ impl TokenSymbolTable { } /// Find the unique symbol corresponding to an index in the symbol table - pub fn find_index(&self, idx: &u32) -> Option<&String> { + pub fn find_index(&self, idx: u32) -> Option<&str> { self.string_to_index .iter() - .find(|(_, sym_idx)| *sym_idx == idx) - .map(|(symbol, _)| symbol) + .find(|(_, sym_idx)| **sym_idx == idx) + .map(|(symbol, _)| &**symbol) } /// Remove the unique symbol corresponding to an index in the symbol table - pub fn remove_index(&mut self, idx: &u32) -> Option { - let symbol = self.find_index(idx).cloned(); - symbol.and_then(|symbol| - self.string_to_index - .remove(&symbol) - .map(|_| symbol)) + pub fn remove_index(&mut self, idx: u32) -> Option { + self.find_index(idx) + .map(|sym| sym.to_string()) + .and_then(|symbol| { + self.string_to_index + .remove(symbol.as_str()) + .map(|_| symbol.to_string()) + }) } } #[derive(PartialEq, Eq, Debug, Default, Serialize, Deserialize)] pub struct ResolvedSymbolTable { index_to_resolved: BTreeMap, + index_to_resolved_id: BTreeMap, available_index: u32, } +#[derive(Clone, PartialEq, Eq, Debug)] +pub struct ResolvedSymbol<'a> { + /// Resolved value + pub value: &'a str, + /// Identifier optionally associated to the resolved value + pub identifier: Option<&'a str>, +} + impl ResolvedSymbolTable { - /// Add a symbol to the symbol table. If the symbol already exists, this will - /// generate a new index to allow the symbol to be duplicated in the symbol table + /// Add a symbol to the symbol table, along with its optional identifier. If the symbol already + /// exists, this will generate a new index to allow the symbol to be duplicated in the symbol + /// table /// Returns the newly generated corresponding index - pub fn add_symbol(&mut self, symbol: String) -> u32 { + pub fn add_symbol(&mut self, symbol: String, symbol_id: Option) -> u32 { let available_index = self.available_index; self.index_to_resolved.insert(available_index, symbol); + if let Some(id) = symbol_id { + self.index_to_resolved_id.insert(available_index, id); + } self.available_index += 1; available_index } /// Find a symbol from its index - pub fn find_index(&self, index: &u32) -> Option<&String> { - self.index_to_resolved.get(index) + pub fn find_index(&self, index: u32) -> Option { + self.index_to_resolved.get(&index).map(|symbol| { + let symbol_id = self.index_to_resolved_id.get(&index); + ResolvedSymbol { + value: &**symbol, + identifier: symbol_id.map(|id| &**id), + } + }) } /// Find all the indices corresponding to a single symbol @@ -82,16 +104,96 @@ impl ResolvedSymbolTable { /// Remove a symbol and all its linked indices from the symbol table pub fn remove_symbol(&mut self, symbol: &str) -> Vec { let indices = self.find_symbol(symbol); - indices.into_iter() - .flat_map(|idx| - self.index_to_resolved - .remove(&idx) - .map(|_| idx)) + indices + .into_iter() + .flat_map(|idx| { + self.index_to_resolved_id.remove(&idx); + self.index_to_resolved.remove(&idx).map(|_| idx) + }) .collect() } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_add_in_token_symbol_table() { + // Given + let mut symtable = TokenSymbolTable::default(); + + // When + let symbol = "hello"; + let index = symtable.add_symbol(symbol.to_string()); + + // Then + assert_eq!(Some(&index), symtable.find_symbol(symbol)); + assert_eq!(Some(symbol), symtable.find_index(index)); + } + + #[test] + fn test_remove_in_token_symbol_table() { + // Given + let mut symtable = TokenSymbolTable::default(); + let index_hello = symtable.add_symbol("hello".to_string()); + symtable.add_symbol("world".to_string()); + + // When + let removed_hello = symtable.remove_index(index_hello); + let hello_sym = symtable.find_index(index_hello); + let hello_idx = symtable.find_symbol("hello"); + + // Then + assert_eq!(Some("hello".to_string()), removed_hello); + assert_eq!(None, hello_sym); + assert_eq!(None, hello_idx); + } + + #[test] + fn test_add_in_resolved_symbol_table() { + // Given + let mut symtable = ResolvedSymbolTable::default(); + + // When + let symbol = "hello"; + let index_1 = symtable.add_symbol(symbol.to_string(), Some("id_42".to_string())); + let index_2 = symtable.add_symbol(symbol.to_string(), Some("id_43".to_string())); + + // Then + assert_eq!(vec![index_1, index_2], symtable.find_symbol(symbol)); + assert_eq!( + Some(ResolvedSymbol { + value: "hello", + identifier: Some("id_42") + }), + symtable.find_index(index_1) + ); + assert_eq!( + Some(ResolvedSymbol { + value: "hello", + identifier: Some("id_43") + }), + symtable.find_index(index_2) + ); + } + + #[test] + fn test_remove_in_resolved_symbol_table() { + // Given + let mut symtable = ResolvedSymbolTable::default(); + let index_hello_1 = symtable.add_symbol("hello".to_string(), Some("42".to_string())); + let index_hello_2 = symtable.add_symbol("hello".to_string(), Some("43".to_string())); + symtable.add_symbol("world".to_string(), None); + + // When + let removed_hello_indices = symtable.remove_symbol("hello"); + let hello_sym = symtable.find_index(index_hello_1); + let hello_indices = symtable.find_symbol("hello"); - /// Get a vec of all the integer values used to represent the symbols in the symbol table - pub fn get_all_indices(&self) -> Vec<&u32> { - self.index_to_resolved.keys().collect() + // Then + assert_eq!(vec![index_hello_1, index_hello_2], removed_hello_indices); + assert_eq!(None, hello_sym); + assert_eq!(Vec::::new(), hello_indices); } } diff --git a/update_version.sh b/update_version.sh index 2efad30..8901fde 100755 --- a/update_version.sh +++ b/update_version.sh @@ -4,13 +4,13 @@ set -e NEW_VERSION=$1 -if [ -z $NEW_VERSION ]; then +if [[ -z ${NEW_VERSION} ]]; then echo "Usage: $0 NEW_VERSION" exit 1 fi SPLIT_VERSION=( ${NEW_VERSION//./ } ) -if [ ${#SPLIT_VERSION[@]} -ne 3 ]; then +if [[ ${#SPLIT_VERSION[@]} -ne 3 ]]; then echo "Version number is invalid (must be of the form x.y.z)" exit 1 fi