Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 63 additions & 9 deletions src/semantic-router/pkg/services/classification.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package services
import (
"fmt"
"os"
"strings"
"sync"
"time"

Expand Down Expand Up @@ -35,9 +36,9 @@ func NewClassificationService(classifier *classification.Classifier, config *con
}

// NewUnifiedClassificationService creates a new service with unified classifier
func NewUnifiedClassificationService(unifiedClassifier *classification.UnifiedClassifier, config *config.RouterConfig) *ClassificationService {
func NewUnifiedClassificationService(unifiedClassifier *classification.UnifiedClassifier, legacyClassifier *classification.Classifier, config *config.RouterConfig) *ClassificationService {
service := &ClassificationService{
classifier: nil, // Legacy classifier not used
classifier: legacyClassifier,
unifiedClassifier: unifiedClassifier,
config: config,
}
Expand All @@ -54,16 +55,69 @@ func NewClassificationServiceWithAutoDiscovery(config *config.RouterConfig) (*Cl
observability.Debugf("Debug: Attempting to discover models in: ./models")

// Always try to auto-discover and initialize unified classifier for batch processing
unifiedClassifier, err := classification.AutoInitializeUnifiedClassifier("./models")
// Use model path from config, fallback to "./models" if not specified
modelsPath := "./models"
if config != nil && config.Classifier.CategoryModel.ModelID != "" {
// Extract the models directory from the model path
// e.g., "models/category_classifier_modernbert-base_model" -> "models"
if idx := strings.Index(config.Classifier.CategoryModel.ModelID, "/"); idx > 0 {
modelsPath = config.Classifier.CategoryModel.ModelID[:idx]
}
}
unifiedClassifier, ucErr := classification.AutoInitializeUnifiedClassifier(modelsPath)
if ucErr != nil {
observability.Infof("Unified classifier auto-discovery failed: %v", ucErr)
}
// create legacy classifier
Copy link
Preview

Copilot AI Oct 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment should start with a capital letter: 'Create legacy classifier'.

Suggested change
// create legacy classifier
// Create legacy classifier

Copilot uses AI. Check for mistakes.

legacyClassifier, lcErr := createLegacyClassifier(config)
if lcErr != nil {
observability.Warnf("Legacy classifier initialization failed: %v", lcErr)
}
if unifiedClassifier == nil && legacyClassifier == nil {
observability.Warnf("No classifier initialized. Using placeholder service.")
}
return NewUnifiedClassificationService(unifiedClassifier, legacyClassifier, config), nil
}

// createLegacyClassifier creates a legacy classifier with proper model loading
func createLegacyClassifier(config *config.RouterConfig) (*classification.Classifier, error) {
// Load category mapping
var categoryMapping *classification.CategoryMapping
if config.Classifier.CategoryModel.CategoryMappingPath != "" {
var err error
categoryMapping, err = classification.LoadCategoryMapping(config.Classifier.CategoryModel.CategoryMappingPath)
if err != nil {
return nil, fmt.Errorf("failed to load category mapping: %w", err)
}
}

// Load PII mapping
var piiMapping *classification.PIIMapping
if config.Classifier.PIIModel.PIIMappingPath != "" {
var err error
piiMapping, err = classification.LoadPIIMapping(config.Classifier.PIIModel.PIIMappingPath)
if err != nil {
return nil, fmt.Errorf("failed to load PII mapping: %w", err)
}
}

// Load jailbreak mapping
var jailbreakMapping *classification.JailbreakMapping
if config.PromptGuard.JailbreakMappingPath != "" {
var err error
jailbreakMapping, err = classification.LoadJailbreakMapping(config.PromptGuard.JailbreakMappingPath)
if err != nil {
return nil, fmt.Errorf("failed to load jailbreak mapping: %w", err)
}
}

// Create classifier
classifier, err := classification.NewClassifier(config, categoryMapping, piiMapping, jailbreakMapping)
if err != nil {
// Log the discovery failure but don't fail - fall back to legacy processing
observability.Infof("Unified classifier auto-discovery failed: %v. Using legacy processing.", err)
return NewClassificationService(nil, config), nil
return nil, fmt.Errorf("failed to create classifier: %w", err)
}

// Success! Create service with unified classifier
observability.Infof("Unified classifier auto-discovered and initialized. Using batch processing.")
return NewUnifiedClassificationService(unifiedClassifier, config), nil
return classifier, nil
}

// GetGlobalClassificationService returns the global classification service instance
Expand Down
Loading