diff --git a/Sources/Compiler/Config.swift b/Sources/Compiler/Config.swift index a605763..c3fa830 100644 --- a/Sources/Compiler/Config.swift +++ b/Sources/Compiler/Config.swift @@ -14,6 +14,7 @@ public struct Config: Codable { public let output: String? public let databaseName: String? public let additionalImports: [String]? + public let tableNamePattern: String? struct NotFoundError: Error, CustomStringConvertible { var description: String { "Config does not exist" } diff --git a/Sources/Compiler/Gen/Language.swift b/Sources/Compiler/Gen/Language.swift index 1b48d51..68a718f 100644 --- a/Sources/Compiler/Gen/Language.swift +++ b/Sources/Compiler/Gen/Language.swift @@ -13,6 +13,8 @@ import Foundation public protocol Language { init(options: GenerationOptions) + var options: GenerationOptions { get } + var boolName: String { get } /// A list of types that have builtin adapters supplied by the library. @@ -186,8 +188,12 @@ extension Language { } private func model(for table: Table) -> GeneratedModel { - GeneratedModel( - name: table.name.name.capitalizedFirst, + var name = table.name.name.capitalizedFirst + if let pattern = options.tableNamePattern { + name = String(format: pattern, name) + } + return GeneratedModel( + name: name, fields: table.columns.reduce(into: [:]) { fields, column in let name = column.key.description let type = column.value.type @@ -365,15 +371,18 @@ public struct GenerationOptions: Sendable { public var databaseName: String public var imports: [String] public var createDirectoryIfNeeded: Bool + public var tableNamePattern: String? public init( databaseName: String, imports: [String] = [], - createDirectoryIfNeeded: Bool = true + createDirectoryIfNeeded: Bool = true, + tableNamePattern: String? = nil ) { self.databaseName = databaseName self.imports = imports self.createDirectoryIfNeeded = createDirectoryIfNeeded + self.tableNamePattern = tableNamePattern } } diff --git a/Sources/Compiler/Gen/SwiftLanguage.swift b/Sources/Compiler/Gen/SwiftLanguage.swift index a41fb52..22534d5 100644 --- a/Sources/Compiler/Gen/SwiftLanguage.swift +++ b/Sources/Compiler/Gen/SwiftLanguage.swift @@ -6,7 +6,7 @@ // public struct SwiftLanguage: Language { - let options: GenerationOptions + public let options: GenerationOptions private var writer = SourceWriter() public init(options: GenerationOptions) { diff --git a/Sources/PureSQLCLI/GenerateCommand.swift b/Sources/PureSQLCLI/GenerateCommand.swift index 668a646..acf999f 100644 --- a/Sources/PureSQLCLI/GenerateCommand.swift +++ b/Sources/PureSQLCLI/GenerateCommand.swift @@ -45,7 +45,8 @@ struct GenerateCommand: AsyncParsableCommand { let options = GenerationOptions( databaseName: config.databaseName ?? "DB", imports: config.additionalImports ?? [], - createDirectoryIfNeeded: !skipDirectoryCreate + createDirectoryIfNeeded: !skipDirectoryCreate, + tableNamePattern: config.tableNamePattern ) try await generate( diff --git a/Tests/CompilerTests/Gen/SwiftWithPattern.output b/Tests/CompilerTests/Gen/SwiftWithPattern.output new file mode 100644 index 0000000..cdc64d0 --- /dev/null +++ b/Tests/CompilerTests/Gen/SwiftWithPattern.output @@ -0,0 +1,526 @@ +import Foundation +import PureSQL + +struct BarRecord: Hashable, Sendable, RowDecodable { + let intPk: Int + let barNotNullText: String + + static let nonOptionalIndices: [Int32] = [0, 1] + + init( + row: borrowing PureSQL.Row, + startingAt start: Int32 + ) throws(PureSQL.SQLError) { + self.intPk = try row.value(at: start + 0) + self.barNotNullText = try row.value(at: start + 1) + } + + init( + intPk: Int, + barNotNullText: String + ) { + self.intPk = intPk + self.barNotNullText = barNotNullText + } +} + +struct FooRecord: Hashable, Sendable, RowDecodableWithAdapters { + let intPk: Int + let textNotNull: String + let textNullable: String? + let dateWithAdapterNotNull: Date + let dateWithAdapterNullable: Date? + let dateWithCustomAdapter: Date? + let generatedColumn: String + + static let nonOptionalIndices: [Int32] = [0, 1, 3, 6] + + init( + row: borrowing PureSQL.Row, + startingAt start: Int32, + adapters: DB.Adapters + ) throws(PureSQL.SQLError) { + self.intPk = try row.value(at: start + 0) + self.textNotNull = try row.value(at: start + 1) + self.textNullable = try row.value(at: start + 2) + self.dateWithAdapterNotNull = try row.value(at: start + 3, using: adapters.date, storage: Int.self) + self.dateWithAdapterNullable = try row.optionalValue(at: start + 4, using: adapters.date, storage: Int.self) + self.dateWithCustomAdapter = try row.optionalValue(at: start + 5, using: adapters.customDate, storage: String.self) + self.generatedColumn = try row.value(at: start + 6) + } + + init( + intPk: Int, + textNotNull: String, + textNullable: String?, + dateWithAdapterNotNull: Date, + dateWithAdapterNullable: Date?, + dateWithCustomAdapter: Date?, + generatedColumn: String + ) { + self.intPk = intPk + self.textNotNull = textNotNull + self.textNullable = textNullable + self.dateWithAdapterNotNull = dateWithAdapterNotNull + self.dateWithAdapterNullable = dateWithAdapterNullable + self.dateWithCustomAdapter = dateWithCustomAdapter + self.generatedColumn = generatedColumn + } +} + +struct InsertFooReturningFooInput: Hashable, Sendable { + let textNotNull: String + let textNullable: String? + let dateWithAdapterNotNull: Date + let dateWithAdapterNullable: Date? + let dateWithCustomAdapter: Date? +} + +struct InsertFooReturningFooOutput: Hashable, Sendable, RowDecodableWithAdapters { + let intPk: Int + let textNotNull: String + let textNullable: String? + let dateWithAdapterNotNull: Date + let dateWithAdapterNullable: Date? + let dateWithCustomAdapter: Date? + let generatedColumn: String + + static let nonOptionalIndices: [Int32] = [] + + init( + row: borrowing PureSQL.Row, + startingAt start: Int32, + adapters: DB.Adapters + ) throws(PureSQL.SQLError) { + self.intPk = try row.value(at: start + 0) + self.textNotNull = try row.value(at: start + 1) + self.textNullable = try row.value(at: start + 2) + self.dateWithAdapterNotNull = try row.value(at: start + 3, using: adapters.date, storage: Int.self) + self.dateWithAdapterNullable = try row.optionalValue(at: start + 4, using: adapters.date, storage: Int.self) + self.dateWithCustomAdapter = try row.optionalValue(at: start + 5, using: adapters.customDate, storage: String.self) + self.generatedColumn = try row.value(at: start + 6) + } + + init( + intPk: Int, + textNotNull: String, + textNullable: String?, + dateWithAdapterNotNull: Date, + dateWithAdapterNullable: Date?, + dateWithCustomAdapter: Date?, + generatedColumn: String + ) { + self.intPk = intPk + self.textNotNull = textNotNull + self.textNullable = textNullable + self.dateWithAdapterNotNull = dateWithAdapterNotNull + self.dateWithAdapterNullable = dateWithAdapterNullable + self.dateWithCustomAdapter = dateWithCustomAdapter + self.generatedColumn = generatedColumn + } +} + +struct InsertBarReturningIntPkInput: Hashable, Sendable { + let customNameIntPk: Int + let barNotNullText: String +} + +struct InsertBarReturningExtraColumnInput: Hashable, Sendable { + let intPk: Int + let barNotNullText: String +} + +struct InsertBarReturningExtraColumnOutput: Hashable, Sendable, RowDecodable { + let intPk: Int + let barNotNullText: String + let columnAfter: Int + + static let nonOptionalIndices: [Int32] = [] + + init( + row: borrowing PureSQL.Row, + startingAt start: Int32 + ) throws(PureSQL.SQLError) { + self.intPk = try row.value(at: start + 0) + self.barNotNullText = try row.value(at: start + 1) + self.columnAfter = try row.value(at: start + 2) + } + + init( + intPk: Int, + barNotNullText: String, + columnAfter: Int + ) { + self.intPk = intPk + self.barNotNullText = barNotNullText + self.columnAfter = columnAfter + } +} + +@dynamicMemberLookup +struct HasEmbeddedFooOutput: Hashable, Sendable, RowDecodableWithAdapters { + let foo: FooRecord + let shouldBeNullable: String? + + static let nonOptionalIndices: [Int32] = [] + + init( + row: borrowing PureSQL.Row, + startingAt start: Int32, + adapters: DB.Adapters + ) throws(PureSQL.SQLError) { + self.foo = try row.embedded(at: start + 0, adapters: adapters) + self.shouldBeNullable = try row.value(at: start + 7) + } + + init( + foo: FooRecord, + shouldBeNullable: String? + ) { + self.foo = foo + self.shouldBeNullable = shouldBeNullable + } + + subscript(dynamicMember dynamicMember: KeyPath) -> Value { + self.foo[keyPath: dynamicMember] + } +} + +struct BothColumnsShouldNotBeNullableOutput: Hashable, Sendable, RowDecodable { + let f: Int + let b: Int + + static let nonOptionalIndices: [Int32] = [] + + init( + row: borrowing PureSQL.Row, + startingAt start: Int32 + ) throws(PureSQL.SQLError) { + self.f = try row.value(at: start + 0) + self.b = try row.value(at: start + 1) + } + + init( + f: Int, + b: Int + ) { + self.f = f + self.b = b + } +} + +struct SelectWithManyInputsInput: Hashable, Sendable { + let intPk: Int + let textNotNull: String +} + +struct InputContainsArrayInput: Hashable, Sendable { + let intPks: [Int] + let barNotNullText: String +} + +struct QueriesQueries: ConnectionWrapper, Sendable { + let connection: any Connection + var insertFooReturningFoo: any InsertFooReturningFooQuery + var insertBarReturningIntPk: any InsertBarReturningIntPkQuery + var insertBarReturningExtraColumn: any InsertBarReturningExtraColumnQuery + var selectSingleFoo: any SelectSingleFooQuery + var hasEmbeddedFoo: any HasEmbeddedFooQuery + var bothColumnsShouldNotBeNullable: any BothColumnsShouldNotBeNullableQuery + var selectWithManyInputs: any SelectWithManyInputsQuery + var inputIsArray: any InputIsArrayQuery + var inputContainsArray: any InputContainsArrayQuery + + static func noop( + insertFooReturningFoo: any InsertFooReturningFooQuery = Queries.Fail(), + insertBarReturningIntPk: any InsertBarReturningIntPkQuery = Queries.Just(0), + insertBarReturningExtraColumn: any InsertBarReturningExtraColumnQuery = Queries.Fail(), + selectSingleFoo: any SelectSingleFooQuery = Queries.Just(), + hasEmbeddedFoo: any HasEmbeddedFooQuery = Queries.Just(), + bothColumnsShouldNotBeNullable: any BothColumnsShouldNotBeNullableQuery = Queries.Just(), + selectWithManyInputs: any SelectWithManyInputsQuery = Queries.Just(), + inputIsArray: any InputIsArrayQuery = Queries.Just(), + inputContainsArray: any InputContainsArrayQuery = Queries.Just() + ) -> QueriesQueries { + QueriesQueries( + connection: NoopConnection(), + insertFooReturningFoo: insertFooReturningFoo, + insertBarReturningIntPk: insertBarReturningIntPk, + insertBarReturningExtraColumn: insertBarReturningExtraColumn, + selectSingleFoo: selectSingleFoo, + hasEmbeddedFoo: hasEmbeddedFoo, + bothColumnsShouldNotBeNullable: bothColumnsShouldNotBeNullable, + selectWithManyInputs: selectWithManyInputs, + inputIsArray: inputIsArray, + inputContainsArray: inputContainsArray + ) + } + + static func live(connection: Connection, adapters: DB.Adapters) -> QueriesQueries { + return QueriesQueries( + connection: connection, + insertFooReturningFoo: DatabaseQuery( + .write, + in: connection, + watchingTables: ["foo"] + ) { input, tx in + let statement = try PureSQL.Statement( + """ + INSERT INTO foo + (textNotNull, textNullable, dateWithAdapterNotNull, dateWithAdapterNullable, dateWithCustomAdapter) + VALUES (?, ?, ?, ?, ?) + RETURNING * + """, + transaction: tx + ) + try statement.bind(value: input.textNotNull) + try statement.bind(value: input.textNullable) + try statement.bind(value: input.dateWithAdapterNotNull, using: adapters.date, as: Int.self) + try statement.bind(value: input.dateWithAdapterNullable, using: adapters.date, as: Int.self) + try statement.bind(value: input.dateWithCustomAdapter, using: adapters.customDate, as: String.self) + return try statement.fetchOne(adapters: adapters) + }, + insertBarReturningIntPk: DatabaseQuery( + .write, + in: connection, + watchingTables: ["bar"] + ) { input, tx in + let statement = try PureSQL.Statement( + """ + INSERT INTO bar (intPk, barNotNullText) VALUES (:customNameIntPk, ?) RETURNING intPk + """, + transaction: tx + ) + try statement.bind(value: input.customNameIntPk) + try statement.bind(value: input.barNotNullText) + return try statement.fetchOne() + }, + insertBarReturningExtraColumn: DatabaseQuery( + .write, + in: connection, + watchingTables: ["bar"] + ) { input, tx in + let statement = try PureSQL.Statement( + """ + INSERT INTO bar VALUES (?, ?) RETURNING *, 123 AS columnAfter + """, + transaction: tx + ) + try statement.bind(value: input.intPk) + try statement.bind(value: input.barNotNullText) + return try statement.fetchOne() + }, + selectSingleFoo: DatabaseQuery( + .read, + in: connection, + watchingTables: ["foo"] + ) { input, tx in + let statement = try PureSQL.Statement( + """ + SELECT * FROM foo WHERE intPk = ? + """, + transaction: tx + ) + try statement.bind(value: input) + return try statement.fetchOne(adapters: adapters) + }, + hasEmbeddedFoo: DatabaseQuery( + .read, + in: connection, + watchingTables: ["bar","foo"] + ) { input, tx in + let statement = try PureSQL.Statement( + """ + SELECT foo.*, bar.barNotNullText AS shouldBeNullable + FROM foo + LEFT OUTER JOIN bar ON foo.intPk = bar.intPk + WHERE foo.intPk = ? + """, + transaction: tx + ) + try statement.bind(value: input) + return try statement.fetchAll(adapters: adapters) + }, + bothColumnsShouldNotBeNullable: DatabaseQuery<(), [BothColumnsShouldNotBeNullableOutput]>( + .read, + in: connection, + watchingTables: ["bar","foo"] + ) { input, tx in + let statement = try PureSQL.Statement( + """ + SELECT foo.intPk AS f, bar.intPk AS b FROM foo + INNER JOIN bar ON foo.intPk = bar.intPk + """, + transaction: tx + ) + return try statement.fetchAll() + }, + selectWithManyInputs: DatabaseQuery( + .read, + in: connection, + watchingTables: ["foo"] + ) { input, tx in + let statement = try PureSQL.Statement( + """ + SELECT * FROM foo WHERE intPk = ? AND textNotNull = ? + """, + transaction: tx + ) + try statement.bind(value: input.intPk) + try statement.bind(value: input.textNotNull) + return try statement.fetchOne(adapters: adapters) + }, + inputIsArray: DatabaseQuery<[Int], ()>( + .write, + in: connection, + watchingTables: ["bar"] + ) { input, tx in + let statement = try PureSQL.Statement( + """ + DELETE FROM bar WHERE intPk IN (\(input.sqlQuestionMarks)) + """, + transaction: tx + ) + for element in input { + try statement.bind(value: element) + } + _ = try statement.step() + }, + inputContainsArray: DatabaseQuery( + .write, + in: connection, + watchingTables: ["bar"] + ) { input, tx in + let statement = try PureSQL.Statement( + """ + DELETE FROM bar WHERE intPk IN (\(input.intPks.sqlQuestionMarks))AND barNotNullText = ? + """, + transaction: tx + ) + for element in input.intPks { + try statement.bind(value: element) + } + try statement.bind(value: input.barNotNullText) + _ = try statement.step() + } + ) + } + +} + +struct DB: Database{ + let connection: any PureSQL.Connection + let adapters: Adapters + + struct Adapters: PureSQL.Adapters { + let customDate: AnyDatabaseValueAdapter + + init( + customDate: any DatabaseValueAdapter + ) { + self.customDate = AnyDatabaseValueAdapter(customDate) + } + } + + static var migrations: [String] { + return [ + """ + CREATE TABLE foo ( + intPk INTEGER PRIMARY KEY AUTOINCREMENT, + textNotNull TEXT NOT NULL, + textNullable TEXT, + dateWithAdapterNotNull INTEGER NOT NULL, + dateWithAdapterNullable INTEGER , + dateWithCustomAdapter TEXT , + generatedColumn TEXT NOT NULL GENERATED ALWAYS AS ('a-good-prefix ' || textNotNull) + ); + CREATE TABLE bar ( + intPk INTEGER PRIMARY KEY, + barNotNullText TEXT NOT NULL + );; + """ + ] + } + var queriesQueries: QueriesQueries { + QueriesQueries.live(connection: connection, adapters: adapters) + } +} + +typealias InsertFooReturningFooQuery = Query +extension Query where Input == InsertFooReturningFooInput { + func execute(textNotNull: String, textNullable: String?, dateWithAdapterNotNull: Date, dateWithAdapterNullable: Date?, dateWithCustomAdapter: Date?) async throws -> Output { + try await execute(InsertFooReturningFooInput(textNotNull: textNotNull, textNullable: textNullable, dateWithAdapterNotNull: dateWithAdapterNotNull, dateWithAdapterNullable: dateWithAdapterNullable, dateWithCustomAdapter: dateWithCustomAdapter)) + } + + func execute(textNotNull: String, textNullable: String?, dateWithAdapterNotNull: Date, dateWithAdapterNullable: Date?, dateWithCustomAdapter: Date?, tx: borrowing Transaction) throws -> Output { + try execute(InsertFooReturningFooInput(textNotNull: textNotNull, textNullable: textNullable, dateWithAdapterNotNull: dateWithAdapterNotNull, dateWithAdapterNullable: dateWithAdapterNullable, dateWithCustomAdapter: dateWithCustomAdapter), tx: tx) + } + + func observe(textNotNull: String, textNullable: String?, dateWithAdapterNotNull: Date, dateWithAdapterNullable: Date?, dateWithCustomAdapter: Date?) -> QueryStream { + observe(InsertFooReturningFooInput(textNotNull: textNotNull, textNullable: textNullable, dateWithAdapterNotNull: dateWithAdapterNotNull, dateWithAdapterNullable: dateWithAdapterNullable, dateWithCustomAdapter: dateWithCustomAdapter)) + } +} + +typealias InsertBarReturningIntPkQuery = Query +extension Query where Input == InsertBarReturningIntPkInput { + func execute(customNameIntPk: Int, barNotNullText: String) async throws -> Output { + try await execute(InsertBarReturningIntPkInput(customNameIntPk: customNameIntPk, barNotNullText: barNotNullText)) + } + + func execute(customNameIntPk: Int, barNotNullText: String, tx: borrowing Transaction) throws -> Output { + try execute(InsertBarReturningIntPkInput(customNameIntPk: customNameIntPk, barNotNullText: barNotNullText), tx: tx) + } + + func observe(customNameIntPk: Int, barNotNullText: String) -> QueryStream { + observe(InsertBarReturningIntPkInput(customNameIntPk: customNameIntPk, barNotNullText: barNotNullText)) + } +} + +typealias InsertBarReturningExtraColumnQuery = Query +extension Query where Input == InsertBarReturningExtraColumnInput { + func execute(intPk: Int, barNotNullText: String) async throws -> Output { + try await execute(InsertBarReturningExtraColumnInput(intPk: intPk, barNotNullText: barNotNullText)) + } + + func execute(intPk: Int, barNotNullText: String, tx: borrowing Transaction) throws -> Output { + try execute(InsertBarReturningExtraColumnInput(intPk: intPk, barNotNullText: barNotNullText), tx: tx) + } + + func observe(intPk: Int, barNotNullText: String) -> QueryStream { + observe(InsertBarReturningExtraColumnInput(intPk: intPk, barNotNullText: barNotNullText)) + } +} + +typealias SelectSingleFooQuery = Query +typealias HasEmbeddedFooQuery = Query +typealias BothColumnsShouldNotBeNullableQuery = Query<(), [BothColumnsShouldNotBeNullableOutput]> +typealias SelectWithManyInputsQuery = Query +extension Query where Input == SelectWithManyInputsInput { + func execute(intPk: Int, textNotNull: String) async throws -> Output { + try await execute(SelectWithManyInputsInput(intPk: intPk, textNotNull: textNotNull)) + } + + func execute(intPk: Int, textNotNull: String, tx: borrowing Transaction) throws -> Output { + try execute(SelectWithManyInputsInput(intPk: intPk, textNotNull: textNotNull), tx: tx) + } + + func observe(intPk: Int, textNotNull: String) -> QueryStream { + observe(SelectWithManyInputsInput(intPk: intPk, textNotNull: textNotNull)) + } +} + +typealias InputIsArrayQuery = Query<[Int], ()> +typealias InputContainsArrayQuery = Query +extension Query where Input == InputContainsArrayInput { + func execute(intPks: [Int], barNotNullText: String) async throws -> Output { + try await execute(InputContainsArrayInput(intPks: intPks, barNotNullText: barNotNullText)) + } + + func execute(intPks: [Int], barNotNullText: String, tx: borrowing Transaction) throws -> Output { + try execute(InputContainsArrayInput(intPks: intPks, barNotNullText: barNotNullText), tx: tx) + } + + func observe(intPks: [Int], barNotNullText: String) -> QueryStream { + observe(InputContainsArrayInput(intPks: intPks, barNotNullText: barNotNullText)) + } +} diff --git a/Tests/CompilerTests/GenTests.swift b/Tests/CompilerTests/GenTests.swift index e9cfee4..924e48c 100644 --- a/Tests/CompilerTests/GenTests.swift +++ b/Tests/CompilerTests/GenTests.swift @@ -12,7 +12,10 @@ import Foundation @Suite struct GenTests { - @Test func generation() throws { + @Test(arguments: [ + ("Swift", GenerationOptions(databaseName: "DB")), + ("SwiftWithPattern", GenerationOptions(databaseName: "DB", tableNamePattern: "%@Record")), + ]) func generation(args: (outputFile: String, options: GenerationOptions)) throws { var compiler = Compiler() let migrations = try compiler.compile(migration: load(file: "Migrations")) let queries = try compiler.compile(queries: load(file: "Queries")) @@ -27,14 +30,14 @@ struct GenTests { guard migrations.1.isEmpty && queries.1.isEmpty else { return } - let language = SwiftLanguage(options: GenerationOptions(databaseName: "DB")) + let language = SwiftLanguage(options: args.options) let rawOutput = try language.generate( migrations: [migrations.0.map(\.sanitizedSource).joined(separator: "\n\n")], queries: [("Queries", queries.0)], schema: compiler.schema ) - - let expected = try load(file: "Swift", ext: "output") + + let expected = try load(file: args.outputFile, ext: "output") .split(separator: "\n") .filter{ !$0.isEmpty }