From 4000758ba1eb063cd648173b68e0f7d766ad751d Mon Sep 17 00:00:00 2001 From: ymc9 <104139426+ymc9@users.noreply.github.com> Date: Mon, 10 Apr 2023 16:41:55 -0700 Subject: [PATCH] fix: wrap generated trpc routes with error handling --- packages/plugins/trpc/src/generator.ts | 130 +++++++++++++++++- packages/plugins/trpc/src/helpers.ts | 38 ++--- .../access-policy/policy-guard-generator.ts | 12 +- packages/schema/src/plugins/plugin-utils.ts | 1 - packages/sdk/src/constants.ts | 5 + packages/server/src/express/middleware.ts | 3 +- packages/server/src/fastify/plugin.ts | 3 +- tests/integration/test-run/package-lock.json | 6 +- 8 files changed, 156 insertions(+), 42 deletions(-) diff --git a/packages/plugins/trpc/src/generator.ts b/packages/plugins/trpc/src/generator.ts index 1ef8873e1..26f535a60 100644 --- a/packages/plugins/trpc/src/generator.ts +++ b/packages/plugins/trpc/src/generator.ts @@ -1,14 +1,20 @@ import { DMMF } from '@prisma/generator-helper'; -import { PluginError, PluginOptions } from '@zenstackhq/sdk'; +import { CrudFailureReason, PluginError, PluginOptions, RUNTIME_PACKAGE } from '@zenstackhq/sdk'; import { Model } from '@zenstackhq/sdk/ast'; +import { camelCase } from 'change-case'; import { promises as fs } from 'fs'; import path from 'path'; -import { generate as PrismaZodGenerator } from './zod/generator'; -import { generateProcedure, generateRouterSchemaImports, getInputTypeByOpName, resolveModelsComments } from './helpers'; +import { Project } from 'ts-morph'; +import { + generateHelperImport, + generateProcedure, + generateRouterSchemaImports, + getInputTypeByOpName, + resolveModelsComments, +} from './helpers'; import { project } from './project'; import removeDir from './utils/removeDir'; -import { camelCase } from 'change-case'; -import { Project } from 'ts-morph'; +import { generate as PrismaZodGenerator } from './zod/generator'; export async function generate(model: Model, options: PluginOptions, dmmf: DMMF.Document) { let outDir = options.output as string; @@ -33,6 +39,13 @@ export async function generate(model: Model, options: PluginOptions, dmmf: DMMF. const hiddenModels: string[] = []; resolveModelsComments(models, hiddenModels); + createAppRouter(outDir, modelOperations, hiddenModels); + createHelper(outDir); + + await project.save(); +} + +function createAppRouter(outDir: string, modelOperations: DMMF.ModelMapping[], hiddenModels: string[]) { const appRouter = project.createSourceFile(path.resolve(outDir, 'routers', `index.ts`), undefined, { overwrite: true, }); @@ -110,7 +123,6 @@ export async function generate(model: Model, options: PluginOptions, dmmf: DMMF. }); appRouter.formatText(); - await project.save(); } function generateModelCreateRouter( @@ -133,6 +145,7 @@ function generateModelCreateRouter( ]); generateRouterSchemaImports(modelRouter, model); + generateHelperImport(modelRouter); modelRouter .addFunction({ @@ -162,3 +175,108 @@ function generateModelCreateRouter( modelRouter.formatText(); } + +function createHelper(outDir: string) { + const sf = project.createSourceFile(path.resolve(outDir, 'helper.ts'), undefined, { + overwrite: true, + }); + + sf.addStatements(`import { TRPCError } from '@trpc/server';`); + sf.addStatements(`import { isPrismaClientKnownRequestError } from '${RUNTIME_PACKAGE}';`); + + const checkMutate = sf.addFunction({ + name: 'checkMutate', + typeParameters: [{ name: 'T' }], + parameters: [ + { + name: 'promise', + type: 'Promise', + }, + ], + isAsync: true, + isExported: true, + returnType: 'Promise', + }); + + checkMutate.setBodyText( + `try { + return await promise; + } catch (err: any) { + if (isPrismaClientKnownRequestError(err)) { + if (err.code === 'P2004') { + if (err.meta?.reason === '${CrudFailureReason.RESULT_NOT_READABLE}') { + // unable to readback data + return undefined; + } else { + // rejected by policy + throw new TRPCError({ + code: 'FORBIDDEN', + message: err.message, + cause: err, + }); + } + } else { + // request error + throw new TRPCError({ + code: 'BAD_REQUEST', + message: err.message, + cause: err, + }); + } + } else { + throw err; + } + } + ` + ); + checkMutate.formatText(); + + const checkRead = sf.addFunction({ + name: 'checkRead', + typeParameters: [{ name: 'T' }], + parameters: [ + { + name: 'promise', + type: 'Promise', + }, + ], + isAsync: true, + isExported: true, + returnType: 'Promise', + }); + + checkRead.setBodyText( + `try { + return await promise; + } catch (err: any) { + if (isPrismaClientKnownRequestError(err)) { + if (err.code === 'P2004') { + // rejected by policy + throw new TRPCError({ + code: 'FORBIDDEN', + message: err.message, + cause: err, + }); + } else if (err.code === 'P2025') { + // not found + throw new TRPCError({ + code: 'NOT_FOUND', + message: err.message, + cause: err, + }); + } else { + // request error + throw new TRPCError({ + code: 'BAD_REQUEST', + message: err.message, + cause: err, + }) + } + } else { + throw err; + } + } + ` + ); + checkRead.formatText(); +} diff --git a/packages/plugins/trpc/src/helpers.ts b/packages/plugins/trpc/src/helpers.ts index b7be4dc29..3972c79e0 100644 --- a/packages/plugins/trpc/src/helpers.ts +++ b/packages/plugins/trpc/src/helpers.ts @@ -1,22 +1,7 @@ import { DMMF } from '@prisma/generator-helper'; -import { CrudFailureReason } from '@zenstackhq/sdk'; import { CodeBlockWriter, SourceFile } from 'ts-morph'; import { uncapitalizeFirstLetter } from './utils/uncapitalizeFirstLetter'; -export const generatetRPCImport = (sourceFile: SourceFile) => { - sourceFile.addImportDeclaration({ - moduleSpecifier: '@trpc/server', - namespaceImport: 'trpc', - }); -}; - -export const generateRouterImport = (sourceFile: SourceFile, modelNamePlural: string, modelNameCamelCase: string) => { - sourceFile.addImportDeclaration({ - moduleSpecifier: `./${modelNameCamelCase}.router`, - namedImports: [`${modelNamePlural}Router`], - }); -}; - export function generateProcedure( writer: CodeBlockWriter, opType: string, @@ -29,24 +14,15 @@ export function generateProcedure( if (procType === 'query') { writer.write(` - ${opType}: procedure.input(${typeName}).query(({ctx, input}) => db(ctx).${uncapitalizeFirstLetter( + ${opType}: procedure.input(${typeName}).query(({ctx, input}) => checkRead(db(ctx).${uncapitalizeFirstLetter( modelName - )}.${prismaMethod}(input)), + )}.${prismaMethod}(input))), `); } else if (procType === 'mutation') { writer.write(` - ${opType}: procedure.input(${typeName}).mutation(async ({ctx, input}) => { - try { - return await db(ctx).${uncapitalizeFirstLetter(modelName)}.${prismaMethod}(input); - } catch (err: any) { - if (err.code === 'P2004' && err.meta?.reason === '${CrudFailureReason.RESULT_NOT_READABLE}') { - // unable to readback data - return undefined; - } else { - throw err; - } - } - }), + ${opType}: procedure.input(${typeName}).mutation(async ({ctx, input}) => checkMutate(db(ctx).${uncapitalizeFirstLetter( + modelName + )}.${prismaMethod}(input))), `); } } @@ -55,6 +31,10 @@ export function generateRouterSchemaImports(sourceFile: SourceFile, name: string sourceFile.addStatements(`import { ${name}Schema } from '../schemas/${name}.schema';`); } +export function generateHelperImport(sourceFile: SourceFile) { + sourceFile.addStatements(`import { checkRead, checkMutate } from '../helper';`); +} + export const getInputTypeByOpName = (opName: string, modelName: string) => { let inputType; switch (opName) { diff --git a/packages/schema/src/plugins/access-policy/policy-guard-generator.ts b/packages/schema/src/plugins/access-policy/policy-guard-generator.ts index a42c6560c..c1979d590 100644 --- a/packages/schema/src/plugins/access-policy/policy-guard-generator.ts +++ b/packages/schema/src/plugins/access-policy/policy-guard-generator.ts @@ -14,7 +14,15 @@ import { Model, } from '@zenstackhq/language/ast'; import type { PolicyKind, PolicyOperationKind } from '@zenstackhq/runtime'; -import { getDataModels, getLiteral, GUARD_FIELD_NAME, PluginError, PluginOptions, resolved } from '@zenstackhq/sdk'; +import { + getDataModels, + getLiteral, + GUARD_FIELD_NAME, + PluginError, + PluginOptions, + resolved, + RUNTIME_PACKAGE, +} from '@zenstackhq/sdk'; import { camelCase } from 'change-case'; import { streamAllContents } from 'langium'; import path from 'path'; @@ -22,7 +30,7 @@ import { FunctionDeclaration, Project, SourceFile, VariableDeclarationKind } fro import { name } from '.'; import { isFromStdlib } from '../../language-server/utils'; import { analyzePolicies, getIdFields } from '../../utils/ast-utils'; -import { ALL_OPERATION_KINDS, getDefaultOutputFolder, RUNTIME_PACKAGE } from '../plugin-utils'; +import { ALL_OPERATION_KINDS, getDefaultOutputFolder } from '../plugin-utils'; import { ExpressionWriter } from './expression-writer'; import { isFutureExpr } from './utils'; import { ZodSchemaGenerator } from './zod-schema-generator'; diff --git a/packages/schema/src/plugins/plugin-utils.ts b/packages/schema/src/plugins/plugin-utils.ts index 7d3cc1320..d10be7c94 100644 --- a/packages/schema/src/plugins/plugin-utils.ts +++ b/packages/schema/src/plugins/plugin-utils.ts @@ -2,7 +2,6 @@ import type { PolicyOperationKind } from '@zenstackhq/runtime'; import fs from 'fs'; import path from 'path'; -export const RUNTIME_PACKAGE = '@zenstackhq/runtime'; export const ALL_OPERATION_KINDS: PolicyOperationKind[] = ['create', 'update', 'postUpdate', 'read', 'delete']; /** diff --git a/packages/sdk/src/constants.ts b/packages/sdk/src/constants.ts index 038a7cebe..da7d620e4 100644 --- a/packages/sdk/src/constants.ts +++ b/packages/sdk/src/constants.ts @@ -22,3 +22,8 @@ export enum CrudFailureReason { */ RESULT_NOT_READABLE = 'RESULT_NOT_READABLE', } + +/** + * @zenstackhq/runtime package name + */ +export const RUNTIME_PACKAGE = '@zenstackhq/runtime'; diff --git a/packages/server/src/express/middleware.ts b/packages/server/src/express/middleware.ts index 54d6e86be..214b30b64 100644 --- a/packages/server/src/express/middleware.ts +++ b/packages/server/src/express/middleware.ts @@ -19,7 +19,8 @@ export interface MiddlewareOptions { logger?: LoggerConfig; /** - * Zod schemas for validating request input. Pass `true` to load from standard location (need to enable `@core/zod` plugin in schema.zmodel). + * Zod schemas for validating request input. Pass `true` to load from standard location + * (need to enable `@core/zod` plugin in schema.zmodel) or omit to disable input validation. */ zodSchemas?: ModelZodSchema | boolean; } diff --git a/packages/server/src/fastify/plugin.ts b/packages/server/src/fastify/plugin.ts index 4999b5d8a..c22d64a29 100644 --- a/packages/server/src/fastify/plugin.ts +++ b/packages/server/src/fastify/plugin.ts @@ -25,7 +25,8 @@ export interface PluginOptions { logger?: LoggerConfig; /** - * Zod schemas for validating request input. Pass `true` to load from standard location (need to enable `@core/zod` plugin in schema.zmodel). + * Zod schemas for validating request input. Pass `true` to load from standard location + * (need to enable `@core/zod` plugin in schema.zmodel) or omit to disable input validation. */ zodSchemas?: ModelZodSchema | boolean; } diff --git a/tests/integration/test-run/package-lock.json b/tests/integration/test-run/package-lock.json index 12f5787fc..e72e11b92 100644 --- a/tests/integration/test-run/package-lock.json +++ b/tests/integration/test-run/package-lock.json @@ -188,7 +188,8 @@ "vscode-languageserver": "^8.0.2", "vscode-languageserver-textdocument": "^1.0.7", "vscode-uri": "^3.0.6", - "zod": "^3.19.1" + "zod": "^3.19.1", + "zod-validation-error": "^0.2.1" }, "bin": { "zenstack": "bin/cli" @@ -455,7 +456,8 @@ "vscode-languageserver": "^8.0.2", "vscode-languageserver-textdocument": "^1.0.7", "vscode-uri": "^3.0.6", - "zod": "^3.19.1" + "zod": "^3.19.1", + "zod-validation-error": "^0.2.1" } }, "zod": {