diff --git a/cgosqlite/cgosqlite.go b/cgosqlite/cgosqlite.go index 80206c3..4d7d8fc 100644 --- a/cgosqlite/cgosqlite.go +++ b/cgosqlite/cgosqlite.go @@ -253,11 +253,11 @@ func (stmt *Stmt) StartTimer() { } func (stmt *Stmt) ColumnDatabaseName(col int) string { - return C.GoString((*C.char)(unsafe.Pointer(C.sqlite3_column_database_name(stmt.stmt.ptr(), C.int(col))))) + return internStringFromCString(C.sqlite3_column_database_name(stmt.stmt.ptr(), C.int(col))) } func (stmt *Stmt) ColumnTableName(col int) string { - return C.GoString((*C.char)(unsafe.Pointer(C.sqlite3_column_table_name(stmt.stmt.ptr(), C.int(col))))) + return internStringFromCString(C.sqlite3_column_table_name(stmt.stmt.ptr(), C.int(col))) } func (stmt *Stmt) Step(colType []sqliteh.ColumnType) (row bool, err error) { @@ -330,11 +330,7 @@ func (stmt *Stmt) BindParameterCount() int { } func (stmt *Stmt) BindParameterName(col int) string { - cstr := C.sqlite3_bind_parameter_name(stmt.stmt.ptr(), C.int(col)) - if cstr == nil { - return "" - } - return C.GoString(cstr) + return internStringFromCString(C.sqlite3_bind_parameter_name(stmt.stmt.ptr(), C.int(col))) } func (stmt *Stmt) BindParameterIndex(name string) int { @@ -357,7 +353,7 @@ func (stmt *Stmt) ColumnCount() int { } func (stmt *Stmt) ColumnName(col int) string { - return C.GoString(C.sqlite3_column_name(stmt.stmt.ptr(), C.int(col))) + return internStringFromCString(C.sqlite3_column_name(stmt.stmt.ptr(), C.int(col))) } func (stmt *Stmt) ColumnText(col int) string { @@ -395,15 +391,15 @@ func (stmt *Stmt) ColumnDeclType(col int) string { if cstr == nil { return "" } - clen := C.strlen(cstr) - b := (*[1 << 30]byte)(unsafe.Pointer(cstr))[0:clen] + bstr := (*byte)(unsafe.Pointer(cstr)) + clen := findnull(bstr) if stmt.db.declTypes == nil { stmt.db.declTypes = make(map[string]string) } - if res, found := stmt.db.declTypes[string(b)]; found { + if res, found := stmt.db.declTypes[unsafe.String(bstr, clen)]; found { return res } - res := string(b) + res := C.GoStringN(cstr, C.int(clen)) stmt.db.declTypes[res] = res return res } @@ -415,18 +411,29 @@ func errCode(code C.int) error { return sqliteh.CodeAsError(sqliteh.Code(code)) // internCache contains interned strings. var internCache sync.Map // string => string (key == value) -// stringFromBytes returns string(b), interned into a map forever. It's meant +// internStringFromBytes returns string(b), interned into a map forever. It's meant // for use on hot, small strings from closed set (like database or table or // column names) where it doesn't matter if it leaks forever. -func stringFromBytes(b []byte) string { +func internStringFromBytes(b []byte) string { if len(b) == 0 { return "" } - v, _ := internCache.Load(unsafe.String(&b[0], len(b))) + return internStringFromPtr((*C.char)(unsafe.Pointer(&b[0])), C.int(len(b))) +} + +func internStringFromPtr(p *C.char, n C.int) string { + if n == 0 { + return "" + } + v, _ := internCache.Load(unsafe.String((*byte)(unsafe.Pointer(p)), int(n))) if s, ok := v.(string); ok { return s } - s := string(b) + s := C.GoStringN(p, n) internCache.Store(s, s) return s } + +func internStringFromCString(p *C.char) string { + return internStringFromPtr(p, C.int(findnull((*byte)(unsafe.Pointer(p))))) +} diff --git a/cgosqlite/stubs.go b/cgosqlite/stubs.go new file mode 100644 index 0000000..1bea848 --- /dev/null +++ b/cgosqlite/stubs.go @@ -0,0 +1,10 @@ +package cgosqlite + +import _ "unsafe" + +// findnull exposes the runtime.findnull function to the cgosqlite package, this +// is a wide instruction optimized page by page null byte search aka fast +// strlen. +// +//go:linkname findnull runtime.findnull +func findnull(*byte) int diff --git a/cgosqlite/walcallback.go b/cgosqlite/walcallback.go index e3fd0d5..06a174f 100644 --- a/cgosqlite/walcallback.go +++ b/cgosqlite/walcallback.go @@ -20,7 +20,7 @@ func walCallbackGo(db *C.sqlite3, dbNameC *C.char, dbNameLen C.int, pages C.int) } dbNameB := unsafe.Slice((*byte)(unsafe.Pointer(dbNameC)), dbNameLen) - dbName := stringFromBytes(dbNameB) + dbName := internStringFromBytes(dbNameB) hook(dbName, int(pages)) return C.int(0) // result's kinda useless }