From f296f83537b2dc9a5c2cd374f9830ef7ccf3e5d7 Mon Sep 17 00:00:00 2001 From: ymc9 <104139426+ymc9@users.noreply.github.com> Date: Thu, 24 Jul 2025 16:00:40 +0800 Subject: [PATCH 1/3] fix(delegate): relation selection --- .../runtime/src/client/crud/dialects/base.ts | 2 + .../runtime/src/client/crud/dialects/index.ts | 6 +- .../src/client/crud/dialects/postgresql.ts | 26 +- .../src/client/crud/dialects/sqlite.ts | 26 +- .../src/client/crud/operations/base.ts | 50 +-- packages/runtime/src/client/query-utils.ts | 17 +- .../runtime/test/client-api/delegate.test.ts | 314 ++++++++++-------- 7 files changed, 279 insertions(+), 162 deletions(-) diff --git a/packages/runtime/src/client/crud/dialects/base.ts b/packages/runtime/src/client/crud/dialects/base.ts index 7d4e6fb2..93cbaefc 100644 --- a/packages/runtime/src/client/crud/dialects/base.ts +++ b/packages/runtime/src/client/crud/dialects/base.ts @@ -28,11 +28,13 @@ import { makeDefaultOrderBy, requireField, } from '../../query-utils'; +import type { BaseOperationHandler } from '../operations/base'; export abstract class BaseCrudDialect { constructor( protected readonly schema: Schema, protected readonly options: ClientOptions, + protected readonly handler: BaseOperationHandler, ) {} abstract get provider(): DataSourceProviderType; diff --git a/packages/runtime/src/client/crud/dialects/index.ts b/packages/runtime/src/client/crud/dialects/index.ts index 9d67009e..f8fa6506 100644 --- a/packages/runtime/src/client/crud/dialects/index.ts +++ b/packages/runtime/src/client/crud/dialects/index.ts @@ -1,6 +1,7 @@ import { match } from 'ts-pattern'; import type { SchemaDef } from '../../../schema'; import type { ClientOptions } from '../../options'; +import type { BaseOperationHandler } from '../operations/base'; import type { BaseCrudDialect } from './base'; import { PostgresCrudDialect } from './postgresql'; import { SqliteCrudDialect } from './sqlite'; @@ -8,9 +9,10 @@ import { SqliteCrudDialect } from './sqlite'; export function getCrudDialect( schema: Schema, options: ClientOptions, + handler: BaseOperationHandler, ): BaseCrudDialect { return match(schema.provider.type) - .with('sqlite', () => new SqliteCrudDialect(schema, options)) - .with('postgresql', () => new PostgresCrudDialect(schema, options)) + .with('sqlite', () => new SqliteCrudDialect(schema, options, handler)) + .with('postgresql', () => new PostgresCrudDialect(schema, options, handler)) .exhaustive(); } diff --git a/packages/runtime/src/client/crud/dialects/postgresql.ts b/packages/runtime/src/client/crud/dialects/postgresql.ts index c73b4bb1..4d8a70db 100644 --- a/packages/runtime/src/client/crud/dialects/postgresql.ts +++ b/packages/runtime/src/client/crud/dialects/postgresql.ts @@ -9,10 +9,12 @@ import { } from 'kysely'; import { match } from 'ts-pattern'; import type { BuiltinType, FieldDef, GetModels, SchemaDef } from '../../../schema'; +import { DELEGATE_JOINED_FIELD_PREFIX } from '../../constants'; import type { FindArgs } from '../../crud-types'; import { buildFieldRef, buildJoinPairs, + getDelegateDescendantModels, getIdFields, getManyToManyRelation, isRelationField, @@ -79,10 +81,18 @@ export class PostgresCrudDialect extends BaseCrudDiale // simple select by default let result = eb.selectFrom(`${relationModel} as ${joinTableName}`); + const joinBases: string[] = []; + // however if there're filter/orderBy/take/skip, // we need to build a subquery to handle them before aggregation result = eb.selectFrom(() => { - let subQuery = eb.selectFrom(`${relationModel}`).selectAll(); + let subQuery = eb.selectFrom(relationModel); + subQuery = this.handler.buildSelectAllFields( + relationModel, + subQuery, + typeof payload === 'object' ? payload?.omit : undefined, + joinBases, + ); if (payload && typeof payload === 'object') { if (payload.where) { @@ -200,6 +210,20 @@ export class PostgresCrudDialect extends BaseCrudDiale string | ExpressionWrapper | SelectQueryBuilder | RawBuilder > = []; + // TODO: descendant JSON shouldn't be joined and selected if none of its fields are selected + const descendantModels = getDelegateDescendantModels(this.schema, relationModel); + if (descendantModels.length > 0) { + // select all JSONs built from delegate descendants + objArgs.push( + ...descendantModels + .map((subModel) => [ + sql.lit(`${DELEGATE_JOINED_FIELD_PREFIX}${subModel.name}`), + eb.ref(`${DELEGATE_JOINED_FIELD_PREFIX}${subModel.name}`), + ]) + .flatMap((v) => v), + ); + } + if (payload === true || !payload.select) { // select all scalar fields objArgs.push( diff --git a/packages/runtime/src/client/crud/dialects/sqlite.ts b/packages/runtime/src/client/crud/dialects/sqlite.ts index c27cd7de..c18be56a 100644 --- a/packages/runtime/src/client/crud/dialects/sqlite.ts +++ b/packages/runtime/src/client/crud/dialects/sqlite.ts @@ -10,9 +10,11 @@ import { } from 'kysely'; import { match } from 'ts-pattern'; import type { BuiltinType, GetModels, SchemaDef } from '../../../schema'; +import { DELEGATE_JOINED_FIELD_PREFIX } from '../../constants'; import type { FindArgs } from '../../crud-types'; import { buildFieldRef, + getDelegateDescendantModels, getIdFields, getManyToManyRelation, getRelationForeignKeyFieldPairs, @@ -75,7 +77,15 @@ export class SqliteCrudDialect extends BaseCrudDialect const subQueryName = `${parentName}$${relationField}`; let tbl = eb.selectFrom(() => { - let subQuery = eb.selectFrom(relationModel).selectAll(); + let subQuery = eb.selectFrom(relationModel); + + const joinBases: string[] = []; + subQuery = this.handler.buildSelectAllFields( + relationModel, + subQuery, + typeof payload === 'object' ? payload?.omit : undefined, + joinBases, + ); if (payload && typeof payload === 'object') { if (payload.where) { @@ -143,6 +153,20 @@ export class SqliteCrudDialect extends BaseCrudDialect type ArgsType = Expression | RawBuilder | SelectQueryBuilder; const objArgs: ArgsType[] = []; + // TODO: descendant JSON shouldn't be joined and selected if none of its fields are selected + const descendantModels = getDelegateDescendantModels(this.schema, relationModel); + if (descendantModels.length > 0) { + // select all JSONs built from delegate descendants + objArgs.push( + ...descendantModels + .map((subModel) => [ + sql.lit(`${DELEGATE_JOINED_FIELD_PREFIX}${subModel.name}`), + eb.ref(`${DELEGATE_JOINED_FIELD_PREFIX}${subModel.name}`), + ]) + .flatMap((v) => v), + ); + } + if (payload === true || !payload.select) { // select all scalar fields objArgs.push( diff --git a/packages/runtime/src/client/crud/operations/base.ts b/packages/runtime/src/client/crud/operations/base.ts index a4bb5c11..f22d8985 100644 --- a/packages/runtime/src/client/crud/operations/base.ts +++ b/packages/runtime/src/client/crud/operations/base.ts @@ -33,6 +33,7 @@ import { ensureArray, extractIdFields, flattenCompoundUniqueFilters, + getDelegateDescendantModels, getDiscriminatorField, getField, getIdFields, @@ -84,7 +85,7 @@ export abstract class BaseOperationHandler { protected readonly model: GetModels, protected readonly inputValidator: InputValidator, ) { - this.dialect = getCrudDialect(this.schema, this.client.$options); + this.dialect = getCrudDialect(this.schema, this.client.$options, this); } protected get schema() { @@ -183,19 +184,22 @@ export abstract class BaseOperationHandler { } } + // for deduplicating base joins + const joinedBases: string[] = []; + // select if (args && 'select' in args && args.select) { // select is mutually exclusive with omit - query = this.buildFieldSelection(model, query, args.select, model); + query = this.buildFieldSelection(model, query, args.select, model, joinedBases); } else { // include all scalar fields except those in omit - query = this.buildSelectAllScalarFields(model, query, (args as any)?.omit); + query = this.buildSelectAllFields(model, query, (args as any)?.omit, joinedBases); } // include if (args && 'include' in args && args.include) { // note that 'omit' is handled above already - query = this.buildFieldSelection(model, query, args.include, model); + query = this.buildFieldSelection(model, query, args.include, model, joinedBases); } if (args?.cursor) { @@ -246,9 +250,9 @@ export abstract class BaseOperationHandler { query: SelectQueryBuilder, selectOrInclude: Record, parentAlias: string, + joinedBases: string[], ) { let result = query; - const joinedBases: string[] = []; for (const [field, payload] of Object.entries(selectOrInclude)) { if (!payload) { @@ -262,12 +266,29 @@ export abstract class BaseOperationHandler { const fieldDef = this.requireField(model, field); if (!fieldDef.relation) { + // scalar field result = this.selectField(result, model, parentAlias, field, joinedBases); } else { if (!fieldDef.array && !fieldDef.optional && payload.where) { throw new QueryError(`Field "${field}" doesn't support filtering`); } - result = this.dialect.buildRelationSelection(result, model, field, parentAlias, payload); + if (fieldDef.originModel) { + // relation is inherited from a delegate base model, need to build a join + if (!joinedBases.includes(fieldDef.originModel)) { + joinedBases.push(fieldDef.originModel); + result = this.buildDelegateJoin(parentAlias, fieldDef.originModel, result); + } + result = this.dialect.buildRelationSelection( + result, + fieldDef.originModel, + field, + fieldDef.originModel, + payload, + ); + } else { + // regular relation + result = this.dialect.buildRelationSelection(result, model, field, parentAlias, payload); + } } } @@ -332,14 +353,14 @@ export abstract class BaseOperationHandler { return query; } - private buildSelectAllScalarFields( + buildSelectAllFields( model: string, query: SelectQueryBuilder, omit?: Record, + joinedBases: string[] = [], ) { const modelDef = this.requireModel(model); let result = query; - const joinedBases: string[] = []; for (const field of Object.keys(modelDef.fields)) { if (isRelationField(this.schema, model, field)) { @@ -352,7 +373,7 @@ export abstract class BaseOperationHandler { } // select all fields from delegate descendants and pack into a JSON field `$delegate$Model` - const descendants = this.getDelegateDescendantModels(model); + const descendants = getDelegateDescendantModels(this.schema, model); for (const subModel of descendants) { if (!joinedBases.includes(subModel.name)) { joinedBases.push(subModel.name); @@ -378,17 +399,6 @@ export abstract class BaseOperationHandler { return result; } - private getDelegateDescendantModels(model: string, collected: Set = new Set()): ModelDef[] { - const subModels = Object.values(this.schema.models).filter((m) => m.baseModel === model); - subModels.forEach((def) => { - if (!collected.has(def)) { - collected.add(def); - this.getDelegateDescendantModels(def.name, collected); - } - }); - return [...collected]; - } - private selectField( query: SelectQueryBuilder, model: string, diff --git a/packages/runtime/src/client/query-utils.ts b/packages/runtime/src/client/query-utils.ts index 8c47b895..47b91a00 100644 --- a/packages/runtime/src/client/query-utils.ts +++ b/packages/runtime/src/client/query-utils.ts @@ -1,5 +1,5 @@ import type { ExpressionBuilder, ExpressionWrapper } from 'kysely'; -import { ExpressionUtils, type FieldDef, type GetModels, type SchemaDef } from '../schema'; +import { ExpressionUtils, type FieldDef, type GetModels, type ModelDef, type SchemaDef } from '../schema'; import type { OrderBy } from './crud-types'; import { InternalError, QueryError } from './errors'; import type { ClientOptions } from './options'; @@ -313,3 +313,18 @@ export function getDiscriminatorField(schema: SchemaDef, model: string) { } return discriminator.value.field; } + +export function getDelegateDescendantModels( + schema: SchemaDef, + model: string, + collected: Set = new Set(), +): ModelDef[] { + const subModels = Object.values(schema.models).filter((m) => m.baseModel === model); + subModels.forEach((def) => { + if (!collected.has(def)) { + collected.add(def); + getDelegateDescendantModels(schema, def.name, collected); + } + }); + return [...collected]; +} diff --git a/packages/runtime/test/client-api/delegate.test.ts b/packages/runtime/test/client-api/delegate.test.ts index b1e4fc83..eb55ae22 100644 --- a/packages/runtime/test/client-api/delegate.test.ts +++ b/packages/runtime/test/client-api/delegate.test.ts @@ -1,8 +1,12 @@ -import { describe, expect, it } from 'vitest'; +import { afterEach, beforeEach, describe, expect, it } from 'vitest'; import { createTestClient } from '../utils'; -describe('Delegate model tests', () => { - const POLYMORPHIC_SCHEMA = ` +const DB_NAME = `client-api-delegate-tests`; + +describe.each([{ provider: 'sqlite' as const }, { provider: 'postgresql' as const }])( + 'Delegate model tests for $provider', + ({ provider }) => { + const POLYMORPHIC_SCHEMA = ` model User { id Int @id @default(autoincrement()) email String? @unique @@ -49,158 +53,194 @@ model Gallery { } `; - it('works with create', async () => { - const client = await createTestClient(POLYMORPHIC_SCHEMA, { - usePrismaPush: true, + let client: any; + + beforeEach(async () => { + client = await createTestClient(POLYMORPHIC_SCHEMA, { + usePrismaPush: true, + provider, + dbName: provider === 'postgresql' ? DB_NAME : undefined, + }); }); - // delegate model cannot be created directly - await expect( - client.video.create({ - data: { - duration: 100, - url: 'abc', - videoType: 'MyVideo', - }, - }), - ).rejects.toThrow('is a delegate'); + afterEach(async () => { + await client.$disconnect(); + }); - // create entity with two levels of delegation - await expect( - client.ratedVideo.create({ - data: { - duration: 100, - url: 'abc', - rating: 5, + it('works with create', async () => { + // delegate model cannot be created directly + await expect( + client.video.create({ + data: { + duration: 100, + url: 'abc', + videoType: 'MyVideo', + }, + }), + ).rejects.toThrow('is a delegate'); + + // create entity with two levels of delegation + await expect( + client.ratedVideo.create({ + data: { + duration: 100, + url: 'abc', + rating: 5, + }, + }), + ).resolves.toMatchObject({ + id: expect.any(Number), + duration: 100, + url: 'abc', + rating: 5, + assetType: 'Video', + videoType: 'RatedVideo', + }); + + // create entity with relation + await expect( + client.ratedVideo.create({ + data: { + duration: 50, + url: 'bcd', + rating: 5, + user: { create: { email: 'u1@example.com' } }, + }, + include: { user: true }, + }), + ).resolves.toMatchObject({ + userId: expect.any(Number), + user: { + email: 'u1@example.com', }, - }), - ).resolves.toMatchObject({ - id: expect.any(Number), - duration: 100, - url: 'abc', - rating: 5, - assetType: 'Video', - videoType: 'RatedVideo', + }); + + // create entity with one level of delegation + await expect( + client.image.create({ + data: { + format: 'png', + gallery: { + create: {}, + }, + }, + }), + ).resolves.toMatchObject({ + id: expect.any(Number), + format: 'png', + galleryId: expect.any(Number), + assetType: 'Image', + }); }); - // create entity with relation - await expect( - client.ratedVideo.create({ + it('works with find', async () => { + const u = await client.user.create({ data: { - duration: 50, - url: 'bcd', - rating: 5, - user: { create: { email: 'u1@example.com' } }, + email: 'u1@example.com', }, - include: { user: true }, - }), - ).resolves.toMatchObject({ - userId: expect.any(Number), - user: { - email: 'u1@example.com', - }, - }); - - // create entity with one level of delegation - await expect( - client.image.create({ + }); + const v = await client.ratedVideo.create({ data: { - format: 'png', - gallery: { - create: {}, - }, + duration: 100, + url: 'abc', + rating: 5, + owner: { connect: { id: u.id } }, + user: { connect: { id: u.id } }, }, - }), - ).resolves.toMatchObject({ - id: expect.any(Number), - format: 'png', - galleryId: expect.any(Number), - assetType: 'Image', - }); - }); - - it('works with find', async () => { - const client = await createTestClient(POLYMORPHIC_SCHEMA, { - usePrismaPush: true, - log: ['query'], - }); + }); - const u = await client.user.create({ - data: { - email: 'u1@example.com', - }, - }); - const v = await client.ratedVideo.create({ - data: { + const ratedVideoContent = { + id: v.id, + createdAt: expect.any(Date), duration: 100, + rating: 5, + assetType: 'Video', + videoType: 'RatedVideo', + }; + + // include all base fields + await expect( + client.ratedVideo.findUnique({ + where: { id: v.id }, + include: { user: true, owner: true }, + }), + ).resolves.toMatchObject({ ...ratedVideoContent, user: expect.any(Object), owner: expect.any(Object) }); + + // select fields + await expect( + client.ratedVideo.findUnique({ + where: { id: v.id }, + select: { + id: true, + viewCount: true, + url: true, + rating: true, + }, + }), + ).resolves.toEqual({ + id: v.id, + viewCount: 0, url: 'abc', rating: 5, - user: { connect: { id: u.id } }, - }, - include: { user: true }, - }); + }); - const ratedVideoContent = { - id: v.id, - createdAt: expect.any(Date), - duration: 100, - rating: 5, - assetType: 'Video', - videoType: 'RatedVideo', - }; - - // include all base fields - await expect( - client.ratedVideo.findUnique({ + // omit fields + const r = await client.ratedVideo.findUnique({ where: { id: v.id }, - include: { user: true }, - }), - ).resolves.toMatchObject({ ...ratedVideoContent, user: expect.any(Object) }); - - // select fields - await expect( - client.ratedVideo.findUnique({ - where: { id: v.id }, - select: { - id: true, + omit: { viewCount: true, url: true, rating: true, }, - }), - ).resolves.toEqual({ - id: v.id, - viewCount: 0, - url: 'abc', - rating: 5, - }); - - // omit fields - const r = await client.ratedVideo.findUnique({ - where: { id: v.id }, - omit: { - viewCount: true, - url: true, - rating: true, - }, + }); + expect(r.viewCount).toBeUndefined(); + expect(r.url).toBeUndefined(); + expect(r.rating).toBeUndefined(); + expect(r.duration).toEqual(expect.any(Number)); + + // include all sub fields + await expect( + client.video.findUnique({ + where: { id: v.id }, + }), + ).resolves.toMatchObject(ratedVideoContent); + + // include all sub fields + await expect( + client.asset.findUnique({ + where: { id: v.id }, + }), + ).resolves.toMatchObject(ratedVideoContent); + + // find as a relation + await expect( + client.user.findUnique({ + where: { id: u.id }, + include: { assets: true, ratedVideos: true }, + }), + ).resolves.toMatchObject({ + assets: [ratedVideoContent], + ratedVideos: [ratedVideoContent], + }); + + // find as a relation with selection + await expect( + client.user.findUnique({ + where: { id: u.id }, + include: { + assets: { + select: { id: true, assetType: true }, + }, + ratedVideos: { + url: true, + rating: true, + }, + }, + }), + ).resolves.toMatchObject({ + assets: [{ id: v.id, assetType: 'Video' }], + ratedVideos: [{ url: 'abc', rating: 5 }], + }); }); - expect(r.viewCount).toBeUndefined(); - expect(r.url).toBeUndefined(); - expect(r.rating).toBeUndefined(); - expect(r.duration).toEqual(expect.any(Number)); - - // include all sub fields - await expect( - client.video.findUnique({ - where: { id: v.id }, - }), - ).resolves.toMatchObject(ratedVideoContent); - - // include all sub fields - await expect( - client.asset.findUnique({ - where: { id: v.id }, - }), - ).resolves.toMatchObject(ratedVideoContent); - }); -}); + }, +); From 2e3347e51f6d626536703436774da2081ceb4bc2 Mon Sep 17 00:00:00 2001 From: ymc9 <104139426+ymc9@users.noreply.github.com> Date: Thu, 24 Jul 2025 16:20:01 +0800 Subject: [PATCH 2/3] fixes --- .../runtime/src/client/crud/dialects/base.ts | 137 +++++++++++++++--- .../runtime/src/client/crud/dialects/index.ts | 6 +- .../src/client/crud/dialects/postgresql.ts | 2 +- .../src/client/crud/dialects/sqlite.ts | 2 +- .../src/client/crud/operations/base.ts | 94 +----------- 5 files changed, 128 insertions(+), 113 deletions(-) diff --git a/packages/runtime/src/client/crud/dialects/base.ts b/packages/runtime/src/client/crud/dialects/base.ts index 93cbaefc..3ce8e0f7 100644 --- a/packages/runtime/src/client/crud/dialects/base.ts +++ b/packages/runtime/src/client/crud/dialects/base.ts @@ -5,6 +5,7 @@ import { match, P } from 'ts-pattern'; import type { BuiltinType, DataSourceProviderType, FieldDef, GetModels, SchemaDef } from '../../../schema'; import { enumerate } from '../../../utils/enumerate'; import type { OrArray } from '../../../utils/type-utils'; +import { DELEGATE_JOINED_FIELD_PREFIX } from '../../constants'; import type { BooleanFilter, BytesFilter, @@ -20,42 +21,30 @@ import { buildFieldRef, buildJoinPairs, flattenCompoundUniqueFilters, + getDelegateDescendantModels, getField, getIdFields, getManyToManyRelation, getRelationForeignKeyFieldPairs, isEnum, + isInheritedField, + isRelationField, makeDefaultOrderBy, requireField, + requireModel, } from '../../query-utils'; -import type { BaseOperationHandler } from '../operations/base'; export abstract class BaseCrudDialect { constructor( protected readonly schema: Schema, protected readonly options: ClientOptions, - protected readonly handler: BaseOperationHandler, ) {} - abstract get provider(): DataSourceProviderType; - transformPrimitive(value: unknown, _type: BuiltinType, _forArrayField: boolean) { return value; } - abstract buildRelationSelection( - query: SelectQueryBuilder, - model: string, - relationField: string, - parentAlias: string, - payload: true | FindArgs, true>, - ): SelectQueryBuilder; - - abstract buildSkipTake( - query: SelectQueryBuilder, - skip: number | undefined, - take: number | undefined, - ): SelectQueryBuilder; + // #region common query builders buildFilter( eb: ExpressionBuilder, @@ -790,6 +779,92 @@ export abstract class BaseCrudDialect { return result; } + buildSelectAllFields( + model: string, + query: SelectQueryBuilder, + omit?: Record, + joinedBases: string[] = [], + ) { + const modelDef = requireModel(this.schema, model); + let result = query; + + for (const field of Object.keys(modelDef.fields)) { + if (isRelationField(this.schema, model, field)) { + continue; + } + if (omit?.[field] === true) { + continue; + } + result = this.buildSelectField(result, model, model, field, joinedBases); + } + + // select all fields from delegate descendants and pack into a JSON field `$delegate$Model` + const descendants = getDelegateDescendantModels(this.schema, model); + for (const subModel of descendants) { + if (!joinedBases.includes(subModel.name)) { + joinedBases.push(subModel.name); + result = this.buildDelegateJoin(model, subModel.name, result); + } + result = result.select((eb) => { + const jsonObject: Record> = {}; + for (const field of Object.keys(subModel.fields)) { + if ( + isRelationField(this.schema, subModel.name, field) || + isInheritedField(this.schema, subModel.name, field) + ) { + continue; + } + jsonObject[field] = eb.ref(`${subModel.name}.${field}`); + } + return this.buildJsonObject(eb, jsonObject).as(`${DELEGATE_JOINED_FIELD_PREFIX}${subModel.name}`); + }); + } + + return result; + } + + buildSelectField( + query: SelectQueryBuilder, + model: string, + modelAlias: string, + field: string, + joinedBases: string[], + ) { + const fieldDef = requireField(this.schema, model, field); + + if (fieldDef.computed) { + // TODO: computed field from delegate base? + return query.select((eb) => buildFieldRef(this.schema, model, field, this.options, eb).as(field)); + } else if (!fieldDef.originModel) { + // regular field + return query.select(sql.ref(`${modelAlias}.${field}`).as(field)); + } else { + // field from delegate base, build a join + let result = query; + if (!joinedBases.includes(fieldDef.originModel)) { + joinedBases.push(fieldDef.originModel); + result = this.buildDelegateJoin(model, fieldDef.originModel, result); + } + result = this.buildSelectField(result, fieldDef.originModel, fieldDef.originModel, field, joinedBases); + return result; + } + } + + buildDelegateJoin(thisModel: string, otherModel: string, query: SelectQueryBuilder) { + const idFields = getIdFields(this.schema, thisModel); + query = query.leftJoin(otherModel, (qb) => { + for (const idField of idFields) { + qb = qb.onRef(`${thisModel}.${idField}`, '=', `${otherModel}.${idField}`); + } + return qb; + }); + return query; + } + + // #endregion + + // #region utils + private negateSort(sort: SortOrder, negated: boolean) { return negated ? (sort === 'asc' ? 'desc' : 'asc') : sort; } @@ -844,6 +919,32 @@ export abstract class BaseCrudDialect { return eb.not(this.and(eb, ...args)); } + // #endregion + + // #region abstract methods + + abstract get provider(): DataSourceProviderType; + + /** + * Builds selection for a relation field. + */ + abstract buildRelationSelection( + query: SelectQueryBuilder, + model: string, + relationField: string, + parentAlias: string, + payload: true | FindArgs, true>, + ): SelectQueryBuilder; + + /** + * Builds skip and take clauses. + */ + abstract buildSkipTake( + query: SelectQueryBuilder, + skip: number | undefined, + take: number | undefined, + ): SelectQueryBuilder; + /** * Builds an Kysely expression that returns a JSON object for the given key-value pairs. */ @@ -879,4 +980,6 @@ export abstract class BaseCrudDialect { * Whether the dialect supports DISTINCT ON. */ abstract get supportsDistinctOn(): boolean; + + // #endregion } diff --git a/packages/runtime/src/client/crud/dialects/index.ts b/packages/runtime/src/client/crud/dialects/index.ts index f8fa6506..9d67009e 100644 --- a/packages/runtime/src/client/crud/dialects/index.ts +++ b/packages/runtime/src/client/crud/dialects/index.ts @@ -1,7 +1,6 @@ import { match } from 'ts-pattern'; import type { SchemaDef } from '../../../schema'; import type { ClientOptions } from '../../options'; -import type { BaseOperationHandler } from '../operations/base'; import type { BaseCrudDialect } from './base'; import { PostgresCrudDialect } from './postgresql'; import { SqliteCrudDialect } from './sqlite'; @@ -9,10 +8,9 @@ import { SqliteCrudDialect } from './sqlite'; export function getCrudDialect( schema: Schema, options: ClientOptions, - handler: BaseOperationHandler, ): BaseCrudDialect { return match(schema.provider.type) - .with('sqlite', () => new SqliteCrudDialect(schema, options, handler)) - .with('postgresql', () => new PostgresCrudDialect(schema, options, handler)) + .with('sqlite', () => new SqliteCrudDialect(schema, options)) + .with('postgresql', () => new PostgresCrudDialect(schema, options)) .exhaustive(); } diff --git a/packages/runtime/src/client/crud/dialects/postgresql.ts b/packages/runtime/src/client/crud/dialects/postgresql.ts index 4d8a70db..be273e81 100644 --- a/packages/runtime/src/client/crud/dialects/postgresql.ts +++ b/packages/runtime/src/client/crud/dialects/postgresql.ts @@ -87,7 +87,7 @@ export class PostgresCrudDialect extends BaseCrudDiale // we need to build a subquery to handle them before aggregation result = eb.selectFrom(() => { let subQuery = eb.selectFrom(relationModel); - subQuery = this.handler.buildSelectAllFields( + subQuery = this.buildSelectAllFields( relationModel, subQuery, typeof payload === 'object' ? payload?.omit : undefined, diff --git a/packages/runtime/src/client/crud/dialects/sqlite.ts b/packages/runtime/src/client/crud/dialects/sqlite.ts index c18be56a..2961b864 100644 --- a/packages/runtime/src/client/crud/dialects/sqlite.ts +++ b/packages/runtime/src/client/crud/dialects/sqlite.ts @@ -80,7 +80,7 @@ export class SqliteCrudDialect extends BaseCrudDialect let subQuery = eb.selectFrom(relationModel); const joinBases: string[] = []; - subQuery = this.handler.buildSelectAllFields( + subQuery = this.buildSelectAllFields( relationModel, subQuery, typeof payload === 'object' ? payload?.omit : undefined, diff --git a/packages/runtime/src/client/crud/operations/base.ts b/packages/runtime/src/client/crud/operations/base.ts index f22d8985..9ae6af01 100644 --- a/packages/runtime/src/client/crud/operations/base.ts +++ b/packages/runtime/src/client/crud/operations/base.ts @@ -22,7 +22,7 @@ import { ExpressionUtils, type GetModels, type ModelDef, type SchemaDef } from ' import { clone } from '../../../utils/clone'; import { enumerate } from '../../../utils/enumerate'; import { extractFields, fieldsToSelectObject } from '../../../utils/object-utils'; -import { CONTEXT_COMMENT_PREFIX, DELEGATE_JOINED_FIELD_PREFIX, NUMERIC_FIELD_TYPES } from '../../constants'; +import { CONTEXT_COMMENT_PREFIX, NUMERIC_FIELD_TYPES } from '../../constants'; import type { CRUD } from '../../contract'; import type { FindArgs, SelectIncludeOmit, SortOrder, WhereInput } from '../../crud-types'; import { InternalError, NotFoundError, QueryError } from '../../errors'; @@ -33,7 +33,6 @@ import { ensureArray, extractIdFields, flattenCompoundUniqueFilters, - getDelegateDescendantModels, getDiscriminatorField, getField, getIdFields, @@ -42,7 +41,6 @@ import { getModel, getRelationForeignKeyFieldPairs, isForeignKeyField, - isInheritedField, isRelationField, isScalarField, makeDefaultOrderBy, @@ -193,7 +191,7 @@ export abstract class BaseOperationHandler { query = this.buildFieldSelection(model, query, args.select, model, joinedBases); } else { // include all scalar fields except those in omit - query = this.buildSelectAllFields(model, query, (args as any)?.omit, joinedBases); + query = this.dialect.buildSelectAllFields(model, query, (args as any)?.omit, joinedBases); } // include @@ -267,7 +265,7 @@ export abstract class BaseOperationHandler { const fieldDef = this.requireField(model, field); if (!fieldDef.relation) { // scalar field - result = this.selectField(result, model, parentAlias, field, joinedBases); + result = this.dialect.buildSelectField(result, model, parentAlias, field, joinedBases); } else { if (!fieldDef.array && !fieldDef.optional && payload.where) { throw new QueryError(`Field "${field}" doesn't support filtering`); @@ -276,7 +274,7 @@ export abstract class BaseOperationHandler { // relation is inherited from a delegate base model, need to build a join if (!joinedBases.includes(fieldDef.originModel)) { joinedBases.push(fieldDef.originModel); - result = this.buildDelegateJoin(parentAlias, fieldDef.originModel, result); + result = this.dialect.buildDelegateJoin(parentAlias, fieldDef.originModel, result); } result = this.dialect.buildRelationSelection( result, @@ -353,90 +351,6 @@ export abstract class BaseOperationHandler { return query; } - buildSelectAllFields( - model: string, - query: SelectQueryBuilder, - omit?: Record, - joinedBases: string[] = [], - ) { - const modelDef = this.requireModel(model); - let result = query; - - for (const field of Object.keys(modelDef.fields)) { - if (isRelationField(this.schema, model, field)) { - continue; - } - if (omit?.[field] === true) { - continue; - } - result = this.selectField(result, model, model, field, joinedBases); - } - - // select all fields from delegate descendants and pack into a JSON field `$delegate$Model` - const descendants = getDelegateDescendantModels(this.schema, model); - for (const subModel of descendants) { - if (!joinedBases.includes(subModel.name)) { - joinedBases.push(subModel.name); - result = this.buildDelegateJoin(model, subModel.name, result); - } - result = result.select((eb) => { - const jsonObject: Record> = {}; - for (const field of Object.keys(subModel.fields)) { - if ( - isRelationField(this.schema, subModel.name, field) || - isInheritedField(this.schema, subModel.name, field) - ) { - continue; - } - jsonObject[field] = eb.ref(`${subModel.name}.${field}`); - } - return this.dialect - .buildJsonObject(eb, jsonObject) - .as(`${DELEGATE_JOINED_FIELD_PREFIX}${subModel.name}`); - }); - } - - return result; - } - - private selectField( - query: SelectQueryBuilder, - model: string, - modelAlias: string, - field: string, - joinedBases: string[], - ) { - const fieldDef = this.requireField(model, field); - - if (fieldDef.computed) { - // TODO: computed field from delegate base? - return query.select((eb) => buildFieldRef(this.schema, model, field, this.options, eb).as(field)); - } else if (!fieldDef.originModel) { - // regular field - return query.select(sql.ref(`${modelAlias}.${field}`).as(field)); - } else { - // field from delegate base, build a join - let result = query; - if (!joinedBases.includes(fieldDef.originModel)) { - joinedBases.push(fieldDef.originModel); - result = this.buildDelegateJoin(model, fieldDef.originModel, result); - } - result = this.selectField(result, fieldDef.originModel, fieldDef.originModel, field, joinedBases); - return result; - } - } - - private buildDelegateJoin(thisModel: string, otherModel: string, query: SelectQueryBuilder) { - const idFields = getIdFields(this.schema, thisModel); - query = query.leftJoin(otherModel, (qb) => { - for (const idField of idFields) { - qb = qb.onRef(`${thisModel}.${idField}`, '=', `${otherModel}.${idField}`); - } - return qb; - }); - return query; - } - private buildCursorFilter( model: string, query: SelectQueryBuilder, From 08833087a712dcb5bc2c956cfc45bd9b231b2739 Mon Sep 17 00:00:00 2001 From: ymc9 <104139426+ymc9@users.noreply.github.com> Date: Thu, 24 Jul 2025 16:20:29 +0800 Subject: [PATCH 3/3] update --- packages/runtime/src/client/crud/operations/base.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/runtime/src/client/crud/operations/base.ts b/packages/runtime/src/client/crud/operations/base.ts index 9ae6af01..ca77245a 100644 --- a/packages/runtime/src/client/crud/operations/base.ts +++ b/packages/runtime/src/client/crud/operations/base.ts @@ -83,7 +83,7 @@ export abstract class BaseOperationHandler { protected readonly model: GetModels, protected readonly inputValidator: InputValidator, ) { - this.dialect = getCrudDialect(this.schema, this.client.$options, this); + this.dialect = getCrudDialect(this.schema, this.client.$options); } protected get schema() {