From f92115507b793a11df7e182f474e5679c5b93b34 Mon Sep 17 00:00:00 2001 From: KOUNOIKE Yuusuke Date: Mon, 8 Apr 2024 06:57:41 +0000 Subject: [PATCH 1/3] add Multilingual E5 models --- src/lib.rs | 59 +++++++++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 52 insertions(+), 7 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 886e3d1..7844266 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -93,6 +93,13 @@ pub enum EmbeddingModel { ParaphraseMLMiniLML12V2, /// v1.5 release of the small Chinese model BGESmallZHV15, + /// Small model of multilingual E5 Text Embeddings + MultilingualE5Small, + /// Base model of multilingual E5 Text Embeddings + MultilingualE5Base, + // Large model is something wrong, model.onnx size is only 546kB + // /// Large model of multilingual E5 Text Embeddings + // MultilingualE5Large, } impl Display for EmbeddingModel { @@ -191,6 +198,7 @@ pub struct TokenizerFiles { pub struct TextEmbedding { tokenizer: Tokenizer, session: Session, + need_token_type_ids: bool, } impl TextEmbedding { @@ -266,7 +274,15 @@ impl TextEmbedding { /// Private method to return an instance fn new(tokenizer: Tokenizer, session: Session) -> Self { - Self { tokenizer, session } + let need_token_type_ids = session + .inputs + .iter() + .any(|input| input.name == "token_type_ids"); + Self { + tokenizer, + session, + need_token_type_ids, + } } /// Return the TextEmbedding model's directory from cache or remote retrieval fn retrieve_model( @@ -456,6 +472,25 @@ impl TextEmbedding { description: String::from("v1.5 release of the small Chinese model"), model_code: String::from("Xenova/bge-small-zh-v1.5"), }, + ModelInfo { + model: EmbeddingModel::MultilingualE5Small, + dim: 384, + description: String::from("Small model of multilingual E5 Text Embeddings"), + model_code: String::from("intfloat/multilingual-e5-small"), + }, + ModelInfo { + model: EmbeddingModel::MultilingualE5Base, + dim: 768, + description: String::from("Base model of multilingual E5 Text Embeddings"), + model_code: String::from("intfloat/multilingual-e5-base"), + }, + // something wrong in MultilingualE5Large, model.onnx size is only 546kB + // ModelInfo { + // model: EmbeddingModel::MultilingualE5Large, + // dim: 1024, + // description: String::from("Large model of multilingual E5 Text Embeddings"), + // model_code: String::from("intfloat/multilingual-e5-large"), + // }, ]; // TODO: Use when out in stable @@ -528,11 +563,18 @@ impl TextEmbedding { let token_type_ids_array = Array::from_shape_vec((batch_size, encoding_length), typeids_array)?; - let outputs = self.session.run(ort::inputs![ - "input_ids" => Value::from_array(inputs_ids_array)?, - "attention_mask" => Value::from_array(attention_mask_array)?, - "token_type_ids" => Value::from_array(token_type_ids_array)?, - ]?)?; + let outputs = if self.need_token_type_ids { + self.session.run(ort::inputs![ + "input_ids" => Value::from_array(inputs_ids_array)?, + "attention_mask" => Value::from_array(attention_mask_array)?, + "token_type_ids" => Value::from_array(token_type_ids_array)?, + ]?)? + } else { + self.session.run(ort::inputs![ + "input_ids" => Value::from_array(inputs_ids_array)?, + "attention_mask" => Value::from_array(attention_mask_array)?, + ]?)? + }; // Extract and normalize embeddings let output_data = outputs["last_hidden_state"].extract_tensor::()?; @@ -667,9 +709,12 @@ mod tests { .unwrap(); // Skip "nomic-ai/nomic-embed-text-v1" model for now as it has a different folder structure - // Also skip Xenova/bge-small-zh-v1.5" for the same reason + // Also skip "Xenova/bge-small-zh-v1.5", "intfloat/multilingual-e5-small", + // and "intfloat/multilingual-e5-base" for the same reason if supported_model.model_code == "nomic-ai/nomic-embed-text-v1" || supported_model.model_code == "Xenova/bge-small-zh-v1.5" + || supported_model.model_code == "intfloat/multilingual-e5-small" + || supported_model.model_code == "intfloat/multilingual-e5-base" { continue; } From 9211da04137bfae98b42d776120e9c17e648617c Mon Sep 17 00:00:00 2001 From: KOUNOIKE Yuusuke Date: Mon, 8 Apr 2024 10:51:16 +0000 Subject: [PATCH 2/3] fix session inputs creation --- src/lib.rs | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 7844266..bb15f3f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -563,18 +563,16 @@ impl TextEmbedding { let token_type_ids_array = Array::from_shape_vec((batch_size, encoding_length), typeids_array)?; - let outputs = if self.need_token_type_ids { - self.session.run(ort::inputs![ - "input_ids" => Value::from_array(inputs_ids_array)?, - "attention_mask" => Value::from_array(attention_mask_array)?, - "token_type_ids" => Value::from_array(token_type_ids_array)?, - ]?)? - } else { - self.session.run(ort::inputs![ - "input_ids" => Value::from_array(inputs_ids_array)?, - "attention_mask" => Value::from_array(attention_mask_array)?, - ]?)? - }; + let mut session_inputs = ort::inputs![ + "input_ids" => Value::from_array(inputs_ids_array)?, + "attention_mask" => Value::from_array(attention_mask_array)?, + ]?; + if self.need_token_type_ids { + session_inputs + .insert("token_type_ids", Value::from_array(token_type_ids_array)?); + } + + let outputs = self.session.run(session_inputs)?; // Extract and normalize embeddings let output_data = outputs["last_hidden_state"].extract_tensor::()?; From b657560d5d5590f5c47af0e85dff7dfc87f6de59 Mon Sep 17 00:00:00 2001 From: KOUNOIKE Yuusuke Date: Mon, 8 Apr 2024 10:51:59 +0000 Subject: [PATCH 3/3] add multilingual E5 models to README --- README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.md b/README.md index dbda90e..5e8351c 100644 --- a/README.md +++ b/README.md @@ -29,6 +29,8 @@ The default model is Flag Embedding, which is top of the [MTEB](https://huggingf - [**sentence-transformers/all-MiniLM-L6-v2**](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2) - [**sentence-transformers/paraphrase-MiniLM-L12-v2**](https://huggingface.co/sentence-transformers/paraphrase-MiniLM-L12-v2) - [**nomic-ai/nomic-embed-text-v1**](https://huggingface.co/nomic-ai/nomic-embed-text-v1) +- [**intfloat/multilingual-e5-small**](https://huggingface.co/intfloat/multilingual-e5-small) +- [**intfloat/multilingual-e5-base**](https://huggingface.co/intfloat/multilingual-e5-base) Alternatively, raw .onnx files can be loaded through the UserDefinedEmbeddingModel struct (for "bring your own" text embedding models).