Skip to content

Commit e5ca67b

Browse files
fix: avoid closing a kernel twice by using a mutex
A kernel could be close multiple times if the websocket was closed at the same time as the close beacon was send. This caused in CI KeyError on custom_storage.py when the kernel cleanup function was called twice. State managegment for pages is not protected by a mutex, so that it is thread safe.
1 parent 2501c83 commit e5ca67b

File tree

3 files changed

+117
-75
lines changed

3 files changed

+117
-75
lines changed

solara/server/kernel_context.py

Lines changed: 110 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ class VirtualKernelContext:
7373
_last_kernel_cull_task: "Optional[asyncio.Future[None]]" = None
7474
closed_event: threading.Event = dataclasses.field(default_factory=threading.Event)
7575
_on_close_callbacks: List[Callable[[], None]] = dataclasses.field(default_factory=list)
76+
lock: threading.RLock = dataclasses.field(default_factory=threading.RLock)
7677

7778
def __post_init__(self):
7879
with self:
@@ -108,14 +109,15 @@ def __exit__(self, *args):
108109
current_context[key] = local.kernel_context_stack.pop()
109110

110111
def close(self):
111-
if self.closed_event.is_set():
112-
logger.error("Tried to close a kernel context that is already closed: %s", self.id)
113-
return
114-
logger.info("Shut down virtual kernel: %s", self.id)
115-
with self:
112+
with self, self.lock:
113+
for key in self.page_status:
114+
self.page_status[key] = PageStatus.CLOSED
115+
if self.closed_event.is_set():
116+
logger.error("Tried to close a kernel context that is already closed: %s", self.id)
117+
return
118+
logger.info("Shut down virtual kernel: %s", self.id)
116119
for f in reversed(self._on_close_callbacks):
117120
f()
118-
with self:
119121
if self.app_object is not None:
120122
if isinstance(self.app_object, reacton.core._RenderContext):
121123
try:
@@ -128,9 +130,9 @@ def close(self):
128130
# import gc
129131
# gc.collect()
130132
self.kernel.session.close()
131-
if self.id in contexts:
132-
del contexts[self.id]
133-
self.closed_event.set()
133+
if self.id in contexts:
134+
del contexts[self.id]
135+
self.closed_event.set()
134136

135137
def _state_reset(self):
136138
state_directory = Path(".") / "states"
@@ -157,77 +159,125 @@ def state_save(self, state_directory: os.PathLike):
157159

158160
def page_connect(self, page_id: str):
159161
logger.info("Connect page %s for kernel %s", page_id, self.id)
160-
assert self.page_status.get(page_id) != PageStatus.CLOSED, "cannot connect with the same page_id after a close"
161-
self.page_status[page_id] = PageStatus.CONNECTED
162-
if self._last_kernel_cull_task:
163-
self._last_kernel_cull_task.cancel()
164-
165-
def page_disconnect(self, page_id: str) -> "asyncio.Future[None]":
166-
"""Signal that a page has disconnected, and schedule a kernel cull if needed.
167-
168-
During the kernel reconnect window, we will keep the kernel alive, even if all pages have disconnected.
169-
170-
Returns a future that is set when the kernel cull is done.
171-
The scheduled kernel cull can be cancelled when a new page connects, a new disconnect is scheduled,
172-
or a page if explicitly closed.
173-
"""
174-
logger.info("Disconnect page %s for kernel %s", page_id, self.id)
175-
future: "asyncio.Future[None]" = asyncio.Future()
176-
self.page_status[page_id] = PageStatus.DISCONNECTED
177-
current_event_loop = asyncio.get_event_loop()
162+
with self.lock:
163+
if self.closed_event.is_set():
164+
raise RuntimeError("Cannot connect a page to a closed kernel")
165+
if page_id in self.page_status and self.page_status.get(page_id) == PageStatus.CLOSED:
166+
raise RuntimeError("Cannot connect a page that is already closed")
167+
self.page_status[page_id] = PageStatus.CONNECTED
168+
if self._last_kernel_cull_task:
169+
logger.info("Cancelling previous kernel cull task for virtual kernel %s", self.id)
170+
self._last_kernel_cull_task.cancel()
178171

172+
def _bump_kernel_cull(self):
179173
async def kernel_cull():
180174
try:
181175
cull_timeout_sleep_seconds = solara.util.parse_timedelta(solara.server.settings.kernel.cull_timeout)
182176
logger.info("Scheduling kernel cull, will wait for max %s before shutting down the virtual kernel %s", cull_timeout_sleep_seconds, self.id)
183177
await asyncio.sleep(cull_timeout_sleep_seconds)
184-
has_connected_pages = PageStatus.CONNECTED in self.page_status.values()
185-
if has_connected_pages:
186-
logger.info("We have (re)connected pages, keeping the virtual kernel %s alive", self.id)
187-
else:
188-
logger.info("No connected pages, and timeout reached, shutting down virtual kernel %s", self.id)
189-
self.close()
190-
current_event_loop.call_soon_threadsafe(future.set_result, None)
178+
logger.info("Timeout reached, checking if we should be shutting down virtual kernel %s", self.id)
179+
with self.lock:
180+
has_connected_pages = PageStatus.CONNECTED in self.page_status.values()
181+
if has_connected_pages:
182+
logger.info("We have (re)connected pages, keeping the virtual kernel %s alive", self.id)
183+
else:
184+
logger.info("No connected pages, and timeout reached, shutting down virtual kernel %s", self.id)
185+
self.close()
186+
if current_event_loop is not None and future is not None:
187+
current_event_loop.call_soon_threadsafe(future.set_result, None)
191188
except asyncio.CancelledError:
192-
if sys.version_info >= (3, 9):
193-
current_event_loop.call_soon_threadsafe(future.cancel, "cancelled because a new cull task was scheduled")
194-
else:
195-
current_event_loop.call_soon_threadsafe(future.cancel)
189+
if current_event_loop is not None and future is not None:
190+
if sys.version_info >= (3, 9):
191+
current_event_loop.call_soon_threadsafe(future.cancel, "cancelled because a new cull task was scheduled")
192+
else:
193+
current_event_loop.call_soon_threadsafe(future.cancel)
196194
raise
197195

198-
has_connected_pages = PageStatus.CONNECTED in self.page_status.values()
199-
if not has_connected_pages:
200-
# when we have no connected pages, we will schedule a kernel cull
196+
async def create_task():
197+
task = asyncio.create_task(kernel_cull())
198+
# create a reference to the task so we can cancel it later
199+
self._last_kernel_cull_task = task
200+
await task
201+
202+
with self.lock:
203+
future: "Optional[asyncio.Future[None]]" = None
204+
current_event_loop: Optional[asyncio.AbstractEventLoop] = None
205+
try:
206+
future = asyncio.Future()
207+
current_event_loop = asyncio.get_event_loop()
208+
except RuntimeError:
209+
pass
201210
if self._last_kernel_cull_task:
211+
logger.info("Cancelling previous kernel cull tas for virtual kernel %s", self.id)
202212
self._last_kernel_cull_task.cancel()
203213

204-
async def create_task():
205-
task = asyncio.create_task(kernel_cull())
206-
# create a reference to the task so we can cancel it later
207-
self._last_kernel_cull_task = task
208-
await task
209-
214+
logger.info("Scheduling kernel cull for virtual kernel %s", self.id)
210215
asyncio.run_coroutine_threadsafe(create_task(), keep_alive_event_loop)
211-
else:
212-
future.set_result(None)
213-
return future
216+
return future
217+
218+
def page_disconnect(self, page_id: str) -> "Optional[asyncio.Future[None]]":
219+
"""Signal that a page has disconnected, and schedule a kernel cull if needed.
220+
221+
During the kernel reconnect window, we will keep the kernel alive, even if all pages have disconnected.
222+
223+
Will return a future that is set when the kernel cull is done, when an event loop is available.
224+
The scheduled kernel cull can be cancelled when a new page connects, a new disconnect is scheduled,
225+
or a page if explicitly closed.
226+
"""
227+
228+
logger.info("Disconnect page %s for kernel %s", page_id, self.id)
229+
future: "asyncio.Future[None]" = asyncio.Future()
230+
with self.lock:
231+
if self.page_status[page_id] == PageStatus.CLOSED:
232+
# this happens when the close beackon call happens before the websocket disconnect
233+
logger.info("Page %s already closed for kernel %s", page_id, self.id)
234+
future.set_result(None)
235+
return future
236+
assert self.page_status[page_id] == PageStatus.CONNECTED, "cannot disconnect a page that is in state: %r" % self.page_status[page_id]
237+
self.page_status[page_id] = PageStatus.DISCONNECTED
238+
has_connected_pages = PageStatus.CONNECTED in self.page_status.values()
239+
if not has_connected_pages:
240+
# when we have no connected pages, we will schedule a kernel cull
241+
future = self._bump_kernel_cull()
242+
else:
243+
logger.info("Still have connected pages, do nothing for kernel %s", self.id)
244+
future.set_result(None)
245+
return future
214246

215247
def page_close(self, page_id: str):
216-
"""Signal that a page has closed, and close the context if needed.
248+
"""Signal that a page has closed, close the context if needed and schedule a kernel cull if needed.
217249
218250
Closing the browser tab or a page navigation means an explicit close, which is
219251
different from a websocket/page disconnect, which we might want to recover from.
220252
221253
"""
222-
self.page_status[page_id] = PageStatus.CLOSED
223-
logger.info("Close page %s for kernel %s", page_id, self.id)
224-
has_connected_pages = PageStatus.CONNECTED in self.page_status.values()
225-
has_disconnected_pages = PageStatus.DISCONNECTED in self.page_status.values()
226-
if not (has_connected_pages or has_disconnected_pages):
227-
logger.info("No connected or disconnected pages, shutting down virtual kernel %s", self.id)
228-
if self._last_kernel_cull_task:
229-
self._last_kernel_cull_task.cancel()
230-
self.close()
254+
future: "Optional[asyncio.Future[None]]" = None
255+
256+
try:
257+
future = asyncio.Future()
258+
except RuntimeError:
259+
pass
260+
else:
261+
future.set_result(None)
262+
with self.lock:
263+
if self.page_status[page_id] == PageStatus.CLOSED:
264+
logger.info("Page %s already closed for kernel %s", page_id, self.id)
265+
return
266+
self.page_status[page_id] = PageStatus.CLOSED
267+
logger.info("Close page %s for kernel %s", page_id, self.id)
268+
has_connected_pages = PageStatus.CONNECTED in self.page_status.values()
269+
has_disconnected_pages = PageStatus.DISCONNECTED in self.page_status.values()
270+
# if we have disconnected pages, we may have cancelled the kernel cull task
271+
# if we still have connected pages, it will go to a disconnected state again
272+
# which will also trigger a new kernel cull
273+
if has_disconnected_pages:
274+
future = self._bump_kernel_cull()
275+
if not (has_connected_pages or has_disconnected_pages):
276+
logger.info("No connected or disconnected pages, shutting down virtual kernel %s", self.id)
277+
self.close()
278+
else:
279+
logger.info("Still have connected or disconnected pages, keeping virtual kernel %s alive", self.id)
280+
return future
231281

232282

233283
try:

tests/integration/lifecycle_test.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import threading
21
from pathlib import Path
32
from typing import cast
43

@@ -81,12 +80,4 @@ def test_kernel_lifecycle_close_while_disconnected(
8180
page_session.locator("text=Clicks-1").click()
8281
page_session.locator("text=Clicks-2").wait_for()
8382
page_session.goto("about:blank")
84-
# give a bit of time to make sure the cull task is started
85-
page_session.wait_for_timeout(100)
86-
cull_task_2 = context._last_kernel_cull_task
87-
assert cull_task_2 is not None
88-
# we can't mix do async, so we hook up an event to the Future
89-
event = threading.Event()
90-
cull_task_2.add_done_callback(lambda x: event.set())
91-
event.wait()
92-
assert context.closed_event.is_set()
83+
assert context.closed_event.wait(timeout=20)

tests/unit/lifecycle_test.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -84,20 +84,21 @@ async def test_kernel_lifecycle_close_single(close_first, short_cull_timeout):
8484
async def test_kernel_lifecycle_close_while_disconnected(close_first, short_cull_timeout):
8585
# a reconnect should be possible within the reconnect window
8686
websocket = Mock()
87-
context = kernel_context.initialize_virtual_kernel("session-id-1", "kernel-id-1", websocket)
87+
context = kernel_context.initialize_virtual_kernel(f"session-id-1-{close_first}", f"kernel-id-1-{close_first}", websocket)
8888
context.page_connect("page-id-1")
8989
cull_task_1 = context.page_disconnect("page-id-1")
9090
await asyncio.sleep(0.1)
9191
# after 0.1 we connect again, but close it directly
9292
context.page_connect("page-id-2")
9393
if close_first:
94-
context.page_close("page-id-2")
94+
cull_task_2 = context.page_close("page-id-2")
9595
await asyncio.sleep(0.01)
96-
cull_task_2 = context.page_disconnect("page-id-2")
96+
context.page_disconnect("page-id-2")
9797
else:
98-
cull_task_2 = context.page_disconnect("page-id-2")
98+
context.page_disconnect("page-id-2")
9999
await asyncio.sleep(0.01)
100-
context.page_close("page-id-2")
100+
cull_task_2 = context.page_close("page-id-2")
101+
assert cull_task_2 is not None
101102
assert not context.closed_event.is_set()
102103
await asyncio.sleep(0.15)
103104
# but even though we closed, the first page is still in the disconnected state

0 commit comments

Comments
 (0)