From a9e739b2a8394b935f9dfe1f140ddd417a6bba84 Mon Sep 17 00:00:00 2001 From: ymc9 <104139426+ymc9@users.noreply.github.com> Date: Tue, 9 Sep 2025 23:07:45 -0700 Subject: [PATCH] fix(policy): logical combination issue and more tests for update --- .../plugins/policy/expression-transformer.ts | 4 +- packages/runtime/src/plugins/policy/utils.ts | 43 ++++-- .../runtime/test/policy/crud/create.test.ts | 9 +- .../runtime/test/policy/crud/update.test.ts | 123 ++++++++++++++++++ 4 files changed, 162 insertions(+), 17 deletions(-) create mode 100644 packages/runtime/test/policy/crud/update.test.ts diff --git a/packages/runtime/src/plugins/policy/expression-transformer.ts b/packages/runtime/src/plugins/policy/expression-transformer.ts index ecc4df1f..68af823b 100644 --- a/packages/runtime/src/plugins/policy/expression-transformer.ts +++ b/packages/runtime/src/plugins/policy/expression-transformer.ts @@ -231,7 +231,7 @@ export class ExpressionTransformer { }); if (expr.op === '!') { - predicateFilter = logicalNot(predicateFilter); + predicateFilter = logicalNot(this.dialect, predicateFilter); } const count = FunctionNode.create('count', [ValueNode.createImmediate(1)]); @@ -305,7 +305,7 @@ export class ExpressionTransformer { private _unary(expr: UnaryExpression, context: ExpressionTransformerContext) { // only '!' operator for now invariant(expr.op === '!', 'only "!" operator is supported'); - return logicalNot(this.transform(expr.operand, context)); + return logicalNot(this.dialect, this.transform(expr.operand, context)); } private transformOperator(op: Exclude) { diff --git a/packages/runtime/src/plugins/policy/utils.ts b/packages/runtime/src/plugins/policy/utils.ts index 3c2e641d..a86b4857 100644 --- a/packages/runtime/src/plugins/policy/utils.ts +++ b/packages/runtime/src/plugins/policy/utils.ts @@ -50,6 +50,12 @@ export function conjunction( dialect: BaseCrudDialect, nodes: OperationNode[], ): OperationNode { + if (nodes.length === 0) { + return trueNode(dialect); + } + if (nodes.length === 1) { + return nodes[0]!; + } if (nodes.some(isFalseNode)) { return falseNode(dialect); } @@ -57,17 +63,19 @@ export function conjunction( if (items.length === 0) { return trueNode(dialect); } - return items.reduce((acc, node) => - OrNode.is(node) - ? AndNode.create(acc, ParensNode.create(node)) // wraps parentheses - : AndNode.create(acc, node), - ); + return items.reduce((acc, node) => AndNode.create(wrapParensIf(acc, OrNode.is), wrapParensIf(node, OrNode.is))); } export function disjunction( dialect: BaseCrudDialect, nodes: OperationNode[], ): OperationNode { + if (nodes.length === 0) { + return falseNode(dialect); + } + if (nodes.length === 1) { + return nodes[0]!; + } if (nodes.some(isTrueNode)) { return trueNode(dialect); } @@ -75,25 +83,32 @@ export function disjunction( if (items.length === 0) { return falseNode(dialect); } - return items.reduce((acc, node) => - AndNode.is(node) - ? OrNode.create(acc, ParensNode.create(node)) // wraps parentheses - : OrNode.create(acc, node), - ); + return items.reduce((acc, node) => OrNode.create(wrapParensIf(acc, AndNode.is), wrapParensIf(node, AndNode.is))); } /** * Negates a logical expression. */ -export function logicalNot(node: OperationNode): OperationNode { +export function logicalNot( + dialect: BaseCrudDialect, + node: OperationNode, +): OperationNode { + if (isTrueNode(node)) { + return falseNode(dialect); + } + if (isFalseNode(node)) { + return trueNode(dialect); + } return UnaryOperationNode.create( OperatorNode.create('not'), - AndNode.is(node) || OrNode.is(node) - ? ParensNode.create(node) // wraps parentheses - : node, + wrapParensIf(node, (n) => AndNode.is(n) || OrNode.is(n)), ); } +function wrapParensIf(node: OperationNode, predicate: (node: OperationNode) => boolean): OperationNode { + return predicate(node) ? ParensNode.create(node) : node; +} + /** * Builds an expression node that checks if a node is true. */ diff --git a/packages/runtime/test/policy/crud/create.test.ts b/packages/runtime/test/policy/crud/create.test.ts index a9bacb01..be8c82da 100644 --- a/packages/runtime/test/policy/crud/create.test.ts +++ b/packages/runtime/test/policy/crud/create.test.ts @@ -15,6 +15,13 @@ model Foo { ); await expect(db.foo.create({ data: { x: 0 } })).toBeRejectedByPolicy(); await expect(db.foo.create({ data: { x: 1 } })).resolves.toMatchObject({ x: 1 }); + + await expect( + db.$qb.insertInto('Foo').values({ x: 0 }).returningAll().executeTakeFirst(), + ).toBeRejectedByPolicy(); + await expect( + db.$qb.insertInto('Foo').values({ x: 1 }).returningAll().executeTakeFirst(), + ).resolves.toMatchObject({ x: 1 }); }); it('works with this scalar member check', async () => { @@ -66,7 +73,7 @@ model Foo { id Int @id @default(autoincrement()) x Int @@deny('create', x <= 0) - @@allow('create', x > 1) + @@allow('create', x <= 0 || x > 1) @@allow('read', true) } `, diff --git a/packages/runtime/test/policy/crud/update.test.ts b/packages/runtime/test/policy/crud/update.test.ts new file mode 100644 index 00000000..eb5735a2 --- /dev/null +++ b/packages/runtime/test/policy/crud/update.test.ts @@ -0,0 +1,123 @@ +import { describe, expect, it } from 'vitest'; +import { createPolicyTestClient } from '../utils'; + +describe('Update policy tests', () => { + it('works with scalar field check', async () => { + const db = await createPolicyTestClient( + ` +model Foo { + id Int @id + x Int + @@allow('update', x > 0) + @@allow('create,read', true) +} +`, + ); + + await db.foo.create({ data: { id: 1, x: 0 } }); + await expect(db.foo.update({ where: { id: 1 }, data: { x: 1 } })).toBeRejectedNotFound(); + await db.foo.create({ data: { id: 2, x: 1 } }); + await expect(db.foo.update({ where: { id: 2 }, data: { x: 2 } })).resolves.toMatchObject({ x: 2 }); + + await expect( + db.$qb.updateTable('Foo').set({ x: 1 }).where('id', '=', 1).executeTakeFirst(), + ).resolves.toMatchObject({ numUpdatedRows: 0n }); + await expect( + db.$qb.updateTable('Foo').set({ x: 3 }).where('id', '=', 2).returningAll().execute(), + ).resolves.toMatchObject([{ id: 2, x: 3 }]); + }); + + it('works with this scalar member check', async () => { + const db = await createPolicyTestClient( + ` +model Foo { + id Int @id + x Int + @@allow('update', this.x > 0) + @@allow('create,read', true) +} +`, + ); + + await db.foo.create({ data: { id: 1, x: 0 } }); + await expect(db.foo.update({ where: { id: 1 }, data: { x: 1 } })).toBeRejectedNotFound(); + await db.foo.create({ data: { id: 2, x: 1 } }); + await expect(db.foo.update({ where: { id: 2 }, data: { x: 2 } })).resolves.toMatchObject({ x: 2 }); + }); + + it('denies by default', async () => { + const db = await createPolicyTestClient( + ` +model Foo { + id Int @id + x Int + @@allow('create,read', true) +} +`, + ); + + await db.foo.create({ data: { id: 1, x: 0 } }); + await expect(db.foo.update({ where: { id: 1 }, data: { x: 1 } })).toBeRejectedNotFound(); + }); + + it('works with deny rule', async () => { + const db = await createPolicyTestClient( + ` +model Foo { + id Int @id + x Int + @@deny('update', x <= 0) + @@allow('create,read,update', true) +} +`, + ); + await db.foo.create({ data: { id: 1, x: 0 } }); + await expect(db.foo.update({ where: { id: 1 }, data: { x: 1 } })).toBeRejectedNotFound(); + await db.foo.create({ data: { id: 2, x: 1 } }); + await expect(db.foo.update({ where: { id: 2 }, data: { x: 2 } })).resolves.toMatchObject({ x: 2 }); + }); + + it('works with mixed allow and deny rules', async () => { + const db = await createPolicyTestClient( + ` +model Foo { + id Int @id + x Int + @@deny('update', x <= 0) + @@allow('update', x <= 0 || x > 1) + @@allow('create,read', true) +} +`, + ); + + await db.foo.create({ data: { id: 1, x: 0 } }); + await expect(db.foo.update({ where: { id: 1 }, data: { x: 1 } })).toBeRejectedNotFound(); + await db.foo.create({ data: { id: 2, x: 1 } }); + await expect(db.foo.update({ where: { id: 2 }, data: { x: 2 } })).toBeRejectedNotFound(); + await db.foo.create({ data: { id: 3, x: 2 } }); + await expect(db.foo.update({ where: { id: 3 }, data: { x: 3 } })).resolves.toMatchObject({ x: 3 }); + }); + + it('works with auth check', async () => { + const db = await createPolicyTestClient( + ` +type Auth { + x Int + @@auth +} + +model Foo { + id Int @id + x Int + @@allow('update', x == auth().x) + @@allow('create,read', true) +} +`, + ); + await db.foo.create({ data: { id: 1, x: 1 } }); + await expect(db.$setAuth({ x: 0 }).foo.update({ where: { id: 1 }, data: { x: 2 } })).toBeRejectedNotFound(); + await expect(db.$setAuth({ x: 1 }).foo.update({ where: { id: 1 }, data: { x: 2 } })).resolves.toMatchObject({ + x: 2, + }); + }); +});