diff --git a/packages/runtime/src/client/crud-types.ts b/packages/runtime/src/client/crud-types.ts index abe011f4..20fd4be9 100644 --- a/packages/runtime/src/client/crud-types.ts +++ b/packages/runtime/src/client/crud-types.ts @@ -393,7 +393,12 @@ export type SelectInput< [Key in NonRelationFields]?: true; } & (AllowRelation extends true ? IncludeInput : {}) & // relation fields // relation count - (AllowCount extends true ? { _count?: SelectCount } : {}); + (AllowCount extends true + ? // _count is only allowed if the model has to-many relations + HasToManyRelations extends true + ? { _count?: SelectCount } + : {} + : {}); type SelectCount> = | true @@ -1181,4 +1186,10 @@ type NonOwnedRelationFields> = keyof { + [Key in RelationFields as FieldIsArray extends true ? Key : never]: true; +} extends never + ? false + : true; + // #endregion diff --git a/packages/runtime/src/client/crud/dialects/base.ts b/packages/runtime/src/client/crud/dialects/base.ts index 369c1539..c1bc7660 100644 --- a/packages/runtime/src/client/crud/dialects/base.ts +++ b/packages/runtime/src/client/crud/dialects/base.ts @@ -847,6 +847,56 @@ export abstract class BaseCrudDialect { return query; } + buildCountJson(model: string, eb: ExpressionBuilder, parentAlias: string, payload: any) { + const modelDef = requireModel(this.schema, model); + const toManyRelations = Object.entries(modelDef.fields).filter(([, field]) => field.relation && field.array); + + const selections = + payload === true + ? { + select: toManyRelations.reduce( + (acc, [field]) => { + acc[field] = true; + return acc; + }, + {} as Record, + ), + } + : payload; + + const jsonObject: Record> = {}; + + for (const [field, value] of Object.entries(selections.select)) { + const fieldDef = requireField(this.schema, model, field); + const fieldModel = fieldDef.type; + const joinPairs = buildJoinPairs(this.schema, model, parentAlias, field, fieldModel); + + // build a nested query to count the number of records in the relation + let fieldCountQuery = eb.selectFrom(fieldModel).select(eb.fn.countAll().as(`_count$${field}`)); + + // join conditions + for (const [left, right] of joinPairs) { + fieldCountQuery = fieldCountQuery.whereRef(left, '=', right); + } + + // merge _count filter + if ( + value && + typeof value === 'object' && + 'where' in value && + value.where && + typeof value.where === 'object' + ) { + const filter = this.buildFilter(eb, fieldModel, fieldModel, value.where); + fieldCountQuery = fieldCountQuery.where(filter); + } + + jsonObject[field] = fieldCountQuery; + } + + return this.buildJsonObject(eb, jsonObject); + } + // #endregion // #region utils diff --git a/packages/runtime/src/client/crud/dialects/postgresql.ts b/packages/runtime/src/client/crud/dialects/postgresql.ts index f3408820..5cb9c5de 100644 --- a/packages/runtime/src/client/crud/dialects/postgresql.ts +++ b/packages/runtime/src/client/crud/dialects/postgresql.ts @@ -200,7 +200,7 @@ export class PostgresCrudDialect extends BaseCrudDiale relationField: string, eb: ExpressionBuilder, payload: true | FindArgs, true>, - parentName: string, + parentAlias: string, ) { const relationModelDef = requireModel(this.schema, relationModel); const objArgs: Array< @@ -238,14 +238,24 @@ export class PostgresCrudDialect extends BaseCrudDiale objArgs.push( ...Object.entries(payload.select) .filter(([, value]) => value) - .map(([field]) => { - const fieldDef = requireField(this.schema, relationModel, field); - const fieldValue = fieldDef.relation - ? // reference the synthesized JSON field - eb.ref(`${parentName}$${relationField}$${field}.$j`) - : // reference a plain field - buildFieldRef(this.schema, relationModel, field, this.options, eb); - return [sql.lit(field), fieldValue]; + .map(([field, value]) => { + if (field === '_count') { + const subJson = this.buildCountJson( + relationModel as GetModels, + eb, + `${parentAlias}$${relationField}`, + value, + ); + return [sql.lit(field), subJson]; + } else { + const fieldDef = requireField(this.schema, relationModel, field); + const fieldValue = fieldDef.relation + ? // reference the synthesized JSON field + eb.ref(`${parentAlias}$${relationField}$${field}.$j`) + : // reference a plain field + buildFieldRef(this.schema, relationModel, field, this.options, eb); + return [sql.lit(field), fieldValue]; + } }) .flatMap((v) => v), ); @@ -259,7 +269,7 @@ export class PostgresCrudDialect extends BaseCrudDiale .map(([field]) => [ sql.lit(field), // reference the synthesized JSON field - eb.ref(`${parentName}$${relationField}$${field}.$j`), + eb.ref(`${parentAlias}$${relationField}$${field}.$j`), ]) .flatMap((v) => v), ); diff --git a/packages/runtime/src/client/crud/dialects/sqlite.ts b/packages/runtime/src/client/crud/dialects/sqlite.ts index 695795ab..9277af48 100644 --- a/packages/runtime/src/client/crud/dialects/sqlite.ts +++ b/packages/runtime/src/client/crud/dialects/sqlite.ts @@ -67,14 +67,14 @@ export class SqliteCrudDialect extends BaseCrudDialect model: string, eb: ExpressionBuilder, relationField: string, - parentName: string, + parentAlias: string, payload: true | FindArgs, true>, ) { const relationFieldDef = requireField(this.schema, model, relationField); const relationModel = relationFieldDef.type as GetModels; const relationModelDef = requireModel(this.schema, relationModel); - const subQueryName = `${parentName}$${relationField}`; + const subQueryName = `${parentAlias}$${relationField}`; let tbl = eb.selectFrom(() => { let subQuery = this.buildSelectModel(eb, relationModel); @@ -129,7 +129,7 @@ export class SqliteCrudDialect extends BaseCrudDialect eb .selectFrom(m2m.joinTable) .select(`${m2m.joinTable}.${m2m.otherFkName}`) - .whereRef(`${parentName}.${parentIds[0]}`, '=', `${m2m.joinTable}.${m2m.parentFkName}`), + .whereRef(`${parentAlias}.${parentIds[0]}`, '=', `${m2m.joinTable}.${m2m.parentFkName}`), ), ); } else { @@ -137,10 +137,10 @@ export class SqliteCrudDialect extends BaseCrudDialect keyPairs.forEach(({ fk, pk }) => { if (ownedByModel) { // the parent model owns the fk - subQuery = subQuery.whereRef(`${relationModel}.${pk}`, '=', `${parentName}.${fk}`); + subQuery = subQuery.whereRef(`${relationModel}.${pk}`, '=', `${parentAlias}.${fk}`); } else { // the relation side owns the fk - subQuery = subQuery.whereRef(`${relationModel}.${fk}`, '=', `${parentName}.${pk}`); + subQuery = subQuery.whereRef(`${relationModel}.${fk}`, '=', `${parentAlias}.${pk}`); } }); } @@ -183,21 +183,31 @@ export class SqliteCrudDialect extends BaseCrudDialect ...Object.entries(payload.select) .filter(([, value]) => value) .map(([field, value]) => { - const fieldDef = requireField(this.schema, relationModel, field); - if (fieldDef.relation) { - const subJson = this.buildRelationJSON( + if (field === '_count') { + const subJson = this.buildCountJson( relationModel as GetModels, eb, - field, - `${parentName}$${relationField}`, + `${parentAlias}$${relationField}`, value, ); - return [sql.lit(field), subJson as ArgsType]; + return [sql.lit(field), subJson]; } else { - return [ - sql.lit(field), - buildFieldRef(this.schema, relationModel, field, this.options, eb) as ArgsType, - ]; + const fieldDef = requireField(this.schema, relationModel, field); + if (fieldDef.relation) { + const subJson = this.buildRelationJSON( + relationModel as GetModels, + eb, + field, + `${parentAlias}$${relationField}`, + value, + ); + return [sql.lit(field), subJson]; + } else { + return [ + sql.lit(field), + buildFieldRef(this.schema, relationModel, field, this.options, eb) as ArgsType, + ]; + } } }) .flatMap((v) => v), @@ -214,7 +224,7 @@ export class SqliteCrudDialect extends BaseCrudDialect relationModel as GetModels, eb, field, - `${parentName}$${relationField}`, + `${parentAlias}$${relationField}`, value, ); return [sql.lit(field), subJson]; diff --git a/packages/runtime/src/client/crud/operations/base.ts b/packages/runtime/src/client/crud/operations/base.ts index bc43e2af..58fe3759 100644 --- a/packages/runtime/src/client/crud/operations/base.ts +++ b/packages/runtime/src/client/crud/operations/base.ts @@ -8,7 +8,6 @@ import { UpdateResult, type Compilable, type IsolationLevel, - type Expression as KyselyExpression, type QueryResult, type SelectQueryBuilder, } from 'kysely'; @@ -31,7 +30,6 @@ import { InternalError, NotFoundError, QueryError } from '../../errors'; import type { ToKysely } from '../../query-builder'; import { buildFieldRef, - buildJoinPairs, ensureArray, extractIdFields, flattenCompoundUniqueFilters, @@ -298,56 +296,7 @@ export abstract class BaseOperationHandler { parentAlias: string, payload: any, ) { - const modelDef = requireModel(this.schema, model); - const toManyRelations = Object.entries(modelDef.fields).filter(([, field]) => field.relation && field.array); - - const selections = - payload === true - ? { - select: toManyRelations.reduce( - (acc, [field]) => { - acc[field] = true; - return acc; - }, - {} as Record, - ), - } - : payload; - - const eb = expressionBuilder(); - const jsonObject: Record> = {}; - - for (const [field, value] of Object.entries(selections.select)) { - const fieldDef = requireField(this.schema, model, field); - const fieldModel = fieldDef.type; - const joinPairs = buildJoinPairs(this.schema, model, parentAlias, field, fieldModel); - - // build a nested query to count the number of records in the relation - let fieldCountQuery = eb.selectFrom(fieldModel).select(eb.fn.countAll().as(`_count$${field}`)); - - // join conditions - for (const [left, right] of joinPairs) { - fieldCountQuery = fieldCountQuery.whereRef(left, '=', right); - } - - // merge _count filter - if ( - value && - typeof value === 'object' && - 'where' in value && - value.where && - typeof value.where === 'object' - ) { - const filter = this.dialect.buildFilter(eb, fieldModel, fieldModel, value.where); - fieldCountQuery = fieldCountQuery.where(filter); - } - - jsonObject[field] = fieldCountQuery; - } - - query = query.select((eb) => this.dialect.buildJsonObject(eb, jsonObject).as('_count')); - - return query; + return query.select((eb) => this.dialect.buildCountJson(model, eb, parentAlias, payload).as('_count')); } private buildCursorFilter( diff --git a/packages/runtime/src/client/crud/validator.ts b/packages/runtime/src/client/crud/validator.ts index c4c7a9d1..32ab09cb 100644 --- a/packages/runtime/src/client/crud/validator.ts +++ b/packages/runtime/src/client/crud/validator.ts @@ -544,7 +544,7 @@ export class InputValidator { } } - const toManyRelations = Object.entries(modelDef.fields).filter(([, value]) => value.relation && value.array); + const toManyRelations = Object.values(modelDef.fields).filter((def) => def.relation && def.array); if (toManyRelations.length > 0) { fields['_count'] = z @@ -552,9 +552,9 @@ export class InputValidator { z.literal(true), z.object( toManyRelations.reduce( - (acc, [name, fieldDef]) => ({ + (acc, fieldDef) => ({ ...acc, - [name]: z + [fieldDef.name]: z .union([ z.boolean(), z.object({ diff --git a/packages/runtime/src/client/query-utils.ts b/packages/runtime/src/client/query-utils.ts index c4cd78f5..7302933b 100644 --- a/packages/runtime/src/client/query-utils.ts +++ b/packages/runtime/src/client/query-utils.ts @@ -97,23 +97,23 @@ export function getRelationForeignKeyFieldPairs(schema: SchemaDef, model: string } export function isScalarField(schema: SchemaDef, model: string, field: string): boolean { - const fieldDef = requireField(schema, model, field); - return !fieldDef.relation && !fieldDef.foreignKeyFor; + const fieldDef = getField(schema, model, field); + return !fieldDef?.relation && !fieldDef?.foreignKeyFor; } export function isForeignKeyField(schema: SchemaDef, model: string, field: string): boolean { - const fieldDef = requireField(schema, model, field); - return !!fieldDef.foreignKeyFor; + const fieldDef = getField(schema, model, field); + return !!fieldDef?.foreignKeyFor; } export function isRelationField(schema: SchemaDef, model: string, field: string): boolean { - const fieldDef = requireField(schema, model, field); - return !!fieldDef.relation; + const fieldDef = getField(schema, model, field); + return !!fieldDef?.relation; } export function isInheritedField(schema: SchemaDef, model: string, field: string): boolean { - const fieldDef = requireField(schema, model, field); - return !!fieldDef.originModel; + const fieldDef = getField(schema, model, field); + return !!fieldDef?.originModel; } export function getUniqueFields(schema: SchemaDef, model: string) { diff --git a/packages/runtime/test/client-api/find.test.ts b/packages/runtime/test/client-api/find.test.ts index e1a05be1..3cb85495 100644 --- a/packages/runtime/test/client-api/find.test.ts +++ b/packages/runtime/test/client-api/find.test.ts @@ -832,6 +832,31 @@ describe.each(createClientSpecs(PG_DB_NAME))('Client find tests for $provider', _count: { posts: 2 }, }); + await expect( + client.user.findUnique({ + where: { id: user1.id }, + select: { + id: true, + posts: { + select: { _count: true }, + }, + }, + }), + ).resolves.toMatchObject({ + id: user1.id, + posts: [{ _count: { comments: 0 } }, { _count: { comments: 0 } }], + }); + + client.comment.findFirst({ + // @ts-expect-error Comment has no to-many relations to count + select: { _count: true }, + }); + + client.post.findFirst({ + // @ts-expect-error Comment has no to-many relations to count + select: { comments: { _count: true } }, + }); + await expect( client.user.findUnique({ where: { id: user1.id },