diff --git a/packages/runtime/src/enhancements/model-meta.ts b/packages/runtime/src/enhancements/model-meta.ts index 83eef9a64..953ff5ba9 100644 --- a/packages/runtime/src/enhancements/model-meta.ts +++ b/packages/runtime/src/enhancements/model-meta.ts @@ -28,7 +28,7 @@ export function getDefaultModelMeta(): ModelMeta { * Resolves a model field to its metadata. Returns undefined if not found. */ export function resolveField(modelMeta: ModelMeta, model: string, field: string): FieldInfo | undefined { - return modelMeta.fields[lowerCaseFirst(model)][field]; + return modelMeta.fields[lowerCaseFirst(model)]?.[field]; } /** diff --git a/packages/runtime/src/enhancements/policy/policy-utils.ts b/packages/runtime/src/enhancements/policy/policy-utils.ts index e16008299..76cedd03e 100644 --- a/packages/runtime/src/enhancements/policy/policy-utils.ts +++ b/packages/runtime/src/enhancements/policy/policy-utils.ts @@ -19,6 +19,7 @@ import { getFields, resolveField } from '../model-meta'; import { NestedWriteVisitorContext } from '../nested-write-vistor'; import type { InputCheckFunc, ModelMeta, PolicyDef, ReadFieldCheckFunc, ZodSchemas } from '../types'; import { + enumerate, formatObject, getIdFields, getModelFields, @@ -54,14 +55,28 @@ export class PolicyUtil { * Creates a conjunction of a list of query conditions. */ and(...conditions: (boolean | object | undefined)[]): object { - return this.reduce({ AND: conditions }); + const filtered = conditions.filter((c) => c !== undefined); + if (filtered.length === 0) { + return this.makeTrue(); + } else if (filtered.length === 1) { + return this.reduce(filtered[0]); + } else { + return this.reduce({ AND: filtered }); + } } /** * Creates a disjunction of a list of query conditions. */ or(...conditions: (boolean | object | undefined)[]): object { - return this.reduce({ OR: conditions }); + const filtered = conditions.filter((c) => c !== undefined); + if (filtered.length === 0) { + return this.makeFalse(); + } else if (filtered.length === 1) { + return this.reduce(filtered[0]); + } else { + return this.reduce({ OR: filtered }); + } } /** @@ -116,48 +131,75 @@ export class PolicyUtil { return this.makeFalse(); } - if ('AND' in condition && Array.isArray(condition.AND)) { - const children = condition.AND.map((c: any) => this.reduce(c)).filter( - (c) => c !== undefined && !this.isTrue(c) - ); - if (children.length === 0) { - return this.makeTrue(); - } else if (children.some((c) => this.isFalse(c))) { - return this.makeFalse(); - } else if (children.length === 1) { - return children[0]; - } else { - return { AND: children }; - } + if (condition === null) { + return condition; } - if ('OR' in condition && Array.isArray(condition.OR)) { - const children = condition.OR.map((c: any) => this.reduce(c)).filter( - (c) => c !== undefined && !this.isFalse(c) - ); - if (children.length === 0) { - return this.makeFalse(); - } else if (children.some((c) => this.isTrue(c))) { - return this.makeTrue(); - } else if (children.length === 1) { - return children[0]; - } else { - return { OR: children }; + const result: any = {}; + for (const [key, value] of Object.entries(condition)) { + if (value === null || value === undefined) { + result[key] = value; + continue; } - } - if ('NOT' in condition && condition.NOT !== null && typeof condition.NOT === 'object') { - const child = this.reduce(condition.NOT); - if (this.isTrue(child)) { - return this.makeFalse(); - } else if (this.isFalse(child)) { - return this.makeTrue(); - } else { - return { NOT: child }; + switch (key) { + case 'AND': { + const children = enumerate(value) + .map((c: any) => this.reduce(c)) + .filter((c) => c !== undefined && !this.isTrue(c)); + if (children.length === 0) { + result[key] = []; // true + } else if (children.some((c) => this.isFalse(c))) { + result['OR'] = []; // false + } else { + if (!this.isTrue({ AND: result[key] })) { + // use AND only if it's not already true + result[key] = !Array.isArray(value) && children.length === 1 ? children[0] : children; + } + } + break; + } + + case 'OR': { + const children = enumerate(value) + .map((c: any) => this.reduce(c)) + .filter((c) => c !== undefined && !this.isFalse(c)); + if (children.length === 0) { + result[key] = []; // false + } else if (children.some((c) => this.isTrue(c))) { + result['AND'] = []; // true + } else { + if (!this.isFalse({ OR: result[key] })) { + // use OR only if it's not already false + result[key] = !Array.isArray(value) && children.length === 1 ? children[0] : children; + } + } + break; + } + + case 'NOT': { + result[key] = this.reduce(value); + break; + } + + default: { + const booleanKeys = ['AND', 'OR', 'NOT', 'is', 'isNot', 'none', 'every', 'some']; + if ( + typeof value === 'object' && + value && + // recurse only if the value has at least one boolean key + Object.keys(value).some((k) => booleanKeys.includes(k)) + ) { + result[key] = this.reduce(value); + } else { + result[key] = value; + } + break; + } } } - return condition; + return result; } //#endregion @@ -349,18 +391,18 @@ export class PolicyUtil { operation: PolicyOperationKind ) { const guard = this.getAuthGuard(db, fieldInfo.type, operation); + + // is|isNot and flat fields conditions are mutually exclusive + if (payload.is || payload.isNot) { if (payload.is) { this.injectGuardForRelationFields(db, fieldInfo.type, payload.is, operation); - // turn "is" into: { is: { AND: [ originalIs, guard ] } - payload.is = this.and(payload.is, guard); } if (payload.isNot) { this.injectGuardForRelationFields(db, fieldInfo.type, payload.isNot, operation); - // turn "isNot" into: { isNot: { AND: [ originalIsNot, { NOT: guard } ] } } - payload.isNot = this.and(payload.isNot, this.not(guard)); - delete payload.isNot; } + // merge guard with existing "is": { is: [originalIs, guard] } + payload.is = this.and(payload.is, guard); } else { this.injectGuardForRelationFields(db, fieldInfo.type, payload, operation); // turn direct conditions into: { is: { AND: [ originalConditions, guard ] } } @@ -1062,7 +1104,6 @@ export class PolicyUtil { throw new Error('invalid where clause'); } - extra = this.reduce(extra); if (this.isTrue(extra)) { return; } diff --git a/tests/integration/tests/enhancements/with-policy/query-reduction.test.ts b/tests/integration/tests/enhancements/with-policy/query-reduction.test.ts new file mode 100644 index 000000000..1654fba96 --- /dev/null +++ b/tests/integration/tests/enhancements/with-policy/query-reduction.test.ts @@ -0,0 +1,147 @@ +import { loadSchema } from '@zenstackhq/testtools'; +import path from 'path'; + +describe('With Policy: query reduction', () => { + let origDir: string; + + beforeAll(async () => { + origDir = path.resolve('.'); + }); + + afterEach(() => { + process.chdir(origDir); + }); + + it('test query reduction', async () => { + const { prisma, withPolicy } = await loadSchema( + ` + model User { + id Int @id @default(autoincrement()) + role String @default("User") + posts Post[] + private Boolean @default(false) + age Int + + @@allow('all', auth() == this) + @@allow('read', !private) + } + + model Post { + id Int @id @default(autoincrement()) + user User @relation(fields: [userId], references: [id]) + userId Int + title String + published Boolean @default(false) + viewCount Int @default(0) + + @@allow('all', auth() == user) + @@allow('read', published) + } + ` + ); + + await prisma.user.create({ + data: { + id: 1, + role: 'User', + age: 18, + posts: { + create: [ + { id: 1, title: 'Post 1' }, + { id: 2, title: 'Post 2', published: true }, + ], + }, + }, + }); + await prisma.user.create({ + data: { + id: 2, + role: 'Admin', + age: 28, + private: true, + posts: { + create: [{ id: 3, title: 'Post 3', viewCount: 100 }], + }, + }, + }); + + const dbUser1 = withPolicy({ id: 1 }); + const dbUser2 = withPolicy({ id: 2 }); + + await expect( + dbUser1.user.findMany({ + where: { id: 2, AND: { age: { gt: 20 } } }, + }) + ).resolves.toHaveLength(0); + + await expect( + dbUser2.user.findMany({ + where: { id: 2, AND: { age: { gt: 20 } } }, + }) + ).resolves.toHaveLength(1); + + await expect( + dbUser1.user.findMany({ + where: { + AND: { age: { gt: 10 } }, + OR: [{ age: { gt: 25 } }, { age: { lt: 20 } }], + NOT: { private: true }, + }, + }) + ).resolves.toHaveLength(1); + + await expect( + dbUser2.user.findMany({ + where: { + AND: { age: { gt: 10 } }, + OR: [{ age: { gt: 25 } }, { age: { lt: 20 } }], + NOT: { private: true }, + }, + }) + ).resolves.toHaveLength(1); + + // to-many relation query + await expect( + dbUser1.user.findMany({ + where: { posts: { some: { published: true } } }, + }) + ).resolves.toHaveLength(1); + await expect( + dbUser1.user.findMany({ + where: { posts: { some: { AND: [{ published: true }, { viewCount: { gt: 0 } }] } } }, + }) + ).resolves.toHaveLength(0); + await expect( + dbUser2.user.findMany({ + where: { posts: { some: { AND: [{ published: false }, { viewCount: { gt: 0 } }] } } }, + }) + ).resolves.toHaveLength(1); + await expect( + dbUser1.user.findMany({ + where: { posts: { every: { published: true } } }, + }) + ).resolves.toHaveLength(0); + await expect( + dbUser1.user.findMany({ + where: { posts: { none: { published: true } } }, + }) + ).resolves.toHaveLength(0); + + // to-one relation query + await expect( + dbUser1.post.findMany({ + where: { user: { role: 'Admin' } }, + }) + ).resolves.toHaveLength(0); + await expect( + dbUser1.post.findMany({ + where: { user: { is: { role: 'Admin' } } }, + }) + ).resolves.toHaveLength(0); + await expect( + dbUser1.post.findMany({ + where: { user: { isNot: { role: 'User' } } }, + }) + ).resolves.toHaveLength(0); + }); +}); diff --git a/tests/integration/tests/enhancements/with-policy/relation-one-to-one-filter.test.ts b/tests/integration/tests/enhancements/with-policy/relation-one-to-one-filter.test.ts index e77b27792..7c26bc854 100644 --- a/tests/integration/tests/enhancements/with-policy/relation-one-to-one-filter.test.ts +++ b/tests/integration/tests/enhancements/with-policy/relation-one-to-one-filter.test.ts @@ -206,7 +206,7 @@ describe('With Policy: relation one-to-one filter', () => { }, }, }) - ).toResolveTruthy(); + ).toResolveFalsy(); // m1 with m2 and m3 await db.m1.create({ @@ -257,7 +257,7 @@ describe('With Policy: relation one-to-one filter', () => { }, }, }) - ).toResolveTruthy(); + ).toResolveFalsy(); }); it('direct object filter', async () => { diff --git a/tests/integration/tests/regression/issue-689.test.ts b/tests/integration/tests/regression/issue-689.test.ts new file mode 100644 index 000000000..32687abca --- /dev/null +++ b/tests/integration/tests/regression/issue-689.test.ts @@ -0,0 +1,71 @@ +import { loadSchema } from '@zenstackhq/testtools'; + +describe('Regression: issue 689', () => { + it('regression', async () => { + const { prisma, enhance } = await loadSchema( + ` + model UserRole { + id Int @id @default(autoincrement()) + user User @relation(fields: [userId], references: [id]) + userId Int + role String + + @@allow('all', true) + } + + model User { + id Int @id @default(autoincrement()) + userRole UserRole[] + deleted Boolean @default(false) + + @@allow('create,read', true) + @@allow('all', auth() == this) + @@allow('all', userRole?[user == auth() && 'Admin' == role]) + @@allow('read', userRole?[user == auth()]) + } + ` + ); + + await prisma.user.create({ + data: { + id: 1, + userRole: { + create: [ + { id: 1, role: 'Admin' }, + { id: 2, role: 'Student' }, + ], + }, + }, + }); + + await prisma.user.create({ + data: { + id: 2, + userRole: { + connect: { id: 1 }, + }, + }, + }); + + const c1 = await prisma.user.count({ + where: { + userRole: { + some: { role: 'Student' }, + }, + NOT: { deleted: true }, + }, + }); + + const db = enhance(); + const c2 = await db.user.count({ + where: { + userRole: { + some: { role: 'Student' }, + }, + NOT: { deleted: true }, + }, + }); + + expect(c1).toEqual(c2); + }); +});