Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion BREAKINGCHANGES.md
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
1. `auth()` cannot be directly compared with a relation anymore
2.
2. `update` and `delete` policy rejection throws `NotFoundError`
2 changes: 2 additions & 0 deletions NEW-FEATURES.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
- Cross-field comparison (for read and mutations)
- Custom policy functions
2 changes: 2 additions & 0 deletions TODO.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,11 @@
- [x] Custom table name
- [x] Custom field name
- [ ] Access Policy
- [ ] Short-circuit pre-create check for scalar-field only policies
- [ ] Polymorphism
- [x] Migration
- [ ] Databases
- [x] SQLite
- [x] PostgreSQL
- [ ] Schema
- [ ] MySQL
70 changes: 0 additions & 70 deletions packages/language/src/validators/expression-validator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,6 @@ export default class ExpressionValidator implements AstValidator<Expression> {
});
}

this.validateCrossModelFieldComparison(expr, accept);
break;
}

Expand Down Expand Up @@ -164,10 +163,6 @@ export default class ExpressionValidator implements AstValidator<Expression> {
node: expr,
});
}

if (expr.operator !== '&&' && expr.operator !== '||') {
this.validateCrossModelFieldComparison(expr, accept);
}
break;
}

Expand Down Expand Up @@ -196,10 +191,6 @@ export default class ExpressionValidator implements AstValidator<Expression> {
break;
}

if (!this.validateCrossModelFieldComparison(expr, accept)) {
break;
}

if (
(expr.left.$resolvedType?.nullable &&
isNullExpr(expr.right)) ||
Expand Down Expand Up @@ -289,67 +280,6 @@ export default class ExpressionValidator implements AstValidator<Expression> {
}
}

private validateCrossModelFieldComparison(
expr: BinaryExpr,
accept: ValidationAcceptor
) {
// not supported in "read" rules:
// - foo.a == bar
// - foo.user.id == userId
// except:
// - future().userId == userId
if (
(isMemberAccessExpr(expr.left) &&
isDataModelField(expr.left.member.ref) &&
expr.left.member.ref.$container !=
AstUtils.getContainerOfType(expr, isDataModel)) ||
(isMemberAccessExpr(expr.right) &&
isDataModelField(expr.right.member.ref) &&
expr.right.member.ref.$container !=
AstUtils.getContainerOfType(expr, isDataModel))
) {
// foo.user.id == auth().id
// foo.user.id == "123"
// foo.user.id == null
// foo.user.id == EnumValue
if (
!(
this.isNotModelFieldExpr(expr.left) ||
this.isNotModelFieldExpr(expr.right)
)
) {
const containingPolicyAttr = findUpAst(
expr,
(node) =>
isDataModelAttribute(node) &&
['@@allow', '@@deny'].includes(node.decl.$refText)
) as DataModelAttribute | undefined;

if (containingPolicyAttr) {
const operation = getAttributeArgLiteral<string>(
containingPolicyAttr,
'operation'
);
if (
operation?.split(',').includes('all') ||
operation?.split(',').includes('read')
) {
accept(
'error',
'comparison between fields of different models is not supported in model-level "read" rules',
{
node: expr,
}
);
return false;
}
}
}
}

return true;
}

private validateCollectionPredicate(
expr: BinaryExpr,
accept: ValidationAcceptor
Expand Down
58 changes: 3 additions & 55 deletions packages/language/src/validators/function-invocation-validator.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import { AstUtils, type AstNode, type ValidationAcceptor } from 'langium';
import { match, P } from 'ts-pattern';
import { ExpressionContext } from '../constants';
import {
Argument,
DataModel,
Expand All @@ -7,26 +10,19 @@ import {
FunctionDecl,
FunctionParam,
InvocationExpr,
isArrayExpr,
isDataModel,
isDataModelAttribute,
isDataModelFieldAttribute,
isLiteralExpr,
} from '../generated/ast';
import { match, P } from 'ts-pattern';
import {
getFieldReference,
getFunctionExpressionContext,
getLiteral,
isCheckInvocation,
isDataModelFieldReference,
isEnumFieldReference,
isFromStdlib,
typeAssignable,
} from '../utils';
import type { AstValidator } from './common';
import { AstUtils, type AstNode, type ValidationAcceptor } from 'langium';
import { ExpressionContext } from '../constants';

// a registry of function handlers marked with @func
const invocationCheckers = new Map<string, PropertyDescriptor>();
Expand Down Expand Up @@ -128,54 +124,6 @@ export default class FunctionInvocationValidator
}
);
}
} else if (
funcAllowedContext.includes(ExpressionContext.AccessPolicy) ||
funcAllowedContext.includes(ExpressionContext.ValidationRule)
) {
// filter operation functions validation

// first argument must refer to a model field
const firstArg = expr.args?.[0]?.value;
if (firstArg) {
if (!getFieldReference(firstArg)) {
accept(
'error',
'first argument must be a field reference',
{ node: firstArg }
);
}
}

// second argument must be a literal or array of literal
const secondArg = expr.args?.[1]?.value;
if (
secondArg &&
// literal
!isLiteralExpr(secondArg) &&
// enum field
!isEnumFieldReference(secondArg) &&
// TODO: revisit this
// `auth()...` expression
// !isAuthOrAuthMemberAccess(secondArg) &&
// array of literal/enum
!(
isArrayExpr(secondArg) &&
secondArg.items.every(
(item) =>
isLiteralExpr(item) ||
isEnumFieldReference(item)
// || isAuthOrAuthMemberAccess(item)
)
)
) {
accept(
'error',
'second argument must be a literal, an enum, an expression starting with `auth().`, or an array of them',
{
node: secondArg,
}
);
}
}
}

Expand Down
5 changes: 2 additions & 3 deletions packages/runtime/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@
"dependencies": {
"@paralleldrive/cuid2": "^2.2.2",
"decimal.js": "^10.4.3",
"decimal.js-light": "^2.5.1",
"kysely": "^0.27.5",
"nanoid": "^5.0.9",
"tiny-invariant": "^1.3.3",
Expand All @@ -90,8 +89,8 @@
"@types/better-sqlite3": "^7.0.0",
"@types/pg": "^8.0.0",
"@types/tmp": "^0.2.6",
"tmp": "^0.2.3",
"@zenstackhq/language": "workspace:*",
"@zenstackhq/testtools": "workspace:*"
"@zenstackhq/testtools": "workspace:*",
"tmp": "^0.2.3"
}
}
6 changes: 6 additions & 0 deletions packages/runtime/src/client/client-impl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import type { RuntimePlugin } from './plugin';
import { createDeferredPromise } from './promise';
import type { ToKysely } from './query-builder';
import { ResultProcessor } from './result-processor';
import * as BuiltinFunctions from './functions';

/**
* Creates a new ZenStack client instance.
Expand Down Expand Up @@ -58,6 +59,11 @@ export class ClientImpl<Schema extends SchemaDef> {
this.$schema = schema;
this.$options = options ?? ({} as ClientOptions<Schema>);

this.$options.functions = {
...BuiltinFunctions,
...this.$options.functions,
};

// here we use kysely's props constructor so we can pass a custom query executor
if (baseClient) {
this.kyselyProps = {
Expand Down
2 changes: 1 addition & 1 deletion packages/runtime/src/client/contract.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import type { Decimal } from 'decimal.js-light';
import type { Decimal } from 'decimal.js';
import {
type AuthType,
type GetModels,
Expand Down
8 changes: 8 additions & 0 deletions packages/runtime/src/client/crud-types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,14 @@ export type DateTimeFilter<Nullable extends boolean> =
| NullableIf<Date | string, Nullable>
| CommonPrimitiveFilter<Date | string, 'DateTime', Nullable>;

export type BytesFilter<Nullable extends boolean> =
| NullableIf<Uint8Array | Buffer, Nullable>
| {
equals?: NullableIf<Uint8Array, Nullable>;
in?: Uint8Array[];
notIn?: Uint8Array[];
not?: BytesFilter<Nullable>;
};
export type BooleanFilter<Nullable extends boolean> =
| NullableIf<boolean, Nullable>
| {
Expand Down
42 changes: 41 additions & 1 deletion packages/runtime/src/client/crud/dialects/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,16 @@ import { sql, type SelectQueryBuilder } from 'kysely';
import invariant from 'tiny-invariant';
import { match, P } from 'ts-pattern';
import type { GetModels, SchemaDef } from '../../../schema';
import type { BuiltinType, FieldDef } from '../../../schema/schema';
import type {
BuiltinType,
DataSourceProviderType,
FieldDef,
} from '../../../schema/schema';
import { enumerate } from '../../../utils/enumerate';
import { isPlainObject } from '../../../utils/is-plain-object';
import type {
BooleanFilter,
BytesFilter,
DateTimeFilter,
FindArgs,
SortOrder,
Expand All @@ -36,6 +41,8 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
protected readonly options: ClientOptions<Schema>
) {}

abstract get provider(): DataSourceProviderType;

transformPrimitive(value: unknown, _type: BuiltinType) {
return value;
}
Expand Down Expand Up @@ -496,6 +503,9 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
.with('DateTime', () =>
this.buildDateTimeFilter(eb, table, field, payload)
)
.with('Bytes', () =>
this.buildBytesFilter(eb, table, field, payload)
)
.exhaustive();
}

Expand Down Expand Up @@ -745,6 +755,31 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
return this.and(eb, ...conditions);
}

private buildBytesFilter(
eb: ExpressionBuilder<any, any>,
table: string,
field: string,
payload: BytesFilter<true>
) {
const conditions = this.buildStandardFilter(
eb,
'Bytes',
payload,
sql.ref(`${table}.${field}`),
(value) => this.transformPrimitive(value, 'Bytes'),
(value) =>
this.buildBytesFilter(
eb,
table,
field,
value as BytesFilter<true>
),
true,
['equals', 'in', 'notIn', 'not']
);
return this.and(eb, ...conditions.conditions);
}

private buildEnumFilter(
eb: ExpressionBuilder<any, any>,
table: string,
Expand Down Expand Up @@ -948,6 +983,11 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
value: Record<string, Expression<unknown>>
): ExpressionWrapper<any, any, unknown>;

abstract buildArrayLength(
eb: ExpressionBuilder<any, any>,
array: Expression<unknown>
): ExpressionWrapper<any, any, number>;

get supportsUpdateWithLimit() {
return true;
}
Expand Down
11 changes: 11 additions & 0 deletions packages/runtime/src/client/crud/dialects/postgresql.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ import { BaseCrudDialect } from './base';
export class PostgresCrudDialect<
Schema extends SchemaDef
> extends BaseCrudDialect<Schema> {
override get provider() {
return 'postgresql' as const;
}

override transformPrimitive(value: unknown, type: BuiltinType) {
return match(type)
.with('DateTime', () =>
Expand Down Expand Up @@ -324,4 +328,11 @@ export class PostgresCrudDialect<
override get supportsUpdateWithLimit(): boolean {
return false;
}

override buildArrayLength(
eb: ExpressionBuilder<any, any>,
array: Expression<unknown>
): ExpressionWrapper<any, any, number> {
return eb.fn('array_length', [array]);
}
}
Loading