Skip to content

Commit

Permalink
fix(oidc): github auth
Browse files Browse the repository at this point in the history
  • Loading branch information
fiftin committed Sep 16, 2023
1 parent b8c2080 commit a70688f
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 53 deletions.
142 changes: 91 additions & 51 deletions api/login.go
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,61 @@ func generateStateOauthCookie(w http.ResponseWriter) string {
return oauthState
}

type oidcClaimResult struct {
username string
name string
email string
}

func claimOidcToken(idToken *oidc.IDToken, provider util.OidcProvider) (res oidcClaimResult, err error) {

claims := make(map[string]interface{})
if err = idToken.Claims(&claims); err != nil {
return
}

var ok bool

res.email, ok = claims[provider.EmailClaim].(string)
if !ok {
err = fmt.Errorf("claim '%s' missing from id_token or not a string", provider.EmailClaim)
return
}

res.username = getUsernameFromEmail(res.email)

res.name, ok = claims[provider.NameClaim].(string)
if !ok || res.name == "" {
res.name = getProfileNameFromEmail(res.email)
}

return
}

func extractUsernameFromEmail(email string) string {
parts := strings.Split(email, "@")
if len(parts) > 0 {
return parts[0]
}
return ""
}

func getUsernameFromEmail(email string) string {
username := extractUsernameFromEmail(email)
suffix := util.RandString(12)
return username + "_" + suffix
}

func getProfileNameFromEmail(email string) string {
username := extractUsernameFromEmail(email)

runes := []rune(username)

runes[0] = []rune(strings.ToUpper(string(runes[0])))[0]

return string(runes)
}

func oidcRedirect(w http.ResponseWriter, r *http.Request) {
pid := mux.Vars(r)["provider"]
oauthState, err := r.Cookie("oauthstate")
Expand All @@ -380,93 +435,78 @@ func oidcRedirect(w http.ResponseWriter, r *http.Request) {
return
}

if r.FormValue("state") != oauthState.Value {
http.Redirect(w, r, "/auth/login", http.StatusTemporaryRedirect)
return
}

ctx := context.Background()
_oidc, oauth, err := getOidcProvider(pid, ctx)
if err != nil {
log.Error(err.Error())
http.Redirect(w, r, "/auth/login", http.StatusTemporaryRedirect)
return
}

provider, ok := util.Config.OidcProviders[pid]
if !ok {
log.Error(fmt.Errorf("no such provider: %s", pid))
http.Redirect(w, r, "/auth/login", http.StatusTemporaryRedirect)
return
}

verifier := _oidc.Verifier(&oidc.Config{ClientID: oauth.ClientID})

if r.FormValue("state") != oauthState.Value {
http.Redirect(w, r, "/auth/login", http.StatusTemporaryRedirect)
return
}
code := r.URL.Query().Get("code")

oauth2Token, err := oauth.Exchange(ctx, r.URL.Query().Get("code"))
oauth2Token, err := oauth.Exchange(ctx, code)
if err != nil {
log.Error(err.Error())
http.Redirect(w, r, "/auth/login", http.StatusTemporaryRedirect)
return
}

var claims oidcClaimResult

// Extract the ID Token from OAuth2 token.
rawIDToken, ok := oauth2Token.Extra("id_token").(string)
if !ok {
log.Error(fmt.Errorf("id_token is missing in token response"))
http.Redirect(w, r, "/auth/login", http.StatusTemporaryRedirect)
return
}

// Parse and verify ID Token payload.
idToken, err := verifier.Verify(ctx, rawIDToken)
if err != nil {
log.Error(err.Error())
http.Redirect(w, r, "/auth/login", http.StatusTemporaryRedirect)
return
}
if ok && rawIDToken != "" {
var idToken *oidc.IDToken
// Parse and verify ID Token payload.
idToken, err = verifier.Verify(ctx, rawIDToken)

// Extract custom claims
claims := make(map[string]interface{})
if err := idToken.Claims(&claims); err != nil {
log.Error(err.Error())
http.Redirect(w, r, "/auth/login", http.StatusTemporaryRedirect)
return
}
if err == nil {
claims, err = claimOidcToken(idToken, provider)
}
} else {
var userInfo *oidc.UserInfo
userInfo, err = _oidc.UserInfo(ctx, oauth2.StaticTokenSource(oauth2Token))

//if len(provider.UsernameClaim) == 0 {
// provider.UsernameClaim = "preferred_username"
//}
usernameClaim, ok := claims[provider.UsernameClaim].(string)
if !ok {
log.Error(fmt.Errorf("claim '%s' missing from id_token or not a string", provider.UsernameClaim))
http.Redirect(w, r, "/auth/login", http.StatusTemporaryRedirect)
return
}
if err == nil {
claims.email = userInfo.Email
claims.username = getUsernameFromEmail(claims.email)

//if len(provider.NameClaim) == 0 {
// provider.NameClaim = "preferred_username"
//}
nameClaim, ok := claims[provider.NameClaim].(string)
if !ok {
log.Error(fmt.Errorf("claim '%s' missing from id_token or not a string", provider.NameClaim))
http.Redirect(w, r, "/auth/login", http.StatusTemporaryRedirect)
return
if userInfo.Profile != "" {
claims.name = userInfo.Profile
} else {
claims.name = getProfileNameFromEmail(claims.email)
}
}
}

//if len(provider.EmailClaim) == 0 {
// provider.EmailClaim = "email"
//}
emailClaim, ok := claims[provider.EmailClaim].(string)
if !ok {
log.Error(fmt.Errorf("claim '%s' missing from id_token or not a string", provider.EmailClaim))
if err != nil {
log.Error(err.Error())
http.Redirect(w, r, "/auth/login", http.StatusTemporaryRedirect)
return
}

user, err := helpers.Store(r).GetUserByLoginOrEmail(usernameClaim, emailClaim)
user, err := helpers.Store(r).GetUserByLoginOrEmail("", claims.email) // ignore username because it creates a lot of problems
if err != nil {
user = db.User{
Username: usernameClaim,
Name: nameClaim,
Email: emailClaim,
Username: claims.username,
Name: claims.name,
Email: claims.email,
External: true,
}
user, err = helpers.Store(r).CreateUserWithoutPassword(user)
Expand Down
4 changes: 2 additions & 2 deletions util/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ type oidcEndpoint struct {
Algorithms []string `json:"algorithms"`
}

type oidcProvider struct {
type OidcProvider struct {
ClientID string `json:"client_id"`
ClientSecret string `json:"client_secret"`
RedirectURL string `json:"redirect_url"`
Expand Down Expand Up @@ -164,7 +164,7 @@ type ConfigType struct {
SlackUrl string `json:"slack_url" env:"SEMAPHORE_SLACK_URL"`

// oidc settings
OidcProviders map[string]oidcProvider `json:"oidc_providers"`
OidcProviders map[string]OidcProvider `json:"oidc_providers"`

// task concurrency
MaxParallelTasks int `json:"max_parallel_tasks" rule:"^[0-9]{1,10}$" env:"SEMAPHORE_MAX_PARALLEL_TASKS"`
Expand Down

0 comments on commit a70688f

Please sign in to comment.