Skip to content

Commit

Permalink
fix: ignore CSRF for session extension on public route
Browse files Browse the repository at this point in the history
  • Loading branch information
jonas-jonas authored and aeneasr committed Aug 19, 2022
1 parent 576f9c0 commit 866b472
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 9 deletions.
1 change: 1 addition & 0 deletions session/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ func (h *Handler) RegisterPublicRoutes(public *x.RouterPublic) {
h.r.CSRFHandler().IgnorePath(RouteWhoami)
h.r.CSRFHandler().IgnorePath(RouteCollection)
h.r.CSRFHandler().IgnoreGlob(RouteCollection + "/*")
h.r.CSRFHandler().IgnoreGlob(RouteCollection + "/*/extend")
h.r.CSRFHandler().IgnoreGlob(AdminRouteIdentity + "/*/sessions")

for _, m := range []string{http.MethodGet, http.MethodHead, http.MethodPost, http.MethodPut, http.MethodPatch, http.MethodConnect, http.MethodOptions, http.MethodTrace} {
Expand Down
29 changes: 20 additions & 9 deletions session/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -687,21 +687,22 @@ func TestHandlerSelfServiceSessionManagement(t *testing.T) {

func TestHandlerRefreshSessionBySessionID(t *testing.T) {
conf, reg := internal.NewFastRegistryWithMocks(t)
_, ts, _, _ := testhelpers.NewKratosServerWithCSRFAndRouters(t, reg)
publicServer, adminServer, _, _ := testhelpers.NewKratosServerWithCSRFAndRouters(t, reg)

// set this intermediate because kratos needs some valid url for CRUDE operations
conf.MustSet(config.ViperKeyPublicBaseURL, "http://example.com")
testhelpers.SetDefaultIdentitySchema(conf, "file://./stub/identity.schema.json")
conf.MustSet(config.ViperKeyPublicBaseURL, ts.URL)
conf.MustSet(config.ViperKeyPublicBaseURL, adminServer.URL)

i := identity.NewIdentity("")
require.NoError(t, reg.IdentityManager().Create(context.Background(), i))
s := &Session{Identity: i, ExpiresAt: time.Now().Add(5 * time.Minute)}
require.NoError(t, reg.SessionPersister().UpsertSession(context.Background(), s))

t.Run("case=should return 200 after refreshing one session", func(t *testing.T) {
client := testhelpers.NewClientWithCookies(t)
i := identity.NewIdentity("")
require.NoError(t, reg.IdentityManager().Create(context.Background(), i))
s := &Session{Identity: i, ExpiresAt: time.Now().Add(5 * time.Minute)}
require.NoError(t, reg.SessionPersister().UpsertSession(context.Background(), s))

req, _ := http.NewRequest("PATCH", ts.URL+"/admin/sessions/"+s.ID.String()+"/extend", nil)
req, _ := http.NewRequest("PATCH", adminServer.URL+"/admin/sessions/"+s.ID.String()+"/extend", nil)
res, err := client.Do(req)
require.NoError(t, err)
require.Equal(t, http.StatusOK, res.StatusCode)
Expand All @@ -712,7 +713,7 @@ func TestHandlerRefreshSessionBySessionID(t *testing.T) {

t.Run("case=should return 400 when bad UUID is sent", func(t *testing.T) {
client := testhelpers.NewClientWithCookies(t)
req, _ := http.NewRequest("PATCH", ts.URL+"/admin/sessions/BADUUID/extend", nil)
req, _ := http.NewRequest("PATCH", adminServer.URL+"/admin/sessions/BADUUID/extend", nil)
res, err := client.Do(req)
require.NoError(t, err)
require.Equal(t, http.StatusBadRequest, res.StatusCode)
Expand All @@ -721,9 +722,19 @@ func TestHandlerRefreshSessionBySessionID(t *testing.T) {
t.Run("case=should return 404 when calling with missing UUID", func(t *testing.T) {
client := testhelpers.NewClientWithCookies(t)
someID, _ := uuid.NewV4()
req, _ := http.NewRequest("PATCH", ts.URL+"/admin/sessions/"+someID.String()+"/extend", nil)
req, _ := http.NewRequest("PATCH", adminServer.URL+"/admin/sessions/"+someID.String()+"/extend", nil)
res, err := client.Do(req)
require.NoError(t, err)
require.Equal(t, http.StatusNotFound, res.StatusCode)
})

t.Run("case=should return 404 when calling puplic server", func(t *testing.T) {
req := x.NewTestHTTPRequest(t, "PATCH", publicServer.URL+"/sessions/"+s.ID.String()+"/extend", nil)

res, err := publicServer.Client().Do(req)
require.NoError(t, err)
assert.Equal(t, http.StatusNotFound, res.StatusCode)
body := ioutilx.MustReadAll(res.Body)
assert.NotEqual(t, gjson.GetBytes(body, "error.id").String(), "security_csrf_violation")
})
}

0 comments on commit 866b472

Please sign in to comment.