diff --git a/packages/runtime/src/enhancements/policy/handler.ts b/packages/runtime/src/enhancements/policy/handler.ts index 698dcd364..e11379cdf 100644 --- a/packages/runtime/src/enhancements/policy/handler.ts +++ b/packages/runtime/src/enhancements/policy/handler.ts @@ -467,7 +467,7 @@ export class PolicyProxyHandler implements Pr // Validates the given create payload against Zod schema if any private validateCreateInputSchema(model: string, data: any) { const schema = this.utils.getZodSchema(model, 'create'); - if (schema) { + if (schema && data) { const parseResult = schema.safeParse(data); if (!parseResult.success) { throw this.utils.deniedByPolicy( @@ -496,11 +496,18 @@ export class PolicyProxyHandler implements Pr args = this.utils.clone(args); - // do static input validation and check if post-create checks are needed + // go through create items, statically check input to determine if post-create + // check is needed, and also validate zod schema let needPostCreateCheck = false; for (const item of enumerate(args.data)) { + const validationResult = this.validateCreateInputSchema(this.model, item); + if (validationResult !== item) { + this.utils.replace(item, validationResult); + } + const inputCheck = this.utils.checkInputGuard(this.model, item, 'create'); if (inputCheck === false) { + // unconditionally deny throw this.utils.deniedByPolicy( this.model, 'create', @@ -508,14 +515,10 @@ export class PolicyProxyHandler implements Pr CrudFailureReason.ACCESS_POLICY_VIOLATION ); } else if (inputCheck === true) { - const r = this.validateCreateInputSchema(this.model, item); - if (r !== item) { - this.utils.replace(item, r); - } + // unconditionally allow } else if (inputCheck === undefined) { // static policy check is not possible, need to do post-create check needPostCreateCheck = true; - break; } } @@ -786,7 +789,13 @@ export class PolicyProxyHandler implements Pr // check if the update actually writes to this model let thisModelUpdate = false; - const updatePayload: any = (args as any).data ?? args; + const updatePayload = (args as any).data ?? args; + + const validatedPayload = this.validateUpdateInputSchema(model, updatePayload); + if (validatedPayload !== updatePayload) { + this.utils.replace(updatePayload, validatedPayload); + } + if (updatePayload) { for (const key of Object.keys(updatePayload)) { const field = resolveField(this.modelMeta, model, key); @@ -857,6 +866,8 @@ export class PolicyProxyHandler implements Pr ); } + args.data = this.validateUpdateInputSchema(model, args.data); + const updateGuard = this.utils.getAuthGuard(db, model, 'update'); if (this.utils.isTrue(updateGuard) || this.utils.isFalse(updateGuard)) { // injects simple auth guard into where clause @@ -917,7 +928,10 @@ export class PolicyProxyHandler implements Pr await _registerPostUpdateCheck(model, uniqueFilter); // convert upsert to update - context.parent.update = { where: args.where, data: args.update }; + context.parent.update = { + where: args.where, + data: this.validateUpdateInputSchema(model, args.update), + }; delete context.parent.upsert; // continue visiting the new payload @@ -1016,6 +1030,37 @@ export class PolicyProxyHandler implements Pr return { result, postWriteChecks }; } + // Validates the given update payload against Zod schema if any + private validateUpdateInputSchema(model: string, data: any) { + const schema = this.utils.getZodSchema(model, 'update'); + if (schema && data) { + // update payload can contain non-literal fields, like: + // { x: { increment: 1 } } + // we should only validate literal fields + + const literalData = Object.entries(data).reduce( + (acc, [k, v]) => ({ ...acc, ...(typeof v !== 'object' ? { [k]: v } : {}) }), + {} + ); + + const parseResult = schema.safeParse(literalData); + if (!parseResult.success) { + throw this.utils.deniedByPolicy( + model, + 'update', + `input failed validation: ${fromZodError(parseResult.error)}`, + CrudFailureReason.DATA_VALIDATION_VIOLATION, + parseResult.error + ); + } + + // schema may have transformed field values, use it to overwrite the original data + return { ...data, ...parseResult.data }; + } else { + return data; + } + } + private isUnsafeMutate(model: string, args: any) { if (!args) { return false; @@ -1046,6 +1091,8 @@ export class PolicyProxyHandler implements Pr args = this.utils.clone(args); this.utils.injectAuthGuardAsWhere(this.prisma, args, this.model, 'update'); + args.data = this.validateUpdateInputSchema(this.model, args.data); + if (this.utils.hasAuthGuard(this.model, 'postUpdate') || this.utils.getZodSchema(this.model)) { // use a transaction to do post-update checks const postWriteChecks: PostWriteCheckRecord[] = []; diff --git a/tests/integration/tests/enhancements/with-policy/field-validation.test.ts b/tests/integration/tests/enhancements/with-policy/field-validation.test.ts index 55b9c5cee..bb505ca55 100644 --- a/tests/integration/tests/enhancements/with-policy/field-validation.test.ts +++ b/tests/integration/tests/enhancements/with-policy/field-validation.test.ts @@ -45,7 +45,7 @@ describe('With Policy: field validation', () => { id String @id @default(cuid()) user User @relation(fields: [userId], references: [id]) userId String - slug String @regex("^[0-9a-zA-Z]{4,16}$") + slug String @regex("^[0-9a-zA-Z]{4,16}$") @lower @@allow('all', true) } @@ -508,50 +508,104 @@ describe('With Policy: field validation', () => { }, }); - await expect( - db.userData.create({ - data: { - userId: '1', - a: 1, - b: 0, - c: -1, - d: 0, - text1: 'abc123', - text2: 'def', - text3: 'aaa', - text4: 'abcab', - text6: ' AbC ', - text7: 'abc', + let ud = await db.userData.create({ + data: { + userId: '1', + a: 1, + b: 0, + c: -1, + d: 0, + text1: 'abc123', + text2: 'def', + text3: 'aaa', + text4: 'abcab', + text6: ' AbC ', + text7: 'abc', + }, + }); + expect(ud).toMatchObject({ text6: 'abc', text7: 'ABC' }); + + ud = await db.userData.update({ + where: { id: ud.id }, + data: { + text4: 'xyz', + text6: ' bCD ', + text7: 'bcd', + }, + }); + expect(ud).toMatchObject({ text4: 'xyz', text6: 'bcd', text7: 'BCD' }); + + let u = await db.user.create({ + data: { + id: '2', + password: 'abc123!@#', + email: 'who@myorg.com', + handle: 'user2', + userData: { + create: { + a: 1, + b: 0, + c: -1, + d: 0, + text1: 'abc123', + text2: 'def', + text3: 'aaa', + text4: 'abcab', + text6: ' AbC ', + text7: 'abc', + }, }, - }) - ).resolves.toMatchObject({ text6: 'abc', text7: 'ABC' }); + }, + include: { userData: true }, + }); + expect(u.userData).toMatchObject({ + text6: 'abc', + text7: 'ABC', + }); - await expect( - db.user.create({ - data: { - id: '2', - password: 'abc123!@#', - email: 'who@myorg.com', - handle: 'user2', - userData: { - create: { - a: 1, - b: 0, - c: -1, - d: 0, - text1: 'abc123', - text2: 'def', - text3: 'aaa', - text4: 'abcab', - text6: ' AbC ', - text7: 'abc', - }, + u = await db.user.update({ + where: { id: u.id }, + data: { + userData: { + update: { + data: { text4: 'xyz', text6: ' bCD ', text7: 'bcd' }, }, }, - include: { userData: true }, - }) - ).resolves.toMatchObject({ - userData: expect.objectContaining({ text6: 'abc', text7: 'ABC' }), + }, + include: { userData: true }, + }); + expect(u.userData).toMatchObject({ text4: 'xyz', text6: 'bcd', text7: 'BCD' }); + + // upsert create + u = await db.user.update({ + where: { id: u.id }, + data: { + tasks: { + upsert: { + where: { id: 'unknown' }, + create: { slug: 'SLUG1' }, + update: {}, + }, + }, + }, + include: { tasks: true }, + }); + expect(u.tasks[0]).toMatchObject({ slug: 'slug1' }); + + // upsert update + u = await db.user.update({ + where: { id: u.id }, + data: { + tasks: { + upsert: { + where: { id: u.tasks[0].id }, + create: {}, + update: { slug: 'SLUG2' }, + }, + }, + }, + include: { tasks: true }, }); + expect(u.tasks[0]).toMatchObject({ slug: 'slug2' }); }); }); diff --git a/tests/integration/tests/enhancements/with-policy/refactor.test.ts b/tests/integration/tests/enhancements/with-policy/refactor.test.ts index 126c038fa..6a329a739 100644 --- a/tests/integration/tests/enhancements/with-policy/refactor.test.ts +++ b/tests/integration/tests/enhancements/with-policy/refactor.test.ts @@ -144,12 +144,15 @@ describe('With Policy: refactor tests', () => { // read back check await expect( anonDb.user.create({ - data: { id: 1, email: 'user1@zenstack.dev' }, + data: { id: 1, email: 'User1@zenstack.dev' }, }) ).rejects.toThrow(/not allowed to be read back/); // success - await expect(user1Db.user.findUnique({ where: { id: 1 } })).toResolveTruthy(); + await expect(user1Db.user.findUnique({ where: { id: 1 } })).resolves.toMatchObject({ + // email to lower + email: 'user1@zenstack.dev', + }); // nested creation failure await expect( @@ -202,7 +205,7 @@ describe('With Policy: refactor tests', () => { posts: { create: { id: 2, - title: 'Post 2', + title: ' Post 2 ', published: true, comments: { create: { @@ -213,8 +216,14 @@ describe('With Policy: refactor tests', () => { }, }, }, + include: { posts: true }, }) - ).toResolveTruthy(); + ).resolves.toMatchObject({ + posts: expect.arrayContaining([ + // title is trimmed + expect.objectContaining({ title: 'Post 2' }), + ]), + }); // create with connect: posts await expect( @@ -389,7 +398,7 @@ describe('With Policy: refactor tests', () => { data: [ { id: 7, title: 'Post 7.1' }, { id: 7, title: 'Post 7.2' }, - { id: 8, title: 'Post 8' }, + { id: 8, title: ' Post 8 ' }, ], skipDuplicates: true, }, @@ -400,7 +409,10 @@ describe('With Policy: refactor tests', () => { // success await expect(adminDb.user.findUnique({ where: { id: 7 } })).toResolveTruthy(); await expect(adminDb.post.findUnique({ where: { id: 7 } })).toResolveTruthy(); - await expect(adminDb.post.findUnique({ where: { id: 8 } })).toResolveTruthy(); + await expect(adminDb.post.findUnique({ where: { id: 8 } })).resolves.toMatchObject({ + // title is trimmed + title: 'Post 8', + }); }); it('createMany', async () => { @@ -412,11 +424,18 @@ describe('With Policy: refactor tests', () => { await expect( user1Db.post.createMany({ data: [ - { id: 1, title: 'Post 1', authorId: 1 }, + { id: 1, title: ' Post 1 ', authorId: 1 }, { id: 2, title: 'Post 2', authorId: 1 }, ], }) - ).resolves.toMatchObject({ count: 2 }); + ).toResolveTruthy(); + + await expect(user1Db.post.findMany()).resolves.toEqual( + expect.arrayContaining([ + expect.objectContaining({ title: 'Post 1' }), // title is trimmed + expect.objectContaining({ title: 'Post 2' }), + ]) + ); // unique constraint violation await expect( @@ -502,8 +521,8 @@ describe('With Policy: refactor tests', () => { user2Db.user.update({ where: { id: 1 }, data: { email: 'user2@zenstack.dev' } }) ).toBeRejectedByPolicy(); await expect( - adminDb.user.update({ where: { id: 1 }, data: { email: 'user1-nice@zenstack.dev' } }) - ).toResolveTruthy(); + adminDb.user.update({ where: { id: 1 }, data: { email: 'User1-nice@zenstack.dev' } }) + ).resolves.toMatchObject({ email: 'user1-nice@zenstack.dev' }); // update nested profile await expect( @@ -561,9 +580,10 @@ describe('With Policy: refactor tests', () => { await expect( user1Db.user.update({ where: { id: 1 }, - data: { posts: { update: { where: { id: 1 }, data: { published: false } } } }, + data: { posts: { update: { where: { id: 1 }, data: { title: ' New ', published: false } } } }, + include: { posts: true }, }) - ).toResolveTruthy(); + ).resolves.toMatchObject({ posts: expect.arrayContaining([expect.objectContaining({ title: 'New' })]) }); // update nested comment prevent update of toplevel await expect( @@ -588,23 +608,24 @@ describe('With Policy: refactor tests', () => { await expect(adminDb.comment.findFirst({ where: { content: 'Comment 2 updated' } })).toResolveFalsy(); // update with create - await expect( - user1Db.user.update({ - where: { id: 1 }, - data: { - posts: { - create: { - id: 3, - title: 'Post 3', - published: true, - comments: { - create: { author: { connect: { id: 1 } }, content: 'Comment 3' }, - }, + const r1 = await user1Db.user.update({ + where: { id: 1 }, + data: { + posts: { + create: { + id: 3, + title: 'Post 3', + published: true, + comments: { + create: { author: { connect: { id: 1 } }, content: ' Comment 3 ' }, }, }, }, - }) - ).toResolveTruthy(); + }, + include: { posts: { include: { comments: true } } }, + }); + expect(r1.posts[r1.posts.length - 1].comments[0].content).toEqual('Comment 3'); + await expect( user1Db.user.update({ where: { id: 1 }, @@ -636,7 +657,7 @@ describe('With Policy: refactor tests', () => { posts: { createMany: { data: [ - { id: 4, title: 'Post 4' }, + { id: 4, title: ' Post 4 ' }, { id: 5, title: 'Post 5' }, ], }, @@ -644,6 +665,7 @@ describe('With Policy: refactor tests', () => { }, }) ).toResolveTruthy(); + await expect(user1Db.post.findUnique({ where: { id: 4 } })).resolves.toMatchObject({ title: 'Post 4' }); await expect( user1Db.user.update({ include: { posts: true }, @@ -723,12 +745,13 @@ describe('With Policy: refactor tests', () => { posts: { update: { where: { id: 1 }, - data: { title: 'Post1-1' }, + data: { title: ' Post1-1' }, }, }, }, }) ).toResolveTruthy(); + await expect(user1Db.post.findUnique({ where: { id: 1 } })).resolves.toMatchObject({ title: 'Post1-1' }); await expect( user1Db.user.update({ where: { id: 1 }, @@ -799,14 +822,14 @@ describe('With Policy: refactor tests', () => { posts: { upsert: { where: { id: 1 }, - update: { title: 'Post 1-1' }, // update + update: { title: ' Post 2' }, // update create: { id: 7, title: 'Post 1' }, }, }, }, }) ).toResolveTruthy(); - await expect(user1Db.post.findUnique({ where: { id: 1 } })).resolves.toMatchObject({ title: 'Post 1-1' }); + await expect(user1Db.post.findUnique({ where: { id: 1 } })).resolves.toMatchObject({ title: 'Post 2' }); await expect( user1Db.user.update({ where: { id: 1 }, @@ -815,7 +838,7 @@ describe('With Policy: refactor tests', () => { upsert: { where: { id: 7 }, update: { title: 'Post 7-1' }, - create: { id: 7, title: 'Post 7' }, // create + create: { id: 7, title: ' Post 7' }, // create }, }, }, @@ -1094,9 +1117,10 @@ describe('With Policy: refactor tests', () => { ).toBeRejectedByPolicy(); await expect( user1Db.post.updateMany({ - data: { title: 'My post' }, + data: { title: ' My post' }, }) ).resolves.toMatchObject({ count: 2 }); + await expect(user1Db.post.findFirst()).resolves.toMatchObject({ title: 'My post' }); }); it('delete single', async () => { diff --git a/tests/integration/tests/schema/refactor-pg.zmodel b/tests/integration/tests/schema/refactor-pg.zmodel index f52f36c98..d0b4579e1 100644 --- a/tests/integration/tests/schema/refactor-pg.zmodel +++ b/tests/integration/tests/schema/refactor-pg.zmodel @@ -5,7 +5,7 @@ enum Role { model User { id Int @id @default(autoincrement()) - email String @unique @email + email String @unique @email @lower role Role @default(USER) profile Profile? posts Post[] @@ -52,7 +52,7 @@ model Image { model Post { id Int @id @default(autoincrement()) - title String @length(1, 8) + title String @length(1, 8) @trim published Boolean @default(false) comments Comment[] author User @relation(fields: [authorId], references: [id], onDelete: Cascade) @@ -67,7 +67,7 @@ model Post { model Comment { id Int @id @default(autoincrement()) - content String + content String @trim author User @relation(fields: [authorId], references: [id], onDelete: Cascade) authorId Int