diff --git a/packages/runtime/src/client/client-impl.ts b/packages/runtime/src/client/client-impl.ts index d5f65a46..be1f2dff 100644 --- a/packages/runtime/src/client/client-impl.ts +++ b/packages/runtime/src/client/client-impl.ts @@ -14,7 +14,7 @@ import type { AuthType } from '../schema/auth'; import type { UnwrapTuplePromises } from '../utils/type-utils'; import type { ClientConstructor, ClientContract, ModelOperations, TransactionIsolationLevel } from './contract'; import { AggregateOperationHandler } from './crud/operations/aggregate'; -import type { CrudOperation } from './crud/operations/base'; +import type { AllCrudOperation, CoreCrudOperation } from './crud/operations/base'; import { BaseOperationHandler } from './crud/operations/base'; import { CountOperationHandler } from './crud/operations/count'; import { CreateOperationHandler } from './crud/operations/create'; @@ -351,7 +351,8 @@ function createModelCrudHandler, ): ModelOperations { const createPromise = ( - operation: CrudOperation, + operation: CoreCrudOperation, + nominalOperation: AllCrudOperation, args: unknown, handler: BaseOperationHandler, postProcess = false, @@ -383,7 +384,7 @@ function createModelCrudHandler { return createPromise( + 'findUnique', 'findUnique', args, new FindOperationHandler(client, model, inputValidator), @@ -410,6 +412,7 @@ function createModelCrudHandler { return createPromise( 'findUnique', + 'findUniqueOrThrow', args, new FindOperationHandler(client, model, inputValidator), true, @@ -419,6 +422,7 @@ function createModelCrudHandler { return createPromise( + 'findFirst', 'findFirst', args, new FindOperationHandler(client, model, inputValidator), @@ -429,6 +433,7 @@ function createModelCrudHandler { return createPromise( 'findFirst', + 'findFirstOrThrow', args, new FindOperationHandler(client, model, inputValidator), true, @@ -438,6 +443,7 @@ function createModelCrudHandler { return createPromise( + 'findMany', 'findMany', args, new FindOperationHandler(client, model, inputValidator), @@ -447,6 +453,7 @@ function createModelCrudHandler { return createPromise( + 'create', 'create', args, new CreateOperationHandler(client, model, inputValidator), @@ -456,6 +463,7 @@ function createModelCrudHandler { return createPromise( + 'createMany', 'createMany', args, new CreateOperationHandler(client, model, inputValidator), @@ -465,6 +473,7 @@ function createModelCrudHandler { return createPromise( + 'createManyAndReturn', 'createManyAndReturn', args, new CreateOperationHandler(client, model, inputValidator), @@ -474,6 +483,7 @@ function createModelCrudHandler { return createPromise( + 'update', 'update', args, new UpdateOperationHandler(client, model, inputValidator), @@ -483,6 +493,7 @@ function createModelCrudHandler { return createPromise( + 'updateMany', 'updateMany', args, new UpdateOperationHandler(client, model, inputValidator), @@ -492,6 +503,7 @@ function createModelCrudHandler { return createPromise( + 'updateManyAndReturn', 'updateManyAndReturn', args, new UpdateOperationHandler(client, model, inputValidator), @@ -501,6 +513,7 @@ function createModelCrudHandler { return createPromise( + 'upsert', 'upsert', args, new UpdateOperationHandler(client, model, inputValidator), @@ -510,6 +523,7 @@ function createModelCrudHandler { return createPromise( + 'delete', 'delete', args, new DeleteOperationHandler(client, model, inputValidator), @@ -519,6 +533,7 @@ function createModelCrudHandler { return createPromise( + 'deleteMany', 'deleteMany', args, new DeleteOperationHandler(client, model, inputValidator), @@ -528,6 +543,7 @@ function createModelCrudHandler { return createPromise( + 'count', 'count', args, new CountOperationHandler(client, model, inputValidator), @@ -537,6 +553,7 @@ function createModelCrudHandler { return createPromise( + 'aggregate', 'aggregate', args, new AggregateOperationHandler(client, model, inputValidator), @@ -546,6 +563,7 @@ function createModelCrudHandler { return createPromise( + 'groupBy', 'groupBy', args, new GroupByOperationHandler(client, model, inputValidator), diff --git a/packages/runtime/src/client/crud/operations/base.ts b/packages/runtime/src/client/crud/operations/base.ts index f11ad80c..9765ea59 100644 --- a/packages/runtime/src/client/crud/operations/base.ts +++ b/packages/runtime/src/client/crud/operations/base.ts @@ -51,7 +51,7 @@ import { getCrudDialect } from '../dialects'; import type { BaseCrudDialect } from '../dialects/base'; import { InputValidator } from '../validator'; -export type CrudOperation = +export type CoreCrudOperation = | 'findMany' | 'findUnique' | 'findFirst' @@ -68,7 +68,7 @@ export type CrudOperation = | 'aggregate' | 'groupBy'; -export type AllCrudOperation = CrudOperation | 'findUniqueOrThrow' | 'findFirstOrThrow'; +export type AllCrudOperation = CoreCrudOperation | 'findUniqueOrThrow' | 'findFirstOrThrow'; export type FromRelationContext = { model: GetModels; @@ -99,7 +99,7 @@ export abstract class BaseOperationHandler { return this.client.$qb; } - abstract handle(operation: CrudOperation, args: any): Promise; + abstract handle(operation: CoreCrudOperation, args: any): Promise; withClient(client: ClientContract) { return new (this.constructor as new (...args: any[]) => this)(client, this.model, this.inputValidator); diff --git a/packages/runtime/src/client/crud/operations/find.ts b/packages/runtime/src/client/crud/operations/find.ts index 77bbb615..ef2b60be 100644 --- a/packages/runtime/src/client/crud/operations/find.ts +++ b/packages/runtime/src/client/crud/operations/find.ts @@ -1,9 +1,9 @@ import type { GetModels, SchemaDef } from '../../../schema'; import type { FindArgs } from '../../crud-types'; -import { BaseOperationHandler, type CrudOperation } from './base'; +import { BaseOperationHandler, type CoreCrudOperation } from './base'; export class FindOperationHandler extends BaseOperationHandler { - async handle(operation: CrudOperation, args: unknown, validateArgs = true): Promise { + async handle(operation: CoreCrudOperation, args: unknown, validateArgs = true): Promise { // normalize args to strip `undefined` fields const normalizedArgs = this.normalizeArgs(args); diff --git a/packages/runtime/src/client/crud/validator.ts b/packages/runtime/src/client/crud/validator.ts index be83d102..eb9ecf21 100644 --- a/packages/runtime/src/client/crud/validator.ts +++ b/packages/runtime/src/client/crud/validator.ts @@ -627,30 +627,32 @@ export class InputValidator { } private makeRelationSelectIncludeSchema(fieldDef: FieldDef) { - return z.union([ - z.boolean(), - z.strictObject({ - ...(fieldDef.array || fieldDef.optional - ? { - // to-many relations and optional to-one relations are filterable - where: z.lazy(() => this.makeWhereSchema(fieldDef.type, false)).optional(), - } - : {}), - select: z.lazy(() => this.makeSelectSchema(fieldDef.type)).optional(), - include: z.lazy(() => this.makeIncludeSchema(fieldDef.type)).optional(), - omit: z.lazy(() => this.makeOmitSchema(fieldDef.type)).optional(), - ...(fieldDef.array - ? { - // to-many relations can be ordered, skipped, taken, and cursor-located - orderBy: z.lazy(() => this.makeOrderBySchema(fieldDef.type, true, false)).optional(), - skip: this.makeSkipSchema().optional(), - take: this.makeTakeSchema().optional(), - cursor: this.makeCursorSchema(fieldDef.type).optional(), - distinct: this.makeDistinctSchema(fieldDef.type).optional(), - } - : {}), - }), - ]); + let objSchema: z.ZodType = z.strictObject({ + ...(fieldDef.array || fieldDef.optional + ? { + // to-many relations and optional to-one relations are filterable + where: z.lazy(() => this.makeWhereSchema(fieldDef.type, false)).optional(), + } + : {}), + select: z.lazy(() => this.makeSelectSchema(fieldDef.type)).optional(), + include: z.lazy(() => this.makeIncludeSchema(fieldDef.type)).optional(), + omit: z.lazy(() => this.makeOmitSchema(fieldDef.type)).optional(), + ...(fieldDef.array + ? { + // to-many relations can be ordered, skipped, taken, and cursor-located + orderBy: z.lazy(() => this.makeOrderBySchema(fieldDef.type, true, false)).optional(), + skip: this.makeSkipSchema().optional(), + take: this.makeTakeSchema().optional(), + cursor: this.makeCursorSchema(fieldDef.type).optional(), + distinct: this.makeDistinctSchema(fieldDef.type).optional(), + } + : {}), + }); + + objSchema = this.refineForSelectIncludeMutuallyExclusive(objSchema); + objSchema = this.refineForSelectOmitMutuallyExclusive(objSchema); + + return z.union([z.boolean(), objSchema]); } private makeOmitSchema(model: string) { @@ -742,7 +744,7 @@ export class InputValidator { private makeCreateSchema(model: string) { const dataSchema = this.makeCreateDataSchema(model, false); - const schema = z.object({ + const schema = z.strictObject({ data: dataSchema, select: this.makeSelectSchema(model).optional(), include: this.makeIncludeSchema(model).optional(), @@ -757,12 +759,10 @@ export class InputValidator { private makeCreateManyAndReturnSchema(model: string) { const base = this.makeCreateManyDataSchema(model, []); - const result = base.merge( - z.strictObject({ - select: this.makeSelectSchema(model).optional(), - omit: this.makeOmitSchema(model).optional(), - }), - ); + const result = base.extend({ + select: this.makeSelectSchema(model).optional(), + omit: this.makeOmitSchema(model).optional(), + }); return this.refineForSelectOmitMutuallyExclusive(result).optional(); } @@ -986,7 +986,7 @@ export class InputValidator { const whereSchema = this.makeWhereSchema(model, true); const createSchema = this.makeCreateDataSchema(model, false, withoutFields); return this.orArray( - z.object({ + z.strictObject({ where: whereSchema, create: createSchema, }), @@ -995,7 +995,7 @@ export class InputValidator { } private makeCreateManyDataSchema(model: string, withoutFields: string[]) { - return z.object({ + return z.strictObject({ data: this.makeCreateDataSchema(model, true, withoutFields, true), skipDuplicates: z.boolean().optional(), }); @@ -1006,7 +1006,7 @@ export class InputValidator { // #region Update private makeUpdateSchema(model: string) { - const schema = z.object({ + const schema = z.strictObject({ where: this.makeWhereSchema(model, true), data: this.makeUpdateDataSchema(model), select: this.makeSelectSchema(model).optional(), @@ -1017,7 +1017,7 @@ export class InputValidator { } private makeUpdateManySchema(model: string) { - return z.object({ + return z.strictObject({ where: this.makeWhereSchema(model, false).optional(), data: this.makeUpdateDataSchema(model, [], true), limit: z.int().nonnegative().optional(), @@ -1026,17 +1026,15 @@ export class InputValidator { private makeUpdateManyAndReturnSchema(model: string) { const base = this.makeUpdateManySchema(model); - const result = base.merge( - z.strictObject({ - select: this.makeSelectSchema(model).optional(), - omit: this.makeOmitSchema(model).optional(), - }), - ); + const result = base.extend({ + select: this.makeSelectSchema(model).optional(), + omit: this.makeOmitSchema(model).optional(), + }); return this.refineForSelectOmitMutuallyExclusive(result); } private makeUpsertSchema(model: string) { - const schema = z.object({ + const schema = z.strictObject({ where: this.makeWhereSchema(model, true), create: this.makeCreateDataSchema(model, false), update: this.makeUpdateDataSchema(model), @@ -1148,7 +1146,7 @@ export class InputValidator { // #region Delete private makeDeleteSchema(model: GetModels) { - const schema = z.object({ + const schema = z.strictObject({ where: this.makeWhereSchema(model, true), select: this.makeSelectSchema(model).optional(), include: this.makeIncludeSchema(model).optional(), @@ -1187,7 +1185,7 @@ export class InputValidator { const modelDef = requireModel(this.schema, model); return z.union([ z.literal(true), - z.object({ + z.strictObject({ _all: z.literal(true).optional(), ...Object.keys(modelDef.fields).reduce( (acc, field) => { @@ -1257,7 +1255,7 @@ export class InputValidator { const modelDef = requireModel(this.schema, model); const nonRelationFields = Object.keys(modelDef.fields).filter((field) => !modelDef.fields[field]?.relation); - let schema = z.object({ + let schema = z.strictObject({ where: this.makeWhereSchema(model, false).optional(), orderBy: this.orArray(this.makeOrderBySchema(model, false, true), true).optional(), by: this.orArray(z.enum(nonRelationFields), true), diff --git a/packages/runtime/src/client/plugin.ts b/packages/runtime/src/client/plugin.ts index 99ee2d92..b8d6e314 100644 --- a/packages/runtime/src/client/plugin.ts +++ b/packages/runtime/src/client/plugin.ts @@ -46,7 +46,7 @@ export function definePlugin(plugin: RuntimePlugin; + return z.strictObject(mapFields(schema, model)) as SelectSchema; } function mapFields(schema: Schema, model: GetModels): any {