From a9c21e2e3d640192b7d50c70df6a2540a71c25b1 Mon Sep 17 00:00:00 2001 From: bitliu Date: Wed, 27 Aug 2025 20:16:07 +0800 Subject: [PATCH] feat: support reasoning mode Signed-off-by: bitliu --- config/config.yaml | 442 ++++++++++-------- config/envoy.yaml | 8 +- src/semantic-router/pkg/config/config.go | 8 +- .../pkg/extproc/endpoint_selection_test.go | 6 +- src/semantic-router/pkg/extproc/processor.go | 14 +- .../pkg/extproc/reason_mode_config_test.go | 309 ++++++++++++ .../pkg/extproc/reason_mode_selector.go | 161 +++++++ .../pkg/extproc/reasoning_integration_test.go | 205 ++++++++ .../pkg/extproc/request_handler.go | 173 +++---- src/semantic-router/pkg/extproc/router.go | 3 + src/semantic-router/pkg/extproc/utils.go | 2 + 11 files changed, 1031 insertions(+), 300 deletions(-) create mode 100644 src/semantic-router/pkg/extproc/reason_mode_config_test.go create mode 100644 src/semantic-router/pkg/extproc/reason_mode_selector.go create mode 100644 src/semantic-router/pkg/extproc/reasoning_integration_test.go diff --git a/config/config.yaml b/config/config.yaml index 45a6a970..5d98b83e 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -1,208 +1,236 @@ -bert_model: - model_id: sentence-transformers/all-MiniLM-L12-v2 - threshold: 0.6 - use_cpu: true -semantic_cache: - enabled: true - similarity_threshold: 0.8 - max_entries: 1000 - ttl_seconds: 3600 -tools: - enabled: true # Set to true to enable automatic tool selection - top_k: 3 # Number of most relevant tools to select - similarity_threshold: 0.2 # Threshold for tool similarity - tools_db_path: "config/tools_db.json" - fallback_to_empty: true # If true, return no tools on failure; if false, return error -prompt_guard: - enabled: true - use_modernbert: true - model_id: "models/jailbreak_classifier_modernbert-base_model" - threshold: 0.7 - use_cpu: true - jailbreak_mapping_path: "models/jailbreak_classifier_modernbert-base_model/jailbreak_type_mapping.json" -gpu_config: - flops: 312000000000000 # 312e12 fp16 - hbm: 2000000000000 # 2e12 (2 TB/s) - description: "A100-80G" # https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/a100/pdf/nvidia-a100-datasheet-us-nvidia-1758950-r4-web.pdf - -# vLLM Endpoints Configuration - supports multiple endpoints, each can serve multiple models -vllm_endpoints: - - name: "endpoint1" - address: "192.168.12.90" - port: 11434 - models: - - "phi4" - - "gemma3:27b" - weight: 1 # Load balancing weight - health_check_path: "/health" # Optional health check endpoint - - name: "endpoint2" - address: "192.168.12.91" - port: 11434 - models: - - "mistral-small3.1" - weight: 1 - health_check_path: "/health" - - name: "endpoint3" - address: "192.168.12.92" - port: 11434 - models: - - "phi4" # Same model can be served by multiple endpoints for redundancy - - "mistral-small3.1" - weight: 2 # Higher weight for more powerful endpoint - -model_config: - phi4: - param_count: 14000000000 # 14B parameters https://huggingface.co/microsoft/phi-4 - batch_size: 512.0 # vLLM default batch size - context_size: 16384.0 # based on https://huggingface.co/microsoft/phi-4 - pii_policy: - allow_by_default: false # Deny all PII by default - pii_types_allowed: ["EMAIL_ADDRESS", "PERSON", "GPE", "PHONE_NUMBER"] # Only allow these specific PII types - # Specify which endpoints can serve this model (optional - if not specified, uses all endpoints that list this model) - preferred_endpoints: ["endpoint1", "endpoint3"] - gemma3:27b: - param_count: 27000000000 # 27B parameters (base version) - batch_size: 512.0 - context_size: 16384.0 - pii_policy: - allow_by_default: false # Deny all PII by default - pii_types_allowed: ["EMAIL_ADDRESS", "PERSON", "GPE", "PHONE_NUMBER"] # Only allow these specific PII types - preferred_endpoints: ["endpoint1"] - "mistral-small3.1": - param_count: 22000000000 - batch_size: 512.0 - context_size: 16384.0 - pii_policy: - allow_by_default: false # Deny all PII by default - pii_types_allowed: ["EMAIL_ADDRESS", "PERSON", "GPE", "PHONE_NUMBER"] # Only allow these specific PII types - preferred_endpoints: ["endpoint2", "endpoint3"] - -# Classifier configuration for text classification -classifier: - category_model: - model_id: "models/category_classifier_modernbert-base_model" #TODO: Use local model for now before the code can download the entire model from huggingface - use_modernbert: true - threshold: 0.6 - use_cpu: true - category_mapping_path: "models/category_classifier_modernbert-base_model/category_mapping.json" - pii_model: - model_id: "models/pii_classifier_modernbert-base_presidio_token_model" #TODO: Use local model for now before the code can download the entire model from huggingface - use_modernbert: true - threshold: 0.7 - use_cpu: true - pii_mapping_path: "models/pii_classifier_modernbert-base_presidio_token_model/pii_type_mapping.json" - load_aware: false -categories: -- name: business - model_scores: - - model: phi4 - score: 0.8 - - model: gemma3:27b - score: 0.4 - - model: mistral-small3.1 - score: 0.2 -- name: law - model_scores: - - model: gemma3:27b - score: 0.8 - - model: phi4 - score: 0.6 - - model: mistral-small3.1 - score: 0.4 -- name: psychology - model_scores: - - model: mistral-small3.1 - score: 0.6 - - model: gemma3:27b - score: 0.4 - - model: phi4 - score: 0.4 -- name: biology - model_scores: - - model: mistral-small3.1 - score: 0.8 - - model: gemma3:27b - score: 0.6 - - model: phi4 - score: 0.2 -- name: chemistry - model_scores: - - model: mistral-small3.1 - score: 0.8 - - model: gemma3:27b - score: 0.6 - - model: phi4 - score: 0.6 -- name: history - model_scores: - - model: mistral-small3.1 - score: 0.8 - - model: phi4 - score: 0.6 - - model: gemma3:27b - score: 0.4 -- name: other - model_scores: - - model: gemma3:27b - score: 0.8 - - model: phi4 - score: 0.6 - - model: mistral-small3.1 - score: 0.6 -- name: health - model_scores: - - model: gemma3:27b - score: 0.8 - - model: phi4 - score: 0.8 - - model: mistral-small3.1 - score: 0.6 -- name: economics - model_scores: - - model: gemma3:27b - score: 0.8 - - model: mistral-small3.1 - score: 0.8 - - model: phi4 - score: 0.0 -- name: math - model_scores: - - model: phi4 - score: 1.0 - - model: mistral-small3.1 - score: 0.8 - - model: gemma3:27b - score: 0.6 -- name: physics - model_scores: - - model: gemma3:27b - score: 0.4 - - model: phi4 - score: 0.4 - - model: mistral-small3.1 - score: 0.4 -- name: computer science - model_scores: - - model: gemma3:27b - score: 0.6 - - model: mistral-small3.1 - score: 0.6 - - model: phi4 - score: 0.0 -- name: philosophy - model_scores: - - model: phi4 - score: 0.6 - - model: gemma3:27b - score: 0.2 - - model: mistral-small3.1 - score: 0.2 -- name: engineering - model_scores: - - model: gemma3:27b - score: 0.6 - - model: mistral-small3.1 - score: 0.6 - - model: phi4 - score: 0.2 +bert_model: + model_id: sentence-transformers/all-MiniLM-L12-v2 + threshold: 0.6 + use_cpu: true +semantic_cache: + enabled: true + similarity_threshold: 0.8 + max_entries: 1000 + ttl_seconds: 3600 +tools: + enabled: true # Set to true to enable automatic tool selection + top_k: 3 # Number of most relevant tools to select + similarity_threshold: 0.2 # Threshold for tool similarity + tools_db_path: "config/tools_db.json" + fallback_to_empty: true # If true, return no tools on failure; if false, return error +prompt_guard: + enabled: true + use_modernbert: true + model_id: "models/jailbreak_classifier_modernbert-base_model" + threshold: 0.7 + use_cpu: true + jailbreak_mapping_path: "models/jailbreak_classifier_modernbert-base_model/jailbreak_type_mapping.json" +gpu_config: + flops: 312000000000000 # 312e12 fp16 + hbm: 2000000000000 # 2e12 (2 TB/s) + description: "A100-80G" # https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/a100/pdf/nvidia-a100-datasheet-us-nvidia-1758950-r4-web.pdf + +# vLLM Endpoints Configuration - supports multiple endpoints, each can serve multiple models +vllm_endpoints: + - name: "endpoint1" + address: "192.168.12.90" + port: 11434 + models: + - "phi4" + - "gemma3:27b" + weight: 1 # Load balancing weight + health_check_path: "/health" # Optional health check endpoint + - name: "endpoint2" + address: "192.168.12.91" + port: 11434 + models: + - "mistral-small3.1" + weight: 1 + health_check_path: "/health" + - name: "endpoint3" + address: "192.168.12.92" + port: 11434 + models: + - "phi4" # Same model can be served by multiple endpoints for redundancy + - "mistral-small3.1" + weight: 2 # Higher weight for more powerful endpoint + +model_config: + phi4: + param_count: 14000000000 # 14B parameters https://huggingface.co/microsoft/phi-4 + batch_size: 512.0 # vLLM default batch size + context_size: 16384.0 # based on https://huggingface.co/microsoft/phi-4 + pii_policy: + allow_by_default: false # Deny all PII by default + pii_types_allowed: ["EMAIL_ADDRESS", "PERSON", "GPE", "PHONE_NUMBER"] # Only allow these specific PII types + # Specify which endpoints can serve this model (optional - if not specified, uses all endpoints that list this model) + preferred_endpoints: ["endpoint1", "endpoint3"] + gemma3:27b: + param_count: 27000000000 # 27B parameters (base version) + batch_size: 512.0 + context_size: 16384.0 + pii_policy: + allow_by_default: false # Deny all PII by default + pii_types_allowed: ["EMAIL_ADDRESS", "PERSON", "GPE", "PHONE_NUMBER"] # Only allow these specific PII types + preferred_endpoints: ["endpoint1"] + "mistral-small3.1": + param_count: 22000000000 + batch_size: 512.0 + context_size: 16384.0 + pii_policy: + allow_by_default: false # Deny all PII by default + pii_types_allowed: ["EMAIL_ADDRESS", "PERSON", "GPE", "PHONE_NUMBER"] # Only allow these specific PII types + preferred_endpoints: ["endpoint2", "endpoint3"] + +# Classifier configuration for text classification +classifier: + category_model: + model_id: "models/category_classifier_modernbert-base_model" #TODO: Use local model for now before the code can download the entire model from huggingface + use_modernbert: true + threshold: 0.6 + use_cpu: true + category_mapping_path: "models/category_classifier_modernbert-base_model/category_mapping.json" + pii_model: + model_id: "models/pii_classifier_modernbert-base_presidio_token_model" #TODO: Use local model for now before the code can download the entire model from huggingface + use_modernbert: true + threshold: 0.7 + use_cpu: true + pii_mapping_path: "models/pii_classifier_modernbert-base_presidio_token_model/pii_type_mapping.json" + load_aware: false +categories: +- name: business + use_reasoning: false + reasoning_description: "Business content is typically conversational" + model_scores: + - model: phi4 + score: 0.8 + - model: gemma3:27b + score: 0.4 + - model: mistral-small3.1 + score: 0.2 +- name: law + use_reasoning: false + reasoning_description: "Legal content is typically explanatory" + model_scores: + - model: gemma3:27b + score: 0.8 + - model: phi4 + score: 0.6 + - model: mistral-small3.1 + score: 0.4 +- name: psychology + use_reasoning: false + reasoning_description: "Psychology content is usually explanatory" + model_scores: + - model: mistral-small3.1 + score: 0.6 + - model: gemma3:27b + score: 0.4 + - model: phi4 + score: 0.4 +- name: biology + use_reasoning: true + reasoning_description: "Biological processes benefit from structured analysis" + model_scores: + - model: mistral-small3.1 + score: 0.8 + - model: gemma3:27b + score: 0.6 + - model: phi4 + score: 0.2 +- name: chemistry + use_reasoning: true + reasoning_description: "Chemical reactions and formulas require systematic thinking" + model_scores: + - model: mistral-small3.1 + score: 0.8 + - model: gemma3:27b + score: 0.6 + - model: phi4 + score: 0.6 +- name: history + use_reasoning: false + reasoning_description: "Historical content is narrative-based" + model_scores: + - model: mistral-small3.1 + score: 0.8 + - model: phi4 + score: 0.6 + - model: gemma3:27b + score: 0.4 +- name: other + use_reasoning: false + reasoning_description: "General content doesn't require reasoning" + model_scores: + - model: gemma3:27b + score: 0.8 + - model: phi4 + score: 0.6 + - model: mistral-small3.1 + score: 0.6 +- name: health + use_reasoning: false + reasoning_description: "Health information is typically informational" + model_scores: + - model: gemma3:27b + score: 0.8 + - model: phi4 + score: 0.8 + - model: mistral-small3.1 + score: 0.6 +- name: economics + use_reasoning: false + reasoning_description: "Economic discussions are usually explanatory" + model_scores: + - model: gemma3:27b + score: 0.8 + - model: mistral-small3.1 + score: 0.8 + - model: phi4 + score: 0.0 +- name: math + use_reasoning: true + reasoning_description: "Mathematical problems require step-by-step reasoning" + model_scores: + - model: phi4 + score: 1.0 + - model: mistral-small3.1 + score: 0.8 + - model: gemma3:27b + score: 0.6 +- name: physics + use_reasoning: true + reasoning_description: "Physics concepts need logical analysis" + model_scores: + - model: gemma3:27b + score: 0.4 + - model: phi4 + score: 0.4 + - model: mistral-small3.1 + score: 0.4 +- name: computer science + use_reasoning: true + reasoning_description: "Programming and algorithms need logical reasoning" + model_scores: + - model: gemma3:27b + score: 0.6 + - model: mistral-small3.1 + score: 0.6 + - model: phi4 + score: 0.0 +- name: philosophy + use_reasoning: false + reasoning_description: "Philosophical discussions are conversational" + model_scores: + - model: phi4 + score: 0.6 + - model: gemma3:27b + score: 0.2 + - model: mistral-small3.1 + score: 0.2 +- name: engineering + use_reasoning: true + reasoning_description: "Engineering problems require systematic problem-solving" + model_scores: + - model: gemma3:27b + score: 0.6 + - model: mistral-small3.1 + score: 0.6 + - model: phi4 + score: 0.2 default_model: mistral-small3.1 \ No newline at end of file diff --git a/config/envoy.yaml b/config/envoy.yaml index 8009455a..04878615 100644 --- a/config/envoy.yaml +++ b/config/envoy.yaml @@ -31,7 +31,7 @@ static_resources: upstream_local_address: "%UPSTREAM_LOCAL_ADDRESS%" request_id: "%REQ(X-REQUEST-ID)%" selected_model: "%REQ(X-SELECTED-MODEL)%" - selected_endpoint: "%REQ(X-GATEWAY-DESTINATION-ENDPOINT)%" + selected_endpoint: "%REQ(X-SEMANTIC-DESTINATION-ENDPOINT)%" route_config: name: local_route virtual_hosts: @@ -42,7 +42,7 @@ static_resources: - match: prefix: "/" headers: - - name: "x-gateway-destination-endpoint" + - name: "x-semantic-destination-endpoint" string_match: exact: "endpoint1" route: @@ -51,7 +51,7 @@ static_resources: - match: prefix: "/" headers: - - name: "x-gateway-destination-endpoint" + - name: "x-semantic-destination-endpoint" string_match: exact: "endpoint2" route: @@ -60,7 +60,7 @@ static_resources: - match: prefix: "/" headers: - - name: "x-gateway-destination-endpoint" + - name: "x-semantic-destination-endpoint" string_match: exact: "endpoint3" route: diff --git a/src/semantic-router/pkg/config/config.go b/src/semantic-router/pkg/config/config.go index e22c7922..d89fc189 100644 --- a/src/semantic-router/pkg/config/config.go +++ b/src/semantic-router/pkg/config/config.go @@ -213,9 +213,11 @@ type ModelScore struct { } type Category struct { - Name string `yaml:"name"` - Description string `yaml:"description,omitempty"` - ModelScores []ModelScore `yaml:"model_scores"` + Name string `yaml:"name"` + Description string `yaml:"description,omitempty"` + UseReasoning bool `yaml:"use_reasoning"` + ReasoningDescription string `yaml:"reasoning_description,omitempty"` + ModelScores []ModelScore `yaml:"model_scores"` } var ( diff --git a/src/semantic-router/pkg/extproc/endpoint_selection_test.go b/src/semantic-router/pkg/extproc/endpoint_selection_test.go index f3baa581..bb179679 100644 --- a/src/semantic-router/pkg/extproc/endpoint_selection_test.go +++ b/src/semantic-router/pkg/extproc/endpoint_selection_test.go @@ -75,7 +75,7 @@ var _ = Describe("Endpoint Selection", func() { var modelHeaderFound bool for _, header := range headerMutation.SetHeaders { - if header.Header.Key == "x-gateway-destination-endpoint" { + if header.Header.Key == "x-semantic-destination-endpoint" { endpointHeaderFound = true // Should be one of the configured endpoints Expect(header.Header.Value).To(BeElementOf("test-endpoint1", "test-endpoint2")) @@ -139,7 +139,7 @@ var _ = Describe("Endpoint Selection", func() { var selectedEndpoint string for _, header := range headerMutation.SetHeaders { - if header.Header.Key == "x-gateway-destination-endpoint" { + if header.Header.Key == "x-semantic-destination-endpoint" { endpointHeaderFound = true selectedEndpoint = header.Header.Value break @@ -198,7 +198,7 @@ var _ = Describe("Endpoint Selection", func() { var selectedEndpoint string for _, header := range headerMutation.SetHeaders { - if header.Header.Key == "x-gateway-destination-endpoint" { + if header.Header.Key == "x-semantic-destination-endpoint" { endpointHeaderFound = true selectedEndpoint = header.Header.Value break diff --git a/src/semantic-router/pkg/extproc/processor.go b/src/semantic-router/pkg/extproc/processor.go index 4fd76bb3..99d98e02 100644 --- a/src/semantic-router/pkg/extproc/processor.go +++ b/src/semantic-router/pkg/extproc/processor.go @@ -52,22 +52,30 @@ func (r *OpenAIRouter) Process(stream ext_proc.ExternalProcessor_ProcessServer) switch v := req.Request.(type) { case *ext_proc.ProcessingRequest_RequestHeaders: + log.Printf("DEBUG: Processing request headers") response, err := r.handleRequestHeaders(v, ctx) if err != nil { + log.Printf("ERROR: handleRequestHeaders failed: %v", err) return err } - if err := sendResponse(stream, response, "header"); err != nil { + if err := sendResponse(stream, response, "request header"); err != nil { + log.Printf("ERROR: sendResponse for headers failed: %v", err) return err } + log.Printf("DEBUG: Request headers processed successfully") case *ext_proc.ProcessingRequest_RequestBody: + log.Printf("DEBUG: Processing request body - THIS IS WHERE ROUTING HAPPENS") response, err := r.handleRequestBody(v, ctx) if err != nil { + log.Printf("ERROR: handleRequestBody failed: %v", err) return err } - if err := sendResponse(stream, response, "body"); err != nil { + if err := sendResponse(stream, response, "request body"); err != nil { + log.Printf("ERROR: sendResponse for body failed: %v", err) return err } + log.Printf("DEBUG: Request body processed successfully") case *ext_proc.ProcessingRequest_ResponseHeaders: response, err := r.handleResponseHeaders(v) @@ -105,5 +113,7 @@ func (r *OpenAIRouter) Process(stream ext_proc.ExternalProcessor_ProcessServer) return err } } + + log.Printf("DEBUG: Finished processing message, continuing to next...") } } diff --git a/src/semantic-router/pkg/extproc/reason_mode_config_test.go b/src/semantic-router/pkg/extproc/reason_mode_config_test.go new file mode 100644 index 00000000..0b9d572d --- /dev/null +++ b/src/semantic-router/pkg/extproc/reason_mode_config_test.go @@ -0,0 +1,309 @@ +package extproc + +import ( + "encoding/json" + "fmt" + "strings" + "testing" + + "github.com/vllm-project/semantic-router/semantic-router/pkg/config" +) + +// TestReasoningModeConfiguration demonstrates how the reasoning mode works with the new config-based approach +func TestReasoningModeConfiguration(t *testing.T) { + fmt.Println("=== Configuration-Based Reasoning Mode Test ===") + + // Create a mock configuration for testing + cfg := &config.RouterConfig{ + Categories: []config.Category{ + { + Name: "math", + UseReasoning: true, + ReasoningDescription: "Mathematical problems require step-by-step reasoning", + }, + { + Name: "business", + UseReasoning: false, + ReasoningDescription: "Business content is typically conversational", + }, + { + Name: "biology", + UseReasoning: true, + ReasoningDescription: "Biological processes benefit from structured analysis", + }, + }, + } + + fmt.Printf("Loaded configuration with %d categories\n\n", len(cfg.Categories)) + + // Display reasoning configuration for each category + fmt.Println("--- Reasoning Mode Configuration ---") + for _, category := range cfg.Categories { + reasoningStatus := "DISABLED" + if category.UseReasoning { + reasoningStatus = "ENABLED" + } + + fmt.Printf("Category: %-15s | Reasoning: %-8s | %s\n", + category.Name, reasoningStatus, category.ReasoningDescription) + } + + // Test queries with expected categories + testQueries := []struct { + query string + category string + }{ + {"What is the derivative of x^2 + 3x + 1?", "math"}, + {"Implement a binary search algorithm in Python", "computer science"}, + {"Explain the process of photosynthesis", "biology"}, + {"Write a business plan for a coffee shop", "business"}, + {"Tell me about World War II", "history"}, + {"What are Newton's laws of motion?", "physics"}, + {"How does chemical bonding work?", "chemistry"}, + {"Design a bridge structure", "engineering"}, + } + + fmt.Printf("\n--- Test Query Reasoning Decisions ---\n") + for _, test := range testQueries { + // Find the category configuration + var useReasoning bool + var reasoningDesc string + var found bool + + for _, category := range cfg.Categories { + if strings.EqualFold(category.Name, test.category) { + useReasoning = category.UseReasoning + reasoningDesc = category.ReasoningDescription + found = true + break + } + } + + if !found { + fmt.Printf("Query: %s\n", test.query) + fmt.Printf(" Expected Category: %s (NOT FOUND IN CONFIG)\n", test.category) + fmt.Printf(" Reasoning: DISABLED (default)\n\n") + continue + } + + reasoningStatus := "DISABLED" + if useReasoning { + reasoningStatus = "ENABLED" + } + + fmt.Printf("Query: %s\n", test.query) + fmt.Printf(" Category: %s\n", test.category) + fmt.Printf(" Reasoning: %s - %s\n", reasoningStatus, reasoningDesc) + + // // Generate example request body + // messages := []map[string]string{ + // {"role": "system", "content": "You are an AI assistant"}, + // {"role": "user", "content": test.query}, + // } + + // requestBody := buildRequestBody("deepseek-v31", messages, useReasoning, true) + + // Show key differences in request + if useReasoning { + fmt.Printf(" Request includes: chat_template_kwargs: {thinking: true}\n") + } else { + fmt.Printf(" Request: Standard mode (no reasoning)\n") + } + fmt.Println() + } + + // Show example configuration section + fmt.Println("--- Example Config.yaml Section ---") + fmt.Print(` +categories: +- name: math + use_reasoning: true + reasoning_description: "Mathematical problems require step-by-step reasoning" + model_scores: + - model: phi4 + score: 1.0 + +- name: business + use_reasoning: false + reasoning_description: "Business content is typically conversational" + model_scores: + - model: phi4 + score: 0.8 +`) +} + +// GetReasoningConfigurationSummary returns a summary of the reasoning configuration +func GetReasoningConfigurationSummary(cfg *config.RouterConfig) map[string]interface{} { + summary := make(map[string]interface{}) + + reasoningEnabled := 0 + reasoningDisabled := 0 + + categoriesWithReasoning := []string{} + categoriesWithoutReasoning := []string{} + + for _, category := range cfg.Categories { + if category.UseReasoning { + reasoningEnabled++ + categoriesWithReasoning = append(categoriesWithReasoning, category.Name) + } else { + reasoningDisabled++ + categoriesWithoutReasoning = append(categoriesWithoutReasoning, category.Name) + } + } + + summary["total_categories"] = len(cfg.Categories) + summary["reasoning_enabled_count"] = reasoningEnabled + summary["reasoning_disabled_count"] = reasoningDisabled + summary["categories_with_reasoning"] = categoriesWithReasoning + summary["categories_without_reasoning"] = categoriesWithoutReasoning + + return summary +} + +// DemonstrateConfigurationUsage shows how to use the configuration-based reasoning +func DemonstrateConfigurationUsage() { + fmt.Println("=== Configuration Usage Example ===") + fmt.Println() + + fmt.Println("1. Configure reasoning in config.yaml:") + fmt.Print(` +categories: +- name: math + use_reasoning: true + reasoning_description: "Mathematical problems require step-by-step reasoning" + +- name: creative_writing + use_reasoning: false + reasoning_description: "Creative content flows better without structured reasoning" +`) + + fmt.Println("\n2. Use in Go code:") + fmt.Print(` +// The reasoning decision now comes from configuration +useReasoning := router.shouldUseReasoningMode(query) + +// Build request with appropriate reasoning mode +requestBody := buildRequestBody(model, messages, useReasoning, stream) +`) + + fmt.Println("\n3. Benefits of configuration-based approach:") + fmt.Println(" - Easy to modify reasoning settings without code changes") + fmt.Println(" - Consistent with existing category configuration") + fmt.Println(" - Supports different reasoning strategies per category") + fmt.Println(" - Can be updated at runtime by reloading configuration") + fmt.Println(" - Documentation is embedded in the config file") +} + +// TestAddReasoningModeToRequestBody tests the addReasoningModeToRequestBody function +func TestAddReasoningModeToRequestBody(t *testing.T) { + fmt.Println("=== Testing addReasoningModeToRequestBody Function ===") + + // Create a mock router with minimal config + router := &OpenAIRouter{} + + // Test case 1: Basic request body + originalRequest := map[string]interface{}{ + "model": "phi4", + "messages": []map[string]interface{}{ + {"role": "user", "content": "What is 2 + 2?"}, + }, + "stream": false, + } + + originalBody, err := json.Marshal(originalRequest) + if err != nil { + fmt.Printf("Error marshaling original request: %v\n", err) + return + } + + fmt.Printf("Original request body:\n%s\n\n", string(originalBody)) + + // Add reasoning mode + modifiedBody, err := router.addReasoningModeToRequestBody(originalBody) + if err != nil { + fmt.Printf("Error adding reasoning mode: %v\n", err) + return + } + + fmt.Printf("Modified request body with reasoning mode:\n%s\n\n", string(modifiedBody)) + + // Verify the modification + var modifiedRequest map[string]interface{} + if err := json.Unmarshal(modifiedBody, &modifiedRequest); err != nil { + fmt.Printf("Error unmarshaling modified request: %v\n", err) + return + } + + // Check if chat_template_kwargs was added + if chatTemplateKwargs, exists := modifiedRequest["chat_template_kwargs"]; exists { + if kwargs, ok := chatTemplateKwargs.(map[string]interface{}); ok { + if thinking, hasThinking := kwargs["thinking"]; hasThinking { + if thinkingBool, isBool := thinking.(bool); isBool && thinkingBool { + fmt.Println("✅ SUCCESS: chat_template_kwargs with thinking: true was correctly added") + } else { + fmt.Printf("❌ ERROR: thinking value is not true, got: %v\n", thinking) + } + } else { + fmt.Println("❌ ERROR: thinking field not found in chat_template_kwargs") + } + } else { + fmt.Printf("❌ ERROR: chat_template_kwargs is not a map, got: %T\n", chatTemplateKwargs) + } + } else { + fmt.Println("❌ ERROR: chat_template_kwargs not found in modified request") + } + + // Test case 2: Request with existing fields + fmt.Println("\n--- Test Case 2: Request with existing fields ---") + complexRequest := map[string]interface{}{ + "model": "phi4", + "messages": []map[string]interface{}{ + {"role": "system", "content": "You are a helpful assistant"}, + {"role": "user", "content": "Solve x^2 + 5x + 6 = 0"}, + }, + "stream": true, + "temperature": 0.7, + "max_tokens": 1000, + } + + complexBody, err := json.Marshal(complexRequest) + if err != nil { + fmt.Printf("Error marshaling complex request: %v\n", err) + return + } + + modifiedComplexBody, err := router.addReasoningModeToRequestBody(complexBody) + if err != nil { + fmt.Printf("Error adding reasoning mode to complex request: %v\n", err) + return + } + + var modifiedComplexRequest map[string]interface{} + if err := json.Unmarshal(modifiedComplexBody, &modifiedComplexRequest); err != nil { + fmt.Printf("Error unmarshaling modified complex request: %v\n", err) + return + } + + // Verify all original fields are preserved + originalFields := []string{"model", "messages", "stream", "temperature", "max_tokens"} + allFieldsPreserved := true + for _, field := range originalFields { + if _, exists := modifiedComplexRequest[field]; !exists { + fmt.Printf("❌ ERROR: Original field '%s' was lost\n", field) + allFieldsPreserved = false + } + } + + if allFieldsPreserved { + fmt.Println("✅ SUCCESS: All original fields preserved") + } + + // Verify chat_template_kwargs was added + if _, exists := modifiedComplexRequest["chat_template_kwargs"]; exists { + fmt.Println("✅ SUCCESS: chat_template_kwargs added to complex request") + fmt.Printf("Final modified request:\n%s\n", string(modifiedComplexBody)) + } else { + fmt.Println("❌ ERROR: chat_template_kwargs not added to complex request") + } +} diff --git a/src/semantic-router/pkg/extproc/reason_mode_selector.go b/src/semantic-router/pkg/extproc/reason_mode_selector.go new file mode 100644 index 00000000..2d23e31e --- /dev/null +++ b/src/semantic-router/pkg/extproc/reason_mode_selector.go @@ -0,0 +1,161 @@ +package extproc + +import ( + "encoding/json" + "fmt" + "log" + "strings" +) + +// shouldUseReasoningMode determines if reasoning mode should be enabled based on the query category +func (r *OpenAIRouter) shouldUseReasoningMode(query string) bool { + // Get the category for this query using the existing classification system + categoryName := r.findCategoryForClassification(query) + + // If no category was determined (empty string), default to no reasoning + if categoryName == "" { + log.Printf("No category determined for query, defaulting to no reasoning mode") + return false + } + + // Normalize category name for consistent lookup + normalizedCategory := strings.ToLower(strings.TrimSpace(categoryName)) + + // Look up the category in the configuration + for _, category := range r.Config.Categories { + if strings.EqualFold(category.Name, normalizedCategory) { + reasoningStatus := "DISABLED" + if category.UseReasoning { + reasoningStatus = "ENABLED" + } + log.Printf("Reasoning mode decision: Category '%s' → %s", + categoryName, reasoningStatus) + return category.UseReasoning + } + } + + // If category not found in config, default to no reasoning + log.Printf("Category '%s' not found in configuration, defaulting to no reasoning mode", categoryName) + return false +} + +// buildRequestBody builds the request body for vLLM inference server with proper reasoning mode settings +func buildRequestBody(model string, messages []map[string]string, useReasoning bool, stream bool) map[string]interface{} { + requestBody := map[string]interface{}{ + "model": model, + "messages": messages, + "stream": stream, + } + + // Add chat template kwargs if reasoning is enabled + if useReasoning { + requestBody["chat_template_kwargs"] = getChatTemplateKwargs(true) + log.Printf("Added reasoning mode to request for model: %s", model) + } else { + log.Printf("Using standard mode (no reasoning) for model: %s", model) + } + + return requestBody +} + +// getChatTemplateKwargs returns the appropriate chat template kwargs based on reasoning mode and streaming +func getChatTemplateKwargs(useReasoning bool) map[string]interface{} { + if useReasoning { + return map[string]interface{}{ + "thinking": true, + } + } + return nil +} + +// addReasoningModeToRequestBody adds chat_template_kwargs to the JSON request body +func (r *OpenAIRouter) addReasoningModeToRequestBody(requestBody []byte) ([]byte, error) { + // Parse the JSON request body + var requestMap map[string]interface{} + if err := json.Unmarshal(requestBody, &requestMap); err != nil { + return nil, fmt.Errorf("failed to parse request body: %w", err) + } + + // Add chat_template_kwargs for reasoning mode + requestMap["chat_template_kwargs"] = getChatTemplateKwargs(true) + + // Get the model name for logging + model := "unknown" + if modelValue, ok := requestMap["model"]; ok { + if modelStr, ok := modelValue.(string); ok { + model = modelStr + } + } + + log.Printf("Added reasoning mode (thinking: true) to request for model: %s", model) + + // Serialize back to JSON + modifiedBody, err := json.Marshal(requestMap) + if err != nil { + return nil, fmt.Errorf("failed to serialize modified request: %w", err) + } + + return modifiedBody, nil +} + +// logReasoningConfiguration logs the reasoning mode configuration for all categories during startup +func (r *OpenAIRouter) logReasoningConfiguration() { + if len(r.Config.Categories) == 0 { + log.Printf("No categories configured for reasoning mode") + return + } + + reasoningEnabled := []string{} + reasoningDisabled := []string{} + + for _, category := range r.Config.Categories { + if category.UseReasoning { + reasoningEnabled = append(reasoningEnabled, category.Name) + } else { + reasoningDisabled = append(reasoningDisabled, category.Name) + } + } + + log.Printf("Reasoning configuration - Total categories: %d", len(r.Config.Categories)) + + if len(reasoningEnabled) > 0 { + log.Printf("Reasoning ENABLED for categories (%d): %v", len(reasoningEnabled), reasoningEnabled) + } + + if len(reasoningDisabled) > 0 { + log.Printf("Reasoning DISABLED for categories (%d): %v", len(reasoningDisabled), reasoningDisabled) + } +} + +// ClassifyAndDetermineReasoningMode performs category classification and returns both the best model and reasoning mode setting +func (r *OpenAIRouter) ClassifyAndDetermineReasoningMode(query string) (string, bool) { + // Get the best model using existing logic + bestModel := r.classifyAndSelectBestModel(query) + + // Determine if reasoning mode should be used + useReasoning := r.shouldUseReasoningMode(query) + + reasoningStatus := "disabled" + if useReasoning { + reasoningStatus = "enabled" + } + log.Printf("Model selection complete: model=%s, reasoning=%s", bestModel, reasoningStatus) + + return bestModel, useReasoning +} + +// LogReasoningConfigurationSummary provides a compact summary of reasoning configuration +func (r *OpenAIRouter) LogReasoningConfigurationSummary() { + if len(r.Config.Categories) == 0 { + return + } + + enabledCount := 0 + for _, category := range r.Config.Categories { + if category.UseReasoning { + enabledCount++ + } + } + + log.Printf("Reasoning mode summary: %d/%d categories have reasoning enabled", enabledCount, len(r.Config.Categories)) +} diff --git a/src/semantic-router/pkg/extproc/reasoning_integration_test.go b/src/semantic-router/pkg/extproc/reasoning_integration_test.go new file mode 100644 index 00000000..36b7e768 --- /dev/null +++ b/src/semantic-router/pkg/extproc/reasoning_integration_test.go @@ -0,0 +1,205 @@ +package extproc + +import ( + "encoding/json" + "testing" + + "github.com/vllm-project/semantic-router/semantic-router/pkg/config" +) + +// TestReasoningModeIntegration tests the complete reasoning mode integration +func TestReasoningModeIntegration(t *testing.T) { + // Create a mock router with reasoning configuration + cfg := &config.RouterConfig{ + Categories: []config.Category{ + { + Name: "math", + UseReasoning: true, + ReasoningDescription: "Mathematical problems require step-by-step reasoning", + }, + { + Name: "business", + UseReasoning: false, + ReasoningDescription: "Business content is typically conversational", + }, + }, + } + + router := &OpenAIRouter{ + Config: cfg, + } + + // Test case 1: Math query should enable reasoning (when classifier works) + t.Run("Math query enables reasoning", func(t *testing.T) { + mathQuery := "What is the derivative of x^2 + 3x + 1?" + + // Since we don't have the actual classifier, this will return false + // But we can test the configuration logic directly + useReasoning := router.shouldUseReasoningMode(mathQuery) + + // Without a working classifier, this should be false + expectedReasoning := false + + if useReasoning != expectedReasoning { + t.Errorf("Expected reasoning mode %v for math query without classifier, got %v", expectedReasoning, useReasoning) + } + + // Test the configuration logic directly + mathCategory := cfg.Categories[0] // math category + if !mathCategory.UseReasoning { + t.Error("Math category should have UseReasoning set to true in configuration") + } + }) + + // Test case 2: Business query should not enable reasoning + t.Run("Business query disables reasoning", func(t *testing.T) { + businessQuery := "Write a business plan for a coffee shop" + + useReasoning := router.shouldUseReasoningMode(businessQuery) + + // Should be false because classifier returns empty (no category found) + if useReasoning != false { + t.Errorf("Expected reasoning mode false for business query, got %v", useReasoning) + } + }) + + // Test case 3: Test addReasoningModeToRequestBody function + t.Run("addReasoningModeToRequestBody adds correct fields", func(t *testing.T) { + originalRequest := map[string]interface{}{ + "model": "phi4", + "messages": []map[string]interface{}{ + {"role": "user", "content": "What is 2 + 2?"}, + }, + "stream": false, + } + + originalBody, err := json.Marshal(originalRequest) + if err != nil { + t.Fatalf("Failed to marshal original request: %v", err) + } + + modifiedBody, err := router.addReasoningModeToRequestBody(originalBody) + if err != nil { + t.Fatalf("Failed to add reasoning mode: %v", err) + } + + var modifiedRequest map[string]interface{} + if err := json.Unmarshal(modifiedBody, &modifiedRequest); err != nil { + t.Fatalf("Failed to unmarshal modified request: %v", err) + } + + // Check if chat_template_kwargs was added + chatTemplateKwargs, exists := modifiedRequest["chat_template_kwargs"] + if !exists { + t.Error("chat_template_kwargs not found in modified request") + } + + // Check if thinking: true was set + if kwargs, ok := chatTemplateKwargs.(map[string]interface{}); ok { + if thinking, hasThinking := kwargs["thinking"]; hasThinking { + if thinkingBool, isBool := thinking.(bool); !isBool || !thinkingBool { + t.Errorf("Expected thinking: true, got %v", thinking) + } + } else { + t.Error("thinking field not found in chat_template_kwargs") + } + } else { + t.Errorf("chat_template_kwargs is not a map, got %T", chatTemplateKwargs) + } + + // Verify original fields are preserved + originalFields := []string{"model", "messages", "stream"} + for _, field := range originalFields { + if _, exists := modifiedRequest[field]; !exists { + t.Errorf("Original field '%s' was lost", field) + } + } + }) + + // Test case 4: Test getChatTemplateKwargs function + t.Run("getChatTemplateKwargs returns correct values", func(t *testing.T) { + // Test with reasoning enabled + kwargs := getChatTemplateKwargs(true) + if kwargs == nil { + t.Error("Expected non-nil kwargs for reasoning enabled") + } + + if thinking, ok := kwargs["thinking"]; !ok || thinking != true { + t.Errorf("Expected thinking: true, got %v", thinking) + } + + // Test with reasoning disabled + kwargs = getChatTemplateKwargs(false) + if kwargs != nil { + t.Errorf("Expected nil kwargs for reasoning disabled, got %v", kwargs) + } + }) + + // Test case 5: Test empty query handling + t.Run("Empty query defaults to no reasoning", func(t *testing.T) { + useReasoning := router.shouldUseReasoningMode("") + if useReasoning != false { + t.Errorf("Expected reasoning mode false for empty query, got %v", useReasoning) + } + }) + + // Test case 6: Test unknown category handling + t.Run("Unknown category defaults to no reasoning", func(t *testing.T) { + unknownQuery := "This is some unknown category query" + useReasoning := router.shouldUseReasoningMode(unknownQuery) + if useReasoning != false { + t.Errorf("Expected reasoning mode false for unknown category, got %v", useReasoning) + } + }) +} + +// TestReasoningModeConfigurationValidation tests the configuration validation +func TestReasoningModeConfigurationValidation(t *testing.T) { + testCases := []struct { + name string + category config.Category + expected bool + }{ + { + name: "Math category with reasoning enabled", + category: config.Category{ + Name: "math", + UseReasoning: true, + ReasoningDescription: "Mathematical problems require step-by-step reasoning", + }, + expected: true, + }, + { + name: "Business category with reasoning disabled", + category: config.Category{ + Name: "business", + UseReasoning: false, + ReasoningDescription: "Business content is typically conversational", + }, + expected: false, + }, + { + name: "Science category with reasoning enabled", + category: config.Category{ + Name: "science", + UseReasoning: true, + ReasoningDescription: "Scientific concepts benefit from structured analysis", + }, + expected: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + if tc.category.UseReasoning != tc.expected { + t.Errorf("Expected UseReasoning %v for %s, got %v", + tc.expected, tc.category.Name, tc.category.UseReasoning) + } + + // Verify description is not empty + if tc.category.ReasoningDescription == "" { + t.Errorf("ReasoningDescription should not be empty for category %s", tc.category.Name) + } + }) + } +} diff --git a/src/semantic-router/pkg/extproc/request_handler.go b/src/semantic-router/pkg/extproc/request_handler.go index b0476f0a..398f4741 100644 --- a/src/semantic-router/pkg/extproc/request_handler.go +++ b/src/semantic-router/pkg/extproc/request_handler.go @@ -117,20 +117,28 @@ func (r *OpenAIRouter) handleRequestHeaders(v *ext_proc.ProcessingRequest_Reques // Store headers for later use headers := v.RequestHeaders.Headers + log.Printf("Processing %d request headers", len(headers.Headers)) for _, h := range headers.Headers { - ctx.Headers[h.Key] = h.Value + // Use RawValue instead of Value for header values + headerValue := string(h.RawValue) + + ctx.Headers[h.Key] = headerValue // Store request ID if present if strings.ToLower(h.Key) == "x-request-id" { - ctx.RequestID = h.Value + ctx.RequestID = headerValue } } + // Headers will be set in body phase to avoid conflicts (body phase replaces header phase) + log.Printf("DEBUG: [Header Phase] Skipping header setting - will be handled in body phase to avoid conflicts") + // Allow the request to continue response := &ext_proc.ProcessingResponse{ Response: &ext_proc.ProcessingResponse_RequestHeaders{ RequestHeaders: &ext_proc.HeadersResponse{ Response: &ext_proc.CommonResponse{ Status: ext_proc.CommonResponse_CONTINUE, + // No HeaderMutation - will be handled in body phase }, }, }, @@ -329,6 +337,10 @@ func (r *OpenAIRouter) handleModelRouting(openAIRequest *openai.ChatCompletionNe log.Printf("Routing to model: %s", matchedModel) + // Check reasoning mode for this category + useReasoning := r.shouldUseReasoningMode(userContent) + log.Printf("Reasoning mode decision for this query: %v", useReasoning) + // Track the model load for the selected model r.Classifier.IncrementModelLoad(matchedModel) @@ -357,6 +369,15 @@ func (r *OpenAIRouter) handleModelRouting(openAIRequest *openai.ChatCompletionNe return nil, status.Errorf(codes.Internal, "error serializing modified request: %v", err) } + // Add reasoning mode to the request body if needed + if useReasoning { + modifiedBody, err = r.addReasoningModeToRequestBody(modifiedBody) + if err != nil { + log.Printf("Error adding reasoning mode to request: %v", err) + return nil, status.Errorf(codes.Internal, "error adding reasoning mode: %v", err) + } + } + // Create body mutation with the modified body bodyMutation := &ext_proc.BodyMutation{ Mutation: &ext_proc.BodyMutation_Body{ @@ -364,36 +385,38 @@ func (r *OpenAIRouter) handleModelRouting(openAIRequest *openai.ChatCompletionNe }, } - // Create header mutation to remove content-length and add endpoint selection - headerMutation := &ext_proc.HeaderMutation{ - RemoveHeaders: []string{"content-length"}, - } - - // Add endpoint and model selection headers + // Create header mutation with content-length removal AND all necessary routing headers + // (body phase HeaderMutation replaces header phase completely) + log.Printf("DEBUG: Creating headers - selectedEndpoint='%s', actualModel='%s'", selectedEndpoint, actualModel) + setHeaders := []*core.HeaderValueOption{} if selectedEndpoint != "" { - if headerMutation.SetHeaders == nil { - headerMutation.SetHeaders = make([]*core.HeaderValueOption, 0) - } - headerMutation.SetHeaders = append(headerMutation.SetHeaders, &core.HeaderValueOption{ + setHeaders = append(setHeaders, &core.HeaderValueOption{ Header: &core.HeaderValue{ - Key: "x-gateway-destination-endpoint", + Key: "x-semantic-destination-endpoint", Value: selectedEndpoint, }, }) } + if actualModel != "" { + setHeaders = append(setHeaders, &core.HeaderValueOption{ + Header: &core.HeaderValue{ + Key: "x-selected-model", + Value: actualModel, + }, + }) + } + log.Printf("DEBUG: Created headers array with %d headers", len(setHeaders)) - // Always add the selected model header for logging/debugging - if headerMutation.SetHeaders == nil { - headerMutation.SetHeaders = make([]*core.HeaderValueOption, 0) + headerMutation := &ext_proc.HeaderMutation{ + RemoveHeaders: []string{"content-length"}, + SetHeaders: setHeaders, } - headerMutation.SetHeaders = append(headerMutation.SetHeaders, &core.HeaderValueOption{ - Header: &core.HeaderValue{ - Key: "x-selected-model", - Value: actualModel, - }, - }) - // Set the response with both mutations + log.Printf("DEBUG: Body phase - removing content-length AND setting headers (Authorization left untouched):") + log.Printf("DEBUG: selectedEndpoint = '%s'", selectedEndpoint) + log.Printf("DEBUG: actualModel = '%s'", actualModel) + + // Set the response with body mutation and content-length removal response = &ext_proc.ProcessingResponse{ Response: &ext_proc.ProcessingResponse_RequestBody{ RequestBody: &ext_proc.BodyResponse{ @@ -437,41 +460,8 @@ func (r *OpenAIRouter) handleModelRouting(openAIRequest *openai.ChatCompletionNe // Save the actual model that will be used for token tracking ctx.RequestModel = actualModel - // If we have an endpoint selected, we need to set headers - if selectedEndpoint != "" { - var headerMutation *ext_proc.HeaderMutation - - // Get existing header mutation or create a new one - if response.GetRequestBody().GetResponse().GetHeaderMutation() != nil { - headerMutation = response.GetRequestBody().GetResponse().GetHeaderMutation() - } else { - headerMutation = &ext_proc.HeaderMutation{} - } - - // Initialize SetHeaders if nil - if headerMutation.SetHeaders == nil { - headerMutation.SetHeaders = make([]*core.HeaderValueOption, 0) - } - - // Add endpoint selection header - headerMutation.SetHeaders = append(headerMutation.SetHeaders, &core.HeaderValueOption{ - Header: &core.HeaderValue{ - Key: "x-gateway-destination-endpoint", - Value: selectedEndpoint, - }, - }) - - // Add selected model header for logging/debugging - headerMutation.SetHeaders = append(headerMutation.SetHeaders, &core.HeaderValueOption{ - Header: &core.HeaderValue{ - Key: "x-selected-model", - Value: actualModel, - }, - }) - - // Update the response with header mutation - response.GetRequestBody().GetResponse().HeaderMutation = headerMutation - } + // Endpoint selection headers already handled in header phase - no additional logic needed + log.Printf("DEBUG: Endpoint selection complete - all headers handled in header phase") // Handle tool selection based on tool_choice field if err := r.handleToolSelection(openAIRequest, userContent, nonUserMessages, &response, ctx); err != nil { @@ -525,7 +515,7 @@ func (r *OpenAIRouter) handleToolSelection(openAIRequest *openai.ChatCompletionN if r.Config.Tools.FallbackToEmpty { log.Printf("Tool selection failed, falling back to no tools: %v", err) openAIRequest.Tools = nil - return r.updateRequestWithTools(openAIRequest, response) + return r.updateRequestWithTools(openAIRequest, response, ctx) } return err } @@ -558,11 +548,11 @@ func (r *OpenAIRouter) handleToolSelection(openAIRequest *openai.ChatCompletionN log.Printf("Auto-selected %d tools for query: %s", len(selectedTools), classificationText) } - return r.updateRequestWithTools(openAIRequest, response) + return r.updateRequestWithTools(openAIRequest, response, ctx) } // updateRequestWithTools updates the request body with the selected tools -func (r *OpenAIRouter) updateRequestWithTools(openAIRequest *openai.ChatCompletionNewParams, response **ext_proc.ProcessingResponse) error { +func (r *OpenAIRouter) updateRequestWithTools(openAIRequest *openai.ChatCompletionNewParams, response **ext_proc.ProcessingResponse, ctx *RequestContext) error { // Re-serialize the request with modified tools modifiedBody, err := serializeOpenAIRequest(openAIRequest) if err != nil { @@ -576,33 +566,54 @@ func (r *OpenAIRouter) updateRequestWithTools(openAIRequest *openai.ChatCompleti }, } - // Create or get existing header mutation - var headerMutation *ext_proc.HeaderMutation - if (*response).GetRequestBody().GetResponse().GetHeaderMutation() != nil { - headerMutation = (*response).GetRequestBody().GetResponse().GetHeaderMutation() - } else { - headerMutation = &ext_proc.HeaderMutation{} + // Create header mutation with content-length removal AND all necessary routing headers + // (body phase HeaderMutation replaces header phase completely) + + // Get the headers that should have been set in the main routing + var selectedEndpoint, actualModel string + + // These should be available from the existing response + if (*response).GetRequestBody() != nil && (*response).GetRequestBody().GetResponse() != nil && + (*response).GetRequestBody().GetResponse().GetHeaderMutation() != nil && + (*response).GetRequestBody().GetResponse().GetHeaderMutation().GetSetHeaders() != nil { + for _, header := range (*response).GetRequestBody().GetResponse().GetHeaderMutation().GetSetHeaders() { + switch header.Header.Key { + case "x-semantic-destination-endpoint": + selectedEndpoint = header.Header.Value + case "x-selected-model": + actualModel = header.Header.Value + } + } } - // Add content-length removal if not already present - if headerMutation.RemoveHeaders == nil { - headerMutation.RemoveHeaders = make([]string, 0) + setHeaders := []*core.HeaderValueOption{} + if selectedEndpoint != "" { + setHeaders = append(setHeaders, &core.HeaderValueOption{ + Header: &core.HeaderValue{ + Key: "x-semantic-destination-endpoint", + Value: selectedEndpoint, + }, + }) } - - // Check if content-length is already in the remove list - hasContentLength := false - for _, header := range headerMutation.RemoveHeaders { - if strings.ToLower(header) == "content-length" { - hasContentLength = true - break - } + if actualModel != "" { + setHeaders = append(setHeaders, &core.HeaderValueOption{ + Header: &core.HeaderValue{ + Key: "x-selected-model", + Value: actualModel, + }, + }) } - if !hasContentLength { - headerMutation.RemoveHeaders = append(headerMutation.RemoveHeaders, "content-length") + // Intentionally do not mutate Authorization header here + + headerMutation := &ext_proc.HeaderMutation{ + RemoveHeaders: []string{"content-length"}, + SetHeaders: setHeaders, } - // Update the response with both mutations + log.Printf("DEBUG: Tool selection - removing content-length AND preserving routing headers (body phase replaces header phase)") + + // Update the response with body mutation and content-length removal *response = &ext_proc.ProcessingResponse{ Response: &ext_proc.ProcessingResponse_RequestBody{ RequestBody: &ext_proc.BodyResponse{ diff --git a/src/semantic-router/pkg/extproc/router.go b/src/semantic-router/pkg/extproc/router.go index b49bbfc0..eb358999 100644 --- a/src/semantic-router/pkg/extproc/router.go +++ b/src/semantic-router/pkg/extproc/router.go @@ -150,6 +150,9 @@ func NewOpenAIRouter(configPath string) (*OpenAIRouter, error) { pendingRequests: make(map[string][]byte), } + // Log reasoning configuration after router is created + router.logReasoningConfiguration() + return router, nil } diff --git a/src/semantic-router/pkg/extproc/utils.go b/src/semantic-router/pkg/extproc/utils.go index 7727b9c8..6b4bfc62 100644 --- a/src/semantic-router/pkg/extproc/utils.go +++ b/src/semantic-router/pkg/extproc/utils.go @@ -8,6 +8,8 @@ import ( // sendResponse sends a response with proper error handling and logging func sendResponse(stream ext_proc.ExternalProcessor_ProcessServer, response *ext_proc.ProcessingResponse, msgType string) error { + log.Printf("Sending at Stage [%s]: %+v", msgType, response) + // log.Printf("Sending %s response: %+v", msgType, response) if err := stream.Send(response); err != nil { log.Printf("Error sending %s response: %v", msgType, err)