From 429438a014eaae4fbb3effb1733219a1a3efe6bb Mon Sep 17 00:00:00 2001 From: ymc9 <104139426+ymc9@users.noreply.github.com> Date: Mon, 28 Jul 2025 16:04:31 +0800 Subject: [PATCH 1/2] feat: count and aggregate for delegate models --- .../runtime/src/client/crud/dialects/base.ts | 13 +- .../src/client/crud/operations/aggregate.ts | 29 +++- .../src/client/crud/operations/count.ts | 25 +++- .../src/client/executor/name-mapper.ts | 2 +- .../runtime/test/client-api/delegate.test.ts | 141 ++++++++++++++++++ 5 files changed, 194 insertions(+), 16 deletions(-) diff --git a/packages/runtime/src/client/crud/dialects/base.ts b/packages/runtime/src/client/crud/dialects/base.ts index c253b905..a8e5ee39 100644 --- a/packages/runtime/src/client/crud/dialects/base.ts +++ b/packages/runtime/src/client/crud/dialects/base.ts @@ -818,9 +818,13 @@ export abstract class BaseCrudDialect { return result; } - buildSelectField(query: SelectQueryBuilder, model: string, modelAlias: string, field: string) { + buildSelectField( + query: SelectQueryBuilder, + model: string, + modelAlias: string, + field: string, + ): SelectQueryBuilder { const fieldDef = requireField(this.schema, model, field); - if (fieldDef.computed) { // TODO: computed field from delegate base? return query.select((eb) => buildFieldRef(this.schema, model, field, this.options, eb).as(field)); @@ -828,10 +832,7 @@ export abstract class BaseCrudDialect { // regular field return query.select(sql.ref(`${modelAlias}.${field}`).as(field)); } else { - // field from delegate base, build a join - let result = query; - result = this.buildSelectField(result, fieldDef.originModel, fieldDef.originModel, field); - return result; + return this.buildSelectField(query, fieldDef.originModel, fieldDef.originModel, field); } } diff --git a/packages/runtime/src/client/crud/operations/aggregate.ts b/packages/runtime/src/client/crud/operations/aggregate.ts index 2bcd2014..de061545 100644 --- a/packages/runtime/src/client/crud/operations/aggregate.ts +++ b/packages/runtime/src/client/crud/operations/aggregate.ts @@ -1,3 +1,4 @@ +import type { ExpressionBuilder } from 'kysely'; import { sql } from 'kysely'; import { match } from 'ts-pattern'; import type { SchemaDef } from '../../../schema'; @@ -15,12 +16,32 @@ export class AggregateOperationHandler extends BaseOpe let query = this.kysely.selectFrom((eb) => { // nested query for filtering and pagination - // where - let subQuery = eb - .selectFrom(this.model) - .selectAll(this.model as any) // TODO: check typing + // table and where + let subQuery = this.dialect + .buildSelectModel(eb as ExpressionBuilder, this.model) .where((eb1) => this.dialect.buildFilter(eb1, this.model, this.model, parsedArgs?.where)); + // select fields: collect fields from aggregation body + const selectedFields: string[] = []; + for (const [key, value] of Object.entries(parsedArgs)) { + if (key.startsWith('_') && value && typeof value === 'object') { + // select fields + Object.entries(value) + .filter(([, val]) => val === true) + .forEach(([field]) => { + if (!selectedFields.includes(field)) selectedFields.push(field); + }); + } + } + if (selectedFields.length > 0) { + for (const field of selectedFields) { + subQuery = this.dialect.buildSelectField(subQuery, this.model, this.model, field); + } + } else { + // if no field is explicitly selected, just do a `select 1` so `_count` works + subQuery = subQuery.select(() => eb.lit(1).as('_all')); + } + // skip & take const skip = parsedArgs?.skip; let take = parsedArgs?.take; diff --git a/packages/runtime/src/client/crud/operations/count.ts b/packages/runtime/src/client/crud/operations/count.ts index e44a5897..fc22c2ec 100644 --- a/packages/runtime/src/client/crud/operations/count.ts +++ b/packages/runtime/src/client/crud/operations/count.ts @@ -1,3 +1,4 @@ +import type { ExpressionBuilder } from 'kysely'; import { sql } from 'kysely'; import type { SchemaDef } from '../../../schema'; import { BaseOperationHandler } from './base'; @@ -9,15 +10,29 @@ export class CountOperationHandler extends BaseOperati // parse args const parsedArgs = this.inputValidator.validateCountArgs(this.model, normalizedArgs); + const subQueryName = '$sub'; let query = this.kysely.selectFrom((eb) => { // nested query for filtering and pagination - let subQuery = eb - .selectFrom(this.model) - .selectAll() + + let subQuery = this.dialect + .buildSelectModel(eb as ExpressionBuilder, this.model) .where((eb1) => this.dialect.buildFilter(eb1, this.model, this.model, parsedArgs?.where)); + + if (parsedArgs?.select && typeof parsedArgs.select === 'object') { + // select fields + for (const [key, value] of Object.entries(parsedArgs.select)) { + if (key !== '_all' && value === true) { + subQuery = this.dialect.buildSelectField(subQuery, this.model, this.model, key); + } + } + } else { + // no field selection, just build a `select 1` + subQuery = subQuery.select(() => eb.lit(1).as('_all')); + } + subQuery = this.dialect.buildSkipTake(subQuery, parsedArgs?.skip, parsedArgs?.take); - return subQuery.as('$sub'); + return subQuery.as(subQueryName); }); if (parsedArgs?.select && typeof parsedArgs.select === 'object') { @@ -26,7 +41,7 @@ export class CountOperationHandler extends BaseOperati Object.keys(parsedArgs.select!).map((key) => key === '_all' ? eb.cast(eb.fn.countAll(), 'integer').as('_all') - : eb.cast(eb.fn.count(sql.ref(`$sub.${key}`)), 'integer').as(key), + : eb.cast(eb.fn.count(sql.ref(`${subQueryName}.${key}`)), 'integer').as(key), ), ); diff --git a/packages/runtime/src/client/executor/name-mapper.ts b/packages/runtime/src/client/executor/name-mapper.ts index d5a893d5..814c1ba7 100644 --- a/packages/runtime/src/client/executor/name-mapper.ts +++ b/packages/runtime/src/client/executor/name-mapper.ts @@ -263,7 +263,7 @@ export class QueryNameMapper extends OperationNodeTransformer { model = model ?? this.currentModel; const modelDef = requireModel(this.schema, model!); const scalarFields = Object.entries(modelDef.fields) - .filter(([, fieldDef]) => !fieldDef.relation && !fieldDef.computed) + .filter(([, fieldDef]) => !fieldDef.relation && !fieldDef.computed && !fieldDef.originModel) .map(([fieldName]) => fieldName); return scalarFields; } diff --git a/packages/runtime/test/client-api/delegate.test.ts b/packages/runtime/test/client-api/delegate.test.ts index 791cbf45..79f92163 100644 --- a/packages/runtime/test/client-api/delegate.test.ts +++ b/packages/runtime/test/client-api/delegate.test.ts @@ -1070,5 +1070,146 @@ model Gallery { await expect(client.asset.findMany()).toResolveWithLength(1); }); }); + + describe('Delegate aggregation tests', () => { + beforeEach(async () => { + const u = await client.user.create({ + data: { + id: 1, + email: 'u1@example.com', + }, + }); + await client.ratedVideo.create({ + data: { + id: 1, + viewCount: 0, + duration: 100, + url: 'v1', + rating: 5, + owner: { connect: { id: u.id } }, + user: { connect: { id: u.id } }, + comments: { create: [{ content: 'c1' }, { content: 'c2' }] }, + }, + }); + await client.ratedVideo.create({ + data: { + id: 2, + viewCount: 2, + duration: 200, + url: 'v2', + rating: 3, + }, + }); + }); + + it('works with count', async () => { + await expect( + client.ratedVideo.count({ + where: { rating: 5 }, + }), + ).resolves.toEqual(1); + await expect( + client.ratedVideo.count({ + where: { duration: 100 }, + }), + ).resolves.toEqual(1); + await expect( + client.ratedVideo.count({ + where: { viewCount: 1 }, + }), + ).resolves.toEqual(1); + + await expect( + client.video.count({ + where: { duration: 100 }, + }), + ).resolves.toEqual(1); + await expect( + client.asset.count({ + where: { viewCount: { gt: 0 } }, + }), + ).resolves.toEqual(1); + + // field selection + await expect( + client.ratedVideo.count({ + select: { _all: true, viewCount: true, url: true, rating: true }, + }), + ).resolves.toMatchObject({ + _all: 2, + viewCount: 2, + url: 2, + rating: 2, + }); + await expect( + client.video.count({ + select: { _all: true, viewCount: true, url: true }, + }), + ).resolves.toMatchObject({ + _all: 2, + viewCount: 2, + url: 2, + }); + await expect( + client.asset.count({ + select: { _all: true, viewCount: true }, + }), + ).resolves.toMatchObject({ + _all: 2, + viewCount: 2, + }); + }); + + it('works with aggregate', async () => { + await expect( + client.ratedVideo.aggregate({ + where: { viewCount: { gte: 0 }, duration: { gt: 0 }, rating: { gt: 0 } }, + _avg: { viewCount: true, duration: true, rating: true }, + _count: true, + }), + ).resolves.toMatchObject({ + _avg: { + viewCount: 1, + duration: 150, + rating: 4, + }, + _count: 2, + }); + await expect( + client.video.aggregate({ + where: { viewCount: { gte: 0 }, duration: { gt: 0 } }, + _avg: { viewCount: true, duration: true }, + _count: true, + }), + ).resolves.toMatchObject({ + _avg: { + viewCount: 1, + duration: 150, + }, + _count: 2, + }); + await expect( + client.asset.aggregate({ + where: { viewCount: { gte: 0 } }, + _avg: { viewCount: true }, + _count: true, + }), + ).resolves.toMatchObject({ + _avg: { + viewCount: 1, + }, + _count: 2, + }); + + // just count + await expect( + client.ratedVideo.aggregate({ + _count: true, + }), + ).resolves.toMatchObject({ + _count: 2, + }); + }); + }); }, ); From 7707eed32ccd1d04fd8a41780ee905b3b14347fa Mon Sep 17 00:00:00 2001 From: ymc9 <104139426+ymc9@users.noreply.github.com> Date: Mon, 28 Jul 2025 16:27:46 +0800 Subject: [PATCH 2/2] fixes --- packages/runtime/src/client/crud/operations/aggregate.ts | 1 + packages/runtime/test/client-api/delegate.test.ts | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/packages/runtime/src/client/crud/operations/aggregate.ts b/packages/runtime/src/client/crud/operations/aggregate.ts index de061545..5d309dda 100644 --- a/packages/runtime/src/client/crud/operations/aggregate.ts +++ b/packages/runtime/src/client/crud/operations/aggregate.ts @@ -27,6 +27,7 @@ export class AggregateOperationHandler extends BaseOpe if (key.startsWith('_') && value && typeof value === 'object') { // select fields Object.entries(value) + .filter(([field]) => field !== '_all') .filter(([, val]) => val === true) .forEach(([field]) => { if (!selectedFields.includes(field)) selectedFields.push(field); diff --git a/packages/runtime/test/client-api/delegate.test.ts b/packages/runtime/test/client-api/delegate.test.ts index 79f92163..51749af2 100644 --- a/packages/runtime/test/client-api/delegate.test.ts +++ b/packages/runtime/test/client-api/delegate.test.ts @@ -1115,7 +1115,7 @@ model Gallery { ).resolves.toEqual(1); await expect( client.ratedVideo.count({ - where: { viewCount: 1 }, + where: { viewCount: 2 }, }), ).resolves.toEqual(1);