Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 16 additions & 6 deletions packages/runtime/src/client/client-impl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -255,23 +255,33 @@ export class ClientImpl<Schema extends SchemaDef> {
}

$use(plugin: RuntimePlugin<Schema>) {
const newOptions = {
// tsc perf
const newPlugins: RuntimePlugin<Schema>[] = [...(this.$options.plugins ?? []), plugin];
const newOptions: ClientOptions<Schema> = {
...this.options,
plugins: [...(this.options.plugins ?? []), plugin],
plugins: newPlugins,
};
return new ClientImpl<Schema>(this.schema, newOptions, this);
}

$unuse(pluginId: string) {
const newOptions = {
// tsc perf
const newPlugins: RuntimePlugin<Schema>[] = [];
for (const plugin of this.options.plugins ?? []) {
if (plugin.id !== pluginId) {
newPlugins.push(plugin);
}
}
const newOptions: ClientOptions<Schema> = {
...this.options,
plugins: this.options.plugins?.filter((p) => p.id !== pluginId),
plugins: newPlugins,
};
return new ClientImpl<Schema>(this.schema, newOptions, this);
}

$unuseAll() {
const newOptions = {
// tsc perf
const newOptions: ClientOptions<Schema> = {
...this.options,
plugins: [] as RuntimePlugin<Schema>[],
};
Expand Down Expand Up @@ -388,7 +398,7 @@ function createModelCrudHandler<Schema extends SchemaDef, Model extends GetModel
for (const plugin of plugins) {
if (plugin.onQuery && typeof plugin.onQuery === 'object') {
// for each model key or "$allModels"
for (const [_model, modelHooks] of Object.entries(plugin.onQuery)) {
for (const [_model, modelHooks] of Object.entries<any>(plugin.onQuery)) {
if (_model === lowerCaseFirst(model) || _model === '$allModels') {
if (modelHooks && typeof modelHooks === 'object') {
// for each operation key or "$allOperations"
Expand Down
13 changes: 8 additions & 5 deletions packages/runtime/src/client/executor/zenstack-query-executor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import type { GetModels, SchemaDef } from '../../schema';
import { type ClientImpl } from '../client-impl';
import type { ClientContract } from '../contract';
import { InternalError, QueryError } from '../errors';
import type { MutationInterceptionFilterResult } from '../plugin';
import type { MutationInterceptionFilterResult, OnKyselyQueryCallback } from '../plugin';
import { QueryNameMapper } from './name-mapper';
import type { ZenStackDriver } from './zenstack-driver';

Expand Down Expand Up @@ -113,10 +113,13 @@ export class ZenStackQueryExecutor<Schema extends SchemaDef> extends DefaultQuer
// return this.executeWithTransaction(() => callback(p));
// };

const hooks =
this.options.plugins
?.filter((plugin) => typeof plugin.onKyselyQuery === 'function')
.map((plugin) => plugin.onKyselyQuery!.bind(plugin)) ?? [];
const hooks: OnKyselyQueryCallback<Schema>[] = [];
// tsc perf
for (const plugin of this.client.$options.plugins ?? []) {
if (plugin.onKyselyQuery) {
hooks.push(plugin.onKyselyQuery.bind(plugin));
}
}

for (const hook of hooks) {
const _proceed = proceed;
Expand Down
6 changes: 5 additions & 1 deletion packages/runtime/src/client/plugin.ts
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,10 @@ export type OnKyselyQueryArgs<Schema extends SchemaDef> = {

export type ProceedKyselyQueryFunction = (query: RootOperationNode) => Promise<QueryResult<any>>;

export type OnKyselyQueryCallback<Schema extends SchemaDef> = (
args: OnKyselyQueryArgs<Schema>,
) => Promise<QueryResult<UnknownRow>>;

/**
* ZenStack runtime plugin.
*/
Expand Down Expand Up @@ -123,7 +127,7 @@ export interface RuntimePlugin<Schema extends SchemaDef = SchemaDef> {
/**
* Intercepts a Kysely query.
*/
onKyselyQuery?: (args: OnKyselyQueryArgs<Schema>) => Promise<QueryResult<UnknownRow>>;
onKyselyQuery?: OnKyselyQueryCallback<Schema>;

/**
* This callback determines whether a mutation should be intercepted, and if so,
Expand Down
16 changes: 8 additions & 8 deletions packages/runtime/test/client-api/client-specs.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@ export function createClientSpecs(dbName: string, logQueries = false, providers:
{
provider: 'sqlite' as const,
schema: getSchema('sqlite'),
createClient: async () => {
const client = await makeSqliteClient(getSchema('sqlite'), {
createClient: async (): Promise<ClientContract<typeof schema>> => {
// tsc perf
return makeSqliteClient<any>(getSchema('sqlite'), {
log: logQueries ? logger('sqlite') : undefined,
});
return client as ClientContract<typeof schema>;
}) as unknown as ClientContract<typeof schema>;
},
},
]
Expand All @@ -29,11 +29,11 @@ export function createClientSpecs(dbName: string, logQueries = false, providers:
{
provider: 'postgresql' as const,
schema: getSchema('postgresql'),
createClient: async () => {
const client = await makePostgresClient(getSchema('postgresql'), dbName, {
createClient: async (): Promise<ClientContract<typeof schema>> => {
// tsc perf
return makePostgresClient<any>(getSchema('postgresql'), dbName, {
log: logQueries ? logger('postgresql') : undefined,
});
return client as unknown as ClientContract<typeof schema>;
}) as unknown as ClientContract<typeof schema>;
},
},
]
Expand Down
2 changes: 1 addition & 1 deletion packages/runtime/test/policy/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,6 @@ export async function createPolicyTestClient<Schema extends SchemaDef>(
{
...options,
plugins: [new PolicyPlugin()],
} as CreateTestClientOptions<SchemaDef>,
} as any,
);
}
4 changes: 2 additions & 2 deletions packages/runtime/test/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ type PostgresSchema = SchemaDef & { provider: { type: 'postgresql' } };
export async function makeSqliteClient<Schema extends SqliteSchema>(
schema: Schema,
extraOptions?: Partial<ClientOptions<Schema>>,
) {
): Promise<ClientContract<Schema>> {
const client = new ZenStackClient(schema, {
...extraOptions,
dialectConfig: { database: new SQLite(':memory:') },
Expand All @@ -37,7 +37,7 @@ export async function makePostgresClient<Schema extends PostgresSchema>(
schema: Schema,
dbName: string,
extraOptions?: Partial<ClientOptions<Schema>>,
) {
): Promise<ClientContract<Schema>> {
invariant(dbName, 'dbName is required');
const pgClient = new PGClient(TEST_PG_CONFIG);
await pgClient.connect();
Expand Down