Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions internal/auth/factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,21 @@ import (
func NewAuthMiddleware(
ctx context.Context,
cfg *config.AuthConfig,
validatorFactory ValidatorFactory,
factory validatorFactory,
) (func(http.Handler) http.Handler, http.Handler, error) {
// Handle nil config - defaults to anonymous
// TODO: switch to non-anonymous once the whole branch is merged
if cfg == nil {
logger.Infof("auth: anonymous mode (no auth config)")
return AnonymousMiddleware, nil, nil
return anonymousMiddleware, nil, nil
}

switch cfg.Mode {
case config.AuthModeAnonymous, "":
logger.Infof("auth: anonymous mode")
return AnonymousMiddleware, nil, nil
return anonymousMiddleware, nil, nil
case config.AuthModeOAuth:
return createOAuthMiddleware(ctx, cfg, validatorFactory)
return createOAuthMiddleware(ctx, cfg, factory)
default:
return nil, nil, fmt.Errorf("unsupported auth mode: %s", cfg.Mode)
}
Expand All @@ -41,7 +41,7 @@ func NewAuthMiddleware(
func createOAuthMiddleware(
ctx context.Context,
cfg *config.AuthConfig,
validatorFactory ValidatorFactory,
factory validatorFactory,
) (func(http.Handler) http.Handler, http.Handler, error) {
if cfg.OAuth == nil {
return nil, nil, errors.New("oauth configuration is required for oauth mode")
Expand Down Expand Up @@ -75,7 +75,7 @@ func createOAuthMiddleware(
issuerURLs[i] = p.IssuerURL
}

m, err := NewMultiProviderMiddleware(ctx, providers, oauth.ResourceURL, oauth.Realm, validatorFactory)
m, err := newMultiProviderMiddleware(ctx, providers, oauth.ResourceURL, oauth.Realm, factory)
if err != nil {
return nil, nil, fmt.Errorf("failed to create multi-provider middleware: %w", err)
}
Expand All @@ -91,7 +91,7 @@ func createOAuthMiddleware(
return m.Middleware, handler, nil
}

// AnonymousMiddleware is a no-op middleware that passes requests through without authentication.
func AnonymousMiddleware(next http.Handler) http.Handler {
// anonymousMiddleware is a no-op middleware that passes requests through without authentication.
func anonymousMiddleware(next http.Handler) http.Handler {
return next
}
16 changes: 8 additions & 8 deletions internal/auth/factory_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,14 @@ func TestNewAuthMiddleware(t *testing.T) {
ctrl := gomock.NewController(t)

// Mock validator factory for OAuth tests
mockValidatorFactory := func(_ context.Context, _ thvauth.TokenValidatorConfig) (TokenValidatorInterface, error) {
return mocks.NewMockTokenValidatorInterface(ctrl), nil
mockValidatorFactory := func(_ context.Context, _ thvauth.TokenValidatorConfig) (tokenValidatorInterface, error) {
return mocks.NewMocktokenValidatorInterface(ctrl), nil
}

tests := []struct {
name string
config *config.AuthConfig
validatorFactory ValidatorFactory
validatorFactory validatorFactory
wantErr string
wantHandler bool // whether handler should be non-nil
}{
Expand Down Expand Up @@ -119,9 +119,9 @@ func TestNewAuthMiddleware_ClientSecretFile(t *testing.T) {

ctrl := gomock.NewController(t)
var capturedConfig thvauth.TokenValidatorConfig
capturingFactory := func(_ context.Context, cfg thvauth.TokenValidatorConfig) (TokenValidatorInterface, error) {
capturingFactory := func(_ context.Context, cfg thvauth.TokenValidatorConfig) (tokenValidatorInterface, error) {
capturedConfig = cfg
return mocks.NewMockTokenValidatorInterface(ctrl), nil
return mocks.NewMocktokenValidatorInterface(ctrl), nil
}

// Create temp file with secret
Expand Down Expand Up @@ -154,8 +154,8 @@ func TestNewAuthMiddleware_ClientSecretFile(t *testing.T) {
t.Parallel()

ctrl := gomock.NewController(t)
mockFactory := func(_ context.Context, _ thvauth.TokenValidatorConfig) (TokenValidatorInterface, error) {
return mocks.NewMockTokenValidatorInterface(ctrl), nil
mockFactory := func(_ context.Context, _ thvauth.TokenValidatorConfig) (tokenValidatorInterface, error) {
return mocks.NewMocktokenValidatorInterface(ctrl), nil
}

cfg := &config.AuthConfig{
Expand Down Expand Up @@ -187,7 +187,7 @@ func TestAnonymousMiddleware(t *testing.T) {
w.WriteHeader(http.StatusOK)
})

wrapped := AnonymousMiddleware(handler)
wrapped := anonymousMiddleware(handler)

req := httptest.NewRequest("GET", "/test", nil)
rr := httptest.NewRecorder()
Expand Down
8 changes: 4 additions & 4 deletions internal/auth/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ type middlewareTestConfig struct {
}

// setupMiddleware creates a multiProviderMiddleware with standard test configuration.
func setupMiddleware(t *testing.T, jwksServer *testJWKSServer, cfg middlewareTestConfig) *MultiProviderMiddleware {
func setupMiddleware(t *testing.T, jwksServer *testJWKSServer, cfg middlewareTestConfig) *multiProviderMiddleware {
t.Helper()

ctx := context.Background()
Expand All @@ -148,7 +148,7 @@ func setupMiddleware(t *testing.T, jwksServer *testJWKSServer, cfg middlewareTes
},
}

middleware, err := NewMultiProviderMiddleware(ctx, providers, cfg.resourceURL, cfg.realm, DefaultValidatorFactory)
middleware, err := newMultiProviderMiddleware(ctx, providers, cfg.resourceURL, cfg.realm, DefaultValidatorFactory)
require.NoError(t, err)

return middleware
Expand All @@ -167,7 +167,7 @@ func executeAuthRequest(t *testing.T, handler http.Handler, path, token string)
}

// setupMultiProviderMiddleware creates a multiProviderMiddleware with multiple test JWKS servers.
func setupMultiProviderMiddleware(t *testing.T, servers ...*testJWKSServer) *MultiProviderMiddleware {
func setupMultiProviderMiddleware(t *testing.T, servers ...*testJWKSServer) *multiProviderMiddleware {
t.Helper()

ctx := context.Background()
Expand All @@ -185,7 +185,7 @@ func setupMultiProviderMiddleware(t *testing.T, servers ...*testJWKSServer) *Mul
}
}

middleware, err := NewMultiProviderMiddleware(ctx, providers, "", "", DefaultValidatorFactory)
middleware, err := newMultiProviderMiddleware(ctx, providers, "", "", DefaultValidatorFactory)
require.NoError(t, err)

return middleware
Expand Down
99 changes: 45 additions & 54 deletions internal/auth/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,104 +13,95 @@ import (
"github.com/stacklok/toolhive/pkg/logger"
)

// Domain errors for authentication
var (
// ErrAllProvidersFailed indicates all providers failed during sequential fallback
ErrAllProvidersFailed = errors.New("all providers failed to validate token")

// ErrMissingToken indicates the authorization header is missing
ErrMissingToken = errors.New("authorization header missing")

// ErrInvalidTokenFormat indicates the token format is invalid (not Bearer)
ErrInvalidTokenFormat = errors.New("invalid bearer token format")
)
// errAllProvidersFailed indicates all providers failed during sequential fallback
var errAllProvidersFailed = errors.New("all providers failed to validate token")

// RFC 6750 Section 3 error codes
const (
// ErrorCodeInvalidRequest indicates the request is missing a required parameter,
// errorCodeInvalidRequest indicates the request is missing a required parameter,
// includes an unsupported parameter or parameter value, or is otherwise malformed.
ErrorCodeInvalidRequest = "invalid_request"
errorCodeInvalidRequest = "invalid_request"

// ErrorCodeInvalidToken indicates the access token provided is expired, revoked,
// errorCodeInvalidToken indicates the access token provided is expired, revoked,
// malformed, or invalid for other reasons.
ErrorCodeInvalidToken = "invalid_token"
errorCodeInvalidToken = "invalid_token"
)

// ValidationResult contains the outcome of token validation
type ValidationResult struct {
// validationResult contains the outcome of token validation
type validationResult struct {
// Provider is the name of the provider that validated the token
Provider string

// Error is set if validation failed
Error error

// Errors contains all errors from sequential fallback (for debugging)
Errors []ProviderError
Errors []providerError
}

// ProviderError pairs a provider name with its validation error
type ProviderError struct {
// providerError pairs a provider name with its validation error
type providerError struct {
Provider string
Error error
}

// NamedValidator pairs a validator with its provider metadata
type NamedValidator struct {
// namedValidator pairs a validator with its provider metadata
type namedValidator struct {
Name string
Validator TokenValidatorInterface
Validator tokenValidatorInterface
}

// DefaultRealm is the default protection space identifier
const DefaultRealm = "mcp-registry"
// defaultRealm is the default protection space identifier
const defaultRealm = "mcp-registry"

// ValidatorFactory creates token validators from configuration.
type ValidatorFactory func(ctx context.Context, cfg auth.TokenValidatorConfig) (TokenValidatorInterface, error)
// validatorFactory creates token validators from configuration.
type validatorFactory func(ctx context.Context, cfg auth.TokenValidatorConfig) (tokenValidatorInterface, error)

// DefaultValidatorFactory uses the real ToolHive token validator.
var DefaultValidatorFactory ValidatorFactory = func(
var DefaultValidatorFactory validatorFactory = func(
ctx context.Context,
cfg auth.TokenValidatorConfig,
) (TokenValidatorInterface, error) {
) (tokenValidatorInterface, error) {
return auth.NewTokenValidator(ctx, cfg)
}

// MultiProviderMiddleware handles authentication with multiple OAuth/OIDC providers.
type MultiProviderMiddleware struct {
validators []NamedValidator
// multiProviderMiddleware handles authentication with multiple OAuth/OIDC providers.
type multiProviderMiddleware struct {
validators []namedValidator
resourceURL string
realm string
}

// NewMultiProviderMiddleware creates a new multi-provider authentication middleware.
func NewMultiProviderMiddleware(
// newMultiProviderMiddleware creates a new multi-provider authentication middleware.
func newMultiProviderMiddleware(
ctx context.Context,
providers []providerConfig,
resourceURL string,
realm string,
validatorFactory ValidatorFactory,
) (*MultiProviderMiddleware, error) {
factory validatorFactory,
) (*multiProviderMiddleware, error) {
if len(providers) == 0 {
return nil, errors.New("at least one provider must be configured")
}

// Apply default realm if not specified
if realm == "" {
realm = DefaultRealm
realm = defaultRealm
}

m := &MultiProviderMiddleware{
validators: make([]NamedValidator, 0, len(providers)),
m := &multiProviderMiddleware{
validators: make([]namedValidator, 0, len(providers)),
resourceURL: resourceURL,
realm: realm,
}

for _, pc := range providers {
validator, err := validatorFactory(ctx, pc.ValidatorConfig)
validator, err := factory(ctx, pc.ValidatorConfig)
if err != nil {
return nil, fmt.Errorf("failed to create validator for provider %q: %w", pc.Name, err)
}

nv := NamedValidator{
nv := namedValidator{
Name: pc.Name,
Validator: validator,
}
Expand All @@ -121,19 +112,19 @@ func NewMultiProviderMiddleware(
}

// Middleware returns an HTTP middleware function that performs authentication.
func (m *MultiProviderMiddleware) Middleware(next http.Handler) http.Handler {
func (m *multiProviderMiddleware) Middleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
token, err := auth.ExtractBearerToken(r)
if err != nil {
logger.Debugf("auth: token extraction failed: %v", err)
m.writeError(w, http.StatusUnauthorized, ErrorCodeInvalidRequest, "missing or malformed authorization header")
m.writeError(w, http.StatusUnauthorized, errorCodeInvalidRequest, "missing or malformed authorization header")
return
}

result := m.validateToken(r.Context(), token)
if result.Error != nil {
logger.Debugf("auth: token validation failed: %v", result.Error)
m.writeError(w, http.StatusUnauthorized, ErrorCodeInvalidToken, "token validation failed")
m.writeError(w, http.StatusUnauthorized, errorCodeInvalidToken, "token validation failed")
return
}

Expand All @@ -144,28 +135,28 @@ func (m *MultiProviderMiddleware) Middleware(next http.Handler) http.Handler {
}

// validateToken attempts to validate the token by iterating through providers sequentially.
func (m *MultiProviderMiddleware) validateToken(ctx context.Context, token string) ValidationResult {
providerErrors := make([]ProviderError, 0, len(m.validators))
func (m *multiProviderMiddleware) validateToken(ctx context.Context, token string) validationResult {
providerErrors := make([]providerError, 0, len(m.validators))

for _, nv := range m.validators {
_, err := nv.Validator.ValidateToken(ctx, token)
if err != nil {
providerErrors = append(providerErrors, ProviderError{
providerErrors = append(providerErrors, providerError{
Provider: nv.Name,
Error: err,
})
logger.Debugf("auth: provider %q failed to validate token: %v", nv.Name, err)
continue
}

return ValidationResult{
return validationResult{
Provider: nv.Name,
Errors: providerErrors,
}
}

return ValidationResult{
Error: ErrAllProvidersFailed,
return validationResult{
Error: errAllProvidersFailed,
Errors: providerErrors,
}
}
Expand All @@ -186,8 +177,8 @@ func sanitizeHeaderValue(s string) string {
}

// writeError writes a JSON error response with RFC 6750 compliant WWW-Authenticate header.
// The errorCode parameter should be one of the RFC 6750 error codes (invalid_request, invalid_token).
func (m *MultiProviderMiddleware) writeError(w http.ResponseWriter, status int, errorCode, description string) {
// The errCode parameter should be one of the RFC 6750 error codes (invalid_request, invalid_token).
func (m *multiProviderMiddleware) writeError(w http.ResponseWriter, status int, errCode, description string) {
w.Header().Set("Content-Type", "application/json")

// Sanitize values to prevent header injection
Expand All @@ -197,11 +188,11 @@ func (m *MultiProviderMiddleware) writeError(w http.ResponseWriter, status int,

// Build WWW-Authenticate header with error codes per RFC 6750 Section 3
wwwAuth := fmt.Sprintf(`Bearer realm="%s", error="%s", error_description="%s"`,
realm, errorCode, sanitizedDescription)
realm, errCode, sanitizedDescription)
if resourceURL != "" {
wwwAuth = fmt.Sprintf(
`Bearer realm="%s", error="%s", error_description="%s", resource_metadata="%s/.well-known/oauth-protected-resource"`,
realm, errorCode, sanitizedDescription, resourceURL)
realm, errCode, sanitizedDescription, resourceURL)
}
w.Header().Set("WWW-Authenticate", wwwAuth)
w.WriteHeader(status)
Expand Down
Loading
Loading