Skip to content

Commit

Permalink
Merge branch 'tim/SessionClientDisconnectedError' into tim/StreamlitR…
Browse files Browse the repository at this point in the history
…untime

* tim/SessionClientDisconnectedError:
  better docstring
  typo
  SessionClientDisconnectedError, with test
  WIP
  * Remove staticmethod decorator from __call__ method of SingletonAPI to avoid mypy bug python/mypy#7781 (streamlit#4981)
  • Loading branch information
tconkling committed Jul 15, 2022
2 parents ba0df18 + 88266e1 commit dc12af3
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 7 deletions.
10 changes: 6 additions & 4 deletions lib/streamlit/caching/singleton_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,22 +122,24 @@ class SingletonAPI:

# Bare decorator usage
@overload
@staticmethod
def __call__(func: F) -> F:
def __call__(self, func: F) -> F:
...

# Decorator with arguments
@overload
@staticmethod
def __call__(
self,
*,
show_spinner: bool = True,
suppress_st_warning=False,
) -> Callable[[F], F]:
...

@staticmethod
# __call__ should be a static method, but there's a mypy bug that
# breaks type checking for overloaded static functions:
# https://github.com/python/mypy/issues/7781
def __call__(
self,
func: Optional[F] = None,
*,
show_spinner: bool = True,
Expand Down
19 changes: 16 additions & 3 deletions lib/streamlit/web/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,11 +119,21 @@
SCRIPT_RUN_CHECK_TIMEOUT = 60


class SessionClientDisconnectedError(Exception):
"""Raised by operations on a disconnected SessionClient."""

pass


class SessionClient(Protocol):
"""Interface for sending data to a session's client."""

def write_forward_msg(self, msg: ForwardMsg) -> None:
"""Deliver a ForwardMsg to the client."""
"""Deliver a ForwardMsg to the client.
If the SessionClient has been disconnected, it should raise a
SessionClientDisconnectedError.
"""


class SessionInfo:
Expand Down Expand Up @@ -515,7 +525,7 @@ async def _loop_coroutine(
for msg in msg_list:
try:
self._send_message(session_info, msg)
except tornado.websocket.WebSocketClosedError:
except SessionClientDisconnectedError:
self._close_app_session(session_info.session.id)
await asyncio.sleep(0)
await asyncio.sleep(0)
Expand Down Expand Up @@ -722,7 +732,10 @@ def check_origin(self, origin: str) -> bool:

def write_forward_msg(self, msg: ForwardMsg) -> None:
"""Send a ForwardMsg to the browser."""
self.write_message(serialize_forward_msg(msg), binary=True)
try:
self.write_message(serialize_forward_msg(msg), binary=True)
except tornado.websocket.WebSocketClosedError as e:
raise SessionClientDisconnectedError from e

def open(self, *args, **kwargs) -> Optional[Awaitable[None]]:
# Extract user info from the X-Streamlit-User header
Expand Down
42 changes: 42 additions & 0 deletions lib/tests/streamlit/web/server/server_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,48 @@ async def test_orphaned_upload_file_deletion(self):
[],
)

@tornado.testing.gen_test
async def test_send_message_to_disconnected_websocket(self):
"""Sending a message to a disconnected SessionClient raises an error.
We should gracefully handle the error by cleaning up the session.
"""
with patch(
"streamlit.web.server.server.LocalSourcesWatcher"
), self._patch_app_session():
await self.start_server_loop()
await self.ws_connect()

# Get the server's socket and session for this client
session_info = list(self.server._session_info_by_id.values())[0]

with patch.object(
session_info.session, "flush_browser_queue"
) as flush_browser_queue, patch.object(
session_info.client, "write_message"
) as ws_write_message:
# Patch flush_browser_queue to simulate a pending message.
flush_browser_queue.return_value = [_create_dataframe_msg([1, 2, 3])]

# Patch the session's WebsocketHandler to raise a
# WebSocketClosedError when we write to it.
ws_write_message.side_effect = tornado.websocket.WebSocketClosedError()

# Tick the server. Our session's browser_queue will be flushed,
# and the Websocket client's write_message will be called,
# raising our WebSocketClosedError.
while not flush_browser_queue.called:
self.server._need_send_data.set()
await asyncio.sleep(0)

flush_browser_queue.assert_called_once()
ws_write_message.assert_called_once()

# Our session should have been removed from the server as
# a result of the WebSocketClosedError.
self.assertIsNone(
self.server._get_session_info(session_info.session.id)
)


class ServerUtilsTest(unittest.TestCase):
def test_is_url_from_allowed_origins_allowed_domains(self):
Expand Down

0 comments on commit dc12af3

Please sign in to comment.