Skip to content

Commit

Permalink
refactor: use rfc compliant error formating
Browse files Browse the repository at this point in the history
BREAKING CHANGE: This patch removes fields `error_hint`, `error_debug` from error responses. To use the legacy error format where these fields are included, set `UseLegacyErrorFormat` to true in your compose config or directly on the `Fosite` struct. If `UseLegacyErrorFormat` is set, the `error_description` no longer merges `error_hint` nor `error_debug` messages which reverts a change introduced in `v0.33.0`. Instead, `error_hint` and `error_debug` are included and the merged message can be constructed from those fields.
  • Loading branch information
aeneasr committed Nov 16, 2020
1 parent de5c8f9 commit edbbda3
Show file tree
Hide file tree
Showing 10 changed files with 301 additions and 92 deletions.
2 changes: 1 addition & 1 deletion access_error.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ func (f *Fosite) writeJsonError(rw http.ResponseWriter, err error) {
rw.Header().Set("Cache-Control", "no-store")
rw.Header().Set("Pragma", "no-cache")

rfcerr := ErrorToRFC6749Error(err)
rfcerr := ErrorToRFC6749Error(err).WithLegacyFormat(f.UseLegacyErrorFormat)
if !f.SendDebugMessagesToClients {
rfcerr = rfcerr.Sanitize()
}
Expand Down
49 changes: 30 additions & 19 deletions access_error_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,18 +60,22 @@ func TestWriteAccessError_RFC6749(t *testing.T) {
code string
debug bool
expectDebugMessage string
includeExtraFields bool
}{
{ErrInvalidRequest.WithDebug("some-debug"), "invalid_request", true, "some-debug"},
{ErrInvalidRequest.WithDebugf("some-debug-%d", 1234), "invalid_request", true, "some-debug-1234"},
{ErrInvalidRequest.WithDebug("some-debug"), "invalid_request", false, "some-debug"},
{ErrInvalidClient.WithDebug("some-debug"), "invalid_client", false, "some-debug"},
{ErrInvalidGrant.WithDebug("some-debug"), "invalid_grant", false, "some-debug"},
{ErrInvalidScope.WithDebug("some-debug"), "invalid_scope", false, "some-debug"},
{ErrUnauthorizedClient.WithDebug("some-debug"), "unauthorized_client", false, "some-debug"},
{ErrUnsupportedGrantType.WithDebug("some-debug"), "unsupported_grant_type", false, "some-debug"},
{ErrInvalidRequest.WithDebug("some-debug"), "invalid_request", true, "some-debug", true},
{ErrInvalidRequest.WithDebugf("some-debug-%d", 1234), "invalid_request", true, "some-debug-1234", true},
{ErrInvalidRequest.WithDebug("some-debug"), "invalid_request", false, "some-debug", true},
{ErrInvalidClient.WithDebug("some-debug"), "invalid_client", false, "some-debug", true},
{ErrInvalidGrant.WithDebug("some-debug"), "invalid_grant", false, "some-debug", true},
{ErrInvalidScope.WithDebug("some-debug"), "invalid_scope", false, "some-debug", true},
{ErrUnauthorizedClient.WithDebug("some-debug"), "unauthorized_client", false, "some-debug", true},
{ErrUnsupportedGrantType.WithDebug("some-debug"), "unsupported_grant_type", false, "some-debug", true},
{ErrUnsupportedGrantType.WithDebug("some-debug"), "unsupported_grant_type", false, "some-debug", false},
{ErrUnsupportedGrantType.WithDebug("some-debug"), "unsupported_grant_type", true, "some-debug", false},
} {
t.Run(fmt.Sprintf("case=%d", k), func(t *testing.T) {
f.SendDebugMessagesToClients = c.debug
f.UseLegacyErrorFormat = c.includeExtraFields

rw := httptest.NewRecorder()
f.WriteAccessError(rw, nil, c.err)
Expand All @@ -88,19 +92,26 @@ func TestWriteAccessError_RFC6749(t *testing.T) {
require.NoError(t, err)

assert.Equal(t, c.code, params.Error)
assert.Equal(t, c.err.HintField, params.Hint)

expectDescription := c.err.DescriptionField
if c.err.HintField != "" {
expectDescription += " " + c.err.HintField
}

if !c.debug {
assert.Equal(t, expectDescription, params.Description)
if !c.includeExtraFields {
assert.Empty(t, params.Debug)
assert.Empty(t, params.Hint)
assert.Contains(t, params.Description, c.err.DescriptionField)
assert.Contains(t, params.Description, c.err.HintField)

if c.debug {
assert.Contains(t, params.Description, c.err.DebugField)
} else {
assert.NotContains(t, params.Description, c.err.DebugField)
}
} else {
assert.Equal(t, expectDescription+" "+c.expectDebugMessage, params.Description)
assert.Equal(t, c.expectDebugMessage, params.Debug)
assert.EqualValues(t, c.err.DescriptionField, params.Description)
assert.EqualValues(t, c.err.HintField, params.Hint)

if !c.debug {
assert.Empty(t, params.Debug)
} else {
assert.EqualValues(t, c.err.DebugField, params.Debug)
}
}
})
}
Expand Down
18 changes: 9 additions & 9 deletions authorize_error.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func (f *Fosite) WriteAuthorizeError(rw http.ResponseWriter, ar AuthorizeRequest
rw.Header().Set("Cache-Control", "no-store")
rw.Header().Set("Pragma", "no-cache")

rfcerr := ErrorToRFC6749Error(err)
rfcerr := ErrorToRFC6749Error(err).WithLegacyFormat(f.UseLegacyErrorFormat)
if !f.SendDebugMessagesToClients {
rfcerr = rfcerr.Sanitize()
}
Expand Down Expand Up @@ -60,26 +60,26 @@ func (f *Fosite) WriteAuthorizeError(rw http.ResponseWriter, ar AuthorizeRequest
// The endpoint URI MUST NOT include a fragment component.
redirectURI.Fragment = ""

query := rfcerr.ToValues()
query.Add("state", ar.GetState())
errors := rfcerr.ToValues()
errors.Set("state", ar.GetState())

var redirectURIString string
if ar.GetResponseMode() == ResponseModeFormPost {
rw.Header().Add("Content-Type", "text/html;charset=UTF-8")
WriteAuthorizeFormPostResponse(redirectURI.String(), query, GetPostFormHTMLTemplate(*f), rw)
rw.Header().Set("Content-Type", "text/html;charset=UTF-8")
WriteAuthorizeFormPostResponse(redirectURI.String(), errors, GetPostFormHTMLTemplate(*f), rw)
return
} else if ar.GetResponseMode() == ResponseModeFragment {
redirectURIString = redirectURI.String() + "#" + query.Encode()
redirectURIString = redirectURI.String() + "#" + errors.Encode()
} else {
for key, values := range redirectURI.Query() {
for _, value := range values {
query.Add(key, value)
errors.Add(key, value)
}
}
redirectURI.RawQuery = query.Encode()
redirectURI.RawQuery = errors.Encode()
redirectURIString = redirectURI.String()
}

rw.Header().Add("Location", redirectURIString)
rw.Header().Set("Location", redirectURIString)
rw.WriteHeader(http.StatusFound)
}
Loading

0 comments on commit edbbda3

Please sign in to comment.