Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions cgosqlite/cgosqlite.go
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,10 @@ func (db *DB) Prepare(query string, prepFlags sqliteh.PrepareFlags) (stmt sqlite
return &Stmt{db: db, stmt: cStmtFromPtr(cstmt)}, remainingQuery, nil
}

func (db *DB) DisableFunction(name string, numArgs int) error {
return errCode(C.ts_sqlite3_disable_function(db.db, C.CString(name), C.int(numArgs)))
}

func (stmt *Stmt) DBHandle() sqliteh.DB {
cdb := C.sqlite3_db_handle(stmt.stmt.ptr())
if cdb != nil {
Expand Down
4 changes: 4 additions & 0 deletions cgosqlite/cgosqlite.h
Original file line number Diff line number Diff line change
Expand Up @@ -146,3 +146,7 @@ static double ts_sqlite3_column_double(handle_sqlite3_stmt stmt, int iCol) {
static sqlite3_int64 ts_sqlite3_column_int64(handle_sqlite3_stmt stmt, int iCol) {
return sqlite3_column_int64((sqlite3_stmt*)(stmt), iCol);
}

static int ts_sqlite3_disable_function(sqlite3 *db, const char *zFunctionName, int nArg) {
return sqlite3_create_function(db, zFunctionName, nArg, SQLITE_ANY, NULL, NULL, NULL, NULL);
}
12 changes: 12 additions & 0 deletions sqlite.go
Original file line number Diff line number Diff line change
Expand Up @@ -1036,6 +1036,18 @@ func Checkpoint(sqlconn SQLConn, dbName string, mode sqliteh.Checkpoint) (numFra
return numFrames, numFramesCheckpointed, err
}

// DisableFunction disables the named function on the given connection.
// numArgs must match the number of args of the function to be disabled.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As written, this will also let you "create" new previously non-existing no-op functions. That's harmless, but someone who misspells the function name they wanted to disable could easily be confused as there's no error to say "that doesn't exist".

Is it practical to look up the requested function first, and use that to report an error if someone tries to "disable" a function that isn't there? (E.g., we could query select name, narg from pragma_function_list where name = 'fname'). That would also have the incidental benefit that we wouldn't need to know the argument count a priori.

Not positive it's worth it, but this should be uncommon enough that it might be worth the extra step.

func DisableFunction(sqlconn SQLConn, name string, numArgs int) error {
return sqlconn.Raw(func(driverConn any) error {
c, ok := driverConn.(*conn)
if !ok {
return fmt.Errorf("sqlite.DisableFunction: sql.Conn is not the sqlite driver: %T", driverConn)
}
return c.db.DisableFunction(name, numArgs)
})
}

// WithPersist makes a ctx instruct the sqlite driver to persist a prepared query.
//
// This should be used with recurring queries to avoid constant parsing and
Expand Down
24 changes: 24 additions & 0 deletions sqlite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1330,3 +1330,27 @@ func TestRegression(t *testing.T) {
t.Log("OK") // Reaching here at all means we didn't panic.
})
}

func TestDisableFunction(t *testing.T) {
db := openTestDB(t)

conn, err := db.Conn(context.Background())
if err != nil {
t.Fatal(err)
}
defer conn.Close()

ctx := context.Background()

if _, err := conn.ExecContext(ctx, "SELECT LOWER('Hi')"); err != nil {
t.Fatal("Attempting to use the LOWER function before disabling should have been allowed")
}

if err := DisableFunction(conn, "lower", 1); err != nil {
t.Fatal(err)
}

if _, err := conn.ExecContext(ctx, "SELECT LOWER('Hi')"); err == nil {
t.Fatal("Attempting to use the LOWER function after disabling should have failed")
}
}
5 changes: 5 additions & 0 deletions sqliteh/sqliteh.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,11 @@ type DB interface {
//
// If hook is nil, the hook is removed.
SetWALHook(hook func(dbName string, pages int))
// DisableFunction disables an existing function (including built-ins) using
// sqlite3_create_function. The name and numArgs must match the existing
// function's signature.
//
Comment on lines +63 to +66
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor nit

Suggested change
// DisableFunction disables an existing function (including built-ins) using
// sqlite3_create_function. The name and numArgs must match the existing
// function's signature.
//
// DisableFunction disables an existing function (including built-ins) using
// sqlite3_create_function. The name and numArgs must match the existing
// function's signature.

It is probably also worth noting that we "disable" it by replacing that function name with a stub that does nothing, rather than making it report an error. That might be relevant to someone debugging.

DisableFunction(name string, numArgs int) error
}

// Stmt is an sqlite3_stmt* database connection object.
Expand Down
Loading