Skip to content

Commit

Permalink
feat(op): Add response_mode: form_post (#551)
Browse files Browse the repository at this point in the history
* feat(op): Add response_mode: form_post

* Fix to parse the template ahead of time

* Fix to render the template in a buffer

* Remove unnecessary import

* Fix test

* Fix example client setting

* Make sure the client not to reuse the content of the response

* Fix error handling

* Add the response_mode param

* Allow implicit flow in the example app

* feat(rp): allow form_post in code exchange callback handler

---------

Co-authored-by: Tim Möhlmann <tim+github@zitadel.com>
  • Loading branch information
ay4toh5i and muhlemmer committed Mar 5, 2024
1 parent fc743a6 commit 5ef597b
Show file tree
Hide file tree
Showing 8 changed files with 131 additions and 11 deletions.
15 changes: 14 additions & 1 deletion example/client/app/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ func main() {
issuer := os.Getenv("ISSUER")
port := os.Getenv("PORT")
scopes := strings.Split(os.Getenv("SCOPES"), " ")
responseMode := os.Getenv("RESPONSE_MODE")

redirectURI := fmt.Sprintf("http://localhost:%v%v", port, callbackPath)
cookieHandler := httphelper.NewCookieHandler(key, key, httphelper.WithUnsecure())
Expand Down Expand Up @@ -77,12 +78,24 @@ func main() {
return uuid.New().String()
}

urlOptions := []rp.URLParamOpt{
rp.WithPromptURLParam("Welcome back!"),
}

if responseMode != "" {
urlOptions = append(urlOptions, rp.WithResponseModeURLParam(oidc.ResponseMode(responseMode)))
}

// register the AuthURLHandler at your preferred path.
// the AuthURLHandler creates the auth request and redirects the user to the auth server.
// including state handling with secure cookie and the possibility to use PKCE.
// Prompts can optionally be set to inform the server of
// any messages that need to be prompted back to the user.
http.Handle("/login", rp.AuthURLHandler(state, provider, rp.WithPromptURLParam("Welcome back!")))
http.Handle("/login", rp.AuthURLHandler(
state,
provider,
urlOptions...,
))

// for demonstration purposes the returned userinfo response is written as JSON object onto response
marshalUserinfo := func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens[*oidc.IDTokenClaims], state string, rp rp.RelyingParty, info *oidc.UserInfo) {
Expand Down
4 changes: 2 additions & 2 deletions example/server/storage/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -184,10 +184,10 @@ func WebClient(id, secret string, redirectURIs ...string) *Client {
applicationType: op.ApplicationTypeWeb,
authMethod: oidc.AuthMethodBasic,
loginURL: defaultLoginURL,
responseTypes: []oidc.ResponseType{oidc.ResponseTypeCode},
responseTypes: []oidc.ResponseType{oidc.ResponseTypeCode, oidc.ResponseTypeIDTokenOnly, oidc.ResponseTypeIDToken},
grantTypes: []oidc.GrantType{oidc.GrantTypeCode, oidc.GrantTypeRefreshToken, oidc.GrantTypeTokenExchange},
accessTokenType: op.AccessTokenTypeBearer,
devMode: false,
devMode: true,
idTokenUserinfoClaimsAssertion: false,
clockSkew: 0,
}
Expand Down
4 changes: 3 additions & 1 deletion example/server/storage/oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ type AuthRequest struct {
UserID string
Scopes []string
ResponseType oidc.ResponseType
ResponseMode oidc.ResponseMode
Nonce string
CodeChallenge *OIDCCodeChallenge

Expand Down Expand Up @@ -100,7 +101,7 @@ func (a *AuthRequest) GetResponseType() oidc.ResponseType {
}

func (a *AuthRequest) GetResponseMode() oidc.ResponseMode {
return "" // we won't handle response mode in this example
return a.ResponseMode
}

func (a *AuthRequest) GetScopes() []string {
Expand Down Expand Up @@ -154,6 +155,7 @@ func authRequestToInternal(authReq *oidc.AuthRequest, userID string) *AuthReques
UserID: userID,
Scopes: authReq.Scopes,
ResponseType: authReq.ResponseType,
ResponseMode: authReq.ResponseMode,
Nonce: authReq.Nonce,
CodeChallenge: &OIDCCodeChallenge{
Challenge: authReq.CodeChallenge,
Expand Down
7 changes: 3 additions & 4 deletions pkg/client/rp/relying_party.go
Original file line number Diff line number Diff line change
Expand Up @@ -494,9 +494,8 @@ func CodeExchangeHandler[C oidc.IDClaims](callback CodeExchangeCallback[C], rp R
unauthorizedError(w, r, "failed to get state: "+err.Error(), state, rp)
return
}
params := r.URL.Query()
if params.Get("error") != "" {
rp.ErrorHandler()(w, r, params.Get("error"), params.Get("error_description"), state)
if errValue := r.FormValue("error"); errValue != "" {
rp.ErrorHandler()(w, r, errValue, r.FormValue("error_description"), state)
return
}
codeOpts := make([]CodeExchangeOpt, len(urlParam))
Expand All @@ -521,7 +520,7 @@ func CodeExchangeHandler[C oidc.IDClaims](callback CodeExchangeCallback[C], rp R
}
codeOpts = append(codeOpts, WithClientAssertionJWT(assertion))
}
tokens, err := CodeExchange[C](r.Context(), params.Get("code"), rp, codeOpts...)
tokens, err := CodeExchange[C](r.Context(), r.FormValue("code"), rp, codeOpts...)
if err != nil {
unauthorizedError(w, r, "failed to exchange token: "+err.Error(), state, rp)
return
Expand Down
1 change: 1 addition & 0 deletions pkg/oidc/authorization.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ const (

ResponseModeQuery ResponseMode = "query"
ResponseModeFragment ResponseMode = "fragment"
ResponseModeFormPost ResponseMode = "form_post"

// PromptNone (`none`) disallows the Authorization Server to display any authentication or consent user interface pages.
// An error (login_required, interaction_required, ...) will be returned if the user is not already authenticated or consent is needed
Expand Down
62 changes: 62 additions & 0 deletions pkg/op/auth_request.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
package op

import (
"bytes"
"context"
_ "embed"
"errors"
"fmt"
"html/template"
"log/slog"
"net"
"net/http"
Expand Down Expand Up @@ -464,6 +467,17 @@ func AuthResponseCode(w http.ResponseWriter, r *http.Request, authReq AuthReques
Code: code,
State: authReq.GetState(),
}

if authReq.GetResponseMode() == oidc.ResponseModeFormPost {
err := AuthResponseFormPost(w, authReq.GetRedirectURI(), &codeResponse, authorizer.Encoder())
if err != nil {
AuthRequestError(w, r, authReq, err, authorizer)
return
}

return
}

callback, err := AuthResponseURL(authReq.GetRedirectURI(), authReq.GetResponseType(), authReq.GetResponseMode(), &codeResponse, authorizer.Encoder())
if err != nil {
AuthRequestError(w, r, authReq, err, authorizer)
Expand All @@ -484,6 +498,17 @@ func AuthResponseToken(w http.ResponseWriter, r *http.Request, authReq AuthReque
AuthRequestError(w, r, authReq, err, authorizer)
return
}

if authReq.GetResponseMode() == oidc.ResponseModeFormPost {
err := AuthResponseFormPost(w, authReq.GetRedirectURI(), resp, authorizer.Encoder())
if err != nil {
AuthRequestError(w, r, authReq, err, authorizer)
return
}

return
}

callback, err := AuthResponseURL(authReq.GetRedirectURI(), authReq.GetResponseType(), authReq.GetResponseMode(), resp, authorizer.Encoder())
if err != nil {
AuthRequestError(w, r, authReq, err, authorizer)
Expand Down Expand Up @@ -535,6 +560,43 @@ func AuthResponseURL(redirectURI string, responseType oidc.ResponseType, respons
return mergeQueryParams(uri, params), nil
}

//go:embed form_post.html.tmpl
var formPostHtmlTemplate string

var formPostTmpl = template.Must(template.New("form_post").Parse(formPostHtmlTemplate))

// AuthResponseFormPost responds a html page that automatically submits the form which contains the auth response parameters
func AuthResponseFormPost(res http.ResponseWriter, redirectURI string, response any, encoder httphelper.Encoder) error {
values := make(map[string][]string)
err := encoder.Encode(response, values)
if err != nil {
return oidc.ErrServerError().WithParent(err)
}

params := &struct {
RedirectURI string
Params any
}{
RedirectURI: redirectURI,
Params: values,
}

var buf bytes.Buffer
err = formPostTmpl.Execute(&buf, params)
if err != nil {
return oidc.ErrServerError().WithParent(err)
}

res.Header().Set("Cache-Control", "no-store")
res.WriteHeader(http.StatusOK)
_, err = buf.WriteTo(res)
if err != nil {
return oidc.ErrServerError().WithParent(err)
}

return nil
}

func setFragment(uri *url.URL, params url.Values) string {
uri.Fragment = params.Encode()
return uri.String()
Expand Down
35 changes: 32 additions & 3 deletions pkg/op/auth_request_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1027,9 +1027,10 @@ func TestAuthResponseCode(t *testing.T) {
authorizer func(*testing.T) op.Authorizer
}
type res struct {
wantCode int
wantLocationHeader string
wantBody string
wantCode int
wantLocationHeader string
wantCacheControlHeader string
wantBody string
}
tests := []struct {
name string
Expand Down Expand Up @@ -1111,6 +1112,33 @@ func TestAuthResponseCode(t *testing.T) {
wantBody: "",
},
},
{
name: "success form_post",
args: args{
authReq: &storage.AuthRequest{
ID: "id1",
CallbackURI: "https://example.com/callback",
TransferState: "state1",
ResponseMode: "form_post",
},
authorizer: func(t *testing.T) op.Authorizer {
ctrl := gomock.NewController(t)
storage := mock.NewMockStorage(ctrl)
storage.EXPECT().SaveAuthCode(context.Background(), "id1", "id1")

authorizer := mock.NewMockAuthorizer(ctrl)
authorizer.EXPECT().Storage().Return(storage)
authorizer.EXPECT().Crypto().Return(&mockCrypto{})
authorizer.EXPECT().Encoder().Return(schema.NewEncoder())
return authorizer
},
},
res: res{
wantCode: http.StatusOK,
wantCacheControlHeader: "no-store",
wantBody: "<!doctype html>\n<html>\n<head><meta charset=\"UTF-8\" /></head>\n<body onload=\"javascript:document.forms[0].submit()\">\n<form method=\"post\" action=\"https://example.com/callback\">\n<input type=\"hidden\" name=\"state\" value=\"state1\"/>\n<input type=\"hidden\" name=\"code\" value=\"id1\" />\n\n\n\n\n</form>\n</body>\n</html>",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
Expand All @@ -1121,6 +1149,7 @@ func TestAuthResponseCode(t *testing.T) {
defer resp.Body.Close()
assert.Equal(t, tt.res.wantCode, resp.StatusCode)
assert.Equal(t, tt.res.wantLocationHeader, resp.Header.Get("Location"))
assert.Equal(t, tt.res.wantCacheControlHeader, resp.Header.Get("Cache-Control"))
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
assert.Equal(t, tt.res.wantBody, string(body))
Expand Down
14 changes: 14 additions & 0 deletions pkg/op/form_post.html.tmpl
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
<!doctype html>
<html>
<head><meta charset="UTF-8" /></head>
<body onload="javascript:document.forms[0].submit()">
<form method="post" action="{{ .RedirectURI }}">
{{with .Params.state}}<input type="hidden" name="state" value="{{ index . 0 }}"/>{{end}}
{{with .Params.code}}<input type="hidden" name="code" value="{{ index . 0 }}" />{{end}}
{{with .Params.id_token}}<input type="hidden" name="id_token" value="{{ index . 0 }}"/>{{end}}
{{with .Params.access_token}}<input type="hidden" name="access_token" value="{{ index . 0 }}" />{{end}}
{{with .Params.token_type}}<input type="hidden" name="token_type" value="{{ index . 0 }}" />{{end}}
{{with .Params.expires_in}}<input type="hidden" name="expires_in" value="{{ index . 0 }}" />{{end}}
</form>
</body>
</html>

0 comments on commit 5ef597b

Please sign in to comment.