diff --git a/app/api/generate/route.ts b/app/api/generate/route.ts index a0a9677736..7d34526f70 100644 --- a/app/api/generate/route.ts +++ b/app/api/generate/route.ts @@ -1,7 +1,6 @@ import { OpenAIStream, openai } from "@/lib/openai"; import { getServerSession } from "@/lib/session/get-server-session"; - -export const runtime = "edge"; +import { prisma } from "@/lib/prisma"; export async function POST(req: Request) { const json = await req.json(); @@ -20,7 +19,25 @@ export async function POST(req: Request) { const stream = await OpenAIStream(res); - return new Response(stream, { + let fullResponse = ""; + const decoder = new TextDecoder(); + const saveToPrisma = new TransformStream({ + transform: async (chunk, controller) => { + controller.enqueue(chunk); + fullResponse += decoder.decode(chunk); + }, + flush: async () => { + await prisma.chat.upsert({ + where: { + id: json.id, + }, + create: json, + update: json, + }); + }, + }); + + return new Response(stream.pipeThrough(saveToPrisma), { status: 200, headers: { "Content-Type": "text/event-stream" }, }); diff --git a/app/chat.tsx b/app/chat.tsx index 0ce7861bff..e3634b522e 100644 --- a/app/chat.tsx +++ b/app/chat.tsx @@ -13,7 +13,7 @@ export interface ChatProps { } export function Chat({ - id: _id, + id, // create, messages, }: ChatProps) { @@ -22,7 +22,7 @@ export function Chat({ const { isLoading, messageList, appendUserMessage, reloadLastMessage } = usePrompt({ messages, - _id, + id, // onCreate: (id: string) => { // router.push(`/chat/${id}`); // }, diff --git a/app/use-prompt.tsx b/app/use-prompt.tsx index d3150947fb..9bf0a1165b 100644 --- a/app/use-prompt.tsx +++ b/app/use-prompt.tsx @@ -4,10 +4,10 @@ import { nanoid } from "@/lib/utils"; export function usePrompt({ messages = [], - _id, + id, }: { messages?: Message[]; - _id: string | undefined | null; + id: string | undefined; }) { const [isLoading, setIsLoading] = useState(false); const [messageList, setMessageList] = useState(messages); @@ -23,83 +23,87 @@ export function usePrompt({ messageListRef.current = messageList; }, [messageList]); - const appendUserMessage = useCallback(async (content: string | Message) => { - // Prevent multiple requests at once - if (isLoadingRef.current) return; - - const userMsg = - typeof content === "string" - ? ({ id: nanoid(10), role: "user", content } as Message) - : content; - const assMsg = { - id: nanoid(10), - role: "assistant", - content: "", - } as Message; - const messageListSnapshot = messageListRef.current; - - // Reset output - setIsLoading(true); - - try { - // Set user input immediately - setMessageList([...messageListSnapshot, userMsg]); - - // If streaming, we need to use fetchEventSource directly - const response = await fetch(`/api/generate`, { - method: "POST", - body: JSON.stringify({ - messages: [...messageListSnapshot, userMsg].map((m) => ({ - role: m.role, - content: m.content, - })), - }), - headers: { "Content-Type": "application/json" }, - }); - // This data is a ReadableStream - const data = response.body; - if (!data) { - return; - } + const appendUserMessage = useCallback( + async (content: string | Message) => { + // Prevent multiple requests at once + if (isLoadingRef.current) return; + + const userMsg = + typeof content === "string" + ? ({ id: nanoid(10), role: "user", content } as Message) + : content; + const assMsg = { + id: nanoid(10), + role: "assistant", + content: "", + } as Message; + const messageListSnapshot = messageListRef.current; + + // Reset output + setIsLoading(true); + + try { + // Set user input immediately + setMessageList([...messageListSnapshot, userMsg]); + + // If streaming, we need to use fetchEventSource directly + const response = await fetch(`/api/generate`, { + method: "POST", + body: JSON.stringify({ + id: id || nanoid(10), + messages: [...messageListSnapshot, userMsg].map((m) => ({ + role: m.role, + content: m.content, + })), + }), + headers: { "Content-Type": "application/json" }, + }); + // This data is a ReadableStream + const data = response.body; + if (!data) { + return; + } - const reader = data.getReader(); - const decoder = new TextDecoder(); - let done = false; - let accumulatedValue = ""; // Variable to accumulate chunks - - while (!done) { - const { value, done: doneReading } = await reader.read(); - done = doneReading; - const chunkValue = decoder.decode(value); - accumulatedValue += chunkValue; // Accumulate the chunk value - - // Check if the accumulated value contains the delimiter - const delimiter = "\n"; - const chunks = accumulatedValue.split(delimiter); - - // Process all chunks except the last one (which may be incomplete) - while (chunks.length > 1) { - const chunkToDispatch = chunks.shift(); // Get the first chunk - if (chunkToDispatch && chunkToDispatch.length > 0) { - const chunk = JSON.parse(chunkToDispatch); - assMsg.content += chunk; - setMessageList([...messageListSnapshot, userMsg, assMsg]); + const reader = data.getReader(); + const decoder = new TextDecoder(); + let done = false; + let accumulatedValue = ""; // Variable to accumulate chunks + + while (!done) { + const { value, done: doneReading } = await reader.read(); + done = doneReading; + const chunkValue = decoder.decode(value); + accumulatedValue += chunkValue; // Accumulate the chunk value + + // Check if the accumulated value contains the delimiter + const delimiter = "\n"; + const chunks = accumulatedValue.split(delimiter); + + // Process all chunks except the last one (which may be incomplete) + while (chunks.length > 1) { + const chunkToDispatch = chunks.shift(); // Get the first chunk + if (chunkToDispatch && chunkToDispatch.length > 0) { + const chunk = JSON.parse(chunkToDispatch); + assMsg.content += chunk; + setMessageList([...messageListSnapshot, userMsg, assMsg]); + } } - } - // The last chunk may be incomplete, so keep it in the accumulated value - accumulatedValue = chunks[0]; - } + // The last chunk may be incomplete, so keep it in the accumulated value + accumulatedValue = chunks[0]; + } - // Process any remaining accumulated value after the loop is done - if (accumulatedValue.length > 0) { - assMsg.content += accumulatedValue; - setMessageList([...messageListSnapshot, userMsg, assMsg]); + // Process any remaining accumulated value after the loop is done + if (accumulatedValue.length > 0) { + assMsg.content += accumulatedValue; + setMessageList([...messageListSnapshot, userMsg, assMsg]); + } + } finally { + setIsLoading(false); } - } finally { - setIsLoading(false); - } - }, []); + }, + [id] + ); const reloadLastMessage = useCallback(async () => { // Prevent multiple requests at once