Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions packages/runtime/src/plugins/policy/expression-transformer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ export class ExpressionTransformer<Schema extends SchemaDef> {
});

if (expr.op === '!') {
predicateFilter = logicalNot(predicateFilter);
predicateFilter = logicalNot(this.dialect, predicateFilter);
}

const count = FunctionNode.create('count', [ValueNode.createImmediate(1)]);
Expand Down Expand Up @@ -305,7 +305,7 @@ export class ExpressionTransformer<Schema extends SchemaDef> {
private _unary(expr: UnaryExpression, context: ExpressionTransformerContext<Schema>) {
// only '!' operator for now
invariant(expr.op === '!', 'only "!" operator is supported');
return logicalNot(this.transform(expr.operand, context));
return logicalNot(this.dialect, this.transform(expr.operand, context));
}

private transformOperator(op: Exclude<BinaryOperator, '?' | '!' | '^'>) {
Expand Down
43 changes: 29 additions & 14 deletions packages/runtime/src/plugins/policy/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -50,50 +50,65 @@ export function conjunction<Schema extends SchemaDef>(
dialect: BaseCrudDialect<Schema>,
nodes: OperationNode[],
): OperationNode {
if (nodes.length === 0) {
return trueNode(dialect);
}
if (nodes.length === 1) {
return nodes[0]!;
}
if (nodes.some(isFalseNode)) {
return falseNode(dialect);
}
const items = nodes.filter((n) => !isTrueNode(n));
if (items.length === 0) {
return trueNode(dialect);
}
return items.reduce((acc, node) =>
OrNode.is(node)
? AndNode.create(acc, ParensNode.create(node)) // wraps parentheses
: AndNode.create(acc, node),
);
return items.reduce((acc, node) => AndNode.create(wrapParensIf(acc, OrNode.is), wrapParensIf(node, OrNode.is)));
}

export function disjunction<Schema extends SchemaDef>(
dialect: BaseCrudDialect<Schema>,
nodes: OperationNode[],
): OperationNode {
if (nodes.length === 0) {
return falseNode(dialect);
}
if (nodes.length === 1) {
return nodes[0]!;
}
if (nodes.some(isTrueNode)) {
return trueNode(dialect);
}
const items = nodes.filter((n) => !isFalseNode(n));
if (items.length === 0) {
return falseNode(dialect);
}
return items.reduce((acc, node) =>
AndNode.is(node)
? OrNode.create(acc, ParensNode.create(node)) // wraps parentheses
: OrNode.create(acc, node),
);
return items.reduce((acc, node) => OrNode.create(wrapParensIf(acc, AndNode.is), wrapParensIf(node, AndNode.is)));
}

/**
* Negates a logical expression.
*/
export function logicalNot(node: OperationNode): OperationNode {
export function logicalNot<Schema extends SchemaDef>(
dialect: BaseCrudDialect<Schema>,
node: OperationNode,
): OperationNode {
if (isTrueNode(node)) {
return falseNode(dialect);
}
if (isFalseNode(node)) {
return trueNode(dialect);
}
return UnaryOperationNode.create(
OperatorNode.create('not'),
AndNode.is(node) || OrNode.is(node)
? ParensNode.create(node) // wraps parentheses
: node,
wrapParensIf(node, (n) => AndNode.is(n) || OrNode.is(n)),
);
}

function wrapParensIf(node: OperationNode, predicate: (node: OperationNode) => boolean): OperationNode {
return predicate(node) ? ParensNode.create(node) : node;
}

/**
* Builds an expression node that checks if a node is true.
*/
Expand Down
9 changes: 8 additions & 1 deletion packages/runtime/test/policy/crud/create.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,13 @@ model Foo {
);
await expect(db.foo.create({ data: { x: 0 } })).toBeRejectedByPolicy();
await expect(db.foo.create({ data: { x: 1 } })).resolves.toMatchObject({ x: 1 });

await expect(
db.$qb.insertInto('Foo').values({ x: 0 }).returningAll().executeTakeFirst(),
).toBeRejectedByPolicy();
await expect(
db.$qb.insertInto('Foo').values({ x: 1 }).returningAll().executeTakeFirst(),
).resolves.toMatchObject({ x: 1 });
});

it('works with this scalar member check', async () => {
Expand Down Expand Up @@ -66,7 +73,7 @@ model Foo {
id Int @id @default(autoincrement())
x Int
@@deny('create', x <= 0)
@@allow('create', x > 1)
@@allow('create', x <= 0 || x > 1)
@@allow('read', true)
}
`,
Expand Down
123 changes: 123 additions & 0 deletions packages/runtime/test/policy/crud/update.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
import { describe, expect, it } from 'vitest';
import { createPolicyTestClient } from '../utils';

describe('Update policy tests', () => {
it('works with scalar field check', async () => {
const db = await createPolicyTestClient(
`
model Foo {
id Int @id
x Int
@@allow('update', x > 0)
@@allow('create,read', true)
}
`,
);

await db.foo.create({ data: { id: 1, x: 0 } });
await expect(db.foo.update({ where: { id: 1 }, data: { x: 1 } })).toBeRejectedNotFound();
await db.foo.create({ data: { id: 2, x: 1 } });
await expect(db.foo.update({ where: { id: 2 }, data: { x: 2 } })).resolves.toMatchObject({ x: 2 });

await expect(
db.$qb.updateTable('Foo').set({ x: 1 }).where('id', '=', 1).executeTakeFirst(),
).resolves.toMatchObject({ numUpdatedRows: 0n });
await expect(
db.$qb.updateTable('Foo').set({ x: 3 }).where('id', '=', 2).returningAll().execute(),
).resolves.toMatchObject([{ id: 2, x: 3 }]);
});

it('works with this scalar member check', async () => {
const db = await createPolicyTestClient(
`
model Foo {
id Int @id
x Int
@@allow('update', this.x > 0)
@@allow('create,read', true)
}
`,
);

await db.foo.create({ data: { id: 1, x: 0 } });
await expect(db.foo.update({ where: { id: 1 }, data: { x: 1 } })).toBeRejectedNotFound();
await db.foo.create({ data: { id: 2, x: 1 } });
await expect(db.foo.update({ where: { id: 2 }, data: { x: 2 } })).resolves.toMatchObject({ x: 2 });
});

it('denies by default', async () => {
const db = await createPolicyTestClient(
`
model Foo {
id Int @id
x Int
@@allow('create,read', true)
}
`,
);

await db.foo.create({ data: { id: 1, x: 0 } });
await expect(db.foo.update({ where: { id: 1 }, data: { x: 1 } })).toBeRejectedNotFound();
});

it('works with deny rule', async () => {
const db = await createPolicyTestClient(
`
model Foo {
id Int @id
x Int
@@deny('update', x <= 0)
@@allow('create,read,update', true)
}
`,
);
await db.foo.create({ data: { id: 1, x: 0 } });
await expect(db.foo.update({ where: { id: 1 }, data: { x: 1 } })).toBeRejectedNotFound();
await db.foo.create({ data: { id: 2, x: 1 } });
await expect(db.foo.update({ where: { id: 2 }, data: { x: 2 } })).resolves.toMatchObject({ x: 2 });
});

it('works with mixed allow and deny rules', async () => {
const db = await createPolicyTestClient(
`
model Foo {
id Int @id
x Int
@@deny('update', x <= 0)
@@allow('update', x <= 0 || x > 1)
@@allow('create,read', true)
}
`,
);

await db.foo.create({ data: { id: 1, x: 0 } });
await expect(db.foo.update({ where: { id: 1 }, data: { x: 1 } })).toBeRejectedNotFound();
await db.foo.create({ data: { id: 2, x: 1 } });
await expect(db.foo.update({ where: { id: 2 }, data: { x: 2 } })).toBeRejectedNotFound();
await db.foo.create({ data: { id: 3, x: 2 } });
await expect(db.foo.update({ where: { id: 3 }, data: { x: 3 } })).resolves.toMatchObject({ x: 3 });
});

it('works with auth check', async () => {
const db = await createPolicyTestClient(
`
type Auth {
x Int
@@auth
}

model Foo {
id Int @id
x Int
@@allow('update', x == auth().x)
@@allow('create,read', true)
}
`,
);
await db.foo.create({ data: { id: 1, x: 1 } });
await expect(db.$setAuth({ x: 0 }).foo.update({ where: { id: 1 }, data: { x: 2 } })).toBeRejectedNotFound();
await expect(db.$setAuth({ x: 1 }).foo.update({ where: { id: 1 }, data: { x: 2 } })).resolves.toMatchObject({
x: 2,
});
});
});