diff --git a/packages/server/src/express/middleware.ts b/packages/server/src/express/middleware.ts index 4056cb0c0..3c97de74e 100644 --- a/packages/server/src/express/middleware.ts +++ b/packages/server/src/express/middleware.ts @@ -14,6 +14,14 @@ export interface MiddlewareOptions extends AdapterBaseOptions { * Callback for getting a PrismaClient for the given request */ getPrisma: (req: Request, res: Response) => unknown | Promise; + /** + * This option is used to enable/disable the option to manage the response + * by the middleware. If set to true, the middleware will not send the + * response and the user will be responsible for sending the response. + * + * Defaults to false; + */ + manageCustomResponse?: boolean; } /** @@ -30,13 +38,18 @@ const factory = (options: MiddlewareOptions): Handler => { const requestHandler = options.handler || RPCAPIHandler(); const useSuperJson = options.useSuperJson === true; - return async (request, response) => { + return async (request, response, next) => { const prisma = (await options.getPrisma(request, response)) as DbClientContract; + const { manageCustomResponse } = options; + + if (manageCustomResponse && !prisma) { + throw new Error('unable to get prisma from request context'); + } + if (!prisma) { - response + return response .status(500) .json(marshalToObject({ message: 'unable to get prisma from request context' }, useSuperJson)); - return; } let query: Record = {}; @@ -53,8 +66,10 @@ const factory = (options: MiddlewareOptions): Handler => { } query = buildUrlQuery(rawQuery, useSuperJson); } catch { - response.status(400).json(marshalToObject({ message: 'invalid query parameters' }, useSuperJson)); - return; + if (manageCustomResponse) { + throw new Error('invalid query parameters'); + } + return response.status(400).json(marshalToObject({ message: 'invalid query parameters' }, useSuperJson)); } try { @@ -68,9 +83,19 @@ const factory = (options: MiddlewareOptions): Handler => { zodSchemas, logger: options.logger, }); - response.status(r.status).json(marshalToObject(r.body, useSuperJson)); + if (manageCustomResponse) { + response.locals = { + status: r.status, + body: r.body, + }; + return next(); + } + return response.status(r.status).json(marshalToObject(r.body, useSuperJson)); } catch (err) { - response + if (manageCustomResponse) { + throw err; + } + return response .status(500) .json(marshalToObject({ message: `An unhandled error occurred: ${err}` }, useSuperJson)); } diff --git a/packages/server/tests/adapter/express.test.ts b/packages/server/tests/adapter/express.test.ts index df96ec7a2..0a964ef41 100644 --- a/packages/server/tests/adapter/express.test.ts +++ b/packages/server/tests/adapter/express.test.ts @@ -251,3 +251,30 @@ describe('Express adapter tests - rest handler', () => { expect(await prisma.user.findMany()).toHaveLength(0); }); }); + +describe('Express adapter tests - rest handler with customMiddleware', () => { + it('run middleware', async () => { + const { prisma, zodSchemas, modelMeta } = await loadSchema(schema); + + const app = express(); + app.use(bodyParser.json()); + app.use( + '/api', + ZenStackMiddleware({ + getPrisma: () => prisma, + modelMeta, + zodSchemas, + handler: RESTAPIHandler({ endpoint: 'http://localhost/api' }), + manageCustomResponse: true, + }) + ); + + app.use((req, res) => { + res.status(res.locals.status).json({ message: res.locals.body }); + }); + + const r = await request(app).get(makeUrl('/api/post/1')); + expect(r.status).toBe(404); + expect(r.body.message).toHaveProperty('errors'); + }); +});