From 1ea34748e05e6dd563f7b780cc1130d4c50b635f Mon Sep 17 00:00:00 2001 From: Miguel Molina Date: Wed, 12 Jul 2017 17:43:46 +0200 Subject: [PATCH] do not query again if result is less than limit If a query returned less rows than the limit, it would make an additional query and then it would mark the batcher as EOF. This change takes care of that case and makes sure no additional query is made. Signed-off-by: Miguel Molina --- batcher.go | 15 +++++++++++++-- batcher_test.go | 41 ++++++++++++++++++++++++++++++++++++++++- 2 files changed, 53 insertions(+), 3 deletions(-) diff --git a/batcher.go b/batcher.go index 4262ca1..2e708d6 100644 --- a/batcher.go +++ b/batcher.go @@ -52,7 +52,7 @@ func newBatchQueryRunner(schema Schema, db squirrel.DBProxy, q Query) *batchQuer } func (r *batchQueryRunner) next() (Record, error) { - if r.eof { + if r.eof && len(r.records) == 0 { return nil, errNoMoreRows } @@ -63,7 +63,7 @@ func (r *batchQueryRunner) next() (Record, error) { ) limit := r.q.GetLimit() - if limit <= 0 || limit > uint64(r.total) { + if limit == 0 || limit > uint64(r.total) { records, err = r.loadNextBatch() if err != nil { return nil, err @@ -75,6 +75,17 @@ func (r *batchQueryRunner) next() (Record, error) { return nil, errNoMoreRows } + batchSize := r.q.GetBatchSize() + if batchSize > 0 && batchSize < limit { + if uint64(len(records)) < batchSize { + r.eof = true + } + } else if limit > 0 { + if uint64(len(records)) < limit { + r.eof = true + } + } + r.total += len(records) r.records = records[1:] return records[0], nil diff --git a/batcher_test.go b/batcher_test.go index f706e40..bccc25c 100644 --- a/batcher_test.go +++ b/batcher_test.go @@ -54,7 +54,7 @@ func TestBatcherLimit(t *testing.T) { q.BatchSize(2) q.Limit(5) r.NoError(q.AddRelation(RelSchema, "rels", OneToMany, Eq(f("foo"), "1"))) - runner := newBatchQueryRunner(ModelSchema, squirrel.NewStmtCacher(db), q) + runner := newBatchQueryRunner(ModelSchema, store.proxy, q) rs := NewBatchingResultSet(runner) var count int @@ -66,3 +66,42 @@ func TestBatcherLimit(t *testing.T) { r.NoError(err) r.Equal(5, count) } + +func TestBatcherNoExtraQueryIfLessThanLimit(t *testing.T) { + r := require.New(t) + db, err := openTestDB() + r.NoError(err) + setupTables(t, db) + defer db.Close() + defer teardownTables(t, db) + + store := NewStore(db) + for i := 0; i < 4; i++ { + m := newModel("foo", "bar", 1) + r.NoError(store.Insert(ModelSchema, m)) + + for i := 0; i < 4; i++ { + r.NoError(store.Insert(RelSchema, newRel(m.GetID(), fmt.Sprint(i)))) + } + } + + q := NewBaseQuery(ModelSchema) + q.Limit(6) + r.NoError(q.AddRelation(RelSchema, "rels", OneToMany, Eq(f("foo"), "1"))) + var queries int + proxy := store.DebugWith(func(_ string, _ ...interface{}) { + queries++ + }).proxy + runner := newBatchQueryRunner(ModelSchema, proxy, q) + rs := NewBatchingResultSet(runner) + + var count int + for rs.Next() { + _, err := rs.Get(nil) + r.NoError(err) + count++ + } + r.NoError(err) + r.Equal(4, count) + r.Equal(2, queries) +}