diff --git a/internal/auth/factory.go b/internal/auth/factory.go index 27d1bd5..b2b7cc2 100644 --- a/internal/auth/factory.go +++ b/internal/auth/factory.go @@ -17,21 +17,21 @@ import ( func NewAuthMiddleware( ctx context.Context, cfg *config.AuthConfig, - validatorFactory ValidatorFactory, + factory validatorFactory, ) (func(http.Handler) http.Handler, http.Handler, error) { // Handle nil config - defaults to anonymous // TODO: switch to non-anonymous once the whole branch is merged if cfg == nil { logger.Infof("auth: anonymous mode (no auth config)") - return AnonymousMiddleware, nil, nil + return anonymousMiddleware, nil, nil } switch cfg.Mode { case config.AuthModeAnonymous, "": logger.Infof("auth: anonymous mode") - return AnonymousMiddleware, nil, nil + return anonymousMiddleware, nil, nil case config.AuthModeOAuth: - return createOAuthMiddleware(ctx, cfg, validatorFactory) + return createOAuthMiddleware(ctx, cfg, factory) default: return nil, nil, fmt.Errorf("unsupported auth mode: %s", cfg.Mode) } @@ -41,7 +41,7 @@ func NewAuthMiddleware( func createOAuthMiddleware( ctx context.Context, cfg *config.AuthConfig, - validatorFactory ValidatorFactory, + factory validatorFactory, ) (func(http.Handler) http.Handler, http.Handler, error) { if cfg.OAuth == nil { return nil, nil, errors.New("oauth configuration is required for oauth mode") @@ -75,7 +75,7 @@ func createOAuthMiddleware( issuerURLs[i] = p.IssuerURL } - m, err := NewMultiProviderMiddleware(ctx, providers, oauth.ResourceURL, oauth.Realm, validatorFactory) + m, err := newMultiProviderMiddleware(ctx, providers, oauth.ResourceURL, oauth.Realm, factory) if err != nil { return nil, nil, fmt.Errorf("failed to create multi-provider middleware: %w", err) } @@ -91,7 +91,7 @@ func createOAuthMiddleware( return m.Middleware, handler, nil } -// AnonymousMiddleware is a no-op middleware that passes requests through without authentication. -func AnonymousMiddleware(next http.Handler) http.Handler { +// anonymousMiddleware is a no-op middleware that passes requests through without authentication. +func anonymousMiddleware(next http.Handler) http.Handler { return next } diff --git a/internal/auth/factory_test.go b/internal/auth/factory_test.go index 7603440..270cc95 100644 --- a/internal/auth/factory_test.go +++ b/internal/auth/factory_test.go @@ -23,14 +23,14 @@ func TestNewAuthMiddleware(t *testing.T) { ctrl := gomock.NewController(t) // Mock validator factory for OAuth tests - mockValidatorFactory := func(_ context.Context, _ thvauth.TokenValidatorConfig) (TokenValidatorInterface, error) { - return mocks.NewMockTokenValidatorInterface(ctrl), nil + mockValidatorFactory := func(_ context.Context, _ thvauth.TokenValidatorConfig) (tokenValidatorInterface, error) { + return mocks.NewMocktokenValidatorInterface(ctrl), nil } tests := []struct { name string config *config.AuthConfig - validatorFactory ValidatorFactory + validatorFactory validatorFactory wantErr string wantHandler bool // whether handler should be non-nil }{ @@ -119,9 +119,9 @@ func TestNewAuthMiddleware_ClientSecretFile(t *testing.T) { ctrl := gomock.NewController(t) var capturedConfig thvauth.TokenValidatorConfig - capturingFactory := func(_ context.Context, cfg thvauth.TokenValidatorConfig) (TokenValidatorInterface, error) { + capturingFactory := func(_ context.Context, cfg thvauth.TokenValidatorConfig) (tokenValidatorInterface, error) { capturedConfig = cfg - return mocks.NewMockTokenValidatorInterface(ctrl), nil + return mocks.NewMocktokenValidatorInterface(ctrl), nil } // Create temp file with secret @@ -154,8 +154,8 @@ func TestNewAuthMiddleware_ClientSecretFile(t *testing.T) { t.Parallel() ctrl := gomock.NewController(t) - mockFactory := func(_ context.Context, _ thvauth.TokenValidatorConfig) (TokenValidatorInterface, error) { - return mocks.NewMockTokenValidatorInterface(ctrl), nil + mockFactory := func(_ context.Context, _ thvauth.TokenValidatorConfig) (tokenValidatorInterface, error) { + return mocks.NewMocktokenValidatorInterface(ctrl), nil } cfg := &config.AuthConfig{ @@ -187,7 +187,7 @@ func TestAnonymousMiddleware(t *testing.T) { w.WriteHeader(http.StatusOK) }) - wrapped := AnonymousMiddleware(handler) + wrapped := anonymousMiddleware(handler) req := httptest.NewRequest("GET", "/test", nil) rr := httptest.NewRecorder() diff --git a/internal/auth/integration_test.go b/internal/auth/integration_test.go index 22e5257..b73cd3e 100644 --- a/internal/auth/integration_test.go +++ b/internal/auth/integration_test.go @@ -131,7 +131,7 @@ type middlewareTestConfig struct { } // setupMiddleware creates a multiProviderMiddleware with standard test configuration. -func setupMiddleware(t *testing.T, jwksServer *testJWKSServer, cfg middlewareTestConfig) *MultiProviderMiddleware { +func setupMiddleware(t *testing.T, jwksServer *testJWKSServer, cfg middlewareTestConfig) *multiProviderMiddleware { t.Helper() ctx := context.Background() @@ -148,7 +148,7 @@ func setupMiddleware(t *testing.T, jwksServer *testJWKSServer, cfg middlewareTes }, } - middleware, err := NewMultiProviderMiddleware(ctx, providers, cfg.resourceURL, cfg.realm, DefaultValidatorFactory) + middleware, err := newMultiProviderMiddleware(ctx, providers, cfg.resourceURL, cfg.realm, DefaultValidatorFactory) require.NoError(t, err) return middleware @@ -167,7 +167,7 @@ func executeAuthRequest(t *testing.T, handler http.Handler, path, token string) } // setupMultiProviderMiddleware creates a multiProviderMiddleware with multiple test JWKS servers. -func setupMultiProviderMiddleware(t *testing.T, servers ...*testJWKSServer) *MultiProviderMiddleware { +func setupMultiProviderMiddleware(t *testing.T, servers ...*testJWKSServer) *multiProviderMiddleware { t.Helper() ctx := context.Background() @@ -185,7 +185,7 @@ func setupMultiProviderMiddleware(t *testing.T, servers ...*testJWKSServer) *Mul } } - middleware, err := NewMultiProviderMiddleware(ctx, providers, "", "", DefaultValidatorFactory) + middleware, err := newMultiProviderMiddleware(ctx, providers, "", "", DefaultValidatorFactory) require.NoError(t, err) return middleware diff --git a/internal/auth/middleware.go b/internal/auth/middleware.go index 7f22336..93c80a3 100644 --- a/internal/auth/middleware.go +++ b/internal/auth/middleware.go @@ -13,31 +13,22 @@ import ( "github.com/stacklok/toolhive/pkg/logger" ) -// Domain errors for authentication -var ( - // ErrAllProvidersFailed indicates all providers failed during sequential fallback - ErrAllProvidersFailed = errors.New("all providers failed to validate token") - - // ErrMissingToken indicates the authorization header is missing - ErrMissingToken = errors.New("authorization header missing") - - // ErrInvalidTokenFormat indicates the token format is invalid (not Bearer) - ErrInvalidTokenFormat = errors.New("invalid bearer token format") -) +// errAllProvidersFailed indicates all providers failed during sequential fallback +var errAllProvidersFailed = errors.New("all providers failed to validate token") // RFC 6750 Section 3 error codes const ( - // ErrorCodeInvalidRequest indicates the request is missing a required parameter, + // errorCodeInvalidRequest indicates the request is missing a required parameter, // includes an unsupported parameter or parameter value, or is otherwise malformed. - ErrorCodeInvalidRequest = "invalid_request" + errorCodeInvalidRequest = "invalid_request" - // ErrorCodeInvalidToken indicates the access token provided is expired, revoked, + // errorCodeInvalidToken indicates the access token provided is expired, revoked, // malformed, or invalid for other reasons. - ErrorCodeInvalidToken = "invalid_token" + errorCodeInvalidToken = "invalid_token" ) -// ValidationResult contains the outcome of token validation -type ValidationResult struct { +// validationResult contains the outcome of token validation +type validationResult struct { // Provider is the name of the provider that validated the token Provider string @@ -45,72 +36,72 @@ type ValidationResult struct { Error error // Errors contains all errors from sequential fallback (for debugging) - Errors []ProviderError + Errors []providerError } -// ProviderError pairs a provider name with its validation error -type ProviderError struct { +// providerError pairs a provider name with its validation error +type providerError struct { Provider string Error error } -// NamedValidator pairs a validator with its provider metadata -type NamedValidator struct { +// namedValidator pairs a validator with its provider metadata +type namedValidator struct { Name string - Validator TokenValidatorInterface + Validator tokenValidatorInterface } -// DefaultRealm is the default protection space identifier -const DefaultRealm = "mcp-registry" +// defaultRealm is the default protection space identifier +const defaultRealm = "mcp-registry" -// ValidatorFactory creates token validators from configuration. -type ValidatorFactory func(ctx context.Context, cfg auth.TokenValidatorConfig) (TokenValidatorInterface, error) +// validatorFactory creates token validators from configuration. +type validatorFactory func(ctx context.Context, cfg auth.TokenValidatorConfig) (tokenValidatorInterface, error) // DefaultValidatorFactory uses the real ToolHive token validator. -var DefaultValidatorFactory ValidatorFactory = func( +var DefaultValidatorFactory validatorFactory = func( ctx context.Context, cfg auth.TokenValidatorConfig, -) (TokenValidatorInterface, error) { +) (tokenValidatorInterface, error) { return auth.NewTokenValidator(ctx, cfg) } -// MultiProviderMiddleware handles authentication with multiple OAuth/OIDC providers. -type MultiProviderMiddleware struct { - validators []NamedValidator +// multiProviderMiddleware handles authentication with multiple OAuth/OIDC providers. +type multiProviderMiddleware struct { + validators []namedValidator resourceURL string realm string } -// NewMultiProviderMiddleware creates a new multi-provider authentication middleware. -func NewMultiProviderMiddleware( +// newMultiProviderMiddleware creates a new multi-provider authentication middleware. +func newMultiProviderMiddleware( ctx context.Context, providers []providerConfig, resourceURL string, realm string, - validatorFactory ValidatorFactory, -) (*MultiProviderMiddleware, error) { + factory validatorFactory, +) (*multiProviderMiddleware, error) { if len(providers) == 0 { return nil, errors.New("at least one provider must be configured") } // Apply default realm if not specified if realm == "" { - realm = DefaultRealm + realm = defaultRealm } - m := &MultiProviderMiddleware{ - validators: make([]NamedValidator, 0, len(providers)), + m := &multiProviderMiddleware{ + validators: make([]namedValidator, 0, len(providers)), resourceURL: resourceURL, realm: realm, } for _, pc := range providers { - validator, err := validatorFactory(ctx, pc.ValidatorConfig) + validator, err := factory(ctx, pc.ValidatorConfig) if err != nil { return nil, fmt.Errorf("failed to create validator for provider %q: %w", pc.Name, err) } - nv := NamedValidator{ + nv := namedValidator{ Name: pc.Name, Validator: validator, } @@ -121,19 +112,19 @@ func NewMultiProviderMiddleware( } // Middleware returns an HTTP middleware function that performs authentication. -func (m *MultiProviderMiddleware) Middleware(next http.Handler) http.Handler { +func (m *multiProviderMiddleware) Middleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { token, err := auth.ExtractBearerToken(r) if err != nil { logger.Debugf("auth: token extraction failed: %v", err) - m.writeError(w, http.StatusUnauthorized, ErrorCodeInvalidRequest, "missing or malformed authorization header") + m.writeError(w, http.StatusUnauthorized, errorCodeInvalidRequest, "missing or malformed authorization header") return } result := m.validateToken(r.Context(), token) if result.Error != nil { logger.Debugf("auth: token validation failed: %v", result.Error) - m.writeError(w, http.StatusUnauthorized, ErrorCodeInvalidToken, "token validation failed") + m.writeError(w, http.StatusUnauthorized, errorCodeInvalidToken, "token validation failed") return } @@ -144,13 +135,13 @@ func (m *MultiProviderMiddleware) Middleware(next http.Handler) http.Handler { } // validateToken attempts to validate the token by iterating through providers sequentially. -func (m *MultiProviderMiddleware) validateToken(ctx context.Context, token string) ValidationResult { - providerErrors := make([]ProviderError, 0, len(m.validators)) +func (m *multiProviderMiddleware) validateToken(ctx context.Context, token string) validationResult { + providerErrors := make([]providerError, 0, len(m.validators)) for _, nv := range m.validators { _, err := nv.Validator.ValidateToken(ctx, token) if err != nil { - providerErrors = append(providerErrors, ProviderError{ + providerErrors = append(providerErrors, providerError{ Provider: nv.Name, Error: err, }) @@ -158,14 +149,14 @@ func (m *MultiProviderMiddleware) validateToken(ctx context.Context, token strin continue } - return ValidationResult{ + return validationResult{ Provider: nv.Name, Errors: providerErrors, } } - return ValidationResult{ - Error: ErrAllProvidersFailed, + return validationResult{ + Error: errAllProvidersFailed, Errors: providerErrors, } } @@ -186,8 +177,8 @@ func sanitizeHeaderValue(s string) string { } // writeError writes a JSON error response with RFC 6750 compliant WWW-Authenticate header. -// The errorCode parameter should be one of the RFC 6750 error codes (invalid_request, invalid_token). -func (m *MultiProviderMiddleware) writeError(w http.ResponseWriter, status int, errorCode, description string) { +// The errCode parameter should be one of the RFC 6750 error codes (invalid_request, invalid_token). +func (m *multiProviderMiddleware) writeError(w http.ResponseWriter, status int, errCode, description string) { w.Header().Set("Content-Type", "application/json") // Sanitize values to prevent header injection @@ -197,11 +188,11 @@ func (m *MultiProviderMiddleware) writeError(w http.ResponseWriter, status int, // Build WWW-Authenticate header with error codes per RFC 6750 Section 3 wwwAuth := fmt.Sprintf(`Bearer realm="%s", error="%s", error_description="%s"`, - realm, errorCode, sanitizedDescription) + realm, errCode, sanitizedDescription) if resourceURL != "" { wwwAuth = fmt.Sprintf( `Bearer realm="%s", error="%s", error_description="%s", resource_metadata="%s/.well-known/oauth-protected-resource"`, - realm, errorCode, sanitizedDescription, resourceURL) + realm, errCode, sanitizedDescription, resourceURL) } w.Header().Set("WWW-Authenticate", wwwAuth) w.WriteHeader(status) diff --git a/internal/auth/middleware_test.go b/internal/auth/middleware_test.go index fb23894..4884ec7 100644 --- a/internal/auth/middleware_test.go +++ b/internal/auth/middleware_test.go @@ -30,7 +30,7 @@ func singleProviderConfig() []providerConfig { func TestNewMultiProviderMiddleware_EmptyProviders(t *testing.T) { t.Parallel() - _, err := NewMultiProviderMiddleware(context.Background(), nil, "", "", DefaultValidatorFactory) + _, err := newMultiProviderMiddleware(context.Background(), nil, "", "", DefaultValidatorFactory) require.Error(t, err) assert.Contains(t, err.Error(), "at least one provider must be configured") } @@ -41,7 +41,7 @@ func TestMultiProviderMiddleware_Middleware(t *testing.T) { tests := []struct { name string authHeader string - setupMock func(*mocks.MockTokenValidatorInterface) + setupMock func(*mocks.MocktokenValidatorInterface) wantStatus int wantCalled bool }{ @@ -66,7 +66,7 @@ func TestMultiProviderMiddleware_Middleware(t *testing.T) { { name: "valid token", authHeader: "Bearer valid-token", - setupMock: func(m *mocks.MockTokenValidatorInterface) { + setupMock: func(m *mocks.MocktokenValidatorInterface) { m.EXPECT().ValidateToken(gomock.Any(), "valid-token"). Return(map[string]any{"sub": "user"}, nil) }, @@ -76,7 +76,7 @@ func TestMultiProviderMiddleware_Middleware(t *testing.T) { { name: "invalid token", authHeader: "Bearer bad-token", - setupMock: func(m *mocks.MockTokenValidatorInterface) { + setupMock: func(m *mocks.MocktokenValidatorInterface) { m.EXPECT().ValidateToken(gomock.Any(), "bad-token"). Return(nil, errors.New("validation failed")) }, @@ -90,17 +90,17 @@ func TestMultiProviderMiddleware_Middleware(t *testing.T) { t.Parallel() ctrl := gomock.NewController(t) - mockValidator := mocks.NewMockTokenValidatorInterface(ctrl) + mockValidator := mocks.NewMocktokenValidatorInterface(ctrl) if tt.setupMock != nil { tt.setupMock(mockValidator) } - m, err := NewMultiProviderMiddleware( + m, err := newMultiProviderMiddleware( context.Background(), singleProviderConfig(), "https://api.example.com", "", - func(_ context.Context, _ thvauth.TokenValidatorConfig) (TokenValidatorInterface, error) { + func(_ context.Context, _ thvauth.TokenValidatorConfig) (tokenValidatorInterface, error) { return mockValidator, nil }, ) @@ -138,8 +138,8 @@ func TestMultiProviderMiddleware_SequentialFallback(t *testing.T) { ctrl := gomock.NewController(t) // Two different providers with different issuers - keycloakMock := mocks.NewMockTokenValidatorInterface(ctrl) - googleMock := mocks.NewMockTokenValidatorInterface(ctrl) + keycloakMock := mocks.NewMocktokenValidatorInterface(ctrl) + googleMock := mocks.NewMocktokenValidatorInterface(ctrl) // First provider (Keycloak) fails, second (Google) succeeds gomock.InOrder( @@ -170,14 +170,14 @@ func TestMultiProviderMiddleware_SequentialFallback(t *testing.T) { // Factory returns the correct mock based on call order callIdx := 0 - validatorMocks := []*mocks.MockTokenValidatorInterface{keycloakMock, googleMock} + validatorMocks := []*mocks.MocktokenValidatorInterface{keycloakMock, googleMock} - m, err := NewMultiProviderMiddleware( + m, err := newMultiProviderMiddleware( context.Background(), providers, "", "", - func(_ context.Context, _ thvauth.TokenValidatorConfig) (TokenValidatorInterface, error) { + func(_ context.Context, _ thvauth.TokenValidatorConfig) (tokenValidatorInterface, error) { mock := validatorMocks[callIdx] callIdx++ return mock, nil @@ -275,16 +275,16 @@ func TestMultiProviderMiddleware_WWWAuthenticate(t *testing.T) { t.Parallel() ctrl := gomock.NewController(t) - mockValidator := mocks.NewMockTokenValidatorInterface(ctrl) + mockValidator := mocks.NewMocktokenValidatorInterface(ctrl) mockValidator.EXPECT().ValidateToken(gomock.Any(), gomock.Any()). Return(nil, errors.New("fail")).AnyTimes() - m, err := NewMultiProviderMiddleware( + m, err := newMultiProviderMiddleware( context.Background(), singleProviderConfig(), tt.resourceURL, tt.realm, - func(_ context.Context, _ thvauth.TokenValidatorConfig) (TokenValidatorInterface, error) { + func(_ context.Context, _ thvauth.TokenValidatorConfig) (tokenValidatorInterface, error) { return mockValidator, nil }, ) diff --git a/internal/auth/mocks/mock_validator.go b/internal/auth/mocks/mock_validator.go index 50a7986..7834698 100644 --- a/internal/auth/mocks/mock_validator.go +++ b/internal/auth/mocks/mock_validator.go @@ -3,7 +3,7 @@ // // Generated by this command: // -// mockgen -destination=mocks/mock_validator.go -package=mocks -source=validator.go TokenValidatorInterface +// mockgen -destination=mocks/mock_validator.go -package=mocks -source=validator.go tokenValidatorInterface // // Package mocks is a generated GoMock package. @@ -17,32 +17,32 @@ import ( gomock "go.uber.org/mock/gomock" ) -// MockTokenValidatorInterface is a mock of TokenValidatorInterface interface. -type MockTokenValidatorInterface struct { +// MocktokenValidatorInterface is a mock of tokenValidatorInterface interface. +type MocktokenValidatorInterface struct { ctrl *gomock.Controller - recorder *MockTokenValidatorInterfaceMockRecorder + recorder *MocktokenValidatorInterfaceMockRecorder isgomock struct{} } -// MockTokenValidatorInterfaceMockRecorder is the mock recorder for MockTokenValidatorInterface. -type MockTokenValidatorInterfaceMockRecorder struct { - mock *MockTokenValidatorInterface +// MocktokenValidatorInterfaceMockRecorder is the mock recorder for MocktokenValidatorInterface. +type MocktokenValidatorInterfaceMockRecorder struct { + mock *MocktokenValidatorInterface } -// NewMockTokenValidatorInterface creates a new mock instance. -func NewMockTokenValidatorInterface(ctrl *gomock.Controller) *MockTokenValidatorInterface { - mock := &MockTokenValidatorInterface{ctrl: ctrl} - mock.recorder = &MockTokenValidatorInterfaceMockRecorder{mock} +// NewMocktokenValidatorInterface creates a new mock instance. +func NewMocktokenValidatorInterface(ctrl *gomock.Controller) *MocktokenValidatorInterface { + mock := &MocktokenValidatorInterface{ctrl: ctrl} + mock.recorder = &MocktokenValidatorInterfaceMockRecorder{mock} return mock } // EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockTokenValidatorInterface) EXPECT() *MockTokenValidatorInterfaceMockRecorder { +func (m *MocktokenValidatorInterface) EXPECT() *MocktokenValidatorInterfaceMockRecorder { return m.recorder } // ValidateToken mocks base method. -func (m *MockTokenValidatorInterface) ValidateToken(ctx context.Context, token string) (jwt.MapClaims, error) { +func (m *MocktokenValidatorInterface) ValidateToken(ctx context.Context, token string) (jwt.MapClaims, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "ValidateToken", ctx, token) ret0, _ := ret[0].(jwt.MapClaims) @@ -51,7 +51,7 @@ func (m *MockTokenValidatorInterface) ValidateToken(ctx context.Context, token s } // ValidateToken indicates an expected call of ValidateToken. -func (mr *MockTokenValidatorInterfaceMockRecorder) ValidateToken(ctx, token any) *gomock.Call { +func (mr *MocktokenValidatorInterfaceMockRecorder) ValidateToken(ctx, token any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ValidateToken", reflect.TypeOf((*MockTokenValidatorInterface)(nil).ValidateToken), ctx, token) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ValidateToken", reflect.TypeOf((*MocktokenValidatorInterface)(nil).ValidateToken), ctx, token) } diff --git a/internal/auth/validator.go b/internal/auth/validator.go index 3224f52..9e86b61 100644 --- a/internal/auth/validator.go +++ b/internal/auth/validator.go @@ -1,6 +1,6 @@ package auth -//go:generate mockgen -destination=mocks/mock_validator.go -package=mocks -source=validator.go TokenValidatorInterface +//go:generate mockgen -destination=mocks/mock_validator.go -package=mocks -source=validator.go tokenValidatorInterface import ( "context" @@ -8,8 +8,8 @@ import ( "github.com/golang-jwt/jwt/v5" ) -// TokenValidatorInterface abstracts token validation for testability. +// tokenValidatorInterface abstracts token validation for testability. // This allows mocking the toolhive auth.TokenValidator in tests. -type TokenValidatorInterface interface { +type tokenValidatorInterface interface { ValidateToken(ctx context.Context, token string) (jwt.MapClaims, error) }