From a1b44f32b0de4c240adf588f25432375803a3ae0 Mon Sep 17 00:00:00 2001 From: ymc9 <104139426+ymc9@users.noreply.github.com> Date: Sun, 31 Aug 2025 22:48:56 +0800 Subject: [PATCH] fix: stricter type def validation and compound unique field fix --- TODO.md | 2 +- .../src/validators/typedef-validator.ts | 7 +- packages/language/test/mixin.test.ts | 15 ++ .../test/policy/client-extensions.test.ts | 10 +- packages/runtime/test/policy/mixin.test.ts | 92 +++++++++ .../test/policy/multi-field-unique.test.ts | 179 ++++++++++++++++++ packages/sdk/src/ts-schema-generator.ts | 13 +- 7 files changed, 309 insertions(+), 9 deletions(-) create mode 100644 packages/runtime/test/policy/mixin.test.ts create mode 100644 packages/runtime/test/policy/multi-field-unique.test.ts diff --git a/TODO.md b/TODO.md index 35edf349..8ccc6729 100644 --- a/TODO.md +++ b/TODO.md @@ -7,6 +7,7 @@ - [x] init - [x] validate - [ ] format + - [ ] repl - [x] plugin mechanism - [x] built-in plugins - [x] typescript @@ -82,7 +83,6 @@ - [x] Error system - [x] Custom table name - [x] Custom field name - - [ ] Strict undefined checks - [ ] DbNull vs JsonNull - [ ] Migrate to tsdown - [ ] Benchmark diff --git a/packages/language/src/validators/typedef-validator.ts b/packages/language/src/validators/typedef-validator.ts index d029d8ba..6ad35b0b 100644 --- a/packages/language/src/validators/typedef-validator.ts +++ b/packages/language/src/validators/typedef-validator.ts @@ -1,5 +1,5 @@ import type { ValidationAcceptor } from 'langium'; -import type { DataField, TypeDef } from '../generated/ast'; +import { isDataModel, type DataField, type TypeDef } from '../generated/ast'; import { validateAttributeApplication } from './attribute-application-validator'; import { validateDuplicatedDeclarations, type AstValidator } from './common'; @@ -22,6 +22,11 @@ export default class TypeDefValidator implements AstValidator { } private validateField(field: DataField, accept: ValidationAcceptor): void { + if (isDataModel(field.type.reference?.ref)) { + accept('error', 'Type field cannot be a relation', { + node: field.type, + }); + } field.attributes.forEach((attr) => validateAttributeApplication(attr, accept)); } } diff --git a/packages/language/test/mixin.test.ts b/packages/language/test/mixin.test.ts index 3832148d..8e7bcd0a 100644 --- a/packages/language/test/mixin.test.ts +++ b/packages/language/test/mixin.test.ts @@ -106,4 +106,19 @@ describe('Mixin Tests', () => { 'can only be applied once', ); }); + + it('does not allow relation fields in type', async () => { + await loadSchemaWithError( + ` + model User { + id Int @id @default(autoincrement()) + } + + type T { + u User + } + `, + 'Type field cannot be a relation', + ); + }); }); diff --git a/packages/runtime/test/policy/client-extensions.test.ts b/packages/runtime/test/policy/client-extensions.test.ts index f1f916b4..1f725172 100644 --- a/packages/runtime/test/policy/client-extensions.test.ts +++ b/packages/runtime/test/policy/client-extensions.test.ts @@ -22,7 +22,7 @@ describe('client extensions tests for policies', () => { await rawDb.model.create({ data: { x: 2, y: 300 } }); const ext = definePlugin({ - id: 'prisma-extension-queryOverride', + id: 'queryOverride', onQuery: async ({ args, proceed }: any) => { args = args ?? {}; args.where = { ...args.where, y: { lt: 300 } }; @@ -53,7 +53,7 @@ describe('client extensions tests for policies', () => { await rawDb.model.create({ data: { x: 2, y: 300 } }); const ext = definePlugin({ - id: 'prisma-extension-queryOverride', + id: 'queryOverride', onQuery: async ({ args, proceed }: any) => { args = args ?? {}; args.where = { ...args.where, y: { lt: 300 } }; @@ -84,7 +84,7 @@ describe('client extensions tests for policies', () => { await rawDb.model.create({ data: { x: 2, y: 300 } }); const ext = definePlugin({ - id: 'prisma-extension-queryOverride', + id: 'queryOverride', onQuery: async ({ args, proceed }: any) => { args = args ?? {}; args.where = { ...args.where, y: { lt: 300 } }; @@ -115,7 +115,7 @@ describe('client extensions tests for policies', () => { await rawDb.model.create({ data: { x: 2, y: 300 } }); const ext = definePlugin({ - id: 'prisma-extension-queryOverride', + id: 'queryOverride', onQuery: async ({ args, proceed }: any) => { args = args ?? {}; args.where = { ...args.where, y: { lt: 300 } }; @@ -144,7 +144,7 @@ describe('client extensions tests for policies', () => { await rawDb.model.create({ data: { value: 1 } }); const ext = definePlugin({ - id: 'prisma-extension-resultMutation', + id: 'resultMutation', onQuery: async ({ args, proceed }: any) => { const r: any = await proceed(args); for (let i = 0; i < r.length; i++) { diff --git a/packages/runtime/test/policy/mixin.test.ts b/packages/runtime/test/policy/mixin.test.ts new file mode 100644 index 00000000..247e9864 --- /dev/null +++ b/packages/runtime/test/policy/mixin.test.ts @@ -0,0 +1,92 @@ +import { describe, expect, it } from 'vitest'; +import { createPolicyTestClient } from './utils'; + +describe('Abstract models', () => { + it('connect test1', async () => { + const db = await createPolicyTestClient( + ` + model User { + id Int @id @default(autoincrement()) + profile Profile? @relation(fields: [profileId], references: [id]) + profileId Int? @unique + + @@allow('create,read', true) + @@allow('update', auth().id == 1) + } + + type BaseProfile { + id Int @id @default(autoincrement()) + + @@allow('all', true) + } + + model Profile with BaseProfile { + name String + user User? + } + `, + ); + + const dbUser2 = db.$setAuth({ id: 2 }); + const user = await dbUser2.user.create({ data: { id: 1 } }); + const profile = await dbUser2.profile.create({ data: { id: 1, name: 'John' } }); + await expect( + dbUser2.profile.update({ where: { id: 1 }, data: { user: { connect: { id: user.id } } } }), + ).toBeRejectedNotFound(); + await expect( + dbUser2.user.update({ where: { id: 1 }, data: { profile: { connect: { id: profile.id } } } }), + ).toBeRejectedNotFound(); + + const dbUser1 = db.$setAuth({ id: 1 }); + await expect( + dbUser1.profile.update({ where: { id: 1 }, data: { user: { connect: { id: user.id } } } }), + ).toResolveTruthy(); + await expect( + dbUser1.user.update({ where: { id: 1 }, data: { profile: { connect: { id: profile.id } } } }), + ).toResolveTruthy(); + }); + + it('connect test2', async () => { + const db = await createPolicyTestClient( + ` + model User { + id Int @id @default(autoincrement()) + profile Profile? + + @@allow('all', true) + } + + type BaseProfile { + id Int @id @default(autoincrement()) + + @@allow('create,read', true) + @@allow('update', auth().id == 1) + } + + model Profile with BaseProfile { + name String + user User? @relation(fields: [userId], references: [id]) + userId Int? @unique + } + `, + ); + + const dbUser2 = db.$setAuth({ id: 2 }); + const user = await dbUser2.user.create({ data: { id: 1 } }); + const profile = await dbUser2.profile.create({ data: { id: 1, name: 'John' } }); + await expect( + dbUser2.profile.update({ where: { id: 1 }, data: { user: { connect: { id: user.id } } } }), + ).toBeRejectedNotFound(); + await expect( + dbUser2.user.update({ where: { id: 1 }, data: { profile: { connect: { id: profile.id } } } }), + ).toBeRejectedNotFound(); + + const dbUser1 = db.$setAuth({ id: 1 }); + await expect( + dbUser1.profile.update({ where: { id: 1 }, data: { user: { connect: { id: user.id } } } }), + ).toResolveTruthy(); + await expect( + dbUser1.user.update({ where: { id: 1 }, data: { profile: { connect: { id: profile.id } } } }), + ).toResolveTruthy(); + }); +}); diff --git a/packages/runtime/test/policy/multi-field-unique.test.ts b/packages/runtime/test/policy/multi-field-unique.test.ts new file mode 100644 index 00000000..029bdaeb --- /dev/null +++ b/packages/runtime/test/policy/multi-field-unique.test.ts @@ -0,0 +1,179 @@ +import path from 'path'; +import { afterEach, beforeAll, describe, expect, it } from 'vitest'; +import { createPolicyTestClient } from './utils'; +import { QueryError } from '../../src'; + +describe('With Policy: multi-field unique', () => { + let origDir: string; + + beforeAll(async () => { + origDir = path.resolve('.'); + }); + + afterEach(() => { + process.chdir(origDir); + }); + + it('toplevel crud test unnamed constraint', async () => { + const db = await createPolicyTestClient( + ` + model Model { + id String @id @default(uuid()) + a String + b String + x Int + @@unique([a, b]) + + @@allow('all', x > 0) + @@deny('update', x > 1) + } + `, + ); + + await expect(db.model.create({ data: { a: 'a1', b: 'b1', x: 1 } })).toResolveTruthy(); + await expect(db.model.create({ data: { a: 'a1', b: 'b1', x: 2 } })).rejects.toThrow(QueryError); + await expect(db.model.create({ data: { a: 'a2', b: 'b2', x: 0 } })).toBeRejectedByPolicy(); + + await expect(db.model.findUnique({ where: { a_b: { a: 'a1', b: 'b1' } } })).toResolveTruthy(); + await expect(db.model.findUnique({ where: { a_b: { a: 'a1', b: 'b2' } } })).toResolveFalsy(); + await expect(db.model.update({ where: { a_b: { a: 'a1', b: 'b1' } }, data: { x: 2 } })).toResolveTruthy(); + await expect(db.model.update({ where: { a_b: { a: 'a1', b: 'b1' } }, data: { x: 0 } })).toBeRejectedNotFound(); + + await expect(db.model.delete({ where: { a_b: { a: 'a1', b: 'b1' } } })).toResolveTruthy(); + }); + + it('toplevel crud test named constraint', async () => { + const db = await createPolicyTestClient( + ` + model Model { + id String @id @default(uuid()) + a String + b String + x Int + @@unique([a, b], name: 'myconstraint') + + @@allow('all', x > 0) + @@deny('update', x > 1) + } + `, + ); + + await expect(db.model.create({ data: { a: 'a1', b: 'b1', x: 1 } })).toResolveTruthy(); + await expect(db.model.findUnique({ where: { myconstraint: { a: 'a1', b: 'b1' } } })).toResolveTruthy(); + await expect(db.model.findUnique({ where: { myconstraint: { a: 'a1', b: 'b2' } } })).toResolveFalsy(); + await expect( + db.model.update({ where: { myconstraint: { a: 'a1', b: 'b1' } }, data: { x: 2 } }), + ).toResolveTruthy(); + await expect( + db.model.update({ where: { myconstraint: { a: 'a1', b: 'b1' } }, data: { x: 0 } }), + ).toBeRejectedNotFound(); + await expect(db.model.delete({ where: { myconstraint: { a: 'a1', b: 'b1' } } })).toResolveTruthy(); + }); + + it('nested crud test', async () => { + const db = await createPolicyTestClient( + ` + model M1 { + id String @id @default(uuid()) + m2 M2[] + @@allow('all', true) + } + + model M2 { + id String @id @default(uuid()) + a String + b String + x Int + m1 M1 @relation(fields: [m1Id], references: [id]) + m1Id String + + @@unique([a, b]) + @@allow('all', x > 0) + } + `, + ); + + await expect(db.m1.create({ data: { id: '1', m2: { create: { a: 'a1', b: 'b1', x: 1 } } } })).toResolveTruthy(); + await expect(db.m1.create({ data: { id: '2', m2: { create: { a: 'a1', b: 'b1', x: 2 } } } })).rejects.toThrow( + QueryError, + ); + await expect( + db.m1.create({ data: { id: '3', m2: { create: { a: 'a1', b: 'b2', x: 0 } } } }), + ).toBeRejectedByPolicy(); + + await expect( + db.m1.update({ + where: { id: '1' }, + data: { + m2: { + connectOrCreate: { + where: { a_b: { a: 'a1', b: 'b1' } }, + create: { a: 'a1', b: 'b1', x: 2 }, + }, + }, + }, + }), + ).toResolveTruthy(); + await expect(db.m2.count()).resolves.toBe(1); + + await expect( + db.m1.update({ + where: { id: '1' }, + data: { + m2: { + connectOrCreate: { + where: { a_b: { a: 'a1', b: 'b2' } }, + create: { a: 'a1', b: 'b2', x: 2 }, + }, + }, + }, + }), + ).toResolveTruthy(); + await expect(db.m2.count()).resolves.toBe(2); + + await expect( + db.m1.update({ + where: { id: '1' }, + data: { + m2: { + connectOrCreate: { + where: { a_b: { a: 'a2', b: 'b2' } }, + create: { a: 'a2', b: 'b2', x: 0 }, + }, + }, + }, + }), + ).toBeRejectedByPolicy(); + + await expect( + db.m1.update({ + where: { id: '1' }, + data: { + m2: { + update: { + where: { a_b: { a: 'a1', b: 'b2' } }, + data: { x: 3 }, + }, + }, + }, + }), + ).toResolveTruthy(); + await expect(db.m2.findUnique({ where: { a_b: { a: 'a1', b: 'b2' } } })).resolves.toEqual( + expect.objectContaining({ x: 3 }), + ); + + await expect( + db.m1.update({ + where: { id: '1' }, + data: { + m2: { + delete: { + a_b: { a: 'a1', b: 'b1' }, + }, + }, + }, + }), + ).toResolveTruthy(); + await expect(db.m2.count()).resolves.toBe(1); + }); +}); diff --git a/packages/sdk/src/ts-schema-generator.ts b/packages/sdk/src/ts-schema-generator.ts index 7484e700..d2e8ba64 100644 --- a/packages/sdk/src/ts-schema-generator.ts +++ b/packages/sdk/src/ts-schema-generator.ts @@ -795,14 +795,14 @@ export class TsSchemaGenerator { ); } else { // multi-field unique - const key = fieldNames.join('_'); + const key = this.getCompoundUniqueKey(attr, fieldNames); if (seenKeys.has(key)) { continue; } seenKeys.add(key); properties.push( ts.factory.createPropertyAssignment( - fieldNames.join('_'), + key, ts.factory.createObjectLiteralExpression( fieldNames.map((field) => { const fieldDef = allFields.find((f) => f.name === field)!; @@ -826,6 +826,15 @@ export class TsSchemaGenerator { return ts.factory.createObjectLiteralExpression(properties, true); } + private getCompoundUniqueKey(attr: DataModelAttribute, fieldNames: string[]) { + const nameArg = attr.args.find((arg) => arg.$resolvedParam.name === 'name'); + if (nameArg && isLiteralExpr(nameArg.value)) { + return nameArg.value.value as string; + } else { + return fieldNames.join('_'); + } + } + private generateFieldTypeLiteral(field: DataField): ts.Expression { invariant( field.type.type || field.type.reference || field.type.unsupported,