Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions TODO.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
- [ ] format
- [ ] db seed
- [ ] ZModel
- [ ] Import
- [ ] View support
- [ ] ORM
- [x] Create
Expand Down Expand Up @@ -80,8 +81,8 @@
- [ ] Strict undefined checks
- [ ] DbNull vs JsonNull
- [ ] Benchmark
- [ ] Plugin
- [ ] Post-mutation hooks should be called after transaction is committed
- [x] Plugin
- [x] Post-mutation hooks should be called after transaction is committed
- [x] TypeDef and mixin
- [ ] Strongly typed JSON
- [x] Polymorphism
Expand Down
4 changes: 2 additions & 2 deletions packages/runtime/src/client/crud/operations/aggregate.ts
Original file line number Diff line number Diff line change
Expand Up @@ -112,11 +112,11 @@ export class AggregateOperationHandler<Schema extends SchemaDef> extends BaseOpe
}
}

const result = await query.executeTakeFirstOrThrow();
const result = await this.executeQuery(this.kysely, query, 'aggregate');
const ret: any = {};

// postprocess result to convert flat fields into nested objects
for (const [key, value] of Object.entries(result as object)) {
for (const [key, value] of Object.entries(result.rows[0] as object)) {
if (key === '_count') {
ret[key] = value;
continue;
Expand Down
73 changes: 50 additions & 23 deletions packages/runtime/src/client/crud/operations/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@ import {
ExpressionWrapper,
sql,
UpdateResult,
type Compilable,
type IsolationLevel,
type Expression as KyselyExpression,
type QueryResult,
type SelectQueryBuilder,
} from 'kysely';
import { nanoid } from 'nanoid';
Expand Down Expand Up @@ -125,7 +127,11 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
return getField(this.schema, model, field);
}

protected exists(kysely: ToKysely<Schema>, model: GetModels<Schema>, filter: any): Promise<unknown | undefined> {
protected async exists(
kysely: ToKysely<Schema>,
model: GetModels<Schema>,
filter: any,
): Promise<unknown | undefined> {
const idFields = getIdFields(this.schema, model);
const _filter = flattenCompoundUniqueFilters(this.schema, model, filter);
const query = kysely
Expand All @@ -134,7 +140,7 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
.select(idFields.map((f) => kysely.dynamic.ref(f)))
.limit(1)
.modifyEnd(this.makeContextComment({ model, operation: 'read' }));
return query.executeTakeFirst();
return this.executeQueryTakeFirst(kysely, query, 'exists');
}

protected async read(
Expand Down Expand Up @@ -444,7 +450,7 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
operation: 'update',
}),
);
return query.execute();
return this.executeQuery(kysely, query, 'update');
};
}
}
Expand Down Expand Up @@ -511,10 +517,10 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
}),
);

const createdEntity = await query.executeTakeFirst();
const createdEntity = await this.executeQueryTakeFirst(kysely, query, 'create');

// try {
// createdEntity = await query.executeTakeFirst();
// createdEntity = await this.executeQueryTakeFirst(kysely, query, 'create');
// } catch (err) {
// const { sql, parameters } = query.compile();
// throw new QueryError(
Expand Down Expand Up @@ -893,8 +899,8 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
);

if (!returnData) {
const result = await query.executeTakeFirstOrThrow();
return { count: Number(result.numInsertedOrUpdatedRows) } as Result;
const result = await this.executeQuery(kysely, query, 'createMany');
return { count: Number(result.numAffectedRows) } as Result;
} else {
const idFields = getIdFields(this.schema, model);
const result = await query.returning(idFields as any).execute();
Expand Down Expand Up @@ -1160,10 +1166,10 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
}),
);

const updatedEntity = await query.executeTakeFirst();
const updatedEntity = await this.executeQueryTakeFirst(kysely, query, 'update');

// try {
// updatedEntity = await query.executeTakeFirst();
// updatedEntity = await this.executeQueryTakeFirst(kysely, query, 'update');
// } catch (err) {
// const { sql, parameters } = query.compile();
// throw new QueryError(
Expand Down Expand Up @@ -1401,8 +1407,8 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
query = query.modifyEnd(this.makeContextComment({ model, operation: 'update' }));

if (!returnData) {
const result = await query.executeTakeFirstOrThrow();
return { count: Number(result.numUpdatedRows) } as Result;
const result = await this.executeQuery(kysely, query, 'update');
return { count: Number(result.numAffectedRows) } as Result;
} else {
const idFields = getIdFields(this.schema, model);
const result = await query.returning(idFields as any).execute();
Expand Down Expand Up @@ -1636,7 +1642,7 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
fromRelation.model,
fromRelation.field,
);
let updateResult: UpdateResult;
let updateResult: QueryResult<unknown>;

if (ownedByModel) {
// set parent fk directly
Expand Down Expand Up @@ -1665,7 +1671,7 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
operation: 'update',
}),
);
updateResult = await query.executeTakeFirstOrThrow();
updateResult = await this.executeQuery(kysely, query, 'connect');
} else {
// disconnect current if it's a one-one relation
const relationFieldDef = this.requireField(fromRelation.model, fromRelation.field);
Expand All @@ -1681,7 +1687,7 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
operation: 'update',
}),
);
await query.execute();
await this.executeQuery(kysely, query, 'disconnect');
}

// connect
Expand All @@ -1703,11 +1709,11 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
operation: 'update',
}),
);
updateResult = await query.executeTakeFirstOrThrow();
updateResult = await this.executeQuery(kysely, query, 'connect');
}

// validate connect result
if (_data.length > updateResult.numUpdatedRows) {
if (_data.length > updateResult.numAffectedRows!) {
// some entities were not connected
throw new NotFoundError(model);
}
Expand Down Expand Up @@ -1821,7 +1827,7 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
operation: 'update',
}),
);
await query.executeTakeFirstOrThrow();
await this.executeQuery(kysely, query, 'disconnect');
} else {
// disconnect
const query = kysely
Expand All @@ -1841,7 +1847,7 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
operation: 'update',
}),
);
await query.executeTakeFirstOrThrow();
await this.executeQuery(kysely, query, 'disconnect');
}
}
}
Expand Down Expand Up @@ -1920,7 +1926,7 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
operation: 'update',
}),
);
await query.execute();
await this.executeQuery(kysely, query, 'disconnect');

// connect
if (_data.length > 0) {
Expand All @@ -1942,10 +1948,10 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
operation: 'update',
}),
);
const r = await query.executeTakeFirstOrThrow();
const r = await this.executeQuery(kysely, query, 'connect');

// validate result
if (_data.length > r.numUpdatedRows!) {
if (_data.length > r.numAffectedRows!) {
// some entities were not connected
throw new NotFoundError(model);
}
Expand Down Expand Up @@ -2109,8 +2115,8 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
await this.processDelegateRelationDelete(kysely, modelDef, where, limit);

query = query.modifyEnd(this.makeContextComment({ model, operation: 'delete' }));
const result = await query.executeTakeFirstOrThrow();
return { count: Number(result.numDeletedRows) };
const result = await this.executeQuery(kysely, query, 'delete');
return { count: Number(result.numAffectedRows) };
}

private async processDelegateRelationDelete(
Expand Down Expand Up @@ -2240,4 +2246,25 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
}
}
}

protected makeQueryId(operation: string) {
return { queryId: `${operation}-${createId()}` };
}

protected executeQuery(kysely: ToKysely<Schema>, query: Compilable, operation: string) {
return kysely.executeQuery(query.compile(), this.makeQueryId(operation));
}

protected async executeQueryTakeFirst(kysely: ToKysely<Schema>, query: Compilable, operation: string) {
const result = await kysely.executeQuery(query.compile(), this.makeQueryId(operation));
return result.rows[0];
}

protected async executeQueryTakeFirstOrThrow(kysely: ToKysely<Schema>, query: Compilable, operation: string) {
const result = await kysely.executeQuery(query.compile(), this.makeQueryId(operation));
if (result.rows.length === 0) {
throw new QueryError('No rows found');
}
return result.rows[0];
}
}
8 changes: 4 additions & 4 deletions packages/runtime/src/client/crud/operations/count.ts
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,13 @@ export class CountOperationHandler<Schema extends SchemaDef> extends BaseOperati
: eb.cast(eb.fn.count(sql.ref(`${subQueryName}.${key}`)), 'integer').as(key),
),
);

return query.executeTakeFirstOrThrow();
const result = await this.executeQuery(this.kysely, query, 'count');
return result.rows[0];
} else {
// simple count all
query = query.select((eb) => eb.cast(eb.fn.countAll(), 'integer').as('count'));
const result = await query.executeTakeFirstOrThrow();
return (result as any).count as number;
const result = await this.executeQuery(this.kysely, query, 'count');
return (result.rows[0] as any).count as number;
}
}
}
2 changes: 1 addition & 1 deletion packages/runtime/src/client/crud/operations/delete.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ export class DeleteOperationHandler<Schema extends SchemaDef> extends BaseOperat

// TODO: avoid using transaction for simple delete
await this.safeTransaction(async (tx) => {
const result = await this.delete(tx, this.model, args.where, undefined);
const result = await this.delete(tx, this.model, args.where);
if (result.count === 0) {
throw new NotFoundError(this.model);
}
Expand Down
4 changes: 2 additions & 2 deletions packages/runtime/src/client/crud/operations/group-by.ts
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,8 @@ export class GroupByOperationHandler<Schema extends SchemaDef> extends BaseOpera
}
}

const result = await query.execute();
return result.map((row) => this.postProcessRow(row));
const result = await this.executeQuery(this.kysely, query, 'groupBy');
return result.rows.map((row) => this.postProcessRow(row));
}

private postProcessRow(row: any) {
Expand Down
44 changes: 35 additions & 9 deletions packages/runtime/src/client/executor/zenstack-driver.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@ import type { CompiledQuery, DatabaseConnection, Driver, Log, QueryResult, Trans
export class ZenStackDriver implements Driver {
readonly #driver: Driver;
readonly #log: Log;
txConnection: DatabaseConnection | undefined;

#initPromise?: Promise<void>;
#initDone: boolean;
#destroyPromise?: Promise<void>;
#connections = new WeakSet<DatabaseConnection>();
#txConnections = new WeakMap<DatabaseConnection, Array<() => Promise<unknown>>>();

constructor(driver: Driver, log: Log) {
this.#initDone = false;
Expand Down Expand Up @@ -67,23 +67,33 @@ export class ZenStackDriver implements Driver {

async beginTransaction(connection: DatabaseConnection, settings: TransactionSettings): Promise<void> {
const result = await this.#driver.beginTransaction(connection, settings);
this.txConnection = connection;
this.#txConnections.set(connection, []);
return result;
}

commitTransaction(connection: DatabaseConnection): Promise<void> {
async commitTransaction(connection: DatabaseConnection): Promise<void> {
try {
return this.#driver.commitTransaction(connection);
} finally {
this.txConnection = undefined;
const result = await this.#driver.commitTransaction(connection);
const callbacks = this.#txConnections.get(connection);
// delete from the map immediately to avoid accidental re-triggering
this.#txConnections.delete(connection);
if (callbacks) {
for (const callback of callbacks) {
await callback();
}
}
return result;
} catch (err) {
this.#txConnections.delete(connection);
throw err;
}
}

rollbackTransaction(connection: DatabaseConnection): Promise<void> {
async rollbackTransaction(connection: DatabaseConnection): Promise<void> {
try {
return this.#driver.rollbackTransaction(connection);
return await this.#driver.rollbackTransaction(connection);
} finally {
this.txConnection = undefined;
this.#txConnections.delete(connection);
}
}

Expand Down Expand Up @@ -175,6 +185,22 @@ export class ZenStackDriver implements Driver {
#calculateDurationMillis(startTime: number): number {
return performanceNow() - startTime;
}

isTransactionConnection(connection: DatabaseConnection): boolean {
return this.#txConnections.has(connection);
}

registerTransactionCommitCallback(connection: DatabaseConnection, callback: () => Promise<unknown>): void {
if (!this.#txConnections.has(connection)) {
return;
}
const callbacks = this.#txConnections.get(connection);
if (callbacks) {
callbacks.push(callback);
} else {
this.#txConnections.set(connection, [callback]);
}
}
}

export function performanceNow() {
Expand Down
Loading