Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/test-and-build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,9 @@ jobs:
with:
path: |
models/
key: ${{ runner.os }}-models-v1-${{ hashFiles('tools/make/models.mk') }}
key: ${{ runner.os }}-models-v2-${{ hashFiles('tools/make/models.mk') }}
restore-keys: |
${{ runner.os }}-models-v1-
${{ runner.os }}-models-v2-
continue-on-error: true # Don't fail the job if caching fails

- name: Check go mod tidy
Expand Down
34 changes: 34 additions & 0 deletions candle-binding/src/classifiers/lora/intent_lora.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,40 @@ impl IntentLoRAClassifier {
})
}

/// Classify intent and return (class_index, confidence, intent_label) for FFI
pub fn classify_with_index(&self, text: &str) -> Result<(usize, f32, String)> {
// Use real BERT model for classification
let (predicted_class, confidence) =
self.bert_classifier.classify_text(text).map_err(|e| {
let unified_err = model_error!(
ModelErrorType::LoRA,
"intent classification",
format!("Classification failed: {}", e),
text
);
candle_core::Error::from(unified_err)
})?;

// Map class index to intent label - fail if class not found
let intent = if predicted_class < self.intent_labels.len() {
self.intent_labels[predicted_class].clone()
} else {
let unified_err = model_error!(
ModelErrorType::LoRA,
"intent classification",
format!(
"Invalid class index {} not found in labels (max: {})",
predicted_class,
self.intent_labels.len()
),
text
);
return Err(candle_core::Error::from(unified_err));
};

Ok((predicted_class, confidence, intent))
}

/// Parallel classification for multiple texts using rayon
///
/// # Performance
Expand Down
14 changes: 13 additions & 1 deletion candle-binding/src/core/tokenization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,19 @@ impl DualPathTokenizer for UnifiedTokenizer {
let encoding = tokenizer
.encode(text, self.config.add_special_tokens)
.map_err(E::msg)?;
Ok(self.encoding_to_result(&encoding))

// Explicitly enforce max_length truncation for LoRA models
// This is a safety check to ensure we never exceed the model's position embedding size
let mut result = self.encoding_to_result(&encoding);
let max_len = self.config.max_length;
if result.token_ids.len() > max_len {
result.token_ids.truncate(max_len);
result.token_ids_u32.truncate(max_len);
result.attention_mask.truncate(max_len);
result.tokens.truncate(max_len);
}

Ok(result)
}

fn tokenize_batch_smart(
Expand Down
31 changes: 28 additions & 3 deletions candle-binding/src/ffi/classify.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use crate::BertClassifier;
use std::ffi::{c_char, CStr};
use std::sync::{Arc, OnceLock};

use crate::ffi::init::{PARALLEL_LORA_ENGINE, UNIFIED_CLASSIFIER};
use crate::ffi::init::{LORA_INTENT_CLASSIFIER, PARALLEL_LORA_ENGINE, UNIFIED_CLASSIFIER};
// Import DeBERTa classifier for jailbreak detection
use super::init::DEBERTA_JAILBREAK_CLASSIFIER;

Expand Down Expand Up @@ -693,7 +693,32 @@ pub extern "C" fn classify_candle_bert_text(text: *const c_char) -> Classificati
Err(_) => return default_result,
}
};
// Use TraditionalBertClassifier for Candle BERT text classification

// Try LoRA intent classifier first (preferred for higher accuracy)
if let Some(classifier) = LORA_INTENT_CLASSIFIER.get() {
let classifier = classifier.clone();
match classifier.classify_with_index(text) {
Ok((class_idx, confidence, ref intent)) => {
// Allocate C string for intent label
let label_ptr = unsafe { allocate_c_string(intent) };

return ClassificationResult {
predicted_class: class_idx as i32,
confidence,
label: label_ptr,
};
}
Err(e) => {
eprintln!(
"LoRA intent classifier error: {}, falling back to Traditional BERT",
e
);
// Don't return - fall through to Traditional BERT classifier
}
}
}

// Fallback to Traditional BERT classifier
if let Some(classifier) = TRADITIONAL_BERT_CLASSIFIER.get() {
let classifier = classifier.clone();
match classifier.classify_text(text) {
Expand All @@ -717,7 +742,7 @@ pub extern "C" fn classify_candle_bert_text(text: *const c_char) -> Classificati
}
}
} else {
println!("TraditionalBertClassifier not initialized - call init_bert_classifier first");
println!("No classifier initialized - call init_candle_bert_classifier first");
ClassificationResult {
predicted_class: -1,
confidence: 0.0,
Expand Down
55 changes: 42 additions & 13 deletions candle-binding/src/ffi/init.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ pub static PARALLEL_LORA_ENGINE: OnceLock<
pub static LORA_TOKEN_CLASSIFIER: OnceLock<
Arc<crate::classifiers::lora::token_lora::LoRATokenClassifier>,
> = OnceLock::new();
// LoRA intent classifier for sequence classification
pub static LORA_INTENT_CLASSIFIER: OnceLock<
Arc<crate::classifiers::lora::intent_lora::IntentLoRAClassifier>,
> = OnceLock::new();

/// Model type detection for intelligent routing
#[derive(Debug, Clone, PartialEq)]
Expand Down Expand Up @@ -604,28 +608,53 @@ pub extern "C" fn init_candle_bert_classifier(
num_classes: i32,
use_cpu: bool,
) -> bool {
// Migrated from lib.rs:1555-1578
let model_path = unsafe {
match CStr::from_ptr(model_path).to_str() {
Ok(s) => s,
Err(_) => return false,
}
};

// Initialize TraditionalBertClassifier
match crate::model_architectures::traditional::bert::TraditionalBertClassifier::new(
model_path,
num_classes as usize,
use_cpu,
) {
Ok(_classifier) => {
// Store in global static (would need to add this to the lazy_static block)
// Intelligent model type detection (same as token classifier)
let model_type = detect_model_type(model_path);

true
match model_type {
ModelType::LoRA => {
// Check if already initialized
if LORA_INTENT_CLASSIFIER.get().is_some() {
return true; // Already initialized, return success
}

// Route to LoRA intent classifier initialization
match crate::classifiers::lora::intent_lora::IntentLoRAClassifier::new(
model_path, use_cpu,
) {
Ok(classifier) => LORA_INTENT_CLASSIFIER.set(Arc::new(classifier)).is_ok(),
Err(e) => {
eprintln!(
" ERROR: Failed to initialize LoRA intent classifier: {}",
e
);
false
}
}
}
Err(e) => {
eprintln!("Failed to initialize Candle BERT classifier: {}", e);
false
ModelType::Traditional => {
// Initialize TraditionalBertClassifier
match crate::model_architectures::traditional::bert::TraditionalBertClassifier::new(
model_path,
num_classes as usize,
use_cpu,
) {
Ok(_classifier) => {
// Store in global static (would need to add this to the lazy_static block)
true
}
Err(e) => {
eprintln!("Failed to initialize Candle BERT classifier: {}", e);
false
}
}
}
}
}
Expand Down
22 changes: 20 additions & 2 deletions candle-binding/src/model_architectures/lora/bert_lora.rs
Original file line number Diff line number Diff line change
Expand Up @@ -499,9 +499,18 @@ impl HighPerformanceBertClassifier {

// Load tokenizer
let tokenizer_path = Path::new(model_path).join("tokenizer.json");
let tokenizer = Tokenizer::from_file(&tokenizer_path)
let mut tokenizer = Tokenizer::from_file(&tokenizer_path)
.map_err(|e| E::msg(format!("Failed to load tokenizer: {}", e)))?;

// Configure truncation to max 512 tokens (BERT's position embedding limit)
use tokenizers::TruncationParams;
tokenizer
.with_truncation(Some(TruncationParams {
max_length: 512,
..Default::default()
}))
.map_err(E::msg)?;

// Load model weights
let weights_path = if Path::new(model_path).join("model.safetensors").exists() {
Path::new(model_path).join("model.safetensors")
Expand Down Expand Up @@ -690,9 +699,18 @@ impl HighPerformanceBertTokenClassifier {

// Load tokenizer
let tokenizer_path = Path::new(model_path).join("tokenizer.json");
let tokenizer = Tokenizer::from_file(&tokenizer_path)
let mut tokenizer = Tokenizer::from_file(&tokenizer_path)
.map_err(|e| E::msg(format!("Failed to load tokenizer: {}", e)))?;

// Configure truncation to max 512 tokens (BERT's position embedding limit)
use tokenizers::TruncationParams;
tokenizer
.with_truncation(Some(TruncationParams {
max_length: 512,
..Default::default()
}))
.map_err(E::msg)?;

// Load model weights
let weights_path = if Path::new(model_path).join("model.safetensors").exists() {
Path::new(model_path).join("model.safetensors")
Expand Down
11 changes: 5 additions & 6 deletions config/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -58,15 +58,14 @@ model_config:
# Classifier configuration
classifier:
category_model:
model_id: "models/category_classifier_modernbert-base_model"
use_modernbert: true
model_id: "models/lora_intent_classifier_bert-base-uncased_model"
threshold: 0.6
use_cpu: true
category_mapping_path: "models/category_classifier_modernbert-base_model/category_mapping.json"
category_mapping_path: "models/lora_intent_classifier_bert-base-uncased_model/category_mapping.json"
pii_model:
model_id: "models/pii_classifier_modernbert-base_presidio_token_model"
use_modernbert: true
threshold: 0.7
model_id: "models/lora_pii_detector_bert-base-uncased_model"
use_modernbert: false
threshold: 0.9
use_cpu: true
pii_mapping_path: "models/pii_classifier_modernbert-base_presidio_token_model/pii_type_mapping.json"

Expand Down
8 changes: 5 additions & 3 deletions deploy/helm/semantic-router/values.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# Declare variables to be passed into your templates.

# Global settings
global:

Check warning on line 6 in deploy/helm/semantic-router/values.yaml

View workflow job for this annotation

GitHub Actions / Run Validation Script

6:1 [document-start] missing document start "---"
# -- Namespace for all resources (if not specified, uses Release.Namespace)
namespace: ""

Expand Down Expand Up @@ -47,7 +47,7 @@

# Pod security context
podSecurityContext: {}
# fsGroup: 2000

Check warning on line 50 in deploy/helm/semantic-router/values.yaml

View workflow job for this annotation

GitHub Actions / Run Validation Script

50:3 [comments-indentation] comment not indented like content

# Container security context
securityContext:
Expand Down Expand Up @@ -100,7 +100,7 @@
className: ""
# -- Ingress annotations
annotations: {}
# kubernetes.io/ingress.class: nginx

Check warning on line 103 in deploy/helm/semantic-router/values.yaml

View workflow job for this annotation

GitHub Actions / Run Validation Script

103:5 [comments-indentation] comment not indented like content
# kubernetes.io/tls-acme: "true"
# -- Ingress hosts configuration
hosts:
Expand Down Expand Up @@ -159,6 +159,8 @@
repo: Qwen/Qwen3-Embedding-0.6B
- name: all-MiniLM-L12-v2
repo: sentence-transformers/all-MiniLM-L12-v2
- name: lora_intent_classifier_bert-base-uncased_model
repo: LLM-Semantic-Router/lora_intent_classifier_bert-base-uncased_model
- name: category_classifier_modernbert-base_model
repo: LLM-Semantic-Router/category_classifier_modernbert-base_model
- name: pii_classifier_modernbert-base_model
Expand All @@ -166,7 +168,7 @@
- name: jailbreak_classifier_modernbert-base_model
repo: LLM-Semantic-Router/jailbreak_classifier_modernbert-base_model
- name: pii_classifier_modernbert-base_presidio_token_model
repo: LLM-Semantic-Router/pii_classifier_modernbert-base_presidio_token_model

Check failure on line 171 in deploy/helm/semantic-router/values.yaml

View workflow job for this annotation

GitHub Actions / Run Validation Script

171:81 [line-length] line too long (83 > 80 characters)
# LoRA PII detector (for auto-detection feature)
- name: lora_pii_detector_bert-base-uncased_model
repo: LLM-Semantic-Router/lora_pii_detector_bert-base-uncased_model
Expand Down Expand Up @@ -232,7 +234,7 @@
size: 10Gi
# -- Annotations for PVC
annotations: {}
# -- Existing claim name (if provided, will use existing PVC instead of creating new one)

Check failure on line 237 in deploy/helm/semantic-router/values.yaml

View workflow job for this annotation

GitHub Actions / Run Validation Script

237:81 [line-length] line too long (91 > 80 characters)
existingClaim: ""

# Application configuration
Expand Down Expand Up @@ -267,22 +269,22 @@
model_id: "models/jailbreak_classifier_modernbert-base_model"
threshold: 0.7
use_cpu: true
jailbreak_mapping_path: "models/jailbreak_classifier_modernbert-base_model/jailbreak_type_mapping.json"

Check failure on line 272 in deploy/helm/semantic-router/values.yaml

View workflow job for this annotation

GitHub Actions / Run Validation Script

272:81 [line-length] line too long (107 > 80 characters)

# Classifier configuration
classifier:
category_model:
model_id: "models/category_classifier_modernbert-base_model"
use_modernbert: true
model_id: "models/lora_intent_classifier_bert-base-uncased_model"
use_modernbert: false # Use LoRA intent classifier with auto-detection
threshold: 0.6
use_cpu: true
category_mapping_path: "models/category_classifier_modernbert-base_model/category_mapping.json"
category_mapping_path: "models/lora_intent_classifier_bert-base-uncased_model/category_mapping.json"

Check failure on line 281 in deploy/helm/semantic-router/values.yaml

View workflow job for this annotation

GitHub Actions / Run Validation Script

281:81 [line-length] line too long (106 > 80 characters)
pii_model:
model_id: "models/pii_classifier_modernbert-base_presidio_token_model"
use_modernbert: true
threshold: 0.7
use_cpu: true
pii_mapping_path: "models/pii_classifier_modernbert-base_presidio_token_model/pii_type_mapping.json"

Check failure on line 287 in deploy/helm/semantic-router/values.yaml

View workflow job for this annotation

GitHub Actions / Run Validation Script

287:81 [line-length] line too long (106 > 80 characters)

# Reasoning families
reasoning_families:
Expand Down Expand Up @@ -313,7 +315,7 @@
detailed_goroutine_tracking: true
high_resolution_timing: false
sample_rate: 1.0
duration_buckets: [0.001, 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1, 2.5, 5, 10, 30]

Check failure on line 318 in deploy/helm/semantic-router/values.yaml

View workflow job for this annotation

GitHub Actions / Run Validation Script

318:81 [line-length] line too long (94 > 80 characters)
size_buckets: [1, 2, 5, 10, 20, 50, 100, 200]

# Observability configuration
Expand Down Expand Up @@ -351,7 +353,7 @@
enum: ["celsius", "fahrenheit"]
description: "Temperature unit"
required: ["location"]
description: "Get current weather information, temperature, conditions, forecast for any location, city, or place. Check weather today, now, current conditions, temperature, rain, sun, cloudy, hot, cold, storm, snow"

Check failure on line 356 in deploy/helm/semantic-router/values.yaml

View workflow job for this annotation

GitHub Actions / Run Validation Script

356:81 [line-length] line too long (220 > 80 characters)
category: "weather"
tags: ["weather", "temperature", "forecast", "climate"]
- tool:
Expand All @@ -370,7 +372,7 @@
description: "Number of results to return"
default: 5
required: ["query"]
description: "Search the internet, web search, find information online, browse web content, lookup, research, google, find answers, discover, investigate"

Check failure on line 375 in deploy/helm/semantic-router/values.yaml

View workflow job for this annotation

GitHub Actions / Run Validation Script

375:81 [line-length] line too long (158 > 80 characters)
category: "search"
tags: ["search", "web", "internet", "information", "browse"]
- tool:
Expand All @@ -385,7 +387,7 @@
type: "string"
description: "Mathematical expression to evaluate"
required: ["expression"]
description: "Calculate mathematical expressions, solve math problems, arithmetic operations, compute numbers, addition, subtraction, multiplication, division, equations, formula"

Check failure on line 390 in deploy/helm/semantic-router/values.yaml

View workflow job for this annotation

GitHub Actions / Run Validation Script

390:81 [line-length] line too long (183 > 80 characters)
category: "math"
tags: ["math", "calculation", "arithmetic", "compute", "numbers"]
- tool:
Expand All @@ -406,7 +408,7 @@
type: "string"
description: "Email body content"
required: ["to", "subject", "body"]
description: "Send email messages, email communication, contact people via email, mail, message, correspondence, notify, inform"

Check failure on line 411 in deploy/helm/semantic-router/values.yaml

View workflow job for this annotation

GitHub Actions / Run Validation Script

411:81 [line-length] line too long (132 > 80 characters)
category: "communication"
tags: ["email", "send", "communication", "message", "contact"]
- tool:
Expand Down
6 changes: 3 additions & 3 deletions deploy/kubernetes/aibrix/semantic-router-values/values.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -431,11 +431,11 @@ config:
# Classifier configuration
classifier:
category_model:
model_id: "models/category_classifier_modernbert-base_model"
use_modernbert: true
model_id: "models/lora_intent_classifier_bert-base-uncased_model"
use_modernbert: false # Use LoRA intent classifier with auto-detection
threshold: 0.6
use_cpu: true
category_mapping_path: "models/category_classifier_modernbert-base_model/category_mapping.json"
category_mapping_path: "models/lora_intent_classifier_bert-base-uncased_model/category_mapping.json"
pii_model:
# Support both traditional (modernbert) and LoRA-based PII detection
# When model_type is "auto", the system will auto-detect LoRA configuration
Expand Down
8 changes: 4 additions & 4 deletions e2e/profiles/ai-gateway/values.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -461,17 +461,17 @@ config:
# Classifier configuration
classifier:
category_model:
model_id: "models/category_classifier_modernbert-base_model"
use_modernbert: true
model_id: "models/lora_intent_classifier_bert-base-uncased_model"
use_modernbert: false # Use LoRA intent classifier with auto-detection
threshold: 0.6
use_cpu: true
category_mapping_path: "models/category_classifier_modernbert-base_model/category_mapping.json"
category_mapping_path: "models/lora_intent_classifier_bert-base-uncased_model/category_mapping.json"
pii_model:
# Support both traditional (modernbert) and LoRA-based PII detection
# When model_type is "auto", the system will auto-detect LoRA configuration
model_id: "models/lora_pii_detector_bert-base-uncased_model"
use_modernbert: false # Use LoRA PII model with auto-detection
threshold: 0.7
threshold: 0.9
use_cpu: true
pii_mapping_path: "models/pii_classifier_modernbert-base_presidio_token_model/pii_type_mapping.json"

Expand Down
8 changes: 4 additions & 4 deletions e2e/profiles/dynamic-config/values.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,15 @@ config:

classifier:
category_model:
model_id: "models/category_classifier_modernbert-base_model"
use_modernbert: true
model_id: "models/lora_intent_classifier_bert-base-uncased_model"
use_modernbert: false # Use LoRA intent classifier with auto-detection
threshold: 0.6
use_cpu: true
category_mapping_path: "models/category_classifier_modernbert-base_model/category_mapping.json"
category_mapping_path: "models/lora_intent_classifier_bert-base-uncased_model/category_mapping.json"
pii_model:
model_id: "models/lora_pii_detector_bert-base-uncased_model"
use_modernbert: false # Use LoRA PII model with auto-detection
threshold: 0.7
threshold: 0.9
use_cpu: true
pii_mapping_path: "models/pii_classifier_modernbert-base_presidio_token_model/pii_type_mapping.json"

Expand Down
Loading
Loading