Skip to content

Commit

Permalink
fix: resolves a bug that prevents sessions from expiring (#612)
Browse files Browse the repository at this point in the history
Closes #611
  • Loading branch information
aeneasr committed Jul 30, 2020
1 parent 562cfc4 commit 86b281a
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 6 deletions.
12 changes: 8 additions & 4 deletions session/manager_http.go
Expand Up @@ -99,13 +99,17 @@ func (s *ManagerHTTP) FetchFromRequest(ctx context.Context, r *http.Request) (*S
}

se, err := s.r.SessionPersister().GetSession(ctx, x.ParseUUID(sid))
if err != nil && (err.Error() == herodot.ErrNotFound.Error() ||
err.Error() == sqlcon.ErrNoRows.Error()) {
return nil, errors.WithStack(ErrNoActiveSessionFound)
} else if err != nil {
if err != nil {
if errors.Is(err, herodot.ErrNotFound) || errors.Is(err, sqlcon.ErrNoRows) {
return nil, errors.WithStack(ErrNoActiveSessionFound)
}
return nil, err
}

if se.ExpiresAt.Before(time.Now()) {
return nil, errors.WithStack(ErrNoActiveSessionFound)
}

se.Identity = se.Identity.CopyWithoutCredentials()

return se, nil
Expand Down
73 changes: 71 additions & 2 deletions session/manager_http_test.go
Expand Up @@ -2,14 +2,22 @@ package session_test

import (
"context"
"errors"
"net/http"
"net/http/httptest"
"testing"
"time"

"github.com/julienschmidt/httprouter"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/ory/viper"

"github.com/ory/kratos/driver/configuration"
"github.com/ory/kratos/identity"
"github.com/ory/kratos/internal"
"github.com/ory/kratos/internal/testhelpers"
"github.com/ory/kratos/session"
"github.com/ory/kratos/x"
)
Expand All @@ -27,13 +35,74 @@ func (f *mockCSRFHandler) RegenerateToken(w http.ResponseWriter, r *http.Request
}

func TestManagerHTTP(t *testing.T) {
t.Run("method=SaveToRequest", func(t *testing.T) {
t.Run("case=regenerate csrf on principal change", func(t *testing.T) {
_, reg := internal.NewFastRegistryWithMocks(t)

mock := new(mockCSRFHandler)
reg.WithCSRFHandler(mock)

require.NoError(t, reg.SessionManager().SaveToRequest(context.Background(), httptest.NewRecorder(), new(http.Request), new(session.Session)))
assert.Equal(t, 1, mock.c)
})

t.Run("suite=lifecycle", func(t *testing.T) {
conf, reg := internal.NewFastRegistryWithMocks(t)

viper.Set(configuration.ViperKeySelfServiceLoginUI, "https://www.ory.sh")
viper.Set(configuration.ViperKeyDefaultIdentitySchemaURL, "file://./stub/fake-session.schema.json")

var s *session.Session
rp := x.NewRouterPublic()
rp.GET("/session/set", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
require.NoError(t, reg.SessionManager().CreateToRequest(r.Context(), w, r, s))
w.WriteHeader(http.StatusOK)
})

rp.GET("/session/get", func(w http.ResponseWriter, r *http.Request, p httprouter.Params) {
sess, err := reg.SessionManager().FetchFromRequest(r.Context(), r)
if err != nil {
t.Logf("Got error on lookup: %s %T", err, errors.Unwrap(err))
reg.Writer().WriteError(w, r, err)
return
}
reg.Writer().Write(w, r, sess)
})

pts := httptest.NewServer(x.NewTestCSRFHandler(rp, reg))
t.Cleanup(pts.Close)
viper.Set(configuration.ViperKeyPublicBaseURL, pts.URL)
reg.RegisterPublicRoutes(rp)

t.Run("case=valid", func(t *testing.T) {
viper.Set(configuration.ViperKeySessionLifespan, "1m")

i := identity.Identity{Traits: []byte("{}")}
require.NoError(t, reg.PrivilegedIdentityPool().CreateIdentity(context.Background(), &i))
s = session.NewSession(&i, conf, time.Now())

c := testhelpers.NewClientWithCookies(t)
testhelpers.MockHydrateCookieClient(t, c, pts.URL+"/session/set")

res, err := c.Get(pts.URL + "/session/get")
require.NoError(t, err)
assert.EqualValues(t, http.StatusOK, res.StatusCode)
})

t.Run("case=expired", func(t *testing.T) {
viper.Set(configuration.ViperKeySessionLifespan, "1ns")

i := identity.Identity{Traits: []byte("{}")}
require.NoError(t, reg.PrivilegedIdentityPool().CreateIdentity(context.Background(), &i))
s = session.NewSession(&i, conf, time.Now())

c := testhelpers.NewClientWithCookies(t)
testhelpers.MockHydrateCookieClient(t, c, pts.URL+"/session/set")

time.Sleep(time.Nanosecond * 2)

res, err := c.Get(pts.URL + "/session/get")
require.NoError(t, err)
assert.EqualValues(t, http.StatusUnauthorized, res.StatusCode)
})
})

}

0 comments on commit 86b281a

Please sign in to comment.