diff --git a/packages/cli/test/ts-schema-gen.test.ts b/packages/cli/test/ts-schema-gen.test.ts index 2ec04048..cd34de58 100644 --- a/packages/cli/test/ts-schema-gen.test.ts +++ b/packages/cli/test/ts-schema-gen.test.ts @@ -266,4 +266,63 @@ type Address with Base { }, }); }); + + it('merges fields and attributes from base models', async () => { + const { schema } = await generateTsSchema(` +model Base { + id String @id @default(uuid()) + createdAt DateTime @default(now()) + updatedAt DateTime @updatedAt + type String + @@delegate(type) +} + +model User extends Base { + email String @unique +} + `); + expect(schema).toMatchObject({ + models: { + Base: { + fields: { + id: { + type: 'String', + id: true, + default: expect.objectContaining({ function: 'uuid', kind: 'call' }), + }, + createdAt: { + type: 'DateTime', + default: expect.objectContaining({ function: 'now', kind: 'call' }), + }, + updatedAt: { type: 'DateTime', updatedAt: true }, + type: { type: 'String' }, + }, + attributes: [ + { + name: '@@delegate', + args: [{ name: 'discriminator', value: { kind: 'field', field: 'type' } }], + }, + ], + isDelegate: true, + }, + User: { + baseModel: 'Base', + fields: { + id: { type: 'String' }, + createdAt: { + type: 'DateTime', + default: expect.objectContaining({ function: 'now', kind: 'call' }), + originModel: 'Base', + }, + updatedAt: { type: 'DateTime', updatedAt: true, originModel: 'Base' }, + type: { type: 'String', originModel: 'Base' }, + email: { type: 'String' }, + }, + uniqueFields: expect.objectContaining({ + email: { type: 'String' }, + }), + }, + }, + }); + }); }); diff --git a/packages/language/src/ast.ts b/packages/language/src/ast.ts index c301eac4..71d31d4d 100644 --- a/packages/language/src/ast.ts +++ b/packages/language/src/ast.ts @@ -1,5 +1,5 @@ import type { AstNode } from 'langium'; -import { AbstractDeclaration, BinaryExpr, DataModel, type ExpressionType } from './generated/ast'; +import { AbstractDeclaration, BinaryExpr, DataField, DataModel, type ExpressionType } from './generated/ast'; export type { AstNode, Reference } from 'langium'; export * from './generated/ast'; @@ -46,14 +46,6 @@ declare module './ast' { $resolvedParam?: AttributeParam; } - interface DataField { - $inheritedFrom?: DataModel; - } - - interface DataModelAttribute { - $inheritedFrom?: DataModel; - } - export interface DataModel { /** * All fields including those marked with `@ignore` diff --git a/packages/language/src/utils.ts b/packages/language/src/utils.ts index f661bbc5..adb4f78f 100644 --- a/packages/language/src/utils.ts +++ b/packages/language/src/utils.ts @@ -161,10 +161,6 @@ export function resolved(ref: Reference): T { return ref.ref; } -export function getModelFieldsWithBases(model: DataModel, includeDelegate = true) { - return [...model.fields, ...getRecursiveBases(model, includeDelegate).flatMap((base) => base.fields)]; -} - export function getRecursiveBases( decl: DataModel | TypeDef, includeDelegate = true, @@ -533,22 +529,51 @@ export function isMemberContainer(node: unknown): node is DataModel | TypeDef { return isDataModel(node) || isTypeDef(node); } -export function getAllFields(decl: DataModel | TypeDef, includeIgnored = false): DataField[] { +export function getAllFields( + decl: DataModel | TypeDef, + includeIgnored = false, + seen: Set = new Set(), +): DataField[] { + if (seen.has(decl)) { + return []; + } + seen.add(decl); + const fields: DataField[] = []; for (const mixin of decl.mixins) { invariant(mixin.ref, `Mixin ${mixin.$refText} is not resolved`); - fields.push(...getAllFields(mixin.ref)); + fields.push(...getAllFields(mixin.ref, includeIgnored, seen)); + } + + if (isDataModel(decl) && decl.baseModel) { + invariant(decl.baseModel.ref, `Base model ${decl.baseModel.$refText} is not resolved`); + fields.push(...getAllFields(decl.baseModel.ref, includeIgnored, seen)); } + fields.push(...decl.fields.filter((f) => includeIgnored || !hasAttribute(f, '@ignore'))); return fields; } -export function getAllAttributes(decl: DataModel | TypeDef): DataModelAttribute[] { +export function getAllAttributes( + decl: DataModel | TypeDef, + seen: Set = new Set(), +): DataModelAttribute[] { + if (seen.has(decl)) { + return []; + } + seen.add(decl); + const attributes: DataModelAttribute[] = []; for (const mixin of decl.mixins) { invariant(mixin.ref, `Mixin ${mixin.$refText} is not resolved`); - attributes.push(...getAllAttributes(mixin.ref)); + attributes.push(...getAllAttributes(mixin.ref, seen)); + } + + if (isDataModel(decl) && decl.baseModel) { + invariant(decl.baseModel.ref, `Base model ${decl.baseModel.$refText} is not resolved`); + attributes.push(...getAllAttributes(decl.baseModel.ref, seen)); } + attributes.push(...decl.attributes); return attributes; } diff --git a/packages/language/src/validators/datamodel-validator.ts b/packages/language/src/validators/datamodel-validator.ts index 5d7f9919..49c8dfc7 100644 --- a/packages/language/src/validators/datamodel-validator.ts +++ b/packages/language/src/validators/datamodel-validator.ts @@ -1,3 +1,4 @@ +import { invariant } from '@zenstackhq/common-helpers'; import { AstUtils, type AstNode, type DiagnosticInfo, type ValidationAcceptor } from 'langium'; import { IssueCodes, SCALAR_TYPES } from '../constants'; import { @@ -16,8 +17,8 @@ import { } from '../generated/ast'; import { getAllAttributes, + getAllFields, getLiteral, - getModelFieldsWithBases, getModelIdFields, getModelUniqueFields, getUniqueFields, @@ -32,7 +33,7 @@ import { validateDuplicatedDeclarations, type AstValidator } from './common'; */ export default class DataModelValidator implements AstValidator { validate(dm: DataModel, accept: ValidationAcceptor): void { - validateDuplicatedDeclarations(dm, getModelFieldsWithBases(dm), accept); + validateDuplicatedDeclarations(dm, getAllFields(dm), accept); this.validateAttributes(dm, accept); this.validateFields(dm, accept); if (dm.mixins.length > 0) { @@ -42,7 +43,7 @@ export default class DataModelValidator implements AstValidator { } private validateFields(dm: DataModel, accept: ValidationAcceptor) { - const allFields = getModelFieldsWithBases(dm); + const allFields = getAllFields(dm); const idFields = allFields.filter((f) => f.attributes.find((attr) => attr.decl.ref?.name === '@id')); const uniqueFields = allFields.filter((f) => f.attributes.find((attr) => attr.decl.ref?.name === '@unique')); const modelLevelIds = getModelIdFields(dm); @@ -266,7 +267,7 @@ export default class DataModelValidator implements AstValidator { const oppositeModel = field.type.reference!.ref! as DataModel; // Use name because the current document might be updated - let oppositeFields = getModelFieldsWithBases(oppositeModel, false).filter( + let oppositeFields = getAllFields(oppositeModel, false).filter( (f) => f !== field && // exclude self in case of self relation f.type.reference?.ref?.name === contextModel.name, @@ -438,11 +439,38 @@ export default class DataModelValidator implements AstValidator { if (!model.baseModel) { return; } - if (model.baseModel.ref && !isDelegateModel(model.baseModel.ref)) { + + invariant(model.baseModel.ref, 'baseModel must be resolved'); + + // check if the base model is a delegate model + if (!isDelegateModel(model.baseModel.ref)) { accept('error', `Model ${model.baseModel.$refText} cannot be extended because it's not a delegate model`, { node: model, property: 'baseModel', }); + return; + } + + // check for cyclic inheritance + const seen: DataModel[] = []; + const todo = [model.baseModel.ref]; + while (todo.length > 0) { + const current = todo.shift()!; + if (seen.includes(current)) { + accept( + 'error', + `Cyclic inheritance detected: ${seen.map((m) => m.name).join(' -> ')} -> ${current.name}`, + { + node: model, + }, + ); + return; + } + seen.push(current); + if (current.baseModel) { + invariant(current.baseModel.ref, 'baseModel must be resolved'); + todo.push(current.baseModel.ref); + } } } diff --git a/packages/language/src/zmodel-linker.ts b/packages/language/src/zmodel-linker.ts index 1867b368..65a2cb84 100644 --- a/packages/language/src/zmodel-linker.ts +++ b/packages/language/src/zmodel-linker.ts @@ -20,9 +20,9 @@ import { AttributeParam, BinaryExpr, BooleanLiteral, - DataModel, DataField, DataFieldType, + DataModel, Enum, EnumField, type ExpressionType, @@ -43,19 +43,19 @@ import { UnaryExpr, isArrayExpr, isBooleanLiteral, - isDataModel, isDataField, isDataFieldType, + isDataModel, isEnum, isNumberLiteral, isReferenceExpr, isStringLiteral, } from './ast'; import { + getAllFields, getAllLoadedAndReachableDataModelsAndTypeDefs, getAuthDecl, getContainingDataModel, - getModelFieldsWithBases, isAuthInvocation, isFutureExpr, isMemberContainer, @@ -397,8 +397,7 @@ export class ZModelLinker extends DefaultLinker { const transitiveDataModel = attrAppliedOn.type.reference?.ref as DataModel; if (transitiveDataModel) { // resolve references in the context of the transitive data model - const scopeProvider = (name: string) => - getModelFieldsWithBases(transitiveDataModel).find((f) => f.name === name); + const scopeProvider = (name: string) => getAllFields(transitiveDataModel).find((f) => f.name === name); if (isArrayExpr(node.value)) { node.value.items.forEach((item) => { if (isReferenceExpr(item)) { diff --git a/packages/language/src/zmodel-scope.ts b/packages/language/src/zmodel-scope.ts index e95ac0b7..e2b58f02 100644 --- a/packages/language/src/zmodel-scope.ts +++ b/packages/language/src/zmodel-scope.ts @@ -19,8 +19,8 @@ import { match } from 'ts-pattern'; import { BinaryExpr, MemberAccessExpr, - isDataModel, isDataField, + isDataModel, isEnumField, isInvocationExpr, isMemberAccessExpr, @@ -31,9 +31,9 @@ import { } from './ast'; import { PLUGIN_MODULE_NAME, STD_LIB_MODULE_NAME } from './constants'; import { + getAllFields, getAllLoadedAndReachableDataModelsAndTypeDefs, getAuthDecl, - getModelFieldsWithBases, getRecursiveBases, isAuthInvocation, isCollectionPredicate, @@ -231,7 +231,7 @@ export class ZModelScopeProvider extends DefaultScopeProvider { private createScopeForContainer(node: AstNode | undefined, globalScope: Scope, includeTypeDefScope = false) { if (isDataModel(node)) { - return this.createScopeForNodes(getModelFieldsWithBases(node), globalScope); + return this.createScopeForNodes(getAllFields(node), globalScope); } else if (includeTypeDefScope && isTypeDef(node)) { return this.createScopeForNodes(node.fields, globalScope); } else { diff --git a/packages/language/test/delegate.test.ts b/packages/language/test/delegate.test.ts new file mode 100644 index 00000000..185be2bc --- /dev/null +++ b/packages/language/test/delegate.test.ts @@ -0,0 +1,92 @@ +import { describe, expect, it } from 'vitest'; +import { DataModel } from '../src/ast'; +import { loadSchema, loadSchemaWithError } from './utils'; + +describe('Delegate Tests', () => { + it('supports inheriting from delegate', async () => { + const model = await loadSchema(` + model A { + id Int @id @default(autoincrement()) + x String + @@delegate(x) + } + + model B extends A { + y String + } + `); + const a = model.declarations.find((d) => d.name === 'A') as DataModel; + expect(a.baseModel).toBeUndefined(); + const b = model.declarations.find((d) => d.name === 'B') as DataModel; + expect(b.baseModel?.ref).toBe(a); + }); + + it('rejects inheriting from non-delegate models', async () => { + await loadSchemaWithError( + ` + model A { + id Int @id @default(autoincrement()) + x String + } + + model B extends A { + y String + } + `, + 'not a delegate model', + ); + }); + + it('can detect cyclic inherits', async () => { + await loadSchemaWithError( + ` + model A extends B { + x String + @@delegate(x) + } + + model B extends A { + y String + @@delegate(y) + } + `, + 'cyclic', + ); + }); + + it('can detect duplicated fields from base model', async () => { + await loadSchemaWithError( + ` + model A { + id String @id + x String + @@delegate(x) + } + + model B extends A { + x String + } + `, + 'duplicated', + ); + }); + + it('can detect duplicated attributes from base model', async () => { + await loadSchemaWithError( + ` + model A { + id String @id + x String + @@id([x]) + @@delegate(x) + } + + model B extends A { + y String + @@id([y]) + } + `, + 'can only be applied once', + ); + }); +}); diff --git a/packages/runtime/src/client/constants.ts b/packages/runtime/src/client/constants.ts index c80a247a..746cb900 100644 --- a/packages/runtime/src/client/constants.ts +++ b/packages/runtime/src/client/constants.ts @@ -12,3 +12,8 @@ export const NUMERIC_FIELD_TYPES = ['Int', 'Float', 'BigInt', 'Decimal']; * Client API methods that are not supported in transactions. */ export const TRANSACTION_UNSUPPORTED_METHODS = ['$transaction', '$disconnect', '$use'] as const; + +/** + * Prefix for JSON field used to store joined delegate rows. + */ +export const DELEGATE_JOINED_FIELD_PREFIX = '$delegate$'; diff --git a/packages/runtime/src/client/crud/dialects/base.ts b/packages/runtime/src/client/crud/dialects/base.ts index daa154e9..7d4e6fb2 100644 --- a/packages/runtime/src/client/crud/dialects/base.ts +++ b/packages/runtime/src/client/crud/dialects/base.ts @@ -578,7 +578,7 @@ export abstract class BaseCrudDialect { private buildNumberFilter( eb: ExpressionBuilder, model: string, - table: string, + modelAlias: string, field: string, type: BuiltinType, payload: any, @@ -587,9 +587,9 @@ export abstract class BaseCrudDialect { eb, type, payload, - buildFieldRef(this.schema, model, field, this.options, eb), + buildFieldRef(this.schema, model, field, this.options, eb, modelAlias), (value) => this.transformPrimitive(value, type, false), - (value) => this.buildNumberFilter(eb, model, table, field, type, value), + (value) => this.buildNumberFilter(eb, model, modelAlias, field, type, value), ); return this.and(eb, ...conditions); } diff --git a/packages/runtime/src/client/crud/operations/base.ts b/packages/runtime/src/client/crud/operations/base.ts index 64e4efee..a4bb5c11 100644 --- a/packages/runtime/src/client/crud/operations/base.ts +++ b/packages/runtime/src/client/crud/operations/base.ts @@ -22,7 +22,7 @@ import { ExpressionUtils, type GetModels, type ModelDef, type SchemaDef } from ' import { clone } from '../../../utils/clone'; import { enumerate } from '../../../utils/enumerate'; import { extractFields, fieldsToSelectObject } from '../../../utils/object-utils'; -import { CONTEXT_COMMENT_PREFIX, NUMERIC_FIELD_TYPES } from '../../constants'; +import { CONTEXT_COMMENT_PREFIX, DELEGATE_JOINED_FIELD_PREFIX, NUMERIC_FIELD_TYPES } from '../../constants'; import type { CRUD } from '../../contract'; import type { FindArgs, SelectIncludeOmit, SortOrder, WhereInput } from '../../crud-types'; import { InternalError, NotFoundError, QueryError } from '../../errors'; @@ -31,7 +31,9 @@ import { buildFieldRef, buildJoinPairs, ensureArray, + extractIdFields, flattenCompoundUniqueFilters, + getDiscriminatorField, getField, getIdFields, getIdValues, @@ -39,6 +41,7 @@ import { getModel, getRelationForeignKeyFieldPairs, isForeignKeyField, + isInheritedField, isRelationField, isScalarField, makeDefaultOrderBy, @@ -245,6 +248,7 @@ export abstract class BaseOperationHandler { parentAlias: string, ) { let result = query; + const joinedBases: string[] = []; for (const [field, payload] of Object.entries(selectOrInclude)) { if (!payload) { @@ -258,7 +262,7 @@ export abstract class BaseOperationHandler { const fieldDef = this.requireField(model, field); if (!fieldDef.relation) { - result = this.selectField(result, model, parentAlias, field); + result = this.selectField(result, model, parentAlias, field, joinedBases); } else { if (!fieldDef.array && !fieldDef.optional && payload.where) { throw new QueryError(`Field "${field}" doesn't support filtering`); @@ -334,21 +338,95 @@ export abstract class BaseOperationHandler { omit?: Record, ) { const modelDef = this.requireModel(model); - return Object.keys(modelDef.fields) - .filter((f) => !isRelationField(this.schema, model, f)) - .filter((f) => omit?.[f] !== true) - .reduce((acc, f) => this.selectField(acc, model, model, f), query); + let result = query; + const joinedBases: string[] = []; + + for (const field of Object.keys(modelDef.fields)) { + if (isRelationField(this.schema, model, field)) { + continue; + } + if (omit?.[field] === true) { + continue; + } + result = this.selectField(result, model, model, field, joinedBases); + } + + // select all fields from delegate descendants and pack into a JSON field `$delegate$Model` + const descendants = this.getDelegateDescendantModels(model); + for (const subModel of descendants) { + if (!joinedBases.includes(subModel.name)) { + joinedBases.push(subModel.name); + result = this.buildDelegateJoin(model, subModel.name, result); + } + result = result.select((eb) => { + const jsonObject: Record> = {}; + for (const field of Object.keys(subModel.fields)) { + if ( + isRelationField(this.schema, subModel.name, field) || + isInheritedField(this.schema, subModel.name, field) + ) { + continue; + } + jsonObject[field] = eb.ref(`${subModel.name}.${field}`); + } + return this.dialect + .buildJsonObject(eb, jsonObject) + .as(`${DELEGATE_JOINED_FIELD_PREFIX}${subModel.name}`); + }); + } + + return result; + } + + private getDelegateDescendantModels(model: string, collected: Set = new Set()): ModelDef[] { + const subModels = Object.values(this.schema.models).filter((m) => m.baseModel === model); + subModels.forEach((def) => { + if (!collected.has(def)) { + collected.add(def); + this.getDelegateDescendantModels(def.name, collected); + } + }); + return [...collected]; } - private selectField(query: SelectQueryBuilder, model: string, modelAlias: string, field: string) { + private selectField( + query: SelectQueryBuilder, + model: string, + modelAlias: string, + field: string, + joinedBases: string[], + ) { const fieldDef = this.requireField(model, field); - if (!fieldDef.computed) { + + if (fieldDef.computed) { + // TODO: computed field from delegate base? + return query.select((eb) => buildFieldRef(this.schema, model, field, this.options, eb).as(field)); + } else if (!fieldDef.originModel) { + // regular field return query.select(sql.ref(`${modelAlias}.${field}`).as(field)); } else { - return query.select((eb) => buildFieldRef(this.schema, model, field, this.options, eb).as(field)); + // field from delegate base, build a join + let result = query; + if (!joinedBases.includes(fieldDef.originModel)) { + joinedBases.push(fieldDef.originModel); + result = this.buildDelegateJoin(model, fieldDef.originModel, result); + } + result = this.selectField(result, fieldDef.originModel, fieldDef.originModel, field, joinedBases); + return result; } } + private buildDelegateJoin(thisModel: string, otherModel: string, query: SelectQueryBuilder) { + const idFields = getIdFields(this.schema, thisModel); + query = query.leftJoin(otherModel, (qb) => { + for (const idField of idFields) { + qb = qb.onRef(`${thisModel}.${idField}`, '=', `${otherModel}.${idField}`); + } + return qb; + }); + return query; + } + private buildCursorFilter( model: string, query: SelectQueryBuilder, @@ -399,7 +477,7 @@ export abstract class BaseOperationHandler { fromRelation?: FromRelationContext, ): Promise { const modelDef = this.requireModel(model); - const createFields: any = {}; + let createFields: any = {}; let parentUpdateTask: ((entity: any) => Promise) | undefined = undefined; let m2m: ReturnType = undefined; @@ -489,6 +567,12 @@ export abstract class BaseOperationHandler { } } + // create delegate base model entity + if (modelDef.baseModel) { + const baseCreateResult = await this.processBaseModelCreate(kysely, modelDef.baseModel, createFields, model); + createFields = baseCreateResult.remainingFields; + } + const updatedData = this.fillGeneratedValues(modelDef, createFields); const idFields = getIdFields(this.schema, model); const query = kysely @@ -547,6 +631,33 @@ export abstract class BaseOperationHandler { return createdEntity; } + private async processBaseModelCreate(kysely: ToKysely, model: string, createFields: any, forModel: string) { + const thisCreateFields: any = {}; + const remainingFields: any = {}; + + Object.entries(createFields).forEach(([field, value]) => { + const fieldDef = this.getField(model, field); + if (fieldDef) { + thisCreateFields[field] = value; + } else { + remainingFields[field] = value; + } + }); + + const discriminatorField = getDiscriminatorField(this.schema, model); + invariant(discriminatorField, `Base model "${model}" must have a discriminator field`); + thisCreateFields[discriminatorField] = forModel; + + // create base model entity + const createResult = await this.create(kysely, model as GetModels, thisCreateFields); + + // copy over id fields from base model + const idValues = extractIdFields(createResult, this.schema, model); + Object.assign(remainingFields, idValues); + + return { baseEntity: createResult, remainingFields }; + } + private buildFkAssignments(model: string, relationField: string, entity: any) { const parentFkFields: any = {}; @@ -848,7 +959,11 @@ export abstract class BaseOperationHandler { private fillGeneratedValues(modelDef: ModelDef, data: object) { const fields = modelDef.fields; const values: any = clone(data); - for (const field in fields) { + for (const [field, fieldDef] of Object.entries(fields)) { + if (fieldDef.originModel) { + // skip fields from delegate base + continue; + } if (!(field in data)) { if (typeof fields[field]?.default === 'object' && 'kind' in fields[field].default) { const generated = this.evalGenerator(fields[field].default); diff --git a/packages/runtime/src/client/crud/operations/create.ts b/packages/runtime/src/client/crud/operations/create.ts index bc15bb36..e097d475 100644 --- a/packages/runtime/src/client/crud/operations/create.ts +++ b/packages/runtime/src/client/crud/operations/create.ts @@ -4,9 +4,15 @@ import type { GetModels, SchemaDef } from '../../../schema'; import type { CreateArgs, CreateManyAndReturnArgs, CreateManyArgs, WhereInput } from '../../crud-types'; import { getIdValues } from '../../query-utils'; import { BaseOperationHandler } from './base'; +import { QueryError } from '../../errors'; export class CreateOperationHandler extends BaseOperationHandler { async handle(operation: 'create' | 'createMany' | 'createManyAndReturn', args: unknown | undefined) { + const modelDef = this.requireModel(this.model); + if (modelDef.isDelegate) { + throw new QueryError(`Model "${this.model}" is a delegate and cannot be created directly.`); + } + // normalize args to strip `undefined` fields const normalizedArgs = this.normalizeArgs(args); diff --git a/packages/runtime/src/client/crud/validator.ts b/packages/runtime/src/client/crud/validator.ts index cad8e953..c4c7a9d1 100644 --- a/packages/runtime/src/client/crud/validator.ts +++ b/packages/runtime/src/client/crud/validator.ts @@ -3,7 +3,7 @@ import Decimal from 'decimal.js'; import stableStringify from 'json-stable-stringify'; import { match, P } from 'ts-pattern'; import { z, ZodType } from 'zod'; -import type { BuiltinType, EnumDef, FieldDef, GetModels, SchemaDef } from '../../schema'; +import { type BuiltinType, type EnumDef, type FieldDef, type GetModels, type SchemaDef } from '../../schema'; import { NUMERIC_FIELD_TYPES } from '../constants'; import { type AggregateArgs, @@ -21,7 +21,15 @@ import { type UpsertArgs, } from '../crud-types'; import { InputValidationError, InternalError, QueryError } from '../errors'; -import { fieldHasDefaultValue, getEnum, getModel, getUniqueFields, requireField, requireModel } from '../query-utils'; +import { + fieldHasDefaultValue, + getDiscriminatorField, + getEnum, + getModel, + getUniqueFields, + requireField, + requireModel, +} from '../query-utils'; type GetSchemaFunc = (model: GetModels, options: Options) => ZodType; @@ -705,6 +713,11 @@ export class InputValidator { return; } + if (this.isDelegateDiscriminator(fieldDef)) { + // discriminator field is auto-assigned + return; + } + if (fieldDef.relation) { if (withoutRelationFields) { return; @@ -791,6 +804,15 @@ export class InputValidator { } } + private isDelegateDiscriminator(fieldDef: FieldDef) { + if (!fieldDef.originModel) { + // not inherited from a delegate + return false; + } + const discriminatorField = getDiscriminatorField(this.schema, fieldDef.originModel); + return discriminatorField === fieldDef.name; + } + private makeRelationManipulationSchema(fieldDef: FieldDef, withoutFields: string[], mode: 'create' | 'update') { const fieldType = fieldDef.type; const array = !!fieldDef.array; diff --git a/packages/runtime/src/client/query-utils.ts b/packages/runtime/src/client/query-utils.ts index 2f341673..8c47b895 100644 --- a/packages/runtime/src/client/query-utils.ts +++ b/packages/runtime/src/client/query-utils.ts @@ -1,5 +1,5 @@ import type { ExpressionBuilder, ExpressionWrapper } from 'kysely'; -import type { FieldDef, GetModels, SchemaDef } from '../schema'; +import { ExpressionUtils, type FieldDef, type GetModels, type SchemaDef } from '../schema'; import type { OrderBy } from './crud-types'; import { InternalError, QueryError } from './errors'; import type { ClientOptions } from './options'; @@ -111,6 +111,11 @@ export function isRelationField(schema: SchemaDef, model: string, field: string) return !!fieldDef.relation; } +export function isInheritedField(schema: SchemaDef, model: string, field: string): boolean { + const fieldDef = requireField(schema, model, field); + return !!fieldDef.originModel; +} + export function getUniqueFields(schema: SchemaDef, model: string) { const modelDef = requireModel(schema, model); const result: Array< @@ -276,3 +281,35 @@ export function safeJSONStringify(value: unknown) { } }); } + +export function extractFields(object: any, fields: string[]) { + return fields.reduce((acc: any, field) => { + if (field in object) { + acc[field] = object[field]; + } + return acc; + }, {}); +} + +export function extractIdFields(entity: any, schema: SchemaDef, model: string) { + const idFields = getIdFields(schema, model); + return idFields.reduce((acc: any, field) => { + if (field in entity) { + acc[field] = entity[field]; + } + return acc; + }, {}); +} + +export function getDiscriminatorField(schema: SchemaDef, model: string) { + const modelDef = requireModel(schema, model); + const delegateAttr = modelDef.attributes?.find((attr) => attr.name === '@@delegate'); + if (!delegateAttr) { + return undefined; + } + const discriminator = delegateAttr.args?.find((arg) => arg.name === 'discriminator'); + if (!discriminator || !ExpressionUtils.isField(discriminator.value)) { + throw new InternalError(`Discriminator field not defined for model "${model}"`); + } + return discriminator.value.field; +} diff --git a/packages/runtime/src/client/result-processor.ts b/packages/runtime/src/client/result-processor.ts index a43e4648..c7aa230a 100644 --- a/packages/runtime/src/client/result-processor.ts +++ b/packages/runtime/src/client/result-processor.ts @@ -2,7 +2,8 @@ import { invariant } from '@zenstackhq/common-helpers'; import Decimal from 'decimal.js'; import { match } from 'ts-pattern'; import type { BuiltinType, FieldDef, GetModels, SchemaDef } from '../schema'; -import { ensureArray, getField } from './query-utils'; +import { DELEGATE_JOINED_FIELD_PREFIX } from './constants'; +import { ensureArray, getField, getIdValues } from './query-utils'; export class ResultProcessor { constructor(private readonly schema: Schema) {} @@ -38,6 +39,29 @@ export class ResultProcessor { continue; } + if (key.startsWith(DELEGATE_JOINED_FIELD_PREFIX)) { + // merge delegate descendant fields + if (value) { + // descendant fields are packed as JSON + const subRow = this.transformJson(value); + + // process the sub-row + const subModel = key.slice(DELEGATE_JOINED_FIELD_PREFIX.length) as GetModels; + const idValues = getIdValues(this.schema, subModel, subRow); + if (Object.values(idValues).some((v) => v === null || v === undefined)) { + // if the row doesn't have a valid id, the joined row doesn't exist + delete data[key]; + continue; + } + const processedSubRow = this.processRow(subRow, subRow); + + // merge the sub-row into the main row + Object.assign(data, processedSubRow); + } + delete data[key]; + continue; + } + const fieldDef = getField(this.schema, model, key); if (!fieldDef) { continue; diff --git a/packages/runtime/test/client-api/default-values.test.ts b/packages/runtime/test/client-api/default-values.test.ts index d5c93fbd..cafedad1 100644 --- a/packages/runtime/test/client-api/default-values.test.ts +++ b/packages/runtime/test/client-api/default-values.test.ts @@ -12,37 +12,46 @@ const schema = { }, models: { Model: { + name: 'Model', fields: { uuid: { + name: 'uuid', type: 'String', id: true, default: ExpressionUtils.call('uuid'), }, uuid7: { + name: 'uuid7', type: 'String', default: ExpressionUtils.call('uuid', [ExpressionUtils.literal(7)]), }, cuid: { + name: 'cuid', type: 'String', default: ExpressionUtils.call('cuid'), }, cuid2: { + name: 'cuid2', type: 'String', default: ExpressionUtils.call('cuid', [ExpressionUtils.literal(2)]), }, nanoid: { + name: 'nanoid', type: 'String', default: ExpressionUtils.call('nanoid'), }, nanoid8: { + name: 'nanoid8', type: 'String', default: ExpressionUtils.call('nanoid', [ExpressionUtils.literal(8)]), }, ulid: { + name: 'ulid', type: 'String', default: ExpressionUtils.call('ulid'), }, dt: { + name: 'dt', type: 'DateTime', default: ExpressionUtils.call('now'), }, diff --git a/packages/runtime/test/client-api/delegate.test.ts b/packages/runtime/test/client-api/delegate.test.ts new file mode 100644 index 00000000..b1e4fc83 --- /dev/null +++ b/packages/runtime/test/client-api/delegate.test.ts @@ -0,0 +1,206 @@ +import { describe, expect, it } from 'vitest'; +import { createTestClient } from '../utils'; + +describe('Delegate model tests', () => { + const POLYMORPHIC_SCHEMA = ` +model User { + id Int @id @default(autoincrement()) + email String? @unique + level Int @default(0) + assets Asset[] + ratedVideos RatedVideo[] @relation('direct') +} + +model Asset { + id Int @id @default(autoincrement()) + createdAt DateTime @default(now()) + updatedAt DateTime @updatedAt + viewCount Int @default(0) + owner User? @relation(fields: [ownerId], references: [id]) + ownerId Int? + assetType String + + @@delegate(assetType) +} + +model Video extends Asset { + duration Int + url String + videoType String + + @@delegate(videoType) +} + +model RatedVideo extends Video { + rating Int + user User? @relation(name: 'direct', fields: [userId], references: [id]) + userId Int? +} + +model Image extends Asset { + format String + gallery Gallery? @relation(fields: [galleryId], references: [id]) + galleryId Int? +} + +model Gallery { + id Int @id @default(autoincrement()) + images Image[] +} +`; + + it('works with create', async () => { + const client = await createTestClient(POLYMORPHIC_SCHEMA, { + usePrismaPush: true, + }); + + // delegate model cannot be created directly + await expect( + client.video.create({ + data: { + duration: 100, + url: 'abc', + videoType: 'MyVideo', + }, + }), + ).rejects.toThrow('is a delegate'); + + // create entity with two levels of delegation + await expect( + client.ratedVideo.create({ + data: { + duration: 100, + url: 'abc', + rating: 5, + }, + }), + ).resolves.toMatchObject({ + id: expect.any(Number), + duration: 100, + url: 'abc', + rating: 5, + assetType: 'Video', + videoType: 'RatedVideo', + }); + + // create entity with relation + await expect( + client.ratedVideo.create({ + data: { + duration: 50, + url: 'bcd', + rating: 5, + user: { create: { email: 'u1@example.com' } }, + }, + include: { user: true }, + }), + ).resolves.toMatchObject({ + userId: expect.any(Number), + user: { + email: 'u1@example.com', + }, + }); + + // create entity with one level of delegation + await expect( + client.image.create({ + data: { + format: 'png', + gallery: { + create: {}, + }, + }, + }), + ).resolves.toMatchObject({ + id: expect.any(Number), + format: 'png', + galleryId: expect.any(Number), + assetType: 'Image', + }); + }); + + it('works with find', async () => { + const client = await createTestClient(POLYMORPHIC_SCHEMA, { + usePrismaPush: true, + log: ['query'], + }); + + const u = await client.user.create({ + data: { + email: 'u1@example.com', + }, + }); + const v = await client.ratedVideo.create({ + data: { + duration: 100, + url: 'abc', + rating: 5, + user: { connect: { id: u.id } }, + }, + include: { user: true }, + }); + + const ratedVideoContent = { + id: v.id, + createdAt: expect.any(Date), + duration: 100, + rating: 5, + assetType: 'Video', + videoType: 'RatedVideo', + }; + + // include all base fields + await expect( + client.ratedVideo.findUnique({ + where: { id: v.id }, + include: { user: true }, + }), + ).resolves.toMatchObject({ ...ratedVideoContent, user: expect.any(Object) }); + + // select fields + await expect( + client.ratedVideo.findUnique({ + where: { id: v.id }, + select: { + id: true, + viewCount: true, + url: true, + rating: true, + }, + }), + ).resolves.toEqual({ + id: v.id, + viewCount: 0, + url: 'abc', + rating: 5, + }); + + // omit fields + const r = await client.ratedVideo.findUnique({ + where: { id: v.id }, + omit: { + viewCount: true, + url: true, + rating: true, + }, + }); + expect(r.viewCount).toBeUndefined(); + expect(r.url).toBeUndefined(); + expect(r.rating).toBeUndefined(); + expect(r.duration).toEqual(expect.any(Number)); + + // include all sub fields + await expect( + client.video.findUnique({ + where: { id: v.id }, + }), + ).resolves.toMatchObject(ratedVideoContent); + + // include all sub fields + await expect( + client.asset.findUnique({ + where: { id: v.id }, + }), + ).resolves.toMatchObject(ratedVideoContent); + }); +}); diff --git a/packages/runtime/test/client-api/mixin.test.ts b/packages/runtime/test/client-api/mixin.test.ts index 23655f05..ffbdbf2f 100644 --- a/packages/runtime/test/client-api/mixin.test.ts +++ b/packages/runtime/test/client-api/mixin.test.ts @@ -1,7 +1,7 @@ import { describe, expect, it } from 'vitest'; import { createTestClient } from '../utils'; -describe('Client API Mixins', () => { +describe('Mixin tests', () => { it('includes fields and attributes from mixins', async () => { const schema = ` type TimeStamped { diff --git a/packages/runtime/test/client-api/name-mapping.test.ts b/packages/runtime/test/client-api/name-mapping.test.ts index 7c7ca42d..ded45ad0 100644 --- a/packages/runtime/test/client-api/name-mapping.test.ts +++ b/packages/runtime/test/client-api/name-mapping.test.ts @@ -10,13 +10,16 @@ describe('Name mapping tests', () => { }, models: { Foo: { + name: 'Foo', fields: { id: { + name: 'id', type: 'String', id: true, default: ExpressionUtils.call('uuid'), }, x: { + name: 'x', type: 'Int', attributes: [ { diff --git a/packages/runtime/test/test-schema/schema.ts b/packages/runtime/test/test-schema/schema.ts index db61c902..e2c8fb52 100644 --- a/packages/runtime/test/test-schema/schema.ts +++ b/packages/runtime/test/test-schema/schema.ts @@ -12,43 +12,52 @@ export const schema = { }, models: { User: { + name: "User", fields: { id: { + name: "id", type: "String", id: true, attributes: [{ name: "@id" }, { name: "@default", args: [{ name: "value", value: ExpressionUtils.call("cuid") }] }], default: ExpressionUtils.call("cuid") }, createdAt: { + name: "createdAt", type: "DateTime", attributes: [{ name: "@default", args: [{ name: "value", value: ExpressionUtils.call("now") }] }], default: ExpressionUtils.call("now") }, updatedAt: { + name: "updatedAt", type: "DateTime", updatedAt: true, attributes: [{ name: "@updatedAt" }] }, email: { + name: "email", type: "String", unique: true, attributes: [{ name: "@unique" }] }, name: { + name: "name", type: "String", optional: true }, role: { + name: "role", type: "Role", attributes: [{ name: "@default", args: [{ name: "value", value: ExpressionUtils.literal("USER") }] }], default: "USER" }, posts: { + name: "posts", type: "Post", array: true, relation: { opposite: "author" } }, profile: { + name: "profile", type: "Profile", optional: true, relation: { opposite: "user" } @@ -65,47 +74,57 @@ export const schema = { } }, Post: { + name: "Post", fields: { id: { + name: "id", type: "String", id: true, attributes: [{ name: "@id" }, { name: "@default", args: [{ name: "value", value: ExpressionUtils.call("cuid") }] }], default: ExpressionUtils.call("cuid") }, createdAt: { + name: "createdAt", type: "DateTime", attributes: [{ name: "@default", args: [{ name: "value", value: ExpressionUtils.call("now") }] }], default: ExpressionUtils.call("now") }, updatedAt: { + name: "updatedAt", type: "DateTime", updatedAt: true, attributes: [{ name: "@updatedAt" }] }, title: { + name: "title", type: "String" }, content: { + name: "content", type: "String", optional: true }, published: { + name: "published", type: "Boolean", attributes: [{ name: "@default", args: [{ name: "value", value: ExpressionUtils.literal(false) }] }], default: false }, author: { + name: "author", type: "User", attributes: [{ name: "@relation", args: [{ name: "fields", value: ExpressionUtils.array([ExpressionUtils.field("authorId")]) }, { name: "references", value: ExpressionUtils.array([ExpressionUtils.field("id")]) }, { name: "onUpdate", value: ExpressionUtils.literal("Cascade") }, { name: "onDelete", value: ExpressionUtils.literal("Cascade") }] }], relation: { opposite: "posts", fields: ["authorId"], references: ["id"], onUpdate: "Cascade", onDelete: "Cascade" } }, authorId: { + name: "authorId", type: "String", foreignKeyFor: [ "author" ] }, comments: { + name: "comments", type: "Comment", array: true, relation: { opposite: "post" } @@ -122,33 +141,40 @@ export const schema = { } }, Comment: { + name: "Comment", fields: { id: { + name: "id", type: "String", id: true, attributes: [{ name: "@id" }, { name: "@default", args: [{ name: "value", value: ExpressionUtils.call("cuid") }] }], default: ExpressionUtils.call("cuid") }, createdAt: { + name: "createdAt", type: "DateTime", attributes: [{ name: "@default", args: [{ name: "value", value: ExpressionUtils.call("now") }] }], default: ExpressionUtils.call("now") }, updatedAt: { + name: "updatedAt", type: "DateTime", updatedAt: true, attributes: [{ name: "@updatedAt" }] }, content: { + name: "content", type: "String" }, post: { + name: "post", type: "Post", optional: true, attributes: [{ name: "@relation", args: [{ name: "fields", value: ExpressionUtils.array([ExpressionUtils.field("postId")]) }, { name: "references", value: ExpressionUtils.array([ExpressionUtils.field("id")]) }, { name: "onUpdate", value: ExpressionUtils.literal("Cascade") }, { name: "onDelete", value: ExpressionUtils.literal("Cascade") }] }], relation: { opposite: "comments", fields: ["postId"], references: ["id"], onUpdate: "Cascade", onDelete: "Cascade" } }, postId: { + name: "postId", type: "String", optional: true, foreignKeyFor: [ @@ -162,37 +188,45 @@ export const schema = { } }, Profile: { + name: "Profile", fields: { id: { + name: "id", type: "String", id: true, attributes: [{ name: "@id" }, { name: "@default", args: [{ name: "value", value: ExpressionUtils.call("cuid") }] }], default: ExpressionUtils.call("cuid") }, createdAt: { + name: "createdAt", type: "DateTime", attributes: [{ name: "@default", args: [{ name: "value", value: ExpressionUtils.call("now") }] }], default: ExpressionUtils.call("now") }, updatedAt: { + name: "updatedAt", type: "DateTime", updatedAt: true, attributes: [{ name: "@updatedAt" }] }, bio: { + name: "bio", type: "String" }, age: { + name: "age", type: "Int", optional: true }, user: { + name: "user", type: "User", optional: true, attributes: [{ name: "@relation", args: [{ name: "fields", value: ExpressionUtils.array([ExpressionUtils.field("userId")]) }, { name: "references", value: ExpressionUtils.array([ExpressionUtils.field("id")]) }, { name: "onUpdate", value: ExpressionUtils.literal("Cascade") }, { name: "onDelete", value: ExpressionUtils.literal("Cascade") }] }], relation: { opposite: "profile", fields: ["userId"], references: ["id"], onUpdate: "Cascade", onDelete: "Cascade" } }, userId: { + name: "userId", type: "String", unique: true, optional: true, @@ -211,18 +245,22 @@ export const schema = { }, typeDefs: { CommonFields: { + name: "CommonFields", fields: { id: { + name: "id", type: "String", attributes: [{ name: "@id" }, { name: "@default", args: [{ name: "value", value: ExpressionUtils.call("cuid") }] }], default: ExpressionUtils.call("cuid") }, createdAt: { + name: "createdAt", type: "DateTime", attributes: [{ name: "@default", args: [{ name: "value", value: ExpressionUtils.call("now") }] }], default: ExpressionUtils.call("now") }, updatedAt: { + name: "updatedAt", type: "DateTime", updatedAt: true, attributes: [{ name: "@updatedAt" }] diff --git a/packages/runtime/test/typing/schema.ts b/packages/runtime/test/typing/schema.ts index 49bf584e..de56a9dc 100644 --- a/packages/runtime/test/typing/schema.ts +++ b/packages/runtime/test/typing/schema.ts @@ -12,52 +12,63 @@ export const schema = { }, models: { User: { + name: "User", fields: { id: { + name: "id", type: "Int", id: true, attributes: [{ name: "@id" }, { name: "@default", args: [{ name: "value", value: ExpressionUtils.call("autoincrement") }] }], default: ExpressionUtils.call("autoincrement") }, createdAt: { + name: "createdAt", type: "DateTime", attributes: [{ name: "@default", args: [{ name: "value", value: ExpressionUtils.call("now") }] }], default: ExpressionUtils.call("now") }, updatedAt: { + name: "updatedAt", type: "DateTime", updatedAt: true, attributes: [{ name: "@updatedAt" }] }, name: { + name: "name", type: "String" }, email: { + name: "email", type: "String", unique: true, attributes: [{ name: "@unique" }] }, role: { + name: "role", type: "Role", attributes: [{ name: "@default", args: [{ name: "value", value: ExpressionUtils.literal("USER") }] }], default: "USER" }, posts: { + name: "posts", type: "Post", array: true, relation: { opposite: "author" } }, profile: { + name: "profile", type: "Profile", optional: true, relation: { opposite: "user" } }, postCount: { + name: "postCount", type: "Int", attributes: [{ name: "@computed" }], computed: true }, identity: { + name: "identity", type: "Identity", optional: true, attributes: [{ name: "@json" }] @@ -75,36 +86,44 @@ export const schema = { } }, Post: { + name: "Post", fields: { id: { + name: "id", type: "Int", id: true, attributes: [{ name: "@id" }, { name: "@default", args: [{ name: "value", value: ExpressionUtils.call("autoincrement") }] }], default: ExpressionUtils.call("autoincrement") }, title: { + name: "title", type: "String" }, content: { + name: "content", type: "String" }, author: { + name: "author", type: "User", attributes: [{ name: "@relation", args: [{ name: "fields", value: ExpressionUtils.array([ExpressionUtils.field("authorId")]) }, { name: "references", value: ExpressionUtils.array([ExpressionUtils.field("id")]) }] }], relation: { opposite: "posts", fields: ["authorId"], references: ["id"] } }, authorId: { + name: "authorId", type: "Int", foreignKeyFor: [ "author" ] }, tags: { + name: "tags", type: "Tag", array: true, relation: { opposite: "posts" } }, meta: { + name: "meta", type: "Meta", optional: true, relation: { opposite: "post" } @@ -116,23 +135,28 @@ export const schema = { } }, Profile: { + name: "Profile", fields: { id: { + name: "id", type: "Int", id: true, attributes: [{ name: "@id" }, { name: "@default", args: [{ name: "value", value: ExpressionUtils.call("autoincrement") }] }], default: ExpressionUtils.call("autoincrement") }, age: { + name: "age", type: "Int" }, region: { + name: "region", type: "Region", optional: true, attributes: [{ name: "@relation", args: [{ name: "fields", value: ExpressionUtils.array([ExpressionUtils.field("regionCountry"), ExpressionUtils.field("regionCity")]) }, { name: "references", value: ExpressionUtils.array([ExpressionUtils.field("country"), ExpressionUtils.field("city")]) }] }], relation: { opposite: "profiles", fields: ["regionCountry", "regionCity"], references: ["country", "city"] } }, regionCountry: { + name: "regionCountry", type: "String", optional: true, foreignKeyFor: [ @@ -140,6 +164,7 @@ export const schema = { ] }, regionCity: { + name: "regionCity", type: "String", optional: true, foreignKeyFor: [ @@ -147,11 +172,13 @@ export const schema = { ] }, user: { + name: "user", type: "User", attributes: [{ name: "@relation", args: [{ name: "fields", value: ExpressionUtils.array([ExpressionUtils.field("userId")]) }, { name: "references", value: ExpressionUtils.array([ExpressionUtils.field("id")]) }] }], relation: { opposite: "profile", fields: ["userId"], references: ["id"] } }, userId: { + name: "userId", type: "Int", unique: true, attributes: [{ name: "@unique" }], @@ -167,17 +194,21 @@ export const schema = { } }, Tag: { + name: "Tag", fields: { id: { + name: "id", type: "Int", id: true, attributes: [{ name: "@id" }, { name: "@default", args: [{ name: "value", value: ExpressionUtils.call("autoincrement") }] }], default: ExpressionUtils.call("autoincrement") }, name: { + name: "name", type: "String" }, posts: { + name: "posts", type: "Post", array: true, relation: { opposite: "tags" } @@ -189,20 +220,25 @@ export const schema = { } }, Region: { + name: "Region", fields: { country: { + name: "country", type: "String", id: true }, city: { + name: "city", type: "String", id: true }, zip: { + name: "zip", type: "String", optional: true }, profiles: { + name: "profiles", type: "Profile", array: true, relation: { opposite: "region" } @@ -217,25 +253,31 @@ export const schema = { } }, Meta: { + name: "Meta", fields: { id: { + name: "id", type: "Int", id: true, attributes: [{ name: "@id" }, { name: "@default", args: [{ name: "value", value: ExpressionUtils.call("autoincrement") }] }], default: ExpressionUtils.call("autoincrement") }, reviewed: { + name: "reviewed", type: "Boolean" }, published: { + name: "published", type: "Boolean" }, post: { + name: "post", type: "Post", attributes: [{ name: "@relation", args: [{ name: "fields", value: ExpressionUtils.array([ExpressionUtils.field("postId")]) }, { name: "references", value: ExpressionUtils.array([ExpressionUtils.field("id")]) }] }], relation: { opposite: "meta", fields: ["postId"], references: ["id"] } }, postId: { + name: "postId", type: "Int", unique: true, attributes: [{ name: "@unique" }], @@ -253,19 +295,24 @@ export const schema = { }, typeDefs: { Identity: { + name: "Identity", fields: { providers: { + name: "providers", type: "IdentityProvider", array: true } } }, IdentityProvider: { + name: "IdentityProvider", fields: { id: { + name: "id", type: "String" }, name: { + name: "name", type: "String", optional: true } diff --git a/packages/sdk/src/model-utils.ts b/packages/sdk/src/model-utils.ts index 14532938..3ab4a01e 100644 --- a/packages/sdk/src/model-utils.ts +++ b/packages/sdk/src/model-utils.ts @@ -109,3 +109,14 @@ export function getAuthDecl(model: Model) { } return found; } + +export function getIdFields(dm: DataModel) { + return getAllFields(dm) + .filter((f) => isIdField(f, dm)) + .map((f) => f.name); +} + +/** + * Prefix for auxiliary relation fields generated for delegated models + */ +export const DELEGATE_AUX_RELATION_PREFIX = 'delegate_aux'; diff --git a/packages/sdk/src/prisma/prisma-schema-generator.ts b/packages/sdk/src/prisma/prisma-schema-generator.ts index aa4c9172..116ee872 100644 --- a/packages/sdk/src/prisma/prisma-schema-generator.ts +++ b/packages/sdk/src/prisma/prisma-schema-generator.ts @@ -1,3 +1,4 @@ +import { lowerCaseFirst } from '@zenstackhq/common-helpers'; import { AttributeArg, BooleanLiteral, @@ -16,6 +17,7 @@ import { GeneratorDecl, InvocationExpr, isArrayExpr, + isDataModel, isInvocationExpr, isLiteralExpr, isModel, @@ -29,12 +31,13 @@ import { StringLiteral, type AstNode, } from '@zenstackhq/language/ast'; +import { getAllAttributes, getAllFields, isDelegateModel } from '@zenstackhq/language/utils'; import { AstUtils } from 'langium'; import { match } from 'ts-pattern'; - -import { getAllAttributes, getAllFields } from '@zenstackhq/language/utils'; import { ModelUtils, ZModelCodeGenerator } from '..'; +import { DELEGATE_AUX_RELATION_PREFIX, getIdFields } from '../model-utils'; import { + AttributeArgValue, ModelFieldType, AttributeArg as PrismaAttributeArg, AttributeArgValue as PrismaAttributeArgValue, @@ -51,6 +54,10 @@ import { type SimpleField, } from './prisma-builder'; +// Some database providers like postgres and mysql have default limit to the length of identifiers +// Here we use a conservative value that should work for most cases, and truncate names if needed +const IDENTIFIER_NAME_MAX_LENGTH = 50 - DELEGATE_AUX_RELATION_PREFIX.length; + /** * Generates Prisma schema file */ @@ -62,6 +69,9 @@ export class PrismaSchemaGenerator { `; + // a mapping from full names to shortened names + private shortNameMap = new Map(); + constructor(private readonly zmodel: Model) {} async generate() { @@ -155,8 +165,10 @@ export class PrismaSchemaGenerator { if (ModelUtils.hasAttribute(field, '@computed')) { continue; // skip computed fields } - // TODO: exclude fields inherited from delegate - this.generateModelField(model, field, decl); + // exclude non-id fields inherited from delegate + if (ModelUtils.isIdField(field, decl) || !this.isInheritedFromDelegate(field, decl)) { + this.generateModelField(model, field, decl); + } } const allAttributes = getAllAttributes(decl); @@ -167,21 +179,11 @@ export class PrismaSchemaGenerator { // user defined comments pass-through decl.comments.forEach((c) => model.addComment(c)); - // TODO: delegate model handling - // // physical: generate relation fields on base models linking to concrete models - // this.generateDelegateRelationForBase(model, decl); - - // TODO: delegate model handling - // // physical: generate reverse relation fields on concrete models - // this.generateDelegateRelationForConcrete(model, decl); + // generate relation fields on base models linking to concrete models + this.generateDelegateRelationForBase(model, decl); - // TODO: delegate model handling - // // logical: expand relations on other models that reference delegated models to concrete models - // this.expandPolymorphicRelations(model, decl); - - // TODO: delegate model handling - // // logical: ensure relations inherited from delegate models - // this.ensureRelationsInheritedFromDelegate(model, decl); + // generate reverse relation fields on concrete models + this.generateDelegateRelationForConcrete(model, decl); } private isPrismaAttribute(attr: DataModelAttribute | DataFieldAttribute) { @@ -247,7 +249,7 @@ export class PrismaSchemaGenerator { // when building physical schema, exclude `@default` for id fields inherited from delegate base !( ModelUtils.isIdField(field, contextModel) && - this.isInheritedFromDelegate(field) && + this.isInheritedFromDelegate(field, contextModel) && attr.decl.$refText === '@default' ), ) @@ -276,8 +278,8 @@ export class PrismaSchemaGenerator { return !!model && !!model.$document && model.$document.uri.path.endsWith('plugin.zmodel'); } - private isInheritedFromDelegate(field: DataField) { - return field.$inheritedFrom && ModelUtils.isDelegateModel(field.$inheritedFrom); + private isInheritedFromDelegate(field: DataField, contextModel: DataModel) { + return field.$container !== contextModel && ModelUtils.isDelegateModel(field.$container); } private makeFieldAttribute(attr: DataFieldAttribute) { @@ -375,4 +377,100 @@ export class PrismaSchemaGenerator { const docs = [...field.comments]; _enum.addField(field.name, attributes, docs); } + + private generateDelegateRelationForBase(model: PrismaDataModel, decl: DataModel) { + if (!isDelegateModel(decl)) { + return; + } + + // collect concrete models inheriting this model + const concreteModels = this.getConcreteModels(decl); + + // generate an optional relation field in delegate base model to each concrete model + concreteModels.forEach((concrete) => { + const auxName = this.truncate(`${DELEGATE_AUX_RELATION_PREFIX}_${lowerCaseFirst(concrete.name)}`); + model.addField(auxName, new ModelFieldType(concrete.name, false, true)); + }); + } + + private generateDelegateRelationForConcrete(model: PrismaDataModel, concreteDecl: DataModel) { + // generate a relation field for each delegated base model + const base = concreteDecl.baseModel?.ref; + if (!base) { + return; + } + + const idFields = getIdFields(base); + + // add relation fields + const relationField = this.truncate(`${DELEGATE_AUX_RELATION_PREFIX}_${lowerCaseFirst(base.name)}`); + model.addField(relationField, base.name, [ + new PrismaFieldAttribute('@relation', [ + new PrismaAttributeArg( + 'fields', + new AttributeArgValue( + 'Array', + idFields.map( + (idField) => new AttributeArgValue('FieldReference', new PrismaFieldReference(idField)), + ), + ), + ), + new PrismaAttributeArg( + 'references', + new AttributeArgValue( + 'Array', + idFields.map( + (idField) => new AttributeArgValue('FieldReference', new PrismaFieldReference(idField)), + ), + ), + ), + new PrismaAttributeArg( + 'onDelete', + new AttributeArgValue('FieldReference', new PrismaFieldReference('Cascade')), + ), + new PrismaAttributeArg( + 'onUpdate', + new AttributeArgValue('FieldReference', new PrismaFieldReference('Cascade')), + ), + ]), + ]); + } + + private getConcreteModels(dataModel: DataModel): DataModel[] { + if (!isDelegateModel(dataModel)) { + return []; + } + return dataModel.$container.declarations.filter( + (d): d is DataModel => isDataModel(d) && d !== dataModel && d.baseModel?.ref === dataModel, + ); + } + + private truncate(name: string) { + if (name.length <= IDENTIFIER_NAME_MAX_LENGTH) { + return name; + } + + const existing = this.shortNameMap.get(name); + if (existing) { + return existing; + } + + const baseName = name.slice(0, IDENTIFIER_NAME_MAX_LENGTH); + let index = 0; + let shortName = `${baseName}_${index}`; + + while (true) { + const conflict = Array.from(this.shortNameMap.values()).find((v) => v === shortName); + if (!conflict) { + this.shortNameMap.set(name, shortName); + break; + } + + // try next index + index++; + shortName = `${baseName}_${index}`; + } + + return shortName; + } } diff --git a/packages/sdk/src/schema/schema.ts b/packages/sdk/src/schema/schema.ts index 208024a8..d7a38f9e 100644 --- a/packages/sdk/src/schema/schema.ts +++ b/packages/sdk/src/schema/schema.ts @@ -18,6 +18,8 @@ export type SchemaDef = { }; export type ModelDef = { + name: string; + baseModel?: string; fields: Record; attributes?: AttributeApplication[]; uniqueFields: Record< @@ -29,6 +31,7 @@ export type ModelDef = { >; idFields: string[]; computedFields?: Record; + isDelegate?: boolean; }; export type AttributeApplication = { @@ -53,6 +56,7 @@ export type RelationInfo = { }; export type FieldDef = { + name: string; type: string; id?: boolean; array?: boolean; @@ -64,6 +68,7 @@ export type FieldDef = { relation?: RelationInfo; foreignKeyFor?: string[]; computed?: boolean; + originModel?: string; }; export type ProcedureParam = { name: string; type: string; optional?: boolean }; @@ -91,6 +96,7 @@ export type MappedBuiltinType = string | boolean | number | bigint | Decimal | D export type EnumDef = Record; export type TypeDefDef = { + name: string; fields: Record; attributes?: AttributeApplication[]; }; diff --git a/packages/sdk/src/ts-schema-generator.ts b/packages/sdk/src/ts-schema-generator.ts index 91e04112..e5202b32 100644 --- a/packages/sdk/src/ts-schema-generator.ts +++ b/packages/sdk/src/ts-schema-generator.ts @@ -42,7 +42,15 @@ import path from 'node:path'; import { match } from 'ts-pattern'; import * as ts from 'typescript'; import { ModelUtils } from '.'; -import { getAttribute, getAuthDecl, hasAttribute, isUniqueField } from './model-utils'; +import { + getAttribute, + getAuthDecl, + getIdFields, + hasAttribute, + isDelegateModel, + isIdField, + isUniqueField, +} from './model-utils'; export class TsSchemaGenerator { public async generate(schemaFile: string, pluginModelFiles: string[], outputDir: string) { @@ -213,9 +221,28 @@ export class TsSchemaGenerator { private createDataModelObject(dm: DataModel) { const allFields = getAllFields(dm); - const allAttributes = getAllAttributes(dm); + const allAttributes = getAllAttributes(dm).filter((attr) => { + // exclude `@@delegate` attribute from base model + if (attr.decl.$refText === '@@delegate' && attr.$container !== dm) { + return false; + } + return true; + }); const fields: ts.PropertyAssignment[] = [ + // name + ts.factory.createPropertyAssignment('name', ts.factory.createStringLiteral(dm.name)), + + // baseModel + ...(dm.baseModel + ? [ + ts.factory.createPropertyAssignment( + 'baseModel', + ts.factory.createStringLiteral(dm.baseModel.$refText), + ), + ] + : []), + // fields ts.factory.createPropertyAssignment( 'fields', @@ -244,12 +271,17 @@ export class TsSchemaGenerator { ts.factory.createPropertyAssignment( 'idFields', ts.factory.createArrayLiteralExpression( - this.getIdFields(dm).map((idField) => ts.factory.createStringLiteral(idField)), + getIdFields(dm).map((idField) => ts.factory.createStringLiteral(idField)), ), ), // uniqueFields ts.factory.createPropertyAssignment('uniqueFields', this.createUniqueFieldsObject(dm)), + + // isDelegate + ...(isDelegateModel(dm) + ? [ts.factory.createPropertyAssignment('isDelegate', ts.factory.createTrue())] + : []), ]; const computedFields = dm.fields.filter((f) => hasAttribute(f, '@computed')); @@ -268,6 +300,9 @@ export class TsSchemaGenerator { const allAttributes = getAllAttributes(td); const fields: ts.PropertyAssignment[] = [ + // name + ts.factory.createPropertyAssignment('name', ts.factory.createStringLiteral(td.name)), + // fields ts.factory.createPropertyAssignment( 'fields', @@ -344,7 +379,28 @@ export class TsSchemaGenerator { } private createDataFieldObject(field: DataField, contextModel: DataModel | undefined) { - const objectFields = [ts.factory.createPropertyAssignment('type', this.generateFieldTypeLiteral(field))]; + const objectFields = [ + // name + ts.factory.createPropertyAssignment('name', ts.factory.createStringLiteral(field.name)), + // type + ts.factory.createPropertyAssignment('type', this.generateFieldTypeLiteral(field)), + ]; + + if ( + contextModel && + // id fields are duplicated in inherited models + !isIdField(field, contextModel) && + field.$container !== contextModel && + isDelegateModel(field.$container) + ) { + // field is inherited from delegate + objectFields.push( + ts.factory.createPropertyAssignment( + 'originModel', + ts.factory.createStringLiteral(field.$container.name), + ), + ); + } if (contextModel && ModelUtils.isIdField(field, contextModel)) { objectFields.push(ts.factory.createPropertyAssignment('id', ts.factory.createTrue())); @@ -668,12 +724,6 @@ export class TsSchemaGenerator { return undefined; } - private getIdFields(dm: DataModel) { - return getAllFields(dm) - .filter((f) => ModelUtils.isIdField(f, dm)) - .map((f) => f.name); - } - private createUniqueFieldsObject(dm: DataModel) { const properties: ts.PropertyAssignment[] = []; diff --git a/samples/blog/zenstack/schema.ts b/samples/blog/zenstack/schema.ts index c1ad9a73..64515be7 100644 --- a/samples/blog/zenstack/schema.ts +++ b/samples/blog/zenstack/schema.ts @@ -12,48 +12,58 @@ export const schema = { }, models: { User: { + name: "User", fields: { id: { + name: "id", type: "String", id: true, attributes: [{ name: "@id" }, { name: "@default", args: [{ name: "value", value: ExpressionUtils.call("cuid") }] }], default: ExpressionUtils.call("cuid") }, createdAt: { + name: "createdAt", type: "DateTime", attributes: [{ name: "@default", args: [{ name: "value", value: ExpressionUtils.call("now") }] }], default: ExpressionUtils.call("now") }, updatedAt: { + name: "updatedAt", type: "DateTime", updatedAt: true, attributes: [{ name: "@updatedAt" }] }, email: { + name: "email", type: "String", unique: true, attributes: [{ name: "@unique" }] }, name: { + name: "name", type: "String", optional: true }, postCount: { + name: "postCount", type: "Int", attributes: [{ name: "@computed" }], computed: true }, role: { + name: "role", type: "Role", attributes: [{ name: "@default", args: [{ name: "value", value: ExpressionUtils.literal("USER") }] }], default: "USER" }, posts: { + name: "posts", type: "Post", array: true, relation: { opposite: "author" } }, profile: { + name: "profile", type: "Profile", optional: true, relation: { opposite: "user" } @@ -71,38 +81,46 @@ export const schema = { } }, Profile: { + name: "Profile", fields: { id: { + name: "id", type: "String", id: true, attributes: [{ name: "@id" }, { name: "@default", args: [{ name: "value", value: ExpressionUtils.call("cuid") }] }], default: ExpressionUtils.call("cuid") }, createdAt: { + name: "createdAt", type: "DateTime", attributes: [{ name: "@default", args: [{ name: "value", value: ExpressionUtils.call("now") }] }], default: ExpressionUtils.call("now") }, updatedAt: { + name: "updatedAt", type: "DateTime", updatedAt: true, attributes: [{ name: "@updatedAt" }] }, bio: { + name: "bio", type: "String", optional: true }, age: { + name: "age", type: "Int", optional: true }, user: { + name: "user", type: "User", optional: true, attributes: [{ name: "@relation", args: [{ name: "fields", value: ExpressionUtils.array([ExpressionUtils.field("userId")]) }, { name: "references", value: ExpressionUtils.array([ExpressionUtils.field("id")]) }] }], relation: { opposite: "profile", fields: ["userId"], references: ["id"] } }, userId: { + name: "userId", type: "String", unique: true, optional: true, @@ -119,40 +137,49 @@ export const schema = { } }, Post: { + name: "Post", fields: { id: { + name: "id", type: "String", id: true, attributes: [{ name: "@id" }, { name: "@default", args: [{ name: "value", value: ExpressionUtils.call("cuid") }] }], default: ExpressionUtils.call("cuid") }, createdAt: { + name: "createdAt", type: "DateTime", attributes: [{ name: "@default", args: [{ name: "value", value: ExpressionUtils.call("now") }] }], default: ExpressionUtils.call("now") }, updatedAt: { + name: "updatedAt", type: "DateTime", updatedAt: true, attributes: [{ name: "@updatedAt" }] }, title: { + name: "title", type: "String" }, content: { + name: "content", type: "String" }, published: { + name: "published", type: "Boolean", attributes: [{ name: "@default", args: [{ name: "value", value: ExpressionUtils.literal(false) }] }], default: false }, author: { + name: "author", type: "User", attributes: [{ name: "@relation", args: [{ name: "fields", value: ExpressionUtils.array([ExpressionUtils.field("authorId")]) }, { name: "references", value: ExpressionUtils.array([ExpressionUtils.field("id")]) }] }], relation: { opposite: "posts", fields: ["authorId"], references: ["id"] } }, authorId: { + name: "authorId", type: "String", foreignKeyFor: [ "author" @@ -167,18 +194,22 @@ export const schema = { }, typeDefs: { CommonFields: { + name: "CommonFields", fields: { id: { + name: "id", type: "String", attributes: [{ name: "@id" }, { name: "@default", args: [{ name: "value", value: ExpressionUtils.call("cuid") }] }], default: ExpressionUtils.call("cuid") }, createdAt: { + name: "createdAt", type: "DateTime", attributes: [{ name: "@default", args: [{ name: "value", value: ExpressionUtils.call("now") }] }], default: ExpressionUtils.call("now") }, updatedAt: { + name: "updatedAt", type: "DateTime", updatedAt: true, attributes: [{ name: "@updatedAt" }]