diff --git a/Sources/PureSQL/Connection.swift b/Sources/PureSQL/Connection.swift index cace0a9..7bb5e26 100644 --- a/Sources/PureSQL/Connection.swift +++ b/Sources/PureSQL/Connection.swift @@ -16,10 +16,22 @@ public protocol Connection: Sendable { /// Cancels the observation for the given subscriber func cancel(subscriber: DatabaseSubscriber) + /// Begins a transaction and passes it to the `execute` function. + /// If no error is thrown the changes are automatically commited. + /// If an error is thrown the changes are rolled back. func begin( _ kind: Transaction.Kind, execute: @Sendable (borrowing Transaction) throws -> Output ) async throws -> Output + + /// Gets a raw connection to the database and allows for direct + /// SQL access. No transaction is automatically started. + /// + /// This is likely not the API you want, and should just use `begin`. + func withConnection( + isWrite: Bool, + execute: @Sendable (borrowing RawConnection) throws -> Output + ) async throws -> Output } /// A no operation database connection that does nothing. @@ -36,6 +48,13 @@ public struct NoopConnection: Connection { ) async throws -> Output { try execute(Transaction(connection: NoopRawConnection(), kind: kind)) } + + public func withConnection( + isWrite: Bool, + execute: @Sendable (borrowing RawConnection) throws -> Output + ) async throws -> Output { + try execute(NoopRawConnection()) + } } /// A type that has a database connection. @@ -62,4 +81,11 @@ public extension ConnectionWrapper { ) async throws -> Output { try await connection.begin(kind, execute: execute) } + + func withConnection( + isWrite: Bool, + execute: @Sendable (borrowing RawConnection) throws -> Output + ) async throws -> Output { + try await connection.withConnection(isWrite: isWrite, execute: execute) + } } diff --git a/Sources/PureSQL/ConnectionPool.swift b/Sources/PureSQL/ConnectionPool.swift index 909d47d..b5eaf5f 100644 --- a/Sources/PureSQL/ConnectionPool.swift +++ b/Sources/PureSQL/ConnectionPool.swift @@ -36,7 +36,8 @@ public actor ConnectionPool: Sendable { public init( path: String, limit: Int, - migrations: [String] + migrations: [String], + runMigrations: Bool = true ) throws { guard limit > 0 else { throw SQLError.poolCannotHaveZeroConnections @@ -50,7 +51,11 @@ public actor ConnectionPool: Sendable { // Turn on WAL mode try connection.execute(sql: "PRAGMA journal_mode=WAL;") - try MigrationRunner.execute(migrations: migrations, connection: connection) + + if runMigrations { + try MigrationRunner.execute(migrations: migrations, connection: connection) + } + self.availableConnections = [connection] } @@ -63,26 +68,32 @@ public actor ConnectionPool: Sendable { private func begin( _ kind: Transaction.Kind ) async throws(SQLError) -> sending Transaction { - // Writes must be exclusive, make sure to wait on any pending writes. - if kind == .write { - await writeLock.lock() - } - - return try await Transaction(connection: getConnection(), kind: kind) + return try await Transaction( + connection: getConnection(isWrite: kind == .write), + kind: kind + ) } /// Gives the connection back to the pool. - private func reclaim(connection: RawConnection, kind: Transaction.Kind) async { + private func reclaim( + connection: RawConnection, + isWrite: Bool + ) async { availableConnections.append(connection) alertAnyWaitersOfAvailableConnection() - if kind == .write { + if isWrite { await writeLock.unlock() } } /// Will get, wait or create a connection to the database - private func getConnection() async throws(SQLError) -> RawConnection { + private func getConnection(isWrite: Bool) async throws(SQLError) -> RawConnection { + // Writes must be exclusive, make sure to wait on any pending writes. + if isWrite { + await writeLock.lock() + } + guard availableConnections.isEmpty else { // Have an available connection, just use it return availableConnections.removeLast() @@ -169,11 +180,22 @@ extension ConnectionPool: Connection { do { let output = try await execute(tx) - await reclaim(connection: conn, kind: kind) + await reclaim(connection: conn, isWrite: kind == .write) return output } catch { - await reclaim(connection: conn, kind: kind) + await reclaim(connection: conn, isWrite: kind == .write) throw error } } + + /// Gets a connection to the database. No tx is started. + public func withConnection( + isWrite: Bool, + execute: @Sendable (borrowing RawConnection) throws -> Output + ) async throws -> Output { + let conn = try await getConnection(isWrite: isWrite) + let output = try execute(conn) + await reclaim(connection: conn, isWrite: isWrite) + return output + } } diff --git a/Sources/PureSQL/Database.swift b/Sources/PureSQL/Database.swift index cc211ea..14c7147 100644 --- a/Sources/PureSQL/Database.swift +++ b/Sources/PureSQL/Database.swift @@ -48,13 +48,15 @@ public extension Database { try ConnectionPool( path: path, limit: config.maxConnectionCount, - migrations: Self.sanitizedMigrations + migrations: Self.sanitizedMigrations, + runMigrations: config.autoMigrate ) } else { try ConnectionPool( path: ":memory:", limit: 1, - migrations: Self.sanitizedMigrations + migrations: Self.sanitizedMigrations, + runMigrations: config.autoMigrate ) } @@ -65,6 +67,20 @@ public extension Database { static func inMemory(adapters: Adapters) throws -> Self { return try Self(config: DatabaseConfig(path: nil), adapters: adapters) } + + /// Runs the migrations up to and including the `maxMigration`. + /// + /// The `maxMigration` number is not equal to the filename, but + /// rather the zero based index. + func migrate(upTo maxMigration: Int? = nil) async throws { + try await connection.withConnection(isWrite: true) { conn in + try MigrationRunner.execute( + migrations: Self.sanitizedMigrations, + connection: conn, + upTo: maxMigration + ) + } + } } diff --git a/Sources/PureSQL/DatabaseConfig.swift b/Sources/PureSQL/DatabaseConfig.swift index 73de103..119f0ca 100644 --- a/Sources/PureSQL/DatabaseConfig.swift +++ b/Sources/PureSQL/DatabaseConfig.swift @@ -15,12 +15,17 @@ public struct DatabaseConfig { /// In memory databases will be overriden to `1` regardless /// of the input public var maxConnectionCount: Int + /// If `true` the migrations will run when the connection is opened. + /// Default is `true`. + public var autoMigrate: Bool public init( path: String?, - maxConnectionCount: Int = 5 + maxConnectionCount: Int = 5, + autoMigrate: Bool = true ) { self.path = path self.maxConnectionCount = maxConnectionCount + self.autoMigrate = autoMigrate } } diff --git a/Sources/PureSQL/Migration.swift b/Sources/PureSQL/Migration.swift index a9483be..5160ffa 100644 --- a/Sources/PureSQL/Migration.swift +++ b/Sources/PureSQL/Migration.swift @@ -9,8 +9,12 @@ enum MigrationRunner { static let migrationTableName = "__puresqlMigration" - static func execute(migrations: [String], connection: SQLiteConnection) throws { - let previouslyRunMigrations = try runMigrations(connection: connection) + static func execute( + migrations: [String], + connection: RawConnection, + upTo maxMigration: Int? = nil + ) throws { + let previouslyRunMigrations = try getRanMigrations(connection: connection) let lastMigration = previouslyRunMigrations.last ?? Int.min let pendingMigrations = migrations.enumerated() @@ -18,6 +22,8 @@ enum MigrationRunner { .filter { $0.number > lastMigration } for (number, migration) in pendingMigrations { + if let maxMigration, number > maxMigration { return } + // Run each migration in it's own transaction. let tx = try Transaction(connection: connection, kind: .write) @@ -50,7 +56,7 @@ enum MigrationRunner { } /// Creates the migrations table and gets the last migration that ran. - private static func runMigrations(connection: SQLiteConnection) throws -> [Int] { + private static func getRanMigrations(connection: RawConnection) throws -> [Int] { let tx = try Transaction(connection: connection, kind: .write) // Create the migration table if need be. diff --git a/Sources/PureSQL/SQLiteConnection.swift b/Sources/PureSQL/SQLiteConnection.swift index 4f5dd9f..5f68f94 100644 --- a/Sources/PureSQL/SQLiteConnection.swift +++ b/Sources/PureSQL/SQLiteConnection.swift @@ -9,8 +9,12 @@ import Collections import Foundation import SQLite3 -protocol RawConnection: Sendable { +/// Represents a raw connection to the SQLite database +public protocol RawConnection: Sendable { + /// Initializes a SQLite prepared statement func prepare(sql: String) throws(SQLError) -> OpaquePointer + /// Executes the SQL statement. + /// Equivalent to `sqlite3_exec` func execute(sql: String) throws(SQLError) } diff --git a/Tests/PureSQLTests/MigrationRunnerTests.swift b/Tests/PureSQLTests/MigrationRunnerTests.swift index 0b31bc4..59a8f19 100644 --- a/Tests/PureSQLTests/MigrationRunnerTests.swift +++ b/Tests/PureSQLTests/MigrationRunnerTests.swift @@ -108,6 +108,33 @@ struct MigrationRunnerTests: ~Copyable { #expect(tables.contains("foo")) } + @Test func canRunMigrationsUpToCertainNumber() async throws { + let migrations = [ + "CREATE TABLE foo (bar INTEGER);", + "CREATE TABLE bar (baz TEXT);", + ] + + try MigrationRunner.execute( + migrations: migrations, + connection: connection, + upTo: 0 // Dont run the last migration + ) + + let onlyFirstMigration = try tableNames() + #expect(onlyFirstMigration.contains("foo")) + #expect(!onlyFirstMigration.contains("bar")) + + try MigrationRunner.execute( + migrations: migrations, + connection: connection, + upTo: nil // Now run all migrations + ) + + let allMigrations = try tableNames() + #expect(allMigrations.contains("foo")) + #expect(allMigrations.contains("bar")) + } + private func runMigrations() throws -> [Int] { return try query("SELECT * FROM \(MigrationRunner.migrationTableName) ORDER BY number ASC") { try $0.fetchAll() } }