Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add ResponseModeHandler to support custom response modes #592

Merged
merged 2 commits into from
Jun 11, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions authorize_error.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ func (f *Fosite) WriteAuthorizeError(rw http.ResponseWriter, ar AuthorizeRequest
rw.Header().Set("Cache-Control", "no-store")
rw.Header().Set("Pragma", "no-cache")

if f.ResponseModeHandler().ResponseModes().Has(ar.GetResponseMode()) {
f.ResponseModeHandler().WriteAuthorizeError(rw, ar, err)
return
}

rfcerr := ErrorToRFC6749Error(err).WithLegacyFormat(f.UseLegacyErrorFormat).WithExposeDebug(f.SendDebugMessagesToClients)
if !ar.IsRedirectURIValid() {
rw.Header().Set("Content-Type", "application/json;charset=UTF-8")
Expand Down
3 changes: 2 additions & 1 deletion authorize_error_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ func TestWriteAuthorizeError(t *testing.T) {
err: ErrInvalidGrant,
mock: func(rw *MockResponseWriter, req *MockAuthorizeRequester) {
req.EXPECT().IsRedirectURIValid().Return(false)
req.EXPECT().GetResponseMode().Return(ResponseModeDefault)
rw.EXPECT().Header().Times(3).Return(header)
rw.EXPECT().WriteHeader(http.StatusBadRequest)
rw.EXPECT().Write(gomock.Any())
Expand Down Expand Up @@ -427,7 +428,7 @@ func TestWriteAuthorizeError(t *testing.T) {
req.EXPECT().GetRedirectURI().Return(copyUrl(purls[1]))
req.EXPECT().GetState().Return("foostate")
req.EXPECT().GetResponseTypes().AnyTimes().Return(Arguments([]string{"token"}))
req.EXPECT().GetResponseMode().Return(ResponseModeFormPost).Times(1)
req.EXPECT().GetResponseMode().Return(ResponseModeFormPost).Times(2)
rw.EXPECT().Header().Times(3).Return(header)
rw.EXPECT().Write(gomock.Any()).AnyTimes()
},
Expand Down
5 changes: 5 additions & 0 deletions authorize_request_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,11 @@ func (f *Fosite) ParseResponseMode(r *http.Request, request *AuthorizeRequest) e
case string(ResponseModeFormPost):
request.ResponseMode = ResponseModeFormPost
default:
rm := ResponseModeType(responseMode)
if f.ResponseModeHandler().ResponseModes().Has(rm) {
request.ResponseMode = ResponseModeType(rm)
break
}
return errorsx.WithStack(ErrUnsupportedResponseMode.WithHintf("Request with unsupported response_mode \"%s\".", responseMode))
}

Expand Down
7 changes: 6 additions & 1 deletion authorize_write.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func (f *Fosite) WriteAuthorizeResponse(rw http.ResponseWriter, ar AuthorizeRequ
wh.Set("Pragma", "no-cache")

redir := ar.GetRedirectURI()
switch ar.GetResponseMode() {
switch rm := ar.GetResponseMode(); rm {
case ResponseModeFormPost:
//form_post
rw.Header().Add("Content-Type", "text/html;charset=UTF-8")
Expand All @@ -60,6 +60,11 @@ func (f *Fosite) WriteAuthorizeResponse(rw http.ResponseWriter, ar AuthorizeRequ
URLSetFragment(redir, resp.GetParameters())
sendRedirect(redir.String(), rw)
return
default:
if f.ResponseModeHandler().ResponseModes().Has(rm) {
f.ResponseModeHandler().WriteAuthorizeResponse(rw, ar, resp)
return
}
}
}

Expand Down
1 change: 1 addition & 0 deletions compose/compose.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ func Compose(config *Config, storage interface{}, strategy interface{}, hasher f
MinParameterEntropy: config.GetMinParameterEntropy(),
UseLegacyErrorFormat: config.UseLegacyErrorFormat,
ClientAuthenticationStrategy: config.GetClientAuthenticationStrategy(),
ResponseModeHandlerExtension: config.ResponseModeHandlerExtension,
}

for _, factory := range factories {
Expand Down
3 changes: 3 additions & 0 deletions compose/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,9 @@ type Config struct {

// ClientAuthenticationStrategy indicates the Strategy to authenticate client requests
ClientAuthenticationStrategy fosite.ClientAuthenticationStrategy

// ResponseModeHandlerExtension provides a handler for custom response modes
ResponseModeHandlerExtension fosite.ResponseModeHandler
}

// GetScopeStrategy returns the scope strategy to be used. Defaults to glob scope strategy.
Expand Down
11 changes: 11 additions & 0 deletions fosite.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@ type Fosite struct {

// ClientAuthenticationStrategy provides an extension point to plug a strategy to authenticate clients
ClientAuthenticationStrategy ClientAuthenticationStrategy

ResponseModeHandlerExtension ResponseModeHandler
}

const MinParameterEntropy = 8
Expand All @@ -125,3 +127,12 @@ func (f *Fosite) GetMinParameterEntropy() int {
return f.MinParameterEntropy
}
}

var defaultResponseModeHandler = &DefaultResponseModeHandler{}

func (f *Fosite) ResponseModeHandler() ResponseModeHandler {
if f.ResponseModeHandlerExtension == nil {
return defaultResponseModeHandler
}
return f.ResponseModeHandlerExtension
}
109 changes: 79 additions & 30 deletions integration/authorize_form_post_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ package integration_test
import (
"fmt"
"net/http"
"net/url"
"strings"
"testing"

Expand All @@ -40,6 +41,15 @@ import (
"github.com/ory/fosite/compose"
)

type formPostTestCase struct {
description string
setup func()
check checkFunc
responseType string
}

type checkFunc func(t *testing.T, stateFromServer string, code string, token goauth.Token, iDToken string, cparam url.Values, err map[string]string)

func TestAuthorizeFormPostResponseMode(t *testing.T) {
session := &defaultSession{
DefaultSession: &openid.DefaultSession{
Expand All @@ -49,7 +59,8 @@ func TestAuthorizeFormPostResponseMode(t *testing.T) {
Headers: &jwt.Headers{},
},
}
f := compose.ComposeAllEnabled(new(compose.Config), fositeStore, []byte("some-secret-thats-random-some-secret-thats-random-"), internal.MustRSAKey())
config := &compose.Config{ResponseModeHandlerExtension: &decoratedFormPostResponse{}}
f := compose.ComposeAllEnabled(config, fositeStore, []byte("some-secret-thats-random-some-secret-thats-random-"), internal.MustRSAKey())
ts := mockServer(t, f, session)
defer ts.Close()

Expand All @@ -58,26 +69,21 @@ func TestAuthorizeFormPostResponseMode(t *testing.T) {
defaultClient.RedirectURIs[0] = ts.URL + "/callback"
responseModeClient := &fosite.DefaultResponseModeClient{
DefaultClient: defaultClient,
ResponseModes: []fosite.ResponseModeType{fosite.ResponseModeFormPost},
ResponseModes: []fosite.ResponseModeType{fosite.ResponseModeFormPost, fosite.ResponseModeFormPost, "decorated_form_post"},
}
fositeStore.Clients["response-mode-client"] = responseModeClient
oauthClient.ClientID = "response-mode-client"

var state string
for k, c := range []struct {
description string
setup func()
check func(t *testing.T, stateFromServer string, code string, token goauth.Token, iDToken string, err map[string]string)
responseType string
}{
for k, c := range []formPostTestCase{
{
description: "implicit grant #1 test with form_post",
responseType: "id_token%20token",
setup: func() {
state = "12345678901234567890"
oauthClient.Scopes = []string{"openid"}
},
check: func(t *testing.T, stateFromServer string, code string, token goauth.Token, iDToken string, err map[string]string) {
check: func(t *testing.T, stateFromServer string, code string, token goauth.Token, iDToken string, cparam url.Values, err map[string]string) {
assert.EqualValues(t, state, stateFromServer)
assert.NotEmpty(t, token.TokenType)
assert.NotEmpty(t, token.AccessToken)
Expand All @@ -92,7 +98,7 @@ func TestAuthorizeFormPostResponseMode(t *testing.T) {
state = "12345678901234567890"
oauthClient.Scopes = []string{"openid"}
},
check: func(t *testing.T, stateFromServer string, code string, token goauth.Token, iDToken string, err map[string]string) {
check: func(t *testing.T, stateFromServer string, code string, token goauth.Token, iDToken string, cparam url.Values, err map[string]string) {
assert.EqualValues(t, state, stateFromServer)
assert.NotEmpty(t, iDToken)
},
Expand All @@ -103,7 +109,7 @@ func TestAuthorizeFormPostResponseMode(t *testing.T) {
setup: func() {
state = "12345678901234567890"
},
check: func(t *testing.T, stateFromServer string, code string, token goauth.Token, iDToken string, err map[string]string) {
check: func(t *testing.T, stateFromServer string, code string, token goauth.Token, iDToken string, cparam url.Values, err map[string]string) {
assert.EqualValues(t, state, stateFromServer)
assert.NotEmpty(t, code)
},
Expand All @@ -115,7 +121,7 @@ func TestAuthorizeFormPostResponseMode(t *testing.T) {
state = "12345678901234567890"
oauthClient.Scopes = []string{"openid"}
},
check: func(t *testing.T, stateFromServer string, code string, token goauth.Token, iDToken string, err map[string]string) {
check: func(t *testing.T, stateFromServer string, code string, token goauth.Token, iDToken string, cparam url.Values, err map[string]string) {
assert.EqualValues(t, state, stateFromServer)
assert.NotEmpty(t, code)
assert.NotEmpty(t, token.TokenType)
Expand All @@ -130,7 +136,7 @@ func TestAuthorizeFormPostResponseMode(t *testing.T) {
state = "12345678901234567890"
oauthClient.Scopes = []string{"openid"}
},
check: func(t *testing.T, stateFromServer string, code string, token goauth.Token, iDToken string, err map[string]string) {
check: func(t *testing.T, stateFromServer string, code string, token goauth.Token, iDToken string, cparam url.Values, err map[string]string) {
assert.EqualValues(t, state, stateFromServer)
assert.NotEmpty(t, code)
assert.NotEmpty(t, iDToken)
Expand All @@ -146,7 +152,7 @@ func TestAuthorizeFormPostResponseMode(t *testing.T) {
state = "12345678901234567890"
oauthClient.Scopes = []string{"openid"}
},
check: func(t *testing.T, stateFromServer string, code string, token goauth.Token, iDToken string, err map[string]string) {
check: func(t *testing.T, stateFromServer string, code string, token goauth.Token, iDToken string, cparam url.Values, err map[string]string) {
assert.EqualValues(t, state, stateFromServer)
assert.NotEmpty(t, code)
assert.NotEmpty(t, iDToken)
Expand All @@ -158,27 +164,70 @@ func TestAuthorizeFormPostResponseMode(t *testing.T) {
setup: func() {
state = "12345678901234567890"
},
check: func(t *testing.T, stateFromServer string, code string, token goauth.Token, iDToken string, err map[string]string) {
check: func(t *testing.T, stateFromServer string, code string, token goauth.Token, iDToken string, cparam url.Values, err map[string]string) {
assert.EqualValues(t, state, stateFromServer)
assert.NotEmpty(t, err["ErrorField"])
assert.NotEmpty(t, err["DescriptionField"])
},
},
} {
t.Run(fmt.Sprintf("case=%d/description=%s", k, c.description), func(t *testing.T) {
c.setup()
authURL := strings.Replace(oauthClient.AuthCodeURL(state, goauth.SetAuthURLParam("response_mode", "form_post"), goauth.SetAuthURLParam("nonce", "111111111")), "response_type=code", "response_type="+c.responseType, -1)
client := &http.Client{
CheckRedirect: func(req *http.Request, via []*http.Request) error {
return errors.New("Dont follow redirects")
},
}
resp, err := client.Get(authURL)
require.NoError(t, err)
require.Equal(t, http.StatusOK, resp.StatusCode)
code, state, token, iDToken, _, errResp, err := internal.ParseFormPostResponse(fositeStore.Clients["response-mode-client"].GetRedirectURIs()[0], resp.Body)
require.NoError(t, err)
c.check(t, state, code, iDToken, token, errResp)
})
// Test canonical form_post
t.Run(fmt.Sprintf("case=%d/description=%s", k, c.description), testFormPost(&state, false, c, oauthClient, "form_post"))

// Test decorated form_post response
c.check = decorateCheck(c.check)
t.Run(fmt.Sprintf("case=%d/description=decorated_%s", k, c.description), testFormPost(&state, true, c, oauthClient, "decorated_form_post"))
}
}

func testFormPost(state *string, customResponse bool, c formPostTestCase, oauthClient *goauth.Config, responseMode string) func(t *testing.T) {
return func(t *testing.T) {
c.setup()
authURL := strings.Replace(oauthClient.AuthCodeURL(*state, goauth.SetAuthURLParam("response_mode", responseMode), goauth.SetAuthURLParam("nonce", "111111111")), "response_type=code", "response_type="+c.responseType, -1)
client := &http.Client{
CheckRedirect: func(req *http.Request, via []*http.Request) error {
return errors.New("Dont follow redirects")
},
}
resp, err := client.Get(authURL)
require.NoError(t, err)
require.Equal(t, http.StatusOK, resp.StatusCode)
code, state, token, iDToken, cparam, errResp, err := internal.ParseFormPostResponse(fositeStore.Clients["response-mode-client"].GetRedirectURIs()[0], resp.Body)
require.NoError(t, err)
c.check(t, state, code, iDToken, token, cparam, errResp)
}
}

func decorateCheck(cf checkFunc) checkFunc {
return func(t *testing.T, stateFromServer string, code string, token goauth.Token, iDToken string, cparam url.Values, err map[string]string) {
cf(t, stateFromServer, code, token, iDToken, cparam, err)
if len(err) > 0 {
assert.Contains(t, cparam, "custom_err_param")
return
}
assert.Contains(t, cparam, "custom_param")
}
}

// This test type provides an example implementation
// of a custom response mode handler.
// In this case it decorates the `form_post` response mode
// with some additional custom parameters
type decoratedFormPostResponse struct {
}

func (m *decoratedFormPostResponse) ResponseModes() fosite.ResponseModeTypes {
return fosite.ResponseModeTypes{"decorated_form_post"}
}
func (m *decoratedFormPostResponse) WriteAuthorizeResponse(rw http.ResponseWriter, ar fosite.AuthorizeRequester, resp fosite.AuthorizeResponder) {
rw.Header().Add("Content-Type", "text/html;charset=UTF-8")
resp.AddParameter("custom_param", "foo")
fosite.WriteAuthorizeFormPostResponse(ar.GetRedirectURI().String(), resp.GetParameters(), fosite.GetPostFormHTMLTemplate(fosite.Fosite{}), rw)
}
func (m *decoratedFormPostResponse) WriteAuthorizeError(rw http.ResponseWriter, ar fosite.AuthorizeRequester, err error) {
rfcerr := fosite.ErrorToRFC6749Error(err)
errors := rfcerr.ToValues()
errors.Set("state", ar.GetState())
errors.Add("custom_err_param", "bar")
fosite.WriteAuthorizeFormPostResponse(ar.GetRedirectURI().String(), errors, fosite.GetPostFormHTMLTemplate(fosite.Fosite{}), rw)
}
47 changes: 47 additions & 0 deletions response_handler.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package fosite

import "net/http"

// ResponseModeHandler provides a contract for handling custom response modes
type ResponseModeHandler interface {
// ResponseModes returns a set of supported response modes handled
// by the interface implementation.
//
// In an authorize request with any of the provide response modes
// methods `WriteAuthorizeResponse` and `WriteAuthorizeError` will be
// invoked to write the successful or error authorization responses respectively.
ResponseModes() ResponseModeTypes

// WriteAuthorizeResponse writes successful responses
//
// Following headers are expected to be set by default:
// header.Set("Cache-Control", "no-store")
// header.Set("Pragma", "no-cache")
WriteAuthorizeResponse(rw http.ResponseWriter, ar AuthorizeRequester, resp AuthorizeResponder)

// WriteAuthorizeError writes error responses
//
// Following headers are expected to be set by default:
// header.Set("Cache-Control", "no-store")
// header.Set("Pragma", "no-cache")
WriteAuthorizeError(rw http.ResponseWriter, ar AuthorizeRequester, err error)
}

type ResponseModeTypes []ResponseModeType

func (rs ResponseModeTypes) Has(item ResponseModeType) bool {
for _, r := range rs {
if r == item {
return true
}
}
return false
}

type DefaultResponseModeHandler struct{}

func (d *DefaultResponseModeHandler) ResponseModes() ResponseModeTypes { return nil }
func (d *DefaultResponseModeHandler) WriteAuthorizeResponse(rw http.ResponseWriter, ar AuthorizeRequester, resp AuthorizeResponder) {
}
func (d *DefaultResponseModeHandler) WriteAuthorizeError(rw http.ResponseWriter, ar AuthorizeRequester, err error) {
}