-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Introduce useAssistant (experimental). (#728)
Co-authored-by: Safi Nettah <nettah.safi@protonmail.com> Co-authored-by: Max Leiter <max.leiter@vercel.com>
- Loading branch information
1 parent
fd7d445
commit 69ca8f5
Showing
13 changed files
with
628 additions
and
57 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
--- | ||
'ai': patch | ||
--- | ||
|
||
ai/react: add experimental_useAssistant hook and experimental_AssistantResponse |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
# Home Automation Assistant Example | ||
|
||
## Setup | ||
|
||
### Create OpenAI Assistant | ||
|
||
[OpenAI Assistant Website](https://platform.openai.com/assistants) | ||
|
||
Create a new assistant. Enable Code interpreter. Add the following functions and instructions to the assistant. | ||
|
||
Then add the assistant id to the `.env.local` file as `ASSISTANT_ID=your-assistant-id`. | ||
|
||
### Instructions | ||
|
||
``` | ||
You are an assistant with access to a home automation system. You can get and set the temperature in the bedroom, home office, living room, kitchen and bathroom. | ||
The system uses temperature in Celsius. If the user requests Fahrenheit, you should convert the temperature to Fahrenheit. | ||
``` | ||
|
||
### getRoomTemperature function | ||
|
||
```json | ||
{ | ||
"name": "getRoomTemperature", | ||
"description": "Get the temperature in a room", | ||
"parameters": { | ||
"type": "object", | ||
"properties": { | ||
"room": { | ||
"type": "string", | ||
"enum": ["bedroom", "home office", "living room", "kitchen", "bathroom"] | ||
} | ||
}, | ||
"required": ["room"] | ||
} | ||
} | ||
``` | ||
|
||
### setRoomTemperature function | ||
|
||
```json | ||
{ | ||
"name": "setRoomTemperature", | ||
"description": "Set the temperature in a room", | ||
"parameters": { | ||
"type": "object", | ||
"properties": { | ||
"room": { | ||
"type": "string", | ||
"enum": ["bedroom", "home office", "living room", "kitchen", "bathroom"] | ||
}, | ||
"temperature": { "type": "number" } | ||
}, | ||
"required": ["room", "temperature"] | ||
} | ||
} | ||
``` | ||
|
||
## Run | ||
|
||
1. Run `pnpm run dev` in `examples/next-openai` | ||
2. Go to http://localhost:3000/assistant |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,140 @@ | ||
import { experimental_AssistantResponse } from 'ai'; | ||
import OpenAI from 'openai'; | ||
import { MessageContentText } from 'openai/resources/beta/threads/messages/messages'; | ||
|
||
// Create an OpenAI API client (that's edge friendly!) | ||
const openai = new OpenAI({ | ||
apiKey: process.env.OPENAI_API_KEY || '', | ||
}); | ||
|
||
// IMPORTANT! Set the runtime to edge | ||
export const runtime = 'edge'; | ||
|
||
const homeTemperatures = { | ||
bedroom: 20, | ||
'home office': 21, | ||
'living room': 21, | ||
kitchen: 22, | ||
bathroom: 23, | ||
}; | ||
|
||
export async function POST(req: Request) { | ||
// Parse the request body | ||
const input: { | ||
threadId: string | null; | ||
message: string; | ||
} = await req.json(); | ||
|
||
// Create a thread if needed | ||
const threadId = input.threadId ?? (await openai.beta.threads.create({})).id; | ||
|
||
// Add a message to the thread | ||
const createdMessage = await openai.beta.threads.messages.create(threadId, { | ||
role: 'user', | ||
content: input.message, | ||
}); | ||
|
||
return experimental_AssistantResponse( | ||
{ threadId, messageId: createdMessage.id }, | ||
async ({ threadId, sendMessage }) => { | ||
// Run the assistant on the thread | ||
const run = await openai.beta.threads.runs.create(threadId, { | ||
assistant_id: | ||
process.env.ASSISTANT_ID ?? | ||
(() => { | ||
throw new Error('ASSISTANT_ID is not set'); | ||
})(), | ||
}); | ||
|
||
async function waitForRun(run: OpenAI.Beta.Threads.Runs.Run) { | ||
// Poll for status change | ||
while (run.status === 'queued' || run.status === 'in_progress') { | ||
// delay for 500ms: | ||
await new Promise(resolve => setTimeout(resolve, 500)); | ||
|
||
run = await openai.beta.threads.runs.retrieve(threadId!, run.id); | ||
} | ||
|
||
// Check the run status | ||
if ( | ||
run.status === 'cancelled' || | ||
run.status === 'cancelling' || | ||
run.status === 'failed' || | ||
run.status === 'expired' | ||
) { | ||
throw new Error(run.status); | ||
} | ||
|
||
if (run.status === 'requires_action') { | ||
if (run.required_action?.type === 'submit_tool_outputs') { | ||
const tool_outputs = | ||
run.required_action.submit_tool_outputs.tool_calls.map( | ||
toolCall => { | ||
const parameters = JSON.parse(toolCall.function.arguments); | ||
|
||
switch (toolCall.function.name) { | ||
case 'getRoomTemperature': { | ||
const temperature = | ||
homeTemperatures[ | ||
parameters.room as keyof typeof homeTemperatures | ||
]; | ||
|
||
return { | ||
tool_call_id: toolCall.id, | ||
output: temperature.toString(), | ||
}; | ||
} | ||
|
||
case 'setRoomTemperature': { | ||
homeTemperatures[ | ||
parameters.room as keyof typeof homeTemperatures | ||
] = parameters.temperature; | ||
|
||
return { | ||
tool_call_id: toolCall.id, | ||
output: `temperature set successfully`, | ||
}; | ||
} | ||
|
||
default: | ||
throw new Error( | ||
`Unknown tool call function: ${toolCall.function.name}`, | ||
); | ||
} | ||
}, | ||
); | ||
|
||
run = await openai.beta.threads.runs.submitToolOutputs( | ||
threadId!, | ||
run.id, | ||
{ tool_outputs }, | ||
); | ||
|
||
await waitForRun(run); | ||
} | ||
} | ||
} | ||
|
||
await waitForRun(run); | ||
|
||
// Get new thread messages (after our message) | ||
const responseMessages = ( | ||
await openai.beta.threads.messages.list(threadId, { | ||
after: createdMessage.id, | ||
order: 'asc', | ||
}) | ||
).data; | ||
|
||
// Send the messages | ||
for (const message of responseMessages) { | ||
sendMessage({ | ||
id: message.id, | ||
role: 'assistant', | ||
content: message.content.filter( | ||
content => content.type === 'text', | ||
) as Array<MessageContentText>, | ||
}); | ||
} | ||
}, | ||
); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
'use client'; | ||
|
||
import { Message, experimental_useAssistant as useAssistant } from 'ai/react'; | ||
import { useEffect, useRef } from 'react'; | ||
|
||
const roleToColorMap: Record<Message['role'], string> = { | ||
system: 'red', | ||
user: 'black', | ||
function: 'blue', | ||
assistant: 'green', | ||
}; | ||
|
||
export default function Chat() { | ||
const { status, messages, input, submitMessage, handleInputChange, error } = | ||
useAssistant({ | ||
api: '/api/assistant', | ||
}); | ||
|
||
// When status changes to accepting messages, focus the input: | ||
const inputRef = useRef<HTMLInputElement>(null); | ||
useEffect(() => { | ||
if (status === 'awaiting_message') { | ||
inputRef.current?.focus(); | ||
} | ||
}, [status]); | ||
|
||
return ( | ||
<div className="flex flex-col w-full max-w-md py-24 mx-auto stretch"> | ||
{error != null && ( | ||
<div className="relative bg-red-500 text-white px-6 py-4 rounded-md"> | ||
<span className="block sm:inline"> | ||
Error: {(error as any).toString()} | ||
</span> | ||
</div> | ||
)} | ||
|
||
{messages.map((m: Message) => ( | ||
<div | ||
key={m.id} | ||
className="whitespace-pre-wrap" | ||
style={{ color: roleToColorMap[m.role] }} | ||
> | ||
<strong>{`${m.role}: `}</strong> | ||
{m.content} | ||
<br /> | ||
<br /> | ||
</div> | ||
))} | ||
|
||
{status === 'in_progress' && ( | ||
<div className="h-8 w-full max-w-md p-2 mb-8 bg-gray-300 dark:bg-gray-600 rounded-lg animate-pulse" /> | ||
)} | ||
|
||
<form onSubmit={submitMessage}> | ||
<input | ||
ref={inputRef} | ||
disabled={status !== 'awaiting_message'} | ||
className="fixed bottom-0 w-full max-w-md p-2 mb-8 border border-gray-300 rounded shadow-xl" | ||
value={input} | ||
placeholder="What is the temperature in the living room?" | ||
onChange={handleInputChange} | ||
/> | ||
</form> | ||
</div> | ||
); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,3 @@ | ||
export * from './use-chat'; | ||
export * from './use-completion'; | ||
export * from './use-assistant'; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
/* eslint-disable react-hooks/rules-of-hooks */ | ||
|
||
import { useState } from 'react'; | ||
import { processMessageStream } from '../shared/process-message-stream'; | ||
import { Message } from '../shared/types'; | ||
import { parseStreamPart } from '../shared/stream-parts'; | ||
|
||
export type AssistantStatus = 'in_progress' | 'awaiting_message'; | ||
|
||
export function experimental_useAssistant({ | ||
api, | ||
threadId: threadIdParam, | ||
}: { | ||
api: string; | ||
threadId?: string | undefined; | ||
}) { | ||
const [messages, setMessages] = useState<Message[]>([]); | ||
const [input, setInput] = useState(''); | ||
const [threadId, setThreadId] = useState<string | undefined>(undefined); | ||
const [status, setStatus] = useState<AssistantStatus>('awaiting_message'); | ||
const [error, setError] = useState<unknown | undefined>(undefined); | ||
|
||
const handleInputChange = (e: any) => { | ||
setInput(e.target.value); | ||
}; | ||
|
||
const submitMessage = async (e: any) => { | ||
e.preventDefault(); | ||
|
||
if (input === '') { | ||
return; | ||
} | ||
|
||
setStatus('in_progress'); | ||
|
||
setMessages(messages => [ | ||
...messages, | ||
{ id: '', role: 'user', content: input }, | ||
]); | ||
|
||
setInput(''); | ||
|
||
const result = await fetch(api, { | ||
method: 'POST', | ||
headers: { 'Content-Type': 'application/json' }, | ||
body: JSON.stringify({ | ||
// always use user-provided threadId when available: | ||
threadId: threadIdParam ?? threadId ?? null, | ||
message: input, | ||
}), | ||
}); | ||
|
||
if (result.body == null) { | ||
throw new Error('The response body is empty.'); | ||
} | ||
|
||
await processMessageStream(result.body.getReader(), (message: string) => { | ||
try { | ||
const { type, value } = parseStreamPart(message); | ||
|
||
switch (type) { | ||
case 'assistant_message': { | ||
// append message: | ||
setMessages(messages => [ | ||
...messages, | ||
{ | ||
id: value.id, | ||
role: value.role, | ||
content: value.content[0].text.value, | ||
}, | ||
]); | ||
break; | ||
} | ||
|
||
case 'assistant_control_data': { | ||
setThreadId(value.threadId); | ||
|
||
// set id of last message: | ||
setMessages(messages => { | ||
const lastMessage = messages[messages.length - 1]; | ||
lastMessage.id = value.messageId; | ||
return [...messages.slice(0, messages.length - 1), lastMessage]; | ||
}); | ||
|
||
break; | ||
} | ||
|
||
case 'error': { | ||
setError(value); | ||
break; | ||
} | ||
} | ||
} catch (error) { | ||
setError(error); | ||
} | ||
}); | ||
|
||
setStatus('awaiting_message'); | ||
}; | ||
|
||
return { | ||
messages, | ||
input, | ||
handleInputChange, | ||
submitMessage, | ||
status, | ||
error, | ||
}; | ||
} |
Oops, something went wrong.