diff --git a/packages/runtime/src/enhancements/policy/constraint-solver.ts b/packages/runtime/src/enhancements/policy/constraint-solver.ts index c87a528e7..9b792a0fa 100644 --- a/packages/runtime/src/enhancements/policy/constraint-solver.ts +++ b/packages/runtime/src/enhancements/policy/constraint-solver.ts @@ -1,10 +1,10 @@ import Logic from 'logic-solver'; import { match } from 'ts-pattern'; import type { - CheckerConstraint, ComparisonConstraint, ComparisonTerm, LogicalConstraint, + PermissionCheckerConstraint, ValueConstraint, VariableConstraint, } from '../types'; @@ -22,7 +22,7 @@ export class ConstraintSolver { /** * Check the satisfiability of the given constraint. */ - checkSat(constraint: CheckerConstraint): boolean { + checkSat(constraint: PermissionCheckerConstraint): boolean { // reset state this.stringTable = []; this.variables = new Map(); @@ -46,7 +46,7 @@ export class ConstraintSolver { return !!solver.solve(); } - private buildFormula(constraint: CheckerConstraint): Logic.Formula { + private buildFormula(constraint: PermissionCheckerConstraint): Logic.Formula { return match(constraint) .when( (c): c is ValueConstraint => c.kind === 'value', @@ -100,11 +100,11 @@ export class ConstraintSolver { return Logic.not(this.buildFormula(constraint.children[0])); } - private isTrue(constraint: CheckerConstraint): unknown { + private isTrue(constraint: PermissionCheckerConstraint): unknown { return constraint.kind === 'value' && constraint.value === true; } - private isFalse(constraint: CheckerConstraint): unknown { + private isFalse(constraint: PermissionCheckerConstraint): unknown { return constraint.kind === 'value' && constraint.value === false; } diff --git a/packages/runtime/src/enhancements/policy/handler.ts b/packages/runtime/src/enhancements/policy/handler.ts index 997e727d5..7ce3a8987 100644 --- a/packages/runtime/src/enhancements/policy/handler.ts +++ b/packages/runtime/src/enhancements/policy/handler.ts @@ -1,5 +1,6 @@ /* eslint-disable @typescript-eslint/no-explicit-any */ +import deepmerge from 'deepmerge'; import { lowerCaseFirst } from 'lower-case-first'; import invariant from 'tiny-invariant'; import { P, match } from 'ts-pattern'; @@ -23,7 +24,7 @@ import { Logger } from '../logger'; import { createDeferredPromise, createFluentPromise } from '../promise'; import { PrismaProxyHandler } from '../proxy'; import { QueryUtils } from '../query-utils'; -import type { CheckerConstraint } from '../types'; +import type { EntityCheckerFunc, PermissionCheckerConstraint } from '../types'; import { clone, formatObject, isUnsafeMutate, prismaClientValidationError } from '../utils'; import { ConstraintSolver } from './constraint-solver'; import { PolicyUtil } from './policy-utils'; @@ -152,8 +153,7 @@ export class PolicyProxyHandler implements Pr } const result = await this.modelClient[actionName](_args); - this.policyUtils.postProcessForRead(result, this.model, origArgs); - return result; + return this.policyUtils.postProcessForRead(result, this.model, origArgs); } //#endregion @@ -779,10 +779,27 @@ export class PolicyProxyHandler implements Pr } }; - const _connectDisconnect = async (model: string, args: any, context: NestedWriteVisitorContext) => { + const _connectDisconnect = async ( + model: string, + args: any, + context: NestedWriteVisitorContext, + operation: 'connect' | 'disconnect' + ) => { if (context.field?.backLink) { const backLinkField = this.policyUtils.getModelField(model, context.field.backLink); if (backLinkField?.isRelationOwner) { + let uniqueFilter = args; + if (operation === 'disconnect') { + // disconnect filter is not unique, need to build a reversed query to + // locate the entity and use its id fields as unique filter + const reversedQuery = this.policyUtils.buildReversedQuery(context); + const found = await db[model].findUnique({ + where: reversedQuery, + select: this.policyUtils.makeIdSelection(model), + }); + uniqueFilter = found && this.policyUtils.getIdFieldValues(model, found); + } + // update happens on the related model, require updatable, // translate args to foreign keys so field-level policies can be checked const checkArgs: any = {}; @@ -794,10 +811,15 @@ export class PolicyProxyHandler implements Pr } } } - await this.policyUtils.checkPolicyForUnique(model, args, 'update', db, checkArgs); - // register post-update check - await _registerPostUpdateCheck(model, args, args); + // `uniqueFilter` can be undefined if the entity to be disconnected doesn't exist + if (uniqueFilter) { + // check for update + await this.policyUtils.checkPolicyForUnique(model, uniqueFilter, 'update', db, checkArgs); + + // register post-update check + await _registerPostUpdateCheck(model, uniqueFilter, uniqueFilter); + } } } }; @@ -970,14 +992,14 @@ export class PolicyProxyHandler implements Pr } }, - connect: async (model, args, context) => _connectDisconnect(model, args, context), + connect: async (model, args, context) => _connectDisconnect(model, args, context, 'connect'), connectOrCreate: async (model, args, context) => { // the where condition is already unique, so we can use it to check if the target exists const existing = await this.policyUtils.checkExistence(db, model, args.where); if (existing) { // connect - await _connectDisconnect(model, args.where, context); + await _connectDisconnect(model, args.where, context, 'connect'); return true; } else { // create @@ -997,7 +1019,7 @@ export class PolicyProxyHandler implements Pr } }, - disconnect: async (model, args, context) => _connectDisconnect(model, args, context), + disconnect: async (model, args, context) => _connectDisconnect(model, args, context, 'disconnect'), set: async (model, args, context) => { // find the set of items to be replaced @@ -1012,10 +1034,10 @@ export class PolicyProxyHandler implements Pr const currentSet = await db[model].findMany(findCurrSetArgs); // register current set for update (foreign key) - await Promise.all(currentSet.map((item) => _connectDisconnect(model, item, context))); + await Promise.all(currentSet.map((item) => _connectDisconnect(model, item, context, 'disconnect'))); // proceed with connecting the new set - await Promise.all(enumerate(args).map((item) => _connectDisconnect(model, item, context))); + await Promise.all(enumerate(args).map((item) => _connectDisconnect(model, item, context, 'connect'))); }, delete: async (model, args, context) => { @@ -1160,48 +1182,78 @@ export class PolicyProxyHandler implements Pr args.data = this.validateUpdateInputSchema(this.model, args.data); - if (this.policyUtils.hasAuthGuard(this.model, 'postUpdate') || this.policyUtils.getZodSchema(this.model)) { - // use a transaction to do post-update checks - const postWriteChecks: PostWriteCheckRecord[] = []; - return this.queryUtils.transaction(this.prisma, async (tx) => { - // collect pre-update values - let select = this.policyUtils.makeIdSelection(this.model); - const preValueSelect = this.policyUtils.getPreValueSelect(this.model); - if (preValueSelect) { - select = { ...select, ...preValueSelect }; - } - const currentSetQuery = { select, where: args.where }; - this.policyUtils.injectAuthGuardAsWhere(tx, currentSetQuery, this.model, 'read'); - - if (this.shouldLogQuery) { - this.logger.info(`[policy] \`findMany\` ${this.model}: ${formatObject(currentSetQuery)}`); - } - const currentSet = await tx[this.model].findMany(currentSetQuery); - - postWriteChecks.push( - ...currentSet.map((preValue) => ({ - model: this.model, - operation: 'postUpdate' as PolicyOperationKind, - uniqueFilter: this.policyUtils.getEntityIds(this.model, preValue), - preValue: preValueSelect ? preValue : undefined, - })) - ); - - // proceed with the update - const result = await tx[this.model].updateMany(args); + const entityChecker = this.policyUtils.getEntityChecker(this.model, 'update'); - // run post-write checks - await this.runPostWriteChecks(postWriteChecks, tx); + const canProceedWithoutTransaction = + // no post-update rules + !this.policyUtils.hasAuthGuard(this.model, 'postUpdate') && + // no Zod schema + !this.policyUtils.getZodSchema(this.model) && + // no entity checker + !entityChecker; - return result; - }); - } else { + if (canProceedWithoutTransaction) { // proceed without a transaction if (this.shouldLogQuery) { this.logger.info(`[policy] \`updateMany\` ${this.model}: ${formatObject(args)}`); } return this.modelClient.updateMany(args); } + + // collect post-update checks + const postWriteChecks: PostWriteCheckRecord[] = []; + + return this.queryUtils.transaction(this.prisma, async (tx) => { + // collect pre-update values + let select = this.policyUtils.makeIdSelection(this.model); + const preValueSelect = this.policyUtils.getPreValueSelect(this.model); + if (preValueSelect) { + select = { ...select, ...preValueSelect }; + } + + // merge selection required for running additional checker + const entityChecker = this.policyUtils.getEntityChecker(this.model, 'update'); + if (entityChecker?.selector) { + select = deepmerge(select, entityChecker.selector); + } + + const currentSetQuery = { select, where: args.where }; + this.policyUtils.injectAuthGuardAsWhere(tx, currentSetQuery, this.model, 'update'); + + if (this.shouldLogQuery) { + this.logger.info(`[policy] \`findMany\` ${this.model}: ${formatObject(currentSetQuery)}`); + } + let candidates = await tx[this.model].findMany(currentSetQuery); + + if (entityChecker) { + // filter candidates with additional checker and build an id filter + const r = this.buildIdFilterWithEntityChecker(candidates, entityChecker.func); + candidates = r.filteredCandidates; + + // merge id filter into update's where clause + args.where = args.where ? { AND: [args.where, r.idFilter] } : r.idFilter; + } + + postWriteChecks.push( + ...candidates.map((preValue) => ({ + model: this.model, + operation: 'postUpdate' as PolicyOperationKind, + uniqueFilter: this.policyUtils.getEntityIds(this.model, preValue), + preValue: preValueSelect ? preValue : undefined, + })) + ); + + // proceed with the update + if (this.shouldLogQuery) { + this.logger.info(`[policy] \`updateMany\` in tx for ${this.model}: ${formatObject(args)}`); + } + const result = await tx[this.model].updateMany(args); + + // run post-write checks + await this.runPostWriteChecks(postWriteChecks, tx); + + return result; + }); }); } @@ -1328,14 +1380,49 @@ export class PolicyProxyHandler implements Pr this.policyUtils.tryReject(this.prisma, this.model, 'delete'); // inject policy conditions - args = args ?? {}; + args = clone(args); this.policyUtils.injectAuthGuardAsWhere(this.prisma, args, this.model, 'delete'); - // conduct the deletion - if (this.shouldLogQuery) { - this.logger.info(`[policy] \`deleteMany\` ${this.model}:\n${formatObject(args)}`); + const entityChecker = this.policyUtils.getEntityChecker(this.model, 'delete'); + if (entityChecker) { + // additional checker exists, need to run deletion inside a transaction + return this.queryUtils.transaction(this.prisma, async (tx) => { + // find the delete candidates, selecting id fields and fields needed for + // running the additional checker + let candidateSelect = this.policyUtils.makeIdSelection(this.model); + if (entityChecker.selector) { + candidateSelect = deepmerge(candidateSelect, entityChecker.selector); + } + + if (this.shouldLogQuery) { + this.logger.info( + `[policy] \`findMany\` ${this.model}: ${formatObject({ + where: args.where, + select: candidateSelect, + })}` + ); + } + const candidates = await tx[this.model].findMany({ where: args.where, select: candidateSelect }); + + // build a ID filter based on id values filtered by the additional checker + const { idFilter } = this.buildIdFilterWithEntityChecker(candidates, entityChecker.func); + + // merge the ID filter into the where clause + args.where = args.where ? { AND: [args.where, idFilter] } : idFilter; + + // finally, conduct the deletion with the combined where clause + if (this.shouldLogQuery) { + this.logger.info(`[policy] \`deleteMany\` in tx for ${this.model}:\n${formatObject(args)}`); + } + return tx[this.model].deleteMany(args); + }); + } else { + // conduct the deletion directly + if (this.shouldLogQuery) { + this.logger.info(`[policy] \`deleteMany\` ${this.model}:\n${formatObject(args)}`); + } + return this.modelClient.deleteMany(args); } - return this.modelClient.deleteMany(args); }); } @@ -1469,7 +1556,7 @@ export class PolicyProxyHandler implements Pr if (args.where) { // combine runtime filters with generated constraints - const extraConstraints: CheckerConstraint[] = []; + const extraConstraints: PermissionCheckerConstraint[] = []; for (const [field, value] of Object.entries(args.where)) { if (value === undefined) { continue; @@ -1599,5 +1686,17 @@ export class PolicyProxyHandler implements Pr } } + private buildIdFilterWithEntityChecker(candidates: any[], entityChecker: EntityCheckerFunc) { + const filteredCandidates = candidates.filter((value) => entityChecker(value, { user: this.context?.user })); + const idFields = this.policyUtils.getIdFields(this.model); + let idFilter: any; + if (idFields.length === 1) { + idFilter = { [idFields[0].name]: { in: filteredCandidates.map((x) => x[idFields[0].name]) } }; + } else { + idFilter = { AND: filteredCandidates.map((x) => this.policyUtils.getIdFieldValues(this.model, x)) }; + } + return { filteredCandidates, idFilter }; + } + //#endregion } diff --git a/packages/runtime/src/enhancements/policy/policy-utils.ts b/packages/runtime/src/enhancements/policy/policy-utils.ts index 02bf87ebf..b76875d28 100644 --- a/packages/runtime/src/enhancements/policy/policy-utils.ts +++ b/packages/runtime/src/enhancements/policy/policy-utils.ts @@ -1,18 +1,26 @@ /* eslint-disable @typescript-eslint/no-explicit-any */ import deepcopy from 'deepcopy'; +import deepmerge from 'deepmerge'; import { lowerCaseFirst } from 'lower-case-first'; import { upperCaseFirst } from 'upper-case-first'; import { ZodError } from 'zod'; import { fromZodError } from 'zod-validation-error'; import { CrudFailureReason, PrismaErrorCode } from '../../constants'; import { enumerate, getFields, getModelFields, resolveField, zip, type FieldInfo, type ModelMeta } from '../../cross'; -import { AuthUser, CrudContract, DbClientContract, PolicyCrudKind, PolicyOperationKind } from '../../types'; +import { + AuthUser, + CrudContract, + DbClientContract, + PolicyCrudKind, + PolicyOperationKind, + QueryContext, +} from '../../types'; import { getVersion } from '../../version'; import type { EnhancementContext, InternalEnhancementOptions } from '../create-enhancement'; import { Logger } from '../logger'; import { QueryUtils } from '../query-utils'; -import type { CheckerFunc, ModelPolicyDef, PolicyDef, PolicyFunc, ZodSchemas } from '../types'; +import type { EntityChecker, ModelPolicyDef, PermissionCheckerFunc, PolicyDef, PolicyFunc, ZodSchemas } from '../types'; import { formatObject, prismaClientKnownRequestError } from '../utils'; /** @@ -272,7 +280,7 @@ export class PolicyUtil extends QueryUtils { */ getFieldOverrideReadAuthGuard(db: CrudContract, model: string, field: string) { const def = this.getModelPolicyDef(model); - const guard = def.fieldLevel?.read?.overrideGuard?.[field]; + const guard = def.fieldLevel?.read?.[field]?.overrideGuard; if (guard === undefined) { // field access is denied by default in override mode @@ -292,7 +300,7 @@ export class PolicyUtil extends QueryUtils { */ getFieldUpdateAuthGuard(db: CrudContract, model: string, field: string) { const def = this.getModelPolicyDef(model); - const guard = def.fieldLevel?.update?.guard?.[field]; + const guard = def.fieldLevel?.update?.[field]?.guard; if (guard === undefined) { // field access is allowed by default @@ -312,7 +320,7 @@ export class PolicyUtil extends QueryUtils { */ getFieldOverrideUpdateAuthGuard(db: CrudContract, model: string, field: string) { const def = this.getModelPolicyDef(model); - const guard = def.fieldLevel?.update?.overrideGuard?.[field]; + const guard = def.fieldLevel?.update?.[field]?.overrideGuard; if (guard === undefined) { // field access is denied by default in override mode @@ -343,8 +351,13 @@ export class PolicyUtil extends QueryUtils { return false; } const def = this.getModelPolicyDef(model); - const guard = def.fieldLevel?.[operation]?.overrideGuard; - return guard && Object.keys(guard).length > 0; + if (def.fieldLevel?.[operation]) { + return Object.values(def.fieldLevel[operation]).some( + (f) => f.overrideGuard !== undefined || f.overrideEntityChecker !== undefined + ); + } else { + return false; + } } /** @@ -551,7 +564,7 @@ export class PolicyUtil extends QueryUtils { /** * Gets checker constraints for the given model and operation. */ - getCheckerConstraint(model: string, operation: PolicyCrudKind): ReturnType | boolean { + getCheckerConstraint(model: string, operation: PolicyCrudKind): ReturnType | boolean { if (this.options.kinds && !this.options.kinds.includes('policy')) { // policy enhancement not enabled, return a constant true checker result return true; @@ -697,6 +710,8 @@ export class PolicyUtil extends QueryUtils { ); } + let entityChecker: EntityChecker | undefined; + if (operation === 'update' && args) { // merge field-level policy guards const fieldUpdateGuard = this.getFieldUpdateGuards(db, model, args); @@ -710,33 +725,47 @@ export class PolicyUtil extends QueryUtils { }"`, CrudFailureReason.ACCESS_POLICY_VIOLATION ); - } else { - if (fieldUpdateGuard.guard) { - // merge field-level guard - guard = this.and(guard, fieldUpdateGuard.guard); - } + } - if (fieldUpdateGuard.overrideGuard) { - // merge field-level override guard - guard = this.or(guard, fieldUpdateGuard.overrideGuard); - } + if (fieldUpdateGuard.guard) { + // merge field-level guard with AND + guard = this.and(guard, fieldUpdateGuard.guard); + } + + if (fieldUpdateGuard.overrideGuard) { + // merge field-level override guard with OR + guard = this.or(guard, fieldUpdateGuard.overrideGuard); } + + // field-level entity checker + entityChecker = fieldUpdateGuard.entityChecker; } // Zod schema is to be checked for "create" and "postUpdate" const schema = ['create', 'postUpdate'].includes(operation) ? this.getZodSchema(model) : undefined; - if (this.isTrue(guard) && !schema) { + // combine field-level entity checker with model-level + const modelEntityChecker = this.getEntityChecker(model, operation); + entityChecker = this.combineEntityChecker(entityChecker, modelEntityChecker, 'and'); + + if (this.isTrue(guard) && !schema && !entityChecker) { // unconditionally allowed return; } - const select = schema + let select = schema ? // need to validate against schema, need to fetch all fields undefined : // only fetch id fields this.makeIdSelection(model); + if (entityChecker?.selector) { + if (!select) { + select = this.makeAllScalarFieldSelect(model); + } + select = { ...select, ...entityChecker.selector }; + } + let where = this.clone(uniqueFilter); // query args may have be of combined-id form, need to flatten it to call findFirst this.flattenGeneratedUniqueField(model, where); @@ -758,6 +787,20 @@ export class PolicyUtil extends QueryUtils { ); } + if (entityChecker) { + if (this.logger.enabled('info')) { + this.logger.info(`[policy] running entity checker on ${model} for ${operation}`); + } + if (!entityChecker.func(result, { user: this.user, preValue })) { + throw this.deniedByPolicy( + model, + operation, + `entity ${formatObject(uniqueFilter, false)} failed policy check`, + CrudFailureReason.ACCESS_POLICY_VIOLATION + ); + } + } + if (schema) { // TODO: push down schema check to the database const parseResult = schema.safeParse(result); @@ -777,6 +820,20 @@ export class PolicyUtil extends QueryUtils { } } + getEntityChecker(model: string, operation: PolicyOperationKind, field?: string) { + const def = this.getModelPolicyDef(model); + if (field) { + return def.fieldLevel?.[operation as 'read' | 'update']?.[field]?.entityChecker; + } else { + return def.modelLevel[operation].entityChecker; + } + } + + getUpdateOverrideEntityCheckerForField(model: string, field: string) { + const def = this.getModelPolicyDef(model); + return def.fieldLevel?.update?.[field]?.overrideEntityChecker; + } + private getFieldReadGuards(db: CrudContract, model: string, args: { select?: any; include?: any }) { const allFields = Object.values(getFields(this.modelMeta, model)); @@ -803,19 +860,20 @@ export class PolicyUtil extends QueryUtils { private getFieldUpdateGuards(db: CrudContract, model: string, args: any) { const allFieldGuards = []; const allOverrideFieldGuards = []; + let entityChecker: EntityChecker | undefined; - for (const [k, v] of Object.entries(args.data ?? args)) { - if (typeof v === 'undefined') { + for (const [field, value] of Object.entries(args.data ?? args)) { + if (typeof value === 'undefined') { continue; } - const field = resolveField(this.modelMeta, model, k); + const fieldInfo = resolveField(this.modelMeta, model, field); - if (field?.isDataModel) { + if (fieldInfo?.isDataModel) { // relation field update should be treated as foreign key update, // fetch and merge all foreign key guards - if (field.isRelationOwner && field.foreignKeyMapping) { - const foreignKeys = Object.values(field.foreignKeyMapping); + if (fieldInfo.isRelationOwner && fieldInfo.foreignKeyMapping) { + const foreignKeys = Object.values(fieldInfo.foreignKeyMapping); for (const fk of foreignKeys) { const fieldGuard = this.getFieldUpdateAuthGuard(db, model, fk); if (this.isFalse(fieldGuard)) { @@ -831,18 +889,26 @@ export class PolicyUtil extends QueryUtils { } } } else { - const fieldGuard = this.getFieldUpdateAuthGuard(db, model, k); + const fieldGuard = this.getFieldUpdateAuthGuard(db, model, field); if (this.isFalse(fieldGuard)) { - return { guard: fieldGuard, rejectedByField: k }; + return { guard: fieldGuard, rejectedByField: field }; } // add field guard allFieldGuards.push(fieldGuard); // add field override guard - const overrideFieldGuard = this.getFieldOverrideUpdateAuthGuard(db, model, k); + const overrideFieldGuard = this.getFieldOverrideUpdateAuthGuard(db, model, field); allOverrideFieldGuards.push(overrideFieldGuard); } + + // merge regular and override entity checkers with OR + let checker = this.getEntityChecker(model, 'update', field); + const overrideChecker = this.getUpdateOverrideEntityCheckerForField(model, field); + checker = this.combineEntityChecker(checker, overrideChecker, 'or'); + + // accumulate entity checker across fields + entityChecker = this.combineEntityChecker(entityChecker, checker, 'and'); } const allFieldsCombined = this.and(...allFieldGuards); @@ -853,6 +919,31 @@ export class PolicyUtil extends QueryUtils { guard: allFieldsCombined, overrideGuard: allOverrideFieldsCombined, rejectedByField: undefined, + entityChecker, + }; + } + + private combineEntityChecker( + left: EntityChecker | undefined, + right: EntityChecker | undefined, + combiner: 'and' | 'or' + ): EntityChecker | undefined { + if (!left) { + return right; + } + + if (!right) { + return left; + } + + const func = + combiner === 'and' + ? (entity: any, context: QueryContext) => left.func(entity, context) && right.func(entity, context) + : (entity: any, context: QueryContext) => left.func(entity, context) || right.func(entity, context); + + return { + func, + selector: deepmerge(left.selector ?? {}, right.selector ?? {}), }; } @@ -934,8 +1025,8 @@ export class PolicyUtil extends QueryUtils { } /** - * Injects field selection needed for checking field-level read policy into query args. - * @returns + * Injects field selection needed for checking field-level read policy check and evaluating + * entity checker into query args. */ injectReadCheckSelect(model: string, args: any) { // we need to recurse into relation fields before injecting the current level, because @@ -957,6 +1048,11 @@ export class PolicyUtil extends QueryUtils { this.doInjectReadCheckSelect(model, args, { select: readFieldSelect }); } } + + const entityChecker = this.getEntityChecker(model, 'read'); + if (entityChecker?.selector) { + this.doInjectReadCheckSelect(model, args, { select: entityChecker.selector }); + } } private doInjectReadCheckSelect(model: string, args: any, input: any) { @@ -1074,19 +1170,36 @@ export class PolicyUtil extends QueryUtils { return def.modelLevel.postUpdate.preUpdateSelector; } + // get a merged selector object for all field-level read policies private getFieldReadCheckSelector(model: string) { const def = this.getModelPolicyDef(model); - return def.fieldLevel?.read?.selector; + let result: any = {}; + const fieldLevel = def.fieldLevel?.read; + if (fieldLevel) { + for (const def of Object.values(fieldLevel)) { + if (def.entityChecker?.selector) { + result = deepmerge(result, def.entityChecker.selector); + } + if (def.overrideEntityChecker?.selector) { + result = deepmerge(result, def.overrideEntityChecker.selector); + } + } + } + return Object.keys(result).length > 0 ? result : undefined; } private checkReadField(model: string, field: string, entity: any) { const def = this.getModelPolicyDef(model); - const guard = def.fieldLevel?.read?.checker?.[field]; - if (guard === undefined) { + // combine regular and override field-level entity checkers with OR + const checker = def.fieldLevel?.read?.[field]?.entityChecker; + const overrideChecker = def.fieldLevel?.read?.[field]?.overrideEntityChecker; + const combinedChecker = this.combineEntityChecker(checker, overrideChecker, 'or'); + + if (combinedChecker === undefined) { return true; } else { - return guard(entity, { user: this.user }); + return combinedChecker.func(entity, { user: this.user }); } } @@ -1096,7 +1209,7 @@ export class PolicyUtil extends QueryUtils { private hasFieldLevelPolicy(model: string) { const def = this.getModelPolicyDef(model); - return !!def.fieldLevel?.read?.checker; + return Object.keys(def.fieldLevel?.read ?? {}).length > 0; } /** @@ -1119,7 +1232,7 @@ export class PolicyUtil extends QueryUtils { // preserve the original data as it may be needed for checking field-level readability, // while the "data" will be manipulated during traversal (deleting unreadable fields) const origData = this.clone(data); - this.doPostProcessForRead(data, model, origData, queryArgs, this.hasFieldLevelPolicy(model)); + return this.doPostProcessForRead(data, model, origData, queryArgs, this.hasFieldLevelPolicy(model)); } private doPostProcessForRead( @@ -1131,12 +1244,44 @@ export class PolicyUtil extends QueryUtils { path = '' ) { if (data === null || data === undefined) { - return; + return data; } - for (const [entityData, entityFullData] of zip(data, fullData)) { + let filteredData = data; + let filteredFullData = fullData; + + const entityChecker = this.getEntityChecker(model, 'read'); + if (entityChecker) { + if (Array.isArray(data)) { + filteredData = []; + filteredFullData = []; + for (const [entityData, entityFullData] of zip(data, fullData)) { + if (!entityChecker.func(entityData, { user: this.user })) { + if (this.shouldLogQuery) { + this.logger.info( + `[policy] dropping ${model} entity${path ? ' at ' + path : ''} due to entity checker` + ); + } + } else { + filteredData.push(entityData); + filteredFullData.push(entityFullData); + } + } + } else { + if (!entityChecker.func(data, { user: this.user })) { + if (this.shouldLogQuery) { + this.logger.info( + `[policy] dropping ${model} entity${path ? ' at ' + path : ''} due to entity checker` + ); + } + return null; + } + } + } + + for (const [entityData, entityFullData] of zip(filteredData, filteredFullData)) { if (typeof entityData !== 'object' || !entityData) { - return; + continue; } for (const [field, fieldData] of Object.entries(entityData)) { @@ -1192,7 +1337,7 @@ export class PolicyUtil extends QueryUtils { if (fieldInfo.isDataModel) { // recurse into nested fields const nextArgs = (queryArgs?.select ?? queryArgs?.include)?.[field]; - this.doPostProcessForRead( + const nestedResult = this.doPostProcessForRead( fieldData, fieldInfo.type, entityFullData[field], @@ -1200,9 +1345,16 @@ export class PolicyUtil extends QueryUtils { this.hasFieldLevelPolicy(fieldInfo.type), path ? path + '.' + field : field ); + if (nestedResult === undefined) { + delete entityData[field]; + } else { + entityData[field] = nestedResult; + } } } } + + return filteredData; } /** diff --git a/packages/runtime/src/enhancements/types.ts b/packages/runtime/src/enhancements/types.ts index aa14555b8..8aefcd8ed 100644 --- a/packages/runtime/src/enhancements/types.ts +++ b/packages/runtime/src/enhancements/types.ts @@ -1,6 +1,6 @@ /* eslint-disable @typescript-eslint/no-explicit-any */ import { z } from 'zod'; -import type { CheckerContext, CrudContract, QueryContext } from '../types'; +import type { CrudContract, PermissionCheckerContext, QueryContext } from '../types'; /** * Common options for PrismaClient enhancements @@ -24,10 +24,15 @@ export interface CommonEnhancementOptions { */ export type PolicyFunc = (context: QueryContext, db: CrudContract) => object; +/** + * Function for checking an entity's data for permission + */ +export type EntityCheckerFunc = (input: any, context: QueryContext) => boolean; + /** * Function for checking if an operation is possibly allowed. */ -export type CheckerFunc = (context: CheckerContext) => CheckerConstraint; +export type PermissionCheckerFunc = (context: PermissionCheckerContext) => PermissionCheckerConstraint; /** * Supported checker constraint checking value types. @@ -67,23 +72,17 @@ export type ComparisonConstraint = { */ export type LogicalConstraint = { kind: 'and' | 'or' | 'not'; - children: CheckerConstraint[]; + children: PermissionCheckerConstraint[]; }; /** * Operation allowability checking constraint */ -export type CheckerConstraint = ValueConstraint | VariableConstraint | ComparisonConstraint | LogicalConstraint; - -/** - * Function for getting policy guard with a given context - */ -export type InputCheckFunc = (args: any, context: QueryContext) => boolean; - -/** - * Function for getting policy guard with a given context - */ -export type ReadFieldCheckFunc = (input: any, context: QueryContext) => boolean; +export type PermissionCheckerConstraint = + | ValueConstraint + | VariableConstraint + | ComparisonConstraint + | LogicalConstraint; /** * Policy definition @@ -128,6 +127,21 @@ export type ModelCrudDef = { postUpdate: ModelPostUpdateDef; }; +/** + * Information for checking entity data outside of Prisma + */ +export type EntityChecker = { + /** + * Checker function + */ + func: EntityCheckerFunc; + + /** + * Selector for fetching entity data + */ + selector?: object; +}; + /** * Common policy definition for a CRUD operation */ @@ -137,10 +151,18 @@ type ModelCrudCommon = { */ guard: PolicyFunc | boolean; + /** + * Additional checker function for checking policies outside of Prisma + */ + /** + * Additional checker function for checking policies outside of Prisma + */ + entityChecker?: EntityChecker; + /** * Permission checker function or a constant condition */ - permissionChecker?: CheckerFunc | boolean; + permissionChecker?: PermissionCheckerFunc | boolean; }; /** @@ -156,7 +178,7 @@ type ModelCreateDef = ModelCrudCommon & { * Create input validation function. Only generated when a create * can be approved or denied based on input values. */ - inputChecker?: InputCheckFunc | boolean; + inputChecker?: EntityCheckerFunc | boolean; }; /** @@ -172,8 +194,7 @@ type ModelDeleteDef = ModelCrudCommon; /** * Policy definition for post-update checking a model */ -type ModelPostUpdateDef = { - guard: PolicyFunc | boolean; +type ModelPostUpdateDef = Exclude & { preUpdateSelector?: object; }; @@ -184,37 +205,51 @@ type FieldCrudDef = { /** * Field-level read policy */ - read?: { - /** - * Selector for reading fields needed for evaluating the policy - */ - selector?: object; + read: Record; + + /** + * Field-level update policy + */ + update: Record; +}; - /** - * Field-level Prisma query guard - */ - checker?: Record; +type FieldReadDef = { + /** + * Entity checker + */ + entityChecker?: EntityChecker; - /** - * Field-level read override Prisma query guard - */ - overrideGuard?: Record; - }; + /** + * Field-level read override Prisma query guard + */ + overrideGuard?: PolicyFunc; /** - * Field-level update policy + * Entity checker for override policies + */ + overrideEntityChecker?: EntityChecker; +}; + +type FieldUpdateDef = { + /** + * Field-level update Prisma query guard + */ + guard?: PolicyFunc; + + /** + * Additional entity checker + */ + entityChecker?: EntityChecker; + + /** + * Field-level update override Prisma query guard + */ + overrideGuard?: PolicyFunc; + + /** + * Additional entity checker for override policies */ - update?: { - /** - * Field-level update Prisma query guard - */ - guard?: Record; - - /** - * Field-level update override Prisma query guard - */ - overrideGuard?: Record; - }; + overrideEntityChecker?: EntityChecker; }; /** diff --git a/packages/runtime/src/types.ts b/packages/runtime/src/types.ts index 4c32480ba..b9497b7ee 100644 --- a/packages/runtime/src/types.ts +++ b/packages/runtime/src/types.ts @@ -62,7 +62,7 @@ export type QueryContext = { /** * Context for checking operation allowability. */ -export type CheckerContext = { +export type PermissionCheckerContext = { /** * Current user */ diff --git a/packages/schema/src/language-server/validator/expression-validator.ts b/packages/schema/src/language-server/validator/expression-validator.ts index d65e304dc..478db5ff7 100644 --- a/packages/schema/src/language-server/validator/expression-validator.ts +++ b/packages/schema/src/language-server/validator/expression-validator.ts @@ -1,6 +1,7 @@ import { AstNode, BinaryExpr, + DataModelAttribute, Expression, ExpressionType, isDataModel, @@ -13,7 +14,12 @@ import { isReferenceExpr, isThisExpr, } from '@zenstackhq/language/ast'; -import { isAuthInvocation, isDataModelFieldReference, isEnumFieldReference } from '@zenstackhq/sdk'; +import { + getAttributeArgLiteral, + isAuthInvocation, + isDataModelFieldReference, + isEnumFieldReference, +} from '@zenstackhq/sdk'; import { ValidationAcceptor, streamAst } from 'langium'; import { findUpAst, getContainingDataModel } from '../../utils/ast-utils'; import { AstValidator } from '../types'; @@ -151,6 +157,7 @@ export default class ExpressionValidator implements AstValidator { accept('error', 'incompatible operand types', { node: expr }); break; } + // not supported: // - foo.a == bar // - foo.user.id == userId @@ -169,10 +176,24 @@ export default class ExpressionValidator implements AstValidator { // foo.user.id == null // foo.user.id == EnumValue if (!(this.isNotModelFieldExpr(expr.left) || this.isNotModelFieldExpr(expr.right))) { - accept('error', 'comparison between fields of different models are not supported', { - node: expr, - }); - break; + const containingPolicyAttr = findUpAst( + expr, + (node) => isDataModelAttribute(node) && ['@@allow', '@@deny'].includes(node.decl.$refText) + ) as DataModelAttribute | undefined; + + if (containingPolicyAttr) { + const operation = getAttributeArgLiteral(containingPolicyAttr, 'operation'); + if (operation?.split(',').includes('all') || operation?.split(',').includes('read')) { + accept( + 'error', + 'comparison between fields of different models is not supported in model-level "read" rules', + { + node: expr, + } + ); + break; + } + } } } @@ -246,16 +267,6 @@ export default class ExpressionValidator implements AstValidator { accept('error', 'collection predicate can only be used on an array of model type', { node: expr }); return; } - - // TODO: revisit this when we implement lambda inside collection predicate - const thisExpr = streamAst(expr).find(isThisExpr); - if (thisExpr) { - accept( - 'error', - 'using `this` in collection predicate is not supported. To compare entity identity, use id field comparison instead.', - { node: thisExpr } - ); - } } private isInValidationContext(node: AstNode) { diff --git a/packages/schema/src/plugins/enhancer/policy/policy-guard-generator.ts b/packages/schema/src/plugins/enhancer/policy/policy-guard-generator.ts index 619543e44..0f65c76c0 100644 --- a/packages/schema/src/plugins/enhancer/policy/policy-guard-generator.ts +++ b/packages/schema/src/plugins/enhancer/policy/policy-guard-generator.ts @@ -31,6 +31,7 @@ import path from 'path'; import { CodeBlockWriter, Project, SourceFile, VariableDeclarationKind, WriterFunction } from 'ts-morph'; import { ConstraintTransformer } from './constraint-transformer'; import { + generateEntityCheckerFunction, generateNormalizedAuthRef, generateQueryGuardFunction, generateSelectForRules, @@ -85,8 +86,8 @@ export class PolicyGenerator { { name: 'type CrudContract' }, { name: 'allFieldsEqual' }, { name: 'type PolicyDef' }, - { name: 'type CheckerContext' }, - { name: 'type CheckerConstraint' }, + { name: 'type PermissionCheckerContext' }, + { name: 'type PermissionCheckerConstraint' }, ], moduleSpecifier: `${RUNTIME_PACKAGE}`, }); @@ -171,15 +172,16 @@ export class PolicyGenerator { // writes `inputChecker: [funcName]` for a given model private writeCreateInputChecker(model: DataModel, writer: CodeBlockWriter, sourceFile: SourceFile) { - const allows = getPolicyExpressions(model, 'allow', 'create'); - const denies = getPolicyExpressions(model, 'deny', 'create'); - if (this.canCheckCreateBasedOnInput(model, allows, denies)) { - const inputCheckFunc = this.generateCreateInputCheckerFunction(model, allows, denies, sourceFile); + if (this.canCheckCreateBasedOnInput(model)) { + const inputCheckFunc = this.generateCreateInputCheckerFunction(model, sourceFile); writer.write(`inputChecker: ${inputCheckFunc.getName()!},`); } } - private canCheckCreateBasedOnInput(model: DataModel, allows: Expression[], denies: Expression[]) { + private canCheckCreateBasedOnInput(model: DataModel) { + const allows = getPolicyExpressions(model, 'allow', 'create', false, 'all'); + const denies = getPolicyExpressions(model, 'deny', 'create', false, 'all'); + return [...allows, ...denies].every((rule) => { return streamAst(rule).every((expr) => { if (isThisExpr(expr)) { @@ -216,13 +218,10 @@ export class PolicyGenerator { } // generates a function for checking "create" input - private generateCreateInputCheckerFunction( - model: DataModel, - allows: Expression[], - denies: Expression[], - sourceFile: SourceFile - ) { + private generateCreateInputCheckerFunction(model: DataModel, sourceFile: SourceFile) { const statements: (string | WriterFunction)[] = []; + const allows = getPolicyExpressions(model, 'allow', 'create'); + const denies = getPolicyExpressions(model, 'deny', 'create'); generateNormalizedAuthRef(model, allows, denies, statements); @@ -348,6 +347,52 @@ export class PolicyGenerator { if (kind !== 'postUpdate') { this.writePermissionChecker(model, kind, policies, allows, denies, writer, sourceFile); } + + // write cross-model comparison rules as entity checker functions + // because they cannot be checked inside Prisma + this.writeEntityChecker(model, kind, writer, sourceFile, true); + } + + private writeEntityChecker( + target: DataModel | DataModelField, + kind: PolicyOperationKind, + writer: CodeBlockWriter, + sourceFile: SourceFile, + onlyCrossModelComparison = false, + forOverride = false + ) { + const allows = getPolicyExpressions( + target, + 'allow', + kind, + forOverride, + onlyCrossModelComparison ? 'onlyCrossModelComparison' : 'all' + ); + const denies = getPolicyExpressions( + target, + 'deny', + kind, + forOverride, + onlyCrossModelComparison ? 'onlyCrossModelComparison' : 'all' + ); + + if (allows.length === 0 && denies.length === 0) { + return; + } + + const model = isDataModel(target) ? target : (target.$container as DataModel); + const func = generateEntityCheckerFunction( + sourceFile, + model, + kind, + allows, + denies, + isDataModelField(target) ? target : undefined, + forOverride + ); + const selector = generateSelectForRules([...allows, ...denies], false, kind !== 'postUpdate') ?? {}; + const key = forOverride ? 'overrideEntityChecker' : 'entityChecker'; + writer.write(`${key}: { func: ${func.getName()!}, selector: ${JSON.stringify(selector)} },`); } // writes `guard: ...` for a given policy operation kind @@ -413,11 +458,10 @@ export class PolicyGenerator { // post-update counterpart if (getPolicyExpressions(model, 'allow', 'postUpdate').length === 0) { writer.write(`permissionChecker: false,`); - return; } else { writer.write(`permissionChecker: true,`); - return; } + return; } const guardFunc = this.generatePermissionCheckerFunction(model, kind, allows, denies, sourceFile); @@ -443,11 +487,11 @@ export class PolicyGenerator { const func = sourceFile.addFunction({ name: `${model.name}$checker$${kind}`, - returnType: 'CheckerConstraint', + returnType: 'PermissionCheckerConstraint', parameters: [ { name: 'context', - type: 'CheckerContext', + type: 'PermissionCheckerContext', }, ], statements, @@ -470,132 +514,93 @@ export class PolicyGenerator { } private writeFieldReadDef(model: DataModel, writer: CodeBlockWriter, sourceFile: SourceFile) { - const fieldCheckers: Record = {}; - const overrideGuards: Record = {}; - const allFieldsAllows: Expression[] = []; - const allFieldsDenies: Expression[] = []; - - // generate field read checkers - for (const field of model.fields) { - const allows = getPolicyExpressions(field, 'allow', 'read'); - const denies = getPolicyExpressions(field, 'deny', 'read'); - if (denies.length === 0 && allows.length === 0) { - continue; - } - - allFieldsAllows.push(...allows); - allFieldsDenies.push(...denies); - - const guardFunc = this.generateFieldReadCheckerFunction(sourceFile, field, allows, denies); - // eslint-disable-next-line @typescript-eslint/no-non-null-assertion - fieldCheckers[field.name] = guardFunc.getName()!; - - const overrideAllows = getPolicyExpressions(field, 'allow', 'read', true); - if (overrideAllows.length > 0) { - const denies = getPolicyExpressions(field, 'deny', 'read'); - const overrideGuardFunc = generateQueryGuardFunction( - sourceFile, - model, - 'read', - overrideAllows, - denies, - field, - true - ); - overrideGuards[field.name] = overrideGuardFunc.getName()!; - } - } + writer.writeLine('read:'); + writer.block(() => { + for (const field of model.fields) { + const policyAttrs = field.attributes.filter((attr) => ['@allow', '@deny'].includes(attr.decl.$refText)); - if (Object.keys(fieldCheckers).length > 0 || Object.keys(overrideGuards).length > 0) { - writer.write('read:'); - writer.block(() => { - if (Object.keys(fieldCheckers).length > 0) { - writer.write('checker:'); + if (policyAttrs.length === 0) { + continue; + } - // write checkers - writer.inlineBlock(() => { - Object.entries(fieldCheckers).forEach(([fieldName, funcName]) => { - writer.write(`${fieldName}: ${funcName},`); - }); - }); - writer.writeLine(','); + writer.write(`${field.name}:`); - // write field selector - const readFieldCheckSelect = generateSelectForRules([...allFieldsAllows, ...allFieldsDenies]); - if (readFieldCheckSelect) { - writer.write(`selector: ${JSON.stringify(readFieldCheckSelect)},`); + writer.block(() => { + // checker function + // write all field-level rules as entity checker function + this.writeEntityChecker(field, 'read', writer, sourceFile, false, false); + + const overrideAllows = getPolicyExpressions(field, 'allow', 'read', true); + if (overrideAllows.length > 0) { + // override guard function + const denies = getPolicyExpressions(field, 'deny', 'read'); + const overrideGuardFunc = generateQueryGuardFunction( + sourceFile, + model, + 'read', + overrideAllows, + denies, + field, + true + ); + writer.write(`overrideGuard: ${overrideGuardFunc.getName()},`); + + // additional entity checker for override + this.writeEntityChecker(field, 'read', writer, sourceFile, false, true); } - } - - if (Object.keys(overrideGuards).length > 0) { - // write override guards - writer.write('overrideGuard:'); - writer.inlineBlock(() => { - Object.entries(overrideGuards).forEach(([fieldName, funcName]) => { - writer.write(`${fieldName}: ${funcName},`); - }); - }); - writer.writeLine(','); - } - }); - writer.writeLine(','); - } + }); + writer.writeLine(','); + } + }); + writer.writeLine(','); } private writeFieldUpdateDef(model: DataModel, writer: CodeBlockWriter, sourceFile: SourceFile) { - const guards: Record = {}; - const overrideGuards: Record = {}; - - for (const field of model.fields) { - const allows = getPolicyExpressions(field, 'allow', 'update'); - const denies = getPolicyExpressions(field, 'deny', 'update'); + writer.writeLine('update:'); + writer.block(() => { + for (const field of model.fields) { + const allows = getPolicyExpressions(field, 'allow', 'update'); + const denies = getPolicyExpressions(field, 'deny', 'update'); + const overrideAllows = getPolicyExpressions(field, 'allow', 'update', true); + + if (allows.length === 0 && denies.length === 0 && overrideAllows.length === 0) { + continue; + } - if (denies.length === 0 && allows.length === 0) { - continue; - } + writer.write(`${field.name}:`); - const guardFunc = generateQueryGuardFunction(sourceFile, model, 'update', allows, denies, field); - guards[field.name] = guardFunc.getName()!; - - const overrideAllows = getPolicyExpressions(field, 'allow', 'update', true); - if (overrideAllows.length > 0) { - const overrideGuardFunc = generateQueryGuardFunction( - sourceFile, - model, - 'update', - overrideAllows, - denies, - field, - true - ); - overrideGuards[field.name] = overrideGuardFunc.getName()!; + writer.block(() => { + // guard + const guardFunc = generateQueryGuardFunction(sourceFile, model, 'update', allows, denies, field); + writer.write(`guard: ${guardFunc.getName()},`); + + // write cross-model comparison rules as entity checker functions + // because they cannot be checked inside Prisma + this.writeEntityChecker(field, 'update', writer, sourceFile, true, false); + + const overrideAllows = getPolicyExpressions(field, 'allow', 'update', true); + if (overrideAllows.length > 0) { + // override guard + const overrideGuardFunc = generateQueryGuardFunction( + sourceFile, + model, + 'update', + overrideAllows, + denies, + field, + true + ); + writer.write(`overrideGuard: ${overrideGuardFunc.getName()},`); + + // write cross-model comparison override rules as entity checker functions + // because they cannot be checked inside Prisma + this.writeEntityChecker(field, 'update', writer, sourceFile, true, true); + } + }); + writer.writeLine(','); } - } - - if (Object.keys(guards).length > 0 || Object.keys(overrideGuards).length > 0) { - writer.write('update:'); - writer.block(() => { - if (Object.keys(guards).length > 0) { - writer.write('guard:'); - writer.inlineBlock(() => { - Object.entries(guards).forEach(([fieldName, funcName]) => { - writer.write(`${fieldName}: ${funcName},`); - }); - }); - writer.writeLine(','); - } - - if (Object.keys(overrideGuards).length > 0) { - writer.write('overrideGuard:'); - writer.inlineBlock(() => { - Object.entries(overrideGuards).forEach(([fieldName, funcName]) => { - writer.write(`${fieldName}: ${funcName},`); - }); - }); - writer.writeLine(','); - } - }); - } + }); + writer.writeLine(','); } private generateFieldReadCheckerFunction( diff --git a/packages/schema/src/plugins/enhancer/policy/utils.ts b/packages/schema/src/plugins/enhancer/policy/utils.ts index c8b75ffd8..f6f8bd801 100644 --- a/packages/schema/src/plugins/enhancer/policy/utils.ts +++ b/packages/schema/src/plugins/enhancer/policy/utils.ts @@ -11,6 +11,7 @@ import { getIdFields, getLiteral, isAuthInvocation, + isDataModelFieldReference, isEnumFieldReference, isFromStdlib, isFutureExpr, @@ -19,6 +20,7 @@ import { import { Enum, Model, + isBinaryExpr, isDataModel, isDataModelField, isExpression, @@ -30,10 +32,10 @@ import { type DataModelField, type Expression, } from '@zenstackhq/sdk/ast'; -import { streamAllContents, streamAst, streamContents } from 'langium'; +import { getContainerOfType, streamAllContents, streamAst, streamContents } from 'langium'; import { SourceFile, WriterFunction } from 'ts-morph'; import { name } from '..'; -import { isCollectionPredicate } from '../../../utils/ast-utils'; +import { isCollectionPredicate, isFutureInvocation } from '../../../utils/ast-utils'; import { ExpressionWriter, FALSE, TRUE } from './expression-writer'; /** @@ -43,7 +45,8 @@ export function getPolicyExpressions( target: DataModel | DataModelField, kind: PolicyKind, operation: PolicyOperationKind, - override = false + forOverride = false, + filter: 'all' | 'withoutCrossModelComparison' | 'onlyCrossModelComparison' = 'all' ) { const attributes = target.attributes; const attrName = isDataModel(target) ? `@@${kind}` : `@${kind}`; @@ -52,12 +55,10 @@ export function getPolicyExpressions( return false; } - if (override) { - const overrideArg = getAttributeArg(attr, 'override'); - return overrideArg && getLiteral(overrideArg) === true; - } else { - return true; - } + const overrideArg = getAttributeArg(attr, 'override'); + const isOverride = !!overrideArg && getLiteral(overrideArg) === true; + + return (forOverride && isOverride) || (!forOverride && !isOverride); }); const checkOperation = operation === 'postUpdate' ? 'update' : operation; @@ -73,6 +74,12 @@ export function getPolicyExpressions( }) .map((attr) => attr.args[1].value); + if (filter === 'onlyCrossModelComparison') { + result = result.filter((expr) => hasCrossModelComparison(expr)); + } else if (filter === 'withoutCrossModelComparison') { + result = result.filter((expr) => !hasCrossModelComparison(expr)); + } + if (operation === 'update') { result = processUpdatePolicies(result, false); } else if (operation === 'postUpdate') { @@ -108,9 +115,14 @@ function processUpdatePolicies(expressions: Expression[], postUpdate: boolean) { * Generates a "select" object that contains (recursively) fields referenced by the * given policy rules */ -export function generateSelectForRules(rules: Expression[], forAuthContext = false): object { +export function generateSelectForRules(rules: Expression[], forAuthContext = false, ignoreFutureReference = true) { const result: any = {}; const addPath = (path: string[]) => { + const thisIndex = path.lastIndexOf('$this'); + if (thisIndex >= 0) { + // drop everything before $this + path = path.slice(thisIndex + 1); + } let curr = result; path.forEach((seg, i) => { if (i === path.length - 1) { @@ -128,6 +140,10 @@ export function generateSelectForRules(rules: Expression[], forAuthContext = fal // selection path const visit = (node: Expression): string[] | undefined => { if (isThisExpr(node)) { + return ['$this']; + } + + if (isFutureExpr(node)) { return []; } @@ -144,7 +160,7 @@ export function generateSelectForRules(rules: Expression[], forAuthContext = fal return [node.member.$refText]; } - if (isFutureExpr(node.operand)) { + if (isFutureExpr(node.operand) && ignoreFutureReference) { // future().field is not subject to pre-update select return undefined; } @@ -183,13 +199,15 @@ export function generateSelectForRules(rules: Expression[], forAuthContext = fal } } else if (isCollectionPredicate(expr)) { const path = visit(expr.left); + // recurse into RHS + const rhs = collectReferencePaths(expr.right); if (path) { - // recurse into RHS - const rhs = collectReferencePaths(expr.right); // combine path of LHS and RHS return rhs.map((r) => [...path, ...r]); } else { - return []; + // LHS is not rooted from the current model, + // only keep RHS items that contains '$this' + return rhs.filter((r) => r.includes('$this')); } } else if (isInvocationExpr(expr)) { // recurse into function arguments @@ -225,9 +243,12 @@ export function generateQueryGuardFunction( ) { const statements: (string | WriterFunction)[] = []; - generateNormalizedAuthRef(model, allows, denies, statements); + const allowRules = allows.filter((rule) => !hasCrossModelComparison(rule)); + const denyRules = denies.filter((rule) => !hasCrossModelComparison(rule)); - const hasFieldAccess = [...denies, ...allows].some((rule) => + generateNormalizedAuthRef(model, allowRules, denyRules, statements); + + const hasFieldAccess = [...denyRules, ...allowRules].some((rule) => streamAst(rule).some( (child) => // this.??? @@ -248,10 +269,10 @@ export function generateQueryGuardFunction( isPostGuard: kind === 'postUpdate', }); try { - denies.forEach((rule) => { + denyRules.forEach((rule) => { writer.write(`if (${transformer.transform(rule, false)}) { return ${FALSE}; }`); }); - allows.forEach((rule) => { + allowRules.forEach((rule) => { writer.write(`if (${transformer.transform(rule, false)}) { return ${TRUE}; }`); }); } catch (err) { @@ -267,12 +288,22 @@ export function generateQueryGuardFunction( // if there's no allow rule, for field-level rules, by default we allow writer.write(`return ${TRUE};`); } else { - // if there's any allow rule, we deny unless any allow rule evaluates to true - writer.write(`return ${FALSE};`); + if (allowRules.length < allows.length) { + writer.write(`return ${TRUE};`); + } else { + // if there's any allow rule, we deny unless any allow rule evaluates to true + writer.write(`return ${FALSE};`); + } } } else { - // for model-level rules, the default is always deny - writer.write(`return ${FALSE};`); + if (allowRules.length < allows.length) { + // some rules are filtered out here and will be generated as additional + // checker functions, so we allow here to avoid a premature denial + writer.write(`return ${TRUE};`); + } else { + // for model-level rules, the default is always deny unless for 'postUpdate' + writer.write(`return ${kind === 'postUpdate' ? TRUE : FALSE};`); + } } }); } else { @@ -280,42 +311,42 @@ export function generateQueryGuardFunction( writer.write('return '); const exprWriter = new ExpressionWriter(writer, kind === 'postUpdate'); const writeDenies = () => { - writer.conditionalWrite(denies.length > 1, '{ AND: ['); - denies.forEach((expr, i) => { + writer.conditionalWrite(denyRules.length > 1, '{ AND: ['); + denyRules.forEach((expr, i) => { writer.inlineBlock(() => { writer.write('NOT: '); exprWriter.write(expr); }); - writer.conditionalWrite(i !== denies.length - 1, ','); + writer.conditionalWrite(i !== denyRules.length - 1, ','); }); - writer.conditionalWrite(denies.length > 1, ']}'); + writer.conditionalWrite(denyRules.length > 1, ']}'); }; const writeAllows = () => { - writer.conditionalWrite(allows.length > 1, '{ OR: ['); - allows.forEach((expr, i) => { + writer.conditionalWrite(allowRules.length > 1, '{ OR: ['); + allowRules.forEach((expr, i) => { exprWriter.write(expr); - writer.conditionalWrite(i !== allows.length - 1, ','); + writer.conditionalWrite(i !== allowRules.length - 1, ','); }); - writer.conditionalWrite(allows.length > 1, ']}'); + writer.conditionalWrite(allowRules.length > 1, ']}'); }; - if (allows.length > 0 && denies.length > 0) { + if (allowRules.length > 0 && denyRules.length > 0) { // include both allow and deny rules writer.write('{ AND: ['); writeDenies(); writer.write(','); writeAllows(); writer.write(']}'); - } else if (denies.length > 0) { + } else if (denyRules.length > 0) { // only deny rules writeDenies(); - } else if (allows.length > 0) { + } else if (allowRules.length > 0) { // only allow rules writeAllows(); } else { - // disallow any operation - writer.write(`{ OR: [] }`); + // disallow any operation unless for 'postUpdate' + writer.write(`return ${kind === 'postUpdate' ? TRUE : FALSE};`); } writer.write(';'); }); @@ -341,6 +372,59 @@ export function generateQueryGuardFunction( return func; } +export function generateEntityCheckerFunction( + sourceFile: SourceFile, + model: DataModel, + kind: PolicyOperationKind, + allows: Expression[], + denies: Expression[], + forField?: DataModelField, + fieldOverride = false +) { + const statements: (string | WriterFunction)[] = []; + + generateNormalizedAuthRef(model, allows, denies, statements); + + const transformer = new TypeScriptExpressionTransformer({ + context: ExpressionContext.AccessPolicy, + thisExprContext: 'input', + fieldReferenceContext: 'input', + isPostGuard: kind === 'postUpdate', + futureRefContext: 'input', + }); + + denies.forEach((rule) => { + const compiled = transformer.transform(rule); + statements.push(`if (${compiled}) { return false; }`); + }); + + allows.forEach((rule) => { + const compiled = transformer.transform(rule); + statements.push(`if (${compiled}) { return true; }`); + }); + + // default: deny unless for 'postUpdate' + statements.push(kind === 'postUpdate' ? 'return true;' : 'return false;'); + + const func = sourceFile.addFunction({ + name: `$check_${model.name}${forField ? '$' + forField.name : ''}${fieldOverride ? '$override' : ''}_${kind}`, + returnType: 'any', + parameters: [ + { + name: 'input', + type: 'any', + }, + { + name: 'context', + type: 'QueryContext', + }, + ], + statements, + }); + + return func; +} + /** * Generates a normalized auth reference for the given policy rules */ @@ -384,3 +468,44 @@ export function isEnumReferenced(model: Model, decl: Enum): unknown { return false; }); } + +function hasCrossModelComparison(expr: Expression) { + return streamAst(expr).some((node) => { + if (isBinaryExpr(node) && ['==', '!=', '>', '<', '>=', '<=', 'in'].includes(node.operator)) { + const leftRoot = getSourceModelOfFieldAccess(node.left); + const rightRoot = getSourceModelOfFieldAccess(node.right); + if (leftRoot && rightRoot && leftRoot !== rightRoot) { + return true; + } + } + return false; + }); +} + +function getSourceModelOfFieldAccess(expr: Expression) { + if (isDataModel(expr.$resolvedType?.decl)) { + return expr.$resolvedType?.decl; + } + + // `this` reference + if (isThisExpr(expr)) { + return getContainerOfType(expr, isDataModel); + } + + // `future()` + if (isFutureInvocation(expr)) { + return getContainerOfType(expr, isDataModel); + } + + // direct field reference + if (isDataModelFieldReference(expr)) { + return (expr.target.ref as DataModelField).$container; + } + + // member access + if (isMemberAccessExpr(expr)) { + return getSourceModelOfFieldAccess(expr.operand); + } + + return undefined; +} diff --git a/packages/schema/tests/schema/validation/attribute-validation.test.ts b/packages/schema/tests/schema/validation/attribute-validation.test.ts index 380836e21..b2ac1544b 100644 --- a/packages/schema/tests/schema/validation/attribute-validation.test.ts +++ b/packages/schema/tests/schema/validation/attribute-validation.test.ts @@ -699,7 +699,36 @@ describe('Attribute tests', () => { } `) - ).toContain('comparison between fields of different models are not supported'); + ).toContain('comparison between fields of different models is not supported in model-level "read" rules'); + + expect( + await loadModel(` + ${prelude} + model User { + id Int @id + lists List[] + todos Todo[] + } + + model List { + id Int @id + user User @relation(fields: [userId], references: [id]) + userId Int + todos Todo[] + } + + model Todo { + id Int @id + user User @relation(fields: [userId], references: [id]) + userId Int + list List @relation(fields: [listId], references: [id]) + listId Int + + @@allow('create', list.user.id == userId) + } + + `) + ).toBeTruthy(); expect( await loadModelWithError(` diff --git a/packages/schema/tests/schema/validation/datamodel-validation.test.ts b/packages/schema/tests/schema/validation/datamodel-validation.test.ts index e0778da51..e7dd6bf84 100644 --- a/packages/schema/tests/schema/validation/datamodel-validation.test.ts +++ b/packages/schema/tests/schema/validation/datamodel-validation.test.ts @@ -88,7 +88,7 @@ describe('Data Model Validation Tests', () => { @@allow('all', members?[this == auth()]) } `) - ).toMatchObject(errorLike('using `this` in collection predicate is not supported')); + ).toBeTruthy(); expect( await loadModel(` diff --git a/packages/sdk/src/typescript-expression-transformer.ts b/packages/sdk/src/typescript-expression-transformer.ts index 8e33eb4a7..28ce1d345 100644 --- a/packages/sdk/src/typescript-expression-transformer.ts +++ b/packages/sdk/src/typescript-expression-transformer.ts @@ -34,6 +34,7 @@ type Options = { isPostGuard?: boolean; fieldReferenceContext?: string; thisExprContext?: string; + futureRefContext?: string; context: ExpressionContext; }; @@ -116,7 +117,9 @@ export class TypeScriptExpressionTransformer { if (this.options?.isPostGuard !== true) { throw new TypeScriptExpressionTransformerError(`future() is only supported in postUpdate rules`); } - return expr.member.ref.name; + return this.options.futureRefContext + ? `${this.options.futureRefContext}.${expr.member.ref.name}` + : expr.member.ref.name; } else { if (normalizeUndefined) { // normalize field access to null instead of undefined to avoid accidentally use undefined in filter @@ -449,7 +452,6 @@ export class TypeScriptExpressionTransformer { ...this.options, isPostGuard: false, fieldReferenceContext: '_item', - thisExprContext: '_item', }); const predicate = innerTransformer.transform(expr.right, normalizeUndefined); diff --git a/tests/integration/tests/enhancements/with-policy/cross-model-field-comparison.test.ts b/tests/integration/tests/enhancements/with-policy/cross-model-field-comparison.test.ts new file mode 100644 index 000000000..1ebfaeba6 --- /dev/null +++ b/tests/integration/tests/enhancements/with-policy/cross-model-field-comparison.test.ts @@ -0,0 +1,823 @@ +import { loadSchema } from '@zenstackhq/testtools'; + +describe('Cross-model field comparison', () => { + it('to-one relation', async () => { + const { prisma, enhance } = await loadSchema( + ` + model User { + id Int @id + profile Profile @relation(fields: [profileId], references: [id]) + profileId Int @unique + age Int + + @@allow('read', true) + @@allow('create,update,delete', age == profile.age) + @@deny('update', future().age < future().profile.age && age > 0) + } + + model Profile { + id Int @id + age Int + user User? + + @@allow('all', true) + } + ` + ); + + const db = enhance(); + + const reset = async () => { + await prisma.user.deleteMany(); + await prisma.profile.deleteMany(); + }; + + // create + await expect( + db.user.create({ data: { id: 1, age: 18, profile: { create: { id: 1, age: 20 } } } }) + ).toBeRejectedByPolicy(); + await expect(prisma.user.findUnique({ where: { id: 1 } })).toResolveNull(); + await expect( + db.user.create({ data: { id: 1, age: 18, profile: { create: { id: 1, age: 18 } } } }) + ).toResolveTruthy(); + await expect(prisma.user.findUnique({ where: { id: 1 } })).toResolveTruthy(); + await reset(); + + // createMany + await expect( + db.user.createMany({ data: [{ id: 1, age: 18, profile: { create: { id: 1, age: 20 } } }] }) + ).toBeRejectedByPolicy(); + await expect(prisma.user.findUnique({ where: { id: 1 } })).toResolveNull(); + await expect( + db.user.createMany({ data: { id: 1, age: 18, profile: { create: { id: 1, age: 18 } } } }) + ).toResolveTruthy(); + await expect(prisma.user.findUnique({ where: { id: 1 } })).toResolveTruthy(); + await reset(); + + // TODO: cross-model field comparison is not supported for read rules yet + // // read + // await prisma.user.create({ data: { id: 1, age: 18, profile: { create: { id: 1, age: 18 } } } }); + // await expect(db.user.findUnique({ where: { id: 1 } })).toResolveTruthy(); + // await expect(db.user.findMany()).resolves.toHaveLength(1); + // await prisma.user.update({ where: { id: 1 }, data: { age: 20 } }); + // await expect(db.user.findUnique({ where: { id: 1 } })).toResolveNull(); + // await expect(db.user.findMany()).resolves.toHaveLength(0); + // await reset(); + + // update + await prisma.user.create({ data: { id: 1, age: 18, profile: { create: { id: 1, age: 18 } } } }); + await expect(db.user.update({ where: { id: 1 }, data: { age: 20 } })).toResolveTruthy(); + await expect(prisma.user.findUnique({ where: { id: 1 } })).resolves.toMatchObject({ age: 20 }); + await expect(db.user.update({ where: { id: 1 }, data: { age: 18 } })).toBeRejectedByPolicy(); + await reset(); + + // post update + await prisma.user.create({ data: { id: 1, age: 18, profile: { create: { id: 1, age: 18 } } } }); + await expect(db.user.update({ where: { id: 1 }, data: { age: 15 } })).toBeRejectedByPolicy(); + await expect(db.user.update({ where: { id: 1 }, data: { age: 20 } })).toResolveTruthy(); + await reset(); + + // upsert + await prisma.user.create({ data: { id: 1, age: 18, profile: { create: { id: 1, age: 20 } } } }); + await expect( + db.user.upsert({ where: { id: 1 }, create: { id: 1, age: 25 }, update: { age: 25 } }) + ).toBeRejectedByPolicy(); + await expect( + db.user.upsert({ + where: { id: 2 }, + create: { id: 2, age: 18, profile: { create: { id: 2, age: 25 } } }, + update: { age: 25 }, + }) + ).toBeRejectedByPolicy(); + await prisma.user.update({ where: { id: 1 }, data: { age: 20 } }); + await expect( + db.user.upsert({ where: { id: 1 }, create: { id: 1, age: 25 }, update: { age: 25 } }) + ).toResolveTruthy(); + await expect(prisma.user.findUnique({ where: { id: 1 } })).resolves.toMatchObject({ age: 25 }); + await expect( + db.user.upsert({ + where: { id: 2 }, + create: { id: 2, age: 25, profile: { create: { id: 2, age: 25 } } }, + update: { age: 25 }, + }) + ).toResolveTruthy(); + await expect(prisma.user.findMany()).resolves.toHaveLength(2); + await reset(); + + // updateMany + await prisma.user.create({ data: { id: 1, age: 18, profile: { create: { id: 1, age: 20 } } } }); + // non updatable + await expect(db.user.updateMany({ data: { age: 18 } })).resolves.toMatchObject({ count: 0 }); + await prisma.user.create({ data: { id: 2, age: 25, profile: { create: { id: 2, age: 25 } } } }); + // one of the two is updatable + await expect(db.user.updateMany({ data: { age: 30 } })).resolves.toMatchObject({ count: 1 }); + await expect(prisma.user.findUnique({ where: { id: 1 } })).resolves.toMatchObject({ age: 18 }); + await expect(prisma.user.findUnique({ where: { id: 2 } })).resolves.toMatchObject({ age: 30 }); + await reset(); + + // delete + await prisma.user.create({ data: { id: 1, age: 18, profile: { create: { id: 1, age: 20 } } } }); + await expect(db.user.delete({ where: { id: 1 } })).toBeRejectedByPolicy(); + await expect(prisma.user.findMany()).resolves.toHaveLength(1); + await prisma.user.update({ where: { id: 1 }, data: { age: 20 } }); + await expect(db.user.delete({ where: { id: 1 } })).toResolveTruthy(); + await expect(prisma.user.findMany()).resolves.toHaveLength(0); + await reset(); + + // deleteMany + await prisma.user.create({ data: { id: 1, age: 18, profile: { create: { id: 1, age: 20 } } } }); + await expect(db.user.deleteMany()).resolves.toMatchObject({ count: 0 }); + await prisma.user.create({ data: { id: 2, age: 25, profile: { create: { id: 2, age: 25 } } } }); + // one of the two is deletable + await expect(db.user.deleteMany()).resolves.toMatchObject({ count: 1 }); + await expect(prisma.user.findMany()).resolves.toHaveLength(1); + }); + + it('nested inside to-one relation', async () => { + const { prisma, enhance } = await loadSchema( + ` + model User { + id Int @id + profile Profile? + age Int + + @@allow('all', true) + } + + model Profile { + id Int @id + age Int + user User? @relation(fields: [userId], references: [id]) + userId Int? @unique + + @@allow('read', true) + @@allow('create,update,delete', user == null || age == user.age) + @@deny('update', future().user != null && future().age < future().user.age && age > 0) + } + ` + ); + + const db = enhance(); + + const reset = async () => { + await prisma.profile.deleteMany(); + await prisma.user.deleteMany(); + }; + + // create + await expect( + db.user.create({ data: { id: 1, age: 18, profile: { create: { id: 1, age: 20 } } } }) + ).toBeRejectedByPolicy(); + await expect(prisma.user.findUnique({ where: { id: 1 } })).toResolveNull(); + await expect( + db.user.create({ data: { id: 1, age: 18, profile: { create: { id: 1, age: 18 } } } }) + ).toResolveTruthy(); + await expect(prisma.user.findUnique({ where: { id: 1 } })).toResolveTruthy(); + await reset(); + + // TODO: cross-model field comparison is not supported for read rules yet + // // read + // await prisma.user.create({ data: { id: 1, age: 18, profile: { create: { id: 1, age: 18 } } } }); + // await expect(db.user.findUnique({ where: { id: 1 }, include: { profile: true } })).resolves.toMatchObject({ + // age: 18, + // profile: expect.objectContaining({ age: 18 }), + // }); + // await expect(db.user.findMany({ include: { profile: true } })).resolves.toEqual( + // expect.arrayContaining([ + // expect.objectContaining({ + // age: 18, + // profile: expect.objectContaining({ age: 18 }), + // }), + // ]) + // ); + // await prisma.user.update({ where: { id: 1 }, data: { age: 20 } }); + // let r = await db.user.findUnique({ where: { id: 1 }, include: { profile: true } }); + // expect(r.profile).toBeUndefined(); + // r = await db.user.findMany({ include: { profile: true } }); + // expect(r[0].profile).toBeUndefined(); + // await reset(); + + // update + await prisma.user.create({ data: { id: 1, age: 18, profile: { create: { id: 1, age: 18 } } } }); + await expect( + db.user.update({ where: { id: 1 }, data: { profile: { update: { age: 20 } } } }) + ).toResolveTruthy(); + const r = await prisma.user.findUnique({ where: { id: 1 }, include: { profile: true } }); + expect(r.profile).toMatchObject({ age: 20 }); + await expect( + db.user.update({ where: { id: 1 }, data: { profile: { update: { age: 18 } } } }) + ).toBeRejectedByPolicy(); + await reset(); + + // post update + await prisma.user.create({ data: { id: 1, age: 18, profile: { create: { id: 1, age: 18 } } } }); + await expect( + db.user.update({ where: { id: 1 }, data: { profile: { update: { age: 15 } } } }) + ).toBeRejectedByPolicy(); + await expect( + db.user.update({ where: { id: 1 }, data: { profile: { update: { age: 20 } } } }) + ).toResolveTruthy(); + await reset(); + + // upsert + await prisma.user.create({ data: { id: 1, age: 18, profile: { create: { id: 1, age: 20 } } } }); + await expect( + db.user.update({ + where: { id: 1 }, + data: { + profile: { + upsert: { + create: { id: 1, age: 25 }, + update: { age: 25 }, + }, + }, + }, + }) + ).toBeRejectedByPolicy(); + await prisma.user.update({ where: { id: 1 }, data: { age: 20 } }); + await expect( + db.user.update({ + where: { id: 1 }, + data: { + profile: { + upsert: { + create: { id: 1, age: 25 }, + update: { age: 25 }, + }, + }, + }, + }) + ).toResolveTruthy(); + await prisma.user.create({ data: { id: 2, age: 18 } }); + await expect( + db.user.update({ + where: { id: 2 }, + data: { + profile: { + upsert: { + create: { id: 2, age: 25 }, + update: { age: 25 }, + }, + }, + }, + }) + ).toBeRejectedByPolicy(); + await expect( + db.user.update({ + where: { id: 2 }, + data: { + profile: { + upsert: { + create: { id: 2, age: 18 }, + update: { age: 25 }, + }, + }, + }, + }) + ).toResolveTruthy(); + await reset(); + + // delete + await prisma.user.create({ data: { id: 1, age: 18, profile: { create: { id: 1, age: 20 } } } }); + await expect(db.user.update({ where: { id: 1 }, data: { profile: { delete: true } } })).toBeRejectedByPolicy(); + await prisma.user.update({ where: { id: 1 }, data: { age: 20 } }); + await expect(db.user.update({ where: { id: 1 }, data: { profile: { delete: true } } })).toResolveTruthy(); + await expect(await prisma.profile.findMany()).toHaveLength(0); + await reset(); + + // connect/disconnect + await prisma.user.create({ data: { id: 1, age: 18, profile: { create: { id: 1, age: 20 } } } }); + await expect( + db.user.update({ where: { id: 1 }, data: { profile: { disconnect: true } } }) + ).toBeRejectedByPolicy(); + await prisma.user.update({ where: { id: 1 }, data: { age: 20 } }); + await expect(db.user.update({ where: { id: 1 }, data: { profile: { disconnect: true } } })).toResolveTruthy(); + await prisma.user.create({ data: { id: 2, age: 25 } }); + await expect( + db.user.update({ where: { id: 2 }, data: { profile: { connect: { id: 1 } } } }) + ).toBeRejectedByPolicy(); + await prisma.user.create({ data: { id: 3, age: 20 } }); + await expect(db.user.update({ where: { id: 3 }, data: { profile: { connect: { id: 1 } } } })).toResolveTruthy(); + await expect(prisma.profile.findFirst()).resolves.toMatchObject({ userId: 3 }); + await reset(); + }); + + it('to-many relation', async () => { + const { prisma, enhance } = await loadSchema( + ` + model User { + id Int @id + profiles Profile[] + age Int + + @@allow('read', true) + @@allow('create,update,delete', profiles![this.age == age]) + @@deny('update', future().profiles?[this.age < age]) + } + + model Profile { + id Int @id + age Int + user User @relation(fields: [userId], references: [id], onDelete: Cascade) + userId Int + + @@allow('all', true) + } + `, + { preserveTsFiles: true } + ); + + const db = enhance(); + + const reset = async () => { + await prisma.user.deleteMany(); + }; + + // create + await expect( + db.user.create({ data: { id: 1, age: 18, profiles: { create: [{ id: 1, age: 20 }] } } }) + ).toBeRejectedByPolicy(); + await expect( + db.user.create({ + data: { + id: 1, + age: 18, + profiles: { + createMany: { + data: [ + { id: 1, age: 18 }, + { id: 2, age: 20 }, + ], + }, + }, + }, + }) + ).toBeRejectedByPolicy(); + await expect( + db.user.create({ data: { id: 1, age: 18, profiles: { create: [{ id: 1, age: 20 }] } } }) + ).toBeRejectedByPolicy(); + await expect( + db.user.create({ + data: { + id: 1, + age: 18, + profiles: { + createMany: { + data: [ + { id: 1, age: 18 }, + { id: 2, age: 18 }, + ], + }, + }, + }, + }) + ).toResolveTruthy(); + await expect( + db.user.create({ + data: { id: 2, age: 18 }, + }) + ).toResolveTruthy(); + await reset(); + + // createMany + await expect( + db.user.createMany({ + data: [ + { id: 1, age: 18, profiles: { create: { id: 1, age: 20 } } }, + { id: 2, age: 18, profiles: { create: { id: 2, age: 20 } } }, + ], + }) + ).toBeRejectedByPolicy(); + await expect( + db.user.createMany({ + data: [ + { id: 1, age: 18, profiles: { create: { id: 1, age: 18 } } }, + { id: 2, age: 19, profiles: { create: { id: 2, age: 19 } } }, + ], + }) + ).resolves.toEqual({ count: 2 }); + await reset(); + + // TODO: cross-model field comparison is not supported for read rules yet + // // read + // await prisma.user.create({ data: { id: 1, age: 18, profiles: { create: { id: 1, age: 18 } } } }); + // await expect(db.user.findUnique({ where: { id: 1 } })).toResolveTruthy(); + // await expect(db.user.findMany()).resolves.toHaveLength(1); + // await prisma.user.update({ where: { id: 1 }, data: { age: 20 } }); + // await expect(db.user.findUnique({ where: { id: 1 } })).toResolveNull(); + // await expect(db.user.findMany()).resolves.toHaveLength(0); + // await reset(); + + // update + await prisma.user.create({ data: { id: 1, age: 18, profiles: { create: { id: 1, age: 18 } } } }); + await expect(db.user.update({ where: { id: 1 }, data: { age: 20 } })).toResolveTruthy(); + await expect(prisma.user.findUnique({ where: { id: 1 } })).resolves.toMatchObject({ age: 20 }); + await expect(db.user.update({ where: { id: 1 }, data: { age: 18 } })).toBeRejectedByPolicy(); + await reset(); + + // post update + await prisma.user.create({ data: { id: 1, age: 18, profiles: { create: { id: 1, age: 18 } } } }); + await expect(db.user.update({ where: { id: 1 }, data: { age: 15 } })).toBeRejectedByPolicy(); + await expect(db.user.update({ where: { id: 1 }, data: { age: 20 } })).toResolveTruthy(); + await reset(); + + // upsert + await prisma.user.create({ data: { id: 1, age: 18, profiles: { create: { id: 1, age: 20 } } } }); + await expect( + db.user.upsert({ where: { id: 1 }, create: { id: 1, age: 25 }, update: { age: 25 } }) + ).toBeRejectedByPolicy(); + await expect( + db.user.upsert({ + where: { id: 2 }, + create: { id: 2, age: 18, profiles: { create: { id: 2, age: 25 } } }, + update: { age: 25 }, + }) + ).toBeRejectedByPolicy(); + await prisma.user.update({ where: { id: 1 }, data: { age: 20 } }); + await expect( + db.user.upsert({ where: { id: 1 }, create: { id: 1, age: 25 }, update: { age: 25 } }) + ).toResolveTruthy(); + await expect(prisma.user.findUnique({ where: { id: 1 } })).resolves.toMatchObject({ age: 25 }); + await expect( + db.user.upsert({ + where: { id: 2 }, + create: { id: 2, age: 25, profiles: { create: { id: 2, age: 25 } } }, + update: { age: 25 }, + }) + ).toResolveTruthy(); + await expect(prisma.user.findMany()).resolves.toHaveLength(2); + await reset(); + + // updateMany + await prisma.user.create({ data: { id: 1, age: 18, profiles: { create: { id: 1, age: 20 } } } }); + // non updatable + await expect(db.user.updateMany({ data: { age: 18 } })).resolves.toMatchObject({ count: 0 }); + await prisma.user.create({ data: { id: 2, age: 25, profiles: { create: { id: 2, age: 25 } } } }); + // one of the two is updatable + await expect(db.user.updateMany({ data: { age: 30 } })).resolves.toMatchObject({ count: 1 }); + await expect(prisma.user.findUnique({ where: { id: 1 } })).resolves.toMatchObject({ age: 18 }); + await expect(prisma.user.findUnique({ where: { id: 2 } })).resolves.toMatchObject({ age: 30 }); + await reset(); + + // delete + await prisma.user.create({ data: { id: 1, age: 18, profiles: { create: { id: 1, age: 20 } } } }); + await expect(db.user.delete({ where: { id: 1 } })).toBeRejectedByPolicy(); + await expect(prisma.user.findMany()).resolves.toHaveLength(1); + await prisma.user.update({ where: { id: 1 }, data: { age: 20 } }); + await expect(db.user.delete({ where: { id: 1 } })).toResolveTruthy(); + await expect(prisma.user.findMany()).resolves.toHaveLength(0); + await reset(); + + // deleteMany + await prisma.user.create({ data: { id: 1, age: 18, profiles: { create: { id: 1, age: 20 } } } }); + await expect(db.user.deleteMany()).resolves.toMatchObject({ count: 0 }); + await prisma.user.create({ data: { id: 2, age: 25, profiles: { create: { id: 2, age: 25 } } } }); + // one of the two is deletable + await expect(db.user.deleteMany()).resolves.toMatchObject({ count: 1 }); + await expect(prisma.user.findMany()).resolves.toHaveLength(1); + }); + + it('nested inside to-many relation', async () => { + const { prisma, enhance } = await loadSchema( + ` + model User { + id Int @id + profiles Profile[] + age Int + + @@allow('all', true) + } + + model Profile { + id Int @id + age Int + user User? @relation(fields: [userId], references: [id]) + userId Int? @unique + + @@allow('read', true) + @@allow('create,update,delete', user == null || age == user.age) + @@deny('update', future().user != null && future().age < future().user.age && age > 0) + } + ` + ); + + const db = enhance(); + + const reset = async () => { + await prisma.profile.deleteMany(); + await prisma.user.deleteMany(); + }; + + // create + await expect( + db.user.create({ data: { id: 1, age: 18, profiles: { create: { id: 1, age: 20 } } } }) + ).toBeRejectedByPolicy(); + await expect(prisma.user.findUnique({ where: { id: 1 } })).toResolveNull(); + await expect( + db.user.create({ data: { id: 1, age: 18, profiles: { create: { id: 1, age: 18 } } } }) + ).toResolveTruthy(); + await expect(prisma.user.findUnique({ where: { id: 1 } })).toResolveTruthy(); + await reset(); + + // TODO: cross-model field comparison is not supported for read rules yet + // // read + // await prisma.user.create({ data: { id: 1, age: 18, profiles: { create: { id: 1, age: 18 } } } }); + // await expect(db.user.findUnique({ where: { id: 1 }, include: { profiles: true } })).resolves.toMatchObject({ + // age: 18, + // profiles: [expect.objectContaining({ age: 18 })], + // }); + // await expect(db.user.findMany({ include: { profiles: true } })).resolves.toEqual( + // expect.arrayContaining([ + // expect.objectContaining({ + // age: 18, + // profiles: [expect.objectContaining({ age: 18 })], + // }), + // ]) + // ); + // await prisma.user.update({ where: { id: 1 }, data: { age: 20 } }); + // let r = await db.user.findUnique({ where: { id: 1 }, include: { profiles: true } }); + // expect(r.profiles).toHaveLength(0); + // r = await db.user.findMany({ include: { profiles: true } }); + // expect(r[0].profiles).toHaveLength(0); + // await reset(); + + // update + await prisma.user.create({ data: { id: 1, age: 18, profiles: { create: { id: 1, age: 18 } } } }); + await expect( + db.user.update({ + where: { id: 1 }, + data: { profiles: { update: { where: { id: 1 }, data: { age: 20 } } } }, + }) + ).toResolveTruthy(); + let r = await prisma.user.findUnique({ where: { id: 1 }, include: { profiles: true } }); + expect(r.profiles[0]).toMatchObject({ age: 20 }); + await expect( + db.user.update({ + where: { id: 1 }, + data: { profiles: { update: { where: { id: 1 }, data: { age: 18 } } } }, + }) + ).toBeRejectedByPolicy(); + await reset(); + + // post update + await prisma.user.create({ data: { id: 1, age: 18, profiles: { create: { id: 1, age: 18 } } } }); + await expect( + db.user.update({ + where: { id: 1 }, + data: { profiles: { update: { where: { id: 1 }, data: { age: 15 } } } }, + }) + ).toBeRejectedByPolicy(); + await expect( + db.user.update({ + where: { id: 1 }, + data: { profiles: { update: { where: { id: 1 }, data: { age: 20 } } } }, + }) + ).toResolveTruthy(); + await reset(); + + // upsert + await prisma.user.create({ data: { id: 1, age: 18, profiles: { create: { id: 1, age: 20 } } } }); + await expect( + db.user.update({ + where: { id: 1 }, + data: { + profiles: { + upsert: { + where: { id: 1 }, + create: { id: 1, age: 25 }, + update: { age: 25 }, + }, + }, + }, + }) + ).toBeRejectedByPolicy(); + await prisma.user.update({ where: { id: 1 }, data: { age: 20 } }); + await expect( + db.user.update({ + where: { id: 1 }, + data: { + profiles: { + upsert: { + where: { id: 1 }, + create: { id: 1, age: 25 }, + update: { age: 25 }, + }, + }, + }, + }) + ).toResolveTruthy(); + await prisma.user.create({ data: { id: 2, age: 18 } }); + await expect( + db.user.update({ + where: { id: 2 }, + data: { + profiles: { + upsert: { + where: { id: 2 }, + create: { id: 2, age: 25 }, + update: { age: 25 }, + }, + }, + }, + }) + ).toBeRejectedByPolicy(); + await expect( + db.user.update({ + where: { id: 2 }, + data: { + profiles: { + upsert: { + where: { id: 2 }, + create: { id: 2, age: 18 }, + update: { age: 25 }, + }, + }, + }, + }) + ).toResolveTruthy(); + await reset(); + + // delete + await prisma.user.create({ data: { id: 1, age: 18, profiles: { create: { id: 1, age: 20 } } } }); + await expect( + db.user.update({ where: { id: 1 }, data: { profiles: { delete: { id: 1 } } } }) + ).toBeRejectedByPolicy(); + await prisma.user.update({ where: { id: 1 }, data: { age: 20 } }); + await expect(db.user.update({ where: { id: 1 }, data: { profiles: { delete: { id: 1 } } } })).toResolveTruthy(); + await expect(await prisma.profile.findMany()).toHaveLength(0); + await reset(); + + // connect/disconnect + await prisma.user.create({ data: { id: 1, age: 18, profiles: { create: { id: 1, age: 20 } } } }); + await expect( + db.user.update({ where: { id: 1 }, data: { profiles: { disconnect: { id: 1 } } } }) + ).toBeRejectedByPolicy(); + await prisma.user.update({ where: { id: 1 }, data: { age: 20 } }); + await expect( + db.user.update({ where: { id: 1 }, data: { profiles: { disconnect: { id: 1 } } } }) + ).toResolveTruthy(); + await prisma.user.create({ data: { id: 2, age: 25 } }); + await expect( + db.user.update({ where: { id: 2 }, data: { profiles: { connect: { id: 1 } } } }) + ).toBeRejectedByPolicy(); + await prisma.user.create({ data: { id: 3, age: 20 } }); + await expect( + db.user.update({ where: { id: 3 }, data: { profiles: { connect: { id: 1 } } } }) + ).toResolveTruthy(); + await expect(prisma.profile.findFirst()).resolves.toMatchObject({ userId: 3 }); + await reset(); + }); + + it('field-level simple', async () => { + const { prisma, enhance } = await loadSchema( + ` + model User { + id Int @id + profile Profile @relation(fields: [profileId], references: [id]) + profileId Int @unique + age Int @allow('read', age == profile.age) @allow('update', age > profile.age) + level Int + + @@allow('all', true) + } + + model Profile { + id Int @id + age Int + user User? + + @@allow('all', true) + } + ` + ); + + const db = enhance(); + + // read + await prisma.user.create({ data: { id: 1, age: 18, level: 1, profile: { create: { id: 1, age: 20 } } } }); + let r = await db.user.findUnique({ where: { id: 1 } }); + expect(r.age).toBeUndefined(); + r = await db.user.findUnique({ where: { id: 1 }, select: { age: true } }); + expect(r.age).toBeUndefined(); + + // update + await expect(db.user.update({ where: { id: 1 }, data: { age: 21 } })).toBeRejectedByPolicy(); + await expect(db.user.update({ where: { id: 1 }, data: { level: 2 } })).toResolveTruthy(); + await prisma.user.update({ where: { id: 1 }, data: { age: 21 } }); + await expect(db.user.update({ where: { id: 1 }, data: { age: 25 } })).toResolveTruthy(); + }); + + it('field-level read override', async () => { + const { prisma, enhance } = await loadSchema( + ` + model User { + id Int @id + profile Profile @relation(fields: [profileId], references: [id]) + profileId Int @unique + age Int @allow('read', age == profile.age, true) + level Int + } + + model Profile { + id Int @id + age Int + user User? + @@allow('all', true) + } + ` + ); + + const db = enhance(); + + await prisma.user.create({ data: { id: 1, age: 18, level: 1, profile: { create: { id: 1, age: 20 } } } }); + let r = await db.user.findUnique({ where: { id: 1 } }); + expect(r).toBeNull(); + r = await db.user.findUnique({ where: { id: 1 }, select: { age: true } }); + expect(Object.keys(r).length).toBe(0); + await prisma.user.update({ where: { id: 1 }, data: { age: 20 } }); + r = await db.user.findUnique({ where: { id: 1 }, select: { age: true } }); + expect(r).toMatchObject({ age: 20 }); + }); + + it('field-level update override', async () => { + const { prisma, enhance } = await loadSchema( + ` + model User { + id Int @id + profile Profile @relation(fields: [profileId], references: [id]) + profileId Int @unique + age Int @allow('update', age > profile.age, true) + level Int + @@allow('read', true) + } + + model Profile { + id Int @id + age Int + user User? + @@allow('all', true) + } + ` + ); + + const db = enhance(); + + await prisma.user.create({ data: { id: 1, age: 18, level: 1, profile: { create: { id: 1, age: 20 } } } }); + await expect(db.user.update({ where: { id: 1 }, data: { age: 21 } })).toBeRejectedByPolicy(); + await expect(db.user.update({ where: { id: 1 }, data: { level: 2 } })).toBeRejectedByPolicy(); + await prisma.user.update({ where: { id: 1 }, data: { age: 21 } }); + await expect(db.user.update({ where: { id: 1 }, data: { age: 25 } })).toResolveTruthy(); + }); + + it('with auth', async () => { + const { prisma, enhance } = await loadSchema( + ` + model User { + id Int @id @default(autoincrement()) + permissions Permission[] + @@allow('all', true) + } + + model Permission { + id Int @id @default(autoincrement()) + user User @relation(fields: [userId], references: [id]) + userId Int + model String + level Int + @@allow('all', true) + } + + model Post { + id Int @id @default(autoincrement()) + title String + permission PostPermission? + + @@allow('read', true) + @@allow("create", auth().permissions?[model == 'Post' && level == this.permission.level]) + } + + model PostPermission { + id Int @id @default(autoincrement()) + post Post @relation(fields: [postId], references: [id]) + postId Int @unique + level Int + @@allow('all', true) + } + `, + { preserveTsFiles: true } + ); + + await expect(enhance().post.create({ data: { title: 'P1' } })).toBeRejectedByPolicy(); + await expect( + enhance({ id: 1, permissions: [{ model: 'Foo', level: 1 }] }).post.create({ data: { title: 'P1' } }) + ).toBeRejectedByPolicy(); + await expect( + enhance({ id: 1, permissions: [{ model: 'Post', level: 1 }] }).post.create({ data: { title: 'P1' } }) + ).toBeRejectedByPolicy(); + await expect( + enhance({ id: 1, permissions: [{ model: 'Post', level: 1 }] }).post.create({ + data: { title: 'P1', permission: { create: { level: 1 } } }, + }) + ).toResolveTruthy(); + }); +}); diff --git a/tests/integration/tests/enhancements/with-policy/field-level-policy.test.ts b/tests/integration/tests/enhancements/with-policy/field-level-policy.test.ts index de778e8e8..0297116a0 100644 --- a/tests/integration/tests/enhancements/with-policy/field-level-policy.test.ts +++ b/tests/integration/tests/enhancements/with-policy/field-level-policy.test.ts @@ -915,6 +915,10 @@ describe('Policy: field-level policy', () => { data: { models: { connect: { id: 1 } } }, }) ).toBeRejectedByPolicy(); + await prisma.user.update({ + where: { id: 1 }, + data: { models: { connect: { id: 1 } } }, + }); await expect( db.user.update({ where: { id: 1 }, @@ -1015,6 +1019,10 @@ describe('Policy: field-level policy', () => { data: { model: { connect: { id: 1 } } }, }) ).toBeRejectedByPolicy(); + await prisma.user.update({ + where: { id: 1 }, + data: { model: { connect: { id: 1 } } }, + }); await expect( db.user.update({ where: { id: 1 },