Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ The default model is Flag Embedding, which is top of the [MTEB](https://huggingf
- [**sentence-transformers/all-MiniLM-L6-v2**](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2)
- [**sentence-transformers/paraphrase-MiniLM-L12-v2**](https://huggingface.co/sentence-transformers/paraphrase-MiniLM-L12-v2)
- [**nomic-ai/nomic-embed-text-v1**](https://huggingface.co/nomic-ai/nomic-embed-text-v1)
- [**intfloat/multilingual-e5-small**](https://huggingface.co/intfloat/multilingual-e5-small)
- [**intfloat/multilingual-e5-base**](https://huggingface.co/intfloat/multilingual-e5-base)

Alternatively, raw .onnx files can be loaded through the UserDefinedEmbeddingModel struct (for "bring your own" text embedding models).

Expand Down
53 changes: 48 additions & 5 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,13 @@ pub enum EmbeddingModel {
ParaphraseMLMiniLML12V2,
/// v1.5 release of the small Chinese model
BGESmallZHV15,
/// Small model of multilingual E5 Text Embeddings
MultilingualE5Small,
/// Base model of multilingual E5 Text Embeddings
MultilingualE5Base,
// Large model is something wrong, model.onnx size is only 546kB
// /// Large model of multilingual E5 Text Embeddings
// MultilingualE5Large,
}

impl Display for EmbeddingModel {
Expand Down Expand Up @@ -191,6 +198,7 @@ pub struct TokenizerFiles {
pub struct TextEmbedding {
tokenizer: Tokenizer,
session: Session,
need_token_type_ids: bool,
}

impl TextEmbedding {
Expand Down Expand Up @@ -266,7 +274,15 @@ impl TextEmbedding {

/// Private method to return an instance
fn new(tokenizer: Tokenizer, session: Session) -> Self {
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sourcery logo suggestion (docstrings): Please update the docstring for function: TextEmbedding::new

Reason for update: The function signature and internal logic have changed to include a new field need_token_type_ids.

Suggested new docstring:

/// Private method to return an instance, determining if `token_type_ids` are needed based on the session inputs.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this comment correct?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this comment helpful?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the comment type correct?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the comment area correct?

Self { tokenizer, session }
let need_token_type_ids = session
.inputs
.iter()
.any(|input| input.name == "token_type_ids");
Self {
tokenizer,
session,
need_token_type_ids,
}
}
/// Return the TextEmbedding model's directory from cache or remote retrieval
fn retrieve_model(
Expand Down Expand Up @@ -456,6 +472,25 @@ impl TextEmbedding {
description: String::from("v1.5 release of the small Chinese model"),
model_code: String::from("Xenova/bge-small-zh-v1.5"),
},
ModelInfo {
model: EmbeddingModel::MultilingualE5Small,
dim: 384,
description: String::from("Small model of multilingual E5 Text Embeddings"),
model_code: String::from("intfloat/multilingual-e5-small"),
},
ModelInfo {
model: EmbeddingModel::MultilingualE5Base,
dim: 768,
description: String::from("Base model of multilingual E5 Text Embeddings"),
model_code: String::from("intfloat/multilingual-e5-base"),
},
// something wrong in MultilingualE5Large, model.onnx size is only 546kB
// ModelInfo {
// model: EmbeddingModel::MultilingualE5Large,
// dim: 1024,
// description: String::from("Large model of multilingual E5 Text Embeddings"),
// model_code: String::from("intfloat/multilingual-e5-large"),
// },
];

// TODO: Use when out in stable
Expand Down Expand Up @@ -528,11 +563,16 @@ impl TextEmbedding {
let token_type_ids_array =
Array::from_shape_vec((batch_size, encoding_length), typeids_array)?;

let outputs = self.session.run(ort::inputs![
let mut session_inputs = ort::inputs![
"input_ids" => Value::from_array(inputs_ids_array)?,
"attention_mask" => Value::from_array(attention_mask_array)?,
"token_type_ids" => Value::from_array(token_type_ids_array)?,
]?)?;
]?;
if self.need_token_type_ids {
session_inputs
.insert("token_type_ids", Value::from_array(token_type_ids_array)?);
}

let outputs = self.session.run(session_inputs)?;

// Extract and normalize embeddings
let output_data = outputs["last_hidden_state"].extract_tensor::<f32>()?;
Expand Down Expand Up @@ -667,9 +707,12 @@ mod tests {
.unwrap();

// Skip "nomic-ai/nomic-embed-text-v1" model for now as it has a different folder structure
// Also skip Xenova/bge-small-zh-v1.5" for the same reason
// Also skip "Xenova/bge-small-zh-v1.5", "intfloat/multilingual-e5-small",
// and "intfloat/multilingual-e5-base" for the same reason
if supported_model.model_code == "nomic-ai/nomic-embed-text-v1"
|| supported_model.model_code == "Xenova/bge-small-zh-v1.5"
|| supported_model.model_code == "intfloat/multilingual-e5-small"
|| supported_model.model_code == "intfloat/multilingual-e5-base"
{
continue;
}
Expand Down