diff --git a/.env.example b/.env.example index 42b5b248..cd29653c 100644 --- a/.env.example +++ b/.env.example @@ -91,6 +91,8 @@ TINYAUTH_APPS_name_LDAP_GROUPS= # Comma-separated list of allowed OAuth domains. TINYAUTH_OAUTH_WHITELIST= +# Path to the OAuth whitelist file. +TINYAUTH_OAUTH_WHITELISTFILE= # The OAuth provider to use for automatic redirection. TINYAUTH_OAUTH_AUTOREDIRECT= # OAuth client ID. diff --git a/internal/bootstrap/app_bootstrap.go b/internal/bootstrap/app_bootstrap.go index 3879c05e..0290be4e 100644 --- a/internal/bootstrap/app_bootstrap.go +++ b/internal/bootstrap/app_bootstrap.go @@ -30,6 +30,7 @@ type BootstrapApp struct { redirectCookieName string oauthSessionCookieName string users []config.User + oauthWhitelist []string oauthProviders map[string]config.OAuthServiceConfig configuredProviders []controller.Provider oidcClients []config.OIDCClientConfig @@ -71,6 +72,13 @@ func (app *BootstrapApp) Setup() error { app.context.users = users + oauthWhitelist, err := utils.GetStringList(app.config.OAuth.Whitelist, app.config.OAuth.WhitelistFile) + if err != nil { + return err + } + + app.context.oauthWhitelist = oauthWhitelist + // Setup OAuth providers app.context.oauthProviders = app.config.OAuth.Providers diff --git a/internal/bootstrap/service_bootstrap.go b/internal/bootstrap/service_bootstrap.go index 9c5806b9..3d6b5711 100644 --- a/internal/bootstrap/service_bootstrap.go +++ b/internal/bootstrap/service_bootstrap.go @@ -70,7 +70,7 @@ func (app *BootstrapApp) initServices(queries *repository.Queries) (Services, er authService := service.NewAuthService(service.AuthServiceConfig{ Users: app.context.users, - OauthWhitelist: app.config.OAuth.Whitelist, + OauthWhitelist: app.context.oauthWhitelist, SessionExpiry: app.config.Auth.SessionExpiry, SessionMaxLifetime: app.config.Auth.SessionMaxLifetime, SecureCookie: app.config.Auth.SecureCookie, diff --git a/internal/config/config.go b/internal/config/config.go index 1bf64af4..ecfae946 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -159,6 +159,7 @@ type IPConfig struct { type OAuthConfig struct { Whitelist []string `description:"Comma-separated list of allowed OAuth domains." yaml:"whitelist"` + WhitelistFile string `description:"Path to the OAuth whitelist file." yaml:"whitelistFile"` AutoRedirect string `description:"The OAuth provider to use for automatic redirection." yaml:"autoRedirect"` Providers map[string]OAuthServiceConfig `description:"OAuth providers configuration." yaml:"providers"` } diff --git a/internal/utils/string_utils.go b/internal/utils/string_utils.go index 8a629adc..d6725b4d 100644 --- a/internal/utils/string_utils.go +++ b/internal/utils/string_utils.go @@ -28,3 +28,41 @@ func CoalesceToString(value any) string { return "" } } + +func ParseNonEmptyLines(contents string) []string { + lines := make([]string, 0) + + for line := range strings.SplitSeq(contents, "\n") { + lineTrimmed := strings.TrimSpace(line) + if lineTrimmed == "" { + continue + } + lines = append(lines, lineTrimmed) + } + + return lines +} + +func GetStringList(valuesCfg []string, valuesPath string) ([]string, error) { + values := make([]string, 0, len(valuesCfg)) + + for _, value := range valuesCfg { + valueTrimmed := strings.TrimSpace(value) + if valueTrimmed == "" { + continue + } + values = append(values, valueTrimmed) + } + + if valuesPath == "" { + return values, nil + } + + contents, err := ReadFile(valuesPath) + if err != nil { + return []string{}, err + } + + values = append(values, ParseNonEmptyLines(contents)...) + return values, nil +} diff --git a/internal/utils/string_utils_test.go b/internal/utils/string_utils_test.go index 1db3bf17..2d03bf01 100644 --- a/internal/utils/string_utils_test.go +++ b/internal/utils/string_utils_test.go @@ -1,6 +1,7 @@ package utils_test import ( + "os" "testing" "github.com/tinyauthapp/tinyauth/internal/utils" @@ -57,3 +58,33 @@ func TestCompileUserEmail(t *testing.T) { // Test with invalid email assert.Equal(t, "user@example.com", utils.CompileUserEmail("user", "example.com")) } + +func TestParseNonEmptyLines(t *testing.T) { + lines := utils.ParseNonEmptyLines(" first@example.com \n\n second@example.com \n \n") + + assert.DeepEqual(t, []string{"first@example.com", "second@example.com"}, lines) +} + +func TestGetStringList(t *testing.T) { + file, err := os.Create("/tmp/tinyauth_list_test_file") + assert.NilError(t, err) + + _, err = file.WriteString(" third@example.com \n\n fourth@example.com \n") + assert.NilError(t, err) + + err = file.Close() + assert.NilError(t, err) + defer os.Remove("/tmp/tinyauth_list_test_file") + + values, err := utils.GetStringList([]string{" first@example.com ", "", "second@example.com"}, "/tmp/tinyauth_list_test_file") + assert.NilError(t, err) + assert.DeepEqual(t, []string{"first@example.com", "second@example.com", "third@example.com", "fourth@example.com"}, values) + + values, err = utils.GetStringList(nil, "") + assert.NilError(t, err) + assert.DeepEqual(t, []string{}, values) + + values, err = utils.GetStringList(nil, "/tmp/non_existing_list_file") + assert.ErrorContains(t, err, "no such file or directory") + assert.DeepEqual(t, []string{}, values) +} diff --git a/internal/utils/user_utils.go b/internal/utils/user_utils.go index d80c655d..f3d67b5a 100644 --- a/internal/utils/user_utils.go +++ b/internal/utils/user_utils.go @@ -34,32 +34,9 @@ func ParseUsers(usersStr []string, userAttributes map[string]config.UserAttribut } func GetUsers(usersCfg []string, usersPath string, userAttributes map[string]config.UserAttributes) ([]config.User, error) { - var usersStr []string - - if len(usersCfg) == 0 && usersPath == "" { - return []config.User{}, nil - } - - if len(usersCfg) > 0 { - usersStr = append(usersStr, usersCfg...) - } - - if usersPath != "" { - contents, err := ReadFile(usersPath) - - if err != nil { - return []config.User{}, err - } - - lines := strings.SplitSeq(contents, "\n") - - for line := range lines { - lineTrimmed := strings.TrimSpace(line) - if lineTrimmed == "" { - continue - } - usersStr = append(usersStr, lineTrimmed) - } + usersStr, err := GetStringList(usersCfg, usersPath) + if err != nil { + return []config.User{}, err } return ParseUsers(usersStr, userAttributes)