Skip to content

Commit

Permalink
fix session state after resuming (OpenDevin#1999)
Browse files Browse the repository at this point in the history
* fix state resuming

* fix session reconnection

* fix lint
  • Loading branch information
rbren authored and super-dainiu committed May 23, 2024
1 parent 862f96e commit 61d7e9a
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 18 deletions.
22 changes: 17 additions & 5 deletions frontend/src/services/session.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import { getSettings } from "./settings";
class Session {
private static _socket: WebSocket | null = null;

private static _latest_event_id: number = -1;

// callbacks contain a list of callable functions
// event: function, like:
// open: [function1, function2]
Expand All @@ -25,11 +27,10 @@ class Session {
private static _disconnecting = false;

public static restoreOrStartNewSession() {
const token = getToken();
if (Session.isConnected()) {
Session.disconnect();
}
Session._connect(token);
Session._connect();
}

public static startNewSession() {
Expand All @@ -44,13 +45,20 @@ class Session {
Session.send(eventString);
};

private static _connect(token: string = ""): void {
private static _connect(): void {
if (Session.isConnected()) return;
Session._connecting = true;

const protocol = window.location.protocol === "https:" ? "wss:" : "ws:";
const WS_URL = `${protocol}//${window.location.host}/ws?token=${token}`;
Session._socket = new WebSocket(WS_URL);
let wsURL = `${protocol}//${window.location.host}/ws`;
const token = getToken();
if (token) {
wsURL += `?token=${token}`;
if (Session._latest_event_id !== -1) {
wsURL += `&latest_event_id=${Session._latest_event_id}`;
}
}
Session._socket = new WebSocket(wsURL);
Session._setupSocket();
}

Expand All @@ -77,10 +85,14 @@ class Session {
return;
}
if (data.error && data.error_code === 401) {
Session._latest_event_id = -1;
clearToken();
} else if (data.token) {
setToken(data.token);
} else {
if (data.id !== undefined) {
Session._latest_event_id = data.id;
}
handleAssistantMessage(data);
}
};
Expand Down
5 changes: 5 additions & 0 deletions opendevin/controller/agent_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ async def set_agent_state_to(self, new_state: AgentState):
logger.info(
f'Setting agent({type(self.agent).__name__}) state from {self.state.agent_state} to {new_state}'
)

if new_state == self.state.agent_state:
return

Expand All @@ -169,6 +170,10 @@ async def set_agent_state_to(self, new_state: AgentState):
AgentStateChangedObservation('', self.state.agent_state), EventSource.AGENT
)

if new_state == AgentState.INIT and self.state.resume_state:
await self.set_agent_state_to(self.state.resume_state)
self.state.resume_state = None

def get_agent_state(self):
"""Returns the current state of the agent task."""
return self.state.agent_state
Expand Down
13 changes: 13 additions & 0 deletions opendevin/controller/state/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,13 @@
)
from opendevin.storage import get_file_store

RESUMABLE_STATES = [
AgentState.RUNNING,
AgentState.PAUSED,
AgentState.AWAITING_USER_INPUT,
AgentState.FINISHED,
]


@dataclass
class State:
Expand All @@ -31,6 +38,7 @@ class State:
outputs: dict = field(default_factory=dict)
error: str | None = None
agent_state: AgentState = AgentState.LOADING
resume_state: AgentState | None = None
metrics: Metrics = Metrics()

def save_to_session(self, sid: str):
Expand All @@ -53,6 +61,11 @@ def restore_from_session(sid: str) -> 'State':
except Exception as e:
logger.error(f'Failed to restore state from session: {e}')
raise e
if state.agent_state in RESUMABLE_STATES:
state.resume_state = state.agent_state
else:
state.resume_state = None
state.agent_state = AgentState.LOADING
return state

def get_current_user_intent(self):
Expand Down
19 changes: 10 additions & 9 deletions opendevin/events/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,15 +52,16 @@ def _get_id_from_filename(self, filename: str) -> int:
return int(filename.split('/')[-1].split('.')[0])

def get_events(self, start_id=0, end_id=None) -> Iterable[Event]:
try:
events = self._file_store.list(f'sessions/{self.sid}/events')
except FileNotFoundError:
return
for event_str in events:
id = self._get_id_from_filename(event_str)
if start_id <= id and (end_id is None or id <= end_id):
event = self.get_event(id)
yield event
event_id = start_id
while True:
if end_id is not None and event_id > end_id:
break
try:
event = self.get_event(event_id)
except FileNotFoundError:
break
yield event
event_id += 1

def get_event(self, id: int) -> Event:
filename = self._get_filename_for_id(id)
Expand Down
8 changes: 4 additions & 4 deletions opendevin/server/listen.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,11 +154,11 @@ async def websocket_endpoint(websocket: WebSocket):
session = session_manager.add_or_restart_session(sid, websocket)
await websocket.send_json({'token': token, 'status': 'ok'})

last_event_id = -1
if websocket.query_params.get('last_event_id'):
last_event_id = int(websocket.query_params.get('last_event_id'))
latest_event_id = -1
if websocket.query_params.get('latest_event_id'):
latest_event_id = int(websocket.query_params.get('latest_event_id'))
for event in session.agent_session.event_stream.get_events(
start_id=last_event_id + 1
start_id=latest_event_id + 1
):
if isinstance(event, NullAction) or isinstance(event, NullObservation):
continue
Expand Down

0 comments on commit 61d7e9a

Please sign in to comment.