From 15a03e9a170b0c60b00cb2271ecfda330a8952ac Mon Sep 17 00:00:00 2001 From: ymc9 <104139426+ymc9@users.noreply.github.com> Date: Sun, 25 May 2025 23:11:20 -0700 Subject: [PATCH] upsert support and fixes about "omit" --- TODO.md | 2 +- packages/runtime/src/client/client-impl.ts | 13 +++- packages/runtime/src/client/crud-types.ts | 58 ++++++++++---- .../src/client/crud/operations/base.ts | 9 ++- .../src/client/crud/operations/create.ts | 1 + .../src/client/crud/operations/delete.ts | 1 + .../src/client/crud/operations/update.ts | 47 ++++++++++- packages/runtime/src/client/crud/validator.ts | 38 +++++++-- .../runtime/test/client-api/create.test.ts | 27 +++++++ .../runtime/test/client-api/upsert.test.ts | 77 +++++++++++++++++++ 10 files changed, 245 insertions(+), 28 deletions(-) create mode 100644 packages/runtime/test/client-api/upsert.test.ts diff --git a/TODO.md b/TODO.md index ed5383c3..b3e6196e 100644 --- a/TODO.md +++ b/TODO.md @@ -43,7 +43,7 @@ - [x] Nested to-one - [ ] Delta update for numeric fields - [ ] Array update - - [ ] Upsert + - [x] Upsert - [x] Delete - [ ] Aggregation - [x] Count diff --git a/packages/runtime/src/client/client-impl.ts b/packages/runtime/src/client/client-impl.ts index 883473d9..4109c837 100644 --- a/packages/runtime/src/client/client-impl.ts +++ b/packages/runtime/src/client/client-impl.ts @@ -266,7 +266,7 @@ function createModelCrudHandler< args: unknown, handler: BaseOperationHandler, postProcess = false, - throwIfNotFound = false + throwIfNoResult = false ) => { return createDeferredPromise(async () => { let proceed = async ( @@ -275,7 +275,7 @@ function createModelCrudHandler< ) => { const _handler = tx ? handler.withClient(tx) : handler; const r = await _handler.handle(operation, _args ?? args); - if (!r && throwIfNotFound) { + if (!r && throwIfNoResult) { throw new NotFoundError(model); } let result: unknown; @@ -400,6 +400,15 @@ function createModelCrudHandler< ); }, + upsert: (args: unknown) => { + return createPromise( + 'upsert', + args, + new UpdateOperationHandler(client, model, inputValidator), + true + ); + }, + delete: (args: unknown) => { return createPromise( 'delete', diff --git a/packages/runtime/src/client/crud-types.ts b/packages/runtime/src/client/crud-types.ts index a8dd78ea..8ebcef2d 100644 --- a/packages/runtime/src/client/crud-types.ts +++ b/packages/runtime/src/client/crud-types.ts @@ -39,29 +39,35 @@ import type { ToKyselySchema } from './query-builder'; type DefaultModelResult< Schema extends SchemaDef, Model extends GetModels, + Omit = undefined, Optional = false, Array = false > = WrapType< { - [Key in NonRelationFields]: MapFieldType< - Schema, - Model, - Key - >; + [Key in NonRelationFields as Key extends keyof Omit + ? Omit[Key] extends true + ? never + : Key + : Key]: MapFieldType; }, Optional, Array >; type ModelSelectResult< - Select, Schema extends SchemaDef, - Model extends GetModels + Model extends GetModels, + Select, + Omit > = { [Key in keyof Select & GetFields as Select[Key] extends | false | undefined ? never + : Key extends keyof Omit + ? Omit[Key] extends true + ? never + : Key : Key]: Key extends ScalarFields ? MapFieldType : Key extends RelationFields @@ -80,6 +86,7 @@ type ModelSelectResult< : DefaultModelResult< Schema, RelationFieldType, + Omit, FieldIsOptional, FieldIsArray > @@ -89,18 +96,20 @@ type ModelSelectResult< export type ModelResult< Schema extends SchemaDef, Model extends GetModels, - Args extends SelectInclude = {}, + Args extends SelectIncludeOmit = {}, Optional = false, Array = false > = WrapType< Args extends { select: infer S; + omit?: infer O; } - ? ModelSelectResult + ? ModelSelectResult : Args extends { include: infer I; + omit?: infer O; } - ? DefaultModelResult & { + ? DefaultModelResult & { [Key in keyof I & RelationFields as I[Key] extends | false | undefined @@ -124,6 +133,8 @@ export type ModelResult< FieldIsArray >; } + : Args extends { omit: infer O } + ? DefaultModelResult : DefaultModelResult, Optional, Array @@ -311,7 +322,7 @@ type OmitFields> = { [Key in ScalarFields]?: true; }; -export type SelectInclude< +export type SelectIncludeOmit< Schema extends SchemaDef, Model extends GetModels, AllowCount extends boolean @@ -528,14 +539,14 @@ export type FindArgs< where?: Where; } : {}) & - SelectInclude; + SelectIncludeOmit; export type FindUniqueArgs< Schema extends SchemaDef, Model extends GetModels > = { where?: WhereUnique; -} & SelectInclude; +} & SelectIncludeOmit; //#endregion @@ -548,6 +559,7 @@ export type CreateArgs< data: CreateInput; select?: Select; include?: Include; + omit?: OmitFields; }; export type CreateManyArgs< @@ -559,7 +571,7 @@ export type CreateManyAndReturnArgs< Schema extends SchemaDef, Model extends GetModels > = CreateManyPayload & - Omit, 'include'>; + Omit, 'include'>; type OptionalWrap< Schema extends SchemaDef, @@ -678,6 +690,7 @@ export type UpdateArgs< where: WhereUnique; select?: Select; include?: Include; + omit?: OmitFields; }; export type UpdateManyArgs< @@ -689,6 +702,18 @@ export type UpdateManyArgs< limit?: number; }; +export type UpsertArgs< + Schema extends SchemaDef, + Model extends GetModels +> = { + create: CreateInput; + update: UpdateInput; + where: WhereUnique; + select?: Select; + include?: Include; + omit?: OmitFields; +}; + export type UpdateScalarInput< Schema extends SchemaDef, Model extends GetModels, @@ -761,6 +786,7 @@ export type DeleteArgs< where: WhereUnique; select?: Select; include?: Include; + omit?: OmitFields; }; export type DeleteManyArgs< @@ -1104,6 +1130,10 @@ export type ModelOperations< args: Subset> ): Promise; + upsert>( + args: SelectSubset> + ): Promise>; + delete>( args: SelectSubset> ): Promise>; diff --git a/packages/runtime/src/client/crud/operations/base.ts b/packages/runtime/src/client/crud/operations/base.ts index 6b8ca7ee..3ff3821c 100644 --- a/packages/runtime/src/client/crud/operations/base.ts +++ b/packages/runtime/src/client/crud/operations/base.ts @@ -29,7 +29,7 @@ import { } from '../../../utils/object-utils'; import { CONTEXT_COMMENT_PREFIX } from '../../constants'; import type { CRUD } from '../../contract'; -import type { FindArgs, SelectInclude, Where } from '../../crud-types'; +import type { FindArgs, SelectIncludeOmit, Where } from '../../crud-types'; import { InternalError, NotFoundError, QueryError } from '../../errors'; import type { ToKysely } from '../../query-builder'; import { @@ -59,6 +59,7 @@ export type CrudOperation = | 'createManyAndReturn' | 'update' | 'updateMany' + | 'upsert' | 'delete' | 'deleteMany' | 'count' @@ -165,8 +166,10 @@ export abstract class BaseOperationHandler { // select if (args?.select) { + // select is mutually exclusive with omit query = this.buildFieldSelection(model, query, args?.select, model); } else { + // include all scalar fields except those in omit query = this.buildSelectAllScalarFields(model, query, args?.omit); } @@ -1783,7 +1786,7 @@ export abstract class BaseOperationHandler { protected trimResult( data: any, - args: SelectInclude, boolean> + args: SelectIncludeOmit, boolean> ) { if (!args.select) { return data; @@ -1796,7 +1799,7 @@ export abstract class BaseOperationHandler { protected needReturnRelations( model: string, - args: SelectInclude, boolean> + args: SelectIncludeOmit, boolean> ) { let returnRelation = false; diff --git a/packages/runtime/src/client/crud/operations/create.ts b/packages/runtime/src/client/crud/operations/create.ts index a782f24a..5ec8bf9f 100644 --- a/packages/runtime/src/client/crud/operations/create.ts +++ b/packages/runtime/src/client/crud/operations/create.ts @@ -45,6 +45,7 @@ export class CreateOperationHandler< return this.readUnique(tx, this.model, { select: args.select, include: args.include, + omit: args.omit, where: getIdValues(this.schema, this.model, createResult), }); }); diff --git a/packages/runtime/src/client/crud/operations/delete.ts b/packages/runtime/src/client/crud/operations/delete.ts index 3e642d10..6d712bbd 100644 --- a/packages/runtime/src/client/crud/operations/delete.ts +++ b/packages/runtime/src/client/crud/operations/delete.ts @@ -31,6 +31,7 @@ export class DeleteOperationHandler< const existing = await this.readUnique(this.kysely, this.model, { select: args.select, include: args.include, + omit: args.omit, where: args.where, }); if (!existing) { diff --git a/packages/runtime/src/client/crud/operations/update.ts b/packages/runtime/src/client/crud/operations/update.ts index 18c86ce5..7478b561 100644 --- a/packages/runtime/src/client/crud/operations/update.ts +++ b/packages/runtime/src/client/crud/operations/update.ts @@ -1,14 +1,14 @@ import { match } from 'ts-pattern'; import { RejectedByPolicyError } from '../../../plugins/policy/errors'; import type { GetModels, SchemaDef } from '../../../schema'; -import type { UpdateArgs, UpdateManyArgs } from '../../crud-types'; -import { BaseOperationHandler } from './base'; +import type { UpdateArgs, UpdateManyArgs, UpsertArgs } from '../../crud-types'; import { getIdValues } from '../../query-utils'; +import { BaseOperationHandler } from './base'; export class UpdateOperationHandler< Schema extends SchemaDef > extends BaseOperationHandler { - async handle(operation: 'update' | 'updateMany', args: unknown) { + async handle(operation: 'update' | 'updateMany' | 'upsert', args: unknown) { return match(operation) .with('update', () => this.runUpdate( @@ -20,6 +20,11 @@ export class UpdateOperationHandler< this.inputValidator.validateUpdateManyArgs(this.model, args) ) ) + .with('upsert', () => + this.runUpsert( + this.inputValidator.validateUpsertArgs(this.model, args) + ) + ) .exhaustive(); } @@ -34,6 +39,7 @@ export class UpdateOperationHandler< return this.readUnique(tx, this.model, { select: args.select, include: args.include, + omit: args.omit, where: getIdValues(this.schema, this.model, updated), }); }); @@ -59,4 +65,39 @@ export class UpdateOperationHandler< args.limit ); } + + private async runUpsert(args: UpsertArgs>) { + const result = await this.safeTransaction(async (tx) => { + let mutationResult = await this.update( + tx, + this.model, + args.where, + args.update, + undefined, + true, + false + ); + + if (!mutationResult) { + // non-existing, create + mutationResult = await this.create(tx, this.model, args.create); + } + + return this.readUnique(tx, this.model, { + select: args.select, + include: args.include, + omit: args.omit, + where: getIdValues(this.schema, this.model, mutationResult), + }); + }); + + if (!result) { + throw new RejectedByPolicyError( + this.model, + 'result is not allowed to be read back' + ); + } + + return result; + } } diff --git a/packages/runtime/src/client/crud/validator.ts b/packages/runtime/src/client/crud/validator.ts index 0232af80..d147e678 100644 --- a/packages/runtime/src/client/crud/validator.ts +++ b/packages/runtime/src/client/crud/validator.ts @@ -19,6 +19,7 @@ import { type FindArgs, type UpdateArgs, type UpdateManyArgs, + type UpsertArgs, } from '../crud-types'; import { InternalError, QueryError } from '../errors'; import { @@ -83,6 +84,14 @@ export class InputValidator { ); } + validateUpsertArgs(model: GetModels, args: unknown) { + return this.validate>>( + this.makeUpsertSchema(model), + 'upsert', + args + ); + } + validateDeleteArgs(model: GetModels, args: unknown) { return this.validate>>( this.makeDeleteSchema(model), @@ -576,13 +585,15 @@ export class InputValidator { private makeCreateSchema(model: string) { const dataSchema = this.makeCreateDataSchema(model, false); - return z + const schema = z .object({ data: dataSchema, - select: z.record(z.string(), z.any()).optional(), - include: z.record(z.string(), z.any()).optional(), + select: this.makeSelectSchema(model).optional(), + include: this.makeIncludeSchema(model).optional(), + omit: this.makeOmitSchema(model).optional(), }) .strict(); + return this.refineForSelectIncludeMutuallyExclusive(schema); } private makeCreateManySchema(model: string) { @@ -595,6 +606,7 @@ export class InputValidator { .merge( z.object({ select: this.makeSelectSchema(model).optional(), + include: this.makeIncludeSchema(model).optional(), omit: this.makeOmitSchema(model).optional(), }) ) @@ -896,8 +908,9 @@ export class InputValidator { .object({ where: this.makeWhereSchema(model, true), data: this.makeUpdateDataSchema(model), - select: z.record(z.string(), z.any()).optional(), - include: z.record(z.string(), z.any()).optional(), + select: this.makeSelectSchema(model).optional(), + include: this.makeIncludeSchema(model).optional(), + omit: this.makeOmitSchema(model).optional(), }) .strict(); @@ -914,6 +927,21 @@ export class InputValidator { .strict(); } + private makeUpsertSchema(model: string) { + const schema = z + .object({ + where: this.makeWhereSchema(model, true), + create: this.makeCreateDataSchema(model, false), + update: this.makeUpdateDataSchema(model), + select: this.makeSelectSchema(model).optional(), + include: this.makeIncludeSchema(model).optional(), + omit: this.makeOmitSchema(model).optional(), + }) + .strict(); + + return this.refineForSelectIncludeMutuallyExclusive(schema); + } + private makeUpdateDataSchema( model: string, withoutFields: string[] = [], diff --git a/packages/runtime/test/client-api/create.test.ts b/packages/runtime/test/client-api/create.test.ts index 1b97ed17..40d89bfc 100644 --- a/packages/runtime/test/client-api/create.test.ts +++ b/packages/runtime/test/client-api/create.test.ts @@ -32,6 +32,33 @@ describe.each(createClientSpecs(PG_DB_NAME))( email: 'u1@test.com', name: 'name', }); + + const user2 = await client.user.create({ + data: { + email: 'u2@test.com', + name: 'name', + }, + omit: { name: true }, + }); + expect(user2.email).toBe('u2@test.com'); + expect((user2 as any).name).toBeUndefined(); + // @ts-expect-error + console.log(user2.name); + + const user3 = await client.user.create({ + data: { + email: 'u3@test.com', + name: 'name', + posts: { create: { title: 'Post1' } }, + }, + include: { posts: true }, + omit: { name: true }, + }); + expect(user3.email).toBe('u3@test.com'); + expect(user3.posts).toHaveLength(1); + expect((user3 as any).name).toBeUndefined(); + // @ts-expect-error + console.log(user3.name); }); it('works with nested relation one-to-one, owner side', async () => { diff --git a/packages/runtime/test/client-api/upsert.test.ts b/packages/runtime/test/client-api/upsert.test.ts new file mode 100644 index 00000000..8055d0e2 --- /dev/null +++ b/packages/runtime/test/client-api/upsert.test.ts @@ -0,0 +1,77 @@ +import { afterEach, beforeEach, describe, expect, it } from 'vitest'; +import type { ClientContract } from '../../src/client'; +import { schema } from '../test-schema'; +import { createClientSpecs } from './client-specs'; + +const PG_DB_NAME = 'client-api-upsert-tests'; + +describe.each(createClientSpecs(PG_DB_NAME))( + 'Client upsert tests', + ({ createClient }) => { + let client: ClientContract; + + beforeEach(async () => { + client = await createClient(); + await client.$pushSchema(); + }); + + afterEach(async () => { + await client?.$disconnect(); + }); + + it('works with toplevel upsert', async () => { + // create + await expect( + client.user.upsert({ + where: { id: '1' }, + create: { + id: '1', + email: 'u1@test.com', + name: 'New', + profile: { create: { bio: 'My bio' } }, + }, + update: { name: 'Foo' }, + include: { profile: true }, + }) + ).resolves.toMatchObject({ + id: '1', + name: 'New', + profile: { bio: 'My bio' }, + }); + + // update + await expect( + client.user.upsert({ + where: { id: '1' }, + create: { + id: '2', + email: 'u2@test.com', + name: 'New', + }, + update: { name: 'Updated' }, + include: { profile: true }, + }) + ).resolves.toMatchObject({ + id: '1', + name: 'Updated', + profile: { bio: 'My bio' }, + }); + + // id update + await expect( + client.user.upsert({ + where: { id: '1' }, + create: { + id: '2', + email: 'u2@test.com', + name: 'New', + }, + update: { id: '3' }, + }) + ).resolves.toMatchObject({ + id: '3', + name: 'Updated', + }); + }); + } +);