Skip to content

Commit

Permalink
react/use-chat: fix client side function calling with stream data pro…
Browse files Browse the repository at this point in the history
…tocol (#473)
  • Loading branch information
MaxLeiter committed Aug 18, 2023
1 parent b84af19 commit 867a3f9
Show file tree
Hide file tree
Showing 3 changed files with 122 additions and 34 deletions.
7 changes: 7 additions & 0 deletions .changeset/new-eggs-turn.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
---
'ai': patch
---

Fix client-side function calling (#467, #469)

add Completion type from the `openai` SDK to openai-stream (#472)
58 changes: 32 additions & 26 deletions packages/core/react/use-chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -147,15 +147,6 @@ const getStreamedResponse = async (
const decode = createChunkDecoder(isComplexMode)
let responseMessages: Message[] = []

// TODO-STREAMDATA: Remove this once Strem Data is not experimental
let streamedResponse = ''
const replyId = nanoid()
let responseMessage: Message = {
id: replyId,
createdAt,
content: '',
role: 'assistant'
}
// END TODO-STREAMDATA
let responseData: any = []
type PrefixMap = {
Expand Down Expand Up @@ -198,8 +189,28 @@ const getStreamedResponse = async (
}
}

let functionCallMessage: Message | null = null

if (type === 'function_call') {
prefixMap['function_call'] = value

let functionCall = prefixMap['function_call']
// Ensure it hasn't been parsed
if (functionCall && typeof functionCall === 'string') {
const parsedFunctionCall: CreateChatCompletionRequestMessage.FunctionCall =
JSON.parse(functionCall as string).function_call

functionCallMessage = {
id: nanoid(),
role: 'assistant',
content: '',
function_call: parsedFunctionCall,
name: parsedFunctionCall.name,
createdAt
}

prefixMap['function_call'] = functionCallMessage as any
}
}

if (type === 'data') {
Expand All @@ -213,28 +224,14 @@ const getStreamedResponse = async (

const data = prefixMap['data']
const responseMessage = prefixMap['text']
let functionCall = prefixMap['function_call']
let functionCallMessage: Message | null = null
if (functionCall) {
const parsedFunctionCall: CreateChatCompletionRequestMessage.FunctionCall =
JSON.parse(functionCall as string).function_call

functionCallMessage = {
id: nanoid(),
role: 'function',
content: '',
name: parsedFunctionCall.name,
createdAt
}
}

// We add function calls and response messages to the messages[], but data is its own thing
const merged = [functionCallMessage, responseMessage].filter(
Boolean
) as Message[]

mutate([...chatRequest.messages, ...merged], false)
mutateStreamData([...(existingData || []), ...(data || [])])
mutateStreamData([...(existingData || []), ...(data || [])], false)

// The request has been aborted, stop reading the stream.
if (abortControllerRef.current === null) {
Expand All @@ -248,7 +245,6 @@ const getStreamedResponse = async (
if (onFinish && type === 'text') {
onFinish(item as Message)
}

if (type === 'data') {
responseData.push(item)
} else {
Expand All @@ -257,6 +253,16 @@ const getStreamedResponse = async (
}
return { messages: responseMessages, data: responseData }
} else {
// TODO-STREAMDATA: Remove this once Strem Data is not experimental
let streamedResponse = ''
const replyId = nanoid()
let responseMessage: Message = {
id: replyId,
createdAt,
content: '',
role: 'assistant'
}

// TODO-STREAMDATA: Remove this once Strem Data is not experimental
while (true) {
const { done, value } = await reader.read()
Expand Down Expand Up @@ -417,7 +423,7 @@ export function useChat({
}
} else {
const streamedResponseMessage = messagesAndDataOrJustMessage
// TODO-STREAMDATA: Remove this once Strem Data is not experimental
// TODO-STREAMDATA: Remove this once Stream Data is not experimental
if (
streamedResponseMessage.function_call === undefined ||
typeof streamedResponseMessage.function_call === 'string'
Expand Down
91 changes: 83 additions & 8 deletions packages/core/streams/openai-stream.ts
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ export type OpenAIStreamCallbacks = AIStreamCallbacksAndOptions & {
functionCallResult: JSONValue
) => CreateMessage[]
) => Promise<
Response | undefined | void | string | AsyncIterable<ChatCompletionChunk>
Response | undefined | void | string | AsyncIterableOpenAIStreamReturnTypes
>
}

Expand Down Expand Up @@ -102,6 +102,53 @@ interface FunctionCall {
name?: string
}

/**
* https://github.com/openai/openai-node/blob/3ec43ee790a2eb6a0ccdd5f25faa23251b0f9b8e/src/resources/completions.ts#L28C1-L64C1
* Completions API. Streamed and non-streamed responses are the same.
*/
interface Completion {
/**
* A unique identifier for the completion.
*/
id: string

/**
* The list of completion choices the model generated for the input prompt.
*/
choices: Array<CompletionChoice>

/**
* The Unix timestamp of when the completion was created.
*/
created: number

/**
* The model used for completion.
*/
model: string

/**
* The object type, which is always "text_completion"
*/
object: string
}

interface CompletionChoice {
/**
* The reason the model stopped generating tokens. This will be `stop` if the model
* hit a natural stop point or a provided stop sequence, or `length` if the maximum
* number of tokens specified in the request was reached.
*/
finish_reason: 'stop' | 'length'

index: number

// edited: Removed CompletionChoice.logProbs and replaced with any
logprobs: any | null

text: string
}

/**
* Creates a parser function for processing the OpenAI stream data.
* The parser extracts and trims text content from the JSON data. This parser
Expand All @@ -112,7 +159,7 @@ interface FunctionCall {
function parseOpenAIStream(): (data: string) => string | void {
const extract = chunkToText()
return data => {
return extract(JSON.parse(data) as ChatCompletionChunk)
return extract(JSON.parse(data) as OpenAIStreamReturnTypes)
}
}

Expand All @@ -121,15 +168,15 @@ function parseOpenAIStream(): (data: string) => string | void {
* the same as the old Response body interface with an included SSE parser
* doing the parsing for us.
*/
async function* streamable(stream: AsyncIterable<ChatCompletionChunk>) {
async function* streamable(stream: AsyncIterableOpenAIStreamReturnTypes) {
const extract = chunkToText()
for await (const chunk of stream) {
const text = extract(chunk)
if (text) yield text
}
}

function chunkToText(): (chunk: ChatCompletionChunk) => string | void {
function chunkToText(): (chunk: OpenAIStreamReturnTypes) => string | void {
const trimStartOfStream = trimStartOfStreamHelper()
let isFunctionStreamingIn: boolean
return json => {
Expand Down Expand Up @@ -219,10 +266,16 @@ function chunkToText(): (chunk: ChatCompletionChunk) => string | void {
}
}
*/
if (json.choices[0]?.delta?.function_call?.name) {
if (
isChatCompletionChunk(json) &&
json.choices[0]?.delta?.function_call?.name
) {
isFunctionStreamingIn = true
return `{"function_call": {"name": "${json.choices[0]?.delta?.function_call.name}", "arguments": "`
} else if (json.choices[0]?.delta?.function_call?.arguments) {
} else if (
isChatCompletionChunk(json) &&
json.choices[0]?.delta?.function_call?.arguments
) {
const argumentChunk: string =
json.choices[0].delta.function_call.arguments

Expand All @@ -246,16 +299,38 @@ function chunkToText(): (chunk: ChatCompletionChunk) => string | void {
}

const text = trimStartOfStream(
json.choices[0]?.delta?.content ?? (json.choices[0] as any)?.text ?? ''
isChatCompletionChunk(json) && json.choices[0].delta.content
? json.choices[0].delta.content
: isCompletion(json)
? json.choices[0].text
: ''
)
return text
}
}

const __internal__OpenAIFnMessagesSymbol = Symbol('internal_openai_fn_messages')

type AsyncIterableOpenAIStreamReturnTypes =
| AsyncIterable<ChatCompletionChunk>
| AsyncIterable<Completion>

type ExtractType<T> = T extends AsyncIterable<infer U> ? U : never

type OpenAIStreamReturnTypes = ExtractType<AsyncIterableOpenAIStreamReturnTypes>

function isChatCompletionChunk(
data: OpenAIStreamReturnTypes
): data is ChatCompletionChunk {
return 'choices' in data && 'delta' in data.choices[0]
}

function isCompletion(data: OpenAIStreamReturnTypes): data is Completion {
return 'choices' in data && 'text' in data.choices[0]
}

export function OpenAIStream(
res: Response | AsyncIterable<ChatCompletionChunk>,
res: Response | AsyncIterableOpenAIStreamReturnTypes,
callbacks?: OpenAIStreamCallbacks
): ReadableStream {
// Annotate the internal `messages` property for recursive function calls
Expand Down

0 comments on commit 867a3f9

Please sign in to comment.