Skip to content

Commit

Permalink
Add experimental_StreamingReactResponse (#680)
Browse files Browse the repository at this point in the history
  • Loading branch information
shuding committed Oct 26, 2023
1 parent 6aff786 commit 699552d
Show file tree
Hide file tree
Showing 5 changed files with 156 additions and 14 deletions.
5 changes: 5 additions & 0 deletions .changeset/mighty-carrots-grab.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
'ai': patch
---

add experimental_StreamingReactResponse
83 changes: 69 additions & 14 deletions packages/core/react/use-chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@ import type {
ChatRequestOptions,
FunctionCall,
} from '../shared/types';

import type {
ReactResponseRow,
experimental_StreamingReactResponse,
} from '../streams/streaming-react-response';
export type { Message, CreateMessage, UseChatOptions };

export type UseChatHelpers = {
Expand Down Expand Up @@ -67,8 +72,12 @@ export type UseChatHelpers = {
data?: any;
};

type StreamingReactResponseAction = (payload: {
messages: Message[];
}) => Promise<experimental_StreamingReactResponse>;

const getStreamedResponse = async (
api: string,
api: string | StreamingReactResponseAction,
chatRequest: ChatRequest,
mutate: KeyedMutator<Message[]>,
mutateStreamData: KeyedMutator<any[]>,
Expand All @@ -85,21 +94,65 @@ const getStreamedResponse = async (
const previousMessages = messagesRef.current;
mutate(chatRequest.messages, false);

const constructedMessagesPayload = sendExtraMessageFields
? chatRequest.messages
: chatRequest.messages.map(({ role, content, name, function_call }) => ({
role,
content,
...(name !== undefined && { name }),
...(function_call !== undefined && {
function_call: function_call,
}),
}));

if (typeof api !== 'string') {
// In this case, we are handling a Server Action. No complex mode handling needed.

const replyId = nanoid();
const createdAt = new Date();
let responseMessage: Message = {
id: replyId,
createdAt,
content: '',
role: 'assistant',
};

async function readRow(promise: Promise<ReactResponseRow>) {
const { content, ui, next } = await promise;

// TODO: Handle function calls.
responseMessage['content'] = content;
responseMessage['ui'] = await ui;

mutate([...chatRequest.messages, { ...responseMessage }], false);

if (next) {
await readRow(next);
}
}

try {
const promise = api({
messages: constructedMessagesPayload as Message[],
}) as Promise<ReactResponseRow>;
await readRow(promise);
} catch (e) {
// Restore the previous messages if the request fails.
mutate(previousMessages, false);
throw e;
}

if (onFinish) {
onFinish(responseMessage);
}

return responseMessage;
}

const res = await fetch(api, {
method: 'POST',
body: JSON.stringify({
messages: sendExtraMessageFields
? chatRequest.messages
: chatRequest.messages.map(
({ role, content, name, function_call }) => ({
role,
content,
...(name !== undefined && { name }),
...(function_call !== undefined && {
function_call: function_call,
}),
}),
),
messages: constructedMessagesPayload,
...extraMetadataRef.current.body,
...chatRequest.options?.body,
...(chatRequest.functions !== undefined && {
Expand Down Expand Up @@ -346,7 +399,9 @@ export function useChat({
credentials,
headers,
body,
}: UseChatOptions = {}): UseChatHelpers {
}: Omit<UseChatOptions, 'api'> & {
api?: string | StreamingReactResponseAction;
} = {}): UseChatHelpers {
// Generate a unique id for the chat if not provided.
const hookId = useId();
const chatId = id || hookId;
Expand Down
1 change: 1 addition & 0 deletions packages/core/shared/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ export interface Message {
id: string;
createdAt?: Date;
content: string;
ui?: string | JSX.Element | JSX.Element[] | null | undefined;
role: 'system' | 'user' | 'assistant' | 'function';
/**
* If the message has a role of `function`, the `name` field is the name of the function.
Expand Down
1 change: 1 addition & 0 deletions packages/core/streams/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@ export * from './replicate-stream';
export * from '../shared/types';
export * from '../shared/utils';
export * from './stream-data';
export * from './streaming-react-response';
80 changes: 80 additions & 0 deletions packages/core/streams/streaming-react-response.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
/**
* This is a naive implementation of the streaming React response API.
* Currently, it can carry the original raw content, data payload and a special
* UI payload and stream them via "rows" (nested promises).
* It must be used inside Server Actions so Flight can encode the React elements.
*
* It is naive as unlike the StreamingTextResponse, it does not send the diff
* between the rows, but flushing the full payload on each row.
*/

import { createChunkDecoder } from '../shared/utils';

type UINode = string | JSX.Element | JSX.Element[] | null | undefined;

type Payload = {
ui: UINode | Promise<UINode>;
content: string;
};

export type ReactResponseRow = Payload & {
next: null | Promise<ReactResponseRow>;
};

/**
* A utility class for streaming React responses.
*/
export class experimental_StreamingReactResponse {
constructor(
res: ReadableStream,
options?: {
ui?: (message: { content: string }) => UINode | Promise<UINode>;
},
) {
let resolveFunc: (row: ReactResponseRow) => void = () => {};
let next = new Promise<ReactResponseRow>(resolve => {
resolveFunc = resolve;
});

let content = '';

const decode = createChunkDecoder();
const reader = res.getReader();
async function readChunk() {
const { done, value } = await reader.read();
if (!done) {
content += decode(value);
}

// TODO: Handle generators. With this current implementation we can support
// synchronous and asynchronous UIs.
// TODO: Handle function calls.
const ui = options?.ui?.({ content }) || content;

const payload: Payload = {
ui,
content,
};

const resolvePrevious = resolveFunc;
const nextRow = done
? null
: new Promise<ReactResponseRow>(resolve => {
resolveFunc = resolve;
});
resolvePrevious({
next: nextRow,
...payload,
});

if (done) {
return;
}

await readChunk();
}
readChunk();

return next;
}
}

0 comments on commit 699552d

Please sign in to comment.