diff --git a/TODO.md b/TODO.md index faededb7..270a15aa 100644 --- a/TODO.md +++ b/TODO.md @@ -11,6 +11,7 @@ - [ ] format - [ ] db seed - [ ] ZModel + - [ ] Import - [ ] View support - [ ] ORM - [x] Create @@ -80,8 +81,8 @@ - [ ] Strict undefined checks - [ ] DbNull vs JsonNull - [ ] Benchmark -- [ ] Plugin - - [ ] Post-mutation hooks should be called after transaction is committed +- [x] Plugin + - [x] Post-mutation hooks should be called after transaction is committed - [x] TypeDef and mixin - [ ] Strongly typed JSON - [x] Polymorphism diff --git a/packages/runtime/src/client/crud/operations/aggregate.ts b/packages/runtime/src/client/crud/operations/aggregate.ts index 5d309dda..13cb8b8e 100644 --- a/packages/runtime/src/client/crud/operations/aggregate.ts +++ b/packages/runtime/src/client/crud/operations/aggregate.ts @@ -112,11 +112,11 @@ export class AggregateOperationHandler extends BaseOpe } } - const result = await query.executeTakeFirstOrThrow(); + const result = await this.executeQuery(this.kysely, query, 'aggregate'); const ret: any = {}; // postprocess result to convert flat fields into nested objects - for (const [key, value] of Object.entries(result as object)) { + for (const [key, value] of Object.entries(result.rows[0] as object)) { if (key === '_count') { ret[key] = value; continue; diff --git a/packages/runtime/src/client/crud/operations/base.ts b/packages/runtime/src/client/crud/operations/base.ts index 926bb907..12e25891 100644 --- a/packages/runtime/src/client/crud/operations/base.ts +++ b/packages/runtime/src/client/crud/operations/base.ts @@ -6,8 +6,10 @@ import { ExpressionWrapper, sql, UpdateResult, + type Compilable, type IsolationLevel, type Expression as KyselyExpression, + type QueryResult, type SelectQueryBuilder, } from 'kysely'; import { nanoid } from 'nanoid'; @@ -125,7 +127,11 @@ export abstract class BaseOperationHandler { return getField(this.schema, model, field); } - protected exists(kysely: ToKysely, model: GetModels, filter: any): Promise { + protected async exists( + kysely: ToKysely, + model: GetModels, + filter: any, + ): Promise { const idFields = getIdFields(this.schema, model); const _filter = flattenCompoundUniqueFilters(this.schema, model, filter); const query = kysely @@ -134,7 +140,7 @@ export abstract class BaseOperationHandler { .select(idFields.map((f) => kysely.dynamic.ref(f))) .limit(1) .modifyEnd(this.makeContextComment({ model, operation: 'read' })); - return query.executeTakeFirst(); + return this.executeQueryTakeFirst(kysely, query, 'exists'); } protected async read( @@ -444,7 +450,7 @@ export abstract class BaseOperationHandler { operation: 'update', }), ); - return query.execute(); + return this.executeQuery(kysely, query, 'update'); }; } } @@ -511,10 +517,10 @@ export abstract class BaseOperationHandler { }), ); - const createdEntity = await query.executeTakeFirst(); + const createdEntity = await this.executeQueryTakeFirst(kysely, query, 'create'); // try { - // createdEntity = await query.executeTakeFirst(); + // createdEntity = await this.executeQueryTakeFirst(kysely, query, 'create'); // } catch (err) { // const { sql, parameters } = query.compile(); // throw new QueryError( @@ -893,8 +899,8 @@ export abstract class BaseOperationHandler { ); if (!returnData) { - const result = await query.executeTakeFirstOrThrow(); - return { count: Number(result.numInsertedOrUpdatedRows) } as Result; + const result = await this.executeQuery(kysely, query, 'createMany'); + return { count: Number(result.numAffectedRows) } as Result; } else { const idFields = getIdFields(this.schema, model); const result = await query.returning(idFields as any).execute(); @@ -1160,10 +1166,10 @@ export abstract class BaseOperationHandler { }), ); - const updatedEntity = await query.executeTakeFirst(); + const updatedEntity = await this.executeQueryTakeFirst(kysely, query, 'update'); // try { - // updatedEntity = await query.executeTakeFirst(); + // updatedEntity = await this.executeQueryTakeFirst(kysely, query, 'update'); // } catch (err) { // const { sql, parameters } = query.compile(); // throw new QueryError( @@ -1401,8 +1407,8 @@ export abstract class BaseOperationHandler { query = query.modifyEnd(this.makeContextComment({ model, operation: 'update' })); if (!returnData) { - const result = await query.executeTakeFirstOrThrow(); - return { count: Number(result.numUpdatedRows) } as Result; + const result = await this.executeQuery(kysely, query, 'update'); + return { count: Number(result.numAffectedRows) } as Result; } else { const idFields = getIdFields(this.schema, model); const result = await query.returning(idFields as any).execute(); @@ -1636,7 +1642,7 @@ export abstract class BaseOperationHandler { fromRelation.model, fromRelation.field, ); - let updateResult: UpdateResult; + let updateResult: QueryResult; if (ownedByModel) { // set parent fk directly @@ -1665,7 +1671,7 @@ export abstract class BaseOperationHandler { operation: 'update', }), ); - updateResult = await query.executeTakeFirstOrThrow(); + updateResult = await this.executeQuery(kysely, query, 'connect'); } else { // disconnect current if it's a one-one relation const relationFieldDef = this.requireField(fromRelation.model, fromRelation.field); @@ -1681,7 +1687,7 @@ export abstract class BaseOperationHandler { operation: 'update', }), ); - await query.execute(); + await this.executeQuery(kysely, query, 'disconnect'); } // connect @@ -1703,11 +1709,11 @@ export abstract class BaseOperationHandler { operation: 'update', }), ); - updateResult = await query.executeTakeFirstOrThrow(); + updateResult = await this.executeQuery(kysely, query, 'connect'); } // validate connect result - if (_data.length > updateResult.numUpdatedRows) { + if (_data.length > updateResult.numAffectedRows!) { // some entities were not connected throw new NotFoundError(model); } @@ -1821,7 +1827,7 @@ export abstract class BaseOperationHandler { operation: 'update', }), ); - await query.executeTakeFirstOrThrow(); + await this.executeQuery(kysely, query, 'disconnect'); } else { // disconnect const query = kysely @@ -1841,7 +1847,7 @@ export abstract class BaseOperationHandler { operation: 'update', }), ); - await query.executeTakeFirstOrThrow(); + await this.executeQuery(kysely, query, 'disconnect'); } } } @@ -1920,7 +1926,7 @@ export abstract class BaseOperationHandler { operation: 'update', }), ); - await query.execute(); + await this.executeQuery(kysely, query, 'disconnect'); // connect if (_data.length > 0) { @@ -1942,10 +1948,10 @@ export abstract class BaseOperationHandler { operation: 'update', }), ); - const r = await query.executeTakeFirstOrThrow(); + const r = await this.executeQuery(kysely, query, 'connect'); // validate result - if (_data.length > r.numUpdatedRows!) { + if (_data.length > r.numAffectedRows!) { // some entities were not connected throw new NotFoundError(model); } @@ -2109,8 +2115,8 @@ export abstract class BaseOperationHandler { await this.processDelegateRelationDelete(kysely, modelDef, where, limit); query = query.modifyEnd(this.makeContextComment({ model, operation: 'delete' })); - const result = await query.executeTakeFirstOrThrow(); - return { count: Number(result.numDeletedRows) }; + const result = await this.executeQuery(kysely, query, 'delete'); + return { count: Number(result.numAffectedRows) }; } private async processDelegateRelationDelete( @@ -2240,4 +2246,25 @@ export abstract class BaseOperationHandler { } } } + + protected makeQueryId(operation: string) { + return { queryId: `${operation}-${createId()}` }; + } + + protected executeQuery(kysely: ToKysely, query: Compilable, operation: string) { + return kysely.executeQuery(query.compile(), this.makeQueryId(operation)); + } + + protected async executeQueryTakeFirst(kysely: ToKysely, query: Compilable, operation: string) { + const result = await kysely.executeQuery(query.compile(), this.makeQueryId(operation)); + return result.rows[0]; + } + + protected async executeQueryTakeFirstOrThrow(kysely: ToKysely, query: Compilable, operation: string) { + const result = await kysely.executeQuery(query.compile(), this.makeQueryId(operation)); + if (result.rows.length === 0) { + throw new QueryError('No rows found'); + } + return result.rows[0]; + } } diff --git a/packages/runtime/src/client/crud/operations/count.ts b/packages/runtime/src/client/crud/operations/count.ts index fc22c2ec..8c11af3a 100644 --- a/packages/runtime/src/client/crud/operations/count.ts +++ b/packages/runtime/src/client/crud/operations/count.ts @@ -44,13 +44,13 @@ export class CountOperationHandler extends BaseOperati : eb.cast(eb.fn.count(sql.ref(`${subQueryName}.${key}`)), 'integer').as(key), ), ); - - return query.executeTakeFirstOrThrow(); + const result = await this.executeQuery(this.kysely, query, 'count'); + return result.rows[0]; } else { // simple count all query = query.select((eb) => eb.cast(eb.fn.countAll(), 'integer').as('count')); - const result = await query.executeTakeFirstOrThrow(); - return (result as any).count as number; + const result = await this.executeQuery(this.kysely, query, 'count'); + return (result.rows[0] as any).count as number; } } } diff --git a/packages/runtime/src/client/crud/operations/delete.ts b/packages/runtime/src/client/crud/operations/delete.ts index a33c2179..3ed17ce0 100644 --- a/packages/runtime/src/client/crud/operations/delete.ts +++ b/packages/runtime/src/client/crud/operations/delete.ts @@ -30,7 +30,7 @@ export class DeleteOperationHandler extends BaseOperat // TODO: avoid using transaction for simple delete await this.safeTransaction(async (tx) => { - const result = await this.delete(tx, this.model, args.where, undefined); + const result = await this.delete(tx, this.model, args.where); if (result.count === 0) { throw new NotFoundError(this.model); } diff --git a/packages/runtime/src/client/crud/operations/group-by.ts b/packages/runtime/src/client/crud/operations/group-by.ts index f309bf06..009fa3b5 100644 --- a/packages/runtime/src/client/crud/operations/group-by.ts +++ b/packages/runtime/src/client/crud/operations/group-by.ts @@ -108,8 +108,8 @@ export class GroupByOperationHandler extends BaseOpera } } - const result = await query.execute(); - return result.map((row) => this.postProcessRow(row)); + const result = await this.executeQuery(this.kysely, query, 'groupBy'); + return result.rows.map((row) => this.postProcessRow(row)); } private postProcessRow(row: any) { diff --git a/packages/runtime/src/client/executor/zenstack-driver.ts b/packages/runtime/src/client/executor/zenstack-driver.ts index 651c3eaf..9a0a32c3 100644 --- a/packages/runtime/src/client/executor/zenstack-driver.ts +++ b/packages/runtime/src/client/executor/zenstack-driver.ts @@ -6,12 +6,12 @@ import type { CompiledQuery, DatabaseConnection, Driver, Log, QueryResult, Trans export class ZenStackDriver implements Driver { readonly #driver: Driver; readonly #log: Log; - txConnection: DatabaseConnection | undefined; #initPromise?: Promise; #initDone: boolean; #destroyPromise?: Promise; #connections = new WeakSet(); + #txConnections = new WeakMap Promise>>(); constructor(driver: Driver, log: Log) { this.#initDone = false; @@ -67,23 +67,33 @@ export class ZenStackDriver implements Driver { async beginTransaction(connection: DatabaseConnection, settings: TransactionSettings): Promise { const result = await this.#driver.beginTransaction(connection, settings); - this.txConnection = connection; + this.#txConnections.set(connection, []); return result; } - commitTransaction(connection: DatabaseConnection): Promise { + async commitTransaction(connection: DatabaseConnection): Promise { try { - return this.#driver.commitTransaction(connection); - } finally { - this.txConnection = undefined; + const result = await this.#driver.commitTransaction(connection); + const callbacks = this.#txConnections.get(connection); + // delete from the map immediately to avoid accidental re-triggering + this.#txConnections.delete(connection); + if (callbacks) { + for (const callback of callbacks) { + await callback(); + } + } + return result; + } catch (err) { + this.#txConnections.delete(connection); + throw err; } } - rollbackTransaction(connection: DatabaseConnection): Promise { + async rollbackTransaction(connection: DatabaseConnection): Promise { try { - return this.#driver.rollbackTransaction(connection); + return await this.#driver.rollbackTransaction(connection); } finally { - this.txConnection = undefined; + this.#txConnections.delete(connection); } } @@ -175,6 +185,22 @@ export class ZenStackDriver implements Driver { #calculateDurationMillis(startTime: number): number { return performanceNow() - startTime; } + + isTransactionConnection(connection: DatabaseConnection): boolean { + return this.#txConnections.has(connection); + } + + registerTransactionCommitCallback(connection: DatabaseConnection, callback: () => Promise): void { + if (!this.#txConnections.has(connection)) { + return; + } + const callbacks = this.#txConnections.get(connection); + if (callbacks) { + callbacks.push(callback); + } else { + this.#txConnections.set(connection, [callback]); + } + } } export function performanceNow() { diff --git a/packages/runtime/src/client/executor/zenstack-query-executor.ts b/packages/runtime/src/client/executor/zenstack-query-executor.ts index 97cdd9cb..454c0a82 100644 --- a/packages/runtime/src/client/executor/zenstack-query-executor.ts +++ b/packages/runtime/src/client/executor/zenstack-query-executor.ts @@ -10,6 +10,7 @@ import { UpdateQueryNode, WhereNode, type ConnectionProvider, + type DatabaseConnection, type DialectAdapter, type KyselyPlugin, type OperationNode, @@ -25,7 +26,7 @@ import type { GetModels, SchemaDef } from '../../schema'; import { type ClientImpl } from '../client-impl'; import type { ClientContract } from '../contract'; import { InternalError, QueryError } from '../errors'; -import type { MutationInterceptionFilterResult, OnKyselyQueryCallback } from '../plugin'; +import type { AfterEntityMutationCallback, MutationInterceptionFilterResult, OnKyselyQueryCallback } from '../plugin'; import { QueryNameMapper } from './name-mapper'; import type { ZenStackDriver } from './zenstack-driver'; @@ -54,7 +55,7 @@ export class ZenStackQueryExecutor extends DefaultQuer return this.client.$options; } - override async executeQuery(compiledQuery: CompiledQuery, queryId: QueryId) { + override async executeQuery(compiledQuery: CompiledQuery, _queryId: QueryId) { let queryNode = compiledQuery.query; let mutationInterceptionInfo: Awaited>; if (this.isMutationNode(queryNode) && this.hasMutationHooks) { @@ -67,14 +68,14 @@ export class ZenStackQueryExecutor extends DefaultQuer await this.callBeforeMutationHooks(queryNode, mutationInterceptionInfo); } - // TODO: make sure insert and delete return rows + // TODO: make sure insert and update return rows const oldQueryNode = queryNode; if ( - (InsertQueryNode.is(queryNode) || DeleteQueryNode.is(queryNode)) && + (InsertQueryNode.is(queryNode) || UpdateQueryNode.is(queryNode)) && mutationInterceptionInfo?.loadAfterMutationEntity ) { // need to make sure the query node has "returnAll" - // for insert and delete queries + // for insert and update queries queryNode = { ...queryNode, returning: ReturningNode.create([SelectionNode.createSelectAll()]), @@ -84,18 +85,23 @@ export class ZenStackQueryExecutor extends DefaultQuer // proceed with the query with kysely interceptors // if the query is a raw query, we need to carry over the parameters const queryParams = (compiledQuery as any).$raw ? compiledQuery.parameters : undefined; - const result = await this.proceedQueryWithKyselyInterceptors(queryNode, queryParams, queryId); + const result = await this.proceedQueryWithKyselyInterceptors(queryNode, queryParams); // call after mutation hooks if (this.isMutationNode(queryNode)) { - await this.callAfterQueryInterceptionFilters(result, queryNode, mutationInterceptionInfo); + await this.callAfterMutationHooks( + result.result, + queryNode, + mutationInterceptionInfo, + result.connection, + ); } if (oldQueryNode !== queryNode) { // TODO: trim the result to the original query node } - return result; + return result.result; }; return task(); @@ -104,14 +110,8 @@ export class ZenStackQueryExecutor extends DefaultQuer private proceedQueryWithKyselyInterceptors( queryNode: RootOperationNode, parameters: readonly unknown[] | undefined, - queryId: QueryId, ) { - let proceed = (q: RootOperationNode) => this.proceedQuery(q, parameters, queryId); - - // TODO: transactional hooks - // const makeTx = (p: typeof proceed) => (callback: OnKyselyQueryTransactionCallback) => { - // return this.executeWithTransaction(() => callback(p)); - // }; + let proceed = (q: RootOperationNode) => this.proceedQuery(q, parameters); const hooks: OnKyselyQueryCallback[] = []; // tsc perf @@ -123,23 +123,30 @@ export class ZenStackQueryExecutor extends DefaultQuer for (const hook of hooks) { const _proceed = proceed; - proceed = (query: RootOperationNode) => { - return hook!({ + proceed = async (query: RootOperationNode) => { + let connection: DatabaseConnection | undefined; + const _p = async (q: RootOperationNode) => { + const r = await _proceed(q); + // carry over the database connection returned by the original executor + connection = r.connection; + return r.result; + }; + + const hookResult = await hook!({ client: this.client as ClientContract, schema: this.client.$schema, kysely: this.kysely, query, - proceed: _proceed, - // TODO: transactional hooks - // transaction: makeTx(_proceed), + proceed: _p, }); + return { result: hookResult, connection: connection! }; }; } return proceed(queryNode); } - private async proceedQuery(query: RootOperationNode, parameters: readonly unknown[] | undefined, queryId: QueryId) { + private async proceedQuery(query: RootOperationNode, parameters: readonly unknown[] | undefined) { // run built-in transformers const finalQuery = this.nameMapper.transformNode(query); let compiled = this.compileQuery(finalQuery); @@ -148,14 +155,10 @@ export class ZenStackQueryExecutor extends DefaultQuer } try { - return await super.executeQuery(compiled, queryId); - - // TODO: transaction hooks - // return this.driver.txConnection - // ? await super - // .withConnectionProvider(new SingleConnectionProvider(this.driver.txConnection)) - // .executeQuery(compiled, queryId) - // : await super.executeQuery(compiled, queryId); + return await this.provideConnection(async (connection) => { + const result = await connection.executeQuery(compiled); + return { result, connection }; + }); } catch (err) { let message = `Failed to execute query: ${err}, sql: ${compiled.sql}`; if (this.options.debug) { @@ -201,6 +204,7 @@ export class ZenStackQueryExecutor extends DefaultQuer [plugin, ...this.plugins], ); } + override withoutPlugins() { return new ZenStackQueryExecutor( this.client, @@ -310,10 +314,11 @@ export class ZenStackQueryExecutor extends DefaultQuer } if (this.options.plugins) { + const mutationModel = this.getMutationModel(queryNode); for (const plugin of this.options.plugins) { if (plugin.beforeEntityMutation) { await plugin.beforeEntityMutation({ - model: this.getMutationModel(queryNode), + model: mutationModel, action: mutationInterceptionInfo.action, queryNode, entities: mutationInterceptionInfo.beforeMutationEntities, @@ -323,39 +328,59 @@ export class ZenStackQueryExecutor extends DefaultQuer } } - private async callAfterQueryInterceptionFilters( + private async callAfterMutationHooks( queryResult: QueryResult, queryNode: OperationNode, mutationInterceptionInfo: Awaited>, + connection: DatabaseConnection, ) { if (!mutationInterceptionInfo?.intercept) { return; } - if (this.options.plugins) { - const mutationModel = this.getMutationModel(queryNode); - for (const plugin of this.options.plugins) { - if (plugin.afterEntityMutation) { - let afterMutationEntities: Record[] | undefined = undefined; - if (mutationInterceptionInfo.loadAfterMutationEntity) { - if (UpdateQueryNode.is(queryNode)) { - afterMutationEntities = await this.loadEntities( - mutationModel, - mutationInterceptionInfo.where, - ); - } else { - afterMutationEntities = queryResult.rows as Record[]; - } - } - - await plugin.afterEntityMutation({ - model: this.getMutationModel(queryNode), + const hooks: AfterEntityMutationCallback[] = []; + // tsc perf + for (const plugin of this.options.plugins ?? []) { + if (plugin.afterEntityMutation) { + hooks.push(plugin.afterEntityMutation.bind(plugin)); + } + } + if (hooks.length === 0) { + return; + } + + const mutationModel = this.getMutationModel(queryNode); + const inTransaction = this.driver.isTransactionConnection(connection); + + for (const hook of hooks) { + let afterMutationEntities: Record[] | undefined = undefined; + if (mutationInterceptionInfo.loadAfterMutationEntity) { + if (InsertQueryNode.is(queryNode) || UpdateQueryNode.is(queryNode)) { + afterMutationEntities = queryResult.rows as Record[]; + } + } + + const action = async () => { + try { + await hook({ + model: mutationModel, action: mutationInterceptionInfo.action, queryNode, beforeMutationEntities: mutationInterceptionInfo.beforeMutationEntities, afterMutationEntities, }); + } catch (err) { + console.error(`Error in afterEntityMutation hook for model "${mutationModel}": ${err}`); } + }; + + if (inTransaction) { + // if we're in a transaction, the after mutation hooks should be triggered after the transaction is committed, + // only register a callback here + this.driver.registerTransactionCommitCallback(connection, action); + } else { + // otherwise trigger the hooks immediately + await action(); } } } diff --git a/packages/runtime/src/client/plugin.ts b/packages/runtime/src/client/plugin.ts index 04320935..3b006862 100644 --- a/packages/runtime/src/client/plugin.ts +++ b/packages/runtime/src/client/plugin.ts @@ -100,6 +100,18 @@ export type OnKyselyQueryCallback = ( args: OnKyselyQueryArgs, ) => Promise>; +export type MutationInterceptionFilter = ( + args: MutationHooksArgs, +) => MaybePromise; + +export type BeforeEntityMutationCallback = ( + args: PluginBeforeEntityMutationArgs, +) => MaybePromise; + +export type AfterEntityMutationCallback = ( + args: PluginAfterEntityMutationArgs, +) => MaybePromise; + /** * ZenStack runtime plugin. */ @@ -133,14 +145,14 @@ export interface RuntimePlugin { * This callback determines whether a mutation should be intercepted, and if so, * what data should be loaded before and after the mutation. */ - mutationInterceptionFilter?: (args: MutationHooksArgs) => MaybePromise; + mutationInterceptionFilter?: MutationInterceptionFilter; /** * Called before an entity is mutated. * @param args.entity Only available if `loadBeforeMutationEntity` is set to true in the * return value of {@link RuntimePlugin.mutationInterceptionFilter}. */ - beforeEntityMutation?: (args: PluginBeforeEntityMutationArgs) => MaybePromise; + beforeEntityMutation?: BeforeEntityMutationCallback; /** * Called after an entity is mutated. @@ -149,7 +161,7 @@ export interface RuntimePlugin { * @param args.afterMutationEntity Only available if `loadAfterMutationEntity` is set to true in the * return value of {@link RuntimePlugin.mutationInterceptionFilter}. */ - afterEntityMutation?: (args: PluginAfterEntityMutationArgs) => MaybePromise; + afterEntityMutation?: AfterEntityMutationCallback; } type OnQueryHooks = { diff --git a/packages/runtime/test/plugin/kysely-on-query.test.ts b/packages/runtime/test/plugin/kysely-on-query.test.ts index 90c47bb8..2e750b6e 100644 --- a/packages/runtime/test/plugin/kysely-on-query.test.ts +++ b/packages/runtime/test/plugin/kysely-on-query.test.ts @@ -112,43 +112,6 @@ describe('Kysely onQuery tests', () => { }); }); - // TODO: revisit transactions - // it('rolls back on error when a transaction is used', async () => { - // const client = _client.$use({ - // id: 'test-plugin', - // async onKyselyQuery({ kysely, proceed, transaction, query }) { - // if (query.kind !== 'InsertQueryNode') { - // return proceed(query); - // } - - // return transaction(async (txProceed) => { - // const result = await txProceed(query); - - // // create a post for the user - // const now = new Date().toISOString(); - // const createPost = kysely.insertInto('Post').values({ - // id: '1', - // title: 'Post1', - // authorId: 'none-exist', - // updatedAt: now, - // }); - // await txProceed(createPost.toOperationNode()); - - // return result; - // }); - // }, - // }); - - // await expect( - // client.user.create({ - // data: { id: '1', email: 'u1@test.com' }, - // }), - // ).rejects.toThrow('constraint failed'); - - // await expect(client.user.findFirst()).toResolveNull(); - // await expect(client.post.findFirst()).toResolveNull(); - // }); - it('works with multiple interceptors', async () => { let called1 = false; let called2 = false; @@ -205,107 +168,6 @@ describe('Kysely onQuery tests', () => { await expect(called2).toBe(true); }); - // TODO: revisit transactions - // it('works with multiple transactional interceptors - order 1', async () => { - // let called1 = false; - // let called2 = false; - - // const client = _client - // .$use({ - // id: 'test-plugin', - // async onKyselyQuery({ query, proceed }) { - // if (query.kind !== 'InsertQueryNode') { - // return proceed(query); - // } - // called1 = true; - // await proceed(query); - // throw new Error('test error'); - // }, - // }) - // .$use({ - // id: 'test-plugin2', - // onKyselyQuery({ query, proceed, transaction }) { - // if (query.kind !== 'InsertQueryNode') { - // return proceed(query); - // } - // called2 = true; - // return transaction(async (txProceed) => { - // const valueList = [ - // ...(((query as InsertQueryNode).values as ValuesNode).values[0] as PrimitiveValueListNode) - // .values, - // ]; - // valueList[0] = 'u2@test.com'; - // valueList[1] = 'Marvin1'; - // const newQuery = InsertQueryNode.cloneWith(query as InsertQueryNode, { - // values: ValuesNode.create([PrimitiveValueListNode.create(valueList)]), - // }); - // return txProceed(newQuery); - // }); - // }, - // }); - - // await expect( - // client.user.create({ - // data: { email: 'u1@test.com', name: 'Marvin' }, - // }), - // ).rejects.toThrow('test error'); - - // await expect(called1).toBe(true); - // await expect(called2).toBe(true); - // await expect(client.user.findFirst()).toResolveNull(); - // }); - - // TODO: revisit transactions - // it('works with multiple transactional interceptors - order 2', async () => { - // let called1 = false; - // let called2 = false; - - // const client = _client - // .$use({ - // id: 'test-plugin', - // async onKyselyQuery({ query, proceed, transaction }) { - // if (query.kind !== 'InsertQueryNode') { - // return proceed(query); - // } - // called1 = true; - - // return transaction(async (txProceed) => { - // await txProceed(query); - // throw new Error('test error'); - // }); - // }, - // }) - // .$use({ - // id: 'test-plugin2', - // onKyselyQuery({ query, proceed }) { - // if (query.kind !== 'InsertQueryNode') { - // return proceed(query); - // } - // called2 = true; - // const valueList = [ - // ...(((query as InsertQueryNode).values as ValuesNode).values[0] as PrimitiveValueListNode) - // .values, - // ]; - // valueList[0] = 'u2@test.com'; - // valueList[1] = 'Marvin1'; - // const newQuery = InsertQueryNode.cloneWith(query as InsertQueryNode, { - // values: ValuesNode.create([PrimitiveValueListNode.create(valueList)]), - // }); - // return proceed(newQuery); - // }, - // }); - - // await expect( - // client.user.create({ - // data: { email: 'u1@test.com', name: 'Marvin' }, - // }), - // ).rejects.toThrow('test error'); - - // await expect(called1).toBe(true); - // await expect(called2).toBe(true); - // await expect(client.user.findFirst()).toResolveNull(); - // }); - it('works with multiple interceptors with outer transaction', async () => { let called1 = false; let called2 = false; @@ -354,130 +216,6 @@ describe('Kysely onQuery tests', () => { await expect(called2).toBe(true); await expect(client.user.findFirst()).toResolveNull(); }); - - // TODO: revisit transactions - // it('works with nested transactional interceptors success', async () => { - // let called1 = false; - // let called2 = false; - - // const client = _client - // .$use({ - // id: 'test-plugin', - // onKyselyQuery({ query, proceed, transaction }) { - // if (query.kind !== 'InsertQueryNode') { - // return proceed(query); - // } - // called1 = true; - // return transaction(async (txProceed) => { - // const valueList = [ - // ...(((query as InsertQueryNode).values as ValuesNode).values[0] as PrimitiveValueListNode) - // .values, - // ]; - // valueList[1] = 'Marvin2'; - // const newQuery = InsertQueryNode.cloneWith(query as InsertQueryNode, { - // values: ValuesNode.create([PrimitiveValueListNode.create(valueList)]), - // }); - // return txProceed(newQuery); - // }); - // }, - // }) - // .$use({ - // id: 'test-plugin2', - // onKyselyQuery({ query, proceed, transaction }) { - // if (query.kind !== 'InsertQueryNode') { - // return proceed(query); - // } - // called2 = true; - // return transaction(async (txProceed) => { - // const valueList = [ - // ...(((query as InsertQueryNode).values as ValuesNode).values[0] as PrimitiveValueListNode) - // .values, - // ]; - // valueList[0] = 'u2@test.com'; - // valueList[1] = 'Marvin1'; - // const newQuery = InsertQueryNode.cloneWith(query as InsertQueryNode, { - // values: ValuesNode.create([PrimitiveValueListNode.create(valueList)]), - // }); - // return txProceed(newQuery); - // }); - // }, - // }); - - // await expect( - // client.user.create({ - // data: { email: 'u1@test.com', name: 'Marvin' }, - // }), - // ).resolves.toMatchObject({ - // email: 'u2@test.com', - // name: 'Marvin2', - // }); - // await expect(called1).toBe(true); - // await expect(called2).toBe(true); - // }); - - // TODO: revisit transactions - // it('works with nested transactional interceptors roll back', async () => { - // let called1 = false; - // let called2 = false; - - // const client = _client - // .$use({ - // id: 'test-plugin', - // onKyselyQuery({ kysely, query, proceed, transaction }) { - // if (query.kind !== 'InsertQueryNode') { - // return proceed(query); - // } - // called1 = true; - // return transaction(async (txProceed) => { - // const valueList = [ - // ...(((query as InsertQueryNode).values as ValuesNode).values[0] as PrimitiveValueListNode) - // .values, - // ]; - // valueList[1] = 'Marvin2'; - // const newQuery = InsertQueryNode.cloneWith(query as InsertQueryNode, { - // values: ValuesNode.create([PrimitiveValueListNode.create(valueList)]), - // }); - // const result = await txProceed(newQuery); - - // // create a post for the user - // await txProceed(createPost(kysely, result)); - - // throw new Error('test error'); - // }); - // }, - // }) - // .$use({ - // id: 'test-plugin2', - // onKyselyQuery({ query, proceed, transaction }) { - // if (query.kind !== 'InsertQueryNode') { - // return proceed(query); - // } - // called2 = true; - // return transaction(async (txProceed) => { - // const valueList = [ - // ...(((query as InsertQueryNode).values as ValuesNode).values[0] as PrimitiveValueListNode) - // .values, - // ]; - // valueList[0] = 'u2@test.com'; - // valueList[1] = 'Marvin1'; - // const newQuery = InsertQueryNode.cloneWith(query as InsertQueryNode, { - // values: ValuesNode.create([PrimitiveValueListNode.create(valueList)]), - // }); - // return txProceed(newQuery); - // }); - // }, - // }); - - // await expect( - // client.user.create({ - // data: { email: 'u1@test.com', name: 'Marvin' }, - // }), - // ).rejects.toThrow('test error'); - // await expect(called1).toBe(true); - // await expect(called2).toBe(true); - // await expect(client.user.findFirst()).toResolveNull(); - // await expect(client.post.findFirst()).toResolveNull(); - // }); }); function createPost(kysely: Kysely, userRows: QueryResult) { diff --git a/packages/runtime/test/plugin/mutation-hooks.test.ts b/packages/runtime/test/plugin/mutation-hooks.test.ts index 08023e28..557fd728 100644 --- a/packages/runtime/test/plugin/mutation-hooks.test.ts +++ b/packages/runtime/test/plugin/mutation-hooks.test.ts @@ -300,57 +300,97 @@ describe('Entity lifecycle tests', () => { expect(post2Intercepted).toBe(true); }); - // // TODO: revisit mutation hooks and transactions - // it.skip('proceeds with mutation even when hooks throw', async () => { - // let userIntercepted = false; - - // const client = _client.$use({ - // id: 'test', - // afterEntityMutation() { - // userIntercepted = true; - // throw new Error('trigger error'); - // }, - // }); - - // let gotError = false; - // try { - // await client.user.create({ - // data: { email: 'u1@test.com' }, - // }); - // } catch (err) { - // gotError = true; - // expect((err as Error).message).toContain('trigger error'); - // } - - // expect(userIntercepted).toBe(true); - // expect(gotError).toBe(true); - // console.log(await client.user.findMany()); - // await expect(client.user.findMany()).toResolveWithLength(1); - // }); - - it('rolls back when hooks throw if transaction is used', async () => { - let userIntercepted = false; + it('does not affect the database operation if an afterEntityMutation hook throws', async () => { + let intercepted = false; const client = _client.$use({ id: 'test', afterEntityMutation() { - userIntercepted = true; + intercepted = true; throw new Error('trigger rollback'); }, }); - let gotError = false; + await client.user.create({ + data: { email: 'u1@test.com' }, + }); + + expect(intercepted).toBe(true); + await expect(client.user.findMany()).toResolveWithLength(1); + }); + + it('does not trigger afterEntityMutation hook if a transaction is rolled back', async () => { + let intercepted = false; + + const client = _client.$use({ + id: 'test', + afterEntityMutation() { + intercepted = true; + }, + }); + try { - await client.user.create({ - data: { email: 'u1@test.com' }, + await client.$transaction(async (tx) => { + await tx.user.create({ + data: { email: 'u1@test.com' }, + }); + throw new Error('trigger rollback'); }); - } catch (err) { - gotError = true; - expect((err as Error).message).toContain('trigger rollback'); + } catch { + // noop } - expect(userIntercepted).toBe(true); - expect(gotError).toBe(true); await expect(client.user.findMany()).toResolveWithLength(0); + expect(intercepted).toBe(false); + }); + + it('triggers multiple afterEntityMutation hooks for multiple mutations', async () => { + const triggered: any[] = []; + + const client = _client.$use({ + id: 'test', + mutationInterceptionFilter: () => { + return { + intercept: true, + loadBeforeMutationEntity: true, + loadAfterMutationEntity: true, + }; + }, + afterEntityMutation(args) { + triggered.push(args); + }, + }); + + await client.$transaction(async (tx) => { + await tx.user.create({ + data: { email: 'u1@test.com' }, + }); + await tx.user.update({ + where: { email: 'u1@test.com' }, + data: { email: 'u2@test.com' }, + }); + await tx.user.delete({ where: { email: 'u2@test.com' } }); + }); + + expect(triggered).toEqual([ + expect.objectContaining({ + action: 'create', + model: 'User', + beforeMutationEntities: undefined, + afterMutationEntities: [expect.objectContaining({ email: 'u1@test.com' })], + }), + expect.objectContaining({ + action: 'update', + model: 'User', + beforeMutationEntities: [expect.objectContaining({ email: 'u1@test.com' })], + afterMutationEntities: [expect.objectContaining({ email: 'u2@test.com' })], + }), + expect.objectContaining({ + action: 'delete', + model: 'User', + beforeMutationEntities: [expect.objectContaining({ email: 'u2@test.com' })], + afterMutationEntities: undefined, + }), + ]); }); }); diff --git a/packages/runtime/test/plugin/query-lifecycle.test.ts b/packages/runtime/test/plugin/query-lifecycle.test.ts index 70658c55..106b55af 100644 --- a/packages/runtime/test/plugin/query-lifecycle.test.ts +++ b/packages/runtime/test/plugin/query-lifecycle.test.ts @@ -254,40 +254,6 @@ describe('Query interception tests', () => { ).toResolveTruthy(); }); - // TODO: revisit transactional hooks - it.skip('rolls back the effect with transaction', async () => { - let hooksCalled = false; - const client = _client.$use({ - id: 'test-plugin', - onQuery: { - user: { - create: async (ctx) => { - hooksCalled = true; - return ctx.client.$transaction(async (_tx) => { - await ctx.query(ctx.args /*, tx*/); - throw new Error('trigger error'); - }); - }, - }, - }, - }); - - try { - await client.user.create({ - data: { id: '1', email: 'u1@test.com' }, - }); - } catch { - // no-op - } - - expect(hooksCalled).toBe(true); - await expect( - _client.user.findFirst({ - where: { id: '1' }, - }), - ).toResolveNull(); - }); - it('supports plugin encapsulation', async () => { const user = await _client.user.create({ data: { email: 'u1@test.com' },