diff --git a/packages/core/src/v3/apiClient/runStream.ts b/packages/core/src/v3/apiClient/runStream.ts index 217b7a5108..b52283eae9 100644 --- a/packages/core/src/v3/apiClient/runStream.ts +++ b/packages/core/src/v3/apiClient/runStream.ts @@ -236,6 +236,10 @@ export class SSEStreamSubscription implements StreamSubscription { // permanently. `404` (stream gone) and `410` (session closed) // are sensible defaults; tune per-caller for other 4xx. nonRetryableStatuses?: readonly number[]; + // Optional fetch override. Used by transports that need to route + // the SSE connect through a custom path (proxy, custom headers, + // tracing). Defaults to global `fetch`. + fetchClient?: typeof fetch; } ) { this.lastEventId = options.lastEventId; @@ -331,7 +335,8 @@ export class SSEStreamSubscription implements StreamSubscription { headers["Timeout-Seconds"] = this.options.timeoutInSeconds.toString(); } - const response = await fetch(this.url, { + const fetchClient = this.options.fetchClient ?? fetch; + const response = await fetchClient(this.url, { headers, signal: this.internalAbort.signal, }); diff --git a/packages/trigger-sdk/src/v3/ai.ts b/packages/trigger-sdk/src/v3/ai.ts index 81994a0368..64aa5f9a99 100644 --- a/packages/trigger-sdk/src/v3/ai.ts +++ b/packages/trigger-sdk/src/v3/ai.ts @@ -35,6 +35,7 @@ import { type TaskWithSchema, SESSION_IN_EVENT_ID_HEADER, TRIGGER_CONTROL_SUBTYPE, + generateJWT, type WriterStreamOptions, } from "@trigger.dev/core/v3"; import type { @@ -8411,6 +8412,32 @@ export type { InferChatClientData, InferChatUIMessage } from "./ai-shared.js"; /** * Options for {@link createChatStartSessionAction}. */ +/** + * Discriminator for per-endpoint `baseURL` / `fetch` callbacks on + * `createChatStartSessionAction`. + * + * - `"sessions"` — `POST /api/v1/sessions` (session create + first run trigger). + * - `"auth"` — `POST /api/v1/auth/jwt/claims` (only fired when + * `tokenTTL` is set; otherwise the publicAccessToken from session create + * is reused as-is). + */ +export type ChatStartSessionEndpoint = "sessions" | "auth"; + +export type ChatStartSessionEndpointContext = { + endpoint: ChatStartSessionEndpoint; + chatId: string; +}; + +export type ChatStartSessionBaseURLResolver = ( + ctx: ChatStartSessionEndpointContext +) => string; + +export type ChatStartSessionFetchOverride = ( + url: string, + init: RequestInit, + ctx: ChatStartSessionEndpointContext +) => Promise; + export type CreateChatStartSessionActionOptions = { /** TTL for the session-scoped public access token. @default "1h" */ tokenTTL?: string | number | Date; @@ -8419,6 +8446,21 @@ export type CreateChatStartSessionActionOptions = { * Per-call `params.triggerConfig` shallow-merges on top. */ triggerConfig?: Partial; + /** + * Override the Trigger.dev API base URL. String applies to both + * `/api/v1/sessions` and `/api/v1/auth/jwt/claims`; function picks per + * endpoint. When unset, falls back to `apiClientManager.baseURL` + * (typically the `TRIGGER_API_URL` env var). Set this to route session + * create through a trusted edge proxy that injects server-side signal + * into `basePayload.metadata` before forwarding upstream. + */ + baseURL?: string | ChatStartSessionBaseURLResolver; + /** + * Per-request fetch override. Receives the resolved URL, RequestInit, + * and endpoint context. Use for header injection, proxy routing, or + * custom retry. Applies to both session-create and JWT-claims POSTs. + */ + fetch?: ChatStartSessionFetchOverride; }; /** @@ -8542,13 +8584,26 @@ function createChatStartSessionAction( : {}), }; - const created = await sessions.start({ - type: "chat.agent", + const startBody = { + type: "chat.agent" as const, externalId: params.chatId, taskIdentifier: taskId, triggerConfig, metadata: params.metadata, - }); + }; + + const baseURLOption = options?.baseURL; + const fetchOverride = options?.fetch; + const hasOverride = baseURLOption !== undefined || fetchOverride !== undefined; + + const created: { id: string; runId: string; publicAccessToken: string } = hasOverride + ? await callSessionsCreateWithOverride({ + chatId: params.chatId, + body: startBody, + baseURLOption, + fetchOverride, + }) + : await sessions.start(startBody); // Session create returns a session PAT directly when called with a // start token, but when the SDK call goes via the secret key we still @@ -8556,13 +8611,20 @@ function createChatStartSessionAction( // re-minting here lets the customer override `tokenTTL`). const publicAccessToken = options?.tokenTTL !== undefined - ? await auth.createPublicToken({ - scopes: { - read: { sessions: params.chatId }, - write: { sessions: params.chatId }, - }, - expirationTime: options.tokenTTL, - }) + ? hasOverride + ? await mintPublicTokenWithOverride({ + chatId: params.chatId, + expirationTime: options.tokenTTL, + baseURLOption, + fetchOverride, + }) + : await auth.createPublicToken({ + scopes: { + read: { sessions: params.chatId }, + write: { sessions: params.chatId }, + }, + expirationTime: options.tokenTTL, + }) : created.publicAccessToken; return { @@ -8573,6 +8635,101 @@ function createChatStartSessionAction( }; } +function resolveChatStartBaseURL( + endpoint: ChatStartSessionEndpoint, + chatId: string, + option: string | ChatStartSessionBaseURLResolver | undefined +): string { + const fallback = apiClientManager.baseURL ?? "https://api.trigger.dev"; + const raw = + typeof option === "function" + ? option({ endpoint, chatId }) + : option ?? fallback; + return raw.replace(/\/$/, ""); +} + +function overrideRequestHeaders(accessToken: string): Record { + const headers: Record = { + "Content-Type": "application/json", + Authorization: `Bearer ${accessToken}`, + "x-trigger-source": "sdk", + }; + // Forward the preview-branch hint so override-mode requests land on the + // same env the standard ApiClient path would have routed to. Mirrors + // ApiClient.#getHeaders. Read from TRIGGER_PREVIEW_BRANCH / + // VERCEL_GIT_COMMIT_REF via apiClientManager.branchName. + if (apiClientManager.branchName) { + headers["x-trigger-branch"] = apiClientManager.branchName; + } + return headers; +} + +async function callSessionsCreateWithOverride(args: { + chatId: string; + body: { type: "chat.agent"; externalId: string; taskIdentifier: string; triggerConfig: SessionTriggerConfig; metadata?: Record }; + baseURLOption: string | ChatStartSessionBaseURLResolver | undefined; + fetchOverride: ChatStartSessionFetchOverride | undefined; +}): Promise<{ id: string; runId: string; publicAccessToken: string }> { + const accessToken = apiClientManager.accessToken; + if (!accessToken) { + throw new Error( + "chat.createStartSessionAction: no API access token configured. Set TRIGGER_SECRET_KEY or call apiClientManager.setGlobalAPIClientConfiguration before invoking the action." + ); + } + const ctx: ChatStartSessionEndpointContext = { endpoint: "sessions", chatId: args.chatId }; + const url = `${resolveChatStartBaseURL("sessions", args.chatId, args.baseURLOption)}/api/v1/sessions`; + const init: RequestInit = { + method: "POST", + headers: overrideRequestHeaders(accessToken), + body: JSON.stringify(args.body), + }; + const response = args.fetchOverride + ? await args.fetchOverride(url, init, ctx) + : await fetch(url, init); + if (!response.ok) { + const text = await response.text().catch(() => ""); + throw new Error(`sessions.start failed: ${response.status} ${text}`); + } + const json = (await response.json()) as { id: string; runId: string; publicAccessToken: string }; + return json; +} + +async function mintPublicTokenWithOverride(args: { + chatId: string; + expirationTime: string | number | Date; + baseURLOption: string | ChatStartSessionBaseURLResolver | undefined; + fetchOverride: ChatStartSessionFetchOverride | undefined; +}): Promise { + const accessToken = apiClientManager.accessToken; + if (!accessToken) { + throw new Error( + "chat.createStartSessionAction: no API access token configured for JWT mint." + ); + } + const ctx: ChatStartSessionEndpointContext = { endpoint: "auth", chatId: args.chatId }; + const url = `${resolveChatStartBaseURL("auth", args.chatId, args.baseURLOption)}/api/v1/auth/jwt/claims`; + const init: RequestInit = { + method: "POST", + headers: overrideRequestHeaders(accessToken), + }; + const response = args.fetchOverride + ? await args.fetchOverride(url, init, ctx) + : await fetch(url, init); + if (!response.ok) { + const text = await response.text().catch(() => ""); + throw new Error(`auth.createPublicToken failed: ${response.status} ${text}`); + } + const claims = (await response.json()) as Record; + return generateJWT({ + secretKey: accessToken, + payload: { + ...claims, + scopes: [`read:sessions:${args.chatId}`, `write:sessions:${args.chatId}`], + }, + expirationTime: args.expirationTime, + }); +} + export const chat = { /** Create a chat agent. See {@link chatAgent}. */ agent: chatAgent, diff --git a/packages/trigger-sdk/src/v3/chat-client.ts b/packages/trigger-sdk/src/v3/chat-client.ts index 40132a624e..98380f1e8b 100644 --- a/packages/trigger-sdk/src/v3/chat-client.ts +++ b/packages/trigger-sdk/src/v3/chat-client.ts @@ -20,7 +20,6 @@ import type { SessionTriggerConfig, Task } from "@trigger.dev/core/v3"; import type { ModelMessage, UIMessage, UIMessageChunk } from "ai"; import { readUIMessageStream } from "ai"; import { - ApiClient, apiClientManager, controlSubtype, SSEStreamSubscription, @@ -53,6 +52,26 @@ export type ChatSession = { lastEventId?: string; }; +/** + * Discriminator passed to per-endpoint `baseURL` and `fetch` callbacks on + * `AgentChat`. Same shape as the type on `TriggerChatTransport` — these + * mirror so customers can share a single resolver between the two clients. + */ +export type AgentChatEndpoint = "in" | "out"; + +export type AgentChatEndpointContext = { + endpoint: AgentChatEndpoint; + chatId: string; +}; + +export type AgentChatBaseURLResolver = (ctx: AgentChatEndpointContext) => string; + +export type AgentChatFetchOverride = ( + url: string, + init: RequestInit, + ctx: AgentChatEndpointContext +) => Promise; + export type AgentChatOptions = { /** The agent task ID to trigger. */ agent: string; @@ -89,6 +108,26 @@ export type AgentChatOptions = { * chat. Folded into `sessions.start({...triggerConfig})` body. */ triggerConfig?: SessionTriggerConfig; + /** + * Override the Trigger.dev API base URL for the chat's `.in/append` and + * `.out` SSE endpoints. String form applies to both; pass a function to + * pick per endpoint. Defaults to `apiClientManager.baseURL` (whatever + * `@trigger.dev/sdk` was configured with — typically `TRIGGER_API_URL` + * env var). + * + * Session creation (`POST /api/v1/sessions`) and token mint + * (`POST /api/v1/auth/jwt/claims`) still flow through + * `apiClientManager` — pass equivalent options to + * `chat.createStartSessionAction` if you need those routed too. + */ + baseURL?: string | AgentChatBaseURLResolver; + /** + * Optional per-request fetch override. Receives the resolved URL, the + * RequestInit, and endpoint context. Use this for header injection + * (tracing), proxy routing, or custom retries. Applies to both the + * `.in/append` POSTs and the `.out` SSE GET. + */ + fetch?: AgentChatFetchOverride; }; // ─── ChatStream ──────────────────────────────────────────────────── @@ -272,6 +311,8 @@ export class AgentChat { private readonly triggerConfigDefault: SessionTriggerConfig | undefined; private readonly onTriggered: AgentChatOptions["onTriggered"]; private readonly onTurnComplete: AgentChatOptions["onTurnComplete"]; + private readonly baseURLResolver: AgentChatBaseURLResolver; + private readonly fetchOverride: AgentChatFetchOverride | undefined; private state: SessionState; @@ -283,6 +324,11 @@ export class AgentChat { this.triggerConfigDefault = options.triggerConfig; this.onTriggered = options.onTriggered; this.onTurnComplete = options.onTurnComplete; + const baseURLOption = options.baseURL; + this.baseURLResolver = typeof baseURLOption === "function" + ? baseURLOption + : () => baseURLOption ?? apiClientManager.baseURL ?? "https://api.trigger.dev"; + this.fetchOverride = options.fetch; // Hydration: a non-empty `session` means the caller knows the // session already exists (started in a previous request). Mark @@ -378,12 +424,7 @@ export class AgentChat { metadata: this.clientData, } as ChatTaskWirePayload; - const api = this.createApiClient(); - await api.appendToSessionStream( - this.chatId, - "in", - serializeInputChunk({ kind: "message", payload }) - ); + await this.appendInputChunk(serializeInputChunk({ kind: "message", payload })); return this.subscribeToSessionStream(options?.abortSignal); } @@ -404,15 +445,7 @@ export class AgentChat { }; try { - const api = this.createApiClient(); - await api.appendToSessionStream( - this.chatId, - "in", - serializeInputChunk({ - kind: "message", - payload, - }) - ); + await this.appendInputChunk(serializeInputChunk({ kind: "message", payload })); return true; } catch { return false; @@ -424,14 +457,7 @@ export class AgentChat { if (!this.state.started) return; this.state.skipToTurnComplete = true; - const api = this.createApiClient(); - await api - .appendToSessionStream( - this.chatId, - "in", - serializeInputChunk({ kind: "stop" }) - ) - .catch(() => {}); + await this.appendInputChunk(serializeInputChunk({ kind: "stop" })).catch(() => {}); } /** @@ -459,10 +485,7 @@ export class AgentChat { */ isFinal: boolean; }): Promise { - const api = this.createApiClient(); - await api.appendToSessionStream( - this.chatId, - "in", + await this.appendInputChunk( serializeInputChunk({ kind: "handover", partialAssistantMessage: args.partialAssistantMessage, @@ -481,12 +504,7 @@ export class AgentChat { * surface. */ async sendHandoverSkip(): Promise { - const api = this.createApiClient(); - await api.appendToSessionStream( - this.chatId, - "in", - serializeInputChunk({ kind: "handover-skip" }) - ); + await this.appendInputChunk(serializeInputChunk({ kind: "handover-skip" })); } /** @@ -531,15 +549,7 @@ export class AgentChat { }; try { - const api = this.createApiClient(); - await api.appendToSessionStream( - this.chatId, - "in", - serializeInputChunk({ - kind: "message", - payload, - }) - ); + await this.appendInputChunk(serializeInputChunk({ kind: "message", payload })); } catch { throw new Error("Failed to send action. The session may have ended."); } @@ -553,10 +563,7 @@ export class AgentChat { if (!this.state.started) return false; try { - const api = this.createApiClient(); - await api.appendToSessionStream( - this.chatId, - "in", + await this.appendInputChunk( serializeInputChunk({ kind: "message", payload: { @@ -582,10 +589,41 @@ export class AgentChat { // ─── Private ─────────────────────────────────────────────────── - private createApiClient(): ApiClient { - const baseURL = apiClientManager.baseURL ?? "https://api.trigger.dev"; + private resolveBaseURL(endpoint: AgentChatEndpoint): string { + return this.baseURLResolver({ endpoint, chatId: this.chatId }).replace(/\/$/, ""); + } + + private async doFetch( + ctx: AgentChatEndpointContext, + url: string, + init: RequestInit + ): Promise { + return this.fetchOverride ? this.fetchOverride(url, init, ctx) : fetch(url, init); + } + + private async appendInputChunk(body: string): Promise { const accessToken = apiClientManager.accessToken ?? ""; - return new ApiClient(baseURL, accessToken); + const ctx: AgentChatEndpointContext = { endpoint: "in", chatId: this.chatId }; + const url = `${this.resolveBaseURL("in")}/realtime/v1/sessions/${encodeURIComponent(this.chatId)}/in/append`; + const headers: Record = { + "Content-Type": "application/json", + Authorization: `Bearer ${accessToken}`, + "x-trigger-source": "sdk", + }; + const response = await this.doFetch(ctx, url, { method: "POST", headers, body }); + if (!response.ok) { + const text = await response.text().catch(() => ""); + // Match the error shape that ApiClient/zodfetch produced before the + // inline-POST refactor so callers inspecting `error.name === + // "TriggerApiError"` or `error.status` keep working. + const err = new Error(`appendToSessionStream failed: ${response.status} ${text}`) as Error & { + name: string; + status: number; + }; + err.name = "TriggerApiError"; + err.status = response.status; + throw err; + } } /** @@ -650,10 +688,33 @@ export class AgentChat { options?: { sendStopOnAbort?: boolean } ): ReadableStream { const state = this.state; - const baseURL = apiClientManager.baseURL ?? "https://api.trigger.dev"; const accessToken = apiClientManager.accessToken ?? ""; const onTurnComplete = this.onTurnComplete; const chatId = this.chatId; + const sseCtx: AgentChatEndpointContext = { endpoint: "out", chatId }; + const fetchOverride = this.fetchOverride; + const sseFetchClient: typeof fetch | undefined = fetchOverride + ? ((input, init) => { + if (typeof input === "string") { + return fetchOverride(input, init ?? {}, sseCtx); + } + if (input instanceof URL) { + return fetchOverride(input.toString(), init ?? {}, sseCtx); + } + // Request — preserve its url + intrinsic init, let any provided + // init override on top (matches fetch(Request, init) semantics). + return fetchOverride( + input.url, + { + method: input.method, + headers: input.headers, + signal: input.signal, + ...(init ?? {}), + }, + sseCtx + ); + }) as typeof fetch + : undefined; const internalAbort = new AbortController(); const combinedSignal = abortSignal @@ -666,14 +727,7 @@ export class AgentChat { () => { if (options?.sendStopOnAbort !== false) { state.skipToTurnComplete = true; - const api = new ApiClient(baseURL, accessToken); - api - .appendToSessionStream( - chatId, - "in", - serializeInputChunk({ kind: "stop" }) - ) - .catch(() => {}); + this.appendInputChunk(serializeInputChunk({ kind: "stop" })).catch(() => {}); } internalAbort.abort(); }, @@ -681,7 +735,7 @@ export class AgentChat { ); } - const streamUrl = `${baseURL}/realtime/v1/sessions/${encodeURIComponent(chatId)}/out`; + const streamUrl = `${this.resolveBaseURL("out")}/realtime/v1/sessions/${encodeURIComponent(chatId)}/out`; return new ReadableStream({ start: async (controller) => { @@ -693,6 +747,7 @@ export class AgentChat { signal: combinedSignal, timeoutInSeconds: this.streamTimeoutSeconds, lastEventId: state.lastEventId, + fetchClient: sseFetchClient, }); const sseStream = await subscription.subscribe(); const reader = sseStream.getReader(); diff --git a/packages/trigger-sdk/src/v3/chat.test.ts b/packages/trigger-sdk/src/v3/chat.test.ts index 5f50854ec4..6469f1ac86 100644 --- a/packages/trigger-sdk/src/v3/chat.test.ts +++ b/packages/trigger-sdk/src/v3/chat.test.ts @@ -609,6 +609,94 @@ describe("TriggerChatTransport", () => { expect(subscribe!).toContain("/realtime/v1/sessions/chat-by-chatid/out"); }); + it("functional baseURL dispatches per endpoint (in vs out)", async () => { + const requests: Array<{ url: string; ctxEndpoint: string | undefined }> = []; + global.fetch = vi.fn().mockImplementation(async (url: string | URL) => { + const urlStr = typeof url === "string" ? url : url.toString(); + requests.push({ url: urlStr, ctxEndpoint: undefined }); + if (isSessionStreamAppendUrl(urlStr)) return defaultAppendResponse(); + if (isSessionOutSubscribeUrl(urlStr)) return defaultSseResponse(); + throw new Error(`Unexpected URL: ${urlStr}`); + }); + + const baseURLFn = vi.fn(({ endpoint }: { endpoint: "in" | "out"; chatId: string }) => + endpoint === "out" + ? "https://stream.example.com" + : "https://api.example.com" + ); + + const transport = new TriggerChatTransport({ + task: "my-chat-task", + accessToken: () => "pat", + baseURL: baseURLFn, + sessions: { "chat-fn": { publicAccessToken: "p" } }, + }); + + const stream = await transport.sendMessages({ + trigger: "submit-message", + chatId: "chat-fn", + messageId: undefined, + messages: [createUserMessage("Hi")], + abortSignal: undefined, + }); + await drainChunks(stream); + + const appendCalls = baseURLFn.mock.calls.filter((c) => c[0].endpoint === "in"); + const outCalls = baseURLFn.mock.calls.filter((c) => c[0].endpoint === "out"); + expect(appendCalls.length).toBeGreaterThanOrEqual(1); + expect(outCalls.length).toBeGreaterThanOrEqual(1); + expect(appendCalls[0]![0].chatId).toBe("chat-fn"); + expect(outCalls[0]![0].chatId).toBe("chat-fn"); + + const append = requests.find((r) => isSessionStreamAppendUrl(r.url)); + const subscribe = requests.find((r) => isSessionOutSubscribeUrl(r.url)); + expect(append!.url.startsWith("https://api.example.com/")).toBe(true); + expect(subscribe!.url.startsWith("https://stream.example.com/")).toBe(true); + }); + + it("fetch override is invoked for both .in/append and .out SSE with endpoint ctx", async () => { + const fetchCalls: Array<{ url: string; endpoint: string; chatId: string }> = []; + + const customFetch = vi.fn( + async ( + url: string, + init: RequestInit, + ctx: { endpoint: "in" | "out"; chatId: string } + ) => { + fetchCalls.push({ url, endpoint: ctx.endpoint, chatId: ctx.chatId }); + if (isSessionStreamAppendUrl(url)) return defaultAppendResponse(); + if (isSessionOutSubscribeUrl(url)) return defaultSseResponse(); + throw new Error(`Unexpected URL: ${url}`); + } + ); + + global.fetch = vi.fn().mockRejectedValue(new Error("global fetch should not be called")); + + const transport = new TriggerChatTransport({ + task: "my-chat-task", + accessToken: () => "pat", + baseURL: "https://api.test.trigger.dev", + fetch: customFetch, + sessions: { "chat-fetch": { publicAccessToken: "p" } }, + }); + + const stream = await transport.sendMessages({ + trigger: "submit-message", + chatId: "chat-fetch", + messageId: undefined, + messages: [createUserMessage("Hi")], + abortSignal: undefined, + }); + await drainChunks(stream); + + const inCalls = fetchCalls.filter((c) => c.endpoint === "in"); + const outCalls = fetchCalls.filter((c) => c.endpoint === "out"); + expect(inCalls.length).toBeGreaterThanOrEqual(1); + expect(outCalls.length).toBeGreaterThanOrEqual(1); + expect(inCalls[0]!.chatId).toBe("chat-fetch"); + expect(outCalls[0]!.chatId).toBe("chat-fetch"); + }); + it("routes .out SSE through streamBaseURL while appends stay on baseURL", async () => { const requests: string[] = []; global.fetch = vi.fn().mockImplementation(async (url: string | URL) => { diff --git a/packages/trigger-sdk/src/v3/chat.ts b/packages/trigger-sdk/src/v3/chat.ts index a979b8f2b1..2aefc2bb80 100644 --- a/packages/trigger-sdk/src/v3/chat.ts +++ b/packages/trigger-sdk/src/v3/chat.ts @@ -25,7 +25,6 @@ import type { ChatTransport, UIMessage, UIMessageChunk, ChatRequestOptions } from "ai"; import { - ApiClient, controlSubtype, headerValue, PUBLIC_ACCESS_TOKEN_HEADER, @@ -38,6 +37,43 @@ import type { ChatInputChunk, ChatTaskWirePayload } from "./ai-shared.js"; const DEFAULT_BASE_URL = "https://api.trigger.dev"; const DEFAULT_STREAM_TIMEOUT_SECONDS = 120; +/** + * Discriminator passed to per-endpoint `baseURL` and `fetch` callbacks. + * + * - `"in"` — `POST /realtime/v1/sessions/{chatId}/in/append` (user messages, + * stops, actions). + * - `"out"` — `GET /realtime/v1/sessions/{chatId}/out` (SSE response stream). + * + * Other endpoints (`/api/v1/sessions`, `/api/v1/auth/jwt/claims`) are reached + * from the server-side `chat.createStartSessionAction` and `accessToken` + * callback, not the transport — they accept the same callback shape on their + * own option objects. + */ +export type ChatTransportEndpoint = "in" | "out"; + +/** Context passed to `baseURL` and `fetch` callbacks. */ +export type ChatTransportEndpointContext = { + endpoint: ChatTransportEndpoint; + chatId: string; +}; + +/** Resolver form of `baseURL` — return the base for the given endpoint. */ +export type ChatBaseURLResolver = (ctx: ChatTransportEndpointContext) => string; + +/** + * Per-request fetch override. Receives the fully-resolved URL and the + * RequestInit the transport would have used, plus endpoint context for + * routing decisions. Customers can rewrite the URL, inject headers, or + * delegate to a custom transport (e.g. a Cloudflare worker fronting + * `api.trigger.dev`). Must return a `Response` semantically equivalent to + * what `globalThis.fetch(url, init)` would have returned. + */ +export type ChatFetchOverride = ( + url: string, + init: RequestInit, + ctx: ChatTransportEndpointContext +) => Promise; + /** * Detect 401/403 from realtime/input-stream calls without relying on `instanceof` * (Vitest can load duplicate `@trigger.dev/core` copies, which breaks subclass checks). @@ -229,18 +265,45 @@ export type TriggerChatTransportOptions = { > ) => Promise; - /** Base URL for the Trigger.dev API. @default "https://api.trigger.dev" */ - baseURL?: string; + /** + * Base URL for the Trigger.dev API. Either a single string applied to every + * endpoint, or a function called per request that picks a base URL from the + * endpoint discriminator and chat ID. @default "https://api.trigger.dev" + * + * @example Route appends through a proxy, SSE direct: + * ```ts + * baseURL: ({ endpoint }) => + * endpoint === "out" ? "https://api.trigger.dev" : "https://proxy.example.com", + * ``` + */ + baseURL?: string | ChatBaseURLResolver; /** * Base URL for the SSE stream subscription only (`GET .../sessions/{chatId}/out`). - * Falls back to `baseURL` when unset. Set this to route the long-lived - * stream through a custom proxy (e.g. a Cloudflare worker capturing JA4 - * fingerprints for bot detection) while keeping append POSTs direct to - * `baseURL` to avoid an extra hop on every user message. + * @deprecated Pass a function for `baseURL` instead and branch on + * `endpoint === "out"`. `streamBaseURL` continues to work for backwards + * compatibility and wins over `baseURL` for the SSE endpoint when both + * are set. */ streamBaseURL?: string; + /** + * Optional per-request fetch override. Called with the resolved URL and the + * RequestInit the transport built, plus endpoint context. Use this to + * inject custom headers (e.g. distributed tracing), redirect via a proxy, + * or wrap fetch with retries/logging. + * + * @example Add a tracing header to every chat request: + * ```ts + * fetch: (url, init, ctx) => { + * init.headers = new Headers(init.headers); + * init.headers.set("traceparent", currentTraceparent()); + * return globalThis.fetch(url, init); + * }, + * ``` + */ + fetch?: ChatFetchOverride; + /** Additional headers included in every API request. */ headers?: Record; @@ -361,8 +424,8 @@ export class TriggerChatTransport implements ChatTransport { private readonly resolveStartSession: | ((params: StartSessionParams>) => Promise) | undefined; - private readonly baseURL: string; - private readonly streamBaseURL: string; + private readonly resolveBaseURLFn: ChatBaseURLResolver; + private readonly fetchOverride: ChatFetchOverride | undefined; private readonly extraHeaders: Record; private readonly streamTimeoutSeconds: number; private defaultMetadata: Record | undefined; @@ -383,8 +446,12 @@ export class TriggerChatTransport implements ChatTransport { this.resolveStartSession = options.startSession as | ((params: StartSessionParams>) => Promise) | undefined; - this.baseURL = options.baseURL ?? DEFAULT_BASE_URL; - this.streamBaseURL = options.streamBaseURL ?? this.baseURL; + const baseURLOption = options.baseURL ?? DEFAULT_BASE_URL; + const streamOverride = options.streamBaseURL; + this.resolveBaseURLFn = typeof baseURLOption === "function" + ? (ctx) => (ctx.endpoint === "out" && streamOverride ? streamOverride : baseURLOption(ctx)) + : (ctx) => (ctx.endpoint === "out" && streamOverride ? streamOverride : baseURLOption); + this.fetchOverride = options.fetch; this.extraHeaders = options.headers ?? {}; this.streamTimeoutSeconds = options.streamTimeoutSeconds ?? DEFAULT_STREAM_TIMEOUT_SECONDS; this.defaultMetadata = options.clientData; @@ -528,10 +595,9 @@ export class TriggerChatTransport implements ChatTransport { const state = await this.ensureSessionState(chatId); const sendChatMessage = async (token: string) => { - const apiClient = new ApiClient(this.baseURL, token); - await apiClient.appendToSessionStream( + await this.appendInputChunk( chatId, - "in", + token, this.serializeInputChunk({ kind: "message", payload: wirePayload }) ); }; @@ -708,10 +774,9 @@ export class TriggerChatTransport implements ChatTransport { }; const send = async (token: string) => { - const apiClient = new ApiClient(this.baseURL, token); - await apiClient.appendToSessionStream( + await this.appendInputChunk( chatId, - "in", + token, this.serializeInputChunk({ kind: "message", payload: wirePayload }) ); }; @@ -768,12 +833,7 @@ export class TriggerChatTransport implements ChatTransport { if (!state) return false; const send = async (token: string) => { - const api = new ApiClient(this.baseURL, token); - await api.appendToSessionStream( - chatId, - "in", - this.serializeInputChunk({ kind: "stop" }) - ); + await this.appendInputChunk(chatId, token, this.serializeInputChunk({ kind: "stop" })); }; try { @@ -822,8 +882,7 @@ export class TriggerChatTransport implements ChatTransport { const body = this.serializeInputChunk({ kind: "message", payload: wirePayload }); const send = async (token: string) => { - const apiClient = new ApiClient(this.baseURL, token); - await apiClient.appendToSessionStream(chatId, "in", body); + await this.appendInputChunk(chatId, token, body); }; await this.callWithAuthRetry(chatId, state, send); @@ -978,6 +1037,41 @@ export class TriggerChatTransport implements ChatTransport { * Run `op` with the session's stored PAT. On 401/403, refresh the PAT * via `accessToken` and retry once. Surfaces non-auth errors as-is. */ + private resolveBaseURL(ctx: ChatTransportEndpointContext): string { + const raw = this.resolveBaseURLFn(ctx); + return raw.replace(/\/$/, ""); + } + + private async doFetch( + ctx: ChatTransportEndpointContext, + url: string, + init: RequestInit + ): Promise { + return this.fetchOverride ? this.fetchOverride(url, init, ctx) : fetch(url, init); + } + + private async appendInputChunk(chatId: string, token: string, body: string): Promise { + const ctx: ChatTransportEndpointContext = { endpoint: "in", chatId }; + const url = `${this.resolveBaseURL(ctx)}/realtime/v1/sessions/${encodeURIComponent(chatId)}/in/append`; + const headers: Record = { + "Content-Type": "application/json", + Authorization: `Bearer ${token}`, + "x-trigger-source": "sdk", + ...this.extraHeaders, + }; + const response = await this.doFetch(ctx, url, { method: "POST", headers, body }); + if (!response.ok) { + const text = await response.text().catch(() => ""); + const err = new Error(`appendToSessionStream failed: ${response.status} ${text}`) as Error & { + name: string; + status: number; + }; + err.name = "TriggerApiError"; + err.status = response.status; + throw err; + } + } + private async callWithAuthRetry( chatId: string, state: ChatSessionState, @@ -1026,14 +1120,11 @@ export class TriggerChatTransport implements ChatTransport { () => { if (options?.sendStopOnAbort !== false) { state.skipToTurnComplete = true; - const api = new ApiClient(this.baseURL, state.publicAccessToken); - api - .appendToSessionStream( - chatId, - "in", - this.serializeInputChunk({ kind: "stop" }) - ) - .catch(() => {}); + this.appendInputChunk( + chatId, + state.publicAccessToken, + this.serializeInputChunk({ kind: "stop" }) + ).catch(() => {}); } internalAbort.abort(); }, @@ -1041,7 +1132,7 @@ export class TriggerChatTransport implements ChatTransport { ); } - const streamUrl = `${this.streamBaseURL}/realtime/v1/sessions/${encodeURIComponent(chatId)}/out`; + const streamUrl = `${this.resolveBaseURL({ endpoint: "out", chatId })}/realtime/v1/sessions/${encodeURIComponent(chatId)}/out`; return new ReadableStream({ start: async (controller) => { @@ -1099,6 +1190,31 @@ export class TriggerChatTransport implements ChatTransport { })() : () => {}; + const sseCtx: ChatTransportEndpointContext = { endpoint: "out", chatId }; + const fetchOverride = this.fetchOverride; + const sseFetchClient: typeof fetch | undefined = fetchOverride + ? ((input, init) => { + if (typeof input === "string") { + return fetchOverride(input, init ?? {}, sseCtx); + } + if (input instanceof URL) { + return fetchOverride(input.toString(), init ?? {}, sseCtx); + } + // Request — preserve its url + intrinsic init, let any + // provided init override on top (matches fetch(Request, init) + // semantics). + return fetchOverride( + input.url, + { + method: input.method, + headers: input.headers, + signal: input.signal, + ...(init ?? {}), + }, + sseCtx + ); + }) as typeof fetch + : undefined; const connectSseOnce = async (token: string) => { const subscription = new SSEStreamSubscription(streamUrl, { headers: { @@ -1113,6 +1229,7 @@ export class TriggerChatTransport implements ChatTransport { // keepalive) arrives in 60s, force reconnect. Sized // generously over typical agent thinking pauses. stallTimeoutMs: 60_000, + fetchClient: sseFetchClient, }); currentSubscription = subscription; const sseStream = await subscription.subscribe(); diff --git a/references/ai-chat/cf-worker/.gitignore b/references/ai-chat/cf-worker/.gitignore new file mode 100644 index 0000000000..8619bbe6b2 --- /dev/null +++ b/references/ai-chat/cf-worker/.gitignore @@ -0,0 +1,3 @@ +.wrangler/ +node_modules/ +*.log diff --git a/references/ai-chat/cf-worker/README.md b/references/ai-chat/cf-worker/README.md new file mode 100644 index 0000000000..8c9a733d73 --- /dev/null +++ b/references/ai-chat/cf-worker/README.md @@ -0,0 +1,33 @@ +# cf-trust-test worker + +A minimal Cloudflare Worker that demonstrates the trusted-edge-signals pattern from [`docs/ai-chat/patterns/trusted-edge-signals`](../../../docs/ai-chat/patterns/trusted-edge-signals.mdx). The worker sits in front of the Trigger.dev API, intercepts the two body-write paths (`POST /api/v1/sessions` and `POST /realtime/v1/sessions/{id}/in/append`), and injects a server-trusted `__cf` namespace into the wire payload's `metadata` field. Everything else (SSE, auth, dashboard) passes through untouched. + +Pairs with the `cfTrustTestAgent` (task id `cf-trust-test`) defined in `src/trigger/chat.ts`, which declares the `__cf` namespace in its `clientDataSchema` and echoes the values back so the round-trip is visible in the streamed response. + +## Run it + +```bash +# In references/ai-chat/cf-worker +pnpm install +pnpm run dev # serves on http://localhost:8787, proxies to TRIGGER_API_UPSTREAM +``` + +Point the Next.js reference app at the worker by setting `TRIGGER_API_URL` and `NEXT_PUBLIC_TRIGGER_API_URL` to `http://localhost:8787` in `references/ai-chat/.env`. Then start trigger-dev and Next.js as usual. + +`wrangler dev` populates `request.cf` with the developer's real Cloudflare edge metadata even in local mode; the worker falls back to hardcoded sample values if `request.cf` is unset. + +## Wire-up for `.out` SSE direct (optional) + +By default the reference app routes every request through `NEXT_PUBLIC_TRIGGER_API_URL`, so SSE also flows through the worker. To skip the worker on the long-lived `.out` channel — which gives no body-mutation benefit and adds one extra edge hop per reconnect — switch the transport's `baseURL` to the function form: + +```ts +const transport = useTriggerChatTransport({ + // ... + baseURL: ({ endpoint }) => + endpoint === "out" + ? "https://api.trigger.dev" + : process.env.NEXT_PUBLIC_TRIGGER_API_URL!, +}); +``` + +See [`docs/ai-chat/patterns/trusted-edge-signals`](../../../docs/ai-chat/patterns/trusted-edge-signals.mdx) for the full design — threat model, agent-side schema, deploy considerations. diff --git a/references/ai-chat/cf-worker/package.json b/references/ai-chat/cf-worker/package.json new file mode 100644 index 0000000000..3e1f8debe9 --- /dev/null +++ b/references/ai-chat/cf-worker/package.json @@ -0,0 +1,14 @@ +{ + "name": "cf-trust-test-worker", + "version": "0.0.0", + "private": true, + "type": "module", + "scripts": { + "dev": "wrangler dev", + "deploy": "wrangler deploy" + }, + "devDependencies": { + "@cloudflare/workers-types": "4.20240909.0", + "wrangler": "3.78.0" + } +} diff --git a/references/ai-chat/cf-worker/src/index.ts b/references/ai-chat/cf-worker/src/index.ts new file mode 100644 index 0000000000..1e449153d1 --- /dev/null +++ b/references/ai-chat/cf-worker/src/index.ts @@ -0,0 +1,124 @@ +/** + * cf-trust-test proxy. Validates that a trusted edge proxy can inject a + * namespaced metadata field (`__cf`) into trigger.dev's chat session-create + * and follow-up message wire payloads — and that the trigger.dev server passes + * it through to the agent untouched. + * + * Local dev: `wrangler dev` exposes the worker on http://localhost:8787 and + * forwards to TRIGGER_API_UPSTREAM. With `wrangler dev --remote` the worker + * runs on the CF edge and `request.cf` is populated with real signals; the + * --local default leaves request.cf undefined, so we fall back to hardcoded + * trust values that prove the plumbing without depending on a real CF edge. + */ + +export interface Env { + TRIGGER_API_UPSTREAM: string; +} + +type CfTrustData = { + botScore: number; + ja4: string; + asn: number; + country: string; +}; + +function readCfTrustData(request: Request): CfTrustData { + const cf = (request as Request & { cf?: Record }).cf; + const bm = (cf?.botManagement as Record | undefined) ?? undefined; + return { + botScore: (bm?.score as number | undefined) ?? 95, + ja4: (bm?.ja4 as string | undefined) ?? "t13d1715h2_5b57614c22b0_5c2c4ed3e2d9", + asn: (cf?.asn as number | undefined) ?? 13335, + country: (cf?.country as string | undefined) ?? "US", + }; +} + +function withCors(response: Response, request: Request): Response { + const headers = new Headers(response.headers); + const origin = request.headers.get("origin") ?? "*"; + const reqHeaders = request.headers.get("access-control-request-headers"); + headers.set("Access-Control-Allow-Origin", origin); + headers.set("Vary", "Origin"); + headers.set("Access-Control-Allow-Methods", "GET, POST, OPTIONS, PUT, PATCH, DELETE"); + if (reqHeaders) headers.set("Access-Control-Allow-Headers", reqHeaders); + headers.set("Access-Control-Expose-Headers", "*"); + headers.set("Access-Control-Allow-Credentials", "true"); + return new Response(response.body, { status: response.status, statusText: response.statusText, headers }); +} + +function handlePreflight(request: Request): Response { + return withCors(new Response(null, { status: 204 }), request); +} + +function setCfNamespace( + metadata: Record | undefined, + cf: CfTrustData +): Record { + const stripped: Record = { ...(metadata ?? {}) }; + delete stripped.__cf; + return { ...stripped, __cf: cf }; +} + +async function rewriteSessionCreateBody(body: string, cf: CfTrustData): Promise { + const parsed = JSON.parse(body) as Record; + const triggerConfig = (parsed.triggerConfig as Record | undefined) ?? {}; + const basePayload = (triggerConfig.basePayload as Record | undefined) ?? {}; + const metadata = basePayload.metadata as Record | undefined; + parsed.triggerConfig = { + ...triggerConfig, + basePayload: { ...basePayload, metadata: setCfNamespace(metadata, cf) }, + }; + return JSON.stringify(parsed); +} + +async function rewriteAppendBody(body: string, cf: CfTrustData): Promise { + let parsed: Record; + try { + parsed = JSON.parse(body) as Record; + } catch { + return body; + } + if (parsed.kind !== "message") return body; + const payload = (parsed.payload as Record | undefined) ?? {}; + const metadata = payload.metadata as Record | undefined; + parsed.payload = { ...payload, metadata: setCfNamespace(metadata, cf) }; + return JSON.stringify(parsed); +} + +export default { + async fetch(request: Request, env: Env): Promise { + if (request.method === "OPTIONS") return handlePreflight(request); + + const upstream = new URL(env.TRIGGER_API_UPSTREAM); + const incoming = new URL(request.url); + const target = new URL(incoming.pathname + incoming.search, upstream); + + const cf = readCfTrustData(request); + const isAppend = + request.method === "POST" && + /^\/realtime\/v1\/sessions\/[^/]+\/in\/append$/.test(incoming.pathname); + const isSessionsCreate = + request.method === "POST" && incoming.pathname === "/api/v1/sessions"; + + let body: BodyInit | null = null; + if (request.method !== "GET" && request.method !== "HEAD") { + const raw = await request.text(); + if (isSessionsCreate && raw) body = await rewriteSessionCreateBody(raw, cf); + else if (isAppend && raw) body = await rewriteAppendBody(raw, cf); + else body = raw; + } + + const headers = new Headers(request.headers); + headers.delete("host"); + headers.delete("content-length"); + + const upstreamResponse = await fetch(target.toString(), { + method: request.method, + headers, + body, + redirect: "manual", + }); + + return withCors(upstreamResponse, request); + }, +}; diff --git a/references/ai-chat/cf-worker/tsconfig.json b/references/ai-chat/cf-worker/tsconfig.json new file mode 100644 index 0000000000..7d45444bae --- /dev/null +++ b/references/ai-chat/cf-worker/tsconfig.json @@ -0,0 +1,14 @@ +{ + "compilerOptions": { + "target": "es2022", + "module": "es2022", + "moduleResolution": "bundler", + "lib": ["es2022"], + "types": ["@cloudflare/workers-types"], + "strict": true, + "noEmit": true, + "esModuleInterop": true, + "skipLibCheck": true + }, + "include": ["src/**/*.ts"] +} diff --git a/references/ai-chat/cf-worker/wrangler.toml b/references/ai-chat/cf-worker/wrangler.toml new file mode 100644 index 0000000000..e62a10cc1b --- /dev/null +++ b/references/ai-chat/cf-worker/wrangler.toml @@ -0,0 +1,10 @@ +name = "cf-trust-test-worker" +main = "src/index.ts" +compatibility_date = "2024-09-23" +compatibility_flags = ["nodejs_compat"] + +[vars] +TRIGGER_API_UPSTREAM = "https://api.trigger.dev" + +[dev] +port = 8787 diff --git a/references/ai-chat/src/app/actions.ts b/references/ai-chat/src/app/actions.ts index 0ef650cfc8..4586aa3eeb 100644 --- a/references/ai-chat/src/app/actions.ts +++ b/references/ai-chat/src/app/actions.ts @@ -8,6 +8,7 @@ import type { aiChatRaw, aiChatSession, upgradeTestAgent, + cfTrustTestAgent, } from "@/trigger/chat"; import type { ChatUiMessage } from "@/lib/chat-tools-schemas"; import { prisma } from "@/lib/prisma"; @@ -20,7 +21,8 @@ export type ChatReferenceTaskId = | "ai-chat-hydrated" | "ai-chat-raw" | "ai-chat-session" - | "upgrade-test"; + | "upgrade-test" + | "cf-trust-test"; function isChatReferenceTaskId(id: string): id is ChatReferenceTaskId { return ( @@ -28,7 +30,8 @@ function isChatReferenceTaskId(id: string): id is ChatReferenceTaskId { id === "ai-chat-hydrated" || id === "ai-chat-raw" || id === "ai-chat-session" || - id === "upgrade-test" + id === "upgrade-test" || + id === "cf-trust-test" ); } @@ -38,7 +41,8 @@ type TaskIdentifierForChat = | (typeof aiChatHydrated)["id"] | (typeof aiChatRaw)["id"] | (typeof aiChatSession)["id"] - | (typeof upgradeTestAgent)["id"]; + | (typeof upgradeTestAgent)["id"] + | (typeof cfTrustTestAgent)["id"]; /** * Server-mediated start: creates the Session row + triggers the first @@ -70,6 +74,7 @@ const startActionByTaskId: Record< "ai-chat-raw": startChatSessionFor("ai-chat-raw"), "ai-chat-session": startChatSessionFor("ai-chat-session"), "upgrade-test": startChatSessionFor("upgrade-test"), + "cf-trust-test": startChatSessionFor("cf-trust-test"), }; export async function startChatSession(input: { diff --git a/references/ai-chat/src/components/chat-sidebar.tsx b/references/ai-chat/src/components/chat-sidebar.tsx index e036eebc71..9707b61ac3 100644 --- a/references/ai-chat/src/components/chat-sidebar.tsx +++ b/references/ai-chat/src/components/chat-sidebar.tsx @@ -118,6 +118,7 @@ export function ChatSidebar({ +