diff --git a/cgosqlite/cgosqlite.go b/cgosqlite/cgosqlite.go index 872b913..a3fae26 100644 --- a/cgosqlite/cgosqlite.go +++ b/cgosqlite/cgosqlite.go @@ -203,6 +203,10 @@ func (db *DB) Prepare(query string, prepFlags sqliteh.PrepareFlags) (stmt sqlite return &Stmt{db: db, stmt: cStmtFromPtr(cstmt)}, remainingQuery, nil } +func (db *DB) DisableFunction(name string, numArgs int) error { + return errCode(C.ts_sqlite3_disable_function(db.db, C.CString(name), C.int(numArgs))) +} + func (stmt *Stmt) DBHandle() sqliteh.DB { cdb := C.sqlite3_db_handle(stmt.stmt.ptr()) if cdb != nil { diff --git a/cgosqlite/cgosqlite.h b/cgosqlite/cgosqlite.h index 7e88483..a58c380 100644 --- a/cgosqlite/cgosqlite.h +++ b/cgosqlite/cgosqlite.h @@ -146,3 +146,7 @@ static double ts_sqlite3_column_double(handle_sqlite3_stmt stmt, int iCol) { static sqlite3_int64 ts_sqlite3_column_int64(handle_sqlite3_stmt stmt, int iCol) { return sqlite3_column_int64((sqlite3_stmt*)(stmt), iCol); } + +static int ts_sqlite3_disable_function(sqlite3 *db, const char *zFunctionName, int nArg) { + return sqlite3_create_function(db, zFunctionName, nArg, SQLITE_ANY, NULL, NULL, NULL, NULL); +} \ No newline at end of file diff --git a/sqlite.go b/sqlite.go index 4dbee0f..7e592b2 100644 --- a/sqlite.go +++ b/sqlite.go @@ -1036,6 +1036,18 @@ func Checkpoint(sqlconn SQLConn, dbName string, mode sqliteh.Checkpoint) (numFra return numFrames, numFramesCheckpointed, err } +// DisableFunction disables the named function on the given connection. +// numArgs must match the number of args of the function to be disabled. +func DisableFunction(sqlconn SQLConn, name string, numArgs int) error { + return sqlconn.Raw(func(driverConn any) error { + c, ok := driverConn.(*conn) + if !ok { + return fmt.Errorf("sqlite.DisableFunction: sql.Conn is not the sqlite driver: %T", driverConn) + } + return c.db.DisableFunction(name, numArgs) + }) +} + // WithPersist makes a ctx instruct the sqlite driver to persist a prepared query. // // This should be used with recurring queries to avoid constant parsing and diff --git a/sqlite_test.go b/sqlite_test.go index 361c04f..4011d9b 100644 --- a/sqlite_test.go +++ b/sqlite_test.go @@ -1330,3 +1330,27 @@ func TestRegression(t *testing.T) { t.Log("OK") // Reaching here at all means we didn't panic. }) } + +func TestDisableFunction(t *testing.T) { + db := openTestDB(t) + + conn, err := db.Conn(context.Background()) + if err != nil { + t.Fatal(err) + } + defer conn.Close() + + ctx := context.Background() + + if _, err := conn.ExecContext(ctx, "SELECT LOWER('Hi')"); err != nil { + t.Fatal("Attempting to use the LOWER function before disabling should have been allowed") + } + + if err := DisableFunction(conn, "lower", 1); err != nil { + t.Fatal(err) + } + + if _, err := conn.ExecContext(ctx, "SELECT LOWER('Hi')"); err == nil { + t.Fatal("Attempting to use the LOWER function after disabling should have failed") + } +} diff --git a/sqliteh/sqliteh.go b/sqliteh/sqliteh.go index 3210908..ab9290e 100644 --- a/sqliteh/sqliteh.go +++ b/sqliteh/sqliteh.go @@ -60,6 +60,11 @@ type DB interface { // // If hook is nil, the hook is removed. SetWALHook(hook func(dbName string, pages int)) + // DisableFunction disables an existing function (including built-ins) using + // sqlite3_create_function. The name and numArgs must match the existing + // function's signature. + // + DisableFunction(name string, numArgs int) error } // Stmt is an sqlite3_stmt* database connection object.