diff --git a/packages/cli/package.json b/packages/cli/package.json index 6d2a6f9d..328d569c 100644 --- a/packages/cli/package.json +++ b/packages/cli/package.json @@ -44,12 +44,12 @@ }, "devDependencies": { "@types/better-sqlite3": "^7.6.13", - "@types/tmp": "^0.2.6", + "@types/tmp": "catalog:", "@zenstackhq/eslint-config": "workspace:*", "@zenstackhq/runtime": "workspace:*", "@zenstackhq/testtools": "workspace:*", "@zenstackhq/typescript-config": "workspace:*", "better-sqlite3": "^11.8.1", - "tmp": "^0.2.3" + "tmp": "catalog:" } } diff --git a/packages/cli/src/actions/action-utils.ts b/packages/cli/src/actions/action-utils.ts index b11b671b..a6e4ec2d 100644 --- a/packages/cli/src/actions/action-utils.ts +++ b/packages/cli/src/actions/action-utils.ts @@ -1,4 +1,5 @@ import { loadDocument } from '@zenstackhq/language'; +import { isDataSource } from '@zenstackhq/language/ast'; import { PrismaSchemaGenerator } from '@zenstackhq/sdk'; import colors from 'colors'; import fs from 'node:fs'; @@ -41,6 +42,9 @@ export async function loadSchemaDocument(schemaFile: string) { }); throw new CliError('Failed to load schema'); } + loadResult.warnings.forEach((warn) => { + console.warn(colors.yellow(warn)); + }); return loadResult.model; } @@ -54,6 +58,9 @@ export function handleSubProcessError(err: unknown) { export async function generateTempPrismaSchema(zmodelPath: string, folder?: string) { const model = await loadSchemaDocument(zmodelPath); + if (!model.declarations.some(isDataSource)) { + throw new CliError('Schema must define a datasource'); + } const prismaSchema = await new PrismaSchemaGenerator(model).generate(); if (!folder) { folder = path.dirname(zmodelPath); diff --git a/packages/cli/test/ts-schema-gen.test.ts b/packages/cli/test/ts-schema-gen.test.ts index cd34de58..18c1e7d9 100644 --- a/packages/cli/test/ts-schema-gen.test.ts +++ b/packages/cli/test/ts-schema-gen.test.ts @@ -1,5 +1,7 @@ import { ExpressionUtils } from '@zenstackhq/runtime/schema'; -import { generateTsSchema } from '@zenstackhq/testtools'; +import { createTestProject, generateTsSchema, generateTsSchemaInPlace } from '@zenstackhq/testtools'; +import fs from 'node:fs'; +import path from 'node:path'; import { describe, expect, it } from 'vitest'; describe('TypeScript schema generation tests', () => { @@ -325,4 +327,37 @@ model User extends Base { }, }); }); + + it('merges all declarations from imported modules', async () => { + const workDir = createTestProject(); + fs.writeFileSync( + path.join(workDir, 'a.zmodel'), + ` + enum Role { + Admin + User + } + `, + ); + fs.writeFileSync( + path.join(workDir, 'b.zmodel'), + ` + import './a' + + datasource db { + provider = 'sqlite' + url = 'file:./test.db' + } + + model User { + id Int @id + role Role + } + `, + ); + + const { schema } = await generateTsSchemaInPlace(path.join(workDir, 'b.zmodel')); + expect(schema.enums).toMatchObject({ Role: expect.any(Object) }); + expect(schema.models).toMatchObject({ User: expect.any(Object) }); + }); }); diff --git a/packages/language/package.json b/packages/language/package.json index 90606145..1505e239 100644 --- a/packages/language/package.json +++ b/packages/language/package.json @@ -62,7 +62,9 @@ "@zenstackhq/eslint-config": "workspace:*", "@zenstackhq/typescript-config": "workspace:*", "@zenstackhq/common-helpers": "workspace:*", - "langium-cli": "catalog:" + "langium-cli": "catalog:", + "tmp": "catalog:", + "@types/tmp": "catalog:" }, "volta": { "node": "18.19.1", diff --git a/packages/language/src/index.ts b/packages/language/src/index.ts index 3c9b23f5..fdd3b544 100644 --- a/packages/language/src/index.ts +++ b/packages/language/src/index.ts @@ -1,11 +1,12 @@ -import { URI } from 'langium'; +import { isAstNode, URI, type LangiumDocument, type LangiumDocuments, type Mutable } from 'langium'; import { NodeFileSystem } from 'langium/node'; import fs from 'node:fs'; import path from 'node:path'; import { fileURLToPath } from 'node:url'; -import type { Model } from './ast'; +import { isDataSource, type AstNode, type Model } from './ast'; import { STD_LIB_MODULE_NAME } from './constants'; import { createZModelLanguageServices } from './module'; +import { getDataModelAndTypeDefs, getDocument, hasAttribute, resolveImport, resolveTransitiveImports } from './utils'; export function createZModelServices() { return createZModelLanguageServices(NodeFileSystem); @@ -60,8 +61,15 @@ export async function loadDocument( const langiumDocuments = services.shared.workspace.LangiumDocuments; const document = await langiumDocuments.getOrCreateDocument(URI.file(path.resolve(fileName))); + // load imports + const importedURIs = await loadImports(document, langiumDocuments); + const importedDocuments: LangiumDocument[] = []; + for (const uri of importedURIs) { + importedDocuments.push(await langiumDocuments.getOrCreateDocument(uri)); + } + // build the document together with standard library, plugin modules, and imported documents - await services.shared.workspace.DocumentBuilder.build([stdLib, ...pluginDocs, document], { + await services.shared.workspace.DocumentBuilder.build([stdLib, ...pluginDocs, document, ...importedDocuments], { validation: true, }); @@ -95,6 +103,27 @@ export async function loadDocument( }; } + const model = document.parseResult.value as Model; + + // merge all declarations into the main document + const imported = mergeImportsDeclarations(langiumDocuments, model); + + // remove imported documents + imported.forEach((model) => { + langiumDocuments.deleteDocument(model.$document!.uri); + services.shared.workspace.IndexManager.remove(model.$document!.uri); + }); + + // extra validation after merging imported declarations + const additionalErrors = validationAfterImportMerge(model); + if (additionalErrors.length > 0) { + return { + success: false, + errors: additionalErrors, + warnings, + }; + } + return { success: true, model: document.parseResult.value as Model, @@ -102,4 +131,72 @@ export async function loadDocument( }; } +async function loadImports(document: LangiumDocument, documents: LangiumDocuments, uris: Set = new Set()) { + const uriString = document.uri.toString(); + if (!uris.has(uriString)) { + uris.add(uriString); + const model = document.parseResult.value as Model; + for (const imp of model.imports) { + const importedModel = resolveImport(documents, imp); + if (importedModel) { + const importedDoc = getDocument(importedModel); + await loadImports(importedDoc, documents, uris); + } + } + } + return Array.from(uris) + .filter((x) => uriString != x) + .map((e) => URI.parse(e)); +} + +function mergeImportsDeclarations(documents: LangiumDocuments, model: Model) { + const importedModels = resolveTransitiveImports(documents, model); + + const importedDeclarations = importedModels.flatMap((m) => m.declarations); + model.declarations.push(...importedDeclarations); + + // remove import directives + model.imports = []; + + // fix $container, $containerIndex, and $containerProperty + linkContentToContainer(model); + + return importedModels; +} + +function linkContentToContainer(node: AstNode): void { + for (const [name, value] of Object.entries(node)) { + if (!name.startsWith('$')) { + if (Array.isArray(value)) { + value.forEach((item, index) => { + if (isAstNode(item)) { + (item as Mutable).$container = node; + (item as Mutable).$containerProperty = name; + (item as Mutable).$containerIndex = index; + } + }); + } else if (isAstNode(value)) { + (value as Mutable).$container = node; + (value as Mutable).$containerProperty = name; + } + } + } +} + +function validationAfterImportMerge(model: Model) { + const errors: string[] = []; + const dataSources = model.declarations.filter((d) => isDataSource(d)); + if (dataSources.length > 1) { + errors.push('Validation error: Multiple datasource declarations are not allowed'); + } + + // at most one `@@auth` model + const decls = getDataModelAndTypeDefs(model, true); + const authDecls = decls.filter((d) => hasAttribute(d, '@@auth')); + if (authDecls.length > 1) { + errors.push('Validation error: Multiple `@@auth` declarations are not allowed'); + } + return errors; +} + export * from './module'; diff --git a/packages/language/src/utils.ts b/packages/language/src/utils.ts index adb4f78f..06677192 100644 --- a/packages/language/src/utils.ts +++ b/packages/language/src/utils.ts @@ -1,5 +1,5 @@ import { invariant } from '@zenstackhq/common-helpers'; -import { AstUtils, URI, type AstNode, type LangiumDocuments, type Reference } from 'langium'; +import { AstUtils, URI, type AstNode, type LangiumDocument, type LangiumDocuments, type Reference } from 'langium'; import fs from 'node:fs'; import path from 'path'; import { STD_LIB_MODULE_NAME, type ExpressionContext } from './constants'; @@ -413,38 +413,13 @@ export function resolveImport(documents: LangiumDocuments, imp: ModelImport) { } export function resolveImportUri(imp: ModelImport) { - if (!imp.path) return undefined; // This will return true if imp.path is undefined, null, or an empty string (""). - - if (!imp.path.endsWith('.zmodel')) { - imp.path += '.zmodel'; - } - - if ( - !imp.path.startsWith('.') && // Respect relative paths - !path.isAbsolute(imp.path) // Respect Absolute paths - ) { - // use the current model's path as the search context - const contextPath = imp.$container.$document - ? path.dirname(imp.$container.$document.uri.fsPath) - : process.cwd(); - imp.path = findNodeModulesFile(imp.path, contextPath) ?? imp.path; + if (!imp.path) { + return undefined; } - const doc = AstUtils.getDocument(imp); const dir = path.dirname(doc.uri.fsPath); - return URI.file(path.resolve(dir, imp.path)); -} - -export function findNodeModulesFile(name: string, cwd: string = process.cwd()) { - if (!name) return undefined; - try { - // Use require.resolve to find the module/file. The paths option allows specifying the directory to start from. - const resolvedPath = require.resolve(name, { paths: [cwd] }); - return resolvedPath; - } catch { - // If require.resolve fails to find the module/file, it will throw an error. - return undefined; - } + const importPath = imp.path.endsWith('.zmodel') ? imp.path : `${imp.path}.zmodel`; + return URI.file(path.resolve(dir, importPath)); } /** @@ -577,3 +552,28 @@ export function getAllAttributes( attributes.push(...decl.attributes); return attributes; } + +/** + * Retrieve the document in which the given AST node is contained. A reference to the document is + * usually held by the root node of the AST. + * + * @throws an error if the node is not contained in a document. + */ +export function getDocument(node: AstNode): LangiumDocument { + const rootNode = findRootNode(node); + const result = rootNode.$document; + if (!result) { + throw new Error('AST node has no document.'); + } + return result as LangiumDocument; +} + +/** + * Returns the root node of the given AST node by following the `$container` references. + */ +export function findRootNode(node: AstNode): AstNode { + while (node.$container) { + node = node.$container; + } + return node; +} diff --git a/packages/language/src/validators/schema-validator.ts b/packages/language/src/validators/schema-validator.ts index 5d856380..69d5a801 100644 --- a/packages/language/src/validators/schema-validator.ts +++ b/packages/language/src/validators/schema-validator.ts @@ -47,8 +47,8 @@ export default class SchemaValidator implements AstValidator { private validateImports(model: Model, accept: ValidationAcceptor) { model.imports.forEach((imp) => { const importedModel = resolveImport(this.documents, imp); - const importPath = imp.path.endsWith('.zmodel') ? imp.path : `${imp.path}.zmodel`; if (!importedModel) { + const importPath = imp.path.endsWith('.zmodel') ? imp.path : `${imp.path}.zmodel`; accept('error', `Cannot find model file ${importPath}`, { node: imp, }); diff --git a/packages/language/test/import.test.ts b/packages/language/test/import.test.ts new file mode 100644 index 00000000..48cec382 --- /dev/null +++ b/packages/language/test/import.test.ts @@ -0,0 +1,101 @@ +import { invariant } from '@zenstackhq/common-helpers'; +import fs from 'node:fs'; +import path from 'node:path'; +import tmp from 'tmp'; +import { describe, expect, it } from 'vitest'; +import { loadDocument } from '../src'; +import { DataModel, isDataModel } from '../src/ast'; + +describe('Import tests', () => { + it('merges declarations', async () => { + const { name } = tmp.dirSync(); + fs.writeFileSync( + path.join(name, 'a.zmodel'), + ` +model A { + id Int @id + name String +} + `, + ); + fs.writeFileSync( + path.join(name, 'b.zmodel'), + ` +import './a' +model B { + id Int @id +} + `, + ); + + const model = await expectLoaded(path.join(name, 'b.zmodel')); + expect(model.declarations.filter(isDataModel)).toHaveLength(2); + expect(model.imports).toHaveLength(0); + }); + + it('resolves imported symbols', async () => { + const { name } = tmp.dirSync(); + fs.writeFileSync( + path.join(name, 'a.zmodel'), + ` +enum Role { + Admin + User +} + `, + ); + fs.writeFileSync( + path.join(name, 'b.zmodel'), + ` +import './a' +model User { + id Int @id + role Role +} +`, + ); + + const model = await expectLoaded(path.join(name, 'b.zmodel')); + expect((model.declarations[0] as DataModel).fields[1].type.reference?.ref?.name).toBe('Role'); + }); + + it('supports cyclic imports', async () => { + const { name } = tmp.dirSync(); + fs.writeFileSync( + path.join(name, 'a.zmodel'), + ` +import './b' +model A { + id Int @id + b B? +} + `, + ); + fs.writeFileSync( + path.join(name, 'b.zmodel'), + ` +import './a' +model B { + id Int @id + a A @relation(fields: [aId], references: [id]) + aId Int @unique +} +`, + ); + + const modelB = await expectLoaded(path.join(name, 'b.zmodel')); + expect((modelB.declarations[0] as DataModel).fields[1].type.reference?.ref?.name).toBe('A'); + const modelA = await expectLoaded(path.join(name, 'a.zmodel')); + expect((modelA.declarations[0] as DataModel).fields[1].type.reference?.ref?.name).toBe('B'); + }); + + async function expectLoaded(file: string) { + const result = await loadDocument(file); + if (!result.success) { + console.error('Errors:', result.errors); + throw new Error(`Failed to load document from ${file}`); + } + invariant(result.success); + return result.model; + } +}); diff --git a/packages/runtime/test/client-api/import.test.ts b/packages/runtime/test/client-api/import.test.ts new file mode 100644 index 00000000..98e43f77 --- /dev/null +++ b/packages/runtime/test/client-api/import.test.ts @@ -0,0 +1,71 @@ +import { createTestProject, generateTsSchemaInPlace } from '@zenstackhq/testtools'; +import fs from 'node:fs'; +import path from 'node:path'; +import { describe, expect, it } from 'vitest'; +import { createTestClient } from '../utils'; + +describe('Import tests', () => { + it('works with imported models', async () => { + const workDir = createTestProject(); + + fs.writeFileSync( + path.join(workDir, 'user.zmodel'), + ` + import './post' + model User { + id Int @id @default(autoincrement()) + email String + posts Post[] + } + `, + ); + fs.writeFileSync( + path.join(workDir, 'post.zmodel'), + ` + import './user' + + model Post { + id Int @id @default(autoincrement()) + title String + author User @relation(fields: [authorId], references: [id]) + authorId Int + } + `, + ); + fs.writeFileSync( + path.join(workDir, 'main.zmodel'), + ` + import './user' + import './post' + + datasource db { + provider = "sqlite" + url = "file:./dev.db" + } + `, + ); + + const { schema } = await generateTsSchemaInPlace(path.join(workDir, 'main.zmodel')); + const client: any = await createTestClient(schema); + + await expect( + client.user.create({ + data: { + id: 1, + email: 'u1@test.com', + posts: { + create: { title: 'Post1' }, + }, + }, + include: { posts: true }, + }), + ).resolves.toMatchObject({ + email: 'u1@test.com', + posts: [ + expect.objectContaining({ + title: 'Post1', + }), + ], + }); + }); +}); diff --git a/packages/runtime/test/utils.ts b/packages/runtime/test/utils.ts index 449e82a5..515dcaa5 100644 --- a/packages/runtime/test/utils.ts +++ b/packages/runtime/test/utils.ts @@ -62,6 +62,7 @@ export type CreateTestClientOptions = Omit; + workDir?: string; }; export async function createTestClient( @@ -78,7 +79,7 @@ export async function createTestClient( options?: CreateTestClientOptions, schemaFile?: string, ): Promise { - let workDir: string | undefined; + let workDir = options?.workDir; let _schema: Schema; const provider = options?.provider ?? 'sqlite'; @@ -114,7 +115,7 @@ export async function createTestClient( type: provider, }, }; - workDir = await createTestProject(); + workDir ??= createTestProject(); if (schemaFile) { let schemaContent = fs.readFileSync(schemaFile, 'utf-8'); if (dbUrl) { @@ -131,7 +132,9 @@ export async function createTestClient( } } - console.log(`Work directory: ${workDir}`); + if (workDir) { + console.log(`Work directory: ${workDir}`); + } const { plugins, ...rest } = options ?? {}; const _options: ClientOptions = { @@ -148,7 +151,7 @@ export async function createTestClient( const prismaSchemaText = await prismaSchema.generate(); fs.writeFileSync(path.resolve(workDir, 'schema.prisma'), prismaSchemaText); execSync('npx prisma db push --schema ./schema.prisma --skip-generate --force-reset', { - cwd: workDir!, + cwd: workDir, stdio: 'inherit', }); } else { diff --git a/packages/testtools/package.json b/packages/testtools/package.json index f6d4b247..b84edbc9 100644 --- a/packages/testtools/package.json +++ b/packages/testtools/package.json @@ -31,7 +31,7 @@ "@zenstackhq/language": "workspace:*", "@zenstackhq/sdk": "workspace:*", "glob": "^11.0.2", - "tmp": "^0.2.3", + "tmp": "catalog:", "ts-pattern": "catalog:", "prisma": "catalog:", "typescript": "catalog:" @@ -41,7 +41,7 @@ "pg": "^8.13.1" }, "devDependencies": { - "@types/tmp": "^0.2.6", + "@types/tmp": "catalog:", "@zenstackhq/eslint-config": "workspace:*", "@zenstackhq/typescript-config": "workspace:*" } diff --git a/packages/testtools/src/schema.ts b/packages/testtools/src/schema.ts index 0acb0b87..48c43b97 100644 --- a/packages/testtools/src/schema.ts +++ b/packages/testtools/src/schema.ts @@ -54,6 +54,10 @@ export async function generateTsSchema( } // compile the generated TS schema + return compileAndLoad(workDir); +} + +async function compileAndLoad(workDir: string) { execSync('npx tsc', { cwd: workDir, stdio: 'inherit', @@ -68,3 +72,12 @@ export function generateTsSchemaFromFile(filePath: string) { const schemaText = fs.readFileSync(filePath, 'utf8'); return generateTsSchema(schemaText); } + +export async function generateTsSchemaInPlace(schemaPath: string) { + const workDir = path.dirname(schemaPath); + const pluginModelFiles = glob.sync(path.resolve(__dirname, '../../runtime/src/plugins/**/plugin.zmodel')); + + const generator = new TsSchemaGenerator(); + await generator.generate(schemaPath, pluginModelFiles, workDir); + return compileAndLoad(workDir); +} diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 835af938..6fa11704 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -6,6 +6,9 @@ settings: catalogs: default: + '@types/tmp': + specifier: ^0.2.6 + version: 0.2.6 kysely: specifier: ^0.27.6 version: 0.27.6 @@ -18,6 +21,9 @@ catalogs: prisma: specifier: ^6.0.0 version: 6.9.0 + tmp: + specifier: ^0.2.3 + version: 0.2.3 ts-pattern: specifier: ^5.7.1 version: 5.7.1 @@ -106,7 +112,7 @@ importers: specifier: ^7.6.13 version: 7.6.13 '@types/tmp': - specifier: ^0.2.6 + specifier: 'catalog:' version: 0.2.6 '@zenstackhq/eslint-config': specifier: workspace:* @@ -124,7 +130,7 @@ importers: specifier: ^11.8.1 version: 11.8.1 tmp: - specifier: ^0.2.3 + specifier: 'catalog:' version: 0.2.3 packages/common-helpers: @@ -197,6 +203,9 @@ importers: '@types/pluralize': specifier: ^0.0.33 version: 0.0.33 + '@types/tmp': + specifier: 'catalog:' + version: 0.2.6 '@zenstackhq/common-helpers': specifier: workspace:* version: link:../common-helpers @@ -209,6 +218,9 @@ importers: langium-cli: specifier: 'catalog:' version: 3.5.0 + tmp: + specifier: 'catalog:' + version: 0.2.3 packages/runtime: dependencies: @@ -342,7 +354,7 @@ importers: specifier: 'catalog:' version: 6.9.0(typescript@5.8.3) tmp: - specifier: ^0.2.3 + specifier: 'catalog:' version: 0.2.3 ts-pattern: specifier: 'catalog:' @@ -352,7 +364,7 @@ importers: version: 5.8.3 devDependencies: '@types/tmp': - specifier: ^0.2.6 + specifier: 'catalog:' version: 0.2.6 '@zenstackhq/eslint-config': specifier: workspace:* diff --git a/pnpm-workspace.yaml b/pnpm-workspace.yaml index fcdfd392..0c30abb5 100644 --- a/pnpm-workspace.yaml +++ b/pnpm-workspace.yaml @@ -11,3 +11,5 @@ catalog: langium-cli: 3.5.0 ts-pattern: ^5.7.1 typescript: ^5.0.0 + tmp: ^0.2.3 + '@types/tmp': ^0.2.6