From 5bf4fba03a77e9ad4ee57555e9729a09afbf94c5 Mon Sep 17 00:00:00 2001 From: Tommy Nguyen Date: Mon, 3 Nov 2025 14:24:45 -0800 Subject: [PATCH 1/3] fix(mcp/oauth): Add WrapMCPEndpoint for automatic 401 handling Signed-off-by: Tommy Nguyen --- examples/mark3labs/advanced/main.go | 4 +- examples/mark3labs/simple/main.go | 6 +- mark3labs/oauth.go | 4 ++ mcp/oauth.go | 55 ++++++++++++++-- oauth.go | 99 +++++++++++++++++++++++++++++ 5 files changed, 157 insertions(+), 11 deletions(-) diff --git a/examples/mark3labs/advanced/main.go b/examples/mark3labs/advanced/main.go index 35ef7f7..df5a949 100644 --- a/examples/mark3labs/advanced/main.go +++ b/examples/mark3labs/advanced/main.go @@ -50,7 +50,7 @@ func main() { ) streamableServer := mcpserver.NewStreamableHTTPServer(mcpServer, httpOpts...) - // Feature 4: WrapHandler - Auto Bearer token pre-check with 401 + // Feature 4: WrapMCPEndpoint - Automatic 401 handling with CORS support mcpHandler := func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Access-Control-Allow-Origin", "*") w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS") @@ -64,7 +64,7 @@ func main() { streamableServer.ServeHTTP(w, r) } - mux.HandleFunc("/mcp", oauthServer.WrapHandlerFunc(mcpHandler)) + mux.HandleFunc("/mcp", oauthServer.WrapMCPEndpoint(http.HandlerFunc(mcpHandler))) // Add status endpoint (not OAuth protected) mux.HandleFunc("/status", func(w http.ResponseWriter, r *http.Request) { diff --git a/examples/mark3labs/simple/main.go b/examples/mark3labs/simple/main.go index 0a0ebd5..3913115 100644 --- a/examples/mark3labs/simple/main.go +++ b/examples/mark3labs/simple/main.go @@ -22,7 +22,7 @@ func main() { // export OKTA_DOMAIN="dev-12345.okta.com" (your Okta domain) // export OKTA_AUDIENCE="api://my-mcp-server" (your API identifier) // export SERVER_URL="https://mcp.example.com" (your server URL) - _, oauthOption, err := mark3labs.WithOAuth(mux, &oauth.Config{ + oauthServer, oauthOption, err := mark3labs.WithOAuth(mux, &oauth.Config{ Provider: "okta", Issuer: fmt.Sprintf("https://%s", getEnv("OKTA_DOMAIN", "dev-12345.okta.com")), Audience: getEnv("OKTA_AUDIENCE", "api://my-mcp-server"), @@ -51,13 +51,13 @@ func main() { }, ) - // 5. Setup MCP endpoint + // 5. Setup MCP endpoint with automatic 401 handling streamableServer := mcpserver.NewStreamableHTTPServer( mcpServer, mcpserver.WithEndpointPath("/mcp"), mcpserver.WithHTTPContextFunc(oauth.CreateHTTPContextFunc()), ) - mux.Handle("/mcp", streamableServer) + mux.HandleFunc("/mcp", oauthServer.WrapMCPEndpoint(streamableServer)) // 6. Start server // Note: PORT is the local bind port. If you change SERVER_URL port diff --git a/mark3labs/oauth.go b/mark3labs/oauth.go index f31969d..3a6320e 100644 --- a/mark3labs/oauth.go +++ b/mark3labs/oauth.go @@ -23,12 +23,16 @@ import ( // }) // mcpServer := server.NewMCPServer("Server", "1.0.0", oauthOption) // +// streamableServer := server.NewStreamableHTTPServer(mcpServer, ...) +// mux.HandleFunc("/mcp", oauthServer.WrapMCPEndpoint(streamableServer)) +// // This function: // - Creates OAuth server instance // - Registers OAuth HTTP endpoints on mux // - Returns server instance and middleware as server option // // The returned Server instance provides access to: +// - WrapMCPEndpoint() - Wrap /mcp endpoint with automatic 401 handling // - WrapHandler() - Wrap HTTP handlers with OAuth token validation // - GetHTTPServerOptions() - Get StreamableHTTPServer options // - LogStartup() - Log OAuth endpoint information diff --git a/mcp/oauth.go b/mcp/oauth.go index e8b3f5f..62123e9 100644 --- a/mcp/oauth.go +++ b/mcp/oauth.go @@ -3,6 +3,7 @@ package mcp import ( "fmt" "net/http" + "strings" "github.com/modelcontextprotocol/go-sdk/mcp" oauth "github.com/tuannvm/oauth-mcp-proxy" @@ -32,14 +33,18 @@ import ( // This function: // - Creates OAuth server instance // - Registers OAuth HTTP endpoints on mux -// - Wraps MCP StreamableHTTPHandler with OAuth token validation +// - Wraps MCP StreamableHTTPHandler with automatic 401 handling // - Returns OAuth server and protected HTTP handler // +// The returned handler automatically: +// - Returns 401 with WWW-Authenticate headers if Bearer token missing +// - Passes through OPTIONS requests (CORS pre-flight) +// - Passes through non-Bearer auth schemes (e.g., Basic auth) +// // The returned oauth.Server instance provides access to: // - LogStartup() - Log OAuth endpoint information // - Discovery URL helpers (GetCallbackURL, GetMetadataURL, etc.) // -// The HTTP handler validates OAuth tokens before delegating to the MCP server. // Tool handlers can access the authenticated user via oauth.GetUserFromContext(ctx). func WithOAuth(mux *http.ServeMux, cfg *oauth.Config, mcpServer *mcp.Server) (*oauth.Server, http.Handler, error) { oauthServer, err := oauth.NewServer(cfg) @@ -54,24 +59,62 @@ func WithOAuth(mux *http.ServeMux, cfg *oauth.Config, mcpServer *mcp.Server) (*o }, nil) wrappedHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Pass through OPTIONS requests (CORS pre-flight) + if r.Method == http.MethodOptions { + mcpHandler.ServeHTTP(w, r) + return + } + + // Check Authorization header authHeader := r.Header.Get("Authorization") - if authHeader == "" || len(authHeader) < 7 || authHeader[:7] != "Bearer " { - http.Error(w, "Missing or invalid Authorization header", http.StatusUnauthorized) + authLower := strings.ToLower(authHeader) + + // Return 401 if Bearer token missing + if authHeader == "" { + oauthServer.Return401(w) return } - token := authHeader[7:] + // Check if it's a Bearer token (case-insensitive per OAuth 2.0 spec) + if strings.HasPrefix(authLower, "bearer") { + // Malformed Bearer token (no space after "Bearer") + if !strings.HasPrefix(authLower, "bearer ") { + oauthServer.Return401InvalidToken(w) + return + } + // Valid Bearer format, continue to validation + } else { + // Pass through non-Bearer schemes (e.g., Basic auth) + mcpHandler.ServeHTTP(w, r) + return + } + + // Extract and validate token (safe slice operation) + const bearerPrefix = "Bearer " + if len(authHeader) < len(bearerPrefix)+1 { + oauthServer.Return401InvalidToken(w) + return + } + token := authHeader[len(bearerPrefix):] + + // Validate token is not empty + if token == "" { + oauthServer.Return401InvalidToken(w) + return + } user, err := oauthServer.ValidateTokenCached(r.Context(), token) if err != nil { - http.Error(w, fmt.Sprintf("Authentication failed: %v", err), http.StatusUnauthorized) + oauthServer.Return401InvalidToken(w) return } + // Add token and user to context ctx := oauth.WithOAuthToken(r.Context(), token) ctx = oauth.WithUser(ctx, user) r = r.WithContext(ctx) + // Pass to wrapped handler mcpHandler.ServeHTTP(w, r) }) diff --git a/oauth.go b/oauth.go index 4d66b86..f46fae7 100644 --- a/oauth.go +++ b/oauth.go @@ -6,6 +6,7 @@ import ( "encoding/json" "fmt" "net/http" + "strings" "time" mcpserver "github.com/mark3labs/mcp-go/server" @@ -312,6 +313,104 @@ func (s *Server) WrapHandlerFunc(next http.HandlerFunc) http.HandlerFunc { return s.WrapHandler(next).ServeHTTP } +// WrapMCPEndpoint wraps an MCP endpoint handler with automatic 401 handling. +// Returns 401 with WWW-Authenticate headers if Bearer token is missing. +// +// This method provides automatic OAuth discovery for MCP clients by: +// - Passing through OPTIONS requests (CORS pre-flight) +// - Passing through non-Bearer auth schemes (e.g., Basic auth) +// - Returning 401 with proper headers if Bearer token is missing +// - Extracting token to context and passing to wrapped handler +// +// Usage with mark3labs SDK: +// +// streamableServer := server.NewStreamableHTTPServer(mcpServer, ...) +// mux.HandleFunc("/mcp", oauthServer.WrapMCPEndpoint(streamableServer)) +// +// For official SDK, use mcp.WithOAuth() which includes this automatically. +func (s *Server) WrapMCPEndpoint(handler http.Handler) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + // Pass through OPTIONS requests (CORS pre-flight) + if r.Method == http.MethodOptions { + handler.ServeHTTP(w, r) + return + } + + // Check Authorization header + authHeader := r.Header.Get("Authorization") + authLower := strings.ToLower(authHeader) + + // Return 401 if Bearer token missing + if authHeader == "" { + s.Return401(w) + return + } + + // Check if it's a Bearer token (case-insensitive per OAuth 2.0 spec) + if strings.HasPrefix(authLower, "bearer") { + // Malformed Bearer token (no space after "Bearer") + if !strings.HasPrefix(authLower, "bearer ") { + s.Return401(w) + return + } + // Valid Bearer format, extract to context + // (validation happens in downstream middleware) + } else { + // Pass through non-Bearer schemes (e.g., Basic auth) + handler.ServeHTTP(w, r) + return + } + + // Extract token to context + contextFunc := CreateHTTPContextFunc() + ctx := contextFunc(r.Context(), r) + r = r.WithContext(ctx) + + // Pass to wrapped handler + handler.ServeHTTP(w, r) + } +} + +// Return401 writes a 401 response with WWW-Authenticate headers. +// Used by WrapMCPEndpoint and can be called by adapters. +// +// Returns error code "invalid_request" per RFC 6750 §3.1 for missing tokens. +// Includes resource_metadata URL for OAuth discovery. +func (s *Server) Return401(w http.ResponseWriter) { + metadataURL := s.GetProtectedResourceMetadataURL() + + w.Header().Add("WWW-Authenticate", `Bearer realm="OAuth", error="invalid_request", error_description="Bearer token required"`) + w.Header().Add("WWW-Authenticate", fmt.Sprintf(`resource_metadata="%s"`, metadataURL)) + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusUnauthorized) + + errorResponse := map[string]string{ + "error": "invalid_request", + "error_description": "Bearer token required", + } + _ = json.NewEncoder(w).Encode(errorResponse) +} + +// Return401InvalidToken writes a 401 response for invalid/expired tokens. +// Used when token validation fails (vs missing token). +// +// Returns error code "invalid_token" per RFC 6750 §3.1 for invalid tokens. +// Includes resource_metadata URL for OAuth discovery. +func (s *Server) Return401InvalidToken(w http.ResponseWriter) { + metadataURL := s.GetProtectedResourceMetadataURL() + + w.Header().Add("WWW-Authenticate", `Bearer realm="OAuth", error="invalid_token", error_description="Authentication failed"`) + w.Header().Add("WWW-Authenticate", fmt.Sprintf(`resource_metadata="%s"`, metadataURL)) + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusUnauthorized) + + errorResponse := map[string]string{ + "error": "invalid_token", + "error_description": "Authentication failed", + } + _ = json.NewEncoder(w).Encode(errorResponse) +} + // WithOAuth returns a server option that enables OAuth authentication // This is the composable API for mcp-go v0.41.1 // From 7131d3b7cba40b031d5ba72b8c6f90fbddddc1e5 Mon Sep 17 00:00:00 2001 From: Tommy Nguyen Date: Mon, 3 Nov 2025 15:57:27 -0800 Subject: [PATCH 2/3] fix(oauth): trim whitespace from Bearer tokens in headers Signed-off-by: Tommy Nguyen --- mcp/oauth.go | 3 +++ middleware.go | 9 ++++++--- oauth.go | 16 ++++++++++------ 3 files changed, 19 insertions(+), 9 deletions(-) diff --git a/mcp/oauth.go b/mcp/oauth.go index 62123e9..94dbe6e 100644 --- a/mcp/oauth.go +++ b/mcp/oauth.go @@ -97,6 +97,9 @@ func WithOAuth(mux *http.ServeMux, cfg *oauth.Config, mcpServer *mcp.Server) (*o } token := authHeader[len(bearerPrefix):] + // Clean any whitespace (e.g., "Bearer token ") + token = strings.TrimSpace(token) + // Validate token is not empty if token == "" { oauthServer.Return401InvalidToken(w) diff --git a/middleware.go b/middleware.go index e82870d..3df55f1 100644 --- a/middleware.go +++ b/middleware.go @@ -119,10 +119,13 @@ func OAuthMiddleware(validator provider.TokenValidator, enabled bool) func(serve // via WithOAuthToken(). The OAuth middleware then retrieves it via GetOAuthToken(). func CreateHTTPContextFunc() func(context.Context, *http.Request) context.Context { return func(ctx context.Context, r *http.Request) context.Context { - // Extract Bearer token from Authorization header + // Extract Bearer token from Authorization header (case-insensitive per OAuth 2.0 spec) authHeader := r.Header.Get("Authorization") - if strings.HasPrefix(authHeader, "Bearer ") { - token := strings.TrimPrefix(authHeader, "Bearer ") + authLower := strings.ToLower(authHeader) + + if strings.HasPrefix(authLower, "bearer ") { + // Extract token (skip "Bearer " or "bearer " prefix) + token := authHeader[7:] // Clean any whitespace token = strings.TrimSpace(token) ctx = WithOAuthToken(ctx, token) diff --git a/oauth.go b/oauth.go index f46fae7..1df1f52 100644 --- a/oauth.go +++ b/oauth.go @@ -350,7 +350,7 @@ func (s *Server) WrapMCPEndpoint(handler http.Handler) http.HandlerFunc { if strings.HasPrefix(authLower, "bearer") { // Malformed Bearer token (no space after "Bearer") if !strings.HasPrefix(authLower, "bearer ") { - s.Return401(w) + s.Return401InvalidToken(w) return } // Valid Bearer format, extract to context @@ -371,7 +371,7 @@ func (s *Server) WrapMCPEndpoint(handler http.Handler) http.HandlerFunc { } } -// Return401 writes a 401 response with WWW-Authenticate headers. +// Return401 writes a 401 response with WWW-Authenticate header. // Used by WrapMCPEndpoint and can be called by adapters. // // Returns error code "invalid_request" per RFC 6750 §3.1 for missing tokens. @@ -379,8 +379,10 @@ func (s *Server) WrapMCPEndpoint(handler http.Handler) http.HandlerFunc { func (s *Server) Return401(w http.ResponseWriter) { metadataURL := s.GetProtectedResourceMetadataURL() - w.Header().Add("WWW-Authenticate", `Bearer realm="OAuth", error="invalid_request", error_description="Bearer token required"`) - w.Header().Add("WWW-Authenticate", fmt.Sprintf(`resource_metadata="%s"`, metadataURL)) + // RFC 6750 compliant: all parameters in single Bearer header + w.Header().Set("WWW-Authenticate", fmt.Sprintf( + `Bearer realm="OAuth", error="invalid_request", error_description="Bearer token required", resource_metadata="%s"`, + metadataURL)) w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusUnauthorized) @@ -399,8 +401,10 @@ func (s *Server) Return401(w http.ResponseWriter) { func (s *Server) Return401InvalidToken(w http.ResponseWriter) { metadataURL := s.GetProtectedResourceMetadataURL() - w.Header().Add("WWW-Authenticate", `Bearer realm="OAuth", error="invalid_token", error_description="Authentication failed"`) - w.Header().Add("WWW-Authenticate", fmt.Sprintf(`resource_metadata="%s"`, metadataURL)) + // RFC 6750 compliant: all parameters in single Bearer header + w.Header().Set("WWW-Authenticate", fmt.Sprintf( + `Bearer realm="OAuth", error="invalid_token", error_description="Authentication failed", resource_metadata="%s"`, + metadataURL)) w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusUnauthorized) From e0b23544f9fbe2f9d769378370db6609be05097c Mon Sep 17 00:00:00 2001 From: Tommy Nguyen Date: Mon, 3 Nov 2025 16:30:13 -0800 Subject: [PATCH 3/3] fix(oauth): reject non-Bearer auth schemes and handle malformed tokens Signed-off-by: Tommy Nguyen --- mcp/oauth.go | 21 ++++++++++----------- oauth.go | 26 ++++++++++++-------------- 2 files changed, 22 insertions(+), 25 deletions(-) diff --git a/mcp/oauth.go b/mcp/oauth.go index 94dbe6e..9a90911 100644 --- a/mcp/oauth.go +++ b/mcp/oauth.go @@ -39,7 +39,7 @@ import ( // The returned handler automatically: // - Returns 401 with WWW-Authenticate headers if Bearer token missing // - Passes through OPTIONS requests (CORS pre-flight) -// - Passes through non-Bearer auth schemes (e.g., Basic auth) +// - Rejects non-Bearer auth schemes (OAuth-only endpoint) // // The returned oauth.Server instance provides access to: // - LogStartup() - Log OAuth endpoint information @@ -76,16 +76,15 @@ func WithOAuth(mux *http.ServeMux, cfg *oauth.Config, mcpServer *mcp.Server) (*o } // Check if it's a Bearer token (case-insensitive per OAuth 2.0 spec) - if strings.HasPrefix(authLower, "bearer") { - // Malformed Bearer token (no space after "Bearer") - if !strings.HasPrefix(authLower, "bearer ") { - oauthServer.Return401InvalidToken(w) - return - } - // Valid Bearer format, continue to validation - } else { - // Pass through non-Bearer schemes (e.g., Basic auth) - mcpHandler.ServeHTTP(w, r) + if !strings.HasPrefix(authLower, "bearer") { + // Reject non-Bearer schemes (OAuth endpoints require Bearer tokens only) + oauthServer.Return401(w) + return + } + + // Malformed Bearer token (no space after "Bearer") + if !strings.HasPrefix(authLower, "bearer ") { + oauthServer.Return401InvalidToken(w) return } diff --git a/oauth.go b/oauth.go index 1df1f52..027d96c 100644 --- a/oauth.go +++ b/oauth.go @@ -314,12 +314,12 @@ func (s *Server) WrapHandlerFunc(next http.HandlerFunc) http.HandlerFunc { } // WrapMCPEndpoint wraps an MCP endpoint handler with automatic 401 handling. -// Returns 401 with WWW-Authenticate headers if Bearer token is missing. +// Returns 401 with WWW-Authenticate headers if Bearer token is missing or invalid. // // This method provides automatic OAuth discovery for MCP clients by: // - Passing through OPTIONS requests (CORS pre-flight) -// - Passing through non-Bearer auth schemes (e.g., Basic auth) -// - Returning 401 with proper headers if Bearer token is missing +// - Rejecting non-Bearer auth schemes (OAuth-only endpoint) +// - Returning 401 with proper headers if Bearer token is missing/malformed // - Extracting token to context and passing to wrapped handler // // Usage with mark3labs SDK: @@ -347,17 +347,15 @@ func (s *Server) WrapMCPEndpoint(handler http.Handler) http.HandlerFunc { } // Check if it's a Bearer token (case-insensitive per OAuth 2.0 spec) - if strings.HasPrefix(authLower, "bearer") { - // Malformed Bearer token (no space after "Bearer") - if !strings.HasPrefix(authLower, "bearer ") { - s.Return401InvalidToken(w) - return - } - // Valid Bearer format, extract to context - // (validation happens in downstream middleware) - } else { - // Pass through non-Bearer schemes (e.g., Basic auth) - handler.ServeHTTP(w, r) + if !strings.HasPrefix(authLower, "bearer") { + // Reject non-Bearer schemes (OAuth endpoints require Bearer tokens only) + s.Return401(w) + return + } + + // Malformed Bearer token (no space after "Bearer") + if !strings.HasPrefix(authLower, "bearer ") { + s.Return401InvalidToken(w) return }