diff --git a/deploy/helm/semantic-router/values.yaml b/deploy/helm/semantic-router/values.yaml index fa5f052c8..1ca81118b 100644 --- a/deploy/helm/semantic-router/values.yaml +++ b/deploy/helm/semantic-router/values.yaml @@ -167,6 +167,9 @@ initContainer: repo: LLM-Semantic-Router/jailbreak_classifier_modernbert-base_model - name: pii_classifier_modernbert-base_presidio_token_model repo: LLM-Semantic-Router/pii_classifier_modernbert-base_presidio_token_model + # LoRA PII detector (for auto-detection feature) + - name: lora_pii_detector_bert-base-uncased_model + repo: LLM-Semantic-Router/lora_pii_detector_bert-base-uncased_model # Autoscaling configuration diff --git a/deploy/kubernetes/aibrix/semantic-router-values/values.yaml b/deploy/kubernetes/aibrix/semantic-router-values/values.yaml index ffb17a1d4..ec1d43537 100644 --- a/deploy/kubernetes/aibrix/semantic-router-values/values.yaml +++ b/deploy/kubernetes/aibrix/semantic-router-values/values.yaml @@ -437,8 +437,10 @@ config: use_cpu: true category_mapping_path: "models/category_classifier_modernbert-base_model/category_mapping.json" pii_model: - model_id: "models/pii_classifier_modernbert-base_presidio_token_model" - use_modernbert: true + # Support both traditional (modernbert) and LoRA-based PII detection + # When model_type is "auto", the system will auto-detect LoRA configuration + model_id: "models/lora_pii_detector_bert-base-uncased_model" + use_modernbert: false # Use LoRA PII model with auto-detection threshold: 0.7 use_cpu: true pii_mapping_path: "models/pii_classifier_modernbert-base_presidio_token_model/pii_type_mapping.json" diff --git a/e2e/profiles/ai-gateway/values.yaml b/e2e/profiles/ai-gateway/values.yaml index 8523cdeb3..ed67b6eaf 100644 --- a/e2e/profiles/ai-gateway/values.yaml +++ b/e2e/profiles/ai-gateway/values.yaml @@ -467,8 +467,10 @@ config: use_cpu: true category_mapping_path: "models/category_classifier_modernbert-base_model/category_mapping.json" pii_model: - model_id: "models/pii_classifier_modernbert-base_presidio_token_model" - use_modernbert: true + # Support both traditional (modernbert) and LoRA-based PII detection + # When model_type is "auto", the system will auto-detect LoRA configuration + model_id: "models/lora_pii_detector_bert-base-uncased_model" + use_modernbert: false # Use LoRA PII model with auto-detection threshold: 0.7 use_cpu: true pii_mapping_path: "models/pii_classifier_modernbert-base_presidio_token_model/pii_type_mapping.json" diff --git a/e2e/profiles/dynamic-config/values.yaml b/e2e/profiles/dynamic-config/values.yaml index af8d9ee71..b14a2b92c 100644 --- a/e2e/profiles/dynamic-config/values.yaml +++ b/e2e/profiles/dynamic-config/values.yaml @@ -48,8 +48,8 @@ config: use_cpu: true category_mapping_path: "models/category_classifier_modernbert-base_model/category_mapping.json" pii_model: - model_id: "models/pii_classifier_modernbert-base_presidio_token_model" - use_modernbert: true + model_id: "models/lora_pii_detector_bert-base-uncased_model" + use_modernbert: false # Use LoRA PII model with auto-detection threshold: 0.7 use_cpu: true pii_mapping_path: "models/pii_classifier_modernbert-base_presidio_token_model/pii_type_mapping.json" diff --git a/src/semantic-router/pkg/classification/classifier.go b/src/semantic-router/pkg/classification/classifier.go index 9e737a12b..8cfb8ee41 100644 --- a/src/semantic-router/pkg/classification/classifier.go +++ b/src/semantic-router/pkg/classification/classifier.go @@ -140,35 +140,55 @@ func createJailbreakInference(useModernBERT bool) JailbreakInference { } type PIIInitializer interface { - Init(modelID string, useCPU bool) error + Init(modelID string, useCPU bool, numClasses int) error } -type ModernBertPIIInitializer struct{} +type PIIInitializerImpl struct { + usedModernBERT bool // Track which init path succeeded for inference routing +} + +func (c *PIIInitializerImpl) Init(modelID string, useCPU bool, numClasses int) error { + // Try auto-detecting Candle BERT init first - checks for lora_config.json + // This enables LoRA PII models when available + success := candle_binding.InitCandleBertTokenClassifier(modelID, numClasses, useCPU) + if success { + c.usedModernBERT = false + logging.Infof("Initialized PII token classifier with auto-detection (LoRA or Traditional BERT)") + return nil + } -func (c *ModernBertPIIInitializer) Init(modelID string, useCPU bool) error { + // Fallback to ModernBERT-specific init for backward compatibility + // This handles models with incomplete configs (missing hidden_act, etc.) + logging.Infof("Auto-detection failed, falling back to ModernBERT PII initializer") err := candle_binding.InitModernBertPIITokenClassifier(modelID, useCPU) if err != nil { - return err + return fmt.Errorf("failed to initialize PII token classifier (both auto-detect and ModernBERT): %w", err) } - logging.Infof("Initialized ModernBERT PII token classifier for entity detection") + c.usedModernBERT = true + logging.Infof("Initialized ModernBERT PII token classifier (fallback mode)") return nil } -// createPIIInitializer creates the appropriate PII initializer (currently only ModernBERT) -func createPIIInitializer() PIIInitializer { return &ModernBertPIIInitializer{} } +// createPIIInitializer creates the PII initializer (auto-detecting) +func createPIIInitializer() PIIInitializer { + return &PIIInitializerImpl{} +} type PIIInference interface { ClassifyTokens(text string, configPath string) (candle_binding.TokenClassificationResult, error) } -type ModernBertPIIInference struct{} +type PIIInferenceImpl struct{} -func (c *ModernBertPIIInference) ClassifyTokens(text string, configPath string) (candle_binding.TokenClassificationResult, error) { - return candle_binding.ClassifyModernBertPIITokens(text, configPath) +func (c *PIIInferenceImpl) ClassifyTokens(text string, configPath string) (candle_binding.TokenClassificationResult, error) { + // Auto-detecting inference - uses whichever classifier was initialized (LoRA or Traditional) + return candle_binding.ClassifyCandleBertTokens(text) } -// createPIIInference creates the appropriate PII inference (currently only ModernBERT) -func createPIIInference() PIIInference { return &ModernBertPIIInference{} } +// createPIIInference creates the PII inference (auto-detecting) +func createPIIInference() PIIInference { + return &PIIInferenceImpl{} +} // JailbreakDetection represents the result of jailbreak analysis for a piece of content type JailbreakDetection struct { @@ -348,7 +368,7 @@ func NewClassifier(cfg *config.RouterConfig, categoryMapping *CategoryMapping, p // Add in-tree classifier if configured if cfg.CategoryModel.ModelID != "" { - options = append(options, withCategory(categoryMapping, createCategoryInitializer(cfg.UseModernBERT), createCategoryInference(cfg.UseModernBERT))) + options = append(options, withCategory(categoryMapping, createCategoryInitializer(cfg.CategoryModel.UseModernBERT), createCategoryInference(cfg.CategoryModel.UseModernBERT))) } // Add MCP classifier if configured @@ -509,7 +529,8 @@ func (c *Classifier) initializePIIClassifier() error { return fmt.Errorf("not enough PII types for classification, need at least 2, got %d", numPIIClasses) } - return c.piiInitializer.Init(c.Config.PIIModel.ModelID, c.Config.PIIModel.UseCPU) + // Pass numClasses to support auto-detection + return c.piiInitializer.Init(c.Config.PIIModel.ModelID, c.Config.PIIModel.UseCPU, numPIIClasses) } // EvaluateAllRules evaluates all rule types and returns matched rule names diff --git a/src/semantic-router/pkg/classification/classifier_test.go b/src/semantic-router/pkg/classification/classifier_test.go index 42f0f3320..080897f5b 100644 --- a/src/semantic-router/pkg/classification/classifier_test.go +++ b/src/semantic-router/pkg/classification/classifier_test.go @@ -287,7 +287,7 @@ var _ = Describe("jailbreak detection", func() { type MockPIIInitializer struct{ InitError error } -func (m *MockPIIInitializer) Init(_ string, useCPU bool) error { return m.InitError } +func (m *MockPIIInitializer) Init(_ string, useCPU bool, numClasses int) error { return m.InitError } type MockPIIInferenceResponse struct { classifyTokensResult candle_binding.TokenClassificationResult diff --git a/src/semantic-router/pkg/extproc/extproc_test.go b/src/semantic-router/pkg/extproc/extproc_test.go index fc320e746..e8b9fab76 100644 --- a/src/semantic-router/pkg/extproc/extproc_test.go +++ b/src/semantic-router/pkg/extproc/extproc_test.go @@ -2030,6 +2030,8 @@ var _ = Describe("Caching Functionality", func() { BeforeEach(func() { cfg = CreateTestConfig() cfg.Enabled = true + // Disable PII detection for caching tests (not needed and avoids model loading issues) + cfg.InlineModels.Classifier.PIIModel.ModelID = "" var err error router, err = CreateTestRouter(cfg)