From 646128c87338ea0b11c2ae868006dfe5ba03d4bb Mon Sep 17 00:00:00 2001 From: ymc9 <104139426+ymc9@users.noreply.github.com> Date: Sun, 8 Jun 2025 20:57:03 -0700 Subject: [PATCH 1/2] feat: many-to-many relation (sqlite only) --- TODO.md | 14 +- packages/cli/src/actions/generate.ts | 3 +- packages/cli/test/ts-schema-gen.test.ts | 2 +- packages/runtime/package.json | 1 + packages/runtime/src/client/client-impl.ts | 18 + packages/runtime/src/client/contract.ts | 5 + .../runtime/src/client/crud/dialects/base.ts | 103 ++- .../src/client/crud/dialects/sqlite.ts | 72 +- .../src/client/crud/operations/base.ts | 829 ++++++++++++------ packages/runtime/src/client/query-utils.ts | 34 + packages/runtime/src/schema/schema.ts | 1 + .../runtime/test/client-api/relation.test.ts | 720 +++++++++++++++ .../runtime/test/policy/todo-sample.test.ts | 3 +- packages/runtime/test/utils.ts | 56 +- packages/sdk/src/index.ts | 1 + .../{cli => sdk}/src/prisma/indent-string.ts | 0 .../{cli => sdk}/src/prisma/prisma-builder.ts | 0 .../src/prisma/prisma-schema-generator.ts | 2 +- packages/sdk/src/ts-schema-generator.ts | 40 +- packages/sdk/tsconfig.json | 2 +- packages/testtools/package.json | 3 +- packages/testtools/src/schema.ts | 6 +- pnpm-lock.yaml | 81 +- samples/blog/zenstack/schema.ts | 2 +- 24 files changed, 1649 insertions(+), 349 deletions(-) create mode 100644 packages/runtime/test/client-api/relation.test.ts rename packages/{cli => sdk}/src/prisma/indent-string.ts (100%) rename packages/{cli => sdk}/src/prisma/prisma-builder.ts (100%) rename packages/{cli => sdk}/src/prisma/prisma-schema-generator.ts (99%) diff --git a/TODO.md b/TODO.md index f2393796..1d75312e 100644 --- a/TODO.md +++ b/TODO.md @@ -36,7 +36,7 @@ - [x] Sorting - [x] Pagination - [x] Distinct - - [ ] Update + - [x] Update - [x] Input validation - [x] Top-level - [x] Nested to-many @@ -44,6 +44,7 @@ - [x] Incremental update for numeric fields - [x] Array update - [x] Upsert + - [ ] Implement with "on conflict" - [x] Delete - [x] Aggregation - [x] Count @@ -54,22 +55,23 @@ - [x] Computed fields - [ ] Prisma client extension - [ ] Misc + - [ ] Cache validation schemas - [ ] Compound ID - [ ] Cross field comparison - - [ ] Many-to-many relation - - [ ] Cache validation schemas + - [x] Many-to-many relation + - [ ] Empty AND/OR/NOT behavior - [?] Logging - - [ ] Error system + - [?] Error system - [x] Custom table name - [x] Custom field name - - [ ] Empty AND/OR/NOT behavior - [?] Strict undefined check - [ ] Access Policy - [ ] Short-circuit pre-create check for scalar-field only policies + - [ ] Inject "replace into" + - [ ] Inject "on conflict do update" - [ ] Polymorphism - [x] Migration - [ ] Databases - [x] SQLite - [x] PostgreSQL - [ ] Schema - - [ ] MySQL diff --git a/packages/cli/src/actions/generate.ts b/packages/cli/src/actions/generate.ts index edd01a8b..b5c3ed1b 100644 --- a/packages/cli/src/actions/generate.ts +++ b/packages/cli/src/actions/generate.ts @@ -1,11 +1,10 @@ import { isPlugin, LiteralExpr, type Model } from '@zenstackhq/language/ast'; import type { CliGenerator } from '@zenstackhq/runtime/client'; -import { TsSchemaGenerator } from '@zenstackhq/sdk'; +import { PrismaSchemaGenerator, TsSchemaGenerator } from '@zenstackhq/sdk'; import colors from 'colors'; import fs from 'node:fs'; import path from 'node:path'; import invariant from 'tiny-invariant'; -import { PrismaSchemaGenerator } from '../prisma/prisma-schema-generator'; import { getSchemaFile, loadSchemaDocument } from './action-utils'; type Options = { diff --git a/packages/cli/test/ts-schema-gen.test.ts b/packages/cli/test/ts-schema-gen.test.ts index 4ca693cb..cd2efd41 100644 --- a/packages/cli/test/ts-schema-gen.test.ts +++ b/packages/cli/test/ts-schema-gen.test.ts @@ -4,7 +4,7 @@ import { describe, expect, it } from 'vitest'; describe('TypeScript schema generation tests', () => { it('generates correct data models', async () => { - const schema = await generateTsSchema(` + const { schema } = await generateTsSchema(` model User { id String @id @default(uuid()) name String diff --git a/packages/runtime/package.json b/packages/runtime/package.json index d5a6b9e7..278c5ae1 100644 --- a/packages/runtime/package.json +++ b/packages/runtime/package.json @@ -113,6 +113,7 @@ "@types/tmp": "^0.2.6", "@zenstackhq/language": "workspace:*", "@zenstackhq/testtools": "workspace:*", + "@zenstackhq/sdk": "workspace:*", "tmp": "^0.2.3" } } diff --git a/packages/runtime/src/client/client-impl.ts b/packages/runtime/src/client/client-impl.ts index 8509487d..5a3c77b0 100644 --- a/packages/runtime/src/client/client-impl.ts +++ b/packages/runtime/src/client/client-impl.ts @@ -1,5 +1,6 @@ import { DefaultConnectionProvider, + DefaultQueryExecutor, Kysely, Log, PostgresDialect, @@ -47,6 +48,7 @@ export const ZenStackClient = function ( export class ClientImpl { private kysely: ToKysely; + private kyselyRaw: ToKysely; public readonly $options: ClientOptions; public readonly $schema: Schema; readonly kyselyProps: KyselyProps; @@ -77,6 +79,7 @@ export class ClientImpl { new DefaultConnectionProvider(baseClient.kyselyProps.driver) ), }; + this.kyselyRaw = baseClient.kyselyRaw; } else { const dialect = this.getKyselyDialect(); const driver = new ZenStackDriver( @@ -103,6 +106,17 @@ export class ClientImpl { driver, executor, }; + + // raw kysely instance with default executor + this.kyselyRaw = new Kysely({ + ...this.kyselyProps, + executor: new DefaultQueryExecutor( + compiler, + adapter, + connectionProvider, + [] + ), + }); } this.kysely = new Kysely(this.kyselyProps); @@ -114,6 +128,10 @@ export class ClientImpl { return this.kysely; } + public get $qbRaw() { + return this.kyselyRaw; + } + private getKyselyDialect() { return match(this.schema.provider.type) .with('sqlite', () => this.makeSqliteKyselyDialect()) diff --git a/packages/runtime/src/client/contract.ts b/packages/runtime/src/client/contract.ts index fa96fa5f..bb920347 100644 --- a/packages/runtime/src/client/contract.ts +++ b/packages/runtime/src/client/contract.ts @@ -37,6 +37,11 @@ export type ClientContract = { */ readonly $qb: ToKysely; + /** + * The raw Kysely query builder without any ZenStack enhancements. + */ + readonly $qbRaw: ToKysely; + /** * Starts a transaction. */ diff --git a/packages/runtime/src/client/crud/dialects/base.ts b/packages/runtime/src/client/crud/dialects/base.ts index f89fbf15..58d0ac6d 100644 --- a/packages/runtime/src/client/crud/dialects/base.ts +++ b/packages/runtime/src/client/crud/dialects/base.ts @@ -32,6 +32,8 @@ import { buildFieldRef, buildJoinPairs, getField, + getIdFields, + getManyToManyRelation, getRelationForeignKeyFieldPairs, isEnum, makeDefaultOrderBy, @@ -68,18 +70,18 @@ export abstract class BaseCrudDialect { eb: ExpressionBuilder, model: string, modelAlias: string, - where: object | undefined + where: boolean | object | undefined ) { - let result = this.true(eb); - - if (where === undefined) { - return result; + if (where === true || where === undefined) { + return this.true(eb); } - if (where === null || typeof where !== 'object') { - throw new InternalError('impossible null as filter'); + if (where === false) { + return this.false(eb); } + let result = this.true(eb); + for (const [key, payload] of Object.entries(where)) { if (payload === undefined) { continue; @@ -148,7 +150,12 @@ export abstract class BaseCrudDialect { } // call expression builder and combine the results - if ('$expr' in where && typeof where['$expr'] === 'function') { + if ( + typeof where === 'object' && + where !== null && + '$expr' in where && + typeof where['$expr'] === 'function' + ) { result = this.and(eb, result, where['$expr'](eb)); } @@ -356,45 +363,67 @@ export abstract class BaseCrudDialect { fieldDef: FieldDef, payload: any ) { - const relationModel = fieldDef.type; - - const relationKeyPairs = getRelationForeignKeyFieldPairs( - this.schema, - model, - field - ); - // null check needs to be converted to fk "is null" checks if (payload === null) { return eb(sql.ref(`${table}.${field}`), 'is', null); } + const relationModel = fieldDef.type; + const buildPkFkWhereRefs = (eb: ExpressionBuilder) => { - let r = this.true(eb); - for (const { fk, pk } of relationKeyPairs.keyPairs) { - if (relationKeyPairs.ownedByModel) { - r = this.and( - eb, - r, - eb( - sql.ref(`${table}.${fk}`), - '=', - sql.ref(`${relationModel}.${pk}`) - ) - ); - } else { - r = this.and( - eb, - r, - eb( - sql.ref(`${table}.${pk}`), + const m2m = getManyToManyRelation(this.schema, model, field); + if (m2m) { + // many-to-many relation + const modelIdField = getIdFields(this.schema, model)[0]!; + const relationIdField = getIdFields( + this.schema, + relationModel + )[0]!; + return eb( + sql.ref(`${relationModel}.${relationIdField}`), + 'in', + eb + .selectFrom(m2m.joinTable) + .select(`${m2m.joinTable}.${m2m.otherFkName}`) + .whereRef( + sql.ref(`${m2m.joinTable}.${m2m.parentFkName}`), '=', - sql.ref(`${relationModel}.${fk}`) + sql.ref(`${table}.${modelIdField}`) ) - ); + ); + } else { + const relationKeyPairs = getRelationForeignKeyFieldPairs( + this.schema, + model, + field + ); + + let result = this.true(eb); + for (const { fk, pk } of relationKeyPairs.keyPairs) { + if (relationKeyPairs.ownedByModel) { + result = this.and( + eb, + result, + eb( + sql.ref(`${table}.${fk}`), + '=', + sql.ref(`${relationModel}.${pk}`) + ) + ); + } else { + result = this.and( + eb, + result, + eb( + sql.ref(`${table}.${pk}`), + '=', + sql.ref(`${relationModel}.${fk}`) + ) + ); + } } + return result; } - return r; }; let result = this.true(eb); diff --git a/packages/runtime/src/client/crud/dialects/sqlite.ts b/packages/runtime/src/client/crud/dialects/sqlite.ts index 2bc6320b..62ea65db 100644 --- a/packages/runtime/src/client/crud/dialects/sqlite.ts +++ b/packages/runtime/src/client/crud/dialects/sqlite.ts @@ -7,12 +7,15 @@ import { type RawBuilder, type SelectQueryBuilder, } from 'kysely'; +import invariant from 'tiny-invariant'; import { match } from 'ts-pattern'; import type { SchemaDef } from '../../../schema'; import type { BuiltinType, GetModels } from '../../../schema/schema'; import type { FindArgs } from '../../crud-types'; import { buildFieldRef, + getIdFields, + getManyToManyRelation, getRelationForeignKeyFieldPairs, requireField, requireModel, @@ -117,28 +120,63 @@ export class SqliteCrudDialect< } // join conditions - const { keyPairs, ownedByModel } = getRelationForeignKeyFieldPairs( + + const m2m = getManyToManyRelation( this.schema, model, relationField ); - keyPairs.forEach(({ fk, pk }) => { - if (ownedByModel) { - // the parent model owns the fk - subQuery = subQuery.whereRef( - `${relationModel}.${pk}`, - '=', - `${parentName}.${fk}` - ); - } else { - // the relation side owns the fk - subQuery = subQuery.whereRef( - `${relationModel}.${fk}`, - '=', - `${parentName}.${pk}` + if (m2m) { + // many-to-many relation + const parentIds = getIdFields(this.schema, model); + const relationIds = getIdFields(this.schema, relationModel); + invariant( + parentIds.length === 1, + 'many-to-many relation must have exactly one id field' + ); + invariant( + relationIds.length === 1, + 'many-to-many relation must have exactly one id field' + ); + subQuery = subQuery.where( + eb( + eb.ref(`${relationModel}.${relationIds[0]}`), + 'in', + eb + .selectFrom(m2m.joinTable) + .select(`${m2m.joinTable}.${m2m.otherFkName}`) + .whereRef( + `${parentName}.${parentIds[0]}`, + '=', + `${m2m.joinTable}.${m2m.parentFkName}` + ) + ) + ); + } else { + const { keyPairs, ownedByModel } = + getRelationForeignKeyFieldPairs( + this.schema, + model, + relationField ); - } - }); + keyPairs.forEach(({ fk, pk }) => { + if (ownedByModel) { + // the parent model owns the fk + subQuery = subQuery.whereRef( + `${relationModel}.${pk}`, + '=', + `${parentName}.${fk}` + ); + } else { + // the relation side owns the fk + subQuery = subQuery.whereRef( + `${relationModel}.${fk}`, + '=', + `${parentName}.${pk}` + ); + } + }); + } return subQuery.as(subQueryName); }); diff --git a/packages/runtime/src/client/crud/operations/base.ts b/packages/runtime/src/client/crud/operations/base.ts index 9d4c698e..10d28798 100644 --- a/packages/runtime/src/client/crud/operations/base.ts +++ b/packages/runtime/src/client/crud/operations/base.ts @@ -45,6 +45,7 @@ import { getField, getIdFields, getIdValues, + getManyToManyRelation, getModel, getRelationForeignKeyFieldPairs, isForeignKeyField, @@ -492,43 +493,54 @@ export abstract class BaseOperationHandler { let parentUpdateTask: ((entity: any) => Promise) | undefined = undefined; + let m2m: ReturnType = undefined; + if (fromRelation) { - const { ownedByModel, keyPairs } = getRelationForeignKeyFieldPairs( + m2m = getManyToManyRelation( this.schema, - fromRelation?.model ?? '', - fromRelation?.field ?? '' + fromRelation.model, + fromRelation.field ); + if (!m2m) { + // many-to-many relations are handled after create + const { ownedByModel, keyPairs } = + getRelationForeignKeyFieldPairs( + this.schema, + fromRelation?.model ?? '', + fromRelation?.field ?? '' + ); - if (!ownedByModel) { - // assign fks from parent - const parentFkFields = this.buildFkAssignments( - fromRelation.model, - fromRelation.field, - fromRelation.ids - ); - Object.assign(createFields, parentFkFields); - } else { - parentUpdateTask = (entity) => { - const query = kysely - .updateTable(fromRelation.model) - .set( - keyPairs.reduce( - (acc, { fk, pk }) => ({ - ...acc, - [fk]: entity[pk], - }), - {} as any + if (!ownedByModel) { + // assign fks from parent + const parentFkFields = this.buildFkAssignments( + fromRelation.model, + fromRelation.field, + fromRelation.ids + ); + Object.assign(createFields, parentFkFields); + } else { + parentUpdateTask = (entity) => { + const query = kysely + .updateTable(fromRelation.model) + .set( + keyPairs.reduce( + (acc, { fk, pk }) => ({ + ...acc, + [fk]: entity[pk], + }), + {} as any + ) ) - ) - .where((eb) => eb.and(fromRelation.ids)) - .modifyEnd( - this.makeContextComment({ - model: fromRelation.model, - operation: 'update', - }) - ); - return query.execute(); - }; + .where((eb) => eb.and(fromRelation.ids)) + .modifyEnd( + this.makeContextComment({ + model: fromRelation.model, + operation: 'update', + }) + ); + return query.execute(); + }; + } } } @@ -559,7 +571,9 @@ export abstract class BaseOperationHandler { ); } } else { + const subM2M = getManyToManyRelation(this.schema, model, field); if ( + !subM2M && fieldDef.relation?.fields && fieldDef.relation?.references ) { @@ -623,6 +637,21 @@ export abstract class BaseOperationHandler { await Promise.all(relationPromises); } + if (fromRelation && m2m) { + // connect many-to-many relation + await this.handleManyToManyRelation( + kysely, + 'connect', + fromRelation.model, + fromRelation.field, + fromRelation.ids, + m2m.otherModel, + m2m.otherField, + createdEntity, + m2m.joinTable + ); + } + // finally update parent if needed if (parentUpdateTask) { await parentUpdateTask(createdEntity); @@ -666,6 +695,103 @@ export abstract class BaseOperationHandler { return parentFkFields; } + private async handleManyToManyRelation< + Action extends 'connect' | 'disconnect' + >( + kysely: ToKysely, + action: Action, + leftModel: string, + leftField: string, + leftEntity: any, + rightModel: string, + rightField: string, + rightEntity: any, + joinTable: string + ): Promise< + Action extends 'connect' + ? UpdateResult | undefined + : DeleteResult | undefined + > { + const sortedRecords = [ + { + model: leftModel, + field: leftField, + entity: leftEntity, + }, + { + model: rightModel, + field: rightField, + entity: rightEntity, + }, + ].sort((a, b) => a.model.localeCompare(b.model)); + + const firstIds = getIdFields(this.schema, sortedRecords[0]!.model); + const secondIds = getIdFields(this.schema, sortedRecords[1]!.model); + invariant( + firstIds.length === 1, + 'many-to-many relation must have exactly one id field' + ); + invariant( + secondIds.length === 1, + 'many-to-many relation must have exactly one id field' + ); + + // Prisma's convention for many-to-many: fk fields are named "A" and "B" + if (action === 'connect') { + const result = await kysely + .insertInto(joinTable as any) + .values({ + A: sortedRecords[0]!.entity[firstIds[0]!], + B: sortedRecords[1]!.entity[secondIds[0]!], + } as any) + .onConflict((oc) => oc.columns(['A', 'B'] as any).doNothing()) + .execute(); + return result[0] as any; + } else { + const eb = expressionBuilder(); + const result = await kysely + .deleteFrom(joinTable as any) + .where( + eb( + `${joinTable}.A`, + '=', + sortedRecords[0]!.entity[firstIds[0]!] + ) + ) + .where( + eb( + `${joinTable}.B`, + '=', + sortedRecords[1]!.entity[secondIds[0]!] + ) + ) + .execute(); + return result[0] as any; + } + } + + private resetManyToManyRelation( + kysely: ToKysely, + model: GetModels, + field: string, + parentIds: any + ) { + invariant( + Object.keys(parentIds).length === 1, + 'parentIds must have exactly one field' + ); + const parentId = Object.values(parentIds)[0]!; + + const m2m = getManyToManyRelation(this.schema, model, field); + invariant(m2m, 'not a many-to-many relation'); + + const eb = expressionBuilder(); + return kysely + .deleteFrom(m2m.joinTable as any) + .where(eb(`${m2m.joinTable}.${m2m.parentFkName}`, '=', parentId)) + .execute(); + } + private async processOwnedRelation( kysely: ToKysely, relationField: FieldDef, @@ -1006,28 +1132,48 @@ export abstract class BaseOperationHandler { } const parentWhere: any = {}; + let m2m: ReturnType = undefined; + if (fromRelation) { - // merge foreign key conditions from the relation - const { ownedByModel, keyPairs } = getRelationForeignKeyFieldPairs( + m2m = getManyToManyRelation( this.schema, fromRelation.model, fromRelation.field ); - if (ownedByModel) { - const fromEntity = await this.readUnique( - kysely, - fromRelation.model as GetModels, - { - where: fromRelation.ids, + if (!m2m) { + // merge foreign key conditions from the relation + const { ownedByModel, keyPairs } = + getRelationForeignKeyFieldPairs( + this.schema, + fromRelation.model, + fromRelation.field + ); + if (ownedByModel) { + const fromEntity = await this.readUnique( + kysely, + fromRelation.model as GetModels, + { + where: fromRelation.ids, + } + ); + for (const { fk, pk } of keyPairs) { + parentWhere[pk] = fromEntity[fk]; + } + } else { + for (const { fk, pk } of keyPairs) { + parentWhere[fk] = fromRelation.ids[pk]; } - ); - for (const { fk, pk } of keyPairs) { - parentWhere[pk] = fromEntity[fk]; } } else { - for (const { fk, pk } of keyPairs) { - parentWhere[fk] = fromRelation.ids[pk]; - } + // many-to-many relation, filter for parent with "some" + const fromRelationFieldDef = this.requireField( + fromRelation.model, + fromRelation.field + ); + invariant(fromRelationFieldDef.relation?.opposite); + parentWhere[fromRelationFieldDef.relation.opposite] = { + some: fromRelation.ids, + }; } } @@ -1592,100 +1738,151 @@ export abstract class BaseOperationHandler { return; } - const { ownedByModel, keyPairs } = getRelationForeignKeyFieldPairs( + const m2m = getManyToManyRelation( this.schema, fromRelation.model, fromRelation.field ); - let updateResult: UpdateResult; - - if (ownedByModel) { - // set parent fk directly - invariant(_data.length === 1, 'only one entity can be connected'); - const target = await this.readUnique(kysely, model, { - where: _data[0], + if (m2m) { + // handle many-to-many relation + const actions = _data.map(async (d) => { + const ids = await this.getEntityIds(kysely, model, d); + return this.handleManyToManyRelation( + kysely, + 'connect', + fromRelation.model, + fromRelation.field, + fromRelation.ids, + m2m.otherModel!, + m2m.otherField!, + ids, + m2m.joinTable + ); }); - if (!target) { + const results = await Promise.all(actions); + + // validate connect result + if (_data.length > results.filter((r) => !!r).length) { throw new NotFoundError(model); } - const query = kysely - .updateTable(fromRelation.model) - .where((eb) => eb.and(fromRelation.ids)) - .set( - keyPairs.reduce( - (acc, { fk, pk }) => ({ - ...acc, - [fk]: target[pk], - }), - {} as any - ) - ) - .modifyEnd( - this.makeContextComment({ - model: fromRelation.model, - operation: 'update', - }) - ); - updateResult = await query.executeTakeFirstOrThrow(); } else { - // disconnect current if it's a one-one relation - const relationFieldDef = this.requireField( + const { ownedByModel, keyPairs } = getRelationForeignKeyFieldPairs( + this.schema, fromRelation.model, fromRelation.field ); + let updateResult: UpdateResult; - if (!relationFieldDef.array) { + if (ownedByModel) { + // set parent fk directly + invariant( + _data.length === 1, + 'only one entity can be connected' + ); + const target = await this.readUnique(kysely, model, { + where: _data[0], + }); + if (!target) { + throw new NotFoundError(model); + } const query = kysely - .updateTable(model) - .where((eb) => - eb.and( - keyPairs.map(({ fk, pk }) => - eb(sql.ref(fk), '=', fromRelation.ids[pk]) - ) + .updateTable(fromRelation.model) + .where((eb) => eb.and(fromRelation.ids)) + .set( + keyPairs.reduce( + (acc, { fk, pk }) => ({ + ...acc, + [fk]: target[pk], + }), + {} as any ) ) + .modifyEnd( + this.makeContextComment({ + model: fromRelation.model, + operation: 'update', + }) + ); + updateResult = await query.executeTakeFirstOrThrow(); + } else { + // disconnect current if it's a one-one relation + const relationFieldDef = this.requireField( + fromRelation.model, + fromRelation.field + ); + + if (!relationFieldDef.array) { + const query = kysely + .updateTable(model) + .where((eb) => + eb.and( + keyPairs.map(({ fk, pk }) => + eb(sql.ref(fk), '=', fromRelation.ids[pk]) + ) + ) + ) + .set( + keyPairs.reduce( + (acc, { fk }) => ({ ...acc, [fk]: null }), + {} as any + ) + ) + .modifyEnd( + this.makeContextComment({ + model: fromRelation.model, + operation: 'update', + }) + ); + await query.execute(); + } + + // connect + const query = kysely + .updateTable(model) + .where((eb) => eb.or(_data.map((d) => eb.and(d)))) .set( keyPairs.reduce( - (acc, { fk }) => ({ ...acc, [fk]: null }), + (acc, { fk, pk }) => ({ + ...acc, + [fk]: fromRelation.ids[pk], + }), {} as any ) ) .modifyEnd( this.makeContextComment({ - model: fromRelation.model, + model, operation: 'update', }) ); - await query.execute(); + updateResult = await query.executeTakeFirstOrThrow(); } - // connect - const query = kysely - .updateTable(model) - .where((eb) => eb.or(_data.map((d) => eb.and(d)))) - .set( - keyPairs.reduce( - (acc, { fk, pk }) => ({ - ...acc, - [fk]: fromRelation.ids[pk], - }), - {} as any - ) - ) - .modifyEnd( - this.makeContextComment({ - model, - operation: 'update', - }) - ); - updateResult = await query.executeTakeFirstOrThrow(); + // validate connect result + if (_data.length > updateResult.numUpdatedRows) { + // some entities were not connected + throw new NotFoundError(model); + } } + } - // validate connect result - if (_data.length > updateResult.numUpdatedRows) { - // some entities were not connected - throw new NotFoundError(model); + private getEntityIds( + kysely: ToKysely, + model: GetModels, + uniqueFilter: any + ) { + const idFields = getIdFields(this.schema, model); + if ( + idFields.every( + (f) => f in uniqueFilter && uniqueFilter[f] !== undefined + ) + ) { + return uniqueFilter; } + + return this.readUnique(kysely, model, { + where: uniqueFilter, + }); } protected async connectOrCreateRelation( @@ -1743,71 +1940,100 @@ export abstract class BaseOperationHandler { return; } - const { ownedByModel, keyPairs } = getRelationForeignKeyFieldPairs( + const m2m = getManyToManyRelation( this.schema, fromRelation.model, fromRelation.field ); - - let updateResult: UpdateResult; - - if (ownedByModel) { - // set parent fk directly - invariant( - disconnectConditions.length === 1, - 'only one entity can be disconnected' - ); - const target = await this.readUnique(kysely, model, { - where: - disconnectConditions[0] === true - ? {} - : disconnectConditions[0], + if (m2m) { + // handle many-to-many relation + const actions = disconnectConditions.map(async (d) => { + const ids = await this.getEntityIds(kysely, model, d); + return this.handleManyToManyRelation( + kysely, + 'disconnect', + fromRelation.model, + fromRelation.field, + fromRelation.ids, + m2m.otherModel, + m2m.otherField, + ids, + m2m.joinTable + ); }); - if (!target) { + const results = await Promise.all(actions); + + // validate disconnect result + if (expectedUpdateCount > results.filter((r) => !!r).length) { throw new NotFoundError(model); } - const query = kysely - .updateTable(fromRelation.model) - .where((eb) => eb.and(fromRelation.ids)) - .set( - keyPairs.reduce( - (acc, { fk }) => ({ ...acc, [fk]: null }), - {} as any - ) - ) - .modifyEnd( - this.makeContextComment({ - model: fromRelation.model, - operation: 'update', - }) - ); - updateResult = await query.executeTakeFirstOrThrow(); } else { - // disconnect - const query = kysely - .updateTable(model) - .where((eb) => - eb.or(disconnectConditions.map((d) => eb.and(d))) - ) - .set( - keyPairs.reduce( - (acc, { fk }) => ({ ...acc, [fk]: null }), - {} as any - ) - ) - .modifyEnd( - this.makeContextComment({ - model, - operation: 'update', - }) + let updateResult: UpdateResult; + + const { ownedByModel, keyPairs } = getRelationForeignKeyFieldPairs( + this.schema, + fromRelation.model, + fromRelation.field + ); + + if (ownedByModel) { + // set parent fk directly + invariant( + disconnectConditions.length === 1, + 'only one entity can be disconnected' ); - updateResult = await query.executeTakeFirstOrThrow(); - } + const target = await this.readUnique(kysely, model, { + where: + disconnectConditions[0] === true + ? {} + : disconnectConditions[0], + }); + if (!target) { + throw new NotFoundError(model); + } + const query = kysely + .updateTable(fromRelation.model) + .where((eb) => eb.and(fromRelation.ids)) + .set( + keyPairs.reduce( + (acc, { fk }) => ({ ...acc, [fk]: null }), + {} as any + ) + ) + .modifyEnd( + this.makeContextComment({ + model: fromRelation.model, + operation: 'update', + }) + ); + updateResult = await query.executeTakeFirstOrThrow(); + } else { + // disconnect + const query = kysely + .updateTable(model) + .where((eb) => + eb.or(disconnectConditions.map((d) => eb.and(d))) + ) + .set( + keyPairs.reduce( + (acc, { fk }) => ({ ...acc, [fk]: null }), + {} as any + ) + ) + .modifyEnd( + this.makeContextComment({ + model, + operation: 'update', + }) + ); + updateResult = await query.executeTakeFirstOrThrow(); + } - // validate connect result - if (expectedUpdateCount > updateResult.numUpdatedRows!) { - // some entities were not connected - throw new NotFoundError(model); + // validate disconnect result + if (expectedUpdateCount > updateResult.numUpdatedRows!) { + // some entities were not connected + throw new NotFoundError(model); + } } } @@ -1818,62 +2044,80 @@ export abstract class BaseOperationHandler { fromRelation: FromRelationContext ) { const _data = enumerate(data); - const { ownedByModel, keyPairs } = getRelationForeignKeyFieldPairs( + + const m2m = getManyToManyRelation( this.schema, fromRelation.model, fromRelation.field ); - if (ownedByModel) { - throw new InternalError( - 'relation can only be set from the non-owning side' + if (m2m) { + // handle many-to-many relation + + // reset for the parent + await this.resetManyToManyRelation( + kysely, + fromRelation.model, + fromRelation.field, + fromRelation.ids ); - } - const fkConditions = keyPairs.reduce( - (acc, { fk, pk }) => ({ - ...acc, - [fk]: fromRelation.ids[pk], - }), - {} as any - ); + // connect new entities + const actions = _data.map(async (d) => { + const ids = await this.getEntityIds(kysely, model, d); + return this.handleManyToManyRelation( + kysely, + 'connect', + fromRelation.model, + fromRelation.field, + fromRelation.ids, + m2m.otherModel, + m2m.otherField, + ids, + m2m.joinTable + ); + }); + const results = await Promise.all(actions); - // disconnect - const query = kysely - .updateTable(model) - .where((eb) => - eb.and([ - // match parent - eb.and(fkConditions), - // exclude entities to be connected - eb.not(eb.or(_data.map((d) => eb.and(d)))), - ]) - ) - .set( - keyPairs.reduce( - (acc, { fk }) => ({ ...acc, [fk]: null }), - {} as any - ) - ) - .modifyEnd( - this.makeContextComment({ - model, - operation: 'update', - }) + // validate connect result + if (_data.length > results.filter((r) => !!r).length) { + throw new NotFoundError(model); + } + } else { + const { ownedByModel, keyPairs } = getRelationForeignKeyFieldPairs( + this.schema, + fromRelation.model, + fromRelation.field ); - await query.execute(); - // connect - if (_data.length > 0) { + if (ownedByModel) { + throw new InternalError( + 'relation can only be set from the non-owning side' + ); + } + + const fkConditions = keyPairs.reduce( + (acc, { fk, pk }) => ({ + ...acc, + [fk]: fromRelation.ids[pk], + }), + {} as any + ); + + // disconnect const query = kysely .updateTable(model) - .where((eb) => eb.or(_data.map((d) => eb.and(d)))) + .where((eb) => + eb.and([ + // match parent + eb.and(fkConditions), + // exclude entities to be connected + eb.not(eb.or(_data.map((d) => eb.and(d)))), + ]) + ) .set( keyPairs.reduce( - (acc, { fk, pk }) => ({ - ...acc, - [fk]: fromRelation.ids[pk], - }), + (acc, { fk }) => ({ ...acc, [fk]: null }), {} as any ) ) @@ -1883,16 +2127,38 @@ export abstract class BaseOperationHandler { operation: 'update', }) ); - const r = await query.executeTakeFirstOrThrow(); + await query.execute(); - // validate result - if (_data.length > r.numUpdatedRows!) { - // some entities were not connected - throw new NotFoundError(model); + // connect + if (_data.length > 0) { + const query = kysely + .updateTable(model) + .where((eb) => eb.or(_data.map((d) => eb.and(d)))) + .set( + keyPairs.reduce( + (acc, { fk, pk }) => ({ + ...acc, + [fk]: fromRelation.ids[pk], + }), + {} as any + ) + ) + .modifyEnd( + this.makeContextComment({ + model, + operation: 'update', + }) + ); + const r = await query.executeTakeFirstOrThrow(); + + // validate result + if (_data.length > r.numUpdatedRows!) { + // some entities were not connected + throw new NotFoundError(model); + } } } } - protected async deleteRelation( kysely: ToKysely, model: GetModels, @@ -1917,67 +2183,108 @@ export abstract class BaseOperationHandler { expectedDeleteCount = deleteConditions.length; } - const { ownedByModel, keyPairs } = getRelationForeignKeyFieldPairs( + let deleteResult: { count: number }; + const m2m = getManyToManyRelation( this.schema, fromRelation.model, fromRelation.field ); - let deleteResult: DeleteResult; - if (ownedByModel) { - const fromEntity = await this.readUnique( + if (m2m) { + // handle many-to-many relation + const fieldDef = this.requireField( + fromRelation.model, + fromRelation.field + ); + invariant(fieldDef.relation?.opposite); + + deleteResult = await this.delete( kysely, - fromRelation.model as GetModels, + model, { - where: fromRelation.ids, - } + AND: [ + { + [fieldDef.relation.opposite]: { + some: fromRelation.ids, + }, + }, + { + OR: deleteConditions, + }, + ], + }, + false ); - if (!fromEntity) { - throw new NotFoundError(model); - } - const query = kysely - .deleteFrom(model) - .where((eb) => - eb.and([ - eb.and( - keyPairs.map(({ fk, pk }) => - eb(sql.ref(pk), '=', fromEntity[fk]) - ) - ), - eb.or(deleteConditions.map((d) => eb.and(d))), - ]) - ) - .modifyEnd( - this.makeContextComment({ - model, - operation: 'delete', - }) - ); - deleteResult = await query.executeTakeFirstOrThrow(); } else { - const query = kysely - .deleteFrom(model) - .where((eb) => - eb.and([ - eb.and( - keyPairs.map(({ fk, pk }) => - eb(sql.ref(fk), '=', fromRelation.ids[pk]) - ) - ), - eb.or(deleteConditions.map((d) => eb.and(d))), - ]) - ) - .modifyEnd( - this.makeContextComment({ model, operation: 'delete' }) + const { ownedByModel, keyPairs } = getRelationForeignKeyFieldPairs( + this.schema, + fromRelation.model, + fromRelation.field + ); + + if (ownedByModel) { + const fromEntity = await this.readUnique( + kysely, + fromRelation.model as GetModels, + { + where: fromRelation.ids, + } + ); + if (!fromEntity) { + throw new NotFoundError(model); + } + + const fieldDef = this.requireField( + fromRelation.model, + fromRelation.field ); - deleteResult = await query.executeTakeFirstOrThrow(); + invariant(fieldDef.relation?.opposite); + deleteResult = await this.delete( + kysely, + model, + { + AND: [ + { + // filter for parent + [fieldDef.relation.opposite]: + Object.fromEntries( + keyPairs.map(({ fk, pk }) => [ + fk, + fromEntity[pk], + ]) + ), + }, + { + OR: deleteConditions, + }, + ], + }, + false + ); + } else { + deleteResult = await this.delete( + kysely, + model, + { + AND: [ + Object.fromEntries( + keyPairs.map(({ fk, pk }) => [ + fk, + fromRelation.ids[pk], + ]) + ), + { + OR: deleteConditions, + }, + ], + }, + false + ); + } } // validate result - if ( - throwForNotFound && - expectedDeleteCount > deleteResult.numDeletedRows - ) { + if (throwForNotFound && expectedDeleteCount > deleteResult.count) { // some entities were not deleted throw new NotFoundError(model); } diff --git a/packages/runtime/src/client/query-utils.ts b/packages/runtime/src/client/query-utils.ts index 62d696bc..96ee2bfb 100644 --- a/packages/runtime/src/client/query-utils.ts +++ b/packages/runtime/src/client/query-utils.ts @@ -263,6 +263,40 @@ export function makeDefaultOrderBy( ); } +export function getManyToManyRelation( + schema: SchemaDef, + model: string, + field: string +) { + const fieldDef = requireField(schema, model, field); + if (!fieldDef.array || !fieldDef.relation?.opposite) { + return undefined; + } + const oppositeFieldDef = requireField( + schema, + fieldDef.type, + fieldDef.relation.opposite + ); + if (oppositeFieldDef.array) { + // Prisma's convention for many-to-many relation: + // - model are sorted alphabetically by name + // - join table is named _To, unless an explicit name is provided by `@relation` + // - foreign keys are named A and B (based on the order of the model) + const sortedModelNames = [model, fieldDef.type].sort(); + return { + parentFkName: sortedModelNames[0] === model ? 'A' : 'B', + otherModel: fieldDef.type, + otherField: fieldDef.relation.opposite, + otherFkName: sortedModelNames[0] === fieldDef.type ? 'A' : 'B', + joinTable: fieldDef.relation.name + ? `_${fieldDef.relation.name}` + : `_${sortedModelNames[0]}To${sortedModelNames[1]}`, + }; + } else { + return undefined; + } +} + export function ensureArray(value: T | T[]): T[] { if (Array.isArray(value)) { return value; diff --git a/packages/runtime/src/schema/schema.ts b/packages/runtime/src/schema/schema.ts index c504a976..2d96d239 100644 --- a/packages/runtime/src/schema/schema.ts +++ b/packages/runtime/src/schema/schema.ts @@ -50,6 +50,7 @@ export type CascadeAction = | 'SetDefault'; export type RelationInfo = { + name?: string; fields?: string[]; references?: string[]; opposite?: string; diff --git a/packages/runtime/test/client-api/relation.test.ts b/packages/runtime/test/client-api/relation.test.ts new file mode 100644 index 00000000..2760ed49 --- /dev/null +++ b/packages/runtime/test/client-api/relation.test.ts @@ -0,0 +1,720 @@ +import { beforeEach, describe, expect, it } from 'vitest'; +import { createTestClient } from '../utils'; + +describe('Relation tests', () => { + it('works with unnamed one-to-one relation', async () => { + const client = await createTestClient(` + model User { + id Int @id @default(autoincrement()) + name String + profile Profile? + } + + model Profile { + id Int @id @default(autoincrement()) + age Int + user User @relation(fields: [userId], references: [id]) + userId Int @unique + } + `); + + await expect( + client.user.create({ + data: { + name: 'User', + profile: { create: { age: 20 } }, + }, + include: { profile: true }, + }) + ).resolves.toMatchObject({ + name: 'User', + profile: { age: 20 }, + }); + }); + + it('works with named one-to-one relation', async () => { + const client = await createTestClient(` + model User { + id Int @id @default(autoincrement()) + name String + profile1 Profile? @relation('profile1') + profile2 Profile? @relation('profile2') + } + + model Profile { + id Int @id @default(autoincrement()) + age Int + user1 User? @relation('profile1', fields: [userId1], references: [id]) + user2 User? @relation('profile2', fields: [userId2], references: [id]) + userId1 Int? @unique + userId2 Int? @unique + } + `); + + await expect( + client.user.create({ + data: { + name: 'User', + profile1: { create: { age: 20 } }, + profile2: { create: { age: 21 } }, + }, + include: { profile1: true, profile2: true }, + }) + ).resolves.toMatchObject({ + name: 'User', + profile1: { age: 20 }, + profile2: { age: 21 }, + }); + }); + + it('works with unnamed one-to-many relation', async () => { + const client = await createTestClient(` + model User { + id Int @id @default(autoincrement()) + name String + posts Post[] + } + + model Post { + id Int @id @default(autoincrement()) + title String + user User @relation(fields: [userId], references: [id]) + userId Int + } + `); + + await expect( + client.user.create({ + data: { + name: 'User', + posts: { + create: [{ title: 'Post 1' }, { title: 'Post 2' }], + }, + }, + include: { posts: true }, + }) + ).resolves.toMatchObject({ + name: 'User', + posts: [ + expect.objectContaining({ title: 'Post 1' }), + expect.objectContaining({ title: 'Post 2' }), + ], + }); + }); + + it('works with named one-to-many relation', async () => { + const client = await createTestClient(` + model User { + id Int @id @default(autoincrement()) + name String + posts1 Post[] @relation('userPosts1') + posts2 Post[] @relation('userPosts2') + } + + model Post { + id Int @id @default(autoincrement()) + title String + user1 User? @relation('userPosts1', fields: [userId1], references: [id]) + user2 User? @relation('userPosts2', fields: [userId2], references: [id]) + userId1 Int? + userId2 Int? + } + `); + + await expect( + client.user.create({ + data: { + name: 'User', + posts1: { + create: [{ title: 'Post 1' }, { title: 'Post 2' }], + }, + posts2: { + create: [{ title: 'Post 3' }, { title: 'Post 4' }], + }, + }, + include: { posts1: true, posts2: true }, + }) + ).resolves.toMatchObject({ + name: 'User', + posts1: [ + expect.objectContaining({ title: 'Post 1' }), + expect.objectContaining({ title: 'Post 2' }), + ], + posts2: [ + expect.objectContaining({ title: 'Post 3' }), + expect.objectContaining({ title: 'Post 4' }), + ], + }); + }); + + it('works with explicit many-to-many relation', async () => { + const client = await createTestClient(` + model User { + id Int @id @default(autoincrement()) + name String + tags UserTag[] + } + + model Tag { + id Int @id @default(autoincrement()) + name String + users UserTag[] + } + + model UserTag { + id Int @id @default(autoincrement()) + userId Int + tagId Int + user User @relation(fields: [userId], references: [id]) + tag Tag @relation(fields: [tagId], references: [id]) + @@unique([userId, tagId]) + } + `); + + await client.user.create({ data: { id: 1, name: 'User1' } }); + await client.user.create({ data: { id: 2, name: 'User2' } }); + await client.tag.create({ data: { id: 1, name: 'Tag1' } }); + await client.tag.create({ data: { id: 2, name: 'Tag2' } }); + + await client.userTag.create({ data: { userId: 1, tagId: 1 } }); + await client.userTag.create({ data: { userId: 1, tagId: 2 } }); + await client.userTag.create({ data: { userId: 2, tagId: 1 } }); + + await expect( + client.user.findMany({ + include: { tags: { include: { tag: true } } }, + }) + ).resolves.toMatchObject([ + expect.objectContaining({ + name: 'User1', + tags: [ + expect.objectContaining({ + tag: expect.objectContaining({ name: 'Tag1' }), + }), + expect.objectContaining({ + tag: expect.objectContaining({ name: 'Tag2' }), + }), + ], + }), + expect.objectContaining({ + name: 'User2', + tags: [ + expect.objectContaining({ + tag: expect.objectContaining({ name: 'Tag1' }), + }), + ], + }), + ]); + }); + + describe('Implicit many-to-many relation', () => { + let client: any; + + beforeEach(async () => { + client = await createTestClient( + ` + model User { + id Int @id @default(autoincrement()) + name String + profile Profile? + tags Tag[] + } + + model Tag { + id Int @id @default(autoincrement()) + name String + users User[] + } + + model Profile { + id Int @id @default(autoincrement()) + age Int + user User @relation(fields: [userId], references: [id]) + userId Int @unique + } + `, + { dbName: 'file:./dev.db', usePrismaPush: true } + ); + }); + + it('works with find', async () => { + await client.user.create({ + data: { + id: 1, + name: 'User1', + tags: { + create: [ + { id: 1, name: 'Tag1' }, + { id: 2, name: 'Tag2' }, + ], + }, + profile: { + create: { + id: 1, + age: 20, + }, + }, + }, + }); + + await client.user.create({ + data: { + id: 2, + name: 'User2', + }, + }); + + // include without filter + await expect( + client.user.findFirst({ + include: { tags: true }, + }) + ).resolves.toMatchObject({ + tags: [ + expect.objectContaining({ name: 'Tag1' }), + expect.objectContaining({ name: 'Tag2' }), + ], + }); + + await expect( + client.profile.findFirst({ + include: { + user: { + include: { tags: true }, + }, + }, + }) + ).resolves.toMatchObject({ + user: expect.objectContaining({ + tags: [ + expect.objectContaining({ name: 'Tag1' }), + expect.objectContaining({ name: 'Tag2' }), + ], + }), + }); + + await expect( + client.user.findUnique({ + where: { id: 2 }, + include: { tags: true }, + }) + ).resolves.toMatchObject({ + tags: [], + }); + + // include with filter + await expect( + client.user.findFirst({ + where: { id: 1 }, + include: { tags: { where: { name: 'Tag1' } } }, + }) + ).resolves.toMatchObject({ + tags: [expect.objectContaining({ name: 'Tag1' })], + }); + + // filter with m2m + await expect( + client.user.findMany({ + where: { tags: { some: { name: 'Tag1' } } }, + }) + ).resolves.toEqual([ + expect.objectContaining({ + name: 'User1', + }), + ]); + await expect( + client.user.findMany({ + where: { tags: { none: { name: 'Tag1' } } }, + }) + ).resolves.toEqual([ + expect.objectContaining({ + name: 'User2', + }), + ]); + }); + + it('works with create', async () => { + // create + await expect( + client.user.create({ + data: { + id: 1, + name: 'User1', + tags: { + create: [ + { + id: 1, + name: 'Tag1', + }, + { + id: 2, + name: 'Tag2', + }, + ], + }, + }, + include: { tags: true }, + }) + ).resolves.toMatchObject({ + tags: [ + expect.objectContaining({ name: 'Tag1' }), + expect.objectContaining({ name: 'Tag2' }), + ], + }); + + // connect + await expect( + client.user.create({ + data: { + id: 2, + name: 'User2', + tags: { connect: { id: 1 } }, + }, + include: { tags: true }, + }) + ).resolves.toMatchObject({ + tags: [expect.objectContaining({ name: 'Tag1' })], + }); + + // connectOrCreate + await expect( + client.user.create({ + data: { + id: 3, + name: 'User3', + tags: { + connectOrCreate: { + where: { id: 1 }, + create: { id: 1, name: 'Tag1' }, + }, + }, + }, + include: { tags: true }, + }) + ).resolves.toMatchObject({ + tags: [expect.objectContaining({ id: 1, name: 'Tag1' })], + }); + + await expect( + client.user.create({ + data: { + id: 4, + name: 'User4', + tags: { + connectOrCreate: { + where: { id: 3 }, + create: { id: 3, name: 'Tag3' }, + }, + }, + }, + include: { tags: true }, + }) + ).resolves.toMatchObject({ + tags: [expect.objectContaining({ id: 3, name: 'Tag3' })], + }); + }); + + it('works with update', async () => { + // create + await client.user.create({ + data: { + id: 1, + name: 'User1', + tags: { + create: [ + { + id: 1, + name: 'Tag1', + }, + ], + }, + }, + include: { tags: true }, + }); + + // create + await expect( + client.user.update({ + where: { id: 1 }, + data: { + tags: { + create: [ + { + id: 2, + name: 'Tag2', + }, + ], + }, + }, + include: { tags: true }, + }) + ).resolves.toMatchObject({ + tags: [ + expect.objectContaining({ id: 1 }), + expect.objectContaining({ id: 2 }), + ], + }); + + await client.tag.create({ + data: { + id: 3, + name: 'Tag3', + }, + }); + + // connect + await expect( + client.user.update({ + where: { id: 1 }, + data: { tags: { connect: { id: 3 } } }, + include: { tags: true }, + }) + ).resolves.toMatchObject({ + tags: [ + expect.objectContaining({ id: 1 }), + expect.objectContaining({ id: 2 }), + expect.objectContaining({ id: 3 }), + ], + }); + // connecting a connected entity is no-op + await expect( + client.user.update({ + where: { id: 1 }, + data: { tags: { connect: { id: 3 } } }, + include: { tags: true }, + }) + ).resolves.toMatchObject({ + tags: [ + expect.objectContaining({ id: 1 }), + expect.objectContaining({ id: 2 }), + expect.objectContaining({ id: 3 }), + ], + }); + + // disconnect + await expect( + client.user.update({ + where: { id: 1 }, + data: { tags: { disconnect: { id: 3 } } }, + include: { tags: true }, + }) + ).resolves.toMatchObject({ + tags: [ + expect.objectContaining({ id: 1 }), + expect.objectContaining({ id: 2 }), + ], + }); + + await expect( + client.$qbRaw + .selectFrom('_TagToUser') + .selectAll() + .where('B', '=', 1) // user id + .where('A', '=', 3) // tag id + .execute() + ).resolves.toHaveLength(0); + + await expect( + client.user.update({ + where: { id: 1 }, + data: { tags: { set: [{ id: 2 }, { id: 3 }] } }, + include: { tags: true }, + }) + ).resolves.toMatchObject({ + tags: [ + expect.objectContaining({ id: 2 }), + expect.objectContaining({ id: 3 }), + ], + }); + + // update - not found + await expect( + client.user.update({ + where: { id: 1 }, + data: { + tags: { + update: { + where: { id: 1 }, + data: { name: 'Tag1-updated' }, + }, + }, + }, + }) + ).toBeRejectedNotFound(); + + // update - found + await expect( + client.user.update({ + where: { id: 1 }, + data: { + tags: { + update: { + where: { id: 2 }, + data: { name: 'Tag2-updated' }, + }, + }, + }, + include: { tags: true }, + }) + ).resolves.toMatchObject({ + tags: [ + expect.objectContaining({ id: 2, name: 'Tag2-updated' }), + expect.objectContaining({ id: 3, name: 'Tag3' }), + ], + }); + + // updateMany + await expect( + client.user.update({ + where: { id: 1 }, + data: { + tags: { + updateMany: { + where: { id: { not: 2 } }, + data: { name: 'Tag3-updated' }, + }, + }, + }, + include: { tags: true }, + }) + ).resolves.toMatchObject({ + tags: [ + expect.objectContaining({ id: 2, name: 'Tag2-updated' }), + expect.objectContaining({ id: 3, name: 'Tag3-updated' }), + ], + }); + + await expect( + client.tag.findUnique({ where: { id: 1 } }) + ).resolves.toMatchObject({ + name: 'Tag1', + }); + + // upsert - update + await expect( + client.user.update({ + where: { id: 1 }, + data: { + tags: { + upsert: { + where: { id: 3 }, + create: { id: 3, name: 'Tag4' }, + update: { name: 'Tag3-updated-1' }, + }, + }, + }, + include: { tags: true }, + }) + ).resolves.toMatchObject({ + tags: [ + expect.objectContaining({ id: 2, name: 'Tag2-updated' }), + expect.objectContaining({ id: 3, name: 'Tag3-updated-1' }), + ], + }); + + // upsert - create + await expect( + client.user.update({ + where: { id: 1 }, + data: { + tags: { + upsert: { + where: { id: 4 }, + create: { id: 4, name: 'Tag4' }, + update: { name: 'Tag4' }, + }, + }, + }, + include: { tags: true }, + }) + ).resolves.toMatchObject({ + tags: expect.arrayContaining([ + expect.objectContaining({ id: 4, name: 'Tag4' }), + ]), + }); + + // delete - not found + await expect( + client.user.update({ + where: { id: 1 }, + data: { tags: { delete: { id: 1 } } }, + }) + ).toBeRejectedNotFound(); + + // delete - found + await expect( + client.user.update({ + where: { id: 1 }, + data: { tags: { delete: { id: 2 } } }, + include: { tags: true }, + }) + ).resolves.toMatchObject({ + tags: [ + expect.objectContaining({ id: 3 }), + expect.objectContaining({ id: 4 }), + ], + }); + await expect( + client.tag.findUnique({ where: { id: 2 } }) + ).toResolveNull(); + + // deleteMany + await expect( + client.user.update({ + where: { id: 1 }, + data: { tags: { deleteMany: { id: { in: [1, 2, 3] } } } }, + include: { tags: true }, + }) + ).resolves.toMatchObject({ + tags: [expect.objectContaining({ id: 4 })], + }); + await expect( + client.tag.findUnique({ where: { id: 3 } }) + ).toResolveNull(); + await expect( + client.tag.findUnique({ where: { id: 1 } }) + ).toResolveTruthy(); + }); + + it('works with delete', async () => { + await client.user.create({ + data: { + id: 1, + name: 'User1', + tags: { + create: [ + { id: 1, name: 'Tag1' }, + { id: 2, name: 'Tag2' }, + ], + }, + }, + }); + + // cascade from tag + await client.tag.delete({ + where: { id: 1 }, + }); + await expect( + client.user.findUnique({ + where: { id: 1 }, + include: { tags: true }, + }) + ).resolves.toMatchObject({ + tags: [expect.objectContaining({ id: 2 })], + }); + + // cascade from user + await client.user.delete({ + where: { id: 1 }, + }); + await expect( + client.tag.findUnique({ + where: { id: 2 }, + include: { users: true }, + }) + ).resolves.toMatchObject({ + users: [], + }); + }); + }); +}); diff --git a/packages/runtime/test/policy/todo-sample.test.ts b/packages/runtime/test/policy/todo-sample.test.ts index 2d278d02..922801b7 100644 --- a/packages/runtime/test/policy/todo-sample.test.ts +++ b/packages/runtime/test/policy/todo-sample.test.ts @@ -8,9 +8,10 @@ describe('todo sample tests', () => { let schema: SchemaDef; beforeAll(async () => { - schema = await generateTsSchemaFromFile( + const r = await generateTsSchemaFromFile( path.join(__dirname, '../schemas/todo.zmodel') ); + schema = r.schema; }); it('works with user CRUD', async () => { diff --git a/packages/runtime/test/utils.ts b/packages/runtime/test/utils.ts index e71f3d68..4e87db46 100644 --- a/packages/runtime/test/utils.ts +++ b/packages/runtime/test/utils.ts @@ -1,5 +1,10 @@ +import { loadDocument } from '@zenstackhq/language'; +import { PrismaSchemaGenerator } from '@zenstackhq/sdk'; import { generateTsSchema } from '@zenstackhq/testtools'; import Sqlite from 'better-sqlite3'; +import { execSync } from 'node:child_process'; +import fs from 'node:fs'; +import path from 'node:path'; import { Client as PGClient, Pool } from 'pg'; import invariant from 'tiny-invariant'; import { ZenStackClient } from '../src/client'; @@ -58,6 +63,7 @@ export type CreateTestClientOptions = ClientOptions & { provider?: 'sqlite' | 'postgresql'; dbName?: string; + usePrismaPush?: boolean; }; export async function createTestClient( @@ -72,16 +78,44 @@ export async function createTestClient( schema: Schema | string, options?: CreateTestClientOptions ): Promise { - let _schema = - typeof schema === 'string' - ? ((await generateTsSchema( - schema, - options?.provider, - options?.dbName - )) as Schema) - : schema; + let workDir: string | undefined; + let _schema: Schema; - const { plugins, ...rest } = options ?? {}; + if (typeof schema === 'string') { + const generated = await generateTsSchema( + schema, + options?.provider, + options?.dbName + ); + workDir = generated.workDir; + _schema = generated.schema as Schema; + } else { + _schema = schema; + } + + if (options?.usePrismaPush) { + invariant(typeof schema === 'string', 'schema must be a string'); + invariant(workDir, 'workDir is required'); + const r = await loadDocument(path.resolve(workDir, 'schema.zmodel')); + if (!r.success) { + throw new Error(r.errors.join('\n')); + } + const prismaSchema = new PrismaSchemaGenerator(r.model); + const prismaSchemaText = await prismaSchema.generate(); + fs.writeFileSync( + path.resolve(workDir, 'schema.prisma'), + prismaSchemaText + ); + execSync( + 'npx prisma db push --schema ./schema.prisma --skip-generate', + { + cwd: workDir!, + stdio: 'inherit', + } + ); + } + + const { plugins, usePrismaPush, ...rest } = options ?? {}; let client = new ZenStackClient(_schema, rest as ClientOptions); @@ -94,7 +128,9 @@ export async function createTestClient( await pgClient.end(); } - await client.$pushSchema(); + if (!usePrismaPush) { + await client.$pushSchema(); + } if (options?.plugins) { for (const plugin of options.plugins) { diff --git a/packages/sdk/src/index.ts b/packages/sdk/src/index.ts index f7f7bca8..fccc7ce8 100644 --- a/packages/sdk/src/index.ts +++ b/packages/sdk/src/index.ts @@ -1,4 +1,5 @@ import * as ModelUtils from './model-utils'; +export { PrismaSchemaGenerator } from './prisma/prisma-schema-generator'; export * from './ts-schema-generator'; export * from './zmodel-code-generator'; export { ModelUtils }; diff --git a/packages/cli/src/prisma/indent-string.ts b/packages/sdk/src/prisma/indent-string.ts similarity index 100% rename from packages/cli/src/prisma/indent-string.ts rename to packages/sdk/src/prisma/indent-string.ts diff --git a/packages/cli/src/prisma/prisma-builder.ts b/packages/sdk/src/prisma/prisma-builder.ts similarity index 100% rename from packages/cli/src/prisma/prisma-builder.ts rename to packages/sdk/src/prisma/prisma-builder.ts diff --git a/packages/cli/src/prisma/prisma-schema-generator.ts b/packages/sdk/src/prisma/prisma-schema-generator.ts similarity index 99% rename from packages/cli/src/prisma/prisma-schema-generator.ts rename to packages/sdk/src/prisma/prisma-schema-generator.ts index 62744403..f0bcda10 100644 --- a/packages/cli/src/prisma/prisma-schema-generator.ts +++ b/packages/sdk/src/prisma/prisma-schema-generator.ts @@ -32,7 +32,7 @@ import { import { AstUtils } from 'langium'; import { match, P } from 'ts-pattern'; -import { ModelUtils, ZModelCodeGenerator } from '@zenstackhq/sdk'; +import { ModelUtils, ZModelCodeGenerator } from '..'; import { AttributeArgValue, ModelField, diff --git a/packages/sdk/src/ts-schema-generator.ts b/packages/sdk/src/ts-schema-generator.ts index fb27cb1d..d5b60df1 100644 --- a/packages/sdk/src/ts-schema-generator.ts +++ b/packages/sdk/src/ts-schema-generator.ts @@ -769,6 +769,16 @@ export class TsSchemaGenerator { ); } + const relationName = this.getRelationName(field); + if (relationName) { + relationFields.push( + ts.factory.createPropertyAssignment( + 'name', + ts.factory.createStringLiteral(relationName) + ) + ); + } + const relation = getAttribute(field, '@relation'); if (relation) { for (const arg of relation.args) { @@ -843,15 +853,39 @@ export class TsSchemaGenerator { const sourceModel = field.$container as DataModel; const targetModel = field.type.reference.ref as DataModel; - + const relationName = this.getRelationName(field); for (const otherField of targetModel.fields) { if (otherField === field) { // backlink field is never self continue; } if (otherField.type.reference?.ref === sourceModel) { - // TODO: named relation - return otherField; + if (relationName) { + // if relation has a name, the opposite side must match + const otherRelationName = this.getRelationName(otherField); + if (otherRelationName === relationName) { + return otherField; + } + } else { + return otherField; + } + } + } + return undefined; + } + + private getRelationName(field: DataModelField) { + const relation = getAttribute(field, '@relation'); + if (relation) { + const nameArg = relation.args.find( + (arg) => arg.$resolvedParam.name === 'name' + ); + if (nameArg) { + invariant( + isLiteralExpr(nameArg.value), + 'name must be a literal' + ); + return nameArg.value.value as string; } } return undefined; diff --git a/packages/sdk/tsconfig.json b/packages/sdk/tsconfig.json index b2b15c85..dc25b0a5 100644 --- a/packages/sdk/tsconfig.json +++ b/packages/sdk/tsconfig.json @@ -4,5 +4,5 @@ "outDir": "dist", "noUnusedLocals": false }, - "include": ["src/**/*.ts", "test/**/*.ts"] + "include": ["src/**/*.ts"] } diff --git a/packages/testtools/package.json b/packages/testtools/package.json index f83ba902..9996bc59 100644 --- a/packages/testtools/package.json +++ b/packages/testtools/package.json @@ -35,7 +35,8 @@ "glob": "^11.0.2", "tmp": "^0.2.3", "ts-pattern": "^5.7.1", - "typescript": "^5.8.3" + "typescript": "^5.8.3", + "prisma": "^6.9.0" }, "peerDependencies": { "better-sqlite3": "^11.8.1", diff --git a/packages/testtools/src/schema.ts b/packages/testtools/src/schema.ts index ebfb9494..0eab224e 100644 --- a/packages/testtools/src/schema.ts +++ b/packages/testtools/src/schema.ts @@ -1,10 +1,10 @@ import type { SchemaDef } from '@zenstackhq/runtime/schema'; import { TsSchemaGenerator } from '@zenstackhq/sdk'; +import { glob } from 'glob'; import { execSync } from 'node:child_process'; import fs from 'node:fs'; import path from 'node:path'; import tmp from 'tmp'; -import { glob } from 'glob'; import { match } from 'ts-pattern'; function makePrelude(provider: 'sqlite' | 'postgresql', dbName?: string) { @@ -13,7 +13,7 @@ function makePrelude(provider: 'sqlite' | 'postgresql', dbName?: string) { return ` datasource db { provider = 'sqlite' - url = ':memory:' + url = '${dbName ?? ':memory:'}' } `; }) @@ -87,7 +87,7 @@ export async function generateTsSchema( // load the schema module const module = await import(path.join(workDir, 'schema.js')); - return module.schema as SchemaDef; + return { workDir, schema: module.schema as SchemaDef }; } export function generateTsSchemaFromFile(filePath: string) { diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 471155e9..d01c62e4 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -16,7 +16,7 @@ importers: version: 4.1.5 tsup: specifier: ^8.3.5 - version: 8.3.5(@swc/core@1.10.15)(postcss@8.5.1)(tsx@4.19.2)(typescript@5.7.3) + version: 8.3.5(@swc/core@1.10.15)(jiti@2.4.2)(postcss@8.5.1)(tsx@4.19.2)(typescript@5.7.3) tsx: specifier: ^4.19.2 version: 4.19.2 @@ -182,6 +182,9 @@ importers: '@zenstackhq/language': specifier: workspace:* version: link:../language + '@zenstackhq/sdk': + specifier: workspace:* + version: link:../sdk '@zenstackhq/testtools': specifier: workspace:* version: link:../testtools @@ -249,6 +252,9 @@ importers: pg: specifier: ^8.13.1 version: 8.13.1 + prisma: + specifier: ^6.9.0 + version: 6.9.0(typescript@5.8.3) tmp: specifier: ^0.2.3 version: 0.2.3 @@ -822,21 +828,39 @@ packages: '@prisma/config@6.5.0': resolution: {integrity: sha512-sOH/2Go9Zer67DNFLZk6pYOHj+rumSb0VILgltkoxOjYnlLqUpHPAN826vnx8HigqnOCxj9LRhT6U7uLiIIWgw==} + '@prisma/config@6.9.0': + resolution: {integrity: sha512-Wcfk8/lN3WRJd5w4jmNQkUwhUw0eksaU/+BlAJwPQKW10k0h0LC9PD/6TQFmqKVbHQL0vG2z266r0S1MPzzhbA==} + '@prisma/debug@6.5.0': resolution: {integrity: sha512-fc/nusYBlJMzDmDepdUtH9aBsJrda2JNErP9AzuHbgUEQY0/9zQYZdNlXmKoIWENtio+qarPNe/+DQtrX5kMcQ==} + '@prisma/debug@6.9.0': + resolution: {integrity: sha512-bFeur/qi/Q+Mqk4JdQ3R38upSYPebv5aOyD1RKywVD+rAMLtRkmTFn28ZuTtVOnZHEdtxnNOCH+bPIeSGz1+Fg==} + '@prisma/engines-version@6.5.0-73.173f8d54f8d52e692c7e27e72a88314ec7aeff60': resolution: {integrity: sha512-iK3EmiVGFDCmXjSpdsKGNqy9hOdLnvYBrJB61far/oP03hlIxrb04OWmDjNTwtmZ3UZdA5MCvI+f+3k2jPTflQ==} + '@prisma/engines-version@6.9.0-10.81e4af48011447c3cc503a190e86995b66d2a28e': + resolution: {integrity: sha512-Qp9gMoBHgqhKlrvumZWujmuD7q4DV/gooEyPCLtbkc13EZdSz2RsGUJ5mHb3RJgAbk+dm6XenqG7obJEhXcJ6Q==} + '@prisma/engines@6.5.0': resolution: {integrity: sha512-FVPQYHgOllJklN9DUyujXvh3hFJCY0NX86sDmBErLvoZjy2OXGiZ5FNf3J/C4/RZZmCypZBYpBKEhx7b7rEsdw==} + '@prisma/engines@6.9.0': + resolution: {integrity: sha512-im0X0bwDLA0244CDf8fuvnLuCQcBBdAGgr+ByvGfQY9wWl6EA+kRGwVk8ZIpG65rnlOwtaWIr/ZcEU5pNVvq9g==} + '@prisma/fetch-engine@6.5.0': resolution: {integrity: sha512-3LhYA+FXP6pqY8FLHCjewyE8pGXXJ7BxZw2rhPq+CZAhvflVzq4K8Qly3OrmOkn6wGlz79nyLQdknyCG2HBTuA==} + '@prisma/fetch-engine@6.9.0': + resolution: {integrity: sha512-PMKhJdl4fOdeE3J3NkcWZ+tf3W6rx3ht/rLU8w4SXFRcLhd5+3VcqY4Kslpdm8osca4ej3gTfB3+cSk5pGxgFg==} + '@prisma/get-platform@6.5.0': resolution: {integrity: sha512-xYcvyJwNMg2eDptBYFqFLUCfgi+wZLcj6HDMsj0Qw0irvauG4IKmkbywnqwok0B+k+W+p+jThM2DKTSmoPCkzw==} + '@prisma/get-platform@6.9.0': + resolution: {integrity: sha512-/B4n+5V1LI/1JQcHp+sUpyRT1bBgZVPHbsC4lt4/19Xp4jvNIVcq5KYNtQDk5e/ukTSjo9PZVAxxy9ieFtlpTQ==} + '@rollup/rollup-android-arm-eabi@4.30.1': resolution: {integrity: sha512-pSWY+EVt3rJ9fQ3IqlrEUtXh3cGqGtPDH1FQlNZehO2yYxCHEX1SPsz1M//NXwYfbTlcKr9WObLnJX9FsS9K1Q==} cpu: [arm] @@ -1862,6 +1886,10 @@ packages: resolution: {integrity: sha512-9DDdhb5j6cpeitCbvLO7n7J4IxnbM6hoF6O1g4HQ5TfhvvKN8ywDM7668ZhMHRqVmxqhps/F6syWK2KcPxYlkw==} engines: {node: 20 || >=22} + jiti@2.4.2: + resolution: {integrity: sha512-rg9zJN+G4n2nfJl5MW3BMygZX56zKPNVEYYqq7adpmMh4Jn2QNEwhvQlFy6jPVdcod7txZtKHWnyZiA3a0zP7A==} + hasBin: true + joycon@3.1.1: resolution: {integrity: sha512-34wB/Y7MW7bzjKRjUKTa46I2Z7eV62Rkhva+KkopW7Qvv/OSWBqvkSY7vusOPrNuZcUG3tApvdVgNB8POj3SPw==} engines: {node: '>=10'} @@ -2297,6 +2325,16 @@ packages: typescript: optional: true + prisma@6.9.0: + resolution: {integrity: sha512-resJAwMyZREC/I40LF6FZ6rZTnlrlrYrb63oW37Gq+U+9xHwbyMSPJjKtM7VZf3gTO86t/Oyz+YeSXr3CmAY1Q==} + engines: {node: '>=18.18'} + hasBin: true + peerDependencies: + typescript: '>=5.1.0' + peerDependenciesMeta: + typescript: + optional: true + pump@3.0.2: resolution: {integrity: sha512-tUPXtzlGM8FE3P0ZL6DVs/3P58k9nk8/jZeQCurTJylQA8qFYzHFfhBJkuqyE0FifOsQ0uKWekiZ5g8wtr28cw==} @@ -3247,10 +3285,18 @@ snapshots: transitivePeerDependencies: - supports-color + '@prisma/config@6.9.0': + dependencies: + jiti: 2.4.2 + '@prisma/debug@6.5.0': {} + '@prisma/debug@6.9.0': {} + '@prisma/engines-version@6.5.0-73.173f8d54f8d52e692c7e27e72a88314ec7aeff60': {} + '@prisma/engines-version@6.9.0-10.81e4af48011447c3cc503a190e86995b66d2a28e': {} + '@prisma/engines@6.5.0': dependencies: '@prisma/debug': 6.5.0 @@ -3258,16 +3304,33 @@ snapshots: '@prisma/fetch-engine': 6.5.0 '@prisma/get-platform': 6.5.0 + '@prisma/engines@6.9.0': + dependencies: + '@prisma/debug': 6.9.0 + '@prisma/engines-version': 6.9.0-10.81e4af48011447c3cc503a190e86995b66d2a28e + '@prisma/fetch-engine': 6.9.0 + '@prisma/get-platform': 6.9.0 + '@prisma/fetch-engine@6.5.0': dependencies: '@prisma/debug': 6.5.0 '@prisma/engines-version': 6.5.0-73.173f8d54f8d52e692c7e27e72a88314ec7aeff60 '@prisma/get-platform': 6.5.0 + '@prisma/fetch-engine@6.9.0': + dependencies: + '@prisma/debug': 6.9.0 + '@prisma/engines-version': 6.9.0-10.81e4af48011447c3cc503a190e86995b66d2a28e + '@prisma/get-platform': 6.9.0 + '@prisma/get-platform@6.5.0': dependencies: '@prisma/debug': 6.5.0 + '@prisma/get-platform@6.9.0': + dependencies: + '@prisma/debug': 6.9.0 + '@rollup/rollup-android-arm-eabi@4.30.1': optional: true @@ -4432,6 +4495,8 @@ snapshots: dependencies: '@isaacs/cliui': 8.0.2 + jiti@2.4.2: {} + joycon@3.1.1: {} js-yaml@4.1.0: @@ -4772,10 +4837,11 @@ snapshots: possible-typed-array-names@1.1.0: {} - postcss-load-config@6.0.1(postcss@8.5.1)(tsx@4.19.2): + postcss-load-config@6.0.1(jiti@2.4.2)(postcss@8.5.1)(tsx@4.19.2): dependencies: lilconfig: 3.1.3 optionalDependencies: + jiti: 2.4.2 postcss: 8.5.1 tsx: 4.19.2 @@ -4844,6 +4910,13 @@ snapshots: transitivePeerDependencies: - supports-color + prisma@6.9.0(typescript@5.8.3): + dependencies: + '@prisma/config': 6.9.0 + '@prisma/engines': 6.9.0 + optionalDependencies: + typescript: 5.8.3 + pump@3.0.2: dependencies: end-of-stream: 1.4.4 @@ -5229,7 +5302,7 @@ snapshots: ts-pattern@5.7.1: {} - tsup@8.3.5(@swc/core@1.10.15)(postcss@8.5.1)(tsx@4.19.2)(typescript@5.7.3): + tsup@8.3.5(@swc/core@1.10.15)(jiti@2.4.2)(postcss@8.5.1)(tsx@4.19.2)(typescript@5.7.3): dependencies: bundle-require: 5.1.0(esbuild@0.24.2) cac: 6.7.14 @@ -5239,7 +5312,7 @@ snapshots: esbuild: 0.24.2 joycon: 3.1.1 picocolors: 1.1.1 - postcss-load-config: 6.0.1(postcss@8.5.1)(tsx@4.19.2) + postcss-load-config: 6.0.1(jiti@2.4.2)(postcss@8.5.1)(tsx@4.19.2) resolve-from: 5.0.0 rollup: 4.30.1 source-map: 0.8.0-beta.0 diff --git a/samples/blog/zenstack/schema.ts b/samples/blog/zenstack/schema.ts index 994c5309..d4ba702d 100644 --- a/samples/blog/zenstack/schema.ts +++ b/samples/blog/zenstack/schema.ts @@ -10,7 +10,7 @@ import { toDialectConfig } from "@zenstackhq/runtime/utils/sqlite-utils"; export const schema = { provider: { type: "sqlite", - dialectConfigProvider: function (): any { + dialectConfigProvider: function () { return toDialectConfig("./dev.db", typeof __dirname !== 'undefined' ? __dirname : path.dirname(url.fileURLToPath(import.meta.url))); } }, From 34689609b87c388d22c23aa64ee0391fab9a100b Mon Sep 17 00:00:00 2001 From: ymc9 <104139426+ymc9@users.noreply.github.com> Date: Sun, 8 Jun 2025 22:24:20 -0700 Subject: [PATCH 2/2] postgres support --- .../src/client/crud/dialects/postgresql.ts | 65 +- .../src/client/helpers/schema-db-pusher.ts | 22 + .../runtime/test/client-api/relation.test.ts | 972 ++++++++++-------- packages/runtime/test/utils.ts | 22 +- 4 files changed, 610 insertions(+), 471 deletions(-) diff --git a/packages/runtime/src/client/crud/dialects/postgresql.ts b/packages/runtime/src/client/crud/dialects/postgresql.ts index 857f45e7..72d40932 100644 --- a/packages/runtime/src/client/crud/dialects/postgresql.ts +++ b/packages/runtime/src/client/crud/dialects/postgresql.ts @@ -6,6 +6,7 @@ import { type RawBuilder, type SelectQueryBuilder, } from 'kysely'; +import invariant from 'tiny-invariant'; import { match } from 'ts-pattern'; import type { SchemaDef } from '../../../schema'; import type { BuiltinType, FieldDef, GetModels } from '../../../schema/schema'; @@ -13,6 +14,8 @@ import type { FindArgs } from '../../crud-types'; import { buildFieldRef, buildJoinPairs, + getIdFields, + getManyToManyRelation, requireField, requireModel, } from '../../query-utils'; @@ -129,21 +132,61 @@ export class PostgresCrudDialect< } // add join conditions - const joinPairs = buildJoinPairs( + + const m2m = getManyToManyRelation( this.schema, model, - parentName, - relationField, - relationModel + relationField ); - subQuery = subQuery.where((eb) => - this.and( - eb, - ...joinPairs.map(([left, right]) => - eb(sql.ref(left), '=', sql.ref(right)) + + if (m2m) { + // many-to-many relation + const parentIds = getIdFields(this.schema, model); + const relationIds = getIdFields( + this.schema, + relationModel + ); + invariant( + parentIds.length === 1, + 'many-to-many relation must have exactly one id field' + ); + invariant( + relationIds.length === 1, + 'many-to-many relation must have exactly one id field' + ); + subQuery = subQuery.where( + eb( + eb.ref(`${relationModel}.${relationIds[0]}`), + 'in', + eb + .selectFrom(m2m.joinTable) + .select( + `${m2m.joinTable}.${m2m.otherFkName}` + ) + .whereRef( + `${parentName}.${parentIds[0]}`, + '=', + `${m2m.joinTable}.${m2m.parentFkName}` + ) ) - ) - ); + ); + } else { + const joinPairs = buildJoinPairs( + this.schema, + model, + parentName, + relationField, + relationModel + ); + subQuery = subQuery.where((eb) => + this.and( + eb, + ...joinPairs.map(([left, right]) => + eb(sql.ref(left), '=', sql.ref(right)) + ) + ) + ); + } return subQuery.as(joinTableName); }); diff --git a/packages/runtime/src/client/helpers/schema-db-pusher.ts b/packages/runtime/src/client/helpers/schema-db-pusher.ts index d0e752a6..8032e84c 100644 --- a/packages/runtime/src/client/helpers/schema-db-pusher.ts +++ b/packages/runtime/src/client/helpers/schema-db-pusher.ts @@ -165,6 +165,13 @@ export class SchemaDbPusher { col = col.notNull(); } + if ( + this.isAutoIncrement(fieldDef) && + this.schema.provider.type === 'sqlite' + ) { + col = col.autoIncrement(); + } + return col; } ); @@ -177,6 +184,13 @@ export class SchemaDbPusher { : 'text'; } + if ( + this.isAutoIncrement(fieldDef) && + this.schema.provider.type === 'postgresql' + ) { + return 'serial'; + } + const type = fieldDef.type as BuiltinType; let result = match(type) .with('String', () => 'text') @@ -201,6 +215,14 @@ export class SchemaDbPusher { } } + private isAutoIncrement(fieldDef: FieldDef) { + return ( + fieldDef.default && + Expression.isCall(fieldDef.default) && + fieldDef.default.function === 'autoincrement' + ); + } + private addForeignKeyConstraint( table: CreateTableBuilder, model: GetModels, diff --git a/packages/runtime/test/client-api/relation.test.ts b/packages/runtime/test/client-api/relation.test.ts index 2760ed49..da688b14 100644 --- a/packages/runtime/test/client-api/relation.test.ts +++ b/packages/runtime/test/client-api/relation.test.ts @@ -1,9 +1,23 @@ -import { beforeEach, describe, expect, it } from 'vitest'; +import { afterEach, beforeEach, describe, expect, it } from 'vitest'; import { createTestClient } from '../utils'; -describe('Relation tests', () => { +const TEST_DB = 'client-api-relation-test'; + +describe.each([ + { + provider: 'sqlite' as const, + }, + { provider: 'postgresql' as const }, +])('Relation tests for $provider', ({ provider }) => { + let client: any; + + afterEach(async () => { + await client?.$disconnect(); + }); + it('works with unnamed one-to-one relation', async () => { - const client = await createTestClient(` + client = await createTestClient( + ` model User { id Int @id @default(autoincrement()) name String @@ -16,7 +30,12 @@ describe('Relation tests', () => { user User @relation(fields: [userId], references: [id]) userId Int @unique } - `); + `, + { + provider, + dbName: TEST_DB, + } + ); await expect( client.user.create({ @@ -33,7 +52,8 @@ describe('Relation tests', () => { }); it('works with named one-to-one relation', async () => { - const client = await createTestClient(` + client = await createTestClient( + ` model User { id Int @id @default(autoincrement()) name String @@ -49,7 +69,12 @@ describe('Relation tests', () => { userId1 Int? @unique userId2 Int? @unique } - `); + `, + { + provider, + dbName: TEST_DB, + } + ); await expect( client.user.create({ @@ -68,7 +93,8 @@ describe('Relation tests', () => { }); it('works with unnamed one-to-many relation', async () => { - const client = await createTestClient(` + client = await createTestClient( + ` model User { id Int @id @default(autoincrement()) name String @@ -81,7 +107,12 @@ describe('Relation tests', () => { user User @relation(fields: [userId], references: [id]) userId Int } - `); + `, + { + provider, + dbName: TEST_DB, + } + ); await expect( client.user.create({ @@ -103,7 +134,8 @@ describe('Relation tests', () => { }); it('works with named one-to-many relation', async () => { - const client = await createTestClient(` + client = await createTestClient( + ` model User { id Int @id @default(autoincrement()) name String @@ -119,7 +151,12 @@ describe('Relation tests', () => { userId1 Int? userId2 Int? } - `); + `, + { + provider, + dbName: TEST_DB, + } + ); await expect( client.user.create({ @@ -148,7 +185,8 @@ describe('Relation tests', () => { }); it('works with explicit many-to-many relation', async () => { - const client = await createTestClient(` + client = await createTestClient( + ` model User { id Int @id @default(autoincrement()) name String @@ -169,7 +207,12 @@ describe('Relation tests', () => { tag Tag @relation(fields: [tagId], references: [id]) @@unique([userId, tagId]) } - `); + `, + { + provider, + dbName: TEST_DB, + } + ); await client.user.create({ data: { id: 1, name: 'User1' } }); await client.user.create({ data: { id: 2, name: 'User2' } }); @@ -207,23 +250,27 @@ describe('Relation tests', () => { ]); }); - describe('Implicit many-to-many relation', () => { - let client: any; - - beforeEach(async () => { - client = await createTestClient( - ` + describe.each([{ relationName: undefined }, { relationName: 'myM2M' }])( + 'Implicit many-to-many relation ($relationName)', + ({ relationName }) => { + beforeEach(async () => { + client = await createTestClient( + ` model User { id Int @id @default(autoincrement()) name String profile Profile? - tags Tag[] + tags Tag[] ${ + relationName ? `@relation("${relationName}")` : '' + } } model Tag { id Int @id @default(autoincrement()) name String - users User[] + users User[] ${ + relationName ? `@relation("${relationName}")` : '' + } } model Profile { @@ -233,488 +280,513 @@ describe('Relation tests', () => { userId Int @unique } `, - { dbName: 'file:./dev.db', usePrismaPush: true } - ); - }); - - it('works with find', async () => { - await client.user.create({ - data: { - id: 1, - name: 'User1', - tags: { - create: [ - { id: 1, name: 'Tag1' }, - { id: 2, name: 'Tag2' }, - ], - }, - profile: { - create: { - id: 1, - age: 20, - }, - }, - }, - }); - - await client.user.create({ - data: { - id: 2, - name: 'User2', - }, + { + provider, + dbName: + provider === 'sqlite' ? 'file:./dev.db' : TEST_DB, + usePrismaPush: true, + } + ); }); - // include without filter - await expect( - client.user.findFirst({ - include: { tags: true }, - }) - ).resolves.toMatchObject({ - tags: [ - expect.objectContaining({ name: 'Tag1' }), - expect.objectContaining({ name: 'Tag2' }), - ], - }); - - await expect( - client.profile.findFirst({ - include: { - user: { - include: { tags: true }, - }, - }, - }) - ).resolves.toMatchObject({ - user: expect.objectContaining({ - tags: [ - expect.objectContaining({ name: 'Tag1' }), - expect.objectContaining({ name: 'Tag2' }), - ], - }), - }); - - await expect( - client.user.findUnique({ - where: { id: 2 }, - include: { tags: true }, - }) - ).resolves.toMatchObject({ - tags: [], - }); - - // include with filter - await expect( - client.user.findFirst({ - where: { id: 1 }, - include: { tags: { where: { name: 'Tag1' } } }, - }) - ).resolves.toMatchObject({ - tags: [expect.objectContaining({ name: 'Tag1' })], - }); - - // filter with m2m - await expect( - client.user.findMany({ - where: { tags: { some: { name: 'Tag1' } } }, - }) - ).resolves.toEqual([ - expect.objectContaining({ - name: 'User1', - }), - ]); - await expect( - client.user.findMany({ - where: { tags: { none: { name: 'Tag1' } } }, - }) - ).resolves.toEqual([ - expect.objectContaining({ - name: 'User2', - }), - ]); - }); - - it('works with create', async () => { - // create - await expect( - client.user.create({ + it('works with find', async () => { + await client.user.create({ data: { id: 1, name: 'User1', tags: { create: [ - { - id: 1, - name: 'Tag1', - }, - { - id: 2, - name: 'Tag2', - }, + { id: 1, name: 'Tag1' }, + { id: 2, name: 'Tag2' }, ], }, + profile: { + create: { + id: 1, + age: 20, + }, + }, }, - include: { tags: true }, - }) - ).resolves.toMatchObject({ - tags: [ - expect.objectContaining({ name: 'Tag1' }), - expect.objectContaining({ name: 'Tag2' }), - ], - }); + }); - // connect - await expect( - client.user.create({ + await client.user.create({ data: { id: 2, name: 'User2', - tags: { connect: { id: 1 } }, }, - include: { tags: true }, - }) - ).resolves.toMatchObject({ - tags: [expect.objectContaining({ name: 'Tag1' })], - }); + }); + + // include without filter + await expect( + client.user.findFirst({ + include: { tags: true }, + }) + ).resolves.toMatchObject({ + tags: [ + expect.objectContaining({ name: 'Tag1' }), + expect.objectContaining({ name: 'Tag2' }), + ], + }); - // connectOrCreate - await expect( - client.user.create({ - data: { - id: 3, - name: 'User3', - tags: { - connectOrCreate: { - where: { id: 1 }, - create: { id: 1, name: 'Tag1' }, + await expect( + client.profile.findFirst({ + include: { + user: { + include: { tags: true }, }, }, - }, - include: { tags: true }, - }) - ).resolves.toMatchObject({ - tags: [expect.objectContaining({ id: 1, name: 'Tag1' })], + }) + ).resolves.toMatchObject({ + user: expect.objectContaining({ + tags: [ + expect.objectContaining({ name: 'Tag1' }), + expect.objectContaining({ name: 'Tag2' }), + ], + }), + }); + + await expect( + client.user.findUnique({ + where: { id: 2 }, + include: { tags: true }, + }) + ).resolves.toMatchObject({ + tags: [], + }); + + // include with filter + await expect( + client.user.findFirst({ + where: { id: 1 }, + include: { tags: { where: { name: 'Tag1' } } }, + }) + ).resolves.toMatchObject({ + tags: [expect.objectContaining({ name: 'Tag1' })], + }); + + // filter with m2m + await expect( + client.user.findMany({ + where: { tags: { some: { name: 'Tag1' } } }, + }) + ).resolves.toEqual([ + expect.objectContaining({ + name: 'User1', + }), + ]); + await expect( + client.user.findMany({ + where: { tags: { none: { name: 'Tag1' } } }, + }) + ).resolves.toEqual([ + expect.objectContaining({ + name: 'User2', + }), + ]); }); - await expect( - client.user.create({ - data: { - id: 4, - name: 'User4', - tags: { - connectOrCreate: { - where: { id: 3 }, - create: { id: 3, name: 'Tag3' }, + it('works with create', async () => { + // create + await expect( + client.user.create({ + data: { + id: 1, + name: 'User1', + tags: { + create: [ + { + id: 1, + name: 'Tag1', + }, + { + id: 2, + name: 'Tag2', + }, + ], }, }, - }, - include: { tags: true }, - }) - ).resolves.toMatchObject({ - tags: [expect.objectContaining({ id: 3, name: 'Tag3' })], - }); - }); - - it('works with update', async () => { - // create - await client.user.create({ - data: { - id: 1, - name: 'User1', - tags: { - create: [ - { - id: 1, - name: 'Tag1', + include: { tags: true }, + }) + ).resolves.toMatchObject({ + tags: [ + expect.objectContaining({ name: 'Tag1' }), + expect.objectContaining({ name: 'Tag2' }), + ], + }); + + // connect + await expect( + client.user.create({ + data: { + id: 2, + name: 'User2', + tags: { connect: { id: 1 } }, + }, + include: { tags: true }, + }) + ).resolves.toMatchObject({ + tags: [expect.objectContaining({ name: 'Tag1' })], + }); + + // connectOrCreate + await expect( + client.user.create({ + data: { + id: 3, + name: 'User3', + tags: { + connectOrCreate: { + where: { id: 1 }, + create: { id: 1, name: 'Tag1' }, + }, }, - ], - }, - }, - include: { tags: true }, + }, + include: { tags: true }, + }) + ).resolves.toMatchObject({ + tags: [expect.objectContaining({ id: 1, name: 'Tag1' })], + }); + + await expect( + client.user.create({ + data: { + id: 4, + name: 'User4', + tags: { + connectOrCreate: { + where: { id: 3 }, + create: { id: 3, name: 'Tag3' }, + }, + }, + }, + include: { tags: true }, + }) + ).resolves.toMatchObject({ + tags: [expect.objectContaining({ id: 3, name: 'Tag3' })], + }); }); - // create - await expect( - client.user.update({ - where: { id: 1 }, + it('works with update', async () => { + // create + await client.user.create({ data: { + id: 1, + name: 'User1', tags: { create: [ { - id: 2, - name: 'Tag2', + id: 1, + name: 'Tag1', }, ], }, }, include: { tags: true }, - }) - ).resolves.toMatchObject({ - tags: [ - expect.objectContaining({ id: 1 }), - expect.objectContaining({ id: 2 }), - ], - }); - - await client.tag.create({ - data: { - id: 3, - name: 'Tag3', - }, - }); - - // connect - await expect( - client.user.update({ - where: { id: 1 }, - data: { tags: { connect: { id: 3 } } }, - include: { tags: true }, - }) - ).resolves.toMatchObject({ - tags: [ - expect.objectContaining({ id: 1 }), - expect.objectContaining({ id: 2 }), - expect.objectContaining({ id: 3 }), - ], - }); - // connecting a connected entity is no-op - await expect( - client.user.update({ - where: { id: 1 }, - data: { tags: { connect: { id: 3 } } }, - include: { tags: true }, - }) - ).resolves.toMatchObject({ - tags: [ - expect.objectContaining({ id: 1 }), - expect.objectContaining({ id: 2 }), - expect.objectContaining({ id: 3 }), - ], - }); - - // disconnect - await expect( - client.user.update({ - where: { id: 1 }, - data: { tags: { disconnect: { id: 3 } } }, - include: { tags: true }, - }) - ).resolves.toMatchObject({ - tags: [ - expect.objectContaining({ id: 1 }), - expect.objectContaining({ id: 2 }), - ], - }); - - await expect( - client.$qbRaw - .selectFrom('_TagToUser') - .selectAll() - .where('B', '=', 1) // user id - .where('A', '=', 3) // tag id - .execute() - ).resolves.toHaveLength(0); - - await expect( - client.user.update({ - where: { id: 1 }, - data: { tags: { set: [{ id: 2 }, { id: 3 }] } }, - include: { tags: true }, - }) - ).resolves.toMatchObject({ - tags: [ - expect.objectContaining({ id: 2 }), - expect.objectContaining({ id: 3 }), - ], - }); - - // update - not found - await expect( - client.user.update({ - where: { id: 1 }, - data: { - tags: { - update: { - where: { id: 1 }, - data: { name: 'Tag1-updated' }, + }); + + // create + await expect( + client.user.update({ + where: { id: 1 }, + data: { + tags: { + create: [ + { + id: 2, + name: 'Tag2', + }, + ], }, }, - }, - }) - ).toBeRejectedNotFound(); + include: { tags: true }, + }) + ).resolves.toMatchObject({ + tags: [ + expect.objectContaining({ id: 1 }), + expect.objectContaining({ id: 2 }), + ], + }); - // update - found - await expect( - client.user.update({ - where: { id: 1 }, + await client.tag.create({ data: { - tags: { - update: { - where: { id: 2 }, - data: { name: 'Tag2-updated' }, + id: 3, + name: 'Tag3', + }, + }); + + // connect + await expect( + client.user.update({ + where: { id: 1 }, + data: { tags: { connect: { id: 3 } } }, + include: { tags: true }, + }) + ).resolves.toMatchObject({ + tags: [ + expect.objectContaining({ id: 1 }), + expect.objectContaining({ id: 2 }), + expect.objectContaining({ id: 3 }), + ], + }); + // connecting a connected entity is no-op + await expect( + client.user.update({ + where: { id: 1 }, + data: { tags: { connect: { id: 3 } } }, + include: { tags: true }, + }) + ).resolves.toMatchObject({ + tags: [ + expect.objectContaining({ id: 1 }), + expect.objectContaining({ id: 2 }), + expect.objectContaining({ id: 3 }), + ], + }); + + // disconnect + await expect( + client.user.update({ + where: { id: 1 }, + data: { tags: { disconnect: { id: 3 } } }, + include: { tags: true }, + }) + ).resolves.toMatchObject({ + tags: [ + expect.objectContaining({ id: 1 }), + expect.objectContaining({ id: 2 }), + ], + }); + + await expect( + client.$qbRaw + .selectFrom( + relationName ? `_${relationName}` : '_TagToUser' + ) + .selectAll() + .where('B', '=', 1) // user id + .where('A', '=', 3) // tag id + .execute() + ).resolves.toHaveLength(0); + + await expect( + client.user.update({ + where: { id: 1 }, + data: { tags: { set: [{ id: 2 }, { id: 3 }] } }, + include: { tags: true }, + }) + ).resolves.toMatchObject({ + tags: [ + expect.objectContaining({ id: 2 }), + expect.objectContaining({ id: 3 }), + ], + }); + + // update - not found + await expect( + client.user.update({ + where: { id: 1 }, + data: { + tags: { + update: { + where: { id: 1 }, + data: { name: 'Tag1-updated' }, + }, }, }, - }, - include: { tags: true }, - }) - ).resolves.toMatchObject({ - tags: [ - expect.objectContaining({ id: 2, name: 'Tag2-updated' }), - expect.objectContaining({ id: 3, name: 'Tag3' }), - ], - }); - - // updateMany - await expect( - client.user.update({ - where: { id: 1 }, - data: { - tags: { - updateMany: { - where: { id: { not: 2 } }, - data: { name: 'Tag3-updated' }, + }) + ).toBeRejectedNotFound(); + + // update - found + await expect( + client.user.update({ + where: { id: 1 }, + data: { + tags: { + update: { + where: { id: 2 }, + data: { name: 'Tag2-updated' }, + }, }, }, - }, - include: { tags: true }, - }) - ).resolves.toMatchObject({ - tags: [ - expect.objectContaining({ id: 2, name: 'Tag2-updated' }), - expect.objectContaining({ id: 3, name: 'Tag3-updated' }), - ], - }); - - await expect( - client.tag.findUnique({ where: { id: 1 } }) - ).resolves.toMatchObject({ - name: 'Tag1', - }); - - // upsert - update - await expect( - client.user.update({ - where: { id: 1 }, - data: { - tags: { - upsert: { - where: { id: 3 }, - create: { id: 3, name: 'Tag4' }, - update: { name: 'Tag3-updated-1' }, + include: { tags: true }, + }) + ).resolves.toMatchObject({ + tags: expect.arrayContaining([ + expect.objectContaining({ + id: 2, + name: 'Tag2-updated', + }), + expect.objectContaining({ id: 3, name: 'Tag3' }), + ]), + }); + + // updateMany + await expect( + client.user.update({ + where: { id: 1 }, + data: { + tags: { + updateMany: { + where: { id: { not: 2 } }, + data: { name: 'Tag3-updated' }, + }, }, }, - }, - include: { tags: true }, - }) - ).resolves.toMatchObject({ - tags: [ - expect.objectContaining({ id: 2, name: 'Tag2-updated' }), - expect.objectContaining({ id: 3, name: 'Tag3-updated-1' }), - ], + include: { tags: true }, + }) + ).resolves.toMatchObject({ + tags: [ + expect.objectContaining({ + id: 2, + name: 'Tag2-updated', + }), + expect.objectContaining({ + id: 3, + name: 'Tag3-updated', + }), + ], + }); + + await expect( + client.tag.findUnique({ where: { id: 1 } }) + ).resolves.toMatchObject({ + name: 'Tag1', + }); + + // upsert - update + await expect( + client.user.update({ + where: { id: 1 }, + data: { + tags: { + upsert: { + where: { id: 3 }, + create: { id: 3, name: 'Tag4' }, + update: { name: 'Tag3-updated-1' }, + }, + }, + }, + include: { tags: true }, + }) + ).resolves.toMatchObject({ + tags: [ + expect.objectContaining({ + id: 2, + name: 'Tag2-updated', + }), + expect.objectContaining({ + id: 3, + name: 'Tag3-updated-1', + }), + ], + }); + + // upsert - create + await expect( + client.user.update({ + where: { id: 1 }, + data: { + tags: { + upsert: { + where: { id: 4 }, + create: { id: 4, name: 'Tag4' }, + update: { name: 'Tag4' }, + }, + }, + }, + include: { tags: true }, + }) + ).resolves.toMatchObject({ + tags: expect.arrayContaining([ + expect.objectContaining({ id: 4, name: 'Tag4' }), + ]), + }); + + // delete - not found + await expect( + client.user.update({ + where: { id: 1 }, + data: { tags: { delete: { id: 1 } } }, + }) + ).toBeRejectedNotFound(); + + // delete - found + await expect( + client.user.update({ + where: { id: 1 }, + data: { tags: { delete: { id: 2 } } }, + include: { tags: true }, + }) + ).resolves.toMatchObject({ + tags: [ + expect.objectContaining({ id: 3 }), + expect.objectContaining({ id: 4 }), + ], + }); + await expect( + client.tag.findUnique({ where: { id: 2 } }) + ).toResolveNull(); + + // deleteMany + await expect( + client.user.update({ + where: { id: 1 }, + data: { + tags: { deleteMany: { id: { in: [1, 2, 3] } } }, + }, + include: { tags: true }, + }) + ).resolves.toMatchObject({ + tags: [expect.objectContaining({ id: 4 })], + }); + await expect( + client.tag.findUnique({ where: { id: 3 } }) + ).toResolveNull(); + await expect( + client.tag.findUnique({ where: { id: 1 } }) + ).toResolveTruthy(); }); - // upsert - create - await expect( - client.user.update({ - where: { id: 1 }, + it('works with delete', async () => { + await client.user.create({ data: { + id: 1, + name: 'User1', tags: { - upsert: { - where: { id: 4 }, - create: { id: 4, name: 'Tag4' }, - update: { name: 'Tag4' }, - }, + create: [ + { id: 1, name: 'Tag1' }, + { id: 2, name: 'Tag2' }, + ], }, }, - include: { tags: true }, - }) - ).resolves.toMatchObject({ - tags: expect.arrayContaining([ - expect.objectContaining({ id: 4, name: 'Tag4' }), - ]), - }); - - // delete - not found - await expect( - client.user.update({ - where: { id: 1 }, - data: { tags: { delete: { id: 1 } } }, - }) - ).toBeRejectedNotFound(); - - // delete - found - await expect( - client.user.update({ - where: { id: 1 }, - data: { tags: { delete: { id: 2 } } }, - include: { tags: true }, - }) - ).resolves.toMatchObject({ - tags: [ - expect.objectContaining({ id: 3 }), - expect.objectContaining({ id: 4 }), - ], - }); - await expect( - client.tag.findUnique({ where: { id: 2 } }) - ).toResolveNull(); + }); - // deleteMany - await expect( - client.user.update({ + // cascade from tag + await client.tag.delete({ where: { id: 1 }, - data: { tags: { deleteMany: { id: { in: [1, 2, 3] } } } }, - include: { tags: true }, - }) - ).resolves.toMatchObject({ - tags: [expect.objectContaining({ id: 4 })], - }); - await expect( - client.tag.findUnique({ where: { id: 3 } }) - ).toResolveNull(); - await expect( - client.tag.findUnique({ where: { id: 1 } }) - ).toResolveTruthy(); - }); - - it('works with delete', async () => { - await client.user.create({ - data: { - id: 1, - name: 'User1', - tags: { - create: [ - { id: 1, name: 'Tag1' }, - { id: 2, name: 'Tag2' }, - ], - }, - }, - }); - - // cascade from tag - await client.tag.delete({ - where: { id: 1 }, - }); - await expect( - client.user.findUnique({ + }); + await expect( + client.user.findUnique({ + where: { id: 1 }, + include: { tags: true }, + }) + ).resolves.toMatchObject({ + tags: [expect.objectContaining({ id: 2 })], + }); + + // cascade from user + await client.user.delete({ where: { id: 1 }, - include: { tags: true }, - }) - ).resolves.toMatchObject({ - tags: [expect.objectContaining({ id: 2 })], - }); - - // cascade from user - await client.user.delete({ - where: { id: 1 }, + }); + await expect( + client.tag.findUnique({ + where: { id: 2 }, + include: { users: true }, + }) + ).resolves.toMatchObject({ + users: [], + }); }); - await expect( - client.tag.findUnique({ - where: { id: 2 }, - include: { users: true }, - }) - ).resolves.toMatchObject({ - users: [], - }); - }); - }); + } + ); }); diff --git a/packages/runtime/test/utils.ts b/packages/runtime/test/utils.ts index 4e87db46..f3508ab1 100644 --- a/packages/runtime/test/utils.ts +++ b/packages/runtime/test/utils.ts @@ -107,27 +107,29 @@ export async function createTestClient( prismaSchemaText ); execSync( - 'npx prisma db push --schema ./schema.prisma --skip-generate', + 'npx prisma db push --schema ./schema.prisma --skip-generate --force-reset', { cwd: workDir!, stdio: 'inherit', } ); + } else { + if (options?.provider === 'postgresql') { + invariant(options?.dbName, 'dbName is required'); + const pgClient = new PGClient(TEST_PG_CONFIG); + await pgClient.connect(); + await pgClient.query( + `DROP DATABASE IF EXISTS "${options!.dbName}"` + ); + await pgClient.query(`CREATE DATABASE "${options!.dbName}"`); + await pgClient.end(); + } } const { plugins, usePrismaPush, ...rest } = options ?? {}; let client = new ZenStackClient(_schema, rest as ClientOptions); - if (options?.provider === 'postgresql') { - invariant(options?.dbName, 'dbName is required'); - const pgClient = new PGClient(TEST_PG_CONFIG); - await pgClient.connect(); - await pgClient.query(`DROP DATABASE IF EXISTS "${options!.dbName}"`); - await pgClient.query(`CREATE DATABASE "${options!.dbName}"`); - await pgClient.end(); - } - if (!usePrismaPush) { await client.$pushSchema(); }