diff --git a/frontend/src/App.tsx b/frontend/src/App.tsx index 9c35406383bd..7e35572a4ce7 100644 --- a/frontend/src/App.tsx +++ b/frontend/src/App.tsx @@ -1,5 +1,5 @@ import { useDisclosure } from "@nextui-org/react"; -import React, { useEffect, useState } from "react"; +import React, { useEffect } from "react"; import { Toaster } from "react-hot-toast"; import CogTooth from "#/assets/cog-tooth"; import ChatInterface from "#/components/chat/ChatInterface"; @@ -8,15 +8,13 @@ import { Container, Orientation } from "#/components/Resizable"; import Workspace from "#/components/Workspace"; import LoadPreviousSessionModal from "#/components/modals/load-previous-session/LoadPreviousSessionModal"; import SettingsModal from "#/components/modals/settings/SettingsModal"; -import { fetchMsgTotal } from "#/services/session"; -import Socket from "#/services/socket"; -import { ResFetchMsgTotal } from "#/types/ResponseType"; import "./App.css"; import AgentControlBar from "./components/AgentControlBar"; import AgentStatusBar from "./components/AgentStatusBar"; import Terminal from "./components/terminal/Terminal"; -import { initializeAgent } from "./services/agent"; -import { settingsAreUpToDate } from "./services/settings"; +import Session from "#/services/session"; +import { getToken } from "#/services/auth"; +import { settingsAreUpToDate } from "#/services/settings"; interface Props { setSettingOpen: (isOpen: boolean) => void; @@ -43,8 +41,6 @@ function Controls({ setSettingOpen }: Props): JSX.Element { let initOnce = false; function App(): JSX.Element { - const [isWarned, setIsWarned] = useState(false); - const { isOpen: settingsModalIsOpen, onOpen: onSettingsModalOpen, @@ -57,31 +53,18 @@ function App(): JSX.Element { onOpenChange: onLoadPreviousSessionModalOpenChange, } = useDisclosure(); - const getMsgTotal = () => { - if (isWarned) return; - fetchMsgTotal() - .then((data: ResFetchMsgTotal) => { - if (data.msg_total > 0) { - onLoadPreviousSessionModalOpen(); - setIsWarned(true); - } - }) - .catch(); - }; - useEffect(() => { if (initOnce) return; initOnce = true; if (!settingsAreUpToDate()) { onSettingsModalOpen(); + } else if (getToken()) { + onLoadPreviousSessionModalOpen(); } else { - initializeAgent(); + Session.startNewSession(); } - Socket.registerCallback("open", [getMsgTotal]); - - getMsgTotal(); // eslint-disable-next-line react-hooks/exhaustive-deps }, []); diff --git a/frontend/src/api/index.ts b/frontend/src/api/index.ts deleted file mode 100644 index 6c0f0d840c8d..000000000000 --- a/frontend/src/api/index.ts +++ /dev/null @@ -1,9 +0,0 @@ -export async function fetchModels() { - const response = await fetch(`/api/litellm-models`); - return response.json(); -} - -export async function fetchAgents() { - const response = await fetch(`/api/agents`); - return response.json(); -} diff --git a/frontend/src/components/AgentControlBar.tsx b/frontend/src/components/AgentControlBar.tsx index ad87bceb01bb..32eddb3eb31f 100644 --- a/frontend/src/components/AgentControlBar.tsx +++ b/frontend/src/components/AgentControlBar.tsx @@ -5,7 +5,6 @@ import ArrowIcon from "#/assets/arrow"; import PauseIcon from "#/assets/pause"; import PlayIcon from "#/assets/play"; import { changeAgentState } from "#/services/agentStateService"; -import { clearMsgs } from "#/services/session"; import store, { RootState } from "#/store"; import AgentState from "#/types/AgentState"; import { clearMessages } from "#/state/chatSlice"; @@ -73,7 +72,6 @@ function AgentControlBar() { } if (action === AgentState.STOPPED) { - clearMsgs().then().catch(); store.dispatch(clearMessages()); } else { setIsLoading(true); @@ -86,7 +84,6 @@ function AgentControlBar() { useEffect(() => { if (curAgentState === desiredState) { if (curAgentState === AgentState.STOPPED) { - clearMsgs().then().catch(); store.dispatch(clearMessages()); } setIsLoading(false); diff --git a/frontend/src/components/CodeEditor.tsx b/frontend/src/components/CodeEditor.tsx index ee7e0f551b95..a14b8650fd25 100644 --- a/frontend/src/components/CodeEditor.tsx +++ b/frontend/src/components/CodeEditor.tsx @@ -1,33 +1,24 @@ import Editor, { Monaco } from "@monaco-editor/react"; import { Tab, Tabs } from "@nextui-org/react"; import type { editor } from "monaco-editor"; -import React, { useMemo, useState } from "react"; +import React, { useMemo } from "react"; import { useTranslation } from "react-i18next"; import { VscCode } from "react-icons/vsc"; -import { useDispatch, useSelector } from "react-redux"; +import { useSelector } from "react-redux"; import { I18nKey } from "#/i18n/declaration"; -import { selectFile } from "#/services/fileService"; -import { setCode } from "#/state/codeSlice"; import { RootState } from "#/store"; import FileExplorer from "./file-explorer/FileExplorer"; -import { CodeEditorContext } from "./CodeEditorContext"; function CodeEditor(): JSX.Element { const { t } = useTranslation(); - const [selectedFileAbsolutePath, setSelectedFileAbsolutePath] = useState(""); - const selectedFileName = useMemo(() => { - const paths = selectedFileAbsolutePath.split("/"); - return paths[paths.length - 1]; - }, [selectedFileAbsolutePath]); - const codeEditorContext = useMemo( - () => ({ selectedFileAbsolutePath }), - [selectedFileAbsolutePath], - ); - - const dispatch = useDispatch(); const code = useSelector((state: RootState) => state.code.code); const activeFilepath = useSelector((state: RootState) => state.code.path); + const selectedFileName = useMemo(() => { + const paths = activeFilepath.split("/"); + return paths[paths.length - 1]; + }, [activeFilepath]); + const handleEditorDidMount = ( editor: editor.IStandaloneCodeEditor, monaco: Monaco, @@ -46,57 +37,44 @@ function CodeEditor(): JSX.Element { monaco.editor.setTheme("my-theme"); }; - const updateCode = async () => { - const newCode = await selectFile(activeFilepath); - setSelectedFileAbsolutePath(activeFilepath); - dispatch(setCode(newCode)); - }; - - React.useEffect(() => { - // FIXME: we can probably move this out of the component and into state/service - if (activeFilepath) updateCode(); - }, [activeFilepath]); - return (
- - -
- - +
+ + + +
+ {selectedFileName === "" ? ( +
+ + {t(I18nKey.CODE_EDITOR$EMPTY_MESSAGE)} +
+ ) : ( + - -
- {selectedFileName === "" ? ( -
- - {t(I18nKey.CODE_EDITOR$EMPTY_MESSAGE)} -
- ) : ( - - )} -
+ )}
- +
); } diff --git a/frontend/src/components/CodeEditorContext.ts b/frontend/src/components/CodeEditorContext.ts deleted file mode 100644 index 06340a1e91d9..000000000000 --- a/frontend/src/components/CodeEditorContext.ts +++ /dev/null @@ -1,5 +0,0 @@ -import { createContext } from "react"; - -export const CodeEditorContext = createContext({ - selectedFileAbsolutePath: "", -}); diff --git a/frontend/src/components/chat/ChatInterface.test.tsx b/frontend/src/components/chat/ChatInterface.test.tsx index 865484473892..570c5df3267d 100644 --- a/frontend/src/components/chat/ChatInterface.test.tsx +++ b/frontend/src/components/chat/ChatInterface.test.tsx @@ -5,7 +5,7 @@ import { act } from "react-dom/test-utils"; import userEvent from "@testing-library/user-event"; import { renderWithProviders } from "test-utils"; import ChatInterface from "./ChatInterface"; -import Socket from "#/services/socket"; +import Session from "#/services/session"; import ActionType from "#/types/ActionType"; import { addAssistantMessage } from "#/state/chatSlice"; import AgentState from "#/types/AgentState"; @@ -15,16 +15,17 @@ vi.mock("#/hooks/useTyping", () => ({ useTyping: vi.fn((text: string) => text), })); -const socketSpy = vi.spyOn(Socket, "send"); +const sessionSpy = vi.spyOn(Session, "send"); +vi.spyOn(Session, "isConnected").mockImplementation(() => true); // This is for the scrollview ref in Chat.tsx // TODO: Move this into test setup HTMLElement.prototype.scrollTo = vi.fn(() => {}); describe("ChatInterface", () => { - it("should render the messages and input", () => { + it("should render empty message list and input", () => { renderWithProviders(); - expect(screen.queryAllByTestId("message")).toHaveLength(1); // initial welcome message only + expect(screen.queryAllByTestId("message")).toHaveLength(0); }); it("should render the new message the user has typed", async () => { @@ -65,7 +66,7 @@ describe("ChatInterface", () => { expect(screen.getByText("Hello to you!")).toBeInTheDocument(); }); - it("should send the a start event to the Socket", () => { + it("should send the a start event to the Session", () => { renderWithProviders(, { preloadedState: { agent: { @@ -83,10 +84,10 @@ describe("ChatInterface", () => { action: ActionType.MESSAGE, args: { content: "my message" }, }; - expect(socketSpy).toHaveBeenCalledWith(JSON.stringify(event)); + expect(sessionSpy).toHaveBeenCalledWith(JSON.stringify(event)); }); - it("should send the a user message event to the Socket", () => { + it("should send the a user message event to the Session", () => { renderWithProviders(, { preloadedState: { agent: { @@ -104,7 +105,7 @@ describe("ChatInterface", () => { action: ActionType.MESSAGE, args: { content: "my message" }, }; - expect(socketSpy).toHaveBeenCalledWith(JSON.stringify(event)); + expect(sessionSpy).toHaveBeenCalledWith(JSON.stringify(event)); }); it("should disable the user input if agent is not initialized", () => { diff --git a/frontend/src/components/chat/ChatInterface.tsx b/frontend/src/components/chat/ChatInterface.tsx index 0f82eeb5103c..183ff99c70b4 100644 --- a/frontend/src/components/chat/ChatInterface.tsx +++ b/frontend/src/components/chat/ChatInterface.tsx @@ -10,7 +10,7 @@ import Chat from "./Chat"; import { RootState } from "#/store"; import AgentState from "#/types/AgentState"; import { sendChatMessage } from "#/services/chatService"; -import { addUserMessage } from "#/state/chatSlice"; +import { addUserMessage, addAssistantMessage } from "#/state/chatSlice"; import { I18nKey } from "#/i18n/declaration"; import { useScrollToBottom } from "#/hooks/useScrollToBottom"; @@ -58,6 +58,12 @@ function ChatInterface() { const { scrollDomToBottom, onChatBodyScroll, hitBottom } = useScrollToBottom(scrollRef); + React.useEffect(() => { + if (curAgentState === AgentState.INIT && messages.length === 0) { + dispatch(addAssistantMessage(t(I18nKey.CHAT_INTERFACE$INITIAL_MESSAGE))); + } + }, [curAgentState]); + return (
diff --git a/frontend/src/components/file-explorer/FileExplorer.test.tsx b/frontend/src/components/file-explorer/FileExplorer.test.tsx index 274e16e0814e..e0494017052a 100644 --- a/frontend/src/components/file-explorer/FileExplorer.test.tsx +++ b/frontend/src/components/file-explorer/FileExplorer.test.tsx @@ -7,8 +7,9 @@ import { describe, it, expect, vi, Mock } from "vitest"; import FileExplorer from "./FileExplorer"; import { uploadFiles, listFiles } from "#/services/fileService"; import toast from "#/utils/toast"; +import AgentState from "#/types/AgentState"; -const toastSpy = vi.spyOn(toast, "stickyError"); +const toastSpy = vi.spyOn(toast, "error"); vi.mock("../../services/fileService", async () => ({ listFiles: vi.fn(async (path: string = "/") => { @@ -42,7 +43,13 @@ describe("FileExplorer", () => { it.todo("should render an empty workspace"); it.only("should refetch the workspace when clicking the refresh button", async () => { - const { getByText } = renderWithProviders(); + const { getByText } = renderWithProviders(, { + preloadedState: { + agent: { + curAgentState: AgentState.RUNNING, + }, + }, + }); await waitFor(() => { expect(getByText("folder1")).toBeInTheDocument(); expect(getByText("file2.ts")).toBeInTheDocument(); diff --git a/frontend/src/components/file-explorer/FileExplorer.tsx b/frontend/src/components/file-explorer/FileExplorer.tsx index 5202186b335a..9b3d6d49e15d 100644 --- a/frontend/src/components/file-explorer/FileExplorer.tsx +++ b/frontend/src/components/file-explorer/FileExplorer.tsx @@ -5,14 +5,16 @@ import { IoIosRefresh, IoIosCloudUpload, } from "react-icons/io"; -import { useDispatch } from "react-redux"; +import { useDispatch, useSelector } from "react-redux"; import { IoFileTray } from "react-icons/io5"; import { twMerge } from "tailwind-merge"; +import AgentState from "#/types/AgentState"; import { setRefreshID } from "#/state/codeSlice"; import { listFiles, uploadFiles } from "#/services/fileService"; import IconButton from "../IconButton"; import ExplorerTree from "./ExplorerTree"; import toast from "#/utils/toast"; +import { RootState } from "#/store"; interface ExplorerActionsProps { onRefresh: () => void; @@ -87,6 +89,7 @@ function FileExplorer() { const [isHidden, setIsHidden] = React.useState(false); const [isDragging, setIsDragging] = React.useState(false); const [files, setFiles] = React.useState([]); + const { curAgentState } = useSelector((state: RootState) => state.agent); const fileInputRef = React.useRef(null); const dispatch = useDispatch(); @@ -95,6 +98,12 @@ function FileExplorer() { }; const refreshWorkspace = async () => { + if ( + curAgentState === AgentState.LOADING || + curAgentState === AgentState.STOPPED + ) { + return; + } dispatch(setRefreshID(Math.random())); setFiles(await listFiles("/")); }; @@ -104,7 +113,7 @@ function FileExplorer() { await uploadFiles(toAdd); await refreshWorkspace(); } catch (error) { - toast.stickyError("ws", "Error uploading file"); + toast.error("ws", "Error uploading file"); } }; @@ -112,7 +121,9 @@ function FileExplorer() { (async () => { await refreshWorkspace(); })(); + }, [curAgentState]); + React.useEffect(() => { const enableDragging = () => { setIsDragging(true); }; @@ -130,6 +141,10 @@ function FileExplorer() { }; }, []); + if (!files.length) { + return null; + } + return (
{isDragging && ( diff --git a/frontend/src/components/file-explorer/TreeNode.tsx b/frontend/src/components/file-explorer/TreeNode.tsx index 9852003642e2..df22093ede63 100644 --- a/frontend/src/components/file-explorer/TreeNode.tsx +++ b/frontend/src/components/file-explorer/TreeNode.tsx @@ -4,9 +4,8 @@ import { twMerge } from "tailwind-merge"; import { RootState } from "#/store"; import FolderIcon from "../FolderIcon"; import FileIcon from "../FileIcons"; -import { listFiles } from "#/services/fileService"; -import { setActiveFilepath } from "#/state/codeSlice"; -import { CodeEditorContext } from "../CodeEditorContext"; +import { listFiles, selectFile } from "#/services/fileService"; +import { setCode, setActiveFilepath } from "#/state/codeSlice"; interface TitleProps { name: string; @@ -36,8 +35,8 @@ interface TreeNodeProps { function TreeNode({ path, defaultOpen = false }: TreeNodeProps) { const [isOpen, setIsOpen] = React.useState(defaultOpen); const [children, setChildren] = React.useState(null); - const { selectedFileAbsolutePath } = React.useContext(CodeEditorContext); const refreshID = useSelector((state: RootState) => state.code.refreshID); + const activeFilepath = useSelector((state: RootState) => state.code.path); const dispatch = useDispatch(); @@ -60,10 +59,12 @@ function TreeNode({ path, defaultOpen = false }: TreeNodeProps) { refreshChildren(); }, [refreshID, isOpen]); - const handleClick = () => { + const handleClick = async () => { if (isDirectory) { setIsOpen((prev) => !prev); } else { + const newCode = await selectFile(path); + dispatch(setCode(newCode)); dispatch(setActiveFilepath(path)); } }; @@ -72,7 +73,7 @@ function TreeNode({ path, defaultOpen = false }: TreeNodeProps) {
({ - fetchMsgsMock: vi.fn(), -})); - -vi.mock("../../../services/session", async (importOriginal) => ({ - ...(await importOriginal<typeof import("../../../services/session")>()), - clearMsgs: vi.fn(), - fetchMsgs: mocks.fetchMsgsMock.mockResolvedValue({ - messages: [ - { - id: "1", - role: "user", - payload: { type: "action" }, - }, - { - id: "2", - role: "assistant", - payload: { type: "observation" }, - }, - ], - }), -})); - -vi.mock("../../../services/chatService", async (importOriginal) => ({ - ...(await importOriginal<typeof import("../../../services/chatService")>()), - addChatMessageFromEvent: vi.fn(), -})); - vi.mock("../../../services/actions", async (importOriginal) => ({ ...(await importOriginal<typeof import("../../../services/actions")>()), handleAssistantMessage: vi.fn(), })); -vi.mock("../../../utils/toast", () => ({ - default: { - stickyError: vi.fn(), - }, -})); +vi.spyOn(Session, "isConnected").mockImplementation(() => true); +const restoreOrStartNewSessionSpy = vi.spyOn( + Session, + "restoreOrStartNewSession", +); describe("LoadPreviousSession", () => { afterEach(() => { @@ -75,7 +44,6 @@ describe("LoadPreviousSession", () => { userEvent.click(startNewSessionButton); }); - expect(clearMsgs).toHaveBeenCalledTimes(1); // modal should close right after clearing messages expect(onOpenChangeMock).toHaveBeenCalledWith(false); }); @@ -93,36 +61,9 @@ describe("LoadPreviousSession", () => { }); await waitFor(() => { - expect(fetchMsgs).toHaveBeenCalledTimes(1); - expect(addChatMessageFromEvent).toHaveBeenCalledTimes(1); - expect(handleAssistantMessage).toHaveBeenCalledTimes(1); + expect(restoreOrStartNewSessionSpy).toHaveBeenCalledTimes(1); }); // modal should close right after fetching messages expect(onOpenChangeMock).toHaveBeenCalledWith(false); }); - - it("should show an error toast if there is an error fetching the session", async () => { - mocks.fetchMsgsMock.mockRejectedValue(new Error("Get messages failed.")); - - render(<LoadPreviousSessionModal isOpen onOpenChange={vi.fn} />); - - const resumeSessionButton = screen.getByRole("button", { - name: RESUME_SESSION_BUTTON_LABEL_KEY, - }); - - act(() => { - userEvent.click(resumeSessionButton); - }); - - await waitFor(async () => { - await expect(() => fetchMsgs()).rejects.toThrow(); - expect(handleAssistantMessage).not.toHaveBeenCalled(); - expect(addChatMessageFromEvent).not.toHaveBeenCalled(); - // error toast should be shown - expect(toast.stickyError).toHaveBeenCalledWith( - "ws", - "Error fetching the session", - ); - }); - }); }); diff --git a/frontend/src/components/modals/load-previous-session/LoadPreviousSessionModal.tsx b/frontend/src/components/modals/load-previous-session/LoadPreviousSessionModal.tsx index 2f29f05f10b5..7e3fd1140bf2 100644 --- a/frontend/src/components/modals/load-previous-session/LoadPreviousSessionModal.tsx +++ b/frontend/src/components/modals/load-previous-session/LoadPreviousSessionModal.tsx @@ -1,11 +1,8 @@ import React from "react"; import { useTranslation } from "react-i18next"; import { I18nKey } from "#/i18n/declaration"; -import { handleAssistantMessage } from "#/services/actions"; -import { addChatMessageFromEvent } from "#/services/chatService"; -import { clearMsgs, fetchMsgs } from "#/services/session"; -import toast from "#/utils/toast"; import BaseModal from "../base-modal/BaseModal"; +import Session from "#/services/session"; interface LoadPreviousSessionModalProps { isOpen: boolean; @@ -18,28 +15,6 @@ function LoadPreviousSessionModal({ }: LoadPreviousSessionModalProps) { const { t } = useTranslation(); - const onStartNewSession = async () => { - await clearMsgs(); - }; - - const onResumeSession = async () => { - try { - const { messages } = await fetchMsgs(); - - messages.forEach((message) => { - if (message.role === "user") { - addChatMessageFromEvent(message.payload); - } - - if (message.role === "assistant") { - handleAssistantMessage(message.payload); - } - }); - } catch (error) { - toast.stickyError("ws", "Error fetching the session"); - } - }; - return ( <BaseModal isOpen={isOpen} @@ -50,13 +25,13 @@ function LoadPreviousSessionModal({ { label: t(I18nKey.LOAD_SESSION$RESUME_SESSION_MODAL_ACTION_LABEL), className: "bg-primary rounded-lg", - action: onResumeSession, + action: Session.restoreOrStartNewSession, closeAfterAction: true, }, { label: t(I18nKey.LOAD_SESSION$START_NEW_SESSION_MODAL_ACTION_LABEL), className: "bg-neutral-500 rounded-lg", - action: onStartNewSession, + action: Session.startNewSession, closeAfterAction: true, }, ]} diff --git a/frontend/src/components/modals/settings/SettingsModal.test.tsx b/frontend/src/components/modals/settings/SettingsModal.test.tsx index f9e82103006d..f5a2a6a40b15 100644 --- a/frontend/src/components/modals/settings/SettingsModal.test.tsx +++ b/frontend/src/components/modals/settings/SettingsModal.test.tsx @@ -11,12 +11,14 @@ import { saveSettings, getDefaultSettings, } from "#/services/settings"; -import { initializeAgent } from "#/services/agent"; -import { fetchAgents, fetchModels } from "#/api"; +import Session from "#/services/session"; +import { fetchAgents, fetchModels } from "#/services/options"; import SettingsModal from "./SettingsModal"; const toastSpy = vi.spyOn(toast, "settingsChanged"); const i18nSpy = vi.spyOn(i18next, "changeLanguage"); +const startNewSessionSpy = vi.spyOn(Session, "startNewSession"); +vi.spyOn(Session, "isConnected").mockImplementation(() => true); vi.mock("#/services/settings", async (importOriginal) => ({ ...(await importOriginal<typeof import("#/services/settings")>()), @@ -35,12 +37,8 @@ vi.mock("#/services/settings", async (importOriginal) => ({ saveSettings: vi.fn(), })); -vi.mock("#/services/agent", async () => ({ - initializeAgent: vi.fn(), -})); - -vi.mock("#/api", async (importOriginal) => ({ - ...(await importOriginal<typeof import("#/api")>()), +vi.mock("#/services/options", async (importOriginal) => ({ + ...(await importOriginal<typeof import("#/services/options")>()), fetchModels: vi .fn() .mockResolvedValue(Promise.resolve(["model1", "model2", "model3"])), @@ -162,7 +160,7 @@ describe("SettingsModal", () => { userEvent.click(saveButton); }); - expect(initializeAgent).toHaveBeenCalled(); + expect(startNewSessionSpy).toHaveBeenCalled(); }); it("should display a toast for every change", async () => { diff --git a/frontend/src/components/modals/settings/SettingsModal.tsx b/frontend/src/components/modals/settings/SettingsModal.tsx index e68c0e153343..38bca6454b4f 100644 --- a/frontend/src/components/modals/settings/SettingsModal.tsx +++ b/frontend/src/components/modals/settings/SettingsModal.tsx @@ -3,10 +3,10 @@ import i18next from "i18next"; import React, { useEffect } from "react"; import { useTranslation } from "react-i18next"; import { useSelector } from "react-redux"; -import { fetchAgents, fetchModels } from "#/api"; +import { fetchAgents, fetchModels } from "#/services/options"; import { AvailableLanguages } from "#/i18n"; import { I18nKey } from "#/i18n/declaration"; -import { initializeAgent } from "#/services/agent"; +import Session from "#/services/session"; import { RootState } from "../../../store"; import AgentState from "../../../types/AgentState"; import { @@ -100,7 +100,7 @@ function SettingsModal({ isOpen, onOpenChange }: SettingsProps) { const updatedSettings = getSettingsDifference(settings); saveSettings(settings); i18next.changeLanguage(settings.LANGUAGE); - initializeAgent(); // reinitialize the agent with the new settings + Session.startNewSession(); const sensitiveKeys = ["LLM_API_KEY"]; diff --git a/frontend/src/i18n/translation.json b/frontend/src/i18n/translation.json index b203a330cdba..4af256fc0fc5 100644 --- a/frontend/src/i18n/translation.json +++ b/frontend/src/i18n/translation.json @@ -266,12 +266,12 @@ "en": "Please stop the agent before editing these settings." }, "LOAD_SESSION$MODAL_TITLE": { - "en": "Unfinished Session Detected", - "zh-CN": "检测到有未完成的会话", - "zh-TW": "偵測到未完成的會話" + "en": "Return to existing session?", + "zh-CN": "是否继续未完成的会话?", + "zh-TW": "是否繼續未完成的會話?" }, "LOAD_SESSION$MODAL_CONTENT": { - "en": "You seem to have an unfinished task. Would you like to pick up where you left off or start fresh?", + "en": "You seem to have an ongoing session. Would you like to pick up where you left off, or start fresh?", "zh-CN": "您似乎有一个未完成的任务。您想继续之前的工作还是重新开始?", "zh-TW": "您似乎有一個未完成的任務。您想從上次離開的地方繼續還是重新開始?" }, diff --git a/frontend/src/services/agent.test.ts b/frontend/src/services/agent.test.ts deleted file mode 100644 index 0075f4dbdd8d..000000000000 --- a/frontend/src/services/agent.test.ts +++ /dev/null @@ -1,29 +0,0 @@ -import { describe, expect, it, vi } from "vitest"; - -import ActionType from "#/types/ActionType"; -import { initializeAgent } from "./agent"; -import { Settings, saveSettings } from "./settings"; -import Socket from "./socket"; - -const sendSpy = vi.spyOn(Socket, "send"); - -describe("initializeAgent", () => { - it("Should initialize the agent with the current settings", () => { - const settings: Settings = { - LLM_MODEL: "llm_value", - AGENT: "agent_value", - LANGUAGE: "language_value", - LLM_API_KEY: "sk-...", - }; - - const event = { - action: ActionType.INIT, - args: settings, - }; - - saveSettings(settings); - initializeAgent(); - - expect(sendSpy).toHaveBeenCalledWith(JSON.stringify(event)); - }); -}); diff --git a/frontend/src/services/agent.ts b/frontend/src/services/agent.ts deleted file mode 100644 index dd8b7a6bfded..000000000000 --- a/frontend/src/services/agent.ts +++ /dev/null @@ -1,14 +0,0 @@ -import ActionType from "#/types/ActionType"; -import { getSettings } from "./settings"; -import Socket from "./socket"; - -/** - * Initialize the agent with the current settings. - * @param settings - The new settings. - */ -export const initializeAgent = () => { - const settings = getSettings(); - const event = { action: ActionType.INIT, args: settings }; - const eventString = JSON.stringify(event); - Socket.send(eventString); -}; diff --git a/frontend/src/services/agentStateService.ts b/frontend/src/services/agentStateService.ts index 8704dc7097fd..4df4f1bfc256 100644 --- a/frontend/src/services/agentStateService.ts +++ b/frontend/src/services/agentStateService.ts @@ -1,7 +1,6 @@ import ActionType from "#/types/ActionType"; import AgentState from "#/types/AgentState"; -import Socket from "./socket"; -import { initializeAgent } from "./agent"; +import Session from "./session"; const INIT_DELAY = 1000; @@ -10,10 +9,10 @@ export function changeAgentState(state: AgentState): void { action: ActionType.CHANGE_AGENT_STATE, args: { agent_state: state }, }); - Socket.send(eventString); + Session.send(eventString); if (state === AgentState.STOPPED) { setTimeout(() => { - initializeAgent(); + Session.startNewSession(); }, INIT_DELAY); } } diff --git a/frontend/src/services/api.ts b/frontend/src/services/api.ts new file mode 100644 index 000000000000..bff6fc92293a --- /dev/null +++ b/frontend/src/services/api.ts @@ -0,0 +1,58 @@ +import { getToken } from "./auth"; +import toast from "#/utils/toast"; + +const WAIT_FOR_AUTH_DELAY_MS = 500; + +export async function request( + url: string, + optionsIn: RequestInit = {}, + disableToast: boolean = false, + /* eslint-disable-next-line @typescript-eslint/no-explicit-any */ +): Promise<any> { + const options = JSON.parse(JSON.stringify(optionsIn)); + + const onFail = (msg: string) => { + if (!disableToast) { + toast.error("api", msg); + } + throw new Error(msg); + }; + + const needsAuth = !url.startsWith("/api/options/"); + const token = getToken(); + if (!token && needsAuth) { + return new Promise((resolve) => { + setTimeout(() => { + resolve(request(url, optionsIn, disableToast)); + }, WAIT_FOR_AUTH_DELAY_MS); + }); + } + if (token) { + options.headers = { + ...(options.headers || {}), + Authorization: `Bearer ${token}`, + }; + } + + let response = null; + try { + response = await fetch(url, options); + } catch (e) { + onFail(`Error fetching ${url}`); + } + if (response?.status && response?.status >= 400) { + onFail( + `${response.status} error while fetching ${url}: ${response?.statusText}`, + ); + } + if (!response?.ok) { + onFail(`Error fetching ${url}: ${response?.statusText}`); + } + + try { + return await (response && response.json()); + } catch (e) { + onFail(`Error parsing JSON from ${url}`); + } + return null; +} diff --git a/frontend/src/services/auth.test.ts b/frontend/src/services/auth.test.ts index b21843acfcae..64850c37917a 100644 --- a/frontend/src/services/auth.test.ts +++ b/frontend/src/services/auth.test.ts @@ -1,18 +1,5 @@ -import * as jose from "jose"; import type { Mock } from "vitest"; -import { fetchToken, validateToken, getToken } from "./auth"; - -vi.mock("jose", () => ({ - decodeJwt: vi.fn(), -})); - -// SUGGESTION: Prefer using msw for mocking requests (see https://mswjs.io/) -global.fetch = vi.fn(() => - Promise.resolve({ - status: 200, - json: () => Promise.resolve({ token: "newToken" }), - }), -) as Mock; +import { getToken } from "./auth"; Storage.prototype.getItem = vi.fn(); Storage.prototype.setItem = vi.fn(); @@ -22,66 +9,12 @@ describe("Auth Service", () => { vi.clearAllMocks(); }); - describe("fetchToken", () => { + describe("getToken", () => { it("should fetch and return a token", async () => { - const data = await fetchToken(); - + (Storage.prototype.getItem as Mock).mockReturnValue("newToken"); + const data = await getToken(); expect(localStorage.getItem).toHaveBeenCalledWith("token"); // Used to set Authorization header - expect(data).toEqual({ token: "newToken" }); - expect(fetch).toHaveBeenCalledWith(`/api/auth`, { - headers: expect.any(Headers), - }); - }); - - it("throws an error if response status is not 200", async () => { - (fetch as Mock).mockImplementationOnce(() => - Promise.resolve({ status: 401 }), - ); - await expect(fetchToken()).rejects.toThrow("Get token failed."); - }); - }); - - describe("validateToken", () => { - it("returns true for a valid token", () => { - (jose.decodeJwt as Mock).mockReturnValue({ sid: "123" }); - expect(validateToken("validToken")).toBe(true); - }); - - it("returns false for an invalid token", () => { - (jose.decodeJwt as Mock).mockReturnValue({}); - expect(validateToken("invalidToken")).toBe(false); - }); - - it("returns false when decodeJwt throws", () => { - (jose.decodeJwt as Mock).mockImplementation(() => { - throw new Error("Invalid token"); - }); - expect(validateToken("badToken")).toBe(false); - }); - }); - - describe("getToken", () => { - it("returns existing valid token from localStorage", async () => { - (jose.decodeJwt as Mock).mockReturnValue({ sid: "123" }); - (Storage.prototype.getItem as Mock).mockReturnValue("existingToken"); - - const token = await getToken(); - expect(token).toBe("existingToken"); - }); - - it("fetches, validates, and stores a new token when existing token is invalid", async () => { - (jose.decodeJwt as Mock) - .mockReturnValueOnce({}) - .mockReturnValueOnce({ sid: "123" }); - - const token = await getToken(); - expect(token).toBe("newToken"); - expect(localStorage.setItem).toHaveBeenCalledWith("token", "newToken"); - }); - - it("throws an error when fetched token is invalid", async () => { - (jose.decodeJwt as Mock).mockReturnValue({}); - await expect(getToken()).rejects.toThrow("Token validation failed."); + expect(data).toEqual("newToken"); }); }); }); diff --git a/frontend/src/services/auth.ts b/frontend/src/services/auth.ts index 29fd75db65b8..a7d8cfa490b9 100644 --- a/frontend/src/services/auth.ts +++ b/frontend/src/services/auth.ts @@ -1,44 +1,13 @@ -import * as jose from "jose"; -import { ResFetchToken } from "#/types/ResponseType"; +const TOKEN_KEY = "token"; -const fetchToken = async (): Promise<ResFetchToken> => { - const headers = new Headers({ - "Content-Type": "application/json", - Authorization: `Bearer ${localStorage.getItem("token")}`, - }); - const response = await fetch(`/api/auth`, { headers }); - if (response.status !== 200) { - throw new Error("Get token failed."); - } - const data: ResFetchToken = await response.json(); - return data; -}; +const getToken = (): string => localStorage.getItem(TOKEN_KEY) ?? ""; -export const validateToken = (token: string): boolean => { - try { - const claims = jose.decodeJwt(token); - return !(claims.sid === undefined || claims.sid === ""); - } catch (error) { - return false; - } +const clearToken = (): void => { + localStorage.removeItem(TOKEN_KEY); }; -const getToken = async (): Promise<string> => { - const token = localStorage.getItem("token") ?? ""; - if (validateToken(token)) { - return token; - } - - const data = await fetchToken(); - if (data.token === undefined || data.token === "") { - throw new Error("Get token failed."); - } - const newToken = data.token; - if (validateToken(newToken)) { - localStorage.setItem("token", newToken); - return newToken; - } - throw new Error("Token validation failed."); +const setToken = (token: string): void => { + localStorage.setItem(TOKEN_KEY, token); }; -export { getToken, fetchToken }; +export { getToken, setToken, clearToken }; diff --git a/frontend/src/services/chatService.ts b/frontend/src/services/chatService.ts index 52cb4821e143..af1ab45ce86b 100644 --- a/frontend/src/services/chatService.ts +++ b/frontend/src/services/chatService.ts @@ -1,28 +1,8 @@ -import store from "#/store"; import ActionType from "#/types/ActionType"; -import { SocketMessage } from "#/types/ResponseType"; -import { ActionMessage } from "#/types/Message"; -import Socket from "./socket"; -import { addUserMessage } from "#/state/chatSlice"; +import Session from "./session"; export function sendChatMessage(message: string): void { const event = { action: ActionType.MESSAGE, args: { content: message } }; const eventString = JSON.stringify(event); - Socket.send(eventString); -} - -export function addChatMessageFromEvent(event: string | SocketMessage): void { - try { - let data: ActionMessage; - if (typeof event === "string") { - data = JSON.parse(event); - } else { - data = event as ActionMessage; - } - if (data && data.args && data.args.task) { - store.dispatch(addUserMessage(data.args.task)); - } - } catch (error) { - // - } + Session.send(eventString); } diff --git a/frontend/src/services/fileService.ts b/frontend/src/services/fileService.ts index 21e8f7f054d2..f12a0e499bca 100644 --- a/frontend/src/services/fileService.ts +++ b/frontend/src/services/fileService.ts @@ -1,9 +1,7 @@ +import { request } from "./api"; + export async function selectFile(file: string): Promise<string> { - const res = await fetch(`/api/select-file?file=${file}`); - const data = await res.json(); - if (res.status !== 200) { - throw new Error(data.error); - } + const data = await request(`/api/select-file?file=${file}`); return data.code as string; } @@ -13,20 +11,13 @@ export async function uploadFiles(files: FileList) { formData.append("files", files[i]); } - const res = await fetch("/api/upload-files", { + await request("/api/upload-files", { method: "POST", body: formData, }); - - const data = await res.json(); - - if (res.status !== 200) { - throw new Error(data.error || "Failed to upload files."); - } } export async function listFiles(path: string = "/"): Promise<string[]> { - const res = await fetch(`/api/list-files?path=${path}`); - const data = await res.json(); + const data = await request(`/api/list-files?path=${path}`); return data as string[]; } diff --git a/frontend/src/services/options.ts b/frontend/src/services/options.ts new file mode 100644 index 000000000000..e3216be55d79 --- /dev/null +++ b/frontend/src/services/options.ts @@ -0,0 +1,9 @@ +import { request } from "./api"; + +export async function fetchModels() { + return request(`/api/options/models`); +} + +export async function fetchAgents() { + return request(`/api/options/agents`); +} diff --git a/frontend/src/services/session.test.ts b/frontend/src/services/session.test.ts index 3ef95854804e..bf23b06bf5a0 100644 --- a/frontend/src/services/session.test.ts +++ b/frontend/src/services/session.test.ts @@ -1,127 +1,36 @@ -import type { Mock } from "vitest"; -import { - ResDelMsg, - ResFetchMsg, - ResFetchMsgTotal, - ResFetchMsgs, -} from "../types/ResponseType"; -import { clearMsgs, fetchMsgTotal, fetchMsgs } from "./session"; - -// SUGGESTION: Prefer using msw for mocking requests (see https://mswjs.io/) -global.fetch = vi.fn(); -Storage.prototype.getItem = vi.fn(); - -describe("Session Service", () => { - beforeEach(() => { - vi.clearAllMocks(); - }); - - afterEach(() => { - // Used to set Authorization header - expect(localStorage.getItem).toHaveBeenCalledWith("token"); - }); - - describe("fetchMsgTotal", () => { - it("should fetch and return message total", async () => { - const expectedResult: ResFetchMsgTotal = { - msg_total: 10, - }; - - (fetch as Mock).mockImplementationOnce(() => - Promise.resolve({ - status: 200, - json: () => Promise.resolve(expectedResult), - }), - ); - - const data = await fetchMsgTotal(); - - expect(fetch).toHaveBeenCalledWith(`/api/messages/total`, { - headers: expect.any(Headers), - }); - - expect(data).toEqual(expectedResult); - }); - - it("throws an error if response status is not 200", async () => { - // NOTE: The current implementation ONLY handles 200 status; - // this means throwing even with a status of 201, 204, etc. - (fetch as Mock).mockImplementationOnce(() => - Promise.resolve({ status: 401 }), - ); - - await expect(fetchMsgTotal()).rejects.toThrow( - "Get message total failed.", - ); - }); - }); - - describe("fetchMsgs", () => { - it("should fetch and return messages", async () => { - const expectedResult: ResFetchMsgs = { - messages: [ - { - id: "1", - role: "user", - payload: {} as ResFetchMsg["payload"], - }, - ], - }; - - (fetch as Mock).mockImplementationOnce(() => - Promise.resolve({ - status: 200, - json: () => Promise.resolve(expectedResult), - }), - ); - - const data = await fetchMsgs(); - - expect(fetch).toHaveBeenCalledWith(`/api/messages`, { - headers: expect.any(Headers), - }); - - expect(data).toEqual(expectedResult); - }); - - it("throws an error if response status is not 200", async () => { - (fetch as Mock).mockImplementationOnce(() => - Promise.resolve({ status: 401 }), - ); - - await expect(fetchMsgs()).rejects.toThrow("Get messages failed."); - }); +import { describe, expect, it, vi } from "vitest"; + +import ActionType from "#/types/ActionType"; +import { Settings, saveSettings } from "./settings"; +import Session from "./session"; + +const sendSpy = vi.spyOn(Session, "send"); +const setupSpy = vi + /* eslint-disable-next-line @typescript-eslint/no-explicit-any */ + .spyOn(Session as any, "_setupSocket") + .mockImplementation(() => { + /* eslint-disable-next-line @typescript-eslint/dot-notation */ + Session["_initializeAgent"](); // use key syntax to fix complaint about private fn }); - describe("clearMsgs", () => { - it("should clear messages", async () => { - const expectedResult: ResDelMsg = { - ok: "true", - }; - - (fetch as Mock).mockImplementationOnce(() => - Promise.resolve({ - status: 200, - json: () => Promise.resolve(expectedResult), - }), - ); - - const data = await clearMsgs(); - - expect(fetch).toHaveBeenCalledWith(`/api/messages`, { - method: "DELETE", - headers: expect.any(Headers), - }); - - expect(data).toEqual(expectedResult); - }); - - it("throws an error if response status is not 200", async () => { - (fetch as Mock).mockImplementationOnce(() => - Promise.resolve({ status: 401 }), - ); - - await expect(clearMsgs()).rejects.toThrow("Delete messages failed."); - }); +describe("startNewSession", () => { + it("Should start a new session with the current settings", () => { + const settings: Settings = { + LLM_MODEL: "llm_value", + AGENT: "agent_value", + LANGUAGE: "language_value", + LLM_API_KEY: "sk-...", + }; + + const event = { + action: ActionType.INIT, + args: settings, + }; + + saveSettings(settings); + Session.startNewSession(); + + expect(setupSpy).toHaveBeenCalledTimes(1); + expect(sendSpy).toHaveBeenCalledWith(JSON.stringify(event)); }); }); diff --git a/frontend/src/services/session.ts b/frontend/src/services/session.ts index 91b6e98fe1e6..5b889952a699 100644 --- a/frontend/src/services/session.ts +++ b/frontend/src/services/session.ts @@ -1,49 +1,165 @@ -import { - ResDelMsg, - ResFetchMsgs, - ResFetchMsgTotal, -} from "../types/ResponseType"; - -const fetchMsgTotal = async (): Promise<ResFetchMsgTotal> => { - const headers = new Headers({ - "Content-Type": "application/json", - Authorization: `Bearer ${localStorage.getItem("token")}`, - }); - const response = await fetch(`/api/messages/total`, { headers }); - if (response.status !== 200) { - throw new Error("Get message total failed."); - } - const data: ResFetchMsgTotal = await response.json(); - return data; -}; - -const fetchMsgs = async (): Promise<ResFetchMsgs> => { - const headers = new Headers({ - "Content-Type": "application/json", - Authorization: `Bearer ${localStorage.getItem("token")}`, - }); - const response = await fetch(`/api/messages`, { headers }); - if (response.status !== 200) { - throw new Error("Get messages failed."); - } - const data: ResFetchMsgs = await response.json(); - return data; -}; - -const clearMsgs = async (): Promise<ResDelMsg> => { - const headers = new Headers({ - "Content-Type": "application/json", - Authorization: `Bearer ${localStorage.getItem("token")}`, - }); - const response = await fetch(`/api/messages`, { - method: "DELETE", - headers, - }); - if (response.status !== 200) { - throw new Error("Delete messages failed."); - } - const data: ResDelMsg = await response.json(); - return data; -}; - -export { fetchMsgTotal, fetchMsgs, clearMsgs }; +import toast from "#/utils/toast"; +import { handleAssistantMessage } from "./actions"; +import { getToken, setToken, clearToken } from "./auth"; +import ActionType from "#/types/ActionType"; +import { getSettings } from "./settings"; + +class Session { + private static _socket: WebSocket | null = null; + + // callbacks contain a list of callable functions + // event: function, like: + // open: [function1, function2] + // message: [function1, function2] + private static callbacks: { + [K in keyof WebSocketEventMap]: ((data: WebSocketEventMap[K]) => void)[]; + } = { + open: [], + message: [], + error: [], + close: [], + }; + + private static _connecting = false; + + private static _disconnecting = false; + + public static restoreOrStartNewSession() { + const token = getToken(); + if (Session.isConnected()) { + Session.disconnect(); + } + Session._connect(token); + } + + public static startNewSession() { + clearToken(); + Session.restoreOrStartNewSession(); + } + + private static _initializeAgent = () => { + const settings = getSettings(); + const event = { action: ActionType.INIT, args: settings }; + const eventString = JSON.stringify(event); + Session.send(eventString); + }; + + private static _connect(token: string = ""): void { + if (Session.isConnected()) return; + Session._connecting = true; + + const protocol = window.location.protocol === "https:" ? "wss:" : "ws:"; + const WS_URL = `${protocol}//${window.location.host}/ws?token=${token}`; + Session._socket = new WebSocket(WS_URL); + Session._setupSocket(); + } + + private static _setupSocket(): void { + if (!Session._socket) { + throw new Error("Socket is not initialized."); + } + Session._socket.onopen = (e) => { + toast.success("ws", "Connected to server."); + Session._connecting = false; + Session._initializeAgent(); + Session.callbacks.open?.forEach((callback) => { + callback(e); + }); + }; + + Session._socket.onmessage = (e) => { + let data = null; + try { + data = JSON.parse(e.data); + } catch (err) { + // TODO: report the error + console.error("Error parsing JSON data", err); + return; + } + if (data.error && data.error_code === 401) { + clearToken(); + } else if (data.token) { + setToken(data.token); + } else { + handleAssistantMessage(data); + } + }; + + Session._socket.onerror = () => { + const msg = "Connection failed. Retry..."; + toast.error("ws", msg); + }; + + Session._socket.onclose = () => { + if (!Session._disconnecting) { + setTimeout(() => { + Session.restoreOrStartNewSession(); + }, 3000); // Reconnect after 3 seconds + } + Session._disconnecting = false; + }; + } + + static isConnected(): boolean { + return ( + Session._socket !== null && Session._socket.readyState === WebSocket.OPEN + ); + } + + static disconnect(): void { + Session._disconnecting = true; + if (Session._socket) { + Session._socket.close(); + } + Session._socket = null; + } + + static send(message: string): void { + if (Session._connecting) { + setTimeout(() => Session.send(message), 1000); + return; + } + if (!Session.isConnected()) { + throw new Error("Not connected to server."); + } + + if (Session.isConnected()) { + Session._socket?.send(message); + } else { + const msg = "Connection failed. Retry..."; + toast.error("ws", msg); + } + } + + static addEventListener( + event: string, + callback: (e: MessageEvent) => void, + ): void { + Session._socket?.addEventListener( + event as keyof WebSocketEventMap, + callback as ( + this: WebSocket, + ev: WebSocketEventMap[keyof WebSocketEventMap], + ) => never, + ); + } + + static removeEventListener( + event: string, + listener: (e: Event) => void, + ): void { + Session._socket?.removeEventListener(event, listener); + } + + static registerCallback<K extends keyof WebSocketEventMap>( + event: K, + callbacks: ((data: WebSocketEventMap[K]) => void)[], + ): void { + if (Session.callbacks[event] === undefined) { + return; + } + Session.callbacks[event].push(...callbacks); + } +} + +export default Session; diff --git a/frontend/src/services/socket.ts b/frontend/src/services/socket.ts deleted file mode 100644 index 981c25978371..000000000000 --- a/frontend/src/services/socket.ts +++ /dev/null @@ -1,129 +0,0 @@ -// import { toast } from "sonner"; -import toast from "#/utils/toast"; -import { handleAssistantMessage } from "./actions"; -import { getToken } from "./auth"; - -class Socket { - private static _socket: WebSocket | null = null; - - // callbacks contain a list of callable functions - // event: function, like: - // open: [function1, function2] - // message: [function1, function2] - private static callbacks: { - [K in keyof WebSocketEventMap]: ((data: WebSocketEventMap[K]) => void)[]; - } = { - open: [], - message: [], - error: [], - close: [], - }; - - private static initializing = false; - - public static tryInitialize(): void { - if (Socket.initializing) return; - Socket.initializing = true; - getToken() - .then((token) => { - Socket._initialize(token); - }) - .catch(() => { - const msg = `Connection failed. Retry...`; - toast.stickyError("ws", msg); - - setTimeout(() => { - this.tryInitialize(); - }, 1500); - }); - } - - private static _initialize(token: string): void { - if (Socket.isConnected()) return; - - const protocol = window.location.protocol === "https:" ? "wss:" : "ws:"; - const WS_URL = `${protocol}//${window.location.host}/ws?token=${token}`; - Socket._socket = new WebSocket(WS_URL); - - Socket._socket.onopen = (e) => { - toast.stickySuccess("ws", "Connected to server."); - Socket.initializing = false; - Socket.callbacks.open?.forEach((callback) => { - callback(e); - }); - }; - - Socket._socket.onmessage = (e) => { - handleAssistantMessage(e.data); - }; - - Socket._socket.onerror = () => { - const msg = "Connection failed. Retry..."; - toast.stickyError("ws", msg); - }; - - Socket._socket.onclose = () => { - // Reconnect after a delay - setTimeout(() => { - Socket.tryInitialize(); - }, 3000); // Reconnect after 3 seconds - }; - } - - static isConnected(): boolean { - return ( - Socket._socket !== null && Socket._socket.readyState === WebSocket.OPEN - ); - } - - static send(message: string): void { - if (!Socket.isConnected()) { - Socket.tryInitialize(); - } - if (Socket.initializing) { - setTimeout(() => Socket.send(message), 1000); - return; - } - - if (Socket.isConnected()) { - Socket._socket?.send(message); - } else { - const msg = "Connection failed. Retry..."; - toast.stickyError("ws", msg); - } - } - - static addEventListener( - event: string, - callback: (e: MessageEvent) => void, - ): void { - Socket._socket?.addEventListener( - event as keyof WebSocketEventMap, - callback as ( - this: WebSocket, - ev: WebSocketEventMap[keyof WebSocketEventMap], - ) => never, - ); - } - - static removeEventListener( - event: string, - listener: (e: Event) => void, - ): void { - Socket._socket?.removeEventListener(event, listener); - } - - static registerCallback<K extends keyof WebSocketEventMap>( - event: K, - callbacks: ((data: WebSocketEventMap[K]) => void)[], - ): void { - if (Socket.callbacks[event] === undefined) { - return; - } - Socket.callbacks[event].push(...callbacks); - } -} - -Socket.tryInitialize(); - -export default Socket; diff --git a/frontend/src/services/taskService.ts b/frontend/src/services/taskService.ts index 88b877abf323..fa84444e294d 100644 --- a/frontend/src/services/taskService.ts +++ b/frontend/src/services/taskService.ts @@ -1,3 +1,5 @@ +import { request } from "./api"; + export type Task = { id: string; goal: string; @@ -14,14 +16,6 @@ export enum TaskState { } export async function getRootTask(): Promise<Task | undefined> { - const headers = new Headers({ - "Content-Type": "application/json", - Authorization: `Bearer ${localStorage.getItem("token")}`, - }); - const res = await fetch("/api/root_task", { headers }); - if (res.status !== 200 && res.status !== 204) { - return undefined; - } - const data = (await res.json()) as Task; - return data; + const res = await request("/api/root_task"); + return res as Task; } diff --git a/frontend/src/state/chatSlice.ts b/frontend/src/state/chatSlice.ts index 1438103f9126..757806e1e048 100644 --- a/frontend/src/state/chatSlice.ts +++ b/frontend/src/state/chatSlice.ts @@ -3,13 +3,7 @@ import { createSlice, PayloadAction } from "@reduxjs/toolkit"; type SliceState = { messages: Message[] }; const initialState: SliceState = { - messages: [ - { - content: - "Hi! I'm OpenDevin, an AI Software Engineer. What would you like to build with me today?", - sender: "assistant", - }, - ], + messages: [], }; export const chatSlice = createSlice({ diff --git a/frontend/src/types/ResponseType.tsx b/frontend/src/types/ResponseType.tsx index 63d710e1eeed..b635d78c3370 100644 --- a/frontend/src/types/ResponseType.tsx +++ b/frontend/src/types/ResponseType.tsx @@ -1,41 +1,5 @@ import { ActionMessage, ObservationMessage } from "./Message"; -type Role = "user" | "assistant"; - -interface ResConfigurations { - [key: string]: string | boolean | number; -} - -interface ResFetchToken { - token: string; -} - -interface ResFetchMsgTotal { - msg_total: number; -} - -interface ResFetchMsg { - id: string; - role: Role; - payload: SocketMessage; -} - -interface ResFetchMsgs { - messages: ResFetchMsg[]; -} - -interface ResDelMsg { - ok: string; -} - type SocketMessage = ActionMessage | ObservationMessage; -export { - type ResConfigurations, - type ResFetchToken, - type ResFetchMsgTotal, - type ResFetchMsg, - type ResFetchMsgs, - type ResDelMsg, - type SocketMessage, -}; +export { type SocketMessage }; diff --git a/frontend/src/utils/toast.tsx b/frontend/src/utils/toast.tsx index 08debd157475..132b3497c28b 100644 --- a/frontend/src/utils/toast.tsx +++ b/frontend/src/utils/toast.tsx @@ -3,15 +3,10 @@ import toast from "react-hot-toast"; const idMap = new Map<string, string>(); export default { - stickyError: (id: string, msg: string) => { + error: (id: string, msg: string) => { if (idMap.has(id)) return; // prevent duplicate toast - const toastId = toast.loading(msg, { - // icon: "👏", - // style: { - // borderRadius: "10px", - // background: "#333", - // color: "#fff", - // }, + const toastId = toast(msg, { + duration: 4000, style: { background: "#ef4444", color: "#fff", @@ -24,12 +19,13 @@ export default { }); idMap.set(id, toastId); }, - stickySuccess: (id: string, msg: string) => { + success: (id: string, msg: string) => { const toastId = idMap.get(id); if (toastId === undefined) return; if (toastId) { toast.success(msg, { id: toastId, + duration: 4000, style: { background: "#333", color: "#fff", diff --git a/opendevin/const/guide_url.py b/opendevin/const/guide_url.py deleted file mode 100644 index 80bb3cfa985c..000000000000 --- a/opendevin/const/guide_url.py +++ /dev/null @@ -1 +0,0 @@ -TROUBLESHOOTING_URL = 'https://opendevin.github.io/OpenDevin/modules/usage/troubleshooting' diff --git a/opendevin/controller/agent_controller.py b/opendevin/controller/agent_controller.py index cb374f9a8021..13db2501c5a2 100644 --- a/opendevin/controller/agent_controller.py +++ b/opendevin/controller/agent_controller.py @@ -243,6 +243,9 @@ async def _step(self): def get_state(self): return self.state + def set_state(self, state: State): + self.state = state + def _is_stuck(self): # check if delegate stuck if self.delegate and self.delegate._is_stuck(): diff --git a/opendevin/core/config.py b/opendevin/core/config.py index aba7708c8f97..1de2b78a5a04 100644 --- a/opendevin/core/config.py +++ b/opendevin/core/config.py @@ -3,6 +3,7 @@ import os import pathlib import platform +import uuid from dataclasses import dataclass, field, fields, is_dataclass from types import UnionType from typing import Any, ClassVar, get_args, get_origin @@ -173,6 +174,7 @@ class AppConfig(metaclass=Singleton): sandbox_user_id: int = os.getuid() if hasattr(os, 'getuid') else 1000 sandbox_timeout: int = 120 github_token: str | None = None + jwt_secret: str = uuid.uuid4().hex debug: bool = False enable_auto_lint: bool = ( False # once enabled, OpenDevin would lint files after editing diff --git a/opendevin/core/const/guide_url.py b/opendevin/core/const/guide_url.py new file mode 100644 index 000000000000..7ec5e6c908ac --- /dev/null +++ b/opendevin/core/const/guide_url.py @@ -0,0 +1,3 @@ +TROUBLESHOOTING_URL = ( + 'https://opendevin.github.io/OpenDevin/modules/usage/troubleshooting' +) diff --git a/opendevin/runtime/docker/exec_box.py b/opendevin/runtime/docker/exec_box.py index dd283bf87551..6af621d2517e 100644 --- a/opendevin/runtime/docker/exec_box.py +++ b/opendevin/runtime/docker/exec_box.py @@ -10,8 +10,8 @@ import docker -from opendevin.const.guide_url import TROUBLESHOOTING_URL from opendevin.core.config import config +from opendevin.core.const.guide_url import TROUBLESHOOTING_URL from opendevin.core.exceptions import SandboxInvalidBackgroundCommandError from opendevin.core.logger import opendevin_logger as logger from opendevin.core.schema import CancellableStream diff --git a/opendevin/runtime/docker/ssh_box.py b/opendevin/runtime/docker/ssh_box.py index 9328deed913d..94fa04c5c383 100644 --- a/opendevin/runtime/docker/ssh_box.py +++ b/opendevin/runtime/docker/ssh_box.py @@ -12,8 +12,8 @@ import docker from pexpect import exceptions, pxssh -from opendevin.const.guide_url import TROUBLESHOOTING_URL from opendevin.core.config import config +from opendevin.core.const.guide_url import TROUBLESHOOTING_URL from opendevin.core.exceptions import SandboxInvalidBackgroundCommandError from opendevin.core.logger import opendevin_logger as logger from opendevin.core.schema import CancellableStream diff --git a/opendevin/runtime/e2b/filestore.py b/opendevin/runtime/e2b/filestore.py new file mode 100644 index 000000000000..0494dde21a51 --- /dev/null +++ b/opendevin/runtime/e2b/filestore.py @@ -0,0 +1,18 @@ +from opendevin.storage.files import FileStore + + +class E2BFileStore(FileStore): + def __init__(self, filesystem): + self.filesystem = filesystem + + def write(self, path: str, contents: str) -> None: + self.filesystem.write(path, contents) + + def read(self, path: str) -> str: + return self.filesystem.read(path) + + def list(self, path: str) -> list[str]: + return self.filesystem.list(path) + + def delete(self, path: str) -> None: + self.filesystem.delete(path) diff --git a/opendevin/runtime/e2b/runtime.py b/opendevin/runtime/e2b/runtime.py index 5542ae63af12..1a0710bf91ff 100644 --- a/opendevin/runtime/e2b/runtime.py +++ b/opendevin/runtime/e2b/runtime.py @@ -13,6 +13,7 @@ from opendevin.runtime.server.files import insert_lines, read_lines from opendevin.runtime.server.runtime import ServerRuntime +from .filestore import E2BFileStore from .sandbox import E2BSandbox @@ -26,25 +27,25 @@ def __init__( super().__init__(event_stream, sid, sandbox) if not isinstance(self.sandbox, E2BSandbox): raise ValueError('E2BRuntime requires an E2BSandbox') - self.filesystem = self.sandbox.filesystem + self.file_store = E2BFileStore(self.sandbox.filesystem) async def read(self, action: FileReadAction) -> Observation: - content = self.filesystem.read(action.path) + content = self.file_store.read(action.path) lines = read_lines(content.split('\n'), action.start, action.end) code_view = ''.join(lines) return FileReadObservation(code_view, path=action.path) async def write(self, action: FileWriteAction) -> Observation: if action.start == 0 and action.end == -1: - self.filesystem.write(action.path, action.content) + self.file_store.write(action.path, action.content) return FileWriteObservation(content='', path=action.path) - files = self.filesystem.list(action.path) + files = self.file_store.list(action.path) if action.path in files: - all_lines = self.filesystem.read(action.path) + all_lines = self.file_store.read(action.path).split('\n') new_file = insert_lines( action.content.split('\n'), all_lines, action.start, action.end ) - self.filesystem.write(action.path, ''.join(new_file)) + self.file_store.write(action.path, ''.join(new_file)) return FileWriteObservation('', path=action.path) else: # FIXME: we should create a new file here diff --git a/opendevin/runtime/runtime.py b/opendevin/runtime/runtime.py index 9a52fbff9392..79a7120e52f3 100644 --- a/opendevin/runtime/runtime.py +++ b/opendevin/runtime/runtime.py @@ -31,6 +31,7 @@ ) from opendevin.runtime.browser.browser_env import BrowserEnv from opendevin.runtime.plugins import PluginRequirement +from opendevin.storage import FileStore, InMemoryFileStore def create_sandbox(sid: str = 'default', sandbox_type: str = 'exec') -> Sandbox: @@ -55,6 +56,7 @@ class Runtime: """ sid: str + file_store: FileStore def __init__( self, @@ -70,6 +72,7 @@ def __init__( self.sandbox = sandbox self._is_external_sandbox = True self.browser = BrowserEnv() + self.file_store = InMemoryFileStore() self.event_stream = event_stream self.event_stream.subscribe(EventStreamSubscriber.RUNTIME, self.on_event) self._bg_task = asyncio.create_task(self._start_background_observation_loop()) diff --git a/opendevin/runtime/server/runtime.py b/opendevin/runtime/server/runtime.py index b399ff9badea..8f04984d591e 100644 --- a/opendevin/runtime/server/runtime.py +++ b/opendevin/runtime/server/runtime.py @@ -1,3 +1,4 @@ +from opendevin.core.config import config from opendevin.events.action import ( AgentRecallAction, BrowseInteractiveAction, @@ -15,13 +16,25 @@ NullObservation, Observation, ) +from opendevin.events.stream import EventStream +from opendevin.runtime import Sandbox from opendevin.runtime.runtime import Runtime +from opendevin.storage.local import LocalFileStore from .browse import browse from .files import read_file, write_file class ServerRuntime(Runtime): + def __init__( + self, + event_stream: EventStream, + sid: str = 'default', + sandbox: Sandbox | None = None, + ): + super().__init__(event_stream, sid, sandbox) + self.file_store = LocalFileStore(config.workspace_base) + async def run(self, action: CmdRunAction) -> Observation: return self._run_command(action.command, background=action.background) @@ -71,10 +84,12 @@ async def run_ipython(self, action: IPythonRunCellAction) -> Observation: return IPythonRunCellObservation(content=output, code=action.code) async def read(self, action: FileReadAction) -> Observation: + # TODO: use self.file_store working_dir = self.sandbox.get_working_directory() return await read_file(action.path, working_dir, action.start, action.end) async def write(self, action: FileWriteAction) -> Observation: + # TODO: use self.file_store working_dir = self.sandbox.get_working_directory() return await write_file( action.path, working_dir, action.content, action.start, action.end diff --git a/opendevin/server/agent/__init__.py b/opendevin/server/agent/__init__.py deleted file mode 100644 index bfb0bc1f1a92..000000000000 --- a/opendevin/server/agent/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from .manager import AgentManager - -agent_manager = AgentManager() - -__all__ = ['AgentManager', 'agent_manager'] diff --git a/opendevin/server/agent/agent.py b/opendevin/server/agent/agent.py deleted file mode 100644 index a897815e9600..000000000000 --- a/opendevin/server/agent/agent.py +++ /dev/null @@ -1,161 +0,0 @@ -from typing import Optional - -from agenthub.codeact_agent.codeact_agent import CodeActAgent -from opendevin.const.guide_url import TROUBLESHOOTING_URL -from opendevin.controller import AgentController -from opendevin.controller.agent import Agent -from opendevin.core.config import config -from opendevin.core.logger import opendevin_logger as logger -from opendevin.core.schema import ActionType, AgentState, ConfigType -from opendevin.events.action import ( - ChangeAgentStateAction, - NullAction, -) -from opendevin.events.event import Event -from opendevin.events.observation import ( - NullObservation, -) -from opendevin.events.serialization.action import action_from_dict -from opendevin.events.serialization.event import event_to_dict -from opendevin.events.stream import EventSource, EventStream, EventStreamSubscriber -from opendevin.llm.llm import LLM -from opendevin.runtime import DockerSSHBox -from opendevin.runtime.e2b.runtime import E2BRuntime -from opendevin.runtime.runtime import Runtime -from opendevin.runtime.server.runtime import ServerRuntime -from opendevin.server.session import session_manager - - -class AgentUnit: - """Represents a session with an agent. - - Attributes: - controller: The AgentController instance for controlling the agent. - """ - - sid: str - event_stream: EventStream - controller: Optional[AgentController] = None - runtime: Optional[Runtime] = None - - def __init__(self, sid): - """Initializes a new instance of the Session class.""" - self.sid = sid - self.event_stream = EventStream(sid) - self.event_stream.subscribe(EventStreamSubscriber.SERVER, self.on_event) - if config.runtime == 'server': - logger.info('Using server runtime') - self.runtime = ServerRuntime(self.event_stream, sid) - elif config.runtime == 'e2b': - logger.info('Using E2B runtime') - self.runtime = E2BRuntime(self.event_stream, sid) - - async def send_error(self, message): - """Sends an error message to the client. - - Args: - message: The error message to send. - """ - await session_manager.send_error(self.sid, message) - - async def send_message(self, message): - """Sends a message to the client. - - Args: - message: The message to send. - """ - await session_manager.send_message(self.sid, message) - - async def send(self, data): - """Sends data to the client. - - Args: - data: The data to send. - """ - await session_manager.send(self.sid, data) - - async def dispatch(self, action: str | None, data: dict): - """Dispatches actions to the agent from the client.""" - if action is None: - await self.send_error('Invalid action') - return - - if action == ActionType.INIT: - await self.create_controller(data) - await self.event_stream.add_event( - ChangeAgentStateAction(AgentState.INIT), EventSource.USER - ) - return - - action_dict = data.copy() - action_dict['action'] = action - action_obj = action_from_dict(action_dict) - await self.event_stream.add_event(action_obj, EventSource.USER) - - async def create_controller(self, start_event: dict): - """Creates an AgentController instance. - - Args: - start_event: The start event data (optional). - """ - args = { - key: value - for key, value in start_event.get('args', {}).items() - if value != '' - } # remove empty values, prevent FE from sending empty strings - agent_cls = args.get(ConfigType.AGENT, config.agent.name) - model = args.get(ConfigType.LLM_MODEL, config.llm.model) - api_key = args.get(ConfigType.LLM_API_KEY, config.llm.api_key) - api_base = config.llm.base_url - max_iterations = args.get(ConfigType.MAX_ITERATIONS, config.max_iterations) - max_chars = args.get(ConfigType.MAX_CHARS, config.llm.max_chars) - - logger.info(f'Creating agent {agent_cls} using LLM {model}') - llm = LLM(model=model, api_key=api_key, base_url=api_base) - agent = Agent.get_cls(agent_cls)(llm) - if isinstance(agent, CodeActAgent): - if not self.runtime or not isinstance(self.runtime.sandbox, DockerSSHBox): - logger.warning( - 'CodeActAgent requires DockerSSHBox as sandbox! Using other sandbox that are not stateful (LocalBox, DockerExecBox) will not work properly.' - ) - # Initializing plugins into the runtime - assert self.runtime is not None, 'Runtime is not initialized' - self.runtime.init_sandbox_plugins(agent.sandbox_plugins) - - if self.controller is not None: - await self.controller.close() - try: - self.controller = AgentController( - sid=self.sid, - event_stream=self.event_stream, - agent=agent, - max_iterations=int(max_iterations), - max_chars=int(max_chars), - ) - except Exception as e: - logger.exception(f'Error creating controller: {e}') - await self.send_error( - f'Error creating controller. Please check Docker is running and visit `{TROUBLESHOOTING_URL}` for more debugging information..' - ) - return - - async def on_event(self, event: Event): - """Callback function for agent events. - - Args: - event: The agent event (Observation or Action). - """ - if isinstance(event, NullAction): - return - if isinstance(event, NullObservation): - return - if event.source == 'agent' and not isinstance( - event, (NullAction, NullObservation) - ): - await self.send(event_to_dict(event)) - - async def close(self): - if self.controller is not None: - await self.controller.close() - if self.runtime is not None: - self.runtime.close() diff --git a/opendevin/server/agent/manager.py b/opendevin/server/agent/manager.py deleted file mode 100644 index 5f9593ef6428..000000000000 --- a/opendevin/server/agent/manager.py +++ /dev/null @@ -1,48 +0,0 @@ -import asyncio, atexit - -from opendevin.core.logger import opendevin_logger as logger -from opendevin.server.session import session_manager - -from .agent import AgentUnit - - -class AgentManager: - sid_to_agent: dict[str, 'AgentUnit'] = {} - - def __init__(self): - atexit.register(self.close) - - def register_agent(self, sid: str): - """Registers a new agent. - - Args: - sid: The session ID of the agent. - """ - if sid not in self.sid_to_agent: - self.sid_to_agent[sid] = AgentUnit(sid) - return - - # TODO: confirm whether the agent is alive - - async def dispatch(self, sid: str, action: str | None, data: dict): - """Dispatches actions to the agent from the client.""" - if sid not in self.sid_to_agent: - # self.register_agent(sid) # auto-register agent, may be opened later - logger.error(f'Agent not registered: {sid}') - await session_manager.send_error(sid, 'Agent not registered') - return - - await self.sid_to_agent[sid].dispatch(action, data) - - def close(self): - try: - loop = asyncio.get_event_loop() - except RuntimeError: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - loop.run_until_complete(self._close()) - - async def _close(self): - logger.info(f'Closing {len(self.sid_to_agent)} agent(s)...') - for sid, agent in self.sid_to_agent.items(): - await agent.close() diff --git a/opendevin/server/auth/auth.py b/opendevin/server/auth/auth.py index 98507e8f81ea..36cadd98428a 100644 --- a/opendevin/server/auth/auth.py +++ b/opendevin/server/auth/auth.py @@ -1,12 +1,9 @@ -import os - import jwt from jwt.exceptions import InvalidTokenError +from opendevin.core.config import config from opendevin.core.logger import opendevin_logger as logger -JWT_SECRET = os.getenv('JWT_SECRET', '5ecRe7') - def get_sid_from_token(token: str) -> str: """ @@ -20,7 +17,7 @@ def get_sid_from_token(token: str) -> str: """ try: # Decode the JWT using the specified secret and algorithm - payload = jwt.decode(token, JWT_SECRET, algorithms=['HS256']) + payload = jwt.decode(token, config.jwt_secret, algorithms=['HS256']) # Ensure the payload contains 'sid' if 'sid' in payload: @@ -41,4 +38,4 @@ def sign_token(payload: dict[str, object]) -> str: # "sid": sid, # # "exp": datetime.now(timezone.utc) + timedelta(minutes=15), # } - return jwt.encode(payload, JWT_SECRET, algorithm='HS256') + return jwt.encode(payload, config.jwt_secret, algorithm='HS256') diff --git a/opendevin/server/listen.py b/opendevin/server/listen.py index cf21cbfa2deb..5bdeb15b8199 100644 --- a/opendevin/server/listen.py +++ b/opendevin/server/listen.py @@ -1,26 +1,25 @@ -import os -import shutil import uuid import warnings -from pathlib import Path with warnings.catch_warnings(): warnings.simplefilter('ignore') import litellm -from fastapi import Depends, FastAPI, Request, Response, UploadFile, WebSocket, status +from fastapi import FastAPI, Request, Response, UploadFile, WebSocket, status from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, RedirectResponse -from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer +from fastapi.security import HTTPBearer from fastapi.staticfiles import StaticFiles import agenthub # noqa F401 (we import this to get the agents registered) from opendevin.controller.agent import Agent from opendevin.core.config import config from opendevin.core.logger import opendevin_logger as logger +from opendevin.events.action import ChangeAgentStateAction, NullAction +from opendevin.events.observation import AgentStateChangedObservation, NullObservation +from opendevin.events.serialization import event_to_dict from opendevin.llm import bedrock -from opendevin.server.agent import agent_manager from opendevin.server.auth import get_sid_from_token, sign_token -from opendevin.server.session import message_stack, session_manager +from opendevin.server.session import session_manager app = FastAPI() app.add_middleware( @@ -34,6 +33,45 @@ security_scheme = HTTPBearer() +@app.middleware('http') +async def attach_session(request: Request, call_next): + if request.url.path.startswith('/api/options/') or not request.url.path.startswith( + '/api/' + ): + response = await call_next(request) + return response + + if not request.headers.get('Authorization'): + response = JSONResponse( + status_code=status.HTTP_401_UNAUTHORIZED, + content={'error': 'Missing Authorization header'}, + ) + return response + + auth_token = request.headers.get('Authorization') + if 'Bearer' in auth_token: + auth_token = auth_token.split('Bearer')[1].strip() + + request.state.sid = get_sid_from_token(auth_token) + if request.state.sid == '': + response = JSONResponse( + status_code=status.HTTP_401_UNAUTHORIZED, + content={'error': 'Invalid token'}, + ) + return response + + request.state.session = session_manager.get_session(request.state.sid) + if request.state.session is None: + response = JSONResponse( + status_code=status.HTTP_404_NOT_FOUND, + content={'error': 'Session not found'}, + ) + return response + + response = await call_next(request) + return response + + # This endpoint receives events from the client (i.e. the browser) @app.websocket('/ws') async def websocket_endpoint(websocket: WebSocket): @@ -99,16 +137,41 @@ async def websocket_endpoint(websocket: WebSocket): ``` """ await websocket.accept() - sid = get_sid_from_token(websocket.query_params.get('token') or '') - if sid == '': - logger.error('Failed to decode token') - return - session_manager.add_session(sid, websocket) - agent_manager.register_agent(sid) - await session_manager.loop_recv(sid, agent_manager.dispatch) + session = None + if websocket.query_params.get('token'): + token = websocket.query_params.get('token') + sid = get_sid_from_token(token) + + if sid == '': + await websocket.send_json({'error': 'Invalid token', 'error_code': 401}) + await websocket.close() + return + else: + sid = str(uuid.uuid4()) + token = sign_token({'sid': sid}) + + session = session_manager.add_or_restart_session(sid, websocket) + await websocket.send_json({'token': token, 'status': 'ok'}) -@app.get('/api/litellm-models') + last_event_id = -1 + if websocket.query_params.get('last_event_id'): + last_event_id = int(websocket.query_params.get('last_event_id')) + for event in session.agent_session.event_stream.get_events( + start_id=last_event_id + 1 + ): + if isinstance(event, NullAction) or isinstance(event, NullObservation): + continue + if isinstance(event, ChangeAgentStateAction) or isinstance( + event, AgentStateChangedObservation + ): + continue + await websocket.send_json(event_to_dict(event)) + + await session.loop_recv() + + +@app.get('/api/options/models') async def get_litellm_models(): """ Get all models supported by LiteLLM. @@ -128,7 +191,7 @@ async def get_litellm_models(): return list(set(model_list)) -@app.get('/api/agents') +@app.get('/api/options/agents') async def get_agents(): """ Get all agents supported by LiteLLM. @@ -142,89 +205,6 @@ async def get_agents(): return agents -@app.get('/api/auth') -async def get_token( - credentials: HTTPAuthorizationCredentials = Depends(security_scheme), -): - """ - Generate a JWT for authentication when starting a WebSocket connection. This endpoint checks if valid credentials - are provided and uses them to get a session ID. If no valid credentials are provided, it generates a new session ID. - - To obtain an authentication token: - ```sh - curl -H "Authorization: Bearer 5ecRe7" http://localhost:3000/api/auth - ``` - **Note:** If `JWT_SECRET` is set, use its value instead of `5ecRe7`. - """ - if credentials and credentials.credentials: - sid = get_sid_from_token(credentials.credentials) - if not sid: - sid = str(uuid.uuid4()) - logger.info( - f'Invalid or missing credentials, generating new session ID: {sid}' - ) - else: - sid = str(uuid.uuid4()) - logger.info(f'No credentials provided, generating new session ID: {sid}') - - token = sign_token({'sid': sid}) - return {'token': token, 'status': 'ok'} - - -@app.get('/api/messages') -async def get_messages( - credentials: HTTPAuthorizationCredentials = Depends(security_scheme), -): - """ - Get messages. - - To get messages: - ```sh - curl -H "Authorization: Bearer <TOKEN>" http://localhost:3000/api/messages - ``` - """ - data = [] - sid = get_sid_from_token(credentials.credentials) - if sid != '': - data = message_stack.get_messages(sid) - - return {'messages': data} - - -@app.get('/api/messages/total') -async def get_message_total( - credentials: HTTPAuthorizationCredentials = Depends(security_scheme), -): - """ - Get total message count. - - To get the total message count: - ```sh - curl -H "Authorization: Bearer <TOKEN>" http://localhost:3000/api/messages/total - ``` - """ - sid = get_sid_from_token(credentials.credentials) - return {'msg_total': message_stack.get_message_total(sid)} - - -@app.delete('/api/messages') -async def del_messages( - credentials: HTTPAuthorizationCredentials = Depends(security_scheme), -): - """ - Delete messages. - - To delete messages: - ```sh - - curl -X DELETE -H "Authorization: Bearer <TOKEN>" http://localhost:3000/api/messages - ``` - """ - sid = get_sid_from_token(credentials.credentials) - message_stack.del_messages(sid) - return {'ok': True} - - @app.get('/api/list-files') def list_files(request: Request, path: str = '/'): """ @@ -235,27 +215,25 @@ def list_files(request: Request, path: str = '/'): curl http://localhost:3000/api/list-files ``` """ - if path.startswith('/'): - path = path[1:] - abs_path = os.path.join(config.workspace_base, path) + if not request.state.session.agent_session.runtime: + return JSONResponse( + status_code=status.HTTP_404_NOT_FOUND, + content={'error': 'Runtime not yet initialized'}, + ) + try: - files = os.listdir(abs_path) + return request.state.session.agent_session.runtime.file_store.list(path) except Exception as e: - logger.error(f'Error listing files: {e}', exc_info=False) + logger.error(f'Error refreshing files: {e}', exc_info=False) + error_msg = f'Error refreshing files: {e}' return JSONResponse( - status_code=status.HTTP_404_NOT_FOUND, - content={'error': 'Path not found'}, + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + content={'error': error_msg}, ) - files = [os.path.join(path, f) for f in files] - files = [ - f + '/' if os.path.isdir(os.path.join(config.workspace_base, f)) else f - for f in files - ] - return files @app.get('/api/select-file') -def select_file(file: str): +def select_file(file: str, request: Request): """ Select a file. @@ -265,12 +243,7 @@ def select_file(file: str): ``` """ try: - workspace_base = config.workspace_base - file_path = Path(workspace_base, file) - # The following will check if the file is within the workspace base and throw an exception if not - file_path.resolve().relative_to(Path(workspace_base).resolve()) - with open(file_path, 'r') as selected_file: - content = selected_file.read() + content = request.state.session.agent_session.runtime.file_store.read(file) except Exception as e: logger.error(f'Error opening file {file}: {e}', exc_info=False) error_msg = f'Error opening file: {e}' @@ -282,7 +255,7 @@ def select_file(file: str): @app.post('/api/upload-files') -async def upload_files(files: list[UploadFile]): +async def upload_file(request: Request, files: list[UploadFile]): """ Upload files to the workspace. @@ -292,13 +265,11 @@ async def upload_files(files: list[UploadFile]): ``` """ try: - workspace_base = config.workspace_base for file in files: - file_path = Path(workspace_base, file.filename) - # The following will check if the file is within the workspace base and throw an exception if not - file_path.resolve().relative_to(Path(workspace_base).resolve()) - with open(file_path, 'wb') as buffer: - shutil.copyfileobj(file.file, buffer) + file_contents = await file.read() + request.state.session.agent_session.runtime.file_store.write( + file.filename, file_contents + ) except Exception as e: logger.error(f'Error saving files: {e}', exc_info=True) return JSONResponse( @@ -309,9 +280,7 @@ async def upload_files(files: list[UploadFile]): @app.get('/api/root_task') -def get_root_task( - credentials: HTTPAuthorizationCredentials = Depends(security_scheme), -): +def get_root_task(request: Request): """ Get root_task. @@ -320,9 +289,7 @@ def get_root_task( curl -H "Authorization: Bearer <TOKEN>" http://localhost:3000/api/root_task ``` """ - sid = get_sid_from_token(credentials.credentials) - agent = agent_manager.sid_to_agent[sid] - controller = agent.controller + controller = request.state.session.agent_session.controller if controller is not None: state = controller.get_state() if state: diff --git a/opendevin/server/mock/listen.py b/opendevin/server/mock/listen.py index 5a4d36a13d83..517702f506d1 100644 --- a/opendevin/server/mock/listen.py +++ b/opendevin/server/mock/listen.py @@ -33,7 +33,7 @@ def read_root(): return {'message': 'This is a mock server'} -@app.get('/api/litellm-models') +@app.get('/api/options/models') def read_llm_models(): return [ 'gpt-4', @@ -43,7 +43,7 @@ def read_llm_models(): ] -@app.get('/api/agents') +@app.get('/api/options/agents') def read_llm_agents(): return [ 'MonologueAgent', @@ -52,16 +52,6 @@ def read_llm_agents(): ] -@app.get('/api/messages') -async def get_messages(): - return {'messages': []} - - -@app.get('/api/messages/total') -async def get_message_total(): - return {'msg_total': 0} - - @app.get('/api/list-files') def refresh_files(): return ['hello_world.py'] diff --git a/opendevin/server/session/__init__.py b/opendevin/server/session/__init__.py index 405429cd07cd..391952d6aa31 100644 --- a/opendevin/server/session/__init__.py +++ b/opendevin/server/session/__init__.py @@ -1,5 +1,4 @@ from .manager import SessionManager -from .msg_stack import message_stack from .session import Session session_manager = SessionManager() diff --git a/opendevin/server/session/agent.py b/opendevin/server/session/agent.py new file mode 100644 index 000000000000..aacb791bda2d --- /dev/null +++ b/opendevin/server/session/agent.py @@ -0,0 +1,117 @@ +from typing import Optional + +from agenthub.codeact_agent.codeact_agent import CodeActAgent +from opendevin.controller import AgentController +from opendevin.controller.agent import Agent +from opendevin.controller.state.state import State +from opendevin.core.config import config +from opendevin.core.logger import opendevin_logger as logger +from opendevin.core.schema import ConfigType +from opendevin.events.stream import EventStream +from opendevin.llm.llm import LLM +from opendevin.runtime import DockerSSHBox +from opendevin.runtime.e2b.runtime import E2BRuntime +from opendevin.runtime.runtime import Runtime +from opendevin.runtime.server.runtime import ServerRuntime + + +class AgentSession: + """Represents a session with an agent. + + Attributes: + controller: The AgentController instance for controlling the agent. + """ + + sid: str + event_stream: EventStream + controller: Optional[AgentController] = None + runtime: Optional[Runtime] = None + _closed: bool = False + + def __init__(self, sid): + """Initializes a new instance of the Session class.""" + self.sid = sid + self.event_stream = EventStream(sid) + + async def start(self, start_event: dict): + """Starts the agent session. + + Args: + start_event: The start event data (optional). + """ + if self.controller or self.runtime: + raise Exception( + 'Session already started. You need to close this session and start a new one.' + ) + await self._create_runtime() + await self._create_controller(start_event) + + async def close(self): + if self._closed: + return + if self.controller is not None: + end_state = self.controller.get_state() + end_state.save_to_session(self.sid) + await self.controller.close() + if self.runtime is not None: + self.runtime.close() + self._closed = True + + async def _create_runtime(self): + if self.runtime is not None: + raise Exception('Runtime already created') + if config.runtime == 'server': + logger.info('Using server runtime') + self.runtime = ServerRuntime(self.event_stream, self.sid) + elif config.runtime == 'e2b': + logger.info('Using E2B runtime') + self.runtime = E2BRuntime(self.event_stream, self.sid) + else: + raise Exception( + f'Runtime not defined in config, or is invalid: {config.runtime}' + ) + + async def _create_controller(self, start_event: dict): + """Creates an AgentController instance. + + Args: + start_event: The start event data (optional). + """ + if self.controller is not None: + raise Exception('Controller already created') + if self.runtime is None: + raise Exception('Runtime must be initialized before the agent controller') + args = { + key: value + for key, value in start_event.get('args', {}).items() + if value != '' + } # remove empty values, prevent FE from sending empty strings + agent_cls = args.get(ConfigType.AGENT, config.agent.name) + model = args.get(ConfigType.LLM_MODEL, config.llm.model) + api_key = args.get(ConfigType.LLM_API_KEY, config.llm.api_key) + api_base = config.llm.base_url + max_iterations = args.get(ConfigType.MAX_ITERATIONS, config.max_iterations) + max_chars = args.get(ConfigType.MAX_CHARS, config.llm.max_chars) + + logger.info(f'Creating agent {agent_cls} using LLM {model}') + llm = LLM(model=model, api_key=api_key, base_url=api_base) + agent = Agent.get_cls(agent_cls)(llm) + if isinstance(agent, CodeActAgent): + if not self.runtime or not isinstance(self.runtime.sandbox, DockerSSHBox): + logger.warning( + 'CodeActAgent requires DockerSSHBox as sandbox! Using other sandbox that are not stateful (LocalBox, DockerExecBox) will not work properly.' + ) + self.runtime.init_sandbox_plugins(agent.sandbox_plugins) + + self.controller = AgentController( + sid=self.sid, + event_stream=self.event_stream, + agent=agent, + max_iterations=int(max_iterations), + max_chars=int(max_chars), + ) + try: + agent_state = State.restore_from_session(self.sid) + self.controller.set_state(agent_state) + except Exception as e: + print('Error restoring state', e) diff --git a/opendevin/server/session/manager.py b/opendevin/server/session/manager.py index 73bb0dcee9ba..cae6cae17e07 100644 --- a/opendevin/server/session/manager.py +++ b/opendevin/server/session/manager.py @@ -1,20 +1,12 @@ import asyncio -import atexit -import json -import os import time -from typing import Callable from fastapi import WebSocket from opendevin.core.logger import opendevin_logger as logger -from .msg_stack import message_stack from .session import Session -CACHE_DIR = os.getenv('CACHE_DIR', 'cache') -SESSION_CACHE_FILE = os.path.join(CACHE_DIR, 'sessions.json') - class SessionManager: _sessions: dict[str, Session] = {} @@ -22,30 +14,21 @@ class SessionManager: session_timeout: int = 600 def __init__(self): - self._load_sessions() - atexit.register(self.close) asyncio.create_task(self._cleanup_sessions()) - def add_session(self, sid: str, ws_conn: WebSocket): - if sid not in self._sessions: - self._sessions[sid] = Session(sid=sid, ws=ws_conn) - return - self._sessions[sid].update_connection(ws_conn) + def add_or_restart_session(self, sid: str, ws_conn: WebSocket) -> Session: + if sid in self._sessions: + asyncio.create_task(self._sessions[sid].close()) + self._sessions[sid] = Session(sid=sid, ws=ws_conn) + return self._sessions[sid] - async def loop_recv(self, sid: str, dispatch: Callable): - print(f'Starting loop_recv for sid: {sid}') - """Starts listening for messages from the client.""" + def get_session(self, sid: str) -> Session | None: if sid not in self._sessions: - return - await self._sessions[sid].loop_recv(dispatch) - - def close(self): - logger.info('Saving sessions...') - self._save_sessions() + return None + return self._sessions.get(sid) async def send(self, sid: str, data: dict[str, object]) -> bool: """Sends data to the client.""" - message_stack.add_message(sid, 'assistant', data) if sid not in self._sessions: return False return await self._sessions[sid].send(data) @@ -58,33 +41,6 @@ async def send_message(self, sid: str, message: str) -> bool: """Sends a message to the client.""" return await self.send(sid, {'message': message}) - def _save_sessions(self): - data = {} - for sid, conn in self._sessions.items(): - data[sid] = { - 'sid': conn.sid, - 'last_active_ts': conn.last_active_ts, - 'is_alive': conn.is_alive, - } - if not os.path.exists(CACHE_DIR): - os.makedirs(CACHE_DIR) - with open(SESSION_CACHE_FILE, 'w+') as file: - json.dump(data, file) - - def _load_sessions(self): - try: - with open(SESSION_CACHE_FILE, 'r') as file: - data = json.load(file) - for sid, sdata in data.items(): - conn = Session(sid, None) - ok = conn.load_from_data(sdata) - if ok: - self._sessions[sid] = conn - except FileNotFoundError: - pass - except json.decoder.JSONDecodeError: - pass - async def _cleanup_sessions(self): while True: current_time = time.time() diff --git a/opendevin/server/session/msg_stack.py b/opendevin/server/session/msg_stack.py deleted file mode 100644 index 6e0862af4426..000000000000 --- a/opendevin/server/session/msg_stack.py +++ /dev/null @@ -1,114 +0,0 @@ -import asyncio -import atexit -import json -import os -import uuid - -from opendevin.core.logger import opendevin_logger as logger -from opendevin.core.schema.action import ActionType - -CACHE_DIR = os.getenv('CACHE_DIR', 'cache') -MSG_CACHE_FILE = os.path.join(CACHE_DIR, 'messages.json') - - -class Message: - id: str = str(uuid.uuid4()) - role: str # "user"| "assistant" - payload: dict[str, object] - - def __init__(self, role: str, payload: dict[str, object]): - self.role = role - self.payload = payload - - def to_dict(self): - return {'id': self.id, 'role': self.role, 'payload': self.payload} - - @classmethod - def from_dict(cls, data: dict): - m = cls(data['role'], data['payload']) - m.id = data['id'] - return m - - -class MessageStack: - _messages: dict[str, list[Message]] = {} - - def __init__(self): - self._load_messages() - atexit.register(self.close) - - def close(self): - logger.info('Saving messages...') - self._save_messages() - - def add_message(self, sid: str, role: str, message: dict[str, object]): - if sid not in self._messages: - self._messages[sid] = [] - self._messages[sid].append(Message(role, message)) - - def del_messages(self, sid: str): - if sid not in self._messages: - return - del self._messages[sid] - asyncio.create_task(self._del_messages(sid)) - - def get_messages(self, sid: str) -> list[dict[str, object]]: - if sid not in self._messages: - return [] - return [msg.to_dict() for msg in self._messages[sid]] - - def get_message_total(self, sid: str) -> int: - if sid not in self._messages: - return 0 - cnt = 0 - for msg in self._messages[sid]: - # Ignore assistant init message for now. - if 'action' in msg.payload and msg.payload['action'] in [ - ActionType.INIT, - ActionType.CHANGE_AGENT_STATE, - ]: - continue - cnt += 1 - return cnt - - def _save_messages(self): - if not os.path.exists(CACHE_DIR): - os.makedirs(CACHE_DIR) - data = {} - for sid, msgs in self._messages.items(): - data[sid] = [msg.to_dict() for msg in msgs] - with open(MSG_CACHE_FILE, 'w+') as file: - json.dump(data, file) - - def _load_messages(self): - try: - with open(MSG_CACHE_FILE, 'r') as file: - data = json.load(file) - for sid, msgs in data.items(): - self._messages[sid] = [Message.from_dict(msg) for msg in msgs] - except FileNotFoundError: - pass - except json.decoder.JSONDecodeError: - pass - - async def _del_messages(self, del_sid: str): - logger.info('Deleting messages...') - try: - with open(MSG_CACHE_FILE, 'r+') as file: - data = json.load(file) - new_data = {} - for sid, msgs in data.items(): - if sid != del_sid: - new_data[sid] = msgs - # Move the file pointer to the beginning of the file to overwrite the original contents - file.seek(0) - # clean previous content - file.truncate() - json.dump(new_data, file) - except FileNotFoundError: - pass - except json.decoder.JSONDecodeError: - pass - - -message_stack = MessageStack() diff --git a/opendevin/server/session/session.py b/opendevin/server/session/session.py index 0ca5c6d968de..18df617f11df 100644 --- a/opendevin/server/session/session.py +++ b/opendevin/server/session/session.py @@ -1,11 +1,19 @@ +import asyncio import time -from typing import Callable from fastapi import WebSocket, WebSocketDisconnect +from opendevin.core.const.guide_url import TROUBLESHOOTING_URL from opendevin.core.logger import opendevin_logger as logger +from opendevin.core.schema import AgentState +from opendevin.core.schema.action import ActionType +from opendevin.events.action import ChangeAgentStateAction, NullAction +from opendevin.events.event import Event +from opendevin.events.observation import AgentStateChangedObservation, NullObservation +from opendevin.events.serialization import EventSource, event_from_dict, event_to_dict +from opendevin.events.stream import EventStreamSubscriber -from .msg_stack import message_stack +from .agent import AgentSession DEL_DELT_SEC = 60 * 60 * 5 @@ -15,13 +23,22 @@ class Session: websocket: WebSocket | None last_active_ts: int = 0 is_alive: bool = True + agent_session: AgentSession def __init__(self, sid: str, ws: WebSocket | None): self.sid = sid self.websocket = ws self.last_active_ts = int(time.time()) + self.agent_session = AgentSession(sid) + self.agent_session.event_stream.subscribe( + EventStreamSubscriber.SERVER, self.on_event + ) - async def loop_recv(self, dispatch: Callable): + async def close(self): + self.is_alive = False + await self.agent_session.close() + + async def loop_recv(self): try: if self.websocket is None: return @@ -31,24 +48,62 @@ async def loop_recv(self, dispatch: Callable): except ValueError: await self.send_error('Invalid JSON') continue - - message_stack.add_message(self.sid, 'user', data) - action = data.get('action', None) - await dispatch(self.sid, action, data) + await self.dispatch(data) except WebSocketDisconnect: - self.is_alive = False + await self.close() logger.info('WebSocket disconnected, sid: %s', self.sid) except RuntimeError as e: - # WebSocket is not connected - if 'WebSocket is not connected' in str(e): - self.is_alive = False + await self.close() logger.exception('Error in loop_recv: %s', e) + async def _initialize_agent(self, data: dict): + await self.agent_session.event_stream.add_event( + ChangeAgentStateAction(AgentState.LOADING), EventSource.USER + ) + await self.agent_session.event_stream.add_event( + AgentStateChangedObservation('', AgentState.LOADING), EventSource.AGENT + ) + try: + await self.agent_session.start(data) + except Exception as e: + logger.exception(f'Error creating controller: {e}') + await self.send_error( + f'Error creating controller. Please check Docker is running and visit `{TROUBLESHOOTING_URL}` for more debugging information..' + ) + return + await self.agent_session.event_stream.add_event( + ChangeAgentStateAction(AgentState.INIT), EventSource.USER + ) + + async def on_event(self, event: Event): + """Callback function for agent events. + + Args: + event: The agent event (Observation or Action). + """ + if isinstance(event, NullAction): + return + if isinstance(event, NullObservation): + return + if event.source == EventSource.AGENT and not isinstance( + event, (NullAction, NullObservation) + ): + await self.send(event_to_dict(event)) + + async def dispatch(self, data: dict): + action = data.get('action', '') + if action == ActionType.INIT: + await self._initialize_agent(data) + return + event = event_from_dict(data.copy()) + await self.agent_session.event_stream.add_event(event, EventSource.USER) + async def send(self, data: dict[str, object]) -> bool: try: if self.websocket is None or not self.is_alive: return False await self.websocket.send_json(data) + await asyncio.sleep(0.001) # This flushes the data to the client self.last_active_ts = int(time.time()) return True except WebSocketDisconnect: diff --git a/opendevin/storage/local.py b/opendevin/storage/local.py index 8cc530599b6d..27791b1e68e7 100644 --- a/opendevin/storage/local.py +++ b/opendevin/storage/local.py @@ -28,7 +28,9 @@ def read(self, path: str) -> str: def list(self, path: str) -> list[str]: full_path = self.get_full_path(path) - return [os.path.join(path, f) for f in os.listdir(full_path)] + files = [os.path.join(path, f) for f in os.listdir(full_path)] + files = [f + '/' if os.path.isdir(self.get_full_path(f)) else f for f in files] + return files def delete(self, path: str) -> None: full_path = self.get_full_path(path) diff --git a/opendevin/storage/memory.py b/opendevin/storage/memory.py index 45d4ce100f03..ea797ba7c699 100644 --- a/opendevin/storage/memory.py +++ b/opendevin/storage/memory.py @@ -30,6 +30,8 @@ def list(self, path: str) -> list[str]: files.append(file) else: dir_path = os.path.join(path, parts[0]) + if not dir_path.endswith('/'): + dir_path += '/' if dir_path not in files: files.append(dir_path) return files diff --git a/tests/unit/test_storage.py b/tests/unit/test_storage.py index 272b648b7e44..dc173ff8fe59 100644 --- a/tests/unit/test_storage.py +++ b/tests/unit/test_storage.py @@ -54,10 +54,8 @@ def test_deep_list(setup_env): store.write('foo/bar/baz.txt', 'Hello, world!') store.write('foo/bar/qux.txt', 'Hello, world!') store.write('foo/bar/quux.txt', 'Hello, world!') - assert store.list('') == ['foo'], 'Expected foo, got {} for class {}'.format( - store.list(''), store.__class__ - ) - assert store.list('foo') == ['foo/bar'] + assert store.list('') == ['foo/'], f'for class {store.__class__}' + assert store.list('foo') == ['foo/bar/'] assert ( store.list('foo/bar').sort() == ['foo/bar/baz.txt', 'foo/bar/qux.txt', 'foo/bar/quux.txt'].sort()