Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support reconnecting websocket #257

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
54 changes: 33 additions & 21 deletions packages/solara-widget-manager/src/manager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,7 @@ const WIDGET_MIMETYPE = 'application/vnd.jupyter.widget-view+json';
/**
* Time (in ms) after which we consider the control comm target not responding.
*/
export const CONTROL_COMM_TIMEOUT = 500;

export const CONTROL_COMM_TIMEOUT = 2000;
/**
* A custom widget manager to render widgets with Voila
*/
Expand All @@ -78,11 +77,15 @@ export class WidgetManager extends JupyterLabManager {
);
this._registerWidgets();
this._loader = requireLoader;
const commId = base.uuid();
const kernel = context.sessionContext?.session?.kernel;
this.connectControlComm();
if (!kernel) {
throw new Error('No current kernel');
}
}
async connectControlComm() {
const commId = base.uuid();
const kernel = this.context.sessionContext?.session?.kernel;
this.controlComm = kernel.createComm('solara.control', commId);
this.controlCommHandler = {
onMsg: (msg) => {
Expand All @@ -107,31 +110,40 @@ export class WidgetManager extends JupyterLabManager {
}
async check() {
// checks if app is still valid (e.g. server restarted and lost the widget state)
const okPromise = new Promise((resolve, reject) => {
this.controlCommHandler = {
onMsg: (msg) => {
// if we are connected to the same kernel, we'll get a reply instantly
// however, if we are connected to a new kernel, we rely on the timeout
// so every time we create a new comm.

const kernel = this.context.sessionContext?.session?.kernel;
const commId = base.uuid();
const controlComm = kernel.createComm('solara.control', commId);
controlComm.open({}, {}, [])
try {
return await new Promise((resolve, reject) => {
controlComm.onMsg = (msg) => {
const data = msg['content']['data'];
if (data.method === 'finished') {
resolve(data.ok);
if (data.method === 'check') {
if (data.ok === true) {
resolve({ ok: true, message: data.message });
} else {
resolve({ ok: false, message: data.message });
}
}
else {
reject(data.error);
reject({ ok: false, message: "unexpected message" });
}
},
onClose: () => {
}
controlComm.onClose = () => {
console.error("closed solara control comm")
reject()
reject({ ok: false, message: "closed solara control comm" });
}
};
setTimeout(() => {
reject('timeout');
}, CONTROL_COMM_TIMEOUT);
});
this.controlComm.send({ method: 'check' });
try {
return await okPromise;
setTimeout(() => {
reject('timeout');
}, CONTROL_COMM_TIMEOUT);
controlComm.send({ method: 'check' });
});
} catch (e) {
return false;
return { ok: false, message: e };
}
}

Expand Down
56 changes: 50 additions & 6 deletions solara/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,10 @@ def reload(self):
@dataclasses.dataclass
class AppContext:
id: str
session_id: str
kernel: kernel.Kernel
# some object to identify the thread/eventloop/callstack that created the context
owner: Any = dataclasses.field(default_factory=lambda: object())
control_sockets: List[WebSocket] = dataclasses.field(default_factory=list)
# this is the 'private' version of the normally global ipywidgets.Widgets.widget dict
# see patch.py
Expand All @@ -274,6 +277,13 @@ class AppContext:
reload: Callable = lambda: None
state: Any = None
container: Optional[DOMWidget] = None
# did the client miss any messages?
missed_messages: bool = False
# was this page session created from a reconnected kernel?
# this information is used to determine if a reconnected websocket
# is connected to the same worker or not. If connected to a new worker
# we failed to reconnect the app
reconnected_kernel: bool = False

def display(self, *args):
print(args) # noqa
Expand Down Expand Up @@ -303,6 +313,7 @@ def close(self):
# what if we reference each other
# import gc
# gc.collect()
self.kernel.close()
if self.id in contexts:
del contexts[self.id]

Expand Down Expand Up @@ -340,6 +351,7 @@ def create_dummy_context():

app_context = AppContext(
id="dummy",
session_id="dummy-id",
kernel=kernel.Kernel(),
)
return app_context
Expand Down Expand Up @@ -491,6 +503,13 @@ def on_msg(msg):
comm.send({"method": "finished", "widget_id": context.container._model_id})
elif method == "check":
context = get_current_context()
if context.missed_messages:
comm.send({"method": "check", "ok": False, "message": "Missed messages"})
else:
if context.reconnected_kernel:
comm.send({"method": "check", "ok": True, "message": "All fine"})
else:
comm.send({"method": "check", "ok": False, "message": "Not reconnected"})
elif method == "reload":
assert app is not None
context = get_current_context()
Expand All @@ -501,6 +520,11 @@ def on_msg(msg):

comm.on_msg(on_msg)

def on_close(msg):
logger.info("solara control comm closed: %r", msg)

comm.on_close(on_close)

def reload():
# we don't reload the app ourself, we send a message to the client
# this ensures that we don't run code of any client that for some reason is connected
Expand All @@ -517,17 +541,37 @@ def register_solara_comm_target(kernel: Kernel):
kernel.comm_manager.register_target("solara.control", solara_comm_target)


def initialize_virtual_kernel(context_id: str, websocket: websocket.WebsocketWrapper):
kernel = Kernel()
logger.info("new virtual kernel: %s", context_id)
context = contexts[context_id] = AppContext(id=context_id, kernel=kernel, control_sockets=[], widgets={}, templates={})
def initialize_virtual_kernel(session_id: str, context_id: str, websocket: websocket.WebsocketWrapper):
if context_id in contexts:
logger.info("reusing virtual kernel: %s", context_id)
context = contexts[context_id]
if context.session_id != session_id:
logger.critical("Session id mismatch when reusing kernel (hack attempt?): %s != %s", context.session_id, session_id)
websocket.send_text("Session id mismatch when reusing kernel (hack attempt?)")
websocket.close()
raise ValueError("Session id mismatch")
kernel = context.kernel
context.reconnected_kernel = True
else:
kernel = Kernel()
logger.info("new virtual kernel: %s", context_id)
context = contexts[context_id] = AppContext(id=context_id, session_id=session_id, kernel=kernel, control_sockets=[], widgets={}, templates={})
context.reconnected_kernel = False
with context:
widgets.register_comm_target(kernel)
register_solara_comm_target(kernel)
with context:
widgets.register_comm_target(kernel)
register_solara_comm_target(kernel)
assert kernel is Kernel.instance()
kernel.shell_stream = WebsocketStreamWrapper(websocket, "shell")
kernel.control_stream = WebsocketStreamWrapper(websocket, "control")
kernel.session.websockets.add(websocket)
context.missed_messages = kernel.session.has_dropped_messges
if not kernel.session.has_dropped_messges:
if kernel.session.message_queue:
logger.info("Sending messages from queue (due to reconnect)")
for message in kernel.session.message_queue:
kernel.session.send_websockets(message)
kernel.session.message_queue.clear()


from . import patch # noqa
Expand Down
64 changes: 53 additions & 11 deletions solara/server/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import warnings
from binascii import b2a_base64
from datetime import datetime
from typing import Set
from typing import List, Set

import ipykernel
import ipykernel.kernelbase
Expand All @@ -16,6 +16,7 @@
from zmq.eventloop.zmqstream import ZMQStream

import solara
from solara.server import app
from solara.server.shell import SolaraInteractiveShell

from . import settings, websocket
Expand Down Expand Up @@ -204,19 +205,20 @@ def flush(self, *ignore):
pass


def send_websockets(websockets: Set[websocket.WebsocketWrapper], binary_msg):
for ws in list(websockets):
try:
ws.send(binary_msg)
except: # noqa
# in case of any issue, we simply remove it from the list
websockets.remove(ws)


class SessionWebsocket(session.Session):
def __init__(self, *args, **kwargs):
super(SessionWebsocket, self).__init__(*args, **kwargs)
self.websockets: Set[websocket.WebsocketWrapper] = set() # map from .. msg id to websocket?
self.message_queue: List[bytes] = []
self.has_dropped_messges = False
self.reconnect_buffer_length_bytes = solara.util.parse_size(settings.page_session.queue_size)

def close(self):
for ws in list(self.websockets):
try:
ws.close()
except: # noqa
pass

def send(self, stream, msg_or_type, content=None, parent=None, ident=None, buffers=None, track=False, header=None, metadata=None):
try:
Expand All @@ -238,10 +240,46 @@ def send(self, stream, msg_or_type, content=None, parent=None, ident=None, buffe
if settings.main.use_pdb:
pdb.post_mortem()
raise
send_websockets(self.websockets, wire_message)
self.send_websockets(wire_message)
except Exception as e:
logger.exception("Error sending message: %s", e)

def send_websockets(self, binary_msg):
for ws in list(self.websockets):
try:
ws.send(binary_msg)
except: # noqa
# in case of any issue, we simply remove it from the list
# logger.exception("Error sending websocket message: %s", binary_msg)
try:
self.websockets.remove(ws)
except KeyError:
# this can happen when..
pass
# if we dropped messages, we will not store missed messages anymore
if not self.has_dropped_messges:
self.message_queue.append(binary_msg)
message_queue_bytes = sum(len(m) for m in self.message_queue)
logger.info("Message queue size: %s (max size is %s)", message_queue_bytes, self.reconnect_buffer_length_bytes)
if message_queue_bytes > self.reconnect_buffer_length_bytes:
self.message_queue.clear()
logger.info("Clearing message queue, too many bytes: %s", message_queue_bytes)
self.has_dropped_messges = True
# our current strategy is to close the page session / kernel
# once the message queue is too large
# in the future we may want to find a new strategy of recovering the
# widget state like Voila. However, we currently do not believe that
# recovery strategy is stable (e.g. while getting the state, a thread
# could mutate it giving the frontend an invalid state)
app_context = app.get_current_context()
# if we close the page session, we also close the Reacton context
# but we are most likely already in it, which will give a deadlock
# app_context.close()
# instead, we just spawn a thread
import threading

threading.Thread(target=app_context.close).start()


class Kernel(ipykernel.kernelbase.Kernel):
implementation = "solara"
Expand Down Expand Up @@ -299,3 +337,7 @@ def set_parent(self, ident, parent, channel="shell"):
super().set_parent(ident, parent, channel)
if channel == "shell":
self.shell.set_parent(parent)

def close(self):
# called when the PageSession/AppContext is closed
self.session.close()
45 changes: 0 additions & 45 deletions solara/server/pyodide.py

This file was deleted.

2 changes: 1 addition & 1 deletion solara/server/reload.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def start(self):
# it can happen that an import is done at runtime, that we miss (could be in a thread)
# so we always reload all modules except the ignore_modules
self.ignore_modules = set(sys.modules)
logger.info("Ignoring reloading modules: %s", self.ignore_modules)
logger.debug("Ignoring reloading modules: %s", self.ignore_modules)
self._first = False

def _on_change(self, name):
Expand Down