diff --git a/pkg/auth/auth.go b/pkg/auth/auth.go index 299d9d0..d07942c 100644 --- a/pkg/auth/auth.go +++ b/pkg/auth/auth.go @@ -18,7 +18,7 @@ var templateFS embed.FS type AuthRouter struct { passwordHash []string providers []Provider - template *template.Template + loginTemplate *template.Template unauthorizedTemplate *template.Template } @@ -36,7 +36,7 @@ func NewAuthRouter(passwordHash []string, providers ...Provider) (*AuthRouter, e return &AuthRouter{ passwordHash: passwordHash, providers: providers, - template: tmpl, + loginTemplate: tmpl, unauthorizedTemplate: unauthorizedTmpl, }, nil } @@ -82,11 +82,28 @@ func (a *AuthRouter) SetupRoutes(router gin.IRouter) { c.Error(err) return } + ok, err := provider.Authorization(userID) + if err != nil { + c.Error(err) + return + } + if !ok { + a.renderUnauthorized(c, userID, provider.Name()) + return + } session.Set(SessionKeyProvider, provider.Name()) session.Set(SessionKeyUserID, userID) - session.Save() redirectURL := session.Get(SessionKeyRedirectURL) - c.Redirect(http.StatusFound, redirectURL.(string)) + if redirectURL != nil { + session.Delete(SessionKeyRedirectURL) + } + session.Save() + + if redirectURL == nil { + c.Redirect(http.StatusFound, "/") + } else { + c.Redirect(http.StatusFound, redirectURL.(string)) + } }) router.GET(provider.AuthURL(), func(c *gin.Context) { @@ -118,29 +135,12 @@ func (a *AuthRouter) getProvider(name string) Provider { return nil } -type templateData struct { - Providers []Provider - HasPassword bool - PasswordError string -} - func (a *AuthRouter) handleLogin(c *gin.Context) { if c.Request.Method == "POST" { a.handleLoginPost(c) return } - - data := templateData{ - Providers: a.providers, - HasPassword: len(a.passwordHash) > 0, - PasswordError: "", - } - - c.Header("Content-Type", "text/html; charset=utf-8") - if err := a.template.Execute(c.Writer, data); err != nil { - c.AbortWithError(http.StatusInternalServerError, err) - return - } + a.renderLogin(c, "") } func (a *AuthRouter) handleLoginPost(c *gin.Context) { @@ -165,39 +165,32 @@ func (a *AuthRouter) handleLoginPost(c *gin.Context) { } if errorMessage != "" { - data := templateData{ - Providers: a.providers, - HasPassword: len(a.passwordHash) > 0, - PasswordError: errorMessage, - } - - c.Header("Content-Type", "text/html; charset=utf-8") - c.Status(http.StatusBadRequest) - if err := a.template.Execute(c.Writer, data); err != nil { - c.AbortWithError(http.StatusInternalServerError, err) - return - } + a.renderLogin(c, errorMessage) return } session := sessions.Default(c) session.Set(SessionKeyProvider, PasswordProvider) session.Set(SessionKeyUserID, PasswordUserID) + redirectURL := session.Get(SessionKeyRedirectURL) + if redirectURL != nil { + session.Delete(SessionKeyRedirectURL) + } session.Save() - redirectURL := session.Get(SessionKeyRedirectURL) if redirectURL == nil { c.Redirect(http.StatusFound, "/") - return + } else { + c.Redirect(http.StatusFound, redirectURL.(string)) } - c.Redirect(http.StatusFound, redirectURL.(string)) } func (a *AuthRouter) handleLogout(c *gin.Context) { session := sessions.Default(c) - session.Clear() + session.Delete(SessionKeyProvider) + session.Delete(SessionKeyUserID) session.Save() - c.String(http.StatusOK, "Logged out") + c.Redirect(http.StatusFound, LoginEndpoint) } func (a *AuthRouter) RequireAuth() gin.HandlerFunc { @@ -229,22 +222,52 @@ func (a *AuthRouter) RequireAuth() gin.HandlerFunc { return } if !ok { - data := struct { - UserID string - Provider string - }{ - UserID: userID.(string), - Provider: providerName.(string), - } - c.Header("Content-Type", "text/html; charset=utf-8") - c.Status(http.StatusForbidden) - if err := a.unauthorizedTemplate.Execute(c.Writer, data); err != nil { - c.AbortWithError(http.StatusInternalServerError, err) - return - } + a.renderUnauthorized(c, userID.(string), providerName.(string)) c.Abort() return } c.Next() } } + +type loginTemplateData struct { + Providers []Provider + HasPassword bool + PasswordError string +} + +type unauthorizedTemplateData struct { + UserID string + Provider string +} + +func (a *AuthRouter) renderLogin(c *gin.Context, passwordError string) { + data := loginTemplateData{ + Providers: a.providers, + HasPassword: len(a.passwordHash) > 0, + PasswordError: passwordError, + } + c.Header("Content-Type", "text/html; charset=utf-8") + if passwordError != "" { + c.Status(http.StatusBadRequest) + } else { + c.Status(http.StatusOK) + } + if err := a.loginTemplate.Execute(c.Writer, data); err != nil { + c.AbortWithError(http.StatusInternalServerError, err) + return + } +} + +func (a *AuthRouter) renderUnauthorized(c *gin.Context, userID, providerName string) { + data := unauthorizedTemplateData{ + UserID: userID, + Provider: providerName, + } + c.Header("Content-Type", "text/html; charset=utf-8") + c.Status(http.StatusForbidden) + if err := a.unauthorizedTemplate.Execute(c.Writer, data); err != nil { + c.AbortWithError(http.StatusInternalServerError, err) + return + } +} diff --git a/pkg/auth/auth_test.go b/pkg/auth/auth_test.go index 2d7d870..d962cf5 100644 --- a/pkg/auth/auth_test.go +++ b/pkg/auth/auth_test.go @@ -79,10 +79,15 @@ func TestAuthenticationFlow(t *testing.T) { defer ctrl.Finish() // Create mock provider + mockToken := &oauth2.Token{AccessToken: "test-token"} mockProvider := NewMockProvider(ctrl) mockProvider.EXPECT().Name().Return("test").AnyTimes() mockProvider.EXPECT().AuthURL().Return("/.auth/test").AnyTimes() mockProvider.EXPECT().RedirectURL().Return("/.auth/test/callback").AnyTimes() + mockProvider.EXPECT().AuthCodeURL(gomock.Any(), gomock.Any()).Return("https://example.com/oauth", nil) + mockProvider.EXPECT().Exchange(gomock.Any(), gomock.Any()).Return(mockToken, nil) + mockProvider.EXPECT().GetUserID(gomock.Any(), mockToken).Return("test-user", nil) + mockProvider.EXPECT().Authorization("test-user").Return(true, nil).AnyTimes() // Create AuthRouter authRouter, err := NewAuthRouter(nil, mockProvider) @@ -103,8 +108,6 @@ func TestAuthenticationFlow(t *testing.T) { require.Equal(t, http.StatusFound, resp.StatusCode) // Step 2: Start authentication - mockProvider.EXPECT().AuthCodeURL(gomock.Any(), gomock.Any()).Return("https://example.com/oauth", nil) - resp, err = client.Get(server.URL + "/.auth/test") require.NoError(t, err) resp.Body.Close() @@ -115,23 +118,15 @@ func TestAuthenticationFlow(t *testing.T) { require.Equal(t, "https://example.com/oauth", location) // Step 3: Handle callback - mockToken := &oauth2.Token{AccessToken: "test-token"} - mockProvider.EXPECT().Exchange(gomock.Any(), gomock.Any()).Return(mockToken, nil) - mockProvider.EXPECT().GetUserID(gomock.Any(), mockToken).Return("test-user", nil) - resp, err = client.Get(server.URL + "/.auth/test/callback") require.NoError(t, err) resp.Body.Close() require.Equal(t, http.StatusFound, resp.StatusCode) - - // Verify redirect to root location = resp.Header.Get("Location") require.Equal(t, "/", location) // Step 4: Access after authentication - mockProvider.EXPECT().Authorization("test-user").Return(true, nil) - resp, err = client.Get(server.URL + "/") if err != nil { t.Fatalf("Request failed: %v", err) @@ -146,10 +141,15 @@ func TestAuthenticationFlow(t *testing.T) { defer ctrl.Finish() // Create mock provider + mockToken := &oauth2.Token{AccessToken: "test-token"} mockProvider := NewMockProvider(ctrl) mockProvider.EXPECT().Name().Return("test").AnyTimes() mockProvider.EXPECT().AuthURL().Return("/.auth/test").AnyTimes() mockProvider.EXPECT().RedirectURL().Return("/.auth/test/callback").AnyTimes() + mockProvider.EXPECT().AuthCodeURL(gomock.Any(), gomock.Any()).Return("https://example.com/oauth", nil) + mockProvider.EXPECT().Exchange(gomock.Any(), gomock.Any()).Return(mockToken, nil) + mockProvider.EXPECT().GetUserID(gomock.Any(), mockToken).Return("unauthorized-user", nil) + mockProvider.EXPECT().Authorization("unauthorized-user").Return(false, nil).AnyTimes() // Create AuthRouter authRouter, err := NewAuthRouter(nil, mockProvider) @@ -167,30 +167,26 @@ func TestAuthenticationFlow(t *testing.T) { resp.Body.Close() // Step 2: Start authentication - mockProvider.EXPECT().AuthCodeURL(gomock.Any(), gomock.Any()).Return("https://example.com/oauth", nil) - resp, err = client.Get(server.URL + "/.auth/test") require.NoError(t, err) resp.Body.Close() // Step 3: Complete authentication - mockToken := &oauth2.Token{AccessToken: "test-token"} - mockProvider.EXPECT().Exchange(gomock.Any(), gomock.Any()).Return(mockToken, nil) - mockProvider.EXPECT().GetUserID(gomock.Any(), mockToken).Return("unauthorized-user", nil) - resp, err = client.Get(server.URL + "/.auth/test/callback") require.NoError(t, err) resp.Body.Close() - // Step 4: Test access when authorization fails - mockProvider.EXPECT().Authorization("unauthorized-user").Return(false, nil) + require.Equal(t, http.StatusForbidden, resp.StatusCode) + // Step 4: Test access when authorization fails resp, err = client.Get(server.URL + "/") if err != nil { t.Fatalf("Request failed: %v", err) } defer resp.Body.Close() - require.Equal(t, http.StatusForbidden, resp.StatusCode) + require.Equal(t, http.StatusFound, resp.StatusCode) + location := resp.Header.Get("Location") + require.Equal(t, "/.auth/login", location) }) }