From 563bc0ed113d1641270608fed3ba6b70186fd274 Mon Sep 17 00:00:00 2001 From: OneZero-Y Date: Thu, 23 Oct 2025 10:53:07 +0800 Subject: [PATCH] fix: Fix duplicate UNIFIED_CLASSIFIER definition and optimize lock contention - Remove duplicate UNIFIED_CLASSIFIER global state - Optimize PARALLEL_LORA_ENGINE lock contention by using Arc clone Signed-off-by: OneZero-Y --- candle-binding/src/ffi/classify.rs | 136 ++++++++++++++--------------- candle-binding/src/ffi/init.rs | 13 +-- 2 files changed, 74 insertions(+), 75 deletions(-) diff --git a/candle-binding/src/ffi/classify.rs b/candle-binding/src/ffi/classify.rs index 264c8e14..c42b0631 100644 --- a/candle-binding/src/ffi/classify.rs +++ b/candle-binding/src/ffi/classify.rs @@ -11,12 +11,6 @@ use crate::ffi::memory::{ allocate_pii_result_array, allocate_security_result_array, }; use crate::ffi::types::*; -use crate::BertClassifier; -use lazy_static::lazy_static; -use std::ffi::{c_char, CStr}; -use std::sync::{Arc, Mutex}; - -use crate::classifiers::unified::DualPathUnifiedClassifier; use crate::model_architectures::traditional::bert::{ TRADITIONAL_BERT_CLASSIFIER, TRADITIONAL_BERT_TOKEN_CLASSIFIER, }; @@ -25,9 +19,13 @@ use crate::model_architectures::traditional::modernbert::{ TRADITIONAL_MODERNBERT_PII_CLASSIFIER, TRADITIONAL_MODERNBERT_TOKEN_CLASSIFIER, }; use crate::model_architectures::traits::TaskType; +use crate::BertClassifier; +use lazy_static::lazy_static; +use std::ffi::{c_char, CStr}; +use std::sync::{Arc, Mutex}; extern crate lazy_static; -use crate::ffi::init::PARALLEL_LORA_ENGINE; +use crate::ffi::init::{PARALLEL_LORA_ENGINE, UNIFIED_CLASSIFIER}; /// Load id2label mapping from model config.json file /// Returns HashMap mapping class index (as string) to label name @@ -42,8 +40,9 @@ pub fn load_id2label_from_config( // Global state for classification using dual-path architecture lazy_static! { - static ref UNIFIED_CLASSIFIER: Arc>> = Arc::new(Mutex::new(None)); - // Legacy classifiers for backward compatibility + // NOTE: UNIFIED_CLASSIFIER is defined in ffi/init.rs and re-exported + // We import it here to avoid duplicate definitions + // Legacy classifiers for backward compatibility (still needed for old API paths) static ref BERT_CLASSIFIER: Arc>> = Arc::new(Mutex::new(None)); static ref BERT_PII_CLASSIFIER: Arc>> = Arc::new(Mutex::new(None)); static ref BERT_JAILBREAK_CLASSIFIER: Arc>> = Arc::new(Mutex::new(None)); @@ -654,71 +653,70 @@ pub extern "C" fn classify_batch_with_lora( } let start_time = std::time::Instant::now(); - let engine_guard = PARALLEL_LORA_ENGINE.lock().unwrap(); - match engine_guard.as_ref() { - Some(engine) => { - let text_refs: Vec<&str> = text_vec.iter().map(|s| s.as_ref()).collect(); - match engine.parallel_classify(&text_refs) { - Ok(parallel_result) => { - let _processing_time_ms = start_time.elapsed().as_millis() as f32; - - // Allocate C arrays for LoRA results - let intent_results_ptr = - unsafe { allocate_lora_intent_array(¶llel_result.intent_results) }; - let pii_results_ptr = - unsafe { allocate_lora_pii_array(¶llel_result.pii_results) }; - let security_results_ptr = - unsafe { allocate_lora_security_array(¶llel_result.security_results) }; - LoRABatchResult { - intent_results: intent_results_ptr, - pii_results: pii_results_ptr, - security_results: security_results_ptr, - batch_size: texts_count as i32, - avg_confidence: { - let mut total_confidence = 0.0f32; - let mut count = 0; - - // Sum intent confidences - for intent in ¶llel_result.intent_results { - total_confidence += intent.confidence; - count += 1; - } - - // Sum PII confidences - for pii in ¶llel_result.pii_results { - total_confidence += pii.confidence; - count += 1; - } - - // Sum security confidences - for security in ¶llel_result.security_results { - total_confidence += security.confidence; - count += 1; - } - - if count > 0 { - total_confidence / count as f32 - } else { - 0.0 - } - }, + // Optimization: Clone Arc to minimize lock holding time + // Lock is only held during the clone operation (~nanoseconds), not during inference + let engine: Arc = { + let engine_guard = PARALLEL_LORA_ENGINE.lock().unwrap(); + match engine_guard.as_ref() { + Some(e) => e.clone(), + None => { + eprintln!("PARALLEL_LORA_ENGINE not initialized"); + return default_result; + } + } + }; // Lock is released here immediately after clone + + // Now perform inference without holding the lock (allows concurrent requests) + let text_refs: Vec<&str> = text_vec.iter().map(|s| s.as_ref()).collect(); + match engine.parallel_classify(&text_refs) { + Ok(parallel_result) => { + let _processing_time_ms = start_time.elapsed().as_millis() as f32; + + // Allocate C arrays for LoRA results + let intent_results_ptr = + unsafe { allocate_lora_intent_array(¶llel_result.intent_results) }; + let pii_results_ptr = unsafe { allocate_lora_pii_array(¶llel_result.pii_results) }; + let security_results_ptr = + unsafe { allocate_lora_security_array(¶llel_result.security_results) }; + + LoRABatchResult { + intent_results: intent_results_ptr, + pii_results: pii_results_ptr, + security_results: security_results_ptr, + batch_size: texts_count as i32, + avg_confidence: { + let mut total_confidence = 0.0f32; + let mut count = 0; + + // Sum intent confidences + for intent in ¶llel_result.intent_results { + total_confidence += intent.confidence; + count += 1; } - } - Err(e) => { - println!("LoRA parallel classification failed: {}", e); - LoRABatchResult { - intent_results: std::ptr::null_mut(), - pii_results: std::ptr::null_mut(), - security_results: std::ptr::null_mut(), - batch_size: 0, - avg_confidence: 0.0, + + // Sum PII confidences + for pii in ¶llel_result.pii_results { + total_confidence += pii.confidence; + count += 1; } - } + + // Sum security confidences + for security in ¶llel_result.security_results { + total_confidence += security.confidence; + count += 1; + } + + if count > 0 { + total_confidence / count as f32 + } else { + 0.0 + } + }, } } - None => { - println!("ParallelLoRAEngine not initialized - call init function first"); + Err(e) => { + println!("LoRA parallel classification failed: {}", e); LoRABatchResult { intent_results: std::ptr::null_mut(), pii_results: std::ptr::null_mut(), diff --git a/candle-binding/src/ffi/init.rs b/candle-binding/src/ffi/init.rs index f2fd276f..1bfc2f8a 100644 --- a/candle-binding/src/ffi/init.rs +++ b/candle-binding/src/ffi/init.rs @@ -17,10 +17,11 @@ lazy_static! { static ref BERT_CLASSIFIER: Arc>> = Arc::new(Mutex::new(None)); static ref BERT_PII_CLASSIFIER: Arc>> = Arc::new(Mutex::new(None)); static ref BERT_JAILBREAK_CLASSIFIER: Arc>> = Arc::new(Mutex::new(None)); - // Unified classifier for dual-path architecture - static ref UNIFIED_CLASSIFIER: Arc>> = Arc::new(Mutex::new(None)); - // Parallel LoRA engine for high-performance classification - pub static ref PARALLEL_LORA_ENGINE: Arc>> = Arc::new(Mutex::new(None)); + // Unified classifier for dual-path architecture (exported for use in classify.rs) + pub static ref UNIFIED_CLASSIFIER: Arc>> = Arc::new(Mutex::new(None)); + // Parallel LoRA engine for high-performance classification (primary path for LoRA models) + // Wrapped in Arc for cheap cloning and concurrent access + pub static ref PARALLEL_LORA_ENGINE: Arc>>> = Arc::new(Mutex::new(None)); // LoRA token classifier for token-level classification pub static ref LORA_TOKEN_CLASSIFIER: Arc>> = Arc::new(Mutex::new(None)); } @@ -719,9 +720,9 @@ pub extern "C" fn init_lora_unified_classifier( use_cpu, ) { Ok(engine) => { - // Store in global static variable + // Store in global static variable (wrapped in Arc for efficient cloning) let mut engine_guard = PARALLEL_LORA_ENGINE.lock().unwrap(); - *engine_guard = Some(engine); + *engine_guard = Some(Arc::new(engine)); true } Err(e) => {