From e785bc71cdfe7f7996dfbf1dd2152887f5b5f76d Mon Sep 17 00:00:00 2001 From: Furkan Date: Fri, 5 Jun 2020 09:33:10 +0300 Subject: [PATCH] oauth2: add www-authenticate at userinfo endpoint (#1891) Closes #1827 --- oauth2/handler.go | 8 +++++++- oauth2/handler_test.go | 33 +++++++++++++++++++++++++-------- 2 files changed, 32 insertions(+), 9 deletions(-) diff --git a/oauth2/handler.go b/oauth2/handler.go index 3f1b725826..254465aa85 100644 --- a/oauth2/handler.go +++ b/oauth2/handler.go @@ -272,12 +272,18 @@ func (h *Handler) UserinfoHandler(w http.ResponseWriter, r *http.Request) { session := NewSession("") tokenType, ar, err := h.r.OAuth2Provider().IntrospectToken(r.Context(), fosite.AccessTokenFromRequest(r), fosite.AccessToken, session) if err != nil { + rfcerr := fosite.ErrorToRFC6749Error(err) + if rfcerr.StatusCode() == http.StatusUnauthorized { + w.Header().Set("WWW-Authenticate", fmt.Sprintf("error=%s,error_description=%s,error_hint=%s", rfcerr.Name, rfcerr.Description, rfcerr.Hint)) + } h.r.Writer().WriteError(w, r, err) return } if tokenType != fosite.AccessToken { - h.r.Writer().WriteErrorCode(w, r, http.StatusUnauthorized, errors.New("Only access tokens are allowed in the authorization header")) + errorDescription := "Only access tokens are allowed in the authorization header" + w.Header().Set("WWW-Authenticate", fmt.Sprintf("error_description=\"%s\"", errorDescription)) + h.r.Writer().WriteErrorCode(w, r, http.StatusUnauthorized, errors.New(errorDescription)) return } diff --git a/oauth2/handler_test.go b/oauth2/handler_test.go index a1fa09859e..5bd3c1f047 100644 --- a/oauth2/handler_test.go +++ b/oauth2/handler_test.go @@ -171,9 +171,10 @@ func TestUserinfo(t *testing.T) { defer ts.Close() for k, tc := range []struct { - setup func(t *testing.T) - check func(t *testing.T, body []byte) - expectStatusCode int + setup func(t *testing.T) + checkForSuccess func(t *testing.T, body []byte) + checkForUnauthorized func(t *testing.T, body []byte, header http.Header) + expectStatusCode int }{ { setup: func(t *testing.T) { @@ -187,6 +188,20 @@ func TestUserinfo(t *testing.T) { IntrospectToken(gomock.Any(), gomock.Eq("access-token"), gomock.Eq(fosite.AccessToken), gomock.Any()). Return(fosite.RefreshToken, nil, nil) }, + checkForUnauthorized: func(t *testing.T, body []byte, headers http.Header) { + assert.True(t, headers.Get("WWW-Authenticate") != "", "%s", headers) + }, + expectStatusCode: http.StatusUnauthorized, + }, + { + setup: func(t *testing.T) { + op.EXPECT(). + IntrospectToken(gomock.Any(), gomock.Eq("access-token"), gomock.Eq(fosite.AccessToken), gomock.Any()). + Return(fosite.AccessToken, nil, fosite.ErrRequestUnauthorized) + }, + checkForUnauthorized: func(t *testing.T, body []byte, headers http.Header) { + assert.True(t, headers.Get("WWW-Authenticate") != "", "%s", headers) + }, expectStatusCode: http.StatusUnauthorized, }, { @@ -214,7 +229,7 @@ func TestUserinfo(t *testing.T) { }) }, expectStatusCode: http.StatusOK, - check: func(t *testing.T, body []byte) { + checkForSuccess: func(t *testing.T, body []byte) { assert.True(t, strings.Contains(string(body), `"sub":"alice"`), "%s", body) }, }, @@ -243,7 +258,7 @@ func TestUserinfo(t *testing.T) { }) }, expectStatusCode: http.StatusOK, - check: func(t *testing.T, body []byte) { + checkForSuccess: func(t *testing.T, body []byte) { assert.False(t, strings.Contains(string(body), `"sub":"alice"`), "%s", body) assert.True(t, strings.Contains(string(body), `"sub":"another-alice"`), "%s", body) }, @@ -275,7 +290,7 @@ func TestUserinfo(t *testing.T) { }) }, expectStatusCode: http.StatusOK, - check: func(t *testing.T, body []byte) { + checkForSuccess: func(t *testing.T, body []byte) { assert.True(t, strings.Contains(string(body), `"sub":"alice"`), "%s", body) }, }, @@ -334,7 +349,7 @@ func TestUserinfo(t *testing.T) { }) }, expectStatusCode: http.StatusOK, - check: func(t *testing.T, body []byte) { + checkForSuccess: func(t *testing.T, body []byte) { claims, err := jwt2.Parse(string(body), func(token *jwt2.Token) (interface{}, error) { keys, err := reg.KeyManager().GetKeySet(context.Background(), x.OpenIDConnectKeyName) require.NoError(t, err) @@ -360,7 +375,9 @@ func TestUserinfo(t *testing.T) { body, err := ioutil.ReadAll(resp.Body) require.NoError(t, err) if tc.expectStatusCode == http.StatusOK { - tc.check(t, body) + tc.checkForSuccess(t, body) + } else if tc.expectStatusCode == http.StatusUnauthorized { + tc.checkForUnauthorized(t, body, resp.Header) } }) }