diff --git a/docs/docs/configuration.md b/docs/docs/configuration.md index 56bf941..df6cf2d 100644 --- a/docs/docs/configuration.md +++ b/docs/docs/configuration.md @@ -27,10 +27,11 @@ Complete reference for all MCP Auth Proxy configuration options. #### Password Authentication -| Option | Environment Variable | Default | Description | -| ----------------- | -------------------- | ------- | ------------------------------------------------------------------- | -| `--password` | `PASSWORD` | - | Plain text password for authentication (will be hashed with bcrypt) | -| `--password-hash` | `PASSWORD_HASH` | - | Bcrypt hash of password for authentication | +| Option | Environment Variable | Default | Description | +| --------------------------- | ------------------------- | ------- | -------------------------------------------------------------------------------------------- | +| `--no-provider-auto-select` | `NO_PROVIDER_AUTO_SELECT` | `false` | Disable auto-redirect when only one OAuth/OIDC provider is configured and no password is set | +| `--password` | `PASSWORD` | - | Plain text password for authentication (will be hashed with bcrypt) | +| `--password-hash` | `PASSWORD_HASH` | - | Bcrypt hash of password for authentication | #### Google OAuth diff --git a/main.go b/main.go index a369213..6ed74d1 100644 --- a/main.go +++ b/main.go @@ -89,6 +89,7 @@ func main() { var oidcProviderName string var oidcAllowedUsers string var oidcAllowedUsersGlob string + var noProviderAutoSelect bool var password string var passwordHash string var proxyBearerToken string @@ -195,6 +196,7 @@ func main() { oidcProviderName, oidcAllowedUsersList, oidcAllowedUsersGlobList, + noProviderAutoSelect, password, passwordHash, trustedProxiesList, @@ -239,6 +241,7 @@ func main() { rootCmd.Flags().StringVar(&oidcAllowedUsersGlob, "oidc-allowed-users-glob", getEnvWithDefault("OIDC_ALLOWED_USERS_GLOB", ""), "Comma-separated list of glob patterns for allowed OIDC users") // Password authentication + rootCmd.Flags().BoolVar(&noProviderAutoSelect, "no-provider-auto-select", getEnvBoolWithDefault("NO_PROVIDER_AUTO_SELECT", false), "Disable auto-redirect when only one OAuth/OIDC provider is configured and no password is set") rootCmd.Flags().StringVar(&password, "password", getEnvWithDefault("PASSWORD", ""), "Plain text password for authentication (will be hashed with bcrypt)") rootCmd.Flags().StringVar(&passwordHash, "password-hash", getEnvWithDefault("PASSWORD_HASH", ""), "Bcrypt hash of password for authentication") diff --git a/pkg/auth/auth.go b/pkg/auth/auth.go index d637331..339ff7e 100644 --- a/pkg/auth/auth.go +++ b/pkg/auth/auth.go @@ -21,9 +21,12 @@ type AuthRouter struct { loginTemplate *template.Template unauthorizedTemplate *template.Template errorTemplate *template.Template + // When true, do not auto-redirect to the sole provider even if + // there is only one provider and no password is set. + noProviderAutoSelect bool } -func NewAuthRouter(passwordHash []string, providers ...Provider) (*AuthRouter, error) { +func NewAuthRouter(passwordHash []string, noProviderAutoSelect bool, providers ...Provider) (*AuthRouter, error) { tmpl, err := template.ParseFS(templateFS, "templates/login.html") if err != nil { return nil, err @@ -45,6 +48,7 @@ func NewAuthRouter(passwordHash []string, providers ...Provider) (*AuthRouter, e loginTemplate: tmpl, unauthorizedTemplate: unauthorizedTmpl, errorTemplate: errorTmpl, + noProviderAutoSelect: noProviderAutoSelect, }, nil } @@ -137,6 +141,11 @@ func (a *AuthRouter) handleLogin(c *gin.Context) { a.handleLoginPost(c) return } + // Auto-redirect to the sole provider if enabled and no password is set + if !a.noProviderAutoSelect && len(a.passwordHash) == 0 && len(a.providers) == 1 { + c.Redirect(http.StatusFound, a.providers[0].AuthURL()) + return + } a.renderLogin(c, "") } diff --git a/pkg/auth/auth_test.go b/pkg/auth/auth_test.go index 48b8b49..e1dc8e2 100644 --- a/pkg/auth/auth_test.go +++ b/pkg/auth/auth_test.go @@ -53,8 +53,8 @@ func TestAuthenticationFlow(t *testing.T) { mockProvider.EXPECT().AuthURL().Return("/.auth/test").AnyTimes() mockProvider.EXPECT().RedirectURL().Return("/.auth/test/callback").AnyTimes() - // Create AuthRouter - authRouter, err := NewAuthRouter(nil, mockProvider) + // Create AuthRouter (auto-select enabled by default) + authRouter, err := NewAuthRouter(nil, false, mockProvider) require.NoError(t, err) router := setupTestRouter(authRouter) @@ -88,7 +88,7 @@ func TestAuthenticationFlow(t *testing.T) { mockProvider.EXPECT().Authorization(gomock.Any(), mockToken).Return(true, "authorized_user", nil) // Create AuthRouter - authRouter, err := NewAuthRouter(nil, mockProvider) + authRouter, err := NewAuthRouter(nil, false, mockProvider) require.NoError(t, err) router := setupTestRouter(authRouter) @@ -149,7 +149,7 @@ func TestAuthenticationFlow(t *testing.T) { mockProvider.EXPECT().Authorization(gomock.Any(), mockToken).Return(false, "unauthorized_user", nil) // Create AuthRouter - authRouter, err := NewAuthRouter(nil, mockProvider) + authRouter, err := NewAuthRouter(nil, false, mockProvider) require.NoError(t, err) router := setupTestRouter(authRouter) @@ -187,3 +187,95 @@ func TestAuthenticationFlow(t *testing.T) { require.Equal(t, "/.auth/login", location) }) } + +func TestLoginAutoRedirect(t *testing.T) { + t.Run("Auto-redirects when single provider and no password", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockProvider := NewMockProvider(ctrl) + mockProvider.EXPECT().Name().Return("test").AnyTimes() + mockProvider.EXPECT().Type().Return("test").AnyTimes() + mockProvider.EXPECT().AuthURL().Return("/.auth/test").AnyTimes() + mockProvider.EXPECT().RedirectURL().Return("/.auth/test/callback").AnyTimes() + + authRouter, err := NewAuthRouter(nil, false, mockProvider) + require.NoError(t, err) + + router := gin.New() + store := memstore.NewStore([]byte("test-secret")) + router.Use(sessions.Sessions("session", store)) + authRouter.SetupRoutes(router) + + server := httptest.NewServer(router) + defer server.Close() + + client := setupClient() + resp, err := client.Get(server.URL + LoginEndpoint) + require.NoError(t, err) + defer resp.Body.Close() + + require.Equal(t, http.StatusFound, resp.StatusCode) + location := resp.Header.Get("Location") + require.Equal(t, "/.auth/test", location) + }) + + t.Run("Does not redirect when disabled", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockProvider := NewMockProvider(ctrl) + mockProvider.EXPECT().Name().Return("test").AnyTimes() + mockProvider.EXPECT().Type().Return("test").AnyTimes() + mockProvider.EXPECT().AuthURL().Return("/.auth/test").AnyTimes() + mockProvider.EXPECT().RedirectURL().Return("/.auth/test/callback").AnyTimes() + + authRouter, err := NewAuthRouter(nil, true, mockProvider) + require.NoError(t, err) + + router := gin.New() + store := memstore.NewStore([]byte("test-secret")) + router.Use(sessions.Sessions("session", store)) + authRouter.SetupRoutes(router) + + server := httptest.NewServer(router) + defer server.Close() + + client := setupClient() + resp, err := client.Get(server.URL + LoginEndpoint) + require.NoError(t, err) + defer resp.Body.Close() + + require.Equal(t, http.StatusOK, resp.StatusCode) + }) + + t.Run("Does not redirect when password configured", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockProvider := NewMockProvider(ctrl) + mockProvider.EXPECT().Name().Return("test").AnyTimes() + mockProvider.EXPECT().Type().Return("test").AnyTimes() + mockProvider.EXPECT().AuthURL().Return("/.auth/test").AnyTimes() + mockProvider.EXPECT().RedirectURL().Return("/.auth/test/callback").AnyTimes() + + // Non-empty passwordHash slice disables auto-select + authRouter, err := NewAuthRouter([]string{"dummy"}, false, mockProvider) + require.NoError(t, err) + + router := gin.New() + store := memstore.NewStore([]byte("test-secret")) + router.Use(sessions.Sessions("session", store)) + authRouter.SetupRoutes(router) + + server := httptest.NewServer(router) + defer server.Close() + + client := setupClient() + resp, err := client.Get(server.URL + LoginEndpoint) + require.NoError(t, err) + defer resp.Body.Close() + + require.Equal(t, http.StatusOK, resp.StatusCode) + }) +} diff --git a/pkg/idp/idp_test.go b/pkg/idp/idp_test.go index 51ee643..6d220e3 100644 --- a/pkg/idp/idp_test.go +++ b/pkg/idp/idp_test.go @@ -64,7 +64,7 @@ func setupTestServer(t *testing.T) (*httptest.Server, repository.Repository, str }) // Create auth router and IDP router - authRouter, err := auth.NewAuthRouter([]string{}) + authRouter, err := auth.NewAuthRouter([]string{}, false) require.NoError(t, err) logger, _ := zap.NewDevelopment() diff --git a/pkg/mcp-proxy/main.go b/pkg/mcp-proxy/main.go index c88fbbb..7e4a16f 100644 --- a/pkg/mcp-proxy/main.go +++ b/pkg/mcp-proxy/main.go @@ -57,6 +57,7 @@ func Run( oidcProviderName string, oidcAllowedUsers []string, oidcAllowedUsersGlob []string, + noProviderAutoSelect bool, password string, passwordHash string, trustedProxy []string, @@ -201,7 +202,7 @@ func Run( passwordHashes = append(passwordHashes, passwordHash) } - authRouter, err := auth.NewAuthRouter(passwordHashes, providers...) + authRouter, err := auth.NewAuthRouter(passwordHashes, noProviderAutoSelect, providers...) if err != nil { return fmt.Errorf("failed to create auth router: %w", err) }