diff --git a/src/adapter/bun/index.ts b/src/adapter/bun/index.ts index a2ec4db7..ec65a25b 100644 --- a/src/adapter/bun/index.ts +++ b/src/adapter/bun/index.ts @@ -323,6 +323,14 @@ export const BunAdapter: ElysiaAdapter = { normalize: app.config.normalize }) + const validateUpgradeData = getSchemaValidator(options.upgradeData, { + // @ts-expect-error private property + modules: app.definitions.typebox, + // @ts-expect-error private property + models: app.definitions.type as Record, + normalize: app.config.normalize + }) + app.route( 'WS', path as any, @@ -368,6 +376,12 @@ export const BunAdapter: ElysiaAdapter = { let _id: string | undefined + let _beforeHandleData: any + if (typeof options.beforeHandle === 'function') { + const result = options.beforeHandle(context) + _beforeHandleData = result instanceof Promise ? await result : result + } + const errorHandlers = [ ...(Array.isArray(options.error) ? options.error @@ -413,11 +427,22 @@ export const BunAdapter: ElysiaAdapter = { options.pong?.(data) }, open(ws: ServerWebSocket) { + if (validateUpgradeData?.Check(_beforeHandleData) === false) { + return void ws.send( + new ValidationError( + 'upgradeData', + validateUpgradeData, + _beforeHandleData + ).message as string + ) + } + try { handleResponse( ws, options.open?.( - new ElysiaWS(ws, context as any) + new ElysiaWS(ws, context as any), + _beforeHandleData as any ) ) } catch (error) { diff --git a/src/index.ts b/src/index.ts index 933aa6dc..6b7bc614 100644 --- a/src/index.ts +++ b/src/index.ts @@ -5384,7 +5384,8 @@ export default class Elysia< MergeSchema > >, - const Macro extends Metadata['macro'] + const Macro extends Metadata['macro'], + const UpgradeDataSchema extends TSchema, >( path: Path, options: WSLocalHook< @@ -5396,7 +5397,8 @@ export default class Elysia< Volatile['resolve'] & MacroToContext }, - Macro + Macro, + UpgradeDataSchema > ): Elysia< BasePath, diff --git a/src/ws/types.ts b/src/ws/types.ts index ab415d4c..0eccbce0 100644 --- a/src/ws/types.ts +++ b/src/ws/types.ts @@ -1,3 +1,5 @@ +import { TSchema } from '@sinclair/typebox' + import type { ElysiaWS } from './index' import { WebSocketHandler } from './bun' @@ -17,7 +19,8 @@ import { Prettify, RouteSchema, SingletonBase, - TransformHandler + TransformHandler, + UnwrapSchema } from '../types' type TypedWebSocketMethod = @@ -33,10 +36,12 @@ export type FlattenResponse = interface TypedWebSocketHandler< in out Context, - in out Route extends RouteSchema = {} + in out Route extends RouteSchema = {}, + in out UpgradeDataSchema extends unknown = unknown > extends Omit, TypedWebSocketMethod> { open?( - ws: ElysiaWS & { body: never }> + ws: ElysiaWS & { body: never }>, + data: UpgradeDataSchema ): MaybePromise | void> message?( ws: ElysiaWS, @@ -118,13 +123,14 @@ export type WSParseHandler = ( message: unknown ) => MaybePromise -export type AnyWSLocalHook = WSLocalHook +export type AnyWSLocalHook = WSLocalHook export type WSLocalHook< LocalSchema extends InputSchema, Schema extends RouteSchema, Singleton extends SingletonBase, - Macro extends MetadataBase['macro'] + Macro extends MetadataBase['macro'], + UpgradeDataSchema extends TSchema, > = Prettify & (LocalSchema extends any ? LocalSchema : Prettify) & { detail?: DocumentDecoration @@ -134,6 +140,11 @@ export type WSLocalHook< upgrade?: Record | ((context: Context) => unknown) parse?: MaybeArray> + /** + * Upgrade data's value + */ + upgradeData?: UpgradeDataSchema; + /** * Transform context's value */ @@ -163,5 +174,6 @@ export type WSLocalHook< Omit, 'body'> & { body: never }, - Schema + Schema, + UnwrapSchema > diff --git a/test/ws/message.test.ts b/test/ws/message.test.ts index bdb324e5..4c46af16 100644 --- a/test/ws/message.test.ts +++ b/test/ws/message.test.ts @@ -487,4 +487,38 @@ describe('WebSocket message', () => { await wsClosed(ws) app.stop() }) + + it('should call beforeHandle hook', async () => { + const app = new Elysia() + .ws('/ws', { + upgradeData: t.Object({ + hello: t.String() + }), + beforeHandle() { + return { + hello: 'world' + } + }, + open(ws, data) { + ws.send(data.hello) + } + }) + .listen(0) + + const ws = newWebsocket(app.server!) + + await wsOpen(ws) + + const message = wsMessage(ws) + + ws.send('Hello!') + + const { data } = await message + + expect(data).toBe('world') + + await wsClosed(ws) + + app.stop() + }) })