diff --git a/BREAKINGCHANGES.md b/BREAKINGCHANGES.md index 79068ab3..03adfae1 100644 --- a/BREAKINGCHANGES.md +++ b/BREAKINGCHANGES.md @@ -1,2 +1,3 @@ 1. `auth()` cannot be directly compared with a relation anymore 2. `update` and `delete` policy rejection throws `NotFoundError` +3. non-optional to-one relation doesn't automatically filter parent read when evaluating access policies diff --git a/packages/runtime/src/client/crud/dialects/sqlite.ts b/packages/runtime/src/client/crud/dialects/sqlite.ts index 7119be54..10be424c 100644 --- a/packages/runtime/src/client/crud/dialects/sqlite.ts +++ b/packages/runtime/src/client/crud/dialects/sqlite.ts @@ -1,3 +1,4 @@ +import type Decimal from 'decimal.js'; import { ExpressionWrapper, sql, @@ -17,7 +18,6 @@ import { requireModel, } from '../../query-utils'; import { BaseCrudDialect } from './base'; -import type Decimal from 'decimal.js'; export class SqliteCrudDialect< Schema extends SchemaDef @@ -123,11 +123,11 @@ export class SqliteCrudDialect< } tbl = tbl.select(() => { - const objArgs: Array< + type ArgsType = | Expression | RawBuilder - | SelectQueryBuilder - > = []; + | SelectQueryBuilder; + const objArgs: ArgsType[] = []; if (payload === true || !payload.select) { // select all scalar fields @@ -156,18 +156,36 @@ export class SqliteCrudDialect< } else if (payload.select) { // select specific fields objArgs.push( - ...Object.entries(payload.select) + ...Object.entries(payload.select) .filter(([, value]) => value) - .map(([field]) => [ - sql.lit(field), - buildFieldRef( + .map(([field, value]) => { + const fieldDef = requireField( this.schema, relationModel, - field, - this.options, - eb - ), - ]) + field + ); + if (fieldDef.relation) { + const subJson = this.buildRelationJSON( + relationModel as GetModels, + eb, + field, + `${parentName}$${relationField}`, + value + ); + return [sql.lit(field), subJson as ArgsType]; + } else { + return [ + sql.lit(field), + buildFieldRef( + this.schema, + relationModel, + field, + this.options, + eb + ) as ArgsType, + ]; + } + }) .flatMap((v) => v) ); } diff --git a/packages/runtime/src/client/crud/operations/create.ts b/packages/runtime/src/client/crud/operations/create.ts index 03dd8449..a782f24a 100644 --- a/packages/runtime/src/client/crud/operations/create.ts +++ b/packages/runtime/src/client/crud/operations/create.ts @@ -51,6 +51,7 @@ export class CreateOperationHandler< if (!result) { throw new RejectedByPolicyError( + this.model, `result is not allowed to be read back` ); } diff --git a/packages/runtime/src/client/crud/operations/update.ts b/packages/runtime/src/client/crud/operations/update.ts index 7101e11d..18c86ce5 100644 --- a/packages/runtime/src/client/crud/operations/update.ts +++ b/packages/runtime/src/client/crud/operations/update.ts @@ -40,6 +40,7 @@ export class UpdateOperationHandler< if (!result) { throw new RejectedByPolicyError( + this.model, 'result is not allowed to be read back' ); } diff --git a/packages/runtime/src/client/executor/zenstack-query-executor.ts b/packages/runtime/src/client/executor/zenstack-query-executor.ts index abf20c55..c1b8c4a7 100644 --- a/packages/runtime/src/client/executor/zenstack-query-executor.ts +++ b/packages/runtime/src/client/executor/zenstack-query-executor.ts @@ -26,7 +26,7 @@ import type { PromiseType } from 'utility-types'; import type { GetModels, SchemaDef } from '../../schema'; import type { ClientImpl } from '../client-impl'; import type { ClientContract } from '../contract'; -import { InternalError } from '../errors'; +import { InternalError, QueryError } from '../errors'; import type { MutationInterceptionFilterResult, OnKyselyQueryTransactionCallback, @@ -158,17 +158,23 @@ export class ZenStackQueryExecutor< return proceed(queryNode); } - private proceedQuery(query: RootOperationNode, queryId: QueryId) { + private async proceedQuery(query: RootOperationNode, queryId: QueryId) { // run built-in transformers const finalQuery = this.nameMapper.transformNode(query); const compiled = this.compileQuery(finalQuery); - return this.driver.txConnection - ? super - .withConnectionProvider( - new SingleConnectionProvider(this.driver.txConnection) - ) - .executeQuery(compiled, queryId) - : super.executeQuery(compiled, queryId); + try { + return this.driver.txConnection + ? await super + .withConnectionProvider( + new SingleConnectionProvider(this.driver.txConnection) + ) + .executeQuery(compiled, queryId) + : await super.executeQuery(compiled, queryId); + } catch (err) { + throw new QueryError( + `Policy: failed to execute query: ${err}, sql: ${compiled.sql}, parameters: ${compiled.parameters}` + ); + } } private isMutationNode(queryNode: RootOperationNode) { diff --git a/packages/runtime/src/plugins/policy/errors.ts b/packages/runtime/src/plugins/policy/errors.ts index ae707e74..0c0c85f9 100644 --- a/packages/runtime/src/plugins/policy/errors.ts +++ b/packages/runtime/src/plugins/policy/errors.ts @@ -2,7 +2,12 @@ * Error thrown when an operation is rejected by access policy. */ export class RejectedByPolicyError extends Error { - constructor(reason?: string) { - super(reason ?? `Operation rejected by policy`); + constructor( + public readonly model: string | undefined, + public readonly reason?: string + ) { + super( + reason ?? `Operation rejected by policy${model ? ': ' + model : ''}` + ); } } diff --git a/packages/runtime/src/plugins/policy/expression-transformer.ts b/packages/runtime/src/plugins/policy/expression-transformer.ts index 157b5137..3144e1f7 100644 --- a/packages/runtime/src/plugins/policy/expression-transformer.ts +++ b/packages/runtime/src/plugins/policy/expression-transformer.ts @@ -20,6 +20,7 @@ import { } from 'kysely'; import invariant from 'tiny-invariant'; import { match } from 'ts-pattern'; +import type { CRUD } from '../../client/contract'; import { getCrudDialect } from '../../client/crud/dialects'; import type { BaseCrudDialect } from '../../client/crud/dialects/base'; import { InternalError, QueryError } from '../../client/errors'; @@ -45,7 +46,6 @@ import { import type { BuiltinType, FieldDef, GetModels } from '../../schema/schema'; import { ExpressionEvaluator } from './expression-evaluator'; import { conjunction, disjunction, logicalNot, trueNode } from './utils'; -import type { CRUD } from '../../client/contract'; export type ExpressionTransformerContext = { model: GetModels; @@ -53,6 +53,8 @@ export type ExpressionTransformerContext = { operation: CRUD; thisEntity?: Record; auth?: any; + memberFilter?: OperationNode; + memberSelect?: SelectionNode; }; // a registry of expression handlers marked with @expr @@ -141,12 +143,33 @@ export class ExpressionTransformer { return this.createColumnRef(expr.field, context); } } else { - return this.transformRelationAccess( + const { memberFilter, memberSelect, ...restContext } = context; + const relation = this.transformRelationAccess( expr.field, fieldDef.type, - context + restContext ); + return { + ...relation, + where: this.mergeWhere(relation.where, memberFilter), + selections: memberSelect ? [memberSelect] : relation.selections, + }; + } + } + + private mergeWhere( + where: WhereNode | undefined, + memberFilter: OperationNode | undefined + ) { + if (!where) { + return WhereNode.create(memberFilter ?? trueNode(this.dialect)); + } + if (!memberFilter) { + return where; } + return WhereNode.create( + conjunction(this.dialect, [where.where, memberFilter]) + ); } @expr('null') @@ -251,8 +274,6 @@ export class ExpressionTransformer { return this.transformValue(value, 'Boolean'); } - const left = this.transform(expr.left, context); - invariant( Expression.isField(expr.left) || Expression.isMember(expr.left), 'left operand must be field or member access' @@ -284,7 +305,7 @@ export class ExpressionTransformer { } } - let filter = this.transform(expr.right, { + let predicateFilter = this.transform(expr.right, { ...context, model: newContextModel as GetModels, alias: undefined, @@ -292,92 +313,44 @@ export class ExpressionTransformer { }); if (expr.op === '!') { - filter = logicalNot(filter); + predicateFilter = logicalNot(predicateFilter); } - invariant( - SelectQueryNode.is(left), - 'expected left operand to be select query' - ); - const count = FunctionNode.create('count', [ ValueNode.createImmediate(1), ]); - const finalSelectQuery = this.updateInnerMostSelectQuery( - left, - filter, - match(expr.op) - .with('?', () => - BinaryOperationNode.create( - count, - OperatorNode.create('>'), - ValueNode.createImmediate(0) - ) + + const predicateResult = match(expr.op) + .with('?', () => + BinaryOperationNode.create( + count, + OperatorNode.create('>'), + ValueNode.createImmediate(0) ) - .with('!', () => - BinaryOperationNode.create( - count, - OperatorNode.create('='), - ValueNode.createImmediate(0) - ) + ) + .with('!', () => + BinaryOperationNode.create( + count, + OperatorNode.create('='), + ValueNode.createImmediate(0) ) - .with('^', () => - BinaryOperationNode.create( - count, - OperatorNode.create('='), - ValueNode.createImmediate(0) - ) + ) + .with('^', () => + BinaryOperationNode.create( + count, + OperatorNode.create('='), + ValueNode.createImmediate(0) ) - .exhaustive() - ); + ) + .exhaustive(); - return finalSelectQuery; - } - - private updateInnerMostSelectQuery( - node: SelectQueryNode, - where: OperationNode, - selection: OperationNode - ): SelectQueryNode { - if (!node.selections || node.selections.length === 0) { - return { - ...node, - selections: [ - SelectionNode.create( - AliasNode.create(selection, IdentifierNode.create('$t')) - ), - ], - where: WhereNode.create( - node.where - ? conjunction(this.dialect, [node.where.where, where]) - : where - ), - }; - } else { - invariant( - node.selections.length === 1, - 'expected exactly one selection' - ); - const currSelection = node.selections[0]!; - invariant( - AliasNode.is(currSelection.selection), - 'expected alias node' - ); - const alias = currSelection.selection.alias; - const inner = currSelection.selection.node; - invariant(SelectQueryNode.is(inner), 'expected select query node'); - const newInner = this.updateInnerMostSelectQuery( - inner, - where, - selection - ); - return { - ...node, - selections: [ - SelectionNode.create(AliasNode.create(newInner, alias)), - ], - }; - } + return this.transform(expr.left, { + ...context, + memberSelect: SelectionNode.create( + AliasNode.create(predicateResult, IdentifierNode.create('$t')) + ), + memberFilter: predicateFilter, + }); } private transformAuthBinary(expr: BinaryExpression) { @@ -512,7 +485,9 @@ export class ExpressionTransformer { 'expect receiver to be field expression' ); - const receiver = this.transform(expr.receiver, context); + const { memberFilter, memberSelect, ...restContext } = context; + + const receiver = this.transform(expr.receiver, restContext); invariant( SelectQueryNode.is(receiver), 'expected receiver to be select query' @@ -546,12 +521,13 @@ export class ExpressionTransformer { member, fieldDef.type, { - ...context, + ...restContext, model: fromModel as GetModels, alias: undefined, thisEntity: undefined, } ); + if (currNode) { invariant( SelectQueryNode.is(currNode), @@ -563,33 +539,32 @@ export class ExpressionTransformer { SelectionNode.create( AliasNode.create( currNode, - IdentifierNode.create(member) + IdentifierNode.create(expr.members[i + 1]!) ) ), ], }; } else { - currNode = relation; + // inner most member, merge with member filter from the context + currNode = { + ...relation, + where: this.mergeWhere(relation.where, memberFilter), + selections: memberSelect + ? [memberSelect] + : relation.selections, + }; } } else { invariant( i === expr.members.length - 1, 'plain field access must be the last segment' ); + invariant( + !currNode, + 'plain field access must be the last segment' + ); - const columnRef = ColumnNode.create(member); - if (currNode) { - invariant( - SelectQueryNode.is(currNode), - 'expected select query node' - ); - currNode = { - ...(currNode as SelectQueryNode), - selections: [SelectionNode.create(columnRef)], - }; - } else { - currNode = columnRef; - } + currNode = ColumnNode.create(member); } } diff --git a/packages/runtime/src/plugins/policy/policy-handler.ts b/packages/runtime/src/plugins/policy/policy-handler.ts index b1c834ff..c467e998 100644 --- a/packages/runtime/src/plugins/policy/policy-handler.ts +++ b/packages/runtime/src/plugins/policy/policy-handler.ts @@ -86,7 +86,10 @@ export class PolicyHandler< ) { if (!this.isCrudQueryNode(node)) { // non CRUD queries are not allowed - throw new RejectedByPolicyError('non CRUD queries are not allowed'); + throw new RejectedByPolicyError( + undefined, + 'non-CRUD queries are not allowed' + ); } if (!this.isMutationQueryNode(node)) { @@ -104,7 +107,7 @@ export class PolicyHandler< 'create' ); if (constCondition === false) { - throw new RejectedByPolicyError(); + throw new RejectedByPolicyError(mutationModel); } else if (constCondition === undefined) { mutationRequiresTransaction = true; } @@ -142,6 +145,7 @@ export class PolicyHandler< if (readBackError) { throw new RejectedByPolicyError( + mutationModel, 'result is not allowed to be read back' ); } @@ -217,7 +221,7 @@ export class PolicyHandler< }; const result = await proceed(preCreateCheck); if (!(result.rows[0] as any)?.$condition) { - throw new RejectedByPolicyError(); + throw new RejectedByPolicyError(model); } } diff --git a/packages/runtime/src/plugins/policy/utils.ts b/packages/runtime/src/plugins/policy/utils.ts index 541082b3..7b689641 100644 --- a/packages/runtime/src/plugins/policy/utils.ts +++ b/packages/runtime/src/plugins/policy/utils.ts @@ -3,6 +3,7 @@ import { AliasNode, AndNode, BinaryOperationNode, + FunctionNode, OperatorNode, OrNode, ParensNode, @@ -133,7 +134,8 @@ export function buildIsFalse( return falseNode(dialect); } return BinaryOperationNode.create( - node, + // coalesce so null is treated as false + FunctionNode.create('coalesce', [node, falseNode(dialect)]), OperatorNode.create('='), falseNode(dialect) ); diff --git a/packages/runtime/test/policy/deep-nested.test.ts b/packages/runtime/test/policy/deep-nested.test.ts new file mode 100644 index 00000000..ab5ef628 --- /dev/null +++ b/packages/runtime/test/policy/deep-nested.test.ts @@ -0,0 +1,672 @@ +import { beforeEach, describe, expect, it } from 'vitest'; +import { createPolicyTestClient } from './utils'; + +describe('deep nested operations tests', () => { + const model = ` + // M1 - M2 - M3 + // -* M4 + model M1 { + myId String @id @default(cuid()) + m2 M2? + value Int @default(0) + + @@allow('all', true) + @@deny('create', m2.m4?[value == 100]) + @@deny('update', m2.m4?[value == 101]) + @@deny('read', value == 100) + } + + model M2 { + id Int @id @default(autoincrement()) + value Int + m1 M1 @relation(fields: [m1Id], references: [myId], onDelete: Cascade) + m1Id String @unique + + m3 M3? + m4 M4[] + + @@allow('read', true) + @@allow('create', value > 0) + @@allow('update', value > 1) + @@allow('delete', value > 2) + } + + model M3 { + id String @id @default(cuid()) + value Int + m2 M2 @relation(fields: [m2Id], references: [id], onDelete: Cascade) + m2Id Int @unique + + @@allow('read', true) + @@allow('create', value > 10) + @@allow('update', value > 1) + @@allow('delete', value > 2) + @@deny('read', value == 200) + } + + model M4 { + id String @id @default(cuid()) + value Int + m2 M2? @relation(fields: [m2Id], references: [id], onDelete: Cascade) + m2Id Int? + + @@unique([m2Id, value]) + + @@allow('read', true) + @@allow('create', value > 20) + @@allow('update', value > 21) + @@allow('delete', value > 22) + @@deny('read', value == 200) + } + `; + + let db: any; + let rawDb: any; + + beforeEach(async () => { + db = await createPolicyTestClient(model); + rawDb = db.$unuseAll(); + }); + + it('works with nested read', async () => { + await rawDb.m1.create({ + data: { + myId: '1', + m2: { + create: { + value: 1, + m3: { + create: { id: '3-1', value: 31 }, + }, + m4: { + create: [{ value: 41 }, { value: 42 }], + }, + }, + }, + }, + }); + // all readable + let r = await db.m1.findUnique({ + where: { myId: '1' }, + include: { m2: { include: { m3: true, m4: true } } }, + }); + expect(r.m2.m3).toBeTruthy(); + expect(r.m2.m4).toHaveLength(2); + r = await db.m3.findUnique({ + where: { id: '3-1' }, + include: { m2: { include: { m1: true } } }, + }); + expect(r.m2.m1).toBeTruthy(); + + await rawDb.m1.create({ + data: { + myId: '2', + m2: { + create: { + value: 1, + m3: { + create: { value: 200 }, + }, + m4: { + create: [{ value: 22 }, { value: 200 }], + }, + }, + }, + }, + }); + // check filtered + r = await db.m1.findUnique({ + where: { myId: '2' }, + include: { m2: { include: { m3: true, m4: true } } }, + }); + expect(r.m2.m3).toBeNull(); + expect(r.m2.m4).toHaveLength(1); + + await rawDb.m1.create({ + data: { + myId: '3', + value: 100, + m2: { + create: { + value: 1, + m3: { + create: { id: '3-2', value: 31 }, + }, + }, + }, + }, + }); + // m1 is not readable + r = await db.m3.findUnique({ + where: { id: '3-2' }, + include: { m2: { include: { m1: true } } }, + }); + expect(r.m2.m1).toBeNull(); + }); + + // TODO: should nested create be allowed if it's self-consistent as a whole? + it.skip('works with nested create', async () => { + await expect( + db.m1.create({ + data: { + myId: '1', + m2: { + create: { + value: 1, + m3: { + create: { + id: 'm3-1', + value: 11, + }, + }, + m4: { + create: [ + { id: 'm4-1', value: 22 }, + { id: 'm4-2', value: 23 }, + ], + }, + }, + }, + }, + }) + ).toResolveTruthy(); + + const r = await db.m1.create({ + include: { m2: { include: { m3: true, m4: true } } }, + data: { + myId: '2', + m2: { + create: { + value: 2, + m3: { + connect: { + id: 'm3-1', + }, + }, + m4: { + connect: [{ id: 'm4-1' }], + connectOrCreate: [ + { + where: { id: 'm4-2' }, + create: { id: 'm4-new', value: 24 }, + }, + { + where: { id: 'm4-3' }, + create: { id: 'm4-3', value: 25 }, + }, + ], + }, + }, + }, + }, + }); + expect(r.m2.m3.id).toBe('m3-1'); + expect(r.m2.m4[0].id).toBe('m4-1'); + expect(r.m2.m4[1].id).toBe('m4-2'); + expect(r.m2.m4[2].id).toBe('m4-3'); + + // deep create violation + await expect( + db.m1.create({ + data: { + m2: { + create: { + value: 1, + m4: { + create: [{ value: 20 }, { value: 22 }], + }, + }, + }, + }, + }) + ).toBeRejectedByPolicy(); + + // deep create violation due to deep policy + await expect( + db.m1.create({ + data: { + m2: { + create: { + value: 1, + m4: { + create: { value: 100 }, + }, + }, + }, + }, + }) + ).toBeRejectedByPolicy(); + + // deep connect violation via deep policy: @@deny('create', m2.m4?[value == 100]) + await db.m4.create({ + data: { + id: 'm4-value-100', + value: 100, + }, + }); + await expect( + db.m1.create({ + data: { + m2: { + create: { + value: 1, + m4: { + connect: { id: 'm4-value-100' }, + }, + }, + }, + }, + }) + ).toBeRejectedByPolicy(); + + // create read-back filter: M4 @@deny('read', value == 200) + const r1 = await db.m1.create({ + include: { m2: { include: { m4: true } } }, + data: { + m2: { + create: { + value: 1, + m4: { + create: [{ value: 200 }, { value: 201 }], + }, + }, + }, + }, + }); + expect(r1.m2.m4).toHaveLength(1); + + // create read-back filtering: M3 @@deny('read', value == 200) + const r2 = await db.m1.create({ + include: { m2: { include: { m3: true } } }, + data: { + m2: { + create: { + value: 1, + m3: { + create: { value: 200 }, + }, + }, + }, + }, + }); + expect(r2.m2.m3).toBeNull(); + }); + + it('works with nested update', async () => { + await db.m1.create({ + data: { myId: '1' }, + }); + + // success + await expect( + db.m1.update({ + where: { myId: '1' }, + include: { m2: { include: { m3: true, m4: true } } }, + data: { + m2: { + create: { + value: 2, + m3: { + create: { id: 'm3-1', value: 11 }, + }, + m4: { + create: [ + { id: 'm4-1', value: 22 }, + { id: 'm4-2', value: 23 }, + ], + }, + }, + }, + }, + }) + ).toResolveTruthy(); + + // deep update with connect/disconnect/delete success + await db.m4.create({ + data: { + id: 'm4-3', + value: 24, + }, + }); + const r = await db.m1.update({ + where: { myId: '1' }, + include: { m2: { include: { m4: true } } }, + data: { + m2: { + update: { + m4: { + connect: [{ id: 'm4-3' }], + disconnect: { id: 'm4-1' }, + delete: { id: 'm4-2' }, + }, + }, + }, + }, + }); + expect(r.m2.m4).toHaveLength(1); + expect(r.m2.m4[0].id).toBe('m4-3'); + + // reconnect m14-1, create m14-2 + await expect( + db.m1.update({ + where: { myId: '1' }, + include: { m2: { include: { m4: true } } }, + data: { + m2: { + update: { + m4: { + connect: [{ id: 'm4-1' }], + create: { id: 'm4-2', value: 23 }, + }, + }, + }, + }, + }) + ).toResolveTruthy(); + + // deep update violation + await expect( + db.m1.update({ + where: { myId: '1' }, + data: { + m2: { + update: { + m4: { + create: { value: 20 }, + }, + }, + }, + }, + }) + ).toBeRejectedByPolicy(); + + // deep update violation via deep policy: @@deny('update', m2.m4?[value == 101]) + await db.m1.create({ + data: { + myId: '2', + m2: { + create: { + value: 2, + m4: { + create: { id: 'm4-101', value: 101 }, + }, + }, + }, + }, + }); + await expect( + db.m1.update({ + where: { myId: '2' }, + data: { value: 1 }, + }) + ).toBeRejectedNotFound(); + + // update read-back filter: M4 @@deny('read', value == 200) + const r1 = await db.m1.update({ + where: { myId: '1' }, + include: { m2: { include: { m4: true } } }, + data: { + m2: { + update: { + m4: { + update: { + where: { id: 'm4-1' }, + data: { value: 200 }, + }, + }, + }, + }, + }, + }); + expect(r1.m2.m4).toHaveLength(2); + expect(r1.m2.m4).not.toContain(expect.objectContaining({ id: 'm4-1' })); + + // update read-back rejection: M3 @@deny('read', value == 200) + const r2 = await db.m1.update({ + where: { myId: '1' }, + include: { m2: { include: { m3: true } } }, + data: { + m2: { + update: { + m3: { + update: { value: 200 }, + }, + }, + }, + }, + }); + expect(r2.m2.m3).toBeNull(); + }); + + it('works with nested createMany/updateMany/deleteMany', async () => { + await db.m1.create({ + data: { + myId: '1', + m2: { + create: { + id: 1, + value: 2, + }, + }, + }, + }); + + await db.m1.create({ + data: { + myId: '2', + m2: { + create: { + id: 2, + value: 2, + }, + }, + }, + }); + + // createMany with duplicate + await expect( + db.m1.update({ + where: { myId: '1' }, + data: { + m2: { + update: { + m4: { + createMany: { + data: [ + { id: 'm4-1', value: 21 }, + { id: 'm4-1', value: 22 }, + ], + }, + }, + }, + }, + }, + }) + ).rejects.toThrow('constraint failed'); + + // createMany skip duplicate + await db.m1.update({ + where: { myId: '1' }, + data: { + m2: { + update: { + m4: { + createMany: { + skipDuplicates: true, + data: [ + { id: 'm4-1', value: 21 }, // should be created + { id: 'm4-1', value: 211 }, // should be skipped + { id: 'm4-2', value: 22 }, // should be created + ], + }, + }, + }, + }, + }, + }); + await expect(db.m4.findMany()).resolves.toHaveLength(2); + + // createMany skip duplicate with compound unique involving fk + await db.m1.update({ + where: { myId: '2' }, + data: { + m2: { + update: { + m4: { + createMany: { + skipDuplicates: true, + data: [ + { id: 'm4-3', value: 21 }, // should be created + { id: 'm4-4', value: 21 }, // should be skipped + ], + }, + }, + }, + }, + }, + }); + const allM4 = await db.m4.findMany({ select: { value: true } }); + await expect(allM4).toHaveLength(3); + await expect(allM4).toEqual( + expect.arrayContaining([ + { value: 21 }, + { value: 21 }, + { value: 22 }, + ]) + ); + + // updateMany, filtered out by policy + await db.m1.update({ + where: { myId: '1' }, + data: { + m2: { + update: { + m4: { + updateMany: { + where: { + id: 'm4-1', + }, + data: { + value: 210, + }, + }, + }, + }, + }, + }, + }); + await expect( + db.m4.findUnique({ where: { id: 'm4-1' } }) + ).resolves.toMatchObject({ value: 21 }); + await expect( + db.m4.findUnique({ where: { id: 'm4-2' } }) + ).resolves.toMatchObject({ value: 22 }); + + // updateMany, success + await db.m1.update({ + where: { myId: '1' }, + data: { + m2: { + update: { + m4: { + updateMany: { + where: { + id: 'm4-2', + }, + data: { + value: 220, + }, + }, + }, + }, + }, + }, + }); + await expect( + db.m4.findUnique({ where: { id: 'm4-1' } }) + ).resolves.toMatchObject({ value: 21 }); + await expect( + db.m4.findUnique({ where: { id: 'm4-2' } }) + ).resolves.toMatchObject({ value: 220 }); + + // deleteMany, filtered out by policy + await db.m1.update({ + where: { myId: '1' }, + data: { + m2: { + update: { + m4: { + deleteMany: { + id: 'm4-1', + }, + }, + }, + }, + }, + }); + await expect(db.m4.findMany()).resolves.toHaveLength(3); + + // deleteMany, success + await db.m1.update({ + where: { myId: '1' }, + data: { + m2: { + update: { + m4: { + deleteMany: { + id: 'm4-2', + }, + }, + }, + }, + }, + }); + await expect(db.m4.findMany()).resolves.toHaveLength(2); + }); + + it('works with returning relation when deleting', async () => { + await db.m1.create({ + data: { + myId: '1', + m2: { + create: { + value: 1, + m4: { + create: [{ value: 200 }, { value: 22 }], + }, + }, + }, + }, + }); + + // delete read-back filtered: M4 @@deny('read', value == 200) + const r = await db.m1.delete({ + where: { myId: '1' }, + include: { m2: { select: { m4: true } } }, + }); + expect(r.m2.m4).toHaveLength(1); + + await expect(db.m4.findMany()).resolves.toHaveLength(0); + + await db.m1.create({ + data: { + myId: '2', + m2: { + create: { + value: 1, + m3: { + create: { value: 200 }, + }, + }, + }, + }, + }); + + // delete read-back filtered: M3 @@deny('read', value == 200) + const r1 = await db.m1.delete({ + where: { myId: '2' }, + include: { m2: { select: { m3: { select: { id: true } } } } }, + }); + expect(r1.m2.m3).toBeNull(); + }); +}); diff --git a/samples/blog/zenstack/schema.ts b/samples/blog/zenstack/schema.ts index a28a3c36..84eb5329 100644 --- a/samples/blog/zenstack/schema.ts +++ b/samples/blog/zenstack/schema.ts @@ -23,12 +23,12 @@ export const schema = { type: "String", id: true, attributes: [{ name: "@id" }, { name: "@default", args: [{ name: "value", value: Expression.call("cuid") }] }], - default: { call: "cuid" } + default: Expression.call("cuid") }, createdAt: { type: "DateTime", attributes: [{ name: "@default", args: [{ name: "value", value: Expression.call("now") }] }], - default: { call: "now" } + default: Expression.call("now") }, updatedAt: { type: "DateTime", @@ -51,7 +51,7 @@ export const schema = { }, role: { type: "Role", - attributes: [{ name: "@default", args: [{ name: "value", value: Expression.ref("Role", "USER") }] }], + attributes: [{ name: "@default", args: [{ name: "value", value: Expression.literal("USER") }] }], default: "USER" }, posts: { @@ -82,7 +82,7 @@ export const schema = { type: "String", id: true, attributes: [{ name: "@id" }, { name: "@default", args: [{ name: "value", value: Expression.call("cuid") }] }], - default: { call: "cuid" } + default: Expression.call("cuid") }, bio: { type: "String", @@ -95,14 +95,17 @@ export const schema = { user: { type: "User", optional: true, - attributes: [{ name: "@relation", args: [{ name: "fields", value: Expression.array([Expression.ref("Profile", "userId")]) }, { name: "references", value: Expression.array([Expression.ref("User", "id")]) }] }], + attributes: [{ name: "@relation", args: [{ name: "fields", value: Expression.array([Expression.field("userId")]) }, { name: "references", value: Expression.array([Expression.field("id")]) }] }], relation: { opposite: "profile", fields: ["userId"], references: ["id"] } }, userId: { type: "String", unique: true, optional: true, - attributes: [{ name: "@unique" }] + attributes: [{ name: "@unique" }], + foreignKeyFor: [ + "user" + ] } }, idFields: ["id"], @@ -117,12 +120,12 @@ export const schema = { type: "String", id: true, attributes: [{ name: "@id" }, { name: "@default", args: [{ name: "value", value: Expression.call("cuid") }] }], - default: { call: "cuid" } + default: Expression.call("cuid") }, createdAt: { type: "DateTime", attributes: [{ name: "@default", args: [{ name: "value", value: Expression.call("now") }] }], - default: { call: "now" } + default: Expression.call("now") }, updatedAt: { type: "DateTime", @@ -142,11 +145,14 @@ export const schema = { }, author: { type: "User", - attributes: [{ name: "@relation", args: [{ name: "fields", value: Expression.array([Expression.ref("Post", "authorId")]) }, { name: "references", value: Expression.array([Expression.ref("User", "id")]) }] }], + attributes: [{ name: "@relation", args: [{ name: "fields", value: Expression.array([Expression.field("authorId")]) }, { name: "references", value: Expression.array([Expression.field("id")]) }] }], relation: { opposite: "posts", fields: ["authorId"], references: ["id"] } }, authorId: { - type: "String" + type: "String", + foreignKeyFor: [ + "author" + ] } }, idFields: ["id"],