Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 13 additions & 13 deletions internal/controller/context_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,19 +46,19 @@ type ContextControllerConfig struct {
}

type ContextController struct {
Config ContextControllerConfig
Router *gin.RouterGroup
config ContextControllerConfig
router *gin.RouterGroup
}

func NewContextController(config ContextControllerConfig, router *gin.RouterGroup) *ContextController {
return &ContextController{
Config: config,
Router: router,
config: config,
router: router,
}
}

func (controller *ContextController) SetupRoutes() {
contextGroup := controller.Router.Group("/context")
contextGroup := controller.router.Group("/context")
contextGroup.GET("/user", controller.userContextHandler)
contextGroup.GET("/app", controller.appContextHandler)
}
Expand Down Expand Up @@ -91,18 +91,18 @@ func (controller *ContextController) userContextHandler(c *gin.Context) {
}

func (controller *ContextController) appContextHandler(c *gin.Context) {
appUrl, _ := url.Parse(controller.Config.AppURL) // no need to check error, validated on startup
appUrl, _ := url.Parse(controller.config.AppURL) // no need to check error, validated on startup
Comment thread
steveiliop56 marked this conversation as resolved.

c.JSON(200, AppContextResponse{
Status: 200,
Message: "Success",
ConfiguredProviders: controller.Config.ConfiguredProviders,
Title: controller.Config.Title,
GenericName: controller.Config.GenericName,
ConfiguredProviders: controller.config.ConfiguredProviders,
Title: controller.config.Title,
GenericName: controller.config.GenericName,
AppURL: fmt.Sprintf("%s://%s", appUrl.Scheme, appUrl.Host),
RootDomain: controller.Config.RootDomain,
ForgotPasswordMessage: controller.Config.ForgotPasswordMessage,
BackgroundImage: controller.Config.BackgroundImage,
OAuthAutoRedirect: controller.Config.OAuthAutoRedirect,
RootDomain: controller.config.RootDomain,
ForgotPasswordMessage: controller.config.ForgotPasswordMessage,
BackgroundImage: controller.config.BackgroundImage,
OAuthAutoRedirect: controller.config.OAuthAutoRedirect,
})
}
8 changes: 4 additions & 4 deletions internal/controller/health_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,18 @@ package controller
import "github.com/gin-gonic/gin"

type HealthController struct {
Router *gin.RouterGroup
router *gin.RouterGroup
}

func NewHealthController(router *gin.RouterGroup) *HealthController {
return &HealthController{
Router: router,
router: router,
}
}

func (controller *HealthController) SetupRoutes() {
controller.Router.GET("/health", controller.healthHandler)
controller.Router.HEAD("/health", controller.healthHandler)
controller.router.GET("/health", controller.healthHandler)
controller.router.HEAD("/health", controller.healthHandler)
}

func (controller *HealthController) healthHandler(c *gin.Context) {
Expand Down
78 changes: 42 additions & 36 deletions internal/controller/oauth_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,23 +27,23 @@ type OAuthControllerConfig struct {
}

type OAuthController struct {
Config OAuthControllerConfig
Router *gin.RouterGroup
Auth *service.AuthService
Broker *service.OAuthBrokerService
config OAuthControllerConfig
router *gin.RouterGroup
auth *service.AuthService
broker *service.OAuthBrokerService
}

func NewOAuthController(config OAuthControllerConfig, router *gin.RouterGroup, auth *service.AuthService, broker *service.OAuthBrokerService) *OAuthController {
return &OAuthController{
Config: config,
Router: router,
Auth: auth,
Broker: broker,
config: config,
router: router,
auth: auth,
broker: broker,
}
}

func (controller *OAuthController) SetupRoutes() {
oauthGroup := controller.Router.Group("/oauth")
oauthGroup := controller.router.Group("/oauth")
oauthGroup.GET("/url/:provider", controller.oauthURLHandler)
oauthGroup.GET("/callback/:provider", controller.oauthCallbackHandler)
}
Expand All @@ -61,7 +61,7 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
return
}

service, exists := controller.Broker.GetService(req.Provider)
service, exists := controller.broker.GetService(req.Provider)

if !exists {
log.Warn().Msgf("OAuth provider not found: %s", req.Provider)
Expand All @@ -74,13 +74,13 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) {

state := service.GenerateState()
authURL := service.GetAuthURL(state)
c.SetCookie(controller.Config.CSRFCookieName, state, int(time.Hour.Seconds()), "/", fmt.Sprintf(".%s", controller.Config.RootDomain), controller.Config.SecureCookie, true)
c.SetCookie(controller.config.CSRFCookieName, state, int(time.Hour.Seconds()), "/", fmt.Sprintf(".%s", controller.config.RootDomain), controller.config.SecureCookie, true)

redirectURI := c.Query("redirect_uri")

if redirectURI != "" && utils.IsRedirectSafe(redirectURI, controller.Config.RootDomain) {
if redirectURI != "" && utils.IsRedirectSafe(redirectURI, controller.config.RootDomain) {
log.Debug().Msg("Setting redirect URI cookie")
c.SetCookie(controller.Config.RedirectCookieName, redirectURI, int(time.Hour.Seconds()), "/", fmt.Sprintf(".%s", controller.Config.RootDomain), controller.Config.SecureCookie, true)
c.SetCookie(controller.config.RedirectCookieName, redirectURI, int(time.Hour.Seconds()), "/", fmt.Sprintf(".%s", controller.config.RootDomain), controller.config.SecureCookie, true)
}

c.JSON(200, gin.H{
Expand All @@ -104,58 +104,58 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
}

state := c.Query("state")
csrfCookie, err := c.Cookie(controller.Config.CSRFCookieName)
csrfCookie, err := c.Cookie(controller.config.CSRFCookieName)

if err != nil || state != csrfCookie {
log.Warn().Err(err).Msg("CSRF token mismatch or cookie missing")
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.Config.AppURL))
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
return
}

c.SetCookie(controller.Config.CSRFCookieName, "", -1, "/", fmt.Sprintf(".%s", controller.Config.RootDomain), controller.Config.SecureCookie, true)
c.SetCookie(controller.config.CSRFCookieName, "", -1, "/", fmt.Sprintf(".%s", controller.config.RootDomain), controller.config.SecureCookie, true)

code := c.Query("code")
service, exists := controller.Broker.GetService(req.Provider)
service, exists := controller.broker.GetService(req.Provider)

if !exists {
log.Warn().Msgf("OAuth provider not found: %s", req.Provider)
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.Config.AppURL))
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
return
}

err = service.VerifyCode(code)
if err != nil {
log.Error().Err(err).Msg("Failed to verify OAuth code")
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.Config.AppURL))
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
return
}

user, err := controller.Broker.GetUser(req.Provider)
user, err := controller.broker.GetUser(req.Provider)

if err != nil {
log.Error().Err(err).Msg("Failed to get user from OAuth provider")
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.Config.AppURL))
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
return
}

if user.Email == "" {
log.Error().Msg("OAuth provider did not return an email")
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.Config.AppURL))
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
return
}

if !controller.Auth.IsEmailWhitelisted(user.Email) {
if !controller.auth.IsEmailWhitelisted(user.Email) {
queries, err := query.Values(config.UnauthorizedQuery{
Username: user.Email,
})

if err != nil {
log.Error().Err(err).Msg("Failed to encode unauthorized query")
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.Config.AppURL))
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
return
}

c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/unauthorized?%s", controller.Config.AppURL, queries.Encode()))
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/unauthorized?%s", controller.config.AppURL, queries.Encode()))
return
}

Expand All @@ -169,29 +169,35 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
name = fmt.Sprintf("%s (%s)", utils.Capitalize(strings.Split(user.Email, "@")[0]), strings.Split(user.Email, "@")[1])
}

var usename string
var username string

if user.PreferredUsername != "" {
log.Debug().Msg("Using preferred username from OAuth provider")
usename = user.PreferredUsername
username = user.PreferredUsername
} else {
log.Debug().Msg("No preferred username from OAuth provider, using pseudo username")
usename = strings.Replace(user.Email, "@", "_", -1)
username = strings.Replace(user.Email, "@", "_", -1)
}

controller.Auth.CreateSessionCookie(c, &config.SessionCookie{
Username: usename,
err = controller.auth.CreateSessionCookie(c, &config.SessionCookie{
Username: username,
Name: name,
Email: user.Email,
Provider: req.Provider,
OAuthGroups: utils.CoalesceToString(user.Groups),
})

redirectURI, err := c.Cookie(controller.Config.RedirectCookieName)
if err != nil {
log.Error().Err(err).Msg("Failed to create session cookie")
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
return
}

redirectURI, err := c.Cookie(controller.config.RedirectCookieName)

if err != nil || !utils.IsRedirectSafe(redirectURI, controller.Config.RootDomain) {
if err != nil || !utils.IsRedirectSafe(redirectURI, controller.config.RootDomain) {
log.Debug().Msg("No redirect URI cookie found, redirecting to app root")
c.Redirect(http.StatusTemporaryRedirect, controller.Config.AppURL)
c.Redirect(http.StatusTemporaryRedirect, controller.config.AppURL)
return
}

Expand All @@ -201,10 +207,10 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {

if err != nil {
log.Error().Err(err).Msg("Failed to encode redirect URI query")
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.Config.AppURL))
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
return
}

c.SetCookie(controller.Config.RedirectCookieName, "", -1, "/", fmt.Sprintf(".%s", controller.Config.RootDomain), controller.Config.SecureCookie, true)
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/continue?%s", controller.Config.AppURL, queries.Encode()))
c.SetCookie(controller.config.RedirectCookieName, "", -1, "/", fmt.Sprintf(".%s", controller.config.RootDomain), controller.config.SecureCookie, true)
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/continue?%s", controller.config.AppURL, queries.Encode()))
}
Loading
Loading