diff --git a/packages/stl-api-gen/package.json b/packages/stl-api-gen/package.json index 0f7d6344..1f37a577 100644 --- a/packages/stl-api-gen/package.json +++ b/packages/stl-api-gen/package.json @@ -23,6 +23,8 @@ "dependencies": { "@types/lodash": "^4.14.195", "chalk": "4.1.2", + "chokidar": "^3.5.3", + "commander": "^11.0.0", "lodash": "^4.17.21", "pkg-up": "3.1", "ts-morph": "^19.0.0", diff --git a/packages/stl-api-gen/src/endpointPreprocess.ts b/packages/stl-api-gen/src/endpointPreprocess.ts new file mode 100644 index 00000000..94de7847 --- /dev/null +++ b/packages/stl-api-gen/src/endpointPreprocess.ts @@ -0,0 +1,96 @@ +import * as tm from "ts-morph"; + +import { isSymbolStlMethod } from "./utils"; +import { getPropertyDeclaration } from "ts-to-zod/dist/convertType"; + +export type NodeType = [tm.Node, tm.Type]; + +export interface EndpointTypeInstance { + endpointPath: string; + callExpression: tm.Node, + query?: NodeType, + path?: NodeType, + body?: NodeType, + response?: NodeType, +} + +// call expression is the expression (...endpoint()) +export function handleEndpoint(callExpression: tm.CallExpression): EndpointTypeInstance | undefined { + // lhs is the value on which .endpoint is being called + const lhs = callExpression.getExpression(); + const endpointArgs = callExpression.getArguments(); + if (endpointArgs.length !== 1) return; + const [endpointArg] = endpointArgs; + const endpointType = endpointArg.getType(); + + const endpointProperty = endpointType.getProperty("endpoint"); + if (!endpointProperty) return; + const endpointPropertyType = + endpointProperty.getTypeAtLocation(callExpression); + const endpointPath = endpointPropertyType.getLiteralValue(); + if (typeof endpointPath !== "string") return; + + if (lhs instanceof tm.PropertyAccessExpression) { + const typesExpr = lhs.getExpression(); + // we want .endpoint to be called on stl.types<{...}>() + if (typesExpr instanceof tm.CallExpression) { + const typesReceiver = typesExpr.getExpression(); + const symbol = typesReceiver.getSymbol(); + if (!symbol || !isSymbolStlMethod(symbol)) return; + + if (symbol.getEscapedName() !== "types") return; + + // types() cannot be called with arguments + if (typesExpr.getArguments().length) return; + + const typeArguments = typesExpr.getTypeArguments(); + if (typeArguments.length !== 1) return; + const [typeRef] = typeArguments; + const schemaTypes = typeRef.getType(); + + if (!schemaTypes.isObject()) return; + + let queryNodeType; + let pathNodeType; + let bodyNodeType; + let responseNodeType; + + for (const property of schemaTypes.getProperties()) { + const name = property.getName(); + switch (name) { + case "query": + queryNodeType = propertyToNodeType(property, typesExpr); + break; + case "path": + pathNodeType = propertyToNodeType(property, typesExpr); + break; + case "body": + bodyNodeType = propertyToNodeType(property, typesExpr); + break; + case "response": + responseNodeType = propertyToNodeType(property, typesExpr); + break; + default: + // TODO: add diagnostic for ignored field + continue; + } + } + + return { + endpointPath, + callExpression, + query: queryNodeType, + path: pathNodeType, + body: bodyNodeType, + response: responseNodeType, + } + } + } +} + +function propertyToNodeType(property: tm.Symbol, location: tm.Node): NodeType { + const node = getPropertyDeclaration(property); + if (!node) throw new Error("internal error: invalid property encountered"); + return [node, property.getTypeAtLocation(location)]; + +} diff --git a/packages/stl-api-gen/src/index.ts b/packages/stl-api-gen/src/index.ts index 33d99810..94960243 100644 --- a/packages/stl-api-gen/src/index.ts +++ b/packages/stl-api-gen/src/index.ts @@ -21,15 +21,23 @@ import { import { generateFiles, generateImportStatements, + generatePath, } from "ts-to-zod/dist/generateFiles"; import { GenOptions, createGenerationConfig, } from "ts-to-zod/dist/filePathConfig"; + import { Watcher } from "./watch"; +import { + EndpointTypeInstance, + NodeType, + handleEndpoint, +} from "./endpointPreprocess"; import { program as argParser } from "commander"; +import { convertPathToImport, isSymbolStlMethod, mangleString } from "./utils"; // TODO: add dry run functionality? argParser.option("-w, --watch", "enables watch mode"); @@ -39,9 +47,25 @@ argParser.option( "." ); -const NODE_MODULES_GEN_PATH = "stl-api/gen/"; +const NODE_MODULES_GEN_PATH = "@stl-api/gen"; + +const Z_IMPORT_STATEMENT = factory.createImportDeclaration( + [], + factory.createImportClause( + false, + undefined, + factory.createNamedImports([ + factory.createImportSpecifier( + false, + undefined, + factory.createIdentifier("z") + ), + ]) + ), + factory.createStringLiteral("stainless") +); -interface MagicCallDiagnostics { +interface CallDiagnostics { line: number; column: number; filePath: string; @@ -58,7 +82,7 @@ async function main() { if (!packageJsonPath) { console.error( `Folder ${Path.relative( - "./", + ".", options.directory )} and its parent directories do not contain a package.json.` ); @@ -99,7 +123,7 @@ async function main() { if (watcher) { for await (const { path } of watcher.getEvents()) { console.clear(); - const relativePath = Path.relative("./", path); + const relativePath = Path.relative(".", path); console.log(`Found change in ${relativePath}.`); const succeeded = await evaluate( watcher.project, @@ -145,7 +169,10 @@ async function evaluate( ): Promise { const generationConfig = createGenerationConfig(generationOptions); - const callDiagnostics: MagicCallDiagnostics[] = []; + const callDiagnostics: CallDiagnostics[] = []; + const endpointCalls: Map = new Map(); + + // all of the stl.types calls found for (const file of project.getSourceFiles()) { const ctx = new ConvertTypeContext(baseCtx, file); @@ -161,18 +188,17 @@ async function evaluate( )) { const receiverExpression = callExpression.getExpression(); const symbol = receiverExpression.getSymbol(); - if (!symbol) continue; + if (!symbol || !isSymbolStlMethod(symbol)) continue; - const symbolDeclaration = symbol.getDeclarations()[0]; - if (!symbolDeclaration) continue; - const symbolDeclarationFile = symbolDeclaration - .getSourceFile() - .getFilePath(); + const methodName = symbol.getEscapedName(); - if ( - symbol.getEscapedName() !== "magic" || - symbolDeclarationFile.indexOf("stl.d.ts") < 0 - ) { + if (!(methodName === "magic" || methodName === "endpoint")) continue; + + if (methodName == "endpoint") { + const call = handleEndpoint(callExpression); + if (!call) continue; + const fileCalls = getOrInsert(endpointCalls, file, () => []); + fileCalls.push(call); continue; } @@ -204,17 +230,9 @@ async function evaluate( addDiagnostics(ctx, file, callExpression, callDiagnostics); } const name = symbol.getName(); - let as; const declaration = symbol.getDeclarations()[0]; // TODO factor out this logic in ts-to-zod and export a function - if (type.isEnum()) { - as = `__enum_${name}`; - } else if (type.isClass()) { - as = `__class_${name}`; - } else { - as = `__symbol_${name}`; - } - + const as = mangleTypeName(type, name); const declarationFilePath = declaration.getSourceFile().getFilePath(); if (declarationFilePath === file.getFilePath()) { @@ -297,7 +315,7 @@ async function evaluate( const fileImportDeclarations = file.getImportDeclarations(); for (const importDecl of fileImportDeclarations) { const sourcePath = importDecl.getModuleSpecifier().getLiteralValue(); - if (sourcePath.indexOf("stl-api/gen") == 0) { + if (sourcePath.indexOf(NODE_MODULES_GEN_PATH) == 0) { fileOperations.push(() => importDecl.remove()); } else if (sourcePath === "stainless") { for (const specifier of importDecl.getNamedImports()) { @@ -331,6 +349,155 @@ async function evaluate( fileOperations.forEach((op) => op()); } + const generatedFileContents = generateFiles(baseCtx, generationOptions); + + if (endpointCalls.size) { + const mapEntries = []; + + outer: for (const [file, calls] of endpointCalls) { + const ctx = new ConvertTypeContext(baseCtx, file); + for (const call of calls) { + const mangledName = mangleString(call.endpointPath); + const importExpression = factory.createCallExpression( + factory.createToken(ts.SyntaxKind.ImportKeyword) as ts.Expression, + undefined, + [ + factory.createStringLiteral( + convertPathToImport( + generatePath(file.getFilePath(), generationConfig) + ) + ), + ] + ); + const callExpression = factory.createCallExpression( + factory.createPropertyAccessExpression(importExpression, "then"), + undefined, + [ + factory.createArrowFunction( + undefined, + undefined, + [factory.createParameterDeclaration(undefined, undefined, "mod")], + undefined, + undefined, + factory.createPropertyAccessExpression( + factory.createIdentifier("mod"), + mangledName + ) + ), + ] + ); + const entry = factory.createPropertyAssignment( + factory.createStringLiteral(call.endpointPath), + callExpression + ); + mapEntries.push(entry); + const filex = generatePath(file.getFilePath(), generationConfig); + + const generatedStatements = getOrInsert( + generatedFileContents, + generatePath(file.getFilePath(), generationConfig), + () => [Z_IMPORT_STATEMENT] + ); + + const objectProperties = []; + + const requestTypes = [ + ["query", call.query], + ["path", call.path], + ["body", call.body], + ].filter(([_, type]) => type) as [string, NodeType][]; + for (const [name, nodeType] of requestTypes) { + const schemaExpression = convertEndpointType( + ctx, + call.callExpression, + callDiagnostics, + file, + nodeType[0], + nodeType[1] + ); + if (!schemaExpression) break outer; + objectProperties.push( + factory.createPropertyAssignment(name, schemaExpression) + ); + } + + let schemaExpression; + if (call.response) { + schemaExpression = convertEndpointType( + ctx, + call.callExpression, + callDiagnostics, + file, + call.response[0], + call.response[1] + ); + if (!schemaExpression) break outer; + } else { + // if no response type is provided, use the default schema z.void() to indicate no response + schemaExpression = factory.createCallExpression( + factory.createPropertyAccessExpression( + factory.createIdentifier("z"), + "void" + ), + [], + [] + ); + } + objectProperties.push( + factory.createPropertyAssignment("response", schemaExpression) + ); + + const variableDeclaration = factory.createVariableDeclarationList( + [ + factory.createVariableDeclaration( + mangledName, + undefined, + undefined, + factory.createObjectLiteralExpression(objectProperties) + ), + ], + ts.NodeFlags.Const + ); + + generatedStatements.push( + factory.createVariableStatement( + [factory.createToken(ts.SyntaxKind.ExportKeyword)], + variableDeclaration + ) + ); + } + } + + const mapConstant = factory.createVariableDeclarationList( + [ + factory.createVariableDeclaration( + "someName", + undefined, + undefined, + factory.createObjectLiteralExpression(mapEntries) + ), + ], + // ts.NodeFlags.Const + ); + const mapStatement = factory.createVariableStatement( + [factory.createToken(ts.SyntaxKind.ExportKeyword)], + mapConstant + ); + const mapSourceFile = factory.createSourceFile( + [mapStatement], + factory.createToken(ts.SyntaxKind.EndOfFileToken), + 0 + ); + + const genPath = Path.join(rootPath, "node_modules", NODE_MODULES_GEN_PATH); + const endpointMapGenPath = Path.join(genPath, "__endpointMap.js"); + await fs.promises.mkdir(genPath, { recursive: true }); + await fs.promises.writeFile( + endpointMapGenPath, + printer.printFile(mapSourceFile) + ); + } + if (callDiagnostics.length) { const output = []; let errorCount = 0; @@ -410,21 +577,25 @@ async function evaluate( } } - project.save(); - - const generatedFileContents = generateFiles(baseCtx, generationOptions); - - for (const [file, fileContents] of generatedFileContents) { + for (const [file, fileStatments] of generatedFileContents) { + console.log(`writing to file ${file}`); const fileDir = Path.dirname(file); // creates directory where to write file to, if it doesn't already exist await fs.promises.mkdir(fileDir, { recursive: true, }); + const sourceFile = factory.createSourceFile( + fileStatments, + factory.createToken(ts.SyntaxKind.EndOfFileToken), + 0 + ); + // write sourceFile to file - await fs.promises.writeFile(file, printer.printFile(fileContents)); + await fs.promises.writeFile(file, printer.printFile(sourceFile)); } + project.save(); return true; } @@ -451,7 +622,7 @@ function addDiagnostics( ctx: SchemaGenContext, file: tm.SourceFile, callExpression: tm.Node, - callDiagnostics: MagicCallDiagnostics[] + callDiagnostics: CallDiagnostics[] ) { if (ctx.diagnostics.size) { const { line, column } = file.getLineAndColumnAtPos( @@ -465,3 +636,66 @@ function addDiagnostics( }); } } + +function convertEndpointType( + ctx: ConvertTypeContext, + callExpression: tm.Node, + callDiagnostics: CallDiagnostics[], + diagnosticsFile: tm.SourceFile, + typeArgument: tm.Node, + type: tm.Type +): ts.Expression | undefined { + if ( + typeArgument instanceof tm.TypeReferenceNode && + typeArgument.getTypeArguments().length === 0 + ) { + const symbol = typeArgument.getTypeName().getSymbolOrThrow(); + try { + convertSymbol(ctx, symbol, { + variant: "node", + node: typeArgument, + }); + } catch (e) { + if (e instanceof ErrorAbort) return; + else throw e; + } finally { + addDiagnostics(ctx, diagnosticsFile, callExpression, callDiagnostics); + } + const name = symbol.getName(); + const as = mangleTypeName(type, name); + + return factory.createIdentifier(as || name); + } else { + try { + return convertType(ctx, type, { + variant: "node", + node: typeArgument, + }); + } catch (e) { + if (e instanceof ErrorAbort) return; + else throw e; + } finally { + addDiagnostics(ctx, diagnosticsFile, callExpression, callDiagnostics); + } + } +} + +// TODO: factor out to ts-to-zod? +function mangleTypeName(type: tm.Type, name: string): string { + if (type.isEnum()) { + return `__enum_${name}`; + } else if (type.isClass()) { + return `__class_${name}`; + } else { + return `__symbol_${name}`; + } +} + +function getOrInsert(map: Map, key: K, create: () => V): V { + let value = map.get(key); + if (!value) { + value = create(); + map.set(key, value); + } + return value; +} diff --git a/packages/stl-api-gen/src/utils.ts b/packages/stl-api-gen/src/utils.ts new file mode 100644 index 00000000..1315fe09 --- /dev/null +++ b/packages/stl-api-gen/src/utils.ts @@ -0,0 +1,40 @@ +import * as tm from "ts-morph"; +import Path from "path"; + +export function isSymbolStlMethod(symbol: tm.Symbol): boolean { + const symbolDeclaration = symbol.getDeclarations()[0]; + if (!symbolDeclaration) return false; + const symbolDeclarationFile = symbolDeclaration.getSourceFile().getFilePath(); + + return symbolDeclarationFile.indexOf("stl.d.ts") >= 0; +} + +export function mangleString(str: string): string { + const unicodeLetterRegex = /\p{L}/u; + const escapedStringBuilder = []; + for (const codePointString of str) { + if (codePointString === "/") { + escapedStringBuilder.push("$"); + } else if (unicodeLetterRegex.test(codePointString)) { + escapedStringBuilder.push(codePointString); + } else { + escapedStringBuilder.push(`u${codePointString.codePointAt(0)}`); + } + } + return escapedStringBuilder.join(""); +} + +function absoluteNodeModulesPath(path: string): string { + const nodeModulesPos = path.lastIndexOf("node_modules"); + if (nodeModulesPos >= 0) { + return path.substring(nodeModulesPos + 13); + } else return path; +} + +export function convertPathToImport(path: string): string { + const withAbsolute = absoluteNodeModulesPath(path); + // strip extension from path + // todo: we probably need to handle modules as well + const { dir, name } = Path.parse(withAbsolute); + return Path.join(dir, name); +} diff --git a/packages/ts-to-zod/src/__tests__/multiFileTestCase.ts b/packages/ts-to-zod/src/__tests__/multiFileTestCase.ts index e0d62efa..cde61653 100644 --- a/packages/ts-to-zod/src/__tests__/multiFileTestCase.ts +++ b/packages/ts-to-zod/src/__tests__/multiFileTestCase.ts @@ -1,4 +1,6 @@ import * as tm from "ts-morph"; +import { ts } from "ts-morph"; +const factory = ts.factory; import { SchemaGenContext, convertSymbol } from "../convertType"; import { testProject } from "./testProject"; import { generateFiles } from "../generateFiles"; @@ -47,9 +49,14 @@ export const multiFileTestCase = async (options: { }, rootPath, }; - for (const [file, sourceFile] of generateFiles(ctx, genOptions)) { + for (const [file, statements] of generateFiles(ctx, genOptions)) { const relativeFile = path.relative(rootPath, file); - result[relativeFile] = tm.ts.createPrinter().printFile(sourceFile); + const sourceFile = factory.createSourceFile( + statements, + factory.createToken(ts.SyntaxKind.EndOfFileToken), + 0 + ); + result[relativeFile] = ts.createPrinter().printFile(sourceFile); } return result; }; diff --git a/packages/ts-to-zod/src/convertType.ts b/packages/ts-to-zod/src/convertType.ts index c0f59007..7aa1b432 100644 --- a/packages/ts-to-zod/src/convertType.ts +++ b/packages/ts-to-zod/src/convertType.ts @@ -1347,7 +1347,8 @@ function getDeclarationOrThrow(symbol: tm.Symbol): tm.Node { } else return declaration; } -function getPropertyDeclaration(symbol: tm.Symbol): tm.Node | undefined { +// TODO: move to utils file +export function getPropertyDeclaration(symbol: tm.Symbol): tm.Node | undefined { for (const declaration of symbol.getDeclarations()) { if ( declaration instanceof tm.PropertyDeclaration || diff --git a/packages/ts-to-zod/src/generateFiles.ts b/packages/ts-to-zod/src/generateFiles.ts index 30dafaca..019c51bf 100644 --- a/packages/ts-to-zod/src/generateFiles.ts +++ b/packages/ts-to-zod/src/generateFiles.ts @@ -12,14 +12,11 @@ import { export function generateFiles( ctx: SchemaGenContext, options: GenOptions -): Map { +): Map { const outputMap = new Map(); const generationConfig = createGenerationConfig(options); for (const [path, info] of ctx.files.entries()) { - const generatedPath = generatePath({ - path, - ...generationConfig, - }); + const generatedPath = generatePath(path, generationConfig); const statements: ts.Statement[] = generateImportStatements( generationConfig, @@ -43,12 +40,7 @@ export function generateFiles( ); statements.push(variableStatement); } - const sourceFile = factory.createSourceFile( - statements, - factory.createToken(ts.SyntaxKind.EndOfFileToken), - 0 - ); - outputMap.set(generatedPath, sourceFile); + outputMap.set(generatedPath, statements); } return outputMap; } @@ -62,24 +54,10 @@ function relativeImportPath( return relativePath; } -function generatePath({ - /** Path of the file for which the schema is being generated */ - path, - /** The root path of the project. Usually the root of an npm package. */ - rootPath, - /** Base path where user file schemas should be generated in */ - basePath, - /** Base path where dependency file schemas should be generated in */ - baseDependenciesPath, - /** The suffix to append to file names, if specified */ - suffix, -}: { - path: string; - rootPath: string; - basePath: string; - baseDependenciesPath: string; - suffix?: string; -}): string { +export function generatePath( + path: string, + { basePath, baseDependenciesPath, rootPath, suffix }: GenerationConfig +): string { // set cwd to the root path for proper processing of relative paths // save old cwd to restore later // either basePath or baseDependenciesPath @@ -110,12 +88,7 @@ export function generateImportStatements( ([symbol, { importFromUserFile, sourceFile }]) => relativeImportPath( filePath, - importFromUserFile - ? sourceFile - : generatePath({ - path: sourceFile, - ...config, - }) + importFromUserFile ? sourceFile : generatePath(sourceFile, config) ) ); const zImportClause = factory.createImportClause(