From f89e9f458f401be3e56bc91aeaab1ba328a3777b Mon Sep 17 00:00:00 2001 From: Kevin Galligan Date: Sun, 26 Jan 2025 14:11:55 -0500 Subject: [PATCH 1/3] SqlDelight internal to PowerSync api --- .../DatabaseDriverFactory.android.kt | 2 +- .../com/powersync/DatabaseDriverFactory.kt | 2 +- .../kotlin/com/powersync/PsSqlDriver.kt | 12 ++++++------ .../com/powersync/bucket/BucketStorageImpl.kt | 2 +- .../com/powersync/db/PowerSyncDatabaseImpl.kt | 1 - .../kotlin/com/powersync/db/Queries.kt | 1 - .../kotlin/com/powersync/db/SqlCursor.kt | 13 +++++++++++++ .../db/internal/InternalDatabaseImpl.kt | 10 +++++----- .../db/internal/PowerSyncTransaction.kt | 2 +- .../powersync/db/internal/SqlCursorWrapper.kt | 19 +++++++++++++++++++ .../powersync/DatabaseDriverFactory.ios.kt | 2 +- .../powersync/DatabaseDriverFactory.jvm.kt | 2 +- 12 files changed, 49 insertions(+), 19 deletions(-) create mode 100644 core/src/commonMain/kotlin/com/powersync/db/SqlCursor.kt create mode 100644 core/src/commonMain/kotlin/com/powersync/db/internal/SqlCursorWrapper.kt diff --git a/core/src/androidMain/kotlin/com/powersync/DatabaseDriverFactory.android.kt b/core/src/androidMain/kotlin/com/powersync/DatabaseDriverFactory.android.kt index ba5ad0a1..47cb9594 100644 --- a/core/src/androidMain/kotlin/com/powersync/DatabaseDriverFactory.android.kt +++ b/core/src/androidMain/kotlin/com/powersync/DatabaseDriverFactory.android.kt @@ -32,7 +32,7 @@ public actual class DatabaseDriverFactory( } } - public actual fun createDriver( + internal actual fun createDriver( scope: CoroutineScope, dbFilename: String, ): PsSqlDriver { diff --git a/core/src/commonMain/kotlin/com/powersync/DatabaseDriverFactory.kt b/core/src/commonMain/kotlin/com/powersync/DatabaseDriverFactory.kt index dfe1c07d..ab469ff0 100644 --- a/core/src/commonMain/kotlin/com/powersync/DatabaseDriverFactory.kt +++ b/core/src/commonMain/kotlin/com/powersync/DatabaseDriverFactory.kt @@ -4,7 +4,7 @@ import kotlinx.coroutines.CoroutineScope @Suppress("EXPECT_ACTUAL_CLASSIFIERS_ARE_IN_BETA_WARNING") public expect class DatabaseDriverFactory { - public fun createDriver( + internal fun createDriver( scope: CoroutineScope, dbFilename: String, ): PsSqlDriver diff --git a/core/src/commonMain/kotlin/com/powersync/PsSqlDriver.kt b/core/src/commonMain/kotlin/com/powersync/PsSqlDriver.kt index 7e1a4052..a1ec0bfb 100644 --- a/core/src/commonMain/kotlin/com/powersync/PsSqlDriver.kt +++ b/core/src/commonMain/kotlin/com/powersync/PsSqlDriver.kt @@ -9,7 +9,7 @@ import kotlinx.coroutines.flow.filter import kotlinx.coroutines.flow.map import kotlinx.coroutines.launch -public class PsSqlDriver( +internal class PsSqlDriver( private val driver: SqlDriver, private val scope: CoroutineScope, ) : SqlDriver by driver { @@ -19,21 +19,21 @@ public class PsSqlDriver( // In-memory buffer to store table names before flushing private val pendingUpdates = mutableSetOf() - public fun updateTable(tableName: String) { + fun updateTable(tableName: String) { pendingUpdates.add(tableName) } - public fun clearTableUpdates() { + fun clearTableUpdates() { pendingUpdates.clear() } // Flows on table updates - public fun tableUpdates(): Flow> = tableUpdatesFlow.asSharedFlow() + fun tableUpdates(): Flow> = tableUpdatesFlow.asSharedFlow() // Flows on table updates containing a specific table - public fun updatesOnTable(tableName: String): Flow = tableUpdates().filter { it.contains(tableName) }.map { } + fun updatesOnTable(tableName: String): Flow = tableUpdates().filter { it.contains(tableName) }.map { } - public fun fireTableUpdates() { + fun fireTableUpdates() { val updates = pendingUpdates.toList() if (updates.isEmpty()) { return diff --git a/core/src/commonMain/kotlin/com/powersync/bucket/BucketStorageImpl.kt b/core/src/commonMain/kotlin/com/powersync/bucket/BucketStorageImpl.kt index 4113f12e..1255e22c 100644 --- a/core/src/commonMain/kotlin/com/powersync/bucket/BucketStorageImpl.kt +++ b/core/src/commonMain/kotlin/com/powersync/bucket/BucketStorageImpl.kt @@ -1,8 +1,8 @@ package com.powersync.bucket -import app.cash.sqldelight.db.SqlCursor import co.touchlab.kermit.Logger import co.touchlab.stately.concurrency.AtomicBoolean +import com.powersync.db.SqlCursor import com.powersync.db.crud.CrudEntry import com.powersync.db.crud.CrudRow import com.powersync.db.internal.InternalDatabase diff --git a/core/src/commonMain/kotlin/com/powersync/db/PowerSyncDatabaseImpl.kt b/core/src/commonMain/kotlin/com/powersync/db/PowerSyncDatabaseImpl.kt index 0a69d49b..cefd591f 100644 --- a/core/src/commonMain/kotlin/com/powersync/db/PowerSyncDatabaseImpl.kt +++ b/core/src/commonMain/kotlin/com/powersync/db/PowerSyncDatabaseImpl.kt @@ -1,6 +1,5 @@ package com.powersync.db -import app.cash.sqldelight.db.SqlCursor import co.touchlab.kermit.Logger import com.powersync.DatabaseDriverFactory import com.powersync.PowerSyncDatabase diff --git a/core/src/commonMain/kotlin/com/powersync/db/Queries.kt b/core/src/commonMain/kotlin/com/powersync/db/Queries.kt index ba580d6d..90c72e19 100644 --- a/core/src/commonMain/kotlin/com/powersync/db/Queries.kt +++ b/core/src/commonMain/kotlin/com/powersync/db/Queries.kt @@ -1,6 +1,5 @@ package com.powersync.db -import app.cash.sqldelight.db.SqlCursor import com.powersync.db.internal.PowerSyncTransaction import kotlinx.coroutines.flow.Flow diff --git a/core/src/commonMain/kotlin/com/powersync/db/SqlCursor.kt b/core/src/commonMain/kotlin/com/powersync/db/SqlCursor.kt new file mode 100644 index 00000000..5edffbb6 --- /dev/null +++ b/core/src/commonMain/kotlin/com/powersync/db/SqlCursor.kt @@ -0,0 +1,13 @@ +package com.powersync.db + +public interface SqlCursor { + public fun getBoolean(index: Int): Boolean? + + public fun getBytes(index: Int): ByteArray? + + public fun getDouble(index: Int): Double? + + public fun getLong(index: Int): Long? + + public fun getString(index: Int): String? +} \ No newline at end of file diff --git a/core/src/commonMain/kotlin/com/powersync/db/internal/InternalDatabaseImpl.kt b/core/src/commonMain/kotlin/com/powersync/db/internal/InternalDatabaseImpl.kt index 59549f1a..59ae6d57 100644 --- a/core/src/commonMain/kotlin/com/powersync/db/internal/InternalDatabaseImpl.kt +++ b/core/src/commonMain/kotlin/com/powersync/db/internal/InternalDatabaseImpl.kt @@ -5,10 +5,10 @@ import app.cash.sqldelight.Query import app.cash.sqldelight.coroutines.asFlow import app.cash.sqldelight.coroutines.mapToList import app.cash.sqldelight.db.QueryResult -import app.cash.sqldelight.db.SqlCursor import app.cash.sqldelight.db.SqlPreparedStatement import com.persistence.PowersyncQueries import com.powersync.PsSqlDriver +import com.powersync.db.SqlCursor import com.powersync.persistence.PsDatabase import com.powersync.utils.JsonUtil import kotlinx.coroutines.CoroutineScope @@ -184,8 +184,8 @@ internal class InternalDatabaseImpl( parameters: Int = 0, binders: (SqlPreparedStatement.() -> Unit)? = null, ): ExecutableQuery = - object : ExecutableQuery(mapper) { - override fun execute(mapper: (SqlCursor) -> QueryResult): QueryResult = + object : ExecutableQuery(wrapperMapper(mapper)) { + override fun execute(mapper: (app.cash.sqldelight.db.SqlCursor) -> QueryResult): QueryResult = driver.executeQuery(null, query, mapper, parameters, binders) } @@ -196,8 +196,8 @@ internal class InternalDatabaseImpl( binders: (SqlPreparedStatement.() -> Unit)? = null, tables: Set = setOf(), ): Query = - object : Query(mapper) { - override fun execute(mapper: (SqlCursor) -> QueryResult): QueryResult = + object : Query(wrapperMapper(mapper)) { + override fun execute(mapper: (app.cash.sqldelight.db.SqlCursor) -> QueryResult): QueryResult = driver.executeQuery(null, query, mapper, parameters, binders) override fun addListener(listener: Listener) { diff --git a/core/src/commonMain/kotlin/com/powersync/db/internal/PowerSyncTransaction.kt b/core/src/commonMain/kotlin/com/powersync/db/internal/PowerSyncTransaction.kt index 7d696aef..56d80138 100644 --- a/core/src/commonMain/kotlin/com/powersync/db/internal/PowerSyncTransaction.kt +++ b/core/src/commonMain/kotlin/com/powersync/db/internal/PowerSyncTransaction.kt @@ -1,6 +1,6 @@ package com.powersync.db.internal -import app.cash.sqldelight.db.SqlCursor +import com.powersync.db.SqlCursor public interface PowerSyncTransaction { public fun execute( diff --git a/core/src/commonMain/kotlin/com/powersync/db/internal/SqlCursorWrapper.kt b/core/src/commonMain/kotlin/com/powersync/db/internal/SqlCursorWrapper.kt new file mode 100644 index 00000000..e8232007 --- /dev/null +++ b/core/src/commonMain/kotlin/com/powersync/db/internal/SqlCursorWrapper.kt @@ -0,0 +1,19 @@ +package com.powersync.db.internal + +import app.cash.sqldelight.db.SqlCursor + +internal class SqlCursorWrapper(val realCursor: SqlCursor):com.powersync.db.SqlCursor { + override fun getBoolean(index: Int): Boolean? = realCursor.getBoolean(index) + + override fun getBytes(index: Int): ByteArray? = realCursor.getBytes(index) + + override fun getDouble(index: Int): Double? = realCursor.getDouble(index) + + override fun getLong(index: Int): Long? = realCursor.getLong(index) + + override fun getString(index: Int): String? = realCursor.getString(index) +} + +internal fun wrapperMapper(mapper:(com.powersync.db.SqlCursor)->T):(SqlCursor)->T{ + return {realCursor -> mapper(SqlCursorWrapper(realCursor))} +} \ No newline at end of file diff --git a/core/src/iosMain/kotlin/com/powersync/DatabaseDriverFactory.ios.kt b/core/src/iosMain/kotlin/com/powersync/DatabaseDriverFactory.ios.kt index 9c53e43e..48386ac7 100644 --- a/core/src/iosMain/kotlin/com/powersync/DatabaseDriverFactory.ios.kt +++ b/core/src/iosMain/kotlin/com/powersync/DatabaseDriverFactory.ios.kt @@ -46,7 +46,7 @@ public actual class DatabaseDriverFactory { } } - public actual fun createDriver( + internal actual fun createDriver( scope: CoroutineScope, dbFilename: String, ): PsSqlDriver { diff --git a/core/src/jvmMain/kotlin/com/powersync/DatabaseDriverFactory.jvm.kt b/core/src/jvmMain/kotlin/com/powersync/DatabaseDriverFactory.jvm.kt index 97a674a4..4010e507 100644 --- a/core/src/jvmMain/kotlin/com/powersync/DatabaseDriverFactory.jvm.kt +++ b/core/src/jvmMain/kotlin/com/powersync/DatabaseDriverFactory.jvm.kt @@ -29,7 +29,7 @@ public actual class DatabaseDriverFactory { } } - public actual fun createDriver( + internal actual fun createDriver( scope: CoroutineScope, dbFilename: String, ): PsSqlDriver { From 3ab4cb6598bf8279b4a248a7591dd96ff11464ad Mon Sep 17 00:00:00 2001 From: Kevin Galligan Date: Wed, 29 Jan 2025 10:52:16 -0500 Subject: [PATCH 2/3] Import driver code --- .../DatabaseDriverFactory.android.kt | 2 +- .../kotlin/com/powersync/db/SqlCursor.kt | 14 +- .../powersync/db/internal/SqlCursorWrapper.kt | 34 +- .../powersync/DatabaseDriverFactory.ios.kt | 4 +- .../com/powersync/PSJdbcSqliteDriver.kt | 2 +- gradle/libs.versions.toml | 3 + persistence/build.gradle.kts | 1 + .../persistence/driver/AndroidSqliteDriver.kt | 358 +++++++++++++++ .../persistence/driver/ColNamesSqlCursor.kt | 9 + .../powersync/persistence/driver/Borrowed.kt | 6 + .../persistence/driver/NativeSqlDatabase.kt | 433 ++++++++++++++++++ .../com/powersync/persistence/driver/Pool.kt | 123 +++++ .../persistence/driver/SqliterSqlCursor.kt | 34 ++ .../persistence/driver/SqliterStatement.kt | 42 ++ .../persistence/driver/util/PoolLock.kt | 89 ++++ .../driver/JdbcPreparedStatement.kt | 163 +++++++ 16 files changed, 1309 insertions(+), 8 deletions(-) create mode 100644 persistence/src/androidMain/kotlin/com/powersync/persistence/driver/AndroidSqliteDriver.kt create mode 100644 persistence/src/commonMain/kotlin/com/powersync/persistence/driver/ColNamesSqlCursor.kt create mode 100644 persistence/src/iosMain/kotlin/com/powersync/persistence/driver/Borrowed.kt create mode 100644 persistence/src/iosMain/kotlin/com/powersync/persistence/driver/NativeSqlDatabase.kt create mode 100644 persistence/src/iosMain/kotlin/com/powersync/persistence/driver/Pool.kt create mode 100644 persistence/src/iosMain/kotlin/com/powersync/persistence/driver/SqliterSqlCursor.kt create mode 100644 persistence/src/iosMain/kotlin/com/powersync/persistence/driver/SqliterStatement.kt create mode 100644 persistence/src/iosMain/kotlin/com/powersync/persistence/driver/util/PoolLock.kt create mode 100644 persistence/src/jvmMain/kotlin/com/powersync/persistence/driver/JdbcPreparedStatement.kt diff --git a/core/src/androidMain/kotlin/com/powersync/DatabaseDriverFactory.android.kt b/core/src/androidMain/kotlin/com/powersync/DatabaseDriverFactory.android.kt index 47cb9594..85d3c512 100644 --- a/core/src/androidMain/kotlin/com/powersync/DatabaseDriverFactory.android.kt +++ b/core/src/androidMain/kotlin/com/powersync/DatabaseDriverFactory.android.kt @@ -2,8 +2,8 @@ package com.powersync import android.content.Context import androidx.sqlite.db.SupportSQLiteDatabase -import app.cash.sqldelight.driver.android.AndroidSqliteDriver import com.powersync.db.internal.InternalSchema +import com.powersync.persistence.driver.AndroidSqliteDriver import io.requery.android.database.sqlite.RequerySQLiteOpenHelperFactory import io.requery.android.database.sqlite.SQLiteCustomExtension import kotlinx.coroutines.CoroutineScope diff --git a/core/src/commonMain/kotlin/com/powersync/db/SqlCursor.kt b/core/src/commonMain/kotlin/com/powersync/db/SqlCursor.kt index 5edffbb6..74e6bf90 100644 --- a/core/src/commonMain/kotlin/com/powersync/db/SqlCursor.kt +++ b/core/src/commonMain/kotlin/com/powersync/db/SqlCursor.kt @@ -10,4 +10,16 @@ public interface SqlCursor { public fun getLong(index: Int): Long? public fun getString(index: Int): String? -} \ No newline at end of file + + public fun columnName(index: Int): String? + + public val columnCount: Int + + public val columnNames: Map +} + +public fun SqlCursor.getBoolean(name: String): Boolean? = columnNames[name]?.let { getBoolean(it) } +public fun SqlCursor.getBytes(name: String): ByteArray? = columnNames[name]?.let { getBytes(it) } +public fun SqlCursor.getDouble(name: String): Double? = columnNames[name]?.let { getDouble(it) } +public fun SqlCursor.getLong(name: String): Long? = columnNames[name]?.let { getLong(it) } +public fun SqlCursor.getString(name: String): String? = columnNames[name]?.let { getString(it) } \ No newline at end of file diff --git a/core/src/commonMain/kotlin/com/powersync/db/internal/SqlCursorWrapper.kt b/core/src/commonMain/kotlin/com/powersync/db/internal/SqlCursorWrapper.kt index e8232007..90cb8e0c 100644 --- a/core/src/commonMain/kotlin/com/powersync/db/internal/SqlCursorWrapper.kt +++ b/core/src/commonMain/kotlin/com/powersync/db/internal/SqlCursorWrapper.kt @@ -1,8 +1,9 @@ package com.powersync.db.internal import app.cash.sqldelight.db.SqlCursor +import com.powersync.persistence.driver.ColNamesSqlCursor -internal class SqlCursorWrapper(val realCursor: SqlCursor):com.powersync.db.SqlCursor { +internal class SqlCursorWrapper(val realCursor: ColNamesSqlCursor) : com.powersync.db.SqlCursor { override fun getBoolean(index: Int): Boolean? = realCursor.getBoolean(index) override fun getBytes(index: Int): ByteArray? = realCursor.getBytes(index) @@ -12,8 +13,35 @@ internal class SqlCursorWrapper(val realCursor: SqlCursor):com.powersync.db.SqlC override fun getLong(index: Int): Long? = realCursor.getLong(index) override fun getString(index: Int): String? = realCursor.getString(index) + + override fun columnName(index: Int): String? = realCursor.columnName(index) + + override val columnCount: Int + get() = realCursor.columnCount + + override val columnNames: Map by lazy { + val map = HashMap(this.columnCount) + for (i in 0 until columnCount) { + val key = columnName(i) + if (key == null) { + continue + } + if (map.containsKey(key)) { + var index = 1 + val basicKey = "$key&JOIN" + var finalKey = basicKey + index + while (map.containsKey(finalKey)) { + finalKey = basicKey + ++index + } + map[finalKey] = i + } else { + map[key] = i + } + } + map + } } -internal fun wrapperMapper(mapper:(com.powersync.db.SqlCursor)->T):(SqlCursor)->T{ - return {realCursor -> mapper(SqlCursorWrapper(realCursor))} +internal fun wrapperMapper(mapper: (com.powersync.db.SqlCursor) -> T): (SqlCursor) -> T { + return { realCursor -> mapper(SqlCursorWrapper(realCursor as ColNamesSqlCursor)) } } \ No newline at end of file diff --git a/core/src/iosMain/kotlin/com/powersync/DatabaseDriverFactory.ios.kt b/core/src/iosMain/kotlin/com/powersync/DatabaseDriverFactory.ios.kt index 48386ac7..cec18926 100644 --- a/core/src/iosMain/kotlin/com/powersync/DatabaseDriverFactory.ios.kt +++ b/core/src/iosMain/kotlin/com/powersync/DatabaseDriverFactory.ios.kt @@ -1,10 +1,10 @@ package com.powersync -import app.cash.sqldelight.driver.native.NativeSqliteDriver -import app.cash.sqldelight.driver.native.wrapConnection import co.touchlab.sqliter.DatabaseConfiguration import co.touchlab.sqliter.DatabaseConnection import com.powersync.db.internal.InternalSchema +import com.powersync.persistence.driver.NativeSqliteDriver +import com.powersync.persistence.driver.wrapConnection import com.powersync.sqlite.core.init_powersync_sqlite_extension import com.powersync.sqlite.core.sqlite3_commit_hook import com.powersync.sqlite.core.sqlite3_rollback_hook diff --git a/core/src/jvmMain/kotlin/com/powersync/PSJdbcSqliteDriver.kt b/core/src/jvmMain/kotlin/com/powersync/PSJdbcSqliteDriver.kt index a043cd51..518299a5 100644 --- a/core/src/jvmMain/kotlin/com/powersync/PSJdbcSqliteDriver.kt +++ b/core/src/jvmMain/kotlin/com/powersync/PSJdbcSqliteDriver.kt @@ -8,7 +8,7 @@ import app.cash.sqldelight.db.SqlCursor import app.cash.sqldelight.db.SqlDriver import app.cash.sqldelight.db.SqlPreparedStatement import app.cash.sqldelight.db.SqlSchema -import app.cash.sqldelight.driver.jdbc.JdbcPreparedStatement +import com.powersync.persistence.driver.JdbcPreparedStatement import org.sqlite.SQLiteConnection import java.nio.file.Path import java.sql.DriverManager diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index d3bb5953..2150fcb9 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -26,6 +26,7 @@ junit = "4.13.2" compose = "1.6.11" compose-preview = "1.7.2" +androidxSqlite = "2.4.0" # plugins android-gradle-plugin = "8.5.1" @@ -88,6 +89,8 @@ stately-concurrency = { module = "co.touchlab:stately-concurrency", version.ref supabase-client = { module = "io.github.jan-tennert.supabase:postgrest-kt", version.ref = "supabase" } supabase-auth = { module = "io.github.jan-tennert.supabase:auth-kt", version.ref = "supabase" } +androidx-sqliteFramework = { module = "androidx.sqlite:sqlite-framework", version.ref = "androidxSqlite" } + # Sample - Android androidx-core = { group = "androidx.core", name = "core-ktx", version.ref = "androidx-core" } androidx-appcompat = { group = "androidx.appcompat", name = "appcompat", version.ref = "androidx-appcompat" } diff --git a/persistence/build.gradle.kts b/persistence/build.gradle.kts index 4997bd45..40f5d701 100644 --- a/persistence/build.gradle.kts +++ b/persistence/build.gradle.kts @@ -31,6 +31,7 @@ kotlin { api(libs.sqldelight.driver.android) api(libs.powersync.sqlite.core.android) api(libs.requery.sqlite.android) + implementation(libs.androidx.sqliteFramework) } jvmMain.dependencies { diff --git a/persistence/src/androidMain/kotlin/com/powersync/persistence/driver/AndroidSqliteDriver.kt b/persistence/src/androidMain/kotlin/com/powersync/persistence/driver/AndroidSqliteDriver.kt new file mode 100644 index 00000000..70433a17 --- /dev/null +++ b/persistence/src/androidMain/kotlin/com/powersync/persistence/driver/AndroidSqliteDriver.kt @@ -0,0 +1,358 @@ +package com.powersync.persistence.driver + +import android.content.Context +import android.database.AbstractWindowedCursor +import android.database.Cursor +import android.database.CursorWindow +import android.os.Build +import android.util.LruCache +import androidx.annotation.DoNotInline +import androidx.annotation.RequiresApi +import androidx.sqlite.db.SupportSQLiteDatabase +import androidx.sqlite.db.SupportSQLiteOpenHelper +import androidx.sqlite.db.SupportSQLiteProgram +import androidx.sqlite.db.SupportSQLiteQuery +import androidx.sqlite.db.SupportSQLiteStatement +import androidx.sqlite.db.framework.FrameworkSQLiteOpenHelperFactory +import app.cash.sqldelight.Query +import app.cash.sqldelight.Transacter +import app.cash.sqldelight.db.AfterVersion +import app.cash.sqldelight.db.QueryResult +import app.cash.sqldelight.db.SqlCursor +import app.cash.sqldelight.db.SqlDriver +import app.cash.sqldelight.db.SqlPreparedStatement +import app.cash.sqldelight.db.SqlSchema +import com.powersync.persistence.driver.Api28Impl.setWindowSize + +import kotlin.collections.forEach +import kotlin.collections.getOrPut +import kotlin.io.use +import kotlin.let + +private const val DEFAULT_CACHE_SIZE = 20 + +public class AndroidSqliteDriver private constructor( + private val openHelper: SupportSQLiteOpenHelper? = null, + database: SupportSQLiteDatabase? = null, + private val cacheSize: Int, + private val windowSizeBytes: Long? = null, +) : SqlDriver { + init { + require((openHelper != null) xor (database != null)) + } + + private val transactions = ThreadLocal() + private val database by lazy { + openHelper?.writableDatabase ?: database!! + } + + public constructor( + openHelper: SupportSQLiteOpenHelper, + ) : this(openHelper = openHelper, database = null, cacheSize = DEFAULT_CACHE_SIZE, windowSizeBytes = null) + + /** + * @param [cacheSize] The number of compiled sqlite statements to keep in memory per connection. + * Defaults to 20. + * @param [useNoBackupDirectory] Sets whether to use a no backup directory or not. + * @param [windowSizeBytes] Size of cursor window in bytes, per [CursorWindow] (Android 28+ only), or null to use the default. + */ + @JvmOverloads + public constructor( + schema: SqlSchema>, + context: Context, + name: String? = null, + factory: SupportSQLiteOpenHelper.Factory = FrameworkSQLiteOpenHelperFactory(), + callback: SupportSQLiteOpenHelper.Callback = AndroidSqliteDriver.Callback(schema), + cacheSize: Int = DEFAULT_CACHE_SIZE, + useNoBackupDirectory: Boolean = false, + windowSizeBytes: Long? = null, + ) : this( + database = null, + openHelper = factory.create( + SupportSQLiteOpenHelper.Configuration.builder(context) + .callback(callback) + .name(name) + .noBackupDirectory(useNoBackupDirectory) + .build(), + ), + cacheSize = cacheSize, + windowSizeBytes = windowSizeBytes, + ) + + @JvmOverloads + public constructor( + database: SupportSQLiteDatabase, + cacheSize: Int = DEFAULT_CACHE_SIZE, + windowSizeBytes: Long? = null, + ) : this(openHelper = null, database = database, cacheSize = cacheSize, windowSizeBytes = windowSizeBytes) + + private val statements = object : LruCache(cacheSize) { + override fun entryRemoved( + evicted: Boolean, + key: Int, + oldValue: AndroidStatement, + newValue: AndroidStatement?, + ) { + if (evicted) oldValue.close() + } + } + + private val listeners = linkedMapOf>() + + override fun addListener(vararg queryKeys: String, listener: Query.Listener) { + synchronized(listeners) { + queryKeys.forEach { + listeners.getOrPut(it, { linkedSetOf() }).add(listener) + } + } + } + + override fun removeListener(vararg queryKeys: String, listener: Query.Listener) { + synchronized(listeners) { + queryKeys.forEach { + listeners[it]?.remove(listener) + } + } + } + + override fun notifyListeners(vararg queryKeys: String) { + val listenersToNotify = linkedSetOf() + synchronized(listeners) { + queryKeys.forEach { listeners[it]?.let(listenersToNotify::addAll) } + } + listenersToNotify.forEach(Query.Listener::queryResultsChanged) + } + + override fun newTransaction(): QueryResult { + val enclosing = transactions.get() + val transaction = Transaction(enclosing) + transactions.set(transaction) + + if (enclosing == null) { + database.beginTransactionNonExclusive() + } + + return QueryResult.Value(transaction) + } + + override fun currentTransaction(): Transacter.Transaction? = transactions.get() + + internal inner class Transaction( + override val enclosingTransaction: Transacter.Transaction?, + ) : Transacter.Transaction() { + override fun endTransaction(successful: Boolean): QueryResult { + if (enclosingTransaction == null) { + if (successful) { + database.setTransactionSuccessful() + database.endTransaction() + } else { + database.endTransaction() + } + } + transactions.set(enclosingTransaction) + return QueryResult.Unit + } + } + + private fun execute( + identifier: Int?, + createStatement: () -> AndroidStatement, + binders: (SqlPreparedStatement.() -> Unit)?, + result: AndroidStatement.() -> T, + ): QueryResult.Value { + var statement: AndroidStatement? = null + if (identifier != null) { + statement = statements.remove(identifier) + } + if (statement == null) { + statement = createStatement() + } + try { + if (binders != null) { + statement.binders() + } + return QueryResult.Value(statement.result()) + } finally { + if (identifier != null) { + statements.put(identifier, statement)?.close() + } else { + statement.close() + } + } + } + + override fun execute( + identifier: Int?, + sql: String, + parameters: Int, + binders: (SqlPreparedStatement.() -> Unit)?, + ): QueryResult = execute(identifier, { AndroidPreparedStatement(database.compileStatement(sql)) }, binders, { execute() }) + + override fun executeQuery( + identifier: Int?, + sql: String, + mapper: (SqlCursor) -> QueryResult, + parameters: Int, + binders: (SqlPreparedStatement.() -> Unit)?, + ): QueryResult.Value = execute(identifier, { AndroidQuery(sql, database, parameters, windowSizeBytes) }, binders) { executeQuery(mapper) } + + override fun close() { + statements.evictAll() + return openHelper?.close() ?: database.close() + } + + public open class Callback( + private val schema: SqlSchema>, + private vararg val callbacks: AfterVersion, + ) : SupportSQLiteOpenHelper.Callback( + if (schema.version > Int.MAX_VALUE) error("Schema version is larger than Int.MAX_VALUE: ${schema.version}.") else schema.version.toInt(), + ) { + + override fun onCreate(db: SupportSQLiteDatabase) { + schema.create(AndroidSqliteDriver(openHelper = null, database = db, cacheSize = 1)) + } + + override fun onUpgrade( + db: SupportSQLiteDatabase, + oldVersion: Int, + newVersion: Int, + ) { + schema.migrate( + AndroidSqliteDriver(openHelper = null, database = db, cacheSize = 1), + oldVersion.toLong(), + newVersion.toLong(), + *callbacks, + ) + } + } +} + +internal interface AndroidStatement : SqlPreparedStatement { + fun execute(): Long + fun executeQuery(mapper: (SqlCursor) -> QueryResult): R + fun close() +} + +private class AndroidPreparedStatement( + private val statement: SupportSQLiteStatement, +) : AndroidStatement { + override fun bindBytes(index: Int, bytes: ByteArray?) { + if (bytes == null) statement.bindNull(index + 1) else statement.bindBlob(index + 1, bytes) + } + + override fun bindLong(index: Int, long: Long?) { + if (long == null) statement.bindNull(index + 1) else statement.bindLong(index + 1, long) + } + + override fun bindDouble(index: Int, double: Double?) { + if (double == null) statement.bindNull(index + 1) else statement.bindDouble(index + 1, double) + } + + override fun bindString(index: Int, string: String?) { + if (string == null) statement.bindNull(index + 1) else statement.bindString(index + 1, string) + } + + override fun bindBoolean(index: Int, boolean: Boolean?) { + if (boolean == null) { + statement.bindNull(index + 1) + } else { + statement.bindLong(index + 1, if (boolean) 1L else 0L) + } + } + + override fun executeQuery(mapper: (SqlCursor) -> QueryResult): R = throw UnsupportedOperationException() + + override fun execute(): Long { + return statement.executeUpdateDelete().toLong() + } + + override fun close() { + statement.close() + } +} + +private class AndroidQuery( + override val sql: String, + private val database: SupportSQLiteDatabase, + override val argCount: Int, + private val windowSizeBytes: Long?, +) : SupportSQLiteQuery, + AndroidStatement { + private val binds = MutableList<((SupportSQLiteProgram) -> Unit)?>(argCount) { null } + + override fun bindBytes(index: Int, bytes: ByteArray?) { + binds[index] = { if (bytes == null) it.bindNull(index + 1) else it.bindBlob(index + 1, bytes) } + } + + override fun bindLong(index: Int, long: Long?) { + binds[index] = { if (long == null) it.bindNull(index + 1) else it.bindLong(index + 1, long) } + } + + override fun bindDouble(index: Int, double: Double?) { + binds[index] = { if (double == null) it.bindNull(index + 1) else it.bindDouble(index + 1, double) } + } + + override fun bindString(index: Int, string: String?) { + binds[index] = { if (string == null) it.bindNull(index + 1) else it.bindString(index + 1, string) } + } + + override fun bindBoolean(index: Int, boolean: Boolean?) { + binds[index] = { + if (boolean == null) { + it.bindNull(index + 1) + } else { + it.bindLong(index + 1, if (boolean) 1L else 0L) + } + } + } + + override fun execute() = throw UnsupportedOperationException() + + override fun executeQuery(mapper: (SqlCursor) -> QueryResult): R { + return database.query(this) + .use { cursor -> mapper(AndroidCursor(cursor, windowSizeBytes)).value } + } + + override fun bindTo(statement: SupportSQLiteProgram) { + for (action in binds) { + action!!(statement) + } + } + + override fun toString() = sql + + override fun close() {} +} + +private class AndroidCursor( + private val cursor: Cursor, + windowSizeBytes: Long?, +) : ColNamesSqlCursor { + init { + if ( + Build.VERSION.SDK_INT >= Build.VERSION_CODES.P && + windowSizeBytes != null && + cursor is AbstractWindowedCursor + ) { + cursor.setWindowSize(windowSizeBytes) + } + } + + override fun next(): QueryResult.Value = QueryResult.Value(cursor.moveToNext()) + override fun getString(index: Int) = if (cursor.isNull(index)) null else cursor.getString(index) + override fun getLong(index: Int) = if (cursor.isNull(index)) null else cursor.getLong(index) + override fun getBytes(index: Int) = if (cursor.isNull(index)) null else cursor.getBlob(index) + override fun getDouble(index: Int) = if (cursor.isNull(index)) null else cursor.getDouble(index) + override fun getBoolean(index: Int) = if (cursor.isNull(index)) null else cursor.getLong(index) == 1L + override fun columnName(index: Int): String? = cursor.getColumnName(index) + override val columnCount: Int = cursor.columnCount +} + +@RequiresApi(Build.VERSION_CODES.P) +private object Api28Impl { + @JvmStatic + @DoNotInline + fun AbstractWindowedCursor.setWindowSize(windowSizeBytes: Long) { + window = CursorWindow(null, windowSizeBytes) + } +} diff --git a/persistence/src/commonMain/kotlin/com/powersync/persistence/driver/ColNamesSqlCursor.kt b/persistence/src/commonMain/kotlin/com/powersync/persistence/driver/ColNamesSqlCursor.kt new file mode 100644 index 00000000..2add8ffa --- /dev/null +++ b/persistence/src/commonMain/kotlin/com/powersync/persistence/driver/ColNamesSqlCursor.kt @@ -0,0 +1,9 @@ +package com.powersync.persistence.driver + +import app.cash.sqldelight.db.SqlCursor + +public interface ColNamesSqlCursor: SqlCursor { + public fun columnName(index: Int): String? + + public val columnCount: Int +} \ No newline at end of file diff --git a/persistence/src/iosMain/kotlin/com/powersync/persistence/driver/Borrowed.kt b/persistence/src/iosMain/kotlin/com/powersync/persistence/driver/Borrowed.kt new file mode 100644 index 00000000..72858220 --- /dev/null +++ b/persistence/src/iosMain/kotlin/com/powersync/persistence/driver/Borrowed.kt @@ -0,0 +1,6 @@ +package com.powersync.persistence.driver + +internal interface Borrowed { + val value: T + fun release() +} diff --git a/persistence/src/iosMain/kotlin/com/powersync/persistence/driver/NativeSqlDatabase.kt b/persistence/src/iosMain/kotlin/com/powersync/persistence/driver/NativeSqlDatabase.kt new file mode 100644 index 00000000..6c6797f7 --- /dev/null +++ b/persistence/src/iosMain/kotlin/com/powersync/persistence/driver/NativeSqlDatabase.kt @@ -0,0 +1,433 @@ +package com.powersync.persistence.driver + +import app.cash.sqldelight.Query +import app.cash.sqldelight.Transacter +import app.cash.sqldelight.db.AfterVersion +import app.cash.sqldelight.db.Closeable +import app.cash.sqldelight.db.QueryResult +import app.cash.sqldelight.db.SqlCursor +import app.cash.sqldelight.db.SqlDriver +import app.cash.sqldelight.db.SqlPreparedStatement +import app.cash.sqldelight.db.SqlSchema +import app.cash.sqldelight.internal.currentThreadId +import co.touchlab.sqliter.DatabaseConfiguration +import co.touchlab.sqliter.DatabaseConnection +import co.touchlab.sqliter.DatabaseManager +import co.touchlab.sqliter.Statement +import co.touchlab.sqliter.createDatabaseManager +import co.touchlab.sqliter.withStatement +import co.touchlab.stately.concurrency.ThreadLocalRef +import co.touchlab.stately.concurrency.value +import com.powersync.persistence.driver.util.PoolLock + +public sealed class ConnectionWrapper : SqlDriver { + internal abstract fun accessConnection( + readOnly: Boolean, + block: ThreadConnection.() -> R, + ): R + + private fun accessStatement( + readOnly: Boolean, + identifier: Int?, + sql: String, + binders: (SqlPreparedStatement.() -> Unit)?, + block: (Statement) -> R, + ): R { + println("accessStatement: $sql") + return accessConnection(readOnly) { + println("accessStatement: 1") + val statement = useStatement(identifier, sql) + println("accessStatement: 2") + try { + println("accessStatement: 3") + if (binders != null) { + SqliterStatement(statement).binders() + } + println("accessStatement: 4") + val blockResult = block(statement) + println("accessStatement: 5") + blockResult + } finally { + statement.resetStatement() + clearIfNeeded(identifier, statement) + } + } + } + + final override fun execute( + identifier: Int?, + sql: String, + parameters: Int, + binders: (SqlPreparedStatement.() -> Unit)?, + ): QueryResult = QueryResult.Value( + accessStatement(false, identifier, sql, binders) { statement -> + statement.executeUpdateDelete().toLong() + }, + ) + + final override fun executeQuery( + identifier: Int?, + sql: String, + mapper: (SqlCursor) -> QueryResult, + parameters: Int, + binders: (SqlPreparedStatement.() -> Unit)?, + ): QueryResult = accessStatement(true, identifier, sql, binders) { statement -> + mapper(SqliterSqlCursor(statement.query())) + } +} + +/** + * Native driver implementation. + * + * The driver creates two connection pools, which default to 1 connection maximum. There is a reader pool, which + * handles all query requests outside of a transaction. The other pool is the transaction pool, which handles + * all transactions and write requests outside of a transaction. + * + * When a transaction is started, that thread is aligned with a transaction pool connection. Attempting a write or + * starting another transaction, if no connections are available, will cause the caller to wait. + * + * You can have multiple connections in the transaction pool, but this would only be useful for read transactions. Writing + * from multiple connections in an overlapping manner can be problematic. + * + * Aligning a transaction to a thread means you cannot operate on a single transaction from multiple threads. + * However, it would be difficult to find a use case where this would be desirable or safe. Currently, the native + * implementation of kotlinx.coroutines does not use thread pooling. When that changes, we'll need a way to handle + * transaction/connection alignment similar to what the Android/JVM driver implemented. + * + * https://medium.com/androiddevelopers/threading-models-in-coroutines-and-android-sqlite-api-6cab11f7eb90 + * + * To use SqlDelight during create/upgrade processes, you can alternatively wrap a real connection + * with wrapConnection. + * + * SqlPreparedStatement instances also do not point to real resources until either execute or + * executeQuery is called. The SqlPreparedStatement structure also maintains a thread-aligned + * instance which accumulates bind calls. Those are replayed on a real SQLite statement instance + * when execute or executeQuery is called. This avoids race conditions with bind calls. + */ +public class NativeSqliteDriver( + private val databaseManager: DatabaseManager, + maxReaderConnections: Int = 1, +) : ConnectionWrapper(), + SqlDriver { + public constructor( + configuration: DatabaseConfiguration, + maxReaderConnections: Int = 1, + ) : this( + databaseManager = createDatabaseManager(configuration), + maxReaderConnections = maxReaderConnections, + ) + + /** + * @param onConfiguration Callback to hook into [DatabaseConfiguration] creation. + */ + public constructor( + schema: SqlSchema>, + name: String, + maxReaderConnections: Int = 1, + onConfiguration: (DatabaseConfiguration) -> DatabaseConfiguration = { it }, + vararg callbacks: AfterVersion, + ) : this( + configuration = DatabaseConfiguration( + name = name, + version = if (schema.version > Int.MAX_VALUE) error("Schema version is larger than Int.MAX_VALUE: ${schema.version}.") else schema.version.toInt(), + create = { connection -> wrapConnection(connection) { schema.create(it) } }, + upgrade = { connection, oldVersion, newVersion -> + wrapConnection(connection) { schema.migrate(it, oldVersion.toLong(), newVersion.toLong(), *callbacks) } + }, + ).let(onConfiguration), + maxReaderConnections = maxReaderConnections, + ) + + // A pool of reader connections used by all operations not in a transaction + private val transactionPool: Pool + internal val readerPool: Pool + + // Once a transaction is started and connection borrowed, it will be here, but only for that + // thread + private val borrowedConnectionThread = ThreadLocalRef>() + private val listeners = mutableMapOf>() + private val lock = PoolLock(reentrant = true) + + init { + if (databaseManager.configuration.isEphemeral) { + // Single connection for transactions + transactionPool = Pool(1) { + ThreadConnection(databaseManager.createMultiThreadedConnection()) { _ -> + borrowedConnectionThread.let { + it.get()?.release() + it.value = null + } + } + } + + readerPool = transactionPool + } else { + // Single connection for transactions + transactionPool = Pool(1) { + ThreadConnection(databaseManager.createMultiThreadedConnection()) { _ -> + borrowedConnectionThread.let { + it.get()?.release() + it.value = null + } + } + } + + readerPool = Pool(maxReaderConnections) { + val connection = databaseManager.createMultiThreadedConnection() + connection.withStatement("PRAGMA query_only = 1") { execute() } // Ensure read only + ThreadConnection(connection) { + throw UnsupportedOperationException("Should never be in a transaction") + } + } + } + } + + override fun addListener(vararg queryKeys: String, listener: Query.Listener) { + lock.withLock { + queryKeys.forEach { + listeners.getOrPut(it) { mutableSetOf() }.add(listener) + } + } + } + + override fun removeListener(vararg queryKeys: String, listener: Query.Listener) { + lock.withLock { + queryKeys.forEach { + listeners.get(it)?.remove(listener) + } + } + } + + override fun notifyListeners(vararg queryKeys: String) { + val listenersToNotify = mutableSetOf() + lock.withLock { + queryKeys.forEach { key -> listeners.get(key)?.let { listenersToNotify.addAll(it) } } + } + listenersToNotify.forEach(Query.Listener::queryResultsChanged) + } + + override fun currentTransaction(): Transacter.Transaction? { + println("currentTransaction() thread id: ${currentThreadId()}") + return borrowedConnectionThread.get()?.value?.transaction?.value + } + + override fun newTransaction(): QueryResult { + println("newTransaction() thread id: ${currentThreadId()}") + val alreadyBorrowed = borrowedConnectionThread.get() + val transaction = if (alreadyBorrowed == null) { + val borrowed = transactionPool.borrowEntry() + + try { + val trans = borrowed.value.newTransaction() + + borrowedConnectionThread.value = borrowed + trans + } catch (e: Throwable) { + // Unlock on failure. + borrowed.release() + throw e + } + } else { + alreadyBorrowed.value.newTransaction() + } + + return QueryResult.Value(transaction) + } + + /** + * If we're in a transaction, then I have a connection. Otherwise use shared. + */ + override fun accessConnection( + readOnly: Boolean, + block: ThreadConnection.() -> R, + ): R { + println("accessConnection() thread id: ${currentThreadId()}") + val mine = borrowedConnectionThread.get() + println("accessConnection() with connection $mine. Thread id: ${currentThreadId()}") + return if (readOnly) { + // Code intends to read, which doesn't need to block + if (mine != null) { + mine.value.block() + } else { + println("accessConnection() before readerPool") + val conn = readerPool.access(block) + println("accessConnection() after readerPool") + conn + } + } else { + // Code intends to write, for which we're managing locks in code + if (mine != null) { + mine.value.block() + } else { + transactionPool.access(block) + } + } + } + + override fun close() { + transactionPool.close() + readerPool.close() + } +} + +/** + * Helper function to create an in-memory driver. In-memory drivers have a single connection, so + * concurrent access will be block + */ +public fun inMemoryDriver(schema: SqlSchema>): NativeSqliteDriver = NativeSqliteDriver( + DatabaseConfiguration( + name = null, + inMemory = true, + version = if (schema.version > Int.MAX_VALUE) error("Schema version is larger than Int.MAX_VALUE: ${schema.version}.") else schema.version.toInt(), + create = { connection -> + wrapConnection(connection) { schema.create(it) } + }, + upgrade = { connection, oldVersion, newVersion -> + wrapConnection(connection) { schema.migrate(it, oldVersion.toLong(), newVersion.toLong()) } + }, + ), +) + +/** + * Sqliter's DatabaseConfiguration takes lambda arguments for it's create and upgrade operations, + * which each take a DatabaseConnection argument. Use wrapConnection to have SqlDelight access this + * passed connection and avoid the pooling that the full SqlDriver instance performs. + * + * Note that queries created during this operation will be cleaned up. If holding onto a cursor from + * a wrap call, it will no longer be viable. + */ +public fun wrapConnection( + connection: DatabaseConnection, + block: (SqlDriver) -> Unit, +) { + val conn = SqliterWrappedConnection(ThreadConnection(connection) {}) + try { + block(conn) + } finally { + conn.close() + } +} + +/** + * SqlDriverConnection that wraps a Sqliter connection. Useful for migration tasks, or if you + * don't want the polling. + */ +internal class SqliterWrappedConnection( + private val threadConnection: ThreadConnection, +) : ConnectionWrapper(), + SqlDriver { + override fun currentTransaction(): Transacter.Transaction? = threadConnection.transaction.value + + override fun newTransaction(): QueryResult = QueryResult.Value(threadConnection.newTransaction()) + + override fun accessConnection( + readOnly: Boolean, + block: ThreadConnection.() -> R, + ): R = threadConnection.block() + + override fun addListener(vararg queryKeys: String, listener: Query.Listener) { + // No-op + } + + override fun removeListener(vararg queryKeys: String, listener: Query.Listener) { + // No-op + } + + override fun notifyListeners(vararg queryKeys: String) { + // No-op + } + + override fun close() { + threadConnection.cleanUp() + } +} + +/** + * Wraps and manages a "real" database connection. + * + * SQLite statements are specific to connections, and must be finalized explicitly. Cursors are + * backed by a statement resource, so we keep links to open cursors to allow us to close them out + * properly in cases where the user does not. + */ +internal class ThreadConnection( + private val connection: DatabaseConnection, + private val onEndTransaction: (ThreadConnection) -> Unit, +) : Closeable { + internal val transaction = ThreadLocalRef() + private val closed: Boolean + get() = connection.closed + + private val statementCache = mutableMapOf() + + fun useStatement(identifier: Int?, sql: String): Statement { + return if (identifier != null) { + statementCache.getOrPut(identifier) { + connection.createStatement(sql) + } + } else { + connection.createStatement(sql) + } + } + + fun clearIfNeeded(identifier: Int?, statement: Statement) { + if (identifier == null || closed) { + statement.finalizeStatement() + } + } + + fun newTransaction(): Transacter.Transaction { + val enclosing = transaction.value + + // Create here, in case we bomb... + if (enclosing == null) { + connection.beginTransaction() + } + + val trans = Transaction(enclosing) + transaction.value = trans + + return trans + } + + /** + * This should only be called directly from wrapConnection. Clean resources without actually closing + * the underlying connection. + */ + internal fun cleanUp() { + statementCache.values.forEach { it: Statement -> + it.finalizeStatement() + } + } + + override fun close() { + cleanUp() + connection.close() + } + + private inner class Transaction( + override val enclosingTransaction: Transacter.Transaction?, + ) : Transacter.Transaction() { + + override fun endTransaction(successful: Boolean): QueryResult { + transaction.value = enclosingTransaction + + if (enclosingTransaction == null) { + try { + if (successful) { + connection.setTransactionSuccessful() + } + + connection.endTransaction() + } finally { + // Release if we have + onEndTransaction(this@ThreadConnection) + } + } + return QueryResult.Unit + } + } +} + +private inline val DatabaseConfiguration.isEphemeral: Boolean get() { + return inMemory || (name?.isEmpty() == true && extendedConfig.basePath?.isEmpty() == true) +} diff --git a/persistence/src/iosMain/kotlin/com/powersync/persistence/driver/Pool.kt b/persistence/src/iosMain/kotlin/com/powersync/persistence/driver/Pool.kt new file mode 100644 index 00000000..d94b2d5e --- /dev/null +++ b/persistence/src/iosMain/kotlin/com/powersync/persistence/driver/Pool.kt @@ -0,0 +1,123 @@ +package com.powersync.persistence.driver + +import app.cash.sqldelight.db.Closeable +import co.touchlab.stately.concurrency.AtomicBoolean +import com.powersync.persistence.driver.util.PoolLock +import kotlin.concurrent.AtomicReference + +/** + * A shared pool of connections. Borrowing is blocking when all connections are in use, and the pool has reached its + * designated capacity. + */ +internal class Pool(internal val capacity: Int, private val producer: () -> T) { + /** + * Hold a list of active connections. If it is null, it means the MultiPool has been closed. + */ + private val entriesRef = AtomicReference?>(listOf()) + private val poolLock = PoolLock() + + /** + * For test purposes only + */ + internal fun entryCount(): Int = poolLock.withLock { + entriesRef.value?.size ?: 0 + } + + fun borrowEntry(): Borrowed { + val snapshot = entriesRef.value ?: throw ClosedMultiPoolException + + // Fastpath: Borrow the first available entry. + val firstAvailable = snapshot.firstOrNull { it.tryToAcquire() } + + if (firstAvailable != null) { + return firstAvailable.asBorrowed(poolLock) + } + + // Slowpath: Create a new entry if capacity limit has not been reached, or wait for the next available entry. + val nextAvailable = poolLock.withLock { + // Reload the list since it could've been updated by other threads concurrently. + val entries = entriesRef.value ?: throw ClosedMultiPoolException + + if (entries.count() < capacity) { + // Capacity hasn't been reached — create a new entry to serve this call. + val newEntry = Entry(producer()) + val done = newEntry.tryToAcquire() + check(done) + + entriesRef.value = (entries + listOf(newEntry)) + return@withLock newEntry + } else { + // Capacity is reached — wait for the next available entry. + return@withLock loopForConditionalResult { + // Reload the list, since the thread can be suspended here while the list of entries has been modified. + val innerEntries = entriesRef.value ?: throw ClosedMultiPoolException + innerEntries.firstOrNull { it.tryToAcquire() } + } + } + } + + return nextAvailable.asBorrowed(poolLock) + } + + fun access(action: (T) -> R): R { + val borrowed = borrowEntry() + return try { + println("before access, capacity: $capacity") + val result = action(borrowed.value) + println("after access") + result + } finally { + borrowed.release() + } + } + + fun close() { + if (!poolLock.close()) { + return + } + + val entries = entriesRef.value + val done = entriesRef.compareAndSet(entries, null) + check(done) + + entries?.forEach { it.value.close() } + } + + inner class Entry(val value: T) { + val isAvailable = AtomicBoolean(true) + + fun tryToAcquire(): Boolean = isAvailable.compareAndSet(expected = true, new = false) + + fun asBorrowed(poolLock: PoolLock): Borrowed = object : Borrowed { + override val value: T + get() = this@Entry.value + + override fun release() { + /** + * Mark-as-available should be done before signalling blocked threads via [PoolLock.notifyConditionChanged], + * since the happens-before relationship guarantees the woken thread to see the + * available entry (if not having been taken by other threads during the wake-up lead time). + */ + + val done = isAvailable.compareAndSet(expected = false, new = true) + check(done) + + // While signalling blocked threads does not require locking, doing so avoids a subtle race + // condition in which: + // + // 1. a [loopForConditionalResult] iteration in [borrowEntry] slow path is happening concurrently; + // 2. the iteration fails to see the atomic `isAvailable = true` above; + // 3. we signal availability here but it is a no-op due to no waiting blocker; and finally + // 4. the iteration entered an indefinite blocking wait, not being aware of us having signalled availability here. + // + // By acquiring the pool lock first, signalling cannot happen concurrently with the loop + // iterations in [borrowEntry], thus eliminating the race condition. + poolLock.withLock { + poolLock.notifyConditionChanged() + } + } + } + } +} + +private val ClosedMultiPoolException get() = IllegalStateException("Attempt to access a closed MultiPool.") diff --git a/persistence/src/iosMain/kotlin/com/powersync/persistence/driver/SqliterSqlCursor.kt b/persistence/src/iosMain/kotlin/com/powersync/persistence/driver/SqliterSqlCursor.kt new file mode 100644 index 00000000..b1c36090 --- /dev/null +++ b/persistence/src/iosMain/kotlin/com/powersync/persistence/driver/SqliterSqlCursor.kt @@ -0,0 +1,34 @@ +package com.powersync.persistence.driver + +import app.cash.sqldelight.db.QueryResult +import app.cash.sqldelight.db.SqlCursor +import co.touchlab.sqliter.Cursor +import co.touchlab.sqliter.getBytesOrNull +import co.touchlab.sqliter.getDoubleOrNull +import co.touchlab.sqliter.getLongOrNull +import co.touchlab.sqliter.getStringOrNull + +/** + * Wrapper for cursor calls. Cursors point to real SQLite statements, so we need to be careful with + * them. If dev closes the outer structure, this will get closed as well, which means it could start + * throwing errors if you're trying to access it. + */ +internal class SqliterSqlCursor(private val cursor: Cursor) : ColNamesSqlCursor { + override fun getBytes(index: Int): ByteArray? = cursor.getBytesOrNull(index) + + override fun getDouble(index: Int): Double? = cursor.getDoubleOrNull(index) + + override fun getLong(index: Int): Long? = cursor.getLongOrNull(index) + + override fun getString(index: Int): String? = cursor.getStringOrNull(index) + + override fun getBoolean(index: Int): Boolean? { + return (cursor.getLongOrNull(index) ?: return null) == 1L + } + + override fun columnName(index: Int): String? = cursor.columnName(index) + + override val columnCount: Int = cursor.columnCount + + override fun next(): QueryResult.Value = QueryResult.Value(cursor.next()) +} diff --git a/persistence/src/iosMain/kotlin/com/powersync/persistence/driver/SqliterStatement.kt b/persistence/src/iosMain/kotlin/com/powersync/persistence/driver/SqliterStatement.kt new file mode 100644 index 00000000..44f8a5c1 --- /dev/null +++ b/persistence/src/iosMain/kotlin/com/powersync/persistence/driver/SqliterStatement.kt @@ -0,0 +1,42 @@ +package com.powersync.persistence.driver + +import app.cash.sqldelight.db.SqlPreparedStatement +import co.touchlab.sqliter.Statement +import co.touchlab.sqliter.bindBlob +import co.touchlab.sqliter.bindDouble +import co.touchlab.sqliter.bindLong +import co.touchlab.sqliter.bindString + +/** + * @param [recycle] A function which recycles any resources this statement is backed by. + */ +internal class SqliterStatement( + private val statement: Statement, +) : SqlPreparedStatement { + override fun bindBytes(index: Int, bytes: ByteArray?) { + statement.bindBlob(index + 1, bytes) + } + + override fun bindLong(index: Int, long: Long?) { + statement.bindLong(index + 1, long) + } + + override fun bindDouble(index: Int, double: Double?) { + statement.bindDouble(index + 1, double) + } + + override fun bindString(index: Int, string: String?) { + statement.bindString(index + 1, string) + } + + override fun bindBoolean(index: Int, boolean: Boolean?) { + statement.bindLong( + index + 1, + when (boolean) { + null -> null + true -> 1L + false -> 0L + }, + ) + } +} diff --git a/persistence/src/iosMain/kotlin/com/powersync/persistence/driver/util/PoolLock.kt b/persistence/src/iosMain/kotlin/com/powersync/persistence/driver/util/PoolLock.kt new file mode 100644 index 00000000..a68cc553 --- /dev/null +++ b/persistence/src/iosMain/kotlin/com/powersync/persistence/driver/util/PoolLock.kt @@ -0,0 +1,89 @@ +package com.powersync.persistence.driver.util + +import co.touchlab.stately.concurrency.AtomicBoolean +import kotlinx.cinterop.ExperimentalForeignApi +import kotlinx.cinterop.alloc +import kotlinx.cinterop.free +import kotlinx.cinterop.nativeHeap +import kotlinx.cinterop.ptr +import platform.posix.pthread_cond_destroy +import platform.posix.pthread_cond_init +import platform.posix.pthread_cond_signal +import platform.posix.pthread_cond_t +import platform.posix.pthread_cond_wait +import platform.posix.pthread_mutex_destroy +import platform.posix.pthread_mutex_init +import platform.posix.pthread_mutex_lock +import platform.posix.pthread_mutex_t +import platform.posix.pthread_mutex_unlock +import platform.posix.pthread_mutexattr_destroy +import platform.posix.pthread_mutexattr_init +import platform.posix.pthread_mutexattr_settype +import platform.posix.pthread_mutexattr_t + +@OptIn(ExperimentalForeignApi::class) +internal class PoolLock constructor(reentrant: Boolean = false) { + private val isActive = AtomicBoolean(true) + + private val attr = nativeHeap.alloc() + .apply { + pthread_mutexattr_init(ptr) + if (reentrant) { + pthread_mutexattr_settype(ptr, platform.posix.PTHREAD_MUTEX_RECURSIVE) + } + } + private val mutex = nativeHeap.alloc() + .apply { pthread_mutex_init(ptr, attr.ptr) } + private val cond = nativeHeap.alloc() + .apply { pthread_cond_init(ptr, null) } + + fun withLock( + action: CriticalSection.() -> R, + ): R { + check(isActive.value) + pthread_mutex_lock(mutex.ptr) + + val result: R + + try { + result = action(CriticalSection()) + } finally { + pthread_mutex_unlock(mutex.ptr) + } + + return result + } + + fun notifyConditionChanged() { + pthread_cond_signal(cond.ptr) + } + + fun close(): Boolean { + if (isActive.compareAndSet(expected = true, new = false)) { + pthread_cond_destroy(cond.ptr) + pthread_mutex_destroy(mutex.ptr) + pthread_mutexattr_destroy(attr.ptr) + nativeHeap.free(cond) + nativeHeap.free(mutex) + nativeHeap.free(attr) + return true + } + + return false + } + + inner class CriticalSection { + fun loopForConditionalResult(block: () -> R?): R { + check(isActive.value) + + var result = block() + + while (result == null) { + pthread_cond_wait(cond.ptr, mutex.ptr) + result = block() + } + + return result + } + } +} \ No newline at end of file diff --git a/persistence/src/jvmMain/kotlin/com/powersync/persistence/driver/JdbcPreparedStatement.kt b/persistence/src/jvmMain/kotlin/com/powersync/persistence/driver/JdbcPreparedStatement.kt new file mode 100644 index 00000000..014e480a --- /dev/null +++ b/persistence/src/jvmMain/kotlin/com/powersync/persistence/driver/JdbcPreparedStatement.kt @@ -0,0 +1,163 @@ +package com.powersync.persistence.driver + +import app.cash.sqldelight.db.QueryResult +import app.cash.sqldelight.db.SqlCursor +import app.cash.sqldelight.db.SqlPreparedStatement +import java.math.BigDecimal +import java.sql.PreparedStatement +import java.sql.ResultSet +import java.sql.Types + +/** +* Binds the parameter to [preparedStatement] by calling [bindString], [bindLong] or similar. +* After binding, [execute] executes the query without a result, while [executeQuery] returns [JdbcCursor]. +*/ +public class JdbcPreparedStatement( + private val preparedStatement: PreparedStatement, +) : SqlPreparedStatement { + override fun bindBytes(index: Int, bytes: ByteArray?) { + preparedStatement.setBytes(index + 1, bytes) + } + + override fun bindBoolean(index: Int, boolean: Boolean?) { + if (boolean == null) { + preparedStatement.setNull(index + 1, Types.BOOLEAN) + } else { + preparedStatement.setBoolean(index + 1, boolean) + } + } + + public fun bindByte(index: Int, byte: Byte?) { + if (byte == null) { + preparedStatement.setNull(index + 1, Types.TINYINT) + } else { + preparedStatement.setByte(index + 1, byte) + } + } + + public fun bindShort(index: Int, short: Short?) { + if (short == null) { + preparedStatement.setNull(index + 1, Types.SMALLINT) + } else { + preparedStatement.setShort(index + 1, short) + } + } + + public fun bindInt(index: Int, int: Int?) { + if (int == null) { + preparedStatement.setNull(index + 1, Types.INTEGER) + } else { + preparedStatement.setInt(index + 1, int) + } + } + + override fun bindLong(index: Int, long: Long?) { + if (long == null) { + preparedStatement.setNull(index + 1, Types.BIGINT) + } else { + preparedStatement.setLong(index + 1, long) + } + } + + public fun bindFloat(index: Int, float: Float?) { + if (float == null) { + preparedStatement.setNull(index + 1, Types.REAL) + } else { + preparedStatement.setFloat(index + 1, float) + } + } + + override fun bindDouble(index: Int, double: Double?) { + if (double == null) { + preparedStatement.setNull(index + 1, Types.DOUBLE) + } else { + preparedStatement.setDouble(index + 1, double) + } + } + + public fun bindBigDecimal(index: Int, decimal: BigDecimal?) { + preparedStatement.setBigDecimal(index + 1, decimal) + } + + public fun bindObject(index: Int, obj: Any?) { + if (obj == null) { + preparedStatement.setNull(index + 1, Types.OTHER) + } else { + preparedStatement.setObject(index + 1, obj) + } + } + + public fun bindObject(index: Int, obj: Any?, type: Int) { + if (obj == null) { + preparedStatement.setNull(index + 1, type) + } else { + preparedStatement.setObject(index + 1, obj, type) + } + } + + override fun bindString(index: Int, string: String?) { + preparedStatement.setString(index + 1, string) + } + + public fun bindDate(index: Int, date: java.sql.Date?) { + preparedStatement.setDate(index, date) + } + + public fun bindTime(index: Int, date: java.sql.Time?) { + preparedStatement.setTime(index, date) + } + + public fun bindTimestamp(index: Int, timestamp: java.sql.Timestamp?) { + preparedStatement.setTimestamp(index, timestamp) + } + + public fun executeQuery(mapper: (SqlCursor) -> R): R { + try { + return preparedStatement.executeQuery() + .use { resultSet -> mapper(JdbcCursor(resultSet)) } + } finally { + preparedStatement.close() + } + } + + public fun execute(): Long { + return if (preparedStatement.execute()) { + // returned true so this is a result set return type. + 0L + } else { + preparedStatement.updateCount.toLong() + } + } +} + +/** + * Iterate each row in [resultSet] and map the columns to Kotlin classes by calling [getString], [getLong] etc. + * Use [next] to retrieve the next row and [close] to close the connection. + */ +internal class JdbcCursor(val resultSet: ResultSet) : ColNamesSqlCursor { + override fun getString(index: Int): String? = resultSet.getString(index + 1) + override fun getBytes(index: Int): ByteArray? = resultSet.getBytes(index + 1) + override fun getBoolean(index: Int): Boolean? = getAtIndex(index, resultSet::getBoolean) + override fun columnName(index: Int): String? = resultSet.metaData.getColumnName(index) + override val columnCount: Int = resultSet.metaData.columnCount + + fun getByte(index: Int): Byte? = getAtIndex(index, resultSet::getByte) + fun getShort(index: Int): Short? = getAtIndex(index, resultSet::getShort) + fun getInt(index: Int): Int? = getAtIndex(index, resultSet::getInt) + override fun getLong(index: Int): Long? = getAtIndex(index, resultSet::getLong) + fun getFloat(index: Int): Float? = getAtIndex(index, resultSet::getFloat) + override fun getDouble(index: Int): Double? = getAtIndex(index, resultSet::getDouble) + fun getBigDecimal(index: Int): BigDecimal? = resultSet.getBigDecimal(index + 1) + inline fun getObject(index: Int): T? = resultSet.getObject(index + 1, T::class.java) + fun getDate(index: Int): java.sql.Date? = resultSet.getDate(index) + fun getTime(index: Int): java.sql.Time? = resultSet.getTime(index) + fun getTimestamp(index: Int): java.sql.Timestamp? = resultSet.getTimestamp(index) + + @Suppress("UNCHECKED_CAST") + fun getArray(index: Int) = getAtIndex(index, resultSet::getArray)?.array as Array? + + private fun getAtIndex(index: Int, converter: (Int) -> T): T? = + converter(index + 1).takeUnless { resultSet.wasNull() } + + override fun next(): QueryResult.Value = QueryResult.Value(resultSet.next()) +} From 44494d2c8fa98fefc46a04821504ef551b2c6acb Mon Sep 17 00:00:00 2001 From: Kevin Galligan Date: Sun, 2 Feb 2025 23:34:53 -0500 Subject: [PATCH 3/3] Clean up formatting --- .../persistence/driver/AndroidSqliteDriver.kt | 564 ++++++++-------- .../persistence/driver/ColNamesSqlCursor.kt | 2 +- .../powersync/persistence/driver/Borrowed.kt | 4 +- .../persistence/driver/NativeSqlDatabase.kt | 604 +++++++++--------- .../com/powersync/persistence/driver/Pool.kt | 183 +++--- .../persistence/driver/SqliterSqlCursor.kt | 21 +- .../persistence/driver/SqliterStatement.kt | 46 +- .../persistence/driver/util/PoolLock.kt | 100 +-- .../driver/JdbcPreparedStatement.kt | 280 ++++---- 9 files changed, 894 insertions(+), 910 deletions(-) diff --git a/persistence/src/androidMain/kotlin/com/powersync/persistence/driver/AndroidSqliteDriver.kt b/persistence/src/androidMain/kotlin/com/powersync/persistence/driver/AndroidSqliteDriver.kt index 70433a17..3390c9e3 100644 --- a/persistence/src/androidMain/kotlin/com/powersync/persistence/driver/AndroidSqliteDriver.kt +++ b/persistence/src/androidMain/kotlin/com/powersync/persistence/driver/AndroidSqliteDriver.kt @@ -24,335 +24,335 @@ import app.cash.sqldelight.db.SqlPreparedStatement import app.cash.sqldelight.db.SqlSchema import com.powersync.persistence.driver.Api28Impl.setWindowSize -import kotlin.collections.forEach -import kotlin.collections.getOrPut -import kotlin.io.use -import kotlin.let - private const val DEFAULT_CACHE_SIZE = 20 public class AndroidSqliteDriver private constructor( - private val openHelper: SupportSQLiteOpenHelper? = null, - database: SupportSQLiteDatabase? = null, - private val cacheSize: Int, - private val windowSizeBytes: Long? = null, + private val openHelper: SupportSQLiteOpenHelper? = null, + database: SupportSQLiteDatabase? = null, + private val cacheSize: Int, + private val windowSizeBytes: Long? = null, ) : SqlDriver { - init { - require((openHelper != null) xor (database != null)) - } - - private val transactions = ThreadLocal() - private val database by lazy { - openHelper?.writableDatabase ?: database!! - } - - public constructor( - openHelper: SupportSQLiteOpenHelper, - ) : this(openHelper = openHelper, database = null, cacheSize = DEFAULT_CACHE_SIZE, windowSizeBytes = null) - - /** - * @param [cacheSize] The number of compiled sqlite statements to keep in memory per connection. - * Defaults to 20. - * @param [useNoBackupDirectory] Sets whether to use a no backup directory or not. - * @param [windowSizeBytes] Size of cursor window in bytes, per [CursorWindow] (Android 28+ only), or null to use the default. - */ - @JvmOverloads - public constructor( - schema: SqlSchema>, - context: Context, - name: String? = null, - factory: SupportSQLiteOpenHelper.Factory = FrameworkSQLiteOpenHelperFactory(), - callback: SupportSQLiteOpenHelper.Callback = AndroidSqliteDriver.Callback(schema), - cacheSize: Int = DEFAULT_CACHE_SIZE, - useNoBackupDirectory: Boolean = false, - windowSizeBytes: Long? = null, - ) : this( - database = null, - openHelper = factory.create( - SupportSQLiteOpenHelper.Configuration.builder(context) - .callback(callback) - .name(name) - .noBackupDirectory(useNoBackupDirectory) - .build(), - ), - cacheSize = cacheSize, - windowSizeBytes = windowSizeBytes, - ) - - @JvmOverloads - public constructor( - database: SupportSQLiteDatabase, - cacheSize: Int = DEFAULT_CACHE_SIZE, - windowSizeBytes: Long? = null, - ) : this(openHelper = null, database = database, cacheSize = cacheSize, windowSizeBytes = windowSizeBytes) - - private val statements = object : LruCache(cacheSize) { - override fun entryRemoved( - evicted: Boolean, - key: Int, - oldValue: AndroidStatement, - newValue: AndroidStatement?, - ) { - if (evicted) oldValue.close() + init { + require((openHelper != null) xor (database != null)) } - } - - private val listeners = linkedMapOf>() - - override fun addListener(vararg queryKeys: String, listener: Query.Listener) { - synchronized(listeners) { - queryKeys.forEach { - listeners.getOrPut(it, { linkedSetOf() }).add(listener) - } - } - } - - override fun removeListener(vararg queryKeys: String, listener: Query.Listener) { - synchronized(listeners) { - queryKeys.forEach { - listeners[it]?.remove(listener) - } - } - } - - override fun notifyListeners(vararg queryKeys: String) { - val listenersToNotify = linkedSetOf() - synchronized(listeners) { - queryKeys.forEach { listeners[it]?.let(listenersToNotify::addAll) } - } - listenersToNotify.forEach(Query.Listener::queryResultsChanged) - } - - override fun newTransaction(): QueryResult { - val enclosing = transactions.get() - val transaction = Transaction(enclosing) - transactions.set(transaction) - - if (enclosing == null) { - database.beginTransactionNonExclusive() + + private val transactions = ThreadLocal() + private val database by lazy { + openHelper?.writableDatabase ?: database!! } - return QueryResult.Value(transaction) - } + public constructor( + openHelper: SupportSQLiteOpenHelper, + ) : this(openHelper = openHelper, database = null, cacheSize = DEFAULT_CACHE_SIZE, windowSizeBytes = null) + + /** + * @param [cacheSize] The number of compiled sqlite statements to keep in memory per connection. + * Defaults to 20. + * @param [useNoBackupDirectory] Sets whether to use a no backup directory or not. + * @param [windowSizeBytes] Size of cursor window in bytes, per [CursorWindow] (Android 28+ only), or null to use the default. + */ + @JvmOverloads + public constructor( + schema: SqlSchema>, + context: Context, + name: String? = null, + factory: SupportSQLiteOpenHelper.Factory = FrameworkSQLiteOpenHelperFactory(), + callback: SupportSQLiteOpenHelper.Callback = AndroidSqliteDriver.Callback(schema), + cacheSize: Int = DEFAULT_CACHE_SIZE, + useNoBackupDirectory: Boolean = false, + windowSizeBytes: Long? = null, + ) : this( + database = null, + openHelper = factory.create( + SupportSQLiteOpenHelper.Configuration.builder(context) + .callback(callback) + .name(name) + .noBackupDirectory(useNoBackupDirectory) + .build(), + ), + cacheSize = cacheSize, + windowSizeBytes = windowSizeBytes, + ) + + @JvmOverloads + public constructor( + database: SupportSQLiteDatabase, + cacheSize: Int = DEFAULT_CACHE_SIZE, + windowSizeBytes: Long? = null, + ) : this(openHelper = null, database = database, cacheSize = cacheSize, windowSizeBytes = windowSizeBytes) + + private val statements = object : LruCache(cacheSize) { + override fun entryRemoved( + evicted: Boolean, + key: Int, + oldValue: AndroidStatement, + newValue: AndroidStatement?, + ) { + if (evicted) oldValue.close() + } + } - override fun currentTransaction(): Transacter.Transaction? = transactions.get() + private val listeners = linkedMapOf>() - internal inner class Transaction( - override val enclosingTransaction: Transacter.Transaction?, - ) : Transacter.Transaction() { - override fun endTransaction(successful: Boolean): QueryResult { - if (enclosingTransaction == null) { - if (successful) { - database.setTransactionSuccessful() - database.endTransaction() - } else { - database.endTransaction() + override fun addListener(vararg queryKeys: String, listener: Query.Listener) { + synchronized(listeners) { + queryKeys.forEach { + listeners.getOrPut(it, { linkedSetOf() }).add(listener) + } } - } - transactions.set(enclosingTransaction) - return QueryResult.Unit } - } - - private fun execute( - identifier: Int?, - createStatement: () -> AndroidStatement, - binders: (SqlPreparedStatement.() -> Unit)?, - result: AndroidStatement.() -> T, - ): QueryResult.Value { - var statement: AndroidStatement? = null - if (identifier != null) { - statement = statements.remove(identifier) + + override fun removeListener(vararg queryKeys: String, listener: Query.Listener) { + synchronized(listeners) { + queryKeys.forEach { + listeners[it]?.remove(listener) + } + } } - if (statement == null) { - statement = createStatement() + + override fun notifyListeners(vararg queryKeys: String) { + val listenersToNotify = linkedSetOf() + synchronized(listeners) { + queryKeys.forEach { listeners[it]?.let(listenersToNotify::addAll) } + } + listenersToNotify.forEach(Query.Listener::queryResultsChanged) } - try { - if (binders != null) { - statement.binders() - } - return QueryResult.Value(statement.result()) - } finally { - if (identifier != null) { - statements.put(identifier, statement)?.close() - } else { - statement.close() - } + + override fun newTransaction(): QueryResult { + val enclosing = transactions.get() + val transaction = Transaction(enclosing) + transactions.set(transaction) + + if (enclosing == null) { + database.beginTransactionNonExclusive() + } + + return QueryResult.Value(transaction) } - } - - override fun execute( - identifier: Int?, - sql: String, - parameters: Int, - binders: (SqlPreparedStatement.() -> Unit)?, - ): QueryResult = execute(identifier, { AndroidPreparedStatement(database.compileStatement(sql)) }, binders, { execute() }) - - override fun executeQuery( - identifier: Int?, - sql: String, - mapper: (SqlCursor) -> QueryResult, - parameters: Int, - binders: (SqlPreparedStatement.() -> Unit)?, - ): QueryResult.Value = execute(identifier, { AndroidQuery(sql, database, parameters, windowSizeBytes) }, binders) { executeQuery(mapper) } - - override fun close() { - statements.evictAll() - return openHelper?.close() ?: database.close() - } - - public open class Callback( - private val schema: SqlSchema>, - private vararg val callbacks: AfterVersion, - ) : SupportSQLiteOpenHelper.Callback( - if (schema.version > Int.MAX_VALUE) error("Schema version is larger than Int.MAX_VALUE: ${schema.version}.") else schema.version.toInt(), - ) { - - override fun onCreate(db: SupportSQLiteDatabase) { - schema.create(AndroidSqliteDriver(openHelper = null, database = db, cacheSize = 1)) + + override fun currentTransaction(): Transacter.Transaction? = transactions.get() + + internal inner class Transaction( + override val enclosingTransaction: Transacter.Transaction?, + ) : Transacter.Transaction() { + override fun endTransaction(successful: Boolean): QueryResult { + if (enclosingTransaction == null) { + if (successful) { + database.setTransactionSuccessful() + database.endTransaction() + } else { + database.endTransaction() + } + } + transactions.set(enclosingTransaction) + return QueryResult.Unit + } + } + + private fun execute( + identifier: Int?, + createStatement: () -> AndroidStatement, + binders: (SqlPreparedStatement.() -> Unit)?, + result: AndroidStatement.() -> T, + ): QueryResult.Value { + var statement: AndroidStatement? = null + if (identifier != null) { + statement = statements.remove(identifier) + } + if (statement == null) { + statement = createStatement() + } + try { + if (binders != null) { + statement.binders() + } + return QueryResult.Value(statement.result()) + } finally { + if (identifier != null) { + statements.put(identifier, statement)?.close() + } else { + statement.close() + } + } } - override fun onUpgrade( - db: SupportSQLiteDatabase, - oldVersion: Int, - newVersion: Int, + override fun execute( + identifier: Int?, + sql: String, + parameters: Int, + binders: (SqlPreparedStatement.() -> Unit)?, + ): QueryResult = + execute(identifier, { AndroidPreparedStatement(database.compileStatement(sql)) }, binders, { execute() }) + + override fun executeQuery( + identifier: Int?, + sql: String, + mapper: (SqlCursor) -> QueryResult, + parameters: Int, + binders: (SqlPreparedStatement.() -> Unit)?, + ): QueryResult.Value = execute( + identifier, + { AndroidQuery(sql, database, parameters, windowSizeBytes) }, + binders + ) { executeQuery(mapper) } + + override fun close() { + statements.evictAll() + return openHelper?.close() ?: database.close() + } + + public open class Callback( + private val schema: SqlSchema>, + private vararg val callbacks: AfterVersion, + ) : SupportSQLiteOpenHelper.Callback( + if (schema.version > Int.MAX_VALUE) error("Schema version is larger than Int.MAX_VALUE: ${schema.version}.") else schema.version.toInt(), ) { - schema.migrate( - AndroidSqliteDriver(openHelper = null, database = db, cacheSize = 1), - oldVersion.toLong(), - newVersion.toLong(), - *callbacks, - ) + + override fun onCreate(db: SupportSQLiteDatabase) { + schema.create(AndroidSqliteDriver(openHelper = null, database = db, cacheSize = 1)) + } + + override fun onUpgrade( + db: SupportSQLiteDatabase, + oldVersion: Int, + newVersion: Int, + ) { + schema.migrate( + AndroidSqliteDriver(openHelper = null, database = db, cacheSize = 1), + oldVersion.toLong(), + newVersion.toLong(), + *callbacks, + ) + } } - } } internal interface AndroidStatement : SqlPreparedStatement { - fun execute(): Long - fun executeQuery(mapper: (SqlCursor) -> QueryResult): R - fun close() + fun execute(): Long + fun executeQuery(mapper: (SqlCursor) -> QueryResult): R + fun close() } private class AndroidPreparedStatement( - private val statement: SupportSQLiteStatement, + private val statement: SupportSQLiteStatement, ) : AndroidStatement { - override fun bindBytes(index: Int, bytes: ByteArray?) { - if (bytes == null) statement.bindNull(index + 1) else statement.bindBlob(index + 1, bytes) - } - - override fun bindLong(index: Int, long: Long?) { - if (long == null) statement.bindNull(index + 1) else statement.bindLong(index + 1, long) - } - - override fun bindDouble(index: Int, double: Double?) { - if (double == null) statement.bindNull(index + 1) else statement.bindDouble(index + 1, double) - } - - override fun bindString(index: Int, string: String?) { - if (string == null) statement.bindNull(index + 1) else statement.bindString(index + 1, string) - } - - override fun bindBoolean(index: Int, boolean: Boolean?) { - if (boolean == null) { - statement.bindNull(index + 1) - } else { - statement.bindLong(index + 1, if (boolean) 1L else 0L) + override fun bindBytes(index: Int, bytes: ByteArray?) { + if (bytes == null) statement.bindNull(index + 1) else statement.bindBlob(index + 1, bytes) } - } - override fun executeQuery(mapper: (SqlCursor) -> QueryResult): R = throw UnsupportedOperationException() + override fun bindLong(index: Int, long: Long?) { + if (long == null) statement.bindNull(index + 1) else statement.bindLong(index + 1, long) + } - override fun execute(): Long { - return statement.executeUpdateDelete().toLong() - } + override fun bindDouble(index: Int, double: Double?) { + if (double == null) statement.bindNull(index + 1) else statement.bindDouble(index + 1, double) + } + + override fun bindString(index: Int, string: String?) { + if (string == null) statement.bindNull(index + 1) else statement.bindString(index + 1, string) + } + + override fun bindBoolean(index: Int, boolean: Boolean?) { + if (boolean == null) { + statement.bindNull(index + 1) + } else { + statement.bindLong(index + 1, if (boolean) 1L else 0L) + } + } - override fun close() { - statement.close() - } + override fun executeQuery(mapper: (SqlCursor) -> QueryResult): R = throw UnsupportedOperationException() + + override fun execute(): Long { + return statement.executeUpdateDelete().toLong() + } + + override fun close() { + statement.close() + } } private class AndroidQuery( - override val sql: String, - private val database: SupportSQLiteDatabase, - override val argCount: Int, - private val windowSizeBytes: Long?, + override val sql: String, + private val database: SupportSQLiteDatabase, + override val argCount: Int, + private val windowSizeBytes: Long?, ) : SupportSQLiteQuery, - AndroidStatement { - private val binds = MutableList<((SupportSQLiteProgram) -> Unit)?>(argCount) { null } - - override fun bindBytes(index: Int, bytes: ByteArray?) { - binds[index] = { if (bytes == null) it.bindNull(index + 1) else it.bindBlob(index + 1, bytes) } - } - - override fun bindLong(index: Int, long: Long?) { - binds[index] = { if (long == null) it.bindNull(index + 1) else it.bindLong(index + 1, long) } - } - - override fun bindDouble(index: Int, double: Double?) { - binds[index] = { if (double == null) it.bindNull(index + 1) else it.bindDouble(index + 1, double) } - } - - override fun bindString(index: Int, string: String?) { - binds[index] = { if (string == null) it.bindNull(index + 1) else it.bindString(index + 1, string) } - } - - override fun bindBoolean(index: Int, boolean: Boolean?) { - binds[index] = { - if (boolean == null) { - it.bindNull(index + 1) - } else { - it.bindLong(index + 1, if (boolean) 1L else 0L) - } + AndroidStatement { + private val binds = MutableList<((SupportSQLiteProgram) -> Unit)?>(argCount) { null } + + override fun bindBytes(index: Int, bytes: ByteArray?) { + binds[index] = { if (bytes == null) it.bindNull(index + 1) else it.bindBlob(index + 1, bytes) } + } + + override fun bindLong(index: Int, long: Long?) { + binds[index] = { if (long == null) it.bindNull(index + 1) else it.bindLong(index + 1, long) } + } + + override fun bindDouble(index: Int, double: Double?) { + binds[index] = { if (double == null) it.bindNull(index + 1) else it.bindDouble(index + 1, double) } + } + + override fun bindString(index: Int, string: String?) { + binds[index] = { if (string == null) it.bindNull(index + 1) else it.bindString(index + 1, string) } + } + + override fun bindBoolean(index: Int, boolean: Boolean?) { + binds[index] = { + if (boolean == null) { + it.bindNull(index + 1) + } else { + it.bindLong(index + 1, if (boolean) 1L else 0L) + } + } } - } - override fun execute() = throw UnsupportedOperationException() + override fun execute() = throw UnsupportedOperationException() - override fun executeQuery(mapper: (SqlCursor) -> QueryResult): R { - return database.query(this) - .use { cursor -> mapper(AndroidCursor(cursor, windowSizeBytes)).value } - } + override fun executeQuery(mapper: (SqlCursor) -> QueryResult): R { + return database.query(this) + .use { cursor -> mapper(AndroidCursor(cursor, windowSizeBytes)).value } + } - override fun bindTo(statement: SupportSQLiteProgram) { - for (action in binds) { - action!!(statement) + override fun bindTo(statement: SupportSQLiteProgram) { + for (action in binds) { + action!!(statement) + } } - } - override fun toString() = sql + override fun toString() = sql - override fun close() {} + override fun close() {} } private class AndroidCursor( - private val cursor: Cursor, - windowSizeBytes: Long?, + private val cursor: Cursor, + windowSizeBytes: Long?, ) : ColNamesSqlCursor { - init { - if ( - Build.VERSION.SDK_INT >= Build.VERSION_CODES.P && - windowSizeBytes != null && - cursor is AbstractWindowedCursor - ) { - cursor.setWindowSize(windowSizeBytes) + init { + if ( + Build.VERSION.SDK_INT >= Build.VERSION_CODES.P && + windowSizeBytes != null && + cursor is AbstractWindowedCursor + ) { + cursor.setWindowSize(windowSizeBytes) + } } - } - - override fun next(): QueryResult.Value = QueryResult.Value(cursor.moveToNext()) - override fun getString(index: Int) = if (cursor.isNull(index)) null else cursor.getString(index) - override fun getLong(index: Int) = if (cursor.isNull(index)) null else cursor.getLong(index) - override fun getBytes(index: Int) = if (cursor.isNull(index)) null else cursor.getBlob(index) - override fun getDouble(index: Int) = if (cursor.isNull(index)) null else cursor.getDouble(index) - override fun getBoolean(index: Int) = if (cursor.isNull(index)) null else cursor.getLong(index) == 1L - override fun columnName(index: Int): String? = cursor.getColumnName(index) - override val columnCount: Int = cursor.columnCount + + override fun next(): QueryResult.Value = QueryResult.Value(cursor.moveToNext()) + override fun getString(index: Int) = if (cursor.isNull(index)) null else cursor.getString(index) + override fun getLong(index: Int) = if (cursor.isNull(index)) null else cursor.getLong(index) + override fun getBytes(index: Int) = if (cursor.isNull(index)) null else cursor.getBlob(index) + override fun getDouble(index: Int) = if (cursor.isNull(index)) null else cursor.getDouble(index) + override fun getBoolean(index: Int) = if (cursor.isNull(index)) null else cursor.getLong(index) == 1L + override fun columnName(index: Int): String? = cursor.getColumnName(index) + override val columnCount: Int = cursor.columnCount } @RequiresApi(Build.VERSION_CODES.P) private object Api28Impl { - @JvmStatic - @DoNotInline - fun AbstractWindowedCursor.setWindowSize(windowSizeBytes: Long) { - window = CursorWindow(null, windowSizeBytes) - } + @JvmStatic + @DoNotInline + fun AbstractWindowedCursor.setWindowSize(windowSizeBytes: Long) { + window = CursorWindow(null, windowSizeBytes) + } } diff --git a/persistence/src/commonMain/kotlin/com/powersync/persistence/driver/ColNamesSqlCursor.kt b/persistence/src/commonMain/kotlin/com/powersync/persistence/driver/ColNamesSqlCursor.kt index 2add8ffa..a04903c5 100644 --- a/persistence/src/commonMain/kotlin/com/powersync/persistence/driver/ColNamesSqlCursor.kt +++ b/persistence/src/commonMain/kotlin/com/powersync/persistence/driver/ColNamesSqlCursor.kt @@ -2,7 +2,7 @@ package com.powersync.persistence.driver import app.cash.sqldelight.db.SqlCursor -public interface ColNamesSqlCursor: SqlCursor { +public interface ColNamesSqlCursor : SqlCursor { public fun columnName(index: Int): String? public val columnCount: Int diff --git a/persistence/src/iosMain/kotlin/com/powersync/persistence/driver/Borrowed.kt b/persistence/src/iosMain/kotlin/com/powersync/persistence/driver/Borrowed.kt index 72858220..7f39efb2 100644 --- a/persistence/src/iosMain/kotlin/com/powersync/persistence/driver/Borrowed.kt +++ b/persistence/src/iosMain/kotlin/com/powersync/persistence/driver/Borrowed.kt @@ -1,6 +1,6 @@ package com.powersync.persistence.driver internal interface Borrowed { - val value: T - fun release() + val value: T + fun release() } diff --git a/persistence/src/iosMain/kotlin/com/powersync/persistence/driver/NativeSqlDatabase.kt b/persistence/src/iosMain/kotlin/com/powersync/persistence/driver/NativeSqlDatabase.kt index 6c6797f7..b6f935cb 100644 --- a/persistence/src/iosMain/kotlin/com/powersync/persistence/driver/NativeSqlDatabase.kt +++ b/persistence/src/iosMain/kotlin/com/powersync/persistence/driver/NativeSqlDatabase.kt @@ -21,59 +21,52 @@ import co.touchlab.stately.concurrency.value import com.powersync.persistence.driver.util.PoolLock public sealed class ConnectionWrapper : SqlDriver { - internal abstract fun accessConnection( - readOnly: Boolean, - block: ThreadConnection.() -> R, - ): R - - private fun accessStatement( - readOnly: Boolean, - identifier: Int?, - sql: String, - binders: (SqlPreparedStatement.() -> Unit)?, - block: (Statement) -> R, - ): R { - println("accessStatement: $sql") - return accessConnection(readOnly) { - println("accessStatement: 1") - val statement = useStatement(identifier, sql) - println("accessStatement: 2") - try { - println("accessStatement: 3") - if (binders != null) { - SqliterStatement(statement).binders() + internal abstract fun accessConnection( + readOnly: Boolean, + block: ThreadConnection.() -> R, + ): R + + private fun accessStatement( + readOnly: Boolean, + identifier: Int?, + sql: String, + binders: (SqlPreparedStatement.() -> Unit)?, + block: (Statement) -> R, + ): R { + return accessConnection(readOnly) { + val statement = useStatement(identifier, sql) + try { + if (binders != null) { + SqliterStatement(statement).binders() + } + block(statement) + } finally { + statement.resetStatement() + clearIfNeeded(identifier, statement) + } } - println("accessStatement: 4") - val blockResult = block(statement) - println("accessStatement: 5") - blockResult - } finally { - statement.resetStatement() - clearIfNeeded(identifier, statement) - } } - } - - final override fun execute( - identifier: Int?, - sql: String, - parameters: Int, - binders: (SqlPreparedStatement.() -> Unit)?, - ): QueryResult = QueryResult.Value( - accessStatement(false, identifier, sql, binders) { statement -> - statement.executeUpdateDelete().toLong() - }, - ) - - final override fun executeQuery( - identifier: Int?, - sql: String, - mapper: (SqlCursor) -> QueryResult, - parameters: Int, - binders: (SqlPreparedStatement.() -> Unit)?, - ): QueryResult = accessStatement(true, identifier, sql, binders) { statement -> - mapper(SqliterSqlCursor(statement.query())) - } + + final override fun execute( + identifier: Int?, + sql: String, + parameters: Int, + binders: (SqlPreparedStatement.() -> Unit)?, + ): QueryResult = QueryResult.Value( + accessStatement(false, identifier, sql, binders) { statement -> + statement.executeUpdateDelete().toLong() + }, + ) + + final override fun executeQuery( + identifier: Int?, + sql: String, + mapper: (SqlCursor) -> QueryResult, + parameters: Int, + binders: (SqlPreparedStatement.() -> Unit)?, + ): QueryResult = accessStatement(true, identifier, sql, binders) { statement -> + mapper(SqliterSqlCursor(statement.query())) + } } /** @@ -105,169 +98,162 @@ public sealed class ConnectionWrapper : SqlDriver { * when execute or executeQuery is called. This avoids race conditions with bind calls. */ public class NativeSqliteDriver( - private val databaseManager: DatabaseManager, - maxReaderConnections: Int = 1, -) : ConnectionWrapper(), - SqlDriver { - public constructor( - configuration: DatabaseConfiguration, - maxReaderConnections: Int = 1, - ) : this( - databaseManager = createDatabaseManager(configuration), - maxReaderConnections = maxReaderConnections, - ) - - /** - * @param onConfiguration Callback to hook into [DatabaseConfiguration] creation. - */ - public constructor( - schema: SqlSchema>, - name: String, + private val databaseManager: DatabaseManager, maxReaderConnections: Int = 1, - onConfiguration: (DatabaseConfiguration) -> DatabaseConfiguration = { it }, - vararg callbacks: AfterVersion, - ) : this( - configuration = DatabaseConfiguration( - name = name, - version = if (schema.version > Int.MAX_VALUE) error("Schema version is larger than Int.MAX_VALUE: ${schema.version}.") else schema.version.toInt(), - create = { connection -> wrapConnection(connection) { schema.create(it) } }, - upgrade = { connection, oldVersion, newVersion -> - wrapConnection(connection) { schema.migrate(it, oldVersion.toLong(), newVersion.toLong(), *callbacks) } - }, - ).let(onConfiguration), - maxReaderConnections = maxReaderConnections, - ) - - // A pool of reader connections used by all operations not in a transaction - private val transactionPool: Pool - internal val readerPool: Pool - - // Once a transaction is started and connection borrowed, it will be here, but only for that - // thread - private val borrowedConnectionThread = ThreadLocalRef>() - private val listeners = mutableMapOf>() - private val lock = PoolLock(reentrant = true) - - init { - if (databaseManager.configuration.isEphemeral) { - // Single connection for transactions - transactionPool = Pool(1) { - ThreadConnection(databaseManager.createMultiThreadedConnection()) { _ -> - borrowedConnectionThread.let { - it.get()?.release() - it.value = null - } - } - } - - readerPool = transactionPool - } else { - // Single connection for transactions - transactionPool = Pool(1) { - ThreadConnection(databaseManager.createMultiThreadedConnection()) { _ -> - borrowedConnectionThread.let { - it.get()?.release() - it.value = null - } - } - } - - readerPool = Pool(maxReaderConnections) { - val connection = databaseManager.createMultiThreadedConnection() - connection.withStatement("PRAGMA query_only = 1") { execute() } // Ensure read only - ThreadConnection(connection) { - throw UnsupportedOperationException("Should never be in a transaction") - } - } +) : ConnectionWrapper(), + SqlDriver { + public constructor( + configuration: DatabaseConfiguration, + maxReaderConnections: Int = 1, + ) : this( + databaseManager = createDatabaseManager(configuration), + maxReaderConnections = maxReaderConnections, + ) + + /** + * @param onConfiguration Callback to hook into [DatabaseConfiguration] creation. + */ + public constructor( + schema: SqlSchema>, + name: String, + maxReaderConnections: Int = 1, + onConfiguration: (DatabaseConfiguration) -> DatabaseConfiguration = { it }, + vararg callbacks: AfterVersion, + ) : this( + configuration = DatabaseConfiguration( + name = name, + version = if (schema.version > Int.MAX_VALUE) error("Schema version is larger than Int.MAX_VALUE: ${schema.version}.") else schema.version.toInt(), + create = { connection -> wrapConnection(connection) { schema.create(it) } }, + upgrade = { connection, oldVersion, newVersion -> + wrapConnection(connection) { schema.migrate(it, oldVersion.toLong(), newVersion.toLong(), *callbacks) } + }, + ).let(onConfiguration), + maxReaderConnections = maxReaderConnections, + ) + + // A pool of reader connections used by all operations not in a transaction + private val transactionPool: Pool + internal val readerPool: Pool + + // Once a transaction is started and connection borrowed, it will be here, but only for that + // thread + private val borrowedConnectionThread = ThreadLocalRef>() + private val listeners = mutableMapOf>() + private val lock = PoolLock(reentrant = true) + + init { + if (databaseManager.configuration.isEphemeral) { + // Single connection for transactions + transactionPool = Pool(1) { + ThreadConnection(databaseManager.createMultiThreadedConnection()) { _ -> + borrowedConnectionThread.let { + it.get()?.release() + it.value = null + } + } + } + + readerPool = transactionPool + } else { + // Single connection for transactions + transactionPool = Pool(1) { + ThreadConnection(databaseManager.createMultiThreadedConnection()) { _ -> + borrowedConnectionThread.let { + it.get()?.release() + it.value = null + } + } + } + + readerPool = Pool(maxReaderConnections) { + val connection = databaseManager.createMultiThreadedConnection() + connection.withStatement("PRAGMA query_only = 1") { execute() } // Ensure read only + ThreadConnection(connection) { + throw UnsupportedOperationException("Should never be in a transaction") + } + } + } + } + + override fun addListener(vararg queryKeys: String, listener: Query.Listener) { + lock.withLock { + queryKeys.forEach { + listeners.getOrPut(it) { mutableSetOf() }.add(listener) + } + } } - } - override fun addListener(vararg queryKeys: String, listener: Query.Listener) { - lock.withLock { - queryKeys.forEach { - listeners.getOrPut(it) { mutableSetOf() }.add(listener) - } + override fun removeListener(vararg queryKeys: String, listener: Query.Listener) { + lock.withLock { + queryKeys.forEach { + listeners.get(it)?.remove(listener) + } + } } - } - override fun removeListener(vararg queryKeys: String, listener: Query.Listener) { - lock.withLock { - queryKeys.forEach { - listeners.get(it)?.remove(listener) - } + override fun notifyListeners(vararg queryKeys: String) { + val listenersToNotify = mutableSetOf() + lock.withLock { + queryKeys.forEach { key -> listeners.get(key)?.let { listenersToNotify.addAll(it) } } + } + listenersToNotify.forEach(Query.Listener::queryResultsChanged) } - } - override fun notifyListeners(vararg queryKeys: String) { - val listenersToNotify = mutableSetOf() - lock.withLock { - queryKeys.forEach { key -> listeners.get(key)?.let { listenersToNotify.addAll(it) } } + override fun currentTransaction(): Transacter.Transaction? { + return borrowedConnectionThread.get()?.value?.transaction?.value } - listenersToNotify.forEach(Query.Listener::queryResultsChanged) - } - - override fun currentTransaction(): Transacter.Transaction? { - println("currentTransaction() thread id: ${currentThreadId()}") - return borrowedConnectionThread.get()?.value?.transaction?.value - } - - override fun newTransaction(): QueryResult { - println("newTransaction() thread id: ${currentThreadId()}") - val alreadyBorrowed = borrowedConnectionThread.get() - val transaction = if (alreadyBorrowed == null) { - val borrowed = transactionPool.borrowEntry() - - try { - val trans = borrowed.value.newTransaction() - - borrowedConnectionThread.value = borrowed - trans - } catch (e: Throwable) { - // Unlock on failure. - borrowed.release() - throw e - } - } else { - alreadyBorrowed.value.newTransaction() + + override fun newTransaction(): QueryResult { + val alreadyBorrowed = borrowedConnectionThread.get() + val transaction = if (alreadyBorrowed == null) { + val borrowed = transactionPool.borrowEntry() + + try { + val trans = borrowed.value.newTransaction() + + borrowedConnectionThread.value = borrowed + trans + } catch (e: Throwable) { + // Unlock on failure. + borrowed.release() + throw e + } + } else { + alreadyBorrowed.value.newTransaction() + } + + return QueryResult.Value(transaction) } - return QueryResult.Value(transaction) - } - - /** - * If we're in a transaction, then I have a connection. Otherwise use shared. - */ - override fun accessConnection( - readOnly: Boolean, - block: ThreadConnection.() -> R, - ): R { - println("accessConnection() thread id: ${currentThreadId()}") - val mine = borrowedConnectionThread.get() - println("accessConnection() with connection $mine. Thread id: ${currentThreadId()}") - return if (readOnly) { - // Code intends to read, which doesn't need to block - if (mine != null) { - mine.value.block() - } else { - println("accessConnection() before readerPool") - val conn = readerPool.access(block) - println("accessConnection() after readerPool") - conn - } - } else { - // Code intends to write, for which we're managing locks in code - if (mine != null) { - mine.value.block() - } else { - transactionPool.access(block) - } + /** + * If we're in a transaction, then I have a connection. Otherwise use shared. + */ + override fun accessConnection( + readOnly: Boolean, + block: ThreadConnection.() -> R, + ): R { + val mine = borrowedConnectionThread.get() + return if (readOnly) { + // Code intends to read, which doesn't need to block + if (mine != null) { + mine.value.block() + } else { + readerPool.access(block) + } + } else { + // Code intends to write, for which we're managing locks in code + if (mine != null) { + mine.value.block() + } else { + transactionPool.access(block) + } + } } - } - override fun close() { - transactionPool.close() - readerPool.close() - } + override fun close() { + transactionPool.close() + readerPool.close() + } } /** @@ -275,17 +261,17 @@ public class NativeSqliteDriver( * concurrent access will be block */ public fun inMemoryDriver(schema: SqlSchema>): NativeSqliteDriver = NativeSqliteDriver( - DatabaseConfiguration( - name = null, - inMemory = true, - version = if (schema.version > Int.MAX_VALUE) error("Schema version is larger than Int.MAX_VALUE: ${schema.version}.") else schema.version.toInt(), - create = { connection -> - wrapConnection(connection) { schema.create(it) } - }, - upgrade = { connection, oldVersion, newVersion -> - wrapConnection(connection) { schema.migrate(it, oldVersion.toLong(), newVersion.toLong()) } - }, - ), + DatabaseConfiguration( + name = null, + inMemory = true, + version = if (schema.version > Int.MAX_VALUE) error("Schema version is larger than Int.MAX_VALUE: ${schema.version}.") else schema.version.toInt(), + create = { connection -> + wrapConnection(connection) { schema.create(it) } + }, + upgrade = { connection, oldVersion, newVersion -> + wrapConnection(connection) { schema.migrate(it, oldVersion.toLong(), newVersion.toLong()) } + }, + ), ) /** @@ -297,15 +283,15 @@ public fun inMemoryDriver(schema: SqlSchema>): NativeSql * a wrap call, it will no longer be viable. */ public fun wrapConnection( - connection: DatabaseConnection, - block: (SqlDriver) -> Unit, + connection: DatabaseConnection, + block: (SqlDriver) -> Unit, ) { - val conn = SqliterWrappedConnection(ThreadConnection(connection) {}) - try { - block(conn) - } finally { - conn.close() - } + val conn = SqliterWrappedConnection(ThreadConnection(connection) {}) + try { + block(conn) + } finally { + conn.close() + } } /** @@ -313,33 +299,34 @@ public fun wrapConnection( * don't want the polling. */ internal class SqliterWrappedConnection( - private val threadConnection: ThreadConnection, + private val threadConnection: ThreadConnection, ) : ConnectionWrapper(), - SqlDriver { - override fun currentTransaction(): Transacter.Transaction? = threadConnection.transaction.value + SqlDriver { + override fun currentTransaction(): Transacter.Transaction? = threadConnection.transaction.value - override fun newTransaction(): QueryResult = QueryResult.Value(threadConnection.newTransaction()) + override fun newTransaction(): QueryResult = + QueryResult.Value(threadConnection.newTransaction()) - override fun accessConnection( - readOnly: Boolean, - block: ThreadConnection.() -> R, - ): R = threadConnection.block() + override fun accessConnection( + readOnly: Boolean, + block: ThreadConnection.() -> R, + ): R = threadConnection.block() - override fun addListener(vararg queryKeys: String, listener: Query.Listener) { - // No-op - } + override fun addListener(vararg queryKeys: String, listener: Query.Listener) { + // No-op + } - override fun removeListener(vararg queryKeys: String, listener: Query.Listener) { - // No-op - } + override fun removeListener(vararg queryKeys: String, listener: Query.Listener) { + // No-op + } - override fun notifyListeners(vararg queryKeys: String) { - // No-op - } + override fun notifyListeners(vararg queryKeys: String) { + // No-op + } - override fun close() { - threadConnection.cleanUp() - } + override fun close() { + threadConnection.cleanUp() + } } /** @@ -350,84 +337,85 @@ internal class SqliterWrappedConnection( * properly in cases where the user does not. */ internal class ThreadConnection( - private val connection: DatabaseConnection, - private val onEndTransaction: (ThreadConnection) -> Unit, + private val connection: DatabaseConnection, + private val onEndTransaction: (ThreadConnection) -> Unit, ) : Closeable { - internal val transaction = ThreadLocalRef() - private val closed: Boolean - get() = connection.closed - - private val statementCache = mutableMapOf() - - fun useStatement(identifier: Int?, sql: String): Statement { - return if (identifier != null) { - statementCache.getOrPut(identifier) { - connection.createStatement(sql) - } - } else { - connection.createStatement(sql) + internal val transaction = ThreadLocalRef() + private val closed: Boolean + get() = connection.closed + + private val statementCache = mutableMapOf() + + fun useStatement(identifier: Int?, sql: String): Statement { + return if (identifier != null) { + statementCache.getOrPut(identifier) { + connection.createStatement(sql) + } + } else { + connection.createStatement(sql) + } } - } - fun clearIfNeeded(identifier: Int?, statement: Statement) { - if (identifier == null || closed) { - statement.finalizeStatement() + fun clearIfNeeded(identifier: Int?, statement: Statement) { + if (identifier == null || closed) { + statement.finalizeStatement() + } } - } - fun newTransaction(): Transacter.Transaction { - val enclosing = transaction.value + fun newTransaction(): Transacter.Transaction { + val enclosing = transaction.value - // Create here, in case we bomb... - if (enclosing == null) { - connection.beginTransaction() - } + // Create here, in case we bomb... + if (enclosing == null) { + connection.beginTransaction() + } + + val trans = Transaction(enclosing) + transaction.value = trans - val trans = Transaction(enclosing) - transaction.value = trans + return trans + } - return trans - } + /** + * This should only be called directly from wrapConnection. Clean resources without actually closing + * the underlying connection. + */ + internal fun cleanUp() { + statementCache.values.forEach { it: Statement -> + it.finalizeStatement() + } + } - /** - * This should only be called directly from wrapConnection. Clean resources without actually closing - * the underlying connection. - */ - internal fun cleanUp() { - statementCache.values.forEach { it: Statement -> - it.finalizeStatement() + override fun close() { + cleanUp() + connection.close() } - } - - override fun close() { - cleanUp() - connection.close() - } - - private inner class Transaction( - override val enclosingTransaction: Transacter.Transaction?, - ) : Transacter.Transaction() { - - override fun endTransaction(successful: Boolean): QueryResult { - transaction.value = enclosingTransaction - - if (enclosingTransaction == null) { - try { - if (successful) { - connection.setTransactionSuccessful() - } - - connection.endTransaction() - } finally { - // Release if we have - onEndTransaction(this@ThreadConnection) + + private inner class Transaction( + override val enclosingTransaction: Transacter.Transaction?, + ) : Transacter.Transaction() { + + override fun endTransaction(successful: Boolean): QueryResult { + transaction.value = enclosingTransaction + + if (enclosingTransaction == null) { + try { + if (successful) { + connection.setTransactionSuccessful() + } + + connection.endTransaction() + } finally { + // Release if we have + onEndTransaction(this@ThreadConnection) + } + } + return QueryResult.Unit } - } - return QueryResult.Unit } - } } -private inline val DatabaseConfiguration.isEphemeral: Boolean get() { - return inMemory || (name?.isEmpty() == true && extendedConfig.basePath?.isEmpty() == true) -} +private inline val DatabaseConfiguration.isEphemeral: Boolean + get() { + return inMemory || (name?.isEmpty() == true && extendedConfig.basePath?.isEmpty() == true) + } diff --git a/persistence/src/iosMain/kotlin/com/powersync/persistence/driver/Pool.kt b/persistence/src/iosMain/kotlin/com/powersync/persistence/driver/Pool.kt index d94b2d5e..78c1f785 100644 --- a/persistence/src/iosMain/kotlin/com/powersync/persistence/driver/Pool.kt +++ b/persistence/src/iosMain/kotlin/com/powersync/persistence/driver/Pool.kt @@ -10,114 +10,111 @@ import kotlin.concurrent.AtomicReference * designated capacity. */ internal class Pool(internal val capacity: Int, private val producer: () -> T) { - /** - * Hold a list of active connections. If it is null, it means the MultiPool has been closed. - */ - private val entriesRef = AtomicReference?>(listOf()) - private val poolLock = PoolLock() - - /** - * For test purposes only - */ - internal fun entryCount(): Int = poolLock.withLock { - entriesRef.value?.size ?: 0 - } - - fun borrowEntry(): Borrowed { - val snapshot = entriesRef.value ?: throw ClosedMultiPoolException - - // Fastpath: Borrow the first available entry. - val firstAvailable = snapshot.firstOrNull { it.tryToAcquire() } - - if (firstAvailable != null) { - return firstAvailable.asBorrowed(poolLock) + /** + * Hold a list of active connections. If it is null, it means the MultiPool has been closed. + */ + private val entriesRef = AtomicReference?>(listOf()) + private val poolLock = PoolLock() + + /** + * For test purposes only + */ + internal fun entryCount(): Int = poolLock.withLock { + entriesRef.value?.size ?: 0 } - // Slowpath: Create a new entry if capacity limit has not been reached, or wait for the next available entry. - val nextAvailable = poolLock.withLock { - // Reload the list since it could've been updated by other threads concurrently. - val entries = entriesRef.value ?: throw ClosedMultiPoolException + fun borrowEntry(): Borrowed { + val snapshot = entriesRef.value ?: throw ClosedMultiPoolException - if (entries.count() < capacity) { - // Capacity hasn't been reached — create a new entry to serve this call. - val newEntry = Entry(producer()) - val done = newEntry.tryToAcquire() - check(done) + // Fastpath: Borrow the first available entry. + val firstAvailable = snapshot.firstOrNull { it.tryToAcquire() } - entriesRef.value = (entries + listOf(newEntry)) - return@withLock newEntry - } else { - // Capacity is reached — wait for the next available entry. - return@withLock loopForConditionalResult { - // Reload the list, since the thread can be suspended here while the list of entries has been modified. - val innerEntries = entriesRef.value ?: throw ClosedMultiPoolException - innerEntries.firstOrNull { it.tryToAcquire() } + if (firstAvailable != null) { + return firstAvailable.asBorrowed(poolLock) } - } - } - return nextAvailable.asBorrowed(poolLock) - } - - fun access(action: (T) -> R): R { - val borrowed = borrowEntry() - return try { - println("before access, capacity: $capacity") - val result = action(borrowed.value) - println("after access") - result - } finally { - borrowed.release() - } - } + // Slowpath: Create a new entry if capacity limit has not been reached, or wait for the next available entry. + val nextAvailable = poolLock.withLock { + // Reload the list since it could've been updated by other threads concurrently. + val entries = entriesRef.value ?: throw ClosedMultiPoolException + + if (entries.count() < capacity) { + // Capacity hasn't been reached — create a new entry to serve this call. + val newEntry = Entry(producer()) + val done = newEntry.tryToAcquire() + check(done) + + entriesRef.value = (entries + listOf(newEntry)) + return@withLock newEntry + } else { + // Capacity is reached — wait for the next available entry. + return@withLock loopForConditionalResult { + // Reload the list, since the thread can be suspended here while the list of entries has been modified. + val innerEntries = entriesRef.value ?: throw ClosedMultiPoolException + innerEntries.firstOrNull { it.tryToAcquire() } + } + } + } - fun close() { - if (!poolLock.close()) { - return + return nextAvailable.asBorrowed(poolLock) } - val entries = entriesRef.value - val done = entriesRef.compareAndSet(entries, null) - check(done) - - entries?.forEach { it.value.close() } - } - - inner class Entry(val value: T) { - val isAvailable = AtomicBoolean(true) - - fun tryToAcquire(): Boolean = isAvailable.compareAndSet(expected = true, new = false) - - fun asBorrowed(poolLock: PoolLock): Borrowed = object : Borrowed { - override val value: T - get() = this@Entry.value + fun access(action: (T) -> R): R { + val borrowed = borrowEntry() + return try { + action(borrowed.value) + } finally { + borrowed.release() + } + } - override fun release() { - /** - * Mark-as-available should be done before signalling blocked threads via [PoolLock.notifyConditionChanged], - * since the happens-before relationship guarantees the woken thread to see the - * available entry (if not having been taken by other threads during the wake-up lead time). - */ + fun close() { + if (!poolLock.close()) { + return + } - val done = isAvailable.compareAndSet(expected = false, new = true) + val entries = entriesRef.value + val done = entriesRef.compareAndSet(entries, null) check(done) - // While signalling blocked threads does not require locking, doing so avoids a subtle race - // condition in which: - // - // 1. a [loopForConditionalResult] iteration in [borrowEntry] slow path is happening concurrently; - // 2. the iteration fails to see the atomic `isAvailable = true` above; - // 3. we signal availability here but it is a no-op due to no waiting blocker; and finally - // 4. the iteration entered an indefinite blocking wait, not being aware of us having signalled availability here. - // - // By acquiring the pool lock first, signalling cannot happen concurrently with the loop - // iterations in [borrowEntry], thus eliminating the race condition. - poolLock.withLock { - poolLock.notifyConditionChanged() + entries?.forEach { it.value.close() } + } + + inner class Entry(val value: T) { + val isAvailable = AtomicBoolean(true) + + fun tryToAcquire(): Boolean = isAvailable.compareAndSet(expected = true, new = false) + + fun asBorrowed(poolLock: PoolLock): Borrowed = object : Borrowed { + override val value: T + get() = this@Entry.value + + override fun release() { + /** + * Mark-as-available should be done before signalling blocked threads via [PoolLock.notifyConditionChanged], + * since the happens-before relationship guarantees the woken thread to see the + * available entry (if not having been taken by other threads during the wake-up lead time). + */ + + val done = isAvailable.compareAndSet(expected = false, new = true) + check(done) + + // While signalling blocked threads does not require locking, doing so avoids a subtle race + // condition in which: + // + // 1. a [loopForConditionalResult] iteration in [borrowEntry] slow path is happening concurrently; + // 2. the iteration fails to see the atomic `isAvailable = true` above; + // 3. we signal availability here but it is a no-op due to no waiting blocker; and finally + // 4. the iteration entered an indefinite blocking wait, not being aware of us having signalled availability here. + // + // By acquiring the pool lock first, signalling cannot happen concurrently with the loop + // iterations in [borrowEntry], thus eliminating the race condition. + poolLock.withLock { + poolLock.notifyConditionChanged() + } + } } - } } - } } private val ClosedMultiPoolException get() = IllegalStateException("Attempt to access a closed MultiPool.") diff --git a/persistence/src/iosMain/kotlin/com/powersync/persistence/driver/SqliterSqlCursor.kt b/persistence/src/iosMain/kotlin/com/powersync/persistence/driver/SqliterSqlCursor.kt index b1c36090..727913ce 100644 --- a/persistence/src/iosMain/kotlin/com/powersync/persistence/driver/SqliterSqlCursor.kt +++ b/persistence/src/iosMain/kotlin/com/powersync/persistence/driver/SqliterSqlCursor.kt @@ -1,7 +1,6 @@ package com.powersync.persistence.driver import app.cash.sqldelight.db.QueryResult -import app.cash.sqldelight.db.SqlCursor import co.touchlab.sqliter.Cursor import co.touchlab.sqliter.getBytesOrNull import co.touchlab.sqliter.getDoubleOrNull @@ -14,21 +13,21 @@ import co.touchlab.sqliter.getStringOrNull * throwing errors if you're trying to access it. */ internal class SqliterSqlCursor(private val cursor: Cursor) : ColNamesSqlCursor { - override fun getBytes(index: Int): ByteArray? = cursor.getBytesOrNull(index) + override fun getBytes(index: Int): ByteArray? = cursor.getBytesOrNull(index) - override fun getDouble(index: Int): Double? = cursor.getDoubleOrNull(index) + override fun getDouble(index: Int): Double? = cursor.getDoubleOrNull(index) - override fun getLong(index: Int): Long? = cursor.getLongOrNull(index) + override fun getLong(index: Int): Long? = cursor.getLongOrNull(index) - override fun getString(index: Int): String? = cursor.getStringOrNull(index) + override fun getString(index: Int): String? = cursor.getStringOrNull(index) - override fun getBoolean(index: Int): Boolean? { - return (cursor.getLongOrNull(index) ?: return null) == 1L - } + override fun getBoolean(index: Int): Boolean? { + return (cursor.getLongOrNull(index) ?: return null) == 1L + } - override fun columnName(index: Int): String? = cursor.columnName(index) + override fun columnName(index: Int): String? = cursor.columnName(index) - override val columnCount: Int = cursor.columnCount + override val columnCount: Int = cursor.columnCount - override fun next(): QueryResult.Value = QueryResult.Value(cursor.next()) + override fun next(): QueryResult.Value = QueryResult.Value(cursor.next()) } diff --git a/persistence/src/iosMain/kotlin/com/powersync/persistence/driver/SqliterStatement.kt b/persistence/src/iosMain/kotlin/com/powersync/persistence/driver/SqliterStatement.kt index 44f8a5c1..f78f2ef1 100644 --- a/persistence/src/iosMain/kotlin/com/powersync/persistence/driver/SqliterStatement.kt +++ b/persistence/src/iosMain/kotlin/com/powersync/persistence/driver/SqliterStatement.kt @@ -11,32 +11,32 @@ import co.touchlab.sqliter.bindString * @param [recycle] A function which recycles any resources this statement is backed by. */ internal class SqliterStatement( - private val statement: Statement, + private val statement: Statement, ) : SqlPreparedStatement { - override fun bindBytes(index: Int, bytes: ByteArray?) { - statement.bindBlob(index + 1, bytes) - } + override fun bindBytes(index: Int, bytes: ByteArray?) { + statement.bindBlob(index + 1, bytes) + } - override fun bindLong(index: Int, long: Long?) { - statement.bindLong(index + 1, long) - } + override fun bindLong(index: Int, long: Long?) { + statement.bindLong(index + 1, long) + } - override fun bindDouble(index: Int, double: Double?) { - statement.bindDouble(index + 1, double) - } + override fun bindDouble(index: Int, double: Double?) { + statement.bindDouble(index + 1, double) + } - override fun bindString(index: Int, string: String?) { - statement.bindString(index + 1, string) - } + override fun bindString(index: Int, string: String?) { + statement.bindString(index + 1, string) + } - override fun bindBoolean(index: Int, boolean: Boolean?) { - statement.bindLong( - index + 1, - when (boolean) { - null -> null - true -> 1L - false -> 0L - }, - ) - } + override fun bindBoolean(index: Int, boolean: Boolean?) { + statement.bindLong( + index + 1, + when (boolean) { + null -> null + true -> 1L + false -> 0L + }, + ) + } } diff --git a/persistence/src/iosMain/kotlin/com/powersync/persistence/driver/util/PoolLock.kt b/persistence/src/iosMain/kotlin/com/powersync/persistence/driver/util/PoolLock.kt index a68cc553..152210f8 100644 --- a/persistence/src/iosMain/kotlin/com/powersync/persistence/driver/util/PoolLock.kt +++ b/persistence/src/iosMain/kotlin/com/powersync/persistence/driver/util/PoolLock.kt @@ -23,67 +23,67 @@ import platform.posix.pthread_mutexattr_t @OptIn(ExperimentalForeignApi::class) internal class PoolLock constructor(reentrant: Boolean = false) { - private val isActive = AtomicBoolean(true) - - private val attr = nativeHeap.alloc() - .apply { - pthread_mutexattr_init(ptr) - if (reentrant) { - pthread_mutexattr_settype(ptr, platform.posix.PTHREAD_MUTEX_RECURSIVE) - } - } - private val mutex = nativeHeap.alloc() - .apply { pthread_mutex_init(ptr, attr.ptr) } - private val cond = nativeHeap.alloc() - .apply { pthread_cond_init(ptr, null) } + private val isActive = AtomicBoolean(true) - fun withLock( - action: CriticalSection.() -> R, - ): R { - check(isActive.value) - pthread_mutex_lock(mutex.ptr) + private val attr = nativeHeap.alloc() + .apply { + pthread_mutexattr_init(ptr) + if (reentrant) { + pthread_mutexattr_settype(ptr, platform.posix.PTHREAD_MUTEX_RECURSIVE) + } + } + private val mutex = nativeHeap.alloc() + .apply { pthread_mutex_init(ptr, attr.ptr) } + private val cond = nativeHeap.alloc() + .apply { pthread_cond_init(ptr, null) } - val result: R + fun withLock( + action: CriticalSection.() -> R, + ): R { + check(isActive.value) + pthread_mutex_lock(mutex.ptr) - try { - result = action(CriticalSection()) - } finally { - pthread_mutex_unlock(mutex.ptr) - } + val result: R - return result - } + try { + result = action(CriticalSection()) + } finally { + pthread_mutex_unlock(mutex.ptr) + } - fun notifyConditionChanged() { - pthread_cond_signal(cond.ptr) - } + return result + } - fun close(): Boolean { - if (isActive.compareAndSet(expected = true, new = false)) { - pthread_cond_destroy(cond.ptr) - pthread_mutex_destroy(mutex.ptr) - pthread_mutexattr_destroy(attr.ptr) - nativeHeap.free(cond) - nativeHeap.free(mutex) - nativeHeap.free(attr) - return true + fun notifyConditionChanged() { + pthread_cond_signal(cond.ptr) } - return false - } + fun close(): Boolean { + if (isActive.compareAndSet(expected = true, new = false)) { + pthread_cond_destroy(cond.ptr) + pthread_mutex_destroy(mutex.ptr) + pthread_mutexattr_destroy(attr.ptr) + nativeHeap.free(cond) + nativeHeap.free(mutex) + nativeHeap.free(attr) + return true + } + + return false + } - inner class CriticalSection { - fun loopForConditionalResult(block: () -> R?): R { - check(isActive.value) + inner class CriticalSection { + fun loopForConditionalResult(block: () -> R?): R { + check(isActive.value) - var result = block() + var result = block() - while (result == null) { - pthread_cond_wait(cond.ptr, mutex.ptr) - result = block() - } + while (result == null) { + pthread_cond_wait(cond.ptr, mutex.ptr) + result = block() + } - return result + return result + } } - } } \ No newline at end of file diff --git a/persistence/src/jvmMain/kotlin/com/powersync/persistence/driver/JdbcPreparedStatement.kt b/persistence/src/jvmMain/kotlin/com/powersync/persistence/driver/JdbcPreparedStatement.kt index 014e480a..59ac7d44 100644 --- a/persistence/src/jvmMain/kotlin/com/powersync/persistence/driver/JdbcPreparedStatement.kt +++ b/persistence/src/jvmMain/kotlin/com/powersync/persistence/driver/JdbcPreparedStatement.kt @@ -9,125 +9,125 @@ import java.sql.ResultSet import java.sql.Types /** -* Binds the parameter to [preparedStatement] by calling [bindString], [bindLong] or similar. -* After binding, [execute] executes the query without a result, while [executeQuery] returns [JdbcCursor]. -*/ + * Binds the parameter to [preparedStatement] by calling [bindString], [bindLong] or similar. + * After binding, [execute] executes the query without a result, while [executeQuery] returns [JdbcCursor]. + */ public class JdbcPreparedStatement( - private val preparedStatement: PreparedStatement, + private val preparedStatement: PreparedStatement, ) : SqlPreparedStatement { - override fun bindBytes(index: Int, bytes: ByteArray?) { - preparedStatement.setBytes(index + 1, bytes) - } - - override fun bindBoolean(index: Int, boolean: Boolean?) { - if (boolean == null) { - preparedStatement.setNull(index + 1, Types.BOOLEAN) - } else { - preparedStatement.setBoolean(index + 1, boolean) - } - } - - public fun bindByte(index: Int, byte: Byte?) { - if (byte == null) { - preparedStatement.setNull(index + 1, Types.TINYINT) - } else { - preparedStatement.setByte(index + 1, byte) - } - } - - public fun bindShort(index: Int, short: Short?) { - if (short == null) { - preparedStatement.setNull(index + 1, Types.SMALLINT) - } else { - preparedStatement.setShort(index + 1, short) - } - } - - public fun bindInt(index: Int, int: Int?) { - if (int == null) { - preparedStatement.setNull(index + 1, Types.INTEGER) - } else { - preparedStatement.setInt(index + 1, int) - } - } - - override fun bindLong(index: Int, long: Long?) { - if (long == null) { - preparedStatement.setNull(index + 1, Types.BIGINT) - } else { - preparedStatement.setLong(index + 1, long) - } - } - - public fun bindFloat(index: Int, float: Float?) { - if (float == null) { - preparedStatement.setNull(index + 1, Types.REAL) - } else { - preparedStatement.setFloat(index + 1, float) - } - } - - override fun bindDouble(index: Int, double: Double?) { - if (double == null) { - preparedStatement.setNull(index + 1, Types.DOUBLE) - } else { - preparedStatement.setDouble(index + 1, double) - } - } - - public fun bindBigDecimal(index: Int, decimal: BigDecimal?) { - preparedStatement.setBigDecimal(index + 1, decimal) - } - - public fun bindObject(index: Int, obj: Any?) { - if (obj == null) { - preparedStatement.setNull(index + 1, Types.OTHER) - } else { - preparedStatement.setObject(index + 1, obj) - } - } - - public fun bindObject(index: Int, obj: Any?, type: Int) { - if (obj == null) { - preparedStatement.setNull(index + 1, type) - } else { - preparedStatement.setObject(index + 1, obj, type) - } - } - - override fun bindString(index: Int, string: String?) { - preparedStatement.setString(index + 1, string) - } - - public fun bindDate(index: Int, date: java.sql.Date?) { - preparedStatement.setDate(index, date) - } - - public fun bindTime(index: Int, date: java.sql.Time?) { - preparedStatement.setTime(index, date) - } - - public fun bindTimestamp(index: Int, timestamp: java.sql.Timestamp?) { - preparedStatement.setTimestamp(index, timestamp) - } - - public fun executeQuery(mapper: (SqlCursor) -> R): R { - try { - return preparedStatement.executeQuery() - .use { resultSet -> mapper(JdbcCursor(resultSet)) } - } finally { - preparedStatement.close() + override fun bindBytes(index: Int, bytes: ByteArray?) { + preparedStatement.setBytes(index + 1, bytes) + } + + override fun bindBoolean(index: Int, boolean: Boolean?) { + if (boolean == null) { + preparedStatement.setNull(index + 1, Types.BOOLEAN) + } else { + preparedStatement.setBoolean(index + 1, boolean) + } + } + + public fun bindByte(index: Int, byte: Byte?) { + if (byte == null) { + preparedStatement.setNull(index + 1, Types.TINYINT) + } else { + preparedStatement.setByte(index + 1, byte) + } + } + + public fun bindShort(index: Int, short: Short?) { + if (short == null) { + preparedStatement.setNull(index + 1, Types.SMALLINT) + } else { + preparedStatement.setShort(index + 1, short) + } + } + + public fun bindInt(index: Int, int: Int?) { + if (int == null) { + preparedStatement.setNull(index + 1, Types.INTEGER) + } else { + preparedStatement.setInt(index + 1, int) + } + } + + override fun bindLong(index: Int, long: Long?) { + if (long == null) { + preparedStatement.setNull(index + 1, Types.BIGINT) + } else { + preparedStatement.setLong(index + 1, long) + } + } + + public fun bindFloat(index: Int, float: Float?) { + if (float == null) { + preparedStatement.setNull(index + 1, Types.REAL) + } else { + preparedStatement.setFloat(index + 1, float) + } + } + + override fun bindDouble(index: Int, double: Double?) { + if (double == null) { + preparedStatement.setNull(index + 1, Types.DOUBLE) + } else { + preparedStatement.setDouble(index + 1, double) + } + } + + public fun bindBigDecimal(index: Int, decimal: BigDecimal?) { + preparedStatement.setBigDecimal(index + 1, decimal) + } + + public fun bindObject(index: Int, obj: Any?) { + if (obj == null) { + preparedStatement.setNull(index + 1, Types.OTHER) + } else { + preparedStatement.setObject(index + 1, obj) + } + } + + public fun bindObject(index: Int, obj: Any?, type: Int) { + if (obj == null) { + preparedStatement.setNull(index + 1, type) + } else { + preparedStatement.setObject(index + 1, obj, type) + } + } + + override fun bindString(index: Int, string: String?) { + preparedStatement.setString(index + 1, string) + } + + public fun bindDate(index: Int, date: java.sql.Date?) { + preparedStatement.setDate(index, date) + } + + public fun bindTime(index: Int, date: java.sql.Time?) { + preparedStatement.setTime(index, date) + } + + public fun bindTimestamp(index: Int, timestamp: java.sql.Timestamp?) { + preparedStatement.setTimestamp(index, timestamp) + } + + public fun executeQuery(mapper: (SqlCursor) -> R): R { + try { + return preparedStatement.executeQuery() + .use { resultSet -> mapper(JdbcCursor(resultSet)) } + } finally { + preparedStatement.close() + } } - } - - public fun execute(): Long { - return if (preparedStatement.execute()) { - // returned true so this is a result set return type. - 0L - } else { - preparedStatement.updateCount.toLong() + + public fun execute(): Long { + return if (preparedStatement.execute()) { + // returned true so this is a result set return type. + 0L + } else { + preparedStatement.updateCount.toLong() + } } - } } /** @@ -135,29 +135,29 @@ public class JdbcPreparedStatement( * Use [next] to retrieve the next row and [close] to close the connection. */ internal class JdbcCursor(val resultSet: ResultSet) : ColNamesSqlCursor { - override fun getString(index: Int): String? = resultSet.getString(index + 1) - override fun getBytes(index: Int): ByteArray? = resultSet.getBytes(index + 1) - override fun getBoolean(index: Int): Boolean? = getAtIndex(index, resultSet::getBoolean) - override fun columnName(index: Int): String? = resultSet.metaData.getColumnName(index) - override val columnCount: Int = resultSet.metaData.columnCount - - fun getByte(index: Int): Byte? = getAtIndex(index, resultSet::getByte) - fun getShort(index: Int): Short? = getAtIndex(index, resultSet::getShort) - fun getInt(index: Int): Int? = getAtIndex(index, resultSet::getInt) - override fun getLong(index: Int): Long? = getAtIndex(index, resultSet::getLong) - fun getFloat(index: Int): Float? = getAtIndex(index, resultSet::getFloat) - override fun getDouble(index: Int): Double? = getAtIndex(index, resultSet::getDouble) - fun getBigDecimal(index: Int): BigDecimal? = resultSet.getBigDecimal(index + 1) - inline fun getObject(index: Int): T? = resultSet.getObject(index + 1, T::class.java) - fun getDate(index: Int): java.sql.Date? = resultSet.getDate(index) - fun getTime(index: Int): java.sql.Time? = resultSet.getTime(index) - fun getTimestamp(index: Int): java.sql.Timestamp? = resultSet.getTimestamp(index) - - @Suppress("UNCHECKED_CAST") - fun getArray(index: Int) = getAtIndex(index, resultSet::getArray)?.array as Array? - - private fun getAtIndex(index: Int, converter: (Int) -> T): T? = - converter(index + 1).takeUnless { resultSet.wasNull() } - - override fun next(): QueryResult.Value = QueryResult.Value(resultSet.next()) + override fun getString(index: Int): String? = resultSet.getString(index + 1) + override fun getBytes(index: Int): ByteArray? = resultSet.getBytes(index + 1) + override fun getBoolean(index: Int): Boolean? = getAtIndex(index, resultSet::getBoolean) + override fun columnName(index: Int): String? = resultSet.metaData.getColumnName(index) + override val columnCount: Int = resultSet.metaData.columnCount + + fun getByte(index: Int): Byte? = getAtIndex(index, resultSet::getByte) + fun getShort(index: Int): Short? = getAtIndex(index, resultSet::getShort) + fun getInt(index: Int): Int? = getAtIndex(index, resultSet::getInt) + override fun getLong(index: Int): Long? = getAtIndex(index, resultSet::getLong) + fun getFloat(index: Int): Float? = getAtIndex(index, resultSet::getFloat) + override fun getDouble(index: Int): Double? = getAtIndex(index, resultSet::getDouble) + fun getBigDecimal(index: Int): BigDecimal? = resultSet.getBigDecimal(index + 1) + inline fun getObject(index: Int): T? = resultSet.getObject(index + 1, T::class.java) + fun getDate(index: Int): java.sql.Date? = resultSet.getDate(index) + fun getTime(index: Int): java.sql.Time? = resultSet.getTime(index) + fun getTimestamp(index: Int): java.sql.Timestamp? = resultSet.getTimestamp(index) + + @Suppress("UNCHECKED_CAST") + fun getArray(index: Int) = getAtIndex(index, resultSet::getArray)?.array as Array? + + private fun getAtIndex(index: Int, converter: (Int) -> T): T? = + converter(index + 1).takeUnless { resultSet.wasNull() } + + override fun next(): QueryResult.Value = QueryResult.Value(resultSet.next()) }