From 934ff692b7b5c6b69264358f7bec0866784cb5ed Mon Sep 17 00:00:00 2001 From: ymc9 <104139426+ymc9@users.noreply.github.com> Date: Sun, 1 Jun 2025 22:32:59 -0700 Subject: [PATCH] feat: select distinct --- TODO.md | 4 +- packages/runtime/src/client/crud-types.ts | 7 +++- .../runtime/src/client/crud/dialects/base.ts | 6 +-- .../src/client/crud/dialects/postgresql.ts | 4 ++ .../src/client/crud/dialects/sqlite.ts | 4 ++ .../src/client/crud/operations/base.ts | 33 ++++++++++++++- .../src/client/crud/operations/create.ts | 6 ++- packages/runtime/src/client/crud/validator.ts | 21 +++++++--- packages/runtime/src/client/query-utils.ts | 8 ++++ packages/runtime/src/schema/schema.ts | 2 +- packages/runtime/test/client-api/find.test.ts | 42 +++++++++++++++++++ 11 files changed, 122 insertions(+), 15 deletions(-) diff --git a/TODO.md b/TODO.md index 55d64a6c..84a13f6a 100644 --- a/TODO.md +++ b/TODO.md @@ -35,7 +35,7 @@ - [x] Filtering - [x] Sorting - [x] Pagination - - [ ] Distinct + - [x] Distinct - [ ] Update - [x] Input validation - [x] Top-level @@ -52,7 +52,7 @@ - [ ] Extensions - [x] Query builder API - [x] Computed fields - - [?] Prisma client extension + - [ ] Prisma client extension - [ ] Misc - [ ] Compound ID - [ ] Cross field comparison diff --git a/packages/runtime/src/client/crud-types.ts b/packages/runtime/src/client/crud-types.ts index 410a8e96..4a6a48a8 100644 --- a/packages/runtime/src/client/crud-types.ts +++ b/packages/runtime/src/client/crud-types.ts @@ -358,6 +358,10 @@ export type SelectIncludeOmit< omit?: OmitFields; }; +type Distinct> = { + distinct?: OrArray>; +}; + type Select< Schema extends SchemaDef, Model extends GetModels, @@ -565,7 +569,8 @@ export type FindArgs< where?: WhereInput; } : {}) & - SelectIncludeOmit; + SelectIncludeOmit & + Distinct; export type FindUniqueArgs< Schema extends SchemaDef, diff --git a/packages/runtime/src/client/crud/dialects/base.ts b/packages/runtime/src/client/crud/dialects/base.ts index 230023f0..573107f6 100644 --- a/packages/runtime/src/client/crud/dialects/base.ts +++ b/packages/runtime/src/client/crud/dialects/base.ts @@ -1025,7 +1025,7 @@ export abstract class BaseCrudDialect { abstract buildArrayLiteralSQL(values: unknown[]): string; - get supportsUpdateWithLimit() { - return true; - } + abstract get supportsUpdateWithLimit(): boolean; + + abstract get supportsDistinctOn(): boolean; } diff --git a/packages/runtime/src/client/crud/dialects/postgresql.ts b/packages/runtime/src/client/crud/dialects/postgresql.ts index a33b4cca..71b94187 100644 --- a/packages/runtime/src/client/crud/dialects/postgresql.ts +++ b/packages/runtime/src/client/crud/dialects/postgresql.ts @@ -329,6 +329,10 @@ export class PostgresCrudDialect< return false; } + override get supportsDistinctOn(): boolean { + return true; + } + override buildArrayLength( eb: ExpressionBuilder, array: Expression diff --git a/packages/runtime/src/client/crud/dialects/sqlite.ts b/packages/runtime/src/client/crud/dialects/sqlite.ts index 59697514..b9e88647 100644 --- a/packages/runtime/src/client/crud/dialects/sqlite.ts +++ b/packages/runtime/src/client/crud/dialects/sqlite.ts @@ -283,6 +283,10 @@ export class SqliteCrudDialect< return false; } + override get supportsDistinctOn() { + return false; + } + override buildArrayLength( eb: ExpressionBuilder, array: Expression diff --git a/packages/runtime/src/client/crud/operations/base.ts b/packages/runtime/src/client/crud/operations/base.ts index 1792fb9e..22bc7f90 100644 --- a/packages/runtime/src/client/crud/operations/base.ts +++ b/packages/runtime/src/client/crud/operations/base.ts @@ -35,6 +35,7 @@ import type { ToKysely } from '../../query-builder'; import { buildFieldRef, buildJoinPairs, + ensureArray, getField, getIdFields, getIdValues, @@ -154,6 +155,21 @@ export abstract class BaseOperationHandler { // skip && take query = this.dialect.buildSkipTake(query, args?.skip, args?.take); + let inMemoryDistinct: string[] | undefined = undefined; + + // distinct + if (args?.distinct) { + const distinct = ensureArray(args.distinct); + if (this.dialect.supportsDistinctOn) { + query = query.distinctOn( + distinct.map((f: any) => sql.ref(`${model}.${f}`)) + ); + } else { + // in-memory distinct after fetching all results + inMemoryDistinct = distinct; + } + } + // orderBy if (args?.orderBy) { query = this.dialect.buildOrderBy( @@ -188,7 +204,22 @@ export abstract class BaseOperationHandler { ); try { - return await query.execute(); + let result = await query.execute(); + if (inMemoryDistinct) { + const distinctResult: Record[] = []; + const seen = new Set(); + for (const r of result as any[]) { + const key = JSON.stringify( + inMemoryDistinct.map((f) => r[f]) + )!; + if (!seen.has(key)) { + distinctResult.push(r); + seen.add(key); + } + } + result = distinctResult; + } + return result; } catch (err) { const { sql, parameters } = query.compile(); throw new QueryError( diff --git a/packages/runtime/src/client/crud/operations/create.ts b/packages/runtime/src/client/crud/operations/create.ts index 5ec8bf9f..2093843f 100644 --- a/packages/runtime/src/client/crud/operations/create.ts +++ b/packages/runtime/src/client/crud/operations/create.ts @@ -46,7 +46,11 @@ export class CreateOperationHandler< select: args.select, include: args.include, omit: args.omit, - where: getIdValues(this.schema, this.model, createResult), + where: getIdValues( + this.schema, + this.model, + createResult + ) as any, }); }); diff --git a/packages/runtime/src/client/crud/validator.ts b/packages/runtime/src/client/crud/validator.ts index 88ed433f..029b3fdf 100644 --- a/packages/runtime/src/client/crud/validator.ts +++ b/packages/runtime/src/client/crud/validator.ts @@ -157,6 +157,7 @@ export class InputValidator { fields['select'] = this.makeSelectSchema(model).optional(); fields['include'] = this.makeIncludeSchema(model).optional(); fields['omit'] = this.makeOmitSchema(model).optional(); + fields['distinct'] = this.makeDistinctSchema(model).optional(); if (collection) { fields['skip'] = z.number().int().nonnegative().optional(); @@ -192,7 +193,7 @@ export class InputValidator { .otherwise(() => z.unknown()); } - protected makeWhereSchema( + private makeWhereSchema( model: string, unique: boolean, withoutRelationFields = false @@ -344,7 +345,7 @@ export class InputValidator { ]); } - protected makePrimitiveFilterSchema(type: BuiltinType, optional: boolean) { + private makePrimitiveFilterSchema(type: BuiltinType, optional: boolean) { return match(type) .with('String', () => this.makeStringFilterSchema(optional)) .with(P.union('Int', 'Float', 'Decimal', 'BigInt'), (type) => @@ -447,7 +448,7 @@ export class InputValidator { ); } - protected makeSelectSchema(model: string) { + private makeSelectSchema(model: string) { const modelDef = requireModel(this.schema, model); const fields: Record = {}; for (const field of Object.keys(modelDef.fields)) { @@ -510,7 +511,7 @@ export class InputValidator { return z.object(fields).strict(); } - protected makeOmitSchema(model: string) { + private makeOmitSchema(model: string) { const modelDef = requireModel(this.schema, model); const fields: Record = {}; for (const field of Object.keys(modelDef.fields)) { @@ -522,7 +523,7 @@ export class InputValidator { return z.object(fields).strict(); } - protected makeIncludeSchema(model: string) { + private makeIncludeSchema(model: string) { const modelDef = requireModel(this.schema, model); const fields: Record = {}; for (const field of Object.keys(modelDef.fields)) { @@ -556,7 +557,7 @@ export class InputValidator { return z.object(fields).strict(); } - protected makeOrderBySchema( + private makeOrderBySchema( model: string, withRelation: boolean, WithAggregation: boolean @@ -617,6 +618,14 @@ export class InputValidator { return z.object(fields); } + private makeDistinctSchema(model: string) { + const modelDef = requireModel(this.schema, model); + const nonRelationFields = Object.keys(modelDef.fields).filter( + (field) => !modelDef.fields[field]?.relation + ); + return this.orArray(z.enum(nonRelationFields as any), true); + } + // #endregion // #region Create diff --git a/packages/runtime/src/client/query-utils.ts b/packages/runtime/src/client/query-utils.ts index 270cbd24..19621323 100644 --- a/packages/runtime/src/client/query-utils.ts +++ b/packages/runtime/src/client/query-utils.ts @@ -249,3 +249,11 @@ export function buildJoinPairs( } }); } + +export function ensureArray(value: T | T[]): T[] { + if (Array.isArray(value)) { + return value; + } else { + return [value]; + } +} diff --git a/packages/runtime/src/schema/schema.ts b/packages/runtime/src/schema/schema.ts index 19a149a4..c504a976 100644 --- a/packages/runtime/src/schema/schema.ts +++ b/packages/runtime/src/schema/schema.ts @@ -121,7 +121,7 @@ export type GetEnum< export type GetFields< Schema extends SchemaDef, Model extends GetModels -> = keyof Schema['models'][Model]['fields']; +> = Extract['fields'], string>; export type GetField< Schema extends SchemaDef, diff --git a/packages/runtime/test/client-api/find.test.ts b/packages/runtime/test/client-api/find.test.ts index 4e568f4b..2f42bf00 100644 --- a/packages/runtime/test/client-api/find.test.ts +++ b/packages/runtime/test/client-api/find.test.ts @@ -154,6 +154,48 @@ describe.each(createClientSpecs(PG_DB_NAME))( ).resolves.toMatchObject(user2); }); + it('works with distinct', async () => { + await createUser(client, 'u1@test.com', { + name: 'Admin1', + role: 'ADMIN', + }); + await createUser(client, 'u3@test.com', { + name: 'User', + role: 'USER', + }); + await createUser(client, 'u2@test.com', { + name: 'Admin2', + role: 'ADMIN', + }); + await createUser(client, 'u4@test.com', { + name: 'User', + role: 'USER', + }); + + // single field distinct + let r = await client.user.findMany({ distinct: ['role'] }); + expect(r).toHaveLength(2); + expect(r).toEqual( + expect.arrayContaining([ + expect.objectContaining({ role: 'ADMIN' }), + expect.objectContaining({ role: 'USER' }), + ]) + ); + + // multiple fields distinct + r = await client.user.findMany({ + distinct: ['role', 'name'], + }); + expect(r).toHaveLength(3); + expect(r).toEqual( + expect.arrayContaining([ + expect.objectContaining({ name: 'Admin1', role: 'ADMIN' }), + expect.objectContaining({ name: 'Admin2', role: 'ADMIN' }), + expect.objectContaining({ name: 'User', role: 'USER' }), + ]) + ); + }); + it('works with unique finds', async () => { let r = await client.user.findUnique({ where: { id: 'none' } }); expect(r).toBeNull();