Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 20 additions & 3 deletions app/api/generate/route.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import { OpenAIStream, openai } from "@/lib/openai";
import { getServerSession } from "@/lib/session/get-server-session";

export const runtime = "edge";
import { prisma } from "@/lib/prisma";

export async function POST(req: Request) {
const json = await req.json();
Expand All @@ -20,7 +19,25 @@ export async function POST(req: Request) {

const stream = await OpenAIStream(res);

return new Response(stream, {
let fullResponse = "";
const decoder = new TextDecoder();
const saveToPrisma = new TransformStream({
transform: async (chunk, controller) => {
controller.enqueue(chunk);
fullResponse += decoder.decode(chunk);
},
flush: async () => {
await prisma.chat.upsert({
where: {
id: json.id,
},
create: json,
update: json,
});
},
});

return new Response(stream.pipeThrough(saveToPrisma), {
status: 200,
headers: { "Content-Type": "text/event-stream" },
});
Expand Down
4 changes: 2 additions & 2 deletions app/chat.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ export interface ChatProps {
}

export function Chat({
id: _id,
id,
// create,
messages,
}: ChatProps) {
Expand All @@ -22,7 +22,7 @@ export function Chat({
const { isLoading, messageList, appendUserMessage, reloadLastMessage } =
usePrompt({
messages,
_id,
id,
// onCreate: (id: string) => {
// router.push(`/chat/${id}`);
// },
Expand Down
152 changes: 78 additions & 74 deletions app/use-prompt.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@ import { nanoid } from "@/lib/utils";

export function usePrompt({
messages = [],
_id,
id,
}: {
messages?: Message[];
_id: string | undefined | null;
id: string | undefined;
}) {
const [isLoading, setIsLoading] = useState(false);
const [messageList, setMessageList] = useState(messages);
Expand All @@ -23,83 +23,87 @@ export function usePrompt({
messageListRef.current = messageList;
}, [messageList]);

const appendUserMessage = useCallback(async (content: string | Message) => {
// Prevent multiple requests at once
if (isLoadingRef.current) return;

const userMsg =
typeof content === "string"
? ({ id: nanoid(10), role: "user", content } as Message)
: content;
const assMsg = {
id: nanoid(10),
role: "assistant",
content: "",
} as Message;
const messageListSnapshot = messageListRef.current;

// Reset output
setIsLoading(true);

try {
// Set user input immediately
setMessageList([...messageListSnapshot, userMsg]);

// If streaming, we need to use fetchEventSource directly
const response = await fetch(`/api/generate`, {
method: "POST",
body: JSON.stringify({
messages: [...messageListSnapshot, userMsg].map((m) => ({
role: m.role,
content: m.content,
})),
}),
headers: { "Content-Type": "application/json" },
});
// This data is a ReadableStream
const data = response.body;
if (!data) {
return;
}
const appendUserMessage = useCallback(
async (content: string | Message) => {
// Prevent multiple requests at once
if (isLoadingRef.current) return;

const userMsg =
typeof content === "string"
? ({ id: nanoid(10), role: "user", content } as Message)
: content;
const assMsg = {
id: nanoid(10),
role: "assistant",
content: "",
} as Message;
const messageListSnapshot = messageListRef.current;

// Reset output
setIsLoading(true);

try {
// Set user input immediately
setMessageList([...messageListSnapshot, userMsg]);

// If streaming, we need to use fetchEventSource directly
const response = await fetch(`/api/generate`, {
method: "POST",
body: JSON.stringify({
id: id || nanoid(10),
messages: [...messageListSnapshot, userMsg].map((m) => ({
role: m.role,
content: m.content,
})),
}),
headers: { "Content-Type": "application/json" },
});
// This data is a ReadableStream
const data = response.body;
if (!data) {
return;
}

const reader = data.getReader();
const decoder = new TextDecoder();
let done = false;
let accumulatedValue = ""; // Variable to accumulate chunks

while (!done) {
const { value, done: doneReading } = await reader.read();
done = doneReading;
const chunkValue = decoder.decode(value);
accumulatedValue += chunkValue; // Accumulate the chunk value

// Check if the accumulated value contains the delimiter
const delimiter = "\n";
const chunks = accumulatedValue.split(delimiter);

// Process all chunks except the last one (which may be incomplete)
while (chunks.length > 1) {
const chunkToDispatch = chunks.shift(); // Get the first chunk
if (chunkToDispatch && chunkToDispatch.length > 0) {
const chunk = JSON.parse(chunkToDispatch);
assMsg.content += chunk;
setMessageList([...messageListSnapshot, userMsg, assMsg]);
const reader = data.getReader();
const decoder = new TextDecoder();
let done = false;
let accumulatedValue = ""; // Variable to accumulate chunks

while (!done) {
const { value, done: doneReading } = await reader.read();
done = doneReading;
const chunkValue = decoder.decode(value);
accumulatedValue += chunkValue; // Accumulate the chunk value

// Check if the accumulated value contains the delimiter
const delimiter = "\n";
const chunks = accumulatedValue.split(delimiter);

// Process all chunks except the last one (which may be incomplete)
while (chunks.length > 1) {
const chunkToDispatch = chunks.shift(); // Get the first chunk
if (chunkToDispatch && chunkToDispatch.length > 0) {
const chunk = JSON.parse(chunkToDispatch);
assMsg.content += chunk;
setMessageList([...messageListSnapshot, userMsg, assMsg]);
}
}
}

// The last chunk may be incomplete, so keep it in the accumulated value
accumulatedValue = chunks[0];
}
// The last chunk may be incomplete, so keep it in the accumulated value
accumulatedValue = chunks[0];
}

// Process any remaining accumulated value after the loop is done
if (accumulatedValue.length > 0) {
assMsg.content += accumulatedValue;
setMessageList([...messageListSnapshot, userMsg, assMsg]);
// Process any remaining accumulated value after the loop is done
if (accumulatedValue.length > 0) {
assMsg.content += accumulatedValue;
setMessageList([...messageListSnapshot, userMsg, assMsg]);
}
} finally {
setIsLoading(false);
}
} finally {
setIsLoading(false);
}
}, []);
},
[id]
);

const reloadLastMessage = useCallback(async () => {
// Prevent multiple requests at once
Expand Down