diff --git a/frontend/src/App.tsx b/frontend/src/App.tsx
index 9c35406383bd..7e35572a4ce7 100644
--- a/frontend/src/App.tsx
+++ b/frontend/src/App.tsx
@@ -1,5 +1,5 @@
import { useDisclosure } from "@nextui-org/react";
-import React, { useEffect, useState } from "react";
+import React, { useEffect } from "react";
import { Toaster } from "react-hot-toast";
import CogTooth from "#/assets/cog-tooth";
import ChatInterface from "#/components/chat/ChatInterface";
@@ -8,15 +8,13 @@ import { Container, Orientation } from "#/components/Resizable";
import Workspace from "#/components/Workspace";
import LoadPreviousSessionModal from "#/components/modals/load-previous-session/LoadPreviousSessionModal";
import SettingsModal from "#/components/modals/settings/SettingsModal";
-import { fetchMsgTotal } from "#/services/session";
-import Socket from "#/services/socket";
-import { ResFetchMsgTotal } from "#/types/ResponseType";
import "./App.css";
import AgentControlBar from "./components/AgentControlBar";
import AgentStatusBar from "./components/AgentStatusBar";
import Terminal from "./components/terminal/Terminal";
-import { initializeAgent } from "./services/agent";
-import { settingsAreUpToDate } from "./services/settings";
+import Session from "#/services/session";
+import { getToken } from "#/services/auth";
+import { settingsAreUpToDate } from "#/services/settings";
interface Props {
setSettingOpen: (isOpen: boolean) => void;
@@ -43,8 +41,6 @@ function Controls({ setSettingOpen }: Props): JSX.Element {
let initOnce = false;
function App(): JSX.Element {
- const [isWarned, setIsWarned] = useState(false);
-
const {
isOpen: settingsModalIsOpen,
onOpen: onSettingsModalOpen,
@@ -57,31 +53,18 @@ function App(): JSX.Element {
onOpenChange: onLoadPreviousSessionModalOpenChange,
} = useDisclosure();
- const getMsgTotal = () => {
- if (isWarned) return;
- fetchMsgTotal()
- .then((data: ResFetchMsgTotal) => {
- if (data.msg_total > 0) {
- onLoadPreviousSessionModalOpen();
- setIsWarned(true);
- }
- })
- .catch();
- };
-
useEffect(() => {
if (initOnce) return;
initOnce = true;
if (!settingsAreUpToDate()) {
onSettingsModalOpen();
+ } else if (getToken()) {
+ onLoadPreviousSessionModalOpen();
} else {
- initializeAgent();
+ Session.startNewSession();
}
- Socket.registerCallback("open", [getMsgTotal]);
-
- getMsgTotal();
// eslint-disable-next-line react-hooks/exhaustive-deps
}, []);
diff --git a/frontend/src/api/index.ts b/frontend/src/api/index.ts
deleted file mode 100644
index 6c0f0d840c8d..000000000000
--- a/frontend/src/api/index.ts
+++ /dev/null
@@ -1,9 +0,0 @@
-export async function fetchModels() {
- const response = await fetch(`/api/litellm-models`);
- return response.json();
-}
-
-export async function fetchAgents() {
- const response = await fetch(`/api/agents`);
- return response.json();
-}
diff --git a/frontend/src/components/AgentControlBar.tsx b/frontend/src/components/AgentControlBar.tsx
index ad87bceb01bb..32eddb3eb31f 100644
--- a/frontend/src/components/AgentControlBar.tsx
+++ b/frontend/src/components/AgentControlBar.tsx
@@ -5,7 +5,6 @@ import ArrowIcon from "#/assets/arrow";
import PauseIcon from "#/assets/pause";
import PlayIcon from "#/assets/play";
import { changeAgentState } from "#/services/agentStateService";
-import { clearMsgs } from "#/services/session";
import store, { RootState } from "#/store";
import AgentState from "#/types/AgentState";
import { clearMessages } from "#/state/chatSlice";
@@ -73,7 +72,6 @@ function AgentControlBar() {
}
if (action === AgentState.STOPPED) {
- clearMsgs().then().catch();
store.dispatch(clearMessages());
} else {
setIsLoading(true);
@@ -86,7 +84,6 @@ function AgentControlBar() {
useEffect(() => {
if (curAgentState === desiredState) {
if (curAgentState === AgentState.STOPPED) {
- clearMsgs().then().catch();
store.dispatch(clearMessages());
}
setIsLoading(false);
diff --git a/frontend/src/components/CodeEditor.tsx b/frontend/src/components/CodeEditor.tsx
index ee7e0f551b95..a14b8650fd25 100644
--- a/frontend/src/components/CodeEditor.tsx
+++ b/frontend/src/components/CodeEditor.tsx
@@ -1,33 +1,24 @@
import Editor, { Monaco } from "@monaco-editor/react";
import { Tab, Tabs } from "@nextui-org/react";
import type { editor } from "monaco-editor";
-import React, { useMemo, useState } from "react";
+import React, { useMemo } from "react";
import { useTranslation } from "react-i18next";
import { VscCode } from "react-icons/vsc";
-import { useDispatch, useSelector } from "react-redux";
+import { useSelector } from "react-redux";
import { I18nKey } from "#/i18n/declaration";
-import { selectFile } from "#/services/fileService";
-import { setCode } from "#/state/codeSlice";
import { RootState } from "#/store";
import FileExplorer from "./file-explorer/FileExplorer";
-import { CodeEditorContext } from "./CodeEditorContext";
function CodeEditor(): JSX.Element {
const { t } = useTranslation();
- const [selectedFileAbsolutePath, setSelectedFileAbsolutePath] = useState("");
- const selectedFileName = useMemo(() => {
- const paths = selectedFileAbsolutePath.split("/");
- return paths[paths.length - 1];
- }, [selectedFileAbsolutePath]);
- const codeEditorContext = useMemo(
- () => ({ selectedFileAbsolutePath }),
- [selectedFileAbsolutePath],
- );
-
- const dispatch = useDispatch();
const code = useSelector((state: RootState) => state.code.code);
const activeFilepath = useSelector((state: RootState) => state.code.path);
+ const selectedFileName = useMemo(() => {
+ const paths = activeFilepath.split("/");
+ return paths[paths.length - 1];
+ }, [activeFilepath]);
+
const handleEditorDidMount = (
editor: editor.IStandaloneCodeEditor,
monaco: Monaco,
@@ -46,57 +37,44 @@ function CodeEditor(): JSX.Element {
monaco.editor.setTheme("my-theme");
};
- const updateCode = async () => {
- const newCode = await selectFile(activeFilepath);
- setSelectedFileAbsolutePath(activeFilepath);
- dispatch(setCode(newCode));
- };
-
- React.useEffect(() => {
- // FIXME: we can probably move this out of the component and into state/service
- if (activeFilepath) updateCode();
- }, [activeFilepath]);
-
return (
-
-
-
-
-
+
+
+
+
+
+ {selectedFileName === "" ? (
+
+
+ {t(I18nKey.CODE_EDITOR$EMPTY_MESSAGE)}
+
+ ) : (
+
-
-
- {selectedFileName === "" ? (
-
-
- {t(I18nKey.CODE_EDITOR$EMPTY_MESSAGE)}
-
- ) : (
-
- )}
-
+ )}
-
+
);
}
diff --git a/frontend/src/components/CodeEditorContext.ts b/frontend/src/components/CodeEditorContext.ts
deleted file mode 100644
index 06340a1e91d9..000000000000
--- a/frontend/src/components/CodeEditorContext.ts
+++ /dev/null
@@ -1,5 +0,0 @@
-import { createContext } from "react";
-
-export const CodeEditorContext = createContext({
- selectedFileAbsolutePath: "",
-});
diff --git a/frontend/src/components/chat/ChatInterface.test.tsx b/frontend/src/components/chat/ChatInterface.test.tsx
index 865484473892..570c5df3267d 100644
--- a/frontend/src/components/chat/ChatInterface.test.tsx
+++ b/frontend/src/components/chat/ChatInterface.test.tsx
@@ -5,7 +5,7 @@ import { act } from "react-dom/test-utils";
import userEvent from "@testing-library/user-event";
import { renderWithProviders } from "test-utils";
import ChatInterface from "./ChatInterface";
-import Socket from "#/services/socket";
+import Session from "#/services/session";
import ActionType from "#/types/ActionType";
import { addAssistantMessage } from "#/state/chatSlice";
import AgentState from "#/types/AgentState";
@@ -15,16 +15,17 @@ vi.mock("#/hooks/useTyping", () => ({
useTyping: vi.fn((text: string) => text),
}));
-const socketSpy = vi.spyOn(Socket, "send");
+const sessionSpy = vi.spyOn(Session, "send");
+vi.spyOn(Session, "isConnected").mockImplementation(() => true);
// This is for the scrollview ref in Chat.tsx
// TODO: Move this into test setup
HTMLElement.prototype.scrollTo = vi.fn(() => {});
describe("ChatInterface", () => {
- it("should render the messages and input", () => {
+ it("should render empty message list and input", () => {
renderWithProviders();
- expect(screen.queryAllByTestId("message")).toHaveLength(1); // initial welcome message only
+ expect(screen.queryAllByTestId("message")).toHaveLength(0);
});
it("should render the new message the user has typed", async () => {
@@ -65,7 +66,7 @@ describe("ChatInterface", () => {
expect(screen.getByText("Hello to you!")).toBeInTheDocument();
});
- it("should send the a start event to the Socket", () => {
+ it("should send the a start event to the Session", () => {
renderWithProviders(, {
preloadedState: {
agent: {
@@ -83,10 +84,10 @@ describe("ChatInterface", () => {
action: ActionType.MESSAGE,
args: { content: "my message" },
};
- expect(socketSpy).toHaveBeenCalledWith(JSON.stringify(event));
+ expect(sessionSpy).toHaveBeenCalledWith(JSON.stringify(event));
});
- it("should send the a user message event to the Socket", () => {
+ it("should send the a user message event to the Session", () => {
renderWithProviders(, {
preloadedState: {
agent: {
@@ -104,7 +105,7 @@ describe("ChatInterface", () => {
action: ActionType.MESSAGE,
args: { content: "my message" },
};
- expect(socketSpy).toHaveBeenCalledWith(JSON.stringify(event));
+ expect(sessionSpy).toHaveBeenCalledWith(JSON.stringify(event));
});
it("should disable the user input if agent is not initialized", () => {
diff --git a/frontend/src/components/chat/ChatInterface.tsx b/frontend/src/components/chat/ChatInterface.tsx
index 0f82eeb5103c..183ff99c70b4 100644
--- a/frontend/src/components/chat/ChatInterface.tsx
+++ b/frontend/src/components/chat/ChatInterface.tsx
@@ -10,7 +10,7 @@ import Chat from "./Chat";
import { RootState } from "#/store";
import AgentState from "#/types/AgentState";
import { sendChatMessage } from "#/services/chatService";
-import { addUserMessage } from "#/state/chatSlice";
+import { addUserMessage, addAssistantMessage } from "#/state/chatSlice";
import { I18nKey } from "#/i18n/declaration";
import { useScrollToBottom } from "#/hooks/useScrollToBottom";
@@ -58,6 +58,12 @@ function ChatInterface() {
const { scrollDomToBottom, onChatBodyScroll, hitBottom } =
useScrollToBottom(scrollRef);
+ React.useEffect(() => {
+ if (curAgentState === AgentState.INIT && messages.length === 0) {
+ dispatch(addAssistantMessage(t(I18nKey.CHAT_INTERFACE$INITIAL_MESSAGE)));
+ }
+ }, [curAgentState]);
+
return (
diff --git a/frontend/src/components/file-explorer/FileExplorer.test.tsx b/frontend/src/components/file-explorer/FileExplorer.test.tsx
index 274e16e0814e..e0494017052a 100644
--- a/frontend/src/components/file-explorer/FileExplorer.test.tsx
+++ b/frontend/src/components/file-explorer/FileExplorer.test.tsx
@@ -7,8 +7,9 @@ import { describe, it, expect, vi, Mock } from "vitest";
import FileExplorer from "./FileExplorer";
import { uploadFiles, listFiles } from "#/services/fileService";
import toast from "#/utils/toast";
+import AgentState from "#/types/AgentState";
-const toastSpy = vi.spyOn(toast, "stickyError");
+const toastSpy = vi.spyOn(toast, "error");
vi.mock("../../services/fileService", async () => ({
listFiles: vi.fn(async (path: string = "/") => {
@@ -42,7 +43,13 @@ describe("FileExplorer", () => {
it.todo("should render an empty workspace");
it.only("should refetch the workspace when clicking the refresh button", async () => {
- const { getByText } = renderWithProviders(
);
+ const { getByText } = renderWithProviders(
, {
+ preloadedState: {
+ agent: {
+ curAgentState: AgentState.RUNNING,
+ },
+ },
+ });
await waitFor(() => {
expect(getByText("folder1")).toBeInTheDocument();
expect(getByText("file2.ts")).toBeInTheDocument();
diff --git a/frontend/src/components/file-explorer/FileExplorer.tsx b/frontend/src/components/file-explorer/FileExplorer.tsx
index 5202186b335a..9b3d6d49e15d 100644
--- a/frontend/src/components/file-explorer/FileExplorer.tsx
+++ b/frontend/src/components/file-explorer/FileExplorer.tsx
@@ -5,14 +5,16 @@ import {
IoIosRefresh,
IoIosCloudUpload,
} from "react-icons/io";
-import { useDispatch } from "react-redux";
+import { useDispatch, useSelector } from "react-redux";
import { IoFileTray } from "react-icons/io5";
import { twMerge } from "tailwind-merge";
+import AgentState from "#/types/AgentState";
import { setRefreshID } from "#/state/codeSlice";
import { listFiles, uploadFiles } from "#/services/fileService";
import IconButton from "../IconButton";
import ExplorerTree from "./ExplorerTree";
import toast from "#/utils/toast";
+import { RootState } from "#/store";
interface ExplorerActionsProps {
onRefresh: () => void;
@@ -87,6 +89,7 @@ function FileExplorer() {
const [isHidden, setIsHidden] = React.useState(false);
const [isDragging, setIsDragging] = React.useState(false);
const [files, setFiles] = React.useState
([]);
+ const { curAgentState } = useSelector((state: RootState) => state.agent);
const fileInputRef = React.useRef(null);
const dispatch = useDispatch();
@@ -95,6 +98,12 @@ function FileExplorer() {
};
const refreshWorkspace = async () => {
+ if (
+ curAgentState === AgentState.LOADING ||
+ curAgentState === AgentState.STOPPED
+ ) {
+ return;
+ }
dispatch(setRefreshID(Math.random()));
setFiles(await listFiles("/"));
};
@@ -104,7 +113,7 @@ function FileExplorer() {
await uploadFiles(toAdd);
await refreshWorkspace();
} catch (error) {
- toast.stickyError("ws", "Error uploading file");
+ toast.error("ws", "Error uploading file");
}
};
@@ -112,7 +121,9 @@ function FileExplorer() {
(async () => {
await refreshWorkspace();
})();
+ }, [curAgentState]);
+ React.useEffect(() => {
const enableDragging = () => {
setIsDragging(true);
};
@@ -130,6 +141,10 @@ function FileExplorer() {
};
}, []);
+ if (!files.length) {
+ return null;
+ }
+
return (
{isDragging && (
diff --git a/frontend/src/components/file-explorer/TreeNode.tsx b/frontend/src/components/file-explorer/TreeNode.tsx
index 9852003642e2..df22093ede63 100644
--- a/frontend/src/components/file-explorer/TreeNode.tsx
+++ b/frontend/src/components/file-explorer/TreeNode.tsx
@@ -4,9 +4,8 @@ import { twMerge } from "tailwind-merge";
import { RootState } from "#/store";
import FolderIcon from "../FolderIcon";
import FileIcon from "../FileIcons";
-import { listFiles } from "#/services/fileService";
-import { setActiveFilepath } from "#/state/codeSlice";
-import { CodeEditorContext } from "../CodeEditorContext";
+import { listFiles, selectFile } from "#/services/fileService";
+import { setCode, setActiveFilepath } from "#/state/codeSlice";
interface TitleProps {
name: string;
@@ -36,8 +35,8 @@ interface TreeNodeProps {
function TreeNode({ path, defaultOpen = false }: TreeNodeProps) {
const [isOpen, setIsOpen] = React.useState(defaultOpen);
const [children, setChildren] = React.useState
(null);
- const { selectedFileAbsolutePath } = React.useContext(CodeEditorContext);
const refreshID = useSelector((state: RootState) => state.code.refreshID);
+ const activeFilepath = useSelector((state: RootState) => state.code.path);
const dispatch = useDispatch();
@@ -60,10 +59,12 @@ function TreeNode({ path, defaultOpen = false }: TreeNodeProps) {
refreshChildren();
}, [refreshID, isOpen]);
- const handleClick = () => {
+ const handleClick = async () => {
if (isDirectory) {
setIsOpen((prev) => !prev);
} else {
+ const newCode = await selectFile(path);
+ dispatch(setCode(newCode));
dispatch(setActiveFilepath(path));
}
};
@@ -72,7 +73,7 @@ function TreeNode({ path, defaultOpen = false }: TreeNodeProps) {
({
- fetchMsgsMock: vi.fn(),
-}));
-
-vi.mock("../../../services/session", async (importOriginal) => ({
- ...(await importOriginal()),
- clearMsgs: vi.fn(),
- fetchMsgs: mocks.fetchMsgsMock.mockResolvedValue({
- messages: [
- {
- id: "1",
- role: "user",
- payload: { type: "action" },
- },
- {
- id: "2",
- role: "assistant",
- payload: { type: "observation" },
- },
- ],
- }),
-}));
-
-vi.mock("../../../services/chatService", async (importOriginal) => ({
- ...(await importOriginal()),
- addChatMessageFromEvent: vi.fn(),
-}));
-
vi.mock("../../../services/actions", async (importOriginal) => ({
...(await importOriginal()),
handleAssistantMessage: vi.fn(),
}));
-vi.mock("../../../utils/toast", () => ({
- default: {
- stickyError: vi.fn(),
- },
-}));
+vi.spyOn(Session, "isConnected").mockImplementation(() => true);
+const restoreOrStartNewSessionSpy = vi.spyOn(
+ Session,
+ "restoreOrStartNewSession",
+);
describe("LoadPreviousSession", () => {
afterEach(() => {
@@ -75,7 +44,6 @@ describe("LoadPreviousSession", () => {
userEvent.click(startNewSessionButton);
});
- expect(clearMsgs).toHaveBeenCalledTimes(1);
// modal should close right after clearing messages
expect(onOpenChangeMock).toHaveBeenCalledWith(false);
});
@@ -93,36 +61,9 @@ describe("LoadPreviousSession", () => {
});
await waitFor(() => {
- expect(fetchMsgs).toHaveBeenCalledTimes(1);
- expect(addChatMessageFromEvent).toHaveBeenCalledTimes(1);
- expect(handleAssistantMessage).toHaveBeenCalledTimes(1);
+ expect(restoreOrStartNewSessionSpy).toHaveBeenCalledTimes(1);
});
// modal should close right after fetching messages
expect(onOpenChangeMock).toHaveBeenCalledWith(false);
});
-
- it("should show an error toast if there is an error fetching the session", async () => {
- mocks.fetchMsgsMock.mockRejectedValue(new Error("Get messages failed."));
-
- render();
-
- const resumeSessionButton = screen.getByRole("button", {
- name: RESUME_SESSION_BUTTON_LABEL_KEY,
- });
-
- act(() => {
- userEvent.click(resumeSessionButton);
- });
-
- await waitFor(async () => {
- await expect(() => fetchMsgs()).rejects.toThrow();
- expect(handleAssistantMessage).not.toHaveBeenCalled();
- expect(addChatMessageFromEvent).not.toHaveBeenCalled();
- // error toast should be shown
- expect(toast.stickyError).toHaveBeenCalledWith(
- "ws",
- "Error fetching the session",
- );
- });
- });
});
diff --git a/frontend/src/components/modals/load-previous-session/LoadPreviousSessionModal.tsx b/frontend/src/components/modals/load-previous-session/LoadPreviousSessionModal.tsx
index 2f29f05f10b5..7e3fd1140bf2 100644
--- a/frontend/src/components/modals/load-previous-session/LoadPreviousSessionModal.tsx
+++ b/frontend/src/components/modals/load-previous-session/LoadPreviousSessionModal.tsx
@@ -1,11 +1,8 @@
import React from "react";
import { useTranslation } from "react-i18next";
import { I18nKey } from "#/i18n/declaration";
-import { handleAssistantMessage } from "#/services/actions";
-import { addChatMessageFromEvent } from "#/services/chatService";
-import { clearMsgs, fetchMsgs } from "#/services/session";
-import toast from "#/utils/toast";
import BaseModal from "../base-modal/BaseModal";
+import Session from "#/services/session";
interface LoadPreviousSessionModalProps {
isOpen: boolean;
@@ -18,28 +15,6 @@ function LoadPreviousSessionModal({
}: LoadPreviousSessionModalProps) {
const { t } = useTranslation();
- const onStartNewSession = async () => {
- await clearMsgs();
- };
-
- const onResumeSession = async () => {
- try {
- const { messages } = await fetchMsgs();
-
- messages.forEach((message) => {
- if (message.role === "user") {
- addChatMessageFromEvent(message.payload);
- }
-
- if (message.role === "assistant") {
- handleAssistantMessage(message.payload);
- }
- });
- } catch (error) {
- toast.stickyError("ws", "Error fetching the session");
- }
- };
-
return (
true);
vi.mock("#/services/settings", async (importOriginal) => ({
...(await importOriginal()),
@@ -35,12 +37,8 @@ vi.mock("#/services/settings", async (importOriginal) => ({
saveSettings: vi.fn(),
}));
-vi.mock("#/services/agent", async () => ({
- initializeAgent: vi.fn(),
-}));
-
-vi.mock("#/api", async (importOriginal) => ({
- ...(await importOriginal()),
+vi.mock("#/services/options", async (importOriginal) => ({
+ ...(await importOriginal()),
fetchModels: vi
.fn()
.mockResolvedValue(Promise.resolve(["model1", "model2", "model3"])),
@@ -162,7 +160,7 @@ describe("SettingsModal", () => {
userEvent.click(saveButton);
});
- expect(initializeAgent).toHaveBeenCalled();
+ expect(startNewSessionSpy).toHaveBeenCalled();
});
it("should display a toast for every change", async () => {
diff --git a/frontend/src/components/modals/settings/SettingsModal.tsx b/frontend/src/components/modals/settings/SettingsModal.tsx
index e68c0e153343..38bca6454b4f 100644
--- a/frontend/src/components/modals/settings/SettingsModal.tsx
+++ b/frontend/src/components/modals/settings/SettingsModal.tsx
@@ -3,10 +3,10 @@ import i18next from "i18next";
import React, { useEffect } from "react";
import { useTranslation } from "react-i18next";
import { useSelector } from "react-redux";
-import { fetchAgents, fetchModels } from "#/api";
+import { fetchAgents, fetchModels } from "#/services/options";
import { AvailableLanguages } from "#/i18n";
import { I18nKey } from "#/i18n/declaration";
-import { initializeAgent } from "#/services/agent";
+import Session from "#/services/session";
import { RootState } from "../../../store";
import AgentState from "../../../types/AgentState";
import {
@@ -100,7 +100,7 @@ function SettingsModal({ isOpen, onOpenChange }: SettingsProps) {
const updatedSettings = getSettingsDifference(settings);
saveSettings(settings);
i18next.changeLanguage(settings.LANGUAGE);
- initializeAgent(); // reinitialize the agent with the new settings
+ Session.startNewSession();
const sensitiveKeys = ["LLM_API_KEY"];
diff --git a/frontend/src/i18n/translation.json b/frontend/src/i18n/translation.json
index b203a330cdba..4af256fc0fc5 100644
--- a/frontend/src/i18n/translation.json
+++ b/frontend/src/i18n/translation.json
@@ -266,12 +266,12 @@
"en": "Please stop the agent before editing these settings."
},
"LOAD_SESSION$MODAL_TITLE": {
- "en": "Unfinished Session Detected",
- "zh-CN": "检测到有未完成的会话",
- "zh-TW": "偵測到未完成的會話"
+ "en": "Return to existing session?",
+ "zh-CN": "是否继续未完成的会话?",
+ "zh-TW": "是否繼續未完成的會話?"
},
"LOAD_SESSION$MODAL_CONTENT": {
- "en": "You seem to have an unfinished task. Would you like to pick up where you left off or start fresh?",
+ "en": "You seem to have an ongoing session. Would you like to pick up where you left off, or start fresh?",
"zh-CN": "您似乎有一个未完成的任务。您想继续之前的工作还是重新开始?",
"zh-TW": "您似乎有一個未完成的任務。您想從上次離開的地方繼續還是重新開始?"
},
diff --git a/frontend/src/services/agent.test.ts b/frontend/src/services/agent.test.ts
deleted file mode 100644
index 0075f4dbdd8d..000000000000
--- a/frontend/src/services/agent.test.ts
+++ /dev/null
@@ -1,29 +0,0 @@
-import { describe, expect, it, vi } from "vitest";
-
-import ActionType from "#/types/ActionType";
-import { initializeAgent } from "./agent";
-import { Settings, saveSettings } from "./settings";
-import Socket from "./socket";
-
-const sendSpy = vi.spyOn(Socket, "send");
-
-describe("initializeAgent", () => {
- it("Should initialize the agent with the current settings", () => {
- const settings: Settings = {
- LLM_MODEL: "llm_value",
- AGENT: "agent_value",
- LANGUAGE: "language_value",
- LLM_API_KEY: "sk-...",
- };
-
- const event = {
- action: ActionType.INIT,
- args: settings,
- };
-
- saveSettings(settings);
- initializeAgent();
-
- expect(sendSpy).toHaveBeenCalledWith(JSON.stringify(event));
- });
-});
diff --git a/frontend/src/services/agent.ts b/frontend/src/services/agent.ts
deleted file mode 100644
index dd8b7a6bfded..000000000000
--- a/frontend/src/services/agent.ts
+++ /dev/null
@@ -1,14 +0,0 @@
-import ActionType from "#/types/ActionType";
-import { getSettings } from "./settings";
-import Socket from "./socket";
-
-/**
- * Initialize the agent with the current settings.
- * @param settings - The new settings.
- */
-export const initializeAgent = () => {
- const settings = getSettings();
- const event = { action: ActionType.INIT, args: settings };
- const eventString = JSON.stringify(event);
- Socket.send(eventString);
-};
diff --git a/frontend/src/services/agentStateService.ts b/frontend/src/services/agentStateService.ts
index 8704dc7097fd..4df4f1bfc256 100644
--- a/frontend/src/services/agentStateService.ts
+++ b/frontend/src/services/agentStateService.ts
@@ -1,7 +1,6 @@
import ActionType from "#/types/ActionType";
import AgentState from "#/types/AgentState";
-import Socket from "./socket";
-import { initializeAgent } from "./agent";
+import Session from "./session";
const INIT_DELAY = 1000;
@@ -10,10 +9,10 @@ export function changeAgentState(state: AgentState): void {
action: ActionType.CHANGE_AGENT_STATE,
args: { agent_state: state },
});
- Socket.send(eventString);
+ Session.send(eventString);
if (state === AgentState.STOPPED) {
setTimeout(() => {
- initializeAgent();
+ Session.startNewSession();
}, INIT_DELAY);
}
}
diff --git a/frontend/src/services/api.ts b/frontend/src/services/api.ts
new file mode 100644
index 000000000000..bff6fc92293a
--- /dev/null
+++ b/frontend/src/services/api.ts
@@ -0,0 +1,58 @@
+import { getToken } from "./auth";
+import toast from "#/utils/toast";
+
+const WAIT_FOR_AUTH_DELAY_MS = 500;
+
+export async function request(
+ url: string,
+ optionsIn: RequestInit = {},
+ disableToast: boolean = false,
+ /* eslint-disable-next-line @typescript-eslint/no-explicit-any */
+): Promise {
+ const options = JSON.parse(JSON.stringify(optionsIn));
+
+ const onFail = (msg: string) => {
+ if (!disableToast) {
+ toast.error("api", msg);
+ }
+ throw new Error(msg);
+ };
+
+ const needsAuth = !url.startsWith("/api/options/");
+ const token = getToken();
+ if (!token && needsAuth) {
+ return new Promise((resolve) => {
+ setTimeout(() => {
+ resolve(request(url, optionsIn, disableToast));
+ }, WAIT_FOR_AUTH_DELAY_MS);
+ });
+ }
+ if (token) {
+ options.headers = {
+ ...(options.headers || {}),
+ Authorization: `Bearer ${token}`,
+ };
+ }
+
+ let response = null;
+ try {
+ response = await fetch(url, options);
+ } catch (e) {
+ onFail(`Error fetching ${url}`);
+ }
+ if (response?.status && response?.status >= 400) {
+ onFail(
+ `${response.status} error while fetching ${url}: ${response?.statusText}`,
+ );
+ }
+ if (!response?.ok) {
+ onFail(`Error fetching ${url}: ${response?.statusText}`);
+ }
+
+ try {
+ return await (response && response.json());
+ } catch (e) {
+ onFail(`Error parsing JSON from ${url}`);
+ }
+ return null;
+}
diff --git a/frontend/src/services/auth.test.ts b/frontend/src/services/auth.test.ts
index b21843acfcae..64850c37917a 100644
--- a/frontend/src/services/auth.test.ts
+++ b/frontend/src/services/auth.test.ts
@@ -1,18 +1,5 @@
-import * as jose from "jose";
import type { Mock } from "vitest";
-import { fetchToken, validateToken, getToken } from "./auth";
-
-vi.mock("jose", () => ({
- decodeJwt: vi.fn(),
-}));
-
-// SUGGESTION: Prefer using msw for mocking requests (see https://mswjs.io/)
-global.fetch = vi.fn(() =>
- Promise.resolve({
- status: 200,
- json: () => Promise.resolve({ token: "newToken" }),
- }),
-) as Mock;
+import { getToken } from "./auth";
Storage.prototype.getItem = vi.fn();
Storage.prototype.setItem = vi.fn();
@@ -22,66 +9,12 @@ describe("Auth Service", () => {
vi.clearAllMocks();
});
- describe("fetchToken", () => {
+ describe("getToken", () => {
it("should fetch and return a token", async () => {
- const data = await fetchToken();
-
+ (Storage.prototype.getItem as Mock).mockReturnValue("newToken");
+ const data = await getToken();
expect(localStorage.getItem).toHaveBeenCalledWith("token"); // Used to set Authorization header
- expect(data).toEqual({ token: "newToken" });
- expect(fetch).toHaveBeenCalledWith(`/api/auth`, {
- headers: expect.any(Headers),
- });
- });
-
- it("throws an error if response status is not 200", async () => {
- (fetch as Mock).mockImplementationOnce(() =>
- Promise.resolve({ status: 401 }),
- );
- await expect(fetchToken()).rejects.toThrow("Get token failed.");
- });
- });
-
- describe("validateToken", () => {
- it("returns true for a valid token", () => {
- (jose.decodeJwt as Mock).mockReturnValue({ sid: "123" });
- expect(validateToken("validToken")).toBe(true);
- });
-
- it("returns false for an invalid token", () => {
- (jose.decodeJwt as Mock).mockReturnValue({});
- expect(validateToken("invalidToken")).toBe(false);
- });
-
- it("returns false when decodeJwt throws", () => {
- (jose.decodeJwt as Mock).mockImplementation(() => {
- throw new Error("Invalid token");
- });
- expect(validateToken("badToken")).toBe(false);
- });
- });
-
- describe("getToken", () => {
- it("returns existing valid token from localStorage", async () => {
- (jose.decodeJwt as Mock).mockReturnValue({ sid: "123" });
- (Storage.prototype.getItem as Mock).mockReturnValue("existingToken");
-
- const token = await getToken();
- expect(token).toBe("existingToken");
- });
-
- it("fetches, validates, and stores a new token when existing token is invalid", async () => {
- (jose.decodeJwt as Mock)
- .mockReturnValueOnce({})
- .mockReturnValueOnce({ sid: "123" });
-
- const token = await getToken();
- expect(token).toBe("newToken");
- expect(localStorage.setItem).toHaveBeenCalledWith("token", "newToken");
- });
-
- it("throws an error when fetched token is invalid", async () => {
- (jose.decodeJwt as Mock).mockReturnValue({});
- await expect(getToken()).rejects.toThrow("Token validation failed.");
+ expect(data).toEqual("newToken");
});
});
});
diff --git a/frontend/src/services/auth.ts b/frontend/src/services/auth.ts
index 29fd75db65b8..a7d8cfa490b9 100644
--- a/frontend/src/services/auth.ts
+++ b/frontend/src/services/auth.ts
@@ -1,44 +1,13 @@
-import * as jose from "jose";
-import { ResFetchToken } from "#/types/ResponseType";
+const TOKEN_KEY = "token";
-const fetchToken = async (): Promise => {
- const headers = new Headers({
- "Content-Type": "application/json",
- Authorization: `Bearer ${localStorage.getItem("token")}`,
- });
- const response = await fetch(`/api/auth`, { headers });
- if (response.status !== 200) {
- throw new Error("Get token failed.");
- }
- const data: ResFetchToken = await response.json();
- return data;
-};
+const getToken = (): string => localStorage.getItem(TOKEN_KEY) ?? "";
-export const validateToken = (token: string): boolean => {
- try {
- const claims = jose.decodeJwt(token);
- return !(claims.sid === undefined || claims.sid === "");
- } catch (error) {
- return false;
- }
+const clearToken = (): void => {
+ localStorage.removeItem(TOKEN_KEY);
};
-const getToken = async (): Promise => {
- const token = localStorage.getItem("token") ?? "";
- if (validateToken(token)) {
- return token;
- }
-
- const data = await fetchToken();
- if (data.token === undefined || data.token === "") {
- throw new Error("Get token failed.");
- }
- const newToken = data.token;
- if (validateToken(newToken)) {
- localStorage.setItem("token", newToken);
- return newToken;
- }
- throw new Error("Token validation failed.");
+const setToken = (token: string): void => {
+ localStorage.setItem(TOKEN_KEY, token);
};
-export { getToken, fetchToken };
+export { getToken, setToken, clearToken };
diff --git a/frontend/src/services/chatService.ts b/frontend/src/services/chatService.ts
index 52cb4821e143..af1ab45ce86b 100644
--- a/frontend/src/services/chatService.ts
+++ b/frontend/src/services/chatService.ts
@@ -1,28 +1,8 @@
-import store from "#/store";
import ActionType from "#/types/ActionType";
-import { SocketMessage } from "#/types/ResponseType";
-import { ActionMessage } from "#/types/Message";
-import Socket from "./socket";
-import { addUserMessage } from "#/state/chatSlice";
+import Session from "./session";
export function sendChatMessage(message: string): void {
const event = { action: ActionType.MESSAGE, args: { content: message } };
const eventString = JSON.stringify(event);
- Socket.send(eventString);
-}
-
-export function addChatMessageFromEvent(event: string | SocketMessage): void {
- try {
- let data: ActionMessage;
- if (typeof event === "string") {
- data = JSON.parse(event);
- } else {
- data = event as ActionMessage;
- }
- if (data && data.args && data.args.task) {
- store.dispatch(addUserMessage(data.args.task));
- }
- } catch (error) {
- //
- }
+ Session.send(eventString);
}
diff --git a/frontend/src/services/fileService.ts b/frontend/src/services/fileService.ts
index 21e8f7f054d2..f12a0e499bca 100644
--- a/frontend/src/services/fileService.ts
+++ b/frontend/src/services/fileService.ts
@@ -1,9 +1,7 @@
+import { request } from "./api";
+
export async function selectFile(file: string): Promise {
- const res = await fetch(`/api/select-file?file=${file}`);
- const data = await res.json();
- if (res.status !== 200) {
- throw new Error(data.error);
- }
+ const data = await request(`/api/select-file?file=${file}`);
return data.code as string;
}
@@ -13,20 +11,13 @@ export async function uploadFiles(files: FileList) {
formData.append("files", files[i]);
}
- const res = await fetch("/api/upload-files", {
+ await request("/api/upload-files", {
method: "POST",
body: formData,
});
-
- const data = await res.json();
-
- if (res.status !== 200) {
- throw new Error(data.error || "Failed to upload files.");
- }
}
export async function listFiles(path: string = "/"): Promise {
- const res = await fetch(`/api/list-files?path=${path}`);
- const data = await res.json();
+ const data = await request(`/api/list-files?path=${path}`);
return data as string[];
}
diff --git a/frontend/src/services/options.ts b/frontend/src/services/options.ts
new file mode 100644
index 000000000000..e3216be55d79
--- /dev/null
+++ b/frontend/src/services/options.ts
@@ -0,0 +1,9 @@
+import { request } from "./api";
+
+export async function fetchModels() {
+ return request(`/api/options/models`);
+}
+
+export async function fetchAgents() {
+ return request(`/api/options/agents`);
+}
diff --git a/frontend/src/services/session.test.ts b/frontend/src/services/session.test.ts
index 3ef95854804e..bf23b06bf5a0 100644
--- a/frontend/src/services/session.test.ts
+++ b/frontend/src/services/session.test.ts
@@ -1,127 +1,36 @@
-import type { Mock } from "vitest";
-import {
- ResDelMsg,
- ResFetchMsg,
- ResFetchMsgTotal,
- ResFetchMsgs,
-} from "../types/ResponseType";
-import { clearMsgs, fetchMsgTotal, fetchMsgs } from "./session";
-
-// SUGGESTION: Prefer using msw for mocking requests (see https://mswjs.io/)
-global.fetch = vi.fn();
-Storage.prototype.getItem = vi.fn();
-
-describe("Session Service", () => {
- beforeEach(() => {
- vi.clearAllMocks();
- });
-
- afterEach(() => {
- // Used to set Authorization header
- expect(localStorage.getItem).toHaveBeenCalledWith("token");
- });
-
- describe("fetchMsgTotal", () => {
- it("should fetch and return message total", async () => {
- const expectedResult: ResFetchMsgTotal = {
- msg_total: 10,
- };
-
- (fetch as Mock).mockImplementationOnce(() =>
- Promise.resolve({
- status: 200,
- json: () => Promise.resolve(expectedResult),
- }),
- );
-
- const data = await fetchMsgTotal();
-
- expect(fetch).toHaveBeenCalledWith(`/api/messages/total`, {
- headers: expect.any(Headers),
- });
-
- expect(data).toEqual(expectedResult);
- });
-
- it("throws an error if response status is not 200", async () => {
- // NOTE: The current implementation ONLY handles 200 status;
- // this means throwing even with a status of 201, 204, etc.
- (fetch as Mock).mockImplementationOnce(() =>
- Promise.resolve({ status: 401 }),
- );
-
- await expect(fetchMsgTotal()).rejects.toThrow(
- "Get message total failed.",
- );
- });
- });
-
- describe("fetchMsgs", () => {
- it("should fetch and return messages", async () => {
- const expectedResult: ResFetchMsgs = {
- messages: [
- {
- id: "1",
- role: "user",
- payload: {} as ResFetchMsg["payload"],
- },
- ],
- };
-
- (fetch as Mock).mockImplementationOnce(() =>
- Promise.resolve({
- status: 200,
- json: () => Promise.resolve(expectedResult),
- }),
- );
-
- const data = await fetchMsgs();
-
- expect(fetch).toHaveBeenCalledWith(`/api/messages`, {
- headers: expect.any(Headers),
- });
-
- expect(data).toEqual(expectedResult);
- });
-
- it("throws an error if response status is not 200", async () => {
- (fetch as Mock).mockImplementationOnce(() =>
- Promise.resolve({ status: 401 }),
- );
-
- await expect(fetchMsgs()).rejects.toThrow("Get messages failed.");
- });
+import { describe, expect, it, vi } from "vitest";
+
+import ActionType from "#/types/ActionType";
+import { Settings, saveSettings } from "./settings";
+import Session from "./session";
+
+const sendSpy = vi.spyOn(Session, "send");
+const setupSpy = vi
+ /* eslint-disable-next-line @typescript-eslint/no-explicit-any */
+ .spyOn(Session as any, "_setupSocket")
+ .mockImplementation(() => {
+ /* eslint-disable-next-line @typescript-eslint/dot-notation */
+ Session["_initializeAgent"](); // use key syntax to fix complaint about private fn
});
- describe("clearMsgs", () => {
- it("should clear messages", async () => {
- const expectedResult: ResDelMsg = {
- ok: "true",
- };
-
- (fetch as Mock).mockImplementationOnce(() =>
- Promise.resolve({
- status: 200,
- json: () => Promise.resolve(expectedResult),
- }),
- );
-
- const data = await clearMsgs();
-
- expect(fetch).toHaveBeenCalledWith(`/api/messages`, {
- method: "DELETE",
- headers: expect.any(Headers),
- });
-
- expect(data).toEqual(expectedResult);
- });
-
- it("throws an error if response status is not 200", async () => {
- (fetch as Mock).mockImplementationOnce(() =>
- Promise.resolve({ status: 401 }),
- );
-
- await expect(clearMsgs()).rejects.toThrow("Delete messages failed.");
- });
+describe("startNewSession", () => {
+ it("Should start a new session with the current settings", () => {
+ const settings: Settings = {
+ LLM_MODEL: "llm_value",
+ AGENT: "agent_value",
+ LANGUAGE: "language_value",
+ LLM_API_KEY: "sk-...",
+ };
+
+ const event = {
+ action: ActionType.INIT,
+ args: settings,
+ };
+
+ saveSettings(settings);
+ Session.startNewSession();
+
+ expect(setupSpy).toHaveBeenCalledTimes(1);
+ expect(sendSpy).toHaveBeenCalledWith(JSON.stringify(event));
});
});
diff --git a/frontend/src/services/session.ts b/frontend/src/services/session.ts
index 91b6e98fe1e6..5b889952a699 100644
--- a/frontend/src/services/session.ts
+++ b/frontend/src/services/session.ts
@@ -1,49 +1,165 @@
-import {
- ResDelMsg,
- ResFetchMsgs,
- ResFetchMsgTotal,
-} from "../types/ResponseType";
-
-const fetchMsgTotal = async (): Promise => {
- const headers = new Headers({
- "Content-Type": "application/json",
- Authorization: `Bearer ${localStorage.getItem("token")}`,
- });
- const response = await fetch(`/api/messages/total`, { headers });
- if (response.status !== 200) {
- throw new Error("Get message total failed.");
- }
- const data: ResFetchMsgTotal = await response.json();
- return data;
-};
-
-const fetchMsgs = async (): Promise => {
- const headers = new Headers({
- "Content-Type": "application/json",
- Authorization: `Bearer ${localStorage.getItem("token")}`,
- });
- const response = await fetch(`/api/messages`, { headers });
- if (response.status !== 200) {
- throw new Error("Get messages failed.");
- }
- const data: ResFetchMsgs = await response.json();
- return data;
-};
-
-const clearMsgs = async (): Promise => {
- const headers = new Headers({
- "Content-Type": "application/json",
- Authorization: `Bearer ${localStorage.getItem("token")}`,
- });
- const response = await fetch(`/api/messages`, {
- method: "DELETE",
- headers,
- });
- if (response.status !== 200) {
- throw new Error("Delete messages failed.");
- }
- const data: ResDelMsg = await response.json();
- return data;
-};
-
-export { fetchMsgTotal, fetchMsgs, clearMsgs };
+import toast from "#/utils/toast";
+import { handleAssistantMessage } from "./actions";
+import { getToken, setToken, clearToken } from "./auth";
+import ActionType from "#/types/ActionType";
+import { getSettings } from "./settings";
+
+class Session {
+ private static _socket: WebSocket | null = null;
+
+ // callbacks contain a list of callable functions
+ // event: function, like:
+ // open: [function1, function2]
+ // message: [function1, function2]
+ private static callbacks: {
+ [K in keyof WebSocketEventMap]: ((data: WebSocketEventMap[K]) => void)[];
+ } = {
+ open: [],
+ message: [],
+ error: [],
+ close: [],
+ };
+
+ private static _connecting = false;
+
+ private static _disconnecting = false;
+
+ public static restoreOrStartNewSession() {
+ const token = getToken();
+ if (Session.isConnected()) {
+ Session.disconnect();
+ }
+ Session._connect(token);
+ }
+
+ public static startNewSession() {
+ clearToken();
+ Session.restoreOrStartNewSession();
+ }
+
+ private static _initializeAgent = () => {
+ const settings = getSettings();
+ const event = { action: ActionType.INIT, args: settings };
+ const eventString = JSON.stringify(event);
+ Session.send(eventString);
+ };
+
+ private static _connect(token: string = ""): void {
+ if (Session.isConnected()) return;
+ Session._connecting = true;
+
+ const protocol = window.location.protocol === "https:" ? "wss:" : "ws:";
+ const WS_URL = `${protocol}//${window.location.host}/ws?token=${token}`;
+ Session._socket = new WebSocket(WS_URL);
+ Session._setupSocket();
+ }
+
+ private static _setupSocket(): void {
+ if (!Session._socket) {
+ throw new Error("Socket is not initialized.");
+ }
+ Session._socket.onopen = (e) => {
+ toast.success("ws", "Connected to server.");
+ Session._connecting = false;
+ Session._initializeAgent();
+ Session.callbacks.open?.forEach((callback) => {
+ callback(e);
+ });
+ };
+
+ Session._socket.onmessage = (e) => {
+ let data = null;
+ try {
+ data = JSON.parse(e.data);
+ } catch (err) {
+ // TODO: report the error
+ console.error("Error parsing JSON data", err);
+ return;
+ }
+ if (data.error && data.error_code === 401) {
+ clearToken();
+ } else if (data.token) {
+ setToken(data.token);
+ } else {
+ handleAssistantMessage(data);
+ }
+ };
+
+ Session._socket.onerror = () => {
+ const msg = "Connection failed. Retry...";
+ toast.error("ws", msg);
+ };
+
+ Session._socket.onclose = () => {
+ if (!Session._disconnecting) {
+ setTimeout(() => {
+ Session.restoreOrStartNewSession();
+ }, 3000); // Reconnect after 3 seconds
+ }
+ Session._disconnecting = false;
+ };
+ }
+
+ static isConnected(): boolean {
+ return (
+ Session._socket !== null && Session._socket.readyState === WebSocket.OPEN
+ );
+ }
+
+ static disconnect(): void {
+ Session._disconnecting = true;
+ if (Session._socket) {
+ Session._socket.close();
+ }
+ Session._socket = null;
+ }
+
+ static send(message: string): void {
+ if (Session._connecting) {
+ setTimeout(() => Session.send(message), 1000);
+ return;
+ }
+ if (!Session.isConnected()) {
+ throw new Error("Not connected to server.");
+ }
+
+ if (Session.isConnected()) {
+ Session._socket?.send(message);
+ } else {
+ const msg = "Connection failed. Retry...";
+ toast.error("ws", msg);
+ }
+ }
+
+ static addEventListener(
+ event: string,
+ callback: (e: MessageEvent) => void,
+ ): void {
+ Session._socket?.addEventListener(
+ event as keyof WebSocketEventMap,
+ callback as (
+ this: WebSocket,
+ ev: WebSocketEventMap[keyof WebSocketEventMap],
+ ) => never,
+ );
+ }
+
+ static removeEventListener(
+ event: string,
+ listener: (e: Event) => void,
+ ): void {
+ Session._socket?.removeEventListener(event, listener);
+ }
+
+ static registerCallback(
+ event: K,
+ callbacks: ((data: WebSocketEventMap[K]) => void)[],
+ ): void {
+ if (Session.callbacks[event] === undefined) {
+ return;
+ }
+ Session.callbacks[event].push(...callbacks);
+ }
+}
+
+export default Session;
diff --git a/frontend/src/services/socket.ts b/frontend/src/services/socket.ts
deleted file mode 100644
index 981c25978371..000000000000
--- a/frontend/src/services/socket.ts
+++ /dev/null
@@ -1,129 +0,0 @@
-// import { toast } from "sonner";
-import toast from "#/utils/toast";
-import { handleAssistantMessage } from "./actions";
-import { getToken } from "./auth";
-
-class Socket {
- private static _socket: WebSocket | null = null;
-
- // callbacks contain a list of callable functions
- // event: function, like:
- // open: [function1, function2]
- // message: [function1, function2]
- private static callbacks: {
- [K in keyof WebSocketEventMap]: ((data: WebSocketEventMap[K]) => void)[];
- } = {
- open: [],
- message: [],
- error: [],
- close: [],
- };
-
- private static initializing = false;
-
- public static tryInitialize(): void {
- if (Socket.initializing) return;
- Socket.initializing = true;
- getToken()
- .then((token) => {
- Socket._initialize(token);
- })
- .catch(() => {
- const msg = `Connection failed. Retry...`;
- toast.stickyError("ws", msg);
-
- setTimeout(() => {
- this.tryInitialize();
- }, 1500);
- });
- }
-
- private static _initialize(token: string): void {
- if (Socket.isConnected()) return;
-
- const protocol = window.location.protocol === "https:" ? "wss:" : "ws:";
- const WS_URL = `${protocol}//${window.location.host}/ws?token=${token}`;
- Socket._socket = new WebSocket(WS_URL);
-
- Socket._socket.onopen = (e) => {
- toast.stickySuccess("ws", "Connected to server.");
- Socket.initializing = false;
- Socket.callbacks.open?.forEach((callback) => {
- callback(e);
- });
- };
-
- Socket._socket.onmessage = (e) => {
- handleAssistantMessage(e.data);
- };
-
- Socket._socket.onerror = () => {
- const msg = "Connection failed. Retry...";
- toast.stickyError("ws", msg);
- };
-
- Socket._socket.onclose = () => {
- // Reconnect after a delay
- setTimeout(() => {
- Socket.tryInitialize();
- }, 3000); // Reconnect after 3 seconds
- };
- }
-
- static isConnected(): boolean {
- return (
- Socket._socket !== null && Socket._socket.readyState === WebSocket.OPEN
- );
- }
-
- static send(message: string): void {
- if (!Socket.isConnected()) {
- Socket.tryInitialize();
- }
- if (Socket.initializing) {
- setTimeout(() => Socket.send(message), 1000);
- return;
- }
-
- if (Socket.isConnected()) {
- Socket._socket?.send(message);
- } else {
- const msg = "Connection failed. Retry...";
- toast.stickyError("ws", msg);
- }
- }
-
- static addEventListener(
- event: string,
- callback: (e: MessageEvent) => void,
- ): void {
- Socket._socket?.addEventListener(
- event as keyof WebSocketEventMap,
- callback as (
- this: WebSocket,
- ev: WebSocketEventMap[keyof WebSocketEventMap],
- ) => never,
- );
- }
-
- static removeEventListener(
- event: string,
- listener: (e: Event) => void,
- ): void {
- Socket._socket?.removeEventListener(event, listener);
- }
-
- static registerCallback(
- event: K,
- callbacks: ((data: WebSocketEventMap[K]) => void)[],
- ): void {
- if (Socket.callbacks[event] === undefined) {
- return;
- }
- Socket.callbacks[event].push(...callbacks);
- }
-}
-
-Socket.tryInitialize();
-
-export default Socket;
diff --git a/frontend/src/services/taskService.ts b/frontend/src/services/taskService.ts
index 88b877abf323..fa84444e294d 100644
--- a/frontend/src/services/taskService.ts
+++ b/frontend/src/services/taskService.ts
@@ -1,3 +1,5 @@
+import { request } from "./api";
+
export type Task = {
id: string;
goal: string;
@@ -14,14 +16,6 @@ export enum TaskState {
}
export async function getRootTask(): Promise {
- const headers = new Headers({
- "Content-Type": "application/json",
- Authorization: `Bearer ${localStorage.getItem("token")}`,
- });
- const res = await fetch("/api/root_task", { headers });
- if (res.status !== 200 && res.status !== 204) {
- return undefined;
- }
- const data = (await res.json()) as Task;
- return data;
+ const res = await request("/api/root_task");
+ return res as Task;
}
diff --git a/frontend/src/state/chatSlice.ts b/frontend/src/state/chatSlice.ts
index 1438103f9126..757806e1e048 100644
--- a/frontend/src/state/chatSlice.ts
+++ b/frontend/src/state/chatSlice.ts
@@ -3,13 +3,7 @@ import { createSlice, PayloadAction } from "@reduxjs/toolkit";
type SliceState = { messages: Message[] };
const initialState: SliceState = {
- messages: [
- {
- content:
- "Hi! I'm OpenDevin, an AI Software Engineer. What would you like to build with me today?",
- sender: "assistant",
- },
- ],
+ messages: [],
};
export const chatSlice = createSlice({
diff --git a/frontend/src/types/ResponseType.tsx b/frontend/src/types/ResponseType.tsx
index 63d710e1eeed..b635d78c3370 100644
--- a/frontend/src/types/ResponseType.tsx
+++ b/frontend/src/types/ResponseType.tsx
@@ -1,41 +1,5 @@
import { ActionMessage, ObservationMessage } from "./Message";
-type Role = "user" | "assistant";
-
-interface ResConfigurations {
- [key: string]: string | boolean | number;
-}
-
-interface ResFetchToken {
- token: string;
-}
-
-interface ResFetchMsgTotal {
- msg_total: number;
-}
-
-interface ResFetchMsg {
- id: string;
- role: Role;
- payload: SocketMessage;
-}
-
-interface ResFetchMsgs {
- messages: ResFetchMsg[];
-}
-
-interface ResDelMsg {
- ok: string;
-}
-
type SocketMessage = ActionMessage | ObservationMessage;
-export {
- type ResConfigurations,
- type ResFetchToken,
- type ResFetchMsgTotal,
- type ResFetchMsg,
- type ResFetchMsgs,
- type ResDelMsg,
- type SocketMessage,
-};
+export { type SocketMessage };
diff --git a/frontend/src/utils/toast.tsx b/frontend/src/utils/toast.tsx
index 08debd157475..132b3497c28b 100644
--- a/frontend/src/utils/toast.tsx
+++ b/frontend/src/utils/toast.tsx
@@ -3,15 +3,10 @@ import toast from "react-hot-toast";
const idMap = new Map();
export default {
- stickyError: (id: string, msg: string) => {
+ error: (id: string, msg: string) => {
if (idMap.has(id)) return; // prevent duplicate toast
- const toastId = toast.loading(msg, {
- // icon: "👏",
- // style: {
- // borderRadius: "10px",
- // background: "#333",
- // color: "#fff",
- // },
+ const toastId = toast(msg, {
+ duration: 4000,
style: {
background: "#ef4444",
color: "#fff",
@@ -24,12 +19,13 @@ export default {
});
idMap.set(id, toastId);
},
- stickySuccess: (id: string, msg: string) => {
+ success: (id: string, msg: string) => {
const toastId = idMap.get(id);
if (toastId === undefined) return;
if (toastId) {
toast.success(msg, {
id: toastId,
+ duration: 4000,
style: {
background: "#333",
color: "#fff",
diff --git a/opendevin/const/guide_url.py b/opendevin/const/guide_url.py
deleted file mode 100644
index 80bb3cfa985c..000000000000
--- a/opendevin/const/guide_url.py
+++ /dev/null
@@ -1 +0,0 @@
-TROUBLESHOOTING_URL = 'https://opendevin.github.io/OpenDevin/modules/usage/troubleshooting'
diff --git a/opendevin/controller/agent_controller.py b/opendevin/controller/agent_controller.py
index cb374f9a8021..13db2501c5a2 100644
--- a/opendevin/controller/agent_controller.py
+++ b/opendevin/controller/agent_controller.py
@@ -243,6 +243,9 @@ async def _step(self):
def get_state(self):
return self.state
+ def set_state(self, state: State):
+ self.state = state
+
def _is_stuck(self):
# check if delegate stuck
if self.delegate and self.delegate._is_stuck():
diff --git a/opendevin/core/config.py b/opendevin/core/config.py
index aba7708c8f97..1de2b78a5a04 100644
--- a/opendevin/core/config.py
+++ b/opendevin/core/config.py
@@ -3,6 +3,7 @@
import os
import pathlib
import platform
+import uuid
from dataclasses import dataclass, field, fields, is_dataclass
from types import UnionType
from typing import Any, ClassVar, get_args, get_origin
@@ -173,6 +174,7 @@ class AppConfig(metaclass=Singleton):
sandbox_user_id: int = os.getuid() if hasattr(os, 'getuid') else 1000
sandbox_timeout: int = 120
github_token: str | None = None
+ jwt_secret: str = uuid.uuid4().hex
debug: bool = False
enable_auto_lint: bool = (
False # once enabled, OpenDevin would lint files after editing
diff --git a/opendevin/core/const/guide_url.py b/opendevin/core/const/guide_url.py
new file mode 100644
index 000000000000..7ec5e6c908ac
--- /dev/null
+++ b/opendevin/core/const/guide_url.py
@@ -0,0 +1,3 @@
+TROUBLESHOOTING_URL = (
+ 'https://opendevin.github.io/OpenDevin/modules/usage/troubleshooting'
+)
diff --git a/opendevin/runtime/docker/exec_box.py b/opendevin/runtime/docker/exec_box.py
index dd283bf87551..6af621d2517e 100644
--- a/opendevin/runtime/docker/exec_box.py
+++ b/opendevin/runtime/docker/exec_box.py
@@ -10,8 +10,8 @@
import docker
-from opendevin.const.guide_url import TROUBLESHOOTING_URL
from opendevin.core.config import config
+from opendevin.core.const.guide_url import TROUBLESHOOTING_URL
from opendevin.core.exceptions import SandboxInvalidBackgroundCommandError
from opendevin.core.logger import opendevin_logger as logger
from opendevin.core.schema import CancellableStream
diff --git a/opendevin/runtime/docker/ssh_box.py b/opendevin/runtime/docker/ssh_box.py
index 9328deed913d..94fa04c5c383 100644
--- a/opendevin/runtime/docker/ssh_box.py
+++ b/opendevin/runtime/docker/ssh_box.py
@@ -12,8 +12,8 @@
import docker
from pexpect import exceptions, pxssh
-from opendevin.const.guide_url import TROUBLESHOOTING_URL
from opendevin.core.config import config
+from opendevin.core.const.guide_url import TROUBLESHOOTING_URL
from opendevin.core.exceptions import SandboxInvalidBackgroundCommandError
from opendevin.core.logger import opendevin_logger as logger
from opendevin.core.schema import CancellableStream
diff --git a/opendevin/runtime/e2b/filestore.py b/opendevin/runtime/e2b/filestore.py
new file mode 100644
index 000000000000..0494dde21a51
--- /dev/null
+++ b/opendevin/runtime/e2b/filestore.py
@@ -0,0 +1,18 @@
+from opendevin.storage.files import FileStore
+
+
+class E2BFileStore(FileStore):
+ def __init__(self, filesystem):
+ self.filesystem = filesystem
+
+ def write(self, path: str, contents: str) -> None:
+ self.filesystem.write(path, contents)
+
+ def read(self, path: str) -> str:
+ return self.filesystem.read(path)
+
+ def list(self, path: str) -> list[str]:
+ return self.filesystem.list(path)
+
+ def delete(self, path: str) -> None:
+ self.filesystem.delete(path)
diff --git a/opendevin/runtime/e2b/runtime.py b/opendevin/runtime/e2b/runtime.py
index 5542ae63af12..1a0710bf91ff 100644
--- a/opendevin/runtime/e2b/runtime.py
+++ b/opendevin/runtime/e2b/runtime.py
@@ -13,6 +13,7 @@
from opendevin.runtime.server.files import insert_lines, read_lines
from opendevin.runtime.server.runtime import ServerRuntime
+from .filestore import E2BFileStore
from .sandbox import E2BSandbox
@@ -26,25 +27,25 @@ def __init__(
super().__init__(event_stream, sid, sandbox)
if not isinstance(self.sandbox, E2BSandbox):
raise ValueError('E2BRuntime requires an E2BSandbox')
- self.filesystem = self.sandbox.filesystem
+ self.file_store = E2BFileStore(self.sandbox.filesystem)
async def read(self, action: FileReadAction) -> Observation:
- content = self.filesystem.read(action.path)
+ content = self.file_store.read(action.path)
lines = read_lines(content.split('\n'), action.start, action.end)
code_view = ''.join(lines)
return FileReadObservation(code_view, path=action.path)
async def write(self, action: FileWriteAction) -> Observation:
if action.start == 0 and action.end == -1:
- self.filesystem.write(action.path, action.content)
+ self.file_store.write(action.path, action.content)
return FileWriteObservation(content='', path=action.path)
- files = self.filesystem.list(action.path)
+ files = self.file_store.list(action.path)
if action.path in files:
- all_lines = self.filesystem.read(action.path)
+ all_lines = self.file_store.read(action.path).split('\n')
new_file = insert_lines(
action.content.split('\n'), all_lines, action.start, action.end
)
- self.filesystem.write(action.path, ''.join(new_file))
+ self.file_store.write(action.path, ''.join(new_file))
return FileWriteObservation('', path=action.path)
else:
# FIXME: we should create a new file here
diff --git a/opendevin/runtime/runtime.py b/opendevin/runtime/runtime.py
index 9a52fbff9392..79a7120e52f3 100644
--- a/opendevin/runtime/runtime.py
+++ b/opendevin/runtime/runtime.py
@@ -31,6 +31,7 @@
)
from opendevin.runtime.browser.browser_env import BrowserEnv
from opendevin.runtime.plugins import PluginRequirement
+from opendevin.storage import FileStore, InMemoryFileStore
def create_sandbox(sid: str = 'default', sandbox_type: str = 'exec') -> Sandbox:
@@ -55,6 +56,7 @@ class Runtime:
"""
sid: str
+ file_store: FileStore
def __init__(
self,
@@ -70,6 +72,7 @@ def __init__(
self.sandbox = sandbox
self._is_external_sandbox = True
self.browser = BrowserEnv()
+ self.file_store = InMemoryFileStore()
self.event_stream = event_stream
self.event_stream.subscribe(EventStreamSubscriber.RUNTIME, self.on_event)
self._bg_task = asyncio.create_task(self._start_background_observation_loop())
diff --git a/opendevin/runtime/server/runtime.py b/opendevin/runtime/server/runtime.py
index b399ff9badea..8f04984d591e 100644
--- a/opendevin/runtime/server/runtime.py
+++ b/opendevin/runtime/server/runtime.py
@@ -1,3 +1,4 @@
+from opendevin.core.config import config
from opendevin.events.action import (
AgentRecallAction,
BrowseInteractiveAction,
@@ -15,13 +16,25 @@
NullObservation,
Observation,
)
+from opendevin.events.stream import EventStream
+from opendevin.runtime import Sandbox
from opendevin.runtime.runtime import Runtime
+from opendevin.storage.local import LocalFileStore
from .browse import browse
from .files import read_file, write_file
class ServerRuntime(Runtime):
+ def __init__(
+ self,
+ event_stream: EventStream,
+ sid: str = 'default',
+ sandbox: Sandbox | None = None,
+ ):
+ super().__init__(event_stream, sid, sandbox)
+ self.file_store = LocalFileStore(config.workspace_base)
+
async def run(self, action: CmdRunAction) -> Observation:
return self._run_command(action.command, background=action.background)
@@ -71,10 +84,12 @@ async def run_ipython(self, action: IPythonRunCellAction) -> Observation:
return IPythonRunCellObservation(content=output, code=action.code)
async def read(self, action: FileReadAction) -> Observation:
+ # TODO: use self.file_store
working_dir = self.sandbox.get_working_directory()
return await read_file(action.path, working_dir, action.start, action.end)
async def write(self, action: FileWriteAction) -> Observation:
+ # TODO: use self.file_store
working_dir = self.sandbox.get_working_directory()
return await write_file(
action.path, working_dir, action.content, action.start, action.end
diff --git a/opendevin/server/agent/__init__.py b/opendevin/server/agent/__init__.py
deleted file mode 100644
index bfb0bc1f1a92..000000000000
--- a/opendevin/server/agent/__init__.py
+++ /dev/null
@@ -1,5 +0,0 @@
-from .manager import AgentManager
-
-agent_manager = AgentManager()
-
-__all__ = ['AgentManager', 'agent_manager']
diff --git a/opendevin/server/agent/agent.py b/opendevin/server/agent/agent.py
deleted file mode 100644
index a897815e9600..000000000000
--- a/opendevin/server/agent/agent.py
+++ /dev/null
@@ -1,161 +0,0 @@
-from typing import Optional
-
-from agenthub.codeact_agent.codeact_agent import CodeActAgent
-from opendevin.const.guide_url import TROUBLESHOOTING_URL
-from opendevin.controller import AgentController
-from opendevin.controller.agent import Agent
-from opendevin.core.config import config
-from opendevin.core.logger import opendevin_logger as logger
-from opendevin.core.schema import ActionType, AgentState, ConfigType
-from opendevin.events.action import (
- ChangeAgentStateAction,
- NullAction,
-)
-from opendevin.events.event import Event
-from opendevin.events.observation import (
- NullObservation,
-)
-from opendevin.events.serialization.action import action_from_dict
-from opendevin.events.serialization.event import event_to_dict
-from opendevin.events.stream import EventSource, EventStream, EventStreamSubscriber
-from opendevin.llm.llm import LLM
-from opendevin.runtime import DockerSSHBox
-from opendevin.runtime.e2b.runtime import E2BRuntime
-from opendevin.runtime.runtime import Runtime
-from opendevin.runtime.server.runtime import ServerRuntime
-from opendevin.server.session import session_manager
-
-
-class AgentUnit:
- """Represents a session with an agent.
-
- Attributes:
- controller: The AgentController instance for controlling the agent.
- """
-
- sid: str
- event_stream: EventStream
- controller: Optional[AgentController] = None
- runtime: Optional[Runtime] = None
-
- def __init__(self, sid):
- """Initializes a new instance of the Session class."""
- self.sid = sid
- self.event_stream = EventStream(sid)
- self.event_stream.subscribe(EventStreamSubscriber.SERVER, self.on_event)
- if config.runtime == 'server':
- logger.info('Using server runtime')
- self.runtime = ServerRuntime(self.event_stream, sid)
- elif config.runtime == 'e2b':
- logger.info('Using E2B runtime')
- self.runtime = E2BRuntime(self.event_stream, sid)
-
- async def send_error(self, message):
- """Sends an error message to the client.
-
- Args:
- message: The error message to send.
- """
- await session_manager.send_error(self.sid, message)
-
- async def send_message(self, message):
- """Sends a message to the client.
-
- Args:
- message: The message to send.
- """
- await session_manager.send_message(self.sid, message)
-
- async def send(self, data):
- """Sends data to the client.
-
- Args:
- data: The data to send.
- """
- await session_manager.send(self.sid, data)
-
- async def dispatch(self, action: str | None, data: dict):
- """Dispatches actions to the agent from the client."""
- if action is None:
- await self.send_error('Invalid action')
- return
-
- if action == ActionType.INIT:
- await self.create_controller(data)
- await self.event_stream.add_event(
- ChangeAgentStateAction(AgentState.INIT), EventSource.USER
- )
- return
-
- action_dict = data.copy()
- action_dict['action'] = action
- action_obj = action_from_dict(action_dict)
- await self.event_stream.add_event(action_obj, EventSource.USER)
-
- async def create_controller(self, start_event: dict):
- """Creates an AgentController instance.
-
- Args:
- start_event: The start event data (optional).
- """
- args = {
- key: value
- for key, value in start_event.get('args', {}).items()
- if value != ''
- } # remove empty values, prevent FE from sending empty strings
- agent_cls = args.get(ConfigType.AGENT, config.agent.name)
- model = args.get(ConfigType.LLM_MODEL, config.llm.model)
- api_key = args.get(ConfigType.LLM_API_KEY, config.llm.api_key)
- api_base = config.llm.base_url
- max_iterations = args.get(ConfigType.MAX_ITERATIONS, config.max_iterations)
- max_chars = args.get(ConfigType.MAX_CHARS, config.llm.max_chars)
-
- logger.info(f'Creating agent {agent_cls} using LLM {model}')
- llm = LLM(model=model, api_key=api_key, base_url=api_base)
- agent = Agent.get_cls(agent_cls)(llm)
- if isinstance(agent, CodeActAgent):
- if not self.runtime or not isinstance(self.runtime.sandbox, DockerSSHBox):
- logger.warning(
- 'CodeActAgent requires DockerSSHBox as sandbox! Using other sandbox that are not stateful (LocalBox, DockerExecBox) will not work properly.'
- )
- # Initializing plugins into the runtime
- assert self.runtime is not None, 'Runtime is not initialized'
- self.runtime.init_sandbox_plugins(agent.sandbox_plugins)
-
- if self.controller is not None:
- await self.controller.close()
- try:
- self.controller = AgentController(
- sid=self.sid,
- event_stream=self.event_stream,
- agent=agent,
- max_iterations=int(max_iterations),
- max_chars=int(max_chars),
- )
- except Exception as e:
- logger.exception(f'Error creating controller: {e}')
- await self.send_error(
- f'Error creating controller. Please check Docker is running and visit `{TROUBLESHOOTING_URL}` for more debugging information..'
- )
- return
-
- async def on_event(self, event: Event):
- """Callback function for agent events.
-
- Args:
- event: The agent event (Observation or Action).
- """
- if isinstance(event, NullAction):
- return
- if isinstance(event, NullObservation):
- return
- if event.source == 'agent' and not isinstance(
- event, (NullAction, NullObservation)
- ):
- await self.send(event_to_dict(event))
-
- async def close(self):
- if self.controller is not None:
- await self.controller.close()
- if self.runtime is not None:
- self.runtime.close()
diff --git a/opendevin/server/agent/manager.py b/opendevin/server/agent/manager.py
deleted file mode 100644
index 5f9593ef6428..000000000000
--- a/opendevin/server/agent/manager.py
+++ /dev/null
@@ -1,48 +0,0 @@
-import asyncio, atexit
-
-from opendevin.core.logger import opendevin_logger as logger
-from opendevin.server.session import session_manager
-
-from .agent import AgentUnit
-
-
-class AgentManager:
- sid_to_agent: dict[str, 'AgentUnit'] = {}
-
- def __init__(self):
- atexit.register(self.close)
-
- def register_agent(self, sid: str):
- """Registers a new agent.
-
- Args:
- sid: The session ID of the agent.
- """
- if sid not in self.sid_to_agent:
- self.sid_to_agent[sid] = AgentUnit(sid)
- return
-
- # TODO: confirm whether the agent is alive
-
- async def dispatch(self, sid: str, action: str | None, data: dict):
- """Dispatches actions to the agent from the client."""
- if sid not in self.sid_to_agent:
- # self.register_agent(sid) # auto-register agent, may be opened later
- logger.error(f'Agent not registered: {sid}')
- await session_manager.send_error(sid, 'Agent not registered')
- return
-
- await self.sid_to_agent[sid].dispatch(action, data)
-
- def close(self):
- try:
- loop = asyncio.get_event_loop()
- except RuntimeError:
- loop = asyncio.new_event_loop()
- asyncio.set_event_loop(loop)
- loop.run_until_complete(self._close())
-
- async def _close(self):
- logger.info(f'Closing {len(self.sid_to_agent)} agent(s)...')
- for sid, agent in self.sid_to_agent.items():
- await agent.close()
diff --git a/opendevin/server/auth/auth.py b/opendevin/server/auth/auth.py
index 98507e8f81ea..36cadd98428a 100644
--- a/opendevin/server/auth/auth.py
+++ b/opendevin/server/auth/auth.py
@@ -1,12 +1,9 @@
-import os
-
import jwt
from jwt.exceptions import InvalidTokenError
+from opendevin.core.config import config
from opendevin.core.logger import opendevin_logger as logger
-JWT_SECRET = os.getenv('JWT_SECRET', '5ecRe7')
-
def get_sid_from_token(token: str) -> str:
"""
@@ -20,7 +17,7 @@ def get_sid_from_token(token: str) -> str:
"""
try:
# Decode the JWT using the specified secret and algorithm
- payload = jwt.decode(token, JWT_SECRET, algorithms=['HS256'])
+ payload = jwt.decode(token, config.jwt_secret, algorithms=['HS256'])
# Ensure the payload contains 'sid'
if 'sid' in payload:
@@ -41,4 +38,4 @@ def sign_token(payload: dict[str, object]) -> str:
# "sid": sid,
# # "exp": datetime.now(timezone.utc) + timedelta(minutes=15),
# }
- return jwt.encode(payload, JWT_SECRET, algorithm='HS256')
+ return jwt.encode(payload, config.jwt_secret, algorithm='HS256')
diff --git a/opendevin/server/listen.py b/opendevin/server/listen.py
index cf21cbfa2deb..5bdeb15b8199 100644
--- a/opendevin/server/listen.py
+++ b/opendevin/server/listen.py
@@ -1,26 +1,25 @@
-import os
-import shutil
import uuid
import warnings
-from pathlib import Path
with warnings.catch_warnings():
warnings.simplefilter('ignore')
import litellm
-from fastapi import Depends, FastAPI, Request, Response, UploadFile, WebSocket, status
+from fastapi import FastAPI, Request, Response, UploadFile, WebSocket, status
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, RedirectResponse
-from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
+from fastapi.security import HTTPBearer
from fastapi.staticfiles import StaticFiles
import agenthub # noqa F401 (we import this to get the agents registered)
from opendevin.controller.agent import Agent
from opendevin.core.config import config
from opendevin.core.logger import opendevin_logger as logger
+from opendevin.events.action import ChangeAgentStateAction, NullAction
+from opendevin.events.observation import AgentStateChangedObservation, NullObservation
+from opendevin.events.serialization import event_to_dict
from opendevin.llm import bedrock
-from opendevin.server.agent import agent_manager
from opendevin.server.auth import get_sid_from_token, sign_token
-from opendevin.server.session import message_stack, session_manager
+from opendevin.server.session import session_manager
app = FastAPI()
app.add_middleware(
@@ -34,6 +33,45 @@
security_scheme = HTTPBearer()
+@app.middleware('http')
+async def attach_session(request: Request, call_next):
+ if request.url.path.startswith('/api/options/') or not request.url.path.startswith(
+ '/api/'
+ ):
+ response = await call_next(request)
+ return response
+
+ if not request.headers.get('Authorization'):
+ response = JSONResponse(
+ status_code=status.HTTP_401_UNAUTHORIZED,
+ content={'error': 'Missing Authorization header'},
+ )
+ return response
+
+ auth_token = request.headers.get('Authorization')
+ if 'Bearer' in auth_token:
+ auth_token = auth_token.split('Bearer')[1].strip()
+
+ request.state.sid = get_sid_from_token(auth_token)
+ if request.state.sid == '':
+ response = JSONResponse(
+ status_code=status.HTTP_401_UNAUTHORIZED,
+ content={'error': 'Invalid token'},
+ )
+ return response
+
+ request.state.session = session_manager.get_session(request.state.sid)
+ if request.state.session is None:
+ response = JSONResponse(
+ status_code=status.HTTP_404_NOT_FOUND,
+ content={'error': 'Session not found'},
+ )
+ return response
+
+ response = await call_next(request)
+ return response
+
+
# This endpoint receives events from the client (i.e. the browser)
@app.websocket('/ws')
async def websocket_endpoint(websocket: WebSocket):
@@ -99,16 +137,41 @@ async def websocket_endpoint(websocket: WebSocket):
```
"""
await websocket.accept()
- sid = get_sid_from_token(websocket.query_params.get('token') or '')
- if sid == '':
- logger.error('Failed to decode token')
- return
- session_manager.add_session(sid, websocket)
- agent_manager.register_agent(sid)
- await session_manager.loop_recv(sid, agent_manager.dispatch)
+ session = None
+ if websocket.query_params.get('token'):
+ token = websocket.query_params.get('token')
+ sid = get_sid_from_token(token)
+
+ if sid == '':
+ await websocket.send_json({'error': 'Invalid token', 'error_code': 401})
+ await websocket.close()
+ return
+ else:
+ sid = str(uuid.uuid4())
+ token = sign_token({'sid': sid})
+
+ session = session_manager.add_or_restart_session(sid, websocket)
+ await websocket.send_json({'token': token, 'status': 'ok'})
-@app.get('/api/litellm-models')
+ last_event_id = -1
+ if websocket.query_params.get('last_event_id'):
+ last_event_id = int(websocket.query_params.get('last_event_id'))
+ for event in session.agent_session.event_stream.get_events(
+ start_id=last_event_id + 1
+ ):
+ if isinstance(event, NullAction) or isinstance(event, NullObservation):
+ continue
+ if isinstance(event, ChangeAgentStateAction) or isinstance(
+ event, AgentStateChangedObservation
+ ):
+ continue
+ await websocket.send_json(event_to_dict(event))
+
+ await session.loop_recv()
+
+
+@app.get('/api/options/models')
async def get_litellm_models():
"""
Get all models supported by LiteLLM.
@@ -128,7 +191,7 @@ async def get_litellm_models():
return list(set(model_list))
-@app.get('/api/agents')
+@app.get('/api/options/agents')
async def get_agents():
"""
Get all agents supported by LiteLLM.
@@ -142,89 +205,6 @@ async def get_agents():
return agents
-@app.get('/api/auth')
-async def get_token(
- credentials: HTTPAuthorizationCredentials = Depends(security_scheme),
-):
- """
- Generate a JWT for authentication when starting a WebSocket connection. This endpoint checks if valid credentials
- are provided and uses them to get a session ID. If no valid credentials are provided, it generates a new session ID.
-
- To obtain an authentication token:
- ```sh
- curl -H "Authorization: Bearer 5ecRe7" http://localhost:3000/api/auth
- ```
- **Note:** If `JWT_SECRET` is set, use its value instead of `5ecRe7`.
- """
- if credentials and credentials.credentials:
- sid = get_sid_from_token(credentials.credentials)
- if not sid:
- sid = str(uuid.uuid4())
- logger.info(
- f'Invalid or missing credentials, generating new session ID: {sid}'
- )
- else:
- sid = str(uuid.uuid4())
- logger.info(f'No credentials provided, generating new session ID: {sid}')
-
- token = sign_token({'sid': sid})
- return {'token': token, 'status': 'ok'}
-
-
-@app.get('/api/messages')
-async def get_messages(
- credentials: HTTPAuthorizationCredentials = Depends(security_scheme),
-):
- """
- Get messages.
-
- To get messages:
- ```sh
- curl -H "Authorization: Bearer " http://localhost:3000/api/messages
- ```
- """
- data = []
- sid = get_sid_from_token(credentials.credentials)
- if sid != '':
- data = message_stack.get_messages(sid)
-
- return {'messages': data}
-
-
-@app.get('/api/messages/total')
-async def get_message_total(
- credentials: HTTPAuthorizationCredentials = Depends(security_scheme),
-):
- """
- Get total message count.
-
- To get the total message count:
- ```sh
- curl -H "Authorization: Bearer " http://localhost:3000/api/messages/total
- ```
- """
- sid = get_sid_from_token(credentials.credentials)
- return {'msg_total': message_stack.get_message_total(sid)}
-
-
-@app.delete('/api/messages')
-async def del_messages(
- credentials: HTTPAuthorizationCredentials = Depends(security_scheme),
-):
- """
- Delete messages.
-
- To delete messages:
- ```sh
-
- curl -X DELETE -H "Authorization: Bearer " http://localhost:3000/api/messages
- ```
- """
- sid = get_sid_from_token(credentials.credentials)
- message_stack.del_messages(sid)
- return {'ok': True}
-
-
@app.get('/api/list-files')
def list_files(request: Request, path: str = '/'):
"""
@@ -235,27 +215,25 @@ def list_files(request: Request, path: str = '/'):
curl http://localhost:3000/api/list-files
```
"""
- if path.startswith('/'):
- path = path[1:]
- abs_path = os.path.join(config.workspace_base, path)
+ if not request.state.session.agent_session.runtime:
+ return JSONResponse(
+ status_code=status.HTTP_404_NOT_FOUND,
+ content={'error': 'Runtime not yet initialized'},
+ )
+
try:
- files = os.listdir(abs_path)
+ return request.state.session.agent_session.runtime.file_store.list(path)
except Exception as e:
- logger.error(f'Error listing files: {e}', exc_info=False)
+ logger.error(f'Error refreshing files: {e}', exc_info=False)
+ error_msg = f'Error refreshing files: {e}'
return JSONResponse(
- status_code=status.HTTP_404_NOT_FOUND,
- content={'error': 'Path not found'},
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ content={'error': error_msg},
)
- files = [os.path.join(path, f) for f in files]
- files = [
- f + '/' if os.path.isdir(os.path.join(config.workspace_base, f)) else f
- for f in files
- ]
- return files
@app.get('/api/select-file')
-def select_file(file: str):
+def select_file(file: str, request: Request):
"""
Select a file.
@@ -265,12 +243,7 @@ def select_file(file: str):
```
"""
try:
- workspace_base = config.workspace_base
- file_path = Path(workspace_base, file)
- # The following will check if the file is within the workspace base and throw an exception if not
- file_path.resolve().relative_to(Path(workspace_base).resolve())
- with open(file_path, 'r') as selected_file:
- content = selected_file.read()
+ content = request.state.session.agent_session.runtime.file_store.read(file)
except Exception as e:
logger.error(f'Error opening file {file}: {e}', exc_info=False)
error_msg = f'Error opening file: {e}'
@@ -282,7 +255,7 @@ def select_file(file: str):
@app.post('/api/upload-files')
-async def upload_files(files: list[UploadFile]):
+async def upload_file(request: Request, files: list[UploadFile]):
"""
Upload files to the workspace.
@@ -292,13 +265,11 @@ async def upload_files(files: list[UploadFile]):
```
"""
try:
- workspace_base = config.workspace_base
for file in files:
- file_path = Path(workspace_base, file.filename)
- # The following will check if the file is within the workspace base and throw an exception if not
- file_path.resolve().relative_to(Path(workspace_base).resolve())
- with open(file_path, 'wb') as buffer:
- shutil.copyfileobj(file.file, buffer)
+ file_contents = await file.read()
+ request.state.session.agent_session.runtime.file_store.write(
+ file.filename, file_contents
+ )
except Exception as e:
logger.error(f'Error saving files: {e}', exc_info=True)
return JSONResponse(
@@ -309,9 +280,7 @@ async def upload_files(files: list[UploadFile]):
@app.get('/api/root_task')
-def get_root_task(
- credentials: HTTPAuthorizationCredentials = Depends(security_scheme),
-):
+def get_root_task(request: Request):
"""
Get root_task.
@@ -320,9 +289,7 @@ def get_root_task(
curl -H "Authorization: Bearer " http://localhost:3000/api/root_task
```
"""
- sid = get_sid_from_token(credentials.credentials)
- agent = agent_manager.sid_to_agent[sid]
- controller = agent.controller
+ controller = request.state.session.agent_session.controller
if controller is not None:
state = controller.get_state()
if state:
diff --git a/opendevin/server/mock/listen.py b/opendevin/server/mock/listen.py
index 5a4d36a13d83..517702f506d1 100644
--- a/opendevin/server/mock/listen.py
+++ b/opendevin/server/mock/listen.py
@@ -33,7 +33,7 @@ def read_root():
return {'message': 'This is a mock server'}
-@app.get('/api/litellm-models')
+@app.get('/api/options/models')
def read_llm_models():
return [
'gpt-4',
@@ -43,7 +43,7 @@ def read_llm_models():
]
-@app.get('/api/agents')
+@app.get('/api/options/agents')
def read_llm_agents():
return [
'MonologueAgent',
@@ -52,16 +52,6 @@ def read_llm_agents():
]
-@app.get('/api/messages')
-async def get_messages():
- return {'messages': []}
-
-
-@app.get('/api/messages/total')
-async def get_message_total():
- return {'msg_total': 0}
-
-
@app.get('/api/list-files')
def refresh_files():
return ['hello_world.py']
diff --git a/opendevin/server/session/__init__.py b/opendevin/server/session/__init__.py
index 405429cd07cd..391952d6aa31 100644
--- a/opendevin/server/session/__init__.py
+++ b/opendevin/server/session/__init__.py
@@ -1,5 +1,4 @@
from .manager import SessionManager
-from .msg_stack import message_stack
from .session import Session
session_manager = SessionManager()
diff --git a/opendevin/server/session/agent.py b/opendevin/server/session/agent.py
new file mode 100644
index 000000000000..aacb791bda2d
--- /dev/null
+++ b/opendevin/server/session/agent.py
@@ -0,0 +1,117 @@
+from typing import Optional
+
+from agenthub.codeact_agent.codeact_agent import CodeActAgent
+from opendevin.controller import AgentController
+from opendevin.controller.agent import Agent
+from opendevin.controller.state.state import State
+from opendevin.core.config import config
+from opendevin.core.logger import opendevin_logger as logger
+from opendevin.core.schema import ConfigType
+from opendevin.events.stream import EventStream
+from opendevin.llm.llm import LLM
+from opendevin.runtime import DockerSSHBox
+from opendevin.runtime.e2b.runtime import E2BRuntime
+from opendevin.runtime.runtime import Runtime
+from opendevin.runtime.server.runtime import ServerRuntime
+
+
+class AgentSession:
+ """Represents a session with an agent.
+
+ Attributes:
+ controller: The AgentController instance for controlling the agent.
+ """
+
+ sid: str
+ event_stream: EventStream
+ controller: Optional[AgentController] = None
+ runtime: Optional[Runtime] = None
+ _closed: bool = False
+
+ def __init__(self, sid):
+ """Initializes a new instance of the Session class."""
+ self.sid = sid
+ self.event_stream = EventStream(sid)
+
+ async def start(self, start_event: dict):
+ """Starts the agent session.
+
+ Args:
+ start_event: The start event data (optional).
+ """
+ if self.controller or self.runtime:
+ raise Exception(
+ 'Session already started. You need to close this session and start a new one.'
+ )
+ await self._create_runtime()
+ await self._create_controller(start_event)
+
+ async def close(self):
+ if self._closed:
+ return
+ if self.controller is not None:
+ end_state = self.controller.get_state()
+ end_state.save_to_session(self.sid)
+ await self.controller.close()
+ if self.runtime is not None:
+ self.runtime.close()
+ self._closed = True
+
+ async def _create_runtime(self):
+ if self.runtime is not None:
+ raise Exception('Runtime already created')
+ if config.runtime == 'server':
+ logger.info('Using server runtime')
+ self.runtime = ServerRuntime(self.event_stream, self.sid)
+ elif config.runtime == 'e2b':
+ logger.info('Using E2B runtime')
+ self.runtime = E2BRuntime(self.event_stream, self.sid)
+ else:
+ raise Exception(
+ f'Runtime not defined in config, or is invalid: {config.runtime}'
+ )
+
+ async def _create_controller(self, start_event: dict):
+ """Creates an AgentController instance.
+
+ Args:
+ start_event: The start event data (optional).
+ """
+ if self.controller is not None:
+ raise Exception('Controller already created')
+ if self.runtime is None:
+ raise Exception('Runtime must be initialized before the agent controller')
+ args = {
+ key: value
+ for key, value in start_event.get('args', {}).items()
+ if value != ''
+ } # remove empty values, prevent FE from sending empty strings
+ agent_cls = args.get(ConfigType.AGENT, config.agent.name)
+ model = args.get(ConfigType.LLM_MODEL, config.llm.model)
+ api_key = args.get(ConfigType.LLM_API_KEY, config.llm.api_key)
+ api_base = config.llm.base_url
+ max_iterations = args.get(ConfigType.MAX_ITERATIONS, config.max_iterations)
+ max_chars = args.get(ConfigType.MAX_CHARS, config.llm.max_chars)
+
+ logger.info(f'Creating agent {agent_cls} using LLM {model}')
+ llm = LLM(model=model, api_key=api_key, base_url=api_base)
+ agent = Agent.get_cls(agent_cls)(llm)
+ if isinstance(agent, CodeActAgent):
+ if not self.runtime or not isinstance(self.runtime.sandbox, DockerSSHBox):
+ logger.warning(
+ 'CodeActAgent requires DockerSSHBox as sandbox! Using other sandbox that are not stateful (LocalBox, DockerExecBox) will not work properly.'
+ )
+ self.runtime.init_sandbox_plugins(agent.sandbox_plugins)
+
+ self.controller = AgentController(
+ sid=self.sid,
+ event_stream=self.event_stream,
+ agent=agent,
+ max_iterations=int(max_iterations),
+ max_chars=int(max_chars),
+ )
+ try:
+ agent_state = State.restore_from_session(self.sid)
+ self.controller.set_state(agent_state)
+ except Exception as e:
+ print('Error restoring state', e)
diff --git a/opendevin/server/session/manager.py b/opendevin/server/session/manager.py
index 73bb0dcee9ba..cae6cae17e07 100644
--- a/opendevin/server/session/manager.py
+++ b/opendevin/server/session/manager.py
@@ -1,20 +1,12 @@
import asyncio
-import atexit
-import json
-import os
import time
-from typing import Callable
from fastapi import WebSocket
from opendevin.core.logger import opendevin_logger as logger
-from .msg_stack import message_stack
from .session import Session
-CACHE_DIR = os.getenv('CACHE_DIR', 'cache')
-SESSION_CACHE_FILE = os.path.join(CACHE_DIR, 'sessions.json')
-
class SessionManager:
_sessions: dict[str, Session] = {}
@@ -22,30 +14,21 @@ class SessionManager:
session_timeout: int = 600
def __init__(self):
- self._load_sessions()
- atexit.register(self.close)
asyncio.create_task(self._cleanup_sessions())
- def add_session(self, sid: str, ws_conn: WebSocket):
- if sid not in self._sessions:
- self._sessions[sid] = Session(sid=sid, ws=ws_conn)
- return
- self._sessions[sid].update_connection(ws_conn)
+ def add_or_restart_session(self, sid: str, ws_conn: WebSocket) -> Session:
+ if sid in self._sessions:
+ asyncio.create_task(self._sessions[sid].close())
+ self._sessions[sid] = Session(sid=sid, ws=ws_conn)
+ return self._sessions[sid]
- async def loop_recv(self, sid: str, dispatch: Callable):
- print(f'Starting loop_recv for sid: {sid}')
- """Starts listening for messages from the client."""
+ def get_session(self, sid: str) -> Session | None:
if sid not in self._sessions:
- return
- await self._sessions[sid].loop_recv(dispatch)
-
- def close(self):
- logger.info('Saving sessions...')
- self._save_sessions()
+ return None
+ return self._sessions.get(sid)
async def send(self, sid: str, data: dict[str, object]) -> bool:
"""Sends data to the client."""
- message_stack.add_message(sid, 'assistant', data)
if sid not in self._sessions:
return False
return await self._sessions[sid].send(data)
@@ -58,33 +41,6 @@ async def send_message(self, sid: str, message: str) -> bool:
"""Sends a message to the client."""
return await self.send(sid, {'message': message})
- def _save_sessions(self):
- data = {}
- for sid, conn in self._sessions.items():
- data[sid] = {
- 'sid': conn.sid,
- 'last_active_ts': conn.last_active_ts,
- 'is_alive': conn.is_alive,
- }
- if not os.path.exists(CACHE_DIR):
- os.makedirs(CACHE_DIR)
- with open(SESSION_CACHE_FILE, 'w+') as file:
- json.dump(data, file)
-
- def _load_sessions(self):
- try:
- with open(SESSION_CACHE_FILE, 'r') as file:
- data = json.load(file)
- for sid, sdata in data.items():
- conn = Session(sid, None)
- ok = conn.load_from_data(sdata)
- if ok:
- self._sessions[sid] = conn
- except FileNotFoundError:
- pass
- except json.decoder.JSONDecodeError:
- pass
-
async def _cleanup_sessions(self):
while True:
current_time = time.time()
diff --git a/opendevin/server/session/msg_stack.py b/opendevin/server/session/msg_stack.py
deleted file mode 100644
index 6e0862af4426..000000000000
--- a/opendevin/server/session/msg_stack.py
+++ /dev/null
@@ -1,114 +0,0 @@
-import asyncio
-import atexit
-import json
-import os
-import uuid
-
-from opendevin.core.logger import opendevin_logger as logger
-from opendevin.core.schema.action import ActionType
-
-CACHE_DIR = os.getenv('CACHE_DIR', 'cache')
-MSG_CACHE_FILE = os.path.join(CACHE_DIR, 'messages.json')
-
-
-class Message:
- id: str = str(uuid.uuid4())
- role: str # "user"| "assistant"
- payload: dict[str, object]
-
- def __init__(self, role: str, payload: dict[str, object]):
- self.role = role
- self.payload = payload
-
- def to_dict(self):
- return {'id': self.id, 'role': self.role, 'payload': self.payload}
-
- @classmethod
- def from_dict(cls, data: dict):
- m = cls(data['role'], data['payload'])
- m.id = data['id']
- return m
-
-
-class MessageStack:
- _messages: dict[str, list[Message]] = {}
-
- def __init__(self):
- self._load_messages()
- atexit.register(self.close)
-
- def close(self):
- logger.info('Saving messages...')
- self._save_messages()
-
- def add_message(self, sid: str, role: str, message: dict[str, object]):
- if sid not in self._messages:
- self._messages[sid] = []
- self._messages[sid].append(Message(role, message))
-
- def del_messages(self, sid: str):
- if sid not in self._messages:
- return
- del self._messages[sid]
- asyncio.create_task(self._del_messages(sid))
-
- def get_messages(self, sid: str) -> list[dict[str, object]]:
- if sid not in self._messages:
- return []
- return [msg.to_dict() for msg in self._messages[sid]]
-
- def get_message_total(self, sid: str) -> int:
- if sid not in self._messages:
- return 0
- cnt = 0
- for msg in self._messages[sid]:
- # Ignore assistant init message for now.
- if 'action' in msg.payload and msg.payload['action'] in [
- ActionType.INIT,
- ActionType.CHANGE_AGENT_STATE,
- ]:
- continue
- cnt += 1
- return cnt
-
- def _save_messages(self):
- if not os.path.exists(CACHE_DIR):
- os.makedirs(CACHE_DIR)
- data = {}
- for sid, msgs in self._messages.items():
- data[sid] = [msg.to_dict() for msg in msgs]
- with open(MSG_CACHE_FILE, 'w+') as file:
- json.dump(data, file)
-
- def _load_messages(self):
- try:
- with open(MSG_CACHE_FILE, 'r') as file:
- data = json.load(file)
- for sid, msgs in data.items():
- self._messages[sid] = [Message.from_dict(msg) for msg in msgs]
- except FileNotFoundError:
- pass
- except json.decoder.JSONDecodeError:
- pass
-
- async def _del_messages(self, del_sid: str):
- logger.info('Deleting messages...')
- try:
- with open(MSG_CACHE_FILE, 'r+') as file:
- data = json.load(file)
- new_data = {}
- for sid, msgs in data.items():
- if sid != del_sid:
- new_data[sid] = msgs
- # Move the file pointer to the beginning of the file to overwrite the original contents
- file.seek(0)
- # clean previous content
- file.truncate()
- json.dump(new_data, file)
- except FileNotFoundError:
- pass
- except json.decoder.JSONDecodeError:
- pass
-
-
-message_stack = MessageStack()
diff --git a/opendevin/server/session/session.py b/opendevin/server/session/session.py
index 0ca5c6d968de..18df617f11df 100644
--- a/opendevin/server/session/session.py
+++ b/opendevin/server/session/session.py
@@ -1,11 +1,19 @@
+import asyncio
import time
-from typing import Callable
from fastapi import WebSocket, WebSocketDisconnect
+from opendevin.core.const.guide_url import TROUBLESHOOTING_URL
from opendevin.core.logger import opendevin_logger as logger
+from opendevin.core.schema import AgentState
+from opendevin.core.schema.action import ActionType
+from opendevin.events.action import ChangeAgentStateAction, NullAction
+from opendevin.events.event import Event
+from opendevin.events.observation import AgentStateChangedObservation, NullObservation
+from opendevin.events.serialization import EventSource, event_from_dict, event_to_dict
+from opendevin.events.stream import EventStreamSubscriber
-from .msg_stack import message_stack
+from .agent import AgentSession
DEL_DELT_SEC = 60 * 60 * 5
@@ -15,13 +23,22 @@ class Session:
websocket: WebSocket | None
last_active_ts: int = 0
is_alive: bool = True
+ agent_session: AgentSession
def __init__(self, sid: str, ws: WebSocket | None):
self.sid = sid
self.websocket = ws
self.last_active_ts = int(time.time())
+ self.agent_session = AgentSession(sid)
+ self.agent_session.event_stream.subscribe(
+ EventStreamSubscriber.SERVER, self.on_event
+ )
- async def loop_recv(self, dispatch: Callable):
+ async def close(self):
+ self.is_alive = False
+ await self.agent_session.close()
+
+ async def loop_recv(self):
try:
if self.websocket is None:
return
@@ -31,24 +48,62 @@ async def loop_recv(self, dispatch: Callable):
except ValueError:
await self.send_error('Invalid JSON')
continue
-
- message_stack.add_message(self.sid, 'user', data)
- action = data.get('action', None)
- await dispatch(self.sid, action, data)
+ await self.dispatch(data)
except WebSocketDisconnect:
- self.is_alive = False
+ await self.close()
logger.info('WebSocket disconnected, sid: %s', self.sid)
except RuntimeError as e:
- # WebSocket is not connected
- if 'WebSocket is not connected' in str(e):
- self.is_alive = False
+ await self.close()
logger.exception('Error in loop_recv: %s', e)
+ async def _initialize_agent(self, data: dict):
+ await self.agent_session.event_stream.add_event(
+ ChangeAgentStateAction(AgentState.LOADING), EventSource.USER
+ )
+ await self.agent_session.event_stream.add_event(
+ AgentStateChangedObservation('', AgentState.LOADING), EventSource.AGENT
+ )
+ try:
+ await self.agent_session.start(data)
+ except Exception as e:
+ logger.exception(f'Error creating controller: {e}')
+ await self.send_error(
+ f'Error creating controller. Please check Docker is running and visit `{TROUBLESHOOTING_URL}` for more debugging information..'
+ )
+ return
+ await self.agent_session.event_stream.add_event(
+ ChangeAgentStateAction(AgentState.INIT), EventSource.USER
+ )
+
+ async def on_event(self, event: Event):
+ """Callback function for agent events.
+
+ Args:
+ event: The agent event (Observation or Action).
+ """
+ if isinstance(event, NullAction):
+ return
+ if isinstance(event, NullObservation):
+ return
+ if event.source == EventSource.AGENT and not isinstance(
+ event, (NullAction, NullObservation)
+ ):
+ await self.send(event_to_dict(event))
+
+ async def dispatch(self, data: dict):
+ action = data.get('action', '')
+ if action == ActionType.INIT:
+ await self._initialize_agent(data)
+ return
+ event = event_from_dict(data.copy())
+ await self.agent_session.event_stream.add_event(event, EventSource.USER)
+
async def send(self, data: dict[str, object]) -> bool:
try:
if self.websocket is None or not self.is_alive:
return False
await self.websocket.send_json(data)
+ await asyncio.sleep(0.001) # This flushes the data to the client
self.last_active_ts = int(time.time())
return True
except WebSocketDisconnect:
diff --git a/opendevin/storage/local.py b/opendevin/storage/local.py
index 8cc530599b6d..27791b1e68e7 100644
--- a/opendevin/storage/local.py
+++ b/opendevin/storage/local.py
@@ -28,7 +28,9 @@ def read(self, path: str) -> str:
def list(self, path: str) -> list[str]:
full_path = self.get_full_path(path)
- return [os.path.join(path, f) for f in os.listdir(full_path)]
+ files = [os.path.join(path, f) for f in os.listdir(full_path)]
+ files = [f + '/' if os.path.isdir(self.get_full_path(f)) else f for f in files]
+ return files
def delete(self, path: str) -> None:
full_path = self.get_full_path(path)
diff --git a/opendevin/storage/memory.py b/opendevin/storage/memory.py
index 45d4ce100f03..ea797ba7c699 100644
--- a/opendevin/storage/memory.py
+++ b/opendevin/storage/memory.py
@@ -30,6 +30,8 @@ def list(self, path: str) -> list[str]:
files.append(file)
else:
dir_path = os.path.join(path, parts[0])
+ if not dir_path.endswith('/'):
+ dir_path += '/'
if dir_path not in files:
files.append(dir_path)
return files
diff --git a/tests/unit/test_storage.py b/tests/unit/test_storage.py
index 272b648b7e44..dc173ff8fe59 100644
--- a/tests/unit/test_storage.py
+++ b/tests/unit/test_storage.py
@@ -54,10 +54,8 @@ def test_deep_list(setup_env):
store.write('foo/bar/baz.txt', 'Hello, world!')
store.write('foo/bar/qux.txt', 'Hello, world!')
store.write('foo/bar/quux.txt', 'Hello, world!')
- assert store.list('') == ['foo'], 'Expected foo, got {} for class {}'.format(
- store.list(''), store.__class__
- )
- assert store.list('foo') == ['foo/bar']
+ assert store.list('') == ['foo/'], f'for class {store.__class__}'
+ assert store.list('foo') == ['foo/bar/']
assert (
store.list('foo/bar').sort()
== ['foo/bar/baz.txt', 'foo/bar/qux.txt', 'foo/bar/quux.txt'].sort()