diff --git a/pkg/auth/auth.go b/pkg/auth/auth.go index d07942c..07c0a3a 100644 --- a/pkg/auth/auth.go +++ b/pkg/auth/auth.go @@ -20,6 +20,7 @@ type AuthRouter struct { providers []Provider loginTemplate *template.Template unauthorizedTemplate *template.Template + errorTemplate *template.Template } func NewAuthRouter(passwordHash []string, providers ...Provider) (*AuthRouter, error) { @@ -33,11 +34,17 @@ func NewAuthRouter(passwordHash []string, providers ...Provider) (*AuthRouter, e return nil, err } + errorTmpl, err := template.ParseFS(templateFS, "templates/error.html") + if err != nil { + return nil, err + } + return &AuthRouter{ passwordHash: passwordHash, providers: providers, loginTemplate: tmpl, unauthorizedTemplate: unauthorizedTmpl, + errorTemplate: errorTmpl, }, nil } @@ -69,22 +76,22 @@ func (a *AuthRouter) SetupRoutes(router gin.IRouter) { session := sessions.Default(c) state := session.Get(SessionKeyOAuthState) if state == nil { - c.Error(errors.New("OAuth state is missing")) + a.renderError(c, errors.New("OAuth state is missing")) return } token, err := provider.Exchange(c, state.(string)) if err != nil { - c.Error(err) + a.renderError(c, err) return } userID, err := provider.GetUserID(c, token) if err != nil { - c.Error(err) + a.renderError(c, err) return } ok, err := provider.Authorization(userID) if err != nil { - c.Error(err) + a.renderError(c, err) return } if !ok { @@ -97,7 +104,10 @@ func (a *AuthRouter) SetupRoutes(router gin.IRouter) { if redirectURL != nil { session.Delete(SessionKeyRedirectURL) } - session.Save() + if err := session.Save(); err != nil { + a.renderError(c, err) + return + } if redirectURL == nil { c.Redirect(http.StatusFound, "/") @@ -111,16 +121,19 @@ func (a *AuthRouter) SetupRoutes(router gin.IRouter) { state, err := utils.GenerateState() if err != nil { - c.Error(err) + a.renderError(c, err) return } url, err := provider.AuthCodeURL(c, state) if err != nil { - c.Error(err) + a.renderError(c, err) return } session.Set(SessionKeyOAuthState, state) - session.Save() + if err := session.Save(); err != nil { + a.renderError(c, err) + return + } c.Redirect(http.StatusFound, url) }) } @@ -176,7 +189,10 @@ func (a *AuthRouter) handleLoginPost(c *gin.Context) { if redirectURL != nil { session.Delete(SessionKeyRedirectURL) } - session.Save() + if err := session.Save(); err != nil { + a.renderError(c, err) + return + } if redirectURL == nil { c.Redirect(http.StatusFound, "/") @@ -189,7 +205,10 @@ func (a *AuthRouter) handleLogout(c *gin.Context) { session := sessions.Default(c) session.Delete(SessionKeyProvider) session.Delete(SessionKeyUserID) - session.Save() + if err := session.Save(); err != nil { + a.renderError(c, err) + return + } c.Redirect(http.StatusFound, LoginEndpoint) } @@ -200,7 +219,10 @@ func (a *AuthRouter) RequireAuth() gin.HandlerFunc { userID := session.Get(SessionKeyUserID) if providerName == nil || userID == nil { session.Set(SessionKeyRedirectURL, c.Request.URL.String()) - session.Save() + if err := session.Save(); err != nil { + a.renderError(c, err) + return + } c.Redirect(http.StatusFound, LoginEndpoint) return } @@ -241,6 +263,10 @@ type unauthorizedTemplateData struct { Provider string } +type errorTemplateData struct { + ErrorMessage string +} + func (a *AuthRouter) renderLogin(c *gin.Context, passwordError string) { data := loginTemplateData{ Providers: a.providers, @@ -271,3 +297,16 @@ func (a *AuthRouter) renderUnauthorized(c *gin.Context, userID, providerName str return } } + +func (a *AuthRouter) renderError(c *gin.Context, err error) { + data := errorTemplateData{ + ErrorMessage: err.Error(), + } + c.Header("Content-Type", "text/html; charset=utf-8") + c.Status(http.StatusInternalServerError) + if templateErr := a.errorTemplate.Execute(c.Writer, data); templateErr != nil { + c.AbortWithError(http.StatusInternalServerError, templateErr) + return + } + c.Abort() +} diff --git a/pkg/auth/templates/error.html b/pkg/auth/templates/error.html new file mode 100644 index 0000000..d78ed58 --- /dev/null +++ b/pkg/auth/templates/error.html @@ -0,0 +1,73 @@ + + +
+ + +