Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 21 additions & 3 deletions packages/runtime/src/client/client-impl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import type { AuthType } from '../schema/auth';
import type { UnwrapTuplePromises } from '../utils/type-utils';
import type { ClientConstructor, ClientContract, ModelOperations, TransactionIsolationLevel } from './contract';
import { AggregateOperationHandler } from './crud/operations/aggregate';
import type { CrudOperation } from './crud/operations/base';
import type { AllCrudOperation, CoreCrudOperation } from './crud/operations/base';
import { BaseOperationHandler } from './crud/operations/base';
import { CountOperationHandler } from './crud/operations/count';
import { CreateOperationHandler } from './crud/operations/create';
Expand Down Expand Up @@ -351,7 +351,8 @@ function createModelCrudHandler<Schema extends SchemaDef, Model extends GetModel
resultProcessor: ResultProcessor<Schema>,
): ModelOperations<Schema, Model> {
const createPromise = (
operation: CrudOperation,
operation: CoreCrudOperation,
nominalOperation: AllCrudOperation,
args: unknown,
handler: BaseOperationHandler<Schema>,
postProcess = false,
Expand Down Expand Up @@ -383,7 +384,7 @@ function createModelCrudHandler<Schema extends SchemaDef, Model extends GetModel
onQuery({
client,
model,
operation,
operation: nominalOperation,
// reflect the latest override if provided
args: _args,
// ensure inner overrides are propagated to the previous proceed
Expand All @@ -400,6 +401,7 @@ function createModelCrudHandler<Schema extends SchemaDef, Model extends GetModel
return {
findUnique: (args: unknown) => {
return createPromise(
'findUnique',
'findUnique',
args,
new FindOperationHandler<Schema>(client, model, inputValidator),
Expand All @@ -410,6 +412,7 @@ function createModelCrudHandler<Schema extends SchemaDef, Model extends GetModel
findUniqueOrThrow: (args: unknown) => {
return createPromise(
'findUnique',
'findUniqueOrThrow',
args,
new FindOperationHandler<Schema>(client, model, inputValidator),
true,
Expand All @@ -419,6 +422,7 @@ function createModelCrudHandler<Schema extends SchemaDef, Model extends GetModel

findFirst: (args: unknown) => {
return createPromise(
'findFirst',
'findFirst',
args,
new FindOperationHandler<Schema>(client, model, inputValidator),
Expand All @@ -429,6 +433,7 @@ function createModelCrudHandler<Schema extends SchemaDef, Model extends GetModel
findFirstOrThrow: (args: unknown) => {
return createPromise(
'findFirst',
'findFirstOrThrow',
args,
new FindOperationHandler<Schema>(client, model, inputValidator),
true,
Expand All @@ -438,6 +443,7 @@ function createModelCrudHandler<Schema extends SchemaDef, Model extends GetModel

findMany: (args: unknown) => {
return createPromise(
'findMany',
'findMany',
args,
new FindOperationHandler<Schema>(client, model, inputValidator),
Expand All @@ -447,6 +453,7 @@ function createModelCrudHandler<Schema extends SchemaDef, Model extends GetModel

create: (args: unknown) => {
return createPromise(
'create',
'create',
args,
new CreateOperationHandler<Schema>(client, model, inputValidator),
Expand All @@ -456,6 +463,7 @@ function createModelCrudHandler<Schema extends SchemaDef, Model extends GetModel

createMany: (args: unknown) => {
return createPromise(
'createMany',
'createMany',
args,
new CreateOperationHandler<Schema>(client, model, inputValidator),
Expand All @@ -465,6 +473,7 @@ function createModelCrudHandler<Schema extends SchemaDef, Model extends GetModel

createManyAndReturn: (args: unknown) => {
return createPromise(
'createManyAndReturn',
'createManyAndReturn',
args,
new CreateOperationHandler<Schema>(client, model, inputValidator),
Expand All @@ -474,6 +483,7 @@ function createModelCrudHandler<Schema extends SchemaDef, Model extends GetModel

update: (args: unknown) => {
return createPromise(
'update',
'update',
args,
new UpdateOperationHandler<Schema>(client, model, inputValidator),
Expand All @@ -483,6 +493,7 @@ function createModelCrudHandler<Schema extends SchemaDef, Model extends GetModel

updateMany: (args: unknown) => {
return createPromise(
'updateMany',
'updateMany',
args,
new UpdateOperationHandler<Schema>(client, model, inputValidator),
Expand All @@ -492,6 +503,7 @@ function createModelCrudHandler<Schema extends SchemaDef, Model extends GetModel

updateManyAndReturn: (args: unknown) => {
return createPromise(
'updateManyAndReturn',
'updateManyAndReturn',
args,
new UpdateOperationHandler<Schema>(client, model, inputValidator),
Expand All @@ -501,6 +513,7 @@ function createModelCrudHandler<Schema extends SchemaDef, Model extends GetModel

upsert: (args: unknown) => {
return createPromise(
'upsert',
'upsert',
args,
new UpdateOperationHandler<Schema>(client, model, inputValidator),
Expand All @@ -510,6 +523,7 @@ function createModelCrudHandler<Schema extends SchemaDef, Model extends GetModel

delete: (args: unknown) => {
return createPromise(
'delete',
'delete',
args,
new DeleteOperationHandler<Schema>(client, model, inputValidator),
Expand All @@ -519,6 +533,7 @@ function createModelCrudHandler<Schema extends SchemaDef, Model extends GetModel

deleteMany: (args: unknown) => {
return createPromise(
'deleteMany',
'deleteMany',
args,
new DeleteOperationHandler<Schema>(client, model, inputValidator),
Expand All @@ -528,6 +543,7 @@ function createModelCrudHandler<Schema extends SchemaDef, Model extends GetModel

count: (args: unknown) => {
return createPromise(
'count',
'count',
args,
new CountOperationHandler<Schema>(client, model, inputValidator),
Expand All @@ -537,6 +553,7 @@ function createModelCrudHandler<Schema extends SchemaDef, Model extends GetModel

aggregate: (args: unknown) => {
return createPromise(
'aggregate',
'aggregate',
args,
new AggregateOperationHandler<Schema>(client, model, inputValidator),
Expand All @@ -546,6 +563,7 @@ function createModelCrudHandler<Schema extends SchemaDef, Model extends GetModel

groupBy: (args: unknown) => {
return createPromise(
'groupBy',
'groupBy',
args,
new GroupByOperationHandler<Schema>(client, model, inputValidator),
Expand Down
6 changes: 3 additions & 3 deletions packages/runtime/src/client/crud/operations/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ import { getCrudDialect } from '../dialects';
import type { BaseCrudDialect } from '../dialects/base';
import { InputValidator } from '../validator';

export type CrudOperation =
export type CoreCrudOperation =
| 'findMany'
| 'findUnique'
| 'findFirst'
Expand All @@ -68,7 +68,7 @@ export type CrudOperation =
| 'aggregate'
| 'groupBy';

export type AllCrudOperation = CrudOperation | 'findUniqueOrThrow' | 'findFirstOrThrow';
export type AllCrudOperation = CoreCrudOperation | 'findUniqueOrThrow' | 'findFirstOrThrow';

export type FromRelationContext<Schema extends SchemaDef> = {
model: GetModels<Schema>;
Expand Down Expand Up @@ -99,7 +99,7 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
return this.client.$qb;
}

abstract handle(operation: CrudOperation, args: any): Promise<unknown>;
abstract handle(operation: CoreCrudOperation, args: any): Promise<unknown>;

withClient(client: ClientContract<Schema>) {
return new (this.constructor as new (...args: any[]) => this)(client, this.model, this.inputValidator);
Expand Down
4 changes: 2 additions & 2 deletions packages/runtime/src/client/crud/operations/find.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import type { GetModels, SchemaDef } from '../../../schema';
import type { FindArgs } from '../../crud-types';
import { BaseOperationHandler, type CrudOperation } from './base';
import { BaseOperationHandler, type CoreCrudOperation } from './base';

export class FindOperationHandler<Schema extends SchemaDef> extends BaseOperationHandler<Schema> {
async handle(operation: CrudOperation, args: unknown, validateArgs = true): Promise<unknown> {
async handle(operation: CoreCrudOperation, args: unknown, validateArgs = true): Promise<unknown> {
// normalize args to strip `undefined` fields
const normalizedArgs = this.normalizeArgs(args);

Expand Down
88 changes: 43 additions & 45 deletions packages/runtime/src/client/crud/validator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -627,30 +627,32 @@ export class InputValidator<Schema extends SchemaDef> {
}

private makeRelationSelectIncludeSchema(fieldDef: FieldDef) {
return z.union([
z.boolean(),
z.strictObject({
...(fieldDef.array || fieldDef.optional
? {
// to-many relations and optional to-one relations are filterable
where: z.lazy(() => this.makeWhereSchema(fieldDef.type, false)).optional(),
}
: {}),
select: z.lazy(() => this.makeSelectSchema(fieldDef.type)).optional(),
include: z.lazy(() => this.makeIncludeSchema(fieldDef.type)).optional(),
omit: z.lazy(() => this.makeOmitSchema(fieldDef.type)).optional(),
...(fieldDef.array
? {
// to-many relations can be ordered, skipped, taken, and cursor-located
orderBy: z.lazy(() => this.makeOrderBySchema(fieldDef.type, true, false)).optional(),
skip: this.makeSkipSchema().optional(),
take: this.makeTakeSchema().optional(),
cursor: this.makeCursorSchema(fieldDef.type).optional(),
distinct: this.makeDistinctSchema(fieldDef.type).optional(),
}
: {}),
}),
]);
let objSchema: z.ZodType = z.strictObject({
...(fieldDef.array || fieldDef.optional
? {
// to-many relations and optional to-one relations are filterable
where: z.lazy(() => this.makeWhereSchema(fieldDef.type, false)).optional(),
}
: {}),
select: z.lazy(() => this.makeSelectSchema(fieldDef.type)).optional(),
include: z.lazy(() => this.makeIncludeSchema(fieldDef.type)).optional(),
omit: z.lazy(() => this.makeOmitSchema(fieldDef.type)).optional(),
...(fieldDef.array
? {
// to-many relations can be ordered, skipped, taken, and cursor-located
orderBy: z.lazy(() => this.makeOrderBySchema(fieldDef.type, true, false)).optional(),
skip: this.makeSkipSchema().optional(),
take: this.makeTakeSchema().optional(),
cursor: this.makeCursorSchema(fieldDef.type).optional(),
distinct: this.makeDistinctSchema(fieldDef.type).optional(),
}
: {}),
});

objSchema = this.refineForSelectIncludeMutuallyExclusive(objSchema);
objSchema = this.refineForSelectOmitMutuallyExclusive(objSchema);

return z.union([z.boolean(), objSchema]);
}

private makeOmitSchema(model: string) {
Expand Down Expand Up @@ -742,7 +744,7 @@ export class InputValidator<Schema extends SchemaDef> {

private makeCreateSchema(model: string) {
const dataSchema = this.makeCreateDataSchema(model, false);
const schema = z.object({
const schema = z.strictObject({
data: dataSchema,
select: this.makeSelectSchema(model).optional(),
include: this.makeIncludeSchema(model).optional(),
Expand All @@ -757,12 +759,10 @@ export class InputValidator<Schema extends SchemaDef> {

private makeCreateManyAndReturnSchema(model: string) {
const base = this.makeCreateManyDataSchema(model, []);
const result = base.merge(
z.strictObject({
select: this.makeSelectSchema(model).optional(),
omit: this.makeOmitSchema(model).optional(),
}),
);
const result = base.extend({
select: this.makeSelectSchema(model).optional(),
omit: this.makeOmitSchema(model).optional(),
});
return this.refineForSelectOmitMutuallyExclusive(result).optional();
}

Expand Down Expand Up @@ -986,7 +986,7 @@ export class InputValidator<Schema extends SchemaDef> {
const whereSchema = this.makeWhereSchema(model, true);
const createSchema = this.makeCreateDataSchema(model, false, withoutFields);
return this.orArray(
z.object({
z.strictObject({
where: whereSchema,
create: createSchema,
}),
Expand All @@ -995,7 +995,7 @@ export class InputValidator<Schema extends SchemaDef> {
}

private makeCreateManyDataSchema(model: string, withoutFields: string[]) {
return z.object({
return z.strictObject({
data: this.makeCreateDataSchema(model, true, withoutFields, true),
skipDuplicates: z.boolean().optional(),
});
Expand All @@ -1006,7 +1006,7 @@ export class InputValidator<Schema extends SchemaDef> {
// #region Update

private makeUpdateSchema(model: string) {
const schema = z.object({
const schema = z.strictObject({
where: this.makeWhereSchema(model, true),
data: this.makeUpdateDataSchema(model),
select: this.makeSelectSchema(model).optional(),
Expand All @@ -1017,7 +1017,7 @@ export class InputValidator<Schema extends SchemaDef> {
}

private makeUpdateManySchema(model: string) {
return z.object({
return z.strictObject({
where: this.makeWhereSchema(model, false).optional(),
data: this.makeUpdateDataSchema(model, [], true),
limit: z.int().nonnegative().optional(),
Expand All @@ -1026,17 +1026,15 @@ export class InputValidator<Schema extends SchemaDef> {

private makeUpdateManyAndReturnSchema(model: string) {
const base = this.makeUpdateManySchema(model);
const result = base.merge(
z.strictObject({
select: this.makeSelectSchema(model).optional(),
omit: this.makeOmitSchema(model).optional(),
}),
);
const result = base.extend({
select: this.makeSelectSchema(model).optional(),
omit: this.makeOmitSchema(model).optional(),
});
return this.refineForSelectOmitMutuallyExclusive(result);
}

private makeUpsertSchema(model: string) {
const schema = z.object({
const schema = z.strictObject({
where: this.makeWhereSchema(model, true),
create: this.makeCreateDataSchema(model, false),
update: this.makeUpdateDataSchema(model),
Expand Down Expand Up @@ -1148,7 +1146,7 @@ export class InputValidator<Schema extends SchemaDef> {
// #region Delete

private makeDeleteSchema(model: GetModels<Schema>) {
const schema = z.object({
const schema = z.strictObject({
where: this.makeWhereSchema(model, true),
select: this.makeSelectSchema(model).optional(),
include: this.makeIncludeSchema(model).optional(),
Expand Down Expand Up @@ -1187,7 +1185,7 @@ export class InputValidator<Schema extends SchemaDef> {
const modelDef = requireModel(this.schema, model);
return z.union([
z.literal(true),
z.object({
z.strictObject({
_all: z.literal(true).optional(),
...Object.keys(modelDef.fields).reduce(
(acc, field) => {
Expand Down Expand Up @@ -1257,7 +1255,7 @@ export class InputValidator<Schema extends SchemaDef> {
const modelDef = requireModel(this.schema, model);
const nonRelationFields = Object.keys(modelDef.fields).filter((field) => !modelDef.fields[field]?.relation);

let schema = z.object({
let schema = z.strictObject({
where: this.makeWhereSchema(model, false).optional(),
orderBy: this.orArray(this.makeOrderBySchema(model, false, true), true).optional(),
by: this.orArray(z.enum(nonRelationFields), true),
Expand Down
2 changes: 1 addition & 1 deletion packages/runtime/src/client/plugin.ts
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ export function definePlugin<Schema extends SchemaDef>(plugin: RuntimePlugin<Sch
return plugin;
}

export { type CrudOperation } from './crud/operations/base';
export { type CoreCrudOperation as CrudOperation } from './crud/operations/base';

// #region OnQuery hooks

Expand Down
2 changes: 1 addition & 1 deletion packages/zod/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ export function makeSelectSchema<Schema extends SchemaDef, Model extends GetMode
schema: Schema,
model: Model,
) {
return z.object(mapFields(schema, model)) as SelectSchema<Schema, typeof model>;
return z.strictObject(mapFields(schema, model)) as SelectSchema<Schema, typeof model>;
}

function mapFields<Schema extends SchemaDef>(schema: Schema, model: GetModels<Schema>): any {
Expand Down