Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,11 @@ func main() {
var googleClientID string
var googleClientSecret string
var googleAllowedUsers string
var googleAllowedWorkspaces string
var githubClientID string
var githubClientSecret string
var githubAllowedUsers string
var githubAllowedOrgs string
var oidcConfigurationURL string
var oidcClientID string
var oidcClientSecret string
Expand All @@ -63,6 +65,14 @@ func main() {
}
}

var googleAllowedWorkspacesList []string
if googleAllowedWorkspaces != "" {
googleAllowedWorkspacesList = strings.Split(googleAllowedWorkspaces, ",")
for i := range googleAllowedWorkspacesList {
googleAllowedWorkspacesList[i] = strings.TrimSpace(googleAllowedWorkspacesList[i])
}
}

var githubAllowedUsersList []string
if githubAllowedUsers != "" {
githubAllowedUsersList = strings.Split(githubAllowedUsers, ",")
Expand All @@ -71,6 +81,14 @@ func main() {
}
}

var githubAllowedOrgsList []string
if githubAllowedOrgs != "" {
githubAllowedOrgsList = strings.Split(githubAllowedOrgs, ",")
for i := range githubAllowedOrgsList {
githubAllowedOrgsList[i] = strings.TrimSpace(githubAllowedOrgsList[i])
}
}

var oidcAllowedUsersList []string
if oidcAllowedUsers != "" {
oidcAllowedUsersList = strings.Split(oidcAllowedUsers, ",")
Expand Down Expand Up @@ -110,9 +128,11 @@ func main() {
googleClientID,
googleClientSecret,
googleAllowedUsersList,
googleAllowedWorkspacesList,
githubClientID,
githubClientSecret,
githubAllowedUsersList,
githubAllowedOrgsList,
oidcConfigurationURL,
oidcClientID,
oidcClientSecret,
Expand Down Expand Up @@ -144,11 +164,13 @@ func main() {
rootCmd.Flags().StringVar(&googleClientID, "google-client-id", getEnvWithDefault("GOOGLE_CLIENT_ID", ""), "Google OAuth client ID")
rootCmd.Flags().StringVar(&googleClientSecret, "google-client-secret", getEnvWithDefault("GOOGLE_CLIENT_SECRET", ""), "Google OAuth client secret")
rootCmd.Flags().StringVar(&googleAllowedUsers, "google-allowed-users", getEnvWithDefault("GOOGLE_ALLOWED_USERS", ""), "Comma-separated list of allowed Google users (emails)")
rootCmd.Flags().StringVar(&googleAllowedWorkspaces, "google-allowed-workspaces", getEnvWithDefault("GOOGLE_ALLOWED_WORKSPACES", ""), "Comma-separated list of allowed Google workspaces")

// GitHub OAuth configuration
rootCmd.Flags().StringVar(&githubClientID, "github-client-id", getEnvWithDefault("GITHUB_CLIENT_ID", ""), "GitHub OAuth client ID")
rootCmd.Flags().StringVar(&githubClientSecret, "github-client-secret", getEnvWithDefault("GITHUB_CLIENT_SECRET", ""), "GitHub OAuth client secret")
rootCmd.Flags().StringVar(&githubAllowedUsers, "github-allowed-users", getEnvWithDefault("GITHUB_ALLOWED_USERS", ""), "Comma-separated list of allowed GitHub users (usernames)")
rootCmd.Flags().StringVar(&githubAllowedOrgs, "github-allowed-orgs", getEnvWithDefault("GITHUB_ALLOWED_ORGS", ""), "Comma-separated list of allowed GitHub organizations. You can also restrict access to specific teams using the format `Org:Team`")

// OIDC configuration
rootCmd.Flags().StringVar(&oidcConfigurationURL, "oidc-configuration-url", getEnvWithDefault("OIDC_CONFIGURATION_URL", ""), "OIDC configuration URL")
Expand Down
59 changes: 12 additions & 47 deletions pkg/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,7 @@ const (
PasswordProvider = "password"
PasswordUserID = "password_user"

SessionKeyProvider = "provider"
SessionKeyUserID = "user_id"
SessionKeyAuthorized = "authorized"
SessionKeyRedirectURL = "redirect_url"
SessionKeyOAuthState = "oauth_state"
)
Expand All @@ -84,22 +83,16 @@ func (a *AuthRouter) SetupRoutes(router gin.IRouter) {
a.renderError(c, err)
return
}
userID, err := provider.GetUserID(c, token)
if err != nil {
a.renderError(c, err)
return
}
ok, err := provider.Authorization(userID)
ok, user, err := provider.Authorization(c, token)
if err != nil {
a.renderError(c, err)
return
}
if !ok {
a.renderUnauthorized(c, userID, provider.Name())
a.renderUnauthorized(c, user, provider.Name())
return
}
session.Set(SessionKeyProvider, provider.Name())
session.Set(SessionKeyUserID, userID)
session.Set(SessionKeyAuthorized, true)
redirectURL := session.Get(SessionKeyRedirectURL)
if redirectURL != nil {
session.Delete(SessionKeyRedirectURL)
Expand All @@ -124,7 +117,7 @@ func (a *AuthRouter) SetupRoutes(router gin.IRouter) {
a.renderError(c, err)
return
}
url, err := provider.AuthCodeURL(c, state)
url, err := provider.AuthCodeURL(state)
if err != nil {
a.renderError(c, err)
return
Expand All @@ -139,15 +132,6 @@ func (a *AuthRouter) SetupRoutes(router gin.IRouter) {
}
}

func (a *AuthRouter) getProvider(name string) Provider {
for _, provider := range a.providers {
if provider.Name() == name {
return provider
}
}
return nil
}

func (a *AuthRouter) handleLogin(c *gin.Context) {
if c.Request.Method == "POST" {
a.handleLoginPost(c)
Expand Down Expand Up @@ -183,8 +167,7 @@ func (a *AuthRouter) handleLoginPost(c *gin.Context) {
}

session := sessions.Default(c)
session.Set(SessionKeyProvider, PasswordProvider)
session.Set(SessionKeyUserID, PasswordUserID)
session.Set(SessionKeyAuthorized, true)
redirectURL := session.Get(SessionKeyRedirectURL)
if redirectURL != nil {
session.Delete(SessionKeyRedirectURL)
Expand All @@ -203,8 +186,7 @@ func (a *AuthRouter) handleLoginPost(c *gin.Context) {

func (a *AuthRouter) handleLogout(c *gin.Context) {
session := sessions.Default(c)
session.Delete(SessionKeyProvider)
session.Delete(SessionKeyUserID)
session.Delete(SessionKeyAuthorized)
if err := session.Save(); err != nil {
a.renderError(c, err)
return
Expand All @@ -215,9 +197,8 @@ func (a *AuthRouter) handleLogout(c *gin.Context) {
func (a *AuthRouter) RequireAuth() gin.HandlerFunc {
return func(c *gin.Context) {
session := sessions.Default(c)
providerName := session.Get(SessionKeyProvider)
userID := session.Get(SessionKeyUserID)
if providerName == nil || userID == nil {
authorized := session.Get(SessionKeyAuthorized)
if authorized == nil {
session.Set(SessionKeyRedirectURL, c.Request.URL.String())
if err := session.Save(); err != nil {
a.renderError(c, err)
Expand All @@ -227,25 +208,9 @@ func (a *AuthRouter) RequireAuth() gin.HandlerFunc {
return
}

// Allow password authentication
if providerName.(string) == PasswordProvider {
c.Next()
return
}

p := a.getProvider(providerName.(string))
if p == nil {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Unknown provider"})
return
}
ok, err := p.Authorization(userID.(string))
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "Authorization failed"})
return
}
if !ok {
a.renderUnauthorized(c, userID.(string), providerName.(string))
c.Abort()
if !authorized.(bool) {
// not expected
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
return
}
c.Next()
Expand Down
11 changes: 4 additions & 7 deletions pkg/auth/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ import (
)

func setupTestRouter(authRouter *AuthRouter) *gin.Engine {
gin.SetMode(gin.TestMode)
router := gin.New()

// Setup session middleware
Expand Down Expand Up @@ -84,10 +83,9 @@ func TestAuthenticationFlow(t *testing.T) {
mockProvider.EXPECT().Name().Return("test").AnyTimes()
mockProvider.EXPECT().AuthURL().Return("/.auth/test").AnyTimes()
mockProvider.EXPECT().RedirectURL().Return("/.auth/test/callback").AnyTimes()
mockProvider.EXPECT().AuthCodeURL(gomock.Any(), gomock.Any()).Return("https://example.com/oauth", nil)
mockProvider.EXPECT().AuthCodeURL(gomock.Any()).Return("https://example.com/oauth", nil)
mockProvider.EXPECT().Exchange(gomock.Any(), gomock.Any()).Return(mockToken, nil)
mockProvider.EXPECT().GetUserID(gomock.Any(), mockToken).Return("test-user", nil)
mockProvider.EXPECT().Authorization("test-user").Return(true, nil).AnyTimes()
mockProvider.EXPECT().Authorization(gomock.Any(), mockToken).Return(true, "authorized_user", nil)

// Create AuthRouter
authRouter, err := NewAuthRouter(nil, mockProvider)
Expand Down Expand Up @@ -146,10 +144,9 @@ func TestAuthenticationFlow(t *testing.T) {
mockProvider.EXPECT().Name().Return("test").AnyTimes()
mockProvider.EXPECT().AuthURL().Return("/.auth/test").AnyTimes()
mockProvider.EXPECT().RedirectURL().Return("/.auth/test/callback").AnyTimes()
mockProvider.EXPECT().AuthCodeURL(gomock.Any(), gomock.Any()).Return("https://example.com/oauth", nil)
mockProvider.EXPECT().AuthCodeURL(gomock.Any()).Return("https://example.com/oauth", nil)
mockProvider.EXPECT().Exchange(gomock.Any(), gomock.Any()).Return(mockToken, nil)
mockProvider.EXPECT().GetUserID(gomock.Any(), mockToken).Return("unauthorized-user", nil)
mockProvider.EXPECT().Authorization("unauthorized-user").Return(false, nil).AnyTimes()
mockProvider.EXPECT().Authorization(gomock.Any(), mockToken).Return(false, "unauthorized_user", nil)

// Create AuthRouter
authRouter, err := NewAuthRouter(nil, mockProvider)
Expand Down
111 changes: 90 additions & 21 deletions pkg/auth/github.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,34 +5,53 @@ import (
"encoding/json"
"errors"
"net/url"
"slices"
"strings"

"github.com/gin-gonic/gin"
"github.com/sigbit/mcp-auth-proxy/pkg/utils"
"golang.org/x/oauth2"
"golang.org/x/oauth2/github"
)

type githubProvider struct {
endpoint string
oauth2 oauth2.Config
allowedUsers []string
allowedOrgs []string
}

func NewGithubProvider(clientID, clientSecret, externalURL string, allowedUsers []string) (Provider, error) {
func NewGithubProvider(clientID, clientSecret, externalURL string, allowedUsers []string, allowedOrgs []string) (Provider, error) {
r, err := url.JoinPath(externalURL, GitHubCallbackEndpoint)
if err != nil {
return nil, err
}
scopes := []string{}
if len(allowedOrgs) > 0 {
scopes = append(scopes, "read:org")
}
return &githubProvider{
endpoint: "https://api.github.com",
oauth2: oauth2.Config{
ClientID: clientID,
ClientSecret: clientSecret,
RedirectURL: r,
Scopes: []string{"user:email"},
Scopes: scopes,
Endpoint: github.Endpoint,
},
allowedUsers: allowedUsers,
allowedOrgs: allowedOrgs,
}, nil
}

func (p *githubProvider) SetApiEndpoint(u string) {
p.endpoint = u
}

func (p *githubProvider) SetOAuth2Endpoint(cfg oauth2.Endpoint) {
p.oauth2.Endpoint = cfg
}

func (p *githubProvider) Name() string {
return "GitHub"
}
Expand All @@ -49,7 +68,7 @@ func (p *githubProvider) AuthURL() string {
return GitHubAuthEndpoint
}

func (p *githubProvider) AuthCodeURL(c *gin.Context, state string) (string, error) {
func (p *githubProvider) AuthCodeURL(state string) (string, error) {
authURL := p.oauth2.AuthCodeURL(state)
return authURL, nil
}
Expand All @@ -66,37 +85,87 @@ func (p *githubProvider) Exchange(c *gin.Context, state string) (*oauth2.Token,
return token, nil
}

func (p *githubProvider) GetUserID(ctx context.Context, token *oauth2.Token) (string, error) {
func (p *githubProvider) Authorization(ctx context.Context, token *oauth2.Token) (bool, string, error) {
client := p.oauth2.Client(ctx, token)
resp, err := client.Get("https://api.github.com/user")
resp1, err := client.Get(utils.Must(url.JoinPath(p.endpoint, "/user")))
if err != nil {
return "", err
return false, "", err
}
if resp1.StatusCode < 200 || resp1.StatusCode >= 300 {
return false, "", errors.New("failed to get user info from GitHub API: " + resp1.Status)
}
defer resp.Body.Close()
defer resp1.Body.Close()

var userInfo struct {
ID uint64 `json:"id"`
Login string `json:"login"`
Name string `json:"name"`
Email string `json:"email"`
}
if err := json.NewDecoder(resp.Body).Decode(&userInfo); err != nil {
return "", err
if err := json.NewDecoder(resp1.Body).Decode(&userInfo); err != nil {
return false, "", err
}

return userInfo.Login, nil
}
if len(p.allowedUsers) == 0 && len(p.allowedOrgs) == 0 {
return true, userInfo.Login, nil
}

func (p *githubProvider) Authorization(userid string) (bool, error) {
if len(p.allowedUsers) == 0 {
return true, nil
if slices.Contains(p.allowedUsers, userInfo.Login) {
return true, userInfo.Login, nil
}

allowedOrgTeams := []string{}
allowedOrgs := []string{}
for _, allowedOrg := range p.allowedOrgs {
if strings.Contains(allowedOrg, ":") {
allowedOrgTeams = append(allowedOrgTeams, allowedOrg)
} else {
allowedOrgs = append(allowedOrgs, allowedOrg)
}
}

for _, allowedUser := range p.allowedUsers {
if allowedUser == userid {
return true, nil
if len(allowedOrgs) > 0 {
resp2, err := client.Get(utils.Must(url.JoinPath(p.endpoint, "/user/orgs")))
if err != nil {
return false, "", err
}
if resp2.StatusCode < 200 || resp2.StatusCode >= 300 {
return false, "", errors.New("failed to get user info from GitHub API: " + resp2.Status)
}
defer resp2.Body.Close()
var orgInfo []struct {
Login string `json:"login"`
}
if err := json.NewDecoder(resp2.Body).Decode(&orgInfo); err != nil {
return false, "", err
}
for _, o := range orgInfo {
if slices.Contains(allowedOrgs, o.Login) {
return true, userInfo.Login, nil
}
}
}
if len(allowedOrgTeams) > 0 {
resp3, err := client.Get(utils.Must(url.JoinPath(p.endpoint, "/user/teams")))
if err != nil {
return false, "", err
}
if resp3.StatusCode < 200 || resp3.StatusCode >= 300 {
return false, "", errors.New("failed to get user info from GitHub API: " + resp3.Status)
}
defer resp3.Body.Close()
var teamInfo []struct {
Organization struct {
Login string `json:"login"`
} `json:"organization"`
Slug string `json:"slug"`
}
if err := json.NewDecoder(resp3.Body).Decode(&teamInfo); err != nil {
return false, "", err
}
for _, team := range teamInfo {
if slices.Contains(allowedOrgTeams, team.Organization.Login+":"+team.Slug) {
return true, userInfo.Login, nil
}
}
}

return false, nil
return false, userInfo.Login, nil
}
Loading