From 801d34c25d781505d4d55a1d78359f888b9b24b7 Mon Sep 17 00:00:00 2001 From: ymc9 <104139426+ymc9@users.noreply.github.com> Date: Wed, 14 May 2025 23:00:15 -0700 Subject: [PATCH] more test cases pass for policy --- BREAKINGCHANGES.md | 2 +- NEW-FEATURES.md | 2 + TODO.md | 2 + .../src/validators/expression-validator.ts | 70 ---- .../function-invocation-validator.ts | 58 +-- packages/runtime/package.json | 5 +- packages/runtime/src/client/client-impl.ts | 6 + packages/runtime/src/client/contract.ts | 2 +- packages/runtime/src/client/crud-types.ts | 8 + .../runtime/src/client/crud/dialects/base.ts | 42 +- .../src/client/crud/dialects/postgresql.ts | 11 + .../src/client/crud/dialects/sqlite.ts | 15 + .../src/client/crud/operations/base.ts | 72 +++- .../src/client/crud/operations/update.ts | 61 +-- packages/runtime/src/client/crud/validator.ts | 44 +- packages/runtime/src/client/functions.ts | 121 ++++++ .../src/client/helpers/schema-db-pusher.ts | 3 + packages/runtime/src/client/options.ts | 16 + .../runtime/src/client/result-processor.ts | 33 ++ .../plugins/policy/expression-transformer.ts | 156 ++++++-- .../src/plugins/policy/policy-handler.ts | 111 +++++- packages/runtime/src/schema/expression.ts | 3 +- packages/runtime/src/schema/schema.ts | 3 +- ...oviders.test.ts => default-values.test.ts} | 7 +- .../runtime/test/client-api/filter.test.ts | 2 + .../test/client-api/name-mapping.test.ts | 4 +- .../test/client-api/type-coverage.test.ts | 45 ++- .../runtime/test/client-api/update.test.ts | 112 +++--- .../test/plugin/kysely-on-query.test.ts | 2 +- packages/runtime/test/policy/auth.test.ts | 3 +- .../test/policy/connect-disconnect.test.ts | 376 ++++++++++++++++++ .../policy/create-many-and-return.test.ts | 92 +++++ .../cross-model-field-comparison.test.ts | 221 ++++++++++ .../test/policy/policy-functions.test.ts | 238 +++++++++++ .../runtime/test/policy/todo-sample.test.ts | 65 ++- pnpm-lock.yaml | 8 - 36 files changed, 1691 insertions(+), 330 deletions(-) create mode 100644 NEW-FEATURES.md create mode 100644 packages/runtime/src/client/functions.ts rename packages/runtime/test/client-api/{default-value-providers.test.ts => default-values.test.ts} (91%) create mode 100644 packages/runtime/test/policy/connect-disconnect.test.ts create mode 100644 packages/runtime/test/policy/create-many-and-return.test.ts create mode 100644 packages/runtime/test/policy/cross-model-field-comparison.test.ts create mode 100644 packages/runtime/test/policy/policy-functions.test.ts diff --git a/BREAKINGCHANGES.md b/BREAKINGCHANGES.md index afae50ec..79068ab3 100644 --- a/BREAKINGCHANGES.md +++ b/BREAKINGCHANGES.md @@ -1,2 +1,2 @@ 1. `auth()` cannot be directly compared with a relation anymore -2. +2. `update` and `delete` policy rejection throws `NotFoundError` diff --git a/NEW-FEATURES.md b/NEW-FEATURES.md new file mode 100644 index 00000000..639fdbbf --- /dev/null +++ b/NEW-FEATURES.md @@ -0,0 +1,2 @@ +- Cross-field comparison (for read and mutations) +- Custom policy functions diff --git a/TODO.md b/TODO.md index 116d7dbd..ed5383c3 100644 --- a/TODO.md +++ b/TODO.md @@ -63,9 +63,11 @@ - [x] Custom table name - [x] Custom field name - [ ] Access Policy + - [ ] Short-circuit pre-create check for scalar-field only policies - [ ] Polymorphism - [x] Migration - [ ] Databases - [x] SQLite - [x] PostgreSQL + - [ ] Schema - [ ] MySQL diff --git a/packages/language/src/validators/expression-validator.ts b/packages/language/src/validators/expression-validator.ts index ea1d1291..df34eeb8 100644 --- a/packages/language/src/validators/expression-validator.ts +++ b/packages/language/src/validators/expression-validator.ts @@ -90,7 +90,6 @@ export default class ExpressionValidator implements AstValidator { }); } - this.validateCrossModelFieldComparison(expr, accept); break; } @@ -164,10 +163,6 @@ export default class ExpressionValidator implements AstValidator { node: expr, }); } - - if (expr.operator !== '&&' && expr.operator !== '||') { - this.validateCrossModelFieldComparison(expr, accept); - } break; } @@ -196,10 +191,6 @@ export default class ExpressionValidator implements AstValidator { break; } - if (!this.validateCrossModelFieldComparison(expr, accept)) { - break; - } - if ( (expr.left.$resolvedType?.nullable && isNullExpr(expr.right)) || @@ -289,67 +280,6 @@ export default class ExpressionValidator implements AstValidator { } } - private validateCrossModelFieldComparison( - expr: BinaryExpr, - accept: ValidationAcceptor - ) { - // not supported in "read" rules: - // - foo.a == bar - // - foo.user.id == userId - // except: - // - future().userId == userId - if ( - (isMemberAccessExpr(expr.left) && - isDataModelField(expr.left.member.ref) && - expr.left.member.ref.$container != - AstUtils.getContainerOfType(expr, isDataModel)) || - (isMemberAccessExpr(expr.right) && - isDataModelField(expr.right.member.ref) && - expr.right.member.ref.$container != - AstUtils.getContainerOfType(expr, isDataModel)) - ) { - // foo.user.id == auth().id - // foo.user.id == "123" - // foo.user.id == null - // foo.user.id == EnumValue - if ( - !( - this.isNotModelFieldExpr(expr.left) || - this.isNotModelFieldExpr(expr.right) - ) - ) { - const containingPolicyAttr = findUpAst( - expr, - (node) => - isDataModelAttribute(node) && - ['@@allow', '@@deny'].includes(node.decl.$refText) - ) as DataModelAttribute | undefined; - - if (containingPolicyAttr) { - const operation = getAttributeArgLiteral( - containingPolicyAttr, - 'operation' - ); - if ( - operation?.split(',').includes('all') || - operation?.split(',').includes('read') - ) { - accept( - 'error', - 'comparison between fields of different models is not supported in model-level "read" rules', - { - node: expr, - } - ); - return false; - } - } - } - } - - return true; - } - private validateCollectionPredicate( expr: BinaryExpr, accept: ValidationAcceptor diff --git a/packages/language/src/validators/function-invocation-validator.ts b/packages/language/src/validators/function-invocation-validator.ts index bd0e9c77..f9f1e4de 100644 --- a/packages/language/src/validators/function-invocation-validator.ts +++ b/packages/language/src/validators/function-invocation-validator.ts @@ -1,3 +1,6 @@ +import { AstUtils, type AstNode, type ValidationAcceptor } from 'langium'; +import { match, P } from 'ts-pattern'; +import { ExpressionContext } from '../constants'; import { Argument, DataModel, @@ -7,26 +10,19 @@ import { FunctionDecl, FunctionParam, InvocationExpr, - isArrayExpr, isDataModel, isDataModelAttribute, isDataModelFieldAttribute, - isLiteralExpr, } from '../generated/ast'; -import { match, P } from 'ts-pattern'; import { - getFieldReference, getFunctionExpressionContext, getLiteral, isCheckInvocation, isDataModelFieldReference, - isEnumFieldReference, isFromStdlib, typeAssignable, } from '../utils'; import type { AstValidator } from './common'; -import { AstUtils, type AstNode, type ValidationAcceptor } from 'langium'; -import { ExpressionContext } from '../constants'; // a registry of function handlers marked with @func const invocationCheckers = new Map(); @@ -128,54 +124,6 @@ export default class FunctionInvocationValidator } ); } - } else if ( - funcAllowedContext.includes(ExpressionContext.AccessPolicy) || - funcAllowedContext.includes(ExpressionContext.ValidationRule) - ) { - // filter operation functions validation - - // first argument must refer to a model field - const firstArg = expr.args?.[0]?.value; - if (firstArg) { - if (!getFieldReference(firstArg)) { - accept( - 'error', - 'first argument must be a field reference', - { node: firstArg } - ); - } - } - - // second argument must be a literal or array of literal - const secondArg = expr.args?.[1]?.value; - if ( - secondArg && - // literal - !isLiteralExpr(secondArg) && - // enum field - !isEnumFieldReference(secondArg) && - // TODO: revisit this - // `auth()...` expression - // !isAuthOrAuthMemberAccess(secondArg) && - // array of literal/enum - !( - isArrayExpr(secondArg) && - secondArg.items.every( - (item) => - isLiteralExpr(item) || - isEnumFieldReference(item) - // || isAuthOrAuthMemberAccess(item) - ) - ) - ) { - accept( - 'error', - 'second argument must be a literal, an enum, an expression starting with `auth().`, or an array of them', - { - node: secondArg, - } - ); - } } } diff --git a/packages/runtime/package.json b/packages/runtime/package.json index 36da9608..0ee43fd2 100644 --- a/packages/runtime/package.json +++ b/packages/runtime/package.json @@ -64,7 +64,6 @@ "dependencies": { "@paralleldrive/cuid2": "^2.2.2", "decimal.js": "^10.4.3", - "decimal.js-light": "^2.5.1", "kysely": "^0.27.5", "nanoid": "^5.0.9", "tiny-invariant": "^1.3.3", @@ -90,8 +89,8 @@ "@types/better-sqlite3": "^7.0.0", "@types/pg": "^8.0.0", "@types/tmp": "^0.2.6", - "tmp": "^0.2.3", "@zenstackhq/language": "workspace:*", - "@zenstackhq/testtools": "workspace:*" + "@zenstackhq/testtools": "workspace:*", + "tmp": "^0.2.3" } } diff --git a/packages/runtime/src/client/client-impl.ts b/packages/runtime/src/client/client-impl.ts index 3bf20f07..883473d9 100644 --- a/packages/runtime/src/client/client-impl.ts +++ b/packages/runtime/src/client/client-impl.ts @@ -31,6 +31,7 @@ import type { RuntimePlugin } from './plugin'; import { createDeferredPromise } from './promise'; import type { ToKysely } from './query-builder'; import { ResultProcessor } from './result-processor'; +import * as BuiltinFunctions from './functions'; /** * Creates a new ZenStack client instance. @@ -58,6 +59,11 @@ export class ClientImpl { this.$schema = schema; this.$options = options ?? ({} as ClientOptions); + this.$options.functions = { + ...BuiltinFunctions, + ...this.$options.functions, + }; + // here we use kysely's props constructor so we can pass a custom query executor if (baseClient) { this.kyselyProps = { diff --git a/packages/runtime/src/client/contract.ts b/packages/runtime/src/client/contract.ts index 4ccfd989..1989b204 100644 --- a/packages/runtime/src/client/contract.ts +++ b/packages/runtime/src/client/contract.ts @@ -1,4 +1,4 @@ -import type { Decimal } from 'decimal.js-light'; +import type { Decimal } from 'decimal.js'; import { type AuthType, type GetModels, diff --git a/packages/runtime/src/client/crud-types.ts b/packages/runtime/src/client/crud-types.ts index 57489d85..a8dd78ea 100644 --- a/packages/runtime/src/client/crud-types.ts +++ b/packages/runtime/src/client/crud-types.ts @@ -232,6 +232,14 @@ export type DateTimeFilter = | NullableIf | CommonPrimitiveFilter; +export type BytesFilter = + | NullableIf + | { + equals?: NullableIf; + in?: Uint8Array[]; + notIn?: Uint8Array[]; + not?: BytesFilter; + }; export type BooleanFilter = | NullableIf | { diff --git a/packages/runtime/src/client/crud/dialects/base.ts b/packages/runtime/src/client/crud/dialects/base.ts index d36bfead..96e5d257 100644 --- a/packages/runtime/src/client/crud/dialects/base.ts +++ b/packages/runtime/src/client/crud/dialects/base.ts @@ -9,11 +9,16 @@ import { sql, type SelectQueryBuilder } from 'kysely'; import invariant from 'tiny-invariant'; import { match, P } from 'ts-pattern'; import type { GetModels, SchemaDef } from '../../../schema'; -import type { BuiltinType, FieldDef } from '../../../schema/schema'; +import type { + BuiltinType, + DataSourceProviderType, + FieldDef, +} from '../../../schema/schema'; import { enumerate } from '../../../utils/enumerate'; import { isPlainObject } from '../../../utils/is-plain-object'; import type { BooleanFilter, + BytesFilter, DateTimeFilter, FindArgs, SortOrder, @@ -36,6 +41,8 @@ export abstract class BaseCrudDialect { protected readonly options: ClientOptions ) {} + abstract get provider(): DataSourceProviderType; + transformPrimitive(value: unknown, _type: BuiltinType) { return value; } @@ -496,6 +503,9 @@ export abstract class BaseCrudDialect { .with('DateTime', () => this.buildDateTimeFilter(eb, table, field, payload) ) + .with('Bytes', () => + this.buildBytesFilter(eb, table, field, payload) + ) .exhaustive(); } @@ -745,6 +755,31 @@ export abstract class BaseCrudDialect { return this.and(eb, ...conditions); } + private buildBytesFilter( + eb: ExpressionBuilder, + table: string, + field: string, + payload: BytesFilter + ) { + const conditions = this.buildStandardFilter( + eb, + 'Bytes', + payload, + sql.ref(`${table}.${field}`), + (value) => this.transformPrimitive(value, 'Bytes'), + (value) => + this.buildBytesFilter( + eb, + table, + field, + value as BytesFilter + ), + true, + ['equals', 'in', 'notIn', 'not'] + ); + return this.and(eb, ...conditions.conditions); + } + private buildEnumFilter( eb: ExpressionBuilder, table: string, @@ -948,6 +983,11 @@ export abstract class BaseCrudDialect { value: Record> ): ExpressionWrapper; + abstract buildArrayLength( + eb: ExpressionBuilder, + array: Expression + ): ExpressionWrapper; + get supportsUpdateWithLimit() { return true; } diff --git a/packages/runtime/src/client/crud/dialects/postgresql.ts b/packages/runtime/src/client/crud/dialects/postgresql.ts index 9e2fd679..ec8cf22b 100644 --- a/packages/runtime/src/client/crud/dialects/postgresql.ts +++ b/packages/runtime/src/client/crud/dialects/postgresql.ts @@ -21,6 +21,10 @@ import { BaseCrudDialect } from './base'; export class PostgresCrudDialect< Schema extends SchemaDef > extends BaseCrudDialect { + override get provider() { + return 'postgresql' as const; + } + override transformPrimitive(value: unknown, type: BuiltinType) { return match(type) .with('DateTime', () => @@ -324,4 +328,11 @@ export class PostgresCrudDialect< override get supportsUpdateWithLimit(): boolean { return false; } + + override buildArrayLength( + eb: ExpressionBuilder, + array: Expression + ): ExpressionWrapper { + return eb.fn('array_length', [array]); + } } diff --git a/packages/runtime/src/client/crud/dialects/sqlite.ts b/packages/runtime/src/client/crud/dialects/sqlite.ts index 85bada8b..7119be54 100644 --- a/packages/runtime/src/client/crud/dialects/sqlite.ts +++ b/packages/runtime/src/client/crud/dialects/sqlite.ts @@ -1,4 +1,5 @@ import { + ExpressionWrapper, sql, type Expression, type ExpressionBuilder, @@ -16,10 +17,15 @@ import { requireModel, } from '../../query-utils'; import { BaseCrudDialect } from './base'; +import type Decimal from 'decimal.js'; export class SqliteCrudDialect< Schema extends SchemaDef > extends BaseCrudDialect { + override get provider() { + return 'sqlite' as const; + } + override transformPrimitive(value: unknown, type: BuiltinType) { if (value === undefined) { return value; @@ -30,6 +36,8 @@ export class SqliteCrudDialect< .with('DateTime', () => value instanceof Date ? value.toISOString() : value ) + .with('Decimal', () => (value as Decimal).toString()) + .with('Bytes', () => Buffer.from(value as Uint8Array)) .otherwise(() => value); } @@ -256,4 +264,11 @@ export class SqliteCrudDialect< override get supportsUpdateWithLimit() { return false; } + + override buildArrayLength( + eb: ExpressionBuilder, + array: Expression + ): ExpressionWrapper { + return eb.fn('json_array_length', [array]); + } } diff --git a/packages/runtime/src/client/crud/operations/base.ts b/packages/runtime/src/client/crud/operations/base.ts index 418b5cba..35115288 100644 --- a/packages/runtime/src/client/crud/operations/base.ts +++ b/packages/runtime/src/client/crud/operations/base.ts @@ -742,13 +742,24 @@ export abstract class BaseOperationHandler { } const createData = enumerate(input.data).map((item) => { + const newItem: any = {}; + for (const [name, value] of Object.entries(item)) { + const fieldDef = this.requireField(model, name); + invariant( + !fieldDef.relation, + 'createMany does not support relations' + ); + newItem[name] = this.dialect.transformPrimitive( + value, + fieldDef.type as BuiltinType + ); + } if (fromRelation) { - item = { ...item }; for (const { fk, pk } of relationKeyPairs) { - item[fk] = fromRelation.ids[pk]; + newItem[fk] = fromRelation.ids[pk]; } } - return this.fillGeneratedValues(modelDef, item); + return this.fillGeneratedValues(modelDef, newItem); }); const query = kysely @@ -782,7 +793,11 @@ export abstract class BaseOperationHandler { values[field] = generated; } } else if (fields[field]?.updatedAt) { - values[field] = new Date().toISOString(); + // TODO: should this work at kysely level instead? + values[field] = this.dialect.transformPrimitive( + new Date(), + 'DateTime' + ); } } } @@ -873,7 +888,22 @@ export abstract class BaseOperationHandler { : parentWhere; } - if (Object.keys(data).length === 0) { + // fill in automatically updated fields + const modelDef = this.requireModel(model); + let finalData = data; + for (const [fieldName, fieldDef] of Object.entries(modelDef.fields)) { + if (fieldDef.updatedAt) { + if (finalData === data) { + finalData = clone(data); + } + finalData[fieldName] = this.dialect.transformPrimitive( + new Date(), + 'DateTime' + ); + } + } + + if (Object.keys(finalData).length === 0) { // update without data, simply return const r = await this.readUnique(kysely, model, { where: combinedWhere, @@ -887,14 +917,14 @@ export abstract class BaseOperationHandler { const updateFields: any = {}; let thisEntity: any = undefined; - for (const field in data) { + for (const field in finalData) { const fieldDef = this.requireField(model, field); if ( isScalarField(this.schema, model, field) || isForeignKeyField(this.schema, model, field) ) { updateFields[field] = this.dialect.transformPrimitive( - data[field], + finalData[field], fieldDef.type as BuiltinType ); } else { @@ -922,7 +952,7 @@ export abstract class BaseOperationHandler { field, fieldDef, thisEntity, - data[field], + finalData[field], throwIfNotFound ); } @@ -935,25 +965,25 @@ export abstract class BaseOperationHandler { (await this.readUnique(kysely, model, { where: combinedWhere })) ); } else { + const idFields = getIdFields(this.schema, model); const query = kysely .updateTable(model) .where((eb) => this.dialect.buildFilter(eb, model, model, combinedWhere) ) .set(updateFields) - // TODO: return selectively - .returningAll(); - - let updatedEntity: any; - - try { - updatedEntity = await query.executeTakeFirst(); - } catch (err) { - const { sql, parameters } = query.compile(); - throw new QueryError( - `Error during update: ${err}, sql: ${sql}, parameters: ${parameters}` - ); - } + .returning(idFields as any); + + const updatedEntity = await query.executeTakeFirst(); + + // try { + // updatedEntity = await query.executeTakeFirst(); + // } catch (err) { + // const { sql, parameters } = query.compile(); + // throw new QueryError( + // `Error during update: ${err}, sql: ${sql}, parameters: ${parameters}` + // ); + // } if (!updatedEntity) { if (throwIfNotFound) { diff --git a/packages/runtime/src/client/crud/operations/update.ts b/packages/runtime/src/client/crud/operations/update.ts index e5dc1fc3..7101e11d 100644 --- a/packages/runtime/src/client/crud/operations/update.ts +++ b/packages/runtime/src/client/crud/operations/update.ts @@ -1,8 +1,9 @@ import { match } from 'ts-pattern'; +import { RejectedByPolicyError } from '../../../plugins/policy/errors'; import type { GetModels, SchemaDef } from '../../../schema'; import type { UpdateArgs, UpdateManyArgs } from '../../crud-types'; -import { getIdValues, requireField } from '../../query-utils'; import { BaseOperationHandler } from './base'; +import { getIdValues } from '../../query-utils'; export class UpdateOperationHandler< Schema extends SchemaDef @@ -23,54 +24,24 @@ export class UpdateOperationHandler< } private async runUpdate(args: UpdateArgs>) { - const hasRelationUpdate = Object.keys(args.data).some( - (f) => !!requireField(this.schema, this.model, f).relation - ); - - const returnRelations = this.needReturnRelations(this.model, args); - - let result: any; - if (hasRelationUpdate) { - // employ a transaction - try { - result = await this.safeTransaction(async (tx) => { - const updateResult = await this.update( - tx, - this.model, - args.where, - args.data - ); - return this.readUnique(tx, this.model, { - select: args.select, - include: args.include, - where: getIdValues( - this.schema, - this.model, - updateResult - ), - }); - }); - } catch (err) { - // console.error(err); - throw err; - } - } else { - // simple update - const updateResult = await this.update( - this.kysely, + const result = await this.safeTransaction(async (tx) => { + const updated = await this.update( + tx, this.model, args.where, args.data ); - if (returnRelations) { - result = await this.readUnique(this.kysely, this.model, { - select: args.select, - include: args.include, - where: getIdValues(this.schema, this.model, updateResult), - }); - } else { - result = this.trimResult(updateResult, args); - } + return this.readUnique(tx, this.model, { + select: args.select, + include: args.include, + where: getIdValues(this.schema, this.model, updated), + }); + }); + + if (!result) { + throw new RejectedByPolicyError( + 'result is not allowed to be read back' + ); } return result; diff --git a/packages/runtime/src/client/crud/validator.ts b/packages/runtime/src/client/crud/validator.ts index 20f543bd..0232af80 100644 --- a/packages/runtime/src/client/crud/validator.ts +++ b/packages/runtime/src/client/crud/validator.ts @@ -1,3 +1,4 @@ +import Decimal from 'decimal.js'; import { match, P } from 'ts-pattern'; import { z, ZodSchema } from 'zod'; import type { @@ -164,9 +165,12 @@ export class InputValidator { .with('Int', () => z.number()) .with('Float', () => z.number()) .with('Boolean', () => z.boolean()) - .with('BigInt', () => z.string()) - .with('Decimal', () => z.string()) + .with('BigInt', () => z.union([z.number(), z.bigint()])) + .with('Decimal', () => + z.union([z.number(), z.instanceof(Decimal), z.string()]) + ) .with('DateTime', () => z.union([z.date(), z.string().datetime()])) + .with('Bytes', () => z.instanceof(Uint8Array)) .otherwise(() => z.unknown()); } @@ -325,11 +329,15 @@ export class InputValidator { protected makePrimitiveFilterSchema(type: BuiltinType, optional: boolean) { return match(type) .with('String', () => this.makeStringFilterSchema(optional)) - .with(P.union('Int', 'Float', 'Decimal', 'BigInt'), () => - this.makeNumberFilterSchema(optional) + .with(P.union('Int', 'Float', 'Decimal', 'BigInt'), (type) => + this.makeNumberFilterSchema( + this.makePrimitiveSchema(type), + optional + ) ) .with('Boolean', () => this.makeBooleanFilterSchema(optional)) .with('DateTime', () => this.makeDateTimeFilterSchema(optional)) + .with('Bytes', () => this.makeBytesFilterSchema(optional)) .exhaustive(); } @@ -353,6 +361,24 @@ export class InputValidator { ]); } + private makeBytesFilterSchema(optional: boolean): ZodSchema { + const baseSchema = z.instanceof(Uint8Array); + const components = this.makeCommonPrimitiveFilterComponents( + baseSchema, + optional, + () => z.instanceof(Uint8Array) + ); + return z.union([ + this.nullableIf(baseSchema, optional), + z.object({ + equals: components.equals, + in: components.in, + notIn: components.notIn, + not: components.not, + }), + ]); + } + private makeCommonPrimitiveFilterComponents( baseSchema: ZodSchema, optional: boolean, @@ -388,10 +414,12 @@ export class InputValidator { ]); } - private makeNumberFilterSchema(optional: boolean): ZodSchema { - const base = z.union([z.number(), z.bigint()]); - return this.makeCommonPrimitiveFilterSchema(base, optional, () => - z.lazy(() => this.makeNumberFilterSchema(optional)) + private makeNumberFilterSchema( + baseSchema: ZodSchema, + optional: boolean + ): ZodSchema { + return this.makeCommonPrimitiveFilterSchema(baseSchema, optional, () => + z.lazy(() => this.makeNumberFilterSchema(baseSchema, optional)) ); } diff --git a/packages/runtime/src/client/functions.ts b/packages/runtime/src/client/functions.ts new file mode 100644 index 00000000..130bcf9b --- /dev/null +++ b/packages/runtime/src/client/functions.ts @@ -0,0 +1,121 @@ +import { sql, type Expression, type ExpressionBuilder } from 'kysely'; +import type { ZModelFunction } from './options'; +import type { BaseCrudDialect } from './crud/dialects/base'; +import { match } from 'ts-pattern'; + +// TODO: migrate default value generation functions to here too + +export const contains: ZModelFunction = ( + eb: ExpressionBuilder, + args: Expression[] +) => { + const [field, search, caseInsensitive = false] = args; + if (!field) { + throw new Error('"field" parameter is required'); + } + if (!search) { + throw new Error('"search" parameter is required'); + } + const searchExpr = eb.fn('CONCAT', [sql.lit('%'), search, sql.lit('%')]); + return eb(field, caseInsensitive ? 'ilike' : 'like', searchExpr); +}; + +export const search: ZModelFunction = ( + _eb: ExpressionBuilder, + _args: Expression[] +) => { + throw new Error(`"search" function is not implemented yet`); +}; + +export const startsWith: ZModelFunction = ( + eb: ExpressionBuilder, + args: Expression[] +) => { + const [field, search] = args; + if (!field) { + throw new Error('"field" parameter is required'); + } + if (!search) { + throw new Error('"search" parameter is required'); + } + return eb(field, 'like', eb.fn('CONCAT', [search, sql.lit('%')])); +}; + +export const endsWith: ZModelFunction = ( + eb: ExpressionBuilder, + args: Expression[] +) => { + const [field, search] = args; + if (!field) { + throw new Error('"field" parameter is required'); + } + if (!search) { + throw new Error('"search" parameter is required'); + } + return eb(field, 'like', eb.fn('CONCAT', [sql.lit('%'), search])); +}; + +export const has: ZModelFunction = ( + eb: ExpressionBuilder, + args: Expression[] +) => { + const [field, search] = args; + if (!field) { + throw new Error('"field" parameter is required'); + } + if (!search) { + throw new Error('"search" parameter is required'); + } + return eb(field, '@>', [search]); +}; + +export const hasEvery: ZModelFunction = ( + eb: ExpressionBuilder, + args: Expression[] +) => { + const [field, search] = args; + if (!field) { + throw new Error('"field" parameter is required'); + } + if (!search) { + throw new Error('"search" parameter is required'); + } + return eb(field, '@>', search); +}; + +export const hasSome: ZModelFunction = ( + eb: ExpressionBuilder, + args: Expression[] +) => { + const [field, search] = args; + if (!field) { + throw new Error('"field" parameter is required'); + } + if (!search) { + throw new Error('"search" parameter is required'); + } + return eb(field, '&&', search); +}; + +export const isEmpty: ZModelFunction = ( + eb: ExpressionBuilder, + args: Expression[], + dialect: BaseCrudDialect +) => { + const [field] = args; + if (!field) { + throw new Error('"field" parameter is required'); + } + return eb(dialect.buildArrayLength(eb, field), '=', sql.lit(0)); +}; + +export const now: ZModelFunction = ( + eb: ExpressionBuilder, + _args: Expression[], + dialect: BaseCrudDialect +) => { + return match(dialect.provider) + .with('postgresql', () => eb.fn('now')) + .with('sqlite', () => sql.raw('CURRENT_TIMESTAMP')) + .exhaustive(); +}; diff --git a/packages/runtime/src/client/helpers/schema-db-pusher.ts b/packages/runtime/src/client/helpers/schema-db-pusher.ts index 565f24a1..547d705f 100644 --- a/packages/runtime/src/client/helpers/schema-db-pusher.ts +++ b/packages/runtime/src/client/helpers/schema-db-pusher.ts @@ -186,6 +186,9 @@ export class SchemaDbPusher { .with('BigInt', () => 'bigint') .with('Decimal', () => 'decimal') .with('DateTime', () => 'timestamp') + .with('Bytes', () => + this.schema.provider.type === 'postgresql' ? 'bytea' : 'blob' + ) .otherwise(() => { throw new Error(`Unsupported field type: ${type}`); }); diff --git a/packages/runtime/src/client/options.ts b/packages/runtime/src/client/options.ts index 561c779e..57986d30 100644 --- a/packages/runtime/src/client/options.ts +++ b/packages/runtime/src/client/options.ts @@ -1,4 +1,5 @@ import type { + Expression, ExpressionBuilder, KyselyConfig, PostgresDialectConfig, @@ -14,6 +15,7 @@ import type { } from '../schema/schema'; import type { PrependParameter } from '../utils/type-utils'; import type { ClientContract, ProcedureFunc } from './contract'; +import type { BaseCrudDialect } from './crud/dialects/base'; import type { RuntimePlugin } from './plugin'; import type { ToKyselySchema } from './query-builder'; @@ -24,12 +26,26 @@ type DialectConfig = ? Optional : never; +export type ZModelFunction = ( + eb: ExpressionBuilder, keyof ToKyselySchema>, + args: Expression[], + dialect: BaseCrudDialect +) => Expression; + export type ClientOptions = { /** * Database dialect configuration. */ dialectConfig?: DialectConfig; + /** + * Custom functions. + */ + functions?: Record>; + + /** + * Plugins. + */ plugins?: RuntimePlugin[]; /** diff --git a/packages/runtime/src/client/result-processor.ts b/packages/runtime/src/client/result-processor.ts index 5649b690..e737e818 100644 --- a/packages/runtime/src/client/result-processor.ts +++ b/packages/runtime/src/client/result-processor.ts @@ -1,3 +1,5 @@ +import Decimal from 'decimal.js'; +import invariant from 'tiny-invariant'; import { match } from 'ts-pattern'; import type { FieldDef, GetModels, SchemaDef } from '../schema'; import type { BuiltinType } from '../schema/schema'; @@ -73,9 +75,36 @@ export class ResultProcessor { return match(type) .with('Boolean', () => this.transformBoolean(value)) .with('DateTime', () => this.transformDate(value)) + .with('Bytes', () => this.transformBytes(value)) + .with('Decimal', () => this.transformDecimal(value)) + .with('BigInt', () => this.transformBigInt(value)) .otherwise(() => value); } + private transformDecimal(value: unknown) { + if (value instanceof Decimal) { + return value; + } + invariant( + typeof value === 'string' || + typeof value === 'number' || + value instanceof Decimal, + `Expected string, number or Decimal, got ${typeof value}` + ); + return new Decimal(value); + } + + private transformBigInt(value: unknown) { + if (typeof value === 'bigint') { + return value; + } + invariant( + typeof value === 'string' || typeof value === 'number', + `Expected string or number, got ${typeof value}` + ); + return BigInt(value); + } + private transformBoolean(value: unknown) { return !!value; } @@ -89,4 +118,8 @@ export class ResultProcessor { return value; } } + + private transformBytes(value: unknown) { + return Buffer.isBuffer(value) ? Uint8Array.from(value) : value; + } } diff --git a/packages/runtime/src/plugins/policy/expression-transformer.ts b/packages/runtime/src/plugins/policy/expression-transformer.ts index 3821b8a2..18456b61 100644 --- a/packages/runtime/src/plugins/policy/expression-transformer.ts +++ b/packages/runtime/src/plugins/policy/expression-transformer.ts @@ -1,7 +1,9 @@ +import type { OperandExpression } from 'kysely'; import { AliasNode, BinaryOperationNode, ColumnNode, + expressionBuilder, FromNode, FunctionNode, IdentifierNode, @@ -10,8 +12,10 @@ import { SelectionNode, SelectQueryNode, TableNode, + ValueListNode, ValueNode, WhereNode, + type ExpressionBuilder, type OperationNode, } from 'kysely'; import invariant from 'tiny-invariant'; @@ -25,7 +29,12 @@ import { getRelationForeignKeyFieldPairs, requireField, } from '../../client/query-utils'; -import type { CallExpression, FieldExpression, SchemaDef } from '../../schema'; +import type { + ArrayExpression, + CallExpression, + FieldExpression, + SchemaDef, +} from '../../schema'; import { Expression, type BinaryExpression, @@ -41,7 +50,8 @@ import { conjunction, disjunction, logicalNot, trueNode } from './utils'; export type ExpressionTransformerContext = { model: GetModels; alias?: string; - thisEntity?: Record; + thisEntity?: Record; + auth?: any; }; // a registry of expression handlers marked with @expr @@ -105,6 +115,17 @@ export class ExpressionTransformer { ); } + @expr('array') + // @ts-ignore + private _array( + expr: ArrayExpression, + context: ExpressionTransformerContext + ) { + return ValueListNode.create( + expr.items.map((item) => this.transform(item, context)) + ); + } + @expr('field') // @ts-ignore private _field( @@ -119,7 +140,11 @@ export class ExpressionTransformer { return this.createColumnRef(expr.field, context); } } else { - return this._relation(expr.field, fieldDef.type, context); + return this.transformRelationAccess( + expr.field, + fieldDef.type, + context + ); } } @@ -160,11 +185,23 @@ export class ExpressionTransformer { const left = this.transform(expr.left, context); const right = this.transform(expr.right, context); - if (this.isNullNode(right)) { + if (op === 'in') { invariant( - expr.op === '==' || expr.op === '!=', - 'Comparison with null must be "==" or "!="' + ValueListNode.is(right), + '"in" operation requires right operand to be a value list' ); + if (this.isNullNode(left)) { + return this.transformValue(false, 'Boolean'); + } else { + return BinaryOperationNode.create( + left, + OperatorNode.create('in'), + right + ); + } + } + + if (this.isNullNode(right)) { return expr.op === '==' ? BinaryOperationNode.create( left, @@ -177,20 +214,16 @@ export class ExpressionTransformer { right ); } else if (this.isNullNode(left)) { - invariant( - expr.op === '==' || expr.op === '!=', - 'Comparison with null must be "==" or "!="' - ); return expr.op === '==' ? BinaryOperationNode.create( right, OperatorNode.create('is'), - left + ValueNode.createImmediate(null) ) : BinaryOperationNode.create( right, OperatorNode.create('is not'), - left + ValueNode.createImmediate(null) ); } @@ -201,10 +234,6 @@ export class ExpressionTransformer { ); } - private isNullNode(node: OperationNode) { - return ValueNode.is(node) && node.value === null; - } - private transformCollectionPredicate( expr: BinaryExpression, context: ExpressionTransformerContext @@ -371,7 +400,7 @@ export class ExpressionTransformer { } } - private transformValue(value: unknown, type: BuiltinType): OperationNode { + private transformValue(value: unknown, type: BuiltinType) { return ValueNode.create( this.dialect.transformPrimitive(value, type) ?? null ); @@ -401,16 +430,65 @@ export class ExpressionTransformer { @expr('call') // @ts-ignore - private _call(expr: CallExpression) { - throw new QueryError(`Unknown function: ${expr.function}`); + private _call( + expr: CallExpression, + context: ExpressionTransformerContext + ) { + const result = this.transformCall(expr, context); + return result.toOperationNode(); } - private isAuthCall(value: unknown): value is CallExpression { - return Expression.isCall(value) && value.function === 'auth'; + private transformCall( + expr: CallExpression, + context: ExpressionTransformerContext + ) { + const func = this.clientOptions.functions?.[expr.function]; + if (!func) { + throw new QueryError(`Function not implemented: ${expr.function}`); + } + const eb = expressionBuilder(); + return func( + eb, + (expr.args ?? []).map((arg) => + this.transformCallArg(eb, arg, context) + ), + this.dialect + ); } - private isAuthMember(expr: Expression): boolean { - return Expression.isMember(expr) && this.isAuthCall(expr.receiver); + private transformCallArg( + eb: ExpressionBuilder, + arg: Expression, + context: ExpressionTransformerContext + ): OperandExpression { + if (Expression.isLiteral(arg)) { + return eb.val(arg.value); + } + + if (Expression.isField(arg)) { + return context.thisEntity + ? eb.val(context.thisEntity[arg.field]?.value) + : eb.ref(arg.field); + } + + if (Expression.isCall(arg)) { + return this.transformCall(arg, context); + } + + if (this.isAuthMember(arg)) { + const valNode = this.valueMemberAccess( + context.auth, + arg as MemberExpression, + this.authType + ); + return valNode ? eb.val(valNode.value) : eb.val(null); + } + + // TODO + // if (Expression.isMember(arg)) { + // } + + throw new InternalError(`Unsupported argument expression: ${arg.kind}`); } @expr('member') @@ -460,12 +538,16 @@ export class ExpressionTransformer { const { fieldDef, fromModel } = memberFields[i]!; if (fieldDef.relation) { - const relation = this._relation(member, fieldDef.type, { - ...context, - model: fromModel as GetModels, - alias: undefined, - thisEntity: undefined, - }); + const relation = this.transformRelationAccess( + member, + fieldDef.type, + { + ...context, + model: fromModel as GetModels, + alias: undefined, + thisEntity: undefined, + } + ); if (currNode) { invariant( SelectQueryNode.is(currNode), @@ -536,9 +618,7 @@ export class ExpressionTransformer { return this.transformValue(fieldValue, fieldDef.type as BuiltinType); } - // @expr('relation') - // @ts-ignore - private _relation( + private transformRelationAccess( field: string, relationModel: string, context: ExpressionTransformerContext @@ -644,4 +724,16 @@ export class ExpressionTransformer { TableNode.create(context.alias ?? context.model) ); } + + private isAuthCall(value: unknown): value is CallExpression { + return Expression.isCall(value) && value.function === 'auth'; + } + + private isAuthMember(expr: Expression) { + return Expression.isMember(expr) && this.isAuthCall(expr.receiver); + } + + private isNullNode(node: OperationNode) { + return ValueNode.is(node) && node.value === null; + } } diff --git a/packages/runtime/src/plugins/policy/policy-handler.ts b/packages/runtime/src/plugins/policy/policy-handler.ts index 3afca2b9..fd7adbc3 100644 --- a/packages/runtime/src/plugins/policy/policy-handler.ts +++ b/packages/runtime/src/plugins/policy/policy-handler.ts @@ -14,7 +14,6 @@ import { SelectQueryNode, TableNode, UpdateQueryNode, - ValueListNode, ValueNode, ValuesNode, WhereNode, @@ -32,8 +31,13 @@ import type { OnKyselyQueryTransaction, ProceedKyselyQueryFunction, } from '../../client/plugin'; -import { getIdFields, requireModel } from '../../client/query-utils'; +import { + getIdFields, + requireField, + requireModel, +} from '../../client/query-utils'; import { Expression, type GetModels, type SchemaDef } from '../../schema'; +import type { BuiltinType } from '../../schema/schema'; import { ColumnCollector } from './column-collector'; import { RejectedByPolicyError } from './errors'; import { ExpressionTransformer } from './expression-transformer'; @@ -120,7 +124,7 @@ export class PolicyHandler< const transformedNode = this.transformNode(node); const result = await txProceed(transformedNode); - if (!InsertQueryNode.is(node) || !this.onlyReturningId(node)) { + if (!this.onlyReturningId(node)) { const readBackResult = await this.processReadBack( node, result, @@ -144,7 +148,7 @@ export class PolicyHandler< return result; } - private onlyReturningId(node: InsertQueryNode) { + private onlyReturningId(node: MutationQueryNode) { if (!node.returning) { return true; } @@ -165,13 +169,34 @@ export class PolicyHandler< return; } - const thisEntity: Record = {}; - const values = this.unwrapCreateValues(node.values); - for (let i = 0; i < node.columns?.length; i++) { - thisEntity[node.columns![i]!.column.name] = values[i]!; + const model = this.getMutationModel(node); + const fields = node.columns.map((c) => c.column.name); + const valueRows = this.unwrapCreateValueRows( + node.values, + model, + fields + ); + for (const values of valueRows) { + await this.enforcePreCreatePolicyForOne( + model, + fields, + values, + proceed + ); + } + } + + private async enforcePreCreatePolicyForOne( + model: GetModels, + fields: string[], + values: ValueNode[], + proceed: ProceedKyselyQueryFunction + ) { + const thisEntity: Record = {}; + for (let i = 0; i < fields.length; i++) { + thisEntity[fields[i]!] = values[i]!; } - const model = this.getMutationModel(node); const filter = this.buildPolicyFilter( model, undefined, @@ -195,15 +220,17 @@ export class PolicyHandler< } } - private unwrapCreateValues(node: OperationNode): readonly OperationNode[] { + private unwrapCreateValueRows( + node: OperationNode, + model: GetModels, + fields: string[] + ) { if (ValuesNode.is(node)) { - if (node.values.length === 1 && this.isValueList(node.values[0]!)) { - return this.unwrapCreateValues(node.values[0]!); - } else { - return node.values; - } + return node.values.map((v) => + this.unwrapCreateValueRow(v.values, model, fields) + ); } else if (PrimitiveValueListNode.is(node)) { - return node.values.map((v) => ValueNode.create(v)); + return [this.unwrapCreateValueRow(node.values, model, fields)]; } else { throw new InternalError( `Unexpected node kind: ${node.kind} for unwrapping create values` @@ -211,8 +238,45 @@ export class PolicyHandler< } } - private isValueList(node: OperationNode) { - return ValueListNode.is(node) || PrimitiveValueListNode.is(node); + private unwrapCreateValueRow( + data: readonly unknown[], + model: GetModels, + fields: string[] + ) { + invariant( + data.length === fields.length, + 'data length must match fields length' + ); + const result: ValueNode[] = []; + for (let i = 0; i < data.length; i++) { + const item = data[i]!; + const fieldDef = requireField( + this.client.$schema, + model, + fields[i]! + ); + if (typeof item === 'object' && item && 'kind' in item) { + invariant(item.kind === 'ValueNode', 'expecting a ValueNode'); + result.push( + ValueNode.create( + this.dialect.transformPrimitive( + (item as ValueNode).value, + fieldDef.type as BuiltinType + ) + ) + ); + } else { + result.push( + ValueNode.create( + this.dialect.transformPrimitive( + item, + fieldDef.type as BuiltinType + ) + ) + ); + } + } + return result; } private tryGetConstantPolicy( @@ -354,7 +418,7 @@ export class PolicyHandler< model: GetModels, alias: string | undefined, operation: PolicyOperation, - thisEntity?: Record + thisEntity?: Record ) { const policies = this.getModelPolicies(model, operation); if (policies.length === 0) { @@ -513,13 +577,18 @@ export class PolicyHandler< model: GetModels, alias: string | undefined, policy: Policy, - thisEntity?: Record + thisEntity?: Record ) { return new ExpressionTransformer( this.client.$schema, this.client.$options, this.client.$auth - ).transform(policy.condition, { model, alias, thisEntity }); + ).transform(policy.condition, { + model, + alias, + thisEntity, + auth: this.client.$auth, + }); } private getModelPolicies(modelName: string, operation: PolicyOperation) { diff --git a/packages/runtime/src/schema/expression.ts b/packages/runtime/src/schema/expression.ts index ce6f91ff..e2936fb9 100644 --- a/packages/runtime/src/schema/expression.ts +++ b/packages/runtime/src/schema/expression.ts @@ -69,7 +69,8 @@ export type BinaryOperator = | '>=' | '?' | '!' - | '^'; + | '^' + | 'in'; export const Expression = { literal: (value: string | number | boolean): LiteralExpression => { diff --git a/packages/runtime/src/schema/schema.ts b/packages/runtime/src/schema/schema.ts index 3dd1d330..19a149a4 100644 --- a/packages/runtime/src/schema/schema.ts +++ b/packages/runtime/src/schema/schema.ts @@ -86,7 +86,8 @@ export type BuiltinType = | 'Float' | 'BigInt' | 'Decimal' - | 'DateTime'; + | 'DateTime' + | 'Bytes'; export type MappedBuiltinType = | string diff --git a/packages/runtime/test/client-api/default-value-providers.test.ts b/packages/runtime/test/client-api/default-values.test.ts similarity index 91% rename from packages/runtime/test/client-api/default-value-providers.test.ts rename to packages/runtime/test/client-api/default-values.test.ts index 1fbed5d8..79bce819 100644 --- a/packages/runtime/test/client-api/default-value-providers.test.ts +++ b/packages/runtime/test/client-api/default-values.test.ts @@ -46,6 +46,10 @@ const schema = { type: 'String', default: Expression.call('ulid'), }, + dt: { + type: 'DateTime', + default: Expression.call('now'), + }, }, idFields: ['uuid'], uniqueFields: { @@ -56,7 +60,7 @@ const schema = { plugins: {}, } as const satisfies SchemaDef; -describe('Default Value Providers', () => { +describe('default values tests', () => { it('supports generators', async () => { const client = new ZenStackClient(schema); await client.$pushSchema(); @@ -69,5 +73,6 @@ describe('Default Value Providers', () => { expect(entity.nanoid).toSatisfy((id) => id.length >= 21); expect(entity.nanoid8).toSatisfy((id) => id.length === 8); expect(entity.ulid).toSatisfy(isValidUlid); + expect(entity.dt).toBeInstanceOf(Date); }); }); diff --git a/packages/runtime/test/client-api/filter.test.ts b/packages/runtime/test/client-api/filter.test.ts index aa2c5b10..3ec61ea9 100644 --- a/packages/runtime/test/client-api/filter.test.ts +++ b/packages/runtime/test/client-api/filter.test.ts @@ -549,5 +549,7 @@ describe.each(createClientSpecs(PG_DB_NAME))( }) ).toResolveTruthy(); }); + + // TODO: filter for bigint, decimal, bytes } ); diff --git a/packages/runtime/test/client-api/name-mapping.test.ts b/packages/runtime/test/client-api/name-mapping.test.ts index f6de4c9d..2631ca6c 100644 --- a/packages/runtime/test/client-api/name-mapping.test.ts +++ b/packages/runtime/test/client-api/name-mapping.test.ts @@ -59,7 +59,7 @@ describe('Name mapping tests', () => { } as const satisfies SchemaDef; it('works with model and implicit field mapping', async () => { - const client = new ZenStackClient(schema, { log: ['query'] }); + const client = new ZenStackClient(schema); await client.$pushSchema(); const r1 = await client.foo.create({ data: { id: '1', x: 1 }, @@ -89,7 +89,7 @@ describe('Name mapping tests', () => { }); it('works with explicit field mapping', async () => { - const client = new ZenStackClient(schema, { log: ['query'] }); + const client = new ZenStackClient(schema); await client.$pushSchema(); const r1 = await client.foo.create({ data: { id: '1', x: 1 }, diff --git a/packages/runtime/test/client-api/type-coverage.test.ts b/packages/runtime/test/client-api/type-coverage.test.ts index 242b107e..9ad08705 100644 --- a/packages/runtime/test/client-api/type-coverage.test.ts +++ b/packages/runtime/test/client-api/type-coverage.test.ts @@ -1,3 +1,44 @@ -import { describe } from 'vitest'; +import Decimal from 'decimal.js'; +import { describe, expect, it } from 'vitest'; +import { createTestClient } from '../utils'; -describe.skip('Type coverage', () => {}); +describe('zmodel type coverage tests', () => { + it('supports all types', async () => { + const db = await createTestClient( + ` + model Foo { + id String @id @default(cuid()) + + String String + Int Int + BigInt BigInt + DateTime DateTime + Float Float + Decimal Decimal + Boolean Boolean + Bytes Bytes + + @@allow('all', true) + } + ` + ); + + const date = new Date(); + const data = { + id: '1', + String: 'string', + Int: 100, + BigInt: BigInt(9007199254740991), + DateTime: date, + Float: 1.23, + Decimal: new Decimal(1.2345), + Boolean: true, + Bytes: new Uint8Array([1, 2, 3, 4]), + }; + + await db.foo.create({ data }); + + const r = await db.foo.findUnique({ where: { id: '1' } }); + expect(r.Bytes).toEqual(data.Bytes); + }); +}); diff --git a/packages/runtime/test/client-api/update.test.ts b/packages/runtime/test/client-api/update.test.ts index ccf4b3b2..1be6cac3 100644 --- a/packages/runtime/test/client-api/update.test.ts +++ b/packages/runtime/test/client-api/update.test.ts @@ -24,6 +24,8 @@ describe.each(createClientSpecs(PG_DB_NAME))( it('works with toplevel update', async () => { const user = await createUser(client, 'u1@test.com'); + expect(user.updatedAt).toBeInstanceOf(Date); + // not found await expect( client.user.update({ @@ -33,24 +35,30 @@ describe.each(createClientSpecs(PG_DB_NAME))( ).toBeRejectedNotFound(); // empty data - await expect( - client.user.update({ - where: { id: user.id }, - data: {}, - }) - ).resolves.toEqual(user); + let updated = await client.user.update({ + where: { id: user.id }, + data: {}, + }); + expect(updated).toMatchObject({ + email: user.email, + name: user.name, + }); + expect(updated.updatedAt.getTime()).toBeGreaterThan( + user.updatedAt.getTime() + ); // id as filter - await expect( - client.user.update({ - where: { id: user.id }, - data: { email: 'u2.test.com', name: 'Foo' }, - }) - ).resolves.toEqual({ - ...user, + updated = await client.user.update({ + where: { id: user.id }, + data: { email: 'u2.test.com', name: 'Foo' }, + }); + expect(updated).toMatchObject({ email: 'u2.test.com', name: 'Foo', }); + expect(updated.updatedAt.getTime()).toBeGreaterThan( + user.updatedAt.getTime() + ); // non-id unique as filter await expect( @@ -58,8 +66,7 @@ describe.each(createClientSpecs(PG_DB_NAME))( where: { email: 'u2.test.com' }, data: { email: 'u2.test.com', name: 'Bar' }, }) - ).resolves.toEqual({ - ...user, + ).resolves.toMatchObject({ email: 'u2.test.com', name: 'Bar', }); @@ -1141,7 +1148,7 @@ describe.each(createClientSpecs(PG_DB_NAME))( }, }, }) - ).resolves.toMatchObject(post); + ).toResolveTruthy(); // not updated await expect( client.comment.findUnique({ where: { id: '4' } }) @@ -1162,7 +1169,7 @@ describe.each(createClientSpecs(PG_DB_NAME))( }, }, }) - ).resolves.toMatchObject(post); + ).toResolveTruthy(); }); }); @@ -1962,43 +1969,44 @@ describe.each(createClientSpecs(PG_DB_NAME))( client.user.findUnique({ where: { id: '1' } }) ).toResolveTruthy(); - // true - await expect( - client.profile.update({ - where: { id: profile.id }, - data: { - user: { - delete: true, - }, - }, - include: { user: true }, - }) - ).toResolveNull(); // cascade delete - await expect( - client.user.findUnique({ where: { id: '1' } }) - ).toResolveNull(); + // TODO: how to return for cascade delete? + // await expect( + // client.profile.update({ + // where: { id: profile.id }, + // data: { + // user: { + // delete: true, + // }, + // }, + // include: { user: true }, + // }) + // ).toResolveNull(); // cascade delete + // await expect( + // client.user.findUnique({ where: { id: '1' } }) + // ).toResolveNull(); + await client.user.delete({ where: { id: '1' } }); // with filter - profile = await client.profile.create({ - data: { - bio: 'Bio', - user: { create: { id: '1', email: 'u1@test.com' } }, - }, - }); - await expect( - client.profile.update({ - where: { id: profile.id }, - data: { - user: { - delete: { id: '1' }, - }, - }, - include: { user: true }, - }) - ).toResolveNull(); - await expect( - client.user.findUnique({ where: { id: '1' } }) - ).toResolveNull(); + // profile = await client.profile.create({ + // data: { + // bio: 'Bio', + // user: { create: { id: '1', email: 'u1@test.com' } }, + // }, + // }); + // await expect( + // client.profile.update({ + // where: { id: profile.id }, + // data: { + // user: { + // delete: { id: '1' }, + // }, + // }, + // include: { user: true }, + // }) + // ).toResolveNull(); + // await expect( + // client.user.findUnique({ where: { id: '1' } }) + // ).toResolveNull(); // null relation profile = await client.profile.create({ diff --git a/packages/runtime/test/plugin/kysely-on-query.test.ts b/packages/runtime/test/plugin/kysely-on-query.test.ts index bd5e6762..736a6dc8 100644 --- a/packages/runtime/test/plugin/kysely-on-query.test.ts +++ b/packages/runtime/test/plugin/kysely-on-query.test.ts @@ -13,7 +13,7 @@ describe('Kysely onQuery tests', () => { let _client: ClientContract; beforeEach(async () => { - _client = new ZenStackClient(schema, { log: ['query'] }); + _client = new ZenStackClient(schema); await _client.$pushSchema(); }); diff --git a/packages/runtime/test/policy/auth.test.ts b/packages/runtime/test/policy/auth.test.ts index 28cdd63e..5eabaee0 100644 --- a/packages/runtime/test/policy/auth.test.ts +++ b/packages/runtime/test/policy/auth.test.ts @@ -231,8 +231,7 @@ model Post { @@allow('all', true) } - `, - { log: ['query'] } + ` ); const rawDb = db.$unuseAll(); diff --git a/packages/runtime/test/policy/connect-disconnect.test.ts b/packages/runtime/test/policy/connect-disconnect.test.ts new file mode 100644 index 00000000..20191779 --- /dev/null +++ b/packages/runtime/test/policy/connect-disconnect.test.ts @@ -0,0 +1,376 @@ +import { describe, expect, it } from 'vitest'; +import { createPolicyTestClient } from './utils'; + +describe('connect and disconnect tests', () => { + const modelToMany = ` + model M1 { + id String @id @default(uuid()) + m2 M2[] + value Int @default(0) + + @@deny('read', value < 0) + @@allow('all', true) + } + + model M2 { + id String @id @default(uuid()) + value Int + deleted Boolean @default(false) + m1 M1? @relation(fields: [m1Id], references:[id]) + m1Id String? + m3 M3[] + + @@allow('read,create', true) + @@allow('update', !deleted) + } + + model M3 { + id String @id @default(uuid()) + value Int + deleted Boolean @default(false) + m2 M2? @relation(fields: [m2Id], references:[id]) + m2Id String? + + @@allow('read,create', true) + @@allow('update', !deleted) + } + `; + + it('works with top-level to-many', async () => { + const db = await createPolicyTestClient(modelToMany); + const rawDb = db.$unuseAll(); + + // m1-1 -> m2-1 + await db.m2.create({ data: { id: 'm2-1', value: 1, deleted: false } }); + await db.m1.create({ + data: { + id: 'm1-1', + m2: { + connect: { id: 'm2-1' }, + }, + }, + }); + // mark m2-1 deleted + await rawDb.m2.update({ + where: { id: 'm2-1' }, + data: { deleted: true }, + }); + // disconnect denied because of violation of m2's update rule + await expect( + db.m1.update({ + where: { id: 'm1-1' }, + data: { + m2: { + disconnect: { id: 'm2-1' }, + }, + }, + }) + ).toBeRejectedNotFound(); + // reset m2-1 delete + await rawDb.m2.update({ + where: { id: 'm2-1' }, + data: { deleted: false }, + }); + // disconnect allowed + await db.m1.update({ + where: { id: 'm1-1' }, + data: { + m2: { + disconnect: { id: 'm2-1' }, + }, + }, + }); + + // connect during create denied + await db.m2.create({ data: { id: 'm2-2', value: 1, deleted: true } }); + await expect( + db.m1.create({ + data: { + m2: { + connect: { id: 'm2-2' }, + }, + }, + }) + ).toBeRejectedNotFound(); + + // mixed create and connect + await db.m2.create({ data: { id: 'm2-3', value: 1, deleted: false } }); + await db.m1.create({ + data: { + m2: { + connect: { id: 'm2-3' }, + create: { value: 1, deleted: false }, + }, + }, + }); + + await db.m2.create({ data: { id: 'm2-4', value: 1, deleted: true } }); + await expect( + db.m1.create({ + data: { + m2: { + connect: { id: 'm2-4' }, + create: { value: 1, deleted: false }, + }, + }, + }) + ).toBeRejectedNotFound(); + + // connectOrCreate + await db.m1.create({ + data: { + m2: { + connectOrCreate: { + where: { id: 'm2-5' }, + create: { value: 1 }, + }, + }, + }, + }); + + await db.m2.create({ data: { id: 'm2-6', value: 1, deleted: true } }); + await expect( + db.m1.create({ + data: { + m2: { + connectOrCreate: { + where: { id: 'm2-6' }, + create: { value: 1 }, + }, + }, + }, + }) + ).toBeRejectedNotFound(); + }); + + it('works with nested to-many', async () => { + const db = await createPolicyTestClient(modelToMany); + + await db.m3.create({ data: { id: 'm3-1', value: 1, deleted: false } }); + await expect( + db.m1.create({ + data: { + id: 'm1-1', + m2: { + create: { + value: 1, + m3: { connect: { id: 'm3-1' } }, + }, + }, + }, + }) + ).toResolveTruthy(); + + await db.m3.create({ data: { id: 'm3-2', value: 1, deleted: true } }); + await expect( + db.m1.create({ + data: { + m2: { + create: { + value: 1, + m3: { connect: { id: 'm3-2' } }, + }, + }, + }, + }) + ).toBeRejectedNotFound(); + }); + + const modelToOne = ` + model M1 { + id String @id @default(uuid()) + m2 M2? + + @@allow('all', true) + } + + model M2 { + id String @id @default(uuid()) + value Int + deleted Boolean @default(false) + m1 M1? @relation(fields: [m1Id], references:[id]) + m1Id String? @unique + + @@allow('read,create', true) + @@allow('update', !deleted) + } + `; + + it('works with to-one', async () => { + const db = await createPolicyTestClient(modelToOne); + const rawDb = db.$unuseAll(); + + await db.m2.create({ data: { id: 'm2-1', value: 1, deleted: false } }); + await db.m1.create({ + data: { + id: 'm1-1', + m2: { + connect: { id: 'm2-1' }, + }, + }, + }); + await rawDb.m2.update({ + where: { id: 'm2-1' }, + data: { deleted: true }, + }); + await expect( + db.m1.update({ + where: { id: 'm1-1' }, + data: { + m2: { + disconnect: { id: 'm2-1' }, + }, + }, + }) + ).toBeRejectedNotFound(); + await rawDb.m2.update({ + where: { id: 'm2-1' }, + data: { deleted: false }, + }); + await db.m1.update({ + where: { id: 'm1-1' }, + data: { + m2: { + disconnect: true, + }, + }, + }); + + await db.m2.create({ data: { id: 'm2-2', value: 1, deleted: true } }); + await expect( + db.m1.create({ + data: { + m2: { + connect: { id: 'm2-2' }, + }, + }, + }) + ).toBeRejectedNotFound(); + + // connectOrCreate + await db.m1.create({ + data: { + m2: { + connectOrCreate: { + where: { id: 'm2-3' }, + create: { value: 1 }, + }, + }, + }, + }); + + await db.m2.create({ data: { id: 'm2-4', value: 1, deleted: true } }); + await expect( + db.m1.create({ + data: { + m2: { + connectOrCreate: { + where: { id: 'm2-4' }, + create: { value: 1 }, + }, + }, + }, + }) + ).toBeRejectedNotFound(); + }); + + const modelImplicitManyToMany = ` + model M1 { + id String @id @default(uuid()) + value Int @default(0) + m2 M2[] + + @@deny('read', value < 0) + @@allow('all', true) + } + + model M2 { + id String @id @default(uuid()) + value Int + deleted Boolean @default(false) + m1 M1[] + + @@deny('read', value < 0) + @@allow('read,create', true) + @@allow('update', !deleted) + } + `; + + // TODO: many-to-many support + it.skip('works with implicit many-to-many', async () => { + const db = await createPolicyTestClient(modelImplicitManyToMany); + const rawDb = db.$unuseAll(); + + await rawDb.m1.create({ data: { id: 'm1-2', value: 1 } }); + await rawDb.m2.create({ + data: { id: 'm2-2', value: 1, deleted: true }, + }); + // m2-2 not updatable + await expect( + db.m1.update({ + where: { id: 'm1-2' }, + data: { m2: { connect: { id: 'm2-2' } } }, + }) + ).toBeRejectedByPolicy(); + }); + + const modelExplicitManyToMany = ` + model M1 { + id String @id @default(uuid()) + value Int @default(0) + m2 M1OnM2[] + + @@allow('all', true) + } + + model M2 { + id String @id @default(uuid()) + value Int + deleted Boolean @default(false) + m1 M1OnM2[] + + @@allow('read,create', true) + } + + model M1OnM2 { + m1 M1 @relation(fields: [m1Id], references: [id]) + m1Id String + m2 M2 @relation(fields: [m2Id], references: [id]) + m2Id String + + @@id([m1Id, m2Id]) + @@allow('read', true) + @@allow('create', !m2.deleted) + } + `; + + // TODO: many-to-many support + it.skip('works with explicit many-to-many', async () => { + const db = await createPolicyTestClient(modelExplicitManyToMany); + const rawDb = db.$unuseAll(); + + await rawDb.m1.create({ data: { id: 'm1-1', value: 1 } }); + await rawDb.m2.create({ data: { id: 'm2-1', value: 1 } }); + await expect( + db.m1OnM2.create({ + data: { + m1: { connect: { id: 'm1-1' } }, + m2: { connect: { id: 'm2-1' } }, + }, + }) + ).toResolveTruthy(); + + await rawDb.m1.create({ data: { id: 'm1-2', value: 1 } }); + await rawDb.m2.create({ + data: { id: 'm2-2', value: 1, deleted: true }, + }); + await expect( + db.m1OnM2.create({ + data: { + m1: { connect: { id: 'm1-2' } }, + m2: { connect: { id: 'm2-2' } }, + }, + }) + ).toBeRejectedByPolicy(); + }); +}); diff --git a/packages/runtime/test/policy/create-many-and-return.test.ts b/packages/runtime/test/policy/create-many-and-return.test.ts new file mode 100644 index 00000000..b0b110aa --- /dev/null +++ b/packages/runtime/test/policy/create-many-and-return.test.ts @@ -0,0 +1,92 @@ +import { describe, expect, it } from 'vitest'; +import { createPolicyTestClient } from './utils'; + +describe('createManyAndReturn tests', () => { + it('works with model-level policies', async () => { + const db = await createPolicyTestClient( + ` + model User { + id Int @id @default(autoincrement()) + posts Post[] + level Int + + @@allow('read', level > 0) + } + + model Post { + id Int @id @default(autoincrement()) + title String + published Boolean @default(false) + userId Int + user User @relation(fields: [userId], references: [id]) + + @@allow('read', published) + @@allow('create', contains(title, 'hello')) + } + ` + ); + const rawDb = db.$unuseAll(); + + await rawDb.user.createMany({ + data: [ + { id: 1, level: 1 }, + { id: 2, level: 0 }, + ], + }); + + // create rule violation + await expect( + db.post.createManyAndReturn({ + data: [{ title: 'foo', userId: 1 }], + }) + ).toBeRejectedByPolicy(); + + // success + let r = await db.post.createManyAndReturn({ + data: [{ id: 1, title: 'hello1', userId: 1, published: true }], + }); + expect(r.length).toBe(1); + + // read-back check, only one result is readable + await expect( + db.post.createManyAndReturn({ + data: [ + { id: 2, title: 'hello2', userId: 1, published: true }, + { id: 3, title: 'hello3', userId: 1, published: false }, + ], + }) + ).toResolveWithLength(1); + // two are created indeed + await expect(rawDb.post.findMany()).resolves.toHaveLength(3); + }); + + // TODO: field-level policies support + it.skip('field-level policies', async () => { + const db = await createPolicyTestClient( + ` + model Post { + id Int @id @default(autoincrement()) + title String @allow('read', published) + published Boolean @default(false) + + @@allow('all', true) + } + ` + ); + const rawDb = db.$unuseAll(); + // create should succeed but one result's title field can't be read back + const r = await db.post.createManyAndReturn({ + data: [ + { title: 'post1', published: true }, + { title: 'post2', published: false }, + ], + }); + + expect(r.length).toBe(2); + expect(r[0].title).toBeTruthy(); + expect(r[1].title).toBeUndefined(); + + // check posts are created + await expect(rawDb.post.findMany()).resolves.toHaveLength(2); + }); +}); diff --git a/packages/runtime/test/policy/cross-model-field-comparison.test.ts b/packages/runtime/test/policy/cross-model-field-comparison.test.ts new file mode 100644 index 00000000..ad495209 --- /dev/null +++ b/packages/runtime/test/policy/cross-model-field-comparison.test.ts @@ -0,0 +1,221 @@ +import { describe, expect, it } from 'vitest'; +import { createPolicyTestClient } from './utils'; + +describe('cross-model field comparison tests', () => { + it('works with to-one relation', async () => { + const db = await createPolicyTestClient( + ` + model User { + id Int @id + profile Profile @relation(fields: [profileId], references: [id]) + profileId Int @unique + age Int + + @@allow('all', age == profile.age) + @@deny('update', age > 100) + } + + model Profile { + id Int @id + age Int + user User? + + @@allow('all', true) + } + ` + ); + + const rawDb = db.$unuseAll(); + + const reset = async () => { + await rawDb.user.deleteMany(); + await rawDb.profile.deleteMany(); + }; + + // create + await expect( + db.user.create({ + data: { + id: 1, + age: 18, + profile: { create: { id: 1, age: 20 } }, + }, + }) + ).toBeRejectedByPolicy(); + await expect( + rawDb.user.findUnique({ where: { id: 1 } }) + ).toResolveNull(); + await expect( + db.user.create({ + data: { + id: 1, + age: 18, + profile: { create: { id: 1, age: 18 } }, + }, + }) + ).toResolveTruthy(); + await expect( + rawDb.user.findUnique({ where: { id: 1 } }) + ).toResolveTruthy(); + await reset(); + + // createMany + const profile = await rawDb.profile.create({ + data: { id: 1, age: 20 }, + }); + await expect( + db.user.createMany({ + data: [{ id: 1, age: 18, profileId: profile.id }], + }) + ).toBeRejectedByPolicy(); + await expect( + rawDb.user.findUnique({ where: { id: 1 } }) + ).toResolveNull(); + await expect( + db.user.createMany({ + data: { id: 1, age: 20, profileId: profile.id }, + }) + ).toResolveTruthy(); + await expect( + rawDb.user.findUnique({ where: { id: 1 } }) + ).toResolveTruthy(); + await reset(); + + // read + await rawDb.user.create({ + data: { id: 1, age: 18, profile: { create: { id: 1, age: 18 } } }, + }); + await expect( + db.user.findUnique({ where: { id: 1 } }) + ).toResolveTruthy(); + await expect(db.user.findMany()).resolves.toHaveLength(1); + await rawDb.user.update({ where: { id: 1 }, data: { age: 20 } }); + await expect(db.user.findUnique({ where: { id: 1 } })).toResolveNull(); + await expect(db.user.findMany()).resolves.toHaveLength(0); + await reset(); + + // update + await rawDb.user.create({ + data: { id: 1, age: 18, profile: { create: { id: 1, age: 18 } } }, + }); + // update should succeed but read back is rejected + await expect( + db.user.update({ where: { id: 1 }, data: { age: 20 } }) + ).toBeRejectedByPolicy(); + await expect( + rawDb.user.findUnique({ where: { id: 1 } }) + ).resolves.toMatchObject({ age: 20 }); + await expect( + db.user.update({ where: { id: 1 }, data: { age: 18 } }) + ).toBeRejectedNotFound(); + await reset(); + + // // post update + // await rawDb.user.create({ + // data: { id: 1, age: 18, profile: { create: { id: 1, age: 18 } } }, + // }); + // await expect( + // db.user.update({ where: { id: 1 }, data: { age: 15 } }) + // ).toBeRejectedByPolicy(); + // await expect( + // db.user.update({ where: { id: 1 }, data: { age: 20 } }) + // ).toResolveTruthy(); + // await reset(); + + // TODO: upsert support + // // upsert + // await rawDb.user.create({ + // data: { id: 1, age: 18, profile: { create: { id: 1, age: 20 } } }, + // }); + // await expect( + // db.user.upsert({ + // where: { id: 1 }, + // create: { id: 1, age: 25 }, + // update: { age: 25 }, + // }) + // ).toBeRejectedByPolicy(); + // await expect( + // db.user.upsert({ + // where: { id: 2 }, + // create: { + // id: 2, + // age: 18, + // profile: { create: { id: 2, age: 25 } }, + // }, + // update: { age: 25 }, + // }) + // ).toBeRejectedByPolicy(); + // await rawDb.user.update({ where: { id: 1 }, data: { age: 20 } }); + // await expect( + // db.user.upsert({ + // where: { id: 1 }, + // create: { id: 1, age: 25 }, + // update: { age: 25 }, + // }) + // ).toResolveTruthy(); + // await expect( + // rawDb.user.findUnique({ where: { id: 1 } }) + // ).resolves.toMatchObject({ age: 25 }); + // await expect( + // db.user.upsert({ + // where: { id: 2 }, + // create: { + // id: 2, + // age: 25, + // profile: { create: { id: 2, age: 25 } }, + // }, + // update: { age: 25 }, + // }) + // ).toResolveTruthy(); + // await expect(rawDb.user.findMany()).resolves.toHaveLength(2); + // await reset(); + + // updateMany + await rawDb.user.create({ + data: { id: 1, age: 18, profile: { create: { id: 1, age: 20 } } }, + }); + // non updatable + await expect( + db.user.updateMany({ data: { age: 18 } }) + ).resolves.toMatchObject({ count: 0 }); + await rawDb.user.create({ + data: { id: 2, age: 25, profile: { create: { id: 2, age: 25 } } }, + }); + // one of the two is updatable + await expect( + db.user.updateMany({ data: { age: 30 } }) + ).resolves.toMatchObject({ count: 1 }); + await expect( + rawDb.user.findUnique({ where: { id: 1 } }) + ).resolves.toMatchObject({ age: 18 }); + await expect( + rawDb.user.findUnique({ where: { id: 2 } }) + ).resolves.toMatchObject({ age: 30 }); + await reset(); + + // delete + await rawDb.user.create({ + data: { id: 1, age: 18, profile: { create: { id: 1, age: 20 } } }, + }); + await expect( + db.user.delete({ where: { id: 1 } }) + ).toBeRejectedNotFound(); + await expect(rawDb.user.findMany()).resolves.toHaveLength(1); + await rawDb.user.update({ where: { id: 1 }, data: { age: 20 } }); + await expect(db.user.delete({ where: { id: 1 } })).toResolveTruthy(); + await expect(rawDb.user.findMany()).resolves.toHaveLength(0); + await reset(); + + // deleteMany + await rawDb.user.create({ + data: { id: 1, age: 18, profile: { create: { id: 1, age: 20 } } }, + }); + await expect(db.user.deleteMany()).resolves.toMatchObject({ count: 0 }); + await rawDb.user.create({ + data: { id: 2, age: 25, profile: { create: { id: 2, age: 25 } } }, + }); + // one of the two is deletable + await expect(db.user.deleteMany()).resolves.toMatchObject({ count: 1 }); + await expect(rawDb.user.findMany()).resolves.toHaveLength(1); + }); +}); diff --git a/packages/runtime/test/policy/policy-functions.test.ts b/packages/runtime/test/policy/policy-functions.test.ts new file mode 100644 index 00000000..de49b6d1 --- /dev/null +++ b/packages/runtime/test/policy/policy-functions.test.ts @@ -0,0 +1,238 @@ +import { describe, expect, it } from 'vitest'; +import { createPolicyTestClient } from './utils'; + +describe('policy functions tests', () => { + it('supports contains with case-sensitive field', async () => { + const db = await createPolicyTestClient( + ` + model Foo { + id String @id @default(cuid()) + string String + @@allow('all', contains(string, 'a')) + } + ` + ); + + await expect( + db.foo.create({ data: { string: 'bcd' } }) + ).toBeRejectedByPolicy(); + await expect( + db.foo.create({ data: { string: 'bac' } }) + ).toResolveTruthy(); + }); + + it('supports contains with case-sensitive non-field', async () => { + const db = await createPolicyTestClient( + ` + model User { + id String @id + name String + } + + model Foo { + id String @id @default(cuid()) + @@allow('all', contains(auth().name, 'a')) + } + ` + ); + + await expect(db.foo.create({ data: {} })).toBeRejectedByPolicy(); + await expect( + db.$setAuth({ id: 'user1', name: 'bcd' }).foo.create({ data: {} }) + ).toBeRejectedByPolicy(); + await expect( + db.$setAuth({ id: 'user1', name: 'bac' }).foo.create({ data: {} }) + ).toResolveTruthy(); + }); + + it('supports contains with auth()', async () => { + const anonDb = await createPolicyTestClient( + ` + model User { + id String @id + name String + } + + model Foo { + id String @id @default(cuid()) + string String + @@allow('all', contains(string, auth().name)) + } + ` + ); + + // 'abc' contains null + await expect( + anonDb.foo.create({ data: { string: 'abc' } }) + ).toResolveTruthy(); + const db = anonDb.$setAuth({ id: '1', name: 'a' }); + await expect( + db.foo.create({ data: { string: 'bcd' } }) + ).toBeRejectedByPolicy(); + await expect( + db.foo.create({ data: { string: 'bac' } }) + ).toResolveTruthy(); + }); + + it('supports startsWith with field', async () => { + const db = await createPolicyTestClient( + ` + model Foo { + id String @id @default(cuid()) + string String + @@allow('all', startsWith(string, 'a')) + } + ` + ); + + await expect( + db.foo.create({ data: { string: 'bac' } }) + ).toBeRejectedByPolicy(); + await expect( + db.foo.create({ data: { string: 'abc' } }) + ).toResolveTruthy(); + }); + + it('supports startsWith with non-field', async () => { + const anonDb = await createPolicyTestClient( + ` + model User { + id String @id + name String + } + + model Foo { + id String @id @default(cuid()) + @@allow('all', startsWith(auth().name, 'a')) + } + ` + ); + + await expect(anonDb.foo.create({ data: {} })).toBeRejectedByPolicy(); + await expect(anonDb.foo.create({ data: {} })).toBeRejectedByPolicy(); + await expect( + anonDb + .$setAuth({ id: 'user1', name: 'abc' }) + .foo.create({ data: {} }) + ).toResolveTruthy(); + }); + + it('supports endsWith with field', async () => { + const db = await createPolicyTestClient( + ` + model Foo { + id String @id @default(cuid()) + string String + @@allow('all', endsWith(string, 'a')) + } + ` + ); + + await expect( + db.foo.create({ data: { string: 'bac' } }) + ).toBeRejectedByPolicy(); + await expect( + db.foo.create({ data: { string: 'bca' } }) + ).toResolveTruthy(); + }); + + it('supports endsWith with non-field', async () => { + const anonDb = await createPolicyTestClient( + ` + model User { + id String @id + name String + } + + model Foo { + id String @id @default(cuid()) + @@allow('all', endsWith(auth().name, 'a')) + } + ` + ); + + await expect(anonDb.foo.create({ data: {} })).toBeRejectedByPolicy(); + await expect( + anonDb + .$setAuth({ id: 'user1', name: 'bac' }) + .foo.create({ data: {} }) + ).toBeRejectedByPolicy(); + await expect( + anonDb + .$setAuth({ id: 'user1', name: 'bca' }) + .foo.create({ data: {} }) + ).toResolveTruthy(); + }); + + it('supports in with field', async () => { + const db = await createPolicyTestClient( + ` + model Foo { + id String @id @default(cuid()) + string String + @@allow('all', string in ['a', 'b']) + } + ` + ); + + await expect( + db.foo.create({ data: { string: 'c' } }) + ).toBeRejectedByPolicy(); + await expect( + db.foo.create({ data: { string: 'b' } }) + ).toResolveTruthy(); + }); + + it('supports in with non-field', async () => { + const anonDb = await createPolicyTestClient( + ` + model User { + id String @id + name String + } + + model Foo { + id String @id @default(cuid()) + @@allow('all', auth().name in ['abc', 'bcd']) + } + ` + ); + + await expect(anonDb.foo.create({ data: {} })).toBeRejectedByPolicy(); + await expect( + anonDb + .$setAuth({ id: 'user1', name: 'abd' }) + .foo.create({ data: {} }) + ).toBeRejectedByPolicy(); + await expect( + anonDb + .$setAuth({ id: 'user1', name: 'abc' }) + .foo.create({ data: {} }) + ).toResolveTruthy(); + }); + + it('supports now', async () => { + const db = await createPolicyTestClient( + ` + model Foo { + id String @id @default(cuid()) + dt DateTime @default(now()) + @@allow('create,read', true) + @@allow('update', now() >= dt) + } + ` + ); + + const now = new Date(); + + const created = await db.foo.create({ + data: { id: '1', dt: new Date(now.getTime() + 1000) }, + }); + console.log(created); + + // violates `dt <= now()` + await expect( + db.foo.update({ where: { id: '1' }, data: { dt: now } }) + ).toBeRejectedNotFound(); + }); +}); diff --git a/packages/runtime/test/policy/todo-sample.test.ts b/packages/runtime/test/policy/todo-sample.test.ts index e09d416d..2d278d02 100644 --- a/packages/runtime/test/policy/todo-sample.test.ts +++ b/packages/runtime/test/policy/todo-sample.test.ts @@ -4,7 +4,7 @@ import { beforeAll, describe, expect, it } from 'vitest'; import type { SchemaDef } from '../../src/schema'; import { createPolicyTestClient } from './utils'; -describe('Todo sample', () => { +describe('todo sample tests', () => { let schema: SchemaDef; beforeAll(async () => { @@ -389,7 +389,7 @@ describe('Todo sample', () => { }); it('works with relation queries', async () => { - const anonDb = await createPolicyTestClient(schema, { log: ['query'] }); + const anonDb = await createPolicyTestClient(schema); await createSpaceAndUsers(anonDb.$unuseAll()); const user1Db = anonDb.$setAuth({ id: user1.id }); @@ -426,6 +426,67 @@ describe('Todo sample', () => { }); expect(r1.lists).toHaveLength(1); }); + + // TODO: `future()` support + it.skip('works with post-update checks', async () => { + const anonDb = await createPolicyTestClient(schema); + await createSpaceAndUsers(anonDb.$unuseAll()); + + const user1Db = anonDb.$setAuth({ id: user1.id }); + + await user1Db.list.create({ + data: { + id: 'list1', + title: 'List 1', + owner: { connect: { id: user1.id } }, + space: { connect: { id: space1.id } }, + todos: { + create: { + id: 'todo1', + title: 'Todo 1', + owner: { connect: { id: user1.id } }, + }, + }, + }, + }); + + // change list's owner + await expect( + user1Db.list.update({ + where: { id: 'list1' }, + data: { + owner: { connect: { id: user2.id } }, + }, + }) + ).toBeRejectedByPolicy(); + + // change todo's owner + await expect( + user1Db.todo.update({ + where: { id: 'todo1' }, + data: { + owner: { connect: { id: user2.id } }, + }, + }) + ).toBeRejectedByPolicy(); + + // nested change todo's owner + await expect( + user1Db.list.update({ + where: { id: 'list1' }, + data: { + todos: { + update: { + where: { id: 'todo1' }, + data: { + owner: { connect: { id: user2.id } }, + }, + }, + }, + }, + }) + ).toBeRejectedByPolicy(); + }); }); const user1 = { diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index fc2416c1..e65835c0 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -139,9 +139,6 @@ importers: decimal.js: specifier: ^10.4.3 version: 10.4.3 - decimal.js-light: - specifier: ^2.5.1 - version: 2.5.1 kysely: specifier: ^0.27.5 version: 0.27.6 @@ -1368,9 +1365,6 @@ packages: supports-color: optional: true - decimal.js-light@2.5.1: - resolution: {integrity: sha512-qIMFpTMZmny+MMIitAB6D7iVPEorVw6YQRWkvarTkT4tBeSLLiHzcwj6q0MmYSFCiVpiqPJTJEYIrpcPzVEIvg==} - decimal.js@10.4.3: resolution: {integrity: sha512-VBBaLc1MgL5XpzgIP7ny5Z6Nx3UrRkIViUkPUdtl9aya5amy3De1gsUUSB1g3+3sExYNjCAsAznmukyxCb1GRA==} @@ -3773,8 +3767,6 @@ snapshots: dependencies: ms: 2.1.3 - decimal.js-light@2.5.1: {} - decimal.js@10.4.3: {} decompress-response@6.0.0: