Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions pkg/auth/tokenexchange/exchange.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,8 @@ func (c clientAuthentication) String() string {
c.ClientID, clientSecret)
}

// Config holds the configuration for token exchange.
type Config struct {
// ExchangeConfig holds the configuration for token exchange.
type ExchangeConfig struct {
// TokenURL is the OAuth 2.0 token endpoint URL
TokenURL string

Expand All @@ -185,8 +185,8 @@ type Config struct {
HTTPClient *http.Client
}

// Validate checks if the Config contains all required fields.
func (c *Config) Validate() error {
// Validate checks if the ExchangeConfig contains all required fields.
func (c *ExchangeConfig) Validate() error {
if c.TokenURL == "" {
return fmt.Errorf("TokenURL is required")
}
Expand All @@ -211,7 +211,7 @@ func (c *Config) Validate() error {
// tokenSource implements oauth2.TokenSource for token exchange.
type tokenSource struct {
ctx context.Context
conf *Config
conf *ExchangeConfig
}

// Token implements oauth2.TokenSource interface.
Expand Down Expand Up @@ -281,7 +281,7 @@ func (ts *tokenSource) Token() (*oauth2.Token, error) {
}

// TokenSource returns an oauth2.TokenSource that performs token exchange.
func (c *Config) TokenSource(ctx context.Context) oauth2.TokenSource {
func (c *ExchangeConfig) TokenSource(ctx context.Context) oauth2.TokenSource {
return &tokenSource{
ctx: ctx,
conf: c,
Expand Down
20 changes: 10 additions & 10 deletions pkg/auth/tokenexchange/exchange_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ func TestTokenSource_Token_Success(t *testing.T) {
defer server.Close()

// Create config with test server
config := &Config{
config := &ExchangeConfig{
TokenURL: server.URL,
ClientID: "test-client-id",
ClientSecret: "test-client-secret",
Expand Down Expand Up @@ -166,7 +166,7 @@ func TestTokenSource_Token_WithRefreshToken(t *testing.T) {
}))
defer server.Close()

config := &Config{
config := &ExchangeConfig{
TokenURL: server.URL,
ClientID: "test-client-id",
ClientSecret: "test-client-secret",
Expand Down Expand Up @@ -198,7 +198,7 @@ func TestTokenSource_Token_NoExpiry(t *testing.T) {
}))
defer server.Close()

config := &Config{
config := &ExchangeConfig{
TokenURL: server.URL,
ClientID: "test-client-id",
ClientSecret: "test-client-secret",
Expand All @@ -221,7 +221,7 @@ func TestTokenSource_Token_SubjectTokenProviderError(t *testing.T) {
t.Parallel()

providerErr := errors.New("failed to get token from provider")
config := &Config{
config := &ExchangeConfig{
TokenURL: "https://example.com/token",
ClientID: "test-client-id",
ClientSecret: "test-client-secret",
Expand Down Expand Up @@ -251,7 +251,7 @@ func TestTokenSource_Token_ContextCancellation(t *testing.T) {
}))
defer server.Close()

config := &Config{
config := &ExchangeConfig{
TokenURL: server.URL,
ClientID: "test-client-id",
ClientSecret: "test-client-secret",
Expand Down Expand Up @@ -800,7 +800,7 @@ func TestSubjectTokenProvider_Variants(t *testing.T) {
}))
defer server.Close()

config := &Config{
config := &ExchangeConfig{
TokenURL: server.URL,
ClientID: "test-client-id",
ClientSecret: "test-client-secret",
Expand Down Expand Up @@ -1036,10 +1036,10 @@ func TestExchangeToken_ScopeArray(t *testing.T) {
}

// TestConfig_TokenSource tests that TokenSource creates a valid tokenSource.
func TestConfig_TokenSource(t *testing.T) {
func TestExchangeConfig_TokenSource(t *testing.T) {
t.Parallel()

config := &Config{
config := &ExchangeConfig{
TokenURL: "https://example.com/token",
ClientID: "test-client-id",
ClientSecret: "test-client-secret",
Expand Down Expand Up @@ -1175,14 +1175,14 @@ func TestClientAuthentication_Fields(t *testing.T) {
}

// TestConfig_Fields tests Config struct fields.
func TestConfig_Fields(t *testing.T) {
func TestExchangeConfig_Fields(t *testing.T) {
t.Parallel()

provider := func() (string, error) {
return "token", nil
}

config := &Config{
config := &ExchangeConfig{
TokenURL: "https://example.com/token",
ClientID: "test-client-id",
ClientSecret: "test-client-secret",
Expand Down
239 changes: 239 additions & 0 deletions pkg/auth/tokenexchange/middleware.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,239 @@
package tokenexchange

import (
"encoding/json"
"errors"
"fmt"
"net/http"
"strings"

"github.com/golang-jwt/jwt/v5"

"github.com/stacklok/toolhive/pkg/auth"
"github.com/stacklok/toolhive/pkg/logger"
"github.com/stacklok/toolhive/pkg/transport/types"
)

// Middleware type constant
const (
MiddlewareType = "tokenexchange"
)

// Header injection strategy constants
const (
// HeaderStrategyReplace replaces the Authorization header with the exchanged token
HeaderStrategyReplace = "replace"
// HeaderStrategyCustom adds the exchanged token to a custom header
HeaderStrategyCustom = "custom"
)

var errUnknownStrategy = errors.New("unknown token injection strategy")

// MiddlewareParams represents the parameters for token exchange middleware
type MiddlewareParams struct {
TokenExchangeConfig *Config `json:"token_exchange_config,omitempty"`
}

// Config holds configuration for token exchange middleware
type Config struct {
// TokenURL is the OAuth 2.0 token endpoint URL
TokenURL string `json:"token_url"`

// ClientID is the OAuth 2.0 client identifier
ClientID string `json:"client_id"`

// ClientSecret is the OAuth 2.0 client secret
ClientSecret string `json:"client_secret"`

// Audience is the target audience for the exchanged token
Audience string `json:"audience"`

// Scopes is the list of scopes to request for the exchanged token
Scopes []string `json:"scopes,omitempty"`

// HeaderStrategy determines how to inject the token
// Valid values: HeaderStrategyReplace (default), HeaderStrategyCustom
HeaderStrategy string `json:"header_strategy,omitempty"`

// ExternalTokenHeaderName is the name of the custom header to use when HeaderStrategy is "custom"
ExternalTokenHeaderName string `json:"external_token_header_name,omitempty"`
}

// Middleware wraps token exchange middleware functionality
type Middleware struct {
middleware types.MiddlewareFunction
}

// Handler returns the middleware function used by the proxy.
func (m *Middleware) Handler() types.MiddlewareFunction {
return m.middleware
}

// Close cleans up any resources used by the middleware.
func (*Middleware) Close() error {
// Token exchange middleware doesn't need cleanup
return nil
}

// CreateMiddleware factory function for token exchange middleware
func CreateMiddleware(config *types.MiddlewareConfig, runner types.MiddlewareRunner) error {
var params MiddlewareParams
if err := json.Unmarshal(config.Parameters, &params); err != nil {
return fmt.Errorf("failed to unmarshal token exchange middleware parameters: %w", err)
}

// Token exchange config is required when this middleware type is specified
if params.TokenExchangeConfig == nil {
return fmt.Errorf("token exchange configuration is required but not provided")
}

// Validate configuration
if err := validateTokenExchangeConfig(params.TokenExchangeConfig); err != nil {
return fmt.Errorf("invalid token exchange configuration: %w", err)
}

middleware, err := CreateTokenExchangeMiddlewareFromClaims(*params.TokenExchangeConfig)
if err != nil {
return fmt.Errorf("invalid token exchange middleware config: %w", err)
}

tokenExchangeMw := &Middleware{
middleware: middleware,
}

// Add middleware to runner
runner.AddMiddleware(tokenExchangeMw)

return nil
}

// validateTokenExchangeConfig validates the token exchange configuration
func validateTokenExchangeConfig(config *Config) error {
if config.HeaderStrategy == HeaderStrategyCustom && config.ExternalTokenHeaderName == "" {
return fmt.Errorf("external_token_header_name must be specified when header_strategy is '%s'", HeaderStrategyCustom)
}

if config.HeaderStrategy != "" &&
config.HeaderStrategy != HeaderStrategyReplace &&
config.HeaderStrategy != HeaderStrategyCustom {
return fmt.Errorf("invalid header_strategy: %s (valid values: '%s', '%s')",
config.HeaderStrategy, HeaderStrategyReplace, HeaderStrategyCustom)
}

return nil
}

// injectionFunc is a function that injects a token into an HTTP request
type injectionFunc func(*http.Request, string) error

// createReplaceInjector creates an injection function that replaces the Authorization header
func createReplaceInjector() injectionFunc {
return func(r *http.Request, token string) error {
logger.Debugf("Token exchange successful, replacing Authorization header")
r.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
return nil
}
}

// createCustomInjector creates an injection function that adds the token to a custom header
func createCustomInjector(headerName string) injectionFunc {
// Validate header name at creation time
if headerName == "" {
return func(_ *http.Request, _ string) error {
return fmt.Errorf("external_token_header_name must be specified when header_strategy is '%s'", HeaderStrategyCustom)
}
}

return func(r *http.Request, token string) error {
logger.Debugf("Token exchange successful, adding token to custom header: %s", headerName)
r.Header.Set(headerName, fmt.Sprintf("Bearer %s", token))
return nil
}
}

// CreateTokenExchangeMiddlewareFromClaims creates a middleware that uses token claims
// from the auth middleware to perform token exchange.
// This is a public function for direct usage in proxy commands.
func CreateTokenExchangeMiddlewareFromClaims(config Config) (types.MiddlewareFunction, error) {
// Determine injection strategy at startup time
strategy := config.HeaderStrategy
if strategy == "" {
strategy = HeaderStrategyReplace // Default to replace for backwards compatibility
}

var injectToken injectionFunc
switch strategy {
case HeaderStrategyReplace:
injectToken = createReplaceInjector()
case HeaderStrategyCustom:
injectToken = createCustomInjector(config.ExternalTokenHeaderName)
default:
return nil, fmt.Errorf("%w: invalid header injection strategy %s", errUnknownStrategy, strategy)
}

// Create base exchange config at startup time with all static fields
baseExchangeConfig := ExchangeConfig{
TokenURL: config.TokenURL,
ClientID: config.ClientID,
ClientSecret: config.ClientSecret,
Audience: config.Audience,
Scopes: config.Scopes,
// SubjectTokenProvider will be set per request
}

return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Get claims from the auth middleware
claims, ok := r.Context().Value(auth.ClaimsContextKey{}).(jwt.MapClaims)
if !ok {
logger.Debug("No claims found in context, proceeding without token exchange")
next.ServeHTTP(w, r)
return
}

// Extract the original token from the Authorization header
authHeader := r.Header.Get("Authorization")
if authHeader == "" || !strings.HasPrefix(authHeader, "Bearer ") {
logger.Debug("No valid Bearer token found, proceeding without token exchange")
next.ServeHTTP(w, r)
return
}

subjectToken := strings.TrimPrefix(authHeader, "Bearer ")
if subjectToken == "" {
logger.Debug("Empty Bearer token, proceeding without token exchange")
next.ServeHTTP(w, r)
return
}

// Log some claim information for debugging
if sub, exists := claims["sub"]; exists {
logger.Debugf("Performing token exchange for subject: %v", sub)
}

// Create a copy of the base config with the request-specific subject token
exchangeConfig := baseExchangeConfig
exchangeConfig.SubjectTokenProvider = func() (string, error) {
return subjectToken, nil
}

// Get token from token source
tokenSource := exchangeConfig.TokenSource(r.Context())
exchangedToken, err := tokenSource.Token()
if err != nil {
logger.Warnf("Token exchange failed: %v", err)
http.Error(w, "Token exchange failed", http.StatusUnauthorized)
return
}

// Inject the exchanged token into the request using the pre-selected strategy
if err := injectToken(r, exchangedToken.AccessToken); err != nil {
logger.Warnf("Failed to inject token: %v", err)
http.Error(w, "Token injection failed", http.StatusInternalServerError)
return
}

next.ServeHTTP(w, r)
})
}, nil
}
Loading