Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/mark3labs/advanced/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ func main() {
)
streamableServer := mcpserver.NewStreamableHTTPServer(mcpServer, httpOpts...)

// Feature 4: WrapHandler - Auto Bearer token pre-check with 401
// Feature 4: WrapMCPEndpoint - Automatic 401 handling with CORS support
mcpHandler := func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
Expand All @@ -64,7 +64,7 @@ func main() {
streamableServer.ServeHTTP(w, r)
}

mux.HandleFunc("/mcp", oauthServer.WrapHandlerFunc(mcpHandler))
mux.HandleFunc("/mcp", oauthServer.WrapMCPEndpoint(http.HandlerFunc(mcpHandler)))

// Add status endpoint (not OAuth protected)
mux.HandleFunc("/status", func(w http.ResponseWriter, r *http.Request) {
Expand Down
6 changes: 3 additions & 3 deletions examples/mark3labs/simple/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ func main() {
// export OKTA_DOMAIN="dev-12345.okta.com" (your Okta domain)
// export OKTA_AUDIENCE="api://my-mcp-server" (your API identifier)
// export SERVER_URL="https://mcp.example.com" (your server URL)
_, oauthOption, err := mark3labs.WithOAuth(mux, &oauth.Config{
oauthServer, oauthOption, err := mark3labs.WithOAuth(mux, &oauth.Config{
Provider: "okta",
Issuer: fmt.Sprintf("https://%s", getEnv("OKTA_DOMAIN", "dev-12345.okta.com")),
Audience: getEnv("OKTA_AUDIENCE", "api://my-mcp-server"),
Expand Down Expand Up @@ -51,13 +51,13 @@ func main() {
},
)

// 5. Setup MCP endpoint
// 5. Setup MCP endpoint with automatic 401 handling
streamableServer := mcpserver.NewStreamableHTTPServer(
mcpServer,
mcpserver.WithEndpointPath("/mcp"),
mcpserver.WithHTTPContextFunc(oauth.CreateHTTPContextFunc()),
)
mux.Handle("/mcp", streamableServer)
mux.HandleFunc("/mcp", oauthServer.WrapMCPEndpoint(streamableServer))

// 6. Start server
// Note: PORT is the local bind port. If you change SERVER_URL port
Expand Down
4 changes: 4 additions & 0 deletions mark3labs/oauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,16 @@ import (
// })
// mcpServer := server.NewMCPServer("Server", "1.0.0", oauthOption)
//
// streamableServer := server.NewStreamableHTTPServer(mcpServer, ...)
// mux.HandleFunc("/mcp", oauthServer.WrapMCPEndpoint(streamableServer))
//
// This function:
// - Creates OAuth server instance
// - Registers OAuth HTTP endpoints on mux
// - Returns server instance and middleware as server option
//
// The returned Server instance provides access to:
// - WrapMCPEndpoint() - Wrap /mcp endpoint with automatic 401 handling
// - WrapHandler() - Wrap HTTP handlers with OAuth token validation
// - GetHTTPServerOptions() - Get StreamableHTTPServer options
// - LogStartup() - Log OAuth endpoint information
Expand Down
57 changes: 51 additions & 6 deletions mcp/oauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package mcp
import (
"fmt"
"net/http"
"strings"

"github.com/modelcontextprotocol/go-sdk/mcp"
oauth "github.com/tuannvm/oauth-mcp-proxy"
Expand Down Expand Up @@ -32,14 +33,18 @@ import (
// This function:
// - Creates OAuth server instance
// - Registers OAuth HTTP endpoints on mux
// - Wraps MCP StreamableHTTPHandler with OAuth token validation
// - Wraps MCP StreamableHTTPHandler with automatic 401 handling
// - Returns OAuth server and protected HTTP handler
//
// The returned handler automatically:
// - Returns 401 with WWW-Authenticate headers if Bearer token missing
// - Passes through OPTIONS requests (CORS pre-flight)
// - Rejects non-Bearer auth schemes (OAuth-only endpoint)
//
// The returned oauth.Server instance provides access to:
// - LogStartup() - Log OAuth endpoint information
// - Discovery URL helpers (GetCallbackURL, GetMetadataURL, etc.)
//
// The HTTP handler validates OAuth tokens before delegating to the MCP server.
// Tool handlers can access the authenticated user via oauth.GetUserFromContext(ctx).
func WithOAuth(mux *http.ServeMux, cfg *oauth.Config, mcpServer *mcp.Server) (*oauth.Server, http.Handler, error) {
oauthServer, err := oauth.NewServer(cfg)
Expand All @@ -54,24 +59,64 @@ func WithOAuth(mux *http.ServeMux, cfg *oauth.Config, mcpServer *mcp.Server) (*o
}, nil)

wrappedHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Pass through OPTIONS requests (CORS pre-flight)
if r.Method == http.MethodOptions {
mcpHandler.ServeHTTP(w, r)
return
}

// Check Authorization header
authHeader := r.Header.Get("Authorization")
if authHeader == "" || len(authHeader) < 7 || authHeader[:7] != "Bearer " {
http.Error(w, "Missing or invalid Authorization header", http.StatusUnauthorized)
authLower := strings.ToLower(authHeader)

// Return 401 if Bearer token missing
if authHeader == "" {
oauthServer.Return401(w)
return
}

token := authHeader[7:]
// Check if it's a Bearer token (case-insensitive per OAuth 2.0 spec)
if !strings.HasPrefix(authLower, "bearer") {
// Reject non-Bearer schemes (OAuth endpoints require Bearer tokens only)
oauthServer.Return401(w)
return
}

// Malformed Bearer token (no space after "Bearer")
if !strings.HasPrefix(authLower, "bearer ") {
oauthServer.Return401InvalidToken(w)
return
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔒 Letting non-Bearer Authorization headers fall through means anyone can send Authorization: Basic ... and reach the MCP handler without an OAuth token, undoing the previous 401 guard and creating an auth bypass. Please keep rejecting any scheme that isn't a properly formatted Bearer token before invoking the handler.

}

// Extract and validate token (safe slice operation)
const bearerPrefix = "Bearer "
if len(authHeader) < len(bearerPrefix)+1 {
oauthServer.Return401InvalidToken(w)
return
}
token := authHeader[len(bearerPrefix):]

// Clean any whitespace (e.g., "Bearer token ")
token = strings.TrimSpace(token)

// Validate token is not empty
if token == "" {
oauthServer.Return401InvalidToken(w)
return
}

user, err := oauthServer.ValidateTokenCached(r.Context(), token)
if err != nil {
http.Error(w, fmt.Sprintf("Authentication failed: %v", err), http.StatusUnauthorized)
oauthServer.Return401InvalidToken(w)
return
}

// Add token and user to context
ctx := oauth.WithOAuthToken(r.Context(), token)
ctx = oauth.WithUser(ctx, user)
r = r.WithContext(ctx)

// Pass to wrapped handler
mcpHandler.ServeHTTP(w, r)
})

Expand Down
9 changes: 6 additions & 3 deletions middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,10 +119,13 @@ func OAuthMiddleware(validator provider.TokenValidator, enabled bool) func(serve
// via WithOAuthToken(). The OAuth middleware then retrieves it via GetOAuthToken().
func CreateHTTPContextFunc() func(context.Context, *http.Request) context.Context {
return func(ctx context.Context, r *http.Request) context.Context {
// Extract Bearer token from Authorization header
// Extract Bearer token from Authorization header (case-insensitive per OAuth 2.0 spec)
authHeader := r.Header.Get("Authorization")
if strings.HasPrefix(authHeader, "Bearer ") {
token := strings.TrimPrefix(authHeader, "Bearer ")
authLower := strings.ToLower(authHeader)

if strings.HasPrefix(authLower, "bearer ") {
// Extract token (skip "Bearer " or "bearer " prefix)
token := authHeader[7:]
// Clean any whitespace
token = strings.TrimSpace(token)
ctx = WithOAuthToken(ctx, token)
Expand Down
101 changes: 101 additions & 0 deletions oauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"encoding/json"
"fmt"
"net/http"
"strings"
"time"

mcpserver "github.com/mark3labs/mcp-go/server"
Expand Down Expand Up @@ -312,6 +313,106 @@ func (s *Server) WrapHandlerFunc(next http.HandlerFunc) http.HandlerFunc {
return s.WrapHandler(next).ServeHTTP
}

// WrapMCPEndpoint wraps an MCP endpoint handler with automatic 401 handling.
// Returns 401 with WWW-Authenticate headers if Bearer token is missing or invalid.
//
// This method provides automatic OAuth discovery for MCP clients by:
// - Passing through OPTIONS requests (CORS pre-flight)
// - Rejecting non-Bearer auth schemes (OAuth-only endpoint)
// - Returning 401 with proper headers if Bearer token is missing/malformed
// - Extracting token to context and passing to wrapped handler
//
// Usage with mark3labs SDK:
//
// streamableServer := server.NewStreamableHTTPServer(mcpServer, ...)
// mux.HandleFunc("/mcp", oauthServer.WrapMCPEndpoint(streamableServer))
//
// For official SDK, use mcp.WithOAuth() which includes this automatically.
func (s *Server) WrapMCPEndpoint(handler http.Handler) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
// Pass through OPTIONS requests (CORS pre-flight)
if r.Method == http.MethodOptions {
handler.ServeHTTP(w, r)
return
}

// Check Authorization header
authHeader := r.Header.Get("Authorization")
authLower := strings.ToLower(authHeader)

// Return 401 if Bearer token missing
if authHeader == "" {
s.Return401(w)
return
}

// Check if it's a Bearer token (case-insensitive per OAuth 2.0 spec)
if !strings.HasPrefix(authLower, "bearer") {
// Reject non-Bearer schemes (OAuth endpoints require Bearer tokens only)
s.Return401(w)
return
}

// Malformed Bearer token (no space after "Bearer")
if !strings.HasPrefix(authLower, "bearer ") {
s.Return401InvalidToken(w)
return
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔒 Passing through non-Bearer Authorization headers here lets a client hit the MCP endpoint with Authorization: Basic ... and skip OAuth entirely, which regressed from the prior 401 guard. Please continue to reject any non-Bearer scheme before calling the wrapped handler.

}

// Extract token to context
contextFunc := CreateHTTPContextFunc()
ctx := contextFunc(r.Context(), r)
r = r.WithContext(ctx)

// Pass to wrapped handler
handler.ServeHTTP(w, r)
}
}

// Return401 writes a 401 response with WWW-Authenticate header.
// Used by WrapMCPEndpoint and can be called by adapters.
//
// Returns error code "invalid_request" per RFC 6750 §3.1 for missing tokens.
// Includes resource_metadata URL for OAuth discovery.
func (s *Server) Return401(w http.ResponseWriter) {
metadataURL := s.GetProtectedResourceMetadataURL()

// RFC 6750 compliant: all parameters in single Bearer header
w.Header().Set("WWW-Authenticate", fmt.Sprintf(
`Bearer realm="OAuth", error="invalid_request", error_description="Bearer token required", resource_metadata="%s"`,
metadataURL))
w.Header().Set("Content-Type", "application/json")
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Logic This second WWW-Authenticate header omits the auth scheme, so clients treat it as invalid and ignore the resource_metadata hint. Please include it in the Bearer header (and do the same in Return401InvalidToken) so OAuth discovery works.

w.WriteHeader(http.StatusUnauthorized)

errorResponse := map[string]string{
"error": "invalid_request",
"error_description": "Bearer token required",
}
_ = json.NewEncoder(w).Encode(errorResponse)
}

// Return401InvalidToken writes a 401 response for invalid/expired tokens.
// Used when token validation fails (vs missing token).
//
// Returns error code "invalid_token" per RFC 6750 §3.1 for invalid tokens.
// Includes resource_metadata URL for OAuth discovery.
func (s *Server) Return401InvalidToken(w http.ResponseWriter) {
metadataURL := s.GetProtectedResourceMetadataURL()

// RFC 6750 compliant: all parameters in single Bearer header
w.Header().Set("WWW-Authenticate", fmt.Sprintf(
`Bearer realm="OAuth", error="invalid_token", error_description="Authentication failed", resource_metadata="%s"`,
metadataURL))
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusUnauthorized)

errorResponse := map[string]string{
"error": "invalid_token",
"error_description": "Authentication failed",
}
_ = json.NewEncoder(w).Encode(errorResponse)
}

// WithOAuth returns a server option that enables OAuth authentication
// This is the composable API for mcp-go v0.41.1
//
Expand Down
Loading