From 7bc1b2e90b81785bfa9097decd1eac39bcb0761e Mon Sep 17 00:00:00 2001 From: ymc9 <104139426+ymc9@users.noreply.github.com> Date: Thu, 8 May 2025 12:53:28 -0700 Subject: [PATCH 1/2] WIP --- packages/cli/package.json | 4 +- packages/cli/src/actions/generate.ts | 4 +- .../cli/src/prisma/prisma-schema-generator.ts | 16 +- packages/cli/test/ts-schema-gen.test.ts | 2 +- packages/cli/tsup.config.ts | 2 +- packages/language/res/stdlib.zmodel | 12 + packages/language/src/index.ts | 23 +- packages/language/src/utils.ts | 40 +- .../src/validators/expression-validator.ts | 43 +- packages/language/src/zmodel-linker.ts | 41 +- packages/language/src/zmodel-scope.ts | 123 ++-- packages/runtime/package.json | 8 +- packages/runtime/src/client/client-impl.ts | 24 + packages/runtime/src/client/contract.ts | 16 + .../runtime/src/client/crud/dialects/base.ts | 10 +- .../src/client/crud/operations/base.ts | 25 +- .../src/client/crud/operations/create.ts | 66 +- packages/runtime/src/client/errors.ts | 6 + .../src/client/executor/name-mapper.ts | 26 +- .../executor/zenstack-query-executor.ts | 4 +- .../src/client/helpers/schema-db-pusher.ts | 4 +- packages/runtime/src/client/options.ts | 6 +- packages/runtime/src/client/plugin.ts | 8 +- packages/runtime/src/client/query-utils.ts | 2 +- .../src/plugins/policy/column-collector.ts | 21 + .../plugins/policy/expression-transformer.ts | 596 ++++++++++++++---- .../runtime/src/plugins/policy/generator.ts | 84 --- packages/runtime/src/plugins/policy/index.ts | 1 - .../runtime/src/plugins/policy/options.ts | 16 - packages/runtime/src/plugins/policy/plugin.ts | 28 +- .../runtime/src/plugins/policy/plugin.zmodel | 12 - .../src/plugins/policy/policy-handler.ts | 524 +++++++++++++++ .../src/plugins/policy/policy-transformer.ts | 136 ---- packages/runtime/src/plugins/policy/types.ts | 21 +- packages/runtime/src/plugins/policy/utils.ts | 153 +++++ packages/runtime/src/schema/expression.ts | 64 +- packages/runtime/src/schema/schema.ts | 26 +- .../utils/default-operation-node-visitor.ts | 415 ++++++++++++ .../runtime/test/client-api/delete.test.ts | 2 +- .../runtime/test/client-api/update.test.ts | 44 +- .../runtime/test/policy/todo-sample.test.ts | 131 ++++ packages/runtime/test/schemas/todo.zmodel | 153 +++++ packages/runtime/test/test-schema.ts | 114 ++-- packages/runtime/test/vitest-ext.ts | 34 +- packages/runtime/test/vitest.d.ts | 3 +- packages/sdk/package.json | 42 ++ packages/sdk/src/index.ts | 4 + .../src/zmodel => sdk/src}/model-utils.ts | 23 +- .../zmodel => sdk/src}/ts-schema-generator.ts | 259 ++++++-- .../src}/zmodel-code-generator.ts | 4 +- packages/sdk/tsconfig.json | 8 + packages/sdk/tsup.config.ts | 13 + packages/testtools/package.json | 46 ++ packages/testtools/src/index.ts | 1 + .../test/utils.ts => testtools/src/schema.ts} | 24 +- packages/testtools/tsup.config.ts | 13 + pnpm-lock.yaml | 158 ++++- samples/blog/zenstack/schema.ts | 8 +- samples/blog/zenstack/schema.zmodel | 4 - 59 files changed, 2898 insertions(+), 802 deletions(-) create mode 100644 packages/runtime/src/plugins/policy/column-collector.ts delete mode 100644 packages/runtime/src/plugins/policy/generator.ts delete mode 100644 packages/runtime/src/plugins/policy/options.ts create mode 100644 packages/runtime/src/plugins/policy/policy-handler.ts delete mode 100644 packages/runtime/src/plugins/policy/policy-transformer.ts create mode 100644 packages/runtime/src/plugins/policy/utils.ts create mode 100644 packages/runtime/src/utils/default-operation-node-visitor.ts create mode 100644 packages/runtime/test/policy/todo-sample.test.ts create mode 100644 packages/runtime/test/schemas/todo.zmodel create mode 100644 packages/sdk/package.json create mode 100644 packages/sdk/src/index.ts rename packages/{cli/src/zmodel => sdk/src}/model-utils.ts (94%) rename packages/{cli/src/zmodel => sdk/src}/ts-schema-generator.ts (84%) rename packages/{cli/src/zmodel => sdk/src}/zmodel-code-generator.ts (99%) create mode 100644 packages/sdk/tsconfig.json create mode 100644 packages/sdk/tsup.config.ts create mode 100644 packages/testtools/package.json create mode 100644 packages/testtools/src/index.ts rename packages/{cli/test/utils.ts => testtools/src/schema.ts} (67%) create mode 100644 packages/testtools/tsup.config.ts diff --git a/packages/cli/package.json b/packages/cli/package.json index d8555b57..64f6af11 100644 --- a/packages/cli/package.json +++ b/packages/cli/package.json @@ -27,9 +27,10 @@ "pack": "pnpm pack" }, "dependencies": { - "@types/node": "^20.12.7", + "@types/node": "^18.0.0", "@zenstackhq/language": "workspace:*", "@zenstackhq/runtime": "workspace:*", + "@zenstackhq/sdk": "workspace:*", "async-exit-hook": "^2.0.1", "colors": "1.4.0", "commander": "^8.3.0", @@ -43,6 +44,7 @@ "typescript": "^5.0.0" }, "devDependencies": { + "@zenstackhq/testtools": "workspace:*", "@types/async-exit-hook": "^2.0.0", "@types/better-sqlite3": "^7.6.13", "@types/semver": "^7.3.13", diff --git a/packages/cli/src/actions/generate.ts b/packages/cli/src/actions/generate.ts index 92c14121..855746bb 100644 --- a/packages/cli/src/actions/generate.ts +++ b/packages/cli/src/actions/generate.ts @@ -1,11 +1,11 @@ import { isPlugin, LiteralExpr, type Model } from '@zenstackhq/language/ast'; import type { CliGenerator } from '@zenstackhq/runtime/client'; +import { TsSchemaGenerator } from '@zenstackhq/sdk'; import colors from 'colors'; import fs from 'node:fs'; import path from 'node:path'; import invariant from 'tiny-invariant'; import { PrismaSchemaGenerator } from '../prisma/prisma-schema-generator'; -import { TsSchemaGenerator } from '../zmodel/ts-schema-generator'; import { getSchemaFile, loadSchemaDocument } from './action-utils'; type Options = { @@ -25,7 +25,7 @@ export async function run(options: Options) { // generate TS schema const tsSchemaFile = path.join(outputPath, 'schema.ts'); - await new TsSchemaGenerator().generate(schemaFile, tsSchemaFile); + await new TsSchemaGenerator().generate(schemaFile, [], tsSchemaFile); await runPlugins(model, outputPath, tsSchemaFile); diff --git a/packages/cli/src/prisma/prisma-schema-generator.ts b/packages/cli/src/prisma/prisma-schema-generator.ts index e7c29ade..62744403 100644 --- a/packages/cli/src/prisma/prisma-schema-generator.ts +++ b/packages/cli/src/prisma/prisma-schema-generator.ts @@ -32,12 +32,7 @@ import { import { AstUtils } from 'langium'; import { match, P } from 'ts-pattern'; -import { - hasAttribute, - isDelegateModel, - isIdField, -} from '../zmodel/model-utils'; -import { ZModelCodeGenerator } from '../zmodel/zmodel-code-generator'; +import { ModelUtils, ZModelCodeGenerator } from '@zenstackhq/sdk'; import { AttributeArgValue, ModelField, @@ -165,7 +160,7 @@ export class PrismaSchemaGenerator { ? prisma.addView(decl.name) : prisma.addModel(decl.name); for (const field of decl.fields) { - if (hasAttribute(field, '@computed')) { + if (ModelUtils.hasAttribute(field, '@computed')) { continue; // skip computed fields } // TODO: exclude fields inherited from delegate @@ -274,7 +269,7 @@ export class PrismaSchemaGenerator { (attr) => // when building physical schema, exclude `@default` for id fields inherited from delegate base !( - isIdField(field) && + ModelUtils.isIdField(field) && this.isInheritedFromDelegate(field) && attr.decl.$refText === '@default' ) @@ -360,7 +355,10 @@ export class PrismaSchemaGenerator { } private isInheritedFromDelegate(field: DataModelField) { - return field.$inheritedFrom && isDelegateModel(field.$inheritedFrom); + return ( + field.$inheritedFrom && + ModelUtils.isDelegateModel(field.$inheritedFrom) + ); } private makeFieldAttribute(attr: DataModelFieldAttribute) { diff --git a/packages/cli/test/ts-schema-gen.test.ts b/packages/cli/test/ts-schema-gen.test.ts index c9dddb46..fbf8414a 100644 --- a/packages/cli/test/ts-schema-gen.test.ts +++ b/packages/cli/test/ts-schema-gen.test.ts @@ -1,5 +1,5 @@ import { describe, expect, it } from 'vitest'; -import { generateTsSchema } from './utils'; +import { generateTsSchema } from '@zenstackhq/testtools'; describe('TypeScript schema generation tests', () => { it('generates correct data models', async () => { diff --git a/packages/cli/tsup.config.ts b/packages/cli/tsup.config.ts index 67517d13..2496f3ea 100644 --- a/packages/cli/tsup.config.ts +++ b/packages/cli/tsup.config.ts @@ -9,5 +9,5 @@ export default defineConfig({ sourcemap: true, clean: true, dts: true, - format: ['esm'], + format: ['esm', 'cjs'], }); diff --git a/packages/language/res/stdlib.zmodel b/packages/language/res/stdlib.zmodel index cced265c..363d818b 100644 --- a/packages/language/res/stdlib.zmodel +++ b/packages/language/res/stdlib.zmodel @@ -707,3 +707,15 @@ attribute @json() @@@targetField([TypeDefField]) * Marks a field to be computed. */ attribute @computed() + +/** + * Gets the current login user. + */ +function auth(): Any { +} @@@expressionContext([DefaultValue, AccessPolicy]) + +/** + * Used to specify the model for resolving `auth()` function call in access policies. A Zmodel + * can have at most one model with this attribute. By default, the model named "User" is used. + */ +attribute @@auth() @@@supportTypeDef diff --git a/packages/language/src/index.ts b/packages/language/src/index.ts index a2d49035..f2dc07a3 100644 --- a/packages/language/src/index.ts +++ b/packages/language/src/index.ts @@ -18,7 +18,8 @@ export class DocumentLoadError extends Error { } export async function loadDocument( - fileName: string + fileName: string, + pluginModelFiles: string[] = [] ): Promise< | { success: true; model: Model; warnings: string[] } | { success: false; errors: string[]; warnings: string[] } @@ -55,16 +56,28 @@ export async function loadDocument( ) ); - const langiumDocuments = services.shared.workspace.LangiumDocuments; + // load plugin model files + const pluginDocs = await Promise.all( + pluginModelFiles.map((file) => + services.shared.workspace.LangiumDocuments.getOrCreateDocument( + URI.file(path.resolve(file)) + ) + ) + ); + // load the document + const langiumDocuments = services.shared.workspace.LangiumDocuments; const document = await langiumDocuments.getOrCreateDocument( URI.file(path.resolve(fileName)) ); // build the document together with standard library, plugin modules, and imported documents - await services.shared.workspace.DocumentBuilder.build([stdLib, document], { - validation: true, - }); + await services.shared.workspace.DocumentBuilder.build( + [stdLib, ...pluginDocs, document], + { + validation: true, + } + ); const diagnostics = langiumDocuments.all .flatMap((doc) => diff --git a/packages/language/src/utils.ts b/packages/language/src/utils.ts index f4653692..39353803 100644 --- a/packages/language/src/utils.ts +++ b/packages/language/src/utils.ts @@ -76,13 +76,13 @@ export function isFromStdlib(node: AstNode) { ); } -// export function isAuthInvocation(node: AstNode) { -// return ( -// isInvocationExpr(node) && -// node.function.ref?.name === 'auth' && -// isFromStdlib(node.function.ref) -// ); -// } +export function isAuthInvocation(node: AstNode) { + return ( + isInvocationExpr(node) && + node.function.ref?.name === 'auth' && + isFromStdlib(node.function.ref) + ); +} /** * Try getting string value from a potential string literal expression @@ -161,12 +161,12 @@ export function mapBuiltinTypeToExpressionType( } } -// export function isAuthOrAuthMemberAccess(expr: Expression): boolean { -// return ( -// isAuthInvocation(expr) || -// (isMemberAccessExpr(expr) && isAuthOrAuthMemberAccess(expr.operand)) -// ); -// } +export function isAuthOrAuthMemberAccess(expr: Expression): boolean { + return ( + isAuthInvocation(expr) || + (isMemberAccessExpr(expr) && isAuthOrAuthMemberAccess(expr.operand)) + ); +} export function isEnumFieldReference(node: AstNode): node is ReferenceExpr { return isReferenceExpr(node) && isEnumField(node.target.ref); @@ -598,13 +598,13 @@ export function getAllDeclarationsIncludingImports( return model.declarations.concat(...imports.map((imp) => imp.declarations)); } -// export function getAuthDecl(decls: (DataModel | TypeDef)[]) { -// let authModel = decls.find((m) => hasAttribute(m, '@@auth')); -// if (!authModel) { -// authModel = decls.find((m) => m.name === 'User'); -// } -// return authModel; -// } +export function getAuthDecl(decls: (DataModel | TypeDef)[]) { + let authModel = decls.find((m) => hasAttribute(m, '@@auth')); + if (!authModel) { + authModel = decls.find((m) => m.name === 'User'); + } + return authModel; +} export function isFutureInvocation(node: AstNode) { return ( diff --git a/packages/language/src/validators/expression-validator.ts b/packages/language/src/validators/expression-validator.ts index 27abaf69..ea1d1291 100644 --- a/packages/language/src/validators/expression-validator.ts +++ b/packages/language/src/validators/expression-validator.ts @@ -19,6 +19,8 @@ import { import { findUpAst, getAttributeArgLiteral, + isAuthInvocation, + isAuthOrAuthMemberAccess, isDataModelFieldReference, isEnumFieldReference, typeAssignable, @@ -32,18 +34,17 @@ export default class ExpressionValidator implements AstValidator { validate(expr: Expression, accept: ValidationAcceptor): void { // deal with a few cases where reference resolution fail silently if (!expr.$resolvedType) { - // TODO: revisit this - // if (isAuthInvocation(expr)) { - // // check was done at link time - // accept( - // 'error', - // 'auth() cannot be resolved because no model marked with "@@auth()" or named "User" is found', - // { node: expr } - // ); - // } else { - - const hasReferenceResolutionError = AstUtils.streamAst(expr).some( - (node) => { + if (isAuthInvocation(expr)) { + // check was done at link time + accept( + 'error', + 'auth() cannot be resolved because no model marked with "@@auth()" or named "User" is found', + { node: expr } + ); + } else { + const hasReferenceResolutionError = AstUtils.streamAst( + expr + ).some((node) => { if (isMemberAccessExpr(node)) { return !!node.member.error; } @@ -51,15 +52,14 @@ export default class ExpressionValidator implements AstValidator { return !!node.target.error; } return false; - } - ); - if (!hasReferenceResolutionError) { - // report silent errors not involving linker errors - accept('error', 'Expression cannot be resolved', { - node: expr, }); + if (!hasReferenceResolutionError) { + // report silent errors not involving linker errors + accept('error', 'Expression cannot be resolved', { + node: expr, + }); + } } - // } } // extra validations by expression type @@ -379,9 +379,8 @@ export default class ExpressionValidator implements AstValidator { isEnumFieldReference(expr) || // null isNullExpr(expr) || - // TODO: revise cross-model field comparison - // // `auth()` access - // isAuthOrAuthMemberAccess(expr) || + // `auth()` access + isAuthOrAuthMemberAccess(expr) || // array (isArrayExpr(expr) && expr.items.every((item) => this.isNotModelFieldExpr(item))) diff --git a/packages/language/src/zmodel-linker.ts b/packages/language/src/zmodel-linker.ts index 50f10bc5..d6d57944 100644 --- a/packages/language/src/zmodel-linker.ts +++ b/packages/language/src/zmodel-linker.ts @@ -55,8 +55,10 @@ import { } from './ast'; import { getAllLoadedAndReachableDataModelsAndTypeDefs, + getAuthDecl, getContainingDataModel, getModelFieldsWithBases, + isAuthInvocation, isFutureExpr, isMemberContainer, mapBuiltinTypeToExpressionType, @@ -360,23 +362,20 @@ export class ZModelLinker extends DefaultLinker { // eslint-disable-next-line @typescript-eslint/ban-types const funcDecl = node.function.ref as FunctionDecl; - // TODO: revisit this - // if (isAuthInvocation(node)) { - // // auth() function is resolved against all loaded and reachable documents + if (isAuthInvocation(node)) { + // auth() function is resolved against all loaded and reachable documents - // // get all data models from loaded and reachable documents - // const allDecls = getAllLoadedAndReachableDataModelsAndTypeDefs( - // this.langiumDocuments(), - // AstUtils.getContainerOfType(node, isDataModel) - // ); - - // const authDecl = getAuthDecl(allDecls); - // if (authDecl) { - // node.$resolvedType = { decl: authDecl, nullable: true }; - // } - // } else + // get all data models from loaded and reachable documents + const allDecls = getAllLoadedAndReachableDataModelsAndTypeDefs( + this.langiumDocuments(), + AstUtils.getContainerOfType(node, isDataModel) + ); - if (isFutureExpr(node)) { + const authDecl = getAuthDecl(allDecls); + if (authDecl) { + node.$resolvedType = { decl: authDecl, nullable: true }; + } + } else if (isFutureExpr(node)) { // future() function is resolved to current model node.$resolvedType = { decl: getContainingDataModel(node) }; } else { @@ -413,13 +412,11 @@ export class ZModelLinker extends DefaultLinker { // member access is resolved only in the context of the operand type if (node.member.ref) { this.resolveToDeclaredType(node, node.member.ref.type); - - // TODO: revisit this - // if (node.$resolvedType && isAuthInvocation(node.operand)) { - // // member access on auth() function is nullable - // // because user may not have provided all fields - // node.$resolvedType.nullable = true; - // } + if (node.$resolvedType && isAuthInvocation(node.operand)) { + // member access on auth() function is nullable + // because user may not have provided all fields + node.$resolvedType.nullable = true; + } } } } diff --git a/packages/language/src/zmodel-scope.ts b/packages/language/src/zmodel-scope.ts index 977a6f8d..7e47d728 100644 --- a/packages/language/src/zmodel-scope.ts +++ b/packages/language/src/zmodel-scope.ts @@ -33,7 +33,9 @@ import { import { PLUGIN_MODULE_NAME, STD_LIB_MODULE_NAME } from './constants'; import { getAllLoadedAndReachableDataModelsAndTypeDefs, + getAuthDecl, getModelFieldsWithBases, + isAuthInvocation, isCollectionPredicate, isFutureInvocation, resolveImportUri, @@ -213,11 +215,10 @@ export class ZModelScopeProvider extends DefaultScopeProvider { .when(isInvocationExpr, (operand) => { // deal with member access from `auth()` and `future() - // TODO: generalize it - // if (isAuthInvocation(operand)) { - // // resolve to `User` or `@@auth` decl - // return this.createScopeForAuth(node, globalScope); - // } + if (isAuthInvocation(operand)) { + // resolve to `User` or `@@auth` decl + return this.createScopeForAuth(node, globalScope); + } if (isFutureInvocation(operand)) { // resolve `future()` to the containing model @@ -244,51 +245,48 @@ export class ZModelScopeProvider extends DefaultScopeProvider { // const allowTypeDefScope = isAuthOrAuthMemberAccess(collection); const allowTypeDefScope = false; - return ( - match(collection) - .when(isReferenceExpr, (expr) => { - // collection is a reference - model or typedef field - const ref = expr.target.ref; - if (isDataModelField(ref) || isTypeDefField(ref)) { - return this.createScopeForContainer( - ref.type.reference?.ref, - globalScope, - allowTypeDefScope - ); - } - return EMPTY_SCOPE; - }) - .when(isMemberAccessExpr, (expr) => { - // collection is a member access, it can only be resolved to a model or typedef field - const ref = expr.member.ref; - if (isDataModelField(ref) || isTypeDefField(ref)) { - return this.createScopeForContainer( - ref.type.reference?.ref, - globalScope, - allowTypeDefScope - ); - } + return match(collection) + .when(isReferenceExpr, (expr) => { + // collection is a reference - model or typedef field + const ref = expr.target.ref; + if (isDataModelField(ref) || isTypeDefField(ref)) { + return this.createScopeForContainer( + ref.type.reference?.ref, + globalScope, + allowTypeDefScope + ); + } + return EMPTY_SCOPE; + }) + .when(isMemberAccessExpr, (expr) => { + // collection is a member access, it can only be resolved to a model or typedef field + const ref = expr.member.ref; + if (isDataModelField(ref) || isTypeDefField(ref)) { + return this.createScopeForContainer( + ref.type.reference?.ref, + globalScope, + allowTypeDefScope + ); + } + return EMPTY_SCOPE; + }) + .when(isInvocationExpr, (expr) => { + const returnTypeDecl = + expr.function.ref?.returnType.reference?.ref; + if (isDataModel(returnTypeDecl)) { + return this.createScopeForContainer( + returnTypeDecl, + globalScope, + allowTypeDefScope + ); + } else { return EMPTY_SCOPE; - }) - .when(isInvocationExpr, (expr) => { - const returnTypeDecl = - expr.function.ref?.returnType.reference?.ref; - if (isDataModel(returnTypeDecl)) { - return this.createScopeForContainer( - returnTypeDecl, - globalScope, - allowTypeDefScope - ); - } else { - return EMPTY_SCOPE; - } - }) - // TODO: generalize it - // .when(isAuthInvocation, (expr) => { - // return this.createScopeForAuth(expr, globalScope); - // }) - .otherwise(() => EMPTY_SCOPE) - ); + } + }) + .when(isAuthInvocation, (expr) => { + return this.createScopeForAuth(expr, globalScope); + }) + .otherwise(() => EMPTY_SCOPE); } private createScopeForContainingModel(node: AstNode, globalScope: Scope) { @@ -317,21 +315,20 @@ export class ZModelScopeProvider extends DefaultScopeProvider { } } - // TODO: revisit this - // private createScopeForAuth(node: AstNode, globalScope: Scope) { - // // get all data models and type defs from loaded and reachable documents - // const decls = getAllLoadedAndReachableDataModelsAndTypeDefs( - // this.services.shared.workspace.LangiumDocuments, - // AstUtils.getContainerOfType(node, isDataModel) - // ); + private createScopeForAuth(node: AstNode, globalScope: Scope) { + // get all data models and type defs from loaded and reachable documents + const decls = getAllLoadedAndReachableDataModelsAndTypeDefs( + this.services.shared.workspace.LangiumDocuments, + AstUtils.getContainerOfType(node, isDataModel) + ); - // const authDecl = getAuthDecl(decls); - // if (authDecl) { - // return this.createScopeForContainer(authDecl, globalScope, true); - // } else { - // return EMPTY_SCOPE; - // } - // } + const authDecl = getAuthDecl(decls); + if (authDecl) { + return this.createScopeForContainer(authDecl, globalScope, true); + } else { + return EMPTY_SCOPE; + } + } } function getCollectionPredicateContext(node: AstNode) { diff --git a/packages/runtime/package.json b/packages/runtime/package.json index e09bfcfb..36da9608 100644 --- a/packages/runtime/package.json +++ b/packages/runtime/package.json @@ -74,7 +74,6 @@ "uuid": "^11.0.5" }, "peerDependencies": { - "@zenstackhq/language": "workspace:*", "better-sqlite3": "^11.8.1", "pg": "^8.13.1", "zod": "^3.0.0" @@ -85,15 +84,14 @@ }, "pg": { "optional": true - }, - "@zenstackhq/language": { - "optional": true } }, "devDependencies": { "@types/better-sqlite3": "^7.0.0", "@types/pg": "^8.0.0", "@types/tmp": "^0.2.6", - "tmp": "^0.2.3" + "tmp": "^0.2.3", + "@zenstackhq/language": "workspace:*", + "@zenstackhq/testtools": "workspace:*" } } diff --git a/packages/runtime/src/client/client-impl.ts b/packages/runtime/src/client/client-impl.ts index 8f4d4bc6..b94703fc 100644 --- a/packages/runtime/src/client/client-impl.ts +++ b/packages/runtime/src/client/client-impl.ts @@ -30,6 +30,7 @@ import type { RuntimePlugin } from './plugin'; import { createDeferredPromise } from './promise'; import type { ToKysely } from './query-builder'; import { ResultProcessor } from './result-processor'; +import type { AuthType } from '../schema/schema'; /** * Creates a new ZenStack client instance. @@ -47,6 +48,7 @@ export class ClientImpl { public readonly $options: ClientOptions; public readonly $schema: Schema; readonly kyselyProps: KyselyProps; + private auth: AuthType | undefined; constructor( private readonly schema: Schema, @@ -190,6 +192,28 @@ export class ClientImpl { } as ClientOptions; return new ClientImpl(this.schema, newOptions, this); } + + $unuseAll() { + const newOptions = { + ...this.options, + plugins: [] as RuntimePlugin[], + } as ClientOptions; + return new ClientImpl(this.schema, newOptions, this); + } + + $setAuth(auth: AuthType) { + const newClient = new ClientImpl( + this.schema, + this.$options, + this + ); + newClient.auth = auth; + return newClient; + } + + get $auth() { + return this.auth; + } } function createClientProxy( diff --git a/packages/runtime/src/client/contract.ts b/packages/runtime/src/client/contract.ts index 873bf1d6..4ccfd989 100644 --- a/packages/runtime/src/client/contract.ts +++ b/packages/runtime/src/client/contract.ts @@ -1,5 +1,6 @@ import type { Decimal } from 'decimal.js-light'; import { + type AuthType, type GetModels, type ProcedureDef, type SchemaDef, @@ -21,6 +22,16 @@ export type ClientContract = { */ readonly $options: ClientOptions; + /** + * The current user identity. + */ + get $auth(): AuthType | undefined; + + /** + * Sets the current user identity. + */ + $setAuth(auth: AuthType | undefined): ClientContract; + /** * The Kysely query builder instance. */ @@ -38,6 +49,11 @@ export type ClientContract = { */ $use(plugin: RuntimePlugin): ClientContract; + /** + * Returns a new client with all plugins removed. + */ + $unuseAll(): ClientContract; + /** * Disconnects the underlying Kysely instance from the database. */ diff --git a/packages/runtime/src/client/crud/dialects/base.ts b/packages/runtime/src/client/crud/dialects/base.ts index 561b9737..d36bfead 100644 --- a/packages/runtime/src/client/crud/dialects/base.ts +++ b/packages/runtime/src/client/crud/dialects/base.ts @@ -546,7 +546,10 @@ export abstract class BaseCrudDialect { rhs === null ? eb(lhs, 'is', null) : eb(lhs, '=', rhs) ) .with('in', () => { - invariant(Array.isArray(rhs)); + invariant( + Array.isArray(rhs), + 'right hand side must be an array' + ); if (rhs.length === 0) { return this.false(eb); } else { @@ -554,7 +557,10 @@ export abstract class BaseCrudDialect { } }) .with('notIn', () => { - invariant(Array.isArray(rhs)); + invariant( + Array.isArray(rhs), + 'right hand side must be an array' + ); if (rhs.length === 0) { return this.true(eb); } else { diff --git a/packages/runtime/src/client/crud/operations/base.ts b/packages/runtime/src/client/crud/operations/base.ts index 1567f0cf..a8b66cb3 100644 --- a/packages/runtime/src/client/crud/operations/base.ts +++ b/packages/runtime/src/client/crud/operations/base.ts @@ -436,21 +436,22 @@ export abstract class BaseOperationHandler { } const updatedData = this.fillGeneratedValues(modelDef, createFields); + const idFields = getIdFields(this.schema, model); const query = kysely .insertInto(model) .values(updatedData) - .returningAll(); - - let createdEntity: any; - - try { - createdEntity = await query.executeTakeFirst(); - } catch (err) { - const { sql, parameters } = query.compile(); - throw new QueryError( - `Error during create: ${err}, sql: ${sql}, parameters: ${parameters}` - ); - } + .returning(idFields as any); + + const createdEntity = await query.executeTakeFirst(); + + // try { + // createdEntity = await query.executeTakeFirst(); + // } catch (err) { + // const { sql, parameters } = query.compile(); + // throw new QueryError( + // `Error during create: ${err}, sql: ${sql}, parameters: ${parameters}` + // ); + // } if (Object.keys(postCreateRelations).length > 0) { // process nested creates that need to happen after the current entity is created diff --git a/packages/runtime/src/client/crud/operations/create.ts b/packages/runtime/src/client/crud/operations/create.ts index 33db9969..d9696e9e 100644 --- a/packages/runtime/src/client/crud/operations/create.ts +++ b/packages/runtime/src/client/crud/operations/create.ts @@ -1,7 +1,8 @@ import { match } from 'ts-pattern'; import type { GetModels, SchemaDef } from '../../../schema'; import type { CreateArgs, CreateManyArgs } from '../../crud-types'; -import { getIdValues, requireField } from '../../query-utils'; +import { RejectedByPolicyError } from '../../errors'; +import { getIdValues } from '../../query-utils'; import { BaseOperationHandler } from './base'; export class CreateOperationHandler< @@ -26,47 +27,36 @@ export class CreateOperationHandler< } private async runCreate(args: CreateArgs>) { - const hasRelationCreate = Object.keys(args.data).some( - (f) => !!requireField(this.schema, this.model, f).relation - ); - - const returnRelations = this.needReturnRelations(this.model, args); - let result: any; - if (hasRelationCreate || returnRelations) { - // employ a transaction - try { - result = await this.kysely - .transaction() - .setIsolationLevel('repeatable read') - .execute(async (tx) => { - const createResult = await this.create( - tx, + try { + result = await this.kysely + .transaction() + .setIsolationLevel('repeatable read') + .execute(async (tx) => { + const createResult = await this.create( + tx, + this.model, + args.data + ); + return this.readUnique(tx, this.model, { + select: args.select, + include: args.include, + where: getIdValues( + this.schema, this.model, - args.data - ); - return this.readUnique(tx, this.model, { - select: args.select, - include: args.include, - where: getIdValues( - this.schema, - this.model, - createResult - ), - }); + createResult + ), }); - } catch (err) { - // console.error(err); - throw err; - } - } else { - // simple create - const createResult = await this.create( - this.kysely, - this.model, - args.data + }); + } catch (err) { + // console.error(err); + throw err; + } + + if (!result) { + throw new RejectedByPolicyError( + `result is not allowed to be read back` ); - result = this.trimResult(createResult, args); } return result; diff --git a/packages/runtime/src/client/errors.ts b/packages/runtime/src/client/errors.ts index f58a32f8..91731ba0 100644 --- a/packages/runtime/src/client/errors.ts +++ b/packages/runtime/src/client/errors.ts @@ -15,3 +15,9 @@ export class NotFoundError extends Error { super(`Entity not found for model "${model}"`); } } + +export class RejectedByPolicyError extends Error { + constructor(reason?: string) { + super(reason ?? `Operation rejected by policy`); + } +} diff --git a/packages/runtime/src/client/executor/name-mapper.ts b/packages/runtime/src/client/executor/name-mapper.ts index c26abe9e..47a658ac 100644 --- a/packages/runtime/src/client/executor/name-mapper.ts +++ b/packages/runtime/src/client/executor/name-mapper.ts @@ -90,24 +90,24 @@ export class QueryNameMapper extends OperationNodeTransformer { } protected override transformSelectQuery(node: SelectQueryNode) { - this.currentModel = undefined; - if ( - node.from?.froms && - node.from.froms.length === 1 && - node.from.froms[0] - ) { - const from = node.from.froms[0]; - if (TableNode.is(from)) { - this.currentModel = from.table.identifier.name; - } else if (AliasNode.is(from) && TableNode.is(from.node)) { - this.currentModel = from.node.table.identifier.name; - } - } else { + if (!node.from?.froms || node.from.froms.length === 0) { + return super.transformSelectQuery(node); + } + + if (node.from.froms.length > 1) { throw new InternalError( `SelectQueryNode must have a single table in from clause` ); } + this.currentModel = undefined; + const from = node.from.froms[0]!; + if (TableNode.is(from)) { + this.currentModel = from.table.identifier.name; + } else if (AliasNode.is(from) && TableNode.is(from.node)) { + this.currentModel = from.node.table.identifier.name; + } + const selections = node.selections ? this.transformSelections(node.selections, node) : node.selections; diff --git a/packages/runtime/src/client/executor/zenstack-query-executor.ts b/packages/runtime/src/client/executor/zenstack-query-executor.ts index a85d1514..ea7f1eac 100644 --- a/packages/runtime/src/client/executor/zenstack-query-executor.ts +++ b/packages/runtime/src/client/executor/zenstack-query-executor.ts @@ -113,6 +113,7 @@ export class ZenStackQueryExecutor< // trim the result to the original query node if (oldQueryNode !== queryNode) { + // TODO: trim the result to the original query node } return result; @@ -158,11 +159,12 @@ export class ZenStackQueryExecutor< return proceed(queryNode); } - private async proceedQuery(query: RootOperationNode, queryId: QueryId) { + private proceedQuery(query: RootOperationNode, queryId: QueryId) { // run built-in transformers const finalQuery = this.nameMapper.transformNode(query); const compiled = this.compileQuery(finalQuery); + return this.driver.txConnection ? super .withConnectionProvider( diff --git a/packages/runtime/src/client/helpers/schema-db-pusher.ts b/packages/runtime/src/client/helpers/schema-db-pusher.ts index b84bf5bc..b1240ae4 100644 --- a/packages/runtime/src/client/helpers/schema-db-pusher.ts +++ b/packages/runtime/src/client/helpers/schema-db-pusher.ts @@ -101,7 +101,7 @@ export class SchemaDbPusher { modelDef: ModelDef ) { for (const [key, value] of Object.entries(modelDef.uniqueFields)) { - invariant(typeof value === 'object'); + invariant(typeof value === 'object', 'expecting an object'); if ('type' in value) { // uni-field constraint, check if it's already defined at field level const fieldDef = modelDef.fields[key]!; @@ -193,7 +193,7 @@ export class SchemaDbPusher { fieldName: string, fieldDef: FieldDef ) { - invariant(fieldDef.relation); + invariant(fieldDef.relation, 'field must be a relation'); if (!fieldDef.relation.fields || !fieldDef.relation.references) { // not fk side diff --git a/packages/runtime/src/client/options.ts b/packages/runtime/src/client/options.ts index ba9a7f62..561c779e 100644 --- a/packages/runtime/src/client/options.ts +++ b/packages/runtime/src/client/options.ts @@ -62,7 +62,11 @@ export type ComputedFieldsOptions = { }; export type HasComputedFields = - keyof ComputedFieldsOptions extends never ? false : true; + string extends GetModels + ? false + : keyof ComputedFieldsOptions extends never + ? false + : true; export type ProceduresOptions = Schema extends { procedures: Record; diff --git a/packages/runtime/src/client/plugin.ts b/packages/runtime/src/client/plugin.ts index 7ccc26d4..e2f6f3a1 100644 --- a/packages/runtime/src/client/plugin.ts +++ b/packages/runtime/src/client/plugin.ts @@ -98,15 +98,17 @@ export type OnKyselyQueryTransactionCallback = ( proceed: ProceedKyselyQueryFunction ) => Promise>; +export type OnKyselyQueryTransaction = ( + callback: OnKyselyQueryTransactionCallback +) => Promise>; + export type OnKyselyQueryArgs = { kysely: ToKysely; schema: SchemaDef; client: ClientContract; query: RootOperationNode; proceed: ProceedKyselyQueryFunction; - transaction: ( - callback: OnKyselyQueryTransactionCallback - ) => Promise>; + transaction: OnKyselyQueryTransaction; }; export type ProceedKyselyQueryFunction = ( diff --git a/packages/runtime/src/client/query-utils.ts b/packages/runtime/src/client/query-utils.ts index b512eaab..270cbd24 100644 --- a/packages/runtime/src/client/query-utils.ts +++ b/packages/runtime/src/client/query-utils.ts @@ -41,7 +41,7 @@ export function getIdFields( model: GetModels ) { const modelDef = requireModel(schema, model); - return modelDef?.idFields; + return modelDef?.idFields as GetModels[]; } export function requireIdFields(schema: SchemaDef, model: string) { diff --git a/packages/runtime/src/plugins/policy/column-collector.ts b/packages/runtime/src/plugins/policy/column-collector.ts new file mode 100644 index 00000000..8b2a9f77 --- /dev/null +++ b/packages/runtime/src/plugins/policy/column-collector.ts @@ -0,0 +1,21 @@ +import type { ColumnNode, OperationNode } from 'kysely'; +import { DefaultOperationNodeVisitor } from '../../utils/default-operation-node-visitor'; + +/** + * Collects all column names from a query. + */ +export class ColumnCollector extends DefaultOperationNodeVisitor { + private columns: string[] = []; + + collect(node: OperationNode) { + this.columns = []; + this.visitNode(node); + return this.columns; + } + + protected override visitColumn(node: ColumnNode): void { + if (!this.columns.includes(node.column.name)) { + this.columns.push(node.column.name); + } + } +} diff --git a/packages/runtime/src/plugins/policy/expression-transformer.ts b/packages/runtime/src/plugins/policy/expression-transformer.ts index 0f8c87b7..164bb4c6 100644 --- a/packages/runtime/src/plugins/policy/expression-transformer.ts +++ b/packages/runtime/src/plugins/policy/expression-transformer.ts @@ -1,39 +1,45 @@ import { - AndNode, + AliasNode, BinaryOperationNode, ColumnNode, + FromNode, + FunctionNode, + IdentifierNode, OperatorNode, - OrNode, ReferenceNode, + SelectionNode, + SelectQueryNode, + TableNode, ValueNode, + WhereNode, type OperationNode, } from 'kysely'; import invariant from 'tiny-invariant'; import { match } from 'ts-pattern'; +import type { FieldDef } from '../../../dist/schema'; import { getCrudDialect } from '../../client/crud/dialects'; import type { BaseCrudDialect } from '../../client/crud/dialects/base'; -import { QueryError } from '../../client/errors'; +import { InternalError, QueryError } from '../../client/errors'; import type { ClientOptions } from '../../client/options'; import { - getIdFields, getRelationForeignKeyFieldPairs, requireField, } from '../../client/query-utils'; -import type { CallExpression, SchemaDef } from '../../schema'; +import type { CallExpression, FieldExpression, SchemaDef } from '../../schema'; import { Expression, type BinaryExpression, type BinaryOperator, - type FieldReferenceExpression, type LiteralExpression, + type MemberExpression, type UnaryExpression, } from '../../schema/expression'; import type { BuiltinType, GetModels } from '../../schema/schema'; -import type { PolicyOptions } from './options'; -import type { SchemaPolicy } from './types'; +import { conjunction, disjunction, logicalNot, trueNode } from './utils'; export type ExpressionTransformerContext = { model: GetModels; + thisEntity?: Record; }; // a registry of expression handlers marked with @expr @@ -54,19 +60,23 @@ function expr(kind: Expression['kind']) { } export class ExpressionTransformer { - private readonly options: PolicyOptions; private readonly dialect: BaseCrudDialect; - private readonly schemaPolicy: SchemaPolicy; constructor( private readonly schema: Schema, private readonly clientOptions: ClientOptions, - options: PolicyOptions + private readonly auth: unknown | undefined ) { - this.options = options; this.dialect = getCrudDialect(this.schema, this.clientOptions); - invariant(this.schema.plugins['policy']); - this.schemaPolicy = this.schema.plugins['policy'] as SchemaPolicy; + } + + get authType() { + if (!this.schema.authType) { + throw new InternalError( + 'Schema does not have an "authType" specified' + ); + } + return this.schema.authType; } transform( @@ -77,122 +87,282 @@ export class ExpressionTransformer { if (!handler) { throw new Error(`Unsupported expression kind: ${expression.kind}`); } - return handler.value.call(this, expression, context); } @expr('literal') + // @ts-ignore private _literal(expr: LiteralExpression) { - return ValueNode.create(expr.value); + return this.transformValue( + expr.value, + typeof expr.value === 'string' + ? 'String' + : typeof expr.value === 'boolean' + ? 'Boolean' + : 'Int' + ); } - @expr('ref') - private _ref(expr: FieldReferenceExpression) { - return ReferenceNode.create(ColumnNode.create(expr.field)); + @expr('field') + // @ts-ignore + private _field( + expr: FieldExpression, + context: ExpressionTransformerContext + ) { + const fieldDef = requireField(this.schema, context.model, expr.field); + if (!fieldDef.relation) { + if (context.thisEntity) { + return context.thisEntity[expr.field]; + } else { + return ColumnNode.create(expr.field); + } + } else { + return this._relation( + context.model, + expr.field, + fieldDef.type, + context + ); + } } @expr('null') + // @ts-ignore private _null() { return ValueNode.create(null); } @expr('binary') + // @ts-ignore private _binary( expr: BinaryExpression, context: ExpressionTransformerContext ) { + if (expr.op === '&&') { + return conjunction(this.dialect, [ + this.transform(expr.left, context), + this.transform(expr.right, context), + ]); + } else if (expr.op === '||') { + return disjunction(this.dialect, [ + this.transform(expr.left, context), + this.transform(expr.right, context), + ]); + } + if (this.isAuthCall(expr.left) || this.isAuthCall(expr.right)) { return this.transformAuthBinary(expr); } + const op = expr.op; + + if (op === '?' || op === '!' || op === '^') { + return this.transformCollectionPredicate(expr, context); + } + + const left = this.transform(expr.left, context); + const right = this.transform(expr.right, context); + + if (this.isNullNode(right)) { + invariant( + expr.op === '==' || expr.op === '!=', + 'Comparison with null must be "==" or "!="' + ); + return expr.op === '==' + ? BinaryOperationNode.create( + left, + OperatorNode.create('is'), + right + ) + : BinaryOperationNode.create( + left, + OperatorNode.create('is not'), + right + ); + } else if (this.isNullNode(left)) { + invariant( + expr.op === '==' || expr.op === '!=', + 'Comparison with null must be "==" or "!="' + ); + return expr.op === '==' + ? BinaryOperationNode.create( + right, + OperatorNode.create('is'), + left + ) + : BinaryOperationNode.create( + right, + OperatorNode.create('is not'), + left + ); + } + return BinaryOperationNode.create( - this.transform(expr.left, context), - this.transformOperator(expr.op), - this.transform(expr.right, context) + left, + this.transformOperator(op), + right ); } - private transformAuthBinary(expr: BinaryExpression) { - if (expr.op !== '==' && expr.op !== '!=') { - throw new Error(`Unsupported operator for auth call: ${expr.op}`); - } - let other: Expression; - if (this.isAuthCall(expr.left)) { - other = expr.right; - } else { - other = expr.left; - } + private isNullNode(node: OperationNode) { + return ValueNode.is(node) && node.value === null; + } - if (Expression.isNull(other)) { - return this.transformValue( - expr.op === '==' ? !this.options.auth : !!this.options.auth, - 'Boolean' - ); - } else if (Expression.isThis(other)) { - const idFields = getIdFields( + private transformCollectionPredicate( + expr: BinaryExpression, + context: ExpressionTransformerContext + ) { + invariant( + expr.op === '?' || expr.op === '!' || expr.op === '^', + 'expected "?" or "!" or "^" operator' + ); + + const left = this.transform(expr.left, context); + + invariant( + Expression.isFieldExpr(expr.left) || + Expression.isMemberExpr(expr.left), + 'left operand must be field or member access' + ); + + let newContextModel: string; + if (Expression.isFieldExpr(expr.left)) { + const fieldDef = requireField( this.schema, - this.schemaPolicy.authModel - ); - return this.buildAuthFieldComparison( - idFields.map((f) => ({ authField: f, tableField: f })), - expr.op + context.model, + expr.left.field ); - } else if (Expression.isRef(other)) { - const { keyPairs, ownedByModel } = getRelationForeignKeyFieldPairs( + newContextModel = fieldDef.type; + } else { + invariant(Expression.isFieldExpr(expr.left.receiver)); + const fieldDef = requireField( this.schema, - other.model, - other.field + context.model, + expr.left.receiver.field ); - - if (ownedByModel) { - return this.buildAuthFieldComparison( - keyPairs.map( - ({ fk, pk }) => ({ authField: pk, tableField: fk }), - expr.op - ), - expr.op + newContextModel = fieldDef.type; + for (const member of expr.left.members) { + const memberDef = requireField( + this.schema, + newContextModel, + member ); - } else { - throw new Error('Todo: join relation'); + newContextModel = memberDef.type; } - } else { - throw new Error('Unsupported expression'); } - } - private buildAuthFieldComparison( - fields: Array<{ authField: string; tableField: string }>, - op: '==' | '!=' - ) { - if (op === '==') { - return this.buildAndNode( - fields.map(({ authField, tableField }) => + let filter = this.transform(expr.right, { + ...context, + model: newContextModel as GetModels, + thisEntity: undefined, + }); + + if (expr.op === '!') { + filter = logicalNot(filter); + } + + invariant( + SelectQueryNode.is(left), + 'expected left operand to be select query' + ); + + const count = FunctionNode.create('count', [ValueNode.create(1)]); + const finalSelectQuery = this.updateInnerMostSelectQuery( + left, + filter, + match(expr.op) + .with('?', () => BinaryOperationNode.create( - ReferenceNode.create(ColumnNode.create(tableField)), - this.transformOperator(op), - this.transformAuthFieldSelect(authField) + count, + OperatorNode.create('>'), + ValueNode.create(0) ) ) - ); - } else { - return this.buildOrNode( - fields.map(({ authField, tableField }) => + .with('!', () => + BinaryOperationNode.create( + count, + OperatorNode.create('='), + ValueNode.create(0) + ) + ) + .with('^', () => BinaryOperationNode.create( - ReferenceNode.create(ColumnNode.create(tableField)), - this.transformOperator('!='), - this.transformAuthFieldSelect(authField) + count, + OperatorNode.create('='), + ValueNode.create(0) ) ) + .exhaustive() + ); + + return finalSelectQuery; + } + + private updateInnerMostSelectQuery( + node: SelectQueryNode, + where: OperationNode, + selection: OperationNode + ): SelectQueryNode { + if (!node.selections || node.selections.length === 0) { + return { + ...node, + selections: [ + SelectionNode.create( + AliasNode.create(selection, IdentifierNode.create('$t')) + ), + ], + where: WhereNode.create( + node.where + ? conjunction(this.dialect, [node.where.where, where]) + : where + ), + }; + } else { + invariant( + node.selections.length === 1, + 'expected exactly one selection' + ); + const currSelection = node.selections[0]!; + invariant( + AliasNode.is(currSelection.selection), + 'expected alias node' + ); + const alias = currSelection.selection.alias; + const inner = currSelection.selection.node; + invariant(SelectQueryNode.is(inner), 'expected select query node'); + const newInner = this.updateInnerMostSelectQuery( + inner, + where, + selection ); + return { + ...node, + selections: [ + SelectionNode.create(AliasNode.create(newInner, alias)), + ], + }; } } - private transformAuthFieldSelect(field: string): OperationNode { - return this.transformValue( - this.options.auth?.[field as keyof typeof this.options.auth], - requireField(this.schema, this.schemaPolicy.authModel!, field) - .type as BuiltinType - ); + private transformAuthBinary(expr: BinaryExpression) { + if (expr.op !== '==' && expr.op !== '!=') { + throw new Error(`Unsupported operator for auth call: ${expr.op}`); + } + let other: Expression; + if (this.isAuthCall(expr.left)) { + other = expr.right; + } else { + other = expr.left; + } + + if (Expression.isNull(other)) { + return this.transformValue( + expr.op === '==' ? !this.auth : !!this.auth, + 'Boolean' + ); + } else { + throw new Error('Unsupported expression'); + } } private transformValue(value: unknown, type: BuiltinType): OperationNode { @@ -201,44 +371,22 @@ export class ExpressionTransformer { ); } - private buildAndNode(nodes: OperationNode[]) { - if (nodes.length === 0) { - throw new Error('Expected at least one node'); - } - if (nodes.length === 1) { - return nodes[0]; - } - const initial = nodes.shift()!; - return nodes.reduce( - (prev, curr) => AndNode.create(prev, curr), - initial - ); - } - - private buildOrNode(nodes: OperationNode[]) { - if (nodes.length === 0) { - throw new Error('Expected at least one node'); - } - if (nodes.length === 1) { - return nodes[0]; - } - const initial = nodes.shift()!; - return nodes.reduce((prev, curr) => OrNode.create(prev, curr), initial); - } - @expr('unary') + // @ts-ignore private _unary( expr: UnaryExpression, context: ExpressionTransformerContext ) { + // only '!' operator for now + invariant(expr.op === '!', 'only "!" operator is supported'); return BinaryOperationNode.create( this.transform(expr.operand, context), this.transformOperator('!='), - ValueNode.create(true) + trueNode(this.dialect) ); } - private transformOperator(op: BinaryOperator) { + private transformOperator(op: Exclude) { const mappedOp = match(op) .with('==', () => '=' as const) .otherwise(() => op); @@ -246,14 +394,232 @@ export class ExpressionTransformer { } @expr('call') - private _call( - expr: CallExpression, - context: ExpressionTransformerContext - ) { + // @ts-ignore + private _call(expr: CallExpression) { throw new QueryError(`Unknown function: ${expr.function}`); } private isAuthCall(value: unknown): value is CallExpression { return Expression.isCall(value) && value.function === 'auth'; } + + @expr('member') + // @ts-ignore + private _member( + expr: MemberExpression, + context: ExpressionTransformerContext + ) { + // auth() member access + if (this.isAuthCall(expr.receiver)) { + return this.valueMemberAccess(this.auth, expr, this.authType); + } + + invariant( + Expression.isFieldExpr(expr.receiver), + 'expect receiver to be field expression' + ); + + const receiver = this.transform(expr.receiver, context); + invariant( + SelectQueryNode.is(receiver), + 'expected receiver to be select query' + ); + + // relation member access + const receiverField = requireField( + this.schema, + context.model, + expr.receiver.field + ); + + // traverse forward to collect member types + const memberFields: { fromModel: string; fieldDef: FieldDef }[] = []; + let currType = receiverField.type; + for (const member of expr.members) { + const fieldDef = requireField(this.schema, currType, member); + memberFields.push({ fieldDef, fromModel: currType }); + currType = fieldDef.type; + } + + let currNode: SelectQueryNode | ColumnNode | undefined = undefined; + const innerContext = { ...context, thisEntity: undefined }; + + for (let i = expr.members.length - 1; i >= 0; i--) { + const member = expr.members[i]!; + const { fieldDef, fromModel } = memberFields[i]!; + if (fieldDef.relation) { + const relation = this._relation( + fromModel, + member, + fieldDef.type, + innerContext + ); + if (currNode) { + invariant( + SelectQueryNode.is(currNode), + 'expected select query node' + ); + currNode = { + ...(currNode as SelectQueryNode), + selections: [ + SelectionNode.create( + AliasNode.create( + relation, + IdentifierNode.create(member) + ) + ), + ], + }; + } else { + currNode = relation; + } + } else { + invariant( + i === expr.members.length - 1, + 'plain field access must be the last segment' + ); + if (currNode) { + invariant( + SelectQueryNode.is(currNode), + 'expected select query node' + ); + currNode = { + ...(currNode as SelectQueryNode), + selections: [ + SelectionNode.create(ColumnNode.create(member)), + ], + }; + } else { + currNode = ColumnNode.create(member); + } + } + } + + return { + ...receiver, + selections: [ + SelectionNode.create( + AliasNode.create(currNode!, IdentifierNode.create('$t')) + ), + ], + }; + } + + private valueMemberAccess( + receiver: any, + expr: MemberExpression, + receiverType: string + ) { + if (!receiver) { + return ValueNode.create(null); + } + + if (expr.members.length !== 1) { + throw new Error(`Only single member access is supported`); + } + + const field = expr.members[0]!; + const fieldDef = requireField(this.schema, receiverType, field); + const fieldValue = receiver[field] ?? null; + return this.transformValue(fieldValue, fieldDef.type as BuiltinType); + } + + // @expr('relation') + // @ts-ignore + private _relation( + fromModel: string, + field: string, + relationModel: string, + context: ExpressionTransformerContext + ): SelectQueryNode { + const { keyPairs, ownedByModel } = getRelationForeignKeyFieldPairs( + this.schema, + fromModel, + field + ); + + if (context.thisEntity) { + let condition: OperationNode; + if (ownedByModel) { + condition = conjunction( + this.dialect, + keyPairs.map(({ fk, pk }) => + BinaryOperationNode.create( + ReferenceNode.create( + ColumnNode.create(pk), + TableNode.create(relationModel) + ), + OperatorNode.create('='), + context.thisEntity![fk]! + ) + ) + ); + } else { + condition = conjunction( + this.dialect, + keyPairs.map(({ fk, pk }) => + BinaryOperationNode.create( + ReferenceNode.create( + ColumnNode.create(fk), + TableNode.create(relationModel) + ), + OperatorNode.create('='), + context.thisEntity![pk]! + ) + ) + ); + } + + return { + kind: 'SelectQueryNode', + from: FromNode.create([TableNode.create(relationModel)]), + where: WhereNode.create(condition), + }; + } else { + let condition: OperationNode; + if (ownedByModel) { + // `fromModel` owns the fk + condition = conjunction( + this.dialect, + keyPairs.map(({ fk, pk }) => + BinaryOperationNode.create( + ReferenceNode.create( + ColumnNode.create(fk), + TableNode.create(fromModel) + ), + OperatorNode.create('='), + ReferenceNode.create( + ColumnNode.create(pk), + TableNode.create(relationModel) + ) + ) + ) + ); + } else { + // `relationModel` owns the fk + condition = conjunction( + this.dialect, + keyPairs.map(({ fk, pk }) => + BinaryOperationNode.create( + ReferenceNode.create( + ColumnNode.create(pk), + TableNode.create(fromModel) + ), + OperatorNode.create('='), + ReferenceNode.create( + ColumnNode.create(fk), + TableNode.create(relationModel) + ) + ) + ) + ); + } + + return { + kind: 'SelectQueryNode', + from: FromNode.create([TableNode.create(relationModel)]), + where: WhereNode.create(condition), + }; + } + } } diff --git a/packages/runtime/src/plugins/policy/generator.ts b/packages/runtime/src/plugins/policy/generator.ts deleted file mode 100644 index 2019af75..00000000 --- a/packages/runtime/src/plugins/policy/generator.ts +++ /dev/null @@ -1,84 +0,0 @@ -import { isDataModel, type Model } from '@zenstackhq/language/ast'; -import fs from 'node:fs'; -import ts from 'typescript'; -import type { CliGenerator } from '../../client/plugin'; - -export const generate: CliGenerator = (context) => { - const source = fs.readFileSync(context.tsSchemaFile, 'utf-8'); - const sourceFile = ts.createSourceFile( - context.tsSchemaFile, - source, - ts.ScriptTarget.Latest, - true - ); - - const transformer: ts.TransformerFactory = ( - ctx: ts.TransformationContext - ) => { - return (rootNode: ts.SourceFile) => { - function generateForPlugin(node: ts.PropertyAssignment) { - const initializer = - node.initializer as ts.ObjectLiteralExpression; - return ts.factory.updatePropertyAssignment( - node, - node.name, - ts.factory.updateObjectLiteralExpression(initializer, [ - ...initializer.properties, - makePolicyProperty(), - ]) - ); - } - - function makePolicyProperty(): ts.PropertyAssignment { - return ts.factory.createPropertyAssignment( - 'policy', - ts.factory.createObjectLiteralExpression([ - ts.factory.createPropertyAssignment( - 'authModel', - ts.factory.createStringLiteral( - getAuthModelName(context.model) - ) - ), - ]) - ); - } - - const visitor: ts.Visitor = (node) => { - if ( - ts.isPropertyAssignment(node) && - node.name.getText() === 'plugins' && - ts.isObjectLiteralExpression(node.initializer) - ) { - return generateForPlugin(node); - } - return ts.visitEachChild(node, visitor, ctx); - }; - - return ts.visitNode(rootNode, visitor) as ts.SourceFile; - }; - }; - - const result = ts.transform(sourceFile, [transformer]); - const printer = ts.createPrinter(); - const transformedSource = printer.printFile(result.transformed[0]!); - fs.writeFileSync(context.tsSchemaFile, transformedSource); -}; - -function getAuthModelName(model: Model) { - let found = model.declarations.find( - (d) => - isDataModel(d) && - d.attributes.some((attr) => attr.decl.$refText === '@@auth') - ); - if (!found) { - found = model.declarations.find( - (d) => isDataModel(d) && d.name === 'User' - ); - } - if (!found) { - throw new Error( - `@@auth model not found, please add @@auth to your model or create a User model` - ); - } - return found.name; -} diff --git a/packages/runtime/src/plugins/policy/index.ts b/packages/runtime/src/plugins/policy/index.ts index caf4d89f..1110b645 100644 --- a/packages/runtime/src/plugins/policy/index.ts +++ b/packages/runtime/src/plugins/policy/index.ts @@ -1,2 +1 @@ export * from './plugin'; -export { generate as default } from './generator'; diff --git a/packages/runtime/src/plugins/policy/options.ts b/packages/runtime/src/plugins/policy/options.ts deleted file mode 100644 index c3f93957..00000000 --- a/packages/runtime/src/plugins/policy/options.ts +++ /dev/null @@ -1,16 +0,0 @@ -import type { ModelResult } from '../../client'; -import type { GetModels, SchemaDef } from '../../schema'; - -export type Auth = Schema['plugins'] extends { - policy: object; -} - ? Schema['plugins']['policy'] extends { authModel: infer AuthModel } - ? AuthModel extends GetModels - ? Partial> - : never - : never - : never; - -export type PolicyOptions = { - auth?: Auth; -}; diff --git a/packages/runtime/src/plugins/policy/plugin.ts b/packages/runtime/src/plugins/policy/plugin.ts index b929c3af..ed731d87 100644 --- a/packages/runtime/src/plugins/policy/plugin.ts +++ b/packages/runtime/src/plugins/policy/plugin.ts @@ -3,18 +3,11 @@ import { type RuntimePlugin, } from '../../client/plugin'; import type { SchemaDef } from '../../schema'; -import type { Auth, PolicyOptions } from './options'; -import { PolicyTransformer } from './policy-transformer'; +import { PolicyHandler } from './policy-handler'; export class PolicyPlugin implements RuntimePlugin { - private readonly options: PolicyOptions; - - constructor(options?: PolicyOptions) { - this.options = options ?? {}; - } - get id() { return 'policy'; } @@ -27,16 +20,13 @@ export class PolicyPlugin return 'Enforces access policies defined in the schema.'; } - onKyselyQuery({ proceed, query, client }: OnKyselyQueryArgs) { - const transformer = new PolicyTransformer(client, this.options); - const transformedQuery = transformer.transformNode(query); - return proceed(transformedQuery); - } - - setAuth(auth: Auth) { - return new PolicyPlugin({ - ...this.options, - auth, - }); + onKyselyQuery({ + query, + client, + proceed, + transaction, + }: OnKyselyQueryArgs) { + const handler = new PolicyHandler(client); + return handler.handle(query, proceed, transaction); } } diff --git a/packages/runtime/src/plugins/policy/plugin.zmodel b/packages/runtime/src/plugins/policy/plugin.zmodel index f3b916b2..3ff7d9c0 100644 --- a/packages/runtime/src/plugins/policy/plugin.zmodel +++ b/packages/runtime/src/plugins/policy/plugin.zmodel @@ -1,15 +1,3 @@ -/** - * Gets the current login user. - */ -function auth(): Any { -} @@@expressionContext([DefaultValue, AccessPolicy]) - -/** - * Used to specify the model for resolving `auth()` function call in access policies. A Zmodel - * can have at most one model with this attribute. By default, the model named "User" is used. - */ -attribute @@auth() @@@supportTypeDef - /** * Defines an access policy that allows a set of operations when the given condition is true. * diff --git a/packages/runtime/src/plugins/policy/policy-handler.ts b/packages/runtime/src/plugins/policy/policy-handler.ts new file mode 100644 index 00000000..dea3a831 --- /dev/null +++ b/packages/runtime/src/plugins/policy/policy-handler.ts @@ -0,0 +1,524 @@ +import { + AliasNode, + BinaryOperationNode, + ColumnNode, + DeleteQueryNode, + FromNode, + IdentifierNode, + InsertQueryNode, + OperationNodeTransformer, + OperatorNode, + PrimitiveValueListNode, + ReturningNode, + SelectionNode, + SelectQueryNode, + TableNode, + UpdateQueryNode, + ValueListNode, + ValueNode, + ValuesNode, + WhereNode, + type OperationNode, + type QueryResult, + type RootOperationNode, +} from 'kysely'; +import invariant from 'tiny-invariant'; +import { match } from 'ts-pattern'; +import type { ClientContract } from '../../client'; +import { getCrudDialect } from '../../client/crud/dialects'; +import type { BaseCrudDialect } from '../../client/crud/dialects/base'; +import { InternalError, RejectedByPolicyError } from '../../client/errors'; +import type { + OnKyselyQueryTransaction, + ProceedKyselyQueryFunction, +} from '../../client/plugin'; +import { getIdFields, requireModel } from '../../client/query-utils'; +import { Expression, type GetModels, type SchemaDef } from '../../schema'; +import { ColumnCollector } from './column-collector'; +import { ExpressionTransformer } from './expression-transformer'; +import type { Policy, PolicyOperation } from './types'; +import { + buildIsFalse, + conjunction, + disjunction, + falseNode, + getTableName, +} from './utils'; + +export type CrudQueryNode = + | SelectQueryNode + | InsertQueryNode + | UpdateQueryNode + | DeleteQueryNode; + +export type MutationQueryNode = + | InsertQueryNode + | UpdateQueryNode + | DeleteQueryNode; + +export class PolicyHandler< + Schema extends SchemaDef +> extends OperationNodeTransformer { + private readonly dialect: BaseCrudDialect; + + constructor(private readonly client: ClientContract) { + super(); + this.dialect = getCrudDialect( + this.client.$schema, + this.client.$options + ); + } + + get kysely() { + return this.client.$qb; + } + + async handle( + node: RootOperationNode, + proceed: ProceedKyselyQueryFunction, + transaction: OnKyselyQueryTransaction + ) { + if (!this.isCrudQueryNode(node)) { + // non CRUD queries are not allowed + throw new RejectedByPolicyError('non CRUD queries are not allowed'); + } + + let mutationRequiresTransaction = false; + + if (InsertQueryNode.is(node)) { + const constCondition = this.tryGetConstantPolicy( + this.getMutationModel(node), + 'create' + ); + if (constCondition === false) { + throw new RejectedByPolicyError(); + } else if (constCondition === undefined) { + mutationRequiresTransaction = true; + } + } + + if (!this.isMutationQueryNode(node)) { + // transform and proceed read without transaction + return proceed(this.transformNode(node)); + } + + if (!mutationRequiresTransaction && !node.returning) { + // transform and proceed mutation without transaction + return proceed(this.transformNode(node)); + } + + let readBackError = false; + + // transform and post-process in a transaction + const result = await transaction(async (txProceed) => { + if (InsertQueryNode.is(node)) { + await this.enforcePreCreatePolicy(node, txProceed); + } + const transformedNode = this.transformNode(node); + const result = await txProceed(transformedNode); + + if (!InsertQueryNode.is(node) || !this.onlyReturningId(node)) { + const readBackResult = await this.processReadBack( + node, + result, + txProceed + ); + if (readBackResult.rows.length !== result.rows.length) { + readBackError = true; + } + return readBackResult; + } else { + return result; + } + }); + + if (readBackError) { + throw new RejectedByPolicyError( + 'result is not allowed to be read back' + ); + } + + return result; + } + + private onlyReturningId(node: InsertQueryNode) { + if (!node.returning) { + return true; + } + const idFields = getIdFields( + this.client.$schema, + this.getMutationModel(node) + ); + const collector = new ColumnCollector(); + const selectedColumns = collector.collect(node.returning); + return selectedColumns.every((c) => idFields.includes(c)); + } + + private async enforcePreCreatePolicy( + node: InsertQueryNode, + proceed: ProceedKyselyQueryFunction + ) { + if (!node.columns || !node.values) { + return; + } + + const thisEntity: Record = {}; + const values = this.unwrapCreateValues(node.values); + for (let i = 0; i < node.columns?.length; i++) { + thisEntity[node.columns![i]!.column.name] = values[i]!; + } + + const model = this.getMutationModel(node); + const filter = this.buildPolicyFilter(model, 'create', thisEntity); + const preCreateCheck: SelectQueryNode = { + kind: 'SelectQueryNode', + selections: [ + SelectionNode.create( + AliasNode.create( + filter, + IdentifierNode.create('$condition') + ) + ), + ], + }; + const result = await proceed(preCreateCheck); + if (!(result.rows[0] as any)?.$condition) { + throw new RejectedByPolicyError(); + } + } + + private unwrapCreateValues(node: OperationNode): readonly OperationNode[] { + if (ValuesNode.is(node)) { + if (node.values.length === 1 && this.isValueList(node.values[0]!)) { + return this.unwrapCreateValues(node.values[0]!); + } else { + return node.values; + } + } else if (PrimitiveValueListNode.is(node)) { + return node.values.map((v) => ValueNode.create(v)); + } else { + throw new InternalError( + `Unexpected node kind: ${node.kind} for unwrapping create values` + ); + } + } + + private isValueList(node: OperationNode) { + return ValueListNode.is(node) || PrimitiveValueListNode.is(node); + } + + private tryGetConstantPolicy( + model: GetModels, + operation: PolicyOperation + ) { + const policies = this.getModelPolicies(model, operation); + if (!policies.some((p) => p.kind === 'allow')) { + // no allow -> unconditional deny + return false; + } else if ( + // unconditional deny + policies.some( + (p) => p.kind === 'deny' && this.isTrueExpr(p.condition) + ) + ) { + return false; + } else if ( + // unconditional allow + !policies.some((p) => p.kind === 'deny') && + policies.some( + (p) => p.kind === 'allow' && this.isTrueExpr(p.condition) + ) + ) { + return true; + } else { + return undefined; + } + } + + private isTrueExpr(expr: Expression) { + return Expression.isLiteral(expr) && expr.value === true; + } + + private async processReadBack( + node: CrudQueryNode, + result: QueryResult, + proceed: ProceedKyselyQueryFunction + ) { + if ( + InsertQueryNode.is(node) || + UpdateQueryNode.is(node) || + DeleteQueryNode.is(node) + ) { + if (node.returning) { + // do a select (with policy) in place of returning + const table = this.getMutationModel(node); + if (!table) { + throw new InternalError( + `Unable to get table name for query node: ${node}` + ); + } + + const idConditions = this.buildIdConditions(table, result.rows); + const policyFilter = this.buildPolicyFilter(table, 'read'); + + const select: SelectQueryNode = { + kind: 'SelectQueryNode', + from: FromNode.create([TableNode.create(table)]), + where: WhereNode.create( + conjunction(this.dialect, [idConditions, policyFilter]) + ), + selections: node.returning.selections, + }; + const selectResult = await proceed(select); + return selectResult; + } else { + return result; + } + } + + return result; + } + + private buildIdConditions(table: string, rows: any[]): OperationNode { + const idFields = getIdFields(this.client.$schema, table); + return disjunction( + this.dialect, + rows.map((row) => + conjunction( + this.dialect, + idFields.map((field) => + BinaryOperationNode.create( + ColumnNode.create(field), + OperatorNode.create('='), + ValueNode.create(row[field]) + ) + ) + ) + ) + ); + } + + private getMutationModel( + node: InsertQueryNode | UpdateQueryNode | DeleteQueryNode + ) { + return match(node) + .when( + InsertQueryNode.is, + (node) => getTableName(node.into) as GetModels + ) + .when( + UpdateQueryNode.is, + (node) => getTableName(node.table) as GetModels + ) + .when( + DeleteQueryNode.is, + (node) => getTableName(node.from) as GetModels + ) + .exhaustive(); + } + + private isCrudQueryNode(node: RootOperationNode): node is CrudQueryNode { + return ( + SelectQueryNode.is(node) || + InsertQueryNode.is(node) || + UpdateQueryNode.is(node) || + DeleteQueryNode.is(node) + ); + } + + private isMutationQueryNode( + node: RootOperationNode + ): node is MutationQueryNode { + return ( + InsertQueryNode.is(node) || + UpdateQueryNode.is(node) || + DeleteQueryNode.is(node) + ); + } + + private buildPolicyFilter( + model: GetModels, + operation: PolicyOperation, + thisEntity?: Record + ) { + const policies = this.getModelPolicies(model, operation); + if (policies.length === 0) { + return falseNode(this.dialect); + } + + const allows = policies + .filter((policy) => policy.kind === 'allow') + .map((policy) => + this.transformPolicyCondition(model, policy, thisEntity) + ); + + const denies = policies + .filter((policy) => policy.kind === 'deny') + .map((policy) => + this.transformPolicyCondition(model, policy, thisEntity) + ); + + let combinedPolicy: OperationNode; + + if (allows.length === 0) { + // constant false + combinedPolicy = ValueNode.create( + this.dialect.transformPrimitive(false, 'Boolean') + ); + } else { + // or(...allows) + combinedPolicy = disjunction(this.dialect, allows); + + // and(...!denies) + if (denies.length !== 0) { + const combinedDenies = conjunction( + this.dialect, + denies.map((d) => buildIsFalse(d, this.dialect)) + ); + // or(...allows) && and(...!denies) + combinedPolicy = conjunction(this.dialect, [ + combinedPolicy, + combinedDenies, + ]); + } + } + return combinedPolicy; + } + + protected override transformSelectQuery(node: SelectQueryNode) { + let whereNode = node.where; + + node.from?.froms.forEach((from) => { + let modelName = this.extractTableName(from); + const filter = this.buildPolicyFilter(modelName, 'read'); + whereNode = WhereNode.create( + whereNode?.where + ? conjunction(this.dialect, [whereNode.where, filter]) + : filter + ); + }); + + const baseResult = super.transformSelectQuery({ + ...node, + where: undefined, + }); + + return { + ...baseResult, + where: whereNode, + }; + } + + protected override transformInsertQuery(node: InsertQueryNode) { + const result = super.transformInsertQuery(node); + if (!node.returning) { + return result; + } + if (this.onlyReturningId(node)) { + return result; + } else { + // only return ID fields, that's enough for reading back the inserted row + const idFields = getIdFields( + this.client.$schema, + this.getMutationModel(node) + ); + return { + ...result, + returning: ReturningNode.create( + idFields.map((field) => + SelectionNode.create(ColumnNode.create(field)) + ) + ), + }; + } + } + + protected override transformUpdateQuery(node: UpdateQueryNode) { + const result = super.transformUpdateQuery(node); + if (!node.returning) { + return result; + } + return { + ...result, + returning: ReturningNode.create([SelectionNode.createSelectAll()]), + }; + } + + protected override transformDeleteQuery(node: DeleteQueryNode) { + const result = super.transformDeleteQuery(node); + if (!node.returning) { + return result; + } + return { + ...result, + returning: ReturningNode.create([SelectionNode.createSelectAll()]), + }; + } + + private extractTableName(from: OperationNode): GetModels { + if (TableNode.is(from)) { + return from.table.identifier.name as GetModels; + } + if (AliasNode.is(from)) { + return this.extractTableName(from.node); + } else { + throw new Error(`Unexpected "from" node kind: ${from.kind}`); + } + } + + private transformPolicyCondition( + model: GetModels, + policy: Policy, + thisEntity?: Record + ) { + return new ExpressionTransformer( + this.client.$schema, + this.client.$options, + this.client.$auth + ).transform(policy.condition, { model, thisEntity }); + } + + private getModelPolicies(modelName: string, operation: PolicyOperation) { + const modelDef = requireModel(this.client.$schema, modelName); + const result: Policy[] = []; + + const extractOperations = (expr: Expression) => { + invariant(Expression.isLiteral(expr), 'expecting a literal'); + invariant( + typeof expr.value === 'string', + 'expecting a string literal' + ); + return expr.value + .split(',') + .filter((v) => !!v) + .map((v) => v.trim()) as PolicyOperation[]; + }; + + if (modelDef.attributes) { + result.push( + ...modelDef.attributes + .filter( + (attr) => + attr.name === '@@allow' || attr.name === '@@deny' + ) + .map( + (attr) => + ({ + kind: + attr.name === '@@allow' ? 'allow' : 'deny', + operations: extractOperations( + attr.args![0]!.value + ), + condition: attr.args![1]!.value, + } as const) + ) + .filter( + (policy) => + policy.operations.includes('all') || + policy.operations.includes(operation) + ) + ); + } + return result; + } +} diff --git a/packages/runtime/src/plugins/policy/policy-transformer.ts b/packages/runtime/src/plugins/policy/policy-transformer.ts deleted file mode 100644 index 97306ed2..00000000 --- a/packages/runtime/src/plugins/policy/policy-transformer.ts +++ /dev/null @@ -1,136 +0,0 @@ -import { - AliasNode, - AndNode, - OperationNodeTransformer, - OrNode, - SelectQueryNode, - TableNode, - UnaryOperationNode, - ValueNode, - WhereNode, - type OperationNode, -} from 'kysely'; -import type { ClientContract } from '../../client'; -import { getCrudDialect } from '../../client/crud/dialects'; -import type { BaseCrudDialect } from '../../client/crud/dialects/base'; -import { requireModel } from '../../client/query-utils'; -import type { GetModels, SchemaDef } from '../../schema'; -import type { Policy } from '../../schema/schema'; -import { ExpressionTransformer } from './expression-transformer'; -import type { PolicyOptions } from './options'; - -export class PolicyTransformer< - Schema extends SchemaDef -> extends OperationNodeTransformer { - private readonly dialect: BaseCrudDialect; - - constructor( - private readonly client: ClientContract, - private readonly options: PolicyOptions - ) { - super(); - this.dialect = getCrudDialect( - this.client.$schema, - this.client.$options - ); - } - - protected override transformSelectQuery(node: SelectQueryNode) { - let whereNode = node.where; - - node.from?.froms.forEach((from) => { - let modelName = this.extractTableName(from); - const policies = this.getModelPolicies(modelName); - if (policies && policies.length > 0) { - const combinedPolicy = this.buildPolicyFilterNode( - modelName as GetModels, - policies - ); - whereNode = WhereNode.create( - whereNode?.where - ? AndNode.create(whereNode.where, combinedPolicy) - : combinedPolicy - ); - } - }); - - const baseResult = super.transformSelectQuery({ - ...node, - where: undefined, - }); - - return { - ...baseResult, - where: whereNode, - }; - } - - private buildPolicyFilterNode( - model: GetModels, - policies: Policy[] - ) { - const allows = policies - .filter((policy) => policy.kind === 'allow') - .map((policy) => this.buildPolicyWhere(model, policy)); - - const denies = policies - .filter((policy) => policy.kind === 'deny') - .map((policy) => this.buildPolicyWhere(model, policy)); - - let combinedPolicy: OperationNode; - - if (allows.length === 0) { - // constant false - combinedPolicy = ValueNode.create( - this.dialect.transformPrimitive(false, 'Boolean') - ); - } else { - // or(...allows) - combinedPolicy = allows.reduce((prev, curr, i) => - i === 0 ? curr : OrNode.create(prev, curr) - ); - - // and(...!denies) - if (denies.length !== 0) { - const combinedDenies = denies.reduce((prev, curr, i) => - i === 0 - ? UnaryOperationNode.create(ValueNode.create('!'), curr) - : AndNode.create( - prev, - UnaryOperationNode.create( - ValueNode.create('!'), - curr - ) - ) - ); - - // or(...allows) && and(...!denies) - combinedPolicy = AndNode.create(combinedPolicy, combinedDenies); - } - } - return combinedPolicy; - } - - private extractTableName(from: OperationNode): string { - if (TableNode.is(from)) { - return from.table.identifier.name; - } - if (AliasNode.is(from)) { - return this.extractTableName(from.node); - } else { - throw new Error(`Unexpected "from" node kind: ${from.kind}`); - } - } - - private buildPolicyWhere(model: GetModels, policy: Policy) { - return new ExpressionTransformer( - this.client.$schema, - this.client.$options, - this.options - ).transform(policy.expression, { model }); - } - - private getModelPolicies(modelName: string) { - return requireModel(this.client.$schema, modelName).policies; - } -} diff --git a/packages/runtime/src/plugins/policy/types.ts b/packages/runtime/src/plugins/policy/types.ts index b6ffcbfc..6b7f91cf 100644 --- a/packages/runtime/src/plugins/policy/types.ts +++ b/packages/runtime/src/plugins/policy/types.ts @@ -1,3 +1,20 @@ -export type SchemaPolicy = { - authModel: string; +import type { Expression } from '../../schema'; + +/** + * Access policy kind. + */ +export type PolicyKind = 'allow' | 'deny'; + +/** + * Access policy operation. + */ +export type PolicyOperation = 'create' | 'read' | 'update' | 'delete' | 'all'; + +/** + * Access policy definition. + */ +export type Policy = { + kind: PolicyKind; + operations: readonly PolicyOperation[]; + condition: Expression; }; diff --git a/packages/runtime/src/plugins/policy/utils.ts b/packages/runtime/src/plugins/policy/utils.ts new file mode 100644 index 00000000..4e2e1fd5 --- /dev/null +++ b/packages/runtime/src/plugins/policy/utils.ts @@ -0,0 +1,153 @@ +import type { OperationNode } from 'kysely'; +import { + AliasNode, + AndNode, + BinaryOperationNode, + OperatorNode, + OrNode, + ParensNode, + ReferenceNode, + TableNode, + UnaryOperationNode, + ValueNode, +} from 'kysely'; +import type { BaseCrudDialect } from '../../client/crud/dialects/base'; +import type { SchemaDef } from '../../schema'; + +/** + * Creates a `true` value node. + */ +export function trueNode( + dialect: BaseCrudDialect +) { + return ValueNode.create(dialect.transformPrimitive(true, 'Boolean')); +} + +/** + * Creates a `false` value node. + */ +export function falseNode( + dialect: BaseCrudDialect +) { + return ValueNode.create(dialect.transformPrimitive(false, 'Boolean')); +} + +/** + * Checks if a node is a truthy value node. + */ +export function isTrueNode(node: OperationNode): boolean { + return ValueNode.is(node) && (node.value === true || node.value === 1); +} + +/** + * Checks if a node is a falsy value node. + */ +export function isFalseNode(node: OperationNode): boolean { + return ValueNode.is(node) && (node.value === false || node.value === 0); +} + +/** + * Builds a logical conjunction of a list of nodes. + */ +export function conjunction( + dialect: BaseCrudDialect, + nodes: OperationNode[] +): OperationNode { + if (nodes.some(isFalseNode)) { + return falseNode(dialect); + } + const items = nodes.filter((n) => !isTrueNode(n)); + if (items.length === 0) { + return trueNode(dialect); + } + return items.reduce((acc, node) => + OrNode.is(node) + ? AndNode.create(acc, ParensNode.create(node)) // wraps parentheses + : AndNode.create(acc, node) + ); +} + +export function disjunction( + dialect: BaseCrudDialect, + nodes: OperationNode[] +): OperationNode { + if (nodes.some(isTrueNode)) { + return trueNode(dialect); + } + const items = nodes.filter((n) => !isFalseNode(n)); + if (items.length === 0) { + return falseNode(dialect); + } + return items.reduce((acc, node) => + AndNode.is(node) + ? OrNode.create(acc, ParensNode.create(node)) // wraps parentheses + : OrNode.create(acc, node) + ); +} + +/** + * Negates a logical expression. + */ +export function logicalNot(node: OperationNode): OperationNode { + return UnaryOperationNode.create( + OperatorNode.create('not'), + AndNode.is(node) || OrNode.is(node) + ? ParensNode.create(node) // wraps parentheses + : node + ); +} + +/** + * Builds an expression node that checks if a node is true. + */ +export function buildIsTrue( + node: OperationNode, + dialect: BaseCrudDialect +) { + if (isTrueNode(node)) { + return trueNode(dialect); + } else if (isFalseNode(node)) { + return falseNode(dialect); + } + return BinaryOperationNode.create( + node, + OperatorNode.create('='), + trueNode(dialect) + ); +} + +/** + * Builds an expression node that checks if a node is false. + */ +export function buildIsFalse( + node: OperationNode, + dialect: BaseCrudDialect +) { + if (isFalseNode(node)) { + return trueNode(dialect); + } else if (isTrueNode(node)) { + return falseNode(dialect); + } + return BinaryOperationNode.create( + node, + OperatorNode.create('='), + falseNode(dialect) + ); +} + +/** + * Gets the table name from a node. + */ +export function getTableName(node: OperationNode | undefined) { + if (!node) { + return node; + } + if (TableNode.is(node)) { + return node.table.identifier.name; + } else if (AliasNode.is(node)) { + return getTableName(node.node); + } else if (ReferenceNode.is(node) && node.table) { + return getTableName(node.table); + } + return undefined; +} diff --git a/packages/runtime/src/schema/expression.ts b/packages/runtime/src/schema/expression.ts index 18e26b6a..d8236864 100644 --- a/packages/runtime/src/schema/expression.ts +++ b/packages/runtime/src/schema/expression.ts @@ -1,8 +1,8 @@ export type Expression = | LiteralExpression | ArrayExpression - | FieldReferenceExpression - | MemberAccessExpression + | FieldExpression + | MemberExpression | CallExpression | UnaryExpression | BinaryExpression @@ -19,16 +19,15 @@ export type ArrayExpression = { items: Expression[]; }; -export type FieldReferenceExpression = { - kind: 'ref'; - model: string; +export type FieldExpression = { + kind: 'field'; field: string; }; -export type MemberAccessExpression = { +export type MemberExpression = { kind: 'member'; - object: Expression; - property: string; + receiver: Expression; + members: string[]; }; export type UnaryExpression = { @@ -67,7 +66,10 @@ export type BinaryOperator = | '<' | '<=' | '>' - | '>='; + | '>=' + | '?' + | '!' + | '^'; export const Expression = { literal: (value: string | number | boolean): LiteralExpression => { @@ -113,23 +115,30 @@ export const Expression = { }; }, - _this: (): ThisExpression => { + field: (field: string): FieldExpression => { return { - kind: 'this', + kind: 'field', + field, }; }, - _null: (): NullExpression => { + member: (receiver: Expression, members: string[]): MemberExpression => { return { - kind: 'null', + kind: 'member', + receiver: receiver, + members, }; }, - ref: (model: string, field: string): FieldReferenceExpression => { + _this: (): ThisExpression => { return { - kind: 'ref', - model, - field, + kind: 'this', + }; + }, + + _null: (): NullExpression => { + return { + kind: 'null', }; }, @@ -162,9 +171,6 @@ export const Expression = { isArray: (value: unknown): value is ArrayExpression => Expression.is(value, 'array'), - isRef: (value: unknown): value is FieldReferenceExpression => - Expression.is(value, 'ref'), - isCall: (value: unknown): value is CallExpression => Expression.is(value, 'call'), @@ -173,4 +179,22 @@ export const Expression = { isThis: (value: unknown): value is ThisExpression => Expression.is(value, 'this'), + + isUnaryExpr: (value: unknown): value is UnaryExpression => + Expression.is(value, 'unary'), + + isBinaryExpr: (value: unknown): value is BinaryExpression => + Expression.is(value, 'binary'), + + isFieldExpr: (value: unknown): value is FieldExpression => + Expression.is(value, 'field'), + + isMemberExpr: (value: unknown): value is MemberExpression => + Expression.is(value, 'member'), + + isCallExpr: (value: unknown): value is CallExpression => + Expression.is(value, 'call'), + + isThisExpr: (value: unknown): value is ThisExpression => + Expression.is(value, 'this'), }; diff --git a/packages/runtime/src/schema/schema.ts b/packages/runtime/src/schema/schema.ts index 03cd6ac6..d740563c 100644 --- a/packages/runtime/src/schema/schema.ts +++ b/packages/runtime/src/schema/schema.ts @@ -1,4 +1,5 @@ import type Decimal from 'decimal.js'; +import type { ModelResult } from '../client'; import type { Expression } from './expression'; export type DataSourceProviderType = 'sqlite' | 'postgresql'; @@ -14,6 +15,7 @@ export type SchemaDef = { enums?: Record; plugins: Record; procedures?: Record; + authType?: GetModels; }; export type ModelDef = { @@ -27,7 +29,6 @@ export type ModelDef = { | Record> >; idFields: string[]; - policies?: Policy[]; computedFields?: Record; }; @@ -41,22 +42,6 @@ export type AttributeArg = { value: Expression; }; -export type PolicyKind = 'allow' | 'deny'; - -export type PolicyOperation = - | 'create' - | 'read' - | 'update' - | 'post-update' - | 'delete' - | 'all'; - -export type Policy = { - kind: PolicyKind; - operations: PolicyOperation[]; - expression: Expression; -}; - export type CascadeAction = | 'SetNull' | 'Cascade' @@ -270,4 +255,11 @@ export type FieldIsRelationArray< ? FieldIsArray : false; +export type AuthType = + string extends GetModels + ? Record + : Schema['authType'] extends GetModels + ? Partial> + : never; + //#endregion diff --git a/packages/runtime/src/utils/default-operation-node-visitor.ts b/packages/runtime/src/utils/default-operation-node-visitor.ts new file mode 100644 index 00000000..8881b0ee --- /dev/null +++ b/packages/runtime/src/utils/default-operation-node-visitor.ts @@ -0,0 +1,415 @@ +import { + AddColumnNode, + AddConstraintNode, + AddIndexNode, + AggregateFunctionNode, + AliasNode, + AlterColumnNode, + AlterTableNode, + AndNode, + BinaryOperationNode, + CaseNode, + CastNode, + CheckConstraintNode, + ColumnDefinitionNode, + ColumnNode, + ColumnUpdateNode, + CommonTableExpressionNameNode, + CommonTableExpressionNode, + CreateIndexNode, + CreateSchemaNode, + CreateTableNode, + CreateTypeNode, + CreateViewNode, + DataTypeNode, + DefaultInsertValueNode, + DefaultValueNode, + DeleteQueryNode, + DropColumnNode, + DropConstraintNode, + DropIndexNode, + DropSchemaNode, + DropTableNode, + DropTypeNode, + DropViewNode, + ExplainNode, + FetchNode, + ForeignKeyConstraintNode, + FromNode, + FunctionNode, + GeneratedNode, + GroupByItemNode, + GroupByNode, + HavingNode, + IdentifierNode, + InsertQueryNode, + JoinNode, + JSONOperatorChainNode, + JSONPathLegNode, + JSONPathNode, + JSONReferenceNode, + LimitNode, + ListNode, + MatchedNode, + MergeQueryNode, + ModifyColumnNode, + OffsetNode, + OnConflictNode, + OnDuplicateKeyNode, + OnNode, + OperationNodeVisitor, + OperatorNode, + OrderByItemNode, + OrderByNode, + OrNode, + OutputNode, + OverNode, + ParensNode, + PartitionByItemNode, + PartitionByNode, + PrimitiveValueListNode, + RawNode, + ReferenceNode, + ReferencesNode, + RenameColumnNode, + ReturningNode, + SchemableIdentifierNode, + SelectAllNode, + SelectionNode, + SelectModifierNode, + SelectQueryNode, + SetOperationNode, + TableNode, + TopNode, + TupleNode, + UnaryOperationNode, + UniqueConstraintNode, + UpdateQueryNode, + UsingNode, + ValueListNode, + ValueNode, + ValuesNode, + WhenNode, + WhereNode, + WithNode, + type OperationNode, + type PrimaryKeyConstraintNode, +} from 'kysely'; + +export class DefaultOperationNodeVisitor extends OperationNodeVisitor { + protected defaultVisit(node: OperationNode) { + Object.values(node).forEach((value) => { + if (!value) { + return; + } + if (Array.isArray(value)) { + value.forEach((el) => this.defaultVisit(el)); + } + if ( + typeof value === 'object' && + 'kind' in value && + typeof value.kind === 'string' + ) { + this.visitNode(value); + } + }); + } + + protected override visitSelectQuery(node: SelectQueryNode): void { + this.defaultVisit(node); + } + protected override visitSelection(node: SelectionNode): void { + this.defaultVisit(node); + } + protected override visitColumn(node: ColumnNode): void { + this.defaultVisit(node); + } + protected override visitAlias(node: AliasNode): void { + this.defaultVisit(node); + } + protected override visitTable(node: TableNode): void { + this.defaultVisit(node); + } + protected override visitFrom(node: FromNode): void { + this.defaultVisit(node); + } + protected override visitReference(node: ReferenceNode): void { + this.defaultVisit(node); + } + protected override visitAnd(node: AndNode): void { + this.defaultVisit(node); + } + protected override visitOr(node: OrNode): void { + this.defaultVisit(node); + } + protected override visitValueList(node: ValueListNode): void { + this.defaultVisit(node); + } + protected override visitParens(node: ParensNode): void { + this.defaultVisit(node); + } + protected override visitJoin(node: JoinNode): void { + this.defaultVisit(node); + } + protected override visitRaw(node: RawNode): void { + this.defaultVisit(node); + } + protected override visitWhere(node: WhereNode): void { + this.defaultVisit(node); + } + protected override visitInsertQuery(node: InsertQueryNode): void { + this.defaultVisit(node); + } + protected override visitDeleteQuery(node: DeleteQueryNode): void { + this.defaultVisit(node); + } + protected override visitReturning(node: ReturningNode): void { + this.defaultVisit(node); + } + protected override visitCreateTable(node: CreateTableNode): void { + this.defaultVisit(node); + } + protected override visitAddColumn(node: AddColumnNode): void { + this.defaultVisit(node); + } + protected override visitColumnDefinition(node: ColumnDefinitionNode): void { + this.defaultVisit(node); + } + protected override visitDropTable(node: DropTableNode): void { + this.defaultVisit(node); + } + protected override visitOrderBy(node: OrderByNode): void { + this.defaultVisit(node); + } + protected override visitOrderByItem(node: OrderByItemNode): void { + this.defaultVisit(node); + } + protected override visitGroupBy(node: GroupByNode): void { + this.defaultVisit(node); + } + protected override visitGroupByItem(node: GroupByItemNode): void { + this.defaultVisit(node); + } + protected override visitUpdateQuery(node: UpdateQueryNode): void { + this.defaultVisit(node); + } + protected override visitColumnUpdate(node: ColumnUpdateNode): void { + this.defaultVisit(node); + } + protected override visitLimit(node: LimitNode): void { + this.defaultVisit(node); + } + protected override visitOffset(node: OffsetNode): void { + this.defaultVisit(node); + } + protected override visitOnConflict(node: OnConflictNode): void { + this.defaultVisit(node); + } + protected override visitOnDuplicateKey(node: OnDuplicateKeyNode): void { + this.defaultVisit(node); + } + protected override visitCheckConstraint(node: CheckConstraintNode): void { + this.defaultVisit(node); + } + protected override visitDataType(node: DataTypeNode): void { + this.defaultVisit(node); + } + protected override visitSelectAll(node: SelectAllNode): void { + this.defaultVisit(node); + } + protected override visitIdentifier(node: IdentifierNode): void { + this.defaultVisit(node); + } + protected override visitSchemableIdentifier( + node: SchemableIdentifierNode + ): void { + this.defaultVisit(node); + } + protected override visitValue(node: ValueNode): void { + this.defaultVisit(node); + } + protected override visitPrimitiveValueList( + node: PrimitiveValueListNode + ): void { + this.defaultVisit(node); + } + protected override visitOperator(node: OperatorNode): void { + this.defaultVisit(node); + } + protected override visitCreateIndex(node: CreateIndexNode): void { + this.defaultVisit(node); + } + protected override visitDropIndex(node: DropIndexNode): void { + this.defaultVisit(node); + } + protected override visitList(node: ListNode): void { + this.defaultVisit(node); + } + protected override visitPrimaryKeyConstraint( + node: PrimaryKeyConstraintNode + ): void { + this.defaultVisit(node); + } + protected override visitUniqueConstraint(node: UniqueConstraintNode): void { + this.defaultVisit(node); + } + protected override visitReferences(node: ReferencesNode): void { + this.defaultVisit(node); + } + protected override visitWith(node: WithNode): void { + this.defaultVisit(node); + } + protected override visitCommonTableExpression( + node: CommonTableExpressionNode + ): void { + this.defaultVisit(node); + } + protected override visitCommonTableExpressionName( + node: CommonTableExpressionNameNode + ): void { + this.defaultVisit(node); + } + protected override visitHaving(node: HavingNode): void { + this.defaultVisit(node); + } + protected override visitCreateSchema(node: CreateSchemaNode): void { + this.defaultVisit(node); + } + protected override visitDropSchema(node: DropSchemaNode): void { + this.defaultVisit(node); + } + protected override visitAlterTable(node: AlterTableNode): void { + this.defaultVisit(node); + } + protected override visitDropColumn(node: DropColumnNode): void { + this.defaultVisit(node); + } + protected override visitRenameColumn(node: RenameColumnNode): void { + this.defaultVisit(node); + } + protected override visitAlterColumn(node: AlterColumnNode): void { + this.defaultVisit(node); + } + protected override visitModifyColumn(node: ModifyColumnNode): void { + this.defaultVisit(node); + } + protected override visitAddConstraint(node: AddConstraintNode): void { + this.defaultVisit(node); + } + protected override visitDropConstraint(node: DropConstraintNode): void { + this.defaultVisit(node); + } + protected override visitForeignKeyConstraint( + node: ForeignKeyConstraintNode + ): void { + this.defaultVisit(node); + } + protected override visitCreateView(node: CreateViewNode): void { + this.defaultVisit(node); + } + protected override visitDropView(node: DropViewNode): void { + this.defaultVisit(node); + } + protected override visitGenerated(node: GeneratedNode): void { + this.defaultVisit(node); + } + protected override visitDefaultValue(node: DefaultValueNode): void { + this.defaultVisit(node); + } + protected override visitOn(node: OnNode): void { + this.defaultVisit(node); + } + protected override visitValues(node: ValuesNode): void { + this.defaultVisit(node); + } + protected override visitSelectModifier(node: SelectModifierNode): void { + this.defaultVisit(node); + } + protected override visitCreateType(node: CreateTypeNode): void { + this.defaultVisit(node); + } + protected override visitDropType(node: DropTypeNode): void { + this.defaultVisit(node); + } + protected override visitExplain(node: ExplainNode): void { + this.defaultVisit(node); + } + protected override visitDefaultInsertValue( + node: DefaultInsertValueNode + ): void { + this.defaultVisit(node); + } + protected override visitAggregateFunction( + node: AggregateFunctionNode + ): void { + this.defaultVisit(node); + } + protected override visitOver(node: OverNode): void { + this.defaultVisit(node); + } + protected override visitPartitionBy(node: PartitionByNode): void { + this.defaultVisit(node); + } + protected override visitPartitionByItem(node: PartitionByItemNode): void { + this.defaultVisit(node); + } + protected override visitSetOperation(node: SetOperationNode): void { + this.defaultVisit(node); + } + protected override visitBinaryOperation(node: BinaryOperationNode): void { + this.defaultVisit(node); + } + protected override visitUnaryOperation(node: UnaryOperationNode): void { + this.defaultVisit(node); + } + protected override visitUsing(node: UsingNode): void { + this.defaultVisit(node); + } + protected override visitFunction(node: FunctionNode): void { + this.defaultVisit(node); + } + protected override visitCase(node: CaseNode): void { + this.defaultVisit(node); + } + protected override visitWhen(node: WhenNode): void { + this.defaultVisit(node); + } + protected override visitJSONReference(node: JSONReferenceNode): void { + this.defaultVisit(node); + } + protected override visitJSONPath(node: JSONPathNode): void { + this.defaultVisit(node); + } + protected override visitJSONPathLeg(node: JSONPathLegNode): void { + this.defaultVisit(node); + } + protected override visitJSONOperatorChain( + node: JSONOperatorChainNode + ): void { + this.defaultVisit(node); + } + protected override visitTuple(node: TupleNode): void { + this.defaultVisit(node); + } + protected override visitMergeQuery(node: MergeQueryNode): void { + this.defaultVisit(node); + } + protected override visitMatched(node: MatchedNode): void { + this.defaultVisit(node); + } + protected override visitAddIndex(node: AddIndexNode): void { + this.defaultVisit(node); + } + protected override visitCast(node: CastNode): void { + this.defaultVisit(node); + } + protected override visitFetch(node: FetchNode): void { + this.defaultVisit(node); + } + protected override visitTop(node: TopNode): void { + this.defaultVisit(node); + } + protected override visitOutput(node: OutputNode): void { + this.defaultVisit(node); + } +} diff --git a/packages/runtime/test/client-api/delete.test.ts b/packages/runtime/test/client-api/delete.test.ts index 130f8e4b..f84ef6ac 100644 --- a/packages/runtime/test/client-api/delete.test.ts +++ b/packages/runtime/test/client-api/delete.test.ts @@ -32,7 +32,7 @@ describe.each(createClientSpecs(PG_DB_NAME))( client.user.delete({ where: { id: '2' }, }) - ).toRejectNotFound(); + ).toBeRejectNotFound(); // found await expect( diff --git a/packages/runtime/test/client-api/update.test.ts b/packages/runtime/test/client-api/update.test.ts index 5bd5db3f..18ec0afd 100644 --- a/packages/runtime/test/client-api/update.test.ts +++ b/packages/runtime/test/client-api/update.test.ts @@ -30,7 +30,7 @@ describe.each(createClientSpecs(PG_DB_NAME))( where: { id: 'not-found' }, data: { name: 'Foo' }, }) - ).toRejectNotFound(); + ).toBeRejectNotFound(); // empty data await expect( @@ -286,7 +286,7 @@ describe.each(createClientSpecs(PG_DB_NAME))( }, include: { comments: true }, }) - ).toRejectNotFound(); + ).toBeRejectNotFound(); // set multiple await expect( @@ -359,7 +359,7 @@ describe.each(createClientSpecs(PG_DB_NAME))( }, include: { comments: true }, }) - ).toRejectNotFound(); + ).toBeRejectNotFound(); // connect multiple await expect( @@ -516,7 +516,7 @@ describe.each(createClientSpecs(PG_DB_NAME))( }, include: { comments: true }, }) - ).toRejectNotFound(); + ).toBeRejectNotFound(); // multiple await expect( @@ -574,7 +574,7 @@ describe.each(createClientSpecs(PG_DB_NAME))( data: { comments: { delete: { id: '4' } } }, include: { comments: true }, }) - ).toRejectNotFound(); + ).toBeRejectNotFound(); await expect(client.comment.findMany()).toResolveWithLength(3); // non-existing @@ -584,7 +584,7 @@ describe.each(createClientSpecs(PG_DB_NAME))( data: { comments: { delete: { id: '5' } } }, include: { comments: true }, }) - ).toRejectNotFound(); + ).toBeRejectNotFound(); await expect(client.comment.findMany()).toResolveWithLength(3); // multiple @@ -796,7 +796,7 @@ describe.each(createClientSpecs(PG_DB_NAME))( }, }, }) - ).toRejectNotFound(); + ).toBeRejectNotFound(); // transaction fails as a whole await expect( client.comment.findUnique({ where: { id: '1' } }) @@ -823,7 +823,7 @@ describe.each(createClientSpecs(PG_DB_NAME))( }, }, }) - ).toRejectNotFound(); + ).toBeRejectNotFound(); // transaction fails as a whole await expect( client.comment.findUnique({ where: { id: '1' } }) @@ -1240,7 +1240,7 @@ describe.each(createClientSpecs(PG_DB_NAME))( }, include: { profile: true }, }) - ).toRejectNotFound(); + ).toBeRejectNotFound(); }); it('works with nested to-one relation connectOrCreate', async () => { @@ -1367,7 +1367,7 @@ describe.each(createClientSpecs(PG_DB_NAME))( }, include: { profile: true }, }) - ).toRejectNotFound(); + ).toBeRejectNotFound(); }); it('works with nested to-one relation update', async () => { @@ -1423,7 +1423,7 @@ describe.each(createClientSpecs(PG_DB_NAME))( }, }, }) - ).toRejectNotFound(); + ).toBeRejectNotFound(); // not connected const user2 = await createUser(client, 'u2@example.com', {}); @@ -1436,7 +1436,7 @@ describe.each(createClientSpecs(PG_DB_NAME))( }, }, }) - ).toRejectNotFound(); + ).toBeRejectNotFound(); }); it('works with nested to-one relation upsert', async () => { @@ -1558,7 +1558,7 @@ describe.each(createClientSpecs(PG_DB_NAME))( }, }, }) - ).toRejectNotFound(); + ).toBeRejectNotFound(); // not connected await client.profile.create({ @@ -1573,7 +1573,7 @@ describe.each(createClientSpecs(PG_DB_NAME))( }, }, }) - ).toRejectNotFound(); + ).toBeRejectNotFound(); // non-existing await client.user.update({ @@ -1593,7 +1593,7 @@ describe.each(createClientSpecs(PG_DB_NAME))( }, }, }) - ).toRejectNotFound(); + ).toBeRejectNotFound(); }); }); @@ -1674,7 +1674,7 @@ describe.each(createClientSpecs(PG_DB_NAME))( }, include: { user: true }, }) - ).toRejectNotFound(); + ).toBeRejectNotFound(); }); it('works with nested to-one owning relation connectOrCreate', async () => { @@ -1790,7 +1790,7 @@ describe.each(createClientSpecs(PG_DB_NAME))( }, }, }) - ).toRejectNotFound(); + ).toBeRejectNotFound(); // null relation await expect( @@ -1874,7 +1874,7 @@ describe.each(createClientSpecs(PG_DB_NAME))( }, }, }) - ).toRejectNotFound(); + ).toBeRejectNotFound(); // not connected const profile2 = await client.profile.create({ @@ -1889,7 +1889,7 @@ describe.each(createClientSpecs(PG_DB_NAME))( }, }, }) - ).toRejectNotFound(); + ).toBeRejectNotFound(); }); it('works with nested to-one owning relation upsert', async () => { @@ -2015,7 +2015,7 @@ describe.each(createClientSpecs(PG_DB_NAME))( }, }, }) - ).toRejectNotFound(); + ).toBeRejectNotFound(); // not connected await client.user.create({ @@ -2030,7 +2030,7 @@ describe.each(createClientSpecs(PG_DB_NAME))( }, }, }) - ).toRejectNotFound(); + ).toBeRejectNotFound(); // non-existing await client.profile.update({ @@ -2050,7 +2050,7 @@ describe.each(createClientSpecs(PG_DB_NAME))( }, }, }) - ).toRejectNotFound(); + ).toBeRejectNotFound(); }); }); } diff --git a/packages/runtime/test/policy/todo-sample.test.ts b/packages/runtime/test/policy/todo-sample.test.ts new file mode 100644 index 00000000..aa43cb62 --- /dev/null +++ b/packages/runtime/test/policy/todo-sample.test.ts @@ -0,0 +1,131 @@ +import { generateTsSchemaFromFile } from '@zenstackhq/testtools'; +import { beforeAll, describe, expect, it } from 'vitest'; +import { ZenStackClient } from '../../src'; +import type { SchemaDef } from '../../src/schema'; +import { PolicyPlugin } from '../../src/plugins/policy'; + +describe('Todo sample', () => { + let schema: SchemaDef; + + beforeAll(async () => { + schema = await generateTsSchemaFromFile('../schemas/todo.zmodel'); + }); + + it('works with user CRUD', async () => { + const user1 = { + id: 'user1', + email: 'user1@zenstack.dev', + name: 'User 1', + }; + const user2 = { + id: 'user2', + email: 'user2@zenstack.dev', + name: 'User 2', + }; + + const client: any = new ZenStackClient(schema, { log: ['query'] }); + await client.$pushSchema(); + + const anonDb: any = client.$use(new PolicyPlugin()); + + const user1Db = anonDb.$setAuth({ id: user1.id }); + const user2Db = anonDb.$setAuth({ id: user2.id }); + + // create user1 + // create should succeed but result can't be read back anonymously + await expect(anonDb.user.create({ data: user1 })).toBeRejectedByPolicy([ + 'result is not allowed to be read back', + ]); + await expect( + user1Db.user.findUnique({ where: { id: user1.id } }) + ).toResolveTruthy(); + await expect( + user2Db.user.findUnique({ where: { id: user1.id } }) + ).toResolveNull(); + + // create user2 + await expect( + anonDb.user.create({ data: user2 }) + ).toBeRejectedByPolicy(); + await expect(client.user.count()).resolves.toBe(2); + + // find with user1 should only get user1 + const r = await user1Db.user.findMany(); + expect(r).toHaveLength(1); + expect(r[0]).toEqual(expect.objectContaining(user1)); + + // get user2 as user1 + await expect( + user1Db.user.findUnique({ where: { id: user2.id } }) + ).toResolveNull(); + + await expect( + user1Db.space.create({ + data: { + id: 'space1', + name: 'Space 1', + slug: 'space1', + owner: { connect: { id: user1.id } }, + members: { + create: { + user: { connect: { id: user1.id } }, + role: 'ADMIN', + }, + }, + }, + }) + ).toResolveTruthy(); + + // user2 can't add himself into space1 by setting himself as admin + // because "create" check is done before entity is created + await expect( + user2Db.spaceUser.create({ + data: { + spaceId: 'space1', + userId: user2.id, + role: 'ADMIN', + }, + }) + ).toBeRejectedByPolicy(); + + // user1 can add user2 as a member + await expect( + user1Db.spaceUser.create({ + data: { spaceId: 'space1', userId: user2.id, role: 'USER' }, + }) + ).toResolveTruthy(); + + // now both user1 and user2 should be visible + await expect(user1Db.user.findMany()).resolves.toHaveLength(2); + await expect(user2Db.user.findMany()).resolves.toHaveLength(2); + + // // update user2 as user1 + // await expect( + // user2Db.user.update({ + // where: { id: user1.id }, + // data: { name: 'hello' }, + // }) + // ).toBeRejectedByPolicy(); + + // // update user1 as user1 + // await expect( + // user1Db.user.update({ + // where: { id: user1.id }, + // data: { name: 'hello' }, + // }) + // ).toResolveTruthy(); + + // // delete user2 as user1 + // await expect( + // user1Db.user.delete({ where: { id: user2.id } }) + // ).toBeRejectedByPolicy(); + + // // delete user1 as user1 + // await expect( + // user1Db.user.delete({ where: { id: user1.id } }) + // ).toResolveTruthy(); + // await expect( + // user1Db.user.findUnique({ where: { id: user1.id } }) + // ).toResolveNull(); + }); +}); diff --git a/packages/runtime/test/schemas/todo.zmodel b/packages/runtime/test/schemas/todo.zmodel new file mode 100644 index 00000000..535488ac --- /dev/null +++ b/packages/runtime/test/schemas/todo.zmodel @@ -0,0 +1,153 @@ +/* +* Sample model for a collaborative Todo app +*/ + +datasource db { + provider = 'sqlite' + url = 'file:./test.db' +} + +generator js { + provider = 'prisma-client-js' +} + +/* + * Model for a space in which users can collaborate on Lists and Todos + */ +model Space { + id String @id @default(uuid()) + createdAt DateTime @default(now()) + updatedAt DateTime @updatedAt + name String @length(4, 50) + slug String @unique @length(4, 16) + owner User? @relation(fields: [ownerId], references: [id]) + ownerId String? + members SpaceUser[] + lists List[] + + // require login + @@deny('all', auth() == null) + + // everyone can create a space + @@allow('create', true) + + // any user in the space can read the space + @@allow('read', members?[userId == auth().id]) + + // space admin can update and delete + @@allow('update,delete', members?[userId == auth().id && role == 'ADMIN']) +} + +/* + * Model representing membership of a user in a space + */ +model SpaceUser { + id String @id @default(uuid()) + createdAt DateTime @default(now()) + updatedAt DateTime @updatedAt + space Space @relation(fields: [spaceId], references: [id], onDelete: Cascade) + spaceId String + user User @relation(fields: [userId], references: [id], onDelete: Cascade) + userId String + role String + @@unique([userId, spaceId]) + + // require login + @@deny('all', auth() == null) + + // space admin can create/update/delete + @@allow('create,update,delete', space.ownerId == auth().id || space.members?[userId == auth().id && role == 'ADMIN']) + + // user can read entries for spaces which he's a member of + @@allow('read', space.members?[userId == auth().id]) +} + +/* + * Model for a user + */ +model User { + id String @id @default(uuid()) + createdAt DateTime @default(now()) + updatedAt DateTime @updatedAt + email String @unique @email + password String? @password @omit + emailVerified DateTime? + name String? + bio String @ignore + ownedSpaces Space[] + spaces SpaceUser[] + image String? @url + lists List[] + todos Todo[] + + // can be created by anyone, even not logged in + @@allow('create', true) + + // can be read by users sharing any space + @@allow('read', spaces?[space.members?[userId == auth().id]]) + + // full access by oneself + @@allow('all', auth().id == id) +} + +/* + * Model for a Todo list + */ +model List { + id String @id @default(uuid()) + createdAt DateTime @default(now()) + updatedAt DateTime @updatedAt + space Space @relation(fields: [spaceId], references: [id], onDelete: Cascade) + spaceId String + owner User @relation(fields: [ownerId], references: [id], onDelete: Cascade) + ownerId String + title String @length(1, 100) + private Boolean @default(false) + todos Todo[] + revision Int @default(0) + + // require login + @@deny('all', auth() == null) + + // can be read by owner or space members (only if not private) + @@allow('read', ownerId == auth().id || (space.members?[userId == auth().id] && !private)) + + // when create, owner must be set to current user, and user must be in the space + @@allow('create', ownerId == auth().id && space.members?[userId == auth().id]) + + // when create, owner must be set to current user, and user must be in the space + // update is not allowed to change owner + @@allow('update', ownerId == auth().id && space.members?[userId == auth().id] + // TODO: future() support + // && future().ownerId == ownerId + ) + + // can be deleted by owner + @@allow('delete', ownerId == auth().id) +} + +/* + * Model for a single Todo + */ +model Todo { + id String @id @default(uuid()) + createdAt DateTime @default(now()) + updatedAt DateTime @updatedAt + owner User @relation(fields: [ownerId], references: [id], onDelete: Cascade) + ownerId String + list List @relation(fields: [listId], references: [id], onDelete: Cascade) + listId String + title String @length(1, 100) + completedAt DateTime? + + // require login + @@deny('all', auth() == null) + + // owner has full access, also space members have full access (if the parent List is not private) + @@allow('all', list.ownerId == auth().id) + @@allow('all', list.space.members?[userId == auth().id] && !list.private) + + // TODO: future() support + // // update is not allowed to change owner + // @@deny('update', future().owner != owner) +} diff --git a/packages/runtime/test/test-schema.ts b/packages/runtime/test/test-schema.ts index eed89ec6..c514794a 100644 --- a/packages/runtime/test/test-schema.ts +++ b/packages/runtime/test/test-schema.ts @@ -95,26 +95,42 @@ export const schema = { id: { type: 'String' }, email: { type: 'String' }, }, - policies: [ + attributes: [ // @@allow('all', auth() == this) { - kind: 'allow', - operations: ['all'], - expression: Expression.binary( - Expression.call('auth'), - '==', - Expression._this() - ), + name: '@@allow', + args: [ + { + name: 'operation', + value: Expression.literal('all'), + }, + { + name: 'condition', + value: Expression.binary( + Expression.call('auth'), + '==', + Expression._this() + ), + }, + ], }, // @@allow('read', auth() != null) { - kind: 'allow', - operations: ['read'], - expression: Expression.binary( - Expression.call('auth'), - '!=', - Expression._null() - ), + name: '@@allow', + args: [ + { + name: 'operation', + value: Expression.literal('read'), + }, + { + name: 'condition', + value: Expression.binary( + Expression.call('auth'), + '!=', + Expression._null() + ), + }, + ], }, ], }, @@ -170,31 +186,56 @@ export const schema = { uniqueFields: { id: { type: 'String' }, }, - policies: [ + attributes: [ // @@deny('all', auth() == null) { - kind: 'deny', - operations: ['all'], - expression: Expression.binary( - Expression.call('auth'), - '==', - Expression._null() - ), + name: '@@deny', + args: [ + { + name: 'operation', + value: Expression.literal('all'), + }, + { + name: 'condition', + value: Expression.binary( + Expression.call('auth'), + '==', + Expression._null() + ), + }, + ], }, // @@allow('all', auth() == author) { - kind: 'allow', - operations: ['all'], - expression: Expression.binary( - Expression.call('auth'), - '==', - Expression.ref('Post', 'author') - ), + name: '@@allow', + args: [ + { + name: 'operation', + value: Expression.literal('all'), + }, + { + name: 'condition', + value: Expression.binary( + Expression.call('auth'), + '==', + Expression.field('author') + ), + }, + ], }, + // @@allow('read', published) { - kind: 'allow', - operations: ['read'], - expression: Expression.ref('Post', 'published'), + name: '@@allow', + args: [ + { + name: 'operation', + value: Expression.literal('read'), + }, + { + name: 'condition', + value: Expression.field('published'), + }, + ], }, ], }, @@ -272,17 +313,14 @@ export const schema = { }, }, }, + authType: 'User', enums: { Role: { ADMIN: 'ADMIN', USER: 'USER', }, }, - plugins: { - policy: { - authModel: 'User', - }, - }, + plugins: {}, } as const satisfies SchemaDef; export function getSchema( diff --git a/packages/runtime/test/vitest-ext.ts b/packages/runtime/test/vitest-ext.ts index 8c23395d..ede7a086 100644 --- a/packages/runtime/test/vitest-ext.ts +++ b/packages/runtime/test/vitest-ext.ts @@ -1,5 +1,5 @@ import { expect } from 'vitest'; -import { NotFoundError } from '../src/client/errors'; +import { NotFoundError, RejectedByPolicyError } from '../src/client/errors'; function isPromise(value: any) { return ( @@ -67,7 +67,7 @@ expect.extend({ }; }, - async toRejectNotFound(received: Promise) { + async toBeRejectedNotFound(received: Promise) { if (!isPromise(received)) { return { message: () => 'a promise is expected', pass: false }; } @@ -81,4 +81,34 @@ expect.extend({ pass: false, }; }, + + async toBeRejectedByPolicy( + received: Promise, + expectedMessages?: string[] + ) { + if (!isPromise(received)) { + return { message: () => 'a promise is expected', pass: false }; + } + try { + await received; + } catch (err) { + if (expectedMessages && err instanceof RejectedByPolicyError) { + const message = err.message || ''; + for (const m of expectedMessages) { + if (!message.includes(m)) { + return { + message: () => + `expected message not found in error: ${m}, got message: ${message}`, + pass: false, + }; + } + } + } + return expectError(err, RejectedByPolicyError); + } + return { + message: () => `expected PolicyError, got no error`, + pass: false, + }; + }, }); diff --git a/packages/runtime/test/vitest.d.ts b/packages/runtime/test/vitest.d.ts index 1fab2e73..b4be622b 100644 --- a/packages/runtime/test/vitest.d.ts +++ b/packages/runtime/test/vitest.d.ts @@ -5,7 +5,8 @@ interface CustomMatchers { toResolveFalsy: () => Promise; toResolveNull: () => Promise; toResolveWithLength: (length: number) => Promise; - toRejectNotFound: () => Promise; + toBeRejectNotFound: () => Promise; + toBeRejectedByPolicy: (expectedMessages?: string[]) => Promise; } declare module 'vitest' { diff --git a/packages/sdk/package.json b/packages/sdk/package.json new file mode 100644 index 00000000..e3b5884e --- /dev/null +++ b/packages/sdk/package.json @@ -0,0 +1,42 @@ +{ + "name": "@zenstackhq/sdk", + "version": "3.0.0-alpha.1", + "description": "ZenStack SDK", + "type": "module", + "scripts": { + "build": "tsup-node", + "watch": "tsup-node --watch", + "test": "vitest", + "pack": "pnpm pack" + }, + "keywords": [], + "author": "ZenStack Team", + "license": "MIT", + "files": [ + "dist" + ], + "exports": { + ".": { + "import": { + "types": "./dist/index.d.ts", + "default": "./dist/index.js" + }, + "require": { + "types": "./dist/index.d.cts", + "default": "./dist/index.cjs" + } + } + }, + "dependencies": { + "@zenstackhq/language": "workspace:*", + "langium": "~3.3.0", + "tiny-invariant": "^1.3.3", + "tmp": "^0.2.3", + "ts-pattern": "^5.7.0", + "typescript": "^5.8.3" + }, + "devDependencies": { + "@types/node": "^18.0.0", + "@types/tmp": "^0.2.6" + } +} diff --git a/packages/sdk/src/index.ts b/packages/sdk/src/index.ts new file mode 100644 index 00000000..f7f7bca8 --- /dev/null +++ b/packages/sdk/src/index.ts @@ -0,0 +1,4 @@ +import * as ModelUtils from './model-utils'; +export * from './ts-schema-generator'; +export * from './zmodel-code-generator'; +export { ModelUtils }; diff --git a/packages/cli/src/zmodel/model-utils.ts b/packages/sdk/src/model-utils.ts similarity index 94% rename from packages/cli/src/zmodel/model-utils.ts rename to packages/sdk/src/model-utils.ts index 8d8dd083..5bd4609a 100644 --- a/packages/cli/src/zmodel/model-utils.ts +++ b/packages/sdk/src/model-utils.ts @@ -1,7 +1,6 @@ import { isArrayExpr, isDataModel, - isInvocationExpr, isLiteralExpr, isModel, isReferenceExpr, @@ -200,14 +199,6 @@ export function isUniqueField(field: DataModelField) { return false; } -// export function isAuthInvocation(node: AstNode) { -// return ( -// isInvocationExpr(node) && -// node.function.ref?.name === 'auth' && -// isFromStdlib(node.function.ref) -// ); -// } - export function isFromStdlib(node: AstNode) { const model = getContainingModel(node); return ( @@ -230,3 +221,17 @@ export function resolved(ref: Reference): T { } return ref.ref; } + +export function getAuthDecl(model: Model) { + let found = model.declarations.find( + (d) => + isDataModel(d) && + d.attributes.some((attr) => attr.decl.$refText === '@@auth') + ); + if (!found) { + found = model.declarations.find( + (d) => isDataModel(d) && d.name === 'User' + ); + } + return found; +} diff --git a/packages/cli/src/zmodel/ts-schema-generator.ts b/packages/sdk/src/ts-schema-generator.ts similarity index 84% rename from packages/cli/src/zmodel/ts-schema-generator.ts rename to packages/sdk/src/ts-schema-generator.ts index 8dca2af0..ae601c0f 100644 --- a/packages/cli/src/zmodel/ts-schema-generator.ts +++ b/packages/sdk/src/ts-schema-generator.ts @@ -2,6 +2,7 @@ import { loadDocument } from '@zenstackhq/language'; import { ArrayExpr, AttributeArg, + BinaryExpr, DataModel, DataModelAttribute, DataModelField, @@ -10,20 +11,27 @@ import { Expression, InvocationExpr, isArrayExpr, + isBinaryExpr, isDataModel, + isDataModelField, isDataSource, isEnum, isEnumField, isInvocationExpr, isLiteralExpr, + isMemberAccessExpr, + isNullExpr, isProcedure, isReferenceExpr, + isThisExpr, + isUnaryExpr, LiteralExpr, + MemberAccessExpr, Procedure, ReferenceExpr, + UnaryExpr, type Model, } from '@zenstackhq/language/ast'; -import colors from 'colors'; import fs from 'node:fs'; import path from 'node:path'; import invariant from 'tiny-invariant'; @@ -31,30 +39,24 @@ import { match } from 'ts-pattern'; import * as ts from 'typescript'; import { getAttribute, + getAuthDecl, hasAttribute, isIdField, isUniqueField, } from './model-utils'; export class TsSchemaGenerator { - public async generate(schemaFile: string, outputFile: string) { - const loaded = await loadDocument(schemaFile); + public async generate( + schemaFile: string, + pluginModelFiles: string[], + outputFile: string + ) { + const loaded = await loadDocument(schemaFile, pluginModelFiles); if (!loaded.success) { - console.error(colors.red('Error loading schema:')); - loaded.errors.forEach((error) => - console.error(colors.red(`- ${error}`)) - ); - return; + throw new Error(`Error loading schema:${loaded.errors.join('\n')}`); } const { model, warnings } = loaded; - if (warnings.length > 0) { - console.warn(colors.yellow('Warnings:')); - warnings.forEach((warning) => - console.warn(colors.yellow(`- ${warning}`)) - ); - } - const statements: ts.Statement[] = []; this.generateSchemaStatements(model, statements); @@ -77,6 +79,8 @@ export class TsSchemaGenerator { fs.mkdirSync(path.dirname(outputFile), { recursive: true }); fs.writeFileSync(outputFile, result); + + return { model, warnings }; } private generateSchemaStatements(model: Model, statements: ts.Statement[]) { @@ -242,6 +246,17 @@ export class TsSchemaGenerator { ); } + // authType + const authType = getAuthDecl(model); + if (authType) { + properties.push( + ts.factory.createPropertyAssignment( + 'authType', + this.createLiteralNode(authType.name) + ) + ); + } + // procedures const procedures = model.declarations.filter(isProcedure); if (procedures.length > 0) { @@ -284,7 +299,10 @@ export class TsSchemaGenerator { private createModelsObject(model: Model) { return ts.factory.createObjectLiteralExpression( model.declarations - .filter(isDataModel) + .filter( + (d): d is DataModel => + isDataModel(d) && !hasAttribute(d, '@@ignore') + ) .map((dm) => ts.factory.createPropertyAssignment( dm.name, @@ -301,12 +319,14 @@ export class TsSchemaGenerator { ts.factory.createPropertyAssignment( 'fields', ts.factory.createObjectLiteralExpression( - dm.fields.map((field) => - ts.factory.createPropertyAssignment( - field.name, - this.createDataModelFieldObject(field) - ) - ), + dm.fields + .filter((field) => !hasAttribute(field, '@ignore')) + .map((field) => + ts.factory.createPropertyAssignment( + field.name, + this.createDataModelFieldObject(field) + ) + ), true ) ), @@ -319,7 +339,8 @@ export class TsSchemaGenerator { ts.factory.createArrayLiteralExpression( dm.attributes.map((attr) => this.createAttributeObject(attr) - ) + ), + true ) ), ] @@ -553,17 +574,6 @@ export class TsSchemaGenerator { return ts.factory.createObjectLiteralExpression(objectFields, true); } - private getTableName(dm: DataModel) { - const mapping = dm.attributes.find( - (attr) => attr.decl.$refText === '@map' - ); - if (mapping) { - return (mapping.args[0]?.value as LiteralExpr).value as string; - } else { - return dm.name; - } - } - private getDataSourceProvider(model: Model) { const dataSource = model.declarations.find(isDataSource); invariant(dataSource, 'No data source found in the model'); @@ -701,7 +711,8 @@ export class TsSchemaGenerator { arg.name === 'fields' && isArrayExpr(arg.value) && arg.value.items.some( - (el) => isLiteralExpr(el) && el.value === field.name + (el) => + isReferenceExpr(el) && el.target.ref === field ) ) { result.push(f.name); @@ -770,22 +781,53 @@ export class TsSchemaGenerator { if (!fieldNames) { continue; } - properties.push( - ts.factory.createPropertyAssignment( - fieldNames.join('_'), - ts.factory.createObjectLiteralExpression( - fieldNames.map((field) => { - const f = dm.fields.find( - (f) => f.name === field - )!; - return ts.factory.createPropertyAssignment( + + if (fieldNames.length === 1) { + // single-field unique + const fieldDef = dm.fields.find( + (f) => f.name === fieldNames[0] + )!; + properties.push( + ts.factory.createPropertyAssignment( + fieldNames[0]!, + ts.factory.createObjectLiteralExpression([ + ts.factory.createPropertyAssignment( 'type', - ts.factory.createStringLiteral(f.type.type!) - ); - }) + ts.factory.createStringLiteral( + fieldDef.type.type! + ) + ), + ]) ) - ) - ); + ); + } else { + // multi-field unique + properties.push( + ts.factory.createPropertyAssignment( + fieldNames.join('_'), + ts.factory.createObjectLiteralExpression( + fieldNames.map((field) => { + const fieldDef = dm.fields.find( + (f) => f.name === field + )!; + return ts.factory.createPropertyAssignment( + field, + ts.factory.createObjectLiteralExpression( + [ + ts.factory.createPropertyAssignment( + 'type', + ts.factory.createStringLiteral( + fieldDef.type.type! + ) + ), + ] + ) + ); + }) + ) + ) + ); + } } } @@ -819,8 +861,10 @@ export class TsSchemaGenerator { } } - private createLiteralNode(arg: string | number | boolean): any { - return typeof arg === 'string' + private createLiteralNode(arg: string | number | boolean | null): any { + return arg === null + ? ts.factory.createNull() + : typeof arg === 'string' ? ts.factory.createStringLiteral(arg) : typeof arg === 'number' ? ts.factory.createNumericLiteral(arg) @@ -1107,6 +1151,13 @@ export class TsSchemaGenerator { .when(isInvocationExpr, (expr) => this.createCallExpression(expr)) .when(isReferenceExpr, (expr) => this.createRefExpression(expr)) .when(isArrayExpr, (expr) => this.createArrayExpression(expr)) + .when(isUnaryExpr, (expr) => this.createUnaryExpression(expr)) + .when(isBinaryExpr, (expr) => this.createBinaryExpression(expr)) + .when(isMemberAccessExpr, (expr) => + this.createMemberExpression(expr) + ) + .when(isNullExpr, () => this.createNullExpression()) + .when(isThisExpr, () => this.createThisExpression()) .otherwise(() => { throw new Error( `Unsupported attribute arg value: ${value.$type}` @@ -1114,6 +1165,87 @@ export class TsSchemaGenerator { }); } + private createThisExpression() { + return ts.factory.createCallExpression( + ts.factory.createIdentifier('Expression._this'), + undefined, + [] + ); + } + + private createMemberExpression(expr: MemberAccessExpr) { + const members: string[] = []; + + // turn nested member access expression into a flat list of members + let current: Expression = expr; + while (isMemberAccessExpr(current)) { + members.unshift(current.member.$refText); + current = current.operand; + } + const receiver = current; + + const args = [ + this.createExpression(receiver), + ts.factory.createArrayLiteralExpression( + members.map((m) => ts.factory.createStringLiteral(m)) + ), + ]; + + // if (isDataModel(expr.$resolvedType?.decl)) { + // const operandModel = expr.operand.$resolvedType?.decl! as DataModel; + // const relationModel = expr.$resolvedType.decl; + // args.push( + // ts.factory.createObjectLiteralExpression([ + // ts.factory.createPropertyAssignment( + // 'fromModel', + // ts.factory.createStringLiteral(operandModel.name) + // ), + // ts.factory.createPropertyAssignment( + // 'relationModel', + // ts.factory.createStringLiteral(relationModel.name) + // ), + // ]) + // ); + // } + + return ts.factory.createCallExpression( + ts.factory.createIdentifier('Expression.member'), + undefined, + args + ); + } + + private createNullExpression() { + return ts.factory.createCallExpression( + ts.factory.createIdentifier('Expression._null'), + undefined, + [] + ); + } + + private createBinaryExpression(expr: BinaryExpr) { + return ts.factory.createCallExpression( + ts.factory.createIdentifier('Expression.binary'), + undefined, + [ + this.createExpression(expr.left), + this.createLiteralNode(expr.operator), + this.createExpression(expr.right), + ] + ); + } + + private createUnaryExpression(expr: UnaryExpr) { + return ts.factory.createCallExpression( + ts.factory.createIdentifier('Expression.unary'), + undefined, + [ + this.createLiteralNode(expr.operator), + this.createExpression(expr.operand), + ] + ); + } + private createArrayExpression(expr: ArrayExpr): any { return ts.factory.createCallExpression( ts.factory.createIdentifier('Expression.array'), @@ -1127,15 +1259,22 @@ export class TsSchemaGenerator { } private createRefExpression(expr: ReferenceExpr): any { - const target = expr.target.ref!; - return ts.factory.createCallExpression( - ts.factory.createIdentifier('Expression.ref'), - undefined, - [ - ts.factory.createStringLiteral(target.$container.name), - ts.factory.createStringLiteral(target.name), - ] - ); + if (isDataModelField(expr.target.ref)) { + return ts.factory.createCallExpression( + ts.factory.createIdentifier('Expression.field'), + undefined, + [this.createLiteralNode(expr.target.$refText)] + ); + } else if (isEnumField(expr.target.ref)) { + return this.createLiteralExpression( + 'StringLiteral', + expr.target.$refText + ); + } else { + throw new Error( + `Unsupported reference type: ${expr.target.$refText}` + ); + } } private createCallExpression(expr: InvocationExpr) { diff --git a/packages/cli/src/zmodel/zmodel-code-generator.ts b/packages/sdk/src/zmodel-code-generator.ts similarity index 99% rename from packages/cli/src/zmodel/zmodel-code-generator.ts rename to packages/sdk/src/zmodel-code-generator.ts index 2d2edde7..4cf6cebe 100644 --- a/packages/cli/src/zmodel/zmodel-code-generator.ts +++ b/packages/sdk/src/zmodel-code-generator.ts @@ -61,8 +61,8 @@ const generationHandlers = new Map(); // generation handler decorator function gen(name: string) { return function ( - target: unknown, - propertyKey: string, + _target: unknown, + _propertyKey: string, descriptor: PropertyDescriptor ) { if (!generationHandlers.get(name)) { diff --git a/packages/sdk/tsconfig.json b/packages/sdk/tsconfig.json new file mode 100644 index 00000000..b2b15c85 --- /dev/null +++ b/packages/sdk/tsconfig.json @@ -0,0 +1,8 @@ +{ + "extends": "../../tsconfig.json", + "compilerOptions": { + "outDir": "dist", + "noUnusedLocals": false + }, + "include": ["src/**/*.ts", "test/**/*.ts"] +} diff --git a/packages/sdk/tsup.config.ts b/packages/sdk/tsup.config.ts new file mode 100644 index 00000000..5a74a9dd --- /dev/null +++ b/packages/sdk/tsup.config.ts @@ -0,0 +1,13 @@ +import { defineConfig } from 'tsup'; + +export default defineConfig({ + entry: { + index: 'src/index.ts', + }, + outDir: 'dist', + splitting: false, + sourcemap: true, + clean: true, + dts: true, + format: ['cjs', 'esm'], +}); diff --git a/packages/testtools/package.json b/packages/testtools/package.json new file mode 100644 index 00000000..9575caab --- /dev/null +++ b/packages/testtools/package.json @@ -0,0 +1,46 @@ +{ + "name": "@zenstackhq/testtools", + "version": "3.0.0-alpha.1", + "description": "ZenStack Test Tools", + "type": "module", + "scripts": { + "build": "tsup-node", + "watch": "tsup-node --watch", + "test": "vitest", + "pack": "pnpm pack" + }, + "keywords": [], + "author": "ZenStack Team", + "license": "MIT", + "files": [ + "dist" + ], + "exports": { + ".": { + "import": { + "types": "./dist/index.d.ts", + "default": "./dist/index.js" + }, + "require": { + "types": "./dist/index.d.cts", + "default": "./dist/index.cjs" + } + } + }, + "dependencies": { + "@types/node": "^18.0.0", + "@zenstackhq/language": "workspace:*", + "@zenstackhq/runtime": "workspace:*", + "@zenstackhq/sdk": "workspace:*", + "glob": "^11.0.2", + "tmp": "^0.2.3", + "typescript": "^5.8.3" + }, + "peerDependencies": { + "better-sqlite3": "^11.8.1", + "pg": "^8.13.1" + }, + "devDependencies": { + "@types/tmp": "^0.2.6" + } +} diff --git a/packages/testtools/src/index.ts b/packages/testtools/src/index.ts new file mode 100644 index 00000000..e27a6e2f --- /dev/null +++ b/packages/testtools/src/index.ts @@ -0,0 +1 @@ +export * from './schema'; diff --git a/packages/cli/test/utils.ts b/packages/testtools/src/schema.ts similarity index 67% rename from packages/cli/test/utils.ts rename to packages/testtools/src/schema.ts index 34ed0601..3780e2df 100644 --- a/packages/cli/test/utils.ts +++ b/packages/testtools/src/schema.ts @@ -1,9 +1,10 @@ +import type { SchemaDef } from '@zenstackhq/runtime/schema'; +import { TsSchemaGenerator } from '@zenstackhq/sdk'; import { execSync } from 'node:child_process'; import fs from 'node:fs'; import path from 'node:path'; import tmp from 'tmp'; -import { TsSchemaGenerator } from '../src/zmodel/ts-schema-generator'; -import type { SchemaDef } from '../../runtime/src/schema'; +import { glob } from 'glob'; const ZMODEL_PRELUDE = ` datasource db { @@ -12,14 +13,22 @@ datasource db { } `; -export async function generateTsSchema(schemaText: string) { +export async function generateTsSchema(schemaText: string, noPrelude = false) { const { name: workDir } = tmp.dirSync({ unsafeCleanup: true }); console.log(`Working directory: ${workDir}`); const zmodelPath = path.join(workDir, 'schema.zmodel'); - fs.writeFileSync(zmodelPath, `${ZMODEL_PRELUDE}\n\n${schemaText}`); + fs.writeFileSync( + zmodelPath, + `${noPrelude ? '' : ZMODEL_PRELUDE}\n\n${schemaText}` + ); + + const pluginModelFiles = glob.sync( + path.resolve(__dirname, '../../runtime/src/plugins/**/plugin.zmodel') + ); + const generator = new TsSchemaGenerator(); const tsPath = path.join(workDir, 'schema.ts'); - await generator.generate(zmodelPath, tsPath); + await generator.generate(zmodelPath, pluginModelFiles, tsPath); fs.symlinkSync( path.join(__dirname, '../node_modules'), @@ -59,3 +68,8 @@ export async function generateTsSchema(schemaText: string) { const module = await import(path.join(workDir, 'schema.js')); return module.schema as SchemaDef; } + +export function generateTsSchemaFromFile(filePath: string) { + const schemaText = fs.readFileSync(filePath, 'utf8'); + return generateTsSchema(schemaText, true); +} diff --git a/packages/testtools/tsup.config.ts b/packages/testtools/tsup.config.ts new file mode 100644 index 00000000..5a74a9dd --- /dev/null +++ b/packages/testtools/tsup.config.ts @@ -0,0 +1,13 @@ +import { defineConfig } from 'tsup'; + +export default defineConfig({ + entry: { + index: 'src/index.ts', + }, + outDir: 'dist', + splitting: false, + sourcemap: true, + clean: true, + dts: true, + format: ['cjs', 'esm'], +}); diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index c0a6b68f..fc2416c1 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -33,14 +33,17 @@ importers: packages/cli: dependencies: '@types/node': - specifier: ^20.12.7 - version: 20.17.24 + specifier: ^18.0.0 + version: 18.19.71 '@zenstackhq/language': specifier: workspace:* version: link:../language '@zenstackhq/runtime': specifier: workspace:* version: link:../runtime + '@zenstackhq/sdk': + specifier: workspace:* + version: link:../sdk async-exit-hook: specifier: ^2.0.1 version: 2.0.1 @@ -81,6 +84,9 @@ importers: '@types/tmp': specifier: ^0.2.6 version: 0.2.6 + '@zenstackhq/testtools': + specifier: workspace:* + version: link:../testtools better-sqlite3: specifier: ^11.8.1 version: 11.8.1 @@ -127,9 +133,6 @@ importers: '@paralleldrive/cuid2': specifier: ^2.2.2 version: 2.2.2 - '@zenstackhq/language': - specifier: workspace:* - version: link:../language better-sqlite3: specifier: ^11.8.1 version: 11.8.1 @@ -176,9 +179,43 @@ importers: '@types/tmp': specifier: ^0.2.6 version: 0.2.6 + '@zenstackhq/language': + specifier: workspace:* + version: link:../language + '@zenstackhq/testtools': + specifier: workspace:* + version: link:../testtools + tmp: + specifier: ^0.2.3 + version: 0.2.3 + + packages/sdk: + dependencies: + '@zenstackhq/language': + specifier: workspace:* + version: link:../language + langium: + specifier: ~3.3.0 + version: 3.3.0 + tiny-invariant: + specifier: ^1.3.3 + version: 1.3.3 tmp: specifier: ^0.2.3 version: 0.2.3 + ts-pattern: + specifier: ^5.7.0 + version: 5.7.0 + typescript: + specifier: ^5.8.3 + version: 5.8.3 + devDependencies: + '@types/node': + specifier: ^18.0.0 + version: 18.19.71 + '@types/tmp': + specifier: ^0.2.6 + version: 0.2.6 packages/tanstack-query: dependencies: @@ -189,6 +226,40 @@ importers: specifier: workspace:* version: link:../runtime + packages/testtools: + dependencies: + '@types/node': + specifier: ^18.0.0 + version: 18.19.71 + '@zenstackhq/language': + specifier: workspace:* + version: link:../language + '@zenstackhq/runtime': + specifier: workspace:* + version: link:../runtime + '@zenstackhq/sdk': + specifier: workspace:* + version: link:../sdk + better-sqlite3: + specifier: ^11.8.1 + version: 11.8.1 + glob: + specifier: ^11.0.2 + version: 11.0.2 + pg: + specifier: ^8.13.1 + version: 8.13.1 + tmp: + specifier: ^0.2.3 + version: 0.2.3 + typescript: + specifier: ^5.8.3 + version: 5.8.3 + devDependencies: + '@types/tmp': + specifier: ^0.2.6 + version: 0.2.6 + packages/zod: dependencies: '@zenstackhq/runtime': @@ -218,7 +289,7 @@ importers: version: 7.6.12 prisma: specifier: ^6.0.0 - version: 6.5.0(typescript@5.7.3) + version: 6.5.0(typescript@5.8.3) packages: @@ -1569,6 +1640,11 @@ packages: resolution: {integrity: sha512-7Bv8RF0k6xjo7d4A/PxYLbUCfb6c+Vpd2/mB2yRDlew7Jb5hEXiCD9ibfO7wpk8i4sevK6DFny9h7EYbM3/sHg==} hasBin: true + glob@11.0.2: + resolution: {integrity: sha512-YT7U7Vye+t5fZ/QMkBFrTJ7ZQxInIUjwyAjVj84CYXqgBdv30MFUPGnBR6sQaVq6Is15wYJUsnzTuWaGRBhBAQ==} + engines: {node: 20 || >=22} + hasBin: true + glob@7.2.3: resolution: {integrity: sha512-nFR0zLpU2YCaRxwoCJvL6UvCH2JFyFVIvwTLsIf21AuHlMskA1hhTdk+LlYJtOlYt9v6dvszD2BGRqBL+iQK9Q==} deprecated: Glob versions prior to v9 are no longer supported @@ -1782,6 +1858,10 @@ packages: jackspeak@3.4.3: resolution: {integrity: sha512-OGlZQpz2yfahA/Rd1Y8Cd9SIEsqvXkLVoSw/cgwhnhFMDbsQFeZYoJJ7bIZBS9BcamUW96asq/npPWugM+RQBw==} + jackspeak@4.1.0: + resolution: {integrity: sha512-9DDdhb5j6cpeitCbvLO7n7J4IxnbM6hoF6O1g4HQ5TfhvvKN8ywDM7668ZhMHRqVmxqhps/F6syWK2KcPxYlkw==} + engines: {node: 20 || >=22} + joycon@3.1.1: resolution: {integrity: sha512-34wB/Y7MW7bzjKRjUKTa46I2Z7eV62Rkhva+KkopW7Qvv/OSWBqvkSY7vusOPrNuZcUG3tApvdVgNB8POj3SPw==} engines: {node: '>=10'} @@ -1872,6 +1952,10 @@ packages: lru-cache@10.4.3: resolution: {integrity: sha512-JNAzZcXrCt42VGLuYz0zfAzDfAvJWW6AfYlDBQyDV5DClI2m5sAmK+OIO7s59XfsRsWHp02jAJrRadPRGTt6SQ==} + lru-cache@11.1.0: + resolution: {integrity: sha512-QIXZUBJUx+2zHUdQujWejBkcD9+cs94tLn0+YL8UrCh+D5sCXZ4c7LaEH48pNwRY3MLDgqUFyhlCyjJPf1WP0A==} + engines: {node: 20 || >=22} + magic-string@0.30.17: resolution: {integrity: sha512-sNPKHvyjVf7gyjwS4xGTaW/mCnF8wnjtifKBEhxfZ7E/S8tQ0rssrwGNn6q8JH/ohItJfSQp9mBtQYuTlH5QnA==} @@ -1899,6 +1983,10 @@ packages: resolution: {integrity: sha512-z0yWI+4FDrrweS8Zmt4Ej5HdJmky15+L2e6Wgn3+iK5fWzb6T3fhNFq2+MeTRb064c6Wr4N/wv0DzQTjNzHNGQ==} engines: {node: '>=10'} + minimatch@10.0.1: + resolution: {integrity: sha512-ethXTt3SGGR+95gudmqJ1eNhRO7eGEGIgYA9vnPatK4/etz2MEVDno5GMCibdMTuBMyElzIlgxMna3K94XDIDQ==} + engines: {node: 20 || >=22} + minimatch@3.1.2: resolution: {integrity: sha512-J7p63hRiAjw1NDEww1W7i37+ByIrOWO5XQQAzZ3VOcL0PNybwpfmV/N05zFAzwQ9USyEcX6t3UO+K5aqBQOIHw==} @@ -2037,6 +2125,10 @@ packages: resolution: {integrity: sha512-Xa4Nw17FS9ApQFJ9umLiJS4orGjm7ZzwUrwamcGQuHSzDyth9boKDaycYdDcZDuqYATXw4HFXgaqWTctW/v1HA==} engines: {node: '>=16 || 14 >=14.18'} + path-scurry@2.0.0: + resolution: {integrity: sha512-ypGJsmGtdXUOeM5u93TyeIEfEhM6s+ljAhrk5vAvSx8uyY/02OvrZnA0YNGUrPXfpJMgI1ODd3nwz8Npx4O4cg==} + engines: {node: 20 || >=22} + path-type@3.0.0: resolution: {integrity: sha512-T2ZUsdZFHgA3u4e5PfPbjd7HDDpxPnQb5jN0SrDsjNSuVXHJqtwTnWqG0B1jZrgmJ/7lj1EmVIByWt1gxGkWvg==} engines: {node: '>=4'} @@ -2538,6 +2630,9 @@ packages: ts-pattern@5.6.0: resolution: {integrity: sha512-SL8u60X5+LoEy9tmQHWCdPc2hhb2pKI6I1tU5Jue3v8+iRqZdcT3mWPwKKJy1fMfky6uha82c8ByHAE8PMhKHw==} + ts-pattern@5.7.0: + resolution: {integrity: sha512-0/FvIG4g3kNkYgbNwBBW5pZBkfpeYQnH+2AA3xmjkCAit/DSDPKmgwC3fKof4oYUq6gupClVOJlFl+939VRBMg==} + tsup@8.3.5: resolution: {integrity: sha512-Tunf6r6m6tnZsG9GYWndg0z8dEV7fD733VBFzFJ5Vcm1FtlXB8xBD/rtrBi2a3YKEV7hHtxiZtW5EAVADoe1pA==} engines: {node: '>=18'} @@ -2633,6 +2728,11 @@ packages: engines: {node: '>=14.17'} hasBin: true + typescript@5.8.3: + resolution: {integrity: sha512-p1diW6TqL9L07nNxvRMM7hMMw4c5XOo/1ibL4aAIGmSAt9slTE1Xgw5KWuof2uTOvCg9BY7ZRi+GaF+7sfgPeQ==} + engines: {node: '>=14.17'} + hasBin: true + ulid@3.0.0: resolution: {integrity: sha512-yvZYdXInnJve6LdlPIuYmURdS2NP41ZoF4QW7SXwbUKYt53+0eDAySO+rGSvM2O/ciuB/G+8N7GQrZ1mCJpuqw==} hasBin: true @@ -3286,7 +3386,7 @@ snapshots: '@types/better-sqlite3@7.6.13': dependencies: - '@types/node': 20.17.24 + '@types/node': 18.19.71 '@types/estree@1.0.6': {} @@ -3299,10 +3399,11 @@ snapshots: '@types/node@20.17.24': dependencies: undici-types: 6.19.8 + optional: true '@types/pg@8.11.11': dependencies: - '@types/node': 20.17.24 + '@types/node': 18.19.71 pg-protocol: 1.7.0 pg-types: 4.0.2 @@ -4104,6 +4205,15 @@ snapshots: package-json-from-dist: 1.0.1 path-scurry: 1.11.1 + glob@11.0.2: + dependencies: + foreground-child: 3.3.0 + jackspeak: 4.1.0 + minimatch: 10.0.1 + minipass: 7.1.2 + package-json-from-dist: 1.0.1 + path-scurry: 2.0.0 + glob@7.2.3: dependencies: fs.realpath: 1.0.0 @@ -4314,6 +4424,10 @@ snapshots: optionalDependencies: '@pkgjs/parseargs': 0.11.0 + jackspeak@4.1.0: + dependencies: + '@isaacs/cliui': 8.0.2 + joycon@3.1.1: {} js-yaml@4.1.0: @@ -4404,6 +4518,8 @@ snapshots: lru-cache@10.4.3: {} + lru-cache@11.1.0: {} + magic-string@0.30.17: dependencies: '@jridgewell/sourcemap-codec': 1.5.0 @@ -4423,6 +4539,10 @@ snapshots: mimic-response@3.1.0: {} + minimatch@10.0.1: + dependencies: + brace-expansion: 2.0.1 + minimatch@3.1.2: dependencies: brace-expansion: 1.1.11 @@ -4568,6 +4688,11 @@ snapshots: lru-cache: 10.4.3 minipass: 7.1.2 + path-scurry@2.0.0: + dependencies: + lru-cache: 11.1.0 + minipass: 7.1.2 + path-type@3.0.0: dependencies: pify: 3.0.0 @@ -4703,6 +4828,16 @@ snapshots: transitivePeerDependencies: - supports-color + prisma@6.5.0(typescript@5.8.3): + dependencies: + '@prisma/config': 6.5.0 + '@prisma/engines': 6.5.0 + optionalDependencies: + fsevents: 2.3.3 + typescript: 5.8.3 + transitivePeerDependencies: + - supports-color + pump@3.0.2: dependencies: end-of-stream: 1.4.4 @@ -5084,6 +5219,8 @@ snapshots: ts-pattern@5.6.0: {} + ts-pattern@5.7.0: {} + tsup@8.3.5(@swc/core@1.10.15)(postcss@8.5.1)(tsx@4.19.2)(typescript@5.7.3): dependencies: bundle-require: 5.1.0(esbuild@0.24.2) @@ -5193,6 +5330,8 @@ snapshots: typescript@5.7.3: {} + typescript@5.8.3: {} + ulid@3.0.0: {} unbox-primitive@1.1.0: @@ -5204,7 +5343,8 @@ snapshots: undici-types@5.26.5: {} - undici-types@6.19.8: {} + undici-types@6.19.8: + optional: true universalify@2.0.1: {} diff --git a/samples/blog/zenstack/schema.ts b/samples/blog/zenstack/schema.ts index 9b8f9ffc..a28a3c36 100644 --- a/samples/blog/zenstack/schema.ts +++ b/samples/blog/zenstack/schema.ts @@ -2,6 +2,7 @@ // DO NOT MODIFY THIS FILE // // This file is automatically generated by ZenStack CLI and should not be manually updated. // ////////////////////////////////////////////////////////////////////////////////////////////// + import { type SchemaDef, type OperandExpression, Expression } from "@zenstackhq/runtime/schema"; import path from "node:path"; import url from "node:url"; @@ -11,8 +12,8 @@ export const schema = { type: "sqlite", dialectConfigProvider: function (): any { return { database: new SQLite(path.resolve(typeof __dirname !== 'undefined' - ? __dirname - : path.dirname(url.fileURLToPath(import.meta.url)), "./dev.db")) }; + ? __dirname + : path.dirname(url.fileURLToPath(import.meta.url)), "./dev.db")) }; } }, models: { @@ -160,6 +161,7 @@ export const schema = { USER: "USER" } }, + authType: "User", procedures: { signUp: { params: [ @@ -180,6 +182,6 @@ export const schema = { mutation: true } }, - plugins: { policy: { authModel: "User" } } + plugins: {} } as const satisfies SchemaDef; export type SchemaType = typeof schema; diff --git a/samples/blog/zenstack/schema.zmodel b/samples/blog/zenstack/schema.zmodel index f0c03eeb..5f1cd815 100644 --- a/samples/blog/zenstack/schema.zmodel +++ b/samples/blog/zenstack/schema.zmodel @@ -3,10 +3,6 @@ datasource db { url = 'file:./dev.db' } -plugin policy { - provider = '@core/policy' -} - enum Role { ADMIN USER From 3db450a2ea894216217541b5f6532d2797434b5e Mon Sep 17 00:00:00 2001 From: ymc9 <104139426+ymc9@users.noreply.github.com> Date: Thu, 8 May 2025 18:34:29 -0700 Subject: [PATCH 2/2] got todo sample running with policies --- packages/cli/test/ts-schema-gen.test.ts | 11 +- .../src/client/crud/operations/base.ts | 15 +++ .../src/client/crud/operations/create.ts | 36 ++---- .../src/client/crud/operations/delete.ts | 44 +++---- .../src/client/crud/operations/update.ts | 33 +++-- packages/runtime/src/client/errors.ts | 6 - .../executor/zenstack-query-executor.ts | 15 ++- packages/runtime/src/plugins/policy/errors.ts | 8 ++ .../plugins/policy/expression-transformer.ts | 12 +- .../src/plugins/policy/policy-handler.ts | 119 ++++++++++-------- packages/runtime/src/plugins/policy/utils.ts | 8 +- .../runtime/test/client-api/delete.test.ts | 2 +- .../runtime/test/client-api/update.test.ts | 44 +++---- .../test/plugin/kysely-on-query.test.ts | 53 ++------ .../test/plugin/mutation-hooks.test.ts | 4 +- packages/runtime/test/policy/read.test.ts | 34 ++--- .../runtime/test/policy/todo-sample.test.ts | 61 ++++----- packages/runtime/test/schemas/todo.zmodel | 2 +- packages/runtime/test/test-schema.ts | 12 +- packages/runtime/test/vitest-ext.ts | 3 +- packages/runtime/test/vitest.d.ts | 2 +- 21 files changed, 248 insertions(+), 276 deletions(-) create mode 100644 packages/runtime/src/plugins/policy/errors.ts diff --git a/packages/cli/test/ts-schema-gen.test.ts b/packages/cli/test/ts-schema-gen.test.ts index fbf8414a..6eb47db3 100644 --- a/packages/cli/test/ts-schema-gen.test.ts +++ b/packages/cli/test/ts-schema-gen.test.ts @@ -144,8 +144,7 @@ model Post { kind: 'array', items: [ { - kind: 'ref', - model: 'Post', + kind: 'field', field: 'authorId', }, ], @@ -157,8 +156,7 @@ model Post { kind: 'array', items: [ { - kind: 'ref', - model: 'User', + kind: 'field', field: 'id', }, ], @@ -167,9 +165,8 @@ model Post { { name: 'onDelete', value: { - kind: 'ref', - model: 'ReferentialAction', - field: 'Cascade', + kind: 'literal', + value: 'Cascade', }, }, ], diff --git a/packages/runtime/src/client/crud/operations/base.ts b/packages/runtime/src/client/crud/operations/base.ts index a8b66cb3..dd081e09 100644 --- a/packages/runtime/src/client/crud/operations/base.ts +++ b/packages/runtime/src/client/crud/operations/base.ts @@ -865,6 +865,7 @@ export abstract class BaseOperationHandler { this.dialect.buildFilter(eb, model, model, combinedWhere) ) .set(updateFields) + // TODO: return selectively .returningAll(); let updatedEntity: any; @@ -1555,6 +1556,7 @@ export abstract class BaseOperationHandler { const query = kysely .deleteFrom(model) .where((eb) => this.dialect.buildFilter(eb, model, model, where)) + // TODO: return selectively .$if(returnData, (qb) => qb.returningAll()); // const result = await this.queryExecutor.execute(kysely, query); @@ -1607,4 +1609,17 @@ export abstract class BaseOperationHandler { } return returnRelation; } + + protected async safeTransaction( + callback: (tx: ToKysely) => Promise + ) { + if (this.kysely.isTransaction) { + return callback(this.kysely); + } else { + return this.kysely + .transaction() + .setIsolationLevel('repeatable read') + .execute(callback); + } + } } diff --git a/packages/runtime/src/client/crud/operations/create.ts b/packages/runtime/src/client/crud/operations/create.ts index d9696e9e..033bf529 100644 --- a/packages/runtime/src/client/crud/operations/create.ts +++ b/packages/runtime/src/client/crud/operations/create.ts @@ -1,7 +1,7 @@ import { match } from 'ts-pattern'; +import { RejectedByPolicyError } from '../../../plugins/policy/errors'; import type { GetModels, SchemaDef } from '../../../schema'; import type { CreateArgs, CreateManyArgs } from '../../crud-types'; -import { RejectedByPolicyError } from '../../errors'; import { getIdValues } from '../../query-utils'; import { BaseOperationHandler } from './base'; @@ -27,31 +27,15 @@ export class CreateOperationHandler< } private async runCreate(args: CreateArgs>) { - let result: any; - try { - result = await this.kysely - .transaction() - .setIsolationLevel('repeatable read') - .execute(async (tx) => { - const createResult = await this.create( - tx, - this.model, - args.data - ); - return this.readUnique(tx, this.model, { - select: args.select, - include: args.include, - where: getIdValues( - this.schema, - this.model, - createResult - ), - }); - }); - } catch (err) { - // console.error(err); - throw err; - } + // TODO: avoid using transaction for simple create + const result = await this.safeTransaction(async (tx) => { + const createResult = await this.create(tx, this.model, args.data); + return this.readUnique(tx, this.model, { + select: args.select, + include: args.include, + where: getIdValues(this.schema, this.model, createResult), + }); + }); if (!result) { throw new RejectedByPolicyError( diff --git a/packages/runtime/src/client/crud/operations/delete.ts b/packages/runtime/src/client/crud/operations/delete.ts index 7b658f51..3e642d10 100644 --- a/packages/runtime/src/client/crud/operations/delete.ts +++ b/packages/runtime/src/client/crud/operations/delete.ts @@ -28,34 +28,24 @@ export class DeleteOperationHandler< async runDelete( args: DeleteArgs> ) { - const returnRelations = this.needReturnRelations(this.model, args); - - if (returnRelations) { - // employ a transaction - return this.kysely.transaction().execute(async (tx) => { - const existing = await this.readUnique(tx, this.model, { - select: args.select, - include: args.include, - where: args.where, - }); - if (!existing) { - throw new NotFoundError(this.model); - } - await this.delete(tx, this.model, args.where, false); - return existing; - }); - } else { - const result = await this.delete( - this.kysely, - this.model, - args.where, - true - ); - if ((result as unknown[]).length < 1) { - throw new NotFoundError(this.model); - } - return this.trimResult(result[0], args); + const existing = await this.readUnique(this.kysely, this.model, { + select: args.select, + include: args.include, + where: args.where, + }); + if (!existing) { + throw new NotFoundError(this.model); + } + const result = await this.delete( + this.kysely, + this.model, + args.where, + false + ); + if (result.count === 0) { + throw new NotFoundError(this.model); } + return existing; } async runDeleteMany( diff --git a/packages/runtime/src/client/crud/operations/update.ts b/packages/runtime/src/client/crud/operations/update.ts index f96b713c..e5dc1fc3 100644 --- a/packages/runtime/src/client/crud/operations/update.ts +++ b/packages/runtime/src/client/crud/operations/update.ts @@ -33,26 +33,23 @@ export class UpdateOperationHandler< if (hasRelationUpdate) { // employ a transaction try { - result = await this.kysely - .transaction() - .setIsolationLevel('repeatable read') - .execute(async (tx) => { - const updateResult = await this.update( - tx, + result = await this.safeTransaction(async (tx) => { + const updateResult = await this.update( + tx, + this.model, + args.where, + args.data + ); + return this.readUnique(tx, this.model, { + select: args.select, + include: args.include, + where: getIdValues( + this.schema, this.model, - args.where, - args.data - ); - return this.readUnique(tx, this.model, { - select: args.select, - include: args.include, - where: getIdValues( - this.schema, - this.model, - updateResult - ), - }); + updateResult + ), }); + }); } catch (err) { // console.error(err); throw err; diff --git a/packages/runtime/src/client/errors.ts b/packages/runtime/src/client/errors.ts index 91731ba0..f58a32f8 100644 --- a/packages/runtime/src/client/errors.ts +++ b/packages/runtime/src/client/errors.ts @@ -15,9 +15,3 @@ export class NotFoundError extends Error { super(`Entity not found for model "${model}"`); } } - -export class RejectedByPolicyError extends Error { - constructor(reason?: string) { - super(reason ?? `Operation rejected by policy`); - } -} diff --git a/packages/runtime/src/client/executor/zenstack-query-executor.ts b/packages/runtime/src/client/executor/zenstack-query-executor.ts index ea7f1eac..9cfd7017 100644 --- a/packages/runtime/src/client/executor/zenstack-query-executor.ts +++ b/packages/runtime/src/client/executor/zenstack-query-executor.ts @@ -7,6 +7,7 @@ import { Kysely, ReturningNode, SelectionNode, + SelectQueryNode, SingleConnectionProvider, UpdateQueryNode, WhereNode, @@ -17,7 +18,6 @@ import { type QueryCompiler, type QueryResult, type RootOperationNode, - type SelectQueryNode, type TableNode, } from 'kysely'; import { nanoid } from 'nanoid'; @@ -61,6 +61,15 @@ export class ZenStackQueryExecutor< return this.client.$options; } + private isCrudQueryNode(node: RootOperationNode) { + return ( + SelectQueryNode.is(node) || + InsertQueryNode.is(node) || + UpdateQueryNode.is(node) || + DeleteQueryNode.is(node) + ); + } + override async executeQuery( compiledQuery: CompiledQuery, queryId: QueryId @@ -111,7 +120,6 @@ export class ZenStackQueryExecutor< mutationInterceptionInfo ); - // trim the result to the original query node if (oldQueryNode !== queryNode) { // TODO: trim the result to the original query node } @@ -162,9 +170,7 @@ export class ZenStackQueryExecutor< private proceedQuery(query: RootOperationNode, queryId: QueryId) { // run built-in transformers const finalQuery = this.nameMapper.transformNode(query); - const compiled = this.compileQuery(finalQuery); - return this.driver.txConnection ? super .withConnectionProvider( @@ -415,7 +421,6 @@ export class ZenStackQueryExecutor< } plugin.afterEntityMutation({ - // context: this.queryContext, model: this.getMutationModel(queryNode), action: mutationInterceptionInfo.action, queryNode, diff --git a/packages/runtime/src/plugins/policy/errors.ts b/packages/runtime/src/plugins/policy/errors.ts new file mode 100644 index 00000000..ae707e74 --- /dev/null +++ b/packages/runtime/src/plugins/policy/errors.ts @@ -0,0 +1,8 @@ +/** + * Error thrown when an operation is rejected by access policy. + */ +export class RejectedByPolicyError extends Error { + constructor(reason?: string) { + super(reason ?? `Operation rejected by policy`); + } +} diff --git a/packages/runtime/src/plugins/policy/expression-transformer.ts b/packages/runtime/src/plugins/policy/expression-transformer.ts index 164bb4c6..cf36e092 100644 --- a/packages/runtime/src/plugins/policy/expression-transformer.ts +++ b/packages/runtime/src/plugins/policy/expression-transformer.ts @@ -129,7 +129,7 @@ export class ExpressionTransformer { @expr('null') // @ts-ignore private _null() { - return ValueNode.create(null); + return ValueNode.createImmediate(null); } @expr('binary') @@ -275,21 +275,21 @@ export class ExpressionTransformer { BinaryOperationNode.create( count, OperatorNode.create('>'), - ValueNode.create(0) + ValueNode.createImmediate(0) ) ) .with('!', () => BinaryOperationNode.create( count, OperatorNode.create('='), - ValueNode.create(0) + ValueNode.createImmediate(0) ) ) .with('^', () => BinaryOperationNode.create( count, OperatorNode.create('='), - ValueNode.create(0) + ValueNode.createImmediate(0) ) ) .exhaustive() @@ -361,7 +361,7 @@ export class ExpressionTransformer { 'Boolean' ); } else { - throw new Error('Unsupported expression'); + throw new Error('Unsupported binary expression with `auth()`'); } } @@ -511,7 +511,7 @@ export class ExpressionTransformer { receiverType: string ) { if (!receiver) { - return ValueNode.create(null); + return ValueNode.createImmediate(null); } if (expr.members.length !== 1) { diff --git a/packages/runtime/src/plugins/policy/policy-handler.ts b/packages/runtime/src/plugins/policy/policy-handler.ts index dea3a831..6dcce1b7 100644 --- a/packages/runtime/src/plugins/policy/policy-handler.ts +++ b/packages/runtime/src/plugins/policy/policy-handler.ts @@ -27,7 +27,7 @@ import { match } from 'ts-pattern'; import type { ClientContract } from '../../client'; import { getCrudDialect } from '../../client/crud/dialects'; import type { BaseCrudDialect } from '../../client/crud/dialects/base'; -import { InternalError, RejectedByPolicyError } from '../../client/errors'; +import { InternalError } from '../../client/errors'; import type { OnKyselyQueryTransaction, ProceedKyselyQueryFunction, @@ -35,6 +35,7 @@ import type { import { getIdFields, requireModel } from '../../client/query-utils'; import { Expression, type GetModels, type SchemaDef } from '../../schema'; import { ColumnCollector } from './column-collector'; +import { RejectedByPolicyError } from './errors'; import { ExpressionTransformer } from './expression-transformer'; import type { Policy, PolicyOperation } from './types'; import { @@ -83,11 +84,18 @@ export class PolicyHandler< throw new RejectedByPolicyError('non CRUD queries are not allowed'); } + if (!this.isMutationQueryNode(node)) { + // transform and proceed read without transaction + return proceed(this.transformNode(node)); + } + let mutationRequiresTransaction = false; + const mutationModel = this.getMutationModel(node); if (InsertQueryNode.is(node)) { + // reject create if unconditional deny const constCondition = this.tryGetConstantPolicy( - this.getMutationModel(node), + mutationModel, 'create' ); if (constCondition === false) { @@ -97,11 +105,6 @@ export class PolicyHandler< } } - if (!this.isMutationQueryNode(node)) { - // transform and proceed read without transaction - return proceed(this.transformNode(node)); - } - if (!mutationRequiresTransaction && !node.returning) { // transform and proceed mutation without transaction return proceed(this.transformNode(node)); @@ -244,39 +247,35 @@ export class PolicyHandler< result: QueryResult, proceed: ProceedKyselyQueryFunction ) { - if ( - InsertQueryNode.is(node) || - UpdateQueryNode.is(node) || - DeleteQueryNode.is(node) - ) { - if (node.returning) { - // do a select (with policy) in place of returning - const table = this.getMutationModel(node); - if (!table) { - throw new InternalError( - `Unable to get table name for query node: ${node}` - ); - } + if (result.rows.length === 0) { + return result; + } - const idConditions = this.buildIdConditions(table, result.rows); - const policyFilter = this.buildPolicyFilter(table, 'read'); - - const select: SelectQueryNode = { - kind: 'SelectQueryNode', - from: FromNode.create([TableNode.create(table)]), - where: WhereNode.create( - conjunction(this.dialect, [idConditions, policyFilter]) - ), - selections: node.returning.selections, - }; - const selectResult = await proceed(select); - return selectResult; - } else { - return result; - } + if (!this.isMutationQueryNode(node) || !node.returning) { + return result; } - return result; + // do a select (with policy) in place of returning + const table = this.getMutationModel(node); + if (!table) { + throw new InternalError( + `Unable to get table name for query node: ${node}` + ); + } + + const idConditions = this.buildIdConditions(table, result.rows); + const policyFilter = this.buildPolicyFilter(table, 'read'); + + const select: SelectQueryNode = { + kind: 'SelectQueryNode', + from: FromNode.create([TableNode.create(table)]), + where: WhereNode.create( + conjunction(this.dialect, [idConditions, policyFilter]) + ), + selections: node.returning.selections, + }; + const selectResult = await proceed(select); + return selectResult; } private buildIdConditions(table: string, rows: any[]): OperationNode { @@ -301,7 +300,7 @@ export class PolicyHandler< private getMutationModel( node: InsertQueryNode | UpdateQueryNode | DeleteQueryNode ) { - return match(node) + const r = match(node) .when( InsertQueryNode.is, (node) => getTableName(node.into) as GetModels @@ -310,11 +309,21 @@ export class PolicyHandler< UpdateQueryNode.is, (node) => getTableName(node.table) as GetModels ) - .when( - DeleteQueryNode.is, - (node) => getTableName(node.from) as GetModels - ) + .when(DeleteQueryNode.is, (node) => { + if (node.from.froms.length !== 1) { + throw new InternalError( + 'Only one from table is supported for delete' + ); + } + return getTableName(node.from.froms[0]) as GetModels; + }) .exhaustive(); + if (!r) { + throw new InternalError( + `Unable to get table name for query node: ${node}` + ); + } + return r; } private isCrudQueryNode(node: RootOperationNode): node is CrudQueryNode { @@ -362,9 +371,7 @@ export class PolicyHandler< if (allows.length === 0) { // constant false - combinedPolicy = ValueNode.create( - this.dialect.transformPrimitive(false, 'Boolean') - ); + combinedPolicy = falseNode(this.dialect); } else { // or(...allows) combinedPolicy = disjunction(this.dialect, allows); @@ -435,23 +442,29 @@ export class PolicyHandler< protected override transformUpdateQuery(node: UpdateQueryNode) { const result = super.transformUpdateQuery(node); - if (!node.returning) { - return result; - } + const mutationModel = this.getMutationModel(node); + const filter = this.buildPolicyFilter(mutationModel, 'update'); return { ...result, - returning: ReturningNode.create([SelectionNode.createSelectAll()]), + where: WhereNode.create( + result.where + ? conjunction(this.dialect, [result.where.where, filter]) + : filter + ), }; } protected override transformDeleteQuery(node: DeleteQueryNode) { const result = super.transformDeleteQuery(node); - if (!node.returning) { - return result; - } + const mutationModel = this.getMutationModel(node); + const filter = this.buildPolicyFilter(mutationModel, 'update'); return { ...result, - returning: ReturningNode.create([SelectionNode.createSelectAll()]), + where: WhereNode.create( + result.where + ? conjunction(this.dialect, [result.where.where, filter]) + : filter + ), }; } diff --git a/packages/runtime/src/plugins/policy/utils.ts b/packages/runtime/src/plugins/policy/utils.ts index 4e2e1fd5..541082b3 100644 --- a/packages/runtime/src/plugins/policy/utils.ts +++ b/packages/runtime/src/plugins/policy/utils.ts @@ -20,7 +20,9 @@ import type { SchemaDef } from '../../schema'; export function trueNode( dialect: BaseCrudDialect ) { - return ValueNode.create(dialect.transformPrimitive(true, 'Boolean')); + return ValueNode.createImmediate( + dialect.transformPrimitive(true, 'Boolean') + ); } /** @@ -29,7 +31,9 @@ export function trueNode( export function falseNode( dialect: BaseCrudDialect ) { - return ValueNode.create(dialect.transformPrimitive(false, 'Boolean')); + return ValueNode.createImmediate( + dialect.transformPrimitive(false, 'Boolean') + ); } /** diff --git a/packages/runtime/test/client-api/delete.test.ts b/packages/runtime/test/client-api/delete.test.ts index f84ef6ac..393dddfb 100644 --- a/packages/runtime/test/client-api/delete.test.ts +++ b/packages/runtime/test/client-api/delete.test.ts @@ -32,7 +32,7 @@ describe.each(createClientSpecs(PG_DB_NAME))( client.user.delete({ where: { id: '2' }, }) - ).toBeRejectNotFound(); + ).toBeRejectedNotFound(); // found await expect( diff --git a/packages/runtime/test/client-api/update.test.ts b/packages/runtime/test/client-api/update.test.ts index 18ec0afd..ccf4b3b2 100644 --- a/packages/runtime/test/client-api/update.test.ts +++ b/packages/runtime/test/client-api/update.test.ts @@ -30,7 +30,7 @@ describe.each(createClientSpecs(PG_DB_NAME))( where: { id: 'not-found' }, data: { name: 'Foo' }, }) - ).toBeRejectNotFound(); + ).toBeRejectedNotFound(); // empty data await expect( @@ -286,7 +286,7 @@ describe.each(createClientSpecs(PG_DB_NAME))( }, include: { comments: true }, }) - ).toBeRejectNotFound(); + ).toBeRejectedNotFound(); // set multiple await expect( @@ -359,7 +359,7 @@ describe.each(createClientSpecs(PG_DB_NAME))( }, include: { comments: true }, }) - ).toBeRejectNotFound(); + ).toBeRejectedNotFound(); // connect multiple await expect( @@ -516,7 +516,7 @@ describe.each(createClientSpecs(PG_DB_NAME))( }, include: { comments: true }, }) - ).toBeRejectNotFound(); + ).toBeRejectedNotFound(); // multiple await expect( @@ -574,7 +574,7 @@ describe.each(createClientSpecs(PG_DB_NAME))( data: { comments: { delete: { id: '4' } } }, include: { comments: true }, }) - ).toBeRejectNotFound(); + ).toBeRejectedNotFound(); await expect(client.comment.findMany()).toResolveWithLength(3); // non-existing @@ -584,7 +584,7 @@ describe.each(createClientSpecs(PG_DB_NAME))( data: { comments: { delete: { id: '5' } } }, include: { comments: true }, }) - ).toBeRejectNotFound(); + ).toBeRejectedNotFound(); await expect(client.comment.findMany()).toResolveWithLength(3); // multiple @@ -796,7 +796,7 @@ describe.each(createClientSpecs(PG_DB_NAME))( }, }, }) - ).toBeRejectNotFound(); + ).toBeRejectedNotFound(); // transaction fails as a whole await expect( client.comment.findUnique({ where: { id: '1' } }) @@ -823,7 +823,7 @@ describe.each(createClientSpecs(PG_DB_NAME))( }, }, }) - ).toBeRejectNotFound(); + ).toBeRejectedNotFound(); // transaction fails as a whole await expect( client.comment.findUnique({ where: { id: '1' } }) @@ -1240,7 +1240,7 @@ describe.each(createClientSpecs(PG_DB_NAME))( }, include: { profile: true }, }) - ).toBeRejectNotFound(); + ).toBeRejectedNotFound(); }); it('works with nested to-one relation connectOrCreate', async () => { @@ -1367,7 +1367,7 @@ describe.each(createClientSpecs(PG_DB_NAME))( }, include: { profile: true }, }) - ).toBeRejectNotFound(); + ).toBeRejectedNotFound(); }); it('works with nested to-one relation update', async () => { @@ -1423,7 +1423,7 @@ describe.each(createClientSpecs(PG_DB_NAME))( }, }, }) - ).toBeRejectNotFound(); + ).toBeRejectedNotFound(); // not connected const user2 = await createUser(client, 'u2@example.com', {}); @@ -1436,7 +1436,7 @@ describe.each(createClientSpecs(PG_DB_NAME))( }, }, }) - ).toBeRejectNotFound(); + ).toBeRejectedNotFound(); }); it('works with nested to-one relation upsert', async () => { @@ -1558,7 +1558,7 @@ describe.each(createClientSpecs(PG_DB_NAME))( }, }, }) - ).toBeRejectNotFound(); + ).toBeRejectedNotFound(); // not connected await client.profile.create({ @@ -1573,7 +1573,7 @@ describe.each(createClientSpecs(PG_DB_NAME))( }, }, }) - ).toBeRejectNotFound(); + ).toBeRejectedNotFound(); // non-existing await client.user.update({ @@ -1593,7 +1593,7 @@ describe.each(createClientSpecs(PG_DB_NAME))( }, }, }) - ).toBeRejectNotFound(); + ).toBeRejectedNotFound(); }); }); @@ -1674,7 +1674,7 @@ describe.each(createClientSpecs(PG_DB_NAME))( }, include: { user: true }, }) - ).toBeRejectNotFound(); + ).toBeRejectedNotFound(); }); it('works with nested to-one owning relation connectOrCreate', async () => { @@ -1790,7 +1790,7 @@ describe.each(createClientSpecs(PG_DB_NAME))( }, }, }) - ).toBeRejectNotFound(); + ).toBeRejectedNotFound(); // null relation await expect( @@ -1874,7 +1874,7 @@ describe.each(createClientSpecs(PG_DB_NAME))( }, }, }) - ).toBeRejectNotFound(); + ).toBeRejectedNotFound(); // not connected const profile2 = await client.profile.create({ @@ -1889,7 +1889,7 @@ describe.each(createClientSpecs(PG_DB_NAME))( }, }, }) - ).toBeRejectNotFound(); + ).toBeRejectedNotFound(); }); it('works with nested to-one owning relation upsert', async () => { @@ -2015,7 +2015,7 @@ describe.each(createClientSpecs(PG_DB_NAME))( }, }, }) - ).toBeRejectNotFound(); + ).toBeRejectedNotFound(); // not connected await client.user.create({ @@ -2030,7 +2030,7 @@ describe.each(createClientSpecs(PG_DB_NAME))( }, }, }) - ).toBeRejectNotFound(); + ).toBeRejectedNotFound(); // non-existing await client.profile.update({ @@ -2050,7 +2050,7 @@ describe.each(createClientSpecs(PG_DB_NAME))( }, }, }) - ).toBeRejectNotFound(); + ).toBeRejectedNotFound(); }); }); } diff --git a/packages/runtime/test/plugin/kysely-on-query.test.ts b/packages/runtime/test/plugin/kysely-on-query.test.ts index 2ce8ff74..bd5e6762 100644 --- a/packages/runtime/test/plugin/kysely-on-query.test.ts +++ b/packages/runtime/test/plugin/kysely-on-query.test.ts @@ -22,11 +22,9 @@ describe('Kysely onQuery tests', () => { const client = _client.$use({ id: 'test-plugin', onKyselyQuery(args) { - called = true; - expect(args).toMatchObject({ - query: expect.objectContaining({ kind: 'InsertQueryNode' }), - proceed: expect.any(Function), - }); + if (args.query.kind === 'InsertQueryNode') { + called = true; + } return args.proceed(args.query); }, }); @@ -63,6 +61,9 @@ describe('Kysely onQuery tests', () => { const client = _client.$use({ id: 'test-plugin', onKyselyQuery({ proceed, query }) { + if (query.kind !== 'InsertQueryNode') { + return proceed(query); + } const valueList = [ ...( ((query as InsertQueryNode).values as ValuesNode) @@ -122,42 +123,6 @@ describe('Kysely onQuery tests', () => { }); }); - it('can partially succeed without a transaction', async () => { - const client = _client.$use({ - id: 'test-plugin', - async onKyselyQuery({ kysely, proceed, query }) { - if (query.kind !== 'InsertQueryNode') { - return proceed(query); - } - - const result = await proceed(query); - - // create a post for the user - const now = new Date().toISOString(); - const createPost = kysely.insertInto('Post').values({ - id: '1', - title: 'Post1', - authorId: 'none-exist', - updatedAt: now, - }); - await proceed(createPost.toOperationNode()); - - return result; - }, - }); - - await expect( - client.user.create({ - data: { id: '1', email: 'u1@test.com' }, - }) - ).rejects.toThrow(); - - await expect(client.user.findFirst()).resolves.toMatchObject({ - email: 'u1@test.com', - }); - await expect(client.post.findFirst()).toResolveNull(); - }); - it('rolls back on error when a transaction is used', async () => { const client = _client.$use({ id: 'test-plugin', @@ -202,6 +167,9 @@ describe('Kysely onQuery tests', () => { .$use({ id: 'test-plugin', onKyselyQuery({ proceed, query }) { + if (query.kind !== 'InsertQueryNode') { + return proceed(query); + } called1 = true; const valueList = [ ...( @@ -224,6 +192,9 @@ describe('Kysely onQuery tests', () => { .$use({ id: 'test-plugin2', onKyselyQuery({ proceed, query }) { + if (query.kind !== 'InsertQueryNode') { + return proceed(query); + } called2 = true; const valueList = [ ...( diff --git a/packages/runtime/test/plugin/mutation-hooks.test.ts b/packages/runtime/test/plugin/mutation-hooks.test.ts index b321c868..790b03cb 100644 --- a/packages/runtime/test/plugin/mutation-hooks.test.ts +++ b/packages/runtime/test/plugin/mutation-hooks.test.ts @@ -307,7 +307,8 @@ describe('Entity lifecycle tests', () => { expect(post2Intercepted).toBe(true); }); - it('proceeds with mutation even when hooks throw', async () => { + // TODO: revisit mutation hooks and transactions + it.skip('proceeds with mutation even when hooks throw', async () => { let userIntercepted = false; const client = _client.$use({ @@ -330,6 +331,7 @@ describe('Entity lifecycle tests', () => { expect(userIntercepted).toBe(true); expect(gotError).toBe(true); + console.log(await client.user.findMany()); await expect(client.user.findMany()).toResolveWithLength(1); }); diff --git a/packages/runtime/test/policy/read.test.ts b/packages/runtime/test/policy/read.test.ts index 97028dac..12ec87d0 100644 --- a/packages/runtime/test/policy/read.test.ts +++ b/packages/runtime/test/policy/read.test.ts @@ -1,8 +1,8 @@ import { afterEach, beforeEach, describe, expect, it } from 'vitest'; import { type ClientContract } from '../../src/client'; +import { PolicyPlugin } from '../../src/plugins/policy/plugin'; import { createClientSpecs } from '../client-api/client-specs'; import { schema } from '../test-schema'; -import { PolicyPlugin } from '../../src/plugins/policy/plugin'; const PG_DB_NAME = 'policy-read-tests'; @@ -28,24 +28,13 @@ describe.each(createClientSpecs(PG_DB_NAME))( }); // anonymous auth context by default - const policyPlugin = new PolicyPlugin(); - - const anonClient = client.$use(policyPlugin); + const anonClient = client.$use(new PolicyPlugin()); await expect(anonClient.user.findFirst()).toResolveNull(); - const authClient = client.$use( - // switch auth context - policyPlugin.setAuth({ - id: user.id, - }) - ); + const authClient = anonClient.$setAuth({ + id: user.id, + }); await expect(authClient.user.findFirst()).resolves.toEqual(user); - - const authClient1 = client.$use( - // set auth context when creating the plugin - new PolicyPlugin({ auth: { id: user.id } }) - ); - await expect(authClient1.user.findFirst()).resolves.toEqual(user); }); it('works with ORM API nested', async () => { @@ -63,17 +52,14 @@ describe.each(createClientSpecs(PG_DB_NAME))( }, }); - const otherUserClient = client.$use( - new PolicyPlugin({ auth: { id: '2' } }) - ); + const anonClient = client.$use(new PolicyPlugin()); + const otherUserClient = anonClient.$setAuth({ id: '2' }); const r = await otherUserClient.user.findFirst({ include: { posts: true }, }); expect(r?.posts).toHaveLength(0); - const authClient = client.$use( - new PolicyPlugin({ auth: { id: '1' } }) - ); + const authClient = anonClient.$setAuth({ id: '1' }); const r1 = await authClient.user.findFirst({ include: { posts: true }, }); @@ -92,9 +78,7 @@ describe.each(createClientSpecs(PG_DB_NAME))( anonClient.$qb.selectFrom('User').selectAll().executeTakeFirst() ).toResolveFalsy(); - const authClient = client.$use( - new PolicyPlugin({ auth: { id: user.id } }) - ); + const authClient = anonClient.$setAuth({ id: user.id }); const foundUser = await authClient.$qb .selectFrom('User') .selectAll() diff --git a/packages/runtime/test/policy/todo-sample.test.ts b/packages/runtime/test/policy/todo-sample.test.ts index aa43cb62..e5f7b2f4 100644 --- a/packages/runtime/test/policy/todo-sample.test.ts +++ b/packages/runtime/test/policy/todo-sample.test.ts @@ -3,12 +3,15 @@ import { beforeAll, describe, expect, it } from 'vitest'; import { ZenStackClient } from '../../src'; import type { SchemaDef } from '../../src/schema'; import { PolicyPlugin } from '../../src/plugins/policy'; +import path from 'node:path'; describe('Todo sample', () => { let schema: SchemaDef; beforeAll(async () => { - schema = await generateTsSchemaFromFile('../schemas/todo.zmodel'); + schema = await generateTsSchemaFromFile( + path.join(__dirname, '../schemas/todo.zmodel') + ); }); it('works with user CRUD', async () => { @@ -99,33 +102,33 @@ describe('Todo sample', () => { await expect(user1Db.user.findMany()).resolves.toHaveLength(2); await expect(user2Db.user.findMany()).resolves.toHaveLength(2); - // // update user2 as user1 - // await expect( - // user2Db.user.update({ - // where: { id: user1.id }, - // data: { name: 'hello' }, - // }) - // ).toBeRejectedByPolicy(); - - // // update user1 as user1 - // await expect( - // user1Db.user.update({ - // where: { id: user1.id }, - // data: { name: 'hello' }, - // }) - // ).toResolveTruthy(); - - // // delete user2 as user1 - // await expect( - // user1Db.user.delete({ where: { id: user2.id } }) - // ).toBeRejectedByPolicy(); - - // // delete user1 as user1 - // await expect( - // user1Db.user.delete({ where: { id: user1.id } }) - // ).toResolveTruthy(); - // await expect( - // user1Db.user.findUnique({ where: { id: user1.id } }) - // ).toResolveNull(); + // update user2 as user1 + await expect( + user2Db.user.update({ + where: { id: user1.id }, + data: { name: 'hello' }, + }) + ).toBeRejectedNotFound(); + + // update user1 as user1 + await expect( + user1Db.user.update({ + where: { id: user1.id }, + data: { name: 'hello' }, + }) + ).toResolveTruthy(); + + // delete user2 as user1 + await expect( + user1Db.user.delete({ where: { id: user2.id } }) + ).toBeRejectedNotFound(); + + // delete user1 as user1 + await expect( + user1Db.user.delete({ where: { id: user1.id } }) + ).toResolveTruthy(); + await expect( + user1Db.user.findUnique({ where: { id: user1.id } }) + ).toResolveNull(); }); }); diff --git a/packages/runtime/test/schemas/todo.zmodel b/packages/runtime/test/schemas/todo.zmodel index 535488ac..fd87427b 100644 --- a/packages/runtime/test/schemas/todo.zmodel +++ b/packages/runtime/test/schemas/todo.zmodel @@ -20,7 +20,7 @@ model Space { updatedAt DateTime @updatedAt name String @length(4, 50) slug String @unique @length(4, 16) - owner User? @relation(fields: [ownerId], references: [id]) + owner User? @relation(fields: [ownerId], references: [id], onDelete: Cascade) ownerId String? members SpaceUser[] lists List[] diff --git a/packages/runtime/test/test-schema.ts b/packages/runtime/test/test-schema.ts index c514794a..446890bf 100644 --- a/packages/runtime/test/test-schema.ts +++ b/packages/runtime/test/test-schema.ts @@ -107,9 +107,11 @@ export const schema = { { name: 'condition', value: Expression.binary( - Expression.call('auth'), + Expression.member(Expression.call('auth'), [ + 'id', + ]), '==', - Expression._this() + Expression.field('id') ), }, ], @@ -216,9 +218,11 @@ export const schema = { { name: 'condition', value: Expression.binary( - Expression.call('auth'), + Expression.member(Expression.call('auth'), [ + 'id', + ]), '==', - Expression.field('author') + Expression.field('authorId') ), }, ], diff --git a/packages/runtime/test/vitest-ext.ts b/packages/runtime/test/vitest-ext.ts index ede7a086..6f423169 100644 --- a/packages/runtime/test/vitest-ext.ts +++ b/packages/runtime/test/vitest-ext.ts @@ -1,5 +1,6 @@ import { expect } from 'vitest'; -import { NotFoundError, RejectedByPolicyError } from '../src/client/errors'; +import { NotFoundError } from '../src/client/errors'; +import { RejectedByPolicyError } from '../src/plugins/policy/errors'; function isPromise(value: any) { return ( diff --git a/packages/runtime/test/vitest.d.ts b/packages/runtime/test/vitest.d.ts index b4be622b..b547127c 100644 --- a/packages/runtime/test/vitest.d.ts +++ b/packages/runtime/test/vitest.d.ts @@ -5,7 +5,7 @@ interface CustomMatchers { toResolveFalsy: () => Promise; toResolveNull: () => Promise; toResolveWithLength: (length: number) => Promise; - toBeRejectNotFound: () => Promise; + toBeRejectedNotFound: () => Promise; toBeRejectedByPolicy: (expectedMessages?: string[]) => Promise; }