Skip to content
46 changes: 35 additions & 11 deletions reflex/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@
)
from reflex.utils.imports import ImportVar
from reflex.utils.misc import run_in_thread
from reflex.utils.token_manager import TokenManager
from reflex.utils.token_manager import RedisTokenManager, TokenManager
from reflex.utils.types import ASGIApp, Message, Receive, Scope, Send

if TYPE_CHECKING:
Expand Down Expand Up @@ -2033,11 +2033,13 @@ def __init__(self, namespace: str, app: App):
self._token_manager = TokenManager.create()

@property
def token_to_sid(self) -> dict[str, str]:
def token_to_sid(self) -> Mapping[str, str]:
"""Get token to SID mapping for backward compatibility.

Note: this mapping is read-only.

Returns:
The token to SID mapping dict.
The token to SID mapping.
"""
# For backward compatibility, expose the underlying dict
return self._token_manager.token_to_sid
Expand All @@ -2059,6 +2061,9 @@ async def on_connect(self, sid: str, environ: dict):
sid: The Socket.IO session id.
environ: The request information, including HTTP headers.
"""
if isinstance(self._token_manager, RedisTokenManager):
# Make sure this instance is watching for updates from other instances.
self._token_manager.ensure_lost_and_found_task(self.emit_update)
query_params = urllib.parse.parse_qs(environ.get("QUERY_STRING", ""))
token_list = query_params.get("token", [])
if token_list:
Expand All @@ -2072,11 +2077,14 @@ async def on_connect(self, sid: str, environ: dict):
f"Frontend version {subprotocol} for session {sid} does not match the backend version {constants.Reflex.VERSION}."
)

def on_disconnect(self, sid: str):
def on_disconnect(self, sid: str) -> asyncio.Task | None:
"""Event for when the websocket disconnects.

Args:
sid: The Socket.IO session id.

Returns:
An asyncio Task for cleaning up the token, or None.
"""
# Get token before cleaning up
disconnect_token = self.sid_to_token.get(sid)
Expand All @@ -2091,6 +2099,8 @@ def on_disconnect(self, sid: str):
lambda t: t.exception()
and console.error(f"Token cleanup error: {t.exception()}")
)
return task
return None

async def emit_update(self, update: StateUpdate, token: str) -> None:
"""Emit an update to the client.
Expand All @@ -2100,16 +2110,30 @@ async def emit_update(self, update: StateUpdate, token: str) -> None:
token: The client token (tab) associated with the event.
"""
client_token, _ = _split_substate_key(token)
sid = self.token_to_sid.get(client_token)
if sid is None:
# If the sid is None, we are not connected to a client. Prevent sending
# updates to all clients.
console.warn(f"Attempting to send delta to disconnected client {token!r}")
socket_record = self._token_manager.token_to_socket.get(client_token)
if (
socket_record is None
or socket_record.instance_id != self._token_manager.instance_id
):
if isinstance(self._token_manager, RedisTokenManager):
# The socket belongs to another instance of the app, send it to the lost and found.
if not await self._token_manager.emit_lost_and_found(
client_token, update
):
console.warn(
f"Failed to send delta to lost and found for client {token!r}"
)
else:
# If the socket record is None, we are not connected to a client. Prevent sending
# updates to all clients.
console.warn(
f"Attempting to send delta to disconnected client {token!r}"
)
return
# Creating a task prevents the update from being blocked behind other coroutines.
await asyncio.create_task(
self.emit(str(constants.SocketEvent.EVENT), update, to=sid),
name=f"reflex_emit_event|{token}|{sid}|{time.time()}",
self.emit(str(constants.SocketEvent.EVENT), update, to=socket_record.sid),
name=f"reflex_emit_event|{token}|{socket_record.sid}|{time.time()}",
)

async def on_event(self, sid: str, data: Any):
Expand Down
2 changes: 1 addition & 1 deletion reflex/istate/manager/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ class StateManagerRedis(StateManager):
# The keyspace subscription string when redis is waiting for lock to be released.
_redis_notify_keyspace_events: str = dataclasses.field(
default="K" # Enable keyspace notifications (target a particular key)
"$" # For String commands (like setting keys)
"g" # For generic commands (DEL, EXPIRE, etc)
"x" # For expired events
"e" # For evicted events (i.e. maxmemory exceeded)
Expand All @@ -76,7 +77,6 @@ class StateManagerRedis(StateManager):
_redis_keyspace_lock_release_events: set[bytes] = dataclasses.field(
default_factory=lambda: {
b"del",
b"expire",
b"expired",
b"evicted",
}
Expand Down
Loading