From d904a3872ded278efa0299b13dd7e01714c42d80 Mon Sep 17 00:00:00 2001 From: Andrew Dunham Date: Wed, 7 Jun 2023 19:12:01 -0400 Subject: [PATCH] sqlite: don't return persistent queries to two callers This test case used to fail with: === RUN TestPrepareReuse sqlite_test.go:1129: rows2: num=2, want 1 --- FAIL: TestPrepareReuse (0.00s) This indicates that we were erroneously returning the same "persisted" query twice, which resulted in the two sets of Rows returned affecting each other. Instead, only persist queries in c.stmts on Close, after they've been reset and are ready for re-use. Updates #73 Signed-off-by: Andrew Dunham --- sqlite.go | 22 +++++++++++++-- sqlite_test.go | 74 ++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 93 insertions(+), 3 deletions(-) diff --git a/sqlite.go b/sqlite.go index e049e3c..f5a2a1d 100644 --- a/sqlite.go +++ b/sqlite.go @@ -192,6 +192,9 @@ func (c *conn) prepare(ctx context.Context, query string, persist bool) (s *stmt query = strings.TrimSpace(query) if s := c.stmts[query]; s != nil { s.prepCtx = ctx + + // don't hand the same statement out twice; this is re-added on s.Close + delete(c.stmts, query) return s, nil } if c.tracer != nil { @@ -234,10 +237,12 @@ func (c *conn) prepare(ctx context.Context, query string, persist bool) (s *stmt return s, nil } + // NOTE: don't add the statement to c.stmts here, since we could return + // it to another caller before Close is called; it's added to the + // c.stmts map on Close. if c.stmts == nil { c.stmts = make(map[string]*stmt) } - c.stmts[query] = s return s, nil } @@ -391,8 +396,19 @@ func (s *stmt) NumInput() int { } func (s *stmt) Close() error { - if s.persist { - return s.reserr("Stmt.Close", s.resetAndClear()) + // We return this statement to the conn only if it's persistent, and + // only if there's not already a statement with the same query already + // cached there. + shouldPersist := s.persist + if _, alreadyPersisted := s.conn.stmts[s.query]; alreadyPersisted { + shouldPersist = false + } + if shouldPersist { + err := s.reserr("Stmt.Close", s.resetAndClear()) + if err == nil { + s.conn.stmts[s.query] = s + } + return err } return s.reserr("Stmt.Close", s.stmt.Finalize()) } diff --git a/sqlite_test.go b/sqlite_test.go index 2070b39..2939d66 100644 --- a/sqlite_test.go +++ b/sqlite_test.go @@ -1072,3 +1072,77 @@ func BenchmarkBeginTxNoop(b *testing.B) { // TODO(crawshaw): test TextMarshaler // TODO(crawshaw): test named types // TODO(crawshaw): check coverage + +// This tests for https://github.com/tailscale/sqlite/issues/73 +func TestPrepareReuse(t *testing.T) { + db := openTestDB(t) + ctx := context.Background() + sqlConn, err := db.Conn(ctx) + if err != nil { + t.Fatal(err) + } + + // TODO(andrew): this deadlocks; why? + //defer sqlConn.Close() + + // Insert a bunch of values into a table that we'll query to get + // multiple rows back. + err = ExecScript(sqlConn, + `BEGIN; + CREATE TABLE t (c); + INSERT INTO t VALUES (1), (2), (3), (4); + COMMIT;`) + if err != nil { + t.Fatal(err) + } + + ctx = WithPersist(ctx) + + // Calling PrepareContext twice in a row used to return the same + // statement to both callers. + const query = "SELECT c FROM t;" + stmt1, err := sqlConn.PrepareContext(ctx, query) + if err != nil { + t.Fatal(err) + } + stmt2, err := sqlConn.PrepareContext(ctx, query) + if err != nil { + t.Fatal(err) + } + + rows1, err := stmt1.QueryContext(ctx) + if err != nil { + t.Fatal(err) + } + rows2, err := stmt2.QueryContext(ctx) + if err != nil { + t.Fatal(err) + } + + assertResult := func(prefix string, rows *sql.Rows, want int) { + var num int + if err := rows.Scan(&num); err != nil { + t.Fatalf(prefix+"Scan: %v", err) + } + if num != want { + t.Fatalf(prefix+"num=%d, want %d", num, want) + } + } + + // Each set of rows should get a full copy of the query results; if + // these are incorrectly shared, then advancing one Rows will change + // the results from the other. + for i := 0; i < 4; i++ { + if !rows1.Next() { + t.Fatalf("[1] pass %d: Next=false", i) + } + if !rows2.Next() { + t.Fatalf("[2] pass %d: Next=false", i) + } + + // rows2 should be different from row1 and should return a + // different set of values. + assertResult("rows1: ", rows1, i+1) + assertResult("rows2: ", rows2, i+1) + } +}