Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 67 additions & 69 deletions candle-binding/src/ffi/classify.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,6 @@ use crate::ffi::memory::{
allocate_pii_result_array, allocate_security_result_array,
};
use crate::ffi::types::*;
use crate::BertClassifier;
use lazy_static::lazy_static;
use std::ffi::{c_char, CStr};
use std::sync::{Arc, Mutex};

use crate::classifiers::unified::DualPathUnifiedClassifier;
use crate::model_architectures::traditional::bert::{
TRADITIONAL_BERT_CLASSIFIER, TRADITIONAL_BERT_TOKEN_CLASSIFIER,
};
Expand All @@ -25,9 +19,13 @@ use crate::model_architectures::traditional::modernbert::{
TRADITIONAL_MODERNBERT_PII_CLASSIFIER, TRADITIONAL_MODERNBERT_TOKEN_CLASSIFIER,
};
use crate::model_architectures::traits::TaskType;
use crate::BertClassifier;
use lazy_static::lazy_static;
use std::ffi::{c_char, CStr};
use std::sync::{Arc, Mutex};
extern crate lazy_static;

use crate::ffi::init::PARALLEL_LORA_ENGINE;
use crate::ffi::init::{PARALLEL_LORA_ENGINE, UNIFIED_CLASSIFIER};

/// Load id2label mapping from model config.json file
/// Returns HashMap mapping class index (as string) to label name
Expand All @@ -42,8 +40,9 @@ pub fn load_id2label_from_config(

// Global state for classification using dual-path architecture
lazy_static! {
static ref UNIFIED_CLASSIFIER: Arc<Mutex<Option<DualPathUnifiedClassifier>>> = Arc::new(Mutex::new(None));
// Legacy classifiers for backward compatibility
// NOTE: UNIFIED_CLASSIFIER is defined in ffi/init.rs and re-exported
// We import it here to avoid duplicate definitions
// Legacy classifiers for backward compatibility (still needed for old API paths)
static ref BERT_CLASSIFIER: Arc<Mutex<Option<BertClassifier>>> = Arc::new(Mutex::new(None));
static ref BERT_PII_CLASSIFIER: Arc<Mutex<Option<BertClassifier>>> = Arc::new(Mutex::new(None));
static ref BERT_JAILBREAK_CLASSIFIER: Arc<Mutex<Option<BertClassifier>>> = Arc::new(Mutex::new(None));
Expand Down Expand Up @@ -654,71 +653,70 @@ pub extern "C" fn classify_batch_with_lora(
}

let start_time = std::time::Instant::now();
let engine_guard = PARALLEL_LORA_ENGINE.lock().unwrap();
match engine_guard.as_ref() {
Some(engine) => {
let text_refs: Vec<&str> = text_vec.iter().map(|s| s.as_ref()).collect();
match engine.parallel_classify(&text_refs) {
Ok(parallel_result) => {
let _processing_time_ms = start_time.elapsed().as_millis() as f32;

// Allocate C arrays for LoRA results
let intent_results_ptr =
unsafe { allocate_lora_intent_array(&parallel_result.intent_results) };
let pii_results_ptr =
unsafe { allocate_lora_pii_array(&parallel_result.pii_results) };
let security_results_ptr =
unsafe { allocate_lora_security_array(&parallel_result.security_results) };

LoRABatchResult {
intent_results: intent_results_ptr,
pii_results: pii_results_ptr,
security_results: security_results_ptr,
batch_size: texts_count as i32,
avg_confidence: {
let mut total_confidence = 0.0f32;
let mut count = 0;

// Sum intent confidences
for intent in &parallel_result.intent_results {
total_confidence += intent.confidence;
count += 1;
}

// Sum PII confidences
for pii in &parallel_result.pii_results {
total_confidence += pii.confidence;
count += 1;
}

// Sum security confidences
for security in &parallel_result.security_results {
total_confidence += security.confidence;
count += 1;
}

if count > 0 {
total_confidence / count as f32
} else {
0.0
}
},
// Optimization: Clone Arc to minimize lock holding time
// Lock is only held during the clone operation (~nanoseconds), not during inference
let engine: Arc<crate::classifiers::lora::parallel_engine::ParallelLoRAEngine> = {
let engine_guard = PARALLEL_LORA_ENGINE.lock().unwrap();
match engine_guard.as_ref() {
Some(e) => e.clone(),
None => {
eprintln!("PARALLEL_LORA_ENGINE not initialized");
return default_result;
}
}
}; // Lock is released here immediately after clone

// Now perform inference without holding the lock (allows concurrent requests)
let text_refs: Vec<&str> = text_vec.iter().map(|s| s.as_ref()).collect();
match engine.parallel_classify(&text_refs) {
Ok(parallel_result) => {
let _processing_time_ms = start_time.elapsed().as_millis() as f32;

// Allocate C arrays for LoRA results
let intent_results_ptr =
unsafe { allocate_lora_intent_array(&parallel_result.intent_results) };
let pii_results_ptr = unsafe { allocate_lora_pii_array(&parallel_result.pii_results) };
let security_results_ptr =
unsafe { allocate_lora_security_array(&parallel_result.security_results) };

LoRABatchResult {
intent_results: intent_results_ptr,
pii_results: pii_results_ptr,
security_results: security_results_ptr,
batch_size: texts_count as i32,
avg_confidence: {
let mut total_confidence = 0.0f32;
let mut count = 0;

// Sum intent confidences
for intent in &parallel_result.intent_results {
total_confidence += intent.confidence;
count += 1;
}
}
Err(e) => {
println!("LoRA parallel classification failed: {}", e);
LoRABatchResult {
intent_results: std::ptr::null_mut(),
pii_results: std::ptr::null_mut(),
security_results: std::ptr::null_mut(),
batch_size: 0,
avg_confidence: 0.0,

// Sum PII confidences
for pii in &parallel_result.pii_results {
total_confidence += pii.confidence;
count += 1;
}
}

// Sum security confidences
for security in &parallel_result.security_results {
total_confidence += security.confidence;
count += 1;
}

if count > 0 {
total_confidence / count as f32
} else {
0.0
}
},
}
}
None => {
println!("ParallelLoRAEngine not initialized - call init function first");
Err(e) => {
println!("LoRA parallel classification failed: {}", e);
LoRABatchResult {
intent_results: std::ptr::null_mut(),
pii_results: std::ptr::null_mut(),
Expand Down
13 changes: 7 additions & 6 deletions candle-binding/src/ffi/init.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@ lazy_static! {
static ref BERT_CLASSIFIER: Arc<Mutex<Option<BertClassifier>>> = Arc::new(Mutex::new(None));
static ref BERT_PII_CLASSIFIER: Arc<Mutex<Option<BertClassifier>>> = Arc::new(Mutex::new(None));
static ref BERT_JAILBREAK_CLASSIFIER: Arc<Mutex<Option<BertClassifier>>> = Arc::new(Mutex::new(None));
// Unified classifier for dual-path architecture
static ref UNIFIED_CLASSIFIER: Arc<Mutex<Option<crate::classifiers::unified::DualPathUnifiedClassifier>>> = Arc::new(Mutex::new(None));
// Parallel LoRA engine for high-performance classification
pub static ref PARALLEL_LORA_ENGINE: Arc<Mutex<Option<crate::classifiers::lora::parallel_engine::ParallelLoRAEngine>>> = Arc::new(Mutex::new(None));
// Unified classifier for dual-path architecture (exported for use in classify.rs)
pub static ref UNIFIED_CLASSIFIER: Arc<Mutex<Option<crate::classifiers::unified::DualPathUnifiedClassifier>>> = Arc::new(Mutex::new(None));
// Parallel LoRA engine for high-performance classification (primary path for LoRA models)
// Wrapped in Arc for cheap cloning and concurrent access
pub static ref PARALLEL_LORA_ENGINE: Arc<Mutex<Option<Arc<crate::classifiers::lora::parallel_engine::ParallelLoRAEngine>>>> = Arc::new(Mutex::new(None));
// LoRA token classifier for token-level classification
pub static ref LORA_TOKEN_CLASSIFIER: Arc<Mutex<Option<crate::classifiers::lora::token_lora::LoRATokenClassifier>>> = Arc::new(Mutex::new(None));
}
Expand Down Expand Up @@ -719,9 +720,9 @@ pub extern "C" fn init_lora_unified_classifier(
use_cpu,
) {
Ok(engine) => {
// Store in global static variable
// Store in global static variable (wrapped in Arc for efficient cloning)
let mut engine_guard = PARALLEL_LORA_ENGINE.lock().unwrap();
*engine_guard = Some(engine);
*engine_guard = Some(Arc::new(engine));
true
}
Err(e) => {
Expand Down
Loading