Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions packages/cli/test/ts-schema-gen.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -266,4 +266,63 @@ type Address with Base {
},
});
});

it('merges fields and attributes from base models', async () => {
const { schema } = await generateTsSchema(`
model Base {
id String @id @default(uuid())
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
type String
@@delegate(type)
}

model User extends Base {
email String @unique
}
`);
expect(schema).toMatchObject({
models: {
Base: {
fields: {
id: {
type: 'String',
id: true,
default: expect.objectContaining({ function: 'uuid', kind: 'call' }),
},
createdAt: {
type: 'DateTime',
default: expect.objectContaining({ function: 'now', kind: 'call' }),
},
updatedAt: { type: 'DateTime', updatedAt: true },
type: { type: 'String' },
},
attributes: [
{
name: '@@delegate',
args: [{ name: 'discriminator', value: { kind: 'field', field: 'type' } }],
},
],
isDelegate: true,
},
User: {
baseModel: 'Base',
fields: {
id: { type: 'String' },
createdAt: {
type: 'DateTime',
default: expect.objectContaining({ function: 'now', kind: 'call' }),
originModel: 'Base',
},
updatedAt: { type: 'DateTime', updatedAt: true, originModel: 'Base' },
type: { type: 'String', originModel: 'Base' },
email: { type: 'String' },
},
uniqueFields: expect.objectContaining({
email: { type: 'String' },
}),
},
},
});
});
});
10 changes: 1 addition & 9 deletions packages/language/src/ast.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import type { AstNode } from 'langium';
import { AbstractDeclaration, BinaryExpr, DataModel, type ExpressionType } from './generated/ast';
import { AbstractDeclaration, BinaryExpr, DataField, DataModel, type ExpressionType } from './generated/ast';

export type { AstNode, Reference } from 'langium';
export * from './generated/ast';
Expand Down Expand Up @@ -46,14 +46,6 @@ declare module './ast' {
$resolvedParam?: AttributeParam;
}

interface DataField {
$inheritedFrom?: DataModel;
}

interface DataModelAttribute {
$inheritedFrom?: DataModel;
}

export interface DataModel {
/**
* All fields including those marked with `@ignore`
Expand Down
41 changes: 33 additions & 8 deletions packages/language/src/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -161,10 +161,6 @@ export function resolved<T extends AstNode>(ref: Reference<T>): T {
return ref.ref;
}

export function getModelFieldsWithBases(model: DataModel, includeDelegate = true) {
return [...model.fields, ...getRecursiveBases(model, includeDelegate).flatMap((base) => base.fields)];
}

export function getRecursiveBases(
decl: DataModel | TypeDef,
includeDelegate = true,
Expand Down Expand Up @@ -533,22 +529,51 @@ export function isMemberContainer(node: unknown): node is DataModel | TypeDef {
return isDataModel(node) || isTypeDef(node);
}

export function getAllFields(decl: DataModel | TypeDef, includeIgnored = false): DataField[] {
export function getAllFields(
decl: DataModel | TypeDef,
includeIgnored = false,
seen: Set<DataModel | TypeDef> = new Set(),
): DataField[] {
if (seen.has(decl)) {
return [];
}
seen.add(decl);

const fields: DataField[] = [];
for (const mixin of decl.mixins) {
invariant(mixin.ref, `Mixin ${mixin.$refText} is not resolved`);
fields.push(...getAllFields(mixin.ref));
fields.push(...getAllFields(mixin.ref, includeIgnored, seen));
}

if (isDataModel(decl) && decl.baseModel) {
invariant(decl.baseModel.ref, `Base model ${decl.baseModel.$refText} is not resolved`);
fields.push(...getAllFields(decl.baseModel.ref, includeIgnored, seen));
}

fields.push(...decl.fields.filter((f) => includeIgnored || !hasAttribute(f, '@ignore')));
return fields;
}

export function getAllAttributes(decl: DataModel | TypeDef): DataModelAttribute[] {
export function getAllAttributes(
decl: DataModel | TypeDef,
seen: Set<DataModel | TypeDef> = new Set(),
): DataModelAttribute[] {
if (seen.has(decl)) {
return [];
}
seen.add(decl);

const attributes: DataModelAttribute[] = [];
for (const mixin of decl.mixins) {
invariant(mixin.ref, `Mixin ${mixin.$refText} is not resolved`);
attributes.push(...getAllAttributes(mixin.ref));
attributes.push(...getAllAttributes(mixin.ref, seen));
}

if (isDataModel(decl) && decl.baseModel) {
invariant(decl.baseModel.ref, `Base model ${decl.baseModel.$refText} is not resolved`);
attributes.push(...getAllAttributes(decl.baseModel.ref, seen));
}

attributes.push(...decl.attributes);
return attributes;
}
38 changes: 33 additions & 5 deletions packages/language/src/validators/datamodel-validator.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import { invariant } from '@zenstackhq/common-helpers';
import { AstUtils, type AstNode, type DiagnosticInfo, type ValidationAcceptor } from 'langium';
import { IssueCodes, SCALAR_TYPES } from '../constants';
import {
Expand All @@ -16,8 +17,8 @@ import {
} from '../generated/ast';
import {
getAllAttributes,
getAllFields,
getLiteral,
getModelFieldsWithBases,
getModelIdFields,
getModelUniqueFields,
getUniqueFields,
Expand All @@ -32,7 +33,7 @@ import { validateDuplicatedDeclarations, type AstValidator } from './common';
*/
export default class DataModelValidator implements AstValidator<DataModel> {
validate(dm: DataModel, accept: ValidationAcceptor): void {
validateDuplicatedDeclarations(dm, getModelFieldsWithBases(dm), accept);
validateDuplicatedDeclarations(dm, getAllFields(dm), accept);
this.validateAttributes(dm, accept);
this.validateFields(dm, accept);
if (dm.mixins.length > 0) {
Expand All @@ -42,7 +43,7 @@ export default class DataModelValidator implements AstValidator<DataModel> {
}

private validateFields(dm: DataModel, accept: ValidationAcceptor) {
const allFields = getModelFieldsWithBases(dm);
const allFields = getAllFields(dm);
const idFields = allFields.filter((f) => f.attributes.find((attr) => attr.decl.ref?.name === '@id'));
const uniqueFields = allFields.filter((f) => f.attributes.find((attr) => attr.decl.ref?.name === '@unique'));
const modelLevelIds = getModelIdFields(dm);
Expand Down Expand Up @@ -266,7 +267,7 @@ export default class DataModelValidator implements AstValidator<DataModel> {
const oppositeModel = field.type.reference!.ref! as DataModel;

// Use name because the current document might be updated
let oppositeFields = getModelFieldsWithBases(oppositeModel, false).filter(
let oppositeFields = getAllFields(oppositeModel, false).filter(
(f) =>
f !== field && // exclude self in case of self relation
f.type.reference?.ref?.name === contextModel.name,
Expand Down Expand Up @@ -438,11 +439,38 @@ export default class DataModelValidator implements AstValidator<DataModel> {
if (!model.baseModel) {
return;
}
if (model.baseModel.ref && !isDelegateModel(model.baseModel.ref)) {

invariant(model.baseModel.ref, 'baseModel must be resolved');

// check if the base model is a delegate model
if (!isDelegateModel(model.baseModel.ref)) {
accept('error', `Model ${model.baseModel.$refText} cannot be extended because it's not a delegate model`, {
node: model,
property: 'baseModel',
});
return;
}

// check for cyclic inheritance
const seen: DataModel[] = [];
const todo = [model.baseModel.ref];
while (todo.length > 0) {
const current = todo.shift()!;
if (seen.includes(current)) {
accept(
'error',
`Cyclic inheritance detected: ${seen.map((m) => m.name).join(' -> ')} -> ${current.name}`,
{
node: model,
},
);
return;
}
seen.push(current);
if (current.baseModel) {
invariant(current.baseModel.ref, 'baseModel must be resolved');
todo.push(current.baseModel.ref);
}
}
}

Expand Down
9 changes: 4 additions & 5 deletions packages/language/src/zmodel-linker.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ import {
AttributeParam,
BinaryExpr,
BooleanLiteral,
DataModel,
DataField,
DataFieldType,
DataModel,
Enum,
EnumField,
type ExpressionType,
Expand All @@ -43,19 +43,19 @@ import {
UnaryExpr,
isArrayExpr,
isBooleanLiteral,
isDataModel,
isDataField,
isDataFieldType,
isDataModel,
isEnum,
isNumberLiteral,
isReferenceExpr,
isStringLiteral,
} from './ast';
import {
getAllFields,
getAllLoadedAndReachableDataModelsAndTypeDefs,
getAuthDecl,
getContainingDataModel,
getModelFieldsWithBases,
isAuthInvocation,
isFutureExpr,
isMemberContainer,
Expand Down Expand Up @@ -397,8 +397,7 @@ export class ZModelLinker extends DefaultLinker {
const transitiveDataModel = attrAppliedOn.type.reference?.ref as DataModel;
if (transitiveDataModel) {
// resolve references in the context of the transitive data model
const scopeProvider = (name: string) =>
getModelFieldsWithBases(transitiveDataModel).find((f) => f.name === name);
const scopeProvider = (name: string) => getAllFields(transitiveDataModel).find((f) => f.name === name);
if (isArrayExpr(node.value)) {
node.value.items.forEach((item) => {
if (isReferenceExpr(item)) {
Expand Down
6 changes: 3 additions & 3 deletions packages/language/src/zmodel-scope.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ import { match } from 'ts-pattern';
import {
BinaryExpr,
MemberAccessExpr,
isDataModel,
isDataField,
isDataModel,
isEnumField,
isInvocationExpr,
isMemberAccessExpr,
Expand All @@ -31,9 +31,9 @@ import {
} from './ast';
import { PLUGIN_MODULE_NAME, STD_LIB_MODULE_NAME } from './constants';
import {
getAllFields,
getAllLoadedAndReachableDataModelsAndTypeDefs,
getAuthDecl,
getModelFieldsWithBases,
getRecursiveBases,
isAuthInvocation,
isCollectionPredicate,
Expand Down Expand Up @@ -231,7 +231,7 @@ export class ZModelScopeProvider extends DefaultScopeProvider {

private createScopeForContainer(node: AstNode | undefined, globalScope: Scope, includeTypeDefScope = false) {
if (isDataModel(node)) {
return this.createScopeForNodes(getModelFieldsWithBases(node), globalScope);
return this.createScopeForNodes(getAllFields(node), globalScope);
} else if (includeTypeDefScope && isTypeDef(node)) {
return this.createScopeForNodes(node.fields, globalScope);
} else {
Expand Down
Loading