Skip to content

Commit

Permalink
fix: identity sessions list response includes pagination headers (#2763)
Browse files Browse the repository at this point in the history
Closes #2762
  • Loading branch information
brahmlower committed Oct 13, 2022
1 parent d8514b5 commit 0c2efa2
Show file tree
Hide file tree
Showing 5 changed files with 119 additions and 14 deletions.
20 changes: 15 additions & 5 deletions persistence/sql/persister_session.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,15 +54,16 @@ func (p *Persister) GetSession(ctx context.Context, sid uuid.UUID, expandables s
}

// ListSessionsByIdentity retrieves sessions for an identity from the store.
func (p *Persister) ListSessionsByIdentity(ctx context.Context, iID uuid.UUID, active *bool, page, perPage int, except uuid.UUID, expandables session.Expandables) ([]*session.Session, error) {
func (p *Persister) ListSessionsByIdentity(ctx context.Context, iID uuid.UUID, active *bool, page, perPage int, except uuid.UUID, expandables session.Expandables) ([]*session.Session, int64, error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.ListSessionsByIdentity")
defer span.End()

s := make([]*session.Session, 0)
t := int64(0)
nid := p.NetworkID(ctx)

if err := p.Transaction(ctx, func(ctx context.Context, c *pop.Connection) error {
q := c.Where("identity_id = ? AND nid = ?", iID, nid).Paginate(page, perPage)
q := c.Where("identity_id = ? AND nid = ?", iID, nid)
if except != uuid.Nil {
q = q.Where("id != ?", except)
}
Expand All @@ -72,7 +73,16 @@ func (p *Persister) ListSessionsByIdentity(ctx context.Context, iID uuid.UUID, a
if len(expandables) > 0 {
q = q.Eager(expandables.ToEager()...)
}
if err := q.All(&s); err != nil {

// Get the total count of matching items
total, err := q.Count(new(session.Session))
if err != nil {
return sqlcon.HandleError(err)
}
t = int64(total)

// Get the paginated list of matching items
if err := q.Paginate(page, perPage).All(&s); err != nil {
return sqlcon.HandleError(err)
}

Expand All @@ -88,10 +98,10 @@ func (p *Persister) ListSessionsByIdentity(ctx context.Context, iID uuid.UUID, a
}
return nil
}); err != nil {
return nil, err
return nil, 0, err
}

return s, nil
return s, t, nil
}

// UpsertSession creates a session if not found else updates.
Expand Down
7 changes: 5 additions & 2 deletions session/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"github.com/pkg/errors"

"github.com/ory/x/decoderx"
"github.com/ory/x/urlx"

"github.com/ory/herodot"

Expand Down Expand Up @@ -304,12 +305,13 @@ func (h *Handler) adminListIdentitySessions(w http.ResponseWriter, r *http.Reque
}

page, perPage := x.ParsePagination(r)
sess, err := h.r.SessionPersister().ListSessionsByIdentity(r.Context(), iID, active, page, perPage, uuid.Nil, ExpandEverything)
sess, total, err := h.r.SessionPersister().ListSessionsByIdentity(r.Context(), iID, active, page, perPage, uuid.Nil, ExpandEverything)
if err != nil {
h.r.Writer().WriteError(w, r, err)
return
}

x.PaginationHeader(w, urlx.AppendPaths(h.r.Config().SelfAdminURL(r.Context()), RouteCollection), total, page, perPage)
h.r.Writer().Write(w, r, sess)
}

Expand Down Expand Up @@ -448,12 +450,13 @@ func (h *Handler) listSessions(w http.ResponseWriter, r *http.Request, _ httprou
}

page, perPage := x.ParsePagination(r)
sess, err := h.r.SessionPersister().ListSessionsByIdentity(r.Context(), s.IdentityID, pointerx.Bool(true), page, perPage, s.ID, ExpandEverything)
sess, total, err := h.r.SessionPersister().ListSessionsByIdentity(r.Context(), s.IdentityID, pointerx.Bool(true), page, perPage, s.ID, ExpandEverything)
if err != nil {
h.r.Writer().WriteError(w, r, err)
return
}

x.PaginationHeader(w, urlx.AppendPaths(h.r.Config().SelfAdminURL(r.Context()), RouteCollection), total, page, perPage)
h.r.Writer().Write(w, r, sess)
}

Expand Down
89 changes: 88 additions & 1 deletion session/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"io"
"net/http"
"net/http/httptest"
"strconv"
"strings"
"testing"
"time"
Expand Down Expand Up @@ -472,6 +473,61 @@ func TestHandlerAdminSessionManagement(t *testing.T) {
require.Equal(t, http.StatusNotFound, res.StatusCode)
})

t.Run("case=should return pagination headers on list response", func(t *testing.T) {
client := testhelpers.NewClientWithCookies(t)
i := identity.NewIdentity("")
require.NoError(t, reg.IdentityManager().Create(ctx, i))

numSessions := 5
numSessionsActive := 2

sess := make([]Session, numSessions)
for j := range sess {
require.NoError(t, faker.FakeData(&sess[j]))
sess[j].Identity = i
if j < numSessionsActive {
sess[j].Active = true
} else {
sess[j].Active = false
}
require.NoError(t, reg.SessionPersister().UpsertSession(ctx, &sess[j]))
}

for _, tc := range []struct {
activeOnly string
expectedTotalCount int
}{
{
activeOnly: "true",
expectedTotalCount: numSessionsActive,
},
{
activeOnly: "false",
expectedTotalCount: numSessions - numSessionsActive,
},
{
activeOnly: "",
expectedTotalCount: numSessions,
},
} {
t.Run(fmt.Sprintf("active=%#v", tc.activeOnly), func(t *testing.T) {
reqURL := ts.URL + "/admin/identities/" + i.ID.String() + "/sessions"
if tc.activeOnly != "" {
reqURL += "?active=" + tc.activeOnly
}
req, _ := http.NewRequest("GET", reqURL, nil)
res, err := client.Do(req)
require.NoError(t, err)
require.Equal(t, http.StatusOK, res.StatusCode)

totalCount, err := strconv.Atoi(res.Header.Get("X-Total-Count"))
require.NoError(t, err)
require.Equal(t, tc.expectedTotalCount, totalCount)
require.NotEqual(t, "", res.Header.Get("Link"))
})
}
})

t.Run("case=should respect active on list", func(t *testing.T) {
client := testhelpers.NewClientWithCookies(t)
i := identity.NewIdentity("")
Expand Down Expand Up @@ -559,6 +615,36 @@ func TestHandlerSelfServiceSessionManagement(t *testing.T) {
}
}

t.Run("case=list should return pagination headers", func(t *testing.T) {
client, i, _ := setup(t)

numSessions := 5
numSessionsActive := 2

sess := make([]Session, numSessions)
for j := range sess {
require.NoError(t, faker.FakeData(&sess[j]))
sess[j].Identity = i
if j < numSessionsActive {
sess[j].Active = true
} else {
sess[j].Active = false
}
require.NoError(t, reg.SessionPersister().UpsertSession(ctx, &sess[j]))
}

reqURL := ts.URL + "/sessions"
req, _ := http.NewRequest("GET", reqURL, nil)
res, err := client.Do(req)
require.NoError(t, err)
require.Equal(t, http.StatusOK, res.StatusCode)

totalCount, err := strconv.Atoi(res.Header.Get("X-Total-Count"))
require.NoError(t, err)
require.Equal(t, numSessionsActive, totalCount)
require.NotEqual(t, "", res.Header.Get("Link"))
})

t.Run("case=should return 200 and number after invalidating all other sessions", func(t *testing.T) {
client, i, currSess := setup(t)

Expand Down Expand Up @@ -601,9 +687,10 @@ func TestHandlerSelfServiceSessionManagement(t *testing.T) {
require.NoError(t, err)
require.Equal(t, http.StatusNoContent, res.StatusCode)

actualOthers, err := reg.SessionPersister().ListSessionsByIdentity(ctx, i.ID, nil, 1, 10, uuid.Nil, ExpandNothing)
actualOthers, total, err := reg.SessionPersister().ListSessionsByIdentity(ctx, i.ID, nil, 1, 10, uuid.Nil, ExpandNothing)
require.NoError(t, err)
require.Len(t, actualOthers, 3)
require.Equal(t, int64(3), total)

for _, s := range actualOthers {
if s.ID == others[0].ID {
Expand Down
2 changes: 1 addition & 1 deletion session/persistence.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ type Persister interface {
GetSession(ctx context.Context, sid uuid.UUID, expandables Expandables) (*Session, error)

// ListSessionsByIdentity retrieves sessions for an identity from the store.
ListSessionsByIdentity(ctx context.Context, iID uuid.UUID, active *bool, page, perPage int, except uuid.UUID, expandables Expandables) ([]*Session, error)
ListSessionsByIdentity(ctx context.Context, iID uuid.UUID, active *bool, page, perPage int, except uuid.UUID, expandables Expandables) ([]*Session, int64, error)

// UpsertSession inserts or updates a session into / in the store.
UpsertSession(ctx context.Context, s *Session) error
Expand Down
15 changes: 10 additions & 5 deletions session/test/persistence.go
Original file line number Diff line number Diff line change
Expand Up @@ -178,10 +178,11 @@ func TestPersister(ctx context.Context, conf *config.Config, p interface {
},
} {
t.Run("case="+tc.desc, func(t *testing.T) {
actual, err := p.ListSessionsByIdentity(ctx, i.ID, tc.active, 1, 10, tc.except, session.ExpandEverything)
actual, total, err := p.ListSessionsByIdentity(ctx, i.ID, tc.active, 1, 10, tc.except, session.ExpandEverything)
require.NoError(t, err)

require.Equal(t, len(tc.expected), len(actual))
require.Equal(t, int64(len(tc.expected)), total)
for _, es := range tc.expected {
found := false
for _, as := range actual {
Expand All @@ -197,8 +198,9 @@ func TestPersister(ctx context.Context, conf *config.Config, p interface {

t.Run("other network", func(t *testing.T) {
_, other := testhelpers.NewNetwork(t, ctx, p)
actual, err := other.ListSessionsByIdentity(ctx, i.ID, nil, 1, 10, uuid.Nil, session.ExpandNothing)
actual, total, err := other.ListSessionsByIdentity(ctx, i.ID, nil, 1, 10, uuid.Nil, session.ExpandNothing)
require.NoError(t, err)
require.Equal(t, int64(0), total)
assert.Len(t, actual, 0)
})
})
Expand Down Expand Up @@ -322,9 +324,10 @@ func TestPersister(ctx context.Context, conf *config.Config, p interface {
require.NoError(t, err)
assert.Equal(t, 1, n)

actual, err := p.ListSessionsByIdentity(ctx, sessions[0].IdentityID, nil, 1, 10, uuid.Nil, session.ExpandNothing)
actual, total, err := p.ListSessionsByIdentity(ctx, sessions[0].IdentityID, nil, 1, 10, uuid.Nil, session.ExpandNothing)
require.NoError(t, err)
require.Len(t, actual, 2)
require.Equal(t, int64(2), total)

if actual[0].ID == sessions[0].ID {
assert.True(t, actual[0].Active)
Expand All @@ -335,9 +338,10 @@ func TestPersister(ctx context.Context, conf *config.Config, p interface {
assert.False(t, actual[0].Active)
}

otherIdentitiesSessions, err := p.ListSessionsByIdentity(ctx, sessions[2].IdentityID, nil, 1, 10, uuid.Nil, session.ExpandNothing)
otherIdentitiesSessions, total, err := p.ListSessionsByIdentity(ctx, sessions[2].IdentityID, nil, 1, 10, uuid.Nil, session.ExpandNothing)
require.NoError(t, err)
require.Len(t, actual, 2)
require.Equal(t, int64(2), total)

for _, s := range otherIdentitiesSessions {
assert.True(t, s.Active)
Expand Down Expand Up @@ -369,9 +373,10 @@ func TestPersister(ctx context.Context, conf *config.Config, p interface {

require.NoError(t, p.RevokeSession(ctx, sessions[0].IdentityID, sessions[0].ID))

actual, err := p.ListSessionsByIdentity(ctx, sessions[0].IdentityID, nil, 1, 10, uuid.Nil, session.ExpandNothing)
actual, total, err := p.ListSessionsByIdentity(ctx, sessions[0].IdentityID, nil, 1, 10, uuid.Nil, session.ExpandNothing)
require.NoError(t, err)
require.Len(t, actual, 2)
require.Equal(t, int64(2), total)

if actual[0].ID == sessions[0].ID {
assert.False(t, actual[0].Active)
Expand Down

0 comments on commit 0c2efa2

Please sign in to comment.