Skip to content

Commit

Permalink
feat: implement tests and anti-csrf for API settings flows
Browse files Browse the repository at this point in the history
  • Loading branch information
aeneasr committed Aug 25, 2020
1 parent a4e3bc5 commit 8b8b6e5
Show file tree
Hide file tree
Showing 6 changed files with 171 additions and 78 deletions.
3 changes: 3 additions & 0 deletions selfservice/strategy/password/.schema/settings.schema.json
Expand Up @@ -6,6 +6,9 @@
"password"
],
"properties": {
"csrf_token": {
"type": "string"
},
"password": {
"type": "string",
"minLength": 1
Expand Down
18 changes: 9 additions & 9 deletions selfservice/strategy/password/login_test.go
Expand Up @@ -179,7 +179,7 @@ func TestCompleteLogin(t *testing.T) {
t.Run("should show the error ui because the request is malformed", func(t *testing.T) {
run := func(t *testing.T, isAPI bool) (string, *http.Response) {
lr := nlr(0, isAPI)
res, body := fakeRequest(t, lr, isAPI, "14=)=!(%)$/ZP()GHIÖ", nil, nil, expectStatusCode(isAPI, http.StatusBadRequest))
res, body := fakeRequest(t, lr, isAPI, "14=)=!(%)$/ZP()GHIÖ", nil, nil, expectStatusCodeBrowserOKOr(isAPI, http.StatusBadRequest))

assert.Equal(t, lr.ID.String(), gjson.GetBytes(body, "id").String(), "%s", body)
assert.Equal(t, "/action", gjson.GetBytes(body, "methods.password.config.action").String(), "%s", body)
Expand All @@ -202,7 +202,7 @@ func TestCompleteLogin(t *testing.T) {
t.Run("should show the error ui because the request id missing", func(t *testing.T) {
run := func(t *testing.T, isAPI bool) (*http.Response, []byte) {
lr := nlr(time.Minute, isAPI)
return fakeRequest(t, lr, isAPI, url.Values{}.Encode(), pointerx.String(""), nil, expectStatusCode(isAPI, http.StatusBadRequest))
return fakeRequest(t, lr, isAPI, url.Values{}.Encode(), pointerx.String(""), nil, expectStatusCodeBrowserOKOr(isAPI, http.StatusBadRequest))
}

t.Run("type=browser", func(t *testing.T) {
Expand All @@ -225,7 +225,7 @@ func TestCompleteLogin(t *testing.T) {
t.Run("should return an error because the request does not exist", func(t *testing.T) {
run := func(t *testing.T, isAPI bool, payload string) (*http.Response, []byte) {
lr := nlr(0, isAPI)
return fakeRequest(t, lr, isAPI, payload, pointerx.String(x.NewUUID().String()), nil, expectStatusCode(isAPI, http.StatusNotFound))
return fakeRequest(t, lr, isAPI, payload, pointerx.String(x.NewUUID().String()), nil, expectStatusCodeBrowserOKOr(isAPI, http.StatusNotFound))
}

t.Run("type=browser", func(t *testing.T) {
Expand Down Expand Up @@ -277,7 +277,7 @@ func TestCompleteLogin(t *testing.T) {
t.Run("should return an error because the credentials are invalid (user does not exist)", func(t *testing.T) {
run := func(t *testing.T, isAPI bool, payload string) *http.Response {
lr := nlr(time.Hour, isAPI)
res, body := fakeRequest(t, lr, isAPI, payload, nil, nil, expectStatusCode(isAPI, http.StatusBadRequest))
res, body := fakeRequest(t, lr, isAPI, payload, nil, nil, expectStatusCodeBrowserOKOr(isAPI, http.StatusBadRequest))
assert.Equal(t, lr.ID.String(), gjson.GetBytes(body, "id").String(), "%s", body)
assert.Equal(t, "/action", gjson.GetBytes(body, "methods.password.config.action").String())
assert.Equal(t, text.NewErrorValidationInvalidCredentials().Text, gjson.GetBytes(body, "methods.password.config.messages.0.text").String())
Expand Down Expand Up @@ -317,7 +317,7 @@ func TestCompleteLogin(t *testing.T) {
t.Run("should return an error because no identifier is set", func(t *testing.T) {
run := func(t *testing.T, isAPI bool, payload string) *http.Response {
lr := nlr(time.Hour, isAPI)
res, body := fakeRequest(t, lr, isAPI, payload, nil, nil, expectStatusCode(isAPI, http.StatusBadRequest))
res, body := fakeRequest(t, lr, isAPI, payload, nil, nil, expectStatusCodeBrowserOKOr(isAPI, http.StatusBadRequest))

// Let's ensure that the payload is being propagated properly.
assert.Equal(t, lr.ID.String(), gjson.GetBytes(body, "id").String())
Expand Down Expand Up @@ -345,7 +345,7 @@ func TestCompleteLogin(t *testing.T) {
t.Run("should return an error because no password is set", func(t *testing.T) {
run := func(t *testing.T, isAPI bool, payload string) *http.Response {
lr := nlr(time.Hour, isAPI)
res, body := fakeRequest(t, lr, isAPI, payload, nil, nil, expectStatusCode(isAPI, http.StatusBadRequest))
res, body := fakeRequest(t, lr, isAPI, payload, nil, nil, expectStatusCodeBrowserOKOr(isAPI, http.StatusBadRequest))

// Let's ensure that the payload is being propagated properly.
assert.Equal(t, lr.ID.String(), gjson.GetBytes(body, "id").String())
Expand Down Expand Up @@ -387,7 +387,7 @@ func TestCompleteLogin(t *testing.T) {
}

lr := nlr(time.Hour, isAPI)
res, body := fakeRequest(t, lr, isAPI, payload, nil, nil, expectStatusCode(isAPI, http.StatusBadRequest))
res, body := fakeRequest(t, lr, isAPI, payload, nil, nil, expectStatusCodeBrowserOKOr(isAPI, http.StatusBadRequest))

assert.Equal(t, lr.ID.String(), gjson.GetBytes(body, "id").String())
assert.Equal(t, "/action", gjson.GetBytes(body, "methods.password.config.action").String())
Expand Down Expand Up @@ -426,7 +426,7 @@ func TestCompleteLogin(t *testing.T) {
}

lr := nlr(time.Hour, isAPI)
return fakeRequest(t, lr, isAPI, payload, nil, nil, expectStatusCode(isAPI, http.StatusOK))
return fakeRequest(t, lr, isAPI, payload, nil, nil, expectStatusCodeBrowserOKOr(isAPI, http.StatusOK))
}

t.Run("type=browser", func(t *testing.T) {
Expand Down Expand Up @@ -578,7 +578,7 @@ func TestCompleteLogin(t *testing.T) {
Identifier: identifier})
}

res, body := fakeRequest(t, lr, isAPI, payload, nil, nil, expectStatusCode(isAPI, http.StatusBadRequest))
res, body := fakeRequest(t, lr, isAPI, payload, nil, nil, expectStatusCodeBrowserOKOr(isAPI, http.StatusBadRequest))
if isAPI {
require.Contains(t, res.Request.URL.Path, password.RouteLogin)
checkFormContent(t, body, "identifier", "password")
Expand Down
16 changes: 8 additions & 8 deletions selfservice/strategy/password/registration_test.go
Expand Up @@ -143,7 +143,7 @@ func TestRegistration(t *testing.T) {
t.Run("case=should show the error ui because the request payload is malformed", func(t *testing.T) {
run := func(t *testing.T, isAPI bool) (*registration.Flow, []byte, *http.Response) {
rr := newRegistrationRequest(t, time.Minute, isAPI)
body, res := makeRequest(t, rr.ID, isAPI, "14=)=!(%)$/ZP()GHIÖ", expectStatusCode(isAPI, http.StatusBadRequest))
body, res := makeRequest(t, rr.ID, isAPI, "14=)=!(%)$/ZP()GHIÖ", expectStatusCodeBrowserOKOr(isAPI, http.StatusBadRequest))
return rr, body, res
}

Expand All @@ -166,7 +166,7 @@ func TestRegistration(t *testing.T) {
run := func(t *testing.T, isAPI bool) ([]byte, *http.Response) {
_ = newRegistrationRequest(t, time.Minute, isAPI)
uuidDesNotExistInStore := x.NewUUID()
return makeRequest(t, uuidDesNotExistInStore, isAPI, "", expectStatusCode(isAPI, http.StatusNotFound))
return makeRequest(t, uuidDesNotExistInStore, isAPI, "", expectStatusCodeBrowserOKOr(isAPI, http.StatusNotFound))
}

t.Run("type=api", func(t *testing.T) {
Expand Down Expand Up @@ -209,7 +209,7 @@ func TestRegistration(t *testing.T) {
t.Run("case=should return an error because the password failed validation", func(t *testing.T) {
run := func(t *testing.T, isAPI bool, payload string) *http.Response {
rr := newRegistrationRequest(t, time.Minute, isAPI)
body, res := makeRequest(t, rr.ID, isAPI, payload, expectStatusCode(isAPI, http.StatusBadRequest))
body, res := makeRequest(t, rr.ID, isAPI, payload, expectStatusCodeBrowserOKOr(isAPI, http.StatusBadRequest))
assert.Equal(t, rr.ID.String(), gjson.GetBytes(body, "id").String(), "%s", body)
assert.Equal(t, "/action", gjson.GetBytes(body, "methods.password.config.action").String(), "%s", body)
checkFormContent(t, body, "password", "csrf_token", "traits.username", "traits.foobar")
Expand All @@ -235,7 +235,7 @@ func TestRegistration(t *testing.T) {
t.Run("case=should return an error because not passing validation", func(t *testing.T) {
run := func(t *testing.T, isAPI bool, payload string) *http.Response {
rr := newRegistrationRequest(t, time.Minute, isAPI)
body, res := makeRequest(t, rr.ID, isAPI, payload, expectStatusCode(isAPI, http.StatusBadRequest))
body, res := makeRequest(t, rr.ID, isAPI, payload, expectStatusCodeBrowserOKOr(isAPI, http.StatusBadRequest))
assert.Equal(t, rr.ID.String(), gjson.GetBytes(body, "id").String(), "%s", body)
assert.Equal(t, "/action", gjson.GetBytes(body, "methods.password.config.action").String(), "%s", body)
checkFormContent(t, body, "password", "csrf_token", "traits.username", "traits.foobar")
Expand All @@ -261,7 +261,7 @@ func TestRegistration(t *testing.T) {
viper.Set(configuration.ViperKeyDefaultIdentitySchemaURL, "file://./stub/missing-identifier.schema.json")
run := func(t *testing.T, isAPI bool, payload string) ([]byte, *http.Response) {
rr := newRegistrationRequest(t, time.Minute, isAPI)
return makeRequest(t, rr.ID, isAPI, payload, expectStatusCode(isAPI, http.StatusInternalServerError))
return makeRequest(t, rr.ID, isAPI, payload, expectStatusCodeBrowserOKOr(isAPI, http.StatusInternalServerError))
}

t.Run("type=api", func(t *testing.T) {
Expand Down Expand Up @@ -313,7 +313,7 @@ func TestRegistration(t *testing.T) {

run := func(t *testing.T, isAPI bool, payload string) ([]byte, *http.Response) {
rr := newRegistrationRequest(t, time.Minute, isAPI)
body, res := makeRequest(t, rr.ID, isAPI, payload, expectStatusCode(isAPI, http.StatusInternalServerError))
body, res := makeRequest(t, rr.ID, isAPI, payload, expectStatusCodeBrowserOKOr(isAPI, http.StatusInternalServerError))
return body, res
}

Expand Down Expand Up @@ -437,7 +437,7 @@ func TestRegistration(t *testing.T) {
}

require.NoError(t, reg.RegistrationFlowPersister().CreateRegistrationFlow(context.Background(), rr))
body, res := makeRequest(t, rr.ID, isAPI, payload, expectStatusCode(isAPI, http.StatusBadRequest))
body, res := makeRequest(t, rr.ID, isAPI, payload, expectStatusCodeBrowserOKOr(isAPI, http.StatusBadRequest))

assert.Equal(t, rr.ID.String(), gjson.GetBytes(body, "id").String(), "%s", body)
assert.Equal(t, "/action", gjson.GetBytes(body, "methods.password.config.action").String(), "%s", body)
Expand Down Expand Up @@ -532,7 +532,7 @@ func TestRegistration(t *testing.T) {
false, payload, http.StatusOK, jar)

body2, res2 := makeRequestWithCookieJar(t, newRegistrationRequest(t, time.Minute, false).ID,
false, payload, expectStatusCode(false, http.StatusBadRequest), jar)
false, payload, expectStatusCodeBrowserOKOr(false, http.StatusBadRequest), jar)

assert.Contains(t, res1.Request.URL.String(), redirTS.URL+"/registration-return-ts")
assert.Contains(t, res2.Request.URL.String(), redirTS.URL+"/default-return-to")
Expand Down
34 changes: 27 additions & 7 deletions selfservice/strategy/password/settings.go
Expand Up @@ -26,6 +26,7 @@ const (
)

func (s *Strategy) RegisterSettingsRoutes(router *x.RouterPublic) {
s.d.CSRFHandler().ExemptPath(RouteSettings)
router.POST(RouteSettings, s.submitSettingsFlow)
router.GET(RouteSettings, s.submitSettingsFlow)
}
Expand All @@ -36,24 +37,38 @@ func (s *Strategy) SettingsStrategyID() string {

// swagger:parameters completeSelfServiceSettingsFlowWithPasswordMethod
type completeSelfServiceSettingsFlowWithPasswordMethod struct {
// in: body
Payload SettingsFlowPayload

// Flow is flow ID.
//
// in: query
Flow string `json:"flow"`
}

type SettingsFlowPayload struct {
// Password is the updated password
//
// type: string
// in: body
// required: true
Password string `json:"password"`

// CSRFToken is the anti-CSRF token
//
// type: string
CSRFToken string `json:"csrf_token"`

// Flow is flow ID.
//
// in: query
// swagger:ignore
Flow string `json:"flow"`
}

func (p *completeSelfServiceSettingsFlowWithPasswordMethod) GetFlowID() uuid.UUID {
func (p *SettingsFlowPayload) GetFlowID() uuid.UUID {
return x.ParseUUID(p.Flow)
}

func (p *completeSelfServiceSettingsFlowWithPasswordMethod) SetFlowID(rid uuid.UUID) {
func (p *SettingsFlowPayload) SetFlowID(rid uuid.UUID) {
p.Flow = rid.String()
}

Expand All @@ -78,7 +93,7 @@ func (p *completeSelfServiceSettingsFlowWithPasswordMethod) SetFlowID(rid uuid.U
// 302: emptyResponse
// 500: genericError
func (s *Strategy) submitSettingsFlow(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
var p completeSelfServiceSettingsFlowWithPasswordMethod
var p SettingsFlowPayload
ctxUpdate, err := settings.PrepareUpdate(s.d, w, r, settings.ContinuityKey(s.SettingsStrategyID()), &p)
if errors.Is(err, settings.ErrContinuePreviousAction) {
s.continueSettingsFlow(w, r, ctxUpdate, &p)
Expand Down Expand Up @@ -112,8 +127,13 @@ func (s *Strategy) decodeSettingsFlow(r *http.Request, dest interface{}) error {

func (s *Strategy) continueSettingsFlow(
w http.ResponseWriter, r *http.Request,
ctxUpdate *settings.UpdateContext, p *completeSelfServiceSettingsFlowWithPasswordMethod,
ctxUpdate *settings.UpdateContext, p *SettingsFlowPayload,
) {
if err := flow.VerifyRequest(r,ctxUpdate.Flow.Type,s.d.GenerateCSRFToken,p.CSRFToken); err != nil {
s.handleSettingsError(w, r, ctxUpdate, p, err)
return
}

if ctxUpdate.Session.AuthenticatedAt.Add(s.c.SelfServiceFlowSettingsPrivilegedSessionMaxAge()).Before(time.Now()) {
s.handleSettingsError(w, r, ctxUpdate, p, errors.WithStack(settings.ErrRequestNeedsReAuthentication))
return
Expand Down Expand Up @@ -176,7 +196,7 @@ func (s *Strategy) PopulateSettingsMethod(r *http.Request, _ *identity.Identity,
return nil
}

func (s *Strategy) handleSettingsError(w http.ResponseWriter, r *http.Request, ctxUpdate *settings.UpdateContext, p *completeSelfServiceSettingsFlowWithPasswordMethod, err error) {
func (s *Strategy) handleSettingsError(w http.ResponseWriter, r *http.Request, ctxUpdate *settings.UpdateContext, p *SettingsFlowPayload, err error) {
// Do not pause flow if the flow type is an API flow as we can't save cookies in those flows.
if errors.Is(err, settings.ErrRequestNeedsReAuthentication) && ctxUpdate.Flow != nil && ctxUpdate.Flow.Type == flow.TypeBrowser {
if err := s.d.ContinuityManager().Pause(r.Context(), w, r,
Expand Down

0 comments on commit 8b8b6e5

Please sign in to comment.