diff --git a/sanic/server/websockets/connection.py b/sanic/server/websockets/connection.py index 87881b84de..8ff19da668 100644 --- a/sanic/server/websockets/connection.py +++ b/sanic/server/websockets/connection.py @@ -45,7 +45,7 @@ async def send(self, data: Union[str, bytes], *args, **kwargs) -> None: await self._send(message) - async def recv(self, *args, **kwargs) -> Optional[str]: + async def recv(self, *args, **kwargs) -> Optional[Union[str, bytes]]: message = await self._receive() if message["type"] == "websocket.receive": @@ -53,7 +53,7 @@ async def recv(self, *args, **kwargs) -> Optional[str]: return message["text"] except KeyError: try: - return message["bytes"].decode() + return message["bytes"] except KeyError: raise InvalidUsage("Bad ASGI message received") elif message["type"] == "websocket.disconnect": diff --git a/tests/test_asgi.py b/tests/test_asgi.py index 0c76a67f3c..fe9ff306e2 100644 --- a/tests/test_asgi.py +++ b/tests/test_asgi.py @@ -342,7 +342,7 @@ async def test_websocket_send(send, receive, message_stack): @pytest.mark.asyncio -async def test_websocket_receive(send, receive, message_stack): +async def test_websocket_text_receive(send, receive, message_stack): msg = {"text": "hello", "type": "websocket.receive"} message_stack.append(msg) @@ -351,6 +351,15 @@ async def test_websocket_receive(send, receive, message_stack): assert text == msg["text"] +@pytest.mark.asyncio +async def test_websocket_bytes_receive(send, receive, message_stack): + msg = {"bytes": b"hello", "type": "websocket.receive"} + message_stack.append(msg) + + ws = WebSocketConnection(send, receive) + data = await ws.receive() + + assert data == msg["bytes"] @pytest.mark.asyncio async def test_websocket_accept_with_no_subprotocols(