Skip to content

Commit

Permalink
Parse OAuth auth/token URLs, preserving parameters
Browse files Browse the repository at this point in the history
Parse the provided Auth/Token URLs, resulting in an error at start if
they are invalid.
This allows a provider's auth URL to include additional parameters,
which is needed to implement Discord auth.
  • Loading branch information
wlcx committed Jul 12, 2019
1 parent 1c0c36b commit 19b6fd5
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 25 deletions.
14 changes: 12 additions & 2 deletions oauth2/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,20 @@ func (manager *Manager) AddConfig(providerName string, opts map[string]string) e
return fmt.Errorf("no provider for name %v", providerName)
}

authURL, err := url.Parse(p.AuthURL)
if err != nil {
return fmt.Errorf("parse auth URL: %s", err)
}

tokenURL, err := url.Parse(p.TokenURL)
if err != nil {
return fmt.Errorf("parse token URL: %s", err)
}

cfg := Config{
Provider: p,
AuthURL: p.AuthURL,
TokenURL: p.TokenURL,
AuthURL: authURL,
TokenURL: tokenURL,
}

clientID, exist := opts["client_id"]
Expand Down
4 changes: 2 additions & 2 deletions oauth2/manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ func Test_Manager_Positive_Flow(t *testing.T) {
expectedConfig := Config{
ClientID: "client42",
ClientSecret: "secret",
AuthURL: exampleProvider.AuthURL,
TokenURL: exampleProvider.TokenURL,
AuthURL: mustParseURL(exampleProvider.AuthURL),
TokenURL: mustParseURL(exampleProvider.TokenURL),
RedirectURI: "http://localhost",
Scope: "email other",
Provider: exampleProvider,
Expand Down
34 changes: 22 additions & 12 deletions oauth2/oauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@ type Config struct {
ClientSecret string

// The oauth authentication url to redirect to
AuthURL string
AuthURL *url.URL

// The url for token exchange
TokenURL string
TokenURL *url.URL

// RedirectURL is the URL to redirect users going through
// the OAuth flow, after the resource owner's URLs.
Expand Down Expand Up @@ -68,23 +68,25 @@ const defaultTimeout = 5 * time.Second
// StartFlow by redirecting the user to the login provider.
// A state parameter to protect against cross-site request forgery attacks is randomly generated and stored in a cookie
func StartFlow(cfg Config, w http.ResponseWriter) {
values := make(url.Values)
values.Set("client_id", cfg.ClientID)
values.Set("scope", cfg.Scope)
values.Set("redirect_uri", cfg.RedirectURI)
values.Set("response_type", "code")
// Add auth parameters to the URL, preserving any existing parameters
query := cfg.AuthURL.Query()
query.Set("client_id", cfg.ClientID)
query.Set("scope", cfg.Scope)
query.Set("redirect_uri", cfg.RedirectURI)
query.Set("response_type", "code")

// set and store the state param
values.Set("state", randStringBytes(15))
query.Set("state", randStringBytes(15))
http.SetCookie(w, &http.Cookie{
Name: stateCookieName,
MaxAge: 60 * 10, // 10 minutes
Value: values.Get("state"),
Value: query.Get("state"),
HttpOnly: true,
})

targetURL := cfg.AuthURL + "?" + values.Encode()
w.Header().Set("Location", targetURL)
redirectURL := *cfg.AuthURL
redirectURL.RawQuery = query.Encode()
w.Header().Set("Location", redirectURL.String())
w.WriteHeader(http.StatusFound)
}

Expand Down Expand Up @@ -116,7 +118,7 @@ func getAccessToken(cfg Config, state, code string) (TokenInfo, error) {
values.Set("redirect_uri", cfg.RedirectURI)
values.Set("grant_type", "authorization_code")

r, _ := http.NewRequest("POST", cfg.TokenURL, strings.NewReader(values.Encode()))
r, _ := http.NewRequest("POST", cfg.TokenURL.String(), strings.NewReader(values.Encode()))
cntx, cancel := context.WithTimeout(context.Background(), defaultTimeout)
defer cancel()
r.WithContext(cntx)
Expand Down Expand Up @@ -163,3 +165,11 @@ func randStringBytes(n int) string {
}
return string(b)
}

func mustParseURL(rawurl string) (theURL *url.URL) {
theURL, err := url.Parse(rawurl)
if err != nil {
panic(err)
}
return
}
18 changes: 9 additions & 9 deletions oauth2/oauth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ import (
var testConfig = Config{
ClientID: "client42",
ClientSecret: "secret",
AuthURL: "http://auth-provider/auth",
TokenURL: "http://auth-provider/token",
AuthURL: mustParseURL("http://auth-provider/auth?extraparam=foo"),
TokenURL: mustParseURL("http://auth-provider/token"),
RedirectURI: "http://localhost/callback",
Scope: "email other",
}
Expand All @@ -32,8 +32,8 @@ func Test_StartFlow(t *testing.T) {
Equal(t, stateCookieName, strings.Split(cHeader, "=")[0])
state := strings.Split(cHeader, "=")[1]

expectedLocation := fmt.Sprintf("%v?client_id=%v&redirect_uri=%v&response_type=code&scope=%v&state=%v",
testConfig.AuthURL,
expectedLocation := fmt.Sprintf("%v?client_id=%v&extraparam=foo&redirect_uri=%v&response_type=code&scope=%v&state=%v",
strings.Split(testConfig.AuthURL.String(), "?")[0], // Everything preceeding the querystring
testConfig.ClientID,
url.QueryEscape(testConfig.RedirectURI),
"email+other",
Expand All @@ -59,7 +59,7 @@ func Test_Authenticate(t *testing.T) {
defer server.Close()

testConfigCopy := testConfig
testConfigCopy.TokenURL = server.URL
testConfigCopy.TokenURL = mustParseURL(server.URL)

request, _ := http.NewRequest("GET", testConfig.RedirectURI, nil)
request.Header.Set("Cookie", "oauthState=theState")
Expand All @@ -85,7 +85,7 @@ func Test_Authenticate_CodeExchangeError(t *testing.T) {
defer server.Close()

testConfigCopy := testConfig
testConfigCopy.TokenURL = server.URL
testConfigCopy.TokenURL = mustParseURL(server.URL)

request, _ := http.NewRequest("GET", testConfig.RedirectURI, nil)
request.Header.Set("Cookie", "oauthState=theState")
Expand Down Expand Up @@ -152,7 +152,7 @@ func Test_Authentication_Provider500(t *testing.T) {
defer server.Close()

testConfigCopy := testConfig
testConfigCopy.TokenURL = server.URL
testConfigCopy.TokenURL = mustParseURL(server.URL)

request, _ := http.NewRequest("GET", testConfig.RedirectURI, nil)
request.Header.Set("Cookie", "oauthState=theState")
Expand All @@ -167,7 +167,7 @@ func Test_Authentication_Provider500(t *testing.T) {
func Test_Authentication_ProviderNetworkError(t *testing.T) {

testConfigCopy := testConfig
testConfigCopy.TokenURL = "http://localhost:12345678"
testConfigCopy.TokenURL = mustParseURL("http://localhost:12345678")

request, _ := http.NewRequest("GET", testConfig.RedirectURI, nil)
request.Header.Set("Cookie", "oauthState=theState")
Expand All @@ -190,7 +190,7 @@ func Test_Authentication_TokenParseError(t *testing.T) {
defer server.Close()

testConfigCopy := testConfig
testConfigCopy.TokenURL = server.URL
testConfigCopy.TokenURL = mustParseURL(server.URL)

request, _ := http.NewRequest("GET", testConfig.RedirectURI, nil)
request.Header.Set("Cookie", "oauthState=theState")
Expand Down

0 comments on commit 19b6fd5

Please sign in to comment.