Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions Sources/PureSQL/Connection.swift
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,22 @@ public protocol Connection: Sendable {
/// Cancels the observation for the given subscriber
func cancel(subscriber: DatabaseSubscriber)

/// Begins a transaction and passes it to the `execute` function.
/// If no error is thrown the changes are automatically commited.
/// If an error is thrown the changes are rolled back.
func begin<Output>(
_ kind: Transaction.Kind,
execute: @Sendable (borrowing Transaction) throws -> Output
) async throws -> Output

/// Gets a raw connection to the database and allows for direct
/// SQL access. No transaction is automatically started.
///
/// This is likely not the API you want, and should just use `begin`.
func withConnection<Output>(
isWrite: Bool,
execute: @Sendable (borrowing RawConnection) throws -> Output
) async throws -> Output
}

/// A no operation database connection that does nothing.
Expand All @@ -36,6 +48,13 @@ public struct NoopConnection: Connection {
) async throws -> Output {
try execute(Transaction(connection: NoopRawConnection(), kind: kind))
}

public func withConnection<Output>(
isWrite: Bool,
execute: @Sendable (borrowing RawConnection) throws -> Output
) async throws -> Output {
try execute(NoopRawConnection())
}
}

/// A type that has a database connection.
Expand All @@ -62,4 +81,11 @@ public extension ConnectionWrapper {
) async throws -> Output {
try await connection.begin(kind, execute: execute)
}

func withConnection<Output>(
isWrite: Bool,
execute: @Sendable (borrowing RawConnection) throws -> Output
) async throws -> Output {
try await connection.withConnection(isWrite: isWrite, execute: execute)
}
}
48 changes: 35 additions & 13 deletions Sources/PureSQL/ConnectionPool.swift
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ public actor ConnectionPool: Sendable {
public init(
path: String,
limit: Int,
migrations: [String]
migrations: [String],
runMigrations: Bool = true
) throws {
guard limit > 0 else {
throw SQLError.poolCannotHaveZeroConnections
Expand All @@ -50,7 +51,11 @@ public actor ConnectionPool: Sendable {

// Turn on WAL mode
try connection.execute(sql: "PRAGMA journal_mode=WAL;")
try MigrationRunner.execute(migrations: migrations, connection: connection)

if runMigrations {
try MigrationRunner.execute(migrations: migrations, connection: connection)
}

self.availableConnections = [connection]
}

Expand All @@ -63,26 +68,32 @@ public actor ConnectionPool: Sendable {
private func begin(
_ kind: Transaction.Kind
) async throws(SQLError) -> sending Transaction {
// Writes must be exclusive, make sure to wait on any pending writes.
if kind == .write {
await writeLock.lock()
}

return try await Transaction(connection: getConnection(), kind: kind)
return try await Transaction(
connection: getConnection(isWrite: kind == .write),
kind: kind
)
}

/// Gives the connection back to the pool.
private func reclaim(connection: RawConnection, kind: Transaction.Kind) async {
private func reclaim(
connection: RawConnection,
isWrite: Bool
) async {
availableConnections.append(connection)
alertAnyWaitersOfAvailableConnection()

if kind == .write {
if isWrite {
await writeLock.unlock()
}
}

/// Will get, wait or create a connection to the database
private func getConnection() async throws(SQLError) -> RawConnection {
private func getConnection(isWrite: Bool) async throws(SQLError) -> RawConnection {
// Writes must be exclusive, make sure to wait on any pending writes.
if isWrite {
await writeLock.lock()
}

guard availableConnections.isEmpty else {
// Have an available connection, just use it
return availableConnections.removeLast()
Expand Down Expand Up @@ -169,11 +180,22 @@ extension ConnectionPool: Connection {

do {
let output = try await execute(tx)
await reclaim(connection: conn, kind: kind)
await reclaim(connection: conn, isWrite: kind == .write)
return output
} catch {
await reclaim(connection: conn, kind: kind)
await reclaim(connection: conn, isWrite: kind == .write)
throw error
}
}

/// Gets a connection to the database. No tx is started.
public func withConnection<Output: Sendable>(
isWrite: Bool,
execute: @Sendable (borrowing RawConnection) throws -> Output
) async throws -> Output {
let conn = try await getConnection(isWrite: isWrite)
let output = try execute(conn)
await reclaim(connection: conn, isWrite: isWrite)
return output
}
}
20 changes: 18 additions & 2 deletions Sources/PureSQL/Database.swift
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,15 @@ public extension Database {
try ConnectionPool(
path: path,
limit: config.maxConnectionCount,
migrations: Self.sanitizedMigrations
migrations: Self.sanitizedMigrations,
runMigrations: config.autoMigrate
)
} else {
try ConnectionPool(
path: ":memory:",
limit: 1,
migrations: Self.sanitizedMigrations
migrations: Self.sanitizedMigrations,
runMigrations: config.autoMigrate
)
}

Expand All @@ -65,6 +67,20 @@ public extension Database {
static func inMemory(adapters: Adapters) throws -> Self {
return try Self(config: DatabaseConfig(path: nil), adapters: adapters)
}

/// Runs the migrations up to and including the `maxMigration`.
///
/// The `maxMigration` number is not equal to the filename, but
/// rather the zero based index.
func migrate(upTo maxMigration: Int? = nil) async throws {
try await connection.withConnection(isWrite: true) { conn in
try MigrationRunner.execute(
migrations: Self.sanitizedMigrations,
connection: conn,
upTo: maxMigration
)
}
}
}


Expand Down
7 changes: 6 additions & 1 deletion Sources/PureSQL/DatabaseConfig.swift
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,17 @@ public struct DatabaseConfig {
/// In memory databases will be overriden to `1` regardless
/// of the input
public var maxConnectionCount: Int
/// If `true` the migrations will run when the connection is opened.
/// Default is `true`.
public var autoMigrate: Bool

public init(
path: String?,
maxConnectionCount: Int = 5
maxConnectionCount: Int = 5,
autoMigrate: Bool = true
) {
self.path = path
self.maxConnectionCount = maxConnectionCount
self.autoMigrate = autoMigrate
}
}
12 changes: 9 additions & 3 deletions Sources/PureSQL/Migration.swift
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,21 @@
enum MigrationRunner {
static let migrationTableName = "__puresqlMigration"

static func execute(migrations: [String], connection: SQLiteConnection) throws {
let previouslyRunMigrations = try runMigrations(connection: connection)
static func execute(
migrations: [String],
connection: RawConnection,
upTo maxMigration: Int? = nil
) throws {
let previouslyRunMigrations = try getRanMigrations(connection: connection)
let lastMigration = previouslyRunMigrations.last ?? Int.min

let pendingMigrations = migrations.enumerated()
.map { (number: $0.offset, migration: $0.element) }
.filter { $0.number > lastMigration }

for (number, migration) in pendingMigrations {
if let maxMigration, number > maxMigration { return }

// Run each migration in it's own transaction.
let tx = try Transaction(connection: connection, kind: .write)

Expand Down Expand Up @@ -50,7 +56,7 @@ enum MigrationRunner {
}

/// Creates the migrations table and gets the last migration that ran.
private static func runMigrations(connection: SQLiteConnection) throws -> [Int] {
private static func getRanMigrations(connection: RawConnection) throws -> [Int] {
let tx = try Transaction(connection: connection, kind: .write)

// Create the migration table if need be.
Expand Down
6 changes: 5 additions & 1 deletion Sources/PureSQL/SQLiteConnection.swift
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,12 @@ import Collections
import Foundation
import SQLite3

protocol RawConnection: Sendable {
/// Represents a raw connection to the SQLite database
public protocol RawConnection: Sendable {
/// Initializes a SQLite prepared statement
func prepare(sql: String) throws(SQLError) -> OpaquePointer
/// Executes the SQL statement.
/// Equivalent to `sqlite3_exec`
func execute(sql: String) throws(SQLError)
}

Expand Down
27 changes: 27 additions & 0 deletions Tests/PureSQLTests/MigrationRunnerTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,33 @@ struct MigrationRunnerTests: ~Copyable {
#expect(tables.contains("foo"))
}

@Test func canRunMigrationsUpToCertainNumber() async throws {
let migrations = [
"CREATE TABLE foo (bar INTEGER);",
"CREATE TABLE bar (baz TEXT);",
]

try MigrationRunner.execute(
migrations: migrations,
connection: connection,
upTo: 0 // Dont run the last migration
)

let onlyFirstMigration = try tableNames()
#expect(onlyFirstMigration.contains("foo"))
#expect(!onlyFirstMigration.contains("bar"))

try MigrationRunner.execute(
migrations: migrations,
connection: connection,
upTo: nil // Now run all migrations
)

let allMigrations = try tableNames()
#expect(allMigrations.contains("foo"))
#expect(allMigrations.contains("bar"))
}

private func runMigrations() throws -> [Int] {
return try query("SELECT * FROM \(MigrationRunner.migrationTableName) ORDER BY number ASC") { try $0.fetchAll() }
}
Expand Down