Skip to content

Commit

Permalink
Replacing assignation by typing for websocket_handshake (#2273)
Browse files Browse the repository at this point in the history
* Replacing assignation by typing for `websocket_handshake`

Related to #2272

* Fix some type hinting issues

* Cleanup websocket handchake response concat

* Optimize concat encoding

Co-authored-by: Adam Hopkins <adam@amhopkins.com>
  • Loading branch information
cnicodeme and ahopkins authored Oct 27, 2021
1 parent 645310c commit 71cc30e
Showing 1 changed file with 24 additions and 25 deletions.
49 changes: 24 additions & 25 deletions sanic/server/protocols/websocket_protocol.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from typing import TYPE_CHECKING, Optional, Sequence
from typing import TYPE_CHECKING, Optional, Sequence, cast

from websockets.connection import CLOSED, CLOSING, OPEN
from websockets.server import ServerConnection
from websockets.typing import Subprotocol

from sanic.exceptions import ServerError
from sanic.log import error_logger
Expand All @@ -15,13 +16,6 @@


class WebSocketProtocol(HttpProtocol):

websocket: Optional[WebsocketImplProtocol]
websocket_timeout: float
websocket_max_size = Optional[int]
websocket_ping_interval = Optional[float]
websocket_ping_timeout = Optional[float]

def __init__(
self,
*args,
Expand All @@ -35,7 +29,7 @@ def __init__(
**kwargs,
):
super().__init__(*args, **kwargs)
self.websocket = None
self.websocket: Optional[WebsocketImplProtocol] = None
self.websocket_timeout = websocket_timeout
self.websocket_max_size = websocket_max_size
if websocket_max_queue is not None and websocket_max_queue > 0:
Expand Down Expand Up @@ -109,14 +103,22 @@ def close_if_idle(self):
return super().close_if_idle()

async def websocket_handshake(
self, request, subprotocols=Optional[Sequence[str]]
self, request, subprotocols: Optional[Sequence[str]] = None
):
# let the websockets package do the handshake with the client
try:
if subprotocols is not None:
# subprotocols can be a set or frozenset,
# but ServerConnection needs a list
subprotocols = list(subprotocols)
subprotocols = cast(
Optional[Sequence[Subprotocol]],
list(
[
Subprotocol(subprotocol)
for subprotocol in subprotocols
]
),
)
ws_conn = ServerConnection(
max_size=self.websocket_max_size,
subprotocols=subprotocols,
Expand All @@ -131,21 +133,18 @@ async def websocket_handshake(
)
raise ServerError(msg, status_code=500)
if 100 <= resp.status_code <= 299:
rbody = "".join(
[
"HTTP/1.1 ",
str(resp.status_code),
" ",
resp.reason_phrase,
"\r\n",
]
)
rbody += "".join(f"{k}: {v}\r\n" for k, v in resp.headers.items())
first_line = (
f"HTTP/1.1 {resp.status_code} {resp.reason_phrase}\r\n"
).encode()
rbody = bytearray(first_line)
rbody += (
"".join(f"{k}: {v}\r\n" for k, v in resp.headers.items())
).encode()
rbody += b"\r\n"
if resp.body is not None:
rbody += f"\r\n{resp.body}\r\n\r\n"
else:
rbody += "\r\n"
await super().send(rbody.encode())
rbody += resp.body
rbody += b"\r\n\r\n"
await super().send(rbody)
else:
raise ServerError(resp.body, resp.status_code)
self.websocket = WebsocketImplProtocol(
Expand Down

0 comments on commit 71cc30e

Please sign in to comment.