diff --git a/packages/language/res/stdlib.zmodel b/packages/language/res/stdlib.zmodel index cdf37fc4..d3bef43b 100644 --- a/packages/language/res/stdlib.zmodel +++ b/packages/language/res/stdlib.zmodel @@ -457,9 +457,9 @@ attribute @db.ByteA() @@@targetField([BytesField]) @@@prisma /** * Specifies the schema to use in a multi-schema PostgreSQL database. * - * @param name: The name of the database schema. + * @param map: The name of the database schema. */ -attribute @@schema(_ name: String) @@@prisma +attribute @@schema(_ map: String) @@@prisma ////////////////////////////////////////////// // Begin validation attributes and functions diff --git a/packages/orm/src/client/crud/validator/index.ts b/packages/orm/src/client/crud/validator/index.ts index 7d97c942..4cd9c584 100644 --- a/packages/orm/src/client/crud/validator/index.ts +++ b/packages/orm/src/client/crud/validator/index.ts @@ -31,10 +31,12 @@ import { type UpdateManyArgs, type UpsertArgs, } from '../../crud-types'; +import { createInternalError, createInvalidInputError } from '../../errors'; import { fieldHasDefaultValue, getDiscriminatorField, getEnum, + getTypeDef, getUniqueFields, requireField, requireModel, @@ -47,7 +49,6 @@ import { addNumberValidation, addStringValidation, } from './utils'; -import { createInternalError, createInvalidInputError } from '../../errors'; const schemaCache = new WeakMap>(); @@ -281,6 +282,8 @@ export class InputValidator { private makePrimitiveSchema(type: string, attributes?: AttributeApplication[]) { if (this.schema.typeDefs && type in this.schema.typeDefs) { return this.makeTypeDefSchema(type); + } else if (this.schema.enums && type in this.schema.enums) { + return this.makeEnumSchema(type); } else { return match(type) .with('String', () => @@ -314,6 +317,22 @@ export class InputValidator { } } + private makeEnumSchema(type: string) { + const key = stableStringify({ + type: 'enum', + name: type, + }); + let schema = this.getSchemaCache(key!); + if (schema) { + return schema; + } + const enumDef = getEnum(this.schema, type); + invariant(enumDef, `Enum "${type}" not found in schema`); + schema = z.enum(Object.keys(enumDef.values) as [string, ...string[]]); + this.setSchemaCache(key!, schema); + return schema; + } + private makeTypeDefSchema(type: string): z.ZodType { const key = stableStringify({ type: 'typedef', @@ -324,24 +343,22 @@ export class InputValidator { if (schema) { return schema; } - const typeDef = this.schema.typeDefs?.[type]; + const typeDef = getTypeDef(this.schema, type); invariant(typeDef, `Type definition "${type}" not found in schema`); - schema = z - .object( - Object.fromEntries( - Object.entries(typeDef.fields).map(([field, def]) => { - let fieldSchema = this.makePrimitiveSchema(def.type); - if (def.array) { - fieldSchema = fieldSchema.array(); - } - if (def.optional) { - fieldSchema = fieldSchema.optional(); - } - return [field, fieldSchema]; - }), - ), - ) - .passthrough(); + schema = z.looseObject( + Object.fromEntries( + Object.entries(typeDef.fields).map(([field, def]) => { + let fieldSchema = this.makePrimitiveSchema(def.type); + if (def.array) { + fieldSchema = fieldSchema.array(); + } + if (def.optional) { + fieldSchema = fieldSchema.optional(); + } + return [field, fieldSchema]; + }), + ), + ); this.setSchemaCache(key!, schema); return schema; } @@ -392,7 +409,7 @@ export class InputValidator { const enumDef = getEnum(this.schema, fieldDef.type); if (enumDef) { // enum - if (Object.keys(enumDef).length > 0) { + if (Object.keys(enumDef.values).length > 0) { fieldSchema = this.makeEnumFilterSchema(enumDef, !!fieldDef.optional, withAggregations); } } else if (fieldDef.array) { @@ -427,7 +444,7 @@ export class InputValidator { const enumDef = getEnum(this.schema, def.type); if (enumDef) { // enum - if (Object.keys(enumDef).length > 0) { + if (Object.keys(enumDef.values).length > 0) { fieldSchema = this.makeEnumFilterSchema(enumDef, !!def.optional, false); } else { fieldSchema = z.never(); @@ -493,7 +510,7 @@ export class InputValidator { } private makeEnumFilterSchema(enumDef: EnumDef, optional: boolean, withAggregations: boolean) { - const baseSchema = z.enum(Object.keys(enumDef) as [string, ...string[]]); + const baseSchema = z.enum(Object.keys(enumDef.values) as [string, ...string[]]); const components = this.makeCommonPrimitiveFilterComponents( baseSchema, optional, diff --git a/packages/orm/src/client/executor/name-mapper.ts b/packages/orm/src/client/executor/name-mapper.ts index 1f508b0b..2ad522f5 100644 --- a/packages/orm/src/client/executor/name-mapper.ts +++ b/packages/orm/src/client/executor/name-mapper.ts @@ -1,12 +1,17 @@ import { invariant } from '@zenstackhq/common-helpers'; import { AliasNode, + CaseWhenBuilder, ColumnNode, + ColumnUpdateNode, DeleteQueryNode, + expressionBuilder, + ExpressionWrapper, FromNode, IdentifierNode, InsertQueryNode, OperationNodeTransformer, + PrimitiveValueListNode, ReferenceNode, ReturningNode, SelectAllNode, @@ -14,10 +19,23 @@ import { SelectQueryNode, TableNode, UpdateQueryNode, + ValueListNode, + ValueNode, + ValuesNode, type OperationNode, + type SimpleReferenceExpressionNode, } from 'kysely'; -import type { FieldDef, ModelDef, SchemaDef } from '../../schema'; -import { extractFieldName, extractModelName, getModel, requireModel, stripAlias } from '../query-utils'; +import type { EnumDef, EnumField, FieldDef, ModelDef, SchemaDef } from '../../schema'; +import { + extractFieldName, + extractModelName, + getEnum, + getField, + getModel, + isEnum, + requireModel, + stripAlias, +} from '../query-utils'; type Scope = { model?: string; @@ -25,6 +43,8 @@ type Scope = { namesMapped?: boolean; // true means fields referring to this scope have their names already mapped }; +type SelectionNodeChild = SimpleReferenceExpressionNode | AliasNode | SelectAllNode; + export class QueryNameMapper extends OperationNodeTransformer { private readonly modelToTableMap = new Map(); private readonly fieldToColumnMap = new Map(); @@ -89,15 +109,27 @@ export class QueryNameMapper extends OperationNodeTransformer { return super.transformInsertQuery(node); } - return this.withScope( - { model: node.into.table.identifier.name }, - () => - ({ - ...super.transformInsertQuery(node), - // map table name - into: this.processTableRef(node.into!), - }) satisfies InsertQueryNode, - ); + const model = extractModelName(node.into); + invariant(model, 'InsertQueryNode must have a model name in the "into" clause'); + + return this.withScope({ model }, () => { + const baseResult = super.transformInsertQuery(node); + let values = baseResult.values; + if (node.columns && values) { + // process enum values with name mapping + values = this.processEnumMappingForColumns(model, node.columns, values); + } + return { + ...baseResult, + // map table name + into: this.processTableRef(node.into!), + values, + } satisfies InsertQueryNode; + }); + } + + private isOperationNode(value: unknown): value is OperationNode { + return !!value && typeof value === 'object' && 'kind' in value; } protected override transformReturning(node: ReturningNode) { @@ -158,9 +190,29 @@ export class QueryNameMapper extends OperationNodeTransformer { return super.transformUpdateQuery(node); } - return this.withScope({ model: innerTable.table.identifier.name, alias }, () => { + const model = extractModelName(innerTable); + invariant(model, 'UpdateQueryNode must have a model name in the "table" clause'); + + return this.withScope({ model, alias }, () => { + const baseResult = super.transformUpdateQuery(node); + + // process enum value mappings in update set values + const updates = baseResult.updates?.map((update, i) => { + if (ColumnNode.is(update.column)) { + // fetch original column that doesn't have name mapping applied + const origColumn = node.updates![i]!.column as ColumnNode; + return ColumnUpdateNode.create( + update.column, + this.processEnumMappingForValue(model, origColumn, update.value) as OperationNode, + ); + } else { + return update; + } + }); + return { - ...super.transformUpdateQuery(node), + ...baseResult, + updates, // map table name table: this.wrapAlias(this.processTableRef(innerTable), alias), }; @@ -204,42 +256,70 @@ export class QueryNameMapper extends OperationNodeTransformer { private processSelectQuerySelections(node: SelectQueryNode) { const selections: SelectionNode[] = []; for (const selection of node.selections ?? []) { + const processedSelections: { originalField?: string; selection: SelectionNode }[] = []; if (SelectAllNode.is(selection.selection)) { // expand `selectAll` to all fields with name mapping if the // inner-most scope is not already mapped - const scope = this.scopes[this.scopes.length - 1]; + const scope = this.requireCurrentScope(); if (scope?.model && !scope.namesMapped) { - selections.push(...this.createSelectAllFields(scope.model, scope.alias)); + // expand + processedSelections.push(...this.createSelectAllFields(scope.model, scope.alias)); } else { - selections.push(super.transformSelection(selection)); + // preserve + processedSelections.push({ + originalField: undefined, + selection: super.transformSelection(selection), + }); } } else if (ReferenceNode.is(selection.selection) || ColumnNode.is(selection.selection)) { // map column name and add/preserve alias const transformed = this.transformNode(selection.selection); + + // field name without applying name mapping + const originalField = extractFieldName(selection.selection); + if (AliasNode.is(transformed)) { // keep the alias if there's one - selections.push(SelectionNode.create(transformed)); + processedSelections.push({ originalField, selection: SelectionNode.create(transformed) }); } else { // otherwise use an alias to preserve the original field name - const origFieldName = extractFieldName(selection.selection); const fieldName = extractFieldName(transformed); - if (fieldName !== origFieldName) { - selections.push( - SelectionNode.create( + if (fieldName !== originalField) { + processedSelections.push({ + originalField, + selection: SelectionNode.create( this.wrapAlias( transformed, - origFieldName ? IdentifierNode.create(origFieldName) : undefined, + originalField ? IdentifierNode.create(originalField) : undefined, ), ), - ); + }); } else { - selections.push(SelectionNode.create(transformed)); + processedSelections.push({ + originalField, + selection: SelectionNode.create(transformed), + }); } } } else { - selections.push(super.transformSelection(selection)); + const { node: innerNode } = stripAlias(selection.selection); + processedSelections.push({ + originalField: extractFieldName(innerNode), + selection: super.transformSelection(selection), + }); } + + // process enum value mapping + const enumProcessedSelections = processedSelections.map(({ originalField, selection }) => { + if (!originalField) { + return selection; + } else { + return SelectionNode.create(this.processEnumSelection(selection.selection, originalField)); + } + }); + selections.push(...enumProcessedSelections); } + return selections; } @@ -320,7 +400,7 @@ export class QueryNameMapper extends OperationNodeTransformer { return this.createTableNode(mappedName, tableSchema); } - private getMappedName(def: ModelDef | FieldDef) { + private getMappedName(def: ModelDef | FieldDef | EnumField) { const mapAttr = def.attributes?.find((attr) => attr.name === '@@map' || attr.name === '@map'); if (mapAttr) { const nameArg = mapAttr.args?.find((arg) => arg.name === 'name'); @@ -393,7 +473,7 @@ export class QueryNameMapper extends OperationNodeTransformer { let schema = this.schema.provider.defaultSchema ?? 'public'; const schemaAttr = this.schema.models[model]?.attributes?.find((attr) => attr.name === '@@schema'); if (schemaAttr) { - const nameArg = schemaAttr.args?.find((arg) => arg.name === 'name'); + const nameArg = schemaAttr.args?.find((arg) => arg.name === 'map'); if (nameArg && nameArg.value.kind === 'literal') { schema = nameArg.value.value as string; } @@ -411,9 +491,9 @@ export class QueryNameMapper extends OperationNodeTransformer { ); if (columnName !== fieldDef.name) { const aliased = AliasNode.create(columnRef, IdentifierNode.create(fieldDef.name)); - return SelectionNode.create(aliased); + return { originalField: fieldDef.name, selection: SelectionNode.create(aliased) }; } else { - return SelectionNode.create(columnRef); + return { originalField: fieldDef.name, selection: SelectionNode.create(columnRef) }; } }); } @@ -442,20 +522,28 @@ export class QueryNameMapper extends OperationNodeTransformer { return result; } - private processSelection(node: AliasNode | ColumnNode | ReferenceNode) { - let alias: string | undefined; - if (!AliasNode.is(node)) { - alias = extractFieldName(node); + private processSelection(node: SelectionNodeChild) { + const { alias, node: innerNode } = stripAlias(node); + const originalField = extractFieldName(innerNode); + let result = super.transformNode(node); + + if (originalField) { + // process enum value mapping + result = this.processEnumSelection(result, originalField); + } + + if (!AliasNode.is(result)) { + const addAlias = alias ?? (originalField ? IdentifierNode.create(originalField) : undefined); + if (addAlias) { + result = this.wrapAlias(result, addAlias); + } } - const result = super.transformNode(node); - return this.wrapAlias(result, alias ? IdentifierNode.create(alias) : undefined); + return result; } private processSelectAll(node: SelectAllNode) { - const scope = this.scopes[this.scopes.length - 1]; - invariant(scope); - - if (!scope.model || !this.hasMappedColumns(scope.model)) { + const scope = this.requireCurrentScope(); + if (!scope.model || !(this.hasMappedColumns(scope.model) || this.modelUsesEnumWithMappedValues(scope.model))) { // no name mapping needed, preserve the select all return super.transformSelectAll(node); } @@ -465,9 +553,13 @@ export class QueryNameMapper extends OperationNodeTransformer { return this.getModelFields(modelDef).map((fieldDef) => { const columnName = this.mapFieldName(modelDef.name, fieldDef.name); const columnRef = ReferenceNode.create(ColumnNode.create(columnName)); - return columnName !== fieldDef.name - ? this.wrapAlias(columnRef, IdentifierNode.create(fieldDef.name)) - : columnRef; + + // process enum value mapping + const enumProcessed = this.processEnumSelection(columnRef, fieldDef.name); + + return columnName !== fieldDef.name && !AliasNode.is(enumProcessed) + ? this.wrapAlias(enumProcessed, IdentifierNode.create(fieldDef.name)) + : enumProcessed; }); } @@ -475,5 +567,150 @@ export class QueryNameMapper extends OperationNodeTransformer { return schemaName ? TableNode.createWithSchema(schemaName, tableName) : TableNode.create(tableName); } + private requireCurrentScope() { + const scope = this.scopes[this.scopes.length - 1]; + invariant(scope, 'No scope available'); + return scope; + } + + // #endregion + + // #region enum value mapping + + private modelUsesEnumWithMappedValues(model: string) { + const modelDef = getModel(this.schema, model); + if (!modelDef) { + return false; + } + return this.getModelFields(modelDef).some((fieldDef) => { + const enumDef = getEnum(this.schema, fieldDef.type); + if (!enumDef) { + return false; + } + return Object.values(enumDef.fields ?? {}).some((f) => f.attributes?.some((attr) => attr.name === '@map')); + }); + } + + private getEnumValueMapping(enumDef: EnumDef) { + const mapping: Record = {}; + for (const [key, field] of Object.entries(enumDef.fields ?? {})) { + const mappedName = this.getMappedName(field); + if (mappedName) { + mapping[key] = mappedName; + } + } + return mapping; + } + + private processEnumMappingForColumns( + model: string, + columns: readonly ColumnNode[], + values: OperationNode, + ): OperationNode { + if (ValuesNode.is(values)) { + return ValuesNode.create( + values.values.map((valueItems) => { + if (PrimitiveValueListNode.is(valueItems)) { + return PrimitiveValueListNode.create( + this.processEnumMappingForValues(model, columns, valueItems.values), + ); + } else { + return ValueListNode.create( + this.processEnumMappingForValues(model, columns, valueItems.values) as OperationNode[], + ); + } + }), + ); + } else if (PrimitiveValueListNode.is(values)) { + return PrimitiveValueListNode.create(this.processEnumMappingForValues(model, columns, values.values)); + } else { + return values; + } + } + + private processEnumMappingForValues(model: string, columns: readonly ColumnNode[], values: readonly unknown[]) { + const result: unknown[] = []; + for (let i = 0; i < columns.length; i++) { + const value = values[i]; + if (value === null || value === undefined) { + result.push(value); + continue; + } + result.push(this.processEnumMappingForValue(model, columns[i]!, value)); + } + return result; + } + + private processEnumMappingForValue(model: string, column: ColumnNode, value: unknown) { + const fieldDef = getField(this.schema, model, column.column.name); + if (!fieldDef) { + return value; + } + if (!isEnum(this.schema, fieldDef.type)) { + return value; + } + + const enumDef = getEnum(this.schema, fieldDef.type); + if (!enumDef) { + return value; + } + + const enumValueMapping = this.getEnumValueMapping(enumDef); + if (this.isOperationNode(value) && ValueNode.is(value) && typeof value.value === 'string') { + const mappedValue = enumValueMapping[value.value]; + if (mappedValue) { + return ValueNode.create(mappedValue); + } + } else if (typeof value === 'string') { + const mappedValue = enumValueMapping[value]; + if (mappedValue) { + return mappedValue; + } + } + + return value; + } + + private processEnumSelection(selection: SelectionNodeChild, fieldName: string) { + const { alias, node } = stripAlias(selection); + const fieldScope = this.resolveFieldFromScopes(fieldName); + if (!fieldScope || !fieldScope.model) { + return selection; + } + const aliasName = alias && IdentifierNode.is(alias) ? alias.name : fieldName; + + const fieldDef = getField(this.schema, fieldScope.model, fieldName); + if (!fieldDef) { + return selection; + } + const enumDef = getEnum(this.schema, fieldDef.type); + if (!enumDef) { + return selection; + } + const enumValueMapping = this.getEnumValueMapping(enumDef); + if (Object.keys(enumValueMapping).length === 0) { + return selection; + } + + const eb = expressionBuilder(); + const caseBuilder = eb.case(); + let caseWhen: CaseWhenBuilder | undefined; + for (const [key, value] of Object.entries(enumValueMapping)) { + if (!caseWhen) { + caseWhen = caseBuilder.when(new ExpressionWrapper(node), '=', value).then(key); + } else { + caseWhen = caseWhen.when(new ExpressionWrapper(node), '=', value).then(key); + } + } + + // the explicit cast to "text" is needed to address postgres's case-when type inference issue + const finalExpr = caseWhen!.else(eb.cast(new ExpressionWrapper(node), 'text')).end(); + if (aliasName) { + return finalExpr.as(aliasName).toOperationNode() as SelectionNodeChild; + } else { + return finalExpr.toOperationNode() as SelectionNodeChild; + } + } + // #endregion } diff --git a/packages/orm/src/client/helpers/schema-db-pusher.ts b/packages/orm/src/client/helpers/schema-db-pusher.ts index 52fc462b..f04c14e1 100644 --- a/packages/orm/src/client/helpers/schema-db-pusher.ts +++ b/packages/orm/src/client/helpers/schema-db-pusher.ts @@ -23,7 +23,26 @@ export class SchemaDbPusher { await this.kysely.transaction().execute(async (tx) => { if (this.schema.enums && this.schema.provider.type === 'postgresql') { for (const [name, enumDef] of Object.entries(this.schema.enums)) { - const createEnum = tx.schema.createType(name).asEnum(Object.values(enumDef)); + let enumValues: string[]; + if (enumDef.fields) { + enumValues = Object.values(enumDef.fields).map((f) => { + const mapAttr = f.attributes?.find((a) => a.name === '@map'); + if (!mapAttr || !mapAttr.args?.[0]) { + return f.name; + } else { + const mappedName = ExpressionUtils.getLiteralValue(mapAttr.args[0].value); + invariant( + mappedName && typeof mappedName === 'string', + `Invalid @map attribute for enum field ${f.name}`, + ); + return mappedName; + } + }); + } else { + enumValues = Object.values(enumDef.values); + } + + const createEnum = tx.schema.createType(name).asEnum(enumValues); await createEnum.execute(); } } diff --git a/packages/schema/src/schema.ts b/packages/schema/src/schema.ts index ac214fa1..13fc90b9 100644 --- a/packages/schema/src/schema.ts +++ b/packages/schema/src/schema.ts @@ -98,7 +98,16 @@ export type BuiltinType = export type MappedBuiltinType = string | boolean | number | bigint | Decimal | Date; -export type EnumDef = Record; +export type EnumField = { + name: string; + attributes?: AttributeApplication[]; +}; + +export type EnumDef = { + fields?: Record; + values: Record; + attributes?: AttributeApplication[]; +}; export type TypeDefDef = { name: string; @@ -125,7 +134,9 @@ export type GetModel> export type GetEnums = keyof Schema['enums']; -export type GetEnum> = Schema['enums'][Enum]; +export type GetEnum> = Schema['enums'][Enum] extends EnumDef + ? Schema['enums'][Enum]['values'] + : never; export type GetTypeDefs = Extract; diff --git a/packages/sdk/src/ts-schema-generator.ts b/packages/sdk/src/ts-schema-generator.ts index 564f5112..798b6dfe 100644 --- a/packages/sdk/src/ts-schema-generator.ts +++ b/packages/sdk/src/ts-schema-generator.ts @@ -947,9 +947,68 @@ export class TsSchemaGenerator { private createEnumObject(e: Enum) { return ts.factory.createObjectLiteralExpression( - e.fields.map((field) => - ts.factory.createPropertyAssignment(field.name, ts.factory.createStringLiteral(field.name)), - ), + [ + ts.factory.createPropertyAssignment( + 'values', + ts.factory.createObjectLiteralExpression( + e.fields.map((f) => + ts.factory.createPropertyAssignment(f.name, ts.factory.createStringLiteral(f.name)), + ), + true, + ), + ), + + // only generate `fields` if there are attributes on the fields + ...(e.fields.some((f) => f.attributes.length > 0) + ? [ + ts.factory.createPropertyAssignment( + 'fields', + ts.factory.createObjectLiteralExpression( + e.fields.map((field) => + ts.factory.createPropertyAssignment( + field.name, + ts.factory.createObjectLiteralExpression( + [ + ts.factory.createPropertyAssignment( + 'name', + ts.factory.createStringLiteral(field.name), + ), + ...(field.attributes.length > 0 + ? [ + ts.factory.createPropertyAssignment( + 'attributes', + ts.factory.createArrayLiteralExpression( + field.attributes?.map((attr) => + this.createAttributeObject(attr), + ) ?? [], + true, + ), + ), + ] + : []), + ], + true, + ), + ), + ), + true, + ), + ), + ] + : []), + + ...(e.attributes.length > 0 + ? [ + ts.factory.createPropertyAssignment( + 'attributes', + ts.factory.createArrayLiteralExpression( + e.attributes.map((attr) => this.createAttributeObject(attr)), + true, + ), + ), + ] + : []), + ], true, ); } @@ -1259,7 +1318,7 @@ export class TsSchemaGenerator { statements.push(typeDef); } - // generate: export const Enum = $schema.enums.Enum; + // generate: export const Enum = $schema.enums.Enum['values']; const enums = model.declarations.filter(isEnum); for (const e of enums) { let enumDecl = ts.factory.createVariableStatement( @@ -1272,10 +1331,13 @@ export class TsSchemaGenerator { undefined, ts.factory.createPropertyAccessExpression( ts.factory.createPropertyAccessExpression( - ts.factory.createIdentifier('$schema'), - ts.factory.createIdentifier('enums'), + ts.factory.createPropertyAccessExpression( + ts.factory.createIdentifier('$schema'), + ts.factory.createIdentifier('enums'), + ), + ts.factory.createIdentifier(e.name), ), - ts.factory.createIdentifier(e.name), + ts.factory.createIdentifier('values'), ), ), ], diff --git a/packages/testtools/src/client.ts b/packages/testtools/src/client.ts index 3f22fead..569a67ab 100644 --- a/packages/testtools/src/client.ts +++ b/packages/testtools/src/client.ts @@ -33,6 +33,7 @@ const TEST_PG_CONFIG = { export type CreateTestClientOptions = Omit, 'dialect'> & { provider?: 'sqlite' | 'postgresql'; + schemaFile?: string; dbName?: string; usePrismaPush?: boolean; extraSourceFiles?: Record; @@ -44,7 +45,6 @@ export type CreateTestClientOptions = Omit( schema: Schema, options?: CreateTestClientOptions, - schemaFile?: string, ): Promise>; export async function createTestClient( schema: string, @@ -53,7 +53,6 @@ export async function createTestClient( export async function createTestClient( schema: Schema | string, options?: CreateTestClientOptions, - schemaFile?: string, ): Promise { let workDir = options?.workDir; let _schema: Schema; @@ -87,8 +86,8 @@ export async function createTestClient( }, }; workDir ??= createTestProject(); - if (schemaFile) { - let schemaContent = fs.readFileSync(schemaFile, 'utf-8'); + if (options?.schemaFile) { + let schemaContent = fs.readFileSync(options.schemaFile, 'utf-8'); if (dbUrl) { // replace `datasource db { }` section schemaContent = schemaContent.replace( @@ -124,7 +123,7 @@ export async function createTestClient( if (!options?.dbFile) { if (options?.usePrismaPush) { invariant( - typeof schema === 'string' || schemaFile, + typeof schema === 'string' || options?.schemaFile, 'a schema file must be provided when using prisma db push', ); if (!model) { diff --git a/samples/orm/zenstack/models.ts b/samples/orm/zenstack/models.ts index 2eb57fad..e2db380e 100644 --- a/samples/orm/zenstack/models.ts +++ b/samples/orm/zenstack/models.ts @@ -23,7 +23,7 @@ export type CommonFields = $TypeDefResult<$Schema, "CommonFields">; /** * User roles */ -export const Role = $schema.enums.Role; +export const Role = $schema.enums.Role.values; /** * User roles */ diff --git a/samples/orm/zenstack/schema.ts b/samples/orm/zenstack/schema.ts index 637ec687..4c9134b9 100644 --- a/samples/orm/zenstack/schema.ts +++ b/samples/orm/zenstack/schema.ts @@ -232,8 +232,10 @@ export const schema = { }, enums: { Role: { - ADMIN: "ADMIN", - USER: "USER" + values: { + ADMIN: "ADMIN", + USER: "USER" + } } }, authType: "User", diff --git a/tests/e2e/orm/client-api/delegate.test.ts b/tests/e2e/orm/client-api/delegate.test.ts index 1497f91b..9076c60d 100644 --- a/tests/e2e/orm/client-api/delegate.test.ts +++ b/tests/e2e/orm/client-api/delegate.test.ts @@ -8,13 +8,10 @@ describe('Delegate model tests ', () => { let client: ClientContract; beforeEach(async () => { - client = await createTestClient( - schema, - { - usePrismaPush: true, - }, - path.join(__dirname, '../schemas/delegate/schema.zmodel'), - ); + client = await createTestClient(schema, { + usePrismaPush: true, + schemaFile: path.join(__dirname, '../schemas/delegate/schema.zmodel'), + }); }); afterEach(async () => { diff --git a/tests/e2e/orm/client-api/name-mapping.test.ts b/tests/e2e/orm/client-api/name-mapping.test.ts index 5d9151e7..6f279114 100644 --- a/tests/e2e/orm/client-api/name-mapping.test.ts +++ b/tests/e2e/orm/client-api/name-mapping.test.ts @@ -1,18 +1,17 @@ +import type { ClientContract } from '@zenstackhq/orm'; +import { createTestClient } from '@zenstackhq/testtools'; import path from 'node:path'; import { afterEach, beforeEach, describe, expect, it } from 'vitest'; -import type { ClientContract } from '@zenstackhq/orm'; import { schema, type SchemaType } from '../schemas/name-mapping/schema'; -import { createTestClient } from '@zenstackhq/testtools'; describe('Name mapping tests', () => { let db: ClientContract; beforeEach(async () => { - db = await createTestClient( - schema, - { usePrismaPush: true }, - path.join(__dirname, '../schemas/name-mapping/schema.zmodel'), - ); + db = await createTestClient(schema, { + usePrismaPush: true, + schemaFile: path.join(__dirname, '../schemas/name-mapping/schema.zmodel'), + }); }); afterEach(async () => { @@ -34,6 +33,37 @@ describe('Name mapping tests', () => { ).resolves.toMatchObject({ id: expect.any(Number), email: 'u1@test.com', + role: 'USER', // mapped enum value + }); + + let rawRead = await db.$qbRaw + .selectFrom('users') + .where('user_email', '=', 'u1@test.com') + .selectAll() + .executeTakeFirst(); + await expect(rawRead).toMatchObject({ + user_email: 'u1@test.com', + user_role: 'role_user', + }); + + await expect( + db.user.create({ + data: { + email: 'u1_1@test.com', + role: 'MODERATOR', // unmapped enum value + }, + }), + ).resolves.toMatchObject({ + role: 'MODERATOR', + }); + + rawRead = await db.$qbRaw + .selectFrom('users') + .where('user_email', '=', 'u1_1@test.com') + .selectAll() + .executeTakeFirst(); + await expect(rawRead).toMatchObject({ + user_role: 'MODERATOR', }); await expect( @@ -41,12 +71,22 @@ describe('Name mapping tests', () => { .insertInto('User') .values({ email: 'u2@test.com', + role: 'ADMIN', }) - .returning(['id', 'email']) + .returning(['id', 'email', 'role']) .executeTakeFirst(), ).resolves.toMatchObject({ id: expect.any(Number), email: 'u2@test.com', + role: 'ADMIN', + }); + rawRead = await db.$qbRaw + .selectFrom('users') + .where('user_email', '=', 'u2@test.com') + .selectAll() + .executeTakeFirst(); + await expect(rawRead).toMatchObject({ + user_role: 'role_admin', }); await expect( @@ -73,6 +113,7 @@ describe('Name mapping tests', () => { ).resolves.toMatchObject({ id: expect.any(Number), email: 'u4@test.com', + role: 'USER', }); }); @@ -94,26 +135,45 @@ describe('Name mapping tests', () => { select: { id: true, email: true, + role: true, posts: { where: { title: { contains: 'Post1' } }, select: { title: true } }, }, }), ).resolves.toMatchObject({ id: expect.any(Number), email: 'u1@test.com', + role: 'USER', posts: [{ title: 'Post1' }], }); + // select all + await expect( + db.user.findFirst({ + where: { email: 'u1@test.com' }, + }), + ).resolves.toMatchObject({ + id: expect.any(Number), + email: 'u1@test.com', + role: 'USER', + }); + await expect( db.$qb.selectFrom('User').selectAll().where('email', '=', 'u1@test.com').executeTakeFirst(), ).resolves.toMatchObject({ id: expect.any(Number), email: 'u1@test.com', + role: 'USER', }); await expect( - db.$qb.selectFrom('User').select(['User.email']).where('email', '=', 'u1@test.com').executeTakeFirst(), + db.$qb + .selectFrom('User') + .select(['User.email', 'User.role']) + .where('email', '=', 'u1@test.com') + .executeTakeFirst(), ).resolves.toMatchObject({ email: 'u1@test.com', + role: 'USER', }); await expect( @@ -172,6 +232,7 @@ describe('Name mapping tests', () => { where: { id: user.id }, data: { email: 'u2@test.com', + role: 'ADMIN', posts: { update: { where: { id: 1 }, @@ -184,21 +245,22 @@ describe('Name mapping tests', () => { ).resolves.toMatchObject({ id: user.id, email: 'u2@test.com', + role: 'ADMIN', posts: [expect.objectContaining({ title: 'Post2' })], }); await expect( db.$qb .updateTable('User') - .set({ email: (eb) => eb.fn('upper', [eb.ref('email')]) }) + .set({ email: (eb) => eb.fn('upper', [eb.ref('email')]), role: 'USER' }) .where('email', '=', 'u2@test.com') - .returning(['email']) + .returning(['email', 'role']) .executeTakeFirst(), - ).resolves.toMatchObject({ email: 'U2@TEST.COM' }); + ).resolves.toMatchObject({ email: 'U2@TEST.COM', role: 'USER' }); await expect( db.$qb.updateTable('User as u').set({ email: 'u3@test.com' }).returningAll().executeTakeFirst(), - ).resolves.toMatchObject({ id: expect.any(Number), email: 'u3@test.com' }); + ).resolves.toMatchObject({ id: expect.any(Number), email: 'u3@test.com', role: 'USER' }); }); it('works with delete', async () => { @@ -229,6 +291,7 @@ describe('Name mapping tests', () => { ).resolves.toMatchObject({ email: 'u1@test.com', posts: [], + role: 'USER', }); }); @@ -236,6 +299,7 @@ describe('Name mapping tests', () => { await db.user.create({ data: { email: 'u1@test.com', + role: 'USER', posts: { create: [{ title: 'Post1' }, { title: 'Post2' }], }, @@ -245,6 +309,7 @@ describe('Name mapping tests', () => { await db.user.create({ data: { email: 'u2@test.com', + role: 'MODERATOR', posts: { create: [{ title: 'Post3' }], }, @@ -254,8 +319,9 @@ describe('Name mapping tests', () => { // Test ORM count operations await expect(db.user.count()).resolves.toBe(2); await expect(db.post.count()).resolves.toBe(3); - await expect(db.user.count({ select: { email: true } })).resolves.toMatchObject({ + await expect(db.user.count({ select: { email: true, role: true } })).resolves.toMatchObject({ email: 2, + role: 2, }); await expect(db.user.count({ where: { email: 'u1@test.com' } })).resolves.toBe(1); @@ -266,9 +332,11 @@ describe('Name mapping tests', () => { // Test Kysely count operations const r = await db.$qb .selectFrom('User') - .select((eb) => eb.fn.count('email').as('count')) + .select((eb) => eb.fn.count('email').as('email_count')) + .select((eb) => eb.fn.count('role').as('role_count')) .executeTakeFirst(); - await expect(Number(r?.count)).toBe(2); + await expect(Number(r?.email_count)).toBe(2); + await expect(Number(r?.role_count)).toBe(2); }); it('works with aggregate', async () => { @@ -276,6 +344,7 @@ describe('Name mapping tests', () => { data: { id: 1, email: 'u1@test.com', + role: 'USER', posts: { create: [ { id: 1, title: 'Post1' }, @@ -289,6 +358,7 @@ describe('Name mapping tests', () => { data: { id: 2, email: 'u2@test.com', + role: 'MODERATOR', posts: { create: [{ id: 3, title: 'Post3' }], }, @@ -296,8 +366,12 @@ describe('Name mapping tests', () => { }); // Test ORM aggregate operations - await expect(db.user.aggregate({ _count: { id: true, email: true } })).resolves.toMatchObject({ + await expect( + db.user.aggregate({ _count: { id: true, email: true }, _max: { role: true }, _min: { role: true } }), + ).resolves.toMatchObject({ _count: { id: 2, email: 2 }, + _max: { role: 'USER' }, + _min: { role: 'MODERATOR' }, }); await expect( @@ -342,6 +416,7 @@ describe('Name mapping tests', () => { data: { id: 1, email: 'u1@test.com', + role: 'USER', posts: { create: [ { id: 1, title: 'Post1' }, @@ -356,6 +431,7 @@ describe('Name mapping tests', () => { data: { id: 2, email: 'u2@test.com', + role: 'MODERATOR', posts: { create: [ { id: 4, title: 'Post4' }, @@ -389,6 +465,18 @@ describe('Name mapping tests', () => { ]), ); + const userGroupBy1 = await db.user.groupBy({ + by: ['role'], + _count: { id: true }, + }); + expect(userGroupBy1).toHaveLength(2); + expect(userGroupBy1).toEqual( + expect.arrayContaining([ + { role: 'USER', _count: { id: 2 } }, + { role: 'MODERATOR', _count: { id: 1 } }, + ]), + ); + const postGroupBy = await db.post.groupBy({ by: ['authorId'], _count: { id: true }, diff --git a/tests/e2e/orm/query-builder/query-builder.test.ts b/tests/e2e/orm/query-builder/query-builder.test.ts index 23f31e12..563118a4 100644 --- a/tests/e2e/orm/query-builder/query-builder.test.ts +++ b/tests/e2e/orm/query-builder/query-builder.test.ts @@ -1,7 +1,7 @@ import { createId } from '@paralleldrive/cuid2'; +import { createTestClient } from '@zenstackhq/testtools'; import { describe, expect, it } from 'vitest'; import { getSchema } from '../schemas/basic'; -import { createTestClient } from '@zenstackhq/testtools'; describe('Client API tests', () => { const schema = getSchema('sqlite'); diff --git a/tests/e2e/orm/schemas/basic/models.ts b/tests/e2e/orm/schemas/basic/models.ts index d532d7d4..be197879 100644 --- a/tests/e2e/orm/schemas/basic/models.ts +++ b/tests/e2e/orm/schemas/basic/models.ts @@ -12,5 +12,5 @@ export type Post = $ModelResult<$Schema, "Post">; export type Comment = $ModelResult<$Schema, "Comment">; export type Profile = $ModelResult<$Schema, "Profile">; export type CommonFields = $TypeDefResult<$Schema, "CommonFields">; -export const Role = $schema.enums.Role; +export const Role = $schema.enums.Role.values; export type Role = (typeof Role)[keyof typeof Role]; diff --git a/tests/e2e/orm/schemas/basic/schema.ts b/tests/e2e/orm/schemas/basic/schema.ts index 14e627b9..6339ab0a 100644 --- a/tests/e2e/orm/schemas/basic/schema.ts +++ b/tests/e2e/orm/schemas/basic/schema.ts @@ -275,8 +275,10 @@ export const schema = { }, enums: { Role: { - ADMIN: "ADMIN", - USER: "USER" + values: { + ADMIN: "ADMIN", + USER: "USER" + } } }, authType: "User", diff --git a/tests/e2e/orm/schemas/name-mapping/models.ts b/tests/e2e/orm/schemas/name-mapping/models.ts index 72654e58..944ad9cb 100644 --- a/tests/e2e/orm/schemas/name-mapping/models.ts +++ b/tests/e2e/orm/schemas/name-mapping/models.ts @@ -5,7 +5,9 @@ /* eslint-disable */ -import { type SchemaType as $Schema } from "./schema"; +import { schema as $schema, type SchemaType as $Schema } from "./schema"; import { type ModelResult as $ModelResult } from "@zenstackhq/orm"; export type User = $ModelResult<$Schema, "User">; export type Post = $ModelResult<$Schema, "Post">; +export const Role = $schema.enums.Role.values; +export type Role = (typeof Role)[keyof typeof Role]; diff --git a/tests/e2e/orm/schemas/name-mapping/schema.ts b/tests/e2e/orm/schemas/name-mapping/schema.ts index 5c27728b..97aa169a 100644 --- a/tests/e2e/orm/schemas/name-mapping/schema.ts +++ b/tests/e2e/orm/schemas/name-mapping/schema.ts @@ -27,6 +27,12 @@ export const schema = { unique: true, attributes: [{ name: "@map", args: [{ name: "name", value: ExpressionUtils.literal("user_email") }] }, { name: "@unique" }] }, + role: { + name: "role", + type: "Role", + attributes: [{ name: "@default", args: [{ name: "value", value: ExpressionUtils.literal("USER") }] }, { name: "@map", args: [{ name: "name", value: ExpressionUtils.literal("user_role") }] }], + default: "USER" + }, posts: { name: "posts", type: "Post", @@ -82,6 +88,35 @@ export const schema = { } } }, + enums: { + Role: { + values: { + USER: "USER", + ADMIN: "ADMIN", + MODERATOR: "MODERATOR" + }, + fields: { + USER: { + name: "USER", + attributes: [ + { name: "@map", args: [{ name: "name", value: ExpressionUtils.literal("role_user") }] } + ] + }, + ADMIN: { + name: "ADMIN", + attributes: [ + { name: "@map", args: [{ name: "name", value: ExpressionUtils.literal("role_admin") }] } + ] + }, + MODERATOR: { + name: "MODERATOR" + } + }, + attributes: [ + { name: "@@map", args: [{ name: "name", value: ExpressionUtils.literal("user_role") }] } + ] + } + }, authType: "User", plugins: {} } as const satisfies SchemaDef; diff --git a/tests/e2e/orm/schemas/name-mapping/schema.zmodel b/tests/e2e/orm/schemas/name-mapping/schema.zmodel index baddc94f..2dfbaeda 100644 --- a/tests/e2e/orm/schemas/name-mapping/schema.zmodel +++ b/tests/e2e/orm/schemas/name-mapping/schema.zmodel @@ -3,9 +3,17 @@ datasource db { url = "file:./dev.db" } +enum Role { + USER @map('role_user') + ADMIN @map('role_admin') + MODERATOR + @@map("user_role") +} + model User { id Int @id @default(autoincrement()) email String @map('user_email') @unique + role Role @default(USER) @map('user_role') posts Post[] @@map('users') } diff --git a/tests/e2e/orm/schemas/typing/models.ts b/tests/e2e/orm/schemas/typing/models.ts index b2fa673f..15eae9a9 100644 --- a/tests/e2e/orm/schemas/typing/models.ts +++ b/tests/e2e/orm/schemas/typing/models.ts @@ -15,7 +15,7 @@ export type Region = $ModelResult<$Schema, "Region">; export type Meta = $ModelResult<$Schema, "Meta">; export type Identity = $TypeDefResult<$Schema, "Identity">; export type IdentityProvider = $TypeDefResult<$Schema, "IdentityProvider">; -export const Role = $schema.enums.Role; +export const Role = $schema.enums.Role.values; export type Role = (typeof Role)[keyof typeof Role]; -export const Status = $schema.enums.Status; +export const Status = $schema.enums.Status.values; export type Status = (typeof Status)[keyof typeof Status]; diff --git a/tests/e2e/orm/schemas/typing/schema.ts b/tests/e2e/orm/schemas/typing/schema.ts index 10c4daa7..1a1212bd 100644 --- a/tests/e2e/orm/schemas/typing/schema.ts +++ b/tests/e2e/orm/schemas/typing/schema.ts @@ -328,13 +328,17 @@ export const schema = { }, enums: { Role: { - ADMIN: "ADMIN", - USER: "USER" + values: { + ADMIN: "ADMIN", + USER: "USER" + } }, Status: { - ACTIVE: "ACTIVE", - INACTIVE: "INACTIVE", - BANNED: "BANNED" + values: { + ACTIVE: "ACTIVE", + INACTIVE: "INACTIVE", + BANNED: "BANNED" + } } }, authType: "User",