From 72d8c1fdc7aed5b4aedca92ee97f5cd1d186489b Mon Sep 17 00:00:00 2001 From: Jakub Hrozek Date: Tue, 30 Sep 2025 14:48:08 +0100 Subject: [PATCH 1/6] Add middleware to swap the downstream ticket for the upstream ticket Implements HTTP middleware that automatically exchanges downstream authentication tokens for backend-specific tokens using RFC 8693 OAuth 2.0 Token Exchange. The middleware extracts subject tokens from authenticated requests and replaces them with exchanged tokens, supporting two injection strategies: replacing the Authorization header or adding a custom header while preserving the original. Fixes: #2065 --- pkg/auth/tokenexchange/exchange.go | 12 +- pkg/auth/tokenexchange/exchange_test.go | 20 +- pkg/auth/tokenexchange/middleware.go | 219 ++++++++ pkg/auth/tokenexchange/middleware_test.go | 595 ++++++++++++++++++++++ 4 files changed, 830 insertions(+), 16 deletions(-) create mode 100644 pkg/auth/tokenexchange/middleware.go create mode 100644 pkg/auth/tokenexchange/middleware_test.go diff --git a/pkg/auth/tokenexchange/exchange.go b/pkg/auth/tokenexchange/exchange.go index 9eafbe3b5..4786a6db8 100644 --- a/pkg/auth/tokenexchange/exchange.go +++ b/pkg/auth/tokenexchange/exchange.go @@ -158,8 +158,8 @@ func (c clientAuthentication) String() string { c.ClientID, clientSecret) } -// Config holds the configuration for token exchange. -type Config struct { +// ExchangeConfig holds the configuration for token exchange. +type ExchangeConfig struct { // TokenURL is the OAuth 2.0 token endpoint URL TokenURL string @@ -185,8 +185,8 @@ type Config struct { HTTPClient *http.Client } -// Validate checks if the Config contains all required fields. -func (c *Config) Validate() error { +// Validate checks if the ExchangeConfig contains all required fields. +func (c *ExchangeConfig) Validate() error { if c.TokenURL == "" { return fmt.Errorf("TokenURL is required") } @@ -211,7 +211,7 @@ func (c *Config) Validate() error { // tokenSource implements oauth2.TokenSource for token exchange. type tokenSource struct { ctx context.Context - conf *Config + conf *ExchangeConfig } // Token implements oauth2.TokenSource interface. @@ -281,7 +281,7 @@ func (ts *tokenSource) Token() (*oauth2.Token, error) { } // TokenSource returns an oauth2.TokenSource that performs token exchange. -func (c *Config) TokenSource(ctx context.Context) oauth2.TokenSource { +func (c *ExchangeConfig) TokenSource(ctx context.Context) oauth2.TokenSource { return &tokenSource{ ctx: ctx, conf: c, diff --git a/pkg/auth/tokenexchange/exchange_test.go b/pkg/auth/tokenexchange/exchange_test.go index 54bf3ac49..655bd0ae5 100644 --- a/pkg/auth/tokenexchange/exchange_test.go +++ b/pkg/auth/tokenexchange/exchange_test.go @@ -125,7 +125,7 @@ func TestTokenSource_Token_Success(t *testing.T) { defer server.Close() // Create config with test server - config := &Config{ + config := &ExchangeConfig{ TokenURL: server.URL, ClientID: "test-client-id", ClientSecret: "test-client-secret", @@ -166,7 +166,7 @@ func TestTokenSource_Token_WithRefreshToken(t *testing.T) { })) defer server.Close() - config := &Config{ + config := &ExchangeConfig{ TokenURL: server.URL, ClientID: "test-client-id", ClientSecret: "test-client-secret", @@ -198,7 +198,7 @@ func TestTokenSource_Token_NoExpiry(t *testing.T) { })) defer server.Close() - config := &Config{ + config := &ExchangeConfig{ TokenURL: server.URL, ClientID: "test-client-id", ClientSecret: "test-client-secret", @@ -221,7 +221,7 @@ func TestTokenSource_Token_SubjectTokenProviderError(t *testing.T) { t.Parallel() providerErr := errors.New("failed to get token from provider") - config := &Config{ + config := &ExchangeConfig{ TokenURL: "https://example.com/token", ClientID: "test-client-id", ClientSecret: "test-client-secret", @@ -251,7 +251,7 @@ func TestTokenSource_Token_ContextCancellation(t *testing.T) { })) defer server.Close() - config := &Config{ + config := &ExchangeConfig{ TokenURL: server.URL, ClientID: "test-client-id", ClientSecret: "test-client-secret", @@ -800,7 +800,7 @@ func TestSubjectTokenProvider_Variants(t *testing.T) { })) defer server.Close() - config := &Config{ + config := &ExchangeConfig{ TokenURL: server.URL, ClientID: "test-client-id", ClientSecret: "test-client-secret", @@ -1036,10 +1036,10 @@ func TestExchangeToken_ScopeArray(t *testing.T) { } // TestConfig_TokenSource tests that TokenSource creates a valid tokenSource. -func TestConfig_TokenSource(t *testing.T) { +func TestExchangeConfig_TokenSource(t *testing.T) { t.Parallel() - config := &Config{ + config := &ExchangeConfig{ TokenURL: "https://example.com/token", ClientID: "test-client-id", ClientSecret: "test-client-secret", @@ -1175,14 +1175,14 @@ func TestClientAuthentication_Fields(t *testing.T) { } // TestConfig_Fields tests Config struct fields. -func TestConfig_Fields(t *testing.T) { +func TestExchangeConfig_Fields(t *testing.T) { t.Parallel() provider := func() (string, error) { return "token", nil } - config := &Config{ + config := &ExchangeConfig{ TokenURL: "https://example.com/token", ClientID: "test-client-id", ClientSecret: "test-client-secret", diff --git a/pkg/auth/tokenexchange/middleware.go b/pkg/auth/tokenexchange/middleware.go new file mode 100644 index 000000000..f8a72ee4e --- /dev/null +++ b/pkg/auth/tokenexchange/middleware.go @@ -0,0 +1,219 @@ +package tokenexchange + +import ( + "encoding/json" + "fmt" + "net/http" + "strings" + + "github.com/golang-jwt/jwt/v5" + + "github.com/stacklok/toolhive/pkg/auth" + "github.com/stacklok/toolhive/pkg/logger" + "github.com/stacklok/toolhive/pkg/transport/types" +) + +// Middleware type constant +const ( + MiddlewareType = "tokenexchange" +) + +// Header injection strategy constants +const ( + // HeaderStrategyReplace replaces the Authorization header with the exchanged token + HeaderStrategyReplace = "replace" + // HeaderStrategyCustom adds the exchanged token to a custom header + HeaderStrategyCustom = "custom" +) + +// MiddlewareParams represents the parameters for token exchange middleware +type MiddlewareParams struct { + TokenExchangeConfig *Config `json:"token_exchange_config,omitempty"` +} + +// Config holds configuration for token exchange middleware +type Config struct { + // TokenURL is the OAuth 2.0 token endpoint URL + TokenURL string `json:"token_url"` + + // ClientID is the OAuth 2.0 client identifier + ClientID string `json:"client_id"` + + // ClientSecret is the OAuth 2.0 client secret + ClientSecret string `json:"client_secret"` + + // Audience is the target audience for the exchanged token + Audience string `json:"audience"` + + // Scope is the scope to request for the exchanged token + Scope string `json:"scope,omitempty"` + + // HeaderStrategy determines how to inject the token + // Valid values: HeaderStrategyReplace (default), HeaderStrategyCustom + HeaderStrategy string `json:"header_strategy,omitempty"` + + // ExternalTokenHeaderName is the name of the custom header to use when HeaderStrategy is "custom" + ExternalTokenHeaderName string `json:"external_token_header_name,omitempty"` +} + +// Middleware wraps token exchange middleware functionality +type Middleware struct { + middleware types.MiddlewareFunction +} + +// Handler returns the middleware function used by the proxy. +func (m *Middleware) Handler() types.MiddlewareFunction { + return m.middleware +} + +// Close cleans up any resources used by the middleware. +func (*Middleware) Close() error { + // Token exchange middleware doesn't need cleanup + return nil +} + +// CreateMiddleware factory function for token exchange middleware +func CreateMiddleware(config *types.MiddlewareConfig, runner types.MiddlewareRunner) error { + var params MiddlewareParams + if err := json.Unmarshal(config.Parameters, ¶ms); err != nil { + return fmt.Errorf("failed to unmarshal token exchange middleware parameters: %w", err) + } + + // If no token exchange config provided, skip middleware + if params.TokenExchangeConfig == nil { + logger.Debug("No token exchange config provided, skipping token exchange middleware") + return nil + } + + // Validate configuration + if err := validateTokenExchangeConfig(params.TokenExchangeConfig); err != nil { + return fmt.Errorf("invalid token exchange configuration: %w", err) + } + + middleware := CreateTokenExchangeMiddlewareFromClaims(*params.TokenExchangeConfig) + + tokenExchangeMw := &Middleware{ + middleware: middleware, + } + + // Add middleware to runner + runner.AddMiddleware(tokenExchangeMw) + + return nil +} + +// validateTokenExchangeConfig validates the token exchange configuration +func validateTokenExchangeConfig(config *Config) error { + if config.HeaderStrategy == HeaderStrategyCustom && config.ExternalTokenHeaderName == "" { + return fmt.Errorf("external_token_header_name must be specified when header_strategy is '%s'", HeaderStrategyCustom) + } + + if config.HeaderStrategy != "" && + config.HeaderStrategy != HeaderStrategyReplace && + config.HeaderStrategy != HeaderStrategyCustom { + return fmt.Errorf("invalid header_strategy: %s (valid values: '%s', '%s')", + config.HeaderStrategy, HeaderStrategyReplace, HeaderStrategyCustom) + } + + return nil +} + +// injectToken handles token injection based on the configured strategy +func injectToken(r *http.Request, token string, config Config) error { + strategy := config.HeaderStrategy + if strategy == "" { + strategy = HeaderStrategyReplace // Default to replace for backwards compatibility + } + + switch strategy { + case HeaderStrategyReplace: + logger.Debugf("Token exchange successful, replacing Authorization header") + r.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) + return nil + + case HeaderStrategyCustom: + if config.ExternalTokenHeaderName == "" { + return fmt.Errorf("external_token_header_name must be specified when header_strategy is '%s'", HeaderStrategyCustom) + } + logger.Debugf("Token exchange successful, adding token to custom header: %s", config.ExternalTokenHeaderName) + r.Header.Set(config.ExternalTokenHeaderName, fmt.Sprintf("Bearer %s", token)) + return nil + + default: + return fmt.Errorf("unsupported header_strategy: %s (valid values: '%s', '%s')", + strategy, HeaderStrategyReplace, HeaderStrategyCustom) + } +} + +// CreateTokenExchangeMiddlewareFromClaims creates a middleware that uses token claims +// from the auth middleware to perform token exchange. +// This is a public function for direct usage in proxy commands. +func CreateTokenExchangeMiddlewareFromClaims(config Config) types.MiddlewareFunction { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Get claims from the auth middleware + claims, ok := r.Context().Value(auth.ClaimsContextKey{}).(jwt.MapClaims) + if !ok { + logger.Debug("No claims found in context, proceeding without token exchange") + next.ServeHTTP(w, r) + return + } + + // Extract the original token from the Authorization header + authHeader := r.Header.Get("Authorization") + if authHeader == "" || !strings.HasPrefix(authHeader, "Bearer ") { + logger.Debug("No valid Bearer token found, proceeding without token exchange") + next.ServeHTTP(w, r) + return + } + + subjectToken := strings.TrimPrefix(authHeader, "Bearer ") + if subjectToken == "" { + logger.Debug("Empty Bearer token, proceeding without token exchange") + next.ServeHTTP(w, r) + return + } + + // Log some claim information for debugging + if sub, exists := claims["sub"]; exists { + logger.Debugf("Performing token exchange for subject: %v", sub) + } + + // Build scopes array + scopes := []string{} + if config.Scope != "" { + scopes = strings.Split(config.Scope, " ") + } + + // Create RFC-8693 token exchange config with subject token provider + exchangeConfig := &ExchangeConfig{ + TokenURL: config.TokenURL, + ClientID: config.ClientID, + ClientSecret: config.ClientSecret, + Audience: config.Audience, + Scopes: scopes, + SubjectTokenProvider: func() (string, error) { + return subjectToken, nil + }, + } + + // Get token from token source + tokenSource := exchangeConfig.TokenSource(r.Context()) + exchangedToken, err := tokenSource.Token() + if err != nil { + logger.Warnf("Token exchange failed: %v", err) + http.Error(w, "Token exchange failed", http.StatusUnauthorized) + return + } + + // Inject the exchanged token into the request + if err := injectToken(r, exchangedToken.AccessToken, config); err != nil { + logger.Warnf("Failed to inject token: %v", err) + http.Error(w, "Token injection failed", http.StatusInternalServerError) + return + } + + next.ServeHTTP(w, r) + }) + } +} diff --git a/pkg/auth/tokenexchange/middleware_test.go b/pkg/auth/tokenexchange/middleware_test.go new file mode 100644 index 000000000..5f8257f5d --- /dev/null +++ b/pkg/auth/tokenexchange/middleware_test.go @@ -0,0 +1,595 @@ +package tokenexchange + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/golang-jwt/jwt/v5" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" + + "github.com/stacklok/toolhive/pkg/auth" + "github.com/stacklok/toolhive/pkg/transport/types" + "github.com/stacklok/toolhive/pkg/transport/types/mocks" +) + +// TestValidateTokenExchangeConfig tests configuration validation. +func TestValidateTokenExchangeConfig(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + config *Config + expectError bool + errorMsg string + }{ + { + name: "valid replace strategy explicit", + config: &Config{ + HeaderStrategy: HeaderStrategyReplace, + }, + expectError: false, + }, + { + name: "valid custom strategy with header name", + config: &Config{ + HeaderStrategy: HeaderStrategyCustom, + ExternalTokenHeaderName: "X-Upstream-Token", + }, + expectError: false, + }, + { + name: "valid empty strategy defaults to replace", + config: &Config{ + HeaderStrategy: "", + }, + expectError: false, + }, + { + name: "invalid custom strategy missing header name", + config: &Config{ + HeaderStrategy: HeaderStrategyCustom, + }, + expectError: true, + errorMsg: "external_token_header_name must be specified", + }, + { + name: "invalid strategy name", + config: &Config{ + HeaderStrategy: "invalid-strategy", + }, + expectError: true, + errorMsg: "invalid header_strategy", + }, + { + name: "unknown strategy", + config: &Config{ + HeaderStrategy: "query-param", + }, + expectError: true, + errorMsg: "invalid header_strategy", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + err := validateTokenExchangeConfig(tt.config) + + if tt.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errorMsg) + } else { + assert.NoError(t, err) + } + }) + } +} + +// TestInjectToken tests the token injection strategies. +func TestInjectToken(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + config Config + originalAuthHeader string + newToken string + expectError bool + errorMsg string + expectedAuthHeader string + expectedCustomHeader string + customHeaderName string + }{ + { + name: "replace strategy replaces Authorization header", + config: Config{ + HeaderStrategy: HeaderStrategyReplace, + }, + originalAuthHeader: "Bearer original-token", + newToken: "new-token", + expectError: false, + expectedAuthHeader: "Bearer new-token", + }, + { + name: "empty strategy defaults to replace", + config: Config{ + HeaderStrategy: "", + }, + originalAuthHeader: "Bearer original-token", + newToken: "new-token", + expectError: false, + expectedAuthHeader: "Bearer new-token", + }, + { + name: "custom strategy preserves original and adds custom header", + config: Config{ + HeaderStrategy: HeaderStrategyCustom, + ExternalTokenHeaderName: "X-Upstream-Token", + }, + originalAuthHeader: "Bearer original-token", + newToken: "new-token", + expectError: false, + expectedAuthHeader: "Bearer original-token", + expectedCustomHeader: "Bearer new-token", + customHeaderName: "X-Upstream-Token", + }, + { + name: "custom strategy with different header name", + config: Config{ + HeaderStrategy: HeaderStrategyCustom, + ExternalTokenHeaderName: "X-External-Auth", + }, + originalAuthHeader: "Bearer original-token", + newToken: "exchanged-token", + expectError: false, + expectedAuthHeader: "Bearer original-token", + expectedCustomHeader: "Bearer exchanged-token", + customHeaderName: "X-External-Auth", + }, + { + name: "custom strategy missing header name fails", + config: Config{ + HeaderStrategy: HeaderStrategyCustom, + }, + newToken: "new-token", + expectError: true, + errorMsg: "external_token_header_name must be specified", + }, + { + name: "unsupported strategy fails", + config: Config{ + HeaderStrategy: "unsupported-strategy", + }, + newToken: "new-token", + expectError: true, + errorMsg: "unsupported header_strategy", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + req := httptest.NewRequest(http.MethodGet, "/test", nil) + if tt.originalAuthHeader != "" { + req.Header.Set("Authorization", tt.originalAuthHeader) + } + + err := injectToken(req, tt.newToken, tt.config) + + if tt.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errorMsg) + } else { + require.NoError(t, err) + assert.Equal(t, tt.expectedAuthHeader, req.Header.Get("Authorization")) + if tt.customHeaderName != "" { + assert.Equal(t, tt.expectedCustomHeader, req.Header.Get(tt.customHeaderName)) + } + } + }) + } +} + +// TestCreateTokenExchangeMiddlewareFromClaims_Success tests successful token exchange flow. +func TestCreateTokenExchangeMiddlewareFromClaims_Success(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + headerStrategy string + customHeaderName string + scopes string + expectedAuthHeader string + expectedCustomHeader string + expectedScopesReceived string + }{ + { + name: "replace strategy", + headerStrategy: HeaderStrategyReplace, + expectedAuthHeader: "Bearer exchanged-token", + expectedScopesReceived: "", + }, + { + name: "custom strategy", + headerStrategy: HeaderStrategyCustom, + customHeaderName: "X-Upstream-Token", + expectedAuthHeader: "Bearer original-token", + expectedCustomHeader: "Bearer exchanged-token", + expectedScopesReceived: "", + }, + { + name: "with scopes", + headerStrategy: HeaderStrategyReplace, + scopes: "read write admin", + expectedAuthHeader: "Bearer exchanged-token", + expectedScopesReceived: "read write admin", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + var receivedScopes string + + // Create mock OAuth server + exchangeServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if tt.expectedScopesReceived != "" { + _ = r.ParseForm() + receivedScopes = r.Form.Get("scope") + } + + resp := response{ + AccessToken: "exchanged-token", + TokenType: "Bearer", + IssuedTokenType: "urn:ietf:params:oauth:token-type:access_token", + ExpiresIn: 3600, + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(resp) + })) + defer exchangeServer.Close() + + config := Config{ + TokenURL: exchangeServer.URL, + ClientID: "test-client-id", + ClientSecret: "test-client-secret", + Audience: "https://api.example.com", + Scope: tt.scopes, + HeaderStrategy: tt.headerStrategy, + ExternalTokenHeaderName: tt.customHeaderName, + } + + middleware := CreateTokenExchangeMiddlewareFromClaims(config) + + // Test handler verifies token injection + testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, tt.expectedAuthHeader, r.Header.Get("Authorization")) + if tt.customHeaderName != "" { + assert.Equal(t, tt.expectedCustomHeader, r.Header.Get(tt.customHeaderName)) + } + w.WriteHeader(http.StatusOK) + }) + + // Create request with claims and token + req := httptest.NewRequest(http.MethodGet, "/test", nil) + req.Header.Set("Authorization", "Bearer original-token") + claims := jwt.MapClaims{ + "sub": "user123", + "aud": "test-audience", + } + ctx := context.WithValue(req.Context(), auth.ClaimsContextKey{}, claims) + req = req.WithContext(ctx) + + // Execute middleware + rec := httptest.NewRecorder() + handler := middleware(testHandler) + handler.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + if tt.expectedScopesReceived != "" { + assert.Equal(t, tt.expectedScopesReceived, receivedScopes) + } + }) + } +} + +// TestCreateTokenExchangeMiddlewareFromClaims_PassThrough tests cases where middleware passes through. +func TestCreateTokenExchangeMiddlewareFromClaims_PassThrough(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + setupReq func(*http.Request) *http.Request + description string + }{ + { + name: "no claims in context", + setupReq: func(req *http.Request) *http.Request { + req.Header.Set("Authorization", "Bearer original-token") + return req + }, + description: "should pass through without token exchange", + }, + { + name: "no Authorization header", + setupReq: func(req *http.Request) *http.Request { + claims := jwt.MapClaims{"sub": "user123"} + ctx := context.WithValue(req.Context(), auth.ClaimsContextKey{}, claims) + return req.WithContext(ctx) + }, + description: "should pass through without token exchange", + }, + { + name: "non-Bearer token", + setupReq: func(req *http.Request) *http.Request { + req.Header.Set("Authorization", "Basic dXNlcjpwYXNz") + claims := jwt.MapClaims{"sub": "user123"} + ctx := context.WithValue(req.Context(), auth.ClaimsContextKey{}, claims) + return req.WithContext(ctx) + }, + description: "should pass through with non-Bearer auth", + }, + { + name: "empty Bearer token", + setupReq: func(req *http.Request) *http.Request { + req.Header.Set("Authorization", "Bearer ") + claims := jwt.MapClaims{"sub": "user123"} + ctx := context.WithValue(req.Context(), auth.ClaimsContextKey{}, claims) + return req.WithContext(ctx) + }, + description: "should pass through with empty Bearer token", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + config := Config{ + TokenURL: "https://example.com/token", + ClientID: "test-client-id", + ClientSecret: "test-client-secret", + } + + middleware := CreateTokenExchangeMiddlewareFromClaims(config) + + handlerCalled := false + testHandler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + handlerCalled = true + w.WriteHeader(http.StatusOK) + }) + + req := httptest.NewRequest(http.MethodGet, "/test", nil) + req = tt.setupReq(req) + + rec := httptest.NewRecorder() + handler := middleware(testHandler) + handler.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code, tt.description) + assert.True(t, handlerCalled, "handler should be called") + }) + } +} + +// TestCreateTokenExchangeMiddlewareFromClaims_Failures tests error scenarios. +func TestCreateTokenExchangeMiddlewareFromClaims_Failures(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + serverResponse func(w http.ResponseWriter, r *http.Request) + headerStrategy string + customHeaderName string + expectedStatusCode int + expectedBodyMsg string + }{ + { + name: "token exchange returns 401", + serverResponse: func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + _, _ = w.Write([]byte(`{"error":"invalid_client"}`)) + }, + headerStrategy: HeaderStrategyReplace, + expectedStatusCode: http.StatusUnauthorized, + expectedBodyMsg: "Token exchange failed", + }, + { + name: "token exchange returns 500", + serverResponse: func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + _, _ = w.Write([]byte(`{"error":"server_error"}`)) + }, + headerStrategy: HeaderStrategyReplace, + expectedStatusCode: http.StatusUnauthorized, + expectedBodyMsg: "Token exchange failed", + }, + { + name: "invalid injection config", + serverResponse: func(w http.ResponseWriter, _ *http.Request) { + resp := response{ + AccessToken: "exchanged-token", + TokenType: "Bearer", + IssuedTokenType: "urn:ietf:params:oauth:token-type:access_token", + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(resp) + }, + headerStrategy: HeaderStrategyCustom, + customHeaderName: "", // Missing header name causes injection failure + expectedStatusCode: http.StatusInternalServerError, + expectedBodyMsg: "Token injection failed", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + exchangeServer := httptest.NewServer(http.HandlerFunc(tt.serverResponse)) + defer exchangeServer.Close() + + config := Config{ + TokenURL: exchangeServer.URL, + ClientID: "test-client-id", + ClientSecret: "test-client-secret", + HeaderStrategy: tt.headerStrategy, + ExternalTokenHeaderName: tt.customHeaderName, + } + + middleware := CreateTokenExchangeMiddlewareFromClaims(config) + + testHandler := http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { + t.Fatal("handler should not be called on failure") + }) + + req := httptest.NewRequest(http.MethodGet, "/test", nil) + req.Header.Set("Authorization", "Bearer original-token") + claims := jwt.MapClaims{"sub": "user123"} + ctx := context.WithValue(req.Context(), auth.ClaimsContextKey{}, claims) + req = req.WithContext(ctx) + + rec := httptest.NewRecorder() + handler := middleware(testHandler) + handler.ServeHTTP(rec, req) + + assert.Equal(t, tt.expectedStatusCode, rec.Code) + assert.Contains(t, rec.Body.String(), tt.expectedBodyMsg) + }) + } +} + +// TestCreateMiddleware tests the factory function. +func TestCreateMiddleware(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + params MiddlewareParams + expectError bool + errorMsg string + expectAddMiddleware bool + }{ + { + name: "valid config creates middleware", + params: MiddlewareParams{ + TokenExchangeConfig: &Config{ + TokenURL: "https://example.com/token", + ClientID: "test-client-id", + ClientSecret: "test-client-secret", + HeaderStrategy: HeaderStrategyReplace, + }, + }, + expectError: false, + expectAddMiddleware: true, + }, + { + name: "nil config skips middleware creation", + params: MiddlewareParams{ + TokenExchangeConfig: nil, + }, + expectError: false, + expectAddMiddleware: false, + }, + { + name: "invalid config fails validation", + params: MiddlewareParams{ + TokenExchangeConfig: &Config{ + HeaderStrategy: HeaderStrategyCustom, + // Missing ExternalTokenHeaderName + }, + }, + expectError: true, + errorMsg: "invalid token exchange configuration", + expectAddMiddleware: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockRunner := mocks.NewMockMiddlewareRunner(ctrl) + + if tt.expectAddMiddleware { + mockRunner.EXPECT().AddMiddleware(gomock.Any()).Do(func(mw types.Middleware) { + _, ok := mw.(*Middleware) + assert.True(t, ok, "Expected middleware to be of type *tokenexchange.Middleware") + }) + } + + paramsJSON, err := json.Marshal(tt.params) + require.NoError(t, err) + + config := &types.MiddlewareConfig{ + Type: MiddlewareType, + Parameters: paramsJSON, + } + + err = CreateMiddleware(config, mockRunner) + + if tt.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errorMsg) + } else { + require.NoError(t, err) + } + }) + } +} + +// TestCreateMiddleware_InvalidJSON tests error handling for malformed parameters. +func TestCreateMiddleware_InvalidJSON(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockRunner := mocks.NewMockMiddlewareRunner(ctrl) + + config := &types.MiddlewareConfig{ + Type: MiddlewareType, + Parameters: []byte(`{invalid json}`), + } + + err := CreateMiddleware(config, mockRunner) + + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to unmarshal token exchange middleware parameters") +} + +// TestMiddleware_Methods tests the Middleware struct methods. +func TestMiddleware_Methods(t *testing.T) { + t.Parallel() + + middlewareFunc := func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + next.ServeHTTP(w, r) + }) + } + + mw := &Middleware{ + middleware: middlewareFunc, + } + + // Test Handler returns the function + handler := mw.Handler() + assert.NotNil(t, handler) + + // Test Close returns no error + err := mw.Close() + assert.NoError(t, err) +} From 05309f07c1ea7011f9bb1f83c295482647fd2b24 Mon Sep 17 00:00:00 2001 From: Jakub Hrozek Date: Wed, 8 Oct 2025 12:34:01 +0100 Subject: [PATCH 2/6] review feedback: Change scopes in Config to []strings --- pkg/auth/tokenexchange/middleware.go | 12 +++--------- pkg/auth/tokenexchange/middleware_test.go | 8 +++++--- 2 files changed, 8 insertions(+), 12 deletions(-) diff --git a/pkg/auth/tokenexchange/middleware.go b/pkg/auth/tokenexchange/middleware.go index f8a72ee4e..28118cb59 100644 --- a/pkg/auth/tokenexchange/middleware.go +++ b/pkg/auth/tokenexchange/middleware.go @@ -45,8 +45,8 @@ type Config struct { // Audience is the target audience for the exchanged token Audience string `json:"audience"` - // Scope is the scope to request for the exchanged token - Scope string `json:"scope,omitempty"` + // Scopes is the list of scopes to request for the exchanged token + Scopes []string `json:"scopes,omitempty"` // HeaderStrategy determines how to inject the token // Valid values: HeaderStrategyReplace (default), HeaderStrategyCustom @@ -179,19 +179,13 @@ func CreateTokenExchangeMiddlewareFromClaims(config Config) types.MiddlewareFunc logger.Debugf("Performing token exchange for subject: %v", sub) } - // Build scopes array - scopes := []string{} - if config.Scope != "" { - scopes = strings.Split(config.Scope, " ") - } - // Create RFC-8693 token exchange config with subject token provider exchangeConfig := &ExchangeConfig{ TokenURL: config.TokenURL, ClientID: config.ClientID, ClientSecret: config.ClientSecret, Audience: config.Audience, - Scopes: scopes, + Scopes: config.Scopes, SubjectTokenProvider: func() (string, error) { return subjectToken, nil }, diff --git a/pkg/auth/tokenexchange/middleware_test.go b/pkg/auth/tokenexchange/middleware_test.go index 5f8257f5d..481768281 100644 --- a/pkg/auth/tokenexchange/middleware_test.go +++ b/pkg/auth/tokenexchange/middleware_test.go @@ -205,7 +205,7 @@ func TestCreateTokenExchangeMiddlewareFromClaims_Success(t *testing.T) { name string headerStrategy string customHeaderName string - scopes string + scopes []string expectedAuthHeader string expectedCustomHeader string expectedScopesReceived string @@ -213,6 +213,7 @@ func TestCreateTokenExchangeMiddlewareFromClaims_Success(t *testing.T) { { name: "replace strategy", headerStrategy: HeaderStrategyReplace, + scopes: nil, expectedAuthHeader: "Bearer exchanged-token", expectedScopesReceived: "", }, @@ -220,6 +221,7 @@ func TestCreateTokenExchangeMiddlewareFromClaims_Success(t *testing.T) { name: "custom strategy", headerStrategy: HeaderStrategyCustom, customHeaderName: "X-Upstream-Token", + scopes: nil, expectedAuthHeader: "Bearer original-token", expectedCustomHeader: "Bearer exchanged-token", expectedScopesReceived: "", @@ -227,7 +229,7 @@ func TestCreateTokenExchangeMiddlewareFromClaims_Success(t *testing.T) { { name: "with scopes", headerStrategy: HeaderStrategyReplace, - scopes: "read write admin", + scopes: []string{"read", "write", "admin"}, expectedAuthHeader: "Bearer exchanged-token", expectedScopesReceived: "read write admin", }, @@ -263,7 +265,7 @@ func TestCreateTokenExchangeMiddlewareFromClaims_Success(t *testing.T) { ClientID: "test-client-id", ClientSecret: "test-client-secret", Audience: "https://api.example.com", - Scope: tt.scopes, + Scopes: tt.scopes, HeaderStrategy: tt.headerStrategy, ExternalTokenHeaderName: tt.customHeaderName, } From 5e2ac18306472e00acb9956b96e7e257de102478 Mon Sep 17 00:00:00 2001 From: Jakub Hrozek Date: Wed, 8 Oct 2025 16:44:40 +0100 Subject: [PATCH 3/6] review feedback: Make the strategy selection a closure called by the middleware handler --- pkg/auth/tokenexchange/middleware.go | 57 ++++++++++++++++------- pkg/auth/tokenexchange/middleware_test.go | 22 ++++++++- 2 files changed, 60 insertions(+), 19 deletions(-) diff --git a/pkg/auth/tokenexchange/middleware.go b/pkg/auth/tokenexchange/middleware.go index 28118cb59..f714b0b8b 100644 --- a/pkg/auth/tokenexchange/middleware.go +++ b/pkg/auth/tokenexchange/middleware.go @@ -118,30 +118,31 @@ func validateTokenExchangeConfig(config *Config) error { return nil } -// injectToken handles token injection based on the configured strategy -func injectToken(r *http.Request, token string, config Config) error { - strategy := config.HeaderStrategy - if strategy == "" { - strategy = HeaderStrategyReplace // Default to replace for backwards compatibility - } +// injectionFunc is a function that injects a token into an HTTP request +type injectionFunc func(*http.Request, string) error - switch strategy { - case HeaderStrategyReplace: +// createReplaceInjector creates an injection function that replaces the Authorization header +func createReplaceInjector() injectionFunc { + return func(r *http.Request, token string) error { logger.Debugf("Token exchange successful, replacing Authorization header") r.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) return nil + } +} - case HeaderStrategyCustom: - if config.ExternalTokenHeaderName == "" { +// createCustomInjector creates an injection function that adds the token to a custom header +func createCustomInjector(headerName string) injectionFunc { + // Validate header name at creation time + if headerName == "" { + return func(_ *http.Request, _ string) error { return fmt.Errorf("external_token_header_name must be specified when header_strategy is '%s'", HeaderStrategyCustom) } - logger.Debugf("Token exchange successful, adding token to custom header: %s", config.ExternalTokenHeaderName) - r.Header.Set(config.ExternalTokenHeaderName, fmt.Sprintf("Bearer %s", token)) - return nil + } - default: - return fmt.Errorf("unsupported header_strategy: %s (valid values: '%s', '%s')", - strategy, HeaderStrategyReplace, HeaderStrategyCustom) + return func(r *http.Request, token string) error { + logger.Debugf("Token exchange successful, adding token to custom header: %s", headerName) + r.Header.Set(headerName, fmt.Sprintf("Bearer %s", token)) + return nil } } @@ -149,6 +150,26 @@ func injectToken(r *http.Request, token string, config Config) error { // from the auth middleware to perform token exchange. // This is a public function for direct usage in proxy commands. func CreateTokenExchangeMiddlewareFromClaims(config Config) types.MiddlewareFunction { + // Determine injection strategy at startup time + strategy := config.HeaderStrategy + if strategy == "" { + strategy = HeaderStrategyReplace // Default to replace for backwards compatibility + } + + var injectToken injectionFunc + switch strategy { + case HeaderStrategyReplace: + injectToken = createReplaceInjector() + case HeaderStrategyCustom: + injectToken = createCustomInjector(config.ExternalTokenHeaderName) + default: + // For invalid strategies, create a function that returns an error + injectToken = func(_ *http.Request, _ string) error { + return fmt.Errorf("unsupported header_strategy: %s (valid values: '%s', '%s')", + strategy, HeaderStrategyReplace, HeaderStrategyCustom) + } + } + return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Get claims from the auth middleware @@ -200,8 +221,8 @@ func CreateTokenExchangeMiddlewareFromClaims(config Config) types.MiddlewareFunc return } - // Inject the exchanged token into the request - if err := injectToken(r, exchangedToken.AccessToken, config); err != nil { + // Inject the exchanged token into the request using the pre-selected strategy + if err := injectToken(r, exchangedToken.AccessToken); err != nil { logger.Warnf("Failed to inject token: %v", err) http.Error(w, "Token injection failed", http.StatusInternalServerError) return diff --git a/pkg/auth/tokenexchange/middleware_test.go b/pkg/auth/tokenexchange/middleware_test.go index 481768281..eeaaabf8d 100644 --- a/pkg/auth/tokenexchange/middleware_test.go +++ b/pkg/auth/tokenexchange/middleware_test.go @@ -3,6 +3,7 @@ package tokenexchange import ( "context" "encoding/json" + "fmt" "net/http" "net/http/httptest" "testing" @@ -181,7 +182,26 @@ func TestInjectToken(t *testing.T) { req.Header.Set("Authorization", tt.originalAuthHeader) } - err := injectToken(req, tt.newToken, tt.config) + // Create the injector function based on the strategy (mimics CreateTokenExchangeMiddlewareFromClaims) + strategy := tt.config.HeaderStrategy + if strategy == "" { + strategy = HeaderStrategyReplace + } + + var injectToken injectionFunc + switch strategy { + case HeaderStrategyReplace: + injectToken = createReplaceInjector() + case HeaderStrategyCustom: + injectToken = createCustomInjector(tt.config.ExternalTokenHeaderName) + default: + injectToken = func(_ *http.Request, _ string) error { + return fmt.Errorf("unsupported header_strategy: %s (valid values: '%s', '%s')", + strategy, HeaderStrategyReplace, HeaderStrategyCustom) + } + } + + err := injectToken(req, tt.newToken) if tt.expectError { require.Error(t, err) From 82f579c23f2c2debcf306307cb7f20fdc2e37d0b Mon Sep 17 00:00:00 2001 From: Jakub Hrozek Date: Wed, 8 Oct 2025 16:50:53 +0100 Subject: [PATCH 4/6] review feedback: move exhcnageConfig outside the handler to CreateTokenExchangeMiddlewareFromClaims --- pkg/auth/tokenexchange/middleware.go | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/pkg/auth/tokenexchange/middleware.go b/pkg/auth/tokenexchange/middleware.go index f714b0b8b..f500236b1 100644 --- a/pkg/auth/tokenexchange/middleware.go +++ b/pkg/auth/tokenexchange/middleware.go @@ -170,6 +170,16 @@ func CreateTokenExchangeMiddlewareFromClaims(config Config) types.MiddlewareFunc } } + // Create base exchange config at startup time with all static fields + baseExchangeConfig := ExchangeConfig{ + TokenURL: config.TokenURL, + ClientID: config.ClientID, + ClientSecret: config.ClientSecret, + Audience: config.Audience, + Scopes: config.Scopes, + // SubjectTokenProvider will be set per request + } + return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Get claims from the auth middleware @@ -200,16 +210,10 @@ func CreateTokenExchangeMiddlewareFromClaims(config Config) types.MiddlewareFunc logger.Debugf("Performing token exchange for subject: %v", sub) } - // Create RFC-8693 token exchange config with subject token provider - exchangeConfig := &ExchangeConfig{ - TokenURL: config.TokenURL, - ClientID: config.ClientID, - ClientSecret: config.ClientSecret, - Audience: config.Audience, - Scopes: config.Scopes, - SubjectTokenProvider: func() (string, error) { - return subjectToken, nil - }, + // Create a copy of the base config with the request-specific subject token + exchangeConfig := baseExchangeConfig + exchangeConfig.SubjectTokenProvider = func() (string, error) { + return subjectToken, nil } // Get token from token source From 20b61d493fc4e1e8c37490474538ca33c828f747 Mon Sep 17 00:00:00 2001 From: Jakub Hrozek Date: Wed, 8 Oct 2025 17:00:26 +0100 Subject: [PATCH 5/6] throw an error instead of nil in case the middleware is misconfigured --- pkg/auth/tokenexchange/middleware.go | 5 ++--- pkg/auth/tokenexchange/middleware_test.go | 5 +++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/pkg/auth/tokenexchange/middleware.go b/pkg/auth/tokenexchange/middleware.go index f500236b1..6f3a7f15f 100644 --- a/pkg/auth/tokenexchange/middleware.go +++ b/pkg/auth/tokenexchange/middleware.go @@ -79,10 +79,9 @@ func CreateMiddleware(config *types.MiddlewareConfig, runner types.MiddlewareRun return fmt.Errorf("failed to unmarshal token exchange middleware parameters: %w", err) } - // If no token exchange config provided, skip middleware + // Token exchange config is required when this middleware type is specified if params.TokenExchangeConfig == nil { - logger.Debug("No token exchange config provided, skipping token exchange middleware") - return nil + return fmt.Errorf("token exchange configuration is required but not provided") } // Validate configuration diff --git a/pkg/auth/tokenexchange/middleware_test.go b/pkg/auth/tokenexchange/middleware_test.go index eeaaabf8d..d8018fedd 100644 --- a/pkg/auth/tokenexchange/middleware_test.go +++ b/pkg/auth/tokenexchange/middleware_test.go @@ -516,11 +516,12 @@ func TestCreateMiddleware(t *testing.T) { expectAddMiddleware: true, }, { - name: "nil config skips middleware creation", + name: "nil config returns error", params: MiddlewareParams{ TokenExchangeConfig: nil, }, - expectError: false, + expectError: true, + errorMsg: "token exchange configuration is required", expectAddMiddleware: false, }, { From 98c4962fc37daf5718c0ce8a4e40047c06c179ad Mon Sep 17 00:00:00 2001 From: Jakub Hrozek Date: Wed, 8 Oct 2025 22:44:08 +0100 Subject: [PATCH 6/6] Review feedback: Make CreateTokenExchangeMiddlewareFromClaims return middleware, err --- pkg/auth/tokenexchange/middleware.go | 18 ++++++++++-------- pkg/auth/tokenexchange/middleware_test.go | 9 ++++++--- 2 files changed, 16 insertions(+), 11 deletions(-) diff --git a/pkg/auth/tokenexchange/middleware.go b/pkg/auth/tokenexchange/middleware.go index 6f3a7f15f..1dc245da4 100644 --- a/pkg/auth/tokenexchange/middleware.go +++ b/pkg/auth/tokenexchange/middleware.go @@ -2,6 +2,7 @@ package tokenexchange import ( "encoding/json" + "errors" "fmt" "net/http" "strings" @@ -26,6 +27,8 @@ const ( HeaderStrategyCustom = "custom" ) +var errUnknownStrategy = errors.New("unknown token injection strategy") + // MiddlewareParams represents the parameters for token exchange middleware type MiddlewareParams struct { TokenExchangeConfig *Config `json:"token_exchange_config,omitempty"` @@ -89,7 +92,10 @@ func CreateMiddleware(config *types.MiddlewareConfig, runner types.MiddlewareRun return fmt.Errorf("invalid token exchange configuration: %w", err) } - middleware := CreateTokenExchangeMiddlewareFromClaims(*params.TokenExchangeConfig) + middleware, err := CreateTokenExchangeMiddlewareFromClaims(*params.TokenExchangeConfig) + if err != nil { + return fmt.Errorf("invalid token exchange middleware config: %w", err) + } tokenExchangeMw := &Middleware{ middleware: middleware, @@ -148,7 +154,7 @@ func createCustomInjector(headerName string) injectionFunc { // CreateTokenExchangeMiddlewareFromClaims creates a middleware that uses token claims // from the auth middleware to perform token exchange. // This is a public function for direct usage in proxy commands. -func CreateTokenExchangeMiddlewareFromClaims(config Config) types.MiddlewareFunction { +func CreateTokenExchangeMiddlewareFromClaims(config Config) (types.MiddlewareFunction, error) { // Determine injection strategy at startup time strategy := config.HeaderStrategy if strategy == "" { @@ -162,11 +168,7 @@ func CreateTokenExchangeMiddlewareFromClaims(config Config) types.MiddlewareFunc case HeaderStrategyCustom: injectToken = createCustomInjector(config.ExternalTokenHeaderName) default: - // For invalid strategies, create a function that returns an error - injectToken = func(_ *http.Request, _ string) error { - return fmt.Errorf("unsupported header_strategy: %s (valid values: '%s', '%s')", - strategy, HeaderStrategyReplace, HeaderStrategyCustom) - } + return nil, fmt.Errorf("%w: invalid header injection strategy %s", errUnknownStrategy, strategy) } // Create base exchange config at startup time with all static fields @@ -233,5 +235,5 @@ func CreateTokenExchangeMiddlewareFromClaims(config Config) types.MiddlewareFunc next.ServeHTTP(w, r) }) - } + }, nil } diff --git a/pkg/auth/tokenexchange/middleware_test.go b/pkg/auth/tokenexchange/middleware_test.go index d8018fedd..f4fd3439c 100644 --- a/pkg/auth/tokenexchange/middleware_test.go +++ b/pkg/auth/tokenexchange/middleware_test.go @@ -290,7 +290,8 @@ func TestCreateTokenExchangeMiddlewareFromClaims_Success(t *testing.T) { ExternalTokenHeaderName: tt.customHeaderName, } - middleware := CreateTokenExchangeMiddlewareFromClaims(config) + middleware, err := CreateTokenExchangeMiddlewareFromClaims(config) + require.NoError(t, err) // Test handler verifies token injection testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -382,7 +383,8 @@ func TestCreateTokenExchangeMiddlewareFromClaims_PassThrough(t *testing.T) { ClientSecret: "test-client-secret", } - middleware := CreateTokenExchangeMiddlewareFromClaims(config) + middleware, err := CreateTokenExchangeMiddlewareFromClaims(config) + require.NoError(t, err) handlerCalled := false testHandler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { @@ -469,7 +471,8 @@ func TestCreateTokenExchangeMiddlewareFromClaims_Failures(t *testing.T) { ExternalTokenHeaderName: tt.customHeaderName, } - middleware := CreateTokenExchangeMiddlewareFromClaims(config) + middleware, err := CreateTokenExchangeMiddlewareFromClaims(config) + require.NoError(t, err) testHandler := http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { t.Fatal("handler should not be called on failure")