diff --git a/package.json b/package.json index 7630ddc9..0230dc0c 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "zenstack-v3", - "version": "3.0.0-beta.20", + "version": "3.0.0-beta.21", "description": "ZenStack", "packageManager": "pnpm@10.20.0", "scripts": { diff --git a/packages/cli/package.json b/packages/cli/package.json index 6fe96eee..3f2187eb 100644 --- a/packages/cli/package.json +++ b/packages/cli/package.json @@ -3,7 +3,7 @@ "publisher": "zenstack", "displayName": "ZenStack CLI", "description": "FullStack database toolkit with built-in access control and automatic API generation.", - "version": "3.0.0-beta.20", + "version": "3.0.0-beta.21", "type": "module", "author": { "name": "ZenStack Team" diff --git a/packages/clients/tanstack-query/package.json b/packages/clients/tanstack-query/package.json index e8bb62a3..ea0cfee4 100644 --- a/packages/clients/tanstack-query/package.json +++ b/packages/clients/tanstack-query/package.json @@ -1,6 +1,6 @@ { "name": "@zenstackhq/tanstack-query", - "version": "3.0.0-beta.20", + "version": "3.0.0-beta.21", "description": "TanStack Query Client for consuming ZenStack v3's CRUD service", "main": "index.js", "type": "module", diff --git a/packages/common-helpers/package.json b/packages/common-helpers/package.json index 338c5ed7..421b8306 100644 --- a/packages/common-helpers/package.json +++ b/packages/common-helpers/package.json @@ -1,6 +1,6 @@ { "name": "@zenstackhq/common-helpers", - "version": "3.0.0-beta.20", + "version": "3.0.0-beta.21", "description": "ZenStack Common Helpers", "type": "module", "scripts": { diff --git a/packages/config/eslint-config/package.json b/packages/config/eslint-config/package.json index bbaa866e..411d7649 100644 --- a/packages/config/eslint-config/package.json +++ b/packages/config/eslint-config/package.json @@ -1,6 +1,6 @@ { "name": "@zenstackhq/eslint-config", - "version": "3.0.0-beta.20", + "version": "3.0.0-beta.21", "type": "module", "private": true, "license": "MIT" diff --git a/packages/config/typescript-config/package.json b/packages/config/typescript-config/package.json index 9d069661..8854ce61 100644 --- a/packages/config/typescript-config/package.json +++ b/packages/config/typescript-config/package.json @@ -1,6 +1,6 @@ { "name": "@zenstackhq/typescript-config", - "version": "3.0.0-beta.20", + "version": "3.0.0-beta.21", "private": true, "license": "MIT" } diff --git a/packages/config/vitest-config/package.json b/packages/config/vitest-config/package.json index 0dad87b3..7c479951 100644 --- a/packages/config/vitest-config/package.json +++ b/packages/config/vitest-config/package.json @@ -1,7 +1,7 @@ { "name": "@zenstackhq/vitest-config", "type": "module", - "version": "3.0.0-beta.20", + "version": "3.0.0-beta.21", "private": true, "license": "MIT", "exports": { diff --git a/packages/create-zenstack/package.json b/packages/create-zenstack/package.json index c1012c86..2ca35a0e 100644 --- a/packages/create-zenstack/package.json +++ b/packages/create-zenstack/package.json @@ -1,6 +1,6 @@ { "name": "create-zenstack", - "version": "3.0.0-beta.20", + "version": "3.0.0-beta.21", "description": "Create a new ZenStack project", "type": "module", "scripts": { diff --git a/packages/dialects/sql.js/package.json b/packages/dialects/sql.js/package.json index d577277f..74028a95 100644 --- a/packages/dialects/sql.js/package.json +++ b/packages/dialects/sql.js/package.json @@ -1,6 +1,6 @@ { "name": "@zenstackhq/kysely-sql-js", - "version": "3.0.0-beta.20", + "version": "3.0.0-beta.21", "description": "Kysely dialect for sql.js", "type": "module", "scripts": { diff --git a/packages/language/package.json b/packages/language/package.json index afb000e5..7bc7408d 100644 --- a/packages/language/package.json +++ b/packages/language/package.json @@ -1,7 +1,7 @@ { "name": "@zenstackhq/language", "description": "ZenStack ZModel language specification", - "version": "3.0.0-beta.20", + "version": "3.0.0-beta.21", "license": "MIT", "author": "ZenStack Team", "files": [ diff --git a/packages/language/res/stdlib.zmodel b/packages/language/res/stdlib.zmodel index 269a0f9c..d3bef43b 100644 --- a/packages/language/res/stdlib.zmodel +++ b/packages/language/res/stdlib.zmodel @@ -454,12 +454,12 @@ attribute @db.JsonB() @@@targetField([JsonField]) @@@prisma attribute @db.ByteA() @@@targetField([BytesField]) @@@prisma -// /** -// * Specifies the schema to use in a multi-schema database. https://www.prisma.io/docs/guides/database/multi-schema. -// * -// * @param: The name of the database schema. -// */ -// attribute @@schema(_ name: String) @@@prisma +/** + * Specifies the schema to use in a multi-schema PostgreSQL database. + * + * @param map: The name of the database schema. + */ +attribute @@schema(_ map: String) @@@prisma ////////////////////////////////////////////// // Begin validation attributes and functions diff --git a/packages/language/src/utils.ts b/packages/language/src/utils.ts index 894c6fc7..885a11d6 100644 --- a/packages/language/src/utils.ts +++ b/packages/language/src/utils.ts @@ -5,6 +5,7 @@ import path from 'node:path'; import { fileURLToPath, pathToFileURL } from 'node:url'; import { PLUGIN_MODULE_NAME, STD_LIB_MODULE_NAME, type ExpressionContext } from './constants'; import { + InternalAttribute, isArrayExpr, isBinaryExpr, isConfigArrayExpr, @@ -173,7 +174,7 @@ export function getRecursiveBases( bases.forEach((base) => { // avoid using .ref since this function can be called before linking const baseDecl = decl.$container.declarations.find( - (d): d is TypeDef | DataModel => isTypeDef(d) || (isDataModel(d) && d.name === base.$refText), + (d): d is TypeDef | DataModel => (isTypeDef(d) || isDataModel(d)) && d.name === base.$refText, ); if (baseDecl) { if (!includeDelegate && isDelegateModel(baseDecl)) { @@ -321,8 +322,15 @@ function getArray(expr: Expression | ConfigExpr | undefined) { return isArrayExpr(expr) || isConfigArrayExpr(expr) ? expr.items : undefined; } +export function getAttributeArg( + attr: DataModelAttribute | DataFieldAttribute | InternalAttribute, + name: string, +): Expression | undefined { + return attr.args.find((arg) => arg.$resolvedParam?.name === name)?.value; +} + export function getAttributeArgLiteral( - attr: DataModelAttribute | DataFieldAttribute, + attr: DataModelAttribute | DataFieldAttribute | InternalAttribute, name: string, ): T | undefined { for (const arg of attr.args) { diff --git a/packages/language/src/validators/attribute-application-validator.ts b/packages/language/src/validators/attribute-application-validator.ts index 80fa1668..a4321c40 100644 --- a/packages/language/src/validators/attribute-application-validator.ts +++ b/packages/language/src/validators/attribute-application-validator.ts @@ -1,3 +1,4 @@ +import { invariant } from '@zenstackhq/common-helpers'; import { AstUtils, type ValidationAcceptor } from 'langium'; import pluralize from 'pluralize'; import type { BinaryExpr, DataModel, Expression } from '../ast'; @@ -13,14 +14,19 @@ import { ReferenceExpr, isArrayExpr, isAttribute, + isConfigArrayExpr, isDataField, isDataModel, + isDataSource, isEnum, + isLiteralExpr, + isModel, isReferenceExpr, isTypeDef, } from '../generated/ast'; import { getAllAttributes, + getAttributeArg, getStringLiteral, hasAttribute, isAuthOrAuthMemberAccess, @@ -291,7 +297,7 @@ export default class AttributeApplicationValidator implements AstValidator f.name === 'schemas'); + if (schemas && isConfigArrayExpr(schemas.value)) { + found = schemas.value.items.some((item) => isLiteralExpr(item) && item.value === schemaName); + } + if (!found) { + accept('error', `Schema "${schemaName}" is not defined in the datasource`, { + node: attr, + }); + } + } + } + private validatePolicyKinds( kind: string, candidates: string[], diff --git a/packages/language/src/validators/datasource-validator.ts b/packages/language/src/validators/datasource-validator.ts index 84302785..b667d2b2 100644 --- a/packages/language/src/validators/datasource-validator.ts +++ b/packages/language/src/validators/datasource-validator.ts @@ -1,6 +1,6 @@ import type { ValidationAcceptor } from 'langium'; import { SUPPORTED_PROVIDERS } from '../constants'; -import { DataSource, isInvocationExpr } from '../generated/ast'; +import { DataSource, isConfigArrayExpr, isInvocationExpr, isLiteralExpr } from '../generated/ast'; import { getStringLiteral } from '../utils'; import { validateDuplicatedDeclarations, type AstValidator } from './common'; @@ -12,7 +12,6 @@ export default class DataSourceValidator implements AstValidator { validateDuplicatedDeclarations(ds, ds.fields, accept); this.validateProvider(ds, accept); this.validateUrl(ds, accept); - this.validateRelationMode(ds, accept); } private validateProvider(ds: DataSource, accept: ValidationAcceptor) { @@ -24,20 +23,63 @@ export default class DataSourceValidator implements AstValidator { return; } - const value = getStringLiteral(provider.value); - if (!value) { + const providerValue = getStringLiteral(provider.value); + if (!providerValue) { accept('error', '"provider" must be set to a string literal', { node: provider.value, }); - } else if (!SUPPORTED_PROVIDERS.includes(value)) { + } else if (!SUPPORTED_PROVIDERS.includes(providerValue)) { accept( 'error', - `Provider "${value}" is not supported. Choose from ${SUPPORTED_PROVIDERS.map((p) => '"' + p + '"').join( - ' | ', - )}.`, + `Provider "${providerValue}" is not supported. Choose from ${SUPPORTED_PROVIDERS.map( + (p) => '"' + p + '"', + ).join(' | ')}.`, { node: provider.value }, ); } + + const defaultSchemaField = ds.fields.find((f) => f.name === 'defaultSchema'); + let defaultSchemaValue: string | undefined; + if (defaultSchemaField) { + if (providerValue !== 'postgresql') { + accept('error', '"defaultSchema" is only supported for "postgresql" provider', { + node: defaultSchemaField, + }); + } + + defaultSchemaValue = getStringLiteral(defaultSchemaField.value); + if (!defaultSchemaValue) { + accept('error', '"defaultSchema" must be a string literal', { + node: defaultSchemaField.value, + }); + } + } + + const schemasField = ds.fields.find((f) => f.name === 'schemas'); + if (schemasField) { + if (providerValue !== 'postgresql') { + accept('error', '"schemas" is only supported for "postgresql" provider', { + node: schemasField, + }); + } + const schemasValue = schemasField.value; + if ( + !isConfigArrayExpr(schemasValue) || + !schemasValue.items.every((e) => isLiteralExpr(e) && typeof getStringLiteral(e) === 'string') + ) { + accept('error', '"schemas" must be an array of string literals', { + node: schemasField, + }); + } else if ( + // validate `defaultSchema` is included in `schemas` + defaultSchemaValue && + !schemasValue.items.some((e) => getStringLiteral(e) === defaultSchemaValue) + ) { + accept('error', `"${defaultSchemaValue}" must be included in the "schemas" array`, { + node: schemasField, + }); + } + } } private validateUrl(ds: DataSource, accept: ValidationAcceptor) { @@ -53,14 +95,4 @@ export default class DataSourceValidator implements AstValidator { }); } } - - private validateRelationMode(ds: DataSource, accept: ValidationAcceptor) { - const field = ds.fields.find((f) => f.name === 'relationMode'); - if (field) { - const val = getStringLiteral(field.value); - if (!val || !['foreignKeys', 'prisma'].includes(val)) { - accept('error', '"relationMode" must be set to "foreignKeys" or "prisma"', { node: field.value }); - } - } - } } diff --git a/packages/orm/package.json b/packages/orm/package.json index 3adc96dc..fb1975f5 100644 --- a/packages/orm/package.json +++ b/packages/orm/package.json @@ -1,6 +1,6 @@ { "name": "@zenstackhq/orm", - "version": "3.0.0-beta.20", + "version": "3.0.0-beta.21", "description": "ZenStack ORM", "type": "module", "scripts": { @@ -67,6 +67,7 @@ "@zenstackhq/common-helpers": "workspace:*", "decimal.js": "catalog:", "json-stable-stringify": "^1.3.0", + "kysely": "catalog:", "nanoid": "^5.0.9", "toposort": "^2.0.2", "ts-pattern": "catalog:", @@ -76,7 +77,6 @@ }, "peerDependencies": { "better-sqlite3": "catalog:", - "kysely": "catalog:", "pg": "catalog:", "zod": "catalog:" }, diff --git a/packages/orm/src/client/crud/validator/index.ts b/packages/orm/src/client/crud/validator/index.ts index 7d97c942..4cd9c584 100644 --- a/packages/orm/src/client/crud/validator/index.ts +++ b/packages/orm/src/client/crud/validator/index.ts @@ -31,10 +31,12 @@ import { type UpdateManyArgs, type UpsertArgs, } from '../../crud-types'; +import { createInternalError, createInvalidInputError } from '../../errors'; import { fieldHasDefaultValue, getDiscriminatorField, getEnum, + getTypeDef, getUniqueFields, requireField, requireModel, @@ -47,7 +49,6 @@ import { addNumberValidation, addStringValidation, } from './utils'; -import { createInternalError, createInvalidInputError } from '../../errors'; const schemaCache = new WeakMap>(); @@ -281,6 +282,8 @@ export class InputValidator { private makePrimitiveSchema(type: string, attributes?: AttributeApplication[]) { if (this.schema.typeDefs && type in this.schema.typeDefs) { return this.makeTypeDefSchema(type); + } else if (this.schema.enums && type in this.schema.enums) { + return this.makeEnumSchema(type); } else { return match(type) .with('String', () => @@ -314,6 +317,22 @@ export class InputValidator { } } + private makeEnumSchema(type: string) { + const key = stableStringify({ + type: 'enum', + name: type, + }); + let schema = this.getSchemaCache(key!); + if (schema) { + return schema; + } + const enumDef = getEnum(this.schema, type); + invariant(enumDef, `Enum "${type}" not found in schema`); + schema = z.enum(Object.keys(enumDef.values) as [string, ...string[]]); + this.setSchemaCache(key!, schema); + return schema; + } + private makeTypeDefSchema(type: string): z.ZodType { const key = stableStringify({ type: 'typedef', @@ -324,24 +343,22 @@ export class InputValidator { if (schema) { return schema; } - const typeDef = this.schema.typeDefs?.[type]; + const typeDef = getTypeDef(this.schema, type); invariant(typeDef, `Type definition "${type}" not found in schema`); - schema = z - .object( - Object.fromEntries( - Object.entries(typeDef.fields).map(([field, def]) => { - let fieldSchema = this.makePrimitiveSchema(def.type); - if (def.array) { - fieldSchema = fieldSchema.array(); - } - if (def.optional) { - fieldSchema = fieldSchema.optional(); - } - return [field, fieldSchema]; - }), - ), - ) - .passthrough(); + schema = z.looseObject( + Object.fromEntries( + Object.entries(typeDef.fields).map(([field, def]) => { + let fieldSchema = this.makePrimitiveSchema(def.type); + if (def.array) { + fieldSchema = fieldSchema.array(); + } + if (def.optional) { + fieldSchema = fieldSchema.optional(); + } + return [field, fieldSchema]; + }), + ), + ); this.setSchemaCache(key!, schema); return schema; } @@ -392,7 +409,7 @@ export class InputValidator { const enumDef = getEnum(this.schema, fieldDef.type); if (enumDef) { // enum - if (Object.keys(enumDef).length > 0) { + if (Object.keys(enumDef.values).length > 0) { fieldSchema = this.makeEnumFilterSchema(enumDef, !!fieldDef.optional, withAggregations); } } else if (fieldDef.array) { @@ -427,7 +444,7 @@ export class InputValidator { const enumDef = getEnum(this.schema, def.type); if (enumDef) { // enum - if (Object.keys(enumDef).length > 0) { + if (Object.keys(enumDef.values).length > 0) { fieldSchema = this.makeEnumFilterSchema(enumDef, !!def.optional, false); } else { fieldSchema = z.never(); @@ -493,7 +510,7 @@ export class InputValidator { } private makeEnumFilterSchema(enumDef: EnumDef, optional: boolean, withAggregations: boolean) { - const baseSchema = z.enum(Object.keys(enumDef) as [string, ...string[]]); + const baseSchema = z.enum(Object.keys(enumDef.values) as [string, ...string[]]); const components = this.makeCommonPrimitiveFilterComponents( baseSchema, optional, diff --git a/packages/orm/src/client/executor/name-mapper.ts b/packages/orm/src/client/executor/name-mapper.ts index dcea8152..2ad522f5 100644 --- a/packages/orm/src/client/executor/name-mapper.ts +++ b/packages/orm/src/client/executor/name-mapper.ts @@ -1,12 +1,17 @@ import { invariant } from '@zenstackhq/common-helpers'; import { AliasNode, + CaseWhenBuilder, ColumnNode, + ColumnUpdateNode, DeleteQueryNode, + expressionBuilder, + ExpressionWrapper, FromNode, IdentifierNode, InsertQueryNode, OperationNodeTransformer, + PrimitiveValueListNode, ReferenceNode, ReturningNode, SelectAllNode, @@ -14,10 +19,23 @@ import { SelectQueryNode, TableNode, UpdateQueryNode, + ValueListNode, + ValueNode, + ValuesNode, type OperationNode, + type SimpleReferenceExpressionNode, } from 'kysely'; -import type { FieldDef, ModelDef, SchemaDef } from '../../schema'; -import { extractFieldName, extractModelName, getModel, requireModel, stripAlias } from '../query-utils'; +import type { EnumDef, EnumField, FieldDef, ModelDef, SchemaDef } from '../../schema'; +import { + extractFieldName, + extractModelName, + getEnum, + getField, + getModel, + isEnum, + requireModel, + stripAlias, +} from '../query-utils'; type Scope = { model?: string; @@ -25,6 +43,8 @@ type Scope = { namesMapped?: boolean; // true means fields referring to this scope have their names already mapped }; +type SelectionNodeChild = SimpleReferenceExpressionNode | AliasNode | SelectAllNode; + export class QueryNameMapper extends OperationNodeTransformer { private readonly modelToTableMap = new Map(); private readonly fieldToColumnMap = new Map(); @@ -89,15 +109,27 @@ export class QueryNameMapper extends OperationNodeTransformer { return super.transformInsertQuery(node); } - return this.withScope( - { model: node.into.table.identifier.name }, - () => - ({ - ...super.transformInsertQuery(node), - // map table name - into: this.processTableRef(node.into!), - }) satisfies InsertQueryNode, - ); + const model = extractModelName(node.into); + invariant(model, 'InsertQueryNode must have a model name in the "into" clause'); + + return this.withScope({ model }, () => { + const baseResult = super.transformInsertQuery(node); + let values = baseResult.values; + if (node.columns && values) { + // process enum values with name mapping + values = this.processEnumMappingForColumns(model, node.columns, values); + } + return { + ...baseResult, + // map table name + into: this.processTableRef(node.into!), + values, + } satisfies InsertQueryNode; + }); + } + + private isOperationNode(value: unknown): value is OperationNode { + return !!value && typeof value === 'object' && 'kind' in value; } protected override transformReturning(node: ReturningNode) { @@ -129,10 +161,9 @@ export class QueryNameMapper extends OperationNodeTransformer { mappedTableName = this.mapTableName(scope.model); } } - return ReferenceNode.create( ColumnNode.create(mappedFieldName), - mappedTableName ? TableNode.create(mappedTableName) : undefined, + mappedTableName ? this.createTableNode(mappedTableName, undefined) : undefined, ); } else { // no name mapping needed @@ -159,9 +190,29 @@ export class QueryNameMapper extends OperationNodeTransformer { return super.transformUpdateQuery(node); } - return this.withScope({ model: innerTable.table.identifier.name, alias }, () => { + const model = extractModelName(innerTable); + invariant(model, 'UpdateQueryNode must have a model name in the "table" clause'); + + return this.withScope({ model, alias }, () => { + const baseResult = super.transformUpdateQuery(node); + + // process enum value mappings in update set values + const updates = baseResult.updates?.map((update, i) => { + if (ColumnNode.is(update.column)) { + // fetch original column that doesn't have name mapping applied + const origColumn = node.updates![i]!.column as ColumnNode; + return ColumnUpdateNode.create( + update.column, + this.processEnumMappingForValue(model, origColumn, update.value) as OperationNode, + ); + } else { + return update; + } + }); + return { - ...super.transformUpdateQuery(node), + ...baseResult, + updates, // map table name table: this.wrapAlias(this.processTableRef(innerTable), alias), }; @@ -205,42 +256,70 @@ export class QueryNameMapper extends OperationNodeTransformer { private processSelectQuerySelections(node: SelectQueryNode) { const selections: SelectionNode[] = []; for (const selection of node.selections ?? []) { + const processedSelections: { originalField?: string; selection: SelectionNode }[] = []; if (SelectAllNode.is(selection.selection)) { // expand `selectAll` to all fields with name mapping if the // inner-most scope is not already mapped - const scope = this.scopes[this.scopes.length - 1]; + const scope = this.requireCurrentScope(); if (scope?.model && !scope.namesMapped) { - selections.push(...this.createSelectAllFields(scope.model, scope.alias)); + // expand + processedSelections.push(...this.createSelectAllFields(scope.model, scope.alias)); } else { - selections.push(super.transformSelection(selection)); + // preserve + processedSelections.push({ + originalField: undefined, + selection: super.transformSelection(selection), + }); } } else if (ReferenceNode.is(selection.selection) || ColumnNode.is(selection.selection)) { // map column name and add/preserve alias const transformed = this.transformNode(selection.selection); + + // field name without applying name mapping + const originalField = extractFieldName(selection.selection); + if (AliasNode.is(transformed)) { // keep the alias if there's one - selections.push(SelectionNode.create(transformed)); + processedSelections.push({ originalField, selection: SelectionNode.create(transformed) }); } else { // otherwise use an alias to preserve the original field name - const origFieldName = extractFieldName(selection.selection); const fieldName = extractFieldName(transformed); - if (fieldName !== origFieldName) { - selections.push( - SelectionNode.create( + if (fieldName !== originalField) { + processedSelections.push({ + originalField, + selection: SelectionNode.create( this.wrapAlias( transformed, - origFieldName ? IdentifierNode.create(origFieldName) : undefined, + originalField ? IdentifierNode.create(originalField) : undefined, ), ), - ); + }); } else { - selections.push(SelectionNode.create(transformed)); + processedSelections.push({ + originalField, + selection: SelectionNode.create(transformed), + }); } } } else { - selections.push(super.transformSelection(selection)); + const { node: innerNode } = stripAlias(selection.selection); + processedSelections.push({ + originalField: extractFieldName(innerNode), + selection: super.transformSelection(selection), + }); } + + // process enum value mapping + const enumProcessedSelections = processedSelections.map(({ originalField, selection }) => { + if (!originalField) { + return selection; + } else { + return SelectionNode.create(this.processEnumSelection(selection.selection, originalField)); + } + }); + selections.push(...enumProcessedSelections); } + return selections; } @@ -316,10 +395,12 @@ export class QueryNameMapper extends OperationNodeTransformer { if (!TableNode.is(node)) { return super.transformNode(node); } - return TableNode.create(this.mapTableName(node.table.identifier.name)); + const mappedName = this.mapTableName(node.table.identifier.name); + const tableSchema = this.getTableSchema(node.table.identifier.name); + return this.createTableNode(mappedName, tableSchema); } - private getMappedName(def: ModelDef | FieldDef) { + private getMappedName(def: ModelDef | FieldDef | EnumField) { const mapAttr = def.attributes?.find((attr) => attr.name === '@@map' || attr.name === '@map'); if (mapAttr) { const nameArg = mapAttr.args?.find((arg) => arg.name === 'name'); @@ -362,8 +443,9 @@ export class QueryNameMapper extends OperationNodeTransformer { const modelName = innerNode.table.identifier.name; const mappedName = this.mapTableName(modelName); const finalAlias = alias ?? (mappedName !== modelName ? IdentifierNode.create(modelName) : undefined); + const tableSchema = this.getTableSchema(modelName); return { - node: this.wrapAlias(TableNode.create(mappedName), finalAlias), + node: this.wrapAlias(this.createTableNode(mappedName, tableSchema), finalAlias), scope: { alias: alias ?? IdentifierNode.create(modelName), model: modelName, @@ -384,6 +466,21 @@ export class QueryNameMapper extends OperationNodeTransformer { } } + private getTableSchema(model: string) { + if (this.schema.provider.type !== 'postgresql') { + return undefined; + } + let schema = this.schema.provider.defaultSchema ?? 'public'; + const schemaAttr = this.schema.models[model]?.attributes?.find((attr) => attr.name === '@@schema'); + if (schemaAttr) { + const nameArg = schemaAttr.args?.find((arg) => arg.name === 'map'); + if (nameArg && nameArg.value.kind === 'literal') { + schema = nameArg.value.value as string; + } + } + return schema; + } + private createSelectAllFields(model: string, alias: OperationNode | undefined) { const modelDef = requireModel(this.schema, model); return this.getModelFields(modelDef).map((fieldDef) => { @@ -394,9 +491,9 @@ export class QueryNameMapper extends OperationNodeTransformer { ); if (columnName !== fieldDef.name) { const aliased = AliasNode.create(columnRef, IdentifierNode.create(fieldDef.name)); - return SelectionNode.create(aliased); + return { originalField: fieldDef.name, selection: SelectionNode.create(aliased) }; } else { - return SelectionNode.create(columnRef); + return { originalField: fieldDef.name, selection: SelectionNode.create(columnRef) }; } }); } @@ -425,20 +522,28 @@ export class QueryNameMapper extends OperationNodeTransformer { return result; } - private processSelection(node: AliasNode | ColumnNode | ReferenceNode) { - let alias: string | undefined; - if (!AliasNode.is(node)) { - alias = extractFieldName(node); + private processSelection(node: SelectionNodeChild) { + const { alias, node: innerNode } = stripAlias(node); + const originalField = extractFieldName(innerNode); + let result = super.transformNode(node); + + if (originalField) { + // process enum value mapping + result = this.processEnumSelection(result, originalField); } - const result = super.transformNode(node); - return this.wrapAlias(result, alias ? IdentifierNode.create(alias) : undefined); + + if (!AliasNode.is(result)) { + const addAlias = alias ?? (originalField ? IdentifierNode.create(originalField) : undefined); + if (addAlias) { + result = this.wrapAlias(result, addAlias); + } + } + return result; } private processSelectAll(node: SelectAllNode) { - const scope = this.scopes[this.scopes.length - 1]; - invariant(scope); - - if (!scope.model || !this.hasMappedColumns(scope.model)) { + const scope = this.requireCurrentScope(); + if (!scope.model || !(this.hasMappedColumns(scope.model) || this.modelUsesEnumWithMappedValues(scope.model))) { // no name mapping needed, preserve the select all return super.transformSelectAll(node); } @@ -448,11 +553,164 @@ export class QueryNameMapper extends OperationNodeTransformer { return this.getModelFields(modelDef).map((fieldDef) => { const columnName = this.mapFieldName(modelDef.name, fieldDef.name); const columnRef = ReferenceNode.create(ColumnNode.create(columnName)); - return columnName !== fieldDef.name - ? this.wrapAlias(columnRef, IdentifierNode.create(fieldDef.name)) - : columnRef; + + // process enum value mapping + const enumProcessed = this.processEnumSelection(columnRef, fieldDef.name); + + return columnName !== fieldDef.name && !AliasNode.is(enumProcessed) + ? this.wrapAlias(enumProcessed, IdentifierNode.create(fieldDef.name)) + : enumProcessed; }); } + private createTableNode(tableName: string, schemaName: string | undefined) { + return schemaName ? TableNode.createWithSchema(schemaName, tableName) : TableNode.create(tableName); + } + + private requireCurrentScope() { + const scope = this.scopes[this.scopes.length - 1]; + invariant(scope, 'No scope available'); + return scope; + } + + // #endregion + + // #region enum value mapping + + private modelUsesEnumWithMappedValues(model: string) { + const modelDef = getModel(this.schema, model); + if (!modelDef) { + return false; + } + return this.getModelFields(modelDef).some((fieldDef) => { + const enumDef = getEnum(this.schema, fieldDef.type); + if (!enumDef) { + return false; + } + return Object.values(enumDef.fields ?? {}).some((f) => f.attributes?.some((attr) => attr.name === '@map')); + }); + } + + private getEnumValueMapping(enumDef: EnumDef) { + const mapping: Record = {}; + for (const [key, field] of Object.entries(enumDef.fields ?? {})) { + const mappedName = this.getMappedName(field); + if (mappedName) { + mapping[key] = mappedName; + } + } + return mapping; + } + + private processEnumMappingForColumns( + model: string, + columns: readonly ColumnNode[], + values: OperationNode, + ): OperationNode { + if (ValuesNode.is(values)) { + return ValuesNode.create( + values.values.map((valueItems) => { + if (PrimitiveValueListNode.is(valueItems)) { + return PrimitiveValueListNode.create( + this.processEnumMappingForValues(model, columns, valueItems.values), + ); + } else { + return ValueListNode.create( + this.processEnumMappingForValues(model, columns, valueItems.values) as OperationNode[], + ); + } + }), + ); + } else if (PrimitiveValueListNode.is(values)) { + return PrimitiveValueListNode.create(this.processEnumMappingForValues(model, columns, values.values)); + } else { + return values; + } + } + + private processEnumMappingForValues(model: string, columns: readonly ColumnNode[], values: readonly unknown[]) { + const result: unknown[] = []; + for (let i = 0; i < columns.length; i++) { + const value = values[i]; + if (value === null || value === undefined) { + result.push(value); + continue; + } + result.push(this.processEnumMappingForValue(model, columns[i]!, value)); + } + return result; + } + + private processEnumMappingForValue(model: string, column: ColumnNode, value: unknown) { + const fieldDef = getField(this.schema, model, column.column.name); + if (!fieldDef) { + return value; + } + if (!isEnum(this.schema, fieldDef.type)) { + return value; + } + + const enumDef = getEnum(this.schema, fieldDef.type); + if (!enumDef) { + return value; + } + + const enumValueMapping = this.getEnumValueMapping(enumDef); + if (this.isOperationNode(value) && ValueNode.is(value) && typeof value.value === 'string') { + const mappedValue = enumValueMapping[value.value]; + if (mappedValue) { + return ValueNode.create(mappedValue); + } + } else if (typeof value === 'string') { + const mappedValue = enumValueMapping[value]; + if (mappedValue) { + return mappedValue; + } + } + + return value; + } + + private processEnumSelection(selection: SelectionNodeChild, fieldName: string) { + const { alias, node } = stripAlias(selection); + const fieldScope = this.resolveFieldFromScopes(fieldName); + if (!fieldScope || !fieldScope.model) { + return selection; + } + const aliasName = alias && IdentifierNode.is(alias) ? alias.name : fieldName; + + const fieldDef = getField(this.schema, fieldScope.model, fieldName); + if (!fieldDef) { + return selection; + } + const enumDef = getEnum(this.schema, fieldDef.type); + if (!enumDef) { + return selection; + } + const enumValueMapping = this.getEnumValueMapping(enumDef); + if (Object.keys(enumValueMapping).length === 0) { + return selection; + } + + const eb = expressionBuilder(); + const caseBuilder = eb.case(); + let caseWhen: CaseWhenBuilder | undefined; + for (const [key, value] of Object.entries(enumValueMapping)) { + if (!caseWhen) { + caseWhen = caseBuilder.when(new ExpressionWrapper(node), '=', value).then(key); + } else { + caseWhen = caseWhen.when(new ExpressionWrapper(node), '=', value).then(key); + } + } + + // the explicit cast to "text" is needed to address postgres's case-when type inference issue + const finalExpr = caseWhen!.else(eb.cast(new ExpressionWrapper(node), 'text')).end(); + if (aliasName) { + return finalExpr.as(aliasName).toOperationNode() as SelectionNodeChild; + } else { + return finalExpr.toOperationNode() as SelectionNodeChild; + } + } + // #endregion } diff --git a/packages/orm/src/client/executor/zenstack-query-executor.ts b/packages/orm/src/client/executor/zenstack-query-executor.ts index e53552c5..06f8d133 100644 --- a/packages/orm/src/client/executor/zenstack-query-executor.ts +++ b/packages/orm/src/client/executor/zenstack-query-executor.ts @@ -55,7 +55,10 @@ export class ZenStackQueryExecutor extends DefaultQuer ) { super(compiler, adapter, connectionProvider, plugins); - if (this.schemaHasMappedNames(client.$schema)) { + if ( + client.$schema.provider.type === 'postgresql' || // postgres queries need to be schema-qualified + this.schemaHasMappedNames(client.$schema) + ) { this.nameMapper = new QueryNameMapper(client.$schema); } } diff --git a/packages/orm/src/client/helpers/schema-db-pusher.ts b/packages/orm/src/client/helpers/schema-db-pusher.ts index 52fc462b..f04c14e1 100644 --- a/packages/orm/src/client/helpers/schema-db-pusher.ts +++ b/packages/orm/src/client/helpers/schema-db-pusher.ts @@ -23,7 +23,26 @@ export class SchemaDbPusher { await this.kysely.transaction().execute(async (tx) => { if (this.schema.enums && this.schema.provider.type === 'postgresql') { for (const [name, enumDef] of Object.entries(this.schema.enums)) { - const createEnum = tx.schema.createType(name).asEnum(Object.values(enumDef)); + let enumValues: string[]; + if (enumDef.fields) { + enumValues = Object.values(enumDef.fields).map((f) => { + const mapAttr = f.attributes?.find((a) => a.name === '@map'); + if (!mapAttr || !mapAttr.args?.[0]) { + return f.name; + } else { + const mappedName = ExpressionUtils.getLiteralValue(mapAttr.args[0].value); + invariant( + mappedName && typeof mappedName === 'string', + `Invalid @map attribute for enum field ${f.name}`, + ); + return mappedName; + } + }); + } else { + enumValues = Object.values(enumDef.values); + } + + const createEnum = tx.schema.createType(name).asEnum(enumValues); await createEnum.execute(); } } diff --git a/packages/plugins/policy/package.json b/packages/plugins/policy/package.json index c374d135..2bebee97 100644 --- a/packages/plugins/policy/package.json +++ b/packages/plugins/policy/package.json @@ -1,6 +1,6 @@ { "name": "@zenstackhq/plugin-policy", - "version": "3.0.0-beta.20", + "version": "3.0.0-beta.21", "description": "ZenStack Policy Plugin", "type": "module", "scripts": { diff --git a/packages/schema/package.json b/packages/schema/package.json index a8ea28ed..b80e9429 100644 --- a/packages/schema/package.json +++ b/packages/schema/package.json @@ -1,6 +1,6 @@ { "name": "@zenstackhq/schema", - "version": "3.0.0-beta.20", + "version": "3.0.0-beta.21", "description": "ZenStack Runtime Schema", "type": "module", "scripts": { diff --git a/packages/schema/src/schema.ts b/packages/schema/src/schema.ts index 5dc9efc4..13fc90b9 100644 --- a/packages/schema/src/schema.ts +++ b/packages/schema/src/schema.ts @@ -5,6 +5,7 @@ export type DataSourceProviderType = 'sqlite' | 'postgresql'; export type DataSourceProvider = { type: DataSourceProviderType; + defaultSchema?: string; }; export type SchemaDef = { @@ -97,7 +98,16 @@ export type BuiltinType = export type MappedBuiltinType = string | boolean | number | bigint | Decimal | Date; -export type EnumDef = Record; +export type EnumField = { + name: string; + attributes?: AttributeApplication[]; +}; + +export type EnumDef = { + fields?: Record; + values: Record; + attributes?: AttributeApplication[]; +}; export type TypeDefDef = { name: string; @@ -124,7 +134,9 @@ export type GetModel> export type GetEnums = keyof Schema['enums']; -export type GetEnum> = Schema['enums'][Enum]; +export type GetEnum> = Schema['enums'][Enum] extends EnumDef + ? Schema['enums'][Enum]['values'] + : never; export type GetTypeDefs = Extract; diff --git a/packages/sdk/package.json b/packages/sdk/package.json index 83ca0c46..2348bdc1 100644 --- a/packages/sdk/package.json +++ b/packages/sdk/package.json @@ -1,6 +1,6 @@ { "name": "@zenstackhq/sdk", - "version": "3.0.0-beta.20", + "version": "3.0.0-beta.21", "description": "ZenStack SDK", "type": "module", "scripts": { diff --git a/packages/sdk/src/prisma/prisma-schema-generator.ts b/packages/sdk/src/prisma/prisma-schema-generator.ts index 3f3ba823..45ffed3c 100644 --- a/packages/sdk/src/prisma/prisma-schema-generator.ts +++ b/packages/sdk/src/prisma/prisma-schema-generator.ts @@ -19,6 +19,7 @@ import { InvocationExpr, isArrayExpr, isDataModel, + isDataSource, isInvocationExpr, isLiteralExpr, isNullExpr, @@ -29,9 +30,14 @@ import { Model, NumberLiteral, StringLiteral, - type AstNode, } from '@zenstackhq/language/ast'; -import { getAllAttributes, getAllFields, isAuthInvocation, isDelegateModel } from '@zenstackhq/language/utils'; +import { + getAllAttributes, + getAllFields, + getStringLiteral, + isAuthInvocation, + isDelegateModel, +} from '@zenstackhq/language/utils'; import { AstUtils } from 'langium'; import { match } from 'ts-pattern'; import { ModelUtils } from '..'; @@ -58,6 +64,9 @@ import { // Here we use a conservative value that should work for most cases, and truncate names if needed const IDENTIFIER_NAME_MAX_LENGTH = 50 - DELEGATE_AUX_RELATION_PREFIX.length; +// Datasource fields that only exist in ZModel but not in Prisma schema +const NON_PRISMA_DATASOURCE_FIELDS = ['defaultSchema']; + /** * Generates Prisma schema file */ @@ -101,10 +110,12 @@ export class PrismaSchemaGenerator { } private generateDataSource(prisma: PrismaModel, dataSource: DataSource) { - const fields: SimpleField[] = dataSource.fields.map((f) => ({ - name: f.name, - text: this.configExprToText(f.value), - })); + const fields: SimpleField[] = dataSource.fields + .filter((f) => !NON_PRISMA_DATASOURCE_FIELDS.includes(f.name)) + .map((f) => ({ + name: f.name, + text: this.configExprToText(f.value), + })); prisma.addDataSource(dataSource.name, fields); } @@ -171,13 +182,27 @@ export class PrismaSchemaGenerator { } } - const allAttributes = getAllAttributes(decl); - for (const attr of allAttributes.filter( + const allAttributes = getAllAttributes(decl).filter( (attr) => this.isPrismaAttribute(attr) && !this.isInheritedMapAttribute(attr, decl), - )) { + ); + + for (const attr of allAttributes) { this.generateContainerAttribute(model, attr); } + if ( + this.datasourceHasSchemasSetting(decl.$container) && + !allAttributes.some((attr) => attr.decl.ref?.name === '@@schema') + ) { + // if the datasource declared `schemas` and no @@schema attribute is defined, add a default one + model.addAttribute('@@schema', [ + new PrismaAttributeArg( + undefined, + new PrismaAttributeArgValue('String', this.getDefaultPostgresSchemaName(decl.$container)), + ), + ]); + } + // user defined comments pass-through decl.comments.forEach((c) => model.addComment(c)); @@ -188,6 +213,20 @@ export class PrismaSchemaGenerator { this.generateDelegateRelationForConcrete(model, decl); } + private getDatasourceField(zmodel: Model, fieldName: string) { + const dataSource = zmodel.declarations.find(isDataSource); + return dataSource?.fields.find((f) => f.name === fieldName); + } + + private datasourceHasSchemasSetting(zmodel: Model) { + return !!this.getDatasourceField(zmodel, 'schemas'); + } + + private getDefaultPostgresSchemaName(zmodel: Model) { + const defaultSchemaField = this.getDatasourceField(zmodel, 'defaultSchema'); + return getStringLiteral(defaultSchemaField?.value) ?? 'public'; + } + private isInheritedMapAttribute(attr: DataModelAttribute, contextModel: DataModel) { if (attr.$container === contextModel) { return false; @@ -206,7 +245,7 @@ export class PrismaSchemaGenerator { private getUnsupportedFieldType(fieldType: DataFieldType) { if (fieldType.unsupported) { - const value = this.getStringLiteral(fieldType.unsupported.value); + const value = getStringLiteral(fieldType.unsupported.value); if (value) { return `Unsupported("${value}")`; } else { @@ -217,10 +256,6 @@ export class PrismaSchemaGenerator { } } - private getStringLiteral(node: AstNode | undefined): string | undefined { - return isStringLiteral(node) ? node.value : undefined; - } - private generateModelField(model: PrismaDataModel, field: DataField, contextModel: DataModel, addToFront = false) { let fieldType: string | undefined; diff --git a/packages/sdk/src/ts-schema-generator.ts b/packages/sdk/src/ts-schema-generator.ts index 821e0bd6..798b6dfe 100644 --- a/packages/sdk/src/ts-schema-generator.ts +++ b/packages/sdk/src/ts-schema-generator.ts @@ -35,7 +35,7 @@ import { UnaryExpr, type Model, } from '@zenstackhq/language/ast'; -import { getAllAttributes, getAllFields, isDataFieldReference } from '@zenstackhq/language/utils'; +import { getAllAttributes, getAllFields, getAttributeArg, isDataFieldReference } from '@zenstackhq/language/utils'; import fs from 'node:fs'; import path from 'node:path'; import { match } from 'ts-pattern'; @@ -236,8 +236,20 @@ export class TsSchemaGenerator { private createProviderObject(model: Model): ts.Expression { const dsProvider = this.getDataSourceProvider(model); + const defaultSchema = this.getDataSourceDefaultSchema(model); + return ts.factory.createObjectLiteralExpression( - [ts.factory.createPropertyAssignment('type', ts.factory.createStringLiteral(dsProvider.type))], + [ + ts.factory.createPropertyAssignment('type', ts.factory.createStringLiteral(dsProvider)), + ...(defaultSchema + ? [ + ts.factory.createPropertyAssignment( + 'defaultSchema', + ts.factory.createStringLiteral(defaultSchema), + ), + ] + : []), + ], true, ); } @@ -621,9 +633,26 @@ export class TsSchemaGenerator { invariant(dataSource, 'No data source found in the model'); const providerExpr = dataSource.fields.find((f) => f.name === 'provider')?.value; - invariant(isLiteralExpr(providerExpr), 'Provider must be a literal'); - const type = providerExpr.value as string; - return { type }; + invariant( + isLiteralExpr(providerExpr) && typeof providerExpr.value === 'string', + 'Provider must be a string literal', + ); + return providerExpr.value as string; + } + + private getDataSourceDefaultSchema(model: Model) { + const dataSource = model.declarations.find(isDataSource); + invariant(dataSource, 'No data source found in the model'); + + const defaultSchemaExpr = dataSource.fields.find((f) => f.name === 'defaultSchema')?.value; + if (!defaultSchemaExpr) { + return undefined; + } + invariant( + isLiteralExpr(defaultSchemaExpr) && typeof defaultSchemaExpr.value === 'string', + 'Default schema must be a string literal', + ); + return defaultSchemaExpr.value as string; } private getFieldMappedDefault( @@ -840,7 +869,11 @@ export class TsSchemaGenerator { const seenKeys = new Set(); for (const attr of allAttributes) { if (attr.decl.$refText === '@@id' || attr.decl.$refText === '@@unique') { - const fieldNames = this.getReferenceNames(attr.args[0]!.value); + const fieldsArg = getAttributeArg(attr, 'fields'); + if (!fieldsArg) { + continue; + } + const fieldNames = this.getReferenceNames(fieldsArg); if (!fieldNames) { continue; } @@ -914,9 +947,68 @@ export class TsSchemaGenerator { private createEnumObject(e: Enum) { return ts.factory.createObjectLiteralExpression( - e.fields.map((field) => - ts.factory.createPropertyAssignment(field.name, ts.factory.createStringLiteral(field.name)), - ), + [ + ts.factory.createPropertyAssignment( + 'values', + ts.factory.createObjectLiteralExpression( + e.fields.map((f) => + ts.factory.createPropertyAssignment(f.name, ts.factory.createStringLiteral(f.name)), + ), + true, + ), + ), + + // only generate `fields` if there are attributes on the fields + ...(e.fields.some((f) => f.attributes.length > 0) + ? [ + ts.factory.createPropertyAssignment( + 'fields', + ts.factory.createObjectLiteralExpression( + e.fields.map((field) => + ts.factory.createPropertyAssignment( + field.name, + ts.factory.createObjectLiteralExpression( + [ + ts.factory.createPropertyAssignment( + 'name', + ts.factory.createStringLiteral(field.name), + ), + ...(field.attributes.length > 0 + ? [ + ts.factory.createPropertyAssignment( + 'attributes', + ts.factory.createArrayLiteralExpression( + field.attributes?.map((attr) => + this.createAttributeObject(attr), + ) ?? [], + true, + ), + ), + ] + : []), + ], + true, + ), + ), + ), + true, + ), + ), + ] + : []), + + ...(e.attributes.length > 0 + ? [ + ts.factory.createPropertyAssignment( + 'attributes', + ts.factory.createArrayLiteralExpression( + e.attributes.map((attr) => this.createAttributeObject(attr)), + true, + ), + ), + ] + : []), + ], true, ); } @@ -1226,7 +1318,7 @@ export class TsSchemaGenerator { statements.push(typeDef); } - // generate: export const Enum = $schema.enums.Enum; + // generate: export const Enum = $schema.enums.Enum['values']; const enums = model.declarations.filter(isEnum); for (const e of enums) { let enumDecl = ts.factory.createVariableStatement( @@ -1239,10 +1331,13 @@ export class TsSchemaGenerator { undefined, ts.factory.createPropertyAccessExpression( ts.factory.createPropertyAccessExpression( - ts.factory.createIdentifier('$schema'), - ts.factory.createIdentifier('enums'), + ts.factory.createPropertyAccessExpression( + ts.factory.createIdentifier('$schema'), + ts.factory.createIdentifier('enums'), + ), + ts.factory.createIdentifier(e.name), ), - ts.factory.createIdentifier(e.name), + ts.factory.createIdentifier('values'), ), ), ], diff --git a/packages/server/package.json b/packages/server/package.json index af6eebb4..f08bb23b 100644 --- a/packages/server/package.json +++ b/packages/server/package.json @@ -1,6 +1,6 @@ { "name": "@zenstackhq/server", - "version": "3.0.0-beta.20", + "version": "3.0.0-beta.21", "description": "ZenStack automatic CRUD API handlers and server adapters", "type": "module", "scripts": { diff --git a/packages/testtools/package.json b/packages/testtools/package.json index 653427b9..7ae44ce3 100644 --- a/packages/testtools/package.json +++ b/packages/testtools/package.json @@ -1,6 +1,6 @@ { "name": "@zenstackhq/testtools", - "version": "3.0.0-beta.20", + "version": "3.0.0-beta.21", "description": "ZenStack Test Tools", "type": "module", "scripts": { diff --git a/packages/testtools/src/client.ts b/packages/testtools/src/client.ts index f6ea4b8d..569a67ab 100644 --- a/packages/testtools/src/client.ts +++ b/packages/testtools/src/client.ts @@ -1,8 +1,8 @@ import { invariant } from '@zenstackhq/common-helpers'; import type { Model } from '@zenstackhq/language/ast'; -import { PolicyPlugin } from '@zenstackhq/plugin-policy'; import { ZenStackClient, type ClientContract, type ClientOptions } from '@zenstackhq/orm'; import type { SchemaDef } from '@zenstackhq/orm/schema'; +import { PolicyPlugin } from '@zenstackhq/plugin-policy'; import { PrismaSchemaGenerator } from '@zenstackhq/sdk'; import SQLite from 'better-sqlite3'; import { PostgresDialect, SqliteDialect, type LogEvent } from 'kysely'; @@ -33,17 +33,18 @@ const TEST_PG_CONFIG = { export type CreateTestClientOptions = Omit, 'dialect'> & { provider?: 'sqlite' | 'postgresql'; + schemaFile?: string; dbName?: string; usePrismaPush?: boolean; extraSourceFiles?: Record; workDir?: string; debug?: boolean; + dbFile?: string; }; export async function createTestClient( schema: Schema, options?: CreateTestClientOptions, - schemaFile?: string, ): Promise>; export async function createTestClient( schema: string, @@ -52,14 +53,11 @@ export async function createTestClient( export async function createTestClient( schema: Schema | string, options?: CreateTestClientOptions, - schemaFile?: string, ): Promise { let workDir = options?.workDir; let _schema: Schema; const provider = options?.provider ?? getTestDbProvider() ?? 'sqlite'; - const dbName = options?.dbName ?? getTestDbName(provider); - const dbUrl = provider === 'sqlite' ? `file:${dbName}` @@ -68,13 +66,14 @@ export async function createTestClient( let model: Model | undefined; if (typeof schema === 'string') { - const generated = await generateTsSchema(schema, provider, dbUrl, options?.extraSourceFiles); + const generated = await generateTsSchema(schema, provider, dbUrl, options?.extraSourceFiles, undefined); workDir = generated.workDir; model = generated.model; // replace schema's provider _schema = { ...generated.schema, provider: { + ...generated.schema.provider, type: provider, }, } as Schema; @@ -87,8 +86,8 @@ export async function createTestClient( }, }; workDir ??= createTestProject(); - if (schemaFile) { - let schemaContent = fs.readFileSync(schemaFile, 'utf-8'); + if (options?.schemaFile) { + let schemaContent = fs.readFileSync(options.schemaFile, 'utf-8'); if (dbUrl) { // replace `datasource db { }` section schemaContent = schemaContent.replace( @@ -108,35 +107,48 @@ export async function createTestClient( console.log(`Work directory: ${workDir}`); } + // copy db file to workDir if specified + if (options?.dbFile) { + if (provider !== 'sqlite') { + throw new Error('dbFile option is only supported for sqlite provider'); + } + fs.copyFileSync(options.dbFile, path.join(workDir, dbName)); + } + const { plugins, ...rest } = options ?? {}; const _options: ClientOptions = { ...rest, } as ClientOptions; - if (options?.usePrismaPush) { - invariant(typeof schema === 'string' || schemaFile, 'a schema file must be provided when using prisma db push'); - if (!model) { - const r = await loadDocumentWithPlugins(path.join(workDir, 'schema.zmodel')); - if (!r.success) { - throw new Error(r.errors.join('\n')); + if (!options?.dbFile) { + if (options?.usePrismaPush) { + invariant( + typeof schema === 'string' || options?.schemaFile, + 'a schema file must be provided when using prisma db push', + ); + if (!model) { + const r = await loadDocumentWithPlugins(path.join(workDir, 'schema.zmodel')); + if (!r.success) { + throw new Error(r.errors.join('\n')); + } + model = r.model; + } + const prismaSchema = new PrismaSchemaGenerator(model); + const prismaSchemaText = await prismaSchema.generate(); + fs.writeFileSync(path.resolve(workDir!, 'schema.prisma'), prismaSchemaText); + execSync('npx prisma db push --schema ./schema.prisma --skip-generate --force-reset', { + cwd: workDir, + stdio: 'ignore', + }); + } else { + if (provider === 'postgresql') { + invariant(dbName, 'dbName is required'); + const pgClient = new PGClient(TEST_PG_CONFIG); + await pgClient.connect(); + await pgClient.query(`DROP DATABASE IF EXISTS "${dbName}"`); + await pgClient.query(`CREATE DATABASE "${dbName}"`); + await pgClient.end(); } - model = r.model; - } - const prismaSchema = new PrismaSchemaGenerator(model); - const prismaSchemaText = await prismaSchema.generate(); - fs.writeFileSync(path.resolve(workDir!, 'schema.prisma'), prismaSchemaText); - execSync('npx prisma db push --schema ./schema.prisma --skip-generate --force-reset', { - cwd: workDir, - stdio: 'ignore', - }); - } else { - if (provider === 'postgresql') { - invariant(dbName, 'dbName is required'); - const pgClient = new PGClient(TEST_PG_CONFIG); - await pgClient.connect(); - await pgClient.query(`DROP DATABASE IF EXISTS "${dbName}"`); - await pgClient.query(`CREATE DATABASE "${dbName}"`); - await pgClient.end(); } } @@ -155,7 +167,7 @@ export async function createTestClient( let client = new ZenStackClient(_schema, _options); - if (!options?.usePrismaPush) { + if (!options?.usePrismaPush && !options?.dbFile) { await client.$pushSchema(); } diff --git a/packages/testtools/src/schema.ts b/packages/testtools/src/schema.ts index c805cb95..1ecb015c 100644 --- a/packages/testtools/src/schema.ts +++ b/packages/testtools/src/schema.ts @@ -32,6 +32,11 @@ datasource db { .exhaustive(); } +function replacePlaceholders(schemaText: string, provider: 'sqlite' | 'postgresql', dbUrl: string | undefined) { + const url = dbUrl ?? (provider === 'sqlite' ? 'file:./test.db' : 'postgres://postgres:postgres@localhost:5432/db'); + return schemaText.replace(/\$DB_URL/g, url).replace(/\$PROVIDER/g, provider); +} + export async function generateTsSchema( schemaText: string, provider: 'sqlite' | 'postgresql' = 'sqlite', @@ -43,7 +48,10 @@ export async function generateTsSchema( const zmodelPath = path.join(workDir, 'schema.zmodel'); const noPrelude = schemaText.includes('datasource '); - fs.writeFileSync(zmodelPath, `${noPrelude ? '' : makePrelude(provider, dbUrl)}\n\n${schemaText}`); + fs.writeFileSync( + zmodelPath, + `${noPrelude ? '' : makePrelude(provider, dbUrl)}\n\n${replacePlaceholders(schemaText, provider, dbUrl)}`, + ); const result = await loadDocumentWithPlugins(zmodelPath); if (!result.success) { diff --git a/packages/zod/package.json b/packages/zod/package.json index fac1d39e..36dace1a 100644 --- a/packages/zod/package.json +++ b/packages/zod/package.json @@ -1,6 +1,6 @@ { "name": "@zenstackhq/zod", - "version": "3.0.0-beta.20", + "version": "3.0.0-beta.21", "description": "", "type": "module", "main": "index.js", diff --git a/samples/next.js/package.json b/samples/next.js/package.json index 65b6ebb5..6c65bc1f 100644 --- a/samples/next.js/package.json +++ b/samples/next.js/package.json @@ -1,6 +1,6 @@ { "name": "next.js", - "version": "3.0.0-beta.20", + "version": "3.0.0-beta.21", "private": true, "scripts": { "generate": "zen generate --lite", diff --git a/samples/orm/package.json b/samples/orm/package.json index fe66c962..9e7c56a3 100644 --- a/samples/orm/package.json +++ b/samples/orm/package.json @@ -1,6 +1,6 @@ { "name": "sample-blog", - "version": "3.0.0-beta.20", + "version": "3.0.0-beta.21", "description": "", "main": "index.js", "private": true, diff --git a/samples/orm/zenstack/models.ts b/samples/orm/zenstack/models.ts index 2eb57fad..e2db380e 100644 --- a/samples/orm/zenstack/models.ts +++ b/samples/orm/zenstack/models.ts @@ -23,7 +23,7 @@ export type CommonFields = $TypeDefResult<$Schema, "CommonFields">; /** * User roles */ -export const Role = $schema.enums.Role; +export const Role = $schema.enums.Role.values; /** * User roles */ diff --git a/samples/orm/zenstack/schema.ts b/samples/orm/zenstack/schema.ts index 637ec687..4c9134b9 100644 --- a/samples/orm/zenstack/schema.ts +++ b/samples/orm/zenstack/schema.ts @@ -232,8 +232,10 @@ export const schema = { }, enums: { Role: { - ADMIN: "ADMIN", - USER: "USER" + values: { + ADMIN: "ADMIN", + USER: "USER" + } } }, authType: "User", diff --git a/tests/e2e/orm/client-api/delegate.test.ts b/tests/e2e/orm/client-api/delegate.test.ts index 1497f91b..9076c60d 100644 --- a/tests/e2e/orm/client-api/delegate.test.ts +++ b/tests/e2e/orm/client-api/delegate.test.ts @@ -8,13 +8,10 @@ describe('Delegate model tests ', () => { let client: ClientContract; beforeEach(async () => { - client = await createTestClient( - schema, - { - usePrismaPush: true, - }, - path.join(__dirname, '../schemas/delegate/schema.zmodel'), - ); + client = await createTestClient(schema, { + usePrismaPush: true, + schemaFile: path.join(__dirname, '../schemas/delegate/schema.zmodel'), + }); }); afterEach(async () => { diff --git a/tests/e2e/orm/client-api/name-mapping.test.ts b/tests/e2e/orm/client-api/name-mapping.test.ts index 5d9151e7..6f279114 100644 --- a/tests/e2e/orm/client-api/name-mapping.test.ts +++ b/tests/e2e/orm/client-api/name-mapping.test.ts @@ -1,18 +1,17 @@ +import type { ClientContract } from '@zenstackhq/orm'; +import { createTestClient } from '@zenstackhq/testtools'; import path from 'node:path'; import { afterEach, beforeEach, describe, expect, it } from 'vitest'; -import type { ClientContract } from '@zenstackhq/orm'; import { schema, type SchemaType } from '../schemas/name-mapping/schema'; -import { createTestClient } from '@zenstackhq/testtools'; describe('Name mapping tests', () => { let db: ClientContract; beforeEach(async () => { - db = await createTestClient( - schema, - { usePrismaPush: true }, - path.join(__dirname, '../schemas/name-mapping/schema.zmodel'), - ); + db = await createTestClient(schema, { + usePrismaPush: true, + schemaFile: path.join(__dirname, '../schemas/name-mapping/schema.zmodel'), + }); }); afterEach(async () => { @@ -34,6 +33,37 @@ describe('Name mapping tests', () => { ).resolves.toMatchObject({ id: expect.any(Number), email: 'u1@test.com', + role: 'USER', // mapped enum value + }); + + let rawRead = await db.$qbRaw + .selectFrom('users') + .where('user_email', '=', 'u1@test.com') + .selectAll() + .executeTakeFirst(); + await expect(rawRead).toMatchObject({ + user_email: 'u1@test.com', + user_role: 'role_user', + }); + + await expect( + db.user.create({ + data: { + email: 'u1_1@test.com', + role: 'MODERATOR', // unmapped enum value + }, + }), + ).resolves.toMatchObject({ + role: 'MODERATOR', + }); + + rawRead = await db.$qbRaw + .selectFrom('users') + .where('user_email', '=', 'u1_1@test.com') + .selectAll() + .executeTakeFirst(); + await expect(rawRead).toMatchObject({ + user_role: 'MODERATOR', }); await expect( @@ -41,12 +71,22 @@ describe('Name mapping tests', () => { .insertInto('User') .values({ email: 'u2@test.com', + role: 'ADMIN', }) - .returning(['id', 'email']) + .returning(['id', 'email', 'role']) .executeTakeFirst(), ).resolves.toMatchObject({ id: expect.any(Number), email: 'u2@test.com', + role: 'ADMIN', + }); + rawRead = await db.$qbRaw + .selectFrom('users') + .where('user_email', '=', 'u2@test.com') + .selectAll() + .executeTakeFirst(); + await expect(rawRead).toMatchObject({ + user_role: 'role_admin', }); await expect( @@ -73,6 +113,7 @@ describe('Name mapping tests', () => { ).resolves.toMatchObject({ id: expect.any(Number), email: 'u4@test.com', + role: 'USER', }); }); @@ -94,26 +135,45 @@ describe('Name mapping tests', () => { select: { id: true, email: true, + role: true, posts: { where: { title: { contains: 'Post1' } }, select: { title: true } }, }, }), ).resolves.toMatchObject({ id: expect.any(Number), email: 'u1@test.com', + role: 'USER', posts: [{ title: 'Post1' }], }); + // select all + await expect( + db.user.findFirst({ + where: { email: 'u1@test.com' }, + }), + ).resolves.toMatchObject({ + id: expect.any(Number), + email: 'u1@test.com', + role: 'USER', + }); + await expect( db.$qb.selectFrom('User').selectAll().where('email', '=', 'u1@test.com').executeTakeFirst(), ).resolves.toMatchObject({ id: expect.any(Number), email: 'u1@test.com', + role: 'USER', }); await expect( - db.$qb.selectFrom('User').select(['User.email']).where('email', '=', 'u1@test.com').executeTakeFirst(), + db.$qb + .selectFrom('User') + .select(['User.email', 'User.role']) + .where('email', '=', 'u1@test.com') + .executeTakeFirst(), ).resolves.toMatchObject({ email: 'u1@test.com', + role: 'USER', }); await expect( @@ -172,6 +232,7 @@ describe('Name mapping tests', () => { where: { id: user.id }, data: { email: 'u2@test.com', + role: 'ADMIN', posts: { update: { where: { id: 1 }, @@ -184,21 +245,22 @@ describe('Name mapping tests', () => { ).resolves.toMatchObject({ id: user.id, email: 'u2@test.com', + role: 'ADMIN', posts: [expect.objectContaining({ title: 'Post2' })], }); await expect( db.$qb .updateTable('User') - .set({ email: (eb) => eb.fn('upper', [eb.ref('email')]) }) + .set({ email: (eb) => eb.fn('upper', [eb.ref('email')]), role: 'USER' }) .where('email', '=', 'u2@test.com') - .returning(['email']) + .returning(['email', 'role']) .executeTakeFirst(), - ).resolves.toMatchObject({ email: 'U2@TEST.COM' }); + ).resolves.toMatchObject({ email: 'U2@TEST.COM', role: 'USER' }); await expect( db.$qb.updateTable('User as u').set({ email: 'u3@test.com' }).returningAll().executeTakeFirst(), - ).resolves.toMatchObject({ id: expect.any(Number), email: 'u3@test.com' }); + ).resolves.toMatchObject({ id: expect.any(Number), email: 'u3@test.com', role: 'USER' }); }); it('works with delete', async () => { @@ -229,6 +291,7 @@ describe('Name mapping tests', () => { ).resolves.toMatchObject({ email: 'u1@test.com', posts: [], + role: 'USER', }); }); @@ -236,6 +299,7 @@ describe('Name mapping tests', () => { await db.user.create({ data: { email: 'u1@test.com', + role: 'USER', posts: { create: [{ title: 'Post1' }, { title: 'Post2' }], }, @@ -245,6 +309,7 @@ describe('Name mapping tests', () => { await db.user.create({ data: { email: 'u2@test.com', + role: 'MODERATOR', posts: { create: [{ title: 'Post3' }], }, @@ -254,8 +319,9 @@ describe('Name mapping tests', () => { // Test ORM count operations await expect(db.user.count()).resolves.toBe(2); await expect(db.post.count()).resolves.toBe(3); - await expect(db.user.count({ select: { email: true } })).resolves.toMatchObject({ + await expect(db.user.count({ select: { email: true, role: true } })).resolves.toMatchObject({ email: 2, + role: 2, }); await expect(db.user.count({ where: { email: 'u1@test.com' } })).resolves.toBe(1); @@ -266,9 +332,11 @@ describe('Name mapping tests', () => { // Test Kysely count operations const r = await db.$qb .selectFrom('User') - .select((eb) => eb.fn.count('email').as('count')) + .select((eb) => eb.fn.count('email').as('email_count')) + .select((eb) => eb.fn.count('role').as('role_count')) .executeTakeFirst(); - await expect(Number(r?.count)).toBe(2); + await expect(Number(r?.email_count)).toBe(2); + await expect(Number(r?.role_count)).toBe(2); }); it('works with aggregate', async () => { @@ -276,6 +344,7 @@ describe('Name mapping tests', () => { data: { id: 1, email: 'u1@test.com', + role: 'USER', posts: { create: [ { id: 1, title: 'Post1' }, @@ -289,6 +358,7 @@ describe('Name mapping tests', () => { data: { id: 2, email: 'u2@test.com', + role: 'MODERATOR', posts: { create: [{ id: 3, title: 'Post3' }], }, @@ -296,8 +366,12 @@ describe('Name mapping tests', () => { }); // Test ORM aggregate operations - await expect(db.user.aggregate({ _count: { id: true, email: true } })).resolves.toMatchObject({ + await expect( + db.user.aggregate({ _count: { id: true, email: true }, _max: { role: true }, _min: { role: true } }), + ).resolves.toMatchObject({ _count: { id: 2, email: 2 }, + _max: { role: 'USER' }, + _min: { role: 'MODERATOR' }, }); await expect( @@ -342,6 +416,7 @@ describe('Name mapping tests', () => { data: { id: 1, email: 'u1@test.com', + role: 'USER', posts: { create: [ { id: 1, title: 'Post1' }, @@ -356,6 +431,7 @@ describe('Name mapping tests', () => { data: { id: 2, email: 'u2@test.com', + role: 'MODERATOR', posts: { create: [ { id: 4, title: 'Post4' }, @@ -389,6 +465,18 @@ describe('Name mapping tests', () => { ]), ); + const userGroupBy1 = await db.user.groupBy({ + by: ['role'], + _count: { id: true }, + }); + expect(userGroupBy1).toHaveLength(2); + expect(userGroupBy1).toEqual( + expect.arrayContaining([ + { role: 'USER', _count: { id: 2 } }, + { role: 'MODERATOR', _count: { id: 1 } }, + ]), + ); + const postGroupBy = await db.post.groupBy({ by: ['authorId'], _count: { id: true }, diff --git a/tests/e2e/orm/client-api/pg-custom-schema.test.ts b/tests/e2e/orm/client-api/pg-custom-schema.test.ts new file mode 100644 index 00000000..4308e864 --- /dev/null +++ b/tests/e2e/orm/client-api/pg-custom-schema.test.ts @@ -0,0 +1,261 @@ +import { ORMError } from '@zenstackhq/orm'; +import { createTestClient } from '@zenstackhq/testtools'; +import { describe, expect, it } from 'vitest'; + +describe('Postgres custom schema support', () => { + it('defaults to public schema for ORM queries', async () => { + const foundSchema = { create: false, read: false, update: false, delete: false }; + const db = await createTestClient( + ` +model Foo { + id Int @id + name String +} +`, + { + provider: 'postgresql', + log: (event) => { + const sql = event.query.sql.toLowerCase(); + if (sql.includes('"public"."foo"')) { + sql.includes('insert') && (foundSchema.create = true); + sql.includes('select') && (foundSchema.read = true); + sql.includes('update') && (foundSchema.update = true); + sql.includes('delete') && (foundSchema.delete = true); + } + }, + }, + ); + + await expect(db.foo.create({ data: { id: 1, name: 'test' } })).toResolveTruthy(); + await expect(db.foo.findFirst()).toResolveTruthy(); + await expect(db.foo.update({ where: { id: 1 }, data: { name: 'updated' } })).toResolveTruthy(); + await expect(db.foo.delete({ where: { id: 1 } })).toResolveTruthy(); + + expect(foundSchema).toEqual({ create: true, read: true, update: true, delete: true }); + }); + + it('defaults to public schema for QB queries', async () => { + const foundSchema = { create: false, read: false, update: false, delete: false }; + const db = await createTestClient( + ` +model Foo { + id Int @id + name String +} +`, + { + provider: 'postgresql', + log: (event) => { + const sql = event.query.sql.toLowerCase(); + if (sql.includes('"public"."foo"')) { + sql.includes('insert') && (foundSchema.create = true); + sql.includes('select') && (foundSchema.read = true); + sql.includes('update') && (foundSchema.update = true); + sql.includes('delete') && (foundSchema.delete = true); + } + }, + }, + ); + + await expect(db.$qb.insertInto('Foo').values({ id: 1, name: 'test' }).execute()).toResolveTruthy(); + await expect(db.$qb.selectFrom('Foo').selectAll().executeTakeFirst()).toResolveTruthy(); + await expect( + db.$qb.updateTable('Foo').set({ name: 'updated' }).where('id', '=', 1).execute(), + ).toResolveTruthy(); + await expect(db.$qb.deleteFrom('Foo').where('id', '=', 1).execute()).toResolveTruthy(); + + expect(foundSchema).toEqual({ create: true, read: true, update: true, delete: true }); + }); + + it('supports changing default schema', async () => { + const db = await createTestClient( + ` +datasource db { + provider = 'postgresql' + defaultSchema = 'mySchema' +} + +model Foo { + id Int @id + name String +} +`, + { + provider: 'postgresql', + }, + ); + + await expect(db.foo.create({ data: { id: 1, name: 'test' } })).rejects.toSatisfy( + (e) => e instanceof ORMError && !!e.dbErrorMessage?.includes('relation "mySchema.Foo" does not exist'), + ); + + await db.$disconnect(); + + const db1 = await createTestClient( + ` +datasource db { + provider = 'postgresql' + defaultSchema = 'public' +} + +model Foo { + id Int @id + name String +} +`, + { + provider: 'postgresql', + }, + ); + + await expect(db1.foo.create({ data: { id: 1, name: 'test' } })).toResolveTruthy(); + }); + + it('supports custom schemas', async () => { + let fooQueriesVerified = false; + let barQueriesVerified = false; + + const db = await createTestClient( + ` +datasource db { + provider = '$PROVIDER' + schemas = ['public', 'mySchema'] + url = '$DB_URL' +} + +model Foo { + id Int @id + name String + @@schema('mySchema') +} + +model Bar { + id Int @id + name String + @@schema('public') +} +`, + { + provider: 'postgresql', + usePrismaPush: true, + log: (event) => { + const sql = event.query.sql.toLowerCase(); + if (sql.includes('"myschema"."foo"')) { + fooQueriesVerified = true; + } + if (sql.includes('"public"."bar"')) { + barQueriesVerified = true; + } + }, + }, + ); + + await expect(db.foo.create({ data: { id: 1, name: 'test' } })).toResolveTruthy(); + await expect(db.bar.create({ data: { id: 1, name: 'test' } })).toResolveTruthy(); + + expect(fooQueriesVerified).toBe(true); + expect(barQueriesVerified).toBe(true); + }); + + it('rejects using schema for non-postgresql providers', async () => { + await expect( + createTestClient( + ` +datasource db { + provider = 'sqlite' + defaultSchema = 'mySchema' +} + +model Foo { + id Int @id + name String +} +`, + ), + ).rejects.toThrow('only supported for "postgresql" provider'); + }); + + it('rejects using schema not defined in datasource', async () => { + await expect( + createTestClient( + ` +datasource db { + provider = 'postgresql' + schemas = ['public'] +} + +model Foo { + id Int @id + name String + @@schema('mySchema') +} +`, + ), + ).rejects.toThrow('Schema "mySchema" is not defined in the datasource'); + }); + + it('requires defaultSchema to be included in schemas', async () => { + await expect( + createTestClient( + ` +datasource db { + provider = 'postgresql' + defaultSchema = 'mySchema' + schemas = ['public'] +} + +model Foo { + id Int @id + name String +} +`, + ), + ).rejects.toThrow('"mySchema" must be included in the "schemas" array'); + }); + + it('allows specifying schema only on a few models', async () => { + let fooQueriesVerified = false; + let barQueriesVerified = false; + + const db = await createTestClient( + ` +datasource db { + provider = 'postgresql' + defaultSchema = 'somedefault' + schemas = ['mySchema', 'somedefault'] + url = '$DB_URL' +} + +model Foo { + id Int @id + name String + @@schema('mySchema') +} + +model Bar { + id Int @id + name String +} +`, + { + provider: 'postgresql', + usePrismaPush: true, + log: (event) => { + const sql = event.query.sql.toLowerCase(); + if (sql.includes('"myschema"."foo"')) { + fooQueriesVerified = true; + } + if (sql.includes('"somedefault"."bar"')) { + barQueriesVerified = true; + } + }, + }, + ); + + await expect(db.foo.create({ data: { id: 1, name: 'test' } })).toResolveTruthy(); + await expect(db.bar.create({ data: { id: 1, name: 'test' } })).toResolveTruthy(); + + expect(fooQueriesVerified).toBe(true); + expect(barQueriesVerified).toBe(true); + }); +}); diff --git a/tests/e2e/orm/query-builder/query-builder.test.ts b/tests/e2e/orm/query-builder/query-builder.test.ts index 23f31e12..563118a4 100644 --- a/tests/e2e/orm/query-builder/query-builder.test.ts +++ b/tests/e2e/orm/query-builder/query-builder.test.ts @@ -1,7 +1,7 @@ import { createId } from '@paralleldrive/cuid2'; +import { createTestClient } from '@zenstackhq/testtools'; import { describe, expect, it } from 'vitest'; import { getSchema } from '../schemas/basic'; -import { createTestClient } from '@zenstackhq/testtools'; describe('Client API tests', () => { const schema = getSchema('sqlite'); diff --git a/tests/e2e/orm/schemas/basic/models.ts b/tests/e2e/orm/schemas/basic/models.ts index d532d7d4..be197879 100644 --- a/tests/e2e/orm/schemas/basic/models.ts +++ b/tests/e2e/orm/schemas/basic/models.ts @@ -12,5 +12,5 @@ export type Post = $ModelResult<$Schema, "Post">; export type Comment = $ModelResult<$Schema, "Comment">; export type Profile = $ModelResult<$Schema, "Profile">; export type CommonFields = $TypeDefResult<$Schema, "CommonFields">; -export const Role = $schema.enums.Role; +export const Role = $schema.enums.Role.values; export type Role = (typeof Role)[keyof typeof Role]; diff --git a/tests/e2e/orm/schemas/basic/schema.ts b/tests/e2e/orm/schemas/basic/schema.ts index 14e627b9..6339ab0a 100644 --- a/tests/e2e/orm/schemas/basic/schema.ts +++ b/tests/e2e/orm/schemas/basic/schema.ts @@ -275,8 +275,10 @@ export const schema = { }, enums: { Role: { - ADMIN: "ADMIN", - USER: "USER" + values: { + ADMIN: "ADMIN", + USER: "USER" + } } }, authType: "User", diff --git a/tests/e2e/orm/schemas/name-mapping/models.ts b/tests/e2e/orm/schemas/name-mapping/models.ts index 72654e58..944ad9cb 100644 --- a/tests/e2e/orm/schemas/name-mapping/models.ts +++ b/tests/e2e/orm/schemas/name-mapping/models.ts @@ -5,7 +5,9 @@ /* eslint-disable */ -import { type SchemaType as $Schema } from "./schema"; +import { schema as $schema, type SchemaType as $Schema } from "./schema"; import { type ModelResult as $ModelResult } from "@zenstackhq/orm"; export type User = $ModelResult<$Schema, "User">; export type Post = $ModelResult<$Schema, "Post">; +export const Role = $schema.enums.Role.values; +export type Role = (typeof Role)[keyof typeof Role]; diff --git a/tests/e2e/orm/schemas/name-mapping/schema.ts b/tests/e2e/orm/schemas/name-mapping/schema.ts index 5c27728b..97aa169a 100644 --- a/tests/e2e/orm/schemas/name-mapping/schema.ts +++ b/tests/e2e/orm/schemas/name-mapping/schema.ts @@ -27,6 +27,12 @@ export const schema = { unique: true, attributes: [{ name: "@map", args: [{ name: "name", value: ExpressionUtils.literal("user_email") }] }, { name: "@unique" }] }, + role: { + name: "role", + type: "Role", + attributes: [{ name: "@default", args: [{ name: "value", value: ExpressionUtils.literal("USER") }] }, { name: "@map", args: [{ name: "name", value: ExpressionUtils.literal("user_role") }] }], + default: "USER" + }, posts: { name: "posts", type: "Post", @@ -82,6 +88,35 @@ export const schema = { } } }, + enums: { + Role: { + values: { + USER: "USER", + ADMIN: "ADMIN", + MODERATOR: "MODERATOR" + }, + fields: { + USER: { + name: "USER", + attributes: [ + { name: "@map", args: [{ name: "name", value: ExpressionUtils.literal("role_user") }] } + ] + }, + ADMIN: { + name: "ADMIN", + attributes: [ + { name: "@map", args: [{ name: "name", value: ExpressionUtils.literal("role_admin") }] } + ] + }, + MODERATOR: { + name: "MODERATOR" + } + }, + attributes: [ + { name: "@@map", args: [{ name: "name", value: ExpressionUtils.literal("user_role") }] } + ] + } + }, authType: "User", plugins: {} } as const satisfies SchemaDef; diff --git a/tests/e2e/orm/schemas/name-mapping/schema.zmodel b/tests/e2e/orm/schemas/name-mapping/schema.zmodel index baddc94f..2dfbaeda 100644 --- a/tests/e2e/orm/schemas/name-mapping/schema.zmodel +++ b/tests/e2e/orm/schemas/name-mapping/schema.zmodel @@ -3,9 +3,17 @@ datasource db { url = "file:./dev.db" } +enum Role { + USER @map('role_user') + ADMIN @map('role_admin') + MODERATOR + @@map("user_role") +} + model User { id Int @id @default(autoincrement()) email String @map('user_email') @unique + role Role @default(USER) @map('user_role') posts Post[] @@map('users') } diff --git a/tests/e2e/orm/schemas/typing/models.ts b/tests/e2e/orm/schemas/typing/models.ts index b2fa673f..15eae9a9 100644 --- a/tests/e2e/orm/schemas/typing/models.ts +++ b/tests/e2e/orm/schemas/typing/models.ts @@ -15,7 +15,7 @@ export type Region = $ModelResult<$Schema, "Region">; export type Meta = $ModelResult<$Schema, "Meta">; export type Identity = $TypeDefResult<$Schema, "Identity">; export type IdentityProvider = $TypeDefResult<$Schema, "IdentityProvider">; -export const Role = $schema.enums.Role; +export const Role = $schema.enums.Role.values; export type Role = (typeof Role)[keyof typeof Role]; -export const Status = $schema.enums.Status; +export const Status = $schema.enums.Status.values; export type Status = (typeof Status)[keyof typeof Status]; diff --git a/tests/e2e/orm/schemas/typing/schema.ts b/tests/e2e/orm/schemas/typing/schema.ts index 10c4daa7..1a1212bd 100644 --- a/tests/e2e/orm/schemas/typing/schema.ts +++ b/tests/e2e/orm/schemas/typing/schema.ts @@ -328,13 +328,17 @@ export const schema = { }, enums: { Role: { - ADMIN: "ADMIN", - USER: "USER" + values: { + ADMIN: "ADMIN", + USER: "USER" + } }, Status: { - ACTIVE: "ACTIVE", - INACTIVE: "INACTIVE", - BANNED: "BANNED" + values: { + ACTIVE: "ACTIVE", + INACTIVE: "INACTIVE", + BANNED: "BANNED" + } } }, authType: "User", diff --git a/tests/e2e/package.json b/tests/e2e/package.json index d6564bff..de46284e 100644 --- a/tests/e2e/package.json +++ b/tests/e2e/package.json @@ -1,6 +1,6 @@ { "name": "e2e", - "version": "3.0.0-beta.20", + "version": "3.0.0-beta.21", "private": true, "type": "module", "scripts": { diff --git a/tests/regression/package.json b/tests/regression/package.json index 18b957e5..d59f39e5 100644 --- a/tests/regression/package.json +++ b/tests/regression/package.json @@ -1,6 +1,6 @@ { "name": "regression", - "version": "3.0.0-beta.20", + "version": "3.0.0-beta.21", "private": true, "type": "module", "scripts": { diff --git a/tests/regression/test/v2-migrated/issue-2283/.gitignore b/tests/regression/test/v2-migrated/issue-2283/.gitignore new file mode 100644 index 00000000..78254b4c --- /dev/null +++ b/tests/regression/test/v2-migrated/issue-2283/.gitignore @@ -0,0 +1 @@ +!*.db \ No newline at end of file diff --git a/tests/regression/test/v2-migrated/issue-2283/dev.db b/tests/regression/test/v2-migrated/issue-2283/dev.db new file mode 100644 index 00000000..8eab9f73 Binary files /dev/null and b/tests/regression/test/v2-migrated/issue-2283/dev.db differ diff --git a/tests/regression/test/v2-migrated/issue-2283/regression.test.ts b/tests/regression/test/v2-migrated/issue-2283/regression.test.ts new file mode 100644 index 00000000..e1fb6a61 --- /dev/null +++ b/tests/regression/test/v2-migrated/issue-2283/regression.test.ts @@ -0,0 +1,703 @@ +import { createPolicyTestClient } from '@zenstackhq/testtools'; +import path from 'path'; +import { describe, expect, it } from 'vitest'; + +describe('Regression for issue 2283', () => { + it('regression', async () => { + const db: any = await createPolicyTestClient( + ` +// Base models +type Base { + createdAt DateTime @default(now()) + updatedAt DateTime @updatedAt() +} + +type BaseWithCuid with Base { + id String @id @default(cuid()) +} + +type Publishable { + published Boolean @default(false) +} + +// Media models +model Image with BaseWithCuid { + storageRef String + displayName String? + width Int + height Int + size BigInt + + // Relations + userProfiles UserProfile[] + labProfiles LabProfile[] + contents Content[] + modules Module[] + classes Class[] + + @@allow('all', true) +} + +model Video with BaseWithCuid { + storageRef String + displayName String? + durationMillis Int + width Int? + height Int? + size BigInt + + // Relations + previewForContent Content[] + previewForModule Module[] + classes Class[] + + @@allow('all', true) +} + +// User models +model User with Base { + id String @id @default(uuid()) + email String @unique + displayName String? + + profile UserProfile? + labs UserLabJoin[] + ownedLabs Lab[] + + @@allow('all', true) +} + +model UserProfile with BaseWithCuid { + user User @relation(fields: [userId], references: [id], onDelete: Cascade) + userId String @unique + bio String? + instagram String? + profilePhoto Image? @relation(fields: [profilePhotoId], references: [id], onDelete: SetNull) + profilePhotoId String? + + @@allow('all', true) +} + +// Lab models +model Lab with BaseWithCuid, Publishable { + name String + profile LabProfile? + owners User[] + community UserLabJoin[] + roles Role[] + privileges Privilege[] + content Content[] + permissions LabPermission[] + + @@allow('create', auth() != null) + @@allow('read', owners?[id == auth().id] || published) + @@allow('update', + owners?[id == auth().id] + || + community?[ + userLabRoles?[ + userId == auth().id + && + role.privileges?[ + privilege.labPermissions?[ + type == "ALLOW_ADMINISTRATION" + ] + ] + ] + ] + ) + @@allow('delete', owners?[id == auth().id]) +} + +model LabProfile with BaseWithCuid { + lab Lab @relation(fields: [labId], references: [id], onDelete: Cascade) + labId String @unique + bio String? + instagram String? + profilePhoto Image? @relation(fields: [profilePhotoId], references: [id], onDelete: SetNull) + profilePhotoId String? + slug String? @unique + + @@allow('read', check(lab, "read")) + @@allow('create', lab.owners?[id == auth().id]) + @@allow('update', check(lab, "update")) + @@allow('delete', check(lab, "delete")) +} + +// User-Lab relationship +model UserLabJoin with Base { + user User @relation(fields: [userId], references: [id], onDelete: Cascade) + userId String + lab Lab @relation(fields: [labId], references: [id], onDelete: Restrict) + labId String + userLabRoles UserLabRole[] + + @@id(name: "userLabJoinId", [userId, labId]) + + @@allow('create', auth().id == userId) + @@allow('update', auth().id == userId) + @@allow('read', true) + @@allow('delete', auth().id == userId) +} + +// Role and Permission models +model Role with BaseWithCuid { + name String + shortDescription String? + longDescription String? + lab Lab @relation(fields: [labId], references: [id], onDelete: Cascade) + labId String + userLabRoles UserLabRole[] + privileges RolePrivilegeJoin[] + public Boolean @default(false) + priority Int @default(0) + isTeamRole Boolean @default(false) + + @@unique([labId, id]) + @@unique([name, labId]) + + @@allow('read', + auth().id != null + && + ( + userLabRoles?[userId == auth().id] + || + auth().labs?[ + userLabRoles?[ + role.privileges?[ + privilege.labPermissions?[type == "ALLOW_ADMINISTRATION"] + ] + && + labId == this.labId + ] + ] + || + lab.owners?[id == auth().id] + ) + ) + @@allow('create', + auth().id != null + && + ( + lab.owners?[id == auth().id] + || + auth().labs?[ + userLabRoles?[ + role.privileges?[ + privilege.labPermissions?[type == "ALLOW_ADMINISTRATION"] + && + privilege.labId == this.labId + ] + ] + ] + ) + ) + @@allow('update', + auth().id != null + && + ( + lab.owners?[id == auth().id] + || + auth().labs?[ + userLabRoles?[ + role.privileges?[ + privilege.labPermissions?[type == "ALLOW_ADMINISTRATION"] + && + privilege.labId == this.labId + ] + ] + ] + ) + ) + @@allow('delete', + auth().id != null + && + ( + lab.owners?[id == auth().id] + || + auth().labs?[ + userLabRoles?[ + role.privileges?[ + privilege.labPermissions?[type == "ALLOW_ADMINISTRATION"] + && + privilege.labId == this.labId + ] + ] + ] + ) + ) +} + +model UserLabRole with Base { + userLabJoin UserLabJoin @relation(fields: [userId, labId], references: [userId, labId], onDelete: Cascade) + userId String + labId String + role Role @relation(fields: [labId, roleId], references: [labId, id], onDelete: Cascade) + roleId String + expiresAt DateTime? + + @@id(name: "userLabRoleId", [userId, labId, roleId]) + + @@allow('read', auth().id != null) + @@allow('create', + auth().id != null + && + ( + userLabJoin.lab.owners?[id == auth().id] + || + auth().labs?[ + userLabRoles?[ + role.labId == labId + && + role.privileges?[ + privilege.labPermissions?[type == "ALLOW_ADMINISTRATION"] + ] + ] + ] + ) + ) + @@allow('update', + auth().id != null + && + ( + userLabJoin.lab.owners?[id == auth().id] + || + auth().labs?[ + userLabRoles?[ + role.labId == labId + && + role.privileges?[ + privilege.labPermissions?[type == "ALLOW_ADMINISTRATION"] + ] + ] + ] + ) + ) + @@allow('delete', + auth().id != null + && + ( + userLabJoin.lab.owners?[id == auth().id] + || + auth().labs?[ + userLabRoles?[ + role.labId == labId + && + role.privileges?[ + privilege.labPermissions?[type == "ALLOW_ADMINISTRATION"] + ] + ] + ] + ) + ) +} + +model Privilege with BaseWithCuid { + name String + longDescription String? + shortDescription String + lab Lab @relation(fields: [labId], references: [id], onDelete: Cascade) + labId String + roles RolePrivilegeJoin[] + labPermissions LabPermission[] + public Boolean @default(false) + + @@unique([name, labId]) + + @@allow('read', auth().id != null) + @@allow('create', + auth().id != null + && + ( + lab.owners?[id == auth().id] + || + auth().labs?[ + userLabRoles?[ + role.labId == labId + && + role.privileges?[ + privilege.labPermissions?[type == "ALLOW_ADMINISTRATION"] + ] + ] + ] + ) + ) + @@allow('update', + auth().id != null + && + ( + lab.owners?[id == auth().id] + || + auth().labs?[ + userLabRoles?[ + role.labId == labId + && + role.privileges?[ + privilege.labPermissions?[type == "ALLOW_ADMINISTRATION"] + ] + ] + ] + ) + ) + @@allow('delete', + auth().id != null + && + ( + lab.owners?[id == auth().id] + || + auth().labs?[ + userLabRoles?[ + role.labId == labId + && + role.privileges?[ + privilege.labPermissions?[type == "ALLOW_ADMINISTRATION"] + ] + ] + ] + ) + ) +} + +model LabPermission with BaseWithCuid { + name String + lab Lab @relation(fields: [labId], references: [id], onDelete: Cascade) + labId String + privileges Privilege[] + type String + + @@unique([name, labId]) + + @@allow('read', auth().id != null) + @@allow('create', + auth().id != null + && + ( + lab.owners?[id == auth().id] + || + auth().labs?[ + userLabRoles?[ + role.labId == this.labId + && + role.privileges?[ + privilege.labPermissions?[type == "ALLOW_ADMINISTRATION"] + ] + ] + ] + ) + ) + @@allow('update', + auth().id != null + && + ( + lab.owners?[id == auth().id] + || + auth().labs?[ + userLabRoles?[ + role.labId == this.labId + && + role.privileges?[ + privilege.labPermissions?[type == "ALLOW_ADMINISTRATION"] + ] + ] + ] + ) + ) + @@allow('delete', + auth().id != null + && + ( + lab.owners?[id == auth().id] + || + auth().labs?[ + userLabRoles?[ + role.labId == this.labId + && + role.privileges?[ + privilege.labPermissions?[type == "ALLOW_ADMINISTRATION"] + ] + ] + ] + ) + ) +} + +model RolePrivilegeJoin with Base { + role Role @relation(fields: [roleId], references: [id], onDelete: Cascade) + roleId String + privilege Privilege @relation(fields: [privilegeId], references: [id], onDelete: Cascade) + privilegeId String + order Int? + + @@id(name: "rolePrivilegeJoinId", [roleId, privilegeId]) + + @@allow('read', auth().id != null) + @@allow('create', + auth().id != null + && + ( + role.lab.owners?[id == auth().id] + || + auth().labs?[ + userLabRoles?[ + role.privileges?[ + privilege.labPermissions?[type == "ALLOW_ADMINISTRATION"] + ] + ] + ] + ) + ) + @@allow('update', + auth().id != null + && + ( + role.lab.owners?[id == auth().id] + || + auth().labs?[ + userLabRoles?[ + role.privileges?[ + privilege.labPermissions?[type == "ALLOW_ADMINISTRATION"] + ] + ] + ] + ) + ) + @@allow('delete', + auth().id != null + && + ( + role.lab.owners?[id == auth().id] + || + auth().labs?[ + userLabRoles?[ + role.privileges?[ + privilege.labPermissions?[type == "ALLOW_ADMINISTRATION"] + ] + ] + ] + ) + ) +} + +// Content models +model Content with BaseWithCuid { + lab Lab @relation(fields: [labId], references: [id], onDelete: Cascade) + labId String + name String + shortDescription String? + longDescription String? + thumbnail Image? @relation(fields: [thumbnailId], references: [id]) + thumbnailId String? + modules Module[] + published Boolean + previewVideo Video? @relation(fields: [previewVideoId], references: [id]) + previewVideoId String? + order Int + + @@unique([labId, order]) + + @@allow('read', + lab.owners?[id == auth().id] + || + lab.community?[ + userId == auth().id + && + userLabRoles?[ + labId == this.labId + && + role.privileges?[ + privilege.labPermissions?[ + type in ["ALLOW_ADMINISTRATION"] + ] + ] + ] + ] + || + published == true + ) + @@allow('create', + lab.owners?[id == auth().id] + || + lab.community?[ + userId == auth().id + && + userLabRoles?[ + labId == this.labId + && + role.privileges?[ + privilege.labPermissions?[ + type in ["ALLOW_ADMINISTRATION"] + ] + ] + ] + ] + ) + @@allow('update', + lab.owners?[id == auth().id] + || + lab.community?[ + userId == auth().id + && + userLabRoles?[ + labId == this.labId + && + role.privileges?[ + privilege.labPermissions?[ + type in ["ALLOW_ADMINISTRATION"] + ] + ] + ] + ] + ) + @@allow('delete', + lab.owners?[id == auth().id] + || + lab.community?[ + userId == auth().id + && + userLabRoles?[ + labId == this.labId + && + role.privileges?[ + privilege.labPermissions?[ + type in ["ALLOW_ADMINISTRATION"] + ] + ] + ] + ] + ) +} + +model Module with BaseWithCuid { + name String + shortDescription String? + longDescription String? + thumbnail Image? @relation(fields: [thumbnailId], references: [id]) + thumbnailId String? + content Content @relation(fields: [contentId], references: [id], onDelete: Restrict) + contentId String + classes Class[] + order Int + published Boolean + category String? + previewVideo Video? @relation(fields: [previewVideoId], references: [id]) + previewVideoId String? + + @@unique([order, category, contentId]) + + @@allow('read', + content.lab.owners?[id == auth().id] + || + content.lab.permissions?[ + privileges?[ + roles?[ + role.userLabRoles?[ + userId == auth().id + ] + ] + && + labPermissions?[ + type in ["ALLOW_ADMINISTRATION"] + ] + ] + ] + || + ( + check(content, 'read') + && + published == true + ) + ) + @@allow('create', check(content, 'create')) + @@allow('update', check(content, 'update')) + @@allow('delete', check(content, 'delete')) +} + +model Class with BaseWithCuid { + name String + shortDescription String? + longDescription String? + thumbnail Image? @relation(fields: [thumbnailId], references: [id]) + thumbnailId String? + module Module @relation(fields: [moduleId], references: [id], onDelete: Restrict) + moduleId String + order Int + published Boolean + video Video? @relation(fields: [videoId], references: [id]) + videoId String? + category String? + + @@unique([order, category, moduleId]) + + @@allow('read', check(module, 'read')) + @@allow('create', check(module, 'create')) + @@allow('update', check(module, 'update')) + @@allow('delete', check(module, 'delete')) +} +`, + { + provider: 'sqlite', + dbFile: path.join(__dirname, 'dev.db'), + }, + ); + + const r = await db.labProfile.findUnique({ + where: { + slug: 'test-lab-slug', + lab: { + published: true, + }, + }, + select: { + lab: { + select: { + id: true, + name: true, + content: { + where: { + published: true, + }, + select: { + id: true, + name: true, + modules: { + select: { + id: true, + name: true, + classes: { + select: { + id: true, + name: true, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }); + + expect(r).toMatchObject({ + lab: expect.objectContaining({ + name: 'Test Lab', + content: [ + expect.objectContaining({ + name: 'Test Course', + modules: [ + expect.objectContaining({ + name: 'Test Module', + classes: [ + expect.objectContaining({ + name: 'Test Class', + }), + ], + }), + ], + }), + ], + }), + }); + expect(r.lab.content[0].modules[0].classes[0].module).toBeUndefined(); + }); +});