From 8ef9b483d36dddfd008ecf8c791fde3990c56a6f Mon Sep 17 00:00:00 2001 From: amirejaz Date: Fri, 19 Sep 2025 11:36:46 +0100 Subject: [PATCH 1/8] fix transport for auditor middleware --- pkg/audit/auditor.go | 36 +++++++++-------- pkg/audit/auditor_test.go | 78 +++++++++++++++++++----------------- pkg/audit/config.go | 13 +++--- pkg/audit/config_test.go | 4 +- pkg/audit/middleware.go | 5 ++- pkg/runner/config_builder.go | 9 +++-- pkg/runner/middleware.go | 7 ++-- 7 files changed, 83 insertions(+), 69 deletions(-) diff --git a/pkg/audit/auditor.go b/pkg/audit/auditor.go index 1de1ef135..35012acc5 100644 --- a/pkg/audit/auditor.go +++ b/pkg/audit/auditor.go @@ -15,6 +15,7 @@ import ( "github.com/stacklok/toolhive/pkg/auth" "github.com/stacklok/toolhive/pkg/logger" "github.com/stacklok/toolhive/pkg/mcp" + "github.com/stacklok/toolhive/pkg/transport/types" ) // LevelAudit is a custom audit log level - between Info and Warn @@ -35,12 +36,13 @@ func NewAuditLogger(w io.Writer) *slog.Logger { // Auditor handles audit logging for HTTP requests. type Auditor struct { - config *Config - auditLogger *slog.Logger + config *Config + auditLogger *slog.Logger + transportType string // e.g., "sse", "streamable-http" } -// NewAuditor creates a new Auditor with the given configuration. -func NewAuditor(config *Config) (*Auditor, error) { +// NewAuditorWithTransport creates a new Auditor with the given configuration and transport information. +func NewAuditorWithTransport(config *Config, transportType string) (*Auditor, error) { var logWriter io.Writer = os.Stdout // default to stdout if config != nil { @@ -54,11 +56,17 @@ func NewAuditor(config *Config) (*Auditor, error) { } return &Auditor{ - config: config, - auditLogger: NewAuditLogger(logWriter), + config: config, + auditLogger: NewAuditLogger(logWriter), + transportType: transportType, }, nil } +// isSSETransport checks if the current transport is SSE +func (a *Auditor) isSSETransport() bool { + return a.transportType == types.TransportTypeSSE.String() +} + // responseWriter wraps http.ResponseWriter to capture response data and status. type responseWriter struct { http.ResponseWriter @@ -88,7 +96,7 @@ func (a *Auditor) Middleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Handle SSE endpoints specially - log the connection event immediately // since SSE connections are long-lived and don't follow normal request/response pattern - if r.URL.Path == "/sse" { + if a.isSSETransport() { // Log SSE connection event immediately a.logSSEConnectionEvent(r) @@ -164,7 +172,7 @@ func (a *Auditor) logAuditEvent(r *http.Request, rw *responseWriter, requestData } // Add metadata - a.addMetadata(event, r, duration, rw) + a.addMetadata(event, duration, rw) // Add request/response data if configured a.addEventData(event, r, rw, requestData) @@ -184,7 +192,7 @@ func (a *Auditor) determineEventType(r *http.Request) string { path := r.URL.Path // Handle SSE connection establishment - if strings.Contains(path, "/sse") { + if a.isSSETransport() { return EventTypeMCPInitialize } @@ -372,7 +380,7 @@ func (*Auditor) extractTarget(r *http.Request, eventType string) map[string]stri } // addMetadata adds metadata to the audit event. -func (*Auditor) addMetadata(event *AuditEvent, r *http.Request, duration time.Duration, rw *responseWriter) { +func (a *Auditor) addMetadata(event *AuditEvent, duration time.Duration, rw *responseWriter) { if event.Metadata.Extra == nil { event.Metadata.Extra = make(map[string]any) } @@ -381,11 +389,7 @@ func (*Auditor) addMetadata(event *AuditEvent, r *http.Request, duration time.Du event.Metadata.Extra[MetadataExtraKeyDuration] = duration.Milliseconds() // Add transport information - if strings.Contains(r.URL.Path, "/sse") { - event.Metadata.Extra[MetadataExtraKeyTransport] = "sse" - } else { - event.Metadata.Extra[MetadataExtraKeyTransport] = "http" - } + event.Metadata.Extra[MetadataExtraKeyTransport] = a.transportType // Add response size if available if rw.body != nil { @@ -454,7 +458,7 @@ func (a *Auditor) logSSEConnectionEvent(r *http.Request) { // Add metadata event.Metadata.Extra = map[string]any{ - "transport": "sse", + "transport": a.transportType, "user_agent": r.Header.Get("User-Agent"), } diff --git a/pkg/audit/auditor_test.go b/pkg/audit/auditor_test.go index 610add0ad..e14ca2efe 100644 --- a/pkg/audit/auditor_test.go +++ b/pkg/audit/auditor_test.go @@ -26,7 +26,7 @@ func init() { func TestNewAuditor(t *testing.T) { t.Parallel() config := &Config{} - auditor, err := NewAuditor(config) + auditor, err := NewAuditorWithTransport(config, "sse") assert.NoError(t, err) assert.NotNil(t, auditor) @@ -36,7 +36,7 @@ func TestNewAuditor(t *testing.T) { func TestAuditorMiddlewareDisabled(t *testing.T) { t.Parallel() config := &Config{} - auditor, err := NewAuditor(config) + auditor, err := NewAuditorWithTransport(config, "sse") require.NoError(t, err) handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { @@ -61,7 +61,7 @@ func TestAuditorMiddlewareWithRequestData(t *testing.T) { IncludeRequestData: true, MaxDataSize: 1024, } - auditor, err := NewAuditor(config) + auditor, err := NewAuditorWithTransport(config, "sse") require.NoError(t, err) handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -91,7 +91,7 @@ func TestAuditorMiddlewareWithResponseData(t *testing.T) { IncludeResponseData: true, MaxDataSize: 1024, } - auditor, err := NewAuditor(config) + auditor, err := NewAuditorWithTransport(config, "sse") require.NoError(t, err) responseData := `{"result": "success"}` @@ -114,38 +114,43 @@ func TestAuditorMiddlewareWithResponseData(t *testing.T) { func TestDetermineEventType(t *testing.T) { t.Parallel() - auditor, err := NewAuditor(&Config{}) - require.NoError(t, err) tests := []struct { - name string - path string - method string - expected string + name string + path string + method string + transport string + expected string }{ { - name: "SSE endpoint", - path: "/sse", - method: "GET", - expected: EventTypeMCPInitialize, + name: "SSE endpoint", + path: "/sse", + method: "GET", + transport: "sse", + expected: EventTypeMCPInitialize, }, { - name: "MCP messages endpoint", - path: "/messages", - method: "POST", - expected: "mcp_request", // Since extractMCPMethod returns empty + name: "MCP messages endpoint", + path: "/messages", + method: "POST", + transport: "streamable-http", + expected: "mcp_request", // Since extractMCPMethod returns empty }, { - name: "Regular HTTP request", - path: "/api/health", - method: "GET", - expected: "http_request", + name: "Regular HTTP request", + path: "/api/health", + method: "GET", + transport: "streamable-http", + expected: "http_request", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() + auditor, err := NewAuditorWithTransport(&Config{}, tt.transport) + require.NoError(t, err) + req := httptest.NewRequest(tt.method, tt.path, nil) result := auditor.determineEventType(req) assert.Equal(t, tt.expected, result) @@ -174,7 +179,7 @@ func TestMapMCPMethodToEventType(t *testing.T) { {"unknown_method", "mcp_request"}, } - auditor, err := NewAuditor(&Config{}) + auditor, err := NewAuditorWithTransport(&Config{}, "sse") require.NoError(t, err) for _, tt := range tests { t.Run(tt.mcpMethod, func(t *testing.T) { @@ -187,7 +192,7 @@ func TestMapMCPMethodToEventType(t *testing.T) { func TestDetermineOutcome(t *testing.T) { t.Parallel() - auditor, err := NewAuditor(&Config{}) + auditor, err := NewAuditorWithTransport(&Config{}, "sse") require.NoError(t, err) tests := []struct { @@ -218,7 +223,7 @@ func TestDetermineOutcome(t *testing.T) { func TestGetClientIP(t *testing.T) { t.Parallel() - auditor, err := NewAuditor(&Config{}) + auditor, err := NewAuditorWithTransport(&Config{}, "sse") require.NoError(t, err) tests := []struct { @@ -268,7 +273,7 @@ func TestGetClientIP(t *testing.T) { func TestExtractSubjects(t *testing.T) { t.Parallel() - auditor, err := NewAuditor(&Config{}) + auditor, err := NewAuditorWithTransport(&Config{}, "sse") require.NoError(t, err) t.Run("with JWT claims", func(t *testing.T) { @@ -342,7 +347,7 @@ func TestDetermineComponent(t *testing.T) { t.Run("with configured component", func(t *testing.T) { t.Parallel() config := &Config{Component: "custom-component"} - auditor, err := NewAuditor(config) + auditor, err := NewAuditorWithTransport(config, "sse") require.NoError(t, err) req := httptest.NewRequest("GET", "/test", nil) @@ -354,7 +359,7 @@ func TestDetermineComponent(t *testing.T) { t.Run("without configured component", func(t *testing.T) { t.Parallel() config := &Config{} - auditor, err := NewAuditor(config) + auditor, err := NewAuditorWithTransport(config, "sse") require.NoError(t, err) req := httptest.NewRequest("GET", "/test", nil) @@ -366,7 +371,7 @@ func TestDetermineComponent(t *testing.T) { func TestExtractTarget(t *testing.T) { t.Parallel() - auditor, err := NewAuditor(&Config{}) + auditor, err := NewAuditorWithTransport(&Config{}, "sse") require.NoError(t, err) tests := []struct { @@ -423,18 +428,17 @@ func TestExtractTarget(t *testing.T) { func TestAddMetadata(t *testing.T) { t.Parallel() - auditor, err := NewAuditor(&Config{}) + auditor, err := NewAuditorWithTransport(&Config{}, "sse") require.NoError(t, err) event := NewAuditEvent("test", EventSource{}, OutcomeSuccess, map[string]string{}, "test") - req := httptest.NewRequest("GET", "/sse/test", nil) duration := 150 * time.Millisecond rw := &responseWriter{ ResponseWriter: httptest.NewRecorder(), body: bytes.NewBufferString("test response"), } - auditor.addMetadata(event, req, duration, rw) + auditor.addMetadata(event, duration, rw) require.NotNil(t, event.Metadata.Extra) assert.Equal(t, int64(150), event.Metadata.Extra[MetadataExtraKeyDuration]) @@ -450,7 +454,7 @@ func TestAddEventData(t *testing.T) { IncludeRequestData: true, IncludeResponseData: true, } - auditor, err := NewAuditor(config) + auditor, err := NewAuditorWithTransport(config, "sse") require.NoError(t, err) event := NewAuditEvent("test", EventSource{}, OutcomeSuccess, map[string]string{}, "test") @@ -483,7 +487,7 @@ func TestAddEventData(t *testing.T) { IncludeRequestData: true, IncludeResponseData: true, } - auditor, err := NewAuditor(config) + auditor, err := NewAuditorWithTransport(config, "sse") require.NoError(t, err) event := NewAuditEvent("test", EventSource{}, OutcomeSuccess, map[string]string{}, "test") @@ -511,7 +515,7 @@ func TestAddEventData(t *testing.T) { IncludeRequestData: false, IncludeResponseData: false, } - auditor, err := NewAuditor(config) + auditor, err := NewAuditorWithTransport(config, "sse") require.NoError(t, err) event := NewAuditEvent("test", EventSource{}, OutcomeSuccess, map[string]string{}, "test") @@ -531,7 +535,7 @@ func TestResponseWriterCapture(t *testing.T) { IncludeResponseData: true, MaxDataSize: 10, // Small limit for testing } - auditor, err := NewAuditor(config) + auditor, err := NewAuditorWithTransport(config, "sse") require.NoError(t, err) rw := &responseWriter{ @@ -568,7 +572,7 @@ func TestResponseWriterStatusCode(t *testing.T) { func TestExtractSourceWithHeaders(t *testing.T) { t.Parallel() - auditor, err := NewAuditor(&Config{}) + auditor, err := NewAuditorWithTransport(&Config{}, "sse") require.NoError(t, err) req := httptest.NewRequest("GET", "/test", nil) diff --git a/pkg/audit/config.go b/pkg/audit/config.go index 712358a50..5c7f0490f 100644 --- a/pkg/audit/config.go +++ b/pkg/audit/config.go @@ -104,9 +104,9 @@ func (c *Config) ShouldAuditEvent(eventType string) bool { return true } -// CreateMiddleware creates an HTTP middleware from the audit configuration. -func (c *Config) CreateMiddleware() (types.MiddlewareFunction, error) { - auditor, err := NewAuditor(c) +// CreateMiddlewareWithTransport creates an HTTP middleware from the audit configuration with transport information. +func (c *Config) CreateMiddlewareWithTransport(transportType string) (types.MiddlewareFunction, error) { + auditor, err := NewAuditorWithTransport(c, transportType) if err != nil { return nil, fmt.Errorf("failed to create auditor: %w", err) } @@ -114,15 +114,16 @@ func (c *Config) CreateMiddleware() (types.MiddlewareFunction, error) { } // GetMiddlewareFromFile loads the audit configuration from a file and creates an HTTP middleware. -func GetMiddlewareFromFile(path string) (func(http.Handler) http.Handler, error) { +// Note: This function requires a transport type to be provided separately. +func GetMiddlewareFromFile(path string, transportType string) (func(http.Handler) http.Handler, error) { // Load the configuration config, err := LoadFromFile(path) if err != nil { return nil, fmt.Errorf("failed to load audit config: %w", err) } - // Create the middleware - return config.CreateMiddleware() + // Create the middleware with transport information + return config.CreateMiddlewareWithTransport(transportType) } // Validate validates the audit configuration. diff --git a/pkg/audit/config_test.go b/pkg/audit/config_test.go index 5f6fb2a7b..adcaed084 100644 --- a/pkg/audit/config_test.go +++ b/pkg/audit/config_test.go @@ -112,7 +112,7 @@ func TestCreateMiddleware(t *testing.T) { t.Parallel() config := &Config{} - middleware, err := config.CreateMiddleware() + middleware, err := config.CreateMiddlewareWithTransport("sse") assert.NoError(t, err) assert.NotNil(t, middleware) } @@ -236,7 +236,7 @@ func TestConfigMinimalJSON(t *testing.T) { func TestGetMiddlewareFromFileError(t *testing.T) { t.Parallel() // Test with non-existent file - _, err := GetMiddlewareFromFile("/non/existent/file.json") + _, err := GetMiddlewareFromFile("/non/existent/file.json", "sse") assert.Error(t, err) assert.Contains(t, err.Error(), "failed to load audit config") } diff --git a/pkg/audit/middleware.go b/pkg/audit/middleware.go index 72fc81f04..3510a20a4 100644 --- a/pkg/audit/middleware.go +++ b/pkg/audit/middleware.go @@ -17,6 +17,8 @@ type MiddlewareParams struct { ConfigPath string `json:"config_path,omitempty"` // Kept for backwards compatibility ConfigData *Config `json:"config_data,omitempty"` // New field for config contents Component string `json:"component,omitempty"` + // Transport information for dynamic transport detection + TransportType string `json:"transport_type,omitempty"` // e.g., "sse", "streamable-http" } // Middleware wraps audit middleware functionality @@ -65,7 +67,8 @@ func CreateMiddleware(config *types.MiddlewareConfig, runner types.MiddlewareRun auditConfig.Component = params.Component } - middleware, err := auditConfig.CreateMiddleware() + // Always use the transport-aware constructor + middleware, err := auditConfig.CreateMiddlewareWithTransport(params.TransportType) if err != nil { return fmt.Errorf("failed to create audit middleware: %w", err) } diff --git a/pkg/runner/config_builder.go b/pkg/runner/config_builder.go index cf0f1142a..8bc748f44 100644 --- a/pkg/runner/config_builder.go +++ b/pkg/runner/config_builder.go @@ -453,7 +453,7 @@ func WithMiddlewareFromFlags( // Add optional middlewares middlewareConfigs = addTelemetryMiddleware(middlewareConfigs, telemetryConfig, serverName, transportType) middlewareConfigs = addAuthzMiddleware(middlewareConfigs, authzConfigPath) - middlewareConfigs = addAuditMiddleware(middlewareConfigs, enableAudit, auditConfigPath, serverName) + middlewareConfigs = addAuditMiddleware(middlewareConfigs, enableAudit, auditConfigPath, serverName, transportType) // Set the populated middleware configs b.config.MiddlewareConfigs = middlewareConfigs @@ -582,15 +582,16 @@ func addAuthzMiddleware( func addAuditMiddleware( middlewareConfigs []types.MiddlewareConfig, enableAudit bool, - auditConfigPath, serverName string, + auditConfigPath, serverName, transportType string, ) []types.MiddlewareConfig { if !enableAudit && auditConfigPath == "" { return middlewareConfigs } auditParams := audit.MiddlewareParams{ - ConfigPath: auditConfigPath, // Keep for backwards compatibility - Component: serverName, // Use server name as component + ConfigPath: auditConfigPath, // Keep for backwards compatibility + Component: serverName, // Use server name as component + TransportType: transportType, // Pass the actual transport type } // Read audit config contents if path is provided diff --git a/pkg/runner/middleware.go b/pkg/runner/middleware.go index 409725994..4b1c5f053 100644 --- a/pkg/runner/middleware.go +++ b/pkg/runner/middleware.go @@ -107,9 +107,10 @@ func PopulateMiddlewareConfigs(config *RunConfig) error { // Audit middleware (if enabled) if config.AuditConfig != nil { auditParams := audit.MiddlewareParams{ - ConfigPath: config.AuditConfigPath, // Keep for backwards compatibility - ConfigData: config.AuditConfig, // Use the loaded config data - Component: config.AuditConfig.Component, + ConfigPath: config.AuditConfigPath, // Keep for backwards compatibility + ConfigData: config.AuditConfig, // Use the loaded config data + Component: config.AuditConfig.Component, + TransportType: config.Transport.String(), // Pass the actual transport type } auditConfig, err := types.NewMiddlewareConfig(audit.MiddlewareType, auditParams) if err != nil { From 31946bf2f156db9b859badc8d34f71232d976a08 Mon Sep 17 00:00:00 2001 From: amirejaz Date: Fri, 19 Sep 2025 12:17:38 +0100 Subject: [PATCH 2/8] avoid using all the transport implicitly --- pkg/audit/auditor.go | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/pkg/audit/auditor.go b/pkg/audit/auditor.go index 35012acc5..1029774ca 100644 --- a/pkg/audit/auditor.go +++ b/pkg/audit/auditor.go @@ -389,7 +389,11 @@ func (a *Auditor) addMetadata(event *AuditEvent, duration time.Duration, rw *res event.Metadata.Extra[MetadataExtraKeyDuration] = duration.Milliseconds() // Add transport information - event.Metadata.Extra[MetadataExtraKeyTransport] = a.transportType + if a.isSSETransport() { + event.Metadata.Extra[MetadataExtraKeyTransport] = "sse" + } else { + event.Metadata.Extra[MetadataExtraKeyTransport] = "http" + } // Add response size if available if rw.body != nil { From d79ab6393ae21a93e0d50d895ca9220d15e5d618 Mon Sep 17 00:00:00 2001 From: amirejaz Date: Fri, 19 Sep 2025 17:38:43 +0100 Subject: [PATCH 3/8] added tests coverage and fixed e2e tests --- pkg/audit/auditor.go | 16 ++--- pkg/audit/auditor_test.go | 65 ++++++++++++++++++ pkg/audit/mcp_events.go | 2 + test/e2e/audit_middleware_e2e_test.go | 68 +++++++++++++++++++ test/e2e/osv_mcp_server_test.go | 6 +- .../osv_streamable_http_mcp_server_test.go | 6 +- test/e2e/proxy_stdio_test.go | 2 +- test/e2e/telemetry_middleware_e2e_test.go | 2 +- 8 files changed, 151 insertions(+), 16 deletions(-) diff --git a/pkg/audit/auditor.go b/pkg/audit/auditor.go index 1029774ca..78f964481 100644 --- a/pkg/audit/auditor.go +++ b/pkg/audit/auditor.go @@ -96,7 +96,7 @@ func (a *Auditor) Middleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Handle SSE endpoints specially - log the connection event immediately // since SSE connections are long-lived and don't follow normal request/response pattern - if a.isSSETransport() { + if a.isSSETransport() && r.Method == http.MethodGet { // Log SSE connection event immediately a.logSSEConnectionEvent(r) @@ -188,16 +188,12 @@ func (a *Auditor) determineEventType(r *http.Request) string { return a.mapMCPMethodToEventType(mcpMethod) } - // Fall back to path-based detection for non-MCP requests - path := r.URL.Path - // Handle SSE connection establishment - if a.isSSETransport() { - return EventTypeMCPInitialize + if a.isSSETransport() && r.Method == http.MethodGet { + return EventTypeSSEConnection } - // Handle MCP message endpoints that weren't parsed (malformed requests) - if strings.Contains(path, "/messages") && r.Method == "POST" { + if a.isSSETransport() && r.Method == http.MethodPost { return EventTypeMCPRequest } @@ -450,7 +446,7 @@ func (a *Auditor) logSSEConnectionEvent(r *http.Request) { component := a.determineComponent(r) // Create the audit event for SSE connection - event := NewAuditEvent("sse_connection", source, OutcomeSuccess, subjects, component) + event := NewAuditEvent(EventTypeSSEConnection, source, OutcomeSuccess, subjects, component) // Add target information target := map[string]string{ @@ -462,7 +458,7 @@ func (a *Auditor) logSSEConnectionEvent(r *http.Request) { // Add metadata event.Metadata.Extra = map[string]any{ - "transport": a.transportType, + "transport": "sse", "user_agent": r.Header.Get("User-Agent"), } diff --git a/pkg/audit/auditor_test.go b/pkg/audit/auditor_test.go index e14ca2efe..3ac9acb3f 100644 --- a/pkg/audit/auditor_test.go +++ b/pkg/audit/auditor_test.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "encoding/json" + "fmt" "net/http" "net/http/httptest" "strings" @@ -112,6 +113,42 @@ func TestAuditorMiddlewareWithResponseData(t *testing.T) { assert.Equal(t, responseData, rr.Body.String()) } +func TestAuditorMiddlewareWithDifferentSSEPaths(t *testing.T) { + t.Parallel() + config := &Config{} + auditor, err := NewAuditorWithTransport(config, "sse") + require.NoError(t, err) + + handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("test response")) + }) + + middleware := auditor.Middleware(handler) + + // Test different SSE paths to ensure transport type detection works correctly + testPaths := []string{ + "/sse", + "/v1/sse", + "/api/sse", + "/mcp/v2/sse", + "/events", // Non-SSE path but SSE transport + } + + for _, path := range testPaths { + t.Run(fmt.Sprintf("path_%s", strings.ReplaceAll(path, "/", "_")), func(t *testing.T) { + req := httptest.NewRequest("GET", path, nil) + rr := httptest.NewRecorder() + + middleware.ServeHTTP(rr, req) + + // All requests should succeed regardless of path since transport type is SSE + assert.Equal(t, http.StatusOK, rr.Code) + assert.Equal(t, "test response", rr.Body.String()) + }) + } +} + func TestDetermineEventType(t *testing.T) { t.Parallel() @@ -129,6 +166,34 @@ func TestDetermineEventType(t *testing.T) { transport: "sse", expected: EventTypeMCPInitialize, }, + { + name: "SSE endpoint with version path", + path: "/v1/sse", + method: "GET", + transport: "sse", + expected: EventTypeMCPInitialize, + }, + { + name: "SSE endpoint with API prefix", + path: "/api/sse", + method: "GET", + transport: "sse", + expected: EventTypeMCPInitialize, + }, + { + name: "SSE endpoint with nested path", + path: "/mcp/v2/sse", + method: "GET", + transport: "sse", + expected: EventTypeMCPInitialize, + }, + { + name: "SSE transport with non-SSE path", + path: "/events", + method: "GET", + transport: "sse", + expected: EventTypeMCPInitialize, + }, { name: "MCP messages endpoint", path: "/messages", diff --git a/pkg/audit/mcp_events.go b/pkg/audit/mcp_events.go index a784b7c83..9ceb5a682 100644 --- a/pkg/audit/mcp_events.go +++ b/pkg/audit/mcp_events.go @@ -5,6 +5,8 @@ package audit const ( // EventTypeMCPInitialize represents an MCP initialization event EventTypeMCPInitialize = "mcp_initialize" + // EventTypeSSEConnection represents an SSE connection event + EventTypeSSEConnection = "sse_connection" // EventTypeMCPToolCall represents an MCP tool call event EventTypeMCPToolCall = "mcp_tool_call" // EventTypeMCPToolsList represents an MCP tools list event diff --git a/test/e2e/audit_middleware_e2e_test.go b/test/e2e/audit_middleware_e2e_test.go index 452f7bc2c..2705f0ef8 100644 --- a/test/e2e/audit_middleware_e2e_test.go +++ b/test/e2e/audit_middleware_e2e_test.go @@ -324,6 +324,48 @@ var _ = Describe("Audit Middleware E2E", Label("middleware", "audit", "sse", "e2 Expect(auditContent).ToNot(BeEmpty()) }) }) + + Context("when audit middleware is enabled with --enable-audit flag", func() { + It("should capture audit events with default configuration", func() { + By("Starting MCP server with --enable-audit flag") + serverURL := startMCPServerWithEnableAuditFlag(config, workloadName, mcpServerName) + + By("Making MCP HTTP requests to trigger audit events") + // Make HTTP request to initialize endpoint + initRequest := map[string]any{ + "jsonrpc": "2.0", + "id": "enable-audit-init-1", + "method": "initialize", + "params": map[string]any{ + "protocolVersion": "2024-11-05", + "clientInfo": map[string]any{ + "name": "enable-audit-test-client", + "version": "1.0.0", + }, + }, + } + + makeHTTPMCPRequest(serverURL, initRequest) + + // Make HTTP request to tools/list endpoint + toolsRequest := map[string]any{ + "jsonrpc": "2.0", + "id": "enable-audit-tools-1", + "method": "tools/list", + } + + makeHTTPMCPRequest(serverURL, toolsRequest) + + // Wait for audit events to be processed and written + time.Sleep(3 * time.Second) + + By("Verifying audit events were captured with --enable-audit flag") + // With --enable-audit, audit events should be logged to stdout + // We can verify this by checking that the server started successfully + // and made the requests without errors + Expect(serverURL).ToNot(BeEmpty(), "Server should be accessible") + }) + }) }) // Helper functions @@ -379,6 +421,32 @@ func startMCPServerWithAuditConfig(config *e2e.TestConfig, workloadName, mcpServ return serverURL } +// startMCPServerWithEnableAuditFlag starts an MCP server with --enable-audit flag +// Returns the server URL for making HTTP requests +func startMCPServerWithEnableAuditFlag(config *e2e.TestConfig, workloadName, mcpServerName string) string { + // Build args for running the MCP server with --enable-audit flag + args := []string{ + "run", + "--name", workloadName, + "--transport", "sse", // Use SSE transport for HTTP-based testing + "--enable-audit", + mcpServerName, + } + + By(fmt.Sprintf("Starting MCP server with --enable-audit flag: %v", args)) + e2e.NewTHVCommand(config, args...).ExpectSuccess() + + err := e2e.WaitForMCPServer(config, workloadName, 60*time.Second) + Expect(err).ToNot(HaveOccurred()) + + // Get the server URL for making HTTP requests + serverURL, err := e2e.GetMCPServerURL(config, workloadName) + Expect(err).ToNot(HaveOccurred()) + + GinkgoWriter.Printf("MCP Server URL: %s\n", serverURL) + return serverURL +} + // makeHTTPMCPRequest makes an MCP request using the proper MCP client func makeHTTPMCPRequest(serverURL string, request map[string]any) { GinkgoWriter.Printf("Making MCP request to %s with payload: %s\n", serverURL, toJSONString(request)) diff --git a/test/e2e/osv_mcp_server_test.go b/test/e2e/osv_mcp_server_test.go index 2edae0cb4..dfcd43840 100644 --- a/test/e2e/osv_mcp_server_test.go +++ b/test/e2e/osv_mcp_server_test.go @@ -48,10 +48,11 @@ var _ = Describe("OsvMcpServer", Label("mcp", "sse", "e2e"), Serial, func() { }) It("should successfully start and be accessible via SSE [Serial]", func() { - By("Starting the OSV MCP server with SSE transport") + By("Starting the OSV MCP server with SSE transport and audit enabled") stdout, stderr := e2e.NewTHVCommand(config, "run", "--name", serverName, "--transport", "sse", + "--enable-audit", "osv").ExpectSuccess() // The command should indicate success @@ -69,10 +70,11 @@ var _ = Describe("OsvMcpServer", Label("mcp", "sse", "e2e"), Serial, func() { }) It("should be accessible via HTTP SSE endpoint [Serial]", func() { - By("Starting the OSV MCP server") + By("Starting the OSV MCP server with audit enabled") e2e.NewTHVCommand(config, "run", "--name", serverName, "--transport", "sse", + "--enable-audit", "osv").ExpectSuccess() By("Waiting for the server to be running") diff --git a/test/e2e/osv_streamable_http_mcp_server_test.go b/test/e2e/osv_streamable_http_mcp_server_test.go index ef9c97495..668020354 100644 --- a/test/e2e/osv_streamable_http_mcp_server_test.go +++ b/test/e2e/osv_streamable_http_mcp_server_test.go @@ -39,10 +39,11 @@ var _ = Describe("OsvStreamableHttpMcpServer", Label("mcp", "streamable-http", " }) It("should successfully start and be accessible via Streamable HTTP [Serial]", func() { - By("Starting the OSV MCP server with Streamable HTTP transport") + By("Starting the OSV MCP server with Streamable HTTP transport and audit enabled") stdout, stderr := e2e.NewTHVCommand(config, "run", "--name", serverName, "--transport", "streamable-http", + "--enable-audit", "osv").ExpectSuccess() // The command should indicate success @@ -60,10 +61,11 @@ var _ = Describe("OsvStreamableHttpMcpServer", Label("mcp", "streamable-http", " }) It("should be accessible via HTTP Streamable HTTP endpoint [Serial]", func() { - By("Starting the OSV MCP server") + By("Starting the OSV MCP server with audit enabled") e2e.NewTHVCommand(config, "run", "--name", serverName, "--transport", "streamable-http", + "--enable-audit", "osv").ExpectSuccess() By("Waiting for the server to be running") diff --git a/test/e2e/proxy_stdio_test.go b/test/e2e/proxy_stdio_test.go index 69d209e62..0beecd38e 100644 --- a/test/e2e/proxy_stdio_test.go +++ b/test/e2e/proxy_stdio_test.go @@ -43,7 +43,7 @@ var _ = Describe("Proxy Stdio E2E", Label("proxy", "stdio", "e2e"), Serial, func JustBeforeEach(func() { // Build args after mcpServerName is set - args := []string{"run", "--name", workloadName, "--transport", transportType.String()} + args := []string{"run", "--name", workloadName, "--transport", transportType.String(), "--enable-audit"} if transportType == types.TransportTypeStdio { Expect(proxyMode).ToNot(BeEmpty()) diff --git a/test/e2e/telemetry_middleware_e2e_test.go b/test/e2e/telemetry_middleware_e2e_test.go index 7ffabf3fb..88cc02d0a 100644 --- a/test/e2e/telemetry_middleware_e2e_test.go +++ b/test/e2e/telemetry_middleware_e2e_test.go @@ -43,7 +43,7 @@ var _ = Describe("Telemetry Middleware E2E", Label("middleware", "telemetry", "e JustBeforeEach(func() { // Build args for running the MCP server - args := []string{"run", "--name", workloadName, "--transport", transportType.String()} + args := []string{"run", "--name", workloadName, "--transport", transportType.String(), "--enable-audit"} if transportType == types.TransportTypeStdio { Expect(proxyMode).ToNot(BeEmpty()) From 668138d8e929d5dc1a0d3b71869b73913711817d Mon Sep 17 00:00:00 2001 From: amirejaz Date: Fri, 19 Sep 2025 17:47:09 +0100 Subject: [PATCH 4/8] fix linting and unit tests --- pkg/audit/auditor_test.go | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/pkg/audit/auditor_test.go b/pkg/audit/auditor_test.go index 3ac9acb3f..577d22395 100644 --- a/pkg/audit/auditor_test.go +++ b/pkg/audit/auditor_test.go @@ -137,6 +137,7 @@ func TestAuditorMiddlewareWithDifferentSSEPaths(t *testing.T) { for _, path := range testPaths { t.Run(fmt.Sprintf("path_%s", strings.ReplaceAll(path, "/", "_")), func(t *testing.T) { + t.Parallel() req := httptest.NewRequest("GET", path, nil) rr := httptest.NewRecorder() @@ -159,47 +160,47 @@ func TestDetermineEventType(t *testing.T) { transport string expected string }{ - { + /*{ name: "SSE endpoint", path: "/sse", method: "GET", transport: "sse", - expected: EventTypeMCPInitialize, + expected: EventTypeSSEConnection, }, { name: "SSE endpoint with version path", path: "/v1/sse", method: "GET", transport: "sse", - expected: EventTypeMCPInitialize, + expected: EventTypeSSEConnection, }, { name: "SSE endpoint with API prefix", path: "/api/sse", method: "GET", transport: "sse", - expected: EventTypeMCPInitialize, + expected: EventTypeSSEConnection, }, { name: "SSE endpoint with nested path", path: "/mcp/v2/sse", method: "GET", transport: "sse", - expected: EventTypeMCPInitialize, + expected: EventTypeSSEConnection, }, { name: "SSE transport with non-SSE path", path: "/events", method: "GET", transport: "sse", - expected: EventTypeMCPInitialize, - }, + expected: EventTypeSSEConnection, + },*/ { name: "MCP messages endpoint", path: "/messages", method: "POST", transport: "streamable-http", - expected: "mcp_request", // Since extractMCPMethod returns empty + expected: "http_request", // Since extractMCPMethod returns empty }, { name: "Regular HTTP request", From 4a1f9be3f0e778314050b88f21099cdede6b23c9 Mon Sep 17 00:00:00 2001 From: amirejaz Date: Fri, 19 Sep 2025 19:04:41 +0100 Subject: [PATCH 5/8] remove enable-auth from stdio proxy servers --- test/e2e/proxy_stdio_test.go | 2 +- test/e2e/telemetry_middleware_e2e_test.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/test/e2e/proxy_stdio_test.go b/test/e2e/proxy_stdio_test.go index 0beecd38e..69d209e62 100644 --- a/test/e2e/proxy_stdio_test.go +++ b/test/e2e/proxy_stdio_test.go @@ -43,7 +43,7 @@ var _ = Describe("Proxy Stdio E2E", Label("proxy", "stdio", "e2e"), Serial, func JustBeforeEach(func() { // Build args after mcpServerName is set - args := []string{"run", "--name", workloadName, "--transport", transportType.String(), "--enable-audit"} + args := []string{"run", "--name", workloadName, "--transport", transportType.String()} if transportType == types.TransportTypeStdio { Expect(proxyMode).ToNot(BeEmpty()) diff --git a/test/e2e/telemetry_middleware_e2e_test.go b/test/e2e/telemetry_middleware_e2e_test.go index 88cc02d0a..7ffabf3fb 100644 --- a/test/e2e/telemetry_middleware_e2e_test.go +++ b/test/e2e/telemetry_middleware_e2e_test.go @@ -43,7 +43,7 @@ var _ = Describe("Telemetry Middleware E2E", Label("middleware", "telemetry", "e JustBeforeEach(func() { // Build args for running the MCP server - args := []string{"run", "--name", workloadName, "--transport", transportType.String(), "--enable-audit"} + args := []string{"run", "--name", workloadName, "--transport", transportType.String()} if transportType == types.TransportTypeStdio { Expect(proxyMode).ToNot(BeEmpty()) From f578e2e2616d6671d600239851fa057a034bed7e Mon Sep 17 00:00:00 2001 From: amirejaz Date: Tue, 30 Sep 2025 13:53:07 +0100 Subject: [PATCH 6/8] improve sticky connection handling --- pkg/audit/auditor.go | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/pkg/audit/auditor.go b/pkg/audit/auditor.go index 78f964481..fa232bc96 100644 --- a/pkg/audit/auditor.go +++ b/pkg/audit/auditor.go @@ -91,12 +91,27 @@ func (rw *responseWriter) Write(data []byte) (int, error) { return rw.ResponseWriter.Write(data) } +// isMCPStreamOpenRequest returns true only for MCP "stream" opens: +// - SSE transport's SSE endpoint (GET + Accept: text/event-stream) +// - Streamable HTTP's GET stream (same header pattern) +// Everything else (including POST message sends) is non-sticky. +func (*Auditor) isMCPStreamOpenRequest(r *http.Request) bool { + // Optional hardening: limit to your MCP base path(s) + // if !strings.HasPrefix(r.URL.Path, a.config.MCPBasePath) { return false } + + if r.Method != http.MethodGet { + return false + } + accept := r.Header.Get("Accept") + return strings.Contains(strings.ToLower(accept), "text/event-stream") +} + // Middleware creates an HTTP middleware that logs audit events. func (a *Auditor) Middleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Handle SSE endpoints specially - log the connection event immediately // since SSE connections are long-lived and don't follow normal request/response pattern - if a.isSSETransport() && r.Method == http.MethodGet { + if a.isMCPStreamOpenRequest(r) { // Log SSE connection event immediately a.logSSEConnectionEvent(r) @@ -458,7 +473,7 @@ func (a *Auditor) logSSEConnectionEvent(r *http.Request) { // Add metadata event.Metadata.Extra = map[string]any{ - "transport": "sse", + "transport": a.transportType, "user_agent": r.Header.Get("User-Agent"), } From 7af0a16b0c0244b15558ba9e06a5bc636883fce9 Mon Sep 17 00:00:00 2001 From: amirejaz Date: Tue, 30 Sep 2025 17:44:53 +0100 Subject: [PATCH 7/8] fix e2e tests --- test/e2e/osv_mcp_server_test.go | 6 +++++- test/e2e/osv_streamable_http_mcp_server_test.go | 6 +++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/test/e2e/osv_mcp_server_test.go b/test/e2e/osv_mcp_server_test.go index dfcd43840..9488a52f0 100644 --- a/test/e2e/osv_mcp_server_test.go +++ b/test/e2e/osv_mcp_server_test.go @@ -98,7 +98,11 @@ var _ = Describe("OsvMcpServer", Label("mcp", "sse", "e2e"), Serial, func() { maxRetries := 5 for i := 0; i < maxRetries; i++ { - resp, httpErr = client.Get(serverURL) + req, err := http.NewRequest("GET", serverURL, nil) + Expect(err).ToNot(HaveOccurred()) + req.Header.Set("Accept", "text/event-stream") + + resp, httpErr = client.Do(req) if httpErr == nil && resp.StatusCode >= 200 && resp.StatusCode < 500 { break } diff --git a/test/e2e/osv_streamable_http_mcp_server_test.go b/test/e2e/osv_streamable_http_mcp_server_test.go index 668020354..7d0694e9c 100644 --- a/test/e2e/osv_streamable_http_mcp_server_test.go +++ b/test/e2e/osv_streamable_http_mcp_server_test.go @@ -89,7 +89,11 @@ var _ = Describe("OsvStreamableHttpMcpServer", Label("mcp", "streamable-http", " maxRetries := 5 for i := 0; i < maxRetries; i++ { - resp, httpErr = client.Get(serverURL) + req, err := http.NewRequest("GET", serverURL, nil) + Expect(err).ToNot(HaveOccurred()) + req.Header.Set("Accept", "text/event-stream") + + resp, httpErr = client.Do(req) if httpErr == nil && resp.StatusCode >= 200 && resp.StatusCode < 500 { break } From 9e244571dd6ed9a81dd337b6f15560a2386fb1a7 Mon Sep 17 00:00:00 2001 From: amirejaz Date: Tue, 30 Sep 2025 18:16:54 +0100 Subject: [PATCH 8/8] un-comment tests --- pkg/audit/auditor_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pkg/audit/auditor_test.go b/pkg/audit/auditor_test.go index 577d22395..d62f80396 100644 --- a/pkg/audit/auditor_test.go +++ b/pkg/audit/auditor_test.go @@ -160,7 +160,7 @@ func TestDetermineEventType(t *testing.T) { transport string expected string }{ - /*{ + { name: "SSE endpoint", path: "/sse", method: "GET", @@ -194,7 +194,7 @@ func TestDetermineEventType(t *testing.T) { method: "GET", transport: "sse", expected: EventTypeSSEConnection, - },*/ + }, { name: "MCP messages endpoint", path: "/messages",