diff --git a/cgosqlite/cgosqlite.go b/cgosqlite/cgosqlite.go index 1b6c873..9e530e0 100644 --- a/cgosqlite/cgosqlite.go +++ b/cgosqlite/cgosqlite.go @@ -222,7 +222,11 @@ func (stmt *Stmt) SQL() string { } func (stmt *Stmt) ExpandedSQL() string { - return C.GoString(C.sqlite3_expanded_sql(stmt.stmt.ptr())) + // sqlite3_expanded_sql returns a string obtained by sqlite3_malloc, which + // must be freed after use. + cstr := C.sqlite3_expanded_sql(stmt.stmt.ptr()) + defer C.sqlite3_free(unsafe.Pointer(cstr)) + return C.GoString(cstr) } func (stmt *Stmt) Reset() error { diff --git a/sqlite_test.go b/sqlite_test.go index f0c3230..a9f85d5 100644 --- a/sqlite_test.go +++ b/sqlite_test.go @@ -1539,3 +1539,32 @@ func TestConnLogger_read_tx(t *testing.T) { } } } + +func TestExpandedSQL(t *testing.T) { + ctx := context.Background() + connector := Connector("file:"+t.TempDir()+"/test.db", nil, nil) + sqlConn, err := connector.Connect(ctx) + if err != nil { + t.Fatalf("Connect: %v", err) + } + conn := sqlConn.(*conn) + + sqlStmt, err := conn.PrepareContext(ctx, "SELECT ? + ?") + if err != nil { + t.Fatalf("PrepareContext: %v", err) + } + stmt, ok := sqlStmt.(*stmt) + if !ok { + t.Fatalf("not a *stmt: %#v", stmt) + } + if err := stmt.bindAll([]driver.NamedValue{ + {Ordinal: 1, Value: 6}, + {Ordinal: 2, Value: 7}, + }); err != nil { + t.Errorf("bindAll: %v", err) + } + + if got, want := stmt.stmt.ExpandedSQL(), "SELECT 6 + 7"; got != want { + t.Errorf("wrong sql: got %q, want %q", got, want) + } +}