diff --git a/packages/server/package.json b/packages/server/package.json index 84e47382..2f6dacdb 100644 --- a/packages/server/package.json +++ b/packages/server/package.json @@ -39,6 +39,16 @@ "default": "./dist/api.cjs" } }, + "./common": { + "import": { + "types": "./dist/common.d.ts", + "default": "./dist/common.js" + }, + "require": { + "types": "./dist/common.d.cts", + "default": "./dist/common.cjs" + } + }, "./express": { "import": { "types": "./dist/express.d.ts", @@ -118,6 +128,16 @@ "types": "./dist/tanstack-start.d.cts", "default": "./dist/tanstack-start.cjs" } + }, + "./types": { + "import": { + "types": "./dist/types.d.ts", + "default": "./dist/types.js" + }, + "require": { + "types": "./dist/types.d.cts", + "default": "./dist/types.cjs" + } } }, "dependencies": { diff --git a/packages/server/src/api/index.ts b/packages/server/src/api/index.ts index 09d9700e..e57e81b8 100644 --- a/packages/server/src/api/index.ts +++ b/packages/server/src/api/index.ts @@ -1,2 +1,3 @@ export { RestApiHandler, type RestApiHandlerOptions } from './rest'; export { RPCApiHandler, type RPCApiHandlerOptions } from './rpc'; +export * from './utils'; diff --git a/packages/server/src/api/rest/index.ts b/packages/server/src/api/rest/index.ts index 536a45a5..4df8ee48 100644 --- a/packages/server/src/api/rest/index.ts +++ b/packages/server/src/api/rest/index.ts @@ -118,10 +118,10 @@ registerCustomSerializers(); */ export class RestApiHandler implements ApiHandler { // resource serializers - private serializers = new Map(); + protected serializers = new Map(); // error responses - private readonly errors: Record = { + protected readonly errors: Record = { unsupportedModel: { status: 404, title: 'Unsupported model type', @@ -200,10 +200,10 @@ export class RestApiHandler implements ApiHandler(\[[^[\]]+\])+)$/); + protected filterParamPattern = new RegExp(/^filter(?(\[[^[\]]+\])+)$/); // zod schema for payload of creating and updating a resource - private createUpdatePayloadSchema = z + protected createUpdatePayloadSchema = z .object({ data: z.object({ type: z.string(), @@ -225,16 +225,16 @@ export class RestApiHandler implements ApiHandler implements ApiHandler = {}; + protected typeMap: Record = {}; // divider used to separate compound ID fields - private idDivider; + protected idDivider; - private urlPatternMap: Record; - private modelNameMapping: Record; - private reverseModelNameMapping: Record; - private externalIdMapping: Record; + protected urlPatternMap: Record; + protected modelNameMapping: Record; + protected reverseModelNameMapping: Record; + protected externalIdMapping: Record; - constructor(private readonly options: RestApiHandlerOptions) { + constructor(protected readonly options: RestApiHandlerOptions) { this.idDivider = options.idDivider ?? DEFAULT_ID_DIVIDER; const segmentCharset = options.urlSegmentCharset ?? 'a-zA-Z0-9-_~ %'; @@ -283,7 +283,7 @@ export class RestApiHandler implements ApiHandler { + protected buildUrlPatternMap(urlSegmentNameCharset: string): Record { const options = { segmentValueCharset: urlSegmentNameCharset }; const buildPath = (segments: string[]) => { @@ -301,11 +301,11 @@ export class RestApiHandler implements ApiHandler implements ApiHandler { + protected handleGenericError(err: unknown): Response | PromiseLike { return this.makeError('unknownError', err instanceof Error ? `${err.message}\n${err.stack}` : 'Unknown error'); } - private async processSingleRead( + protected async processSingleRead( client: ClientContract, type: string, resourceId: string, @@ -528,7 +528,7 @@ export class RestApiHandler implements ApiHandler, type: string, resourceId: string, @@ -617,7 +617,7 @@ export class RestApiHandler implements ApiHandler, type: string, resourceId: string, @@ -683,7 +683,7 @@ export class RestApiHandler implements ApiHandler, type: string, query: Record | undefined, @@ -785,7 +785,7 @@ export class RestApiHandler implements ApiHandler | undefined) { + protected buildPartialSelect(type: string, query: Record | undefined) { const selectFieldsQuery = query?.[`fields[${type}]`]; if (!selectFieldsQuery) { return { select: undefined, error: undefined }; @@ -812,11 +812,11 @@ export class RestApiHandler implements ApiHandler implements ApiHandler implements ApiHandler, type: string, _query: Record | undefined, @@ -931,7 +931,7 @@ export class RestApiHandler implements ApiHandler, type: string, _query: Record | undefined, @@ -1014,7 +1014,7 @@ export class RestApiHandler implements ApiHandler @@ -1024,7 +1024,7 @@ export class RestApiHandler implements ApiHandler, mode: 'create' | 'update' | 'delete', type: string, @@ -1119,7 +1119,7 @@ export class RestApiHandler implements ApiHandler, type: any, resourceId: string, @@ -1186,7 +1186,7 @@ export class RestApiHandler implements ApiHandler, type: any, resourceId: string): Promise { + protected async processDelete(client: ClientContract, type: any, resourceId: string): Promise { const typeInfo = this.getModelInfo(type); if (!typeInfo) { return this.makeUnsupportedModelError(type); @@ -1203,7 +1203,7 @@ export class RestApiHandler implements ApiHandler implements ApiHandler implements ApiHandler implements ApiHandler implements ApiHandler> = {}; for (const model of Object.keys(this.schema.models)) { @@ -1382,7 +1382,7 @@ export class RestApiHandler implements ApiHandler implements ApiHandler>) { + protected async serializeItems(model: string, items: unknown, options?: Partial>) { model = lowerCaseFirst(model); const serializer = this.serializers.get(model); if (!serializer) { @@ -1421,7 +1421,7 @@ export class RestApiHandler implements ApiHandler implements ApiHandler implements ApiHandler) { + protected replaceURLSearchParams(url: string, params: Record) { const r = new URL(url); for (const [key, value] of Object.entries(params)) { r.searchParams.set(key, value.toString()); @@ -1486,7 +1486,7 @@ export class RestApiHandler implements ApiHandler implements ApiHandler ({ ...acc, [curr.name]: true }), {}); } - private makeIdConnect(idFields: FieldDef[], id: string | number) { + protected makeIdConnect(idFields: FieldDef[], id: string | number) { if (idFields.length === 1) { return { [idFields[0]!.name]: this.coerce(idFields[0]!, id) }; } else { @@ -1535,20 +1535,20 @@ export class RestApiHandler implements ApiHandler idf.name).join(this.idDivider); } - private makeDefaultIdKey(idFields: FieldDef[]) { + protected makeDefaultIdKey(idFields: FieldDef[]) { // TODO: support `@@id` with custom name return idFields.map((idf) => idf.name).join(DEFAULT_ID_DIVIDER); } - private makeCompoundId(idFields: FieldDef[], item: any) { + protected makeCompoundId(idFields: FieldDef[], item: any) { return idFields.map((idf) => item[idf.name]).join(this.idDivider); } - private makeUpsertWhere(matchFields: any[], attributes: any, typeInfo: ModelInfo) { + protected makeUpsertWhere(matchFields: any[], attributes: any, typeInfo: ModelInfo) { const where = matchFields.reduce((acc: any, field: string) => { acc[field] = attributes[field] ?? null; return acc; @@ -1566,7 +1566,7 @@ export class RestApiHandler implements ApiHandler implements ApiHandler attr.name === '@json')) { try { @@ -1624,7 +1624,7 @@ export class RestApiHandler implements ApiHandler | undefined) { + protected makeNormalizedUrl(path: string, query: Record | undefined) { const url = new URL(this.makeLinkUrl(path)); for (const [key, value] of Object.entries(query ?? {})) { if ( @@ -1642,7 +1642,7 @@ export class RestApiHandler implements ApiHandler | undefined) { + protected getPagination(query: Record | undefined) { if (!query) { return { offset: 0, limit: this.options.pageSize ?? DEFAULT_PAGE_SIZE }; } @@ -1676,7 +1676,7 @@ export class RestApiHandler implements ApiHandler | undefined, ): { filter: any; error: any } { @@ -1780,7 +1780,7 @@ export class RestApiHandler implements ApiHandler | undefined) { + protected buildSort(type: string, query: Record | undefined) { if (!query?.['sort']) { return { sort: undefined, error: undefined }; } @@ -1857,7 +1857,7 @@ export class RestApiHandler implements ApiHandler | undefined, @@ -1917,7 +1917,7 @@ export class RestApiHandler implements ApiHandler implements ApiHandler implements ApiHandler { return this.makeError('validationError', err.message, 422); @@ -2036,7 +2036,7 @@ export class RestApiHandler implements ApiHandler implements ApiHandler = { * RPC style API request handler that mirrors the ZenStackClient API */ export class RPCApiHandler implements ApiHandler { - constructor(private readonly options: RPCApiHandlerOptions) {} + constructor(protected readonly options: RPCApiHandlerOptions) {} get schema(): Schema { return this.options.schema; @@ -163,11 +163,11 @@ export class RPCApiHandler implements ApiHandler, model: string) { + protected isValidModel(client: ClientContract, model: string) { return Object.keys(client.$schema.models).some((m) => lowerCaseFirst(m) === lowerCaseFirst(model)); } - private makeBadInputErrorResponse(message: string) { + protected makeBadInputErrorResponse(message: string) { const resp = { status: 400, body: { error: { message } }, @@ -176,7 +176,7 @@ export class RPCApiHandler implements ApiHandler implements ApiHandler implements ApiHandler implements ApiHandler = { + method: string; + path: string; + query?: Record; + body?: unknown; + client: ClientContract; +}; + +class RecordingHandler implements ApiHandler { + constructor( + protected readonly schemaDef: SchemaDef, + protected readonly response: Response, + protected readonly logger?: (...args: any[]) => void, + ) {} + + readonly contexts: Array> = []; + + get schema(): SchemaDef { + return this.schemaDef; + } + + get log() { + return this.logger; + } + + async handleRequest(context: RequestContext): Promise { + this.contexts.push(context); + return this.response; + } +} + +class ThrowingHandler implements ApiHandler { + constructor(protected readonly schemaDef: SchemaDef, protected readonly logger: (...args: any[]) => void) {} + + get schema(): SchemaDef { + return this.schemaDef; + } + + get log() { + return this.logger; + } + + async handleRequest(): Promise { + throw new Error('adapter failure'); + } +} + +function createCustomAdapter( + options: CommonAdapterOptions, +): (request: AdapterRequest) => Promise { + return async (request) => { + const context: RequestContext = { + client: request.client, + method: request.method, + path: request.path, + query: request.query, + requestBody: request.body, + }; + + try { + return await options.apiHandler.handleRequest(context); + } catch (err) { + logInternalError(options.apiHandler.log, err); + throw err; + } + }; +} + +describe('Custom adapter test', () => { + const schema = {} as SchemaDef; + const client = { $schema: schema } as unknown as ClientContract; + + it('delegates to api handler', async () => { + const response: Response = { status: 201, body: { ok: true } }; + const handler = new RecordingHandler(schema, response); + const adapter = createCustomAdapter({ apiHandler: handler }); + + const result = await adapter({ + method: 'get', + path: '/something', + query: { foo: 'bar' }, + body: { value: 1 }, + client, + }); + + expect(result).toEqual(response); + expect(handler.contexts).toHaveLength(1); + const captured = handler.contexts[0]; + expect(captured.method).toBe('get'); + expect(captured.path).toBe('/something'); + expect(captured.query).toEqual({ foo: 'bar' }); + expect(captured.requestBody).toEqual({ value: 1 }); + expect(captured.client).toBe(client); + }); + + it('logs internal error when handler throws', async () => { + const logger = vi.fn(); + const handler = new ThrowingHandler(schema, logger); + const adapter = createCustomAdapter({ apiHandler: handler }); + + await expect( + adapter({ + method: 'post', + path: '/fail', + client, + }), + ).rejects.toThrow('adapter failure'); + expect(logger).toHaveBeenCalledTimes(1); + const call = logger.mock.calls[0]; + expect(call[0]).toBe('error'); + expect(call[1]).toContain('An unhandled error occurred while processing the request: Error: adapter failure'); + }); +}); diff --git a/packages/server/test/api/custom.test.ts b/packages/server/test/api/custom.test.ts new file mode 100644 index 00000000..d51d7b9f --- /dev/null +++ b/packages/server/test/api/custom.test.ts @@ -0,0 +1,62 @@ +import type { ClientContract } from '@zenstackhq/orm'; +import type { SchemaDef } from '@zenstackhq/orm/schema'; +import { Decimal } from 'decimal.js'; +import SuperJSON from 'superjson'; +import { describe, expect, it, vi } from 'vitest'; +import { log, registerCustomSerializers } from '../../src/api/utils'; +import { type ApiHandler, type LogConfig, type RequestContext, type Response } from '../../src/types'; + +class CustomApiHandler implements ApiHandler { + protected readonly handled: Array> = []; + + constructor(protected readonly schemaDef: SchemaDef, protected readonly logger: LogConfig) {} + + get schema(): SchemaDef { + return this.schemaDef; + } + + get log(): LogConfig { + return this.logger; + } + + get contexts(): ReadonlyArray> { + return this.handled; + } + + async handleRequest(context: RequestContext): Promise { + this.handled.push(context); + log(this.logger, 'info', () => `received ${context.method.toUpperCase()} ${context.path}`); + return { status: 202, body: { handled: true } }; + } +} + +describe('Custom API handler test', () => { + const schema = {} as SchemaDef; + const client = { $schema: schema } as unknown as ClientContract; + + it('allows building custom handlers with logging helpers', async () => { + const logger = vi.fn(); + const handler = new CustomApiHandler(schema, logger); + + const response = await handler.handleRequest({ + method: 'post', + path: '/custom', + query: { foo: 'bar' }, + requestBody: { value: 1 }, + client, + }); + + expect(response).toEqual({ status: 202, body: { handled: true } }); + expect(handler.contexts).toHaveLength(1); + expect(handler.contexts[0].query).toEqual({ foo: 'bar' }); + expect(logger).toHaveBeenCalledWith('info', 'received POST /custom', undefined); + }); + + it('provides serialization helpers for custom handlers', () => { + registerCustomSerializers(); + const serialized = SuperJSON.serialize({ value: new Decimal('3.14159') }); + const roundTripped = SuperJSON.deserialize(serialized) as { value: Decimal }; + expect(Decimal.isDecimal(roundTripped.value)).toBe(true); + expect(roundTripped.value.toString()).toBe('3.14159'); + }); +}); diff --git a/packages/server/test/api/rest.test.ts b/packages/server/test/api/rest.test.ts index b40ff604..32a39f8d 100644 --- a/packages/server/test/api/rest.test.ts +++ b/packages/server/test/api/rest.test.ts @@ -3163,4 +3163,79 @@ describe('REST server tests', () => { }); }); }); + + describe('REST server tests - handler extension', () => { + const schema = ` + model Post { + id String @id + title String + } + `; + + class CustomRestApiHandler extends RestApiHandler { + public readonly buildFilterCalls: Array<{ + type: string; + query: Record | undefined; + filter: unknown; + }> = []; + + protected override buildFilter( + type: string, + query: Record | undefined, + ) { + const result = super.buildFilter(type, query); + if (type !== 'post') { + this.buildFilterCalls.push({ type, query, filter: result.filter }); + return result; + } + + const baseFilter = + result.filter && typeof result.filter === 'object' && !Array.isArray(result.filter) + ? { ...(result.filter as Record) } + : {}; + + const modified = { + ...result, + filter: { + ...baseFilter, + title: 'second', + }, + }; + + this.buildFilterCalls.push({ type, query, filter: modified.filter }); + return modified; + } + } + + beforeEach(async () => { + client = await createTestClient(schema); + await client.post.create({ data: { id: 'post-first', title: 'first' } }); + await client.post.create({ data: { id: 'post-second', title: 'second' } }); + }); + + it('allows extending RestApiHandler to customize filtering', async () => { + const customHandler = new CustomRestApiHandler({ + schema: client.$schema, + endpoint: 'http://localhost/api', + }); + + const response = await customHandler.handleRequest({ + method: 'get', + path: '/post', + query: {}, + client, + }); + + expect(customHandler.buildFilterCalls).toHaveLength(1); + expect(customHandler.buildFilterCalls[0].type).toBe('post'); + expect(customHandler.buildFilterCalls[0].filter).toMatchObject({ title: 'second' }); + + expect(response.status).toBe(200); + const body = response.body as { + data: Array<{ attributes: { title: string } }>; + }; + expect(body.data).toHaveLength(1); + expect(body.data[0].attributes.title).toBe('second'); + }); + }); }); diff --git a/packages/server/test/api/rpc.test.ts b/packages/server/test/api/rpc.test.ts index 19e44ca0..2f50bf5d 100644 --- a/packages/server/test/api/rpc.test.ts +++ b/packages/server/test/api/rpc.test.ts @@ -508,6 +508,64 @@ describe('RPC API Handler Tests', () => { expect(r.data).toBeNull(); }); + it('allows extending RPCApiHandler to customize query unmarshalling', async () => { + await rawClient.post.deleteMany(); + await rawClient.user.deleteMany(); + + await rawClient.user.create({ + data: { + id: 'ext-user', + email: 'ext@example.com', + posts: { + create: [ + { id: 'ext-post-1', title: 'first', published: true }, + { id: 'ext-post-2', title: 'second', published: true }, + ], + }, + }, + }); + + class CustomHandler extends RPCApiHandler { + public readonly unmarshalCalls: Array<{ value: string; meta: string | undefined; result: unknown }> = []; + protected override unmarshalQ(value: string, meta: string | undefined) { + const result = super.unmarshalQ(value, meta); + this.unmarshalCalls.push({ value, meta, result }); + const asRecord = (result ?? {}) as Record; + const baseWhere = (asRecord.where ?? {}) as Record; + return { + ...asRecord, + where: { + ...baseWhere, + title: 'second', + }, + }; + } + } + + const handler = new CustomHandler({ schema: client.$schema }); + const callHandler = (args: Parameters[0]) => handler.handleRequest(args); + + const response = await callHandler({ + method: 'get', + path: '/post/findMany', + client: rawClient, + query: { + q: JSON.stringify({ where: {} }), + }, + }); + + expect(handler.unmarshalCalls).toHaveLength(1); + expect(handler.unmarshalCalls[0].value).toBeDefined(); + expect(handler.unmarshalCalls[0].result).toEqual({ where: {} }); + expect(response.status).toBe(200); + const responseBody = response.body as { data: Array<{ title: string }> }; + expect(responseBody.data).toHaveLength(1); + expect(responseBody.data[0].title).toBe('second'); + + await rawClient.post.deleteMany(); + await rawClient.user.deleteMany(); + }); + function makeHandler() { const handler = new RPCApiHandler({ schema: client.$schema }); return async (args: any) => { diff --git a/packages/server/tsup.config.ts b/packages/server/tsup.config.ts index 4c236d2f..70009f51 100644 --- a/packages/server/tsup.config.ts +++ b/packages/server/tsup.config.ts @@ -2,7 +2,9 @@ import { defineConfig } from 'tsup'; export default defineConfig({ entry: { + types: 'src/types.ts', api: 'src/api/index.ts', + common: 'src/adapter/common.ts', express: 'src/adapter/express/index.ts', next: 'src/adapter/next/index.ts', fastify: 'src/adapter/fastify/index.ts',