diff --git a/packages/language/res/stdlib.zmodel b/packages/language/res/stdlib.zmodel index 269a0f9c..cdf37fc4 100644 --- a/packages/language/res/stdlib.zmodel +++ b/packages/language/res/stdlib.zmodel @@ -454,12 +454,12 @@ attribute @db.JsonB() @@@targetField([JsonField]) @@@prisma attribute @db.ByteA() @@@targetField([BytesField]) @@@prisma -// /** -// * Specifies the schema to use in a multi-schema database. https://www.prisma.io/docs/guides/database/multi-schema. -// * -// * @param: The name of the database schema. -// */ -// attribute @@schema(_ name: String) @@@prisma +/** + * Specifies the schema to use in a multi-schema PostgreSQL database. + * + * @param name: The name of the database schema. + */ +attribute @@schema(_ name: String) @@@prisma ////////////////////////////////////////////// // Begin validation attributes and functions diff --git a/packages/language/src/validators/attribute-application-validator.ts b/packages/language/src/validators/attribute-application-validator.ts index aa0a0a77..a4321c40 100644 --- a/packages/language/src/validators/attribute-application-validator.ts +++ b/packages/language/src/validators/attribute-application-validator.ts @@ -1,3 +1,4 @@ +import { invariant } from '@zenstackhq/common-helpers'; import { AstUtils, type ValidationAcceptor } from 'langium'; import pluralize from 'pluralize'; import type { BinaryExpr, DataModel, Expression } from '../ast'; @@ -13,9 +14,13 @@ import { ReferenceExpr, isArrayExpr, isAttribute, + isConfigArrayExpr, isDataField, isDataModel, + isDataSource, isEnum, + isLiteralExpr, + isModel, isReferenceExpr, isTypeDef, } from '../generated/ast'; @@ -332,6 +337,28 @@ export default class AttributeApplicationValidator implements AstValidator f.name === 'schemas'); + if (schemas && isConfigArrayExpr(schemas.value)) { + found = schemas.value.items.some((item) => isLiteralExpr(item) && item.value === schemaName); + } + if (!found) { + accept('error', `Schema "${schemaName}" is not defined in the datasource`, { + node: attr, + }); + } + } + } + private validatePolicyKinds( kind: string, candidates: string[], diff --git a/packages/language/src/validators/datasource-validator.ts b/packages/language/src/validators/datasource-validator.ts index 84302785..9f6abd64 100644 --- a/packages/language/src/validators/datasource-validator.ts +++ b/packages/language/src/validators/datasource-validator.ts @@ -1,6 +1,6 @@ import type { ValidationAcceptor } from 'langium'; import { SUPPORTED_PROVIDERS } from '../constants'; -import { DataSource, isInvocationExpr } from '../generated/ast'; +import { DataSource, isConfigArrayExpr, isInvocationExpr, isLiteralExpr } from '../generated/ast'; import { getStringLiteral } from '../utils'; import { validateDuplicatedDeclarations, type AstValidator } from './common'; @@ -12,7 +12,6 @@ export default class DataSourceValidator implements AstValidator { validateDuplicatedDeclarations(ds, ds.fields, accept); this.validateProvider(ds, accept); this.validateUrl(ds, accept); - this.validateRelationMode(ds, accept); } private validateProvider(ds: DataSource, accept: ValidationAcceptor) { @@ -24,20 +23,45 @@ export default class DataSourceValidator implements AstValidator { return; } - const value = getStringLiteral(provider.value); - if (!value) { + const providerValue = getStringLiteral(provider.value); + if (!providerValue) { accept('error', '"provider" must be set to a string literal', { node: provider.value, }); - } else if (!SUPPORTED_PROVIDERS.includes(value)) { + } else if (!SUPPORTED_PROVIDERS.includes(providerValue)) { accept( 'error', - `Provider "${value}" is not supported. Choose from ${SUPPORTED_PROVIDERS.map((p) => '"' + p + '"').join( - ' | ', - )}.`, + `Provider "${providerValue}" is not supported. Choose from ${SUPPORTED_PROVIDERS.map( + (p) => '"' + p + '"', + ).join(' | ')}.`, { node: provider.value }, ); } + + const defaultSchemaField = ds.fields.find((f) => f.name === 'defaultSchema'); + if (defaultSchemaField && providerValue !== 'postgresql') { + accept('error', '"defaultSchema" is only supported for "postgresql" provider', { + node: defaultSchemaField, + }); + } + + const schemasField = ds.fields.find((f) => f.name === 'schemas'); + if (schemasField) { + if (providerValue !== 'postgresql') { + accept('error', '"schemas" is only supported for "postgresql" provider', { + node: schemasField, + }); + } + const schemasValue = schemasField.value; + if ( + !isConfigArrayExpr(schemasValue) || + !schemasValue.items.every((e) => isLiteralExpr(e) && typeof getStringLiteral(e) === 'string') + ) { + accept('error', '"schemas" must be an array of string literals', { + node: schemasField, + }); + } + } } private validateUrl(ds: DataSource, accept: ValidationAcceptor) { @@ -53,14 +77,4 @@ export default class DataSourceValidator implements AstValidator { }); } } - - private validateRelationMode(ds: DataSource, accept: ValidationAcceptor) { - const field = ds.fields.find((f) => f.name === 'relationMode'); - if (field) { - const val = getStringLiteral(field.value); - if (!val || !['foreignKeys', 'prisma'].includes(val)) { - accept('error', '"relationMode" must be set to "foreignKeys" or "prisma"', { node: field.value }); - } - } - } } diff --git a/packages/orm/src/client/executor/name-mapper.ts b/packages/orm/src/client/executor/name-mapper.ts index dcea8152..1f508b0b 100644 --- a/packages/orm/src/client/executor/name-mapper.ts +++ b/packages/orm/src/client/executor/name-mapper.ts @@ -129,10 +129,9 @@ export class QueryNameMapper extends OperationNodeTransformer { mappedTableName = this.mapTableName(scope.model); } } - return ReferenceNode.create( ColumnNode.create(mappedFieldName), - mappedTableName ? TableNode.create(mappedTableName) : undefined, + mappedTableName ? this.createTableNode(mappedTableName, undefined) : undefined, ); } else { // no name mapping needed @@ -316,7 +315,9 @@ export class QueryNameMapper extends OperationNodeTransformer { if (!TableNode.is(node)) { return super.transformNode(node); } - return TableNode.create(this.mapTableName(node.table.identifier.name)); + const mappedName = this.mapTableName(node.table.identifier.name); + const tableSchema = this.getTableSchema(node.table.identifier.name); + return this.createTableNode(mappedName, tableSchema); } private getMappedName(def: ModelDef | FieldDef) { @@ -362,8 +363,9 @@ export class QueryNameMapper extends OperationNodeTransformer { const modelName = innerNode.table.identifier.name; const mappedName = this.mapTableName(modelName); const finalAlias = alias ?? (mappedName !== modelName ? IdentifierNode.create(modelName) : undefined); + const tableSchema = this.getTableSchema(modelName); return { - node: this.wrapAlias(TableNode.create(mappedName), finalAlias), + node: this.wrapAlias(this.createTableNode(mappedName, tableSchema), finalAlias), scope: { alias: alias ?? IdentifierNode.create(modelName), model: modelName, @@ -384,6 +386,21 @@ export class QueryNameMapper extends OperationNodeTransformer { } } + private getTableSchema(model: string) { + if (this.schema.provider.type !== 'postgresql') { + return undefined; + } + let schema = this.schema.provider.defaultSchema ?? 'public'; + const schemaAttr = this.schema.models[model]?.attributes?.find((attr) => attr.name === '@@schema'); + if (schemaAttr) { + const nameArg = schemaAttr.args?.find((arg) => arg.name === 'name'); + if (nameArg && nameArg.value.kind === 'literal') { + schema = nameArg.value.value as string; + } + } + return schema; + } + private createSelectAllFields(model: string, alias: OperationNode | undefined) { const modelDef = requireModel(this.schema, model); return this.getModelFields(modelDef).map((fieldDef) => { @@ -454,5 +471,9 @@ export class QueryNameMapper extends OperationNodeTransformer { }); } + private createTableNode(tableName: string, schemaName: string | undefined) { + return schemaName ? TableNode.createWithSchema(schemaName, tableName) : TableNode.create(tableName); + } + // #endregion } diff --git a/packages/orm/src/client/executor/zenstack-query-executor.ts b/packages/orm/src/client/executor/zenstack-query-executor.ts index e53552c5..06f8d133 100644 --- a/packages/orm/src/client/executor/zenstack-query-executor.ts +++ b/packages/orm/src/client/executor/zenstack-query-executor.ts @@ -55,7 +55,10 @@ export class ZenStackQueryExecutor extends DefaultQuer ) { super(compiler, adapter, connectionProvider, plugins); - if (this.schemaHasMappedNames(client.$schema)) { + if ( + client.$schema.provider.type === 'postgresql' || // postgres queries need to be schema-qualified + this.schemaHasMappedNames(client.$schema) + ) { this.nameMapper = new QueryNameMapper(client.$schema); } } diff --git a/packages/schema/src/schema.ts b/packages/schema/src/schema.ts index 5dc9efc4..ac214fa1 100644 --- a/packages/schema/src/schema.ts +++ b/packages/schema/src/schema.ts @@ -5,6 +5,7 @@ export type DataSourceProviderType = 'sqlite' | 'postgresql'; export type DataSourceProvider = { type: DataSourceProviderType; + defaultSchema?: string; }; export type SchemaDef = { diff --git a/packages/sdk/src/ts-schema-generator.ts b/packages/sdk/src/ts-schema-generator.ts index 61bcb40a..564f5112 100644 --- a/packages/sdk/src/ts-schema-generator.ts +++ b/packages/sdk/src/ts-schema-generator.ts @@ -236,8 +236,20 @@ export class TsSchemaGenerator { private createProviderObject(model: Model): ts.Expression { const dsProvider = this.getDataSourceProvider(model); + const defaultSchema = this.getDataSourceDefaultSchema(model); + return ts.factory.createObjectLiteralExpression( - [ts.factory.createPropertyAssignment('type', ts.factory.createStringLiteral(dsProvider.type))], + [ + ts.factory.createPropertyAssignment('type', ts.factory.createStringLiteral(dsProvider)), + ...(defaultSchema + ? [ + ts.factory.createPropertyAssignment( + 'defaultSchema', + ts.factory.createStringLiteral(defaultSchema), + ), + ] + : []), + ], true, ); } @@ -621,9 +633,26 @@ export class TsSchemaGenerator { invariant(dataSource, 'No data source found in the model'); const providerExpr = dataSource.fields.find((f) => f.name === 'provider')?.value; - invariant(isLiteralExpr(providerExpr), 'Provider must be a literal'); - const type = providerExpr.value as string; - return { type }; + invariant( + isLiteralExpr(providerExpr) && typeof providerExpr.value === 'string', + 'Provider must be a string literal', + ); + return providerExpr.value as string; + } + + private getDataSourceDefaultSchema(model: Model) { + const dataSource = model.declarations.find(isDataSource); + invariant(dataSource, 'No data source found in the model'); + + const defaultSchemaExpr = dataSource.fields.find((f) => f.name === 'defaultSchema')?.value; + if (!defaultSchemaExpr) { + return undefined; + } + invariant( + isLiteralExpr(defaultSchemaExpr) && typeof defaultSchemaExpr.value === 'string', + 'Default schema must be a string literal', + ); + return defaultSchemaExpr.value as string; } private getFieldMappedDefault( diff --git a/packages/testtools/src/client.ts b/packages/testtools/src/client.ts index 1cfd1d41..3f22fead 100644 --- a/packages/testtools/src/client.ts +++ b/packages/testtools/src/client.ts @@ -1,8 +1,8 @@ import { invariant } from '@zenstackhq/common-helpers'; import type { Model } from '@zenstackhq/language/ast'; -import { PolicyPlugin } from '@zenstackhq/plugin-policy'; import { ZenStackClient, type ClientContract, type ClientOptions } from '@zenstackhq/orm'; import type { SchemaDef } from '@zenstackhq/orm/schema'; +import { PolicyPlugin } from '@zenstackhq/plugin-policy'; import { PrismaSchemaGenerator } from '@zenstackhq/sdk'; import SQLite from 'better-sqlite3'; import { PostgresDialect, SqliteDialect, type LogEvent } from 'kysely'; @@ -59,7 +59,6 @@ export async function createTestClient( let _schema: Schema; const provider = options?.provider ?? getTestDbProvider() ?? 'sqlite'; const dbName = options?.dbName ?? getTestDbName(provider); - const dbUrl = provider === 'sqlite' ? `file:${dbName}` @@ -68,13 +67,14 @@ export async function createTestClient( let model: Model | undefined; if (typeof schema === 'string') { - const generated = await generateTsSchema(schema, provider, dbUrl, options?.extraSourceFiles); + const generated = await generateTsSchema(schema, provider, dbUrl, options?.extraSourceFiles, undefined); workDir = generated.workDir; model = generated.model; // replace schema's provider _schema = { ...generated.schema, provider: { + ...generated.schema.provider, type: provider, }, } as Schema; diff --git a/packages/testtools/src/schema.ts b/packages/testtools/src/schema.ts index c805cb95..1ecb015c 100644 --- a/packages/testtools/src/schema.ts +++ b/packages/testtools/src/schema.ts @@ -32,6 +32,11 @@ datasource db { .exhaustive(); } +function replacePlaceholders(schemaText: string, provider: 'sqlite' | 'postgresql', dbUrl: string | undefined) { + const url = dbUrl ?? (provider === 'sqlite' ? 'file:./test.db' : 'postgres://postgres:postgres@localhost:5432/db'); + return schemaText.replace(/\$DB_URL/g, url).replace(/\$PROVIDER/g, provider); +} + export async function generateTsSchema( schemaText: string, provider: 'sqlite' | 'postgresql' = 'sqlite', @@ -43,7 +48,10 @@ export async function generateTsSchema( const zmodelPath = path.join(workDir, 'schema.zmodel'); const noPrelude = schemaText.includes('datasource '); - fs.writeFileSync(zmodelPath, `${noPrelude ? '' : makePrelude(provider, dbUrl)}\n\n${schemaText}`); + fs.writeFileSync( + zmodelPath, + `${noPrelude ? '' : makePrelude(provider, dbUrl)}\n\n${replacePlaceholders(schemaText, provider, dbUrl)}`, + ); const result = await loadDocumentWithPlugins(zmodelPath); if (!result.success) { diff --git a/tests/e2e/orm/client-api/pg-custom-schema.test.ts b/tests/e2e/orm/client-api/pg-custom-schema.test.ts new file mode 100644 index 00000000..7b9252ce --- /dev/null +++ b/tests/e2e/orm/client-api/pg-custom-schema.test.ts @@ -0,0 +1,196 @@ +import { ORMError } from '@zenstackhq/orm'; +import { createTestClient } from '@zenstackhq/testtools'; +import { describe, expect, it } from 'vitest'; + +describe('Postgres custom schema support', () => { + it('defaults to public schema for ORM queries', async () => { + const foundSchema = { create: false, read: false, update: false, delete: false }; + const db = await createTestClient( + ` +model Foo { + id Int @id + name String +} +`, + { + provider: 'postgresql', + log: (event) => { + const sql = event.query.sql.toLowerCase(); + if (sql.includes('"public"."foo"')) { + sql.includes('insert') && (foundSchema.create = true); + sql.includes('select') && (foundSchema.read = true); + sql.includes('update') && (foundSchema.update = true); + sql.includes('delete') && (foundSchema.delete = true); + } + }, + }, + ); + + await expect(db.foo.create({ data: { id: 1, name: 'test' } })).toResolveTruthy(); + await expect(db.foo.findFirst()).toResolveTruthy(); + await expect(db.foo.update({ where: { id: 1 }, data: { name: 'updated' } })).toResolveTruthy(); + await expect(db.foo.delete({ where: { id: 1 } })).toResolveTruthy(); + + expect(foundSchema).toEqual({ create: true, read: true, update: true, delete: true }); + }); + + it('defaults to public schema for QB queries', async () => { + const foundSchema = { create: false, read: false, update: false, delete: false }; + const db = await createTestClient( + ` +model Foo { + id Int @id + name String +} +`, + { + provider: 'postgresql', + log: (event) => { + const sql = event.query.sql.toLowerCase(); + if (sql.includes('"public"."foo"')) { + sql.includes('insert') && (foundSchema.create = true); + sql.includes('select') && (foundSchema.read = true); + sql.includes('update') && (foundSchema.update = true); + sql.includes('delete') && (foundSchema.delete = true); + } + }, + }, + ); + + await expect(db.$qb.insertInto('Foo').values({ id: 1, name: 'test' }).execute()).toResolveTruthy(); + await expect(db.$qb.selectFrom('Foo').selectAll().executeTakeFirst()).toResolveTruthy(); + await expect( + db.$qb.updateTable('Foo').set({ name: 'updated' }).where('id', '=', 1).execute(), + ).toResolveTruthy(); + await expect(db.$qb.deleteFrom('Foo').where('id', '=', 1).execute()).toResolveTruthy(); + + expect(foundSchema).toEqual({ create: true, read: true, update: true, delete: true }); + }); + + it('supports changing default schema', async () => { + const db = await createTestClient( + ` +datasource db { + provider = 'postgresql' + defaultSchema = 'mySchema' +} + +model Foo { + id Int @id + name String +} +`, + { + provider: 'postgresql', + }, + ); + + await expect(db.foo.create({ data: { id: 1, name: 'test' } })).rejects.toSatisfy( + (e) => e instanceof ORMError && !!e.dbErrorMessage?.includes('relation "mySchema.Foo" does not exist'), + ); + + await db.$disconnect(); + + const db1 = await createTestClient( + ` +datasource db { + provider = 'postgresql' + defaultSchema = 'public' +} + +model Foo { + id Int @id + name String +} +`, + { + provider: 'postgresql', + }, + ); + + await expect(db1.foo.create({ data: { id: 1, name: 'test' } })).toResolveTruthy(); + }); + + it('supports custom schemas', async () => { + let fooQueriesVerified = false; + let barQueriesVerified = false; + + const db = await createTestClient( + ` +datasource db { + provider = '$PROVIDER' + schemas = ['public', 'mySchema'] + url = '$DB_URL' +} + +model Foo { + id Int @id + name String + @@schema('mySchema') +} + +model Bar { + id Int @id + name String + @@schema('public') +} +`, + { + provider: 'postgresql', + usePrismaPush: true, + log: (event) => { + const sql = event.query.sql.toLowerCase(); + if (sql.includes('"myschema"."foo"')) { + fooQueriesVerified = true; + } + if (sql.includes('"public"."bar"')) { + barQueriesVerified = true; + } + }, + }, + ); + + await expect(db.foo.create({ data: { id: 1, name: 'test' } })).toResolveTruthy(); + await expect(db.bar.create({ data: { id: 1, name: 'test' } })).toResolveTruthy(); + + expect(fooQueriesVerified).toBe(true); + expect(barQueriesVerified).toBe(true); + }); + + it('rejects using schema for non-postgresql providers', async () => { + await expect( + createTestClient( + ` +datasource db { + provider = 'sqlite' + defaultSchema = 'mySchema' +} + +model Foo { + id Int @id + name String +} +`, + ), + ).rejects.toThrow('only supported for "postgresql" provider'); + }); + + it('rejects using schema not defined in datasource', async () => { + await expect( + createTestClient( + ` +datasource db { + provider = 'postgresql' + schemas = ['public'] +} + +model Foo { + id Int @id + name String + @@schema('mySchema') +} +`, + ), + ).rejects.toThrow('Schema "mySchema" is not defined in the datasource'); + }); +});