From 6ce9a57908cbdd42de42d8ea66be63d96d3d7221 Mon Sep 17 00:00:00 2001 From: ymc9 <104139426+ymc9@users.noreply.github.com> Date: Mon, 4 Aug 2025 22:06:26 +0800 Subject: [PATCH 1/2] fix: sqlite createMany issue with mismatching columns --- packages/runtime/src/client/contract.ts | 108 +++++++++--------- .../runtime/src/client/crud/dialects/base.ts | 5 + .../src/client/crud/dialects/postgresql.ts | 4 + .../src/client/crud/dialects/sqlite.ts | 4 + .../src/client/crud/operations/base.ts | 33 +++++- .../test/client-api/default-values.test.ts | 32 +++++- 6 files changed, 124 insertions(+), 62 deletions(-) diff --git a/packages/runtime/src/client/contract.ts b/packages/runtime/src/client/contract.ts index f9182845..48577a32 100644 --- a/packages/runtime/src/client/contract.ts +++ b/packages/runtime/src/client/contract.ts @@ -59,7 +59,7 @@ export type ClientContract = { * Executes a prepared raw query and returns the number of affected rows. * @example * ``` - * const result = await client.$executeRaw`UPDATE User SET cool = ${true} WHERE email = ${'user@email.com'};` + * const result = await db.$executeRaw`UPDATE User SET cool = ${true} WHERE email = ${'user@email.com'};` * ``` */ $executeRaw(query: TemplateStringsArray, ...values: any[]): ZenStackPromise; @@ -69,7 +69,7 @@ export type ClientContract = { * This method is susceptible to SQL injections. * @example * ``` - * const result = await client.$executeRawUnsafe('UPDATE User SET cool = $1 WHERE email = $2 ;', true, 'user@email.com') + * const result = await db.$executeRawUnsafe('UPDATE User SET cool = $1 WHERE email = $2 ;', true, 'user@email.com') * ``` */ $executeRawUnsafe(query: string, ...values: any[]): ZenStackPromise; @@ -78,7 +78,7 @@ export type ClientContract = { * Performs a prepared raw query and returns the `SELECT` data. * @example * ``` - * const result = await client.$queryRaw`SELECT * FROM User WHERE id = ${1} OR email = ${'user@email.com'};` + * const result = await db.$queryRaw`SELECT * FROM User WHERE id = ${1} OR email = ${'user@email.com'};` * ``` */ $queryRaw(query: TemplateStringsArray, ...values: any[]): ZenStackPromise; @@ -88,7 +88,7 @@ export type ClientContract = { * This method is susceptible to SQL injections. * @example * ``` - * const result = await client.$queryRawUnsafe('SELECT * FROM User WHERE id = $1 OR email = $2;', 1, 'user@email.com') + * const result = await db.$queryRawUnsafe('SELECT * FROM User WHERE id = $1 OR email = $2;', 1, 'user@email.com') * ``` */ $queryRawUnsafe(query: string, ...values: any[]): ZenStackPromise; @@ -225,17 +225,17 @@ export type ModelOperations` * * // omit fields - * await client.user.findMany({ + * await db.user.findMany({ * omit: { * name: true, * } * }); // result: `Array<{ id: number; email: string; ... }>` * * // include relations (and all scalar fields) - * await client.user.findMany({ + * await db.user.findMany({ * include: { * posts: true, * } * }); // result: `Array<{ ...; posts: Post[] }>` * * // include relations with filter - * await client.user.findMany({ + * await db.user.findMany({ * include: { * posts: { * where: { @@ -268,14 +268,14 @@ export type ModelOperations>( - args?: SelectSubset>, + args: SelectSubset>, ): ZenStackPromise> | null>; /** @@ -319,7 +319,7 @@ export type ModelOperations>( - args?: SelectSubset>, + args: SelectSubset>, ): ZenStackPromise>>; /** @@ -350,12 +350,12 @@ export type ModelOperations` * * // limit the number of updated entities - * await client.user.updateManyAndReturn({ + * await db.user.updateManyAndReturn({ * where: { email: { endsWith: '@zenstack.dev' } }, * data: { role: 'ADMIN' }, * limit: 10 @@ -628,7 +628,7 @@ export type ModelOperations` * * // group by multiple fields - * await client.profile.groupBy({ + * await db.profile.groupBy({ * by: ['country', 'city'], * _count: true * }); // result: `Array<{ country: string, city: string, _count: number }>` * * // group by with sorting, the `orderBy` fields must be in the `by` list - * await client.profile.groupBy({ + * await db.profile.groupBy({ * by: 'country', * orderBy: { country: 'desc' } * }); * * // group by with having (post-aggregation filter), the `having` fields must * // be in the `by` list - * await client.profile.groupBy({ + * await db.profile.groupBy({ * by: 'country', * having: { country: 'US' } * }); diff --git a/packages/runtime/src/client/crud/dialects/base.ts b/packages/runtime/src/client/crud/dialects/base.ts index a8e5ee39..369c1539 100644 --- a/packages/runtime/src/client/crud/dialects/base.ts +++ b/packages/runtime/src/client/crud/dialects/base.ts @@ -967,5 +967,10 @@ export abstract class BaseCrudDialect { */ abstract get supportsDistinctOn(): boolean; + /** + * Whether the dialect support inserting with `DEFAULT` as field value. + */ + abstract get supportInsertWithDefault(): boolean; + // #endregion } diff --git a/packages/runtime/src/client/crud/dialects/postgresql.ts b/packages/runtime/src/client/crud/dialects/postgresql.ts index f91c7aad..f3408820 100644 --- a/packages/runtime/src/client/crud/dialects/postgresql.ts +++ b/packages/runtime/src/client/crud/dialects/postgresql.ts @@ -342,4 +342,8 @@ export class PostgresCrudDialect extends BaseCrudDiale return `ARRAY[${values.map((v) => (typeof v === 'string' ? `'${v}'` : v))}]`; } } + + override get supportInsertWithDefault() { + return true; + } } diff --git a/packages/runtime/src/client/crud/dialects/sqlite.ts b/packages/runtime/src/client/crud/dialects/sqlite.ts index 3d74f82c..695795ab 100644 --- a/packages/runtime/src/client/crud/dialects/sqlite.ts +++ b/packages/runtime/src/client/crud/dialects/sqlite.ts @@ -282,4 +282,8 @@ export class SqliteCrudDialect extends BaseCrudDialect override buildArrayLiteralSQL(_values: unknown[]): string { throw new Error('SQLite does not support array literals'); } + + override get supportInsertWithDefault() { + return false; + } } diff --git a/packages/runtime/src/client/crud/operations/base.ts b/packages/runtime/src/client/crud/operations/base.ts index 12e25891..a163f934 100644 --- a/packages/runtime/src/client/crud/operations/base.ts +++ b/packages/runtime/src/client/crud/operations/base.ts @@ -519,13 +519,12 @@ export abstract class BaseOperationHandler { const createdEntity = await this.executeQueryTakeFirst(kysely, query, 'create'); + // let createdEntity: any; // try { // createdEntity = await this.executeQueryTakeFirst(kysely, query, 'create'); // } catch (err) { // const { sql, parameters } = query.compile(); - // throw new QueryError( - // `Error during create: ${err}, sql: ${sql}, parameters: ${parameters}` - // ); + // throw new QueryError(`Error during create: ${err}, sql: ${sql}, parameters: ${parameters}`); // } if (Object.keys(postCreateRelations).length > 0) { @@ -871,6 +870,34 @@ export abstract class BaseOperationHandler { return this.fillGeneratedValues(modelDef, newItem); }); + if (!this.dialect.supportInsertWithDefault) { + // if the dialect doesn't support `DEFAULT` as insert field values, + // we need to double check if data rows have mismatching fields, and + // if so, make sure all fields have default value filled if not provided + const allPassedFields = createData.reduce((acc, item) => { + Object.keys(item).forEach((field) => { + if (!acc.includes(field)) { + acc.push(field); + } + }); + return acc; + }, [] as string[]); + for (const item of createData) { + for (const field of allPassedFields) { + if (!(field in item)) { + const fieldDef = this.requireField(model, field); + if (fieldDef.default !== undefined && typeof fieldDef.default !== 'object') { + item[field] = this.dialect.transformPrimitive( + fieldDef.default, + fieldDef.type as BuiltinType, + !!fieldDef.array, + ); + } + } + } + } + } + if (modelDef.baseModel) { if (input.skipDuplicates) { // TODO: simulate createMany with create in this case diff --git a/packages/runtime/test/client-api/default-values.test.ts b/packages/runtime/test/client-api/default-values.test.ts index 13e49177..1e257a10 100644 --- a/packages/runtime/test/client-api/default-values.test.ts +++ b/packages/runtime/test/client-api/default-values.test.ts @@ -15,10 +15,14 @@ const schema = { Model: { name: 'Model', fields: { + id: { + name: 'id', + type: 'Int', + id: true, + }, uuid: { name: 'uuid', type: 'String', - id: true, default: ExpressionUtils.call('uuid'), }, uuid7: { @@ -56,10 +60,15 @@ const schema = { type: 'DateTime', default: ExpressionUtils.call('now'), }, + bool: { + name: 'bool', + type: 'Boolean', + default: false, + }, }, - idFields: ['uuid'], + idFields: ['id'], uniqueFields: { - uuid: { type: 'String' }, + id: { type: 'Int' }, }, }, }, @@ -67,13 +76,13 @@ const schema = { } as const satisfies SchemaDef; describe('default values tests', () => { - it('supports generators', async () => { + it('supports defaults', async () => { const client = new ZenStackClient(schema, { dialect: new SqliteDialect({ database: new SQLite(':memory:') }), }); await client.$pushSchema(); - const entity = await client.model.create({ data: {} }); + const entity = await client.model.create({ data: { id: 1 } }); expect(entity.uuid).toSatisfy(isValidUuid); expect(entity.uuid7).toSatisfy(isValidUuid); expect(entity.cuid).toSatisfy(isCuid); @@ -82,5 +91,18 @@ describe('default values tests', () => { expect(entity.nanoid8).toSatisfy((id) => id.length === 8); expect(entity.ulid).toSatisfy(isValidUlid); expect(entity.dt).toBeInstanceOf(Date); + + // some fields are set but some use default + await expect( + client.model.createMany({ + data: [{ id: 2 }, { id: 3, bool: true }], + }), + ).toResolveTruthy(); + await expect(client.model.findUnique({ where: { id: 2 } })).resolves.toMatchObject({ + bool: false, + }); + await expect(client.model.findUnique({ where: { id: 3 } })).resolves.toMatchObject({ + bool: true, + }); }); }); From 14ff08118af722613192d4aa5807e58b5e262365 Mon Sep 17 00:00:00 2001 From: ymc9 <104139426+ymc9@users.noreply.github.com> Date: Mon, 4 Aug 2025 22:11:36 +0800 Subject: [PATCH 2/2] perf improvement --- packages/runtime/src/client/crud/operations/base.ts | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/packages/runtime/src/client/crud/operations/base.ts b/packages/runtime/src/client/crud/operations/base.ts index a163f934..bc43e2af 100644 --- a/packages/runtime/src/client/crud/operations/base.ts +++ b/packages/runtime/src/client/crud/operations/base.ts @@ -883,10 +883,17 @@ export abstract class BaseOperationHandler { return acc; }, [] as string[]); for (const item of createData) { + if (Object.keys(item).length === allPassedFields.length) { + continue; + } for (const field of allPassedFields) { if (!(field in item)) { const fieldDef = this.requireField(model, field); - if (fieldDef.default !== undefined && typeof fieldDef.default !== 'object') { + if ( + fieldDef.default !== undefined && + fieldDef.default !== null && + typeof fieldDef.default !== 'object' + ) { item[field] = this.dialect.transformPrimitive( fieldDef.default, fieldDef.type as BuiltinType,