Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions TODO.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
- [x] Filtering
- [x] Sorting
- [x] Pagination
- [ ] Distinct
- [x] Distinct
- [ ] Update
- [x] Input validation
- [x] Top-level
Expand All @@ -52,7 +52,7 @@
- [ ] Extensions
- [x] Query builder API
- [x] Computed fields
- [?] Prisma client extension
- [ ] Prisma client extension
- [ ] Misc
- [ ] Compound ID
- [ ] Cross field comparison
Expand Down
7 changes: 6 additions & 1 deletion packages/runtime/src/client/crud-types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,10 @@ export type SelectIncludeOmit<
omit?: OmitFields<Schema, Model>;
};

type Distinct<Schema extends SchemaDef, Model extends GetModels<Schema>> = {
distinct?: OrArray<NonRelationFields<Schema, Model>>;
};

type Select<
Schema extends SchemaDef,
Model extends GetModels<Schema>,
Expand Down Expand Up @@ -565,7 +569,8 @@ export type FindArgs<
where?: WhereInput<Schema, Model>;
}
: {}) &
SelectIncludeOmit<Schema, Model, Collection>;
SelectIncludeOmit<Schema, Model, Collection> &
Distinct<Schema, Model>;

export type FindUniqueArgs<
Schema extends SchemaDef,
Expand Down
6 changes: 3 additions & 3 deletions packages/runtime/src/client/crud/dialects/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1025,7 +1025,7 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {

abstract buildArrayLiteralSQL(values: unknown[]): string;

get supportsUpdateWithLimit() {
return true;
}
abstract get supportsUpdateWithLimit(): boolean;

abstract get supportsDistinctOn(): boolean;
}
4 changes: 4 additions & 0 deletions packages/runtime/src/client/crud/dialects/postgresql.ts
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,10 @@ export class PostgresCrudDialect<
return false;
}

override get supportsDistinctOn(): boolean {
return true;
}

override buildArrayLength(
eb: ExpressionBuilder<any, any>,
array: Expression<unknown>
Expand Down
4 changes: 4 additions & 0 deletions packages/runtime/src/client/crud/dialects/sqlite.ts
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,10 @@ export class SqliteCrudDialect<
return false;
}

override get supportsDistinctOn() {
return false;
}

override buildArrayLength(
eb: ExpressionBuilder<any, any>,
array: Expression<unknown>
Expand Down
33 changes: 32 additions & 1 deletion packages/runtime/src/client/crud/operations/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import type { ToKysely } from '../../query-builder';
import {
buildFieldRef,
buildJoinPairs,
ensureArray,
getField,
getIdFields,
getIdValues,
Expand Down Expand Up @@ -154,6 +155,21 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
// skip && take
query = this.dialect.buildSkipTake(query, args?.skip, args?.take);

let inMemoryDistinct: string[] | undefined = undefined;

// distinct
if (args?.distinct) {
const distinct = ensureArray(args.distinct);
if (this.dialect.supportsDistinctOn) {
query = query.distinctOn(
distinct.map((f: any) => sql.ref(`${model}.${f}`))
);
} else {
// in-memory distinct after fetching all results
inMemoryDistinct = distinct;
}
}

// orderBy
if (args?.orderBy) {
query = this.dialect.buildOrderBy(
Expand Down Expand Up @@ -188,7 +204,22 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
);

try {
return await query.execute();
let result = await query.execute();
if (inMemoryDistinct) {
const distinctResult: Record<string, unknown>[] = [];
const seen = new Set<string>();
for (const r of result as any[]) {
const key = JSON.stringify(
inMemoryDistinct.map((f) => r[f])
)!;
if (!seen.has(key)) {
distinctResult.push(r);
seen.add(key);
}
}
result = distinctResult;
}
return result;
} catch (err) {
const { sql, parameters } = query.compile();
throw new QueryError(
Expand Down
6 changes: 5 additions & 1 deletion packages/runtime/src/client/crud/operations/create.ts
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,11 @@ export class CreateOperationHandler<
select: args.select,
include: args.include,
omit: args.omit,
where: getIdValues(this.schema, this.model, createResult),
where: getIdValues(
this.schema,
this.model,
createResult
) as any,
});
});

Expand Down
21 changes: 15 additions & 6 deletions packages/runtime/src/client/crud/validator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ export class InputValidator<Schema extends SchemaDef> {
fields['select'] = this.makeSelectSchema(model).optional();
fields['include'] = this.makeIncludeSchema(model).optional();
fields['omit'] = this.makeOmitSchema(model).optional();
fields['distinct'] = this.makeDistinctSchema(model).optional();

if (collection) {
fields['skip'] = z.number().int().nonnegative().optional();
Expand Down Expand Up @@ -192,7 +193,7 @@ export class InputValidator<Schema extends SchemaDef> {
.otherwise(() => z.unknown());
}

protected makeWhereSchema(
private makeWhereSchema(
model: string,
unique: boolean,
withoutRelationFields = false
Expand Down Expand Up @@ -344,7 +345,7 @@ export class InputValidator<Schema extends SchemaDef> {
]);
}

protected makePrimitiveFilterSchema(type: BuiltinType, optional: boolean) {
private makePrimitiveFilterSchema(type: BuiltinType, optional: boolean) {
return match(type)
.with('String', () => this.makeStringFilterSchema(optional))
.with(P.union('Int', 'Float', 'Decimal', 'BigInt'), (type) =>
Expand Down Expand Up @@ -447,7 +448,7 @@ export class InputValidator<Schema extends SchemaDef> {
);
}

protected makeSelectSchema(model: string) {
private makeSelectSchema(model: string) {
const modelDef = requireModel(this.schema, model);
const fields: Record<string, ZodSchema> = {};
for (const field of Object.keys(modelDef.fields)) {
Expand Down Expand Up @@ -510,7 +511,7 @@ export class InputValidator<Schema extends SchemaDef> {
return z.object(fields).strict();
}

protected makeOmitSchema(model: string) {
private makeOmitSchema(model: string) {
const modelDef = requireModel(this.schema, model);
const fields: Record<string, ZodSchema> = {};
for (const field of Object.keys(modelDef.fields)) {
Expand All @@ -522,7 +523,7 @@ export class InputValidator<Schema extends SchemaDef> {
return z.object(fields).strict();
}

protected makeIncludeSchema(model: string) {
private makeIncludeSchema(model: string) {
const modelDef = requireModel(this.schema, model);
const fields: Record<string, ZodSchema> = {};
for (const field of Object.keys(modelDef.fields)) {
Expand Down Expand Up @@ -556,7 +557,7 @@ export class InputValidator<Schema extends SchemaDef> {
return z.object(fields).strict();
}

protected makeOrderBySchema(
private makeOrderBySchema(
model: string,
withRelation: boolean,
WithAggregation: boolean
Expand Down Expand Up @@ -617,6 +618,14 @@ export class InputValidator<Schema extends SchemaDef> {
return z.object(fields);
}

private makeDistinctSchema(model: string) {
const modelDef = requireModel(this.schema, model);
const nonRelationFields = Object.keys(modelDef.fields).filter(
(field) => !modelDef.fields[field]?.relation
);
return this.orArray(z.enum(nonRelationFields as any), true);
}

// #endregion

// #region Create
Expand Down
8 changes: 8 additions & 0 deletions packages/runtime/src/client/query-utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -249,3 +249,11 @@ export function buildJoinPairs(
}
});
}

export function ensureArray<T>(value: T | T[]): T[] {
if (Array.isArray(value)) {
return value;
} else {
return [value];
}
}
2 changes: 1 addition & 1 deletion packages/runtime/src/schema/schema.ts
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ export type GetEnum<
export type GetFields<
Schema extends SchemaDef,
Model extends GetModels<Schema>
> = keyof Schema['models'][Model]['fields'];
> = Extract<keyof GetModel<Schema, Model>['fields'], string>;

export type GetField<
Schema extends SchemaDef,
Expand Down
42 changes: 42 additions & 0 deletions packages/runtime/test/client-api/find.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,48 @@ describe.each(createClientSpecs(PG_DB_NAME))(
).resolves.toMatchObject(user2);
});

it('works with distinct', async () => {
await createUser(client, 'u1@test.com', {
name: 'Admin1',
role: 'ADMIN',
});
await createUser(client, 'u3@test.com', {
name: 'User',
role: 'USER',
});
await createUser(client, 'u2@test.com', {
name: 'Admin2',
role: 'ADMIN',
});
await createUser(client, 'u4@test.com', {
name: 'User',
role: 'USER',
});

// single field distinct
let r = await client.user.findMany({ distinct: ['role'] });
expect(r).toHaveLength(2);
expect(r).toEqual(
expect.arrayContaining([
expect.objectContaining({ role: 'ADMIN' }),
expect.objectContaining({ role: 'USER' }),
])
);

// multiple fields distinct
r = await client.user.findMany({
distinct: ['role', 'name'],
});
expect(r).toHaveLength(3);
expect(r).toEqual(
expect.arrayContaining([
expect.objectContaining({ name: 'Admin1', role: 'ADMIN' }),
expect.objectContaining({ name: 'Admin2', role: 'ADMIN' }),
expect.objectContaining({ name: 'User', role: 'USER' }),
])
);
});

it('works with unique finds', async () => {
let r = await client.user.findUnique({ where: { id: 'none' } });
expect(r).toBeNull();
Expand Down