-
Notifications
You must be signed in to change notification settings - Fork 342
Description
Bug
Using all-mpnet-base-v2 (or any model without a token_type_ids input) fails with:
Error: ONNX Runtime error: Invalid input name: token_type_ids
Even when manually working around this, the dimension is always reported as 384 regardless of the actual model output.
Root Cause
Two issues in examples/onnx-embeddings/src/model.rs:
1. Hardcoded dimension (line 190)
// Default embedding dimension (will be determined at runtime from actual output)
// Most sentence-transformers models output 384 dimensions
let dimension = 384;all-mpnet-base-v2 produces 768-dim vectors, but this is always set to 384.
2. Unconditional token_type_ids (lines 239-250)
let token_type_ids_tensor = Tensor::from_array((
vec![batch_size, seq_length],
token_type_ids.to_vec().into_boxed_slice(),
))
.map_err(|e| EmbeddingError::invalid_model(e.to_string()))?;
let inputs = vec![
("input_ids", input_ids_tensor.into_dyn()),
("attention_mask", attention_mask_tensor.into_dyn()),
("token_type_ids", token_type_ids_tensor.into_dyn()),
];all-mpnet-base-v2 only accepts input_ids and attention_mask. The code should check self.info.input_names before including token_type_ids.
Fix
Dimension: Use PretrainedModel::dimension() or detect from model name/output shape instead of hardcoding 384.
token_type_ids: Conditionally include based on the model's actual inputs:
let has_token_type_ids = self.info.input_names.iter().any(|n| n == "token_type_ids");
let mut inputs = vec![
("input_ids", input_ids_tensor.into_dyn()),
("attention_mask", attention_mask_tensor.into_dyn()),
];
if has_token_type_ids {
let token_type_ids_tensor = Tensor::from_array((
vec![batch_size, seq_length],
token_type_ids.to_vec().into_boxed_slice(),
))
.map_err(|e| EmbeddingError::invalid_model(e.to_string()))?;
inputs.push(("token_type_ids", token_type_ids_tensor.into_dyn()));
}Affected Models
Any model that doesn't use token_type_ids: all-mpnet-base-v2 and potentially others. Only all-MiniLM-* variants happen to work because they accept all three inputs.
Environment
- macOS (Apple Silicon M2 Ultra)
- ort 2.x
feat/ruvector-postgres-v2branch