From 314fd0143d77277626b8ee7bb1a12bf9945c7e42 Mon Sep 17 00:00:00 2001 From: Sam Willcocks Date: Fri, 12 Jul 2019 14:19:21 +0100 Subject: [PATCH] Add discord OAuth support Adds support for Discord as an OAuth provider. Fetches user info and also the guilds (servers) that user is a member of, and returns the IDs of said servers as the Groups userinfo attribute. This allows, in conjunction with userfile, to authenticate users based on their membership of discord servers. --- README.md | 5 +- oauth2/discord.go | 114 ++++++++++++++++++++++++++++++++++++++++ oauth2/discord_test.go | 67 +++++++++++++++++++++++ oauth2/provider_test.go | 7 ++- 4 files changed, 191 insertions(+), 2 deletions(-) create mode 100644 oauth2/discord.go create mode 100644 oauth2/discord_test.go diff --git a/README.md b/README.md index e4008b58..cdd18e3b 100644 --- a/README.md +++ b/README.md @@ -44,6 +44,7 @@ The following providers (login backends) are supported. * Bitbucket login * Facebook login * Gitlab login + * Discord login ## Questions @@ -68,6 +69,7 @@ _Note for Caddy users_: Not all parameters are available in Caddy. See the table | -bitbucket | value | | X | OAuth config in the form: client_id=..,client_secret=..[,scope=..][,redirect_uri=..] | | -facebook | value | | X | OAuth config in the form: client_id=..,client_secret=..[,scope=..][,redirect_uri=..] | | -gitlab | value | | X | OAuth config in the form: client_id=..,client_secret=..[,scope=..,][redirect_uri=..] | +| -discord | value | | X | OAuth config in the form: client_id=..,client_secret=..[,scope=..,][redirect_uri=..] | | -host | string | "localhost" | - | Host to listen on | | -htpasswd | value | | X | Htpasswd login backend opts: file=/path/to/pwdfile | | -jwt-expiry | go duration | 24h | X | Expiry duration for the JWT token, e.g. 2h or 3h30m | @@ -314,6 +316,7 @@ Currently the following OAuth provider is supported: * Bitbucket * Facebook * Gitlab +* Discord An OAuth provider supports the following parameters: @@ -392,7 +395,7 @@ below the claim attribute are written into the token. The following attributes c * `origin` - the provider or backend name (all backends) * `email` - the mail address (the OAuth provider) * `domain` - the domain (Google only) -* `groups` - the full path string of user groups enclosed in an array (Gitlab only) +* `groups` - the full path string of user groups enclosed in an array (Gitlab/Discord only) Example: * The user bob will become the `"role": "superAdmin"`, when authenticating with htpasswd file diff --git a/oauth2/discord.go b/oauth2/discord.go new file mode 100644 index 00000000..1f8f2862 --- /dev/null +++ b/oauth2/discord.go @@ -0,0 +1,114 @@ +package oauth2 + +import ( + "encoding/json" + "fmt" + "io/ioutil" + "net/http" + "strings" + + "github.com/tarent/loginsrv/model" +) + +var discordAPI = "https://discordapp.com/api" +var discordCDN = "https://cdn.discordapp.com" + +func init() { + RegisterProvider(providerDiscord) +} + +// DiscordUser is used for parsing the github response +type DiscordUser struct { + ID string `json:"id,omitempty"` + Username string `json:"username,omitempty"` + Discriminator string `json:"discriminator,omitempty"` + AvatarHash string `json:"avatar,omitempty"` + MFAEnabled bool `json:"mfa_enabled,omitempty"` + Locale string `json:"locale,omitempty"` + Verified bool `json:"verified,omitempty"` + Email string `json:"email,omitempty"` + Flags int `json:"flags,omitempty"` + PremiumType int `json:"premium_type,omitempty"` +} + +// DiscordGuild is a partial guild object returned by the /user/guilds endpoint +type DiscordGuild struct { + ID string `json:"id,omitempty"` + Name string `json:"name,omitempty"` + IconHash string `json:"icon,omitempty"` + Owner bool `json:"owner,omitempty"` + Permissions int `json:"permissions,omitempty"` +} + +func discordAPIRequest(endpoint, token string) ([]byte, error) { + req, err := http.NewRequest("GET", fmt.Sprintf("%v/%v", discordAPI, endpoint), nil) + if err != nil { + return nil, fmt.Errorf("create request: %v", err) + } + req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", token)) + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if !strings.Contains(resp.Header.Get("Content-Type"), "application/json") { + return nil, fmt.Errorf("wrong content-type: %v", resp.Header.Get("Content-Type")) + } + + if resp.StatusCode != 200 { + return nil, fmt.Errorf("got http status %v", resp.StatusCode) + } + + b, err := ioutil.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("error reading body: %v", err) + } + + return b, nil +} + +var providerDiscord = Provider{ + Name: "discord", + AuthURL: "https://discordapp.com/api/oauth2/authorize?prompt=none", + TokenURL: "https://discordapp.com/api/oauth2/token", + DefaultScopes: "identify email guilds", + GetUserInfo: func(token TokenInfo) (model.UserInfo, string, error) { + du := DiscordUser{} + dg := []DiscordGuild{} + // Get user info + raw, err := discordAPIRequest("/users/@me", token.AccessToken) + if err != nil { + return model.UserInfo{}, "", fmt.Errorf("error getting discord user info: %v", err) + } + err = json.Unmarshal(raw, &du) + if err != nil { + return model.UserInfo{}, "", fmt.Errorf("error parsing discord get user info: %v", err) + } + + // Get user's guilds (servers) + raw, err = discordAPIRequest("/users/@me/guilds", token.AccessToken) + if err != nil { + return model.UserInfo{}, "", fmt.Errorf("error getting discord user guilds: %v", err) + } + err = json.Unmarshal(raw, &dg) + if err != nil { + return model.UserInfo{}, "", fmt.Errorf("error parsing discord guilds: %v", err) + } + + var guilds []string + for _, g := range dg { + guilds = append(guilds, g.ID) + } + + return model.UserInfo{ + Sub: fmt.Sprintf("%v#%v", du.Username, du.Discriminator), + Picture: fmt.Sprintf("%v/avatars/%v/%v.png", discordCDN, du.ID, du.AvatarHash), + Name: du.Username, + Email: du.Email, + Origin: "discord", + Groups: guilds, + }, string(raw), nil + }, +} diff --git a/oauth2/discord_test.go b/oauth2/discord_test.go new file mode 100644 index 00000000..1ae45a6b --- /dev/null +++ b/oauth2/discord_test.go @@ -0,0 +1,67 @@ +package oauth2 + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/gorilla/mux" + . "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" +) + +var discordTestUserResponse = `{ + "id": "80351110224678912", + "username": "Nelly", + "discriminator": "1337", + "avatar": "8342729096ea3675442027381ff50dfe", + "verified": true, + "email": "nelly@discordapp.com", + "flags": 64, + "premium_type": 1 + }` + +var discordTestUserGuildsResponse = `[ + { + "id": "80351110224678912", + "name": "1337 Krew", + "icon": "8342729096ea3675442027381ff50dfe", + "owner": true, + "permissions": 36953089 + } +]` + +type DiscordTestSuite struct { + suite.Suite + Server *httptest.Server +} + +func (suite *DiscordTestSuite) SetupTest() { + r := mux.NewRouter() + + r.HandleFunc("/users/@me", func(w http.ResponseWriter, r *http.Request) { + suite.Equal(r.Header.Get("Authentication"), "Bearer secret") + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.Write([]byte(discordTestUserResponse)) + }) + + r.HandleFunc("/users/@me/guilds", func(w http.ResponseWriter, r *http.Request) { + suite.Equal(r.Header.Get("Authentication"), "Bearer secret") + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.Write([]byte(discordTestUserGuildsResponse)) + }) + + suite.Server = httptest.NewServer(r) +} + +func (suite *DiscordTestSuite) Test_Discord_getUserInfo(t *testing.T) { + discordAPI = suite.Server.URL + + u, rawJSON, err := providerDiscord.GetUserInfo(TokenInfo{AccessToken: "secret"}) + NoError(t, err) + Equal(t, "Nelly#1337", u.Sub) + Equal(t, "nelly@discordapp.com", u.Email) + Equal(t, "Nelly", u.Name) + Equal(t, []string{"80351110224678912"}, u.Groups) + Equal(t, discordTestUserResponse, rawJSON) +} diff --git a/oauth2/provider_test.go b/oauth2/provider_test.go index 465a6586..8eae627d 100644 --- a/oauth2/provider_test.go +++ b/oauth2/provider_test.go @@ -27,11 +27,16 @@ func Test_ProviderRegistration(t *testing.T) { NotNil(t, gitlab) True(t, exist) + discord, exist := GetProvider("discord") + NotNil(t, discord) + True(t, exist) + list := ProviderList() - Equal(t, 5, len(list)) + Equal(t, 6, len(list)) Contains(t, list, "github") Contains(t, list, "google") Contains(t, list, "bitbucket") Contains(t, list, "facebook") Contains(t, list, "gitlab") + Contains(t, list, "discord") }