Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ The default model is Flag Embedding, which is top of the [MTEB](https://huggingf
- [**sentence-transformers/all-MiniLM-L6-v2**](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2)
- [**sentence-transformers/paraphrase-MiniLM-L12-v2**](https://huggingface.co/sentence-transformers/paraphrase-MiniLM-L12-v2)
- [**nomic-ai/nomic-embed-text-v1**](https://huggingface.co/nomic-ai/nomic-embed-text-v1)
- [**intfloat/multilingual-e5-small**](https://huggingface.co/intfloat/multilingual-e5-small)
- [**intfloat/multilingual-e5-base**](https://huggingface.co/intfloat/multilingual-e5-base)

Alternatively, raw .onnx files can be loaded through the UserDefinedEmbeddingModel struct (for "bring your own" text embedding models).

Expand Down
53 changes: 48 additions & 5 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,13 @@ pub enum EmbeddingModel {
ParaphraseMLMiniLML12V2,
/// v1.5 release of the small Chinese model
BGESmallZHV15,
/// Small model of multilingual E5 Text Embeddings
MultilingualE5Small,
/// Base model of multilingual E5 Text Embeddings
MultilingualE5Base,
// Large model is something wrong, model.onnx size is only 546kB

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (code_clarification): Clarify the comment about the large model issue.

The comment about the large model being 'something wrong' is vague. It would be helpful to specify what the issue is, whether it's a temporary or permanent problem, and any planned steps to resolve it.

Suggested change
// Large model is something wrong, model.onnx size is only 546kB
// The Large model of multilingual E5 Text Embeddings appears to be incorrect due to its unusually small size (546kB). This issue is currently under investigation to determine if it's a file corruption or a misconfiguration. Updates or fixes will be applied once the problem is fully diagnosed.

// /// Large model of multilingual E5 Text Embeddings
// MultilingualE5Large,
}

impl Display for EmbeddingModel {
Expand Down Expand Up @@ -191,6 +198,7 @@ pub struct TokenizerFiles {
pub struct TextEmbedding {
tokenizer: Tokenizer,
session: Session,
need_token_type_ids: bool,
}

impl TextEmbedding {
Expand Down Expand Up @@ -266,7 +274,15 @@ impl TextEmbedding {

/// Private method to return an instance

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (docstrings): Please update the docstring for function: TextEmbedding::new

Reason for update: Initialization logic has changed to include a new field based on session inputs.

Suggested new docstring:

/// Private method to return an instance, initializing `need_token_type_ids` based on session inputs.

fn new(tokenizer: Tokenizer, session: Session) -> Self {
Self { tokenizer, session }
let need_token_type_ids = session
.inputs
.iter()
.any(|input| input.name == "token_type_ids");
Comment on lines +277 to +280

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (code_refinement): Consider initializing 'need_token_type_ids' directly in the struct declaration.

Initializing 'need_token_type_ids' directly in the struct declaration could simplify the 'new' method and improve readability.

Suggested change
let need_token_type_ids = session
.inputs
.iter()
.any(|input| input.name == "token_type_ids");
Self {
tokenizer,
session,
need_token_type_ids: session.inputs.iter().any(|input| input.name == "token_type_ids")
}

Self {
tokenizer,
session,
need_token_type_ids,
}
}
/// Return the TextEmbedding model's directory from cache or remote retrieval
fn retrieve_model(
Expand Down Expand Up @@ -456,6 +472,25 @@ impl TextEmbedding {
description: String::from("v1.5 release of the small Chinese model"),
model_code: String::from("Xenova/bge-small-zh-v1.5"),
},
ModelInfo {
model: EmbeddingModel::MultilingualE5Small,
dim: 384,
description: String::from("Small model of multilingual E5 Text Embeddings"),
model_code: String::from("intfloat/multilingual-e5-small"),
},
ModelInfo {
model: EmbeddingModel::MultilingualE5Base,
dim: 768,
description: String::from("Base model of multilingual E5 Text Embeddings"),
model_code: String::from("intfloat/multilingual-e5-base"),
},
// something wrong in MultilingualE5Large, model.onnx size is only 546kB
// ModelInfo {
// model: EmbeddingModel::MultilingualE5Large,
// dim: 1024,
// description: String::from("Large model of multilingual E5 Text Embeddings"),
// model_code: String::from("intfloat/multilingual-e5-large"),
// },
];

// TODO: Use when out in stable
Expand Down Expand Up @@ -528,11 +563,16 @@ impl TextEmbedding {
let token_type_ids_array =
Array::from_shape_vec((batch_size, encoding_length), typeids_array)?;

let outputs = self.session.run(ort::inputs![
let mut session_inputs = ort::inputs![
"input_ids" => Value::from_array(inputs_ids_array)?,
"attention_mask" => Value::from_array(attention_mask_array)?,
"token_type_ids" => Value::from_array(token_type_ids_array)?,
]?)?;
]?;
if self.need_token_type_ids {
session_inputs
.insert("token_type_ids", Value::from_array(token_type_ids_array)?);
}

let outputs = self.session.run(session_inputs)?;

// Extract and normalize embeddings
let output_data = outputs["last_hidden_state"].extract_tensor::<f32>()?;
Expand Down Expand Up @@ -667,9 +707,12 @@ mod tests {
.unwrap();

// Skip "nomic-ai/nomic-embed-text-v1" model for now as it has a different folder structure
// Also skip Xenova/bge-small-zh-v1.5" for the same reason
// Also skip "Xenova/bge-small-zh-v1.5", "intfloat/multilingual-e5-small",
// and "intfloat/multilingual-e5-base" for the same reason
if supported_model.model_code == "nomic-ai/nomic-embed-text-v1"
|| supported_model.model_code == "Xenova/bge-small-zh-v1.5"
|| supported_model.model_code == "intfloat/multilingual-e5-small"
|| supported_model.model_code == "intfloat/multilingual-e5-base"
{
continue;
}
Expand Down