Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion packages/runtime/src/client/crud-types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,12 @@ export type SelectInput<
[Key in NonRelationFields<Schema, Model>]?: true;
} & (AllowRelation extends true ? IncludeInput<Schema, Model> : {}) & // relation fields
// relation count
(AllowCount extends true ? { _count?: SelectCount<Schema, Model> } : {});
(AllowCount extends true
? // _count is only allowed if the model has to-many relations
HasToManyRelations<Schema, Model> extends true
? { _count?: SelectCount<Schema, Model> }
: {}
: {});

type SelectCount<Schema extends SchemaDef, Model extends GetModels<Schema>> =
| true
Expand Down Expand Up @@ -1181,4 +1186,10 @@ type NonOwnedRelationFields<Schema extends SchemaDef, Model extends GetModels<Sc
: Key]: true;
};

type HasToManyRelations<Schema extends SchemaDef, Model extends GetModels<Schema>> = keyof {
[Key in RelationFields<Schema, Model> as FieldIsArray<Schema, Model, Key> extends true ? Key : never]: true;
} extends never
? false
: true;

// #endregion
50 changes: 50 additions & 0 deletions packages/runtime/src/client/crud/dialects/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -847,6 +847,56 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
return query;
}

buildCountJson(model: string, eb: ExpressionBuilder<any, any>, parentAlias: string, payload: any) {
const modelDef = requireModel(this.schema, model);
const toManyRelations = Object.entries(modelDef.fields).filter(([, field]) => field.relation && field.array);

const selections =
payload === true
? {
select: toManyRelations.reduce(
(acc, [field]) => {
acc[field] = true;
return acc;
},
{} as Record<string, boolean>,
),
}
: payload;

const jsonObject: Record<string, Expression<any>> = {};

for (const [field, value] of Object.entries(selections.select)) {
const fieldDef = requireField(this.schema, model, field);
const fieldModel = fieldDef.type;
const joinPairs = buildJoinPairs(this.schema, model, parentAlias, field, fieldModel);

// build a nested query to count the number of records in the relation
let fieldCountQuery = eb.selectFrom(fieldModel).select(eb.fn.countAll().as(`_count$${field}`));

// join conditions
for (const [left, right] of joinPairs) {
fieldCountQuery = fieldCountQuery.whereRef(left, '=', right);
}

// merge _count filter
if (
value &&
typeof value === 'object' &&
'where' in value &&
value.where &&
typeof value.where === 'object'
) {
const filter = this.buildFilter(eb, fieldModel, fieldModel, value.where);
fieldCountQuery = fieldCountQuery.where(filter);
}

jsonObject[field] = fieldCountQuery;
}

return this.buildJsonObject(eb, jsonObject);
}

// #endregion

// #region utils
Expand Down
30 changes: 20 additions & 10 deletions packages/runtime/src/client/crud/dialects/postgresql.ts
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ export class PostgresCrudDialect<Schema extends SchemaDef> extends BaseCrudDiale
relationField: string,
eb: ExpressionBuilder<any, any>,
payload: true | FindArgs<Schema, GetModels<Schema>, true>,
parentName: string,
parentAlias: string,
) {
const relationModelDef = requireModel(this.schema, relationModel);
const objArgs: Array<
Expand Down Expand Up @@ -238,14 +238,24 @@ export class PostgresCrudDialect<Schema extends SchemaDef> extends BaseCrudDiale
objArgs.push(
...Object.entries(payload.select)
.filter(([, value]) => value)
.map(([field]) => {
const fieldDef = requireField(this.schema, relationModel, field);
const fieldValue = fieldDef.relation
? // reference the synthesized JSON field
eb.ref(`${parentName}$${relationField}$${field}.$j`)
: // reference a plain field
buildFieldRef(this.schema, relationModel, field, this.options, eb);
return [sql.lit(field), fieldValue];
.map(([field, value]) => {
if (field === '_count') {
const subJson = this.buildCountJson(
relationModel as GetModels<Schema>,
eb,
`${parentAlias}$${relationField}`,
value,
);
return [sql.lit(field), subJson];
} else {
const fieldDef = requireField(this.schema, relationModel, field);
const fieldValue = fieldDef.relation
? // reference the synthesized JSON field
eb.ref(`${parentAlias}$${relationField}$${field}.$j`)
: // reference a plain field
buildFieldRef(this.schema, relationModel, field, this.options, eb);
return [sql.lit(field), fieldValue];
}
})
.flatMap((v) => v),
);
Expand All @@ -259,7 +269,7 @@ export class PostgresCrudDialect<Schema extends SchemaDef> extends BaseCrudDiale
.map(([field]) => [
sql.lit(field),
// reference the synthesized JSON field
eb.ref(`${parentName}$${relationField}$${field}.$j`),
eb.ref(`${parentAlias}$${relationField}$${field}.$j`),
])
.flatMap((v) => v),
);
Expand Down
42 changes: 26 additions & 16 deletions packages/runtime/src/client/crud/dialects/sqlite.ts
Original file line number Diff line number Diff line change
Expand Up @@ -67,14 +67,14 @@ export class SqliteCrudDialect<Schema extends SchemaDef> extends BaseCrudDialect
model: string,
eb: ExpressionBuilder<any, any>,
relationField: string,
parentName: string,
parentAlias: string,
payload: true | FindArgs<Schema, GetModels<Schema>, true>,
) {
const relationFieldDef = requireField(this.schema, model, relationField);
const relationModel = relationFieldDef.type as GetModels<Schema>;
const relationModelDef = requireModel(this.schema, relationModel);

const subQueryName = `${parentName}$${relationField}`;
const subQueryName = `${parentAlias}$${relationField}`;

let tbl = eb.selectFrom(() => {
let subQuery = this.buildSelectModel(eb, relationModel);
Expand Down Expand Up @@ -129,18 +129,18 @@ export class SqliteCrudDialect<Schema extends SchemaDef> extends BaseCrudDialect
eb
.selectFrom(m2m.joinTable)
.select(`${m2m.joinTable}.${m2m.otherFkName}`)
.whereRef(`${parentName}.${parentIds[0]}`, '=', `${m2m.joinTable}.${m2m.parentFkName}`),
.whereRef(`${parentAlias}.${parentIds[0]}`, '=', `${m2m.joinTable}.${m2m.parentFkName}`),
),
);
} else {
const { keyPairs, ownedByModel } = getRelationForeignKeyFieldPairs(this.schema, model, relationField);
keyPairs.forEach(({ fk, pk }) => {
if (ownedByModel) {
// the parent model owns the fk
subQuery = subQuery.whereRef(`${relationModel}.${pk}`, '=', `${parentName}.${fk}`);
subQuery = subQuery.whereRef(`${relationModel}.${pk}`, '=', `${parentAlias}.${fk}`);
} else {
// the relation side owns the fk
subQuery = subQuery.whereRef(`${relationModel}.${fk}`, '=', `${parentName}.${pk}`);
subQuery = subQuery.whereRef(`${relationModel}.${fk}`, '=', `${parentAlias}.${pk}`);
}
});
}
Expand Down Expand Up @@ -183,21 +183,31 @@ export class SqliteCrudDialect<Schema extends SchemaDef> extends BaseCrudDialect
...Object.entries<any>(payload.select)
.filter(([, value]) => value)
.map(([field, value]) => {
const fieldDef = requireField(this.schema, relationModel, field);
if (fieldDef.relation) {
const subJson = this.buildRelationJSON(
if (field === '_count') {
const subJson = this.buildCountJson(
relationModel as GetModels<Schema>,
eb,
field,
`${parentName}$${relationField}`,
`${parentAlias}$${relationField}`,
value,
);
return [sql.lit(field), subJson as ArgsType];
return [sql.lit(field), subJson];
} else {
return [
sql.lit(field),
buildFieldRef(this.schema, relationModel, field, this.options, eb) as ArgsType,
];
const fieldDef = requireField(this.schema, relationModel, field);
if (fieldDef.relation) {
const subJson = this.buildRelationJSON(
relationModel as GetModels<Schema>,
eb,
field,
`${parentAlias}$${relationField}`,
value,
);
return [sql.lit(field), subJson];
} else {
return [
sql.lit(field),
buildFieldRef(this.schema, relationModel, field, this.options, eb) as ArgsType,
];
}
}
})
.flatMap((v) => v),
Expand All @@ -214,7 +224,7 @@ export class SqliteCrudDialect<Schema extends SchemaDef> extends BaseCrudDialect
relationModel as GetModels<Schema>,
eb,
field,
`${parentName}$${relationField}`,
`${parentAlias}$${relationField}`,
value,
);
return [sql.lit(field), subJson];
Expand Down
53 changes: 1 addition & 52 deletions packages/runtime/src/client/crud/operations/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import {
UpdateResult,
type Compilable,
type IsolationLevel,
type Expression as KyselyExpression,
type QueryResult,
type SelectQueryBuilder,
} from 'kysely';
Expand All @@ -31,7 +30,6 @@ import { InternalError, NotFoundError, QueryError } from '../../errors';
import type { ToKysely } from '../../query-builder';
import {
buildFieldRef,
buildJoinPairs,
ensureArray,
extractIdFields,
flattenCompoundUniqueFilters,
Expand Down Expand Up @@ -298,56 +296,7 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
parentAlias: string,
payload: any,
) {
const modelDef = requireModel(this.schema, model);
const toManyRelations = Object.entries(modelDef.fields).filter(([, field]) => field.relation && field.array);

const selections =
payload === true
? {
select: toManyRelations.reduce(
(acc, [field]) => {
acc[field] = true;
return acc;
},
{} as Record<string, boolean>,
),
}
: payload;

const eb = expressionBuilder<any, any>();
const jsonObject: Record<string, KyselyExpression<any>> = {};

for (const [field, value] of Object.entries(selections.select)) {
const fieldDef = requireField(this.schema, model, field);
const fieldModel = fieldDef.type;
const joinPairs = buildJoinPairs(this.schema, model, parentAlias, field, fieldModel);

// build a nested query to count the number of records in the relation
let fieldCountQuery = eb.selectFrom(fieldModel).select(eb.fn.countAll().as(`_count$${field}`));

// join conditions
for (const [left, right] of joinPairs) {
fieldCountQuery = fieldCountQuery.whereRef(left, '=', right);
}

// merge _count filter
if (
value &&
typeof value === 'object' &&
'where' in value &&
value.where &&
typeof value.where === 'object'
) {
const filter = this.dialect.buildFilter(eb, fieldModel, fieldModel, value.where);
fieldCountQuery = fieldCountQuery.where(filter);
}

jsonObject[field] = fieldCountQuery;
}

query = query.select((eb) => this.dialect.buildJsonObject(eb, jsonObject).as('_count'));

return query;
return query.select((eb) => this.dialect.buildCountJson(model, eb, parentAlias, payload).as('_count'));
}

private buildCursorFilter(
Expand Down
6 changes: 3 additions & 3 deletions packages/runtime/src/client/crud/validator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -544,17 +544,17 @@ export class InputValidator<Schema extends SchemaDef> {
}
}

const toManyRelations = Object.entries(modelDef.fields).filter(([, value]) => value.relation && value.array);
const toManyRelations = Object.values(modelDef.fields).filter((def) => def.relation && def.array);

if (toManyRelations.length > 0) {
fields['_count'] = z
.union([
z.literal(true),
z.object(
toManyRelations.reduce(
(acc, [name, fieldDef]) => ({
(acc, fieldDef) => ({
...acc,
[name]: z
[fieldDef.name]: z
.union([
z.boolean(),
z.object({
Expand Down
16 changes: 8 additions & 8 deletions packages/runtime/src/client/query-utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -97,23 +97,23 @@ export function getRelationForeignKeyFieldPairs(schema: SchemaDef, model: string
}

export function isScalarField(schema: SchemaDef, model: string, field: string): boolean {
const fieldDef = requireField(schema, model, field);
return !fieldDef.relation && !fieldDef.foreignKeyFor;
const fieldDef = getField(schema, model, field);
return !fieldDef?.relation && !fieldDef?.foreignKeyFor;
}

export function isForeignKeyField(schema: SchemaDef, model: string, field: string): boolean {
const fieldDef = requireField(schema, model, field);
return !!fieldDef.foreignKeyFor;
const fieldDef = getField(schema, model, field);
return !!fieldDef?.foreignKeyFor;
}

export function isRelationField(schema: SchemaDef, model: string, field: string): boolean {
const fieldDef = requireField(schema, model, field);
return !!fieldDef.relation;
const fieldDef = getField(schema, model, field);
return !!fieldDef?.relation;
}

export function isInheritedField(schema: SchemaDef, model: string, field: string): boolean {
const fieldDef = requireField(schema, model, field);
return !!fieldDef.originModel;
const fieldDef = getField(schema, model, field);
return !!fieldDef?.originModel;
}

export function getUniqueFields(schema: SchemaDef, model: string) {
Expand Down
25 changes: 25 additions & 0 deletions packages/runtime/test/client-api/find.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -832,6 +832,31 @@ describe.each(createClientSpecs(PG_DB_NAME))('Client find tests for $provider',
_count: { posts: 2 },
});

await expect(
client.user.findUnique({
where: { id: user1.id },
select: {
id: true,
posts: {
select: { _count: true },
},
},
}),
).resolves.toMatchObject({
id: user1.id,
posts: [{ _count: { comments: 0 } }, { _count: { comments: 0 } }],
});

client.comment.findFirst({
// @ts-expect-error Comment has no to-many relations to count
select: { _count: true },
});

client.post.findFirst({
// @ts-expect-error Comment has no to-many relations to count
select: { comments: { _count: true } },
});

await expect(
client.user.findUnique({
where: { id: user1.id },
Expand Down