Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 50 additions & 11 deletions pkg/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ type AuthRouter struct {
providers []Provider
loginTemplate *template.Template
unauthorizedTemplate *template.Template
errorTemplate *template.Template
}

func NewAuthRouter(passwordHash []string, providers ...Provider) (*AuthRouter, error) {
Expand All @@ -33,11 +34,17 @@ func NewAuthRouter(passwordHash []string, providers ...Provider) (*AuthRouter, e
return nil, err
}

errorTmpl, err := template.ParseFS(templateFS, "templates/error.html")
if err != nil {
return nil, err
}

return &AuthRouter{
passwordHash: passwordHash,
providers: providers,
loginTemplate: tmpl,
unauthorizedTemplate: unauthorizedTmpl,
errorTemplate: errorTmpl,
}, nil
}

Expand Down Expand Up @@ -69,22 +76,22 @@ func (a *AuthRouter) SetupRoutes(router gin.IRouter) {
session := sessions.Default(c)
state := session.Get(SessionKeyOAuthState)
if state == nil {
c.Error(errors.New("OAuth state is missing"))
a.renderError(c, errors.New("OAuth state is missing"))
return
}
token, err := provider.Exchange(c, state.(string))
if err != nil {
c.Error(err)
a.renderError(c, err)
return
}
userID, err := provider.GetUserID(c, token)
if err != nil {
c.Error(err)
a.renderError(c, err)
return
}
ok, err := provider.Authorization(userID)
if err != nil {
c.Error(err)
a.renderError(c, err)
return
}
if !ok {
Expand All @@ -97,7 +104,10 @@ func (a *AuthRouter) SetupRoutes(router gin.IRouter) {
if redirectURL != nil {
session.Delete(SessionKeyRedirectURL)
}
session.Save()
if err := session.Save(); err != nil {
a.renderError(c, err)
return
}

if redirectURL == nil {
c.Redirect(http.StatusFound, "/")
Expand All @@ -111,16 +121,19 @@ func (a *AuthRouter) SetupRoutes(router gin.IRouter) {

state, err := utils.GenerateState()
if err != nil {
c.Error(err)
a.renderError(c, err)
return
}
url, err := provider.AuthCodeURL(c, state)
if err != nil {
c.Error(err)
a.renderError(c, err)
return
}
session.Set(SessionKeyOAuthState, state)
session.Save()
if err := session.Save(); err != nil {
a.renderError(c, err)
return
}
c.Redirect(http.StatusFound, url)
})
}
Expand Down Expand Up @@ -176,7 +189,10 @@ func (a *AuthRouter) handleLoginPost(c *gin.Context) {
if redirectURL != nil {
session.Delete(SessionKeyRedirectURL)
}
session.Save()
if err := session.Save(); err != nil {
a.renderError(c, err)
return
}

if redirectURL == nil {
c.Redirect(http.StatusFound, "/")
Expand All @@ -189,7 +205,10 @@ func (a *AuthRouter) handleLogout(c *gin.Context) {
session := sessions.Default(c)
session.Delete(SessionKeyProvider)
session.Delete(SessionKeyUserID)
session.Save()
if err := session.Save(); err != nil {
a.renderError(c, err)
return
}
c.Redirect(http.StatusFound, LoginEndpoint)
}

Expand All @@ -200,7 +219,10 @@ func (a *AuthRouter) RequireAuth() gin.HandlerFunc {
userID := session.Get(SessionKeyUserID)
if providerName == nil || userID == nil {
session.Set(SessionKeyRedirectURL, c.Request.URL.String())
session.Save()
if err := session.Save(); err != nil {
a.renderError(c, err)
return
}
c.Redirect(http.StatusFound, LoginEndpoint)
return
}
Expand Down Expand Up @@ -241,6 +263,10 @@ type unauthorizedTemplateData struct {
Provider string
}

type errorTemplateData struct {
ErrorMessage string
}

func (a *AuthRouter) renderLogin(c *gin.Context, passwordError string) {
data := loginTemplateData{
Providers: a.providers,
Expand Down Expand Up @@ -271,3 +297,16 @@ func (a *AuthRouter) renderUnauthorized(c *gin.Context, userID, providerName str
return
}
}

func (a *AuthRouter) renderError(c *gin.Context, err error) {
data := errorTemplateData{
ErrorMessage: err.Error(),
}
c.Header("Content-Type", "text/html; charset=utf-8")
c.Status(http.StatusInternalServerError)
if templateErr := a.errorTemplate.Execute(c.Writer, data); templateErr != nil {
c.AbortWithError(http.StatusInternalServerError, templateErr)
Copy link

Copilot AI Aug 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using c.AbortWithError() after already setting the status and writing to the response can cause issues. Since the template execution failed, the response may be partially written. Consider using c.String(http.StatusInternalServerError, "Internal Server Error") instead to ensure a clean error response.

Suggested change
c.AbortWithError(http.StatusInternalServerError, templateErr)
c.String(http.StatusInternalServerError, "Internal Server Error")

Copilot uses AI. Check for mistakes.
return
}
c.Abort()
Copy link

Copilot AI Aug 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The c.Abort() call is unnecessary here since the response has already been completed. The function should return after successfully executing the template. This call could interfere with the response that was just sent.

Suggested change
c.Abort()

Copilot uses AI. Check for mistakes.
}
73 changes: 73 additions & 0 deletions pkg/auth/templates/error.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>Error - MCP Auth Proxy</title>
<style>
body {
font-family: "Segoe UI", Tahoma, Geneva, Verdana, sans-serif;
background: linear-gradient(135deg, #f8f9fa 0%, #e9ecef 100%);
margin: 0;
padding: 0;
min-height: 100vh;
display: flex;
align-items: center;
justify-content: center;
}
.error-container {
background: white;
padding: 2.5rem;
border-radius: 12px;
box-shadow: 0 8px 32px rgba(0, 0, 0, 0.12);
text-align: center;
min-width: 350px;
max-width: 500px;
width: 100%;
}
.error-icon {
font-size: 4rem;
color: #e74c3c;
margin-bottom: 1rem;
}
h1 {
color: #333;
margin-bottom: 1.5rem;
font-size: 2rem;
font-weight: 600;
}
.error-message {
color: #666;
margin-bottom: 2rem;
font-size: 1.1rem;
line-height: 1.6;
}
.error-details {
background: #f8f9fa;
border: 1px solid #e9ecef;
border-radius: 8px;
padding: 1rem;
margin-bottom: 2rem;
font-family: monospace;
color: #e74c3c;
font-size: 0.9rem;
text-align: left;
word-wrap: break-word;
}
</style>
</head>
<body>
<div class="error-container">
<div class="error-icon">⚠️</div>
<h1>Error</h1>
<div class="error-message">
An error occurred while processing your request.
</div>
{{if .ErrorMessage}}
<div class="error-details">
{{.ErrorMessage}}
</div>
{{end}}
</div>
</body>
</html>
4 changes: 0 additions & 4 deletions pkg/mcp-proxy/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,10 +207,6 @@ func Run(
router.Use(ginzap.Ginzap(logger, time.RFC3339, true))
router.Use(ginzap.RecoveryWithZap(logger, true))
store := cookie.NewStore(secret)
store.Options(sessions.Options{
MaxAge: 3600,
HttpOnly: true,
})
router.Use(sessions.Sessions("session", store))
authRouter.SetupRoutes(router)
idpRouter.SetupRoutes(router)
Expand Down