diff --git a/packages/restate-sdk/src/context_impl.ts b/packages/restate-sdk/src/context_impl.ts index 982bca9d..cc6654cf 100644 --- a/packages/restate-sdk/src/context_impl.ts +++ b/packages/restate-sdk/src/context_impl.ts @@ -44,7 +44,6 @@ import { } from "./types/errors.js"; import type { Client, SendClient } from "./types/rpc.js"; import { - defaultSerde, HandlerKind, makeRpcCallProxy, makeRpcSendProxy, @@ -96,6 +95,7 @@ export class ContextImpl implements ObjectContext, WorkflowContext { private readonly outputPump: OutputPump; private readonly runClosuresTracker: RunClosuresTracker; readonly promisesExecutor: PromisesExecutor; + readonly defaultSerde: Serde; constructor( readonly coreVm: vm.WasmVM, @@ -108,6 +108,7 @@ export class ContextImpl implements ObjectContext, WorkflowContext { inputReader: ReadableStreamDefaultReader, outputWriter: WritableStreamDefaultWriter, readonly journalValueCodec: JournalValueCodec, + defaultSerde?: Serde, private readonly asTerminalError?: (error: any) => TerminalError | undefined ) { this.rand = new RandImpl(input.random_seed, () => { @@ -131,6 +132,7 @@ export class ContextImpl implements ObjectContext, WorkflowContext { this.runClosuresTracker, this.promiseExecutorErrorCallback.bind(this) ); + this.defaultSerde = defaultSerde ?? serde.json; } cancel(invocationId: InvocationId): void { @@ -146,7 +148,7 @@ export class ContextImpl implements ObjectContext, WorkflowContext { WasmCommandType.AttachInvocation, () => {}, (vm) => vm.sys_attach_invocation(invocationId), - SuccessWithSerde(serde ?? defaultSerde(), this.journalValueCodec), + SuccessWithSerde(serde ?? this.defaultSerde, this.journalValueCodec), Failure ); } @@ -173,7 +175,7 @@ export class ContextImpl implements ObjectContext, WorkflowContext { () => {}, (vm) => vm.sys_get_state(name), VoidAsNull, - SuccessWithSerde(serde ?? defaultSerde(), this.journalValueCodec) + SuccessWithSerde(serde ?? this.defaultSerde, this.journalValueCodec) ); } @@ -191,7 +193,7 @@ export class ContextImpl implements ObjectContext, WorkflowContext { WasmCommandType.SetState, () => this.journalValueCodec.encode( - (serde ?? defaultSerde()).serialize(value) + (serde ?? this.defaultSerde).serialize(value) ), (vm, bytes) => vm.sys_set_state(name, bytes) ); @@ -344,21 +346,36 @@ export class ContextImpl implements ObjectContext, WorkflowContext { } serviceClient({ name }: ServiceDefinitionFrom): Client> { - return makeRpcCallProxy((call) => this.genericCall(call), name); + return makeRpcCallProxy( + (call) => this.genericCall(call), + this.defaultSerde, + + name + ); } objectClient( { name }: VirtualObjectDefinitionFrom, key: string ): Client> { - return makeRpcCallProxy((call) => this.genericCall(call), name, key); + return makeRpcCallProxy( + (call) => this.genericCall(call), + this.defaultSerde, + name, + key + ); } workflowClient( { name }: WorkflowDefinitionFrom, key: string ): Client> { - return makeRpcCallProxy((call) => this.genericCall(call), name, key); + return makeRpcCallProxy( + (call) => this.genericCall(call), + this.defaultSerde, + name, + key + ); } public serviceSendClient( @@ -367,6 +384,7 @@ export class ContextImpl implements ObjectContext, WorkflowContext { ): SendClient> { return makeRpcSendProxy( (send) => this.genericSend(send), + this.defaultSerde, name, undefined, opts?.delay @@ -380,6 +398,7 @@ export class ContextImpl implements ObjectContext, WorkflowContext { ): SendClient> { return makeRpcSendProxy( (send) => this.genericSend(send), + this.defaultSerde, name, key, opts?.delay @@ -393,6 +412,7 @@ export class ContextImpl implements ObjectContext, WorkflowContext { ): SendClient> { return makeRpcSendProxy( (send) => this.genericSend(send), + this.defaultSerde, name, key, opts?.delay @@ -412,7 +432,7 @@ export class ContextImpl implements ObjectContext, WorkflowContext { nameOrAction, actionSecondParameter ); - const serde = options?.serde ?? defaultSerde(); + const serde = options?.serde ?? this.defaultSerde ?? this.defaultSerde; // Prepare the handle let handle: number; @@ -586,7 +606,7 @@ export class ContextImpl implements ObjectContext, WorkflowContext { awakeable.handle, completeSignalPromiseUsing( VoidAsUndefined, - SuccessWithSerde(serde, this.journalValueCodec), + SuccessWithSerde(serde ?? this.defaultSerde, this.journalValueCodec), Failure ) ), @@ -606,8 +626,8 @@ export class ContextImpl implements ObjectContext, WorkflowContext { } else { value = payload !== undefined - ? defaultSerde().serialize(payload) - : defaultSerde().serialize(null); + ? this.defaultSerde.serialize(payload) + : this.defaultSerde.serialize(null); } return this.journalValueCodec.encode(value); }, @@ -793,7 +813,7 @@ class DurablePromiseImpl implements DurablePromise { private readonly name: string, serde?: Serde ) { - this.serde = serde ?? defaultSerde(); + this.serde = serde ?? (this.ctx.defaultSerde as unknown as Serde); } then( @@ -981,7 +1001,7 @@ const VoidAsUndefined: Completer = (value, prom) => { }; function SuccessWithSerde( - serde?: Serde, + serde: Serde, journalCodec?: JournalValueCodec, transform?: (success: T) => U ): Completer { @@ -995,12 +1015,7 @@ function SuccessWithSerde( } else { buffer = value.Success; } - let val: T; - if (serde) { - val = serde.deserialize(buffer); - } else { - val = defaultSerde().deserialize(buffer); - } + let val = serde.deserialize(buffer); if (transform) { val = transform(val); } diff --git a/packages/restate-sdk/src/endpoint/components.ts b/packages/restate-sdk/src/endpoint/components.ts index bb236c25..9607e5e3 100644 --- a/packages/restate-sdk/src/endpoint/components.ts +++ b/packages/restate-sdk/src/endpoint/components.ts @@ -21,7 +21,8 @@ import type { WorkflowOptions, } from "../types/rpc.js"; import { HandlerKind } from "../types/rpc.js"; -import { millisOrDurationToMillis } from "@restatedev/restate-sdk-core"; +import type { Serde } from "@restatedev/restate-sdk-core"; +import { millisOrDurationToMillis, serde } from "@restatedev/restate-sdk-core"; // // Interfaces @@ -44,17 +45,22 @@ export interface ComponentHandler { // Service // -function handlerInputDiscovery(handler: HandlerWrapper): d.InputPayload { +function handlerInputDiscovery( + handler: HandlerWrapper, + defaultSerde: Serde +): d.InputPayload { + const serde = handler.inputSerde ?? defaultSerde; + let contentType = undefined; let jsonSchema = undefined; - if (handler.inputSerde.jsonSchema) { - jsonSchema = handler.inputSerde.jsonSchema; - contentType = handler.accept ?? handler.inputSerde.contentType; + if (serde.jsonSchema) { + jsonSchema = serde.jsonSchema; + contentType = handler.accept ?? serde.contentType; } else if (handler.accept) { contentType = handler.accept; - } else if (handler.inputSerde.contentType) { - contentType = handler.inputSerde.contentType; + } else if (serde.contentType) { + contentType = serde.contentType; } else { // no input information return {}; @@ -67,20 +73,20 @@ function handlerInputDiscovery(handler: HandlerWrapper): d.InputPayload { }; } -function handlerOutputDiscovery(handler: HandlerWrapper): d.OutputPayload { +function handlerOutputDiscovery( + handler: HandlerWrapper, + defaultSerde: Serde +): d.OutputPayload { + const serde = handler.outputSerde ?? defaultSerde; + let contentType = undefined; let jsonSchema = undefined; - if (handler.outputSerde.jsonSchema) { - jsonSchema = handler.outputSerde.jsonSchema; - contentType = - handler.contentType ?? - handler.outputSerde.contentType ?? - "application/json"; - } else if (handler.contentType) { - contentType = handler.contentType; - } else if (handler.outputSerde.contentType) { - contentType = handler.outputSerde.contentType; + if (serde.jsonSchema) { + jsonSchema = serde.jsonSchema; + contentType = serde.contentType ?? "application/json"; + } else if (serde.contentType) { + contentType = serde.contentType; } else { // no input information return { setContentTypeIfEmpty: false }; @@ -116,7 +122,10 @@ export class ServiceComponent implements Component { ([name, handler]) => { return { name, - ...commonHandlerOptions(handler.handlerWrapper), + ...commonHandlerOptions( + handler.handlerWrapper, + this.options?.serde ?? serde.json + ), } satisfies d.Handler; } ); @@ -188,7 +197,10 @@ export class VirtualObjectComponent implements Component { return { name, ty: handler.kind() === HandlerKind.EXCLUSIVE ? "EXCLUSIVE" : "SHARED", - ...commonHandlerOptions(handler.handlerWrapper), + ...commonHandlerOptions( + handler.handlerWrapper, + this.options?.serde ?? serde.json + ), } satisfies d.Handler; } ); @@ -263,7 +275,10 @@ export class WorkflowComponent implements Component { this.options?.workflowRetention !== undefined ? millisOrDurationToMillis(this.options?.workflowRetention) : undefined, - ...commonHandlerOptions(handler.handlerWrapper), + ...commonHandlerOptions( + handler.handlerWrapper, + this.options?.serde ?? serde.json + ), } satisfies d.Handler; } ); @@ -382,10 +397,13 @@ function commonServiceOptions( }; } -function commonHandlerOptions(wrapper: HandlerWrapper) { +function commonHandlerOptions( + wrapper: HandlerWrapper, + defaultSerde: Serde +) { return { - input: handlerInputDiscovery(wrapper), - output: handlerOutputDiscovery(wrapper), + input: handlerInputDiscovery(wrapper, defaultSerde), + output: handlerOutputDiscovery(wrapper, defaultSerde), journalRetention: wrapper.journalRetention !== undefined ? millisOrDurationToMillis(wrapper.journalRetention) diff --git a/packages/restate-sdk/src/endpoint/handlers/generic.ts b/packages/restate-sdk/src/endpoint/handlers/generic.ts index 1a45036c..f776aec0 100644 --- a/packages/restate-sdk/src/endpoint/handlers/generic.ts +++ b/packages/restate-sdk/src/endpoint/handlers/generic.ts @@ -386,6 +386,7 @@ export class GenericHandler implements RestateHandler { inputReader, outputWriter, journalValueCodec, + service.options?.serde, service.options?.asTerminalError ); diff --git a/packages/restate-sdk/src/types/rpc.ts b/packages/restate-sdk/src/types/rpc.ts index cb1eb2db..3732f6dc 100644 --- a/packages/restate-sdk/src/types/rpc.ts +++ b/packages/restate-sdk/src/types/rpc.ts @@ -36,7 +36,6 @@ import { type WorkflowSharedHandler, type Serde, type Duration, - serde, } from "@restatedev/restate-sdk-core"; import { ensureError, TerminalError } from "./errors.js"; @@ -168,12 +167,9 @@ function optsFromArgs(args: unknown[]): { }; } -export const defaultSerde = (): Serde => { - return serde.json as Serde; -}; - export const makeRpcCallProxy = ( genericCall: (call: GenericCall) => Promise, + defaultSerde: Serde, service: string, key?: string ): T => { @@ -184,10 +180,10 @@ export const makeRpcCallProxy = ( const method = prop as string; return (...args: unknown[]) => { const { parameter, opts } = optsFromArgs(args); - const requestSerde = opts?.input ?? defaultSerde(); + const requestSerde = opts?.input ?? defaultSerde; const responseSerde = (opts as ClientCallOptions | undefined)?.output ?? - defaultSerde(); + defaultSerde; return genericCall({ service, method, @@ -208,6 +204,7 @@ export const makeRpcCallProxy = ( export const makeRpcSendProxy = ( genericSend: (send: GenericSend) => void, + defaultSerde: Serde, service: string, key?: string, legacyDelay?: number @@ -219,7 +216,7 @@ export const makeRpcSendProxy = ( const method = prop as string; return (...args: unknown[]) => { const { parameter, opts } = optsFromArgs(args); - const requestSerde = opts?.input ?? defaultSerde(); + const requestSerde = opts?.input ?? defaultSerde; const delay = legacyDelay ?? (opts as ClientSendOptions | undefined)?.delay; @@ -400,9 +397,6 @@ export class HandlerWrapper { | ObjectHandlerOpts | WorkflowHandlerOpts ): HandlerWrapper { - const inputSerde: Serde = opts?.input ?? defaultSerde(); - const outputSerde: Serde = opts?.output ?? defaultSerde(); - // we must create here a copy of the handler // to be able to reuse the original handler in other places. // like for example the same logic but under different routes. @@ -413,8 +407,8 @@ export class HandlerWrapper { return new HandlerWrapper( kind, handlerCopy, - inputSerde, - outputSerde, + opts?.input, + opts?.output, opts?.accept, opts?.description, opts?.metadata, @@ -435,15 +429,12 @@ export class HandlerWrapper { return handler[HANDLER_SYMBOL] as HandlerWrapper | undefined; } - public readonly accept?: string; - public readonly contentType?: string; - private constructor( public readonly kind: HandlerKind, private handler: Function, - public readonly inputSerde: Serde, - public readonly outputSerde: Serde, - accept?: string, + public readonly inputSerde?: Serde, + public readonly outputSerde?: Serde, + public readonly accept?: string, public readonly description?: string, public readonly metadata?: Record, public readonly idempotencyRetention?: Duration | number, @@ -454,19 +445,16 @@ export class HandlerWrapper { public readonly enableLazyState?: boolean, public readonly retryPolicy?: RetryPolicy, public readonly asTerminalError?: (error: any) => TerminalError | undefined - ) { - this.accept = accept !== undefined ? accept : inputSerde.contentType; - this.contentType = outputSerde.contentType; - } + ) {} bindInstance(t: unknown) { this.handler = this.handler.bind(t) as Function; } - async invoke(context: unknown, input: Uint8Array) { + async invoke(context: { defaultSerde: Serde }, input: Uint8Array) { let req: unknown; try { - req = this.inputSerde.deserialize(input); + req = (this.inputSerde ?? context.defaultSerde).deserialize(input); } catch (e) { const error = ensureError(e); throw new TerminalError(`Failed to deserialize input: ${error.message}`, { @@ -474,7 +462,7 @@ export class HandlerWrapper { }); } const res: unknown = await this.handler(context, req); - return this.outputSerde.serialize(res); + return (this.outputSerde ?? context.defaultSerde).serialize(res); } /** @@ -887,6 +875,13 @@ export type ServiceOptions = { * ``` */ asTerminalError?: (error: any) => TerminalError | undefined; + + /** + * Default serde to use for requests, responses, state, side effects, awakeables, promises. Used when no other serde is specified. + * + * If not provided, defaults to `serde.json`. + */ + serde?: Serde; }; /**