Skip to content
Merged
57 changes: 38 additions & 19 deletions pkg/audit/auditor.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"github.com/stacklok/toolhive/pkg/auth"
"github.com/stacklok/toolhive/pkg/logger"
"github.com/stacklok/toolhive/pkg/mcp"
"github.com/stacklok/toolhive/pkg/transport/types"
)

// LevelAudit is a custom audit log level - between Info and Warn
Expand All @@ -35,12 +36,13 @@ func NewAuditLogger(w io.Writer) *slog.Logger {

// Auditor handles audit logging for HTTP requests.
type Auditor struct {
config *Config
auditLogger *slog.Logger
config *Config
auditLogger *slog.Logger
transportType string // e.g., "sse", "streamable-http"
}

// NewAuditor creates a new Auditor with the given configuration.
func NewAuditor(config *Config) (*Auditor, error) {
// NewAuditorWithTransport creates a new Auditor with the given configuration and transport information.
func NewAuditorWithTransport(config *Config, transportType string) (*Auditor, error) {
var logWriter io.Writer = os.Stdout // default to stdout

if config != nil {
Expand All @@ -54,11 +56,17 @@ func NewAuditor(config *Config) (*Auditor, error) {
}

return &Auditor{
config: config,
auditLogger: NewAuditLogger(logWriter),
config: config,
auditLogger: NewAuditLogger(logWriter),
transportType: transportType,
}, nil
}

// isSSETransport checks if the current transport is SSE
func (a *Auditor) isSSETransport() bool {
return a.transportType == types.TransportTypeSSE.String()
}

// responseWriter wraps http.ResponseWriter to capture response data and status.
type responseWriter struct {
http.ResponseWriter
Expand All @@ -83,12 +91,27 @@ func (rw *responseWriter) Write(data []byte) (int, error) {
return rw.ResponseWriter.Write(data)
}

// isMCPStreamOpenRequest returns true only for MCP "stream" opens:
// - SSE transport's SSE endpoint (GET + Accept: text/event-stream)
// - Streamable HTTP's GET stream (same header pattern)
// Everything else (including POST message sends) is non-sticky.
func (*Auditor) isMCPStreamOpenRequest(r *http.Request) bool {
// Optional hardening: limit to your MCP base path(s)
// if !strings.HasPrefix(r.URL.Path, a.config.MCPBasePath) { return false }

if r.Method != http.MethodGet {
return false
}
accept := r.Header.Get("Accept")
return strings.Contains(strings.ToLower(accept), "text/event-stream")
}

// Middleware creates an HTTP middleware that logs audit events.
func (a *Auditor) Middleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Handle SSE endpoints specially - log the connection event immediately
// since SSE connections are long-lived and don't follow normal request/response pattern
if r.URL.Path == "/sse" {
if a.isMCPStreamOpenRequest(r) {
// Log SSE connection event immediately
a.logSSEConnectionEvent(r)

Expand Down Expand Up @@ -164,7 +187,7 @@ func (a *Auditor) logAuditEvent(r *http.Request, rw *responseWriter, requestData
}

// Add metadata
a.addMetadata(event, r, duration, rw)
a.addMetadata(event, duration, rw)

// Add request/response data if configured
a.addEventData(event, r, rw, requestData)
Expand All @@ -180,16 +203,12 @@ func (a *Auditor) determineEventType(r *http.Request) string {
return a.mapMCPMethodToEventType(mcpMethod)
}

// Fall back to path-based detection for non-MCP requests
path := r.URL.Path

// Handle SSE connection establishment
if strings.Contains(path, "/sse") {
return EventTypeMCPInitialize
if a.isSSETransport() && r.Method == http.MethodGet {
return EventTypeSSEConnection
}

// Handle MCP message endpoints that weren't parsed (malformed requests)
if strings.Contains(path, "/messages") && r.Method == "POST" {
if a.isSSETransport() && r.Method == http.MethodPost {
return EventTypeMCPRequest
}

Expand Down Expand Up @@ -372,7 +391,7 @@ func (*Auditor) extractTarget(r *http.Request, eventType string) map[string]stri
}

// addMetadata adds metadata to the audit event.
func (*Auditor) addMetadata(event *AuditEvent, r *http.Request, duration time.Duration, rw *responseWriter) {
func (a *Auditor) addMetadata(event *AuditEvent, duration time.Duration, rw *responseWriter) {
if event.Metadata.Extra == nil {
event.Metadata.Extra = make(map[string]any)
}
Expand All @@ -381,7 +400,7 @@ func (*Auditor) addMetadata(event *AuditEvent, r *http.Request, duration time.Du
event.Metadata.Extra[MetadataExtraKeyDuration] = duration.Milliseconds()

// Add transport information
if strings.Contains(r.URL.Path, "/sse") {
if a.isSSETransport() {
event.Metadata.Extra[MetadataExtraKeyTransport] = "sse"
} else {
event.Metadata.Extra[MetadataExtraKeyTransport] = "http"
Expand Down Expand Up @@ -442,7 +461,7 @@ func (a *Auditor) logSSEConnectionEvent(r *http.Request) {
component := a.determineComponent(r)

// Create the audit event for SSE connection
event := NewAuditEvent("sse_connection", source, OutcomeSuccess, subjects, component)
event := NewAuditEvent(EventTypeSSEConnection, source, OutcomeSuccess, subjects, component)

// Add target information
target := map[string]string{
Expand All @@ -454,7 +473,7 @@ func (a *Auditor) logSSEConnectionEvent(r *http.Request) {

// Add metadata
event.Metadata.Extra = map[string]any{
"transport": "sse",
"transport": a.transportType,
"user_agent": r.Header.Get("User-Agent"),
}

Expand Down
Loading
Loading