Skip to content

Commit

Permalink
refactor: change pagination to use keyset pagination
Browse files Browse the repository at this point in the history
The page token now is the last ID of the previous page.  This enables faster queries and more stable pagination.
NOTE: in case an integration modified page tokens to control pagination, this change will break the integration. Page tokens are opaque strings and should never be messed with.
  • Loading branch information
zepatrik committed Aug 1, 2022
1 parent a102cee commit 7b861c9
Show file tree
Hide file tree
Showing 6 changed files with 69 additions and 62 deletions.
27 changes: 15 additions & 12 deletions internal/check/engine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -456,7 +456,7 @@ func TestEngine(t *testing.T) {

e := check.NewEngine(reg)

for i, user := range users {
for _, user := range users {
t.Run("user="+user.String(), func(t *testing.T) {
allowed, err := e.SubjectIsAllowed(ctx, &relationtuple.RelationTuple{
Namespace: namesp,
Expand All @@ -466,19 +466,22 @@ func TestEngine(t *testing.T) {
}, 0)
require.NoError(t, err)
assert.True(t, allowed)

// pagination assertions
if i >= pageSize {
assert.Len(t, reg.RequestedPages, 2)
// reset requested pages for next iteration
reg.RequestedPages = nil
} else {
assert.Len(t, reg.RequestedPages, 1)
// reset requested pages for next iteration
reg.RequestedPages = nil
}
})
}

require.Len(t, reg.RequestedPages, 6)
var firstPage int
otherPages := make([]string, 0, 2)
for _, page := range reg.RequestedPages {
if page == "" {
firstPage++
} else {
otherPages = append(otherPages, page)
}
}
assert.Equal(t, 4, firstPage)
require.Len(t, otherPages, 2)
assert.Equal(t, otherPages[0], otherPages[1])
})

t.Run("case=wide tuple graph", func(t *testing.T) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ func TestMigrations(t *testing.T) {
})

t.Run("suite=uuid_migrations", func(t *testing.T) {
ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
ctx, cancel := context.WithTimeout(ctx, 20*time.Second)
defer cancel()
require.NoError(t, tm.Down(ctx, -1))

Expand Down
17 changes: 10 additions & 7 deletions internal/persistence/sql/pagination_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import (
"fmt"
"testing"

"github.com/gofrs/uuid"

"github.com/stretchr/testify/assert"

"github.com/ory/keto/internal/persistence"
Expand All @@ -14,28 +16,29 @@ import (
func TestPaginationToken(t *testing.T) {
t.Parallel()

ids := x.UUIDs(3)
for i, tc := range []struct {
size int
token string
expectedErr error
expectedPage int
expectedLastID uuid.UUID
expectedPerPage int
}{
{
size: 10,
token: "10",
expectedPage: 10,
token: ids[0].String(),
expectedLastID: ids[0],
expectedPerPage: 10,
},
{
size: 0,
token: "15",
expectedPage: 15,
token: ids[1].String(),
expectedLastID: ids[1],
expectedPerPage: defaultPageSize,
},
{
size: 0,
token: "-15",
token: "foobar",
expectedErr: persistence.ErrMalformedPageToken,
expectedPerPage: defaultPageSize,
},
Expand All @@ -45,7 +48,7 @@ func TestPaginationToken(t *testing.T) {

assert.True(t, errors.Is(err, tc.expectedErr))
assert.Equal(t, tc.expectedPerPage, pagination.PerPage)
assert.Equal(t, tc.expectedPage, pagination.Page)
assert.Equal(t, tc.expectedLastID, pagination.LastID)
})
}
}
15 changes: 7 additions & 8 deletions internal/persistence/sql/persister.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,7 @@ package sql
import (
"context"
"embed"
"fmt"
"reflect"
"strconv"

"github.com/gobuffalo/pop/v6"
"github.com/gofrs/uuid"
Expand All @@ -24,7 +22,8 @@ type (
nid uuid.UUID
}
internalPagination struct {
Page, PerPage int
PerPage int
LastID uuid.UUID
}
dependencies interface {
x.LoggerProvider
Expand Down Expand Up @@ -108,19 +107,19 @@ func internalPaginationFromOptions(opts ...x.PaginationOptionSetter) (*internalP

func (p *internalPagination) parsePageToken(t string) error {
if t == "" {
p.Page = 1
p.LastID = uuid.Nil
return nil
}

i, err := strconv.ParseUint(t, 10, 32)
i, err := uuid.FromString(t)
if err != nil {
return errors.WithStack(persistence.ErrMalformedPageToken)
}

p.Page = int(i)
p.LastID = i
return nil
}

func (p *internalPagination) encodeNextPageToken() string {
return fmt.Sprintf("%d", p.Page+1)
func (p *internalPagination) encodeNextPageToken(lastID uuid.UUID) string {
return lastID.String()
}
15 changes: 10 additions & 5 deletions internal/persistence/sql/relationtuples.go
Original file line number Diff line number Diff line change
Expand Up @@ -210,8 +210,9 @@ func (p *Persister) GetRelationTuples(ctx context.Context, query *relationtuple.
}

sqlQuery := p.QueryWithNetwork(ctx).
Order("nid, namespace, object, relation, subject_id, subject_set_namespace, subject_set_object, subject_set_relation, commit_time").
Paginate(pagination.Page, pagination.PerPage)
Order("shard_id, nid").
Where("shard_id > ?", pagination.LastID).
Limit(pagination.PerPage + 1)

err = p.whereQuery(ctx, sqlQuery, query)
if err != nil {
Expand All @@ -221,10 +222,14 @@ func (p *Persister) GetRelationTuples(ctx context.Context, query *relationtuple.
if err := sqlQuery.All(&res); err != nil {
return nil, "", sqlcon.HandleError(err)
}
if len(res) == 0 {
return make([]*relationtuple.RelationTuple, 0, 0), "", nil
}

nextPageToken := pagination.encodeNextPageToken()
if sqlQuery.Paginator.Page >= sqlQuery.Paginator.TotalPages {
nextPageToken = ""
var nextPageToken string
if len(res) > pagination.PerPage {
res = res[:len(res)-1]
nextPageToken = pagination.encodeNextPageToken(res[len(res)-1].ID)
}

internalRes := make([]*relationtuple.RelationTuple, 0, len(res))
Expand Down
55 changes: 26 additions & 29 deletions internal/relationtuple/read_server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -166,37 +166,34 @@ func TestReadHandlers(t *testing.T) {
relationtuple.MapAndWriteTuples(t, reg, tuples...)

var firstResp ketoapi.GetResponse
t.Run("case=first page", func(t *testing.T) {
resp, err := ts.Client().Get(ts.URL + relationtuple.ReadRouteBase + "?" + url.Values{
"namespace": {nspace.Name},
"page_size": {"1"},
}.Encode())
require.NoError(t, err)
require.Equal(t, http.StatusOK, resp.StatusCode)

require.NoError(t, json.NewDecoder(resp.Body).Decode(&firstResp))
require.Len(t, firstResp.RelationTuples, 1)
assert.Contains(t, tuples, firstResp.RelationTuples[0])
assert.NotEqual(t, "", firstResp.NextPageToken)
})

t.Run("case=second page", func(t *testing.T) {
resp, err := ts.Client().Get(ts.URL + relationtuple.ReadRouteBase + "?" + url.Values{
"namespace": {nspace.Name},
"page_size": {"1"},
"page_token": {firstResp.NextPageToken},
}.Encode())
require.NoError(t, err)
require.Equal(t, http.StatusOK, resp.StatusCode)
resp, err := ts.Client().Get(ts.URL + relationtuple.ReadRouteBase + "?" + url.Values{
"namespace": {nspace.Name},
"page_size": {"1"},
}.Encode())
require.NoError(t, err)
require.Equal(t, http.StatusOK, resp.StatusCode)

require.NoError(t, json.NewDecoder(resp.Body).Decode(&firstResp))
require.Len(t, firstResp.RelationTuples, 1)
assert.Contains(t, tuples, firstResp.RelationTuples[0])
assert.NotEqual(t, "", firstResp.NextPageToken)

// second page
resp, err = ts.Client().Get(ts.URL + relationtuple.ReadRouteBase + "?" + url.Values{
"namespace": {nspace.Name},
"page_size": {"1"},
"page_token": {firstResp.NextPageToken},
}.Encode())
require.NoError(t, err)
require.Equal(t, http.StatusOK, resp.StatusCode)

secondResp := ketoapi.GetResponse{}
require.NoError(t, json.NewDecoder(resp.Body).Decode(&secondResp))
require.Len(t, secondResp.RelationTuples, 1)
secondResp := ketoapi.GetResponse{}
require.NoError(t, json.NewDecoder(resp.Body).Decode(&secondResp))
require.Len(t, secondResp.RelationTuples, 1)

assert.NotEqual(t, firstResp.RelationTuples, secondResp.RelationTuples)
assert.Contains(t, tuples, secondResp.RelationTuples[0])
assert.Equal(t, "", secondResp.NextPageToken)
})
assert.NotEqual(t, firstResp.RelationTuples, secondResp.RelationTuples)
assert.Contains(t, tuples, secondResp.RelationTuples[0])
assert.Equal(t, "", secondResp.NextPageToken)
})

t.Run("case=returs bad request on invalid page size", func(t *testing.T) {
Expand Down

0 comments on commit 7b861c9

Please sign in to comment.