From 678134f34f2cde95001d2d96ceeee1c7e5d617a8 Mon Sep 17 00:00:00 2001 From: ymc9 <104139426+ymc9@users.noreply.github.com> Date: Tue, 23 Jan 2024 15:00:21 +0800 Subject: [PATCH 1/3] refactor: simplify zmodel linking by improving scope computation; make AST cloning from base models more robust --- packages/language/src/ast.ts | 20 +- packages/schema/src/cli/cli-util.ts | 2 +- .../validator/datamodel-validator.ts | 40 ++-- .../src/language-server/validator/utils.ts | 4 +- .../src/language-server/zmodel-code-action.ts | 11 +- .../src/language-server/zmodel-linker.ts | 65 +----- .../src/language-server/zmodel-scope.ts | 213 +++++++++++------- packages/schema/src/utils/ast-utils.ts | 74 ++++-- .../validation/attribute-validation.test.ts | 5 +- packages/schema/tests/utils.ts | 4 +- packages/testtools/src/model.ts | 4 +- 11 files changed, 249 insertions(+), 193 deletions(-) diff --git a/packages/language/src/ast.ts b/packages/language/src/ast.ts index c8637115a..86dd55bed 100644 --- a/packages/language/src/ast.ts +++ b/packages/language/src/ast.ts @@ -1,7 +1,8 @@ -import { AbstractDeclaration, ExpressionType, BinaryExpr } from './generated/ast'; +import { AstNode } from 'langium'; +import { AbstractDeclaration, BinaryExpr, DataModel, ExpressionType } from './generated/ast'; -export * from './generated/ast'; export { AstNode, Reference } from 'langium'; +export * from './generated/ast'; /** * Shape of type resolution result: an expression type or reference to a declaration @@ -44,18 +45,19 @@ declare module './generated/ast' { $resolvedParam?: AttributeParam; } - interface DataModel { - /** - * Resolved fields, include inherited fields - */ - $resolvedFields: Array; + interface DataModelField { + $inheritedFrom?: DataModel; } - interface DataModelField { - $isInherited?: boolean; + interface DataModelAttribute { + $inheritedFrom?: DataModel; } } +export interface InheritableNode extends AstNode { + $inheritedFrom?: DataModel; +} + declare module 'langium' { export interface AstNode { /** diff --git a/packages/schema/src/cli/cli-util.ts b/packages/schema/src/cli/cli-util.ts index 000e92ca7..2cfa18fcb 100644 --- a/packages/schema/src/cli/cli-util.ts +++ b/packages/schema/src/cli/cli-util.ts @@ -89,7 +89,7 @@ export async function loadDocument(fileName: string): Promise { validationAfterMerge(model); - mergeBaseModel(model); + mergeBaseModel(model, services.references.Linker); return model; } diff --git a/packages/schema/src/language-server/validator/datamodel-validator.ts b/packages/schema/src/language-server/validator/datamodel-validator.ts index ce1886f5e..dd03e6cbf 100644 --- a/packages/schema/src/language-server/validator/datamodel-validator.ts +++ b/packages/schema/src/language-server/validator/datamodel-validator.ts @@ -8,6 +8,7 @@ import { } from '@zenstackhq/language/ast'; import { analyzePolicies, getLiteral, getModelIdFields, getModelUniqueFields } from '@zenstackhq/sdk'; import { AstNode, DiagnosticInfo, getDocument, ValidationAcceptor } from 'langium'; +import { getModelFieldsWithBases } from '../../utils/ast-utils'; import { IssueCodes, SCALAR_TYPES } from '../constants'; import { AstValidator } from '../types'; import { getUniqueFields } from '../utils'; @@ -20,16 +21,15 @@ import { validateDuplicatedDeclarations } from './utils'; export default class DataModelValidator implements AstValidator { validate(dm: DataModel, accept: ValidationAcceptor): void { this.validateBaseAbstractModel(dm, accept); - validateDuplicatedDeclarations(dm.$resolvedFields, accept); + validateDuplicatedDeclarations(getModelFieldsWithBases(dm), accept); this.validateAttributes(dm, accept); this.validateFields(dm, accept); } private validateFields(dm: DataModel, accept: ValidationAcceptor) { - const idFields = dm.$resolvedFields.filter((f) => f.attributes.find((attr) => attr.decl.ref?.name === '@id')); - const uniqueFields = dm.$resolvedFields.filter((f) => - f.attributes.find((attr) => attr.decl.ref?.name === '@unique') - ); + const allFields = getModelFieldsWithBases(dm); + const idFields = allFields.filter((f) => f.attributes.find((attr) => attr.decl.ref?.name === '@id')); + const uniqueFields = allFields.filter((f) => f.attributes.find((attr) => attr.decl.ref?.name === '@unique')); const modelLevelIds = getModelIdFields(dm); const modelUniqueFields = getModelUniqueFields(dm); @@ -42,7 +42,7 @@ export default class DataModelValidator implements AstValidator { const { allows, denies, hasFieldValidation } = analyzePolicies(dm); if (allows.length > 0 || denies.length > 0 || hasFieldValidation) { // TODO: relax this requirement to require only @unique fields - // when access policies or field valdaition is used, require an @id field + // when access policies or field validation is used, require an @id field accept( 'error', 'Model must include a field with @id or @unique attribute, or a model-level @@id or @@unique attribute to use access policies', @@ -74,10 +74,10 @@ export default class DataModelValidator implements AstValidator { dm.fields.forEach((field) => this.validateField(field, accept)); if (!dm.isAbstract) { - dm.$resolvedFields + allFields .filter((x) => isDataModel(x.type.reference?.ref)) .forEach((y) => { - this.validateRelationField(y, accept); + this.validateRelationField(dm, y, accept); }); } } @@ -194,7 +194,7 @@ export default class DataModelValidator implements AstValidator { // points back const oppositeModel = field.type.reference?.ref as DataModel; if (oppositeModel) { - const oppositeModelFields = oppositeModel.$resolvedFields as DataModelField[]; + const oppositeModelFields = getModelFieldsWithBases(oppositeModel); for (const oppositeField of oppositeModelFields) { // find the opposite relation with the matching name const relAttr = oppositeField.attributes.find((a) => a.decl.ref?.name === '@relation'); @@ -213,7 +213,7 @@ export default class DataModelValidator implements AstValidator { return false; } - private validateRelationField(field: DataModelField, accept: ValidationAcceptor) { + private validateRelationField(contextModel: DataModel, field: DataModelField, accept: ValidationAcceptor) { const thisRelation = this.parseRelation(field, accept); if (!thisRelation.valid) { return; @@ -223,8 +223,8 @@ export default class DataModelValidator implements AstValidator { const oppositeModel = field.type.reference!.ref! as DataModel; // Use name because the current document might be updated - let oppositeFields = oppositeModel.$resolvedFields.filter( - (f) => f.type.reference?.ref?.name === field.$container.name + let oppositeFields = getModelFieldsWithBases(oppositeModel).filter( + (f) => f.type.reference?.ref?.name === contextModel.name ); oppositeFields = oppositeFields.filter((f) => { const fieldRel = this.parseRelation(f); @@ -232,13 +232,13 @@ export default class DataModelValidator implements AstValidator { }); if (oppositeFields.length === 0) { - const node = field.$isInherited ? field.$container : field; - const info: DiagnosticInfo = { node, code: IssueCodes.MissingOppositeRelation }; + const info: DiagnosticInfo = { + node: field, + code: IssueCodes.MissingOppositeRelation, + }; info.property = 'name'; - // use cstNode because the field might be inherited from parent model - // eslint-disable-next-line @typescript-eslint/no-non-null-assertion - const container = field.$cstNode!.element.$container as DataModel; + const container = field.$container; const relationFieldDocUri = getDocument(container).textDocument.uri; const relationDataModelName = container.name; @@ -247,20 +247,20 @@ export default class DataModelValidator implements AstValidator { relationFieldName: field.name, relationDataModelName, relationFieldDocUri, - dataModelName: field.$container.name, + dataModelName: contextModel.name, }; info.data = data; accept( 'error', - `The relation field "${field.name}" on model "${field.$container.name}" is missing an opposite relation field on model "${oppositeModel.name}"`, + `The relation field "${field.name}" on model "${contextModel.name}" is missing an opposite relation field on model "${oppositeModel.name}"`, info ); return; } else if (oppositeFields.length > 1) { oppositeFields - .filter((x) => !x.$isInherited) + .filter((x) => !x.$inheritedFrom) .forEach((f) => { if (this.isSelfRelation(f)) { // self relations are partial diff --git a/packages/schema/src/language-server/validator/utils.ts b/packages/schema/src/language-server/validator/utils.ts index 50e2263d7..340f471b8 100644 --- a/packages/schema/src/language-server/validator/utils.ts +++ b/packages/schema/src/language-server/validator/utils.ts @@ -33,8 +33,8 @@ export function validateDuplicatedDeclarations( for (const [name, decls] of Object.entries(groupByName)) { if (decls.length > 1) { let errorField = decls[1]; - if (decls[0].$type === 'DataModelField') { - const nonInheritedFields = decls.filter((x) => !(x as DataModelField).$isInherited); + if (isDataModelField(decls[0])) { + const nonInheritedFields = decls.filter((x) => !(x as DataModelField).$inheritedFrom); if (nonInheritedFields.length > 0) { errorField = nonInheritedFields.slice(-1)[0]; } diff --git a/packages/schema/src/language-server/zmodel-code-action.ts b/packages/schema/src/language-server/zmodel-code-action.ts index aace4d0fe..5b6a6c95a 100644 --- a/packages/schema/src/language-server/zmodel-code-action.ts +++ b/packages/schema/src/language-server/zmodel-code-action.ts @@ -2,18 +2,19 @@ import { DataModel, DataModelField, Model, isDataModel } from '@zenstackhq/langu import { AstReflection, CodeActionProvider, - getDocument, IndexManager, LangiumDocument, LangiumDocuments, LangiumServices, MaybePromise, + getDocument, } from 'langium'; import { CodeAction, CodeActionKind, CodeActionParams, Command, Diagnostic } from 'vscode-languageserver'; +import { getModelFieldsWithBases } from '../utils/ast-utils'; import { IssueCodes } from './constants'; -import { ZModelFormatter } from './zmodel-formatter'; import { MissingOppositeRelationData } from './validator/datamodel-validator'; +import { ZModelFormatter } from './zmodel-formatter'; export class ZModelCodeActionProvider implements CodeActionProvider { protected readonly reflection: AstReflection; @@ -92,8 +93,8 @@ export class ZModelCodeActionProvider implements CodeActionProvider { let newText = ''; if (fieldAstNode.type.array) { - //post Post[] - const idField = container.$resolvedFields.find((f) => + // post Post[] + const idField = getModelFieldsWithBases(container).find((f) => f.attributes.find((attr) => attr.decl.ref?.name === '@id') ) as DataModelField; @@ -111,7 +112,7 @@ export class ZModelCodeActionProvider implements CodeActionProvider { const idFieldName = idField.name; const referenceIdFieldName = fieldName + this.upperCaseFirstLetter(idFieldName); - if (!oppositeModel.$resolvedFields.find((f) => f.name === referenceIdFieldName)) { + if (!getModelFieldsWithBases(oppositeModel).find((f) => f.name === referenceIdFieldName)) { referenceField = '\n' + indent + `${referenceIdFieldName} ${idField.type.type}`; } diff --git a/packages/schema/src/language-server/zmodel-linker.ts b/packages/schema/src/language-server/zmodel-linker.ts index ef97cf4b6..ccfd9411a 100644 --- a/packages/schema/src/language-server/zmodel-linker.ts +++ b/packages/schema/src/language-server/zmodel-linker.ts @@ -55,8 +55,8 @@ import { CancellationToken } from 'vscode-jsonrpc'; import { getAllDeclarationsFromImports, getContainingDataModel, + getModelFieldsWithBases, isAuthInvocation, - isCollectionPredicate, } from '../utils/ast-utils'; import { mapBuiltinTypeToExpressionType } from './validator/utils'; @@ -261,26 +261,9 @@ export class ZModelLinker extends DefaultLinker { } private resolveReference(node: ReferenceExpr, document: LangiumDocument, extraScopes: ScopeProvider[]) { - this.linkReference(node, 'target', document, extraScopes); - node.args.forEach((arg) => this.resolve(arg, document, extraScopes)); + this.resolveDefault(node, document, extraScopes); if (node.target.ref) { - // if the reference is inside the RHS of a collection predicate, it cannot be resolve to a field - // not belonging to the collection's model type - - const collectionPredicateContext = this.getCollectionPredicateContextDataModel(node); - if ( - // inside a collection predicate RHS - collectionPredicateContext && - // current ref expr is resolved to a field - isDataModelField(node.target.ref) && - // the resolved field doesn't belong to the collection predicate's operand's type - node.target.ref.$container !== collectionPredicateContext - ) { - this.unresolvableRefExpr(node); - return; - } - // resolve type if (node.target.ref.$type === EnumField) { this.resolveToBuiltinTypeOrDecl(node, node.target.ref.$container); @@ -290,26 +273,6 @@ export class ZModelLinker extends DefaultLinker { } } - private getCollectionPredicateContextDataModel(node: ReferenceExpr) { - let curr: AstNode | undefined = node; - while (curr) { - if ( - curr.$container && - // parent is a collection predicate - isCollectionPredicate(curr.$container) && - // the collection predicate's LHS is resolved to a DataModel - isDataModel(curr.$container.left.$resolvedType?.decl) && - // current node is the RHS - curr.$containerProperty === 'right' - ) { - // return the resolved type of LHS - return curr.$container.left.$resolvedType?.decl; - } - curr = curr.$container; - } - return undefined; - } - private resolveArray(node: ArrayExpr, document: LangiumDocument, extraScopes: ScopeProvider[]) { node.items.forEach((item) => this.resolve(item, document, extraScopes)); @@ -372,14 +335,11 @@ export class ZModelLinker extends DefaultLinker { document: LangiumDocument, extraScopes: ScopeProvider[] ) { - this.resolve(node.operand, document, extraScopes); + this.resolveDefault(node, document, extraScopes); const operandResolved = node.operand.$resolvedType; if (operandResolved && !operandResolved.array && isDataModel(operandResolved.decl)) { - const modelDecl = operandResolved.decl as DataModel; - const provider = (name: string) => modelDecl.$resolvedFields.find((f) => f.name === name); // member access is resolved only in the context of the operand type - this.linkReference(node, 'member', document, [provider], true); if (node.member.ref) { this.resolveToDeclaredType(node, node.member.ref.type); @@ -393,20 +353,10 @@ export class ZModelLinker extends DefaultLinker { } private resolveCollectionPredicate(node: BinaryExpr, document: LangiumDocument, extraScopes: ScopeProvider[]) { - this.resolve(node.left, document, extraScopes); + this.resolveDefault(node, document, extraScopes); const resolvedType = node.left.$resolvedType; if (resolvedType && isDataModel(resolvedType.decl) && resolvedType.array) { - const dataModelDecl = resolvedType.decl; - const provider = (name: string) => { - if (name === 'this') { - return dataModelDecl; - } else { - return dataModelDecl.$resolvedFields.find((f) => f.name === name); - } - }; - extraScopes = [provider, ...extraScopes]; - this.resolve(node.right, document, extraScopes); this.resolveToBuiltinTypeOrDecl(node, 'Boolean'); } else { // error is reported in validation pass @@ -460,10 +410,11 @@ export class ZModelLinker extends DefaultLinker { // // In model B, the attribute argument "myId" is resolved to the field "myId" in model A - const transtiveDataModel = attrAppliedOn.type.reference?.ref as DataModel; - if (transtiveDataModel) { + const transitiveDataModel = attrAppliedOn.type.reference?.ref as DataModel; + if (transitiveDataModel) { // resolve references in the context of the transitive data model - const scopeProvider = (name: string) => transtiveDataModel.$resolvedFields.find((f) => f.name === name); + const scopeProvider = (name: string) => + getModelFieldsWithBases(transitiveDataModel).find((f) => f.name === name); if (isArrayExpr(node.value)) { node.value.items.forEach((item) => { if (isReferenceExpr(item)) { diff --git a/packages/schema/src/language-server/zmodel-scope.ts b/packages/schema/src/language-server/zmodel-scope.ts index 8eda869e8..f0d346b36 100644 --- a/packages/schema/src/language-server/zmodel-scope.ts +++ b/packages/schema/src/language-server/zmodel-scope.ts @@ -1,7 +1,6 @@ import { - DataModel, + BinaryExpr, MemberAccessExpr, - Model, isDataModel, isDataModelField, isEnumField, @@ -9,6 +8,7 @@ import { isMemberAccessExpr, isModel, isReferenceExpr, + isThisExpr, } from '@zenstackhq/language/ast'; import { getAuthModel, getDataModels } from '@zenstackhq/sdk'; import { @@ -19,7 +19,6 @@ import { EMPTY_SCOPE, LangiumDocument, LangiumServices, - Mutable, PrecomputedScopes, ReferenceInfo, Scope, @@ -30,8 +29,16 @@ import { stream, streamAllContents, } from 'langium'; +import { match } from 'ts-pattern'; import { CancellationToken } from 'vscode-jsonrpc'; -import { resolveImportUri } from '../utils/ast-utils'; +import { + getModelFieldsWithBases, + getRecursiveBases, + isAuthInvocation, + isCollectionPredicate, + isFutureInvocation, + resolveImportUri, +} from '../utils/ast-utils'; import { PLUGIN_MODULE_NAME, STD_LIB_MODULE_NAME } from './constants'; /** @@ -66,49 +73,18 @@ export class ZModelScopeComputation extends DefaultScopeComputation { return result; } - override computeLocalScopes( - document: LangiumDocument, - cancelToken?: CancellationToken | undefined - ): Promise { - const result = super.computeLocalScopes(document, cancelToken); - - //the $resolvedFields would be used in Linking stage for all the documents - //so we need to set it at the end of the scope computation - this.resolveBaseModels(document); - return result; - } - - private resolveBaseModels(document: LangiumDocument) { - const model = document.parseResult.value as Model; - - model.declarations.forEach((decl) => { - if (decl.$type === 'DataModel') { - const dataModel = decl as DataModel; - dataModel.$resolvedFields = [...dataModel.fields]; - this.getRecursiveSuperTypes(dataModel).forEach((superType) => { - superType.fields.forEach((field) => { - const cloneField = Object.assign({}, field); - cloneField.$isInherited = true; - const mutable = cloneField as Mutable; - // update container - mutable.$container = dataModel; - dataModel.$resolvedFields.push(cloneField); - }); - }); - } - }); - } + override processNode(node: AstNode, document: LangiumDocument, scopes: PrecomputedScopes) { + super.processNode(node, document, scopes); - private getRecursiveSuperTypes(dataModel: DataModel): DataModel[] { - const result: DataModel[] = []; - dataModel.superTypes.forEach((superType) => { - const superTypeDecl = superType.ref; - if (superTypeDecl) { - result.push(superTypeDecl); - result.push(...this.getRecursiveSuperTypes(superTypeDecl)); + if (isDataModel(node)) { + // add base fields to the scope recursively + const bases = getRecursiveBases(node); + for (const base of bases) { + for (const field of base.fields) { + scopes.add(node, this.descriptions.createDescription(field, this.nameProvider.getName(field))); + } } - }); - return result; + } } } @@ -140,50 +116,129 @@ export class ZModelScopeProvider extends DefaultScopeProvider { override getScope(context: ReferenceInfo): Scope { if (isMemberAccessExpr(context.container) && context.container.operand && context.property === 'member') { - return this.getMemberAccessScope(context.container); + return this.getMemberAccessScope(context); + } + + if (isReferenceExpr(context.container) && context.property === 'target') { + // when reference expression is resolved inside a collection predicate, the scope is the collection + const containerCollectionPredicate = getCollectionPredicateContext(context.container); + if (containerCollectionPredicate) { + return this.getCollectionPredicateScope(context, containerCollectionPredicate); + } } + return super.getScope(context); } - private getMemberAccessScope(node: MemberAccessExpr) { - if (isReferenceExpr(node.operand)) { - // scope to target model's fields - const ref = node.operand.target.ref; - if (isDataModelField(ref)) { - const targetModel = ref.type.reference?.ref; - if (isDataModel(targetModel)) { - return this.createScopeForNodes(targetModel.fields); + private getMemberAccessScope(context: ReferenceInfo) { + const referenceType = this.reflection.getReferenceType(context); + const globalScope = this.getGlobalScope(referenceType, context); + const node = context.container as MemberAccessExpr; + + return match(node.operand) + .when(isReferenceExpr, (operand) => { + // operand is a reference, it can only be a model field + const ref = operand.target.ref; + if (isDataModelField(ref)) { + const targetModel = ref.type.reference?.ref; + return this.createScopeForModel(targetModel, globalScope); } - } - } else if (isMemberAccessExpr(node.operand)) { - // scope to target model's fields - const ref = node.operand.member.ref; - if (isDataModelField(ref)) { - const targetModel = ref.type.reference?.ref; - if (isDataModel(targetModel)) { - return this.createScopeForNodes(targetModel.fields); + return EMPTY_SCOPE; + }) + .when(isMemberAccessExpr, (operand) => { + // operand is a member access, it must be resolved to a + const ref = operand.member.ref; + if (isDataModelField(ref)) { + const targetModel = ref.type.reference?.ref; + return this.createScopeForModel(targetModel, globalScope); } - } - } else if (isInvocationExpr(node.operand)) { - // deal with member access from `auth()` and `future() - const funcName = node.operand.function.$refText; - if (funcName === 'auth') { - // resolve to `User` or `@@auth` model - const model = getContainerOfType(node, isModel); - if (model) { - const authModel = getAuthModel(getDataModels(model)); - if (authModel) { - return this.createScopeForNodes(authModel.fields); - } + return EMPTY_SCOPE; + }) + .when(isThisExpr, () => { + // operand is `this`, resolve to the containing model + return this.createScopeForContainingModel(node, globalScope); + }) + .when(isInvocationExpr, (operand) => { + // deal with member access from `auth()` and `future() + if (isAuthInvocation(operand)) { + // resolve to `User` or `@@auth` model + return this.createScopeForAuthModel(node, globalScope); } - } - if (funcName === 'future') { - const thisModel = getContainerOfType(node, isDataModel); - if (thisModel) { - return this.createScopeForNodes(thisModel.fields); + if (isFutureInvocation(operand)) { + // resolve `future()` to the containing model + return this.createScopeForContainingModel(node, globalScope); } + return EMPTY_SCOPE; + }) + .otherwise(() => EMPTY_SCOPE); + } + + private getCollectionPredicateScope(context: ReferenceInfo, collectionPredicate: BinaryExpr) { + const referenceType = this.reflection.getReferenceType(context); + const globalScope = this.getGlobalScope(referenceType, context); + const collection = collectionPredicate.left; + + return match(collection) + .when(isReferenceExpr, (expr) => { + // collection is a reference, it can only be a model field + const ref = expr.target.ref; + if (isDataModelField(ref)) { + const targetModel = ref.type.reference?.ref; + return this.createScopeForModel(targetModel, globalScope); + } + return EMPTY_SCOPE; + }) + .when(isMemberAccessExpr, (expr) => { + // collection is a member access, it can only be resolved to a model field + const ref = expr.member.ref; + if (isDataModelField(ref)) { + const targetModel = ref.type.reference?.ref; + return this.createScopeForModel(targetModel, globalScope); + } + return EMPTY_SCOPE; + }) + .when(isAuthInvocation, (expr) => { + return this.createScopeForAuthModel(expr, globalScope); + }) + .otherwise(() => EMPTY_SCOPE); + } + + private createScopeForContainingModel(node: AstNode, globalScope: Scope) { + const model = getContainerOfType(node, isDataModel); + if (model) { + return this.createScopeForNodes(model.fields, globalScope); + } else { + return EMPTY_SCOPE; + } + } + + private createScopeForModel(node: AstNode | undefined, globalScope: Scope) { + if (isDataModel(node)) { + return this.createScopeForNodes(getModelFieldsWithBases(node), globalScope); + } else { + return EMPTY_SCOPE; + } + } + + private createScopeForAuthModel(node: AstNode, globalScope: Scope) { + const model = getContainerOfType(node, isModel); + if (model) { + const authModel = getAuthModel(getDataModels(model)); + if (authModel) { + return this.createScopeForNodes(authModel.fields, globalScope); } } return EMPTY_SCOPE; } } + +function getCollectionPredicateContext(node: AstNode) { + let curr: AstNode | undefined = node; + while (curr) { + if (curr.$container && isCollectionPredicate(curr.$container) && curr.$containerProperty === 'right') { + return curr.$container; + } + curr = curr.$container; + } + return undefined; +} diff --git a/packages/schema/src/utils/ast-utils.ts b/packages/schema/src/utils/ast-utils.ts index 661f14b26..3956a58bf 100644 --- a/packages/schema/src/utils/ast-utils.ts +++ b/packages/schema/src/utils/ast-utils.ts @@ -3,6 +3,7 @@ import { DataModel, DataModelField, Expression, + InheritableNode, isArrayExpr, isBinaryExpr, isDataModel, @@ -16,7 +17,17 @@ import { ReferenceExpr, } from '@zenstackhq/language/ast'; import { isFromStdlib } from '@zenstackhq/sdk'; -import { AstNode, getDocument, LangiumDocuments, Mutable } from 'langium'; +import { + AstNode, + copyAstNode, + CstNode, + getContainerOfType, + getDocument, + LangiumDocuments, + Linker, + Mutable, + Reference, +} from 'langium'; import { URI, Utils } from 'vscode-uri'; export function extractDataModelsWithAllowRules(model: Model): DataModel[] { @@ -25,7 +36,16 @@ export function extractDataModelsWithAllowRules(model: Model): DataModel[] { ) as DataModel[]; } -export function mergeBaseModel(model: Model) { +type BuildReference = ( + node: AstNode, + property: string, + refNode: CstNode | undefined, + refText: string +) => Reference; + +export function mergeBaseModel(model: Model, linker: Linker) { + const buildReference = linker.buildReference.bind(linker); + model.declarations .filter((x) => x.$type === 'DataModel') .forEach((decl) => { @@ -33,12 +53,15 @@ export function mergeBaseModel(model: Model) { dataModel.fields = dataModel.superTypes // eslint-disable-next-line @typescript-eslint/no-non-null-assertion - .flatMap((superType) => updateContainer(superType.ref!.fields, dataModel)) + .flatMap((superType) => superType.ref!.fields) + .map((f) => cloneAst(f, dataModel, buildReference)) .concat(dataModel.fields); dataModel.attributes = dataModel.superTypes // eslint-disable-next-line @typescript-eslint/no-non-null-assertion - .flatMap((superType) => updateContainer(superType.ref!.attributes, dataModel)) + // .flatMap((superType) => updateContainer(superType.ref!.attributes, dataModel)) + .flatMap((superType) => superType.ref!.attributes) + .map((attr) => cloneAst(attr, dataModel, buildReference)) .concat(dataModel.attributes); }); @@ -46,18 +69,20 @@ export function mergeBaseModel(model: Model) { model.declarations = model.declarations.filter((x) => !(x.$type == 'DataModel' && x.isAbstract)); } -function updateContainer(nodes: T[], container: AstNode): Mutable[] { - return nodes.map((node) => { - const cloneField = Object.assign({}, node); - const mutable = cloneField as Mutable; - // update container - mutable.$container = container; - return mutable; - }); +// deep clone an AST, relink references, and set its container +function cloneAst( + node: T, + newContainer: AstNode, + buildReference: BuildReference +): Mutable { + const clone = copyAstNode(node, buildReference) as Mutable; + clone.$container = newContainer; + clone.$inheritedFrom = getContainerOfType(node, isDataModel); + return clone; } export function getIdFields(dataModel: DataModel) { - const fieldLevelId = dataModel.$resolvedFields.find((f) => + const fieldLevelId = getModelFieldsWithBases(dataModel).find((f) => f.attributes.some((attr) => attr.decl.$refText === '@id') ); if (fieldLevelId) { @@ -83,6 +108,10 @@ export function isAuthInvocation(node: AstNode) { return isInvocationExpr(node) && node.function.ref?.name === 'auth' && isFromStdlib(node.function.ref); } +export function isFutureInvocation(node: AstNode) { + return isInvocationExpr(node) && node.function.ref?.name === 'future' && isFromStdlib(node.function.ref); +} + export function getDataModelFieldReference(expr: Expression): DataModelField | undefined { if (isReferenceExpr(expr) && isDataModelField(expr.target.ref)) { return expr.target.ref; @@ -157,7 +186,6 @@ export function isCollectionPredicate(node: AstNode): node is BinaryExpr { return isBinaryExpr(node) && ['?', '!', '^'].includes(node.operator); } - export function getContainingDataModel(node: Expression): DataModel | undefined { let curr: AstNode | undefined = node.$container; while (curr) { @@ -167,4 +195,20 @@ export function getContainingDataModel(node: Expression): DataModel | undefined curr = curr.$container; } return undefined; -} \ No newline at end of file +} + +export function getModelFieldsWithBases(model: DataModel) { + return [...model.fields, ...getRecursiveBases(model).flatMap((base) => base.fields)]; +} + +export function getRecursiveBases(dataModel: DataModel): DataModel[] { + const result: DataModel[] = []; + dataModel.superTypes.forEach((superType) => { + const baseDecl = superType.ref; + if (baseDecl) { + result.push(baseDecl); + result.push(...getRecursiveBases(baseDecl)); + } + }); + return result; +} diff --git a/packages/schema/tests/schema/validation/attribute-validation.test.ts b/packages/schema/tests/schema/validation/attribute-validation.test.ts index 8b7886334..757e158a0 100644 --- a/packages/schema/tests/schema/validation/attribute-validation.test.ts +++ b/packages/schema/tests/schema/validation/attribute-validation.test.ts @@ -1059,11 +1059,14 @@ describe('Attribute tests', () => { model A { id String @id x Int + b B @relation(references: [id], fields: [bId]) + bId String @unique } model B { id String @id - a A + a A? + aId String @unique @@allow('all', a?[x > 0]) } `) diff --git a/packages/schema/tests/utils.ts b/packages/schema/tests/utils.ts index f88aae6e2..4dcd45170 100644 --- a/packages/schema/tests/utils.ts +++ b/packages/schema/tests/utils.ts @@ -16,7 +16,7 @@ export class SchemaLoadingError extends Error { export async function loadModel(content: string, validate = true, verbose = true, mergeBase = true) { const { name: docPath } = tmp.fileSync({ postfix: '.zmodel' }); fs.writeFileSync(docPath, content); - const { shared } = createZModelServices(NodeFileSystem); + const { shared, ZModel } = createZModelServices(NodeFileSystem); const stdLib = shared.workspace.LangiumDocuments.getOrCreateDocument( URI.file(path.resolve(__dirname, '../../schema/src/res/stdlib.zmodel')) ); @@ -52,7 +52,7 @@ export async function loadModel(content: string, validate = true, verbose = true const model = (await doc.parseResult.value) as Model; if (mergeBase) { - mergeBaseModel(model); + mergeBaseModel(model, ZModel.references.Linker); } return model; diff --git a/packages/testtools/src/model.ts b/packages/testtools/src/model.ts index 4be8a1613..29b15467d 100644 --- a/packages/testtools/src/model.ts +++ b/packages/testtools/src/model.ts @@ -16,7 +16,7 @@ export class SchemaLoadingError extends Error { export async function loadModel(content: string, validate = true, verbose = true) { const { name: docPath } = tmp.fileSync({ postfix: '.zmodel' }); fs.writeFileSync(docPath, content); - const { shared } = createZModelServices(NodeFileSystem); + const { shared, ZModel } = createZModelServices(NodeFileSystem); const stdLib = shared.workspace.LangiumDocuments.getOrCreateDocument( URI.file(path.resolve(__dirname, '../../schema/src/res/stdlib.zmodel')) ); @@ -51,7 +51,7 @@ export async function loadModel(content: string, validate = true, verbose = true const model = (await doc.parseResult.value) as Model; - mergeBaseModel(model); + mergeBaseModel(model, ZModel.references.Linker); return model; } From 29815803ce9edc2dca569a42d57f331edffe9d48 Mon Sep 17 00:00:00 2001 From: ymc9 <104139426+ymc9@users.noreply.github.com> Date: Tue, 23 Jan 2024 19:31:57 +0800 Subject: [PATCH 2/3] fix tests --- packages/schema/src/utils/ast-utils.ts | 42 +++++++++++++++++-- .../tests/regression/issues.test.ts | 28 ++++++------- 2 files changed, 53 insertions(+), 17 deletions(-) diff --git a/packages/schema/src/utils/ast-utils.ts b/packages/schema/src/utils/ast-utils.ts index 3956a58bf..2af62f4f2 100644 --- a/packages/schema/src/utils/ast-utils.ts +++ b/packages/schema/src/utils/ast-utils.ts @@ -19,11 +19,14 @@ import { import { isFromStdlib } from '@zenstackhq/sdk'; import { AstNode, - copyAstNode, CstNode, + GenericAstNode, getContainerOfType, getDocument, + isAstNode, + isReference, LangiumDocuments, + linkContentToContainer, Linker, Mutable, Reference, @@ -59,14 +62,13 @@ export function mergeBaseModel(model: Model, linker: Linker) { dataModel.attributes = dataModel.superTypes // eslint-disable-next-line @typescript-eslint/no-non-null-assertion - // .flatMap((superType) => updateContainer(superType.ref!.attributes, dataModel)) .flatMap((superType) => superType.ref!.attributes) .map((attr) => cloneAst(attr, dataModel, buildReference)) .concat(dataModel.attributes); }); // remove abstract models - model.declarations = model.declarations.filter((x) => !(x.$type == 'DataModel' && x.isAbstract)); + model.declarations = model.declarations.filter((x) => !(isDataModel(x) && x.isAbstract)); } // deep clone an AST, relink references, and set its container @@ -77,10 +79,44 @@ function cloneAst( ): Mutable { const clone = copyAstNode(node, buildReference) as Mutable; clone.$container = newContainer; + clone.$containerProperty = node.$containerProperty; + clone.$containerIndex = node.$containerIndex; clone.$inheritedFrom = getContainerOfType(node, isDataModel); return clone; } +// this function is copied from Langium's ast-utils, but copying $resolvedType as well +function copyAstNode(node: T, buildReference: BuildReference): T { + const copy: GenericAstNode = { $type: node.$type, $resolvedType: node.$resolvedType }; + + for (const [name, value] of Object.entries(node)) { + if (!name.startsWith('$')) { + if (isAstNode(value)) { + copy[name] = copyAstNode(value, buildReference); + } else if (isReference(value)) { + copy[name] = buildReference(copy, name, value.$refNode, value.$refText); + } else if (Array.isArray(value)) { + const copiedArray: unknown[] = []; + for (const element of value) { + if (isAstNode(element)) { + copiedArray.push(copyAstNode(element, buildReference)); + } else if (isReference(element)) { + copiedArray.push(buildReference(copy, name, element.$refNode, element.$refText)); + } else { + copiedArray.push(element); + } + } + copy[name] = copiedArray; + } else { + copy[name] = value; + } + } + } + + linkContentToContainer(copy); + return copy as unknown as T; +} + export function getIdFields(dataModel: DataModel) { const fieldLevelId = getModelFieldsWithBases(dataModel).find((f) => f.attributes.some((attr) => attr.decl.$refText === '@id') diff --git a/tests/integration/tests/regression/issues.test.ts b/tests/integration/tests/regression/issues.test.ts index 4ade85c8c..7c2ca94cd 100644 --- a/tests/integration/tests/regression/issues.test.ts +++ b/tests/integration/tests/regression/issues.test.ts @@ -327,9 +327,9 @@ model User { // can be created by anyone, even not logged in @@allow('create', true) // can be read by users in the same organization - @@allow('read', orgs?[members?[auth() == this]]) + @@allow('read', orgs?[members?[auth().id == id]]) // full access by oneself - @@allow('all', auth() == this) + @@allow('all', auth().id == id) } model Organization { @@ -343,7 +343,7 @@ model Organization { // everyone can create a organization @@allow('create', true) // any user in the organization can read the organization - @@allow('read', members?[auth() == this]) + @@allow('read', members?[auth().id == id]) } abstract model organizationBaseEntity { @@ -359,15 +359,15 @@ abstract model organizationBaseEntity { groups Group[] // when create, owner must be set to current user, and user must be in the organization - @@allow('create', owner == auth() && org.members?[this == auth()]) + @@allow('create', owner == auth() && org.members?[id == auth().id]) // only the owner can update it and is not allowed to change the owner - @@allow('update', owner == auth() && org.members?[this == auth()] && future().owner == owner) + @@allow('update', owner == auth() && org.members?[id == auth().id] && future().owner == owner) // allow owner to read @@allow('read', owner == auth()) // allow shared group members to read it - @@allow('read', groups?[users?[this == auth()]]) + @@allow('read', groups?[users?[id == auth().id]]) // allow organization to access if public - @@allow('read', isPublic && org.members?[this == auth()]) + @@allow('read', isPublic && org.members?[id == auth().id]) // can not be read if deleted @@deny('all', isDeleted == true) } @@ -394,7 +394,7 @@ model Group { orgId String // group is shared by organization - @@allow('all', org.members?[auth() == this]) + @@allow('all', org.members?[auth().id == id]) } ` ); @@ -616,7 +616,7 @@ model Organization { // everyone can create a organization @@allow('create', true) // any user in the organization can read the organization - @@allow('read', members?[auth() == this]) + @@allow('read', members?[auth().id == id]) } abstract model organizationBaseEntity { @@ -632,15 +632,15 @@ abstract model organizationBaseEntity { groups Group[] // when create, owner must be set to current user, and user must be in the organization - @@allow('create', owner == auth() && org.members?[this == auth()]) + @@allow('create', owner == auth() && org.members?[id == auth().id]) // only the owner can update it and is not allowed to change the owner - @@allow('update', owner == auth() && org.members?[this == auth()] && future().owner == owner) + @@allow('update', owner == auth() && org.members?[id == auth().id] && future().owner == owner) // allow owner to read @@allow('read', owner == auth()) // allow shared group members to read it - @@allow('read', groups?[users?[this == auth()]]) + @@allow('read', groups?[users?[id == auth().id]]) // allow organization to access if public - @@allow('read', isPublic && org.members?[this == auth()]) + @@allow('read', isPublic && org.members?[id == auth().id]) // can not be read if deleted @@deny('all', isDeleted == true) } @@ -667,7 +667,7 @@ model Group { orgId String // group is shared by organization - @@allow('all', org.members?[auth() == this]) + @@allow('all', org.members?[auth().id == id]) } ` ); From e95aaaf4db0b4f2be7b6514b557d9a84adc44fa8 Mon Sep 17 00:00:00 2001 From: ymc9 <104139426+ymc9@users.noreply.github.com> Date: Tue, 23 Jan 2024 22:12:17 +0800 Subject: [PATCH 3/3] fix tests --- .../src/language-server/zmodel-scope.ts | 2 +- .../tests/regression/issue-925.test.ts | 20 +++++++++---------- 2 files changed, 10 insertions(+), 12 deletions(-) diff --git a/packages/schema/src/language-server/zmodel-scope.ts b/packages/schema/src/language-server/zmodel-scope.ts index f0d346b36..0c848c58d 100644 --- a/packages/schema/src/language-server/zmodel-scope.ts +++ b/packages/schema/src/language-server/zmodel-scope.ts @@ -223,7 +223,7 @@ export class ZModelScopeProvider extends DefaultScopeProvider { private createScopeForAuthModel(node: AstNode, globalScope: Scope) { const model = getContainerOfType(node, isModel); if (model) { - const authModel = getAuthModel(getDataModels(model)); + const authModel = getAuthModel(getDataModels(model, true)); if (authModel) { return this.createScopeForNodes(authModel.fields, globalScope); } diff --git a/tests/integration/tests/regression/issue-925.test.ts b/tests/integration/tests/regression/issue-925.test.ts index 34b1ac434..b19d9d615 100644 --- a/tests/integration/tests/regression/issue-925.test.ts +++ b/tests/integration/tests/regression/issue-925.test.ts @@ -1,7 +1,7 @@ -import { loadModelWithError } from '@zenstackhq/testtools'; +import { loadModel, loadModelWithError } from '@zenstackhq/testtools'; describe('Regression: issue 925', () => { - it('member reference from this', async () => { + it('member reference without using this', async () => { await expect( loadModelWithError( ` @@ -10,7 +10,7 @@ describe('Regression: issue 925', () => { company Company[] test Int - @@allow('read', auth().company?[staff?[companyId == this.test]]) + @@allow('read', auth().company?[staff?[companyId == test]]) } model Company { @@ -32,19 +32,18 @@ describe('Regression: issue 925', () => { } ` ) - ).resolves.toContain("Could not resolve reference to DataModelField named 'test'."); + ).resolves.toContain("Could not resolve reference to ReferenceTarget named 'test'."); }); - it('simple reference', async () => { - await expect( - loadModelWithError( - ` + it('reference with this', async () => { + await loadModel( + ` model User { id Int @id @default(autoincrement()) company Company[] test Int - @@allow('read', auth().company?[staff?[companyId == test]]) + @@allow('read', auth().company?[staff?[companyId == this.test]]) } model Company { @@ -65,7 +64,6 @@ describe('Regression: issue 925', () => { @@allow('read', true) } ` - ) - ).resolves.toContain("Could not resolve reference to ReferenceTarget named 'test'."); + ); }); });