" from Authorization header
+ // and adds it to the request context. OAuth middleware then validates it.
+ streamableServer := mcpserver.NewStreamableHTTPServer(
+ mcpServer,
+ mcpserver.WithEndpointPath("/mcp"), // MCP endpoint path
+ mcpserver.WithHTTPContextFunc(oauth.CreateHTTPContextFunc()), // Token extraction
+ )
+
+ mux.Handle("/mcp", streamableServer)
+
+ // Step 6: Generate a test token (for HMAC provider testing)
+ // In production with OIDC providers (Okta/Google/Azure), clients get tokens
+ // from the OAuth provider directly. This is just for local testing.
+ testToken := generateTestToken(&oauth.Config{
+ Issuer: "https://test.example.com",
+ Audience: "api://simple-server",
+ JWTSecret: []byte("test-secret-key-must-be-32-bytes-long!"),
+ })
+
+ log.Println("📋 Testing Instructions:")
+ log.Println()
+ log.Println("1. Server is starting on http://localhost:8080")
+ log.Println()
+ log.Println("2. Test OAuth metadata endpoint:")
+ log.Println(" curl http://localhost:8080/.well-known/oauth-authorization-server")
+ log.Println()
+ log.Println("3. Call the 'hello' tool with authentication:")
+ log.Printf(" curl -X POST http://localhost:8080/mcp \\\n")
+ log.Printf(" -H 'Authorization: Bearer %s' \\\n", testToken[:50]+"...")
+ log.Printf(" -H 'Content-Type: application/json' \\\n")
+ log.Printf(" -d '{\"jsonrpc\":\"2.0\",\"id\":1,\"method\":\"tools/call\",\"params\":{\"name\":\"hello\",\"arguments\":{}}}'\n")
+ log.Println()
+ log.Println("4. Try without token (should fail with authentication error)")
+ log.Println()
+
+ log.Println("🚀 Server starting on http://localhost:8080")
+ log.Println()
+ if err := http.ListenAndServe(":8080", mux); err != nil {
+ log.Fatalf("Server failed: %v", err)
+ }
+}
+
+// generateTestToken creates a valid JWT token for testing HMAC provider.
+// In production with OIDC providers (Okta, Google, Azure), clients obtain tokens
+// from the OAuth provider's authorization server, not from your code.
+func generateTestToken(cfg *oauth.Config) string {
+ token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
+ "sub": "test-user-123", // Subject: unique user identifier
+ "email": "test@example.com", // User's email address
+ "preferred_username": "testuser", // Username (optional)
+ "aud": cfg.Audience, // Must match Config.Audience!
+ "iss": cfg.Issuer, // Must match Config.Issuer
+ "exp": time.Now().Add(time.Hour).Unix(), // Token expires in 1 hour
+ "iat": time.Now().Unix(), // Issued at (now)
+ })
+
+ // Sign with secret (must match Config.JWTSecret)
+ tokenString, err := token.SignedString(cfg.JWTSecret)
+ if err != nil {
+ log.Fatalf("Failed to sign token: %v", err)
+ }
+
+ return tokenString
+}
diff --git a/fixed_redirect_test.go b/fixed_redirect_test.go
new file mode 100644
index 0000000..a4ef20a
--- /dev/null
+++ b/fixed_redirect_test.go
@@ -0,0 +1,102 @@
+package oauth
+
+import (
+ "crypto/rand"
+ "testing"
+)
+
+func TestFixedRedirectModeLocalhostOnly(t *testing.T) {
+ key := make([]byte, 32)
+ _, _ = rand.Read(key)
+
+ tests := []struct {
+ name string
+ clientURI string
+ shouldPass bool
+ expectedError string
+ }{
+ {
+ name: "HTTP localhost allowed",
+ clientURI: "http://localhost:8080/callback",
+ shouldPass: true,
+ },
+ {
+ name: "HTTP 127.0.0.1 allowed",
+ clientURI: "http://127.0.0.1:3000/callback",
+ shouldPass: true,
+ },
+ {
+ name: "HTTP IPv6 localhost allowed",
+ clientURI: "http://[::1]:9000/callback",
+ shouldPass: true,
+ },
+ {
+ name: "HTTPS localhost allowed",
+ clientURI: "https://localhost/callback",
+ shouldPass: true,
+ },
+ {
+ name: "HTTPS production domain rejected",
+ clientURI: "https://evil.com/callback",
+ shouldPass: false,
+ expectedError: "Fixed redirect mode only allows localhost",
+ },
+ {
+ name: "HTTP production domain rejected",
+ clientURI: "http://evil.com/callback",
+ shouldPass: false,
+ expectedError: "HTTPS required for non-localhost",
+ },
+ {
+ name: "localhost subdomain rejected",
+ clientURI: "https://localhost.evil.com/callback",
+ shouldPass: false,
+ expectedError: "Fixed redirect mode only allows localhost",
+ },
+ {
+ name: "URI with fragment rejected",
+ clientURI: "http://localhost:8080/callback#fragment",
+ shouldPass: false,
+ expectedError: "must not contain fragment",
+ },
+ {
+ name: "Custom scheme rejected",
+ clientURI: "custom://localhost:8080/callback",
+ shouldPass: false,
+ expectedError: "Invalid redirect_uri scheme",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ isLocalhost := isLocalhostURI(tt.clientURI)
+
+ if tt.shouldPass && !isLocalhost {
+ t.Errorf("Expected localhost detection to pass for %s", tt.clientURI)
+ }
+
+ if !tt.shouldPass && isLocalhost && tt.expectedError != "must not contain fragment" && tt.expectedError != "Invalid redirect_uri scheme" {
+ t.Errorf("Expected localhost detection to fail for %s", tt.clientURI)
+ }
+
+ t.Logf("URI: %s, isLocalhost: %v, shouldPass: %v", tt.clientURI, isLocalhost, tt.shouldPass)
+ })
+ }
+}
+
+func TestFixedRedirectModeSecurityModel(t *testing.T) {
+ t.Log("Fixed Redirect Mode Security Model:")
+ t.Log("- Single OAUTH_REDIRECT_URI configured (no commas)")
+ t.Log("- Server uses fixed URI to communicate with OAuth provider")
+ t.Log("- Client redirect URIs MUST be localhost for security")
+ t.Log("- HMAC-signed state prevents redirect URI tampering")
+ t.Log("")
+ t.Log("Attack Prevention:")
+ t.Log("1. Open Redirect → Localhost-only restriction prevents external redirects")
+ t.Log("2. State Tampering → HMAC signature verification prevents modification")
+ t.Log("3. Code Theft → PKCE prevents token exchange without code_verifier")
+ t.Log("4. HTTP Exposure → HTTPS required for non-localhost URIs")
+ t.Log("")
+ t.Log("Use Case: Development tools (MCP Inspector) running on localhost")
+ t.Log("Production: Use allowlist mode instead")
+}
diff --git a/go.mod b/go.mod
new file mode 100644
index 0000000..9974e17
--- /dev/null
+++ b/go.mod
@@ -0,0 +1,23 @@
+module github.com/tuannvm/oauth-mcp-proxy
+
+go 1.25.1
+
+require (
+ github.com/coreos/go-oidc/v3 v3.16.0
+ github.com/golang-jwt/jwt/v5 v5.3.0
+ github.com/mark3labs/mcp-go v0.41.1
+ golang.org/x/oauth2 v0.32.0
+)
+
+require (
+ github.com/bahlo/generic-list-go v0.2.0 // indirect
+ github.com/buger/jsonparser v1.1.1 // indirect
+ github.com/go-jose/go-jose/v4 v4.1.3 // indirect
+ github.com/google/uuid v1.6.0 // indirect
+ github.com/invopop/jsonschema v0.13.0 // indirect
+ github.com/mailru/easyjson v0.7.7 // indirect
+ github.com/spf13/cast v1.8.0 // indirect
+ github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect
+ github.com/yosida95/uritemplate/v3 v3.0.2 // indirect
+ gopkg.in/yaml.v3 v3.0.1 // indirect
+)
diff --git a/go.sum b/go.sum
new file mode 100644
index 0000000..f4e3114
--- /dev/null
+++ b/go.sum
@@ -0,0 +1,47 @@
+github.com/bahlo/generic-list-go v0.2.0 h1:5sz/EEAK+ls5wF+NeqDpk5+iNdMDXrh3z3nPnH1Wvgk=
+github.com/bahlo/generic-list-go v0.2.0/go.mod h1:2KvAjgMlE5NNynlg/5iLrrCCZ2+5xWbdbCW3pNTGyYg=
+github.com/buger/jsonparser v1.1.1 h1:2PnMjfWD7wBILjqQbt530v576A/cAbQvEW9gGIpYMUs=
+github.com/buger/jsonparser v1.1.1/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0=
+github.com/coreos/go-oidc/v3 v3.16.0 h1:qRQUCFstKpXwmEjDQTIbyY/5jF00+asXzSkmkoa/mow=
+github.com/coreos/go-oidc/v3 v3.16.0/go.mod h1:wqPbKFrVnE90vty060SB40FCJ8fTHTxSwyXJqZH+sI8=
+github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
+github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
+github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8=
+github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
+github.com/go-jose/go-jose/v4 v4.1.3 h1:CVLmWDhDVRa6Mi/IgCgaopNosCaHz7zrMeF9MlZRkrs=
+github.com/go-jose/go-jose/v4 v4.1.3/go.mod h1:x4oUasVrzR7071A4TnHLGSPpNOm2a21K9Kf04k1rs08=
+github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo=
+github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE=
+github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
+github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
+github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
+github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
+github.com/invopop/jsonschema v0.13.0 h1:KvpoAJWEjR3uD9Kbm2HWJmqsEaHt8lBUpd0qHcIi21E=
+github.com/invopop/jsonschema v0.13.0/go.mod h1:ffZ5Km5SWWRAIN6wbDXItl95euhFz2uON45H2qjYt+0=
+github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y=
+github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
+github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
+github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
+github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
+github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0=
+github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc=
+github.com/mark3labs/mcp-go v0.41.1 h1:w78eWfiQam2i8ICL7AL0WFiq7KHNJQ6UB53ZVtH4KGA=
+github.com/mark3labs/mcp-go v0.41.1/go.mod h1:T7tUa2jO6MavG+3P25Oy/jR7iCeJPHImCZHRymCn39g=
+github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
+github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
+github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8=
+github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
+github.com/spf13/cast v1.8.0 h1:gEN9K4b8Xws4EX0+a0reLmhq8moKn7ntRlQYgjPeCDk=
+github.com/spf13/cast v1.8.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo=
+github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
+github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
+github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc=
+github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw=
+github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4=
+github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4=
+golang.org/x/oauth2 v0.32.0 h1:jsCblLleRMDrxMN29H3z/k1KliIvpLgCkE6R8FXXNgY=
+golang.org/x/oauth2 v0.32.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA=
+gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
+gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
+gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
+gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
diff --git a/handlers.go b/handlers.go
new file mode 100644
index 0000000..a1fec89
--- /dev/null
+++ b/handlers.go
@@ -0,0 +1,812 @@
+package oauth
+
+import (
+ "context"
+ "crypto/hmac"
+ "crypto/rand"
+ "crypto/sha256"
+ "crypto/tls"
+ "encoding/base64"
+ "encoding/hex"
+ "encoding/json"
+ "fmt"
+ "io"
+ "log"
+ "net/http"
+ "net/url"
+ "os"
+ "strings"
+ "time"
+
+ "github.com/coreos/go-oidc/v3/oidc"
+ "golang.org/x/oauth2"
+)
+
+// OAuth2Handler handles OAuth2 flows using the standard library
+type OAuth2Handler struct {
+ config *OAuth2Config
+ oauth2Config *oauth2.Config
+ logger Logger
+}
+
+// GetConfig returns the OAuth2 configuration
+func (h *OAuth2Handler) GetConfig() *OAuth2Config {
+ return h.config
+}
+
+// OAuth2Config holds OAuth2 configuration
+type OAuth2Config struct {
+ Enabled bool
+ Mode string // "native" or "proxy"
+ Provider string
+ RedirectURIs string
+
+ // OIDC configuration
+ Issuer string
+ Audience string
+ ClientID string
+ ClientSecret string
+
+ // Server configuration
+ MCPHost string
+ MCPPort string
+ Scheme string
+
+ // MCPURL is the full URL of the MCP server, used for the resource endpoint in the OAuth 2.0 Protected Resource Metadata endpoint
+ MCPURL string
+
+ // Server version
+ Version string
+
+ // State signing key for integrity protection
+ stateSigningKey []byte
+}
+
+// NewOAuth2Handler creates a new OAuth2 handler using the standard library
+func NewOAuth2Handler(cfg *OAuth2Config, logger Logger) *OAuth2Handler {
+ if logger == nil {
+ logger = &defaultLogger{}
+ }
+
+ var endpoint oauth2.Endpoint
+
+ // Use OIDC discovery for supported providers, fallback to hardcoded for others
+ switch cfg.Provider {
+ case "okta", "google", "azure":
+ // Use OIDC discovery to get correct endpoints
+ if discoveredEndpoint, err := discoverOIDCEndpoints(cfg.Issuer); err != nil {
+ logger.Error("OIDC discovery failed for %s provider. Using Okta-style fallback endpoints which may not work for all providers: %v", cfg.Provider, err)
+ // Fallback to Okta-style endpoints as they're most common
+ endpoint = oauth2.Endpoint{
+ AuthURL: cfg.Issuer + "/oauth2/v1/authorize",
+ TokenURL: cfg.Issuer + "/oauth2/v1/token",
+ }
+ } else {
+ endpoint = discoveredEndpoint
+ }
+ default:
+ // For HMAC and unknown providers, use hardcoded endpoints
+ endpoint = oauth2.Endpoint{
+ AuthURL: cfg.Issuer + "/oauth2/v1/authorize",
+ TokenURL: cfg.Issuer + "/oauth2/v1/token",
+ }
+ }
+
+ oauth2Config := &oauth2.Config{
+ ClientID: cfg.ClientID,
+ ClientSecret: cfg.ClientSecret,
+ Endpoint: endpoint,
+ Scopes: []string{"openid", "profile", "email"},
+ }
+
+ // Log client configuration type for debugging
+ if cfg.ClientSecret == "" {
+ logger.Info("Configuring public client (no client secret)")
+ } else {
+ logger.Info("Configuring confidential client (with client secret)")
+ }
+
+ // Initialize state signing key
+ if len(cfg.stateSigningKey) == 0 {
+ logger.Warn("No state signing key configured, generating random key (will not persist across restarts)")
+ key := make([]byte, 32)
+ if _, err := rand.Read(key); err != nil {
+ logger.Error("Failed to generate state signing key: %v", err)
+ // Use a deterministic fallback (not ideal, but better than nothing)
+ cfg.stateSigningKey = []byte("insecure-fallback-key-please-configure-JWT_SECRET")
+ logger.Warn("Using insecure fallback key. Please configure JWT_SECRET environment variable.")
+ } else {
+ cfg.stateSigningKey = key
+ }
+ }
+
+ return &OAuth2Handler{
+ config: cfg,
+ oauth2Config: oauth2Config,
+ logger: logger,
+ }
+}
+
+// discoverOIDCEndpoints uses OIDC discovery to get the correct authorization and token endpoints
+func discoverOIDCEndpoints(issuer string) (oauth2.Endpoint, error) {
+ ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
+ defer cancel()
+
+ // Configure HTTP client with appropriate timeouts and TLS settings
+ httpClient := &http.Client{
+ Timeout: 10 * time.Second,
+ Transport: &http.Transport{
+ TLSClientConfig: &tls.Config{
+ InsecureSkipVerify: false, // Verify TLS certificates
+ MinVersion: tls.VersionTLS12,
+ },
+ IdleConnTimeout: 30 * time.Second,
+ TLSHandshakeTimeout: 10 * time.Second,
+ MaxIdleConns: 10,
+ MaxIdleConnsPerHost: 2,
+ },
+ }
+
+ // Create OIDC provider with custom HTTP client
+ provider, err := oidc.NewProvider(
+ oidc.ClientContext(ctx, httpClient),
+ issuer,
+ )
+ if err != nil {
+ return oauth2.Endpoint{}, fmt.Errorf("failed to discover OIDC provider: %w", err)
+ }
+
+ // Return the discovered endpoint
+ return provider.Endpoint(), nil
+}
+
+// NewOAuth2ConfigFromConfig creates OAuth2 config from generic Config
+func NewOAuth2ConfigFromConfig(cfg *Config, version string) *OAuth2Config {
+ mcpHost := getEnv("MCP_HOST", "localhost")
+ mcpPort := getEnv("MCP_PORT", "8080")
+
+ // Determine scheme based on HTTPS configuration
+ scheme := "http"
+ if getEnv("HTTPS_CERT_FILE", "") != "" && getEnv("HTTPS_KEY_FILE", "") != "" {
+ scheme = "https"
+ }
+
+ // Use ServerURL from config if provided, otherwise build from env vars
+ mcpURL := cfg.ServerURL
+ if mcpURL == "" {
+ mcpURL = getEnv("MCP_URL", fmt.Sprintf("%s://%s:%s", scheme, mcpHost, mcpPort))
+ }
+
+ return &OAuth2Config{
+ Enabled: true,
+ Mode: cfg.Mode,
+ Provider: cfg.Provider,
+ RedirectURIs: cfg.RedirectURIs,
+ Issuer: cfg.Issuer,
+ Audience: cfg.Audience,
+ ClientID: cfg.ClientID,
+ ClientSecret: cfg.ClientSecret,
+ MCPHost: mcpHost,
+ MCPPort: mcpPort,
+ MCPURL: mcpURL,
+ Scheme: scheme,
+ Version: version,
+ stateSigningKey: cfg.JWTSecret,
+ }
+}
+
+// HandleJWKS handles the JWKS endpoint for proxy mode
+func (h *OAuth2Handler) HandleJWKS(w http.ResponseWriter, r *http.Request) {
+ // Defense in depth: Check OAuth mode
+ if h.config.Mode == "native" {
+ http.Error(w, "JWKS endpoint disabled in native mode", http.StatusNotFound)
+ return
+ }
+
+ w.Header().Set("Content-Type", "application/json")
+ w.Header().Set("Cache-Control", "public, max-age=300") // Cache for 5 minutes
+
+ if r.Method != "GET" {
+ http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
+ return
+ }
+
+ // Proxy JWKS from upstream OAuth provider
+ var jwksURL string
+ switch h.config.Provider {
+ case "okta":
+ // Use Okta's standard JWKS path
+ jwksURL = fmt.Sprintf("%s/oauth2/v1/keys", h.config.Issuer)
+ case "google":
+ jwksURL = "https://www.googleapis.com/oauth2/v3/certs"
+ case "azure":
+ jwksURL = fmt.Sprintf("%s/discovery/v2.0/keys", h.config.Issuer)
+ case "hmac":
+ // HMAC doesn't use JWKS, return empty key set
+ w.WriteHeader(http.StatusOK)
+ _, _ = w.Write([]byte(`{"keys":[]}`))
+ return
+ default:
+ http.Error(w, "JWKS not supported for this provider", http.StatusNotImplemented)
+ return
+ }
+
+ // Create HTTP client with timeout
+ client := &http.Client{
+ Timeout: 10 * time.Second,
+ Transport: &http.Transport{
+ TLSClientConfig: &tls.Config{InsecureSkipVerify: false},
+ },
+ }
+
+ // Fetch JWKS from upstream provider
+ resp, err := client.Get(jwksURL)
+ if err != nil {
+ h.logger.Error("OAuth2: Failed to fetch JWKS from %s: %v", jwksURL, err)
+ http.Error(w, "Failed to fetch JWKS", http.StatusBadGateway)
+ return
+ }
+ defer func() { _ = resp.Body.Close() }()
+
+ if resp.StatusCode != http.StatusOK {
+ h.logger.Error("OAuth2: JWKS endpoint returned status %d", resp.StatusCode)
+ http.Error(w, "JWKS endpoint error", http.StatusBadGateway)
+ return
+ }
+
+ // Copy response headers
+ w.Header().Set("Content-Type", resp.Header.Get("Content-Type"))
+ w.WriteHeader(http.StatusOK)
+
+ // Copy response body
+ if _, err := io.Copy(w, resp.Body); err != nil {
+ h.logger.Error("OAuth2: Failed to proxy JWKS response: %v", err)
+ }
+}
+
+// HandleAuthorize handles OAuth2 authorization requests with PKCE
+func (h *OAuth2Handler) HandleAuthorize(w http.ResponseWriter, r *http.Request) {
+ // Defense in depth: Check OAuth mode
+ if h.config.Mode == "native" {
+ http.Error(w, "OAuth proxy disabled in native mode", http.StatusNotFound)
+ return
+ }
+ if r.Method != "GET" {
+ http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
+ return
+ }
+
+ // Extract query parameters
+ query := r.URL.Query()
+
+ // PKCE parameters from client
+ codeChallenge := query.Get("code_challenge")
+ codeChallengeMethod := query.Get("code_challenge_method")
+ clientRedirectURI := query.Get("redirect_uri")
+ state := query.Get("state")
+ clientID := query.Get("client_id")
+
+ h.logger.Info("OAuth2: Authorization request - client_id: %s, redirect_uri: %s, code_challenge: %s",
+ clientID, clientRedirectURI, truncateString(codeChallenge, 10))
+
+ // Determine redirect URI strategy based on configuration
+ var redirectURI string
+ hasFixedRedirect := h.config.RedirectURIs != "" && !strings.Contains(h.config.RedirectURIs, ",")
+
+ if hasFixedRedirect {
+ // Fixed redirect mode: Use server's redirect URI to OAuth provider, proxy back to client
+ redirectURI = strings.TrimSpace(h.config.RedirectURIs)
+ h.logger.Info("OAuth2: Fixed redirect mode - using server URI: %s (will proxy to client: %s)", redirectURI, clientRedirectURI)
+
+ // Validate client redirect URI format and security
+ if clientRedirectURI == "" {
+ h.logger.Warn("SECURITY: Missing client redirect URI")
+ http.Error(w, "Missing redirect_uri", http.StatusBadRequest)
+ return
+ }
+
+ parsedURI, err := url.Parse(clientRedirectURI)
+ if err != nil {
+ h.logger.Warn("SECURITY: Invalid client redirect URI format: %s", clientRedirectURI)
+ http.Error(w, "Invalid redirect_uri format", http.StatusBadRequest)
+ return
+ }
+
+ // Additional security checks for client redirect URI
+ if parsedURI.Scheme != "http" && parsedURI.Scheme != "https" {
+ h.logger.Warn("SECURITY: Invalid redirect URI scheme: %s (must be http or https)", parsedURI.Scheme)
+ http.Error(w, "Invalid redirect_uri scheme", http.StatusBadRequest)
+ return
+ }
+
+ // Enforce HTTPS for non-localhost URIs
+ if parsedURI.Scheme == "http" && !isLocalhostURI(clientRedirectURI) {
+ h.logger.Warn("SECURITY: HTTP redirect URI not allowed for non-localhost: %s", clientRedirectURI)
+ http.Error(w, "HTTPS required for non-localhost redirect_uri", http.StatusBadRequest)
+ return
+ }
+
+ // Prevent fragment in redirect URI (OAuth 2.0 spec)
+ if parsedURI.Fragment != "" {
+ h.logger.Warn("SECURITY: Redirect URI contains fragment: %s", clientRedirectURI)
+ http.Error(w, "redirect_uri must not contain fragment", http.StatusBadRequest)
+ return
+ }
+
+ // Security: For fixed redirect mode, only allow localhost or loopback addresses
+ // This prevents open redirect attacks while still supporting development tools
+ if !isLocalhostURI(clientRedirectURI) {
+ h.logger.Warn("SECURITY: Fixed redirect mode only allows localhost URIs, rejecting: %s from %s", clientRedirectURI, r.RemoteAddr)
+ http.Error(w, "Fixed redirect mode only allows localhost redirect URIs for security. Use allowlist mode for production.", http.StatusBadRequest)
+ return
+ }
+
+ h.logger.Info("OAuth2: Validated localhost redirect URI for proxy: %s", clientRedirectURI)
+ } else if h.config.RedirectURIs != "" {
+ // Allowlist mode: Client's URI must be in allowlist, used directly (no proxy)
+ if !h.isValidRedirectURI(clientRedirectURI) {
+ h.logger.Warn("SECURITY: Redirect URI not in allowlist: %s from %s", clientRedirectURI, r.RemoteAddr)
+ http.Error(w, "Invalid redirect_uri", http.StatusBadRequest)
+ return
+ }
+ redirectURI = clientRedirectURI
+ h.logger.Info("OAuth2: Allowlist mode - using client URI from allowlist: %s", redirectURI)
+ } else {
+ // No configuration: Reject for security
+ h.logger.Warn("SECURITY: No redirect URIs configured, rejecting: %s from %s", clientRedirectURI, r.RemoteAddr)
+ http.Error(w, "Invalid redirect_uri", http.StatusBadRequest)
+ return
+ }
+
+ // Update OAuth2 config with redirect URI
+ h.oauth2Config.RedirectURL = redirectURI
+
+ // For fixed redirect mode, create signed state with client redirect URI
+ actualState := state
+ if hasFixedRedirect {
+ // Create state data with redirect URI
+ stateData := map[string]string{
+ "state": state,
+ "redirect": clientRedirectURI,
+ }
+
+ // Sign state for integrity protection
+ signedState, err := h.signState(stateData)
+ if err != nil {
+ h.logger.Error("OAuth2: Failed to sign state: %v", err)
+ http.Error(w, "Internal server error", http.StatusInternalServerError)
+ return
+ }
+
+ actualState = signedState
+ h.logger.Info("OAuth2: Signed state for proxy callback (length: %d)", len(signedState))
+ }
+
+ // Create authorization URL
+ authURL := h.oauth2Config.AuthCodeURL(actualState, oauth2.AccessTypeOffline)
+
+ // Add PKCE parameters to the URL if provided
+ if codeChallenge != "" {
+ parsedURL, err := url.Parse(authURL)
+ if err != nil {
+ h.logger.Error("OAuth2: Failed to parse auth URL: %v", err)
+ http.Error(w, "Internal server error", http.StatusInternalServerError)
+ return
+ }
+
+ query := parsedURL.Query()
+ query.Set("code_challenge", codeChallenge)
+ query.Set("code_challenge_method", codeChallengeMethod)
+
+ parsedURL.RawQuery = query.Encode()
+ authURL = parsedURL.String()
+ }
+
+ h.logger.Info("OAuth2: Redirecting to authorization URL: %s", authURL)
+ http.Redirect(w, r, authURL, http.StatusTemporaryRedirect)
+}
+
+// HandleCallback handles OAuth2 callback
+func (h *OAuth2Handler) HandleCallback(w http.ResponseWriter, r *http.Request) {
+ // Defense in depth: Check OAuth mode
+ if h.config.Mode == "native" {
+ http.Error(w, "OAuth proxy disabled in native mode", http.StatusNotFound)
+ return
+ }
+
+ if r.Method != "GET" {
+ http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
+ return
+ }
+
+ // Extract parameters
+ code := r.URL.Query().Get("code")
+ state := r.URL.Query().Get("state")
+ errorParam := r.URL.Query().Get("error")
+
+ h.logger.Info("OAuth2: Callback received - code: %s, state: %s, error: %s",
+ truncateString(code, 10), state, errorParam)
+
+ // Handle OAuth errors
+ if errorParam != "" {
+ errorDesc := r.URL.Query().Get("error_description")
+ h.logger.Error("OAuth2: Authorization error: %s - %s", errorParam, errorDesc)
+ http.Error(w, fmt.Sprintf("Authorization failed: %s", errorDesc), http.StatusBadRequest)
+ return
+ }
+
+ if code == "" {
+ h.logger.Error("OAuth2: No authorization code received")
+ http.Error(w, "No authorization code received", http.StatusBadRequest)
+ return
+ }
+
+ // If using fixed redirect URI, handle proxy callback
+ if h.config.RedirectURIs != "" && !strings.Contains(h.config.RedirectURIs, ",") {
+ // Verify and decode signed state parameter
+ stateData, err := h.verifyState(state)
+ if err != nil {
+ h.logger.Warn("SECURITY: State verification failed: %v", err)
+ http.Error(w, "Invalid state parameter", http.StatusBadRequest)
+ return
+ }
+
+ // Extract original state and redirect URI
+ originalState, hasState := stateData["state"]
+ originalRedirectURI, hasRedirect := stateData["redirect"]
+
+ if hasState && hasRedirect {
+ // Re-validate redirect URI for defense in depth
+ // Even though state is HMAC-signed, validate the redirect URI is localhost
+ if !isLocalhostURI(originalRedirectURI) {
+ h.logger.Warn("SECURITY: Callback redirect URI is not localhost (possible key compromise): %s", originalRedirectURI)
+ http.Error(w, "Invalid redirect URI in state", http.StatusBadRequest)
+ return
+ }
+
+ h.logger.Info("OAuth2: State verified, proxying callback to localhost client: %s", originalRedirectURI)
+
+ // Build proxy callback URL
+ proxyURL := fmt.Sprintf("%s?code=%s&state=%s", originalRedirectURI, code, originalState)
+ http.Redirect(w, r, proxyURL, http.StatusFound)
+ return
+ }
+
+ h.logger.Error("OAuth2: State missing required fields")
+ http.Error(w, "Invalid state format", http.StatusBadRequest)
+ return
+ }
+
+ // For non-fixed redirect mode or as fallback, show success page
+ h.showSuccessPage(w, code, state)
+}
+
+// HandleToken handles OAuth2 token exchange
+func (h *OAuth2Handler) HandleToken(w http.ResponseWriter, r *http.Request) {
+ // Defense in depth: Check OAuth mode
+ if h.config.Mode == "native" {
+ http.Error(w, "OAuth proxy disabled in native mode", http.StatusNotFound)
+ return
+ }
+
+ // Add CORS headers for browser-based MCP clients
+ w.Header().Set("Access-Control-Allow-Origin", "*")
+ w.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS")
+ w.Header().Set("Access-Control-Allow-Headers", "Authorization, *")
+ w.Header().Set("Access-Control-Max-Age", "86400")
+
+ if r.Method == "OPTIONS" {
+ w.WriteHeader(http.StatusOK)
+ return
+ }
+
+ if r.Method != "POST" {
+ http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
+ return
+ }
+
+ h.logger.Info("OAuth2: Token exchange request from %s", r.RemoteAddr)
+
+ // Parse form data
+ if err := r.ParseForm(); err != nil {
+ h.logger.Error("OAuth2: Failed to parse form: %v", err)
+ http.Error(w, "Invalid request", http.StatusBadRequest)
+ return
+ }
+
+ // Extract parameters
+ grantType := r.FormValue("grant_type")
+ code := r.FormValue("code")
+ clientRedirectURI := r.FormValue("redirect_uri")
+ clientID := r.FormValue("client_id")
+ codeVerifier := r.FormValue("code_verifier")
+
+ h.logger.Info("OAuth2: Token request - grant_type: %s, client_id: %s, redirect_uri: %s, code: %s",
+ grantType, clientID, clientRedirectURI, truncateString(code, 10))
+
+ // Validate parameters
+ if code == "" {
+ h.logger.Error("OAuth2: Missing authorization code")
+ http.Error(w, "Missing authorization code", http.StatusBadRequest)
+ return
+ }
+
+ if grantType != "authorization_code" {
+ h.logger.Error("OAuth2: Unsupported grant type: %s", grantType)
+ http.Error(w, "Unsupported grant type", http.StatusBadRequest)
+ return
+ }
+
+ // Set redirect URI for token exchange
+ redirectURI := clientRedirectURI
+ if h.config.RedirectURIs != "" && !strings.Contains(h.config.RedirectURIs, ",") {
+ redirectURI = strings.TrimSpace(h.config.RedirectURIs)
+ h.logger.Info("OAuth2: Token exchange using fixed redirect URI: %s", redirectURI)
+ }
+
+ h.oauth2Config.RedirectURL = redirectURI
+
+ // For PKCE, we need to manually add the code_verifier to the token exchange
+ // Since oauth2 library doesn't support PKCE directly, we'll use a custom approach
+ ctx := context.Background()
+
+ // Create custom HTTP client for token exchange with PKCE
+ if codeVerifier != "" {
+ // Create a custom client that adds code_verifier to the token request
+ customClient := &http.Client{
+ Transport: &pkceTransport{
+ base: http.DefaultTransport,
+ codeVerifier: codeVerifier,
+ },
+ }
+ ctx = context.WithValue(ctx, oauth2.HTTPClient, customClient)
+ }
+
+ // Exchange code for tokens
+ token, err := h.oauth2Config.Exchange(ctx, code)
+ if err != nil {
+ h.logger.Error("OAuth2: Token exchange failed: %v", err)
+ http.Error(w, "Token exchange failed", http.StatusInternalServerError)
+ return
+ }
+
+ h.logger.Info("OAuth2: Token exchange successful")
+
+ // Build response
+ response := map[string]interface{}{
+ "access_token": token.AccessToken,
+ "token_type": token.TokenType,
+ "expires_in": int(time.Until(token.Expiry).Seconds()),
+ }
+
+ // Add optional fields
+ if token.RefreshToken != "" {
+ response["refresh_token"] = token.RefreshToken
+ }
+
+ // Add ID token if present
+ if idToken, ok := token.Extra("id_token").(string); ok {
+ response["id_token"] = idToken
+ }
+
+ // Add scope if present
+ if scope, ok := token.Extra("scope").(string); ok {
+ response["scope"] = scope
+ }
+
+ // Send response
+ w.Header().Set("Content-Type", "application/json")
+ w.Header().Set("Cache-Control", "no-store")
+ w.Header().Set("Pragma", "no-cache")
+ w.WriteHeader(http.StatusOK)
+
+ if err := json.NewEncoder(w).Encode(response); err != nil {
+ h.logger.Error("OAuth2: Failed to encode token response: %v", err)
+ }
+}
+
+// showSuccessPage displays a success page after OAuth completion
+func (h *OAuth2Handler) showSuccessPage(w http.ResponseWriter, code, state string) {
+ // Log authorization details server-side (truncated for security)
+ h.logger.Info("OAuth2: Authorization successful - code: %s, state: %s",
+ truncateString(code, 10), truncateString(state, 10))
+
+ w.Header().Set("Content-Type", "text/html; charset=utf-8")
+ w.Header().Set("Cache-Control", "no-store")
+ w.Header().Set("X-Content-Type-Options", "nosniff")
+ w.WriteHeader(http.StatusOK)
+ _, _ = fmt.Fprintf(w, `
+
+
+
+
+ OAuth2 Success
+
+
+ Authentication Successful!
+ You have been successfully authenticated.
+ You can now close this window and return to your application.
+
+ `)
+}
+
+// truncateString safely truncates a string for logging
+func truncateString(s string, maxLen int) string {
+ if len(s) <= maxLen {
+ return s
+ }
+ return s[:maxLen] + "..."
+}
+
+// pkceTransport adds PKCE code_verifier to token exchange requests
+type pkceTransport struct {
+ base http.RoundTripper
+ codeVerifier string
+}
+
+// RoundTrip implements the RoundTripper interface
+func (p *pkceTransport) RoundTrip(req *http.Request) (*http.Response, error) {
+ // Only modify POST requests to token endpoint
+ if req.Method == "POST" && strings.Contains(req.URL.Path, "/token") {
+ // Read the existing body
+ defer func() {
+ if closeErr := req.Body.Close(); closeErr != nil {
+ // Note: pkceTransport doesn't have access to h.logger, using standard log
+ log.Printf("Warning: failed to close request body: %v", closeErr)
+ }
+ }()
+ body, err := io.ReadAll(req.Body)
+ if err != nil {
+ return nil, err
+ }
+
+ // Parse the form data
+ values, err := url.ParseQuery(string(body))
+ if err != nil {
+ return nil, err
+ }
+
+ // Add code_verifier if not already present
+ if values.Get("code_verifier") == "" && p.codeVerifier != "" {
+ values.Set("code_verifier", p.codeVerifier)
+ }
+
+ // Create new body with code_verifier
+ newBody := strings.NewReader(values.Encode())
+ req.Body = io.NopCloser(newBody)
+ req.ContentLength = int64(len(values.Encode()))
+ }
+
+ return p.base.RoundTrip(req)
+}
+
+// getEnv gets environment variable with default value
+func getEnv(key, def string) string {
+ if v, ok := os.LookupEnv(key); ok {
+ return v
+ }
+ return def
+}
+
+// signState signs state data with HMAC-SHA256 for integrity protection
+func (h *OAuth2Handler) signState(stateData map[string]string) (string, error) {
+ // Create deterministic string for signing
+ dataToSign := ""
+ if state, ok := stateData["state"]; ok {
+ dataToSign += "state=" + state + "&"
+ }
+ if redirect, ok := stateData["redirect"]; ok {
+ dataToSign += "redirect=" + redirect
+ }
+
+ // Create HMAC signature
+ mac := hmac.New(sha256.New, h.config.stateSigningKey)
+ mac.Write([]byte(dataToSign))
+ signature := hex.EncodeToString(mac.Sum(nil))
+
+ // Add signature to state data
+ stateData["sig"] = signature
+ signedData, err := json.Marshal(stateData)
+ if err != nil {
+ return "", fmt.Errorf("failed to marshal signed state: %w", err)
+ }
+
+ // Base64 encode for URL safety
+ return base64.URLEncoding.EncodeToString(signedData), nil
+}
+
+// verifyState verifies and decodes HMAC-signed state parameter
+func (h *OAuth2Handler) verifyState(encodedState string) (map[string]string, error) {
+ // Base64 decode
+ decodedState, err := base64.URLEncoding.DecodeString(encodedState)
+ if err != nil {
+ return nil, fmt.Errorf("failed to decode state: %w", err)
+ }
+
+ // Unmarshal state data
+ var stateData map[string]string
+ if err := json.Unmarshal(decodedState, &stateData); err != nil {
+ return nil, fmt.Errorf("failed to unmarshal state: %w", err)
+ }
+
+ // Extract signature
+ receivedSig, ok := stateData["sig"]
+ if !ok {
+ return nil, fmt.Errorf("state missing signature")
+ }
+ delete(stateData, "sig") // Remove for verification
+
+ // Recalculate signature using same deterministic approach
+ dataToSign := ""
+ if state, ok := stateData["state"]; ok {
+ dataToSign += "state=" + state + "&"
+ }
+ if redirect, ok := stateData["redirect"]; ok {
+ dataToSign += "redirect=" + redirect
+ }
+
+ mac := hmac.New(sha256.New, h.config.stateSigningKey)
+ mac.Write([]byte(dataToSign))
+ expectedSig := hex.EncodeToString(mac.Sum(nil))
+
+ // Verify signature using constant-time comparison
+ if !hmac.Equal([]byte(receivedSig), []byte(expectedSig)) {
+ return nil, fmt.Errorf("invalid state signature - possible tampering detected")
+ }
+
+ return stateData, nil
+}
+
+// isLocalhostURI checks if URI is localhost for development
+func isLocalhostURI(uri string) bool {
+ parsedURI, err := url.Parse(uri)
+ if err != nil {
+ return false
+ }
+
+ hostname := strings.ToLower(parsedURI.Hostname())
+ return hostname == "localhost" || hostname == "127.0.0.1" || hostname == "::1"
+}
+
+// isValidRedirectURI validates redirect URI against allowlist for security
+func (h *OAuth2Handler) isValidRedirectURI(uri string) bool {
+ if h.config.RedirectURIs == "" {
+ // No redirect URIs configured - reject all redirects for security
+ h.logger.Warn("WARNING: No OAuth redirect URIs configured, rejecting redirect: %s", uri)
+ return false
+ }
+
+ // Parse allowlist
+ allowedURIs := strings.Split(h.config.RedirectURIs, ",")
+ for _, allowed := range allowedURIs {
+ allowed = strings.TrimSpace(allowed)
+ if allowed != "" && uri == allowed {
+ return true
+ }
+ }
+
+ return false
+}
+
+// validateOAuthParams performs basic input validation to prevent abuse
+func (h *OAuth2Handler) validateOAuthParams(r *http.Request) error {
+ // Basic length validation to prevent abuse
+ if code := r.FormValue("code"); len(code) > 512 {
+ return fmt.Errorf("invalid code parameter length")
+ }
+ if state := r.FormValue("state"); len(state) > 256 {
+ return fmt.Errorf("invalid state parameter length")
+ }
+ if challenge := r.FormValue("code_challenge"); len(challenge) > 256 {
+ return fmt.Errorf("invalid code_challenge parameter length")
+ }
+ return nil
+}
+
+// addSecurityHeaders adds essential security headers for OAuth endpoints
+func (h *OAuth2Handler) addSecurityHeaders(w http.ResponseWriter) {
+ w.Header().Set("X-Content-Type-Options", "nosniff")
+ w.Header().Set("X-Frame-Options", "DENY")
+ w.Header().Set("Cache-Control", "no-store, no-cache, max-age=0")
+ w.Header().Set("Pragma", "no-cache")
+}
diff --git a/integration_test.go b/integration_test.go
new file mode 100644
index 0000000..9079323
--- /dev/null
+++ b/integration_test.go
@@ -0,0 +1,351 @@
+package oauth
+
+import (
+ "context"
+ "testing"
+ "time"
+
+ "github.com/golang-jwt/jwt/v5"
+ "github.com/mark3labs/mcp-go/mcp"
+ mcpserver "github.com/mark3labs/mcp-go/server"
+ "github.com/tuannvm/oauth-mcp-proxy/provider"
+)
+
+// TestIntegration validates core architecture and integration.
+// Tests:
+// - provider/ package isolation
+// - Config conversion (root → provider)
+// - Server struct with instance-scoped state
+// - Middleware integration with MCP server
+// - Backward compatibility (User type re-export)
+func TestIntegration(t *testing.T) {
+ t.Run("ProviderPackageIsolation", func(t *testing.T) {
+ // Test that provider package has its own Config/User/Logger types
+ // and doesn't import root package
+
+ cfg := &provider.Config{
+ Provider: "hmac",
+ Issuer: "https://test.example.com",
+ Audience: "api://test",
+ JWTSecret: []byte("test-secret-key-must-be-32-bytes-long!"),
+ }
+
+ validator := &provider.HMACValidator{}
+ if err := validator.Initialize(cfg); err != nil {
+ t.Fatalf("provider.HMACValidator.Initialize failed: %v", err)
+ }
+
+ // Create test token
+ token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
+ "sub": "test-user",
+ "email": "test@example.com",
+ "aud": cfg.Audience,
+ "exp": time.Now().Add(time.Hour).Unix(),
+ "iat": time.Now().Unix(),
+ })
+
+ tokenString, _ := token.SignedString(cfg.JWTSecret)
+
+ // Validate token using provider package directly
+ user, err := validator.ValidateToken(context.Background(), tokenString)
+ if err != nil {
+ t.Fatalf("ValidateToken failed: %v", err)
+ }
+
+ if user.Subject != "test-user" {
+ t.Errorf("Expected subject 'test-user', got '%s'", user.Subject)
+ }
+
+ t.Logf("✅ provider package works independently")
+ })
+
+ t.Run("ConfigConversion", func(t *testing.T) {
+ // Test root Config → provider.Config conversion
+
+ rootCfg := &Config{
+ Mode: "native",
+ Provider: "hmac",
+ Issuer: "https://test.example.com",
+ Audience: "api://test",
+ JWTSecret: []byte("test-secret-key-must-be-32-bytes-long!"),
+ ClientID: "",
+ ServerURL: "",
+ RedirectURIs: "",
+ }
+
+ // createValidator converts root Config → provider.Config
+ validator, err := createValidator(rootCfg, &defaultLogger{})
+ if err != nil {
+ t.Fatalf("createValidator failed: %v", err)
+ }
+
+ // Validator should be initialized and ready
+ token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
+ "sub": "test-user",
+ "email": "test@example.com",
+ "aud": rootCfg.Audience,
+ "exp": time.Now().Add(time.Hour).Unix(),
+ "iat": time.Now().Unix(),
+ })
+
+ tokenString, _ := token.SignedString(rootCfg.JWTSecret)
+
+ user, err := validator.ValidateToken(context.Background(), tokenString)
+ if err != nil {
+ t.Fatalf("ValidateToken after conversion failed: %v", err)
+ }
+
+ if user.Subject != "test-user" {
+ t.Errorf("Expected subject 'test-user', got '%s'", user.Subject)
+ }
+
+ t.Logf("✅ Config conversion works correctly")
+ })
+
+ t.Run("ServerInstanceScoped", func(t *testing.T) {
+ // Test that Server struct has instance-scoped cache (not global)
+
+ cfg := &Config{
+ Mode: "native",
+ Provider: "hmac",
+ Issuer: "https://test.example.com",
+ Audience: "api://test",
+ JWTSecret: []byte("test-secret-key-must-be-32-bytes-long!"),
+ }
+
+ // Create two servers with same config
+ server1, err := NewServer(cfg)
+ if err != nil {
+ t.Fatalf("NewServer failed: %v", err)
+ }
+
+ server2, err := NewServer(cfg)
+ if err != nil {
+ t.Fatalf("NewServer failed: %v", err)
+ }
+
+ // Verify they have different cache instances
+ if server1.cache == server2.cache {
+ t.Errorf("Server instances share same cache (should be instance-scoped)")
+ }
+
+ t.Logf("✅ Server has instance-scoped cache")
+ })
+
+ t.Run("MiddlewareIntegration", func(t *testing.T) {
+ // Test complete middleware integration with MCP server
+
+ cfg := &Config{
+ Mode: "native",
+ Provider: "hmac",
+ Issuer: "https://test.example.com",
+ Audience: "api://test",
+ JWTSecret: []byte("test-secret-key-must-be-32-bytes-long!"),
+ }
+
+ server, err := NewServer(cfg)
+ if err != nil {
+ t.Fatalf("NewServer failed: %v", err)
+ }
+
+ // Get middleware
+ middleware := server.Middleware()
+
+ // Create test MCP server
+ mcpServer := mcpserver.NewMCPServer("Test Server", "1.0.0")
+
+ // Handler that checks user context
+ var capturedUser *User
+ testHandler := func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
+ user, ok := GetUserFromContext(ctx)
+ if ok {
+ capturedUser = user
+ }
+ return mcp.NewToolResultText("ok"), nil
+ }
+
+ // Wrap with middleware
+ protectedHandler := middleware(testHandler)
+
+ // Add to MCP server
+ mcpServer.AddTool(
+ mcp.Tool{
+ Name: "test",
+ Description: "Test tool",
+ InputSchema: mcp.ToolInputSchema{
+ Type: "object",
+ Properties: map[string]interface{}{},
+ },
+ },
+ protectedHandler,
+ )
+
+ // Generate token
+ token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
+ "sub": "test-user-123",
+ "email": "test@example.com",
+ "preferred_username": "testuser",
+ "aud": cfg.Audience,
+ "exp": time.Now().Add(time.Hour).Unix(),
+ "iat": time.Now().Unix(),
+ })
+
+ tokenString, _ := token.SignedString(cfg.JWTSecret)
+
+ // Create context with token
+ ctx := WithOAuthToken(context.Background(), tokenString)
+
+ // Call protected handler
+ result, err := protectedHandler(ctx, mcp.CallToolRequest{})
+
+ if err != nil {
+ t.Fatalf("Protected handler failed: %v", err)
+ }
+
+ if result == nil {
+ t.Fatal("Expected result, got nil")
+ }
+
+ // Verify user was extracted
+ if capturedUser == nil {
+ t.Fatal("User was not extracted from context")
+ }
+
+ if capturedUser.Subject != "test-user-123" {
+ t.Errorf("Expected subject 'test-user-123', got '%s'", capturedUser.Subject)
+ }
+
+ if capturedUser.Email != "test@example.com" {
+ t.Errorf("Expected email 'test@example.com', got '%s'", capturedUser.Email)
+ }
+
+ if capturedUser.Username != "testuser" {
+ t.Errorf("Expected username 'testuser', got '%s'", capturedUser.Username)
+ }
+
+ t.Logf("✅ Middleware integration works end-to-end")
+ })
+
+ t.Run("UserTypeReexport", func(t *testing.T) {
+ // Test that User type is re-exported from root for backward compatibility
+
+ var rootUser *User
+ var providerUser *provider.User
+
+ // Should be assignable (type alias)
+ rootUser = &User{
+ Subject: "test",
+ Username: "test",
+ Email: "test@example.com",
+ }
+
+ providerUser = rootUser // Should compile (type alias)
+
+ if providerUser.Subject != "test" {
+ t.Errorf("Type alias not working correctly")
+ }
+
+ t.Logf("✅ User type re-export works (backward compatible)")
+ })
+}
+
+// TestValidatorIntegration validates provider package validator integration.
+// Tests HMAC and OIDC validators work correctly with the provider package.
+func TestValidatorIntegration(t *testing.T) {
+ t.Run("HMACValidator", func(t *testing.T) {
+ cfg := &provider.Config{
+ Provider: "hmac",
+ Issuer: "https://test.example.com",
+ Audience: "api://test",
+ JWTSecret: []byte("test-secret-key-must-be-32-bytes-long!"),
+ }
+
+ v := &provider.HMACValidator{}
+ if err := v.Initialize(cfg); err != nil {
+ t.Fatalf("Initialize failed: %v", err)
+ }
+
+ // Valid token
+ token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
+ "sub": "user123",
+ "email": "user@example.com",
+ "aud": cfg.Audience,
+ "exp": time.Now().Add(time.Hour).Unix(),
+ "iat": time.Now().Unix(),
+ })
+
+ tokenString, _ := token.SignedString(cfg.JWTSecret)
+
+ user, err := v.ValidateToken(context.Background(), tokenString)
+ if err != nil {
+ t.Fatalf("ValidateToken failed: %v", err)
+ }
+
+ if user.Subject != "user123" {
+ t.Errorf("Expected subject 'user123', got '%s'", user.Subject)
+ }
+
+ t.Logf("✅ HMACValidator works in provider package")
+ })
+
+ t.Run("OIDCValidator_DirectTest", func(t *testing.T) {
+ // Test OIDCValidator audience validation logic directly
+ _ = &provider.OIDCValidator{}
+
+ testCases := []struct {
+ name string
+ claims jwt.MapClaims
+ audience string
+ expectErr bool
+ }{
+ {
+ name: "valid string audience",
+ claims: jwt.MapClaims{
+ "aud": "api://test",
+ "sub": "user123",
+ },
+ audience: "api://test",
+ expectErr: false,
+ },
+ {
+ name: "invalid string audience",
+ claims: jwt.MapClaims{
+ "aud": "api://wrong",
+ "sub": "user123",
+ },
+ audience: "api://test",
+ expectErr: true,
+ },
+ {
+ name: "valid array audience",
+ claims: jwt.MapClaims{
+ "aud": []interface{}{"api://test", "api://other"},
+ "sub": "user123",
+ },
+ audience: "api://test",
+ expectErr: false,
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ // Use reflection to set audience (private field)
+ // This is just for testing the validateAudience logic
+ cfg := &provider.Config{
+ Provider: "okta",
+ Issuer: "https://test.okta.com",
+ Audience: tc.audience,
+ }
+
+ // OIDCValidator would normally be initialized with provider
+ // Here we're just testing config initialization
+ v := &provider.OIDCValidator{}
+ err := v.Initialize(cfg)
+ // Expected to fail (no real OIDC provider), but config structure is valid
+ _ = err
+
+ t.Logf("✅ OIDCValidator config structure accepted")
+ })
+ }
+ })
+}
diff --git a/logger.go b/logger.go
new file mode 100644
index 0000000..48e1987
--- /dev/null
+++ b/logger.go
@@ -0,0 +1,46 @@
+package oauth
+
+import "log"
+
+// Logger interface for pluggable logging.
+// Implement this interface to integrate oauth-mcp-proxy with your application's
+// logging system (e.g., zap, logrus, slog). If not provided in Config, a default
+// logger using log.Printf will be used.
+//
+// Example:
+//
+// type MyLogger struct{ logger *zap.Logger }
+// func (l *MyLogger) Info(msg string, args ...interface{}) {
+// l.logger.Sugar().Infof(msg, args...)
+// }
+// // ... implement Debug, Warn, Error
+//
+// cfg := &oauth.Config{
+// Provider: "okta",
+// Logger: &MyLogger{logger: zapLogger},
+// }
+type Logger interface {
+ Debug(msg string, args ...interface{}) // Debug-level logging for detailed troubleshooting
+ Info(msg string, args ...interface{}) // Info-level logging for normal OAuth operations
+ Warn(msg string, args ...interface{}) // Warn-level logging for security violations
+ Error(msg string, args ...interface{}) // Error-level logging for OAuth failures
+}
+
+// defaultLogger implements Logger using standard library log
+type defaultLogger struct{}
+
+func (l *defaultLogger) Debug(msg string, args ...interface{}) {
+ log.Printf("[DEBUG] "+msg, args...)
+}
+
+func (l *defaultLogger) Info(msg string, args ...interface{}) {
+ log.Printf("[INFO] "+msg, args...)
+}
+
+func (l *defaultLogger) Warn(msg string, args ...interface{}) {
+ log.Printf("[WARN] "+msg, args...)
+}
+
+func (l *defaultLogger) Error(msg string, args ...interface{}) {
+ log.Printf("[ERROR] "+msg, args...)
+}
diff --git a/metadata.go b/metadata.go
new file mode 100644
index 0000000..81c4e56
--- /dev/null
+++ b/metadata.go
@@ -0,0 +1,322 @@
+package oauth
+
+import (
+ "encoding/json"
+ "fmt"
+ "net/http"
+ "strings"
+ "time"
+)
+
+// HandleMetadata handles the legacy OAuth metadata endpoint for MCP compliance
+func (h *OAuth2Handler) HandleMetadata(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Content-Type", "application/json")
+ w.Header().Set("Cache-Control", "public, max-age=300") // Cache for 5 minutes
+
+ if r.Method != "GET" {
+ http.Error(w, `{"error":"Method not allowed"}`, http.StatusMethodNotAllowed)
+ return
+ }
+
+ // Return OAuth metadata based on configuration
+ if !h.config.Enabled {
+ w.WriteHeader(http.StatusOK)
+ _, _ = fmt.Fprintf(w, `{
+ "oauth_enabled": false,
+ "authentication_methods": ["none"],
+ "mcp_version": "1.0.0"
+ }`)
+ return
+ }
+
+ // Create provider-specific metadata
+ metadata := map[string]interface{}{
+ "oauth_enabled": true,
+ "authentication_methods": []string{"bearer_token"},
+ "token_types": []string{"JWT"},
+ "token_validation": "server_side",
+ "supported_flows": []string{"claude_code", "mcp_remote"},
+ "mcp_version": "1.0.0",
+ "server_version": h.config.Version,
+ "provider": h.config.Provider,
+
+ // Add OIDC discovery fields for MCP client compatibility
+ "issuer": h.config.MCPURL,
+ "authorization_endpoint": fmt.Sprintf("%s/oauth/authorize", h.config.MCPURL),
+ "token_endpoint": fmt.Sprintf("%s/oauth/token", h.config.MCPURL),
+ "registration_endpoint": fmt.Sprintf("%s/oauth/register", h.config.MCPURL),
+ "response_types_supported": []string{"code"},
+ "response_modes_supported": []string{"query"},
+ "grant_types_supported": []string{"authorization_code"},
+ }
+
+ // Add provider-specific metadata
+ switch h.config.Provider {
+ case "hmac":
+ metadata["validation_method"] = "hmac_sha256"
+ metadata["signature_algorithm"] = "HS256"
+ metadata["requires_secret"] = true
+ case "okta", "google", "azure":
+ metadata["validation_method"] = "oidc_jwks"
+ metadata["signature_algorithm"] = "RS256"
+ metadata["requires_secret"] = false
+ if h.config.Issuer != "" {
+ metadata["issuer"] = h.config.Issuer
+ metadata["jwks_uri"] = h.config.Issuer + "/.well-known/jwks.json"
+ }
+ if h.config.Audience != "" {
+ metadata["audience"] = h.config.Audience
+ }
+ }
+
+ // Encode and send response
+ w.WriteHeader(http.StatusOK)
+ if err := json.NewEncoder(w).Encode(metadata); err != nil {
+ h.logger.Error("OAuth2: Error encoding metadata: %v", err)
+ http.Error(w, "Internal server error", http.StatusInternalServerError)
+ }
+}
+
+// HandleAuthorizationServerMetadata handles the standard OAuth 2.0 Authorization Server Metadata endpoint
+func (h *OAuth2Handler) HandleAuthorizationServerMetadata(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Content-Type", "application/json")
+ w.Header().Set("Cache-Control", "public, max-age=300") // Cache for 5 minutes
+ // Add CORS headers for browser-based MCP clients like MCP Inspector
+ w.Header().Set("Access-Control-Allow-Origin", "*")
+ w.Header().Set("Access-Control-Allow-Methods", "GET, HEAD, OPTIONS")
+ w.Header().Set("Access-Control-Allow-Headers", "Authorization, *")
+ w.Header().Set("Access-Control-Max-Age", "86400")
+
+ switch r.Method {
+ case "OPTIONS", "HEAD":
+ w.WriteHeader(http.StatusOK)
+ return
+ case "GET":
+ // Continue to metadata response
+ default:
+ http.Error(w, `{"error":"Method not allowed"}`, http.StatusMethodNotAllowed)
+ return
+ }
+
+ // Return OAuth 2.0 Authorization Server Metadata (RFC 8414)
+ metadata := h.GetAuthorizationServerMetadata()
+
+ // Encode and send response
+ w.WriteHeader(http.StatusOK)
+ if err := json.NewEncoder(w).Encode(metadata); err != nil {
+ h.logger.Error("OAuth2: Error encoding Authorization Server metadata: %v", err)
+ http.Error(w, "Internal server error", http.StatusInternalServerError)
+ }
+}
+
+// HandleProtectedResourceMetadata handles the OAuth 2.0 Protected Resource Metadata endpoint
+func (h *OAuth2Handler) HandleProtectedResourceMetadata(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Content-Type", "application/json")
+ w.Header().Set("Cache-Control", "public, max-age=300") // Cache for 5 minutes
+
+ if r.Method != "GET" {
+ http.Error(w, `{"error":"Method not allowed"}`, http.StatusMethodNotAllowed)
+ return
+ }
+
+ // Return OAuth 2.0 Protected Resource Metadata (RFC 9728)
+ // Point to authorization server based on mode
+ var authServer string
+ if h.config.Mode == "proxy" {
+ // Proxy mode: MCP server acts as authorization server
+ authServer = h.config.MCPURL
+ } else {
+ // Native mode: Point directly to OAuth provider
+ authServer = h.config.Issuer
+ }
+
+ metadata := map[string]interface{}{
+ "resource": h.config.MCPURL,
+ "authorization_servers": []string{authServer},
+ "bearer_methods_supported": []string{"header"},
+ "resource_signing_alg_values_supported": []string{"RS256"},
+ "resource_documentation": fmt.Sprintf("%s/docs", h.config.MCPURL),
+ "resource_policy_uri": fmt.Sprintf("%s/policy", h.config.MCPURL),
+ "resource_tos_uri": fmt.Sprintf("%s/tos", h.config.MCPURL),
+ }
+
+ // Encode and send response
+ w.WriteHeader(http.StatusOK)
+ if err := json.NewEncoder(w).Encode(metadata); err != nil {
+ h.logger.Error("OAuth2: Error encoding Protected Resource metadata: %v", err)
+ http.Error(w, "Internal server error", http.StatusInternalServerError)
+ }
+}
+
+// HandleRegister handles OAuth dynamic client registration for mcp-remote
+func (h *OAuth2Handler) HandleRegister(w http.ResponseWriter, r *http.Request) {
+ // Add CORS headers for browser-based MCP clients
+ w.Header().Set("Access-Control-Allow-Origin", "*")
+ w.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS")
+ w.Header().Set("Access-Control-Allow-Headers", "Authorization, *")
+ w.Header().Set("Access-Control-Max-Age", "86400")
+
+ if r.Method == "OPTIONS" {
+ w.WriteHeader(http.StatusOK)
+ return
+ }
+
+ if r.Method != "POST" {
+ http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
+ return
+ }
+
+ // Parse the registration request
+ var regRequest map[string]interface{}
+ if err := json.NewDecoder(r.Body).Decode(®Request); err != nil {
+ h.logger.Error("OAuth2: Failed to parse registration request: %v", err)
+ http.Error(w, "Invalid request body", http.StatusBadRequest)
+ return
+ }
+
+ h.logger.Info("OAuth2: Registration request: %+v", regRequest)
+
+ // Accept any client registration from mcp-remote
+ // Return our pre-configured client_id
+ response := map[string]interface{}{
+ "client_id": h.config.ClientID,
+ "client_secret": "", // Public client, no secret
+ "client_id_issued_at": time.Now().Unix(),
+ "grant_types": []string{"authorization_code", "refresh_token"},
+ "response_types": []string{"code"},
+ "token_endpoint_auth_method": "none",
+ "application_type": "native",
+ "client_name": regRequest["client_name"],
+ }
+
+ // Allow clients to register their own redirect URIs (needed for mcp-remote)
+ if redirectUris, ok := regRequest["redirect_uris"]; ok {
+ response["redirect_uris"] = redirectUris
+ h.logger.Info("OAuth2: Registration allowing client redirect URIs: %v", redirectUris)
+ } else if h.config.RedirectURIs != "" && !strings.Contains(h.config.RedirectURIs, ",") {
+ // Fallback to fixed redirect URI if no client URIs provided (single URI only)
+ trimmedURI := strings.TrimSpace(h.config.RedirectURIs)
+ response["redirect_uris"] = []string{trimmedURI}
+ h.logger.Info("OAuth2: Registration response using fixed redirect URI: %s", trimmedURI)
+ }
+
+ w.Header().Set("Content-Type", "application/json")
+ w.WriteHeader(http.StatusCreated)
+ if err := json.NewEncoder(w).Encode(response); err != nil {
+ h.logger.Error("OAuth2: Failed to encode registration response: %v", err)
+ http.Error(w, "Internal server error", http.StatusInternalServerError)
+ }
+}
+
+// HandleCallbackRedirect handles the /callback redirect for Claude Code compatibility
+func (h *OAuth2Handler) HandleCallbackRedirect(w http.ResponseWriter, r *http.Request) {
+ // Preserve all query parameters when redirecting
+ redirectURL := "/oauth/callback"
+ if r.URL.RawQuery != "" {
+ redirectURL += "?" + r.URL.RawQuery
+ }
+ http.Redirect(w, r, redirectURL, http.StatusFound)
+}
+
+// HandleOIDCDiscovery handles the OIDC discovery endpoint for MCP client compatibility
+func (h *OAuth2Handler) HandleOIDCDiscovery(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Content-Type", "application/json")
+ w.Header().Set("Cache-Control", "public, max-age=300")
+
+ if r.Method != "GET" {
+ http.Error(w, `{"error":"Method not allowed"}`, http.StatusMethodNotAllowed)
+ return
+ }
+
+ h.logger.Info("OAuth2: OIDC discovery request from %s", r.RemoteAddr)
+
+ // Return OIDC Discovery metadata with existing /oauth/ endpoints
+ metadata := map[string]interface{}{
+ "issuer": h.config.MCPURL,
+ "authorization_endpoint": fmt.Sprintf("%s/oauth/authorize", h.config.MCPURL),
+ "token_endpoint": fmt.Sprintf("%s/oauth/token", h.config.MCPURL),
+ "registration_endpoint": fmt.Sprintf("%s/oauth/register", h.config.MCPURL),
+ "response_types_supported": []string{"code"},
+ "response_modes_supported": []string{"query"},
+ "grant_types_supported": []string{"authorization_code"},
+ "token_endpoint_auth_methods_supported": []string{"none"},
+ "code_challenge_methods_supported": []string{"plain", "S256"},
+ "subject_types_supported": []string{"public"},
+ "scopes_supported": []string{"openid", "profile", "email"},
+ }
+
+ // Add provider-specific fields
+ if h.config.Audience != "" {
+ metadata["audience"] = h.config.Audience
+ }
+
+ // Add provider-specific signing algorithm information
+ switch h.config.Provider {
+ case "hmac":
+ metadata["id_token_signing_alg_values_supported"] = []string{"HS256"}
+ case "okta", "google", "azure":
+ metadata["id_token_signing_alg_values_supported"] = []string{"RS256"}
+ metadata["jwks_uri"] = fmt.Sprintf("%s/.well-known/jwks.json", h.config.MCPURL)
+ }
+
+ h.logger.Info("OAuth2: Returning OIDC discovery metadata for issuer: %s", h.config.MCPURL)
+
+ w.WriteHeader(http.StatusOK)
+ if err := json.NewEncoder(w).Encode(metadata); err != nil {
+ h.logger.Error("OAuth2: Error encoding OIDC discovery metadata: %v", err)
+ http.Error(w, "Internal server error", http.StatusInternalServerError)
+ }
+}
+
+// GetAuthorizationServerMetadata returns the OAuth 2.0 Authorization Server Metadata
+// with conditional responses based on OAuth mode
+func (h *OAuth2Handler) GetAuthorizationServerMetadata() map[string]interface{} {
+ var metadata map[string]interface{}
+
+ if h.config.Mode == "native" {
+ // Native mode: Point to OAuth provider directly
+ metadata = map[string]interface{}{
+ "issuer": h.config.Issuer, // OAuth provider issuer
+ "response_types_supported": []string{"code"},
+ "response_modes_supported": []string{"query"},
+ "grant_types_supported": []string{"authorization_code"},
+ "token_endpoint_auth_methods_supported": []string{"none"},
+ "code_challenge_methods_supported": []string{"plain", "S256"},
+ "scopes_supported": []string{"openid", "profile", "email"},
+ }
+
+ // Add provider-specific endpoints
+ switch h.config.Provider {
+ case "okta":
+ metadata["authorization_endpoint"] = fmt.Sprintf("%s/oauth2/v1/authorize", h.config.Issuer)
+ metadata["token_endpoint"] = fmt.Sprintf("%s/oauth2/v1/token", h.config.Issuer)
+ metadata["registration_endpoint"] = fmt.Sprintf("%s/oauth2/v1/clients", h.config.Issuer)
+ metadata["jwks_uri"] = fmt.Sprintf("%s/oauth2/v1/keys", h.config.Issuer)
+ case "google":
+ metadata["authorization_endpoint"] = "https://accounts.google.com/o/oauth2/v2/auth"
+ metadata["token_endpoint"] = "https://oauth2.googleapis.com/token"
+ metadata["jwks_uri"] = "https://www.googleapis.com/oauth2/v3/certs"
+ case "azure":
+ metadata["authorization_endpoint"] = fmt.Sprintf("%s/oauth2/v2.0/authorize", h.config.Issuer)
+ metadata["token_endpoint"] = fmt.Sprintf("%s/oauth2/v2.0/token", h.config.Issuer)
+ metadata["jwks_uri"] = fmt.Sprintf("%s/discovery/v2.0/keys", h.config.Issuer)
+ }
+ } else {
+ // Proxy mode: Point to MCP server endpoints
+ metadata = map[string]interface{}{
+ "issuer": h.config.MCPURL,
+ "authorization_endpoint": fmt.Sprintf("%s/oauth/authorize", h.config.MCPURL),
+ "token_endpoint": fmt.Sprintf("%s/oauth/token", h.config.MCPURL),
+ "registration_endpoint": fmt.Sprintf("%s/oauth/register", h.config.MCPURL),
+ "jwks_uri": fmt.Sprintf("%s/.well-known/jwks.json", h.config.MCPURL),
+ "response_types_supported": []string{"code"},
+ "response_modes_supported": []string{"query"},
+ "grant_types_supported": []string{"authorization_code"},
+ "token_endpoint_auth_methods_supported": []string{"none"},
+ "code_challenge_methods_supported": []string{"plain", "S256"},
+ "scopes_supported": []string{"openid", "profile", "email"},
+ }
+ }
+
+ return metadata
+}
diff --git a/metadata_test.go b/metadata_test.go
new file mode 100644
index 0000000..62d0a6d
--- /dev/null
+++ b/metadata_test.go
@@ -0,0 +1,231 @@
+package oauth
+
+import (
+ "encoding/json"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+)
+
+func TestGetAuthorizationServerMetadata(t *testing.T) {
+ tests := []struct {
+ name string
+ mode string
+ provider string
+ issuer string
+ mcpURL string
+ checkFields []string
+ }{
+ {
+ name: "Native mode with Okta",
+ mode: "native",
+ provider: "okta",
+ issuer: "https://dev.okta.com",
+ mcpURL: "https://mcp.example.com",
+ checkFields: []string{"issuer", "authorization_endpoint", "token_endpoint", "jwks_uri"},
+ },
+ {
+ name: "Native mode with Google",
+ mode: "native",
+ provider: "google",
+ issuer: "https://accounts.google.com",
+ mcpURL: "https://mcp.example.com",
+ checkFields: []string{"issuer", "authorization_endpoint", "token_endpoint", "jwks_uri"},
+ },
+ {
+ name: "Proxy mode",
+ mode: "proxy",
+ provider: "okta",
+ issuer: "https://dev.okta.com",
+ mcpURL: "https://mcp.example.com",
+ checkFields: []string{"issuer", "authorization_endpoint", "token_endpoint", "jwks_uri"},
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ config := &OAuth2Config{
+ Mode: tt.mode,
+ Provider: tt.provider,
+ Issuer: tt.issuer,
+ MCPURL: tt.mcpURL,
+ }
+ handler := &OAuth2Handler{config: config, logger: &defaultLogger{}}
+
+ metadata := handler.GetAuthorizationServerMetadata()
+
+ // Check that required fields are present
+ for _, field := range tt.checkFields {
+ if _, exists := metadata[field]; !exists {
+ t.Errorf("Missing required field: %s", field)
+ }
+ }
+
+ // Verify mode-specific behavior
+ issuer := metadata["issuer"].(string)
+ authEndpoint := metadata["authorization_endpoint"].(string)
+
+ if tt.mode == "native" {
+ // Native mode should point to OAuth provider
+ if issuer != tt.issuer {
+ t.Errorf("Native mode issuer = %s, expected %s", issuer, tt.issuer)
+ }
+ if tt.provider == "okta" {
+ expectedAuth := tt.issuer + "/oauth2/v1/authorize"
+ if authEndpoint != expectedAuth {
+ t.Errorf("Native mode auth endpoint = %s, expected %s", authEndpoint, expectedAuth)
+ }
+ }
+ } else {
+ // Proxy mode should point to MCP server
+ if issuer != tt.mcpURL {
+ t.Errorf("Proxy mode issuer = %s, expected %s", issuer, tt.mcpURL)
+ }
+ expectedAuth := tt.mcpURL + "/oauth/authorize"
+ if authEndpoint != expectedAuth {
+ t.Errorf("Proxy mode auth endpoint = %s, expected %s", authEndpoint, expectedAuth)
+ }
+ }
+ })
+ }
+}
+
+func TestHandleAuthorizationServerMetadata(t *testing.T) {
+ config := &OAuth2Config{
+ Mode: "native",
+ Provider: "okta",
+ Issuer: "https://dev.okta.com",
+ MCPURL: "https://mcp.example.com",
+ }
+ handler := &OAuth2Handler{config: config, logger: &defaultLogger{}}
+
+ tests := []struct {
+ name string
+ method string
+ expectedStatus int
+ }{
+ {"GET request", "GET", http.StatusOK},
+ {"HEAD request", "HEAD", http.StatusOK},
+ {"OPTIONS request", "OPTIONS", http.StatusOK},
+ {"POST request", "POST", http.StatusMethodNotAllowed},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ recorder := httptest.NewRecorder()
+ req := httptest.NewRequest(tt.method, "/.well-known/oauth-authorization-server", nil)
+
+ handler.HandleAuthorizationServerMetadata(recorder, req)
+
+ if recorder.Code != tt.expectedStatus {
+ t.Errorf("Status = %d, expected %d", recorder.Code, tt.expectedStatus)
+ }
+
+ // Check CORS headers are present
+ if origin := recorder.Header().Get("Access-Control-Allow-Origin"); origin != "*" {
+ t.Errorf("CORS Allow-Origin = %s, expected *", origin)
+ }
+
+ if tt.method == "GET" {
+ // Verify JSON response
+ var metadata map[string]interface{}
+ if err := json.Unmarshal(recorder.Body.Bytes(), &metadata); err != nil {
+ t.Errorf("Failed to parse JSON response: %v", err)
+ }
+
+ if issuer, exists := metadata["issuer"]; !exists {
+ t.Errorf("Missing issuer field in metadata")
+ } else if issuer != config.Issuer {
+ t.Errorf("Issuer = %s, expected %s", issuer, config.Issuer)
+ }
+ }
+ })
+ }
+}
+
+func TestHandleProtectedResourceMetadata(t *testing.T) {
+ config := &OAuth2Config{
+ Issuer: "https://dev.okta.com",
+ MCPURL: "https://mcp.example.com",
+ }
+ handler := &OAuth2Handler{config: config, logger: &defaultLogger{}}
+
+ recorder := httptest.NewRecorder()
+ req := httptest.NewRequest("GET", "/.well-known/oauth-protected-resource", nil)
+
+ handler.HandleProtectedResourceMetadata(recorder, req)
+
+ if recorder.Code != http.StatusOK {
+ t.Errorf("Status = %d, expected %d", recorder.Code, http.StatusOK)
+ }
+
+ var metadata map[string]interface{}
+ if err := json.Unmarshal(recorder.Body.Bytes(), &metadata); err != nil {
+ t.Errorf("Failed to parse JSON response: %v", err)
+ }
+
+ // Check required fields
+ if resource := metadata["resource"]; resource != config.MCPURL {
+ t.Errorf("Resource = %s, expected %s", resource, config.MCPURL)
+ }
+
+ authServers, exists := metadata["authorization_servers"]
+ if !exists {
+ t.Errorf("Missing authorization_servers field")
+ } else {
+ servers := authServers.([]interface{})
+ if len(servers) != 1 || servers[0] != config.Issuer {
+ t.Errorf("Authorization servers = %v, expected [%s]", servers, config.Issuer)
+ }
+ }
+}
+
+func TestHandleOIDCDiscovery(t *testing.T) {
+ config := &OAuth2Config{
+ MCPURL: "https://mcp.example.com",
+ Provider: "okta",
+ Audience: "https://api.example.com",
+ }
+ handler := &OAuth2Handler{
+ config: config,
+ logger: &defaultLogger{},
+ }
+
+ recorder := httptest.NewRecorder()
+ req := httptest.NewRequest("GET", "/.well-known/openid_configuration", nil)
+
+ handler.HandleOIDCDiscovery(recorder, req)
+
+ if recorder.Code != http.StatusOK {
+ t.Errorf("Status = %d, expected %d", recorder.Code, http.StatusOK)
+ }
+
+ var metadata map[string]interface{}
+ if err := json.Unmarshal(recorder.Body.Bytes(), &metadata); err != nil {
+ t.Errorf("Failed to parse JSON response: %v", err)
+ }
+
+ // Check required OIDC fields
+ requiredFields := []string{
+ "issuer",
+ "authorization_endpoint",
+ "token_endpoint",
+ "response_types_supported",
+ "subject_types_supported",
+ "id_token_signing_alg_values_supported",
+ }
+
+ for _, field := range requiredFields {
+ if _, exists := metadata[field]; !exists {
+ t.Errorf("Missing required OIDC field: %s", field)
+ }
+ }
+
+ if issuer := metadata["issuer"]; issuer != config.MCPURL {
+ t.Errorf("OIDC issuer = %s, expected %s", issuer, config.MCPURL)
+ }
+
+ if audience := metadata["audience"]; audience != config.Audience {
+ t.Errorf("OIDC audience = %s, expected %s", audience, config.Audience)
+ }
+}
diff --git a/middleware.go b/middleware.go
new file mode 100644
index 0000000..d4bce2b
--- /dev/null
+++ b/middleware.go
@@ -0,0 +1,252 @@
+package oauth
+
+import (
+ "context"
+ "crypto/sha256"
+ "fmt"
+ "log"
+ "net/http"
+ "strings"
+ "sync"
+ "time"
+
+ "github.com/mark3labs/mcp-go/mcp"
+ "github.com/mark3labs/mcp-go/server"
+ "github.com/tuannvm/oauth-mcp-proxy/provider"
+)
+
+// Re-export User from provider for backwards compatibility
+type User = provider.User
+
+// Context keys
+type contextKey string
+
+const (
+ oauthTokenKey contextKey = "oauth_token"
+ userContextKey contextKey = "user"
+)
+
+// TokenCache stores validated tokens to avoid re-validation
+type TokenCache struct {
+ mu sync.RWMutex
+ cache map[string]*CachedToken
+}
+
+// CachedToken represents a cached token validation result
+type CachedToken struct {
+ User *User
+ ExpiresAt time.Time
+}
+
+// WithOAuthToken adds an OAuth token to the context
+func WithOAuthToken(ctx context.Context, token string) context.Context {
+ return context.WithValue(ctx, oauthTokenKey, token)
+}
+
+// GetOAuthToken extracts an OAuth token from the context
+func GetOAuthToken(ctx context.Context) (string, bool) {
+ token, ok := ctx.Value(oauthTokenKey).(string)
+ return token, ok
+}
+
+// getCachedToken retrieves a cached token validation result
+func (tc *TokenCache) getCachedToken(tokenHash string) (*CachedToken, bool) {
+ tc.mu.RLock()
+
+ cached, exists := tc.cache[tokenHash]
+ if !exists {
+ tc.mu.RUnlock()
+ return nil, false
+ }
+
+ // Check if token is expired
+ if time.Now().After(cached.ExpiresAt) {
+ tc.mu.RUnlock()
+ // Schedule expired token deletion in a separate operation
+ go tc.deleteExpiredToken(tokenHash)
+ return nil, false
+ }
+
+ tc.mu.RUnlock()
+ return cached, true
+}
+
+// deleteExpiredToken safely deletes an expired token from the cache
+func (tc *TokenCache) deleteExpiredToken(tokenHash string) {
+ tc.mu.Lock()
+ defer tc.mu.Unlock()
+
+ // Double-check if token is still expired before deleting
+ if cached, exists := tc.cache[tokenHash]; exists && time.Now().After(cached.ExpiresAt) {
+ delete(tc.cache, tokenHash)
+ }
+}
+
+// setCachedToken stores a token validation result
+func (tc *TokenCache) setCachedToken(tokenHash string, user *User, expiresAt time.Time) {
+ tc.mu.Lock()
+ defer tc.mu.Unlock()
+
+ tc.cache[tokenHash] = &CachedToken{
+ User: user,
+ ExpiresAt: expiresAt,
+ }
+}
+
+// Middleware returns an authentication middleware for MCP tools.
+// Validates OAuth tokens, caches results, and adds authenticated user to context.
+//
+// The middleware:
+// 1. Extracts OAuth token from context (set by CreateHTTPContextFunc)
+// 2. Checks token cache (5-minute TTL)
+// 3. Validates token using configured provider if not cached
+// 4. Adds User to context via userContextKey
+// 5. Passes request to tool handler with authenticated context
+//
+// Use GetUserFromContext(ctx) in tool handlers to access authenticated user.
+//
+// Note: WithOAuth() returns this middleware wrapped as mcpserver.ServerOption.
+// Only call directly if using NewServer() for advanced use cases.
+func (s *Server) Middleware() func(server.ToolHandlerFunc) server.ToolHandlerFunc {
+ return func(next server.ToolHandlerFunc) server.ToolHandlerFunc {
+ return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
+ // Extract token from context (set by HTTP middleware)
+ tokenString, ok := GetOAuthToken(ctx)
+ if !ok {
+ s.logger.Info("No token found in context for tool: %s", req.Params.Name)
+ return nil, fmt.Errorf("authentication required: missing OAuth token")
+ }
+
+ // Create token hash for caching
+ tokenHash := fmt.Sprintf("%x", sha256.Sum256([]byte(tokenString)))
+
+ // Check cache first
+ if cached, exists := s.cache.getCachedToken(tokenHash); exists {
+ s.logger.Info("Using cached authentication for tool: %s (user: %s)", req.Params.Name, cached.User.Username)
+ ctx = context.WithValue(ctx, userContextKey, cached.User)
+ return next(ctx, req)
+ }
+
+ // Log token hash for debugging (prevents sensitive data exposure)
+ tokenHashFull := fmt.Sprintf("%x", sha256.Sum256([]byte(tokenString)))
+ tokenHashPreview := tokenHashFull[:16] + "..."
+ s.logger.Info("Validating token for tool %s (hash: %s)", req.Params.Name, tokenHashPreview)
+
+ // Validate token using configured provider (with request context for timeout/cancellation)
+ user, err := s.validator.ValidateToken(ctx, tokenString)
+ if err != nil {
+ s.logger.Error("Token validation failed for tool %s: %v", req.Params.Name, err)
+ return nil, fmt.Errorf("authentication failed: %w", err)
+ }
+
+ // Cache the validation result (expire in 5 minutes)
+ expiresAt := time.Now().Add(5 * time.Minute)
+ s.cache.setCachedToken(tokenHash, user, expiresAt)
+
+ // Add user to context for downstream handlers
+ ctx = context.WithValue(ctx, userContextKey, user)
+ s.logger.Info("Authenticated user %s for tool: %s (cached for 5 minutes)", user.Username, req.Params.Name)
+
+ return next(ctx, req)
+ }
+ }
+}
+
+// OAuthMiddleware creates an authentication middleware (legacy function for compatibility).
+//
+// Deprecated: Use WithOAuth() for new code. This function creates a temporary
+// Server instance for each call and doesn't support custom logging. Kept for
+// backward compatibility only.
+//
+// Modern usage:
+//
+// oauthOption, _ := oauth.WithOAuth(mux, &oauth.Config{...})
+// mcpServer := server.NewMCPServer("name", "1.0.0", oauthOption)
+func OAuthMiddleware(validator provider.TokenValidator, enabled bool) func(server.ToolHandlerFunc) server.ToolHandlerFunc {
+ // Create a temporary server for legacy compatibility
+ cache := &TokenCache{cache: make(map[string]*CachedToken)}
+ s := &Server{
+ validator: validator,
+ cache: cache,
+ logger: &defaultLogger{},
+ }
+
+ if !enabled {
+ // Return passthrough middleware
+ return func(next server.ToolHandlerFunc) server.ToolHandlerFunc {
+ return next
+ }
+ }
+
+ return s.Middleware()
+}
+
+// validateJWT is deprecated - use provider-based validation instead
+
+// GetUserFromContext extracts the authenticated user from context.
+// Returns the User and true if authentication succeeded, or nil and false otherwise.
+//
+// Example:
+//
+// func toolHandler(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
+// user, ok := oauth.GetUserFromContext(ctx)
+// if !ok {
+// return nil, fmt.Errorf("authentication required")
+// }
+// // Use user.Subject, user.Email, user.Username
+// return mcp.NewToolResultText("Hello, " + user.Username), nil
+// }
+func GetUserFromContext(ctx context.Context) (*User, bool) {
+ user, ok := ctx.Value(userContextKey).(*User)
+ return user, ok
+}
+
+// CreateHTTPContextFunc creates an HTTP context function that extracts OAuth tokens
+// from Authorization headers. Use with mcpserver.WithHTTPContextFunc() to enable
+// token extraction from HTTP requests.
+//
+// Example:
+//
+// streamableServer := mcpserver.NewStreamableHTTPServer(
+// mcpServer,
+// mcpserver.WithHTTPContextFunc(oauth.CreateHTTPContextFunc()),
+// )
+//
+// This extracts "Bearer " from Authorization header and adds it to context
+// via WithOAuthToken(). The OAuth middleware then retrieves it via GetOAuthToken().
+func CreateHTTPContextFunc() func(context.Context, *http.Request) context.Context {
+ return func(ctx context.Context, r *http.Request) context.Context {
+ // Extract Bearer token from Authorization header
+ authHeader := r.Header.Get("Authorization")
+ if strings.HasPrefix(authHeader, "Bearer ") {
+ token := strings.TrimPrefix(authHeader, "Bearer ")
+ // Clean any whitespace
+ token = strings.TrimSpace(token)
+ ctx = WithOAuthToken(ctx, token)
+ log.Printf("OAuth: Token extracted from request (length: %d)", len(token))
+ } else if authHeader != "" {
+ preview := authHeader
+ if len(authHeader) > 30 {
+ preview = authHeader[:30] + "..."
+ }
+ log.Printf("OAuth: Invalid Authorization header format: %s", preview)
+ }
+ return ctx
+ }
+}
+
+// CreateRequestAuthHook creates a server-level authentication hook for all MCP requests.
+//
+// Deprecated: This function cannot propagate context changes due to its signature limitation.
+// Use WithOAuth() instead, which properly handles context propagation via tool-level middleware.
+//
+// This function is a no-op that always returns nil. Authentication happens at the tool level
+// via Server.Middleware() which can properly propagate the authenticated user in context.
+func CreateRequestAuthHook(validator provider.TokenValidator) func(context.Context, interface{}, interface{}) error {
+ return func(ctx context.Context, id interface{}, message interface{}) error {
+ // This hook cannot propagate context changes due to its signature limitation.
+ // Authentication is handled by tool-level middleware instead.
+ log.Printf("OAuth: Server-level auth hook called for request ID: %v (using tool-level middleware)", id)
+ return nil // Always succeed - actual auth is done at tool level
+ }
+}
diff --git a/middleware_compatibility_test.go b/middleware_compatibility_test.go
new file mode 100644
index 0000000..bd7cbd9
--- /dev/null
+++ b/middleware_compatibility_test.go
@@ -0,0 +1,226 @@
+package oauth
+
+import (
+ "context"
+ "fmt"
+ "testing"
+ "time"
+
+ "github.com/golang-jwt/jwt/v5"
+ "github.com/mark3labs/mcp-go/mcp"
+ mcpserver "github.com/mark3labs/mcp-go/server"
+)
+
+// TestMCPGoMiddlewareCompatibility validates mcp-go v0.41.1 middleware integration
+func TestMCPGoMiddlewareCompatibility(t *testing.T) {
+ t.Run("WithToolHandlerMiddleware_ServerWide", func(t *testing.T) {
+ // This test validates that our middleware works with mcp-go v0.41.1's
+ // WithToolHandlerMiddleware option (server-wide middleware)
+
+ // 1. Create OAuth server
+ cfg := &Config{
+ Mode: "native",
+ Provider: "hmac",
+ Issuer: "https://test.example.com",
+ Audience: "api://test",
+ JWTSecret: []byte("test-secret-key-must-be-32-bytes-long!"),
+ }
+
+ oauthServer, err := NewServer(cfg)
+ if err != nil {
+ t.Fatalf("NewServer failed: %v", err)
+ }
+
+ // 2. Create MCP server with OAuth middleware (server-wide)
+ // This is the CORRECT pattern for mcp-go v0.41.1
+ mcpServer := mcpserver.NewMCPServer("Test Server", "1.0.0",
+ mcpserver.WithToolHandlerMiddleware(oauthServer.Middleware()),
+ )
+
+ // 3. Verify server was created successfully
+ if mcpServer == nil {
+ t.Fatal("MCP server creation failed")
+ }
+
+ // 4. Add a tool (middleware automatically applies)
+ toolCalled := false
+ var capturedCtx context.Context
+ mcpServer.AddTool(
+ mcp.Tool{
+ Name: "test_tool",
+ Description: "Test tool",
+ },
+ func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
+ toolCalled = true
+ capturedCtx = ctx
+
+ // Verify user was added to context by middleware
+ user, ok := GetUserFromContext(ctx)
+ if !ok {
+ return nil, fmt.Errorf("user not found in context")
+ }
+ if user.Subject != "test-user-123" {
+ return nil, fmt.Errorf("expected subject 'test-user-123', got '%s'", user.Subject)
+ }
+
+ return mcp.NewToolResultText("success"), nil
+ },
+ )
+
+ // 5. Manually test the middleware directly
+ token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
+ "sub": "test-user-123",
+ "email": "test@example.com",
+ "preferred_username": "testuser",
+ "aud": cfg.Audience,
+ "exp": time.Now().Add(time.Hour).Unix(),
+ "iat": time.Now().Unix(),
+ })
+
+ tokenString, _ := token.SignedString(cfg.JWTSecret)
+ ctx := WithOAuthToken(context.Background(), tokenString)
+
+ // Get the middleware and apply it to a test handler
+ middleware := oauthServer.Middleware()
+ testHandler := middleware(func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
+ toolCalled = true
+ capturedCtx = ctx
+ return mcp.NewToolResultText("ok"), nil
+ })
+
+ // Call the wrapped handler
+ result, err := testHandler(ctx, mcp.CallToolRequest{})
+ if err != nil {
+ t.Fatalf("Middleware handler failed: %v", err)
+ }
+
+ if !toolCalled {
+ t.Error("Tool was not called")
+ }
+
+ if result == nil {
+ t.Fatal("Expected result, got nil")
+ }
+
+ // Verify user is in context
+ if capturedCtx != nil {
+ user, ok := GetUserFromContext(capturedCtx)
+ if !ok {
+ t.Error("User not found in captured context")
+ }
+ if user != nil && user.Subject != "test-user-123" {
+ t.Errorf("Expected subject 'test-user-123', got '%s'", user.Subject)
+ }
+ }
+
+ t.Logf("✅ WithToolHandlerMiddleware compatible with mcp-go v0.41.1")
+ t.Logf(" - Middleware applied server-wide")
+ t.Logf(" - OAuth validation successful")
+ t.Logf(" - User context propagated to tool")
+ })
+
+ t.Run("MiddlewareCompilationCheck", func(t *testing.T) {
+ // Test that server creation with middleware compiles correctly
+
+ cfg := &Config{
+ Mode: "native",
+ Provider: "hmac",
+ Issuer: "https://test.example.com",
+ Audience: "api://test",
+ JWTSecret: []byte("test-secret-key-must-be-32-bytes-long!"),
+ }
+
+ oauthServer, _ := NewServer(cfg)
+
+ // This is the key test: server creation with middleware should compile
+ mcpServer := mcpserver.NewMCPServer("Test Server", "1.0.0",
+ mcpserver.WithToolHandlerMiddleware(oauthServer.Middleware()),
+ )
+
+ if mcpServer == nil {
+ t.Fatal("Server creation failed")
+ }
+
+ // Add multiple tools to verify middleware applies to all
+ for _, toolName := range []string{"tool1", "tool2", "tool3"} {
+ mcpServer.AddTool(
+ mcp.Tool{
+ Name: toolName,
+ Description: "Test tool",
+ },
+ func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
+ return mcp.NewToolResultText("ok"), nil
+ },
+ )
+ }
+
+ t.Logf("✅ Server-wide middleware compilation successful")
+ t.Logf(" - 3 tools added, all protected by middleware")
+ })
+
+ t.Run("MiddlewareRejectsInvalidToken", func(t *testing.T) {
+ // Test that middleware rejects invalid tokens
+
+ cfg := &Config{
+ Mode: "native",
+ Provider: "hmac",
+ Issuer: "https://test.example.com",
+ Audience: "api://test",
+ JWTSecret: []byte("test-secret-key-must-be-32-bytes-long!"),
+ }
+
+ oauthServer, _ := NewServer(cfg)
+
+ // Get middleware and test directly
+ middleware := oauthServer.Middleware()
+
+ toolCalled := false
+ wrappedHandler := middleware(func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
+ toolCalled = true
+ return mcp.NewToolResultText("should not reach here"), nil
+ })
+
+ // Try with invalid token
+ ctx := WithOAuthToken(context.Background(), "invalid-token")
+
+ _, err := wrappedHandler(ctx, mcp.CallToolRequest{})
+
+ // Should fail
+ if err == nil {
+ t.Error("Expected authentication error, got nil")
+ }
+
+ if toolCalled {
+ t.Error("Tool should not be called with invalid token")
+ }
+
+ t.Logf("✅ Middleware correctly rejects invalid tokens")
+ t.Logf(" - Error: %v", err)
+ })
+}
+
+// TestMiddlewareSignatureCompatibility validates the middleware function signature
+func TestMiddlewareSignatureCompatibility(t *testing.T) {
+ // This test ensures our Server.Middleware() returns the correct type
+ // for mcp-go v0.41.1's WithToolHandlerMiddleware
+
+ cfg := &Config{
+ Mode: "native",
+ Provider: "hmac",
+ Issuer: "https://test.example.com",
+ Audience: "api://test",
+ JWTSecret: []byte("test-secret-key-must-be-32-bytes-long!"),
+ }
+
+ server, _ := NewServer(cfg)
+
+ // Get middleware
+ middleware := server.Middleware()
+
+ // Type assertion: should be func(ToolHandlerFunc) ToolHandlerFunc
+ // If this compiles, the signature is correct
+ var _ = middleware
+
+ t.Logf("✅ Middleware signature is compatible with mcp-go v0.41.1")
+ t.Logf(" Type: func(server.ToolHandlerFunc) server.ToolHandlerFunc")
+}
diff --git a/oauth.go b/oauth.go
new file mode 100644
index 0000000..ce74374
--- /dev/null
+++ b/oauth.go
@@ -0,0 +1,124 @@
+package oauth
+
+import (
+ "fmt"
+ "net/http"
+
+ mcpserver "github.com/mark3labs/mcp-go/server"
+ "github.com/tuannvm/oauth-mcp-proxy/provider"
+)
+
+// Server represents an OAuth authentication server instance.
+// Each Server maintains its own token cache and validator, allowing
+// multiple independent OAuth configurations in the same application.
+//
+// Create using NewServer(). Access middleware via Middleware() and
+// register HTTP endpoints via RegisterHandlers().
+type Server struct {
+ config *Config
+ validator provider.TokenValidator
+ cache *TokenCache
+ handler *OAuth2Handler
+ logger Logger
+}
+
+// NewServer creates a new OAuth server with the given configuration.
+// Validates configuration, initializes provider-specific token validator,
+// and creates instance-scoped token cache.
+//
+// Example:
+//
+// server, err := oauth.NewServer(&oauth.Config{
+// Provider: "okta",
+// Issuer: "https://company.okta.com",
+// Audience: "api://my-server",
+// })
+//
+// Most users should use WithOAuth() instead, which wraps NewServer()
+// and automatically registers handlers and middleware.
+func NewServer(cfg *Config) (*Server, error) {
+ // Validate configuration
+ if err := cfg.Validate(); err != nil {
+ return nil, fmt.Errorf("invalid config: %w", err)
+ }
+
+ // Use default logger if not provided
+ logger := cfg.Logger
+ if logger == nil {
+ logger = &defaultLogger{}
+ }
+
+ // Create validator with logger
+ validator, err := createValidator(cfg, logger)
+ if err != nil {
+ return nil, fmt.Errorf("failed to create validator: %w", err)
+ }
+
+ // Create instance-scoped cache
+ cache := &TokenCache{
+ cache: make(map[string]*CachedToken),
+ }
+
+ // Create OAuth handler with logger
+ handler := CreateOAuth2Handler(cfg, "1.0.0", logger)
+
+ return &Server{
+ config: cfg,
+ validator: validator,
+ cache: cache,
+ handler: handler,
+ logger: logger,
+ }, nil
+}
+
+// RegisterHandlers registers OAuth HTTP endpoints on the provided mux.
+// Endpoints registered:
+// - /.well-known/oauth-authorization-server - OAuth 2.0 metadata (RFC 8414)
+// - /.well-known/oauth-protected-resource - Resource metadata
+// - /.well-known/jwks.json - JWKS keys
+// - /.well-known/openid-configuration - OIDC discovery
+// - /oauth/authorize - Authorization endpoint (proxy mode)
+// - /oauth/callback - Callback handler (proxy mode)
+// - /oauth/token - Token exchange (proxy mode)
+//
+// Note: WithOAuth() calls this automatically. Only call directly if using
+// NewServer() for advanced use cases.
+func (s *Server) RegisterHandlers(mux *http.ServeMux) {
+ mux.HandleFunc("/.well-known/oauth-authorization-server", s.handler.HandleAuthorizationServerMetadata)
+ mux.HandleFunc("/.well-known/oauth-protected-resource", s.handler.HandleProtectedResourceMetadata)
+ mux.HandleFunc("/.well-known/jwks.json", s.handler.HandleJWKS)
+ mux.HandleFunc("/oauth/authorize", s.handler.HandleAuthorize)
+ mux.HandleFunc("/oauth/callback", s.handler.HandleCallback)
+ mux.HandleFunc("/oauth/token", s.handler.HandleToken)
+ mux.HandleFunc("/.well-known/openid-configuration", s.handler.HandleOIDCDiscovery)
+}
+
+// WithOAuth returns a server option that enables OAuth authentication
+// This is the composable API for mcp-go v0.41.1
+//
+// Usage:
+//
+// mux := http.NewServeMux()
+// oauthOption, err := oauth.WithOAuth(mux, &oauth.Config{...})
+// mcpServer := server.NewMCPServer("Server", "1.0.0", oauthOption)
+//
+// This function:
+// - Creates OAuth server internally
+// - Registers OAuth HTTP endpoints on mux
+// - Returns middleware as server option
+//
+// Note: You must also configure HTTPContextFunc to extract the OAuth token
+// from HTTP headers. Use CreateHTTPContextFunc() helper.
+func WithOAuth(mux *http.ServeMux, cfg *Config) (mcpserver.ServerOption, error) {
+ // Create OAuth server
+ oauthServer, err := NewServer(cfg)
+ if err != nil {
+ return nil, fmt.Errorf("failed to create OAuth server: %w", err)
+ }
+
+ // Register HTTP handlers
+ oauthServer.RegisterHandlers(mux)
+
+ // Return middleware as server option
+ return mcpserver.WithToolHandlerMiddleware(oauthServer.Middleware()), nil
+}
diff --git a/provider/provider.go b/provider/provider.go
new file mode 100644
index 0000000..53007e6
--- /dev/null
+++ b/provider/provider.go
@@ -0,0 +1,333 @@
+package provider
+
+import (
+ "context"
+ "crypto/tls"
+ "fmt"
+ "net/http"
+ "strings"
+ "sync"
+ "time"
+
+ "github.com/coreos/go-oidc/v3/oidc"
+ "github.com/golang-jwt/jwt/v5"
+)
+
+// User represents an authenticated user
+type User struct {
+ Username string
+ Email string
+ Subject string
+}
+
+// Logger interface for pluggable logging
+type Logger interface {
+ Debug(msg string, args ...interface{})
+ Info(msg string, args ...interface{})
+ Warn(msg string, args ...interface{})
+ Error(msg string, args ...interface{})
+}
+
+// Config holds OAuth configuration (subset needed by provider)
+type Config struct {
+ Provider string
+ Issuer string
+ Audience string
+ JWTSecret []byte
+ Logger Logger
+}
+
+// TokenValidator interface for OAuth token validation
+type TokenValidator interface {
+ ValidateToken(ctx context.Context, token string) (*User, error)
+ Initialize(cfg *Config) error
+}
+
+// HMACValidator validates JWT tokens using HMAC-SHA256 (backward compatibility)
+type HMACValidator struct {
+ secret string
+ audience string
+ secretOnce sync.Once
+}
+
+// OIDCValidator validates JWT tokens using OIDC/JWKS (Okta, Google, Azure)
+type OIDCValidator struct {
+ verifier *oidc.IDTokenVerifier
+ provider *oidc.Provider
+ audience string
+ logger Logger
+}
+
+// Initialize sets up the HMAC validator with JWT secret and audience
+func (v *HMACValidator) Initialize(cfg *Config) error {
+ v.secretOnce.Do(func() {
+ v.secret = string(cfg.JWTSecret)
+ v.audience = cfg.Audience
+ })
+
+ if v.secret == "" {
+ return fmt.Errorf("JWT_SECRET is required for HMAC provider")
+ }
+
+ if v.audience == "" {
+ return fmt.Errorf("JWT audience is required for HMAC provider")
+ }
+
+ return nil
+}
+
+// ValidateToken validates JWT token using HMAC-SHA256
+func (v *HMACValidator) ValidateToken(ctx context.Context, tokenString string) (*User, error) {
+ // Note: ctx parameter accepted for interface compliance, but HMAC validation is local-only (no I/O)
+ // Remove Bearer prefix if present
+ tokenString = strings.TrimPrefix(tokenString, "Bearer ")
+
+ // Parse and validate JWT with signature verification
+ token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
+ // Validate signing method
+ if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
+ return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
+ }
+ return []byte(v.secret), nil
+ })
+
+ if err != nil {
+ return nil, fmt.Errorf("failed to parse and validate token: %w", err)
+ }
+
+ if !token.Valid {
+ return nil, fmt.Errorf("invalid token")
+ }
+
+ claims, ok := token.Claims.(jwt.MapClaims)
+ if !ok {
+ return nil, fmt.Errorf("invalid token claims")
+ }
+
+ // Validate required claims including audience
+ if err := validateTokenClaims(claims); err != nil {
+ return nil, fmt.Errorf("token validation failed: %w", err)
+ }
+
+ // Validate audience claim for security
+ if err := v.validateAudience(claims); err != nil {
+ return nil, fmt.Errorf("audience validation failed: %w", err)
+ }
+
+ // Extract user information
+ user := &User{
+ Subject: getStringClaim(claims, "sub"),
+ Username: getStringClaim(claims, "preferred_username"),
+ Email: getStringClaim(claims, "email"),
+ }
+
+ if user.Subject == "" {
+ return nil, fmt.Errorf("missing subject in token")
+ }
+
+ return user, nil
+}
+
+// validateAudience validates the audience claim matches the expected value
+func (v *HMACValidator) validateAudience(claims jwt.MapClaims) error {
+ // Extract audience claim (can be string or []string)
+ audClaim, exists := claims["aud"]
+ if !exists {
+ return fmt.Errorf("missing audience claim")
+ }
+
+ // Handle string audience
+ if audStr, ok := audClaim.(string); ok {
+ if audStr != v.audience {
+ return fmt.Errorf("invalid audience: expected %s, got %s", v.audience, audStr)
+ }
+ return nil
+ }
+
+ // Handle array of audiences
+ if audArray, ok := audClaim.([]interface{}); ok {
+ for _, aud := range audArray {
+ if audStr, ok := aud.(string); ok && audStr == v.audience {
+ return nil
+ }
+ }
+ return fmt.Errorf("invalid audience: expected %s not found in audience list", v.audience)
+ }
+
+ return fmt.Errorf("invalid audience claim type")
+}
+
+// Initialize sets up the OIDC validator with provider discovery
+func (v *OIDCValidator) Initialize(cfg *Config) error {
+ if cfg.Issuer == "" {
+ return fmt.Errorf("OIDC issuer is required for OIDC provider")
+ }
+ if cfg.Audience == "" {
+ return fmt.Errorf("OIDC audience is required for OIDC provider")
+ }
+
+ // Use standard library context with timeout
+ ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
+ defer cancel()
+
+ // Configure HTTP client with appropriate timeouts and TLS settings
+ httpClient := &http.Client{
+ Timeout: 30 * time.Second,
+ Transport: &http.Transport{
+ TLSClientConfig: &tls.Config{
+ InsecureSkipVerify: false, // Verify TLS certificates
+ MinVersion: tls.VersionTLS12,
+ },
+ IdleConnTimeout: 90 * time.Second,
+ TLSHandshakeTimeout: 10 * time.Second,
+ MaxIdleConns: 100,
+ MaxIdleConnsPerHost: 10,
+ },
+ }
+
+ // Create OIDC provider with custom HTTP client
+ provider, err := oidc.NewProvider(
+ oidc.ClientContext(ctx, httpClient),
+ cfg.Issuer,
+ )
+ if err != nil {
+ return fmt.Errorf("failed to initialize OIDC provider: %w", err)
+ }
+
+ // Configure token verifier with required validation settings
+ verifier := provider.Verifier(&oidc.Config{
+ ClientID: cfg.Audience, // Note: go-oidc uses ClientID field for audience validation - see https://github.com/coreos/go-oidc/blob/v3/oidc/verify.go#L85
+ SupportedSigningAlgs: []string{oidc.RS256, oidc.ES256},
+ SkipClientIDCheck: false, // Always validate if ClientID is provided
+ SkipExpiryCheck: false, // Verify expiration
+ SkipIssuerCheck: false, // Verify issuer
+ })
+
+ v.logger.Info("OAuth: OIDC validator initialized with audience validation: %s", cfg.Audience)
+
+ v.provider = provider
+ v.verifier = verifier
+ v.audience = cfg.Audience
+ return nil
+}
+
+// ValidateToken validates JWT token using OIDC/JWKS
+func (v *OIDCValidator) ValidateToken(ctx context.Context, tokenString string) (*User, error) {
+ // Remove Bearer prefix if present
+ tokenString = strings.TrimPrefix(tokenString, "Bearer ")
+
+ // Use incoming context with timeout for OIDC provider call
+ ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
+ defer cancel()
+
+ // go-oidc handles RSA signature validation, JWKS fetching, and key rotation
+ idToken, err := v.verifier.Verify(ctx, tokenString)
+ if err != nil {
+ return nil, fmt.Errorf("token verification failed: %w", err)
+ }
+
+ // Extract claims from verified token
+ var claims struct {
+ Subject string `json:"sub"`
+ PreferredUsername string `json:"preferred_username"`
+ Email string `json:"email"`
+ EmailVerified bool `json:"email_verified,omitempty"`
+ Name string `json:"name,omitempty"`
+ // Standard OIDC claims are validated by go-oidc:
+ // - iss (issuer)
+ // - aud (audience)
+ // - exp (expiration)
+ // - iat (issued at)
+ // - nbf (not before)
+ }
+
+ if err := idToken.Claims(&claims); err != nil {
+ return nil, fmt.Errorf("failed to extract claims: %w", err)
+ }
+
+ // Extract raw claims for audience validation
+ var rawClaims jwt.MapClaims
+ if err := idToken.Claims(&rawClaims); err != nil {
+ return nil, fmt.Errorf("failed to extract raw claims: %w", err)
+ }
+
+ // Validate audience claim for security (explicit check)
+ if err := v.validateAudience(rawClaims); err != nil {
+ return nil, fmt.Errorf("audience validation failed: %w", err)
+ }
+
+ return &User{
+ Subject: claims.Subject,
+ Username: claims.PreferredUsername,
+ Email: claims.Email,
+ }, nil
+}
+
+// validateAudience validates the audience claim matches the expected value for OIDC tokens
+func (v *OIDCValidator) validateAudience(claims jwt.MapClaims) error {
+ // Extract audience claim (can be string or []string)
+ audClaim, exists := claims["aud"]
+ if !exists {
+ return fmt.Errorf("missing audience claim")
+ }
+
+ // Handle string audience
+ if audStr, ok := audClaim.(string); ok {
+ if audStr != v.audience {
+ return fmt.Errorf("invalid audience: expected %s, got %s", v.audience, audStr)
+ }
+ return nil
+ }
+
+ // Handle array of audiences
+ if audArray, ok := audClaim.([]interface{}); ok {
+ for _, aud := range audArray {
+ if audStr, ok := aud.(string); ok && audStr == v.audience {
+ return nil
+ }
+ }
+ return fmt.Errorf("invalid audience: expected %s not found in audience list", v.audience)
+ }
+
+ return fmt.Errorf("invalid audience claim type")
+}
+
+// validateTokenClaims validates standard JWT claims
+func validateTokenClaims(claims jwt.MapClaims) error {
+ // Validate expiration
+ if exp, ok := claims["exp"]; ok {
+ if expTime, ok := exp.(float64); ok {
+ if time.Now().Unix() > int64(expTime) {
+ return fmt.Errorf("token expired")
+ }
+ }
+ }
+
+ // Validate not before
+ if nbf, ok := claims["nbf"]; ok {
+ if nbfTime, ok := nbf.(float64); ok {
+ if time.Now().Unix() < int64(nbfTime) {
+ return fmt.Errorf("token not yet valid")
+ }
+ }
+ }
+
+ // Validate issued at (should not be in the future)
+ if iat, ok := claims["iat"]; ok {
+ if iatTime, ok := iat.(float64); ok {
+ if time.Now().Unix() < int64(iatTime) {
+ return fmt.Errorf("token issued in the future")
+ }
+ }
+ }
+
+ return nil
+}
+
+// getStringClaim safely extracts a string claim
+func getStringClaim(claims jwt.MapClaims, key string) string {
+ if val, ok := claims[key].(string); ok {
+ return val
+ }
+ return ""
+}
diff --git a/provider/provider_test.go b/provider/provider_test.go
new file mode 100644
index 0000000..7f0bd25
--- /dev/null
+++ b/provider/provider_test.go
@@ -0,0 +1,373 @@
+package provider
+
+import (
+ "context"
+ "strings"
+ "testing"
+ "time"
+
+ "github.com/golang-jwt/jwt/v5"
+)
+
+// TestHMACValidator_AudienceValidation tests JWT audience validation
+func TestHMACValidator_AudienceValidation(t *testing.T) {
+ // Test configuration
+ cfg := &Config{
+ JWTSecret: []byte("test-secret-key-for-hmac-validation"),
+ Audience: "test-service-audience",
+ }
+
+ validator := &HMACValidator{}
+ err := validator.Initialize(cfg)
+ if err != nil {
+ t.Fatalf("Failed to initialize validator: %v", err)
+ }
+
+ t.Run("ValidAudience", func(t *testing.T) {
+ // Create token with correct audience
+ token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
+ "sub": "test-user",
+ "aud": "test-service-audience",
+ "exp": time.Now().Add(time.Hour).Unix(),
+ "iat": time.Now().Unix(),
+ "email": "test@example.com",
+ })
+
+ tokenString, err := token.SignedString([]byte(cfg.JWTSecret))
+ if err != nil {
+ t.Fatalf("Failed to sign token: %v", err)
+ }
+
+ user, err := validator.ValidateToken(context.Background(), tokenString)
+ if err != nil {
+ t.Errorf("Expected valid token to pass, got error: %v", err)
+ }
+
+ if user == nil || user.Subject != "test-user" {
+ t.Errorf("Expected valid user, got: %+v", user)
+ }
+ })
+
+ t.Run("InvalidAudience", func(t *testing.T) {
+ // Create token with wrong audience
+ token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
+ "sub": "test-user",
+ "aud": "wrong.audience.com", // Wrong audience
+ "exp": time.Now().Add(time.Hour).Unix(),
+ "iat": time.Now().Unix(),
+ })
+
+ tokenString, err := token.SignedString([]byte(cfg.JWTSecret))
+ if err != nil {
+ t.Fatalf("Failed to sign token: %v", err)
+ }
+
+ _, err = validator.ValidateToken(context.Background(), tokenString)
+ if err == nil {
+ t.Error("Expected token with wrong audience to fail validation")
+ }
+
+ if err != nil && err.Error() != "audience validation failed: invalid audience: expected test-service-audience, got wrong.audience.com" {
+ t.Errorf("Expected specific audience error, got: %v", err)
+ }
+ })
+
+ t.Run("MissingAudience", func(t *testing.T) {
+ // Create token without audience
+ token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
+ "sub": "test-user",
+ "exp": time.Now().Add(time.Hour).Unix(),
+ "iat": time.Now().Unix(),
+ })
+
+ tokenString, err := token.SignedString([]byte(cfg.JWTSecret))
+ if err != nil {
+ t.Fatalf("Failed to sign token: %v", err)
+ }
+
+ _, err = validator.ValidateToken(context.Background(), tokenString)
+ if err == nil {
+ t.Error("Expected token without audience to fail validation")
+ }
+
+ if err != nil && err.Error() != "audience validation failed: missing audience claim" {
+ t.Errorf("Expected missing audience error, got: %v", err)
+ }
+ })
+
+ t.Run("AudienceArray", func(t *testing.T) {
+ // Create token with audience as array (valid)
+ token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
+ "sub": "test-user",
+ "aud": []string{"test-service-audience", "other.service.com"}, // Array with correct audience
+ "exp": time.Now().Add(time.Hour).Unix(),
+ "iat": time.Now().Unix(),
+ })
+
+ tokenString, err := token.SignedString([]byte(cfg.JWTSecret))
+ if err != nil {
+ t.Fatalf("Failed to sign token: %v", err)
+ }
+
+ user, err := validator.ValidateToken(context.Background(), tokenString)
+ if err != nil {
+ t.Errorf("Expected token with correct audience in array to pass, got error: %v", err)
+ }
+
+ if user == nil || user.Subject != "test-user" {
+ t.Errorf("Expected valid user, got: %+v", user)
+ }
+ })
+
+ t.Run("AudienceArrayInvalid", func(t *testing.T) {
+ // Create token with audience array not containing expected audience
+ token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
+ "sub": "test-user",
+ "aud": []string{"wrong.service.com", "other.service.com"}, // Array without correct audience
+ "exp": time.Now().Add(time.Hour).Unix(),
+ "iat": time.Now().Unix(),
+ })
+
+ tokenString, err := token.SignedString([]byte(cfg.JWTSecret))
+ if err != nil {
+ t.Fatalf("Failed to sign token: %v", err)
+ }
+
+ _, err = validator.ValidateToken(context.Background(), tokenString)
+ if err == nil {
+ t.Error("Expected token with wrong audience array to fail validation")
+ }
+
+ if err != nil && err.Error() != "audience validation failed: invalid audience: expected test-service-audience not found in audience list" {
+ t.Errorf("Expected specific audience array error, got: %v", err)
+ }
+ })
+}
+
+// TestHMACValidator_InitializationValidation tests validator initialization
+func TestHMACValidator_InitializationValidation(t *testing.T) {
+ t.Run("MissingSecret", func(t *testing.T) {
+ cfg := &Config{
+ JWTSecret: []byte(""), // Missing secret
+ Audience: "test-service-audience",
+ }
+
+ validator := &HMACValidator{}
+ err := validator.Initialize(cfg)
+
+ if err == nil {
+ t.Error("Expected initialization to fail with missing secret")
+ }
+
+ if err != nil && err.Error() != "JWT_SECRET is required for HMAC provider" {
+ t.Errorf("Expected specific secret error, got: %v", err)
+ }
+ })
+
+ t.Run("MissingAudience", func(t *testing.T) {
+ cfg := &Config{
+ JWTSecret: []byte("test-secret"),
+ Audience: "", // Missing audience
+ }
+
+ validator := &HMACValidator{}
+ err := validator.Initialize(cfg)
+
+ if err == nil {
+ t.Error("Expected initialization to fail with missing audience")
+ }
+
+ if err != nil && err.Error() != "JWT audience is required for HMAC provider" {
+ t.Errorf("Expected specific audience error, got: %v", err)
+ }
+ })
+
+ t.Run("ValidConfiguration", func(t *testing.T) {
+ cfg := &Config{
+ JWTSecret: []byte("test-secret"),
+ Audience: "test-service-audience",
+ }
+
+ validator := &HMACValidator{}
+ err := validator.Initialize(cfg)
+
+ if err != nil {
+ t.Errorf("Expected valid configuration to succeed, got error: %v", err)
+ }
+
+ if validator.secret != "test-secret" {
+ t.Errorf("Expected secret to be set correctly")
+ }
+
+ if validator.audience != "test-service-audience" {
+ t.Errorf("Expected audience to be set correctly")
+ }
+ })
+}
+
+// TestHMACValidator_SecurityValidation tests that the vulnerability is fixed
+func TestHMACValidator_SecurityValidation(t *testing.T) {
+ // This test specifically validates that the vulnerability described in PE-7429 is fixed
+
+ t.Run("RejectCrossServiceToken", func(t *testing.T) {
+ cfg := &Config{
+ JWTSecret: []byte("test-secret"),
+ Audience: "test-service-audience",
+ }
+
+ validator := &HMACValidator{}
+ err := validator.Initialize(cfg)
+ if err != nil {
+ t.Fatalf("Failed to initialize validator: %v", err)
+ }
+
+ // Simulate a token from another service (different audience)
+ crossServiceToken := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
+ "sub": "cross-service-user",
+ "aud": "other.service.com", // Different service audience
+ "exp": time.Now().Add(time.Hour).Unix(),
+ "iat": time.Now().Unix(),
+ "iss": "company.okta.com", // Same issuer
+ })
+
+ tokenString, err := crossServiceToken.SignedString([]byte(cfg.JWTSecret))
+ if err != nil {
+ t.Fatalf("Failed to sign cross-service token: %v", err)
+ }
+
+ // This should FAIL - the vulnerability would allow this to pass
+ _, err = validator.ValidateToken(context.Background(), tokenString)
+ if err == nil {
+ t.Error("SECURITY VULNERABILITY: Cross-service token was accepted! This should fail.")
+ }
+
+ // Verify it fails for the correct reason (audience validation)
+ if err != nil && !strings.Contains(err.Error(), "audience validation failed") {
+ t.Errorf("Token failed for wrong reason. Expected audience validation failure, got: %v", err)
+ }
+ })
+}
+
+// TestOIDCValidator_AudienceValidation tests OIDC JWT audience validation
+func TestOIDCValidator_AudienceValidation(t *testing.T) {
+ // Test the validateAudience method directly since OIDC provider setup requires external services
+ validator := &OIDCValidator{
+ audience: "test-service-audience",
+ }
+
+ tests := []struct {
+ name string
+ claims jwt.MapClaims
+ expectErr bool
+ errMsg string
+ }{
+ {
+ name: "valid audience string",
+ claims: jwt.MapClaims{
+ "aud": "test-service-audience",
+ "sub": "user123",
+ },
+ expectErr: false,
+ },
+ {
+ name: "invalid audience string",
+ claims: jwt.MapClaims{
+ "aud": "wrong.audience.com",
+ "sub": "user123",
+ },
+ expectErr: true,
+ errMsg: "invalid audience: expected test-service-audience, got wrong.audience.com",
+ },
+ {
+ name: "missing audience claim",
+ claims: jwt.MapClaims{
+ "sub": "user123",
+ },
+ expectErr: true,
+ errMsg: "missing audience claim",
+ },
+ {
+ name: "valid audience array",
+ claims: jwt.MapClaims{
+ "aud": []interface{}{"test-service-audience", "other.service.com"},
+ "sub": "user123",
+ },
+ expectErr: false,
+ },
+ {
+ name: "invalid audience array",
+ claims: jwt.MapClaims{
+ "aud": []interface{}{"wrong.service.com", "other.service.com"},
+ "sub": "user123",
+ },
+ expectErr: true,
+ errMsg: "invalid audience: expected test-service-audience not found in audience list",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ err := validator.validateAudience(tt.claims)
+
+ if tt.expectErr {
+ if err == nil {
+ t.Errorf("Expected error but got none")
+ } else if tt.errMsg != "" && err.Error() != tt.errMsg {
+ t.Errorf("Expected error message '%s', got '%s'", tt.errMsg, err.Error())
+ }
+ } else {
+ if err != nil {
+ t.Errorf("Unexpected error: %v", err)
+ }
+ }
+ })
+ }
+}
+
+// TestOIDCValidator_InitializationValidation tests OIDC initialization validation
+func TestOIDCValidator_InitializationValidation(t *testing.T) {
+ tests := []struct {
+ name string
+ config *Config
+ expectError bool
+ errorMsg string
+ }{
+ {
+ name: "missing issuer",
+ config: &Config{
+ Issuer: "",
+ Audience: "test-audience",
+ },
+ expectError: true,
+ errorMsg: "OIDC issuer is required for OIDC provider",
+ },
+ {
+ name: "missing audience",
+ config: &Config{
+ Issuer: "https://example.com",
+ Audience: "",
+ },
+ expectError: true,
+ errorMsg: "OIDC audience is required for OIDC provider",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ validator := &OIDCValidator{}
+ err := validator.Initialize(tt.config)
+
+ if tt.expectError {
+ if err == nil {
+ t.Errorf("Expected error but got none")
+ } else if tt.errorMsg != "" && err.Error() != tt.errorMsg {
+ t.Errorf("Expected error message '%s', got '%s'", tt.errorMsg, err.Error())
+ }
+ } else {
+ if err != nil {
+ t.Errorf("Unexpected error: %v", err)
+ }
+ }
+ })
+ }
+}
diff --git a/security_scenarios_test.go b/security_scenarios_test.go
new file mode 100644
index 0000000..a04ccdd
--- /dev/null
+++ b/security_scenarios_test.go
@@ -0,0 +1,134 @@
+package oauth
+
+import (
+ "crypto/rand"
+ "encoding/base64"
+ "encoding/json"
+ "net/url"
+ "testing"
+)
+
+func TestSecurityScenarios(t *testing.T) {
+ key := make([]byte, 32)
+ _, _ = rand.Read(key)
+
+ handler := &OAuth2Handler{
+ config: &OAuth2Config{
+ stateSigningKey: key,
+ RedirectURIs: "https://mcp-server.com/oauth/callback",
+ },
+ }
+
+ t.Run("Attack: State tampering to redirect to attacker site", func(t *testing.T) {
+ // Attacker obtains valid signed state
+ stateData := map[string]string{
+ "state": "legitimate-state",
+ "redirect": "https://legitimate-client.com/callback",
+ }
+
+ signedState, err := handler.signState(stateData)
+ if err != nil {
+ t.Fatalf("Failed to sign state: %v", err)
+ }
+
+ // Attacker tries to decode and modify the redirect URI
+ decoded, _ := base64.URLEncoding.DecodeString(signedState)
+ var tamperedData map[string]string
+ _ = json.Unmarshal(decoded, &tamperedData)
+
+ // Change redirect to evil site
+ tamperedData["redirect"] = "https://evil.com/steal-codes"
+
+ // Re-encode (but signature is now invalid)
+ tamperedJSON, _ := json.Marshal(tamperedData)
+ tamperedState := base64.URLEncoding.EncodeToString(tamperedJSON)
+
+ // Try to verify tampered state
+ _, err = handler.verifyState(tamperedState)
+
+ // Should fail due to invalid signature
+ if err == nil {
+ t.Error("SECURITY FAILURE: Tampered state was accepted!")
+ } else {
+ t.Logf("✓ Security working: Tampered state rejected: %v", err)
+ }
+ })
+
+ t.Run("Attack: Remove signature from state", func(t *testing.T) {
+ // Create unsigned state without signature
+ unsignedData := map[string]string{
+ "state": "some-state",
+ "redirect": "https://evil.com/callback",
+ }
+
+ unsignedJSON, _ := json.Marshal(unsignedData)
+ unsignedState := base64.URLEncoding.EncodeToString(unsignedJSON)
+
+ // Try to verify unsigned state
+ _, err := handler.verifyState(unsignedState)
+
+ if err == nil {
+ t.Error("SECURITY FAILURE: Unsigned state was accepted!")
+ } else {
+ t.Logf("✓ Security working: Unsigned state rejected: %v", err)
+ }
+ })
+
+ t.Run("Attack: Replay state from different session", func(t *testing.T) {
+ // Sign state with one handler
+ stateData := map[string]string{
+ "state": "session-1",
+ "redirect": "https://client.com/callback",
+ }
+ signedState, _ := handler.signState(stateData)
+
+ // Create new handler with different key (simulates different server/restart)
+ newKey := make([]byte, 32)
+ _, _ = rand.Read(newKey)
+
+ newHandler := &OAuth2Handler{
+ config: &OAuth2Config{
+ stateSigningKey: newKey,
+ },
+ }
+
+ // Try to use old state with new handler
+ _, err := newHandler.verifyState(signedState)
+
+ if err == nil {
+ t.Error("SECURITY FAILURE: State from different key was accepted!")
+ } else {
+ t.Logf("✓ Security working: Cross-session state rejected: %v", err)
+ }
+ })
+}
+
+func TestHTTPSEnforcementForNonLocalhost(t *testing.T) {
+ tests := []struct {
+ name string
+ uri string
+ shouldFail bool
+ }{
+ {"HTTP localhost allowed", "http://localhost:8080/callback", false},
+ {"HTTP 127.0.0.1 allowed", "http://127.0.0.1:3000/callback", false},
+ {"HTTPS production allowed", "https://example.com/callback", false},
+ {"HTTP production rejected", "http://example.com/callback", true},
+ {"HTTP subdomain rejected", "http://app.example.com/callback", true},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ isLocalhost := isLocalhostURI(tt.uri)
+ parsed, _ := url.Parse(tt.uri)
+
+ requiresHTTPS := !isLocalhost && parsed.Scheme == "http"
+
+ if requiresHTTPS && !tt.shouldFail {
+ t.Errorf("HTTP non-localhost should be rejected but test expects pass: %s", tt.uri)
+ }
+ if !requiresHTTPS && tt.shouldFail {
+ t.Errorf("URI should be allowed but test expects fail: %s", tt.uri)
+ }
+ })
+ }
+}
diff --git a/security_test.go b/security_test.go
new file mode 100644
index 0000000..bd40ab6
--- /dev/null
+++ b/security_test.go
@@ -0,0 +1,267 @@
+package oauth
+
+import (
+ "net/http"
+ "net/http/httptest"
+ "strings"
+ "testing"
+)
+
+func TestRedirectURIValidation(t *testing.T) {
+ tests := []struct {
+ name string
+ allowedRedirects string
+ testURI string
+ expected bool
+ }{
+ {
+ name: "No allowlist configured - reject all",
+ allowedRedirects: "",
+ testURI: "https://client.example.com/callback",
+ expected: false,
+ },
+ {
+ name: "Single URI match",
+ allowedRedirects: "https://client.example.com/callback",
+ testURI: "https://client.example.com/callback",
+ expected: true,
+ },
+ {
+ name: "Multiple URIs - first match",
+ allowedRedirects: "https://client1.example.com/callback,https://client2.example.com/callback",
+ testURI: "https://client1.example.com/callback",
+ expected: true,
+ },
+ {
+ name: "Multiple URIs - second match",
+ allowedRedirects: "https://client1.example.com/callback,https://client2.example.com/callback",
+ testURI: "https://client2.example.com/callback",
+ expected: true,
+ },
+ {
+ name: "No match",
+ allowedRedirects: "https://client1.example.com/callback",
+ testURI: "https://malicious.example.com/callback",
+ expected: false,
+ },
+ {
+ name: "Partial match rejected",
+ allowedRedirects: "https://client.example.com/callback",
+ testURI: "https://client.example.com/callback/evil",
+ expected: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ config := &OAuth2Config{
+ RedirectURIs: tt.allowedRedirects,
+ }
+ handler := &OAuth2Handler{config: config, logger: &defaultLogger{}}
+
+ result := handler.isValidRedirectURI(tt.testURI)
+ if result != tt.expected {
+ t.Errorf("isValidRedirectURI(%q) = %v, expected %v", tt.testURI, result, tt.expected)
+ }
+ })
+ }
+}
+
+func TestOAuthParameterValidation(t *testing.T) {
+ handler := &OAuth2Handler{logger: &defaultLogger{}}
+
+ tests := []struct {
+ name string
+ params map[string]string
+ expectError bool
+ errorMsg string
+ }{
+ {
+ name: "Valid parameters",
+ params: map[string]string{
+ "code": "valid_code_123",
+ "state": "valid_state",
+ "code_challenge": "valid_challenge",
+ },
+ expectError: false,
+ },
+ {
+ name: "Code too long",
+ params: map[string]string{
+ "code": strings.Repeat("a", 513), // 513 characters
+ },
+ expectError: true,
+ errorMsg: "invalid code parameter length",
+ },
+ {
+ name: "State too long",
+ params: map[string]string{
+ "state": strings.Repeat("a", 257), // 257 characters
+ },
+ expectError: true,
+ errorMsg: "invalid state parameter length",
+ },
+ {
+ name: "Code challenge too long",
+ params: map[string]string{
+ "code_challenge": strings.Repeat("a", 257), // 257 characters
+ },
+ expectError: true,
+ errorMsg: "invalid code_challenge parameter length",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ // Create a form request with the test parameters
+ values := make([]string, 0, len(tt.params)*2)
+ for key, value := range tt.params {
+ values = append(values, key, value)
+ }
+
+ req := httptest.NewRequest("POST", "/test", strings.NewReader(""))
+ req.Form = make(map[string][]string)
+ for i := 0; i < len(values); i += 2 {
+ req.Form[values[i]] = []string{values[i+1]}
+ }
+
+ err := handler.validateOAuthParams(req)
+
+ if tt.expectError {
+ if err == nil {
+ t.Errorf("Expected error but got none")
+ } else if !strings.Contains(err.Error(), tt.errorMsg) {
+ t.Errorf("Expected error containing %q, got %q", tt.errorMsg, err.Error())
+ }
+ } else {
+ if err != nil {
+ t.Errorf("Unexpected error: %v", err)
+ }
+ }
+ })
+ }
+}
+
+func TestSecurityHeaders(t *testing.T) {
+ handler := &OAuth2Handler{logger: &defaultLogger{}}
+ recorder := httptest.NewRecorder()
+
+ handler.addSecurityHeaders(recorder)
+
+ expectedHeaders := map[string]string{
+ "X-Content-Type-Options": "nosniff",
+ "X-Frame-Options": "DENY",
+ "Cache-Control": "no-store, no-cache, max-age=0",
+ "Pragma": "no-cache",
+ }
+
+ for header, expectedValue := range expectedHeaders {
+ actualValue := recorder.Header().Get(header)
+ if actualValue != expectedValue {
+ t.Errorf("Header %s = %q, expected %q", header, actualValue, expectedValue)
+ }
+ }
+}
+
+func TestHTTPSEnforcementInHandlers(t *testing.T) {
+ config := &OAuth2Config{
+ Mode: "proxy",
+ }
+ handler := &OAuth2Handler{config: config, logger: &defaultLogger{}}
+
+ endpoints := []struct {
+ name string
+ handler func(http.ResponseWriter, *http.Request)
+ }{
+ {"HandleAuthorize", handler.HandleAuthorize},
+ {"HandleCallback", handler.HandleCallback},
+ {"HandleToken", handler.HandleToken},
+ }
+
+ for _, endpoint := range endpoints {
+ t.Run("Native mode blocks "+endpoint.name, func(t *testing.T) {
+ // Test native mode blocking
+ nativeConfig := &OAuth2Config{Mode: "native"}
+ nativeHandler := &OAuth2Handler{config: nativeConfig}
+
+ var testHandler func(http.ResponseWriter, *http.Request)
+ switch endpoint.name {
+ case "HandleAuthorize":
+ testHandler = nativeHandler.HandleAuthorize
+ case "HandleCallback":
+ testHandler = nativeHandler.HandleCallback
+ case "HandleToken":
+ testHandler = nativeHandler.HandleToken
+ }
+
+ recorder := httptest.NewRecorder()
+ req := httptest.NewRequest("GET", "/test", nil)
+
+ testHandler(recorder, req)
+
+ if recorder.Code != http.StatusNotFound {
+ t.Errorf("%s in native mode should return 404, got %d", endpoint.name, recorder.Code)
+ }
+
+ body := recorder.Body.String()
+ if !strings.Contains(body, "OAuth proxy disabled in native mode") {
+ t.Errorf("%s should return OAuth proxy disabled message", endpoint.name)
+ }
+ })
+ }
+}
+
+func TestJWKSProxyMode(t *testing.T) {
+ tests := []struct {
+ name string
+ mode string
+ provider string
+ expected int
+ }{
+ {
+ name: "Native mode blocks JWKS",
+ mode: "native",
+ provider: "okta",
+ expected: http.StatusNotFound,
+ },
+ {
+ name: "HMAC provider returns empty JWKS",
+ mode: "proxy",
+ provider: "hmac",
+ expected: http.StatusOK,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ config := &OAuth2Config{
+ Mode: tt.mode,
+ Provider: tt.provider,
+ }
+ handler := &OAuth2Handler{config: config, logger: &defaultLogger{}}
+
+ recorder := httptest.NewRecorder()
+ req := httptest.NewRequest("GET", "/.well-known/jwks.json", nil)
+
+ handler.HandleJWKS(recorder, req)
+
+ if recorder.Code != tt.expected {
+ t.Errorf("Expected status %d, got %d", tt.expected, recorder.Code)
+ }
+
+ if tt.mode == "native" {
+ body := recorder.Body.String()
+ if !strings.Contains(body, "JWKS endpoint disabled in native mode") {
+ t.Errorf("Should return JWKS disabled message in native mode")
+ }
+ }
+
+ if tt.provider == "hmac" && tt.mode == "proxy" {
+ body := recorder.Body.String()
+ if body != `{"keys":[]}` {
+ t.Errorf("HMAC provider should return empty JWKS, got %s", body)
+ }
+ }
+ })
+ }
+}
diff --git a/state_test.go b/state_test.go
new file mode 100644
index 0000000..2df4bc3
--- /dev/null
+++ b/state_test.go
@@ -0,0 +1,158 @@
+package oauth
+
+import (
+ "crypto/rand"
+ "testing"
+)
+
+func TestStateSigningAndVerification(t *testing.T) {
+ // Create handler with signing key
+ key := make([]byte, 32)
+ _, _ = rand.Read(key)
+
+ handler := &OAuth2Handler{
+ config: &OAuth2Config{
+ stateSigningKey: key,
+ },
+ }
+
+ tests := []struct {
+ name string
+ stateData map[string]string
+ expectError bool
+ tamper bool
+ }{
+ {
+ name: "Valid state with both fields",
+ stateData: map[string]string{
+ "state": "abc123",
+ "redirect": "https://example.com/callback",
+ },
+ expectError: false,
+ },
+ {
+ name: "Valid state with localhost redirect",
+ stateData: map[string]string{
+ "state": "xyz789",
+ "redirect": "http://localhost:8080/callback",
+ },
+ expectError: false,
+ },
+ {
+ name: "State with special characters",
+ stateData: map[string]string{
+ "state": "state-with-dashes_and_underscores",
+ "redirect": "https://example.com/callback?foo=bar&baz=qux",
+ },
+ expectError: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ // Sign state
+ signed, err := handler.signState(tt.stateData)
+ if err != nil {
+ t.Fatalf("Failed to sign state: %v", err)
+ }
+
+ // Verify state
+ verified, err := handler.verifyState(signed)
+ if tt.expectError && err == nil {
+ t.Error("Expected error but got none")
+ }
+ if !tt.expectError && err != nil {
+ t.Errorf("Unexpected error: %v", err)
+ }
+
+ // Check data integrity
+ if !tt.expectError {
+ if verified["state"] != tt.stateData["state"] {
+ t.Errorf("State mismatch: got %s, want %s", verified["state"], tt.stateData["state"])
+ }
+ if verified["redirect"] != tt.stateData["redirect"] {
+ t.Errorf("Redirect mismatch: got %s, want %s", verified["redirect"], tt.stateData["redirect"])
+ }
+ }
+ })
+ }
+}
+
+func TestStateTamperingDetection(t *testing.T) {
+ // Create handler with signing key
+ key := make([]byte, 32)
+ _, _ = rand.Read(key)
+
+ handler := &OAuth2Handler{
+ config: &OAuth2Config{
+ stateSigningKey: key,
+ },
+ }
+
+ // Create and sign valid state
+ stateData := map[string]string{
+ "state": "original",
+ "redirect": "https://good.com/callback",
+ }
+
+ signed, err := handler.signState(stateData)
+ if err != nil {
+ t.Fatalf("Failed to sign state: %v", err)
+ }
+
+ // Verify the original signed state works correctly
+ _, err = handler.verifyState(signed)
+ if err != nil {
+ t.Logf("Good: Original state verification works: %v", err)
+ }
+
+ // Now create a handler with different key
+ differentKey := make([]byte, 32)
+ _, _ = rand.Read(differentKey)
+
+ handler2 := &OAuth2Handler{
+ config: &OAuth2Config{
+ stateSigningKey: differentKey,
+ },
+ }
+
+ // Try to verify with different key (should fail)
+ _, err = handler2.verifyState(signed)
+ if err == nil {
+ t.Error("Expected verification to fail with different key, but it succeeded")
+ } else {
+ t.Logf("Good: Verification failed with different key: %v", err)
+ }
+
+ // Test with completely invalid base64
+ _, err = handler.verifyState("not-valid-base64!!!")
+ if err == nil {
+ t.Error("Expected verification to fail with invalid base64")
+ }
+}
+
+func TestLocalhostDetection(t *testing.T) {
+ tests := []struct {
+ name string
+ uri string
+ expected bool
+ }{
+ {"HTTP localhost", "http://localhost:8080/callback", true},
+ {"HTTPS localhost", "https://localhost/callback", true},
+ {"HTTP 127.0.0.1", "http://127.0.0.1:3000/callback", true},
+ {"HTTPS 127.0.0.1", "https://127.0.0.1/callback", true},
+ {"IPv6 localhost", "http://[::1]:8080/callback", true},
+ {"Non-localhost domain", "http://example.com/callback", false},
+ {"Non-localhost subdomain", "https://localhost.example.com/callback", false},
+ {"Invalid URI", "not-a-valid-uri", false},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result := isLocalhostURI(tt.uri)
+ if result != tt.expected {
+ t.Errorf("isLocalhostURI(%q) = %v, expected %v", tt.uri, result, tt.expected)
+ }
+ })
+ }
+}