diff --git a/packages/runtime/src/client/crud/operations/base.ts b/packages/runtime/src/client/crud/operations/base.ts index 65d0d32b..6d5cae1b 100644 --- a/packages/runtime/src/client/crud/operations/base.ts +++ b/packages/runtime/src/client/crud/operations/base.ts @@ -1009,10 +1009,7 @@ export abstract class BaseOperationHandler { throw new QueryError(`Relation update not allowed for field "${field}"`); } if (!thisEntity) { - thisEntity = await this.readUnique(kysely, model, { - where: combinedWhere, - select: this.makeIdSelect(model), - }); + thisEntity = await this.getEntityIds(kysely, model, combinedWhere); if (!thisEntity) { if (throwIfNotFound) { throw new NotFoundError(model); diff --git a/packages/runtime/src/plugins/policy/expression-transformer.ts b/packages/runtime/src/plugins/policy/expression-transformer.ts index f6e35a55..2c39cbb7 100644 --- a/packages/runtime/src/plugins/policy/expression-transformer.ts +++ b/packages/runtime/src/plugins/policy/expression-transformer.ts @@ -25,7 +25,7 @@ import { getCrudDialect } from '../../client/crud/dialects'; import type { BaseCrudDialect } from '../../client/crud/dialects/base-dialect'; import { InternalError, QueryError } from '../../client/errors'; import type { ClientOptions } from '../../client/options'; -import { getModel, getRelationForeignKeyFieldPairs, requireField } from '../../client/query-utils'; +import { getIdFields, getModel, getRelationForeignKeyFieldPairs, requireField } from '../../client/query-utils'; import type { BinaryExpression, BinaryOperator, @@ -111,7 +111,6 @@ export class ExpressionTransformer { } @expr('field') - // @ts-expect-error private _field(expr: FieldExpression, context: ExpressionTransformerContext) { const fieldDef = requireField(this.schema, context.model, expr.field); if (!fieldDef.relation) { @@ -162,8 +161,9 @@ export class ExpressionTransformer { return this.transformCollectionPredicate(expr, context); } - const left = this.transform(expr.left, context); - const right = this.transform(expr.right, context); + const { normalizedLeft, normalizedRight } = this.normalizeBinaryOperationOperands(expr, context); + const left = this.transform(normalizedLeft, context); + const right = this.transform(normalizedRight, context); if (op === 'in') { if (this.isNullNode(left)) { @@ -195,6 +195,22 @@ export class ExpressionTransformer { return BinaryOperationNode.create(left, this.transformOperator(op), right); } + private normalizeBinaryOperationOperands(expr: BinaryExpression, context: ExpressionTransformerContext) { + let normalizedLeft: Expression = expr.left; + if (this.isRelationField(expr.left, context.model)) { + invariant(ExpressionUtils.isNull(expr.right), 'only null comparison is supported for relation field'); + const idFields = getIdFields(this.schema, context.model); + normalizedLeft = this.makeOrAppendMember(normalizedLeft, idFields[0]!); + } + let normalizedRight: Expression = expr.right; + if (this.isRelationField(expr.right, context.model)) { + invariant(ExpressionUtils.isNull(expr.left), 'only null comparison is supported for relation field'); + const idFields = getIdFields(this.schema, context.model); + normalizedRight = this.makeOrAppendMember(normalizedRight, idFields[0]!); + } + return { normalizedLeft, normalizedRight }; + } + private transformCollectionPredicate(expr: BinaryExpression, context: ExpressionTransformerContext) { invariant(expr.op === '?' || expr.op === '!' || expr.op === '^', 'expected "?" or "!" or "^" operator'); @@ -211,11 +227,15 @@ export class ExpressionTransformer { ); let newContextModel: string; - if (ExpressionUtils.isField(expr.left)) { - const fieldDef = requireField(this.schema, context.model, expr.left.field); + const fieldDef = this.getFieldDefFromFieldRef(expr.left, context.model); + if (fieldDef) { + invariant(fieldDef.relation, `field is not a relation: ${JSON.stringify(expr.left)}`); newContextModel = fieldDef.type; } else { - invariant(ExpressionUtils.isField(expr.left.receiver)); + invariant( + ExpressionUtils.isMember(expr.left) && ExpressionUtils.isField(expr.left.receiver), + 'left operand must be member access with field receiver', + ); const fieldDef = requireField(this.schema, context.model, expr.left.receiver.field); newContextModel = fieldDef.type; for (const member of expr.left.members) { @@ -396,16 +416,14 @@ export class ExpressionTransformer { if (ExpressionUtils.isThis(expr.receiver)) { if (expr.members.length === 1) { - // optimize for the simple this.scalar case - const fieldDef = requireField(this.schema, context.model, expr.members[0]!); - invariant(!fieldDef.relation, 'this.relation access should have been transformed into relation access'); - return this.createColumnRef(expr.members[0]!, restContext); + // `this.relation` case, equivalent to field access + return this._field(ExpressionUtils.field(expr.members[0]!), context); + } else { + // transform the first segment into a relation access, then continue with the rest of the members + const firstMemberFieldDef = requireField(this.schema, context.model, expr.members[0]!); + receiver = this.transformRelationAccess(expr.members[0]!, firstMemberFieldDef.type, restContext); + members = expr.members.slice(1); } - - // transform the first segment into a relation access, then continue with the rest of the members - const firstMemberFieldDef = requireField(this.schema, context.model, expr.members[0]!); - receiver = this.transformRelationAccess(expr.members[0]!, firstMemberFieldDef.type, restContext); - members = expr.members.slice(1); } else { receiver = this.transform(expr.receiver, restContext); } @@ -559,4 +577,23 @@ export class ExpressionTransformer { return conditions.reduce((acc, condition) => ExpressionUtils.binary(acc, '&&', condition)); } } + + private isRelationField(expr: Expression, model: GetModels) { + const fieldDef = this.getFieldDefFromFieldRef(expr, model); + return !!fieldDef?.relation; + } + + private getFieldDefFromFieldRef(expr: Expression, model: GetModels): FieldDef | undefined { + if (ExpressionUtils.isField(expr)) { + return requireField(this.schema, model, expr.field); + } else if ( + ExpressionUtils.isMember(expr) && + expr.members.length === 1 && + ExpressionUtils.isThis(expr.receiver) + ) { + return requireField(this.schema, model, expr.members[0]!); + } else { + return undefined; + } + } } diff --git a/packages/runtime/src/plugins/policy/policy-handler.ts b/packages/runtime/src/plugins/policy/policy-handler.ts index fcbabe43..6d018980 100644 --- a/packages/runtime/src/plugins/policy/policy-handler.ts +++ b/packages/runtime/src/plugins/policy/policy-handler.ts @@ -102,6 +102,7 @@ export class PolicyHandler extends OperationNodeTransf } return readBackResult; } else { + // reading id fields bypasses policy return result; } diff --git a/packages/runtime/test/policy/crud/read.test.ts b/packages/runtime/test/policy/crud/read.test.ts new file mode 100644 index 00000000..a0e42815 --- /dev/null +++ b/packages/runtime/test/policy/crud/read.test.ts @@ -0,0 +1,248 @@ +import { describe, expect, it } from 'vitest'; +import { createPolicyTestClient } from '../utils'; + +describe('Read policy tests', () => { + describe('Find tests', () => { + it('works with top-level find', async () => { + const db = await createPolicyTestClient( + ` +model Foo { + id Int @id + x Int + @@allow('create', true) + @@allow('read', x > 0) +} +`, + ); + + await db.$unuseAll().foo.create({ data: { id: 1, x: 0 } }); + await expect(db.foo.findUnique({ where: { id: 1 } })).toResolveNull(); + + await db.$unuseAll().foo.update({ where: { id: 1 }, data: { x: 1 } }); + await expect(db.foo.findUnique({ where: { id: 1 } })).toResolveTruthy(); + }); + + it('works with mutation read-back', async () => { + const db = await createPolicyTestClient( + ` +model Foo { + id Int @id + x Int + @@allow('create,update', true) + @@allow('read', x > 0) +} +`, + ); + + await expect(db.foo.create({ data: { id: 1, x: 0 } })).toBeRejectedByPolicy(); + await expect(db.$unuseAll().foo.count()).resolves.toBe(1); + await expect(db.foo.update({ where: { id: 1 }, data: { x: 1 } })).resolves.toMatchObject({ x: 1 }); + }); + + it('works with to-one relation optional owner-side read', async () => { + const db = await createPolicyTestClient( + ` +model Foo { + id Int @id + bar Bar? @relation(fields: [barId], references: [id]) + barId Int? @unique + @@allow('all', true) +} + +model Bar { + id Int @id + y Int + foo Foo? + @@allow('create,update', true) + @@allow('read', y > 0) +} +`, + ); + + await db.foo.create({ data: { id: 1, bar: { create: { id: 1, y: 0 } } } }); + await expect(db.foo.findFirst({ include: { bar: true } })).resolves.toMatchObject({ id: 1, bar: null }); + await db.bar.update({ where: { id: 1 }, data: { y: 1 } }); + await expect(db.foo.findFirst({ include: { bar: true } })).resolves.toMatchObject({ + id: 1, + bar: { id: 1 }, + }); + }); + + // TODO: check if we should be consistent with v2 and filter out the parent entity + // if a non-optional child relation is included but not readable + it('works with to-one relation non-optional owner-side read', async () => { + const db = await createPolicyTestClient( + ` +model Foo { + id Int @id + bar Bar @relation(fields: [barId], references: [id]) + barId Int @unique + @@allow('all', true) +} + +model Bar { + id Int @id + y Int + foo Foo? + @@allow('create,update', true) + @@allow('read', y > 0) +} +`, + ); + + await db.foo.create({ data: { id: 1, bar: { create: { id: 1, y: 0 } } } }); + await expect(db.foo.findFirst({ include: { bar: true } })).resolves.toMatchObject({ id: 1, bar: null }); + await db.bar.update({ where: { id: 1 }, data: { y: 1 } }); + await expect(db.foo.findFirst({ include: { bar: true } })).resolves.toMatchObject({ + id: 1, + bar: { id: 1 }, + }); + }); + + it('works with to-one relation non-owner-side read', async () => { + const db = await createPolicyTestClient( + ` +model Foo { + id Int @id + bar Bar? + @@allow('all', true) +} + +model Bar { + id Int @id + y Int + foo Foo @relation(fields: [fooId], references: [id]) + fooId Int @unique + @@allow('create,update', true) + @@allow('read', y > 0) +} +`, + ); + + await db.foo.create({ data: { id: 1, bar: { create: { id: 1, y: 0 } } } }); + await expect(db.foo.findFirst({ include: { bar: true } })).resolves.toMatchObject({ id: 1, bar: null }); + await db.bar.update({ where: { id: 1 }, data: { y: 1 } }); + await expect(db.foo.findFirst({ include: { bar: true } })).resolves.toMatchObject({ + id: 1, + bar: { id: 1 }, + }); + }); + + it('works with to-many relation read', async () => { + const db = await createPolicyTestClient( + ` +model Foo { + id Int @id + bars Bar[] + @@allow('all', true) +} + +model Bar { + id Int @id + y Int + foo Foo? @relation(fields: [fooId], references: [id]) + fooId Int? + @@allow('create,update', true) + @@allow('read', y > 0) +} +`, + ); + + await db.foo.create({ + data: { + id: 1, + bars: { + create: [ + { id: 1, y: 0 }, + { id: 2, y: 1 }, + ], + }, + }, + }); + await expect(db.foo.findFirst({ include: { bars: true } })).resolves.toMatchObject({ + id: 1, + bars: [{ id: 2 }], + }); + }); + + it('works with filtered by to-one relation field', async () => { + const db = await createPolicyTestClient( + ` +model Foo { + id Int @id + bar Bar? @relation(fields: [barId], references: [id]) + barId Int? @unique + @@allow('create', true) + @@allow('read', bar.y > 0) +} + +model Bar { + id Int @id + y Int + foo Foo? + @@allow('all', true) +} +`, + ); + + await db.$unuseAll().foo.create({ data: { id: 1, bar: { create: { id: 1, y: 0 } } } }); + await expect(db.foo.findMany()).resolves.toHaveLength(0); + await db.bar.update({ where: { id: 1 }, data: { y: 1 } }); + await expect(db.foo.findMany()).resolves.toHaveLength(1); + }); + + it('works with filtered by to-one relation non-null', async () => { + const db = await createPolicyTestClient( + ` +model Foo { + id Int @id + bar Bar? @relation(fields: [barId], references: [id]) + barId Int? @unique + @@allow('create,update', true) + @@allow('read', bar != null) + @@allow('read', this.bar != null) +} + +model Bar { + id Int @id + y Int + foo Foo? + @@allow('all', true) +} +`, + ); + + await db.$unuseAll().foo.create({ data: { id: 1 } }); + await expect(db.foo.findMany()).resolves.toHaveLength(0); + await db.foo.update({ where: { id: 1 }, data: { bar: { create: { id: 1, y: 0 } } } }); + await expect(db.foo.findMany()).resolves.toHaveLength(1); + }); + + it('works with filtered by to-many relation', async () => { + const db = await createPolicyTestClient( + ` +model Foo { + id Int @id + bars Bar[] + @@allow('create,update', true) + @@allow('read', bars?[y > 0]) + @@allow('read', this.bars?[y > 0]) +} + +model Bar { + id Int @id + y Int + foo Foo? @relation(fields: [fooId], references: [id]) + fooId Int? + @@allow('all', true) +} +`, + ); + + await db.$unuseAll().foo.create({ data: { id: 1, bars: { create: [{ id: 1, y: 0 }] } } }); + await expect(db.foo.findMany()).resolves.toHaveLength(0); + await db.foo.update({ where: { id: 1 }, data: { bars: { create: { id: 2, y: 1 } } } }); + await expect(db.foo.findMany()).resolves.toHaveLength(1); + }); + }); +});