diff --git a/sqlite.go b/sqlite.go index e049e3c..f5a2a1d 100644 --- a/sqlite.go +++ b/sqlite.go @@ -192,6 +192,9 @@ func (c *conn) prepare(ctx context.Context, query string, persist bool) (s *stmt query = strings.TrimSpace(query) if s := c.stmts[query]; s != nil { s.prepCtx = ctx + + // don't hand the same statement out twice; this is re-added on s.Close + delete(c.stmts, query) return s, nil } if c.tracer != nil { @@ -234,10 +237,12 @@ func (c *conn) prepare(ctx context.Context, query string, persist bool) (s *stmt return s, nil } + // NOTE: don't add the statement to c.stmts here, since we could return + // it to another caller before Close is called; it's added to the + // c.stmts map on Close. if c.stmts == nil { c.stmts = make(map[string]*stmt) } - c.stmts[query] = s return s, nil } @@ -391,8 +396,19 @@ func (s *stmt) NumInput() int { } func (s *stmt) Close() error { - if s.persist { - return s.reserr("Stmt.Close", s.resetAndClear()) + // We return this statement to the conn only if it's persistent, and + // only if there's not already a statement with the same query already + // cached there. + shouldPersist := s.persist + if _, alreadyPersisted := s.conn.stmts[s.query]; alreadyPersisted { + shouldPersist = false + } + if shouldPersist { + err := s.reserr("Stmt.Close", s.resetAndClear()) + if err == nil { + s.conn.stmts[s.query] = s + } + return err } return s.reserr("Stmt.Close", s.stmt.Finalize()) } diff --git a/sqlite_test.go b/sqlite_test.go index 2070b39..2939d66 100644 --- a/sqlite_test.go +++ b/sqlite_test.go @@ -1072,3 +1072,77 @@ func BenchmarkBeginTxNoop(b *testing.B) { // TODO(crawshaw): test TextMarshaler // TODO(crawshaw): test named types // TODO(crawshaw): check coverage + +// This tests for https://github.com/tailscale/sqlite/issues/73 +func TestPrepareReuse(t *testing.T) { + db := openTestDB(t) + ctx := context.Background() + sqlConn, err := db.Conn(ctx) + if err != nil { + t.Fatal(err) + } + + // TODO(andrew): this deadlocks; why? + //defer sqlConn.Close() + + // Insert a bunch of values into a table that we'll query to get + // multiple rows back. + err = ExecScript(sqlConn, + `BEGIN; + CREATE TABLE t (c); + INSERT INTO t VALUES (1), (2), (3), (4); + COMMIT;`) + if err != nil { + t.Fatal(err) + } + + ctx = WithPersist(ctx) + + // Calling PrepareContext twice in a row used to return the same + // statement to both callers. + const query = "SELECT c FROM t;" + stmt1, err := sqlConn.PrepareContext(ctx, query) + if err != nil { + t.Fatal(err) + } + stmt2, err := sqlConn.PrepareContext(ctx, query) + if err != nil { + t.Fatal(err) + } + + rows1, err := stmt1.QueryContext(ctx) + if err != nil { + t.Fatal(err) + } + rows2, err := stmt2.QueryContext(ctx) + if err != nil { + t.Fatal(err) + } + + assertResult := func(prefix string, rows *sql.Rows, want int) { + var num int + if err := rows.Scan(&num); err != nil { + t.Fatalf(prefix+"Scan: %v", err) + } + if num != want { + t.Fatalf(prefix+"num=%d, want %d", num, want) + } + } + + // Each set of rows should get a full copy of the query results; if + // these are incorrectly shared, then advancing one Rows will change + // the results from the other. + for i := 0; i < 4; i++ { + if !rows1.Next() { + t.Fatalf("[1] pass %d: Next=false", i) + } + if !rows2.Next() { + t.Fatalf("[2] pass %d: Next=false", i) + } + + // rows2 should be different from row1 and should return a + // different set of values. + assertResult("rows1: ", rows1, i+1) + assertResult("rows2: ", rows2, i+1) + } +}