Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Connection state #2671

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
11 changes: 10 additions & 1 deletion reflex/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
from reflex.state import (
BaseState,
RouterData,
SessionStatus,
State,
StateManager,
StateUpdate,
Expand Down Expand Up @@ -1138,6 +1139,9 @@ async def process(
# assignment will recurse into substates and force recalculation of
# dependent ComputedVar (dynamic route variables)
state.router_data = router_data
if state.router:
state.router.update(router_data)
else:
benedikt-bartscher marked this conversation as resolved.
Show resolved Hide resolved
state.router = RouterData(router_data)

# Preprocess the event.
Expand Down Expand Up @@ -1326,7 +1330,7 @@ def on_connect(self, sid, environ):
"""
pass

def on_disconnect(self, sid):
async def on_disconnect(self, sid):
"""Event for when the websocket disconnects.

Args:
Expand All @@ -1335,6 +1339,11 @@ def on_disconnect(self, sid):
disconnect_token = self.sid_to_token.pop(sid, None)
if disconnect_token:
self.token_to_sid.pop(disconnect_token, None)
else:
return

async with self.app.state_manager.modify_state(disconnect_token) as state:
state.router.session.status = SessionStatus.DISCONNECTED

async def emit_update(self, update: StateUpdate, sid: str) -> None:
"""Emit an update to the client.
Expand Down
44 changes: 36 additions & 8 deletions reflex/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import asyncio
import contextlib
import copy
import datetime
import enum
import functools
import inspect
import os
Expand Down Expand Up @@ -120,24 +122,42 @@ def __init__(self, router_data: Optional[dict] = None):
self.params = router_data.get(constants.RouteVar.QUERY, {})


class SessionStatus(enum.Enum):
"""The status of the session."""

INITIAL = "initial"
CONNECTED = "connected"
DISCONNECTED = "disconnected"
RECONNECTED = "reconnected"


class SessionData(Base):
"""An object containing session data."""

client_token: str = ""
client_ip: str = ""
session_id: str = ""
status: SessionStatus = SessionStatus.INITIAL
# also represents disconnected_at if status is DISCONNECTED
last_event: datetime.datetime = datetime.datetime.now()

def __init__(self, router_data: Optional[dict] = None):
"""Initalize the SessionData object based on router_data.
def update(self, router_data: Optional[dict] = None):
"""Update the session data based on the router_data.

Args:
router_data: the router_data dict.
"""
super().__init__()
if router_data:
self.client_token = router_data.get(constants.RouteVar.CLIENT_TOKEN, "")
self.client_ip = router_data.get(constants.RouteVar.CLIENT_IP, "")
self.session_id = router_data.get(constants.RouteVar.SESSION_ID, "")
self.last_event = datetime.datetime.now()
if not router_data:
return
self.client_token = router_data.get(constants.RouteVar.CLIENT_TOKEN, "")
self.client_ip = router_data.get(constants.RouteVar.CLIENT_IP, "")
new_session_id = router_data.get(constants.RouteVar.SESSION_ID, "")
if self.session_id and new_session_id and self.session_id != new_session_id:
self.status = SessionStatus.RECONNECTED
else:
self.status = SessionStatus.CONNECTED
self.session_id = new_session_id


class RouterData(Base):
Expand All @@ -154,7 +174,15 @@ def __init__(self, router_data: Optional[dict] = None):
router_data: the router_data dict.
"""
super().__init__()
self.session = SessionData(router_data)
self.update(router_data)

def update(self, router_data: Optional[dict] = None):
"""Update the router data based on the router_data.

Args:
router_data: the router_data dict.
"""
self.session.update(router_data)
self.headers = HeaderData(router_data)
self.page = PageData(router_data)

Expand Down
Loading