diff --git a/examples/mark3labs/advanced/main.go b/examples/mark3labs/advanced/main.go index 35ef7f7..df5a949 100644 --- a/examples/mark3labs/advanced/main.go +++ b/examples/mark3labs/advanced/main.go @@ -50,7 +50,7 @@ func main() { ) streamableServer := mcpserver.NewStreamableHTTPServer(mcpServer, httpOpts...) - // Feature 4: WrapHandler - Auto Bearer token pre-check with 401 + // Feature 4: WrapMCPEndpoint - Automatic 401 handling with CORS support mcpHandler := func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Access-Control-Allow-Origin", "*") w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS") @@ -64,7 +64,7 @@ func main() { streamableServer.ServeHTTP(w, r) } - mux.HandleFunc("/mcp", oauthServer.WrapHandlerFunc(mcpHandler)) + mux.HandleFunc("/mcp", oauthServer.WrapMCPEndpoint(http.HandlerFunc(mcpHandler))) // Add status endpoint (not OAuth protected) mux.HandleFunc("/status", func(w http.ResponseWriter, r *http.Request) { diff --git a/examples/mark3labs/simple/main.go b/examples/mark3labs/simple/main.go index 0a0ebd5..3913115 100644 --- a/examples/mark3labs/simple/main.go +++ b/examples/mark3labs/simple/main.go @@ -22,7 +22,7 @@ func main() { // export OKTA_DOMAIN="dev-12345.okta.com" (your Okta domain) // export OKTA_AUDIENCE="api://my-mcp-server" (your API identifier) // export SERVER_URL="https://mcp.example.com" (your server URL) - _, oauthOption, err := mark3labs.WithOAuth(mux, &oauth.Config{ + oauthServer, oauthOption, err := mark3labs.WithOAuth(mux, &oauth.Config{ Provider: "okta", Issuer: fmt.Sprintf("https://%s", getEnv("OKTA_DOMAIN", "dev-12345.okta.com")), Audience: getEnv("OKTA_AUDIENCE", "api://my-mcp-server"), @@ -51,13 +51,13 @@ func main() { }, ) - // 5. Setup MCP endpoint + // 5. Setup MCP endpoint with automatic 401 handling streamableServer := mcpserver.NewStreamableHTTPServer( mcpServer, mcpserver.WithEndpointPath("/mcp"), mcpserver.WithHTTPContextFunc(oauth.CreateHTTPContextFunc()), ) - mux.Handle("/mcp", streamableServer) + mux.HandleFunc("/mcp", oauthServer.WrapMCPEndpoint(streamableServer)) // 6. Start server // Note: PORT is the local bind port. If you change SERVER_URL port diff --git a/mark3labs/oauth.go b/mark3labs/oauth.go index f31969d..3a6320e 100644 --- a/mark3labs/oauth.go +++ b/mark3labs/oauth.go @@ -23,12 +23,16 @@ import ( // }) // mcpServer := server.NewMCPServer("Server", "1.0.0", oauthOption) // +// streamableServer := server.NewStreamableHTTPServer(mcpServer, ...) +// mux.HandleFunc("/mcp", oauthServer.WrapMCPEndpoint(streamableServer)) +// // This function: // - Creates OAuth server instance // - Registers OAuth HTTP endpoints on mux // - Returns server instance and middleware as server option // // The returned Server instance provides access to: +// - WrapMCPEndpoint() - Wrap /mcp endpoint with automatic 401 handling // - WrapHandler() - Wrap HTTP handlers with OAuth token validation // - GetHTTPServerOptions() - Get StreamableHTTPServer options // - LogStartup() - Log OAuth endpoint information diff --git a/mcp/oauth.go b/mcp/oauth.go index e8b3f5f..9a90911 100644 --- a/mcp/oauth.go +++ b/mcp/oauth.go @@ -3,6 +3,7 @@ package mcp import ( "fmt" "net/http" + "strings" "github.com/modelcontextprotocol/go-sdk/mcp" oauth "github.com/tuannvm/oauth-mcp-proxy" @@ -32,14 +33,18 @@ import ( // This function: // - Creates OAuth server instance // - Registers OAuth HTTP endpoints on mux -// - Wraps MCP StreamableHTTPHandler with OAuth token validation +// - Wraps MCP StreamableHTTPHandler with automatic 401 handling // - Returns OAuth server and protected HTTP handler // +// The returned handler automatically: +// - Returns 401 with WWW-Authenticate headers if Bearer token missing +// - Passes through OPTIONS requests (CORS pre-flight) +// - Rejects non-Bearer auth schemes (OAuth-only endpoint) +// // The returned oauth.Server instance provides access to: // - LogStartup() - Log OAuth endpoint information // - Discovery URL helpers (GetCallbackURL, GetMetadataURL, etc.) // -// The HTTP handler validates OAuth tokens before delegating to the MCP server. // Tool handlers can access the authenticated user via oauth.GetUserFromContext(ctx). func WithOAuth(mux *http.ServeMux, cfg *oauth.Config, mcpServer *mcp.Server) (*oauth.Server, http.Handler, error) { oauthServer, err := oauth.NewServer(cfg) @@ -54,24 +59,64 @@ func WithOAuth(mux *http.ServeMux, cfg *oauth.Config, mcpServer *mcp.Server) (*o }, nil) wrappedHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Pass through OPTIONS requests (CORS pre-flight) + if r.Method == http.MethodOptions { + mcpHandler.ServeHTTP(w, r) + return + } + + // Check Authorization header authHeader := r.Header.Get("Authorization") - if authHeader == "" || len(authHeader) < 7 || authHeader[:7] != "Bearer " { - http.Error(w, "Missing or invalid Authorization header", http.StatusUnauthorized) + authLower := strings.ToLower(authHeader) + + // Return 401 if Bearer token missing + if authHeader == "" { + oauthServer.Return401(w) return } - token := authHeader[7:] + // Check if it's a Bearer token (case-insensitive per OAuth 2.0 spec) + if !strings.HasPrefix(authLower, "bearer") { + // Reject non-Bearer schemes (OAuth endpoints require Bearer tokens only) + oauthServer.Return401(w) + return + } + + // Malformed Bearer token (no space after "Bearer") + if !strings.HasPrefix(authLower, "bearer ") { + oauthServer.Return401InvalidToken(w) + return + } + + // Extract and validate token (safe slice operation) + const bearerPrefix = "Bearer " + if len(authHeader) < len(bearerPrefix)+1 { + oauthServer.Return401InvalidToken(w) + return + } + token := authHeader[len(bearerPrefix):] + + // Clean any whitespace (e.g., "Bearer token ") + token = strings.TrimSpace(token) + + // Validate token is not empty + if token == "" { + oauthServer.Return401InvalidToken(w) + return + } user, err := oauthServer.ValidateTokenCached(r.Context(), token) if err != nil { - http.Error(w, fmt.Sprintf("Authentication failed: %v", err), http.StatusUnauthorized) + oauthServer.Return401InvalidToken(w) return } + // Add token and user to context ctx := oauth.WithOAuthToken(r.Context(), token) ctx = oauth.WithUser(ctx, user) r = r.WithContext(ctx) + // Pass to wrapped handler mcpHandler.ServeHTTP(w, r) }) diff --git a/middleware.go b/middleware.go index e82870d..3df55f1 100644 --- a/middleware.go +++ b/middleware.go @@ -119,10 +119,13 @@ func OAuthMiddleware(validator provider.TokenValidator, enabled bool) func(serve // via WithOAuthToken(). The OAuth middleware then retrieves it via GetOAuthToken(). func CreateHTTPContextFunc() func(context.Context, *http.Request) context.Context { return func(ctx context.Context, r *http.Request) context.Context { - // Extract Bearer token from Authorization header + // Extract Bearer token from Authorization header (case-insensitive per OAuth 2.0 spec) authHeader := r.Header.Get("Authorization") - if strings.HasPrefix(authHeader, "Bearer ") { - token := strings.TrimPrefix(authHeader, "Bearer ") + authLower := strings.ToLower(authHeader) + + if strings.HasPrefix(authLower, "bearer ") { + // Extract token (skip "Bearer " or "bearer " prefix) + token := authHeader[7:] // Clean any whitespace token = strings.TrimSpace(token) ctx = WithOAuthToken(ctx, token) diff --git a/oauth.go b/oauth.go index 4d66b86..027d96c 100644 --- a/oauth.go +++ b/oauth.go @@ -6,6 +6,7 @@ import ( "encoding/json" "fmt" "net/http" + "strings" "time" mcpserver "github.com/mark3labs/mcp-go/server" @@ -312,6 +313,106 @@ func (s *Server) WrapHandlerFunc(next http.HandlerFunc) http.HandlerFunc { return s.WrapHandler(next).ServeHTTP } +// WrapMCPEndpoint wraps an MCP endpoint handler with automatic 401 handling. +// Returns 401 with WWW-Authenticate headers if Bearer token is missing or invalid. +// +// This method provides automatic OAuth discovery for MCP clients by: +// - Passing through OPTIONS requests (CORS pre-flight) +// - Rejecting non-Bearer auth schemes (OAuth-only endpoint) +// - Returning 401 with proper headers if Bearer token is missing/malformed +// - Extracting token to context and passing to wrapped handler +// +// Usage with mark3labs SDK: +// +// streamableServer := server.NewStreamableHTTPServer(mcpServer, ...) +// mux.HandleFunc("/mcp", oauthServer.WrapMCPEndpoint(streamableServer)) +// +// For official SDK, use mcp.WithOAuth() which includes this automatically. +func (s *Server) WrapMCPEndpoint(handler http.Handler) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + // Pass through OPTIONS requests (CORS pre-flight) + if r.Method == http.MethodOptions { + handler.ServeHTTP(w, r) + return + } + + // Check Authorization header + authHeader := r.Header.Get("Authorization") + authLower := strings.ToLower(authHeader) + + // Return 401 if Bearer token missing + if authHeader == "" { + s.Return401(w) + return + } + + // Check if it's a Bearer token (case-insensitive per OAuth 2.0 spec) + if !strings.HasPrefix(authLower, "bearer") { + // Reject non-Bearer schemes (OAuth endpoints require Bearer tokens only) + s.Return401(w) + return + } + + // Malformed Bearer token (no space after "Bearer") + if !strings.HasPrefix(authLower, "bearer ") { + s.Return401InvalidToken(w) + return + } + + // Extract token to context + contextFunc := CreateHTTPContextFunc() + ctx := contextFunc(r.Context(), r) + r = r.WithContext(ctx) + + // Pass to wrapped handler + handler.ServeHTTP(w, r) + } +} + +// Return401 writes a 401 response with WWW-Authenticate header. +// Used by WrapMCPEndpoint and can be called by adapters. +// +// Returns error code "invalid_request" per RFC 6750 §3.1 for missing tokens. +// Includes resource_metadata URL for OAuth discovery. +func (s *Server) Return401(w http.ResponseWriter) { + metadataURL := s.GetProtectedResourceMetadataURL() + + // RFC 6750 compliant: all parameters in single Bearer header + w.Header().Set("WWW-Authenticate", fmt.Sprintf( + `Bearer realm="OAuth", error="invalid_request", error_description="Bearer token required", resource_metadata="%s"`, + metadataURL)) + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusUnauthorized) + + errorResponse := map[string]string{ + "error": "invalid_request", + "error_description": "Bearer token required", + } + _ = json.NewEncoder(w).Encode(errorResponse) +} + +// Return401InvalidToken writes a 401 response for invalid/expired tokens. +// Used when token validation fails (vs missing token). +// +// Returns error code "invalid_token" per RFC 6750 §3.1 for invalid tokens. +// Includes resource_metadata URL for OAuth discovery. +func (s *Server) Return401InvalidToken(w http.ResponseWriter) { + metadataURL := s.GetProtectedResourceMetadataURL() + + // RFC 6750 compliant: all parameters in single Bearer header + w.Header().Set("WWW-Authenticate", fmt.Sprintf( + `Bearer realm="OAuth", error="invalid_token", error_description="Authentication failed", resource_metadata="%s"`, + metadataURL)) + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusUnauthorized) + + errorResponse := map[string]string{ + "error": "invalid_token", + "error_description": "Authentication failed", + } + _ = json.NewEncoder(w).Encode(errorResponse) +} + // WithOAuth returns a server option that enables OAuth authentication // This is the composable API for mcp-go v0.41.1 //