diff --git a/oauth2/manager.go b/oauth2/manager.go index cb15584c..02aa12a0 100644 --- a/oauth2/manager.go +++ b/oauth2/manager.go @@ -95,10 +95,20 @@ func (manager *Manager) AddConfig(providerName string, opts map[string]string) e return fmt.Errorf("no provider for name %v", providerName) } + authURL, err := url.Parse(p.AuthURL) + if err != nil { + return fmt.Errorf("parse auth URL: %s", err) + } + + tokenURL, err := url.Parse(p.TokenURL) + if err != nil { + return fmt.Errorf("parse token URL: %s", err) + } + cfg := Config{ Provider: p, - AuthURL: p.AuthURL, - TokenURL: p.TokenURL, + AuthURL: authURL, + TokenURL: tokenURL, } clientID, exist := opts["client_id"] diff --git a/oauth2/manager_test.go b/oauth2/manager_test.go index 9c893a62..3dbd57a5 100644 --- a/oauth2/manager_test.go +++ b/oauth2/manager_test.go @@ -34,8 +34,8 @@ func Test_Manager_Positive_Flow(t *testing.T) { expectedConfig := Config{ ClientID: "client42", ClientSecret: "secret", - AuthURL: exampleProvider.AuthURL, - TokenURL: exampleProvider.TokenURL, + AuthURL: mustParseURL(exampleProvider.AuthURL), + TokenURL: mustParseURL(exampleProvider.TokenURL), RedirectURI: "http://localhost", Scope: "email other", Provider: exampleProvider, diff --git a/oauth2/oauth.go b/oauth2/oauth.go index 1120fe24..501b8e12 100644 --- a/oauth2/oauth.go +++ b/oauth2/oauth.go @@ -26,10 +26,10 @@ type Config struct { ClientSecret string // The oauth authentication url to redirect to - AuthURL string + AuthURL *url.URL // The url for token exchange - TokenURL string + TokenURL *url.URL // RedirectURL is the URL to redirect users going through // the OAuth flow, after the resource owner's URLs. @@ -68,23 +68,25 @@ const defaultTimeout = 5 * time.Second // StartFlow by redirecting the user to the login provider. // A state parameter to protect against cross-site request forgery attacks is randomly generated and stored in a cookie func StartFlow(cfg Config, w http.ResponseWriter) { - values := make(url.Values) - values.Set("client_id", cfg.ClientID) - values.Set("scope", cfg.Scope) - values.Set("redirect_uri", cfg.RedirectURI) - values.Set("response_type", "code") + // Add auth parameters to the URL, preserving any existing parameters + query := cfg.AuthURL.Query() + query.Set("client_id", cfg.ClientID) + query.Set("scope", cfg.Scope) + query.Set("redirect_uri", cfg.RedirectURI) + query.Set("response_type", "code") // set and store the state param - values.Set("state", randStringBytes(15)) + query.Set("state", randStringBytes(15)) http.SetCookie(w, &http.Cookie{ Name: stateCookieName, MaxAge: 60 * 10, // 10 minutes - Value: values.Get("state"), + Value: query.Get("state"), HttpOnly: true, }) - targetURL := cfg.AuthURL + "?" + values.Encode() - w.Header().Set("Location", targetURL) + redirectURL := *cfg.AuthURL + redirectURL.RawQuery = query.Encode() + w.Header().Set("Location", redirectURL.String()) w.WriteHeader(http.StatusFound) } @@ -116,7 +118,7 @@ func getAccessToken(cfg Config, state, code string) (TokenInfo, error) { values.Set("redirect_uri", cfg.RedirectURI) values.Set("grant_type", "authorization_code") - r, _ := http.NewRequest("POST", cfg.TokenURL, strings.NewReader(values.Encode())) + r, _ := http.NewRequest("POST", cfg.TokenURL.String(), strings.NewReader(values.Encode())) cntx, cancel := context.WithTimeout(context.Background(), defaultTimeout) defer cancel() r.WithContext(cntx) @@ -163,3 +165,11 @@ func randStringBytes(n int) string { } return string(b) } + +func mustParseURL(rawurl string) (theURL *url.URL) { + theURL, err := url.Parse(rawurl) + if err != nil { + panic(err) + } + return +} diff --git a/oauth2/oauth_test.go b/oauth2/oauth_test.go index c6aadc56..5a944912 100644 --- a/oauth2/oauth_test.go +++ b/oauth2/oauth_test.go @@ -15,8 +15,8 @@ import ( var testConfig = Config{ ClientID: "client42", ClientSecret: "secret", - AuthURL: "http://auth-provider/auth", - TokenURL: "http://auth-provider/token", + AuthURL: mustParseURL("http://auth-provider/auth?extraparam=foo"), + TokenURL: mustParseURL("http://auth-provider/token"), RedirectURI: "http://localhost/callback", Scope: "email other", } @@ -32,8 +32,8 @@ func Test_StartFlow(t *testing.T) { Equal(t, stateCookieName, strings.Split(cHeader, "=")[0]) state := strings.Split(cHeader, "=")[1] - expectedLocation := fmt.Sprintf("%v?client_id=%v&redirect_uri=%v&response_type=code&scope=%v&state=%v", - testConfig.AuthURL, + expectedLocation := fmt.Sprintf("%v?client_id=%v&extraparam=foo&redirect_uri=%v&response_type=code&scope=%v&state=%v", + strings.Split(testConfig.AuthURL.String(), "?")[0], // Everything preceeding the querystring testConfig.ClientID, url.QueryEscape(testConfig.RedirectURI), "email+other", @@ -59,7 +59,7 @@ func Test_Authenticate(t *testing.T) { defer server.Close() testConfigCopy := testConfig - testConfigCopy.TokenURL = server.URL + testConfigCopy.TokenURL = mustParseURL(server.URL) request, _ := http.NewRequest("GET", testConfig.RedirectURI, nil) request.Header.Set("Cookie", "oauthState=theState") @@ -85,7 +85,7 @@ func Test_Authenticate_CodeExchangeError(t *testing.T) { defer server.Close() testConfigCopy := testConfig - testConfigCopy.TokenURL = server.URL + testConfigCopy.TokenURL = mustParseURL(server.URL) request, _ := http.NewRequest("GET", testConfig.RedirectURI, nil) request.Header.Set("Cookie", "oauthState=theState") @@ -152,7 +152,7 @@ func Test_Authentication_Provider500(t *testing.T) { defer server.Close() testConfigCopy := testConfig - testConfigCopy.TokenURL = server.URL + testConfigCopy.TokenURL = mustParseURL(server.URL) request, _ := http.NewRequest("GET", testConfig.RedirectURI, nil) request.Header.Set("Cookie", "oauthState=theState") @@ -167,7 +167,7 @@ func Test_Authentication_Provider500(t *testing.T) { func Test_Authentication_ProviderNetworkError(t *testing.T) { testConfigCopy := testConfig - testConfigCopy.TokenURL = "http://localhost:12345678" + testConfigCopy.TokenURL = mustParseURL("http://localhost:12345678") request, _ := http.NewRequest("GET", testConfig.RedirectURI, nil) request.Header.Set("Cookie", "oauthState=theState") @@ -190,7 +190,7 @@ func Test_Authentication_TokenParseError(t *testing.T) { defer server.Close() testConfigCopy := testConfig - testConfigCopy.TokenURL = server.URL + testConfigCopy.TokenURL = mustParseURL(server.URL) request, _ := http.NewRequest("GET", testConfig.RedirectURI, nil) request.Header.Set("Cookie", "oauthState=theState")