Skip to content

Commit

Permalink
sqlite: don't return persistent queries to two callers
Browse files Browse the repository at this point in the history
This test case used to fail with:

    === RUN   TestPrepareReuse
        sqlite_test.go:1129: rows2: num=2, want 1
    --- FAIL: TestPrepareReuse (0.00s)

This indicates that we were erroneously returning the same "persisted"
query twice, which resulted in the two sets of Rows returned affecting
each other. Instead, only persist queries in c.stmts on Close, after
they've been reset and are ready for re-use.

Updates #73

Signed-off-by: Andrew Dunham <andrew@du.nham.ca>
  • Loading branch information
andrew-d committed Jun 7, 2023
1 parent 2d70ae2 commit d904a38
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 3 deletions.
22 changes: 19 additions & 3 deletions sqlite.go
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,9 @@ func (c *conn) prepare(ctx context.Context, query string, persist bool) (s *stmt
query = strings.TrimSpace(query)
if s := c.stmts[query]; s != nil {
s.prepCtx = ctx

// don't hand the same statement out twice; this is re-added on s.Close
delete(c.stmts, query)
return s, nil
}
if c.tracer != nil {
Expand Down Expand Up @@ -234,10 +237,12 @@ func (c *conn) prepare(ctx context.Context, query string, persist bool) (s *stmt
return s, nil
}

// NOTE: don't add the statement to c.stmts here, since we could return
// it to another caller before Close is called; it's added to the
// c.stmts map on Close.
if c.stmts == nil {
c.stmts = make(map[string]*stmt)
}
c.stmts[query] = s
return s, nil
}

Expand Down Expand Up @@ -391,8 +396,19 @@ func (s *stmt) NumInput() int {
}

func (s *stmt) Close() error {
if s.persist {
return s.reserr("Stmt.Close", s.resetAndClear())
// We return this statement to the conn only if it's persistent, and
// only if there's not already a statement with the same query already
// cached there.
shouldPersist := s.persist
if _, alreadyPersisted := s.conn.stmts[s.query]; alreadyPersisted {
shouldPersist = false
}
if shouldPersist {
err := s.reserr("Stmt.Close", s.resetAndClear())
if err == nil {
s.conn.stmts[s.query] = s
}
return err
}
return s.reserr("Stmt.Close", s.stmt.Finalize())
}
Expand Down
74 changes: 74 additions & 0 deletions sqlite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1072,3 +1072,77 @@ func BenchmarkBeginTxNoop(b *testing.B) {
// TODO(crawshaw): test TextMarshaler
// TODO(crawshaw): test named types
// TODO(crawshaw): check coverage

// This tests for https://github.com/tailscale/sqlite/issues/73
func TestPrepareReuse(t *testing.T) {
db := openTestDB(t)
ctx := context.Background()
sqlConn, err := db.Conn(ctx)
if err != nil {
t.Fatal(err)
}

// TODO(andrew): this deadlocks; why?
//defer sqlConn.Close()

// Insert a bunch of values into a table that we'll query to get
// multiple rows back.
err = ExecScript(sqlConn,
`BEGIN;
CREATE TABLE t (c);
INSERT INTO t VALUES (1), (2), (3), (4);
COMMIT;`)
if err != nil {
t.Fatal(err)
}

ctx = WithPersist(ctx)

// Calling PrepareContext twice in a row used to return the same
// statement to both callers.
const query = "SELECT c FROM t;"
stmt1, err := sqlConn.PrepareContext(ctx, query)
if err != nil {
t.Fatal(err)
}
stmt2, err := sqlConn.PrepareContext(ctx, query)
if err != nil {
t.Fatal(err)
}

rows1, err := stmt1.QueryContext(ctx)
if err != nil {
t.Fatal(err)
}
rows2, err := stmt2.QueryContext(ctx)
if err != nil {
t.Fatal(err)
}

assertResult := func(prefix string, rows *sql.Rows, want int) {
var num int
if err := rows.Scan(&num); err != nil {
t.Fatalf(prefix+"Scan: %v", err)
}
if num != want {
t.Fatalf(prefix+"num=%d, want %d", num, want)
}
}

// Each set of rows should get a full copy of the query results; if
// these are incorrectly shared, then advancing one Rows will change
// the results from the other.
for i := 0; i < 4; i++ {
if !rows1.Next() {
t.Fatalf("[1] pass %d: Next=false", i)
}
if !rows2.Next() {
t.Fatalf("[2] pass %d: Next=false", i)
}

// rows2 should be different from row1 and should return a
// different set of values.
assertResult("rows1: ", rows1, i+1)
assertResult("rows2: ", rows2, i+1)
}
}

0 comments on commit d904a38

Please sign in to comment.