From efc795d4156f85689ff50815b5ed92e57dc3e1d6 Mon Sep 17 00:00:00 2001 From: Takanori Hirano Date: Sat, 23 Aug 2025 18:54:31 +0000 Subject: [PATCH 1/3] feat: enhance OAuth providers with organization and workspace support - Add GitHub organization and team-based authorization - Add Google Workspace domain-based authorization - Consolidate authentication flow by combining user retrieval and authorization - Add comprehensive test coverage for OAuth providers - Add utilities for better error handling - Improve session management with proper cookie settings BREAKING CHANGE: Authorization interface changed from separate GetUserID/Authorization calls to combined Authorization method --- main.go | 22 ++++ pkg/auth/auth.go | 50 ++------ pkg/auth/auth_test.go | 11 +- pkg/auth/github.go | 107 ++++++++++++++--- pkg/auth/github_test.go | 176 ++++++++++++++++++++++++++++ pkg/auth/google.go | 59 ++++++---- pkg/auth/google_test.go | 191 +++++++++++++++++++++++++++++++ pkg/auth/interface.go | 5 +- pkg/auth/main_test.go | 13 +++ pkg/auth/mock.go | 36 ++---- pkg/auth/oidc.go | 28 +++-- pkg/auth/oidc_test.go | 248 ++++++++++++++++++++++++++++++++++++++++ pkg/idp/idp_test.go | 3 +- pkg/mcp-proxy/main.go | 13 ++- pkg/utils/must.go | 8 ++ 15 files changed, 839 insertions(+), 131 deletions(-) create mode 100644 pkg/auth/github_test.go create mode 100644 pkg/auth/google_test.go create mode 100644 pkg/auth/main_test.go create mode 100644 pkg/auth/oidc_test.go create mode 100644 pkg/utils/must.go diff --git a/main.go b/main.go index b0ef87e..65fb1b1 100644 --- a/main.go +++ b/main.go @@ -37,9 +37,11 @@ func main() { var googleClientID string var googleClientSecret string var googleAllowedUsers string + var googleAllowedWorkspaces string var githubClientID string var githubClientSecret string var githubAllowedUsers string + var githubAllowedOrgs string var oidcConfigurationURL string var oidcClientID string var oidcClientSecret string @@ -63,6 +65,14 @@ func main() { } } + var googleAllowedWorkspacesList []string + if googleAllowedWorkspaces != "" { + googleAllowedWorkspacesList = strings.Split(googleAllowedWorkspaces, ",") + for i := range googleAllowedWorkspacesList { + googleAllowedWorkspacesList[i] = strings.TrimSpace(googleAllowedWorkspacesList[i]) + } + } + var githubAllowedUsersList []string if githubAllowedUsers != "" { githubAllowedUsersList = strings.Split(githubAllowedUsers, ",") @@ -71,6 +81,14 @@ func main() { } } + var githubAllowedOrgsList []string + if githubAllowedOrgs != "" { + githubAllowedOrgsList = strings.Split(githubAllowedOrgs, ",") + for i := range githubAllowedOrgsList { + githubAllowedOrgsList[i] = strings.TrimSpace(githubAllowedOrgsList[i]) + } + } + var oidcAllowedUsersList []string if oidcAllowedUsers != "" { oidcAllowedUsersList = strings.Split(oidcAllowedUsers, ",") @@ -110,9 +128,11 @@ func main() { googleClientID, googleClientSecret, googleAllowedUsersList, + googleAllowedWorkspacesList, githubClientID, githubClientSecret, githubAllowedUsersList, + githubAllowedOrgsList, oidcConfigurationURL, oidcClientID, oidcClientSecret, @@ -144,11 +164,13 @@ func main() { rootCmd.Flags().StringVar(&googleClientID, "google-client-id", getEnvWithDefault("GOOGLE_CLIENT_ID", ""), "Google OAuth client ID") rootCmd.Flags().StringVar(&googleClientSecret, "google-client-secret", getEnvWithDefault("GOOGLE_CLIENT_SECRET", ""), "Google OAuth client secret") rootCmd.Flags().StringVar(&googleAllowedUsers, "google-allowed-users", getEnvWithDefault("GOOGLE_ALLOWED_USERS", ""), "Comma-separated list of allowed Google users (emails)") + rootCmd.Flags().StringVar(&googleAllowedWorkspaces, "google-allowed-workspaces", getEnvWithDefault("GOOGLE_ALLOWED_WORKSPACES", ""), "Comma-separated list of allowed Google workspaces") // GitHub OAuth configuration rootCmd.Flags().StringVar(&githubClientID, "github-client-id", getEnvWithDefault("GITHUB_CLIENT_ID", ""), "GitHub OAuth client ID") rootCmd.Flags().StringVar(&githubClientSecret, "github-client-secret", getEnvWithDefault("GITHUB_CLIENT_SECRET", ""), "GitHub OAuth client secret") rootCmd.Flags().StringVar(&githubAllowedUsers, "github-allowed-users", getEnvWithDefault("GITHUB_ALLOWED_USERS", ""), "Comma-separated list of allowed GitHub users (usernames)") + rootCmd.Flags().StringVar(&githubAllowedOrgs, "github-allowed-orgs", getEnvWithDefault("GITHUB_ALLOWED_ORGS", ""), "Comma-separated list of allowed GitHub organizations. You can also restrict access to specific teams using the format `Org:Team`") // OIDC configuration rootCmd.Flags().StringVar(&oidcConfigurationURL, "oidc-configuration-url", getEnvWithDefault("OIDC_CONFIGURATION_URL", ""), "OIDC configuration URL") diff --git a/pkg/auth/auth.go b/pkg/auth/auth.go index 07c0a3a..9fbb2b7 100644 --- a/pkg/auth/auth.go +++ b/pkg/auth/auth.go @@ -61,8 +61,7 @@ const ( PasswordProvider = "password" PasswordUserID = "password_user" - SessionKeyProvider = "provider" - SessionKeyUserID = "user_id" + SessionKeyAuthorized = "authorized" SessionKeyRedirectURL = "redirect_url" SessionKeyOAuthState = "oauth_state" ) @@ -84,22 +83,16 @@ func (a *AuthRouter) SetupRoutes(router gin.IRouter) { a.renderError(c, err) return } - userID, err := provider.GetUserID(c, token) - if err != nil { - a.renderError(c, err) - return - } - ok, err := provider.Authorization(userID) + ok, user, err := provider.Authorization(c, token) if err != nil { a.renderError(c, err) return } if !ok { - a.renderUnauthorized(c, userID, provider.Name()) + a.renderUnauthorized(c, user, provider.Name()) return } - session.Set(SessionKeyProvider, provider.Name()) - session.Set(SessionKeyUserID, userID) + session.Set(SessionKeyAuthorized, true) redirectURL := session.Get(SessionKeyRedirectURL) if redirectURL != nil { session.Delete(SessionKeyRedirectURL) @@ -124,7 +117,7 @@ func (a *AuthRouter) SetupRoutes(router gin.IRouter) { a.renderError(c, err) return } - url, err := provider.AuthCodeURL(c, state) + url, err := provider.AuthCodeURL(state) if err != nil { a.renderError(c, err) return @@ -183,8 +176,7 @@ func (a *AuthRouter) handleLoginPost(c *gin.Context) { } session := sessions.Default(c) - session.Set(SessionKeyProvider, PasswordProvider) - session.Set(SessionKeyUserID, PasswordUserID) + session.Set(SessionKeyAuthorized, true) redirectURL := session.Get(SessionKeyRedirectURL) if redirectURL != nil { session.Delete(SessionKeyRedirectURL) @@ -203,8 +195,7 @@ func (a *AuthRouter) handleLoginPost(c *gin.Context) { func (a *AuthRouter) handleLogout(c *gin.Context) { session := sessions.Default(c) - session.Delete(SessionKeyProvider) - session.Delete(SessionKeyUserID) + session.Delete(SessionKeyAuthorized) if err := session.Save(); err != nil { a.renderError(c, err) return @@ -215,9 +206,8 @@ func (a *AuthRouter) handleLogout(c *gin.Context) { func (a *AuthRouter) RequireAuth() gin.HandlerFunc { return func(c *gin.Context) { session := sessions.Default(c) - providerName := session.Get(SessionKeyProvider) - userID := session.Get(SessionKeyUserID) - if providerName == nil || userID == nil { + authorized := session.Get(SessionKeyAuthorized) + if authorized == nil { session.Set(SessionKeyRedirectURL, c.Request.URL.String()) if err := session.Save(); err != nil { a.renderError(c, err) @@ -227,25 +217,9 @@ func (a *AuthRouter) RequireAuth() gin.HandlerFunc { return } - // Allow password authentication - if providerName.(string) == PasswordProvider { - c.Next() - return - } - - p := a.getProvider(providerName.(string)) - if p == nil { - c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Unknown provider"}) - return - } - ok, err := p.Authorization(userID.(string)) - if err != nil { - c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "Authorization failed"}) - return - } - if !ok { - a.renderUnauthorized(c, userID.(string), providerName.(string)) - c.Abort() + if !authorized.(bool) { + // not expected + c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"}) return } c.Next() diff --git a/pkg/auth/auth_test.go b/pkg/auth/auth_test.go index d962cf5..48b8b49 100644 --- a/pkg/auth/auth_test.go +++ b/pkg/auth/auth_test.go @@ -15,7 +15,6 @@ import ( ) func setupTestRouter(authRouter *AuthRouter) *gin.Engine { - gin.SetMode(gin.TestMode) router := gin.New() // Setup session middleware @@ -84,10 +83,9 @@ func TestAuthenticationFlow(t *testing.T) { mockProvider.EXPECT().Name().Return("test").AnyTimes() mockProvider.EXPECT().AuthURL().Return("/.auth/test").AnyTimes() mockProvider.EXPECT().RedirectURL().Return("/.auth/test/callback").AnyTimes() - mockProvider.EXPECT().AuthCodeURL(gomock.Any(), gomock.Any()).Return("https://example.com/oauth", nil) + mockProvider.EXPECT().AuthCodeURL(gomock.Any()).Return("https://example.com/oauth", nil) mockProvider.EXPECT().Exchange(gomock.Any(), gomock.Any()).Return(mockToken, nil) - mockProvider.EXPECT().GetUserID(gomock.Any(), mockToken).Return("test-user", nil) - mockProvider.EXPECT().Authorization("test-user").Return(true, nil).AnyTimes() + mockProvider.EXPECT().Authorization(gomock.Any(), mockToken).Return(true, "authorized_user", nil) // Create AuthRouter authRouter, err := NewAuthRouter(nil, mockProvider) @@ -146,10 +144,9 @@ func TestAuthenticationFlow(t *testing.T) { mockProvider.EXPECT().Name().Return("test").AnyTimes() mockProvider.EXPECT().AuthURL().Return("/.auth/test").AnyTimes() mockProvider.EXPECT().RedirectURL().Return("/.auth/test/callback").AnyTimes() - mockProvider.EXPECT().AuthCodeURL(gomock.Any(), gomock.Any()).Return("https://example.com/oauth", nil) + mockProvider.EXPECT().AuthCodeURL(gomock.Any()).Return("https://example.com/oauth", nil) mockProvider.EXPECT().Exchange(gomock.Any(), gomock.Any()).Return(mockToken, nil) - mockProvider.EXPECT().GetUserID(gomock.Any(), mockToken).Return("unauthorized-user", nil) - mockProvider.EXPECT().Authorization("unauthorized-user").Return(false, nil).AnyTimes() + mockProvider.EXPECT().Authorization(gomock.Any(), mockToken).Return(false, "unauthorized_user", nil) // Create AuthRouter authRouter, err := NewAuthRouter(nil, mockProvider) diff --git a/pkg/auth/github.go b/pkg/auth/github.go index d71f0fa..f146f38 100644 --- a/pkg/auth/github.go +++ b/pkg/auth/github.go @@ -5,34 +5,53 @@ import ( "encoding/json" "errors" "net/url" + "slices" + "strings" "github.com/gin-gonic/gin" + "github.com/sigbit/mcp-auth-proxy/pkg/utils" "golang.org/x/oauth2" "golang.org/x/oauth2/github" ) type githubProvider struct { + endpoint string oauth2 oauth2.Config allowedUsers []string + allowedOrgs []string } -func NewGithubProvider(clientID, clientSecret, externalURL string, allowedUsers []string) (Provider, error) { +func NewGithubProvider(clientID, clientSecret, externalURL string, allowedUsers []string, allowedOrgs []string) (Provider, error) { r, err := url.JoinPath(externalURL, GitHubCallbackEndpoint) if err != nil { return nil, err } + scopes := []string{} + if len(allowedOrgs) > 0 { + scopes = append(scopes, "read:org") + } return &githubProvider{ + endpoint: "https://api.github.com", oauth2: oauth2.Config{ ClientID: clientID, ClientSecret: clientSecret, RedirectURL: r, - Scopes: []string{"user:email"}, + Scopes: scopes, Endpoint: github.Endpoint, }, allowedUsers: allowedUsers, + allowedOrgs: allowedOrgs, }, nil } +func (p *githubProvider) SetApiEndpoint(u string) { + p.endpoint = u +} + +func (p *githubProvider) SetOAuth2Endpoint(cfg oauth2.Endpoint) { + p.oauth2.Endpoint = cfg +} + func (p *githubProvider) Name() string { return "GitHub" } @@ -49,7 +68,7 @@ func (p *githubProvider) AuthURL() string { return GitHubAuthEndpoint } -func (p *githubProvider) AuthCodeURL(c *gin.Context, state string) (string, error) { +func (p *githubProvider) AuthCodeURL(state string) (string, error) { authURL := p.oauth2.AuthCodeURL(state) return authURL, nil } @@ -66,37 +85,87 @@ func (p *githubProvider) Exchange(c *gin.Context, state string) (*oauth2.Token, return token, nil } -func (p *githubProvider) GetUserID(ctx context.Context, token *oauth2.Token) (string, error) { +func (p *githubProvider) Authorization(ctx context.Context, token *oauth2.Token) (bool, string, error) { client := p.oauth2.Client(ctx, token) - resp, err := client.Get("https://api.github.com/user") + resp, err := client.Get(utils.Must(url.JoinPath(p.endpoint, "/user"))) if err != nil { - return "", err + return false, "", err + } + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return false, "", errors.New("failed to get user info from GitHub API: " + resp.Status) } defer resp.Body.Close() var userInfo struct { - ID uint64 `json:"id"` Login string `json:"login"` - Name string `json:"name"` - Email string `json:"email"` } if err := json.NewDecoder(resp.Body).Decode(&userInfo); err != nil { - return "", err + return false, "", err } - return userInfo.Login, nil -} + if len(p.allowedUsers) == 0 && len(p.allowedOrgs) == 0 { + return true, userInfo.Login, nil + } -func (p *githubProvider) Authorization(userid string) (bool, error) { - if len(p.allowedUsers) == 0 { - return true, nil + if slices.Contains(p.allowedUsers, userInfo.Login) { + return true, userInfo.Login, nil + } + + allowedOrgTeams := []string{} + allowedOrgs := []string{} + for _, allowedOrg := range p.allowedOrgs { + if strings.Contains(allowedOrg, ":") { + allowedOrgTeams = append(allowedOrgTeams, allowedOrg) + } else { + allowedOrgs = append(allowedOrgs, allowedOrg) + } } - for _, allowedUser := range p.allowedUsers { - if allowedUser == userid { - return true, nil + if len(allowedOrgs) > 0 { + resp, err = client.Get(utils.Must(url.JoinPath(p.endpoint, "/user/orgs"))) + if err != nil { + return false, "", err + } + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return false, "", errors.New("failed to get user info from GitHub API: " + resp.Status) + } + defer resp.Body.Close() + var orgInfo []struct { + Login string `json:"login"` + } + if err := json.NewDecoder(resp.Body).Decode(&orgInfo); err != nil { + return false, "", err + } + for _, o := range orgInfo { + if slices.Contains(allowedOrgs, o.Login) { + return true, userInfo.Login, nil + } + } + } + if len(allowedOrgTeams) > 0 { + resp, err = client.Get(utils.Must(url.JoinPath(p.endpoint, "/user/teams"))) + if err != nil { + return false, "", err + } + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return false, "", errors.New("failed to get user info from GitHub API: " + resp.Status) + } + defer resp.Body.Close() + var teamInfo []struct { + Organization struct { + Login string `json:"login"` + } `json:"organization"` + Slug string `json:"slug"` + } + if err := json.NewDecoder(resp.Body).Decode(&teamInfo); err != nil { + return false, "", err + } + for _, team := range teamInfo { + if slices.Contains(allowedOrgTeams, team.Organization.Login+":"+team.Slug) { + return true, userInfo.Login, nil + } } } - return false, nil + return false, userInfo.Login, nil } diff --git a/pkg/auth/github_test.go b/pkg/auth/github_test.go new file mode 100644 index 0000000..f3e0eeb --- /dev/null +++ b/pkg/auth/github_test.go @@ -0,0 +1,176 @@ +package auth + +import ( + "context" + "net/http" + "net/http/httptest" + "net/url" + "testing" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" + "golang.org/x/oauth2" +) + +const ( + TestGitHubClientID = "test-client-id" + TestGitHubClientSecret = "test-client-secret" + TestGitHubExternalURL = "http://localhost:8080" +) + +func setupGitHubTest(allowedUsers, allowedOrgs []string) (Provider, gin.IRoutes) { + p, _ := NewGithubProvider(TestGitHubClientID, TestGitHubClientSecret, TestGitHubExternalURL, allowedUsers, allowedOrgs) + + gh := gin.New() + gh.POST("/login/oauth/access_token", func(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{ + "access_token": "test-access-token", + }) + }) + tsgh := httptest.NewServer(gh) + gp := p.(*githubProvider) + gp.SetOAuth2Endpoint(oauth2.Endpoint{ + AuthURL: tsgh.URL + "/login/oauth/authorize", + TokenURL: tsgh.URL + "/login/oauth/access_token", + DeviceAuthURL: tsgh.URL + "/login/device/code", + }) + + ghapi := gin.New() + tsghapi := httptest.NewServer(ghapi) + gp.SetApiEndpoint(tsghapi.URL) + + return p, ghapi +} + +func TestGitHubProvider(t *testing.T) { + p, _ := setupGitHubTest([]string{}, []string{}) + require.Equal(t, p.Name(), "GitHub") + require.Equal(t, p.Type(), "github") + require.Equal(t, p.RedirectURL(), GitHubCallbackEndpoint) + require.Equal(t, p.AuthURL(), GitHubAuthEndpoint) + + // check AuthCodeURL + authCodeURL, err := p.AuthCodeURL("test-state") + require.NoError(t, err) + require.NotEmpty(t, authCodeURL) + authCodeURLObj, err := url.Parse(authCodeURL) + require.NoError(t, err) + require.Equal(t, authCodeURLObj.Path, "/login/oauth/authorize") + require.Equal(t, authCodeURLObj.Query().Get("client_id"), TestGitHubClientID) + require.Equal(t, authCodeURLObj.Query().Get("redirect_uri"), TestGitHubExternalURL+"/.auth/github/callback") + require.Equal(t, authCodeURLObj.Query().Get("response_type"), "code") + require.Equal(t, authCodeURLObj.Query().Get("state"), "test-state") + + // Check Exchange + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + req, _ := http.NewRequest("GET", "/?state=test-state&code=test-code", nil) + c.Request = req + _, err = p.Exchange(c, "invalid-state") + require.Error(t, err) + token, err := p.Exchange(c, "test-state") + require.NoError(t, err) + require.NotNil(t, token) + require.Equal(t, token.AccessToken, "test-access-token") +} + +func TestGitHubProviderAuthorization(t *testing.T) { + tc := []struct { + name string + allowedUsers []string + allowedOrgs []string + userResp string + teamsResp string + orgsResp string + expect bool + }{ + { + name: "allow all users", + allowedUsers: []string{}, + allowedOrgs: []string{}, + userResp: `{"login": "user1"}`, + orgsResp: `[]`, + teamsResp: `[]`, + expect: true, + }, + { + name: "allow single user", + allowedUsers: []string{"user1", "user2"}, + allowedOrgs: []string{}, + userResp: `{"login": "user1"}`, + orgsResp: `[]`, + teamsResp: `[]`, + expect: true, + }, + { + name: "deny single user", + allowedUsers: []string{"user1"}, + allowedOrgs: []string{}, + userResp: `{"login": "user2"}`, + orgsResp: `[]`, + teamsResp: `[]`, + expect: false, + }, + { + name: "allow by org", + allowedUsers: []string{}, + allowedOrgs: []string{"org1"}, + userResp: `{"login": "user1"}`, + orgsResp: `[{"login": "org1"}]`, + teamsResp: `[]`, + expect: true, + }, + { + name: "deny by org", + allowedUsers: []string{}, + allowedOrgs: []string{"org1"}, + userResp: `{"login": "user1"}`, + orgsResp: `[{"login": "org2"}]`, + teamsResp: `[]`, + expect: false, + }, + { + name: "allow by team", + allowedUsers: []string{}, + allowedOrgs: []string{"org1:team1"}, + userResp: `{"login": "user1"}`, + orgsResp: `[]`, + teamsResp: `[{"organization": {"login": "org1"}, "slug": "team1"}]`, + expect: true, + }, + { + name: "deny by team", + allowedUsers: []string{}, + allowedOrgs: []string{"org1:team1"}, + userResp: `{"login": "user1"}`, + orgsResp: `[]`, + teamsResp: `[{"organization": {"login": "org1"}, "slug": "team2"}]`, + expect: false, + }, + } + + for _, tt := range tc { + t.Run(tt.name, func(t *testing.T) { + p, ghapi := setupGitHubTest(tt.allowedUsers, tt.allowedOrgs) + userResp := tt.userResp + orgsResp := tt.orgsResp + teamsResp := tt.teamsResp + expect := tt.expect + + ghapi.GET("/user", func(c *gin.Context) { + c.Data(http.StatusOK, "application/json", []byte(userResp)) + }) + ghapi.GET("/user/orgs", func(c *gin.Context) { + c.Data(http.StatusOK, "application/json", []byte(orgsResp)) + }) + ghapi.GET("/user/teams", func(c *gin.Context) { + c.Data(http.StatusOK, "application/json", []byte(teamsResp)) + }) + + // Call the Authorization method + ok, _, err := p.Authorization(context.Background(), &oauth2.Token{AccessToken: "test-access-token"}) + require.NoError(t, err) + require.Equal(t, expect, ok) + }) + } +} diff --git a/pkg/auth/google.go b/pkg/auth/google.go index 1fd223f..ccc46cf 100644 --- a/pkg/auth/google.go +++ b/pkg/auth/google.go @@ -5,6 +5,7 @@ import ( "encoding/json" "errors" "net/url" + "slices" "github.com/gin-gonic/gin" "golang.org/x/oauth2" @@ -12,16 +13,19 @@ import ( ) type googleProvider struct { - oauth2 oauth2.Config - allowedUsers []string + userinfoEndpoint string + oauth2 oauth2.Config + allowedUsers []string + allowedWorkspaces []string } -func NewGoogleProvider(externalURL, clientID, clientSecret string, allowedUsers []string) (Provider, error) { +func NewGoogleProvider(externalURL, clientID, clientSecret string, allowedUsers []string, allowedWorkspaces []string) (Provider, error) { r, err := url.JoinPath(externalURL, GoogleCallbackEndpoint) if err != nil { return nil, err } return &googleProvider{ + userinfoEndpoint: "https://openidconnect.googleapis.com/v1/userinfo", oauth2: oauth2.Config{ ClientID: clientID, ClientSecret: clientSecret, @@ -29,10 +33,19 @@ func NewGoogleProvider(externalURL, clientID, clientSecret string, allowedUsers Scopes: []string{"openid profile email"}, Endpoint: google.Endpoint, }, - allowedUsers: allowedUsers, + allowedUsers: allowedUsers, + allowedWorkspaces: allowedWorkspaces, }, nil } +func (p *googleProvider) SetUserinfoEndpoint(u string) { + p.userinfoEndpoint = u +} + +func (p *googleProvider) SetOAuth2Endpoint(cfg oauth2.Endpoint) { + p.oauth2.Endpoint = cfg +} + func (p *googleProvider) Name() string { return "Google" } @@ -45,8 +58,12 @@ func (p *googleProvider) RedirectURL() string { return GoogleCallbackEndpoint } -func (p *googleProvider) AuthCodeURL(c *gin.Context, state string) (string, error) { - authURL := p.oauth2.AuthCodeURL(state) +func (p *googleProvider) AuthCodeURL(state string) (string, error) { + opts := []oauth2.AuthCodeOption{} + if len(p.allowedUsers) == 0 && len(p.allowedWorkspaces) == 1 { + opts = append(opts, oauth2.SetAuthURLParam("hd", p.allowedWorkspaces[0])) + } + authURL := p.oauth2.AuthCodeURL(state, opts...) return authURL, nil } @@ -66,11 +83,14 @@ func (p *googleProvider) Exchange(c *gin.Context, state string) (*oauth2.Token, return token, nil } -func (p *googleProvider) GetUserID(ctx context.Context, token *oauth2.Token) (string, error) { +func (p *googleProvider) Authorization(ctx context.Context, token *oauth2.Token) (bool, string, error) { client := p.oauth2.Client(ctx, token) - resp, err := client.Get("https://openidconnect.googleapis.com/v1/userinfo") + resp, err := client.Get(p.userinfoEndpoint) if err != nil { - return "", err + return false, "", err + } + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return false, "", errors.New("failed to get user info from Google API: " + resp.Status) } defer resp.Body.Close() @@ -78,24 +98,23 @@ func (p *googleProvider) GetUserID(ctx context.Context, token *oauth2.Token) (st Sub string `json:"sub"` Name string `json:"name"` Email string `json:"email"` + HD string `json:"hd"` } if err := json.NewDecoder(resp.Body).Decode(&userInfo); err != nil { - return "", err + return false, "", err } - return userInfo.Email, nil -} + if len(p.allowedUsers) == 0 && len(p.allowedWorkspaces) == 0 { + return true, userInfo.Email, nil + } -func (p *googleProvider) Authorization(userid string) (bool, error) { - if len(p.allowedUsers) == 0 { - return true, nil + if slices.Contains(p.allowedUsers, userInfo.Email) { + return true, userInfo.Email, nil } - for _, allowedUser := range p.allowedUsers { - if allowedUser == userid { - return true, nil - } + if slices.Contains(p.allowedWorkspaces, userInfo.HD) { + return true, userInfo.Email, nil } - return false, nil + return false, userInfo.Email, nil } diff --git a/pkg/auth/google_test.go b/pkg/auth/google_test.go new file mode 100644 index 0000000..ae792ae --- /dev/null +++ b/pkg/auth/google_test.go @@ -0,0 +1,191 @@ +package auth + +import ( + "context" + "net/http" + "net/http/httptest" + "net/url" + "testing" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" + "golang.org/x/oauth2" +) + +const ( + TestGoogleClientID = "test-client-id" + TestGoogleClientSecret = "test-client-secret" + TestGoogleExternalURL = "http://localhost:8080" +) + +func setupGoogleTest(allowedUsers, allowedWorkspaces []string) (Provider, gin.IRoutes) { + p, _ := NewGoogleProvider(TestGoogleExternalURL, TestGoogleClientID, TestGoogleClientSecret, allowedUsers, allowedWorkspaces) + + goog := gin.New() + goog.POST("/token", func(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{ + "access_token": "test-access-token", + }) + }) + tsgoog := httptest.NewServer(goog) + gp := p.(*googleProvider) + gp.SetOAuth2Endpoint(oauth2.Endpoint{ + AuthURL: tsgoog.URL + "/auth", + TokenURL: tsgoog.URL + "/token", + }) + + googapi := gin.New() + tsgoogapi := httptest.NewServer(googapi) + gp.SetUserinfoEndpoint(tsgoogapi.URL + "/userinfo") + + return p, googapi +} + +func TestGoogleProvider(t *testing.T) { + p, _ := setupGoogleTest([]string{}, []string{}) + require.Equal(t, p.Name(), "Google") + require.Equal(t, p.Type(), "google") + require.Equal(t, p.RedirectURL(), GoogleCallbackEndpoint) + require.Equal(t, p.AuthURL(), GoogleAuthEndpoint) + + authCodeURL, err := p.AuthCodeURL("test-state") + require.NoError(t, err) + require.NotEmpty(t, authCodeURL) + authCodeURLObj, err := url.Parse(authCodeURL) + require.NoError(t, err) + require.Equal(t, authCodeURLObj.Path, "/auth") + require.Equal(t, authCodeURLObj.Query().Get("client_id"), TestGoogleClientID) + require.Equal(t, authCodeURLObj.Query().Get("redirect_uri"), TestGoogleExternalURL+"/.auth/google/callback") + require.Equal(t, authCodeURLObj.Query().Get("response_type"), "code") + require.Equal(t, authCodeURLObj.Query().Get("state"), "test-state") + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + req, _ := http.NewRequest("GET", "/?state=test-state&code=test-code", nil) + c.Request = req + _, err = p.Exchange(c, "invalid-state") + require.Error(t, err) + token, err := p.Exchange(c, "test-state") + require.NoError(t, err) + require.NotNil(t, token) + require.Equal(t, token.AccessToken, "test-access-token") +} + +func TestGoogleProviderWithWorkspace(t *testing.T) { + p, _ := setupGoogleTest([]string{}, []string{"example.com"}) + + authCodeURL, err := p.AuthCodeURL("test-state") + require.NoError(t, err) + require.NotEmpty(t, authCodeURL) + authCodeURLObj, err := url.Parse(authCodeURL) + require.NoError(t, err) + require.Equal(t, authCodeURLObj.Query().Get("hd"), "example.com") +} + +func TestGoogleProviderAuthorization(t *testing.T) { + tc := []struct { + name string + allowedUsers []string + allowedWorkspaces []string + userResp string + expect bool + }{ + { + name: "allow all users", + allowedUsers: []string{}, + allowedWorkspaces: []string{}, + userResp: `{"sub": "12345", "name": "Test User", "email": "user1@gmail.com"}`, + expect: true, + }, + { + name: "allow single user", + allowedUsers: []string{"user1@gmail.com", "user2@gmail.com"}, + allowedWorkspaces: []string{}, + userResp: `{"sub": "12345", "name": "Test User", "email": "user1@gmail.com"}`, + expect: true, + }, + { + name: "deny single user", + allowedUsers: []string{"user1@gmail.com"}, + allowedWorkspaces: []string{}, + userResp: `{"sub": "12345", "name": "Test User", "email": "user2@gmail.com"}`, + expect: false, + }, + { + name: "allow by workspace", + allowedUsers: []string{}, + allowedWorkspaces: []string{"example.com"}, + userResp: `{"sub": "12345", "name": "Test User", "email": "user1@example.com", "hd": "example.com"}`, + expect: true, + }, + { + name: "deny by workspace", + allowedUsers: []string{}, + allowedWorkspaces: []string{"example.com"}, + userResp: `{"sub": "12345", "name": "Test User", "email": "user1@gmail.com"}`, + expect: false, + }, + { + name: "deny by other workspace", + allowedUsers: []string{}, + allowedWorkspaces: []string{"example.com"}, + userResp: `{"sub": "12345", "name": "Test User", "email": "test@other.com", "hd": "other.com"}`, + expect: false, + }, + { + name: "allow user without workspace domain", + allowedUsers: []string{}, + allowedWorkspaces: []string{}, + userResp: `{"sub": "12345", "name": "Test User", "email": "test@gmail.com"}`, + expect: true, + }, + { + name: "allow specific user with workspace", + allowedUsers: []string{"test@example.com"}, + allowedWorkspaces: []string{"other.com"}, + userResp: `{"sub": "12345", "name": "Test User", "email": "test@example.com", "hd": "example.com"}`, + expect: true, + }, + } + + for _, tt := range tc { + t.Run(tt.name, func(t *testing.T) { + p, googapi := setupGoogleTest(tt.allowedUsers, tt.allowedWorkspaces) + + googapi.GET("/userinfo", func(c *gin.Context) { + c.Data(http.StatusOK, "application/json", []byte(tt.userResp)) + }) + + ok, _, err := p.Authorization(context.Background(), &oauth2.Token{AccessToken: "test-access-token"}) + require.NoError(t, err) + require.Equal(t, tt.expect, ok) + }) + } +} + +func TestGoogleProviderAuthorizationAPIError(t *testing.T) { + p, googapi := setupGoogleTest([]string{}, []string{}) + + googapi.GET("/userinfo", func(c *gin.Context) { + c.JSON(http.StatusInternalServerError, gin.H{"error": "internal error"}) + }) + + ok, user, err := p.Authorization(context.Background(), &oauth2.Token{AccessToken: "test-access-token"}) + require.Error(t, err) + require.False(t, ok) + require.Empty(t, user) + require.Contains(t, err.Error(), "failed to get user info from Google API") +} + +func TestGoogleProviderAuthorizationInvalidJSON(t *testing.T) { + p, googapi := setupGoogleTest([]string{}, []string{}) + + googapi.GET("/userinfo", func(c *gin.Context) { + c.Data(http.StatusOK, "application/json", []byte(`invalid json`)) + }) + + ok, user, err := p.Authorization(context.Background(), &oauth2.Token{AccessToken: "test-access-token"}) + require.Error(t, err) + require.False(t, ok) + require.Empty(t, user) +} diff --git a/pkg/auth/interface.go b/pkg/auth/interface.go index 1cc14eb..5aebc70 100644 --- a/pkg/auth/interface.go +++ b/pkg/auth/interface.go @@ -13,8 +13,7 @@ type Provider interface { Type() string RedirectURL() string AuthURL() string - AuthCodeURL(c *gin.Context, state string) (string, error) + AuthCodeURL(state string) (string, error) Exchange(c *gin.Context, state string) (*oauth2.Token, error) - GetUserID(ctx context.Context, token *oauth2.Token) (string, error) - Authorization(userid string) (bool, error) + Authorization(ctx context.Context, token *oauth2.Token) (bool, string, error) } diff --git a/pkg/auth/main_test.go b/pkg/auth/main_test.go new file mode 100644 index 0000000..0591966 --- /dev/null +++ b/pkg/auth/main_test.go @@ -0,0 +1,13 @@ +package auth + +import ( + "os" + "testing" + + "github.com/gin-gonic/gin" +) + +func TestMain(m *testing.M) { + gin.SetMode(gin.TestMode) + os.Exit(m.Run()) +} diff --git a/pkg/auth/mock.go b/pkg/auth/mock.go index f556e90..0a031eb 100644 --- a/pkg/auth/mock.go +++ b/pkg/auth/mock.go @@ -43,18 +43,18 @@ func (m *MockProvider) EXPECT() *MockProviderMockRecorder { } // AuthCodeURL mocks base method. -func (m *MockProvider) AuthCodeURL(c *gin.Context, state string) (string, error) { +func (m *MockProvider) AuthCodeURL(state string) (string, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "AuthCodeURL", c, state) + ret := m.ctrl.Call(m, "AuthCodeURL", state) ret0, _ := ret[0].(string) ret1, _ := ret[1].(error) return ret0, ret1 } // AuthCodeURL indicates an expected call of AuthCodeURL. -func (mr *MockProviderMockRecorder) AuthCodeURL(c, state any) *gomock.Call { +func (mr *MockProviderMockRecorder) AuthCodeURL(state any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AuthCodeURL", reflect.TypeOf((*MockProvider)(nil).AuthCodeURL), c, state) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AuthCodeURL", reflect.TypeOf((*MockProvider)(nil).AuthCodeURL), state) } // AuthURL mocks base method. @@ -72,18 +72,19 @@ func (mr *MockProviderMockRecorder) AuthURL() *gomock.Call { } // Authorization mocks base method. -func (m *MockProvider) Authorization(userid string) (bool, error) { +func (m *MockProvider) Authorization(ctx context.Context, token *oauth2.Token) (bool, string, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Authorization", userid) + ret := m.ctrl.Call(m, "Authorization", ctx, token) ret0, _ := ret[0].(bool) - ret1, _ := ret[1].(error) - return ret0, ret1 + ret1, _ := ret[1].(string) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 } // Authorization indicates an expected call of Authorization. -func (mr *MockProviderMockRecorder) Authorization(userid any) *gomock.Call { +func (mr *MockProviderMockRecorder) Authorization(ctx, token any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Authorization", reflect.TypeOf((*MockProvider)(nil).Authorization), userid) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Authorization", reflect.TypeOf((*MockProvider)(nil).Authorization), ctx, token) } // Exchange mocks base method. @@ -101,21 +102,6 @@ func (mr *MockProviderMockRecorder) Exchange(c, state any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exchange", reflect.TypeOf((*MockProvider)(nil).Exchange), c, state) } -// GetUserID mocks base method. -func (m *MockProvider) GetUserID(ctx context.Context, token *oauth2.Token) (string, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetUserID", ctx, token) - ret0, _ := ret[0].(string) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetUserID indicates an expected call of GetUserID. -func (mr *MockProviderMockRecorder) GetUserID(ctx, token any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserID", reflect.TypeOf((*MockProvider)(nil).GetUserID), ctx, token) -} - // Name mocks base method. func (m *MockProvider) Name() string { m.ctrl.T.Helper() diff --git a/pkg/auth/oidc.go b/pkg/auth/oidc.go index 5f54ed2..767d86e 100644 --- a/pkg/auth/oidc.go +++ b/pkg/auth/oidc.go @@ -6,6 +6,7 @@ import ( "errors" "net/http" "net/url" + "slices" "github.com/gin-gonic/gin" "github.com/mattn/go-jsonpointer" @@ -75,7 +76,7 @@ func (p *oidcProvider) AuthURL() string { return OIDCAuthEndpoint } -func (p *oidcProvider) AuthCodeURL(c *gin.Context, state string) (string, error) { +func (p *oidcProvider) AuthCodeURL(state string) (string, error) { authURL := p.oauth2.AuthCodeURL(state) return authURL, nil } @@ -92,36 +93,33 @@ func (p *oidcProvider) Exchange(c *gin.Context, state string) (*oauth2.Token, er return token, nil } -func (p *oidcProvider) GetUserID(ctx context.Context, token *oauth2.Token) (string, error) { +func (p *oidcProvider) Authorization(ctx context.Context, token *oauth2.Token) (bool, string, error) { client := p.oauth2.Client(ctx, token) resp, err := client.Get(p.userInfoURL) if err != nil { - return "", err + return false, "", err } defer resp.Body.Close() var obj any if err := json.NewDecoder(resp.Body).Decode(&obj); err != nil { - return "", err + return false, "", err } v, err := jsonpointer.Get(obj, p.userIDField) if err != nil { - return "", err + return false, "", err } userID, ok := v.(string) if !ok { - return "", errors.New("user ID field is not a string") + return false, "", errors.New("user ID field is not a string") } - return userID, nil -} -func (p *oidcProvider) Authorization(userid string) (bool, error) { if len(p.allowedUsers) == 0 { - return true, nil + return true, userID, nil } - for _, allowedUser := range p.allowedUsers { - if allowedUser == userid { - return true, nil - } + + if slices.Contains(p.allowedUsers, userID) { + return true, userID, nil } - return false, nil + + return false, userID, nil } diff --git a/pkg/auth/oidc_test.go b/pkg/auth/oidc_test.go new file mode 100644 index 0000000..f4a0101 --- /dev/null +++ b/pkg/auth/oidc_test.go @@ -0,0 +1,248 @@ +package auth + +import ( + "context" + "net/http" + "net/http/httptest" + "net/url" + "testing" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" + "golang.org/x/oauth2" +) + +const ( + TestOIDCClientID = "test-oidc-client-id" + TestOIDCClientSecret = "test-oidc-client-secret" + TestOIDCExternalURL = "http://localhost:8080" + TestOIDCProviderName = "TestOIDC" + TestOIDCUserIDField = "/sub" +) + +func setupOIDCTest(allowedUsers []string, userIDField string) (Provider, gin.IRoutes, gin.IRoutes, *httptest.Server) { + // Setup OIDC configuration server + configServer := gin.New() + configServer.GET("/.well-known/openid_configuration", func(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{ + "authorization_endpoint": "http://localhost:8080/auth", + "token_endpoint": "http://localhost:8080/token", + "userinfo_endpoint": "http://localhost:8080/userinfo", + }) + }) + tsConfig := httptest.NewServer(configServer) + + if userIDField == "" { + userIDField = TestOIDCUserIDField + } + + p, err := NewOIDCProvider( + tsConfig.URL+"/.well-known/openid_configuration", + []string{"openid", "profile"}, + userIDField, + TestOIDCProviderName, + TestOIDCExternalURL, + TestOIDCClientID, + TestOIDCClientSecret, + allowedUsers, + ) + if err != nil { + panic(err) + } + + // Setup OAuth2 token endpoint + oauth := gin.New() + oauth.POST("/token", func(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{ + "access_token": "test-access-token", + "token_type": "Bearer", + }) + }) + tsOAuth := httptest.NewServer(oauth) + + // Setup userinfo endpoint + userinfo := gin.New() + tsUserinfo := httptest.NewServer(userinfo) + + // Override endpoints in provider + op := p.(*oidcProvider) + op.oauth2.Endpoint = oauth2.Endpoint{ + AuthURL: tsOAuth.URL + "/auth", + TokenURL: tsOAuth.URL + "/token", + } + op.userInfoURL = tsUserinfo.URL + "/userinfo" + + return p, oauth, userinfo, tsConfig +} + +func TestOIDCProvider(t *testing.T) { + p, _, _, tsConfig := setupOIDCTest([]string{}, "") + defer tsConfig.Close() + + require.Equal(t, p.Name(), TestOIDCProviderName) + require.Equal(t, p.Type(), "oidc") + require.Equal(t, p.RedirectURL(), OIDCCallbackEndpoint) + require.Equal(t, p.AuthURL(), OIDCAuthEndpoint) + + // Check AuthCodeURL + authCodeURL, err := p.AuthCodeURL("test-state") + require.NoError(t, err) + require.NotEmpty(t, authCodeURL) + authCodeURLObj, err := url.Parse(authCodeURL) + require.NoError(t, err) + require.Equal(t, authCodeURLObj.Query().Get("client_id"), TestOIDCClientID) + require.Equal(t, authCodeURLObj.Query().Get("redirect_uri"), TestOIDCExternalURL+"/.auth/oidc/callback") + require.Equal(t, authCodeURLObj.Query().Get("response_type"), "code") + require.Equal(t, authCodeURLObj.Query().Get("state"), "test-state") + require.Contains(t, authCodeURLObj.Query().Get("scope"), "openid") + + // Check Exchange + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + req, _ := http.NewRequest("GET", "/?state=test-state&code=test-code", nil) + c.Request = req + _, err = p.Exchange(c, "invalid-state") + require.Error(t, err) + token, err := p.Exchange(c, "test-state") + require.NoError(t, err) + require.NotNil(t, token) + require.Equal(t, token.AccessToken, "test-access-token") +} + +func TestOIDCProviderAuthorization(t *testing.T) { + tc := []struct { + name string + allowedUsers []string + userIDField string + userResp string + expect bool + }{ + { + name: "allow all users", + allowedUsers: []string{}, + userIDField: "/sub", + userResp: `{"sub": "user1", "name": "Test User"}`, + expect: true, + }, + { + name: "allow single user", + allowedUsers: []string{"user1", "user2"}, + userIDField: "/sub", + userResp: `{"sub": "user1", "name": "Test User"}`, + expect: true, + }, + { + name: "deny single user", + allowedUsers: []string{"user1"}, + userIDField: "/sub", + userResp: `{"sub": "user2", "name": "Test User"}`, + expect: false, + }, + { + name: "custom user ID field", + allowedUsers: []string{"test@example.com"}, + userIDField: "/email", + userResp: `{"sub": "user1", "email": "test@example.com", "name": "Test User"}`, + expect: true, + }, + { + name: "nested user ID field", + allowedUsers: []string{"user1"}, + userIDField: "/profile/username", + userResp: `{"sub": "123", "profile": {"username": "user1", "display_name": "Test User"}}`, + expect: true, + }, + } + + for _, tt := range tc { + t.Run(tt.name, func(t *testing.T) { + p, _, userinfo, tsConfig := setupOIDCTest(tt.allowedUsers, tt.userIDField) + defer tsConfig.Close() + + userinfo.GET("/userinfo", func(c *gin.Context) { + c.Data(http.StatusOK, "application/json", []byte(tt.userResp)) + }) + + // Call the Authorization method + ok, _, err := p.Authorization(context.Background(), &oauth2.Token{AccessToken: "test-access-token"}) + require.NoError(t, err) + require.Equal(t, tt.expect, ok) + }) + } +} + +func TestOIDCProviderErrors(t *testing.T) { + t.Run("invalid configuration URL", func(t *testing.T) { + _, err := NewOIDCProvider( + "http://invalid-url/.well-known/openid_configuration", + []string{"openid"}, + "/sub", + "TestOIDC", + TestOIDCExternalURL, + TestOIDCClientID, + TestOIDCClientSecret, + []string{}, + ) + require.Error(t, err) + }) + + t.Run("invalid JSON in configuration", func(t *testing.T) { + configServer := gin.New() + configServer.GET("/.well-known/openid_configuration", func(c *gin.Context) { + c.String(http.StatusOK, "invalid json") + }) + tsConfig := httptest.NewServer(configServer) + defer tsConfig.Close() + + _, err := NewOIDCProvider( + tsConfig.URL+"/.well-known/openid_configuration", + []string{"openid"}, + "/sub", + "TestOIDC", + TestOIDCExternalURL, + TestOIDCClientID, + TestOIDCClientSecret, + []string{}, + ) + require.Error(t, err) + }) + + t.Run("missing user ID field", func(t *testing.T) { + p, _, userinfo, tsConfig := setupOIDCTest([]string{}, "/missing_field") + defer tsConfig.Close() + + userinfo.GET("/userinfo", func(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{"sub": "user1"}) + }) + + ok, _, err := p.Authorization(context.Background(), &oauth2.Token{AccessToken: "test-access-token"}) + require.Error(t, err) + require.False(t, ok) + }) + + t.Run("non-string user ID field", func(t *testing.T) { + p, _, userinfo, tsConfig := setupOIDCTest([]string{}, "/sub") + defer tsConfig.Close() + + userinfo.GET("/userinfo", func(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{"sub": 12345}) + }) + + ok, _, err := p.Authorization(context.Background(), &oauth2.Token{AccessToken: "test-access-token"}) + require.Error(t, err) + require.False(t, ok) + }) + + t.Run("userinfo endpoint error", func(t *testing.T) { + p, _, userinfo, tsConfig := setupOIDCTest([]string{}, "/sub") + defer tsConfig.Close() + + userinfo.GET("/userinfo", func(c *gin.Context) { + c.JSON(http.StatusInternalServerError, gin.H{"error": "server error"}) + }) + + ok, _, err := p.Authorization(context.Background(), &oauth2.Token{AccessToken: "test-access-token"}) + require.Error(t, err) + require.False(t, ok) + }) +} diff --git a/pkg/idp/idp_test.go b/pkg/idp/idp_test.go index 100d684..51ee643 100644 --- a/pkg/idp/idp_test.go +++ b/pkg/idp/idp_test.go @@ -53,8 +53,7 @@ func setupTestServer(t *testing.T) (*httptest.Server, repository.Repository, str // Mock auth middleware that always passes router.Use(func(c *gin.Context) { session := sessions.Default(c) - session.Set(auth.SessionKeyProvider, auth.PasswordProvider) - session.Set(auth.SessionKeyUserID, auth.PasswordUserID) + session.Set(auth.SessionKeyAuthorized, true) err := session.Save() if err != nil { c.JSON(500, gin.H{"error": "Failed to save session"}) diff --git a/pkg/mcp-proxy/main.go b/pkg/mcp-proxy/main.go index 76c9c36..086fd40 100644 --- a/pkg/mcp-proxy/main.go +++ b/pkg/mcp-proxy/main.go @@ -45,9 +45,11 @@ func Run( googleClientID string, googleClientSecret string, googleAllowedUsers []string, + googleAllowedWorkspaces []string, githubClientID string, githubClientSecret string, githubAllowedUsers []string, + githubAllowedOrgs []string, oidcConfigurationURL string, oidcClientID string, oidcClientSecret string, @@ -139,7 +141,7 @@ func Run( // Add Google provider if configured if googleClientID != "" && googleClientSecret != "" { - googleProvider, err := auth.NewGoogleProvider(externalURL, googleClientID, googleClientSecret, googleAllowedUsers) + googleProvider, err := auth.NewGoogleProvider(externalURL, googleClientID, googleClientSecret, googleAllowedUsers, googleAllowedWorkspaces) if err != nil { return fmt.Errorf("failed to create Google provider: %w", err) } @@ -148,7 +150,7 @@ func Run( // Add GitHub provider if configured if githubClientID != "" && githubClientSecret != "" { - githubProvider, err := auth.NewGithubProvider(githubClientID, githubClientSecret, externalURL, githubAllowedUsers) + githubProvider, err := auth.NewGithubProvider(githubClientID, githubClientSecret, externalURL, githubAllowedUsers, githubAllowedOrgs) if err != nil { return fmt.Errorf("failed to create GitHub provider: %w", err) } @@ -207,6 +209,13 @@ func Run( router.Use(ginzap.Ginzap(logger, time.RFC3339, true)) router.Use(ginzap.RecoveryWithZap(logger, true)) store := cookie.NewStore(secret) + store.Options(sessions.Options{ + Path: "/", + MaxAge: 600, + HttpOnly: true, + Secure: false, + SameSite: http.SameSiteLaxMode, + }) router.Use(sessions.Sessions("session", store)) authRouter.SetupRoutes(router) idpRouter.SetupRoutes(router) diff --git a/pkg/utils/must.go b/pkg/utils/must.go new file mode 100644 index 0000000..897c561 --- /dev/null +++ b/pkg/utils/must.go @@ -0,0 +1,8 @@ +package utils + +func Must[T any](v T, err error) T { + if err != nil { + panic(err) + } + return v +} From dfce241b5e8be9822a3713948e1448b430d73c68 Mon Sep 17 00:00:00 2001 From: Takanori Hirano Date: Sat, 23 Aug 2025 18:57:04 +0000 Subject: [PATCH 2/3] refactor: remove unused getProvider method from AuthRouter Remove dead code that was not being used anywhere in the codebase. --- pkg/auth/auth.go | 9 --------- 1 file changed, 9 deletions(-) diff --git a/pkg/auth/auth.go b/pkg/auth/auth.go index 9fbb2b7..d637331 100644 --- a/pkg/auth/auth.go +++ b/pkg/auth/auth.go @@ -132,15 +132,6 @@ func (a *AuthRouter) SetupRoutes(router gin.IRouter) { } } -func (a *AuthRouter) getProvider(name string) Provider { - for _, provider := range a.providers { - if provider.Name() == name { - return provider - } - } - return nil -} - func (a *AuthRouter) handleLogin(c *gin.Context) { if c.Request.Method == "POST" { a.handleLoginPost(c) From a944f70a2560056aa1d93f41d03297d59ee20a63 Mon Sep 17 00:00:00 2001 From: Takanori Hirano Date: Sun, 24 Aug 2025 04:19:02 +0000 Subject: [PATCH 3/3] refactor: rename response variables in GitHub OAuth for clarity Renamed resp variables to resp1, resp2, resp3 to avoid variable shadowing and improve code readability in the GitHub OAuth authorization flow. --- pkg/auth/github.go | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/pkg/auth/github.go b/pkg/auth/github.go index f146f38..d931917 100644 --- a/pkg/auth/github.go +++ b/pkg/auth/github.go @@ -87,19 +87,19 @@ func (p *githubProvider) Exchange(c *gin.Context, state string) (*oauth2.Token, func (p *githubProvider) Authorization(ctx context.Context, token *oauth2.Token) (bool, string, error) { client := p.oauth2.Client(ctx, token) - resp, err := client.Get(utils.Must(url.JoinPath(p.endpoint, "/user"))) + resp1, err := client.Get(utils.Must(url.JoinPath(p.endpoint, "/user"))) if err != nil { return false, "", err } - if resp.StatusCode < 200 || resp.StatusCode >= 300 { - return false, "", errors.New("failed to get user info from GitHub API: " + resp.Status) + if resp1.StatusCode < 200 || resp1.StatusCode >= 300 { + return false, "", errors.New("failed to get user info from GitHub API: " + resp1.Status) } - defer resp.Body.Close() + defer resp1.Body.Close() var userInfo struct { Login string `json:"login"` } - if err := json.NewDecoder(resp.Body).Decode(&userInfo); err != nil { + if err := json.NewDecoder(resp1.Body).Decode(&userInfo); err != nil { return false, "", err } @@ -122,18 +122,18 @@ func (p *githubProvider) Authorization(ctx context.Context, token *oauth2.Token) } if len(allowedOrgs) > 0 { - resp, err = client.Get(utils.Must(url.JoinPath(p.endpoint, "/user/orgs"))) + resp2, err := client.Get(utils.Must(url.JoinPath(p.endpoint, "/user/orgs"))) if err != nil { return false, "", err } - if resp.StatusCode < 200 || resp.StatusCode >= 300 { - return false, "", errors.New("failed to get user info from GitHub API: " + resp.Status) + if resp2.StatusCode < 200 || resp2.StatusCode >= 300 { + return false, "", errors.New("failed to get user info from GitHub API: " + resp2.Status) } - defer resp.Body.Close() + defer resp2.Body.Close() var orgInfo []struct { Login string `json:"login"` } - if err := json.NewDecoder(resp.Body).Decode(&orgInfo); err != nil { + if err := json.NewDecoder(resp2.Body).Decode(&orgInfo); err != nil { return false, "", err } for _, o := range orgInfo { @@ -143,21 +143,21 @@ func (p *githubProvider) Authorization(ctx context.Context, token *oauth2.Token) } } if len(allowedOrgTeams) > 0 { - resp, err = client.Get(utils.Must(url.JoinPath(p.endpoint, "/user/teams"))) + resp3, err := client.Get(utils.Must(url.JoinPath(p.endpoint, "/user/teams"))) if err != nil { return false, "", err } - if resp.StatusCode < 200 || resp.StatusCode >= 300 { - return false, "", errors.New("failed to get user info from GitHub API: " + resp.Status) + if resp3.StatusCode < 200 || resp3.StatusCode >= 300 { + return false, "", errors.New("failed to get user info from GitHub API: " + resp3.Status) } - defer resp.Body.Close() + defer resp3.Body.Close() var teamInfo []struct { Organization struct { Login string `json:"login"` } `json:"organization"` Slug string `json:"slug"` } - if err := json.NewDecoder(resp.Body).Decode(&teamInfo); err != nil { + if err := json.NewDecoder(resp3.Body).Decode(&teamInfo); err != nil { return false, "", err } for _, team := range teamInfo {