Skip to content

Commit

Permalink
oauth2: add www-authenticate at userinfo endpoint (#1891)
Browse files Browse the repository at this point in the history
Closes  #1827
  • Loading branch information
bayansar committed Jun 5, 2020
1 parent f0609ad commit e785bc7
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 9 deletions.
8 changes: 7 additions & 1 deletion oauth2/handler.go
Expand Up @@ -272,12 +272,18 @@ func (h *Handler) UserinfoHandler(w http.ResponseWriter, r *http.Request) {
session := NewSession("")
tokenType, ar, err := h.r.OAuth2Provider().IntrospectToken(r.Context(), fosite.AccessTokenFromRequest(r), fosite.AccessToken, session)
if err != nil {
rfcerr := fosite.ErrorToRFC6749Error(err)
if rfcerr.StatusCode() == http.StatusUnauthorized {
w.Header().Set("WWW-Authenticate", fmt.Sprintf("error=%s,error_description=%s,error_hint=%s", rfcerr.Name, rfcerr.Description, rfcerr.Hint))
}
h.r.Writer().WriteError(w, r, err)
return
}

if tokenType != fosite.AccessToken {
h.r.Writer().WriteErrorCode(w, r, http.StatusUnauthorized, errors.New("Only access tokens are allowed in the authorization header"))
errorDescription := "Only access tokens are allowed in the authorization header"
w.Header().Set("WWW-Authenticate", fmt.Sprintf("error_description=\"%s\"", errorDescription))
h.r.Writer().WriteErrorCode(w, r, http.StatusUnauthorized, errors.New(errorDescription))
return
}

Expand Down
33 changes: 25 additions & 8 deletions oauth2/handler_test.go
Expand Up @@ -171,9 +171,10 @@ func TestUserinfo(t *testing.T) {
defer ts.Close()

for k, tc := range []struct {
setup func(t *testing.T)
check func(t *testing.T, body []byte)
expectStatusCode int
setup func(t *testing.T)
checkForSuccess func(t *testing.T, body []byte)
checkForUnauthorized func(t *testing.T, body []byte, header http.Header)
expectStatusCode int
}{
{
setup: func(t *testing.T) {
Expand All @@ -187,6 +188,20 @@ func TestUserinfo(t *testing.T) {
IntrospectToken(gomock.Any(), gomock.Eq("access-token"), gomock.Eq(fosite.AccessToken), gomock.Any()).
Return(fosite.RefreshToken, nil, nil)
},
checkForUnauthorized: func(t *testing.T, body []byte, headers http.Header) {
assert.True(t, headers.Get("WWW-Authenticate") != "", "%s", headers)
},
expectStatusCode: http.StatusUnauthorized,
},
{
setup: func(t *testing.T) {
op.EXPECT().
IntrospectToken(gomock.Any(), gomock.Eq("access-token"), gomock.Eq(fosite.AccessToken), gomock.Any()).
Return(fosite.AccessToken, nil, fosite.ErrRequestUnauthorized)
},
checkForUnauthorized: func(t *testing.T, body []byte, headers http.Header) {
assert.True(t, headers.Get("WWW-Authenticate") != "", "%s", headers)
},
expectStatusCode: http.StatusUnauthorized,
},
{
Expand Down Expand Up @@ -214,7 +229,7 @@ func TestUserinfo(t *testing.T) {
})
},
expectStatusCode: http.StatusOK,
check: func(t *testing.T, body []byte) {
checkForSuccess: func(t *testing.T, body []byte) {
assert.True(t, strings.Contains(string(body), `"sub":"alice"`), "%s", body)
},
},
Expand Down Expand Up @@ -243,7 +258,7 @@ func TestUserinfo(t *testing.T) {
})
},
expectStatusCode: http.StatusOK,
check: func(t *testing.T, body []byte) {
checkForSuccess: func(t *testing.T, body []byte) {
assert.False(t, strings.Contains(string(body), `"sub":"alice"`), "%s", body)
assert.True(t, strings.Contains(string(body), `"sub":"another-alice"`), "%s", body)
},
Expand Down Expand Up @@ -275,7 +290,7 @@ func TestUserinfo(t *testing.T) {
})
},
expectStatusCode: http.StatusOK,
check: func(t *testing.T, body []byte) {
checkForSuccess: func(t *testing.T, body []byte) {
assert.True(t, strings.Contains(string(body), `"sub":"alice"`), "%s", body)
},
},
Expand Down Expand Up @@ -334,7 +349,7 @@ func TestUserinfo(t *testing.T) {
})
},
expectStatusCode: http.StatusOK,
check: func(t *testing.T, body []byte) {
checkForSuccess: func(t *testing.T, body []byte) {
claims, err := jwt2.Parse(string(body), func(token *jwt2.Token) (interface{}, error) {
keys, err := reg.KeyManager().GetKeySet(context.Background(), x.OpenIDConnectKeyName)
require.NoError(t, err)
Expand All @@ -360,7 +375,9 @@ func TestUserinfo(t *testing.T) {
body, err := ioutil.ReadAll(resp.Body)
require.NoError(t, err)
if tc.expectStatusCode == http.StatusOK {
tc.check(t, body)
tc.checkForSuccess(t, body)
} else if tc.expectStatusCode == http.StatusUnauthorized {
tc.checkForUnauthorized(t, body, resp.Header)
}
})
}
Expand Down

0 comments on commit e785bc7

Please sign in to comment.