diff --git a/cmd/root.go b/cmd/root.go index 99b6a45f..5629c690 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -1,6 +1,7 @@ package cmd import ( + "os" "strings" "tinyauth/internal/bootstrap" "tinyauth/internal/config" @@ -14,15 +15,16 @@ import ( ) type rootCmd struct { - root *cobra.Command - cmd *cobra.Command - - viper *viper.Viper + root *cobra.Command + cmd *cobra.Command + viper *viper.Viper + aclFlags map[string]string } func newRootCmd() *rootCmd { return &rootCmd{ - viper: viper.New(), + viper: viper.New(), + aclFlags: make(map[string]string), } } @@ -32,6 +34,9 @@ func (c *rootCmd) Register() { Short: "The simplest way to protect your apps with a login screen", Long: `Tinyauth is a simple authentication middleware that adds a simple login screen or OAuth with Google, Github or any other provider to all of your docker apps.`, Run: c.run, + FParseErrWhitelist: cobra.FParseErrWhitelist{ + UnknownFlags: true, + }, } c.viper.AutomaticEnv() @@ -116,7 +121,7 @@ func (c *rootCmd) run(cmd *cobra.Command, args []string) { log.Warn().Msg("Log level set to trace, this will log sensitive information!") } - app := bootstrap.NewBootstrapApp(conf) + app := bootstrap.NewBootstrapApp(conf, c.aclFlags) err = app.Setup() if err != nil { @@ -126,6 +131,8 @@ func (c *rootCmd) run(cmd *cobra.Command, args []string) { func Run() { rootCmd := newRootCmd() + rootCmd.aclFlags = utils.ExtractACLFlags(os.Args[1:]) + rootCmd.Register() root := rootCmd.GetCmd() diff --git a/internal/bootstrap/app_bootstrap.go b/internal/bootstrap/app_bootstrap.go index fdbd3827..583bef48 100644 --- a/internal/bootstrap/app_bootstrap.go +++ b/internal/bootstrap/app_bootstrap.go @@ -37,13 +37,15 @@ type Service interface { } type BootstrapApp struct { - config config.Config - uuid string + config config.Config + aclFlags map[string]string + uuid string } -func NewBootstrapApp(config config.Config) *BootstrapApp { +func NewBootstrapApp(config config.Config, aclFlags map[string]string) *BootstrapApp { return &BootstrapApp{ - config: config, + config: config, + aclFlags: aclFlags, } } @@ -140,6 +142,7 @@ func (app *BootstrapApp) Setup() error { // Create services dockerService := service.NewDockerService() aclsService := service.NewAccessControlsService(dockerService) + aclsService.SetACLFlags(app.aclFlags) authService := service.NewAuthService(authConfig, dockerService, ldapService, database) oauthBrokerService := service.NewOAuthBrokerService(oauthProviders) diff --git a/internal/service/access_controls_service.go b/internal/service/access_controls_service.go index cde27e50..f5d5bcf1 100644 --- a/internal/service/access_controls_service.go +++ b/internal/service/access_controls_service.go @@ -4,70 +4,39 @@ import ( "os" "strings" "tinyauth/internal/config" - "tinyauth/internal/utils/decoders" + "tinyauth/internal/utils" "github.com/rs/zerolog/log" ) type AccessControlsService struct { - docker *DockerService - envACLs config.Apps + docker *DockerService + envACLs config.Apps + aclFlags map[string]string } func NewAccessControlsService(docker *DockerService) *AccessControlsService { return &AccessControlsService{ - docker: docker, + docker: docker, + aclFlags: make(map[string]string), } } -func (acls *AccessControlsService) Init() error { - acls.envACLs = config.Apps{} - env := os.Environ() - appEnvVars := []string{} - - for _, e := range env { - if strings.HasPrefix(e, "TINYAUTH_APPS_") { - appEnvVars = append(appEnvVars, e) - } - } - - err := acls.loadEnvACLs(appEnvVars) - - if err != nil { - return err - } - - return nil +func (acls *AccessControlsService) SetACLFlags(flags map[string]string) { + acls.aclFlags = flags } -func (acls *AccessControlsService) loadEnvACLs(appEnvVars []string) error { - if len(appEnvVars) == 0 { - return nil - } - - envAcls := map[string]string{} - - for _, e := range appEnvVars { - parts := strings.SplitN(e, "=", 2) - if len(parts) != 2 { - continue - } - - // Normalize key, this should use the same normalization logic as in utils/decoders/decoders.go - key := parts[0] - key = strings.ToLower(key) - key = strings.ReplaceAll(key, "_", ".") - value := parts[1] - envAcls[key] = value - } +func (acls *AccessControlsService) Init() error { + env := os.Environ() - apps, err := decoders.DecodeLabels(envAcls) + apps, err := utils.GetACLsConfig(env, acls.aclFlags) if err != nil { return err } acls.envACLs = apps + return nil } diff --git a/internal/utils/app_utils.go b/internal/utils/app_utils.go index 76044c95..77e2facc 100644 --- a/internal/utils/app_utils.go +++ b/internal/utils/app_utils.go @@ -208,3 +208,53 @@ func GetOAuthProvidersConfig(env []string, args []string, appUrl string) (map[st // Return combined providers return providers, nil } + +func GetACLsConfig(env []string, flagsMap map[string]string) (config.Apps, error) { + apps := config.Apps{Apps: make(map[string]config.App)} + + envMap := make(map[string]string) + + for _, e := range env { + pair := strings.SplitN(e, "=", 2) + if len(pair) == 2 { + envMap[pair[0]] = pair[1] + } + } + + envApps, err := decoders.DecodeACLEnv[config.Apps](envMap, "apps") + + if err != nil { + return config.Apps{}, err + } + + if envApps.Apps != nil { + maps.Copy(apps.Apps, envApps.Apps) + } + + flagApps, err := decoders.DecodeACLFlags[config.Apps](flagsMap, "apps") + + if err != nil { + return config.Apps{}, err + } + + if flagApps.Apps != nil { + maps.Copy(apps.Apps, flagApps.Apps) + } + + return apps, nil +} + +func ExtractACLFlags(args []string) map[string]string { + aclFlags := make(map[string]string) + + for _, arg := range args { + if strings.HasPrefix(arg, "--apps-") || strings.HasPrefix(arg, "--tinyauth-apps-") { + pair := strings.SplitN(arg[2:], "=", 2) + if len(pair) == 2 { + aclFlags[pair[0]] = pair[1] + } + } + } + + return aclFlags +} diff --git a/internal/utils/decoders/decoders.go b/internal/utils/decoders/decoders.go index 28b72fb3..2f0c7ccb 100644 --- a/internal/utils/decoders/decoders.go +++ b/internal/utils/decoders/decoders.go @@ -7,6 +7,71 @@ import ( "github.com/stoewer/go-strcase" ) +func ParsePath(parts []string, idx int, t reflect.Type) []string { + if idx >= len(parts) { + return []string{} + } + + if t.Kind() == reflect.Map { + + if idx >= len(parts) { + return []string{} + } + + elemType := t.Elem() + keyEndIdx := idx + 1 + + if elemType.Kind() == reflect.Struct { + for i := idx + 1; i < len(parts); i++ { + found := false + + for j := 0; j < elemType.NumField(); j++ { + field := elemType.Field(j) + if strings.EqualFold(parts[i], field.Name) { + keyEndIdx = i + found = true + break + } + } + + if found { + break + } + } + } + + keyParts := parts[idx:keyEndIdx] + keyName := strings.ToLower(strings.Join(keyParts, "_")) + + rest := ParsePath(parts, keyEndIdx, elemType) + result := append([]string{keyName}, rest...) + return result + } + + if t.Kind() == reflect.Struct { + for i := 0; i < t.NumField(); i++ { + field := t.Field(i) + if field.Type.Kind() == reflect.Map { + rest := ParsePath(parts, idx, field.Type) + if len(rest) > 0 { + return rest + } + } + } + + for i := 0; i < t.NumField(); i++ { + field := t.Field(i) + if strings.EqualFold(parts[idx], field.Name) { + rest := ParsePath(parts, idx+1, field.Type) + result := append([]string{strings.ToLower(field.Name)}, rest...) + return result + } + } + } + + return []string{} +} + func normalizeKeys[T any](input map[string]string, root string, sep string) map[string]string { knownKeys := getKnownKeys[T]() normalized := make(map[string]string) @@ -74,3 +139,57 @@ func getKnownKeys[T any]() []string { return keys } + +func normalizeACLKeys[T any](input map[string]string, root string, sep string) map[string]string { + normalized := make(map[string]string) + var t T + rootType := reflect.TypeOf(t) + + for k, v := range input { + parts := strings.Split(strings.ToLower(k), sep) + + if len(parts) < 2 { + continue + } + + // Two cases: + // 1. Keys starting with "tinyauth" (env vars): tinyauth_apps_... + // 2. Keys starting with root directly (flags): apps-... + startIdx := 0 + if parts[0] == "tinyauth" { + if len(parts) < 3 { + continue + } + if parts[1] != root { + continue + } + startIdx = 2 // Skip "tinyauth" and root + } else if parts[0] == root { + startIdx = 1 // Skip root only + } else { + continue + } + + if startIdx < len(parts) { + parsedParts := ParsePath(parts[startIdx:], 0, rootType) + + if len(parsedParts) == 0 { + continue + } + + final := "tinyauth." + root + + for _, part := range parsedParts { + if strings.Contains(part, "_") { + final += "." + part + } else { + final += "." + strcase.LowerCamelCase(part) + } + } + + normalized[final] = v + } + } + + return normalized +} diff --git a/internal/utils/decoders/env_decoder.go b/internal/utils/decoders/env_decoder.go index 532ec648..0132adb0 100644 --- a/internal/utils/decoders/env_decoder.go +++ b/internal/utils/decoders/env_decoder.go @@ -17,3 +17,17 @@ func DecodeEnv[T any, C any](env map[string]string, subName string) (T, error) { return result, nil } + +func DecodeACLEnv[T any](env map[string]string, subName string) (T, error) { + var result T + + normalized := normalizeACLKeys[T](env, subName, "_") + + err := parser.Decode(normalized, &result, "tinyauth", "tinyauth."+subName) + + if err != nil { + return result, err + } + + return result, nil +} diff --git a/internal/utils/decoders/flags_decoder.go b/internal/utils/decoders/flags_decoder.go index 0aae2341..72b623e5 100644 --- a/internal/utils/decoders/flags_decoder.go +++ b/internal/utils/decoders/flags_decoder.go @@ -21,6 +21,21 @@ func DecodeFlags[T any, C any](flags map[string]string, subName string) (T, erro return result, nil } +func DecodeACLFlags[T any](flags map[string]string, subName string) (T, error) { + var result T + + filtered := filterFlags(flags) + normalized := normalizeACLKeys[T](filtered, subName, "-") + + err := parser.Decode(normalized, &result, "tinyauth", "tinyauth."+subName) + + if err != nil { + return result, err + } + + return result, nil +} + func filterFlags(flags map[string]string) map[string]string { filtered := make(map[string]string) for k, v := range flags {