Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions TODO.md
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@
- [ ] Short-circuit pre-create check for scalar-field only policies
- [x] Inject "on conflict do update"
- [x] `check` function
- [ ] Accessing tables not in the schema
- [x] Migration
- [ ] Databases
- [x] SQLite
Expand Down
8 changes: 4 additions & 4 deletions packages/language/res/stdlib.zmodel
Original file line number Diff line number Diff line change
Expand Up @@ -666,7 +666,7 @@ attribute @@@deprecated(_ message: String)
* @param operation: comma-separated list of "create", "read", "update", "delete". Use "all" to denote all operations.
* @param condition: a boolean expression that controls if the operation should be allowed.
*/
attribute @@allow(_ operation: String @@@completionHint(["'create'", "'read'", "'update'", "'delete'", "'all'"]), _ condition: Boolean)
attribute @@allow(_ operation: String @@@completionHint(["'create'", "'read'", "'update'", "'post-update'","'delete'", "'all'"]), _ condition: Boolean)

/**
* Defines an access policy that allows the annotated field to be read or updated.
Expand All @@ -684,7 +684,7 @@ attribute @allow(_ operation: String @@@completionHint(["'create'", "'read'", "'
* @param operation: comma-separated list of "create", "read", "update", "delete". Use "all" to denote all operations.
* @param condition: a boolean expression that controls if the operation should be denied.
*/
attribute @@deny(_ operation: String @@@completionHint(["'create'", "'read'", "'update'", "'delete'", "'all'"]), _ condition: Boolean)
attribute @@deny(_ operation: String @@@completionHint(["'create'", "'read'", "'update'", "'post-update'","'delete'", "'all'"]), _ condition: Boolean)

/**
* Defines an access policy that denies the annotated field to be read or updated.
Expand All @@ -705,8 +705,8 @@ function check(field: Any, operation: String?): Boolean {
} @@@expressionContext([AccessPolicy])

/**
* Gets entities value before an update. Only valid when used in a "update" policy rule.
* Gets entity's value before an update. Only valid when used in a "post-update" policy rule.
*/
function future(): Any {
function before(): Any {
} @@@expressionContext([AccessPolicy])

8 changes: 2 additions & 6 deletions packages/language/src/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -145,10 +145,6 @@ export function isRelationshipField(field: DataField) {
return isDataModel(field.type.reference?.ref);
}

export function isFutureExpr(node: AstNode) {
return isInvocationExpr(node) && node.function.ref?.name === 'future' && isFromStdlib(node.function.ref);
}

export function isDelegateModel(node: AstNode) {
return isDataModel(node) && hasAttribute(node, '@@delegate');
}
Expand Down Expand Up @@ -450,8 +446,8 @@ export function getAuthDecl(decls: (DataModel | TypeDef)[]) {
return authModel;
}

export function isFutureInvocation(node: AstNode) {
return isInvocationExpr(node) && node.function.ref?.name === 'future' && isFromStdlib(node.function.ref);
export function isBeforeInvocation(node: AstNode) {
return isInvocationExpr(node) && node.function.ref?.name === 'before' && isFromStdlib(node.function.ref);
}

export function isCollectionPredicate(node: AstNode): node is BinaryExpr {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@ import {
getAllAttributes,
getStringLiteral,
isAuthOrAuthMemberAccess,
isBeforeInvocation,
isCollectionPredicate,
isDataFieldReference,
isDelegateModel,
isFutureExpr,
isRelationshipField,
mapBuiltinTypeToExpressionType,
resolved,
Expand Down Expand Up @@ -166,13 +166,20 @@ export default class AttributeApplicationValidator implements AstValidator<Attri
});
return;
}
this.validatePolicyKinds(kind, ['create', 'read', 'update', 'delete', 'all'], attr, accept);
this.validatePolicyKinds(kind, ['create', 'read', 'update', 'post-update', 'delete', 'all'], attr, accept);

if ((kind === 'create' || kind === 'all') && attr.args[1]?.value) {
// "create" rules cannot access non-owned relations because the entity does not exist yet, so
// there can't possibly be a fk that points to it
this.rejectNonOwnedRelationInExpression(attr.args[1].value, accept);
}

if (kind !== 'post-update' && attr.args[1]?.value) {
const beforeCall = AstUtils.streamAst(attr.args[1]?.value).find(isBeforeInvocation);
if (beforeCall) {
accept('error', `"before()" is only allowed in "post-update" policy rules`, { node: beforeCall });
}
}
}

private rejectNonOwnedRelationInExpression(expr: Expression, accept: ValidationAcceptor) {
Expand Down Expand Up @@ -251,8 +258,8 @@ export default class AttributeApplicationValidator implements AstValidator<Attri
const kindItems = this.validatePolicyKinds(kind, ['read', 'update', 'all'], attr, accept);

const expr = attr.args[1]?.value;
if (expr && AstUtils.streamAst(expr).some((node) => isFutureExpr(node))) {
accept('error', `"future()" is not allowed in field-level policy rules`, { node: expr });
if (expr && AstUtils.streamAst(expr).some((node) => isBeforeInvocation(node))) {
accept('error', `"before()" is not allowed in field-level policy rules`, { node: expr });
}

// 'update' rules are not allowed for relation fields
Expand Down
11 changes: 11 additions & 0 deletions packages/language/src/validators/expression-validator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,15 @@ import {
isNullExpr,
isReferenceExpr,
isThisExpr,
MemberAccessExpr,
type ExpressionType,
} from '../generated/ast';

import {
findUpAst,
isAuthInvocation,
isAuthOrAuthMemberAccess,
isBeforeInvocation,
isDataFieldReference,
isEnumFieldReference,
typeAssignable,
Expand Down Expand Up @@ -59,12 +61,21 @@ export default class ExpressionValidator implements AstValidator<Expression> {

// extra validations by expression type
switch (expr.$type) {
case 'MemberAccessExpr':
this.validateMemberAccessExpr(expr, accept);
break;
case 'BinaryExpr':
this.validateBinaryExpr(expr, accept);
break;
}
}

private validateMemberAccessExpr(expr: MemberAccessExpr, accept: ValidationAcceptor) {
if (isBeforeInvocation(expr.operand) && isDataModel(expr.$resolvedType?.decl)) {
accept('error', 'relation fields cannot be accessed from `before()`', { node: expr });
}
}

private validateBinaryExpr(expr: BinaryExpr, accept: ValidationAcceptor) {
switch (expr.operator) {
case 'in': {
Expand Down
6 changes: 3 additions & 3 deletions packages/language/src/zmodel-linker.ts
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ import {
getAuthDecl,
getContainingDataModel,
isAuthInvocation,
isFutureExpr,
isBeforeInvocation,
isMemberContainer,
mapBuiltinTypeToExpressionType,
} from './utils';
Expand Down Expand Up @@ -292,8 +292,8 @@ export class ZModelLinker extends DefaultLinker {
if (authDecl) {
node.$resolvedType = { decl: authDecl, nullable: true };
}
} else if (isFutureExpr(node)) {
// future() function is resolved to current model
} else if (isBeforeInvocation(node)) {
// before() function is resolved to current model
node.$resolvedType = { decl: getContainingDataModel(node) };
} else {
this.resolveToDeclaredType(node, funcDecl.returnType);
Expand Down
6 changes: 3 additions & 3 deletions packages/language/src/zmodel-scope.ts
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ import {
getRecursiveBases,
isAuthInvocation,
isCollectionPredicate,
isFutureInvocation,
isBeforeInvocation,
resolveImportUri,
} from './utils';

Expand Down Expand Up @@ -170,8 +170,8 @@ export class ZModelScopeProvider extends DefaultScopeProvider {
return this.createScopeForAuth(node, globalScope);
}

if (isFutureInvocation(operand)) {
// resolve `future()` to the containing model
if (isBeforeInvocation(operand)) {
// resolve `before()` to the containing model
return this.createScopeForContainingModel(node, globalScope);
}
return EMPTY_SCOPE;
Expand Down
23 changes: 23 additions & 0 deletions packages/language/test/attribute-application.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import { describe, it } from 'vitest';
import { loadSchemaWithError } from './utils';

describe('Attribute application validation tests', () => {
it('rejects before in non-post-update policies', async () => {
await loadSchemaWithError(
`
datasource db {
provider = 'sqlite'
url = 'file:./dev.db'
}

model Foo {
id Int @id @default(autoincrement())
x Int
@@allow('all', true)
@@deny('update', before(x) > 2)
}
`,
`"before()" is only allowed in "post-update" policy rules`,
);
});
});
10 changes: 10 additions & 0 deletions packages/runtime/src/client/contract.ts
Original file line number Diff line number Diff line change
Expand Up @@ -213,11 +213,21 @@ export interface ClientConstructor {
*/
export type CRUD = 'create' | 'read' | 'update' | 'delete';

/**
* Extended CRUD operations including 'post-update'.
*/
export type CRUD_EXT = CRUD | 'post-update';

/**
* CRUD operations.
*/
export const CRUD = ['create', 'read', 'update', 'delete'] as const;

/**
* Extended CRUD operations including 'post-update'.
*/
export const CRUD_EXT = [...CRUD, 'post-update'] as const;

//#region Model operations

export type AllModelOperations<Schema extends SchemaDef, Model extends GetModels<Schema>> = {
Expand Down
5 changes: 3 additions & 2 deletions packages/runtime/src/client/crud/operations/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1296,8 +1296,9 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
return { count: Number(result.numAffectedRows) } as Result;
} else {
const idFields = requireIdFields(this.schema, model);
const result = await query.returning(idFields as any).execute();
return result as Result;
const finalQuery = query.returning(idFields as any);
const result = await this.executeQuery(kysely, finalQuery, 'update');
return result.rows as Result;
}
}

Expand Down
4 changes: 2 additions & 2 deletions packages/runtime/src/client/options.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import type { Dialect, Expression, ExpressionBuilder, KyselyConfig } from 'kysely';
import type { GetModel, GetModels, ProcedureDef, SchemaDef } from '../schema';
import type { PrependParameter } from '../utils/type-utils';
import type { ClientContract, CRUD, ProcedureFunc } from './contract';
import type { ClientContract, CRUD_EXT, ProcedureFunc } from './contract';
import type { BaseCrudDialect } from './crud/dialects/base-dialect';
import type { RuntimePlugin } from './plugin';
import type { ToKyselySchema } from './query-builder';
Expand Down Expand Up @@ -30,7 +30,7 @@ export type ZModelFunctionContext<Schema extends SchemaDef> = {
/**
* The CRUD operation being performed
*/
operation: CRUD;
operation: CRUD_EXT;
};

export type ZModelFunction<Schema extends SchemaDef> = (
Expand Down
20 changes: 14 additions & 6 deletions packages/runtime/src/plugins/policy/expression-transformer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import {
type OperationNode,
} from 'kysely';
import { match } from 'ts-pattern';
import type { ClientContract, CRUD } from '../../client/contract';
import type { ClientContract, CRUD_EXT } from '../../client/contract';
import { getCrudDialect } from '../../client/crud/dialects';
import type { BaseCrudDialect } from '../../client/crud/dialects/base-dialect';
import { InternalError, QueryError } from '../../client/errors';
Expand Down Expand Up @@ -50,13 +50,12 @@ import {
type SchemaDef,
} from '../../schema';
import { ExpressionEvaluator } from './expression-evaluator';
import { conjunction, disjunction, falseNode, logicalNot, trueNode } from './utils';
import { conjunction, disjunction, falseNode, isBeforeInvocation, logicalNot, trueNode } from './utils';

export type ExpressionTransformerContext<Schema extends SchemaDef> = {
model: GetModels<Schema>;
alias?: string;
operation: CRUD;
auth?: any;
operation: CRUD_EXT;
memberFilter?: OperationNode;
memberSelect?: SelectionNode;
};
Expand Down Expand Up @@ -439,7 +438,7 @@ export class ExpressionTransformer<Schema extends SchemaDef> {
}

if (this.isAuthMember(arg)) {
const valNode = this.valueMemberAccess(context.auth, arg as MemberExpression, this.authType);
const valNode = this.valueMemberAccess(this.auth, arg as MemberExpression, this.authType);
return valNode ? eb.val(valNode.value) : eb.val(null);
}

Expand All @@ -453,11 +452,20 @@ export class ExpressionTransformer<Schema extends SchemaDef> {
@expr('member')
// @ts-ignore
private _member(expr: MemberExpression, context: ExpressionTransformerContext<Schema>) {
// auth() member access
// `auth()` member access
if (this.isAuthCall(expr.receiver)) {
return this.valueMemberAccess(this.auth, expr, this.authType);
}

// `before()` member access
if (isBeforeInvocation(expr.receiver)) {
// policy handler creates a join table named `$before` using entity value before update,
// we can directly reference the column from there
invariant(context.operation === 'post-update', 'before() can only be used in post-update policy');
invariant(expr.members.length === 1, 'before() can only be followed by a scalar field access');
return ReferenceNode.create(ColumnNode.create(expr.members[0]!), TableNode.create('$before'));
}

invariant(
ExpressionUtils.isField(expr.receiver) || ExpressionUtils.isThis(expr.receiver),
'expect receiver to be field expression or "this"',
Expand Down
4 changes: 2 additions & 2 deletions packages/runtime/src/plugins/policy/plugin.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ export class PolicyPlugin<Schema extends SchemaDef> implements RuntimePlugin<Sch
};
}

onKyselyQuery({ query, client, proceed /*, transaction*/ }: OnKyselyQueryArgs<Schema>) {
onKyselyQuery({ query, client, proceed }: OnKyselyQueryArgs<Schema>) {
const handler = new PolicyHandler<Schema>(client);
return handler.handle(query, proceed /*, transaction*/);
return handler.handle(query, proceed);
}
}
Loading