From ace22acdb2bc750dd8c6b3d12304e9342b6cc779 Mon Sep 17 00:00:00 2001 From: Stavros Date: Mon, 25 Aug 2025 13:28:30 +0300 Subject: [PATCH 01/17] wip: add middlewares --- internal/hooks/hooks.go | 144 ---------------------- internal/middleware/context_middleware.go | 143 +++++++++++++++++++++ internal/middleware/ui_middlware.go | 66 ++++++++++ internal/middleware/zerolog_middleware.go | 62 ++++++++++ internal/server/server.go | 81 +----------- internal/types/config.go | 5 - 6 files changed, 276 insertions(+), 225 deletions(-) delete mode 100644 internal/hooks/hooks.go create mode 100644 internal/middleware/context_middleware.go create mode 100644 internal/middleware/ui_middlware.go create mode 100644 internal/middleware/zerolog_middleware.go diff --git a/internal/hooks/hooks.go b/internal/hooks/hooks.go deleted file mode 100644 index 3083b98e..00000000 --- a/internal/hooks/hooks.go +++ /dev/null @@ -1,144 +0,0 @@ -package hooks - -import ( - "fmt" - "strings" - "tinyauth/internal/auth" - "tinyauth/internal/oauth" - "tinyauth/internal/providers" - "tinyauth/internal/types" - "tinyauth/internal/utils" - - "github.com/gin-gonic/gin" - "github.com/rs/zerolog/log" -) - -type Hooks struct { - Config types.HooksConfig - Auth *auth.Auth - Providers *providers.Providers -} - -func NewHooks(config types.HooksConfig, auth *auth.Auth, providers *providers.Providers) *Hooks { - return &Hooks{ - Config: config, - Auth: auth, - Providers: providers, - } -} - -func (hooks *Hooks) UseUserContext(c *gin.Context) types.UserContext { - cookie, err := hooks.Auth.GetSessionCookie(c) - var provider *oauth.OAuth - - if err != nil { - log.Error().Err(err).Msg("Failed to get session cookie") - goto basic - } - - if cookie.TotpPending { - log.Debug().Msg("Totp pending") - return types.UserContext{ - Username: cookie.Username, - Name: cookie.Name, - Email: cookie.Email, - Provider: cookie.Provider, - TotpPending: true, - } - } - - if cookie.Provider == "username" { - log.Debug().Msg("Provider is username") - - userSearch := hooks.Auth.SearchUser(cookie.Username) - - if userSearch.Type == "unknown" { - log.Warn().Str("username", cookie.Username).Msg("User does not exist") - goto basic - } - - log.Debug().Str("type", userSearch.Type).Msg("User exists") - - return types.UserContext{ - Username: cookie.Username, - Name: cookie.Name, - Email: cookie.Email, - IsLoggedIn: true, - Provider: "username", - } - } - - log.Debug().Msg("Provider is not username") - - provider = hooks.Providers.GetProvider(cookie.Provider) - - if provider != nil { - log.Debug().Msg("Provider exists") - - if !hooks.Auth.EmailWhitelisted(cookie.Email) { - log.Warn().Str("email", cookie.Email).Msg("Email is not whitelisted") - hooks.Auth.DeleteSessionCookie(c) - goto basic - } - - log.Debug().Msg("Email is whitelisted") - - return types.UserContext{ - Username: cookie.Username, - Name: cookie.Name, - Email: cookie.Email, - IsLoggedIn: true, - OAuth: true, - Provider: cookie.Provider, - OAuthGroups: cookie.OAuthGroups, - } - } - -basic: - log.Debug().Msg("Trying basic auth") - - basic := hooks.Auth.GetBasicAuth(c) - - if basic != nil { - log.Debug().Msg("Got basic auth") - - userSearch := hooks.Auth.SearchUser(basic.Username) - - if userSearch.Type == "unkown" { - log.Error().Str("username", basic.Username).Msg("Basic auth user does not exist") - return types.UserContext{} - } - - if !hooks.Auth.VerifyUser(userSearch, basic.Password) { - log.Error().Str("username", basic.Username).Msg("Basic auth user password incorrect") - return types.UserContext{} - } - - if userSearch.Type == "ldap" { - log.Debug().Msg("User is LDAP") - - return types.UserContext{ - Username: basic.Username, - Name: utils.Capitalize(basic.Username), - Email: fmt.Sprintf("%s@%s", strings.ToLower(basic.Username), hooks.Config.Domain), - IsLoggedIn: true, - Provider: "basic", - TotpEnabled: false, - } - } - - user := hooks.Auth.GetLocalUser(basic.Username) - - return types.UserContext{ - Username: basic.Username, - Name: utils.Capitalize(basic.Username), - Email: fmt.Sprintf("%s@%s", strings.ToLower(basic.Username), hooks.Config.Domain), - IsLoggedIn: true, - Provider: "basic", - TotpEnabled: user.TotpSecret != "", - } - - } - - return types.UserContext{} -} diff --git a/internal/middleware/context_middleware.go b/internal/middleware/context_middleware.go new file mode 100644 index 00000000..c4b76823 --- /dev/null +++ b/internal/middleware/context_middleware.go @@ -0,0 +1,143 @@ +package middlewares + +import ( + "fmt" + "strings" + "tinyauth/internal/auth" + "tinyauth/internal/providers" + "tinyauth/internal/types" + "tinyauth/internal/utils" + + "github.com/gin-gonic/gin" +) + +type ContextMiddlewareConfig struct { + Domain string +} + +type ContextMiddleware struct { + Config ContextMiddlewareConfig + Auth *auth.Auth + Providers *providers.Providers +} + +func NewContextMiddleware(config ContextMiddlewareConfig, auth *auth.Auth, providers *providers.Providers) *ContextMiddleware { + return &ContextMiddleware{ + Config: config, + Auth: auth, + Providers: providers, + } +} + +func (m *ContextMiddleware) Middleware() gin.HandlerFunc { + return func(c *gin.Context) { + cookie, err := m.Auth.GetSessionCookie(c) + + if err != nil { + goto basic + } + + if cookie.TotpPending { + c.Set("context", &types.UserContext{ + Username: cookie.Username, + Name: cookie.Name, + Email: cookie.Email, + Provider: "username", + TotpPending: true, + TotpEnabled: true, + }) + c.Next() + return + } + + switch cookie.Provider { + case "username": + userSearch := m.Auth.SearchUser(cookie.Username) + + if userSearch.Type == "unknown" { + goto basic + } + + c.Set("context", &types.UserContext{ + Username: cookie.Username, + Name: cookie.Name, + Email: cookie.Email, + Provider: "username", + IsLoggedIn: true, + }) + c.Next() + return + default: + provider := m.Providers.GetProvider(cookie.Provider) + + if provider == nil { + goto basic + } + + if !m.Auth.EmailWhitelisted(cookie.Email) { + m.Auth.DeleteSessionCookie(c) + goto basic + } + + c.Set("context", &types.UserContext{ + Username: cookie.Username, + Name: cookie.Name, + Email: cookie.Email, + Provider: cookie.Provider, + OAuthGroups: cookie.OAuthGroups, + IsLoggedIn: true, + OAuth: true, + }) + c.Next() + return + } + + basic: + basic := m.Auth.GetBasicAuth(c) + + if basic == nil { + c.Next() + return + } + + userSearch := m.Auth.SearchUser(basic.Username) + + if userSearch.Type == "unknown" { + c.Next() + return + } + + if !m.Auth.VerifyUser(userSearch, basic.Password) { + c.Next() + return + } + + switch userSearch.Type { + case "local": + user := m.Auth.GetLocalUser(basic.Username) + + c.Set("context", &types.UserContext{ + Username: user.Username, + Name: utils.Capitalize(user.Username), + Email: fmt.Sprintf("%s@%s", strings.ToLower(user.Username), m.Config.Domain), + Provider: "basic", + IsLoggedIn: true, + TotpEnabled: user.TotpSecret != "", + }) + c.Next() + return + case "ldap": + c.Set("context", &types.UserContext{ + Username: basic.Username, + Name: utils.Capitalize(basic.Username), + Email: fmt.Sprintf("%s@%s", strings.ToLower(basic.Username), m.Config.Domain), + Provider: "basic", + IsLoggedIn: true, + }) + c.Next() + return + } + + c.Next() + } +} diff --git a/internal/middleware/ui_middlware.go b/internal/middleware/ui_middlware.go new file mode 100644 index 00000000..2a68782e --- /dev/null +++ b/internal/middleware/ui_middlware.go @@ -0,0 +1,66 @@ +package middlewares + +import ( + "io/fs" + "net/http" + "os" + "strings" + "tinyauth/internal/assets" + + "github.com/gin-gonic/gin" +) + +type UIMiddleware struct { + UIFS fs.FS + UIFileServer http.Handler + ResourcesFileServer http.Handler +} + +func NewUIMiddleware() (*UIMiddleware, error) { + ui, err := fs.Sub(assets.Assets, "dist") + + if err != nil { + return nil, err + } + + uiFileServer := http.FileServer(http.FS(ui)) + resourcesFileServer := http.FileServer(http.Dir("/data/resources")) + + return &UIMiddleware{ + UIFS: ui, + UIFileServer: uiFileServer, + ResourcesFileServer: resourcesFileServer, + }, nil +} + +func (m UIMiddleware) Middlware() gin.HandlerFunc { + return func(c *gin.Context) { + switch strings.Split(c.Request.URL.Path, "/")[1] { + case "api": + c.Next() + return + case "resources": + _, err := os.Stat("/data/resources/" + strings.TrimPrefix(c.Request.URL.Path, "/resources/")) + + if os.IsNotExist(err) { + c.Status(404) + c.Abort() + return + } + + m.ResourcesFileServer.ServeHTTP(c.Writer, c.Request) + c.Abort() + return + default: + _, err := fs.Stat(m.UIFS, strings.TrimPrefix(c.Request.URL.Path, "/")) + + if os.IsNotExist(err) { + c.Request.URL.Path = "/" + } + + m.UIFileServer.ServeHTTP(c.Writer, c.Request) + c.Abort() + return + } + } +} diff --git a/internal/middleware/zerolog_middleware.go b/internal/middleware/zerolog_middleware.go new file mode 100644 index 00000000..bca9a12f --- /dev/null +++ b/internal/middleware/zerolog_middleware.go @@ -0,0 +1,62 @@ +package middlewares + +import ( + "strings" + "time" + + "github.com/gin-gonic/gin" + "github.com/rs/zerolog/log" +) + +var ( + loggerSkipPathsPrefix = []string{ + "GET /api/healthcheck", + "HEAD /api/healthcheck", + "GET /favicon.ico", + } +) + +type ZerologMiddleware struct{} + +func NewZerologMiddleware() *ZerologMiddleware { + return &ZerologMiddleware{} +} + +func (m ZerologMiddleware) logPath(path string) bool { + for _, prefix := range loggerSkipPathsPrefix { + if strings.HasPrefix(path, prefix) { + return false + } + } + return true +} + +func (m ZerologMiddleware) Middlware() gin.HandlerFunc { + return func(c *gin.Context) { + tStart := time.Now() + + c.Next() + + code := c.Writer.Status() + address := c.Request.RemoteAddr + clientIP := c.ClientIP() + method := c.Request.Method + path := c.Request.URL.Path + + latency := time.Since(tStart).String() + + // logPath check if the path should be logged normally or with debug + if m.logPath(method + " " + path) { + switch { + case code >= 200 && code < 300: + log.Info().Str("method", method).Str("path", path).Str("address", address).Str("clientIp", clientIP).Int("status", code).Str("latency", latency).Msg("Request") + case code >= 300 && code < 400: + log.Warn().Str("method", method).Str("path", path).Str("address", address).Str("clientIp", clientIP).Int("status", code).Str("latency", latency).Msg("Request") + case code >= 400: + log.Error().Str("method", method).Str("path", path).Str("address", address).Str("clientIp", clientIP).Int("status", code).Str("latency", latency).Msg("Request") + } + } else { + log.Debug().Str("method", method).Str("path", path).Str("address", address).Int("status", code).Str("latency", latency).Msg("Request") + } + } +} diff --git a/internal/server/server.go b/internal/server/server.go index 88260322..a3820743 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -2,12 +2,6 @@ package server import ( "fmt" - "io/fs" - "net/http" - "os" - "strings" - "time" - "tinyauth/internal/assets" "tinyauth/internal/handlers" "tinyauth/internal/types" @@ -21,52 +15,17 @@ type Server struct { Router *gin.Engine } -var ( - loggerSkipPathsPrefix = []string{ - "GET /api/healthcheck", - "HEAD /api/healthcheck", - "GET /favicon.ico", - } -) - -func logPath(path string) bool { - for _, prefix := range loggerSkipPathsPrefix { - if strings.HasPrefix(path, prefix) { - return false - } - } - return true +type Middlware interface { + Middlware() gin.HandlerFunc } -func NewServer(config types.ServerConfig, handlers *handlers.Handlers) (*Server, error) { - gin.SetMode(gin.ReleaseMode) - - log.Debug().Msg("Setting up router") +func NewServer(config types.ServerConfig, handlers *handlers.Handlers, middlewares []Middlware) (*Server, error) { router := gin.New() - router.Use(zerolog()) - log.Debug().Msg("Setting up assets") - dist, err := fs.Sub(assets.Assets, "dist") - if err != nil { - return nil, err + for _, middleware := range middlewares { + router.Use(middleware.Middlware()) } - log.Debug().Msg("Setting up file server") - fileServer := http.FileServer(http.FS(dist)) - - // UI middleware - router.Use(func(c *gin.Context) { - // If not an API request, serve the UI - if !strings.HasPrefix(c.Request.URL.Path, "/api") { - _, err := fs.Stat(dist, strings.TrimPrefix(c.Request.URL.Path, "/")) - if os.IsNotExist(err) { - c.Request.URL.Path = "/" - } - fileServer.ServeHTTP(c.Writer, c.Request) - c.Abort() - } - }) - // Proxy routes router.GET("/api/auth/:proxy", handlers.ProxyHandler) @@ -98,33 +57,3 @@ func (s *Server) Start() error { log.Info().Str("address", s.Config.Address).Int("port", s.Config.Port).Msg("Starting server") return s.Router.Run(fmt.Sprintf("%s:%d", s.Config.Address, s.Config.Port)) } - -// zerolog is a middleware for gin that logs requests using zerolog -func zerolog() gin.HandlerFunc { - return func(c *gin.Context) { - tStart := time.Now() - - c.Next() - - code := c.Writer.Status() - address := c.Request.RemoteAddr - method := c.Request.Method - path := c.Request.URL.Path - - latency := time.Since(tStart).String() - - // logPath check if the path should be logged normally or with debug - if logPath(method + " " + path) { - switch { - case code >= 200 && code < 300: - log.Info().Str("method", method).Str("path", path).Str("address", address).Int("status", code).Str("latency", latency).Msg("Request") - case code >= 300 && code < 400: - log.Warn().Str("method", method).Str("path", path).Str("address", address).Int("status", code).Str("latency", latency).Msg("Request") - case code >= 400: - log.Error().Str("method", method).Str("path", path).Str("address", address).Int("status", code).Str("latency", latency).Msg("Request") - } - } else { - log.Debug().Str("method", method).Str("path", path).Str("address", address).Int("status", code).Str("latency", latency).Msg("Request") - } - } -} diff --git a/internal/types/config.go b/internal/types/config.go index b53e0536..54ab6c8c 100644 --- a/internal/types/config.go +++ b/internal/types/config.go @@ -95,11 +95,6 @@ type AuthConfig struct { EncryptionSecret string } -// HooksConfig is the configuration for the hooks service -type HooksConfig struct { - Domain string -} - // OAuthLabels is a list of labels that can be used in a tinyauth protected container type OAuthLabels struct { Whitelist string From e1d8ce3cb5468d4ca875040d6a85d7a6fabf0572 Mon Sep 17 00:00:00 2001 From: Stavros Date: Mon, 25 Aug 2025 14:19:52 +0300 Subject: [PATCH 02/17] refactor: use context fom middleware in handlers --- air.toml | 2 +- cmd/root.go | 25 +++++++++++++-------- internal/handlers/context.go | 24 ++++++++++++++++++-- internal/handlers/handlers.go | 22 +++++++++++++----- internal/handlers/proxy.go | 23 ++++++++++++++++--- internal/handlers/user.go | 22 ++++++++++++++++-- internal/middleware/context_middleware.go | 10 ++++++++- internal/middleware/ui_middlware.go | 27 ++++++++++++++--------- internal/middleware/zerolog_middleware.go | 14 +++++++++--- internal/server/server.go | 15 +++++++++---- internal/types/config.go | 15 ------------- 11 files changed, 142 insertions(+), 57 deletions(-) diff --git a/air.toml b/air.toml index 7505b79a..f84163bc 100644 --- a/air.toml +++ b/air.toml @@ -4,7 +4,7 @@ tmp_dir = "tmp" [build] pre_cmd = ["mkdir -p internal/assets/dist", "echo 'backend running' > internal/assets/dist/index.html", "go install github.com/go-delve/delve/cmd/dlv@v1.25.0"] cmd = "CGO_ENABLED=0 go build -gcflags=\"all=-N -l\" -o tmp/tinyauth ." -bin = "/go/bin/dlv --listen :4000 --headless=true --api-version=2 --accept-multiclient --log=true exec tmp/tinyauth --continue" +bin = "/go/bin/dlv --listen :4000 --headless=true --api-version=2 --accept-multiclient --log=true exec tmp/tinyauth --continue --check-go-version=false" include_ext = ["go"] exclude_dir = ["internal/assets/dist"] exclude_regex = [".*_test\\.go"] diff --git a/cmd/root.go b/cmd/root.go index f96ec6bc..927b375a 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -10,8 +10,8 @@ import ( "tinyauth/internal/constants" "tinyauth/internal/docker" "tinyauth/internal/handlers" - "tinyauth/internal/hooks" "tinyauth/internal/ldap" + "tinyauth/internal/middleware" "tinyauth/internal/providers" "tinyauth/internal/server" "tinyauth/internal/types" @@ -84,7 +84,7 @@ var rootCmd = &cobra.Command{ AppURL: config.AppURL, } - handlersConfig := types.HandlersConfig{ + handlersConfig := handlers.HandlersConfig{ AppURL: config.AppURL, DisableContinue: config.DisableContinue, Title: config.Title, @@ -116,10 +116,6 @@ var rootCmd = &cobra.Command{ EncryptionSecret: encryptionSecret, } - hooksConfig := types.HooksConfig{ - Domain: domain, - } - var ldapService *ldap.LDAP if config.LdapAddress != "" { @@ -151,9 +147,20 @@ var rootCmd = &cobra.Command{ HandleError(err, "Failed to initialize docker") auth := auth.NewAuth(authConfig, docker, ldapService) providers := providers.NewProviders(oauthConfig) - hooks := hooks.NewHooks(hooksConfig, auth, providers) - handlers := handlers.NewHandlers(handlersConfig, auth, hooks, providers, docker) - srv, err := server.NewServer(serverConfig, handlers) + handlers := handlers.NewHandlers(handlersConfig, auth, providers, docker) + + // Setup the middlewares + var middlewares []server.Middleware + + contextMiddleware := middleware.NewContextMiddleware(middleware.ContextMiddlewareConfig{ + Domain: domain, + }, auth, providers) + uiMiddleware := middleware.NewUIMiddleware() + zerologMiddleware := middleware.NewZerologMiddleware() + + middlewares = append(middlewares, contextMiddleware, uiMiddleware, zerologMiddleware) + + srv, err := server.NewServer(serverConfig, handlers, middlewares) HandleError(err, "Failed to create server") // Start up diff --git a/internal/handlers/context.go b/internal/handlers/context.go index d0fff5e5..0bbe3923 100644 --- a/internal/handlers/context.go +++ b/internal/handlers/context.go @@ -37,8 +37,28 @@ func (h *Handlers) AppContextHandler(c *gin.Context) { func (h *Handlers) UserContextHandler(c *gin.Context) { log.Debug().Msg("Getting user context") - // Create user context using hooks - userContext := h.Hooks.UseUserContext(c) + // Get user context from middleware + userContextValue, exists := c.Get("context") + + if !exists { + c.JSON(200, types.UserContextResponse{ + Status: 200, + Message: "Unauthorized", + IsLoggedIn: false, + }) + return + } + + userContext, ok := userContextValue.(*types.UserContext) + + if !ok { + c.JSON(200, types.UserContextResponse{ + Status: 200, + Message: "Unauthorized", + IsLoggedIn: false, + }) + return + } userContextResponse := types.UserContextResponse{ Status: 200, diff --git a/internal/handlers/handlers.go b/internal/handlers/handlers.go index 0e8ebe22..e24f7fab 100644 --- a/internal/handlers/handlers.go +++ b/internal/handlers/handlers.go @@ -3,26 +3,36 @@ package handlers import ( "tinyauth/internal/auth" "tinyauth/internal/docker" - "tinyauth/internal/hooks" "tinyauth/internal/providers" - "tinyauth/internal/types" "github.com/gin-gonic/gin" ) +type HandlersConfig struct { + AppURL string + Domain string + CookieSecure bool + DisableContinue bool + GenericName string + Title string + ForgotPasswordMessage string + BackgroundImage string + OAuthAutoRedirect string + CsrfCookieName string + RedirectCookieName string +} + type Handlers struct { - Config types.HandlersConfig + Config HandlersConfig Auth *auth.Auth - Hooks *hooks.Hooks Providers *providers.Providers Docker *docker.Docker } -func NewHandlers(config types.HandlersConfig, auth *auth.Auth, hooks *hooks.Hooks, providers *providers.Providers, docker *docker.Docker) *Handlers { +func NewHandlers(config HandlersConfig, auth *auth.Auth, providers *providers.Providers, docker *docker.Docker) *Handlers { return &Handlers{ Config: config, Auth: auth, - Hooks: hooks, Providers: providers, Docker: docker, } diff --git a/internal/handlers/proxy.go b/internal/handlers/proxy.go index fd87fd16..c9d234e0 100644 --- a/internal/handlers/proxy.go +++ b/internal/handlers/proxy.go @@ -146,7 +146,24 @@ func (h *Handlers) ProxyHandler(c *gin.Context) { return } - userContext := h.Hooks.UseUserContext(c) + var userContext *types.UserContext + + userContextValue, exists := c.Get("context") + + if !exists { + userContext = &types.UserContext{ + IsLoggedIn: false, + } + } else { + var ok bool + userContext, ok = userContextValue.(*types.UserContext) + + if !ok { + userContext = &types.UserContext{ + IsLoggedIn: false, + } + } + } // If we are using basic auth, we need to check if the user has totp and if it does then disable basic auth if userContext.Provider == "basic" && userContext.TotpEnabled { @@ -158,7 +175,7 @@ func (h *Handlers) ProxyHandler(c *gin.Context) { log.Debug().Msg("Authenticated") // Check if user is allowed to access subdomain, if request is nginx.example.com the subdomain (resource) is nginx - appAllowed := h.Auth.ResourceAllowed(c, userContext, labels) + appAllowed := h.Auth.ResourceAllowed(c, *userContext, labels) log.Debug().Bool("appAllowed", appAllowed).Msg("Checking if app is allowed") @@ -195,7 +212,7 @@ func (h *Handlers) ProxyHandler(c *gin.Context) { } if userContext.OAuth { - groupOk := h.Auth.OAuthGroup(c, userContext, labels) + groupOk := h.Auth.OAuthGroup(c, *userContext, labels) log.Debug().Bool("groupOk", groupOk).Msg("Checking if user is in required groups") diff --git a/internal/handlers/user.go b/internal/handlers/user.go index 91d0fef5..86a18eea 100644 --- a/internal/handlers/user.go +++ b/internal/handlers/user.go @@ -141,7 +141,25 @@ func (h *Handlers) TOTPHandler(c *gin.Context) { log.Debug().Msg("Checking totp") // Get user context - userContext := h.Hooks.UseUserContext(c) + userContextValue, exists := c.Get("context") + + if !exists { + c.JSON(401, gin.H{ + "status": 401, + "message": "Unauthorized", + }) + return + } + + userContext, ok := userContextValue.(*types.UserContext) + + if !ok { + c.JSON(401, gin.H{ + "status": 401, + "message": "Unauthorized", + }) + return + } // Check if we have a user if userContext.Username == "" { @@ -157,7 +175,7 @@ func (h *Handlers) TOTPHandler(c *gin.Context) { user := h.Auth.GetLocalUser(userContext.Username) // Check if totp is correct - ok := totp.Validate(totpReq.Code, user.TotpSecret) + ok = totp.Validate(totpReq.Code, user.TotpSecret) if !ok { log.Debug().Msg("Totp incorrect") diff --git a/internal/middleware/context_middleware.go b/internal/middleware/context_middleware.go index c4b76823..ead4879b 100644 --- a/internal/middleware/context_middleware.go +++ b/internal/middleware/context_middleware.go @@ -1,4 +1,4 @@ -package middlewares +package middleware import ( "fmt" @@ -29,6 +29,14 @@ func NewContextMiddleware(config ContextMiddlewareConfig, auth *auth.Auth, provi } } +func (m *ContextMiddleware) Init() error { + return nil +} + +func (m *ContextMiddleware) Name() string { + return "ContextMiddleware" +} + func (m *ContextMiddleware) Middleware() gin.HandlerFunc { return func(c *gin.Context) { cookie, err := m.Auth.GetSessionCookie(c) diff --git a/internal/middleware/ui_middlware.go b/internal/middleware/ui_middlware.go index 2a68782e..22f8ca23 100644 --- a/internal/middleware/ui_middlware.go +++ b/internal/middleware/ui_middlware.go @@ -1,4 +1,4 @@ -package middlewares +package middleware import ( "io/fs" @@ -16,24 +16,29 @@ type UIMiddleware struct { ResourcesFileServer http.Handler } -func NewUIMiddleware() (*UIMiddleware, error) { +func NewUIMiddleware() *UIMiddleware { + return &UIMiddleware{} +} + +func (m *UIMiddleware) Init() error { ui, err := fs.Sub(assets.Assets, "dist") if err != nil { - return nil, err + return nil } - uiFileServer := http.FileServer(http.FS(ui)) - resourcesFileServer := http.FileServer(http.Dir("/data/resources")) + m.UIFS = ui + m.UIFileServer = http.FileServer(http.FS(ui)) + m.ResourcesFileServer = http.FileServer(http.Dir("/data/resources")) + + return nil +} - return &UIMiddleware{ - UIFS: ui, - UIFileServer: uiFileServer, - ResourcesFileServer: resourcesFileServer, - }, nil +func (m *UIMiddleware) Name() string { + return "UIMiddleware" } -func (m UIMiddleware) Middlware() gin.HandlerFunc { +func (m *UIMiddleware) Middleware() gin.HandlerFunc { return func(c *gin.Context) { switch strings.Split(c.Request.URL.Path, "/")[1] { case "api": diff --git a/internal/middleware/zerolog_middleware.go b/internal/middleware/zerolog_middleware.go index bca9a12f..79c5d706 100644 --- a/internal/middleware/zerolog_middleware.go +++ b/internal/middleware/zerolog_middleware.go @@ -1,4 +1,4 @@ -package middlewares +package middleware import ( "strings" @@ -22,7 +22,15 @@ func NewZerologMiddleware() *ZerologMiddleware { return &ZerologMiddleware{} } -func (m ZerologMiddleware) logPath(path string) bool { +func (m *ZerologMiddleware) Init() error { + return nil +} + +func (m *ZerologMiddleware) Name() string { + return "ZerologMiddleware" +} + +func (m *ZerologMiddleware) logPath(path string) bool { for _, prefix := range loggerSkipPathsPrefix { if strings.HasPrefix(path, prefix) { return false @@ -31,7 +39,7 @@ func (m ZerologMiddleware) logPath(path string) bool { return true } -func (m ZerologMiddleware) Middlware() gin.HandlerFunc { +func (m *ZerologMiddleware) Middleware() gin.HandlerFunc { return func(c *gin.Context) { tStart := time.Now() diff --git a/internal/server/server.go b/internal/server/server.go index a3820743..78150feb 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -15,15 +15,22 @@ type Server struct { Router *gin.Engine } -type Middlware interface { - Middlware() gin.HandlerFunc +type Middleware interface { + Middleware() gin.HandlerFunc + Init() error + Name() string } -func NewServer(config types.ServerConfig, handlers *handlers.Handlers, middlewares []Middlware) (*Server, error) { +func NewServer(config types.ServerConfig, handlers *handlers.Handlers, middlewares []Middleware) (*Server, error) { router := gin.New() for _, middleware := range middlewares { - router.Use(middleware.Middlware()) + log.Debug().Str("middleware", middleware.Name()).Msg("Initializing middleware") + err := middleware.Init() + if err != nil { + return nil, fmt.Errorf("failed to initialize middleware %s: %w", middleware.Name(), err) + } + router.Use(middleware.Middleware()) } // Proxy routes diff --git a/internal/types/config.go b/internal/types/config.go index 54ab6c8c..4b32ad98 100644 --- a/internal/types/config.go +++ b/internal/types/config.go @@ -44,21 +44,6 @@ type Config struct { LdapSearchFilter string `mapstructure:"ldap-search-filter"` } -// Server configuration -type HandlersConfig struct { - AppURL string - Domain string - CookieSecure bool - DisableContinue bool - GenericName string - Title string - ForgotPasswordMessage string - BackgroundImage string - OAuthAutoRedirect string - CsrfCookieName string - RedirectCookieName string -} - // OAuthConfig is the configuration for the providers type OAuthConfig struct { GithubClientId string From dfdc656145a12556c9ad7ea1b9100c237706115f Mon Sep 17 00:00:00 2001 From: Stavros Date: Mon, 25 Aug 2025 16:40:06 +0300 Subject: [PATCH 03/17] refactor: use controller approach in handlers --- cmd/root.go | 96 ++++-- frontend/src/context/app-context.tsx | 2 +- frontend/src/context/user-context.tsx | 2 +- frontend/src/pages/logout-page.tsx | 2 +- frontend/src/pages/totp-page.tsx | 2 +- go.mod | 1 + go.sum | 2 + internal/controller/context_controller.go | 102 ++++++ internal/controller/health_controller.go | 24 ++ internal/controller/oauth_controller.go | 185 ++++++++++ internal/controller/proxy_controller.go | 281 +++++++++++++++ internal/controller/user_controller.go | 216 ++++++++++++ internal/handlers/context.go | 84 ----- internal/handlers/handlers.go | 46 --- internal/handlers/handlers_test.go | 394 ---------------------- internal/handlers/oauth.go | 223 ------------ internal/handlers/proxy.go | 299 ---------------- internal/handlers/user.go | 215 ------------ internal/server/server.go | 66 ---- internal/types/api.go | 62 ---- internal/types/config.go | 6 - internal/types/types.go | 11 + internal/utils/utils.go | 17 + 23 files changed, 910 insertions(+), 1428 deletions(-) create mode 100644 internal/controller/context_controller.go create mode 100644 internal/controller/health_controller.go create mode 100644 internal/controller/oauth_controller.go create mode 100644 internal/controller/proxy_controller.go create mode 100644 internal/controller/user_controller.go delete mode 100644 internal/handlers/context.go delete mode 100644 internal/handlers/handlers.go delete mode 100644 internal/handlers/handlers_test.go delete mode 100644 internal/handlers/oauth.go delete mode 100644 internal/handlers/proxy.go delete mode 100644 internal/handlers/user.go delete mode 100644 internal/server/server.go delete mode 100644 internal/types/api.go diff --git a/cmd/root.go b/cmd/root.go index 927b375a..8dadd5d0 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -8,22 +8,28 @@ import ( userCmd "tinyauth/cmd/user" "tinyauth/internal/auth" "tinyauth/internal/constants" + "tinyauth/internal/controller" "tinyauth/internal/docker" - "tinyauth/internal/handlers" "tinyauth/internal/ldap" "tinyauth/internal/middleware" "tinyauth/internal/providers" - "tinyauth/internal/server" "tinyauth/internal/types" "tinyauth/internal/utils" - "github.com/go-playground/validator/v10" + "github.com/gin-gonic/gin" + "github.com/go-playground/validator" "github.com/rs/zerolog" "github.com/rs/zerolog/log" "github.com/spf13/cobra" "github.com/spf13/viper" ) +type Middleware interface { + Middleware() gin.HandlerFunc + Init() error + Name() string +} + var rootCmd = &cobra.Command{ Use: "tinyauth", Short: "The simplest way to protect your apps with a login screen.", @@ -84,25 +90,6 @@ var rootCmd = &cobra.Command{ AppURL: config.AppURL, } - handlersConfig := handlers.HandlersConfig{ - AppURL: config.AppURL, - DisableContinue: config.DisableContinue, - Title: config.Title, - GenericName: config.GenericName, - CookieSecure: config.CookieSecure, - Domain: domain, - ForgotPasswordMessage: config.FogotPasswordMessage, - BackgroundImage: config.BackgroundImage, - OAuthAutoRedirect: config.OAuthAutoRedirect, - CsrfCookieName: csrfCookieName, - RedirectCookieName: redirectCookieName, - } - - serverConfig := types.ServerConfig{ - Port: config.Port, - Address: config.Address, - } - authConfig := types.AuthConfig{ Users: users, OauthWhitelist: config.OAuthWhitelist, @@ -147,10 +134,15 @@ var rootCmd = &cobra.Command{ HandleError(err, "Failed to initialize docker") auth := auth.NewAuth(authConfig, docker, ldapService) providers := providers.NewProviders(oauthConfig) - handlers := handlers.NewHandlers(handlersConfig, auth, providers, docker) + + // Create the engine + engine := gin.New() + + // Create the group + router := engine.Group("/api") // Setup the middlewares - var middlewares []server.Middleware + var middlewares []Middleware contextMiddleware := middleware.NewContextMiddleware(middleware.ContextMiddlewareConfig{ Domain: domain, @@ -160,12 +152,58 @@ var rootCmd = &cobra.Command{ middlewares = append(middlewares, contextMiddleware, uiMiddleware, zerologMiddleware) - srv, err := server.NewServer(serverConfig, handlers, middlewares) - HandleError(err, "Failed to create server") + for _, middleware := range middlewares { + log.Debug().Str("middleware", middleware.Name()).Msg("Initializing middleware") + err := middleware.Init() + HandleError(err, fmt.Sprintf("Failed to initialize middleware %s", middleware.Name())) + router.Use(middleware.Middleware()) + } + + // Create configured providers + var configuredProviders []string + + configuredProviders = append(configuredProviders, providers.GetConfiguredProviders()...) + + if auth.UserAuthConfigured() { + configuredProviders = append(configuredProviders, "username") + } + + // Create controllers + contextController := controller.NewContextController(controller.ContextControllerConfig{ + ConfiguredProviders: configuredProviders, + DisableContinue: config.DisableContinue, + Title: config.Title, + GenericName: config.GenericName, + Domain: domain, + ForgotPasswordMessage: config.FogotPasswordMessage, + BackgroundImage: config.BackgroundImage, + OAuthAutoRedirect: config.OAuthAutoRedirect, + }, router) + contextController.SetupRoutes() + + healthController := controller.NewHealthController(router) + healthController.SetupRoutes() + + oauthController := controller.NewOAuthController(controller.OAuthControllerConfig{ + AppURL: config.AppURL, + SecureCookie: config.CookieSecure, + CSRFCookieName: csrfCookieName, + RedirectCookieName: redirectCookieName, + }, router, auth, providers) + oauthController.SetupRoutes() + + proxyController := controller.NewProxyController(controller.ProxyControllerConfig{ + AppURL: config.AppURL, + }, router, docker, auth) + proxyController.SetupRoutes() + + userController := controller.NewUserController(controller.UserControllerConfig{ + Domain: domain, + }, router, auth) + userController.SetupRoutes() - // Start up - err = srv.Start() - HandleError(err, "Failed to start server") + // Run server + engine.Run(fmt.Sprintf("%s:%d", config.Address, config.Port)) }, } diff --git a/frontend/src/context/app-context.tsx b/frontend/src/context/app-context.tsx index 13abf50d..8f76c119 100644 --- a/frontend/src/context/app-context.tsx +++ b/frontend/src/context/app-context.tsx @@ -15,7 +15,7 @@ export const AppContextProvider = ({ }) => { const { isFetching, data, error } = useSuspenseQuery({ queryKey: ["app"], - queryFn: () => axios.get("/api/app").then((res) => res.data), + queryFn: () => axios.get("/api/context/app").then((res) => res.data), }); if (error && !isFetching) { diff --git a/frontend/src/context/user-context.tsx b/frontend/src/context/user-context.tsx index 43b3c005..a3cfeaa2 100644 --- a/frontend/src/context/user-context.tsx +++ b/frontend/src/context/user-context.tsx @@ -15,7 +15,7 @@ export const UserContextProvider = ({ }) => { const { isFetching, data, error } = useSuspenseQuery({ queryKey: ["user"], - queryFn: () => axios.get("/api/user").then((res) => res.data), + queryFn: () => axios.get("/api/context/user").then((res) => res.data), }); if (error && !isFetching) { diff --git a/frontend/src/pages/logout-page.tsx b/frontend/src/pages/logout-page.tsx index 8c285002..30b2af8c 100644 --- a/frontend/src/pages/logout-page.tsx +++ b/frontend/src/pages/logout-page.tsx @@ -26,7 +26,7 @@ export const LogoutPage = () => { const { t } = useTranslation(); const logoutMutation = useMutation({ - mutationFn: () => axios.post("/api/logout"), + mutationFn: () => axios.post("/api/user/logout"), mutationKey: ["logout"], onSuccess: () => { toast.success(t("logoutSuccessTitle"), { diff --git a/frontend/src/pages/totp-page.tsx b/frontend/src/pages/totp-page.tsx index e04fb2f4..7d4ebad1 100644 --- a/frontend/src/pages/totp-page.tsx +++ b/frontend/src/pages/totp-page.tsx @@ -32,7 +32,7 @@ export const TotpPage = () => { const redirectUri = searchParams.get("redirect_uri"); const totpMutation = useMutation({ - mutationFn: (values: TotpSchema) => axios.post("/api/totp", values), + mutationFn: (values: TotpSchema) => axios.post("/api/user/totp", values), mutationKey: ["totp"], onSuccess: () => { toast.success(t("totpSuccessTitle"), { diff --git a/go.mod b/go.mod index 0a6f8852..8388b2a3 100644 --- a/go.mod +++ b/go.mod @@ -68,6 +68,7 @@ require ( github.com/go-logr/stdr v1.2.2 // indirect github.com/go-playground/locales v0.14.1 // indirect github.com/go-playground/universal-translator v0.18.1 // indirect + github.com/go-playground/validator v9.31.0+incompatible github.com/goccy/go-json v0.10.4 // indirect github.com/gogo/protobuf v1.3.2 // indirect github.com/gorilla/securecookie v1.1.2 // indirect diff --git a/go.sum b/go.sum index dabff47e..b43990cb 100644 --- a/go.sum +++ b/go.sum @@ -111,6 +111,8 @@ github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/o github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY= github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY= github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= +github.com/go-playground/validator v9.31.0+incompatible h1:UA72EPEogEnq76ehGdEDp4Mit+3FDh548oRqwVgNsHA= +github.com/go-playground/validator v9.31.0+incompatible/go.mod h1:yrEkQXlcI+PugkyDjY2bRrL/UBU4f3rvrgkN3V8JEig= github.com/go-playground/validator/v10 v10.27.0 h1:w8+XrWVMhGkxOaaowyKH35gFydVHOvC0/uWoy2Fzwn4= github.com/go-playground/validator/v10 v10.27.0/go.mod h1:I5QpIEbmr8On7W0TktmJAumgzX4CA1XNl4ZmDuVHKKo= github.com/go-viper/mapstructure/v2 v2.3.0 h1:27XbWsHIqhbdR5TIC911OfYvgSaW93HM+dX7970Q7jk= diff --git a/internal/controller/context_controller.go b/internal/controller/context_controller.go new file mode 100644 index 00000000..c7dfccfe --- /dev/null +++ b/internal/controller/context_controller.go @@ -0,0 +1,102 @@ +package controller + +import ( + "tinyauth/internal/utils" + + "github.com/gin-gonic/gin" +) + +type UserContextResponse struct { + Status int `json:"status"` + Message string `json:"message"` + IsLoggedIn bool `json:"isLoggedIn"` + Username string `json:"username"` + Name string `json:"name"` + Email string `json:"email"` + Provider string `json:"provider"` + Oauth bool `json:"oauth"` + TotpPending bool `json:"totpPending"` +} + +type AppContextResponse struct { + Status int `json:"status"` + Message string `json:"message"` + ConfiguredProviders []string `json:"configuredProviders"` + DisableContinue bool `json:"disableContinue"` + Title string `json:"title"` + GenericName string `json:"genericName"` + Domain string `json:"domain"` + ForgotPasswordMessage string `json:"forgotPasswordMessage"` + BackgroundImage string `json:"backgroundImage"` + OAuthAutoRedirect string `json:"oauthAutoRedirect"` +} + +type ContextControllerConfig struct { + ConfiguredProviders []string + DisableContinue bool + Title string + GenericName string + Domain string + ForgotPasswordMessage string + BackgroundImage string + OAuthAutoRedirect string +} + +type ContextController struct { + Config ContextControllerConfig + Router *gin.RouterGroup +} + +func NewContextController(config ContextControllerConfig, router *gin.RouterGroup) *ContextController { + return &ContextController{ + Config: config, + Router: router, + } +} + +func (controller *ContextController) SetupRoutes() { + contextGroup := controller.Router.Group("/context") + contextGroup.GET("/user", controller.userContextHandler) + contextGroup.GET("/app", controller.appContextHandler) +} + +func (controller *ContextController) userContextHandler(c *gin.Context) { + context, err := utils.GetContext(c) + + userContext := UserContextResponse{ + Status: 200, + Message: "Success", + IsLoggedIn: context.IsLoggedIn, + Username: context.Username, + Name: context.Name, + Email: context.Email, + Provider: context.Provider, + Oauth: context.OAuth, + TotpPending: context.TotpPending, + } + + if err != nil { + userContext.Status = 401 + userContext.Message = "Unauthorized" + userContext.IsLoggedIn = false + c.JSON(200, userContext) + return + } + + c.JSON(200, userContext) +} + +func (controller *ContextController) appContextHandler(c *gin.Context) { + c.JSON(200, AppContextResponse{ + Status: 200, + Message: "Success", + ConfiguredProviders: controller.Config.ConfiguredProviders, + DisableContinue: controller.Config.DisableContinue, + Title: controller.Config.Title, + GenericName: controller.Config.GenericName, + Domain: controller.Config.Domain, + ForgotPasswordMessage: controller.Config.ForgotPasswordMessage, + BackgroundImage: controller.Config.BackgroundImage, + OAuthAutoRedirect: controller.Config.OAuthAutoRedirect, + }) +} diff --git a/internal/controller/health_controller.go b/internal/controller/health_controller.go new file mode 100644 index 00000000..2330fb17 --- /dev/null +++ b/internal/controller/health_controller.go @@ -0,0 +1,24 @@ +package controller + +import "github.com/gin-gonic/gin" + +type HealthController struct { + Router *gin.RouterGroup +} + +func NewHealthController(router *gin.RouterGroup) *HealthController { + return &HealthController{ + Router: router, + } +} + +func (controller *HealthController) SetupRoutes() { + controller.Router.GET("/health", controller.healthHandler) +} + +func (controller *HealthController) healthHandler(c *gin.Context) { + c.JSON(200, gin.H{ + "status": "ok", + "message": "Healthy", + }) +} diff --git a/internal/controller/oauth_controller.go b/internal/controller/oauth_controller.go new file mode 100644 index 00000000..63b63229 --- /dev/null +++ b/internal/controller/oauth_controller.go @@ -0,0 +1,185 @@ +package controller + +import ( + "fmt" + "net/http" + "strings" + "time" + "tinyauth/internal/auth" + "tinyauth/internal/providers" + "tinyauth/internal/types" + "tinyauth/internal/utils" + + "github.com/gin-gonic/gin" + "github.com/google/go-querystring/query" +) + +type OAuthRequest struct { + Provider string `uri:"provider" binding:"required"` +} + +type OAuthControllerConfig struct { + CSRFCookieName string + RedirectCookieName string + SecureCookie bool + AppURL string +} + +type OAuthController struct { + Config OAuthControllerConfig + Router *gin.RouterGroup + Auth *auth.Auth + Providers *providers.Providers +} + +func NewOAuthController(config OAuthControllerConfig, router *gin.RouterGroup, auth *auth.Auth, providers *providers.Providers) *OAuthController { + return &OAuthController{ + Config: config, + Router: router, + Auth: auth, + Providers: providers, + } +} + +func (controller *OAuthController) SetupRoutes() { + oauthGroup := controller.Router.Group("/oauth") + oauthGroup.GET("/url/:provider", controller.oauthURLHandler) + oauthGroup.GET("/callback/:provider", controller.oauthCallbackHandler) +} + +func (controller *OAuthController) oauthURLHandler(c *gin.Context) { + var req OAuthRequest + + err := c.BindUri(&req) + if err != nil { + c.JSON(400, gin.H{ + "status": 400, + "message": "Bad Request", + }) + return + } + + provider := controller.Providers.GetProvider(req.Provider) + + if provider == nil { + c.JSON(404, gin.H{ + "status": 404, + "message": "Not Found", + }) + return + } + + state := provider.GenerateState() + authURL := provider.GetAuthURL(state) + c.SetCookie(controller.Config.CSRFCookieName, state, int(time.Hour.Seconds()), "/", "", controller.Config.SecureCookie, true) + + redirectURI := c.Query("redirect_uri") + + if redirectURI != "" { + c.SetCookie(controller.Config.RedirectCookieName, redirectURI, int(time.Hour.Seconds()), "/", "", controller.Config.SecureCookie, true) + } + + c.JSON(200, gin.H{ + "status": 200, + "message": "OK", + "url": authURL, + }) +} + +func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) { + var req OAuthRequest + + err := c.BindUri(&req) + if err != nil { + c.JSON(400, gin.H{ + "status": 400, + "message": "Bad Request", + }) + return + } + + state := c.Query("state") + csrfCookie, err := c.Cookie(controller.Config.CSRFCookieName) + + if err != nil || state != csrfCookie { + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.Config.AppURL)) + return + } + + c.SetCookie(controller.Config.CSRFCookieName, "", -1, "/", "", controller.Config.SecureCookie, true) + + code := c.Query("code") + provider := controller.Providers.GetProvider(req.Provider) + + if provider == nil { + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.Config.AppURL)) + return + } + + _, err = provider.ExchangeToken(code) + if err != nil { + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.Config.AppURL)) + return + } + + user, err := controller.Providers.GetUser(req.Provider) + + if err != nil { + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.Config.AppURL)) + return + } + + if user.Email == "" { + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.Config.AppURL)) + return + } + + if !controller.Auth.EmailWhitelisted(user.Email) { + queries, err := query.Values(types.UnauthorizedQuery{ + Username: user.Email, + }) + + if err != nil { + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.Config.AppURL)) + return + } + + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/unauthorized?%s", controller.Config.AppURL, queries.Encode())) + return + } + + var name string + + if user.Name != "" { + name = user.Name + } else { + name = fmt.Sprintf("%s (%s)", utils.Capitalize(strings.Split(user.Email, "@")[0]), strings.Split(user.Email, "@")[1]) + } + + controller.Auth.CreateSessionCookie(c, &types.SessionCookie{ + Username: user.Email, + Name: name, + Email: user.Email, + Provider: req.Provider, + OAuthGroups: utils.CoalesceToString(user.Groups), + }) + + redirectURI, err := c.Cookie(controller.Config.RedirectCookieName) + + if err != nil { + c.Redirect(http.StatusTemporaryRedirect, controller.Config.AppURL) + return + } + + queries, err := query.Values(types.RedirectQuery{ + RedirectURI: redirectURI, + }) + + if err != nil { + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.Config.AppURL)) + return + } + + c.SetCookie(controller.Config.RedirectCookieName, "", -1, "/", "", controller.Config.SecureCookie, true) + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/login?%s", controller.Config.AppURL, queries.Encode())) +} diff --git a/internal/controller/proxy_controller.go b/internal/controller/proxy_controller.go new file mode 100644 index 00000000..f8476f01 --- /dev/null +++ b/internal/controller/proxy_controller.go @@ -0,0 +1,281 @@ +package controller + +import ( + "fmt" + "net/http" + "strings" + "tinyauth/internal/auth" + "tinyauth/internal/docker" + "tinyauth/internal/types" + "tinyauth/internal/utils" + + "github.com/gin-gonic/gin" + "github.com/google/go-querystring/query" +) + +type Proxy struct { + Proxy string `uri:"proxy" binding:"required"` +} + +type ProxyControllerConfig struct { + AppURL string +} + +type ProxyController struct { + Config ProxyControllerConfig + Router *gin.RouterGroup + Docker *docker.Docker + Auth *auth.Auth +} + +func NewProxyController(config ProxyControllerConfig, router *gin.RouterGroup, docker *docker.Docker, auth *auth.Auth) *ProxyController { + return &ProxyController{ + Config: config, + Router: router, + Docker: docker, + Auth: auth, + } +} + +func (controller *ProxyController) SetupRoutes() { + proxyGroup := controller.Router.Group("/api/auth") + proxyGroup.GET("/:proxy", controller.proxyHandler) +} + +func (controller *ProxyController) proxyHandler(c *gin.Context) { + var req Proxy + + err := c.BindUri(&req) + if err != nil { + c.JSON(400, gin.H{ + "status": 400, + "message": "Bad Request", + }) + return + } + + isBrowser := strings.Contains(c.Request.Header.Get("Accept"), "text/html") + + uri := c.Request.Header.Get("X-Forwarded-Uri") + proto := c.Request.Header.Get("X-Forwarded-Proto") + host := c.Request.Header.Get("X-Forwarded-Host") + + hostWithoutPort := strings.Split(host, ":")[0] + id := strings.Split(hostWithoutPort, ".")[0] + + labels, err := controller.Docker.GetLabels(id, hostWithoutPort) + + if err != nil { + if req.Proxy == "nginx" || !isBrowser { + c.JSON(500, gin.H{ + "status": 500, + "message": "Internal Server Error", + }) + return + } + + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.Config.AppURL)) + return + } + + clientIP := c.ClientIP() + + if controller.Auth.BypassedIP(labels, clientIP) { + c.Header("Authorization", c.Request.Header.Get("Authorization")) + + headers := utils.ParseHeaders(labels.Headers) + + for key, value := range headers { + c.Header(key, value) + } + + if labels.Basic.Username != "" && utils.GetSecret(labels.Basic.Password.Plain, labels.Basic.Password.File) != "" { + c.Header("Authorization", fmt.Sprintf("Basic %s", utils.GetBasicAuth(labels.Basic.Username, utils.GetSecret(labels.Basic.Password.Plain, labels.Basic.Password.File)))) + } + + c.JSON(200, gin.H{ + "status": 200, + "message": "Authenticated", + }) + return + } + + if !controller.Auth.CheckIP(labels, clientIP) { + if req.Proxy == "nginx" || !isBrowser { + c.JSON(401, gin.H{ + "status": 401, + "message": "Unauthorized", + }) + return + } + + queries, err := query.Values(types.UnauthorizedQuery{ + Resource: strings.Split(host, ".")[0], + IP: clientIP, + }) + + if err != nil { + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.Config.AppURL)) + } + + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/unauthorized?%s", controller.Config.AppURL, queries.Encode())) + return + } + + authEnabled, err := controller.Auth.AuthEnabled(uri, labels) + + if err != nil { + if req.Proxy == "nginx" || !isBrowser { + c.JSON(500, gin.H{ + "status": 500, + "message": "Internal Server Error", + }) + return + } + + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.Config.AppURL)) + return + } + + if !authEnabled { + c.Header("Authorization", c.Request.Header.Get("Authorization")) + + headers := utils.ParseHeaders(labels.Headers) + + for key, value := range headers { + c.Header(key, value) + } + + if labels.Basic.Username != "" && utils.GetSecret(labels.Basic.Password.Plain, labels.Basic.Password.File) != "" { + c.Header("Authorization", fmt.Sprintf("Basic %s", utils.GetBasicAuth(labels.Basic.Username, utils.GetSecret(labels.Basic.Password.Plain, labels.Basic.Password.File)))) + } + + c.JSON(200, gin.H{ + "status": 200, + "message": "Authenticated", + }) + return + } + + var userContext types.UserContext + + context, err := utils.GetContext(c) + + if err != nil { + userContext = types.UserContext{ + IsLoggedIn: false, + } + } else { + userContext = context + } + + if userContext.Provider == "basic" && userContext.TotpEnabled { + userContext.IsLoggedIn = false + } + + if userContext.IsLoggedIn { + appAllowed := controller.Auth.ResourceAllowed(c, userContext, labels) + + if !appAllowed { + if req.Proxy == "nginx" || !isBrowser { + c.JSON(403, gin.H{ + "status": 403, + "message": "Forbidden", + }) + return + } + + queries, err := query.Values(types.UnauthorizedQuery{ + Resource: strings.Split(host, ".")[0], + }) + + if userContext.OAuth { + queries.Set("username", userContext.Username) + } else { + queries.Set("username", userContext.Email) + } + + if err != nil { + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.Config.AppURL)) + return + } + + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/unauthorized?%s", controller.Config.AppURL, queries.Encode())) + return + } + + if userContext.OAuth { + groupOK := controller.Auth.OAuthGroup(c, userContext, labels) + + if !groupOK { + if req.Proxy == "nginx" || !isBrowser { + c.JSON(403, gin.H{ + "status": 403, + "message": "Forbidden", + }) + return + } + + queries, err := query.Values(types.UnauthorizedQuery{ + Resource: strings.Split(host, ".")[0], + GroupErr: true, + }) + + if userContext.OAuth { + queries.Set("username", userContext.Username) + } else { + queries.Set("username", userContext.Email) + } + + if err != nil { + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.Config.AppURL)) + return + } + + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/unauthorized?%s", controller.Config.AppURL, queries.Encode())) + return + } + } + + c.Header("Authorization", c.Request.Header.Get("Authorization")) + c.Header("Remote-User", utils.SanitizeHeader(userContext.Username)) + c.Header("Remote-Name", utils.SanitizeHeader(userContext.Name)) + c.Header("Remote-Email", utils.SanitizeHeader(userContext.Email)) + c.Header("Remote-Groups", utils.SanitizeHeader(userContext.OAuthGroups)) + + headers := utils.ParseHeaders(labels.Headers) + + for key, value := range headers { + c.Header(key, value) + } + + if labels.Basic.Username != "" && utils.GetSecret(labels.Basic.Password.Plain, labels.Basic.Password.File) != "" { + c.Header("Authorization", fmt.Sprintf("Basic %s", utils.GetBasicAuth(labels.Basic.Username, utils.GetSecret(labels.Basic.Password.Plain, labels.Basic.Password.File)))) + } + + c.JSON(200, gin.H{ + "status": 200, + "message": "Authenticated", + }) + return + } + + if req.Proxy == "nginx" || !isBrowser { + c.JSON(401, gin.H{ + "status": 401, + "message": "Unauthorized", + }) + return + } + + queries, err := query.Values(types.RedirectQuery{ + RedirectURI: fmt.Sprintf("%s://%s%s", proto, host, uri), + }) + + if err != nil { + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.Config.AppURL)) + return + } + + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/login?%s", controller.Config.AppURL, queries.Encode())) +} diff --git a/internal/controller/user_controller.go b/internal/controller/user_controller.go new file mode 100644 index 00000000..e017826a --- /dev/null +++ b/internal/controller/user_controller.go @@ -0,0 +1,216 @@ +package controller + +import ( + "fmt" + "strings" + "tinyauth/internal/auth" + "tinyauth/internal/types" + "tinyauth/internal/utils" + + "github.com/gin-gonic/gin" + "github.com/pquerna/otp/totp" +) + +type LoginRequest struct { + Username string `json:"username"` + Password string `json:"password"` +} + +type TotpRequest struct { + Code string `json:"code"` +} + +type UserControllerConfig struct { + Domain string +} + +type UserController struct { + Config UserControllerConfig + Router *gin.RouterGroup + Auth *auth.Auth +} + +func NewUserController(config UserControllerConfig, router *gin.RouterGroup, auth *auth.Auth) *UserController { + return &UserController{ + Config: config, + Router: router, + Auth: auth, + } +} + +func (controller *UserController) SetupRoutes() { + userGroup := controller.Router.Group("/user") + userGroup.POST("/login", controller.loginHandler) + userGroup.POST("/logout", controller.logoutHandler) + userGroup.POST("/totp", controller.totpHandler) +} + +func (controller *UserController) loginHandler(c *gin.Context) { + var req LoginRequest + + err := c.BindJSON(&req) + if err != nil { + c.JSON(400, gin.H{ + "status": 400, + "message": "Bad Request", + }) + return + } + + clientIP := c.ClientIP() + + rateIdentifier := req.Username + + if rateIdentifier == "" { + rateIdentifier = clientIP + } + + isLocked, remainingTime := controller.Auth.IsAccountLocked(rateIdentifier) + + if isLocked { + c.JSON(429, gin.H{ + "status": 429, + "message": fmt.Sprintf("Too many failed login attempts. Try again in %d seconds", remainingTime), + }) + return + } + + userSearch := controller.Auth.SearchUser(req.Username) + + if userSearch.Type == "" { + controller.Auth.RecordLoginAttempt(rateIdentifier, false) + c.JSON(401, gin.H{ + "status": 401, + "message": "Unauthorized", + }) + return + } + + if !controller.Auth.VerifyUser(userSearch, req.Password) { + controller.Auth.RecordLoginAttempt(rateIdentifier, false) + c.JSON(401, gin.H{ + "status": 401, + "message": "Unauthorized", + }) + return + } + + controller.Auth.RecordLoginAttempt(rateIdentifier, true) + + if userSearch.Type == "local" { + user := controller.Auth.GetLocalUser(userSearch.Username) + + if user.TotpSecret != "" { + controller.Auth.CreateSessionCookie(c, &types.SessionCookie{ + Username: user.Username, + Name: utils.Capitalize(req.Username), + Email: fmt.Sprintf("%s@%s", strings.ToLower(req.Username), controller.Config.Domain), + Provider: "username", + TotpPending: true, + }) + + c.JSON(200, gin.H{ + "status": 200, + "message": "TOTP required", + "totpPending": true, + }) + return + } + } + + controller.Auth.CreateSessionCookie(c, &types.SessionCookie{ + Username: req.Username, + Name: utils.Capitalize(req.Username), + Email: fmt.Sprintf("%s@%s", strings.ToLower(req.Username), controller.Config.Domain), + Provider: "username", + }) + + c.JSON(200, gin.H{ + "status": 200, + "message": "Login successful", + }) +} + +func (controller *UserController) logoutHandler(c *gin.Context) { + controller.Auth.DeleteSessionCookie(c) + c.JSON(200, gin.H{ + "status": 200, + "message": "Logout successful", + }) +} + +func (controller *UserController) totpHandler(c *gin.Context) { + var req TotpRequest + + err := c.BindJSON(&req) + if err != nil { + c.JSON(400, gin.H{ + "status": 400, + "message": "Bad Request", + }) + return + } + + context, err := utils.GetContext(c) + + if err != nil { + c.JSON(500, gin.H{ + "status": 500, + "message": "Internal Server Error", + }) + return + } + + if !context.IsLoggedIn { + c.JSON(401, gin.H{ + "status": 401, + "message": "Unauthorized", + }) + return + } + + clientIP := c.ClientIP() + + rateIdentifier := context.Username + + if rateIdentifier == "" { + rateIdentifier = clientIP + } + + isLocked, remainingTime := controller.Auth.IsAccountLocked(rateIdentifier) + + if isLocked { + c.JSON(429, gin.H{ + "status": 429, + "message": fmt.Sprintf("Too many failed login attempts. Try again in %d seconds", remainingTime), + }) + return + } + + user := controller.Auth.GetLocalUser(context.Username) + + ok := totp.Validate(req.Code, user.TotpSecret) + + if !ok { + controller.Auth.RecordLoginAttempt(rateIdentifier, false) + c.JSON(401, gin.H{ + "status": 401, + "message": "Unauthorized", + }) + return + } + + controller.Auth.RecordLoginAttempt(rateIdentifier, true) + + controller.Auth.CreateSessionCookie(c, &types.SessionCookie{ + Username: user.Username, + Name: utils.Capitalize(user.Username), + Email: fmt.Sprintf("%s@%s", strings.ToLower(user.Username), controller.Config.Domain), + Provider: "username", + }) + + c.JSON(200, gin.H{ + "status": 200, + "message": "Login successful", + }) +} diff --git a/internal/handlers/context.go b/internal/handlers/context.go deleted file mode 100644 index 0bbe3923..00000000 --- a/internal/handlers/context.go +++ /dev/null @@ -1,84 +0,0 @@ -package handlers - -import ( - "tinyauth/internal/types" - - "github.com/gin-gonic/gin" - "github.com/rs/zerolog/log" -) - -func (h *Handlers) AppContextHandler(c *gin.Context) { - log.Debug().Msg("Getting app context") - - // Get configured providers - configuredProviders := h.Providers.GetConfiguredProviders() - - // We have username/password configured so add it to our providers - if h.Auth.UserAuthConfigured() { - configuredProviders = append(configuredProviders, "username") - } - - // Return app context - appContext := types.AppContext{ - Status: 200, - Message: "OK", - ConfiguredProviders: configuredProviders, - DisableContinue: h.Config.DisableContinue, - Title: h.Config.Title, - GenericName: h.Config.GenericName, - Domain: h.Config.Domain, - ForgotPasswordMessage: h.Config.ForgotPasswordMessage, - BackgroundImage: h.Config.BackgroundImage, - OAuthAutoRedirect: h.Config.OAuthAutoRedirect, - } - c.JSON(200, appContext) -} - -func (h *Handlers) UserContextHandler(c *gin.Context) { - log.Debug().Msg("Getting user context") - - // Get user context from middleware - userContextValue, exists := c.Get("context") - - if !exists { - c.JSON(200, types.UserContextResponse{ - Status: 200, - Message: "Unauthorized", - IsLoggedIn: false, - }) - return - } - - userContext, ok := userContextValue.(*types.UserContext) - - if !ok { - c.JSON(200, types.UserContextResponse{ - Status: 200, - Message: "Unauthorized", - IsLoggedIn: false, - }) - return - } - - userContextResponse := types.UserContextResponse{ - Status: 200, - IsLoggedIn: userContext.IsLoggedIn, - Username: userContext.Username, - Name: userContext.Name, - Email: userContext.Email, - Provider: userContext.Provider, - Oauth: userContext.OAuth, - TotpPending: userContext.TotpPending, - } - - // If we are not logged in we set the status to 401 else we set it to 200 - if !userContext.IsLoggedIn { - log.Debug().Msg("Unauthorized") - userContextResponse.Message = "Unauthorized" - } else { - log.Debug().Interface("userContext", userContext).Msg("Authenticated") - userContextResponse.Message = "Authenticated" - } - - c.JSON(200, userContextResponse) -} diff --git a/internal/handlers/handlers.go b/internal/handlers/handlers.go deleted file mode 100644 index e24f7fab..00000000 --- a/internal/handlers/handlers.go +++ /dev/null @@ -1,46 +0,0 @@ -package handlers - -import ( - "tinyauth/internal/auth" - "tinyauth/internal/docker" - "tinyauth/internal/providers" - - "github.com/gin-gonic/gin" -) - -type HandlersConfig struct { - AppURL string - Domain string - CookieSecure bool - DisableContinue bool - GenericName string - Title string - ForgotPasswordMessage string - BackgroundImage string - OAuthAutoRedirect string - CsrfCookieName string - RedirectCookieName string -} - -type Handlers struct { - Config HandlersConfig - Auth *auth.Auth - Providers *providers.Providers - Docker *docker.Docker -} - -func NewHandlers(config HandlersConfig, auth *auth.Auth, providers *providers.Providers, docker *docker.Docker) *Handlers { - return &Handlers{ - Config: config, - Auth: auth, - Providers: providers, - Docker: docker, - } -} - -func (h *Handlers) HealthcheckHandler(c *gin.Context) { - c.JSON(200, gin.H{ - "status": 200, - "message": "OK", - }) -} diff --git a/internal/handlers/handlers_test.go b/internal/handlers/handlers_test.go deleted file mode 100644 index 279534d6..00000000 --- a/internal/handlers/handlers_test.go +++ /dev/null @@ -1,394 +0,0 @@ -package handlers_test - -import ( - "encoding/json" - "io" - "net/http" - "net/http/httptest" - "reflect" - "strings" - "testing" - "time" - "tinyauth/internal/auth" - "tinyauth/internal/docker" - "tinyauth/internal/handlers" - "tinyauth/internal/hooks" - "tinyauth/internal/providers" - "tinyauth/internal/server" - "tinyauth/internal/types" - - "github.com/magiconair/properties/assert" - "github.com/pquerna/otp/totp" -) - -// Simple server config -var serverConfig = types.ServerConfig{ - Port: 8080, - Address: "0.0.0.0", -} - -// Simple handlers config -var handlersConfig = types.HandlersConfig{ - AppURL: "http://localhost:8080", - Domain: "localhost", - DisableContinue: false, - CookieSecure: false, - Title: "Tinyauth", - GenericName: "Generic", - ForgotPasswordMessage: "Message", - CsrfCookieName: "tinyauth-csrf", - RedirectCookieName: "tinyauth-redirect", - BackgroundImage: "https://example.com/image.png", - OAuthAutoRedirect: "none", -} - -// Simple auth config -var authConfig = types.AuthConfig{ - Users: types.Users{}, - OauthWhitelist: "", - HMACSecret: "4bZ9K.*:;zH=,9zG!meUxu.B5-S[7.V.", // Complex on purpose - EncryptionSecret: "\\:!R(u[Sbv6ZLm.7es)H|OqH4y}0u\\rj", - CookieSecure: false, - SessionExpiry: 3600, - LoginTimeout: 0, - LoginMaxRetries: 0, - SessionCookieName: "tinyauth-session", - Domain: "localhost", -} - -// Simple hooks config -var hooksConfig = types.HooksConfig{ - Domain: "localhost", -} - -// Cookie -var cookie string - -// User -var user = types.User{ - Username: "user", - Password: "$2a$10$AvGHLTYv3xiRJ0xV9xs3XeVIlkGTygI9nqIamFYB5Xu.5.0UWF7B6", // pass -} - -// Initialize the server for tests -func getServer(t *testing.T) *server.Server { - // Create services - authConfig.Users = types.Users{ - { - Username: user.Username, - Password: user.Password, - TotpSecret: user.TotpSecret, - }, - } - docker, err := docker.NewDocker() - if err != nil { - t.Fatalf("Failed to create docker client: %v", err) - } - auth := auth.NewAuth(authConfig, nil, nil) - providers := providers.NewProviders(types.OAuthConfig{}) - hooks := hooks.NewHooks(hooksConfig, auth, providers) - handlers := handlers.NewHandlers(handlersConfig, auth, hooks, providers, docker) - - // Create server - srv, err := server.NewServer(serverConfig, handlers) - if err != nil { - t.Fatalf("Failed to create server: %v", err) - } - - return srv -} - -func TestLogin(t *testing.T) { - t.Log("Testing login") - - srv := getServer(t) - - recorder := httptest.NewRecorder() - - user := types.LoginRequest{ - Username: "user", - Password: "pass", - } - - json, err := json.Marshal(user) - if err != nil { - t.Fatalf("Error marshalling json: %v", err) - } - - req, err := http.NewRequest("POST", "/api/login", strings.NewReader(string(json))) - if err != nil { - t.Fatalf("Error creating request: %v", err) - } - - srv.Router.ServeHTTP(recorder, req) - assert.Equal(t, recorder.Code, http.StatusOK) - - cookies := recorder.Result().Cookies() - - if len(cookies) == 0 { - t.Fatalf("Cookie not set") - } - - // Set the cookie for further tests - cookie = cookies[0].Value -} - -func TestAppContext(t *testing.T) { - // Refresh the cookie - TestLogin(t) - - t.Log("Testing app context") - - srv := getServer(t) - - recorder := httptest.NewRecorder() - - req, err := http.NewRequest("GET", "/api/app", nil) - if err != nil { - t.Fatalf("Error creating request: %v", err) - } - - // Set the cookie from the previous test - req.AddCookie(&http.Cookie{ - Name: "tinyauth", - Value: cookie, - }) - - srv.Router.ServeHTTP(recorder, req) - assert.Equal(t, recorder.Code, http.StatusOK) - - body, err := io.ReadAll(recorder.Body) - if err != nil { - t.Fatalf("Error getting body: %v", err) - } - - var app types.AppContext - - err = json.Unmarshal(body, &app) - if err != nil { - t.Fatalf("Error unmarshalling body: %v", err) - } - - expected := types.AppContext{ - Status: 200, - Message: "OK", - ConfiguredProviders: []string{"username"}, - DisableContinue: false, - Title: "Tinyauth", - GenericName: "Generic", - ForgotPasswordMessage: "Message", - BackgroundImage: "https://example.com/image.png", - OAuthAutoRedirect: "none", - Domain: "localhost", - } - - // We should get the username back - if !reflect.DeepEqual(app, expected) { - t.Fatalf("Expected %v, got %v", expected, app) - } -} - -func TestUserContext(t *testing.T) { - // Refresh the cookie - TestLogin(t) - - t.Log("Testing user context") - - srv := getServer(t) - - recorder := httptest.NewRecorder() - - req, err := http.NewRequest("GET", "/api/user", nil) - if err != nil { - t.Fatalf("Error creating request: %v", err) - } - - req.AddCookie(&http.Cookie{ - Name: "tinyauth-session", - Value: cookie, - }) - - srv.Router.ServeHTTP(recorder, req) - assert.Equal(t, recorder.Code, http.StatusOK) - - body, err := io.ReadAll(recorder.Body) - if err != nil { - t.Fatalf("Error getting body: %v", err) - } - - type User struct { - Username string `json:"username"` - } - - var user User - - err = json.Unmarshal(body, &user) - if err != nil { - t.Fatalf("Error unmarshalling body: %v", err) - } - - // We should get the user back - if user.Username != "user" { - t.Fatalf("Expected user, got %s", user.Username) - } -} - -func TestLogout(t *testing.T) { - // Refresh the cookie - TestLogin(t) - - t.Log("Testing logout") - - srv := getServer(t) - - recorder := httptest.NewRecorder() - - req, err := http.NewRequest("POST", "/api/logout", nil) - if err != nil { - t.Fatalf("Error creating request: %v", err) - } - - req.AddCookie(&http.Cookie{ - Name: "tinyauth-session", - Value: cookie, - }) - - srv.Router.ServeHTTP(recorder, req) - assert.Equal(t, recorder.Code, http.StatusOK) - - // Check if the cookie is different (means the cookie is gone) - if recorder.Result().Cookies()[0].Value == cookie { - t.Fatalf("Cookie not flushed") - } -} - -func TestAuth(t *testing.T) { - // Refresh the cookie - TestLogin(t) - - t.Log("Testing auth endpoint") - - srv := getServer(t) - - recorder := httptest.NewRecorder() - - req, err := http.NewRequest("GET", "/api/auth/traefik", nil) - if err != nil { - t.Fatalf("Error creating request: %v", err) - } - - req.Header.Set("Accept", "text/html") - - srv.Router.ServeHTTP(recorder, req) - assert.Equal(t, recorder.Code, http.StatusTemporaryRedirect) - - recorder = httptest.NewRecorder() - - req, err = http.NewRequest("GET", "/api/auth/traefik", nil) - if err != nil { - t.Fatalf("Error creating request: %v", err) - } - - req.AddCookie(&http.Cookie{ - Name: "tinyauth-session", - Value: cookie, - }) - - srv.Router.ServeHTTP(recorder, req) - assert.Equal(t, recorder.Code, http.StatusOK) - - recorder = httptest.NewRecorder() - - req, err = http.NewRequest("GET", "/api/auth/nginx", nil) - if err != nil { - t.Fatalf("Error creating request: %v", err) - } - - srv.Router.ServeHTTP(recorder, req) - assert.Equal(t, recorder.Code, http.StatusUnauthorized) - - recorder = httptest.NewRecorder() - - req, err = http.NewRequest("GET", "/api/auth/nginx", nil) - if err != nil { - t.Fatalf("Error creating request: %v", err) - } - - req.AddCookie(&http.Cookie{ - Name: "tinyauth-session", - Value: cookie, - }) - - srv.Router.ServeHTTP(recorder, req) - assert.Equal(t, recorder.Code, http.StatusOK) -} - -func TestTOTP(t *testing.T) { - t.Log("Testing TOTP") - - key, err := totp.Generate(totp.GenerateOpts{ - Issuer: "Tinyauth", - AccountName: user.Username, - }) - if err != nil { - t.Fatalf("Failed to generate TOTP secret: %v", err) - } - - secret := key.Secret() - - user.TotpSecret = secret - - srv := getServer(t) - - user := types.LoginRequest{ - Username: "user", - Password: "pass", - } - - loginJson, err := json.Marshal(user) - if err != nil { - t.Fatalf("Error marshalling json: %v", err) - } - - recorder := httptest.NewRecorder() - - req, err := http.NewRequest("POST", "/api/login", strings.NewReader(string(loginJson))) - if err != nil { - t.Fatalf("Error creating request: %v", err) - } - - srv.Router.ServeHTTP(recorder, req) - assert.Equal(t, recorder.Code, http.StatusOK) - - // Set the cookie for next test - cookie = recorder.Result().Cookies()[0].Value - - code, err := totp.GenerateCode(secret, time.Now()) - if err != nil { - t.Fatalf("Failed to generate TOTP code: %v", err) - } - - totpRequest := types.TotpRequest{ - Code: code, - } - - totpJson, err := json.Marshal(totpRequest) - if err != nil { - t.Fatalf("Error marshalling TOTP request: %v", err) - } - - recorder = httptest.NewRecorder() - - req, err = http.NewRequest("POST", "/api/totp", strings.NewReader(string(totpJson))) - if err != nil { - t.Fatalf("Error creating request: %v", err) - } - - req.AddCookie(&http.Cookie{ - Name: "tinyauth-session", - Value: cookie, - }) - - srv.Router.ServeHTTP(recorder, req) - assert.Equal(t, recorder.Code, http.StatusOK) -} diff --git a/internal/handlers/oauth.go b/internal/handlers/oauth.go deleted file mode 100644 index 13c3a474..00000000 --- a/internal/handlers/oauth.go +++ /dev/null @@ -1,223 +0,0 @@ -package handlers - -import ( - "fmt" - "net/http" - "strings" - "time" - "tinyauth/internal/types" - "tinyauth/internal/utils" - - "github.com/gin-gonic/gin" - "github.com/google/go-querystring/query" - "github.com/rs/zerolog/log" -) - -func (h *Handlers) OAuthURLHandler(c *gin.Context) { - var request types.OAuthRequest - - err := c.BindUri(&request) - if err != nil { - log.Error().Err(err).Msg("Failed to bind URI") - c.JSON(400, gin.H{ - "status": 400, - "message": "Bad Request", - }) - return - } - - log.Debug().Msg("Got OAuth request") - - // Check if provider exists - provider := h.Providers.GetProvider(request.Provider) - - if provider == nil { - c.JSON(404, gin.H{ - "status": 404, - "message": "Not Found", - }) - return - } - - log.Debug().Str("provider", request.Provider).Msg("Got provider") - - // Create state - state := provider.GenerateState() - - // Get auth URL - authURL := provider.GetAuthURL(state) - - log.Debug().Msg("Got auth URL") - - // Set CSRF cookie - c.SetCookie(h.Config.CsrfCookieName, state, int(time.Hour.Seconds()), "/", "", h.Config.CookieSecure, true) - - // Get redirect URI - redirectURI := c.Query("redirect_uri") - - // Set redirect cookie if redirect URI is provided - if redirectURI != "" { - log.Debug().Str("redirectURI", redirectURI).Msg("Setting redirect cookie") - c.SetCookie(h.Config.RedirectCookieName, redirectURI, int(time.Hour.Seconds()), "/", "", h.Config.CookieSecure, true) - } - - // Return auth URL - c.JSON(200, gin.H{ - "status": 200, - "message": "OK", - "url": authURL, - }) -} - -func (h *Handlers) OAuthCallbackHandler(c *gin.Context) { - var providerName types.OAuthRequest - - err := c.BindUri(&providerName) - if err != nil { - log.Error().Err(err).Msg("Failed to bind URI") - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) - return - } - - log.Debug().Interface("provider", providerName.Provider).Msg("Got provider name") - - // Get state - state := c.Query("state") - - // Get CSRF cookie - csrfCookie, err := c.Cookie(h.Config.CsrfCookieName) - - if err != nil { - log.Debug().Msg("No CSRF cookie") - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) - return - } - - log.Debug().Str("csrfCookie", csrfCookie).Msg("Got CSRF cookie") - - // Check if CSRF cookie is valid - if csrfCookie != state { - log.Warn().Msg("Invalid CSRF cookie or CSRF cookie does not match with the state") - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) - return - } - - // Clean up CSRF cookie - c.SetCookie(h.Config.CsrfCookieName, "", -1, "/", "", h.Config.CookieSecure, true) - - // Get code - code := c.Query("code") - - log.Debug().Msg("Got code") - - // Get provider - provider := h.Providers.GetProvider(providerName.Provider) - - if provider == nil { - c.Redirect(http.StatusTemporaryRedirect, "/not-found") - return - } - - log.Debug().Str("provider", providerName.Provider).Msg("Got provider") - - // Exchange token (authenticates user) - _, err = provider.ExchangeToken(code) - if err != nil { - log.Error().Err(err).Msg("Failed to exchange token") - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) - return - } - - log.Debug().Msg("Got token") - - // Get user - user, err := h.Providers.GetUser(providerName.Provider) - if err != nil { - log.Error().Err(err).Msg("Failed to get user") - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) - return - } - - log.Debug().Interface("user", user).Msg("Got user") - - // Check that email is not empty - if user.Email == "" { - log.Error().Msg("Email is empty") - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) - return - } - - // Email is not whitelisted - if !h.Auth.EmailWhitelisted(user.Email) { - log.Warn().Str("email", user.Email).Msg("Email not whitelisted") - queries, err := query.Values(types.UnauthorizedQuery{ - Username: user.Email, - }) - - if err != nil { - log.Error().Err(err).Msg("Failed to build queries") - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) - return - } - - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/unauthorized?%s", h.Config.AppURL, queries.Encode())) - } - - log.Debug().Msg("Email whitelisted") - - // Get username - var username string - - if user.PreferredUsername != "" { - username = user.PreferredUsername - } else { - username = fmt.Sprintf("%s_%s", strings.Split(user.Email, "@")[0], strings.Split(user.Email, "@")[1]) - } - - // Get name - var name string - - if user.Name != "" { - name = user.Name - } else { - name = fmt.Sprintf("%s (%s)", utils.Capitalize(strings.Split(user.Email, "@")[0]), strings.Split(user.Email, "@")[1]) - } - - // Create session cookie - h.Auth.CreateSessionCookie(c, &types.SessionCookie{ - Username: username, - Name: name, - Email: user.Email, - Provider: providerName.Provider, - OAuthGroups: utils.CoalesceToString(user.Groups), - }) - - // Check if we have a redirect URI - redirectCookie, err := c.Cookie(h.Config.RedirectCookieName) - - if err != nil { - log.Debug().Msg("No redirect cookie") - c.Redirect(http.StatusTemporaryRedirect, h.Config.AppURL) - return - } - - log.Debug().Str("redirectURI", redirectCookie).Msg("Got redirect URI") - - queries, err := query.Values(types.LoginQuery{ - RedirectURI: redirectCookie, - }) - - if err != nil { - log.Error().Err(err).Msg("Failed to build queries") - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) - return - } - - log.Debug().Msg("Got redirect query") - - // Clean up redirect cookie - c.SetCookie(h.Config.RedirectCookieName, "", -1, "/", "", h.Config.CookieSecure, true) - - // Redirect to continue with the redirect URI - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/continue?%s", h.Config.AppURL, queries.Encode())) -} diff --git a/internal/handlers/proxy.go b/internal/handlers/proxy.go deleted file mode 100644 index c9d234e0..00000000 --- a/internal/handlers/proxy.go +++ /dev/null @@ -1,299 +0,0 @@ -package handlers - -import ( - "fmt" - "net/http" - "strings" - "tinyauth/internal/types" - "tinyauth/internal/utils" - - "github.com/gin-gonic/gin" - "github.com/google/go-querystring/query" - "github.com/rs/zerolog/log" -) - -func (h *Handlers) ProxyHandler(c *gin.Context) { - var proxy types.Proxy - - err := c.BindUri(&proxy) - if err != nil { - log.Error().Err(err).Msg("Failed to bind URI") - c.JSON(400, gin.H{ - "status": 400, - "message": "Bad Request", - }) - return - } - - // Check if the request is coming from a browser (tools like curl/bruno use */* and they don't include the text/html) - isBrowser := strings.Contains(c.Request.Header.Get("Accept"), "text/html") - - if isBrowser { - log.Debug().Msg("Request is most likely coming from a browser") - } else { - log.Debug().Msg("Request is most likely not coming from a browser") - } - - log.Debug().Interface("proxy", proxy.Proxy).Msg("Got proxy") - - uri := c.Request.Header.Get("X-Forwarded-Uri") - proto := c.Request.Header.Get("X-Forwarded-Proto") - host := c.Request.Header.Get("X-Forwarded-Host") - - hostPortless := strings.Split(host, ":")[0] // *lol* - id := strings.Split(hostPortless, ".")[0] - - labels, err := h.Docker.GetLabels(id, hostPortless) - if err != nil { - log.Error().Err(err).Msg("Failed to get container labels") - - if proxy.Proxy == "nginx" || !isBrowser { - c.JSON(500, gin.H{ - "status": 500, - "message": "Internal Server Error", - }) - return - } - - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) - return - } - - log.Debug().Interface("labels", labels).Msg("Got labels") - - ip := c.ClientIP() - - if h.Auth.BypassedIP(labels, ip) { - c.Header("Authorization", c.Request.Header.Get("Authorization")) - - headersParsed := utils.ParseHeaders(labels.Headers) - for key, value := range headersParsed { - log.Debug().Str("key", key).Msg("Setting header") - c.Header(key, value) - } - - if labels.Basic.Username != "" && utils.GetSecret(labels.Basic.Password.Plain, labels.Basic.Password.File) != "" { - log.Debug().Str("username", labels.Basic.Username).Msg("Setting basic auth headers") - c.Header("Authorization", fmt.Sprintf("Basic %s", utils.GetBasicAuth(labels.Basic.Username, utils.GetSecret(labels.Basic.Password.Plain, labels.Basic.Password.File)))) - } - - c.JSON(200, gin.H{ - "status": 200, - "message": "Authenticated", - }) - return - } - - if !h.Auth.CheckIP(labels, ip) { - if proxy.Proxy == "nginx" || !isBrowser { - c.JSON(403, gin.H{ - "status": 403, - "message": "Forbidden", - }) - return - } - - values := types.UnauthorizedQuery{ - Resource: strings.Split(host, ".")[0], - IP: ip, - } - - queries, err := query.Values(values) - if err != nil { - log.Error().Err(err).Msg("Failed to build queries") - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) - return - } - - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/unauthorized?%s", h.Config.AppURL, queries.Encode())) - return - } - - authEnabled, err := h.Auth.AuthEnabled(uri, labels) - if err != nil { - log.Error().Err(err).Msg("Failed to check if app is allowed") - if proxy.Proxy == "nginx" || !isBrowser { - c.JSON(500, gin.H{ - "status": 500, - "message": "Internal Server Error", - }) - return - } - - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) - return - } - - if !authEnabled { - c.Header("Authorization", c.Request.Header.Get("Authorization")) - - headersParsed := utils.ParseHeaders(labels.Headers) - for key, value := range headersParsed { - log.Debug().Str("key", key).Msg("Setting header") - c.Header(key, value) - } - - if labels.Basic.Username != "" && utils.GetSecret(labels.Basic.Password.Plain, labels.Basic.Password.File) != "" { - log.Debug().Str("username", labels.Basic.Username).Msg("Setting basic auth headers") - c.Header("Authorization", fmt.Sprintf("Basic %s", utils.GetBasicAuth(labels.Basic.Username, utils.GetSecret(labels.Basic.Password.Plain, labels.Basic.Password.File)))) - } - - c.JSON(200, gin.H{ - "status": 200, - "message": "Authenticated", - }) - - return - } - - var userContext *types.UserContext - - userContextValue, exists := c.Get("context") - - if !exists { - userContext = &types.UserContext{ - IsLoggedIn: false, - } - } else { - var ok bool - userContext, ok = userContextValue.(*types.UserContext) - - if !ok { - userContext = &types.UserContext{ - IsLoggedIn: false, - } - } - } - - // If we are using basic auth, we need to check if the user has totp and if it does then disable basic auth - if userContext.Provider == "basic" && userContext.TotpEnabled { - log.Warn().Str("username", userContext.Username).Msg("User has totp enabled, disabling basic auth") - userContext.IsLoggedIn = false - } - - if userContext.IsLoggedIn { - log.Debug().Msg("Authenticated") - - // Check if user is allowed to access subdomain, if request is nginx.example.com the subdomain (resource) is nginx - appAllowed := h.Auth.ResourceAllowed(c, *userContext, labels) - - log.Debug().Bool("appAllowed", appAllowed).Msg("Checking if app is allowed") - - if !appAllowed { - log.Warn().Str("username", userContext.Username).Str("host", host).Msg("User not allowed") - - if proxy.Proxy == "nginx" || !isBrowser { - c.JSON(401, gin.H{ - "status": 401, - "message": "Unauthorized", - }) - return - } - - values := types.UnauthorizedQuery{ - Resource: strings.Split(host, ".")[0], - } - - if userContext.OAuth { - values.Username = userContext.Email - } else { - values.Username = userContext.Username - } - - queries, err := query.Values(values) - if err != nil { - log.Error().Err(err).Msg("Failed to build queries") - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) - return - } - - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/unauthorized?%s", h.Config.AppURL, queries.Encode())) - return - } - - if userContext.OAuth { - groupOk := h.Auth.OAuthGroup(c, *userContext, labels) - - log.Debug().Bool("groupOk", groupOk).Msg("Checking if user is in required groups") - - if !groupOk { - log.Warn().Str("username", userContext.Username).Str("host", host).Msg("User is not in required groups") - if proxy.Proxy == "nginx" || !isBrowser { - c.JSON(401, gin.H{ - "status": 401, - "message": "Unauthorized", - }) - return - } - - values := types.UnauthorizedQuery{ - Resource: strings.Split(host, ".")[0], - GroupErr: true, - } - - if userContext.OAuth { - values.Username = userContext.Email - } else { - values.Username = userContext.Username - } - - queries, err := query.Values(values) - if err != nil { - log.Error().Err(err).Msg("Failed to build queries") - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) - return - } - - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/unauthorized?%s", h.Config.AppURL, queries.Encode())) - return - } - } - - c.Header("Authorization", c.Request.Header.Get("Authorization")) - c.Header("Remote-User", utils.SanitizeHeader(userContext.Username)) - c.Header("Remote-Name", utils.SanitizeHeader(userContext.Name)) - c.Header("Remote-Email", utils.SanitizeHeader(userContext.Email)) - c.Header("Remote-Groups", utils.SanitizeHeader(userContext.OAuthGroups)) - - parsedHeaders := utils.ParseHeaders(labels.Headers) - for key, value := range parsedHeaders { - log.Debug().Str("key", key).Msg("Setting header") - c.Header(key, value) - } - - if labels.Basic.Username != "" && utils.GetSecret(labels.Basic.Password.Plain, labels.Basic.Password.File) != "" { - log.Debug().Str("username", labels.Basic.Username).Msg("Setting basic auth headers") - c.Header("Authorization", fmt.Sprintf("Basic %s", utils.GetBasicAuth(labels.Basic.Username, utils.GetSecret(labels.Basic.Password.Plain, labels.Basic.Password.File)))) - } - - c.JSON(200, gin.H{ - "status": 200, - "message": "Authenticated", - }) - return - } - - // The user is not logged in - log.Debug().Msg("Unauthorized") - - if proxy.Proxy == "nginx" || !isBrowser { - c.JSON(401, gin.H{ - "status": 401, - "message": "Unauthorized", - }) - return - } - - queries, err := query.Values(types.LoginQuery{ - RedirectURI: fmt.Sprintf("%s://%s%s", proto, host, uri), - }) - - if err != nil { - log.Error().Err(err).Msg("Failed to build queries") - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) - return - } - - log.Debug().Interface("redirect_uri", fmt.Sprintf("%s://%s%s", proto, host, uri)).Msg("Redirecting to login") - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/login?%s", h.Config.AppURL, queries.Encode())) -} diff --git a/internal/handlers/user.go b/internal/handlers/user.go deleted file mode 100644 index 86a18eea..00000000 --- a/internal/handlers/user.go +++ /dev/null @@ -1,215 +0,0 @@ -package handlers - -import ( - "fmt" - "strings" - "tinyauth/internal/types" - "tinyauth/internal/utils" - - "github.com/gin-gonic/gin" - "github.com/pquerna/otp/totp" - "github.com/rs/zerolog/log" -) - -func (h *Handlers) LoginHandler(c *gin.Context) { - var login types.LoginRequest - - err := c.BindJSON(&login) - if err != nil { - log.Error().Err(err).Msg("Failed to bind JSON") - c.JSON(400, gin.H{ - "status": 400, - "message": "Bad Request", - }) - return - } - - log.Debug().Msg("Got login request") - - clientIP := c.ClientIP() - - // Create an identifier for rate limiting (username or IP if username doesn't exist yet) - rateIdentifier := login.Username - if rateIdentifier == "" { - rateIdentifier = clientIP - } - - // Check if the account is locked due to too many failed attempts - locked, remainingTime := h.Auth.IsAccountLocked(rateIdentifier) - if locked { - log.Warn().Str("identifier", rateIdentifier).Int("remaining_seconds", remainingTime).Msg("Account is locked due to too many failed login attempts") - c.JSON(429, gin.H{ - "status": 429, - "message": fmt.Sprintf("Too many failed login attempts. Try again in %d seconds", remainingTime), - }) - return - } - - // Search for a user based on username - log.Debug().Interface("username", login.Username).Msg("Searching for user") - - userSearch := h.Auth.SearchUser(login.Username) - - // User does not exist - if userSearch.Type == "" { - log.Debug().Str("username", login.Username).Msg("User not found") - // Record failed login attempt - h.Auth.RecordLoginAttempt(rateIdentifier, false) - c.JSON(401, gin.H{ - "status": 401, - "message": "Unauthorized", - }) - return - } - - log.Debug().Msg("Got user") - - // Check if password is correct - if !h.Auth.VerifyUser(userSearch, login.Password) { - log.Debug().Str("username", login.Username).Msg("Password incorrect") - // Record failed login attempt - h.Auth.RecordLoginAttempt(rateIdentifier, false) - c.JSON(401, gin.H{ - "status": 401, - "message": "Unauthorized", - }) - return - } - - log.Debug().Msg("Password correct, checking totp") - - // Record successful login attempt (will reset failed attempt counter) - h.Auth.RecordLoginAttempt(rateIdentifier, true) - - // Check if user is using TOTP - if userSearch.Type == "local" { - // Get local user - localUser := h.Auth.GetLocalUser(login.Username) - - // Check if TOTP is enabled - if localUser.TotpSecret != "" { - log.Debug().Msg("Totp enabled") - - // Set totp pending cookie - h.Auth.CreateSessionCookie(c, &types.SessionCookie{ - Username: login.Username, - Name: utils.Capitalize(login.Username), - Email: fmt.Sprintf("%s@%s", strings.ToLower(login.Username), h.Config.Domain), - Provider: "username", - TotpPending: true, - }) - - // Return totp required - c.JSON(200, gin.H{ - "status": 200, - "message": "Waiting for totp", - "totpPending": true, - }) - return - } - } - - // Create session cookie with username as provider - h.Auth.CreateSessionCookie(c, &types.SessionCookie{ - Username: login.Username, - Name: utils.Capitalize(login.Username), - Email: fmt.Sprintf("%s@%s", strings.ToLower(login.Username), h.Config.Domain), - Provider: "username", - }) - - // Return logged in - c.JSON(200, gin.H{ - "status": 200, - "message": "Logged in", - "totpPending": false, - }) -} - -func (h *Handlers) TOTPHandler(c *gin.Context) { - var totpReq types.TotpRequest - - err := c.BindJSON(&totpReq) - if err != nil { - log.Error().Err(err).Msg("Failed to bind JSON") - c.JSON(400, gin.H{ - "status": 400, - "message": "Bad Request", - }) - return - } - - log.Debug().Msg("Checking totp") - - // Get user context - userContextValue, exists := c.Get("context") - - if !exists { - c.JSON(401, gin.H{ - "status": 401, - "message": "Unauthorized", - }) - return - } - - userContext, ok := userContextValue.(*types.UserContext) - - if !ok { - c.JSON(401, gin.H{ - "status": 401, - "message": "Unauthorized", - }) - return - } - - // Check if we have a user - if userContext.Username == "" { - log.Debug().Msg("No user context") - c.JSON(401, gin.H{ - "status": 401, - "message": "Unauthorized", - }) - return - } - - // Get user - user := h.Auth.GetLocalUser(userContext.Username) - - // Check if totp is correct - ok = totp.Validate(totpReq.Code, user.TotpSecret) - - if !ok { - log.Debug().Msg("Totp incorrect") - c.JSON(401, gin.H{ - "status": 401, - "message": "Unauthorized", - }) - return - } - - log.Debug().Msg("Totp correct") - - // Create session cookie with username as provider - h.Auth.CreateSessionCookie(c, &types.SessionCookie{ - Username: user.Username, - Name: utils.Capitalize(user.Username), - Email: fmt.Sprintf("%s@%s", strings.ToLower(user.Username), h.Config.Domain), - Provider: "username", - }) - - // Return logged in - c.JSON(200, gin.H{ - "status": 200, - "message": "Logged in", - }) -} - -func (h *Handlers) LogoutHandler(c *gin.Context) { - log.Debug().Msg("Cleaning up redirect cookie") - - h.Auth.DeleteSessionCookie(c) - - c.JSON(200, gin.H{ - "status": 200, - "message": "Logged out", - }) -} diff --git a/internal/server/server.go b/internal/server/server.go deleted file mode 100644 index 78150feb..00000000 --- a/internal/server/server.go +++ /dev/null @@ -1,66 +0,0 @@ -package server - -import ( - "fmt" - "tinyauth/internal/handlers" - "tinyauth/internal/types" - - "github.com/gin-gonic/gin" - "github.com/rs/zerolog/log" -) - -type Server struct { - Config types.ServerConfig - Handlers *handlers.Handlers - Router *gin.Engine -} - -type Middleware interface { - Middleware() gin.HandlerFunc - Init() error - Name() string -} - -func NewServer(config types.ServerConfig, handlers *handlers.Handlers, middlewares []Middleware) (*Server, error) { - router := gin.New() - - for _, middleware := range middlewares { - log.Debug().Str("middleware", middleware.Name()).Msg("Initializing middleware") - err := middleware.Init() - if err != nil { - return nil, fmt.Errorf("failed to initialize middleware %s: %w", middleware.Name(), err) - } - router.Use(middleware.Middleware()) - } - - // Proxy routes - router.GET("/api/auth/:proxy", handlers.ProxyHandler) - - // Auth routes - router.POST("/api/login", handlers.LoginHandler) - router.POST("/api/totp", handlers.TOTPHandler) - router.POST("/api/logout", handlers.LogoutHandler) - - // Context routes - router.GET("/api/app", handlers.AppContextHandler) - router.GET("/api/user", handlers.UserContextHandler) - - // OAuth routes - router.GET("/api/oauth/url/:provider", handlers.OAuthURLHandler) - router.GET("/api/oauth/callback/:provider", handlers.OAuthCallbackHandler) - - // App routes - router.GET("/api/healthcheck", handlers.HealthcheckHandler) - router.HEAD("/api/healthcheck", handlers.HealthcheckHandler) - - return &Server{ - Config: config, - Handlers: handlers, - Router: router, - }, nil -} - -func (s *Server) Start() error { - log.Info().Str("address", s.Config.Address).Int("port", s.Config.Port).Msg("Starting server") - return s.Router.Run(fmt.Sprintf("%s:%d", s.Config.Address, s.Config.Port)) -} diff --git a/internal/types/api.go b/internal/types/api.go deleted file mode 100644 index fbf8bf77..00000000 --- a/internal/types/api.go +++ /dev/null @@ -1,62 +0,0 @@ -package types - -// LoginQuery is the query parameters for the login endpoint -type LoginQuery struct { - RedirectURI string `url:"redirect_uri"` -} - -// LoginRequest is the request body for the login endpoint -type LoginRequest struct { - Username string `json:"username"` - Password string `json:"password"` -} - -// OAuthRequest is the request for the OAuth endpoint -type OAuthRequest struct { - Provider string `uri:"provider" binding:"required"` -} - -// UnauthorizedQuery is the query parameters for the unauthorized endpoint -type UnauthorizedQuery struct { - Username string `url:"username"` - Resource string `url:"resource"` - GroupErr bool `url:"groupErr"` - IP string `url:"ip"` -} - -// Proxy is the uri parameters for the proxy endpoint -type Proxy struct { - Proxy string `uri:"proxy" binding:"required"` -} - -// User Context response is the response for the user context endpoint -type UserContextResponse struct { - Status int `json:"status"` - Message string `json:"message"` - IsLoggedIn bool `json:"isLoggedIn"` - Username string `json:"username"` - Name string `json:"name"` - Email string `json:"email"` - Provider string `json:"provider"` - Oauth bool `json:"oauth"` - TotpPending bool `json:"totpPending"` -} - -// App Context is the response for the app context endpoint -type AppContext struct { - Status int `json:"status"` - Message string `json:"message"` - ConfiguredProviders []string `json:"configuredProviders"` - DisableContinue bool `json:"disableContinue"` - Title string `json:"title"` - GenericName string `json:"genericName"` - Domain string `json:"domain"` - ForgotPasswordMessage string `json:"forgotPasswordMessage"` - BackgroundImage string `json:"backgroundImage"` - OAuthAutoRedirect string `json:"oauthAutoRedirect"` -} - -// Totp request is the request for the totp endpoint -type TotpRequest struct { - Code string `json:"code"` -} diff --git a/internal/types/config.go b/internal/types/config.go index 4b32ad98..dfb9e987 100644 --- a/internal/types/config.go +++ b/internal/types/config.go @@ -60,12 +60,6 @@ type OAuthConfig struct { AppURL string } -// ServerConfig is the configuration for the server -type ServerConfig struct { - Port int - Address string -} - // AuthConfig is the configuration for the auth service type AuthConfig struct { Users Users diff --git a/internal/types/types.go b/internal/types/types.go index 2c40ae55..1cb6bed8 100644 --- a/internal/types/types.go +++ b/internal/types/types.go @@ -57,3 +57,14 @@ type LoginAttempt struct { LastAttempt time.Time LockedUntil time.Time } + +type UnauthorizedQuery struct { + Username string `url:"username"` + Resource string `url:"resource"` + GroupErr bool `url:"groupErr"` + IP string `url:"ip"` +} + +type RedirectQuery struct { + RedirectURI string `url:"redirect_uri"` +} diff --git a/internal/utils/utils.go b/internal/utils/utils.go index 39b1518f..8c2f4ea3 100644 --- a/internal/utils/utils.go +++ b/internal/utils/utils.go @@ -13,6 +13,7 @@ import ( "strings" "tinyauth/internal/types" + "github.com/gin-gonic/gin" "github.com/traefik/paerser/parser" "golang.org/x/crypto/hkdf" @@ -348,3 +349,19 @@ func CoalesceToString(value any) string { return "" } } + +func GetContext(c *gin.Context) (types.UserContext, error) { + userContextValue, exists := c.Get("context") + + if !exists { + return types.UserContext{}, errors.New("no user context in request") + } + + userContext, ok := userContextValue.(*types.UserContext) + + if !ok { + return types.UserContext{}, errors.New("invalid user context in request") + } + + return *userContext, nil +} From 44f35af3bfac7afed84ea2f587a2e2e41f586998 Mon Sep 17 00:00:00 2001 From: Stavros Date: Mon, 25 Aug 2025 17:50:34 +0300 Subject: [PATCH 04/17] refactor: move oauth providers into services (non-working) --- internal/auth/auth_test.go | 146 ----------------- internal/{types => config}/config.go | 124 +++++++++----- internal/constants/constants.go | 19 --- internal/oauth/oauth.go | 71 -------- internal/providers/generic.go | 37 ----- internal/providers/github.go | 102 ------------ internal/providers/google.go | 56 ------- internal/providers/providers.go | 154 ------------------ .../{auth/auth.go => service/auth_service.go} | 87 +++++----- .../docker.go => service/docker_service.go} | 34 ++-- internal/service/generic_oauth_service.go | 114 +++++++++++++ internal/service/github_oauth_service.go | 144 ++++++++++++++++ internal/service/google_oauth_service.go | 106 ++++++++++++ .../{ldap/ldap.go => service/ldap_service.go} | 58 ++++--- internal/types/types.go | 70 -------- 15 files changed, 544 insertions(+), 778 deletions(-) delete mode 100644 internal/auth/auth_test.go rename internal/{types => config}/config.go (67%) delete mode 100644 internal/constants/constants.go delete mode 100644 internal/oauth/oauth.go delete mode 100644 internal/providers/generic.go delete mode 100644 internal/providers/github.go delete mode 100644 internal/providers/google.go delete mode 100644 internal/providers/providers.go rename internal/{auth/auth.go => service/auth_service.go} (83%) rename internal/{docker/docker.go => service/docker_service.go} (76%) create mode 100644 internal/service/generic_oauth_service.go create mode 100644 internal/service/github_oauth_service.go create mode 100644 internal/service/google_oauth_service.go rename internal/{ldap/ldap.go => service/ldap_service.go} (64%) delete mode 100644 internal/types/types.go diff --git a/internal/auth/auth_test.go b/internal/auth/auth_test.go deleted file mode 100644 index 1ab73294..00000000 --- a/internal/auth/auth_test.go +++ /dev/null @@ -1,146 +0,0 @@ -package auth_test - -import ( - "testing" - "time" - "tinyauth/internal/auth" - "tinyauth/internal/types" -) - -var config = types.AuthConfig{ - Users: types.Users{}, - OauthWhitelist: "", - SessionExpiry: 3600, -} - -func TestLoginRateLimiting(t *testing.T) { - // Initialize a new auth service with 3 max retries and 5 seconds timeout - config.LoginMaxRetries = 3 - config.LoginTimeout = 5 - authService := auth.NewAuth(config, nil, nil) - - // Test identifier - identifier := "test_user" - - // Test successful login - should not lock account - t.Log("Testing successful login") - - authService.RecordLoginAttempt(identifier, true) - locked, _ := authService.IsAccountLocked(identifier) - - if locked { - t.Fatalf("Account should not be locked after successful login") - } - - // Test 2 failed attempts - should not lock account yet - t.Log("Testing 2 failed login attempts") - - authService.RecordLoginAttempt(identifier, false) - authService.RecordLoginAttempt(identifier, false) - locked, _ = authService.IsAccountLocked(identifier) - - if locked { - t.Fatalf("Account should not be locked after only 2 failed attempts") - } - - // Add one more failed attempt (total 3) - should lock account with maxRetries=3 - t.Log("Testing 3 failed login attempts") - authService.RecordLoginAttempt(identifier, false) - locked, remainingTime := authService.IsAccountLocked(identifier) - - if !locked { - t.Fatalf("Account should be locked after reaching max retries") - } - if remainingTime <= 0 || remainingTime > 5 { - t.Fatalf("Expected remaining time between 1-5 seconds, got %d", remainingTime) - } - - // Test reset after waiting for timeout - use 1 second timeout for fast testing - t.Log("Testing unlocking after timeout") - - // Reinitialize auth service with a shorter timeout for testing - config.LoginTimeout = 1 - config.LoginMaxRetries = 3 - authService = auth.NewAuth(config, nil, nil) - - // Add enough failed attempts to lock the account - for i := 0; i < 3; i++ { - authService.RecordLoginAttempt(identifier, false) - } - - // Verify it's locked - locked, _ = authService.IsAccountLocked(identifier) - if !locked { - t.Fatalf("Account should be locked initially") - } - - // Wait a bit and verify it gets unlocked after timeout - time.Sleep(1500 * time.Millisecond) // Wait longer than the timeout - locked, _ = authService.IsAccountLocked(identifier) - - if locked { - t.Fatalf("Account should be unlocked after timeout period") - } - - // Test disabled rate limiting - t.Log("Testing disabled rate limiting") - config.LoginMaxRetries = 0 - config.LoginTimeout = 0 - authService = auth.NewAuth(config, nil, nil) - - for i := 0; i < 10; i++ { - authService.RecordLoginAttempt(identifier, false) - } - - locked, _ = authService.IsAccountLocked(identifier) - if locked { - t.Fatalf("Account should not be locked when rate limiting is disabled") - } -} - -func TestConcurrentLoginAttempts(t *testing.T) { - // Initialize a new auth service with 2 max retries and 5 seconds timeout - config.LoginMaxRetries = 2 - config.LoginTimeout = 5 - authService := auth.NewAuth(config, nil, nil) - - // Test multiple identifiers - identifiers := []string{"user1", "user2", "user3"} - - // Test that locking one identifier doesn't affect others - t.Log("Testing multiple identifiers") - - // Add enough failed attempts to lock first user (2 attempts with maxRetries=2) - authService.RecordLoginAttempt(identifiers[0], false) - authService.RecordLoginAttempt(identifiers[0], false) - - // Check if first user is locked - locked, _ := authService.IsAccountLocked(identifiers[0]) - if !locked { - t.Fatalf("User1 should be locked after reaching max retries") - } - - // Check that other users are not affected - for i := 1; i < len(identifiers); i++ { - locked, _ := authService.IsAccountLocked(identifiers[i]) - if locked { - t.Fatalf("User%d should not be locked", i+1) - } - } - - // Test successful login after failed attempts (but before lock) - t.Log("Testing successful login after failed attempts but before lock") - - // One failed attempt for user2 - authService.RecordLoginAttempt(identifiers[1], false) - - // Successful login should reset the counter - authService.RecordLoginAttempt(identifiers[1], true) - - // Now try a failed login again - should not be locked as counter was reset - authService.RecordLoginAttempt(identifiers[1], false) - locked, _ = authService.IsAccountLocked(identifiers[1]) - if locked { - t.Fatalf("User2 should not be locked after successful login reset") - } -} diff --git a/internal/types/config.go b/internal/config/config.go similarity index 67% rename from internal/types/config.go rename to internal/config/config.go index dfb9e987..48dc47f0 100644 --- a/internal/types/config.go +++ b/internal/config/config.go @@ -1,6 +1,22 @@ -package types +package config + +import "time" + +type Claims struct { + Name string `json:"name"` + Email string `json:"email"` + PreferredUsername string `json:"preferred_username"` + Groups any `json:"groups"` +} + +var Version = "development" +var CommitHash = "n/a" +var BuildTimestamp = "n/a" + +var SessionCookieName = "tinyauth-session" +var CsrfCookieName = "tinyauth-csrf" +var RedirectCookieName = "tinyauth-redirect" -// Config is the configuration for the tinyauth server type Config struct { Port int `mapstructure:"port" validate:"required"` Address string `validate:"required,ip4_addr" mapstructure:"address"` @@ -44,62 +60,27 @@ type Config struct { LdapSearchFilter string `mapstructure:"ldap-search-filter"` } -// OAuthConfig is the configuration for the providers -type OAuthConfig struct { - GithubClientId string - GithubClientSecret string - GoogleClientId string - GoogleClientSecret string - GenericClientId string - GenericClientSecret string - GenericScopes []string - GenericAuthURL string - GenericTokenURL string - GenericUserURL string - GenericSkipSSL bool - AppURL string -} - -// AuthConfig is the configuration for the auth service -type AuthConfig struct { - Users Users - OauthWhitelist string - SessionExpiry int - CookieSecure bool - Domain string - LoginTimeout int - LoginMaxRetries int - SessionCookieName string - HMACSecret string - EncryptionSecret string -} - -// OAuthLabels is a list of labels that can be used in a tinyauth protected container type OAuthLabels struct { Whitelist string Groups string } -// Basic auth labels for a tinyauth protected container type BasicLabels struct { Username string Password PassowrdLabels } -// PassowrdLabels is a struct that contains the password labels for a tinyauth protected container type PassowrdLabels struct { Plain string File string } -// IP labels for a tinyauth protected container type IPLabels struct { Allow []string Block []string Bypass []string } -// Labels is a struct that contains the labels for a tinyauth protected container type Labels struct { Users string Allowed string @@ -110,12 +91,65 @@ type Labels struct { IP IPLabels } -// Ldap config is a struct that contains the configuration for the LDAP service -type LdapConfig struct { - Address string - BindDN string - BindPassword string - BaseDN string - Insecure bool - SearchFilter string +type OAuthServiceConfig struct { + ClientID string + ClientSecret string + Scopes []string + RedirectURL string + AuthURL string + TokenURL string + UserinfoURL string + InsecureSkipVerify bool + Name string +} + +type User struct { + Username string + Password string + TotpSecret string +} + +type UserSearch struct { + Username string + Type string // local, ldap or unknown +} + +type Users []User + +type SessionCookie struct { + Username string + Name string + Email string + Provider string + TotpPending bool + OAuthGroups string +} + +type UserContext struct { + Username string + Name string + Email string + IsLoggedIn bool + OAuth bool + Provider string + TotpPending bool + OAuthGroups string + TotpEnabled bool +} + +type LoginAttempt struct { + FailedAttempts int + LastAttempt time.Time + LockedUntil time.Time +} + +type UnauthorizedQuery struct { + Username string `url:"username"` + Resource string `url:"resource"` + GroupErr bool `url:"groupErr"` + IP string `url:"ip"` +} + +type RedirectQuery struct { + RedirectURI string `url:"redirect_uri"` } diff --git a/internal/constants/constants.go b/internal/constants/constants.go deleted file mode 100644 index d6f64fab..00000000 --- a/internal/constants/constants.go +++ /dev/null @@ -1,19 +0,0 @@ -package constants - -// Claims are the OIDC supported claims (prefered username is included for convinience) -type Claims struct { - Name string `json:"name"` - Email string `json:"email"` - PreferredUsername string `json:"preferred_username"` - Groups any `json:"groups"` -} - -// Version information -var Version = "development" -var CommitHash = "n/a" -var BuildTimestamp = "n/a" - -// Base cookie names -var SessionCookieName = "tinyauth-session" -var CsrfCookieName = "tinyauth-csrf" -var RedirectCookieName = "tinyauth-redirect" diff --git a/internal/oauth/oauth.go b/internal/oauth/oauth.go deleted file mode 100644 index 9529fce5..00000000 --- a/internal/oauth/oauth.go +++ /dev/null @@ -1,71 +0,0 @@ -package oauth - -import ( - "context" - "crypto/rand" - "crypto/tls" - "encoding/base64" - "net/http" - - "golang.org/x/oauth2" -) - -type OAuth struct { - Config oauth2.Config - Context context.Context - Token *oauth2.Token - Verifier string -} - -func NewOAuth(config oauth2.Config, insecureSkipVerify bool) *OAuth { - transport := &http.Transport{ - TLSClientConfig: &tls.Config{ - InsecureSkipVerify: insecureSkipVerify, - MinVersion: tls.VersionTLS12, - }, - } - - httpClient := &http.Client{ - Transport: transport, - } - - ctx := context.Background() - - // Set the HTTP client in the context - ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient) - - verifier := oauth2.GenerateVerifier() - - return &OAuth{ - Config: config, - Context: ctx, - Verifier: verifier, - } -} - -func (oauth *OAuth) GetAuthURL(state string) string { - return oauth.Config.AuthCodeURL(state, oauth2.AccessTypeOffline, oauth2.S256ChallengeOption(oauth.Verifier)) -} - -func (oauth *OAuth) ExchangeToken(code string) (string, error) { - token, err := oauth.Config.Exchange(oauth.Context, code, oauth2.VerifierOption(oauth.Verifier)) - - if err != nil { - return "", err - } - - // Set and return the token - oauth.Token = token - return oauth.Token.AccessToken, nil -} - -func (oauth *OAuth) GetClient() *http.Client { - return oauth.Config.Client(oauth.Context, oauth.Token) -} - -func (oauth *OAuth) GenerateState() string { - b := make([]byte, 128) - rand.Read(b) - state := base64.URLEncoding.EncodeToString(b) - return state -} diff --git a/internal/providers/generic.go b/internal/providers/generic.go deleted file mode 100644 index 200f7c4b..00000000 --- a/internal/providers/generic.go +++ /dev/null @@ -1,37 +0,0 @@ -package providers - -import ( - "encoding/json" - "io" - "net/http" - "tinyauth/internal/constants" - - "github.com/rs/zerolog/log" -) - -func GetGenericUser(client *http.Client, url string) (constants.Claims, error) { - var user constants.Claims - - res, err := client.Get(url) - if err != nil { - return user, err - } - defer res.Body.Close() - - log.Debug().Msg("Got response from generic provider") - - body, err := io.ReadAll(res.Body) - if err != nil { - return user, err - } - - log.Debug().Msg("Read body from generic provider") - - err = json.Unmarshal(body, &user) - if err != nil { - return user, err - } - - log.Debug().Msg("Parsed user from generic provider") - return user, nil -} diff --git a/internal/providers/github.go b/internal/providers/github.go deleted file mode 100644 index 67f85104..00000000 --- a/internal/providers/github.go +++ /dev/null @@ -1,102 +0,0 @@ -package providers - -import ( - "encoding/json" - "errors" - "io" - "net/http" - "tinyauth/internal/constants" - - "github.com/rs/zerolog/log" -) - -// Response for the github email endpoint -type GithubEmailResponse []struct { - Email string `json:"email"` - Primary bool `json:"primary"` -} - -// Response for the github user endpoint -type GithubUserInfoResponse struct { - Login string `json:"login"` - Name string `json:"name"` -} - -// The scopes required for the github provider -func GithubScopes() []string { - return []string{"user:email", "read:user"} -} - -func GetGithubUser(client *http.Client) (constants.Claims, error) { - var user constants.Claims - - res, err := client.Get("https://api.github.com/user") - if err != nil { - return user, err - } - defer res.Body.Close() - - log.Debug().Msg("Got user response from github") - - body, err := io.ReadAll(res.Body) - if err != nil { - return user, err - } - - log.Debug().Msg("Read user body from github") - - var userInfo GithubUserInfoResponse - - err = json.Unmarshal(body, &userInfo) - if err != nil { - return user, err - } - - res, err = client.Get("https://api.github.com/user/emails") - if err != nil { - return user, err - } - defer res.Body.Close() - - log.Debug().Msg("Got email response from github") - - body, err = io.ReadAll(res.Body) - if err != nil { - return user, err - } - - log.Debug().Msg("Read email body from github") - - var emails GithubEmailResponse - - err = json.Unmarshal(body, &emails) - if err != nil { - return user, err - } - - log.Debug().Msg("Parsed emails from github") - - // Find and return the primary email - for _, email := range emails { - if email.Primary { - log.Debug().Str("email", email.Email).Msg("Found primary email") - user.Email = email.Email - break - } - } - - if len(emails) == 0 { - return user, errors.New("no emails found") - } - - // Use first available email if no primary email was found - if user.Email == "" { - log.Warn().Str("email", emails[0].Email).Msg("No primary email found, using first email") - user.Email = emails[0].Email - } - - user.PreferredUsername = userInfo.Login - user.Name = userInfo.Name - - return user, nil -} diff --git a/internal/providers/google.go b/internal/providers/google.go deleted file mode 100644 index e794beec..00000000 --- a/internal/providers/google.go +++ /dev/null @@ -1,56 +0,0 @@ -package providers - -import ( - "encoding/json" - "io" - "net/http" - "strings" - "tinyauth/internal/constants" - - "github.com/rs/zerolog/log" -) - -// Response for the google user endpoint -type GoogleUserInfoResponse struct { - Email string `json:"email"` - Name string `json:"name"` -} - -// The scopes required for the google provider -func GoogleScopes() []string { - return []string{"https://www.googleapis.com/auth/userinfo.email", "https://www.googleapis.com/auth/userinfo.profile"} -} - -func GetGoogleUser(client *http.Client) (constants.Claims, error) { - var user constants.Claims - - res, err := client.Get("https://www.googleapis.com/userinfo/v2/me") - if err != nil { - return user, err - } - defer res.Body.Close() - - log.Debug().Msg("Got response from google") - - body, err := io.ReadAll(res.Body) - if err != nil { - return user, err - } - - log.Debug().Msg("Read body from google") - - var userInfo GoogleUserInfoResponse - - err = json.Unmarshal(body, &userInfo) - if err != nil { - return user, err - } - - log.Debug().Msg("Parsed user from google") - - user.PreferredUsername = strings.Split(userInfo.Email, "@")[0] - user.Name = userInfo.Name - user.Email = userInfo.Email - - return user, nil -} diff --git a/internal/providers/providers.go b/internal/providers/providers.go deleted file mode 100644 index 7af127ea..00000000 --- a/internal/providers/providers.go +++ /dev/null @@ -1,154 +0,0 @@ -package providers - -import ( - "fmt" - "tinyauth/internal/constants" - "tinyauth/internal/oauth" - "tinyauth/internal/types" - - "github.com/rs/zerolog/log" - "golang.org/x/oauth2" - "golang.org/x/oauth2/endpoints" -) - -type Providers struct { - Config types.OAuthConfig - Github *oauth.OAuth - Google *oauth.OAuth - Generic *oauth.OAuth -} - -func NewProviders(config types.OAuthConfig) *Providers { - providers := &Providers{ - Config: config, - } - - if config.GithubClientId != "" && config.GithubClientSecret != "" { - log.Info().Msg("Initializing Github OAuth") - providers.Github = oauth.NewOAuth(oauth2.Config{ - ClientID: config.GithubClientId, - ClientSecret: config.GithubClientSecret, - RedirectURL: fmt.Sprintf("%s/api/oauth/callback/github", config.AppURL), - Scopes: GithubScopes(), - Endpoint: endpoints.GitHub, - }, false) - } - - if config.GoogleClientId != "" && config.GoogleClientSecret != "" { - log.Info().Msg("Initializing Google OAuth") - providers.Google = oauth.NewOAuth(oauth2.Config{ - ClientID: config.GoogleClientId, - ClientSecret: config.GoogleClientSecret, - RedirectURL: fmt.Sprintf("%s/api/oauth/callback/google", config.AppURL), - Scopes: GoogleScopes(), - Endpoint: endpoints.Google, - }, false) - } - - if config.GenericClientId != "" && config.GenericClientSecret != "" { - log.Info().Msg("Initializing Generic OAuth") - providers.Generic = oauth.NewOAuth(oauth2.Config{ - ClientID: config.GenericClientId, - ClientSecret: config.GenericClientSecret, - RedirectURL: fmt.Sprintf("%s/api/oauth/callback/generic", config.AppURL), - Scopes: config.GenericScopes, - Endpoint: oauth2.Endpoint{ - AuthURL: config.GenericAuthURL, - TokenURL: config.GenericTokenURL, - }, - }, config.GenericSkipSSL) - } - - return providers -} - -func (providers *Providers) GetProvider(provider string) *oauth.OAuth { - switch provider { - case "github": - return providers.Github - case "google": - return providers.Google - case "generic": - return providers.Generic - default: - return nil - } -} - -func (providers *Providers) GetUser(provider string) (constants.Claims, error) { - var user constants.Claims - - // Get the user from the provider - switch provider { - case "github": - if providers.Github == nil { - log.Debug().Msg("Github provider not configured") - return user, nil - } - - client := providers.Github.GetClient() - - log.Debug().Msg("Got client from github") - - user, err := GetGithubUser(client) - if err != nil { - return user, err - } - - log.Debug().Msg("Got user from github") - - return user, nil - case "google": - if providers.Google == nil { - log.Debug().Msg("Google provider not configured") - return user, nil - } - - client := providers.Google.GetClient() - - log.Debug().Msg("Got client from google") - - user, err := GetGoogleUser(client) - if err != nil { - return user, err - } - - log.Debug().Msg("Got user from google") - - return user, nil - case "generic": - if providers.Generic == nil { - log.Debug().Msg("Generic provider not configured") - return user, nil - } - - client := providers.Generic.GetClient() - - log.Debug().Msg("Got client from generic") - - user, err := GetGenericUser(client, providers.Config.GenericUserURL) - if err != nil { - return user, err - } - - log.Debug().Msg("Got user from generic") - - return user, nil - default: - return user, nil - } -} - -func (provider *Providers) GetConfiguredProviders() []string { - providers := []string{} - if provider.Github != nil { - providers = append(providers, "github") - } - if provider.Google != nil { - providers = append(providers, "google") - } - if provider.Generic != nil { - providers = append(providers, "generic") - } - return providers -} diff --git a/internal/auth/auth.go b/internal/service/auth_service.go similarity index 83% rename from internal/auth/auth.go rename to internal/service/auth_service.go index 3f18419a..ebbd1ad9 100644 --- a/internal/auth/auth.go +++ b/internal/service/auth_service.go @@ -1,4 +1,4 @@ -package auth +package service import ( "fmt" @@ -6,8 +6,6 @@ import ( "strings" "sync" "time" - "tinyauth/internal/docker" - "tinyauth/internal/ldap" "tinyauth/internal/types" "tinyauth/internal/utils" @@ -17,35 +15,50 @@ import ( "golang.org/x/crypto/bcrypt" ) -type Auth struct { - Config types.AuthConfig - Docker *docker.Docker +type AuthServiceConfig struct { + Users types.Users + OauthWhitelist string + SessionExpiry int + CookieSecure bool + Domain string + LoginTimeout int + LoginMaxRetries int + SessionCookieName string + HMACSecret string + EncryptionSecret string +} + +type AuthService struct { + Config AuthServiceConfig + Docker *DockerService LoginAttempts map[string]*types.LoginAttempt LoginMutex sync.RWMutex Store *sessions.CookieStore - LDAP *ldap.LDAP + LDAP *LdapService } -func NewAuth(config types.AuthConfig, docker *docker.Docker, ldap *ldap.LDAP) *Auth { - // Setup cookie store and create the auth service - store := sessions.NewCookieStore([]byte(config.HMACSecret), []byte(config.EncryptionSecret)) - store.Options = &sessions.Options{ - Path: "/", - MaxAge: config.SessionExpiry, - Secure: config.CookieSecure, - HttpOnly: true, - Domain: fmt.Sprintf(".%s", config.Domain), - } - return &Auth{ +func NewAuthService(config AuthServiceConfig, docker *DockerService, ldap *LdapService) *AuthService { + return &AuthService{ Config: config, Docker: docker, LoginAttempts: make(map[string]*types.LoginAttempt), - Store: store, LDAP: ldap, } } -func (auth *Auth) GetSession(c *gin.Context) (*sessions.Session, error) { +func (auth *AuthService) Init() error { + store := sessions.NewCookieStore([]byte(auth.Config.HMACSecret), []byte(auth.Config.EncryptionSecret)) + store.Options = &sessions.Options{ + Path: "/", + MaxAge: auth.Config.SessionExpiry, + Secure: auth.Config.CookieSecure, + HttpOnly: true, + Domain: fmt.Sprintf(".%s", auth.Config.Domain), + } + return nil +} + +func (auth *AuthService) GetSession(c *gin.Context) (*sessions.Session, error) { session, err := auth.Store.Get(c.Request, auth.Config.SessionCookieName) // If there was an error getting the session, it might be invalid so let's clear it and retry @@ -62,7 +75,7 @@ func (auth *Auth) GetSession(c *gin.Context) (*sessions.Session, error) { return session, nil } -func (auth *Auth) SearchUser(username string) types.UserSearch { +func (auth *AuthService) SearchUser(username string) types.UserSearch { log.Debug().Str("username", username).Msg("Searching for user") // Check local users first @@ -93,7 +106,7 @@ func (auth *Auth) SearchUser(username string) types.UserSearch { } } -func (auth *Auth) VerifyUser(search types.UserSearch, password string) bool { +func (auth *AuthService) VerifyUser(search types.UserSearch, password string) bool { // Authenticate the user based on the type switch search.Type { case "local": @@ -131,7 +144,7 @@ func (auth *Auth) VerifyUser(search types.UserSearch, password string) bool { return false } -func (auth *Auth) GetLocalUser(username string) types.User { +func (auth *AuthService) GetLocalUser(username string) types.User { // Loop through users and return the user if the username matches log.Debug().Str("username", username).Msg("Searching for local user") @@ -146,11 +159,11 @@ func (auth *Auth) GetLocalUser(username string) types.User { return types.User{} } -func (auth *Auth) CheckPassword(user types.User, password string) bool { +func (auth *AuthService) CheckPassword(user types.User, password string) bool { return bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password)) == nil } -func (auth *Auth) IsAccountLocked(identifier string) (bool, int) { +func (auth *AuthService) IsAccountLocked(identifier string) (bool, int) { auth.LoginMutex.RLock() defer auth.LoginMutex.RUnlock() @@ -176,7 +189,7 @@ func (auth *Auth) IsAccountLocked(identifier string) (bool, int) { return false, 0 } -func (auth *Auth) RecordLoginAttempt(identifier string, success bool) { +func (auth *AuthService) RecordLoginAttempt(identifier string, success bool) { // Skip if rate limiting is not configured if auth.Config.LoginMaxRetries <= 0 || auth.Config.LoginTimeout <= 0 { return @@ -212,11 +225,11 @@ func (auth *Auth) RecordLoginAttempt(identifier string, success bool) { } } -func (auth *Auth) EmailWhitelisted(email string) bool { +func (auth *AuthService) EmailWhitelisted(email string) bool { return utils.CheckFilter(auth.Config.OauthWhitelist, email) } -func (auth *Auth) CreateSessionCookie(c *gin.Context, data *types.SessionCookie) error { +func (auth *AuthService) CreateSessionCookie(c *gin.Context, data *types.SessionCookie) error { log.Debug().Msg("Creating session cookie") session, err := auth.GetSession(c) @@ -252,7 +265,7 @@ func (auth *Auth) CreateSessionCookie(c *gin.Context, data *types.SessionCookie) return nil } -func (auth *Auth) DeleteSessionCookie(c *gin.Context) error { +func (auth *AuthService) DeleteSessionCookie(c *gin.Context) error { log.Debug().Msg("Deleting session cookie") session, err := auth.GetSession(c) @@ -275,7 +288,7 @@ func (auth *Auth) DeleteSessionCookie(c *gin.Context) error { return nil } -func (auth *Auth) GetSessionCookie(c *gin.Context) (types.SessionCookie, error) { +func (auth *AuthService) GetSessionCookie(c *gin.Context) (types.SessionCookie, error) { log.Debug().Msg("Getting session cookie") session, err := auth.GetSession(c) @@ -319,12 +332,12 @@ func (auth *Auth) GetSessionCookie(c *gin.Context) (types.SessionCookie, error) }, nil } -func (auth *Auth) UserAuthConfigured() bool { +func (auth *AuthService) UserAuthConfigured() bool { // If there are users or LDAP is configured, return true return len(auth.Config.Users) > 0 || auth.LDAP != nil } -func (auth *Auth) ResourceAllowed(c *gin.Context, context types.UserContext, labels types.Labels) bool { +func (auth *AuthService) ResourceAllowed(c *gin.Context, context types.UserContext, labels types.Labels) bool { if context.OAuth { log.Debug().Msg("Checking OAuth whitelist") return utils.CheckFilter(labels.OAuth.Whitelist, context.Email) @@ -334,7 +347,7 @@ func (auth *Auth) ResourceAllowed(c *gin.Context, context types.UserContext, lab return utils.CheckFilter(labels.Users, context.Username) } -func (auth *Auth) OAuthGroup(c *gin.Context, context types.UserContext, labels types.Labels) bool { +func (auth *AuthService) OAuthGroup(c *gin.Context, context types.UserContext, labels types.Labels) bool { if labels.OAuth.Groups == "" { return true } @@ -361,7 +374,7 @@ func (auth *Auth) OAuthGroup(c *gin.Context, context types.UserContext, labels t return false } -func (auth *Auth) AuthEnabled(uri string, labels types.Labels) (bool, error) { +func (auth *AuthService) AuthEnabled(uri string, labels types.Labels) (bool, error) { // If the label is empty, auth is enabled if labels.Allowed == "" { return true, nil @@ -385,7 +398,7 @@ func (auth *Auth) AuthEnabled(uri string, labels types.Labels) (bool, error) { return true, nil } -func (auth *Auth) GetBasicAuth(c *gin.Context) *types.User { +func (auth *AuthService) GetBasicAuth(c *gin.Context) *types.User { username, password, ok := c.Request.BasicAuth() if !ok { return nil @@ -396,7 +409,7 @@ func (auth *Auth) GetBasicAuth(c *gin.Context) *types.User { } } -func (auth *Auth) CheckIP(labels types.Labels, ip string) bool { +func (auth *AuthService) CheckIP(labels types.Labels, ip string) bool { // Check if the IP is in block list for _, blocked := range labels.IP.Block { res, err := utils.FilterIP(blocked, ip) @@ -433,7 +446,7 @@ func (auth *Auth) CheckIP(labels types.Labels, ip string) bool { return true } -func (auth *Auth) BypassedIP(labels types.Labels, ip string) bool { +func (auth *AuthService) BypassedIP(labels types.Labels, ip string) bool { // For every IP in the bypass list, check if the IP matches for _, bypassed := range labels.IP.Bypass { res, err := utils.FilterIP(bypassed, ip) diff --git a/internal/docker/docker.go b/internal/service/docker_service.go similarity index 76% rename from internal/docker/docker.go rename to internal/service/docker_service.go index f5a04681..f067d7f2 100644 --- a/internal/docker/docker.go +++ b/internal/service/docker_service.go @@ -1,9 +1,9 @@ -package docker +package service import ( "context" "strings" - "tinyauth/internal/types" + "tinyauth/internal/config" "tinyauth/internal/utils" container "github.com/docker/docker/api/types/container" @@ -11,27 +11,27 @@ import ( "github.com/rs/zerolog/log" ) -type Docker struct { +type DockerService struct { Client *client.Client Context context.Context } -func NewDocker() (*Docker, error) { +func NewDockerService() *DockerService { + return &DockerService{} +} + +func (docker *DockerService) Init() error { client, err := client.NewClientWithOpts(client.FromEnv) if err != nil { - return nil, err + return err } ctx := context.Background() client.NegotiateAPIVersion(ctx) - - return &Docker{ - Client: client, - Context: ctx, - }, nil + return nil } -func (docker *Docker) GetContainers() ([]container.Summary, error) { +func (docker *DockerService) GetContainers() ([]container.Summary, error) { containers, err := docker.Client.ContainerList(docker.Context, container.ListOptions{}) if err != nil { return nil, err @@ -39,7 +39,7 @@ func (docker *Docker) GetContainers() ([]container.Summary, error) { return containers, nil } -func (docker *Docker) InspectContainer(containerId string) (container.InspectResponse, error) { +func (docker *DockerService) InspectContainer(containerId string) (container.InspectResponse, error) { inspect, err := docker.Client.ContainerInspect(docker.Context, containerId) if err != nil { return container.InspectResponse{}, err @@ -47,17 +47,17 @@ func (docker *Docker) InspectContainer(containerId string) (container.InspectRes return inspect, nil } -func (docker *Docker) DockerConnected() bool { +func (docker *DockerService) DockerConnected() bool { _, err := docker.Client.Ping(docker.Context) return err == nil } -func (docker *Docker) GetLabels(app string, domain string) (types.Labels, error) { +func (docker *DockerService) GetLabels(app string, domain string) (config.Labels, error) { isConnected := docker.DockerConnected() if !isConnected { log.Debug().Msg("Docker not connected, returning empty labels") - return types.Labels{}, nil + return config.Labels{}, nil } log.Debug().Msg("Getting containers") @@ -65,7 +65,7 @@ func (docker *Docker) GetLabels(app string, domain string) (types.Labels, error) containers, err := docker.GetContainers() if err != nil { log.Error().Err(err).Msg("Error getting containers") - return types.Labels{}, err + return config.Labels{}, err } for _, container := range containers { @@ -98,5 +98,5 @@ func (docker *Docker) GetLabels(app string, domain string) (types.Labels, error) } log.Debug().Msg("No matching container found, returning empty labels") - return types.Labels{}, nil + return config.Labels{}, nil } diff --git a/internal/service/generic_oauth_service.go b/internal/service/generic_oauth_service.go new file mode 100644 index 00000000..9bd6a8ee --- /dev/null +++ b/internal/service/generic_oauth_service.go @@ -0,0 +1,114 @@ +package service + +import ( + "context" + "crypto/rand" + "crypto/tls" + "encoding/base64" + "encoding/json" + "io" + "net/http" + "tinyauth/internal/config" + + "golang.org/x/oauth2" +) + +type GenericOAuthService struct { + Config oauth2.Config + Context context.Context + Token *oauth2.Token + Verifier string + InsecureSkipVerify bool + ServiceName string + UserinfoURL string +} + +func NewGenericOAuthService(config config.OAuthServiceConfig) *GenericOAuthService { + return &GenericOAuthService{ + Config: oauth2.Config{ + ClientID: config.ClientID, + ClientSecret: config.ClientSecret, + RedirectURL: config.RedirectURL, + Scopes: config.Scopes, + Endpoint: oauth2.Endpoint{ + AuthURL: config.AuthURL, + TokenURL: config.TokenURL, + }, + }, + InsecureSkipVerify: config.InsecureSkipVerify, + ServiceName: config.Name, + UserinfoURL: config.UserinfoURL, + } +} + +func (generic *GenericOAuthService) Init() error { + transport := &http.Transport{ + TLSClientConfig: &tls.Config{ + InsecureSkipVerify: generic.InsecureSkipVerify, + MinVersion: tls.VersionTLS12, + }, + } + + httpClient := &http.Client{ + Transport: transport, + } + + ctx := context.Background() + + ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient) + verifier := oauth2.GenerateVerifier() + + generic.Context = ctx + generic.Verifier = verifier + return nil +} + +func (generic *GenericOAuthService) Name() string { + return generic.ServiceName +} + +func (generic *GenericOAuthService) GenerateState() string { + b := make([]byte, 128) + rand.Read(b) + state := base64.URLEncoding.EncodeToString(b) + return state +} + +func (generic *GenericOAuthService) GetAuthURL(state string) string { + return generic.Config.AuthCodeURL(state, oauth2.AccessTypeOffline, oauth2.S256ChallengeOption(generic.Verifier)) +} + +func (generic *GenericOAuthService) VerifyCode(code string) error { + token, err := generic.Config.Exchange(generic.Context, code, oauth2.VerifierOption(generic.Verifier)) + + if err != nil { + return nil + } + + generic.Token = token + return nil +} + +func (generic *GenericOAuthService) Userinfo() (config.Claims, error) { + var user config.Claims + + client := generic.Config.Client(generic.Context, generic.Token) + + res, err := client.Get(generic.UserinfoURL) + if err != nil { + return user, err + } + defer res.Body.Close() + + body, err := io.ReadAll(res.Body) + if err != nil { + return user, err + } + + err = json.Unmarshal(body, &user) + if err != nil { + return user, err + } + + return user, nil +} diff --git a/internal/service/github_oauth_service.go b/internal/service/github_oauth_service.go new file mode 100644 index 00000000..57d8391e --- /dev/null +++ b/internal/service/github_oauth_service.go @@ -0,0 +1,144 @@ +package service + +import ( + "context" + "crypto/rand" + "encoding/base64" + "encoding/json" + "errors" + "io" + "net/http" + "tinyauth/internal/config" + + "golang.org/x/oauth2" +) + +var GithubOAuthScopes = []string{"user:email", "read:user"} + +type GithubEmailResponse []struct { + Email string `json:"email"` + Primary bool `json:"primary"` +} + +type GithubUserInfoResponse struct { + Login string `json:"login"` + Name string `json:"name"` +} + +type GithubOAuthService struct { + Config oauth2.Config + Context context.Context + Token *oauth2.Token + Verifier string +} + +func NewGithubOAuthService(config config.OAuthServiceConfig) *GithubOAuthService { + return &GithubOAuthService{ + Config: oauth2.Config{ + ClientID: config.ClientID, + ClientSecret: config.ClientSecret, + RedirectURL: config.RedirectURL, + Scopes: GithubOAuthScopes, + }, + } +} + +func (github *GithubOAuthService) Init() error { + httpClient := &http.Client{} + ctx := context.Background() + ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient) + verifier := oauth2.GenerateVerifier() + + github.Context = ctx + github.Verifier = verifier + return nil +} + +func (github *GithubOAuthService) Name() string { + return "github" +} + +func (github *GithubOAuthService) GenerateState() string { + b := make([]byte, 128) + rand.Read(b) + state := base64.URLEncoding.EncodeToString(b) + return state +} + +func (github *GithubOAuthService) GetAuthURL(state string) string { + return github.Config.AuthCodeURL(state, oauth2.AccessTypeOffline, oauth2.S256ChallengeOption(github.Verifier)) +} + +func (github *GithubOAuthService) VerifyCode(code string) error { + token, err := github.Config.Exchange(github.Context, code, oauth2.VerifierOption(github.Verifier)) + + if err != nil { + return nil + } + + github.Token = token + return nil +} + +func (github *GithubOAuthService) Userinfo() (config.Claims, error) { + var user config.Claims + + client := github.Config.Client(github.Context, github.Token) + + res, err := client.Get("https://api.github.com/user") + if err != nil { + return user, err + } + defer res.Body.Close() + + body, err := io.ReadAll(res.Body) + if err != nil { + return user, err + } + + var userInfo GithubUserInfoResponse + + err = json.Unmarshal(body, &userInfo) + if err != nil { + return user, err + } + + res, err = client.Get("https://api.github.com/user/emails") + if err != nil { + return user, err + } + defer res.Body.Close() + + body, err = io.ReadAll(res.Body) + if err != nil { + return user, err + } + + var emails GithubEmailResponse + + err = json.Unmarshal(body, &emails) + if err != nil { + return user, err + } + + for _, email := range emails { + if email.Primary { + user.Email = email.Email + break + } + } + + if len(emails) == 0 { + return user, errors.New("no emails found") + } + + // Use first available email if no primary email was found + if user.Email == "" { + user.Email = emails[0].Email + } + + user.PreferredUsername = userInfo.Login + user.Name = userInfo.Name + + return user, nil +} diff --git a/internal/service/google_oauth_service.go b/internal/service/google_oauth_service.go new file mode 100644 index 00000000..2d86a566 --- /dev/null +++ b/internal/service/google_oauth_service.go @@ -0,0 +1,106 @@ +package service + +import ( + "context" + "crypto/rand" + "encoding/base64" + "encoding/json" + "io" + "net/http" + "strings" + "tinyauth/internal/config" + + "golang.org/x/oauth2" +) + +var GoogleOAuthScopes = []string{"https://www.googleapis.com/auth/userinfo.email", "https://www.googleapis.com/auth/userinfo.profile"} + +type GoogleUserInfoResponse struct { + Email string `json:"email"` + Name string `json:"name"` +} + +type GoogleOAuthService struct { + Config oauth2.Config + Context context.Context + Token *oauth2.Token + Verifier string +} + +func NewGoogleOAuthService(config config.OAuthServiceConfig) *GoogleOAuthService { + return &GoogleOAuthService{ + Config: oauth2.Config{ + ClientID: config.ClientID, + ClientSecret: config.ClientSecret, + RedirectURL: config.RedirectURL, + Scopes: GoogleOAuthScopes, + }, + } +} + +func (google *GoogleOAuthService) Init() error { + httpClient := &http.Client{} + ctx := context.Background() + ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient) + verifier := oauth2.GenerateVerifier() + + google.Context = ctx + google.Verifier = verifier + return nil +} + +func (google *GoogleOAuthService) Name() string { + return "google" +} + +func (oauth *GoogleOAuthService) GenerateState() string { + b := make([]byte, 128) + rand.Read(b) + state := base64.URLEncoding.EncodeToString(b) + return state +} + +func (google *GoogleOAuthService) GetAuthURL(state string) string { + return google.Config.AuthCodeURL(state, oauth2.AccessTypeOffline, oauth2.S256ChallengeOption(google.Verifier)) +} + +func (google *GoogleOAuthService) VerifyCode(code string) error { + token, err := google.Config.Exchange(google.Context, code, oauth2.VerifierOption(google.Verifier)) + + if err != nil { + return nil + } + + google.Token = token + return nil +} + +func (google *GoogleOAuthService) Userinfo() (config.Claims, error) { + var user config.Claims + + client := google.Config.Client(google.Context, google.Token) + + res, err := client.Get("https://www.googleapis.com/userinfo/v2/me") + if err != nil { + return config.Claims{}, err + } + defer res.Body.Close() + + body, err := io.ReadAll(res.Body) + if err != nil { + return config.Claims{}, err + } + + var userInfo GoogleUserInfoResponse + + err = json.Unmarshal(body, &userInfo) + if err != nil { + return config.Claims{}, err + } + + user.PreferredUsername = strings.Split(userInfo.Email, "@")[0] + user.Name = userInfo.Name + user.Email = userInfo.Email + + return user, nil +} diff --git a/internal/ldap/ldap.go b/internal/service/ldap_service.go similarity index 64% rename from internal/ldap/ldap.go rename to internal/service/ldap_service.go index 61578d76..805e2f72 100644 --- a/internal/ldap/ldap.go +++ b/internal/service/ldap_service.go @@ -1,30 +1,40 @@ -package ldap +package service import ( "context" "crypto/tls" "fmt" "time" - "tinyauth/internal/types" "github.com/cenkalti/backoff/v5" ldapgo "github.com/go-ldap/ldap/v3" "github.com/rs/zerolog/log" ) -type LDAP struct { - Config types.LdapConfig +type LdapServiceConfig struct { + Address string + BindDN string + BindPassword string + BaseDN string + Insecure bool + SearchFilter string +} + +type LdapService struct { + Config LdapServiceConfig Conn *ldapgo.Conn } -func NewLDAP(config types.LdapConfig) (*LDAP, error) { - ldap := &LDAP{ +func NewLdapService(config LdapServiceConfig) *LdapService { + return &LdapService{ Config: config, } +} +func (ldap *LdapService) Init() error { _, err := ldap.connect() if err != nil { - return nil, fmt.Errorf("failed to connect to LDAP server: %w", err) + return fmt.Errorf("failed to connect to LDAP server: %w", err) } go func() { @@ -41,13 +51,13 @@ func NewLDAP(config types.LdapConfig) (*LDAP, error) { } }() - return ldap, nil + return nil } -func (l *LDAP) connect() (*ldapgo.Conn, error) { +func (ldap *LdapService) connect() (*ldapgo.Conn, error) { log.Debug().Msg("Connecting to LDAP server") - conn, err := ldapgo.DialURL(l.Config.Address, ldapgo.DialWithTLSConfig(&tls.Config{ - InsecureSkipVerify: l.Config.Insecure, + conn, err := ldapgo.DialURL(ldap.Config.Address, ldapgo.DialWithTLSConfig(&tls.Config{ + InsecureSkipVerify: ldap.Config.Insecure, MinVersion: tls.VersionTLS12, })) if err != nil { @@ -55,30 +65,30 @@ func (l *LDAP) connect() (*ldapgo.Conn, error) { } log.Debug().Msg("Binding to LDAP server") - err = conn.Bind(l.Config.BindDN, l.Config.BindPassword) + err = conn.Bind(ldap.Config.BindDN, ldap.Config.BindPassword) if err != nil { return nil, err } // Set and return the connection - l.Conn = conn + ldap.Conn = conn return conn, nil } -func (l *LDAP) Search(username string) (string, error) { +func (ldap *LdapService) Search(username string) (string, error) { // Escape the username to prevent LDAP injection escapedUsername := ldapgo.EscapeFilter(username) - filter := fmt.Sprintf(l.Config.SearchFilter, escapedUsername) + filter := fmt.Sprintf(ldap.Config.SearchFilter, escapedUsername) searchRequest := ldapgo.NewSearchRequest( - l.Config.BaseDN, + ldap.Config.BaseDN, ldapgo.ScopeWholeSubtree, ldapgo.NeverDerefAliases, 0, 0, false, filter, []string{"dn"}, nil, ) - searchResult, err := l.Conn.Search(searchRequest) + searchResult, err := ldap.Conn.Search(searchRequest) if err != nil { return "", err } @@ -91,15 +101,15 @@ func (l *LDAP) Search(username string) (string, error) { return userDN, nil } -func (l *LDAP) Bind(userDN string, password string) error { - err := l.Conn.Bind(userDN, password) +func (ldap *LdapService) Bind(userDN string, password string) error { + err := ldap.Conn.Bind(userDN, password) if err != nil { return err } return nil } -func (l *LDAP) heartbeat() error { +func (ldap *LdapService) heartbeat() error { log.Debug().Msg("Performing LDAP connection heartbeat") searchRequest := ldapgo.NewSearchRequest( @@ -110,7 +120,7 @@ func (l *LDAP) heartbeat() error { nil, ) - _, err := l.Conn.Search(searchRequest) + _, err := ldap.Conn.Search(searchRequest) if err != nil { return err } @@ -119,7 +129,7 @@ func (l *LDAP) heartbeat() error { return nil } -func (l *LDAP) reconnect() error { +func (ldap *LdapService) reconnect() error { log.Info().Msg("Reconnecting to LDAP server") exp := backoff.NewExponentialBackOff() @@ -129,8 +139,8 @@ func (l *LDAP) reconnect() error { exp.Reset() operation := func() (*ldapgo.Conn, error) { - l.Conn.Close() - conn, err := l.connect() + ldap.Conn.Close() + conn, err := ldap.connect() if err != nil { return nil, nil } diff --git a/internal/types/types.go b/internal/types/types.go deleted file mode 100644 index 1cb6bed8..00000000 --- a/internal/types/types.go +++ /dev/null @@ -1,70 +0,0 @@ -package types - -import ( - "time" - "tinyauth/internal/oauth" -) - -// User is the struct for a user -type User struct { - Username string - Password string - TotpSecret string -} - -// UserSearch is the response of the get user -type UserSearch struct { - Username string - Type string // "local", "ldap" or empty -} - -// Users is a list of users -type Users []User - -// OAuthProviders is the struct for the OAuth providers -type OAuthProviders struct { - Github *oauth.OAuth - Google *oauth.OAuth - Microsoft *oauth.OAuth -} - -// SessionCookie is the cookie for the session (exculding the expiry) -type SessionCookie struct { - Username string - Name string - Email string - Provider string - TotpPending bool - OAuthGroups string -} - -// UserContext is the context for the user -type UserContext struct { - Username string - Name string - Email string - IsLoggedIn bool - OAuth bool - Provider string - TotpPending bool - OAuthGroups string - TotpEnabled bool -} - -// LoginAttempt tracks information about login attempts for rate limiting -type LoginAttempt struct { - FailedAttempts int - LastAttempt time.Time - LockedUntil time.Time -} - -type UnauthorizedQuery struct { - Username string `url:"username"` - Resource string `url:"resource"` - GroupErr bool `url:"groupErr"` - IP string `url:"ip"` -} - -type RedirectQuery struct { - RedirectURI string `url:"redirect_uri"` -} From dbadb096b4a863c64e53aee6978211fd9fb15246 Mon Sep 17 00:00:00 2001 From: Stavros Date: Mon, 25 Aug 2025 19:33:52 +0300 Subject: [PATCH 05/17] feat: create oauth broker service --- internal/config/config.go | 11 - internal/controller/oauth_controller.go | 45 +- internal/controller/proxy_controller.go | 23 +- internal/controller/user_controller.go | 14 +- internal/middleware/context_middleware.go | 33 +- internal/service/auth_service.go | 60 +-- internal/service/generic_oauth_service.go | 6 - internal/service/github_oauth_service.go | 4 - internal/service/google_oauth_service.go | 4 - internal/service/oauth_broker_service.go | 76 +++ internal/utils/utils.go | 43 +- internal/utils/utils_test.go | 548 ---------------------- 12 files changed, 184 insertions(+), 683 deletions(-) create mode 100644 internal/service/oauth_broker_service.go delete mode 100644 internal/utils/utils_test.go diff --git a/internal/config/config.go b/internal/config/config.go index 48dc47f0..5584d0e8 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -1,7 +1,5 @@ package config -import "time" - type Claims struct { Name string `json:"name"` Email string `json:"email"` @@ -100,7 +98,6 @@ type OAuthServiceConfig struct { TokenURL string UserinfoURL string InsecureSkipVerify bool - Name string } type User struct { @@ -114,8 +111,6 @@ type UserSearch struct { Type string // local, ldap or unknown } -type Users []User - type SessionCookie struct { Username string Name string @@ -137,12 +132,6 @@ type UserContext struct { TotpEnabled bool } -type LoginAttempt struct { - FailedAttempts int - LastAttempt time.Time - LockedUntil time.Time -} - type UnauthorizedQuery struct { Username string `url:"username"` Resource string `url:"resource"` diff --git a/internal/controller/oauth_controller.go b/internal/controller/oauth_controller.go index 63b63229..0178af6c 100644 --- a/internal/controller/oauth_controller.go +++ b/internal/controller/oauth_controller.go @@ -5,9 +5,8 @@ import ( "net/http" "strings" "time" - "tinyauth/internal/auth" - "tinyauth/internal/providers" - "tinyauth/internal/types" + "tinyauth/internal/config" + "tinyauth/internal/service" "tinyauth/internal/utils" "github.com/gin-gonic/gin" @@ -26,18 +25,18 @@ type OAuthControllerConfig struct { } type OAuthController struct { - Config OAuthControllerConfig - Router *gin.RouterGroup - Auth *auth.Auth - Providers *providers.Providers + Config OAuthControllerConfig + Router *gin.RouterGroup + Auth *service.AuthService + Broker *service.OAuthBrokerService } -func NewOAuthController(config OAuthControllerConfig, router *gin.RouterGroup, auth *auth.Auth, providers *providers.Providers) *OAuthController { +func NewOAuthController(config OAuthControllerConfig, router *gin.RouterGroup, auth *service.AuthService, broker *service.OAuthBrokerService) *OAuthController { return &OAuthController{ - Config: config, - Router: router, - Auth: auth, - Providers: providers, + Config: config, + Router: router, + Auth: auth, + Broker: broker, } } @@ -59,9 +58,9 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) { return } - provider := controller.Providers.GetProvider(req.Provider) + service, exists := controller.Broker.GetService(req.Provider) - if provider == nil { + if !exists { c.JSON(404, gin.H{ "status": 404, "message": "Not Found", @@ -69,8 +68,8 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) { return } - state := provider.GenerateState() - authURL := provider.GetAuthURL(state) + state := service.GenerateState() + authURL := service.GetAuthURL(state) c.SetCookie(controller.Config.CSRFCookieName, state, int(time.Hour.Seconds()), "/", "", controller.Config.SecureCookie, true) redirectURI := c.Query("redirect_uri") @@ -109,20 +108,20 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) { c.SetCookie(controller.Config.CSRFCookieName, "", -1, "/", "", controller.Config.SecureCookie, true) code := c.Query("code") - provider := controller.Providers.GetProvider(req.Provider) + service, exists := controller.Broker.GetService(req.Provider) - if provider == nil { + if !exists { c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.Config.AppURL)) return } - _, err = provider.ExchangeToken(code) + err = service.VerifyCode(code) if err != nil { c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.Config.AppURL)) return } - user, err := controller.Providers.GetUser(req.Provider) + user, err := controller.Broker.GetUser(req.Provider) if err != nil { c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.Config.AppURL)) @@ -135,7 +134,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) { } if !controller.Auth.EmailWhitelisted(user.Email) { - queries, err := query.Values(types.UnauthorizedQuery{ + queries, err := query.Values(config.UnauthorizedQuery{ Username: user.Email, }) @@ -156,7 +155,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) { name = fmt.Sprintf("%s (%s)", utils.Capitalize(strings.Split(user.Email, "@")[0]), strings.Split(user.Email, "@")[1]) } - controller.Auth.CreateSessionCookie(c, &types.SessionCookie{ + controller.Auth.CreateSessionCookie(c, &config.SessionCookie{ Username: user.Email, Name: name, Email: user.Email, @@ -171,7 +170,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) { return } - queries, err := query.Values(types.RedirectQuery{ + queries, err := query.Values(config.RedirectQuery{ RedirectURI: redirectURI, }) diff --git a/internal/controller/proxy_controller.go b/internal/controller/proxy_controller.go index f8476f01..ced09bff 100644 --- a/internal/controller/proxy_controller.go +++ b/internal/controller/proxy_controller.go @@ -4,9 +4,8 @@ import ( "fmt" "net/http" "strings" - "tinyauth/internal/auth" - "tinyauth/internal/docker" - "tinyauth/internal/types" + "tinyauth/internal/config" + "tinyauth/internal/service" "tinyauth/internal/utils" "github.com/gin-gonic/gin" @@ -24,11 +23,11 @@ type ProxyControllerConfig struct { type ProxyController struct { Config ProxyControllerConfig Router *gin.RouterGroup - Docker *docker.Docker - Auth *auth.Auth + Docker *service.DockerService + Auth *service.AuthService } -func NewProxyController(config ProxyControllerConfig, router *gin.RouterGroup, docker *docker.Docker, auth *auth.Auth) *ProxyController { +func NewProxyController(config ProxyControllerConfig, router *gin.RouterGroup, docker *service.DockerService, auth *service.AuthService) *ProxyController { return &ProxyController{ Config: config, Router: router, @@ -109,7 +108,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) { return } - queries, err := query.Values(types.UnauthorizedQuery{ + queries, err := query.Values(config.UnauthorizedQuery{ Resource: strings.Split(host, ".")[0], IP: clientIP, }) @@ -157,12 +156,12 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) { return } - var userContext types.UserContext + var userContext config.UserContext context, err := utils.GetContext(c) if err != nil { - userContext = types.UserContext{ + userContext = config.UserContext{ IsLoggedIn: false, } } else { @@ -185,7 +184,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) { return } - queries, err := query.Values(types.UnauthorizedQuery{ + queries, err := query.Values(config.UnauthorizedQuery{ Resource: strings.Split(host, ".")[0], }) @@ -216,7 +215,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) { return } - queries, err := query.Values(types.UnauthorizedQuery{ + queries, err := query.Values(config.UnauthorizedQuery{ Resource: strings.Split(host, ".")[0], GroupErr: true, }) @@ -268,7 +267,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) { return } - queries, err := query.Values(types.RedirectQuery{ + queries, err := query.Values(config.RedirectQuery{ RedirectURI: fmt.Sprintf("%s://%s%s", proto, host, uri), }) diff --git a/internal/controller/user_controller.go b/internal/controller/user_controller.go index e017826a..77bb6f33 100644 --- a/internal/controller/user_controller.go +++ b/internal/controller/user_controller.go @@ -3,8 +3,8 @@ package controller import ( "fmt" "strings" - "tinyauth/internal/auth" - "tinyauth/internal/types" + "tinyauth/internal/config" + "tinyauth/internal/service" "tinyauth/internal/utils" "github.com/gin-gonic/gin" @@ -27,10 +27,10 @@ type UserControllerConfig struct { type UserController struct { Config UserControllerConfig Router *gin.RouterGroup - Auth *auth.Auth + Auth *service.AuthService } -func NewUserController(config UserControllerConfig, router *gin.RouterGroup, auth *auth.Auth) *UserController { +func NewUserController(config UserControllerConfig, router *gin.RouterGroup, auth *service.AuthService) *UserController { return &UserController{ Config: config, Router: router, @@ -101,7 +101,7 @@ func (controller *UserController) loginHandler(c *gin.Context) { user := controller.Auth.GetLocalUser(userSearch.Username) if user.TotpSecret != "" { - controller.Auth.CreateSessionCookie(c, &types.SessionCookie{ + controller.Auth.CreateSessionCookie(c, &config.SessionCookie{ Username: user.Username, Name: utils.Capitalize(req.Username), Email: fmt.Sprintf("%s@%s", strings.ToLower(req.Username), controller.Config.Domain), @@ -118,7 +118,7 @@ func (controller *UserController) loginHandler(c *gin.Context) { } } - controller.Auth.CreateSessionCookie(c, &types.SessionCookie{ + controller.Auth.CreateSessionCookie(c, &config.SessionCookie{ Username: req.Username, Name: utils.Capitalize(req.Username), Email: fmt.Sprintf("%s@%s", strings.ToLower(req.Username), controller.Config.Domain), @@ -202,7 +202,7 @@ func (controller *UserController) totpHandler(c *gin.Context) { controller.Auth.RecordLoginAttempt(rateIdentifier, true) - controller.Auth.CreateSessionCookie(c, &types.SessionCookie{ + controller.Auth.CreateSessionCookie(c, &config.SessionCookie{ Username: user.Username, Name: utils.Capitalize(user.Username), Email: fmt.Sprintf("%s@%s", strings.ToLower(user.Username), controller.Config.Domain), diff --git a/internal/middleware/context_middleware.go b/internal/middleware/context_middleware.go index ead4879b..a83d4659 100644 --- a/internal/middleware/context_middleware.go +++ b/internal/middleware/context_middleware.go @@ -3,9 +3,8 @@ package middleware import ( "fmt" "strings" - "tinyauth/internal/auth" - "tinyauth/internal/providers" - "tinyauth/internal/types" + "tinyauth/internal/config" + "tinyauth/internal/service" "tinyauth/internal/utils" "github.com/gin-gonic/gin" @@ -16,16 +15,16 @@ type ContextMiddlewareConfig struct { } type ContextMiddleware struct { - Config ContextMiddlewareConfig - Auth *auth.Auth - Providers *providers.Providers + Config ContextMiddlewareConfig + Auth *service.AuthService + Broker *service.OAuthBrokerService } -func NewContextMiddleware(config ContextMiddlewareConfig, auth *auth.Auth, providers *providers.Providers) *ContextMiddleware { +func NewContextMiddleware(config ContextMiddlewareConfig, auth *service.AuthService, broker *service.OAuthBrokerService) *ContextMiddleware { return &ContextMiddleware{ - Config: config, - Auth: auth, - Providers: providers, + Config: config, + Auth: auth, + Broker: broker, } } @@ -46,7 +45,7 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc { } if cookie.TotpPending { - c.Set("context", &types.UserContext{ + c.Set("context", &config.UserContext{ Username: cookie.Username, Name: cookie.Name, Email: cookie.Email, @@ -66,7 +65,7 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc { goto basic } - c.Set("context", &types.UserContext{ + c.Set("context", &config.UserContext{ Username: cookie.Username, Name: cookie.Name, Email: cookie.Email, @@ -76,9 +75,9 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc { c.Next() return default: - provider := m.Providers.GetProvider(cookie.Provider) + _, exists := m.Broker.GetService(cookie.Provider) - if provider == nil { + if !exists { goto basic } @@ -87,7 +86,7 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc { goto basic } - c.Set("context", &types.UserContext{ + c.Set("context", &config.UserContext{ Username: cookie.Username, Name: cookie.Name, Email: cookie.Email, @@ -124,7 +123,7 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc { case "local": user := m.Auth.GetLocalUser(basic.Username) - c.Set("context", &types.UserContext{ + c.Set("context", &config.UserContext{ Username: user.Username, Name: utils.Capitalize(user.Username), Email: fmt.Sprintf("%s@%s", strings.ToLower(user.Username), m.Config.Domain), @@ -135,7 +134,7 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc { c.Next() return case "ldap": - c.Set("context", &types.UserContext{ + c.Set("context", &config.UserContext{ Username: basic.Username, Name: utils.Capitalize(basic.Username), Email: fmt.Sprintf("%s@%s", strings.ToLower(basic.Username), m.Config.Domain), diff --git a/internal/service/auth_service.go b/internal/service/auth_service.go index ebbd1ad9..46bad066 100644 --- a/internal/service/auth_service.go +++ b/internal/service/auth_service.go @@ -6,7 +6,7 @@ import ( "strings" "sync" "time" - "tinyauth/internal/types" + "tinyauth/internal/config" "tinyauth/internal/utils" "github.com/gin-gonic/gin" @@ -15,8 +15,14 @@ import ( "golang.org/x/crypto/bcrypt" ) +type LoginAttempt struct { + FailedAttempts int + LastAttempt time.Time + LockedUntil time.Time +} + type AuthServiceConfig struct { - Users types.Users + Users []config.User OauthWhitelist string SessionExpiry int CookieSecure bool @@ -31,7 +37,7 @@ type AuthServiceConfig struct { type AuthService struct { Config AuthServiceConfig Docker *DockerService - LoginAttempts map[string]*types.LoginAttempt + LoginAttempts map[string]*LoginAttempt LoginMutex sync.RWMutex Store *sessions.CookieStore LDAP *LdapService @@ -41,7 +47,7 @@ func NewAuthService(config AuthServiceConfig, docker *DockerService, ldap *LdapS return &AuthService{ Config: config, Docker: docker, - LoginAttempts: make(map[string]*types.LoginAttempt), + LoginAttempts: make(map[string]*LoginAttempt), LDAP: ldap, } } @@ -75,13 +81,13 @@ func (auth *AuthService) GetSession(c *gin.Context) (*sessions.Session, error) { return session, nil } -func (auth *AuthService) SearchUser(username string) types.UserSearch { +func (auth *AuthService) SearchUser(username string) config.UserSearch { log.Debug().Str("username", username).Msg("Searching for user") // Check local users first if auth.GetLocalUser(username).Username != "" { log.Debug().Str("username", username).Msg("Found local user") - return types.UserSearch{ + return config.UserSearch{ Username: username, Type: "local", } @@ -93,20 +99,20 @@ func (auth *AuthService) SearchUser(username string) types.UserSearch { userDN, err := auth.LDAP.Search(username) if err != nil { log.Warn().Err(err).Str("username", username).Msg("Failed to find user in LDAP") - return types.UserSearch{} + return config.UserSearch{} } - return types.UserSearch{ + return config.UserSearch{ Username: userDN, Type: "ldap", } } - return types.UserSearch{ + return config.UserSearch{ Type: "unknown", } } -func (auth *AuthService) VerifyUser(search types.UserSearch, password string) bool { +func (auth *AuthService) VerifyUser(search config.UserSearch, password string) bool { // Authenticate the user based on the type switch search.Type { case "local": @@ -144,7 +150,7 @@ func (auth *AuthService) VerifyUser(search types.UserSearch, password string) bo return false } -func (auth *AuthService) GetLocalUser(username string) types.User { +func (auth *AuthService) GetLocalUser(username string) config.User { // Loop through users and return the user if the username matches log.Debug().Str("username", username).Msg("Searching for local user") @@ -156,10 +162,10 @@ func (auth *AuthService) GetLocalUser(username string) types.User { // If no user found, return an empty user log.Warn().Str("username", username).Msg("Local user not found") - return types.User{} + return config.User{} } -func (auth *AuthService) CheckPassword(user types.User, password string) bool { +func (auth *AuthService) CheckPassword(user config.User, password string) bool { return bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password)) == nil } @@ -201,7 +207,7 @@ func (auth *AuthService) RecordLoginAttempt(identifier string, success bool) { // Get current attempt record or create a new one attempt, exists := auth.LoginAttempts[identifier] if !exists { - attempt = &types.LoginAttempt{} + attempt = &LoginAttempt{} auth.LoginAttempts[identifier] = attempt } @@ -229,7 +235,7 @@ func (auth *AuthService) EmailWhitelisted(email string) bool { return utils.CheckFilter(auth.Config.OauthWhitelist, email) } -func (auth *AuthService) CreateSessionCookie(c *gin.Context, data *types.SessionCookie) error { +func (auth *AuthService) CreateSessionCookie(c *gin.Context, data *config.SessionCookie) error { log.Debug().Msg("Creating session cookie") session, err := auth.GetSession(c) @@ -288,13 +294,13 @@ func (auth *AuthService) DeleteSessionCookie(c *gin.Context) error { return nil } -func (auth *AuthService) GetSessionCookie(c *gin.Context) (types.SessionCookie, error) { +func (auth *AuthService) GetSessionCookie(c *gin.Context) (config.SessionCookie, error) { log.Debug().Msg("Getting session cookie") session, err := auth.GetSession(c) if err != nil { log.Error().Err(err).Msg("Failed to get session") - return types.SessionCookie{}, err + return config.SessionCookie{}, err } log.Debug().Msg("Got session") @@ -311,18 +317,18 @@ func (auth *AuthService) GetSessionCookie(c *gin.Context) (types.SessionCookie, if !usernameOk || !providerOK || !expiryOk || !totpPendingOk || !emailOk || !nameOk || !oauthGroupsOk { log.Warn().Msg("Session cookie is invalid") auth.DeleteSessionCookie(c) - return types.SessionCookie{}, nil + return config.SessionCookie{}, nil } // If the session cookie has expired, delete it if time.Now().Unix() > expiry { log.Warn().Msg("Session cookie expired") auth.DeleteSessionCookie(c) - return types.SessionCookie{}, nil + return config.SessionCookie{}, nil } log.Debug().Str("username", username).Str("provider", provider).Int64("expiry", expiry).Bool("totpPending", totpPending).Str("name", name).Str("email", email).Str("oauthGroups", oauthGroups).Msg("Parsed cookie") - return types.SessionCookie{ + return config.SessionCookie{ Username: username, Name: name, Email: email, @@ -337,7 +343,7 @@ func (auth *AuthService) UserAuthConfigured() bool { return len(auth.Config.Users) > 0 || auth.LDAP != nil } -func (auth *AuthService) ResourceAllowed(c *gin.Context, context types.UserContext, labels types.Labels) bool { +func (auth *AuthService) ResourceAllowed(c *gin.Context, context config.UserContext, labels config.Labels) bool { if context.OAuth { log.Debug().Msg("Checking OAuth whitelist") return utils.CheckFilter(labels.OAuth.Whitelist, context.Email) @@ -347,7 +353,7 @@ func (auth *AuthService) ResourceAllowed(c *gin.Context, context types.UserConte return utils.CheckFilter(labels.Users, context.Username) } -func (auth *AuthService) OAuthGroup(c *gin.Context, context types.UserContext, labels types.Labels) bool { +func (auth *AuthService) OAuthGroup(c *gin.Context, context config.UserContext, labels config.Labels) bool { if labels.OAuth.Groups == "" { return true } @@ -374,7 +380,7 @@ func (auth *AuthService) OAuthGroup(c *gin.Context, context types.UserContext, l return false } -func (auth *AuthService) AuthEnabled(uri string, labels types.Labels) (bool, error) { +func (auth *AuthService) AuthEnabled(uri string, labels config.Labels) (bool, error) { // If the label is empty, auth is enabled if labels.Allowed == "" { return true, nil @@ -398,18 +404,18 @@ func (auth *AuthService) AuthEnabled(uri string, labels types.Labels) (bool, err return true, nil } -func (auth *AuthService) GetBasicAuth(c *gin.Context) *types.User { +func (auth *AuthService) GetBasicAuth(c *gin.Context) *config.User { username, password, ok := c.Request.BasicAuth() if !ok { return nil } - return &types.User{ + return &config.User{ Username: username, Password: password, } } -func (auth *AuthService) CheckIP(labels types.Labels, ip string) bool { +func (auth *AuthService) CheckIP(labels config.Labels, ip string) bool { // Check if the IP is in block list for _, blocked := range labels.IP.Block { res, err := utils.FilterIP(blocked, ip) @@ -446,7 +452,7 @@ func (auth *AuthService) CheckIP(labels types.Labels, ip string) bool { return true } -func (auth *AuthService) BypassedIP(labels types.Labels, ip string) bool { +func (auth *AuthService) BypassedIP(labels config.Labels, ip string) bool { // For every IP in the bypass list, check if the IP matches for _, bypassed := range labels.IP.Bypass { res, err := utils.FilterIP(bypassed, ip) diff --git a/internal/service/generic_oauth_service.go b/internal/service/generic_oauth_service.go index 9bd6a8ee..c68d150f 100644 --- a/internal/service/generic_oauth_service.go +++ b/internal/service/generic_oauth_service.go @@ -19,7 +19,6 @@ type GenericOAuthService struct { Token *oauth2.Token Verifier string InsecureSkipVerify bool - ServiceName string UserinfoURL string } @@ -36,7 +35,6 @@ func NewGenericOAuthService(config config.OAuthServiceConfig) *GenericOAuthServi }, }, InsecureSkipVerify: config.InsecureSkipVerify, - ServiceName: config.Name, UserinfoURL: config.UserinfoURL, } } @@ -63,10 +61,6 @@ func (generic *GenericOAuthService) Init() error { return nil } -func (generic *GenericOAuthService) Name() string { - return generic.ServiceName -} - func (generic *GenericOAuthService) GenerateState() string { b := make([]byte, 128) rand.Read(b) diff --git a/internal/service/github_oauth_service.go b/internal/service/github_oauth_service.go index 57d8391e..a8c13345 100644 --- a/internal/service/github_oauth_service.go +++ b/internal/service/github_oauth_service.go @@ -54,10 +54,6 @@ func (github *GithubOAuthService) Init() error { return nil } -func (github *GithubOAuthService) Name() string { - return "github" -} - func (github *GithubOAuthService) GenerateState() string { b := make([]byte, 128) rand.Read(b) diff --git a/internal/service/google_oauth_service.go b/internal/service/google_oauth_service.go index 2d86a566..6d9eaed8 100644 --- a/internal/service/google_oauth_service.go +++ b/internal/service/google_oauth_service.go @@ -49,10 +49,6 @@ func (google *GoogleOAuthService) Init() error { return nil } -func (google *GoogleOAuthService) Name() string { - return "google" -} - func (oauth *GoogleOAuthService) GenerateState() string { b := make([]byte, 128) rand.Read(b) diff --git a/internal/service/oauth_broker_service.go b/internal/service/oauth_broker_service.go new file mode 100644 index 00000000..6b5b1e6f --- /dev/null +++ b/internal/service/oauth_broker_service.go @@ -0,0 +1,76 @@ +package service + +import ( + "errors" + "tinyauth/internal/config" + + "github.com/rs/zerolog/log" +) + +type OAuthService interface { + Init() error + GenerateState() string + GetAuthURL(state string) string + VerifyCode(code string) error + Userinfo() (config.Claims, error) +} + +type OAuthBrokerService struct { + Services map[string]OAuthService + Configs map[string]config.OAuthServiceConfig +} + +func NewOAuthBrokerService(configs map[string]config.OAuthServiceConfig) *OAuthBrokerService { + return &OAuthBrokerService{ + Services: make(map[string]OAuthService), + Configs: configs, + } +} + +func (broker *OAuthBrokerService) Init() error { + for name, cfg := range broker.Configs { + switch name { + case "github": + service := NewGithubOAuthService(cfg) + broker.Services[name] = service + case "google": + service := NewGoogleOAuthService(cfg) + broker.Services[name] = service + default: + service := NewGenericOAuthService(cfg) + broker.Services[name] = service + } + } + + for name, service := range broker.Services { + err := service.Init() + if err != nil { + log.Error().Err(err).Msgf("Failed to initialize OAuth service: %s", name) + return err + } + log.Info().Msgf("Initialized OAuth service: %s", name) + } + + return nil +} + +func (broker *OAuthBrokerService) GetConfiguredServices() []string { + services := make([]string, 0, len(broker.Services)) + for name := range broker.Services { + services = append(services, name) + } + return services +} + +func (broker *OAuthBrokerService) GetService(name string) (OAuthService, bool) { + service, exists := broker.Services[name] + return service, exists +} + +func (broker *OAuthBrokerService) GetUser(service string) (config.Claims, error) { + oauthService, exists := broker.Services[service] + if !exists { + return config.Claims{}, errors.New("oauth service not found") + } + return oauthService.Userinfo() +} diff --git a/internal/utils/utils.go b/internal/utils/utils.go index 8c2f4ea3..67b904f6 100644 --- a/internal/utils/utils.go +++ b/internal/utils/utils.go @@ -11,7 +11,7 @@ import ( "os" "regexp" "strings" - "tinyauth/internal/types" + "tinyauth/internal/config" "github.com/gin-gonic/gin" "github.com/traefik/paerser/parser" @@ -22,21 +22,21 @@ import ( ) // Parses a list of comma separated users in a struct -func ParseUsers(users string) (types.Users, error) { +func ParseUsers(users string) ([]config.User, error) { log.Debug().Msg("Parsing users") - var usersParsed types.Users + var usersParsed []config.User userList := strings.Split(users, ",") if len(userList) == 0 { - return types.Users{}, errors.New("invalid user format") + return []config.User{}, errors.New("invalid user format") } for _, user := range userList { parsed, err := ParseUser(user) if err != nil { - return types.Users{}, err + return []config.User{}, err } usersParsed = append(usersParsed, parsed) } @@ -107,11 +107,11 @@ func GetSecret(conf string, file string) string { } // Get the users from the config or file -func GetUsers(conf string, file string) (types.Users, error) { +func GetUsers(conf string, file string) ([]config.User, error) { var users string if conf == "" && file == "" { - return types.Users{}, nil + return []config.User{}, nil } if conf != "" { @@ -152,23 +152,18 @@ func ParseHeaders(headers []string) map[string]string { } // Get labels parses a map of labels into a struct with only the needed labels -func GetLabels(labels map[string]string) (types.Labels, error) { - var labelsParsed types.Labels +func GetLabels(labels map[string]string) (config.Labels, error) { + var labelsParsed config.Labels err := parser.Decode(labels, &labelsParsed, "tinyauth", "tinyauth.users", "tinyauth.allowed", "tinyauth.headers", "tinyauth.domain", "tinyauth.basic", "tinyauth.oauth", "tinyauth.ip") if err != nil { log.Error().Err(err).Msg("Error parsing labels") - return types.Labels{}, err + return config.Labels{}, err } return labelsParsed, nil } -// Check if any of the OAuth providers are configured based on the client id and secret -func OAuthConfigured(config types.Config) bool { - return (config.GithubClientId != "" && config.GithubClientSecret != "") || (config.GoogleClientId != "" && config.GoogleClientSecret != "") || (config.GenericClientId != "" && config.GenericClientSecret != "") -} - // Filter helper function func Filter[T any](slice []T, test func(T) bool) (res []T) { for _, value := range slice { @@ -180,7 +175,7 @@ func Filter[T any](slice []T, test func(T) bool) (res []T) { } // Parse user -func ParseUser(user string) (types.User, error) { +func ParseUser(user string) (config.User, error) { if strings.Contains(user, "$$") { user = strings.ReplaceAll(user, "$$", "$") } @@ -188,23 +183,23 @@ func ParseUser(user string) (types.User, error) { userSplit := strings.Split(user, ":") if len(userSplit) < 2 || len(userSplit) > 3 { - return types.User{}, errors.New("invalid user format") + return config.User{}, errors.New("invalid user format") } for _, userPart := range userSplit { if strings.TrimSpace(userPart) == "" { - return types.User{}, errors.New("invalid user format") + return config.User{}, errors.New("invalid user format") } } if len(userSplit) == 2 { - return types.User{ + return config.User{ Username: strings.TrimSpace(userSplit[0]), Password: strings.TrimSpace(userSplit[1]), }, nil } - return types.User{ + return config.User{ Username: strings.TrimSpace(userSplit[0]), Password: strings.TrimSpace(userSplit[1]), TotpSecret: strings.TrimSpace(userSplit[2]), @@ -350,17 +345,17 @@ func CoalesceToString(value any) string { } } -func GetContext(c *gin.Context) (types.UserContext, error) { +func GetContext(c *gin.Context) (config.UserContext, error) { userContextValue, exists := c.Get("context") if !exists { - return types.UserContext{}, errors.New("no user context in request") + return config.UserContext{}, errors.New("no user context in request") } - userContext, ok := userContextValue.(*types.UserContext) + userContext, ok := userContextValue.(*config.UserContext) if !ok { - return types.UserContext{}, errors.New("invalid user context in request") + return config.UserContext{}, errors.New("invalid user context in request") } return *userContext, nil diff --git a/internal/utils/utils_test.go b/internal/utils/utils_test.go deleted file mode 100644 index 5ae7e897..00000000 --- a/internal/utils/utils_test.go +++ /dev/null @@ -1,548 +0,0 @@ -package utils_test - -import ( - "fmt" - "os" - "reflect" - "testing" - "tinyauth/internal/types" - "tinyauth/internal/utils" -) - -func TestParseUsers(t *testing.T) { - t.Log("Testing parse users with a valid string") - - users := "user1:pass1,user2:pass2" - expected := types.Users{ - { - Username: "user1", - Password: "pass1", - }, - { - Username: "user2", - Password: "pass2", - }, - } - - result, err := utils.ParseUsers(users) - if err != nil { - t.Fatalf("Error parsing users: %v", err) - } - - if !reflect.DeepEqual(expected, result) { - t.Fatalf("Expected %v, got %v", expected, result) - } -} - -func TestGetUpperDomain(t *testing.T) { - t.Log("Testing get upper domain with a valid url") - - url := "https://sub1.sub2.domain.com:8080" - expected := "sub2.domain.com" - - result, err := utils.GetUpperDomain(url) - if err != nil { - t.Fatalf("Error getting root url: %v", err) - } - - if expected != result { - t.Fatalf("Expected %v, got %v", expected, result) - } -} - -func TestReadFile(t *testing.T) { - t.Log("Creating a test file") - - err := os.WriteFile("/tmp/test.txt", []byte("test"), 0644) - if err != nil { - t.Fatalf("Error creating test file: %v", err) - } - - t.Log("Testing read file with a valid file") - - data, err := utils.ReadFile("/tmp/test.txt") - if err != nil { - t.Fatalf("Error reading file: %v", err) - } - - if data != "test" { - t.Fatalf("Expected test, got %v", data) - } - - t.Log("Cleaning up test file") - - err = os.Remove("/tmp/test.txt") - if err != nil { - t.Fatalf("Error cleaning up test file: %v", err) - } -} - -func TestParseFileToLine(t *testing.T) { - t.Log("Testing parse file to line with a valid string") - - content := "\nuser1:pass1\nuser2:pass2\n" - expected := "user1:pass1,user2:pass2" - - result := utils.ParseFileToLine(content) - - if expected != result { - t.Fatalf("Expected %v, got %v", expected, result) - } -} - -func TestGetSecret(t *testing.T) { - t.Log("Testing get secret with an empty config and file") - - conf := "" - file := "/tmp/test.txt" - expected := "test" - - err := os.WriteFile(file, []byte(fmt.Sprintf("\n\n \n\n\n %s \n\n \n ", expected)), 0644) - if err != nil { - t.Fatalf("Error creating test file: %v", err) - } - - result := utils.GetSecret(conf, file) - - if result != expected { - t.Fatalf("Expected %v, got %v", expected, result) - } - - t.Log("Testing get secret with an empty file and a valid config") - - result = utils.GetSecret(expected, "") - - if result != expected { - t.Fatalf("Expected %v, got %v", expected, result) - } - - t.Log("Testing get secret with both a valid config and file") - - result = utils.GetSecret(expected, file) - - if result != expected { - t.Fatalf("Expected %v, got %v", expected, result) - } - - t.Log("Cleaning up test file") - - err = os.Remove(file) - if err != nil { - t.Fatalf("Error cleaning up test file: %v", err) - } -} - -func TestGetUsers(t *testing.T) { - t.Log("Testing get users with a config and no file") - - conf := "user1:pass1,user2:pass2" - file := "" - expected := types.Users{ - { - Username: "user1", - Password: "pass1", - }, - { - Username: "user2", - Password: "pass2", - }, - } - - result, err := utils.GetUsers(conf, file) - if err != nil { - t.Fatalf("Error getting users: %v", err) - } - - if !reflect.DeepEqual(expected, result) { - t.Fatalf("Expected %v, got %v", expected, result) - } - - t.Log("Testing get users with a file and no config") - - conf = "" - file = "/tmp/test.txt" - expected = types.Users{ - { - Username: "user1", - Password: "pass1", - }, - { - Username: "user2", - Password: "pass2", - }, - } - - err = os.WriteFile(file, []byte("user1:pass1\nuser2:pass2"), 0644) - if err != nil { - t.Fatalf("Error creating test file: %v", err) - } - - result, err = utils.GetUsers(conf, file) - if err != nil { - t.Fatalf("Error getting users: %v", err) - } - - if !reflect.DeepEqual(expected, result) { - t.Fatalf("Expected %v, got %v", expected, result) - } - - t.Log("Testing get users with both a config and file") - - conf = "user3:pass3" - expected = types.Users{ - { - Username: "user3", - Password: "pass3", - }, - { - Username: "user1", - Password: "pass1", - }, - { - Username: "user2", - Password: "pass2", - }, - } - - result, err = utils.GetUsers(conf, file) - if err != nil { - t.Fatalf("Error getting users: %v", err) - } - - if !reflect.DeepEqual(expected, result) { - t.Fatalf("Expected %v, got %v", expected, result) - } - - t.Log("Cleaning up test file") - - err = os.Remove(file) - if err != nil { - t.Fatalf("Error cleaning up test file: %v", err) - } -} - -func TestGetLabels(t *testing.T) { - t.Log("Testing get labels with a valid map") - - labels := map[string]string{ - "tinyauth.users": "user1,user2", - "tinyauth.oauth.whitelist": "/regex/", - "tinyauth.allowed": "random", - "tinyauth.headers": "X-Header=value", - "tinyauth.oauth.groups": "group1,group2", - } - - expected := types.Labels{ - Users: "user1,user2", - Allowed: "random", - Headers: []string{"X-Header=value"}, - OAuth: types.OAuthLabels{ - Whitelist: "/regex/", - Groups: "group1,group2", - }, - } - - result, err := utils.GetLabels(labels) - if err != nil { - t.Fatalf("Error getting labels: %v", err) - } - - if !reflect.DeepEqual(expected, result) { - t.Fatalf("Expected %v, got %v", expected, result) - } -} - -func TestParseUser(t *testing.T) { - t.Log("Testing parse user with a valid user") - - user := "user:pass:secret" - expected := types.User{ - Username: "user", - Password: "pass", - TotpSecret: "secret", - } - - result, err := utils.ParseUser(user) - if err != nil { - t.Fatalf("Error parsing user: %v", err) - } - - if !reflect.DeepEqual(expected, result) { - t.Fatalf("Expected %v, got %v", expected, result) - } - - t.Log("Testing parse user with an escaped user") - - user = "user:p$$ass$$:secret" - expected = types.User{ - Username: "user", - Password: "p$ass$", - TotpSecret: "secret", - } - - result, err = utils.ParseUser(user) - if err != nil { - t.Fatalf("Error parsing user: %v", err) - } - - if !reflect.DeepEqual(expected, result) { - t.Fatalf("Expected %v, got %v", expected, result) - } - - t.Log("Testing parse user with an invalid user") - - user = "user::pass" - - _, err = utils.ParseUser(user) - if err == nil { - t.Fatalf("Expected error parsing user") - } -} - -func TestCheckFilter(t *testing.T) { - t.Log("Testing check filter with a comma separated list") - - filter := "user1,user2,user3" - str := "user1" - expected := true - - result := utils.CheckFilter(filter, str) - - if result != expected { - t.Fatalf("Expected %v, got %v", expected, result) - } - - t.Log("Testing check filter with a regex filter") - - filter = "/^user[0-9]+$/" - str = "user1" - expected = true - - result = utils.CheckFilter(filter, str) - - if result != expected { - t.Fatalf("Expected %v, got %v", expected, result) - } - - t.Log("Testing check filter with an empty filter") - - filter = "" - str = "user1" - expected = true - - result = utils.CheckFilter(filter, str) - - if result != expected { - t.Fatalf("Expected %v, got %v", expected, result) - } - - t.Log("Testing check filter with an invalid regex filter") - - filter = "/^user[0-9+$/" - str = "user1" - expected = false - - result = utils.CheckFilter(filter, str) - - if result != expected { - t.Fatalf("Expected %v, got %v", expected, result) - } - - t.Log("Testing check filter with a non matching list") - - filter = "user1,user2,user3" - str = "user4" - expected = false - - result = utils.CheckFilter(filter, str) - - if result != expected { - t.Fatalf("Expected %v, got %v", expected, result) - } -} - -func TestSanitizeHeader(t *testing.T) { - t.Log("Testing sanitize header with a valid string") - - str := "X-Header=value" - expected := "X-Header=value" - - result := utils.SanitizeHeader(str) - - if result != expected { - t.Fatalf("Expected %v, got %v", expected, result) - } - - t.Log("Testing sanitize header with an invalid string") - - str = "X-Header=val\nue" - expected = "X-Header=value" - - result = utils.SanitizeHeader(str) - - if result != expected { - t.Fatalf("Expected %v, got %v", expected, result) - } -} - -func TestParseHeaders(t *testing.T) { - t.Log("Testing parse headers with a valid string") - - headers := []string{"X-Hea\x00der1=value1", "X-Header2=value\n2"} - expected := map[string]string{ - "X-Header1": "value1", - "X-Header2": "value2", - } - - result := utils.ParseHeaders(headers) - - if !reflect.DeepEqual(expected, result) { - t.Fatalf("Expected %v, got %v", expected, result) - } - - t.Log("Testing parse headers with an invalid string") - - headers = []string{"X-Header1=", "X-Header2", "=value", "X-Header3=value3"} - expected = map[string]string{"X-Header3": "value3"} - - result = utils.ParseHeaders(headers) - - if !reflect.DeepEqual(expected, result) { - t.Fatalf("Expected %v, got %v", expected, result) - } -} - -func TestParseSecretFile(t *testing.T) { - t.Log("Testing parse secret file with a valid file") - - content := "\n\n \n\n\n secret \n\n \n " - expected := "secret" - - result := utils.ParseSecretFile(content) - - if result != expected { - t.Fatalf("Expected %v, got %v", expected, result) - } -} - -func TestFilterIP(t *testing.T) { - t.Log("Testing filter IP with an IP and a valid CIDR") - - ip := "10.10.10.10" - filter := "10.10.10.0/24" - expected := true - - result, err := utils.FilterIP(filter, ip) - if err != nil { - t.Fatalf("Error filtering IP: %v", err) - } - - if result != expected { - t.Fatalf("Expected %v, got %v", expected, result) - } - - t.Log("Testing filter IP with an IP and a valid IP") - - filter = "10.10.10.10" - expected = true - - result, err = utils.FilterIP(filter, ip) - if err != nil { - t.Fatalf("Error filtering IP: %v", err) - } - - if result != expected { - t.Fatalf("Expected %v, got %v", expected, result) - } - - t.Log("Testing filter IP with an IP and an non matching CIDR") - - filter = "10.10.15.0/24" - expected = false - - result, err = utils.FilterIP(filter, ip) - if err != nil { - t.Fatalf("Error filtering IP: %v", err) - } - - if result != expected { - t.Fatalf("Expected %v, got %v", expected, result) - } - - t.Log("Testing filter IP with a non matching IP and a valid CIDR") - - filter = "10.10.10.11" - expected = false - - result, err = utils.FilterIP(filter, ip) - - if err != nil { - t.Fatalf("Error filtering IP: %v", err) - } - - if result != expected { - t.Fatalf("Expected %v, got %v", expected, result) - } - - t.Log("Testing filter IP with an IP and an invalid CIDR") - - filter = "10.../83" - - _, err = utils.FilterIP(filter, ip) - if err == nil { - t.Fatalf("Expected error filtering IP") - } -} - -func TestDeriveKey(t *testing.T) { - t.Log("Testing the derive key function") - - master := "master" - info := "info" - expected := "gdrdU/fXzclYjiSXRexEatVgV13qQmKl" - - result, err := utils.DeriveKey(master, info) - - if err != nil { - t.Fatalf("Error deriving key: %v", err) - } - - if result != expected { - t.Fatalf("Expected %v, got %v", expected, result) - } -} - -func TestCoalesceToString(t *testing.T) { - t.Log("Testing coalesce to string with a string") - - value := any("test") - expected := "test" - - result := utils.CoalesceToString(value) - - if result != expected { - t.Fatalf("Expected %v, got %v", expected, result) - } - - t.Log("Testing coalesce to string with a slice of strings") - - value = []any{any("test1"), any("test2"), any(123)} - expected = "test1,test2" - - result = utils.CoalesceToString(value) - - if result != expected { - t.Fatalf("Expected %v, got %v", expected, result) - } - - t.Log("Testing coalesce to string with an unsupported type") - - value = 12345 - expected = "" - - result = utils.CoalesceToString(value) - - if result != expected { - t.Fatalf("Expected %v, got %v", expected, result) - } -} From 659d3561e0031c462ea40a67c9cc5b6e8b7c0b40 Mon Sep 17 00:00:00 2001 From: Stavros Date: Mon, 25 Aug 2025 22:03:06 +0300 Subject: [PATCH 06/17] refactor: use a boostrap service to bootstrap the app --- .env.example | 4 +- cmd/root.go | 340 ++++++----------------- cmd/version.go | 8 +- internal/bootstrap/bootstrap_app.go | 246 ++++++++++++++++ internal/config/config.go | 7 +- internal/service/auth_service.go | 7 +- internal/service/github_oauth_service.go | 2 + internal/service/google_oauth_service.go | 2 + internal/utils/utils.go | 20 ++ 9 files changed, 366 insertions(+), 270 deletions(-) create mode 100644 internal/bootstrap/bootstrap_app.go diff --git a/.env.example b/.env.example index 8edde7b7..0f43bf04 100644 --- a/.env.example +++ b/.env.example @@ -5,7 +5,7 @@ SECRET_FILE=app_secret_file APP_URL=http://localhost:3000 USERS=your_user_password_hash USERS_FILE=users_file -COOKIE_SECURE=false +SECURE_COOKIE=false GITHUB_CLIENT_ID=github_client_id GITHUB_CLIENT_SECRET=github_client_secret GITHUB_CLIENT_SECRET_FILE=github_client_secret_file @@ -25,7 +25,7 @@ GENERIC_NAME=My OAuth SESSION_EXPIRY=7200 LOGIN_TIMEOUT=300 LOGIN_MAX_RETRIES=5 -LOG_LEVEL=0 +LOG_LEVEL=debug APP_TITLE=Tinyauth SSO FORGOT_PASSWORD_MESSAGE=Some message about resetting the password OAUTH_AUTO_REDIRECT=none diff --git a/cmd/root.go b/cmd/root.go index 8dadd5d0..898c27f4 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -1,22 +1,13 @@ package cmd import ( - "errors" - "fmt" "strings" totpCmd "tinyauth/cmd/totp" userCmd "tinyauth/cmd/user" - "tinyauth/internal/auth" - "tinyauth/internal/constants" - "tinyauth/internal/controller" - "tinyauth/internal/docker" - "tinyauth/internal/ldap" - "tinyauth/internal/middleware" - "tinyauth/internal/providers" - "tinyauth/internal/types" + "tinyauth/internal/bootstrap" + "tinyauth/internal/config" "tinyauth/internal/utils" - "github.com/gin-gonic/gin" "github.com/go-playground/validator" "github.com/rs/zerolog" "github.com/rs/zerolog/log" @@ -24,197 +15,51 @@ import ( "github.com/spf13/viper" ) -type Middleware interface { - Middleware() gin.HandlerFunc - Init() error - Name() string -} - var rootCmd = &cobra.Command{ Use: "tinyauth", Short: "The simplest way to protect your apps with a login screen.", Long: `Tinyauth is a simple authentication middleware that adds simple username/password login or OAuth with Google, Github and any generic OAuth provider to all of your docker apps.`, Run: func(cmd *cobra.Command, args []string) { - var config types.Config - err := viper.Unmarshal(&config) - HandleError(err, "Failed to parse config") - - // Check if secrets have a file associated with them - config.Secret = utils.GetSecret(config.Secret, config.SecretFile) - config.GithubClientSecret = utils.GetSecret(config.GithubClientSecret, config.GithubClientSecretFile) - config.GoogleClientSecret = utils.GetSecret(config.GoogleClientSecret, config.GoogleClientSecretFile) - config.GenericClientSecret = utils.GetSecret(config.GenericClientSecret, config.GenericClientSecretFile) - - validator := validator.New() - err = validator.Struct(config) - HandleError(err, "Failed to validate config") - - log.Logger = log.Level(zerolog.Level(config.LogLevel)) - log.Info().Str("version", strings.TrimSpace(constants.Version)).Msg("Starting tinyauth") - - log.Info().Msg("Parsing users") - users, err := utils.GetUsers(config.Users, config.UsersFile) - HandleError(err, "Failed to parse users") - - log.Debug().Msg("Getting domain") - domain, err := utils.GetUpperDomain(config.AppURL) - HandleError(err, "Failed to get upper domain") - log.Info().Str("domain", domain).Msg("Using domain for cookie store") + var conf config.Config - cookieId := utils.GenerateIdentifier(strings.Split(domain, ".")[0]) - sessionCookieName := fmt.Sprintf("%s-%s", constants.SessionCookieName, cookieId) - csrfCookieName := fmt.Sprintf("%s-%s", constants.CsrfCookieName, cookieId) - redirectCookieName := fmt.Sprintf("%s-%s", constants.RedirectCookieName, cookieId) - - log.Debug().Msg("Deriving HMAC and encryption secrets") - - hmacSecret, err := utils.DeriveKey(config.Secret, "hmac") - HandleError(err, "Failed to derive HMAC secret") - - encryptionSecret, err := utils.DeriveKey(config.Secret, "encryption") - HandleError(err, "Failed to derive encryption secret") - - // Split the config into service-specific sub-configs - oauthConfig := types.OAuthConfig{ - GithubClientId: config.GithubClientId, - GithubClientSecret: config.GithubClientSecret, - GoogleClientId: config.GoogleClientId, - GoogleClientSecret: config.GoogleClientSecret, - GenericClientId: config.GenericClientId, - GenericClientSecret: config.GenericClientSecret, - GenericScopes: strings.Split(config.GenericScopes, ","), - GenericAuthURL: config.GenericAuthURL, - GenericTokenURL: config.GenericTokenURL, - GenericUserURL: config.GenericUserURL, - GenericSkipSSL: config.GenericSkipSSL, - AppURL: config.AppURL, + err := viper.Unmarshal(&conf) + if err != nil { + log.Fatal().Err(err).Msg("Failed to parse config") } - authConfig := types.AuthConfig{ - Users: users, - OauthWhitelist: config.OAuthWhitelist, - CookieSecure: config.CookieSecure, - SessionExpiry: config.SessionExpiry, - Domain: domain, - LoginTimeout: config.LoginTimeout, - LoginMaxRetries: config.LoginMaxRetries, - SessionCookieName: sessionCookieName, - HMACSecret: hmacSecret, - EncryptionSecret: encryptionSecret, - } - - var ldapService *ldap.LDAP + // Check if secrets have a file associated with them + conf.Secret = utils.GetSecret(conf.Secret, conf.SecretFile) + conf.GithubClientSecret = utils.GetSecret(conf.GithubClientSecret, conf.GithubClientSecretFile) + conf.GoogleClientSecret = utils.GetSecret(conf.GoogleClientSecret, conf.GoogleClientSecretFile) + conf.GenericClientSecret = utils.GetSecret(conf.GenericClientSecret, conf.GenericClientSecretFile) - if config.LdapAddress != "" { - log.Info().Msg("Using LDAP for authentication") - ldapConfig := types.LdapConfig{ - Address: config.LdapAddress, - BindDN: config.LdapBindDN, - BindPassword: config.LdapBindPassword, - BaseDN: config.LdapBaseDN, - Insecure: config.LdapInsecure, - SearchFilter: config.LdapSearchFilter, - } - ldapService, err = ldap.NewLDAP(ldapConfig) - if err != nil { - log.Error().Err(err).Msg("Failed to initialize LDAP service, disabling LDAP authentication") - ldapService = nil - } - } else { - log.Info().Msg("LDAP not configured, using local users or OAuth") - } + validator := validator.New() - // Check if we have a source of users - if len(users) == 0 && !utils.OAuthConfigured(config) && ldapService == nil { - HandleError(errors.New("err no users"), "Unable to find a source of users") + err = validator.Struct(conf) + if err != nil { + log.Fatal().Err(err).Msg("Invalid config") } - // Setup the services - docker, err := docker.NewDocker() - HandleError(err, "Failed to initialize docker") - auth := auth.NewAuth(authConfig, docker, ldapService) - providers := providers.NewProviders(oauthConfig) + log.Logger = log.Level(zerolog.Level(utils.GetLogLevel(conf.LogLevel))) + log.Info().Str("version", strings.TrimSpace(config.Version)).Msg("Starting tinyauth") - // Create the engine - engine := gin.New() + // Create bootstrap app + app := bootstrap.NewBootstrapApp(conf) - // Create the group - router := engine.Group("/api") + // Run + err = app.Setup() - // Setup the middlewares - var middlewares []Middleware - - contextMiddleware := middleware.NewContextMiddleware(middleware.ContextMiddlewareConfig{ - Domain: domain, - }, auth, providers) - uiMiddleware := middleware.NewUIMiddleware() - zerologMiddleware := middleware.NewZerologMiddleware() - - middlewares = append(middlewares, contextMiddleware, uiMiddleware, zerologMiddleware) - - for _, middleware := range middlewares { - log.Debug().Str("middleware", middleware.Name()).Msg("Initializing middleware") - err := middleware.Init() - HandleError(err, fmt.Sprintf("Failed to initialize middleware %s", middleware.Name())) - router.Use(middleware.Middleware()) + if err != nil { + log.Fatal().Err(err).Msg("Failed to setup app") } - // Create configured providers - var configuredProviders []string - - configuredProviders = append(configuredProviders, providers.GetConfiguredProviders()...) - - if auth.UserAuthConfigured() { - configuredProviders = append(configuredProviders, "username") - } - - // Create controllers - contextController := controller.NewContextController(controller.ContextControllerConfig{ - ConfiguredProviders: configuredProviders, - DisableContinue: config.DisableContinue, - Title: config.Title, - GenericName: config.GenericName, - Domain: domain, - ForgotPasswordMessage: config.FogotPasswordMessage, - BackgroundImage: config.BackgroundImage, - OAuthAutoRedirect: config.OAuthAutoRedirect, - }, router) - contextController.SetupRoutes() - - healthController := controller.NewHealthController(router) - healthController.SetupRoutes() - - oauthController := controller.NewOAuthController(controller.OAuthControllerConfig{ - AppURL: config.AppURL, - SecureCookie: config.CookieSecure, - CSRFCookieName: csrfCookieName, - RedirectCookieName: redirectCookieName, - }, router, auth, providers) - oauthController.SetupRoutes() - - proxyController := controller.NewProxyController(controller.ProxyControllerConfig{ - AppURL: config.AppURL, - }, router, docker, auth) - proxyController.SetupRoutes() - - userController := controller.NewUserController(controller.UserControllerConfig{ - Domain: domain, - }, router, auth) - userController.SetupRoutes() - - // Run server - engine.Run(fmt.Sprintf("%s:%d", config.Address, config.Port)) }, } func Execute() { err := rootCmd.Execute() - HandleError(err, "Failed to execute root command") -} - -func HandleError(err error, msg string) { if err != nil { - log.Fatal().Err(err).Msg(msg) + log.Fatal().Err(err).Msg("Failed to execute command") } } @@ -224,85 +69,66 @@ func init() { viper.AutomaticEnv() - rootCmd.Flags().Int("port", 3000, "Port to run the server on.") - rootCmd.Flags().String("address", "0.0.0.0", "Address to bind the server to.") - rootCmd.Flags().String("secret", "", "Secret to use for the cookie.") - rootCmd.Flags().String("secret-file", "", "Path to a file containing the secret.") - rootCmd.Flags().String("app-url", "", "The tinyauth URL.") - rootCmd.Flags().String("users", "", "Comma separated list of users in the format username:hash.") - rootCmd.Flags().String("users-file", "", "Path to a file containing users in the format username:hash.") - rootCmd.Flags().Bool("cookie-secure", false, "Send cookie over secure connection only.") - rootCmd.Flags().String("github-client-id", "", "Github OAuth client ID.") - rootCmd.Flags().String("github-client-secret", "", "Github OAuth client secret.") - rootCmd.Flags().String("github-client-secret-file", "", "Github OAuth client secret file.") - rootCmd.Flags().String("google-client-id", "", "Google OAuth client ID.") - rootCmd.Flags().String("google-client-secret", "", "Google OAuth client secret.") - rootCmd.Flags().String("google-client-secret-file", "", "Google OAuth client secret file.") - rootCmd.Flags().String("generic-client-id", "", "Generic OAuth client ID.") - rootCmd.Flags().String("generic-client-secret", "", "Generic OAuth client secret.") - rootCmd.Flags().String("generic-client-secret-file", "", "Generic OAuth client secret file.") - rootCmd.Flags().String("generic-scopes", "", "Generic OAuth scopes.") - rootCmd.Flags().String("generic-auth-url", "", "Generic OAuth auth URL.") - rootCmd.Flags().String("generic-token-url", "", "Generic OAuth token URL.") - rootCmd.Flags().String("generic-user-url", "", "Generic OAuth user info URL.") - rootCmd.Flags().String("generic-name", "Generic", "Generic OAuth provider name.") - rootCmd.Flags().Bool("generic-skip-ssl", false, "Skip SSL verification for the generic OAuth provider.") - rootCmd.Flags().Bool("disable-continue", false, "Disable continue screen and redirect to app directly.") - rootCmd.Flags().String("oauth-whitelist", "", "Comma separated list of email addresses to whitelist when using OAuth.") - rootCmd.Flags().String("oauth-auto-redirect", "none", "Auto redirect to the specified OAuth provider if configured. (available providers: github, google, generic)") - rootCmd.Flags().Int("session-expiry", 86400, "Session (cookie) expiration time in seconds.") - rootCmd.Flags().Int("login-timeout", 300, "Login timeout in seconds after max retries reached (0 to disable).") - rootCmd.Flags().Int("login-max-retries", 5, "Maximum login attempts before timeout (0 to disable).") - rootCmd.Flags().Int("log-level", 1, "Log level.") - rootCmd.Flags().String("app-title", "Tinyauth", "Title of the app.") - rootCmd.Flags().String("forgot-password-message", "", "Message to show on the forgot password page.") - rootCmd.Flags().String("background-image", "/background.jpg", "Background image URL for the login page.") - rootCmd.Flags().String("ldap-address", "", "LDAP server address (e.g. ldap://localhost:389).") - rootCmd.Flags().String("ldap-bind-dn", "", "LDAP bind DN (e.g. uid=user,dc=example,dc=com).") - rootCmd.Flags().String("ldap-bind-password", "", "LDAP bind password.") - rootCmd.Flags().String("ldap-base-dn", "", "LDAP base DN (e.g. dc=example,dc=com).") - rootCmd.Flags().Bool("ldap-insecure", false, "Skip certificate verification for the LDAP server.") - rootCmd.Flags().String("ldap-search-filter", "(uid=%s)", "LDAP search filter for user lookup.") + configOptions := []struct { + name string + defaultVal any + description string + }{ + {"port", 3000, "Port to run the server on."}, + {"address", "0.0.0.0", "Address to bind the server to."}, + {"secret", "", "Secret to use for the cookie."}, + {"secret-file", "", "Path to a file containing the secret."}, + {"app-url", "", "The Tinyauth URL."}, + {"users", "", "Comma separated list of users in the format username:hash."}, + {"users-file", "", "Path to a file containing users in the format username:hash."}, + {"cookie-secure", false, "Send cookie over secure connection only."}, + {"github-client-id", "", "Github OAuth client ID."}, + {"github-client-secret", "", "Github OAuth client secret."}, + {"github-client-secret-file", "", "Github OAuth client secret file."}, + {"google-client-id", "", "Google OAuth client ID."}, + {"google-client-secret", "", "Google OAuth client secret."}, + {"google-client-secret-file", "", "Google OAuth client secret file."}, + {"generic-client-id", "", "Generic OAuth client ID."}, + {"generic-client-secret", "", "Generic OAuth client secret."}, + {"generic-client-secret-file", "", "Generic OAuth client secret file."}, + {"generic-scopes", "", "Generic OAuth scopes."}, + {"generic-auth-url", "", "Generic OAuth auth URL."}, + {"generic-token-url", "", "Generic OAuth token URL."}, + {"generic-user-url", "", "Generic OAuth user info URL."}, + {"generic-name", "Generic", "Generic OAuth provider name."}, + {"generic-skip-ssl", false, "Skip SSL verification for the generic OAuth provider."}, + {"disable-continue", false, "Disable continue screen and redirect to app directly."}, + {"oauth-whitelist", "", "Comma separated list of email addresses to whitelist when using OAuth."}, + {"oauth-auto-redirect", "none", "Auto redirect to the specified OAuth provider if configured. (available providers: github, google, generic)"}, + {"session-expiry", 86400, "Session (cookie) expiration time in seconds."}, + {"login-timeout", 300, "Login timeout in seconds after max retries reached (0 to disable)."}, + {"login-max-retries", 5, "Maximum login attempts before timeout (0 to disable)."}, + {"log-level", "info", "Log level."}, + {"app-title", "Tinyauth", "Title of the app."}, + {"forgot-password-message", "", "Message to show on the forgot password page."}, + {"background-image", "/background.jpg", "Background image URL for the login page."}, + {"ldap-address", "", "LDAP server address (e.g. ldap://localhost:389)."}, + {"ldap-bind-dn", "", "LDAP bind DN (e.g. uid=user,dc=example,dc=com)."}, + {"ldap-bind-password", "", "LDAP bind password."}, + {"ldap-base-dn", "", "LDAP base DN (e.g. dc=example,dc=com)."}, + {"ldap-insecure", false, "Skip certificate verification for the LDAP server."}, + {"ldap-search-filter", "(uid=%s)", "LDAP search filter for user lookup."}, + } + + for _, opt := range configOptions { + switch v := opt.defaultVal.(type) { + case bool: + rootCmd.Flags().Bool(opt.name, v, opt.description) + case int: + rootCmd.Flags().Int(opt.name, v, opt.description) + case string: + rootCmd.Flags().String(opt.name, v, opt.description) + } - viper.BindEnv("port", "PORT") - viper.BindEnv("address", "ADDRESS") - viper.BindEnv("secret", "SECRET") - viper.BindEnv("secret-file", "SECRET_FILE") - viper.BindEnv("app-url", "APP_URL") - viper.BindEnv("users", "USERS") - viper.BindEnv("users-file", "USERS_FILE") - viper.BindEnv("cookie-secure", "COOKIE_SECURE") - viper.BindEnv("github-client-id", "GITHUB_CLIENT_ID") - viper.BindEnv("github-client-secret", "GITHUB_CLIENT_SECRET") - viper.BindEnv("github-client-secret-file", "GITHUB_CLIENT_SECRET_FILE") - viper.BindEnv("google-client-id", "GOOGLE_CLIENT_ID") - viper.BindEnv("google-client-secret", "GOOGLE_CLIENT_SECRET") - viper.BindEnv("google-client-secret-file", "GOOGLE_CLIENT_SECRET_FILE") - viper.BindEnv("generic-client-id", "GENERIC_CLIENT_ID") - viper.BindEnv("generic-client-secret", "GENERIC_CLIENT_SECRET") - viper.BindEnv("generic-client-secret-file", "GENERIC_CLIENT_SECRET_FILE") - viper.BindEnv("generic-scopes", "GENERIC_SCOPES") - viper.BindEnv("generic-auth-url", "GENERIC_AUTH_URL") - viper.BindEnv("generic-token-url", "GENERIC_TOKEN_URL") - viper.BindEnv("generic-user-url", "GENERIC_USER_URL") - viper.BindEnv("generic-name", "GENERIC_NAME") - viper.BindEnv("generic-skip-ssl", "GENERIC_SKIP_SSL") - viper.BindEnv("disable-continue", "DISABLE_CONTINUE") - viper.BindEnv("oauth-whitelist", "OAUTH_WHITELIST") - viper.BindEnv("oauth-auto-redirect", "OAUTH_AUTO_REDIRECT") - viper.BindEnv("session-expiry", "SESSION_EXPIRY") - viper.BindEnv("log-level", "LOG_LEVEL") - viper.BindEnv("app-title", "APP_TITLE") - viper.BindEnv("login-timeout", "LOGIN_TIMEOUT") - viper.BindEnv("login-max-retries", "LOGIN_MAX_RETRIES") - viper.BindEnv("forgot-password-message", "FORGOT_PASSWORD_MESSAGE") - viper.BindEnv("background-image", "BACKGROUND_IMAGE") - viper.BindEnv("ldap-address", "LDAP_ADDRESS") - viper.BindEnv("ldap-bind-dn", "LDAP_BIND_DN") - viper.BindEnv("ldap-bind-password", "LDAP_BIND_PASSWORD") - viper.BindEnv("ldap-base-dn", "LDAP_BASE_DN") - viper.BindEnv("ldap-insecure", "LDAP_INSECURE") - viper.BindEnv("ldap-search-filter", "LDAP_SEARCH_FILTER") + // Create uppercase env var name + envVar := strings.ReplaceAll(strings.ToUpper(opt.name), "-", "_") + viper.BindEnv(opt.name, envVar) + } viper.BindPFlags(rootCmd.Flags()) } diff --git a/cmd/version.go b/cmd/version.go index ffbd6fce..2a1827b7 100644 --- a/cmd/version.go +++ b/cmd/version.go @@ -2,7 +2,7 @@ package cmd import ( "fmt" - "tinyauth/internal/constants" + "tinyauth/internal/config" "github.com/spf13/cobra" ) @@ -12,9 +12,9 @@ var versionCmd = &cobra.Command{ Short: "Print the version number of Tinyauth", Long: `All software has versions. This is Tinyauth's`, Run: func(cmd *cobra.Command, args []string) { - fmt.Printf("Version: %s\n", constants.Version) - fmt.Printf("Commit Hash: %s\n", constants.CommitHash) - fmt.Printf("Build Timestamp: %s\n", constants.BuildTimestamp) + fmt.Printf("Version: %s\n", config.Version) + fmt.Printf("Commit Hash: %s\n", config.CommitHash) + fmt.Printf("Build Timestamp: %s\n", config.BuildTimestamp) }, } diff --git a/internal/bootstrap/bootstrap_app.go b/internal/bootstrap/bootstrap_app.go new file mode 100644 index 00000000..2dfc61df --- /dev/null +++ b/internal/bootstrap/bootstrap_app.go @@ -0,0 +1,246 @@ +package bootstrap + +import ( + "fmt" + "strings" + "tinyauth/internal/config" + "tinyauth/internal/controller" + "tinyauth/internal/middleware" + "tinyauth/internal/service" + "tinyauth/internal/utils" + + "github.com/gin-gonic/gin" + "github.com/rs/zerolog/log" +) + +type Controller interface { + SetupRoutes() +} + +type Middleware interface { + Middleware() gin.HandlerFunc + Init() error + Name() string +} + +type Service interface { + Init() error +} + +type BootstrapApp struct { + Config config.Config +} + +func NewBootstrapApp(config config.Config) *BootstrapApp { + return &BootstrapApp{ + Config: config, + } +} + +func (app *BootstrapApp) Setup() error { + // Parse users + users, err := utils.GetUsers(app.Config.Users, app.Config.UsersFile) + + if err != nil { + return err + } + + // Get domain + domain, err := utils.GetUpperDomain(app.Config.AppURL) + + if err != nil { + return err + } + + // Cookie names + cookieId := utils.GenerateIdentifier(strings.Split(domain, ".")[0]) + sessionCookieName := fmt.Sprintf("%s-%s", config.SessionCookieName, cookieId) + csrfCookieName := fmt.Sprintf("%s-%s", config.CSRFCookieName, cookieId) + redirectCookieName := fmt.Sprintf("%s-%s", config.RedirectCookieName, cookieId) + + // Secrets + encryptionSecret, err := utils.DeriveKey(app.Config.Secret, "encryption") + + if err != nil { + return err + } + + hmacSecret, err := utils.DeriveKey(app.Config.Secret, "hmac") + + if err != nil { + return err + } + + // Create configs + authConfig := service.AuthServiceConfig{ + Users: users, + OauthWhitelist: app.Config.OAuthWhitelist, + SessionExpiry: app.Config.SessionExpiry, + SecureCookie: app.Config.SecureCookie, + Domain: domain, + LoginTimeout: app.Config.LoginTimeout, + LoginMaxRetries: app.Config.LoginMaxRetries, + SessionCookieName: sessionCookieName, + HMACSecret: hmacSecret, + EncryptionSecret: encryptionSecret, + } + + // Setup services + var ldapService *service.LdapService + + if app.Config.LdapAddress != "" { + ldapConfig := service.LdapServiceConfig{ + Address: app.Config.LdapAddress, + BindDN: app.Config.LdapBindDN, + BindPassword: app.Config.LdapBindPassword, + BaseDN: app.Config.LdapBaseDN, + Insecure: app.Config.LdapInsecure, + SearchFilter: app.Config.LdapSearchFilter, + } + + ldapService = service.NewLdapService(ldapConfig) + + err := ldapService.Init() + + if err != nil { + ldapService = nil + } + } + + dockerService := service.NewDockerService() + authService := service.NewAuthService(authConfig, dockerService, ldapService) + oauthBrokerService := service.NewOAuthBrokerService(app.getOAuthBrokerConfig()) + + // Initialize services + services := []Service{ + dockerService, + authService, + oauthBrokerService, + } + + for _, svc := range services { + if svc != nil { + err := svc.Init() + if err != nil { + return err + } + } + } + + // Configured providers + var configuredProviders []string + + if authService.UserAuthConfigured() || ldapService != nil { + configuredProviders = append(configuredProviders, "username") + } + + configuredProviders = append(configuredProviders, oauthBrokerService.GetConfiguredServices()...) + + if len(configuredProviders) == 0 { + return fmt.Errorf("no authentication providers configured") + } + + // Create engine + engine := gin.New() + router := engine.Group("/api") + + // Create middlewares + var middlewares []Middleware + + contextMiddleware := middleware.NewContextMiddleware(middleware.ContextMiddlewareConfig{ + Domain: domain, + }, authService, oauthBrokerService) + + uiMiddleware := middleware.NewUIMiddleware() + zerologMiddleware := middleware.NewZerologMiddleware() + + middlewares = append(middlewares, contextMiddleware, uiMiddleware, zerologMiddleware) + + for _, middleware := range middlewares { + log.Debug().Str("middleware", middleware.Name()).Msg("Initializing middleware") + err := middleware.Init() + if err != nil { + return fmt.Errorf("failed to initialize %s middleware: %w", middleware.Name(), err) + } + router.Use(middleware.Middleware()) + } + + // Create controllers + contextController := controller.NewContextController(controller.ContextControllerConfig{ + ConfiguredProviders: configuredProviders, + DisableContinue: app.Config.DisableContinue, + Title: app.Config.Title, + GenericName: app.Config.GenericName, + Domain: domain, + ForgotPasswordMessage: app.Config.FogotPasswordMessage, + BackgroundImage: app.Config.BackgroundImage, + OAuthAutoRedirect: app.Config.OAuthAutoRedirect, + }, router) + + oauthController := controller.NewOAuthController(controller.OAuthControllerConfig{ + AppURL: app.Config.AppURL, + SecureCookie: app.Config.SecureCookie, + CSRFCookieName: csrfCookieName, + RedirectCookieName: redirectCookieName, + }, router, authService, oauthBrokerService) + + proxyController := controller.NewProxyController(controller.ProxyControllerConfig{ + AppURL: app.Config.AppURL, + }, router, dockerService, authService) + + userController := controller.NewUserController(controller.UserControllerConfig{ + Domain: domain, + }, router, authService) + + healthController := controller.NewHealthController(router) + + // Setup routes + controller := []Controller{ + contextController, + oauthController, + proxyController, + userController, + healthController, + } + + for _, ctrl := range controller { + log.Debug().Msgf("Setting up %T routes", ctrl) + ctrl.SetupRoutes() + } + + // Start server + address := fmt.Sprintf("%s:%d", app.Config.Address, app.Config.Port) + log.Info().Msgf("Starting server on %s", address) + if err := engine.Run(address); err != nil { + log.Fatal().Err(err).Msg("Failed to start server") + } + + return nil +} + +// Temporary +func (app *BootstrapApp) getOAuthBrokerConfig() map[string]config.OAuthServiceConfig { + return map[string]config.OAuthServiceConfig{ + "google": { + ClientID: app.Config.GoogleClientId, + ClientSecret: app.Config.GoogleClientSecret, + RedirectURL: fmt.Sprintf("%s/api/oauth/callback/google", app.Config.AppURL), + }, + "github": { + ClientID: app.Config.GithubClientId, + ClientSecret: app.Config.GithubClientSecret, + RedirectURL: fmt.Sprintf("%s/api/oauth/callback/github", app.Config.AppURL), + }, + "generic": { + ClientID: app.Config.GenericClientId, + ClientSecret: app.Config.GenericClientSecret, + RedirectURL: fmt.Sprintf("%s/api/oauth/callback/generic", app.Config.AppURL), + Scopes: strings.Split(app.Config.GenericScopes, ","), + AuthURL: app.Config.GenericAuthURL, + TokenURL: app.Config.GenericTokenURL, + UserinfoURL: app.Config.GenericUserURL, + InsecureSkipVerify: app.Config.GenericSkipSSL, + }, + } + +} diff --git a/internal/config/config.go b/internal/config/config.go index 5584d0e8..655b61a6 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -12,7 +12,7 @@ var CommitHash = "n/a" var BuildTimestamp = "n/a" var SessionCookieName = "tinyauth-session" -var CsrfCookieName = "tinyauth-csrf" +var CSRFCookieName = "tinyauth-csrf" var RedirectCookieName = "tinyauth-redirect" type Config struct { @@ -23,7 +23,7 @@ type Config struct { AppURL string `validate:"required,url" mapstructure:"app-url"` Users string `mapstructure:"users"` UsersFile string `mapstructure:"users-file"` - CookieSecure bool `mapstructure:"cookie-secure"` + SecureCookie bool `mapstructure:"secure-cookie"` GithubClientId string `mapstructure:"github-client-id"` GithubClientSecret string `mapstructure:"github-client-secret"` GithubClientSecretFile string `mapstructure:"github-client-secret-file"` @@ -43,9 +43,8 @@ type Config struct { OAuthWhitelist string `mapstructure:"oauth-whitelist"` OAuthAutoRedirect string `mapstructure:"oauth-auto-redirect" validate:"oneof=none github google generic"` SessionExpiry int `mapstructure:"session-expiry"` - LogLevel int8 `mapstructure:"log-level" validate:"min=-1,max=5"` + LogLevel string `mapstructure:"log-level" validate:"oneof=trace debug info warn error fatal panic"` Title string `mapstructure:"app-title"` - EnvFile string `mapstructure:"env-file"` LoginTimeout int `mapstructure:"login-timeout"` LoginMaxRetries int `mapstructure:"login-max-retries"` FogotPasswordMessage string `mapstructure:"forgot-password-message"` diff --git a/internal/service/auth_service.go b/internal/service/auth_service.go index 46bad066..e91f98a1 100644 --- a/internal/service/auth_service.go +++ b/internal/service/auth_service.go @@ -25,7 +25,7 @@ type AuthServiceConfig struct { Users []config.User OauthWhitelist string SessionExpiry int - CookieSecure bool + SecureCookie bool Domain string LoginTimeout int LoginMaxRetries int @@ -57,10 +57,11 @@ func (auth *AuthService) Init() error { store.Options = &sessions.Options{ Path: "/", MaxAge: auth.Config.SessionExpiry, - Secure: auth.Config.CookieSecure, + Secure: auth.Config.SecureCookie, HttpOnly: true, Domain: fmt.Sprintf(".%s", auth.Config.Domain), } + auth.Store = store return nil } @@ -70,7 +71,7 @@ func (auth *AuthService) GetSession(c *gin.Context) (*sessions.Session, error) { // If there was an error getting the session, it might be invalid so let's clear it and retry if err != nil { log.Error().Err(err).Msg("Invalid session, clearing cookie and retrying") - c.SetCookie(auth.Config.SessionCookieName, "", -1, "/", fmt.Sprintf(".%s", auth.Config.Domain), auth.Config.CookieSecure, true) + c.SetCookie(auth.Config.SessionCookieName, "", -1, "/", fmt.Sprintf(".%s", auth.Config.Domain), auth.Config.SecureCookie, true) session, err = auth.Store.Get(c.Request, auth.Config.SessionCookieName) if err != nil { log.Error().Err(err).Msg("Failed to get session") diff --git a/internal/service/github_oauth_service.go b/internal/service/github_oauth_service.go index a8c13345..2f9e27f8 100644 --- a/internal/service/github_oauth_service.go +++ b/internal/service/github_oauth_service.go @@ -11,6 +11,7 @@ import ( "tinyauth/internal/config" "golang.org/x/oauth2" + "golang.org/x/oauth2/endpoints" ) var GithubOAuthScopes = []string{"user:email", "read:user"} @@ -39,6 +40,7 @@ func NewGithubOAuthService(config config.OAuthServiceConfig) *GithubOAuthService ClientSecret: config.ClientSecret, RedirectURL: config.RedirectURL, Scopes: GithubOAuthScopes, + Endpoint: endpoints.GitHub, }, } } diff --git a/internal/service/google_oauth_service.go b/internal/service/google_oauth_service.go index 6d9eaed8..776aeca7 100644 --- a/internal/service/google_oauth_service.go +++ b/internal/service/google_oauth_service.go @@ -11,6 +11,7 @@ import ( "tinyauth/internal/config" "golang.org/x/oauth2" + "golang.org/x/oauth2/endpoints" ) var GoogleOAuthScopes = []string{"https://www.googleapis.com/auth/userinfo.email", "https://www.googleapis.com/auth/userinfo.profile"} @@ -34,6 +35,7 @@ func NewGoogleOAuthService(config config.OAuthServiceConfig) *GoogleOAuthService ClientSecret: config.ClientSecret, RedirectURL: config.RedirectURL, Scopes: GoogleOAuthScopes, + Endpoint: endpoints.Google, }, } } diff --git a/internal/utils/utils.go b/internal/utils/utils.go index 67b904f6..7181a26e 100644 --- a/internal/utils/utils.go +++ b/internal/utils/utils.go @@ -18,6 +18,7 @@ import ( "golang.org/x/crypto/hkdf" "github.com/google/uuid" + "github.com/rs/zerolog" "github.com/rs/zerolog/log" ) @@ -360,3 +361,22 @@ func GetContext(c *gin.Context) (config.UserContext, error) { return *userContext, nil } + +func GetLogLevel(level string) zerolog.Level { + switch strings.ToLower(level) { + case "debug": + return zerolog.DebugLevel + case "info": + return zerolog.InfoLevel + case "warn": + return zerolog.WarnLevel + case "error": + return zerolog.ErrorLevel + case "fatal": + return zerolog.FatalLevel + case "panic": + return zerolog.PanicLevel + default: + return zerolog.InfoLevel + } +} From 04213836a1cbb85c0de088941694c553a63e4cca Mon Sep 17 00:00:00 2001 From: Stavros Date: Mon, 25 Aug 2025 22:13:53 +0300 Subject: [PATCH 07/17] refactor: split utils into smaller files --- internal/utils/fs_utils.go | 17 ++ internal/utils/header_utils.go | 29 +++ internal/utils/other_utils.go | 95 ++++++++ internal/utils/sec_utils.go | 124 +++++++++++ internal/utils/string_utils.go | 30 +++ internal/utils/user_utils.go | 82 +++++++ internal/utils/utils.go | 382 --------------------------------- 7 files changed, 377 insertions(+), 382 deletions(-) create mode 100644 internal/utils/fs_utils.go create mode 100644 internal/utils/header_utils.go create mode 100644 internal/utils/other_utils.go create mode 100644 internal/utils/sec_utils.go create mode 100644 internal/utils/string_utils.go create mode 100644 internal/utils/user_utils.go delete mode 100644 internal/utils/utils.go diff --git a/internal/utils/fs_utils.go b/internal/utils/fs_utils.go new file mode 100644 index 00000000..8b9f28bf --- /dev/null +++ b/internal/utils/fs_utils.go @@ -0,0 +1,17 @@ +package utils + +import "os" + +func ReadFile(file string) (string, error) { + _, err := os.Stat(file) + if err != nil { + return "", err + } + + data, err := os.ReadFile(file) + if err != nil { + return "", err + } + + return string(data), nil +} diff --git a/internal/utils/header_utils.go b/internal/utils/header_utils.go new file mode 100644 index 00000000..1192de56 --- /dev/null +++ b/internal/utils/header_utils.go @@ -0,0 +1,29 @@ +package utils + +import ( + "strings" +) + +func ParseHeaders(headers []string) map[string]string { + headerMap := make(map[string]string) + for _, header := range headers { + split := strings.SplitN(header, "=", 2) + if len(split) != 2 || strings.TrimSpace(split[0]) == "" || strings.TrimSpace(split[1]) == "" { + continue + } + key := SanitizeHeader(strings.TrimSpace(split[0])) + value := SanitizeHeader(strings.TrimSpace(split[1])) + headerMap[key] = value + } + return headerMap +} + +func SanitizeHeader(header string) string { + return strings.Map(func(r rune) rune { + // Allow only printable ASCII characters (32-126) and safe whitespace (space, tab) + if r == ' ' || r == '\t' || (r >= 32 && r <= 126) { + return r + } + return -1 + }, header) +} diff --git a/internal/utils/other_utils.go b/internal/utils/other_utils.go new file mode 100644 index 00000000..17167254 --- /dev/null +++ b/internal/utils/other_utils.go @@ -0,0 +1,95 @@ +package utils + +import ( + "errors" + "net/url" + "strings" + "tinyauth/internal/config" + + "github.com/gin-gonic/gin" + "github.com/traefik/paerser/parser" + + "github.com/rs/zerolog" +) + +// Get upper domain parses a hostname and returns the upper domain (e.g. sub1.sub2.domain.com -> sub2.domain.com) +func GetUpperDomain(urlSrc string) (string, error) { + urlParsed, err := url.Parse(urlSrc) + if err != nil { + return "", err + } + + urlSplitted := strings.Split(urlParsed.Hostname(), ".") + urlFinal := strings.Join(urlSplitted[1:], ".") + + return urlFinal, nil +} + +func ParseFileToLine(content string) string { + lines := strings.Split(content, "\n") + users := make([]string, 0) + + for _, line := range lines { + if strings.TrimSpace(line) == "" { + continue + } + users = append(users, strings.TrimSpace(line)) + } + + return strings.Join(users, ",") +} + +func GetLabels(labels map[string]string) (config.Labels, error) { + var labelsParsed config.Labels + + err := parser.Decode(labels, &labelsParsed, "tinyauth", "tinyauth.users", "tinyauth.allowed", "tinyauth.headers", "tinyauth.domain", "tinyauth.basic", "tinyauth.oauth", "tinyauth.ip") + if err != nil { + return config.Labels{}, err + } + + return labelsParsed, nil +} + +func Filter[T any](slice []T, test func(T) bool) (res []T) { + for _, value := range slice { + if test(value) { + res = append(res, value) + } + } + return res +} + +func GetContext(c *gin.Context) (config.UserContext, error) { + userContextValue, exists := c.Get("context") + + if !exists { + return config.UserContext{}, errors.New("no user context in request") + } + + userContext, ok := userContextValue.(*config.UserContext) + + if !ok { + return config.UserContext{}, errors.New("invalid user context in request") + } + + return *userContext, nil +} + +func GetLogLevel(level string) zerolog.Level { + switch strings.ToLower(level) { + case "debug": + return zerolog.DebugLevel + case "info": + return zerolog.InfoLevel + case "warn": + return zerolog.WarnLevel + case "error": + return zerolog.ErrorLevel + case "fatal": + return zerolog.FatalLevel + case "panic": + return zerolog.PanicLevel + default: + return zerolog.InfoLevel + } +} diff --git a/internal/utils/sec_utils.go b/internal/utils/sec_utils.go new file mode 100644 index 00000000..4e9e1874 --- /dev/null +++ b/internal/utils/sec_utils.go @@ -0,0 +1,124 @@ +package utils + +import ( + "bytes" + "crypto/sha256" + "encoding/base64" + "errors" + "io" + "net" + "regexp" + "strings" + + "github.com/google/uuid" + "golang.org/x/crypto/hkdf" +) + +func GetSecret(conf string, file string) string { + if conf == "" && file == "" { + return "" + } + + if conf != "" { + return conf + } + + contents, err := ReadFile(file) + if err != nil { + return "" + } + + return ParseSecretFile(contents) +} + +func ParseSecretFile(contents string) string { + lines := strings.Split(contents, "\n") + + for _, line := range lines { + if strings.TrimSpace(line) == "" { + continue + } + return strings.TrimSpace(line) + } + + return "" +} + +func GetBasicAuth(username string, password string) string { + auth := username + ":" + password + return base64.StdEncoding.EncodeToString([]byte(auth)) +} + +func DeriveKey(secret string, info string) (string, error) { + hash := sha256.New + hkdf := hkdf.New(hash, []byte(secret), nil, []byte(info)) // I am not using a salt because I just want two different keys from one secret, maybe bad practice + key := make([]byte, 24) + + _, err := io.ReadFull(hkdf, key) + if err != nil { + return "", err + } + + if bytes.Equal(key, make([]byte, 24)) { + return "", errors.New("derived key is empty") + } + + encodedKey := base64.StdEncoding.EncodeToString(key) + return encodedKey, nil +} + +func FilterIP(filter string, ip string) (bool, error) { + ipAddr := net.ParseIP(ip) + + if strings.Contains(filter, "/") { + _, cidr, err := net.ParseCIDR(filter) + if err != nil { + return false, err + } + return cidr.Contains(ipAddr), nil + } + + ipFilter := net.ParseIP(filter) + if ipFilter == nil { + return false, errors.New("invalid IP address in filter") + } + + if ipFilter.Equal(ipAddr) { + return true, nil + } + + return false, nil +} + +func CheckFilter(filter string, str string) bool { + if len(strings.TrimSpace(filter)) == 0 { + return true + } + + if strings.HasPrefix(filter, "/") && strings.HasSuffix(filter, "/") { + re, err := regexp.Compile(filter[1 : len(filter)-1]) + if err != nil { + return false + } + + if re.MatchString(str) { + return true + } + } + + filterSplit := strings.Split(filter, ",") + + for _, item := range filterSplit { + if strings.TrimSpace(item) == str { + return true + } + } + + return false +} + +func GenerateIdentifier(str string) string { + uuid := uuid.NewSHA1(uuid.NameSpaceURL, []byte(str)) + uuidString := uuid.String() + return strings.Split(uuidString, "-")[0] +} diff --git a/internal/utils/string_utils.go b/internal/utils/string_utils.go new file mode 100644 index 00000000..8a629adc --- /dev/null +++ b/internal/utils/string_utils.go @@ -0,0 +1,30 @@ +package utils + +import ( + "strings" +) + +func Capitalize(str string) string { + if len(str) == 0 { + return "" + } + return strings.ToUpper(string([]rune(str)[0])) + string([]rune(str)[1:]) +} + +func CoalesceToString(value any) string { + switch v := value.(type) { + case []any: + strs := make([]string, 0, len(v)) + for _, item := range v { + if str, ok := item.(string); ok { + strs = append(strs, str) + continue + } + } + return strings.Join(strs, ",") + case string: + return v + default: + return "" + } +} diff --git a/internal/utils/user_utils.go b/internal/utils/user_utils.go new file mode 100644 index 00000000..bfcec495 --- /dev/null +++ b/internal/utils/user_utils.go @@ -0,0 +1,82 @@ +package utils + +import ( + "errors" + "strings" + "tinyauth/internal/config" +) + +func ParseUsers(users string) ([]config.User, error) { + var usersParsed []config.User + + userList := strings.Split(users, ",") + + if len(userList) == 0 { + return []config.User{}, errors.New("invalid user format") + } + + for _, user := range userList { + parsed, err := ParseUser(user) + if err != nil { + return []config.User{}, err + } + usersParsed = append(usersParsed, parsed) + } + + return usersParsed, nil +} + +func GetUsers(conf string, file string) ([]config.User, error) { + var users string + + if conf == "" && file == "" { + return []config.User{}, nil + } + + if conf != "" { + users += conf + } + + if file != "" { + contents, err := ReadFile(file) + if err == nil { + if users != "" { + users += "," + } + users += ParseFileToLine(contents) + } + } + + return ParseUsers(users) +} + +func ParseUser(user string) (config.User, error) { + if strings.Contains(user, "$$") { + user = strings.ReplaceAll(user, "$$", "$") + } + + userSplit := strings.Split(user, ":") + + if len(userSplit) < 2 || len(userSplit) > 3 { + return config.User{}, errors.New("invalid user format") + } + + for _, userPart := range userSplit { + if strings.TrimSpace(userPart) == "" { + return config.User{}, errors.New("invalid user format") + } + } + + if len(userSplit) == 2 { + return config.User{ + Username: strings.TrimSpace(userSplit[0]), + Password: strings.TrimSpace(userSplit[1]), + }, nil + } + + return config.User{ + Username: strings.TrimSpace(userSplit[0]), + Password: strings.TrimSpace(userSplit[1]), + TotpSecret: strings.TrimSpace(userSplit[2]), + }, nil +} diff --git a/internal/utils/utils.go b/internal/utils/utils.go deleted file mode 100644 index 7181a26e..00000000 --- a/internal/utils/utils.go +++ /dev/null @@ -1,382 +0,0 @@ -package utils - -import ( - "bytes" - "crypto/sha256" - "encoding/base64" - "errors" - "io" - "net" - "net/url" - "os" - "regexp" - "strings" - "tinyauth/internal/config" - - "github.com/gin-gonic/gin" - "github.com/traefik/paerser/parser" - "golang.org/x/crypto/hkdf" - - "github.com/google/uuid" - "github.com/rs/zerolog" - "github.com/rs/zerolog/log" -) - -// Parses a list of comma separated users in a struct -func ParseUsers(users string) ([]config.User, error) { - log.Debug().Msg("Parsing users") - - var usersParsed []config.User - - userList := strings.Split(users, ",") - - if len(userList) == 0 { - return []config.User{}, errors.New("invalid user format") - } - - for _, user := range userList { - parsed, err := ParseUser(user) - if err != nil { - return []config.User{}, err - } - usersParsed = append(usersParsed, parsed) - } - - log.Debug().Msg("Parsed users") - return usersParsed, nil -} - -// Get upper domain parses a hostname and returns the upper domain (e.g. sub1.sub2.domain.com -> sub2.domain.com) -func GetUpperDomain(urlSrc string) (string, error) { - urlParsed, err := url.Parse(urlSrc) - if err != nil { - return "", err - } - - urlSplitted := strings.Split(urlParsed.Hostname(), ".") - urlFinal := strings.Join(urlSplitted[1:], ".") - - return urlFinal, nil -} - -// Reads a file and returns the contents -func ReadFile(file string) (string, error) { - _, err := os.Stat(file) - if err != nil { - return "", err - } - - data, err := os.ReadFile(file) - if err != nil { - return "", err - } - - return string(data), nil -} - -// Parses a file into a comma separated list of users -func ParseFileToLine(content string) string { - lines := strings.Split(content, "\n") - users := make([]string, 0) - - for _, line := range lines { - if strings.TrimSpace(line) == "" { - continue - } - users = append(users, strings.TrimSpace(line)) - } - - return strings.Join(users, ",") -} - -// Get the secret from the config or file -func GetSecret(conf string, file string) string { - if conf == "" && file == "" { - return "" - } - - if conf != "" { - return conf - } - - contents, err := ReadFile(file) - if err != nil { - return "" - } - - return ParseSecretFile(contents) -} - -// Get the users from the config or file -func GetUsers(conf string, file string) ([]config.User, error) { - var users string - - if conf == "" && file == "" { - return []config.User{}, nil - } - - if conf != "" { - log.Debug().Msg("Using users from config") - users += conf - } - - if file != "" { - contents, err := ReadFile(file) - if err == nil { - log.Debug().Msg("Using users from file") - if users != "" { - users += "," - } - users += ParseFileToLine(contents) - } - } - - return ParseUsers(users) -} - -// Parse the headers in a map[string]string format -func ParseHeaders(headers []string) map[string]string { - headerMap := make(map[string]string) - - for _, header := range headers { - split := strings.SplitN(header, "=", 2) - if len(split) != 2 || strings.TrimSpace(split[0]) == "" || strings.TrimSpace(split[1]) == "" { - log.Warn().Str("header", header).Msg("Invalid header format, skipping") - continue - } - key := SanitizeHeader(strings.TrimSpace(split[0])) - value := SanitizeHeader(strings.TrimSpace(split[1])) - headerMap[key] = value - } - - return headerMap -} - -// Get labels parses a map of labels into a struct with only the needed labels -func GetLabels(labels map[string]string) (config.Labels, error) { - var labelsParsed config.Labels - - err := parser.Decode(labels, &labelsParsed, "tinyauth", "tinyauth.users", "tinyauth.allowed", "tinyauth.headers", "tinyauth.domain", "tinyauth.basic", "tinyauth.oauth", "tinyauth.ip") - if err != nil { - log.Error().Err(err).Msg("Error parsing labels") - return config.Labels{}, err - } - - return labelsParsed, nil -} - -// Filter helper function -func Filter[T any](slice []T, test func(T) bool) (res []T) { - for _, value := range slice { - if test(value) { - res = append(res, value) - } - } - return res -} - -// Parse user -func ParseUser(user string) (config.User, error) { - if strings.Contains(user, "$$") { - user = strings.ReplaceAll(user, "$$", "$") - } - - userSplit := strings.Split(user, ":") - - if len(userSplit) < 2 || len(userSplit) > 3 { - return config.User{}, errors.New("invalid user format") - } - - for _, userPart := range userSplit { - if strings.TrimSpace(userPart) == "" { - return config.User{}, errors.New("invalid user format") - } - } - - if len(userSplit) == 2 { - return config.User{ - Username: strings.TrimSpace(userSplit[0]), - Password: strings.TrimSpace(userSplit[1]), - }, nil - } - - return config.User{ - Username: strings.TrimSpace(userSplit[0]), - Password: strings.TrimSpace(userSplit[1]), - TotpSecret: strings.TrimSpace(userSplit[2]), - }, nil -} - -// Parse secret file -func ParseSecretFile(contents string) string { - lines := strings.Split(contents, "\n") - - for _, line := range lines { - if strings.TrimSpace(line) == "" { - continue - } - return strings.TrimSpace(line) - } - - return "" -} - -// Check if a string matches a regex or if it is included in a comma separated list -func CheckFilter(filter string, str string) bool { - if len(strings.TrimSpace(filter)) == 0 { - return true - } - - if strings.HasPrefix(filter, "/") && strings.HasSuffix(filter, "/") { - re, err := regexp.Compile(filter[1 : len(filter)-1]) - if err != nil { - log.Error().Err(err).Msg("Error compiling regex") - return false - } - - if re.MatchString(str) { - return true - } - } - - filterSplit := strings.Split(filter, ",") - - for _, item := range filterSplit { - if strings.TrimSpace(item) == str { - return true - } - } - - return false -} - -// Capitalize just the first letter of a string -func Capitalize(str string) string { - if len(str) == 0 { - return "" - } - return strings.ToUpper(string([]rune(str)[0])) + string([]rune(str)[1:]) -} - -// Sanitize header removes all control characters from a string -func SanitizeHeader(header string) string { - return strings.Map(func(r rune) rune { - // Allow only printable ASCII characters (32-126) and safe whitespace (space, tab) - if r == ' ' || r == '\t' || (r >= 32 && r <= 126) { - return r - } - return -1 - }, header) -} - -// Generate a static identifier from a string -func GenerateIdentifier(str string) string { - uuid := uuid.NewSHA1(uuid.NameSpaceURL, []byte(str)) - uuidString := uuid.String() - log.Debug().Str("uuid", uuidString).Msg("Generated UUID") - return strings.Split(uuidString, "-")[0] -} - -// Get a basic auth header from a username and password -func GetBasicAuth(username string, password string) string { - auth := username + ":" + password - return base64.StdEncoding.EncodeToString([]byte(auth)) -} - -// Check if an IP is contained in a CIDR range/matches a single IP -func FilterIP(filter string, ip string) (bool, error) { - ipAddr := net.ParseIP(ip) - - if strings.Contains(filter, "/") { - _, cidr, err := net.ParseCIDR(filter) - if err != nil { - return false, err - } - return cidr.Contains(ipAddr), nil - } - - ipFilter := net.ParseIP(filter) - if ipFilter == nil { - return false, errors.New("invalid IP address in filter") - } - - if ipFilter.Equal(ipAddr) { - return true, nil - } - - return false, nil -} - -func DeriveKey(secret string, info string) (string, error) { - hash := sha256.New - hkdf := hkdf.New(hash, []byte(secret), nil, []byte(info)) // I am not using a salt because I just want two different keys from one secret, maybe bad practice - key := make([]byte, 24) - - _, err := io.ReadFull(hkdf, key) - if err != nil { - return "", err - } - - if bytes.Equal(key, make([]byte, 24)) { - return "", errors.New("derived key is empty") - } - - encodedKey := base64.StdEncoding.EncodeToString(key) - return encodedKey, nil -} - -func CoalesceToString(value any) string { - switch v := value.(type) { - case []any: - log.Debug().Msg("Coalescing []any to string") - strs := make([]string, 0, len(v)) - for _, item := range v { - if str, ok := item.(string); ok { - strs = append(strs, str) - continue - } - log.Warn().Interface("item", item).Msg("Item in []any is not a string, skipping") - } - return strings.Join(strs, ",") - case string: - return v - default: - log.Warn().Interface("value", value).Interface("type", v).Msg("Unsupported type, returning empty string") - return "" - } -} - -func GetContext(c *gin.Context) (config.UserContext, error) { - userContextValue, exists := c.Get("context") - - if !exists { - return config.UserContext{}, errors.New("no user context in request") - } - - userContext, ok := userContextValue.(*config.UserContext) - - if !ok { - return config.UserContext{}, errors.New("invalid user context in request") - } - - return *userContext, nil -} - -func GetLogLevel(level string) zerolog.Level { - switch strings.ToLower(level) { - case "debug": - return zerolog.DebugLevel - case "info": - return zerolog.InfoLevel - case "warn": - return zerolog.WarnLevel - case "error": - return zerolog.ErrorLevel - case "fatal": - return zerolog.FatalLevel - case "panic": - return zerolog.PanicLevel - default: - return zerolog.InfoLevel - } -} From 6418cbe2bad3425382a03b1a11e741ec5d7ab2be Mon Sep 17 00:00:00 2001 From: Stavros Date: Mon, 25 Aug 2025 22:16:26 +0300 Subject: [PATCH 08/17] refactor: use more clear name for frontend assets --- internal/assets/assets.go | 4 ++-- internal/bootstrap/{bootstrap_app.go => app_bootstrap.go} | 0 internal/middleware/ui_middlware.go | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) rename internal/bootstrap/{bootstrap_app.go => app_bootstrap.go} (100%) diff --git a/internal/assets/assets.go b/internal/assets/assets.go index 69188674..0b572e0c 100644 --- a/internal/assets/assets.go +++ b/internal/assets/assets.go @@ -4,7 +4,7 @@ import ( "embed" ) -// UI assets +// Frontend assets // //go:embed dist -var Assets embed.FS +var FontendAssets embed.FS diff --git a/internal/bootstrap/bootstrap_app.go b/internal/bootstrap/app_bootstrap.go similarity index 100% rename from internal/bootstrap/bootstrap_app.go rename to internal/bootstrap/app_bootstrap.go diff --git a/internal/middleware/ui_middlware.go b/internal/middleware/ui_middlware.go index 22f8ca23..b0fabded 100644 --- a/internal/middleware/ui_middlware.go +++ b/internal/middleware/ui_middlware.go @@ -21,7 +21,7 @@ func NewUIMiddleware() *UIMiddleware { } func (m *UIMiddleware) Init() error { - ui, err := fs.Sub(assets.Assets, "dist") + ui, err := fs.Sub(assets.FontendAssets, "dist") if err != nil { return nil From 304c920b7b0006e8ab3613ff7361534cfc40b6e9 Mon Sep 17 00:00:00 2001 From: Stavros Date: Mon, 25 Aug 2025 22:20:00 +0300 Subject: [PATCH 09/17] feat: allow customizability of resources dir --- cmd/root.go | 1 + internal/bootstrap/app_bootstrap.go | 4 +++- internal/config/config.go | 1 + internal/middleware/ui_middlware.go | 15 +++++++++++---- 4 files changed, 16 insertions(+), 5 deletions(-) diff --git a/cmd/root.go b/cmd/root.go index 898c27f4..8e0245d4 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -113,6 +113,7 @@ func init() { {"ldap-base-dn", "", "LDAP base DN (e.g. dc=example,dc=com)."}, {"ldap-insecure", false, "Skip certificate verification for the LDAP server."}, {"ldap-search-filter", "(uid=%s)", "LDAP search filter for user lookup."}, + {"resources-dir", "/data/resources", "Path to a directory containing custom resources (e.g. background image)."}, } for _, opt := range configOptions { diff --git a/internal/bootstrap/app_bootstrap.go b/internal/bootstrap/app_bootstrap.go index 2dfc61df..f452f25a 100644 --- a/internal/bootstrap/app_bootstrap.go +++ b/internal/bootstrap/app_bootstrap.go @@ -151,7 +151,9 @@ func (app *BootstrapApp) Setup() error { Domain: domain, }, authService, oauthBrokerService) - uiMiddleware := middleware.NewUIMiddleware() + uiMiddleware := middleware.NewUIMiddleware(middleware.UIMiddlewareConfig{ + ResourcesDir: app.Config.ResourcesDir, + }) zerologMiddleware := middleware.NewZerologMiddleware() middlewares = append(middlewares, contextMiddleware, uiMiddleware, zerologMiddleware) diff --git a/internal/config/config.go b/internal/config/config.go index 655b61a6..48961d63 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -55,6 +55,7 @@ type Config struct { LdapBaseDN string `mapstructure:"ldap-base-dn"` LdapInsecure bool `mapstructure:"ldap-insecure"` LdapSearchFilter string `mapstructure:"ldap-search-filter"` + ResourcesDir string `mapstructure:"resources-dir"` } type OAuthLabels struct { diff --git a/internal/middleware/ui_middlware.go b/internal/middleware/ui_middlware.go index b0fabded..cd886b4d 100644 --- a/internal/middleware/ui_middlware.go +++ b/internal/middleware/ui_middlware.go @@ -10,14 +10,21 @@ import ( "github.com/gin-gonic/gin" ) +type UIMiddlewareConfig struct { + ResourcesDir string +} + type UIMiddleware struct { + Config UIMiddlewareConfig UIFS fs.FS UIFileServer http.Handler ResourcesFileServer http.Handler } -func NewUIMiddleware() *UIMiddleware { - return &UIMiddleware{} +func NewUIMiddleware(config UIMiddlewareConfig) *UIMiddleware { + return &UIMiddleware{ + Config: config, + } } func (m *UIMiddleware) Init() error { @@ -29,7 +36,7 @@ func (m *UIMiddleware) Init() error { m.UIFS = ui m.UIFileServer = http.FileServer(http.FS(ui)) - m.ResourcesFileServer = http.FileServer(http.Dir("/data/resources")) + m.ResourcesFileServer = http.FileServer(http.Dir(m.Config.ResourcesDir)) return nil } @@ -45,7 +52,7 @@ func (m *UIMiddleware) Middleware() gin.HandlerFunc { c.Next() return case "resources": - _, err := os.Stat("/data/resources/" + strings.TrimPrefix(c.Request.URL.Path, "/resources/")) + _, err := os.Stat(m.Config.ResourcesDir + strings.TrimPrefix(c.Request.URL.Path, "/resources/")) if os.IsNotExist(err) { c.Status(404) From cb8022af9191791414fd4f32b77009fc789db304 Mon Sep 17 00:00:00 2001 From: Stavros Date: Mon, 25 Aug 2025 22:25:17 +0300 Subject: [PATCH 10/17] fix: fix typo in ui middleware --- internal/middleware/{ui_middlware.go => ui_middleware.go} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename internal/middleware/{ui_middlware.go => ui_middleware.go} (100%) diff --git a/internal/middleware/ui_middlware.go b/internal/middleware/ui_middleware.go similarity index 100% rename from internal/middleware/ui_middlware.go rename to internal/middleware/ui_middleware.go From 03af18fd153f4e6f7db8e8684e656b1841808ffa Mon Sep 17 00:00:00 2001 From: Stavros Date: Mon, 25 Aug 2025 22:31:57 +0300 Subject: [PATCH 11/17] fix: validate resource file paths in ui middleware --- internal/middleware/ui_middleware.go | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/internal/middleware/ui_middleware.go b/internal/middleware/ui_middleware.go index cd886b4d..0ce139be 100644 --- a/internal/middleware/ui_middleware.go +++ b/internal/middleware/ui_middleware.go @@ -4,6 +4,7 @@ import ( "io/fs" "net/http" "os" + "path/filepath" "strings" "tinyauth/internal/assets" @@ -52,7 +53,15 @@ func (m *UIMiddleware) Middleware() gin.HandlerFunc { c.Next() return case "resources": - _, err := os.Stat(m.Config.ResourcesDir + strings.TrimPrefix(c.Request.URL.Path, "/resources/")) + requestFilePath := m.Config.ResourcesDir + strings.TrimPrefix(c.Request.URL.Path, "/resources/") + + if !filepath.IsLocal(requestFilePath) { + c.Status(404) + c.Abort() + return + } + + _, err := os.Stat(requestFilePath) if os.IsNotExist(err) { c.Status(404) From 645c555cf036d1a86cedca3c55bba5fdba84a1d0 Mon Sep 17 00:00:00 2001 From: Stavros Date: Tue, 26 Aug 2025 12:22:10 +0300 Subject: [PATCH 12/17] refactor: move resource handling to a controller --- .gitignore | 5 ++- cmd/root.go | 1 + frontend/vite.config.ts | 5 +++ internal/bootstrap/app_bootstrap.go | 36 ++++++++++------ internal/controller/health_controller.go | 1 + internal/controller/resources_controller.go | 32 +++++++++++++++ internal/middleware/context_middleware.go | 4 -- internal/middleware/ui_middleware.go | 41 +++---------------- internal/middleware/zerolog_middleware.go | 4 -- .../utils/{other_utils.go => app_utils.go} | 12 ------ .../utils/{header_utils.go => label_utils.go} | 14 +++++++ .../utils/{sec_utils.go => security_utils.go} | 0 main.go | 2 +- 13 files changed, 86 insertions(+), 71 deletions(-) create mode 100644 internal/controller/resources_controller.go rename internal/utils/{other_utils.go => app_utils.go} (81%) rename internal/utils/{header_utils.go => label_utils.go} (63%) rename internal/utils/{sec_utils.go => security_utils.go} (100%) diff --git a/.gitignore b/.gitignore index 0100a134..cb79b93b 100644 --- a/.gitignore +++ b/.gitignore @@ -23,4 +23,7 @@ secret* tmp # version files -internal/assets/version \ No newline at end of file +internal/assets/version + +# data directory +data \ No newline at end of file diff --git a/cmd/root.go b/cmd/root.go index 8e0245d4..2b0c172e 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -33,6 +33,7 @@ var rootCmd = &cobra.Command{ conf.GoogleClientSecret = utils.GetSecret(conf.GoogleClientSecret, conf.GoogleClientSecretFile) conf.GenericClientSecret = utils.GetSecret(conf.GenericClientSecret, conf.GenericClientSecretFile) + // Validate config validator := validator.New() err = validator.Struct(conf) diff --git a/frontend/vite.config.ts b/frontend/vite.config.ts index 07e6e7e6..f391a49d 100644 --- a/frontend/vite.config.ts +++ b/frontend/vite.config.ts @@ -19,6 +19,11 @@ export default defineConfig({ changeOrigin: true, rewrite: (path) => path.replace(/^\/api/, ""), }, + "/resources": { + target: "http://tinyauth-backend:3000/resources", + changeOrigin: true, + rewrite: (path) => path.replace(/^\/resources/, ""), + }, }, allowedHosts: true, }, diff --git a/internal/bootstrap/app_bootstrap.go b/internal/bootstrap/app_bootstrap.go index f452f25a..54838d1e 100644 --- a/internal/bootstrap/app_bootstrap.go +++ b/internal/bootstrap/app_bootstrap.go @@ -20,7 +20,6 @@ type Controller interface { type Middleware interface { Middleware() gin.HandlerFunc Init() error - Name() string } type Service interface { @@ -103,6 +102,7 @@ func (app *BootstrapApp) Setup() error { err := ldapService.Init() if err != nil { + log.Warn().Err(err).Msg("Failed to initialize LDAP service, continuing without LDAP") ldapService = nil } } @@ -120,6 +120,7 @@ func (app *BootstrapApp) Setup() error { for _, svc := range services { if svc != nil { + log.Debug().Str("service", fmt.Sprintf("%T", svc)).Msg("Initializing service") err := svc.Init() if err != nil { return err @@ -142,7 +143,13 @@ func (app *BootstrapApp) Setup() error { // Create engine engine := gin.New() - router := engine.Group("/api") + + if config.Version != "development" { + gin.SetMode(gin.ReleaseMode) + } + + router := engine.Group("/") + apiRouter := router.Group("/api") // Create middlewares var middlewares []Middleware @@ -151,18 +158,16 @@ func (app *BootstrapApp) Setup() error { Domain: domain, }, authService, oauthBrokerService) - uiMiddleware := middleware.NewUIMiddleware(middleware.UIMiddlewareConfig{ - ResourcesDir: app.Config.ResourcesDir, - }) + uiMiddleware := middleware.NewUIMiddleware() zerologMiddleware := middleware.NewZerologMiddleware() middlewares = append(middlewares, contextMiddleware, uiMiddleware, zerologMiddleware) for _, middleware := range middlewares { - log.Debug().Str("middleware", middleware.Name()).Msg("Initializing middleware") + log.Debug().Str("middleware", fmt.Sprintf("%T", middleware)).Msg("Initializing middleware") err := middleware.Init() if err != nil { - return fmt.Errorf("failed to initialize %s middleware: %w", middleware.Name(), err) + return fmt.Errorf("failed to initialize %s middleware: %T", middleware, err) } router.Use(middleware.Middleware()) } @@ -177,24 +182,28 @@ func (app *BootstrapApp) Setup() error { ForgotPasswordMessage: app.Config.FogotPasswordMessage, BackgroundImage: app.Config.BackgroundImage, OAuthAutoRedirect: app.Config.OAuthAutoRedirect, - }, router) + }, apiRouter) oauthController := controller.NewOAuthController(controller.OAuthControllerConfig{ AppURL: app.Config.AppURL, SecureCookie: app.Config.SecureCookie, CSRFCookieName: csrfCookieName, RedirectCookieName: redirectCookieName, - }, router, authService, oauthBrokerService) + }, apiRouter, authService, oauthBrokerService) proxyController := controller.NewProxyController(controller.ProxyControllerConfig{ AppURL: app.Config.AppURL, - }, router, dockerService, authService) + }, apiRouter, dockerService, authService) userController := controller.NewUserController(controller.UserControllerConfig{ Domain: domain, - }, router, authService) + }, apiRouter, authService) + + resourcesController := controller.NewResourcesController(controller.ResourcesControllerConfig{ + ResourcesDir: app.Config.ResourcesDir, + }, router) - healthController := controller.NewHealthController(router) + healthController := controller.NewHealthController(apiRouter) // Setup routes controller := []Controller{ @@ -203,10 +212,11 @@ func (app *BootstrapApp) Setup() error { proxyController, userController, healthController, + resourcesController, } for _, ctrl := range controller { - log.Debug().Msgf("Setting up %T routes", ctrl) + log.Debug().Msgf("Setting up %T controller", ctrl) ctrl.SetupRoutes() } diff --git a/internal/controller/health_controller.go b/internal/controller/health_controller.go index 2330fb17..842b3d30 100644 --- a/internal/controller/health_controller.go +++ b/internal/controller/health_controller.go @@ -14,6 +14,7 @@ func NewHealthController(router *gin.RouterGroup) *HealthController { func (controller *HealthController) SetupRoutes() { controller.Router.GET("/health", controller.healthHandler) + controller.Router.HEAD("/health", controller.healthHandler) } func (controller *HealthController) healthHandler(c *gin.Context) { diff --git a/internal/controller/resources_controller.go b/internal/controller/resources_controller.go new file mode 100644 index 00000000..f0c20096 --- /dev/null +++ b/internal/controller/resources_controller.go @@ -0,0 +1,32 @@ +package controller + +import ( + "net/http" + + "github.com/gin-gonic/gin" +) + +type ResourcesControllerConfig struct { + ResourcesDir string +} + +type ResourcesController struct { + Config ResourcesControllerConfig + Router *gin.RouterGroup +} + +func NewResourcesController(config ResourcesControllerConfig, router *gin.RouterGroup) *ResourcesController { + return &ResourcesController{ + Config: config, + Router: router, + } +} + +func (controller *ResourcesController) SetupRoutes() { + controller.Router.GET("/resources/*resource", controller.resourcesHandler) +} + +func (controller *ResourcesController) resourcesHandler(c *gin.Context) { + fileServer := http.StripPrefix("/resources", http.FileServer(http.Dir(controller.Config.ResourcesDir))) + fileServer.ServeHTTP(c.Writer, c.Request) +} diff --git a/internal/middleware/context_middleware.go b/internal/middleware/context_middleware.go index a83d4659..62f2a643 100644 --- a/internal/middleware/context_middleware.go +++ b/internal/middleware/context_middleware.go @@ -32,10 +32,6 @@ func (m *ContextMiddleware) Init() error { return nil } -func (m *ContextMiddleware) Name() string { - return "ContextMiddleware" -} - func (m *ContextMiddleware) Middleware() gin.HandlerFunc { return func(c *gin.Context) { cookie, err := m.Auth.GetSessionCookie(c) diff --git a/internal/middleware/ui_middleware.go b/internal/middleware/ui_middleware.go index 0ce139be..6c03e4ff 100644 --- a/internal/middleware/ui_middleware.go +++ b/internal/middleware/ui_middleware.go @@ -4,28 +4,19 @@ import ( "io/fs" "net/http" "os" - "path/filepath" "strings" "tinyauth/internal/assets" "github.com/gin-gonic/gin" ) -type UIMiddlewareConfig struct { - ResourcesDir string -} - type UIMiddleware struct { - Config UIMiddlewareConfig - UIFS fs.FS - UIFileServer http.Handler - ResourcesFileServer http.Handler + UIFS fs.FS + UIFileServer http.Handler } -func NewUIMiddleware(config UIMiddlewareConfig) *UIMiddleware { - return &UIMiddleware{ - Config: config, - } +func NewUIMiddleware() *UIMiddleware { + return &UIMiddleware{} } func (m *UIMiddleware) Init() error { @@ -37,15 +28,10 @@ func (m *UIMiddleware) Init() error { m.UIFS = ui m.UIFileServer = http.FileServer(http.FS(ui)) - m.ResourcesFileServer = http.FileServer(http.Dir(m.Config.ResourcesDir)) return nil } -func (m *UIMiddleware) Name() string { - return "UIMiddleware" -} - func (m *UIMiddleware) Middleware() gin.HandlerFunc { return func(c *gin.Context) { switch strings.Split(c.Request.URL.Path, "/")[1] { @@ -53,24 +39,7 @@ func (m *UIMiddleware) Middleware() gin.HandlerFunc { c.Next() return case "resources": - requestFilePath := m.Config.ResourcesDir + strings.TrimPrefix(c.Request.URL.Path, "/resources/") - - if !filepath.IsLocal(requestFilePath) { - c.Status(404) - c.Abort() - return - } - - _, err := os.Stat(requestFilePath) - - if os.IsNotExist(err) { - c.Status(404) - c.Abort() - return - } - - m.ResourcesFileServer.ServeHTTP(c.Writer, c.Request) - c.Abort() + c.Next() return default: _, err := fs.Stat(m.UIFS, strings.TrimPrefix(c.Request.URL.Path, "/")) diff --git a/internal/middleware/zerolog_middleware.go b/internal/middleware/zerolog_middleware.go index 79c5d706..95f5821f 100644 --- a/internal/middleware/zerolog_middleware.go +++ b/internal/middleware/zerolog_middleware.go @@ -26,10 +26,6 @@ func (m *ZerologMiddleware) Init() error { return nil } -func (m *ZerologMiddleware) Name() string { - return "ZerologMiddleware" -} - func (m *ZerologMiddleware) logPath(path string) bool { for _, prefix := range loggerSkipPathsPrefix { if strings.HasPrefix(path, prefix) { diff --git a/internal/utils/other_utils.go b/internal/utils/app_utils.go similarity index 81% rename from internal/utils/other_utils.go rename to internal/utils/app_utils.go index 17167254..1ed8d4c7 100644 --- a/internal/utils/other_utils.go +++ b/internal/utils/app_utils.go @@ -7,7 +7,6 @@ import ( "tinyauth/internal/config" "github.com/gin-gonic/gin" - "github.com/traefik/paerser/parser" "github.com/rs/zerolog" ) @@ -39,17 +38,6 @@ func ParseFileToLine(content string) string { return strings.Join(users, ",") } -func GetLabels(labels map[string]string) (config.Labels, error) { - var labelsParsed config.Labels - - err := parser.Decode(labels, &labelsParsed, "tinyauth", "tinyauth.users", "tinyauth.allowed", "tinyauth.headers", "tinyauth.domain", "tinyauth.basic", "tinyauth.oauth", "tinyauth.ip") - if err != nil { - return config.Labels{}, err - } - - return labelsParsed, nil -} - func Filter[T any](slice []T, test func(T) bool) (res []T) { for _, value := range slice { if test(value) { diff --git a/internal/utils/header_utils.go b/internal/utils/label_utils.go similarity index 63% rename from internal/utils/header_utils.go rename to internal/utils/label_utils.go index 1192de56..a01685b3 100644 --- a/internal/utils/header_utils.go +++ b/internal/utils/label_utils.go @@ -2,8 +2,22 @@ package utils import ( "strings" + "tinyauth/internal/config" + + "github.com/traefik/paerser/parser" ) +func GetLabels(labels map[string]string) (config.Labels, error) { + var labelsParsed config.Labels + + err := parser.Decode(labels, &labelsParsed, "tinyauth", "tinyauth.users", "tinyauth.allowed", "tinyauth.headers", "tinyauth.domain", "tinyauth.basic", "tinyauth.oauth", "tinyauth.ip") + if err != nil { + return config.Labels{}, err + } + + return labelsParsed, nil +} + func ParseHeaders(headers []string) map[string]string { headerMap := make(map[string]string) for _, header := range headers { diff --git a/internal/utils/sec_utils.go b/internal/utils/security_utils.go similarity index 100% rename from internal/utils/sec_utils.go rename to internal/utils/security_utils.go diff --git a/main.go b/main.go index 27792d81..eac789e8 100644 --- a/main.go +++ b/main.go @@ -10,6 +10,6 @@ import ( ) func main() { - log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr, TimeFormat: time.RFC3339}).With().Timestamp().Logger().Level(zerolog.FatalLevel) + log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr, TimeFormat: time.RFC3339}).With().Timestamp().Caller().Logger().Level(zerolog.FatalLevel) cmd.Execute() } From 77296daef30d7cbea57de1ed2faa6dff90629b9f Mon Sep 17 00:00:00 2001 From: Stavros Date: Tue, 26 Aug 2025 12:45:24 +0300 Subject: [PATCH 13/17] feat: add some logging --- internal/controller/context_controller.go | 2 + internal/controller/oauth_controller.go | 15 ++++++ internal/controller/proxy_controller.go | 30 ++++++++++++ internal/controller/user_controller.go | 21 ++++++++ internal/middleware/context_middleware.go | 12 +++++ internal/service/auth_service.go | 60 +++++------------------ internal/service/docker_service.go | 15 ++---- internal/service/ldap_service.go | 4 +- 8 files changed, 98 insertions(+), 61 deletions(-) diff --git a/internal/controller/context_controller.go b/internal/controller/context_controller.go index c7dfccfe..c7570f0e 100644 --- a/internal/controller/context_controller.go +++ b/internal/controller/context_controller.go @@ -4,6 +4,7 @@ import ( "tinyauth/internal/utils" "github.com/gin-gonic/gin" + "github.com/rs/zerolog/log" ) type UserContextResponse struct { @@ -76,6 +77,7 @@ func (controller *ContextController) userContextHandler(c *gin.Context) { } if err != nil { + log.Debug().Err(err).Msg("No user context found in request") userContext.Status = 401 userContext.Message = "Unauthorized" userContext.IsLoggedIn = false diff --git a/internal/controller/oauth_controller.go b/internal/controller/oauth_controller.go index 0178af6c..025db1bd 100644 --- a/internal/controller/oauth_controller.go +++ b/internal/controller/oauth_controller.go @@ -11,6 +11,7 @@ import ( "github.com/gin-gonic/gin" "github.com/google/go-querystring/query" + "github.com/rs/zerolog/log" ) type OAuthRequest struct { @@ -51,6 +52,7 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) { err := c.BindUri(&req) if err != nil { + log.Error().Err(err).Msg("Failed to bind URI") c.JSON(400, gin.H{ "status": 400, "message": "Bad Request", @@ -61,6 +63,7 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) { service, exists := controller.Broker.GetService(req.Provider) if !exists { + log.Warn().Msgf("OAuth provider not found: %s", req.Provider) c.JSON(404, gin.H{ "status": 404, "message": "Not Found", @@ -75,6 +78,7 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) { redirectURI := c.Query("redirect_uri") if redirectURI != "" { + log.Debug().Msg("Setting redirect URI cookie") c.SetCookie(controller.Config.RedirectCookieName, redirectURI, int(time.Hour.Seconds()), "/", "", controller.Config.SecureCookie, true) } @@ -90,6 +94,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) { err := c.BindUri(&req) if err != nil { + log.Error().Err(err).Msg("Failed to bind URI") c.JSON(400, gin.H{ "status": 400, "message": "Bad Request", @@ -101,6 +106,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) { csrfCookie, err := c.Cookie(controller.Config.CSRFCookieName) if err != nil || state != csrfCookie { + log.Warn().Err(err).Msg("CSRF token mismatch or cookie missing") c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.Config.AppURL)) return } @@ -111,12 +117,14 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) { service, exists := controller.Broker.GetService(req.Provider) if !exists { + log.Warn().Msgf("OAuth provider not found: %s", req.Provider) c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.Config.AppURL)) return } err = service.VerifyCode(code) if err != nil { + log.Error().Err(err).Msg("Failed to verify OAuth code") c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.Config.AppURL)) return } @@ -124,11 +132,13 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) { user, err := controller.Broker.GetUser(req.Provider) if err != nil { + log.Error().Err(err).Msg("Failed to get user from OAuth provider") c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.Config.AppURL)) return } if user.Email == "" { + log.Error().Msg("OAuth provider did not return an email") c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.Config.AppURL)) return } @@ -139,6 +149,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) { }) if err != nil { + log.Error().Err(err).Msg("Failed to encode unauthorized query") c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.Config.AppURL)) return } @@ -150,8 +161,10 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) { var name string if user.Name != "" { + log.Debug().Msg("Using name from OAuth provider") name = user.Name } else { + log.Debug().Msg("No name from OAuth provider, using pseudo name") name = fmt.Sprintf("%s (%s)", utils.Capitalize(strings.Split(user.Email, "@")[0]), strings.Split(user.Email, "@")[1]) } @@ -166,6 +179,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) { redirectURI, err := c.Cookie(controller.Config.RedirectCookieName) if err != nil { + log.Debug().Msg("No redirect URI cookie found, redirecting to app root") c.Redirect(http.StatusTemporaryRedirect, controller.Config.AppURL) return } @@ -175,6 +189,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) { }) if err != nil { + log.Error().Err(err).Msg("Failed to encode redirect URI query") c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.Config.AppURL)) return } diff --git a/internal/controller/proxy_controller.go b/internal/controller/proxy_controller.go index ced09bff..9515d329 100644 --- a/internal/controller/proxy_controller.go +++ b/internal/controller/proxy_controller.go @@ -10,6 +10,7 @@ import ( "github.com/gin-gonic/gin" "github.com/google/go-querystring/query" + "github.com/rs/zerolog/log" ) type Proxy struct { @@ -46,6 +47,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) { err := c.BindUri(&req) if err != nil { + log.Error().Err(err).Msg("Failed to bind URI") c.JSON(400, gin.H{ "status": 400, "message": "Bad Request", @@ -55,6 +57,12 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) { isBrowser := strings.Contains(c.Request.Header.Get("Accept"), "text/html") + if isBrowser { + log.Debug().Msg("Request identified as (most likely) coming from a browser") + } else { + log.Debug().Msg("Request identified as (most likely) coming from a non-browser client") + } + uri := c.Request.Header.Get("X-Forwarded-Uri") proto := c.Request.Header.Get("X-Forwarded-Proto") host := c.Request.Header.Get("X-Forwarded-Host") @@ -65,6 +73,8 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) { labels, err := controller.Docker.GetLabels(id, hostWithoutPort) if err != nil { + log.Error().Err(err).Msg("Failed to get labels from Docker") + if req.Proxy == "nginx" || !isBrowser { c.JSON(500, gin.H{ "status": 500, @@ -85,10 +95,12 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) { headers := utils.ParseHeaders(labels.Headers) for key, value := range headers { + log.Debug().Str("header", key).Msg("Setting header") c.Header(key, value) } if labels.Basic.Username != "" && utils.GetSecret(labels.Basic.Password.Plain, labels.Basic.Password.File) != "" { + log.Debug().Str("username", labels.Basic.Username).Msg("Setting basic auth header") c.Header("Authorization", fmt.Sprintf("Basic %s", utils.GetBasicAuth(labels.Basic.Username, utils.GetSecret(labels.Basic.Password.Plain, labels.Basic.Password.File)))) } @@ -114,6 +126,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) { }) if err != nil { + log.Error().Err(err).Msg("Failed to encode unauthorized query") c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.Config.AppURL)) } @@ -124,6 +137,8 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) { authEnabled, err := controller.Auth.AuthEnabled(uri, labels) if err != nil { + log.Error().Err(err).Msg("Failed to check if auth is enabled for resource") + if req.Proxy == "nginx" || !isBrowser { c.JSON(500, gin.H{ "status": 500, @@ -137,15 +152,19 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) { } if !authEnabled { + log.Debug().Msg("Authentication disabled for resource, allowing access") + c.Header("Authorization", c.Request.Header.Get("Authorization")) headers := utils.ParseHeaders(labels.Headers) for key, value := range headers { + log.Debug().Str("header", key).Msg("Setting header") c.Header(key, value) } if labels.Basic.Username != "" && utils.GetSecret(labels.Basic.Password.Plain, labels.Basic.Password.File) != "" { + log.Debug().Str("username", labels.Basic.Username).Msg("Setting basic auth header") c.Header("Authorization", fmt.Sprintf("Basic %s", utils.GetBasicAuth(labels.Basic.Username, utils.GetSecret(labels.Basic.Password.Plain, labels.Basic.Password.File)))) } @@ -161,6 +180,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) { context, err := utils.GetContext(c) if err != nil { + log.Debug().Msg("No user context found in request, treating as not logged in") userContext = config.UserContext{ IsLoggedIn: false, } @@ -169,6 +189,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) { } if userContext.Provider == "basic" && userContext.TotpEnabled { + log.Debug().Msg("User has TOTP enabled, denying basic auth access") userContext.IsLoggedIn = false } @@ -176,6 +197,8 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) { appAllowed := controller.Auth.ResourceAllowed(c, userContext, labels) if !appAllowed { + log.Warn().Str("user", userContext.Username).Str("resource", strings.Split(host, ".")[0]).Msg("User not allowed to access resource") + if req.Proxy == "nginx" || !isBrowser { c.JSON(403, gin.H{ "status": 403, @@ -195,6 +218,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) { } if err != nil { + log.Error().Err(err).Msg("Failed to encode unauthorized query") c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.Config.AppURL)) return } @@ -207,6 +231,8 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) { groupOK := controller.Auth.OAuthGroup(c, userContext, labels) if !groupOK { + log.Warn().Str("user", userContext.Username).Str("resource", strings.Split(host, ".")[0]).Msg("User OAuth groups do not match resource requirements") + if req.Proxy == "nginx" || !isBrowser { c.JSON(403, gin.H{ "status": 403, @@ -227,6 +253,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) { } if err != nil { + log.Error().Err(err).Msg("Failed to encode unauthorized query") c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.Config.AppURL)) return } @@ -245,10 +272,12 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) { headers := utils.ParseHeaders(labels.Headers) for key, value := range headers { + log.Debug().Str("header", key).Msg("Setting header") c.Header(key, value) } if labels.Basic.Username != "" && utils.GetSecret(labels.Basic.Password.Plain, labels.Basic.Password.File) != "" { + log.Debug().Str("username", labels.Basic.Username).Msg("Setting basic auth header") c.Header("Authorization", fmt.Sprintf("Basic %s", utils.GetBasicAuth(labels.Basic.Username, utils.GetSecret(labels.Basic.Password.Plain, labels.Basic.Password.File)))) } @@ -272,6 +301,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) { }) if err != nil { + log.Error().Err(err).Msg("Failed to encode redirect URI query") c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.Config.AppURL)) return } diff --git a/internal/controller/user_controller.go b/internal/controller/user_controller.go index 77bb6f33..7f307e37 100644 --- a/internal/controller/user_controller.go +++ b/internal/controller/user_controller.go @@ -9,6 +9,7 @@ import ( "github.com/gin-gonic/gin" "github.com/pquerna/otp/totp" + "github.com/rs/zerolog/log" ) type LoginRequest struct { @@ -50,6 +51,7 @@ func (controller *UserController) loginHandler(c *gin.Context) { err := c.BindJSON(&req) if err != nil { + log.Error().Err(err).Msg("Failed to bind JSON") c.JSON(400, gin.H{ "status": 400, "message": "Bad Request", @@ -65,9 +67,12 @@ func (controller *UserController) loginHandler(c *gin.Context) { rateIdentifier = clientIP } + log.Debug().Str("username", req.Username).Str("ip", clientIP).Msg("Login attempt") + isLocked, remainingTime := controller.Auth.IsAccountLocked(rateIdentifier) if isLocked { + log.Warn().Str("username", req.Username).Str("ip", clientIP).Msg("Account is locked due to too many failed login attempts") c.JSON(429, gin.H{ "status": 429, "message": fmt.Sprintf("Too many failed login attempts. Try again in %d seconds", remainingTime), @@ -78,6 +83,7 @@ func (controller *UserController) loginHandler(c *gin.Context) { userSearch := controller.Auth.SearchUser(req.Username) if userSearch.Type == "" { + log.Warn().Str("username", req.Username).Str("ip", clientIP).Msg("User not found") controller.Auth.RecordLoginAttempt(rateIdentifier, false) c.JSON(401, gin.H{ "status": 401, @@ -87,6 +93,7 @@ func (controller *UserController) loginHandler(c *gin.Context) { } if !controller.Auth.VerifyUser(userSearch, req.Password) { + log.Warn().Str("username", req.Username).Str("ip", clientIP).Msg("Invalid password") controller.Auth.RecordLoginAttempt(rateIdentifier, false) c.JSON(401, gin.H{ "status": 401, @@ -95,12 +102,16 @@ func (controller *UserController) loginHandler(c *gin.Context) { return } + log.Info().Str("username", req.Username).Str("ip", clientIP).Msg("Login successful") + controller.Auth.RecordLoginAttempt(rateIdentifier, true) if userSearch.Type == "local" { user := controller.Auth.GetLocalUser(userSearch.Username) if user.TotpSecret != "" { + log.Debug().Str("username", req.Username).Msg("User has TOTP enabled, requiring TOTP verification") + controller.Auth.CreateSessionCookie(c, &config.SessionCookie{ Username: user.Username, Name: utils.Capitalize(req.Username), @@ -132,6 +143,7 @@ func (controller *UserController) loginHandler(c *gin.Context) { } func (controller *UserController) logoutHandler(c *gin.Context) { + log.Debug().Msg("Logout request received") controller.Auth.DeleteSessionCookie(c) c.JSON(200, gin.H{ "status": 200, @@ -144,6 +156,7 @@ func (controller *UserController) totpHandler(c *gin.Context) { err := c.BindJSON(&req) if err != nil { + log.Error().Err(err).Msg("Failed to bind JSON") c.JSON(400, gin.H{ "status": 400, "message": "Bad Request", @@ -154,6 +167,7 @@ func (controller *UserController) totpHandler(c *gin.Context) { context, err := utils.GetContext(c) if err != nil { + log.Error().Err(err).Msg("Failed to get user context") c.JSON(500, gin.H{ "status": 500, "message": "Internal Server Error", @@ -162,6 +176,7 @@ func (controller *UserController) totpHandler(c *gin.Context) { } if !context.IsLoggedIn { + log.Warn().Msg("TOTP attempt without being logged in") c.JSON(401, gin.H{ "status": 401, "message": "Unauthorized", @@ -177,9 +192,12 @@ func (controller *UserController) totpHandler(c *gin.Context) { rateIdentifier = clientIP } + log.Debug().Str("username", context.Username).Str("ip", clientIP).Msg("TOTP verification attempt") + isLocked, remainingTime := controller.Auth.IsAccountLocked(rateIdentifier) if isLocked { + log.Warn().Str("username", context.Username).Str("ip", clientIP).Msg("Account is locked due to too many failed TOTP attempts") c.JSON(429, gin.H{ "status": 429, "message": fmt.Sprintf("Too many failed login attempts. Try again in %d seconds", remainingTime), @@ -192,6 +210,7 @@ func (controller *UserController) totpHandler(c *gin.Context) { ok := totp.Validate(req.Code, user.TotpSecret) if !ok { + log.Warn().Str("username", context.Username).Str("ip", clientIP).Msg("Invalid TOTP code") controller.Auth.RecordLoginAttempt(rateIdentifier, false) c.JSON(401, gin.H{ "status": 401, @@ -200,6 +219,8 @@ func (controller *UserController) totpHandler(c *gin.Context) { return } + log.Info().Str("username", context.Username).Str("ip", clientIP).Msg("TOTP verification successful") + controller.Auth.RecordLoginAttempt(rateIdentifier, true) controller.Auth.CreateSessionCookie(c, &config.SessionCookie{ diff --git a/internal/middleware/context_middleware.go b/internal/middleware/context_middleware.go index 62f2a643..e11f80ca 100644 --- a/internal/middleware/context_middleware.go +++ b/internal/middleware/context_middleware.go @@ -8,6 +8,7 @@ import ( "tinyauth/internal/utils" "github.com/gin-gonic/gin" + "github.com/rs/zerolog/log" ) type ContextMiddlewareConfig struct { @@ -37,6 +38,7 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc { cookie, err := m.Auth.GetSessionCookie(c) if err != nil { + log.Debug().Err(err).Msg("No valid session cookie found") goto basic } @@ -58,6 +60,8 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc { userSearch := m.Auth.SearchUser(cookie.Username) if userSearch.Type == "unknown" { + log.Debug().Msg("User from session cookie not found") + m.Auth.DeleteSessionCookie(c) goto basic } @@ -74,10 +78,12 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc { _, exists := m.Broker.GetService(cookie.Provider) if !exists { + log.Debug().Msg("OAuth provider from session cookie not found") goto basic } if !m.Auth.EmailWhitelisted(cookie.Email) { + log.Debug().Msg("Email from session cookie not whitelisted") m.Auth.DeleteSessionCookie(c) goto basic } @@ -99,6 +105,7 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc { basic := m.Auth.GetBasicAuth(c) if basic == nil { + log.Debug().Msg("No basic auth provided") c.Next() return } @@ -106,17 +113,21 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc { userSearch := m.Auth.SearchUser(basic.Username) if userSearch.Type == "unknown" { + log.Debug().Msg("User from basic auth not found") c.Next() return } if !m.Auth.VerifyUser(userSearch, basic.Password) { + log.Debug().Msg("Invalid password for basic auth user") c.Next() return } switch userSearch.Type { case "local": + log.Debug().Msg("Basic auth user is local") + user := m.Auth.GetLocalUser(basic.Username) c.Set("context", &config.UserContext{ @@ -130,6 +141,7 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc { c.Next() return case "ldap": + log.Debug().Msg("Basic auth user is LDAP") c.Set("context", &config.UserContext{ Username: basic.Username, Name: utils.Capitalize(basic.Username), diff --git a/internal/service/auth_service.go b/internal/service/auth_service.go index e91f98a1..8c91e790 100644 --- a/internal/service/auth_service.go +++ b/internal/service/auth_service.go @@ -61,6 +61,7 @@ func (auth *AuthService) Init() error { HttpOnly: true, Domain: fmt.Sprintf(".%s", auth.Config.Domain), } + auth.Store = store return nil } @@ -70,11 +71,10 @@ func (auth *AuthService) GetSession(c *gin.Context) (*sessions.Session, error) { // If there was an error getting the session, it might be invalid so let's clear it and retry if err != nil { - log.Error().Err(err).Msg("Invalid session, clearing cookie and retrying") + log.Debug().Err(err).Msg("Error getting session, clearing cookie and retrying") c.SetCookie(auth.Config.SessionCookieName, "", -1, "/", fmt.Sprintf(".%s", auth.Config.Domain), auth.Config.SecureCookie, true) session, err = auth.Store.Get(c.Request, auth.Config.SessionCookieName) if err != nil { - log.Error().Err(err).Msg("Failed to get session") return nil, err } } @@ -83,25 +83,21 @@ func (auth *AuthService) GetSession(c *gin.Context) (*sessions.Session, error) { } func (auth *AuthService) SearchUser(username string) config.UserSearch { - log.Debug().Str("username", username).Msg("Searching for user") - - // Check local users first if auth.GetLocalUser(username).Username != "" { - log.Debug().Str("username", username).Msg("Found local user") return config.UserSearch{ Username: username, Type: "local", } } - // If no user found, check LDAP if auth.LDAP != nil { - log.Debug().Str("username", username).Msg("Checking LDAP for user") userDN, err := auth.LDAP.Search(username) + if err != nil { - log.Warn().Err(err).Str("username", username).Msg("Failed to find user in LDAP") + log.Warn().Err(err).Str("username", username).Msg("Failed to search for user in LDAP") return config.UserSearch{} } + return config.UserSearch{ Username: userDN, Type: "ldap", @@ -114,54 +110,42 @@ func (auth *AuthService) SearchUser(username string) config.UserSearch { } func (auth *AuthService) VerifyUser(search config.UserSearch, password string) bool { - // Authenticate the user based on the type switch search.Type { case "local": - // If local user, get the user and check the password user := auth.GetLocalUser(search.Username) return auth.CheckPassword(user, password) case "ldap": - // If LDAP is configured, bind to the LDAP server with the user DN and password if auth.LDAP != nil { - log.Debug().Str("username", search.Username).Msg("Binding to LDAP for user authentication") - err := auth.LDAP.Bind(search.Username, password) if err != nil { log.Warn().Err(err).Str("username", search.Username).Msg("Failed to bind to LDAP") return false } - // Rebind with the service account to reset the connection err = auth.LDAP.Bind(auth.LDAP.Config.BindDN, auth.LDAP.Config.BindPassword) if err != nil { log.Error().Err(err).Msg("Failed to rebind with service account after user authentication") return false } - log.Debug().Str("username", search.Username).Msg("LDAP authentication successful") return true } default: - log.Warn().Str("type", search.Type).Msg("Unknown user type for authentication") + log.Debug().Str("type", search.Type).Msg("Unknown user type for authentication") return false } - // If no user found or authentication failed, return false log.Warn().Str("username", search.Username).Msg("User authentication failed") return false } func (auth *AuthService) GetLocalUser(username string) config.User { - // Loop through users and return the user if the username matches - log.Debug().Str("username", username).Msg("Searching for local user") - for _, user := range auth.Config.Users { if user.Username == username { return user } } - // If no user found, return an empty user log.Warn().Str("username", username).Msg("Local user not found") return config.User{} } @@ -237,16 +221,11 @@ func (auth *AuthService) EmailWhitelisted(email string) bool { } func (auth *AuthService) CreateSessionCookie(c *gin.Context, data *config.SessionCookie) error { - log.Debug().Msg("Creating session cookie") - session, err := auth.GetSession(c) if err != nil { - log.Error().Err(err).Msg("Failed to get session") return err } - log.Debug().Msg("Setting session cookie") - var sessionExpiry int if data.TotpPending { @@ -265,7 +244,6 @@ func (auth *AuthService) CreateSessionCookie(c *gin.Context, data *config.Sessio err = session.Save(c.Request, c.Writer) if err != nil { - log.Error().Err(err).Msg("Failed to save session") return err } @@ -273,11 +251,8 @@ func (auth *AuthService) CreateSessionCookie(c *gin.Context, data *config.Sessio } func (auth *AuthService) DeleteSessionCookie(c *gin.Context) error { - log.Debug().Msg("Deleting session cookie") - session, err := auth.GetSession(c) if err != nil { - log.Error().Err(err).Msg("Failed to get session") return err } @@ -288,7 +263,6 @@ func (auth *AuthService) DeleteSessionCookie(c *gin.Context) error { err = session.Save(c.Request, c.Writer) if err != nil { - log.Error().Err(err).Msg("Failed to save session") return err } @@ -296,16 +270,11 @@ func (auth *AuthService) DeleteSessionCookie(c *gin.Context) error { } func (auth *AuthService) GetSessionCookie(c *gin.Context) (config.SessionCookie, error) { - log.Debug().Msg("Getting session cookie") - session, err := auth.GetSession(c) if err != nil { - log.Error().Err(err).Msg("Failed to get session") return config.SessionCookie{}, err } - log.Debug().Msg("Got session") - username, usernameOk := session.Values["username"].(string) email, emailOk := session.Values["email"].(string) name, nameOk := session.Values["name"].(string) @@ -328,7 +297,6 @@ func (auth *AuthService) GetSessionCookie(c *gin.Context) (config.SessionCookie, return config.SessionCookie{}, nil } - log.Debug().Str("username", username).Str("provider", provider).Int64("expiry", expiry).Bool("totpPending", totpPending).Str("name", name).Str("email", email).Str("oauthGroups", oauthGroups).Msg("Parsed cookie") return config.SessionCookie{ Username: username, Name: name, @@ -359,7 +327,6 @@ func (auth *AuthService) OAuthGroup(c *gin.Context, context config.UserContext, return true } - // Check if we are using the generic oauth provider if context.Provider != "generic" { log.Debug().Msg("Not using generic provider, skipping group check") return true @@ -371,7 +338,6 @@ func (auth *AuthService) OAuthGroup(c *gin.Context, context config.UserContext, // For every group check if it is in the required groups for _, group := range oauthGroups { if utils.CheckFilter(labels.OAuth.Groups, group) { - log.Debug().Str("group", group).Msg("Group is in required groups") return true } } @@ -387,12 +353,9 @@ func (auth *AuthService) AuthEnabled(uri string, labels config.Labels) (bool, er return true, nil } - // Compile regex regex, err := regexp.Compile(labels.Allowed) - // If there is an error, invalid regex, auth enabled if err != nil { - log.Error().Err(err).Msg("Invalid regex") return true, err } @@ -408,6 +371,7 @@ func (auth *AuthService) AuthEnabled(uri string, labels config.Labels) (bool, er func (auth *AuthService) GetBasicAuth(c *gin.Context) *config.User { username, password, ok := c.Request.BasicAuth() if !ok { + log.Debug().Msg("No basic auth provided") return nil } return &config.User{ @@ -421,11 +385,11 @@ func (auth *AuthService) CheckIP(labels config.Labels, ip string) bool { for _, blocked := range labels.IP.Block { res, err := utils.FilterIP(blocked, ip) if err != nil { - log.Error().Err(err).Str("item", blocked).Msg("Invalid IP/CIDR in block list") + log.Warn().Err(err).Str("item", blocked).Msg("Invalid IP/CIDR in block list") continue } if res { - log.Warn().Str("ip", ip).Str("item", blocked).Msg("IP is in blocked list, denying access") + log.Debug().Str("ip", ip).Str("item", blocked).Msg("IP is in blocked list, denying access") return false } } @@ -434,7 +398,7 @@ func (auth *AuthService) CheckIP(labels config.Labels, ip string) bool { for _, allowed := range labels.IP.Allow { res, err := utils.FilterIP(allowed, ip) if err != nil { - log.Error().Err(err).Str("item", allowed).Msg("Invalid IP/CIDR in allow list") + log.Warn().Err(err).Str("item", allowed).Msg("Invalid IP/CIDR in allow list") continue } if res { @@ -445,7 +409,7 @@ func (auth *AuthService) CheckIP(labels config.Labels, ip string) bool { // If not in allowed range and allowed range is not empty, deny access if len(labels.IP.Allow) > 0 { - log.Warn().Str("ip", ip).Msg("IP not in allow list, denying access") + log.Debug().Str("ip", ip).Msg("IP not in allow list, denying access") return false } @@ -458,7 +422,7 @@ func (auth *AuthService) BypassedIP(labels config.Labels, ip string) bool { for _, bypassed := range labels.IP.Bypass { res, err := utils.FilterIP(bypassed, ip) if err != nil { - log.Error().Err(err).Str("item", bypassed).Msg("Invalid IP/CIDR in bypass list") + log.Warn().Err(err).Str("item", bypassed).Msg("Invalid IP/CIDR in bypass list") continue } if res { diff --git a/internal/service/docker_service.go b/internal/service/docker_service.go index f067d7f2..448a9742 100644 --- a/internal/service/docker_service.go +++ b/internal/service/docker_service.go @@ -6,6 +6,8 @@ import ( "tinyauth/internal/config" "tinyauth/internal/utils" + "slices" + container "github.com/docker/docker/api/types/container" "github.com/docker/docker/client" "github.com/rs/zerolog/log" @@ -60,11 +62,8 @@ func (docker *DockerService) GetLabels(app string, domain string) (config.Labels return config.Labels{}, nil } - log.Debug().Msg("Getting containers") - containers, err := docker.GetContainers() if err != nil { - log.Error().Err(err).Msg("Error getting containers") return config.Labels{}, err } @@ -75,8 +74,6 @@ func (docker *DockerService) GetLabels(app string, domain string) (config.Labels continue } - log.Debug().Str("id", inspect.ID).Msg("Getting labels for container") - labels, err := utils.GetLabels(inspect.Config.Labels) if err != nil { log.Warn().Str("id", container.ID).Err(err).Msg("Error getting container labels, skipping") @@ -84,11 +81,9 @@ func (docker *DockerService) GetLabels(app string, domain string) (config.Labels } // Check if the container matches the ID or domain - for _, lDomain := range labels.Domain { - if lDomain == domain { - log.Debug().Str("id", inspect.ID).Msg("Found matching container by domain") - return labels, nil - } + if slices.Contains(labels.Domain, domain) { + log.Debug().Str("id", inspect.ID).Msg("Found matching container by domain") + return labels, nil } if strings.TrimPrefix(inspect.Name, "/") == app { diff --git a/internal/service/ldap_service.go b/internal/service/ldap_service.go index 805e2f72..503432f1 100644 --- a/internal/service/ldap_service.go +++ b/internal/service/ldap_service.go @@ -55,7 +55,6 @@ func (ldap *LdapService) Init() error { } func (ldap *LdapService) connect() (*ldapgo.Conn, error) { - log.Debug().Msg("Connecting to LDAP server") conn, err := ldapgo.DialURL(ldap.Config.Address, ldapgo.DialWithTLSConfig(&tls.Config{ InsecureSkipVerify: ldap.Config.Insecure, MinVersion: tls.VersionTLS12, @@ -64,7 +63,6 @@ func (ldap *LdapService) connect() (*ldapgo.Conn, error) { return nil, err } - log.Debug().Msg("Binding to LDAP server") err = conn.Bind(ldap.Config.BindDN, ldap.Config.BindPassword) if err != nil { return nil, err @@ -94,7 +92,7 @@ func (ldap *LdapService) Search(username string) (string, error) { } if len(searchResult.Entries) != 1 { - return "", fmt.Errorf("err multiple or no entries found for user %s", username) + return "", fmt.Errorf("multiple or no entries found for user %s", username) } userDN := searchResult.Entries[0].DN From 8435cbe4340f8024e28c2c63c9bf8fc8a411959f Mon Sep 17 00:00:00 2001 From: Stavros Date: Tue, 26 Aug 2025 13:17:10 +0300 Subject: [PATCH 14/17] fix: configure middlewares before groups --- internal/bootstrap/app_bootstrap.go | 11 ++++++----- internal/controller/proxy_controller.go | 2 +- internal/service/docker_service.go | 3 +++ 3 files changed, 10 insertions(+), 6 deletions(-) diff --git a/internal/bootstrap/app_bootstrap.go b/internal/bootstrap/app_bootstrap.go index 54838d1e..4401172b 100644 --- a/internal/bootstrap/app_bootstrap.go +++ b/internal/bootstrap/app_bootstrap.go @@ -148,9 +148,6 @@ func (app *BootstrapApp) Setup() error { gin.SetMode(gin.ReleaseMode) } - router := engine.Group("/") - apiRouter := router.Group("/api") - // Create middlewares var middlewares []Middleware @@ -169,9 +166,13 @@ func (app *BootstrapApp) Setup() error { if err != nil { return fmt.Errorf("failed to initialize %s middleware: %T", middleware, err) } - router.Use(middleware.Middleware()) + engine.Use(middleware.Middleware()) } + // Create routers + mainRouter := engine.Group("/") + apiRouter := engine.Group("/api") + // Create controllers contextController := controller.NewContextController(controller.ContextControllerConfig{ ConfiguredProviders: configuredProviders, @@ -201,7 +202,7 @@ func (app *BootstrapApp) Setup() error { resourcesController := controller.NewResourcesController(controller.ResourcesControllerConfig{ ResourcesDir: app.Config.ResourcesDir, - }, router) + }, mainRouter) healthController := controller.NewHealthController(apiRouter) diff --git a/internal/controller/proxy_controller.go b/internal/controller/proxy_controller.go index 9515d329..ae7a1013 100644 --- a/internal/controller/proxy_controller.go +++ b/internal/controller/proxy_controller.go @@ -38,7 +38,7 @@ func NewProxyController(config ProxyControllerConfig, router *gin.RouterGroup, d } func (controller *ProxyController) SetupRoutes() { - proxyGroup := controller.Router.Group("/api/auth") + proxyGroup := controller.Router.Group("/auth") proxyGroup.GET("/:proxy", controller.proxyHandler) } diff --git a/internal/service/docker_service.go b/internal/service/docker_service.go index 448a9742..41eb07c9 100644 --- a/internal/service/docker_service.go +++ b/internal/service/docker_service.go @@ -30,6 +30,9 @@ func (docker *DockerService) Init() error { ctx := context.Background() client.NegotiateAPIVersion(ctx) + + docker.Client = client + docker.Context = ctx return nil } From d3c40bb366b7839fed0b95aa9e76ce1d4de7cc32 Mon Sep 17 00:00:00 2001 From: Stavros Date: Tue, 26 Aug 2025 13:21:36 +0300 Subject: [PATCH 15/17] fix: use correct api path in login mutation --- frontend/src/pages/login-page.tsx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/frontend/src/pages/login-page.tsx b/frontend/src/pages/login-page.tsx index 4828b383..53f183f1 100644 --- a/frontend/src/pages/login-page.tsx +++ b/frontend/src/pages/login-page.tsx @@ -65,7 +65,7 @@ export const LoginPage = () => { }); const loginMutation = useMutation({ - mutationFn: (values: LoginSchema) => axios.post("/api/login", values), + mutationFn: (values: LoginSchema) => axios.post("/api/user/login", values), mutationKey: ["login"], onSuccess: (data) => { if (data.data.totpPending) { From a5e1ae096bc355d78024e9cb88ae5fa66a7512b4 Mon Sep 17 00:00:00 2001 From: Stavros Date: Tue, 26 Aug 2025 14:31:09 +0300 Subject: [PATCH 16/17] fix: coderabbit suggestions --- cmd/root.go | 2 +- internal/assets/assets.go | 2 +- internal/bootstrap/app_bootstrap.go | 5 ++- internal/config/config.go | 4 +- internal/controller/oauth_controller.go | 7 +-- internal/controller/proxy_controller.go | 9 ++-- internal/controller/resources_controller.go | 22 ++++++--- internal/controller/user_controller.go | 39 +++++++++++++--- internal/middleware/context_middleware.go | 1 + internal/middleware/ui_middleware.go | 4 +- internal/middleware/zerolog_middleware.go | 4 +- internal/service/auth_service.go | 4 +- internal/service/generic_oauth_service.go | 7 ++- internal/service/github_oauth_service.go | 29 ++++++++++-- internal/service/google_oauth_service.go | 7 ++- internal/utils/app_utils.go | 50 ++++++++++++++++++--- internal/utils/label_utils.go | 5 +++ internal/utils/user_utils.go | 22 ++++++--- main.go | 2 +- 19 files changed, 178 insertions(+), 47 deletions(-) diff --git a/cmd/root.go b/cmd/root.go index 2b0c172e..7ed86e49 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -82,7 +82,7 @@ func init() { {"app-url", "", "The Tinyauth URL."}, {"users", "", "Comma separated list of users in the format username:hash."}, {"users-file", "", "Path to a file containing users in the format username:hash."}, - {"cookie-secure", false, "Send cookie over secure connection only."}, + {"secure-cookie", false, "Send cookie over secure connection only."}, {"github-client-id", "", "Github OAuth client ID."}, {"github-client-secret", "", "Github OAuth client secret."}, {"github-client-secret-file", "", "Github OAuth client secret file."}, diff --git a/internal/assets/assets.go b/internal/assets/assets.go index 0b572e0c..df6e61f1 100644 --- a/internal/assets/assets.go +++ b/internal/assets/assets.go @@ -7,4 +7,4 @@ import ( // Frontend assets // //go:embed dist -var FontendAssets embed.FS +var FrontendAssets embed.FS diff --git a/internal/bootstrap/app_bootstrap.go b/internal/bootstrap/app_bootstrap.go index 4401172b..594c575f 100644 --- a/internal/bootstrap/app_bootstrap.go +++ b/internal/bootstrap/app_bootstrap.go @@ -164,13 +164,13 @@ func (app *BootstrapApp) Setup() error { log.Debug().Str("middleware", fmt.Sprintf("%T", middleware)).Msg("Initializing middleware") err := middleware.Init() if err != nil { - return fmt.Errorf("failed to initialize %s middleware: %T", middleware, err) + return fmt.Errorf("failed to initialize middleware %T: %w", middleware, err) } engine.Use(middleware.Middleware()) } // Create routers - mainRouter := engine.Group("/") + mainRouter := engine.Group("") apiRouter := engine.Group("/api") // Create controllers @@ -190,6 +190,7 @@ func (app *BootstrapApp) Setup() error { SecureCookie: app.Config.SecureCookie, CSRFCookieName: csrfCookieName, RedirectCookieName: redirectCookieName, + Domain: domain, }, apiRouter, authService, oauthBrokerService) proxyController := controller.NewProxyController(controller.ProxyControllerConfig{ diff --git a/internal/config/config.go b/internal/config/config.go index 48961d63..5d4dba86 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -65,10 +65,10 @@ type OAuthLabels struct { type BasicLabels struct { Username string - Password PassowrdLabels + Password PasswordLabels } -type PassowrdLabels struct { +type PasswordLabels struct { Plain string File string } diff --git a/internal/controller/oauth_controller.go b/internal/controller/oauth_controller.go index 025db1bd..9802ea17 100644 --- a/internal/controller/oauth_controller.go +++ b/internal/controller/oauth_controller.go @@ -23,6 +23,7 @@ type OAuthControllerConfig struct { RedirectCookieName string SecureCookie bool AppURL string + Domain string } type OAuthController struct { @@ -77,7 +78,7 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) { redirectURI := c.Query("redirect_uri") - if redirectURI != "" { + if redirectURI != "" && utils.IsRedirectSafe(redirectURI, controller.Config.Domain) { log.Debug().Msg("Setting redirect URI cookie") c.SetCookie(controller.Config.RedirectCookieName, redirectURI, int(time.Hour.Seconds()), "/", "", controller.Config.SecureCookie, true) } @@ -178,7 +179,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) { redirectURI, err := c.Cookie(controller.Config.RedirectCookieName) - if err != nil { + if err != nil || !utils.IsRedirectSafe(redirectURI, controller.Config.Domain) { log.Debug().Msg("No redirect URI cookie found, redirecting to app root") c.Redirect(http.StatusTemporaryRedirect, controller.Config.AppURL) return @@ -195,5 +196,5 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) { } c.SetCookie(controller.Config.RedirectCookieName, "", -1, "/", "", controller.Config.SecureCookie, true) - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/login?%s", controller.Config.AppURL, queries.Encode())) + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/continue?%s", controller.Config.AppURL, queries.Encode())) } diff --git a/internal/controller/proxy_controller.go b/internal/controller/proxy_controller.go index ae7a1013..348be65b 100644 --- a/internal/controller/proxy_controller.go +++ b/internal/controller/proxy_controller.go @@ -128,6 +128,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) { if err != nil { log.Error().Err(err).Msg("Failed to encode unauthorized query") c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.Config.AppURL)) + return } c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/unauthorized?%s", controller.Config.AppURL, queries.Encode())) @@ -212,9 +213,9 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) { }) if userContext.OAuth { - queries.Set("username", userContext.Username) - } else { queries.Set("username", userContext.Email) + } else { + queries.Set("username", userContext.Username) } if err != nil { @@ -247,9 +248,9 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) { }) if userContext.OAuth { - queries.Set("username", userContext.Username) - } else { queries.Set("username", userContext.Email) + } else { + queries.Set("username", userContext.Username) } if err != nil { diff --git a/internal/controller/resources_controller.go b/internal/controller/resources_controller.go index f0c20096..56bae87d 100644 --- a/internal/controller/resources_controller.go +++ b/internal/controller/resources_controller.go @@ -11,14 +11,18 @@ type ResourcesControllerConfig struct { } type ResourcesController struct { - Config ResourcesControllerConfig - Router *gin.RouterGroup + Config ResourcesControllerConfig + Router *gin.RouterGroup + FileServer http.Handler } func NewResourcesController(config ResourcesControllerConfig, router *gin.RouterGroup) *ResourcesController { + fileServer := http.StripPrefix("/resources", http.FileServer(http.Dir(config.ResourcesDir))) + return &ResourcesController{ - Config: config, - Router: router, + Config: config, + Router: router, + FileServer: fileServer, } } @@ -27,6 +31,12 @@ func (controller *ResourcesController) SetupRoutes() { } func (controller *ResourcesController) resourcesHandler(c *gin.Context) { - fileServer := http.StripPrefix("/resources", http.FileServer(http.Dir(controller.Config.ResourcesDir))) - fileServer.ServeHTTP(c.Writer, c.Request) + if controller.Config.ResourcesDir == "" { + c.JSON(404, gin.H{ + "status": 404, + "message": "Resources not found", + }) + return + } + controller.FileServer.ServeHTTP(c.Writer, c.Request) } diff --git a/internal/controller/user_controller.go b/internal/controller/user_controller.go index 7f307e37..72e22d87 100644 --- a/internal/controller/user_controller.go +++ b/internal/controller/user_controller.go @@ -112,7 +112,7 @@ func (controller *UserController) loginHandler(c *gin.Context) { if user.TotpSecret != "" { log.Debug().Str("username", req.Username).Msg("User has TOTP enabled, requiring TOTP verification") - controller.Auth.CreateSessionCookie(c, &config.SessionCookie{ + err := controller.Auth.CreateSessionCookie(c, &config.SessionCookie{ Username: user.Username, Name: utils.Capitalize(req.Username), Email: fmt.Sprintf("%s@%s", strings.ToLower(req.Username), controller.Config.Domain), @@ -120,6 +120,15 @@ func (controller *UserController) loginHandler(c *gin.Context) { TotpPending: true, }) + if err != nil { + log.Error().Err(err).Msg("Failed to create session cookie") + c.JSON(500, gin.H{ + "status": 500, + "message": "Internal Server Error", + }) + return + } + c.JSON(200, gin.H{ "status": 200, "message": "TOTP required", @@ -129,13 +138,22 @@ func (controller *UserController) loginHandler(c *gin.Context) { } } - controller.Auth.CreateSessionCookie(c, &config.SessionCookie{ + err = controller.Auth.CreateSessionCookie(c, &config.SessionCookie{ Username: req.Username, Name: utils.Capitalize(req.Username), Email: fmt.Sprintf("%s@%s", strings.ToLower(req.Username), controller.Config.Domain), Provider: "username", }) + if err != nil { + log.Error().Err(err).Msg("Failed to create session cookie") + c.JSON(500, gin.H{ + "status": 500, + "message": "Internal Server Error", + }) + return + } + c.JSON(200, gin.H{ "status": 200, "message": "Login successful", @@ -144,7 +162,9 @@ func (controller *UserController) loginHandler(c *gin.Context) { func (controller *UserController) logoutHandler(c *gin.Context) { log.Debug().Msg("Logout request received") + controller.Auth.DeleteSessionCookie(c) + c.JSON(200, gin.H{ "status": 200, "message": "Logout successful", @@ -175,8 +195,8 @@ func (controller *UserController) totpHandler(c *gin.Context) { return } - if !context.IsLoggedIn { - log.Warn().Msg("TOTP attempt without being logged in") + if !context.TotpPending { + log.Warn().Msg("TOTP attempt without a pending TOTP session") c.JSON(401, gin.H{ "status": 401, "message": "Unauthorized", @@ -223,13 +243,22 @@ func (controller *UserController) totpHandler(c *gin.Context) { controller.Auth.RecordLoginAttempt(rateIdentifier, true) - controller.Auth.CreateSessionCookie(c, &config.SessionCookie{ + err = controller.Auth.CreateSessionCookie(c, &config.SessionCookie{ Username: user.Username, Name: utils.Capitalize(user.Username), Email: fmt.Sprintf("%s@%s", strings.ToLower(user.Username), controller.Config.Domain), Provider: "username", }) + if err != nil { + log.Error().Err(err).Msg("Failed to create session cookie") + c.JSON(500, gin.H{ + "status": 500, + "message": "Internal Server Error", + }) + return + } + c.JSON(200, gin.H{ "status": 200, "message": "Login successful", diff --git a/internal/middleware/context_middleware.go b/internal/middleware/context_middleware.go index e11f80ca..58e53e15 100644 --- a/internal/middleware/context_middleware.go +++ b/internal/middleware/context_middleware.go @@ -79,6 +79,7 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc { if !exists { log.Debug().Msg("OAuth provider from session cookie not found") + m.Auth.DeleteSessionCookie(c) goto basic } diff --git a/internal/middleware/ui_middleware.go b/internal/middleware/ui_middleware.go index 6c03e4ff..dcfaa35b 100644 --- a/internal/middleware/ui_middleware.go +++ b/internal/middleware/ui_middleware.go @@ -20,10 +20,10 @@ func NewUIMiddleware() *UIMiddleware { } func (m *UIMiddleware) Init() error { - ui, err := fs.Sub(assets.FontendAssets, "dist") + ui, err := fs.Sub(assets.FrontendAssets, "dist") if err != nil { - return nil + return err } m.UIFS = ui diff --git a/internal/middleware/zerolog_middleware.go b/internal/middleware/zerolog_middleware.go index 95f5821f..877ad4c8 100644 --- a/internal/middleware/zerolog_middleware.go +++ b/internal/middleware/zerolog_middleware.go @@ -10,8 +10,8 @@ import ( var ( loggerSkipPathsPrefix = []string{ - "GET /api/healthcheck", - "HEAD /api/healthcheck", + "GET /api/health", + "HEAD /api/health", "GET /favicon.ico", } ) diff --git a/internal/service/auth_service.go b/internal/service/auth_service.go index 8c91e790..29f2dd16 100644 --- a/internal/service/auth_service.go +++ b/internal/service/auth_service.go @@ -71,9 +71,9 @@ func (auth *AuthService) GetSession(c *gin.Context) (*sessions.Session, error) { // If there was an error getting the session, it might be invalid so let's clear it and retry if err != nil { - log.Debug().Err(err).Msg("Error getting session, clearing cookie and retrying") + log.Debug().Err(err).Msg("Error getting session, creating a new one") c.SetCookie(auth.Config.SessionCookieName, "", -1, "/", fmt.Sprintf(".%s", auth.Config.Domain), auth.Config.SecureCookie, true) - session, err = auth.Store.Get(c.Request, auth.Config.SessionCookieName) + session, err = auth.Store.New(c.Request, auth.Config.SessionCookieName) if err != nil { return nil, err } diff --git a/internal/service/generic_oauth_service.go b/internal/service/generic_oauth_service.go index c68d150f..a09fd933 100644 --- a/internal/service/generic_oauth_service.go +++ b/internal/service/generic_oauth_service.go @@ -6,6 +6,7 @@ import ( "crypto/tls" "encoding/base64" "encoding/json" + "fmt" "io" "net/http" "tinyauth/internal/config" @@ -76,7 +77,7 @@ func (generic *GenericOAuthService) VerifyCode(code string) error { token, err := generic.Config.Exchange(generic.Context, code, oauth2.VerifierOption(generic.Verifier)) if err != nil { - return nil + return err } generic.Token = token @@ -94,6 +95,10 @@ func (generic *GenericOAuthService) Userinfo() (config.Claims, error) { } defer res.Body.Close() + if res.StatusCode < 200 || res.StatusCode >= 300 { + return user, fmt.Errorf("request failed with status: %s", res.Status) + } + body, err := io.ReadAll(res.Body) if err != nil { return user, err diff --git a/internal/service/github_oauth_service.go b/internal/service/github_oauth_service.go index 2f9e27f8..4df4444b 100644 --- a/internal/service/github_oauth_service.go +++ b/internal/service/github_oauth_service.go @@ -6,6 +6,7 @@ import ( "encoding/base64" "encoding/json" "errors" + "fmt" "io" "net/http" "tinyauth/internal/config" @@ -71,7 +72,7 @@ func (github *GithubOAuthService) VerifyCode(code string) error { token, err := github.Config.Exchange(github.Context, code, oauth2.VerifierOption(github.Verifier)) if err != nil { - return nil + return err } github.Token = token @@ -83,12 +84,23 @@ func (github *GithubOAuthService) Userinfo() (config.Claims, error) { client := github.Config.Client(github.Context, github.Token) - res, err := client.Get("https://api.github.com/user") + req, err := http.NewRequest("GET", "https://api.github.com/user", nil) + if err != nil { + return user, err + } + + req.Header.Set("Accept", "application/vnd.github+json") + + res, err := client.Do(req) if err != nil { return user, err } defer res.Body.Close() + if res.StatusCode < 200 || res.StatusCode >= 300 { + return user, fmt.Errorf("request failed with status: %s", res.Status) + } + body, err := io.ReadAll(res.Body) if err != nil { return user, err @@ -101,12 +113,23 @@ func (github *GithubOAuthService) Userinfo() (config.Claims, error) { return user, err } - res, err = client.Get("https://api.github.com/user/emails") + req, err = http.NewRequest("GET", "https://api.github.com/user/emails", nil) + if err != nil { + return user, err + } + + req.Header.Set("Accept", "application/vnd.github+json") + + res, err = client.Do(req) if err != nil { return user, err } defer res.Body.Close() + if res.StatusCode < 200 || res.StatusCode >= 300 { + return user, fmt.Errorf("request failed with status: %s", res.Status) + } + body, err = io.ReadAll(res.Body) if err != nil { return user, err diff --git a/internal/service/google_oauth_service.go b/internal/service/google_oauth_service.go index 776aeca7..4f738e78 100644 --- a/internal/service/google_oauth_service.go +++ b/internal/service/google_oauth_service.go @@ -5,6 +5,7 @@ import ( "crypto/rand" "encoding/base64" "encoding/json" + "fmt" "io" "net/http" "strings" @@ -66,7 +67,7 @@ func (google *GoogleOAuthService) VerifyCode(code string) error { token, err := google.Config.Exchange(google.Context, code, oauth2.VerifierOption(google.Verifier)) if err != nil { - return nil + return err } google.Token = token @@ -84,6 +85,10 @@ func (google *GoogleOAuthService) Userinfo() (config.Claims, error) { } defer res.Body.Close() + if res.StatusCode < 200 || res.StatusCode >= 300 { + return user, fmt.Errorf("request failed with status: %s", res.Status) + } + body, err := io.ReadAll(res.Body) if err != nil { return config.Claims{}, err diff --git a/internal/utils/app_utils.go b/internal/utils/app_utils.go index 1ed8d4c7..85a87542 100644 --- a/internal/utils/app_utils.go +++ b/internal/utils/app_utils.go @@ -2,6 +2,7 @@ package utils import ( "errors" + "net" "net/url" "strings" "tinyauth/internal/config" @@ -12,16 +13,25 @@ import ( ) // Get upper domain parses a hostname and returns the upper domain (e.g. sub1.sub2.domain.com -> sub2.domain.com) -func GetUpperDomain(urlSrc string) (string, error) { - urlParsed, err := url.Parse(urlSrc) +func GetUpperDomain(appUrl string) (string, error) { + appUrlParsed, err := url.Parse(appUrl) if err != nil { return "", err } - urlSplitted := strings.Split(urlParsed.Hostname(), ".") - urlFinal := strings.Join(urlSplitted[1:], ".") + host := appUrlParsed.Hostname() - return urlFinal, nil + if netIP := net.ParseIP(host); netIP != nil { + return "", errors.New("IP addresses are not allowed") + } + + urlParts := strings.Split(host, ".") + + if len(urlParts) < 2 { + return "", errors.New("invalid domain, must be at least second level domain") + } + + return strings.Join(urlParts[1:], "."), nil } func ParseFileToLine(content string) string { @@ -63,8 +73,38 @@ func GetContext(c *gin.Context) (config.UserContext, error) { return *userContext, nil } +func IsRedirectSafe(redirectURL string, domain string) bool { + if redirectURL == "" { + return false + } + + parsedURL, err := url.Parse(redirectURL) + + if err != nil { + return false + } + + if !parsedURL.IsAbs() { + return false + } + + upper, err := GetUpperDomain(redirectURL) + + if err != nil { + return false + } + + if upper != domain { + return false + } + + return true +} + func GetLogLevel(level string) zerolog.Level { switch strings.ToLower(level) { + case "trace": + return zerolog.TraceLevel case "debug": return zerolog.DebugLevel case "info": diff --git a/internal/utils/label_utils.go b/internal/utils/label_utils.go index a01685b3..f10092df 100644 --- a/internal/utils/label_utils.go +++ b/internal/utils/label_utils.go @@ -1,6 +1,7 @@ package utils import ( + "net/http" "strings" "tinyauth/internal/config" @@ -26,6 +27,10 @@ func ParseHeaders(headers []string) map[string]string { continue } key := SanitizeHeader(strings.TrimSpace(split[0])) + if strings.ContainsAny(key, " \t") { + continue + } + key = http.CanonicalHeaderKey(key) value := SanitizeHeader(strings.TrimSpace(split[1])) headerMap[key] = value } diff --git a/internal/utils/user_utils.go b/internal/utils/user_utils.go index bfcec495..0044db4a 100644 --- a/internal/utils/user_utils.go +++ b/internal/utils/user_utils.go @@ -9,6 +9,12 @@ import ( func ParseUsers(users string) ([]config.User, error) { var usersParsed []config.User + users = strings.TrimSpace(users) + + if users == "" { + return []config.User{}, nil + } + userList := strings.Split(users, ",") if len(userList) == 0 { @@ -16,7 +22,10 @@ func ParseUsers(users string) ([]config.User, error) { } for _, user := range userList { - parsed, err := ParseUser(user) + if strings.TrimSpace(user) == "" { + continue + } + parsed, err := ParseUser(strings.TrimSpace(user)) if err != nil { return []config.User{}, err } @@ -39,12 +48,13 @@ func GetUsers(conf string, file string) ([]config.User, error) { if file != "" { contents, err := ReadFile(file) - if err == nil { - if users != "" { - users += "," - } - users += ParseFileToLine(contents) + if err != nil { + return []config.User{}, err + } + if users != "" { + users += "," } + users += ParseFileToLine(contents) } return ParseUsers(users) diff --git a/main.go b/main.go index eac789e8..8126e9ed 100644 --- a/main.go +++ b/main.go @@ -10,6 +10,6 @@ import ( ) func main() { - log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr, TimeFormat: time.RFC3339}).With().Timestamp().Caller().Logger().Level(zerolog.FatalLevel) + log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr, TimeFormat: time.RFC3339}).With().Timestamp().Caller().Logger() cmd.Execute() } From a1b6ecdd5dba49a0df64038e512c9ff252bbc0b5 Mon Sep 17 00:00:00 2001 From: Stavros Date: Tue, 26 Aug 2025 14:49:55 +0300 Subject: [PATCH 17/17] fix: further coderabbit suggestions --- cmd/root.go | 6 +++--- internal/controller/oauth_controller.go | 8 ++++---- internal/controller/user_controller.go | 4 ++-- internal/service/auth_service.go | 3 +++ internal/service/generic_oauth_service.go | 8 ++++++-- internal/service/github_oauth_service.go | 8 ++++++-- internal/service/google_oauth_service.go | 8 ++++++-- internal/service/ldap_service.go | 2 +- internal/utils/security_utils.go | 4 ++-- 9 files changed, 33 insertions(+), 18 deletions(-) diff --git a/cmd/root.go b/cmd/root.go index 7ed86e49..ef5733e1 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -8,7 +8,7 @@ import ( "tinyauth/internal/config" "tinyauth/internal/utils" - "github.com/go-playground/validator" + "github.com/go-playground/validator/v10" "github.com/rs/zerolog" "github.com/rs/zerolog/log" "github.com/spf13/cobra" @@ -34,9 +34,9 @@ var rootCmd = &cobra.Command{ conf.GenericClientSecret = utils.GetSecret(conf.GenericClientSecret, conf.GenericClientSecretFile) // Validate config - validator := validator.New() + v := validator.New() - err = validator.Struct(conf) + err = v.Struct(conf) if err != nil { log.Fatal().Err(err).Msg("Invalid config") } diff --git a/internal/controller/oauth_controller.go b/internal/controller/oauth_controller.go index 9802ea17..aa3289bb 100644 --- a/internal/controller/oauth_controller.go +++ b/internal/controller/oauth_controller.go @@ -74,13 +74,13 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) { state := service.GenerateState() authURL := service.GetAuthURL(state) - c.SetCookie(controller.Config.CSRFCookieName, state, int(time.Hour.Seconds()), "/", "", controller.Config.SecureCookie, true) + c.SetCookie(controller.Config.CSRFCookieName, state, int(time.Hour.Seconds()), "/", fmt.Sprintf(".%s", controller.Config.Domain), controller.Config.SecureCookie, true) redirectURI := c.Query("redirect_uri") if redirectURI != "" && utils.IsRedirectSafe(redirectURI, controller.Config.Domain) { log.Debug().Msg("Setting redirect URI cookie") - c.SetCookie(controller.Config.RedirectCookieName, redirectURI, int(time.Hour.Seconds()), "/", "", controller.Config.SecureCookie, true) + c.SetCookie(controller.Config.RedirectCookieName, redirectURI, int(time.Hour.Seconds()), "/", fmt.Sprintf(".%s", controller.Config.Domain), controller.Config.SecureCookie, true) } c.JSON(200, gin.H{ @@ -112,7 +112,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) { return } - c.SetCookie(controller.Config.CSRFCookieName, "", -1, "/", "", controller.Config.SecureCookie, true) + c.SetCookie(controller.Config.CSRFCookieName, "", -1, "/", fmt.Sprintf(".%s", controller.Config.Domain), controller.Config.SecureCookie, true) code := c.Query("code") service, exists := controller.Broker.GetService(req.Provider) @@ -195,6 +195,6 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) { return } - c.SetCookie(controller.Config.RedirectCookieName, "", -1, "/", "", controller.Config.SecureCookie, true) + c.SetCookie(controller.Config.RedirectCookieName, "", -1, "/", fmt.Sprintf(".%s", controller.Config.Domain), controller.Config.SecureCookie, true) c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/continue?%s", controller.Config.AppURL, queries.Encode())) } diff --git a/internal/controller/user_controller.go b/internal/controller/user_controller.go index 72e22d87..f7f7c9e6 100644 --- a/internal/controller/user_controller.go +++ b/internal/controller/user_controller.go @@ -49,7 +49,7 @@ func (controller *UserController) SetupRoutes() { func (controller *UserController) loginHandler(c *gin.Context) { var req LoginRequest - err := c.BindJSON(&req) + err := c.ShouldBindJSON(&req) if err != nil { log.Error().Err(err).Msg("Failed to bind JSON") c.JSON(400, gin.H{ @@ -174,7 +174,7 @@ func (controller *UserController) logoutHandler(c *gin.Context) { func (controller *UserController) totpHandler(c *gin.Context) { var req TotpRequest - err := c.BindJSON(&req) + err := c.ShouldBindJSON(&req) if err != nil { log.Error().Err(err).Msg("Failed to bind JSON") c.JSON(400, gin.H{ diff --git a/internal/service/auth_service.go b/internal/service/auth_service.go index 29f2dd16..10d49e79 100644 --- a/internal/service/auth_service.go +++ b/internal/service/auth_service.go @@ -266,6 +266,9 @@ func (auth *AuthService) DeleteSessionCookie(c *gin.Context) error { return err } + // Clear the cookie in the browser + c.SetCookie(auth.Config.SessionCookieName, "", -1, "/", fmt.Sprintf(".%s", auth.Config.Domain), auth.Config.SecureCookie, true) + return nil } diff --git a/internal/service/generic_oauth_service.go b/internal/service/generic_oauth_service.go index a09fd933..c16384db 100644 --- a/internal/service/generic_oauth_service.go +++ b/internal/service/generic_oauth_service.go @@ -9,6 +9,7 @@ import ( "fmt" "io" "net/http" + "time" "tinyauth/internal/config" "golang.org/x/oauth2" @@ -64,8 +65,11 @@ func (generic *GenericOAuthService) Init() error { func (generic *GenericOAuthService) GenerateState() string { b := make([]byte, 128) - rand.Read(b) - state := base64.URLEncoding.EncodeToString(b) + _, err := rand.Read(b) + if err != nil { + return base64.RawURLEncoding.EncodeToString(fmt.Appendf(nil, "state-%d", time.Now().UnixNano())) + } + state := base64.RawURLEncoding.EncodeToString(b) return state } diff --git a/internal/service/github_oauth_service.go b/internal/service/github_oauth_service.go index 4df4444b..7f8466b9 100644 --- a/internal/service/github_oauth_service.go +++ b/internal/service/github_oauth_service.go @@ -9,6 +9,7 @@ import ( "fmt" "io" "net/http" + "time" "tinyauth/internal/config" "golang.org/x/oauth2" @@ -59,8 +60,11 @@ func (github *GithubOAuthService) Init() error { func (github *GithubOAuthService) GenerateState() string { b := make([]byte, 128) - rand.Read(b) - state := base64.URLEncoding.EncodeToString(b) + _, err := rand.Read(b) + if err != nil { + return base64.RawURLEncoding.EncodeToString(fmt.Appendf(nil, "state-%d", time.Now().UnixNano())) + } + state := base64.RawURLEncoding.EncodeToString(b) return state } diff --git a/internal/service/google_oauth_service.go b/internal/service/google_oauth_service.go index 4f738e78..1605a855 100644 --- a/internal/service/google_oauth_service.go +++ b/internal/service/google_oauth_service.go @@ -9,6 +9,7 @@ import ( "io" "net/http" "strings" + "time" "tinyauth/internal/config" "golang.org/x/oauth2" @@ -54,8 +55,11 @@ func (google *GoogleOAuthService) Init() error { func (oauth *GoogleOAuthService) GenerateState() string { b := make([]byte, 128) - rand.Read(b) - state := base64.URLEncoding.EncodeToString(b) + _, err := rand.Read(b) + if err != nil { + return base64.RawURLEncoding.EncodeToString(fmt.Appendf(nil, "state-%d", time.Now().UnixNano())) + } + state := base64.RawURLEncoding.EncodeToString(b) return state } diff --git a/internal/service/ldap_service.go b/internal/service/ldap_service.go index 503432f1..8576c4d7 100644 --- a/internal/service/ldap_service.go +++ b/internal/service/ldap_service.go @@ -140,7 +140,7 @@ func (ldap *LdapService) reconnect() error { ldap.Conn.Close() conn, err := ldap.connect() if err != nil { - return nil, nil + return nil, err } return conn, nil } diff --git a/internal/utils/security_utils.go b/internal/utils/security_utils.go index 4e9e1874..a0319008 100644 --- a/internal/utils/security_utils.go +++ b/internal/utils/security_utils.go @@ -101,7 +101,7 @@ func CheckFilter(filter string, str string) bool { return false } - if re.MatchString(str) { + if re.MatchString(strings.TrimSpace(str)) { return true } } @@ -109,7 +109,7 @@ func CheckFilter(filter string, str string) bool { filterSplit := strings.Split(filter, ",") for _, item := range filterSplit { - if strings.TrimSpace(item) == str { + if strings.TrimSpace(item) == strings.TrimSpace(str) { return true } }